Skip to content

Commit 57a3abc

Browse files
authored
refactor(llm): move get_action_details_from_flow_id from llmrails.py to utils.py (#1341)
1 parent 626fb6c commit 57a3abc

File tree

5 files changed

+446
-181
lines changed

5 files changed

+446
-181
lines changed

nemoguardrails/rails/llm/llmrails.py

Lines changed: 5 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@
7373
GenerationOptions,
7474
GenerationResponse,
7575
)
76-
from nemoguardrails.rails.llm.utils import get_history_cache_key
76+
from nemoguardrails.rails.llm.utils import (
77+
get_action_details_from_flow_id,
78+
get_history_cache_key,
79+
)
7780
from nemoguardrails.streaming import END_OF_STREAM, StreamingHandler
7881
from nemoguardrails.utils import (
7982
extract_error_json,
@@ -1590,7 +1593,7 @@ def _prepare_params(
15901593
output_rails_flows_id = self.config.rails.output.flows
15911594
stream_first = stream_first or output_rails_streaming_config.stream_first
15921595
get_action_details = partial(
1593-
_get_action_details_from_flow_id, flows=self.config.flows
1596+
get_action_details_from_flow_id, flows=self.config.flows
15941597
)
15951598

15961599
parallel_mode = getattr(self.config.rails.output, "parallel", False)
@@ -1746,58 +1749,3 @@ def _prepare_params(
17461749
# yield the individual chunks directly from the buffer strategy
17471750
for chunk in user_output_chunks:
17481751
yield chunk
1749-
1750-
1751-
def _get_action_details_from_flow_id(
1752-
flow_id: str,
1753-
flows: List[Union[Dict, Any]],
1754-
prefixes: Optional[List[str]] = None,
1755-
) -> Tuple[str, Any]:
1756-
"""Get the action name and parameters from the flow id.
1757-
1758-
First, try to find an exact match.
1759-
If not found, then if the provided flow_id starts with one of the special prefixes,
1760-
return the first flow whose id starts with that same prefix.
1761-
"""
1762-
1763-
supported_prefixes = [
1764-
"content safety check output",
1765-
"topic safety check output",
1766-
]
1767-
if prefixes:
1768-
supported_prefixes.extend(prefixes)
1769-
1770-
candidate_flow = None
1771-
1772-
for flow in flows:
1773-
# If exact match, use it
1774-
if flow["id"] == flow_id:
1775-
candidate_flow = flow
1776-
break
1777-
1778-
# If no exact match, check if both the provided flow_id and this flow's id share a special prefix
1779-
for prefix in supported_prefixes:
1780-
if flow_id.startswith(prefix) and flow["id"].startswith(prefix):
1781-
candidate_flow = flow
1782-
# We don't break immediately here because an exact match would have been preferred,
1783-
# but since we're in the else branch it's fine to choose the first matching candidate.
1784-
# TODO:we should avoid having multiple matchin prefixes
1785-
break
1786-
1787-
if candidate_flow is not None:
1788-
break
1789-
1790-
if candidate_flow is None:
1791-
raise ValueError(f"No action found for flow_id: {flow_id}")
1792-
1793-
# we have identified a candidate, look for the run_action element.
1794-
for element in candidate_flow["elements"]:
1795-
if (
1796-
element["_type"] == "run_action"
1797-
and element["_source_mapping"]["filename"].endswith(".co")
1798-
and "execute" in element["_source_mapping"]["line_text"]
1799-
and "action_name" in element
1800-
):
1801-
return element["action_name"], element["action_params"]
1802-
1803-
raise ValueError(f"No run_action element found for flow_id: {flow_id}")

nemoguardrails/rails/llm/utils.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import json
16-
from typing import List
16+
from typing import Any, Dict, List, Optional, Tuple, Union
1717

1818

1919
def get_history_cache_key(messages: List[dict]) -> str:
@@ -56,3 +56,57 @@ def get_history_cache_key(messages: List[dict]) -> str:
5656
history_cache_key = ":".join(key_items)
5757

5858
return history_cache_key
59+
60+
61+
def get_action_details_from_flow_id(
62+
flow_id: str,
63+
flows: List[Union[Dict, Any]],
64+
prefixes: Optional[List[str]] = None,
65+
) -> Tuple[str, Any]:
66+
"""Get the action name and parameters from the flow id.
67+
68+
First, try to find an exact match.
69+
If not found, then if the provided flow_id starts with one of the special prefixes,
70+
return the first flow whose id starts with that same prefix.
71+
"""
72+
supported_prefixes = [
73+
"content safety check output",
74+
"topic safety check output",
75+
]
76+
if prefixes:
77+
supported_prefixes.extend(prefixes)
78+
79+
candidate_flow = None
80+
81+
for flow in flows:
82+
# If exact match, use it
83+
if flow["id"] == flow_id:
84+
candidate_flow = flow
85+
break
86+
87+
# If no exact match, check if both the provided flow_id and this flow's id share a special prefix
88+
for prefix in supported_prefixes:
89+
if flow_id.startswith(prefix) and flow["id"].startswith(prefix):
90+
candidate_flow = flow
91+
# We don't break immediately here because an exact match would have been preferred,
92+
# but since we're in the else branch it's fine to choose the first matching candidate.
93+
# TODO:we should avoid having multiple matchin prefixes
94+
break
95+
96+
if candidate_flow is not None:
97+
break
98+
99+
if candidate_flow is None:
100+
raise ValueError(f"No action found for flow_id: {flow_id}")
101+
102+
# we have identified a candidate, look for the run_action element.
103+
for element in candidate_flow["elements"]:
104+
if (
105+
element["_type"] == "run_action"
106+
and element["_source_mapping"]["filename"].endswith(".co")
107+
and "execute" in element["_source_mapping"]["line_text"]
108+
and "action_name" in element
109+
):
110+
return element["action_name"], element["action_params"]
111+
112+
raise ValueError(f"No run_action element found for flow_id: {flow_id}")

tests/test_llmrails.py

Lines changed: 1 addition & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from nemoguardrails import LLMRails, RailsConfig
2323
from nemoguardrails.rails.llm.config import Model
24-
from nemoguardrails.rails.llm.llmrails import _get_action_details_from_flow_id
24+
from nemoguardrails.rails.llm.llmrails import get_action_details_from_flow_id
2525
from tests.utils import FakeLLM, clean_events, event_sequence_conforms
2626

2727

@@ -627,124 +627,6 @@ async def compute(what: Optional[str] = "2 + 3"):
627627
}
628628

629629

630-
# get_action_details_from_flow_id used in llmrails.py
631-
632-
633-
@pytest.fixture
634-
def dummy_flows() -> List[Union[Dict, Any]]:
635-
return [
636-
{
637-
"id": "test_flow",
638-
"elements": [
639-
{
640-
"_type": "run_action",
641-
"_source_mapping": {
642-
"filename": "flows.v1.co",
643-
"line_text": "execute something",
644-
},
645-
"action_name": "test_action",
646-
"action_params": {"param1": "value1"},
647-
}
648-
],
649-
},
650-
# Additional flow that should match on a prefix
651-
{
652-
"id": "other_flow is prefix",
653-
"elements": [
654-
{
655-
"_type": "run_action",
656-
"_source_mapping": {
657-
"filename": "flows.v1.co",
658-
"line_text": "execute something else",
659-
},
660-
"action_name": "other_action",
661-
"action_params": {"param2": "value2"},
662-
}
663-
],
664-
},
665-
{
666-
"id": "test_rails_co",
667-
"elements": [
668-
{
669-
"_type": "run_action",
670-
"_source_mapping": {
671-
"filename": "rails.co",
672-
"line_text": "execute something",
673-
},
674-
"action_name": "test_action_supported",
675-
"action_params": {"param1": "value1"},
676-
}
677-
],
678-
},
679-
{
680-
"id": "test_rails_co_v2",
681-
"elements": [
682-
{
683-
"_type": "run_action",
684-
"_source_mapping": {
685-
"filename": "rails.co",
686-
"line_text": "await something", # in colang 2 we use await
687-
},
688-
"action_name": "test_action_not_supported",
689-
"action_params": {"param1": "value1"},
690-
}
691-
],
692-
},
693-
]
694-
695-
696-
def test_get_action_details_exact_match(dummy_flows):
697-
action_name, action_params = _get_action_details_from_flow_id(
698-
"test_flow", dummy_flows
699-
)
700-
assert action_name == "test_action"
701-
assert action_params == {"param1": "value1"}
702-
703-
704-
def test_get_action_details_exact_match_any_co_file(dummy_flows):
705-
action_name, action_params = _get_action_details_from_flow_id(
706-
"test_rails_co", dummy_flows
707-
)
708-
assert action_name == "test_action_supported"
709-
assert action_params == {"param1": "value1"}
710-
711-
712-
def test_get_action_details_exact_match_not_colang_2(dummy_flows):
713-
with pytest.raises(ValueError) as exc_info:
714-
_get_action_details_from_flow_id("test_rails_co_v2", dummy_flows)
715-
716-
assert "No run_action element found for flow_id" in str(exc_info.value)
717-
718-
719-
def test_get_action_details_prefix_match(dummy_flows):
720-
# For a flow_id that starts with the prefix "other_flow",
721-
# we expect to retrieve the action details from the flow whose id starts with that prefix.
722-
# we expect a result since we are passing the prefixes argument.
723-
action_name, action_params = _get_action_details_from_flow_id(
724-
"other_flow", dummy_flows, prefixes=["other_flow"]
725-
)
726-
assert action_name == "other_action"
727-
assert action_params == {"param2": "value2"}
728-
729-
730-
def test_get_action_details_prefix_match_unsupported_prefix(dummy_flows):
731-
# For a flow_id that starts with the prefix "other_flow",
732-
# we expect to retrieve the action details from the flow whose id starts with that prefix.
733-
# but as the prefix is not supported, we expect a ValueError.
734-
735-
with pytest.raises(ValueError) as exc_info:
736-
_get_action_details_from_flow_id("other_flow", dummy_flows)
737-
738-
assert "No action found for flow_id" in str(exc_info.value)
739-
740-
741-
def test_get_action_details_no_match(dummy_flows):
742-
# Tests that a non matching flow_id raises a ValueError
743-
with pytest.raises(ValueError) as exc_info:
744-
_get_action_details_from_flow_id("non_existing_flow", dummy_flows)
745-
assert "No action found for flow_id" in str(exc_info.value)
746-
747-
748630
@pytest.fixture
749631
def llm_config_with_main():
750632
"""Fixture providing a basic config with a main LLM."""

tests/test_llmrails_reasoning.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,12 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from typing import Any, Dict, List, Optional, Union
16+
from typing import Optional
1717

1818
import pytest
1919

2020
from nemoguardrails import LLMRails, RailsConfig
21-
from nemoguardrails.rails.llm.llmrails import _get_action_details_from_flow_id
22-
from tests.utils import FakeLLM, clean_events, event_sequence_conforms
21+
from tests.utils import FakeLLM
2322

2423

2524
@pytest.fixture

0 commit comments

Comments
 (0)