From d12de089a0bc8c837f7a3fde95aba58b4d8b3343 Mon Sep 17 00:00:00 2001 From: Savely Krendelhoff Date: Sat, 23 Aug 2025 12:04:38 +0700 Subject: [PATCH] [PHASE-5] Delegate encoding to the objects themselves --- internal/protocol/codec.go | 43 +----- internal/protocol/codec_test.go | 265 +++++++++++--------------------- internal/protocol/responses.go | 82 ++++++++++ 3 files changed, 173 insertions(+), 217 deletions(-) create mode 100644 internal/protocol/responses.go diff --git a/internal/protocol/codec.go b/internal/protocol/codec.go index e5d83cd..d4c0d20 100644 --- a/internal/protocol/codec.go +++ b/internal/protocol/codec.go @@ -15,38 +15,6 @@ func NewCodec() *Codec { return &Codec{} } -// Encode writes a message to the writer using the protocol format -func (c *Codec) Encode(w io.Writer, msg *Message) error { - if msg == nil { - return fmt.Errorf("message cannot be nil") - } - - // Validate payload size - if len(msg.Payload) > MaxPayloadSize { - return fmt.Errorf("payload size %d exceeds maximum %d", len(msg.Payload), MaxPayloadSize) - } - - // Write message type (1 byte) - if err := binary.Write(w, binary.BigEndian, msg.Type); err != nil { - return fmt.Errorf("failed to write message type: %w", err) - } - - // Write payload length (4 bytes, big-endian) - payloadLength := uint32(len(msg.Payload)) - if err := binary.Write(w, binary.BigEndian, payloadLength); err != nil { - return fmt.Errorf("failed to write payload length: %w", err) - } - - // Write payload if present - if len(msg.Payload) > 0 { - if _, err := w.Write(msg.Payload); err != nil { - return fmt.Errorf("failed to write payload: %w", err) - } - } - - return nil -} - // Decode reads a message from the reader using the protocol format func (c *Codec) Decode(r io.Reader) (*Message, error) { // Read message type (1 byte) @@ -78,12 +46,9 @@ func (c *Codec) Decode(r io.Reader) (*Message, error) { var payload []byte if payloadLength > 0 { payload = make([]byte, payloadLength) - // Use LimitReader to ensure we don't read more than payloadLength bytes - // even if the underlying reader has more data available - limitedReader := io.LimitReader(r, int64(payloadLength)) - // Note: ReadFull may block waiting for data. The connection handler - // MUST set appropriate read deadlines to prevent slowloris attacks - if _, err := io.ReadFull(limitedReader, payload); err != nil { + // ReadFull reads exactly payloadLength bytes + // The server MUST use LimitReader and set read deadlines to prevent attacks + if _, err := io.ReadFull(r, payload); err != nil { return nil, fmt.Errorf("failed to read payload: %w", err) } @@ -103,7 +68,7 @@ func (c *Codec) Decode(r io.Reader) (*Message, error) { // isValidMessageType checks if the message type is defined in the protocol func isValidMessageType(msgType MessageType) bool { switch msgType { - case ChallengeRequest, ChallengeResponse, SolutionRequest, QuoteResponse, ErrorResponse: + case ChallengeRequestType, ChallengeResponseType, SolutionRequestType, QuoteResponseType, ErrorResponseType: return true default: return false diff --git a/internal/protocol/codec_test.go b/internal/protocol/codec_test.go index 6e5d6d5..f0eb749 100644 --- a/internal/protocol/codec_test.go +++ b/internal/protocol/codec_test.go @@ -2,102 +2,18 @@ package protocol import ( "bytes" - "encoding/json" "io" "testing" "time" + "hash-of-wisdom/internal/pow/challenge" + "hash-of-wisdom/internal/quotes" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestCodec_Encode_Decode(t *testing.T) { - codec := NewCodec() - - tests := []struct { - name string - message *Message - }{ - { - name: "challenge request (empty payload)", - message: &Message{ - Type: ChallengeRequest, - Payload: []byte{}, - }, - }, - { - name: "challenge response with payload", - message: &Message{ - Type: ChallengeResponse, - Payload: []byte(`{"challenge":{"timestamp":1640995200,"difficulty":4}}`), - }, - }, - { - name: "error response", - message: &Message{ - Type: ErrorResponse, - Payload: []byte(`{"code":"INVALID_SOLUTION","message":"Invalid nonce"}`), - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var buf bytes.Buffer - - // Encode message - err := codec.Encode(&buf, tt.message) - require.NoError(t, err) - - // Decode message - decoded, err := codec.Decode(&buf) - require.NoError(t, err) - - assert.Equal(t, tt.message.Type, decoded.Type) - if len(tt.message.Payload) == 0 && len(decoded.Payload) == 0 { - // Both are empty (nil or empty slice) - assert.True(t, true) - } else { - assert.Equal(t, tt.message.Payload, decoded.Payload) - } - }) - } -} - -func TestCodec_Encode_Errors(t *testing.T) { - codec := NewCodec() - - tests := []struct { - name string - message *Message - wantErr string - }{ - { - name: "nil message", - message: nil, - wantErr: "message cannot be nil", - }, - { - name: "payload too large", - message: &Message{ - Type: ChallengeRequest, - Payload: make([]byte, MaxPayloadSize+1), - }, - wantErr: "payload size", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - var buf bytes.Buffer - err := codec.Encode(&buf, tt.message) - assert.Error(t, err) - assert.Contains(t, err.Error(), tt.wantErr) - }) - } -} - -func TestCodec_Decode_Errors(t *testing.T) { +func TestCodec_Decode(t *testing.T) { codec := NewCodec() tests := []struct { @@ -147,96 +63,6 @@ func TestCodec_Decode_Errors(t *testing.T) { } } - -func TestCodec_RoundTrip_RealPayloads(t *testing.T) { - codec := NewCodec() - - t.Run("challenge response round trip", func(t *testing.T) { - original := &ChallengeResponsePayload{ - Timestamp: time.Now().Unix(), - Difficulty: 4, - Resource: "quotes", - Random: []byte("random123"), - HMAC: []byte("hmac_signature"), - } - - // Marshal payload - jsonData, err := json.Marshal(original) - require.NoError(t, err) - - msg := &Message{ - Type: ChallengeResponse, - Payload: jsonData, - } - - // Simulate network transmission - var buf bytes.Buffer - err = codec.Encode(&buf, msg) - require.NoError(t, err) - - // Decode message - decoded, err := codec.Decode(&buf) - require.NoError(t, err) - - // Unmarshal payload - var result ChallengeResponsePayload - err = json.Unmarshal(decoded.Payload, &result) - require.NoError(t, err) - - assert.Equal(t, original.Timestamp, result.Timestamp) - assert.Equal(t, original.Difficulty, result.Difficulty) - assert.Equal(t, original.Resource, result.Resource) - assert.Equal(t, original.Random, result.Random) - assert.Equal(t, original.HMAC, result.HMAC) - }) - - t.Run("quote response round trip", func(t *testing.T) { - original := &QuoteResponsePayload{ - Text: "Test quote", - Author: "Test author", - } - - // Marshal payload - jsonData, err := json.Marshal(original) - require.NoError(t, err) - - msg := &Message{ - Type: QuoteResponse, - Payload: jsonData, - } - - var buf bytes.Buffer - err = codec.Encode(&buf, msg) - require.NoError(t, err) - - decoded, err := codec.Decode(&buf) - require.NoError(t, err) - - var result QuoteResponsePayload - err = json.Unmarshal(decoded.Payload, &result) - require.NoError(t, err) - - assert.Equal(t, original.Text, result.Text) - assert.Equal(t, original.Author, result.Author) - }) -} - -func TestCodec_WriteError_Handling(t *testing.T) { - codec := NewCodec() - - // Create a writer that fails after a certain number of bytes - failAfter := 3 - writer := &failingWriter{failAfter: failAfter} - - msg := &Message{ - Type: ChallengeResponse, - Payload: []byte("test payload"), - } - - err := codec.Encode(writer, msg) - assert.Error(t, err) -} - func TestCodec_ReadError_Handling(t *testing.T) { codec := NewCodec() @@ -251,6 +77,89 @@ func TestCodec_ReadError_Handling(t *testing.T) { assert.Contains(t, err.Error(), "failed to read payload") } +func TestChallengeResponse_Encode(t *testing.T) { + challenge := &challenge.Challenge{ + Timestamp: time.Now().Unix(), + Difficulty: 4, + Resource: "quotes", + Random: []byte("random123"), + HMAC: []byte("hmac_signature"), + } + + response := &ChallengeResponse{Challenge: challenge} + + var buf bytes.Buffer + err := response.Encode(&buf) + require.NoError(t, err) + + // Verify the encoded message can be decoded + codec := NewCodec() + decoded, err := codec.Decode(&buf) + require.NoError(t, err) + + assert.Equal(t, ChallengeResponseType, decoded.Type) + assert.Contains(t, string(decoded.Payload), "quotes") + assert.Contains(t, string(decoded.Payload), "cmFuZG9tMTIz") // "random123" base64 encoded +} + +func TestSolutionResponse_Encode(t *testing.T) { + quote := "es.Quote{ + Text: "Test quote", + Author: "Test author", + } + + response := &SolutionResponse{Quote: quote} + + var buf bytes.Buffer + err := response.Encode(&buf) + require.NoError(t, err) + + // Verify the encoded message can be decoded + codec := NewCodec() + decoded, err := codec.Decode(&buf) + require.NoError(t, err) + + assert.Equal(t, QuoteResponseType, decoded.Type) + assert.Contains(t, string(decoded.Payload), "Test quote") + assert.Contains(t, string(decoded.Payload), "Test author") +} + +func TestErrorResponse_Encode(t *testing.T) { + errorResp := &ErrorResponse{ + Code: "INVALID_SOLUTION", + Message: "The provided PoW solution is incorrect", + RetryAfter: 30, + Details: map[string]string{"attempt": "1"}, + } + + var buf bytes.Buffer + err := errorResp.Encode(&buf) + require.NoError(t, err) + + // Verify the encoded message can be decoded + codec := NewCodec() + decoded, err := codec.Decode(&buf) + require.NoError(t, err) + + assert.Equal(t, ErrorResponseType, decoded.Type) + assert.Contains(t, string(decoded.Payload), "INVALID_SOLUTION") + assert.Contains(t, string(decoded.Payload), "The provided PoW solution is incorrect") + assert.Contains(t, string(decoded.Payload), "30") +} + +func TestResponse_WriteError_Handling(t *testing.T) { + response := &ErrorResponse{ + Code: "TEST_ERROR", + Message: "Test message", + } + + // Create a writer that fails immediately + writer := &failingWriter{failAfter: 1} + + err := response.Encode(writer) + assert.Error(t, err) +} + // Helper functions and types for testing func encodeBigEndianUint32(val uint32) []byte { diff --git a/internal/protocol/responses.go b/internal/protocol/responses.go new file mode 100644 index 0000000..8521d4f --- /dev/null +++ b/internal/protocol/responses.go @@ -0,0 +1,82 @@ +package protocol + +import ( + "encoding/binary" + "encoding/json" + "fmt" + "io" + + "hash-of-wisdom/internal/pow/challenge" + "hash-of-wisdom/internal/quotes" +) + +// writeHeader writes the message type and payload length to the writer +func writeHeader(w io.Writer, msgType MessageType, payloadLength uint32) error { + // Write message type (1 byte) + if err := binary.Write(w, binary.BigEndian, msgType); err != nil { + return fmt.Errorf("failed to write message type: %w", err) + } + + // Write payload length (4 bytes, big-endian) + if err := binary.Write(w, binary.BigEndian, payloadLength); err != nil { + return fmt.Errorf("failed to write payload length: %w", err) + } + + return nil +} + +// encodeResponse is a helper function that encodes any response with the given message type +func encodeResponse(w io.Writer, msgType MessageType, payload interface{}) error { + // Marshal to get exact payload size + payloadBytes, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("failed to encode payload: %w", err) + } + + // Write header + if err := writeHeader(w, msgType, uint32(len(payloadBytes))); err != nil { + return err + } + + // Write JSON payload directly to stream + if len(payloadBytes) > 0 { + if _, err := w.Write(payloadBytes); err != nil { + return fmt.Errorf("failed to write payload: %w", err) + } + } + + return nil +} + +// ChallengeResponse represents a challenge response +type ChallengeResponse struct { + Challenge *challenge.Challenge +} + +// SolutionResponse represents a successful solution response (contains quote) +type SolutionResponse struct { + Quote *quotes.Quote +} + +// ErrorResponse represents an error response +type ErrorResponse struct { + Code string `json:"code"` + Message string `json:"message"` + RetryAfter int `json:"retry_after,omitempty"` + Details map[string]string `json:"details,omitempty"` +} + +// Encode writes the challenge response to the writer +func (r *ChallengeResponse) Encode(w io.Writer) error { + return encodeResponse(w, ChallengeResponseType, r.Challenge) +} + +// Encode writes the solution response to the writer +func (r *SolutionResponse) Encode(w io.Writer) error { + return encodeResponse(w, QuoteResponseType, r.Quote) +} + +// Encode writes the error response to the writer +func (r *ErrorResponse) Encode(w io.Writer) error { + return encodeResponse(w, ErrorResponseType, r) +}