diff --git a/langgraph_swarm/handoff.py b/langgraph_swarm/handoff.py index 620ee51..0eb30cc 100644 --- a/langgraph_swarm/handoff.py +++ b/langgraph_swarm/handoff.py @@ -32,12 +32,17 @@ def _get_field(obj: Any, key: str) -> Any: WHITESPACE_RE = re.compile(r"\s+") +NON_ALNUM_UNDERSCORE_RE = re.compile(r"[^a-z0-9_]+") +MULTI_UNDERSCORE_RE = re.compile(r"_+") METADATA_KEY_HANDOFF_DESTINATION = "__handoff_destination" def _normalize_agent_name(agent_name: str) -> str: """Normalize an agent name to be used inside the tool name.""" - return WHITESPACE_RE.sub("_", agent_name.strip()).lower() + normalized = WHITESPACE_RE.sub("_", agent_name.strip()).lower() + normalized = NON_ALNUM_UNDERSCORE_RE.sub("_", normalized) + normalized = MULTI_UNDERSCORE_RE.sub("_", normalized).strip("_") + return normalized or "agent" def create_handoff_tool( diff --git a/langgraph_swarm/swarm.py b/langgraph_swarm/swarm.py index c6b0609..f2d7d3e 100644 --- a/langgraph_swarm/swarm.py +++ b/langgraph_swarm/swarm.py @@ -161,7 +161,18 @@ def add(a: int, b: int) -> int: ) def route_to_active_agent(state: dict) -> str: - return cast("str", state.get("active_agent", default_active_agent)) + active_agent = state.get("active_agent", default_active_agent) + if not active_agent: + return default_active_agent + if active_agent not in route_to: + warn( + f"Active agent '{active_agent}' not found in routes {route_to}. " + f"Falling back to '{default_active_agent}'.", + RuntimeWarning, + stacklevel=2, + ) + return default_active_agent + return cast("str", active_agent) builder.add_conditional_edges(START, route_to_active_agent, path_map=route_to) return builder diff --git a/tests/test_swarm.py b/tests/test_swarm.py index 912d9a4..defab7a 100644 --- a/tests/test_swarm.py +++ b/tests/test_swarm.py @@ -1,6 +1,7 @@ from collections.abc import Callable, Sequence from typing import TYPE_CHECKING, Any +import pytest from langchain.agents import AgentState, create_agent from langchain.chat_models import BaseChatModel from langchain.messages import AIMessage @@ -45,6 +46,18 @@ def bind_tools( return self +def test_handoff_tool_default_name_sanitizes_special_characters() -> None: + handoff_tool = create_handoff_tool(agent_name="R&D Agent / v2") + + assert handoff_tool.name == "transfer_to_r_d_agent_v2" + + +def test_handoff_tool_default_name_falls_back_when_empty_after_sanitize() -> None: + handoff_tool = create_handoff_tool(agent_name="!!!") + + assert handoff_tool.name == "transfer_to_agent" + + def test_basic_swarm() -> None: # Create fake responses for the model recorded_messages = [ @@ -152,6 +165,63 @@ def add(a: int, b: int) -> int: assert turn_2["active_agent"] == "Alice" +def test_swarm_falls_back_on_unknown_active_agent() -> None: + recorded_messages = [ + AIMessage( + content="", + name="Alice", + tool_calls=[ + { + "name": "transfer_to_bob", + "args": {}, + "id": "call_1LlFyjm6iIhDjdn7juWuPYr4", + }, + ], + ), + AIMessage( + content="Ahoy, matey! Bob the pirate be at yer service.", + name="Bob", + ), + ] + + model = FakeChatModel(responses=recorded_messages) # type: ignore[arg-type] + + def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + alice: Any = create_agent( + model, + tools=[add, create_handoff_tool(agent_name="Bob")], + system_prompt="You are Alice, an addition expert.", + name="Alice", + ) + + bob: Any = create_agent( + model, + tools=[create_handoff_tool(agent_name="Alice")], + system_prompt="You are Bob, you speak like a pirate.", + name="Bob", + ) + + checkpointer = MemorySaver() + workflow = create_swarm([alice, bob], default_active_agent="Alice") # type: ignore[list-item] + app = workflow.compile(checkpointer=checkpointer) + + config: RunnableConfig = {"configurable": {"thread_id": "fallback-1"}} + with pytest.warns(RuntimeWarning, match="Active agent 'Charlie'"): + turn_1 = app.invoke( + { # type: ignore[arg-type] + "messages": [{"role": "user", "content": "i'd like to speak to Bob"}], + "active_agent": "Charlie", + }, + config, + ) + + assert turn_1["active_agent"] == "Bob" + assert turn_1["messages"][-2].content == "Successfully transferred to Bob" + + def test_basic_swarm_pydantic() -> None: """Test a basic swarm with Pydantic state schema."""