Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added rest_old.py
Empty file.
21 changes: 17 additions & 4 deletions src/a2a/client/errors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Custom exceptions for the A2A client."""

from collections.abc import Mapping

from a2a.types import JSONRPCErrorResponse


Expand All @@ -10,16 +12,27 @@ class A2AClientError(Exception):
class A2AClientHTTPError(A2AClientError):
"""Client exception for HTTP errors received from the server."""

def __init__(self, status_code: int, message: str):
def __init__(
self,
status: int,
message: str,
body: str | None = None,
headers: Mapping[str, str] | None = None,
):
"""Initializes the A2AClientHTTPError.

Args:
status_code: The HTTP status code of the response.
status: The HTTP status code of the response.
message: A descriptive error message.
body: The raw response body, if available.
headers: The HTTP response headers.
"""
self.status_code = status_code
self.status = status
self.status_code = status
self.message = message
super().__init__(f'HTTP Error {status_code}: {message}')
self.body = body
self.headers = dict(headers or {})
super().__init__(f'HTTP {status} - {message}')
Comment on lines +15 to +35
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low

For consistency and to avoid redundancy, it would be clearer to use a single attribute for the status code. The httpx library uses status_code, which is a common convention. Using both status and status_code for the same value can be confusing. I suggest standardizing on status_code.

Suggested change
def __init__(
self,
status: int,
message: str,
body: str | None = None,
headers: Mapping[str, str] | None = None,
):
"""Initializes the A2AClientHTTPError.
Args:
status_code: The HTTP status code of the response.
status: The HTTP status code of the response.
message: A descriptive error message.
body: The raw response body, if available.
headers: The HTTP response headers.
"""
self.status_code = status_code
self.status = status
self.status_code = status
self.message = message
super().__init__(f'HTTP Error {status_code}: {message}')
self.body = body
self.headers = dict(headers or {})
super().__init__(f'HTTP {status} - {message}')
def __init__(
self,
status_code: int,
message: str,
body: str | None = None,
headers: Mapping[str, str] | None = None,
):
"""Initializes the A2AClientHTTPError.
Args:
status_code: The HTTP status code of the response.
message: A descriptive error message.
body: The raw response body, if available.
headers: The HTTP response headers.
"""
self.status_code = status_code
self.message = message
self.body = body
self.headers = dict(headers or {})
super().__init__(f'HTTP {status_code} - {message}')



class A2AClientJSONError(A2AClientError):
Expand Down
116 changes: 116 additions & 0 deletions src/a2a/client/transports/_streaming_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""Shared helpers for handling streaming HTTP responses."""

from __future__ import annotations

import json

from typing import Any

import httpx # noqa: TC002

from httpx_sse import EventSource # noqa: TC002

from a2a.client.errors import A2AClientHTTPError


SUCCESS_STATUS_MIN = 200
SUCCESS_STATUS_MAX = 300


async def ensure_streaming_response(event_source: EventSource) -> None:
"""Validate the initial streaming response before attempting SSE parsing."""
response = event_source.response
if not SUCCESS_STATUS_MIN <= response.status_code < SUCCESS_STATUS_MAX:
error = await _build_http_error(response)
raise error

if not _has_event_stream_content_type(response):
error = await _build_content_type_error(response)
raise error


async def _build_http_error(response: httpx.Response) -> A2AClientHTTPError:
body_text = await _read_body(response)
json_payload: Any | None
try:
json_payload = response.json()
except (json.JSONDecodeError, ValueError):
json_payload = None
Comment on lines +35 to +38
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The response body is already read into the body_text variable. Instead of calling response.json(), which might re-process the response body, it's more direct and potentially more efficient to parse body_text directly. This also makes the flow of data clearer.

Suggested change
try:
json_payload = response.json()
except (json.JSONDecodeError, ValueError):
json_payload = None
try:
json_payload = json.loads(body_text) if body_text else None
except json.JSONDecodeError:
json_payload = None


message = _extract_message(response, json_payload, body_text)
return A2AClientHTTPError(
response.status_code,
message,
body=body_text,
headers=dict(response.headers),
)


async def _build_content_type_error(
response: httpx.Response,
) -> A2AClientHTTPError:
body_text = await _read_body(response)
content_type = response.headers.get('content-type', None)
descriptor = content_type or 'missing'
message = f'Unexpected Content-Type {descriptor!r} for streaming response'
return A2AClientHTTPError(
response.status_code,
message,
body=body_text,
headers=dict(response.headers),
)


async def _read_body(response: httpx.Response) -> str | None:
await response.aread()
text = response.text
return text if text else None


def _extract_message(
response: httpx.Response,
json_payload: Any | None,
body_text: str | None,
) -> str:
message: str | None = None
if isinstance(json_payload, dict):
title = _coerce_str(json_payload.get('title'))
detail = _coerce_str(json_payload.get('detail'))
if title and detail:
message = f'{title}: {detail}'
else:
for key in ('message', 'detail', 'error', 'title'):
value = _coerce_str(json_payload.get(key))
if value:
message = value
break
elif isinstance(json_payload, list):
# Some APIs return a list of error descriptions—prefer the first string entry.
for item in json_payload:
value = _coerce_str(item)
if value:
message = value
break

if not message and body_text:
stripped = body_text.strip()
if stripped:
message = stripped

if not message:
reason = getattr(response, 'reason_phrase', '') or ''
message = reason or 'HTTP error'

return message

Check failure on line 104 in src/a2a/client/transports/_streaming_utils.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Type "str | None" is not assignable to return type "str"   Type "str | None" is not assignable to type "str"     "None" is not assignable to "str" (reportReturnType)


def _coerce_str(value: Any) -> str | None:
if isinstance(value, str):
stripped = value.strip()
return stripped or None
return None


def _has_event_stream_content_type(response: httpx.Response) -> bool:
content_type = response.headers.get('content-type', '')
return 'text/event-stream' in content_type.lower()
22 changes: 15 additions & 7 deletions src/a2a/client/transports/jsonrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
A2AClientTimeoutError,
)
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
from a2a.client.transports._streaming_utils import ensure_streaming_response
from a2a.client.transports.base import ClientTransport
from a2a.types import (
AgentCard,
Expand Down Expand Up @@ -161,28 +162,35 @@
json=payload,
**modified_kwargs,
) as event_source:
http_response = event_source.response
try:
await ensure_streaming_response(event_source)
async for sse in event_source.aiter_sse():
response = SendStreamingMessageResponse.model_validate(
json.loads(sse.data)
stream_response = (
SendStreamingMessageResponse.model_validate(
json.loads(sse.data)
)
)
if isinstance(response.root, JSONRPCErrorResponse):
raise A2AClientJSONRPCError(response.root)
yield response.root.result
if isinstance(stream_response.root, JSONRPCErrorResponse):
raise A2AClientJSONRPCError(stream_response.root)
yield stream_response.root.result
except SSEError as e:
raise A2AClientHTTPError(
400, f'Invalid SSE response or protocol error: {e}'
http_response.status_code,
f'Invalid SSE response or protocol error: {e}',
headers=dict(http_response.headers),
) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
503,
f'Network communication error: {e}',
) from e

async def _send_request(
self,
rpc_request_payload: dict[str, Any],

Check notice on line 193 in src/a2a/client/transports/jsonrpc.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/rest.py (150-164)
http_kwargs: dict[str, Any] | None = None,
) -> dict[str, Any]:
try:
Expand Down
10 changes: 8 additions & 2 deletions src/a2a/client/transports/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,30 @@
from a2a.client.card_resolver import A2ACardResolver
from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
from a2a.client.transports._streaming_utils import ensure_streaming_response
from a2a.client.transports.base import ClientTransport
from a2a.grpc import a2a_pb2
from a2a.types import (
AgentCard,
GetTaskPushNotificationConfigParams,
Message,
MessageSendParams,
Task,
TaskArtifactUpdateEvent,
TaskIdParams,
TaskPushNotificationConfig,
TaskQueryParams,
TaskStatusUpdateEvent,
)
from a2a.utils import proto_utils
from a2a.utils.telemetry import SpanKind, trace_class


logger = logging.getLogger(__name__)


@trace_class(kind=SpanKind.CLIENT)
class RestTransport(ClientTransport):

Check notice on line 38 in src/a2a/client/transports/rest.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/grpc.py (20-40)
"""A REST transport for the A2A client."""

def __init__(
Expand Down Expand Up @@ -139,23 +140,28 @@
json=payload,
**modified_kwargs,
) as event_source:
http_response = event_source.response
try:
await ensure_streaming_response(event_source)
async for sse in event_source.aiter_sse():
event = a2a_pb2.StreamResponse()
Parse(sse.data, event)
yield proto_utils.FromProto.stream_response(event)
except SSEError as e:
raise A2AClientHTTPError(
400, f'Invalid SSE response or protocol error: {e}'
http_response.status_code,
f'Invalid SSE response or protocol error: {e}',
headers=dict(http_response.headers),
) from e
except json.JSONDecodeError as e:
raise A2AClientJSONError(str(e)) from e
except httpx.RequestError as e:
raise A2AClientHTTPError(
503, f'Network communication error: {e}'
503,
f'Network communication error: {e}',
) from e

async def _send_request(self, request: httpx.Request) -> dict[str, Any]:

Check notice on line 164 in src/a2a/client/transports/rest.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/jsonrpc.py (177-193)
try:
response = await self.httpx_client.send(request)
response.raise_for_status()
Expand Down
16 changes: 9 additions & 7 deletions tests/client/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ def test_instantiation(self):
assert isinstance(error, A2AClientError)
assert error.status_code == 404
assert error.message == 'Not Found'
assert error.body is None
assert error.headers == {}

def test_message_formatting(self):
"""Test that the error message is formatted correctly."""
error = A2AClientHTTPError(500, 'Internal Server Error')
assert str(error) == 'HTTP Error 500: Internal Server Error'
assert str(error) == 'HTTP 500 - Internal Server Error'

def test_inheritance(self):
"""Test that A2AClientHTTPError inherits from A2AClientError."""
Expand All @@ -43,7 +45,7 @@ def test_with_empty_message(self):
error = A2AClientHTTPError(403, '')
assert error.status_code == 403
assert error.message == ''
assert str(error) == 'HTTP Error 403: '
assert str(error) == 'HTTP 403 - '

def test_with_various_status_codes(self):
"""Test with different HTTP status codes."""
Expand All @@ -62,7 +64,7 @@ def test_with_various_status_codes(self):
error = A2AClientHTTPError(status_code, message)
assert error.status_code == status_code
assert error.message == message
assert str(error) == f'HTTP Error {status_code}: {message}'
assert str(error) == f'HTTP {status_code} - {message}'


class TestA2AClientJSONError:
Expand Down Expand Up @@ -148,7 +150,7 @@ def test_raising_http_error(self):
error = excinfo.value
assert error.status_code == 429
assert error.message == 'Too Many Requests'
assert str(error) == 'HTTP Error 429: Too Many Requests'
assert str(error) == 'HTTP 429 - Too Many Requests'

def test_raising_json_error(self):
"""Test raising a JSON error and checking its properties."""
Expand All @@ -173,9 +175,9 @@ def test_raising_base_error(self):
@pytest.mark.parametrize(
'status_code,message,expected',
[
(400, 'Bad Request', 'HTTP Error 400: Bad Request'),
(404, 'Not Found', 'HTTP Error 404: Not Found'),
(500, 'Server Error', 'HTTP Error 500: Server Error'),
(400, 'Bad Request', 'HTTP 400 - Bad Request'),
(404, 'Not Found', 'HTTP 404 - Not Found'),
(500, 'Server Error', 'HTTP 500 - Server Error'),
],
)
def test_http_error_parametrized(status_code, message, expected):
Expand Down
Loading
Loading