[PHASE-5] Delegate encoding to the objects themselves

This commit is contained in:
Savely Krendelhoff 2025-08-23 12:04:38 +07:00
parent 94eb94e167
commit d12de089a0
No known key found for this signature in database
GPG key ID: F70DFD34F40238DE
3 changed files with 173 additions and 217 deletions

View file

@ -15,38 +15,6 @@ func NewCodec() *Codec {
return &Codec{}
}
// Encode writes a message to the writer using the protocol format
func (c *Codec) Encode(w io.Writer, msg *Message) error {
if msg == nil {
return fmt.Errorf("message cannot be nil")
}
// Validate payload size
if len(msg.Payload) > MaxPayloadSize {
return fmt.Errorf("payload size %d exceeds maximum %d", len(msg.Payload), MaxPayloadSize)
}
// Write message type (1 byte)
if err := binary.Write(w, binary.BigEndian, msg.Type); err != nil {
return fmt.Errorf("failed to write message type: %w", err)
}
// Write payload length (4 bytes, big-endian)
payloadLength := uint32(len(msg.Payload))
if err := binary.Write(w, binary.BigEndian, payloadLength); err != nil {
return fmt.Errorf("failed to write payload length: %w", err)
}
// Write payload if present
if len(msg.Payload) > 0 {
if _, err := w.Write(msg.Payload); err != nil {
return fmt.Errorf("failed to write payload: %w", err)
}
}
return nil
}
// Decode reads a message from the reader using the protocol format
func (c *Codec) Decode(r io.Reader) (*Message, error) {
// Read message type (1 byte)
@ -78,12 +46,9 @@ func (c *Codec) Decode(r io.Reader) (*Message, error) {
var payload []byte
if payloadLength > 0 {
payload = make([]byte, payloadLength)
// Use LimitReader to ensure we don't read more than payloadLength bytes
// even if the underlying reader has more data available
limitedReader := io.LimitReader(r, int64(payloadLength))
// Note: ReadFull may block waiting for data. The connection handler
// MUST set appropriate read deadlines to prevent slowloris attacks
if _, err := io.ReadFull(limitedReader, payload); err != nil {
// ReadFull reads exactly payloadLength bytes
// The server MUST use LimitReader and set read deadlines to prevent attacks
if _, err := io.ReadFull(r, payload); err != nil {
return nil, fmt.Errorf("failed to read payload: %w", err)
}
@ -103,7 +68,7 @@ func (c *Codec) Decode(r io.Reader) (*Message, error) {
// isValidMessageType checks if the message type is defined in the protocol
func isValidMessageType(msgType MessageType) bool {
switch msgType {
case ChallengeRequest, ChallengeResponse, SolutionRequest, QuoteResponse, ErrorResponse:
case ChallengeRequestType, ChallengeResponseType, SolutionRequestType, QuoteResponseType, ErrorResponseType:
return true
default:
return false

View file

@ -2,102 +2,18 @@ package protocol
import (
"bytes"
"encoding/json"
"io"
"testing"
"time"
"hash-of-wisdom/internal/pow/challenge"
"hash-of-wisdom/internal/quotes"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestCodec_Encode_Decode(t *testing.T) {
codec := NewCodec()
tests := []struct {
name string
message *Message
}{
{
name: "challenge request (empty payload)",
message: &Message{
Type: ChallengeRequest,
Payload: []byte{},
},
},
{
name: "challenge response with payload",
message: &Message{
Type: ChallengeResponse,
Payload: []byte(`{"challenge":{"timestamp":1640995200,"difficulty":4}}`),
},
},
{
name: "error response",
message: &Message{
Type: ErrorResponse,
Payload: []byte(`{"code":"INVALID_SOLUTION","message":"Invalid nonce"}`),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var buf bytes.Buffer
// Encode message
err := codec.Encode(&buf, tt.message)
require.NoError(t, err)
// Decode message
decoded, err := codec.Decode(&buf)
require.NoError(t, err)
assert.Equal(t, tt.message.Type, decoded.Type)
if len(tt.message.Payload) == 0 && len(decoded.Payload) == 0 {
// Both are empty (nil or empty slice)
assert.True(t, true)
} else {
assert.Equal(t, tt.message.Payload, decoded.Payload)
}
})
}
}
func TestCodec_Encode_Errors(t *testing.T) {
codec := NewCodec()
tests := []struct {
name string
message *Message
wantErr string
}{
{
name: "nil message",
message: nil,
wantErr: "message cannot be nil",
},
{
name: "payload too large",
message: &Message{
Type: ChallengeRequest,
Payload: make([]byte, MaxPayloadSize+1),
},
wantErr: "payload size",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var buf bytes.Buffer
err := codec.Encode(&buf, tt.message)
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.wantErr)
})
}
}
func TestCodec_Decode_Errors(t *testing.T) {
func TestCodec_Decode(t *testing.T) {
codec := NewCodec()
tests := []struct {
@ -147,96 +63,6 @@ func TestCodec_Decode_Errors(t *testing.T) {
}
}
func TestCodec_RoundTrip_RealPayloads(t *testing.T) {
codec := NewCodec()
t.Run("challenge response round trip", func(t *testing.T) {
original := &ChallengeResponsePayload{
Timestamp: time.Now().Unix(),
Difficulty: 4,
Resource: "quotes",
Random: []byte("random123"),
HMAC: []byte("hmac_signature"),
}
// Marshal payload
jsonData, err := json.Marshal(original)
require.NoError(t, err)
msg := &Message{
Type: ChallengeResponse,
Payload: jsonData,
}
// Simulate network transmission
var buf bytes.Buffer
err = codec.Encode(&buf, msg)
require.NoError(t, err)
// Decode message
decoded, err := codec.Decode(&buf)
require.NoError(t, err)
// Unmarshal payload
var result ChallengeResponsePayload
err = json.Unmarshal(decoded.Payload, &result)
require.NoError(t, err)
assert.Equal(t, original.Timestamp, result.Timestamp)
assert.Equal(t, original.Difficulty, result.Difficulty)
assert.Equal(t, original.Resource, result.Resource)
assert.Equal(t, original.Random, result.Random)
assert.Equal(t, original.HMAC, result.HMAC)
})
t.Run("quote response round trip", func(t *testing.T) {
original := &QuoteResponsePayload{
Text: "Test quote",
Author: "Test author",
}
// Marshal payload
jsonData, err := json.Marshal(original)
require.NoError(t, err)
msg := &Message{
Type: QuoteResponse,
Payload: jsonData,
}
var buf bytes.Buffer
err = codec.Encode(&buf, msg)
require.NoError(t, err)
decoded, err := codec.Decode(&buf)
require.NoError(t, err)
var result QuoteResponsePayload
err = json.Unmarshal(decoded.Payload, &result)
require.NoError(t, err)
assert.Equal(t, original.Text, result.Text)
assert.Equal(t, original.Author, result.Author)
})
}
func TestCodec_WriteError_Handling(t *testing.T) {
codec := NewCodec()
// Create a writer that fails after a certain number of bytes
failAfter := 3
writer := &failingWriter{failAfter: failAfter}
msg := &Message{
Type: ChallengeResponse,
Payload: []byte("test payload"),
}
err := codec.Encode(writer, msg)
assert.Error(t, err)
}
func TestCodec_ReadError_Handling(t *testing.T) {
codec := NewCodec()
@ -251,6 +77,89 @@ func TestCodec_ReadError_Handling(t *testing.T) {
assert.Contains(t, err.Error(), "failed to read payload")
}
func TestChallengeResponse_Encode(t *testing.T) {
challenge := &challenge.Challenge{
Timestamp: time.Now().Unix(),
Difficulty: 4,
Resource: "quotes",
Random: []byte("random123"),
HMAC: []byte("hmac_signature"),
}
response := &ChallengeResponse{Challenge: challenge}
var buf bytes.Buffer
err := response.Encode(&buf)
require.NoError(t, err)
// Verify the encoded message can be decoded
codec := NewCodec()
decoded, err := codec.Decode(&buf)
require.NoError(t, err)
assert.Equal(t, ChallengeResponseType, decoded.Type)
assert.Contains(t, string(decoded.Payload), "quotes")
assert.Contains(t, string(decoded.Payload), "cmFuZG9tMTIz") // "random123" base64 encoded
}
func TestSolutionResponse_Encode(t *testing.T) {
quote := &quotes.Quote{
Text: "Test quote",
Author: "Test author",
}
response := &SolutionResponse{Quote: quote}
var buf bytes.Buffer
err := response.Encode(&buf)
require.NoError(t, err)
// Verify the encoded message can be decoded
codec := NewCodec()
decoded, err := codec.Decode(&buf)
require.NoError(t, err)
assert.Equal(t, QuoteResponseType, decoded.Type)
assert.Contains(t, string(decoded.Payload), "Test quote")
assert.Contains(t, string(decoded.Payload), "Test author")
}
func TestErrorResponse_Encode(t *testing.T) {
errorResp := &ErrorResponse{
Code: "INVALID_SOLUTION",
Message: "The provided PoW solution is incorrect",
RetryAfter: 30,
Details: map[string]string{"attempt": "1"},
}
var buf bytes.Buffer
err := errorResp.Encode(&buf)
require.NoError(t, err)
// Verify the encoded message can be decoded
codec := NewCodec()
decoded, err := codec.Decode(&buf)
require.NoError(t, err)
assert.Equal(t, ErrorResponseType, decoded.Type)
assert.Contains(t, string(decoded.Payload), "INVALID_SOLUTION")
assert.Contains(t, string(decoded.Payload), "The provided PoW solution is incorrect")
assert.Contains(t, string(decoded.Payload), "30")
}
func TestResponse_WriteError_Handling(t *testing.T) {
response := &ErrorResponse{
Code: "TEST_ERROR",
Message: "Test message",
}
// Create a writer that fails immediately
writer := &failingWriter{failAfter: 1}
err := response.Encode(writer)
assert.Error(t, err)
}
// Helper functions and types for testing
func encodeBigEndianUint32(val uint32) []byte {

View file

@ -0,0 +1,82 @@
package protocol
import (
"encoding/binary"
"encoding/json"
"fmt"
"io"
"hash-of-wisdom/internal/pow/challenge"
"hash-of-wisdom/internal/quotes"
)
// writeHeader writes the message type and payload length to the writer
func writeHeader(w io.Writer, msgType MessageType, payloadLength uint32) error {
// Write message type (1 byte)
if err := binary.Write(w, binary.BigEndian, msgType); err != nil {
return fmt.Errorf("failed to write message type: %w", err)
}
// Write payload length (4 bytes, big-endian)
if err := binary.Write(w, binary.BigEndian, payloadLength); err != nil {
return fmt.Errorf("failed to write payload length: %w", err)
}
return nil
}
// encodeResponse is a helper function that encodes any response with the given message type
func encodeResponse(w io.Writer, msgType MessageType, payload interface{}) error {
// Marshal to get exact payload size
payloadBytes, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("failed to encode payload: %w", err)
}
// Write header
if err := writeHeader(w, msgType, uint32(len(payloadBytes))); err != nil {
return err
}
// Write JSON payload directly to stream
if len(payloadBytes) > 0 {
if _, err := w.Write(payloadBytes); err != nil {
return fmt.Errorf("failed to write payload: %w", err)
}
}
return nil
}
// ChallengeResponse represents a challenge response
type ChallengeResponse struct {
Challenge *challenge.Challenge
}
// SolutionResponse represents a successful solution response (contains quote)
type SolutionResponse struct {
Quote *quotes.Quote
}
// ErrorResponse represents an error response
type ErrorResponse struct {
Code string `json:"code"`
Message string `json:"message"`
RetryAfter int `json:"retry_after,omitempty"`
Details map[string]string `json:"details,omitempty"`
}
// Encode writes the challenge response to the writer
func (r *ChallengeResponse) Encode(w io.Writer) error {
return encodeResponse(w, ChallengeResponseType, r.Challenge)
}
// Encode writes the solution response to the writer
func (r *SolutionResponse) Encode(w io.Writer) error {
return encodeResponse(w, QuoteResponseType, r.Quote)
}
// Encode writes the error response to the writer
func (r *ErrorResponse) Encode(w io.Writer) error {
return encodeResponse(w, ErrorResponseType, r)
}