diff --git a/docs/trulens_eval/feedback_function_guide.md b/docs/trulens_eval/feedback_function_guide.md index 92f2fb7e2..f5fd091fc 100644 --- a/docs/trulens_eval/feedback_function_guide.md +++ b/docs/trulens_eval/feedback_function_guide.md @@ -298,7 +298,7 @@ class Record(SerialModel): For an App: ```python -class AppDefinition(SerialModel, WithClassInfo, ABC): +class AppDefinition(WithClassInfo, SerialModel, ABC): ... app_id: AppID diff --git a/trulens_eval/examples/experimental/dev_notebook.ipynb b/trulens_eval/examples/experimental/dev_notebook.ipynb index 09e60e54a..1cb5055d3 100644 --- a/trulens_eval/examples/experimental/dev_notebook.ipynb +++ b/trulens_eval/examples/experimental/dev_notebook.ipynb @@ -16,7 +16,8 @@ "outputs": [], "source": [ "# ! pip uninstall -y trulens_eval # ==0.18.2\n", - "# ! pip list | grep trulens" + "# ! pip list | grep trulens\n", + "# ! pip install --upgrade pydantic" ] }, { @@ -91,6 +92,124 @@ ")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Pydantic testing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Any\n", + "\n", + "from pydantic import BaseModel\n", + "from pydantic import Field\n", + "from pydantic import field_validator\n", + "from pydantic import model_validator\n", + "from pydantic import PydanticUndefinedAnnotation\n", + "from pydantic import SerializeAsAny\n", + "from pydantic import ValidationInfo\n", + "from pydantic import validator\n", + "from pydantic_core import PydanticUndefined\n", + "\n", + "\n", + "class CustomLoader(BaseModel):\n", + " cls: Any\n", + "\n", + " def __init__(self, *args, **kwargs):\n", + " kwargs['cls'] = type(self)\n", + " super().__init__(*args, **kwargs)\n", + "\n", + " @model_validator(mode='before')\n", + " @staticmethod\n", + " def my_model_validate(obj, info: ValidationInfo):\n", + " if not isinstance(obj, dict):\n", + " return obj\n", + "\n", + " cls = obj['cls']\n", + " # print(cls, subcls, obj, info)\n", + "\n", + " validated = dict()\n", + " for k, finfo in cls.model_fields.items():\n", + " print(k, finfo)\n", + " typ = finfo.annotation\n", + " val = finfo.get_default()\n", + "\n", + " if val is PydanticUndefined:\n", + " val = obj[k]\n", + "\n", + " print(typ, type(typ))\n", + " if isinstance(typ, type) \\\n", + " and issubclass(typ, CustomLoader) \\\n", + " and isinstance(val, dict) and \"cls\" in val:\n", + " subcls = val['cls']\n", + " val = subcls.model_validate(val)\n", + " \n", + " validated[k] = val\n", + " \n", + " return validated\n", + "\n", + "class SubModel(CustomLoader):\n", + " sm: int = 3\n", + "\n", + "class Model(CustomLoader):\n", + " m: int = 2\n", + " sub: SubModel\n", + "\n", + "class SubSubModelA(SubModel):\n", + " ssma: int = 42\n", + "\n", + "class SubModelA(SubModel):\n", + " sma: int = 0\n", + " subsub: SubSubModelA\n", + "\n", + "class SubModelB(SubModel):\n", + " smb: int = 1\n", + "\n", + "c = Model(sub=SubModelA(subsub=SubSubModelA()))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "c" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "c.model_dump()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Model.model_validate({'cls': Model, 'm': 2, 'sub': {'cls': SubModelA, 'sma':3, 'subsub': {'cls': SubSubModelA, 'ssma': 42}}})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Model.model_validate({'c': 2, 'sub': {}})" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -202,6 +321,36 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.llms import AzureOpenAI as AzureOpenAIChat\n", + "import os\n", + "\n", + "gpt_35_turbo = AzureOpenAIChat(\n", + " deployment_name=\"gpt-35-turbo\",\n", + " model=\"gpt-35-turbo\",\n", + " api_key=os.getenv(\"AZURE_OPENAI_API_KEY\"),\n", + " api_version=\"2023-05-15\",\n", + " model_version=\"0613\",\n", + " temperature=0.0,\n", + ")\n", + "c = gpt_35_turbo._get_client()\n", + "gpt_35_turbo._get_credential_kwargs()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "c.base_url" + ] + }, { "cell_type": "code", "execution_count": null, @@ -210,7 +359,189 @@ "source": [ "import os\n", "from trulens_eval import feedback\n", - "azopenai = feedback.AzureOpenAI(deployment_name=os.environ['AZURE_OPENAI_DEPLYOMENT_NAME'])" + "azopenai = feedback.AzureOpenAI(\n", + " deployment_name=os.environ['AZURE_OPENAI_DEPLOYMENT_NAME']\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "azopenai.endpoint.client.client_kwargs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# azopenai.relevance(prompt=\"Where is Germany?\", response=\"Germany is in Europe.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# reval = feedback.AzureOpenAI.model_validate(azopenai.model_dump())\n", + "# reval.relevance(prompt=\"Where is Germany?\", response=\"Poland is in Europe.\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "azureOpenAI = azopenai\n", + "\n", + "from trulens_eval.feedback.provider import AzureOpenAI\n", + "from trulens_eval.feedback import Groundedness, GroundTruthAgreement\n", + "from trulens_eval import TruLlama, Feedback\n", + "from trulens_eval.app import App\n", + "import numpy as np\n", + "# Initialize provider class\n", + "#azureOpenAI = AzureOpenAI(deployment_name=\"gpt-35-turbo\")\n", + "\n", + "grounded = Groundedness(groundedness_provider=azureOpenAI)\n", + "# Define a groundedness feedback function\n", + "f_groundedness = (\n", + " Feedback(grounded.groundedness_measure_with_cot_reasons)\n", + " .on_input_output()\n", + " .aggregate(grounded.grounded_statements_aggregator)\n", + ")\n", + "\n", + "# Question/answer relevance between overall question and answer.\n", + "f_answer_relevance = Feedback(azureOpenAI.relevance_with_cot_reasons).on_input_output()\n", + "# Question/statement relevance between question and each context chunk.\n", + "f_context_relevance = (\n", + " Feedback(azureOpenAI.qs_relevance_with_cot_reasons)\n", + " .on_input_output()\n", + " .aggregate(np.mean)\n", + ")\n", + "\n", + "# GroundTruth for comparing the Answer to the Ground-Truth Answer\n", + "#ground_truth_collection = GroundTruthAgreement(golden_set, provider=azureOpenAI)\n", + "#f_answer_correctness = (\n", + "# Feedback(ground_truth_collection.agreement_measure)\n", + "# .on_input_output()\n", + "#)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "f_groundedness.model_dump()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from trulens_eval.utils.pyschema import WithClassInfo\n", + "from trulens_eval.utils.serial import SerialModel\n", + "from trulens_eval.feedback.groundedness import Groundedness\n", + "import pydantic" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "Groundedness.model_validate(grounded.model_dump())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "f2 = Feedback.model_validate(f_groundedness.model_dump())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "f2.implementation.obj.init_bindings.kwargs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "f2.imp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def test_serial(f):\n", + " print(\"Before serialization:\")\n", + " print(f.imp(\"Where is Poland?\", \"Poland is in Europe\"))\n", + " f_dump = f.model_dump()\n", + " f = Feedback.model_validate(f_dump)\n", + " print(\"After serialization:\")\n", + " print(f.imp(\"Where is Poland?\", \"Germany is in Europe\"))\n", + " return f\n", + "\n", + "f2 = test_serial(f_groundedness)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "f_groundedness.imp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "f2.imp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "f_answer_relevance = Feedback(azureOpenAI.relevance_with_cot_reasons).on_input_output()\n", + "\n", + "# test without serialization\n", + "print(f_answer_relevance.imp(prompt=\"Where is Germany?\", response=\"Germany is in Europe.\"))\n", + "\n", + "# serialize/deserialize\n", + "f_answer_relevance2 = Feedback.model_validate(f_answer_relevance.model_dump())\n", + "\n", + "# test after deserialization\n", + "print(f_answer_relevance2.imp(prompt=\"Where is Germany?\", response=\"Poland is in Europe.\"))" ] }, { @@ -219,7 +550,7 @@ "metadata": {}, "outputs": [], "source": [ - "azopenai.relevance(prompt=\"Where is Germany?\", response=\"Germany is in Europe.\")" + "fr = feedback.Feedback.model_validate(f.model_dump())" ] }, { @@ -228,8 +559,7 @@ "metadata": {}, "outputs": [], "source": [ - "reval = feedback.AzureOpenAI.model_validate(azopenai.model_dump())\n", - "reval.relevance(prompt=\"Where is Germany?\", response=\"Poland is in Europe.\")" + "fr.imp(prompt=\"Where is Germany?\", response=\"Germany is in Europe.\")" ] }, { diff --git a/trulens_eval/examples/expositional/frameworks/langchain/langchain_agents.ipynb b/trulens_eval/examples/expositional/frameworks/langchain/langchain_agents.ipynb index 698e4c127..7514e66bb 100644 --- a/trulens_eval/examples/expositional/frameworks/langchain/langchain_agents.ipynb +++ b/trulens_eval/examples/expositional/frameworks/langchain/langchain_agents.ipynb @@ -39,23 +39,32 @@ "metadata": {}, "outputs": [], "source": [ - "from IPython.display import JSON\n", - "from trulens_eval import TruChain, Feedback, Huggingface, Tru\n", + "from trulens_eval import Feedback\n", + "from trulens_eval import Huggingface\n", + "from trulens_eval import Tru\n", + "from trulens_eval import TruChain\n", "from trulens_eval.feedback import OpenAI as fOpenAI\n", + "\n", "tru = Tru()\n", "\n", - "from langchain.chat_models import ChatOpenAI\n", + "from datetime import datetime\n", + "from datetime import timedelta\n", + "from typing import Type\n", + "\n", "from langchain import SerpAPIWrapper\n", + "from langchain.agents import AgentExecutor\n", + "from langchain.agents import AgentType\n", + "from langchain.agents import BaseSingleActionAgent\n", + "from langchain.agents import initialize_agent\n", + "from langchain.agents import load_tools\n", + "from langchain.agents import Tool\n", + "from langchain.chat_models import ChatOpenAI\n", "from langchain.llms import OpenAI as langchainOpenAI\n", - "from langchain.agents import AgentType, initialize_agent, load_tools, Tool, AgentExecutor, BaseSingleActionAgent\n", "from langchain.tools import BaseTool\n", - "import yfinance as yf\n", - "from datetime import datetime, timedelta\n", "import openai\n", - "\n", - "\n", - "from typing import Type\n", - "from pydantic import BaseModel, Field" + "from pydantic import BaseModel\n", + "from pydantic import Field\n", + "import yfinance as yf" ] }, { diff --git a/trulens_eval/examples/expositional/frameworks/langchain/langchain_model_comparison.ipynb b/trulens_eval/examples/expositional/frameworks/langchain/langchain_model_comparison.ipynb index cb9ac82b6..6a6b06c44 100644 --- a/trulens_eval/examples/expositional/frameworks/langchain/langchain_model_comparison.ipynb +++ b/trulens_eval/examples/expositional/frameworks/langchain/langchain_model_comparison.ipynb @@ -39,30 +39,27 @@ "source": [ "import os\n", "\n", - "from IPython.display import JSON\n", - "\n", - "import numpy as np\n", - "\n", + "from langchain import LLMChain\n", "# Imports from langchain to build app. You may need to install langchain first\n", "# with the following:\n", "# ! pip install langchain>=0.0.170\n", "from langchain.chains import LLMChain\n", "from langchain.llms import OpenAI\n", - "from langchain.prompts import ChatPromptTemplate, PromptTemplate\n", + "from langchain.prompts import ChatPromptTemplate\n", "from langchain.prompts import HumanMessagePromptTemplate\n", "from langchain.prompts import PromptTemplate\n", - "from langchain.llms import OpenAI\n", - "from langchain import LLMChain\n", + "import numpy as np\n", "\n", "# Imports main tools:\n", - "from trulens_eval import TruChain, Feedback, Huggingface, Tru\n", "# Imports main tools:\n", "from trulens_eval import Feedback\n", "from trulens_eval import feedback\n", "from trulens_eval import FeedbackMode\n", + "from trulens_eval import Huggingface\n", "from trulens_eval import Select\n", "from trulens_eval import TP\n", "from trulens_eval import Tru\n", + "from trulens_eval import TruChain\n", "from trulens_eval.utils.langchain import WithFeedbackFilterDocuments\n", "\n", "tru = Tru()" diff --git a/trulens_eval/examples/expositional/models/bedrock_finetuning_experiments.ipynb b/trulens_eval/examples/expositional/models/bedrock_finetuning_experiments.ipynb index 1e3420005..ba1d39c29 100644 --- a/trulens_eval/examples/expositional/models/bedrock_finetuning_experiments.ipynb +++ b/trulens_eval/examples/expositional/models/bedrock_finetuning_experiments.ipynb @@ -402,8 +402,9 @@ }, "outputs": [], "source": [ + "from IPython.display import display\n", + "from IPython.display import HTML\n", "import pandas as pd\n", - "from IPython.display import display, HTML\n", "\n", "test_dataset = train_and_test_dataset[\"test\"]\n", "\n", diff --git a/trulens_eval/examples/expositional/models/google_vertex_quickstart.ipynb b/trulens_eval/examples/expositional/models/google_vertex_quickstart.ipynb index 0efd86090..8c06b7aa8 100644 --- a/trulens_eval/examples/expositional/models/google_vertex_quickstart.ipynb +++ b/trulens_eval/examples/expositional/models/google_vertex_quickstart.ipynb @@ -64,10 +64,12 @@ "metadata": {}, "outputs": [], "source": [ - "from IPython.display import JSON\n", - "\n", "# Imports main tools:\n", - "from trulens_eval import TruChain, Feedback, Tru, LiteLLM\n", + "from trulens_eval import Feedback\n", + "from trulens_eval import LiteLLM\n", + "from trulens_eval import Tru\n", + "from trulens_eval import TruChain\n", + "\n", "tru = Tru()\n", "tru.reset_database()\n", "\n", @@ -78,7 +80,8 @@ "from langchain.chains import LLMChain\n", "from langchain.llms import VertexAI\n", "from langchain.prompts import PromptTemplate\n", - "from langchain.prompts.chat import HumanMessagePromptTemplate, ChatPromptTemplate" + "from langchain.prompts.chat import ChatPromptTemplate\n", + "from langchain.prompts.chat import HumanMessagePromptTemplate" ] }, { diff --git a/trulens_eval/examples/expositional/models/litellm_quickstart.ipynb b/trulens_eval/examples/expositional/models/litellm_quickstart.ipynb index 4f7547aae..3c9f9f88c 100644 --- a/trulens_eval/examples/expositional/models/litellm_quickstart.ipynb +++ b/trulens_eval/examples/expositional/models/litellm_quickstart.ipynb @@ -57,10 +57,12 @@ "metadata": {}, "outputs": [], "source": [ - "from IPython.display import JSON\n", - "\n", "# Imports main tools:\n", - "from trulens_eval import TruChain, Feedback, Tru, LiteLLM\n", + "from trulens_eval import Feedback\n", + "from trulens_eval import LiteLLM\n", + "from trulens_eval import Tru\n", + "from trulens_eval import TruChain\n", + "\n", "tru = Tru()\n", "tru.reset_database()\n", "\n", @@ -71,7 +73,8 @@ "from langchain.chains import LLMChain\n", "from langchain.llms import OpenAI\n", "from langchain.prompts import PromptTemplate\n", - "from langchain.prompts.chat import HumanMessagePromptTemplate, ChatPromptTemplate" + "from langchain.prompts.chat import ChatPromptTemplate\n", + "from langchain.prompts.chat import HumanMessagePromptTemplate" ] }, { diff --git a/trulens_eval/examples/expositional/models/ollama_quickstart.ipynb b/trulens_eval/examples/expositional/models/ollama_quickstart.ipynb index 075565235..6b3f622e2 100644 --- a/trulens_eval/examples/expositional/models/ollama_quickstart.ipynb +++ b/trulens_eval/examples/expositional/models/ollama_quickstart.ipynb @@ -47,10 +47,11 @@ "metadata": {}, "outputs": [], "source": [ - "from IPython.display import JSON\n", - "\n", "# Imports main tools:\n", - "from trulens_eval import TruChain, Feedback, Tru\n", + "from trulens_eval import Feedback\n", + "from trulens_eval import Tru\n", + "from trulens_eval import TruChain\n", + "\n", "tru = Tru()\n", "tru.reset_database()\n", "\n", @@ -60,7 +61,8 @@ "# ! pip install langchain>=0.0.170\n", "from langchain.chains import LLMChain\n", "from langchain.prompts import PromptTemplate\n", - "from langchain.prompts.chat import HumanMessagePromptTemplate, ChatPromptTemplate" + "from langchain.prompts.chat import ChatPromptTemplate\n", + "from langchain.prompts.chat import HumanMessagePromptTemplate" ] }, { diff --git a/trulens_eval/examples/expositional/use_cases/language_verification.ipynb b/trulens_eval/examples/expositional/use_cases/language_verification.ipynb index 90cf41878..a15f71a73 100644 --- a/trulens_eval/examples/expositional/use_cases/language_verification.ipynb +++ b/trulens_eval/examples/expositional/use_cases/language_verification.ipynb @@ -56,8 +56,6 @@ "metadata": {}, "outputs": [], "source": [ - "from IPython.display import JSON\n", - "\n", "# Imports main tools:\n", "from trulens_eval import Feedback, Huggingface, Tru\n", "tru = Tru()\n", diff --git a/trulens_eval/examples/expositional/use_cases/model_comparison.ipynb b/trulens_eval/examples/expositional/use_cases/model_comparison.ipynb index 2c4fa1093..639f41b91 100644 --- a/trulens_eval/examples/expositional/use_cases/model_comparison.ipynb +++ b/trulens_eval/examples/expositional/use_cases/model_comparison.ipynb @@ -59,10 +59,11 @@ "metadata": {}, "outputs": [], "source": [ - "from IPython.display import JSON\n", - "\n", "# Imports main tools:\n", - "from trulens_eval import Feedback, Tru, OpenAI\n", + "from trulens_eval import Feedback\n", + "from trulens_eval import OpenAI\n", + "from trulens_eval import Tru\n", + "\n", "tru = Tru()\n", "tru.reset_database()" ] diff --git a/trulens_eval/examples/expositional/use_cases/moderation.ipynb b/trulens_eval/examples/expositional/use_cases/moderation.ipynb index 7512d53e2..94e4b3387 100644 --- a/trulens_eval/examples/expositional/use_cases/moderation.ipynb +++ b/trulens_eval/examples/expositional/use_cases/moderation.ipynb @@ -66,10 +66,11 @@ } ], "source": [ - "from IPython.display import JSON\n", - "\n", "# Imports main tools:\n", - "from trulens_eval import Feedback, Tru, OpenAI\n", + "from trulens_eval import Feedback\n", + "from trulens_eval import OpenAI\n", + "from trulens_eval import Tru\n", + "\n", "tru = Tru()\n", "tru.reset_database()" ] diff --git a/trulens_eval/examples/quickstart/langchain_quickstart.ipynb b/trulens_eval/examples/quickstart/langchain_quickstart.ipynb index 8e7798c5d..f19ab1fdf 100644 --- a/trulens_eval/examples/quickstart/langchain_quickstart.ipynb +++ b/trulens_eval/examples/quickstart/langchain_quickstart.ipynb @@ -55,8 +55,6 @@ "metadata": {}, "outputs": [], "source": [ - "from IPython.display import JSON\n", - "\n", "# Imports main tools:\n", "from trulens_eval import TruChain, Feedback, Huggingface, Tru\n", "from trulens_eval.schema import FeedbackResult\n", diff --git a/trulens_eval/examples/quickstart/text2text_quickstart.ipynb b/trulens_eval/examples/quickstart/text2text_quickstart.ipynb index b49273a17..0222b668f 100644 --- a/trulens_eval/examples/quickstart/text2text_quickstart.ipynb +++ b/trulens_eval/examples/quickstart/text2text_quickstart.ipynb @@ -55,8 +55,6 @@ "metadata": {}, "outputs": [], "source": [ - "from IPython.display import JSON\n", - "\n", "# Create openai client\n", "from openai import OpenAI\n", "client = OpenAI()\n", @@ -222,7 +220,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.5" + "version": "3.8.16" } }, "nbformat": 4, diff --git a/trulens_eval/generated_files/all_tools.ipynb b/trulens_eval/generated_files/all_tools.ipynb index 8bb55d4db..ffe01e520 100644 --- a/trulens_eval/generated_files/all_tools.ipynb +++ b/trulens_eval/generated_files/all_tools.ipynb @@ -55,11 +55,13 @@ "metadata": {}, "outputs": [], "source": [ - "from IPython.display import JSON\n", - "\n", "# Imports main tools:\n", - "from trulens_eval import TruChain, Feedback, Huggingface, Tru\n", + "from trulens_eval import Feedback\n", + "from trulens_eval import Huggingface\n", + "from trulens_eval import Tru\n", + "from trulens_eval import TruChain\n", "from trulens_eval.schema import FeedbackResult\n", + "\n", "tru = Tru()\n", "tru.reset_database()\n", "\n", diff --git a/trulens_eval/tests/unit/feedbacks.py b/trulens_eval/tests/unit/feedbacks.py index 130dbfbf7..4051745a9 100644 --- a/trulens_eval/tests/unit/feedbacks.py +++ b/trulens_eval/tests/unit/feedbacks.py @@ -73,7 +73,7 @@ def make_nonglobal_feedbacks(): # incorrectly. class NG: # "non-global" - + @staticmethod def NGcustom_feedback_function(t1: str) -> float: return 0.1 diff --git a/trulens_eval/trulens_eval/app.py b/trulens_eval/trulens_eval/app.py index c5ba36693..8f614d056 100644 --- a/trulens_eval/trulens_eval/app.py +++ b/trulens_eval/trulens_eval/app.py @@ -385,7 +385,7 @@ def finish_record( return record -class App(AppDefinition, SerialModel, WithInstrumentCallbacks, Hashable): +class App(AppDefinition, WithInstrumentCallbacks, Hashable): """ Generalization of a wrapped model. """ @@ -415,13 +415,13 @@ class App(AppDefinition, SerialModel, WithInstrumentCallbacks, Hashable): # Instrumentation class. This is needed for serialization as it tells us # which objects we want to be included in the json representation of this # app. - instrument: Instrument = Field(exclude=True) + instrument: Instrument = Field(None, exclude=True) # Sequnces of records produced by the this class used as a context manager # are stpred om a RecordingContext. Using a context var so that context # managers can be nested. recording_contexts: contextvars.ContextVar[RecordingContext] \ - = Field(exclude=True) + = Field(None, exclude=True) # Mapping of instrumented methods (by id(.) of owner object and the # function) to their path in this app: @@ -446,13 +446,13 @@ def __init__( "recording_contexts" ) - # Cannot use this to set app. AppDefinition has app as JSON type. - # TODO: Figure out a better design to avoid this. super().__init__(**kwargs) app = kwargs['app'] self.app = app + assert self.instrument is not None, "App class cannot be instantiated. Use one of the subclasses." + self.instrument.instrument_object( obj=self.app, query=Select.Query().app ) diff --git a/trulens_eval/trulens_eval/feedback/README.md b/trulens_eval/trulens_eval/feedback/README.md index 603f6ec2d..a3effff7f 100644 --- a/trulens_eval/trulens_eval/feedback/README.md +++ b/trulens_eval/trulens_eval/feedback/README.md @@ -320,7 +320,7 @@ The other non-excluded fields accessible outside of the wrapped app are listed in the `AppDefinition` class in `schema.py`: ```python -class AppDefinition(SerialModel, WithClassInfo, ABC): +class AppDefinition(WithClassInfo, SerialModel, ABC): ... app_id: AppID diff --git a/trulens_eval/trulens_eval/feedback/embeddings.py b/trulens_eval/trulens_eval/feedback/embeddings.py index de166d6cd..99eb3e109 100644 --- a/trulens_eval/trulens_eval/feedback/embeddings.py +++ b/trulens_eval/trulens_eval/feedback/embeddings.py @@ -15,8 +15,7 @@ with OptionalImports(messages=REQUIREMENT_LLAMA): from llama_index import ServiceContext - -class Embeddings(SerialModel, WithClassInfo): +class Embeddings(WithClassInfo, SerialModel): """Embedding related feedback function implementations. """ _embed_model: 'Embedder' = PrivateAttr() @@ -33,7 +32,7 @@ def __init__(self, embed_model: 'Embedder' = None): service_context = ServiceContext.from_defaults(embed_model=embed_model) self._embed_model = service_context.embed_model - super().__init__(obj=self) + super().__init__() def cosine_distance( self, query: str, document: str diff --git a/trulens_eval/trulens_eval/feedback/groundedness.py b/trulens_eval/trulens_eval/feedback/groundedness.py index ca407d2a7..a38e50113 100644 --- a/trulens_eval/trulens_eval/feedback/groundedness.py +++ b/trulens_eval/trulens_eval/feedback/groundedness.py @@ -1,11 +1,12 @@ import logging from typing import Dict, List, Optional +import pydantic import numpy as np from tqdm.auto import tqdm from trulens_eval.feedback import prompts -from trulens_eval.feedback.provider import Provider +from trulens_eval.feedback.provider.base import Provider from trulens_eval.feedback.provider.bedrock import Bedrock from trulens_eval.feedback.provider.hugs import Huggingface from trulens_eval.feedback.provider.litellm import LiteLLM @@ -18,20 +19,22 @@ logger = logging.getLogger(__name__) -class Groundedness(SerialModel, WithClassInfo): - """Measures Groundedness. +class Groundedness(WithClassInfo, SerialModel): """ + Measures Groundedness. + """ + groundedness_provider: Provider - def __init__( - self, groundedness_provider: Optional[Provider] = None, **kwargs - ): - """Instantiates the groundedness providers. Currently the groundedness functions work well with a summarizer. - This class will use an LLM to find the relevant strings in a text. The groundedness_provider can + def __init__(self, groundedness_provider: Optional[Provider] = None, **kwargs): + """ + Instantiates the groundedness providers. Currently the groundedness + functions work well with a summarizer. This class will use an LLM to + find the relevant strings in a text. The groundedness_provider can either be an LLM provider (such as OpenAI) or NLI with huggingface. Usage 1: - ``` + ```python from trulens_eval.feedback import Groundedness from trulens_eval.feedback.provider.openai import OpenAI openai_provider = OpenAI() @@ -39,7 +42,7 @@ def __init__( ``` Usage 2: - ``` + ```python from trulens_eval.feedback import Groundedness from trulens_eval.feedback.provider.hugs import Huggingface huggingface_provider = Huggingface() @@ -47,15 +50,18 @@ def __init__( ``` Args: - groundedness_provider (Provider, optional): groundedness provider options: OpenAI LLM or HuggingFace NLI. Defaults to OpenAI(). - summarize_provider (Provider, optional): Internal Usage for DB serialization. + - groundedness_provider (Provider, optional): groundedness provider + options: OpenAI LLM or HuggingFace NLI. Defaults to OpenAI(). + - summarize_provider (Provider, optional): Internal Usage for DB + serialization. """ if groundedness_provider is None: + logger.warning("Provider not provided. Using OpenAI.") groundedness_provider = OpenAI() + super().__init__( groundedness_provider=groundedness_provider, - obj=self, # for WithClassInfo **kwargs ) @@ -64,26 +70,27 @@ def groundedness_measure(self, source: str, statement: str) -> float: This groundedness measure is faster; but less accurate than `groundedness_measure_with_summarize_step` Usage on RAG Contexts: - ``` + ```python from trulens_eval import Feedback from trulens_eval.feedback import Groundedness from trulens_eval.feedback.provider.openai import OpenAI grounded = feedback.Groundedness(groundedness_provider=OpenAI()) - f_groundedness = feedback.Feedback(grounded.groundedness_measure).on( Select.Record.app.combine_documents_chain._call.args.inputs.input_documents[:].page_content # See note below ).on_output().aggregate(grounded.grounded_statements_aggregator) ``` - The `on(...)` selector can be changed. See [Feedback Function Guide : Selectors](https://www.trulens.org/trulens_eval/feedback_function_guide/#selector-details) + The `on(...)` selector can be changed. See [Feedback Function Guide : + Selectors](https://www.trulens.org/trulens_eval/feedback_function_guide/#selector-details) Args: source (str): The source that should support the statement statement (str): The statement to check groundedness Returns: - float: A measure between 0 and 1, where 1 means each sentence is grounded in the source. + float: A measure between 0 and 1, where 1 means each sentence is + grounded in the source. """ logger.warning( "Feedback function `groundedness_measure` was renamed to `groundedness_measure_with_cot_reasons`. The new functionality of `groundedness_measure` function will no longer emit reasons as a lower cost option. It may have reduced accuracy due to not using Chain of Thought reasoning in the scoring." @@ -120,7 +127,8 @@ def groundedness_measure(self, source: str, statement: str) -> float: def groundedness_measure_with_cot_reasons( self, source: str, statement: str ) -> float: - """A measure to track if the source material supports each sentence in the statement. + """ + A measure to track if the source material supports each sentence in the statement. This groundedness measure is faster; but less accurate than `groundedness_measure_with_summarize_step`. Also uses chain of thought methodology and emits the reasons. @@ -223,7 +231,6 @@ def grounded_statements_aggregator( ) -> float: """Aggregates multi-input, mulit-output information from the groundedness_measure methods. - Args: source_statements_multi_output (List[Dict]): A list of scores. Each list index is a context. The Dict is a per statement score. diff --git a/trulens_eval/trulens_eval/feedback/groundtruth.py b/trulens_eval/trulens_eval/feedback/groundtruth.py index d72be6c70..e24eda5c8 100644 --- a/trulens_eval/trulens_eval/feedback/groundtruth.py +++ b/trulens_eval/trulens_eval/feedback/groundtruth.py @@ -24,11 +24,13 @@ # TODEP -class GroundTruthAgreement(SerialModel, WithClassInfo): - """Measures Agreement against a Ground Truth. +class GroundTruthAgreement(WithClassInfo, SerialModel): + """ + Measures Agreement against a Ground Truth. """ ground_truth: Union[List[Dict], FunctionOrMethod] provider: Provider + # Note: the bert scorer object isn't serializable # It's a class member because creating it is expensive bert_scorer: object @@ -93,7 +95,6 @@ def __init__( ground_truth_imp=ground_truth_imp, provider=provider, bert_scorer=bert_scorer, - obj=self, # for WithClassInfo **kwargs ) diff --git a/trulens_eval/trulens_eval/feedback/provider/base.py b/trulens_eval/trulens_eval/feedback/provider/base.py index 9b77897c0..32cb7426f 100644 --- a/trulens_eval/trulens_eval/feedback/provider/base.py +++ b/trulens_eval/trulens_eval/feedback/provider/base.py @@ -4,6 +4,8 @@ from typing import ClassVar, Dict, Optional, Sequence, Tuple, Union import warnings +import pydantic + from trulens_eval.feedback import prompts from trulens_eval.feedback.provider.endpoint.base import Endpoint from trulens_eval.utils.generated import re_0_10_rating @@ -13,20 +15,16 @@ logger = logging.getLogger(__name__) -class Provider(SerialModel, WithClassInfo): +class Provider(WithClassInfo, SerialModel): model_config: ClassVar[dict] = dict(arbitrary_types_allowed=True) endpoint: Optional[Endpoint] = None def __init__(self, name: Optional[str] = None, **kwargs): - # for WithClassInfo: - kwargs['obj'] = self - super().__init__(name=name, **kwargs) - -class LLMProvider(Provider, ABC): +class LLMProvider(Provider): # NOTE(piotrm): "model_" prefix for attributes is "protected" by pydantic v2 # by default. Need the below adjustment but this means we don't get any @@ -41,14 +39,13 @@ def __init__(self, *args, **kwargs): # down below. Adding it as None here as a temporary hack # TODO: why was self_kwargs required here independently of kwargs? - self_kwargs = dict() - self_kwargs.update(**kwargs) + self_kwargs = dict(kwargs) super().__init__( **self_kwargs ) # need to include pydantic.BaseModel.__init__ - @abstractmethod + #@abstractmethod def _create_chat_completion( self, prompt: Optional[str] = None, @@ -62,6 +59,7 @@ def _create_chat_completion( str: Completion model response. """ # text + raise NotImplementedError() pass def _find_relevant_string(self, full_source: str, hypothesis: str) -> str: diff --git a/trulens_eval/trulens_eval/feedback/provider/endpoint/base.py b/trulens_eval/trulens_eval/feedback/provider/endpoint/base.py index 559596fae..1cf23dbb3 100644 --- a/trulens_eval/trulens_eval/feedback/provider/endpoint/base.py +++ b/trulens_eval/trulens_eval/feedback/provider/endpoint/base.py @@ -21,7 +21,7 @@ import requests from trulens_eval.schema import Cost -from trulens_eval.utils.pyschema import safe_getattr +from trulens_eval.utils.pyschema import WithClassInfo, safe_getattr from trulens_eval.utils.python import get_first_local_in_call_stack from trulens_eval.utils.python import locals_except from trulens_eval.utils.python import safe_hasattr @@ -65,7 +65,7 @@ def handle_classification(self, response: Any) -> None: self.handle(response) -class Endpoint(SerialModel, SingletonPerName): +class Endpoint(WithClassInfo, SerialModel, SingletonPerName): model_config: ClassVar[dict] = dict(arbitrary_types_allowed=True) @@ -153,22 +153,27 @@ def __new__(cls, *args, name: Optional[str] = None, **kwargs): name = name or cls.__name__ return super().__new__(cls, *args, name=name, **kwargs) - def __init__(self, *args, name: str, callback_class: Any, **kwargs): + def __init__(self, *args, name: str, callback_class: Any = None, **kwargs): """ API usage, pacing, and utilities for API endpoints. + + - `callback_class` should be set by subclass. """ if safe_hasattr(self, "rpm"): # already initialized via the SingletonPerName mechanism return + if callback_class is None: + raise ValueError("Endpoint has to be extended by class that can set `callback_class`.") + kwargs['name'] = name kwargs['callback_class'] = callback_class kwargs['global_callback'] = callback_class() kwargs['callback_name'] = f"callback_{name}" kwargs['pace_thread'] = Thread() # temporary kwargs['pace_thread'].daemon = True - super(SerialModel, self).__init__(*args, **kwargs) + super().__init__(*args, **kwargs) def keep_pace(): while True: diff --git a/trulens_eval/trulens_eval/feedback/provider/endpoint/bedrock.py b/trulens_eval/trulens_eval/feedback/provider/endpoint/bedrock.py index 2858f9d9d..61fc391a2 100644 --- a/trulens_eval/trulens_eval/feedback/provider/endpoint/bedrock.py +++ b/trulens_eval/trulens_eval/feedback/provider/endpoint/bedrock.py @@ -129,7 +129,7 @@ def handle_generation(self, response: Any) -> None: ) -class BedrockEndpoint(Endpoint, WithClassInfo): +class BedrockEndpoint(Endpoint): """ Bedrock endpoint. Instruments "completion" methods in bedrock.* classes. """ @@ -171,9 +171,6 @@ def __init__( kwargs['name'] = name kwargs['callback_class'] = BedrockCallback - # for WithClassInfo: - kwargs['obj'] = self - super().__init__(*args, **kwargs) # Note here was are instrumenting a method that outputs a function which diff --git a/trulens_eval/trulens_eval/feedback/provider/endpoint/hugs.py b/trulens_eval/trulens_eval/feedback/provider/endpoint/hugs.py index f1d0c3e6f..3a433bc17 100644 --- a/trulens_eval/trulens_eval/feedback/provider/endpoint/hugs.py +++ b/trulens_eval/trulens_eval/feedback/provider/endpoint/hugs.py @@ -31,7 +31,7 @@ def handle_classification(self, response: requests.Response) -> None: self.cost.n_classes += len(item) -class HuggingfaceEndpoint(Endpoint, WithClassInfo): +class HuggingfaceEndpoint(Endpoint): """ Huggingface. Instruments the requests.post method for requests to "https://api-inference.huggingface.co". @@ -75,9 +75,6 @@ def __init__(self, *args, **kwargs): if _check_key("HUGGINGFACE_API_KEY", silent=True, warn=True): kwargs['post_headers'] = get_huggingface_headers() - # for WithClassInfo: - kwargs['obj'] = self - super().__init__(*args, **kwargs) self._instrument_class(requests, "post") diff --git a/trulens_eval/trulens_eval/feedback/provider/endpoint/langchain.py b/trulens_eval/trulens_eval/feedback/provider/endpoint/langchain.py index 35a3d0875..7bf532013 100644 --- a/trulens_eval/trulens_eval/feedback/provider/endpoint/langchain.py +++ b/trulens_eval/trulens_eval/feedback/provider/endpoint/langchain.py @@ -23,7 +23,7 @@ def handle_generation(self, response: Any) -> None: super().handle_generation(response) -class LangchainEndpoint(Endpoint, WithClassInfo): +class LangchainEndpoint(Endpoint): """ Langchain endpoint. """ diff --git a/trulens_eval/trulens_eval/feedback/provider/endpoint/litellm.py b/trulens_eval/trulens_eval/feedback/provider/endpoint/litellm.py index b25b6f984..5532081e7 100644 --- a/trulens_eval/trulens_eval/feedback/provider/endpoint/litellm.py +++ b/trulens_eval/trulens_eval/feedback/provider/endpoint/litellm.py @@ -23,7 +23,7 @@ def handle_generation(self, response: Any) -> None: super().handle_generation(response) -class LiteLLMEndpoint(Endpoint, WithClassInfo): +class LiteLLMEndpoint(Endpoint): """ LiteLLM endpoint. Instruments "completion" methods in litellm.* classes. """ @@ -62,7 +62,4 @@ def __init__(self, *args, **kwargs): kwargs['name'] = "litellm" kwargs['callback_class'] = LiteLLMCallback - # for WithClassInfo: - kwargs['obj'] = self - super().__init__(*args, **kwargs) diff --git a/trulens_eval/trulens_eval/feedback/provider/endpoint/openai.py b/trulens_eval/trulens_eval/feedback/provider/endpoint/openai.py index 27e291235..c5d41da51 100644 --- a/trulens_eval/trulens_eval/feedback/provider/endpoint/openai.py +++ b/trulens_eval/trulens_eval/feedback/provider/endpoint/openai.py @@ -194,7 +194,7 @@ def handle_generation(self, response: LLMResult) -> None: ) -class OpenAIEndpoint(Endpoint, WithClassInfo): +class OpenAIEndpoint(Endpoint): """ OpenAI endpoint. Instruments "create" methods in openai client. """ @@ -285,6 +285,7 @@ def handle_wrapped_call( def __init__( self, rpm: float = DEFAULT_RPM, + name: str = "openai", client: Optional[Union[oai.OpenAI, oai.AzureOpenAI, OpenAIClient]] = None, **kwargs @@ -298,13 +299,13 @@ def __init__( return self_kwargs = dict( - name="openai", # for SingletonPerName + name=name, # for SingletonPerName rpm=rpm, # for Endpoint - callback_class=OpenAICallback, - obj=self, # for WithClassInfo: **kwargs ) + self_kwargs['callback_class'] = OpenAICallback + if CLASS_INFO in kwargs: del kwargs[CLASS_INFO] diff --git a/trulens_eval/trulens_eval/feedback/provider/openai.py b/trulens_eval/trulens_eval/feedback/provider/openai.py index c7e37db2e..b5419817f 100644 --- a/trulens_eval/trulens_eval/feedback/provider/openai.py +++ b/trulens_eval/trulens_eval/feedback/provider/openai.py @@ -20,7 +20,9 @@ class OpenAI(LLMProvider): # model_engine: str # LLMProvider - endpoint: Endpoint + # Endpoint cannot presently be serialized but is constructed in __init__ + # below so it is ok. + endpoint: Endpoint = pydantic.Field(exclude=True) def __init__( self, *args, endpoint=None, model_engine="gpt-3.5-turbo", **kwargs @@ -31,21 +33,20 @@ def __init__( """ Create an OpenAI Provider with out of the box feedback functions. - **Usage:** - ```python - from trulens_eval.feedback.provider.openai import OpenAI - openai_provider = OpenAI() - ``` + **Usage:** ```python from trulens_eval.feedback.provider.openai import + OpenAI openai_provider = OpenAI() ``` Args: model_engine (str): The OpenAI completion model. Defaults to - `gpt-3.5-turbo` - endpoint (Endpoint): Internal Usage for DB serialization + `gpt-3.5-turbo` + endpoint (Endpoint): Internal Usage for DB serialization. This + argument is intentionally ignored. """ # TODO: why was self_kwargs required here independently of kwargs? self_kwargs = dict() self_kwargs.update(**kwargs) self_kwargs['model_engine'] = model_engine + self_kwargs['endpoint'] = OpenAIEndpoint(*args, **kwargs) super().__init__( @@ -364,8 +365,9 @@ def moderation_harassment_threatening(self, text: str) -> float: class AzureOpenAI(OpenAI): - """Out of the box feedback functions calling AzureOpenAI APIs. - Has the same functionality as OpenAI out of the box feedback functions. + """ + Out of the box feedback functions calling AzureOpenAI APIs. Has the same + functionality as OpenAI out of the box feedback functions. """ # Sent to our openai client wrapper but need to keep here as well so that it @@ -401,7 +403,7 @@ def __init__(self, deployment_name: str, endpoint=None, **kwargs): - deployment_name (str, required): The name of the deployment. - endpoint (Optional[Endpoint]): Internal Usage for DB - serialization. + serialization. This argument is intentionally ignored. """ # Make a dict of args to pass to AzureOpenAI client. Remove any we use @@ -421,7 +423,8 @@ def __init__(self, deployment_name: str, endpoint=None, **kwargs): kwargs["client"] = OpenAIClient(client=oai.AzureOpenAI(**client_kwargs)) super().__init__( - endpoint=endpoint, **kwargs + endpoint=None, + **kwargs ) # need to include pydantic.BaseModel.__init__ def _create_chat_completion(self, *args, **kwargs): diff --git a/trulens_eval/trulens_eval/feedback/v2/provider/base.py b/trulens_eval/trulens_eval/feedback/v2/provider/base.py index b164e2fa2..7db03adac 100644 --- a/trulens_eval/trulens_eval/feedback/v2/provider/base.py +++ b/trulens_eval/trulens_eval/feedback/v2/provider/base.py @@ -14,16 +14,13 @@ # Level 4 feedback abstraction -class Provider(SerialModel, WithClassInfo): +class Provider(WithClassInfo, SerialModel): model_config: ClassVar[dict] = dict(arbitrary_types_allowed=True) endpoint: Optional[Endpoint] - def __init__(self, name: str = None, **kwargs): - # for WithClassInfo: - kwargs['obj'] = self - + def __init__(self, *args, name: Optional[str] = None, **kwargs): super().__init__(*args, **kwargs) @abstractmethod diff --git a/trulens_eval/trulens_eval/schema.py b/trulens_eval/trulens_eval/schema.py index 16ed7e2bf..e62349fc9 100644 --- a/trulens_eval/trulens_eval/schema.py +++ b/trulens_eval/trulens_eval/schema.py @@ -436,7 +436,7 @@ def __init__( TFeedbackResultFuture = Future -class FeedbackDefinition(SerialModel, WithClassInfo): +class FeedbackDefinition(WithClassInfo, SerialModel): # Serialized parts of a feedback function. The non-serialized parts are in # the feedback.py:Feedback class. @@ -481,9 +481,6 @@ def __init__( selectors = selectors or dict() - # for WithClassInfo: - kwargs['obj'] = self - super().__init__( feedback_definition_id="temporary", selectors=selectors, @@ -524,7 +521,7 @@ class FeedbackMode(str, Enum): DEFERRED = "deferred" -class AppDefinition(SerialModel, WithClassInfo): +class AppDefinition(WithClassInfo, SerialModel): # Serialized fields here whereas app.py:App contains # non-serialized fields. @@ -562,65 +559,6 @@ class AppDefinition(SerialModel, WithClassInfo): # whatever the user might want to see about the app. app_extra_json: JSON - @staticmethod - def continue_session( - app_definition_json: JSON, app: Any - ) -> 'AppDefinition': - # initial_app_loader: Optional[Callable] = None) -> 'AppDefinition': - """ - Create a copy of the json serialized app with the enclosed app being - initialized to its initial state before any records are produced (i.e. - blank memory). - """ - - app_definition_json['app'] = app - - cls = WithClassInfo.get_class(app_definition_json) - - return cls(**app_definition_json) - - @staticmethod - def new_session( - app_definition_json: JSON, - initial_app_loader: Optional[Callable] = None - ) -> 'AppDefinition': - """ - Create a copy of the json serialized app with the enclosed app being - initialized to its initial state before any records are produced (i.e. - blank memory). - """ - - serial_bytes_json: Optional[JSON] = app_definition_json[ - 'initial_app_loader_dump'] - - if initial_app_loader is None: - assert serial_bytes_json is not None, "Cannot create new session without `initial_app_loader`." - - serial_bytes = SerialBytes.model_validate(serial_bytes_json) - - app = dill.loads(serial_bytes.data)() - - else: - app = initial_app_loader() - data = dill.dumps(initial_app_loader, recurse=True) - serial_bytes = SerialBytes(data=data) - serial_bytes_json = serial_bytes.model_dump() - - app_definition_json['app'] = app - app_definition_json['initial_app_loader_dump'] = serial_bytes_json - - cls: Type[App] = WithClassInfo.get_class(app_definition_json) - - return cls.model_validate_json(app_definition_json) - - def jsonify_extra(self, content): - # Called by jsonify for us to add any data we might want to add to the - # serialization of `app`. - if self.app_extra_json is not None: - content['app'].update(self.app_extra_json) - - return content - def __init__( self, app_id: Optional[AppID] = None, @@ -637,10 +575,7 @@ def __init__( kwargs['tags'] = "" kwargs['metadata'] = {} kwargs['app_extra_json'] = app_extra_json or dict() - - # for WithClassInfo: - kwargs['obj'] = self - + super().__init__(**kwargs) if app_id is None: @@ -692,6 +627,66 @@ def __init__( f"Some trulens features may not be available: {e}" ) + + @staticmethod + def continue_session( + app_definition_json: JSON, app: Any + ) -> 'AppDefinition': + # initial_app_loader: Optional[Callable] = None) -> 'AppDefinition': + """ + Create a copy of the json serialized app with the enclosed app being + initialized to its initial state before any records are produced (i.e. + blank memory). + """ + + app_definition_json['app'] = app + + cls = WithClassInfo.get_class(app_definition_json) + + return cls(**app_definition_json) + + @staticmethod + def new_session( + app_definition_json: JSON, + initial_app_loader: Optional[Callable] = None + ) -> 'AppDefinition': + """ + Create a copy of the json serialized app with the enclosed app being + initialized to its initial state before any records are produced (i.e. + blank memory). + """ + + serial_bytes_json: Optional[JSON] = app_definition_json[ + 'initial_app_loader_dump'] + + if initial_app_loader is None: + assert serial_bytes_json is not None, "Cannot create new session without `initial_app_loader`." + + serial_bytes = SerialBytes.model_validate(serial_bytes_json) + + app = dill.loads(serial_bytes.data)() + + else: + app = initial_app_loader() + data = dill.dumps(initial_app_loader, recurse=True) + serial_bytes = SerialBytes(data=data) + serial_bytes_json = serial_bytes.model_dump() + + app_definition_json['app'] = app + app_definition_json['initial_app_loader_dump'] = serial_bytes_json + + cls: Type[App] = WithClassInfo.get_class(app_definition_json) + + return cls.model_validate_json(app_definition_json) + + def jsonify_extra(self, content): + # Called by jsonify for us to add any data we might want to add to the + # serialization of `app`. + if self.app_extra_json is not None: + content['app'].update(self.app_extra_json) + + return content + @staticmethod def get_loadable_apps(): # EXPERIMENTAL diff --git a/trulens_eval/trulens_eval/tru_custom_app.py b/trulens_eval/trulens_eval/tru_custom_app.py index 8ff56f5fd..3be376629 100644 --- a/trulens_eval/trulens_eval/tru_custom_app.py +++ b/trulens_eval/trulens_eval/tru_custom_app.py @@ -459,8 +459,6 @@ def __getattr__(self, __name: str) -> Any: # A message for cases where a user calls something that the wrapped # app has but we do not wrap yet. - print(__name) - if safe_hasattr(self.app, __name): return RuntimeError( f"TruCustomApp has no attribute {__name} but the wrapped app ({type(self.app)}) does. ", diff --git a/trulens_eval/trulens_eval/utils/pyschema.py b/trulens_eval/trulens_eval/utils/pyschema.py index d3a5a8560..061feb109 100644 --- a/trulens_eval/trulens_eval/utils/pyschema.py +++ b/trulens_eval/trulens_eval/utils/pyschema.py @@ -22,10 +22,9 @@ from pprint import PrettyPrinter from types import ModuleType from typing import Any, Callable, Dict, Optional, Sequence, Tuple +import warnings -import dill import pydantic -from pydantic import Field from trulens_eval.utils.python import safe_hasattr from trulens_eval.utils.serial import SerialModel @@ -366,6 +365,7 @@ def of_object( #if isinstance(cls, type): # sig = _safe_init_sig(cls) #else: + sig = _safe_init_sig(cls.__call__) b = sig.bind(*init_args, **init_kwargs) @@ -385,25 +385,32 @@ def load(self) -> object: ) cls = self.cls.load() - sig = _safe_init_sig(cls) - if CLASS_INFO in sig.parameters and CLASS_INFO not in self.init_bindings.kwargs: - extra_kwargs = {CLASS_INFO: self.cls} + if issubclass(cls, pydantic.BaseModel): + # For pydantic Models, use model_validate to reconstruct object: + return cls.model_validate(self.init_bindings.kwargs) + else: - extra_kwargs = {} - try: - bindings = self.init_bindings.load(sig, extra_kwargs=extra_kwargs) + sig = _safe_init_sig(cls) - except Exception as e: - msg = f"Error binding constructor args for object:\n" - msg += str(e) + "\n" - msg += f"\tobj={self}\n" - msg += f"\targs={self.init_bindings.args}\n" - msg += f"\tkwargs={self.init_bindings.kwargs}\n" - raise type(e)(msg) + if CLASS_INFO in sig.parameters and CLASS_INFO not in self.init_bindings.kwargs: + extra_kwargs = {CLASS_INFO: self.cls} + else: + extra_kwargs = {} - return cls(*bindings.args, **bindings.kwargs) + try: + bindings = self.init_bindings.load(sig, extra_kwargs=extra_kwargs) + + except Exception as e: + msg = f"Error binding constructor args for object:\n" + msg += str(e) + "\n" + msg += f"\tobj={self}\n" + msg += f"\targs={self.init_bindings.args}\n" + msg += f"\tkwargs={self.init_bindings.kwargs}\n" + raise type(e)(msg) + + return cls(*bindings.args, **bindings.kwargs) class Bindings(SerialModel): @@ -423,6 +430,7 @@ def _handle_providers_load(self): ## But should not be a user supplied input kwarg. # `groundedness_provider` and `provider` explanation ## The rest of the providers need to be instantiated, but are currently in circular dependency if done from util.py + if 'summarize_provider' in self.kwargs: del self.kwargs['summarize_provider'] if 'groundedness_provider' in self.kwargs: @@ -432,7 +440,10 @@ def _handle_providers_load(self): def load(self, sig: inspect.Signature, extra_args=(), extra_kwargs={}): - self._handle_providers_load() + # Disabling this hack as we now have different providers that may need + # to be selected from (i.e. OpenAI vs AzureOpenAI). + + # self._handle_providers_load() return sig.bind( *(self.args + extra_args), **self.kwargs, **extra_kwargs @@ -576,28 +587,55 @@ class WithClassInfo(pydantic.BaseModel): # Using this odd key to not pollute attribute names in whatever class we mix # this into. Should be the same as CLASS_INFO. - tru_class_info: Class = Field(exclude=False) - - @classmethod - def model_validate(cls, obj, **kwargs): - if isinstance(obj, dict) and CLASS_INFO in obj: - - clsinfo = Class.model_validate(obj[CLASS_INFO]) - clsloaded = clsinfo.load() - - # NOTE(piotrm): even though we have a more specific class than - # AppDefinition, we load it as AppDefinition due to serialization - # issues in the wrapped app. Keeping it as AppDefinition means `app` - # field is just json. - from trulens_eval.schema import AppDefinition + tru_class_info: Class # = Field(None, exclude=False) + + # NOTE(piotrm): for some reason, model_validate is not called for nested + # models but the method decorated as such below is called. We use this to + # load an object which includes our class information instead of using + # pydantic for this loading as it would always load the object as per its + # declared field. For example, `Provider` includes `endpoint: Endpoint` but + # we want to load one of the `Endpoint` subclasses. We add the subclass + # information using `WithClassInfo` meaning we can then use this method + # below to load the subclass. Pydantic would only give us `Endpoint`, the + # parent class. + @pydantic.model_validator(mode='before') + @staticmethod + def load(obj, **kwargs): + + if not isinstance(obj, dict): + return obj - if issubclass(clsloaded, AppDefinition): - return super(cls, AppDefinition).model_validate(obj) - else: - return super(cls, clsloaded).model_validate(obj) + if CLASS_INFO not in obj: + raise ValueError("No class info present in object.") - else: - return super().model_validate(obj) + clsinfo = Class.model_validate(obj[CLASS_INFO]) + try: + # If class cannot be loaded, usually because it is not importable, + # return obj as is. + cls = clsinfo.load() + except RuntimeError: + return obj + + validated = dict() + for k, finfo in cls.model_fields.items(): + typ = finfo.annotation + val = finfo.get_default(call_default_factory=True) + + if k in obj: + val = obj[k] + + if isinstance(typ, type) \ + and issubclass(typ, WithClassInfo) \ + and isinstance(val, dict) and CLASS_INFO in val: + subcls = Class.model_validate(val[CLASS_INFO]).load() + val = subcls.model_validate(val) + + validated[k] = val + + # Note that the rest of the validation/conversions for things which are + # not serialized WithClassInfo will be done by pydantic after we return + # this: + return validated def __init__( self, @@ -607,6 +645,15 @@ def __init__( cls: Optional[type] = None, **kwargs ): + if obj is not None: + warnings.warn( + "`obj` does not need to be provided to WithClassInfo any more", + DeprecationWarning + ) + + if obj is None: + obj = self + if obj is not None: cls = type(obj) diff --git a/trulens_eval/trulens_eval/utils/serial.py b/trulens_eval/trulens_eval/utils/serial.py index ddebdb810..ac55db75d 100644 --- a/trulens_eval/trulens_eval/utils/serial.py +++ b/trulens_eval/trulens_eval/utils/serial.py @@ -74,16 +74,8 @@ def model_dump(self, **kwargs): return jsonify(self, **kwargs) - @classmethod - def model_validate(cls, obj, **kwargs): - # import hierarchy circle here - from trulens_eval.utils.pyschema import CLASS_INFO - from trulens_eval.utils.pyschema import WithClassInfo - - if isinstance(obj, Dict) and CLASS_INFO in obj: - return WithClassInfo.model_validate(obj, **kwargs) - - return super(SerialModel, cls).model_validate(obj, **kwargs) + # NOTE(piotrm): regaring model_validate: custom deserialization is done in + # WithClassInfo class but only for classes that mix it in. def update(self, **d): for k, v in d.items():