diff --git a/.gitignore b/.gitignore index 11760ec5..09c45338 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ venv *.sqlite .coverage .idea/ +htmlcov/ diff --git a/provider/oauth2/tests.py b/provider/oauth2/tests.py index 78e15b4c..5ea494fb 100644 --- a/provider/oauth2/tests.py +++ b/provider/oauth2/tests.py @@ -9,6 +9,7 @@ from django.http import QueryDict from django.test import TestCase from django.utils.html import escape +from mock import patch from provider import constants, scope from provider.oauth2.backends import AccessTokenBackend, BasicClientBackend, RequestParamsClientBackend @@ -166,8 +167,8 @@ def test_token_authorization_redirects_to_correct_uri(self): self.assertEqual(url, self.get_client().redirect_uri) self.assertTrue('access_token' in urlparse.parse_qs(fragment)) + @patch('provider.constants.SINGLE_ACCESS_TOKEN', True) def test_token_ignores_expired_tokens(self): - constants.SINGLE_ACCESS_TOKEN = True AccessToken.objects.create( user=self.get_user(), client=self.get_client(), @@ -179,11 +180,9 @@ def test_token_ignores_expired_tokens(self): self.client.post(self.auth_url2(), data={'authorize': 'Authorize'}) self.assertEqual(AccessToken.objects.count(), 2) - constants.SINGLE_ACCESS_TOKEN = False + @patch('provider.constants.SINGLE_ACCESS_TOKEN', True) def test_token_doesnt_return_tokens_from_another_client(self): - constants.SINGLE_ACCESS_TOKEN = True - # Different client than we'll be submitting an RPC for. AccessToken.objects.create( user=self.get_user(), @@ -195,10 +194,9 @@ def test_token_doesnt_return_tokens_from_another_client(self): self.client.post(self.auth_url2(), data={'authorize': 'Authorize'}) self.assertEqual(AccessToken.objects.count(), 2) - constants.SINGLE_ACCESS_TOKEN = False + @patch('provider.constants.SINGLE_ACCESS_TOKEN', True) def test_token_authorization_respects_single_access_token_constant(self): - constants.SINGLE_ACCESS_TOKEN = True self.login() self.client.get(self.auth_url(), data=self.get_auth_params(response_type="token")) self.client.post(self.auth_url2(), data={'authorize': 'Authorize'}) @@ -210,10 +208,9 @@ def test_token_authorization_respects_single_access_token_constant(self): self.client.post(self.auth_url2(), data={'authorize': 'Authorize'}) self.assertEqual(AccessToken.objects.count(), 1) - constants.SINGLE_ACCESS_TOKEN = False + @patch('provider.constants.SINGLE_ACCESS_TOKEN', False) def test_token_authorization_can_do_multi_access_tokens(self): - constants.SINGLE_ACCESS_TOKEN = False self.login() self.client.get(self.auth_url(), data=self.get_auth_params(response_type="token")) self.client.post(self.auth_url2(), data={'authorize': 'Authorize'}) @@ -226,8 +223,8 @@ def test_token_authorization_can_do_multi_access_tokens(self): self.assertEqual(AccessToken.objects.count(), 2) + @patch('provider.constants.SINGLE_ACCESS_TOKEN', False) def test_token_authorization_cancellation(self): - constants.SINGLE_ACCESS_TOKEN = False self.login() self.client.get(self.auth_url(), data=self.get_auth_params(response_type="token")) self.client.post(self.auth_url2()) @@ -436,19 +433,14 @@ def test_fetching_access_token_with_invalid_grant_type(self): self.assertEqual(400, response.status_code) self.assertEqual('unsupported_grant_type', json.loads(response.content)['error'], response.content) + @patch('provider.constants.SINGLE_ACCESS_TOKEN', True) def test_fetching_single_access_token(self): - constants.SINGLE_ACCESS_TOKEN = True - result1 = self._login_authorize_get_token() result2 = self._login_authorize_get_token() self.assertEqual(result1['access_token'], result2['access_token']) - constants.SINGLE_ACCESS_TOKEN = False - def test_fetching_single_access_token_after_refresh(self): - constants.SINGLE_ACCESS_TOKEN = True - token = self._login_authorize_get_token() self.client.post(self.access_token_url(), { @@ -461,8 +453,6 @@ def test_fetching_single_access_token_after_refresh(self): new_token = self._login_authorize_get_token() self.assertNotEqual(token['access_token'], new_token['access_token']) - constants.SINGLE_ACCESS_TOKEN = False - def test_fetching_access_token_multiple_times(self): self._login_authorize_get_token() code = self.get_grant().code @@ -534,7 +524,7 @@ def test_password_grant_public(self): def test_password_grant_confidential(self): c = self.get_client() - c.client_type = 0 # confidential + c.client_type = constants.CONFIDENTIAL c.save() response = self.client.post(self.access_token_url(), { diff --git a/provider/oauth2/views.py b/provider/oauth2/views.py index 6b393609..0d51a4df 100644 --- a/provider/oauth2/views.py +++ b/provider/oauth2/views.py @@ -8,13 +8,11 @@ from provider import constants from provider.oauth2.backends import BasicClientBackend, RequestParamsClientBackend, PublicPasswordBackend -from provider.oauth2.forms import AuthorizationCodeGrantForm -from provider.oauth2.forms import AuthorizationRequestForm, AuthorizationForm -from provider.oauth2.forms import PasswordGrantForm, RefreshTokenGrantForm +from provider.oauth2.forms import (AuthorizationCodeGrantForm, AuthorizationRequestForm, AuthorizationForm, + PasswordGrantForm, RefreshTokenGrantForm) from provider.oauth2.models import Client, RefreshToken, AccessToken from provider.utils import now -from provider.views import AccessToken as AccessTokenView, OAuthError, AccessTokenMixin -from provider.views import Capture, Authorize, Redirect +from provider.views import AccessToken as AccessTokenView, OAuthError, AccessTokenMixin, Capture, Authorize, Redirect class OAuth2AccessTokenMixin(AccessTokenMixin): @@ -22,8 +20,7 @@ class OAuth2AccessTokenMixin(AccessTokenMixin): def get_access_token(self, request, user, scope, client): try: # Attempt to fetch an existing access token. - at = AccessToken.objects.get(user=user, client=client, - scope=scope, expires__gt=now()) + at = AccessToken.objects.get(user=user, client=client, scope=scope, expires__gt=now()) except AccessToken.DoesNotExist: # None found... make a new one! at = self.create_access_token(request, user, scope, client) diff --git a/provider/views.py b/provider/views.py index 6cd3a0c3..8bca5d13 100644 --- a/provider/views.py +++ b/provider/views.py @@ -605,7 +605,7 @@ def password(self, request, data, client): else: at = self.create_access_token(request, user, scope, client) # Public clients don't get refresh tokens - if client.client_type != 1: + if client.client_type == constants.CONFIDENTIAL: rt = self.create_refresh_token(request, user, scope, at, client) return self.access_token_response(at)