1515
1616from __future__ import annotations
1717
18+ from typing import Any , List , Optional , Union , cast
1819import logging
1920from typing import Any , AsyncIterator , Dict , Iterator , List , Optional , Union
2021
22+ from langchain_core .language_models import BaseChatModel
23+ from langchain_core .language_models .llms import BaseLLM
24+ from langchain_core .messages import AIMessage , HumanMessage , SystemMessage
2125from langchain_core .language_models import BaseLanguageModel
2226from langchain_core .prompt_values import ChatPromptValue , StringPromptValue
2327from langchain_core .runnables import Runnable , RunnableConfig
3337 message_to_dict ,
3438)
3539from nemoguardrails .integrations .langchain .utils import async_wrap
36- from nemoguardrails .rails .llm .options import GenerationOptions
40+ from nemoguardrails .rails .llm .options import GenerationOptions , GenerationResponse
3741
3842logger = logging .getLogger (__name__ )
3943
@@ -62,7 +66,7 @@ class RunnableRails(Runnable[Input, Output]):
6266 def __init__ (
6367 self ,
6468 config : RailsConfig ,
65- llm : Optional [BaseLanguageModel ] = None ,
69+ llm : Optional [Union [ BaseLLM , BaseChatModel ] ] = None ,
6670 tools : Optional [List [Tool ]] = None ,
6771 passthrough : bool = True ,
6872 runnable : Optional [Runnable ] = None ,
@@ -110,7 +114,7 @@ def __init__(
110114 if self .passthrough_runnable :
111115 self ._init_passthrough_fn ()
112116
113- def _init_passthrough_fn (self ):
117+ def _init_passthrough_fn (self ) -> None :
114118 """Initialize the passthrough function for the LLM rails instance."""
115119
116120 async def passthrough_fn (context : dict , events : List [dict ]):
@@ -134,7 +138,8 @@ async def passthrough_fn(context: dict, events: List[dict]):
134138
135139 return text , _output
136140
137- self .rails .llm_generation_actions .passthrough_fn = passthrough_fn
141+ # Dynamically assign passthrough_fn to avoid type checker issues
142+ setattr (self .rails .llm_generation_actions , "passthrough_fn" , passthrough_fn )
138143
139144 def __or__ (
140145 self , other : Union [BaseLanguageModel , Runnable [Any , Any ]]
@@ -687,6 +692,9 @@ def _full_rails_invoke(
687692 res = self .rails .generate (
688693 messages = input_messages , options = GenerationOptions (output_vars = True )
689694 )
695+ # When using output_vars=True, rails.generate returns a GenerationResponse
696+ if not isinstance (res , GenerationResponse ):
697+ raise Exception (f"Expected GenerationResponse, got { type (res )} " )
690698 context = res .output_data
691699 result = res .response
692700
0 commit comments