From 96d6765635ca5aad32408900e96f39544f4bee96 Mon Sep 17 00:00:00 2001 From: Josh Reini <60949774+joshreini1@users.noreply.github.com> Date: Mon, 29 Apr 2024 18:16:57 -0400 Subject: [PATCH] Show OSS models (and tracking) in LiteLLM application (#1109) * oss models in app * import change * fix dims issue, pass embeddings manually * keep some output * remove settings print * link model list for cost tracking --- .../models/litellm_quickstart.ipynb | 1249 ++++++++++++++++- 1 file changed, 1210 insertions(+), 39 deletions(-) diff --git a/trulens_eval/examples/expositional/models/litellm_quickstart.ipynb b/trulens_eval/examples/expositional/models/litellm_quickstart.ipynb index faf8f6b13..2814798c4 100644 --- a/trulens_eval/examples/expositional/models/litellm_quickstart.ipynb +++ b/trulens_eval/examples/expositional/models/litellm_quickstart.ipynb @@ -10,28 +10,32 @@ "\n", "[LiteLLM](https://github.com/BerriAI/litellm) is a consistent way to access 100+ LLMs such as those from OpenAI, HuggingFace, Anthropic, and Cohere. Using LiteLLM dramatically expands the model availability for feedback functions. Please be cautious in trusting the results of evaluations from models that have not yet been tested.\n", "\n", - "Specifically in this example we'll show how to use TogetherAI, but the LiteLLM provider can be used to run feedback functions using any LiteLLM suppported model.\n", + "Specifically in this example we'll show how to use TogetherAI, but the LiteLLM provider can be used to run feedback functions using any LiteLLM suppported model. We'll also use Mistral for the embedding and completion model also accessed via LiteLLM. The token usage and cost metrics for models used by LiteLLM will be also tracked by TruLens.\n", "\n", - "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/truera/trulens/blob/main/trulens_eval/examples/expositional/models/litellm_quickstart.ipynb)" + "Note: LiteLLM costs are tracked for models included in this [litellm community-maintained list](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json).\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/truera/trulens/blob/main/trulens_eval/examples/expositional/models/litellm_quickstart.ipynb)\n", + "\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ - "# ! pip install trulens_eval chromadb openai" + "# ! pip install trulens_eval chromadb mistralai" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import os\n", - "os.environ[\"TOGETHERAI_API_KEY\"] = \"...\"" + "os.environ[\"TOGETHERAI_API_KEY\"] = \"...\"\n", + "os.environ['MISTRAL_API_KEY'] = \"...\"" ] }, { @@ -45,7 +49,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -69,35 +73,1049 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ - "from openai import OpenAI\n", - "oai_client = OpenAI()\n", + "from litellm import embedding\n", + "import os\n", "\n", - "oai_client.embeddings.create(\n", - " model=\"text-embedding-ada-002\",\n", - " input=university_info\n", - " )" + "embedding_response = embedding(\n", + " model=\"mistral/mistral-embed\",\n", + " input=university_info,\n", + ")" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[-0.0302734375,\n", + " 0.01617431640625,\n", + " 0.028350830078125,\n", + " -0.017974853515625,\n", + " 0.05322265625,\n", + " -0.01155853271484375,\n", + " 0.053466796875,\n", + " 0.0017957687377929688,\n", + " -0.00824737548828125,\n", + " 0.0037555694580078125,\n", + " -0.037750244140625,\n", + " 0.0171966552734375,\n", + " 0.0099029541015625,\n", + " 0.0010271072387695312,\n", + " -0.06402587890625,\n", + " 0.023681640625,\n", + " -0.0029296875,\n", + " 0.0113677978515625,\n", + " 0.04144287109375,\n", + " 0.01119232177734375,\n", + " -0.031890869140625,\n", + " -0.03778076171875,\n", + " -0.0233917236328125,\n", + " 0.0240020751953125,\n", + " -0.01018524169921875,\n", + " -0.0157623291015625,\n", + " -0.021636962890625,\n", + " -0.0692138671875,\n", + " -0.04681396484375,\n", + " -0.00518035888671875,\n", + " 0.0244140625,\n", + " -0.0034770965576171875,\n", + " 0.0118560791015625,\n", + " 0.0124969482421875,\n", + " -0.003833770751953125,\n", + " -0.0194244384765625,\n", + " -0.00225830078125,\n", + " -0.04669189453125,\n", + " 0.0265350341796875,\n", + " -0.0079803466796875,\n", + " -0.02178955078125,\n", + " -0.0103302001953125,\n", + " -0.0426025390625,\n", + " -0.034881591796875,\n", + " 0.0002834796905517578,\n", + " -0.037384033203125,\n", + " -0.0142364501953125,\n", + " -0.036956787109375,\n", + " -0.0185699462890625,\n", + " -0.0213470458984375,\n", + " 0.004390716552734375,\n", + " 0.00279998779296875,\n", + " 0.0300445556640625,\n", + " -0.0154266357421875,\n", + " -0.00665283203125,\n", + " 0.021514892578125,\n", + " 0.03765869140625,\n", + " -0.0235595703125,\n", + " -0.048248291015625,\n", + " 0.042388916015625,\n", + " -0.034332275390625,\n", + " -0.026947021484375,\n", + " -0.05242919921875,\n", + " -0.001308441162109375,\n", + " 0.0234375,\n", + " 0.003143310546875,\n", + " 0.00907135009765625,\n", + " -0.042236328125,\n", + " -0.005313873291015625,\n", + " 0.036529541015625,\n", + " 0.0338134765625,\n", + " 0.00955963134765625,\n", + " 3.153085708618164e-05,\n", + " 0.027801513671875,\n", + " -0.041839599609375,\n", + " -0.023712158203125,\n", + " 0.0246429443359375,\n", + " 0.01393890380859375,\n", + " 0.04193115234375,\n", + " -0.01053619384765625,\n", + " -0.042999267578125,\n", + " -0.0033550262451171875,\n", + " 0.06304931640625,\n", + " -0.060699462890625,\n", + " -0.00756072998046875,\n", + " 0.0223236083984375,\n", + " 0.0115203857421875,\n", + " 0.0038013458251953125,\n", + " -0.003421783447265625,\n", + " 0.00727081298828125,\n", + " 0.053741455078125,\n", + " -0.0287017822265625,\n", + " 0.005245208740234375,\n", + " -0.018463134765625,\n", + " 0.04534912109375,\n", + " 0.05615234375,\n", + " -0.024261474609375,\n", + " -0.041168212890625,\n", + " -0.001064300537109375,\n", + " -0.01384735107421875,\n", + " -0.004367828369140625,\n", + " -0.0225982666015625,\n", + " 0.056854248046875,\n", + " -0.014190673828125,\n", + " 0.04400634765625,\n", + " -0.0184783935546875,\n", + " -0.006565093994140625,\n", + " -0.01007080078125,\n", + " 0.0005826950073242188,\n", + " -0.0254364013671875,\n", + " -0.09381103515625,\n", + " -0.035186767578125,\n", + " 0.02978515625,\n", + " -0.0595703125,\n", + " -0.033935546875,\n", + " 0.0074615478515625,\n", + " -0.034210205078125,\n", + " 0.0247955322265625,\n", + " -0.057159423828125,\n", + " -0.02911376953125,\n", + " 0.033538818359375,\n", + " 0.002536773681640625,\n", + " 0.00922393798828125,\n", + " 0.038787841796875,\n", + " -0.036834716796875,\n", + " -0.05084228515625,\n", + " -0.0016632080078125,\n", + " 0.0158538818359375,\n", + " -0.0032291412353515625,\n", + " -0.004863739013671875,\n", + " -0.0186614990234375,\n", + " -0.0272674560546875,\n", + " -0.036834716796875,\n", + " -0.01058197021484375,\n", + " -0.018585205078125,\n", + " -0.0009102821350097656,\n", + " 0.03826904296875,\n", + " -0.0099029541015625,\n", + " -0.0228118896484375,\n", + " 0.01885986328125,\n", + " 0.00411224365234375,\n", + " -0.018829345703125,\n", + " -0.02911376953125,\n", + " -0.0002244710922241211,\n", + " -0.04461669921875,\n", + " -0.0006680488586425781,\n", + " 0.0028514862060546875,\n", + " 0.030670166015625,\n", + " -0.037384033203125,\n", + " -0.004169464111328125,\n", + " 0.01107025146484375,\n", + " 0.0460205078125,\n", + " 0.059967041015625,\n", + " -0.0139617919921875,\n", + " -0.004695892333984375,\n", + " -0.0323486328125,\n", + " 0.01361846923828125,\n", + " -0.0302886962890625,\n", + " 0.014190673828125,\n", + " 0.00502777099609375,\n", + " -0.01064300537109375,\n", + " 0.0057830810546875,\n", + " -0.00299835205078125,\n", + " 0.0418701171875,\n", + " -0.0187225341796875,\n", + " -0.01285552978515625,\n", + " -0.0268707275390625,\n", + " 0.032318115234375,\n", + " -0.02362060546875,\n", + " 0.0262603759765625,\n", + " 0.060333251953125,\n", + " 0.00931549072265625,\n", + " 0.036956787109375,\n", + " 0.07586669921875,\n", + " -0.0256500244140625,\n", + " -0.0191650390625,\n", + " 0.005096435546875,\n", + " -0.0052337646484375,\n", + " 0.048370361328125,\n", + " 0.0379638671875,\n", + " -0.00521087646484375,\n", + " -0.0275421142578125,\n", + " 0.034271240234375,\n", + " -0.019134521484375,\n", + " -0.0124969482421875,\n", + " -0.02215576171875,\n", + " -0.0340576171875,\n", + " -0.02752685546875,\n", + " -0.01617431640625,\n", + " 0.01751708984375,\n", + " 0.0030117034912109375,\n", + " -0.071044921875,\n", + " -0.01113128662109375,\n", + " -0.0064697265625,\n", + " -0.0304412841796875,\n", + " 0.0318603515625,\n", + " 0.0262908935546875,\n", + " -0.0122222900390625,\n", + " 0.026336669921875,\n", + " 0.00785064697265625,\n", + " 0.0111846923828125,\n", + " -0.004241943359375,\n", + " -0.01486968994140625,\n", + " 0.056488037109375,\n", + " 0.0180511474609375,\n", + " -0.0090484619140625,\n", + " -0.00653839111328125,\n", + " -0.00824737548828125,\n", + " 0.038055419921875,\n", + " -0.00913238525390625,\n", + " -0.0241241455078125,\n", + " 0.00873565673828125,\n", + " -0.0291595458984375,\n", + " -0.009033203125,\n", + " -0.0278167724609375,\n", + " -0.0114288330078125,\n", + " 0.018646240234375,\n", + " -0.006195068359375,\n", + " 0.002780914306640625,\n", + " 0.01448822021484375,\n", + " 0.0143890380859375,\n", + " -0.0758056640625,\n", + " 0.01200103759765625,\n", + " 0.01334381103515625,\n", + " 0.013946533203125,\n", + " 0.0355224609375,\n", + " 0.018829345703125,\n", + " -0.01739501953125,\n", + " 0.006412506103515625,\n", + " 0.0042572021484375,\n", + " 0.03204345703125,\n", + " -0.01108551025390625,\n", + " -0.0184478759765625,\n", + " 0.0247955322265625,\n", + " -0.0189208984375,\n", + " -0.020111083984375,\n", + " 0.0215301513671875,\n", + " 0.01195526123046875,\n", + " 0.006072998046875,\n", + " -0.0030059814453125,\n", + " -0.0210418701171875,\n", + " 0.02227783203125,\n", + " -0.02288818359375,\n", + " -0.00208282470703125,\n", + " 0.012664794921875,\n", + " -0.01303863525390625,\n", + " 0.03643798828125,\n", + " 0.01007080078125,\n", + " 0.003108978271484375,\n", + " 0.046905517578125,\n", + " -0.056060791015625,\n", + " -0.0241851806640625,\n", + " -0.04766845703125,\n", + " -0.0035858154296875,\n", + " -0.05755615234375,\n", + " -0.032135009765625,\n", + " -0.03448486328125,\n", + " -0.0491943359375,\n", + " 0.0635986328125,\n", + " -0.0217132568359375,\n", + " -0.0192108154296875,\n", + " -0.0305938720703125,\n", + " 0.0301361083984375,\n", + " -0.0230560302734375,\n", + " 0.029693603515625,\n", + " 0.01239013671875,\n", + " -0.03509521484375,\n", + " -0.037109375,\n", + " 0.108642578125,\n", + " -0.007785797119140625,\n", + " -0.01291656494140625,\n", + " -0.0069427490234375,\n", + " 0.035430908203125,\n", + " 0.01904296875,\n", + " 0.031219482421875,\n", + " -0.0257110595703125,\n", + " -0.0087738037109375,\n", + " 0.047088623046875,\n", + " 0.00843048095703125,\n", + " -0.01224517822265625,\n", + " -0.0146331787109375,\n", + " 0.0223846435546875,\n", + " 0.00943756103515625,\n", + " 0.053131103515625,\n", + " -0.060943603515625,\n", + " 0.00433349609375,\n", + " 0.01392364501953125,\n", + " 0.0212860107421875,\n", + " -0.0171661376953125,\n", + " -0.07049560546875,\n", + " -0.00359344482421875,\n", + " 0.035614013671875,\n", + " 0.003993988037109375,\n", + " -0.007427215576171875,\n", + " -0.0180206298828125,\n", + " -0.0101165771484375,\n", + " 0.02435302734375,\n", + " 0.02496337890625,\n", + " -0.021575927734375,\n", + " 0.049285888671875,\n", + " 0.0126800537109375,\n", + " -0.00266265869140625,\n", + " -0.0282745361328125,\n", + " 0.0247802734375,\n", + " 0.01336669921875,\n", + " -0.04107666015625,\n", + " -0.06805419921875,\n", + " -0.0227813720703125,\n", + " 0.0113525390625,\n", + " -0.0655517578125,\n", + " -0.0281982421875,\n", + " 0.02325439453125,\n", + " 0.00467681884765625,\n", + " -0.002475738525390625,\n", + " 0.005615234375,\n", + " -0.0054168701171875,\n", + " -0.051483154296875,\n", + " -0.0445556640625,\n", + " 0.02374267578125,\n", + " -0.0504150390625,\n", + " -0.059326171875,\n", + " -0.00893402099609375,\n", + " 0.03741455078125,\n", + " 0.0238189697265625,\n", + " 0.002716064453125,\n", + " 0.01123809814453125,\n", + " -0.0155487060546875,\n", + " -0.0300445556640625,\n", + " 0.0185394287109375,\n", + " -0.00966644287109375,\n", + " -0.0026645660400390625,\n", + " -0.033416748046875,\n", + " -0.0094146728515625,\n", + " 0.0112152099609375,\n", + " 0.013397216796875,\n", + " 0.00481414794921875,\n", + " 0.03399658203125,\n", + " 0.0386962890625,\n", + " -0.05609130859375,\n", + " -0.0020580291748046875,\n", + " -0.003955841064453125,\n", + " -0.01514434814453125,\n", + " -0.004581451416015625,\n", + " -0.0218505859375,\n", + " -0.0191650390625,\n", + " 0.0222320556640625,\n", + " -0.0138092041015625,\n", + " -0.003833770751953125,\n", + " 0.01146697998046875,\n", + " 0.0294342041015625,\n", + " 0.01666259765625,\n", + " -0.044677734375,\n", + " 0.0010833740234375,\n", + " 0.06488037109375,\n", + " -0.0231475830078125,\n", + " 0.11651611328125,\n", + " -0.0477294921875,\n", + " -0.0235595703125,\n", + " 0.009307861328125,\n", + " 0.04229736328125,\n", + " 0.010162353515625,\n", + " 0.0154876708984375,\n", + " 0.019805908203125,\n", + " 0.002567291259765625,\n", + " -0.0321044921875,\n", + " 0.03204345703125,\n", + " -0.058074951171875,\n", + " 0.01092529296875,\n", + " 0.006603240966796875,\n", + " -0.0210113525390625,\n", + " -0.01084136962890625,\n", + " 0.004161834716796875,\n", + " 0.0247955322265625,\n", + " 0.061248779296875,\n", + " 0.038787841796875,\n", + " 0.02606201171875,\n", + " -0.01549530029296875,\n", + " -0.02923583984375,\n", + " -0.004367828369140625,\n", + " -0.020172119140625,\n", + " -0.0494384765625,\n", + " 0.01407623291015625,\n", + " 0.0146636962890625,\n", + " 0.006526947021484375,\n", + " 0.006916046142578125,\n", + " 0.00458526611328125,\n", + " -0.0282745361328125,\n", + " -0.003810882568359375,\n", + " -0.0264434814453125,\n", + " 0.1046142578125,\n", + " 0.08697509765625,\n", + " 0.07684326171875,\n", + " 0.0419921875,\n", + " 0.0054931640625,\n", + " -0.0016603469848632812,\n", + " 0.02532958984375,\n", + " 0.0130157470703125,\n", + " 0.018768310546875,\n", + " 0.0223541259765625,\n", + " 0.007762908935546875,\n", + " 0.0078277587890625,\n", + " -0.0318603515625,\n", + " 0.0557861328125,\n", + " 0.025482177734375,\n", + " 0.0276641845703125,\n", + " 0.0253753662109375,\n", + " 0.046051025390625,\n", + " 0.03582763671875,\n", + " 0.01108551025390625,\n", + " -0.032501220703125,\n", + " 0.0092010498046875,\n", + " 0.02838134765625,\n", + " -0.01226043701171875,\n", + " 0.0168914794921875,\n", + " -0.0027446746826171875,\n", + " 0.014923095703125,\n", + " -0.047332763671875,\n", + " 0.012939453125,\n", + " 0.0298919677734375,\n", + " -0.00014722347259521484,\n", + " -0.0091400146484375,\n", + " -0.004497528076171875,\n", + " -0.057769775390625,\n", + " -0.00437164306640625,\n", + " 0.05755615234375,\n", + " -0.061798095703125,\n", + " 0.0255584716796875,\n", + " 0.035369873046875,\n", + " 0.00023627281188964844,\n", + " 0.0300445556640625,\n", + " -0.018463134765625,\n", + " -0.05291748046875,\n", + " 0.035369873046875,\n", + " -0.01873779296875,\n", + " -0.06341552734375,\n", + " 0.0131072998046875,\n", + " 0.005413055419921875,\n", + " -0.038604736328125,\n", + " -0.0244140625,\n", + " -0.0018014907836914062,\n", + " 0.039520263671875,\n", + " 0.024078369140625,\n", + " 0.006099700927734375,\n", + " 0.048919677734375,\n", + " -0.033935546875,\n", + " -0.0079345703125,\n", + " 0.0036296844482421875,\n", + " 0.0098876953125,\n", + " 0.0160369873046875,\n", + " -0.0484619140625,\n", + " 0.02178955078125,\n", + " -0.0618896484375,\n", + " -0.0465087890625,\n", + " -0.01361083984375,\n", + " -0.0021228790283203125,\n", + " 0.01849365234375,\n", + " -0.061431884765625,\n", + " -0.012298583984375,\n", + " 0.018524169921875,\n", + " -0.018524169921875,\n", + " 0.00844573974609375,\n", + " -0.0200958251953125,\n", + " -0.0222015380859375,\n", + " -0.072509765625,\n", + " -0.0411376953125,\n", + " -0.00012600421905517578,\n", + " 0.0271148681640625,\n", + " 0.046234130859375,\n", + " 0.006591796875,\n", + " -0.0833740234375,\n", + " 0.031463623046875,\n", + " -0.055755615234375,\n", + " -0.0128326416015625,\n", + " -0.00267791748046875,\n", + " 0.007904052734375,\n", + " -0.0662841796875,\n", + " 0.057708740234375,\n", + " 0.019134521484375,\n", + " -0.004459381103515625,\n", + " -0.003093719482421875,\n", + " 0.0247802734375,\n", + " 0.0033512115478515625,\n", + " 0.01654052734375,\n", + " -0.028076171875,\n", + " 0.041046142578125,\n", + " 0.0159759521484375,\n", + " -0.0902099609375,\n", + " -0.04376220703125,\n", + " 0.00431060791015625,\n", + " 0.0232391357421875,\n", + " 0.06298828125,\n", + " -0.017791748046875,\n", + " -0.0433349609375,\n", + " -0.03338623046875,\n", + " -0.0297393798828125,\n", + " -0.004673004150390625,\n", + " -0.040496826171875,\n", + " -0.0158538818359375,\n", + " -0.034637451171875,\n", + " -0.031402587890625,\n", + " 0.01456451416015625,\n", + " -0.0100555419921875,\n", + " 0.00965118408203125,\n", + " 0.0007476806640625,\n", + " 0.042449951171875,\n", + " 0.01300048828125,\n", + " -0.005397796630859375,\n", + " -0.03216552734375,\n", + " 0.0044403076171875,\n", + " -0.041168212890625,\n", + " -0.0245513916015625,\n", + " -0.031524658203125,\n", + " 0.0247039794921875,\n", + " -0.053436279296875,\n", + " 0.024169921875,\n", + " 0.003513336181640625,\n", + " -0.036041259765625,\n", + " 0.00797271728515625,\n", + " -0.0291595458984375,\n", + " 0.008880615234375,\n", + " -0.04254150390625,\n", + " 0.0018520355224609375,\n", + " -0.005695343017578125,\n", + " -0.047088623046875,\n", + " 0.030792236328125,\n", + " 0.014739990234375,\n", + " 0.00440216064453125,\n", + " -0.005950927734375,\n", + " 0.023895263671875,\n", + " -0.055450439453125,\n", + " 0.022857666015625,\n", + " -0.0103607177734375,\n", + " -0.034393310546875,\n", + " 0.0171051025390625,\n", + " -0.028350830078125,\n", + " 0.0191802978515625,\n", + " -0.006282806396484375,\n", + " 0.058013916015625,\n", + " -0.0283966064453125,\n", + " -0.01318359375,\n", + " -0.0328369140625,\n", + " 0.05267333984375,\n", + " -0.0308990478515625,\n", + " -0.0057525634765625,\n", + " 0.00325775146484375,\n", + " 0.004566192626953125,\n", + " -0.0736083984375,\n", + " 0.010040283203125,\n", + " 0.0194854736328125,\n", + " -0.0057220458984375,\n", + " -0.01258087158203125,\n", + " -0.04376220703125,\n", + " -0.01371002197265625,\n", + " 0.007785797119140625,\n", + " -0.0262603759765625,\n", + " 0.0176849365234375,\n", + " -0.0017185211181640625,\n", + " -0.0128021240234375,\n", + " -0.00899505615234375,\n", + " 0.0006489753723144531,\n", + " 0.002262115478515625,\n", + " 0.005229949951171875,\n", + " -0.0011425018310546875,\n", + " 0.0212249755859375,\n", + " 0.04217529296875,\n", + " -0.02606201171875,\n", + " -0.00763702392578125,\n", + " 0.03240966796875,\n", + " -0.033111572265625,\n", + " -0.0220947265625,\n", + " -0.0175628662109375,\n", + " 0.0009794235229492188,\n", + " -0.01265716552734375,\n", + " -0.0301361083984375,\n", + " 0.03509521484375,\n", + " 0.007724761962890625,\n", + " 0.0083770751953125,\n", + " -0.0167388916015625,\n", + " -0.0017766952514648438,\n", + " 0.004486083984375,\n", + " 0.011199951171875,\n", + " 0.0291595458984375,\n", + " -0.025421142578125,\n", + " -0.040618896484375,\n", + " -0.00024700164794921875,\n", + " 0.008544921875,\n", + " 0.06744384765625,\n", + " 0.031524658203125,\n", + " -0.00023317337036132812,\n", + " -0.0117950439453125,\n", + " 0.006153106689453125,\n", + " 0.03009033203125,\n", + " -0.01513671875,\n", + " -0.0007104873657226562,\n", + " -0.06597900390625,\n", + " 0.046722412109375,\n", + " -0.004730224609375,\n", + " 0.04779052734375,\n", + " 0.02947998046875,\n", + " 0.058013916015625,\n", + " -0.0098419189453125,\n", + " -0.0170135498046875,\n", + " 0.023223876953125,\n", + " 0.08184814453125,\n", + " 0.0178985595703125,\n", + " -0.012786865234375,\n", + " -0.0445556640625,\n", + " -0.0161590576171875,\n", + " 0.01552581787109375,\n", + " -0.053009033203125,\n", + " -0.031768798828125,\n", + " 0.04925537109375,\n", + " 0.007106781005859375,\n", + " -0.067138671875,\n", + " -0.0010423660278320312,\n", + " -0.0208740234375,\n", + " -0.019439697265625,\n", + " -0.003414154052734375,\n", + " 0.035369873046875,\n", + " 0.0204620361328125,\n", + " 0.0458984375,\n", + " -0.006603240966796875,\n", + " -0.026763916015625,\n", + " 0.01291656494140625,\n", + " -0.019683837890625,\n", + " -0.0280303955078125,\n", + " 0.01270294189453125,\n", + " -0.00634002685546875,\n", + " -0.02978515625,\n", + " -0.00811004638671875,\n", + " -0.01092529296875,\n", + " 0.03143310546875,\n", + " 0.0007624626159667969,\n", + " 0.049041748046875,\n", + " 0.01274871826171875,\n", + " 0.0295562744140625,\n", + " 0.03790283203125,\n", + " 0.054443359375,\n", + " -0.02142333984375,\n", + " -0.0457763671875,\n", + " -0.026031494140625,\n", + " 0.046966552734375,\n", + " -0.00402069091796875,\n", + " 0.048492431640625,\n", + " 0.0095367431640625,\n", + " 0.02056884765625,\n", + " 0.0250244140625,\n", + " -0.019073486328125,\n", + " -0.01326751708984375,\n", + " 0.0350341796875,\n", + " -0.0160064697265625,\n", + " -0.02496337890625,\n", + " -0.04132080078125,\n", + " 0.01763916015625,\n", + " -0.045379638671875,\n", + " 0.044342041015625,\n", + " 0.04083251953125,\n", + " 0.006076812744140625,\n", + " -0.0218353271484375,\n", + " 0.060577392578125,\n", + " -0.04296875,\n", + " -0.0513916015625,\n", + " 0.0084075927734375,\n", + " -0.01556396484375,\n", + " -0.0226898193359375,\n", + " -0.044189453125,\n", + " -0.0595703125,\n", + " 0.026458740234375,\n", + " 0.003025054931640625,\n", + " -0.06378173828125,\n", + " -0.041290283203125,\n", + " 0.0237579345703125,\n", + " -0.0023975372314453125,\n", + " 0.00211334228515625,\n", + " -0.00015115737915039062,\n", + " -0.0247802734375,\n", + " -0.004795074462890625,\n", + " -0.0220184326171875,\n", + " -0.06439208984375,\n", + " -0.02630615234375,\n", + " -0.039306640625,\n", + " -0.0080108642578125,\n", + " -0.029632568359375,\n", + " 0.0162811279296875,\n", + " -0.0186004638671875,\n", + " 0.0272216796875,\n", + " 0.0157318115234375,\n", + " -0.033966064453125,\n", + " 0.0010089874267578125,\n", + " -0.030242919921875,\n", + " 0.0231170654296875,\n", + " -0.0038623809814453125,\n", + " -0.0204925537109375,\n", + " 0.051239013671875,\n", + " 0.06329345703125,\n", + " -0.0116729736328125,\n", + " -0.0194091796875,\n", + " -0.0158843994140625,\n", + " -0.0679931640625,\n", + " -0.0086212158203125,\n", + " 0.0123138427734375,\n", + " 0.0226593017578125,\n", + " -0.0130767822265625,\n", + " 0.00115966796875,\n", + " 0.08587646484375,\n", + " -0.0295562744140625,\n", + " 0.02587890625,\n", + " 0.005741119384765625,\n", + " -0.020965576171875,\n", + " -0.0204925537109375,\n", + " 0.0081787109375,\n", + " 0.0175933837890625,\n", + " -0.00223541259765625,\n", + " 0.053985595703125,\n", + " 0.01320648193359375,\n", + " 0.0005278587341308594,\n", + " 0.01934814453125,\n", + " -0.0286865234375,\n", + " 0.051666259765625,\n", + " 0.011016845703125,\n", + " 0.00782012939453125,\n", + " -0.05291748046875,\n", + " -0.00917816162109375,\n", + " 0.033355712890625,\n", + " -0.01148223876953125,\n", + " -0.043304443359375,\n", + " -0.0465087890625,\n", + " -0.01393890380859375,\n", + " 0.040924072265625,\n", + " 0.0006461143493652344,\n", + " 0.0227508544921875,\n", + " 0.0157012939453125,\n", + " 0.0002834796905517578,\n", + " 0.003940582275390625,\n", + " -0.0288238525390625,\n", + " 0.0272979736328125,\n", + " 0.0171356201171875,\n", + " -0.0088958740234375,\n", + " -0.037872314453125,\n", + " -0.01032257080078125,\n", + " 0.0020999908447265625,\n", + " -0.0289764404296875,\n", + " -0.0192108154296875,\n", + " -0.032379150390625,\n", + " 0.041168212890625,\n", + " 0.0219573974609375,\n", + " -0.047332763671875,\n", + " 0.0184173583984375,\n", + " -0.02276611328125,\n", + " 0.02508544921875,\n", + " 0.005527496337890625,\n", + " 0.029541015625,\n", + " -0.01291656494140625,\n", + " 0.0093536376953125,\n", + " -0.02545166015625,\n", + " 0.04998779296875,\n", + " 0.028533935546875,\n", + " 1.5735626220703125e-05,\n", + " -0.006298065185546875,\n", + " 0.0011272430419921875,\n", + " -0.0172576904296875,\n", + " -0.033172607421875,\n", + " 0.0338134765625,\n", + " 0.039337158203125,\n", + " 0.0079498291015625,\n", + " -0.0567626953125,\n", + " -0.03759765625,\n", + " -0.057708740234375,\n", + " 0.010040283203125,\n", + " -0.0033855438232421875,\n", + " 0.036285400390625,\n", + " -0.0034656524658203125,\n", + " -0.0189971923828125,\n", + " -0.06585693359375,\n", + " 0.051513671875,\n", + " -0.01027679443359375,\n", + " 0.0269622802734375,\n", + " -0.031646728515625,\n", + " -0.0156707763671875,\n", + " -0.044952392578125,\n", + " -0.009674072265625,\n", + " -0.037689208984375,\n", + " 0.0204315185546875,\n", + " -0.013153076171875,\n", + " 0.025421142578125,\n", + " -0.0173187255859375,\n", + " -0.02947998046875,\n", + " -0.002391815185546875,\n", + " -0.01141357421875,\n", + " 0.01364898681640625,\n", + " -0.0020160675048828125,\n", + " 0.0111083984375,\n", + " -0.02630615234375,\n", + " 0.0599365234375,\n", + " -0.002490997314453125,\n", + " -0.006988525390625,\n", + " 0.017242431640625,\n", + " 0.00949859619140625,\n", + " 0.00360107421875,\n", + " -0.024566650390625,\n", + " -0.02386474609375,\n", + " 0.0008535385131835938,\n", + " 0.0440673828125,\n", + " 0.059326171875,\n", + " -0.0174713134765625,\n", + " 0.02325439453125,\n", + " 0.030364990234375,\n", + " 0.0013360977172851562,\n", + " 0.003276824951171875,\n", + " -0.040679931640625,\n", + " 0.0050811767578125,\n", + " 0.0113677978515625,\n", + " -0.0019435882568359375,\n", + " -0.038970947265625,\n", + " -0.015625,\n", + " -0.1220703125,\n", + " -0.0167999267578125,\n", + " -0.044403076171875,\n", + " -0.008087158203125,\n", + " 0.0021209716796875,\n", + " 0.01355743408203125,\n", + " 0.011016845703125,\n", + " -0.0013494491577148438,\n", + " 0.03692626953125,\n", + " 0.0316162109375,\n", + " -0.0245208740234375,\n", + " -0.0086669921875,\n", + " 0.0126953125,\n", + " -0.047607421875,\n", + " 0.0343017578125,\n", + " -0.0032291412353515625,\n", + " -0.03900146484375,\n", + " 0.07135009765625,\n", + " -0.003345489501953125,\n", + " -0.0205230712890625,\n", + " -0.024810791015625,\n", + " 0.06280517578125,\n", + " 0.00487518310546875,\n", + " -0.0026988983154296875,\n", + " -0.035491943359375,\n", + " -0.028076171875,\n", + " -0.0014324188232421875,\n", + " 0.00742340087890625,\n", + " -0.0036163330078125,\n", + " -0.0010461807250976562,\n", + " 0.0399169921875,\n", + " -0.04376220703125,\n", + " -0.049835205078125,\n", + " 0.0411376953125,\n", + " -0.004642486572265625,\n", + " -0.0299835205078125,\n", + " -0.0012035369873046875,\n", + " -0.01702880859375,\n", + " 0.004367828369140625,\n", + " 0.001789093017578125,\n", + " 0.050262451171875,\n", + " 0.047454833984375,\n", + " 0.025634765625,\n", + " -0.0186767578125,\n", + " 0.004329681396484375,\n", + " 0.0288543701171875,\n", + " -0.01214599609375,\n", + " 0.050018310546875,\n", + " 0.052154541015625,\n", + " 0.0131072998046875,\n", + " 0.03326416015625,\n", + " -0.0121917724609375,\n", + " -0.01551055908203125,\n", + " -0.0513916015625,\n", + " 0.0400390625,\n", + " -0.0141143798828125,\n", + " -0.08465576171875,\n", + " -0.040496826171875,\n", + " 0.079833984375,\n", + " 0.03912353515625,\n", + " 0.018341064453125,\n", + " 0.01049041748046875,\n", + " 0.0297698974609375,\n", + " 0.052459716796875,\n", + " 0.005542755126953125,\n", + " -0.030242919921875,\n", + " -0.0433349609375,\n", + " -0.0167388916015625,\n", + " 0.035797119140625,\n", + " -0.0021038055419921875,\n", + " -0.0379638671875,\n", + " 0.0301971435546875,\n", + " 0.09130859375,\n", + " -0.045074462890625,\n", + " -0.034912109375,\n", + " 0.0113677978515625,\n", + " 0.038360595703125,\n", + " 0.0447998046875,\n", + " 0.048431396484375,\n", + " -0.023590087890625,\n", + " -0.058929443359375,\n", + " 0.0196075439453125,\n", + " 0.039276123046875,\n", + " 0.020843505859375,\n", + " -0.0268402099609375,\n", + " -0.0286102294921875,\n", + " -0.055084228515625,\n", + " 0.02752685546875,\n", + " -0.0426025390625,\n", + " -0.0233917236328125,\n", + " -0.005435943603515625,\n", + " 0.07830810546875,\n", + " 0.007007598876953125,\n", + " -0.08465576171875,\n", + " -0.016693115234375,\n", + " 0.03265380859375,\n", + " 0.025604248046875,\n", + " -0.021148681640625,\n", + " -0.0108489990234375,\n", + " 0.02789306640625,\n", + " -0.0146026611328125,\n", + " -0.0025272369384765625,\n", + " -6.93202018737793e-05,\n", + " -0.0035877227783203125,\n", + " 0.058258056640625,\n", + " -0.004970550537109375,\n", + " -0.053619384765625,\n", + " 0.00989532470703125,\n", + " 0.01007080078125,\n", + " -0.01363372802734375,\n", + " 0.0067596435546875,\n", + " -0.050506591796875,\n", + " -0.0024318695068359375,\n", + " -0.0256500244140625,\n", + " -0.0005860328674316406,\n", + " 0.0266571044921875,\n", + " 0.006595611572265625,\n", + " 0.0311737060546875,\n", + " -0.05389404296875,\n", + " -0.0168304443359375,\n", + " -0.015350341796875,\n", + " 0.0274658203125,\n", + " 0.022796630859375,\n", + " 0.0078887939453125,\n", + " -0.009674072265625,\n", + " -0.0261077880859375,\n", + " 0.06256103515625,\n", + " -0.016815185546875,\n", + " -0.03863525390625,\n", + " -0.01320648193359375,\n", + " -0.0384521484375,\n", + " 0.0197906494140625,\n", + " -0.02734375,\n", + " -0.0085906982421875,\n", + " -0.0162353515625,\n", + " 0.017333984375,\n", + " 0.0211639404296875,\n", + " 0.00862884521484375,\n", + " 0.053619384765625,\n", + " 0.007144927978515625,\n", + " -0.0205841064453125,\n", + " -0.001682281494140625,\n", + " -0.003360748291015625,\n", + " -0.032440185546875,\n", + " 0.0178985595703125,\n", + " -0.002193450927734375,\n", + " -0.01265716552734375,\n", + " 0.034515380859375,\n", + " -0.093505859375,\n", + " 0.06134033203125,\n", + " 0.0161590576171875,\n", + " 0.0596923828125,\n", + " 0.041107177734375,\n", + " 0.035888671875,\n", + " 0.03533935546875,\n", + " 5.984306335449219e-05,\n", + " -0.0002205371856689453,\n", + " 0.0179290771484375,\n", + " 0.042694091796875,\n", + " 0.039276123046875,\n", + " 0.00992584228515625,\n", + " 0.006435394287109375,\n", + " -0.0369873046875,\n", + " 0.0162506103515625,\n", + " -0.012176513671875,\n", + " -0.0496826171875,\n", + " 0.023651123046875,\n", + " 0.035308837890625,\n", + " 0.0053253173828125,\n", + " 0.007244110107421875,\n", + " -0.0158843994140625,\n", + " -0.0276947021484375,\n", + " -0.03594970703125,\n", + " 0.03509521484375,\n", + " 0.006572723388671875,\n", + " -0.0243377685546875,\n", + " 0.02606201171875,\n", + " -0.033050537109375,\n", + " 0.0186920166015625,\n", + " 0.01274871826171875,\n", + " 0.053680419921875,\n", + " -0.040130615234375,\n", + " 0.0355224609375,\n", + " -0.043060302734375,\n", + " 0.005634307861328125,\n", + " ...]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "embedding_response.data[0]['embedding']" + ] + }, + { + "cell_type": "code", + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import chromadb\n", - "from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction\n", - "\n", - "embedding_function = OpenAIEmbeddingFunction(api_key=os.environ.get('OPENAI_API_KEY'),\n", - " model_name=\"text-embedding-ada-002\")\n", - "\n", "\n", "chroma_client = chromadb.Client()\n", - "vector_store = chroma_client.get_or_create_collection(name=\"Universities\",\n", - " embedding_function=embedding_function)" + "vector_store = chroma_client.get_or_create_collection(name=\"Universities\")" ] }, { @@ -111,13 +1129,15 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": { "collapsed": false }, "outputs": [], "source": [ - "vector_store.add(\"uni_info\", documents=university_info)" + "vector_store.add(\"uni_info\",\n", + " documents=university_info,\n", + " embeddings=embedding_response.data[0]['embedding'])" ] }, { @@ -131,9 +1151,26 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/jreini/opt/anaconda3/envs/trulens_dev_empty/lib/python3.11/site-packages/_distutils_hack/__init__.py:33: UserWarning: Setuptools is replacing distutils.\n", + " warnings.warn(\"Setuptools is replacing distutils.\")\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "🦑 Tru initialized with db url sqlite:///default.sqlite .\n", + "🛑 Secret keys may be written to the database. See the `database_redact_keys` option of Tru` to prevent this.\n" + ] + } + ], "source": [ "from trulens_eval import Tru\n", "from trulens_eval.tru_custom_app import instrument\n", @@ -143,10 +1180,12 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ + "import litellm\n", + "\n", "class RAG_from_scratch:\n", " @instrument\n", " def retrieve(self, query: str) -> list:\n", @@ -154,7 +1193,9 @@ " Retrieve relevant text from vector store.\n", " \"\"\"\n", " results = vector_store.query(\n", - " query_texts=query,\n", + " query_embeddings=embedding(\n", + " model=\"mistral/mistral-embed\",\n", + " input=query).data[0]['embedding'],\n", " n_results=2\n", " )\n", " return results['documents'][0]\n", @@ -164,8 +1205,8 @@ " \"\"\"\n", " Generate answer from context.\n", " \"\"\"\n", - " completion = oai_client.chat.completions.create(\n", - " model=\"gpt-3.5-turbo\",\n", + " completion = litellm.completion(\n", + " model=\"mistral/mistral-small\",\n", " temperature=0,\n", " messages=\n", " [\n", @@ -201,9 +1242,31 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ In Groundedness, input source will be set to __record__.app.retrieve.rets.collect() .\n", + "✅ In Groundedness, input statement will be set to __record__.main_output or `Select.RecordOutput` .\n", + "✅ In Answer Relevance, input prompt will be set to __record__.app.retrieve.args.query .\n", + "✅ In Answer Relevance, input response will be set to __record__.main_output or `Select.RecordOutput` .\n", + "✅ In Context Relevance, input question will be set to __record__.app.retrieve.args.query .\n", + "✅ In Context Relevance, input context will be set to __record__.app.retrieve.rets.collect() .\n", + "✅ In coherence, input text will be set to __record__.main_output or `Select.RecordOutput` .\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[nltk_data] Downloading package punkt to /Users/jreini/nltk_data...\n", + "[nltk_data] Package punkt is already up-to-date!\n" + ] + } + ], "source": [ "from trulens_eval import Feedback, Select\n", "from trulens_eval.feedback import Groundedness\n", @@ -247,9 +1310,35 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "182762ed570e4d42a62b36241c9d71e2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Groundedness per statement in source: 0%| | 0/2 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "({'statement_0': 1.0, 'statement_1': 0.8},\n", + " {'reasons': '\\nSTATEMENT 0:\\n Statement Sentence: The University of Washington was founded in 1861.\\nSupporting Evidence: The University of Washington, founded in 1861 in Seattle, is a public research university.\\nScore: 10\\n\\n\\nSTATEMENT 1:\\n Statement Sentence: It is the flagship institution of the state of Washington.\\nSupporting Evidence: As the flagship institution of the six public universities in Washington state,\\nScore: 8\\n\\n'})" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "grounded.groundedness_measure_with_cot_reasons(\"\"\"e University of Washington, founded in 1861 in Seattle, is a public '\n", " 'research university\\n'\n", @@ -271,7 +1360,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -291,9 +1380,24 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b0124b1f9f5045b7a53449ff4160975f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Groundedness per statement in source: 0%| | 0/9 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "with tru_rag as recording:\n", " rag.query(\"Give me a long history of U Dub\")" @@ -301,9 +1405,76 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " | Answer Relevance | \n", + "Context Relevance | \n", + "Groundedness | \n", + "coherence | \n", + "latency | \n", + "total_cost | \n", + "
---|---|---|---|---|---|---|
app_id | \n", + "\n", + " | \n", + " | \n", + " | \n", + " | \n", + " | \n", + " |
RAG v1 | \n", + "0.8 | \n", + "0.8 | \n", + "0.866667 | \n", + "0.8 | \n", + "4.0 | \n", + "0.001942 | \n", + "