diff --git a/authentik/enterprise/providers/scim/api.py b/authentik/enterprise/providers/scim/api.py index ba065304cff3..6f9ab2d5ede1 100644 --- a/authentik/enterprise/providers/scim/api.py +++ b/authentik/enterprise/providers/scim/api.py @@ -1,14 +1,72 @@ +from datetime import datetime + +from django.urls import reverse from django.utils.translation import gettext as _ from rest_framework.exceptions import ValidationError from authentik.enterprise.license import LicenseKey -from authentik.providers.scim.models import SCIMAuthenticationMode +from authentik.providers.scim.models import SCIMAuthenticationMode, SCIMProvider +from authentik.sources.oauth.models import UserOAuthSourceConnection class SCIMProviderSerializerMixin: + def _get_token(self, instance: SCIMProvider) -> UserOAuthSourceConnection | None: + user = instance.auth_oauth_user + conn = UserOAuthSourceConnection.objects.filter( + user=user, source=instance.auth_oauth + ).first() + return conn + + def get_auth_oauth_token_last_updated(self, instance: SCIMProvider) -> datetime | None: + conn = self._get_token(instance) + return conn.last_updated if conn else None + + def get_auth_oauth_token_expires(self, instance: SCIMProvider) -> datetime | None: + conn = self._get_token(instance) + return conn.expires if conn else None + + def get_auth_oauth_url_callback(self, instance: SCIMProvider) -> str | None: + if ( + instance.auth_mode + in [ + SCIMAuthenticationMode.TOKEN, + SCIMAuthenticationMode.OAUTH_SILENT, + ] + or not instance.backchannel_application + ): + return None + relative_url = reverse( + "authentik_enterprise_providers_scim:callback", + kwargs={"application_slug": instance.backchannel_application.slug}, + ) + if "request" not in self.context: + return relative_url + return self.context["request"].build_absolute_uri(relative_url) + + def get_auth_oauth_url_start(self, instance: SCIMProvider) -> str | None: + if ( + instance.auth_mode + in [ + SCIMAuthenticationMode.TOKEN, + SCIMAuthenticationMode.OAUTH_SILENT, + ] + or not instance.backchannel_application + ): + return None + relative_url = reverse( + "authentik_enterprise_providers_scim:start", + kwargs={"application_slug": instance.backchannel_application.slug}, + ) + if "request" not in self.context: + return relative_url + return self.context["request"].build_absolute_uri(relative_url) + def validate_auth_mode(self, auth_mode: SCIMAuthenticationMode) -> SCIMAuthenticationMode: - if auth_mode == SCIMAuthenticationMode.OAUTH: + if auth_mode in [ + SCIMAuthenticationMode.OAUTH_SILENT, + SCIMAuthenticationMode.OAUTH_INTERACTIVE, + ]: if not LicenseKey.cached_summary().status.is_valid: raise ValidationError(_("Enterprise is required to use the OAuth mode.")) return auth_mode diff --git a/authentik/enterprise/providers/scim/apps.py b/authentik/enterprise/providers/scim/apps.py index 032d1e77eea0..8e064fe237f5 100644 --- a/authentik/enterprise/providers/scim/apps.py +++ b/authentik/enterprise/providers/scim/apps.py @@ -7,3 +7,4 @@ class AuthentikEnterpriseProviderSCIMConfig(EnterpriseConfig): label = "authentik_enterprise_providers_scim" verbose_name = "authentik Enterprise.Providers.SCIM" default = True + mountpoint = "application/scim/" diff --git a/authentik/enterprise/providers/scim/auth_oauth2.py b/authentik/enterprise/providers/scim/auth_oauth2.py index a5ab7dae96ff..42ab871f56b8 100644 --- a/authentik/enterprise/providers/scim/auth_oauth2.py +++ b/authentik/enterprise/providers/scim/auth_oauth2.py @@ -1,12 +1,14 @@ from datetime import timedelta -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from django.utils.timezone import now from requests import Request, RequestException from structlog.stdlib import get_logger +from authentik.common.oauth.constants import GRANT_TYPE_PASSWORD, GRANT_TYPE_REFRESH_TOKEN from authentik.providers.scim.clients.exceptions import SCIMRequestException -from authentik.sources.oauth.clients.oauth2 import OAuth2Client +from authentik.providers.scim.models import SCIMAuthenticationMode +from authentik.sources.oauth.clients.base import BaseOAuthClient from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection if TYPE_CHECKING: @@ -18,23 +20,26 @@ class SCIMOAuthException(SCIMRequestException): class SCIMOAuthAuth: - def __init__(self, provider: SCIMProvider): self.provider = provider self.user = provider.auth_oauth_user self.logger = get_logger().bind() self.connection = self.get_connection() - def retrieve_token(self): - if not self.provider.auth_oauth: - return None + def retrieve_token(self, conn: UserOAuthSourceConnection | None) -> dict[str, Any]: source: OAuthSource = self.provider.auth_oauth - client = OAuth2Client(source, None) + client: BaseOAuthClient = source.source_type.callback_view(request=None).get_client(source) access_token_url = source.source_type.access_token_url or "" if source.source_type.urls_customizable and source.access_token_url: access_token_url = source.access_token_url data = client.get_access_token_args(None, None) - data["grant_type"] = "password" + if self.provider.auth_mode == SCIMAuthenticationMode.OAUTH_SILENT: + data["grant_type"] = GRANT_TYPE_PASSWORD + elif self.provider.auth_mode == SCIMAuthenticationMode.OAUTH_INTERACTIVE: + data["grant_type"] = GRANT_TYPE_REFRESH_TOKEN + if not conn: + raise SCIMOAuthException(None, "Could not refresh SCIM OAuth token") + data["refresh_token"] = conn.refresh_token data.update(self.provider.auth_oauth_params) try: response = client.do_request( @@ -54,12 +59,14 @@ def retrieve_token(self): raise SCIMOAuthException(exc.response, message="Failed to get OAuth token") from exc def get_connection(self): - token = UserOAuthSourceConnection.objects.filter( - source=self.provider.auth_oauth, user=self.user, expires__gt=now() + if not self.provider.auth_oauth: + return None + conn = UserOAuthSourceConnection.objects.filter( + source=self.provider.auth_oauth, user=self.user ).first() - if token and token.access_token: - return token - token = self.retrieve_token() + if conn and conn.access_token and conn.expires > now(): + return conn + token = self.retrieve_token(conn) access_token = token["access_token"] expires_in = int(token.get("expires_in", 0)) token, _ = UserOAuthSourceConnection.objects.update_or_create( @@ -67,6 +74,7 @@ def get_connection(self): user=self.user, defaults={ "access_token": access_token, + "refresh_token": token.get("refresh_token"), "expires": now() + timedelta(seconds=expires_in), }, ) diff --git a/authentik/enterprise/providers/scim/signals.py b/authentik/enterprise/providers/scim/signals.py index d150da21785c..c6d159bc520d 100644 --- a/authentik/enterprise/providers/scim/signals.py +++ b/authentik/enterprise/providers/scim/signals.py @@ -14,7 +14,10 @@ def scim_provider_post_save(sender: type[Model], instance: SCIMProvider, created """Create service account before provider is saved""" identifier = f"ak-providers-scim-{instance.pk}" with audit_ignore(): - if instance.auth_mode == SCIMAuthenticationMode.OAUTH: + if instance.auth_mode in [ + SCIMAuthenticationMode.OAUTH_SILENT, + SCIMAuthenticationMode.OAUTH_INTERACTIVE, + ]: user, user_created = User.objects.update_or_create( username=identifier, defaults={ diff --git a/authentik/enterprise/providers/scim/tests/__init__.py b/authentik/enterprise/providers/scim/tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/authentik/enterprise/providers/scim/tests/test_api.py b/authentik/enterprise/providers/scim/tests/test_api.py new file mode 100644 index 000000000000..bc005238b4b3 --- /dev/null +++ b/authentik/enterprise/providers/scim/tests/test_api.py @@ -0,0 +1,73 @@ +"""SCIM OAuth tests""" + +from unittest.mock import MagicMock, PropertyMock, patch + +from django.urls import reverse +from rest_framework.test import APITestCase + +from authentik.core.tests.utils import create_test_admin_user +from authentik.enterprise.license import LicenseKey +from authentik.enterprise.models import License +from authentik.enterprise.tests.test_license import expiry_valid +from authentik.lib.generators import generate_id +from authentik.sources.oauth.models import OAuthSource + + +class TestSCIMOAuthAPI(APITestCase): + """SCIM User tests""" + + def setUp(self): + self.source = OAuthSource.objects.create( + name=generate_id(), + slug=generate_id(), + access_token_url="http://localhost/token", # nosec + consumer_key=generate_id(), + consumer_secret=generate_id(), + provider_type="openidconnect", + ) + + @patch( + "authentik.enterprise.license.LicenseKey.validate", + MagicMock( + return_value=LicenseKey( + aud="", + exp=expiry_valid, + name=generate_id(), + internal_users=100, + external_users=100, + ) + ), + ) + def test_api_create(self): + License.objects.create(key=generate_id()) + self.client.force_login(create_test_admin_user()) + res = self.client.post( + reverse("authentik_api:scimprovider-list"), + { + "name": generate_id(), + "url": "http://localhost", + "auth_mode": "oauth_silent", + "auth_oauth": str(self.source.pk), + }, + ) + self.assertEqual(res.status_code, 201) + + @patch( + "authentik.enterprise.models.LicenseUsageStatus.is_valid", + PropertyMock(return_value=False), + ) + def test_api_create_no_license(self): + self.client.force_login(create_test_admin_user()) + res = self.client.post( + reverse("authentik_api:scimprovider-list"), + { + "name": generate_id(), + "url": "http://localhost", + "auth_mode": "oauth_silent", + "auth_oauth": str(self.source.pk), + }, + ) + self.assertEqual(res.status_code, 400) + self.assertJSONEqual( + res.content, {"auth_mode": ["Enterprise is required to use the OAuth mode."]} + ) diff --git a/authentik/enterprise/providers/scim/tests/test_auth.py b/authentik/enterprise/providers/scim/tests/test_auth.py new file mode 100644 index 000000000000..1ea0687fb22c --- /dev/null +++ b/authentik/enterprise/providers/scim/tests/test_auth.py @@ -0,0 +1,100 @@ +"""SCIM OAuth tests""" + +from requests_mock import Mocker +from rest_framework.test import APITestCase + +from authentik.blueprints.tests import apply_blueprint +from authentik.core.models import Application, Group, User +from authentik.lib.generators import generate_id +from authentik.providers.scim.models import SCIMAuthenticationMode, SCIMMapping, SCIMProvider +from authentik.sources.oauth.models import OAuthSource +from authentik.tenants.models import Tenant + + +class TestSCIMOAuthAuth(APITestCase): + """SCIM User tests""" + + @apply_blueprint("system/providers-scim.yaml") + def setUp(self) -> None: + # Delete all users and groups as the mocked HTTP responses only return one ID + # which will cause errors with multiple users + Tenant.objects.update(avatars="none") + User.objects.all().exclude_anonymous().delete() + Group.objects.all().delete() + self.source = OAuthSource.objects.create( + name=generate_id(), + slug=generate_id(), + access_token_url="http://localhost/token", # nosec + consumer_key=generate_id(), + consumer_secret=generate_id(), + provider_type="openidconnect", + ) + self.provider = SCIMProvider.objects.create( + name=generate_id(), + url="https://localhost", + auth_mode=SCIMAuthenticationMode.OAUTH_SILENT, + auth_oauth=self.source, + auth_oauth_params={ + "foo": "bar", + }, + exclude_users_service_account=True, + ) + self.app: Application = Application.objects.create( + name=generate_id(), + slug=generate_id(), + ) + self.app.backchannel_providers.add(self.provider) + self.provider.property_mappings.add( + SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/user") + ) + self.provider.property_mappings_group.add( + SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/group") + ) + + @Mocker() + def test_user_create(self, mock: Mocker): + """Test user creation""" + scim_id = generate_id() + token = generate_id() + mock.post("http://localhost/token", json={"access_token": token, "expires_in": 3600}) + mock.get( + "https://localhost/ServiceProviderConfig", + json={}, + ) + mock.post( + "https://localhost/Users", + json={ + "id": scim_id, + }, + ) + uid = generate_id() + user = User.objects.create( + username=uid, + name=f"{uid} {uid}", + email=f"{uid}@goauthentik.io", + ) + self.assertEqual(mock.call_count, 3) + self.assertEqual(mock.request_history[1].method, "GET") + self.assertEqual(mock.request_history[2].method, "POST") + self.assertJSONEqual( + mock.request_history[2].body, + { + "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], + "active": True, + "emails": [ + { + "primary": True, + "type": "other", + "value": f"{uid}@goauthentik.io", + } + ], + "externalId": user.uid, + "name": { + "familyName": uid, + "formatted": f"{uid} {uid}", + "givenName": uid, + }, + "displayName": f"{uid} {uid}", + "userName": uid, + }, + ) diff --git a/authentik/enterprise/providers/scim/tests.py b/authentik/enterprise/providers/scim/tests/test_token.py similarity index 50% rename from authentik/enterprise/providers/scim/tests.py rename to authentik/enterprise/providers/scim/tests/test_token.py index 0680c53d0f36..1693ca195a4a 100644 --- a/authentik/enterprise/providers/scim/tests.py +++ b/authentik/enterprise/providers/scim/tests/test_token.py @@ -2,7 +2,7 @@ from base64 import b64encode from datetime import timedelta -from unittest.mock import MagicMock, PropertyMock, patch +from urllib.parse import parse_qs, urlencode, urlparse from django.urls import reverse from django.utils.timezone import now @@ -11,17 +11,14 @@ from authentik.blueprints.tests import apply_blueprint from authentik.core.models import Application, Group, User -from authentik.core.tests.utils import create_test_admin_user -from authentik.enterprise.license import LicenseKey -from authentik.enterprise.models import License -from authentik.enterprise.tests.test_license import expiry_valid from authentik.lib.generators import generate_id from authentik.providers.scim.models import SCIMAuthenticationMode, SCIMMapping, SCIMProvider from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.tenants.models import Tenant +from tests.live import create_test_admin_user -class SCIMOAuthTests(APITestCase): +class TestSCIMOAuthToken(APITestCase): """SCIM User tests""" @apply_blueprint("system/providers-scim.yaml") @@ -42,7 +39,7 @@ def setUp(self) -> None: self.provider = SCIMProvider.objects.create( name=generate_id(), url="https://localhost", - auth_mode=SCIMAuthenticationMode.OAUTH, + auth_mode=SCIMAuthenticationMode.OAUTH_SILENT, auth_oauth=self.source, auth_oauth_params={ "foo": "bar", @@ -60,8 +57,9 @@ def setUp(self) -> None: self.provider.property_mappings_group.add( SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/group") ) + self.admin = create_test_admin_user() - def test_retrieve_token(self): + def test_retrieve_token_silent(self): """Test token retrieval""" with Mocker() as mocker: token = generate_id() @@ -86,6 +84,44 @@ def test_retrieve_token(self): ) self.assertEqual(mocker.request_history[0].body, "grant_type=password&foo=bar") + def test_retrieve_token_interactive(self): + """Test token retrieval""" + self.provider.auth_mode = SCIMAuthenticationMode.OAUTH_INTERACTIVE + self.provider.save() + refresh_token = generate_id() + access_token = generate_id() + UserOAuthSourceConnection.objects.create( + user=self.provider.auth_oauth_user, + source=self.source, + refresh_token=refresh_token, + access_token=access_token, + ) + with Mocker() as mocker: + token = generate_id() + mocker.post("http://localhost/token", json={"access_token": token, "expires_in": 3600}) + self.provider.scim_auth() + conn = UserOAuthSourceConnection.objects.filter( + source=self.source, + user=self.provider.auth_oauth_user, + ).first() + self.assertIsNotNone(conn) + self.assertTrue(conn.is_valid) + auth = ( + b64encode( + b":".join((self.source.consumer_key.encode(), self.source.consumer_secret.encode())) + ) + .strip() + .decode() + ) + self.assertEqual( + mocker.request_history[0].headers["Authorization"], + f"Basic {auth}", + ) + self.assertEqual( + mocker.request_history[0].body, + f"grant_type=refresh_token&refresh_token={refresh_token}&foo=bar", + ) + def test_existing_token(self): """Test existing token""" UserOAuthSourceConnection.objects.create( @@ -98,96 +134,54 @@ def test_existing_token(self): self.provider.scim_auth() self.assertEqual(len(mocker.request_history), 0) - @Mocker() - def test_user_create(self, mock: Mocker): - """Test user creation""" - scim_id = generate_id() - token = generate_id() - mock.post("http://localhost/token", json={"access_token": token, "expires_in": 3600}) - mock.get( - "https://localhost/ServiceProviderConfig", - json={}, - ) - mock.post( - "https://localhost/Users", - json={ - "id": scim_id, - }, + def test_interactive_start(self): + self.client.force_login(self.admin) + res = self.client.get( + reverse( + "authentik_enterprise_providers_scim:start", + kwargs={ + "application_slug": self.app.slug, + }, + ) ) - uid = generate_id() - user = User.objects.create( - username=uid, - name=f"{uid} {uid}", - email=f"{uid}@goauthentik.io", + self.assertEqual(res.status_code, 302) + query = parse_qs(urlparse(res.url).query) + self.assertEqual(query["client_id"], [self.source.consumer_key]) + self.assertEqual( + query["redirect_uri"], + [f"http://testserver/application/scim/{self.app.slug}/oauth2/callback/"], ) - self.assertEqual(mock.call_count, 3) - self.assertEqual(mock.request_history[1].method, "GET") - self.assertEqual(mock.request_history[2].method, "POST") - self.assertJSONEqual( - mock.request_history[2].body, - { - "schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"], - "active": True, - "emails": [ - { - "primary": True, - "type": "other", - "value": f"{uid}@goauthentik.io", - } - ], - "externalId": user.uid, - "name": { - "familyName": uid, - "formatted": f"{uid} {uid}", - "givenName": uid, + self.assertEqual(query["response_type"], ["code"]) + + def test_interactive_callback(self): + self.client.force_login(self.admin) + res = self.client.get( + reverse( + "authentik_enterprise_providers_scim:start", + kwargs={ + "application_slug": self.app.slug, }, - "displayName": f"{uid} {uid}", - "userName": uid, - }, + ) ) + self.assertEqual(res.status_code, 302) + query = parse_qs(urlparse(res.url).query) + + with Mocker() as mock: + token = generate_id() + mock.post("http://localhost/token", json={"access_token": token, "expires_in": 3600}) - @patch( - "authentik.enterprise.license.LicenseKey.validate", - MagicMock( - return_value=LicenseKey( - aud="", - exp=expiry_valid, - name=generate_id(), - internal_users=100, - external_users=100, + res = self.client.get( + reverse( + "authentik_enterprise_providers_scim:callback", + kwargs={ + "application_slug": self.app.slug, + }, + ) + + "?" + + urlencode({"state": query["state"][0], "code": generate_id()}) ) - ), - ) - def test_api_create(self): - License.objects.create(key=generate_id()) - self.client.force_login(create_test_admin_user()) - res = self.client.post( - reverse("authentik_api:scimprovider-list"), - { - "name": generate_id(), - "url": "http://localhost", - "auth_mode": "oauth", - "auth_oauth": str(self.source.pk), - }, - ) - self.assertEqual(res.status_code, 201) + self.assertEqual(res.status_code, 302) - @patch( - "authentik.enterprise.models.LicenseUsageStatus.is_valid", - PropertyMock(return_value=False), - ) - def test_api_create_no_license(self): - self.client.force_login(create_test_admin_user()) - res = self.client.post( - reverse("authentik_api:scimprovider-list"), - { - "name": generate_id(), - "url": "http://localhost", - "auth_mode": "oauth", - "auth_oauth": str(self.source.pk), - }, - ) - self.assertEqual(res.status_code, 400) - self.assertJSONEqual( - res.content, {"auth_mode": ["Enterprise is required to use the OAuth mode."]} - ) + conn = UserOAuthSourceConnection.objects.filter(source=self.source).first() + self.assertIsNotNone(conn) + self.assertTrue(conn.is_valid) diff --git a/authentik/enterprise/providers/scim/urls.py b/authentik/enterprise/providers/scim/urls.py new file mode 100644 index 000000000000..f51998dcb536 --- /dev/null +++ b/authentik/enterprise/providers/scim/urls.py @@ -0,0 +1,10 @@ +from django.urls import path + +from authentik.enterprise.providers.scim.views import SCIMOAuthStart, SCIMRedirectCallback + +urlpatterns = [ + path("/oauth2/start/", SCIMOAuthStart.as_view(), name="start"), + path( + "/oauth2/callback/", SCIMRedirectCallback.as_view(), name="callback" + ), +] diff --git a/authentik/enterprise/providers/scim/views.py b/authentik/enterprise/providers/scim/views.py new file mode 100644 index 000000000000..91db2377ba1e --- /dev/null +++ b/authentik/enterprise/providers/scim/views.py @@ -0,0 +1,70 @@ +from datetime import timedelta + +from django.core.exceptions import PermissionDenied +from django.http import HttpRequest +from django.shortcuts import redirect +from django.urls import reverse +from django.utils.timezone import now + +from authentik.core.models import Application +from authentik.providers.scim.models import SCIMProvider +from authentik.sources.oauth.clients.base import BaseOAuthClient +from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection +from authentik.sources.oauth.types.registry import RequestKind, registry +from authentik.sources.oauth.views.callback import OAuthCallback +from authentik.sources.oauth.views.redirect import OAuthRedirect + + +class SCIMOAuthViewMixin: + + provider: SCIMProvider + + def get_client(self, source: OAuthSource, **kwargs) -> BaseOAuthClient: + source: OAuthSource = self.provider.auth_oauth + source_cls = registry.find(source.provider_type, kind=RequestKind.CALLBACK) + if not source_cls.client_class: + return super().get_client(source, **kwargs) + return source_cls.client_class(source, self.request, **kwargs) + + def _get_scim_provider(self, app_slug: str): + app = Application.objects.filter(slug=app_slug).first() + if not app: + return None + provider = SCIMProvider.objects.filter(backchannel_application=app) + return provider.first() + + def dispatch(self, request: HttpRequest, application_slug: str): + if not request.user.is_authenticated: + raise PermissionDenied() + provider = self._get_scim_provider(application_slug) + if not provider or not provider.auth_oauth: + raise PermissionDenied() + if not request.user.has_perm( + "authentik_providers_scim.change_scimprovider", + provider, + ): + raise PermissionDenied() + self.provider = provider + return super().dispatch(request, source_slug=provider.auth_oauth.slug) + + +class SCIMOAuthStart(SCIMOAuthViewMixin, OAuthRedirect): + + def get_callback_url(self, source: OAuthSource): + return reverse("authentik_enterprise_providers_scim:callback", kwargs=self.kwargs) + + +class SCIMRedirectCallback(SCIMOAuthViewMixin, OAuthCallback): + + def redirect_flow_manager(self, client: BaseOAuthClient): + expires_in = int(self.token.get("expires_in", 0)) + UserOAuthSourceConnection.objects.update_or_create( + source=self.provider.auth_oauth, + user=self.provider.auth_oauth_user, + defaults={ + "access_token": self.token.get("access_token"), + "refresh_token": self.token.get("refresh_token"), + "expires": now() + timedelta(seconds=expires_in), + }, + ) + return redirect("authentik_core:if-admin") diff --git a/authentik/providers/scim/api/providers.py b/authentik/providers/scim/api/providers.py index a931836b7b67..f4120d363cd2 100644 --- a/authentik/providers/scim/api/providers.py +++ b/authentik/providers/scim/api/providers.py @@ -1,5 +1,6 @@ """SCIM Provider API Views""" +from rest_framework.fields import SerializerMethodField from rest_framework.viewsets import ModelViewSet from authentik.core.api.providers import ProviderSerializer @@ -16,6 +17,11 @@ class SCIMProviderSerializer( ): """SCIMProvider Serializer""" + auth_oauth_token_last_updated = SerializerMethodField() + auth_oauth_token_expires = SerializerMethodField() + auth_oauth_url_callback = SerializerMethodField() + auth_oauth_url_start = SerializerMethodField() + class Meta: model = SCIMProvider fields = [ @@ -35,6 +41,10 @@ class Meta: "auth_mode", "auth_oauth", "auth_oauth_params", + "auth_oauth_token_last_updated", + "auth_oauth_token_expires", + "auth_oauth_url_callback", + "auth_oauth_url_start", "compatibility_mode", "service_provider_config_cache_timeout", "exclude_users_service_account", diff --git a/authentik/providers/scim/migrations/0020_alter_scimprovider_auth_mode.py b/authentik/providers/scim/migrations/0020_alter_scimprovider_auth_mode.py new file mode 100644 index 000000000000..024797fd435e --- /dev/null +++ b/authentik/providers/scim/migrations/0020_alter_scimprovider_auth_mode.py @@ -0,0 +1,36 @@ +# Generated by Django 5.2.14 on 2026-05-05 22:11 + +from django.db import migrations, models +from django.apps.registry import Apps + +from django.db.backends.base.schema import BaseDatabaseSchemaEditor + + +def update_oauth(apps: Apps, schema_editor: BaseDatabaseSchemaEditor): + db_alias = schema_editor.connection.alias + + SCIMProvider = apps.get("authentik_providers_scim", "scimprovider") + + SCIMProvider.objects.using(db_alias).filter(auth_mode="oauth").update(auth_mode="oauth_silent") + + +class Migration(migrations.Migration): + + dependencies = [ + ("authentik_providers_scim", "0019_scimprovider_group_filters_and_more"), + ] + + operations = [ + migrations.AlterField( + model_name="scimprovider", + name="auth_mode", + field=models.TextField( + choices=[ + ("token", "Token"), + ("oauth_silent", "OAuth (Silent)"), + ("oauth_interactive", "OAuth (interactive)"), + ], + default="token", + ), + ), + ] diff --git a/authentik/providers/scim/models.py b/authentik/providers/scim/models.py index 225ec00b474c..a7ae02d1181b 100644 --- a/authentik/providers/scim/models.py +++ b/authentik/providers/scim/models.py @@ -72,7 +72,8 @@ class SCIMAuthenticationMode(models.TextChoices): """SCIM authentication modes""" TOKEN = "token", _("Token") - OAUTH = "oauth", _("OAuth") + OAUTH_SILENT = "oauth_silent", _("OAuth (Silent)") + OAUTH_INTERACTIVE = "oauth_interactive", _("OAuth (interactive)") class SCIMCompatibilityMode(models.TextChoices): @@ -144,7 +145,10 @@ class SCIMProvider(OutgoingSyncProvider, BackchannelProvider): ) def scim_auth(self) -> AuthBase: - if self.auth_mode == SCIMAuthenticationMode.OAUTH: + if self.auth_mode in [ + SCIMAuthenticationMode.OAUTH_SILENT, + SCIMAuthenticationMode.OAUTH_INTERACTIVE, + ]: try: from authentik.enterprise.providers.scim.auth_oauth2 import SCIMOAuthAuth diff --git a/authentik/sources/oauth/types/registry.py b/authentik/sources/oauth/types/registry.py index d501ffb2e175..2f5657f7cb0e 100644 --- a/authentik/sources/oauth/types/registry.py +++ b/authentik/sources/oauth/types/registry.py @@ -1,6 +1,5 @@ """Source type manager""" -from collections.abc import Callable from enum import Enum from typing import Any @@ -114,7 +113,7 @@ def find_type(self, type_name: str) -> type[SourceType]: ) return found_type - def find(self, type_name: str, kind: RequestKind) -> Callable: + def find(self, type_name: str, kind: RequestKind) -> type[OAuthCallback | OAuthRedirect]: """Find fitting Source Type""" found_type = self.find_type(type_name) if kind == RequestKind.CALLBACK: diff --git a/authentik/sources/oauth/views/callback.py b/authentik/sources/oauth/views/callback.py index 3426a599ff07..839ca81c0218 100644 --- a/authentik/sources/oauth/views/callback.py +++ b/authentik/sources/oauth/views/callback.py @@ -15,6 +15,7 @@ from authentik.core.sources.flow_manager import SourceFlowManager from authentik.events.models import Event, EventAction +from authentik.sources.oauth.clients.base import BaseOAuthClient from authentik.sources.oauth.models import ( GroupOAuthSourceConnection, OAuthSource, @@ -29,7 +30,7 @@ class OAuthCallback(OAuthClientMixin, View): "Base OAuth callback view." source: OAuthSource - token: dict | None = None + token: dict[str, Any] | None = None def dispatch(self, request: HttpRequest, *_, **kwargs) -> HttpResponse: """View Get handler""" @@ -49,20 +50,31 @@ def dispatch(self, request: HttpRequest, *_, **kwargs) -> HttpResponse: if "error" in self.token: return self.handle_login_failure(self.token["error"]) # Fetch profile info + try: + res = self.redirect_flow_manager(client) + except ValueError as exc: + # if we're authenticated and not in a source stage and this new flag is enabled, + # just continue + if self.request.user.is_authenticated: + pass + return self.handle_login_failure(exc.args[0]) + return res + + def redirect_flow_manager(self, client: BaseOAuthClient) -> HttpResponse: try: raw_info = client.get_profile_info(self.token) if raw_info is None: - return self.handle_login_failure("Could not retrieve profile.") + raise ValueError("Could not retrieve profile.") except JSONDecodeError as exc: Event.new( EventAction.CONFIGURATION_ERROR, message="Failed to JSON-decode profile.", raw_profile=exc.doc, ).from_http(self.request) - return self.handle_login_failure("Could not retrieve profile.") + raise ValueError("Could not retrieve profile.") from None identifier = self.get_user_id(info=raw_info) if identifier is None: - return self.handle_login_failure("Could not determine id.") + raise ValueError("Could not determine id.") sfm = OAuthSourceFlowManager( source=self.source, request=self.request, diff --git a/blueprints/schema.json b/blueprints/schema.json index 22938bfcc1fe..57a46320924c 100644 --- a/blueprints/schema.json +++ b/blueprints/schema.json @@ -11203,7 +11203,8 @@ "type": "string", "enum": [ "token", - "oauth" + "oauth_silent", + "oauth_interactive" ], "title": "Auth mode" }, diff --git a/packages/client-ts/package.json b/packages/client-ts/package.json index aee985f00101..f86c7e74a511 100644 --- a/packages/client-ts/package.json +++ b/packages/client-ts/package.json @@ -8,8 +8,8 @@ "url": "https://github.com/goauthentik/authentik.git" }, "scripts": { - "clean": "tsc -b --clean tsconfig.json tsconfig.esm.json", "build": "npm run clean && tsc -b tsconfig.json tsconfig.esm.json", + "clean": "tsc -b --clean tsconfig.json tsconfig.esm.json", "prepare": "npm run build" }, "main": "./dist/index.js", diff --git a/packages/client-ts/src/models/SCIMAuthenticationModeEnum.ts b/packages/client-ts/src/models/SCIMAuthenticationModeEnum.ts index 20ba8c8e74c5..5d9d0668db35 100644 --- a/packages/client-ts/src/models/SCIMAuthenticationModeEnum.ts +++ b/packages/client-ts/src/models/SCIMAuthenticationModeEnum.ts @@ -18,7 +18,8 @@ */ export const SCIMAuthenticationModeEnum = { Token: "token", - Oauth: "oauth", + OauthSilent: "oauth_silent", + OauthInteractive: "oauth_interactive", UnknownDefaultOpenApi: "11184809", } as const; export type SCIMAuthenticationModeEnum = diff --git a/packages/client-ts/src/models/SCIMProvider.ts b/packages/client-ts/src/models/SCIMProvider.ts index 4b95d47facd5..752349fcd977 100644 --- a/packages/client-ts/src/models/SCIMProvider.ts +++ b/packages/client-ts/src/models/SCIMProvider.ts @@ -125,6 +125,30 @@ export interface SCIMProvider { * @memberof SCIMProvider */ authOauthParams?: { [key: string]: any }; + /** + * + * @type {Date} + * @memberof SCIMProvider + */ + readonly authOauthTokenLastUpdated: Date | null; + /** + * + * @type {Date} + * @memberof SCIMProvider + */ + readonly authOauthTokenExpires: Date | null; + /** + * + * @type {string} + * @memberof SCIMProvider + */ + readonly authOauthUrlCallback: string | null; + /** + * + * @type {string} + * @memberof SCIMProvider + */ + readonly authOauthUrlStart: string | null; /** * Alter authentik behavior for vendor-specific SCIM implementations. * @type {CompatibilityModeEnum} @@ -190,6 +214,13 @@ export function instanceOfSCIMProvider(value: object): value is SCIMProvider { if (!("verboseNamePlural" in value) || value["verboseNamePlural"] === undefined) return false; if (!("metaModelName" in value) || value["metaModelName"] === undefined) return false; if (!("url" in value) || value["url"] === undefined) return false; + if (!("authOauthTokenLastUpdated" in value) || value["authOauthTokenLastUpdated"] === undefined) + return false; + if (!("authOauthTokenExpires" in value) || value["authOauthTokenExpires"] === undefined) + return false; + if (!("authOauthUrlCallback" in value) || value["authOauthUrlCallback"] === undefined) + return false; + if (!("authOauthUrlStart" in value) || value["authOauthUrlStart"] === undefined) return false; return true; } @@ -223,6 +254,16 @@ export function SCIMProviderFromJSONTyped(json: any, ignoreDiscriminator: boolea : SCIMAuthenticationModeEnumFromJSON(json["auth_mode"]), authOauth: json["auth_oauth"] == null ? undefined : json["auth_oauth"], authOauthParams: json["auth_oauth_params"] == null ? undefined : json["auth_oauth_params"], + authOauthTokenLastUpdated: + json["auth_oauth_token_last_updated"] == null + ? null + : new Date(json["auth_oauth_token_last_updated"]), + authOauthTokenExpires: + json["auth_oauth_token_expires"] == null + ? null + : new Date(json["auth_oauth_token_expires"]), + authOauthUrlCallback: json["auth_oauth_url_callback"], + authOauthUrlStart: json["auth_oauth_url_start"], compatibilityMode: json["compatibility_mode"] == null ? undefined @@ -256,6 +297,10 @@ export function SCIMProviderToJSONTyped( | "verbose_name" | "verbose_name_plural" | "meta_model_name" + | "auth_oauth_token_last_updated" + | "auth_oauth_token_expires" + | "auth_oauth_url_callback" + | "auth_oauth_url_start" > | null, ignoreDiscriminator: boolean = false, ): any { diff --git a/schema.yml b/schema.yml index cf6551d61942..e9cab586ae16 100644 --- a/schema.yml +++ b/schema.yml @@ -54864,7 +54864,8 @@ components: SCIMAuthenticationModeEnum: enum: - token - - oauth + - oauth_silent + - oauth_interactive type: string SCIMMapping: type: object @@ -54999,6 +55000,24 @@ components: type: object additionalProperties: {} description: Additional OAuth parameters, such as grant_type + auth_oauth_token_last_updated: + type: string + format: date-time + nullable: true + readOnly: true + auth_oauth_token_expires: + type: string + format: date-time + nullable: true + readOnly: true + auth_oauth_url_callback: + type: string + nullable: true + readOnly: true + auth_oauth_url_start: + type: string + nullable: true + readOnly: true compatibility_mode: allOf: - $ref: '#/components/schemas/CompatibilityModeEnum' @@ -55031,6 +55050,10 @@ components: required: - assigned_backchannel_application_name - assigned_backchannel_application_slug + - auth_oauth_token_expires + - auth_oauth_token_last_updated + - auth_oauth_url_callback + - auth_oauth_url_start - component - meta_model_name - name diff --git a/web/src/admin/providers/scim/SCIMProviderFormForm.ts b/web/src/admin/providers/scim/SCIMProviderFormForm.ts index 3a503f3f59f9..7a102bcc5266 100644 --- a/web/src/admin/providers/scim/SCIMProviderFormForm.ts +++ b/web/src/admin/providers/scim/SCIMProviderFormForm.ts @@ -92,7 +92,8 @@ export function renderAuth(provider?: Partial, errors: ValidationE default: case SCIMAuthenticationModeEnum.Token: return renderAuthToken(provider, errors); - case SCIMAuthenticationModeEnum.Oauth: + case SCIMAuthenticationModeEnum.OauthSilent: + case SCIMAuthenticationModeEnum.OauthInteractive: return renderAuthOAuth(provider, errors); } } @@ -160,12 +161,18 @@ export function renderForm({ provider, errors, update }: SCIMProviderFormProps) )}`, }, { - label: msg("OAuth"), - value: SCIMAuthenticationModeEnum.Oauth, - default: true, + label: msg("OAuth (Silent)"), + value: SCIMAuthenticationModeEnum.OauthSilent, description: html`${msg("Authenticate SCIM requests using OAuth.")} `, }, + { + label: msg("OAuth (Interactive)"), + value: SCIMAuthenticationModeEnum.OauthInteractive, + description: html`${msg( + "Authenticate SCIM requests using OAuth, interactively authorized.", + )} `, + }, ]} > diff --git a/web/src/admin/providers/scim/SCIMProviderViewPage.ts b/web/src/admin/providers/scim/SCIMProviderViewPage.ts index 74832656f169..64fd9fc27e94 100644 --- a/web/src/admin/providers/scim/SCIMProviderViewPage.ts +++ b/web/src/admin/providers/scim/SCIMProviderViewPage.ts @@ -13,6 +13,7 @@ import "#elements/buttons/ModalButton"; import "#elements/sync/SyncStatusCard"; import "#elements/tasks/ScheduleList"; import "#elements/tasks/TaskList"; +import "#elements/timestamp/ak-timestamp"; import { DEFAULT_CONFIG } from "#common/api/config"; import { EVENT_REFRESH } from "#common/constants"; @@ -20,7 +21,14 @@ import { EVENT_REFRESH } from "#common/constants"; import { AKElement } from "#elements/Base"; import { SlottedTemplateResult } from "#elements/types"; -import { ModelEnum, ProvidersApi, SCIMProvider } from "@goauthentik/api"; +import renderDescriptionList from "#components/DescriptionList"; + +import { + ModelEnum, + ProvidersApi, + SCIMAuthenticationModeEnum, + SCIMProvider, +} from "@goauthentik/api"; import MDSCIMProvider from "~docs/add-secure-apps/providers/scim/index.md"; @@ -154,6 +162,42 @@ export class SCIMProviderViewPage extends AKElement { `; } + renderSyncStatusExtra() { + if ( + this.provider?.authMode !== SCIMAuthenticationModeEnum.OauthSilent && + this.provider?.authMode !== SCIMAuthenticationModeEnum.OauthInteractive + ) + return nothing; + return html` +
+
+ ${msg("OAuth Token last updated")} +
+
+
+ +
+
+
+
+
+ ${msg("OAuth Token expires")} +
+
+
+ +
+
+
+ `; + } + renderTabOverview(): SlottedTemplateResult { if (!this.provider) { return nothing; @@ -168,91 +212,94 @@ export class SCIMProviderViewPage extends AKElement { : nothing}
+
${msg("Info")}
-
-
-
- ${msg("Name")} -
-
-
- ${this.provider.name} -
-
-
-
-
- ${msg("Assigned to application")}`, + ], + [ + msg("Dry-run"), + html``, + ], + [msg("URL"), this.provider.url], + [ + msg("Service Provider Config cache timeout"), + this.provider.serviceProviderConfigCacheTimeout, + ], + [ + msg("Related actions"), + html` + ${msg("Save Changes")} + ${msg("Update SCIM Provider")} + -
-
-
- -
-
-
-
-
- ${msg("Dry-run")} +
-
-
- -
-
-
-
-
- ${msg("URL")} -
-
-
- ${this.provider.url} -
-
-
-
-
- - ${msg("Service Provider Config cache timeout")} - -
-
-
- ${this.provider.serviceProviderConfigCacheTimeout} -
-
-
-
-
-
-
+
+ ${this.provider.authMode === SCIMAuthenticationModeEnum.OauthInteractive + ? html` +
+
+ ${renderDescriptionList( + [ + [ + msg("OAuth Status"), + html` + ${msg("(Re-)authenticate")}`, + ], + [ + msg("OAuth Callback URL"), + html``, + ], + ], + { horizontal: true }, + )} +
+
+ ` + : nothing} { return new ProvidersApi(DEFAULT_CONFIG).providersScimSyncStatusRetrieve( @@ -261,7 +308,9 @@ export class SCIMProviderViewPage extends AKElement { }, ); }} - > + > + ${this.renderSyncStatusExtra()} +
diff --git a/web/src/elements/sync/SyncStatusCard.ts b/web/src/elements/sync/SyncStatusCard.ts index ec5939a7e6bd..dfe30f86b40b 100644 --- a/web/src/elements/sync/SyncStatusCard.ts +++ b/web/src/elements/sync/SyncStatusCard.ts @@ -90,6 +90,7 @@ export class SyncStatusCard extends AKElement {
+ `; }