Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions internal/api/oauthserver/client_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,24 @@ func GetAllValidAuthMethods() []string {
models.TokenEndpointAuthMethodClientSecretPost,
}
}

// DetermineTokenEndpointAuthMethod determines the final token endpoint auth method using:
// 1. Explicit token_endpoint_auth_method if provided
// 2. Default based on client type (none for public, client_secret_basic for confidential)
func DetermineTokenEndpointAuthMethod(explicitAuthMethod, clientType string) string {
// Priority 1: Explicit token_endpoint_auth_method
if explicitAuthMethod != "" {
return explicitAuthMethod
}

// Priority 2: Default based on client type
switch clientType {
case models.OAuthServerClientTypePublic:
return models.TokenEndpointAuthMethodNone
case models.OAuthServerClientTypeConfidential:
return models.TokenEndpointAuthMethodClientSecretBasic
default:
// Default to client_secret_basic for unknown/empty client type
return models.TokenEndpointAuthMethodClientSecretBasic
}
}
61 changes: 61 additions & 0 deletions internal/api/oauthserver/client_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,67 @@ func TestGetAllValidAuthMethods(t *testing.T) {
}
}

func TestDetermineTokenEndpointAuthMethod(t *testing.T) {
tests := []struct {
name string
explicitAuthMethod string
clientType string
expected string
}{
{
name: "explicit none overrides client type",
explicitAuthMethod: models.TokenEndpointAuthMethodNone,
clientType: models.OAuthServerClientTypeConfidential,
expected: models.TokenEndpointAuthMethodNone,
},
{
name: "explicit basic overrides client type",
explicitAuthMethod: models.TokenEndpointAuthMethodClientSecretBasic,
clientType: models.OAuthServerClientTypePublic,
expected: models.TokenEndpointAuthMethodClientSecretBasic,
},
{
name: "explicit post overrides client type",
explicitAuthMethod: models.TokenEndpointAuthMethodClientSecretPost,
clientType: models.OAuthServerClientTypePublic,
expected: models.TokenEndpointAuthMethodClientSecretPost,
},
{
name: "default none for public client",
explicitAuthMethod: "",
clientType: models.OAuthServerClientTypePublic,
expected: models.TokenEndpointAuthMethodNone,
},
{
name: "default basic for confidential client",
explicitAuthMethod: "",
clientType: models.OAuthServerClientTypeConfidential,
expected: models.TokenEndpointAuthMethodClientSecretBasic,
},
{
name: "default basic for empty client type",
explicitAuthMethod: "",
clientType: "",
expected: models.TokenEndpointAuthMethodClientSecretBasic,
},
{
name: "default basic for unknown client type",
explicitAuthMethod: "",
clientType: "unknown_type",
expected: models.TokenEndpointAuthMethodClientSecretBasic,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := DetermineTokenEndpointAuthMethod(tt.explicitAuthMethod, tt.clientType)
if result != tt.expected {
t.Errorf("DetermineTokenEndpointAuthMethod() = %v, expected %v", result, tt.expected)
}
})
}
}

// Helper function to check if a string contains a substring
func containsString(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || (len(s) > len(substr) &&
Expand Down
13 changes: 1 addition & 12 deletions internal/api/oauthserver/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,24 +51,13 @@ type OAuthServerClientListResponse struct {

// oauthServerClientToResponse converts a model to response format
func oauthServerClientToResponse(client *models.OAuthServerClient) *OAuthServerClientResponse {
// Set token endpoint auth methods based on client type
var tokenEndpointAuthMethods string
// TODO(cemal) :: Remove this once we have the token endpoint auth method stored in the database
if client.IsPublic() {
// Public clients don't use client authentication
tokenEndpointAuthMethods = models.TokenEndpointAuthMethodNone
} else {
// Confidential clients use client secret authentication
tokenEndpointAuthMethods = models.TokenEndpointAuthMethodClientSecretBasic
}

response := &OAuthServerClientResponse{
ClientID: client.ID.String(),
ClientType: client.ClientType,

// OAuth 2.1 DCR fields
RedirectURIs: client.GetRedirectURIs(),
TokenEndpointAuthMethod: tokenEndpointAuthMethods,
TokenEndpointAuthMethod: client.TokenEndpointAuthMethod,
GrantTypes: client.GetGrantTypes(),
ResponseTypes: []string{"code"}, // Always "code" in OAuth 2.1
ClientName: utilities.StringValue(client.ClientName),
Expand Down
16 changes: 10 additions & 6 deletions internal/api/oauthserver/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,15 +263,19 @@ func (s *Server) registerOAuthServerClient(ctx context.Context, params *OAuthSer
// Determine client type using centralized logic
clientType := DetermineClientType(params.ClientType, params.TokenEndpointAuthMethod)

// Determine token endpoint auth method based on explicit value or client type
tokenEndpointAuthMethod := DetermineTokenEndpointAuthMethod(params.TokenEndpointAuthMethod, clientType)

db := s.db.WithContext(ctx)

client := &models.OAuthServerClient{
ID: uuid.Must(uuid.NewV4()),
RegistrationType: params.RegistrationType,
ClientType: clientType,
ClientName: utilities.StringPtr(params.ClientName),
ClientURI: utilities.StringPtr(params.ClientURI),
LogoURI: utilities.StringPtr(params.LogoURI),
ID: uuid.Must(uuid.NewV4()),
RegistrationType: params.RegistrationType,
ClientType: clientType,
TokenEndpointAuthMethod: tokenEndpointAuthMethod,
ClientName: utilities.StringPtr(params.ClientName),
ClientURI: utilities.StringPtr(params.ClientURI),
LogoURI: utilities.StringPtr(params.LogoURI),
}

client.SetRedirectURIs(params.RedirectURIs)
Expand Down
17 changes: 9 additions & 8 deletions internal/models/oauth_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,15 @@ type OAuthServerClient struct {
RegistrationType string `json:"registration_type" db:"registration_type"`
ClientType string `json:"client_type" db:"client_type"`

RedirectURIs string `json:"-" db:"redirect_uris"`
GrantTypes string `json:"grant_types" db:"grant_types"`
ClientName *string `json:"client_name,omitempty" db:"client_name"`
ClientURI *string `json:"client_uri,omitempty" db:"client_uri"`
LogoURI *string `json:"logo_uri,omitempty" db:"logo_uri"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
DeletedAt *time.Time `json:"deleted_at,omitempty" db:"deleted_at"`
RedirectURIs string `json:"-" db:"redirect_uris"`
GrantTypes string `json:"grant_types" db:"grant_types"`
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method" db:"token_endpoint_auth_method"`
ClientName *string `json:"client_name,omitempty" db:"client_name"`
ClientURI *string `json:"client_uri,omitempty" db:"client_uri"`
LogoURI *string `json:"logo_uri,omitempty" db:"logo_uri"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
DeletedAt *time.Time `json:"deleted_at,omitempty" db:"deleted_at"`
}

// TableName returns the table name for the OAuthServerClient model
Expand Down
90 changes: 48 additions & 42 deletions internal/models/oauth_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,14 @@ func (ts *OAuthServerClientTestSuite) TestOAuthServerClientValidation() {
testClientName := "Test Client"
testSecretHash, _ := testHashClientSecret("test_secret")
validClient := &OAuthServerClient{
ID: uuid.Must(uuid.NewV4()),
ClientName: &testClientName,
RegistrationType: "dynamic",
ClientType: OAuthServerClientTypeConfidential,
ClientSecretHash: testSecretHash,
RedirectURIs: "https://example.com/callback",
GrantTypes: "authorization_code,refresh_token",
ID: uuid.Must(uuid.NewV4()),
ClientName: &testClientName,
RegistrationType: "dynamic",
ClientType: OAuthServerClientTypeConfidential,
ClientSecretHash: testSecretHash,
RedirectURIs: "https://example.com/callback",
GrantTypes: "authorization_code,refresh_token",
TokenEndpointAuthMethod: TokenEndpointAuthMethodClientSecretBasic,
}

// Test valid client
Expand Down Expand Up @@ -147,13 +148,14 @@ func (ts *OAuthServerClientTestSuite) TestCreateOAuthServerClient() {
testAppName := "Test Application"
testSecretHash, _ := testHashClientSecret("test_secret")
client := &OAuthServerClient{
ID: uuid.Must(uuid.NewV4()),
ClientName: &testAppName,
GrantTypes: "authorization_code,refresh_token",
RegistrationType: "dynamic",
ClientType: OAuthServerClientTypeConfidential,
ClientSecretHash: testSecretHash,
RedirectURIs: "https://example.com/callback",
ID: uuid.Must(uuid.NewV4()),
ClientName: &testAppName,
GrantTypes: "authorization_code,refresh_token",
RegistrationType: "dynamic",
ClientType: OAuthServerClientTypeConfidential,
ClientSecretHash: testSecretHash,
RedirectURIs: "https://example.com/callback",
TokenEndpointAuthMethod: TokenEndpointAuthMethodClientSecretBasic,
}

err := CreateOAuthServerClient(ts.db, client)
Expand Down Expand Up @@ -181,13 +183,14 @@ func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByID() {
testName := "Find By ID Test"
testSecretHash, _ := testHashClientSecret("test_secret")
client := &OAuthServerClient{
ID: uuid.Must(uuid.NewV4()),
ClientName: &testName,
GrantTypes: "authorization_code,refresh_token",
RegistrationType: "dynamic",
ClientType: OAuthServerClientTypeConfidential,
ClientSecretHash: testSecretHash,
RedirectURIs: "https://example.com/callback",
ID: uuid.Must(uuid.NewV4()),
ClientName: &testName,
GrantTypes: "authorization_code,refresh_token",
RegistrationType: "dynamic",
ClientType: OAuthServerClientTypeConfidential,
ClientSecretHash: testSecretHash,
RedirectURIs: "https://example.com/callback",
TokenEndpointAuthMethod: TokenEndpointAuthMethodClientSecretBasic,
}

err := CreateOAuthServerClient(ts.db, client)
Expand All @@ -210,13 +213,14 @@ func (ts *OAuthServerClientTestSuite) TestFindOAuthServerClientByClientID() {
testName := "Find By Client ID Test"
testSecretHash, _ := testHashClientSecret("test_secret")
client := &OAuthServerClient{
ID: uuid.Must(uuid.NewV4()),
ClientName: &testName,
GrantTypes: "authorization_code,refresh_token",
RegistrationType: "manual",
ClientType: OAuthServerClientTypeConfidential,
ClientSecretHash: testSecretHash,
RedirectURIs: "https://example.com/callback",
ID: uuid.Must(uuid.NewV4()),
ClientName: &testName,
GrantTypes: "authorization_code,refresh_token",
RegistrationType: "manual",
ClientType: OAuthServerClientTypeConfidential,
ClientSecretHash: testSecretHash,
RedirectURIs: "https://example.com/callback",
TokenEndpointAuthMethod: TokenEndpointAuthMethodClientSecretBasic,
}

err := CreateOAuthServerClient(ts.db, client)
Expand All @@ -239,13 +243,14 @@ func (ts *OAuthServerClientTestSuite) TestUpdateOAuthServerClient() {
originalName := "Original Name"
testSecretHash, _ := testHashClientSecret("test_secret")
client := &OAuthServerClient{
ID: uuid.Must(uuid.NewV4()),
ClientName: &originalName,
GrantTypes: "authorization_code,refresh_token",
RegistrationType: "dynamic",
ClientType: OAuthServerClientTypeConfidential,
ClientSecretHash: testSecretHash,
RedirectURIs: "https://example.com/callback",
ID: uuid.Must(uuid.NewV4()),
ClientName: &originalName,
GrantTypes: "authorization_code,refresh_token",
RegistrationType: "dynamic",
ClientType: OAuthServerClientTypeConfidential,
ClientSecretHash: testSecretHash,
RedirectURIs: "https://example.com/callback",
TokenEndpointAuthMethod: TokenEndpointAuthMethodClientSecretBasic,
}

err := CreateOAuthServerClient(ts.db, client)
Expand Down Expand Up @@ -293,13 +298,14 @@ func (ts *OAuthServerClientTestSuite) TestSoftDelete() {
testName := "Soft Delete Test"
testSecretHash, _ := testHashClientSecret("test_secret")
client := &OAuthServerClient{
ID: uuid.Must(uuid.NewV4()),
ClientName: &testName,
GrantTypes: "authorization_code,refresh_token",
RegistrationType: "dynamic",
ClientType: OAuthServerClientTypeConfidential,
ClientSecretHash: testSecretHash,
RedirectURIs: "https://example.com/callback",
ID: uuid.Must(uuid.NewV4()),
ClientName: &testName,
GrantTypes: "authorization_code,refresh_token",
RegistrationType: "dynamic",
ClientType: OAuthServerClientTypeConfidential,
ClientSecretHash: testSecretHash,
RedirectURIs: "https://example.com/callback",
TokenEndpointAuthMethod: TokenEndpointAuthMethodClientSecretBasic,
}

err := CreateOAuthServerClient(ts.db, client)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
-- Add token_endpoint_auth_method column to oauth_clients table
-- This stores the authentication method used at the token endpoint:
-- - 'none' for public clients (PKCE only)
-- - 'client_secret_basic' for confidential clients (HTTP Basic Auth)
-- - 'client_secret_post' for confidential clients (POST body)

-- Create enum for token endpoint auth method
do $$ begin
create type {{ index .Options "Namespace" }}.oauth_token_endpoint_auth_method as enum('none', 'client_secret_basic', 'client_secret_post');
exception
when duplicate_object then null;
end $$;

-- Add token_endpoint_auth_method column with a default value based on client_type
-- This uses a CASE expression to set the correct default based on existing client_type
alter table {{ index .Options "Namespace" }}.oauth_clients
add column if not exists token_endpoint_auth_method {{ index .Options "Namespace" }}.oauth_token_endpoint_auth_method;

-- Update existing rows to have the correct token_endpoint_auth_method based on client_type
update {{ index .Options "Namespace" }}.oauth_clients
set token_endpoint_auth_method = case
when client_type = 'public' then 'none'::{{ index .Options "Namespace" }}.oauth_token_endpoint_auth_method
else 'client_secret_basic'::{{ index .Options "Namespace" }}.oauth_token_endpoint_auth_method
end
where token_endpoint_auth_method is null;

-- Now make the column NOT NULL since all rows have values
alter table {{ index .Options "Namespace" }}.oauth_clients
alter column token_endpoint_auth_method set not null;