Skip to content

Commit a5e8b4c

Browse files
authored
feat(llmrails): isolate LLMs only for configured actions (#1342)
Update isolated LLM creation to only target actions defined in rails config flows. This ensures that LLMs are not unnecessarily created for actions not present in the configuration. Adds comprehensive tests to verify correct behavior, including handling of empty configs and skipping already-registered LLMs.
1 parent 94688a8 commit a5e8b4c

File tree

3 files changed

+287
-18
lines changed

3 files changed

+287
-18
lines changed

nemoguardrails/rails/llm/llmrails.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,21 @@ def _create_isolated_llms_for_actions(self):
523523
)
524524

525525
created_count = 0
526-
for action_name in actions_needing_llms:
526+
# Get the actions from flows defined in rails config
527+
get_action_details = partial(
528+
get_action_details_from_flow_id, flows=self.config.flows
529+
)
530+
configured_actions_names = []
531+
for flow_id in self.config.rails.input.flows:
532+
action_name, _ = get_action_details(flow_id)
533+
configured_actions_names.append(action_name)
534+
for flow_id in self.config.rails.output.flows:
535+
action_name, _ = get_action_details(flow_id)
536+
configured_actions_names.append(action_name)
537+
538+
for action_name in configured_actions_names:
539+
if action_name not in actions_needing_llms:
540+
continue
527541
if f"{action_name}_llm" not in self.runtime.registered_action_params:
528542
isolated_llm = self._create_action_llm_copy(self.llm, action_name)
529543
if isolated_llm:

tests/test_llm_isolation.py

Lines changed: 139 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -255,48 +255,89 @@ def test_create_isolated_llms_for_actions_integration(self, rails_with_mock_llm)
255255
"""Test the full isolated LLM creation process."""
256256
rails = rails_with_mock_llm
257257

258+
# Mock rails configuration with flows
259+
rails.config.rails = Mock()
260+
rails.config.rails.input = Mock()
261+
rails.config.rails.output = Mock()
262+
rails.config.rails.input.flows = ["input_flow_1", "input_flow_2"]
263+
rails.config.rails.output.flows = ["output_flow_1"]
264+
258265
rails.runtime = Mock()
259266
rails.runtime.action_dispatcher = MockActionDispatcher()
260267
rails.runtime.registered_action_params = {}
261268
rails.runtime.register_action_param = Mock()
262269

263-
rails._create_isolated_llms_for_actions()
264-
265-
expected_calls = [
270+
# Mock get_action_details_from_flow_id to return actions that need LLMs
271+
def mock_get_action_details(flow_id, flows):
272+
mapping = {
273+
"input_flow_1": ("action_with_llm", {}),
274+
"input_flow_2": ("generate_user_intent", {}),
275+
"output_flow_1": ("self_check_output", {}),
276+
}
277+
return mapping.get(flow_id, ("unknown_action", {}))
278+
279+
with patch(
280+
"nemoguardrails.rails.llm.llmrails.get_action_details_from_flow_id",
281+
side_effect=mock_get_action_details,
282+
):
283+
rails._create_isolated_llms_for_actions()
284+
285+
expected_llm_params = [
266286
"action_with_llm_llm",
267287
"generate_user_intent_llm",
268288
"self_check_output_llm",
269289
]
270290

271-
actual_calls = [
291+
registered_llm_params = [
272292
call[0][0] for call in rails.runtime.register_action_param.call_args_list
273293
]
274294

275-
for expected_call in expected_calls:
276-
assert expected_call in actual_calls
295+
for expected_param in expected_llm_params:
296+
assert expected_param in registered_llm_params
277297

278298
def test_create_isolated_llms_skips_existing_specialized_llms(
279299
self, rails_with_mock_llm
280300
):
281301
"""Test that existing specialized LLMs are not overridden."""
282302
rails = rails_with_mock_llm
283303

304+
# Mock rails configuration with flows
305+
rails.config.rails = Mock()
306+
rails.config.rails.input = Mock()
307+
rails.config.rails.output = Mock()
308+
rails.config.rails.input.flows = ["input_flow_1", "input_flow_2"]
309+
rails.config.rails.output.flows = ["output_flow_1"]
310+
284311
rails.runtime = Mock()
285312
rails.runtime.action_dispatcher = MockActionDispatcher()
286313
rails.runtime.registered_action_params = {"self_check_output_llm": Mock()}
287314
rails.runtime.register_action_param = Mock()
288315

289-
rails._create_isolated_llms_for_actions()
290-
291-
# verify self_check_output_llm was NOT re-registered
292-
actual_calls = [
316+
# Mock get_action_details_from_flow_id to return actions that need LLMs
317+
def mock_get_action_details(flow_id, flows):
318+
mapping = {
319+
"input_flow_1": ("action_with_llm", {}),
320+
"input_flow_2": ("generate_user_intent", {}),
321+
"output_flow_1": (
322+
"self_check_output",
323+
{},
324+
), # This one already has an LLM
325+
}
326+
return mapping.get(flow_id, ("unknown_action", {}))
327+
328+
with patch(
329+
"nemoguardrails.rails.llm.llmrails.get_action_details_from_flow_id",
330+
side_effect=mock_get_action_details,
331+
):
332+
rails._create_isolated_llms_for_actions()
333+
334+
registered_llm_params = [
293335
call[0][0] for call in rails.runtime.register_action_param.call_args_list
294336
]
295-
assert "self_check_output_llm" not in actual_calls
296337

297-
# but other actions should still get isolated LLMs
298-
assert "action_with_llm_llm" in actual_calls
299-
assert "generate_user_intent_llm" in actual_calls
338+
assert "self_check_output_llm" not in registered_llm_params
339+
assert "action_with_llm_llm" in registered_llm_params
340+
assert "generate_user_intent_llm" in registered_llm_params
300341

301342
def test_create_isolated_llms_handles_no_main_llm(self, mock_config):
302343
"""Test graceful handling when no main LLM is available."""
@@ -411,3 +452,87 @@ def test_action_detection_parametrized(
411452
assert action_name in actions_needing_llms
412453
else:
413454
assert action_name not in actions_needing_llms
455+
456+
def test_create_isolated_llms_for_configured_actions_only(
457+
self, rails_with_mock_llm
458+
):
459+
"""Test that isolated LLMs are created only for actions configured in rails flows."""
460+
rails = rails_with_mock_llm
461+
462+
rails.config.rails = Mock()
463+
rails.config.rails.input = Mock()
464+
rails.config.rails.output = Mock()
465+
rails.config.rails.input.flows = [
466+
"input_flow_1",
467+
"input_flow_2",
468+
"input_flow_3",
469+
]
470+
rails.config.rails.output.flows = ["output_flow_1", "output_flow_2"]
471+
472+
rails.runtime = Mock()
473+
rails.runtime.action_dispatcher = MockActionDispatcher()
474+
rails.runtime.registered_action_params = {}
475+
rails.runtime.register_action_param = Mock()
476+
477+
def mock_get_action_details(flow_id, flows):
478+
mapping = {
479+
"input_flow_1": ("action_with_llm", {}),
480+
"input_flow_2": ("action_without_llm", {}),
481+
"input_flow_3": ("self_check_output", {}),
482+
"output_flow_1": ("generate_user_intent", {}),
483+
"output_flow_2": ("non_configured_action", {}),
484+
}
485+
return mapping.get(flow_id, ("unknown_action", {}))
486+
487+
with patch(
488+
"nemoguardrails.rails.llm.llmrails.get_action_details_from_flow_id",
489+
side_effect=mock_get_action_details,
490+
):
491+
rails._create_isolated_llms_for_actions()
492+
493+
registered_llm_params = [
494+
call[0][0] for call in rails.runtime.register_action_param.call_args_list
495+
]
496+
497+
expected_isolated_llm_params = [
498+
"action_with_llm_llm",
499+
"generate_user_intent_llm",
500+
"self_check_output_llm",
501+
]
502+
503+
for expected_param in expected_isolated_llm_params:
504+
assert (
505+
expected_param in registered_llm_params
506+
), f"Expected {expected_param} to be registered as action param"
507+
508+
assert "action_without_llm_llm" not in registered_llm_params
509+
assert "non_configured_action_llm" not in registered_llm_params
510+
511+
assert len(registered_llm_params) == 3, (
512+
f"Should only create isolated LLMs for actions from config flows that need LLMs. "
513+
f"Got {registered_llm_params}"
514+
)
515+
516+
def test_create_isolated_llms_handles_empty_rails_config(self, rails_with_mock_llm):
517+
"""Test that the method handles empty rails configuration gracefully."""
518+
rails = rails_with_mock_llm
519+
520+
rails.config.rails = Mock()
521+
rails.config.rails.input = Mock()
522+
rails.config.rails.output = Mock()
523+
rails.config.rails.input.flows = []
524+
rails.config.rails.output.flows = []
525+
526+
rails.runtime = Mock()
527+
rails.runtime.action_dispatcher = MockActionDispatcher()
528+
rails.runtime.registered_action_params = {}
529+
rails.runtime.register_action_param = Mock()
530+
531+
with patch(
532+
"nemoguardrails.rails.llm.llmrails.get_action_details_from_flow_id"
533+
) as mock_get_action:
534+
rails._create_isolated_llms_for_actions()
535+
536+
mock_get_action.assert_not_called()
537+
538+
rails.runtime.register_action_param.assert_not_called()

tests/test_llm_isolation_e2e.py

Lines changed: 133 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from nemoguardrails import LLMRails
2626
from nemoguardrails.rails.llm.config import RailsConfig
2727

28-
TEST_LIVE_MODE = os.environ.get("TEST_LIVE_MODE")
28+
LIVE_TEST_MODE = os.environ.get("TEST_LIVE_MODE")
2929

3030

3131
@pytest.fixture
@@ -75,7 +75,7 @@ def test_config_path(test_config_content):
7575

7676

7777
@pytest.mark.skipif(
78-
not TEST_LIVE_MODE,
78+
not LIVE_TEST_MODE,
7979
reason="This test requires TEST_LIVE_MODE environment variable to be set for live testing",
8080
)
8181
class TestLLMIsolationE2E:
@@ -383,7 +383,7 @@ async def test_parameter_isolation_multiple_iterations(
383383

384384

385385
@pytest.mark.skipif(
386-
not TEST_LIVE_MODE,
386+
not LIVE_TEST_MODE,
387387
reason="This test requires TEST_LIVE_MODE environment variable to be set for live testing",
388388
)
389389
class TestLLMIsolationErrorHandling:
@@ -474,5 +474,135 @@ async def run_parameter_contamination_test():
474474
)
475475

476476

477+
@pytest.mark.skipif(
478+
not LIVE_TEST_MODE,
479+
reason="This test requires TEST_LIVE_MODE environment variable to be set for live testing",
480+
)
481+
class TestLLMIsolationConfiguredActionsOnly:
482+
"""Test that isolated LLMs are created only for actions configured in rails flows."""
483+
484+
@staticmethod
485+
def _create_rails_with_config(config_content: str) -> LLMRails:
486+
"""Helper to create LLMRails instance from config content."""
487+
with tempfile.TemporaryDirectory() as temp_dir:
488+
config_path = Path(temp_dir) / "config.yml"
489+
config_path.write_text(config_content)
490+
config = RailsConfig.from_path(str(temp_dir))
491+
return LLMRails(config, verbose=False)
492+
493+
@staticmethod
494+
def _get_isolated_llm_params(
495+
rails: LLMRails, exclude_specialized: bool = False
496+
) -> list:
497+
"""Helper to get isolated LLM parameters from rails instance."""
498+
registered_params = rails.runtime.registered_action_params
499+
isolated_llm_params = [
500+
key
501+
for key in registered_params.keys()
502+
if key.endswith("_llm") and key != "llm" and key != "llms"
503+
]
504+
505+
if exclude_specialized:
506+
specialized_llms = ["content_safety_llm", "topic_safety_llm"]
507+
isolated_llm_params = [
508+
param for param in isolated_llm_params if param not in specialized_llms
509+
]
510+
511+
return isolated_llm_params
512+
513+
def test_only_configured__rail_actions_get_isolated_llms(self):
514+
"""Test that only actions from output rails flows get isolated LLMs."""
515+
config_content = """
516+
models:
517+
- type: main
518+
engine: openai
519+
model: gpt-4o-mini
520+
521+
rails:
522+
output:
523+
flows:
524+
- self check output
525+
- self check input
526+
527+
prompts:
528+
- task: self_check_output
529+
content: |
530+
Check if output is safe.
531+
Output: {{ bot_message }}
532+
Safe? (Yes/No):
533+
- task: self_check_input
534+
content: |
535+
Check if input is safe.
536+
Input: {{ user_input }}
537+
Safe? (Yes/No):
538+
"""
539+
540+
rails = self._create_rails_with_config(config_content)
541+
isolated_llm_params = self._get_isolated_llm_params(rails)
542+
543+
assert "self_check_output_llm" in isolated_llm_params
544+
assert "self_check_input_llm" in isolated_llm_params
545+
assert "self_check_facts_llm" not in isolated_llm_params
546+
547+
def test_no_isolated_llms_when_no_rails_configured(self):
548+
"""Test that no isolated LLMs are created when no rails are configured."""
549+
config_content = """
550+
models:
551+
- type: main
552+
engine: openai
553+
model: gpt-4o-mini
554+
"""
555+
556+
rails = self._create_rails_with_config(config_content)
557+
isolated_llm_params = self._get_isolated_llm_params(
558+
rails, exclude_specialized=True
559+
)
560+
561+
assert (
562+
len(isolated_llm_params) == 0
563+
), f"Unexpected isolated LLMs created: {isolated_llm_params}"
564+
565+
def test_empty_rails_flows_creates_no_isolated_llms(self):
566+
"""Test that empty rails flows list creates no isolated LLMs."""
567+
config_content = """
568+
models:
569+
- type: main
570+
engine: openai
571+
model: gpt-4o-mini
572+
573+
rails:
574+
input:
575+
flows: []
576+
output:
577+
flows: []
578+
"""
579+
580+
rails = self._create_rails_with_config(config_content)
581+
isolated_llm_params = self._get_isolated_llm_params(
582+
rails, exclude_specialized=True
583+
)
584+
585+
assert (
586+
len(isolated_llm_params) == 0
587+
), f"Unexpected isolated LLMs created: {isolated_llm_params}"
588+
589+
def test_non_llm_requiring_actions_dont_get_isolated_llms(self):
590+
"""Test that even valid flows don't get isolated LLMs if actions don't require LLMs."""
591+
config_content = """
592+
models:
593+
- type: main
594+
engine: openai
595+
model: gpt-4o-mini
596+
"""
597+
598+
rails = self._create_rails_with_config(config_content)
599+
600+
# retrieve_relevant_chunks action exists but doesn't require LLM
601+
# so it should never get an isolated LLM even if it were configured
602+
assert (
603+
"retrieve_relevant_chunks_llm" not in rails.runtime.registered_action_params
604+
)
605+
606+
477607
if __name__ == "__main__":
478608
asyncio.run(run_parameter_contamination_test())

0 commit comments

Comments
 (0)