Skip to content

Commit

Permalink
feat: add litestar integration
Browse files Browse the repository at this point in the history
  • Loading branch information
cofin committed Jul 31, 2024
1 parent 88e9bc5 commit 1268b9f
Show file tree
Hide file tree
Showing 5 changed files with 338 additions and 0 deletions.
54 changes: 54 additions & 0 deletions docs/litestar.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Litestar

Utilities are provided to ease the integration of an OAuth2 process in [Litestar](https://litestar.dev/).

## `OAuth2AuthorizeCallback`

Dependency callable to handle the authorization callback. It reads the query parameters and returns the access token and the state.

```py
from httpx_oauth.integrations.litestar import OAuth2AuthorizeCallback, AccessTokenState
from httpx_oauth.oauth2 import OAuth2
from litestar import Litestar, get
from litestar.di import Provide
from litestar.params import Dependency

client = OAuth2("CLIENT_ID", "CLIENT_SECRET", "AUTHORIZE_ENDPOINT", "ACCESS_TOKEN_ENDPOINT")
oauth2_authorize_callback = OAuth2AuthorizeCallback(client, "oauth-callback")

@get("/oauth-callback", name="oauth-callback")
async def oauth_callback(
access_token_state: AccessTokenState = Dependency(skip_validation=True),
) -> AccessTokenState:
token, state = access_token_state
# Do something useful

app = Litestar(route_handlers=[oauth_callback],dependencies={"access_token_state": Provide(oauth2_authorize_callback)})


```

[Reference](./reference/httpx_oauth.integrations.litestar.md){ .md-button }
{ .buttons }

### Custom exception handler

If an error occurs inside the callback logic (the user denied access, the authorization code is invalid...), the dependency will raise [OAuth2AuthorizeCallbackError][httpx_oauth.integrations.litestar.OAuth2AuthorizeCallbackError].

It inherits from Litestar's [HTTPException][litestar.exceptions.HTTPException], so it's automatically handled by the default Litestar exception handler. You can customize this behavior by implementing your own exception handler for `OAuth2AuthorizeCallbackError`.

```py
from httpx_oauth.integrations.litestar import OAuth2AuthorizeCallbackError
from litestar import Litestar
from litestar.response import Response

async def oauth2_authorize_callback_error_handler(request: Request, exc: OAuth2AuthorizeCallbackError) -> Response:
detail = exc.detail
status_code = exc.status_code
return Response(
status_code=status_code,
content={"message": "The OAuth2 callback failed", "detail": detail},
)

app = Litestar(exception_handlers={OAuth2AuthorizeCallbackError: oauth2_authorize_callback_error_handler})
```
6 changes: 6 additions & 0 deletions docs/reference/httpx_oauth.integrations.litestar.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Reference - Integrations - Litestar

::: httpx_oauth.integrations.litestar
options:
show_root_heading: false
show_source: false
120 changes: 120 additions & 0 deletions httpx_oauth/integrations/litestar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# pylint: disable=[invalid-name,import-outside-toplevel]
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, List, TypeAlias, Union # noqa: UP035

from litestar import status_codes as status
from litestar.exceptions import HTTPException
from litestar.params import Parameter

from httpx_oauth.oauth2 import BaseOAuth2, GetAccessTokenError, OAuth2Error, OAuth2Token

if TYPE_CHECKING:
import httpx
from litestar import Request


AccessTokenState: TypeAlias = tuple[OAuth2Token, str | None]


class OAuth2AuthorizeCallbackError(OAuth2Error, HTTPException):
"""Error raised when an error occurs during the OAuth2 authorization callback.
It inherits from [HTTPException][litestar.exceptions.HTTPException], so you can either keep
the default Litestar error handling or implement something dedicated.
!!! Note
Due to the way the base `LitestarException` handles the `detail` argument,
the `OAuth2Error` is ordered first here
"""

def __init__(
self,
status_code: int,
detail: Any = None,
headers: Union[Dict[str, str], None] = None, # noqa: UP007, UP006
response: Union[httpx.Response, None] = None, # noqa: UP007
extra: Union[Dict[str, Any], List[Any]] | None = None, # noqa: UP007, UP006
) -> None:
super().__init__(message=detail)
HTTPException.__init__(
self, detail=detail, status_code=status_code, extra=extra, headers=headers
)
self.response = response


class OAuth2AuthorizeCallback:
"""Dependency callable to handle the authorization callback. It reads the query parameters and returns the access token and the state.
Examples:
```py
from litestar import get
from httpx_oauth.integrations.litestar import OAuth2AuthorizeCallback
from httpx_oauth.oauth2 import OAuth2
client = OAuth2("CLIENT_ID", "CLIENT_SECRET", "AUTHORIZE_ENDPOINT", "ACCESS_TOKEN_ENDPOINT")
oauth2_authorize_callback = OAuth2AuthorizeCallback(client, "oauth-callback")
@get("/oauth-callback", name="oauth-callback", dependencies={"access_token_state": Provide(oauth2_authorize_callback)})
async def oauth_callback(access_token_state: AccessTokenState)) -> Response:
token, state = access_token_state
# Do something useful
```
"""

client: BaseOAuth2
route_name: str | None
redirect_url: str | None

def __init__(
self,
client: BaseOAuth2,
route_name: str | None = None,
redirect_url: str | None = None,
) -> None:
"""Args:
client: An [OAuth2][httpx_oauth.oauth2.BaseOAuth2] client.
route_name: Name of the callback route, as defined in the `name` parameter of the route decorator.
redirect_url: Full URL to the callback route.
"""
assert (route_name is not None and redirect_url is None) or (
route_name is None and redirect_url is not None
), "You should either set route_name or redirect_url"
self.client = client
self.route_name = route_name
self.redirect_url = redirect_url

async def __call__(
self,
request: Request,
code: str | None = Parameter(query="code", required=False),
code_verifier: str | None = Parameter(query="code_verifier", required=False),
callback_state: str | None = Parameter(query="state", required=False),
error: str | None = Parameter(query="error", required=False),
) -> AccessTokenState:
if code is None or error is not None:
raise OAuth2AuthorizeCallbackError(
status_code=status.HTTP_400_BAD_REQUEST,
detail=error if error is not None else None,
)

if self.route_name:
redirect_url = str(request.url_for(self.route_name))
elif self.redirect_url:
redirect_url = self.redirect_url

try:
access_token = await self.client.get_access_token(
code,
redirect_url,
code_verifier,
)
except GetAccessTokenError as e:
raise OAuth2AuthorizeCallbackError(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=e.message,
response=e.response,
extra={"message": e.message},
) from e

return access_token, callback_state
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"pytest-asyncio",
"respx",
"fastapi",
"litestar"
]

[tool.hatch.envs.default.scripts]
Expand Down
157 changes: 157 additions & 0 deletions tests/test_integrations_litestar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import pytest
from litestar import Litestar, get
from litestar import status_codes as status
from litestar.di import Provide
from litestar.params import Dependency
from litestar.testing import TestClient
from pytest_mock import MockerFixture

from httpx_oauth.integrations.litestar import AccessTokenState, OAuth2AuthorizeCallback
from httpx_oauth.oauth2 import GetAccessTokenError, OAuth2

CLIENT_ID = "CLIENT_ID"
CLIENT_SECRET = "CLIENT_SECRET"
AUTHORIZE_ENDPOINT = "https://www.camelot.bt/authorize"
ACCESS_TOKEN_ENDPOINT = "https://www.camelot.bt/access-token"
REDIRECT_URL = "https://www.tintagel.bt/callback"
ROUTE_NAME = "callback"

client = OAuth2(CLIENT_ID, CLIENT_SECRET, AUTHORIZE_ENDPOINT, ACCESS_TOKEN_ENDPOINT)
oauth2_authorize_callback_route_name = OAuth2AuthorizeCallback(
client, route_name=ROUTE_NAME
)
oauth2_authorize_callback_redirect_url = OAuth2AuthorizeCallback(
client, redirect_url=REDIRECT_URL
)


@get(
"/authorize-route-name",
dependencies={"access_token_state": Provide(oauth2_authorize_callback_route_name)},
)
async def authorize_route_name(
access_token_state: AccessTokenState = Dependency(skip_validation=True),
) -> AccessTokenState:
return access_token_state


@get(
"/authorize-redirect-url",
dependencies={
"access_token_state": Provide(oauth2_authorize_callback_redirect_url)
},
)
async def authorize_redirect_url(
access_token_state: AccessTokenState = Dependency(skip_validation=True),
) -> AccessTokenState:
return access_token_state


@get("/callback", name="callback")
async def callback() -> dict:
return {}


app = Litestar(route_handlers=[authorize_route_name, authorize_redirect_url, callback])

test_client = TestClient(app=app)


@pytest.mark.parametrize(
"route,expected_redirect_url",
[
("/authorize-route-name", "http://testserver.local/callback"),
("/authorize-redirect-url", "https://www.tintagel.bt/callback"),
],
)
class TestOAuth2AuthorizeCallback:
def test_oauth2_authorize_missing_code(self, route, expected_redirect_url):
response = test_client.get(route)
assert response.status_code == status.HTTP_400_BAD_REQUEST

def test_oauth2_authorize_error(self, route, expected_redirect_url):
response = test_client.get(route, params={"error": "access_denied"})
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.json() == {"status_code": 400, "detail": "access_denied"}

def test_oauth2_authorize_get_access_token_error(
self, mocker: MockerFixture, route, expected_redirect_url
):
get_access_token_mock = mocker.patch.object(
client, "get_access_token", side_effect=GetAccessTokenError("ERROR")
)

response = test_client.get(route, params={"code": "CODE"})

get_access_token_mock.assert_called_once_with(
"CODE", expected_redirect_url, None
)
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
# by default, litestar will only return `Internal Server Error` as the detail on a response.
# we are adding the ERROR to the `extra` payload
assert response.json() == {
"status_code": 500,
"detail": "Internal Server Error",
"extra": {"message": "ERROR"},
}

def test_oauth2_authorize_without_state(
self, patch_async_method, route, expected_redirect_url
):
patch_async_method(client, "get_access_token", return_value="ACCESS_TOKEN")

response = test_client.get(route, params={"code": "CODE"})

client.get_access_token.assert_called()
client.get_access_token.assert_called_once_with(
"CODE", expected_redirect_url, None
)
assert response.status_code == status.HTTP_200_OK
assert response.json() == ["ACCESS_TOKEN", None]

def test_oauth2_authorize_code_verifier_without_state(
self, patch_async_method, route, expected_redirect_url
):
patch_async_method(client, "get_access_token", return_value="ACCESS_TOKEN")

response = test_client.get(
route, params={"code": "CODE", "code_verifier": "CODE_VERIFIER"}
)

client.get_access_token.assert_called()
client.get_access_token.assert_called_once_with(
"CODE", expected_redirect_url, "CODE_VERIFIER"
)
assert response.status_code == status.HTTP_200_OK
assert response.json() == ["ACCESS_TOKEN", None]

def test_oauth2_authorize_with_state(
self, patch_async_method, route, expected_redirect_url
):
patch_async_method(client, "get_access_token", return_value="ACCESS_TOKEN")

response = test_client.get(route, params={"code": "CODE", "state": "STATE"})

client.get_access_token.assert_called()
client.get_access_token.assert_called_once_with(
"CODE", expected_redirect_url, None
)
assert response.status_code == status.HTTP_200_OK
assert response.json() == ["ACCESS_TOKEN", "STATE"]

def test_oauth2_authorize_with_state_and_code_verifier(
self, patch_async_method, route, expected_redirect_url
):
patch_async_method(client, "get_access_token", return_value="ACCESS_TOKEN")

response = test_client.get(
route,
params={"code": "CODE", "state": "STATE", "code_verifier": "CODE_VERIFIER"},
)

client.get_access_token.assert_called()
client.get_access_token.assert_called_once_with(
"CODE", expected_redirect_url, "CODE_VERIFIER"
)
assert response.status_code == status.HTTP_200_OK
assert response.json() == ["ACCESS_TOKEN", "STATE"]

0 comments on commit 1268b9f

Please sign in to comment.