Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 66 additions & 5 deletions src/strands/tools/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def my_tool(param1: str, param2: int = 42) -> dict:
import functools
import inspect
import logging
from copy import copy
from typing import (
Annotated,
Any,
Callable,
Generic,
Expand All @@ -54,12 +56,15 @@ def my_tool(param1: str, param2: int = 42) -> dict:
TypeVar,
Union,
cast,
get_args,
get_origin,
get_type_hints,
overload,
)

import docstring_parser
from pydantic import BaseModel, Field, create_model
from pydantic.fields import FieldInfo
from typing_extensions import override

from ..interrupt import InterruptException
Expand Down Expand Up @@ -97,7 +102,12 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) -
"""
self.func = func
self.signature = inspect.signature(func)
self.type_hints = get_type_hints(func)
# Preserve Annotated extras when possible (Python 3.9+ / 3.10+ support include_extras)
try:
self.type_hints = get_type_hints(func, include_extras=True)
except TypeError:
# Older Python versions / typing implementations may not accept include_extras
self.type_hints = get_type_hints(func)
self._context_param = context_param

self._validate_signature()
Expand All @@ -114,6 +124,32 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) -
# Create a Pydantic model for validation
self.input_model = self._create_input_model()

def _extract_annotated_metadata(self, annotation: Any) -> tuple[Any, Optional[Any]]:
"""Extract type and metadata from Annotated type hint.

Returns:
(actual_type, metadata) where metadata is either:
- a string description
- a pydantic.fields.FieldInfo instance (from Field(...))
- None if no Annotated extras were found
"""
if get_origin(annotation) is Annotated:
args = get_args(annotation)
actual_type = args[0] # Keep the type as-is (including Optional[T])

# Look through metadata for description
for meta in args[1:]:
if isinstance(meta, str):
return actual_type, meta
if isinstance(meta, FieldInfo):
return actual_type, meta

# Annotated but no useful metadata
return actual_type, None

# Not annotated
return annotation, None

def _validate_signature(self) -> None:
"""Verify that ToolContext is used correctly in the function signature."""
for param in self.signature.parameters.values():
Expand Down Expand Up @@ -146,13 +182,38 @@ def _create_input_model(self) -> Type[BaseModel]:
if self._is_special_parameter(name):
continue

# Get parameter type and default
# Get parameter type hint and any Annotated metadata
param_type = self.type_hints.get(name, Any)
actual_type, annotated_meta = self._extract_annotated_metadata(param_type)

# Determine parameter default value
default = ... if param.default is inspect.Parameter.empty else param.default
description = self.param_descriptions.get(name, f"Parameter {name}")

# Create Field with description and default
field_definitions[name] = (param_type, Field(default=default, description=description))
# Determine description (priority: Annotated > docstring > generic)
description: str
if isinstance(annotated_meta, str):
description = annotated_meta
elif isinstance(annotated_meta, FieldInfo) and annotated_meta.description is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are conditioning on FieldInfo several times: L144, L196, L204.

I'm going to tinker a bit to see if we can condense such that if we encounter FieldInfo we return all relevant information to reduce the number of if statements we are seeing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, also would processing Annotated metadata once, then extracting all the relevant details (like description, default, and other constraints) into a single place, and then using that to build the field_definitions, would this avoid the repeated isinstance checks?

Copy link
Member

@dbschmigelski dbschmigelski Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, tried pushing to a branch but was getting blocked by a rule.

I believe the following is semantically the same as what you have

   def _extract_annotated_metadata(
        self, annotation: Any, param_name: str, param_default: Any
    ) -> tuple[Any, FieldInfo]:
        """Extract type and create FieldInfo from Annotated type hint.

        Returns:
            (actual_type, field_info) where field_info is always a FieldInfo instance
        """
        if get_origin(annotation) is Annotated:
            args = get_args(annotation)
            actual_type = args[0]  # Keep the type as-is (including Optional[T])

            field_info: FieldInfo | None = None
            description: str | None = None

            # Look through metadata for FieldInfo and string descriptions
            for meta in args[1:]:
                if isinstance(meta, FieldInfo):
                    field_info = meta
                    if meta.description is not None:
                        description = meta.description
                    break

                if isinstance(meta, str):
                    description = meta
                    break

            final_description = (
                description, # TODO: need to exit if description is empty string since we should honor that
                or self.param_descriptions.get(param_name)  # Docstring description
                or f"Parameter {param_name}"  # Generic fallback
            )

            # Create final FieldInfo
            if field_info:
                field_info_copy = copy(field_info)
                field_info_copy.description = final_description
                # Override default if function signature has one (... means required field)
                if param_default is not ...:
                    field_info_copy.default = param_default
                return actual_type, field_info_copy
            else:
                return actual_type, Field(default=param_default, description=final_description)
    def _create_input_model(self) -> Type[BaseModel]:
        """Create a Pydantic model from function signature for input validation.

        This method analyzes the function's signature, type hints, and docstring to create a Pydantic model that can
        validate input data before passing it to the function.

        Special parameters that can be automatically injected are excluded from the model.

        Returns:
            A Pydantic BaseModel class customized for the function's parameters.
        """
        field_definitions: dict[str, Any] = {}

        for name, param in self.signature.parameters.items():
            # Skip parameters that will be automatically injected
            if self._is_special_parameter(name):
                continue

            # Get parameter type hint and create FieldInfo
            param_type = self.type_hints.get(name, Any)
            # Use ... (Ellipsis) to indicate required field, actual default otherwise
            default = ... if param.default is inspect.Parameter.empty else param.default

            actual_type, field_info = self._extract_annotated_metadata(param_type, name, default)
            field_definitions[name] = (actual_type, field_info)

        # Create model name based on function name
        model_name = f"{self.func.__name__.capitalize()}Tool"

        # Create and return the model
        if field_definitions:
            return create_model(model_name, **field_definitions)
        else:
            # Handle case with no parameters
            return create_model(model_name)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for helping @dbschmigelski , I've pushed a new commit that implements the refactor you proposed. It should now correctly handle the description fallback and the default value precedence that was causing the test failures. I also added the new test case with the inner-default

description = annotated_meta.description
elif name in self.param_descriptions:
description = self.param_descriptions[name]
else:
description = f"Parameter {name}"

# Create Field definition for create_model
if isinstance(annotated_meta, FieldInfo):
# Create a defensive copy to avoid mutating a shared FieldInfo instance.
field_info_copy = copy(annotated_meta)
field_info_copy.description = description

# Update default if specified in the function signature.
if default is not ...:
field_info_copy.default = default

field_definitions[name] = (actual_type, field_info_copy)
else:
# For non-FieldInfo metadata, create a new Field.
field_definitions[name] = (actual_type, Field(default=default, description=description))

# Create model name based on function name
model_name = f"{self.func.__name__.capitalize()}Tool"
Expand Down
232 changes: 231 additions & 1 deletion tests/strands/tools/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
"""

from asyncio import Queue
from typing import Any, AsyncGenerator, Dict, Optional, Union
from typing import Annotated, Any, AsyncGenerator, Dict, List, Optional, Union
from unittest.mock import MagicMock

import pytest
from pydantic import Field

import strands
from strands import Agent
Expand Down Expand Up @@ -1450,3 +1451,232 @@ def test_function_tool_metadata_validate_signature_missing_context_config():
@strands.tool
def my_tool(tool_context: ToolContext):
pass


def test_tool_decorator_annotated_string_description():
"""Test tool decorator with Annotated type hints for descriptions."""

@strands.tool
def annotated_tool(
name: Annotated[str, "The user's full name"],
age: Annotated[int, "The user's age in years"],
city: str, # No annotation - should use docstring or generic
) -> str:
"""Tool with annotated parameters.

Args:
city: The user's city (from docstring)
"""
return f"{name}, {age}, {city}"

spec = annotated_tool.tool_spec
schema = spec["inputSchema"]["json"]

# Check that annotated descriptions are used
assert schema["properties"]["name"]["description"] == "The user's full name"
assert schema["properties"]["age"]["description"] == "The user's age in years"

# Check that docstring is still used for non-annotated params
assert schema["properties"]["city"]["description"] == "The user's city (from docstring)"

# Verify all are required
assert set(schema["required"]) == {"name", "age", "city"}


def test_tool_decorator_annotated_pydantic_field_constraints():
"""Test tool decorator with Pydantic Field in Annotated."""

@strands.tool
def field_annotated_tool(
email: Annotated[str, Field(description="User's email address", pattern=r"^[\w\.-]+@[\w\.-]+\.\w+$")],
score: Annotated[int, Field(description="Score between 0-100", ge=0, le=100)] = 50,
) -> str:
"""Tool with Pydantic Field annotations."""
return f"{email}: {score}"

spec = field_annotated_tool.tool_spec
schema = spec["inputSchema"]["json"]

# Check descriptions from Field
assert schema["properties"]["email"]["description"] == "User's email address"
assert schema["properties"]["score"]["description"] == "Score between 0-100"

# Check that constraints are preserved
assert schema["properties"]["score"]["minimum"] == 0
assert schema["properties"]["score"]["maximum"] == 100

# Check required fields
assert "email" in schema["required"]
assert "score" not in schema["required"] # Has default


def test_tool_decorator_annotated_overrides_docstring():
"""Test that Annotated descriptions override docstring descriptions."""

@strands.tool
def override_tool(param: Annotated[str, "Description from annotation"]) -> str:
"""Tool with both annotation and docstring.

Args:
param: Description from docstring (should be overridden)
"""
return param

spec = override_tool.tool_spec
schema = spec["inputSchema"]["json"]

# Annotated description should win
assert schema["properties"]["param"]["description"] == "Description from annotation"


def test_tool_decorator_annotated_optional_type():
"""Test tool with Optional types in Annotated."""

@strands.tool
def optional_annotated_tool(
required: Annotated[str, "Required parameter"], optional: Annotated[Optional[str], "Optional parameter"] = None
) -> str:
"""Tool with optional annotated parameter."""
return f"{required}, {optional}"

spec = optional_annotated_tool.tool_spec
schema = spec["inputSchema"]["json"]

# Check descriptions
assert schema["properties"]["required"]["description"] == "Required parameter"
assert schema["properties"]["optional"]["description"] == "Optional parameter"

# Check required list
assert "required" in schema["required"]
assert "optional" not in schema["required"]


def test_tool_decorator_annotated_complex_types():
"""Test tool with complex types in Annotated."""

@strands.tool
def complex_annotated_tool(
tags: Annotated[List[str], "List of tag strings"], config: Annotated[Dict[str, Any], "Configuration dictionary"]
) -> str:
"""Tool with complex annotated types."""
return f"Tags: {len(tags)}, Config: {len(config)}"

spec = complex_annotated_tool.tool_spec
schema = spec["inputSchema"]["json"]

# Check descriptions
assert schema["properties"]["tags"]["description"] == "List of tag strings"
assert schema["properties"]["config"]["description"] == "Configuration dictionary"

# Check types are preserved
assert schema["properties"]["tags"]["type"] == "array"
assert schema["properties"]["config"]["type"] == "object"


def test_tool_decorator_annotated_mixed_styles():
"""Test tool with mixed annotation styles."""

@strands.tool
def mixed_tool(
plain: str,
annotated_str: Annotated[str, "String description"],
annotated_field: Annotated[int, Field(description="Field description", ge=0)],
docstring_only: int,
) -> str:
"""Tool with mixed parameter styles.

Args:
plain: Plain parameter description
docstring_only: Docstring description for this param
"""
return "mixed"

spec = mixed_tool.tool_spec
schema = spec["inputSchema"]["json"]

# Check each style works correctly
assert schema["properties"]["plain"]["description"] == "Plain parameter description"
assert schema["properties"]["annotated_str"]["description"] == "String description"
assert schema["properties"]["annotated_field"]["description"] == "Field description"
assert schema["properties"]["docstring_only"]["description"] == "Docstring description for this param"


@pytest.mark.asyncio
async def test_tool_decorator_annotated_execution(alist):
"""Test that annotated tools execute correctly."""

@strands.tool
def execution_test(name: Annotated[str, "User name"], count: Annotated[int, "Number of times"] = 1) -> str:
"""Test execution with annotations."""
return f"Hello {name} " * count

# Test tool use
tool_use = {"toolUseId": "test-id", "input": {"name": "Alice", "count": 2}}
stream = execution_test.stream(tool_use, {})

result = (await alist(stream))[-1]
assert result["tool_result"]["status"] == "success"
assert "Hello Alice Hello Alice" in result["tool_result"]["content"][0]["text"]

# Test direct call
direct_result = execution_test("Bob", 3)
assert direct_result == "Hello Bob Hello Bob Hello Bob "


def test_tool_decorator_annotated_no_description_fallback():
"""Test that Annotated without description falls back to docstring."""

@strands.tool
def no_desc_annotated(
param: Annotated[str, Field()], # Field without description
) -> str:
"""Tool with Annotated but no description.

Args:
param: Docstring description
"""
return param

spec = no_desc_annotated.tool_spec
schema = spec["inputSchema"]["json"]

# Should fall back to docstring
assert schema["properties"]["param"]["description"] == "Docstring description"


def test_tool_decorator_annotated_empty_string_description():
"""Test handling of empty string descriptions in Annotated."""

@strands.tool
def empty_desc_tool(
param: Annotated[str, ""], # Empty string description
) -> str:
"""Tool with empty annotation description.

Args:
param: Docstring description
"""
return param

spec = empty_desc_tool.tool_spec
schema = spec["inputSchema"]["json"]

# Empty string is still a valid description, should not fall back
assert schema["properties"]["param"]["description"] == ""


@pytest.mark.asyncio
async def test_tool_decorator_annotated_validation_error(alist):
"""Test that validation works correctly with annotated parameters."""

@strands.tool
def validation_tool(age: Annotated[int, "User age"]) -> str:
"""Tool for validation testing."""
return f"Age: {age}"

# Test with wrong type
tool_use = {"toolUseId": "test-id", "input": {"age": "not an int"}}
stream = validation_tool.stream(tool_use, {})

result = (await alist(stream))[-1]
assert result["tool_result"]["status"] == "error"