Skip to content

Commit

Permalink
langchain provider fix (#759)
Browse files Browse the repository at this point in the history
* first

* temp format

* minor

* more notes

* minor fixes

* more fixes
  • Loading branch information
piotrm0 authored Jan 4, 2024
1 parent 3d58eee commit 4989a14
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,6 @@ class DummyEndpoint(Endpoint):
overloaded_prob: float

def __new__(cls, *args, **kwargs):

return super(Endpoint, cls).__new__(cls, name="dummyendpoint")

def __init__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import BaseLLM
import pydantic

from trulens_eval.feedback.provider.endpoint.base import Endpoint
from trulens_eval.feedback.provider.endpoint.base import EndpointCallback
Expand All @@ -28,7 +29,10 @@ class LangchainEndpoint(Endpoint):
Langchain endpoint.
"""

chain: Union[BaseLLM, BaseChatModel]
# Cannot validate BaseLLM / BaseChatModel as they are pydantic v1 and there
# is some bug involving their use within pydantic v2.
# https://github.com/langchain-ai/langchain/issues/10112
chain: Any # Union[BaseLLM, BaseChatModel]

def __new__(cls, *args, **kwargs):
return super(Endpoint, cls).__new__(cls, name="langchain")
Expand All @@ -51,7 +55,8 @@ def __init__(self, chain: Union[BaseLLM, BaseChatModel], *args, **kwargs):

if not (isinstance(chain, BaseLLM) or isinstance(chain, BaseChatModel)):
raise ValueError(
f"`chain` must be of type {BaseLLM.__name__} or {BaseChatModel.__name__}"
f"`chain` must be of type {BaseLLM.__name__} or {BaseChatModel.__name__}. "
f"If you are using DEFERRED mode, this may be due to our inability to serialize `chain`."
)

kwargs["chain"] = chain
Expand Down
6 changes: 3 additions & 3 deletions trulens_eval/trulens_eval/feedback/provider/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ class Langchain(LLMProvider):
def __init__(
self,
chain: Union[BaseLLM, BaseChatModel],
model_engine: str = "",
*args,
model_engine: str = "",
**kwargs
):
"""
Expand All @@ -38,10 +38,10 @@ def __init__(
Args:
chain (Union[BaseLLM, BaseChatModel]): Langchain LLMs or chat models
"""
self_kwargs = kwargs.copy()
self_kwargs = dict(kwargs)
self_kwargs["model_engine"] = model_engine or type(chain).__name__
self_kwargs["endpoint"] = LangchainEndpoint(
*args, chain=chain, **kwargs.copy()
*args, chain=chain, **kwargs
)

super().__init__(**self_kwargs)
Expand Down

0 comments on commit 4989a14

Please sign in to comment.