Skip to content

Commit 8f4425e

Browse files
dtmeadowsclaudeRobertCraigie
authored
add test to make sure runner overloads are in sync (#1414)
* add ability to test overloads are in sync on tool runner * formt * fix * use typing_extensions.get_overloads() for broader python version support Co-Authored-By: Claude Code (${CLAUDE_PROJECT_DIR}) <noreply@anthropic.com> * fix import ordering for ruff Co-Authored-By: Claude Code (${CLAUDE_PROJECT_DIR}) <noreply@anthropic.com> * Update tests/lib/streaming/test_beta_messages.py Co-authored-by: Robert Craigie <robert@craigie.dev> * format * format again --------- Co-authored-by: Claude Code (${CLAUDE_PROJECT_DIR}) <noreply@anthropic.com> Co-authored-by: Robert Craigie <robert@craigie.dev>
1 parent e44886c commit 8f4425e

File tree

6 files changed

+70
-62
lines changed

6 files changed

+70
-62
lines changed

src/anthropic/_utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
)
6161
from ._reflection import (
6262
function_has_argument as function_has_argument,
63+
assert_overloads_in_sync as assert_overloads_in_sync,
6364
assert_signatures_in_sync as assert_signatures_in_sync,
6465
)
6566
from ._datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime

src/anthropic/_utils/_reflection.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import inspect
4+
import typing_extensions
45
from typing import Any, Callable
56

67

@@ -40,3 +41,40 @@ def assert_signatures_in_sync(
4041

4142
if errors:
4243
raise AssertionError(f"{len(errors)} errors encountered when comparing signatures:\n\n" + "\n\n".join(errors))
44+
45+
46+
def assert_overloads_in_sync(
47+
source_func: Callable[..., Any],
48+
overloaded_func: Callable[..., Any],
49+
*,
50+
exclude_params: set[str] = set(),
51+
) -> None:
52+
"""Ensure that every @overload of overloaded_func contains all params from source_func."""
53+
source_sig = inspect.signature(source_func)
54+
overloads = typing_extensions.get_overloads(overloaded_func)
55+
56+
if not overloads:
57+
raise AssertionError(f"No @overload definitions found for {overloaded_func!r}")
58+
59+
errors: list[str] = []
60+
61+
for i, overload_fn in enumerate(overloads):
62+
overload_sig = inspect.signature(overload_fn)
63+
for name, source_param in source_sig.parameters.items():
64+
if name in exclude_params:
65+
continue
66+
67+
overload_param = overload_sig.parameters.get(name)
68+
if not overload_param:
69+
errors.append(f"overload {i}: `{name}` param is missing")
70+
continue
71+
72+
if overload_param.annotation != source_param.annotation:
73+
errors.append(
74+
f"overload {i}: types for `{name}` do not match; source={repr(source_param.annotation)} overload={repr(overload_param.annotation)}"
75+
)
76+
77+
if errors:
78+
raise AssertionError(
79+
f"{len(errors)} errors encountered when comparing overload signatures:\n\n" + "\n\n".join(errors)
80+
)

src/anthropic/lib/_parse/_transform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def transform_schema(
112112

113113
enum = json_schema.pop("enum", None)
114114
if is_list(enum):
115-
strict_schema['enum'] = enum
115+
strict_schema["enum"] = enum
116116

117117
description = json_schema.pop("description", None)
118118
if description is not None:

src/anthropic/resources/beta/messages/batches.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -406,9 +406,7 @@ def results(
406406
}
407407
extra_headers = {"anthropic-beta": "message-batches-2024-09-24", **(extra_headers or {})}
408408
return self._get(
409-
path_template(
410-
batch.results_url, message_batch_id=message_batch_id
411-
),
409+
path_template(batch.results_url, message_batch_id=message_batch_id),
412410
options=make_request_options(
413411
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
414412
),
@@ -797,9 +795,7 @@ async def results(
797795
}
798796
extra_headers = {"anthropic-beta": "message-batches-2024-09-24", **(extra_headers or {})}
799797
return await self._get(
800-
path_template(
801-
batch.results_url, message_batch_id=message_batch_id
802-
),
798+
path_template(batch.results_url, message_batch_id=message_batch_id),
803799
options=make_request_options(
804800
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
805801
),

tests/lib/streaming/test_beta_messages.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import os
44
import json
5-
import inspect
65
from typing import Any, Set, Dict, TypeVar, cast
76
from unittest import TestCase
87

@@ -11,6 +10,7 @@
1110
from respx import MockRouter
1211

1312
from anthropic import Anthropic, AsyncAnthropic
13+
from anthropic._utils import assert_overloads_in_sync, assert_signatures_in_sync
1414
from anthropic._compat import PYDANTIC_V1
1515
from anthropic.types.beta.beta_message import BetaMessage
1616
from anthropic.lib.streaming._beta_types import ParsedBetaMessageStreamEvent
@@ -376,36 +376,31 @@ async def test_incomplete_response(self, respx_mock: MockRouter) -> None:
376376
@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
377377
def test_stream_method_definition_in_sync(sync: bool) -> None:
378378
client: Anthropic | AsyncAnthropic = sync_client if sync else async_client
379+
assert_signatures_in_sync(
380+
client.beta.messages.create,
381+
client.beta.messages.stream,
382+
exclude_params={"stream", "output_format"},
383+
)
379384

380-
sig = inspect.signature(client.beta.messages.stream)
381-
generated_sig = inspect.signature(client.beta.messages.create)
382-
383-
errors: list[str] = []
384-
385-
for name, generated_param in generated_sig.parameters.items():
386-
if name == "stream":
387-
# intentionally excluded
388-
continue
389-
390-
if name == "output_format":
391-
continue
392385

393-
custom_param = sig.parameters.get(name)
394-
if not custom_param:
395-
errors.append(f"the `{name}` param is missing")
396-
continue
386+
@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
387+
def test_parse_method_definition_in_sync(sync: bool) -> None:
388+
client: Anthropic | AsyncAnthropic = sync_client if sync else async_client
389+
assert_signatures_in_sync(
390+
client.beta.messages.create,
391+
client.beta.messages.parse,
392+
exclude_params={"stream", "output_format"},
393+
)
397394

398-
if custom_param.annotation != generated_param.annotation:
399-
errors.append(
400-
f"types for the `{name}` param are do not match; generated={repr(generated_param.annotation)} custom={repr(custom_param.annotation)}"
401-
)
402-
continue
403395

404-
if errors:
405-
raise AssertionError(
406-
f"{len(errors)} errors encountered with the {'sync' if sync else 'async'} client `messages.stream()` method:\n\n"
407-
+ "\n\n".join(errors)
408-
)
396+
@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
397+
def test_tool_runner_method_definition_in_sync(sync: bool) -> None:
398+
client: Anthropic | AsyncAnthropic = sync_client if sync else async_client
399+
assert_overloads_in_sync(
400+
client.beta.messages.create,
401+
client.beta.messages.tool_runner,
402+
exclude_params={"stream", "tools", "max_iterations", "compaction_control", "output_format"},
403+
)
409404

410405

411406
# go through all the ContentBlock types to make sure the type alias is up to date

tests/lib/streaming/test_messages.py

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from __future__ import annotations
22

33
import os
4-
import inspect
54
from typing import Any, Set, TypeVar, cast
65

76
import httpx
87
import pytest
98
from respx import MockRouter
109

1110
from anthropic import Stream, Anthropic, AsyncStream, AsyncAnthropic
11+
from anthropic._utils import assert_signatures_in_sync
1212
from anthropic._compat import PYDANTIC_V1
1313
from anthropic.lib.streaming import ParsedMessageStreamEvent
1414
from anthropic.types.message import Message
@@ -274,33 +274,11 @@ async def test_tool_use(self, respx_mock: MockRouter) -> None:
274274
@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
275275
def test_stream_method_definition_in_sync(sync: bool) -> None:
276276
client: Anthropic | AsyncAnthropic = sync_client if sync else async_client
277-
278-
sig = inspect.signature(client.messages.stream)
279-
generated_sig = inspect.signature(client.messages.create)
280-
281-
errors: list[str] = []
282-
283-
for name, generated_param in generated_sig.parameters.items():
284-
if name == "stream":
285-
# intentionally excluded
286-
continue
287-
288-
custom_param = sig.parameters.get(name)
289-
if not custom_param:
290-
errors.append(f"the `{name}` param is missing")
291-
continue
292-
293-
if custom_param.annotation != generated_param.annotation:
294-
errors.append(
295-
f"types for the `{name}` param are do not match; generated={repr(generated_param.annotation)} custom={repr(custom_param.annotation)}"
296-
)
297-
continue
298-
299-
if errors:
300-
raise AssertionError(
301-
f"{len(errors)} errors encountered with the {'sync' if sync else 'async'} client `messages.stream()` method:\n\n"
302-
+ "\n\n".join(errors)
303-
)
277+
assert_signatures_in_sync(
278+
client.messages.create,
279+
client.messages.stream,
280+
exclude_params={"stream"},
281+
)
304282

305283

306284
# go through all the ContentBlock types to make sure the type alias is up to date

0 commit comments

Comments
 (0)