Skip to content

Commit aa69ee9

Browse files
committed
Cleaned integrations directory
1 parent 71d00f0 commit aa69ee9

File tree

1 file changed

+33
-19
lines changed

1 file changed

+33
-19
lines changed

nemoguardrails/integrations/langchain/runnable_rails.py

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@
1515

1616
from __future__ import annotations
1717

18-
from typing import Any, List, Optional
18+
from typing import Any, List, Optional, Union, cast
1919

20-
from langchain_core.language_models import BaseLanguageModel
20+
from langchain_core.language_models import BaseChatModel
21+
from langchain_core.language_models.llms import BaseLLM
2122
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
2223
from langchain_core.prompt_values import ChatPromptValue, StringPromptValue
2324
from langchain_core.runnables import Runnable
@@ -27,14 +28,14 @@
2728

2829
from nemoguardrails import LLMRails, RailsConfig
2930
from nemoguardrails.integrations.langchain.utils import async_wrap
30-
from nemoguardrails.rails.llm.options import GenerationOptions
31+
from nemoguardrails.rails.llm.options import GenerationOptions, GenerationResponse
3132

3233

3334
class RunnableRails(Runnable[Input, Output]):
3435
def __init__(
3536
self,
3637
config: RailsConfig,
37-
llm: Optional[BaseLanguageModel] = None,
38+
llm: Optional[Union[BaseLLM, BaseChatModel]] = None,
3839
tools: Optional[List[Tool]] = None,
3940
passthrough: bool = True,
4041
runnable: Optional[Runnable] = None,
@@ -67,12 +68,14 @@ def __init__(
6768
if self.passthrough_runnable:
6869
self._init_passthrough_fn()
6970

70-
def _init_passthrough_fn(self):
71+
def _init_passthrough_fn(self) -> None:
7172
"""Initialize the passthrough function for the LLM rails instance."""
7273

7374
async def passthrough_fn(context: dict, events: List[dict]):
7475
# First, we fetch the input from the context
7576
_input = context.get("passthrough_input")
77+
if self.passthrough_runnable is None:
78+
raise ValueError("No passthrough runnable provided")
7679
async_wrapped_invoke = async_wrap(self.passthrough_runnable.invoke)
7780
_output = await async_wrapped_invoke(_input, self.config, **self.kwargs)
7881

@@ -84,10 +87,11 @@ async def passthrough_fn(context: dict, events: List[dict]):
8487

8588
return text, _output
8689

87-
self.rails.llm_generation_actions.passthrough_fn = passthrough_fn
90+
# Dynamically assign passthrough_fn to avoid type checker issues
91+
setattr(self.rails.llm_generation_actions, "passthrough_fn", passthrough_fn)
8892

89-
def __or__(self, other):
90-
if isinstance(other, BaseLanguageModel):
93+
def __or__(self, other) -> "RunnableRails[Input, Output]": # type: ignore[override]
94+
if isinstance(other, (BaseLLM, BaseChatModel)):
9195
self.llm = other
9296
self.rails.update_llm(other)
9397

@@ -193,6 +197,9 @@ def invoke(
193197
res = self.rails.generate(
194198
messages=input_messages, options=GenerationOptions(output_vars=True)
195199
)
200+
# When using output_vars=True, rails.generate returns a GenerationResponse
201+
if not isinstance(res, GenerationResponse):
202+
raise Exception(f"Expected GenerationResponse, got {type(res)}")
196203
context = res.output_data
197204
result = res.response
198205

@@ -203,17 +210,16 @@ def invoke(
203210
result = result[0]
204211

205212
if self.passthrough and self.passthrough_runnable:
206-
passthrough_output = context.get("passthrough_output")
213+
passthrough_output = context.get("passthrough_output") if context else None
207214

208215
# If a rail was triggered (input or dialog), the passthrough_output
209216
# will not be set. In this case, we only set the output key to the
210217
# message that was received from the guardrail configuration.
211218
if passthrough_output is None:
212-
passthrough_output = {
213-
self.passthrough_bot_output_key: result["content"]
214-
}
219+
content = result.get("content") if isinstance(result, dict) else result
220+
passthrough_output = {self.passthrough_bot_output_key: content}
215221

216-
bot_message = context.get("bot_message")
222+
bot_message = context.get("bot_message") if context else None
217223

218224
# We make sure that, if the output rails altered the bot message, we
219225
# replace it in the passthrough_output
@@ -222,20 +228,28 @@ def invoke(
222228
elif isinstance(passthrough_output, dict):
223229
passthrough_output[self.passthrough_bot_output_key] = bot_message
224230

225-
return passthrough_output
231+
return cast(Output, passthrough_output)
226232
else:
227233
if isinstance(input, ChatPromptValue):
228-
return AIMessage(content=result["content"])
234+
content = result.get("content") if isinstance(result, dict) else result
235+
# Ensure content is a string for AIMessage
236+
content_str = str(content) if content is not None else ""
237+
return cast(Output, AIMessage(content=content_str))
229238
elif isinstance(input, StringPromptValue):
230239
if isinstance(result, dict):
231-
return result["content"]
240+
return cast(Output, result.get("content", ""))
232241
else:
233-
return result
242+
return cast(Output, result)
234243
elif isinstance(input, dict):
235244
user_input = input["input"]
236245
if isinstance(user_input, str):
237-
return {"output": result["content"]}
246+
content = (
247+
result.get("content") if isinstance(result, dict) else result
248+
)
249+
return cast(Output, {"output": content})
238250
elif isinstance(user_input, list):
239-
return {"output": result}
251+
return cast(Output, {"output": result})
252+
else:
253+
raise ValueError(f"Unexpected user_input type: {type(user_input)}")
240254
else:
241255
raise ValueError(f"Unexpected input type: {type(input)}")

0 commit comments

Comments
 (0)