|
15 | 15 |
|
16 | 16 | import unittest |
17 | 17 | from typing import Any, Dict |
| 18 | +from unittest.mock import AsyncMock, MagicMock |
18 | 19 |
|
| 20 | +import pytest |
19 | 21 | from pydantic import BaseModel |
20 | 22 |
|
| 23 | +from nemoguardrails.actions.llm.utils import llm_call |
21 | 24 | from nemoguardrails.llm.params import LLMParams, llm_params, register_param_manager |
22 | 25 |
|
23 | 26 |
|
@@ -219,3 +222,48 @@ class UnregisteredLLM(BaseModel): |
219 | 222 | pass |
220 | 223 |
|
221 | 224 | self.assertIsInstance(llm_params(UnregisteredLLM()), LLMParams) |
| 225 | + |
| 226 | + |
| 227 | +class TestLLMParamsMigration(unittest.TestCase): |
| 228 | + """Test migration from context manager to direct parameter passing.""" |
| 229 | + |
| 230 | + def test_context_manager_equivalent_to_direct_params(self): |
| 231 | + """Test that context manager behavior matches direct parameter passing.""" |
| 232 | + llm = FakeLLM(param3="original", model_kwargs={"temperature": 0.5}) |
| 233 | + |
| 234 | + with llm_params(llm, temperature=0.8, param3="modified"): |
| 235 | + context_temp = llm.model_kwargs.get("temperature") |
| 236 | + context_param3 = llm.param3 |
| 237 | + |
| 238 | + assert context_temp == 0.8 |
| 239 | + assert context_param3 == "modified" |
| 240 | + assert llm.model_kwargs.get("temperature") == 0.5 |
| 241 | + assert llm.param3 == "original" |
| 242 | + |
| 243 | + @pytest.mark.asyncio |
| 244 | + async def test_llm_call_params_vs_context_manager(self): |
| 245 | + """Test that llm_call with params produces similar results to context manager approach.""" |
| 246 | + mock_llm = AsyncMock() |
| 247 | + mock_bound_llm = AsyncMock() |
| 248 | + mock_response = MagicMock() |
| 249 | + mock_response.content = "Response content" |
| 250 | + |
| 251 | + mock_llm.bind.return_value = mock_bound_llm |
| 252 | + mock_bound_llm.ainvoke.return_value = mock_response |
| 253 | + |
| 254 | + params = {"temperature": 0.7, "max_tokens": 100} |
| 255 | + |
| 256 | + result = await llm_call(mock_llm, "Test prompt", llm_params=params) |
| 257 | + |
| 258 | + assert result == "Response content" |
| 259 | + mock_llm.bind.assert_called_once_with(**params) |
| 260 | + mock_bound_llm.ainvoke.assert_called_once() |
| 261 | + |
| 262 | + def test_parameter_isolation_after_migration(self): |
| 263 | + """Test that parameter changes don't persist after llm_call completes.""" |
| 264 | + llm = FakeLLM(param3="original", model_kwargs={"temperature": 0.5}) |
| 265 | + original_temp = llm.model_kwargs.get("temperature") |
| 266 | + original_param3 = llm.param3 |
| 267 | + |
| 268 | + assert original_temp == 0.5 |
| 269 | + assert original_param3 == "original" |
0 commit comments