diff --git a/api/docs.go b/api/docs.go index 007ef68..2918be4 100644 --- a/api/docs.go +++ b/api/docs.go @@ -2831,7 +2831,8 @@ const docTemplate = `{ "bearer", "basic", "oidc_client_credentials", - "oidc_user" + "oidc_user", + "oauth2_token_endpoint" ] } } diff --git a/api/openapi.yaml b/api/openapi.yaml index 982382e..57f8c71 100644 --- a/api/openapi.yaml +++ b/api/openapi.yaml @@ -2398,6 +2398,7 @@ components: - basic - oidc_client_credentials - oidc_user + - oauth2_token_endpoint type: string required: - type diff --git a/api/swagger.json b/api/swagger.json index b7c9cbf..54d4b89 100644 --- a/api/swagger.json +++ b/api/swagger.json @@ -2825,7 +2825,8 @@ "bearer", "basic", "oidc_client_credentials", - "oidc_user" + "oidc_user", + "oauth2_token_endpoint" ] } } diff --git a/api/swagger.yaml b/api/swagger.yaml index ad96ccd..98124d6 100644 --- a/api/swagger.yaml +++ b/api/swagger.yaml @@ -252,6 +252,7 @@ definitions: - basic - oidc_client_credentials - oidc_user + - oauth2_token_endpoint type: string required: - type diff --git a/pkg/executors/http/auth/errors.go b/pkg/executors/http/auth/errors.go index 35f0546..1b59d3a 100644 --- a/pkg/executors/http/auth/errors.go +++ b/pkg/executors/http/auth/errors.go @@ -23,6 +23,15 @@ var ( ErrOIDCUserClientRequired = pkg.ValidationError{EntityType: "OIDCUserConfig", Message: "oidc_user config: client_id is required"} ErrOIDCUserUsernameRequired = pkg.ValidationError{EntityType: "OIDCUserConfig", Message: "oidc_user config: username is required"} ErrOIDCUserPasswordRequired = pkg.ValidationError{EntityType: "OIDCUserConfig", Message: "oidc_user config: password is required"} + + ErrOAuth2TokenEndpointConfigRequired = pkg.ValidationError{EntityType: "OAuth2TokenEndpointConfig", Message: "oauth2_token_endpoint config is required"} + ErrOAuth2TokenEndpointURLRequired = pkg.ValidationError{EntityType: "OAuth2TokenEndpointConfig", Message: "oauth2_token_endpoint config: token_url is required"} + ErrOAuth2TokenEndpointClientRequired = pkg.ValidationError{EntityType: "OAuth2TokenEndpointConfig", Message: "oauth2_token_endpoint config: client_id is required"} + ErrOAuth2TokenEndpointSecretRequired = pkg.ValidationError{EntityType: "OAuth2TokenEndpointConfig", Message: "oauth2_token_endpoint config: client_secret is required"} + ErrOAuth2TokenEndpointLocationInvalid = pkg.ValidationError{EntityType: "OAuth2TokenEndpointConfig", Message: "oauth2_token_endpoint config: credentials_location must be 'body' or 'basic_header'"} + ErrOAuth2TokenEndpointReservedExtraParam = pkg.ValidationError{EntityType: "OAuth2TokenEndpointConfig", Message: "oauth2_token_endpoint config: extra_params cannot contain reserved keys (client_id, client_secret, scope, audience)"} + ErrOAuth2TokenEndpointUnsupportedTokenType = pkg.ValidationError{EntityType: "OAuth2TokenEndpointConfig", Message: "oauth2_token_endpoint: token endpoint returned unsupported token_type (only Bearer is accepted)"} + ErrOAuth2TokenEndpointEmptyAccessToken = pkg.ValidationError{EntityType: "OAuth2TokenEndpointConfig", Message: "oauth2_token_endpoint: token endpoint returned empty access_token"} ) // Provider constructor errors - returned when creating providers directly. diff --git a/pkg/executors/http/auth/factory.go b/pkg/executors/http/auth/factory.go index 6fc9173..4dfc37d 100644 --- a/pkg/executors/http/auth/factory.go +++ b/pkg/executors/http/auth/factory.go @@ -69,6 +69,16 @@ func NewFromConfig(authConfig map[string]any, httpClient *http.Client) (Provider return NewOIDCUserProvider(cfg, cacheCfg, httpClient) + case TypeOAuth2TokenEndpoint: + cfg, err := parseOAuth2TokenEndpointConfig(configData) + if err != nil { + return nil, err + } + + cacheCfg := parseCacheConfig(cacheData) + + return NewOAuth2TokenEndpointProvider(cfg, cacheCfg, httpClient) + default: return nil, fmt.Errorf("%w: %s", ErrUnknownAuthType, authType) } @@ -168,6 +178,40 @@ func parseOIDCUserConfig(data map[string]any) (*OIDCUserConfig, error) { return cfg, nil } +func parseOAuth2TokenEndpointConfig(data map[string]any) (*OAuth2TokenEndpointConfig, error) { + cfg := &OAuth2TokenEndpointConfig{} + + if err := mapToStruct(data, cfg); err != nil { + return nil, fmt.Errorf("parse oauth2_token_endpoint config: %w", err) + } + + if cfg.TokenURL == "" { + return nil, ErrOAuth2TokenEndpointURLRequired + } + + if cfg.ClientID == "" { + return nil, ErrOAuth2TokenEndpointClientRequired + } + + if cfg.ClientSecret == "" { + return nil, ErrOAuth2TokenEndpointSecretRequired + } + + if cfg.CredentialsLocation != "" && + cfg.CredentialsLocation != "body" && + cfg.CredentialsLocation != "basic_header" { + return nil, ErrOAuth2TokenEndpointLocationInvalid + } + + for k := range cfg.ExtraParams { + if reservedOAuth2ExtraParams[k] { + return nil, ErrOAuth2TokenEndpointReservedExtraParam + } + } + + return cfg, nil +} + func parseCacheConfig(data map[string]any) *CacheCfg { if data == nil { return nil diff --git a/pkg/executors/http/auth/factory_test.go b/pkg/executors/http/auth/factory_test.go index f230f61..b962da8 100644 --- a/pkg/executors/http/auth/factory_test.go +++ b/pkg/executors/http/auth/factory_test.go @@ -194,6 +194,159 @@ func TestNewFromConfigOIDCUserMissingUsername(t *testing.T) { require.ErrorIs(t, err, ErrOIDCUserUsernameRequired) } +func TestNewFromConfigOAuth2TokenEndpoint(t *testing.T) { + tests := []struct { + name string + config map[string]any + wantErr error + }{ + { + name: "full valid config", + config: map[string]any{ + "token_url": "https://idp.example.com/token", + "client_id": "my-client", + "client_secret": "my-secret", + "credentials_location": "body", + "scopes": []any{"api:read"}, + "audience": "my-audience", + "extra_params": map[string]any{"grant_type": "client_credentials"}, + }, + wantErr: nil, + }, + { + name: "minimal valid config (Delorean-style)", + config: map[string]any{ + "token_url": "https://idp.example.com/token", + "client_id": "my-client", + "client_secret": "my-secret", + }, + wantErr: nil, + }, + { + name: "basic_header credentials_location is valid", + config: map[string]any{ + "token_url": "https://idp.example.com/token", + "client_id": "my-client", + "client_secret": "my-secret", + "credentials_location": "basic_header", + }, + wantErr: nil, + }, + { + name: "missing token_url", + config: map[string]any{ + "client_id": "my-client", + "client_secret": "my-secret", + }, + wantErr: ErrOAuth2TokenEndpointURLRequired, + }, + { + name: "missing client_id", + config: map[string]any{ + "token_url": "https://idp.example.com/token", + "client_secret": "my-secret", + }, + wantErr: ErrOAuth2TokenEndpointClientRequired, + }, + { + name: "missing client_secret", + config: map[string]any{ + "token_url": "https://idp.example.com/token", + "client_id": "my-client", + }, + wantErr: ErrOAuth2TokenEndpointSecretRequired, + }, + { + name: "invalid credentials_location", + config: map[string]any{ + "token_url": "https://idp.example.com/token", + "client_id": "my-client", + "client_secret": "my-secret", + "credentials_location": "query", + }, + wantErr: ErrOAuth2TokenEndpointLocationInvalid, + }, + { + name: "reserved key client_id in extra_params", + config: map[string]any{ + "token_url": "https://idp.example.com/token", + "client_id": "my-client", + "client_secret": "my-secret", + "extra_params": map[string]any{"client_id": "override"}, + }, + wantErr: ErrOAuth2TokenEndpointReservedExtraParam, + }, + { + name: "reserved key client_secret in extra_params", + config: map[string]any{ + "token_url": "https://idp.example.com/token", + "client_id": "my-client", + "client_secret": "my-secret", + "extra_params": map[string]any{"client_secret": "override"}, + }, + wantErr: ErrOAuth2TokenEndpointReservedExtraParam, + }, + { + name: "reserved key scope in extra_params", + config: map[string]any{ + "token_url": "https://idp.example.com/token", + "client_id": "my-client", + "client_secret": "my-secret", + "extra_params": map[string]any{"scope": "override"}, + }, + wantErr: ErrOAuth2TokenEndpointReservedExtraParam, + }, + { + name: "reserved key audience in extra_params", + config: map[string]any{ + "token_url": "https://idp.example.com/token", + "client_id": "my-client", + "client_secret": "my-secret", + "extra_params": map[string]any{"audience": "override"}, + }, + wantErr: ErrOAuth2TokenEndpointReservedExtraParam, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := map[string]any{ + "type": "oauth2_token_endpoint", + "config": tt.config, + } + + provider, err := NewFromConfig(cfg, nil) + + if tt.wantErr != nil { + require.ErrorIs(t, err, tt.wantErr) + return + } + + require.NoError(t, err) + assert.Equal(t, TypeOAuth2TokenEndpoint, provider.Type()) + + // For the "full valid config" case, also verify that every field + // from the input map was actually parsed onto the underlying + // OAuth2TokenEndpointConfig. A regression in factory.go's field + // extraction would otherwise pass the type check while silently + // dropping scopes/audience/extra_params. + if tt.name == "full valid config" { + typed, ok := provider.(*OAuth2TokenEndpointProvider) + require.True(t, ok, "provider must be *OAuth2TokenEndpointProvider") + snap := typed.testConfigSnapshot() + require.NotNil(t, snap) + assert.Equal(t, "https://idp.example.com/token", snap.TokenURL) + assert.Equal(t, "my-client", snap.ClientID) + assert.Equal(t, "my-secret", snap.ClientSecret) + assert.Equal(t, "body", snap.CredentialsLocation) + assert.Equal(t, []string{"api:read"}, snap.Scopes) + assert.Equal(t, "my-audience", snap.Audience) + assert.Equal(t, "client_credentials", snap.ExtraParams["grant_type"]) + } + }) + } +} + func TestNewFromConfigUnknownType(t *testing.T) { cfg := map[string]any{ "type": "unknown", diff --git a/pkg/executors/http/auth/oauth2_token_endpoint.go b/pkg/executors/http/auth/oauth2_token_endpoint.go new file mode 100644 index 0000000..fc30fd7 --- /dev/null +++ b/pkg/executors/http/auth/oauth2_token_endpoint.go @@ -0,0 +1,125 @@ +// Copyright (c) 2026 Lerian Studio. All rights reserved. +// Use of this source code is governed by the Elastic License 2.0 +// that can be found in the LICENSE file. + +package auth + +import ( + "context" + "fmt" + "net/http" + "strings" +) + +// reservedOAuth2ExtraParams holds keys that cannot appear in extra_params. +// Prevents extra_params from silently overriding canonical fields (client_id, +// client_secret) or smuggling alternate scope/audience values that don't match +// the validated config. +var reservedOAuth2ExtraParams = map[string]bool{ + "client_id": true, + "client_secret": true, + "scope": true, + "audience": true, +} + +// OAuth2TokenEndpointProvider provides OAuth2 authentication against a custom +// token endpoint (no OIDC discovery). +type OAuth2TokenEndpointProvider struct { + config *OAuth2TokenEndpointConfig + cacheConfig *CacheCfg + tokenFetcher *TokenFetcher + cacheKey string +} + +// NewOAuth2TokenEndpointProvider creates a new OAuth2 token endpoint provider. +// Returns an error if cfg is nil, required fields are missing, credentials_location +// is invalid, or extra_params contains reserved keys. +func NewOAuth2TokenEndpointProvider(cfg *OAuth2TokenEndpointConfig, cacheCfg *CacheCfg, httpClient *http.Client) (*OAuth2TokenEndpointProvider, error) { + if cfg == nil { + return nil, ErrOAuth2TokenEndpointConfigRequired + } + + if cfg.TokenURL == "" { + return nil, ErrOAuth2TokenEndpointURLRequired + } + + if cfg.ClientID == "" { + return nil, ErrOAuth2TokenEndpointClientRequired + } + + if cfg.ClientSecret == "" { + return nil, ErrOAuth2TokenEndpointSecretRequired + } + + if cfg.CredentialsLocation != "" && + cfg.CredentialsLocation != "body" && + cfg.CredentialsLocation != "basic_header" { + return nil, ErrOAuth2TokenEndpointLocationInvalid + } + + for k := range cfg.ExtraParams { + if reservedOAuth2ExtraParams[k] { + return nil, ErrOAuth2TokenEndpointReservedExtraParam + } + } + + // NOTE: credentials_location default ("body") is NOT mutated on the caller's + // *cfg here — that would surprise callers who reuse the struct or whose + // config originates from a shared/cached source. Instead, the empty value + // is treated as "body" wherever it is consumed (buildOAuth2TokenEndpointData + // and fetchNewOAuth2TokenEndpointToken). + + if cacheCfg == nil { + cacheCfg = &CacheCfg{ + Enabled: true, + RefreshBeforeExpirySeconds: 60, + } + } + + // Distinct prefix "oauth2te" to avoid collision with OIDC ("cc", "user"). + cacheKey := generateCacheKey("oauth2te", cfg.TokenURL, cfg.ClientID, cfg.Scopes) + + return &OAuth2TokenEndpointProvider{ + config: cfg, + cacheConfig: cacheCfg, + tokenFetcher: NewTokenFetcher(httpClient, nil), + cacheKey: cacheKey, + }, nil +} + +// Apply implements Provider interface. +func (p *OAuth2TokenEndpointProvider) Apply(ctx context.Context, req *http.Request) error { + token, err := p.tokenFetcher.FetchOAuth2TokenEndpointToken(ctx, p.config, p.cacheKey, p.cacheConfig) + if err != nil { + return fmt.Errorf("fetch oauth2 token endpoint token: %w", err) + } + + // token_type allowlist: Bearer only (case-insensitive). Rejects Basic/MAC/DPoP + // to prevent a malicious token endpoint from forcing a non-Bearer scheme that + // would expose the access token via an unintended Authorization header form. + if token.TokenType != "" && !strings.EqualFold(token.TokenType, "Bearer") { + return ErrOAuth2TokenEndpointUnsupportedTokenType + } + + // Empty access_token guard: a token endpoint returning 200 OK with an + // empty access_token would otherwise produce the literal header + // "Authorization: Bearer " on the downstream request, which most servers + // either accept as an unauthenticated call or surface as a misleading 401. + // Fail fast at fetch time so the caller sees the actual contract + // violation. + if token.AccessToken == "" { + return ErrOAuth2TokenEndpointEmptyAccessToken + } + + req.Header.Set("Authorization", "Bearer "+token.AccessToken) + + return nil +} + +// Type implements Provider interface. +func (p *OAuth2TokenEndpointProvider) Type() Type { + return TypeOAuth2TokenEndpoint +} + +// Verify OAuth2TokenEndpointProvider implements Provider interface. +var _ Provider = (*OAuth2TokenEndpointProvider)(nil) diff --git a/pkg/executors/http/auth/oauth2_token_endpoint_export_test.go b/pkg/executors/http/auth/oauth2_token_endpoint_export_test.go new file mode 100644 index 0000000..e8b1056 --- /dev/null +++ b/pkg/executors/http/auth/oauth2_token_endpoint_export_test.go @@ -0,0 +1,42 @@ +// Copyright (c) 2026 Lerian Studio. All rights reserved. +// Use of this source code is governed by the Elastic License 2.0 +// that can be found in the LICENSE file. + +package auth + +// This file lives only in the _test.go set so the helper is never compiled +// into the production binary. It exposes a defensive snapshot of an +// OAuth2TokenEndpointProvider's parsed config so same-package tests can +// verify factory.NewFromConfig populated every field (not just Type()). +// +// The Scopes slice and ExtraParams map are copied so test mutation cannot +// race with concurrent Apply calls reading the live config. + +// testConfigSnapshot returns a defensive copy of the provider's config for +// same-package test assertions. +func (p *OAuth2TokenEndpointProvider) testConfigSnapshot() *OAuth2TokenEndpointConfig { + if p == nil || p.config == nil { + return nil + } + + scopes := make([]string, len(p.config.Scopes)) + copy(scopes, p.config.Scopes) + + var extraParams map[string]string + if p.config.ExtraParams != nil { + extraParams = make(map[string]string, len(p.config.ExtraParams)) + for k, v := range p.config.ExtraParams { + extraParams[k] = v + } + } + + return &OAuth2TokenEndpointConfig{ + TokenURL: p.config.TokenURL, + ClientID: p.config.ClientID, + ClientSecret: p.config.ClientSecret, + CredentialsLocation: p.config.CredentialsLocation, + Scopes: scopes, + Audience: p.config.Audience, + ExtraParams: extraParams, + } +} diff --git a/pkg/executors/http/auth/oauth2_token_endpoint_test.go b/pkg/executors/http/auth/oauth2_token_endpoint_test.go new file mode 100644 index 0000000..51556d4 --- /dev/null +++ b/pkg/executors/http/auth/oauth2_token_endpoint_test.go @@ -0,0 +1,611 @@ +// Copyright (c) 2026 Lerian Studio. All rights reserved. +// Use of this source code is governed by the Elastic License 2.0 +// that can be found in the LICENSE file. + +package auth + +import ( + "context" + "encoding/base64" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "sync/atomic" + "testing" + + "github.com/LerianStudio/flowker/pkg/safehttp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// tokenEndpointHandler returns an httptest server that serves a single POST +// token endpoint. The handler captures the last received form values and +// headers so tests can assert on them. tokenType controls the token_type +// returned in the body — empty means the field is omitted entirely. +func tokenEndpointServer(t *testing.T, accessToken, tokenType string) (*httptest.Server, *atomic.Pointer[capturedTokenRequest]) { + t.Helper() + + captured := &atomic.Pointer[capturedTokenRequest]{} + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + if err := r.ParseForm(); err != nil { + http.Error(w, "invalid form", http.StatusBadRequest) + return + } + + captured.Store(&capturedTokenRequest{ + Form: r.Form, + AuthHdr: r.Header.Get("Authorization"), + }) + + resp := map[string]any{ + "access_token": accessToken, + "expires_in": 3600, + } + if tokenType != "" { + resp["token_type"] = tokenType + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) + })) + + return srv, captured +} + +type capturedTokenRequest struct { + Form url.Values + AuthHdr string +} + +func TestOAuth2TokenEndpointProvider_BodyCredentials(t *testing.T) { + srv, captured := tokenEndpointServer(t, "abc123", "Bearer") + defer srv.Close() + + cfg := &OAuth2TokenEndpointConfig{ + TokenURL: srv.URL, + ClientID: "client", + ClientSecret: "secret", + } + + provider, err := NewOAuth2TokenEndpointProvider(cfg, nil, srv.Client()) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, "http://target.example.com", nil) + require.NoError(t, err) + + require.NoError(t, provider.Apply(context.Background(), req)) + + assert.Equal(t, "Bearer abc123", req.Header.Get("Authorization")) + + got := captured.Load() + require.NotNil(t, got) + assert.Equal(t, "client", got.Form.Get("client_id")) + assert.Equal(t, "secret", got.Form.Get("client_secret")) + assert.Empty(t, got.AuthHdr, "no Basic auth header should be sent when credentials are in the body") +} + +func TestOAuth2TokenEndpointProvider_BasicHeaderCredentials(t *testing.T) { + srv, captured := tokenEndpointServer(t, "tok-basic", "Bearer") + defer srv.Close() + + cfg := &OAuth2TokenEndpointConfig{ + TokenURL: srv.URL, + ClientID: "client", + ClientSecret: "secret", + CredentialsLocation: "basic_header", + } + + provider, err := NewOAuth2TokenEndpointProvider(cfg, nil, srv.Client()) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, "http://target.example.com", nil) + require.NoError(t, err) + + require.NoError(t, provider.Apply(context.Background(), req)) + + assert.Equal(t, "Bearer tok-basic", req.Header.Get("Authorization")) + + got := captured.Load() + require.NotNil(t, got) + expectedBasic := "Basic " + base64.StdEncoding.EncodeToString([]byte("client:secret")) + assert.Equal(t, expectedBasic, got.AuthHdr) + assert.Empty(t, got.Form.Get("client_id"), "client_id must not be in body when using basic_header") + assert.Empty(t, got.Form.Get("client_secret"), "client_secret must not be in body when using basic_header") +} + +func TestOAuth2TokenEndpointProvider_ExtraParamsAndScopes(t *testing.T) { + srv, captured := tokenEndpointServer(t, "tok-extra", "Bearer") + defer srv.Close() + + cfg := &OAuth2TokenEndpointConfig{ + TokenURL: srv.URL, + ClientID: "client", + ClientSecret: "secret", + Scopes: []string{"api:read", "api:write"}, + Audience: "my-audience", + ExtraParams: map[string]string{"grant_type": "client_credentials", "tenant": "acme"}, + } + + provider, err := NewOAuth2TokenEndpointProvider(cfg, nil, srv.Client()) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, "http://target.example.com", nil) + require.NoError(t, err) + + require.NoError(t, provider.Apply(context.Background(), req)) + + got := captured.Load() + require.NotNil(t, got) + assert.Equal(t, "api:read api:write", got.Form.Get("scope")) + assert.Equal(t, "my-audience", got.Form.Get("audience")) + assert.Equal(t, "client_credentials", got.Form.Get("grant_type")) + assert.Equal(t, "acme", got.Form.Get("tenant")) +} + +func TestOAuth2TokenEndpointProvider_TokenCached(t *testing.T) { + var hits int32 + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&hits, 1) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "cached-token", + "token_type": "Bearer", + "expires_in": 3600, + }) + })) + defer srv.Close() + + cfg := &OAuth2TokenEndpointConfig{ + TokenURL: srv.URL, + ClientID: "client", + ClientSecret: "secret", + } + + provider, err := NewOAuth2TokenEndpointProvider(cfg, nil, srv.Client()) + require.NoError(t, err) + + for i := 0; i < 3; i++ { + req, err := http.NewRequest(http.MethodGet, "http://target.example.com", nil) + require.NoError(t, err) + require.NoError(t, provider.Apply(context.Background(), req)) + assert.Equal(t, "Bearer cached-token", req.Header.Get("Authorization")) + } + + assert.Equal(t, int32(1), atomic.LoadInt32(&hits), "token endpoint should be hit only once due to cache") +} + +func TestOAuth2TokenEndpointProvider_ConstructorValidation(t *testing.T) { + tests := []struct { + name string + cfg *OAuth2TokenEndpointConfig + wantErr error + }{ + { + name: "nil config", + cfg: nil, + wantErr: ErrOAuth2TokenEndpointConfigRequired, + }, + { + name: "missing token_url", + cfg: &OAuth2TokenEndpointConfig{ + ClientID: "c", + ClientSecret: "s", + }, + wantErr: ErrOAuth2TokenEndpointURLRequired, + }, + { + name: "missing client_id", + cfg: &OAuth2TokenEndpointConfig{ + TokenURL: "https://idp.example.com/token", + ClientSecret: "s", + }, + wantErr: ErrOAuth2TokenEndpointClientRequired, + }, + { + name: "missing client_secret", + cfg: &OAuth2TokenEndpointConfig{ + TokenURL: "https://idp.example.com/token", + ClientID: "c", + }, + wantErr: ErrOAuth2TokenEndpointSecretRequired, + }, + { + name: "invalid credentials_location", + cfg: &OAuth2TokenEndpointConfig{ + TokenURL: "https://idp.example.com/token", + ClientID: "c", + ClientSecret: "s", + CredentialsLocation: "query", + }, + wantErr: ErrOAuth2TokenEndpointLocationInvalid, + }, + { + name: "reserved extra_params key", + cfg: &OAuth2TokenEndpointConfig{ + TokenURL: "https://idp.example.com/token", + ClientID: "c", + ClientSecret: "s", + ExtraParams: map[string]string{"scope": "override"}, + }, + wantErr: ErrOAuth2TokenEndpointReservedExtraParam, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewOAuth2TokenEndpointProvider(tt.cfg, nil, nil) + require.ErrorIs(t, err, tt.wantErr) + }) + } +} + +func TestOAuth2TokenEndpointProvider_DefaultCredentialsLocation(t *testing.T) { + // Contract: an empty credentials_location must be treated as "body" at + // request time WITHOUT mutating the caller's *cfg. Two assertions: + // + // 1. The input cfg pointer is preserved verbatim (no side effects on the + // caller). This protects callers that reuse the struct, derive it from + // a shared source, or hash it. + // 2. The provider's behavior matches what the explicit "body" case emits — + // client_id and client_secret appear in the form body, and no Basic + // auth header is sent. + srv, captured := tokenEndpointServer(t, "tok-default", "Bearer") + defer srv.Close() + + cfg := &OAuth2TokenEndpointConfig{ + TokenURL: srv.URL, + ClientID: "client", + ClientSecret: "secret", + // CredentialsLocation intentionally left empty. + } + + provider, err := NewOAuth2TokenEndpointProvider(cfg, nil, srv.Client()) + require.NoError(t, err) + + assert.Equal(t, "", cfg.CredentialsLocation, + "constructor must NOT mutate caller's cfg.CredentialsLocation; "+ + "empty must remain empty so callers reusing the struct are not surprised") + + req, err := http.NewRequest(http.MethodGet, "http://target.example.com", nil) + require.NoError(t, err) + + require.NoError(t, provider.Apply(context.Background(), req)) + + assert.Equal(t, "Bearer tok-default", req.Header.Get("Authorization")) + + got := captured.Load() + require.NotNil(t, got) + assert.Equal(t, "client", got.Form.Get("client_id"), + "empty credentials_location must behave like 'body' — client_id in form") + assert.Equal(t, "secret", got.Form.Get("client_secret"), + "empty credentials_location must behave like 'body' — client_secret in form") + assert.Empty(t, got.AuthHdr, + "no Basic auth header should be sent when credentials_location is empty (defaults to body)") +} + +func TestOAuth2TokenEndpointProvider_Type(t *testing.T) { + cfg := &OAuth2TokenEndpointConfig{ + TokenURL: "https://idp.example.com/token", + ClientID: "client", + ClientSecret: "secret", + } + + provider, err := NewOAuth2TokenEndpointProvider(cfg, nil, nil) + require.NoError(t, err) + + assert.Equal(t, TypeOAuth2TokenEndpoint, provider.Type()) +} + +func TestOAuth2TokenEndpointProvider_EmptyTokenTypeDefaultsToBearer(t *testing.T) { + srv, _ := tokenEndpointServer(t, "no-token-type", "") + defer srv.Close() + + cfg := &OAuth2TokenEndpointConfig{ + TokenURL: srv.URL, + ClientID: "client", + ClientSecret: "secret", + } + + provider, err := NewOAuth2TokenEndpointProvider(cfg, nil, srv.Client()) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, "http://target.example.com", nil) + require.NoError(t, err) + + require.NoError(t, provider.Apply(context.Background(), req)) + + assert.Equal(t, "Bearer no-token-type", req.Header.Get("Authorization")) +} + +func TestOAuth2TokenEndpointProvider_BearerTokenTypeCaseInsensitive(t *testing.T) { + cases := []string{"Bearer", "bearer", "BEARER", "BeArEr"} + + for _, tt := range cases { + t.Run(tt, func(t *testing.T) { + srv, _ := tokenEndpointServer(t, "tok-case", tt) + defer srv.Close() + + cfg := &OAuth2TokenEndpointConfig{ + TokenURL: srv.URL, + ClientID: "client", + ClientSecret: "secret", + } + + provider, err := NewOAuth2TokenEndpointProvider(cfg, nil, srv.Client()) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, "http://target.example.com", nil) + require.NoError(t, err) + + require.NoError(t, provider.Apply(context.Background(), req)) + + // Always emits canonical "Bearer", regardless of upstream casing. + assert.Equal(t, "Bearer tok-case", req.Header.Get("Authorization")) + }) + } +} + +func TestOAuth2TokenEndpointProvider_RejectsNonBearerTokenType(t *testing.T) { + maliciousTokenTypes := []string{"Basic", "MAC", "DPoP"} + + for _, tt := range maliciousTokenTypes { + t.Run(tt, func(t *testing.T) { + srv, _ := tokenEndpointServer(t, "secret-leaked-via-header", tt) + defer srv.Close() + + cfg := &OAuth2TokenEndpointConfig{ + TokenURL: srv.URL, + ClientID: "client", + ClientSecret: "secret", + } + + provider, err := NewOAuth2TokenEndpointProvider(cfg, nil, srv.Client()) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, "http://target.example.com", nil) + require.NoError(t, err) + + err = provider.Apply(context.Background(), req) + require.ErrorIs(t, err, ErrOAuth2TokenEndpointUnsupportedTokenType) + + // The access token must NOT have leaked into any header on the + // downstream request — this is the entire point of the allowlist. + assert.Empty(t, req.Header.Get("Authorization")) + + for _, vals := range req.Header { + for _, v := range vals { + assert.False(t, strings.Contains(v, "secret-leaked-via-header"), + "access_token must not appear in any downstream header") + } + } + }) + } +} + +// TestOAuth2TokenEndpointProvider_SSRFBlocksPrivateNetwork verifies the +// suite-wide SetAllowPrivateForTest(true) (see main_test.go) is locally +// flipped off and that the OAuth2 token endpoint provider correctly delegates +// to safehttp's SSRF policy when the configured TokenURL points at a +// private/loopback address. The expected error wraps safehttp.ErrBlocked so +// callers can branch via errors.Is — same contract the OIDC provider obeys. +func TestOAuth2TokenEndpointProvider_SSRFBlocksPrivateNetwork(t *testing.T) { + previous := safehttp.SetAllowPrivateForTest(false) + t.Cleanup(func() { safehttp.SetAllowPrivateForTest(previous) }) + + cfg := &OAuth2TokenEndpointConfig{ + TokenURL: "http://127.0.0.1:9999/token", + ClientID: "client", + ClientSecret: "secret", + } + + provider, err := NewOAuth2TokenEndpointProvider(cfg, nil, nil) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, "https://example.com/probe", nil) + require.NoError(t, err) + + err = provider.Apply(context.Background(), req) + require.Error(t, err, "Apply must reject when TokenURL resolves to a blocked range") + require.ErrorIs(t, err, safehttp.ErrBlocked, + "the SSRF rejection must wrap safehttp.ErrBlocked so callers can branch via errors.Is") + assert.Empty(t, req.Header.Get("Authorization"), + "no Authorization header must be set on the downstream request when the token fetch is blocked") +} + +// TestOAuth2TokenEndpointProvider_CacheKeyIsolation verifies that two +// providers configured against the same token endpoint with different +// ClientIDs do NOT share a cache entry. Each provider must hit the endpoint +// exactly once on first Apply — a regression here would cause provider A to +// authenticate downstream calls with provider B's token, a cross-tenant +// credential bleed. +func TestOAuth2TokenEndpointProvider_CacheKeyIsolation(t *testing.T) { + var hits int32 + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&hits, 1) + + if err := r.ParseForm(); err != nil { + http.Error(w, "invalid form", http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "application/json") + // Echo the client_id back as the access_token so the test can verify + // each provider received its own credential. + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "tok-for-" + r.Form.Get("client_id"), + "token_type": "Bearer", + "expires_in": 3600, + }) + })) + defer srv.Close() + + cfgA := &OAuth2TokenEndpointConfig{ + TokenURL: srv.URL, + ClientID: "client-a", + ClientSecret: "secret", + } + cfgB := &OAuth2TokenEndpointConfig{ + TokenURL: srv.URL, + ClientID: "client-b", + ClientSecret: "secret", + } + + providerA, err := NewOAuth2TokenEndpointProvider(cfgA, nil, srv.Client()) + require.NoError(t, err) + providerB, err := NewOAuth2TokenEndpointProvider(cfgB, nil, srv.Client()) + require.NoError(t, err) + + require.NotEqual(t, providerA.cacheKey, providerB.cacheKey, + "cache keys must differ when ClientID differs — otherwise cross-credential bleed") + + reqA, _ := http.NewRequest(http.MethodGet, "http://target.example.com", nil) + require.NoError(t, providerA.Apply(context.Background(), reqA)) + assert.Equal(t, "Bearer tok-for-client-a", reqA.Header.Get("Authorization")) + + reqB, _ := http.NewRequest(http.MethodGet, "http://target.example.com", nil) + require.NoError(t, providerB.Apply(context.Background(), reqB)) + assert.Equal(t, "Bearer tok-for-client-b", reqB.Header.Get("Authorization")) + + assert.Equal(t, int32(2), atomic.LoadInt32(&hits), + "each provider must hit the endpoint once (separate cache entries)") +} + +// TestOAuth2TokenEndpointProvider_TokenEndpointReturns401 verifies that when +// the token endpoint rejects the credentials with HTTP 401, Apply surfaces an +// error and does NOT set an Authorization header on the downstream request. +func TestOAuth2TokenEndpointProvider_TokenEndpointReturns401(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, `{"error":"invalid_client"}`, http.StatusUnauthorized) + })) + defer srv.Close() + + cfg := &OAuth2TokenEndpointConfig{ + TokenURL: srv.URL, + ClientID: "client", + ClientSecret: "wrong", + } + + provider, err := NewOAuth2TokenEndpointProvider(cfg, nil, srv.Client()) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, "http://target.example.com", nil) + require.NoError(t, err) + + err = provider.Apply(context.Background(), req) + require.Error(t, err, "Apply must error when token endpoint returns 401") + assert.Contains(t, err.Error(), "401", + "error must identify the upstream HTTP status so a future refactor that swallows the cause is caught") + assert.Empty(t, req.Header.Get("Authorization"), + "no Authorization header must be set when token fetch fails") +} + +// TestOAuth2TokenEndpointProvider_TokenEndpointReturnsMalformedJSON verifies +// that a token endpoint returning 200 OK with a body that is not valid JSON +// causes Apply to fail, with no Authorization header on the downstream +// request. This protects against silent credential omission when an +// upstream proxy returns plaintext error pages with a 200 status. +func TestOAuth2TokenEndpointProvider_TokenEndpointReturnsMalformedJSON(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte("not-json")) + })) + defer srv.Close() + + cfg := &OAuth2TokenEndpointConfig{ + TokenURL: srv.URL, + ClientID: "client", + ClientSecret: "secret", + } + + provider, err := NewOAuth2TokenEndpointProvider(cfg, nil, srv.Client()) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, "http://target.example.com", nil) + require.NoError(t, err) + + err = provider.Apply(context.Background(), req) + require.Error(t, err, "Apply must error when token endpoint returns malformed JSON") + assert.Contains(t, err.Error(), "decode token response", + "error must mention the decode failure so the operator can diagnose it") + assert.Empty(t, req.Header.Get("Authorization"), + "no Authorization header must be set when token decode fails") +} + +// TestOAuth2TokenEndpointProvider_RejectsEmptyAccessToken verifies the empty +// access_token guard: a token endpoint returning 200 OK with an empty +// access_token would otherwise produce the literal header +// "Authorization: Bearer " on the downstream request — a silent +// authentication failure that most servers misreport as 401 or accept as +// unauthenticated. The provider must fail fast in this case. +func TestOAuth2TokenEndpointProvider_RejectsEmptyAccessToken(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "", + "token_type": "Bearer", + "expires_in": 3600, + }) + })) + defer srv.Close() + + cfg := &OAuth2TokenEndpointConfig{ + TokenURL: srv.URL, + ClientID: "client", + ClientSecret: "secret", + } + + provider, err := NewOAuth2TokenEndpointProvider(cfg, nil, srv.Client()) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, "http://target.example.com", nil) + require.NoError(t, err) + + err = provider.Apply(context.Background(), req) + require.ErrorIs(t, err, ErrOAuth2TokenEndpointEmptyAccessToken, + "empty access_token must trigger the sentinel; otherwise the downstream "+ + "request would carry the literal 'Authorization: Bearer ' header") + assert.Empty(t, req.Header.Get("Authorization"), + "the empty-token guard must prevent setting a malformed Authorization header") +} + +// TestOAuth2TokenEndpointProvider_NoGrantTypeByDefault locks in the +// Delorean-style contract: when extra_params does NOT set grant_type, the +// emitted token request must omit the grant_type field entirely. A +// regression here would silently break custom OAuth2-like endpoints that +// reject unexpected fields (Delorean is the canonical example). +func TestOAuth2TokenEndpointProvider_NoGrantTypeByDefault(t *testing.T) { + srv, captured := tokenEndpointServer(t, "tok-no-grant", "Bearer") + defer srv.Close() + + cfg := &OAuth2TokenEndpointConfig{ + TokenURL: srv.URL, + ClientID: "client", + ClientSecret: "secret", + ExtraParams: map[string]string{"tenant": "acme"}, // intentionally NO grant_type + } + + provider, err := NewOAuth2TokenEndpointProvider(cfg, nil, srv.Client()) + require.NoError(t, err) + + req, err := http.NewRequest(http.MethodGet, "http://target.example.com", nil) + require.NoError(t, err) + + require.NoError(t, provider.Apply(context.Background(), req)) + + got := captured.Load() + require.NotNil(t, got) + assert.False(t, got.Form.Has("grant_type"), + "grant_type must not appear in the token request body when not supplied via extra_params") + assert.Equal(t, "acme", got.Form.Get("tenant"), + "other extra_params must still be forwarded") +} diff --git a/pkg/executors/http/auth/token_fetcher.go b/pkg/executors/http/auth/token_fetcher.go index b3686b4..9f04000 100644 --- a/pkg/executors/http/auth/token_fetcher.go +++ b/pkg/executors/http/auth/token_fetcher.go @@ -384,3 +384,70 @@ func (f *TokenFetcher) InvalidateCache(cacheKey string) { delete(f.cache, cacheKey) f.cacheMu.Unlock() } + +// FetchOAuth2TokenEndpointToken fetches a token from an explicit OAuth2 token +// endpoint without performing OIDC discovery. +func (f *TokenFetcher) FetchOAuth2TokenEndpointToken( + ctx context.Context, + cfg *OAuth2TokenEndpointConfig, + cacheKey string, + cacheConfig *CacheCfg, +) (*TokenResponse, error) { + if token := f.getCachedToken(cacheKey, cacheConfig); token != nil { + return token, nil + } + + token, err := f.fetchNewOAuth2TokenEndpointToken(ctx, cfg) + if err != nil { + return nil, err + } + + f.cacheToken(cacheKey, token, cacheConfig) + + return token, nil +} + +func (f *TokenFetcher) fetchNewOAuth2TokenEndpointToken(ctx context.Context, cfg *OAuth2TokenEndpointConfig) (*TokenResponse, error) { + data := f.buildOAuth2TokenEndpointData(cfg) + + req, err := newPinnedTokenRequest(ctx, cfg.TokenURL, data.Encode()) + if err != nil { + return nil, err + } + + if cfg.CredentialsLocation == "basic_header" { + auth := base64.StdEncoding.EncodeToString([]byte(cfg.ClientID + ":" + cfg.ClientSecret)) + req.Header.Set("Authorization", "Basic "+auth) + } + + return f.executeTokenRequest(req) +} + +func (f *TokenFetcher) buildOAuth2TokenEndpointData(cfg *OAuth2TokenEndpointConfig) url.Values { + data := url.Values{} + + // No grant_type by default — Delorean-style endpoints don't require it. + // Callers that need a specific grant_type set it via extra_params. + + // Empty credentials_location is treated as "body" (the default) — the + // constructor deliberately does not mutate the caller's *cfg, so the + // empty-string case must be handled here. + if cfg.CredentialsLocation == "" || cfg.CredentialsLocation == "body" { + data.Set("client_id", cfg.ClientID) + data.Set("client_secret", cfg.ClientSecret) + } + + if len(cfg.Scopes) > 0 { + data.Set("scope", strings.Join(cfg.Scopes, " ")) + } + + if cfg.Audience != "" { + data.Set("audience", cfg.Audience) + } + + for k, v := range cfg.ExtraParams { + data.Set(k, v) + } + + return data +} diff --git a/pkg/executors/http/auth/types.go b/pkg/executors/http/auth/types.go index 68f5e1f..b4a663c 100644 --- a/pkg/executors/http/auth/types.go +++ b/pkg/executors/http/auth/types.go @@ -33,6 +33,11 @@ const ( // TypeOIDCUser represents OIDC resource owner password credentials flow. TypeOIDCUser Type = "oidc_user" + + // TypeOAuth2TokenEndpoint represents OAuth2-style client credentials with an explicit + // token endpoint URL (skips OIDC discovery). For OAuth2-like providers that don't + // publish .well-known/openid-configuration metadata. + TypeOAuth2TokenEndpoint Type = "oauth2_token_endpoint" // #nosec G101 -- auth flow identifier, not a credential value ) // Config represents the authentication configuration. @@ -92,6 +97,23 @@ type OIDCUserConfig struct { ExtraParams map[string]string `json:"extra_params"` } +// OAuth2TokenEndpointConfig represents OAuth2 client credentials with a custom +// token endpoint URL (no OIDC discovery). Credentials can be sent in the form +// body or as a Basic Authorization header. +// +// Validation runs in NewOAuth2TokenEndpointProvider via explicit field checks; +// struct tags are intentionally omitted to avoid signaling validation that +// never executes (the package does not invoke validator.Struct on this config). +type OAuth2TokenEndpointConfig struct { + TokenURL string `json:"token_url"` + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + CredentialsLocation string `json:"credentials_location"` + Scopes []string `json:"scopes"` + Audience string `json:"audience"` + ExtraParams map[string]string `json:"extra_params"` +} + // Provider defines the interface for authentication providers. type Provider interface { // Apply applies authentication to the HTTP request. diff --git a/pkg/executors/http/http.go b/pkg/executors/http/http.go index 9fbec37..b53272f 100644 --- a/pkg/executors/http/http.go +++ b/pkg/executors/http/http.go @@ -178,7 +178,7 @@ const schema = `{ "properties": { "type": { "type": "string", - "enum": ["none", "api_key", "bearer", "basic", "oidc_client_credentials", "oidc_user"], + "enum": ["none", "api_key", "bearer", "basic", "oidc_client_credentials", "oidc_user", "oauth2_token_endpoint"], "default": "none", "description": "Authentication type" }, @@ -307,6 +307,28 @@ const schema = `{ } } } + }, + { + "if": { + "properties": { "type": { "const": "oauth2_token_endpoint" } } + }, + "then": { + "properties": { + "config": { + "type": "object", + "required": ["token_url", "client_id", "client_secret"], + "properties": { + "token_url": { "type": "string", "format": "uri", "description": "OAuth2 token endpoint URL (no OIDC discovery is performed)" }, + "client_id": { "type": "string", "minLength": 1, "description": "OAuth2 Client ID" }, + "client_secret": { "type": "string", "minLength": 1, "description": "OAuth2 Client Secret" }, + "credentials_location": { "type": "string", "enum": ["body", "basic_header"], "default": "body", "description": "Where to send client_id/client_secret" }, + "scopes": { "type": "array", "items": { "type": "string" }, "description": "OAuth2 scopes to request" }, + "audience": { "type": "string", "description": "Optional audience parameter" }, + "extra_params": { "type": "object", "additionalProperties": { "type": "string" }, "description": "Extra form parameters for token request (must not include client_id, client_secret, scope, audience)" } + } + } + } + } } ] }, diff --git a/pkg/model/executor_configuration.go b/pkg/model/executor_configuration.go index 1646fb3..40f1825 100644 --- a/pkg/model/executor_configuration.go +++ b/pkg/model/executor_configuration.go @@ -111,7 +111,7 @@ var ( } ErrExecutorConfigAuthTypeInvalid = pkg.ValidationError{ Code: constant.ErrExecutorConfigAuthTypeInvalid.Error(), - Message: "authentication type must be one of: none, api_key, bearer, basic, oidc_client_credentials, oidc_user", + Message: "authentication type must be one of: none, api_key, bearer, basic, oidc_client_credentials, oidc_user, oauth2_token_endpoint", } ) @@ -201,6 +201,7 @@ func NewExecutorAuthentication(authType string, config map[string]any) (*Executo validTypes := map[string]bool{ "none": true, "api_key": true, "bearer": true, "basic": true, "oidc_client_credentials": true, "oidc_user": true, + "oauth2_token_endpoint": true, } if !validTypes[authType] { return nil, ErrExecutorConfigAuthTypeInvalid diff --git a/pkg/model/executor_configuration_input.go b/pkg/model/executor_configuration_input.go index cacb270..1e2bcc4 100644 --- a/pkg/model/executor_configuration_input.go +++ b/pkg/model/executor_configuration_input.go @@ -37,7 +37,7 @@ type ExecutorEndpointInput struct { // ExecutorAuthenticationInput is the input DTO for executor authentication configuration. type ExecutorAuthenticationInput struct { - Type string `json:"type" validate:"required,oneof=none api_key bearer basic oidc_client_credentials oidc_user"` + Type string `json:"type" validate:"required,oneof=none api_key bearer basic oidc_client_credentials oidc_user oauth2_token_endpoint"` Config map[string]any `json:"config,omitempty"` } diff --git a/pkg/model/executor_configuration_test.go b/pkg/model/executor_configuration_test.go index 38f75bd..5c7f5c1 100644 --- a/pkg/model/executor_configuration_test.go +++ b/pkg/model/executor_configuration_test.go @@ -550,7 +550,11 @@ func TestNewExecutorAuthentication_InvalidType(t *testing.T) { } func TestNewExecutorAuthentication_ValidTypes(t *testing.T) { - validTypes := []string{"none", "api_key", "bearer", "basic", "oidc_client_credentials", "oidc_user"} + validTypes := []string{ + "none", "api_key", "bearer", "basic", + "oidc_client_credentials", "oidc_user", + "oauth2_token_endpoint", + } for _, authType := range validTypes { auth, err := model.NewExecutorAuthentication(authType, nil) @@ -559,6 +563,20 @@ func TestNewExecutorAuthentication_ValidTypes(t *testing.T) { } } +func TestNewExecutorAuthenticationAcceptsOAuth2TokenEndpoint(t *testing.T) { + authConfig := map[string]any{ + "token_url": "https://idp.example.com/token", + "client_id": "my-client", + "client_secret": "my-secret", + } + + auth, err := model.NewExecutorAuthentication("oauth2_token_endpoint", authConfig) + + require.NoError(t, err) + assert.Equal(t, "oauth2_token_endpoint", auth.Type()) + assert.Equal(t, authConfig, auth.Config()) +} + func TestNewExecutorAuthentication_NilConfig(t *testing.T) { auth, err := model.NewExecutorAuthentication("none", nil) diff --git a/tests/integration/http_provider_auth_test.go b/tests/integration/http_provider_auth_test.go index b26b1aa..28f07c3 100644 --- a/tests/integration/http_provider_auth_test.go +++ b/tests/integration/http_provider_auth_test.go @@ -632,3 +632,231 @@ func TestHTTPProviderAuthMissingConfig(t *testing.T) { assert.Equal(t, executor.ExecutionStatusError, result.Status) assert.Contains(t, result.Error, "key is required") } + +// mockTokenEndpointServer creates a mock OAuth2 token endpoint server WITHOUT +// OIDC discovery. Only POST /token is served — the absence of a +// .well-known/openid-configuration handler proves the oauth2_token_endpoint +// provider does not perform discovery. +func mockTokenEndpointServer(t *testing.T) *httptest.Server { + t.Helper() + + mux := http.NewServeMux() + + mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, _ *http.Request) { + // Fail loudly if discovery is attempted — proves the provider does + // not call this endpoint. + t.Errorf("OIDC discovery must NOT be called for oauth2_token_endpoint") + w.WriteHeader(http.StatusInternalServerError) + }) + + mux.HandleFunc("/token", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + if err := r.ParseForm(); err != nil { + http.Error(w, "invalid form", http.StatusBadRequest) + return + } + + var clientID, clientSecret string + + authHeader := r.Header.Get("Authorization") + if strings.HasPrefix(authHeader, "Basic ") { + decoded, err := base64.StdEncoding.DecodeString(strings.TrimPrefix(authHeader, "Basic ")) + if err == nil { + parts := strings.SplitN(string(decoded), ":", 2) + if len(parts) == 2 { + clientID = parts[0] + clientSecret = parts[1] + } + } + } + + if clientID == "" { + clientID = r.Form.Get("client_id") + clientSecret = r.Form.Get("client_secret") + } + + if clientID != "test-client" || clientSecret != "test-secret" { + w.WriteHeader(http.StatusUnauthorized) + writeJSON(t, w, map[string]string{ + "error": "invalid_client", + "error_description": "Invalid client credentials", + }) + + return + } + + // Mirror the tenant extra_param back in scope so tests can assert it + // reached the token endpoint. + tenant := r.Form.Get("tenant") + + token := map[string]any{ + "access_token": "mock-access-token-client-credentials", + "token_type": "Bearer", + "expires_in": 3600, + "scope": r.Form.Get("scope"), + "tenant": tenant, + } + + w.Header().Set("Content-Type", "application/json") + writeJSON(t, w, token) + }) + + return httptest.NewServer(mux) +} + +func TestHTTPProviderAuthOAuth2TokenEndpoint(t *testing.T) { + tokenSrv := mockTokenEndpointServer(t) + defer tokenSrv.Close() + + target := mockTargetServer(t) + defer target.Close() + + runner := httpExecutor.NewRunner() + + input := executor.ExecutionInput{ + Config: map[string]any{ + "method": "GET", + "url": target.URL + "/protected", + "auth": map[string]any{ + "type": "oauth2_token_endpoint", + "config": map[string]any{ + "token_url": tokenSrv.URL + "/token", + "client_id": "test-client", + "client_secret": "test-secret", + }, + }, + }, + HTTPClient: &http.Client{Timeout: 10 * time.Second}, + } + + result, err := runner.Execute(context.Background(), input) + require.NoError(t, err) + assert.Equal(t, executor.ExecutionStatusSuccess, result.Status) + assert.Equal(t, 200, result.Data["status"]) + + body := result.Data["body"].(map[string]any) + assert.Equal(t, "oidc auth success", body["message"]) +} + +func TestHTTPProviderAuthOAuth2TokenEndpointBasicHeader(t *testing.T) { + tokenSrv := mockTokenEndpointServer(t) + defer tokenSrv.Close() + + target := mockTargetServer(t) + defer target.Close() + + runner := httpExecutor.NewRunner() + + input := executor.ExecutionInput{ + Config: map[string]any{ + "method": "GET", + "url": target.URL + "/protected", + "auth": map[string]any{ + "type": "oauth2_token_endpoint", + "config": map[string]any{ + "token_url": tokenSrv.URL + "/token", + "client_id": "test-client", + "client_secret": "test-secret", + "credentials_location": "basic_header", + }, + }, + }, + HTTPClient: &http.Client{Timeout: 10 * time.Second}, + } + + result, err := runner.Execute(context.Background(), input) + require.NoError(t, err) + assert.Equal(t, executor.ExecutionStatusSuccess, result.Status) + assert.Equal(t, 200, result.Data["status"]) +} + +func TestHTTPProviderAuthOAuth2TokenEndpointWithExtraParams(t *testing.T) { + var capturedTenant string + + tokenSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, "invalid form", http.StatusBadRequest) + return + } + + capturedTenant = r.Form.Get("tenant") + + clientID := r.Form.Get("client_id") + clientSecret := r.Form.Get("client_secret") + + if clientID != "test-client" || clientSecret != "test-secret" { + w.WriteHeader(http.StatusUnauthorized) + return + } + + w.Header().Set("Content-Type", "application/json") + writeJSON(t, w, map[string]any{ + "access_token": "mock-access-token-client-credentials", + "token_type": "Bearer", + "expires_in": 3600, + }) + })) + defer tokenSrv.Close() + + target := mockTargetServer(t) + defer target.Close() + + runner := httpExecutor.NewRunner() + + input := executor.ExecutionInput{ + Config: map[string]any{ + "method": "GET", + "url": target.URL + "/protected", + "auth": map[string]any{ + "type": "oauth2_token_endpoint", + "config": map[string]any{ + "token_url": tokenSrv.URL + "/token", + "client_id": "test-client", + "client_secret": "test-secret", + "extra_params": map[string]any{"tenant": "acme"}, + }, + }, + }, + HTTPClient: &http.Client{Timeout: 10 * time.Second}, + } + + result, err := runner.Execute(context.Background(), input) + require.NoError(t, err) + assert.Equal(t, executor.ExecutionStatusSuccess, result.Status) + assert.Equal(t, "acme", capturedTenant, "extra_params must reach the token endpoint") +} + +func TestHTTPProviderAuthOAuth2TokenEndpointInvalidCredentials(t *testing.T) { + tokenSrv := mockTokenEndpointServer(t) + defer tokenSrv.Close() + + target := mockTargetServer(t) + defer target.Close() + + runner := httpExecutor.NewRunner() + + input := executor.ExecutionInput{ + Config: map[string]any{ + "method": "GET", + "url": target.URL + "/protected", + "auth": map[string]any{ + "type": "oauth2_token_endpoint", + "config": map[string]any{ + "token_url": tokenSrv.URL + "/token", + "client_id": "wrong-client", + "client_secret": "wrong-secret", + }, + }, + }, + HTTPClient: &http.Client{Timeout: 10 * time.Second}, + } + + result, err := runner.Execute(context.Background(), input) + require.NoError(t, err) + assert.Equal(t, executor.ExecutionStatusError, result.Status) + assert.Contains(t, result.Error, "failed to apply authentication") +}