package protocol import ( "bytes" "io" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestMessageDecoder_Decode_Header(t *testing.T) { decoder := NewMessageDecoder() tests := []struct { name string data []byte wantType MessageType wantLength uint32 wantErr string }{ { name: "challenge request with empty payload", data: []byte{0x01, 0x00, 0x00, 0x00, 0x00}, wantType: ChallengeRequestType, wantLength: 0, }, { name: "solution request with payload", data: append([]byte{0x03, 0x00, 0x00, 0x00, 0x05}, []byte("hello")...), wantType: SolutionRequestType, wantLength: 5, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { buf := bytes.NewBuffer(tt.data) msg, err := decoder.Decode(buf) if tt.wantErr != "" { assert.Error(t, err) assert.Contains(t, err.Error(), tt.wantErr) return } require.NoError(t, err) assert.Equal(t, tt.wantType, msg.Type) assert.Equal(t, tt.wantLength, msg.PayloadLength) if tt.wantLength > 0 { assert.NotNil(t, msg.PayloadStream) } else { assert.Nil(t, msg.PayloadStream) } }) } } func TestMessageDecoder_Decode_Errors(t *testing.T) { decoder := NewMessageDecoder() tests := []struct { name string data []byte wantErr string }{ { name: "empty data", data: []byte{}, wantErr: "EOF", }, { name: "invalid message type", data: []byte{0xFF, 0x00, 0x00, 0x00, 0x00}, wantErr: "invalid message type", }, { name: "response type not allowed", data: []byte{0x02, 0x00, 0x00, 0x00, 0x00}, // ChallengeResponseType wantErr: "invalid message type", }, { name: "incomplete header", data: []byte{0x01, 0x00, 0x00}, wantErr: "failed to read payload length", }, { name: "payload too large", data: append([]byte{0x01}, encodeBigEndianUint32(MaxPayloadSize+1)...), wantErr: "payload length", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { buf := bytes.NewBuffer(tt.data) _, err := decoder.Decode(buf) assert.Error(t, err) assert.Contains(t, err.Error(), tt.wantErr) }) } } func TestChallengeRequest_Decode(t *testing.T) { tests := []struct { name string stream io.Reader }{ {"nil stream", nil}, {"non-empty stream", bytes.NewReader([]byte("ignored"))}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := &ChallengeRequest{} err := req.Decode(tt.stream) assert.NoError(t, err) }) } } func TestSolutionRequest_Decode(t *testing.T) { tests := []struct { name string payload []byte wantErr bool wantNonce uint64 }{ { name: "valid solution request", payload: []byte(`{"challenge":{"timestamp":1640995200,"difficulty":4,"resource":"quotes","random":"cmFuZG9tMTIz","hmac":"aG1hY19zaWduYXR1cmU="},"nonce":12345}`), wantNonce: 12345, }, { name: "invalid JSON", payload: []byte(`{invalid json}`), wantErr: true, }, { name: "invalid UTF-8", payload: []byte{0xFF, 0xFE, 0xFD}, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { req := &SolutionRequest{} var reader io.Reader if tt.payload != nil { reader = bytes.NewReader(tt.payload) } err := req.Decode(reader) if tt.wantErr { assert.Error(t, err) } else { require.NoError(t, err) assert.Equal(t, tt.wantNonce, req.Nonce) } }) } t.Run("nil stream should error", func(t *testing.T) { req := &SolutionRequest{} err := req.Decode(nil) assert.Error(t, err) }) } func TestEndToEnd_RequestDecoding(t *testing.T) { decoder := NewMessageDecoder() t.Run("challenge request flow", func(t *testing.T) { // Create message data: type=0x01, length=0 data := []byte{0x01, 0x00, 0x00, 0x00, 0x00} buf := bytes.NewBuffer(data) // Decode header msg, err := decoder.Decode(buf) require.NoError(t, err) assert.Equal(t, ChallengeRequestType, msg.Type) // Decode request req := &ChallengeRequest{} err = req.Decode(msg.PayloadStream) require.NoError(t, err) }) t.Run("solution request flow", func(t *testing.T) { payload := `{"challenge":{"timestamp":1640995200,"difficulty":4,"resource":"quotes","random":"cmFuZG9tMTIz","hmac":"aG1hY19zaWduYXR1cmU="},"nonce":12345}` // Create message data: type=0x03, length, payload var buf bytes.Buffer buf.WriteByte(0x03) // SolutionRequestType length := uint32(len(payload)) buf.Write(encodeBigEndianUint32(length)) buf.WriteString(payload) // Decode header msg, err := decoder.Decode(&buf) require.NoError(t, err) assert.Equal(t, SolutionRequestType, msg.Type) assert.Equal(t, length, msg.PayloadLength) // Decode request req := &SolutionRequest{} err = req.Decode(msg.PayloadStream) require.NoError(t, err) assert.Equal(t, uint64(12345), req.Nonce) }) } // Helper functions for testing func encodeBigEndianUint32(val uint32) []byte { return []byte{ byte(val >> 24), byte(val >> 16), byte(val >> 8), byte(val), } }