hash-of-wisdom/internal/protocol/message_decoder_test.go

226 lines
5 KiB
Go

package protocol
import (
"bytes"
"io"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMessageDecoder_Decode_Header(t *testing.T) {
decoder := NewMessageDecoder()
tests := []struct {
name string
data []byte
wantType MessageType
wantLength uint32
wantErr string
}{
{
name: "challenge request with empty payload",
data: []byte{0x01, 0x00, 0x00, 0x00, 0x00},
wantType: ChallengeRequestType,
wantLength: 0,
},
{
name: "solution request with payload",
data: append([]byte{0x03, 0x00, 0x00, 0x00, 0x05}, []byte("hello")...),
wantType: SolutionRequestType,
wantLength: 5,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := bytes.NewBuffer(tt.data)
msg, err := decoder.Decode(buf)
if tt.wantErr != "" {
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.wantErr)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantType, msg.Type)
assert.Equal(t, tt.wantLength, msg.PayloadLength)
if tt.wantLength > 0 {
assert.NotNil(t, msg.PayloadStream)
} else {
assert.Nil(t, msg.PayloadStream)
}
})
}
}
func TestMessageDecoder_Decode_Errors(t *testing.T) {
decoder := NewMessageDecoder()
tests := []struct {
name string
data []byte
wantErr string
}{
{
name: "empty data",
data: []byte{},
wantErr: "EOF",
},
{
name: "invalid message type",
data: []byte{0xFF, 0x00, 0x00, 0x00, 0x00},
wantErr: "invalid message type",
},
{
name: "response type not allowed",
data: []byte{0x02, 0x00, 0x00, 0x00, 0x00}, // ChallengeResponseType
wantErr: "invalid message type",
},
{
name: "incomplete header",
data: []byte{0x01, 0x00, 0x00},
wantErr: "failed to read payload length",
},
{
name: "payload too large",
data: append([]byte{0x01}, encodeBigEndianUint32(MaxPayloadSize+1)...),
wantErr: "payload length",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := bytes.NewBuffer(tt.data)
_, err := decoder.Decode(buf)
assert.Error(t, err)
assert.Contains(t, err.Error(), tt.wantErr)
})
}
}
func TestChallengeRequest_Decode(t *testing.T) {
tests := []struct {
name string
stream io.Reader
}{
{"nil stream", nil},
{"non-empty stream", bytes.NewReader([]byte("ignored"))},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := &ChallengeRequest{}
err := req.Decode(tt.stream)
assert.NoError(t, err)
})
}
}
func TestSolutionRequest_Decode(t *testing.T) {
tests := []struct {
name string
payload []byte
wantErr bool
wantNonce uint64
}{
{
name: "valid solution request",
payload: []byte(`{"challenge":{"timestamp":1640995200,"difficulty":4,"resource":"quotes","random":"cmFuZG9tMTIz","hmac":"aG1hY19zaWduYXR1cmU="},"nonce":12345}`),
wantNonce: 12345,
},
{
name: "invalid JSON",
payload: []byte(`{invalid json}`),
wantErr: true,
},
{
name: "invalid UTF-8",
payload: []byte{0xFF, 0xFE, 0xFD},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := &SolutionRequest{}
var reader io.Reader
if tt.payload != nil {
reader = bytes.NewReader(tt.payload)
}
err := req.Decode(reader)
if tt.wantErr {
assert.Error(t, err)
} else {
require.NoError(t, err)
assert.Equal(t, tt.wantNonce, req.Nonce)
}
})
}
t.Run("nil stream should error", func(t *testing.T) {
req := &SolutionRequest{}
err := req.Decode(nil)
assert.Error(t, err)
})
}
func TestEndToEnd_RequestDecoding(t *testing.T) {
decoder := NewMessageDecoder()
t.Run("challenge request flow", func(t *testing.T) {
// Create message data: type=0x01, length=0
data := []byte{0x01, 0x00, 0x00, 0x00, 0x00}
buf := bytes.NewBuffer(data)
// Decode header
msg, err := decoder.Decode(buf)
require.NoError(t, err)
assert.Equal(t, ChallengeRequestType, msg.Type)
// Decode request
req := &ChallengeRequest{}
err = req.Decode(msg.PayloadStream)
require.NoError(t, err)
})
t.Run("solution request flow", func(t *testing.T) {
payload := `{"challenge":{"timestamp":1640995200,"difficulty":4,"resource":"quotes","random":"cmFuZG9tMTIz","hmac":"aG1hY19zaWduYXR1cmU="},"nonce":12345}`
// Create message data: type=0x03, length, payload
var buf bytes.Buffer
buf.WriteByte(0x03) // SolutionRequestType
length := uint32(len(payload))
buf.Write(encodeBigEndianUint32(length))
buf.WriteString(payload)
// Decode header
msg, err := decoder.Decode(&buf)
require.NoError(t, err)
assert.Equal(t, SolutionRequestType, msg.Type)
assert.Equal(t, length, msg.PayloadLength)
// Decode request
req := &SolutionRequest{}
err = req.Decode(msg.PayloadStream)
require.NoError(t, err)
assert.Equal(t, uint64(12345), req.Nonce)
})
}
// Helper functions for testing
func encodeBigEndianUint32(val uint32) []byte {
return []byte{
byte(val >> 24),
byte(val >> 16),
byte(val >> 8),
byte(val),
}
}