2025-08-23 08:50:42 +03:00
|
|
|
package server
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"context"
|
|
|
|
|
"fmt"
|
|
|
|
|
"io"
|
|
|
|
|
"log/slog"
|
|
|
|
|
"net"
|
|
|
|
|
"sync"
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
"hash-of-wisdom/internal/application"
|
|
|
|
|
"hash-of-wisdom/internal/lib/sl"
|
|
|
|
|
"hash-of-wisdom/internal/protocol"
|
|
|
|
|
"hash-of-wisdom/internal/service"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
// TCPServer handles TCP connections for the Word of Wisdom protocol
|
|
|
|
|
type TCPServer struct {
|
|
|
|
|
config *Config
|
|
|
|
|
wisdomApplication *application.WisdomApplication
|
|
|
|
|
decoder *protocol.MessageDecoder
|
|
|
|
|
listener net.Listener
|
|
|
|
|
logger *slog.Logger
|
|
|
|
|
wg sync.WaitGroup
|
2025-08-23 13:18:22 +03:00
|
|
|
cancel context.CancelFunc
|
2025-08-23 08:50:42 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Option is a functional option for configuring TCPServer
|
|
|
|
|
type option func(*TCPServer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// WithLogger sets a custom logger
|
|
|
|
|
func WithLogger(logger *slog.Logger) option {
|
|
|
|
|
return func(s *TCPServer) {
|
|
|
|
|
s.logger = logger
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2025-08-23 12:12:45 +03:00
|
|
|
// NewTCPServer creates a new TCP server with required configuration
|
|
|
|
|
func NewTCPServer(wisdomService *service.WisdomService, config *Config, opts ...option) *TCPServer {
|
2025-08-23 08:50:42 +03:00
|
|
|
server := &TCPServer{
|
2025-08-23 12:12:45 +03:00
|
|
|
config: config,
|
2025-08-23 08:50:42 +03:00
|
|
|
wisdomApplication: application.NewWisdomApplication(wisdomService),
|
|
|
|
|
decoder: protocol.NewMessageDecoder(),
|
|
|
|
|
logger: slog.Default(),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for _, opt := range opts {
|
|
|
|
|
opt(server)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return server
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Start starts the TCP server
|
|
|
|
|
func (s *TCPServer) Start(ctx context.Context) error {
|
|
|
|
|
listener, err := net.Listen("tcp", s.config.Address)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return fmt.Errorf("failed to listen on %s: %w", s.config.Address, err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
s.listener = listener
|
|
|
|
|
s.logger.Info("tcp server started", "address", s.config.Address)
|
|
|
|
|
|
2025-08-23 13:18:22 +03:00
|
|
|
// Create cancellable context for server lifecycle
|
|
|
|
|
serverCtx, cancel := context.WithCancel(ctx)
|
|
|
|
|
s.cancel = cancel
|
|
|
|
|
|
|
|
|
|
go s.acceptLoop(serverCtx)
|
2025-08-23 08:50:42 +03:00
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Stop gracefully stops the server
|
|
|
|
|
func (s *TCPServer) Stop() error {
|
|
|
|
|
s.logger.Info("stopping tcp server")
|
2025-08-23 13:18:22 +03:00
|
|
|
|
|
|
|
|
// Cancel server context to stop accept loop and active connections
|
|
|
|
|
if s.cancel != nil {
|
|
|
|
|
s.cancel()
|
|
|
|
|
}
|
2025-08-23 08:50:42 +03:00
|
|
|
|
|
|
|
|
if s.listener != nil {
|
|
|
|
|
s.listener.Close()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
s.wg.Wait()
|
|
|
|
|
s.logger.Info("tcp server stopped")
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Address returns the server's listening address
|
|
|
|
|
func (s *TCPServer) Address() string {
|
|
|
|
|
if s.listener != nil {
|
|
|
|
|
return s.listener.Addr().String()
|
|
|
|
|
}
|
|
|
|
|
return s.config.Address
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// acceptLoop accepts and handles incoming connections
|
|
|
|
|
func (s *TCPServer) acceptLoop(ctx context.Context) {
|
|
|
|
|
for {
|
|
|
|
|
select {
|
|
|
|
|
case <-ctx.Done():
|
|
|
|
|
return
|
|
|
|
|
default:
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
rawConn, err := s.listener.Accept()
|
|
|
|
|
if err != nil {
|
|
|
|
|
select {
|
2025-08-23 13:18:22 +03:00
|
|
|
case <-ctx.Done():
|
2025-08-23 08:50:42 +03:00
|
|
|
return
|
|
|
|
|
default:
|
|
|
|
|
s.logger.Error("accept error", sl.Err(err))
|
|
|
|
|
continue
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
s.wg.Add(1)
|
|
|
|
|
go func() {
|
|
|
|
|
defer s.wg.Done()
|
|
|
|
|
s.handleConnection(ctx, rawConn)
|
|
|
|
|
}()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// handleConnection handles a single client connection
|
|
|
|
|
func (s *TCPServer) handleConnection(ctx context.Context, rawConn net.Conn) {
|
|
|
|
|
defer rawConn.Close()
|
|
|
|
|
|
|
|
|
|
connLogger := s.logger.With("remote_addr", rawConn.RemoteAddr().String())
|
|
|
|
|
connLogger.Info("connection accepted")
|
|
|
|
|
|
|
|
|
|
// Create connection-scoped context for overall timeout
|
|
|
|
|
connCtx, cancel := context.WithTimeout(ctx, s.config.Timeouts.Connection)
|
|
|
|
|
defer cancel()
|
|
|
|
|
|
|
|
|
|
if err := s.processConnection(connCtx, rawConn, connLogger); err != nil {
|
|
|
|
|
connLogger.Error("connection error", sl.Err(err))
|
|
|
|
|
} else {
|
|
|
|
|
connLogger.Debug("connection completed successfully")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// deadlineConn wraps a connection to automatically set deadlines on each read/write
|
|
|
|
|
type deadlineConn struct {
|
|
|
|
|
net.Conn
|
|
|
|
|
rto, wto time.Duration
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (d *deadlineConn) Read(p []byte) (int, error) {
|
|
|
|
|
_ = d.SetReadDeadline(time.Now().Add(d.rto))
|
|
|
|
|
return d.Conn.Read(p)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (d *deadlineConn) Write(p []byte) (int, error) {
|
|
|
|
|
_ = d.SetWriteDeadline(time.Now().Add(d.wto))
|
|
|
|
|
return d.Conn.Write(p)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// processConnection handles the protocol message exchange
|
|
|
|
|
func (s *TCPServer) processConnection(ctx context.Context, conn net.Conn, logger *slog.Logger) error {
|
|
|
|
|
// Set overall connection deadline
|
|
|
|
|
globalDL := time.Now().Add(s.config.Timeouts.Connection)
|
|
|
|
|
if err := conn.SetDeadline(globalDL); err != nil {
|
|
|
|
|
return fmt.Errorf("failed to set connection deadline: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Create deadline wrapper for automatic timeout handling
|
|
|
|
|
dc := &deadlineConn{
|
|
|
|
|
Conn: conn,
|
|
|
|
|
rto: s.config.Timeouts.Read,
|
|
|
|
|
wto: s.config.Timeouts.Write,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Use LimitReader to prevent reading more than the protocol-defined maximum message size
|
|
|
|
|
// This protects against malicious clients that lie about payload length in headers
|
|
|
|
|
limitedReader := io.LimitReader(dc, int64(protocol.MaxPayloadSize)+protocol.HeaderSize)
|
|
|
|
|
|
|
|
|
|
// Read incoming message through limited reader
|
|
|
|
|
logger.Debug("reading message from client")
|
|
|
|
|
msg, err := s.decoder.Decode(limitedReader)
|
|
|
|
|
if err != nil {
|
|
|
|
|
if err == io.EOF {
|
|
|
|
|
logger.Debug("client closed connection gracefully")
|
|
|
|
|
return nil
|
|
|
|
|
}
|
|
|
|
|
logger.Error("failed to decode message", sl.Err(err))
|
|
|
|
|
return fmt.Errorf("decode error: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
logger.Debug("message decoded", "type", msg.Type, "payload_length", msg.PayloadLength)
|
|
|
|
|
|
|
|
|
|
// Process message through application layer
|
|
|
|
|
response, err := s.wisdomApplication.HandleMessage(ctx, msg)
|
|
|
|
|
if err != nil {
|
|
|
|
|
logger.Error("application error", sl.Err(err))
|
|
|
|
|
return fmt.Errorf("application error: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
logger.Debug("sending response to client")
|
|
|
|
|
// Send response using the response's own Encode method
|
|
|
|
|
if err := response.Encode(dc); err != nil {
|
|
|
|
|
logger.Error("failed to encode response", sl.Err(err))
|
|
|
|
|
return fmt.Errorf("encode error: %w", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
logger.Debug("response sent successfully")
|
|
|
|
|
return nil
|
|
|
|
|
}
|