From def5a2b6cb8f8fa1902632614dd4df82eadad6ad Mon Sep 17 00:00:00 2001 From: Sean 'Shaleh' Perry Date: Wed, 7 Aug 2024 09:25:22 -0700 Subject: [PATCH 01/15] Improve multiple database support. The token models might not be stored in the default database. There might not _be_ a default database. Intead, the code now relies on Django's routers to determine the actual database to use when creating transactions. This required moving from decorators to context managers for those transactions. To test the multiple database scenario a new settings file as added which derives from settings.py and then defines different databases and the routers needed to access them. The commit is larger than might be expected because when there are multiple databases the Django tests have to be told which databases to work on. Rather than copying the various test cases or making multiple database specific ones the decision was made to add wrappers around the standard Django TestCase classes and programmatically define the databases for them. This enables all of the same test code to work for both the one database and the multi database scenarios with minimal maintenance costs. A tox environment that uses the multi db settings file has been added to ensure both scenarios are always tested. --- oauth2_provider/models.py | 24 ++++++++--- oauth2_provider/oauth2_validators.py | 13 ++++-- tests/common_testing.py | 35 +++++++++++++++ tests/db_router.py | 51 ++++++++++++++++++++++ tests/multi_db_settings.py | 19 +++++++++ tests/test_application_views.py | 2 +- tests/test_auth_backends.py | 4 +- tests/test_authorization_code.py | 3 +- tests/test_client_credential.py | 3 +- tests/test_commands.py | 2 +- tests/test_decorators.py | 4 +- tests/test_generator.py | 3 +- tests/test_hybrid.py | 8 ++-- tests/test_implicit.py | 3 +- tests/test_introspection_auth.py | 3 +- tests/test_introspection_view.py | 5 ++- tests/test_mixins.py | 3 +- tests/test_models.py | 15 ++++--- tests/test_oauth2_backends.py | 3 +- tests/test_oauth2_validators.py | 10 +++-- tests/test_oidc_views.py | 64 ++++++++++++++-------------- tests/test_password.py | 3 +- tests/test_rest_framework.py | 2 +- tests/test_scopes.py | 3 +- tests/test_settings.py | 2 +- tests/test_token_endpoint_cors.py | 3 +- tests/test_token_revocation.py | 4 +- tests/test_token_view.py | 3 +- tests/test_validators.py | 3 +- tox.ini | 7 +++ 30 files changed, 231 insertions(+), 76 deletions(-) create mode 100644 tests/common_testing.py create mode 100644 tests/db_router.py create mode 100644 tests/multi_db_settings.py diff --git a/oauth2_provider/models.py b/oauth2_provider/models.py index f979eef1c..2ec12b153 100644 --- a/oauth2_provider/models.py +++ b/oauth2_provider/models.py @@ -2,6 +2,7 @@ import logging import time import uuid +from contextlib import suppress from datetime import timedelta from urllib.parse import parse_qsl, urlparse @@ -9,7 +10,7 @@ from django.conf import settings from django.contrib.auth.hashers import identify_hasher, make_password from django.core.exceptions import ImproperlyConfigured -from django.db import models, transaction +from django.db import models, router, transaction from django.urls import reverse from django.utils import timezone from django.utils.translation import gettext_lazy as _ @@ -513,16 +514,27 @@ def revoke(self): """ access_token_model = get_access_token_model() refresh_token_model = get_refresh_token_model() - with transaction.atomic(): + + access_token_database = router.db_for_write(access_token_model) + refresh_token_database = router.db_for_write(refresh_token_model) + + # This is highly unlikely, but let's warn people just in case it does. + if access_token_database != refresh_token_database: + logger.warning( + "access token and refresh token are in separate databases but a transaction" + " is only used for the access token" + ) + + # Use the access_token_database instead of making the assumption it is in 'default'. + with transaction.atomic(using=access_token_database): token = refresh_token_model.objects.select_for_update().filter(pk=self.pk, revoked__isnull=True) if not token: return self = list(token)[0] - try: - access_token_model.objects.get(pk=self.access_token_id).revoke() - except access_token_model.DoesNotExist: - pass + with suppress(access_token_model.DoesNotExist): + access_token_model.objects.get(id=self.access_token_id).revoke() + self.access_token = None self.revoked = timezone.now() self.save() diff --git a/oauth2_provider/oauth2_validators.py b/oauth2_provider/oauth2_validators.py index 78667fa0e..808b02ae2 100644 --- a/oauth2_provider/oauth2_validators.py +++ b/oauth2_provider/oauth2_validators.py @@ -15,7 +15,7 @@ from django.contrib.auth import authenticate, get_user_model from django.contrib.auth.hashers import check_password, identify_hasher from django.core.exceptions import ObjectDoesNotExist -from django.db import transaction +from django.db import router, transaction from django.http import HttpRequest from django.utils import dateformat, timezone from django.utils.crypto import constant_time_compare @@ -562,8 +562,12 @@ def rotate_refresh_token(self, request): """ return oauth2_settings.ROTATE_REFRESH_TOKEN - @transaction.atomic def save_bearer_token(self, token, request, *args, **kwargs): + # Use the AccessToken's database instead of making the assumption it is in 'default'. + with transaction.atomic(using=router.db_for_write(AccessToken)): + return self._save_bearer_token_internals(token, request, *args, **kwargs) + + def _save_bearer_token_internals(self, token, request, *args, **kwargs): """ Save access and refresh token, If refresh token is issued, remove or reuse old refresh token as in rfc:`6` @@ -788,7 +792,6 @@ def validate_refresh_token(self, refresh_token, client, request, *args, **kwargs return rt.application == client - @transaction.atomic def _save_id_token(self, jti, request, expires, *args, **kwargs): scopes = request.scope or " ".join(request.scopes) @@ -889,7 +892,9 @@ def finalize_id_token(self, id_token, token, token_handler, request): claims=json.dumps(id_token, default=str), ) jwt_token.make_signed_token(request.client.jwk_key) - id_token = self._save_id_token(id_token["jti"], request, expiration_time) + # Use the IDToken's database instead of making the assumption it is in 'default'. + with transaction.atomic(using=router.db_for_write(IDToken)): + id_token = self._save_id_token(id_token["jti"], request, expiration_time) # this is needed by django rest framework request.access_token = id_token request.id_token = id_token diff --git a/tests/common_testing.py b/tests/common_testing.py new file mode 100644 index 000000000..daffd056f --- /dev/null +++ b/tests/common_testing.py @@ -0,0 +1,35 @@ +from django.conf import settings +from django.test import TestCase as DjangoTestCase +from django.test import TransactionTestCase as DjangoTransactionTestCase + + +class OAuth2ProviderTestCase(DjangoTestCase): + """Place holder to allow overriding behaviors.""" + + +class OAuth2ProviderTransactionTestCase(DjangoTransactionTestCase): + """Place holder to allow overriding behaviors.""" + + +if len(settings.DATABASES) > 1: + # There are multiple databases defined. When this happens Django tests will not + # work unless they are told which database(s) to work with. The multiple + # database scenario setup for these tests purposefully defines 'default' as an + # empty database in order to catch any assumptions in this package about database + # names and in particular to ensure there is no assumption that 'default' is a + # valid database. + # For any test that would usually use Django's TestCase or TransactionTestCase + # using the classes defined here is all that is required. + # Any test that uses pytest's django_db need to base in a databases parameter + # using this definition of test_database_names. + # In test code, anywhere the default database is used the variable + # database_for_oauth2_provider must be used in its place. For instance, + # with self.assertNumQueries(1, using=database_for_oauth2_provider): + # without the using option this fails because default is used. + test_database_names = {name for name in settings.DATABASES if name != "default"} + database_for_oauth2_provider = "alpha" + OAuth2ProviderTestCase.databases = test_database_names + OAuth2ProviderTransactionTestCase.databases = test_database_names +else: + test_database_names = {"default"} + database_for_oauth2_provider = "default" diff --git a/tests/db_router.py b/tests/db_router.py new file mode 100644 index 000000000..461c60ef3 --- /dev/null +++ b/tests/db_router.py @@ -0,0 +1,51 @@ +apps_in_beta = {"some_other_app", "this_one_too"} + +# These are bare minimum routers to fake the scenario where there is actually a +# decision around where an application's models might live. +# alpha is where the core Django models are stored including user. To keep things +# simple this is where the oauth2 provider models are stored as well because they +# have a foreign key to User. + + +class AlphaRouter: + def db_for_read(self, model, **hints): + if model._meta.app_label not in apps_in_beta: + return "alpha" + return None + + def db_for_write(self, model, **hints): + if model._meta.app_label not in apps_in_beta: + return "alpha" + return None + + def allow_relation(self, obj1, obj2, **hints): + if obj1._state.db == "alpha" and obj2._state.db == "alpha": + return True + return None + + def allow_migrate(self, db, app_label, model_name=None, **hints): + if app_label not in apps_in_beta: + return db == "alpha" + return None + + +class BetaRouter: + def db_for_read(self, model, **hints): + if model._meta.app_label in apps_in_beta: + return "beta" + return None + + def db_for_write(self, model, **hints): + if model._meta.app_label in apps_in_beta: + return "beta" + return None + + def allow_relation(self, obj1, obj2, **hints): + if obj1._state.db == "beta" and obj2._state.db == "beta": + return True + return None + + def allow_migrate(self, db, app_label, model_name=None, **hints): + if app_label in apps_in_beta: + return db == "beta" + return None diff --git a/tests/multi_db_settings.py b/tests/multi_db_settings.py new file mode 100644 index 000000000..a6daf04a3 --- /dev/null +++ b/tests/multi_db_settings.py @@ -0,0 +1,19 @@ +# Import the test settings and then override DATABASES. + +from .settings import * # noqa: F401, F403 + + +DATABASES = { + "alpha": { + "ENGINE": "django.db.backends.sqlite3", + "NAME": ":memory:", + }, + "beta": { + "ENGINE": "django.db.backends.sqlite3", + "NAME": ":memory:", + }, + # As https://docs.djangoproject.com/en/4.2/topics/db/multi-db/#defining-your-databases + # indicates, it is ok to have no default database. + "default": {}, +} +DATABASE_ROUTERS = ["tests.db_router.AlphaRouter", "tests.db_router.BetaRouter"] diff --git a/tests/test_application_views.py b/tests/test_application_views.py index c8c145d9b..88617807d 100644 --- a/tests/test_application_views.py +++ b/tests/test_application_views.py @@ -1,11 +1,11 @@ import pytest from django.contrib.auth import get_user_model -from django.test import TestCase from django.urls import reverse from oauth2_provider.models import get_application_model from oauth2_provider.views.application import ApplicationRegistration +from .common_testing import OAuth2ProviderTestCase as TestCase from .models import SampleApplication diff --git a/tests/test_auth_backends.py b/tests/test_auth_backends.py index b0ff145ab..49729b1c4 100644 --- a/tests/test_auth_backends.py +++ b/tests/test_auth_backends.py @@ -5,7 +5,7 @@ from django.contrib.auth.models import AnonymousUser from django.core.exceptions import SuspiciousOperation from django.http import HttpResponse -from django.test import RequestFactory, TestCase +from django.test import RequestFactory from django.test.utils import modify_settings, override_settings from django.utils.timezone import now, timedelta @@ -13,6 +13,8 @@ from oauth2_provider.middleware import OAuth2ExtraTokenMiddleware, OAuth2TokenMiddleware from oauth2_provider.models import get_access_token_model, get_application_model +from .common_testing import OAuth2ProviderTestCase as TestCase + UserModel = get_user_model() ApplicationModel = get_application_model() diff --git a/tests/test_authorization_code.py b/tests/test_authorization_code.py index ae6e7e76e..122474950 100644 --- a/tests/test_authorization_code.py +++ b/tests/test_authorization_code.py @@ -7,7 +7,7 @@ import pytest from django.conf import settings from django.contrib.auth import get_user_model -from django.test import RequestFactory, TestCase +from django.test import RequestFactory from django.urls import reverse from django.utils import timezone from django.utils.crypto import get_random_string @@ -23,6 +23,7 @@ from oauth2_provider.views import ProtectedResourceView from . import presets +from .common_testing import OAuth2ProviderTestCase as TestCase from .utils import get_basic_auth_header diff --git a/tests/test_client_credential.py b/tests/test_client_credential.py index 4c6e384d0..3572f432d 100644 --- a/tests/test_client_credential.py +++ b/tests/test_client_credential.py @@ -4,7 +4,7 @@ import pytest from django.contrib.auth import get_user_model from django.core.exceptions import SuspiciousOperation -from django.test import RequestFactory, TestCase +from django.test import RequestFactory from django.urls import reverse from django.views.generic import View from oauthlib.oauth2 import BackendApplicationServer @@ -16,6 +16,7 @@ from oauth2_provider.views.mixins import OAuthLibMixin from . import presets +from .common_testing import OAuth2ProviderTestCase as TestCase from .utils import get_basic_auth_header diff --git a/tests/test_commands.py b/tests/test_commands.py index 8861f5698..5204ebf77 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -5,11 +5,11 @@ from django.contrib.auth.hashers import check_password from django.core.management import call_command from django.core.management.base import CommandError -from django.test import TestCase from oauth2_provider.models import get_application_model from . import presets +from .common_testing import OAuth2ProviderTestCase as TestCase Application = get_application_model() diff --git a/tests/test_decorators.py b/tests/test_decorators.py index a8ee788b5..f91ada2ac 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -1,12 +1,14 @@ from datetime import timedelta from django.contrib.auth import get_user_model -from django.test import RequestFactory, TestCase +from django.test import RequestFactory from django.utils import timezone from oauth2_provider.decorators import protected_resource, rw_protected_resource from oauth2_provider.models import get_access_token_model, get_application_model +from .common_testing import OAuth2ProviderTestCase as TestCase + Application = get_application_model() AccessToken = get_access_token_model() diff --git a/tests/test_generator.py b/tests/test_generator.py index cc7928017..201200b00 100644 --- a/tests/test_generator.py +++ b/tests/test_generator.py @@ -1,8 +1,9 @@ import pytest -from django.test import TestCase from oauth2_provider.generators import BaseHashGenerator, generate_client_id, generate_client_secret +from .common_testing import OAuth2ProviderTestCase as TestCase + class MockHashGenerator(BaseHashGenerator): def hash(self): diff --git a/tests/test_hybrid.py b/tests/test_hybrid.py index 40cd8c56f..204be7671 100644 --- a/tests/test_hybrid.py +++ b/tests/test_hybrid.py @@ -5,7 +5,7 @@ import pytest from django.contrib.auth import get_user_model -from django.test import RequestFactory, TestCase +from django.test import RequestFactory from django.urls import reverse from django.utils import timezone from jwcrypto import jwt @@ -21,6 +21,8 @@ from oauth2_provider.views import ProtectedResourceView, ScopedProtectedResourceView from . import presets +from .common_testing import OAuth2ProviderTestCase as TestCase +from .common_testing import test_database_names from .utils import get_basic_auth_header, spy_on @@ -1318,7 +1320,7 @@ def test_pre_auth_default_scopes(self): self.assertEqual(form["client_id"].value(), self.application.client_id) -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_id_token_nonce_in_token_response(oauth2_settings, test_user, hybrid_application, client, oidc_key): client.force_login(test_user) @@ -1367,7 +1369,7 @@ def test_id_token_nonce_in_token_response(oauth2_settings, test_user, hybrid_app assert claims["nonce"] == "random_nonce_string" -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_claims_passed_to_code_generation( oauth2_settings, test_user, hybrid_application, client, mocker, oidc_key diff --git a/tests/test_implicit.py b/tests/test_implicit.py index 3f16cf71f..85e773d22 100644 --- a/tests/test_implicit.py +++ b/tests/test_implicit.py @@ -3,7 +3,7 @@ import pytest from django.contrib.auth import get_user_model -from django.test import RequestFactory, TestCase +from django.test import RequestFactory from django.urls import reverse from jwcrypto import jwt @@ -11,6 +11,7 @@ from oauth2_provider.views import ProtectedResourceView from . import presets +from .common_testing import OAuth2ProviderTestCase as TestCase Application = get_application_model() diff --git a/tests/test_introspection_auth.py b/tests/test_introspection_auth.py index d96a013e3..e1a096428 100644 --- a/tests/test_introspection_auth.py +++ b/tests/test_introspection_auth.py @@ -6,7 +6,7 @@ from django.conf.urls import include from django.contrib.auth import get_user_model from django.http import HttpResponse -from django.test import TestCase, override_settings +from django.test import override_settings from django.urls import path from django.utils import timezone from oauthlib.common import Request @@ -18,6 +18,7 @@ from oauth2_provider.views import ScopedProtectedResourceView from . import presets +from .common_testing import OAuth2ProviderTestCase as TestCase try: diff --git a/tests/test_introspection_view.py b/tests/test_introspection_view.py index b82e922be..a1d1df493 100644 --- a/tests/test_introspection_view.py +++ b/tests/test_introspection_view.py @@ -3,13 +3,14 @@ import pytest from django.contrib.auth import get_user_model -from django.test import TestCase from django.urls import reverse from django.utils import timezone from oauth2_provider.models import get_access_token_model, get_application_model from . import presets +from .common_testing import OAuth2ProviderTestCase as TestCase +from .common_testing import database_for_oauth2_provider from .utils import get_basic_auth_header @@ -343,5 +344,5 @@ def test_view_post_invalid_client_creds_plaintext(self): self.assertEqual(response.status_code, 403) def test_select_related_in_view_for_less_db_queries(self): - with self.assertNumQueries(1): + with self.assertNumQueries(1, using=database_for_oauth2_provider): self.client.post(reverse("oauth2_provider:introspect")) diff --git a/tests/test_mixins.py b/tests/test_mixins.py index 327a99194..1cefa1334 100644 --- a/tests/test_mixins.py +++ b/tests/test_mixins.py @@ -3,7 +3,7 @@ import pytest from django.core.exceptions import ImproperlyConfigured from django.http import HttpResponse -from django.test import RequestFactory, TestCase +from django.test import RequestFactory from django.views.generic import View from oauthlib.oauth2 import Server @@ -18,6 +18,7 @@ ) from . import presets +from .common_testing import OAuth2ProviderTestCase as TestCase @pytest.mark.usefixtures("oauth2_settings") diff --git a/tests/test_models.py b/tests/test_models.py index 24e4ceafe..196bac25a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -6,7 +6,6 @@ from django.contrib.auth import get_user_model from django.contrib.auth.hashers import check_password from django.core.exceptions import ImproperlyConfigured, ValidationError -from django.test import TestCase from django.test.utils import override_settings from django.utils import timezone @@ -20,6 +19,8 @@ ) from . import presets +from .common_testing import OAuth2ProviderTestCase as TestCase +from .common_testing import test_database_names CLEARTEXT_SECRET = "1234567890abcdefghijklmnopqrstuvwxyz" @@ -466,7 +467,7 @@ def test_clear_expired_tokens_with_tokens(self): assert remaining_gt_count == initial_gt_count // 2, "half the remaining grants should still exist." -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_id_token_methods(oidc_tokens, rf): id_token = IDToken.objects.get() @@ -501,7 +502,7 @@ def test_id_token_methods(oidc_tokens, rf): assert IDToken.objects.filter(jti=id_token.jti).count() == 0 -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_clear_expired_id_tokens(oauth2_settings, oidc_tokens, rf): id_token = IDToken.objects.get() @@ -540,7 +541,7 @@ def test_clear_expired_id_tokens(oauth2_settings, oidc_tokens, rf): assert not IDToken.objects.filter(jti=id_token.jti).exists() -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_application_key(oauth2_settings, application): # RS256 key @@ -565,7 +566,7 @@ def test_application_key(oauth2_settings, application): assert "This application does not support signed tokens" == str(exc.value) -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_application_clean(oauth2_settings, application): # RS256, RSA key is configured @@ -605,7 +606,7 @@ def test_application_clean(oauth2_settings, application): application.clean() -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) @pytest.mark.oauth2_settings(presets.ALLOWED_SCHEMES_DEFAULT) def test_application_origin_allowed_default_https(oauth2_settings, cors_application): """Test that http schemes are not allowed because ALLOWED_SCHEMES allows only https""" @@ -613,7 +614,7 @@ def test_application_origin_allowed_default_https(oauth2_settings, cors_applicat assert not cors_application.origin_allowed("http://example.com") -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) @pytest.mark.oauth2_settings(presets.ALLOWED_SCHEMES_HTTP) def test_application_origin_allowed_http(oauth2_settings, cors_application): """Test that http schemes are allowed because http was added to ALLOWED_SCHEMES""" diff --git a/tests/test_oauth2_backends.py b/tests/test_oauth2_backends.py index 21dd7a0c3..a4408f8e6 100644 --- a/tests/test_oauth2_backends.py +++ b/tests/test_oauth2_backends.py @@ -3,12 +3,13 @@ import pytest from django.contrib.auth import get_user_model -from django.test import RequestFactory, TestCase +from django.test import RequestFactory from django.utils.timezone import now, timedelta from oauth2_provider.backends import get_oauthlib_core from oauth2_provider.models import get_access_token_model, get_application_model, redirect_to_uri_allowed from oauth2_provider.oauth2_backends import JSONOAuthLibCore, OAuthLibCore +from tests.common_testing import OAuth2ProviderTestCase as TestCase try: diff --git a/tests/test_oauth2_validators.py b/tests/test_oauth2_validators.py index ca80aedb0..d4e53c37f 100644 --- a/tests/test_oauth2_validators.py +++ b/tests/test_oauth2_validators.py @@ -5,7 +5,6 @@ import pytest from django.contrib.auth import get_user_model from django.contrib.auth.hashers import make_password -from django.test import TestCase, TransactionTestCase from django.utils import timezone from jwcrypto import jwt from oauthlib.common import Request @@ -16,6 +15,9 @@ from oauth2_provider.oauth2_validators import OAuth2Validator from . import presets +from .common_testing import OAuth2ProviderTestCase as TestCase +from .common_testing import OAuth2ProviderTransactionTestCase as TransactionTestCase +from .common_testing import test_database_names from .utils import get_basic_auth_header @@ -545,7 +547,7 @@ def test_get_jwt_bearer_token(oauth2_settings, mocker): assert mock_get_id_token.call_args[1] == {} -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_validate_id_token_expired_jwt(oauth2_settings, mocker, oidc_tokens): mocker.patch("oauth2_provider.oauth2_validators.jwt.JWT", side_effect=jwt.JWTExpired) @@ -561,7 +563,7 @@ def test_validate_id_token_no_token(oauth2_settings, mocker): assert status is False -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_validate_id_token_app_removed(oauth2_settings, mocker, oidc_tokens): oidc_tokens.application.delete() @@ -570,7 +572,7 @@ def test_validate_id_token_app_removed(oauth2_settings, mocker, oidc_tokens): assert status is False -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_validate_id_token_bad_token_no_aud(oauth2_settings, mocker, oidc_key): token = jwt.JWT(header=json.dumps({"alg": "RS256"}), claims=json.dumps({"bad": "token"})) diff --git a/tests/test_oidc_views.py b/tests/test_oidc_views.py index f44a808e7..882b01b8c 100644 --- a/tests/test_oidc_views.py +++ b/tests/test_oidc_views.py @@ -1,7 +1,7 @@ import pytest from django.contrib.auth import get_user from django.contrib.auth.models import AnonymousUser -from django.test import RequestFactory, TestCase +from django.test import RequestFactory from django.urls import reverse from django.utils import timezone from pytest_django.asserts import assertRedirects @@ -18,6 +18,8 @@ from oauth2_provider.views.oidc import RPInitiatedLogoutView, _load_id_token, _validate_claims from . import presets +from .common_testing import OAuth2ProviderTestCase as TestCase +from .common_testing import test_database_names @pytest.mark.usefixtures("oauth2_settings") @@ -220,7 +222,7 @@ def mock_request_for(user): return request -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) def test_validate_logout_request(oidc_tokens, public_application, rp_settings): oidc_tokens = oidc_tokens application = oidc_tokens.application @@ -298,7 +300,7 @@ def test_validate_logout_request(oidc_tokens, public_application, rp_settings): ) -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) @pytest.mark.parametrize("ALWAYS_PROMPT", [True, False]) def test_must_prompt(oidc_tokens, other_user, rp_settings, ALWAYS_PROMPT): rp_settings.OIDC_RP_INITIATED_LOGOUT_ALWAYS_PROMPT = ALWAYS_PROMPT @@ -319,14 +321,14 @@ def is_logged_in(client): return get_user(client).is_authenticated -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) def test_rp_initiated_logout_get(logged_in_client, rp_settings): rsp = logged_in_client.get(reverse("oauth2_provider:rp-initiated-logout"), data={}) assert rsp.status_code == 200 assert is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) def test_rp_initiated_logout_get_id_token(logged_in_client, oidc_tokens, rp_settings): rsp = logged_in_client.get( reverse("oauth2_provider:rp-initiated-logout"), data={"id_token_hint": oidc_tokens.id_token} @@ -336,7 +338,7 @@ def test_rp_initiated_logout_get_id_token(logged_in_client, oidc_tokens, rp_sett assert not is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) def test_rp_initiated_logout_get_revoked_id_token(logged_in_client, oidc_tokens, rp_settings): validator = oauth2_settings.OAUTH2_VALIDATOR_CLASS() validator._load_id_token(oidc_tokens.id_token).revoke() @@ -347,7 +349,7 @@ def test_rp_initiated_logout_get_revoked_id_token(logged_in_client, oidc_tokens, assert is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) def test_rp_initiated_logout_get_id_token_redirect(logged_in_client, oidc_tokens, rp_settings): rsp = logged_in_client.get( reverse("oauth2_provider:rp-initiated-logout"), @@ -358,7 +360,7 @@ def test_rp_initiated_logout_get_id_token_redirect(logged_in_client, oidc_tokens assert not is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) def test_rp_initiated_logout_get_id_token_redirect_with_state(logged_in_client, oidc_tokens, rp_settings): rsp = logged_in_client.get( reverse("oauth2_provider:rp-initiated-logout"), @@ -373,7 +375,7 @@ def test_rp_initiated_logout_get_id_token_redirect_with_state(logged_in_client, assert not is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) def test_rp_initiated_logout_get_id_token_missmatch_client_id( logged_in_client, oidc_tokens, public_application, rp_settings ): @@ -385,7 +387,7 @@ def test_rp_initiated_logout_get_id_token_missmatch_client_id( assert is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) def test_rp_initiated_logout_public_client_redirect_client_id( logged_in_client, oidc_non_confidential_tokens, public_application, rp_settings ): @@ -401,7 +403,7 @@ def test_rp_initiated_logout_public_client_redirect_client_id( assert not is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) def test_rp_initiated_logout_public_client_strict_redirect_client_id( logged_in_client, oidc_non_confidential_tokens, public_application, oauth2_settings ): @@ -418,7 +420,7 @@ def test_rp_initiated_logout_public_client_strict_redirect_client_id( assert is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) def test_rp_initiated_logout_get_client_id(logged_in_client, oidc_tokens, rp_settings): rsp = logged_in_client.get( reverse("oauth2_provider:rp-initiated-logout"), data={"client_id": oidc_tokens.application.client_id} @@ -427,7 +429,7 @@ def test_rp_initiated_logout_get_client_id(logged_in_client, oidc_tokens, rp_set assert is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) def test_rp_initiated_logout_post(logged_in_client, oidc_tokens, rp_settings): form_data = { "client_id": oidc_tokens.application.client_id, @@ -437,7 +439,7 @@ def test_rp_initiated_logout_post(logged_in_client, oidc_tokens, rp_settings): assert is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) def test_rp_initiated_logout_post_allowed(logged_in_client, oidc_tokens, rp_settings): form_data = {"client_id": oidc_tokens.application.client_id, "allow": True} rsp = logged_in_client.post(reverse("oauth2_provider:rp-initiated-logout"), form_data) @@ -446,7 +448,7 @@ def test_rp_initiated_logout_post_allowed(logged_in_client, oidc_tokens, rp_sett assert not is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) def test_rp_initiated_logout_post_no_session(client, oidc_tokens, rp_settings): form_data = {"client_id": oidc_tokens.application.client_id, "allow": True} rsp = client.post(reverse("oauth2_provider:rp-initiated-logout"), form_data) @@ -455,7 +457,7 @@ def test_rp_initiated_logout_post_no_session(client, oidc_tokens, rp_settings): assert not is_logged_in(client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT) def test_rp_initiated_logout_expired_tokens_accept(logged_in_client, application, expired_id_token): # Accepting expired (but otherwise valid and signed by us) tokens is enabled. Logout should go through. @@ -470,7 +472,7 @@ def test_rp_initiated_logout_expired_tokens_accept(logged_in_client, application assert not is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT_DENY_EXPIRED) def test_rp_initiated_logout_expired_tokens_deny(logged_in_client, application, expired_id_token): # Expired tokens should not be accepted by default. @@ -485,14 +487,14 @@ def test_rp_initiated_logout_expired_tokens_deny(logged_in_client, application, assert is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT) def test_load_id_token_accept_expired(expired_id_token): id_token, _ = _load_id_token(expired_id_token) assert isinstance(id_token, get_id_token_model()) -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT) def test_load_id_token_wrong_aud(id_token_wrong_aud): id_token, claims = _load_id_token(id_token_wrong_aud) @@ -500,7 +502,7 @@ def test_load_id_token_wrong_aud(id_token_wrong_aud): assert claims is None -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT_DENY_EXPIRED) def test_load_id_token_deny_expired(expired_id_token): id_token, claims = _load_id_token(expired_id_token) @@ -508,7 +510,7 @@ def test_load_id_token_deny_expired(expired_id_token): assert claims is None -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT) def test_validate_claims_wrong_iss(id_token_wrong_iss): id_token, claims = _load_id_token(id_token_wrong_iss) @@ -517,7 +519,7 @@ def test_validate_claims_wrong_iss(id_token_wrong_iss): assert not _validate_claims(mock_request(), claims) -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT) def test_validate_claims(oidc_tokens): id_token, claims = _load_id_token(oidc_tokens.id_token) @@ -525,7 +527,7 @@ def test_validate_claims(oidc_tokens): assert _validate_claims(mock_request_for(oidc_tokens.user), claims) -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) @pytest.mark.parametrize("method", ["get", "post"]) def test_userinfo_endpoint(oidc_tokens, client, method): auth_header = "Bearer %s" % oidc_tokens.access_token @@ -538,7 +540,7 @@ def test_userinfo_endpoint(oidc_tokens, client, method): assert data["sub"] == str(oidc_tokens.user.pk) -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) def test_userinfo_endpoint_bad_token(oidc_tokens, client): # No access token rsp = client.get(reverse("oauth2_provider:user-info")) @@ -551,7 +553,7 @@ def test_userinfo_endpoint_bad_token(oidc_tokens, client): assert rsp.status_code == 401 -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) def test_token_deletion_on_logout(oidc_tokens, logged_in_client, rp_settings): AccessToken = get_access_token_model() IDToken = get_id_token_model() @@ -574,7 +576,7 @@ def test_token_deletion_on_logout(oidc_tokens, logged_in_client, rp_settings): assert all([token.revoked <= timezone.now() for token in RefreshToken.objects.all()]) -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) def test_token_deletion_on_logout_expired_session(oidc_tokens, client, rp_settings): AccessToken = get_access_token_model() IDToken = get_id_token_model() @@ -615,7 +617,7 @@ def test_token_deletion_on_logout_expired_session(oidc_tokens, client, rp_settin assert all(token.revoked <= timezone.now() for token in RefreshToken.objects.all()) -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT_KEEP_TOKENS) def test_token_deletion_on_logout_disabled(oidc_tokens, logged_in_client, rp_settings): rp_settings.OIDC_RP_INITIATED_LOGOUT_DELETE_TOKENS = False @@ -651,7 +653,7 @@ def claim_user_email(request): return EXAMPLE_EMAIL -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) def test_userinfo_endpoint_custom_claims_callable(oidc_tokens, client, oauth2_settings): class CustomValidator(OAuth2Validator): oidc_claim_scope = None @@ -679,7 +681,7 @@ def get_additional_claims(self): assert data["email"] == EXAMPLE_EMAIL -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) def test_userinfo_endpoint_custom_claims_email_scope_callable( oidc_email_scope_tokens, client, oauth2_settings ): @@ -706,7 +708,7 @@ def get_additional_claims(self): assert data["email"] == EXAMPLE_EMAIL -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) def test_userinfo_endpoint_custom_claims_plain(oidc_tokens, client, oauth2_settings): class CustomValidator(OAuth2Validator): oidc_claim_scope = None @@ -734,7 +736,7 @@ def get_additional_claims(self, request): assert data["email"] == EXAMPLE_EMAIL -@pytest.mark.django_db +@pytest.mark.django_db(databases=test_database_names) def test_userinfo_endpoint_custom_claims_email_scopeplain(oidc_email_scope_tokens, client, oauth2_settings): class CustomValidator(OAuth2Validator): def get_additional_claims(self, request): diff --git a/tests/test_password.py b/tests/test_password.py index ec9f17f54..65cf5a8b5 100644 --- a/tests/test_password.py +++ b/tests/test_password.py @@ -2,12 +2,13 @@ import pytest from django.contrib.auth import get_user_model -from django.test import RequestFactory, TestCase +from django.test import RequestFactory from django.urls import reverse from oauth2_provider.models import get_application_model from oauth2_provider.views import ProtectedResourceView +from .common_testing import OAuth2ProviderTestCase as TestCase from .utils import get_basic_auth_header diff --git a/tests/test_rest_framework.py b/tests/test_rest_framework.py index 84b4ad7d9..f8ff86f23 100644 --- a/tests/test_rest_framework.py +++ b/tests/test_rest_framework.py @@ -5,7 +5,6 @@ from django.contrib.auth import get_user_model from django.core.exceptions import ImproperlyConfigured from django.http import HttpResponse -from django.test import TestCase from django.test.utils import override_settings from django.urls import path, re_path from django.utils import timezone @@ -25,6 +24,7 @@ from oauth2_provider.models import get_access_token_model, get_application_model from . import presets +from .common_testing import OAuth2ProviderTestCase as TestCase Application = get_application_model() diff --git a/tests/test_scopes.py b/tests/test_scopes.py index ec36da418..4dae0d3c4 100644 --- a/tests/test_scopes.py +++ b/tests/test_scopes.py @@ -4,12 +4,13 @@ import pytest from django.contrib.auth import get_user_model from django.core.exceptions import ImproperlyConfigured -from django.test import RequestFactory, TestCase +from django.test import RequestFactory from django.urls import reverse from oauth2_provider.models import get_access_token_model, get_application_model, get_grant_model from oauth2_provider.views import ReadWriteScopedResourceView, ScopedProtectedResourceView +from .common_testing import OAuth2ProviderTestCase as TestCase from .utils import get_basic_auth_header diff --git a/tests/test_settings.py b/tests/test_settings.py index f9f540339..b64fc31db 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -1,6 +1,5 @@ import pytest from django.core.exceptions import ImproperlyConfigured -from django.test import TestCase from django.test.utils import override_settings from oauthlib.common import Request @@ -19,6 +18,7 @@ CustomIDTokenAdmin, CustomRefreshTokenAdmin, ) +from tests.common_testing import OAuth2ProviderTestCase as TestCase from . import presets diff --git a/tests/test_token_endpoint_cors.py b/tests/test_token_endpoint_cors.py index 791237b4a..6eaea6560 100644 --- a/tests/test_token_endpoint_cors.py +++ b/tests/test_token_endpoint_cors.py @@ -3,12 +3,13 @@ import pytest from django.contrib.auth import get_user_model -from django.test import RequestFactory, TestCase +from django.test import RequestFactory from django.urls import reverse from oauth2_provider.models import get_application_model from . import presets +from .common_testing import OAuth2ProviderTestCase as TestCase from .utils import get_basic_auth_header diff --git a/tests/test_token_revocation.py b/tests/test_token_revocation.py index 4883e850c..fa836b6a2 100644 --- a/tests/test_token_revocation.py +++ b/tests/test_token_revocation.py @@ -1,12 +1,14 @@ import datetime from django.contrib.auth import get_user_model -from django.test import RequestFactory, TestCase +from django.test import RequestFactory from django.urls import reverse from django.utils import timezone from oauth2_provider.models import get_access_token_model, get_application_model, get_refresh_token_model +from .common_testing import OAuth2ProviderTestCase as TestCase + Application = get_application_model() AccessToken = get_access_token_model() diff --git a/tests/test_token_view.py b/tests/test_token_view.py index fc73c2a66..63e76ed2f 100644 --- a/tests/test_token_view.py +++ b/tests/test_token_view.py @@ -1,12 +1,13 @@ import datetime from django.contrib.auth import get_user_model -from django.test import TestCase from django.urls import reverse from django.utils import timezone from oauth2_provider.models import get_access_token_model, get_application_model +from .common_testing import OAuth2ProviderTestCase as TestCase + Application = get_application_model() AccessToken = get_access_token_model() diff --git a/tests/test_validators.py b/tests/test_validators.py index a28e54a4d..eb382c154 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -1,9 +1,10 @@ import pytest from django.core.validators import ValidationError -from django.test import TestCase from oauth2_provider.validators import AllowedURIValidator +from .common_testing import OAuth2ProviderTestCase as TestCase + @pytest.mark.usefixtures("oauth2_settings") class TestAllowedURIValidator(TestCase): diff --git a/tox.ini b/tox.ini index 2372f044b..58fd222ef 100644 --- a/tox.ini +++ b/tox.ini @@ -9,6 +9,7 @@ envlist = py{310,311,312}-dj50, py{310,311,312}-dj51, py{310,311,312}-djmain, + py39-multi-db-dj-42, [gh-actions] python = @@ -107,6 +108,12 @@ setenv = PYTHONWARNINGS = all commands = django-admin makemigrations --dry-run --check +[testenv:py39-multi-db-dj42] +setenv = + DJANGO_SETTINGS_MODULE = tests.multi_db_settings + PYTHONPATH = {toxinidir} + PYTHONWARNINGS = all + [testenv:migrate_swapped] setenv = DJANGO_SETTINGS_MODULE = tests.settings_swapped From 14821baa2089a0a0a3f83a9ec84ac19e97c00f25 Mon Sep 17 00:00:00 2001 From: Sean 'Shaleh' Perry Date: Wed, 7 Aug 2024 09:41:24 -0700 Subject: [PATCH 02/15] changelog entry and authors update --- AUTHORS | 1 + CHANGELOG.md | 2 ++ 2 files changed, 3 insertions(+) diff --git a/AUTHORS b/AUTHORS index 584ecf59c..1647abbb4 100644 --- a/AUTHORS +++ b/AUTHORS @@ -102,6 +102,7 @@ Rodney Richardson Rustem Saiargaliev Rustem Saiargaliev Sandro Rodrigues +Sean 'Shaleh' Perry Shaheed Haque Shaun Stanworth Sayyid Hamid Mahdavi diff --git a/CHANGELOG.md b/CHANGELOG.md index 738927c5d..e8f1b7a9f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * Update token to TextField from CharField with 255 character limit and SHA-256 checksum in AbstractAccessToken model. Removing the 255 character limit enables supporting JWT tokens with additional claims * Update middleware, validators, and views to use token checksums instead of token for token retrieval and validation. * #1446 use generic models pk instead of id. +* Transactions wrapping writes of the Tokens now rely on Django's database routers to determine the correct + database to use instead of assuming that 'default' is the correct one. ### Deprecated ### Removed From 60437253488c1f8dfccf8b780fb0be814c014ad7 Mon Sep 17 00:00:00 2001 From: Sean 'Shaleh' Perry Date: Thu, 8 Aug 2024 17:28:25 -0700 Subject: [PATCH 03/15] PR review response. Document multiple database requires in advanced_topics.rst. Add an ImproperlyConfigured validator to the ready method of the DOTConfig app. Fix IDToken doc string. Document the use of _save_bearer_token. Define LocalIDToken and use it for validating the configuration test. Questionably, define py39-multi-db-invalid-token-configuration-dj42. This will consistently cause tox runs to fail until it is worked out how to mark this as an expected failure. --- docs/advanced_topics.rst | 11 ++++++ oauth2_provider/apps.py | 28 ++++++++++++++- oauth2_provider/models.py | 13 ++----- oauth2_provider/oauth2_validators.py | 16 ++++++--- tests/db_router.py | 9 ++--- tests/migrations/0006_add_localidtoken.py | 34 +++++++++++++++++++ tests/models.py | 7 ++++ ...db_settings_invalid_token_configuration.py | 8 +++++ tox.ini | 8 +++++ 9 files changed, 114 insertions(+), 20 deletions(-) create mode 100644 tests/migrations/0006_add_localidtoken.py create mode 100644 tests/multi_db_settings_invalid_token_configuration.py diff --git a/docs/advanced_topics.rst b/docs/advanced_topics.rst index 0b2ee20b0..204e3f860 100644 --- a/docs/advanced_topics.rst +++ b/docs/advanced_topics.rst @@ -65,6 +65,17 @@ That's all, now Django OAuth Toolkit will use your model wherever an Application is because of the way Django currently implements swappable models. See `issue #90 `_ for details. +Configuring multiple databases +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +There is no requirement that the tokens are stored in the default database or that there is a +default database provided the database routers can determine the correct Token locations. Because the +Tokens have foreign keys to the ``User`` model, you likely want to keep the tokens in the same database +as your User model. It is also important that all of the tokens are stored in the same database. +This could happen for instance if one of the Tokens is locally overridden and stored in a separate database. +The reason for this is transactions will only be made for the database where AccessToken is stored +even when writing to RefreshToken or other tokens. + Multiple Grants ~~~~~~~~~~~~~~~ diff --git a/oauth2_provider/apps.py b/oauth2_provider/apps.py index 887e4e3fb..818904626 100644 --- a/oauth2_provider/apps.py +++ b/oauth2_provider/apps.py @@ -1,6 +1,32 @@ -from django.apps import AppConfig +from django.apps import AppConfig, apps +from django.core.exceptions import ImproperlyConfigured +from django.db import router class DOTConfig(AppConfig): name = "oauth2_provider" verbose_name = "Django OAuth Toolkit" + + def _validate_token_configuration(self): + from .settings import oauth2_settings + + databases = set( + router.db_for_write(apps.get_model(model)) + for model in ( + oauth2_settings.ACCESS_TOKEN_MODEL, + oauth2_settings.ID_TOKEN_MODEL, + oauth2_settings.REFRESH_TOKEN_MODEL, + ) + ) + + # This is highly unlikely, but let's warn people just in case it does. + # If the tokens were allowed to be in different databases this would require all + # writes to have a transaction around each database. Instead, let's enforce that + # they all live together in one database. + # The tokens are not required to live in the default database provided the Django + # routers know the correct database for them. + if len(databases) > 1: + raise ImproperlyConfigured("the token models are expected to be stored in the same database.") + + def ready(self): + self._validate_token_configuration() diff --git a/oauth2_provider/models.py b/oauth2_provider/models.py index 2ec12b153..831fc551f 100644 --- a/oauth2_provider/models.py +++ b/oauth2_provider/models.py @@ -513,17 +513,8 @@ def revoke(self): Mark this refresh token revoked and revoke related access token """ access_token_model = get_access_token_model() - refresh_token_model = get_refresh_token_model() - access_token_database = router.db_for_write(access_token_model) - refresh_token_database = router.db_for_write(refresh_token_model) - - # This is highly unlikely, but let's warn people just in case it does. - if access_token_database != refresh_token_database: - logger.warning( - "access token and refresh token are in separate databases but a transaction" - " is only used for the access token" - ) + refresh_token_model = get_refresh_token_model() # Use the access_token_database instead of making the assumption it is in 'default'. with transaction.atomic(using=access_token_database): @@ -667,7 +658,7 @@ def get_access_token_model(): def get_id_token_model(): - """Return the AccessToken model that is active in this project.""" + """Return the IDToken model that is active in this project.""" return apps.get_model(oauth2_settings.ID_TOKEN_MODEL) diff --git a/oauth2_provider/oauth2_validators.py b/oauth2_provider/oauth2_validators.py index 808b02ae2..c5d8f3b58 100644 --- a/oauth2_provider/oauth2_validators.py +++ b/oauth2_provider/oauth2_validators.py @@ -563,14 +563,22 @@ def rotate_refresh_token(self, request): return oauth2_settings.ROTATE_REFRESH_TOKEN def save_bearer_token(self, token, request, *args, **kwargs): + """ + Save access and refresh token. + + Override _save_bearer_token and not this function when adding custom logic + for the storing of these token. This allows the transaction logic to be + separate from the token handling. + """ # Use the AccessToken's database instead of making the assumption it is in 'default'. with transaction.atomic(using=router.db_for_write(AccessToken)): - return self._save_bearer_token_internals(token, request, *args, **kwargs) + return self._save_bearer_token(token, request, *args, **kwargs) - def _save_bearer_token_internals(self, token, request, *args, **kwargs): + def _save_bearer_token(self, token, request, *args, **kwargs): """ - Save access and refresh token, If refresh token is issued, remove or - reuse old refresh token as in rfc:`6` + Save access and refresh token. + + If refresh token is issued, remove or reuse old refresh token as in rfc:`6`. @see: https://rfc-editor.org/rfc/rfc6749.html#section-6 """ diff --git a/tests/db_router.py b/tests/db_router.py index 461c60ef3..6ce9f0a65 100644 --- a/tests/db_router.py +++ b/tests/db_router.py @@ -1,13 +1,14 @@ -apps_in_beta = {"some_other_app", "this_one_too"} +apps_in_beta = {"tests", "some_other_app", "this_one_too"} # These are bare minimum routers to fake the scenario where there is actually a # decision around where an application's models might live. -# alpha is where the core Django models are stored including user. To keep things -# simple this is where the oauth2 provider models are stored as well because they -# have a foreign key to User. class AlphaRouter: + # alpha is where the core Django models are stored including user. To keep things + # simple this is where the oauth2 provider models are stored as well because they + # have a foreign key to User. + def db_for_read(self, model, **hints): if model._meta.app_label not in apps_in_beta: return "alpha" diff --git a/tests/migrations/0006_add_localidtoken.py b/tests/migrations/0006_add_localidtoken.py new file mode 100644 index 000000000..133af66e2 --- /dev/null +++ b/tests/migrations/0006_add_localidtoken.py @@ -0,0 +1,34 @@ +# Generated by Django 3.2.25 on 2024-08-08 22:47 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion +import uuid + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.OAUTH2_PROVIDER_APPLICATION_MODEL), + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('tests', '0005_basetestapplication_allowed_origins_and_more'), + ] + + operations = [ + migrations.CreateModel( + name='LocalIDToken', + fields=[ + ('id', models.BigAutoField(primary_key=True, serialize=False)), + ('jti', models.UUIDField(default=uuid.uuid4, editable=False, unique=True, verbose_name='JWT Token ID')), + ('expires', models.DateTimeField()), + ('scope', models.TextField(blank=True)), + ('created', models.DateTimeField(auto_now_add=True)), + ('updated', models.DateTimeField(auto_now=True)), + ('application', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, to=settings.OAUTH2_PROVIDER_APPLICATION_MODEL)), + ('user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='tests_localidtoken', to=settings.AUTH_USER_MODEL)), + ], + options={ + 'abstract': False, + }, + ), + ] diff --git a/tests/models.py b/tests/models.py index 355bc1b57..9f3643db8 100644 --- a/tests/models.py +++ b/tests/models.py @@ -4,6 +4,7 @@ AbstractAccessToken, AbstractApplication, AbstractGrant, + AbstractIDToken, AbstractRefreshToken, ) from oauth2_provider.settings import oauth2_settings @@ -54,3 +55,9 @@ class SampleRefreshToken(AbstractRefreshToken): class SampleGrant(AbstractGrant): custom_field = models.CharField(max_length=255) + + +class LocalIDToken(AbstractIDToken): + """Exists to be improperly configured for multiple databases.""" + + # The other token types will be in 'alpha' database. diff --git a/tests/multi_db_settings_invalid_token_configuration.py b/tests/multi_db_settings_invalid_token_configuration.py new file mode 100644 index 000000000..ed2804f79 --- /dev/null +++ b/tests/multi_db_settings_invalid_token_configuration.py @@ -0,0 +1,8 @@ +from .multi_db_settings import * # noqa: F401, F403 + + +OAUTH2_PROVIDER = { + # The other two tokens will be in alpha. This will cause a failure when the + # app's ready method is called. + "ID_TOKEN_MODEL": "tests.LocalIDToken", +} diff --git a/tox.ini b/tox.ini index 58fd222ef..c8013f537 100644 --- a/tox.ini +++ b/tox.ini @@ -10,6 +10,7 @@ envlist = py{310,311,312}-dj51, py{310,311,312}-djmain, py39-multi-db-dj-42, + py39-multi-db-invalid-token-configuration-dj42, [gh-actions] python = @@ -114,6 +115,13 @@ setenv = PYTHONPATH = {toxinidir} PYTHONWARNINGS = all +[testenv:py39-multi-db-invalid-token-configuration-dj42] +setenv = + DJANGO_SETTINGS_MODULE = tests.multi_db_settings_invalid_token_configuration + PYTHONPATH = {toxinidir} + PYTHONWARNINGS = all + ignore_errors = true + [testenv:migrate_swapped] setenv = DJANGO_SETTINGS_MODULE = tests.settings_swapped From 73a0406c493ca2a91b39da0ad9e1eb2db48767f0 Mon Sep 17 00:00:00 2001 From: Sean 'Shaleh' Perry Date: Thu, 15 Aug 2024 11:12:19 -0700 Subject: [PATCH 04/15] move migration --- .../{0006_add_localidtoken.py => 0007_add_localidtoken.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/migrations/{0006_add_localidtoken.py => 0007_add_localidtoken.py} (100%) diff --git a/tests/migrations/0006_add_localidtoken.py b/tests/migrations/0007_add_localidtoken.py similarity index 100% rename from tests/migrations/0006_add_localidtoken.py rename to tests/migrations/0007_add_localidtoken.py From 7a64859798bc1ea7443f95687863f5616ad88799 Mon Sep 17 00:00:00 2001 From: Sean 'Shaleh' Perry Date: Thu, 15 Aug 2024 11:13:10 -0700 Subject: [PATCH 05/15] update migration --- tests/migrations/0007_add_localidtoken.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/migrations/0007_add_localidtoken.py b/tests/migrations/0007_add_localidtoken.py index 133af66e2..f74cce5b6 100644 --- a/tests/migrations/0007_add_localidtoken.py +++ b/tests/migrations/0007_add_localidtoken.py @@ -11,7 +11,7 @@ class Migration(migrations.Migration): dependencies = [ migrations.swappable_dependency(settings.OAUTH2_PROVIDER_APPLICATION_MODEL), migrations.swappable_dependency(settings.AUTH_USER_MODEL), - ('tests', '0005_basetestapplication_allowed_origins_and_more'), + ('tests', '0006_basetestapplication_token_family'), ] operations = [ From 0c6feae2b303969f7625f05a1c255820b37267f8 Mon Sep 17 00:00:00 2001 From: Sean 'Shaleh' Perry Date: Thu, 15 Aug 2024 11:16:32 -0700 Subject: [PATCH 06/15] use django checks system --- oauth2_provider/checks.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 oauth2_provider/checks.py diff --git a/oauth2_provider/checks.py b/oauth2_provider/checks.py new file mode 100644 index 000000000..46468214d --- /dev/null +++ b/oauth2_provider/checks.py @@ -0,0 +1,29 @@ +from django.apps import apps +from django.core import checks +from django.db import router + +from .settings import oauth2_settings + + +@checks.register(checks.Tags.database) +def validate_token_configuration(app_configs, **kwargs): + breakpoint() + databases = set( + router.db_for_write(apps.get_model(model)) + for model in ( + oauth2_settings.ACCESS_TOKEN_MODEL, + oauth2_settings.ID_TOKEN_MODEL, + oauth2_settings.REFRESH_TOKEN_MODEL, + ) + ) + + # This is highly unlikely, but let's warn people just in case it does. + # If the tokens were allowed to be in different databases this would require all + # writes to have a transaction around each database. Instead, let's enforce that + # they all live together in one database. + # The tokens are not required to live in the default database provided the Django + # routers know the correct database for them. + if len(databases) > 1: + return [checks.Error("The token models are expected to be stored in the same database.")] + + return [] From a2b273ebcb8a2afab3eb405e9f38f183f42d5028 Mon Sep 17 00:00:00 2001 From: Sean 'Shaleh' Perry Date: Thu, 15 Aug 2024 11:48:29 -0700 Subject: [PATCH 07/15] drop misconfigured db check. Let's find a better way. --- tox.ini | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/tox.ini b/tox.ini index c8013f537..c94b17520 100644 --- a/tox.ini +++ b/tox.ini @@ -9,8 +9,7 @@ envlist = py{310,311,312}-dj50, py{310,311,312}-dj51, py{310,311,312}-djmain, - py39-multi-db-dj-42, - py39-multi-db-invalid-token-configuration-dj42, + py39-multi-db-dj-42 [gh-actions] python = @@ -115,13 +114,6 @@ setenv = PYTHONPATH = {toxinidir} PYTHONWARNINGS = all -[testenv:py39-multi-db-invalid-token-configuration-dj42] -setenv = - DJANGO_SETTINGS_MODULE = tests.multi_db_settings_invalid_token_configuration - PYTHONPATH = {toxinidir} - PYTHONWARNINGS = all - ignore_errors = true - [testenv:migrate_swapped] setenv = DJANGO_SETTINGS_MODULE = tests.settings_swapped From 90b2ff33bedc6853824c07913fbc99d886271e0b Mon Sep 17 00:00:00 2001 From: Sean 'Shaleh' Perry Date: Thu, 15 Aug 2024 13:35:48 -0700 Subject: [PATCH 08/15] run checks --- oauth2_provider/apps.py | 28 +++------------------------- 1 file changed, 3 insertions(+), 25 deletions(-) diff --git a/oauth2_provider/apps.py b/oauth2_provider/apps.py index 818904626..3ad08b715 100644 --- a/oauth2_provider/apps.py +++ b/oauth2_provider/apps.py @@ -1,32 +1,10 @@ -from django.apps import AppConfig, apps -from django.core.exceptions import ImproperlyConfigured -from django.db import router +from django.apps import AppConfig class DOTConfig(AppConfig): name = "oauth2_provider" verbose_name = "Django OAuth Toolkit" - def _validate_token_configuration(self): - from .settings import oauth2_settings - - databases = set( - router.db_for_write(apps.get_model(model)) - for model in ( - oauth2_settings.ACCESS_TOKEN_MODEL, - oauth2_settings.ID_TOKEN_MODEL, - oauth2_settings.REFRESH_TOKEN_MODEL, - ) - ) - - # This is highly unlikely, but let's warn people just in case it does. - # If the tokens were allowed to be in different databases this would require all - # writes to have a transaction around each database. Instead, let's enforce that - # they all live together in one database. - # The tokens are not required to live in the default database provided the Django - # routers know the correct database for them. - if len(databases) > 1: - raise ImproperlyConfigured("the token models are expected to be stored in the same database.") - def ready(self): - self._validate_token_configuration() + # Import checks to ensure they run. + from . import checks # noqa: F401 From 7637a491e4f77cf9b61d27fedf0bdad47c6a5f46 Mon Sep 17 00:00:00 2001 From: Sean 'Shaleh' Perry Date: Thu, 15 Aug 2024 14:19:38 -0700 Subject: [PATCH 09/15] maybe a better test definition --- tests/common_testing.py | 50 ++++++++++++++++---------------- tests/test_introspection_view.py | 5 ++-- 2 files changed, 28 insertions(+), 27 deletions(-) diff --git a/tests/common_testing.py b/tests/common_testing.py index daffd056f..11c9ffdde 100644 --- a/tests/common_testing.py +++ b/tests/common_testing.py @@ -3,33 +3,33 @@ from django.test import TransactionTestCase as DjangoTransactionTestCase -class OAuth2ProviderTestCase(DjangoTestCase): - """Place holder to allow overriding behaviors.""" +# When there are multiple databases defined, Django tests will not work unless they are +# told which database(s) to work with. The multiple database scenario setup for these +# tests purposefully defines 'default' as an empty database in order to catch any +# assumptions in this package about database names and in particular to ensure there is +# no assumption that 'default' is a valid database. +# For any test that would usually use Django's TestCase or TransactionTestCase using +# the classes defined here is all that is required. +# Any test that uses pytest's django_db need to include a databases parameter using +# test_database_names defined below. +# In test code, anywhere the database is referenced the Django router needs to be used +# exactly like the package's code. +# For instance: +# token_database = router.db_for_write(AccessToken) +# with self.assertNumQueries(1, using=token_database): +# Without the 'using' option, this test fails in the multiple database scenario because +# 'default' is used. + +test_database_names = ["alpha", "beta"] if len(settings.DATABASES) > 1 else ["default"] + +class OAuth2ProviderBase: + databases = test_database_names -class OAuth2ProviderTransactionTestCase(DjangoTransactionTestCase): + +class OAuth2ProviderTestCase(OAuth2ProviderBase, DjangoTestCase): """Place holder to allow overriding behaviors.""" -if len(settings.DATABASES) > 1: - # There are multiple databases defined. When this happens Django tests will not - # work unless they are told which database(s) to work with. The multiple - # database scenario setup for these tests purposefully defines 'default' as an - # empty database in order to catch any assumptions in this package about database - # names and in particular to ensure there is no assumption that 'default' is a - # valid database. - # For any test that would usually use Django's TestCase or TransactionTestCase - # using the classes defined here is all that is required. - # Any test that uses pytest's django_db need to base in a databases parameter - # using this definition of test_database_names. - # In test code, anywhere the default database is used the variable - # database_for_oauth2_provider must be used in its place. For instance, - # with self.assertNumQueries(1, using=database_for_oauth2_provider): - # without the using option this fails because default is used. - test_database_names = {name for name in settings.DATABASES if name != "default"} - database_for_oauth2_provider = "alpha" - OAuth2ProviderTestCase.databases = test_database_names - OAuth2ProviderTransactionTestCase.databases = test_database_names -else: - test_database_names = {"default"} - database_for_oauth2_provider = "default" +class OAuth2ProviderTransactionTestCase(OAuth2ProviderBase, DjangoTransactionTestCase): + """Place holder to allow overriding behaviors.""" diff --git a/tests/test_introspection_view.py b/tests/test_introspection_view.py index a1d1df493..3db23bbcd 100644 --- a/tests/test_introspection_view.py +++ b/tests/test_introspection_view.py @@ -3,6 +3,7 @@ import pytest from django.contrib.auth import get_user_model +from django.db import router from django.urls import reverse from django.utils import timezone @@ -10,7 +11,6 @@ from . import presets from .common_testing import OAuth2ProviderTestCase as TestCase -from .common_testing import database_for_oauth2_provider from .utils import get_basic_auth_header @@ -344,5 +344,6 @@ def test_view_post_invalid_client_creds_plaintext(self): self.assertEqual(response.status_code, 403) def test_select_related_in_view_for_less_db_queries(self): - with self.assertNumQueries(1, using=database_for_oauth2_provider): + token_database = router.db_for_write(AccessToken) + with self.assertNumQueries(1, using=token_database): self.client.post(reverse("oauth2_provider:introspect")) From 08f5021ccca49e17ec797c12df9c86a7c53853aa Mon Sep 17 00:00:00 2001 From: Sean 'Shaleh' Perry Date: Thu, 15 Aug 2024 14:25:18 -0700 Subject: [PATCH 10/15] listing tests was breaking things --- tests/db_router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/db_router.py b/tests/db_router.py index 6ce9f0a65..e7fdbd796 100644 --- a/tests/db_router.py +++ b/tests/db_router.py @@ -1,4 +1,4 @@ -apps_in_beta = {"tests", "some_other_app", "this_one_too"} +apps_in_beta = {"some_other_app", "this_one_too"} # These are bare minimum routers to fake the scenario where there is actually a # decision around where an application's models might live. From dbddebf6789a9f78158b653d7187543bbd5ef346 Mon Sep 17 00:00:00 2001 From: Sean 'Shaleh' Perry Date: Thu, 15 Aug 2024 16:21:17 -0700 Subject: [PATCH 11/15] No more magic. --- tests/common_testing.py | 14 +++++--- tests/test_hybrid.py | 5 ++- tests/test_models.py | 13 ++++--- tests/test_oauth2_validators.py | 7 ++-- tests/test_oidc_views.py | 61 ++++++++++++++++----------------- 5 files changed, 51 insertions(+), 49 deletions(-) diff --git a/tests/common_testing.py b/tests/common_testing.py index 11c9ffdde..f530d67da 100644 --- a/tests/common_testing.py +++ b/tests/common_testing.py @@ -10,8 +10,6 @@ # no assumption that 'default' is a valid database. # For any test that would usually use Django's TestCase or TransactionTestCase using # the classes defined here is all that is required. -# Any test that uses pytest's django_db need to include a databases parameter using -# test_database_names defined below. # In test code, anywhere the database is referenced the Django router needs to be used # exactly like the package's code. # For instance: @@ -20,11 +18,19 @@ # Without the 'using' option, this test fails in the multiple database scenario because # 'default' is used. -test_database_names = ["alpha", "beta"] if len(settings.DATABASES) > 1 else ["default"] + +def retrieve_current_databases(): + if len(settings.DATABASES) > 1: + return [name for name in settings.DATABASES if name != "default"] + else: + return ["default"] class OAuth2ProviderBase: - databases = test_database_names + @classmethod + def setUpClass(cls): + cls.databases = retrieve_current_databases() + super().setUpClass() class OAuth2ProviderTestCase(OAuth2ProviderBase, DjangoTestCase): diff --git a/tests/test_hybrid.py b/tests/test_hybrid.py index 204be7671..87c4b0ad9 100644 --- a/tests/test_hybrid.py +++ b/tests/test_hybrid.py @@ -22,7 +22,6 @@ from . import presets from .common_testing import OAuth2ProviderTestCase as TestCase -from .common_testing import test_database_names from .utils import get_basic_auth_header, spy_on @@ -1320,7 +1319,7 @@ def test_pre_auth_default_scopes(self): self.assertEqual(form["client_id"].value(), self.application.client_id) -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_id_token_nonce_in_token_response(oauth2_settings, test_user, hybrid_application, client, oidc_key): client.force_login(test_user) @@ -1369,7 +1368,7 @@ def test_id_token_nonce_in_token_response(oauth2_settings, test_user, hybrid_app assert claims["nonce"] == "random_nonce_string" -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_claims_passed_to_code_generation( oauth2_settings, test_user, hybrid_application, client, mocker, oidc_key diff --git a/tests/test_models.py b/tests/test_models.py index 196bac25a..cd4b7342c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -20,7 +20,6 @@ from . import presets from .common_testing import OAuth2ProviderTestCase as TestCase -from .common_testing import test_database_names CLEARTEXT_SECRET = "1234567890abcdefghijklmnopqrstuvwxyz" @@ -467,7 +466,7 @@ def test_clear_expired_tokens_with_tokens(self): assert remaining_gt_count == initial_gt_count // 2, "half the remaining grants should still exist." -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_id_token_methods(oidc_tokens, rf): id_token = IDToken.objects.get() @@ -502,7 +501,7 @@ def test_id_token_methods(oidc_tokens, rf): assert IDToken.objects.filter(jti=id_token.jti).count() == 0 -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_clear_expired_id_tokens(oauth2_settings, oidc_tokens, rf): id_token = IDToken.objects.get() @@ -541,7 +540,7 @@ def test_clear_expired_id_tokens(oauth2_settings, oidc_tokens, rf): assert not IDToken.objects.filter(jti=id_token.jti).exists() -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_application_key(oauth2_settings, application): # RS256 key @@ -566,7 +565,7 @@ def test_application_key(oauth2_settings, application): assert "This application does not support signed tokens" == str(exc.value) -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_application_clean(oauth2_settings, application): # RS256, RSA key is configured @@ -606,7 +605,7 @@ def test_application_clean(oauth2_settings, application): application.clean() -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db @pytest.mark.oauth2_settings(presets.ALLOWED_SCHEMES_DEFAULT) def test_application_origin_allowed_default_https(oauth2_settings, cors_application): """Test that http schemes are not allowed because ALLOWED_SCHEMES allows only https""" @@ -614,7 +613,7 @@ def test_application_origin_allowed_default_https(oauth2_settings, cors_applicat assert not cors_application.origin_allowed("http://example.com") -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db @pytest.mark.oauth2_settings(presets.ALLOWED_SCHEMES_HTTP) def test_application_origin_allowed_http(oauth2_settings, cors_application): """Test that http schemes are allowed because http was added to ALLOWED_SCHEMES""" diff --git a/tests/test_oauth2_validators.py b/tests/test_oauth2_validators.py index d4e53c37f..bf06b73a8 100644 --- a/tests/test_oauth2_validators.py +++ b/tests/test_oauth2_validators.py @@ -17,7 +17,6 @@ from . import presets from .common_testing import OAuth2ProviderTestCase as TestCase from .common_testing import OAuth2ProviderTransactionTestCase as TransactionTestCase -from .common_testing import test_database_names from .utils import get_basic_auth_header @@ -547,7 +546,7 @@ def test_get_jwt_bearer_token(oauth2_settings, mocker): assert mock_get_id_token.call_args[1] == {} -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_validate_id_token_expired_jwt(oauth2_settings, mocker, oidc_tokens): mocker.patch("oauth2_provider.oauth2_validators.jwt.JWT", side_effect=jwt.JWTExpired) @@ -563,7 +562,7 @@ def test_validate_id_token_no_token(oauth2_settings, mocker): assert status is False -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_validate_id_token_app_removed(oauth2_settings, mocker, oidc_tokens): oidc_tokens.application.delete() @@ -572,7 +571,7 @@ def test_validate_id_token_app_removed(oauth2_settings, mocker, oidc_tokens): assert status is False -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_validate_id_token_bad_token_no_aud(oauth2_settings, mocker, oidc_key): token = jwt.JWT(header=json.dumps({"alg": "RS256"}), claims=json.dumps({"bad": "token"})) diff --git a/tests/test_oidc_views.py b/tests/test_oidc_views.py index 882b01b8c..8949f41e7 100644 --- a/tests/test_oidc_views.py +++ b/tests/test_oidc_views.py @@ -19,7 +19,6 @@ from . import presets from .common_testing import OAuth2ProviderTestCase as TestCase -from .common_testing import test_database_names @pytest.mark.usefixtures("oauth2_settings") @@ -222,7 +221,7 @@ def mock_request_for(user): return request -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db def test_validate_logout_request(oidc_tokens, public_application, rp_settings): oidc_tokens = oidc_tokens application = oidc_tokens.application @@ -300,7 +299,7 @@ def test_validate_logout_request(oidc_tokens, public_application, rp_settings): ) -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db @pytest.mark.parametrize("ALWAYS_PROMPT", [True, False]) def test_must_prompt(oidc_tokens, other_user, rp_settings, ALWAYS_PROMPT): rp_settings.OIDC_RP_INITIATED_LOGOUT_ALWAYS_PROMPT = ALWAYS_PROMPT @@ -321,14 +320,14 @@ def is_logged_in(client): return get_user(client).is_authenticated -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db def test_rp_initiated_logout_get(logged_in_client, rp_settings): rsp = logged_in_client.get(reverse("oauth2_provider:rp-initiated-logout"), data={}) assert rsp.status_code == 200 assert is_logged_in(logged_in_client) -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db def test_rp_initiated_logout_get_id_token(logged_in_client, oidc_tokens, rp_settings): rsp = logged_in_client.get( reverse("oauth2_provider:rp-initiated-logout"), data={"id_token_hint": oidc_tokens.id_token} @@ -338,7 +337,7 @@ def test_rp_initiated_logout_get_id_token(logged_in_client, oidc_tokens, rp_sett assert not is_logged_in(logged_in_client) -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db def test_rp_initiated_logout_get_revoked_id_token(logged_in_client, oidc_tokens, rp_settings): validator = oauth2_settings.OAUTH2_VALIDATOR_CLASS() validator._load_id_token(oidc_tokens.id_token).revoke() @@ -349,7 +348,7 @@ def test_rp_initiated_logout_get_revoked_id_token(logged_in_client, oidc_tokens, assert is_logged_in(logged_in_client) -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db def test_rp_initiated_logout_get_id_token_redirect(logged_in_client, oidc_tokens, rp_settings): rsp = logged_in_client.get( reverse("oauth2_provider:rp-initiated-logout"), @@ -360,7 +359,7 @@ def test_rp_initiated_logout_get_id_token_redirect(logged_in_client, oidc_tokens assert not is_logged_in(logged_in_client) -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db def test_rp_initiated_logout_get_id_token_redirect_with_state(logged_in_client, oidc_tokens, rp_settings): rsp = logged_in_client.get( reverse("oauth2_provider:rp-initiated-logout"), @@ -375,7 +374,7 @@ def test_rp_initiated_logout_get_id_token_redirect_with_state(logged_in_client, assert not is_logged_in(logged_in_client) -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db def test_rp_initiated_logout_get_id_token_missmatch_client_id( logged_in_client, oidc_tokens, public_application, rp_settings ): @@ -387,7 +386,7 @@ def test_rp_initiated_logout_get_id_token_missmatch_client_id( assert is_logged_in(logged_in_client) -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db def test_rp_initiated_logout_public_client_redirect_client_id( logged_in_client, oidc_non_confidential_tokens, public_application, rp_settings ): @@ -403,7 +402,7 @@ def test_rp_initiated_logout_public_client_redirect_client_id( assert not is_logged_in(logged_in_client) -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db def test_rp_initiated_logout_public_client_strict_redirect_client_id( logged_in_client, oidc_non_confidential_tokens, public_application, oauth2_settings ): @@ -420,7 +419,7 @@ def test_rp_initiated_logout_public_client_strict_redirect_client_id( assert is_logged_in(logged_in_client) -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db def test_rp_initiated_logout_get_client_id(logged_in_client, oidc_tokens, rp_settings): rsp = logged_in_client.get( reverse("oauth2_provider:rp-initiated-logout"), data={"client_id": oidc_tokens.application.client_id} @@ -429,7 +428,7 @@ def test_rp_initiated_logout_get_client_id(logged_in_client, oidc_tokens, rp_set assert is_logged_in(logged_in_client) -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db def test_rp_initiated_logout_post(logged_in_client, oidc_tokens, rp_settings): form_data = { "client_id": oidc_tokens.application.client_id, @@ -439,7 +438,7 @@ def test_rp_initiated_logout_post(logged_in_client, oidc_tokens, rp_settings): assert is_logged_in(logged_in_client) -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db def test_rp_initiated_logout_post_allowed(logged_in_client, oidc_tokens, rp_settings): form_data = {"client_id": oidc_tokens.application.client_id, "allow": True} rsp = logged_in_client.post(reverse("oauth2_provider:rp-initiated-logout"), form_data) @@ -448,7 +447,7 @@ def test_rp_initiated_logout_post_allowed(logged_in_client, oidc_tokens, rp_sett assert not is_logged_in(logged_in_client) -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db def test_rp_initiated_logout_post_no_session(client, oidc_tokens, rp_settings): form_data = {"client_id": oidc_tokens.application.client_id, "allow": True} rsp = client.post(reverse("oauth2_provider:rp-initiated-logout"), form_data) @@ -457,7 +456,7 @@ def test_rp_initiated_logout_post_no_session(client, oidc_tokens, rp_settings): assert not is_logged_in(client) -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT) def test_rp_initiated_logout_expired_tokens_accept(logged_in_client, application, expired_id_token): # Accepting expired (but otherwise valid and signed by us) tokens is enabled. Logout should go through. @@ -472,7 +471,7 @@ def test_rp_initiated_logout_expired_tokens_accept(logged_in_client, application assert not is_logged_in(logged_in_client) -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT_DENY_EXPIRED) def test_rp_initiated_logout_expired_tokens_deny(logged_in_client, application, expired_id_token): # Expired tokens should not be accepted by default. @@ -487,14 +486,14 @@ def test_rp_initiated_logout_expired_tokens_deny(logged_in_client, application, assert is_logged_in(logged_in_client) -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT) def test_load_id_token_accept_expired(expired_id_token): id_token, _ = _load_id_token(expired_id_token) assert isinstance(id_token, get_id_token_model()) -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT) def test_load_id_token_wrong_aud(id_token_wrong_aud): id_token, claims = _load_id_token(id_token_wrong_aud) @@ -502,7 +501,7 @@ def test_load_id_token_wrong_aud(id_token_wrong_aud): assert claims is None -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT_DENY_EXPIRED) def test_load_id_token_deny_expired(expired_id_token): id_token, claims = _load_id_token(expired_id_token) @@ -510,7 +509,7 @@ def test_load_id_token_deny_expired(expired_id_token): assert claims is None -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT) def test_validate_claims_wrong_iss(id_token_wrong_iss): id_token, claims = _load_id_token(id_token_wrong_iss) @@ -519,7 +518,7 @@ def test_validate_claims_wrong_iss(id_token_wrong_iss): assert not _validate_claims(mock_request(), claims) -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT) def test_validate_claims(oidc_tokens): id_token, claims = _load_id_token(oidc_tokens.id_token) @@ -527,7 +526,7 @@ def test_validate_claims(oidc_tokens): assert _validate_claims(mock_request_for(oidc_tokens.user), claims) -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db @pytest.mark.parametrize("method", ["get", "post"]) def test_userinfo_endpoint(oidc_tokens, client, method): auth_header = "Bearer %s" % oidc_tokens.access_token @@ -540,7 +539,7 @@ def test_userinfo_endpoint(oidc_tokens, client, method): assert data["sub"] == str(oidc_tokens.user.pk) -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db def test_userinfo_endpoint_bad_token(oidc_tokens, client): # No access token rsp = client.get(reverse("oauth2_provider:user-info")) @@ -553,7 +552,7 @@ def test_userinfo_endpoint_bad_token(oidc_tokens, client): assert rsp.status_code == 401 -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db def test_token_deletion_on_logout(oidc_tokens, logged_in_client, rp_settings): AccessToken = get_access_token_model() IDToken = get_id_token_model() @@ -576,7 +575,7 @@ def test_token_deletion_on_logout(oidc_tokens, logged_in_client, rp_settings): assert all([token.revoked <= timezone.now() for token in RefreshToken.objects.all()]) -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db def test_token_deletion_on_logout_expired_session(oidc_tokens, client, rp_settings): AccessToken = get_access_token_model() IDToken = get_id_token_model() @@ -617,7 +616,7 @@ def test_token_deletion_on_logout_expired_session(oidc_tokens, client, rp_settin assert all(token.revoked <= timezone.now() for token in RefreshToken.objects.all()) -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT_KEEP_TOKENS) def test_token_deletion_on_logout_disabled(oidc_tokens, logged_in_client, rp_settings): rp_settings.OIDC_RP_INITIATED_LOGOUT_DELETE_TOKENS = False @@ -653,7 +652,7 @@ def claim_user_email(request): return EXAMPLE_EMAIL -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db def test_userinfo_endpoint_custom_claims_callable(oidc_tokens, client, oauth2_settings): class CustomValidator(OAuth2Validator): oidc_claim_scope = None @@ -681,7 +680,7 @@ def get_additional_claims(self): assert data["email"] == EXAMPLE_EMAIL -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db def test_userinfo_endpoint_custom_claims_email_scope_callable( oidc_email_scope_tokens, client, oauth2_settings ): @@ -708,7 +707,7 @@ def get_additional_claims(self): assert data["email"] == EXAMPLE_EMAIL -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db def test_userinfo_endpoint_custom_claims_plain(oidc_tokens, client, oauth2_settings): class CustomValidator(OAuth2Validator): oidc_claim_scope = None @@ -736,7 +735,7 @@ def get_additional_claims(self, request): assert data["email"] == EXAMPLE_EMAIL -@pytest.mark.django_db(databases=test_database_names) +@pytest.mark.django_db def test_userinfo_endpoint_custom_claims_email_scopeplain(oidc_email_scope_tokens, client, oauth2_settings): class CustomValidator(OAuth2Validator): def get_additional_claims(self, request): From 8072cc712f95237b3e5e280ebd5c7e534533a5c2 Mon Sep 17 00:00:00 2001 From: Sean 'Shaleh' Perry Date: Fri, 16 Aug 2024 07:23:07 -0700 Subject: [PATCH 12/15] Oops. Debugger. --- oauth2_provider/checks.py | 1 - 1 file changed, 1 deletion(-) diff --git a/oauth2_provider/checks.py b/oauth2_provider/checks.py index 46468214d..848ba1af7 100644 --- a/oauth2_provider/checks.py +++ b/oauth2_provider/checks.py @@ -7,7 +7,6 @@ @checks.register(checks.Tags.database) def validate_token_configuration(app_configs, **kwargs): - breakpoint() databases = set( router.db_for_write(apps.get_model(model)) for model in ( From 86b251989c0fb51804a28762db9a21d0c6f9728d Mon Sep 17 00:00:00 2001 From: Sean 'Shaleh' Perry Date: Fri, 16 Aug 2024 07:23:32 -0700 Subject: [PATCH 13/15] Use retrieven_current_databases in django_db marked tests. --- tests/test_hybrid.py | 5 +-- tests/test_models.py | 13 +++---- tests/test_oauth2_validators.py | 7 ++-- tests/test_oidc_views.py | 61 +++++++++++++++++---------------- 4 files changed, 45 insertions(+), 41 deletions(-) diff --git a/tests/test_hybrid.py b/tests/test_hybrid.py index 87c4b0ad9..67c29a54e 100644 --- a/tests/test_hybrid.py +++ b/tests/test_hybrid.py @@ -22,6 +22,7 @@ from . import presets from .common_testing import OAuth2ProviderTestCase as TestCase +from .common_testing import retrieve_current_databases from .utils import get_basic_auth_header, spy_on @@ -1319,7 +1320,7 @@ def test_pre_auth_default_scopes(self): self.assertEqual(form["client_id"].value(), self.application.client_id) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_id_token_nonce_in_token_response(oauth2_settings, test_user, hybrid_application, client, oidc_key): client.force_login(test_user) @@ -1368,7 +1369,7 @@ def test_id_token_nonce_in_token_response(oauth2_settings, test_user, hybrid_app assert claims["nonce"] == "random_nonce_string" -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_claims_passed_to_code_generation( oauth2_settings, test_user, hybrid_application, client, mocker, oidc_key diff --git a/tests/test_models.py b/tests/test_models.py index cd4b7342c..58765db69 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -20,6 +20,7 @@ from . import presets from .common_testing import OAuth2ProviderTestCase as TestCase +from .common_testing import retrieve_current_databases CLEARTEXT_SECRET = "1234567890abcdefghijklmnopqrstuvwxyz" @@ -466,7 +467,7 @@ def test_clear_expired_tokens_with_tokens(self): assert remaining_gt_count == initial_gt_count // 2, "half the remaining grants should still exist." -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_id_token_methods(oidc_tokens, rf): id_token = IDToken.objects.get() @@ -501,7 +502,7 @@ def test_id_token_methods(oidc_tokens, rf): assert IDToken.objects.filter(jti=id_token.jti).count() == 0 -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_clear_expired_id_tokens(oauth2_settings, oidc_tokens, rf): id_token = IDToken.objects.get() @@ -540,7 +541,7 @@ def test_clear_expired_id_tokens(oauth2_settings, oidc_tokens, rf): assert not IDToken.objects.filter(jti=id_token.jti).exists() -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_application_key(oauth2_settings, application): # RS256 key @@ -565,7 +566,7 @@ def test_application_key(oauth2_settings, application): assert "This application does not support signed tokens" == str(exc.value) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_application_clean(oauth2_settings, application): # RS256, RSA key is configured @@ -605,7 +606,7 @@ def test_application_clean(oauth2_settings, application): application.clean() -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.ALLOWED_SCHEMES_DEFAULT) def test_application_origin_allowed_default_https(oauth2_settings, cors_application): """Test that http schemes are not allowed because ALLOWED_SCHEMES allows only https""" @@ -613,7 +614,7 @@ def test_application_origin_allowed_default_https(oauth2_settings, cors_applicat assert not cors_application.origin_allowed("http://example.com") -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.ALLOWED_SCHEMES_HTTP) def test_application_origin_allowed_http(oauth2_settings, cors_application): """Test that http schemes are allowed because http was added to ALLOWED_SCHEMES""" diff --git a/tests/test_oauth2_validators.py b/tests/test_oauth2_validators.py index bf06b73a8..468e05598 100644 --- a/tests/test_oauth2_validators.py +++ b/tests/test_oauth2_validators.py @@ -17,6 +17,7 @@ from . import presets from .common_testing import OAuth2ProviderTestCase as TestCase from .common_testing import OAuth2ProviderTransactionTestCase as TransactionTestCase +from .common_testing import retrieve_current_databases from .utils import get_basic_auth_header @@ -546,7 +547,7 @@ def test_get_jwt_bearer_token(oauth2_settings, mocker): assert mock_get_id_token.call_args[1] == {} -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_validate_id_token_expired_jwt(oauth2_settings, mocker, oidc_tokens): mocker.patch("oauth2_provider.oauth2_validators.jwt.JWT", side_effect=jwt.JWTExpired) @@ -562,7 +563,7 @@ def test_validate_id_token_no_token(oauth2_settings, mocker): assert status is False -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_validate_id_token_app_removed(oauth2_settings, mocker, oidc_tokens): oidc_tokens.application.delete() @@ -571,7 +572,7 @@ def test_validate_id_token_app_removed(oauth2_settings, mocker, oidc_tokens): assert status is False -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RW) def test_validate_id_token_bad_token_no_aud(oauth2_settings, mocker, oidc_key): token = jwt.JWT(header=json.dumps({"alg": "RS256"}), claims=json.dumps({"bad": "token"})) diff --git a/tests/test_oidc_views.py b/tests/test_oidc_views.py index 8949f41e7..8bdf18360 100644 --- a/tests/test_oidc_views.py +++ b/tests/test_oidc_views.py @@ -19,6 +19,7 @@ from . import presets from .common_testing import OAuth2ProviderTestCase as TestCase +from .common_testing import retrieve_current_databases @pytest.mark.usefixtures("oauth2_settings") @@ -221,7 +222,7 @@ def mock_request_for(user): return request -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_validate_logout_request(oidc_tokens, public_application, rp_settings): oidc_tokens = oidc_tokens application = oidc_tokens.application @@ -299,7 +300,7 @@ def test_validate_logout_request(oidc_tokens, public_application, rp_settings): ) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.parametrize("ALWAYS_PROMPT", [True, False]) def test_must_prompt(oidc_tokens, other_user, rp_settings, ALWAYS_PROMPT): rp_settings.OIDC_RP_INITIATED_LOGOUT_ALWAYS_PROMPT = ALWAYS_PROMPT @@ -320,14 +321,14 @@ def is_logged_in(client): return get_user(client).is_authenticated -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_rp_initiated_logout_get(logged_in_client, rp_settings): rsp = logged_in_client.get(reverse("oauth2_provider:rp-initiated-logout"), data={}) assert rsp.status_code == 200 assert is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_rp_initiated_logout_get_id_token(logged_in_client, oidc_tokens, rp_settings): rsp = logged_in_client.get( reverse("oauth2_provider:rp-initiated-logout"), data={"id_token_hint": oidc_tokens.id_token} @@ -337,7 +338,7 @@ def test_rp_initiated_logout_get_id_token(logged_in_client, oidc_tokens, rp_sett assert not is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_rp_initiated_logout_get_revoked_id_token(logged_in_client, oidc_tokens, rp_settings): validator = oauth2_settings.OAUTH2_VALIDATOR_CLASS() validator._load_id_token(oidc_tokens.id_token).revoke() @@ -348,7 +349,7 @@ def test_rp_initiated_logout_get_revoked_id_token(logged_in_client, oidc_tokens, assert is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_rp_initiated_logout_get_id_token_redirect(logged_in_client, oidc_tokens, rp_settings): rsp = logged_in_client.get( reverse("oauth2_provider:rp-initiated-logout"), @@ -359,7 +360,7 @@ def test_rp_initiated_logout_get_id_token_redirect(logged_in_client, oidc_tokens assert not is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_rp_initiated_logout_get_id_token_redirect_with_state(logged_in_client, oidc_tokens, rp_settings): rsp = logged_in_client.get( reverse("oauth2_provider:rp-initiated-logout"), @@ -374,7 +375,7 @@ def test_rp_initiated_logout_get_id_token_redirect_with_state(logged_in_client, assert not is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_rp_initiated_logout_get_id_token_missmatch_client_id( logged_in_client, oidc_tokens, public_application, rp_settings ): @@ -386,7 +387,7 @@ def test_rp_initiated_logout_get_id_token_missmatch_client_id( assert is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_rp_initiated_logout_public_client_redirect_client_id( logged_in_client, oidc_non_confidential_tokens, public_application, rp_settings ): @@ -402,7 +403,7 @@ def test_rp_initiated_logout_public_client_redirect_client_id( assert not is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_rp_initiated_logout_public_client_strict_redirect_client_id( logged_in_client, oidc_non_confidential_tokens, public_application, oauth2_settings ): @@ -419,7 +420,7 @@ def test_rp_initiated_logout_public_client_strict_redirect_client_id( assert is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_rp_initiated_logout_get_client_id(logged_in_client, oidc_tokens, rp_settings): rsp = logged_in_client.get( reverse("oauth2_provider:rp-initiated-logout"), data={"client_id": oidc_tokens.application.client_id} @@ -428,7 +429,7 @@ def test_rp_initiated_logout_get_client_id(logged_in_client, oidc_tokens, rp_set assert is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_rp_initiated_logout_post(logged_in_client, oidc_tokens, rp_settings): form_data = { "client_id": oidc_tokens.application.client_id, @@ -438,7 +439,7 @@ def test_rp_initiated_logout_post(logged_in_client, oidc_tokens, rp_settings): assert is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_rp_initiated_logout_post_allowed(logged_in_client, oidc_tokens, rp_settings): form_data = {"client_id": oidc_tokens.application.client_id, "allow": True} rsp = logged_in_client.post(reverse("oauth2_provider:rp-initiated-logout"), form_data) @@ -447,7 +448,7 @@ def test_rp_initiated_logout_post_allowed(logged_in_client, oidc_tokens, rp_sett assert not is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_rp_initiated_logout_post_no_session(client, oidc_tokens, rp_settings): form_data = {"client_id": oidc_tokens.application.client_id, "allow": True} rsp = client.post(reverse("oauth2_provider:rp-initiated-logout"), form_data) @@ -456,7 +457,7 @@ def test_rp_initiated_logout_post_no_session(client, oidc_tokens, rp_settings): assert not is_logged_in(client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT) def test_rp_initiated_logout_expired_tokens_accept(logged_in_client, application, expired_id_token): # Accepting expired (but otherwise valid and signed by us) tokens is enabled. Logout should go through. @@ -471,7 +472,7 @@ def test_rp_initiated_logout_expired_tokens_accept(logged_in_client, application assert not is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT_DENY_EXPIRED) def test_rp_initiated_logout_expired_tokens_deny(logged_in_client, application, expired_id_token): # Expired tokens should not be accepted by default. @@ -486,14 +487,14 @@ def test_rp_initiated_logout_expired_tokens_deny(logged_in_client, application, assert is_logged_in(logged_in_client) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT) def test_load_id_token_accept_expired(expired_id_token): id_token, _ = _load_id_token(expired_id_token) assert isinstance(id_token, get_id_token_model()) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT) def test_load_id_token_wrong_aud(id_token_wrong_aud): id_token, claims = _load_id_token(id_token_wrong_aud) @@ -501,7 +502,7 @@ def test_load_id_token_wrong_aud(id_token_wrong_aud): assert claims is None -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT_DENY_EXPIRED) def test_load_id_token_deny_expired(expired_id_token): id_token, claims = _load_id_token(expired_id_token) @@ -509,7 +510,7 @@ def test_load_id_token_deny_expired(expired_id_token): assert claims is None -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT) def test_validate_claims_wrong_iss(id_token_wrong_iss): id_token, claims = _load_id_token(id_token_wrong_iss) @@ -518,7 +519,7 @@ def test_validate_claims_wrong_iss(id_token_wrong_iss): assert not _validate_claims(mock_request(), claims) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT) def test_validate_claims(oidc_tokens): id_token, claims = _load_id_token(oidc_tokens.id_token) @@ -526,7 +527,7 @@ def test_validate_claims(oidc_tokens): assert _validate_claims(mock_request_for(oidc_tokens.user), claims) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.parametrize("method", ["get", "post"]) def test_userinfo_endpoint(oidc_tokens, client, method): auth_header = "Bearer %s" % oidc_tokens.access_token @@ -539,7 +540,7 @@ def test_userinfo_endpoint(oidc_tokens, client, method): assert data["sub"] == str(oidc_tokens.user.pk) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_userinfo_endpoint_bad_token(oidc_tokens, client): # No access token rsp = client.get(reverse("oauth2_provider:user-info")) @@ -552,7 +553,7 @@ def test_userinfo_endpoint_bad_token(oidc_tokens, client): assert rsp.status_code == 401 -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_token_deletion_on_logout(oidc_tokens, logged_in_client, rp_settings): AccessToken = get_access_token_model() IDToken = get_id_token_model() @@ -575,7 +576,7 @@ def test_token_deletion_on_logout(oidc_tokens, logged_in_client, rp_settings): assert all([token.revoked <= timezone.now() for token in RefreshToken.objects.all()]) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_token_deletion_on_logout_expired_session(oidc_tokens, client, rp_settings): AccessToken = get_access_token_model() IDToken = get_id_token_model() @@ -616,7 +617,7 @@ def test_token_deletion_on_logout_expired_session(oidc_tokens, client, rp_settin assert all(token.revoked <= timezone.now() for token in RefreshToken.objects.all()) -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) @pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_RP_LOGOUT_KEEP_TOKENS) def test_token_deletion_on_logout_disabled(oidc_tokens, logged_in_client, rp_settings): rp_settings.OIDC_RP_INITIATED_LOGOUT_DELETE_TOKENS = False @@ -652,7 +653,7 @@ def claim_user_email(request): return EXAMPLE_EMAIL -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_userinfo_endpoint_custom_claims_callable(oidc_tokens, client, oauth2_settings): class CustomValidator(OAuth2Validator): oidc_claim_scope = None @@ -680,7 +681,7 @@ def get_additional_claims(self): assert data["email"] == EXAMPLE_EMAIL -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_userinfo_endpoint_custom_claims_email_scope_callable( oidc_email_scope_tokens, client, oauth2_settings ): @@ -707,7 +708,7 @@ def get_additional_claims(self): assert data["email"] == EXAMPLE_EMAIL -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_userinfo_endpoint_custom_claims_plain(oidc_tokens, client, oauth2_settings): class CustomValidator(OAuth2Validator): oidc_claim_scope = None @@ -735,7 +736,7 @@ def get_additional_claims(self, request): assert data["email"] == EXAMPLE_EMAIL -@pytest.mark.django_db +@pytest.mark.django_db(databases=retrieve_current_databases()) def test_userinfo_endpoint_custom_claims_email_scopeplain(oidc_email_scope_tokens, client, oauth2_settings): class CustomValidator(OAuth2Validator): def get_additional_claims(self, request): From ae4c4e6b781327751d78ed7d41814cbbddef6efa Mon Sep 17 00:00:00 2001 From: Sean 'Shaleh' Perry Date: Fri, 16 Aug 2024 09:00:10 -0700 Subject: [PATCH 14/15] Updates. Prove the checks work. Document test requirements. --- docs/contributing.rst | 20 ++++++++++++++++++++ tests/common_testing.py | 18 +++++------------- tests/db_router.py | 24 ++++++++++++++++++++++++ tests/test_django_checks.py | 20 ++++++++++++++++++++ 4 files changed, 69 insertions(+), 13 deletions(-) create mode 100644 tests/test_django_checks.py diff --git a/docs/contributing.rst b/docs/contributing.rst index 425008a62..8d621b413 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -252,6 +252,26 @@ Open :file:`mycoverage/index.html` in your browser and you can see a coverage su There's no need to wait for Codecov to complain after you submit your PR. +The tests are generic and written to work with both single database and multiple database configurations. tox will run +tests both ways. You can see the configurations used in tests/settings.py and tests/multi_db_settins.py. + +When there are multiple databases defined, Django tests will not work unless they are told which database(s) to work with. +For test writers this means any test must either: +- instead of Django's TestCase or TransactionTestCase use the versions of those + classes defined in tests/common_testing.py +- when using pytest's `django_db` mark, define it like this: + `@pytest.mark.django_db(databases=retrieve_current_databases())` + +In test code, anywhere the database is referenced the Django router needs to be used exactly like the package's code. + +.. code-block:: python + + token_database = router.db_for_write(AccessToken) + with self.assertNumQueries(1, using=token_database): + # call something using the database + +Without the 'using' option, this test fails in the multiple database scenario because 'default' will be used instead. + Code conventions matter ----------------------- diff --git a/tests/common_testing.py b/tests/common_testing.py index f530d67da..6f6a5b745 100644 --- a/tests/common_testing.py +++ b/tests/common_testing.py @@ -3,20 +3,12 @@ from django.test import TransactionTestCase as DjangoTransactionTestCase +# The multiple database scenario setup for these tests purposefully defines 'default' as +# an empty database in order to catch any assumptions in this package about database names +# and in particular to ensure there is no assumption that 'default' is a valid database. +# # When there are multiple databases defined, Django tests will not work unless they are -# told which database(s) to work with. The multiple database scenario setup for these -# tests purposefully defines 'default' as an empty database in order to catch any -# assumptions in this package about database names and in particular to ensure there is -# no assumption that 'default' is a valid database. -# For any test that would usually use Django's TestCase or TransactionTestCase using -# the classes defined here is all that is required. -# In test code, anywhere the database is referenced the Django router needs to be used -# exactly like the package's code. -# For instance: -# token_database = router.db_for_write(AccessToken) -# with self.assertNumQueries(1, using=token_database): -# Without the 'using' option, this test fails in the multiple database scenario because -# 'default' is used. +# told which database(s) to work with. def retrieve_current_databases(): diff --git a/tests/db_router.py b/tests/db_router.py index e7fdbd796..7aa354ed8 100644 --- a/tests/db_router.py +++ b/tests/db_router.py @@ -49,4 +49,28 @@ def allow_relation(self, obj1, obj2, **hints): def allow_migrate(self, db, app_label, model_name=None, **hints): if app_label in apps_in_beta: return db == "beta" + + +class CrossDatabaseRouter: + # alpha is where the core Django models are stored including user. To keep things + # simple this is where the oauth2 provider models are stored as well because they + # have a foreign key to User. + def db_for_read(self, model, **hints): + if model._meta.model_name == "accesstoken": + return "beta" + return None + + def db_for_write(self, model, **hints): + if model._meta.model_name == "accesstoken": + return "beta" + return None + + def allow_relation(self, obj1, obj2, **hints): + if obj1._state.db == "beta" and obj2._state.db == "beta": + return True + return None + + def allow_migrate(self, db, app_label, model_name=None, **hints): + if model_name == "accesstoken": + return db == "beta" return None diff --git a/tests/test_django_checks.py b/tests/test_django_checks.py new file mode 100644 index 000000000..77025b115 --- /dev/null +++ b/tests/test_django_checks.py @@ -0,0 +1,20 @@ +from django.core.management import call_command +from django.core.management.base import SystemCheckError +from django.test import override_settings + +from .common_testing import OAuth2ProviderTestCase as TestCase + + +class DjangoChecksTestCase(TestCase): + def test_checks_pass(self): + call_command("check") + + # CrossDatabaseRouter claims AccessToken is in beta while everything else is in alpha. + # This will cause the database checks to fail. + @override_settings( + DATABASE_ROUTERS=["tests.db_router.CrossDatabaseRouter", "tests.db_router.AlphaRouter"] + ) + def test_checks_fail_when_router_crosses_databases(self): + message = "The token models are expected to be stored in the same database." + with self.assertRaisesMessage(SystemCheckError, message): + call_command("check") From ceaebc9800fcdb6f6de1fcc740d02a560590d143 Mon Sep 17 00:00:00 2001 From: Alan Crosswell Date: Mon, 26 Aug 2024 12:02:34 -0400 Subject: [PATCH 15/15] fix typo --- docs/contributing.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/contributing.rst b/docs/contributing.rst index 8d621b413..648993024 100644 --- a/docs/contributing.rst +++ b/docs/contributing.rst @@ -253,7 +253,7 @@ Open :file:`mycoverage/index.html` in your browser and you can see a coverage su There's no need to wait for Codecov to complain after you submit your PR. The tests are generic and written to work with both single database and multiple database configurations. tox will run -tests both ways. You can see the configurations used in tests/settings.py and tests/multi_db_settins.py. +tests both ways. You can see the configurations used in tests/settings.py and tests/multi_db_settings.py. When there are multiple databases defined, Django tests will not work unless they are told which database(s) to work with. For test writers this means any test must either: