diff --git a/.azure_pipelines/ci-eval.yaml b/.azure_pipelines/ci-eval.yaml index 484923ac8..e35cbc367 100644 --- a/.azure_pipelines/ci-eval.yaml +++ b/.azure_pipelines/ci-eval.yaml @@ -98,10 +98,16 @@ jobs: cd ./trulens_eval pip install pytest==7.0.1 pytest-azurepipelines pip install -r trulens_eval/requirements.optional.txt + displayName: Install optional deps - echo '::group::piplist' + - bash: | + source activate $(condaEnvFileSuffix) echo "$(pip list)" - echo '::endgroup::' + displayName: Pip list + + - bash: | + source activate $(condaEnvFileSuffix) + cd ./trulens_eval python -m pytest $(testSubdirectory) displayName: Run notebook tests diff --git a/docs/trulens_eval/llama_index_instrumentation.ipynb b/docs/trulens_eval/llama_index_instrumentation.ipynb index fb9635f7d..b855c2725 100644 --- a/docs/trulens_eval/llama_index_instrumentation.ipynb +++ b/docs/trulens_eval/llama_index_instrumentation.ipynb @@ -63,7 +63,7 @@ "tru_query_engine_recorder = TruLlama(query_engine)\n", "\n", "with tru_query_engine_recorder as recording:\n", - " llm_response = query_engine.query(\"What did the author do growing up?\")" + " print(query_engine.query(\"What did the author do growing up?\"))" ] }, { @@ -142,7 +142,8 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index import VectorStoreIndex, SimpleWebPageReader\n", + "from llama_index import VectorStoreIndex\n", + "from llama_index.readers.web import SimpleWebPageReader\n", "from trulens_eval import TruLlama\n", "\n", "documents = SimpleWebPageReader(html_to_text=True).load_data(\n", @@ -201,7 +202,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.4" + "version": "3.11.6" }, "orig_nbformat": 4, "vscode": { diff --git a/docs/trulens_eval/logging.ipynb b/docs/trulens_eval/logging.ipynb index 35a12bc90..6f07697ea 100644 --- a/docs/trulens_eval/logging.ipynb +++ b/docs/trulens_eval/logging.ipynb @@ -21,6 +21,39 @@ "metadata": {}, "outputs": [], "source": [ + "from IPython.display import JSON\n", + "\n", + "# Imports main tools:\n", + "from trulens_eval import Feedback\n", + "from trulens_eval import Huggingface\n", + "from trulens_eval import Tru\n", + "from trulens_eval import TruChain\n", + "from trulens_eval.schema import FeedbackResult\n", + "\n", + "tru = Tru()\n", + "\n", + "Tru().migrate_database()\n", + "\n", + "from langchain.chains import LLMChain\n", + "from langchain.llms import OpenAI\n", + "from langchain.prompts import ChatPromptTemplate\n", + "from langchain.prompts import HumanMessagePromptTemplate\n", + "from langchain.prompts import PromptTemplate\n", + "\n", + "full_prompt = HumanMessagePromptTemplate(\n", + " prompt=PromptTemplate(\n", + " template=\n", + " \"Provide a helpful response with relevant background information for the following: {prompt}\",\n", + " input_variables=[\"prompt\"],\n", + " )\n", + ")\n", + "\n", + "chat_prompt_template = ChatPromptTemplate.from_messages([full_prompt])\n", + "\n", + "llm = OpenAI(temperature=0.9, max_tokens=128)\n", + "\n", + "chain = LLMChain(llm=llm, prompt=chat_prompt_template, verbose=True)\n", + "\n", "truchain = TruChain(\n", " chain,\n", " app_id='Chain1_ChatApplication',\n", @@ -37,6 +70,21 @@ "Feedback functions can also be logged automatically by providing them in a list to the feedbacks arg." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize Huggingface-based feedback function collection class:\n", + "hugs = Huggingface()\n", + "\n", + "# Define a language match feedback function using HuggingFace.\n", + "f_lang_match = Feedback(hugs.language_match).on_input_output()\n", + "# By default this will check language match on the main app input and main app\n", + "# output." + ] + }, { "cell_type": "code", "execution_count": null, @@ -147,9 +195,11 @@ "outputs": [], "source": [ "thumb_result = True\n", - "tru.add_feedback(name=\"👍 (1) or 👎 (0)\", \n", - " record_id=record.record_id, \n", - " result=thumb_result)" + "tru.add_feedback(\n", + " name=\"👍 (1) or 👎 (0)\", \n", + " record_id=record.record_id, \n", + " result=thumb_result\n", + ")" ] }, { @@ -177,7 +227,8 @@ " record=record,\n", " feedback_functions=[f_lang_match]\n", ")\n", - "display(feedback_results)" + "for result in feedback_results:\n", + " display(result)" ] }, { @@ -225,9 +276,11 @@ " feedback_mode=\"deferred\"\n", ")\n", "\n", + "with truchain:\n", + " chain(\"This will be logged by deferred evaluator.\")\n", + "\n", "tru.start_evaluator()\n", - "truchain(\"This will be logged by deferred evaluator.\")\n", - "tru.stop_evaluator()" + "# tru.stop_evaluator()" ] } ], @@ -247,7 +300,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.3" + "version": "3.8.16" } }, "nbformat": 4, diff --git a/trulens_eval/DEPRECATION.md b/trulens_eval/DEPRECATION.md index 92dce5d2e..79aa559f8 100644 --- a/trulens_eval/DEPRECATION.md +++ b/trulens_eval/DEPRECATION.md @@ -3,10 +3,13 @@ ## Changes in 0.19.0 - Migrated from pydantic v1 to v2 incurring various changes. -- `ObjSerial` class removed. `Obj` now indicate whether they are loadable when - `init_bindings` is not None. - `SingletonPerName` field `instances` renamed to `_instances` due to possible shadowing of `instances` field in subclassed models. + +### Breaking DB changes (migration script should be able to take care of these) + +- `ObjSerial` class removed. `Obj` now indicate whether they are loadable when + `init_bindings` is not None. - `WithClassInfo` field `__tru_class_info` renamed to `tru_class_info` as pydantic does not allow underscore fields. diff --git a/trulens_eval/examples/experimental/dashboard_appui.ipynb b/trulens_eval/examples/experimental/dashboard_appui.ipynb index 11e7edd93..a8255a8fa 100644 --- a/trulens_eval/examples/experimental/dashboard_appui.ipynb +++ b/trulens_eval/examples/experimental/dashboard_appui.ipynb @@ -247,7 +247,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.8.16" } }, "nbformat": 4, diff --git a/trulens_eval/examples/experimental/db_populate.ipynb b/trulens_eval/examples/experimental/db_populate.ipynb new file mode 100644 index 000000000..8b1ebf420 --- /dev/null +++ b/trulens_eval/examples/experimental/db_populate.ipynb @@ -0,0 +1,383 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# DB Populate Notebook\n", + "\n", + "This notebook populates the database with a variety of apps, records, and\n", + "feedback results. It is used primarily for database migration testing." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "from pathlib import Path\n", + "import sys\n", + "\n", + "# If running from github repo, can use this:\n", + "sys.path.append(str(Path().cwd().parent.parent.resolve()))\n", + "\n", + "# Enables: Debugging printouts.\n", + "\"\"\"\n", + "import logging\n", + "root = logging.getLogger()\n", + "root.setLevel(logging.DEBUG)\n", + "\n", + "handler = logging.StreamHandler(sys.stdout)\n", + "handler.setLevel(logging.DEBUG)\n", + "formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')\n", + "handler.setFormatter(formatter)\n", + "root.addHandler(handler)\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# ! pip install llama_index==0.9.15.post2\n", + "# ! pip install pydantic==2.5.2 pydantic_core==2.14.5" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# To test out DB migrations, copy one of the older db dumps to this folder first:\n", + "\n", + "! ls ../../release_dbs/\n", + "! cp ../../release_dbs/0.1.2/default.sqlite default.sqlite\n", + "# ! rm default.sqlite" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from concurrent.futures import as_completed\n", + "import json\n", + "import os\n", + "from pathlib import Path\n", + "from time import sleep\n", + "\n", + "import dotenv\n", + "from tqdm.auto import tqdm\n", + "\n", + "from trulens_eval import Feedback\n", + "from trulens_eval import Tru\n", + "from trulens_eval.feedback.provider.endpoint.base import Endpoint\n", + "from trulens_eval.feedback.provider.hugs import Dummy\n", + "from trulens_eval.schema import Cost\n", + "from trulens_eval.schema import FeedbackMode\n", + "from trulens_eval.schema import Record\n", + "from trulens_eval.tru_custom_app import TruCustomApp\n", + "from trulens_eval.utils.threading import TP" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Setup Tru and/or dashboard.\n", + "\n", + "tru = Tru(database_redact_keys=True)\n", + "\n", + "# tru.reset_database()\n", + "\n", + "tru.start_dashboard(\n", + " force = True,\n", + " _dev=Path().cwd().parent.parent.resolve()\n", + ")\n", + "\n", + "Tru().migrate_database()\n", + "\n", + "from trulens_eval.database.migrations.db_data_migration import _sql_alchemy_serialization_asserts\n", + "_sql_alchemy_serialization_asserts(tru.db)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Feedbacks" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Dummy endpoint\n", + "\n", + "dummy = Dummy(\n", + " loading_prob=0.1,\n", + " freeze_prob=0.0, # we expect requests to have their own timeouts so freeze should never happen\n", + " error_prob=0.01,\n", + " overloaded_prob=0.1,\n", + " rpm=6000\n", + ")\n", + "\n", + "f_lang_match_dummy = Feedback(\n", + " dummy.language_match\n", + ").on_input_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Huggingface endpoint\n", + "from trulens_eval import Huggingface\n", + "\n", + "hugs = Huggingface()\n", + "\n", + "f_lang_match_hugs = Feedback(hugs.language_match).on_input_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# import inspect\n", + "# inspect.signature(Huggingface).bind()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Openai endpoint\n", + "from trulens_eval import OpenAI\n", + "openai = OpenAI()\n", + "\n", + "f_relevance_openai = Feedback(openai.relevance).on_input_output()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Bedrock endpoint\n", + "# Cohere as endpoint\n", + "# Langchain as endpoint\n", + "# Litellm as endpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "feedbacks = [f_lang_match_hugs, f_lang_match_dummy, f_relevance_openai]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Langchain app" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.llms import OpenAI\n", + "from langchain.chains import ConversationChain\n", + "from langchain.memory import ConversationSummaryBufferMemory\n", + "\n", + "llm = OpenAI(temperature=0.9, max_tokens=128)\n", + "\n", + "# Conversation memory.\n", + "memory = ConversationSummaryBufferMemory(\n", + " k=4,\n", + " max_token_limit=64,\n", + " llm=llm,\n", + ")\n", + "\n", + "# Conversational app puts it all together.\n", + "app_langchain = ConversationChain(\n", + " llm=llm,\n", + " memory=memory\n", + ")\n", + "\n", + "from langchain.prompts import PromptTemplate\n", + "from trulens_eval.instruments import instrument\n", + "instrument.method(PromptTemplate, \"format\")\n", + "\n", + "truchain = tru.Chain(app_langchain, app_id=\"langchain_app\", feedbacks=feedbacks)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with truchain as recs:\n", + " print(app_langchain(\"Hello?\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Llama-index app" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index import VectorStoreIndex\n", + "from llama_index.readers.web import SimpleWebPageReader\n", + "\n", + "documents = SimpleWebPageReader(\n", + " html_to_text=True\n", + ").load_data([\"http://paulgraham.com/worked.html\"])\n", + "index = VectorStoreIndex.from_documents(documents)\n", + "\n", + "query_engine = index.as_query_engine()\n", + "\n", + "trullama = tru.Llama(query_engine, app_id=\"llama_index_app\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with trullama as recs:\n", + " print(query_engine.query(\"Who is the author?\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Basic app" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from trulens_eval.tru_custom_app import instrument\n", + "\n", + "def custom_application(prompt: str) -> str:\n", + " return f\"a useful response to {prompt}\"\n", + "\n", + "trubasic = tru.Basic(custom_application, app_id=\"basic_app\", feedbacks=feedbacks)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with trubasic as recs:\n", + " print(trubasic.app(\"hello?\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Custom app" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from examples.expositional.end2end_apps.custom_app.custom_app import CustomApp # our custom app\n", + "\n", + "# Create custom app:\n", + "app_custom = CustomApp()\n", + "\n", + "# Create trulens wrapper:\n", + "trucustom = tru.Custom(\n", + " app=app_custom,\n", + " app_id=\"custom_app\",\n", + " \n", + " # Make sure to specify using the bound method, bound to self=app.\n", + " main_method=app_custom.respond_to_query,\n", + " feedbacks=feedbacks\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with trucustom as recs:\n", + " print(app_custom.respond_to_query(\"hello there\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "py38_trulens", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/trulens_eval/examples/expositional/frameworks/llama_index/llama_index_groundtruth.ipynb b/trulens_eval/examples/expositional/frameworks/llama_index/llama_index_groundtruth.ipynb index 1ec82aea3..a6fd88d71 100644 --- a/trulens_eval/examples/expositional/frameworks/llama_index/llama_index_groundtruth.ipynb +++ b/trulens_eval/examples/expositional/frameworks/llama_index/llama_index_groundtruth.ipynb @@ -37,7 +37,8 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index import VectorStoreIndex, SimpleWebPageReader\n", + "from llama_index import VectorStoreIndex\n", + "from llama_index.readers.web import SimpleWebPageReader\n", "import openai\n", "\n", "from trulens_eval import TruLlama, Feedback, Tru, feedback, FeedbackMode\n", diff --git a/trulens_eval/examples/expositional/vector-dbs/milvus/milvus_evals_build_better_rags.ipynb b/trulens_eval/examples/expositional/vector-dbs/milvus/milvus_evals_build_better_rags.ipynb index 71490860c..bc03f9dd0 100644 --- a/trulens_eval/examples/expositional/vector-dbs/milvus/milvus_evals_build_better_rags.ipynb +++ b/trulens_eval/examples/expositional/vector-dbs/milvus/milvus_evals_build_better_rags.ipynb @@ -73,10 +73,10 @@ "from llama_index.llms import OpenAI\n", "from llama_index import (\n", " VectorStoreIndex,\n", - " SimpleWebPageReader,\n", " LLMPredictor,\n", " ServiceContext\n", ")\n", + "from llama_index.readers.web import SimpleWebPageReader\n", "\n", "from langchain.embeddings import HuggingFaceEmbeddings\n", "from langchain.embeddings.openai import OpenAIEmbeddings\n", diff --git a/trulens_eval/examples/expositional/vector-dbs/milvus/milvus_simple.ipynb b/trulens_eval/examples/expositional/vector-dbs/milvus/milvus_simple.ipynb index a37955c46..54a89c1b9 100644 --- a/trulens_eval/examples/expositional/vector-dbs/milvus/milvus_simple.ipynb +++ b/trulens_eval/examples/expositional/vector-dbs/milvus/milvus_simple.ipynb @@ -75,10 +75,10 @@ "from llama_index.llms import OpenAI\n", "from llama_index import (\n", " VectorStoreIndex,\n", - " SimpleWebPageReader,\n", " LLMPredictor,\n", " ServiceContext\n", ")\n", + "from llama_index.readers.web import SimpleWebPageReader\n", "\n", "from trulens_eval import TruLlama, Feedback, Tru, feedback\n", "from trulens_eval.feedback import GroundTruthAgreement, Groundedness\n", diff --git a/trulens_eval/examples/expositional/vector-dbs/pinecone/pinecone_quickstart.ipynb b/trulens_eval/examples/expositional/vector-dbs/pinecone/pinecone_quickstart.ipynb index 35a38959e..8878cb7ec 100644 --- a/trulens_eval/examples/expositional/vector-dbs/pinecone/pinecone_quickstart.ipynb +++ b/trulens_eval/examples/expositional/vector-dbs/pinecone/pinecone_quickstart.ipynb @@ -73,10 +73,10 @@ "from llama_index.llms import OpenAI\n", "from llama_index import (\n", " VectorStoreIndex,\n", - " SimpleWebPageReader,\n", " LLMPredictor,\n", " ServiceContext\n", ")\n", + "from llama_index.readers.web import SimpleWebPageReader\n", "\n", "from trulens_eval import TruLlama, Feedback, Tru, feedback\n", "from trulens_eval.feedback import GroundTruthAgreement, Groundedness\n", diff --git a/trulens_eval/release_dbs/0.19.0/default.sqlite b/trulens_eval/release_dbs/0.19.0/default.sqlite new file mode 100644 index 000000000..c06f3549f Binary files /dev/null and b/trulens_eval/release_dbs/0.19.0/default.sqlite differ diff --git a/trulens_eval/tests/integration/test_database.py b/trulens_eval/tests/integration/test_database.py index 310b51bc1..dc9a75485 100644 --- a/trulens_eval/tests/integration/test_database.py +++ b/trulens_eval/tests/integration/test_database.py @@ -149,12 +149,12 @@ def test_migrate_legacy_sqlite_file(self): self.assertTrue(DbRevisions.load(db.engine).in_sync) # check that database is usable and no data was lost - self.assertEqual(db.get_app(app.app_id), json.loads(app.json())) + self.assertEqual(db.get_app(app.app_id), json.loads(app.model_dump_json())) df_recs, fb_cols = db.get_records_and_feedback([app.app_id]) self.assertTrue( set(df_recs.columns).issuperset(set(AppsExtractor.app_cols)) ) - self.assertEqual(df_recs["record_json"][0], rec.json()) + self.assertEqual(df_recs["record_json"][0], rec.model_dump_json()) self.assertEqual(list(fb_cols), [fb.name]) df_fb = db.get_feedback(record_id=rec.record_id) @@ -163,7 +163,7 @@ def test_migrate_legacy_sqlite_file(self): df_defs = db.get_feedback_defs( feedback_definition_id=fb.feedback_definition_id ) - self.assertEqual(df_defs["feedback_json"][0], json.loads(fb.json())) + self.assertEqual(df_defs["feedback_json"][0], json.loads(fb.model_dump_json())) class MockFeedback(Provider): diff --git a/trulens_eval/trulens_eval/app.py b/trulens_eval/trulens_eval/app.py index 223d6a491..cab1c0a34 100644 --- a/trulens_eval/trulens_eval/app.py +++ b/trulens_eval/trulens_eval/app.py @@ -686,9 +686,9 @@ def json(self, *args, **kwargs): self, *args, instrument=self.instrument, **kwargs ) - def model_dump(self): + def model_dump(self, redact_keys: bool = False): # Same problem as in json. - return jsonify(self, instrument=self.instrument) + return jsonify(self, instrument=self.instrument, redact_keys=redact_keys) # For use as a context manager. def __enter__(self): diff --git a/trulens_eval/trulens_eval/database/migrations/db_data_migration.py b/trulens_eval/trulens_eval/database/migrations/db_data_migration.py index f2f76b129..f30fc62b5 100644 --- a/trulens_eval/trulens_eval/database/migrations/db_data_migration.py +++ b/trulens_eval/trulens_eval/database/migrations/db_data_migration.py @@ -107,19 +107,19 @@ def _sql_alchemy_serialization_asserts(db: "DB") -> None: pass if attr_name == "record_json": - Record(**test_json) + Record.model_validate(test_json) elif attr_name == "cost_json": - Cost(**test_json) + Cost.model_validate(test_json) elif attr_name == "perf_json": - Perf(**test_json) + Perf.model_validate(test_json) elif attr_name == "calls_json": for record_app_call_json in test_json[ 'calls']: - FeedbackCall(**record_app_call_json) + FeedbackCall.model_validate(record_app_call_json) elif attr_name == "feedback_json": - FeedbackDefinition(**test_json) + FeedbackDefinition.model_validate(test_json) elif attr_name == "app_json": - AppDefinition(**test_json) + AppDefinition.model_validate(test_json) else: # If this happens, trulens needs to add a migration raise VersionException( diff --git a/trulens_eval/trulens_eval/database/orm.py b/trulens_eval/trulens_eval/database/orm.py index b055def4a..224502027 100644 --- a/trulens_eval/trulens_eval/database/orm.py +++ b/trulens_eval/trulens_eval/database/orm.py @@ -33,7 +33,7 @@ def parse( redact_keys: bool = False ) -> "AppDefinition": return cls( - app_id=obj.app_id, app_json=obj.json(redact_keys=redact_keys) + app_id=obj.app_id, app_json=obj.model_dump_json(redact_keys=redact_keys) ) diff --git a/trulens_eval/trulens_eval/database/sqlalchemy_db.py b/trulens_eval/trulens_eval/database/sqlalchemy_db.py index b586bdada..71c422a39 100644 --- a/trulens_eval/trulens_eval/database/sqlalchemy_db.py +++ b/trulens_eval/trulens_eval/database/sqlalchemy_db.py @@ -35,6 +35,7 @@ from trulens_eval.schema import FeedbackResultID from trulens_eval.schema import FeedbackResultStatus from trulens_eval.schema import RecordID +from trulens_eval.utils.pyschema import Class from trulens_eval.utils.serial import JSON from trulens_eval.utils.text import UNICODE_CHECK from trulens_eval.utils.text import UNICODE_CLOCK @@ -172,7 +173,7 @@ def insert_app(self, app: schema.AppDefinition) -> schema.AppID: if _app := session.query(orm.AppDefinition ).filter_by(app_id=app.app_id).first(): - _app.app_json = app.json() + _app.app_json = app.model_dump_json() else: _app = orm.AppDefinition.parse( app, redact_keys=self.redact_keys @@ -192,7 +193,7 @@ def insert_feedback_definition( if _fb_def := session.query(orm.FeedbackDefinition) \ .filter_by(feedback_definition_id=feedback_definition.feedback_definition_id) \ .first(): - _fb_def.app_json = feedback_definition.json() + _fb_def.app_json = feedback_definition.model_dump_json() else: _fb_def = orm.FeedbackDefinition.parse( feedback_definition, redact_keys=self.redact_keys @@ -298,6 +299,9 @@ def get_records_and_feedback( apps = (row[0] for row in session.execute(stmt)) return AppsExtractor().get_df_and_cols(apps) +# Use this Perf for missing Perfs. +# TODO: Migrate the database instead. +no_perf = schema.Perf(start_time=datetime.min, end_time=datetime.min).model_dump() def _extract_feedback_results( results: Iterable[orm.FeedbackResult] @@ -305,7 +309,8 @@ def _extract_feedback_results( def _extract(_result: orm.FeedbackResult): app_json = json.loads(_result.record.app.app_json) - _type = schema.AppDefinition(**app_json).root_class + _type = schema.AppDefinition.model_validate(app_json).root_class + return ( _result.record_id, _result.feedback_result_id, @@ -317,9 +322,9 @@ def _extract(_result: orm.FeedbackResult): _result.result, _result.multi_result, _result.cost_json, - json.loads(_result.record.perf_json), + json.loads(_result.record.perf_json) if _result.record.perf_json != MIGRATION_UNKNOWN_STR else no_perf, json.loads(_result.calls_json)["calls"], - json.loads(_result.feedback_definition.feedback_json), + json.loads(_result.feedback_definition.feedback_json) if _result.feedback_definition is not None else None, json.loads(_result.record.record_json), app_json, _type, @@ -361,7 +366,7 @@ def _extract(perf_json: Union[str, dict, schema.Perf]) -> int: if isinstance(perf_json, str): perf_json = json.loads(perf_json) if isinstance(perf_json, dict): - perf_json = schema.Perf(**perf_json) + perf_json = schema.Perf.model_validate(perf_json) if isinstance(perf_json, schema.Perf): return perf_json.latency.seconds raise ValueError(f"Failed to parse perf_json: {perf_json}") @@ -418,11 +423,11 @@ def extract_apps( for col in self.app_cols: if col == "type": - df[col] = str( - schema.AppDefinition.model_validate_json( - _app.app_json - ).root_class - ) + # Previous DBs did not contain entire app so we cannot + # deserialize AppDefinition here unless we fix prior DBs + # in migration. Because of this, loading just the + # `root_class` here. + df[col] = str(Class.model_validate(json.loads(_app.app_json).get('root_class'))) else: df[col] = getattr(_app, col) diff --git a/trulens_eval/trulens_eval/db.py b/trulens_eval/trulens_eval/db.py index 68ee6b92d..d858119be 100644 --- a/trulens_eval/trulens_eval/db.py +++ b/trulens_eval/trulens_eval/db.py @@ -382,7 +382,7 @@ def insert_record( # DB requirement def insert_app(self, app: AppDefinition) -> AppID: app_id = app.app_id - app_str = app.json() + app_str = app.model_dump_json() vals = (app_id, app_str) self._insert_or_replace_vals(table=self.TABLE_APPS, vals=vals) @@ -399,7 +399,7 @@ def insert_feedback_definition( """ feedback_definition_id = feedback.feedback_definition_id - feedback_str = feedback.json() + feedback_str = feedback.model_dump_json() vals = (feedback_definition_id, feedback_str) self._insert_or_replace_vals(table=self.TABLE_FEEDBACK_DEFS, vals=vals) diff --git a/trulens_eval/trulens_eval/db_migration.py b/trulens_eval/trulens_eval/db_migration.py index 2a3443420..ba96dc7e2 100644 --- a/trulens_eval/trulens_eval/db_migration.py +++ b/trulens_eval/trulens_eval/db_migration.py @@ -1,19 +1,30 @@ # This is pre-sqlalchemy db migration. This file should not need changes. It is here for backwards compatibility of oldest trulens-eval versions. import json +import logging import shutil import traceback -from typing import List +from typing import Callable, List import uuid +import pydantic from tqdm import tqdm +from trulens_eval.feedback.feedback import Feedback from trulens_eval.schema import AppDefinition from trulens_eval.schema import Cost from trulens_eval.schema import FeedbackCall from trulens_eval.schema import FeedbackDefinition from trulens_eval.schema import Perf from trulens_eval.schema import Record +from trulens_eval.utils.pyschema import Class +from trulens_eval.utils.pyschema import CLASS_INFO from trulens_eval.utils.pyschema import FunctionOrMethod +from trulens_eval.utils.pyschema import Method +from trulens_eval.utils.pyschema import Module +from trulens_eval.utils.pyschema import Obj + +logger = logging.getLogger(__name__) + ''' How to make a db migrations: @@ -55,7 +66,7 @@ class VersionException(Exception): MIGRATION_UNKNOWN_STR = "unknown[db_migration]" -migration_versions: List[str] = ["0.9.0", "0.3.0", "0.2.0", "0.1.2"] +migration_versions: List[str] = ["0.19.0", "0.9.0", "0.3.0", "0.2.0", "0.1.2"] def _update_db_json_col( @@ -76,6 +87,158 @@ def _update_db_json_col( db._insert_or_replace_vals(table=table, vals=migrate_record) +def jsonlike_map(fval=None, fkey=None, fkeyval=None): + if fval is None: + fval = lambda x:x + if fkey is None: + fkey = lambda x:x + if fkeyval is None: + fkeyval = lambda x,y: (x,y) + + def walk(obj): + if isinstance(obj, dict): + ret = {} + for k, v in obj.items(): + k = fkey(k) + v = fval(walk(v)) + k, v = fkeyval(k, v) + ret[k] = v + return fval(ret) + + if isinstance(obj, (list, tuple)): + return fval(type(obj)(fval(walk(v)) for v in obj)) + + else: + return fval(obj) + + return walk + +def jsonlike_rename_key(old_key, new_key) -> Callable: + def fkey(k): + if k == old_key: + logger.debug(f"key {old_key} -> {new_key}") + return new_key + else: + return k + + return jsonlike_map(fkey=fkey) + + +def jsonlike_rename_value(old_val, new_val) -> Callable: + def fval(v): + if v == old_val: + logger.debug(f"value {old_val} -> {new_val}") + return new_val + else: + return v + + return jsonlike_map(fval=fval) + + +class UnknownClass(pydantic.BaseModel): + def unknown_method(self): + """ + This is a placeholder put into the database in place of methods whose + information was not recorded in earlier versions of trulens. + """ + +def migrate_0_9_0(db): + rename_classinfo = jsonlike_rename_key("__tru_class_info", "tru_class_info") + rename_objserial = jsonlike_rename_value("ObjSerial", "Obj") + + def migrate_misc(obj): + + # Old Method format: + if isinstance(obj, dict) and "module_name" in obj and "method_name" in obj: + logger.debug(f"migrating RecordAppCallMethod {obj}") + # example: {'module_name': 'langchain.chains.llm', 'class_name': 'LLMChain', 'method_name': '_call'} + return Method( + obj = Obj( + cls = Class( + name=obj['class_name'], + module=Module(module_name=obj['module_name']) + ), + id=0 + ), + name = obj['method_name'] + ).model_dump() + + else: + return obj + + dummy_methods = jsonlike_map(fval=migrate_misc) + + all_migrate = lambda obj: dummy_methods(rename_classinfo(rename_objserial(obj))) + + conn, c = db._connect() + c.execute( + f"""SELECT * FROM records""" + ) # Use hardcode names as versions could go through name change + rows = c.fetchall() + json_db_col_idx = 4 + + for old_entry in tqdm(rows, desc="Migrating Records DB 0.9.0 to 0.19.0"): + new_json = all_migrate(json.loads(old_entry[json_db_col_idx])) + + _update_db_json_col( + db=db, + table= + "records", # Use hardcode names as versions could go through name change + old_entry=old_entry, + json_db_col_idx=json_db_col_idx, + new_json=new_json + ) + + c.execute(f"""SELECT * FROM feedback_defs""") + rows = c.fetchall() + json_db_col_idx = 1 + for old_entry in tqdm(rows, desc="Migrating FeedbackDefs DB 0.9.0 to 0.19.0"): + new_json = all_migrate(json.loads(old_entry[json_db_col_idx])) + + if CLASS_INFO not in new_json: + new_json[CLASS_INFO] = Class.of_class(Feedback).model_dump() + logger.debug(f"adding '{CLASS_INFO}'") + + if "initial_app_loader" not in new_json: + new_json['initial_app_loader'] = None + logger.debug(f"adding 'initial_app_loader'") + + if "initial_app_loader_dump" not in new_json: + new_json['initial_app_loader_dump'] = None + logger.debug(f"adding 'initial_app_loader_dump'") + + _update_db_json_col( + db=db, + table="feedback_defs", + old_entry=old_entry, + json_db_col_idx=json_db_col_idx, + new_json=new_json + ) + + c.execute(f"""SELECT * FROM apps""") + rows = c.fetchall() + json_db_col_idx = 1 + for old_entry in tqdm(rows, desc="Migrating Apps DB 0.9.0 to 0.19.0"): + new_json = all_migrate(json.loads(old_entry[json_db_col_idx])) + + if CLASS_INFO not in new_json: + new_json[CLASS_INFO] = Class.of_class(AppDefinition).model_dump() + logger.debug(f"adding `{CLASS_INFO}`") + + if "app" not in new_json: + new_json['app'] = dict() + logger.debug(f"adding `app`") + + _update_db_json_col( + db=db, + table="apps", + old_entry=old_entry, + json_db_col_idx=json_db_col_idx, + new_json=new_json + ) + + conn.commit() + def migrate_0_3_0(db): conn, c = db._connect() c.execute(f"""ALTER TABLE feedbacks @@ -240,7 +403,8 @@ def migrate_0_1_2(db): #"from_version":("to_version", migrate_method) "0.1.2": ("0.2.0", migrate_0_1_2), "0.2.0": ("0.3.0", migrate_0_2_0), - "0.3.0": ("0.9.0", migrate_0_3_0) + "0.3.0": ("0.9.0", migrate_0_3_0), + "0.9.0": ("0.19.0", migrate_0_9_0) } @@ -375,7 +539,9 @@ def _check_needs_migration(version: str, warn=False) -> None: def _serialization_asserts(db) -> None: - """After a successful migration, Do some checks if serialized jsons are loading properly + """ + After a successful migration, Do some checks if serialized jsons are loading + properly. Args: db (DB): the db object @@ -383,7 +549,12 @@ def _serialization_asserts(db) -> None: global saved_db_locations conn, c = db._connect() SAVED_DB_FILE_LOC = saved_db_locations[db.filename] - validation_fail_advice = f"Please open a ticket on trulens github page including details on the old and new trulens versions. The migration completed so you can still proceed; but stability is not guaranteed. Your original DB file is saved here: {SAVED_DB_FILE_LOC} and can be used with the previous version, or you can `tru.reset_database()`" + validation_fail_advice = ( + f"Please open a ticket on trulens github page including details on the old and new trulens versions. " + f"The migration completed so you can still proceed; but stability is not guaranteed. " + f"Your original DB file is saved here: {SAVED_DB_FILE_LOC} and can be used with the previous version, or you can `tru.reset_database()`" + ) + for table in db.TABLES: c.execute(f"""PRAGMA table_info({table}); """) @@ -415,18 +586,18 @@ def _serialization_asserts(db) -> None: pass if col_name == "record_json": - Record(**test_json) + Record.model_validate(test_json) elif col_name == "cost_json": - Cost(**test_json) + Cost.model_validate(test_json) elif col_name == "perf_json": - Perf(**test_json) + Perf.model_validate(test_json) elif col_name == "calls_json": for record_app_call_json in test_json['calls']: - FeedbackCall(**record_app_call_json) + FeedbackCall.model_validate(record_app_call_json) elif col_name == "feedback_json": - FeedbackDefinition(**test_json) + FeedbackDefinition.model_validate(test_json) elif col_name == "app_json": - AppDefinition(**test_json) + AppDefinition.model_validate(test_json) else: # If this happens, trulens needs to add a migration diff --git a/trulens_eval/trulens_eval/feedback/feedback.py b/trulens_eval/trulens_eval/feedback/feedback.py index e5f156f35..cc939cd75 100644 --- a/trulens_eval/trulens_eval/feedback/feedback.py +++ b/trulens_eval/trulens_eval/feedback/feedback.py @@ -252,6 +252,14 @@ def prepare_feedback(row): app_json = row.app_json + if row.get("feedback_json") is None: + logger.warning( + "Cannot evaluate feedback without `feedback_json`. " + "This might have come from an old database. \n" + f"{row}" + ) + return None, None + feedback = Feedback.model_validate(row.feedback_json) return feedback, feedback.run_and_log( diff --git a/trulens_eval/trulens_eval/feedback/groundtruth.py b/trulens_eval/trulens_eval/feedback/groundtruth.py index ecadff80f..4eee7c552 100644 --- a/trulens_eval/trulens_eval/feedback/groundtruth.py +++ b/trulens_eval/trulens_eval/feedback/groundtruth.py @@ -27,7 +27,7 @@ class GroundTruthAgreement(SerialModel, WithClassInfo): """Measures Agreement against a Ground Truth. """ - ground_truth: Union[List[str], FunctionOrMethod] + ground_truth: Union[List[Dict], FunctionOrMethod] provider: Provider # Note: the bert scorer object isn't serializable # It's a class member because creating it is expensive diff --git a/trulens_eval/trulens_eval/feedback/provider/base.py b/trulens_eval/trulens_eval/feedback/provider/base.py index 583151579..85d00dd67 100644 --- a/trulens_eval/trulens_eval/feedback/provider/base.py +++ b/trulens_eval/trulens_eval/feedback/provider/base.py @@ -20,7 +20,7 @@ class Config: endpoint: Optional[Endpoint] = None - def __init__(self, name: str = None, **kwargs): + def __init__(self, name: Optional[str] = None, **kwargs): # for WithClassInfo: kwargs['obj'] = self diff --git a/trulens_eval/trulens_eval/feedback/provider/endpoint/base.py b/trulens_eval/trulens_eval/feedback/provider/endpoint/base.py index e2e7fedf4..aa8eaa8dc 100644 --- a/trulens_eval/trulens_eval/feedback/provider/endpoint/base.py +++ b/trulens_eval/trulens_eval/feedback/provider/endpoint/base.py @@ -413,7 +413,6 @@ async def atrack_all_costs( for endpoint in Endpoint.ENDPOINT_SETUPS: if locals().get(endpoint.arg_flag): - print(f"tracking {endpoint.class_name}") mod = __import__( endpoint.module_name, fromlist=[endpoint.class_name] ) diff --git a/trulens_eval/trulens_eval/feedback/provider/endpoint/openai.py b/trulens_eval/trulens_eval/feedback/provider/endpoint/openai.py index b539c0cf8..21fe778a8 100644 --- a/trulens_eval/trulens_eval/feedback/provider/endpoint/openai.py +++ b/trulens_eval/trulens_eval/feedback/provider/endpoint/openai.py @@ -279,7 +279,7 @@ def handle_wrapped_call( if not counted_something: logger.warning( - f"Unregonized openai response format. It did not have usage information nor categories:\n" + f"Could not find usage information in openai response:\n" + pp.pformat(response) ) diff --git a/trulens_eval/trulens_eval/feedback/provider/hugs.py b/trulens_eval/trulens_eval/feedback/provider/hugs.py index 1ac35bd32..83f77ea70 100644 --- a/trulens_eval/trulens_eval/feedback/provider/hugs.py +++ b/trulens_eval/trulens_eval/feedback/provider/hugs.py @@ -71,7 +71,7 @@ class Huggingface(Provider): endpoint: Endpoint - def __init__(self, name: Optional[str] = None, endpoint=None, **kwargs): + def __init__(self, name: Optional[str] = None, endpoint: Optional[Endpoint] = None, **kwargs): # NOTE(piotrm): pydantic adds endpoint to the signature of this # constructor if we don't include it explicitly, even though we set it # down below. Adding it as None here as a temporary hack. diff --git a/trulens_eval/trulens_eval/instruments.py b/trulens_eval/trulens_eval/instruments.py index 62ded4105..227c4097a 100644 --- a/trulens_eval/trulens_eval/instruments.py +++ b/trulens_eval/trulens_eval/instruments.py @@ -408,8 +408,6 @@ def tracked_method_wrapper( existing_apps = getattr(func, Instrument.APPS) existing_apps.add(self.app) - # print(f"already instrumented for apps {list(existing_apps)}: {query}, path type is {type(query).__name__}, id={id(type(query))}") - return func # TODO: How to consistently address calls to chains that appear more @@ -418,8 +416,6 @@ def tracked_method_wrapper( else: # Notify the app instrumenting this method where it is located: - # print(f"instrumenting {query}, path type is {type(query).__name__}, id={id(type(query))}") - self.app._on_method_instrumented(obj, func, path=query) logger.debug(f"\t\t\t{query}: instrumenting {method_name}={func}") @@ -581,7 +577,7 @@ def find_instrumented(f): perf=Perf(start_time=start_time, end_time=end_time), pid=os.getpid(), tid=th.get_native_id(), - rets=rets, + rets=jsonify(rets), error=error_str if error is not None else None ) # End of run wrapped block. @@ -715,8 +711,6 @@ def find_instrumented(f): else: stack = ctx_stacks[ctx] - # print(f"creating frame info for path {path}, path type is {type(path).__name__}, id={id(type(path))}") - frame_ident = RecordAppCallMethod( path=path, method=Method.of_method(func, obj=obj, cls=cls) ) @@ -764,7 +758,7 @@ def find_instrumented(f): perf=Perf(start_time=start_time, end_time=end_time), pid=os.getpid(), tid=th.get_native_id(), - rets=rets, + rets=jsonify(rets), error=error_str if error is not None else None ) # End of run wrapped block. diff --git a/trulens_eval/trulens_eval/pages/Progress.py b/trulens_eval/trulens_eval/pages/Progress.py deleted file mode 100644 index 6201e85e1..000000000 --- a/trulens_eval/trulens_eval/pages/Progress.py +++ /dev/null @@ -1,51 +0,0 @@ -import asyncio - -# https://github.com/jerryjliu/llama_index/issues/7244: -asyncio.set_event_loop(asyncio.new_event_loop()) - -from st_aggrid import AgGrid -import streamlit as st -from ux.add_logo import add_logo_and_style_overrides - -from trulens_eval import Tru -from trulens_eval.feedback.provider.endpoint.base import DEFAULT_RPM -from trulens_eval.schema import FeedbackResultStatus - -st.set_page_config(page_title="Feedback Progress", layout="wide") - -st.title("Feedback Progress") - -st.runtime.legacy_caching.clear_cache() - -add_logo_and_style_overrides() - -tru = Tru() -lms = tru.db - -endpoints = ["OpenAI", "HuggingFace"] - -tab1, tab2, tab3 = st.tabs(["Progress", "Endpoints", "Feedback Functions"]) - -with tab1: - feedbacks = lms.get_feedback( - status=[ - FeedbackResultStatus.NONE, FeedbackResultStatus.RUNNING, - FeedbackResultStatus.FAILED - ] - ) - feedbacks = feedbacks.astype(str) - data = AgGrid( - feedbacks, allow_unsafe_jscode=True, fit_columns_on_grid_load=True - ) - -with tab2: - for e in endpoints: - st.header(e) - st.metric("RPM", DEFAULT_RPM) - -with tab3: - feedbacks = lms.get_feedback_defs() - feedbacks = feedbacks.astype(str) - data = AgGrid( - feedbacks, allow_unsafe_jscode=True, fit_columns_on_grid_load=True - ) diff --git a/trulens_eval/trulens_eval/schema.py b/trulens_eval/trulens_eval/schema.py index f48015267..1fc302563 100644 --- a/trulens_eval/trulens_eval/schema.py +++ b/trulens_eval/trulens_eval/schema.py @@ -683,7 +683,7 @@ def get_loadable_apps(): apps = tru.get_apps() for app in apps: - dump = app['initial_app_loader_dump'] + dump = app.get('initial_app_loader_dump') if dump is not None: rets.append(app) diff --git a/trulens_eval/trulens_eval/tru.py b/trulens_eval/trulens_eval/tru.py index 1e1d045ff..dd226128b 100644 --- a/trulens_eval/trulens_eval/tru.py +++ b/trulens_eval/trulens_eval/tru.py @@ -20,6 +20,7 @@ from trulens_eval.feedback import Feedback from trulens_eval.schema import AppDefinition from trulens_eval.schema import FeedbackResult +from trulens_eval.schema import FeedbackResultStatus from trulens_eval.schema import Record from trulens_eval.utils.notebook_utils import is_notebook from trulens_eval.utils.notebook_utils import setup_widget_stdout_stderr @@ -195,7 +196,7 @@ def _submit_feedback_functions( self.db: DB if app is None: - app = AppDefinition.model_validate_json(self.db.get_app(app_id=app_id)) + app = AppDefinition.model_validate(self.db.get_app(app_id=app_id)) if app is None: raise RuntimeError( "App {app_id} not present in db. " @@ -263,13 +264,17 @@ def add_app(self, app: AppDefinition) -> None: self.db.insert_app(app=app) def add_feedback( - self, feedback_result: FeedbackResult = None, **kwargs + self, feedback_result: Optional[FeedbackResult] = None, **kwargs ) -> None: """ Add a single feedback result to the database. """ if feedback_result is None: + if 'result' in kwargs and 'status' not in kwargs: + # If result already present, set status to done. + kwargs['status'] = FeedbackResultStatus.DONE + feedback_result = FeedbackResult(**kwargs) else: feedback_result.update(**kwargs) diff --git a/trulens_eval/trulens_eval/tru_custom_app.py b/trulens_eval/trulens_eval/tru_custom_app.py index 638eaa92a..dbc341d57 100644 --- a/trulens_eval/trulens_eval/tru_custom_app.py +++ b/trulens_eval/trulens_eval/tru_custom_app.py @@ -340,6 +340,7 @@ def __init__(self, app: Any, methods_to_instrument=None, **kwargs): else: main_name = main_method.__name__ main_method_loaded = main_method + main_method = Function.of_function(main_method_loaded) if not safe_hasattr(main_method_loaded, "__self__"): raise ValueError( @@ -353,6 +354,9 @@ def __init__(self, app: Any, methods_to_instrument=None, **kwargs): cls = app_self.__class__ mod = cls.__module__ + kwargs['main_method'] = main_method + kwargs['main_method_loaded'] = main_method_loaded + instrument.include_modules.add(mod) instrument.include_classes.add(cls) instrument.include_methods[main_name] = lambda o: isinstance(o, cls) diff --git a/trulens_eval/trulens_eval/utils/json.py b/trulens_eval/trulens_eval/utils/json.py index 8916de7e0..5a6c60af2 100644 --- a/trulens_eval/trulens_eval/utils/json.py +++ b/trulens_eval/trulens_eval/utils/json.py @@ -174,7 +174,7 @@ def jsonify( return str(obj) if type(obj) in pydantic.v1.json.ENCODERS_BY_TYPE: - return obj + return pydantic.v1.json.ENCODERS_BY_TYPE[type(obj)](obj) # TODO: should we include duplicates? If so, dicted needs to be adjusted. new_dicted = {k: v for k, v in dicted.items()} diff --git a/trulens_eval/trulens_eval/utils/pyschema.py b/trulens_eval/trulens_eval/utils/pyschema.py index 7259abb25..e945fea0c 100644 --- a/trulens_eval/trulens_eval/utils/pyschema.py +++ b/trulens_eval/trulens_eval/utils/pyschema.py @@ -19,13 +19,11 @@ import importlib import inspect import logging -import dill from pprint import PrettyPrinter from types import ModuleType -from typing import ( - Any, Callable, Dict, Optional, Sequence, Tuple -) +from typing import Any, Callable, Dict, Optional, Sequence, Tuple +import dill import pydantic from pydantic import Field @@ -153,7 +151,7 @@ def clean_attributes(obj, include_props: bool = False) -> Dict[str, Any]: class Module(SerialModel): - package_name: Optional[str] # some modules are not in a package + package_name: Optional[str] = None # some modules are not in a package module_name: str def of_module(mod: ModuleType, loadable: bool = False) -> 'Module': @@ -189,7 +187,7 @@ class Class(SerialModel): module: Module - bases: Optional[Sequence[Class]] + bases: Optional[Sequence[Class]] = None def __repr__(self): return self.module.module_name + "." + self.name @@ -387,7 +385,13 @@ def load(self) -> object: cls = self.cls.load() sig = _safe_init_sig(cls) - bindings = self.init_bindings.load(sig) + + if CLASS_INFO in sig.parameters and CLASS_INFO not in self.init_bindings.kwargs: + extra_kwargs = {CLASS_INFO: self.cls} + else: + extra_kwargs = {} + + bindings = self.init_bindings.load(sig, extra_kwargs=extra_kwargs) return cls(*bindings.args, **bindings.kwargs) @@ -416,11 +420,11 @@ def _handle_providers_load(self): if 'provider' in self.kwargs: del self.kwargs['provider'] - def load(self, sig: inspect.Signature): + def load(self, sig: inspect.Signature, extra_args=(), extra_kwargs={}): self._handle_providers_load() - return sig.bind(*self.args, **self.kwargs) + return sig.bind(*(self.args+extra_args), **self.kwargs, **extra_kwargs) class FunctionOrMethod(SerialModel): @@ -564,10 +568,24 @@ class WithClassInfo(pydantic.BaseModel): @classmethod def model_validate(cls, obj, **kwargs): - clsinfo = Class.model_validate(obj[CLASS_INFO]) - clsloaded = clsinfo.load() + if isinstance(obj, dict) and CLASS_INFO in obj: - return super(cls, clsloaded).model_validate(obj) + clsinfo = Class.model_validate(obj[CLASS_INFO]) + clsloaded = clsinfo.load() + + # NOTE(piotrm): even though we have a more specific class than + # AppDefinition, we load it as AppDefinition due to serialization + # issues in the wrapped app. Keeping it as AppDefinition means `app` + # field is just json. + from trulens_eval.schema import AppDefinition + + if issubclass(clsloaded, AppDefinition): + return super(cls, AppDefinition).model_validate(obj) + else: + return super(cls, clsloaded).model_validate(obj) + + else: + return super().model_validate(obj) def __init__( self, @@ -585,6 +603,7 @@ def __init__( class_info = Class.of_class(cls, with_bases=True) kwargs[CLASS_INFO] = class_info + super().__init__(*args, **kwargs) @staticmethod diff --git a/trulens_eval/trulens_eval/utils/serial.py b/trulens_eval/trulens_eval/utils/serial.py index 4f275c410..68c4ea7d5 100644 --- a/trulens_eval/trulens_eval/utils/serial.py +++ b/trulens_eval/trulens_eval/utils/serial.py @@ -22,6 +22,7 @@ from munch import Munch as Bunch import pydantic + from trulens_eval.utils.containers import iterable_peek logger = logging.getLogger(__name__) @@ -64,15 +65,24 @@ class SerialModel(pydantic.BaseModel): help serialization mostly. """ + def model_dump_json(self, **kwargs): + from trulens_eval.utils.json import json_str_of_obj + + return json_str_of_obj(self, **kwargs) + + def model_dump(self, **kwargs): + from trulens_eval.utils.json import jsonify + + return jsonify(self, **kwargs) + @classmethod def model_validate(cls, obj, **kwargs): # import hierarchy circle here - from trulens_eval.utils.pyschema import Class from trulens_eval.utils.pyschema import CLASS_INFO from trulens_eval.utils.pyschema import WithClassInfo if isinstance(obj, Dict) and CLASS_INFO in obj: - return WithClassInfo.model_validate(obj) + return WithClassInfo.model_validate(obj, **kwargs) return super(SerialModel, cls).model_validate(obj, **kwargs) @@ -118,13 +128,16 @@ class Step(pydantic.BaseModel, Hashable): A step in a selection path. """ + def __hash__(self): + raise TypeError(f"Should never be called, self={self.model_dump()}") + @classmethod def model_validate(cls, obj, **kwargs): if isinstance(obj, Step): - return obj + return super().model_validate(obj, **kwargs) - elif isinstance(obj, Dict): + elif isinstance(obj, dict): ATTRIBUTE_TYPE_MAP = { 'item': GetItem, @@ -548,7 +561,10 @@ def validate_from_string(cls, obj, handler): # different than obj. Might be a pydantic oversight/bug. if isinstance(obj, str): - return Lens.of_string(obj) + ret = Lens.of_string(obj) + return ret + elif isinstance(obj, dict): + return handler(dict(path=(Step.model_validate(step) for step in obj['path']))) else: return handler(obj)