diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 00000000..dd2aa46c --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,35 @@ +# Read the Docs configuration file for Sphinx projects +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the OS, Python version and other tools you might need +build: + os: ubuntu-22.04 + tools: + python: "3.12" + # You can also specify other tool versions: + # nodejs: "20" + # rust: "1.70" + # golang: "1.20" + +# Build documentation in the "docs/" directory with Sphinx +sphinx: + configuration: docs/conf.py + # You can configure Sphinx to use a different builder, for instance use the dirhtml builder for simpler URLs + # builder: "dirhtml" + # Fail on all warnings to avoid broken references + # fail_on_warning: true + +# Optionally build your docs in additional formats such as PDF and ePub +# formats: +# - pdf +# - epub + +# Optional but recommended, declare the Python requirements required +# to build your documentation +# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html +# python: +# install: +# - requirements: docs/requirements.txt diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..5519a401 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,20 @@ +ARG PYVERSION=3.9.19-bullseye + +FROM python:${PYVERSION} AS dev + +WORKDIR /app + +COPY requirements.txt /app/ + +RUN apt-get update \ + && apt-get install -q -y \ + jq \ + && apt-get clean + +RUN pip install -r requirements.txt + +FROM dev as prod + +COPY ./ /app/ + + diff --git a/Jenkinsfile b/Jenkinsfile new file mode 100644 index 00000000..81655a8a --- /dev/null +++ b/Jenkinsfile @@ -0,0 +1,47 @@ +pipeline { + agent { + label "python" + } + stages { + stage('Virtualenv'){ + steps { + sh '/usr/bin/virtualenv toxtest -p /usr/bin/python3' + sh 'toxtest/bin/pip install tox==3.28.0 pathlib2' + } + } + stage('Test'){ + parallel { + stage('Unit Test Django 3.1'){ + steps { + sh 'toxtest/bin/tox -e py3.8-django{3.1}' + } + } + stage('Unit Test Django 3.2'){ + steps { + sh 'toxtest/bin/tox -e py3.8-django{3.2}' + } + } + stage('Unit Test Django 4.0'){ + steps { + sh 'toxtest/bin/tox -e py3.8-django{4.0}' + } + } + stage('Unit Test Django 4.1'){ + steps { + sh 'toxtest/bin/tox -e py3.8-django{4.1}' + } + } + stage('Unit Test Django 4.2'){ + steps { + sh 'toxtest/bin/tox -e py3.8-django{4.2}' + } + } + } + } + } + post { + cleanup { + cleanWs() + } + } +} diff --git a/README.rst b/README.rst index 1eb97cad..25051c61 100644 --- a/README.rst +++ b/README.rst @@ -1,7 +1,9 @@ django-oauth2 ====================== -.. image:: https://travis-ci.org/stormsherpa/django-oauth2-provider.png?branch=master +This copy of the repository has been retired! It contains the pull request and release history +since it was forked from caffeinehit/django-oauth2-provider. The current authoritative copy of +this repo is in stormsherpa/django-oauth2-provider. *django-oauth2* is a Django application that provides customizable OAuth2\-authentication for your Django projects. @@ -12,3 +14,10 @@ License ======= *django-oauth2* is a fork of *django-oauth2-provider* which is released under the MIT License. Please see the LICENSE file for details. + + +Packaging +========= + + $ python -m build + diff --git a/aws_identity_example.py b/aws_identity_example.py new file mode 100644 index 00000000..6bc65199 --- /dev/null +++ b/aws_identity_example.py @@ -0,0 +1,68 @@ +import os +import sys +import json + +from datetime import datetime +from urllib import request, error +import requests + +import boto3 +# aws-v4-signature==2.0 +from awsv4sign import generate_http11_header + +service = 'sts' +region = 'us-west-2' + +session = boto3.Session() +creds = session.get_credentials() +access_key = creds.access_key +secret_key = creds.secret_key +session_token = creds.token + +print(f"access_key: {access_key[:10]}") +print(f"secret_key: {secret_key[:10]}") +print(f"session_token: {session_token[:20]}") +print(f"profile: {os.environ.get('AWS_PROFILE')}") + +url = 'https://sts.{region}.amazonaws.com/'.format(region=region) +httpMethod = 'post' +canonicalHeaders = { + 'host': f'sts.{region}.amazonaws.com', + 'x-amz-date': datetime.utcnow().strftime('%Y%m%dT%H%M%SZ'), + 'content-type': 'application/x-www-form-urlencoded; charset=utf-8', +} +if session_token: + canonicalHeaders['x-amz-security-token'] = session_token + +payload_str = "Action=GetCallerIdentity&Version=2011-06-15" + +headers = generate_http11_header( + service, region, access_key, secret_key, + url, 'post', canonicalHeaders, {}, + '', payload_str +) + +token_request_args = { + "grant_type": "aws_identity", + "region": region, + "post_body": payload_str, + "headers_json": json.dumps(headers), +} +print(payload_str) +print(json.dumps(headers, indent=4)) + +req = request.Request("https://sts.us-west-2.amazonaws.com/", data=payload_str.encode('utf-8'), headers=headers, method='POST') +try: + response = request.urlopen(req) + print(f"Local request test result: {response.read()}") +except error.HTTPError as e: + print(f"HTTPError: {e}: {e.fp.read()}") + sys.exit(1) + +print("Attempting access_token grant request with same signed request:\n") + +token_response = requests.post("http://localhost:8000/oauth2/access_token", + data=token_request_args) +token_info = token_response.json() + +print(token_info) diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 00000000..739aba32 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,24 @@ + +services: + test: + build: + context: . + target: dev + user: ${UID} + volumes: + - ${WORKSPACE:-.}:/app + environment: + - DJANGO_SETTINGS_MODULE=tests.settings + + web: + build: + context: . + target: dev + user: ${UID} + volumes: + - ${WORKSPACE:-.}:/app + ports: + - "8000:8000" + environment: + - DJANGO_SETTINGS_MODULE=tests.settings +# entrypoint: [ "python3", "manage.py", "runserver" ] diff --git a/docs/changes.rst b/docs/changes.rst index bc2f4904..b6d78eb9 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -1,3 +1,12 @@ +v 4.1 +----- +* Add aws_identity grant_type +* Update for Django 3.1-4.2 + +v 4.0 +----- +* Update for Django 3.0-4.1 + v 2.4 ----- * Add HTTP Authorization Bearer token support to Oauth2UserMiddleware diff --git a/docs/getting_started.rst b/docs/getting_started.rst index 4308e6fc..23f7b956 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -35,7 +35,7 @@ Add :attr:`provider.oauth2.urls` to your root ``urls.py`` file. :: - url(r'^oauth2/', include('provider.oauth2.urls', namespace = 'oauth2')), + path('oauth2/', include(('provider.oauth2.urls', 'oauth2'))), .. note:: The namespace argument is required. @@ -92,6 +92,27 @@ in :rfc:`4`. .. note:: Remember that you should always use HTTPS for all your OAuth 2 requests otherwise you won't be secured. +Request an Access Token using AWS credentials +--------------------------------------------- + +The new aws_identity grant_type uses the parameters for a signed GetCallerIdentity +request to prove the caller's identity. + +Your client needs to submit a :attr:`POST` request to +:attr:`/oauth2/access_token` including the following parameters: + +* ``region`` - AWS Region +* ``post_body`` - The post body used for signing the request. Usually ``Action=GetCallerIdentity&Version=2011-06-15`` +* ``headers_json`` - The headers produced by the AWSv4 signing process + +The region value is used to produce the standard https://sts.(region).amazonaws.com/ url used to +make the GetCallerIdentity request. The URL is generated server side to reduce the risk of an +attack based on sending an improperly crafted full URL. + +The aws-v4-signature library implements awsv4sign.generate_http11_header(). An example is +presented in the root of the repository in aws_identity_examply.py. + + Integrate with Django Authentication #################################### diff --git a/provider/__init__.py b/provider/__init__.py index 080e846a..dd49e708 100644 --- a/provider/__init__.py +++ b/provider/__init__.py @@ -1 +1,2 @@ -__version__ = "3.2" +__version__ = "4.3" +# The major version is expected to follow the current django major version:q diff --git a/provider/constants.py b/provider/constants.py index cdbe9306..87ed3099 100644 --- a/provider/constants.py +++ b/provider/constants.py @@ -3,10 +3,12 @@ CONFIDENTIAL = 0 PUBLIC = 1 +PKCE = 2 CLIENT_TYPES = ( (CONFIDENTIAL, "Confidential (Web applications)"), - (PUBLIC, "Public (Native and JS applications)") + (PUBLIC, "Public (Native and JS applications)"), + (PKCE, "RFC7636 PKCE (Native, JS, and Web applications)"), ) RESPONSE_TYPE_CHOICES = getattr(settings, 'OAUTH_RESPONSE_TYPE_CHOICES', ("code", "token")) diff --git a/provider/oauth2/admin.py b/provider/oauth2/admin.py index d4711999..00b1c58a 100644 --- a/provider/oauth2/admin.py +++ b/provider/oauth2/admin.py @@ -23,9 +23,15 @@ class AuthorizedClientAdmin(admin.ModelAdmin): raw_id_fields = ('user',) +class AwsAccountAdmin(admin.ModelAdmin): + list_display = ('arn', 'client', 'max_token_lifetime') + raw_id_fields = ('acting_user',) + + admin.site.register(models.AccessToken, AccessTokenAdmin) admin.site.register(models.Grant, GrantAdmin) admin.site.register(models.Client, ClientAdmin) admin.site.register(models.AuthorizedClient, AuthorizedClientAdmin) +admin.site.register(models.AwsAccount, AwsAccountAdmin) admin.site.register(models.RefreshToken) admin.site.register(models.Scope) diff --git a/provider/oauth2/apps.py b/provider/oauth2/apps.py index c9c50344..73b1ae77 100644 --- a/provider/oauth2/apps.py +++ b/provider/oauth2/apps.py @@ -4,3 +4,6 @@ class Oauth2(AppConfig): name = 'provider.oauth2' label = 'oauth2' verbose_name = "Provider Oauth2" + + def ready(self): + import provider.oauth2.signals diff --git a/provider/oauth2/backends.py b/provider/oauth2/backends.py index 3dee1517..52530ddf 100644 --- a/provider/oauth2/backends.py +++ b/provider/oauth2/backends.py @@ -1,11 +1,11 @@ import base64 from provider.utils import now -from provider.oauth2.forms import ClientAuthForm, PublicPasswordGrantForm, PublicClientForm +from provider.oauth2.forms import ClientAuthForm, PublicPasswordGrantForm, PublicClientForm, PkceClientAuthForm from provider.oauth2.models import AccessToken -class BaseBackend(object): +class BaseBackend: """ Base backend used to authenticate clients as defined in :rfc:`1` against our database. @@ -18,7 +18,7 @@ def authenticate(self, request=None): pass -class BasicClientBackend(object): +class BasicClientBackend: """ Backend that tries to authenticate a client through HTTP authorization headers as defined in :rfc:`2.3.1`. @@ -47,7 +47,7 @@ def authenticate(self, request=None): return None -class RequestParamsClientBackend(object): +class RequestParamsClientBackend: """ Backend that tries to authenticate a client through request parameters which might be in the request body or URI as defined in :rfc:`2.3.1`. @@ -68,7 +68,24 @@ def authenticate(self, request=None): return None -class PublicPasswordBackend(object): +class PkceRequestParamsClientBackend: + def authenticate(self, request=None): + if request is None: + return None + + if hasattr(request, 'REQUEST'): + args = request.REQUEST + else: + args = request.POST or request.GET + form = PkceClientAuthForm(args) + + if form.is_valid(): + return form.cleaned_data.get('client') + + return None + + +class PublicPasswordBackend: """ Backend that tries to authenticate a client using username, password and client ID. This is only available in specific circumstances: @@ -93,7 +110,7 @@ def authenticate(self, request=None): return None -class PublicClientBackend(object): +class PublicClientBackend: def authenticate(self, request=None): if request is None: return None @@ -110,7 +127,7 @@ def authenticate(self, request=None): return None -class AccessTokenBackend(object): +class AccessTokenBackend: """ Authenticate a user via access token and client object. """ diff --git a/provider/oauth2/fixtures/test_oauth2.json b/provider/oauth2/fixtures/test_oauth2.json index c8905acc..54abb86f 100644 --- a/provider/oauth2/fixtures/test_oauth2.json +++ b/provider/oauth2/fixtures/test_oauth2.json @@ -37,6 +37,38 @@ "model": "oauth2.client", "pk": 3 }, + { + "fields": { + "redirect_uri": "http://example.com/application/4/", + "client_id": "d6d7369c815d6d22e3dd", + "client_secret": "ebb3b29293b87b72306b71b2899672b76dc8735b", + "client_type": 2, + "url": "http://example.com/", + "user": 2, + "auto_authorize": true, + "allow_public_token": true, + "authorize_every_time": false, + "allow_plain_pkce": false + }, + "model": "oauth2.client", + "pk": 4 + }, + { + "fields": { + "redirect_uri": "http://example.com/application/5/", + "client_id": "3d3f0f3a05923de9d840", + "client_secret": "df3cb008dd5ec208be023472556906c91e50be60", + "client_type": 2, + "url": "http://example.com/", + "user": 2, + "auto_authorize": true, + "allow_public_token": true, + "authorize_every_time": false, + "allow_plain_pkce": true + }, + "model": "oauth2.client", + "pk": 5 + }, { "fields": { "date_joined": "2012-01-23 05:44:17", @@ -73,6 +105,24 @@ "model": "auth.user", "pk": 2 }, + { + "fields": { + "date_joined": "2012-01-23 05:53:31", + "email": "", + "first_name": "", + "groups": [], + "is_active": true, + "is_staff": false, + "is_superuser": false, + "last_login": "2012-01-23 05:53:31", + "last_name": "", + "password": "sha1$0cf1b$d66589690edd96b410170fcae5cc2bdfb68821e7", + "user_permissions": [], + "username": "test-user-aws" + }, + "model": "auth.user", + "pk": 3 + }, { "fields": { "name": "basic", @@ -88,5 +138,19 @@ }, "model": "oauth2.scope", "pk": 2 + }, + { + "fields": { + "arn": "arn:aws:iam::123456789012:role/testrole", + "account_id": "123456789012", + "name": "testrole", + "general_type": "role", + "client": 2, + "autoprovision_user": false, + "acting_user": 3, + "scope": ["basic", "advanced"] + }, + "model": "oauth2.awsaccount", + "pk": 1 } ] diff --git a/provider/oauth2/forms.py b/provider/oauth2/forms.py index f51a3c9f..5a82a985 100644 --- a/provider/oauth2/forms.py +++ b/provider/oauth2/forms.py @@ -1,14 +1,20 @@ -from six import string_types +import logging +from io import StringIO +from urllib import request +from urllib.error import HTTPError +from xml.etree import ElementTree + from django import forms from django.contrib.auth import authenticate from django.conf import settings -from django.utils.translation import ugettext as _ +from django.utils.translation import gettext as _ from django.utils import timezone -from provider.constants import RESPONSE_TYPE_CHOICES, SCOPES, PUBLIC +from provider.constants import RESPONSE_TYPE_CHOICES, PKCE, PUBLIC from provider.forms import OAuthForm, OAuthValidationError -from provider.utils import now +from provider.utils import now, ArnHelper from provider.oauth2.models import Client, Grant, RefreshToken, Scope +log = logging.getLogger('provider.oauth2') DEFAULT_SCOPE = getattr(settings, 'OAUTH2_DEFAULT_SCOPE', 'read') @@ -48,12 +54,40 @@ def clean(self): return data +class PkceClientAuthForm(forms.Form): + """ + Client authentication form. Required to make sure that we're dealing with a + real client. Form is used in :attr:`provider.oauth2.backends` to validate + the client. + """ + client_id = forms.CharField() + code_verifier = forms.CharField() + code = forms.CharField() + + def clean(self): + data = self.cleaned_data + try: + client = Client.objects.get(client_id=data.get('client_id'), client_type=PKCE) + grant = Grant.objects.get(client=client, code=data.get('code')) + if not grant.verify_code_challenge(data.get('code_verifier')): + raise forms.ValidationError(_("Invalid PKCE grant")) + + except Client.DoesNotExist: + raise forms.ValidationError(_("Client does not support PKCE")) + except Grant.DoesNotExist: + raise forms.ValidationError(_("Invalid PKCE grant")) + + data['client'] = client + return data + + + class ScopeModelChoiceField(forms.ModelMultipleChoiceField): # widget = forms.TextInput def to_python(self, value): - if isinstance(value, string_types): + if isinstance(value, str): return [s for s in value.split(' ') if s != ''] elif isinstance(value, list): value_list = list() @@ -151,6 +185,35 @@ def clean_redirect_uri(self): return redirect_uri +class AuthorizationPkceRequestForm(AuthorizationRequestForm): + code_challenge = forms.CharField(required=False) + code_challenge_method = forms.CharField(required=False) + + def clean_code_challenge(self): + code_challenge = self.cleaned_data.get('code_challenge') + if not code_challenge: + raise OAuthValidationError({ + 'error': 'invalid_request', + 'error_description': _("No 'code_challenge' supplied"), + }) + return code_challenge + + def clean_code_challenge_method(self): + method = self.cleaned_data.get('code_challenge_method') or 'plain' + if method not in ['plain', 'S256']: + raise OAuthValidationError({ + 'error': 'invalid_request', + 'error_description': f"{method} is not a supported code_challenge_method", + }) + if method == 'plain' and not self.client.allow_plain_pkce: + raise OAuthValidationError({ + 'error': 'invalid_request', + 'error_description': 'client does not allow code_challenge_method=plain', + }) + return method + + + class AuthorizationForm(ScopeModelMixin, OAuthForm): """ A form used to ask the resource owner for authorization of a given client. @@ -216,6 +279,7 @@ class AuthorizationCodeGrantForm(ScopeModelMixin, OAuthForm): """ code = forms.CharField(required=False) scope = ScopeModelChoiceField(queryset=Scope.objects.all(), required=False) + code_verifier = forms.CharField(required=False) def clean_code(self): code = self.cleaned_data.get('code') @@ -311,6 +375,46 @@ def clean(self): return data +class AwsGrantForm(OAuthForm): + grant_type = forms.CharField(required=True) + region = forms.CharField(required=True) + post_body = forms.CharField(required=True) + headers_json = forms.JSONField(required=True) + + def clean_grant_type(self): + grant_type = self.cleaned_data.get('grant_type') + + if grant_type != 'aws_identity': + raise OAuthValidationError({'error': 'invalid_grant'}) + + return grant_type + + def clean(self): + region = self.cleaned_data['region'] + + sts_url = f"https://sts.{region}.amazonaws.com/" + + post_body = self.cleaned_data['post_body'] + headers_json = self.cleaned_data['headers_json'] + + req = request.Request(sts_url, data=post_body.encode('utf-8'), headers=headers_json, method='POST') + try: + response = request.urlopen(req) + except HTTPError as e: + log.info("Error calling GetCallerIdentity for aws_identity grant: %s", e) + raise OAuthValidationError({'error': 'invalid_grant'}) + + xmldata = response.read() + + et = ElementTree.parse(StringIO(xmldata.decode('utf-8'))) + root = et.getroot() + result = root.find('{https://sts.amazonaws.com/doc/2011-06-15/}GetCallerIdentityResult') + caller_arn = result.find('{https://sts.amazonaws.com/doc/2011-06-15/}Arn').text + self.cleaned_data['arn_string'] = caller_arn + self.cleaned_data['arn'] = ArnHelper(caller_arn) + return self.cleaned_data + + class PublicClientForm(OAuthForm): client_id = forms.CharField(required=True) grant_type = forms.CharField(required=True) diff --git a/provider/oauth2/migrations/0004_awsaccount.py b/provider/oauth2/migrations/0004_awsaccount.py new file mode 100644 index 00000000..c1d50ccc --- /dev/null +++ b/provider/oauth2/migrations/0004_awsaccount.py @@ -0,0 +1,35 @@ +# Generated by Django 4.2 on 2024-08-07 19:03 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ('oauth2', '0003_public_client_options'), + ] + + operations = [ + migrations.CreateModel( + name='AwsAccount', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('arn', models.CharField(help_text='AWS User or Role ARN', max_length=255, unique=True)), + ('general_type', models.CharField(blank=True, max_length=15, null=True)), + ('account_id', models.CharField(blank=True, max_length=12, null=True)), + ('name', models.CharField(blank=True, max_length=255, null=True)), + ('autoprovision_user', models.BooleanField(default=True, help_text='Automatically create acting user on first use')), + ('acting_user', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.DO_NOTHING, to=settings.AUTH_USER_MODEL)), + ('max_token_lifetime', models.IntegerField(default=3600, blank=True, help_text="Maximum access token lifetime in seconds")), + ('client', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='oauth2.client')), + ('scope', models.ManyToManyField(help_text='Scopes to be applied to tokens', to='oauth2.scope')), + ], + options={ + 'db_table': 'oauth2_awsaccount', + 'unique_together': {('general_type', 'account_id', 'name')}, + }, + ), + ] diff --git a/provider/oauth2/migrations/0005_pkce.py b/provider/oauth2/migrations/0005_pkce.py new file mode 100644 index 00000000..3cc5ecec --- /dev/null +++ b/provider/oauth2/migrations/0005_pkce.py @@ -0,0 +1,38 @@ +# Generated by Django 4.2 on 2025-07-27 21:06 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('oauth2', '0004_awsaccount'), + ] + + operations = [ + migrations.AddField( + model_name='client', + name='allow_plain_pkce', + field=models.BooleanField(blank=True, default=False, help_text='Allow code_challenge_method=plain for PKCE'), + ), + migrations.AddField( + model_name='client', + name='token_expiry', + field=models.IntegerField(blank=True, help_text='Token expiration timeout. Defaults to OAUTH_EXPIRE_DELTA_PUBLIC or OAUTH_EXPIRE_DELTA', null=True), + ), + migrations.AddField( + model_name='grant', + name='code_challenge', + field=models.CharField(blank=True, max_length=255, null=True), + ), + migrations.AddField( + model_name='grant', + name='code_challenge_method', + field=models.CharField(blank=True, max_length=20, null=True), + ), + migrations.AlterField( + model_name='client', + name='client_type', + field=models.IntegerField(choices=[(0, 'Confidential (Web applications)'), (1, 'Public (Native and JS applications)'), (2, 'RFC7636 PKCE (Native, JS, and Web applications)')]), + ), + ] diff --git a/provider/oauth2/models.py b/provider/oauth2/models.py index f782ff4d..7d4ffc16 100644 --- a/provider/oauth2/models.py +++ b/provider/oauth2/models.py @@ -4,8 +4,12 @@ views in :attr:`provider.views`. """ +from base64 import urlsafe_b64encode +from hashlib import sha256 + from django.db import models from django.conf import settings +from django.contrib.auth import get_user_model from provider import constants from provider.constants import CLIENT_TYPES from provider.utils import now, short_token, long_token, get_code_expiry @@ -42,13 +46,17 @@ class Client(models.Model): authorize_every_time = models.BooleanField(default=False, blank=True) allow_public_token = models.BooleanField(default=False, blank=True, help_text="Allow public client tokens with only client_id and code") + allow_plain_pkce = models.BooleanField(default=False, blank=True, + help_text="Allow code_challenge_method=plain for PKCE") + token_expiry = models.IntegerField(blank=True, null=True, + help_text="Token expiration timeout. Defaults to OAUTH_EXPIRE_DELTA_PUBLIC or OAUTH_EXPIRE_DELTA") - def __unicode__(self): + def __str__(self): return self.redirect_uri def get_default_token_expiry(self): public = (self.client_type == constants.PUBLIC) - return get_token_expiry(public) + return self.token_expiry or get_token_expiry(public) class Meta: app_label = 'oauth2' @@ -59,7 +67,7 @@ class Scope(models.Model): name = models.CharField(max_length=50, primary_key=True) description = models.CharField(max_length=256, default='', blank=True) - def __unicode__(self): + def __str__(self): return self.name class Meta: @@ -129,14 +137,26 @@ class Grant(models.Model): expires = models.DateTimeField(default=get_code_expiry) redirect_uri = models.CharField(max_length=255, blank=True) scope = models.ManyToManyField('Scope') + code_challenge = models.CharField(max_length=255, blank=True, null=True) + code_challenge_method = models.CharField(max_length=20, blank=True, null=True) - def __unicode__(self): + def __str__(self): return self.code class Meta: app_label = 'oauth2' db_table = 'oauth2_grant' + def verify_code_challenge(self, code_verifier): + if not code_verifier: + return False + if self.code_challenge_method == 'plain': + return code_verifier == self.code_challenge + else: + verifier_digest = sha256(code_verifier.encode('ASCII')).digest() + expected = urlsafe_b64encode(verifier_digest).decode() + return expected == self.code_challenge + class AccessTokenManager(models.Manager): def get_token(self, token): @@ -188,7 +208,7 @@ class AccessToken(models.Model): objects = AccessTokenManager() - def __unicode__(self): + def __str__(self): return self.token def save(self, *args, **kwargs): @@ -258,9 +278,40 @@ class RefreshToken(models.Model): objects = RefreshTokenManager() - def __unicode__(self): + def __str__(self): return self.token class Meta: app_label = 'oauth2' db_table = 'oauth2_refreshtoken' + + +class AwsAccount(models.Model): + arn = models.CharField(max_length=255, unique=True, help_text="AWS User or Role ARN") + general_type = models.CharField(max_length=15, blank=True, null=True) + account_id = models.CharField(max_length=12, blank=True, null=True) + name = models.CharField(max_length=255, blank=True, null=True) + + client = models.ForeignKey('Client', models.DO_NOTHING) + autoprovision_user = models.BooleanField(default=True, help_text="Automatically create acting user on first use") + acting_user = models.ForeignKey(settings.AUTH_USER_MODEL, models.DO_NOTHING, blank=True, null=True) + max_token_lifetime = models.IntegerField(default=3600, blank=True, help_text="Maximum access token lifetime in seconds") + scope = models.ManyToManyField("Scope", help_text="Scopes to be applied to tokens") + + class Meta: + app_label = 'oauth2' + db_table = 'oauth2_awsaccount' + unique_together = ( + ('general_type', 'account_id', 'name'), + ) + + def get_or_create_user(self): + if self.acting_user is not None: + return self.acting_user + + if self.autoprovision_user: + username = f"{self.name}_{self.general_type}_{self.account_id}" + User = get_user_model() + self.acting_user, _ = User.objects.get_or_create(username=username) + self.save() + return self.acting_user diff --git a/provider/oauth2/signals.py b/provider/oauth2/signals.py new file mode 100644 index 00000000..fe975f17 --- /dev/null +++ b/provider/oauth2/signals.py @@ -0,0 +1,19 @@ +from django.contrib.auth.models import User +from django.db.models.signals import pre_save +from django.dispatch import receiver + +from provider.utils import ArnHelper +from provider.oauth2.models import AwsAccount + + +@receiver(pre_save, sender=AwsAccount) +def awsaccount_pre_save(sender, instance, **kwargs): + arn = ArnHelper(instance.arn) + if instance.general_type != arn.general_type: + instance.general_type = arn.general_type + + if instance.name != arn.name: + instance.name = arn.name + + if instance.account_id != arn.account_id: + instance.account_id = arn.account_id diff --git a/provider/oauth2/tests/test_middleware.py b/provider/oauth2/tests/test_middleware.py index e3509e84..eace624c 100644 --- a/provider/oauth2/tests/test_middleware.py +++ b/provider/oauth2/tests/test_middleware.py @@ -1,5 +1,5 @@ import json -from six.moves.urllib_parse import urlparse +from urllib.parse import urlparse from django.shortcuts import reverse from django.http import QueryDict diff --git a/provider/oauth2/tests/test_models.py b/provider/oauth2/tests/test_models.py new file mode 100644 index 00000000..7d571672 --- /dev/null +++ b/provider/oauth2/tests/test_models.py @@ -0,0 +1,26 @@ + +from django.test import TestCase + +from provider.oauth2.models import Client, AwsAccount + + +class ModelTests(TestCase): + fixtures = ['test_oauth2'] + + def test_aws_account(self): + client = Client.objects.get(id=2) + + account = AwsAccount.objects.create( + client=client, + arn="arn:aws:iam::123456789012:user/imauser" + ) + + self.assertEqual(account.account_id, "123456789012") + self.assertEqual(account.name, "imauser") + self.assertEqual(account.general_type, "user") + + new_account = AwsAccount.objects.get(pk=account.pk) + + self.assertEqual(new_account.account_id, "123456789012") + self.assertEqual(new_account.name, "imauser") + self.assertEqual(new_account.general_type, "user") diff --git a/provider/oauth2/tests/test_views.py b/provider/oauth2/tests/test_views.py index 78da3b26..9d968aec 100644 --- a/provider/oauth2/tests/test_views.py +++ b/provider/oauth2/tests/test_views.py @@ -1,9 +1,9 @@ import base64 import json import datetime -from six.moves.urllib_parse import urlparse, parse_qs, quote +from urllib.parse import urlparse, parse_qs, quote -from unittest import SkipTest +from unittest.mock import patch from django.http import QueryDict from django.conf import settings from django.shortcuts import reverse @@ -15,7 +15,7 @@ from provider.templatetags.scope import scopes from provider.utils import now as date_now from provider.oauth2.forms import ClientForm -from provider.oauth2.models import Client, Grant, AccessToken, RefreshToken, AuthorizedClient +from provider.oauth2.models import Client, Grant, AccessToken, RefreshToken, AuthorizedClient, AwsAccount from provider.oauth2.backends import BasicClientBackend, RequestParamsClientBackend from provider.oauth2.backends import AccessTokenBackend @@ -43,6 +43,12 @@ def get_client(self): def get_public_client(self): return Client.objects.get(id=3) + def get_pkce_client(self): + return Client.objects.get(id=4) + + def get_pkce_plain_client(self): + return Client.objects.get(id=5) + def get_grant(self): return Grant.objects.all()[0] @@ -52,6 +58,9 @@ def get_user(self): def get_password(self): return 'test' + def get_aws_role(self): + return AwsAccount.objects.get(id=1) + def _login_and_authorize(self, url_func=None): if url_func is None: def url_func(): @@ -203,6 +212,57 @@ def test_preserving_the_state_variable(self): self.assertTrue('code' in response['Location']) self.assertTrue('state=abc' in response['Location']) + def test_pkce_authorization_is_granted(self): + code_verifier = '8d6e7d6d3375bc536fb276ca1e3feebba6f29dd4' + code_challenge = 'cvbVJuj0KkRIpHMCkBOsZojlbqMhu-H7jkXx9U6E_3Y=' + + self.login() + def url_func(): + return self.auth_url() + '?client_id={}&response_type=code&state=abc&code_challenge={}&code_challenge_method=S256'.format( + self.get_pkce_client().client_id, code_challenge, + ) + + self._login_and_authorize(url_func) + + response = self.client.get(self.redirect_url()) + self.assertEqual(302, response.status_code) + self.assertFalse('error' in response['Location']) + self.assertTrue('code' in response['Location']) + + def test_pkce_authorization_plain_prohibited(self): + code_verifier = '8d6e7d6d3375bc536fb276ca1e3feebba6f29dd4' + code_challenge = 'cvbVJuj0KkRIpHMCkBOsZojlbqMhu-H7jkXx9U6E_3Y=' + + self.login() + def url_func(): + return self.auth_url() + '?client_id={}&response_type=code&state=abc&code_challenge={}'.format( + self.get_pkce_client().client_id, code_verifier, + ) + + response = self.client.get(url_func()) + response = self.client.get(self.auth_url2()) + response = self.client.post(self.auth_url2(), {'authorize': True, 'scope': 'read'}) + + self.assertEqual(400, response.status_code) + self.assertIn('client does not allow code_challenge_method=plain', response.content.decode()) + + def test_pkce_plain_authorization_is_granted(self): + code_verifier = '8d6e7d6d3375bc536fb276ca1e3feebba6f29dd4' + code_challenge = 'cvbVJuj0KkRIpHMCkBOsZojlbqMhu-H7jkXx9U6E_3Y=' + + self.login() + def url_func(): + return self.auth_url() + '?client_id={}&response_type=code&state=abc&code_challenge={}'.format( + self.get_pkce_plain_client().client_id, code_verifier, + ) + + self._login_and_authorize(url_func) + + response = self.client.get(self.redirect_url()) + self.assertEqual(302, response.status_code) + self.assertFalse('error' in response['Location']) + self.assertTrue('code' in response['Location']) + # # FIXME: Not sure what the error condition is that should exist here. # def test_redirect_requires_valid_data(self): # self.login() @@ -508,6 +568,177 @@ def test_access_token_response_valid_token_type(self): token = self._login_authorize_get_token() self.assertEqual(token['token_type'], constants.TOKEN_TYPE, token) + @patch('urllib.request.urlopen') + def test_aws_grant_invalid_caller_identity(self, urlopen): + headers = { + "header1": "a", + "header2": "b", + } + post_body = "mypostbody" + + caller_identity_result = """ + + + arn:aws:iam::123456789012:user/myuser + AIDA27 + 123456789012 + + + 00000000-3558-43b5-8157-07d0769322b5 + + """.strip("\n ").encode('utf-8') + + urlopen.return_value.read.return_value = caller_identity_result + + urlopen.return_value.code = 200 + + response = self.client.post(self.access_token_url(), { + 'grant_type': 'aws_identity', + 'region': "us-west-2", + 'post_body': post_body, + 'headers_json': json.dumps(headers), + }) + + self.assertEqual(400, response.status_code) + self.assertEqual('not_authorized', json.loads(response.content), + response.content) + + @patch('urllib.request.urlopen') + def test_aws_grant_valid_caller_identity(self, urlopen): + headers = { + "header1": "a", + "header2": "b", + } + post_body = "mypostbody" + + caller_identity_result = """ + + + arn:aws:iam::123456789012:assumed-role/testrole/testsession + AIDA27 + 123456789012 + + + 00000000-3558-43b5-8157-07d0769322b5 + + """.strip("\n ").encode('utf-8') + + urlopen.return_value.read.return_value = caller_identity_result + + urlopen.return_value.code = 200 + + response = self.client.post(self.access_token_url(), { + 'grant_type': 'aws_identity', + 'region': "us-west-2", + 'post_body': post_body, + 'headers_json': json.dumps(headers), + }) + + self.assertEqual(200, response.status_code) + self.assertNotIn('refresh_token', json.loads(response.content)) + + def test_access_token_pkce_client(self): + required_props = ['access_token', 'token_type', 'refresh_token'] + + code_verifier = '8d6e7d6d3375bc536fb276ca1e3feebba6f29dd4' + code_challenge = 'cvbVJuj0KkRIpHMCkBOsZojlbqMhu-H7jkXx9U6E_3Y=' + + self.login() + def url_func(): + return self.auth_url() + '?client_id={}&response_type=code&state=abc&code_challenge={}&code_challenge_method=S256'.format( + self.get_pkce_client().client_id, code_challenge, + ) + + self._login_and_authorize(url_func=url_func) + + response = self.client.get(self.redirect_url()) + query = QueryDict(urlparse(response['Location']).query) + code = query['code'] + + response = self.client.post(self.access_token_url(), { + 'grant_type': 'authorization_code', + 'client_id': self.get_pkce_client().client_id, + 'code_verifier': code_verifier, + 'code': code}) + + self.assertEqual(200, response.status_code, response.content) + + token = json.loads(response.content) + + for prop in required_props: + self.assertIn(prop, token, "Access token response missing " + "required property: %s" % prop) + + return token + + def test_access_token_pkce_plain_client(self): + required_props = ['access_token', 'token_type'] + + code_verifier = '8d6e7d6d3375bc536fb276ca1e3feebba6f29dd4' + code_challenge = 'cvbVJuj0KkRIpHMCkBOsZojlbqMhu-H7jkXx9U6E_3Y=' + + self.login() + def url_func(): + return self.auth_url() + '?client_id={}&response_type=code&state=abc&code_challenge={}'.format( + self.get_pkce_plain_client().client_id, code_verifier, + ) + + self._login_and_authorize(url_func=url_func) + + response = self.client.get(self.redirect_url()) + query = QueryDict(urlparse(response['Location']).query) + code = query['code'] + + response = self.client.post(self.access_token_url(), { + 'grant_type': 'authorization_code', + 'client_id': self.get_pkce_plain_client().client_id, + 'code_verifier': code_verifier, + 'code': code}) + + self.assertEqual(200, response.status_code, response.content) + + token = json.loads(response.content) + + for prop in required_props: + self.assertIn(prop, token, "Access token response missing " + "required property: %s" % prop) + + return token + + def test_access_token_pkce_plain_client_with_S256(self): + required_props = ['access_token', 'token_type'] + + code_verifier = '8d6e7d6d3375bc536fb276ca1e3feebba6f29dd4' + code_challenge = 'cvbVJuj0KkRIpHMCkBOsZojlbqMhu-H7jkXx9U6E_3Y=' + + self.login() + def url_func(): + return self.auth_url() + '?client_id={}&response_type=code&state=abc&code_challenge={}&code_challenge_method=S256'.format( + self.get_pkce_plain_client().client_id, code_challenge, + ) + + self._login_and_authorize(url_func=url_func) + + response = self.client.get(self.redirect_url()) + query = QueryDict(urlparse(response['Location']).query) + code = query['code'] + + response = self.client.post(self.access_token_url(), { + 'grant_type': 'authorization_code', + 'client_id': self.get_pkce_plain_client().client_id, + 'code_verifier': code_verifier, + 'code': code}) + + self.assertEqual(200, response.status_code, response.content) + + token = json.loads(response.content) + + for prop in required_props: + self.assertIn(prop, token, "Access token response missing " + "required property: %s" % prop) + + return token + class AuthBackendTest(BaseOAuth2TestCase): fixtures = ['test_oauth2'] @@ -651,7 +882,7 @@ def test_clear_expired(self): 'client_id': self.get_client().client_id, 'client_secret': self.get_client().client_secret, 'code': code}) - self.assertEquals(200, response.status_code) + self.assertEqual(200, response.status_code) token = json.loads(response.content) self.assertTrue('access_token' in token) access_token = token['access_token'] @@ -676,9 +907,9 @@ def test_clear_expired(self): self.assertEqual(200, response.status_code) token = json.loads(response.content) self.assertTrue('access_token' in token) - self.assertNotEquals(access_token, token['access_token']) + self.assertNotEqual(access_token, token['access_token']) self.assertTrue('refresh_token' in token) - self.assertNotEquals(refresh_token, token['refresh_token']) + self.assertNotEqual(refresh_token, token['refresh_token']) # make sure the orig AccessToken and RefreshToken are gone self.assertFalse(AccessToken.objects.filter(token=access_token) diff --git a/provider/oauth2/tests/urls.py b/provider/oauth2/tests/urls.py index 0eefd116..445ab379 100644 --- a/provider/oauth2/tests/urls.py +++ b/provider/oauth2/tests/urls.py @@ -1,4 +1,4 @@ -from django.conf.urls import url +from django.urls import path from django.http.response import JsonResponse from django.views.generic import View from django.contrib.auth.mixins import LoginRequiredMixin @@ -37,6 +37,6 @@ def get(self, request, *args, **kwargs): urlpatterns = [ - url('^badscope$', BadScopeView.as_view(), name='badscope'), - url('^user/(?P\d+)$', UserView.as_view(), name='user'), + path('badscope', BadScopeView.as_view(), name='badscope'), + path('user/', UserView.as_view(), name='user'), ] diff --git a/provider/oauth2/urls.py b/provider/oauth2/urls.py index 43abcc63..2a759219 100644 --- a/provider/oauth2/urls.py +++ b/provider/oauth2/urls.py @@ -35,22 +35,22 @@ from django.contrib.auth.decorators import login_required from django.views.decorators.csrf import csrf_exempt -from django.conf.urls import url, include +from django.urls import path from provider.oauth2 import views app_name = 'oauth2' urlpatterns = [ - url('^authorize/?$', + path('authorize', login_required(views.CaptureView.as_view()), name='capture'), - url('^authorize/confirm/?$', + path('authorize/confirm', login_required(views.AuthorizeView.as_view()), name='authorize'), - url('^redirect/?$', + path('redirect', login_required(views.RedirectView.as_view()), name='redirect'), - url('^access_token/?$', + path('access_token', csrf_exempt(views.AccessTokenView.as_view()), name='access_token'), ] diff --git a/provider/oauth2/views.py b/provider/oauth2/views.py index ccc9db86..ed5e028e 100644 --- a/provider/oauth2/views.py +++ b/provider/oauth2/views.py @@ -1,18 +1,107 @@ +import json +import logging from datetime import timedelta +from urllib.parse import urlparse, ParseResult + +from django.http import HttpResponse +from django.http import HttpResponseRedirect, QueryDict from django.shortcuts import reverse +from django.views.generic import TemplateView, View +from django.core.exceptions import ObjectDoesNotExist +from django.utils.translation import gettext as _ + from provider import constants -from provider.views import CaptureViewBase, AuthorizeViewBase, RedirectViewBase -from provider.views import AccessTokenViewBase, OAuthError -from provider.utils import now +from provider.utils import now, ArnHelper from provider.oauth2 import forms from provider.oauth2 import models from provider.oauth2 import backends -class CaptureView(CaptureViewBase): +log = logging.getLogger('provider.oauth2') + + +class OAuthError(Exception): """ - Implementation of :class:`provider.views.Capture`. + Exception to throw inside any views defined in :attr:`provider.views`. + + Any :attr:`OAuthError` thrown will be signalled to the API consumer. + + :attr:`OAuthError` expects a dictionary as its first argument outlining the + type of error that occured. + + :example: + + :: + + raise OAuthError({'error': 'invalid_request'}) + + The different types of errors are outlined in :rfc:`4.2.2.1` and + :rfc:`5.2`. + + """ + + +class AuthUtilMixin(object): + """ + Mixin providing common methods required in the OAuth view defined in + :attr:`provider.views`. """ + authentication = () + + def get_data(self, request, key='params'): + """ + Return stored data from the session store. + + :param key: `str` The key under which the data was stored. + """ + return request.session.get('%s:%s' % (constants.SESSION_KEY, key)) + + def cache_data(self, request, data, key='params'): + """ + Cache data in the session store. + + :param request: :attr:`django.http.HttpRequest` + :param data: Arbitrary data to store. + :param key: `str` The key under which to store the data. + """ + request.session['%s:%s' % (constants.SESSION_KEY, key)] = data + + def clear_data(self, request): + """ + Clear all OAuth related data from the session store. + """ + for key in list(request.session.keys()): + if key.startswith(constants.SESSION_KEY): + del request.session[key] + + def authenticate(self, request): + """ + Authenticate a client against all the backends configured in + :attr:`authentication`. + """ + for backend in self.authentication: + client = backend().authenticate(request) + if client is not None: + return client + return None + + +class CaptureView(AuthUtilMixin, TemplateView): + """ + As stated in section :rfc:`3.1.2.5` this view captures all the request + parameters and redirects to another URL to avoid any leakage of request + parameters to potentially harmful JavaScripts. + + This application assumes that whatever web-server is used as front-end will + handle SSL transport. + + If you want strict enforcement of secure communication at application + level, set :attr:`settings.OAUTH_ENFORCE_SECURE` to ``True``. + + """ + template_name = 'provider/authorize.html' + + def validate_scopes(self, scope_list): scopes = {s.name for s in models.Scope.objects.filter(name__in=scope_list)} @@ -21,12 +110,49 @@ def validate_scopes(self, scope_list): def get_redirect_url(self, request): return reverse('oauth2:authorize') + def handle(self, request, data): + self.cache_data(request, data) + + if constants.ENFORCE_SECURE and not request.is_secure(): + return self.render_to_response({'error': 'access_denied', + 'error_description': _("A secure connection is required."), + 'next': None}, + status=400) + + scope_list = [s for s in + data.get('scope', '').split(' ') if s != ''] + if self.validate_scopes(scope_list): + return HttpResponseRedirect(self.get_redirect_url(request)) + else: + return HttpResponse("Invalid scope.", status=400) -class AuthorizeView(AuthorizeViewBase): + def get(self, request, *args, **kwargs): + return self.handle(request, request.GET) + + def post(self, request, *args, **kwargs): + return self.handle(request, request.POST) + + +class AuthorizeView(AuthUtilMixin, TemplateView): """ - Implementation of :class:`provider.views.Authorize`. + View to handle the client authorization as outlined in :rfc:`4`. + + :attr:`Authorize` renders the ``provider/authorize.html`` template to + display the authorization form. + + On successful authorization, it redirects the user back to the defined + client callback as defined in :rfc:`4.1.2`. + + On authorization fail :attr:`Authorize` displays an error message to the + user with a modified redirect URL to the callback including the error + and possibly description of the error as defined in :rfc:`4.1.2.1`. """ + + template_name = 'provider/authorize.html' + def get_request_form(self, client, data): + if client.client_type == constants.PKCE: + return forms.AuthorizationPkceRequestForm(data, client=client) return forms.AuthorizationRequestForm(data, client=client) def get_authorization_form(self, request, client, data, client_data): @@ -62,26 +188,172 @@ def save_authorization(self, request, client, form, client_data): grant = form.save(user=request.user, client=client, - redirect_uri=client_data.get('redirect_uri', '')) + redirect_uri=client_data.get('redirect_uri', ''), + code_challenge=client_data.get('code_challenge'), + code_challenge_method=client_data.get('code_challenge_method')) if grant is None: return None - grant.user = request.user - grant.client = client - grant.redirect_uri = client_data.get('redirect_uri', '') - grant.save() return grant.code + def _validate_client(self, request, data): + """ + :return: ``tuple`` - ``(client or False, data or error)`` + """ + client = self.get_client(data.get('client_id')) + + if client is None: + raise OAuthError({ + 'error': 'unauthorized_client', + 'error_description': _("An unauthorized client tried to access" + " your resources.") + }) + + form = self.get_request_form(client, data) + + if not form.is_valid(): + raise OAuthError(form.errors) + + return client, form.cleaned_data + + def error_response(self, request, error, **kwargs): + """ + Return an error to be displayed to the resource owner if anything goes + awry. Errors can include invalid clients, authorization denials and + other edge cases such as a wrong ``redirect_uri`` in the authorization + request. + + :param request: :attr:`django.http.HttpRequest` + :param error: ``dict`` + The different types of errors are outlined in :rfc:`4.2.2.1` + """ + ctx = {} + ctx.update(error) + + # If we got a malicious redirect_uri or client_id, remove all the + # cached data and tell the resource owner. We will *not* redirect back + # to the URL. + + if error.get('error') in ['redirect_uri', 'unauthorized_client']: + ctx.update(next='/') + return self.render_to_response(ctx, **kwargs) + + ctx.update(next=self.get_redirect_url(request)) + + return self.render_to_response(ctx, **kwargs) + + def handle(self, request, post_data=None): + data = self.get_data(request) + + if data is None: + return self.error_response(request, { + 'error': 'expired_authorization', + 'error_description': _('Authorization session has expired.')}) + + try: + client, data = self._validate_client(request, data) + except OAuthError as e: + return self.error_response(request, e.args[0], status=400) + + scope_list = [s.name for s in + data.get('scope', [])] + if self.has_authorization(request, client, scope_list): + post_data = { + 'scope': scope_list, + 'authorize': u'Authorize', + } + + authorization_form = self.get_authorization_form(request, client, + post_data, data) + + if not authorization_form.is_bound or not authorization_form.is_valid(): + return self.render_to_response({ + 'client': client, + 'form': authorization_form, + 'oauth_data': data, + }) + + code = self.save_authorization(request, client, + authorization_form, data) + + # be sure to serialize any objects that aren't natively json + # serializable because these values are stored as session data + data['scope'] = scope_list + data['client_id'] = client.client_id # Add this back in, it gets lost sometimes + self.cache_data(request, data) + self.cache_data(request, code, "code") + self.cache_data(request, client.pk, "client_pk") + + return HttpResponseRedirect(self.get_redirect_url(request)) + + def get(self, request, *args, **kwargs): + return self.handle(request, None) + + def post(self, request, *args, **kwargs): + return self.handle(request, request.POST) -class RedirectView(RedirectViewBase): + + + +class RedirectView(AuthUtilMixin, View): """ - Implementation of :class:`provider.views.Redirect` + Redirect the user back to the client with the right query parameters set. + This can be either parameters indicating success or parameters indicating + an error. """ - pass + def error_response(self, error, mimetype='application/json', status=400, + **kwargs): + """ + Return an error response to the client with default status code of + *400* stating the error as outlined in :rfc:`5.2`. + """ + return HttpResponse(json.dumps(error), content_type=mimetype, + status=status, **kwargs) + + def get(self, request): + data = self.get_data(request) + code = self.get_data(request, "code") + error = self.get_data(request, "error") + client_pk = self.get_data(request, "client_pk") + + client = models.Client.objects.get(pk=client_pk) + + # this is an edge case that is caused by making a request with no data + # it should only happen if this view is called manually, out of the + # normal capture-authorize-redirect flow. + if data is None or client is None: + return self.error_response({ + 'error': 'invalid_data', + 'error_description': _('Data has not been captured')}) + + redirect_uri = data.get('redirect_uri', None) or client.redirect_uri + + parsed = urlparse(redirect_uri) + + query = QueryDict('', mutable=True) + + if 'state' in data: + query['state'] = data['state'] + + if error is not None: + query.update(error) + elif code is None: + query['error'] = 'access_denied' + else: + query['code'] = code + + parsed = parsed[:4] + (query.urlencode(), '') + + redirect_uri = ParseResult(*parsed).geturl() + + self.clear_data(request) + + return HttpResponseRedirect(redirect_uri) -class AccessTokenView(AccessTokenViewBase): + +class AccessTokenView(AuthUtilMixin, TemplateView): """ Implementation of :class:`provider.views.AccessToken`. @@ -89,14 +361,31 @@ class AccessTokenView(AccessTokenViewBase): in :attr:`provider.views.AccessToken.grant_types`. If you wish to disable any, you can override the :meth:`get_handler` method *or* the :attr:`grant_types` list. + + + According to :rfc:`4.4.2` this endpoint too must support secure + communication. For strict enforcement of secure communication at + application level set :attr:`settings.OAUTH_ENFORCE_SECURE` to ``True``. + + According to :rfc:`3.2` we can only accept POST requests. + + Returns with a status code of *400* in case of errors. *200* in case of + success. + """ authentication = ( backends.BasicClientBackend, backends.RequestParamsClientBackend, + backends.PkceRequestParamsClientBackend, backends.PublicPasswordBackend, backends.PublicClientBackend, ) + grant_types = ['authorization_code', 'refresh_token', 'password', 'aws_identity'] + """ + The default grant types supported by this view. + """ + def get_authorization_code_grant(self, request, data, client): form = forms.AuthorizationCodeGrantForm(data, client=client) if not form.is_valid(): @@ -115,6 +404,25 @@ def get_password_grant(self, request, data, client): raise OAuthError(form.errors) return form.cleaned_data + def get_aws_grant(self, request, data, _client): + form = forms.AwsGrantForm(data) + if not form.is_valid(): + raise OAuthError(form.errors) + data = form.cleaned_data + arn = data.get('arn') + try: + account = models.AwsAccount.objects.get( + account_id=arn.account_id, + general_type=arn.general_type, + name=arn.name, + ) + except models.AwsAccount.DoesNotExist: + log.info("No AwsAccount found for arn '%s'", arn.arn) + raise OAuthError("not_authorized") + + data['awsaccount'] = account + return data + def get_access_token(self, request, user, scope, client): try: # Attempt to fetch an existing access token. @@ -162,3 +470,163 @@ def invalidate_access_token(self, at): else: at.expires = now() - timedelta(days=1) at.save() + + def error_response(self, error, mimetype='application/json', status=400, + **kwargs): + """ + Return an error response to the client with default status code of + *400* stating the error as outlined in :rfc:`5.2`. + """ + return HttpResponse(json.dumps(error), content_type=mimetype, + status=status, **kwargs) + + def access_token_response(self, access_token): + """ + Returns a successful response after creating the access token + as defined in :rfc:`5.1`. + """ + + response_data = { + 'access_token': access_token.token, + 'token_type': constants.TOKEN_TYPE, + 'expires_in': access_token.get_expire_delta(), + 'scope': access_token.get_scope_string(), + } + + # Not all access_tokens are given a refresh_token + # (for example, public clients doing password auth) + try: + rt = access_token.refresh_token + response_data['refresh_token'] = rt.token + except ObjectDoesNotExist: + pass + + return HttpResponse( + json.dumps(response_data), content_type='application/json' + ) + + def authorization_code(self, request, data, client): + """ + Handle ``grant_type=authorization_code`` requests as defined in + :rfc:`4.1.3`. + """ + grant = self.get_authorization_code_grant(request, request.POST, + client) + at = self.create_access_token(request, grant.user, + list(grant.scope.all()), client) + + suppress_refresh_token = False + if client.client_type == constants.PUBLIC and client.allow_public_token: + if not request.POST.get('client_secret'): + suppress_refresh_token = True + + if not suppress_refresh_token: + rt = self.create_refresh_token(request, grant.user, + list(grant.scope.all()), at, client) + + self.invalidate_grant(grant) + + return self.access_token_response(at) + + def refresh_token(self, request, data, client): + """ + Handle ``grant_type=refresh_token`` requests as defined in :rfc:`6`. + """ + rt = self.get_refresh_token_grant(request, data, client) + + token_scope = list(rt.access_token.scope.all()) + + # this must be called first in case we need to purge expired tokens + self.invalidate_refresh_token(rt) + self.invalidate_access_token(rt.access_token) + + at = self.create_access_token(request, rt.user, + token_scope, + client) + rt = self.create_refresh_token(request, at.user, + at.scope.all(), at, client) + + return self.access_token_response(at) + + def password(self, request, data, client): + """ + Handle ``grant_type=password`` requests as defined in :rfc:`4.3`. + """ + + data = self.get_password_grant(request, data, client) + user = data.get('user') + scope = data.get('scope') + + at = self.create_access_token(request, user, scope, client) + # Public clients don't get refresh tokens + if client.client_type != constants.PUBLIC: + rt = self.create_refresh_token(request, user, scope, at, client) + + return self.access_token_response(at) + + def aws_identity(self, request, data, client): + data = self.get_aws_grant(request, data, client) + account = data.get('awsaccount') + scope = list(account.scope.all()) + + at = self.create_access_token(request, account.get_or_create_user(), scope, account.client) + at.expires = now() + timedelta(seconds=account.max_token_lifetime) + at.save() + return self.access_token_response(at) + + def get_handler(self, grant_type): + """ + Return a function or method that is capable handling the ``grant_type`` + requested by the client or return ``None`` to indicate that this type + of grant type is not supported, resulting in an error response. + """ + if grant_type == 'authorization_code': + return self.authorization_code + elif grant_type == 'refresh_token': + return self.refresh_token + elif grant_type == 'password': + return self.password + elif grant_type == 'aws_identity': + return self.aws_identity + return None + + def get(self, request, *args, **kwargs): + """ + As per :rfc:`3.2` the token endpoint *only* supports POST requests. + Returns an error response. + """ + return self.error_response({ + 'error': 'invalid_request', + 'error_description': _("Only POST requests allowed.")}) + + def post(self, request): + """ + As per :rfc:`3.2` the token endpoint *only* supports POST requests. + """ + if constants.ENFORCE_SECURE and not request.is_secure(): + return self.error_response({ + 'error': 'invalid_request', + 'error_description': _("A secure connection is required.")}) + + if not 'grant_type' in request.POST: + return self.error_response({ + 'error': 'invalid_request', + 'error_description': _("No 'grant_type' included in the " + "request.")}) + + grant_type = request.POST['grant_type'] + + if grant_type not in self.grant_types: + return self.error_response({'error': 'unsupported_grant_type'}) + + client = self.authenticate(request) + + if client is None and grant_type != 'aws_identity': + return self.error_response({'error': 'invalid_client'}) + + handler = self.get_handler(grant_type) + + try: + return handler(request, request.POST, client) + except OAuthError as e: + return self.error_response(e.args[0]) diff --git a/provider/tests/test_utils.py b/provider/tests/test_utils.py index da1ae28b..8de4527b 100644 --- a/provider/tests/test_utils.py +++ b/provider/tests/test_utils.py @@ -4,6 +4,66 @@ from django.test import TestCase +from provider.utils import ArnHelper, BadArn + class UtilsTestCase(TestCase): - pass + def test_arn_user_helper(self): + user_arn = "arn:aws:iam::123456789012:user/imauser" + + arn = ArnHelper(user_arn) + self.assertEqual(arn.account_id, "123456789012") + self.assertEqual(arn.type, "user") + self.assertEqual(arn.name, "imauser") + + def test_arn_user_equality(self): + user_arn = "arn:aws:iam::123456789012:user/imauser" + + arn = ArnHelper(user_arn) + + caller_identity_arn = ArnHelper("arn:aws:iam::123456789012:user/imauser") + + self.assertEqual(arn, caller_identity_arn) + + def test_arn_role_helper(self): + role_arn = "arn:aws:iam::123456789012:role/my-ec2-role" + arn = ArnHelper(role_arn) + + self.assertEqual(arn.account_id, "123456789012") + self.assertEqual(arn.type, "role") + self.assertEqual(arn.name, "my-ec2-role") + + def test_arn_role_caller_identity_helper(self): + role_arn = "arn:aws:sts::123456789012:assumed-role/my-ec2-role/sessionidentifier" + arn = ArnHelper(role_arn) + self.assertEqual(arn.account_id, "123456789012") + self.assertEqual(arn.type, "assumed-role") + self.assertEqual(arn.general_type, "role") + self.assertEqual(arn.name, "my-ec2-role") + self.assertEqual(arn.session, "sessionidentifier") + + def test_arn_role_equality(self): + role_arn = "arn:aws:iam::123456789012:role/my-ec2-role" + arn = ArnHelper(role_arn) + + caller_identity_arn = ArnHelper( + "arn:aws:sts::123456789012:assumed-role/my-ec2-role/sessionidentifier" + ) + + self.assertEqual(arn, caller_identity_arn) + + def test_invalid_arn_too_long(self): + with self.assertRaises(BadArn): + ArnHelper("arn:aws:iam::123456789012:role/my-ec2-role:invalidextra") + + def test_invalid_arn_too_short(self): + with self.assertRaises(BadArn): + ArnHelper("arn:aws:iam::123456789012") + + def test_invalid_arn_bad_prefix(self): + with self.assertRaises(BadArn): + ArnHelper("notarn:aws:iam::123456789012:role/my-ec2-role") + + def test_invalid_arn_bad_service(self): + with self.assertRaises(BadArn): + ArnHelper("arn:aws:s3::123456789012:role/my-ec2-role") diff --git a/provider/utils.py b/provider/utils.py index 25894901..3e5011b0 100644 --- a/provider/utils.py +++ b/provider/utils.py @@ -48,3 +48,38 @@ def get_code_expiry(): :attr:`datetime.timedelta` object. """ return now() + EXPIRE_CODE_DELTA + + +class BadArn(Exception): + pass + + +class ArnHelper: + def __init__(self, arn): + self.arn = arn + parts = arn.split(':') + if len(parts) != 6: + raise BadArn("Arn must have 6 parts") + if parts[:2] != ['arn', 'aws']: + raise BadArn("Arn must start with 'arn:aws:...'") + + if parts[2] not in ['iam', 'sts']: + raise BadArn("Arn must come from 'iam' or 'sts' service") + + self.service = parts[2] + self.account_id = parts[4] + self.entity_ref = parts[5] + entity_parts = self.entity_ref.split('/') + self.type = entity_parts[0] + self.general_type = self.type if self.type != "assumed-role" else "role" + self.name = entity_parts[1] + self.session = entity_parts[2] if len(entity_parts) > 2 else None + + def __eq__(self, other): + if not isinstance(other, ArnHelper): + return False + + if self.account_id == other.account_id and self.general_type == other.general_type and self.name == other.name: + return True + + return False diff --git a/provider/views.py b/provider/views.py index 94e82d2d..e69de29b 100644 --- a/provider/views.py +++ b/provider/views.py @@ -1,625 +0,0 @@ -from __future__ import absolute_import - -import json - -from six.moves.urllib_parse import urlparse, ParseResult - -from django.http import HttpResponse -from django.http import HttpResponseRedirect, QueryDict -from django.utils.translation import ugettext as _ -from django.views.generic.base import TemplateView, View -from django.core.exceptions import ObjectDoesNotExist -from provider.oauth2.models import Client, Scope -from provider import constants - - -class OAuthError(Exception): - """ - Exception to throw inside any views defined in :attr:`provider.views`. - - Any :attr:`OAuthError` thrown will be signalled to the API consumer. - - :attr:`OAuthError` expects a dictionary as its first argument outlining the - type of error that occured. - - :example: - - :: - - raise OAuthError({'error': 'invalid_request'}) - - The different types of errors are outlined in :rfc:`4.2.2.1` and - :rfc:`5.2`. - - """ - - -class AuthUtilMixin(object): - """ - Mixin providing common methods required in the OAuth view defined in - :attr:`provider.views`. - """ - def get_data(self, request, key='params'): - """ - Return stored data from the session store. - - :param key: `str` The key under which the data was stored. - """ - return request.session.get('%s:%s' % (constants.SESSION_KEY, key)) - - def cache_data(self, request, data, key='params'): - """ - Cache data in the session store. - - :param request: :attr:`django.http.HttpRequest` - :param data: Arbitrary data to store. - :param key: `str` The key under which to store the data. - """ - request.session['%s:%s' % (constants.SESSION_KEY, key)] = data - - def clear_data(self, request): - """ - Clear all OAuth related data from the session store. - """ - for key in list(request.session.keys()): - if key.startswith(constants.SESSION_KEY): - del request.session[key] - - def authenticate(self, request): - """ - Authenticate a client against all the backends configured in - :attr:`authentication`. - """ - for backend in self.authentication: - client = backend().authenticate(request) - if client is not None: - return client - return None - - -class CaptureViewBase(AuthUtilMixin, TemplateView): - """ - As stated in section :rfc:`3.1.2.5` this view captures all the request - parameters and redirects to another URL to avoid any leakage of request - parameters to potentially harmful JavaScripts. - - This application assumes that whatever web-server is used as front-end will - handle SSL transport. - - If you want strict enforcement of secure communication at application - level, set :attr:`settings.OAUTH_ENFORCE_SECURE` to ``True``. - - The actual implementation is required to override :meth:`get_redirect_url`. - """ - template_name = 'provider/authorize.html' - - def get_redirect_url(self, request): - """ - Return a redirect to a URL where the resource owner (see :rfc:`1`) - authorizes the client (also :rfc:`1`). - - :return: :class:`django.http.HttpResponseRedirect` - - """ - raise NotImplementedError - - def validate_scopes(self, scope_list): - raise NotImplementedError - - def handle(self, request, data): - self.cache_data(request, data) - - if constants.ENFORCE_SECURE and not request.is_secure(): - return self.render_to_response({'error': 'access_denied', - 'error_description': _("A secure connection is required."), - 'next': None}, - status=400) - - scope_list = [s for s in - data.get('scope', '').split(' ') if s != ''] - if self.validate_scopes(scope_list): - return HttpResponseRedirect(self.get_redirect_url(request)) - else: - return HttpResponse("Invalid scope.", status=400) - - def get(self, request): - return self.handle(request, request.GET) - - def post(self, request): - return self.handle(request, request.POST) - - -class AuthorizeViewBase(AuthUtilMixin, TemplateView): - """ - View to handle the client authorization as outlined in :rfc:`4`. - Implementation must override a set of methods: - - * :attr:`get_redirect_url` - * :attr:`get_request_form` - * :attr:`get_authorization_form` - * :attr:`get_client` - * :attr:`save_authorization` - - :attr:`Authorize` renders the ``provider/authorize.html`` template to - display the authorization form. - - On successful authorization, it redirects the user back to the defined - client callback as defined in :rfc:`4.1.2`. - - On authorization fail :attr:`Authorize` displays an error message to the - user with a modified redirect URL to the callback including the error - and possibly description of the error as defined in :rfc:`4.1.2.1`. - """ - template_name = 'provider/authorize.html' - - def get_redirect_url(self, request): - """ - :return: ``str`` - The client URL to display in the template after - authorization succeeded or failed. - """ - raise NotImplementedError - - def get_request_form(self, client, data): - """ - Return a form that is capable of validating the request data captured - by the :class:`Capture` view. - The form must accept a keyword argument ``client``. - """ - raise NotImplementedError - - def get_authorization_form(self, request, client, data, client_data): - """ - Return a form that is capable of authorizing the client to the resource - owner. - - :return: :attr:`django.forms.Form` - """ - raise NotImplementedError - - def get_client(self, client_id): - """ - Return a client object from a given client identifier. Return ``None`` - if no client is found. An error will be displayed to the resource owner - and presented to the client upon the final redirect. - """ - raise NotImplementedError - - def save_authorization(self, request, client, form, client_data): - """ - Save the authorization that the user granted to the client, involving - the creation of a time limited authorization code as outlined in - :rfc:`4.1.2`. - - Should return ``None`` in case authorization is not granted. - Should return a string representing the authorization code grant. - - :return: ``None``, ``str`` - """ - raise NotImplementedError - - def has_authorization(self, request, client, scope_list): - """ - Check to see if there is a previous authorization request with the - requested scope permissions. - - :param request: - :param client: - :param scope_list: - :return: ``False``, ``AuthorizedClient`` - """ - return False - - def _validate_client(self, request, data): - """ - :return: ``tuple`` - ``(client or False, data or error)`` - """ - client = self.get_client(data.get('client_id')) - - if client is None: - raise OAuthError({ - 'error': 'unauthorized_client', - 'error_description': _("An unauthorized client tried to access" - " your resources.") - }) - - form = self.get_request_form(client, data) - - if not form.is_valid(): - raise OAuthError(form.errors) - - return client, form.cleaned_data - - def error_response(self, request, error, **kwargs): - """ - Return an error to be displayed to the resource owner if anything goes - awry. Errors can include invalid clients, authorization denials and - other edge cases such as a wrong ``redirect_uri`` in the authorization - request. - - :param request: :attr:`django.http.HttpRequest` - :param error: ``dict`` - The different types of errors are outlined in :rfc:`4.2.2.1` - """ - ctx = {} - ctx.update(error) - - # If we got a malicious redirect_uri or client_id, remove all the - # cached data and tell the resource owner. We will *not* redirect back - # to the URL. - - if error.get('error') in ['redirect_uri', 'unauthorized_client']: - ctx.update(next='/') - return self.render_to_response(ctx, **kwargs) - - ctx.update(next=self.get_redirect_url(request)) - - return self.render_to_response(ctx, **kwargs) - - def handle(self, request, post_data=None): - data = self.get_data(request) - - if data is None: - return self.error_response(request, { - 'error': 'expired_authorization', - 'error_description': _('Authorization session has expired.')}) - - try: - client, data = self._validate_client(request, data) - except OAuthError as e: - return self.error_response(request, e.args[0], status=400) - - scope_list = [s.name for s in - data.get('scope', [])] - if self.has_authorization(request, client, scope_list): - post_data = { - 'scope': scope_list, - 'authorize': u'Authorize', - } - - authorization_form = self.get_authorization_form(request, client, - post_data, data) - - if not authorization_form.is_bound or not authorization_form.is_valid(): - return self.render_to_response({ - 'client': client, - 'form': authorization_form, - 'oauth_data': data, - }) - - code = self.save_authorization(request, client, - authorization_form, data) - - # be sure to serialize any objects that aren't natively json - # serializable because these values are stored as session data - data['scope'] = scope_list - self.cache_data(request, data) - self.cache_data(request, code, "code") - self.cache_data(request, client.pk, "client_pk") - - return HttpResponseRedirect(self.get_redirect_url(request)) - - def get(self, request): - return self.handle(request, None) - - def post(self, request): - return self.handle(request, request.POST) - - -class RedirectViewBase(AuthUtilMixin, View): - """ - Redirect the user back to the client with the right query parameters set. - This can be either parameters indicating success or parameters indicating - an error. - """ - - def error_response(self, error, mimetype='application/json', status=400, - **kwargs): - """ - Return an error response to the client with default status code of - *400* stating the error as outlined in :rfc:`5.2`. - """ - return HttpResponse(json.dumps(error), content_type=mimetype, - status=status, **kwargs) - - def get(self, request): - data = self.get_data(request) - code = self.get_data(request, "code") - error = self.get_data(request, "error") - client_pk = self.get_data(request, "client_pk") - - client = Client.objects.get(pk=client_pk) - - # this is an edge case that is caused by making a request with no data - # it should only happen if this view is called manually, out of the - # normal capture-authorize-redirect flow. - if data is None or client is None: - return self.error_response({ - 'error': 'invalid_data', - 'error_description': _('Data has not been captured')}) - - redirect_uri = data.get('redirect_uri', None) or client.redirect_uri - - parsed = urlparse(redirect_uri) - - query = QueryDict('', mutable=True) - - if 'state' in data: - query['state'] = data['state'] - - if error is not None: - query.update(error) - elif code is None: - query['error'] = 'access_denied' - else: - query['code'] = code - - parsed = parsed[:4] + (query.urlencode(), '') - - redirect_uri = ParseResult(*parsed).geturl() - - self.clear_data(request) - - return HttpResponseRedirect(redirect_uri) - - -class AccessTokenViewBase(AuthUtilMixin, TemplateView): - """ - :attr:`AccessToken` handles creation and refreshing of access tokens. - - Implementations must implement a number of methods: - - * :attr:`get_authorization_code_grant` - * :attr:`get_refresh_token_grant` - * :attr:`get_password_grant` - * :attr:`get_access_token` - * :attr:`create_access_token` - * :attr:`create_refresh_token` - * :attr:`invalidate_grant` - * :attr:`invalidate_access_token` - * :attr:`invalidate_refresh_token` - - The default implementation supports the grant types defined in - :attr:`grant_types`. - - According to :rfc:`4.4.2` this endpoint too must support secure - communication. For strict enforcement of secure communication at - application level set :attr:`settings.OAUTH_ENFORCE_SECURE` to ``True``. - - According to :rfc:`3.2` we can only accept POST requests. - - Returns with a status code of *400* in case of errors. *200* in case of - success. - """ - - authentication = () - """ - Authentication backends used to authenticate a particular client. - """ - - grant_types = ['authorization_code', 'refresh_token', 'password'] - """ - The default grant types supported by this view. - """ - - def get_authorization_code_grant(self, request, data, client): - """ - Return the grant associated with this request or an error dict. - - :return: ``tuple`` - ``(True or False, grant or error_dict)`` - """ - raise NotImplementedError - - def get_refresh_token_grant(self, request, data, client): - """ - Return the refresh token associated with this request or an error dict. - - :return: ``tuple`` - ``(True or False, token or error_dict)`` - """ - raise NotImplementedError - - def get_password_grant(self, request, data, client): - """ - Return a user associated with this request or an error dict. - - :return: ``tuple`` - ``(True or False, user or error_dict)`` - """ - raise NotImplementedError - - def get_access_token(self, request, user, scope, client): - """ - Override to handle fetching of an existing access token. - - :return: ``object`` - Access token - """ - raise NotImplementedError - - def create_access_token(self, request, user, scope, client): - """ - Override to handle access token creation. - - :return: ``object`` - Access token - """ - raise NotImplementedError - - def create_refresh_token(self, request, user, scope, access_token, client): - """ - Override to handle refresh token creation. - - :return: ``object`` - Refresh token - """ - raise NotImplementedError - - def invalidate_grant(self, grant): - """ - Override to handle grant invalidation. A grant is invalidated right - after creating an access token from it. - - :return None: - """ - raise NotImplementedError - - def invalidate_refresh_token(self, refresh_token): - """ - Override to handle refresh token invalidation. When requesting a new - access token from a refresh token, the old one is *always* invalidated. - - :return None: - """ - raise NotImplementedError - - def invalidate_access_token(self, access_token): - """ - Override to handle access token invalidation. When a new access token - is created from a refresh token, the old one is *always* invalidated. - - :return None: - """ - raise NotImplementedError - - def error_response(self, error, mimetype='application/json', status=400, - **kwargs): - """ - Return an error response to the client with default status code of - *400* stating the error as outlined in :rfc:`5.2`. - """ - return HttpResponse(json.dumps(error), content_type=mimetype, - status=status, **kwargs) - - def access_token_response(self, access_token): - """ - Returns a successful response after creating the access token - as defined in :rfc:`5.1`. - """ - - response_data = { - 'access_token': access_token.token, - 'token_type': constants.TOKEN_TYPE, - 'expires_in': access_token.get_expire_delta(), - 'scope': access_token.get_scope_string(), - } - - # Not all access_tokens are given a refresh_token - # (for example, public clients doing password auth) - try: - rt = access_token.refresh_token - response_data['refresh_token'] = rt.token - except ObjectDoesNotExist: - pass - - return HttpResponse( - json.dumps(response_data), content_type='application/json' - ) - - def authorization_code(self, request, data, client): - """ - Handle ``grant_type=authorization_code`` requests as defined in - :rfc:`4.1.3`. - """ - grant = self.get_authorization_code_grant(request, request.POST, - client) - at = self.create_access_token(request, grant.user, - list(grant.scope.all()), client) - - suppress_refresh_token = False - if client.client_type == constants.PUBLIC and client.allow_public_token: - if not request.POST.get('client_secret'): - suppress_refresh_token = True - - if not suppress_refresh_token: - rt = self.create_refresh_token(request, grant.user, - list(grant.scope.all()), at, client) - - self.invalidate_grant(grant) - - return self.access_token_response(at) - - def refresh_token(self, request, data, client): - """ - Handle ``grant_type=refresh_token`` requests as defined in :rfc:`6`. - """ - rt = self.get_refresh_token_grant(request, data, client) - - token_scope = list(rt.access_token.scope.all()) - - # this must be called first in case we need to purge expired tokens - self.invalidate_refresh_token(rt) - self.invalidate_access_token(rt.access_token) - - at = self.create_access_token(request, rt.user, - token_scope, - client) - rt = self.create_refresh_token(request, at.user, - at.scope.all(), at, client) - - return self.access_token_response(at) - - def password(self, request, data, client): - """ - Handle ``grant_type=password`` requests as defined in :rfc:`4.3`. - """ - - data = self.get_password_grant(request, data, client) - user = data.get('user') - scope = data.get('scope') - - at = self.create_access_token(request, user, scope, client) - # Public clients don't get refresh tokens - if client.client_type != constants.PUBLIC: - rt = self.create_refresh_token(request, user, scope, at, client) - - return self.access_token_response(at) - - def get_handler(self, grant_type): - """ - Return a function or method that is capable handling the ``grant_type`` - requested by the client or return ``None`` to indicate that this type - of grant type is not supported, resulting in an error response. - """ - if grant_type == 'authorization_code': - return self.authorization_code - elif grant_type == 'refresh_token': - return self.refresh_token - elif grant_type == 'password': - return self.password - return None - - def get(self, request): - """ - As per :rfc:`3.2` the token endpoint *only* supports POST requests. - Returns an error response. - """ - return self.error_response({ - 'error': 'invalid_request', - 'error_description': _("Only POST requests allowed.")}) - - def post(self, request): - """ - As per :rfc:`3.2` the token endpoint *only* supports POST requests. - """ - if constants.ENFORCE_SECURE and not request.is_secure(): - return self.error_response({ - 'error': 'invalid_request', - 'error_description': _("A secure connection is required.")}) - - if not 'grant_type' in request.POST: - return self.error_response({ - 'error': 'invalid_request', - 'error_description': _("No 'grant_type' included in the " - "request.")}) - - grant_type = request.POST['grant_type'] - - if grant_type not in self.grant_types: - return self.error_response({'error': 'unsupported_grant_type'}) - - client = self.authenticate(request) - - if client is None: - return self.error_response({'error': 'invalid_client'}) - - handler = self.get_handler(grant_type) - - try: - return handler(request, request.POST, client) - except OAuthError as e: - return self.error_response(e.args[0]) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..6adaff8d --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,33 @@ +[build-system] +requires = ["setuptools", "setuptools-git-versioning>=2.0,<3",] +build-backend = "setuptools.build_meta" + +[tool.setuptools-git-versioning] +enabled = true + + +[project] +name = "django-oauth2" +dynamic = ["version"] +description = 'Provide OAuth2 access to your app (fork of django-oauth2-provider)' +readme = 'README.rst' +maintainers = [ + {name="Shaun Kruger", email="shaun.kruger@gmail.com"} +] +requires-python = ">=3.8" +classifiers=[ + 'Environment :: Web Environment', + 'Intended Audience :: Developers', + 'Operating System :: OS Independent', + 'Programming Language :: Python', + 'Framework :: Django', +] +dependencies=[ + "shortuuid>=1.0.11", + "sqlparse>=0.4.3", +] + +[project.urls] +Homepage = 'https://github.com/stormsherpa/django-oauth2-provider' + + diff --git a/requirements.txt b/requirements.txt index 7cf48965..b0235ead 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -Django==3.2 +Django==4.2 shortuuid==1.0.11 -six>=0.16.0 sqlparse>=0.4.3 diff --git a/setup.py b/setup.py index 1eb1e63b..b6ae4a90 100644 --- a/setup.py +++ b/setup.py @@ -1,32 +1,11 @@ #!/usr/bin/env python from setuptools import setup, find_packages -import provider setup( name='django-oauth2', - version=provider.__version__, - description='Provide OAuth2 access to your app (fork of django-oauth2-provider)', long_description=open('README.rst').read(), - author='Shaun Kruger', - author_email='shaun.kruger@gmail.com', - url = 'https://github.com/stormsherpa/django-oauth2-provider', packages=find_packages(exclude=('tests*',)), - license='The MIT License: http://www.opensource.org/licenses/mit-license.php', - platforms='all', - classifiers=[ - 'Environment :: Web Environment', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: MIT License', - 'Operating System :: OS Independent', - 'Programming Language :: Python', - 'Framework :: Django', - ], - install_requires=[ - "shortuuid>=1.0.11", - "six>=0.16.0", - "sqlparse>=0.4.3", - ], include_package_data=True, zip_safe=False, ) diff --git a/tests/urls.py b/tests/urls.py index ea504bc0..bb1bd5c9 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -1,10 +1,10 @@ -from django.conf.urls import url, include +from django.urls import path, include from django.contrib import admin admin.autodiscover() urlpatterns = [ - url(r'^admin/', admin.site.urls), - url(r'^oauth2/', include('provider.oauth2.urls', namespace='oauth2')), - url(r'^tests/', include('provider.oauth2.tests.urls', namespace='tests')), + path('admin/', admin.site.urls), + path('oauth2/', include('provider.oauth2.urls', namespace='oauth2')), + path('tests/', include('provider.oauth2.tests.urls', namespace='tests')), ] diff --git a/tox.ini b/tox.ini index 923b96f5..3c713f66 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ [tox] toxworkdir={env:TOX_WORK_DIR:.tox} downloadcache = {toxworkdir}/cache/ -envlist = py{3.8,3.9,3.10}-django{3.0,3.1,3.2,4.0,4.1} +envlist = py{3.8,3.9,3.10}-django{3.1,3.2,4.0,4.1,4.2} [testenv] setenv = @@ -15,11 +15,6 @@ python = 3.8: py3.8-django{3.0,3.1,3.2,4.0,4.1} -[testenv:py3.8-django3.0] -basepython = python3.8 -deps = Django>=3.0,<3.1 - {[testenv]deps} - [testenv:py3.8-django3.1] basepython = python3.8 deps = Django>=3.1,<3.2 @@ -39,3 +34,8 @@ deps = Django>=4.0,<4.1 basepython = python3.8 deps = Django>=4.1,<4.2 {[testenv]deps} + +[testenv:py3.8-django4.2] +basepython = python3.8 +deps = Django>=4.2,<5.0 + {[testenv]deps}