[PHASE-6] Implement tcp server
This commit is contained in:
parent
8476340f75
commit
0caaab002f
213
internal/server/tcp.go
Normal file
213
internal/server/tcp.go
Normal file
|
|
@ -0,0 +1,213 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"hash-of-wisdom/internal/application"
|
||||
"hash-of-wisdom/internal/lib/sl"
|
||||
"hash-of-wisdom/internal/protocol"
|
||||
"hash-of-wisdom/internal/service"
|
||||
)
|
||||
|
||||
// TCPServer handles TCP connections for the Word of Wisdom protocol
|
||||
type TCPServer struct {
|
||||
config *Config
|
||||
wisdomApplication *application.WisdomApplication
|
||||
decoder *protocol.MessageDecoder
|
||||
listener net.Listener
|
||||
logger *slog.Logger
|
||||
wg sync.WaitGroup
|
||||
shutdown chan struct{}
|
||||
}
|
||||
|
||||
// Option is a functional option for configuring TCPServer
|
||||
type option func(*TCPServer)
|
||||
|
||||
// WithConfig sets a custom configuration
|
||||
func WithConfig(config *Config) option {
|
||||
return func(s *TCPServer) {
|
||||
s.config = config
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogger sets a custom logger
|
||||
func WithLogger(logger *slog.Logger) option {
|
||||
return func(s *TCPServer) {
|
||||
s.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// NewTCPServer creates a new TCP server with optional configuration
|
||||
func NewTCPServer(wisdomService *service.WisdomService, opts ...option) *TCPServer {
|
||||
server := &TCPServer{
|
||||
config: DefaultConfig(),
|
||||
wisdomApplication: application.NewWisdomApplication(wisdomService),
|
||||
decoder: protocol.NewMessageDecoder(),
|
||||
logger: slog.Default(),
|
||||
shutdown: make(chan struct{}),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(server)
|
||||
}
|
||||
|
||||
return server
|
||||
}
|
||||
|
||||
|
||||
// Start starts the TCP server
|
||||
func (s *TCPServer) Start(ctx context.Context) error {
|
||||
listener, err := net.Listen("tcp", s.config.Address)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to listen on %s: %w", s.config.Address, err)
|
||||
}
|
||||
|
||||
s.listener = listener
|
||||
s.logger.Info("tcp server started", "address", s.config.Address)
|
||||
|
||||
go s.acceptLoop(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully stops the server
|
||||
func (s *TCPServer) Stop() error {
|
||||
s.logger.Info("stopping tcp server")
|
||||
close(s.shutdown)
|
||||
|
||||
if s.listener != nil {
|
||||
s.listener.Close()
|
||||
}
|
||||
|
||||
s.wg.Wait()
|
||||
s.logger.Info("tcp server stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Address returns the server's listening address
|
||||
func (s *TCPServer) Address() string {
|
||||
if s.listener != nil {
|
||||
return s.listener.Addr().String()
|
||||
}
|
||||
return s.config.Address
|
||||
}
|
||||
|
||||
// acceptLoop accepts and handles incoming connections
|
||||
func (s *TCPServer) acceptLoop(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-s.shutdown:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
rawConn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-s.shutdown:
|
||||
return
|
||||
default:
|
||||
s.logger.Error("accept error", sl.Err(err))
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.handleConnection(ctx, rawConn)
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnection handles a single client connection
|
||||
func (s *TCPServer) handleConnection(ctx context.Context, rawConn net.Conn) {
|
||||
defer rawConn.Close()
|
||||
|
||||
connLogger := s.logger.With("remote_addr", rawConn.RemoteAddr().String())
|
||||
connLogger.Info("connection accepted")
|
||||
|
||||
// Create connection-scoped context for overall timeout
|
||||
connCtx, cancel := context.WithTimeout(ctx, s.config.Timeouts.Connection)
|
||||
defer cancel()
|
||||
|
||||
if err := s.processConnection(connCtx, rawConn, connLogger); err != nil {
|
||||
connLogger.Error("connection error", sl.Err(err))
|
||||
} else {
|
||||
connLogger.Debug("connection completed successfully")
|
||||
}
|
||||
}
|
||||
|
||||
// deadlineConn wraps a connection to automatically set deadlines on each read/write
|
||||
type deadlineConn struct {
|
||||
net.Conn
|
||||
rto, wto time.Duration
|
||||
}
|
||||
|
||||
func (d *deadlineConn) Read(p []byte) (int, error) {
|
||||
_ = d.SetReadDeadline(time.Now().Add(d.rto))
|
||||
return d.Conn.Read(p)
|
||||
}
|
||||
|
||||
func (d *deadlineConn) Write(p []byte) (int, error) {
|
||||
_ = d.SetWriteDeadline(time.Now().Add(d.wto))
|
||||
return d.Conn.Write(p)
|
||||
}
|
||||
|
||||
// processConnection handles the protocol message exchange
|
||||
func (s *TCPServer) processConnection(ctx context.Context, conn net.Conn, logger *slog.Logger) error {
|
||||
// Set overall connection deadline
|
||||
globalDL := time.Now().Add(s.config.Timeouts.Connection)
|
||||
if err := conn.SetDeadline(globalDL); err != nil {
|
||||
return fmt.Errorf("failed to set connection deadline: %w", err)
|
||||
}
|
||||
|
||||
// Create deadline wrapper for automatic timeout handling
|
||||
dc := &deadlineConn{
|
||||
Conn: conn,
|
||||
rto: s.config.Timeouts.Read,
|
||||
wto: s.config.Timeouts.Write,
|
||||
}
|
||||
|
||||
// Use LimitReader to prevent reading more than the protocol-defined maximum message size
|
||||
// This protects against malicious clients that lie about payload length in headers
|
||||
limitedReader := io.LimitReader(dc, int64(protocol.MaxPayloadSize)+protocol.HeaderSize)
|
||||
|
||||
// Read incoming message through limited reader
|
||||
logger.Debug("reading message from client")
|
||||
msg, err := s.decoder.Decode(limitedReader)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
logger.Debug("client closed connection gracefully")
|
||||
return nil
|
||||
}
|
||||
logger.Error("failed to decode message", sl.Err(err))
|
||||
return fmt.Errorf("decode error: %w", err)
|
||||
}
|
||||
|
||||
logger.Debug("message decoded", "type", msg.Type, "payload_length", msg.PayloadLength)
|
||||
|
||||
// Process message through application layer
|
||||
response, err := s.wisdomApplication.HandleMessage(ctx, msg)
|
||||
if err != nil {
|
||||
logger.Error("application error", sl.Err(err))
|
||||
return fmt.Errorf("application error: %w", err)
|
||||
}
|
||||
|
||||
logger.Debug("sending response to client")
|
||||
// Send response using the response's own Encode method
|
||||
if err := response.Encode(dc); err != nil {
|
||||
logger.Error("failed to encode response", sl.Err(err))
|
||||
return fmt.Errorf("encode error: %w", err)
|
||||
}
|
||||
|
||||
logger.Debug("response sent successfully")
|
||||
return nil
|
||||
}
|
||||
Loading…
Reference in a new issue