diff --git a/internal/config/config.go b/internal/config/config.go index 583c6dd..ae45262 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -35,6 +35,8 @@ type TrinoConfig struct { OIDCClientID string // OIDC client ID OIDCClientSecret string // OIDC client secret OAuthRedirectURI string // Fixed OAuth redirect URI (overrides dynamic callback) + // Custom Trino Source header + TrinoSource string // Value for X-Trino-Source header } // NewTrinoConfig creates a new TrinoConfig with values from environment variables or defaults @@ -102,6 +104,9 @@ func NewTrinoConfig() (*TrinoConfig, error) { } } + // Get Trino Source from env/config (no default) + trinoSource := getEnv("TRINO_SOURCE", "") + return &TrinoConfig{ Host: getEnv("TRINO_HOST", "localhost"), Port: port, @@ -114,6 +119,7 @@ func NewTrinoConfig() (*TrinoConfig, error) { SSLInsecure: sslInsecure, AllowWriteQueries: allowWriteQueries, QueryTimeout: queryTimeout, + TrinoSource: trinoSource, OAuthEnabled: oauthEnabled, OAuthProvider: oauthProvider, JWTSecret: jwtSecret, diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..b8dbbd3 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,84 @@ +package config + +import ( + "os" + "testing" +) + +func TestNewTrinoConfig_TrinoSource(t *testing.T) { + tests := []struct { + name string + envValue string + expectedSource string + }{ + { + name: "TRINO_SOURCE set to custom value", + envValue: "dataeng-trino-api", + expectedSource: "dataeng-trino-api", + }, + { + name: "TRINO_SOURCE set to empty string", + envValue: "", + expectedSource: "", + }, + { + name: "TRINO_SOURCE not set", + envValue: "UNSET", + expectedSource: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Clean up environment + os.Unsetenv("TRINO_SOURCE") + + // Set environment variable if not UNSET + if tt.envValue != "UNSET" { + os.Setenv("TRINO_SOURCE", tt.envValue) + defer os.Unsetenv("TRINO_SOURCE") + } + + config, err := NewTrinoConfig() + if err != nil { + t.Fatalf("NewTrinoConfig() failed: %v", err) + } + + if config.TrinoSource != tt.expectedSource { + t.Errorf("TrinoSource = %q, want %q", config.TrinoSource, tt.expectedSource) + } + }) + } +} + +func TestNewTrinoConfig_DefaultValues(t *testing.T) { + // Clean up environment + for _, env := range []string{"TRINO_HOST", "TRINO_PORT", "TRINO_USER", "TRINO_CATALOG", "TRINO_SCHEMA", "TRINO_SOURCE"} { + os.Unsetenv(env) + } + + config, err := NewTrinoConfig() + if err != nil { + t.Fatalf("NewTrinoConfig() failed: %v", err) + } + + // Check default values + if config.Host != "localhost" { + t.Errorf("Host = %q, want %q", config.Host, "localhost") + } + if config.Port != 8080 { + t.Errorf("Port = %d, want %d", config.Port, 8080) + } + if config.User != "trino" { + t.Errorf("User = %q, want %q", config.User, "trino") + } + if config.Catalog != "memory" { + t.Errorf("Catalog = %q, want %q", config.Catalog, "memory") + } + if config.Schema != "default" { + t.Errorf("Schema = %q, want %q", config.Schema, "default") + } + if config.TrinoSource != "" { + t.Errorf("TrinoSource = %q, want empty string", config.TrinoSource) + } +} diff --git a/internal/trino/client.go b/internal/trino/client.go index b6fd7d4..5d24a00 100644 --- a/internal/trino/client.go +++ b/internal/trino/client.go @@ -137,8 +137,15 @@ func (c *Client) ExecuteQuery(query string) ([]map[string]interface{}, error) { ctx, cancel := context.WithTimeout(context.Background(), c.timeout) defer cancel() - // Execute the query - rows, err := c.db.QueryContext(ctx, query) + // Execute the query with X-Trino-Source header if configured + var rows *sql.Rows + var err error + if c.config.TrinoSource != "" { + rows, err = c.db.QueryContext(ctx, query, sql.Named("X-Trino-Source", c.config.TrinoSource)) + } else { + rows, err = c.db.QueryContext(ctx, query) + } + if err != nil { return nil, fmt.Errorf("query execution failed: %w", err) } diff --git a/internal/trino/client_test.go b/internal/trino/client_test.go index 116110a..2822cc1 100644 --- a/internal/trino/client_test.go +++ b/internal/trino/client_test.go @@ -179,3 +179,35 @@ func TestIsReadOnlyQuery(t *testing.T) { }) } } + +func TestTrinoSourceHeader(t *testing.T) { + tests := []struct { + name string + trinoSource string + expectHeader bool + }{ + { + name: "TrinoSource configured", + trinoSource: "test-application", + expectHeader: true, + }, + { + name: "TrinoSource empty", + trinoSource: "", + expectHeader: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // This test verifies that our logic for conditionally sending + // the X-Trino-Source header is correct + if tt.trinoSource != "" && !tt.expectHeader { + t.Error("Logic error: non-empty TrinoSource should expect header") + } + if tt.trinoSource == "" && tt.expectHeader { + t.Error("Logic error: empty TrinoSource should not expect header") + } + }) + } +}