Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add client credential oauth integration support + related databricks helpers to SDK #348

Merged
merged 22 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
67c7af3
initial proposal for client credentials support in python sdk
zackverham Nov 27, 2024
67edfbb
a bit of refactoring + some more documentation
zackverham Nov 27, 2024
08796e2
test coverage, linting
zackverham Dec 2, 2024
e654b0b
Update src/posit/connect/external/databricks.py
zackverham Dec 3, 2024
dd02020
Update src/posit/connect/external/databricks.py
zackverham Dec 3, 2024
36634c7
responding to PR comments
zackverham Dec 3, 2024
d7ba5e6
fix comment typo
zackverham Dec 3, 2024
e146ddd
Wrap examples in markdown blocks
schloerke Dec 4, 2024
d58a743
Expose `external` on docs
schloerke Dec 4, 2024
b229b4b
updating docstrings in response to PR comments
zackverham Dec 4, 2024
3ec143d
Merge branch 'zack-client-creds' of github.com:posit-dev/posit-sdk-py…
zackverham Dec 4, 2024
89497bd
docstring polish - backticks
zackverham Dec 4, 2024
3f42531
fix linting issues
zackverham Dec 4, 2024
fb80f1d
Copy in existing example app and add title to each file
schloerke Dec 4, 2024
73569a3
add patched-in local strategy for client credential access tokens
zackverham Dec 5, 2024
91877ee
add implementing strategy that can use local credentials provider
zackverham Dec 5, 2024
c38fecd
missing params
zackverham Dec 5, 2024
4a1a512
commit to force develop dependency to update
zackverham Dec 5, 2024
54a47c9
add raise_for_status
zackverham Dec 5, 2024
144e895
adding docstrings, tests
zackverham Dec 5, 2024
2042938
sweep through docstrings in response to PR comments
zackverham Dec 6, 2024
59862c4
linting
zackverham Dec 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 97 additions & 21 deletions src/posit/connect/external/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@
from typing import Callable, Dict, Optional

from ..client import Client
from ..oauth import Credentials
from .external import is_local

"""
NOTE: These APIs are provided as a convenience and are subject to breaking changes:
https://github.com/databricks/databricks-sdk-py#interface-stability
"""

POSIT_OAUTH_INTEGRATION_AUTH_TYPE = "posit-oauth-integration"

# The Databricks SDK CredentialsProvider == Databricks SQL HeaderFactory
CredentialsProvider = Callable[[], Dict[str, str]]


class CredentialsStrategy(abc.ABC):
"""Maintain compatibility with the Databricks SQL/SDK client libraries.

Expand All @@ -29,29 +31,80 @@ def __call__(self, *args, **kwargs) -> CredentialsProvider:
raise NotImplementedError


def _new_bearer_authorization_header(credentials: Credentials) -> Dict[str, str]:
"""Helper to transform an Credentials object into the Bearer auth header consumed by databricks.

Raises
------
ValueError: If provided Credentials object does not contain an access token

Returns
-------
Dict[str, str]
"""
access_token = credentials.get("access_token")
if access_token is None:
raise ValueError("Missing value for field 'access_token' in credentials.")
return {"Authorization": f"Bearer {access_token}"}

def _get_auth_type(local_auth_type: str) -> str:
"""Returns the auth type currently in use.

The databricks-sdk client uses the configurated auth_type to create
a user-agent string which is used for attribution. We should only
overwrite the auth_type if we are using the PositCredentialsStrategy (non-local),
otherwise, we should return the auth_type of the configured local_strategy instead
to avoid breaking someone elses attribution.

https://github.com/databricks/databricks-sdk-py/blob/v0.29.0/databricks/sdk/config.py#L261-L269

NOTE: The databricks-sql client does not use auth_type to set the user-agent.
https://github.com/databricks/databricks-sql-python/blob/v3.3.0/src/databricks/sql/client.py#L214-L219

Returns
-------
str
"""
if is_local():
return local_auth_type

return POSIT_OAUTH_INTEGRATION_AUTH_TYPE



class PositContentCredentialsProvider:
"""CredentialsProvider implementation which initiates a credential exchange using a content-session-token."""

def __init__(self, client: Client):
self._client = client

def __call__(self) -> Dict[str, str]:
credentials = self._client.oauth.get_content_credentials()
return _new_bearer_authorization_header(credentials)


class PositCredentialsProvider:
"""CredentialsProvider implementation which initiates a credential exchange using a user-session-token."""

def __init__(self, client: Client, user_session_token: str):
self._client = client
self._user_session_token = user_session_token

def __call__(self) -> Dict[str, str]:
credentials = self._client.oauth.get_credentials(self._user_session_token)
access_token = credentials.get("access_token")
if access_token is None:
raise ValueError("Missing value for field 'access_token' in credentials.")
return {"Authorization": f"Bearer {access_token}"}
return _new_bearer_authorization_header(credentials)


class PositCredentialsStrategy(CredentialsStrategy):
class PositContentCredentialsStrategy(CredentialsStrategy):
"""CredentialsStrategy implementation which returns a PositContentCredentialsProvider when called."""

def __init__(
self,
local_strategy: CredentialsStrategy,
client: Optional[Client] = None,
user_session_token: Optional[str] = None,
):
self._local_strategy = local_strategy
self._client = client
self._user_session_token = user_session_token

def sql_credentials_provider(self, *args, **kwargs):
"""The sql connector attempts to call the credentials provider w/o any args.
Expand All @@ -66,23 +119,46 @@ def sql_credentials_provider(self, *args, **kwargs):
return lambda: self.__call__(*args, **kwargs)

def auth_type(self) -> str:
"""Returns the auth type currently in use.
return _get_auth_type(self._local_strategy.auth_type())

def __call__(self, *args, **kwargs) -> CredentialsProvider:
# If the content is not running on Connect then fall back to local_strategy
if is_local():
return self._local_strategy(*args, **kwargs)

if self._client is None:
self._client = Client()

The databricks-sdk client uses the configurated auth_type to create
a user-agent string which is used for attribution. We should only
overwrite the auth_type if we are using the PositCredentialsStrategy (non-local),
otherwise, we should return the auth_type of the configured local_strategy instead
to avoid breaking someone elses attribution.
return PositContentCredentialsProvider(self._client)

https://github.com/databricks/databricks-sdk-py/blob/v0.29.0/databricks/sdk/config.py#L261-L269

NOTE: The databricks-sql client does not use auth_type to set the user-agent.
https://github.com/databricks/databricks-sql-python/blob/v3.3.0/src/databricks/sql/client.py#L214-L219
class PositCredentialsStrategy(CredentialsStrategy):
"""CredentialsStrategy implementation which returns a PositContentCredentialsProvider when called."""

def __init__(
self,
local_strategy: CredentialsStrategy,
client: Optional[Client] = None,
user_session_token: Optional[str] = None,
):
self._local_strategy = local_strategy
self._client = client
self._user_session_token = user_session_token

def sql_credentials_provider(self, *args, **kwargs):
"""The sql connector attempts to call the credentials provider w/o any args.

The SQL client's `ExternalAuthProvider` is not compatible w/ the SDK's implementation of
`CredentialsProvider`, so create a no-arg lambda that wraps the args defined by the real caller.
This way we can pass in a databricks `Config` object required by most of the SDK's `CredentialsProvider`
implementations from where `sql.connect` is called.

https://github.com/databricks/databricks-sql-python/issues/148#issuecomment-2271561365
"""
if is_local():
return self._local_strategy.auth_type()
else:
return "posit-oauth-integration"
return lambda: self.__call__(*args, **kwargs)

def auth_type(self) -> str:
return _get_auth_type(self._local_strategy.auth_type())

def __call__(self, *args, **kwargs) -> CredentialsProvider:
# If the content is not running on Connect then fall back to local_strategy
Expand Down
1 change: 1 addition & 0 deletions src/posit/connect/oauth/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .oauth import Credentials as Credentials
from .oauth import OAuth as OAuth
44 changes: 39 additions & 5 deletions src/posit/connect/oauth/oauth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import os
from typing import Optional

from typing_extensions import TypedDict
Expand All @@ -8,12 +9,36 @@
from .integrations import Integrations
from .sessions import Sessions

GRANT_TYPE = "urn:ietf:params:oauth:grant-type:token-exchange"
USER_SESSION_TOKEN_TYPE = "urn:posit:connect:user-session-token"
CONTENT_SESSION_TOKEN_TYPE = "urn:posit:connect:content-session-token"

def _get_content_session_token() -> str:
"""Return the content session token.

Reads the environment variable 'CONNECT_CONTENT_SESSION_TOKEN'.

Raises
------
ValueError: If CONNECT_CONTENT_SESSION_TOKEN is not set or invalid

Returns
-------
str
"""
value = os.environ.get("CONNECT_CONTENT_SESSION_TOKEN")
if not value:
raise ValueError("Invalid value for 'CONNECT_CONTENT_SESSION_TOKEN': Must be a non-empty string.")
return value

class OAuth(Resources):
def __init__(self, params: ResourceParameters, api_key: str) -> None:
super().__init__(params)
self.api_key = api_key

def _get_credentials_url(self) -> str:
return self.params.url + "v1/oauth/integrations/credentials"

@property
def integrations(self):
return Integrations(self.params)
Expand All @@ -23,18 +48,27 @@ def sessions(self):
return Sessions(self.params)

def get_credentials(self, user_session_token: Optional[str] = None) -> Credentials:
url = self.params.url + "v1/oauth/integrations/credentials"

"""Perform an oauth credential exchange with a user-session-token."""
# craft a credential exchange request
data = {}
data["grant_type"] = "urn:ietf:params:oauth:grant-type:token-exchange"
data["subject_token_type"] = "urn:posit:connect:user-session-token"
data["grant_type"] = GRANT_TYPE
data["subject_token_type"] = USER_SESSION_TOKEN_TYPE
if user_session_token:
data["subject_token"] = user_session_token

response = self.params.session.post(url, data=data)
response = self.params.session.post(self._get_credentials_url(), data=data)
return Credentials(**response.json())

def get_content_credentials(self, content_session_token: Optional[str] = None) -> Credentials:
"""Perform an oauth credential exchange with a content-session-token."""
# craft a credential exchange request
data = {}
data["grant_type"] = GRANT_TYPE
data["subject_token_type"] = CONTENT_SESSION_TOKEN_TYPE
data["subject_token"] = content_session_token or _get_content_session_token()

response = self.params.session.post(self._get_credentials_url(), data=data)
return Credentials(**response.json())

class Credentials(TypedDict, total=False):
access_token: str
Expand Down
86 changes: 86 additions & 0 deletions tests/posit/connect/external/test_databricks.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
from typing import Dict
from unittest.mock import patch

import pytest
import responses

from posit.connect import Client
from posit.connect.external.databricks import (
POSIT_OAUTH_INTEGRATION_AUTH_TYPE,
CredentialsProvider,
CredentialsStrategy,
PositContentCredentialsProvider,
PositContentCredentialsStrategy,
PositCredentialsProvider,
PositCredentialsStrategy,
_get_auth_type,
_new_bearer_authorization_header,
)
from posit.connect.oauth import Credentials


class mock_strategy(CredentialsStrategy):
Expand Down Expand Up @@ -42,8 +49,59 @@ def register_mocks():
},
)

responses.post(
"https://connect.example/__api__/v1/oauth/integrations/credentials",
match=[
responses.matchers.urlencoded_params_matcher(
{
"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
"subject_token_type": "urn:posit:connect:content-session-token",
"subject_token": "cit",
},
),
],
json={
"access_token": "content-access-token",
"issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
"token_type": "Bearer",
},
)




class TestPositCredentialsHelpers:

def test_new_bearer_authorization_header(self):
credential = Credentials()
credential["token_type"] = "token_type"
credential["issued_token_type"] = "issued_token_type"

with pytest.raises(ValueError):
_new_bearer_authorization_header(credential)

credential["access_token"] = "access_token"
result = _new_bearer_authorization_header(credential)
assert result == {"Authorization": "Bearer access_token"}

def test_get_auth_type_local(self):
assert _get_auth_type("local-auth") == "local-auth"


@patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"})
def test_get_auth_type_connect(self):
assert _get_auth_type("local-auth") == POSIT_OAUTH_INTEGRATION_AUTH_TYPE

@patch.dict("os.environ", {"CONNECT_CONTENT_SESSION_TOKEN": "cit"})
@responses.activate
def test_posit_content_credentials_provider(self):
register_mocks()

client = Client(api_key="12345", url="https://connect.example/")
client._ctx.version = None
cp = PositContentCredentialsProvider(client=client)
assert cp() == {"Authorization": "Bearer content-access-token"}

@responses.activate
def test_posit_credentials_provider(self):
register_mocks()
Expand All @@ -53,6 +111,23 @@ def test_posit_credentials_provider(self):
cp = PositCredentialsProvider(client=client, user_session_token="cit")
assert cp() == {"Authorization": "Bearer dynamic-viewer-access-token"}

@patch.dict("os.environ", {"CONNECT_CONTENT_SESSION_TOKEN": "cit"})
@responses.activate
@patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"})
def test_posit_content_credentials_strategy(self):
register_mocks()

client = Client(api_key="12345", url="https://connect.example/")
client._ctx.version = None
cs = PositContentCredentialsStrategy(
local_strategy=mock_strategy(),
client=client,
)
cp = cs()
assert cs.auth_type() == "posit-oauth-integration"
assert cp() == {"Authorization": "Bearer content-access-token"}


@responses.activate
@patch.dict("os.environ", {"RSTUDIO_PRODUCT": "CONNECT"})
def test_posit_credentials_strategy(self):
Expand All @@ -69,6 +144,17 @@ def test_posit_credentials_strategy(self):
assert cs.auth_type() == "posit-oauth-integration"
assert cp() == {"Authorization": "Bearer dynamic-viewer-access-token"}

def test_posit_content_credentials_strategy_fallback(self):
# local_strategy is used when the content is running locally
client = Client(api_key="12345", url="https://connect.example/")
cs = PositContentCredentialsStrategy(
local_strategy=mock_strategy(),
client=client,
)
cp = cs()
assert cs.auth_type() == "local"
assert cp() == {"Authorization": "Bearer static-pat-token"}

def test_posit_credentials_strategy_fallback(self):
# local_strategy is used when the content is running locally
client = Client(api_key="12345", url="https://connect.example/")
Expand Down
Loading
Loading