-
Notifications
You must be signed in to change notification settings - Fork 461
feat(tools): Support string descriptions in Annotated parameters #1089
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
296853b
8637653
bdea552
787c0dd
9a4b2e6
616ec5a
662c168
f187b29
07b837b
c8818d1
96ed97c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -44,7 +44,9 @@ def my_tool(param1: str, param2: int = 42) -> dict: | |
| import functools | ||
| import inspect | ||
| import logging | ||
| from copy import copy | ||
| from typing import ( | ||
| Annotated, | ||
| Any, | ||
| Callable, | ||
| Generic, | ||
|
|
@@ -54,12 +56,15 @@ def my_tool(param1: str, param2: int = 42) -> dict: | |
| TypeVar, | ||
| Union, | ||
| cast, | ||
| get_args, | ||
| get_origin, | ||
| get_type_hints, | ||
| overload, | ||
| ) | ||
|
|
||
| import docstring_parser | ||
| from pydantic import BaseModel, Field, create_model | ||
| from pydantic.fields import FieldInfo | ||
| from typing_extensions import override | ||
|
|
||
| from ..interrupt import InterruptException | ||
|
|
@@ -97,7 +102,12 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) - | |
| """ | ||
| self.func = func | ||
| self.signature = inspect.signature(func) | ||
| self.type_hints = get_type_hints(func) | ||
| # Preserve Annotated extras when possible (Python 3.9+ / 3.10+ support include_extras) | ||
| try: | ||
| self.type_hints = get_type_hints(func, include_extras=True) | ||
| except TypeError: | ||
| # Older Python versions / typing implementations may not accept include_extras | ||
| self.type_hints = get_type_hints(func) | ||
| self._context_param = context_param | ||
|
|
||
| self._validate_signature() | ||
|
|
@@ -114,6 +124,32 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) - | |
| # Create a Pydantic model for validation | ||
| self.input_model = self._create_input_model() | ||
|
|
||
| def _extract_annotated_metadata(self, annotation: Any) -> tuple[Any, Optional[Any]]: | ||
| """Extract type and metadata from Annotated type hint. | ||
|
|
||
| Returns: | ||
| (actual_type, metadata) where metadata is either: | ||
| - a string description | ||
| - a pydantic.fields.FieldInfo instance (from Field(...)) | ||
| - None if no Annotated extras were found | ||
| """ | ||
| if get_origin(annotation) is Annotated: | ||
| args = get_args(annotation) | ||
| actual_type = args[0] # Keep the type as-is (including Optional[T]) | ||
|
|
||
| # Look through metadata for description | ||
| for meta in args[1:]: | ||
| if isinstance(meta, str): | ||
| return actual_type, meta | ||
| if isinstance(meta, FieldInfo): | ||
| return actual_type, meta | ||
|
|
||
| # Annotated but no useful metadata | ||
| return actual_type, None | ||
|
|
||
| # Not annotated | ||
| return annotation, None | ||
|
|
||
| def _validate_signature(self) -> None: | ||
| """Verify that ToolContext is used correctly in the function signature.""" | ||
| for param in self.signature.parameters.values(): | ||
|
|
@@ -146,13 +182,38 @@ def _create_input_model(self) -> Type[BaseModel]: | |
| if self._is_special_parameter(name): | ||
| continue | ||
|
|
||
| # Get parameter type and default | ||
| # Get parameter type hint and any Annotated metadata | ||
| param_type = self.type_hints.get(name, Any) | ||
| actual_type, annotated_meta = self._extract_annotated_metadata(param_type) | ||
|
|
||
| # Determine parameter default value | ||
| default = ... if param.default is inspect.Parameter.empty else param.default | ||
| description = self.param_descriptions.get(name, f"Parameter {name}") | ||
|
|
||
| # Create Field with description and default | ||
| field_definitions[name] = (param_type, Field(default=default, description=description)) | ||
| # Determine description (priority: Annotated > docstring > generic) | ||
| description: str | ||
| if isinstance(annotated_meta, str): | ||
| description = annotated_meta | ||
| elif isinstance(annotated_meta, FieldInfo) and annotated_meta.description is not None: | ||
|
||
| description = annotated_meta.description | ||
| elif name in self.param_descriptions: | ||
| description = self.param_descriptions[name] | ||
| else: | ||
| description = f"Parameter {name}" | ||
|
|
||
| # Create Field definition for create_model | ||
| if isinstance(annotated_meta, FieldInfo): | ||
| # Create a defensive copy to avoid mutating a shared FieldInfo instance. | ||
| field_info_copy = copy(annotated_meta) | ||
| field_info_copy.description = description | ||
|
|
||
| # Update default if specified in the function signature. | ||
| if default is not ...: | ||
| field_info_copy.default = default | ||
|
|
||
| field_definitions[name] = (actual_type, field_info_copy) | ||
| else: | ||
| # For non-FieldInfo metadata, create a new Field. | ||
| field_definitions[name] = (actual_type, Field(default=default, description=description)) | ||
|
|
||
| # Create model name based on function name | ||
| model_name = f"{self.func.__name__.capitalize()}Tool" | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.