hash-of-wisdom/internal/server/tcp.go

213 lines
5.4 KiB
Go

package server
import (
"context"
"fmt"
"io"
"log/slog"
"net"
"sync"
"time"
"hash-of-wisdom/internal/application"
"hash-of-wisdom/internal/lib/sl"
"hash-of-wisdom/internal/protocol"
"hash-of-wisdom/internal/service"
)
// TCPServer handles TCP connections for the Word of Wisdom protocol
type TCPServer struct {
config *Config
wisdomApplication *application.WisdomApplication
decoder *protocol.MessageDecoder
listener net.Listener
logger *slog.Logger
wg sync.WaitGroup
cancel context.CancelFunc
}
// Option is a functional option for configuring TCPServer
type option func(*TCPServer)
// WithLogger sets a custom logger
func WithLogger(logger *slog.Logger) option {
return func(s *TCPServer) {
s.logger = logger
}
}
// NewTCPServer creates a new TCP server with required configuration
func NewTCPServer(wisdomService *service.WisdomService, config *Config, opts ...option) *TCPServer {
server := &TCPServer{
config: config,
wisdomApplication: application.NewWisdomApplication(wisdomService),
decoder: protocol.NewMessageDecoder(),
logger: slog.Default(),
}
for _, opt := range opts {
opt(server)
}
return server
}
// Start starts the TCP server
func (s *TCPServer) Start(ctx context.Context) error {
listener, err := net.Listen("tcp", s.config.Address)
if err != nil {
return fmt.Errorf("failed to listen on %s: %w", s.config.Address, err)
}
s.listener = listener
s.logger.Info("tcp server started", "address", s.config.Address)
// Create cancellable context for server lifecycle
serverCtx, cancel := context.WithCancel(ctx)
s.cancel = cancel
go s.acceptLoop(serverCtx)
return nil
}
// Stop gracefully stops the server
func (s *TCPServer) Stop() error {
s.logger.Info("stopping tcp server")
// Cancel server context to stop accept loop and active connections
if s.cancel != nil {
s.cancel()
}
if s.listener != nil {
s.listener.Close()
}
s.wg.Wait()
s.logger.Info("tcp server stopped")
return nil
}
// Address returns the server's listening address
func (s *TCPServer) Address() string {
if s.listener != nil {
return s.listener.Addr().String()
}
return s.config.Address
}
// acceptLoop accepts and handles incoming connections
func (s *TCPServer) acceptLoop(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
default:
}
rawConn, err := s.listener.Accept()
if err != nil {
select {
case <-ctx.Done():
return
default:
s.logger.Error("accept error", sl.Err(err))
continue
}
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.handleConnection(ctx, rawConn)
}()
}
}
// handleConnection handles a single client connection
func (s *TCPServer) handleConnection(ctx context.Context, rawConn net.Conn) {
defer rawConn.Close()
connLogger := s.logger.With("remote_addr", rawConn.RemoteAddr().String())
connLogger.Info("connection accepted")
// Create connection-scoped context for overall timeout
connCtx, cancel := context.WithTimeout(ctx, s.config.Timeouts.Connection)
defer cancel()
if err := s.processConnection(connCtx, rawConn, connLogger); err != nil {
connLogger.Error("connection error", sl.Err(err))
} else {
connLogger.Debug("connection completed successfully")
}
}
// deadlineConn wraps a connection to automatically set deadlines on each read/write
type deadlineConn struct {
net.Conn
rto, wto time.Duration
}
func (d *deadlineConn) Read(p []byte) (int, error) {
_ = d.SetReadDeadline(time.Now().Add(d.rto))
return d.Conn.Read(p)
}
func (d *deadlineConn) Write(p []byte) (int, error) {
_ = d.SetWriteDeadline(time.Now().Add(d.wto))
return d.Conn.Write(p)
}
// processConnection handles the protocol message exchange
func (s *TCPServer) processConnection(ctx context.Context, conn net.Conn, logger *slog.Logger) error {
// Set overall connection deadline
globalDL := time.Now().Add(s.config.Timeouts.Connection)
if err := conn.SetDeadline(globalDL); err != nil {
return fmt.Errorf("failed to set connection deadline: %w", err)
}
// Create deadline wrapper for automatic timeout handling
dc := &deadlineConn{
Conn: conn,
rto: s.config.Timeouts.Read,
wto: s.config.Timeouts.Write,
}
// Use LimitReader to prevent reading more than the protocol-defined maximum message size
// This protects against malicious clients that lie about payload length in headers
limitedReader := io.LimitReader(dc, int64(protocol.MaxPayloadSize)+protocol.HeaderSize)
// Read incoming message through limited reader
logger.Debug("reading message from client")
msg, err := s.decoder.Decode(limitedReader)
if err != nil {
if err == io.EOF {
logger.Debug("client closed connection gracefully")
return nil
}
logger.Error("failed to decode message", sl.Err(err))
return fmt.Errorf("decode error: %w", err)
}
logger.Debug("message decoded", "type", msg.Type, "payload_length", msg.PayloadLength)
// Process message through application layer
response, err := s.wisdomApplication.HandleMessage(ctx, msg)
if err != nil {
logger.Error("application error", sl.Err(err))
return fmt.Errorf("application error: %w", err)
}
logger.Debug("sending response to client")
// Send response using the response's own Encode method
if err := response.Encode(dc); err != nil {
logger.Error("failed to encode response", sl.Err(err))
return fmt.Errorf("encode error: %w", err)
}
logger.Debug("response sent successfully")
return nil
}