package server import ( "context" "fmt" "io" "log/slog" "net" "sync" "time" "hash-of-wisdom/internal/application" "hash-of-wisdom/internal/lib/sl" "hash-of-wisdom/internal/protocol" "hash-of-wisdom/internal/service" ) // TCPServer handles TCP connections for the Word of Wisdom protocol type TCPServer struct { config *Config wisdomApplication *application.WisdomApplication decoder *protocol.MessageDecoder listener net.Listener logger *slog.Logger wg sync.WaitGroup shutdown chan struct{} } // Option is a functional option for configuring TCPServer type option func(*TCPServer) // WithConfig sets a custom configuration func WithConfig(config *Config) option { return func(s *TCPServer) { s.config = config } } // WithLogger sets a custom logger func WithLogger(logger *slog.Logger) option { return func(s *TCPServer) { s.logger = logger } } // NewTCPServer creates a new TCP server with optional configuration func NewTCPServer(wisdomService *service.WisdomService, opts ...option) *TCPServer { server := &TCPServer{ config: DefaultConfig(), wisdomApplication: application.NewWisdomApplication(wisdomService), decoder: protocol.NewMessageDecoder(), logger: slog.Default(), shutdown: make(chan struct{}), } for _, opt := range opts { opt(server) } return server } // Start starts the TCP server func (s *TCPServer) Start(ctx context.Context) error { listener, err := net.Listen("tcp", s.config.Address) if err != nil { return fmt.Errorf("failed to listen on %s: %w", s.config.Address, err) } s.listener = listener s.logger.Info("tcp server started", "address", s.config.Address) go s.acceptLoop(ctx) return nil } // Stop gracefully stops the server func (s *TCPServer) Stop() error { s.logger.Info("stopping tcp server") close(s.shutdown) if s.listener != nil { s.listener.Close() } s.wg.Wait() s.logger.Info("tcp server stopped") return nil } // Address returns the server's listening address func (s *TCPServer) Address() string { if s.listener != nil { return s.listener.Addr().String() } return s.config.Address } // acceptLoop accepts and handles incoming connections func (s *TCPServer) acceptLoop(ctx context.Context) { for { select { case <-s.shutdown: return case <-ctx.Done(): return default: } rawConn, err := s.listener.Accept() if err != nil { select { case <-s.shutdown: return default: s.logger.Error("accept error", sl.Err(err)) continue } } s.wg.Add(1) go func() { defer s.wg.Done() s.handleConnection(ctx, rawConn) }() } } // handleConnection handles a single client connection func (s *TCPServer) handleConnection(ctx context.Context, rawConn net.Conn) { defer rawConn.Close() connLogger := s.logger.With("remote_addr", rawConn.RemoteAddr().String()) connLogger.Info("connection accepted") // Create connection-scoped context for overall timeout connCtx, cancel := context.WithTimeout(ctx, s.config.Timeouts.Connection) defer cancel() if err := s.processConnection(connCtx, rawConn, connLogger); err != nil { connLogger.Error("connection error", sl.Err(err)) } else { connLogger.Debug("connection completed successfully") } } // deadlineConn wraps a connection to automatically set deadlines on each read/write type deadlineConn struct { net.Conn rto, wto time.Duration } func (d *deadlineConn) Read(p []byte) (int, error) { _ = d.SetReadDeadline(time.Now().Add(d.rto)) return d.Conn.Read(p) } func (d *deadlineConn) Write(p []byte) (int, error) { _ = d.SetWriteDeadline(time.Now().Add(d.wto)) return d.Conn.Write(p) } // processConnection handles the protocol message exchange func (s *TCPServer) processConnection(ctx context.Context, conn net.Conn, logger *slog.Logger) error { // Set overall connection deadline globalDL := time.Now().Add(s.config.Timeouts.Connection) if err := conn.SetDeadline(globalDL); err != nil { return fmt.Errorf("failed to set connection deadline: %w", err) } // Create deadline wrapper for automatic timeout handling dc := &deadlineConn{ Conn: conn, rto: s.config.Timeouts.Read, wto: s.config.Timeouts.Write, } // Use LimitReader to prevent reading more than the protocol-defined maximum message size // This protects against malicious clients that lie about payload length in headers limitedReader := io.LimitReader(dc, int64(protocol.MaxPayloadSize)+protocol.HeaderSize) // Read incoming message through limited reader logger.Debug("reading message from client") msg, err := s.decoder.Decode(limitedReader) if err != nil { if err == io.EOF { logger.Debug("client closed connection gracefully") return nil } logger.Error("failed to decode message", sl.Err(err)) return fmt.Errorf("decode error: %w", err) } logger.Debug("message decoded", "type", msg.Type, "payload_length", msg.PayloadLength) // Process message through application layer response, err := s.wisdomApplication.HandleMessage(ctx, msg) if err != nil { logger.Error("application error", sl.Err(err)) return fmt.Errorf("application error: %w", err) } logger.Debug("sending response to client") // Send response using the response's own Encode method if err := response.Encode(dc); err != nil { logger.Error("failed to encode response", sl.Err(err)) return fmt.Errorf("encode error: %w", err) } logger.Debug("response sent successfully") return nil }