Skip to content

Commit 9c96a24

Browse files
committed
feat(runner): add metadata parameter to run(), run_live(), run_debug()
Add metadata support to all run methods for consistency: - run(): sync wrapper, passes metadata to run_async() - run_live(): live mode, passes metadata through invocation context - run_debug(): debug helper, passes metadata to run_async() Also update InvocationContext docstring to reflect all supported entry points.
1 parent a33506f commit 9c96a24

File tree

3 files changed

+136
-1
lines changed

3 files changed

+136
-1
lines changed

src/google/adk/agents/invocation_context.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,12 +207,14 @@ class InvocationContext(BaseModel):
207207
"""The cache of canonical tools for this invocation."""
208208

209209
metadata: Optional[dict[str, Any]] = None
210-
"""Per-request metadata passed from Runner.run_async().
210+
"""Per-request metadata passed from Runner entry points.
211211
212212
This field allows passing arbitrary metadata that can be accessed during
213213
the invocation lifecycle, particularly in callbacks like before_model_callback.
214214
Common use cases include passing user_id, trace_id, memory context keys, or
215215
other request-specific context that needs to be available during processing.
216+
217+
Supported entry points: run(), run_async(), run_live(), run_debug().
216218
"""
217219

218220
_invocation_cost_manager: _InvocationCostManager = PrivateAttr(

src/google/adk/runners.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ def run(
346346
session_id: str,
347347
new_message: types.Content,
348348
run_config: Optional[RunConfig] = None,
349+
metadata: Optional[dict[str, Any]] = None,
349350
) -> Generator[Event, None, None]:
350351
"""Runs the agent.
351352
@@ -363,6 +364,7 @@ def run(
363364
session_id: The session ID of the session.
364365
new_message: A new message to append to the session.
365366
run_config: The run config for the agent.
367+
metadata: Optional per-request metadata that will be passed to callbacks.
366368
367369
Yields:
368370
The events generated by the agent.
@@ -378,6 +380,7 @@ async def _invoke_run_async():
378380
session_id=session_id,
379381
new_message=new_message,
380382
run_config=run_config,
383+
metadata=metadata,
381384
)
382385
) as agen:
383386
async for event in agen:
@@ -902,6 +905,7 @@ async def run_live(
902905
live_request_queue: LiveRequestQueue,
903906
run_config: Optional[RunConfig] = None,
904907
session: Optional[Session] = None,
908+
metadata: Optional[dict[str, Any]] = None,
905909
) -> AsyncGenerator[Event, None]:
906910
"""Runs the agent in live mode (experimental feature).
907911
@@ -943,6 +947,7 @@ async def run_live(
943947
run_config: The run config for the agent.
944948
session: The session to use. This parameter is deprecated, please use
945949
`user_id` and `session_id` instead.
950+
metadata: Optional per-request metadata that will be passed to callbacks.
946951
947952
Yields:
948953
AsyncGenerator[Event, None]: An asynchronous generator that yields
@@ -957,6 +962,7 @@ async def run_live(
957962
Either `session` or both `user_id` and `session_id` must be provided.
958963
"""
959964
run_config = run_config or RunConfig()
965+
metadata = metadata.copy() if metadata is not None else None
960966
# Some native audio models requires the modality to be set. So we set it to
961967
# AUDIO by default.
962968
if run_config.response_modalities is None:
@@ -982,6 +988,7 @@ async def run_live(
982988
session,
983989
live_request_queue=live_request_queue,
984990
run_config=run_config,
991+
metadata=metadata,
985992
)
986993

987994
root_agent = self.agent
@@ -1127,6 +1134,7 @@ async def run_debug(
11271134
run_config: RunConfig | None = None,
11281135
quiet: bool = False,
11291136
verbose: bool = False,
1137+
metadata: dict[str, Any] | None = None,
11301138
) -> list[Event]:
11311139
"""Debug helper for quick agent experimentation and testing.
11321140
@@ -1150,6 +1158,7 @@ async def run_debug(
11501158
shown).
11511159
verbose: If True, shows detailed tool calls and responses. Defaults to
11521160
False for cleaner output showing only final agent responses.
1161+
metadata: Optional per-request metadata that will be passed to callbacks.
11531162
11541163
Returns:
11551164
list[Event]: All events from all messages.
@@ -1212,6 +1221,7 @@ async def run_debug(
12121221
session_id=session.id,
12131222
new_message=types.UserContent(parts=[types.Part(text=message)]),
12141223
run_config=run_config,
1224+
metadata=metadata,
12151225
):
12161226
if not quiet:
12171227
print_event(event, verbose=verbose)
@@ -1401,6 +1411,7 @@ def _new_invocation_context_for_live(
14011411
*,
14021412
live_request_queue: Optional[LiveRequestQueue] = None,
14031413
run_config: Optional[RunConfig] = None,
1414+
metadata: Optional[dict[str, Any]] = None,
14041415
) -> InvocationContext:
14051416
"""Creates a new invocation context for live multi-agent."""
14061417
run_config = run_config or RunConfig()
@@ -1427,6 +1438,7 @@ def _new_invocation_context_for_live(
14271438
session,
14281439
live_request_queue=live_request_queue,
14291440
run_config=run_config,
1441+
metadata=metadata,
14301442
)
14311443

14321444
async def _handle_new_message(

tests/unittests/test_runners.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
from google.adk.agents.base_agent import BaseAgent
2424
from google.adk.agents.context_cache_config import ContextCacheConfig
25+
from google.adk.agents.live_request_queue import LiveRequestQueue
2526
from google.adk.agents.invocation_context import InvocationContext
2627
from google.adk.agents.llm_agent import LlmAgent
2728
from google.adk.agents.run_config import RunConfig
@@ -35,6 +36,7 @@
3536
from google.adk.plugins.base_plugin import BasePlugin
3637
from google.adk.runners import Runner
3738
from google.adk.sessions.in_memory_session_service import InMemorySessionService
39+
from tests.unittests import testing_utils
3840
from google.adk.sessions.session import Session
3941
from google.genai import types
4042
import pytest
@@ -1344,6 +1346,125 @@ def before_model_callback(callback_context, llm_request):
13441346
# Nested object changes in callback WILL affect original (shallow copy behavior)
13451347
assert original_metadata["nested"]["inner_key"] == "modified_nested"
13461348

1349+
def test_new_invocation_context_for_live_with_metadata(self):
1350+
"""Test that _new_invocation_context_for_live correctly passes metadata."""
1351+
mock_session = Session(
1352+
id=TEST_SESSION_ID,
1353+
app_name=TEST_APP_ID,
1354+
user_id=TEST_USER_ID,
1355+
events=[],
1356+
)
1357+
1358+
test_metadata = {"user_id": "live_user", "trace_id": "live_trace"}
1359+
invocation_context = self.runner._new_invocation_context_for_live(
1360+
mock_session, metadata=test_metadata
1361+
)
1362+
1363+
assert invocation_context.metadata == test_metadata
1364+
assert invocation_context.metadata["user_id"] == "live_user"
1365+
1366+
@pytest.mark.asyncio
1367+
async def test_run_sync_passes_metadata(self):
1368+
"""Test that sync run() correctly passes metadata to run_async()."""
1369+
captured_metadata = None
1370+
1371+
def before_model_callback(callback_context, llm_request):
1372+
nonlocal captured_metadata
1373+
captured_metadata = llm_request.metadata
1374+
return LlmResponse(
1375+
content=types.Content(
1376+
role="model", parts=[types.Part(text="Test response")]
1377+
)
1378+
)
1379+
1380+
agent_with_callback = LlmAgent(
1381+
name="callback_agent",
1382+
model="gemini-2.0-flash",
1383+
before_model_callback=before_model_callback,
1384+
)
1385+
1386+
runner_with_callback = Runner(
1387+
app_name="test_app",
1388+
agent=agent_with_callback,
1389+
session_service=self.session_service,
1390+
artifact_service=self.artifact_service,
1391+
)
1392+
1393+
await self.session_service.create_session(
1394+
app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID
1395+
)
1396+
1397+
test_metadata = {"sync_key": "sync_value"}
1398+
1399+
for event in runner_with_callback.run(
1400+
user_id=TEST_USER_ID,
1401+
session_id=TEST_SESSION_ID,
1402+
new_message=types.Content(
1403+
role="user", parts=[types.Part(text="Hello")]
1404+
),
1405+
metadata=test_metadata,
1406+
):
1407+
pass
1408+
1409+
assert captured_metadata is not None
1410+
assert captured_metadata["sync_key"] == "sync_value"
1411+
1412+
@pytest.mark.asyncio
1413+
async def test_run_live_passes_metadata_to_llm_request(self):
1414+
"""Test that run_live() passes metadata through live pipeline to LlmRequest."""
1415+
import asyncio
1416+
1417+
# Create MockModel to capture LlmRequest
1418+
mock_model = testing_utils.MockModel.create(
1419+
responses=[
1420+
LlmResponse(
1421+
content=types.Content(
1422+
role="model", parts=[types.Part(text="Live response")]
1423+
)
1424+
)
1425+
]
1426+
)
1427+
1428+
agent_with_mock = LlmAgent(
1429+
name="live_mock_agent",
1430+
model=mock_model,
1431+
)
1432+
1433+
runner_with_mock = Runner(
1434+
app_name="test_app",
1435+
agent=agent_with_mock,
1436+
session_service=self.session_service,
1437+
artifact_service=self.artifact_service,
1438+
)
1439+
1440+
await self.session_service.create_session(
1441+
app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID
1442+
)
1443+
1444+
test_metadata = {"live_key": "live_value", "trace_id": "live_trace_123"}
1445+
live_queue = LiveRequestQueue()
1446+
live_queue.close() # Close immediately to end the live session
1447+
1448+
async def consume_events():
1449+
async for event in runner_with_mock.run_live(
1450+
user_id=TEST_USER_ID,
1451+
session_id=TEST_SESSION_ID,
1452+
live_request_queue=live_queue,
1453+
metadata=test_metadata,
1454+
):
1455+
pass
1456+
1457+
try:
1458+
await asyncio.wait_for(consume_events(), timeout=2)
1459+
except asyncio.TimeoutError:
1460+
pass # Expected - live session may not terminate cleanly
1461+
1462+
# Verify MockModel received LlmRequest with correct metadata
1463+
assert len(mock_model.requests) > 0
1464+
assert mock_model.requests[0].metadata is not None
1465+
assert mock_model.requests[0].metadata["live_key"] == "live_value"
1466+
assert mock_model.requests[0].metadata["trace_id"] == "live_trace_123"
1467+
13471468

13481469
if __name__ == "__main__":
13491470
pytest.main([__file__])

0 commit comments

Comments
 (0)