[PHASE-5] Delegate encoding to the objects themselves
This commit is contained in:
parent
94eb94e167
commit
d12de089a0
|
|
@ -15,38 +15,6 @@ func NewCodec() *Codec {
|
|||
return &Codec{}
|
||||
}
|
||||
|
||||
// Encode writes a message to the writer using the protocol format
|
||||
func (c *Codec) Encode(w io.Writer, msg *Message) error {
|
||||
if msg == nil {
|
||||
return fmt.Errorf("message cannot be nil")
|
||||
}
|
||||
|
||||
// Validate payload size
|
||||
if len(msg.Payload) > MaxPayloadSize {
|
||||
return fmt.Errorf("payload size %d exceeds maximum %d", len(msg.Payload), MaxPayloadSize)
|
||||
}
|
||||
|
||||
// Write message type (1 byte)
|
||||
if err := binary.Write(w, binary.BigEndian, msg.Type); err != nil {
|
||||
return fmt.Errorf("failed to write message type: %w", err)
|
||||
}
|
||||
|
||||
// Write payload length (4 bytes, big-endian)
|
||||
payloadLength := uint32(len(msg.Payload))
|
||||
if err := binary.Write(w, binary.BigEndian, payloadLength); err != nil {
|
||||
return fmt.Errorf("failed to write payload length: %w", err)
|
||||
}
|
||||
|
||||
// Write payload if present
|
||||
if len(msg.Payload) > 0 {
|
||||
if _, err := w.Write(msg.Payload); err != nil {
|
||||
return fmt.Errorf("failed to write payload: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decode reads a message from the reader using the protocol format
|
||||
func (c *Codec) Decode(r io.Reader) (*Message, error) {
|
||||
// Read message type (1 byte)
|
||||
|
|
@ -78,12 +46,9 @@ func (c *Codec) Decode(r io.Reader) (*Message, error) {
|
|||
var payload []byte
|
||||
if payloadLength > 0 {
|
||||
payload = make([]byte, payloadLength)
|
||||
// Use LimitReader to ensure we don't read more than payloadLength bytes
|
||||
// even if the underlying reader has more data available
|
||||
limitedReader := io.LimitReader(r, int64(payloadLength))
|
||||
// Note: ReadFull may block waiting for data. The connection handler
|
||||
// MUST set appropriate read deadlines to prevent slowloris attacks
|
||||
if _, err := io.ReadFull(limitedReader, payload); err != nil {
|
||||
// ReadFull reads exactly payloadLength bytes
|
||||
// The server MUST use LimitReader and set read deadlines to prevent attacks
|
||||
if _, err := io.ReadFull(r, payload); err != nil {
|
||||
return nil, fmt.Errorf("failed to read payload: %w", err)
|
||||
}
|
||||
|
||||
|
|
@ -103,7 +68,7 @@ func (c *Codec) Decode(r io.Reader) (*Message, error) {
|
|||
// isValidMessageType checks if the message type is defined in the protocol
|
||||
func isValidMessageType(msgType MessageType) bool {
|
||||
switch msgType {
|
||||
case ChallengeRequest, ChallengeResponse, SolutionRequest, QuoteResponse, ErrorResponse:
|
||||
case ChallengeRequestType, ChallengeResponseType, SolutionRequestType, QuoteResponseType, ErrorResponseType:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
|
|
|
|||
|
|
@ -2,102 +2,18 @@ package protocol
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"hash-of-wisdom/internal/pow/challenge"
|
||||
"hash-of-wisdom/internal/quotes"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCodec_Encode_Decode(t *testing.T) {
|
||||
codec := NewCodec()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
message *Message
|
||||
}{
|
||||
{
|
||||
name: "challenge request (empty payload)",
|
||||
message: &Message{
|
||||
Type: ChallengeRequest,
|
||||
Payload: []byte{},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "challenge response with payload",
|
||||
message: &Message{
|
||||
Type: ChallengeResponse,
|
||||
Payload: []byte(`{"challenge":{"timestamp":1640995200,"difficulty":4}}`),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "error response",
|
||||
message: &Message{
|
||||
Type: ErrorResponse,
|
||||
Payload: []byte(`{"code":"INVALID_SOLUTION","message":"Invalid nonce"}`),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
|
||||
// Encode message
|
||||
err := codec.Encode(&buf, tt.message)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Decode message
|
||||
decoded, err := codec.Decode(&buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tt.message.Type, decoded.Type)
|
||||
if len(tt.message.Payload) == 0 && len(decoded.Payload) == 0 {
|
||||
// Both are empty (nil or empty slice)
|
||||
assert.True(t, true)
|
||||
} else {
|
||||
assert.Equal(t, tt.message.Payload, decoded.Payload)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodec_Encode_Errors(t *testing.T) {
|
||||
codec := NewCodec()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
message *Message
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "nil message",
|
||||
message: nil,
|
||||
wantErr: "message cannot be nil",
|
||||
},
|
||||
{
|
||||
name: "payload too large",
|
||||
message: &Message{
|
||||
Type: ChallengeRequest,
|
||||
Payload: make([]byte, MaxPayloadSize+1),
|
||||
},
|
||||
wantErr: "payload size",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
err := codec.Encode(&buf, tt.message)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.wantErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCodec_Decode_Errors(t *testing.T) {
|
||||
func TestCodec_Decode(t *testing.T) {
|
||||
codec := NewCodec()
|
||||
|
||||
tests := []struct {
|
||||
|
|
@ -147,96 +63,6 @@ func TestCodec_Decode_Errors(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
func TestCodec_RoundTrip_RealPayloads(t *testing.T) {
|
||||
codec := NewCodec()
|
||||
|
||||
t.Run("challenge response round trip", func(t *testing.T) {
|
||||
original := &ChallengeResponsePayload{
|
||||
Timestamp: time.Now().Unix(),
|
||||
Difficulty: 4,
|
||||
Resource: "quotes",
|
||||
Random: []byte("random123"),
|
||||
HMAC: []byte("hmac_signature"),
|
||||
}
|
||||
|
||||
// Marshal payload
|
||||
jsonData, err := json.Marshal(original)
|
||||
require.NoError(t, err)
|
||||
|
||||
msg := &Message{
|
||||
Type: ChallengeResponse,
|
||||
Payload: jsonData,
|
||||
}
|
||||
|
||||
// Simulate network transmission
|
||||
var buf bytes.Buffer
|
||||
err = codec.Encode(&buf, msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Decode message
|
||||
decoded, err := codec.Decode(&buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Unmarshal payload
|
||||
var result ChallengeResponsePayload
|
||||
err = json.Unmarshal(decoded.Payload, &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, original.Timestamp, result.Timestamp)
|
||||
assert.Equal(t, original.Difficulty, result.Difficulty)
|
||||
assert.Equal(t, original.Resource, result.Resource)
|
||||
assert.Equal(t, original.Random, result.Random)
|
||||
assert.Equal(t, original.HMAC, result.HMAC)
|
||||
})
|
||||
|
||||
t.Run("quote response round trip", func(t *testing.T) {
|
||||
original := &QuoteResponsePayload{
|
||||
Text: "Test quote",
|
||||
Author: "Test author",
|
||||
}
|
||||
|
||||
// Marshal payload
|
||||
jsonData, err := json.Marshal(original)
|
||||
require.NoError(t, err)
|
||||
|
||||
msg := &Message{
|
||||
Type: QuoteResponse,
|
||||
Payload: jsonData,
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = codec.Encode(&buf, msg)
|
||||
require.NoError(t, err)
|
||||
|
||||
decoded, err := codec.Decode(&buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
var result QuoteResponsePayload
|
||||
err = json.Unmarshal(decoded.Payload, &result)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, original.Text, result.Text)
|
||||
assert.Equal(t, original.Author, result.Author)
|
||||
})
|
||||
}
|
||||
|
||||
func TestCodec_WriteError_Handling(t *testing.T) {
|
||||
codec := NewCodec()
|
||||
|
||||
// Create a writer that fails after a certain number of bytes
|
||||
failAfter := 3
|
||||
writer := &failingWriter{failAfter: failAfter}
|
||||
|
||||
msg := &Message{
|
||||
Type: ChallengeResponse,
|
||||
Payload: []byte("test payload"),
|
||||
}
|
||||
|
||||
err := codec.Encode(writer, msg)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
func TestCodec_ReadError_Handling(t *testing.T) {
|
||||
codec := NewCodec()
|
||||
|
||||
|
|
@ -251,6 +77,89 @@ func TestCodec_ReadError_Handling(t *testing.T) {
|
|||
assert.Contains(t, err.Error(), "failed to read payload")
|
||||
}
|
||||
|
||||
func TestChallengeResponse_Encode(t *testing.T) {
|
||||
challenge := &challenge.Challenge{
|
||||
Timestamp: time.Now().Unix(),
|
||||
Difficulty: 4,
|
||||
Resource: "quotes",
|
||||
Random: []byte("random123"),
|
||||
HMAC: []byte("hmac_signature"),
|
||||
}
|
||||
|
||||
response := &ChallengeResponse{Challenge: challenge}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := response.Encode(&buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the encoded message can be decoded
|
||||
codec := NewCodec()
|
||||
decoded, err := codec.Decode(&buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, ChallengeResponseType, decoded.Type)
|
||||
assert.Contains(t, string(decoded.Payload), "quotes")
|
||||
assert.Contains(t, string(decoded.Payload), "cmFuZG9tMTIz") // "random123" base64 encoded
|
||||
}
|
||||
|
||||
func TestSolutionResponse_Encode(t *testing.T) {
|
||||
quote := "es.Quote{
|
||||
Text: "Test quote",
|
||||
Author: "Test author",
|
||||
}
|
||||
|
||||
response := &SolutionResponse{Quote: quote}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := response.Encode(&buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the encoded message can be decoded
|
||||
codec := NewCodec()
|
||||
decoded, err := codec.Decode(&buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, QuoteResponseType, decoded.Type)
|
||||
assert.Contains(t, string(decoded.Payload), "Test quote")
|
||||
assert.Contains(t, string(decoded.Payload), "Test author")
|
||||
}
|
||||
|
||||
func TestErrorResponse_Encode(t *testing.T) {
|
||||
errorResp := &ErrorResponse{
|
||||
Code: "INVALID_SOLUTION",
|
||||
Message: "The provided PoW solution is incorrect",
|
||||
RetryAfter: 30,
|
||||
Details: map[string]string{"attempt": "1"},
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err := errorResp.Encode(&buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify the encoded message can be decoded
|
||||
codec := NewCodec()
|
||||
decoded, err := codec.Decode(&buf)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, ErrorResponseType, decoded.Type)
|
||||
assert.Contains(t, string(decoded.Payload), "INVALID_SOLUTION")
|
||||
assert.Contains(t, string(decoded.Payload), "The provided PoW solution is incorrect")
|
||||
assert.Contains(t, string(decoded.Payload), "30")
|
||||
}
|
||||
|
||||
func TestResponse_WriteError_Handling(t *testing.T) {
|
||||
response := &ErrorResponse{
|
||||
Code: "TEST_ERROR",
|
||||
Message: "Test message",
|
||||
}
|
||||
|
||||
// Create a writer that fails immediately
|
||||
writer := &failingWriter{failAfter: 1}
|
||||
|
||||
err := response.Encode(writer)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
|
||||
// Helper functions and types for testing
|
||||
|
||||
func encodeBigEndianUint32(val uint32) []byte {
|
||||
|
|
|
|||
82
internal/protocol/responses.go
Normal file
82
internal/protocol/responses.go
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"hash-of-wisdom/internal/pow/challenge"
|
||||
"hash-of-wisdom/internal/quotes"
|
||||
)
|
||||
|
||||
// writeHeader writes the message type and payload length to the writer
|
||||
func writeHeader(w io.Writer, msgType MessageType, payloadLength uint32) error {
|
||||
// Write message type (1 byte)
|
||||
if err := binary.Write(w, binary.BigEndian, msgType); err != nil {
|
||||
return fmt.Errorf("failed to write message type: %w", err)
|
||||
}
|
||||
|
||||
// Write payload length (4 bytes, big-endian)
|
||||
if err := binary.Write(w, binary.BigEndian, payloadLength); err != nil {
|
||||
return fmt.Errorf("failed to write payload length: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// encodeResponse is a helper function that encodes any response with the given message type
|
||||
func encodeResponse(w io.Writer, msgType MessageType, payload interface{}) error {
|
||||
// Marshal to get exact payload size
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode payload: %w", err)
|
||||
}
|
||||
|
||||
// Write header
|
||||
if err := writeHeader(w, msgType, uint32(len(payloadBytes))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write JSON payload directly to stream
|
||||
if len(payloadBytes) > 0 {
|
||||
if _, err := w.Write(payloadBytes); err != nil {
|
||||
return fmt.Errorf("failed to write payload: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ChallengeResponse represents a challenge response
|
||||
type ChallengeResponse struct {
|
||||
Challenge *challenge.Challenge
|
||||
}
|
||||
|
||||
// SolutionResponse represents a successful solution response (contains quote)
|
||||
type SolutionResponse struct {
|
||||
Quote *quotes.Quote
|
||||
}
|
||||
|
||||
// ErrorResponse represents an error response
|
||||
type ErrorResponse struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
RetryAfter int `json:"retry_after,omitempty"`
|
||||
Details map[string]string `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
// Encode writes the challenge response to the writer
|
||||
func (r *ChallengeResponse) Encode(w io.Writer) error {
|
||||
return encodeResponse(w, ChallengeResponseType, r.Challenge)
|
||||
}
|
||||
|
||||
// Encode writes the solution response to the writer
|
||||
func (r *SolutionResponse) Encode(w io.Writer) error {
|
||||
return encodeResponse(w, QuoteResponseType, r.Quote)
|
||||
}
|
||||
|
||||
// Encode writes the error response to the writer
|
||||
func (r *ErrorResponse) Encode(w io.Writer) error {
|
||||
return encodeResponse(w, ErrorResponseType, r)
|
||||
}
|
||||
Loading…
Reference in a new issue