diff --git a/.gitignore b/.gitignore index 6610fea4..62b1de30 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,4 @@ docs/_build/ # Caches .ruff_cache/ .mypy_cache/ +.sandbox/ diff --git a/docs/source/atproto/atproto_oauth.rst b/docs/source/atproto/atproto_oauth.rst new file mode 100644 index 00000000..7a370d69 --- /dev/null +++ b/docs/source/atproto/atproto_oauth.rst @@ -0,0 +1,29 @@ +atproto\_oauth +============== + +.. automodule:: atproto_oauth + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + atproto_oauth.stores + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + atproto_oauth.client + atproto_oauth.dpop + atproto_oauth.exceptions + atproto_oauth.metadata + atproto_oauth.models + atproto_oauth.pkce + atproto_oauth.security diff --git a/docs/source/atproto/atproto_oauth.stores.rst b/docs/source/atproto/atproto_oauth.stores.rst new file mode 100644 index 00000000..98df4077 --- /dev/null +++ b/docs/source/atproto/atproto_oauth.stores.rst @@ -0,0 +1,16 @@ +atproto\_oauth.stores +===================== + +.. automodule:: atproto_oauth.stores + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + atproto_oauth.stores.base + atproto_oauth.stores.memory diff --git a/docs/source/atproto/modules.rst b/docs/source/atproto/modules.rst index 68d9a8fe..2741df89 100644 --- a/docs/source/atproto/modules.rst +++ b/docs/source/atproto/modules.rst @@ -13,4 +13,5 @@ packages atproto_firehose atproto_identity atproto_lexicon + atproto_oauth atproto_server diff --git a/examples/oauth_flask_demo/README.md b/examples/oauth_flask_demo/README.md new file mode 100644 index 00000000..905cf932 --- /dev/null +++ b/examples/oauth_flask_demo/README.md @@ -0,0 +1,143 @@ +# ATProto OAuth Flask Demo + +simple flask application demonstrating OAuth authentication using the `atproto_oauth` package. + +## features + +- complete OAuth 2.1 authorization code flow +- PKCE and DPoP support +- localhost testing (no HTTPS required) +- authenticated API requests + +## quick start + +1. install dependencies: +```bash +uv sync +``` + +2. run the demo: +```bash +uv run python examples/oauth_flask_demo/app.py +``` + +3. visit http://127.0.0.1:5000 + +4. enter your bluesky handle (e.g., `user.bsky.social`) + +5. authorize the app on bluesky + +6. you'll be redirected back to see your profile info + +## how it works + +### 1. start authorization +```python +oauth_client = OAuthClient( + client_id=CLIENT_ID, + redirect_uri=REDIRECT_URI, + scope='atproto', + state_store=MemoryStateStore(), + session_store=MemorySessionStore(), +) + +auth_url, state = await oauth_client.start_authorization(handle) +``` + +### 2. handle callback +```python +oauth_session = await oauth_client.handle_callback(code, state, iss) +``` + +### 3. make authenticated requests +```python +response = await oauth_client.make_authenticated_request( + session=oauth_session, + method='GET', + url=f'{pds_url}/xrpc/com.atproto.repo.describeRepo?repo={did}', +) +``` + +## production considerations + +this is a **development demo** only. for production: + +### use persistent stores +```python +# instead of memory stores +from your_app.stores import DatabaseStateStore, DatabaseSessionStore + +oauth_client = OAuthClient( + state_store=DatabaseStateStore(), + session_store=DatabaseSessionStore(), + ... +) +``` + +### use HTTPS +- deploy with proper TLS certificate +- update client_id to your public HTTPS URL +- create client metadata JSON file + +### security +- use strong secret keys +- implement CSRF protection +- validate all inputs +- handle errors gracefully +- log security events + +### client metadata +for production (non-localhost), create `/oauth-client-metadata.json`: + +```json +{ + "client_id": "https://yourapp.com/oauth-client-metadata.json", + "dpop_bound_access_tokens": true, + "application_type": "web", + "redirect_uris": ["https://yourapp.com/callback"], + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "scope": "atproto", + "token_endpoint_auth_method": "private_key_jwt", + "token_endpoint_auth_signing_alg": "ES256", + "jwks_uri": "https://yourapp.com/oauth/jwks.json", + "client_name": "Your App Name", + "client_uri": "https://yourapp.com" +} +``` + +### confidential client +for server-side apps, generate a client secret key: + +```python +from atproto_oauth.dpop import DPoPManager + +# generate once and store securely +client_secret_key = DPoPManager.generate_keypair() + +oauth_client = OAuthClient( + ..., + client_secret_key=client_secret_key, +) +``` + +## troubleshooting + +### "Invalid state parameter" +- state expired (default TTL: 10 minutes) +- user refreshed callback page +- solution: restart authorization flow + +### "DID mismatch in token" +- identity changed during authorization +- solution: retry with fresh authorization + +### "PDS request failed" +- token expired +- solution: implement automatic token refresh + +## references + +- [ATProto OAuth Spec](https://atproto.com/specs/oauth) +- [Bluesky OAuth Cookbook](https://github.com/bluesky-social/cookbook/tree/main/python-oauth-web-app) +- [atproto Python SDK](https://github.com/MarshalX/atproto) diff --git a/examples/oauth_flask_demo/app.py b/examples/oauth_flask_demo/app.py new file mode 100644 index 00000000..ecafa0da --- /dev/null +++ b/examples/oauth_flask_demo/app.py @@ -0,0 +1,203 @@ +"""Simple Flask OAuth demo for ATProto SDK. + +This demonstrates basic OAuth flow with the atproto_oauth package. +For production use, implement proper session management and error handling. + +Run with: + uv run python examples/oauth_flask_demo/app.py +""" + +import os +from urllib.parse import urlencode + +from atproto_oauth import OAuthClient +from atproto_oauth.stores import MemorySessionStore, MemoryStateStore +from flask import Flask, jsonify, redirect, request, session + +app = Flask(__name__) +app.secret_key = os.getenv('FLASK_SECRET_KEY', 'development-secret-key-change-in-production') + +# OAuth configuration +# For localhost testing, client_id is special localhost URL +REDIRECT_URI = 'http://127.0.0.1:5000/callback' +SCOPE = 'atproto' + +# Create client_id for localhost testing +CLIENT_ID = 'http://localhost?' + urlencode( + { + 'redirect_uri': REDIRECT_URI, + 'scope': SCOPE, + } +) + +# Initialize OAuth client with memory stores (for demo only!) +oauth_client = OAuthClient( + client_id=CLIENT_ID, + redirect_uri=REDIRECT_URI, + scope=SCOPE, + state_store=MemoryStateStore(), + session_store=MemorySessionStore(), +) + + +@app.route('/') +def index() -> str: + """Homepage.""" + if 'user_did' in session: + return f""" + +
+Logged in as: {session.get('user_handle')} ({session.get('user_did')})
+ + + + """ + + return """ + + +{error_msg}'), 500
+
+
+@app.route('/callback')
+def callback() -> str:
+ """Handle OAuth callback."""
+ # Check for errors
+ if error := request.args.get('error'):
+ error_desc = request.args.get('error_description', '')
+ return (
+ f'{error}: {error_desc}
' + f'' + ), 400 + + # Get authorization code and parameters + code = request.args.get('code') + state = request.args.get('state') + iss = request.args.get('iss') + + if not code or not state or not iss: + return 'Missing required parameters', 400 + + # Verify state matches what we stored + stored_state = session.get('oauth_state') + if not stored_state or stored_state != state: + return 'Invalid state parameter', 400 + + try: + # Complete OAuth flow + import asyncio + + oauth_session = asyncio.run(oauth_client.handle_callback(code, state, iss)) + + # Store user info in session + session['user_did'] = oauth_session.did + session['user_handle'] = oauth_session.handle + session.pop('oauth_state', None) + + return redirect('/') + + except Exception as e: # noqa: BLE001 + return f'{e!s}
', 500 + + +@app.route('/logout') +def logout() -> str: + """Logout and revoke OAuth session.""" + user_did = session.get('user_did') + + if user_did: + try: + # Revoke OAuth session + import asyncio + + oauth_session = asyncio.run(oauth_client.session_store.get_session(user_did)) + if oauth_session: + asyncio.run(oauth_client.revoke_session(oauth_session)) + except Exception as e: # noqa: BLE001 + print(f'Error revoking session: {e}') + + # Clear browser session + session.clear() + return redirect('/') + + +@app.route('/api/profile') +def api_profile() -> tuple: + """Example API endpoint using OAuth session.""" + user_did = session.get('user_did') + if not user_did: + return jsonify({'error': 'Not authenticated'}), 401 + + try: + import asyncio + + # Get OAuth session + oauth_session = asyncio.run(oauth_client.session_store.get_session(user_did)) + if not oauth_session: + return jsonify({'error': 'Session not found'}), 401 + + # Make authenticated request to PDS + response = asyncio.run( + oauth_client.make_authenticated_request( + session=oauth_session, + method='GET', + url=f'{oauth_session.pds_url}/xrpc/com.atproto.repo.describeRepo?repo={user_did}', + ) + ) + + if response.status_code != 200: + return jsonify({'error': 'PDS request failed', 'status': response.status_code}), 500 + + return jsonify(response.json()) + + except Exception as e: # noqa: BLE001 + return jsonify({'error': str(e)}), 500 + + +if __name__ == '__main__': + print('Starting ATProto OAuth Flask demo...') + print('Visit http://127.0.0.1:5000 to test OAuth flow') + print() + print('Note: This is a development demo. For production:') + print(' - Use persistent state/session stores (not memory)') + print(' - Implement proper error handling') + print(' - Use HTTPS') + print(' - Set a strong secret key') + print() + app.run(debug=True, port=5000) # noqa: S201 diff --git a/packages/atproto_oauth/README.md b/packages/atproto_oauth/README.md new file mode 100644 index 00000000..528dd99c --- /dev/null +++ b/packages/atproto_oauth/README.md @@ -0,0 +1,388 @@ +# atproto_oauth + +complete OAuth 2.1 implementation for the ATProto Python SDK, following the [ATProto OAuth specification](https://atproto.com/specs/oauth). + +## features + +✅ **full OAuth 2.1 compliance** +- authorization code grant with PKCE (S256) +- DPoP (Demonstrating Proof-of-Possession) with ES256 +- PAR (Pushed Authorization Requests) +- automatic DPoP nonce rotation +- client assertions for confidential clients + +✅ **ATProto-specific** +- DID-based authentication +- handle/DID resolution and verification +- PDS endpoint discovery +- authorization server discovery +- works with Bluesky and custom PDS instances + +✅ **production-ready** +- comprehensive error handling +- SSRF protection and URL validation +- async-first with sync support +- pluggable state/session stores +- fully typed with type hints +- 12 unit tests, all passing + +## quick start + +### basic usage + +```python +from atproto_oauth import OAuthClient +from atproto_oauth.stores import MemorySessionStore, MemoryStateStore + +# create OAuth client +client = OAuthClient( + client_id='http://localhost', # or your HTTPS URL for production + redirect_uri='http://127.0.0.1:5000/callback', + scope='atproto', + state_store=MemoryStateStore(), + session_store=MemorySessionStore(), +) + +# start authorization flow +auth_url, state = await client.start_authorization('user.bsky.social') + +# user authorizes in browser, then... + +# handle callback +session = await client.handle_callback( + code=authorization_code, + state=state, + iss=issuer_from_callback, +) + +# make authenticated requests +response = await client.make_authenticated_request( + session=session, + method='GET', + url=f'{session.pds_url}/xrpc/com.atproto.repo.describeRepo?repo={session.did}', +) +``` + +### flask example + +see [examples/oauth_flask_demo](../../examples/oauth_flask_demo/) for complete working example. + +## installation + +```bash +uv add atproto # includes atproto_oauth +``` + +## core components + +### OAuthClient + +main client for OAuth operations: + +```python +client = OAuthClient( + client_id='https://yourapp.com/client-metadata.json', + redirect_uri='https://yourapp.com/callback', + scope='atproto repo:app.bsky.feed.post', + state_store=your_state_store, + session_store=your_session_store, + client_secret_key=your_secret_key, # optional, for confidential clients +) +``` + +**methods:** +- `start_authorization(handle_or_did)` - begin OAuth flow +- `handle_callback(code, state, iss)` - complete OAuth flow +- `refresh_session(session)` - refresh tokens +- `revoke_session(session)` - revoke tokens +- `make_authenticated_request(session, method, url)` - make DPoP-authenticated requests + +### stores + +pluggable storage for OAuth state and sessions: + +```python +from atproto_oauth.stores import MemoryStateStore, MemorySessionStore + +# memory stores (development only) +state_store = MemoryStateStore(state_ttl_seconds=600) +session_store = MemorySessionStore() +``` + +**custom stores:** + +implement `StateStore` and `SessionStore` interfaces: + +```python +from atproto_oauth.stores.base import StateStore, SessionStore + +class MyDatabaseStateStore(StateStore): + async def save_state(self, state: OAuthState) -> None: + # save to database + pass + + async def get_state(self, state_key: str) -> Optional[OAuthState]: + # retrieve from database + pass + + async def delete_state(self, state_key: str) -> None: + # delete from database + pass +``` + +### models + +```python +from atproto_oauth.models import OAuthSession, OAuthState, AuthServerMetadata + +# OAuth session with tokens +session: OAuthSession + +# access user info +print(session.did, session.handle, session.pds_url) +print(session.access_token, session.refresh_token) +``` + +### utilities + +```python +from atproto_oauth.pkce import PKCEManager +from atproto_oauth.dpop import DPoPManager +from atproto_oauth.metadata import fetch_authserver_metadata_async + +# generate PKCE pair +verifier, challenge = PKCEManager.generate_pair() + +# generate DPoP keypair +dpop_key = DPoPManager.generate_keypair() + +# discover authorization server +metadata = await fetch_authserver_metadata_async('https://bsky.social') +``` + +## architecture + +### authorization flow + +``` +1. user enters handle (e.g., user.bsky.social) +2. resolve handle → DID → PDS endpoint → auth server +3. generate PKCE verifier/challenge +4. generate DPoP keypair +5. send PAR (Pushed Authorization Request) +6. redirect user to authorization server +7. user authorizes +8. callback with authorization code +9. exchange code for tokens (with PKCE + DPoP) +10. store OAuth session +``` + +### security features + +**PKCE (Proof Key for Code Exchange)** +- prevents authorization code interception +- S256 challenge method required + +**DPoP (Demonstrating Proof-of-Possession)** +- binds tokens to client key +- prevents token theft/replay +- automatic nonce rotation + +**state parameter** +- CSRF protection +- cryptographically secure tokens +- single-use, time-limited + +**URL validation** +- SSRF protection +- private IP blocking +- HTTPS enforcement (except localhost) + +## comparison with nick's PR #589 + +### what nick implemented ✅ +- DPoP support in Session class +- static token handling +- basic DPoP JWT generation + +### what this package adds 🆕 +- complete OAuth authorization flow +- PKCE implementation +- PAR (Pushed Authorization Requests) +- state management with stores +- token refresh with OAuth +- client metadata support +- authorization server discovery +- comprehensive error handling +- production-ready security + +## production deployment + +### 1. use persistent stores + +```python +from your_app.stores import RedisStateStore, PostgreSQLSessionStore + +client = OAuthClient( + state_store=RedisStateStore(), + session_store=PostgreSQLSessionStore(), + ... +) +``` + +### 2. deploy with HTTPS + +```python +client = OAuthClient( + client_id='https://yourapp.com/oauth-client-metadata.json', + redirect_uri='https://yourapp.com/callback', + ... +) +``` + +### 3. create client metadata + +serve at `https://yourapp.com/oauth-client-metadata.json`: + +```json +{ + "client_id": "https://yourapp.com/oauth-client-metadata.json", + "dpop_bound_access_tokens": true, + "application_type": "web", + "redirect_uris": ["https://yourapp.com/callback"], + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "scope": "atproto", + "token_endpoint_auth_method": "private_key_jwt", + "token_endpoint_auth_signing_alg": "ES256", + "jwks_uri": "https://yourapp.com/oauth/jwks.json", + "client_name": "Your App Name", + "client_uri": "https://yourapp.com" +} +``` + +### 4. generate client secret (confidential clients) + +```python +from atproto_oauth.dpop import DPoPManager + +# generate once, store securely (e.g., in secrets manager) +client_secret_key = DPoPManager.generate_keypair() + +# use in OAuth client +client = OAuthClient( + client_secret_key=client_secret_key, + ... +) +``` + +### 5. implement token refresh + +```python +# check if token is expired +if session.expires_at and datetime.now(timezone.utc) > session.expires_at: + session = await client.refresh_session(session) +``` + +### 6. handle errors gracefully + +```python +from atproto_oauth.exceptions import OAuthError, OAuthStateError, OAuthTokenError + +try: + session = await client.handle_callback(code, state, iss) +except OAuthStateError as e: + # invalid/expired state - restart flow + pass +except OAuthTokenError as e: + # token exchange failed - show error + pass +except OAuthError as e: + # general OAuth error + pass +``` + +## testing + +run unit tests: + +```bash +uv run pytest tests/test_oauth_pkce.py tests/test_oauth_dpop.py -v +``` + +all 12 tests pass: +- ✅ PKCE verifier/challenge generation +- ✅ DPoP keypair generation +- ✅ DPoP proof JWT creation +- ✅ DPoP nonce error detection +- ✅ access token hash (ath) claim + +## troubleshooting + +### "Invalid or expired state parameter" +**cause:** state expired (default 10 min) or already used +**solution:** restart authorization flow + +### "DID mismatch in token" +**cause:** user identity changed during authorization +**solution:** retry with fresh authorization + +### "Unsupported authorization server" +**cause:** auth server doesn't meet ATProto requirements +**solution:** check server metadata compliance + +### "DPoP nonce error" +**cause:** server requires fresh nonce (handled automatically) +**solution:** library retries automatically + +### import errors +**cause:** missing dependencies +**solution:** `uv sync` to install all dependencies + +## development + +### run tests + +```bash +uv run pytest tests/test_oauth*.py -v +``` + +### run flask demo + +```bash +uv run python examples/oauth_flask_demo/app.py +``` + +### type checking + +```bash +uv run mypy packages/atproto_oauth +``` + +## references + +- [ATProto OAuth Specification](https://atproto.com/specs/oauth) +- [OAuth 2.1](https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1) +- [RFC 9449 - DPoP](https://datatracker.ietf.org/doc/html/rfc9449) +- [RFC 7636 - PKCE](https://datatracker.ietf.org/doc/html/rfc7636) +- [RFC 9126 - PAR](https://datatracker.ietf.org/doc/html/rfc9126) +- [Bluesky OAuth Cookbook](https://github.com/bluesky-social/cookbook/tree/main/python-oauth-web-app) + +## license + +MIT + +## contributing + +contributions welcome! areas for improvement: + +- [ ] SQLite/PostgreSQL session stores +- [ ] Redis state store +- [ ] token automatic refresh +- [ ] scope management helpers +- [ ] more comprehensive tests +- [ ] integration tests with real PDS +- [ ] async/sync sync wrappers +- [ ] documentation improvements + +see main [atproto repository](https://github.com/MarshalX/atproto) for contribution guidelines. diff --git a/packages/atproto_oauth/__init__.py b/packages/atproto_oauth/__init__.py new file mode 100644 index 00000000..feb51b74 --- /dev/null +++ b/packages/atproto_oauth/__init__.py @@ -0,0 +1,21 @@ +"""ATProto OAuth 2.1 implementation.""" + +from atproto_oauth.client import OAuthClient, PromptType +from atproto_oauth.exceptions import ( + OAuthError, + OAuthStateError, + OAuthTokenError, + UnsupportedAuthServerError, +) +from atproto_oauth.models import OAuthSession, OAuthState + +__all__ = [ + 'OAuthClient', + 'OAuthError', + 'OAuthSession', + 'OAuthState', + 'OAuthStateError', + 'OAuthTokenError', + 'PromptType', + 'UnsupportedAuthServerError', +] diff --git a/packages/atproto_oauth/client.py b/packages/atproto_oauth/client.py new file mode 100644 index 00000000..c7f543e0 --- /dev/null +++ b/packages/atproto_oauth/client.py @@ -0,0 +1,634 @@ +"""ATProto OAuth 2.1 client implementation.""" + +import secrets +import time +import typing as t +from urllib.parse import urlencode + +import httpx +from atproto_identity.resolver import AsyncIdResolver + +from atproto_oauth.dpop import DPoPManager +from atproto_oauth.exceptions import OAuthStateError, OAuthTokenError +from atproto_oauth.metadata import ( + discover_authserver_from_pds_async, + fetch_authserver_metadata_async, +) +from atproto_oauth.models import AuthServerMetadata, OAuthSession, OAuthState, TokenResponse +from atproto_oauth.pkce import PKCEManager +from atproto_oauth.security import is_safe_url +from atproto_oauth.stores.base import SessionStore, StateStore + +if t.TYPE_CHECKING: + from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey + +#: Valid values for the OAuth prompt parameter. +#: - 'login': Force re-authentication, ignoring any remembered session. +#: - 'select_account': Show account selection instead of auto-selecting. +#: - 'consent': Force consent screen even if previously approved. +#: - 'none': Silent authentication (fails if user interaction required). +PromptType = t.Literal['login', 'select_account', 'consent', 'none'] + + +def _scopes_are_equivalent(requested: str, granted: str) -> bool: + """Check if granted scopes satisfy requested scopes. + + Handles permission set expansion where PDS expands `include:namespace.permissionSet` + into `repo?collection=...` format. + + Args: + requested: The scope string originally requested (may contain `include:` scopes). + granted: The scope string returned by the PDS (with expanded permissions). + + Returns: + True if granted scopes satisfy all requested permissions. + """ + # fast path: exact match + if requested == granted: + return True + + requested_parts = set(requested.split()) + granted_parts = set(granted.split()) + + # remove 'atproto' prefix from both + requested_parts.discard('atproto') + granted_parts.discard('atproto') + + # separate include: scopes from repo: scopes in requested + requested_includes = {p for p in requested_parts if p.startswith('include:')} + requested_repos = {p for p in requested_parts if p.startswith('repo:')} + + # separate repo?collection= scopes from repo: scopes in granted + granted_expanded = set() # repo?collection= format + granted_repos = set() # repo: format + for p in granted_parts: + if p.startswith('repo?'): + granted_expanded.add(p) + elif p.startswith('repo:'): + granted_repos.add(p) + + # check that all requested repo: scopes are granted + if not requested_repos.issubset(granted_repos): + return False + + # if there are no include: scopes, we're done + if not requested_includes: + return True + + # for include: scopes, verify that we got expanded permissions back + # the PDS expands `include:namespace.permSet` into `repo?collection=ns.col1&collection=ns.col2` + # we just need to verify we got *some* expanded permissions for each namespace + for include_scope in requested_includes: + # extract namespace from include:namespace.permissionSet + # e.g., include:fm.plyr.authFullApp -> fm.plyr + nsid = include_scope.removeprefix('include:') + parts = nsid.split('.') + if len(parts) < 3: + # malformed include scope + return False + # namespace is everything except the last part (the permission set name) + namespace = '.'.join(parts[:-1]) + + # check if any expanded scope references this namespace + found = False + for expanded in granted_expanded: + # repo?collection=fm.plyr.track&collection=fm.plyr.like + if f'collection={namespace}.' in expanded: + found = True + break + if not found: + return False + + return True + + +class OAuthClient: + """ATProto OAuth 2.1 client. + + Implements complete OAuth authorization code flow with PKCE and DPoP. + + Args: + client_id: OAuth client ID (must be HTTPS URL or localhost). + redirect_uri: OAuth redirect URI. + scope: OAuth scopes (space-separated). + state_store: Store for temporary OAuth state. + session_store: Store for OAuth sessions. + client_secret_key: Optional EC private key for confidential clients. + client_secret_kid: Key ID for the client secret key (required if client_secret_key is provided). + """ + + def __init__( + self, + client_id: str, + redirect_uri: str, + scope: str, + state_store: StateStore, + session_store: SessionStore, + client_secret_key: t.Optional['EllipticCurvePrivateKey'] = None, + client_secret_kid: t.Optional[str] = None, + ) -> None: + self.client_id = client_id + self.redirect_uri = redirect_uri + self.scope = scope + self.state_store = state_store + self.session_store = session_store + self.client_secret_key = client_secret_key + self.client_secret_kid = client_secret_kid + + self._id_resolver = AsyncIdResolver() + self._dpop = DPoPManager() + self._pkce = PKCEManager() + + async def start_authorization( + self, + handle_or_did: str, + prompt: t.Optional[PromptType] = None, + ) -> t.Tuple[str, str]: + """Start OAuth authorization flow. + + Args: + handle_or_did: User handle (e.g., 'user.bsky.social') or DID. + prompt: Optional OAuth prompt parameter to control authorization behavior: + - 'login': Force re-authentication, ignoring any remembered session. + - 'select_account': Show account selection instead of auto-selecting. + - 'consent': Force consent screen even if previously approved. + - 'none': Silent authentication (fails if user interaction required). + + Returns: + Tuple of (authorization_url, state) for redirecting user. + + Raises: + ValueError: If handle/DID resolution fails or URL validation fails. + OAuthError: If authorization server discovery or PAR fails. + """ + # 1. Resolve identity + if handle_or_did.startswith('did:'): + # Input is a DID + did = handle_or_did + else: + # Input is a handle - resolve to DID first + resolved_did = await self._id_resolver.handle.resolve(handle_or_did) + if not resolved_did: + raise ValueError(f'Failed to resolve handle: {handle_or_did}') + did = resolved_did + + # 2. Resolve DID to get ATProto data (includes PDS, handle, etc.) + atproto_data = await self._id_resolver.did.resolve_atproto_data(did) + + handle = atproto_data.handle or handle_or_did + pds_url = atproto_data.pds + + if not pds_url: + raise ValueError(f'No PDS endpoint found in DID document for {did}') + + # 3. Discover authorization server + authserver_url = await discover_authserver_from_pds_async(pds_url) + authserver_url = authserver_url.rstrip('/') + + # 4. Fetch authorization server metadata + authserver_meta = await fetch_authserver_metadata_async(authserver_url) + + # 5. Generate PKCE verifier and challenge + pkce_verifier, pkce_challenge = self._pkce.generate_pair() + + # 6. Generate DPoP keypair + dpop_key = self._dpop.generate_keypair() + + # 7. Generate state token + state_token = secrets.token_urlsafe(32) + + # 8. Send PAR (Pushed Authorization Request) + request_uri, dpop_nonce = await self._send_par_request( + authserver_meta=authserver_meta, + login_hint=handle_or_did, + pkce_challenge=pkce_challenge, + dpop_key=dpop_key, + state=state_token, + prompt=prompt, + ) + + # 9. Store state + oauth_state = OAuthState( + state=state_token, + pkce_verifier=pkce_verifier, + redirect_uri=self.redirect_uri, + scope=self.scope, + authserver_iss=authserver_meta.issuer, + dpop_private_key=dpop_key, + dpop_authserver_nonce=dpop_nonce, + did=did, + handle=handle, + pds_url=pds_url, + ) + await self.state_store.save_state(oauth_state) + + # 10. Build authorization URL + auth_params = { + 'client_id': self.client_id, + 'request_uri': request_uri, + } + auth_url = f'{authserver_meta.authorization_endpoint}?{urlencode(auth_params)}' + + if not is_safe_url(auth_url): + raise ValueError(f'Generated unsafe authorization URL: {auth_url}') + + return auth_url, state_token + + async def handle_callback( + self, + code: str, + state: str, + iss: str, + ) -> OAuthSession: + """Handle OAuth callback and complete authorization. + + Args: + code: Authorization code from callback. + state: State parameter from callback. + iss: Issuer parameter from callback. + + Returns: + OAuth session with tokens. + + Raises: + OAuthStateError: If state validation fails. + OAuthTokenError: If token exchange fails. + """ + # 1. Retrieve and verify state + oauth_state = await self.state_store.get_state(state) + if not oauth_state: + raise OAuthStateError('Invalid or expired state parameter') + + if oauth_state.authserver_iss != iss: + raise OAuthStateError(f'Issuer mismatch: expected {oauth_state.authserver_iss}, got {iss}') + + # Delete state (one-time use) + await self.state_store.delete_state(state) + + # 2. Exchange code for tokens + token_response, dpop_nonce = await self._exchange_code_for_tokens( + code=code, + oauth_state=oauth_state, + ) + + # 3. Verify token response + if token_response.sub != oauth_state.did: + raise OAuthTokenError(f'DID mismatch in token: expected {oauth_state.did}, got {token_response.sub}') + + if not _scopes_are_equivalent(self.scope, token_response.scope): + raise OAuthTokenError(f'Scope mismatch: expected {self.scope}, got {token_response.scope}') + + # 4. Create and store session + session = OAuthSession( + did=oauth_state.did or token_response.sub, + handle=oauth_state.handle or '', + pds_url=oauth_state.pds_url or '', + authserver_iss=oauth_state.authserver_iss, + access_token=token_response.access_token, + refresh_token=token_response.refresh_token or '', + dpop_private_key=oauth_state.dpop_private_key, + dpop_authserver_nonce=dpop_nonce, + scope=token_response.scope, + ) + + await self.session_store.save_session(session) + + return session + + async def refresh_session(self, session: OAuthSession) -> OAuthSession: + """Refresh OAuth session tokens. + + Args: + session: Current OAuth session. + + Returns: + Updated OAuth session with new tokens. + + Raises: + OAuthTokenError: If token refresh fails. + """ + # Fetch current auth server metadata + authserver_meta = await fetch_authserver_metadata_async(session.authserver_iss) + + # Prepare refresh token request + params = { + 'grant_type': 'refresh_token', + 'refresh_token': session.refresh_token, + } + + # Make token request with DPoP + dpop_nonce, response = await self._make_token_request( + token_url=authserver_meta.token_endpoint, + params=params, + dpop_key=session.dpop_private_key, + dpop_nonce=session.dpop_authserver_nonce, + issuer=authserver_meta.issuer, + ) + + if response.status_code not in (200, 201): + raise OAuthTokenError(f'Token refresh failed: {response.status_code} {response.text}') + + token_data = response.json() + token_response = TokenResponse( + access_token=token_data['access_token'], + token_type=token_data['token_type'], + scope=token_data['scope'], + sub=token_data['sub'], + refresh_token=token_data.get('refresh_token', session.refresh_token), + expires_in=token_data.get('expires_in'), + ) + + # Update session + session.access_token = token_response.access_token + session.refresh_token = token_response.refresh_token + session.dpop_authserver_nonce = dpop_nonce + + await self.session_store.save_session(session) + + return session + + async def revoke_session(self, session: OAuthSession) -> None: + """Revoke OAuth session tokens. + + Args: + session: OAuth session to revoke. + """ + authserver_meta = await fetch_authserver_metadata_async(session.authserver_iss) + + if not authserver_meta.revocation_endpoint: + # Revocation not supported, just delete local session + await self.session_store.delete_session(session.did) + return + + # Revoke both access and refresh tokens + for token_type in ['access_token', 'refresh_token']: + token = session.access_token if token_type == 'access_token' else session.refresh_token + if not token: + continue + + params = { + 'token': token, + 'token_type_hint': token_type, + } + + try: + await self._make_token_request( + token_url=authserver_meta.revocation_endpoint, + params=params, + dpop_key=session.dpop_private_key, + dpop_nonce=session.dpop_authserver_nonce, + issuer=authserver_meta.issuer, + ) + except (OAuthTokenError, ValueError): + # Best-effort revocation; failures are intentionally silent + pass + + # Delete local session + await self.session_store.delete_session(session.did) + + async def make_authenticated_request( + self, + session: OAuthSession, + method: str, + url: str, + **kwargs: t.Any, + ) -> httpx.Response: + """Make authenticated request to PDS with DPoP. + + Args: + session: OAuth session. + method: HTTP method. + url: Request URL. + **kwargs: Additional request arguments. + + Returns: + HTTP response. + """ + if not is_safe_url(url): + raise ValueError(f'Unsafe URL: {url}') + + # Try request with retry for DPoP nonce + for attempt in range(2): + # Create DPoP proof + dpop_proof = self._dpop.create_proof( + method=method.upper(), + url=url, + private_key=session.dpop_private_key, + nonce=session.dpop_pds_nonce, + access_token=session.access_token, + ) + + # Add auth headers + headers = kwargs.pop('headers', {}) + headers['Authorization'] = f'DPoP {session.access_token}' + headers['DPoP'] = dpop_proof + + # Make request + async with httpx.AsyncClient() as client: + response = await client.request(method, url, headers=headers, **kwargs) + + # Check for DPoP nonce error + if self._dpop.is_dpop_nonce_error(response): + new_nonce = self._dpop.extract_nonce_from_response(response) + if new_nonce and attempt == 0: + session.dpop_pds_nonce = new_nonce + await self.session_store.save_session(session) + continue # Retry with new nonce + + return response + + return response + + async def _send_par_request( + self, + authserver_meta: AuthServerMetadata, + login_hint: str, + pkce_challenge: str, + dpop_key: 'EllipticCurvePrivateKey', + state: str, + prompt: t.Optional[str] = None, + ) -> t.Tuple[str, str]: + """Send Pushed Authorization Request. + + Args: + authserver_meta: Authorization server metadata. + login_hint: User handle or DID hint. + pkce_challenge: PKCE challenge string. + dpop_key: DPoP private key. + state: State token for CSRF protection. + prompt: Optional prompt parameter for authorization behavior. + + Returns: + Tuple of (request_uri, dpop_nonce). + """ + par_url = authserver_meta.pushed_authorization_request_endpoint + + params = { + 'response_type': 'code', + 'code_challenge': pkce_challenge, + 'code_challenge_method': 'S256', + 'state': state, + 'redirect_uri': self.redirect_uri, + 'scope': self.scope, + 'login_hint': login_hint, + } + + if prompt: + params['prompt'] = prompt + + # Make PAR request with DPoP + dpop_nonce, response = await self._make_token_request( + token_url=par_url, + params=params, + dpop_key=dpop_key, + dpop_nonce='', # Initial request has no nonce + issuer=authserver_meta.issuer, + ) + + if response.status_code not in (200, 201): + raise OAuthTokenError(f'PAR request failed: {response.status_code} {response.text}') + + par_response = response.json() + return par_response['request_uri'], dpop_nonce + + async def _exchange_code_for_tokens( + self, + code: str, + oauth_state: OAuthState, + ) -> t.Tuple[TokenResponse, str]: + """Exchange authorization code for tokens. + + Returns: + Tuple of (token_response, dpop_nonce). + """ + # Fetch metadata again (could have changed) + authserver_meta = await fetch_authserver_metadata_async(oauth_state.authserver_iss) + + params = { + 'grant_type': 'authorization_code', + 'code': code, + 'code_verifier': oauth_state.pkce_verifier, + 'redirect_uri': self.redirect_uri, + } + + dpop_nonce, response = await self._make_token_request( + token_url=authserver_meta.token_endpoint, + params=params, + dpop_key=oauth_state.dpop_private_key, + dpop_nonce=oauth_state.dpop_authserver_nonce, + issuer=authserver_meta.issuer, + ) + + if response.status_code not in (200, 201): + raise OAuthTokenError(f'Token exchange failed: {response.status_code} {response.text}') + + token_data = response.json() + token_response = TokenResponse( + access_token=token_data['access_token'], + token_type=token_data['token_type'], + scope=token_data['scope'], + sub=token_data['sub'], + refresh_token=token_data.get('refresh_token'), + expires_in=token_data.get('expires_in'), + ) + + return token_response, dpop_nonce + + async def _make_token_request( + self, + token_url: str, + params: t.Dict[str, str], + dpop_key: 'EllipticCurvePrivateKey', + dpop_nonce: str, + issuer: t.Optional[str] = None, + ) -> t.Tuple[str, httpx.Response]: + """Make token request with DPoP and client assertion. + + Handles DPoP nonce rotation automatically. + + Args: + token_url: The token endpoint URL. + params: Request parameters. + dpop_key: DPoP private key. + dpop_nonce: Current DPoP nonce. + issuer: Authorization server issuer (required for confidential clients). + Per ATProto OAuth spec, the aud claim must be the issuer. + + Returns: + Tuple of (updated_dpop_nonce, response). + """ + if not is_safe_url(token_url): + raise ValueError(f'Unsafe token URL: {token_url}') + + # Add client authentication + if self.client_secret_key: + # Confidential client - use client assertion + # Per ATProto OAuth spec: "The aud claim (audience) of the client + # assertion JWT must be the Authorization Server's issuer." + if not issuer: + raise ValueError('issuer required for confidential client authentication') + client_assertion = self._create_client_assertion(issuer) + params['client_id'] = self.client_id + params['client_assertion_type'] = 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer' + params['client_assertion'] = client_assertion + else: + # Public client + params['client_id'] = self.client_id + + # Try request with DPoP nonce retry + current_nonce = dpop_nonce + for attempt in range(2): + # Create DPoP proof + dpop_proof = self._dpop.create_proof( + method='POST', + url=token_url, + private_key=dpop_key, + nonce=current_nonce if current_nonce else None, + ) + + # Make request + async with httpx.AsyncClient() as client: + response = await client.post( + token_url, + data=params, + headers={'DPoP': dpop_proof}, + ) + + # Check for DPoP nonce error + if self._dpop.is_dpop_nonce_error(response): + new_nonce = self._dpop.extract_nonce_from_response(response) + if new_nonce and attempt == 0: + current_nonce = new_nonce + continue # Retry with new nonce + + # Extract final nonce + final_nonce = self._dpop.extract_nonce_from_response(response) or current_nonce + + return final_nonce, response + + return current_nonce, response + + def _create_client_assertion(self, audience: str) -> str: + """Create client assertion JWT for confidential client.""" + if not self.client_secret_key: + raise ValueError('Client secret key required for client assertion') + if not self.client_secret_kid: + raise ValueError('Client secret kid required for client assertion') + + header = { + 'alg': 'ES256', + 'typ': 'JWT', + 'kid': self.client_secret_kid, + } + + now = int(time.time()) + payload = { + 'iss': self.client_id, + 'sub': self.client_id, + 'aud': audience, + 'jti': secrets.token_urlsafe(16), + 'iat': now, + 'exp': now + 60, # Valid for 60 seconds + } + + return self._dpop._sign_jwt(header, payload, self.client_secret_key) diff --git a/packages/atproto_oauth/dpop.py b/packages/atproto_oauth/dpop.py new file mode 100644 index 00000000..319f8dd6 --- /dev/null +++ b/packages/atproto_oauth/dpop.py @@ -0,0 +1,213 @@ +"""DPoP (Demonstrating Proof-of-Possession) implementation.""" + +import hashlib +import json +import secrets +import time +import typing as t +from base64 import urlsafe_b64encode +from urllib.parse import urlparse + +import httpx +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature + +if t.TYPE_CHECKING: + from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey + + +class DPoPManager: + """Manages DPoP proof generation for OAuth.""" + + @staticmethod + def generate_keypair() -> 'EllipticCurvePrivateKey': + """Generate ES256 keypair for DPoP. + + Returns: + EC private key (P-256 curve). + """ + return ec.generate_private_key(ec.SECP256R1()) + + @staticmethod + def _key_to_jwk(private_key: 'EllipticCurvePrivateKey', include_private: bool = False) -> t.Dict[str, t.Any]: + """Convert EC private key to JWK format. + + Args: + private_key: The EC private key. + include_private: Whether to include private key components. + + Returns: + JWK dictionary. + """ + public_key = private_key.public_key() + public_numbers = public_key.public_numbers() + + # Convert to bytes and base64url encode + def int_to_base64url(n: int, length: int) -> str: + byte_len = (length + 7) // 8 + return urlsafe_b64encode(n.to_bytes(byte_len, 'big')).decode('utf-8').rstrip('=') + + jwk = { + 'kty': 'EC', + 'crv': 'P-256', + 'x': int_to_base64url(public_numbers.x, 256), + 'y': int_to_base64url(public_numbers.y, 256), + } + + if include_private: + private_numbers = private_key.private_numbers() + jwk['d'] = int_to_base64url(private_numbers.private_value, 256) + + return jwk + + @staticmethod + def _sign_jwt( + header: t.Dict[str, t.Any], payload: t.Dict[str, t.Any], private_key: 'EllipticCurvePrivateKey' + ) -> str: + """Sign a JWT using ES256. + + Args: + header: JWT header. + payload: JWT payload. + private_key: EC private key for signing. + + Returns: + Complete JWT string. + """ + from cryptography.hazmat.primitives import hashes + from cryptography.hazmat.primitives.asymmetric import ec + + # Encode header and payload + header_b64 = urlsafe_b64encode(json.dumps(header, separators=(',', ':')).encode()).decode().rstrip('=') + payload_b64 = urlsafe_b64encode(json.dumps(payload, separators=(',', ':')).encode()).decode().rstrip('=') + + # Create signing input + signing_input = f'{header_b64}.{payload_b64}'.encode() + + # Sign (returns DER-encoded signature) + der_signature = private_key.sign(signing_input, ec.ECDSA(hashes.SHA256())) + + # Convert DER signature to IEEE P1363 format (raw r|s concatenated) + # ES256 uses P-256 curve, so r and s are each 32 bytes + r, s = decode_dss_signature(der_signature) + + # Convert r and s to 32-byte big-endian sequences + r_bytes = r.to_bytes(32, 'big') + s_bytes = s.to_bytes(32, 'big') + + # Concatenate and encode + raw_signature = r_bytes + s_bytes + signature_b64 = urlsafe_b64encode(raw_signature).decode().rstrip('=') + + return f'{header_b64}.{payload_b64}.{signature_b64}' + + @classmethod + def create_proof( + cls, + method: str, + url: str, + private_key: 'EllipticCurvePrivateKey', + nonce: t.Optional[str] = None, + access_token: t.Optional[str] = None, + ) -> str: + """Generate DPoP proof JWT. + + Args: + method: HTTP method (e.g., 'GET', 'POST'). + url: Full URL of the request. + private_key: EC private key for signing. + nonce: Optional server-provided nonce. + access_token: Optional access token (for 'ath' claim). + + Returns: + DPoP proof JWT string. + """ + # Get public key JWK + public_jwk = cls._key_to_jwk(private_key, include_private=False) + + # Create header + header = { + 'typ': 'dpop+jwt', + 'alg': 'ES256', + 'jwk': public_jwk, + } + + # Strip query and fragment from URL per RFC 9449 + parsed_url = urlparse(url) + htu = f'{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path}' + + # Create payload + now = int(time.time()) + payload = { + 'jti': secrets.token_urlsafe(16), + 'htm': method.upper(), + 'htu': htu, + 'iat': now, + 'exp': now + 60, # Valid for 60 seconds + } + + # Add optional claims + if nonce: + payload['nonce'] = nonce + + if access_token: + # Hash access token for 'ath' claim (same as PKCE S256) + ath_hash = hashlib.sha256(access_token.encode('utf-8')).digest() + payload['ath'] = urlsafe_b64encode(ath_hash).decode('utf-8').rstrip('=') + + return cls._sign_jwt(header, payload, private_key) + + @staticmethod + def extract_nonce_from_response(response: httpx.Response) -> t.Optional[str]: + """Extract DPoP nonce from HTTP response. + + Checks both the 'DPoP-Nonce' header and error responses. + + Args: + response: HTTP response object. + + Returns: + DPoP nonce string if present, None otherwise. + """ + # Check DPoP-Nonce header + if nonce := response.headers.get('DPoP-Nonce'): + return nonce + + # Check for error response with use_dpop_nonce + if response.status_code in (400, 401): + try: + error_body = response.json() + if isinstance(error_body, dict) and error_body.get('error') == 'use_dpop_nonce': + return response.headers.get('DPoP-Nonce') + except Exception: + pass + + return None + + @staticmethod + def is_dpop_nonce_error(response: httpx.Response) -> bool: + """Check if response indicates DPoP nonce error. + + Args: + response: HTTP response object. + + Returns: + True if response indicates need for new DPoP nonce. + """ + if response.status_code not in (400, 401): + return False + + # Check WWW-Authenticate header + if www_auth := response.headers.get('WWW-Authenticate', ''): + if 'use_dpop_nonce' in www_auth.lower(): + return True + + # Check JSON error response + try: + error_body = response.json() + if isinstance(error_body, dict) and error_body.get('error') == 'use_dpop_nonce': + return True + except Exception: + pass + + return False diff --git a/packages/atproto_oauth/exceptions.py b/packages/atproto_oauth/exceptions.py new file mode 100644 index 00000000..bf510ca4 --- /dev/null +++ b/packages/atproto_oauth/exceptions.py @@ -0,0 +1,16 @@ +"""OAuth-specific exceptions.""" + +from atproto_core.exceptions import AtProtocolError + + +class OAuthError(AtProtocolError): + """Base exception for OAuth errors.""" + + +class OAuthStateError(OAuthError): ... + + +class OAuthTokenError(OAuthError): ... + + +class UnsupportedAuthServerError(OAuthError): ... diff --git a/packages/atproto_oauth/metadata.py b/packages/atproto_oauth/metadata.py new file mode 100644 index 00000000..3382a879 --- /dev/null +++ b/packages/atproto_oauth/metadata.py @@ -0,0 +1,184 @@ +"""Authorization server metadata discovery.""" + +import httpx + +from atproto_oauth.exceptions import UnsupportedAuthServerError +from atproto_oauth.models import AuthServerMetadata +from atproto_oauth.security import is_safe_url, validate_authserver_metadata + + +async def discover_authserver_from_pds_async(pds_url: str, timeout: float = 5.0) -> str: + """Discover authorization server URL from PDS. + + Args: + pds_url: PDS endpoint URL. + timeout: Request timeout in seconds. + + Returns: + Authorization server URL. + + Raises: + ValueError: If PDS URL is unsafe or response is invalid. + httpx.HTTPError: If request fails. + """ + if not is_safe_url(pds_url): + raise ValueError(f'Unsafe PDS URL: {pds_url}') + + async with httpx.AsyncClient(timeout=timeout) as client: + response = await client.get(f'{pds_url}/.well-known/oauth-protected-resource') + response.raise_for_status() + + if response.status_code != 200: + raise ValueError(f'PDS returned non-200 status: {response.status_code}') + + data = response.json() + if not isinstance(data, dict) or 'authorization_servers' not in data: + raise ValueError('Invalid oauth-protected-resource response') + + auth_servers = data['authorization_servers'] + if not auth_servers or not isinstance(auth_servers, list): + raise ValueError('No authorization servers found') + + return auth_servers[0] + + +def discover_authserver_from_pds(pds_url: str, timeout: float = 5.0) -> str: + """Discover authorization server URL from PDS (synchronous). + + Args: + pds_url: PDS endpoint URL. + timeout: Request timeout in seconds. + + Returns: + Authorization server URL. + """ + if not is_safe_url(pds_url): + raise ValueError(f'Unsafe PDS URL: {pds_url}') + + with httpx.Client(timeout=timeout) as client: + response = client.get(f'{pds_url}/.well-known/oauth-protected-resource') + response.raise_for_status() + + if response.status_code != 200: + raise ValueError(f'PDS returned non-200 status: {response.status_code}') + + data = response.json() + if not isinstance(data, dict) or 'authorization_servers' not in data: + raise ValueError('Invalid oauth-protected-resource response') + + auth_servers = data['authorization_servers'] + if not auth_servers or not isinstance(auth_servers, list): + raise ValueError('No authorization servers found') + + return auth_servers[0] + + +async def fetch_authserver_metadata_async(authserver_url: str, timeout: float = 5.0) -> AuthServerMetadata: + """Fetch and validate authorization server metadata. + + Args: + authserver_url: Authorization server URL. + timeout: Request timeout in seconds. + + Returns: + Validated metadata object. + + Raises: + ValueError: If URL is unsafe. + UnsupportedAuthServerError: If metadata doesn't meet requirements. + httpx.HTTPError: If request fails. + """ + if not is_safe_url(authserver_url): + raise ValueError(f'Unsafe authorization server URL: {authserver_url}') + + fetch_url = f'{authserver_url}/.well-known/oauth-authorization-server' + + async with httpx.AsyncClient(timeout=timeout) as client: + response = await client.get(fetch_url) + response.raise_for_status() + + metadata_dict = response.json() + + # Validate against ATProto requirements + try: + validate_authserver_metadata(metadata_dict, fetch_url) + except ValueError as e: + raise UnsupportedAuthServerError(str(e)) from e + + # Parse into model + return AuthServerMetadata( + issuer=metadata_dict['issuer'], + authorization_endpoint=metadata_dict['authorization_endpoint'], + token_endpoint=metadata_dict['token_endpoint'], + pushed_authorization_request_endpoint=metadata_dict['pushed_authorization_request_endpoint'], + response_types_supported=metadata_dict['response_types_supported'], + grant_types_supported=metadata_dict['grant_types_supported'], + code_challenge_methods_supported=metadata_dict['code_challenge_methods_supported'], + token_endpoint_auth_methods_supported=metadata_dict['token_endpoint_auth_methods_supported'], + token_endpoint_auth_signing_alg_values_supported=metadata_dict[ + 'token_endpoint_auth_signing_alg_values_supported' + ], + scopes_supported=metadata_dict['scopes_supported'], + dpop_signing_alg_values_supported=metadata_dict['dpop_signing_alg_values_supported'], + authorization_response_iss_parameter_supported=metadata_dict[ + 'authorization_response_iss_parameter_supported' + ], + require_pushed_authorization_requests=metadata_dict['require_pushed_authorization_requests'], + client_id_metadata_document_supported=metadata_dict['client_id_metadata_document_supported'], + revocation_endpoint=metadata_dict.get('revocation_endpoint'), + jwks_uri=metadata_dict.get('jwks_uri'), + require_request_uri_registration=metadata_dict.get('require_request_uri_registration'), + ) + + +def fetch_authserver_metadata(authserver_url: str, timeout: float = 5.0) -> AuthServerMetadata: + """Fetch and validate authorization server metadata (synchronous). + + Args: + authserver_url: Authorization server URL. + timeout: Request timeout in seconds. + + Returns: + Validated metadata object. + """ + if not is_safe_url(authserver_url): + raise ValueError(f'Unsafe authorization server URL: {authserver_url}') + + fetch_url = f'{authserver_url}/.well-known/oauth-authorization-server' + + with httpx.Client(timeout=timeout) as client: + response = client.get(fetch_url) + response.raise_for_status() + + metadata_dict = response.json() + + # Validate against ATProto requirements + try: + validate_authserver_metadata(metadata_dict, fetch_url) + except ValueError as e: + raise UnsupportedAuthServerError(str(e)) from e + + # Parse into model + return AuthServerMetadata( + issuer=metadata_dict['issuer'], + authorization_endpoint=metadata_dict['authorization_endpoint'], + token_endpoint=metadata_dict['token_endpoint'], + pushed_authorization_request_endpoint=metadata_dict['pushed_authorization_request_endpoint'], + response_types_supported=metadata_dict['response_types_supported'], + grant_types_supported=metadata_dict['grant_types_supported'], + code_challenge_methods_supported=metadata_dict['code_challenge_methods_supported'], + token_endpoint_auth_methods_supported=metadata_dict['token_endpoint_auth_methods_supported'], + token_endpoint_auth_signing_alg_values_supported=metadata_dict[ + 'token_endpoint_auth_signing_alg_values_supported' + ], + scopes_supported=metadata_dict['scopes_supported'], + dpop_signing_alg_values_supported=metadata_dict['dpop_signing_alg_values_supported'], + authorization_response_iss_parameter_supported=metadata_dict[ + 'authorization_response_iss_parameter_supported' + ], + require_pushed_authorization_requests=metadata_dict['require_pushed_authorization_requests'], + client_id_metadata_document_supported=metadata_dict['client_id_metadata_document_supported'], + revocation_endpoint=metadata_dict.get('revocation_endpoint'), + jwks_uri=metadata_dict.get('jwks_uri'), + require_request_uri_registration=metadata_dict.get('require_request_uri_registration'), + ) diff --git a/packages/atproto_oauth/models.py b/packages/atproto_oauth/models.py new file mode 100644 index 00000000..9e245e4e --- /dev/null +++ b/packages/atproto_oauth/models.py @@ -0,0 +1,78 @@ +"""OAuth data models.""" + +import typing as t +from dataclasses import dataclass, field +from datetime import datetime, timezone + +if t.TYPE_CHECKING: + from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey + + +@dataclass +class AuthServerMetadata: + """Authorization Server metadata from discovery.""" + + issuer: str + authorization_endpoint: str + token_endpoint: str + pushed_authorization_request_endpoint: str + response_types_supported: t.List[str] + grant_types_supported: t.List[str] + code_challenge_methods_supported: t.List[str] + token_endpoint_auth_methods_supported: t.List[str] + token_endpoint_auth_signing_alg_values_supported: t.List[str] + scopes_supported: t.List[str] + dpop_signing_alg_values_supported: t.List[str] + authorization_response_iss_parameter_supported: bool + require_pushed_authorization_requests: bool + client_id_metadata_document_supported: bool + revocation_endpoint: t.Optional[str] = None + jwks_uri: t.Optional[str] = None + require_request_uri_registration: t.Optional[bool] = None + + +@dataclass +class OAuthState: + """OAuth state for CSRF protection during authorization flow.""" + + state: str + pkce_verifier: str + redirect_uri: str + scope: str + authserver_iss: str + dpop_private_key: 'EllipticCurvePrivateKey' + dpop_authserver_nonce: str + did: t.Optional[str] = None + handle: t.Optional[str] = None + pds_url: t.Optional[str] = None + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + + +@dataclass +class OAuthSession: + """OAuth session with tokens and metadata.""" + + did: str + handle: str + pds_url: str + authserver_iss: str + access_token: str + refresh_token: str + dpop_private_key: 'EllipticCurvePrivateKey' + dpop_authserver_nonce: str + scope: str + dpop_pds_nonce: t.Optional[str] = None + expires_at: t.Optional[datetime] = None + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + + +@dataclass +class TokenResponse: + """Token response from authorization server.""" + + access_token: str + token_type: str + scope: str + sub: str # DID + refresh_token: t.Optional[str] = None + expires_in: t.Optional[int] = None diff --git a/packages/atproto_oauth/pkce.py b/packages/atproto_oauth/pkce.py new file mode 100644 index 00000000..a2876051 --- /dev/null +++ b/packages/atproto_oauth/pkce.py @@ -0,0 +1,57 @@ +"""PKCE (Proof Key for Code Exchange) implementation.""" + +import base64 +import hashlib +import secrets +import typing as t + + +class PKCEManager: + """Manages PKCE code verifier and challenge generation.""" + + @staticmethod + def generate_verifier(length: int = 128) -> str: + """Generate a PKCE code verifier. + + Args: + length: Length of the verifier (43-128 characters). + + Returns: + Base64url-encoded verifier string. + + Raises: + ValueError: If length is not between 43 and 128. + """ + if not 43 <= length <= 128: + raise ValueError('PKCE verifier length must be between 43 and 128') + + # Generate random bytes and encode as base64url + verifier_bytes = secrets.token_bytes(length) + return base64.urlsafe_b64encode(verifier_bytes).decode('utf-8').rstrip('=')[:length] + + @staticmethod + def generate_challenge(verifier: str) -> str: + """Generate S256 PKCE code challenge from verifier. + + Args: + verifier: The code verifier string. + + Returns: + Base64url-encoded SHA256 hash of the verifier. + """ + digest = hashlib.sha256(verifier.encode('utf-8')).digest() + return base64.urlsafe_b64encode(digest).decode('utf-8').rstrip('=') + + @classmethod + def generate_pair(cls, length: int = 128) -> t.Tuple[str, str]: + """Generate both verifier and challenge. + + Args: + length: Length of the verifier. + + Returns: + Tuple of (verifier, challenge). + """ + verifier = cls.generate_verifier(length) + challenge = cls.generate_challenge(verifier) + return verifier, challenge diff --git a/packages/atproto_oauth/security.py b/packages/atproto_oauth/security.py new file mode 100644 index 00000000..ec2559df --- /dev/null +++ b/packages/atproto_oauth/security.py @@ -0,0 +1,186 @@ +"""Security utilities for OAuth implementation.""" + +import typing as t +from urllib.parse import urlparse + +import httpx + +# Hardened HTTP client configuration +DEFAULT_TIMEOUT = 5.0 +MAX_REDIRECTS = 3 +ALLOWED_SCHEMES = {'https', 'http'} # http only for localhost +BLOCKED_HOSTS = { + '0.0.0.0', + '127.0.0.1', + 'localhost', + '::1', + '169.254.169.254', # AWS metadata + 'metadata.google.internal', # GCP metadata +} + + +def is_safe_url(url: str, allow_localhost: bool = False) -> bool: + """Validate URL for security (SSRF protection). + + Args: + url: URL to validate. + allow_localhost: Whether to allow localhost URLs. + + Returns: + True if URL is safe to use. + """ + try: + parsed = urlparse(url) + + # Check scheme + if parsed.scheme not in ALLOWED_SCHEMES: + return False + + # For http, only allow if allow_localhost and is actually localhost + if parsed.scheme == 'http': + if not allow_localhost: + return False + if parsed.hostname not in ('localhost', '127.0.0.1', '::1'): + return False + + # Check for blocked hosts + if parsed.hostname in BLOCKED_HOSTS and not allow_localhost: + return False + + # Check for IP addresses in private ranges (basic check) + if parsed.hostname: + # Block obvious private IPs + if parsed.hostname.startswith('10.'): + return False + if parsed.hostname.startswith('172.') and 16 <= int(parsed.hostname.split('.')[1]) <= 31: + return False + if parsed.hostname.startswith('192.168.'): + return False + + return True + except Exception: + return False + + +def get_hardened_client( + timeout: float = DEFAULT_TIMEOUT, + max_redirects: int = MAX_REDIRECTS, +) -> httpx.Client: + """Create hardened HTTP client with security settings. + + Args: + timeout: Request timeout in seconds. + max_redirects: Maximum number of redirects to follow. + + Returns: + Configured httpx.Client. + """ + return httpx.Client( + timeout=timeout, + follow_redirects=True, + max_redirects=max_redirects, + limits=httpx.Limits(max_connections=10, max_keepalive_connections=5), + ) + + +def get_hardened_async_client( + timeout: float = DEFAULT_TIMEOUT, + max_redirects: int = MAX_REDIRECTS, +) -> httpx.AsyncClient: + """Create hardened async HTTP client with security settings. + + Args: + timeout: Request timeout in seconds. + max_redirects: Maximum number of redirects to follow. + + Returns: + Configured httpx.AsyncClient. + """ + return httpx.AsyncClient( + timeout=timeout, + follow_redirects=True, + max_redirects=max_redirects, + limits=httpx.Limits(max_connections=10, max_keepalive_connections=5), + ) + + +def validate_authserver_metadata(metadata: t.Dict[str, t.Any], fetch_url: str) -> None: + """Validate authorization server metadata against ATProto requirements. + + Args: + metadata: Metadata dictionary from server. + fetch_url: URL where metadata was fetched from. + + Raises: + ValueError: If metadata doesn't meet requirements. + """ + issuer_url = urlparse(metadata['issuer']) + fetch_parsed = urlparse(fetch_url) + + # Issuer must match fetch URL host + if issuer_url.hostname != fetch_parsed.hostname: + raise ValueError(f'Issuer hostname mismatch: {issuer_url.hostname} != {fetch_parsed.hostname}') + + # Issuer must be HTTPS with no path/params/fragment + if issuer_url.scheme != 'https': + raise ValueError(f'Issuer must be HTTPS: {issuer_url.scheme}') + if issuer_url.port is not None: + raise ValueError(f'Issuer must not have explicit port: {issuer_url.port}') + if issuer_url.path not in ('', '/'): + raise ValueError(f'Issuer must not have path: {issuer_url.path}') + if issuer_url.params or issuer_url.fragment: + raise ValueError('Issuer must not have params or fragment') + + # Check required grant types and methods + required_checks = [ + ('code' in metadata.get('response_types_supported', []), 'response_types_supported must include "code"'), + ( + 'authorization_code' in metadata.get('grant_types_supported', []), + 'grant_types_supported must include "authorization_code"', + ), + ( + 'refresh_token' in metadata.get('grant_types_supported', []), + 'grant_types_supported must include "refresh_token"', + ), + ( + 'S256' in metadata.get('code_challenge_methods_supported', []), + 'code_challenge_methods_supported must include "S256"', + ), + ( + 'none' in metadata.get('token_endpoint_auth_methods_supported', []), + 'token_endpoint_auth_methods_supported must include "none"', + ), + ( + 'private_key_jwt' in metadata.get('token_endpoint_auth_methods_supported', []), + 'token_endpoint_auth_methods_supported must include "private_key_jwt"', + ), + ( + 'ES256' in metadata.get('token_endpoint_auth_signing_alg_values_supported', []), + 'token_endpoint_auth_signing_alg_values_supported must include "ES256"', + ), + ('atproto' in metadata.get('scopes_supported', []), 'scopes_supported must include "atproto"'), + ( + metadata.get('authorization_response_iss_parameter_supported') is True, + 'authorization_response_iss_parameter_supported must be true', + ), + ( + metadata.get('pushed_authorization_request_endpoint') is not None, + 'pushed_authorization_request_endpoint is required', + ), + ( + metadata.get('require_pushed_authorization_requests') is True, + 'require_pushed_authorization_requests must be true', + ), + ( + 'ES256' in metadata.get('dpop_signing_alg_values_supported', []), + 'dpop_signing_alg_values_supported must include "ES256"', + ), + ( + metadata.get('client_id_metadata_document_supported') is True, + 'client_id_metadata_document_supported must be true', + ), + ] + + for check, error_msg in required_checks: + if not check: + raise ValueError(error_msg) diff --git a/packages/atproto_oauth/stores/__init__.py b/packages/atproto_oauth/stores/__init__.py new file mode 100644 index 00000000..8d6b922f --- /dev/null +++ b/packages/atproto_oauth/stores/__init__.py @@ -0,0 +1,11 @@ +"""OAuth state and session stores.""" + +from atproto_oauth.stores.base import SessionStore, StateStore +from atproto_oauth.stores.memory import MemorySessionStore, MemoryStateStore + +__all__ = [ + 'MemorySessionStore', + 'MemoryStateStore', + 'SessionStore', + 'StateStore', +] diff --git a/packages/atproto_oauth/stores/base.py b/packages/atproto_oauth/stores/base.py new file mode 100644 index 00000000..ea468c63 --- /dev/null +++ b/packages/atproto_oauth/stores/base.py @@ -0,0 +1,69 @@ +"""Abstract base classes for OAuth stores.""" + +import typing as t +from abc import ABC, abstractmethod + +if t.TYPE_CHECKING: + from atproto_oauth.models import OAuthSession, OAuthState + + +class StateStore(ABC): + """Abstract store for OAuth state during authorization flow.""" + + @abstractmethod + async def save_state(self, state: 'OAuthState') -> None: + """Save OAuth state. + + Args: + state: OAuth state object to save. + """ + + @abstractmethod + async def get_state(self, state_key: str) -> t.Optional['OAuthState']: + """Retrieve OAuth state by key. + + Args: + state_key: State identifier. + + Returns: + OAuth state object if found, None otherwise. + """ + + @abstractmethod + async def delete_state(self, state_key: str) -> None: + """Delete OAuth state by key. + + Args: + state_key: State identifier. + """ + + +class SessionStore(ABC): + """Abstract store for OAuth sessions.""" + + @abstractmethod + async def save_session(self, session: 'OAuthSession') -> None: + """Save OAuth session. + + Args: + session: OAuth session object to save. + """ + + @abstractmethod + async def get_session(self, did: str) -> t.Optional['OAuthSession']: + """Retrieve OAuth session by DID. + + Args: + did: User DID. + + Returns: + OAuth session object if found, None otherwise. + """ + + @abstractmethod + async def delete_session(self, did: str) -> None: + """Delete OAuth session by DID. + + Args: + did: User DID. + """ diff --git a/packages/atproto_oauth/stores/memory.py b/packages/atproto_oauth/stores/memory.py new file mode 100644 index 00000000..7237c8d5 --- /dev/null +++ b/packages/atproto_oauth/stores/memory.py @@ -0,0 +1,71 @@ +"""In-memory OAuth stores for development.""" + +import typing as t +from datetime import datetime, timedelta, timezone + +from atproto_oauth.models import OAuthSession, OAuthState +from atproto_oauth.stores.base import SessionStore, StateStore + + +class MemoryStateStore(StateStore): + """In-memory OAuth state store. + + Warning: + This store is not suitable for production use in multi-process + or distributed environments. Use a persistent store instead. + """ + + def __init__(self, state_ttl_seconds: int = 600) -> None: + """Initialize memory state store. + + Args: + state_ttl_seconds: Time-to-live for state entries in seconds. + """ + self._states: t.Dict[str, OAuthState] = {} + self._state_ttl = timedelta(seconds=state_ttl_seconds) + + async def save_state(self, state: OAuthState) -> None: + """Save OAuth state.""" + self._cleanup_expired_states() + self._states[state.state] = state + + async def get_state(self, state_key: str) -> t.Optional[OAuthState]: + """Retrieve OAuth state by key.""" + self._cleanup_expired_states() + return self._states.get(state_key) + + async def delete_state(self, state_key: str) -> None: + """Delete OAuth state by key.""" + self._states.pop(state_key, None) + + def _cleanup_expired_states(self) -> None: + """Remove expired state entries.""" + now = datetime.now(timezone.utc) + expired = [state_key for state_key, state in self._states.items() if (now - state.created_at) > self._state_ttl] + for state_key in expired: + del self._states[state_key] + + +class MemorySessionStore(SessionStore): + """In-memory OAuth session store. + + Warning: + This store is not suitable for production use in multi-process + or distributed environments. Use a persistent store instead. + """ + + def __init__(self) -> None: + """Initialize memory session store.""" + self._sessions: t.Dict[str, OAuthSession] = {} + + async def save_session(self, session: OAuthSession) -> None: + """Save OAuth session.""" + self._sessions[session.did] = session + + async def get_session(self, did: str) -> t.Optional[OAuthSession]: + """Retrieve OAuth session by DID.""" + return self._sessions.get(did) + + async def delete_session(self, did: str) -> None: + """Delete OAuth session by DID.""" + self._sessions.pop(did, None) diff --git a/pyproject.toml b/pyproject.toml index b9e09ea2..969f8356 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ packages = [ { include = "atproto_firehose", from = "packages" }, { include = "atproto_identity", from = "packages" }, { include = "atproto_lexicon", from = "packages" }, + { include = "atproto_oauth", from = "packages" }, { include = "atproto_server", from = "packages" }, ] @@ -82,7 +83,7 @@ coverage = "7.6.1" [tool.poetry-dynamic-versioning] # poetry self add "poetry-dynamic-versioning[plugin]" enable = true -strict = true +strict = false bump = true metadata = false fix-shallow-repository = true @@ -153,7 +154,16 @@ ignore = [ "D105", "D104", "D100", "D107", "D103", "D415", # missing docstring "D101", "D102", # missing docstring in public class and method ] -"tests/*.py" = ["D"] +"packages/atproto_oauth/*.py" = [ + "S105", # hardcoded password (false positives for token_type literals) + "S104", # binding to 0.0.0.0 (SSRF protection list) + "BLE001", # broad exception (intentional for robustness) + "S110", # try-except-pass (intentional silent failures) + "SIM105", # contextlib.suppress (try-except-pass is clearer here) + "SIM102", # nested if (separate conditions are more readable) + "C901", # function complexity (security code needs thoroughness) +] +"tests/*.py" = ["D", "S105"] # docstrings, hardcoded password (test tokens) "docs/*.py" = ["T201", "INP001", "ERA001", "E501", "D"] "examples/*.py" = ["T201", "INP001", "ERA001", "D"] "packages/atproto/exceptions.py" = ["F403"] diff --git a/tests/test_oauth_client.py b/tests/test_oauth_client.py new file mode 100644 index 00000000..3c1d30d9 --- /dev/null +++ b/tests/test_oauth_client.py @@ -0,0 +1,134 @@ +"""Tests for OAuth client implementation.""" + +import typing as t +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from atproto_oauth import OAuthClient, PromptType +from atproto_oauth.stores.memory import MemorySessionStore, MemoryStateStore + + +@pytest.fixture +def oauth_client() -> OAuthClient: + """Create an OAuth client for testing.""" + return OAuthClient( + client_id='https://example.com/client-metadata.json', + redirect_uri='https://example.com/callback', + scope='atproto', + state_store=MemoryStateStore(), + session_store=MemorySessionStore(), + ) + + +def test_prompt_type_values() -> None: + """Test that PromptType includes all valid values.""" + valid_prompts: list[PromptType] = ['login', 'select_account', 'consent', 'none'] + assert len(valid_prompts) == 4 + + +def test_prompt_type_is_exported() -> None: + """Test that PromptType is exported from the package.""" + from atproto_oauth import PromptType as ImportedPromptType + + assert ImportedPromptType is PromptType + + +@pytest.mark.asyncio +@pytest.mark.parametrize('prompt', ['login', 'select_account', 'consent', 'none', None]) +async def test_prompt_passed_to_par_request(oauth_client: OAuthClient, prompt: t.Optional[str]) -> None: + """Test that prompt parameter flows through to _send_par_request.""" + oauth_client._id_resolver.handle.resolve = AsyncMock(return_value='did:plc:test123') + oauth_client._id_resolver.did.resolve_atproto_data = AsyncMock( + return_value=MagicMock(handle='test.bsky.social', pds='https://pds.example.com') + ) + + captured_prompt: t.Optional[str] = None + + async def mock_send_par( + authserver_meta: t.Any, + login_hint: str, + pkce_challenge: str, + dpop_key: t.Any, + state: str, + prompt: t.Optional[str] = None, + ) -> tuple[str, str]: + nonlocal captured_prompt + captured_prompt = prompt + return 'urn:ietf:params:oauth:request_uri:test', 'nonce123' + + oauth_client._send_par_request = mock_send_par # type: ignore[method-assign] + + with ( + patch( + 'atproto_oauth.client.discover_authserver_from_pds_async', + new=AsyncMock(return_value='https://auth.example.com'), + ), + patch( + 'atproto_oauth.client.fetch_authserver_metadata_async', + new=AsyncMock( + return_value=MagicMock( + issuer='https://auth.example.com', + authorization_endpoint='https://auth.example.com/authorize', + pushed_authorization_request_endpoint='https://auth.example.com/par', + ) + ), + ), + ): + await oauth_client.start_authorization('test.bsky.social', prompt=prompt) + assert captured_prompt == prompt + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ('prompt', 'expect_in_params'), + [ + ('login', True), + ('select_account', True), + ('consent', True), + ('none', True), + (None, False), + ], +) +async def test_prompt_in_par_params( + oauth_client: OAuthClient, + prompt: t.Optional[str], + expect_in_params: bool, +) -> None: + """Test that prompt is included in PAR params only when provided.""" + authserver_meta = MagicMock( + issuer='https://auth.example.com', + pushed_authorization_request_endpoint='https://auth.example.com/par', + ) + + captured_params: dict[str, str] = {} + + async def mock_make_token_request( + token_url: str, + params: dict[str, str], + dpop_key: t.Any, + dpop_nonce: str, + issuer: t.Optional[str] = None, + ) -> tuple[str, MagicMock]: + nonlocal captured_params + captured_params = params.copy() + response = MagicMock() + response.status_code = 200 + response.json.return_value = {'request_uri': 'urn:test:uri'} + return 'nonce', response + + oauth_client._make_token_request = mock_make_token_request # type: ignore[method-assign] + + await oauth_client._send_par_request( + authserver_meta=authserver_meta, + login_hint='test.bsky.social', + pkce_challenge='challenge123', + dpop_key=MagicMock(), + state='state123', + prompt=prompt, + ) + + if expect_in_params: + assert 'prompt' in captured_params + assert captured_params['prompt'] == prompt + else: + assert 'prompt' not in captured_params diff --git a/tests/test_oauth_dpop.py b/tests/test_oauth_dpop.py new file mode 100644 index 00000000..e2a314cb --- /dev/null +++ b/tests/test_oauth_dpop.py @@ -0,0 +1,239 @@ +"""Tests for DPoP implementation.""" + +import json + +import httpx +from atproto_oauth.dpop import DPoPManager + + +def test_generate_keypair() -> None: + """Test generating DPoP keypair.""" + key = DPoPManager.generate_keypair() + assert key is not None + + # Verify it's an EC key + from cryptography.hazmat.primitives.asymmetric import ec + + assert isinstance(key, ec.EllipticCurvePrivateKey) + + +def test_create_proof() -> None: + """Test creating DPoP proof JWT.""" + key = DPoPManager.generate_keypair() + + proof = DPoPManager.create_proof( + method='GET', + url='https://example.com/api', + private_key=key, + ) + + # Verify JWT structure + assert isinstance(proof, str) + parts = proof.split('.') + assert len(parts) == 3 # header.payload.signature + + # Decode and verify header + import base64 + + header_json = base64.urlsafe_b64decode(parts[0] + '==') + header = json.loads(header_json) + + assert header['typ'] == 'dpop+jwt' + assert header['alg'] == 'ES256' + assert 'jwk' in header + + # Decode and verify payload + payload_json = base64.urlsafe_b64decode(parts[1] + '==') + payload = json.loads(payload_json) + + assert payload['htm'] == 'GET' + assert payload['htu'] == 'https://example.com/api' + assert 'jti' in payload + assert 'iat' in payload + assert 'exp' in payload + + +def test_create_proof_with_nonce() -> None: + """Test creating DPoP proof with nonce.""" + key = DPoPManager.generate_keypair() + nonce = 'test-nonce-123' + + proof = DPoPManager.create_proof( + method='POST', + url='https://example.com/api', + private_key=key, + nonce=nonce, + ) + + # Decode payload and verify nonce + import base64 + + parts = proof.split('.') + payload_json = base64.urlsafe_b64decode(parts[1] + '==') + payload = json.loads(payload_json) + + assert payload['nonce'] == nonce + + +def test_create_proof_with_access_token() -> None: + """Test creating DPoP proof with access token hash.""" + key = DPoPManager.generate_keypair() + access_token = 'test-access-token' + + proof = DPoPManager.create_proof( + method='POST', + url='https://example.com/api', + private_key=key, + access_token=access_token, + ) + + # Decode payload and verify ath claim + import base64 + import hashlib + + parts = proof.split('.') + payload_json = base64.urlsafe_b64decode(parts[1] + '==') + payload = json.loads(payload_json) + + assert 'ath' in payload + + # Verify ath is correct hash + expected_hash = hashlib.sha256(access_token.encode('utf-8')).digest() + expected_ath = base64.urlsafe_b64encode(expected_hash).decode('utf-8').rstrip('=') + assert payload['ath'] == expected_ath + + +def test_is_dpop_nonce_error() -> None: + """Test detecting DPoP nonce error responses.""" + # Test with error in JSON body + response = httpx.Response( + status_code=401, + json={'error': 'use_dpop_nonce'}, + headers={'DPoP-Nonce': 'new-nonce'}, + ) + assert DPoPManager.is_dpop_nonce_error(response) + + # Test with error in WWW-Authenticate header + response = httpx.Response( + status_code=401, + headers={ + 'WWW-Authenticate': 'DPoP error="use_dpop_nonce"', + 'DPoP-Nonce': 'new-nonce', + }, + ) + assert DPoPManager.is_dpop_nonce_error(response) + + # Test normal response + response = httpx.Response(status_code=200, json={'success': True}) + assert not DPoPManager.is_dpop_nonce_error(response) + + +def test_extract_nonce_from_response() -> None: + """Test extracting DPoP nonce from response.""" + response = httpx.Response( + status_code=401, + headers={'DPoP-Nonce': 'test-nonce-456'}, + ) + + nonce = DPoPManager.extract_nonce_from_response(response) + assert nonce == 'test-nonce-456' + + # Test response without nonce + response = httpx.Response(status_code=200) + nonce = DPoPManager.extract_nonce_from_response(response) + assert nonce is None + + +def test_dpop_signature_format() -> None: + """Test that DPoP signature uses IEEE P1363 format (64 bytes for ES256).""" + key = DPoPManager.generate_keypair() + + proof = DPoPManager.create_proof( + method='POST', + url='https://example.com/token', + private_key=key, + ) + + # Decode signature + import base64 + + parts = proof.split('.') + # Add padding if needed + signature_b64 = parts[2] + '=' * (4 - len(parts[2]) % 4) + signature_bytes = base64.urlsafe_b64decode(signature_b64) + + # ES256 signature should be 64 bytes (32 for r, 32 for s) + assert len(signature_bytes) == 64, f'Signature length is {len(signature_bytes)}, expected 64' + + +def test_dpop_htu_strips_query_and_fragment() -> None: + """Test that htu field strips query and fragment per RFC 9449.""" + key = DPoPManager.generate_keypair() + + # Test with query parameters + proof = DPoPManager.create_proof( + method='POST', + url='https://example.com/token?param=value&other=test', + private_key=key, + ) + + import base64 + + parts = proof.split('.') + payload_json = base64.urlsafe_b64decode(parts[1] + '==') + payload = json.loads(payload_json) + + assert payload['htu'] == 'https://example.com/token' + + # Test with fragment + proof = DPoPManager.create_proof( + method='GET', + url='https://example.com/api#section', + private_key=key, + ) + + parts = proof.split('.') + payload_json = base64.urlsafe_b64decode(parts[1] + '==') + payload = json.loads(payload_json) + + assert payload['htu'] == 'https://example.com/api' + + # Test with both query and fragment + proof = DPoPManager.create_proof( + method='GET', + url='https://example.com/api?foo=bar#section', + private_key=key, + ) + + parts = proof.split('.') + payload_json = base64.urlsafe_b64decode(parts[1] + '==') + payload = json.loads(payload_json) + + assert payload['htu'] == 'https://example.com/api' + + +def test_dpop_jwk_format() -> None: + """Test that JWK in header is properly formatted.""" + key = DPoPManager.generate_keypair() + + proof = DPoPManager.create_proof( + method='POST', + url='https://example.com/token', + private_key=key, + ) + + import base64 + + parts = proof.split('.') + header_json = base64.urlsafe_b64decode(parts[0] + '==') + header = json.loads(header_json) + + jwk = header['jwk'] + + # Verify JWK structure + assert jwk['kty'] == 'EC' + assert jwk['crv'] == 'P-256' + assert 'x' in jwk + assert 'y' in jwk + # Must NOT contain private key + assert 'd' not in jwk diff --git a/tests/test_oauth_pkce.py b/tests/test_oauth_pkce.py new file mode 100644 index 00000000..ee31faa7 --- /dev/null +++ b/tests/test_oauth_pkce.py @@ -0,0 +1,64 @@ +"""Tests for PKCE implementation.""" + +import hashlib +from base64 import urlsafe_b64encode + +import pytest +from atproto_oauth.pkce import PKCEManager + + +def test_generate_verifier_default_length() -> None: + """Test generating PKCE verifier with default length.""" + verifier = PKCEManager.generate_verifier() + assert isinstance(verifier, str) + assert 43 <= len(verifier) <= 128 + + +def test_generate_verifier_custom_length() -> None: + """Test generating PKCE verifier with custom length.""" + length = 64 + verifier = PKCEManager.generate_verifier(length) + assert len(verifier) == length + + +def test_generate_verifier_invalid_length() -> None: + """Test that invalid length raises error.""" + with pytest.raises(ValueError, match='must be between 43 and 128'): + PKCEManager.generate_verifier(20) + + with pytest.raises(ValueError, match='must be between 43 and 128'): + PKCEManager.generate_verifier(200) + + +def test_generate_challenge() -> None: + """Test generating S256 challenge from verifier.""" + verifier = 'test_verifier_123456789' + challenge = PKCEManager.generate_challenge(verifier) + + # Verify it's base64url encoded SHA256 + expected_digest = hashlib.sha256(verifier.encode('utf-8')).digest() + expected_challenge = urlsafe_b64encode(expected_digest).decode('utf-8').rstrip('=') + + assert challenge == expected_challenge + + +def test_generate_pair() -> None: + """Test generating verifier and challenge pair.""" + verifier, challenge = PKCEManager.generate_pair() + + # Verify verifier format + assert isinstance(verifier, str) + assert 43 <= len(verifier) <= 128 + + # Verify challenge matches verifier + expected_challenge = PKCEManager.generate_challenge(verifier) + assert challenge == expected_challenge + + +def test_challenge_is_deterministic() -> None: + """Test that same verifier always produces same challenge.""" + verifier = PKCEManager.generate_verifier() + challenge1 = PKCEManager.generate_challenge(verifier) + challenge2 = PKCEManager.generate_challenge(verifier) + + assert challenge1 == challenge2