Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 60 additions & 2 deletions authentik/enterprise/providers/scim/api.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,72 @@
from datetime import datetime

from django.urls import reverse
from django.utils.translation import gettext as _
from rest_framework.exceptions import ValidationError

from authentik.enterprise.license import LicenseKey
from authentik.providers.scim.models import SCIMAuthenticationMode
from authentik.providers.scim.models import SCIMAuthenticationMode, SCIMProvider
from authentik.sources.oauth.models import UserOAuthSourceConnection


class SCIMProviderSerializerMixin:

def _get_token(self, instance: SCIMProvider) -> UserOAuthSourceConnection | None:
user = instance.auth_oauth_user
conn = UserOAuthSourceConnection.objects.filter(
user=user, source=instance.auth_oauth
).first()
return conn

def get_auth_oauth_token_last_updated(self, instance: SCIMProvider) -> datetime | None:
conn = self._get_token(instance)
return conn.last_updated if conn else None

def get_auth_oauth_token_expires(self, instance: SCIMProvider) -> datetime | None:
conn = self._get_token(instance)
return conn.expires if conn else None

def get_auth_oauth_url_callback(self, instance: SCIMProvider) -> str | None:
if (
instance.auth_mode
in [
SCIMAuthenticationMode.TOKEN,
SCIMAuthenticationMode.OAUTH_SILENT,
]
or not instance.backchannel_application
):
return None
relative_url = reverse(
"authentik_enterprise_providers_scim:callback",
kwargs={"application_slug": instance.backchannel_application.slug},
)
if "request" not in self.context:
return relative_url
return self.context["request"].build_absolute_uri(relative_url)

def get_auth_oauth_url_start(self, instance: SCIMProvider) -> str | None:
if (
instance.auth_mode
in [
SCIMAuthenticationMode.TOKEN,
SCIMAuthenticationMode.OAUTH_SILENT,
]
or not instance.backchannel_application
):
return None
relative_url = reverse(
"authentik_enterprise_providers_scim:start",
kwargs={"application_slug": instance.backchannel_application.slug},
)
if "request" not in self.context:
return relative_url
return self.context["request"].build_absolute_uri(relative_url)

def validate_auth_mode(self, auth_mode: SCIMAuthenticationMode) -> SCIMAuthenticationMode:
if auth_mode == SCIMAuthenticationMode.OAUTH:
if auth_mode in [
SCIMAuthenticationMode.OAUTH_SILENT,
SCIMAuthenticationMode.OAUTH_INTERACTIVE,
]:
if not LicenseKey.cached_summary().status.is_valid:
raise ValidationError(_("Enterprise is required to use the OAuth mode."))
return auth_mode
1 change: 1 addition & 0 deletions authentik/enterprise/providers/scim/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ class AuthentikEnterpriseProviderSCIMConfig(EnterpriseConfig):
label = "authentik_enterprise_providers_scim"
verbose_name = "authentik Enterprise.Providers.SCIM"
default = True
mountpoint = "application/scim/"
34 changes: 21 additions & 13 deletions authentik/enterprise/providers/scim/auth_oauth2.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from django.utils.timezone import now
from requests import Request, RequestException
from structlog.stdlib import get_logger

from authentik.common.oauth.constants import GRANT_TYPE_PASSWORD, GRANT_TYPE_REFRESH_TOKEN
from authentik.providers.scim.clients.exceptions import SCIMRequestException
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
from authentik.providers.scim.models import SCIMAuthenticationMode
from authentik.sources.oauth.clients.base import BaseOAuthClient
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection

if TYPE_CHECKING:
Expand All @@ -18,23 +20,26 @@ class SCIMOAuthException(SCIMRequestException):


class SCIMOAuthAuth:

def __init__(self, provider: SCIMProvider):
self.provider = provider
self.user = provider.auth_oauth_user
self.logger = get_logger().bind()
self.connection = self.get_connection()

def retrieve_token(self):
if not self.provider.auth_oauth:
return None
def retrieve_token(self, conn: UserOAuthSourceConnection | None) -> dict[str, Any]:
source: OAuthSource = self.provider.auth_oauth
client = OAuth2Client(source, None)
client: BaseOAuthClient = source.source_type.callback_view(request=None).get_client(source)
access_token_url = source.source_type.access_token_url or ""
if source.source_type.urls_customizable and source.access_token_url:
access_token_url = source.access_token_url
data = client.get_access_token_args(None, None)
data["grant_type"] = "password"
if self.provider.auth_mode == SCIMAuthenticationMode.OAUTH_SILENT:
data["grant_type"] = GRANT_TYPE_PASSWORD
elif self.provider.auth_mode == SCIMAuthenticationMode.OAUTH_INTERACTIVE:
data["grant_type"] = GRANT_TYPE_REFRESH_TOKEN
if not conn:
raise SCIMOAuthException(None, "Could not refresh SCIM OAuth token")
data["refresh_token"] = conn.refresh_token
data.update(self.provider.auth_oauth_params)
try:
response = client.do_request(
Expand All @@ -54,19 +59,22 @@ def retrieve_token(self):
raise SCIMOAuthException(exc.response, message="Failed to get OAuth token") from exc

def get_connection(self):
token = UserOAuthSourceConnection.objects.filter(
source=self.provider.auth_oauth, user=self.user, expires__gt=now()
if not self.provider.auth_oauth:
return None
conn = UserOAuthSourceConnection.objects.filter(
source=self.provider.auth_oauth, user=self.user
).first()
if token and token.access_token:
return token
token = self.retrieve_token()
if conn and conn.access_token and conn.expires > now():
return conn
token = self.retrieve_token(conn)
access_token = token["access_token"]
expires_in = int(token.get("expires_in", 0))
token, _ = UserOAuthSourceConnection.objects.update_or_create(
source=self.provider.auth_oauth,
user=self.user,
defaults={
"access_token": access_token,
"refresh_token": token.get("refresh_token"),
"expires": now() + timedelta(seconds=expires_in),
},
)
Expand Down
5 changes: 4 additions & 1 deletion authentik/enterprise/providers/scim/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ def scim_provider_post_save(sender: type[Model], instance: SCIMProvider, created
"""Create service account before provider is saved"""
identifier = f"ak-providers-scim-{instance.pk}"
with audit_ignore():
if instance.auth_mode == SCIMAuthenticationMode.OAUTH:
if instance.auth_mode in [
SCIMAuthenticationMode.OAUTH_SILENT,
SCIMAuthenticationMode.OAUTH_INTERACTIVE,
]:
user, user_created = User.objects.update_or_create(
username=identifier,
defaults={
Expand Down
Empty file.
73 changes: 73 additions & 0 deletions authentik/enterprise/providers/scim/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""SCIM OAuth tests"""

from unittest.mock import MagicMock, PropertyMock, patch

from django.urls import reverse
from rest_framework.test import APITestCase

from authentik.core.tests.utils import create_test_admin_user
from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import License
from authentik.enterprise.tests.test_license import expiry_valid
from authentik.lib.generators import generate_id
from authentik.sources.oauth.models import OAuthSource


class TestSCIMOAuthAPI(APITestCase):
"""SCIM User tests"""

def setUp(self):
self.source = OAuthSource.objects.create(
name=generate_id(),
slug=generate_id(),
access_token_url="http://localhost/token", # nosec
consumer_key=generate_id(),
consumer_secret=generate_id(),
provider_type="openidconnect",
)

@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
def test_api_create(self):
License.objects.create(key=generate_id())
self.client.force_login(create_test_admin_user())
res = self.client.post(
reverse("authentik_api:scimprovider-list"),
{
"name": generate_id(),
"url": "http://localhost",
"auth_mode": "oauth_silent",
"auth_oauth": str(self.source.pk),
},
)
self.assertEqual(res.status_code, 201)

@patch(
"authentik.enterprise.models.LicenseUsageStatus.is_valid",
PropertyMock(return_value=False),
)
def test_api_create_no_license(self):
self.client.force_login(create_test_admin_user())
res = self.client.post(
reverse("authentik_api:scimprovider-list"),
{
"name": generate_id(),
"url": "http://localhost",
"auth_mode": "oauth_silent",
"auth_oauth": str(self.source.pk),
},
)
self.assertEqual(res.status_code, 400)
self.assertJSONEqual(
res.content, {"auth_mode": ["Enterprise is required to use the OAuth mode."]}
)
100 changes: 100 additions & 0 deletions authentik/enterprise/providers/scim/tests/test_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""SCIM OAuth tests"""

from requests_mock import Mocker
from rest_framework.test import APITestCase

from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import Application, Group, User
from authentik.lib.generators import generate_id
from authentik.providers.scim.models import SCIMAuthenticationMode, SCIMMapping, SCIMProvider
from authentik.sources.oauth.models import OAuthSource
from authentik.tenants.models import Tenant


class TestSCIMOAuthAuth(APITestCase):
"""SCIM User tests"""

@apply_blueprint("system/providers-scim.yaml")
def setUp(self) -> None:
# Delete all users and groups as the mocked HTTP responses only return one ID
# which will cause errors with multiple users
Tenant.objects.update(avatars="none")
User.objects.all().exclude_anonymous().delete()
Group.objects.all().delete()
self.source = OAuthSource.objects.create(
name=generate_id(),
slug=generate_id(),
access_token_url="http://localhost/token", # nosec
consumer_key=generate_id(),
consumer_secret=generate_id(),
provider_type="openidconnect",
)
self.provider = SCIMProvider.objects.create(
name=generate_id(),
url="https://localhost",
auth_mode=SCIMAuthenticationMode.OAUTH_SILENT,
auth_oauth=self.source,
auth_oauth_params={
"foo": "bar",
},
exclude_users_service_account=True,
)
self.app: Application = Application.objects.create(
name=generate_id(),
slug=generate_id(),
)
self.app.backchannel_providers.add(self.provider)
self.provider.property_mappings.add(
SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/user")
)
self.provider.property_mappings_group.add(
SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/group")
)

@Mocker()
def test_user_create(self, mock: Mocker):
"""Test user creation"""
scim_id = generate_id()
token = generate_id()
mock.post("http://localhost/token", json={"access_token": token, "expires_in": 3600})
mock.get(
"https://localhost/ServiceProviderConfig",
json={},
)
mock.post(
"https://localhost/Users",
json={
"id": scim_id,
},
)
uid = generate_id()
user = User.objects.create(
username=uid,
name=f"{uid} {uid}",
email=f"{uid}@goauthentik.io",
)
self.assertEqual(mock.call_count, 3)
self.assertEqual(mock.request_history[1].method, "GET")
self.assertEqual(mock.request_history[2].method, "POST")
self.assertJSONEqual(
mock.request_history[2].body,
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"active": True,
"emails": [
{
"primary": True,
"type": "other",
"value": f"{uid}@goauthentik.io",
}
],
"externalId": user.uid,
"name": {
"familyName": uid,
"formatted": f"{uid} {uid}",
"givenName": uid,
},
"displayName": f"{uid} {uid}",
"userName": uid,
},
)
Loading
Loading