From cd34774c7ff4350117fd94fdfd233a327feb38a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Voron?= Date: Fri, 6 Dec 2024 11:22:21 +0100 Subject: [PATCH] Implement a more generic get_profile method on clients --- docs/usage.md | 10 +++- httpx_oauth/clients/discord.py | 24 +++++--- httpx_oauth/clients/facebook.py | 15 +++-- httpx_oauth/clients/franceconnect.py | 16 +++-- httpx_oauth/clients/github.py | 88 ++++++++++++++++++++-------- httpx_oauth/clients/google.py | 28 +++++---- httpx_oauth/clients/kakao.py | 21 ++++--- httpx_oauth/clients/linkedin.py | 36 ++++++++---- httpx_oauth/clients/microsoft.py | 16 +++-- httpx_oauth/clients/naver.py | 19 ++++-- httpx_oauth/clients/openid.py | 16 +++-- httpx_oauth/clients/reddit.py | 24 ++++---- httpx_oauth/clients/shopify.py | 53 +++++++++++++---- httpx_oauth/exceptions.py | 13 +++- httpx_oauth/oauth2.py | 26 +++++++- tests/test_oauth2.py | 7 +++ 16 files changed, 294 insertions(+), 118 deletions(-) diff --git a/docs/usage.md b/docs/usage.md index 1881c78..354c149 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -54,9 +54,15 @@ access_token = await client.refresh_token("REFRESH_TOKEN") For providers supporting it, you can ask to revoke an access or refresh token. For this, use the [revoke_token][httpx_oauth.oauth2.BaseOAuth2.revoke_token] method. -## Get authenticated user ID and email +## Get profile -For convenience, we provide a method that'll use a valid access token to query the provider API and get the ID and the email (if available) of the authenticated user. For this, use the [get_id_email][httpx_oauth.oauth2.BaseOAuth2.get_id_email] method. +For convenience, we provide a method that'll use a valid access token to query the provider API and get the profile of the authenticated user. For this, use the [get_profile][httpx_oauth.oauth2.BaseOAuth2.get_profile] method. + +This method is implemented specifically on each provider. Please note it's a raw JSON output from the provider API, so it might vary greatly. + +### Get authenticated user ID and email + +Often, what you need is only the ID and the email. We offer another convenience method that'll do the heavy lifting of retrieving them from the profile output: the [get_id_email][httpx_oauth.oauth2.BaseOAuth2.get_id_email] method. This method is implemented specifically on each provider. diff --git a/httpx_oauth/clients/discord.py b/httpx_oauth/clients/discord.py index 380645f..e6c3c04 100644 --- a/httpx_oauth/clients/discord.py +++ b/httpx_oauth/clients/discord.py @@ -1,6 +1,6 @@ from typing import Any, Optional, cast -from httpx_oauth.exceptions import GetIdEmailError +from httpx_oauth.exceptions import GetIdEmailError, GetProfileError from httpx_oauth.oauth2 import BaseOAuth2 AUTHORIZE_ENDPOINT = "https://discord.com/api/oauth2/authorize" @@ -52,7 +52,7 @@ def __init__( revocation_endpoint_auth_method="client_secret_basic", ) - async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: + async def get_profile(self, token: str) -> dict[str, Any]: async with self.get_httpx_client() as client: response = await client.get( PROFILE_ENDPOINT, @@ -60,14 +60,20 @@ async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: ) if response.status_code >= 400: - raise GetIdEmailError(response=response) + raise GetProfileError(response=response) + + return cast(dict[str, Any], response.json()) - data = cast(dict[str, Any], response.json()) + async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: + try: + profile = await self.get_profile(token) + except GetProfileError as e: + raise GetIdEmailError(response=e.response) from e - user_id = data["id"] - user_email = data.get("email") + user_id = profile["id"] + user_email = profile.get("email") - if not data.get("verified", False): - user_email = None + if not profile.get("verified", False): + user_email = None - return user_id, user_email + return user_id, user_email diff --git a/httpx_oauth/clients/facebook.py b/httpx_oauth/clients/facebook.py index 4a0baa3..eebfe31 100644 --- a/httpx_oauth/clients/facebook.py +++ b/httpx_oauth/clients/facebook.py @@ -1,6 +1,6 @@ from typing import Any, Optional, cast -from httpx_oauth.exceptions import GetIdEmailError +from httpx_oauth.exceptions import GetIdEmailError, GetProfileError from httpx_oauth.oauth2 import BaseOAuth2, OAuth2RequestError, OAuth2Token AUTHORIZE_ENDPOINT = "https://www.facebook.com/v5.0/dialog/oauth" @@ -89,7 +89,7 @@ async def get_long_lived_access_token(self, token: str) -> OAuth2Token: data = self.get_json(response, exc_class=GetLongLivedAccessTokenError) return OAuth2Token(data) - async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: + async def get_profile(self, token: str) -> dict[str, Any]: async with self.get_httpx_client() as client: response = await client.get( PROFILE_ENDPOINT, @@ -97,8 +97,13 @@ async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: ) if response.status_code >= 400: - raise GetIdEmailError(response=response) + raise GetProfileError(response=response) - data = cast(dict[str, Any], response.json()) + return cast(dict[str, Any], response.json()) - return data["id"], data.get("email") + async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: + try: + profile = await self.get_profile(token) + except GetProfileError as e: + raise GetIdEmailError(response=e.response) from e + return profile["id"], profile.get("email") diff --git a/httpx_oauth/clients/franceconnect.py b/httpx_oauth/clients/franceconnect.py index c8e91b9..a5430a4 100644 --- a/httpx_oauth/clients/franceconnect.py +++ b/httpx_oauth/clients/franceconnect.py @@ -1,7 +1,7 @@ import secrets from typing import Any, Literal, Optional, TypedDict -from httpx_oauth.exceptions import GetIdEmailError +from httpx_oauth.exceptions import GetIdEmailError, GetProfileError from httpx_oauth.oauth2 import BaseOAuth2 ENDPOINTS = { @@ -72,7 +72,7 @@ async def get_authorization_url( redirect_uri, state, scope, extras_params=_extras_params ) - async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: + async def get_profile(self, token: str) -> dict[str, Any]: async with self.get_httpx_client() as client: response = await client.get( self.profile_endpoint, @@ -80,8 +80,14 @@ async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: ) if response.status_code >= 400: - raise GetIdEmailError(response=response) + raise GetProfileError(response=response) + + return response.json() - data: dict[str, Any] = response.json() + async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: + try: + profile = await self.get_profile(token) + except GetProfileError as e: + raise GetIdEmailError(response=e.response) from e - return str(data["sub"]), data.get("email") + return str(profile["sub"]), profile.get("email") diff --git a/httpx_oauth/clients/github.py b/httpx_oauth/clients/github.py index 6b9a80b..2c8beb2 100644 --- a/httpx_oauth/clients/github.py +++ b/httpx_oauth/clients/github.py @@ -2,7 +2,7 @@ import httpx -from httpx_oauth.exceptions import GetIdEmailError +from httpx_oauth.exceptions import GetIdEmailError, GetProfileError from httpx_oauth.oauth2 import BaseOAuth2, OAuth2Token, RefreshTokenError AUTHORIZE_ENDPOINT = "https://github.com/login/oauth/authorize" @@ -106,10 +106,20 @@ async def refresh_token(self, refresh_token: str) -> OAuth2Token: return OAuth2Token(data) - async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: + async def get_profile(self, token: str) -> dict[str, Any]: + async with httpx.AsyncClient( + headers={**self.request_headers, "Authorization": f"token {token}"} + ) as client: + response = await client.get(PROFILE_ENDPOINT) + + if response.status_code >= 400: + raise GetProfileError(response=response) + + return cast(dict[str, Any], response.json()) + + async def get_emails(self, token: str) -> list[dict[str, Any]]: """ - Returns the id and the email (if available) of the authenticated user - from the API provider. + Return the emails of the authenticated user from the API provider. !!! tip You should enable **Email addresses** permission @@ -120,43 +130,71 @@ async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: token: The access token. Returns: - A tuple with the id and the email of the authenticated user. - + A list of emails as described in the [GitHub API](https://docs.github.com/en/rest/users/emails?apiVersion=2022-11-28#list-email-addresses-for-the-authenticated-user). Raises: - httpx_oauth.exceptions.GetIdEmailError: - An error occurred while getting the id and email. + httpx_oauth.exceptions.GetProfileError: + An error occurred while getting the emails. Examples: ```py - user_id, user_email = await client.get_id_email("TOKEN") + emails = await client.get_emails("TOKEN") ``` """ async with httpx.AsyncClient( headers={**self.request_headers, "Authorization": f"token {token}"} ) as client: - response = await client.get(PROFILE_ENDPOINT) + response = await client.get(EMAILS_ENDPOINT) if response.status_code >= 400: - raise GetIdEmailError(response=response) + raise GetProfileError(response=response) + + return cast(list[dict[str, Any]], response.json()) + + async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: + """ + Returns the id and the email (if available) of the authenticated user + from the API provider. - data = cast(dict[str, Any], response.json()) + !!! tip + You should enable **Email addresses** permission + in the **Permissions & events** section of your GitHub app parameters. + You can find it at [https://github.com/settings/apps/{YOUR_APP}/permissions](https://github.com/settings/apps/{YOUR_APP}/permissions). - id = data["id"] - email = data.get("email") + Args: + token: The access token. - # No public email, make a separate call to /user/emails - if email is None: - response = await client.get(EMAILS_ENDPOINT) + Returns: + A tuple with the id and the email of the authenticated user. - if response.status_code >= 400: - raise GetIdEmailError(response=response) - emails = cast(list[dict[str, Any]], response.json()) + Raises: + httpx_oauth.exceptions.GetIdEmailError: + An error occurred while getting the id and email. - # Use the primary email if it exists, otherwise the first - email = next( - (e["email"] for e in emails if e.get("primary")), emails[0]["email"] - ) + Examples: + ```py + user_id, user_email = await client.get_id_email("TOKEN") + ``` + """ + try: + profile = await self.get_profile(token) + except GetProfileError as e: + raise GetIdEmailError(response=e.response) from e + + id = profile["id"] + email = profile.get("email") + + # No public email, make a separate call to /user/emails + if email is None: + try: + emails = await self.get_emails(token) + except GetProfileError as e: + raise GetIdEmailError(response=e.response) from e + + # Use the primary email if it exists, otherwise the first + email = next( + (e["email"] for e in emails if e.get("primary")), emails[0]["email"] + ) - return str(id), email + return str(id), email diff --git a/httpx_oauth/clients/google.py b/httpx_oauth/clients/google.py index 1aad5fa..e7299e7 100644 --- a/httpx_oauth/clients/google.py +++ b/httpx_oauth/clients/google.py @@ -1,6 +1,6 @@ from typing import Any, Literal, Optional, TypedDict, cast -from httpx_oauth.exceptions import GetIdEmailError +from httpx_oauth.exceptions import GetIdEmailError, GetProfileError from httpx_oauth.oauth2 import BaseOAuth2 AUTHORIZE_ENDPOINT = "https://accounts.google.com/o/oauth2/v2/auth" @@ -65,7 +65,7 @@ def __init__( revocation_endpoint_auth_method="client_secret_post", ) - async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: + async def get_profile(self, token: str) -> dict[str, Any]: async with self.get_httpx_client() as client: response = await client.get( PROFILE_ENDPOINT, @@ -74,15 +74,21 @@ async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: ) if response.status_code >= 400: - raise GetIdEmailError(response=response) + raise GetProfileError(response=response) - data = cast(dict[str, Any], response.json()) + return cast(dict[str, Any], response.json()) - user_id = data["resourceName"] - user_email = next( - email["value"] - for email in data["emailAddresses"] - if email["metadata"]["primary"] - ) + async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: + try: + profile = await self.get_profile(token) + except GetProfileError as e: + raise GetIdEmailError(response=e.response) from e + + user_id = profile["resourceName"] + user_email = next( + email["value"] + for email in profile["emailAddresses"] + if email["metadata"]["primary"] + ) - return user_id, user_email + return user_id, user_email diff --git a/httpx_oauth/clients/kakao.py b/httpx_oauth/clients/kakao.py index 1aecf73..e6e8f0d 100644 --- a/httpx_oauth/clients/kakao.py +++ b/httpx_oauth/clients/kakao.py @@ -1,7 +1,7 @@ import json from typing import Any, Optional, cast -from httpx_oauth.exceptions import GetIdEmailError +from httpx_oauth.exceptions import GetIdEmailError, GetProfileError from httpx_oauth.oauth2 import BaseOAuth2 AUTHORIZE_ENDPOINT = "https://kauth.kakao.com/oauth/authorize" @@ -52,7 +52,7 @@ def __init__( revocation_endpoint_auth_method="client_secret_post", ) - async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: + async def get_profile(self, token: str) -> dict[str, Any]: async with self.get_httpx_client() as client: response = await client.post( PROFILE_ENDPOINT, @@ -61,9 +61,16 @@ async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: ) if response.status_code >= 400: - raise GetIdEmailError(response=response) + raise GetProfileError(response=response) + + return cast(dict[str, Any], response.json()) + + async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: + try: + profile = await self.get_profile(token) + except GetProfileError as e: + raise GetIdEmailError(response=e.response) from e - payload = cast(dict[str, Any], response.json()) - account_id = str(payload["id"]) - email = payload["kakao_account"].get("email") - return account_id, email + account_id = str(profile["id"]) + email = profile["kakao_account"].get("email") + return account_id, email diff --git a/httpx_oauth/clients/linkedin.py b/httpx_oauth/clients/linkedin.py index bc958b8..6619a6d 100644 --- a/httpx_oauth/clients/linkedin.py +++ b/httpx_oauth/clients/linkedin.py @@ -1,6 +1,6 @@ from typing import Any, Optional, cast -from httpx_oauth.exceptions import GetIdEmailError +from httpx_oauth.exceptions import GetIdEmailError, GetProfileError from httpx_oauth.oauth2 import BaseOAuth2, OAuth2Token AUTHORIZE_ENDPOINT = "https://www.linkedin.com/oauth/v2/authorization" @@ -74,30 +74,40 @@ async def refresh_token(self, refresh_token: str) -> OAuth2Token: """ return await super().refresh_token(refresh_token) # pragma: no cover - async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: + async def get_profile(self, token: str) -> dict[str, Any]: async with self.get_httpx_client() as client: - profile_response = await client.get( + response = await client.get( PROFILE_ENDPOINT, headers={"Authorization": f"Bearer {token}"}, params={"projection": "(id)"}, ) - if profile_response.status_code >= 400: - raise GetIdEmailError(response=profile_response) + if response.status_code >= 400: + raise GetProfileError(response=response) + + return cast(dict[str, Any], response.json()) - email_response = await client.get( + async def get_email(self, token: str) -> dict[str, Any]: + async with self.get_httpx_client() as client: + response = await client.get( EMAIL_ENDPOINT, headers={"Authorization": f"Bearer {token}"}, params={"q": "members", "projection": "(elements*(handle~))"}, ) - if email_response.status_code >= 400: - raise GetIdEmailError(response=email_response) + if response.status_code >= 400: + raise GetProfileError(response=response) - profile_data = cast(dict[str, Any], profile_response.json()) - user_id = profile_data["id"] + return cast(dict[str, Any], response.json()) + + async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: + try: + profile = await self.get_profile(token) + email = await self.get_email(token) + except GetProfileError as e: + raise GetIdEmailError(response=e.response) from e - email_data = cast(dict[str, Any], email_response.json()) - user_email = email_data["elements"][0]["handle~"]["emailAddress"] + user_id = profile["id"] + user_email = email["elements"][0]["handle~"]["emailAddress"] - return user_id, user_email + return user_id, user_email diff --git a/httpx_oauth/clients/microsoft.py b/httpx_oauth/clients/microsoft.py index f20433e..e05fc7f 100644 --- a/httpx_oauth/clients/microsoft.py +++ b/httpx_oauth/clients/microsoft.py @@ -1,6 +1,6 @@ from typing import Any, Optional, cast -from httpx_oauth.exceptions import GetIdEmailError +from httpx_oauth.exceptions import GetIdEmailError, GetProfileError from httpx_oauth.oauth2 import BaseOAuth2 AUTHORIZE_ENDPOINT = "https://login.microsoftonline.com/{tenant}/oauth2/v2.0/authorize" @@ -63,7 +63,7 @@ def get_authorization_url( redirect_uri, state=state, scope=scope, extras_params=extras_params ) - async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: + async def get_profile(self, token: str) -> dict[str, Any]: async with self.get_httpx_client() as client: response = await client.get( PROFILE_ENDPOINT, @@ -71,8 +71,14 @@ async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: ) if response.status_code >= 400: - raise GetIdEmailError(response=response) + raise GetProfileError(response=response) + + return cast(dict[str, Any], response.json()) - data = cast(dict[str, Any], response.json()) + async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: + try: + profile = await self.get_profile(token) + except GetProfileError as e: + raise GetIdEmailError(response=e.response) from e - return data["id"], data["userPrincipalName"] + return profile["id"], profile["userPrincipalName"] diff --git a/httpx_oauth/clients/naver.py b/httpx_oauth/clients/naver.py index ea51125..3c7b06c 100644 --- a/httpx_oauth/clients/naver.py +++ b/httpx_oauth/clients/naver.py @@ -1,6 +1,6 @@ from typing import Any, Optional, cast -from httpx_oauth.exceptions import GetIdEmailError +from httpx_oauth.exceptions import GetIdEmailError, GetProfileError from httpx_oauth.oauth2 import BaseOAuth2, RevokeTokenError AUTHORIZE_ENDPOINT = "https://nid.naver.com/oauth2.0/authorize" @@ -80,7 +80,7 @@ async def revoke_token( return None - async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: + async def get_profile(self, token: str) -> dict[str, Any]: async with self.get_httpx_client() as client: response = await client.post( PROFILE_ENDPOINT, @@ -88,8 +88,15 @@ async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: ) if response.status_code >= 400: - raise GetIdEmailError(response=response) + raise GetProfileError(response=response) + + json = response.json() + return cast(dict[str, Any], json["response"]) + + async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: + try: + profile = await self.get_profile(token) + except GetProfileError as e: + raise GetIdEmailError(response=e.response) from e - json = cast(dict[str, Any], response.json()) - account_info: dict[str, Any] = json["response"] - return account_info["id"], account_info.get("email") + return profile["id"], profile.get("email") diff --git a/httpx_oauth/clients/openid.py b/httpx_oauth/clients/openid.py index 12fe5e4..e6c3979 100644 --- a/httpx_oauth/clients/openid.py +++ b/httpx_oauth/clients/openid.py @@ -2,7 +2,7 @@ import httpx -from httpx_oauth.exceptions import GetIdEmailError +from httpx_oauth.exceptions import GetIdEmailError, GetProfileError from httpx_oauth.oauth2 import BaseOAuth2, OAuth2ClientAuthMethod, OAuth2RequestError BASE_SCOPES = ["openid", "email"] @@ -100,7 +100,7 @@ def __init__( ), ) - async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: + async def get_profile(self, token: str) -> dict[str, Any]: async with self.get_httpx_client() as client: response = await client.get( self.openid_configuration["userinfo_endpoint"], @@ -108,8 +108,14 @@ async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: ) if response.status_code >= 400: - raise GetIdEmailError(response=response) + raise GetProfileError(response=response) + + return response.json() - data: dict[str, Any] = response.json() + async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: + try: + profile = await self.get_profile(token) + except GetProfileError as e: + raise GetIdEmailError(response=e.response) from e - return str(data["sub"]), data.get("email") + return str(profile["sub"]), profile.get("email") diff --git a/httpx_oauth/clients/reddit.py b/httpx_oauth/clients/reddit.py index d9fdc4c..b67bf33 100644 --- a/httpx_oauth/clients/reddit.py +++ b/httpx_oauth/clients/reddit.py @@ -2,7 +2,7 @@ import httpx -from httpx_oauth.exceptions import GetIdEmailError +from httpx_oauth.exceptions import GetIdEmailError, GetProfileError from httpx_oauth.oauth2 import ( BaseOAuth2, GetAccessTokenError, @@ -74,20 +74,22 @@ async def get_access_token( return oauth2_token - async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: + async def get_profile(self, token: str) -> dict[str, Any]: async with self.get_httpx_client() as client: - headers = self.request_headers.copy() - headers["Authorization"] = f"Bearer {token}" - response = await client.get( IDENTITY_ENDPOINT, - headers=headers, + headers={**self.request_headers, "Authorization": f"Bearer {token}"}, ) - # Reddit doesn't return any useful JSON in case of auth failures - # on oauth.reddit.com endpoints, so we simulate our own if response.status_code != httpx.codes.OK: - raise GetIdEmailError(response=response) + raise GetProfileError(response=response) + + return cast(dict[str, Any], response.json()) + + async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: + try: + profile = await self.get_profile(token) + except GetProfileError as e: + raise GetIdEmailError(response=e.response) from e - data = cast(dict[str, Any], response.json()) - return data["name"], None + return profile["name"], None diff --git a/httpx_oauth/clients/shopify.py b/httpx_oauth/clients/shopify.py index 1e986e3..38c846e 100644 --- a/httpx_oauth/clients/shopify.py +++ b/httpx_oauth/clients/shopify.py @@ -1,6 +1,6 @@ from typing import Any, Literal, Optional, TypedDict, cast -from httpx_oauth.exceptions import GetIdEmailError +from httpx_oauth.exceptions import GetIdEmailError, GetProfileError from httpx_oauth.oauth2 import BaseOAuth2 AUTHORIZE_ENDPOINT = "https://{shop}.myshopify.com/admin/oauth/authorize" @@ -68,6 +68,40 @@ def __init__( token_endpoint_auth_method="client_secret_post", ) + async def get_profile(self, token: str) -> dict[str, Any]: + """ + Returns the profile of the authenticated user from the API provider. + + !!! warning "`get_profile` is based on the `Shop` resource" + The implementation of `get_profile` calls the [Get Shop endpoint](https://shopify.dev/docs/api/admin-rest/2023-04/resources/shop#get-shop) of the Shopify Admin API. + It means that it'll return you the **profile of the shop**. + + Args: + token: The access token. + + Returns: + The profile of the authenticated shop. + + Raises: + httpx_oauth.exceptions.GetProfileError: + An error occurred while getting the profile + + Examples: + ```py + profile = await client.get_profile("TOKEN") + ``` + """ + async with self.get_httpx_client() as client: + response = await client.get( + self.profile_endpoint, + headers={"X-Shopify-Access-Token": token}, + ) + + if response.status_code >= 400: + raise GetProfileError(response=response) + + return cast(dict[str, Any], response.json()) + async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: """ Returns the id and the email (if available) of the authenticated user @@ -93,15 +127,10 @@ async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: user_id, user_email = await client.get_id_email("TOKEN") ``` """ - async with self.get_httpx_client() as client: - response = await client.get( - self.profile_endpoint, - headers={"X-Shopify-Access-Token": token}, - ) - - if response.status_code >= 400: - raise GetIdEmailError(response=response) + try: + profile = await self.get_profile(token) + except GetProfileError as e: + raise GetIdEmailError(response=e.response) from e - data = cast(dict[str, Any], response.json()) - shop = data["shop"] - return str(shop["id"]), shop["email"] + shop = profile["shop"] + return str(shop["id"]), shop["email"] diff --git a/httpx_oauth/exceptions.py b/httpx_oauth/exceptions.py index 024f730..30bed1a 100644 --- a/httpx_oauth/exceptions.py +++ b/httpx_oauth/exceptions.py @@ -13,7 +13,7 @@ def __init__(self, message: str) -> None: super().__init__(message) -class GetIdEmailError(HTTPXOAuthError): +class GetProfileError(HTTPXOAuthError): """Error raised while retrieving user profile from provider API.""" def __init__( @@ -23,3 +23,14 @@ def __init__( ) -> None: self.response = response super().__init__(message) + + +class GetIdEmailError(GetProfileError): + """Error raised while retrieving id and email from provider API.""" + + def __init__( + self, + message: str = "Error while retrieving id and email.", + response: Union[httpx.Response, None] = None, + ) -> None: + super().__init__(message, response) diff --git a/httpx_oauth/oauth2.py b/httpx_oauth/oauth2.py index 99dc1c7..7b52c4d 100644 --- a/httpx_oauth/oauth2.py +++ b/httpx_oauth/oauth2.py @@ -241,7 +241,7 @@ async def get_authorization_url( authorization_url = await client.get_authorization_url( "https://www.tintagel.bt/oauth-callback", scope=["SCOPE1", "SCOPE2", "SCOPE3"], ) - ```py + ``` """ params = { "response_type": "code", @@ -393,6 +393,30 @@ async def revoke_token( return None + async def get_profile(self, token: str) -> dict[str, Any]: + """ + Returns the profile of the authenticated user + from the API provider. + + **It assumes you have asked for the required scopes**. + + Args: + token: The access token. + + Returns: + The profile of the authenticated user. + + Raises: + httpx_oauth.exceptions.GetProfileError: + An error occurred while getting the profile. + + Examples: + ```py + profile = await client.get_profile("TOKEN") + ``` + """ + raise NotImplementedError() + async def get_id_email(self, token: str) -> tuple[str, Optional[str]]: """ Returns the id and the email (if available) of the authenticated user diff --git a/tests/test_oauth2.py b/tests/test_oauth2.py index 4bae710..04dbf17 100644 --- a/tests/test_oauth2.py +++ b/tests/test_oauth2.py @@ -331,6 +331,13 @@ async def test_revoke_token_http_error(self, client_revoke: OAuth2): assert excinfo.value.response is None +@pytest.mark.asyncio +class TestGetProfile: + async def test_not_implemented(self, client: OAuth2): + with pytest.raises(NotImplementedError): + await client.get_profile("TOKEN") + + @pytest.mark.asyncio class TestGetIdEmail: async def test_not_implemented(self, client: OAuth2):