Skip to content

Commit

Permalink
If the Snowflake connector supplied to us has a "pyformat" paramstyle…
Browse files Browse the repository at this point in the history
…, fail fast and give helpful advice to fix it. (#1637)
  • Loading branch information
sfc-gh-dkurokawa authored Nov 15, 2024
1 parent 959ba5e commit 4b7c920
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 4 deletions.
15 changes: 15 additions & 0 deletions src/connectors/snowflake/trulens/connectors/snowflake/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,20 @@ def _validate_snowpark_session_with_connection_parameters(
self.password_known = True
return snowpark_session_connection_parameters

@staticmethod
def _validate_snowpark_session_paramstyle(
snowpark_session: Session,
) -> None:
if snowpark_session.connection._paramstyle == "pyformat":
# If this is the case, sql executions with bindings will fail later
# on so we fail fast here.
raise ValueError(
"The Snowpark session must have paramstyle 'qmark'! To ensure"
" this, during `snowflake.connector.connect` pass in"
" `paramstyle='qmark'` or set"
" `snowflake.connector.paramstyle = 'qmark'` beforehand."
)

def _init_with_snowpark_session(
self,
snowpark_session: Session,
Expand All @@ -166,6 +180,7 @@ def _init_with_snowpark_session(
database_check_revision: bool,
connection_parameters: Dict[str, str],
):
self._validate_snowpark_session_paramstyle(snowpark_session)
database_args = self._set_up_database_args(
database_args,
snowpark_session,
Expand Down
34 changes: 34 additions & 0 deletions tests/e2e/test_snowflake_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from unittest import main
import uuid

import snowflake.connector
from snowflake.snowpark import Session
from trulens.connectors.snowflake import SnowflakeConnector
from trulens.dashboard import run_dashboard
from trulens.dashboard import stop_dashboard

Expand Down Expand Up @@ -68,6 +71,37 @@ def test_run_leaderboard_without_password(self):
except Exception:
pass

@optional_test
def test_paramstyle_pyformat(self):
default_paramstyle = snowflake.connector.paramstyle
try:
# pyformat paramstyle should fail fast.
snowflake.connector.paramstyle = "pyformat"
schema_name = self.create_and_use_schema(
"test_paramstyle_pyformat", append_uuid=True
)
snowflake_connection = snowflake.connector.connect(
**self._snowflake_connection_parameters, schema=schema_name
)
snowpark_session = Session.builder.configs({
"connection": snowflake_connection
}).create()
with self.assertRaisesRegex(
ValueError, "The Snowpark session must have paramstyle 'qmark'!"
):
SnowflakeConnector(snowpark_session=snowpark_session)
# qmark paramstyle should be fine.
snowflake.connector.paramstyle = "qmark"
snowflake_connection = snowflake.connector.connect(
**self._snowflake_connection_parameters, schema=schema_name
)
snowpark_session = Session.builder.configs({
"connection": snowflake_connection
}).create()
SnowflakeConnector(snowpark_session=snowpark_session)
finally:
snowflake.connector.paramstyle = default_paramstyle


if __name__ == "__main__":
main()
4 changes: 1 addition & 3 deletions tests/e2e/test_snowflake_notebooks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import tempfile
from typing import Sequence
from unittest import main
import uuid

from trulens.connectors.snowflake.utils.server_side_evaluation_artifacts import (
_TRULENS_PACKAGES,
Expand All @@ -18,8 +17,7 @@

class TestSnowflakeNotebooks(SnowflakeTestCase):
def test_simple(self) -> None:
schema_name = f"test_simple_{str(uuid.uuid4()).replace('-', '_')}"
self.create_and_use_schema(schema_name)
self.create_and_use_schema("test_simple", append_uuid=True)
self._upload_and_run_notebook("simple", _TRULENS_PACKAGES)

def test_staged_packages(self) -> None:
Expand Down
11 changes: 10 additions & 1 deletion tests/util/snowflake_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,19 @@ def run_query(
) -> List[Row]:
return self._snowpark_session.sql(q, bindings).collect()

def create_and_use_schema(self, schema_name: str) -> None:
def create_and_use_schema(
self, schema_name: str, append_uuid: bool = False
) -> str:
schema_name = schema_name.upper()
if append_uuid:
schema_name = (
f"{schema_name}__{str(uuid.uuid4()).replace('-', '_')}"
)
self._schema = schema_name
self.run_query("CREATE SCHEMA IDENTIFIER(?)", [schema_name])
self._snowflake_schemas_to_delete.add(schema_name)
self._snowpark_session.use_schema(schema_name)
return schema_name


if __name__ == "__main__":
Expand Down

0 comments on commit 4b7c920

Please sign in to comment.