112 lines
3 KiB
Go
112 lines
3 KiB
Go
package protocol
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"fmt"
|
|
"io"
|
|
"unicode/utf8"
|
|
)
|
|
|
|
// Codec handles encoding and decoding of protocol messages
|
|
type Codec struct{}
|
|
|
|
// NewCodec creates a new protocol codec
|
|
func NewCodec() *Codec {
|
|
return &Codec{}
|
|
}
|
|
|
|
// Encode writes a message to the writer using the protocol format
|
|
func (c *Codec) Encode(w io.Writer, msg *Message) error {
|
|
if msg == nil {
|
|
return fmt.Errorf("message cannot be nil")
|
|
}
|
|
|
|
// Validate payload size
|
|
if len(msg.Payload) > MaxPayloadSize {
|
|
return fmt.Errorf("payload size %d exceeds maximum %d", len(msg.Payload), MaxPayloadSize)
|
|
}
|
|
|
|
// Write message type (1 byte)
|
|
if err := binary.Write(w, binary.BigEndian, msg.Type); err != nil {
|
|
return fmt.Errorf("failed to write message type: %w", err)
|
|
}
|
|
|
|
// Write payload length (4 bytes, big-endian)
|
|
payloadLength := uint32(len(msg.Payload))
|
|
if err := binary.Write(w, binary.BigEndian, payloadLength); err != nil {
|
|
return fmt.Errorf("failed to write payload length: %w", err)
|
|
}
|
|
|
|
// Write payload if present
|
|
if len(msg.Payload) > 0 {
|
|
if _, err := w.Write(msg.Payload); err != nil {
|
|
return fmt.Errorf("failed to write payload: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Decode reads a message from the reader using the protocol format
|
|
func (c *Codec) Decode(r io.Reader) (*Message, error) {
|
|
// Read message type (1 byte)
|
|
var msgType MessageType
|
|
if err := binary.Read(r, binary.BigEndian, &msgType); err != nil {
|
|
if err == io.EOF {
|
|
return nil, err
|
|
}
|
|
return nil, fmt.Errorf("failed to read message type: %w", err)
|
|
}
|
|
|
|
// Validate message type
|
|
if !isValidMessageType(msgType) {
|
|
return nil, fmt.Errorf("invalid message type: 0x%02x", msgType)
|
|
}
|
|
|
|
// Read payload length (4 bytes, big-endian)
|
|
var payloadLength uint32
|
|
if err := binary.Read(r, binary.BigEndian, &payloadLength); err != nil {
|
|
return nil, fmt.Errorf("failed to read payload length: %w", err)
|
|
}
|
|
|
|
// Validate payload length
|
|
if payloadLength > MaxPayloadSize {
|
|
return nil, fmt.Errorf("payload length %d exceeds maximum %d", payloadLength, MaxPayloadSize)
|
|
}
|
|
|
|
// Read payload if present
|
|
var payload []byte
|
|
if payloadLength > 0 {
|
|
payload = make([]byte, payloadLength)
|
|
// Use LimitReader to ensure we don't read more than payloadLength bytes
|
|
// even if the underlying reader has more data available
|
|
limitedReader := io.LimitReader(r, int64(payloadLength))
|
|
// Note: ReadFull may block waiting for data. The connection handler
|
|
// MUST set appropriate read deadlines to prevent slowloris attacks
|
|
if _, err := io.ReadFull(limitedReader, payload); err != nil {
|
|
return nil, fmt.Errorf("failed to read payload: %w", err)
|
|
}
|
|
|
|
// Validate payload is valid UTF-8
|
|
if !utf8.Valid(payload) {
|
|
return nil, fmt.Errorf("payload contains invalid UTF-8")
|
|
}
|
|
}
|
|
|
|
return &Message{
|
|
Type: msgType,
|
|
Payload: payload,
|
|
}, nil
|
|
}
|
|
|
|
|
|
// isValidMessageType checks if the message type is defined in the protocol
|
|
func isValidMessageType(msgType MessageType) bool {
|
|
switch msgType {
|
|
case ChallengeRequest, ChallengeResponse, SolutionRequest, QuoteResponse, ErrorResponse:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|