From 7db1a401d3b88653198d06d7832ab3b9b36ad866 Mon Sep 17 00:00:00 2001 From: Savely Krendelhoff Date: Fri, 22 Aug 2025 21:04:06 +0700 Subject: [PATCH] Implement codec tests --- internal/protocol/codec_test.go | 308 ++++++++++++++++++++++++++++++++ 1 file changed, 308 insertions(+) create mode 100644 internal/protocol/codec_test.go diff --git a/internal/protocol/codec_test.go b/internal/protocol/codec_test.go new file mode 100644 index 0000000..6e5d6d5 --- /dev/null +++ b/internal/protocol/codec_test.go @@ -0,0 +1,308 @@ +package protocol + +import ( + "bytes" + "encoding/json" + "io" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCodec_Encode_Decode(t *testing.T) { + codec := NewCodec() + + tests := []struct { + name string + message *Message + }{ + { + name: "challenge request (empty payload)", + message: &Message{ + Type: ChallengeRequest, + Payload: []byte{}, + }, + }, + { + name: "challenge response with payload", + message: &Message{ + Type: ChallengeResponse, + Payload: []byte(`{"challenge":{"timestamp":1640995200,"difficulty":4}}`), + }, + }, + { + name: "error response", + message: &Message{ + Type: ErrorResponse, + Payload: []byte(`{"code":"INVALID_SOLUTION","message":"Invalid nonce"}`), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + + // Encode message + err := codec.Encode(&buf, tt.message) + require.NoError(t, err) + + // Decode message + decoded, err := codec.Decode(&buf) + require.NoError(t, err) + + assert.Equal(t, tt.message.Type, decoded.Type) + if len(tt.message.Payload) == 0 && len(decoded.Payload) == 0 { + // Both are empty (nil or empty slice) + assert.True(t, true) + } else { + assert.Equal(t, tt.message.Payload, decoded.Payload) + } + }) + } +} + +func TestCodec_Encode_Errors(t *testing.T) { + codec := NewCodec() + + tests := []struct { + name string + message *Message + wantErr string + }{ + { + name: "nil message", + message: nil, + wantErr: "message cannot be nil", + }, + { + name: "payload too large", + message: &Message{ + Type: ChallengeRequest, + Payload: make([]byte, MaxPayloadSize+1), + }, + wantErr: "payload size", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + err := codec.Encode(&buf, tt.message) + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + }) + } +} + +func TestCodec_Decode_Errors(t *testing.T) { + codec := NewCodec() + + tests := []struct { + name string + data []byte + wantErr string + }{ + { + name: "empty data", + data: []byte{}, + wantErr: "EOF", + }, + { + name: "invalid message type", + data: []byte{0xFF, 0x00, 0x00, 0x00, 0x00}, + wantErr: "invalid message type", + }, + { + name: "incomplete header", + data: []byte{0x01, 0x00, 0x00}, + wantErr: "failed to read payload length", + }, + { + name: "payload too large", + data: append([]byte{0x01}, encodeBigEndianUint32(MaxPayloadSize+1)...), + wantErr: "payload length", + }, + { + name: "incomplete payload", + data: []byte{0x01, 0x00, 0x00, 0x00, 0x05, 0x01, 0x02}, + wantErr: "failed to read payload", + }, + { + name: "invalid UTF-8 in payload", + data: []byte{0x01, 0x00, 0x00, 0x00, 0x03, 0xFF, 0xFE, 0xFD}, + wantErr: "invalid UTF-8", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := bytes.NewBuffer(tt.data) + _, err := codec.Decode(buf) + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + }) + } +} + + +func TestCodec_RoundTrip_RealPayloads(t *testing.T) { + codec := NewCodec() + + t.Run("challenge response round trip", func(t *testing.T) { + original := &ChallengeResponsePayload{ + Timestamp: time.Now().Unix(), + Difficulty: 4, + Resource: "quotes", + Random: []byte("random123"), + HMAC: []byte("hmac_signature"), + } + + // Marshal payload + jsonData, err := json.Marshal(original) + require.NoError(t, err) + + msg := &Message{ + Type: ChallengeResponse, + Payload: jsonData, + } + + // Simulate network transmission + var buf bytes.Buffer + err = codec.Encode(&buf, msg) + require.NoError(t, err) + + // Decode message + decoded, err := codec.Decode(&buf) + require.NoError(t, err) + + // Unmarshal payload + var result ChallengeResponsePayload + err = json.Unmarshal(decoded.Payload, &result) + require.NoError(t, err) + + assert.Equal(t, original.Timestamp, result.Timestamp) + assert.Equal(t, original.Difficulty, result.Difficulty) + assert.Equal(t, original.Resource, result.Resource) + assert.Equal(t, original.Random, result.Random) + assert.Equal(t, original.HMAC, result.HMAC) + }) + + t.Run("quote response round trip", func(t *testing.T) { + original := &QuoteResponsePayload{ + Text: "Test quote", + Author: "Test author", + } + + // Marshal payload + jsonData, err := json.Marshal(original) + require.NoError(t, err) + + msg := &Message{ + Type: QuoteResponse, + Payload: jsonData, + } + + var buf bytes.Buffer + err = codec.Encode(&buf, msg) + require.NoError(t, err) + + decoded, err := codec.Decode(&buf) + require.NoError(t, err) + + var result QuoteResponsePayload + err = json.Unmarshal(decoded.Payload, &result) + require.NoError(t, err) + + assert.Equal(t, original.Text, result.Text) + assert.Equal(t, original.Author, result.Author) + }) +} + +func TestCodec_WriteError_Handling(t *testing.T) { + codec := NewCodec() + + // Create a writer that fails after a certain number of bytes + failAfter := 3 + writer := &failingWriter{failAfter: failAfter} + + msg := &Message{ + Type: ChallengeResponse, + Payload: []byte("test payload"), + } + + err := codec.Encode(writer, msg) + assert.Error(t, err) +} + +func TestCodec_ReadError_Handling(t *testing.T) { + codec := NewCodec() + + // Create a reader that fails after reading header + reader := &failingReader{ + data: []byte{0x01, 0x00, 0x00, 0x00, 0x05}, + failAfter: 5, + } + + _, err := codec.Decode(reader) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to read payload") +} + +// Helper functions and types for testing + +func encodeBigEndianUint32(val uint32) []byte { + return []byte{ + byte(val >> 24), + byte(val >> 16), + byte(val >> 8), + byte(val), + } +} + +type failingWriter struct { + written int + failAfter int +} + +func (w *failingWriter) Write(data []byte) (int, error) { + if w.written >= w.failAfter { + return 0, io.ErrShortWrite + } + + remaining := w.failAfter - w.written + if len(data) <= remaining { + w.written += len(data) + return len(data), nil + } + + w.written = w.failAfter + return remaining, io.ErrShortWrite +} + +type failingReader struct { + data []byte + pos int + failAfter int +} + +func (r *failingReader) Read(buf []byte) (int, error) { + if r.pos >= r.failAfter { + return 0, io.ErrUnexpectedEOF + } + + if r.pos >= len(r.data) { + return 0, io.EOF + } + + n := copy(buf, r.data[r.pos:]) + r.pos += n + + if r.pos >= r.failAfter { + return n, io.ErrUnexpectedEOF + } + + return n, nil +}