[PHASE-5] Delegate encoding to the objects themselves
This commit is contained in:
parent
94eb94e167
commit
d12de089a0
|
|
@ -15,38 +15,6 @@ func NewCodec() *Codec {
|
||||||
return &Codec{}
|
return &Codec{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encode writes a message to the writer using the protocol format
|
|
||||||
func (c *Codec) Encode(w io.Writer, msg *Message) error {
|
|
||||||
if msg == nil {
|
|
||||||
return fmt.Errorf("message cannot be nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate payload size
|
|
||||||
if len(msg.Payload) > MaxPayloadSize {
|
|
||||||
return fmt.Errorf("payload size %d exceeds maximum %d", len(msg.Payload), MaxPayloadSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write message type (1 byte)
|
|
||||||
if err := binary.Write(w, binary.BigEndian, msg.Type); err != nil {
|
|
||||||
return fmt.Errorf("failed to write message type: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write payload length (4 bytes, big-endian)
|
|
||||||
payloadLength := uint32(len(msg.Payload))
|
|
||||||
if err := binary.Write(w, binary.BigEndian, payloadLength); err != nil {
|
|
||||||
return fmt.Errorf("failed to write payload length: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write payload if present
|
|
||||||
if len(msg.Payload) > 0 {
|
|
||||||
if _, err := w.Write(msg.Payload); err != nil {
|
|
||||||
return fmt.Errorf("failed to write payload: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode reads a message from the reader using the protocol format
|
// Decode reads a message from the reader using the protocol format
|
||||||
func (c *Codec) Decode(r io.Reader) (*Message, error) {
|
func (c *Codec) Decode(r io.Reader) (*Message, error) {
|
||||||
// Read message type (1 byte)
|
// Read message type (1 byte)
|
||||||
|
|
@ -78,12 +46,9 @@ func (c *Codec) Decode(r io.Reader) (*Message, error) {
|
||||||
var payload []byte
|
var payload []byte
|
||||||
if payloadLength > 0 {
|
if payloadLength > 0 {
|
||||||
payload = make([]byte, payloadLength)
|
payload = make([]byte, payloadLength)
|
||||||
// Use LimitReader to ensure we don't read more than payloadLength bytes
|
// ReadFull reads exactly payloadLength bytes
|
||||||
// even if the underlying reader has more data available
|
// The server MUST use LimitReader and set read deadlines to prevent attacks
|
||||||
limitedReader := io.LimitReader(r, int64(payloadLength))
|
if _, err := io.ReadFull(r, payload); err != nil {
|
||||||
// Note: ReadFull may block waiting for data. The connection handler
|
|
||||||
// MUST set appropriate read deadlines to prevent slowloris attacks
|
|
||||||
if _, err := io.ReadFull(limitedReader, payload); err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to read payload: %w", err)
|
return nil, fmt.Errorf("failed to read payload: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -103,7 +68,7 @@ func (c *Codec) Decode(r io.Reader) (*Message, error) {
|
||||||
// isValidMessageType checks if the message type is defined in the protocol
|
// isValidMessageType checks if the message type is defined in the protocol
|
||||||
func isValidMessageType(msgType MessageType) bool {
|
func isValidMessageType(msgType MessageType) bool {
|
||||||
switch msgType {
|
switch msgType {
|
||||||
case ChallengeRequest, ChallengeResponse, SolutionRequest, QuoteResponse, ErrorResponse:
|
case ChallengeRequestType, ChallengeResponseType, SolutionRequestType, QuoteResponseType, ErrorResponseType:
|
||||||
return true
|
return true
|
||||||
default:
|
default:
|
||||||
return false
|
return false
|
||||||
|
|
|
||||||
|
|
@ -2,102 +2,18 @@ package protocol
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
|
||||||
"io"
|
"io"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"hash-of-wisdom/internal/pow/challenge"
|
||||||
|
"hash-of-wisdom/internal/quotes"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCodec_Encode_Decode(t *testing.T) {
|
func TestCodec_Decode(t *testing.T) {
|
||||||
codec := NewCodec()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
message *Message
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "challenge request (empty payload)",
|
|
||||||
message: &Message{
|
|
||||||
Type: ChallengeRequest,
|
|
||||||
Payload: []byte{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "challenge response with payload",
|
|
||||||
message: &Message{
|
|
||||||
Type: ChallengeResponse,
|
|
||||||
Payload: []byte(`{"challenge":{"timestamp":1640995200,"difficulty":4}}`),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "error response",
|
|
||||||
message: &Message{
|
|
||||||
Type: ErrorResponse,
|
|
||||||
Payload: []byte(`{"code":"INVALID_SOLUTION","message":"Invalid nonce"}`),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
|
|
||||||
// Encode message
|
|
||||||
err := codec.Encode(&buf, tt.message)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Decode message
|
|
||||||
decoded, err := codec.Decode(&buf)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, tt.message.Type, decoded.Type)
|
|
||||||
if len(tt.message.Payload) == 0 && len(decoded.Payload) == 0 {
|
|
||||||
// Both are empty (nil or empty slice)
|
|
||||||
assert.True(t, true)
|
|
||||||
} else {
|
|
||||||
assert.Equal(t, tt.message.Payload, decoded.Payload)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCodec_Encode_Errors(t *testing.T) {
|
|
||||||
codec := NewCodec()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
message *Message
|
|
||||||
wantErr string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "nil message",
|
|
||||||
message: nil,
|
|
||||||
wantErr: "message cannot be nil",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "payload too large",
|
|
||||||
message: &Message{
|
|
||||||
Type: ChallengeRequest,
|
|
||||||
Payload: make([]byte, MaxPayloadSize+1),
|
|
||||||
},
|
|
||||||
wantErr: "payload size",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
err := codec.Encode(&buf, tt.message)
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.Contains(t, err.Error(), tt.wantErr)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCodec_Decode_Errors(t *testing.T) {
|
|
||||||
codec := NewCodec()
|
codec := NewCodec()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
|
|
@ -147,96 +63,6 @@ func TestCodec_Decode_Errors(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
func TestCodec_RoundTrip_RealPayloads(t *testing.T) {
|
|
||||||
codec := NewCodec()
|
|
||||||
|
|
||||||
t.Run("challenge response round trip", func(t *testing.T) {
|
|
||||||
original := &ChallengeResponsePayload{
|
|
||||||
Timestamp: time.Now().Unix(),
|
|
||||||
Difficulty: 4,
|
|
||||||
Resource: "quotes",
|
|
||||||
Random: []byte("random123"),
|
|
||||||
HMAC: []byte("hmac_signature"),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Marshal payload
|
|
||||||
jsonData, err := json.Marshal(original)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
msg := &Message{
|
|
||||||
Type: ChallengeResponse,
|
|
||||||
Payload: jsonData,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simulate network transmission
|
|
||||||
var buf bytes.Buffer
|
|
||||||
err = codec.Encode(&buf, msg)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Decode message
|
|
||||||
decoded, err := codec.Decode(&buf)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Unmarshal payload
|
|
||||||
var result ChallengeResponsePayload
|
|
||||||
err = json.Unmarshal(decoded.Payload, &result)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, original.Timestamp, result.Timestamp)
|
|
||||||
assert.Equal(t, original.Difficulty, result.Difficulty)
|
|
||||||
assert.Equal(t, original.Resource, result.Resource)
|
|
||||||
assert.Equal(t, original.Random, result.Random)
|
|
||||||
assert.Equal(t, original.HMAC, result.HMAC)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("quote response round trip", func(t *testing.T) {
|
|
||||||
original := &QuoteResponsePayload{
|
|
||||||
Text: "Test quote",
|
|
||||||
Author: "Test author",
|
|
||||||
}
|
|
||||||
|
|
||||||
// Marshal payload
|
|
||||||
jsonData, err := json.Marshal(original)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
msg := &Message{
|
|
||||||
Type: QuoteResponse,
|
|
||||||
Payload: jsonData,
|
|
||||||
}
|
|
||||||
|
|
||||||
var buf bytes.Buffer
|
|
||||||
err = codec.Encode(&buf, msg)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
decoded, err := codec.Decode(&buf)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
var result QuoteResponsePayload
|
|
||||||
err = json.Unmarshal(decoded.Payload, &result)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, original.Text, result.Text)
|
|
||||||
assert.Equal(t, original.Author, result.Author)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCodec_WriteError_Handling(t *testing.T) {
|
|
||||||
codec := NewCodec()
|
|
||||||
|
|
||||||
// Create a writer that fails after a certain number of bytes
|
|
||||||
failAfter := 3
|
|
||||||
writer := &failingWriter{failAfter: failAfter}
|
|
||||||
|
|
||||||
msg := &Message{
|
|
||||||
Type: ChallengeResponse,
|
|
||||||
Payload: []byte("test payload"),
|
|
||||||
}
|
|
||||||
|
|
||||||
err := codec.Encode(writer, msg)
|
|
||||||
assert.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCodec_ReadError_Handling(t *testing.T) {
|
func TestCodec_ReadError_Handling(t *testing.T) {
|
||||||
codec := NewCodec()
|
codec := NewCodec()
|
||||||
|
|
||||||
|
|
@ -251,6 +77,89 @@ func TestCodec_ReadError_Handling(t *testing.T) {
|
||||||
assert.Contains(t, err.Error(), "failed to read payload")
|
assert.Contains(t, err.Error(), "failed to read payload")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestChallengeResponse_Encode(t *testing.T) {
|
||||||
|
challenge := &challenge.Challenge{
|
||||||
|
Timestamp: time.Now().Unix(),
|
||||||
|
Difficulty: 4,
|
||||||
|
Resource: "quotes",
|
||||||
|
Random: []byte("random123"),
|
||||||
|
HMAC: []byte("hmac_signature"),
|
||||||
|
}
|
||||||
|
|
||||||
|
response := &ChallengeResponse{Challenge: challenge}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := response.Encode(&buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the encoded message can be decoded
|
||||||
|
codec := NewCodec()
|
||||||
|
decoded, err := codec.Decode(&buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, ChallengeResponseType, decoded.Type)
|
||||||
|
assert.Contains(t, string(decoded.Payload), "quotes")
|
||||||
|
assert.Contains(t, string(decoded.Payload), "cmFuZG9tMTIz") // "random123" base64 encoded
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSolutionResponse_Encode(t *testing.T) {
|
||||||
|
quote := "es.Quote{
|
||||||
|
Text: "Test quote",
|
||||||
|
Author: "Test author",
|
||||||
|
}
|
||||||
|
|
||||||
|
response := &SolutionResponse{Quote: quote}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := response.Encode(&buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the encoded message can be decoded
|
||||||
|
codec := NewCodec()
|
||||||
|
decoded, err := codec.Decode(&buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, QuoteResponseType, decoded.Type)
|
||||||
|
assert.Contains(t, string(decoded.Payload), "Test quote")
|
||||||
|
assert.Contains(t, string(decoded.Payload), "Test author")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestErrorResponse_Encode(t *testing.T) {
|
||||||
|
errorResp := &ErrorResponse{
|
||||||
|
Code: "INVALID_SOLUTION",
|
||||||
|
Message: "The provided PoW solution is incorrect",
|
||||||
|
RetryAfter: 30,
|
||||||
|
Details: map[string]string{"attempt": "1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := errorResp.Encode(&buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the encoded message can be decoded
|
||||||
|
codec := NewCodec()
|
||||||
|
decoded, err := codec.Decode(&buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, ErrorResponseType, decoded.Type)
|
||||||
|
assert.Contains(t, string(decoded.Payload), "INVALID_SOLUTION")
|
||||||
|
assert.Contains(t, string(decoded.Payload), "The provided PoW solution is incorrect")
|
||||||
|
assert.Contains(t, string(decoded.Payload), "30")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponse_WriteError_Handling(t *testing.T) {
|
||||||
|
response := &ErrorResponse{
|
||||||
|
Code: "TEST_ERROR",
|
||||||
|
Message: "Test message",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a writer that fails immediately
|
||||||
|
writer := &failingWriter{failAfter: 1}
|
||||||
|
|
||||||
|
err := response.Encode(writer)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
// Helper functions and types for testing
|
// Helper functions and types for testing
|
||||||
|
|
||||||
func encodeBigEndianUint32(val uint32) []byte {
|
func encodeBigEndianUint32(val uint32) []byte {
|
||||||
|
|
|
||||||
82
internal/protocol/responses.go
Normal file
82
internal/protocol/responses.go
Normal file
|
|
@ -0,0 +1,82 @@
|
||||||
|
package protocol
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"hash-of-wisdom/internal/pow/challenge"
|
||||||
|
"hash-of-wisdom/internal/quotes"
|
||||||
|
)
|
||||||
|
|
||||||
|
// writeHeader writes the message type and payload length to the writer
|
||||||
|
func writeHeader(w io.Writer, msgType MessageType, payloadLength uint32) error {
|
||||||
|
// Write message type (1 byte)
|
||||||
|
if err := binary.Write(w, binary.BigEndian, msgType); err != nil {
|
||||||
|
return fmt.Errorf("failed to write message type: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write payload length (4 bytes, big-endian)
|
||||||
|
if err := binary.Write(w, binary.BigEndian, payloadLength); err != nil {
|
||||||
|
return fmt.Errorf("failed to write payload length: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodeResponse is a helper function that encodes any response with the given message type
|
||||||
|
func encodeResponse(w io.Writer, msgType MessageType, payload interface{}) error {
|
||||||
|
// Marshal to get exact payload size
|
||||||
|
payloadBytes, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to encode payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write header
|
||||||
|
if err := writeHeader(w, msgType, uint32(len(payloadBytes))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write JSON payload directly to stream
|
||||||
|
if len(payloadBytes) > 0 {
|
||||||
|
if _, err := w.Write(payloadBytes); err != nil {
|
||||||
|
return fmt.Errorf("failed to write payload: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChallengeResponse represents a challenge response
|
||||||
|
type ChallengeResponse struct {
|
||||||
|
Challenge *challenge.Challenge
|
||||||
|
}
|
||||||
|
|
||||||
|
// SolutionResponse represents a successful solution response (contains quote)
|
||||||
|
type SolutionResponse struct {
|
||||||
|
Quote *quotes.Quote
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorResponse represents an error response
|
||||||
|
type ErrorResponse struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
RetryAfter int `json:"retry_after,omitempty"`
|
||||||
|
Details map[string]string `json:"details,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode writes the challenge response to the writer
|
||||||
|
func (r *ChallengeResponse) Encode(w io.Writer) error {
|
||||||
|
return encodeResponse(w, ChallengeResponseType, r.Challenge)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode writes the solution response to the writer
|
||||||
|
func (r *SolutionResponse) Encode(w io.Writer) error {
|
||||||
|
return encodeResponse(w, QuoteResponseType, r.Quote)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode writes the error response to the writer
|
||||||
|
func (r *ErrorResponse) Encode(w io.Writer) error {
|
||||||
|
return encodeResponse(w, ErrorResponseType, r)
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue