Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/google/adk/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .llm_agent import Agent
from .llm_agent import LlmAgent
from .loop_agent import LoopAgent
from .map_agent import MapAgent
from .mcp_instruction_provider import McpInstructionProvider
from .parallel_agent import ParallelAgent
from .run_config import RunConfig
Expand All @@ -29,6 +30,7 @@
'BaseAgent',
'LlmAgent',
'LoopAgent',
'MapAgent',
'McpInstructionProvider',
'ParallelAgent',
'SequentialAgent',
Expand Down
158 changes: 158 additions & 0 deletions src/google/adk/agents/map_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from typing import Annotated
from typing import AsyncGenerator

from annotated_types import Len
from google.adk.agents import BaseAgent
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.parallel_agent import _create_branch_ctx_for_sub_agent
from google.adk.agents.parallel_agent import _merge_agent_run
from google.adk.events import Event
from google.adk.flows.llm_flows.contents import _should_include_event_in_context
from google.genai import types
from pydantic import Field
from pydantic import RootModel
from typing_extensions import override


class MapAgent(BaseAgent):
sub_agents: Annotated[list[BaseAgent], Len(1, 1)] = Field(
min_length=1,
max_length=1,
default_factory=list,
description=(
"A single base agent that will be copied and invoked for each prompt"
),
)

@override
async def _run_async_impl(
self, invocation_context: InvocationContext
) -> AsyncGenerator[Event, None]:
"""Core logic of this workflow agent.

Args:
invocation_context: InvocationContext, provides access to the input prompts.

Yields:
Event: the events generated by the sub-agent for each input prompt.
"""

# Create a branch string if it doesn't exist, to ensure parallel invocations don't interfere with each other
prompts, invoker = self._extract_input_prompts(invocation_context)

# for agent naming - e.g. if there are 100-999 prompts, sub-agent copies are named 001, 002, 003 and so on
number_field_width = len(str(len(prompts)))

# Create a separate invocation context for each prompt, each with a numbered copy of the sub-agent.
contexts = [
self._branch_context(
invocation_context,
idx=i,
prompt=prompt,
invoker=invoker,
width=number_field_width,
)
for i, prompt in enumerate(prompts)
]

async for event in _merge_agent_run(
[ctx.agent.run_async(ctx) for ctx in contexts]
):
yield event

def _extract_input_prompts(
self, ctx: InvocationContext
) -> tuple[list[str], str]:
"""
The input to the map agent is a list of strings.
We extract the text content from the latest event, and assume it is a list of strings serialized as a json string.
"""
invoker = "user"

for i in range(len(ctx.session.events) - 1, -1, -1):
event = ctx.session.events[i]
if _should_include_event_in_context(ctx.branch, event):
break
else:
return [], "user"

invoker: str = event.author
input_message: str = (
(event.content or types.Content()).parts or [types.Part()]
)[0].text or ""

# Remove the event which has the prompt list, so that a sub agent does not
# see the prompts of its siblings, which may confuse it.
# The event is removed only for this invocation.
ctx.session.events.pop(i)

agent_input = RootModel[list[str]].model_validate_json(input_message).root

return agent_input, invoker

@staticmethod
def _get_unique_name(name: str, idx: int, width: int) -> str:
"""e.g. my_sub_agent_046"""
return f"{name}_{idx:0{width}d}"

def _branch_context(
self,
ctx: InvocationContext,
*,
prompt: str,
invoker: str,
idx: int,
width: int,
) -> InvocationContext:
"""Creates a numbered copy of the sub-agent that sees a single prompt, and can run separately from its siblings.

Args:
ctx: The current invocation context of the map agent. To be copied and edited for the sub-agent copy.
prompt: the prompt on which the sub-agent copy should be invoked
invoker: the invoker of the map agent in this invocation.
idx: index of the prompt in the input prompts, serves as a unique postfix to the agent name
width: number of digits in the total number of prompts, to ensure naming is consistent in field width
(e.g. 001, 002, ... 010, 011, ... 100, 101; and not 1, 2, ... 10, 11, ... 100, 101)

Returns:
InvocationContext: A new invocation context ready to run with the unique sub-agent copy and the prompt
"""

agent = self._branch_agent_tree(self.sub_agents[0], idx, width)

prompt_part = [types.Part(text=prompt)]

# Add the prompt to the user_content of this branch to easily access agent input in callbacks
user_content = types.Content(
role="user",
parts=((ctx.user_content or types.Content()).parts or []) + prompt_part,
)

new_ctx = _create_branch_ctx_for_sub_agent(self, agent, ctx).model_copy(
update=dict(agent=agent, user_content=user_content)
)

# Add the prompt as a temporary event of this branch in place of the prompt list as the natural input of the sub-agent.
prompt_content = types.Content(
role="user" if invoker == "user" else "model", parts=prompt_part
)
new_ctx.session.events.append(
Event(author=invoker, branch=new_ctx.branch, content=prompt_content)
)

return new_ctx

def _branch_agent_tree(
self, agent: BaseAgent, idx: int, width: int
) -> BaseAgent:
"""
Clone and rename an agent and its sub-tree to create a thread-safe branch.
"""
new_agent = agent.model_copy(
update={"name": self._get_unique_name(agent.name, idx=idx, width=width)}
)

new_agent.sub_agents = [
self._branch_agent_tree(a, idx, width) for a in agent.sub_agents
]
return new_agent
Loading