diff --git a/rest_old.py b/rest_old.py new file mode 100644 index 00000000..e69de29b diff --git a/src/a2a/client/errors.py b/src/a2a/client/errors.py index 890c3726..f571f224 100644 --- a/src/a2a/client/errors.py +++ b/src/a2a/client/errors.py @@ -1,5 +1,7 @@ """Custom exceptions for the A2A client.""" +from collections.abc import Mapping + from a2a.types import JSONRPCErrorResponse @@ -10,16 +12,27 @@ class A2AClientError(Exception): class A2AClientHTTPError(A2AClientError): """Client exception for HTTP errors received from the server.""" - def __init__(self, status_code: int, message: str): + def __init__( + self, + status: int, + message: str, + body: str | None = None, + headers: Mapping[str, str] | None = None, + ): """Initializes the A2AClientHTTPError. Args: - status_code: The HTTP status code of the response. + status: The HTTP status code of the response. message: A descriptive error message. + body: The raw response body, if available. + headers: The HTTP response headers. """ - self.status_code = status_code + self.status = status + self.status_code = status self.message = message - super().__init__(f'HTTP Error {status_code}: {message}') + self.body = body + self.headers = dict(headers or {}) + super().__init__(f'HTTP {status} - {message}') class A2AClientJSONError(A2AClientError): diff --git a/src/a2a/client/transports/_streaming_utils.py b/src/a2a/client/transports/_streaming_utils.py new file mode 100644 index 00000000..b5b1cda6 --- /dev/null +++ b/src/a2a/client/transports/_streaming_utils.py @@ -0,0 +1,116 @@ +"""Shared helpers for handling streaming HTTP responses.""" + +from __future__ import annotations + +import json + +from typing import Any + +import httpx # noqa: TC002 + +from httpx_sse import EventSource # noqa: TC002 + +from a2a.client.errors import A2AClientHTTPError + + +SUCCESS_STATUS_MIN = 200 +SUCCESS_STATUS_MAX = 300 + + +async def ensure_streaming_response(event_source: EventSource) -> None: + """Validate the initial streaming response before attempting SSE parsing.""" + response = event_source.response + if not SUCCESS_STATUS_MIN <= response.status_code < SUCCESS_STATUS_MAX: + error = await _build_http_error(response) + raise error + + if not _has_event_stream_content_type(response): + error = await _build_content_type_error(response) + raise error + + +async def _build_http_error(response: httpx.Response) -> A2AClientHTTPError: + body_text = await _read_body(response) + json_payload: Any | None + try: + json_payload = response.json() + except (json.JSONDecodeError, ValueError): + json_payload = None + + message = _extract_message(response, json_payload, body_text) + return A2AClientHTTPError( + response.status_code, + message, + body=body_text, + headers=dict(response.headers), + ) + + +async def _build_content_type_error( + response: httpx.Response, +) -> A2AClientHTTPError: + body_text = await _read_body(response) + content_type = response.headers.get('content-type', None) + descriptor = content_type or 'missing' + message = f'Unexpected Content-Type {descriptor!r} for streaming response' + return A2AClientHTTPError( + response.status_code, + message, + body=body_text, + headers=dict(response.headers), + ) + + +async def _read_body(response: httpx.Response) -> str | None: + await response.aread() + text = response.text + return text if text else None + + +def _extract_message( + response: httpx.Response, + json_payload: Any | None, + body_text: str | None, +) -> str: + message: str | None = None + if isinstance(json_payload, dict): + title = _coerce_str(json_payload.get('title')) + detail = _coerce_str(json_payload.get('detail')) + if title and detail: + message = f'{title}: {detail}' + else: + for key in ('message', 'detail', 'error', 'title'): + value = _coerce_str(json_payload.get(key)) + if value: + message = value + break + elif isinstance(json_payload, list): + # Some APIs return a list of error descriptions—prefer the first string entry. + for item in json_payload: + value = _coerce_str(item) + if value: + message = value + break + + if not message and body_text: + stripped = body_text.strip() + if stripped: + message = stripped + + if not message: + reason = getattr(response, 'reason_phrase', '') or '' + message = reason or 'HTTP error' + + return message + + +def _coerce_str(value: Any) -> str | None: + if isinstance(value, str): + stripped = value.strip() + return stripped or None + return None + + +def _has_event_stream_content_type(response: httpx.Response) -> bool: + content_type = response.headers.get('content-type', '') + return 'text/event-stream' in content_type.lower() diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index bfba09d7..16973d6e 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -17,6 +17,7 @@ A2AClientTimeoutError, ) from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.client.transports._streaming_utils import ensure_streaming_response from a2a.client.transports.base import ClientTransport from a2a.types import ( AgentCard, @@ -161,23 +162,30 @@ async def send_message_streaming( json=payload, **modified_kwargs, ) as event_source: + http_response = event_source.response try: + await ensure_streaming_response(event_source) async for sse in event_source.aiter_sse(): - response = SendStreamingMessageResponse.model_validate( - json.loads(sse.data) + stream_response = ( + SendStreamingMessageResponse.model_validate( + json.loads(sse.data) + ) ) - if isinstance(response.root, JSONRPCErrorResponse): - raise A2AClientJSONRPCError(response.root) - yield response.root.result + if isinstance(stream_response.root, JSONRPCErrorResponse): + raise A2AClientJSONRPCError(stream_response.root) + yield stream_response.root.result except SSEError as e: raise A2AClientHTTPError( - 400, f'Invalid SSE response or protocol error: {e}' + http_response.status_code, + f'Invalid SSE response or protocol error: {e}', + headers=dict(http_response.headers), ) from e except json.JSONDecodeError as e: raise A2AClientJSONError(str(e)) from e except httpx.RequestError as e: raise A2AClientHTTPError( - 503, f'Network communication error: {e}' + 503, + f'Network communication error: {e}', ) from e async def _send_request( diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index eef7b0f2..a991fdcf 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -12,6 +12,7 @@ from a2a.client.card_resolver import A2ACardResolver from a2a.client.errors import A2AClientHTTPError, A2AClientJSONError from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.client.transports._streaming_utils import ensure_streaming_response from a2a.client.transports.base import ClientTransport from a2a.grpc import a2a_pb2 from a2a.types import ( @@ -139,20 +140,25 @@ async def send_message_streaming( json=payload, **modified_kwargs, ) as event_source: + http_response = event_source.response try: + await ensure_streaming_response(event_source) async for sse in event_source.aiter_sse(): event = a2a_pb2.StreamResponse() Parse(sse.data, event) yield proto_utils.FromProto.stream_response(event) except SSEError as e: raise A2AClientHTTPError( - 400, f'Invalid SSE response or protocol error: {e}' + http_response.status_code, + f'Invalid SSE response or protocol error: {e}', + headers=dict(http_response.headers), ) from e except json.JSONDecodeError as e: raise A2AClientJSONError(str(e)) from e except httpx.RequestError as e: raise A2AClientHTTPError( - 503, f'Network communication error: {e}' + 503, + f'Network communication error: {e}', ) from e async def _send_request(self, request: httpx.Request) -> dict[str, Any]: diff --git a/tests/client/test_errors.py b/tests/client/test_errors.py index 30c4468d..0ab85535 100644 --- a/tests/client/test_errors.py +++ b/tests/client/test_errors.py @@ -27,11 +27,13 @@ def test_instantiation(self): assert isinstance(error, A2AClientError) assert error.status_code == 404 assert error.message == 'Not Found' + assert error.body is None + assert error.headers == {} def test_message_formatting(self): """Test that the error message is formatted correctly.""" error = A2AClientHTTPError(500, 'Internal Server Error') - assert str(error) == 'HTTP Error 500: Internal Server Error' + assert str(error) == 'HTTP 500 - Internal Server Error' def test_inheritance(self): """Test that A2AClientHTTPError inherits from A2AClientError.""" @@ -43,7 +45,7 @@ def test_with_empty_message(self): error = A2AClientHTTPError(403, '') assert error.status_code == 403 assert error.message == '' - assert str(error) == 'HTTP Error 403: ' + assert str(error) == 'HTTP 403 - ' def test_with_various_status_codes(self): """Test with different HTTP status codes.""" @@ -62,7 +64,7 @@ def test_with_various_status_codes(self): error = A2AClientHTTPError(status_code, message) assert error.status_code == status_code assert error.message == message - assert str(error) == f'HTTP Error {status_code}: {message}' + assert str(error) == f'HTTP {status_code} - {message}' class TestA2AClientJSONError: @@ -148,7 +150,7 @@ def test_raising_http_error(self): error = excinfo.value assert error.status_code == 429 assert error.message == 'Too Many Requests' - assert str(error) == 'HTTP Error 429: Too Many Requests' + assert str(error) == 'HTTP 429 - Too Many Requests' def test_raising_json_error(self): """Test raising a JSON error and checking its properties.""" @@ -173,9 +175,9 @@ def test_raising_base_error(self): @pytest.mark.parametrize( 'status_code,message,expected', [ - (400, 'Bad Request', 'HTTP Error 400: Bad Request'), - (404, 'Not Found', 'HTTP Error 404: Not Found'), - (500, 'Server Error', 'HTTP Error 500: Server Error'), + (400, 'Bad Request', 'HTTP 400 - Bad Request'), + (404, 'Not Found', 'HTTP 404 - Not Found'), + (500, 'Server Error', 'HTTP 500 - Server Error'), ], ) def test_http_error_parametrized(status_code, message, expected): diff --git a/tests/client/test_streaming_http_errors.py b/tests/client/test_streaming_http_errors.py new file mode 100644 index 00000000..64e95453 --- /dev/null +++ b/tests/client/test_streaming_http_errors.py @@ -0,0 +1,184 @@ +"""Tests for a2a.client.errors module.""" + +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from dataclasses import dataclass +from typing import Any +from unittest.mock import AsyncMock, patch + +import httpx +import pytest + +from a2a.client import create_text_message_object +from a2a.client.errors import A2AClientHTTPError +from a2a.client.transports.rest import RestTransport +from a2a.types import MessageSendParams + + +@dataclass +class DummyServerSentEvent: + data: str + + +class MockEventSource: + def __init__( + self, + response: httpx.Response, + events: list[Any] | None = None, + error: Exception | None = None, + ): + self.response = response + self._events = events or [] + self._error = error + + async def __aenter__(self) -> 'MockEventSource': + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + return None + + async def aiter_sse(self) -> AsyncIterator[Any]: + if self._error: + raise self._error + for event in self._events: + yield event + + +def make_response( + status: int, + *, + json_body: dict[str, Any] | None = None, + text_body: str | None = None, + headers: dict[str, str] | None = None, +) -> httpx.Response: + request = httpx.Request('POST', 'https://api.example.com/v1/message:stream') + if json_body is not None: + response = httpx.Response( + status, + json=json_body, + headers=headers, + request=request, + ) + else: + response = httpx.Response( + status, + content=text_body.encode() if text_body else b'', + headers=headers, + request=request, + ) + return response + + +def make_transport() -> RestTransport: + httpx_client = AsyncMock(spec=httpx.AsyncClient) + transport = RestTransport( + httpx_client=httpx_client, url='https://api.example.com' + ) + transport._prepare_send_message = AsyncMock(return_value=({}, {})) + return transport + + +async def collect_stream(transport: RestTransport, params: MessageSendParams): + return [item async for item in transport.send_message_streaming(params)] + + +def patch_stream_context(event_source: MockEventSource): + @asynccontextmanager + async def fake_aconnect_sse(*_: Any, **__: Any): + yield event_source + + return patch( + 'a2a.client.transports.rest.aconnect_sse', new=fake_aconnect_sse + ) + + +@pytest.mark.parametrize( + ('status', 'body', 'expected'), + [ + (401, {'error': 'invalid_token'}, 'invalid_token'), + (500, {'message': 'DB down'}, 'DB down'), + (503, {'detail': 'Service unavailable'}, 'Service unavailable'), + ( + 404, + {'title': 'Not Found', 'detail': 'No such task'}, + 'Not Found: No such task', + ), + ], +) +@pytest.mark.asyncio +async def test_streaming_surfaces_http_errors( + status: int, body: dict[str, Any], expected: str +): + transport = make_transport() + params = MessageSendParams( + message=create_text_message_object(content='Hello') + ) + response = make_response(status, json_body=body) + event_source = MockEventSource(response) + + with patch_stream_context(event_source), pytest.raises( + A2AClientHTTPError + ) as exc_info: + await collect_stream(transport, params) + + error = exc_info.value + assert error.status == status + assert expected in error.message + assert error.body is not None + assert str(status) in str(error) + + +@pytest.mark.asyncio +async def test_streaming_rejects_wrong_content_type(): + transport = make_transport() + params = MessageSendParams( + message=create_text_message_object(content='Hello') + ) + response = make_response( + 200, + json_body={'message': 'not a stream'}, + headers={'content-type': 'application/json'}, + ) + event_source = MockEventSource(response) + + with patch_stream_context(event_source), pytest.raises( + A2AClientHTTPError + ) as exc_info: + await collect_stream(transport, params) + + error = exc_info.value + assert error.status == 200 + assert 'Unexpected Content-Type' in error.message + assert 'application/json' in error.message + + +@pytest.mark.asyncio +async def test_streaming_success_path_unchanged(): + transport = make_transport() + params = MessageSendParams( + message=create_text_message_object(content='Hello') + ) + response = make_response( + 200, + text_body='event-stream', + headers={'content-type': 'text/event-stream'}, + ) + events = [DummyServerSentEvent(data='{"foo":"bar"}')] + event_source = MockEventSource(response, events=events) + + with ( + patch_stream_context(event_source), + patch( + 'a2a.client.transports.rest.Parse', + side_effect=lambda data, obj: obj, + ) as mock_parse, + patch( + 'a2a.client.transports.rest.proto_utils.FromProto.stream_response', + return_value={'result': 'ok'}, + ) as mock_from_proto, + ): + results = await collect_stream(transport, params) + + assert results == [{'result': 'ok'}] + mock_parse.assert_called() + mock_from_proto.assert_called() diff --git a/tmp_old_jsonrpc.py b/tmp_old_jsonrpc.py new file mode 100644 index 00000000..e69de29b