Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 81 additions & 5 deletions src/easy_oauth/manager.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import asyncio
import secrets
import urllib.parse
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from functools import cached_property
from datetime import datetime, timedelta, timezone
from functools import cached_property, lru_cache

import httpx
from authlib.integrations.starlette_client import OAuth
from authlib.jose import JoseError, JWTClaims, jwt
from itsdangerous import BadData, URLSafeSerializer
from serieux import deserialize, serialize
from serieux.features.encrypt import Secret
Expand All @@ -17,6 +20,13 @@
from .structs import OpenIDConfiguration, Payload, UserInfo


@lru_cache(maxsize=100)
def _headless_store(session_token: str) -> dict[str, asyncio.Event | list[str]]:
assert session_token

return {"event": asyncio.Event(), "token": []}


@dataclass(kw_only=True)
class OAuthManager:
server_metadata_url: str
Expand Down Expand Up @@ -145,18 +155,59 @@ async def assimilate_payload(self, request):
request.session["access_token"] = payload.access_token
request.session["refresh_token"] = payload.refresh_token

def _init_headless(self, request: Request):
"""Detect and initialize headless session."""
headless = request.query_params.get("headless", "false").lower() == "true"

if headless_session := request.query_params.get("headless_session", None):
try:
decoded: JWTClaims = jwt.decode(headless_session, key=self.secret_key)
decoded.validate()
request.session["headless_session"] = decoded["session"]
except JoseError:
request.session["headless_session"] = None

return headless

##########
# Routes #
##########

async def route_login(self, request):
async def route_login(self, request: Request):
red = request.session.get("redirect_after_login", "/")
request.session.clear()

request.session["redirect_after_login"] = red
if self.force_user: # pragma: no cover
# Pages won't redirect to /login when force_user is True,
# so this won't happen unless the user directly goes to /login
return RedirectResponse(url=red)

if self._init_headless(request):
headless_session = request.session.get(
"headless_session",
jwt.encode(
{"alg": "HS256"},
payload={
"session": secrets.token_urlsafe(32),
"exp": datetime.now(timezone.utc) + timedelta(minutes=5),
},
key=self.secret_key,
),
)
login_params = urllib.parse.urlencode(
{"headless_session": headless_session, "headless": False}
)
token_params = urllib.parse.urlencode(
{"headless_session": headless_session, "headless": True}
)
return JSONResponse(
{
"login_url": f"{request.url_for('token')}?{login_params}",
"token_url": f"{request.url_for('token')}?{token_params}",
}
)

auth_route = request.query_params.get("redirect", "auth")
redirect_uri = request.url_for(auth_route)
params = {}
Expand All @@ -174,20 +225,45 @@ async def route_auth(self, request):
red = request.session.get("redirect_after_login", "/")
return RedirectResponse(url=red)

async def route_token(self, request):
async def route_token(self, request: Request):
if self.force_user:
return JSONResponse({"refresh_token": "XXX"})

headless = self._init_headless(request)
headless_session = request.session.get("headless_session", None)

if headless and headless_session:
headless_state = _headless_store(headless_session)
await headless_state["event"].wait()
try:
return JSONResponse({"refresh_token": headless_state["token"].pop()})
except IndexError:
# The token has been used, a new one will need to be generated
return PlainTextResponse("Unauthorized", status_code=401)

if state := request.query_params.get("state"):
await self.assimilate_payload(request)

if not (rt := request.session.get("refresh_token")):
if not state:
login_url = request.url_for("login")
return RedirectResponse(url=f"{login_url}?offline_token=true&redirect=token")
params: dict = {
**request.query_params,
"offline_token": "true",
"redirect": "token",
}
params.pop("state", None)
return RedirectResponse(url=f"{login_url}?{urllib.parse.urlencode(params)}")
else: # pragma: no cover
return PlainTextResponse("Unauthorized", status_code=401)

ert = self.secrets_serializer.dumps(rt)

if headless_session:
headless_state = _headless_store(headless_session)
headless_state["token"].insert(0, ert)
headless_state["event"].set()

return JSONResponse({"refresh_token": ert})

async def route_logout(self, request):
Expand Down
Loading