1515
1616from  __future__ import  annotations 
1717
18- from  typing  import  Any , List , Optional 
18+ from  typing  import  Any , List , Optional ,  Union ,  cast 
1919
20- from  langchain_core .language_models  import  BaseLanguageModel 
20+ from  langchain_core .language_models  import  BaseChatModel 
21+ from  langchain_core .language_models .llms  import  BaseLLM 
2122from  langchain_core .messages  import  AIMessage , HumanMessage , SystemMessage 
2223from  langchain_core .prompt_values  import  ChatPromptValue , StringPromptValue 
2324from  langchain_core .runnables  import  Runnable 
2728
2829from  nemoguardrails  import  LLMRails , RailsConfig 
2930from  nemoguardrails .integrations .langchain .utils  import  async_wrap 
30- from  nemoguardrails .rails .llm .options  import  GenerationOptions 
31+ from  nemoguardrails .rails .llm .options  import  GenerationOptions ,  GenerationResponse 
3132
3233
3334class  RunnableRails (Runnable [Input , Output ]):
3435    def  __init__ (
3536        self ,
3637        config : RailsConfig ,
37-         llm : Optional [BaseLanguageModel ] =  None ,
38+         llm : Optional [Union [ BaseLLM ,  BaseChatModel ] ] =  None ,
3839        tools : Optional [List [Tool ]] =  None ,
3940        passthrough : bool  =  True ,
4041        runnable : Optional [Runnable ] =  None ,
@@ -67,12 +68,14 @@ def __init__(
6768        if  self .passthrough_runnable :
6869            self ._init_passthrough_fn ()
6970
70-     def  _init_passthrough_fn (self ):
71+     def  _init_passthrough_fn (self )  ->   None :
7172        """Initialize the passthrough function for the LLM rails instance.""" 
7273
7374        async  def  passthrough_fn (context : dict , events : List [dict ]):
7475            # First, we fetch the input from the context 
7576            _input  =  context .get ("passthrough_input" )
77+             if  self .passthrough_runnable  is  None :
78+                 raise  ValueError ("No passthrough runnable provided" )
7679            async_wrapped_invoke  =  async_wrap (self .passthrough_runnable .invoke )
7780            _output  =  await  async_wrapped_invoke (_input , self .config , ** self .kwargs )
7881
@@ -84,10 +87,11 @@ async def passthrough_fn(context: dict, events: List[dict]):
8487
8588            return  text , _output 
8689
87-         self .rails .llm_generation_actions .passthrough_fn  =  passthrough_fn 
90+         # Dynamically assign passthrough_fn to avoid type checker issues 
91+         setattr (self .rails .llm_generation_actions , "passthrough_fn" , passthrough_fn )
8892
89-     def  __or__ (self , other ): 
90-         if  isinstance (other , BaseLanguageModel ):
93+     def  __or__ (self , other )  ->   "RunnableRails[Input, Output]" :   # type: ignore[override] 
94+         if  isinstance (other , ( BaseLLM ,  BaseChatModel ) ):
9195            self .llm  =  other 
9296            self .rails .update_llm (other )
9397
@@ -193,6 +197,9 @@ def invoke(
193197        res  =  self .rails .generate (
194198            messages = input_messages , options = GenerationOptions (output_vars = True )
195199        )
200+         # When using output_vars=True, rails.generate returns a GenerationResponse 
201+         if  not  isinstance (res , GenerationResponse ):
202+             raise  Exception (f"Expected GenerationResponse, got { type (res )}  )
196203        context  =  res .output_data 
197204        result  =  res .response 
198205
@@ -203,17 +210,16 @@ def invoke(
203210            result  =  result [0 ]
204211
205212        if  self .passthrough  and  self .passthrough_runnable :
206-             passthrough_output  =  context .get ("passthrough_output" )
213+             passthrough_output  =  context .get ("passthrough_output" )  if   context   else   None 
207214
208215            # If a rail was triggered (input or dialog), the passthrough_output 
209216            # will not be set. In this case, we only set the output key to the 
210217            # message that was received from the guardrail configuration. 
211218            if  passthrough_output  is  None :
212-                 passthrough_output  =  {
213-                     self .passthrough_bot_output_key : result ["content" ]
214-                 }
219+                 content  =  result .get ("content" ) if  isinstance (result , dict ) else  result 
220+                 passthrough_output  =  {self .passthrough_bot_output_key : content }
215221
216-             bot_message  =  context .get ("bot_message" )
222+             bot_message  =  context .get ("bot_message" )  if   context   else   None 
217223
218224            # We make sure that, if the output rails altered the bot message, we 
219225            # replace it in the passthrough_output 
@@ -222,20 +228,28 @@ def invoke(
222228            elif  isinstance (passthrough_output , dict ):
223229                passthrough_output [self .passthrough_bot_output_key ] =  bot_message 
224230
225-             return  passthrough_output 
231+             return  cast ( Output ,  passthrough_output ) 
226232        else :
227233            if  isinstance (input , ChatPromptValue ):
228-                 return  AIMessage (content = result ["content" ])
234+                 content  =  result .get ("content" ) if  isinstance (result , dict ) else  result 
235+                 # Ensure content is a string for AIMessage 
236+                 content_str  =  str (content ) if  content  is  not None  else  "" 
237+                 return  cast (Output , AIMessage (content = content_str ))
229238            elif  isinstance (input , StringPromptValue ):
230239                if  isinstance (result , dict ):
231-                     return  result [ "content" ] 
240+                     return  cast ( Output ,  result . get ( "content" ,  "" )) 
232241                else :
233-                     return  result 
242+                     return  cast ( Output ,  result ) 
234243            elif  isinstance (input , dict ):
235244                user_input  =  input ["input" ]
236245                if  isinstance (user_input , str ):
237-                     return  {"output" : result ["content" ]}
246+                     content  =  (
247+                         result .get ("content" ) if  isinstance (result , dict ) else  result 
248+                     )
249+                     return  cast (Output , {"output" : content })
238250                elif  isinstance (user_input , list ):
239-                     return  {"output" : result }
251+                     return  cast (Output , {"output" : result })
252+                 else :
253+                     raise  ValueError (f"Unexpected user_input type: { type (user_input )}  )
240254            else :
241255                raise  ValueError (f"Unexpected input type: { type (input )}  )
0 commit comments