diff --git a/src/snowflake/connector/auth/oauth_code.py b/src/snowflake/connector/auth/oauth_code.py index 7ef63713e..46a990903 100644 --- a/src/snowflake/connector/auth/oauth_code.py +++ b/src/snowflake/connector/auth/oauth_code.py @@ -4,8 +4,11 @@ from __future__ import annotations +import base64 +import hashlib import json import logging +import re import secrets import socket import time @@ -57,6 +60,7 @@ def __init__( token_request_url: str, redirect_uri: str, scope: str, + pkce: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) @@ -74,6 +78,8 @@ def __init__( logger.debug("chose oauth state: %s", self._state) self._oauth_token = None self._protocol = "http" + self.pkce = pkce + self._verifier: str | None = None def reset_secrets(self) -> None: self._oauth_token = None @@ -104,6 +110,19 @@ def construct_url(self) -> str: "scope": self.scope, "state": self._state, } + if self.pkce: + self._verifier = secrets.token_urlsafe(43) + self._verifier = re.sub("[^a-zA-Z0-9]+", "", self._verifier) + # calculate challenge and verifier + challenge = ( + base64.urlsafe_b64encode( + hashlib.sha256(self._verifier.encode("utf-8")).digest() + ) + .decode("utf-8") + .replace("=", "") + ) + params["code_challenge"] = challenge + params["code_challenge_method"] = "S256" url_params = urllib.parse.urlencode(params) url = f"{self.authentication_url}?{url_params}" return url @@ -186,6 +205,10 @@ def prepare( } if self.client_secret: fields["client_secret"] = self.client_secret + if self.pkce: + assert self._verifier is not None + fields["code_verifier"] = self._verifier + resp = urllib3.PoolManager().request_encode_body( # TODO: use network pool to gain use of proxy settings and so on "POST", self.token_request_url, diff --git a/src/snowflake/connector/connection.py b/src/snowflake/connector/connection.py index 733187790..83782d836 100644 --- a/src/snowflake/connector/connection.py +++ b/src/snowflake/connector/connection.py @@ -6,6 +6,7 @@ from __future__ import annotations import atexit +import collections.abc import logging import os import pathlib @@ -338,6 +339,11 @@ def _get_private_bytes_from_file( str, # SNOW-1825621: OAUTH implementation ), + "oauth_security_features": ( + ("pkce",), + collections.abc.Iterable, # of strings + # SNOW-1825621: OAUTH PKCE + ), } APPLICATION_RE = re.compile(r"[\w\d_]+") @@ -1122,6 +1128,7 @@ def __open_connection(self): backoff_generator=self._backoff_generator, ) elif self._authenticator == OAUTH_AUTHORIZATION_CODE: + pkce = "pkce" in map(lambda e: e.lower(), self._oauth_security_features) if self._client_id is None: Error.errorhandler_wrapper( self, @@ -1154,6 +1161,7 @@ def __open_connection(self): ), redirect_uri=self._oauth_redirect_uri, scope=self._oauth_scope.format(role=self._role), + pkce=pkce, ) elif self._authenticator == USR_PWD_MFA_AUTHENTICATOR: self._session_parameters[PARAMETER_CLIENT_REQUEST_MFA_TOKEN] = (