From 354cf54013b816267af7a51c75dd963ee715e26f Mon Sep 17 00:00:00 2001 From: David Kurokawa Date: Tue, 29 Oct 2024 21:40:40 -0700 Subject: [PATCH] Don't open extra Snowflake connections and don't recycle connections as quickly. (#1609) --- .../trulens/connectors/snowflake/connector.py | 75 ++++++++++++------- 1 file changed, 49 insertions(+), 26 deletions(-) diff --git a/src/connectors/snowflake/trulens/connectors/snowflake/connector.py b/src/connectors/snowflake/trulens/connectors/snowflake/connector.py index adbd7a1e6..f125ad9fb 100644 --- a/src/connectors/snowflake/trulens/connectors/snowflake/connector.py +++ b/src/connectors/snowflake/trulens/connectors/snowflake/connector.py @@ -166,32 +166,12 @@ def _init_with_snowpark_session( database_check_revision: bool, connection_parameters: Dict[str, str], ): - database_args = database_args or {} - if "engine_params" not in database_args: - database_args["engine_params"] = {} - if "creator" in database_args["engine_params"]: - raise ValueError( - "Cannot set `database_args['engine_params']['creator']!" - ) - database_args["engine_params"]["creator"] = ( - lambda: snowpark_session.connection - ) - if "paramstyle" in database_args["engine_params"]: - raise ValueError( - "Cannot set `database_args['engine_params']['paramstyle']!" - ) - database_args["engine_params"]["paramstyle"] = "qmark" - - database_args.update({ - k: v - for k, v in { - "database_url": URL(**connection_parameters), - "database_redact_keys": database_redact_keys, - }.items() - if v is not None - }) - database_args["database_prefix"] = ( - database_prefix or core_db.DEFAULT_DATABASE_PREFIX + database_args = self._set_up_database_args( + database_args, + snowpark_session, + connection_parameters, + database_redact_keys, + database_prefix, ) self._db: Union[SQLAlchemyDB, python_utils.OpaqueWrapper] = ( SQLAlchemyDB.from_tru_args(**database_args) @@ -233,6 +213,49 @@ def _init_with_snowpark_session( print(f"Set TruLens workspace version tag: {res}") + def _set_up_database_args( + self, + database_args: Dict[str, Any], + snowpark_session: Session, + connection_parameters: Dict[str, str], + database_redact_keys: bool, + database_prefix: Optional[str], + ) -> Dict[str, Any]: + database_args = database_args or {} + # Set engine_params. + default_engine_params = { + "creator": lambda: snowpark_session.connection, + "paramstyle": "qmark", + # The following parameters ensure the pool does not allocate new + # connections that it will close. This is a problem because the + # "creator" does not create new connections, it only passes around + # the single one it has. + "max_overflow": 0, + "pool_recycle": -1, + "pool_timeout": 120, + } + if "engine_params" not in database_args: + database_args["engine_params"] = default_engine_params + else: + for k, v in default_engine_params.items(): + if k in database_args["engine_params"]: + raise ValueError( + f"Cannot set `database_args['engine_params']['{k}']!" + ) + # Set remaining parameters. + database_args.update({ + k: v + for k, v in { + "database_url": URL(**connection_parameters), + "database_redact_keys": database_redact_keys, + }.items() + if v is not None + }) + database_args["database_prefix"] = ( + database_prefix or core_db.DEFAULT_DATABASE_PREFIX + ) + return database_args + @staticmethod def _run_query( snowpark_session: Session,