Skip to content

Commit

Permalink
WithClassInfo bugfixes (#741)
Browse files Browse the repository at this point in the history
* add instructions and text wrapping

* format

* debugging

* making obj arg no longer required

* remove obj and add documentation for WithClassInfo

* remove IPython from most notebooks and organize imports

* fix test errors

* forgot warning

---------

Co-authored-by: Josh Reini <[email protected]>
  • Loading branch information
piotrm0 and joshreini1 authored Jan 4, 2024
1 parent caa9205 commit 7343252
Show file tree
Hide file tree
Showing 33 changed files with 626 additions and 249 deletions.
2 changes: 1 addition & 1 deletion docs/trulens_eval/feedback_function_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ class Record(SerialModel):
For an App:

```python
class AppDefinition(SerialModel, WithClassInfo, ABC):
class AppDefinition(WithClassInfo, SerialModel, ABC):
...

app_id: AppID
Expand Down
340 changes: 335 additions & 5 deletions trulens_eval/examples/experimental/dev_notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
"outputs": [],
"source": [
"# ! pip uninstall -y trulens_eval # ==0.18.2\n",
"# ! pip list | grep trulens"
"# ! pip list | grep trulens\n",
"# ! pip install --upgrade pydantic"
]
},
{
Expand Down Expand Up @@ -91,6 +92,124 @@
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Pydantic testing"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from typing import Any\n",
"\n",
"from pydantic import BaseModel\n",
"from pydantic import Field\n",
"from pydantic import field_validator\n",
"from pydantic import model_validator\n",
"from pydantic import PydanticUndefinedAnnotation\n",
"from pydantic import SerializeAsAny\n",
"from pydantic import ValidationInfo\n",
"from pydantic import validator\n",
"from pydantic_core import PydanticUndefined\n",
"\n",
"\n",
"class CustomLoader(BaseModel):\n",
" cls: Any\n",
"\n",
" def __init__(self, *args, **kwargs):\n",
" kwargs['cls'] = type(self)\n",
" super().__init__(*args, **kwargs)\n",
"\n",
" @model_validator(mode='before')\n",
" @staticmethod\n",
" def my_model_validate(obj, info: ValidationInfo):\n",
" if not isinstance(obj, dict):\n",
" return obj\n",
"\n",
" cls = obj['cls']\n",
" # print(cls, subcls, obj, info)\n",
"\n",
" validated = dict()\n",
" for k, finfo in cls.model_fields.items():\n",
" print(k, finfo)\n",
" typ = finfo.annotation\n",
" val = finfo.get_default()\n",
"\n",
" if val is PydanticUndefined:\n",
" val = obj[k]\n",
"\n",
" print(typ, type(typ))\n",
" if isinstance(typ, type) \\\n",
" and issubclass(typ, CustomLoader) \\\n",
" and isinstance(val, dict) and \"cls\" in val:\n",
" subcls = val['cls']\n",
" val = subcls.model_validate(val)\n",
" \n",
" validated[k] = val\n",
" \n",
" return validated\n",
"\n",
"class SubModel(CustomLoader):\n",
" sm: int = 3\n",
"\n",
"class Model(CustomLoader):\n",
" m: int = 2\n",
" sub: SubModel\n",
"\n",
"class SubSubModelA(SubModel):\n",
" ssma: int = 42\n",
"\n",
"class SubModelA(SubModel):\n",
" sma: int = 0\n",
" subsub: SubSubModelA\n",
"\n",
"class SubModelB(SubModel):\n",
" smb: int = 1\n",
"\n",
"c = Model(sub=SubModelA(subsub=SubSubModelA()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"c"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"c.model_dump()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"Model.model_validate({'cls': Model, 'm': 2, 'sub': {'cls': SubModelA, 'sma':3, 'subsub': {'cls': SubSubModelA, 'ssma': 42}}})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"Model.model_validate({'c': 2, 'sub': {}})"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -202,6 +321,36 @@
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from llama_index.llms import AzureOpenAI as AzureOpenAIChat\n",
"import os\n",
"\n",
"gpt_35_turbo = AzureOpenAIChat(\n",
" deployment_name=\"gpt-35-turbo\",\n",
" model=\"gpt-35-turbo\",\n",
" api_key=os.getenv(\"AZURE_OPENAI_API_KEY\"),\n",
" api_version=\"2023-05-15\",\n",
" model_version=\"0613\",\n",
" temperature=0.0,\n",
")\n",
"c = gpt_35_turbo._get_client()\n",
"gpt_35_turbo._get_credential_kwargs()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"c.base_url"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -210,7 +359,189 @@
"source": [
"import os\n",
"from trulens_eval import feedback\n",
"azopenai = feedback.AzureOpenAI(deployment_name=os.environ['AZURE_OPENAI_DEPLYOMENT_NAME'])"
"azopenai = feedback.AzureOpenAI(\n",
" deployment_name=os.environ['AZURE_OPENAI_DEPLOYMENT_NAME']\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"azopenai.endpoint.client.client_kwargs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# azopenai.relevance(prompt=\"Where is Germany?\", response=\"Germany is in Europe.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# reval = feedback.AzureOpenAI.model_validate(azopenai.model_dump())\n",
"# reval.relevance(prompt=\"Where is Germany?\", response=\"Poland is in Europe.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"azureOpenAI = azopenai\n",
"\n",
"from trulens_eval.feedback.provider import AzureOpenAI\n",
"from trulens_eval.feedback import Groundedness, GroundTruthAgreement\n",
"from trulens_eval import TruLlama, Feedback\n",
"from trulens_eval.app import App\n",
"import numpy as np\n",
"# Initialize provider class\n",
"#azureOpenAI = AzureOpenAI(deployment_name=\"gpt-35-turbo\")\n",
"\n",
"grounded = Groundedness(groundedness_provider=azureOpenAI)\n",
"# Define a groundedness feedback function\n",
"f_groundedness = (\n",
" Feedback(grounded.groundedness_measure_with_cot_reasons)\n",
" .on_input_output()\n",
" .aggregate(grounded.grounded_statements_aggregator)\n",
")\n",
"\n",
"# Question/answer relevance between overall question and answer.\n",
"f_answer_relevance = Feedback(azureOpenAI.relevance_with_cot_reasons).on_input_output()\n",
"# Question/statement relevance between question and each context chunk.\n",
"f_context_relevance = (\n",
" Feedback(azureOpenAI.qs_relevance_with_cot_reasons)\n",
" .on_input_output()\n",
" .aggregate(np.mean)\n",
")\n",
"\n",
"# GroundTruth for comparing the Answer to the Ground-Truth Answer\n",
"#ground_truth_collection = GroundTruthAgreement(golden_set, provider=azureOpenAI)\n",
"#f_answer_correctness = (\n",
"# Feedback(ground_truth_collection.agreement_measure)\n",
"# .on_input_output()\n",
"#)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"f_groundedness.model_dump()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from trulens_eval.utils.pyschema import WithClassInfo\n",
"from trulens_eval.utils.serial import SerialModel\n",
"from trulens_eval.feedback.groundedness import Groundedness\n",
"import pydantic"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"Groundedness.model_validate(grounded.model_dump())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"f2 = Feedback.model_validate(f_groundedness.model_dump())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"f2.implementation.obj.init_bindings.kwargs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"f2.imp"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def test_serial(f):\n",
" print(\"Before serialization:\")\n",
" print(f.imp(\"Where is Poland?\", \"Poland is in Europe\"))\n",
" f_dump = f.model_dump()\n",
" f = Feedback.model_validate(f_dump)\n",
" print(\"After serialization:\")\n",
" print(f.imp(\"Where is Poland?\", \"Germany is in Europe\"))\n",
" return f\n",
"\n",
"f2 = test_serial(f_groundedness)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"f_groundedness.imp"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"f2.imp"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"f_answer_relevance = Feedback(azureOpenAI.relevance_with_cot_reasons).on_input_output()\n",
"\n",
"# test without serialization\n",
"print(f_answer_relevance.imp(prompt=\"Where is Germany?\", response=\"Germany is in Europe.\"))\n",
"\n",
"# serialize/deserialize\n",
"f_answer_relevance2 = Feedback.model_validate(f_answer_relevance.model_dump())\n",
"\n",
"# test after deserialization\n",
"print(f_answer_relevance2.imp(prompt=\"Where is Germany?\", response=\"Poland is in Europe.\"))"
]
},
{
Expand All @@ -219,7 +550,7 @@
"metadata": {},
"outputs": [],
"source": [
"azopenai.relevance(prompt=\"Where is Germany?\", response=\"Germany is in Europe.\")"
"fr = feedback.Feedback.model_validate(f.model_dump())"
]
},
{
Expand All @@ -228,8 +559,7 @@
"metadata": {},
"outputs": [],
"source": [
"reval = feedback.AzureOpenAI.model_validate(azopenai.model_dump())\n",
"reval.relevance(prompt=\"Where is Germany?\", response=\"Poland is in Europe.\")"
"fr.imp(prompt=\"Where is Germany?\", response=\"Germany is in Europe.\")"
]
},
{
Expand Down
Loading

0 comments on commit 7343252

Please sign in to comment.