hash-of-wisdom/internal/protocol/codec.go

112 lines
3 KiB
Go

package protocol
import (
"encoding/binary"
"fmt"
"io"
"unicode/utf8"
)
// Codec handles encoding and decoding of protocol messages
type Codec struct{}
// NewCodec creates a new protocol codec
func NewCodec() *Codec {
return &Codec{}
}
// Encode writes a message to the writer using the protocol format
func (c *Codec) Encode(w io.Writer, msg *Message) error {
if msg == nil {
return fmt.Errorf("message cannot be nil")
}
// Validate payload size
if len(msg.Payload) > MaxPayloadSize {
return fmt.Errorf("payload size %d exceeds maximum %d", len(msg.Payload), MaxPayloadSize)
}
// Write message type (1 byte)
if err := binary.Write(w, binary.BigEndian, msg.Type); err != nil {
return fmt.Errorf("failed to write message type: %w", err)
}
// Write payload length (4 bytes, big-endian)
payloadLength := uint32(len(msg.Payload))
if err := binary.Write(w, binary.BigEndian, payloadLength); err != nil {
return fmt.Errorf("failed to write payload length: %w", err)
}
// Write payload if present
if len(msg.Payload) > 0 {
if _, err := w.Write(msg.Payload); err != nil {
return fmt.Errorf("failed to write payload: %w", err)
}
}
return nil
}
// Decode reads a message from the reader using the protocol format
func (c *Codec) Decode(r io.Reader) (*Message, error) {
// Read message type (1 byte)
var msgType MessageType
if err := binary.Read(r, binary.BigEndian, &msgType); err != nil {
if err == io.EOF {
return nil, err
}
return nil, fmt.Errorf("failed to read message type: %w", err)
}
// Validate message type
if !isValidMessageType(msgType) {
return nil, fmt.Errorf("invalid message type: 0x%02x", msgType)
}
// Read payload length (4 bytes, big-endian)
var payloadLength uint32
if err := binary.Read(r, binary.BigEndian, &payloadLength); err != nil {
return nil, fmt.Errorf("failed to read payload length: %w", err)
}
// Validate payload length
if payloadLength > MaxPayloadSize {
return nil, fmt.Errorf("payload length %d exceeds maximum %d", payloadLength, MaxPayloadSize)
}
// Read payload if present
var payload []byte
if payloadLength > 0 {
payload = make([]byte, payloadLength)
// Use LimitReader to ensure we don't read more than payloadLength bytes
// even if the underlying reader has more data available
limitedReader := io.LimitReader(r, int64(payloadLength))
// Note: ReadFull may block waiting for data. The connection handler
// MUST set appropriate read deadlines to prevent slowloris attacks
if _, err := io.ReadFull(limitedReader, payload); err != nil {
return nil, fmt.Errorf("failed to read payload: %w", err)
}
// Validate payload is valid UTF-8
if !utf8.Valid(payload) {
return nil, fmt.Errorf("payload contains invalid UTF-8")
}
}
return &Message{
Type: msgType,
Payload: payload,
}, nil
}
// isValidMessageType checks if the message type is defined in the protocol
func isValidMessageType(msgType MessageType) bool {
switch msgType {
case ChallengeRequest, ChallengeResponse, SolutionRequest, QuoteResponse, ErrorResponse:
return true
default:
return false
}
}