diff --git a/src/connectors/snowflake/trulens/connectors/snowflake/connector.py b/src/connectors/snowflake/trulens/connectors/snowflake/connector.py index 078a8fbae..5db2bbdd2 100644 --- a/src/connectors/snowflake/trulens/connectors/snowflake/connector.py +++ b/src/connectors/snowflake/trulens/connectors/snowflake/connector.py @@ -47,6 +47,7 @@ def __init__( database_prefix: Optional[str] = None, database_args: Optional[Dict[str, Any]] = None, database_check_revision: bool = True, + host: Optional[str] = None, ): connection_parameters = { "account": account, @@ -56,6 +57,7 @@ def __init__( "schema": schema, "warehouse": warehouse, "role": role, + **({"host": host} if host else {}), } if snowpark_session is None: diff --git a/src/dashboard/trulens/dashboard/run.py b/src/dashboard/trulens/dashboard/run.py index e2d1682a0..a4d5f1633 100644 --- a/src/dashboard/trulens/dashboard/run.py +++ b/src/dashboard/trulens/dashboard/run.py @@ -35,6 +35,7 @@ def run_dashboard( address: Optional[str] = None, force: bool = False, _dev: Optional[Path] = None, + spcs_runtime: Optional[bool] = False, _watch_changes: bool = False, ) -> Process: """Run a streamlit dashboard to view logged results and apps. @@ -120,7 +121,8 @@ def run_dashboard( "--database-prefix", session.connector.db.table_prefix, ] - + if spcs_runtime: + args.append("--spcs-runtime") proc = subprocess.Popen( args, stdout=subprocess.PIPE, diff --git a/src/dashboard/trulens/dashboard/streamlit.py b/src/dashboard/trulens/dashboard/streamlit.py index 80d29e60e..05887f750 100644 --- a/src/dashboard/trulens/dashboard/streamlit.py +++ b/src/dashboard/trulens/dashboard/streamlit.py @@ -32,6 +32,15 @@ class FeedbackDisplay(BaseModel): icon: str +def get_spcs_login_token(): + """ + Read the login token supplied automatically by Snowflake. These tokens + are short lived and should always be read right before creating any new connection. + """ + with open("/snowflake/session/token", "r") as f: + return f.read() + + def init_from_args(): """Parse command line arguments and initialize Tru with them. @@ -43,6 +52,7 @@ def init_from_args(): parser.add_argument( "--database-prefix", default=core_db.DEFAULT_DATABASE_PREFIX ) + parser.add_argument("--spcs-runtime", default=False) try: args = parser.parse_args() @@ -54,9 +64,28 @@ def init_from_args(): # so we have to do a hard exit. sys.exit(e.code) - core_session.TruSession( - database_url=args.database_url, database_prefix=args.database_prefix - ) + if args.spcs_runtime: + import os + + from snowflake.snowpark import Session + from trulens.connectors.snowflake import SnowflakeConnector + + connection_params = { + "account": os.environ.get("SNOWFLAKE_ACCOUNT"), + "host": os.getenv("SNOWFLAKE_HOST"), + "authenticator": "oauth", + "token": get_spcs_login_token(), + "warehouse": os.environ.get("SNOWFLAKE_WAREHOUSE"), + "database": os.environ.get("SNOWFLAKE_DATABASE"), + "schema": os.environ.get("SNOWFLAKE_SCHEMA"), + } + snowpark_session = Session.builder.configs(connection_params).create() + connector = SnowflakeConnector(snowpark_session=snowpark_session) + core_session.TruSession(connector=connector) + else: + core_session.TruSession( + database_url=args.database_url, database_prefix=args.database_prefix + ) def trulens_leaderboard(app_ids: Optional[List[str]] = None): diff --git a/tools/snowflake/spcs_dashboard/Dockerfile b/tools/snowflake/spcs_dashboard/Dockerfile new file mode 100644 index 000000000..c6a3a1b0e --- /dev/null +++ b/tools/snowflake/spcs_dashboard/Dockerfile @@ -0,0 +1,12 @@ +ARG BASE_IMAGE=python:3.11.9-slim-bullseye +FROM $BASE_IMAGE + +COPY ./ /trulens_dashboard/ + +WORKDIR /trulens_dashboard + +RUN pip install -r requirements.txt +RUN pip install trulens_connectors_snowflake-1.0.1-py3-none-any.whl +RUN pip install trulens_dashboard-1.0.1-py3-none-any.whl + +CMD ["python", "run_dashboard.py"] diff --git a/tools/snowflake/spcs_dashboard/requirements.txt b/tools/snowflake/spcs_dashboard/requirements.txt new file mode 100644 index 000000000..a0b0d392e --- /dev/null +++ b/tools/snowflake/spcs_dashboard/requirements.txt @@ -0,0 +1,9 @@ +python-dotenv +pydantic +snowflake[ml] +snowflake-connector-python +snowflake-sqlalchemy +trulens +trulens-connectors-snowflake +# trulens-dashboard +# trulens-feedback diff --git a/tools/snowflake/spcs_dashboard/run_container.py b/tools/snowflake/spcs_dashboard/run_container.py new file mode 100644 index 000000000..3a424126f --- /dev/null +++ b/tools/snowflake/spcs_dashboard/run_container.py @@ -0,0 +1,145 @@ +from argparse import ArgumentParser + +from snowflake.snowpark import Session + +# get args from command line +parser = ArgumentParser(description="Run container script") +parser.add_argument( + "--build-docker", + action="store_true", + help="Build and push the Docker container", +) +args = parser.parse_args() + +session = Session.builder.create() +account = session.get_current_account() +user = session.get_current_user() +database = session.get_current_database() +schema = session.get_current_schema() +warehouse = session.get_current_warehouse() +role = session.get_current_role() + + +def run_sql_command(command: str): + print(f"Running SQL command: {command}") + result = session.sql(command).collect() + print(f"Result: {result}") + return result + + +# Check if the image repository exists, if not create it +repository_name = "TRULENS_REPOSITORY" +images = session.sql("SHOW IMAGE REPOSITORIES").collect() +repository_exists = any(image["name"] == repository_name for image in images) + +if not repository_exists: + session.sql(f"CREATE IMAGE REPOSITORY {repository_name}").collect() + print(f"Image repository {repository_name} created.") +else: + print(f"Image repository {repository_name} already exists.") + +# Retrieve the repository URL +repository_url = ( + session.sql(f"SHOW IMAGE REPOSITORIES LIKE '{repository_name}'") + .select('"repository_url"') + .collect()[0]["repository_url"] +) + +image_name = "trulens_dashboard" +image_tag = "latest" +app_name = "trulens_dashboard" +container_name = app_name + "_container" +if args.build_docker: + # local build, with docker + import subprocess + + subprocess.run( + [ + "docker", + "build", + "--platform", + "linux/amd64", + "-t", + f"{repository_url}/{image_name}:{image_tag}", + ".", + ], + check=True, + ) + subprocess.run( + ["docker", "push", f"{repository_url}/{image_name}:{image_tag}"], + check=True, + ) + + +# Create compute pool if it does not exist +compute_pool_name = input("Enter compute pool name: ") +compute_pools = session.sql("SHOW COMPUTE POOLS").collect() +compute_pool_exists = any( + pool["name"] == compute_pool_name.upper() for pool in compute_pools +) +if compute_pool_exists: + print(f"Compute pool {compute_pool_name} already exists") +else: + session.sql( + f"CREATE COMPUTE POOL {compute_pool_name} MIN_NODES = 1 MAX_NODES = 1 INSTANCE_FAMILY = CPU_X64_M" + ).collect() +session.sql(f"DESCRIBE COMPUTE POOL {compute_pool_name}").collect() + +# Create network rule +network_rule_name = f"{compute_pool_name}_allow_http_https" +session.sql( + f"CREATE OR REPLACE NETWORK RULE {network_rule_name} TYPE = 'HOST_PORT' MODE = 'EGRESS' VALUE_LIST = ('0.0.0.0:443','0.0.0.0:80')" +).collect() +session.sql("SHOW NETWORK RULES").collect() + +# Create external access integration +access_integration_name = f"{compute_pool_name}_access_integration" +session.sql( + f"CREATE OR REPLACE EXTERNAL ACCESS INTEGRATION {access_integration_name} ALLOWED_NETWORK_RULES = ({network_rule_name}) ENABLED = true" +).collect() +session.sql("SHOW EXTERNAL ACCESS INTEGRATIONS").collect() + +service_name = compute_pool_name + "_trulens_dashboard" +session.sql( + """ +CREATE SERVICE {service_name} + IN COMPUTE POOL {compute_pool_name} + EXTERNAL_ACCESS_INTEGRATIONS = ({access_integration_name}) + FROM SPECIFICATION $$ + spec: + containers: + - name: trulens-dashboard + image: /{database}/{schema}/{repository_name}/{app_name}:latest + env: + SNOWFLAKE_ACCOUNT: "{account}" + SNOWFLAKE_DATABASE: "{database}" + SNOWFLAKE_SCHEMA: "{schema}" + SNOWFLAKE_WAREHOUSE: "{warehouse}" + SNOWFLAKE_ROLE: "{role}" + RUN_DASHBOARD: "1" + endpoints: + - name: trulens-demo-dashboard-endpoint + port: 8484 + public: true + $$ +""".format( + service_name=service_name, + compute_pool_name=compute_pool_name, + access_integration_name=access_integration_name, + repository_name=repository_name, + account=account, + database=database, + schema=schema, + warehouse=warehouse, + role=role, + app_name=app_name, + ) +).collect() + +# Show services and get their status +run_sql_command(f"SHOW ENDPOINTS IN SERVICE {service_name}") +run_sql_command(f"CALL SYSTEM$GET_SERVICE_STATUS('{service_name}')") +run_sql_command(f"CALL SYSTEM$GET_SERVICE_STATUS('{service_name}')") + +# Close the session +session.close() diff --git a/tools/snowflake/spcs_dashboard/run_dashboard.py b/tools/snowflake/spcs_dashboard/run_dashboard.py new file mode 100644 index 000000000..06e80348b --- /dev/null +++ b/tools/snowflake/spcs_dashboard/run_dashboard.py @@ -0,0 +1,33 @@ +import os + +from snowflake.snowpark import Session +from trulens.connectors.snowflake import SnowflakeConnector +from trulens.core import TruSession +from trulens.dashboard import run_dashboard + + +def get_login_token(): + """ + Read the login token supplied automatically by Snowflake. These tokens + are short lived and should always be read right before creating any new connection. + """ + with open("/snowflake/session/token", "r") as f: + return f.read() + + +connection_params = { + "account": os.environ.get("SNOWFLAKE_ACCOUNT"), + "host": os.getenv("SNOWFLAKE_HOST"), + "authenticator": "oauth", + "token": get_login_token(), + "warehouse": os.environ.get("SNOWFLAKE_WAREHOUSE"), + "database": os.environ.get("SNOWFLAKE_DATABASE"), + "schema": os.environ.get("SNOWFLAKE_SCHEMA"), +} +snowpark_session = Session.builder.configs(connection_params).create() + +connector = SnowflakeConnector(snowpark_session=snowpark_session) +tru_session = TruSession(connector=connector) +tru_session.get_records_and_feedback() + +run_dashboard(tru_session, port=8484, spcs_runtime=True)