Skip to content

Commit ec3ccd9

Browse files
committed
Add test for session state on authorization view
1 parent f587442 commit ec3ccd9

File tree

6 files changed

+60
-47
lines changed

6 files changed

+60
-47
lines changed

oauth2_provider/settings.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,7 @@ def import_from_string(val, setting_name):
173173
try:
174174
return import_string(val)
175175
except ImportError as e:
176-
msg = "Could not import %r for setting %r. %s: %s." % (
177-
val,
178-
setting_name,
179-
e.__class__.__name__,
180-
e,
181-
)
176+
msg = "Could not import %r for setting %r. %s: %s." % (val, setting_name, e.__class__.__name__, e)
182177
raise ImportError(msg)
183178

184179

oauth2_provider/urls.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,7 @@
1717
management_urlpatterns = [
1818
# Application management views
1919
path("applications/", views.ApplicationList.as_view(), name="list"),
20-
path(
21-
"applications/register/",
22-
views.ApplicationRegistration.as_view(),
23-
name="register",
24-
),
20+
path("applications/register/", views.ApplicationRegistration.as_view(), name="register"),
2521
path("applications/<slug:pk>/", views.ApplicationDetail.as_view(), name="detail"),
2622
path("applications/<slug:pk>/delete/", views.ApplicationDelete.as_view(), name="delete"),
2723
path("applications/<slug:pk>/update/", views.ApplicationUpdate.as_view(), name="update"),

oauth2_provider/views/base.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,7 @@ def form_valid(self, form):
137137

138138
try:
139139
uri, headers, body, status = self.create_authorization_response(
140-
request=self.request,
141-
scopes=scopes,
142-
credentials=credentials,
143-
allow=allow,
140+
request=self.request, scopes=scopes, credentials=credentials, allow=allow
144141
)
145142
except OAuthToolkitError as error:
146143
return self.error_response(error, application)
@@ -160,7 +157,7 @@ def form_valid(self, form):
160157
salt = secrets.token_urlsafe(16)
161158
encoded = " ".join(
162159
[
163-
self.client.client_id,
160+
credentials["client_id"],
164161
client_origin,
165162
session_management_state_key(self.request),
166163
salt,
@@ -231,20 +228,15 @@ def get(self, request, *args, **kwargs):
231228
# are already approved.
232229
if application.skip_authorization:
233230
uri, headers, body, status = self.create_authorization_response(
234-
request=self.request,
235-
scopes=" ".join(scopes),
236-
credentials=credentials,
237-
allow=True,
231+
request=self.request, scopes=" ".join(scopes), credentials=credentials, allow=True
238232
)
239233
return self.redirect(uri, application)
240234

241235
elif require_approval == "auto":
242236
tokens = (
243237
get_access_token_model()
244238
.objects.filter(
245-
user=request.user,
246-
application=kwargs["application"],
247-
expires__gt=timezone.now(),
239+
user=request.user, application=kwargs["application"], expires__gt=timezone.now()
248240
)
249241
.all()
250242
)
@@ -253,10 +245,7 @@ def get(self, request, *args, **kwargs):
253245
for token in tokens:
254246
if token.allow_scopes(scopes):
255247
uri, headers, body, status = self.create_authorization_response(
256-
request=self.request,
257-
scopes=" ".join(scopes),
258-
credentials=credentials,
259-
allow=True,
248+
request=self.request, scopes=" ".join(scopes), credentials=credentials, allow=True
260249
)
261250
return self.redirect(uri, application)
262251

oauth2_provider/views/oidc.py

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,7 @@ def get(self, request, *args, **kwargs):
8383

8484
signing_algorithms = [Application.HS256_ALGORITHM]
8585
if oauth2_settings.OIDC_RSA_PRIVATE_KEY:
86-
signing_algorithms = [
87-
Application.RS256_ALGORITHM,
88-
Application.HS256_ALGORITHM,
89-
]
86+
signing_algorithms = [Application.RS256_ALGORITHM, Application.HS256_ALGORITHM]
9087

9188
validator_class = oauth2_settings.OAUTH2_VALIDATOR_CLASS
9289
validator = validator_class()
@@ -251,10 +248,7 @@ class RPInitiatedLogoutView(OIDCLogoutOnlyMixin, FormView):
251248
form_class = ConfirmLogoutForm
252249
# Only delete tokens for Application whose client type and authorization
253250
# grant type are in the respective lists.
254-
token_deletion_client_types = [
255-
Application.CLIENT_PUBLIC,
256-
Application.CLIENT_CONFIDENTIAL,
257-
]
251+
token_deletion_client_types = [Application.CLIENT_PUBLIC, Application.CLIENT_CONFIDENTIAL]
258252
token_deletion_grant_types = [
259253
Application.GRANT_AUTHORIZATION_CODE,
260254
Application.GRANT_IMPLICIT,
@@ -458,13 +452,7 @@ def must_prompt(self, token_user):
458452
""" We didn't find a reason to prompt the user """
459453
return False
460454

461-
def do_logout(
462-
self,
463-
application=None,
464-
post_logout_redirect_uri=None,
465-
state=None,
466-
token_user=None,
467-
):
455+
def do_logout(self, application=None, post_logout_redirect_uri=None, state=None, token_user=None):
468456
user = token_user or self.request.user
469457
# Delete Access Tokens if a user was found
470458
if oauth2_settings.OIDC_RP_INITIATED_LOGOUT_DELETE_TOKENS and not isinstance(user, AnonymousUser):
@@ -501,8 +489,7 @@ def do_logout(
501489
return OAuth2ResponseRedirect(post_logout_redirect_uri, application.get_allowed_schemes())
502490
else:
503491
return OAuth2ResponseRedirect(
504-
self.request.build_absolute_uri("/"),
505-
oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES,
492+
self.request.build_absolute_uri("/"), oauth2_settings.ALLOWED_REDIRECT_URI_SCHEMES
506493
)
507494

508495
def error_response(self, error):

tests/presets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
OIDC_SETTINGS_RP_LOGOUT_DENY_EXPIRED["OIDC_RP_INITIATED_LOGOUT_ACCEPT_EXPIRED_TOKENS"] = False
3838
OIDC_SETTINGS_RP_LOGOUT_KEEP_TOKENS = deepcopy(OIDC_SETTINGS_RP_LOGOUT)
3939
OIDC_SETTINGS_RP_LOGOUT_KEEP_TOKENS["OIDC_RP_INITIATED_LOGOUT_DELETE_TOKENS"] = False
40+
OIDC_SETTINGS_SESSION_MANAGEMENT = deepcopy(OIDC_SETTINGS_RW)
41+
OIDC_SETTINGS_SESSION_MANAGEMENT["OIDC_SESSION_MANAGEMENT_ENABLED"] = True
4042
REST_FRAMEWORK_SCOPES = {
4143
"SCOPES": {
4244
"read": "Read scope",

tests/test_oidc_views.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from django.contrib.auth import get_user
2+
from django.contrib.auth import get_user, get_user_model
33
from django.contrib.auth.models import AnonymousUser
44
from django.test import RequestFactory
55
from django.urls import reverse
@@ -12,7 +12,12 @@
1212
InvalidOIDCClientError,
1313
InvalidOIDCRedirectURIError,
1414
)
15-
from oauth2_provider.models import get_access_token_model, get_id_token_model, get_refresh_token_model
15+
from oauth2_provider.models import (
16+
get_access_token_model,
17+
get_application_model,
18+
get_id_token_model,
19+
get_refresh_token_model,
20+
)
1621
from oauth2_provider.oauth2_validators import OAuth2Validator
1722
from oauth2_provider.settings import oauth2_settings
1823
from oauth2_provider.views.oidc import RPInitiatedLogoutView, _load_id_token, _validate_claims
@@ -132,7 +137,10 @@ def test_get_connect_discovery_info_without_issuer_url(self):
132137
],
133138
"subject_types_supported": ["public"],
134139
"id_token_signing_alg_values_supported": ["RS256", "HS256"],
135-
"token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
140+
"token_endpoint_auth_methods_supported": [
141+
"client_secret_post",
142+
"client_secret_basic",
143+
],
136144
"code_challenge_methods_supported": ["plain", "S256"],
137145
"claims_supported": ["sub"],
138146
}
@@ -206,6 +214,42 @@ def test_get_jwks_info_multiple_rsa_keys(self):
206214
assert response.json() == expected_response
207215

208216

217+
@pytest.mark.usefixtures("oauth2_settings")
218+
@pytest.mark.oauth2_settings(presets.OIDC_SETTINGS_SESSION_MANAGEMENT)
219+
class TestAuthorizationView(TestCase):
220+
def test_session_state_is_present_in_url(self):
221+
User = get_user_model()
222+
Application = get_application_model()
223+
224+
User.objects.create_user("test_user", "[email protected]", "123456")
225+
dev_user = User.objects.create_user("dev_user", "[email protected]", "123456")
226+
227+
application = Application.objects.create(
228+
name="Test Application",
229+
redirect_uris=(
230+
"http://localhost http://example.com http://example.org custom-scheme://example.com"
231+
),
232+
user=dev_user,
233+
client_type=Application.CLIENT_CONFIDENTIAL,
234+
authorization_grant_type=Application.GRANT_AUTHORIZATION_CODE,
235+
client_secret="1234567890qwertyuiop",
236+
)
237+
self.client.login(username="test_user", password="123456")
238+
response = self.client.post(
239+
reverse("oauth2_provider:authorize"),
240+
{
241+
"client_id": application.client_id,
242+
"response_type": "code",
243+
"state": "random_state_string",
244+
"scope": "read write",
245+
"redirect_uri": "http://example.org",
246+
"allow": True,
247+
},
248+
)
249+
self.assertEqual(response.status_code, 302)
250+
self.assertTrue("session_state" in response["Location"])
251+
252+
209253
def mock_request():
210254
"""
211255
Dummy request with an AnonymousUser attached.

0 commit comments

Comments
 (0)