diff --git a/internal/protocol/codec.go b/internal/protocol/codec.go deleted file mode 100644 index d4c0d20..0000000 --- a/internal/protocol/codec.go +++ /dev/null @@ -1,76 +0,0 @@ -package protocol - -import ( - "encoding/binary" - "fmt" - "io" - "unicode/utf8" -) - -// Codec handles encoding and decoding of protocol messages -type Codec struct{} - -// NewCodec creates a new protocol codec -func NewCodec() *Codec { - return &Codec{} -} - -// Decode reads a message from the reader using the protocol format -func (c *Codec) Decode(r io.Reader) (*Message, error) { - // Read message type (1 byte) - var msgType MessageType - if err := binary.Read(r, binary.BigEndian, &msgType); err != nil { - if err == io.EOF { - return nil, err - } - return nil, fmt.Errorf("failed to read message type: %w", err) - } - - // Validate message type - if !isValidMessageType(msgType) { - return nil, fmt.Errorf("invalid message type: 0x%02x", msgType) - } - - // Read payload length (4 bytes, big-endian) - var payloadLength uint32 - if err := binary.Read(r, binary.BigEndian, &payloadLength); err != nil { - return nil, fmt.Errorf("failed to read payload length: %w", err) - } - - // Validate payload length - if payloadLength > MaxPayloadSize { - return nil, fmt.Errorf("payload length %d exceeds maximum %d", payloadLength, MaxPayloadSize) - } - - // Read payload if present - var payload []byte - if payloadLength > 0 { - payload = make([]byte, payloadLength) - // ReadFull reads exactly payloadLength bytes - // The server MUST use LimitReader and set read deadlines to prevent attacks - if _, err := io.ReadFull(r, payload); err != nil { - return nil, fmt.Errorf("failed to read payload: %w", err) - } - - // Validate payload is valid UTF-8 - if !utf8.Valid(payload) { - return nil, fmt.Errorf("payload contains invalid UTF-8") - } - } - - return &Message{ - Type: msgType, - Payload: payload, - }, nil -} - - -// isValidMessageType checks if the message type is defined in the protocol -func isValidMessageType(msgType MessageType) bool { - switch msgType { - case ChallengeRequestType, ChallengeResponseType, SolutionRequestType, QuoteResponseType, ErrorResponseType: - return true - default: - return false - } -} diff --git a/internal/protocol/codec_test.go b/internal/protocol/codec_test.go deleted file mode 100644 index f0eb749..0000000 --- a/internal/protocol/codec_test.go +++ /dev/null @@ -1,217 +0,0 @@ -package protocol - -import ( - "bytes" - "io" - "testing" - "time" - - "hash-of-wisdom/internal/pow/challenge" - "hash-of-wisdom/internal/quotes" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestCodec_Decode(t *testing.T) { - codec := NewCodec() - - tests := []struct { - name string - data []byte - wantErr string - }{ - { - name: "empty data", - data: []byte{}, - wantErr: "EOF", - }, - { - name: "invalid message type", - data: []byte{0xFF, 0x00, 0x00, 0x00, 0x00}, - wantErr: "invalid message type", - }, - { - name: "incomplete header", - data: []byte{0x01, 0x00, 0x00}, - wantErr: "failed to read payload length", - }, - { - name: "payload too large", - data: append([]byte{0x01}, encodeBigEndianUint32(MaxPayloadSize+1)...), - wantErr: "payload length", - }, - { - name: "incomplete payload", - data: []byte{0x01, 0x00, 0x00, 0x00, 0x05, 0x01, 0x02}, - wantErr: "failed to read payload", - }, - { - name: "invalid UTF-8 in payload", - data: []byte{0x01, 0x00, 0x00, 0x00, 0x03, 0xFF, 0xFE, 0xFD}, - wantErr: "invalid UTF-8", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - buf := bytes.NewBuffer(tt.data) - _, err := codec.Decode(buf) - assert.Error(t, err) - assert.Contains(t, err.Error(), tt.wantErr) - }) - } -} - -func TestCodec_ReadError_Handling(t *testing.T) { - codec := NewCodec() - - // Create a reader that fails after reading header - reader := &failingReader{ - data: []byte{0x01, 0x00, 0x00, 0x00, 0x05}, - failAfter: 5, - } - - _, err := codec.Decode(reader) - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to read payload") -} - -func TestChallengeResponse_Encode(t *testing.T) { - challenge := &challenge.Challenge{ - Timestamp: time.Now().Unix(), - Difficulty: 4, - Resource: "quotes", - Random: []byte("random123"), - HMAC: []byte("hmac_signature"), - } - - response := &ChallengeResponse{Challenge: challenge} - - var buf bytes.Buffer - err := response.Encode(&buf) - require.NoError(t, err) - - // Verify the encoded message can be decoded - codec := NewCodec() - decoded, err := codec.Decode(&buf) - require.NoError(t, err) - - assert.Equal(t, ChallengeResponseType, decoded.Type) - assert.Contains(t, string(decoded.Payload), "quotes") - assert.Contains(t, string(decoded.Payload), "cmFuZG9tMTIz") // "random123" base64 encoded -} - -func TestSolutionResponse_Encode(t *testing.T) { - quote := "es.Quote{ - Text: "Test quote", - Author: "Test author", - } - - response := &SolutionResponse{Quote: quote} - - var buf bytes.Buffer - err := response.Encode(&buf) - require.NoError(t, err) - - // Verify the encoded message can be decoded - codec := NewCodec() - decoded, err := codec.Decode(&buf) - require.NoError(t, err) - - assert.Equal(t, QuoteResponseType, decoded.Type) - assert.Contains(t, string(decoded.Payload), "Test quote") - assert.Contains(t, string(decoded.Payload), "Test author") -} - -func TestErrorResponse_Encode(t *testing.T) { - errorResp := &ErrorResponse{ - Code: "INVALID_SOLUTION", - Message: "The provided PoW solution is incorrect", - RetryAfter: 30, - Details: map[string]string{"attempt": "1"}, - } - - var buf bytes.Buffer - err := errorResp.Encode(&buf) - require.NoError(t, err) - - // Verify the encoded message can be decoded - codec := NewCodec() - decoded, err := codec.Decode(&buf) - require.NoError(t, err) - - assert.Equal(t, ErrorResponseType, decoded.Type) - assert.Contains(t, string(decoded.Payload), "INVALID_SOLUTION") - assert.Contains(t, string(decoded.Payload), "The provided PoW solution is incorrect") - assert.Contains(t, string(decoded.Payload), "30") -} - -func TestResponse_WriteError_Handling(t *testing.T) { - response := &ErrorResponse{ - Code: "TEST_ERROR", - Message: "Test message", - } - - // Create a writer that fails immediately - writer := &failingWriter{failAfter: 1} - - err := response.Encode(writer) - assert.Error(t, err) -} - -// Helper functions and types for testing - -func encodeBigEndianUint32(val uint32) []byte { - return []byte{ - byte(val >> 24), - byte(val >> 16), - byte(val >> 8), - byte(val), - } -} - -type failingWriter struct { - written int - failAfter int -} - -func (w *failingWriter) Write(data []byte) (int, error) { - if w.written >= w.failAfter { - return 0, io.ErrShortWrite - } - - remaining := w.failAfter - w.written - if len(data) <= remaining { - w.written += len(data) - return len(data), nil - } - - w.written = w.failAfter - return remaining, io.ErrShortWrite -} - -type failingReader struct { - data []byte - pos int - failAfter int -} - -func (r *failingReader) Read(buf []byte) (int, error) { - if r.pos >= r.failAfter { - return 0, io.ErrUnexpectedEOF - } - - if r.pos >= len(r.data) { - return 0, io.EOF - } - - n := copy(buf, r.data[r.pos:]) - r.pos += n - - if r.pos >= r.failAfter { - return n, io.ErrUnexpectedEOF - } - - return n, nil -} diff --git a/internal/protocol/message_decoder.go b/internal/protocol/message_decoder.go new file mode 100644 index 0000000..72eea1f --- /dev/null +++ b/internal/protocol/message_decoder.go @@ -0,0 +1,65 @@ +package protocol + +import ( + "encoding/binary" + "fmt" + "io" +) + +// MessageDecoder handles decoding of protocol message headers +type MessageDecoder struct{} + +// NewMessageDecoder creates a new message decoder +func NewMessageDecoder() *MessageDecoder { + return &MessageDecoder{} +} + +// Decode reads the message header and returns a Message with the payload stream +func (d *MessageDecoder) Decode(r io.Reader) (*Message, error) { + // Read message type (1 byte) + var msgType MessageType + if err := binary.Read(r, binary.BigEndian, &msgType); err != nil { + if err == io.EOF { + return nil, err + } + return nil, fmt.Errorf("failed to read message type: %w", err) + } + + // Validate message type (only request types are valid for server) + if !isValidRequestType(msgType) { + return nil, fmt.Errorf("invalid message type: 0x%02x", msgType) + } + + // Read payload length (4 bytes, big-endian) + var payloadLength uint32 + if err := binary.Read(r, binary.BigEndian, &payloadLength); err != nil { + return nil, fmt.Errorf("failed to read payload length: %w", err) + } + + // Validate payload length + if payloadLength > MaxPayloadSize { + return nil, fmt.Errorf("payload length %d exceeds maximum %d", payloadLength, MaxPayloadSize) + } + + // Create limited reader for the payload + var payloadStream io.Reader + if payloadLength > 0 { + payloadStream = io.LimitReader(r, int64(payloadLength)) + } + + return &Message{ + Type: msgType, + PayloadLength: payloadLength, + PayloadStream: payloadStream, + }, nil +} + +// isValidRequestType checks if the message type is a valid request type +func isValidRequestType(msgType MessageType) bool { + switch msgType { + case ChallengeRequestType, SolutionRequestType: + return true + default: + return false + } +} diff --git a/internal/protocol/message_decoder_test.go b/internal/protocol/message_decoder_test.go new file mode 100644 index 0000000..e4c4680 --- /dev/null +++ b/internal/protocol/message_decoder_test.go @@ -0,0 +1,206 @@ +package protocol + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMessageDecoder_Decode_Header(t *testing.T) { + decoder := NewMessageDecoder() + + tests := []struct { + name string + data []byte + wantType MessageType + wantLength uint32 + wantErr string + }{ + { + name: "challenge request with empty payload", + data: []byte{0x01, 0x00, 0x00, 0x00, 0x00}, + wantType: ChallengeRequestType, + wantLength: 0, + }, + { + name: "solution request with payload", + data: append([]byte{0x03, 0x00, 0x00, 0x00, 0x05}, []byte("hello")...), + wantType: SolutionRequestType, + wantLength: 5, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := bytes.NewBuffer(tt.data) + msg, err := decoder.Decode(buf) + + if tt.wantErr != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.wantType, msg.Type) + assert.Equal(t, tt.wantLength, msg.PayloadLength) + + if tt.wantLength > 0 { + assert.NotNil(t, msg.PayloadStream) + } else { + assert.Nil(t, msg.PayloadStream) + } + }) + } +} + +func TestMessageDecoder_Decode_Errors(t *testing.T) { + decoder := NewMessageDecoder() + + tests := []struct { + name string + data []byte + wantErr string + }{ + { + name: "empty data", + data: []byte{}, + wantErr: "EOF", + }, + { + name: "invalid message type", + data: []byte{0xFF, 0x00, 0x00, 0x00, 0x00}, + wantErr: "invalid message type", + }, + { + name: "response type not allowed", + data: []byte{0x02, 0x00, 0x00, 0x00, 0x00}, // ChallengeResponseType + wantErr: "invalid message type", + }, + { + name: "incomplete header", + data: []byte{0x01, 0x00, 0x00}, + wantErr: "failed to read payload length", + }, + { + name: "payload too large", + data: append([]byte{0x01}, encodeBigEndianUint32(MaxPayloadSize+1)...), + wantErr: "payload length", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := bytes.NewBuffer(tt.data) + _, err := decoder.Decode(buf) + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + }) + } +} + +func TestChallengeRequest_Decode(t *testing.T) { + req := &ChallengeRequest{} + + t.Run("always succeeds", func(t *testing.T) { + err := req.Decode(nil) + assert.NoError(t, err) + + err = req.Decode(bytes.NewReader([]byte("ignored"))) + assert.NoError(t, err) + }) +} + +func TestSolutionRequest_Decode(t *testing.T) { + req := &SolutionRequest{} + + t.Run("valid solution request", func(t *testing.T) { + payload := `{"challenge":{"timestamp":1640995200,"difficulty":4,"resource":"quotes","random":"cmFuZG9tMTIz","hmac":"aG1hY19zaWduYXR1cmU="},"nonce":12345}` + reader := bytes.NewReader([]byte(payload)) + + err := req.Decode(reader) + require.NoError(t, err) + + assert.Equal(t, int64(1640995200), req.Challenge.Timestamp) + assert.Equal(t, 4, req.Challenge.Difficulty) + assert.Equal(t, "quotes", req.Challenge.Resource) + assert.Equal(t, uint64(12345), req.Nonce) + }) + + t.Run("empty payload should error", func(t *testing.T) { + err := req.Decode(nil) + assert.Error(t, err) + }) + + t.Run("invalid JSON should error", func(t *testing.T) { + payload := `{invalid json}` + reader := bytes.NewReader([]byte(payload)) + + err := req.Decode(reader) + assert.Error(t, err) + }) + + t.Run("invalid UTF-8 should error", func(t *testing.T) { + payload := []byte{0xFF, 0xFE, 0xFD} + reader := bytes.NewReader(payload) + + err := req.Decode(reader) + assert.Error(t, err) + }) +} + +func TestEndToEnd_RequestDecoding(t *testing.T) { + decoder := NewMessageDecoder() + + t.Run("challenge request flow", func(t *testing.T) { + // Create message data: type=0x01, length=0 + data := []byte{0x01, 0x00, 0x00, 0x00, 0x00} + buf := bytes.NewBuffer(data) + + // Decode header + msg, err := decoder.Decode(buf) + require.NoError(t, err) + assert.Equal(t, ChallengeRequestType, msg.Type) + + // Decode request + req := &ChallengeRequest{} + err = req.Decode(msg.PayloadStream) + require.NoError(t, err) + }) + + t.Run("solution request flow", func(t *testing.T) { + payload := `{"challenge":{"timestamp":1640995200,"difficulty":4,"resource":"quotes","random":"cmFuZG9tMTIz","hmac":"aG1hY19zaWduYXR1cmU="},"nonce":12345}` + + // Create message data: type=0x03, length, payload + var buf bytes.Buffer + buf.WriteByte(0x03) // SolutionRequestType + length := uint32(len(payload)) + buf.Write(encodeBigEndianUint32(length)) + buf.WriteString(payload) + + // Decode header + msg, err := decoder.Decode(&buf) + require.NoError(t, err) + assert.Equal(t, SolutionRequestType, msg.Type) + assert.Equal(t, length, msg.PayloadLength) + + // Decode request + req := &SolutionRequest{} + err = req.Decode(msg.PayloadStream) + require.NoError(t, err) + assert.Equal(t, uint64(12345), req.Nonce) + }) +} + +// Helper functions for testing + +func encodeBigEndianUint32(val uint32) []byte { + return []byte{ + byte(val >> 24), + byte(val >> 16), + byte(val >> 8), + byte(val), + } +} diff --git a/internal/protocol/types.go b/internal/protocol/types.go index 4afd17f..bb2a8ee 100644 --- a/internal/protocol/types.go +++ b/internal/protocol/types.go @@ -1,64 +1,10 @@ package protocol -import ( - "hash-of-wisdom/internal/pow/challenge" - "hash-of-wisdom/internal/quotes" -) +import "io" -// MessageType represents the type of protocol message -type MessageType byte - -const ( - ChallengeRequest MessageType = 0x01 - ChallengeResponse MessageType = 0x02 - SolutionRequest MessageType = 0x03 - QuoteResponse MessageType = 0x04 - ErrorResponse MessageType = 0x05 -) - -// Message represents a protocol message with type and payload +// Message represents a protocol message with type and payload stream type Message struct { - Type MessageType - Payload []byte + Type MessageType + PayloadLength uint32 + PayloadStream io.Reader } - -// ChallengeRequestPayload is empty (no payload for challenge requests) -type ChallengeRequestPayload struct{} - -// ChallengeResponsePayload is the direct challenge object (not wrapped) -type ChallengeResponsePayload challenge.Challenge - -// SolutionRequestPayload contains the client's solution attempt -type SolutionRequestPayload struct { - Challenge challenge.Challenge `json:"challenge"` - Nonce uint64 `json:"nonce"` -} - -// QuoteResponsePayload is the direct quote object (not wrapped) -type QuoteResponsePayload quotes.Quote - -// ErrorResponsePayload contains error information -type ErrorResponsePayload struct { - Code string `json:"code"` - Message string `json:"message"` - RetryAfter int `json:"retry_after,omitempty"` - Details map[string]string `json:"details,omitempty"` -} - -// Error codes as defined in protocol specification -const ( - ErrMalformedMessage = "MALFORMED_MESSAGE" - ErrInvalidChallenge = "INVALID_CHALLENGE" - ErrInvalidSolution = "INVALID_SOLUTION" - ErrExpiredChallenge = "EXPIRED_CHALLENGE" - ErrRateLimited = "RATE_LIMITED" - ErrServerError = "SERVER_ERROR" - ErrTooManyConnections = "TOO_MANY_CONNECTIONS" - ErrDifficultyTooHigh = "DIFFICULTY_TOO_HIGH" -) - -// Protocol constants -const ( - MaxPayloadSize = 8 * 1024 // 8KB maximum payload size - HeaderSize = 5 // 1 byte type + 4 bytes length -)