Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/strands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from . import agent, models, telemetry, types
from .agent.agent import Agent
from .agent.base import AgentBase
from .tools.decorator import tool
from .types.tools import ToolContext

__all__ = [
"Agent",
"AgentBase",
"agent",
"models",
"tool",
Expand Down
5 changes: 4 additions & 1 deletion src/strands/agent/__init__.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the PR description indicating the purpose of the base class - e.g. is it just for agents or is it meant to represent things in-place of an agent (like humans) as the tracking issue #573 calls out.

Specifically, if we were starting fresh, what would implement this? MultiAgent primitives? A2AAgent, etc?

Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

It includes:

- Agent: The main interface for interacting with AI models and tools
- AgentBase: Abstract interface for all agent types
- Agent: The main implementation for interacting with AI models and tools
- ConversationManager: Classes for managing conversation history and context windows
"""

from .agent import Agent
from .agent_result import AgentResult
from .base import AgentBase
from .conversation_manager import (
ConversationManager,
NullConversationManager,
Expand All @@ -17,6 +19,7 @@

__all__ = [
"Agent",
"AgentBase",
"AgentResult",
"ConversationManager",
"NullConversationManager",
Expand Down
69 changes: 62 additions & 7 deletions src/strands/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
from ..types.tools import ToolResult, ToolUse
from ..types.traces import AttributeValue
from .agent_result import AgentResult
from .base import AgentBase
from .conversation_manager import (
ConversationManager,
SlidingWindowConversationManager,
Expand All @@ -88,8 +89,8 @@ class _DefaultCallbackHandlerSentinel:
_DEFAULT_AGENT_ID = "default"


class Agent:
"""Core Agent interface.
class Agent(AgentBase):
"""Core Agent implementation.

An agent orchestrates the following workflow:

Expand Down Expand Up @@ -289,8 +290,8 @@ def __init__(
self.messages = messages if messages is not None else []
self.system_prompt = system_prompt
self._default_structured_output_model = structured_output_model
self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT)
self.name = name or _DEFAULT_AGENT_NAME
self._agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT)
self._name = name or _DEFAULT_AGENT_NAME
self.description = description

# If not provided, create a new PrintingCallbackHandler instance
Expand Down Expand Up @@ -338,13 +339,13 @@ def __init__(
# Initialize agent state management
if state is not None:
if isinstance(state, dict):
self.state = AgentState(state)
self._state = AgentState(state)
elif isinstance(state, AgentState):
self.state = state
self._state = state
else:
raise ValueError("state must be an AgentState object or a dict")
else:
self.state = AgentState()
self._state = AgentState()

self.tool_caller = Agent.ToolCaller(self)

Expand Down Expand Up @@ -389,6 +390,60 @@ def tool_names(self) -> list[str]:
all_tools = self.tool_registry.get_all_tools_config()
return list(all_tools.keys())

@property
def agent_id(self) -> str:
"""Unique identifier for the agent.

Returns:
Unique string identifier for this agent instance.
"""
return self._agent_id

@agent_id.setter
def agent_id(self, value: str) -> None:
"""Set the agent identifier.

Args:
value: New agent identifier.
"""
self._agent_id = value

@property
def name(self) -> str:
"""Human-readable name of the agent.

Returns:
Display name for the agent.
"""
return self._name

@name.setter
def name(self, value: str) -> None:
"""Set the agent name.

Args:
value: New agent name.
"""
self._name = value

@property
def state(self) -> AgentState:
"""Current state of the agent.

Returns:
AgentState object containing stateful information.
"""
return self._state

@state.setter
def state(self, value: AgentState) -> None:
"""Set the agent state.

Args:
value: New agent state.
"""
self._state = value

def __call__(
self,
prompt: AgentInput = None,
Expand Down
117 changes: 117 additions & 0 deletions src/strands/agent/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""Agent Interface.

Defines the minimal interface that all agent types must implement.
"""

from abc import ABC, abstractmethod
from typing import Any, AsyncIterator, Type

from pydantic import BaseModel

from ..types.agent import AgentInput
from .agent_result import AgentResult
from .state import AgentState


class AgentBase(ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we use Protocol instead?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pgrayy has also called out that Protocols should be preferred in general over base-classes; it would also enable us to specify the properties as normal fields instead of getters

"""Abstract interface for all agent types in Strands.

This interface defines the minimal contract that all agent implementations
must satisfy.
"""

@property
@abstractmethod
def agent_id(self) -> str:
"""Unique identifier for the agent.

Returns:
Unique string identifier for this agent instance.
"""
pass

@property
@abstractmethod
def name(self) -> str:
"""Human-readable name of the agent.

Returns:
Display name for the agent.
"""
pass

@property
@abstractmethod
def state(self) -> AgentState:
"""Current state of the agent.

Returns:
AgentState object containing stateful information.
"""
pass

@abstractmethod
async def invoke_async(
self,
prompt: AgentInput = None,
*,
invocation_state: dict[str, Any] | None = None,
structured_output_model: Type[BaseModel] | None = None,
**kwargs: Any,
) -> AgentResult:
"""Asynchronously invoke the agent with the given prompt.

Args:
prompt: Input to the agent.
invocation_state: Optional state to pass to the agent invocation.
structured_output_model: Optional Pydantic model for structured output.
**kwargs: Additional provider-specific arguments.

Returns:
AgentResult containing the agent's response.
"""
pass

@abstractmethod
def __call__(
self,
prompt: AgentInput = None,
*,
invocation_state: dict[str, Any] | None = None,
structured_output_model: Type[BaseModel] | None = None,
**kwargs: Any,
) -> AgentResult:
"""Synchronously invoke the agent with the given prompt.

Args:
prompt: Input to the agent.
invocation_state: Optional state to pass to the agent invocation.
structured_output_model: Optional Pydantic model for structured output.
**kwargs: Additional provider-specific arguments.

Returns:
AgentResult containing the agent's response.
"""
pass

@abstractmethod
def stream_async(
self,
prompt: AgentInput = None,
*,
invocation_state: dict[str, Any] | None = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the use-cases that we're targeting, will invocation_state and structured_output_model be supported? For instance, for multi-agent, I think structured_output is not supported but invocation_state is?

structured_output_model: Type[BaseModel] | None = None,
**kwargs: Any,
) -> AsyncIterator[Any]:
"""Stream agent execution asynchronously.

Args:
prompt: Input to the agent.
invocation_state: Optional state to pass to the agent invocation.
structured_output_model: Optional Pydantic model for structured output.
**kwargs: Additional provider-specific arguments.

Yields:
Events representing the streaming execution.
"""
pass
8 changes: 8 additions & 0 deletions tests/strands/multiagent/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ def create_mock_agent(name, response_text="Default response", metrics=None, agen
"""Create a mock Agent with specified properties."""
agent = Mock(spec=Agent)
agent.name = name
agent.agent_id = agent_id or f"{name}_id"
agent.id = agent_id or f"{name}_id"
agent._session_manager = None
agent.hooks = HookRegistry()
agent.state = AgentState()

if metrics is None:
metrics = Mock(
Expand Down Expand Up @@ -280,12 +282,14 @@ async def test_graph_execution_with_failures(mock_strands_tracer, mock_use_span)
"""Test graph execution error handling and failure propagation."""
failing_agent = Mock(spec=Agent)
failing_agent.name = "failing_agent"
failing_agent.agent_id = "fail_node"
failing_agent.id = "fail_node"
failing_agent.__call__ = Mock(side_effect=Exception("Simulated failure"))

# Add required attributes for validation
failing_agent._session_manager = None
failing_agent.hooks = HookRegistry()
failing_agent.state = AgentState()

async def mock_invoke_failure(*args, **kwargs):
raise Exception("Simulated failure")
Expand Down Expand Up @@ -1524,9 +1528,11 @@ async def test_graph_streaming_with_failures(mock_strands_tracer, mock_use_span)
# Create a failing agent
failing_agent = Mock(spec=Agent)
failing_agent.name = "failing_agent"
failing_agent.agent_id = "fail_node"
failing_agent.id = "fail_node"
failing_agent._session_manager = None
failing_agent.hooks = HookRegistry()
failing_agent.state = AgentState()

async def failing_stream(*args, **kwargs):
yield {"agent_start": True}
Expand Down Expand Up @@ -1697,9 +1703,11 @@ async def test_graph_parallel_with_failures(mock_strands_tracer, mock_use_span):
# Create a failing agent
failing_agent = Mock(spec=Agent)
failing_agent.name = "failing_agent"
failing_agent.agent_id = "fail_node"
failing_agent.id = "fail_node"
failing_agent._session_manager = None
failing_agent.hooks = HookRegistry()
failing_agent.state = AgentState()

async def mock_invoke_failure(*args, **kwargs):
await asyncio.sleep(0.05) # Small delay
Expand Down
Loading