Skip to content

Commit 23d330e

Browse files
wukathcopybara-github
authored andcommitted
feat: Include model ID with token usage for live events
This allows users to track token usage data per model and fixes #4084. Co-authored-by: Kathy Wu <[email protected]> PiperOrigin-RevId: 853925212
1 parent b8917bc commit 23d330e

File tree

4 files changed

+42
-18
lines changed

4 files changed

+42
-18
lines changed

src/google/adk/models/gemini_llm_connection.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,13 @@ def __init__(
4141
self,
4242
gemini_session: live.AsyncSession,
4343
api_backend: GoogleLLMVariant = GoogleLLMVariant.VERTEX_AI,
44+
model_version: str | None = None,
4445
):
4546
self._gemini_session = gemini_session
4647
self._input_transcription_text: str = ''
4748
self._output_transcription_text: str = ''
4849
self._api_backend = api_backend
50+
self._model_version = model_version
4951

5052
async def send_history(self, history: list[types.Content]):
5153
"""Sends the conversation history to the gemini model.
@@ -162,7 +164,11 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
162164
async for message in agen:
163165
logger.debug('Got LLM Live message: %s', message)
164166
if message.usage_metadata:
165-
yield LlmResponse(usage_metadata=message.usage_metadata)
167+
# Tracks token usage data per model.
168+
yield LlmResponse(
169+
usage_metadata=message.usage_metadata,
170+
model_version=self._model_version,
171+
)
166172
if message.server_content:
167173
content = message.server_content.model_turn
168174
if content and content.parts:

src/google/adk/models/google_llm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,11 @@ async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
402402
async with self._live_api_client.aio.live.connect(
403403
model=llm_request.model, config=llm_request.live_connect_config
404404
) as live_session:
405-
yield GeminiLlmConnection(live_session, api_backend=self._api_backend)
405+
yield GeminiLlmConnection(
406+
live_session,
407+
api_backend=self._api_backend,
408+
model_version=llm_request.model,
409+
)
406410

407411
async def _adapt_computer_use_tool(self, llm_request: LlmRequest) -> None:
408412
"""Adapt the google computer use predefined functions to the adk computer use toolset."""

tests/unittests/models/test_gemini_llm_connection.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from google.genai import types
2020
import pytest
2121

22+
MODEL_VERSION = 'gemini-2.5-pro'
23+
2224

2325
@pytest.fixture
2426
def mock_gemini_session():
@@ -30,15 +32,19 @@ def mock_gemini_session():
3032
def gemini_connection(mock_gemini_session):
3133
"""GeminiLlmConnection instance with mocked session."""
3234
return GeminiLlmConnection(
33-
mock_gemini_session, api_backend=GoogleLLMVariant.VERTEX_AI
35+
mock_gemini_session,
36+
api_backend=GoogleLLMVariant.VERTEX_AI,
37+
model_version=MODEL_VERSION,
3438
)
3539

3640

3741
@pytest.fixture
3842
def gemini_api_connection(mock_gemini_session):
3943
"""GeminiLlmConnection instance with mocked session for Gemini API."""
4044
return GeminiLlmConnection(
41-
mock_gemini_session, api_backend=GoogleLLMVariant.GEMINI_API
45+
mock_gemini_session,
46+
api_backend=GoogleLLMVariant.GEMINI_API,
47+
model_version=MODEL_VERSION,
4248
)
4349

4450

@@ -215,6 +221,7 @@ async def mock_receive_generator():
215221

216222
usage_response = next((r for r in responses if r.usage_metadata), None)
217223
assert usage_response is not None
224+
assert usage_response.model_version == MODEL_VERSION
218225
content_response = next((r for r in responses if r.content), None)
219226
assert content_response is not None
220227

tests/unittests/models/test_google_llm.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -705,20 +705,27 @@ async def __aexit__(self, *args):
705705

706706
mock_live_client.aio.live.connect.return_value = MockLiveConnect()
707707

708-
async with gemini_llm.connect(llm_request) as connection:
709-
# Verify that the connect method was called with the right config
710-
mock_live_client.aio.live.connect.assert_called_once()
711-
call_args = mock_live_client.aio.live.connect.call_args
712-
config_arg = call_args.kwargs["config"]
713-
714-
# Verify that http_options remains None since no custom headers were provided
715-
assert config_arg.http_options is None
716-
717-
# Verify that system instruction and tools were still set
718-
assert config_arg.system_instruction is not None
719-
assert config_arg.tools == llm_request.config.tools
720-
721-
assert isinstance(connection, GeminiLlmConnection)
708+
with mock.patch(
709+
"google.adk.models.google_llm.GeminiLlmConnection"
710+
) as MockGeminiLlmConnection:
711+
async with gemini_llm.connect(llm_request) as connection:
712+
# Verify that the connect method was called with the right config
713+
mock_live_client.aio.live.connect.assert_called_once()
714+
call_args = mock_live_client.aio.live.connect.call_args
715+
config_arg = call_args.kwargs["config"]
716+
717+
# Verify that http_options remains None since no custom headers were provided
718+
assert config_arg.http_options is None
719+
720+
# Verify that system instruction and tools were still set
721+
assert config_arg.system_instruction is not None
722+
assert config_arg.tools == llm_request.config.tools
723+
724+
MockGeminiLlmConnection.assert_called_once_with(
725+
mock_live_session,
726+
api_backend=gemini_llm._api_backend,
727+
model_version=llm_request.model,
728+
)
722729

723730

724731
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)