hash-of-wisdom/internal/server/tcp.go

237 lines
6.2 KiB
Go
Raw Normal View History

2025-08-23 08:50:42 +03:00
package server
import (
"context"
"fmt"
"io"
"log/slog"
"net"
"sync"
"time"
"hash-of-wisdom/internal/application"
"hash-of-wisdom/internal/lib/sl"
2025-08-23 13:25:04 +03:00
"hash-of-wisdom/internal/metrics"
2025-08-23 08:50:42 +03:00
"hash-of-wisdom/internal/protocol"
"hash-of-wisdom/internal/service"
)
// TCPServer handles TCP connections for the Word of Wisdom protocol
type TCPServer struct {
config *Config
wisdomApplication *application.WisdomApplication
decoder *protocol.MessageDecoder
listener net.Listener
logger *slog.Logger
wg sync.WaitGroup
cancel context.CancelFunc
2025-08-23 08:50:42 +03:00
}
// Option is a functional option for configuring TCPServer
type option func(*TCPServer)
// WithLogger sets a custom logger
func WithLogger(logger *slog.Logger) option {
return func(s *TCPServer) {
s.logger = logger
}
}
2025-08-23 12:12:45 +03:00
// NewTCPServer creates a new TCP server with required configuration
func NewTCPServer(wisdomService *service.WisdomService, config *Config, opts ...option) *TCPServer {
2025-08-23 08:50:42 +03:00
server := &TCPServer{
2025-08-23 12:12:45 +03:00
config: config,
2025-08-23 08:50:42 +03:00
wisdomApplication: application.NewWisdomApplication(wisdomService),
decoder: protocol.NewMessageDecoder(),
logger: slog.Default(),
}
for _, opt := range opts {
opt(server)
}
return server
}
// Start starts the TCP server
func (s *TCPServer) Start(ctx context.Context) error {
listener, err := net.Listen("tcp", s.config.Address)
if err != nil {
return fmt.Errorf("failed to listen on %s: %w", s.config.Address, err)
}
s.listener = listener
s.logger.Info("tcp server started", "address", s.config.Address)
// Create cancellable context for server lifecycle
serverCtx, cancel := context.WithCancel(ctx)
s.cancel = cancel
go s.acceptLoop(serverCtx)
2025-08-23 08:50:42 +03:00
return nil
}
// Stop gracefully stops the server
func (s *TCPServer) Stop() error {
s.logger.Info("stopping tcp server")
// Cancel server context to stop accept loop and active connections
if s.cancel != nil {
s.cancel()
}
2025-08-23 08:50:42 +03:00
if s.listener != nil {
s.listener.Close()
}
s.wg.Wait()
s.logger.Info("tcp server stopped")
return nil
}
// Address returns the server's listening address
func (s *TCPServer) Address() string {
if s.listener != nil {
return s.listener.Addr().String()
}
return s.config.Address
}
// acceptLoop accepts and handles incoming connections
func (s *TCPServer) acceptLoop(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
default:
}
rawConn, err := s.listener.Accept()
if err != nil {
select {
case <-ctx.Done():
2025-08-23 08:50:42 +03:00
return
default:
s.logger.Error("accept error", sl.Err(err))
continue
}
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
s.handleConnection(ctx, rawConn)
}()
}
}
// handleConnection handles a single client connection
func (s *TCPServer) handleConnection(ctx context.Context, rawConn net.Conn) {
defer rawConn.Close()
2025-08-23 13:25:04 +03:00
// Track active connections
metrics.ActiveConnections.Inc()
defer metrics.ActiveConnections.Dec()
2025-08-23 08:50:42 +03:00
connLogger := s.logger.With("remote_addr", rawConn.RemoteAddr().String())
connLogger.Info("connection accepted")
// Create connection-scoped context for overall timeout
connCtx, cancel := context.WithTimeout(ctx, s.config.Timeouts.Connection)
defer cancel()
if err := s.processConnection(connCtx, rawConn, connLogger); err != nil {
connLogger.Error("connection error", sl.Err(err))
} else {
connLogger.Debug("connection completed successfully")
}
}
// deadlineConn wraps a connection to automatically set deadlines on each read/write
type deadlineConn struct {
net.Conn
rto, wto time.Duration
}
func (d *deadlineConn) Read(p []byte) (int, error) {
_ = d.SetReadDeadline(time.Now().Add(d.rto))
return d.Conn.Read(p)
}
func (d *deadlineConn) Write(p []byte) (int, error) {
_ = d.SetWriteDeadline(time.Now().Add(d.wto))
return d.Conn.Write(p)
}
// processConnection handles the protocol message exchange
func (s *TCPServer) processConnection(ctx context.Context, conn net.Conn, logger *slog.Logger) error {
// Set overall connection deadline
globalDL := time.Now().Add(s.config.Timeouts.Connection)
if err := conn.SetDeadline(globalDL); err != nil {
return fmt.Errorf("failed to set connection deadline: %w", err)
}
// Create deadline wrapper for automatic timeout handling
dc := &deadlineConn{
Conn: conn,
rto: s.config.Timeouts.Read,
wto: s.config.Timeouts.Write,
}
// Use LimitReader to prevent reading more than the protocol-defined maximum message size
// This protects against malicious clients that lie about payload length in headers
limitedReader := io.LimitReader(dc, int64(protocol.MaxPayloadSize)+protocol.HeaderSize)
// Read incoming message through limited reader
logger.Debug("reading message from client")
msg, err := s.decoder.Decode(limitedReader)
if err != nil {
if err == io.EOF {
logger.Debug("client closed connection gracefully")
return nil
}
2025-08-23 13:25:04 +03:00
metrics.RequestErrors.WithLabelValues("decode_error").Inc()
2025-08-23 08:50:42 +03:00
logger.Error("failed to decode message", sl.Err(err))
return fmt.Errorf("decode error: %w", err)
}
logger.Debug("message decoded", "type", msg.Type, "payload_length", msg.PayloadLength)
2025-08-23 13:25:04 +03:00
// Track all requests
metrics.RequestsTotal.Inc()
// Process message through application layer with timing
start := time.Now()
2025-08-23 08:50:42 +03:00
response, err := s.wisdomApplication.HandleMessage(ctx, msg)
2025-08-23 13:25:04 +03:00
duration := time.Since(start)
metrics.RequestDuration.Observe(duration.Seconds())
2025-08-23 08:50:42 +03:00
if err != nil {
2025-08-23 13:25:04 +03:00
metrics.RequestErrors.WithLabelValues("internal_error").Inc()
2025-08-23 08:50:42 +03:00
logger.Error("application error", sl.Err(err))
return fmt.Errorf("application error: %w", err)
}
2025-08-23 13:25:04 +03:00
// Check if response is an error response
if errorResp, isError := response.(*protocol.ErrorResponse); isError {
metrics.RequestErrors.WithLabelValues(string(errorResp.Code)).Inc()
} else {
// Track quotes served for successful solution requests
if msg.Type == protocol.SolutionRequestType {
metrics.QuotesServed.Inc()
}
}
2025-08-23 08:50:42 +03:00
logger.Debug("sending response to client")
// Send response using the response's own Encode method
if err := response.Encode(dc); err != nil {
logger.Error("failed to encode response", sl.Err(err))
return fmt.Errorf("encode error: %w", err)
}
logger.Debug("response sent successfully")
return nil
}