[PHASE-5] Rework codec

This commit is contained in:
Savely Krendelhoff 2025-08-23 12:18:45 +07:00
parent c147bc7fe4
commit 140099d6c2
No known key found for this signature in database
GPG key ID: F70DFD34F40238DE
5 changed files with 276 additions and 352 deletions

View file

@ -1,76 +0,0 @@
package protocol
import (
"encoding/binary"
"fmt"
"io"
"unicode/utf8"
)
// Codec handles encoding and decoding of protocol messages
type Codec struct{}
// NewCodec creates a new protocol codec
func NewCodec() *Codec {
return &Codec{}
}
// Decode reads a message from the reader using the protocol format
func (c *Codec) Decode(r io.Reader) (*Message, error) {
// Read message type (1 byte)
var msgType MessageType
if err := binary.Read(r, binary.BigEndian, &msgType); err != nil {
if err == io.EOF {
return nil, err
}
return nil, fmt.Errorf("failed to read message type: %w", err)
}
// Validate message type
if !isValidMessageType(msgType) {
return nil, fmt.Errorf("invalid message type: 0x%02x", msgType)
}
// Read payload length (4 bytes, big-endian)
var payloadLength uint32
if err := binary.Read(r, binary.BigEndian, &payloadLength); err != nil {
return nil, fmt.Errorf("failed to read payload length: %w", err)
}
// Validate payload length
if payloadLength > MaxPayloadSize {
return nil, fmt.Errorf("payload length %d exceeds maximum %d", payloadLength, MaxPayloadSize)
}
// Read payload if present
var payload []byte
if payloadLength > 0 {
payload = make([]byte, payloadLength)
// ReadFull reads exactly payloadLength bytes
// The server MUST use LimitReader and set read deadlines to prevent attacks
if _, err := io.ReadFull(r, payload); err != nil {
return nil, fmt.Errorf("failed to read payload: %w", err)
}
// Validate payload is valid UTF-8
if !utf8.Valid(payload) {
return nil, fmt.Errorf("payload contains invalid UTF-8")
}
}
return &Message{
Type: msgType,
Payload: payload,
}, nil
}
// isValidMessageType checks if the message type is defined in the protocol
func isValidMessageType(msgType MessageType) bool {
switch msgType {
case ChallengeRequestType, ChallengeResponseType, SolutionRequestType, QuoteResponseType, ErrorResponseType:
return true
default:
return false
}
}

View file

@ -1,217 +0,0 @@
package protocol
import (
"bytes"
"io"
"testing"
"time"
"hash-of-wisdom/internal/pow/challenge"
"hash-of-wisdom/internal/quotes"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCodec_Decode(t *testing.T) {
codec := NewCodec()
tests := []struct {
name string
data []byte
wantErr string
}{
{
name: "empty data",
data: []byte{},
wantErr: "EOF",
},
{
name: "invalid message type",
data: []byte{0xFF, 0x00, 0x00, 0x00, 0x00},
wantErr: "invalid message type",
},
{
name: "incomplete header",
data: []byte{0x01, 0x00, 0x00},
wantErr: "failed to read payload length",
},
{
name: "payload too large",
data: append([]byte{0x01}, encodeBigEndianUint32(MaxPayloadSize+1)...),
wantErr: "payload length",
},
{
name: "incomplete payload",
data: []byte{0x01, 0x00, 0x00, 0x00, 0x05, 0x01, 0x02},
wantErr: "failed to read payload",
},
{
name: "invalid UTF-8 in payload",
data: []byte{0x01, 0x00, 0x00, 0x00, 0x03, 0xFF, 0xFE, 0xFD},
wantErr: "invalid UTF-8",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := bytes.NewBuffer(tt.data)
_, err := codec.Decode(buf)
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.wantErr)
})
}
}
func TestCodec_ReadError_Handling(t *testing.T) {
codec := NewCodec()
// Create a reader that fails after reading header
reader := &failingReader{
data: []byte{0x01, 0x00, 0x00, 0x00, 0x05},
failAfter: 5,
}
_, err := codec.Decode(reader)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failed to read payload")
}
func TestChallengeResponse_Encode(t *testing.T) {
challenge := &challenge.Challenge{
Timestamp: time.Now().Unix(),
Difficulty: 4,
Resource: "quotes",
Random: []byte("random123"),
HMAC: []byte("hmac_signature"),
}
response := &ChallengeResponse{Challenge: challenge}
var buf bytes.Buffer
err := response.Encode(&buf)
require.NoError(t, err)
// Verify the encoded message can be decoded
codec := NewCodec()
decoded, err := codec.Decode(&buf)
require.NoError(t, err)
assert.Equal(t, ChallengeResponseType, decoded.Type)
assert.Contains(t, string(decoded.Payload), "quotes")
assert.Contains(t, string(decoded.Payload), "cmFuZG9tMTIz") // "random123" base64 encoded
}
func TestSolutionResponse_Encode(t *testing.T) {
quote := &quotes.Quote{
Text: "Test quote",
Author: "Test author",
}
response := &SolutionResponse{Quote: quote}
var buf bytes.Buffer
err := response.Encode(&buf)
require.NoError(t, err)
// Verify the encoded message can be decoded
codec := NewCodec()
decoded, err := codec.Decode(&buf)
require.NoError(t, err)
assert.Equal(t, QuoteResponseType, decoded.Type)
assert.Contains(t, string(decoded.Payload), "Test quote")
assert.Contains(t, string(decoded.Payload), "Test author")
}
func TestErrorResponse_Encode(t *testing.T) {
errorResp := &ErrorResponse{
Code: "INVALID_SOLUTION",
Message: "The provided PoW solution is incorrect",
RetryAfter: 30,
Details: map[string]string{"attempt": "1"},
}
var buf bytes.Buffer
err := errorResp.Encode(&buf)
require.NoError(t, err)
// Verify the encoded message can be decoded
codec := NewCodec()
decoded, err := codec.Decode(&buf)
require.NoError(t, err)
assert.Equal(t, ErrorResponseType, decoded.Type)
assert.Contains(t, string(decoded.Payload), "INVALID_SOLUTION")
assert.Contains(t, string(decoded.Payload), "The provided PoW solution is incorrect")
assert.Contains(t, string(decoded.Payload), "30")
}
func TestResponse_WriteError_Handling(t *testing.T) {
response := &ErrorResponse{
Code: "TEST_ERROR",
Message: "Test message",
}
// Create a writer that fails immediately
writer := &failingWriter{failAfter: 1}
err := response.Encode(writer)
assert.Error(t, err)
}
// Helper functions and types for testing
func encodeBigEndianUint32(val uint32) []byte {
return []byte{
byte(val >> 24),
byte(val >> 16),
byte(val >> 8),
byte(val),
}
}
type failingWriter struct {
written int
failAfter int
}
func (w *failingWriter) Write(data []byte) (int, error) {
if w.written >= w.failAfter {
return 0, io.ErrShortWrite
}
remaining := w.failAfter - w.written
if len(data) <= remaining {
w.written += len(data)
return len(data), nil
}
w.written = w.failAfter
return remaining, io.ErrShortWrite
}
type failingReader struct {
data []byte
pos int
failAfter int
}
func (r *failingReader) Read(buf []byte) (int, error) {
if r.pos >= r.failAfter {
return 0, io.ErrUnexpectedEOF
}
if r.pos >= len(r.data) {
return 0, io.EOF
}
n := copy(buf, r.data[r.pos:])
r.pos += n
if r.pos >= r.failAfter {
return n, io.ErrUnexpectedEOF
}
return n, nil
}

View file

@ -0,0 +1,65 @@
package protocol
import (
"encoding/binary"
"fmt"
"io"
)
// MessageDecoder handles decoding of protocol message headers
type MessageDecoder struct{}
// NewMessageDecoder creates a new message decoder
func NewMessageDecoder() *MessageDecoder {
return &MessageDecoder{}
}
// Decode reads the message header and returns a Message with the payload stream
func (d *MessageDecoder) Decode(r io.Reader) (*Message, error) {
// Read message type (1 byte)
var msgType MessageType
if err := binary.Read(r, binary.BigEndian, &msgType); err != nil {
if err == io.EOF {
return nil, err
}
return nil, fmt.Errorf("failed to read message type: %w", err)
}
// Validate message type (only request types are valid for server)
if !isValidRequestType(msgType) {
return nil, fmt.Errorf("invalid message type: 0x%02x", msgType)
}
// Read payload length (4 bytes, big-endian)
var payloadLength uint32
if err := binary.Read(r, binary.BigEndian, &payloadLength); err != nil {
return nil, fmt.Errorf("failed to read payload length: %w", err)
}
// Validate payload length
if payloadLength > MaxPayloadSize {
return nil, fmt.Errorf("payload length %d exceeds maximum %d", payloadLength, MaxPayloadSize)
}
// Create limited reader for the payload
var payloadStream io.Reader
if payloadLength > 0 {
payloadStream = io.LimitReader(r, int64(payloadLength))
}
return &Message{
Type: msgType,
PayloadLength: payloadLength,
PayloadStream: payloadStream,
}, nil
}
// isValidRequestType checks if the message type is a valid request type
func isValidRequestType(msgType MessageType) bool {
switch msgType {
case ChallengeRequestType, SolutionRequestType:
return true
default:
return false
}
}

View file

@ -0,0 +1,206 @@
package protocol
import (
"bytes"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMessageDecoder_Decode_Header(t *testing.T) {
decoder := NewMessageDecoder()
tests := []struct {
name string
data []byte
wantType MessageType
wantLength uint32
wantErr string
}{
{
name: "challenge request with empty payload",
data: []byte{0x01, 0x00, 0x00, 0x00, 0x00},
wantType: ChallengeRequestType,
wantLength: 0,
},
{
name: "solution request with payload",
data: append([]byte{0x03, 0x00, 0x00, 0x00, 0x05}, []byte("hello")...),
wantType: SolutionRequestType,
wantLength: 5,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := bytes.NewBuffer(tt.data)
msg, err := decoder.Decode(buf)
if tt.wantErr != "" {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.wantErr)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantType, msg.Type)
assert.Equal(t, tt.wantLength, msg.PayloadLength)
if tt.wantLength > 0 {
assert.NotNil(t, msg.PayloadStream)
} else {
assert.Nil(t, msg.PayloadStream)
}
})
}
}
func TestMessageDecoder_Decode_Errors(t *testing.T) {
decoder := NewMessageDecoder()
tests := []struct {
name string
data []byte
wantErr string
}{
{
name: "empty data",
data: []byte{},
wantErr: "EOF",
},
{
name: "invalid message type",
data: []byte{0xFF, 0x00, 0x00, 0x00, 0x00},
wantErr: "invalid message type",
},
{
name: "response type not allowed",
data: []byte{0x02, 0x00, 0x00, 0x00, 0x00}, // ChallengeResponseType
wantErr: "invalid message type",
},
{
name: "incomplete header",
data: []byte{0x01, 0x00, 0x00},
wantErr: "failed to read payload length",
},
{
name: "payload too large",
data: append([]byte{0x01}, encodeBigEndianUint32(MaxPayloadSize+1)...),
wantErr: "payload length",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := bytes.NewBuffer(tt.data)
_, err := decoder.Decode(buf)
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.wantErr)
})
}
}
func TestChallengeRequest_Decode(t *testing.T) {
req := &ChallengeRequest{}
t.Run("always succeeds", func(t *testing.T) {
err := req.Decode(nil)
assert.NoError(t, err)
err = req.Decode(bytes.NewReader([]byte("ignored")))
assert.NoError(t, err)
})
}
func TestSolutionRequest_Decode(t *testing.T) {
req := &SolutionRequest{}
t.Run("valid solution request", func(t *testing.T) {
payload := `{"challenge":{"timestamp":1640995200,"difficulty":4,"resource":"quotes","random":"cmFuZG9tMTIz","hmac":"aG1hY19zaWduYXR1cmU="},"nonce":12345}`
reader := bytes.NewReader([]byte(payload))
err := req.Decode(reader)
require.NoError(t, err)
assert.Equal(t, int64(1640995200), req.Challenge.Timestamp)
assert.Equal(t, 4, req.Challenge.Difficulty)
assert.Equal(t, "quotes", req.Challenge.Resource)
assert.Equal(t, uint64(12345), req.Nonce)
})
t.Run("empty payload should error", func(t *testing.T) {
err := req.Decode(nil)
assert.Error(t, err)
})
t.Run("invalid JSON should error", func(t *testing.T) {
payload := `{invalid json}`
reader := bytes.NewReader([]byte(payload))
err := req.Decode(reader)
assert.Error(t, err)
})
t.Run("invalid UTF-8 should error", func(t *testing.T) {
payload := []byte{0xFF, 0xFE, 0xFD}
reader := bytes.NewReader(payload)
err := req.Decode(reader)
assert.Error(t, err)
})
}
func TestEndToEnd_RequestDecoding(t *testing.T) {
decoder := NewMessageDecoder()
t.Run("challenge request flow", func(t *testing.T) {
// Create message data: type=0x01, length=0
data := []byte{0x01, 0x00, 0x00, 0x00, 0x00}
buf := bytes.NewBuffer(data)
// Decode header
msg, err := decoder.Decode(buf)
require.NoError(t, err)
assert.Equal(t, ChallengeRequestType, msg.Type)
// Decode request
req := &ChallengeRequest{}
err = req.Decode(msg.PayloadStream)
require.NoError(t, err)
})
t.Run("solution request flow", func(t *testing.T) {
payload := `{"challenge":{"timestamp":1640995200,"difficulty":4,"resource":"quotes","random":"cmFuZG9tMTIz","hmac":"aG1hY19zaWduYXR1cmU="},"nonce":12345}`
// Create message data: type=0x03, length, payload
var buf bytes.Buffer
buf.WriteByte(0x03) // SolutionRequestType
length := uint32(len(payload))
buf.Write(encodeBigEndianUint32(length))
buf.WriteString(payload)
// Decode header
msg, err := decoder.Decode(&buf)
require.NoError(t, err)
assert.Equal(t, SolutionRequestType, msg.Type)
assert.Equal(t, length, msg.PayloadLength)
// Decode request
req := &SolutionRequest{}
err = req.Decode(msg.PayloadStream)
require.NoError(t, err)
assert.Equal(t, uint64(12345), req.Nonce)
})
}
// Helper functions for testing
func encodeBigEndianUint32(val uint32) []byte {
return []byte{
byte(val >> 24),
byte(val >> 16),
byte(val >> 8),
byte(val),
}
}

View file

@ -1,64 +1,10 @@
package protocol
import (
"hash-of-wisdom/internal/pow/challenge"
"hash-of-wisdom/internal/quotes"
)
import "io"
// MessageType represents the type of protocol message
type MessageType byte
const (
ChallengeRequest MessageType = 0x01
ChallengeResponse MessageType = 0x02
SolutionRequest MessageType = 0x03
QuoteResponse MessageType = 0x04
ErrorResponse MessageType = 0x05
)
// Message represents a protocol message with type and payload
// Message represents a protocol message with type and payload stream
type Message struct {
Type MessageType
Payload []byte
PayloadLength uint32
PayloadStream io.Reader
}
// ChallengeRequestPayload is empty (no payload for challenge requests)
type ChallengeRequestPayload struct{}
// ChallengeResponsePayload is the direct challenge object (not wrapped)
type ChallengeResponsePayload challenge.Challenge
// SolutionRequestPayload contains the client's solution attempt
type SolutionRequestPayload struct {
Challenge challenge.Challenge `json:"challenge"`
Nonce uint64 `json:"nonce"`
}
// QuoteResponsePayload is the direct quote object (not wrapped)
type QuoteResponsePayload quotes.Quote
// ErrorResponsePayload contains error information
type ErrorResponsePayload struct {
Code string `json:"code"`
Message string `json:"message"`
RetryAfter int `json:"retry_after,omitempty"`
Details map[string]string `json:"details,omitempty"`
}
// Error codes as defined in protocol specification
const (
ErrMalformedMessage = "MALFORMED_MESSAGE"
ErrInvalidChallenge = "INVALID_CHALLENGE"
ErrInvalidSolution = "INVALID_SOLUTION"
ErrExpiredChallenge = "EXPIRED_CHALLENGE"
ErrRateLimited = "RATE_LIMITED"
ErrServerError = "SERVER_ERROR"
ErrTooManyConnections = "TOO_MANY_CONNECTIONS"
ErrDifficultyTooHigh = "DIFFICULTY_TOO_HIGH"
)
// Protocol constants
const (
MaxPayloadSize = 8 * 1024 // 8KB maximum payload size
HeaderSize = 5 // 1 byte type + 4 bytes length
)