package protocol import ( "encoding/binary" "fmt" "io" "unicode/utf8" ) // Codec handles encoding and decoding of protocol messages type Codec struct{} // NewCodec creates a new protocol codec func NewCodec() *Codec { return &Codec{} } // Decode reads a message from the reader using the protocol format func (c *Codec) Decode(r io.Reader) (*Message, error) { // Read message type (1 byte) var msgType MessageType if err := binary.Read(r, binary.BigEndian, &msgType); err != nil { if err == io.EOF { return nil, err } return nil, fmt.Errorf("failed to read message type: %w", err) } // Validate message type if !isValidMessageType(msgType) { return nil, fmt.Errorf("invalid message type: 0x%02x", msgType) } // Read payload length (4 bytes, big-endian) var payloadLength uint32 if err := binary.Read(r, binary.BigEndian, &payloadLength); err != nil { return nil, fmt.Errorf("failed to read payload length: %w", err) } // Validate payload length if payloadLength > MaxPayloadSize { return nil, fmt.Errorf("payload length %d exceeds maximum %d", payloadLength, MaxPayloadSize) } // Read payload if present var payload []byte if payloadLength > 0 { payload = make([]byte, payloadLength) // ReadFull reads exactly payloadLength bytes // The server MUST use LimitReader and set read deadlines to prevent attacks if _, err := io.ReadFull(r, payload); err != nil { return nil, fmt.Errorf("failed to read payload: %w", err) } // Validate payload is valid UTF-8 if !utf8.Valid(payload) { return nil, fmt.Errorf("payload contains invalid UTF-8") } } return &Message{ Type: msgType, Payload: payload, }, nil } // isValidMessageType checks if the message type is defined in the protocol func isValidMessageType(msgType MessageType) bool { switch msgType { case ChallengeRequestType, ChallengeResponseType, SolutionRequestType, QuoteResponseType, ErrorResponseType: return true default: return false } }