Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions src/google/adk/agents/map_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from typing import Annotated
from typing import AsyncGenerator

from annotated_types import Len
from google.adk.agents import BaseAgent
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.parallel_agent import _merge_agent_run
from google.adk.events import Event
from google.genai import types
from pydantic import Field
from pydantic import RootModel
from typing_extensions import override


class MapAgent(BaseAgent):
sub_agents: Annotated[list[BaseAgent], Len(1, 1)] = Field(
min_length=1,
max_length=1,
default_factory=list,
description=(
"A single base agent that will be copied and invoked for each prompt"
),
)

@override
async def _run_async_impl(
self, invocation_context: InvocationContext
) -> AsyncGenerator[Event, None]:
"""Core logic of this workflow agent.

Args:
invocation_context: InvocationContext, provides access to the input prompts.

Yields:
Event: the events generated by the sub-agent for each input prompt.
"""

# Create a branch string if it doesn't exist, to ensure parallel invocations don't interfere with each other
invocation_context.branch = invocation_context.branch or self.name

prompts, invoker = self._extract_input_prompts(invocation_context)

# for agent naming - e.g. if there are 100-999 prompts, sub-agent copies are named 001, 002, 003 and so on
number_field_width = len(str(len(prompts)))

# Create a separate invocation context for each prompt, each with a numbered copy of the sub-agent.
contexts = [
self._branch_context(
invocation_context,
idx=i,
prompt=prompt,
invoker=invoker,
width=number_field_width,
)
for i, prompt in enumerate(prompts)
]

async for event in _merge_agent_run(
[ctx.agent.run_async(ctx) for ctx in contexts]
):
yield event

def _extract_input_prompts(
self, ctx: InvocationContext
) -> tuple[list[str], str]:
"""
The input to the map agent is a list of strings.
We extract the text content from the latest event, and assume it is a list of strings serialized as a json string.
"""
invoker = "user"

for i in range(len(ctx.session.events) - 1, -1, -1):
event = ctx.session.events[i]
if event.branch is None or (
ctx.branch is not None and event.branch.startswith(ctx.branch)
):
break
else:
return [], "user"

invoker: str = event.author
input_message: str = (
(event.content or types.Content()).parts or [types.Part()]
)[0].text or ""

"""
Remove the event which has the prompt list, so that a sub agent does not see the prompts of its siblings, which may confuse it.
The event is removed only for this invocation.
"""
ctx.session.events.pop(i)

agent_input = RootModel[list[str]].model_validate_json(input_message).root

return agent_input, invoker

@staticmethod
def _get_unique_name(name: str, idx: int, width: int) -> str:
"""e.g. my_sub_agent_046"""
return f"{name}_{idx:0{width}d}"

def _branch_context(
self,
ctx: InvocationContext,
*,
prompt: str,
invoker: str,
idx: int,
width: int,
) -> InvocationContext:
"""Creates a numbered copy of the sub-agent that sees a single prompt, and can run separately from its siblings.

Args:
ctx: The current invocation context of the map agent. To be copied and edited for the sub-agent copy.
prompt: the prompt on which the sub-agent copy should be invoked
invoker: the invoker of the map agent in this invocation.
idx: index of the prompt in the input prompts, serves as a unique postfix to the agent name
width: number of digits in the total number of prompts, to ensure naming is consistent in field width
(e.g. 001, 002, ... 010, 011, ... 100, 101; and not 1, 2, ... 10, 11, ... 100, 101)

Returns:
InvocationContext: A new invocation context ready to run with the unique sub-agent copy and the prompt
"""

agent = self._branch_agent_tree(self.sub_agents[0], idx, width)

branch = f"{ctx.branch}.{agent.name}"
prompt_part = [types.Part(text=prompt)]

# Add the prompt to the user_content of this branch to easily access agent input in callbacks
user_content = types.Content(
role="user",
parts=((ctx.user_content or types.Content()).parts or []) + prompt_part,
)
new_ctx = ctx.model_copy(
update=dict(branch=branch, agent=agent, user_content=user_content)
)

# Add the prompt as a temporary event of this branch in place of the prompt list as the natural input of the sub-agent.
prompt_content = types.Content(
role="user" if invoker == "user" else "model", parts=prompt_part
)
new_ctx.session.events.append(
Event(author=invoker, branch=branch, content=prompt_content)
)

return new_ctx

def _branch_agent_tree(
self, agent: BaseAgent, idx: int, width: int
) -> BaseAgent:
"""
Clone and rename an agent and its sub-tree to create a thread-safe branch.
"""
new_agent = agent.model_copy(
update={"name": self._get_unique_name(agent.name, idx=idx, width=width)}
)

new_agent.sub_agents = [
self._branch_agent_tree(a, idx, width) for a in agent.sub_agents
]
return new_agent
229 changes: 229 additions & 0 deletions tests/unittests/agents/test_map_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
import json
import re
from typing import AsyncGenerator

from google.adk.agents import LlmAgent
from google.adk.agents import LoopAgent
from google.adk.agents import MapAgent
from google.adk.agents import ParallelAgent
from google.adk.agents import SequentialAgent
from google.adk.agents.callback_context import CallbackContext
from google.adk.events import Event
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.genai import types
import pytest

from ..testing_utils import MockModel
from ..testing_utils import ModelContent
from ..testing_utils import TestInMemoryRunner


class OneTwoThreeModel(MockModel):
"""Maps an input of 'i' to output of "['i', 'i+1', 'i+2']", e.g. '5' -> "['5', '6', '7']" """

responses: list[LlmResponse] = []

async def generate_content_async(
self, llm_request: LlmRequest, stream: bool = False
) -> AsyncGenerator[LlmResponse, None]:
agent_input: str | None = (
(llm_request.contents[-1] or types.Content()).parts or [types.Part()]
)[-1].text
assert agent_input is not None
agent_input = re.sub(r"\[\w+\] said: ", "", agent_input)
assert agent_input.isnumeric()
res = json.dumps([str(int(agent_input) + i) for i in range(3)])
yield LlmResponse(content=ModelContent([types.Part(text=res)]))


def extract_event_text(events: list[Event], agent_prefix: str) -> list[str]:
filtered_events = [e for e in events if e.author.startswith(agent_prefix)]
sorted_events = sorted(
filtered_events,
key=lambda e: (
e.author,
((e.content or types.Content()).parts or [types.Part()])[0].text
or "",
),
)
contents = [e.content or types.Content() for e in sorted_events]
return [(c.parts or [types.Part()])[0].text or "" for c in contents]


@pytest.mark.asyncio
async def test_map_agent_empty_input():
def delete_events(callback_context: CallbackContext) -> None:
callback_context._invocation_context.session.events.clear()

map = MapAgent(
name="map_agent",
sub_agents=[
LlmAgent(
name="test", model=MockModel.create([], error=RuntimeError())
)
],
before_agent_callback=delete_events,
)

runner = TestInMemoryRunner(map)
await runner.run_async_with_new_session("")


@pytest.mark.asyncio
async def test_map_agent_text_input():
map = MapAgent(
name="map_agent",
sub_agents=[LlmAgent(name="mock_agent", model=OneTwoThreeModel())],
)

runner = TestInMemoryRunner(map)

n_runs = 100

input_data = json.dumps([str(i) for i in range(n_runs)])
expected_output = [
json.dumps([str(j) for j in range(i, i + 3)]) for i in range(n_runs)
]

events = await runner.run_async_with_new_session(input_data)
res = extract_event_text(events, "mock_agent")

assert res == expected_output


@pytest.mark.asyncio
async def test_map_agent_with_loop_agent_parent():
map_agent = MapAgent(
name="map_agent",
sub_agents=[LlmAgent(name="mock_agent", model=OneTwoThreeModel())],
)

loop_agent = LoopAgent(
name="test_loop",
sub_agents=[map_agent],
max_iterations=2,
)

runner = TestInMemoryRunner(loop_agent)

input_data = json.dumps(["0"])
expected_output = [json.dumps(["0", "1", "2"])] + [
json.dumps([str(j) for j in range(i, i + 3)]) for i in range(3)
]

events = await runner.run_async_with_new_session(input_data)
res = extract_event_text(events, "mock_agent")
assert res == expected_output


@pytest.mark.parametrize("SubagentClass", [ParallelAgent, SequentialAgent])
@pytest.mark.asyncio
async def test_map_agent_with_sequential_or_parallel_agent(SubagentClass):
"""test map agent with a parallel / sequential sub-agent whose sub-agents don't communicate"""

# A lone parallel agent wrapper hides mock_1's output from its 'cousin' mock_2
mock1 = ParallelAgent(
name="seq_1",
sub_agents=[LlmAgent(name="mock_1", model=OneTwoThreeModel())],
)
mock2 = LlmAgent(name="mock_2", model=OneTwoThreeModel())

subagent = SubagentClass(
name="subagent",
sub_agents=[mock1, mock2],
)

map = MapAgent(
name="map_agent",
sub_agents=[subagent],
)

runner = TestInMemoryRunner(map)

input_data = json.dumps(["0", "1"])
expected_output = [
json.dumps([str(j) for j in range(i, i + 3)]) for i in [0, 1, 0, 1]
]

events = await runner.run_async_with_new_session(input_data)
res = extract_event_text(events, "mock_")
assert res == expected_output


@pytest.mark.asyncio
async def test_map_agent_with_map_agent():
mock_leaf = LlmAgent(name="nested_mock", model=OneTwoThreeModel())

inner_map = MapAgent(
name="inner_map",
sub_agents=[mock_leaf],
)

outer_map = MapAgent(
name="outer_map",
sub_agents=[inner_map],
)

runner = TestInMemoryRunner(outer_map)

input_data = json.dumps(
[json.dumps([str(i), str(i + 1)]) for i in [10, 20, 30]]
)
expected_output = [
json.dumps([str(j) for j in range(i, i + 3)])
for i in [10, 11, 20, 21, 30, 31]
]

events = await runner.run_async_with_new_session(input_data)

res = [e for e in events if e.author.startswith("nested_mock")]
res = sorted(
res,
key=lambda e: (
e.author,
((e.content or types.Content()).parts or [types.Part()])[0].text
or "",
),
)
res = [
((e.content or types.Content()).parts or [types.Part()])[0].text or ""
for e in res
]
assert len(res) == 6
assert res == expected_output


@pytest.mark.asyncio
async def test_map_agent_tree():
inner_map = MapAgent(
name="map_inner",
sub_agents=[LlmAgent(name="mock_1", model=OneTwoThreeModel())],
)

main_loop = LoopAgent(
name="main_sequential",
sub_agents=[
LlmAgent(name="mock_0", model=OneTwoThreeModel()),
inner_map,
],
max_iterations=1,
)

outer_map = MapAgent(
name="map_outer",
sub_agents=[main_loop],
)

runner = TestInMemoryRunner(outer_map)

input_data = json.dumps(["0", "1"])
expected_output = [
json.dumps([str(j) for j in range(i, i + 3)])
for i in [0, 1, 0, 1, 2, 1, 2, 3]
]

events = await runner.run_async_with_new_session(input_data)
res = extract_event_text(events, "mock_")

assert res == expected_output