diff --git a/internal/server/tcp.go b/internal/server/tcp.go new file mode 100644 index 0000000..6d1f4c8 --- /dev/null +++ b/internal/server/tcp.go @@ -0,0 +1,213 @@ +package server + +import ( + "context" + "fmt" + "io" + "log/slog" + "net" + "sync" + "time" + + "hash-of-wisdom/internal/application" + "hash-of-wisdom/internal/lib/sl" + "hash-of-wisdom/internal/protocol" + "hash-of-wisdom/internal/service" +) + +// TCPServer handles TCP connections for the Word of Wisdom protocol +type TCPServer struct { + config *Config + wisdomApplication *application.WisdomApplication + decoder *protocol.MessageDecoder + listener net.Listener + logger *slog.Logger + wg sync.WaitGroup + shutdown chan struct{} +} + +// Option is a functional option for configuring TCPServer +type option func(*TCPServer) + +// WithConfig sets a custom configuration +func WithConfig(config *Config) option { + return func(s *TCPServer) { + s.config = config + } +} + +// WithLogger sets a custom logger +func WithLogger(logger *slog.Logger) option { + return func(s *TCPServer) { + s.logger = logger + } +} + +// NewTCPServer creates a new TCP server with optional configuration +func NewTCPServer(wisdomService *service.WisdomService, opts ...option) *TCPServer { + server := &TCPServer{ + config: DefaultConfig(), + wisdomApplication: application.NewWisdomApplication(wisdomService), + decoder: protocol.NewMessageDecoder(), + logger: slog.Default(), + shutdown: make(chan struct{}), + } + + for _, opt := range opts { + opt(server) + } + + return server +} + + +// Start starts the TCP server +func (s *TCPServer) Start(ctx context.Context) error { + listener, err := net.Listen("tcp", s.config.Address) + if err != nil { + return fmt.Errorf("failed to listen on %s: %w", s.config.Address, err) + } + + s.listener = listener + s.logger.Info("tcp server started", "address", s.config.Address) + + go s.acceptLoop(ctx) + return nil +} + +// Stop gracefully stops the server +func (s *TCPServer) Stop() error { + s.logger.Info("stopping tcp server") + close(s.shutdown) + + if s.listener != nil { + s.listener.Close() + } + + s.wg.Wait() + s.logger.Info("tcp server stopped") + return nil +} + +// Address returns the server's listening address +func (s *TCPServer) Address() string { + if s.listener != nil { + return s.listener.Addr().String() + } + return s.config.Address +} + +// acceptLoop accepts and handles incoming connections +func (s *TCPServer) acceptLoop(ctx context.Context) { + for { + select { + case <-s.shutdown: + return + case <-ctx.Done(): + return + default: + } + + rawConn, err := s.listener.Accept() + if err != nil { + select { + case <-s.shutdown: + return + default: + s.logger.Error("accept error", sl.Err(err)) + continue + } + } + + s.wg.Add(1) + go func() { + defer s.wg.Done() + s.handleConnection(ctx, rawConn) + }() + } +} + +// handleConnection handles a single client connection +func (s *TCPServer) handleConnection(ctx context.Context, rawConn net.Conn) { + defer rawConn.Close() + + connLogger := s.logger.With("remote_addr", rawConn.RemoteAddr().String()) + connLogger.Info("connection accepted") + + // Create connection-scoped context for overall timeout + connCtx, cancel := context.WithTimeout(ctx, s.config.Timeouts.Connection) + defer cancel() + + if err := s.processConnection(connCtx, rawConn, connLogger); err != nil { + connLogger.Error("connection error", sl.Err(err)) + } else { + connLogger.Debug("connection completed successfully") + } +} + +// deadlineConn wraps a connection to automatically set deadlines on each read/write +type deadlineConn struct { + net.Conn + rto, wto time.Duration +} + +func (d *deadlineConn) Read(p []byte) (int, error) { + _ = d.SetReadDeadline(time.Now().Add(d.rto)) + return d.Conn.Read(p) +} + +func (d *deadlineConn) Write(p []byte) (int, error) { + _ = d.SetWriteDeadline(time.Now().Add(d.wto)) + return d.Conn.Write(p) +} + +// processConnection handles the protocol message exchange +func (s *TCPServer) processConnection(ctx context.Context, conn net.Conn, logger *slog.Logger) error { + // Set overall connection deadline + globalDL := time.Now().Add(s.config.Timeouts.Connection) + if err := conn.SetDeadline(globalDL); err != nil { + return fmt.Errorf("failed to set connection deadline: %w", err) + } + + // Create deadline wrapper for automatic timeout handling + dc := &deadlineConn{ + Conn: conn, + rto: s.config.Timeouts.Read, + wto: s.config.Timeouts.Write, + } + + // Use LimitReader to prevent reading more than the protocol-defined maximum message size + // This protects against malicious clients that lie about payload length in headers + limitedReader := io.LimitReader(dc, int64(protocol.MaxPayloadSize)+protocol.HeaderSize) + + // Read incoming message through limited reader + logger.Debug("reading message from client") + msg, err := s.decoder.Decode(limitedReader) + if err != nil { + if err == io.EOF { + logger.Debug("client closed connection gracefully") + return nil + } + logger.Error("failed to decode message", sl.Err(err)) + return fmt.Errorf("decode error: %w", err) + } + + logger.Debug("message decoded", "type", msg.Type, "payload_length", msg.PayloadLength) + + // Process message through application layer + response, err := s.wisdomApplication.HandleMessage(ctx, msg) + if err != nil { + logger.Error("application error", sl.Err(err)) + return fmt.Errorf("application error: %w", err) + } + + logger.Debug("sending response to client") + // Send response using the response's own Encode method + if err := response.Encode(dc); err != nil { + logger.Error("failed to encode response", sl.Err(err)) + return fmt.Errorf("encode error: %w", err) + } + + logger.Debug("response sent successfully") + return nil +}