Skip to content

Commit b287215

Browse files
GWealecopybara-github
authored andcommitted
feat: make LlmAgent.model optional with a default fallback
LlmAgent now resolves model from ancestors or a system default (gemini-2.5-flash) when unset. Added LlmAgent.set_default_model() to override the default globally Co-authored-by: George Weale <[email protected]> PiperOrigin-RevId: 853006116
1 parent 742c926 commit b287215

File tree

10 files changed

+73
-28
lines changed

10 files changed

+73
-28
lines changed

src/google/adk/agents/config_schemas/AgentConfig.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2461,7 +2461,7 @@
24612461
}
24622462
],
24632463
"default": null,
2464-
"description": "Optional. LlmAgent.model. If not set, the model will be inherited from the ancestor.",
2464+
"description": "Optional. LlmAgent.model. Provide a model name string (e.g. \"gemini-2.0-flash\"). If not set, the model will be inherited from the ancestor or fall back to the system default (gemini-2.5-flash unless overridden via LlmAgent.set_default_model). To construct a model instance from code, use model_code.",
24652465
"title": "Model"
24662466
},
24672467
"instruction": {
@@ -4601,4 +4601,4 @@
46014601
}
46024602
],
46034603
"title": "AgentConfig"
4604-
}
4604+
}

src/google/adk/agents/llm_agent.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,18 @@ async def _convert_tool_union_to_tools(
183183
class LlmAgent(BaseAgent):
184184
"""LLM-based Agent."""
185185

186+
DEFAULT_MODEL: ClassVar[str] = 'gemini-2.5-flash'
187+
"""System default model used when no model is set on an agent."""
188+
189+
_default_model: ClassVar[Union[str, BaseLlm]] = DEFAULT_MODEL
190+
"""Current default model used when an agent has no model set."""
191+
186192
model: Union[str, BaseLlm] = ''
187193
"""The model to use for the agent.
188194
189-
When not set, the agent will inherit the model from its ancestor.
195+
When not set, the agent will inherit the model from its ancestor. If no
196+
ancestor provides a model, the agent uses the default model configured via
197+
LlmAgent.set_default_model. The built-in default is gemini-2.5-flash.
190198
"""
191199

192200
config_type: ClassVar[Type[BaseAgentConfig]] = LlmAgentConfig
@@ -503,7 +511,24 @@ def canonical_model(self) -> BaseLlm:
503511
if isinstance(ancestor_agent, LlmAgent):
504512
return ancestor_agent.canonical_model
505513
ancestor_agent = ancestor_agent.parent_agent
506-
raise ValueError(f'No model found for {self.name}.')
514+
return self._resolve_default_model()
515+
516+
@classmethod
517+
def set_default_model(cls, model: Union[str, BaseLlm]) -> None:
518+
"""Overrides the default model used when an agent has no model set."""
519+
if not isinstance(model, (str, BaseLlm)):
520+
raise TypeError('Default model must be a model name or BaseLlm.')
521+
if isinstance(model, str) and not model:
522+
raise ValueError('Default model must be a non-empty string.')
523+
cls._default_model = model
524+
525+
@classmethod
526+
def _resolve_default_model(cls) -> BaseLlm:
527+
"""Resolves the current default model to a BaseLlm instance."""
528+
default_model = cls._default_model
529+
if isinstance(default_model, BaseLlm):
530+
return default_model
531+
return LLMRegistry.new_llm(default_model)
507532

508533
async def canonical_instruction(
509534
self, ctx: ReadonlyContext
@@ -575,10 +600,11 @@ async def canonical_tools(
575600
# because the built-in tools cannot be used together with other tools.
576601
# TODO(b/448114567): Remove once the workaround is no longer needed.
577602
multiple_tools = len(self.tools) > 1
603+
model = self.canonical_model
578604
for tool_union in self.tools:
579605
resolved_tools.extend(
580606
await _convert_tool_union_to_tools(
581-
tool_union, ctx, self.model, multiple_tools
607+
tool_union, ctx, model, multiple_tools
582608
)
583609
)
584610
return resolved_tools

src/google/adk/agents/llm_agent_config.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,9 @@ class LlmAgentConfig(BaseAgentConfig):
5656
description=(
5757
'Optional. LlmAgent.model. Provide a model name string (e.g.'
5858
' "gemini-2.0-flash"). If not set, the model will be inherited from'
59-
' the ancestor. To construct a model instance from code, use'
60-
' model_code.'
59+
' the ancestor or fall back to the system default (gemini-2.5-flash'
60+
' unless overridden via LlmAgent.set_default_model). To construct a'
61+
' model instance from code, use model_code.'
6162
),
6263
)
6364

src/google/adk/flows/llm_flows/_output_schema_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ async def run_async(
4545
if (
4646
not agent.output_schema
4747
or not agent.tools
48-
or can_use_output_schema_with_tools(agent.model)
48+
or can_use_output_schema_with_tools(agent.canonical_model)
4949
):
5050
return
5151

src/google/adk/flows/llm_flows/base_llm_flow.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,11 @@ async def _preprocess_async(
476476
# We may need to wrap some built-in tools if there are other tools
477477
# because the built-in tools cannot be used together with other tools.
478478
# TODO(b/448114567): Remove once the workaround is no longer needed.
479+
if not agent.tools:
480+
return
481+
479482
multiple_tools = len(agent.tools) > 1
483+
model = agent.canonical_model
480484
for tool_union in agent.tools:
481485
tool_context = ToolContext(invocation_context)
482486

@@ -492,7 +496,7 @@ async def _preprocess_async(
492496
tools = await _convert_tool_union_to_tools(
493497
tool_union,
494498
ReadonlyContext(invocation_context),
495-
agent.model,
499+
model,
496500
multiple_tools,
497501
)
498502
for tool in tools:

src/google/adk/flows/llm_flows/basic.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,9 @@ class _BasicLlmRequestProcessor(BaseLlmRequestProcessor):
3535
async def run_async(
3636
self, invocation_context: InvocationContext, llm_request: LlmRequest
3737
) -> AsyncGenerator[Event, None]:
38-
from ...agents.llm_agent import LlmAgent
39-
4038
agent = invocation_context.agent
41-
42-
llm_request.model = (
43-
agent.canonical_model
44-
if isinstance(agent.canonical_model, str)
45-
else agent.canonical_model.model
46-
)
39+
model = agent.canonical_model
40+
llm_request.model = model if isinstance(model, str) else model.model
4741
llm_request.config = (
4842
agent.generate_content_config.model_copy(deep=True)
4943
if agent.generate_content_config
@@ -54,7 +48,7 @@ async def run_async(
5448
# both output_schema and tools at the same time. see
5549
# _output_schema_processor.py for details
5650
if agent.output_schema:
57-
if not agent.tools or can_use_output_schema_with_tools(agent.model):
51+
if not agent.tools or can_use_output_schema_with_tools(model):
5852
llm_request.set_output_schema(agent.output_schema)
5953

6054
llm_request.live_connect_config.response_modalities = (

src/google/adk/flows/llm_flows/interactions_processor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,10 @@ async def run_async(
5353
# Only process if using Gemini with interactions API
5454
if not isinstance(agent, LlmAgent):
5555
return
56-
if not isinstance(agent.model, Gemini):
56+
model = agent.canonical_model
57+
if not isinstance(model, Gemini):
5758
return
58-
if not agent.model.use_interactions_api:
59+
if not model.use_interactions_api:
5960
return
6061
# Extract previous interaction ID from session events
6162
previous_interaction_id = self._find_previous_interaction_id(

tests/unittests/agents/test_llm_agent_fields.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,24 @@ async def _create_readonly_context(
5252
return ReadonlyContext(invocation_context)
5353

5454

55-
def test_canonical_model_empty():
56-
agent = LlmAgent(name='test_agent')
57-
58-
with pytest.raises(ValueError):
59-
_ = agent.canonical_model
55+
@pytest.mark.parametrize(
56+
('default_model', 'expected_model_name', 'expected_model_type'),
57+
[
58+
(LlmAgent.DEFAULT_MODEL, LlmAgent.DEFAULT_MODEL, Gemini),
59+
('gemini-2.0-flash', 'gemini-2.0-flash', Gemini),
60+
],
61+
)
62+
def test_canonical_model_default_fallback(
63+
default_model, expected_model_name, expected_model_type
64+
):
65+
original_default = LlmAgent._default_model
66+
LlmAgent.set_default_model(default_model)
67+
try:
68+
agent = LlmAgent(name='test_agent')
69+
assert isinstance(agent.canonical_model, expected_model_type)
70+
assert agent.canonical_model.model == expected_model_name
71+
finally:
72+
LlmAgent.set_default_model(original_default)
6073

6174

6275
def test_canonical_model_str():

tests/unittests/flows/llm_flows/test_basic_processor.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@ async def test_skips_output_schema_when_tools_present(self, mocker):
110110
assert llm_request.config.response_mime_type != 'application/json'
111111

112112
# Should have checked if output schema can be used with tools
113-
can_use_output_schema_with_tools.assert_called_once_with(agent.model)
113+
can_use_output_schema_with_tools.assert_called_once_with(
114+
agent.canonical_model
115+
)
114116

115117
@pytest.mark.asyncio
116118
async def test_sets_output_schema_when_tools_present(self, mocker):
@@ -141,7 +143,9 @@ async def test_sets_output_schema_when_tools_present(self, mocker):
141143
assert llm_request.config.response_mime_type == 'application/json'
142144

143145
# Should have checked if output schema can be used with tools
144-
can_use_output_schema_with_tools.assert_called_once_with(agent.model)
146+
can_use_output_schema_with_tools.assert_called_once_with(
147+
agent.canonical_model
148+
)
145149

146150
@pytest.mark.asyncio
147151
async def test_no_output_schema_no_tools(self):

tests/unittests/flows/llm_flows/test_output_schema_processor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,9 @@ async def test_output_schema_request_processor(
191191
assert not llm_request.config.system_instruction
192192

193193
# Should have checked if output schema can be used with tools
194-
can_use_output_schema_with_tools.assert_called_once_with(agent.model)
194+
can_use_output_schema_with_tools.assert_called_once_with(
195+
agent.canonical_model
196+
)
195197

196198

197199
@pytest.mark.asyncio

0 commit comments

Comments
 (0)