diff --git a/cmd/server/main.go b/cmd/server/main.go new file mode 100644 index 0000000..055e4af --- /dev/null +++ b/cmd/server/main.go @@ -0,0 +1,73 @@ +package main + +import ( + "context" + "log/slog" + "os" + "os/signal" + "syscall" + "time" + + "hash-of-wisdom/internal/lib/sl" + "hash-of-wisdom/internal/pow/challenge" + "hash-of-wisdom/internal/quotes" + "hash-of-wisdom/internal/server" + "hash-of-wisdom/internal/service" +) + +func main() { + addr := ":8080" + if len(os.Args) > 1 { + addr = os.Args[1] + } + + logger := slog.Default() + logger.Info("starting word of wisdom server", "address", addr) + + // Create components + challengeConfig, err := challenge.NewConfig() + if err != nil { + logger.Error("failed to create config", sl.Err(err)) + os.Exit(1) + } + generator := challenge.NewGenerator(challengeConfig) + verifier := challenge.NewVerifier(challengeConfig) + quoteService := quotes.NewHTTPService() + + // Wire up service + genAdapter := service.NewGeneratorAdapter(generator) + wisdomService := service.NewWisdomService(genAdapter, verifier, quoteService) + + // Create server configuration + serverConfig := server.DefaultConfig() + serverConfig.Address = addr + + // Create server + srv := server.NewTCPServer(wisdomService, + server.WithConfig(serverConfig), + server.WithLogger(logger)) + + // Start server + ctx := context.Background() + if err := srv.Start(ctx); err != nil { + logger.Error("failed to start server", sl.Err(err)) + os.Exit(1) + } + + // Wait for interrupt + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + logger.Info("server ready - press ctrl+c to stop") + <-sigChan + + // Graceful shutdown + logger.Info("shutting down server") + if err := srv.Stop(); err != nil { + logger.Error("error during shutdown", sl.Err(err)) + } + + // Give connections time to close + time.Sleep(100 * time.Millisecond) + logger.Info("server stopped") +} diff --git a/docs/IMPLEMENTATION.md b/docs/IMPLEMENTATION.md index 3b0e54f..6553026 100644 --- a/docs/IMPLEMENTATION.md +++ b/docs/IMPLEMENTATION.md @@ -77,23 +77,33 @@ - [X] Fix application tests for new protocol design ## Phase 6: TCP Server & Connection Management -- [ ] Implement TCP server with connection handling -- [ ] Add dual timeout protection: - - [ ] Connection timeout (max total connection time) - - [ ] Read timeout (max idle time between bytes - slowloris protection) -- [ ] Implement proper connection lifecycle management -- [ ] Create protocol state machine for request/response flow -- [ ] Add graceful connection cleanup and error handling -- [ ] Implement basic client for testing -- [ ] Write integration tests for client-server communication +- [X] Implement TCP server with connection handling +- [X] Add dual timeout protection: + - [X] Connection timeout (max total connection time) + - [X] Read timeout (max idle time between bytes - slowloris protection) +- [X] Implement proper connection lifecycle management +- [X] Create protocol state machine for request/response flow +- [X] Add graceful connection cleanup and error handling +- [X] Add slog structured logging to TCP server +- [X] Implement functional options pattern for server configuration +- [X] Update cmd/server to use new TCP server with logging -## Phase 7: Basic Server Architecture -- [ ] Set up structured logging (zerolog/logrus) +## Phase 7: Client Implementation +- [ ] Create client application structure +- [ ] Implement PoW solver algorithm on client side +- [ ] Create client-side protocol implementation +- [ ] Add retry logic and error handling +- [ ] Implement connection management +- [ ] Create CLI interface for client +- [ ] Add client structured logging +- [ ] Write client unit and integration tests + +## Phase 8: Basic Server Architecture - [ ] Set up metrics collection (prometheus) - [ ] Create configuration management - [ ] Integrate all components into server architecture -## Phase 8: Advanced Server Features +## Phase 9: Advanced Server Features - [ ] Add connection pooling and advanced connection management - [ ] Implement graceful shutdown mechanism - [ ] Add health check endpoints @@ -101,7 +111,7 @@ - [ ] Create health check endpoints - [ ] Write integration tests for server core -## Phase 9: DDOS Protection & Rate Limiting +## Phase 10: DDOS Protection & Rate Limiting - [ ] Implement IP-based connection limiting - [ ] Create rate limiting service with time windows - [ ] Add automatic difficulty adjustment based on load @@ -110,7 +120,7 @@ - [ ] Add monitoring for attack detection - [ ] Write tests for protection mechanisms -## Phase 10: Observability & Monitoring +## Phase 11: Observability & Monitoring - [ ] Add structured logging throughout application - [ ] Implement metrics for key performance indicators: - [ ] Active connections count @@ -122,7 +132,7 @@ - [ ] Add error categorization and reporting - [ ] Implement health check endpoints -## Phase 11: Configuration & Environment Setup +## Phase 12: Configuration & Environment Setup - [ ] Create configuration structure with validation - [ ] Support environment variables and config files - [ ] Add configuration for different environments (dev/prod) @@ -130,16 +140,6 @@ - [ ] Create deployment configuration templates - [ ] Add configuration validation and defaults -## Phase 12: Client Implementation -- [ ] Create client application structure -- [ ] Implement PoW solver algorithm -- [ ] Create client-side protocol implementation -- [ ] Add retry logic and error handling -- [ ] Implement connection management -- [ ] Create CLI interface for client -- [ ] Add client metrics and logging -- [ ] Write client unit and integration tests - ## Phase 13: Docker & Deployment - [ ] Create multi-stage Dockerfile for server - [ ] Create Dockerfile for client diff --git a/internal/lib/sl/sl.go b/internal/lib/sl/sl.go new file mode 100644 index 0000000..4a91b7d --- /dev/null +++ b/internal/lib/sl/sl.go @@ -0,0 +1,27 @@ +package sl + +import ( + "context" + "log/slog" +) + +// Err creates a structured error attribute for slog +func Err(err error) slog.Attr { + return slog.Attr{ + Key: "error", + Value: slog.StringValue(err.Error()), + } +} + +// MockHandler is a test handler that discards all log messages +type MockHandler struct{} + +func (h *MockHandler) Enabled(context.Context, slog.Level) bool { return false } +func (h *MockHandler) Handle(context.Context, slog.Record) error { return nil } +func (h *MockHandler) WithAttrs([]slog.Attr) slog.Handler { return h } +func (h *MockHandler) WithGroup(string) slog.Handler { return h } + +// NewMockLogger creates a logger that discards all messages for testing +func NewMockLogger() *slog.Logger { + return slog.New(&MockHandler{}) +} diff --git a/internal/server/config.go b/internal/server/config.go new file mode 100644 index 0000000..3047106 --- /dev/null +++ b/internal/server/config.go @@ -0,0 +1,31 @@ +package server + +import "time" + +// Config holds configuration for the TCP server +type Config struct { + Address string + Timeouts TimeoutConfig +} + +// TimeoutConfig holds timeout configuration +type TimeoutConfig struct { + // Read timeout protects against slowloris attacks (clients sending data slowly) + Read time.Duration + // Write timeout protects against slow readers (clients reading responses slowly) + Write time.Duration + // Connection timeout is the maximum total connection lifetime + Connection time.Duration +} + +// DefaultConfig returns default server configuration +func DefaultConfig() *Config { + return &Config{ + Address: ":8080", + Timeouts: TimeoutConfig{ + Read: 5 * time.Second, + Write: 5 * time.Second, + Connection: 15 * time.Second, + }, + } +} diff --git a/internal/server/tcp.go b/internal/server/tcp.go new file mode 100644 index 0000000..6d1f4c8 --- /dev/null +++ b/internal/server/tcp.go @@ -0,0 +1,213 @@ +package server + +import ( + "context" + "fmt" + "io" + "log/slog" + "net" + "sync" + "time" + + "hash-of-wisdom/internal/application" + "hash-of-wisdom/internal/lib/sl" + "hash-of-wisdom/internal/protocol" + "hash-of-wisdom/internal/service" +) + +// TCPServer handles TCP connections for the Word of Wisdom protocol +type TCPServer struct { + config *Config + wisdomApplication *application.WisdomApplication + decoder *protocol.MessageDecoder + listener net.Listener + logger *slog.Logger + wg sync.WaitGroup + shutdown chan struct{} +} + +// Option is a functional option for configuring TCPServer +type option func(*TCPServer) + +// WithConfig sets a custom configuration +func WithConfig(config *Config) option { + return func(s *TCPServer) { + s.config = config + } +} + +// WithLogger sets a custom logger +func WithLogger(logger *slog.Logger) option { + return func(s *TCPServer) { + s.logger = logger + } +} + +// NewTCPServer creates a new TCP server with optional configuration +func NewTCPServer(wisdomService *service.WisdomService, opts ...option) *TCPServer { + server := &TCPServer{ + config: DefaultConfig(), + wisdomApplication: application.NewWisdomApplication(wisdomService), + decoder: protocol.NewMessageDecoder(), + logger: slog.Default(), + shutdown: make(chan struct{}), + } + + for _, opt := range opts { + opt(server) + } + + return server +} + + +// Start starts the TCP server +func (s *TCPServer) Start(ctx context.Context) error { + listener, err := net.Listen("tcp", s.config.Address) + if err != nil { + return fmt.Errorf("failed to listen on %s: %w", s.config.Address, err) + } + + s.listener = listener + s.logger.Info("tcp server started", "address", s.config.Address) + + go s.acceptLoop(ctx) + return nil +} + +// Stop gracefully stops the server +func (s *TCPServer) Stop() error { + s.logger.Info("stopping tcp server") + close(s.shutdown) + + if s.listener != nil { + s.listener.Close() + } + + s.wg.Wait() + s.logger.Info("tcp server stopped") + return nil +} + +// Address returns the server's listening address +func (s *TCPServer) Address() string { + if s.listener != nil { + return s.listener.Addr().String() + } + return s.config.Address +} + +// acceptLoop accepts and handles incoming connections +func (s *TCPServer) acceptLoop(ctx context.Context) { + for { + select { + case <-s.shutdown: + return + case <-ctx.Done(): + return + default: + } + + rawConn, err := s.listener.Accept() + if err != nil { + select { + case <-s.shutdown: + return + default: + s.logger.Error("accept error", sl.Err(err)) + continue + } + } + + s.wg.Add(1) + go func() { + defer s.wg.Done() + s.handleConnection(ctx, rawConn) + }() + } +} + +// handleConnection handles a single client connection +func (s *TCPServer) handleConnection(ctx context.Context, rawConn net.Conn) { + defer rawConn.Close() + + connLogger := s.logger.With("remote_addr", rawConn.RemoteAddr().String()) + connLogger.Info("connection accepted") + + // Create connection-scoped context for overall timeout + connCtx, cancel := context.WithTimeout(ctx, s.config.Timeouts.Connection) + defer cancel() + + if err := s.processConnection(connCtx, rawConn, connLogger); err != nil { + connLogger.Error("connection error", sl.Err(err)) + } else { + connLogger.Debug("connection completed successfully") + } +} + +// deadlineConn wraps a connection to automatically set deadlines on each read/write +type deadlineConn struct { + net.Conn + rto, wto time.Duration +} + +func (d *deadlineConn) Read(p []byte) (int, error) { + _ = d.SetReadDeadline(time.Now().Add(d.rto)) + return d.Conn.Read(p) +} + +func (d *deadlineConn) Write(p []byte) (int, error) { + _ = d.SetWriteDeadline(time.Now().Add(d.wto)) + return d.Conn.Write(p) +} + +// processConnection handles the protocol message exchange +func (s *TCPServer) processConnection(ctx context.Context, conn net.Conn, logger *slog.Logger) error { + // Set overall connection deadline + globalDL := time.Now().Add(s.config.Timeouts.Connection) + if err := conn.SetDeadline(globalDL); err != nil { + return fmt.Errorf("failed to set connection deadline: %w", err) + } + + // Create deadline wrapper for automatic timeout handling + dc := &deadlineConn{ + Conn: conn, + rto: s.config.Timeouts.Read, + wto: s.config.Timeouts.Write, + } + + // Use LimitReader to prevent reading more than the protocol-defined maximum message size + // This protects against malicious clients that lie about payload length in headers + limitedReader := io.LimitReader(dc, int64(protocol.MaxPayloadSize)+protocol.HeaderSize) + + // Read incoming message through limited reader + logger.Debug("reading message from client") + msg, err := s.decoder.Decode(limitedReader) + if err != nil { + if err == io.EOF { + logger.Debug("client closed connection gracefully") + return nil + } + logger.Error("failed to decode message", sl.Err(err)) + return fmt.Errorf("decode error: %w", err) + } + + logger.Debug("message decoded", "type", msg.Type, "payload_length", msg.PayloadLength) + + // Process message through application layer + response, err := s.wisdomApplication.HandleMessage(ctx, msg) + if err != nil { + logger.Error("application error", sl.Err(err)) + return fmt.Errorf("application error: %w", err) + } + + logger.Debug("sending response to client") + // Send response using the response's own Encode method + if err := response.Encode(dc); err != nil { + logger.Error("failed to encode response", sl.Err(err)) + return fmt.Errorf("encode error: %w", err) + } + + logger.Debug("response sent successfully") + return nil +}