From dc9f2b24d64b4865eba9ac01d96f61cd0703d639 Mon Sep 17 00:00:00 2001 From: Savely Krendelhoff Date: Fri, 22 Aug 2025 21:03:54 +0700 Subject: [PATCH] Implement codec --- internal/protocol/codec.go | 111 +++++++++++++++++++++++++++++++++++++ 1 file changed, 111 insertions(+) create mode 100644 internal/protocol/codec.go diff --git a/internal/protocol/codec.go b/internal/protocol/codec.go new file mode 100644 index 0000000..e5d83cd --- /dev/null +++ b/internal/protocol/codec.go @@ -0,0 +1,111 @@ +package protocol + +import ( + "encoding/binary" + "fmt" + "io" + "unicode/utf8" +) + +// Codec handles encoding and decoding of protocol messages +type Codec struct{} + +// NewCodec creates a new protocol codec +func NewCodec() *Codec { + return &Codec{} +} + +// Encode writes a message to the writer using the protocol format +func (c *Codec) Encode(w io.Writer, msg *Message) error { + if msg == nil { + return fmt.Errorf("message cannot be nil") + } + + // Validate payload size + if len(msg.Payload) > MaxPayloadSize { + return fmt.Errorf("payload size %d exceeds maximum %d", len(msg.Payload), MaxPayloadSize) + } + + // Write message type (1 byte) + if err := binary.Write(w, binary.BigEndian, msg.Type); err != nil { + return fmt.Errorf("failed to write message type: %w", err) + } + + // Write payload length (4 bytes, big-endian) + payloadLength := uint32(len(msg.Payload)) + if err := binary.Write(w, binary.BigEndian, payloadLength); err != nil { + return fmt.Errorf("failed to write payload length: %w", err) + } + + // Write payload if present + if len(msg.Payload) > 0 { + if _, err := w.Write(msg.Payload); err != nil { + return fmt.Errorf("failed to write payload: %w", err) + } + } + + return nil +} + +// Decode reads a message from the reader using the protocol format +func (c *Codec) Decode(r io.Reader) (*Message, error) { + // Read message type (1 byte) + var msgType MessageType + if err := binary.Read(r, binary.BigEndian, &msgType); err != nil { + if err == io.EOF { + return nil, err + } + return nil, fmt.Errorf("failed to read message type: %w", err) + } + + // Validate message type + if !isValidMessageType(msgType) { + return nil, fmt.Errorf("invalid message type: 0x%02x", msgType) + } + + // Read payload length (4 bytes, big-endian) + var payloadLength uint32 + if err := binary.Read(r, binary.BigEndian, &payloadLength); err != nil { + return nil, fmt.Errorf("failed to read payload length: %w", err) + } + + // Validate payload length + if payloadLength > MaxPayloadSize { + return nil, fmt.Errorf("payload length %d exceeds maximum %d", payloadLength, MaxPayloadSize) + } + + // Read payload if present + var payload []byte + if payloadLength > 0 { + payload = make([]byte, payloadLength) + // Use LimitReader to ensure we don't read more than payloadLength bytes + // even if the underlying reader has more data available + limitedReader := io.LimitReader(r, int64(payloadLength)) + // Note: ReadFull may block waiting for data. The connection handler + // MUST set appropriate read deadlines to prevent slowloris attacks + if _, err := io.ReadFull(limitedReader, payload); err != nil { + return nil, fmt.Errorf("failed to read payload: %w", err) + } + + // Validate payload is valid UTF-8 + if !utf8.Valid(payload) { + return nil, fmt.Errorf("payload contains invalid UTF-8") + } + } + + return &Message{ + Type: msgType, + Payload: payload, + }, nil +} + + +// isValidMessageType checks if the message type is defined in the protocol +func isValidMessageType(msgType MessageType) bool { + switch msgType { + case ChallengeRequest, ChallengeResponse, SolutionRequest, QuoteResponse, ErrorResponse: + return true + default: + return false + } +}