Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions authentik/lib/sync/outgoing/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,12 @@ def _sync_cleanup(self, provider: OutgoingSyncProvider, task: Task):
client = provider.client_for_model(object_type)
except TransientSyncException:
continue
in_scope_pks = set(provider.get_object_qs(object_type).values_list("pk", flat=True))
# Pass the in-scope queryset as a subquery rather than materializing every
# PK into a Python set, which scales better and produces a single SQL
# statement rather than a NOT IN (...) clause with thousands of params.
in_scope_qs = provider.get_object_qs(object_type).values("pk")
stale = client.connection_type.objects.filter(provider=provider).exclude(
**{f"{client.connection_type_query}__pk__in": in_scope_pks}
**{f"{client.connection_type_query}__pk__in": in_scope_qs}
)
for connection in stale:
try:
Expand Down Expand Up @@ -256,11 +259,13 @@ def sync_signal_direct(
task.warning("No provider found. Is it assigned to an application?")
return
client = provider.client_for_model(instance.__class__)
# Check if the object is allowed within the provider's restrictions
queryset = provider.get_object_qs(instance.__class__, pk=instance.pk)
# Check if the object is allowed within the provider's restrictions.
# The queryset we get from the provider must include the instance we've got given
# otherwise ignore this provider
if not queryset or not queryset.exists():
# otherwise ignore this provider. We use .exists() rather than `not queryset`
# because `bool(queryset)` materializes the entire queryset (calls _fetch_all),
# which is wasteful when we only need to know whether any row matches.
queryset = provider.get_object_qs(instance.__class__, pk=instance.pk)
if not queryset.exists():
return

try:
Expand Down Expand Up @@ -377,11 +382,13 @@ def sync_signal_m2m(
task.warning("No provider found. Is it assigned to an application?")
return

# Check if the object is allowed within the provider's restrictions
queryset: QuerySet = provider.get_object_qs(Group, pk=group_pk)
# Check if the object is allowed within the provider's restrictions.
# The queryset we get from the provider must include the instance we've got given
# otherwise ignore this provider
if not queryset or not queryset.filter().exists():
# otherwise ignore this provider. We use .exists() rather than `not queryset`
# because `bool(queryset)` materializes the entire queryset (calls _fetch_all),
# which is wasteful when we only need to know whether any row matches.
queryset: QuerySet = provider.get_object_qs(Group, pk=group_pk)
if not queryset.exists():
return

client = provider.client_for_model(Group)
Expand Down
130 changes: 119 additions & 11 deletions authentik/providers/scim/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from uuid import uuid4

from django.db import models
from django.db.models import QuerySet
from django.db.models import Q, QuerySet
from django.templatetags.static import static
from django.utils.translation import gettext_lazy as _
from dramatiq.actor import Actor
Expand Down Expand Up @@ -191,21 +191,18 @@ def get_object_qs(self, type: type[User | Group], **kwargs) -> QuerySet[User | G
)

# Filter users by their access to the backchannel application if an application is set
# This handles both policy bindings and group_filters
if self.backchannel_application:
base = base.filter(
pk__in=[
user.pk
for user in base
if PolicyEngine(self.backchannel_application, user, None).build().passing
]
)
base = self._filter_users_by_application_access(base)
return base.order_by("pk")

if type == Group:
# Get queryset of all groups with consistent ordering
# according to the provider's settings
base = Group.objects.prefetch_related("scimprovidergroup_set").all().filter(**kwargs)
# according to the provider's settings.
# Note: prefetch_related on scimprovidergroup_set was previously used here
# but the data was never accessed during sync iteration; .write() does its
# own per-object connection lookup, and .values_list() (used by cleanup)
# bypasses prefetches entirely.
base = Group.objects.all().filter(**kwargs)

# Filter groups by group_filters if set
if self.group_filters.exists():
Expand All @@ -214,6 +211,117 @@ def get_object_qs(self, type: type[User | Group], **kwargs) -> QuerySet[User | G
return base.order_by("pk")
raise ValueError(f"Invalid type {type}")

def _filter_users_by_application_access(self, base: QuerySet[User]) -> QuerySet[User]:
"""Filter users by access to the backchannel application.

This replicates ``PolicyEngine``'s static binding logic but applies it
as a single SQL filter when no ``Policy`` objects are bound. This avoids
the previous behavior of instantiating ``PolicyEngine`` once per user
across the entire database, which was O(N) per page (and called once
per page across N/page_size tasks, plus on cleanup), producing
catastrophic performance and millions of "P_ENG: Found static bindings"
log lines for instances with thousands of users.
"""
from authentik.policies.models import PolicyBinding, PolicyEngineMode

app = self.backchannel_application
# Pull all bindings once (with select_related to avoid N+1 on .group/.user/.policy)
bindings = list(
PolicyBinding.objects.filter(target=app, enabled=True)
.select_related("group", "user", "policy")
.order_by("order")
)

# No bindings configured: PolicyEngine returns empty_result=True, so all users pass
if not bindings:
return base

has_policy_bindings = any(b.policy_id is not None for b in bindings)
static_relevant = [b for b in bindings if b.policy_id is None and (b.group_id or b.user_id)]

if not has_policy_bindings:
# Fast path: pure static bindings -> SQL only, zero PolicyEngine instantiations
if not static_relevant:
# Bindings exist but none are static-relevant. PolicyEngine would
# return empty_result=True and let policy bindings decide; with no
# policy bindings here, all users pass.
return base
return self._apply_static_bindings(base, static_relevant, app.policy_engine_mode)

# Slow path: there are real Policy objects that need per-user evaluation.
# We can't replicate Policy logic in SQL (it's user-defined Python). However,
# we can shrink the candidate set using static bindings first when the policy
# engine mode allows it (MODE_ALL: all must pass, so static-failing users
# can never pass overall). For MODE_ANY we must check every user since a
# user might pass via policy alone.
if app.policy_engine_mode == PolicyEngineMode.MODE_ALL and static_relevant:
candidates = self._apply_static_bindings(base, static_relevant, app.policy_engine_mode)
else:
candidates = base
return base.filter(
pk__in=[
user.pk
for user in candidates.iterator()
if PolicyEngine(app, user, None).build().passing
]
)

@staticmethod
def _apply_static_bindings(base: QuerySet[User], bindings, mode) -> QuerySet[User]:
"""Apply static (group/user) bindings to ``base`` using SQL only.

Mirrors :meth:`PolicyEngine.compute_static_bindings` but operates on
the whole user queryset at once:

* MODE_ANY (default): a user passes if ANY binding matches -> combine
per-binding Q expressions with OR (a single M2M JOIN is sufficient
because we're checking whether the user has *any* group matching
*any* binding).
* MODE_ALL: a user passes only if EVERY binding matches -> chain
``.filter()`` calls so each M2M reference becomes a *separate* JOIN.
Combining ``Q(groups__in=g1) & Q(groups__in=g2)`` would always be
False because a single through-table row can only have one
``group_id``; chained filters force Django to add a fresh alias for
each predicate.
"""
from authentik.policies.models import PolicyEngineMode

if not bindings:
return base

if mode == PolicyEngineMode.MODE_ALL:
qs = base
for binding in bindings:
qs = qs.filter(SCIMProvider._binding_to_q(binding))
# Chained filters can still produce duplicates if the same row
# satisfies multiple JOINs trivially; .distinct() guarantees
# uniqueness.
return qs.distinct()

# MODE_ANY (and any unknown mode) -> OR-combine per-binding Q
combined = SCIMProvider._binding_to_q(bindings[0])
for binding in bindings[1:]:
combined |= SCIMProvider._binding_to_q(binding)
return base.filter(combined).distinct()

@staticmethod
def _binding_to_q(binding) -> Q:
"""Translate a single static :class:`PolicyBinding` into a Q expression
matching users for whom the binding "passes" (per
:meth:`PolicyEngine.compute_static_bindings`).
"""
if binding.user_id:
match = Q(pk=binding.user_id)
else:
# Group binding: match users in the bound group OR any descendant.
# "binding.group in user.all_groups()" is equivalent to
# "user is in binding.group.with_descendants()".
descendants = Group.objects.filter(pk=binding.group_id).with_descendants()
match = Q(groups__in=descendants)
if binding.negate:
match = ~match
return match

@classmethod
def get_object_mappings(cls, obj: User | Group) -> list[tuple[str, str]]:
if isinstance(obj, User):
Expand Down
Loading
Loading