Skip to content

Commit

Permalink
implement PKCE
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-mkeller committed Jan 16, 2025
1 parent 6d025e5 commit dc687f4
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
23 changes: 23 additions & 0 deletions src/snowflake/connector/auth/oauth_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@

from __future__ import annotations

import base64
import hashlib
import json
import logging
import re
import secrets
import socket
import time
Expand Down Expand Up @@ -57,6 +60,7 @@ def __init__(
token_request_url: str,
redirect_uri: str,
scope: str,
pkce: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -74,6 +78,8 @@ def __init__(
logger.debug("chose oauth state: %s", self._state)
self._oauth_token = None
self._protocol = "http"
self.pkce = pkce
self._verifier: str | None = None

def reset_secrets(self) -> None:
self._oauth_token = None
Expand Down Expand Up @@ -104,6 +110,19 @@ def construct_url(self) -> str:
"scope": self.scope,
"state": self._state,
}
if self.pkce:
self._verifier = secrets.token_urlsafe(43)
self._verifier = re.sub("[^a-zA-Z0-9]+", "", self._verifier)
# calculate challenge and verifier
challenge = (
base64.urlsafe_b64encode(
hashlib.sha256(self._verifier.encode("utf-8")).digest()
)
.decode("utf-8")
.replace("=", "")
)
params["code_challenge"] = challenge
params["code_challenge_method"] = "S256"
url_params = urllib.parse.urlencode(params)
url = f"{self.authentication_url}?{url_params}"
return url
Expand Down Expand Up @@ -186,6 +205,10 @@ def prepare(
}
if self.client_secret:
fields["client_secret"] = self.client_secret
if self.pkce:
assert self._verifier is not None
fields["code_verifier"] = self._verifier

resp = urllib3.PoolManager().request_encode_body( # TODO: use network pool to gain use of proxy settings and so on
"POST",
self.token_request_url,
Expand Down
8 changes: 8 additions & 0 deletions src/snowflake/connector/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from __future__ import annotations

import atexit
import collections.abc
import logging
import os
import pathlib
Expand Down Expand Up @@ -338,6 +339,11 @@ def _get_private_bytes_from_file(
str,
# SNOW-1825621: OAUTH implementation
),
"oauth_security_features": (
("pkce",),
collections.abc.Iterable, # of strings
# SNOW-1825621: OAUTH PKCE
),
}

APPLICATION_RE = re.compile(r"[\w\d_]+")
Expand Down Expand Up @@ -1122,6 +1128,7 @@ def __open_connection(self):
backoff_generator=self._backoff_generator,
)
elif self._authenticator == OAUTH_AUTHORIZATION_CODE:
pkce = "pkce" in map(lambda e: e.lower(), self._oauth_security_features)
if self._client_id is None:
Error.errorhandler_wrapper(
self,
Expand Down Expand Up @@ -1154,6 +1161,7 @@ def __open_connection(self):
),
redirect_uri=self._oauth_redirect_uri,
scope=self._oauth_scope.format(role=self._role),
pkce=pkce,
)
elif self._authenticator == USR_PWD_MFA_AUTHENTICATOR:
self._session_parameters[PARAMETER_CLIENT_REQUEST_MFA_TOKEN] = (
Expand Down

0 comments on commit dc687f4

Please sign in to comment.