207 lines
5 KiB
Go
207 lines
5 KiB
Go
|
|
package protocol
|
||
|
|
|
||
|
|
import (
|
||
|
|
"bytes"
|
||
|
|
"testing"
|
||
|
|
|
||
|
|
"github.com/stretchr/testify/assert"
|
||
|
|
"github.com/stretchr/testify/require"
|
||
|
|
)
|
||
|
|
|
||
|
|
func TestMessageDecoder_Decode_Header(t *testing.T) {
|
||
|
|
decoder := NewMessageDecoder()
|
||
|
|
|
||
|
|
tests := []struct {
|
||
|
|
name string
|
||
|
|
data []byte
|
||
|
|
wantType MessageType
|
||
|
|
wantLength uint32
|
||
|
|
wantErr string
|
||
|
|
}{
|
||
|
|
{
|
||
|
|
name: "challenge request with empty payload",
|
||
|
|
data: []byte{0x01, 0x00, 0x00, 0x00, 0x00},
|
||
|
|
wantType: ChallengeRequestType,
|
||
|
|
wantLength: 0,
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "solution request with payload",
|
||
|
|
data: append([]byte{0x03, 0x00, 0x00, 0x00, 0x05}, []byte("hello")...),
|
||
|
|
wantType: SolutionRequestType,
|
||
|
|
wantLength: 5,
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, tt := range tests {
|
||
|
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
|
buf := bytes.NewBuffer(tt.data)
|
||
|
|
msg, err := decoder.Decode(buf)
|
||
|
|
|
||
|
|
if tt.wantErr != "" {
|
||
|
|
assert.Error(t, err)
|
||
|
|
assert.Contains(t, err.Error(), tt.wantErr)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
require.NoError(t, err)
|
||
|
|
assert.Equal(t, tt.wantType, msg.Type)
|
||
|
|
assert.Equal(t, tt.wantLength, msg.PayloadLength)
|
||
|
|
|
||
|
|
if tt.wantLength > 0 {
|
||
|
|
assert.NotNil(t, msg.PayloadStream)
|
||
|
|
} else {
|
||
|
|
assert.Nil(t, msg.PayloadStream)
|
||
|
|
}
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestMessageDecoder_Decode_Errors(t *testing.T) {
|
||
|
|
decoder := NewMessageDecoder()
|
||
|
|
|
||
|
|
tests := []struct {
|
||
|
|
name string
|
||
|
|
data []byte
|
||
|
|
wantErr string
|
||
|
|
}{
|
||
|
|
{
|
||
|
|
name: "empty data",
|
||
|
|
data: []byte{},
|
||
|
|
wantErr: "EOF",
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "invalid message type",
|
||
|
|
data: []byte{0xFF, 0x00, 0x00, 0x00, 0x00},
|
||
|
|
wantErr: "invalid message type",
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "response type not allowed",
|
||
|
|
data: []byte{0x02, 0x00, 0x00, 0x00, 0x00}, // ChallengeResponseType
|
||
|
|
wantErr: "invalid message type",
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "incomplete header",
|
||
|
|
data: []byte{0x01, 0x00, 0x00},
|
||
|
|
wantErr: "failed to read payload length",
|
||
|
|
},
|
||
|
|
{
|
||
|
|
name: "payload too large",
|
||
|
|
data: append([]byte{0x01}, encodeBigEndianUint32(MaxPayloadSize+1)...),
|
||
|
|
wantErr: "payload length",
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, tt := range tests {
|
||
|
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
|
buf := bytes.NewBuffer(tt.data)
|
||
|
|
_, err := decoder.Decode(buf)
|
||
|
|
assert.Error(t, err)
|
||
|
|
assert.Contains(t, err.Error(), tt.wantErr)
|
||
|
|
})
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestChallengeRequest_Decode(t *testing.T) {
|
||
|
|
req := &ChallengeRequest{}
|
||
|
|
|
||
|
|
t.Run("always succeeds", func(t *testing.T) {
|
||
|
|
err := req.Decode(nil)
|
||
|
|
assert.NoError(t, err)
|
||
|
|
|
||
|
|
err = req.Decode(bytes.NewReader([]byte("ignored")))
|
||
|
|
assert.NoError(t, err)
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestSolutionRequest_Decode(t *testing.T) {
|
||
|
|
req := &SolutionRequest{}
|
||
|
|
|
||
|
|
t.Run("valid solution request", func(t *testing.T) {
|
||
|
|
payload := `{"challenge":{"timestamp":1640995200,"difficulty":4,"resource":"quotes","random":"cmFuZG9tMTIz","hmac":"aG1hY19zaWduYXR1cmU="},"nonce":12345}`
|
||
|
|
reader := bytes.NewReader([]byte(payload))
|
||
|
|
|
||
|
|
err := req.Decode(reader)
|
||
|
|
require.NoError(t, err)
|
||
|
|
|
||
|
|
assert.Equal(t, int64(1640995200), req.Challenge.Timestamp)
|
||
|
|
assert.Equal(t, 4, req.Challenge.Difficulty)
|
||
|
|
assert.Equal(t, "quotes", req.Challenge.Resource)
|
||
|
|
assert.Equal(t, uint64(12345), req.Nonce)
|
||
|
|
})
|
||
|
|
|
||
|
|
t.Run("empty payload should error", func(t *testing.T) {
|
||
|
|
err := req.Decode(nil)
|
||
|
|
assert.Error(t, err)
|
||
|
|
})
|
||
|
|
|
||
|
|
t.Run("invalid JSON should error", func(t *testing.T) {
|
||
|
|
payload := `{invalid json}`
|
||
|
|
reader := bytes.NewReader([]byte(payload))
|
||
|
|
|
||
|
|
err := req.Decode(reader)
|
||
|
|
assert.Error(t, err)
|
||
|
|
})
|
||
|
|
|
||
|
|
t.Run("invalid UTF-8 should error", func(t *testing.T) {
|
||
|
|
payload := []byte{0xFF, 0xFE, 0xFD}
|
||
|
|
reader := bytes.NewReader(payload)
|
||
|
|
|
||
|
|
err := req.Decode(reader)
|
||
|
|
assert.Error(t, err)
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestEndToEnd_RequestDecoding(t *testing.T) {
|
||
|
|
decoder := NewMessageDecoder()
|
||
|
|
|
||
|
|
t.Run("challenge request flow", func(t *testing.T) {
|
||
|
|
// Create message data: type=0x01, length=0
|
||
|
|
data := []byte{0x01, 0x00, 0x00, 0x00, 0x00}
|
||
|
|
buf := bytes.NewBuffer(data)
|
||
|
|
|
||
|
|
// Decode header
|
||
|
|
msg, err := decoder.Decode(buf)
|
||
|
|
require.NoError(t, err)
|
||
|
|
assert.Equal(t, ChallengeRequestType, msg.Type)
|
||
|
|
|
||
|
|
// Decode request
|
||
|
|
req := &ChallengeRequest{}
|
||
|
|
err = req.Decode(msg.PayloadStream)
|
||
|
|
require.NoError(t, err)
|
||
|
|
})
|
||
|
|
|
||
|
|
t.Run("solution request flow", func(t *testing.T) {
|
||
|
|
payload := `{"challenge":{"timestamp":1640995200,"difficulty":4,"resource":"quotes","random":"cmFuZG9tMTIz","hmac":"aG1hY19zaWduYXR1cmU="},"nonce":12345}`
|
||
|
|
|
||
|
|
// Create message data: type=0x03, length, payload
|
||
|
|
var buf bytes.Buffer
|
||
|
|
buf.WriteByte(0x03) // SolutionRequestType
|
||
|
|
length := uint32(len(payload))
|
||
|
|
buf.Write(encodeBigEndianUint32(length))
|
||
|
|
buf.WriteString(payload)
|
||
|
|
|
||
|
|
// Decode header
|
||
|
|
msg, err := decoder.Decode(&buf)
|
||
|
|
require.NoError(t, err)
|
||
|
|
assert.Equal(t, SolutionRequestType, msg.Type)
|
||
|
|
assert.Equal(t, length, msg.PayloadLength)
|
||
|
|
|
||
|
|
// Decode request
|
||
|
|
req := &SolutionRequest{}
|
||
|
|
err = req.Decode(msg.PayloadStream)
|
||
|
|
require.NoError(t, err)
|
||
|
|
assert.Equal(t, uint64(12345), req.Nonce)
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
// Helper functions for testing
|
||
|
|
|
||
|
|
func encodeBigEndianUint32(val uint32) []byte {
|
||
|
|
return []byte{
|
||
|
|
byte(val >> 24),
|
||
|
|
byte(val >> 16),
|
||
|
|
byte(val >> 8),
|
||
|
|
byte(val),
|
||
|
|
}
|
||
|
|
}
|