diff --git a/src/snowflake/connector/__init__.py b/src/snowflake/connector/__init__.py index 41b5288ac..b41f3c29d 100644 --- a/src/snowflake/connector/__init__.py +++ b/src/snowflake/connector/__init__.py @@ -11,6 +11,9 @@ import logging from logging import NullHandler +from typing import TYPE_CHECKING + +from typing_extensions import Unpack from snowflake.connector.externals_utils.externals_setup import setup_external_libraries @@ -45,13 +48,26 @@ from .log_configuration import EasyLoggingConfigPython from .version import VERSION +if TYPE_CHECKING: + from os import PathLike + + from .connection import SnowflakeConnectionConfig + logging.getLogger(__name__).addHandler(NullHandler()) setup_external_libraries() @wraps(SnowflakeConnection.__init__) -def Connect(**kwargs) -> SnowflakeConnection: - return SnowflakeConnection(**kwargs) +def Connect( + connection_name: str | None = None, + connections_file_path: PathLike[str] | None = None, + **kwargs: Unpack[SnowflakeConnectionConfig], +) -> SnowflakeConnection: + return SnowflakeConnection( + connection_name=connection_name, + connections_file_path=connections_file_path, + **kwargs, + ) connect = Connect diff --git a/src/snowflake/connector/backoff_policies.py b/src/snowflake/connector/backoff_policies.py index 8e6b1010b..7cf389c07 100644 --- a/src/snowflake/connector/backoff_policies.py +++ b/src/snowflake/connector/backoff_policies.py @@ -1,7 +1,7 @@ from __future__ import annotations import random -from typing import Callable, Iterator +from typing import Callable, Generator """This module provides common implementations of backoff policies @@ -38,7 +38,7 @@ def mixed_backoff( base: int = DEFAULT_BACKOFF_BASE, cap: int = DEFAULT_BACKOFF_CAP, enable_jitter: bool = DEFAULT_ENABLE_JITTER, -) -> Callable[..., Iterator[int]]: +) -> Callable[[], Generator[int]]: """Randomly chooses between exponential and constant backoff. Uses equal jitter. Args: @@ -52,7 +52,7 @@ def mixed_backoff( Callable: generator function implementing the mixed backoff policy """ - def generator(): + def generator() -> Generator[int]: cnt = 0 sleep = base @@ -80,7 +80,7 @@ def linear_backoff( base: int = DEFAULT_BACKOFF_BASE, cap: int = DEFAULT_BACKOFF_CAP, enable_jitter: bool = DEFAULT_ENABLE_JITTER, -) -> Callable[..., Iterator[int]]: +) -> Callable[[], Generator[int]]: """Standard linear backoff. Uses full jitter. Args: @@ -94,7 +94,7 @@ def linear_backoff( Callable: generator function implementing the linear backoff policy """ - def generator(): + def generator() -> Generator[int]: sleep = base yield sleep @@ -113,7 +113,7 @@ def exponential_backoff( base: int = DEFAULT_BACKOFF_BASE, cap: int = DEFAULT_BACKOFF_CAP, enable_jitter: bool = DEFAULT_ENABLE_JITTER, -) -> Callable[..., Iterator[int]]: +) -> Callable[[], Generator[int]]: """Standard exponential backoff. Uses full jitter. Args: @@ -127,7 +127,7 @@ def exponential_backoff( Callable: generator function implementing the exponential backoff policy """ - def generator(): + def generator() -> Generator[int]: sleep = base yield sleep diff --git a/src/snowflake/connector/connection.py b/src/snowflake/connector/connection.py index f89e6dfc0..574f29615 100644 --- a/src/snowflake/connector/connection.py +++ b/src/snowflake/connector/connection.py @@ -21,12 +21,23 @@ from logging import getLogger from threading import Lock from types import TracebackType -from typing import Any, Callable, Generator, Iterable, Iterator, NamedTuple, Sequence +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generator, + Iterable, + Iterator, + NamedTuple, + Sequence, + TypedDict, +) from uuid import UUID from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey +from typing_extensions import Unpack from . import errors, proxy from ._query_context_cache import QueryContextCache @@ -123,6 +134,10 @@ from .util_text import construct_hostname, parse_account, split_statements from .wif_util import AttestationProvider +if TYPE_CHECKING: + from os import PathLike + + DEFAULT_CLIENT_PREFETCH_THREADS = 4 MAX_CLIENT_PREFETCH_THREADS = 10 MAX_CLIENT_FETCH_THREADS = 1024 @@ -378,6 +393,53 @@ class TypeAndBinding(NamedTuple): binding: str | None +class SnowflakeConnectionConfig(TypedDict): + """Configuration type for the SnowflakeConnection.""" + + insecure_mode: bool + disable_ocsp_checks: bool + ocsp_fail_open: bool + session_id: int + user: str + host: str + port: int + region: str + proxy_host: str + proxy_port: str + proxy_user: str + proxy_password: str + account: str + database: str + schema: str + warehouse: str + role: str + login_timeout: int + network_timeout: int + socket_timeout: int + backoff_policy: Callable[[], Generator[int]] + client_session_keep_alive_heartbeat_frequency: int + client_prefetch_threads: int + client_fetch_threads: int + rest: SnowflakeRestful + application: str + errorhandler: Callable + converter_class: type[SnowflakeConverter] + validate_default_parameters: bool + is_pyformat: bool + consent_cache_id_token: str + enable_stage_s3_privatelink_for_us_east_1: bool + enable_connection_diag: bool + connection_diag_log_path: PathLike[str] | str + connection_diag_whitelist_path: PathLike[str] | str + connection_diag_allowlist_path: PathLike[str] | str + json_result_force_utf8_decoding: bool + server_session_keep_alive: bool + token_file_path: PathLike[str] | str + unsafe_file_write: bool + gcs_use_virtual_endpoints: bool + check_arrow_conversion_error_on_every_column: bool + + class SnowflakeConnection: """Implementation of the connection object for the Snowflake Database. @@ -448,8 +510,8 @@ class SnowflakeConnection: def __init__( self, connection_name: str | None = None, - connections_file_path: pathlib.Path | None = None, - **kwargs, + connections_file_path: PathLike[str] | None = None, + **kwargs: Unpack[SnowflakeConnectionConfig], ) -> None: """Create a new SnowflakeConnection. @@ -651,7 +713,7 @@ def socket_timeout(self) -> int | None: return int(self._socket_timeout) if self._socket_timeout is not None else None @property - def _backoff_generator(self) -> Iterator: + def _backoff_generator(self) -> Generator[int]: return self._backoff_policy() @property @@ -983,7 +1045,7 @@ def autocommit(self, mode) -> None: except Error as e: if e.sqlstate == SQLSTATE_FEATURE_NOT_SUPPORTED: logger.debug( - "Autocommit feature is not enabled for this " "connection. Ignored" + "Autocommit feature is not enabled for this connection. Ignored" ) def commit(self) -> None: @@ -1166,7 +1228,7 @@ def __open_connection(self): elif self._authenticator == EXTERNAL_BROWSER_AUTHENTICATOR: self._session_parameters[ PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL - ] = (self._client_store_temporary_credential if IS_LINUX else True) + ] = self._client_store_temporary_credential if IS_LINUX else True auth.read_temporary_credentials( self.host, self.user,