diff --git a/internal/api/oauthserver/client_auth.go b/internal/api/oauthserver/client_auth.go index 6ed505e84..cb27ccf88 100644 --- a/internal/api/oauthserver/client_auth.go +++ b/internal/api/oauthserver/client_auth.go @@ -108,3 +108,24 @@ func GetAllValidAuthMethods() []string { models.TokenEndpointAuthMethodClientSecretPost, } } + +// DetermineTokenEndpointAuthMethod determines the final token endpoint auth method using: +// 1. Explicit token_endpoint_auth_method if provided +// 2. Default based on client type (none for public, client_secret_basic for confidential) +func DetermineTokenEndpointAuthMethod(explicitAuthMethod, clientType string) string { + // Priority 1: Explicit token_endpoint_auth_method + if explicitAuthMethod != "" { + return explicitAuthMethod + } + + // Priority 2: Default based on client type + switch clientType { + case models.OAuthServerClientTypePublic: + return models.TokenEndpointAuthMethodNone + case models.OAuthServerClientTypeConfidential: + return models.TokenEndpointAuthMethodClientSecretBasic + default: + // Default to client_secret_basic for unknown/empty client type + return models.TokenEndpointAuthMethodClientSecretBasic + } +} diff --git a/internal/api/oauthserver/client_auth_test.go b/internal/api/oauthserver/client_auth_test.go index 40fec9c40..dda1fd2ae 100644 --- a/internal/api/oauthserver/client_auth_test.go +++ b/internal/api/oauthserver/client_auth_test.go @@ -395,6 +395,67 @@ func TestGetAllValidAuthMethods(t *testing.T) { } } +func TestDetermineTokenEndpointAuthMethod(t *testing.T) { + tests := []struct { + name string + explicitAuthMethod string + clientType string + expected string + }{ + { + name: "explicit none overrides client type", + explicitAuthMethod: models.TokenEndpointAuthMethodNone, + clientType: models.OAuthServerClientTypeConfidential, + expected: models.TokenEndpointAuthMethodNone, + }, + { + name: "explicit basic overrides client type", + explicitAuthMethod: models.TokenEndpointAuthMethodClientSecretBasic, + clientType: models.OAuthServerClientTypePublic, + expected: models.TokenEndpointAuthMethodClientSecretBasic, + }, + { + name: "explicit post overrides client type", + explicitAuthMethod: models.TokenEndpointAuthMethodClientSecretPost, + clientType: models.OAuthServerClientTypePublic, + expected: models.TokenEndpointAuthMethodClientSecretPost, + }, + { + name: "default none for public client", + explicitAuthMethod: "", + clientType: models.OAuthServerClientTypePublic, + expected: models.TokenEndpointAuthMethodNone, + }, + { + name: "default basic for confidential client", + explicitAuthMethod: "", + clientType: models.OAuthServerClientTypeConfidential, + expected: models.TokenEndpointAuthMethodClientSecretBasic, + }, + { + name: "default basic for empty client type", + explicitAuthMethod: "", + clientType: "", + expected: models.TokenEndpointAuthMethodClientSecretBasic, + }, + { + name: "default basic for unknown client type", + explicitAuthMethod: "", + clientType: "unknown_type", + expected: models.TokenEndpointAuthMethodClientSecretBasic, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := DetermineTokenEndpointAuthMethod(tt.explicitAuthMethod, tt.clientType) + if result != tt.expected { + t.Errorf("DetermineTokenEndpointAuthMethod() = %v, expected %v", result, tt.expected) + } + }) + } +} + // Helper function to check if a string contains a substring func containsString(s, substr string) bool { return len(s) >= len(substr) && (s == substr || (len(s) > len(substr) && diff --git a/internal/api/oauthserver/handlers.go b/internal/api/oauthserver/handlers.go index c61de4085..403e4c65e 100644 --- a/internal/api/oauthserver/handlers.go +++ b/internal/api/oauthserver/handlers.go @@ -51,24 +51,13 @@ type OAuthServerClientListResponse struct { // oauthServerClientToResponse converts a model to response format func oauthServerClientToResponse(client *models.OAuthServerClient) *OAuthServerClientResponse { - // Set token endpoint auth methods based on client type - var tokenEndpointAuthMethods string - // TODO(cemal) :: Remove this once we have the token endpoint auth method stored in the database - if client.IsPublic() { - // Public clients don't use client authentication - tokenEndpointAuthMethods = models.TokenEndpointAuthMethodNone - } else { - // Confidential clients use client secret authentication - tokenEndpointAuthMethods = models.TokenEndpointAuthMethodClientSecretBasic - } - response := &OAuthServerClientResponse{ ClientID: client.ID.String(), ClientType: client.ClientType, // OAuth 2.1 DCR fields RedirectURIs: client.GetRedirectURIs(), - TokenEndpointAuthMethod: tokenEndpointAuthMethods, + TokenEndpointAuthMethod: client.TokenEndpointAuthMethod, GrantTypes: client.GetGrantTypes(), ResponseTypes: []string{"code"}, // Always "code" in OAuth 2.1 ClientName: utilities.StringValue(client.ClientName), diff --git a/internal/api/oauthserver/service.go b/internal/api/oauthserver/service.go index dff27c774..503ddde01 100644 --- a/internal/api/oauthserver/service.go +++ b/internal/api/oauthserver/service.go @@ -263,15 +263,19 @@ func (s *Server) registerOAuthServerClient(ctx context.Context, params *OAuthSer // Determine client type using centralized logic clientType := DetermineClientType(params.ClientType, params.TokenEndpointAuthMethod) + // Determine token endpoint auth method based on explicit value or client type + tokenEndpointAuthMethod := DetermineTokenEndpointAuthMethod(params.TokenEndpointAuthMethod, clientType) + db := s.db.WithContext(ctx) client := &models.OAuthServerClient{ - ID: uuid.Must(uuid.NewV4()), - RegistrationType: params.RegistrationType, - ClientType: clientType, - ClientName: utilities.StringPtr(params.ClientName), - ClientURI: utilities.StringPtr(params.ClientURI), - LogoURI: utilities.StringPtr(params.LogoURI), + ID: uuid.Must(uuid.NewV4()), + RegistrationType: params.RegistrationType, + ClientType: clientType, + TokenEndpointAuthMethod: tokenEndpointAuthMethod, + ClientName: utilities.StringPtr(params.ClientName), + ClientURI: utilities.StringPtr(params.ClientURI), + LogoURI: utilities.StringPtr(params.LogoURI), } client.SetRedirectURIs(params.RedirectURIs) diff --git a/internal/models/oauth_client.go b/internal/models/oauth_client.go index c53a7ee62..0a22dd1e4 100644 --- a/internal/models/oauth_client.go +++ b/internal/models/oauth_client.go @@ -33,14 +33,15 @@ type OAuthServerClient struct { RegistrationType string `json:"registration_type" db:"registration_type"` ClientType string `json:"client_type" db:"client_type"` - RedirectURIs string `json:"-" db:"redirect_uris"` - GrantTypes string `json:"grant_types" db:"grant_types"` - ClientName *string `json:"client_name,omitempty" db:"client_name"` - ClientURI *string `json:"client_uri,omitempty" db:"client_uri"` - LogoURI *string `json:"logo_uri,omitempty" db:"logo_uri"` - CreatedAt time.Time `json:"created_at" db:"created_at"` - UpdatedAt time.Time `json:"updated_at" db:"updated_at"` - DeletedAt *time.Time `json:"deleted_at,omitempty" db:"deleted_at"` + RedirectURIs string `json:"-" db:"redirect_uris"` + GrantTypes string `json:"grant_types" db:"grant_types"` + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method" db:"token_endpoint_auth_method"` + ClientName *string `json:"client_name,omitempty" db:"client_name"` + ClientURI *string `json:"client_uri,omitempty" db:"client_uri"` + LogoURI *string `json:"logo_uri,omitempty" db:"logo_uri"` + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` + DeletedAt *time.Time `json:"deleted_at,omitempty" db:"deleted_at"` } // TableName returns the table name for the OAuthServerClient model diff --git a/internal/models/oauth_client_test.go b/internal/models/oauth_client_test.go index 6cee43611..3388b1834 100644 --- a/internal/models/oauth_client_test.go +++ b/internal/models/oauth_client_test.go @@ -49,13 +49,14 @@ func (ts *OAuthServerClientTestSuite) TestOAuthServerClientValidation() { testClientName := "Test Client" testSecretHash, _ := testHashClientSecret("test_secret") validClient := &OAuthServerClient{ - ID: uuid.Must(uuid.NewV4()), - ClientName: &testClientName, - RegistrationType: "dynamic", - ClientType: OAuthServerClientTypeConfidential, - ClientSecretHash: testSecretHash, - RedirectURIs: "https://example.com/callback", - GrantTypes: "authorization_code,refresh_token", + ID: uuid.Must(uuid.NewV4()), + ClientName: &testClientName, + RegistrationType: "dynamic", + ClientType: OAuthServerClientTypeConfidential, + ClientSecretHash: testSecretHash, + RedirectURIs: "https://example.com/callback", + GrantTypes: "authorization_code,refresh_token", + TokenEndpointAuthMethod: TokenEndpointAuthMethodClientSecretBasic, } // Test valid client @@ -147,13 +148,14 @@ func (ts *OAuthServerClientTestSuite) TestCreateOAuthServerClient() { testAppName := "Test Application" testSecretHash, _ := testHashClientSecret("test_secret") client := &OAuthServerClient{ - ID: uuid.Must(uuid.NewV4()), - ClientName: &testAppName, - GrantTypes: "authorization_code,refresh_token", - RegistrationType: "dynamic", - ClientType: OAuthServerClientTypeConfidential, - ClientSecretHash: testSecretHash, - RedirectURIs: "https://example.com/callback", + ID: uuid.Must(uuid.NewV4()), + ClientName: &testAppName, + GrantTypes: "authorization_code,refresh_token", + RegistrationType: "dynamic", + ClientType: OAuthServerClientTypeConfidential, + ClientSecretHash: testSecretHash, + RedirectURIs: "https://example.com/callback", + TokenEndpointAuthMethod: TokenEndpointAuthMethodClientSecretBasic, } err := CreateOAuthServerClient(ts.db, client) @@ -181,13 +183,14 @@ func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByID() { testName := "Find By ID Test" testSecretHash, _ := testHashClientSecret("test_secret") client := &OAuthServerClient{ - ID: uuid.Must(uuid.NewV4()), - ClientName: &testName, - GrantTypes: "authorization_code,refresh_token", - RegistrationType: "dynamic", - ClientType: OAuthServerClientTypeConfidential, - ClientSecretHash: testSecretHash, - RedirectURIs: "https://example.com/callback", + ID: uuid.Must(uuid.NewV4()), + ClientName: &testName, + GrantTypes: "authorization_code,refresh_token", + RegistrationType: "dynamic", + ClientType: OAuthServerClientTypeConfidential, + ClientSecretHash: testSecretHash, + RedirectURIs: "https://example.com/callback", + TokenEndpointAuthMethod: TokenEndpointAuthMethodClientSecretBasic, } err := CreateOAuthServerClient(ts.db, client) @@ -210,13 +213,14 @@ func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByClientID() { testName := "Find By Client ID Test" testSecretHash, _ := testHashClientSecret("test_secret") client := &OAuthServerClient{ - ID: uuid.Must(uuid.NewV4()), - ClientName: &testName, - GrantTypes: "authorization_code,refresh_token", - RegistrationType: "manual", - ClientType: OAuthServerClientTypeConfidential, - ClientSecretHash: testSecretHash, - RedirectURIs: "https://example.com/callback", + ID: uuid.Must(uuid.NewV4()), + ClientName: &testName, + GrantTypes: "authorization_code,refresh_token", + RegistrationType: "manual", + ClientType: OAuthServerClientTypeConfidential, + ClientSecretHash: testSecretHash, + RedirectURIs: "https://example.com/callback", + TokenEndpointAuthMethod: TokenEndpointAuthMethodClientSecretBasic, } err := CreateOAuthServerClient(ts.db, client) @@ -239,13 +243,14 @@ func (ts *OAuthServerClientTestSuite) TestUpdateOAuthServerClient() { originalName := "Original Name" testSecretHash, _ := testHashClientSecret("test_secret") client := &OAuthServerClient{ - ID: uuid.Must(uuid.NewV4()), - ClientName: &originalName, - GrantTypes: "authorization_code,refresh_token", - RegistrationType: "dynamic", - ClientType: OAuthServerClientTypeConfidential, - ClientSecretHash: testSecretHash, - RedirectURIs: "https://example.com/callback", + ID: uuid.Must(uuid.NewV4()), + ClientName: &originalName, + GrantTypes: "authorization_code,refresh_token", + RegistrationType: "dynamic", + ClientType: OAuthServerClientTypeConfidential, + ClientSecretHash: testSecretHash, + RedirectURIs: "https://example.com/callback", + TokenEndpointAuthMethod: TokenEndpointAuthMethodClientSecretBasic, } err := CreateOAuthServerClient(ts.db, client) @@ -293,13 +298,14 @@ func (ts *OAuthServerClientTestSuite) TestSoftDelete() { testName := "Soft Delete Test" testSecretHash, _ := testHashClientSecret("test_secret") client := &OAuthServerClient{ - ID: uuid.Must(uuid.NewV4()), - ClientName: &testName, - GrantTypes: "authorization_code,refresh_token", - RegistrationType: "dynamic", - ClientType: OAuthServerClientTypeConfidential, - ClientSecretHash: testSecretHash, - RedirectURIs: "https://example.com/callback", + ID: uuid.Must(uuid.NewV4()), + ClientName: &testName, + GrantTypes: "authorization_code,refresh_token", + RegistrationType: "dynamic", + ClientType: OAuthServerClientTypeConfidential, + ClientSecretHash: testSecretHash, + RedirectURIs: "https://example.com/callback", + TokenEndpointAuthMethod: TokenEndpointAuthMethodClientSecretBasic, } err := CreateOAuthServerClient(ts.db, client) diff --git a/migrations/20260102000000_add_oauth_token_endpoint_auth_method.up.sql b/migrations/20260102000000_add_oauth_token_endpoint_auth_method.up.sql new file mode 100644 index 000000000..78b7b6442 --- /dev/null +++ b/migrations/20260102000000_add_oauth_token_endpoint_auth_method.up.sql @@ -0,0 +1,29 @@ +-- Add token_endpoint_auth_method column to oauth_clients table +-- This stores the authentication method used at the token endpoint: +-- - 'none' for public clients (PKCE only) +-- - 'client_secret_basic' for confidential clients (HTTP Basic Auth) +-- - 'client_secret_post' for confidential clients (POST body) + +-- Create enum for token endpoint auth method +do $$ begin + create type {{ index .Options "Namespace" }}.oauth_token_endpoint_auth_method as enum('none', 'client_secret_basic', 'client_secret_post'); +exception + when duplicate_object then null; +end $$; + +-- Add token_endpoint_auth_method column with a default value based on client_type +-- This uses a CASE expression to set the correct default based on existing client_type +alter table {{ index .Options "Namespace" }}.oauth_clients + add column if not exists token_endpoint_auth_method {{ index .Options "Namespace" }}.oauth_token_endpoint_auth_method; + +-- Update existing rows to have the correct token_endpoint_auth_method based on client_type +update {{ index .Options "Namespace" }}.oauth_clients +set token_endpoint_auth_method = case + when client_type = 'public' then 'none'::{{ index .Options "Namespace" }}.oauth_token_endpoint_auth_method + else 'client_secret_basic'::{{ index .Options "Namespace" }}.oauth_token_endpoint_auth_method +end +where token_endpoint_auth_method is null; + +-- Now make the column NOT NULL since all rows have values +alter table {{ index .Options "Namespace" }}.oauth_clients + alter column token_endpoint_auth_method set not null;