package protocol import ( "encoding/binary" "fmt" "io" "unicode/utf8" ) // Codec handles encoding and decoding of protocol messages type Codec struct{} // NewCodec creates a new protocol codec func NewCodec() *Codec { return &Codec{} } // Encode writes a message to the writer using the protocol format func (c *Codec) Encode(w io.Writer, msg *Message) error { if msg == nil { return fmt.Errorf("message cannot be nil") } // Validate payload size if len(msg.Payload) > MaxPayloadSize { return fmt.Errorf("payload size %d exceeds maximum %d", len(msg.Payload), MaxPayloadSize) } // Write message type (1 byte) if err := binary.Write(w, binary.BigEndian, msg.Type); err != nil { return fmt.Errorf("failed to write message type: %w", err) } // Write payload length (4 bytes, big-endian) payloadLength := uint32(len(msg.Payload)) if err := binary.Write(w, binary.BigEndian, payloadLength); err != nil { return fmt.Errorf("failed to write payload length: %w", err) } // Write payload if present if len(msg.Payload) > 0 { if _, err := w.Write(msg.Payload); err != nil { return fmt.Errorf("failed to write payload: %w", err) } } return nil } // Decode reads a message from the reader using the protocol format func (c *Codec) Decode(r io.Reader) (*Message, error) { // Read message type (1 byte) var msgType MessageType if err := binary.Read(r, binary.BigEndian, &msgType); err != nil { if err == io.EOF { return nil, err } return nil, fmt.Errorf("failed to read message type: %w", err) } // Validate message type if !isValidMessageType(msgType) { return nil, fmt.Errorf("invalid message type: 0x%02x", msgType) } // Read payload length (4 bytes, big-endian) var payloadLength uint32 if err := binary.Read(r, binary.BigEndian, &payloadLength); err != nil { return nil, fmt.Errorf("failed to read payload length: %w", err) } // Validate payload length if payloadLength > MaxPayloadSize { return nil, fmt.Errorf("payload length %d exceeds maximum %d", payloadLength, MaxPayloadSize) } // Read payload if present var payload []byte if payloadLength > 0 { payload = make([]byte, payloadLength) // Use LimitReader to ensure we don't read more than payloadLength bytes // even if the underlying reader has more data available limitedReader := io.LimitReader(r, int64(payloadLength)) // Note: ReadFull may block waiting for data. The connection handler // MUST set appropriate read deadlines to prevent slowloris attacks if _, err := io.ReadFull(limitedReader, payload); err != nil { return nil, fmt.Errorf("failed to read payload: %w", err) } // Validate payload is valid UTF-8 if !utf8.Valid(payload) { return nil, fmt.Errorf("payload contains invalid UTF-8") } } return &Message{ Type: msgType, Payload: payload, }, nil } // isValidMessageType checks if the message type is defined in the protocol func isValidMessageType(msgType MessageType) bool { switch msgType { case ChallengeRequest, ChallengeResponse, SolutionRequest, QuoteResponse, ErrorResponse: return true default: return false } }