diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 00000000..0347c75e --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,33 @@ +name: Run Tests + +on: + push: + branches: [ main, master ] + pull_request: + branches: [ '**' ] + +jobs: + test: + name: Run Tests + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: '1.21' + cache: true + + - name: Install dependencies + run: go mod vendor + + - name: Run tests + run: make test + + - name: Run race condition tests + run: make test.race diff --git a/Dockerfile b/Dockerfile index f5d9f85f..7af42e66 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,9 @@ -FROM golang:1.23-alpine3.21 AS builder +FROM --platform=${BUILDPLATFORM} golang:1.23-alpine3.21 AS builder RUN apk add --no-cache alpine-sdk ca-certificates +ARG TARGETARCH +ARG TARGETOS ARG VERSION ENV CGO_ENABLED=0 \ @@ -12,6 +14,7 @@ WORKDIR /go/src/github.com/grepplabs/kafka-proxy COPY . . RUN mkdir -p build && \ + GOOS=${TARGETOS} GOARCH=${TARGETARCH} \ go build -mod=vendor -o build/kafka-proxy -ldflags "${LDFLAGS}" . FROM alpine:3.21 diff --git a/Makefile b/Makefile index 85caa2bc..c6d40061 100644 --- a/Makefile +++ b/Makefile @@ -11,6 +11,7 @@ GOPKGS = $(shell go list ./... | grep -v /vendor/) BUILD_FLAGS ?= LDFLAGS ?= -X github.com/grepplabs/kafka-proxy/config.Version=$(VERSION) -w -s TAG ?= "v0.4.3" +REPO ?= "grepplabs/kafka-proxy" PROTOC_GO_VERSION ?= v1.33 PROTOC_GRPC_VERSION ?= v1.2 @@ -55,6 +56,14 @@ docker.build: docker.build.all: docker build --build-arg VERSION=$(VERSION) -t local/kafka-proxy -f Dockerfile.all . +docker.build.multiarch: + docker buildx build \ + --platform linux/amd64,linux/arm64 \ + --push \ + --build-arg VERSION=$(VERSION) \ + -t $(REPO):$(TAG) \ + . + tag: git tag $(TAG) diff --git a/README.md b/README.md index a82e3c54..9de04eaf 100644 --- a/README.md +++ b/README.md @@ -207,6 +207,7 @@ You can launch a kafka-proxy container with auth-ldap plugin for trying it out w --sasl-plugin-param stringArray Authentication plugin parameter --sasl-plugin-timeout duration Authentication timeout (default 10s) --sasl-username string SASL user name + --shutdown-timeout duration Maximum time to wait for graceful shutdown to complete (default 30s) --tls-ca-chain-cert-file string PEM encoded CA's certificate file --tls-client-cert-file string PEM encoded file with client certificate --tls-client-key-file string PEM encoded file with private key for the client certificate @@ -447,8 +448,25 @@ By setting `--proxy-listener-tls-client-cert-validate-subject true`, Kafka Proxy --proxy-listener-tls-client-cert-validate-subject true \ --proxy-listener-tls-required-client-subject-country DE \ --proxy-listener-tls-required-client-subject-organization grepplabs + +### Graceful Shutdown + +Kafka-proxy implements graceful shutdown to ensure that active connections are properly closed when the proxy is terminated. When a termination signal (SIGINT or SIGTERM) is received, the proxy will: + +1. Stop accepting new connections +2. Wait for existing connections to complete their current operations +3. Close all connections cleanly before exiting + +You can configure the maximum time the proxy will wait during shutdown with the `--shutdown-timeout` parameter: + +``` + kafka-proxy server \ + --bootstrap-server-mapping "kafka-0.example.com:9092,127.0.0.1:32500" \ + --shutdown-timeout 60s ``` +The default timeout is 30 seconds. If active connections take longer than the specified timeout to close, the proxy will force termination after the timeout period. + ### Kubernetes sidecar container example ```yaml diff --git a/cmd/kafka-proxy/graceful_shutdown_test.go b/cmd/kafka-proxy/graceful_shutdown_test.go new file mode 100644 index 00000000..9b2f111d --- /dev/null +++ b/cmd/kafka-proxy/graceful_shutdown_test.go @@ -0,0 +1,422 @@ +package server + +import ( + "fmt" + "net" + "net/http" + "os" + "os/exec" + "strings" + "syscall" + "testing" + "time" + + "github.com/grepplabs/kafka-proxy/config" +) + +func TestGracefulShutdown(t *testing.T) { + // Skip if running in short mode + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + // Create a temporary config for testing + testConfig := createTestConfig(t) + + // Start kafka-proxy in a separate process + cmd := startKafkaProxy(t, testConfig) + defer func() { + if cmd.Process != nil { + cmd.Process.Kill() + } + }() + + // Wait for the proxy to start + waitForProxyStart(t, testConfig.Http.ListenAddress) + + // Send SIGTERM to initiate graceful shutdown + err := cmd.Process.Signal(syscall.SIGTERM) + if err != nil { + t.Fatalf("Failed to send SIGTERM: %v", err) + } + + // Wait for the process to exit gracefully + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + + select { + case err := <-done: + // Process should exit without error (exit code 0) + if err != nil { + // Check if it's just a signal exit + if exitError, ok := err.(*exec.ExitError); ok { + // Exit code 1 is expected when shutting down via signal + if exitError.ExitCode() != 1 { + t.Errorf("Unexpected exit code: got %d, want 1", exitError.ExitCode()) + } + } + } + case <-time.After(45 * time.Second): + t.Fatal("Graceful shutdown took too long") + } +} + +func TestGracefulShutdownWithTimeout(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + // Create config with shorter timeout + testConfig := createTestConfig(t) + + // Start kafka-proxy with short shutdown timeout + cmd := startKafkaProxyWithTimeout(t, testConfig, "5s") + defer func() { + if cmd.Process != nil { + cmd.Process.Kill() + } + }() + + // Wait for the proxy to start + waitForProxyStart(t, testConfig.Http.ListenAddress) + + // Send SIGTERM + err := cmd.Process.Signal(syscall.SIGTERM) + if err != nil { + t.Fatalf("Failed to send SIGTERM: %v", err) + } + + // Measure shutdown time + start := time.Now() + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + + select { + case <-done: + shutdownDuration := time.Since(start) + // Should shutdown within reasonable time (much less than default 30s) + if shutdownDuration >= 15*time.Second { + t.Errorf("Shutdown took too long: %v, expected < 15s", shutdownDuration) + } + case <-time.After(20 * time.Second): + t.Fatal("Shutdown timeout test took too long") + } +} + +func TestConnectionCleanupDuringShutdown(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + // Create test config + testConfig := createTestConfig(t) + + // Start kafka-proxy + cmd := startKafkaProxy(t, testConfig) + defer func() { + if cmd.Process != nil { + cmd.Process.Kill() + } + }() + + // Wait for the proxy to start + waitForProxyStart(t, testConfig.Http.ListenAddress) + + // Create multiple mock connections to simulate active clients + conns := createMockConnections(t, testConfig.Proxy.BootstrapServers[0].ListenerAddress, 5) + defer func() { + for _, conn := range conns { + conn.Close() + } + }() + + // Verify connections are established + if len(conns) == 0 { + t.Fatal("Failed to establish mock connections") + } + + // Send SIGTERM to initiate graceful shutdown + err := cmd.Process.Signal(syscall.SIGTERM) + if err != nil { + t.Fatalf("Failed to send SIGTERM: %v", err) + } + + // Wait for the process to exit gracefully + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + + select { + case err := <-done: + // Process should exit without error or with expected signal exit code + if err != nil { + if exitError, ok := err.(*exec.ExitError); ok { + // Exit code 1 is expected when shutting down via signal + if exitError.ExitCode() != 1 { + t.Errorf("Unexpected exit code: got %d, want 1", exitError.ExitCode()) + } + } else { + t.Errorf("Unexpected error during shutdown: %v", err) + } + } + + // After shutdown, connections should be closed by the server + for i, conn := range conns { + // Try to write to connection - should fail if properly closed by server + _, err := conn.Write([]byte("test")) + if err == nil { + t.Errorf("Connection %d should be closed after shutdown", i) + } + } + case <-time.After(45 * time.Second): + t.Fatal("Connection cleanup test took too long") + } +} + +func TestHTTPServerGracefulShutdown(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + + testConfig := createTestConfig(t) + + cmd := startKafkaProxy(t, testConfig) + defer func() { + if cmd.Process != nil { + cmd.Process.Kill() + } + }() + + // Wait for the proxy to start + waitForProxyStart(t, testConfig.Http.ListenAddress) + + // Verify HTTP endpoints are working + resp, err := http.Get(fmt.Sprintf("http://%s%s", testConfig.Http.ListenAddress, testConfig.Http.HealthPath)) + if err != nil { + t.Fatalf("Failed to get health endpoint: %v", err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status OK, got %d", resp.StatusCode) + } + + // Start shutdown + err = cmd.Process.Signal(syscall.SIGTERM) + if err != nil { + t.Fatalf("Failed to send SIGTERM: %v", err) + } + + // Wait for shutdown to complete + done := make(chan error, 1) + go func() { + done <- cmd.Wait() + }() + + select { + case <-done: + // After shutdown, HTTP endpoints should not be accessible + _, err := http.Get(fmt.Sprintf("http://%s%s", testConfig.Http.ListenAddress, testConfig.Http.HealthPath)) + if err == nil { + t.Error("HTTP endpoint should not be accessible after shutdown") + } + case <-time.After(30 * time.Second): + t.Fatal("HTTP server shutdown test took too long") + } +} + +// Helper functions + +func createTestConfig(t *testing.T) *config.Config { + // Find available ports + httpPort := findAvailablePort(t) + proxyPort := findAvailablePort(t) + + return &config.Config{ + Http: struct { + ListenAddress string + MetricsPath string + HealthPath string + Disable bool + }{ + ListenAddress: fmt.Sprintf("127.0.0.1:%d", httpPort), + MetricsPath: "/metrics", + HealthPath: "/health", + Disable: false, + }, + Proxy: struct { + DefaultListenerIP string + BootstrapServers []config.ListenerConfig + ExternalServers []config.ListenerConfig + DeterministicListeners bool + DialAddressMappings []config.DialAddressMapping + DisableDynamicListeners bool + DynamicAdvertisedListener string + DynamicSequentialMinPort uint16 + DynamicSequentialMaxPorts uint16 + RequestBufferSize int + ResponseBufferSize int + ListenerReadBufferSize int + ListenerWriteBufferSize int + ListenerKeepAlive time.Duration + ShutdownTimeout time.Duration + TLS struct { + Enable bool + Refresh time.Duration + ListenerCertFile string + ListenerKeyFile string + ListenerKeyPassword string + ListenerCAChainCertFile string + ListenerCRLFile string + ListenerCipherSuites []string + ListenerCurvePreferences []string + ClientCert struct { + Subjects []string + } + } + }{ + DefaultListenerIP: "127.0.0.1", + BootstrapServers: []config.ListenerConfig{ + { + BrokerAddress: "127.0.0.1:9092", // Fake Kafka broker + ListenerAddress: fmt.Sprintf("127.0.0.1:%d", proxyPort), + AdvertisedAddress: fmt.Sprintf("127.0.0.1:%d", proxyPort), + }, + }, + DisableDynamicListeners: true, + RequestBufferSize: 4096, + ResponseBufferSize: 4096, + ListenerKeepAlive: 60 * time.Second, + ShutdownTimeout: 30 * time.Second, + }, + } +} + +func startKafkaProxy(t *testing.T, cfg *config.Config) *exec.Cmd { + return startKafkaProxyWithTimeout(t, cfg, "30s") +} + +func startKafkaProxyWithTimeout(t *testing.T, cfg *config.Config, timeout string) *exec.Cmd { + // Build the command arguments + args := []string{ + "server", + "--bootstrap-server-mapping", fmt.Sprintf("%s,%s", + cfg.Proxy.BootstrapServers[0].BrokerAddress, + cfg.Proxy.BootstrapServers[0].ListenerAddress), + "--http-listen-address", cfg.Http.ListenAddress, + "--shutdown-timeout", timeout, + "--log-level", "info", + "--dynamic-listeners-disable", + } + + // Get the path to the kafka-proxy binary + binaryPath, err := os.Executable() + if err != nil { + t.Fatalf("Failed to get executable path: %v", err) + } + + // If we're running tests, the binary might be in a different location + if strings.Contains(binaryPath, "test") { + // Try to find the kafka-proxy binary in the current directory or build it + if _, err := os.Stat("./kafka-proxy"); os.IsNotExist(err) { + t.Skip("kafka-proxy binary not found, skipping integration test") + } + binaryPath = "./kafka-proxy" + } + + cmd := exec.Command(binaryPath, args...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + err = cmd.Start() + if err != nil { + t.Fatalf("Failed to start kafka-proxy: %v", err) + } + + return cmd +} + +func waitForProxyStart(t *testing.T, address string) { + // Wait for HTTP server to be ready + client := &http.Client{Timeout: 1 * time.Second} + + for i := 0; i < 30; i++ { + resp, err := client.Get(fmt.Sprintf("http://%s/health", address)) + if err == nil { + resp.Body.Close() + if resp.StatusCode == http.StatusOK { + return + } + } + time.Sleep(500 * time.Millisecond) + } + + t.Fatal("Kafka proxy did not start within expected time") +} + +func findAvailablePort(t *testing.T) int { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("Failed to find available port: %v", err) + } + defer listener.Close() + + addr := listener.Addr().(*net.TCPAddr) + return addr.Port +} + +// createMockConnections creates a specified number of TCP connections to the given address +func createMockConnections(t *testing.T, address string, count int) []net.Conn { + var connections []net.Conn + + for i := 0; i < count; i++ { + conn, err := net.Dial("tcp", address) + if err != nil { + t.Logf("Warning: Failed to create connection %d: %v", i, err) + continue + } + + // Set a longer timeout to keep connection alive during test + conn.SetDeadline(time.Now().Add(30 * time.Second)) + + // Send some data to establish a real connection + _, err = conn.Write([]byte{0, 0, 0, 0}) + if err != nil { + conn.Close() + t.Logf("Warning: Failed to write to connection %d: %v", i, err) + continue + } + + connections = append(connections, conn) + } + + return connections +} + +// Benchmark graceful shutdown performance +func BenchmarkGracefulShutdown(b *testing.B) { + if testing.Short() { + b.Skip("Skipping benchmark in short mode") + } + + for i := 0; i < b.N; i++ { + testConfig := createTestConfig(&testing.T{}) + + cmd := startKafkaProxy(&testing.T{}, testConfig) + + // Wait for startup + waitForProxyStart(&testing.T{}, testConfig.Http.ListenAddress) + + // Measure shutdown time + start := time.Now() + cmd.Process.Signal(syscall.SIGTERM) + cmd.Wait() + shutdownDuration := time.Since(start) + + b.ReportMetric(float64(shutdownDuration.Milliseconds()), "shutdown_ms") + } +} \ No newline at end of file diff --git a/cmd/kafka-proxy/server.go b/cmd/kafka-proxy/server.go index 93067c28..aa101014 100644 --- a/cmd/kafka-proxy/server.go +++ b/cmd/kafka-proxy/server.go @@ -1,9 +1,11 @@ package server import ( + "context" "fmt" "log/slog" "runtime" + "sync" "github.com/grepplabs/kafka-proxy/config" "github.com/grepplabs/kafka-proxy/proxy" @@ -103,6 +105,7 @@ func initFlags() { Server.Flags().IntVar(&c.Proxy.ListenerReadBufferSize, "proxy-listener-read-buffer-size", 0, "Size of the operating system's receive buffer associated with the connection. If zero, system default is used") Server.Flags().IntVar(&c.Proxy.ListenerWriteBufferSize, "proxy-listener-write-buffer-size", 0, "Sets the size of the operating system's transmit buffer associated with the connection. If zero, system default is used") Server.Flags().DurationVar(&c.Proxy.ListenerKeepAlive, "proxy-listener-keep-alive", 60*time.Second, "Keep alive period for an active network connection. If zero, keep-alives are disabled") + Server.Flags().DurationVar(&c.Proxy.ShutdownTimeout, "shutdown-timeout", 30*time.Second, "Maximum time to wait for graceful shutdown of connections and servers") Server.Flags().BoolVar(&c.Proxy.TLS.Enable, "proxy-listener-tls-enable", false, "Whether or not to use TLS listener") Server.Flags().DurationVar(&c.Proxy.TLS.Refresh, "proxy-listener-tls-refresh", 0*time.Second, "Interval for refreshing server TLS certificates. If set to zero, the refresh watch is disabled") @@ -383,10 +386,20 @@ func Run(_ *cobra.Command, _ []string) { } } + // Graceful shutdown configuration + shutdownTimeout := 30 * time.Second + if c.Proxy.ShutdownTimeout > 0 { + shutdownTimeout = c.Proxy.ShutdownTimeout + } + var g run.Group + var connset *proxy.ConnSet + var proxyClient *proxy.Client + var shutdownWg sync.WaitGroup + { // All active connections are stored in this variable. - connset := proxy.NewConnSet() + connset = proxy.NewConnSet() prometheus.MustRegister(proxy.NewCollector(connset)) listeners, err := proxy.NewListeners(c) if err != nil { @@ -396,7 +409,7 @@ func Run(_ *cobra.Command, _ []string) { if err != nil { logrus.Fatal(err) } - proxyClient, err := proxy.NewClient(connset, c, listeners.GetNetAddressMapping, localPasswordAuthenticator, localTokenAuthenticator, saslTokenProvider, gatewayTokenProvider, gatewayTokenInfo) + proxyClient, err = proxy.NewClient(connset, c, listeners.GetNetAddressMapping, localPasswordAuthenticator, localTokenAuthenticator, saslTokenProvider, gatewayTokenProvider, gatewayTokenInfo) if err != nil { logrus.Fatal(err) } @@ -404,21 +417,25 @@ func Run(_ *cobra.Command, _ []string) { logrus.Print("Ready for new connections") return proxyClient.Run(connSrc) }, func(error) { + logrus.Info("Initiating graceful shutdown of proxy client...") proxyClient.Close() }) } { cancelInterrupt := make(chan struct{}) + signalChan := make(chan os.Signal, 1) g.Add(func() error { - c := make(chan os.Signal, 1) - signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) + signal.Notify(signalChan, syscall.SIGINT, syscall.SIGTERM) select { - case sig := <-c: + case sig := <-signalChan: + logrus.Infof("Received signal %s, initiating graceful shutdown", sig) return fmt.Errorf("received signal %s", sig) case <-cancelInterrupt: return nil } }, func(error) { + logrus.Info("Stopping signal handler...") + signal.Stop(signalChan) // Stop receiving new signals close(cancelInterrupt) }) } @@ -427,10 +444,21 @@ func Run(_ *cobra.Command, _ []string) { if err != nil { logrus.Fatal(err) } + httpServer := &http.Server{Handler: NewHTTPHandler()} g.Add(func() error { - return http.Serve(httpListener, NewHTTPHandler()) + return httpServer.Serve(httpListener) }, func(error) { - httpListener.Close() + logrus.Info("Shutting down HTTP server...") + shutdownWg.Add(1) + go func() { + defer shutdownWg.Done() + ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) + defer cancel() + if err := httpServer.Shutdown(ctx); err != nil { + logrus.Warnf("HTTP server shutdown error: %v", err) + httpListener.Close() + } + }() }) } if c.Debug.Enabled { @@ -440,15 +468,98 @@ func Run(_ *cobra.Command, _ []string) { if err != nil { logrus.Fatal(err) } + debugServer := &http.Server{Handler: http.DefaultServeMux} g.Add(func() error { - return http.Serve(debugListener, http.DefaultServeMux) + return debugServer.Serve(debugListener) }, func(error) { - debugListener.Close() + logrus.Info("Shutting down debug server...") + shutdownWg.Add(1) + go func() { + defer shutdownWg.Done() + ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) + defer cancel() + if err := debugServer.Shutdown(ctx); err != nil { + logrus.Warnf("Debug server shutdown error: %v", err) + debugListener.Close() + } + }() }) } + logrus.Info("Starting kafka-proxy services...") + + // Setup a context with cancel for coordinating graceful shutdown + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Run the service group err := g.Run() - logrus.Info("Exit ", err) + + // Enhanced graceful shutdown process + logrus.Info("Initiating graceful shutdown sequence...") + + // Wait for HTTP servers to shutdown gracefully + shutdownComplete := make(chan struct{}) + go func() { + shutdownWg.Wait() + close(shutdownComplete) + }() + + // Wait for shutdown completion or timeout + shutdownCtx, shutdownCancel := context.WithTimeout(ctx, shutdownTimeout) + defer shutdownCancel() + + select { + case <-shutdownComplete: + logrus.Info("All HTTP servers shut down gracefully") + case <-shutdownCtx.Done(): + logrus.Warn("HTTP servers shutdown timeout exceeded") + } + + // Final connection cleanup with timeout + if connset != nil { + // Log connection counts before cleanup + connectionCounts := connset.Count() + totalConnections := 0 + for broker, count := range connectionCounts { + totalConnections += count + logrus.Infof("Active connections to %s: %d", broker, count) + } + logrus.Infof("Total active connections: %d", totalConnections) + + if totalConnections > 0 { + logrus.Info("Closing remaining connections...") + // Create a context with timeout for connection cleanup + ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) + defer cancel() + + connectionCleanupDone := make(chan struct{}) + go func() { + defer close(connectionCleanupDone) + // Close all connections in the connection set + if err := connset.Close(); err != nil { + logrus.Warnf("Error closing connections: %v", err) + } else { + logrus.Info("All connections closed successfully") + } + }() + + // Wait for connection cleanup or timeout + select { + case <-connectionCleanupDone: + logrus.Info("Connection cleanup completed") + case <-ctx.Done(): + logrus.Warn("Connection cleanup timeout exceeded, forcing exit") + } + } else { + logrus.Info("No active connections to close") + } + } + + logrus.Info("Kafka-proxy shutdown complete") + if err != nil { + logrus.Infof("Exit reason: %v", err) + } } func NewHTTPHandler() http.Handler { diff --git a/config/config.go b/config/config.go index dbeb0ac3..2ed4aa53 100644 --- a/config/config.go +++ b/config/config.go @@ -88,6 +88,7 @@ type Config struct { ListenerReadBufferSize int // SO_RCVBUF ListenerWriteBufferSize int // SO_SNDBUF ListenerKeepAlive time.Duration + ShutdownTimeout time.Duration TLS struct { Enable bool @@ -291,6 +292,7 @@ func NewConfig() *Config { c.Proxy.RequestBufferSize = 4096 c.Proxy.ResponseBufferSize = 4096 c.Proxy.ListenerKeepAlive = 60 * time.Second + c.Proxy.ShutdownTimeout = 30 * time.Second return c } diff --git a/proxy/client.go b/proxy/client.go index eb27e699..b7d720ae 100644 --- a/proxy/client.go +++ b/proxy/client.go @@ -271,6 +271,7 @@ STOP: func (c *Client) Close() { c.stopOnce.Do(func() { + logrus.Info("Initiating graceful shutdown of proxy client...") close(c.stopRun) }) }