diff --git a/src/connectors/snowflake/trulens/connectors/snowflake/connector.py b/src/connectors/snowflake/trulens/connectors/snowflake/connector.py index ea4334dae..00f34e529 100644 --- a/src/connectors/snowflake/trulens/connectors/snowflake/connector.py +++ b/src/connectors/snowflake/trulens/connectors/snowflake/connector.py @@ -155,6 +155,20 @@ def _validate_snowpark_session_with_connection_parameters( self.password_known = True return snowpark_session_connection_parameters + @staticmethod + def _validate_snowpark_session_paramstyle( + snowpark_session: Session, + ) -> None: + if snowpark_session.connection._paramstyle == "pyformat": + # If this is the case, sql executions with bindings will fail later + # on so we fail fast here. + raise ValueError( + "The Snowpark session must have paramstyle 'qmark'! To ensure" + " this, during `snowflake.connector.connect` pass in" + " `paramstyle='qmark'` or set" + " `snowflake.connector.paramstyle = 'qmark'` beforehand." + ) + def _init_with_snowpark_session( self, snowpark_session: Session, @@ -166,6 +180,7 @@ def _init_with_snowpark_session( database_check_revision: bool, connection_parameters: Dict[str, str], ): + self._validate_snowpark_session_paramstyle(snowpark_session) database_args = self._set_up_database_args( database_args, snowpark_session, diff --git a/tests/e2e/test_snowflake_connection.py b/tests/e2e/test_snowflake_connection.py index a5d6fdca1..b437c4fb0 100644 --- a/tests/e2e/test_snowflake_connection.py +++ b/tests/e2e/test_snowflake_connection.py @@ -5,6 +5,9 @@ from unittest import main import uuid +import snowflake.connector +from snowflake.snowpark import Session +from trulens.connectors.snowflake import SnowflakeConnector from trulens.dashboard import run_dashboard from trulens.dashboard import stop_dashboard @@ -68,6 +71,37 @@ def test_run_leaderboard_without_password(self): except Exception: pass + @optional_test + def test_paramstyle_pyformat(self): + default_paramstyle = snowflake.connector.paramstyle + try: + # pyformat paramstyle should fail fast. + snowflake.connector.paramstyle = "pyformat" + schema_name = self.create_and_use_schema( + "test_paramstyle_pyformat", append_uuid=True + ) + snowflake_connection = snowflake.connector.connect( + **self._snowflake_connection_parameters, schema=schema_name + ) + snowpark_session = Session.builder.configs({ + "connection": snowflake_connection + }).create() + with self.assertRaisesRegex( + ValueError, "The Snowpark session must have paramstyle 'qmark'!" + ): + SnowflakeConnector(snowpark_session=snowpark_session) + # qmark paramstyle should be fine. + snowflake.connector.paramstyle = "qmark" + snowflake_connection = snowflake.connector.connect( + **self._snowflake_connection_parameters, schema=schema_name + ) + snowpark_session = Session.builder.configs({ + "connection": snowflake_connection + }).create() + SnowflakeConnector(snowpark_session=snowpark_session) + finally: + snowflake.connector.paramstyle = default_paramstyle + if __name__ == "__main__": main() diff --git a/tests/e2e/test_snowflake_notebooks.py b/tests/e2e/test_snowflake_notebooks.py index 015e823e8..a656f67ee 100644 --- a/tests/e2e/test_snowflake_notebooks.py +++ b/tests/e2e/test_snowflake_notebooks.py @@ -1,7 +1,6 @@ import tempfile from typing import Sequence from unittest import main -import uuid from trulens.connectors.snowflake.utils.server_side_evaluation_artifacts import ( _TRULENS_PACKAGES, @@ -18,8 +17,7 @@ class TestSnowflakeNotebooks(SnowflakeTestCase): def test_simple(self) -> None: - schema_name = f"test_simple_{str(uuid.uuid4()).replace('-', '_')}" - self.create_and_use_schema(schema_name) + self.create_and_use_schema("test_simple", append_uuid=True) self._upload_and_run_notebook("simple", _TRULENS_PACKAGES) def test_staged_packages(self) -> None: diff --git a/tests/util/snowflake_test_case.py b/tests/util/snowflake_test_case.py index 53a2f5f43..da1905185 100644 --- a/tests/util/snowflake_test_case.py +++ b/tests/util/snowflake_test_case.py @@ -117,10 +117,19 @@ def run_query( ) -> List[Row]: return self._snowpark_session.sql(q, bindings).collect() - def create_and_use_schema(self, schema_name: str) -> None: + def create_and_use_schema( + self, schema_name: str, append_uuid: bool = False + ) -> str: + schema_name = schema_name.upper() + if append_uuid: + schema_name = ( + f"{schema_name}__{str(uuid.uuid4()).replace('-', '_')}" + ) + self._schema = schema_name self.run_query("CREATE SCHEMA IDENTIFIER(?)", [schema_name]) self._snowflake_schemas_to_delete.add(schema_name) self._snowpark_session.use_schema(schema_name) + return schema_name if __name__ == "__main__":