Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 53 additions & 7 deletions src/strands/tools/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,56 @@ def _create_input_model(self) -> Type[BaseModel]:
# Handle case with no parameters
return create_model(model_name)

def _extract_description_from_docstring(self) -> str:
"""Extract the docstring excluding only the Args section.
This method uses the parsed docstring to extract everything except
the Args/Arguments/Parameters section, preserving Returns, Raises,
Examples, and other sections.
Returns:
The description text, or the function name if no description is available.
"""
func_name = self.func.__name__

# Fallback: try to extract manually from raw docstring
raw_docstring = inspect.getdoc(self.func)
if raw_docstring:
lines = raw_docstring.strip().split("\n")
result_lines = []
skip_args_section = False

for line in lines:
stripped_line = line.strip()

# Check if we're starting the Args section
if stripped_line.lower().startswith(("args:", "arguments:", "parameters:", "param:", "params:")):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consideration: Not saying it needs to be done for this PR, but I'm wondering if there is a reliable third party library out there that we can use to parse the docstrings. From a cursory search, I came across https://pypi.org/project/docstring-parser/. It's in preview though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah - I also looked for built in ways to do this and didn't see any great ways. I think if we were starting today we'd do this differently so that we're not managing raw strings.

I'm a bit reluctant to add a library just this :(

skip_args_section = True
continue

# Check if we're starting a new section (not Args)
elif (
stripped_line.lower().startswith(("returns:", "return:", "yields:", "yield:"))
or stripped_line.lower().startswith(("raises:", "raise:", "except:", "exceptions:"))
or stripped_line.lower().startswith(("examples:", "example:", "note:", "notes:"))
or stripped_line.lower().startswith(("see also:", "seealso:", "references:", "ref:"))
):
skip_args_section = False
result_lines.append(line)
continue

# If we're not in the Args section, include the line
if not skip_args_section:
result_lines.append(line)

# Join and clean up the description
description = "\n".join(result_lines).strip()
if description:
return description

# Final fallback: use function name
return func_name

def extract_metadata(self) -> ToolSpec:
"""Extract metadata from the function to create a tool specification.
Expand All @@ -173,20 +223,16 @@ def extract_metadata(self) -> ToolSpec:
The specification includes:
- name: The function name (or custom override)
- description: The function's docstring
- description: The function's docstring description (excluding Args)
- inputSchema: A JSON schema describing the expected parameters
Returns:
A dictionary containing the tool specification.
"""
func_name = self.func.__name__

# Extract function description from docstring, preserving paragraph breaks
description = inspect.getdoc(self.func)
if description:
description = description.strip()
else:
description = func_name
# Extract function description from parsed docstring, excluding Args section and beyond
description = self._extract_description_from_docstring()

# Get schema directly from the Pydantic model
input_schema = self.input_model.model_json_schema()
Expand Down
4 changes: 2 additions & 2 deletions tests/strands/agent/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2240,8 +2240,8 @@ def test_agent_backwards_compatibility_single_text_block():

# Should extract text for backwards compatibility
assert agent.system_prompt == text


@pytest.mark.parametrize(
"content, expected",
[
Expand Down
177 changes: 169 additions & 8 deletions tests/strands/tools/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,7 @@ def test_tool(param1: str, param2: int) -> str:

# Check basic spec properties
assert spec["name"] == "test_tool"
assert (
spec["description"]
== """Test tool function.

Args:
param1: First parameter
param2: Second parameter"""
)
assert spec["description"] == "Test tool function."

# Check input schema
schema = spec["inputSchema"]["json"]
Expand Down Expand Up @@ -310,6 +303,174 @@ def test_tool(required: str, optional: Optional[int] = None) -> str:
exp_events = [
ToolResultEvent({"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello 42"}]})
]
assert tru_events == exp_events


@pytest.mark.asyncio
async def test_docstring_description_extraction():
"""Test that docstring descriptions are extracted correctly, excluding Args section."""

@strands.tool
def tool_with_full_docstring(param1: str, param2: int) -> str:
"""This is the main description.

This is more description text.

Args:
param1: First parameter
param2: Second parameter

Returns:
A string result

Raises:
ValueError: If something goes wrong
"""
return f"{param1} {param2}"

spec = tool_with_full_docstring.tool_spec
assert (
spec["description"]
== """This is the main description.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: So that you can present this more cleanly in code, you can use textwrap.dedent:

import textwrap
...
description = textwrap.dedent("""
    This is the main description.
    ...
""")


This is more description text.

Returns:
A string result

Raises:
ValueError: If something goes wrong"""
)


def test_docstring_args_variations():
"""Test that various Args section formats are properly excluded."""

@strands.tool
def tool_with_args(param: str) -> str:
"""Main description.

Args:
param: Parameter description
"""
return param

@strands.tool
def tool_with_arguments(param: str) -> str:
"""Main description.

Arguments:
param: Parameter description
"""
return param

@strands.tool
def tool_with_parameters(param: str) -> str:
"""Main description.

Parameters:
param: Parameter description
"""
return param

@strands.tool
def tool_with_params(param: str) -> str:
"""Main description.

Params:
param: Parameter description
"""
return param

for tool in [tool_with_args, tool_with_arguments, tool_with_parameters, tool_with_params]:
spec = tool.tool_spec
assert spec["description"] == "Main description."


def test_docstring_no_args_section():
"""Test docstring extraction when there's no Args section."""

@strands.tool
def tool_no_args(param: str) -> str:
"""This is the complete description.

Returns:
A string result
"""
return param

spec = tool_no_args.tool_spec
expected_desc = """This is the complete description.

Returns:
A string result"""
assert spec["description"] == expected_desc


def test_docstring_only_args_section():
"""Test docstring extraction when there's only an Args section."""

@strands.tool
def tool_only_args(param: str) -> str:
"""Args:
param: Parameter description
"""
return param

spec = tool_only_args.tool_spec
# Should fall back to function name when no description remains
assert spec["description"] == "tool_only_args"


def test_docstring_empty():
"""Test docstring extraction when docstring is empty."""

@strands.tool
def tool_empty_docstring(param: str) -> str:
return param

spec = tool_empty_docstring.tool_spec
# Should fall back to function name
assert spec["description"] == "tool_empty_docstring"


def test_docstring_preserves_other_sections():
"""Test that non-Args sections are preserved in the description."""

@strands.tool
def tool_multiple_sections(param: str) -> str:
"""Main description here.

Args:
param: This should be excluded

Returns:
This should be included

Raises:
ValueError: This should be included

Examples:
This should be included

Note:
This should be included
"""
return param

spec = tool_multiple_sections.tool_spec
description = spec["description"]

# Should include main description and other sections
assert "Main description here." in description
assert "Returns:" in description
assert "This should be included" in description
assert "Raises:" in description
assert "Examples:" in description
assert "Note:" in description

# Should exclude Args section
assert "This should be excluded" not in description


@pytest.mark.asyncio
Expand Down
Loading