Skip to content
This repository was archived by the owner on Mar 3, 2020. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ venv
*.sqlite
.coverage
.idea/
htmlcov/
26 changes: 8 additions & 18 deletions provider/oauth2/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from django.http import QueryDict
from django.test import TestCase
from django.utils.html import escape
from mock import patch

from provider import constants, scope
from provider.oauth2.backends import AccessTokenBackend, BasicClientBackend, RequestParamsClientBackend
Expand Down Expand Up @@ -166,8 +167,8 @@ def test_token_authorization_redirects_to_correct_uri(self):
self.assertEqual(url, self.get_client().redirect_uri)
self.assertTrue('access_token' in urlparse.parse_qs(fragment))

@patch('provider.constants.SINGLE_ACCESS_TOKEN', True)
def test_token_ignores_expired_tokens(self):
constants.SINGLE_ACCESS_TOKEN = True
AccessToken.objects.create(
user=self.get_user(),
client=self.get_client(),
Expand All @@ -179,11 +180,9 @@ def test_token_ignores_expired_tokens(self):
self.client.post(self.auth_url2(), data={'authorize': 'Authorize'})

self.assertEqual(AccessToken.objects.count(), 2)
constants.SINGLE_ACCESS_TOKEN = False

@patch('provider.constants.SINGLE_ACCESS_TOKEN', True)
def test_token_doesnt_return_tokens_from_another_client(self):
constants.SINGLE_ACCESS_TOKEN = True

# Different client than we'll be submitting an RPC for.
AccessToken.objects.create(
user=self.get_user(),
Expand All @@ -195,10 +194,9 @@ def test_token_doesnt_return_tokens_from_another_client(self):
self.client.post(self.auth_url2(), data={'authorize': 'Authorize'})

self.assertEqual(AccessToken.objects.count(), 2)
constants.SINGLE_ACCESS_TOKEN = False

@patch('provider.constants.SINGLE_ACCESS_TOKEN', True)
def test_token_authorization_respects_single_access_token_constant(self):
constants.SINGLE_ACCESS_TOKEN = True
self.login()
self.client.get(self.auth_url(), data=self.get_auth_params(response_type="token"))
self.client.post(self.auth_url2(), data={'authorize': 'Authorize'})
Expand All @@ -210,10 +208,9 @@ def test_token_authorization_respects_single_access_token_constant(self):
self.client.post(self.auth_url2(), data={'authorize': 'Authorize'})

self.assertEqual(AccessToken.objects.count(), 1)
constants.SINGLE_ACCESS_TOKEN = False

@patch('provider.constants.SINGLE_ACCESS_TOKEN', False)
def test_token_authorization_can_do_multi_access_tokens(self):
constants.SINGLE_ACCESS_TOKEN = False
self.login()
self.client.get(self.auth_url(), data=self.get_auth_params(response_type="token"))
self.client.post(self.auth_url2(), data={'authorize': 'Authorize'})
Expand All @@ -226,8 +223,8 @@ def test_token_authorization_can_do_multi_access_tokens(self):

self.assertEqual(AccessToken.objects.count(), 2)

@patch('provider.constants.SINGLE_ACCESS_TOKEN', False)
def test_token_authorization_cancellation(self):
constants.SINGLE_ACCESS_TOKEN = False
self.login()
self.client.get(self.auth_url(), data=self.get_auth_params(response_type="token"))
self.client.post(self.auth_url2())
Expand Down Expand Up @@ -436,19 +433,14 @@ def test_fetching_access_token_with_invalid_grant_type(self):
self.assertEqual(400, response.status_code)
self.assertEqual('unsupported_grant_type', json.loads(response.content)['error'], response.content)

@patch('provider.constants.SINGLE_ACCESS_TOKEN', True)
def test_fetching_single_access_token(self):
constants.SINGLE_ACCESS_TOKEN = True

result1 = self._login_authorize_get_token()
result2 = self._login_authorize_get_token()

self.assertEqual(result1['access_token'], result2['access_token'])

constants.SINGLE_ACCESS_TOKEN = False

def test_fetching_single_access_token_after_refresh(self):
constants.SINGLE_ACCESS_TOKEN = True

token = self._login_authorize_get_token()

self.client.post(self.access_token_url(), {
Expand All @@ -461,8 +453,6 @@ def test_fetching_single_access_token_after_refresh(self):
new_token = self._login_authorize_get_token()
self.assertNotEqual(token['access_token'], new_token['access_token'])

constants.SINGLE_ACCESS_TOKEN = False

def test_fetching_access_token_multiple_times(self):
self._login_authorize_get_token()
code = self.get_grant().code
Expand Down Expand Up @@ -534,7 +524,7 @@ def test_password_grant_public(self):

def test_password_grant_confidential(self):
c = self.get_client()
c.client_type = 0 # confidential
c.client_type = constants.CONFIDENTIAL
c.save()

response = self.client.post(self.access_token_url(), {
Expand Down
11 changes: 4 additions & 7 deletions provider/oauth2/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,19 @@

from provider import constants
from provider.oauth2.backends import BasicClientBackend, RequestParamsClientBackend, PublicPasswordBackend
from provider.oauth2.forms import AuthorizationCodeGrantForm
from provider.oauth2.forms import AuthorizationRequestForm, AuthorizationForm
from provider.oauth2.forms import PasswordGrantForm, RefreshTokenGrantForm
from provider.oauth2.forms import (AuthorizationCodeGrantForm, AuthorizationRequestForm, AuthorizationForm,
PasswordGrantForm, RefreshTokenGrantForm)
from provider.oauth2.models import Client, RefreshToken, AccessToken
from provider.utils import now
from provider.views import AccessToken as AccessTokenView, OAuthError, AccessTokenMixin
from provider.views import Capture, Authorize, Redirect
from provider.views import AccessToken as AccessTokenView, OAuthError, AccessTokenMixin, Capture, Authorize, Redirect


class OAuth2AccessTokenMixin(AccessTokenMixin):

def get_access_token(self, request, user, scope, client):
try:
# Attempt to fetch an existing access token.
at = AccessToken.objects.get(user=user, client=client,
scope=scope, expires__gt=now())
at = AccessToken.objects.get(user=user, client=client, scope=scope, expires__gt=now())
except AccessToken.DoesNotExist:
# None found... make a new one!
at = self.create_access_token(request, user, scope, client)
Expand Down
2 changes: 1 addition & 1 deletion provider/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def password(self, request, data, client):
else:
at = self.create_access_token(request, user, scope, client)
# Public clients don't get refresh tokens
if client.client_type != 1:
if client.client_type == constants.CONFIDENTIAL:
rt = self.create_refresh_token(request, user, scope, at, client)

return self.access_token_response(at)
Expand Down