Skip to content

Commit 9fae68e

Browse files
committed
fix(llmrails): handle LLM models without model_kwargs field in isolation (#1336)
- Only copy model_kwargs if it exists and is not None - Prevents AttributeError for models like ChatNVIDIA that don't have model_kwargs - Fix FakeLLM to share counter across copied instances for test consistency
1 parent e6ac009 commit 9fae68e

File tree

4 files changed

+234
-14
lines changed

4 files changed

+234
-14
lines changed

nemoguardrails/rails/llm/llmrails.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -596,8 +596,6 @@ def _create_action_llm_copy(
596596
and isolated_llm.model_kwargs is not None
597597
):
598598
isolated_llm.model_kwargs = isolated_llm.model_kwargs.copy()
599-
else:
600-
isolated_llm.model_kwargs = {}
601599

602600
log.debug(
603601
"Successfully created isolated LLM copy for action: %s", action_name

tests/test_llm_isolation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,7 @@ def test_create_action_llm_copy_with_none_model_kwargs(self, rails_with_mock_llm
225225

226226
isolated_llm = rails._create_action_llm_copy(original_llm, "test_action")
227227

228-
assert isolated_llm.model_kwargs == {}
229-
assert isinstance(isolated_llm.model_kwargs, dict)
228+
assert isolated_llm.model_kwargs is None
230229

231230
def test_create_action_llm_copy_handles_copy_failure(self, rails_with_mock_llm):
232231
"""Test that copy failures raise detailed error message."""
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Tests for LLM isolation with models that don't have model_kwargs field."""
17+
18+
from typing import Any, Dict, List, Optional
19+
from unittest.mock import Mock
20+
21+
import pytest
22+
from langchain_core.language_models import BaseChatModel
23+
from langchain_core.messages import BaseMessage
24+
from langchain_core.outputs import ChatGeneration, ChatResult
25+
from pydantic import BaseModel, Field
26+
27+
from nemoguardrails.rails.llm.config import RailsConfig
28+
from nemoguardrails.rails.llm.llmrails import LLMRails
29+
30+
31+
class StrictPydanticLLM(BaseModel):
32+
"""Mock Pydantic LLM that doesn't allow arbitrary attributes (like ChatNVIDIA)."""
33+
34+
class Config:
35+
extra = "forbid"
36+
37+
temperature: float = Field(default=0.7)
38+
max_tokens: Optional[int] = Field(default=None)
39+
40+
41+
class MockChatNVIDIA(BaseChatModel):
42+
"""Mock ChatNVIDIA-like model that doesn't have model_kwargs."""
43+
44+
model: str = "nvidia-model"
45+
temperature: float = 0.7
46+
47+
class Config:
48+
extra = "forbid"
49+
50+
def _generate(
51+
self,
52+
messages: List[BaseMessage],
53+
stop: Optional[List[str]] = None,
54+
run_manager: Optional[Any] = None,
55+
**kwargs: Any,
56+
) -> ChatResult:
57+
"""Mock generation method."""
58+
return ChatResult(generations=[ChatGeneration(message=Mock())])
59+
60+
@property
61+
def _llm_type(self) -> str:
62+
"""Return the type of language model."""
63+
return "nvidia"
64+
65+
66+
class FlexibleLLMWithModelKwargs(BaseModel):
67+
"""Mock LLM that has model_kwargs and allows modifications."""
68+
69+
class Config:
70+
extra = "allow"
71+
72+
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
73+
temperature: float = 0.7
74+
75+
76+
class FlexibleLLMWithoutModelKwargs(BaseModel):
77+
"""Mock LLM that doesn't have model_kwargs but allows adding attributes."""
78+
79+
class Config:
80+
extra = "allow"
81+
82+
temperature: float = 0.7
83+
# no model_kwargs field
84+
85+
86+
@pytest.fixture
87+
def test_config():
88+
"""Create test configuration."""
89+
return RailsConfig.from_content(
90+
"""
91+
models:
92+
- type: main
93+
engine: openai
94+
model: gpt-3.5-turbo
95+
"""
96+
)
97+
98+
99+
class TestLLMIsolationModelKwargsFix:
100+
"""Test LLM isolation with different model types."""
101+
102+
def test_strict_pydantic_model_without_model_kwargs(self, test_config):
103+
"""Test isolation with strict Pydantic model that doesn't have model_kwargs."""
104+
rails = LLMRails(config=test_config, verbose=False)
105+
106+
strict_llm = StrictPydanticLLM(temperature=0.5)
107+
108+
isolated_llm = rails._create_action_llm_copy(strict_llm, "test_action")
109+
110+
assert isolated_llm is not None
111+
assert isolated_llm is not strict_llm
112+
assert isolated_llm.temperature == 0.5
113+
assert not hasattr(isolated_llm, "model_kwargs")
114+
115+
def test_mock_chat_nvidia_without_model_kwargs(self, test_config):
116+
"""Test with a ChatNVIDIA-like model that doesn't allow arbitrary attributes."""
117+
rails = LLMRails(config=test_config, verbose=False)
118+
119+
nvidia_llm = MockChatNVIDIA()
120+
121+
isolated_llm = rails._create_action_llm_copy(nvidia_llm, "self_check_output")
122+
123+
assert isolated_llm is not None
124+
assert isolated_llm is not nvidia_llm
125+
assert isolated_llm.model == "nvidia-model"
126+
assert isolated_llm.temperature == 0.7
127+
assert not hasattr(isolated_llm, "model_kwargs")
128+
129+
def test_flexible_llm_with_model_kwargs(self, test_config):
130+
"""Test with LLM that has model_kwargs field."""
131+
rails = LLMRails(config=test_config, verbose=False)
132+
133+
llm_with_kwargs = FlexibleLLMWithModelKwargs(
134+
model_kwargs={"custom_param": "value"}, temperature=0.3
135+
)
136+
137+
isolated_llm = rails._create_action_llm_copy(llm_with_kwargs, "test_action")
138+
139+
assert isolated_llm is not None
140+
assert isolated_llm is not llm_with_kwargs
141+
assert hasattr(isolated_llm, "model_kwargs")
142+
assert isolated_llm.model_kwargs == {"custom_param": "value"}
143+
assert isolated_llm.model_kwargs is not llm_with_kwargs.model_kwargs
144+
145+
isolated_llm.model_kwargs["new_param"] = "new_value"
146+
assert "new_param" not in llm_with_kwargs.model_kwargs
147+
148+
def test_flexible_llm_without_model_kwargs_but_allows_adding(self, test_config):
149+
"""Test with LLM that doesn't have model_kwargs but allows adding attributes."""
150+
rails = LLMRails(config=test_config, verbose=False)
151+
152+
flexible_llm = FlexibleLLMWithoutModelKwargs(temperature=0.8)
153+
154+
isolated_llm = rails._create_action_llm_copy(flexible_llm, "test_action")
155+
156+
assert isolated_llm is not None
157+
assert isolated_llm is not flexible_llm
158+
assert isolated_llm.temperature == 0.8
159+
# since it allows extra attributes, model_kwargs might have been added
160+
# but it shouldn't cause an error either way
161+
162+
def test_llm_with_none_model_kwargs(self, test_config):
163+
"""Test with LLM that has model_kwargs set to None."""
164+
rails = LLMRails(config=test_config, verbose=False)
165+
166+
llm_with_none = FlexibleLLMWithModelKwargs(temperature=0.6)
167+
llm_with_none.model_kwargs = None
168+
169+
isolated_llm = rails._create_action_llm_copy(llm_with_none, "test_action")
170+
171+
assert isolated_llm is not None
172+
assert isolated_llm is not llm_with_none
173+
if hasattr(isolated_llm, "model_kwargs"):
174+
assert isolated_llm.model_kwargs in (None, {})
175+
176+
def test_copy_preserves_other_attributes(self, test_config):
177+
"""Test that copy preserves other attributes correctly."""
178+
rails = LLMRails(config=test_config, verbose=False)
179+
180+
strict_llm = StrictPydanticLLM(temperature=0.2, max_tokens=100)
181+
isolated_strict = rails._create_action_llm_copy(strict_llm, "action1")
182+
183+
assert isolated_strict.temperature == 0.2
184+
assert isolated_strict.max_tokens == 100
185+
186+
flexible_llm = FlexibleLLMWithModelKwargs(
187+
model_kwargs={"key": "value"}, temperature=0.9
188+
)
189+
isolated_flexible = rails._create_action_llm_copy(flexible_llm, "action2")
190+
191+
assert isolated_flexible.temperature == 0.9
192+
assert isolated_flexible.model_kwargs == {"key": "value"}

tests/utils.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,41 @@ class FakeLLM(LLM):
4545
"""Fake LLM wrapper for testing purposes."""
4646

4747
responses: List
48-
i: int = 0
4948
streaming: bool = False
5049
exception: Optional[Exception] = None
5150
token_usage: Optional[List[Dict[str, int]]] = None # Token usage per response
5251
should_enable_stream_usage: bool = False
52+
_shared_state: Optional[Dict] = None # Shared state for isolated copies
53+
54+
def __init__(self, **kwargs):
55+
"""Initialize FakeLLM."""
56+
# Extract initial counter value before parent init
57+
initial_i = kwargs.pop("i", 0)
58+
super().__init__(**kwargs)
59+
# If no shared state, create one with initial counter
60+
if self._shared_state is None:
61+
self._shared_state = {"counter": initial_i}
62+
63+
def __copy__(self):
64+
"""Create a shallow copy that shares state with the original."""
65+
new_instance = self.__class__.__new__(self.__class__)
66+
new_instance.__dict__.update(self.__dict__)
67+
# Share the same state dict so counter is synchronized
68+
new_instance._shared_state = self._shared_state
69+
return new_instance
70+
71+
@property
72+
def i(self) -> int:
73+
"""Get current counter value from shared state."""
74+
if self._shared_state:
75+
return self._shared_state["counter"]
76+
return 0
77+
78+
@i.setter
79+
def i(self, value: int):
80+
"""Set counter value in shared state."""
81+
if self._shared_state:
82+
self._shared_state["counter"] = value
5383

5484
@property
5585
def _llm_type(self) -> str:
@@ -67,14 +97,15 @@ def _call(
6797
if self.exception:
6898
raise self.exception
6999

70-
if self.i >= len(self.responses):
100+
current_i = self.i
101+
if current_i >= len(self.responses):
71102
raise RuntimeError(
72-
f"No responses available for query number {self.i + 1} in FakeLLM. "
103+
f"No responses available for query number {current_i + 1} in FakeLLM. "
73104
"Most likely, too many LLM calls are made or additional responses need to be provided."
74105
)
75106

76-
response = self.responses[self.i]
77-
self.i += 1
107+
response = self.responses[current_i]
108+
self.i = current_i + 1
78109
return response
79110

80111
async def _acall(
@@ -88,15 +119,15 @@ async def _acall(
88119
if self.exception:
89120
raise self.exception
90121

91-
if self.i >= len(self.responses):
122+
current_i = self.i
123+
if current_i >= len(self.responses):
92124
raise RuntimeError(
93-
f"No responses available for query number {self.i + 1} in FakeLLM. "
125+
f"No responses available for query number {current_i + 1} in FakeLLM. "
94126
"Most likely, too many LLM calls are made or additional responses need to be provided."
95127
)
96128

97-
response = self.responses[self.i]
98-
99-
self.i += 1
129+
response = self.responses[current_i]
130+
self.i = current_i + 1
100131

101132
if self.streaming and run_manager:
102133
# To mock streaming, we just split in chunk by spaces

0 commit comments

Comments
 (0)