-
-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
338 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Litestar | ||
|
||
Utilities are provided to ease the integration of an OAuth2 process in [Litestar](https://litestar.dev/). | ||
|
||
## `OAuth2AuthorizeCallback` | ||
|
||
Dependency callable to handle the authorization callback. It reads the query parameters and returns the access token and the state. | ||
|
||
```py | ||
from httpx_oauth.integrations.litestar import OAuth2AuthorizeCallback, AccessTokenState | ||
from httpx_oauth.oauth2 import OAuth2 | ||
from litestar import Litestar, get | ||
from litestar.di import Provide | ||
from litestar.params import Dependency | ||
|
||
client = OAuth2("CLIENT_ID", "CLIENT_SECRET", "AUTHORIZE_ENDPOINT", "ACCESS_TOKEN_ENDPOINT") | ||
oauth2_authorize_callback = OAuth2AuthorizeCallback(client, "oauth-callback") | ||
|
||
@get("/oauth-callback", name="oauth-callback") | ||
async def oauth_callback( | ||
access_token_state: AccessTokenState = Dependency(skip_validation=True), | ||
) -> AccessTokenState: | ||
token, state = access_token_state | ||
# Do something useful | ||
|
||
app = Litestar(route_handlers=[oauth_callback],dependencies={"access_token_state": Provide(oauth2_authorize_callback)}) | ||
|
||
|
||
``` | ||
|
||
[Reference](./reference/httpx_oauth.integrations.litestar.md){ .md-button } | ||
{ .buttons } | ||
|
||
### Custom exception handler | ||
|
||
If an error occurs inside the callback logic (the user denied access, the authorization code is invalid...), the dependency will raise [OAuth2AuthorizeCallbackError][httpx_oauth.integrations.litestar.OAuth2AuthorizeCallbackError]. | ||
|
||
It inherits from Litestar's [HTTPException][litestar.exceptions.HTTPException], so it's automatically handled by the default Litestar exception handler. You can customize this behavior by implementing your own exception handler for `OAuth2AuthorizeCallbackError`. | ||
|
||
```py | ||
from httpx_oauth.integrations.litestar import OAuth2AuthorizeCallbackError | ||
from litestar import Litestar | ||
from litestar.response import Response | ||
|
||
async def oauth2_authorize_callback_error_handler(request: Request, exc: OAuth2AuthorizeCallbackError) -> Response: | ||
detail = exc.detail | ||
status_code = exc.status_code | ||
return Response( | ||
status_code=status_code, | ||
content={"message": "The OAuth2 callback failed", "detail": detail}, | ||
) | ||
|
||
app = Litestar(exception_handlers={OAuth2AuthorizeCallbackError: oauth2_authorize_callback_error_handler}) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Reference - Integrations - Litestar | ||
|
||
::: httpx_oauth.integrations.litestar | ||
options: | ||
show_root_heading: false | ||
show_source: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# pylint: disable=[invalid-name,import-outside-toplevel] | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Any, Dict, List, TypeAlias, Union # noqa: UP035 | ||
|
||
from litestar import status_codes as status | ||
from litestar.exceptions import HTTPException | ||
from litestar.params import Parameter | ||
|
||
from httpx_oauth.oauth2 import BaseOAuth2, GetAccessTokenError, OAuth2Error, OAuth2Token | ||
|
||
if TYPE_CHECKING: | ||
import httpx | ||
from litestar import Request | ||
|
||
|
||
AccessTokenState: TypeAlias = tuple[OAuth2Token, str | None] | ||
|
||
|
||
class OAuth2AuthorizeCallbackError(OAuth2Error, HTTPException): | ||
"""Error raised when an error occurs during the OAuth2 authorization callback. | ||
It inherits from [HTTPException][litestar.exceptions.HTTPException], so you can either keep | ||
the default Litestar error handling or implement something dedicated. | ||
!!! Note | ||
Due to the way the base `LitestarException` handles the `detail` argument, | ||
the `OAuth2Error` is ordered first here | ||
""" | ||
|
||
def __init__( | ||
self, | ||
status_code: int, | ||
detail: Any = None, | ||
headers: Union[Dict[str, str], None] = None, # noqa: UP007, UP006 | ||
response: Union[httpx.Response, None] = None, # noqa: UP007 | ||
extra: Union[Dict[str, Any], List[Any]] | None = None, # noqa: UP007, UP006 | ||
) -> None: | ||
super().__init__(message=detail) | ||
HTTPException.__init__( | ||
self, detail=detail, status_code=status_code, extra=extra, headers=headers | ||
) | ||
self.response = response | ||
|
||
|
||
class OAuth2AuthorizeCallback: | ||
"""Dependency callable to handle the authorization callback. It reads the query parameters and returns the access token and the state. | ||
Examples: | ||
```py | ||
from litestar import get | ||
from httpx_oauth.integrations.litestar import OAuth2AuthorizeCallback | ||
from httpx_oauth.oauth2 import OAuth2 | ||
client = OAuth2("CLIENT_ID", "CLIENT_SECRET", "AUTHORIZE_ENDPOINT", "ACCESS_TOKEN_ENDPOINT") | ||
oauth2_authorize_callback = OAuth2AuthorizeCallback(client, "oauth-callback") | ||
@get("/oauth-callback", name="oauth-callback", dependencies={"access_token_state": Provide(oauth2_authorize_callback)}) | ||
async def oauth_callback(access_token_state: AccessTokenState)) -> Response: | ||
token, state = access_token_state | ||
# Do something useful | ||
``` | ||
""" | ||
|
||
client: BaseOAuth2 | ||
route_name: str | None | ||
redirect_url: str | None | ||
|
||
def __init__( | ||
self, | ||
client: BaseOAuth2, | ||
route_name: str | None = None, | ||
redirect_url: str | None = None, | ||
) -> None: | ||
"""Args: | ||
client: An [OAuth2][httpx_oauth.oauth2.BaseOAuth2] client. | ||
route_name: Name of the callback route, as defined in the `name` parameter of the route decorator. | ||
redirect_url: Full URL to the callback route. | ||
""" | ||
assert (route_name is not None and redirect_url is None) or ( | ||
route_name is None and redirect_url is not None | ||
), "You should either set route_name or redirect_url" | ||
self.client = client | ||
self.route_name = route_name | ||
self.redirect_url = redirect_url | ||
|
||
async def __call__( | ||
self, | ||
request: Request, | ||
code: str | None = Parameter(query="code", required=False), | ||
code_verifier: str | None = Parameter(query="code_verifier", required=False), | ||
callback_state: str | None = Parameter(query="state", required=False), | ||
error: str | None = Parameter(query="error", required=False), | ||
) -> AccessTokenState: | ||
if code is None or error is not None: | ||
raise OAuth2AuthorizeCallbackError( | ||
status_code=status.HTTP_400_BAD_REQUEST, | ||
detail=error if error is not None else None, | ||
) | ||
|
||
if self.route_name: | ||
redirect_url = str(request.url_for(self.route_name)) | ||
elif self.redirect_url: | ||
redirect_url = self.redirect_url | ||
|
||
try: | ||
access_token = await self.client.get_access_token( | ||
code, | ||
redirect_url, | ||
code_verifier, | ||
) | ||
except GetAccessTokenError as e: | ||
raise OAuth2AuthorizeCallbackError( | ||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | ||
detail=e.message, | ||
response=e.response, | ||
extra={"message": e.message}, | ||
) from e | ||
|
||
return access_token, callback_state |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
import pytest | ||
from litestar import Litestar, get | ||
from litestar import status_codes as status | ||
from litestar.di import Provide | ||
from litestar.params import Dependency | ||
from litestar.testing import TestClient | ||
from pytest_mock import MockerFixture | ||
|
||
from httpx_oauth.integrations.litestar import AccessTokenState, OAuth2AuthorizeCallback | ||
from httpx_oauth.oauth2 import GetAccessTokenError, OAuth2 | ||
|
||
CLIENT_ID = "CLIENT_ID" | ||
CLIENT_SECRET = "CLIENT_SECRET" | ||
AUTHORIZE_ENDPOINT = "https://www.camelot.bt/authorize" | ||
ACCESS_TOKEN_ENDPOINT = "https://www.camelot.bt/access-token" | ||
REDIRECT_URL = "https://www.tintagel.bt/callback" | ||
ROUTE_NAME = "callback" | ||
|
||
client = OAuth2(CLIENT_ID, CLIENT_SECRET, AUTHORIZE_ENDPOINT, ACCESS_TOKEN_ENDPOINT) | ||
oauth2_authorize_callback_route_name = OAuth2AuthorizeCallback( | ||
client, route_name=ROUTE_NAME | ||
) | ||
oauth2_authorize_callback_redirect_url = OAuth2AuthorizeCallback( | ||
client, redirect_url=REDIRECT_URL | ||
) | ||
|
||
|
||
@get( | ||
"/authorize-route-name", | ||
dependencies={"access_token_state": Provide(oauth2_authorize_callback_route_name)}, | ||
) | ||
async def authorize_route_name( | ||
access_token_state: AccessTokenState = Dependency(skip_validation=True), | ||
) -> AccessTokenState: | ||
return access_token_state | ||
|
||
|
||
@get( | ||
"/authorize-redirect-url", | ||
dependencies={ | ||
"access_token_state": Provide(oauth2_authorize_callback_redirect_url) | ||
}, | ||
) | ||
async def authorize_redirect_url( | ||
access_token_state: AccessTokenState = Dependency(skip_validation=True), | ||
) -> AccessTokenState: | ||
return access_token_state | ||
|
||
|
||
@get("/callback", name="callback") | ||
async def callback() -> dict: | ||
return {} | ||
|
||
|
||
app = Litestar(route_handlers=[authorize_route_name, authorize_redirect_url, callback]) | ||
|
||
test_client = TestClient(app=app) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"route,expected_redirect_url", | ||
[ | ||
("/authorize-route-name", "http://testserver.local/callback"), | ||
("/authorize-redirect-url", "https://www.tintagel.bt/callback"), | ||
], | ||
) | ||
class TestOAuth2AuthorizeCallback: | ||
def test_oauth2_authorize_missing_code(self, route, expected_redirect_url): | ||
response = test_client.get(route) | ||
assert response.status_code == status.HTTP_400_BAD_REQUEST | ||
|
||
def test_oauth2_authorize_error(self, route, expected_redirect_url): | ||
response = test_client.get(route, params={"error": "access_denied"}) | ||
assert response.status_code == status.HTTP_400_BAD_REQUEST | ||
assert response.json() == {"status_code": 400, "detail": "access_denied"} | ||
|
||
def test_oauth2_authorize_get_access_token_error( | ||
self, mocker: MockerFixture, route, expected_redirect_url | ||
): | ||
get_access_token_mock = mocker.patch.object( | ||
client, "get_access_token", side_effect=GetAccessTokenError("ERROR") | ||
) | ||
|
||
response = test_client.get(route, params={"code": "CODE"}) | ||
|
||
get_access_token_mock.assert_called_once_with( | ||
"CODE", expected_redirect_url, None | ||
) | ||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR | ||
# by default, litestar will only return `Internal Server Error` as the detail on a response. | ||
# we are adding the ERROR to the `extra` payload | ||
assert response.json() == { | ||
"status_code": 500, | ||
"detail": "Internal Server Error", | ||
"extra": {"message": "ERROR"}, | ||
} | ||
|
||
def test_oauth2_authorize_without_state( | ||
self, patch_async_method, route, expected_redirect_url | ||
): | ||
patch_async_method(client, "get_access_token", return_value="ACCESS_TOKEN") | ||
|
||
response = test_client.get(route, params={"code": "CODE"}) | ||
|
||
client.get_access_token.assert_called() | ||
client.get_access_token.assert_called_once_with( | ||
"CODE", expected_redirect_url, None | ||
) | ||
assert response.status_code == status.HTTP_200_OK | ||
assert response.json() == ["ACCESS_TOKEN", None] | ||
|
||
def test_oauth2_authorize_code_verifier_without_state( | ||
self, patch_async_method, route, expected_redirect_url | ||
): | ||
patch_async_method(client, "get_access_token", return_value="ACCESS_TOKEN") | ||
|
||
response = test_client.get( | ||
route, params={"code": "CODE", "code_verifier": "CODE_VERIFIER"} | ||
) | ||
|
||
client.get_access_token.assert_called() | ||
client.get_access_token.assert_called_once_with( | ||
"CODE", expected_redirect_url, "CODE_VERIFIER" | ||
) | ||
assert response.status_code == status.HTTP_200_OK | ||
assert response.json() == ["ACCESS_TOKEN", None] | ||
|
||
def test_oauth2_authorize_with_state( | ||
self, patch_async_method, route, expected_redirect_url | ||
): | ||
patch_async_method(client, "get_access_token", return_value="ACCESS_TOKEN") | ||
|
||
response = test_client.get(route, params={"code": "CODE", "state": "STATE"}) | ||
|
||
client.get_access_token.assert_called() | ||
client.get_access_token.assert_called_once_with( | ||
"CODE", expected_redirect_url, None | ||
) | ||
assert response.status_code == status.HTTP_200_OK | ||
assert response.json() == ["ACCESS_TOKEN", "STATE"] | ||
|
||
def test_oauth2_authorize_with_state_and_code_verifier( | ||
self, patch_async_method, route, expected_redirect_url | ||
): | ||
patch_async_method(client, "get_access_token", return_value="ACCESS_TOKEN") | ||
|
||
response = test_client.get( | ||
route, | ||
params={"code": "CODE", "state": "STATE", "code_verifier": "CODE_VERIFIER"}, | ||
) | ||
|
||
client.get_access_token.assert_called() | ||
client.get_access_token.assert_called_once_with( | ||
"CODE", expected_redirect_url, "CODE_VERIFIER" | ||
) | ||
assert response.status_code == status.HTTP_200_OK | ||
assert response.json() == ["ACCESS_TOKEN", "STATE"] |