Skip to content

Add vendor_id and finish_reason to Gemini/Google model responses #1800

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 27, 2025
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ class ModelResponse:
kind: Literal['response'] = 'response'
"""Message type identifier, this is available on all parts as a discriminator."""

vendor_details: dict[str, Any] | None = field(default=None, repr=False)
vendor_details: dict[str, Any] | None = field(default=None)
"""Additional vendor-specific details in a serializable format.

This allows storing selected vendor-specific data that isn't mapped to standard ModelResponse fields.
Expand Down
25 changes: 22 additions & 3 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ async def _make_request(
yield r

def _process_response(self, response: _GeminiResponse) -> ModelResponse:
vendor_details: dict[str, Any] | None = None

if len(response['candidates']) != 1:
raise UnexpectedModelBehavior('Expected exactly one candidate in Gemini response') # pragma: no cover
if 'content' not in response['candidates'][0]:
Expand All @@ -273,9 +275,19 @@ def _process_response(self, response: _GeminiResponse) -> ModelResponse:
'Content field missing from Gemini response', str(response)
)
parts = response['candidates'][0]['content']['parts']
vendor_id = response.get('vendor_id', None)
finish_reason = response['candidates'][0].get('finish_reason')
if finish_reason:
vendor_details = {'finish_reason': finish_reason}
usage = _metadata_as_usage(response)
usage.requests = 1
return _process_response_from_parts(parts, response.get('model_version', self._model_name), usage)
return _process_response_from_parts(
parts,
response.get('model_version', self._model_name),
usage,
vendor_id=vendor_id,
vendor_details=vendor_details,
)

async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
"""Process a streamed response, and prepare a streaming response to return."""
Expand Down Expand Up @@ -597,7 +609,11 @@ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart


def _process_response_from_parts(
parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, usage: usage.Usage
parts: Sequence[_GeminiPartUnion],
model_name: GeminiModelName,
usage: usage.Usage,
vendor_id: str | None,
vendor_details: dict[str, Any] | None = None,
) -> ModelResponse:
items: list[ModelResponsePart] = []
for part in parts:
Expand All @@ -609,7 +625,9 @@ def _process_response_from_parts(
raise UnexpectedModelBehavior(
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
)
return ModelResponse(parts=items, usage=usage, model_name=model_name)
return ModelResponse(
parts=items, usage=usage, model_name=model_name, vendor_id=vendor_id, vendor_details=vendor_details
)


class _GeminiFunctionCall(TypedDict):
Expand Down Expand Up @@ -721,6 +739,7 @@ class _GeminiResponse(TypedDict):
usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]]
prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]]
model_version: NotRequired[Annotated[str, pydantic.Field(alias='modelVersion')]]
vendor_id: NotRequired[Annotated[str, pydantic.Field(alias='responseId')]]


class _GeminiCandidates(TypedDict):
Expand Down
23 changes: 19 additions & 4 deletions pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from contextlib import asynccontextmanager
from dataclasses import dataclass, field, replace
from datetime import datetime
from typing import Literal, Union, cast, overload
from typing import Any, Literal, Union, cast, overload
from uuid import uuid4

from typing_extensions import assert_never
Expand Down Expand Up @@ -287,9 +287,16 @@ def _process_response(self, response: GenerateContentResponse) -> ModelResponse:
'Content field missing from Gemini response', str(response)
) # pragma: no cover
parts = response.candidates[0].content.parts or []
vendor_id = response.response_id or None
vendor_details: dict[str, Any] | None = None
finish_reason = response.candidates[0].finish_reason
if finish_reason: # pragma: no branch
vendor_details = {'finish_reason': finish_reason.value}
usage = _metadata_as_usage(response)
usage.requests = 1
return _process_response_from_parts(parts, response.model_version or self._model_name, usage)
return _process_response_from_parts(
parts, response.model_version or self._model_name, usage, vendor_id=vendor_id, vendor_details=vendor_details
)

async def _process_streamed_response(self, response: AsyncIterator[GenerateContentResponse]) -> StreamedResponse:
"""Process a streamed response, and prepare a streaming response to return."""
Expand Down Expand Up @@ -435,7 +442,13 @@ def _content_model_response(m: ModelResponse) -> ContentDict:
return ContentDict(role='model', parts=parts)


def _process_response_from_parts(parts: list[Part], model_name: GoogleModelName, usage: usage.Usage) -> ModelResponse:
def _process_response_from_parts(
parts: list[Part],
model_name: GoogleModelName,
usage: usage.Usage,
vendor_id: str | None,
vendor_details: dict[str, Any] | None = None,
) -> ModelResponse:
items: list[ModelResponsePart] = []
for part in parts:
if part.text:
Expand All @@ -450,7 +463,9 @@ def _process_response_from_parts(parts: list[Part], model_name: GoogleModelName,
raise UnexpectedModelBehavior(
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
)
return ModelResponse(parts=items, model_name=model_name, usage=usage)
return ModelResponse(
parts=items, model_name=model_name, usage=usage, vendor_id=vendor_id, vendor_details=vendor_details
)


def _function_declaration_from_tool(tool: ToolDefinition) -> FunctionDeclarationDict:
Expand Down
25 changes: 25 additions & 0 deletions tests/models/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ async def test_text_success(get_gemini_client: GetGeminiClient):
usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}),
model_name='gemini-1.5-flash-123',
timestamp=IsNow(tz=timezone.utc),
vendor_details={'finish_reason': 'STOP'},
),
]
)
Expand All @@ -555,13 +556,15 @@ async def test_text_success(get_gemini_client: GetGeminiClient):
usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}),
model_name='gemini-1.5-flash-123',
timestamp=IsNow(tz=timezone.utc),
vendor_details={'finish_reason': 'STOP'},
),
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
ModelResponse(
parts=[TextPart(content='Hello world')],
usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}),
model_name='gemini-1.5-flash-123',
timestamp=IsNow(tz=timezone.utc),
vendor_details={'finish_reason': 'STOP'},
),
]
)
Expand All @@ -585,6 +588,7 @@ async def test_request_structured_response(get_gemini_client: GetGeminiClient):
usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}),
model_name='gemini-1.5-flash-123',
timestamp=IsNow(tz=timezone.utc),
vendor_details={'finish_reason': 'STOP'},
),
ModelRequest(
parts=[
Expand Down Expand Up @@ -647,6 +651,7 @@ async def get_location(loc_name: str) -> str:
usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}),
model_name='gemini-1.5-flash-123',
timestamp=IsNow(tz=timezone.utc),
vendor_details={'finish_reason': 'STOP'},
),
ModelRequest(
parts=[
Expand All @@ -666,6 +671,7 @@ async def get_location(loc_name: str) -> str:
usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}),
model_name='gemini-1.5-flash-123',
timestamp=IsNow(tz=timezone.utc),
vendor_details={'finish_reason': 'STOP'},
),
ModelRequest(
parts=[
Expand All @@ -688,6 +694,7 @@ async def get_location(loc_name: str) -> str:
usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3, details={}),
model_name='gemini-1.5-flash-123',
timestamp=IsNow(tz=timezone.utc),
vendor_details={'finish_reason': 'STOP'},
),
]
)
Expand Down Expand Up @@ -1099,6 +1106,7 @@ async def get_image() -> BinaryContent:
usage=Usage(requests=1, request_tokens=38, response_tokens=28, total_tokens=427, details={}),
model_name='gemini-2.5-pro-preview-03-25',
timestamp=IsDatetime(),
vendor_details={'finish_reason': 'STOP'},
),
ModelRequest(
parts=[
Expand All @@ -1122,6 +1130,7 @@ async def get_image() -> BinaryContent:
usage=Usage(requests=1, request_tokens=360, response_tokens=11, total_tokens=572, details={}),
model_name='gemini-2.5-pro-preview-03-25',
timestamp=IsDatetime(),
vendor_details={'finish_reason': 'STOP'},
),
]
)
Expand Down Expand Up @@ -1232,6 +1241,7 @@ async def test_gemini_model_instructions(allow_model_requests: None, gemini_api_
usage=Usage(requests=1, request_tokens=13, response_tokens=8, total_tokens=21, details={}),
model_name='gemini-1.5-flash',
timestamp=IsDatetime(),
vendor_details={'finish_reason': 'STOP'},
),
]
)
Expand Down Expand Up @@ -1272,3 +1282,18 @@ async def get_temperature(location: dict[str, CurrentLocation]) -> float: # pra
assert result.output == snapshot(
'I need a location dictionary to use the `get_temperature` function. I cannot provide the temperature in Tokyo without more information.\n'
)


async def test_gemini_no_finish_reason(get_gemini_client: GetGeminiClient):
response = gemini_response(
_content_model_response(ModelResponse(parts=[TextPart('Hello world')])), finish_reason=None
)
gemini_client = get_gemini_client(response)
m = GeminiModel('gemini-1.5-flash', provider=GoogleGLAProvider(http_client=gemini_client))
agent = Agent(m)

result = await agent.run('Hello World')

for message in result.all_messages():
if isinstance(message, ModelResponse):
assert message.vendor_details is None
6 changes: 6 additions & 0 deletions tests/models/test_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ async def test_google_model(allow_model_requests: None, google_provider: GoogleP
usage=Usage(requests=1, request_tokens=7, response_tokens=11, total_tokens=18, details={}),
model_name='gemini-1.5-flash',
timestamp=IsDatetime(),
vendor_details={'finish_reason': 'STOP'},
),
]
)
Expand Down Expand Up @@ -138,6 +139,7 @@ async def temperature(city: str, date: datetime.date) -> str:
usage=Usage(requests=1, request_tokens=101, response_tokens=14, total_tokens=115, details={}),
model_name='gemini-1.5-flash',
timestamp=IsDatetime(),
vendor_details={'finish_reason': 'STOP'},
),
ModelRequest(
parts=[
Expand All @@ -157,6 +159,7 @@ async def temperature(city: str, date: datetime.date) -> str:
usage=Usage(requests=1, request_tokens=123, response_tokens=21, total_tokens=144, details={}),
model_name='gemini-1.5-flash',
timestamp=IsDatetime(),
vendor_details={'finish_reason': 'STOP'},
),
ModelRequest(
parts=[
Expand Down Expand Up @@ -215,6 +218,7 @@ async def get_capital(country: str) -> str:
),
model_name='models/gemini-2.5-pro-preview-05-06',
timestamp=IsDatetime(),
vendor_details={'finish_reason': 'STOP'},
),
ModelRequest(
parts=[
Expand All @@ -235,6 +239,7 @@ async def get_capital(country: str) -> str:
usage=Usage(requests=1, request_tokens=104, response_tokens=18, total_tokens=122, details={}),
model_name='models/gemini-2.5-pro-preview-05-06',
timestamp=IsDatetime(),
vendor_details={'finish_reason': 'STOP'},
),
]
)
Expand Down Expand Up @@ -469,6 +474,7 @@ def instructions() -> str:
usage=Usage(requests=1, request_tokens=13, response_tokens=8, total_tokens=21, details={}),
model_name='gemini-2.0-flash',
timestamp=IsDatetime(),
vendor_details={'finish_reason': 'STOP'},
),
]
)
Expand Down