Skip to content

Commit

Permalink
Add snowflake sqlalchemy implementation (#1216)
Browse files Browse the repository at this point in the history
* add snowflake table implementation for sqlachemy

* fmt

* fmt

* joined load

* bug

* joined 2 deep

* add logging in snowflake doc

* fmt

* fmt?

* remove unneeded

* rm trulens source install instructions

* remove legacy streamlit cache clear

* fmt

* add snowflake doc to mkdocs

---------

Co-authored-by: Piotr Mardziel <[email protected]>
  • Loading branch information
sfc-gh-chu and sfc-gh-pmardziel authored Jun 21, 2024
1 parent 62896cc commit 98cb4a0
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# ❄️ Logging in Snowflake

Snowflake’s fully managed [data warehouse](https://www.snowflake.com/en/data-cloud/workloads/data-warehouse/?utm_cta=website-homepage-workload-card-data-warehouse) provides automatic provisioning, availability, tuning, data protection and more—across clouds and regions—for an unlimited number of users and jobs.

TruLens can write and read from a Snowflake database using a SQLAlchemy connection. This allows you to read, write, persist and share _TruLens_ logs in a _Snowflake_ database.

Here is a _working_ guide to logging in _Snowflake_.

## Install the [Snowflake SQLAlchemy toolkit](https://docs.snowflake.com/en/developer-guide/python-connector/sqlalchemy) with the Python Connector

For now, we need to use a working branch of snowflake-sqlalchemy that supports sqlalchemy 2.0.

!!! example "Install Snowflake-SQLAlchemy"

```bash
# Clone the Snowflake github repo:
git clone [email protected]:snowflakedb/snowflake-sqlalchemy.git

# Check out the sqlalchemy branch:
git checkout SNOW-1058245-sqlalchemy-20-support

# Install hatch:
pip install hatch

# Build snowflake-sqlalchemy via hatch:
python -m hatch build --clean

# Install snowflake-sqlalchemy
pip install dist/*.whl
```

## Connect TruLens to the Snowflake database

!!! example "Connect TruLens to the Snowflake database"

```python
from trulens_eval import Tru
tru = Tru(database_url=(
'snowflake://{user}:{password}@{account_identifier}/'
'{database}/{schema}?warehouse={warehouse}&role={role}'
).format(
user='<user>',
password='<password>',
account_identifier='<account-identifer>',
database='<database>',
schema='<schema>',
warehouse='<warehouse>',
role='<role>'
))
```
3 changes: 2 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ nav:
- trulens_eval/tracking/instrumentation/nemo.ipynb
- Logging:
# PLACEHOLDER: - trulens_eval/tracking/logging/index.md
- Where to Log: trulens_eval/tracking/logging/where_to_log.md
- Where to Log: trulens_eval/tracking/logging/where_to_log/index.md
- ❄️ Logging in Snowflake: trulens_eval/tracking/logging/where_to_log/log_in_snowflake.md
- 📓 Logging Methods: trulens_eval/tracking/logging/logging.ipynb
- 🔍 Guides:
# PLACEHOLDER: - trulens_eval/guides/index.md
Expand Down
8 changes: 6 additions & 2 deletions trulens_eval/trulens_eval/database/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sqlalchemy import VARCHAR
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import backref
from sqlalchemy.orm import configure_mappers
from sqlalchemy.orm import declarative_base
from sqlalchemy.orm import relationship
from sqlalchemy.schema import MetaData
Expand Down Expand Up @@ -315,8 +316,11 @@ def parse(
multi_result=obj.multi_result
)

#configure_mappers()
#base.registry.configure()
configure_mappers() # IMPORTANT
# Without the above, orm class attributes which are defined using backref
# will not be visible, i.e. orm.AppDefinition.records.

# base.registry.configure()

return NewORM

Expand Down
32 changes: 28 additions & 4 deletions trulens_eval/trulens_eval/database/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
)
import warnings

from alembic.ddl.impl import DefaultImpl
import numpy as np
import pandas as pd
from pydantic import Field
from sqlalchemy import create_engine
from sqlalchemy import Engine
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import sessionmaker
from sqlalchemy.sql import text as sql_text

Expand Down Expand Up @@ -51,6 +53,10 @@
logger = logging.getLogger(__name__)


class SnowflakeImpl(DefaultImpl):
__dialect__ = 'snowflake'


class SQLAlchemyDB(DB):
"""Database implemented using sqlalchemy.
Expand Down Expand Up @@ -571,11 +577,22 @@ def get_records_and_feedback(
) -> Tuple[pd.DataFrame, Sequence[str]]:
"""See [DB.get_records_and_feedback][trulens_eval.database.base.DB.get_records_and_feedback]."""

# TODO: Add pagination to this method. Currently the joinedload in
# select below disables lazy loading of records which will be a problem
# for large databases without the use of pagination.

with self.session.begin() as session:
stmt = select(self.orm.AppDefinition)
stmt = select(self.orm.AppDefinition).options(
joinedload(self.orm.AppDefinition.records)\
.joinedload(self.orm.Record.feedback_results)
)

if app_ids:
stmt = stmt.where(self.orm.AppDefinition.app_id.in_(app_ids))
apps = (row[0] for row in session.execute(stmt))

ex = session.execute(stmt).unique() # unique needed for joinedload
apps = (row[0] for row in ex)

return AppsExtractor().get_df_and_cols(apps)


Expand Down Expand Up @@ -721,29 +738,34 @@ def extract_apps(
# deserialize AppDefinition here unless we fix prior DBs
# in migration. Because of this, loading just the
# `root_class` here.

df[col] = str(
Class.model_validate(
json.loads(_app.app_json).get('root_class')
)
)

else:
df[col] = getattr(_app, col)

yield df
except OperationalError as e:
print(
"Error encountered while attempting to retrieve an app. This issue may stem from a corrupted database."
"Error encountered while attempting to retrieve an app. "
"This issue may stem from a corrupted database."
)
print(f"Error details: {e}")

def extract_records(self,
records: Iterable[orm.Record]) -> Iterable[pd.Series]:

for _rec in records:
calls = defaultdict(list)
values = defaultdict(list)

try:
for _res in _rec.feedback_results:

calls[_res.name].append(
json.loads(_res.calls_json)["calls"]
)
Expand Down Expand Up @@ -771,10 +793,12 @@ def extract_records(self,
).isoformat() if col == "ts" else getattr(_rec, col)

yield row

except Exception as e:
# Handling unexpected errors, possibly due to database issues.
print(
"Error encountered while attempting to retrieve feedback results. This issue may stem from a corrupted database."
"Error encountered while attempting to retrieve feedback results. "
"This issue may stem from a corrupted database."
)
print(f"Error details: {e}")

Expand Down
2 changes: 0 additions & 2 deletions trulens_eval/trulens_eval/pages/Evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@
from trulens_eval.ux.components import write_or_json
from trulens_eval.ux.styles import cellstyle_jscode

st.runtime.legacy_caching.clear_cache()

set_page_config(page_title="Evaluations")
st.title("Evaluations")

Expand Down

0 comments on commit 98cb4a0

Please sign in to comment.