Implement codec
This commit is contained in:
parent
ffc23c362b
commit
dc9f2b24d6
111
internal/protocol/codec.go
Normal file
111
internal/protocol/codec.go
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// Codec handles encoding and decoding of protocol messages
|
||||
type Codec struct{}
|
||||
|
||||
// NewCodec creates a new protocol codec
|
||||
func NewCodec() *Codec {
|
||||
return &Codec{}
|
||||
}
|
||||
|
||||
// Encode writes a message to the writer using the protocol format
|
||||
func (c *Codec) Encode(w io.Writer, msg *Message) error {
|
||||
if msg == nil {
|
||||
return fmt.Errorf("message cannot be nil")
|
||||
}
|
||||
|
||||
// Validate payload size
|
||||
if len(msg.Payload) > MaxPayloadSize {
|
||||
return fmt.Errorf("payload size %d exceeds maximum %d", len(msg.Payload), MaxPayloadSize)
|
||||
}
|
||||
|
||||
// Write message type (1 byte)
|
||||
if err := binary.Write(w, binary.BigEndian, msg.Type); err != nil {
|
||||
return fmt.Errorf("failed to write message type: %w", err)
|
||||
}
|
||||
|
||||
// Write payload length (4 bytes, big-endian)
|
||||
payloadLength := uint32(len(msg.Payload))
|
||||
if err := binary.Write(w, binary.BigEndian, payloadLength); err != nil {
|
||||
return fmt.Errorf("failed to write payload length: %w", err)
|
||||
}
|
||||
|
||||
// Write payload if present
|
||||
if len(msg.Payload) > 0 {
|
||||
if _, err := w.Write(msg.Payload); err != nil {
|
||||
return fmt.Errorf("failed to write payload: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode reads a message from the reader using the protocol format
|
||||
func (c *Codec) Decode(r io.Reader) (*Message, error) {
|
||||
// Read message type (1 byte)
|
||||
var msgType MessageType
|
||||
if err := binary.Read(r, binary.BigEndian, &msgType); err != nil {
|
||||
if err == io.EOF {
|
||||
return nil, err
|
||||
}
|
||||
return nil, fmt.Errorf("failed to read message type: %w", err)
|
||||
}
|
||||
|
||||
// Validate message type
|
||||
if !isValidMessageType(msgType) {
|
||||
return nil, fmt.Errorf("invalid message type: 0x%02x", msgType)
|
||||
}
|
||||
|
||||
// Read payload length (4 bytes, big-endian)
|
||||
var payloadLength uint32
|
||||
if err := binary.Read(r, binary.BigEndian, &payloadLength); err != nil {
|
||||
return nil, fmt.Errorf("failed to read payload length: %w", err)
|
||||
}
|
||||
|
||||
// Validate payload length
|
||||
if payloadLength > MaxPayloadSize {
|
||||
return nil, fmt.Errorf("payload length %d exceeds maximum %d", payloadLength, MaxPayloadSize)
|
||||
}
|
||||
|
||||
// Read payload if present
|
||||
var payload []byte
|
||||
if payloadLength > 0 {
|
||||
payload = make([]byte, payloadLength)
|
||||
// Use LimitReader to ensure we don't read more than payloadLength bytes
|
||||
// even if the underlying reader has more data available
|
||||
limitedReader := io.LimitReader(r, int64(payloadLength))
|
||||
// Note: ReadFull may block waiting for data. The connection handler
|
||||
// MUST set appropriate read deadlines to prevent slowloris attacks
|
||||
if _, err := io.ReadFull(limitedReader, payload); err != nil {
|
||||
return nil, fmt.Errorf("failed to read payload: %w", err)
|
||||
}
|
||||
|
||||
// Validate payload is valid UTF-8
|
||||
if !utf8.Valid(payload) {
|
||||
return nil, fmt.Errorf("payload contains invalid UTF-8")
|
||||
}
|
||||
}
|
||||
|
||||
return &Message{
|
||||
Type: msgType,
|
||||
Payload: payload,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
// isValidMessageType checks if the message type is defined in the protocol
|
||||
func isValidMessageType(msgType MessageType) bool {
|
||||
switch msgType {
|
||||
case ChallengeRequest, ChallengeResponse, SolutionRequest, QuoteResponse, ErrorResponse:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
Loading…
Reference in a new issue