From 195aa3c0a30d67a78683c60240a1cf70500ceb1d Mon Sep 17 00:00:00 2001 From: Ryan Pesek Date: Fri, 1 May 2026 09:35:19 -0500 Subject: [PATCH] fix per-user PolicyEngine evaluation in get_object_qs, fix other related inefficiencies for 2026.2 --- authentik/lib/sync/outgoing/tasks.py | 27 ++- authentik/providers/scim/models.py | 130 ++++++++++- .../scim/tests/test_application_policies.py | 202 +++++++++++++++++- 3 files changed, 337 insertions(+), 22 deletions(-) diff --git a/authentik/lib/sync/outgoing/tasks.py b/authentik/lib/sync/outgoing/tasks.py index fd1140d2f1aa..32e755792c2a 100644 --- a/authentik/lib/sync/outgoing/tasks.py +++ b/authentik/lib/sync/outgoing/tasks.py @@ -119,9 +119,12 @@ def _sync_cleanup(self, provider: OutgoingSyncProvider, task: Task): client = provider.client_for_model(object_type) except TransientSyncException: continue - in_scope_pks = set(provider.get_object_qs(object_type).values_list("pk", flat=True)) + # Pass the in-scope queryset as a subquery rather than materializing every + # PK into a Python set, which scales better and produces a single SQL + # statement rather than a NOT IN (...) clause with thousands of params. + in_scope_qs = provider.get_object_qs(object_type).values("pk") stale = client.connection_type.objects.filter(provider=provider).exclude( - **{f"{client.connection_type_query}__pk__in": in_scope_pks} + **{f"{client.connection_type_query}__pk__in": in_scope_qs} ) for connection in stale: try: @@ -256,11 +259,13 @@ def sync_signal_direct( task.warning("No provider found. Is it assigned to an application?") return client = provider.client_for_model(instance.__class__) - # Check if the object is allowed within the provider's restrictions - queryset = provider.get_object_qs(instance.__class__, pk=instance.pk) + # Check if the object is allowed within the provider's restrictions. # The queryset we get from the provider must include the instance we've got given - # otherwise ignore this provider - if not queryset or not queryset.exists(): + # otherwise ignore this provider. We use .exists() rather than `not queryset` + # because `bool(queryset)` materializes the entire queryset (calls _fetch_all), + # which is wasteful when we only need to know whether any row matches. + queryset = provider.get_object_qs(instance.__class__, pk=instance.pk) + if not queryset.exists(): return try: @@ -377,11 +382,13 @@ def sync_signal_m2m( task.warning("No provider found. Is it assigned to an application?") return - # Check if the object is allowed within the provider's restrictions - queryset: QuerySet = provider.get_object_qs(Group, pk=group_pk) + # Check if the object is allowed within the provider's restrictions. # The queryset we get from the provider must include the instance we've got given - # otherwise ignore this provider - if not queryset or not queryset.filter().exists(): + # otherwise ignore this provider. We use .exists() rather than `not queryset` + # because `bool(queryset)` materializes the entire queryset (calls _fetch_all), + # which is wasteful when we only need to know whether any row matches. + queryset: QuerySet = provider.get_object_qs(Group, pk=group_pk) + if not queryset.exists(): return client = provider.client_for_model(Group) diff --git a/authentik/providers/scim/models.py b/authentik/providers/scim/models.py index 0339b3d776c7..4eb70645381c 100644 --- a/authentik/providers/scim/models.py +++ b/authentik/providers/scim/models.py @@ -4,7 +4,7 @@ from uuid import uuid4 from django.db import models -from django.db.models import QuerySet +from django.db.models import Q, QuerySet from django.templatetags.static import static from django.utils.translation import gettext_lazy as _ from dramatiq.actor import Actor @@ -191,21 +191,18 @@ def get_object_qs(self, type: type[User | Group], **kwargs) -> QuerySet[User | G ) # Filter users by their access to the backchannel application if an application is set - # This handles both policy bindings and group_filters if self.backchannel_application: - base = base.filter( - pk__in=[ - user.pk - for user in base - if PolicyEngine(self.backchannel_application, user, None).build().passing - ] - ) + base = self._filter_users_by_application_access(base) return base.order_by("pk") if type == Group: # Get queryset of all groups with consistent ordering - # according to the provider's settings - base = Group.objects.prefetch_related("scimprovidergroup_set").all().filter(**kwargs) + # according to the provider's settings. + # Note: prefetch_related on scimprovidergroup_set was previously used here + # but the data was never accessed during sync iteration; .write() does its + # own per-object connection lookup, and .values_list() (used by cleanup) + # bypasses prefetches entirely. + base = Group.objects.all().filter(**kwargs) # Filter groups by group_filters if set if self.group_filters.exists(): @@ -214,6 +211,117 @@ def get_object_qs(self, type: type[User | Group], **kwargs) -> QuerySet[User | G return base.order_by("pk") raise ValueError(f"Invalid type {type}") + def _filter_users_by_application_access(self, base: QuerySet[User]) -> QuerySet[User]: + """Filter users by access to the backchannel application. + + This replicates ``PolicyEngine``'s static binding logic but applies it + as a single SQL filter when no ``Policy`` objects are bound. This avoids + the previous behavior of instantiating ``PolicyEngine`` once per user + across the entire database, which was O(N) per page (and called once + per page across N/page_size tasks, plus on cleanup), producing + catastrophic performance and millions of "P_ENG: Found static bindings" + log lines for instances with thousands of users. + """ + from authentik.policies.models import PolicyBinding, PolicyEngineMode + + app = self.backchannel_application + # Pull all bindings once (with select_related to avoid N+1 on .group/.user/.policy) + bindings = list( + PolicyBinding.objects.filter(target=app, enabled=True) + .select_related("group", "user", "policy") + .order_by("order") + ) + + # No bindings configured: PolicyEngine returns empty_result=True, so all users pass + if not bindings: + return base + + has_policy_bindings = any(b.policy_id is not None for b in bindings) + static_relevant = [b for b in bindings if b.policy_id is None and (b.group_id or b.user_id)] + + if not has_policy_bindings: + # Fast path: pure static bindings -> SQL only, zero PolicyEngine instantiations + if not static_relevant: + # Bindings exist but none are static-relevant. PolicyEngine would + # return empty_result=True and let policy bindings decide; with no + # policy bindings here, all users pass. + return base + return self._apply_static_bindings(base, static_relevant, app.policy_engine_mode) + + # Slow path: there are real Policy objects that need per-user evaluation. + # We can't replicate Policy logic in SQL (it's user-defined Python). However, + # we can shrink the candidate set using static bindings first when the policy + # engine mode allows it (MODE_ALL: all must pass, so static-failing users + # can never pass overall). For MODE_ANY we must check every user since a + # user might pass via policy alone. + if app.policy_engine_mode == PolicyEngineMode.MODE_ALL and static_relevant: + candidates = self._apply_static_bindings(base, static_relevant, app.policy_engine_mode) + else: + candidates = base + return base.filter( + pk__in=[ + user.pk + for user in candidates.iterator() + if PolicyEngine(app, user, None).build().passing + ] + ) + + @staticmethod + def _apply_static_bindings(base: QuerySet[User], bindings, mode) -> QuerySet[User]: + """Apply static (group/user) bindings to ``base`` using SQL only. + + Mirrors :meth:`PolicyEngine.compute_static_bindings` but operates on + the whole user queryset at once: + + * MODE_ANY (default): a user passes if ANY binding matches -> combine + per-binding Q expressions with OR (a single M2M JOIN is sufficient + because we're checking whether the user has *any* group matching + *any* binding). + * MODE_ALL: a user passes only if EVERY binding matches -> chain + ``.filter()`` calls so each M2M reference becomes a *separate* JOIN. + Combining ``Q(groups__in=g1) & Q(groups__in=g2)`` would always be + False because a single through-table row can only have one + ``group_id``; chained filters force Django to add a fresh alias for + each predicate. + """ + from authentik.policies.models import PolicyEngineMode + + if not bindings: + return base + + if mode == PolicyEngineMode.MODE_ALL: + qs = base + for binding in bindings: + qs = qs.filter(SCIMProvider._binding_to_q(binding)) + # Chained filters can still produce duplicates if the same row + # satisfies multiple JOINs trivially; .distinct() guarantees + # uniqueness. + return qs.distinct() + + # MODE_ANY (and any unknown mode) -> OR-combine per-binding Q + combined = SCIMProvider._binding_to_q(bindings[0]) + for binding in bindings[1:]: + combined |= SCIMProvider._binding_to_q(binding) + return base.filter(combined).distinct() + + @staticmethod + def _binding_to_q(binding) -> Q: + """Translate a single static :class:`PolicyBinding` into a Q expression + matching users for whom the binding "passes" (per + :meth:`PolicyEngine.compute_static_bindings`). + """ + if binding.user_id: + match = Q(pk=binding.user_id) + else: + # Group binding: match users in the bound group OR any descendant. + # "binding.group in user.all_groups()" is equivalent to + # "user is in binding.group.with_descendants()". + descendants = Group.objects.filter(pk=binding.group_id).with_descendants() + match = Q(groups__in=descendants) + if binding.negate: + match = ~match + return match + @classmethod def get_object_mappings(cls, obj: User | Group) -> list[tuple[str, str]]: if isinstance(obj, User): diff --git a/authentik/providers/scim/tests/test_application_policies.py b/authentik/providers/scim/tests/test_application_policies.py index 4b8a03276492..46e554fb829e 100644 --- a/authentik/providers/scim/tests/test_application_policies.py +++ b/authentik/providers/scim/tests/test_application_policies.py @@ -1,11 +1,14 @@ """SCIM Application Policies tests""" +from unittest.mock import patch + from django.test import TestCase from authentik.blueprints.tests import apply_blueprint from authentik.core.models import Application, Group, User from authentik.lib.generators import generate_id -from authentik.policies.models import PolicyBinding +from authentik.policies.dummy.models import DummyPolicy +from authentik.policies.models import PolicyBinding, PolicyEngineMode from authentik.providers.scim.models import SCIMMapping, SCIMProvider from authentik.tenants.models import Tenant @@ -88,3 +91,200 @@ def test_multiple_group_policies(self): set([self.users[1].pk, self.users[2].pk, self.users[4].pk]), set(user_qs.values_list("pk", flat=True)), ) + + def test_user_binding(self): + """Test with a direct user binding (no group, no policy)""" + PolicyBinding.objects.create(target=self.app, user=self.users[3], order=0) + + user_qs = self.provider.get_object_qs(User) + + self.assertEqual( + set([self.users[3].pk]), + set(user_qs.values_list("pk", flat=True)), + ) + + def test_user_and_group_binding_mode_any(self): + """Test with mixed user + group bindings, MODE_ANY (default)""" + PolicyBinding.objects.create(target=self.app, group=self.group1, order=0) + PolicyBinding.objects.create(target=self.app, user=self.users[3], order=1) + + user_qs = self.provider.get_object_qs(User) + + # Users 1, 4 in group1 + user 3 directly bound + self.assertEqual( + set([self.users[1].pk, self.users[3].pk, self.users[4].pk]), + set(user_qs.values_list("pk", flat=True)), + ) + + def test_negated_group_binding(self): + """Test with a negated group binding (block users in group1)""" + PolicyBinding.objects.create(target=self.app, group=self.group1, negate=True, order=0) + + user_qs = self.provider.get_object_qs(User) + + # Users NOT in group1: user 2 (group2 only) and user 3 (no groups) + # Users 1 and 4 are excluded because they're in group1 + self.assertEqual( + set([self.users[2].pk, self.users[3].pk]), + set(user_qs.values_list("pk", flat=True)), + ) + + def test_hierarchical_group_binding(self): + """Test that group binding includes users in descendant groups (via group.parents)""" + # Make group1 a parent of group2: users in group2 should match a binding on group1 + self.group2.parents.add(self.group1) + + PolicyBinding.objects.create(target=self.app, group=self.group1, order=0) + + user_qs = self.provider.get_object_qs(User) + + # User 1: directly in group1 -> match + # User 2: in group2 (descendant of group1) -> match + # User 3: no groups -> no match + # User 4: in group1 and group2 -> match + self.assertEqual( + set([self.users[1].pk, self.users[2].pk, self.users[4].pk]), + set(user_qs.values_list("pk", flat=True)), + ) + + def test_mode_all_two_group_bindings(self): + """MODE_ALL: user must be in ALL bound groups to sync""" + self.app.policy_engine_mode = PolicyEngineMode.MODE_ALL + self.app.save() + + PolicyBinding.objects.create(target=self.app, group=self.group1, order=0) + PolicyBinding.objects.create(target=self.app, group=self.group2, order=1) + + user_qs = self.provider.get_object_qs(User) + + # Only user 4 is in BOTH group1 AND group2 + self.assertEqual( + set([self.users[4].pk]), + set(user_qs.values_list("pk", flat=True)), + ) + + def test_disabled_binding_ignored(self): + """Disabled bindings should not affect the queryset""" + PolicyBinding.objects.create(target=self.app, group=self.group1, enabled=False, order=0) + + user_qs = self.provider.get_object_qs(User) + + # Disabled binding -> treated like no bindings -> all users sync + self.assertEqual( + set([self.users[1].pk, self.users[2].pk, self.users[3].pk, self.users[4].pk]), + set(user_qs.values_list("pk", flat=True)), + ) + + def test_static_bindings_dont_invoke_policy_engine_per_user(self): + """Performance regression test (PR #13947 / issue: 15-20 min per task). + + When the application's bindings are purely static (group/user, no Policy + object), ``get_object_qs`` must filter via SQL and NOT instantiate + ``PolicyEngine`` per user. The previous implementation iterated every + user in the database and built a ``PolicyEngine`` per user, which on a + production instance with ~6000 users multiplied across ~60 paginated + sync tasks produced ~360,000 PolicyEngine evaluations per provider per + full sync. + """ + PolicyBinding.objects.create(target=self.app, group=self.group1, order=0) + + # Add a few more users to make any per-user iteration obvious + for _ in range(20): + uid = generate_id() + user = User.objects.create(username=uid, name=uid, email=f"{uid}@goauthentik.io") + user.ak_groups.add(self.group1) + + with patch("authentik.providers.scim.models.PolicyEngine") as mock_policy_engine: + # Materialize the queryset + list(self.provider.get_object_qs(User)) + + self.assertEqual( + mock_policy_engine.call_count, + 0, + ( + "PolicyEngine was instantiated " + f"{mock_policy_engine.call_count} times during get_object_qs(). " + "Static bindings (no Policy object) must be evaluated as a " + "single SQL filter, not per-user. This is the regression that " + "caused 15-20 min sync tasks for instances with thousands of " + "users." + ), + ) + + def test_static_bindings_query_count_independent_of_user_count(self): + """Performance regression test: query count for ``get_object_qs`` must + not scale with the number of users in the database. + + Construction issues O(1) queries (one to fetch the application's + bindings); materialization is one further query. Crucially, neither + count grows with the user count -- the previous (broken) implementation + ran a PolicyEngine evaluation per user, multiplying the query count by + the user count. + """ + PolicyBinding.objects.create(target=self.app, group=self.group1, order=0) + + # Baseline with the existing 4 users + from django.db import connection + from django.test.utils import CaptureQueriesContext + + with CaptureQueriesContext(connection) as ctx_small: + list(self.provider.get_object_qs(User)) + baseline_queries = len(ctx_small.captured_queries) + + # Add many more users to the same group + for _ in range(50): + uid = generate_id() + user = User.objects.create(username=uid, name=uid, email=f"{uid}@goauthentik.io") + user.ak_groups.add(self.group1) + + with CaptureQueriesContext(connection) as ctx_large: + list(self.provider.get_object_qs(User)) + + self.assertEqual( + len(ctx_large.captured_queries), + baseline_queries, + ( + f"Query count grew from {baseline_queries} to " + f"{len(ctx_large.captured_queries)} after adding 50 users. " + "Query count must be O(1) with respect to user count for " + "static bindings." + ), + ) + + def test_actual_policy_binding_falls_back_to_per_user_evaluation(self): + """When a real ``Policy`` is bound, we cannot translate it to SQL and + must fall back to per-user evaluation. Verify correctness is preserved + on the slow path (this test does NOT assert performance).""" + # Create a policy that always passes + policy = DummyPolicy.objects.create(name=generate_id(), result=True, wait_min=0, wait_max=1) + PolicyBinding.objects.create(target=self.app, policy=policy, order=0) + + user_qs = self.provider.get_object_qs(User) + + # Policy passes for everyone -> all users sync + self.assertEqual( + set([self.users[1].pk, self.users[2].pk, self.users[3].pk, self.users[4].pk]), + set(user_qs.values_list("pk", flat=True)), + ) + + def test_mixed_static_and_policy_binding_mode_all_prefilters(self): + """When MODE_ALL is set with both a static binding and a real policy, + the static binding is used as a SQL pre-filter to shrink the candidate + set before per-user policy evaluation.""" + self.app.policy_engine_mode = PolicyEngineMode.MODE_ALL + self.app.save() + + # Static binding restricts to group1 (users 1, 4) + PolicyBinding.objects.create(target=self.app, group=self.group1, order=0) + # Policy that always passes + policy = DummyPolicy.objects.create(name=generate_id(), result=True, wait_min=0, wait_max=1) + PolicyBinding.objects.create(target=self.app, policy=policy, order=1) + + user_qs = self.provider.get_object_qs(User) + + # MODE_ALL: must pass both static (in group1) AND policy (always True) + # -> users 1 and 4 + self.assertEqual( + set([self.users[1].pk, self.users[4].pk]), + set(user_qs.values_list("pk", flat=True)), + )