From 8c8877282f01339bbb96656d2c4c037dddcc947e Mon Sep 17 00:00:00 2001 From: Arron Bailiss Date: Sat, 1 Nov 2025 17:12:57 -0400 Subject: [PATCH] feat(agent-interface): introduce AgentBase abstract class as the interface for agent classes to implement --- src/strands/__init__.py | 2 + src/strands/agent/__init__.py | 5 +- src/strands/agent/agent.py | 69 +++++++++++++-- src/strands/agent/base.py | 117 +++++++++++++++++++++++++ tests/strands/multiagent/test_graph.py | 8 ++ 5 files changed, 193 insertions(+), 8 deletions(-) create mode 100644 src/strands/agent/base.py diff --git a/src/strands/__init__.py b/src/strands/__init__.py index 3718a29c5..bc17497a0 100644 --- a/src/strands/__init__.py +++ b/src/strands/__init__.py @@ -2,11 +2,13 @@ from . import agent, models, telemetry, types from .agent.agent import Agent +from .agent.base import AgentBase from .tools.decorator import tool from .types.tools import ToolContext __all__ = [ "Agent", + "AgentBase", "agent", "models", "tool", diff --git a/src/strands/agent/__init__.py b/src/strands/agent/__init__.py index 6618d3328..8c8d15648 100644 --- a/src/strands/agent/__init__.py +++ b/src/strands/agent/__init__.py @@ -2,12 +2,14 @@ It includes: -- Agent: The main interface for interacting with AI models and tools +- AgentBase: Abstract interface for all agent types +- Agent: The main implementation for interacting with AI models and tools - ConversationManager: Classes for managing conversation history and context windows """ from .agent import Agent from .agent_result import AgentResult +from .base import AgentBase from .conversation_manager import ( ConversationManager, NullConversationManager, @@ -17,6 +19,7 @@ __all__ = [ "Agent", + "AgentBase", "AgentResult", "ConversationManager", "NullConversationManager", diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 7c63c1e89..6e284e0cd 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -63,6 +63,7 @@ from ..types.tools import ToolResult, ToolUse from ..types.traces import AttributeValue from .agent_result import AgentResult +from .base import AgentBase from .conversation_manager import ( ConversationManager, SlidingWindowConversationManager, @@ -88,8 +89,8 @@ class _DefaultCallbackHandlerSentinel: _DEFAULT_AGENT_ID = "default" -class Agent: - """Core Agent interface. +class Agent(AgentBase): + """Core Agent implementation. An agent orchestrates the following workflow: @@ -289,8 +290,8 @@ def __init__( self.messages = messages if messages is not None else [] self.system_prompt = system_prompt self._default_structured_output_model = structured_output_model - self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT) - self.name = name or _DEFAULT_AGENT_NAME + self._agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT) + self._name = name or _DEFAULT_AGENT_NAME self.description = description # If not provided, create a new PrintingCallbackHandler instance @@ -338,13 +339,13 @@ def __init__( # Initialize agent state management if state is not None: if isinstance(state, dict): - self.state = AgentState(state) + self._state = AgentState(state) elif isinstance(state, AgentState): - self.state = state + self._state = state else: raise ValueError("state must be an AgentState object or a dict") else: - self.state = AgentState() + self._state = AgentState() self.tool_caller = Agent.ToolCaller(self) @@ -389,6 +390,60 @@ def tool_names(self) -> list[str]: all_tools = self.tool_registry.get_all_tools_config() return list(all_tools.keys()) + @property + def agent_id(self) -> str: + """Unique identifier for the agent. + + Returns: + Unique string identifier for this agent instance. + """ + return self._agent_id + + @agent_id.setter + def agent_id(self, value: str) -> None: + """Set the agent identifier. + + Args: + value: New agent identifier. + """ + self._agent_id = value + + @property + def name(self) -> str: + """Human-readable name of the agent. + + Returns: + Display name for the agent. + """ + return self._name + + @name.setter + def name(self, value: str) -> None: + """Set the agent name. + + Args: + value: New agent name. + """ + self._name = value + + @property + def state(self) -> AgentState: + """Current state of the agent. + + Returns: + AgentState object containing stateful information. + """ + return self._state + + @state.setter + def state(self, value: AgentState) -> None: + """Set the agent state. + + Args: + value: New agent state. + """ + self._state = value + def __call__( self, prompt: AgentInput = None, diff --git a/src/strands/agent/base.py b/src/strands/agent/base.py new file mode 100644 index 000000000..c7e984759 --- /dev/null +++ b/src/strands/agent/base.py @@ -0,0 +1,117 @@ +"""Agent Interface. + +Defines the minimal interface that all agent types must implement. +""" + +from abc import ABC, abstractmethod +from typing import Any, AsyncIterator, Type + +from pydantic import BaseModel + +from ..types.agent import AgentInput +from .agent_result import AgentResult +from .state import AgentState + + +class AgentBase(ABC): + """Abstract interface for all agent types in Strands. + + This interface defines the minimal contract that all agent implementations + must satisfy. + """ + + @property + @abstractmethod + def agent_id(self) -> str: + """Unique identifier for the agent. + + Returns: + Unique string identifier for this agent instance. + """ + pass + + @property + @abstractmethod + def name(self) -> str: + """Human-readable name of the agent. + + Returns: + Display name for the agent. + """ + pass + + @property + @abstractmethod + def state(self) -> AgentState: + """Current state of the agent. + + Returns: + AgentState object containing stateful information. + """ + pass + + @abstractmethod + async def invoke_async( + self, + prompt: AgentInput = None, + *, + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, + ) -> AgentResult: + """Asynchronously invoke the agent with the given prompt. + + Args: + prompt: Input to the agent. + invocation_state: Optional state to pass to the agent invocation. + structured_output_model: Optional Pydantic model for structured output. + **kwargs: Additional provider-specific arguments. + + Returns: + AgentResult containing the agent's response. + """ + pass + + @abstractmethod + def __call__( + self, + prompt: AgentInput = None, + *, + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, + ) -> AgentResult: + """Synchronously invoke the agent with the given prompt. + + Args: + prompt: Input to the agent. + invocation_state: Optional state to pass to the agent invocation. + structured_output_model: Optional Pydantic model for structured output. + **kwargs: Additional provider-specific arguments. + + Returns: + AgentResult containing the agent's response. + """ + pass + + @abstractmethod + def stream_async( + self, + prompt: AgentInput = None, + *, + invocation_state: dict[str, Any] | None = None, + structured_output_model: Type[BaseModel] | None = None, + **kwargs: Any, + ) -> AsyncIterator[Any]: + """Stream agent execution asynchronously. + + Args: + prompt: Input to the agent. + invocation_state: Optional state to pass to the agent invocation. + structured_output_model: Optional Pydantic model for structured output. + **kwargs: Additional provider-specific arguments. + + Yields: + Events representing the streaming execution. + """ + pass diff --git a/tests/strands/multiagent/test_graph.py b/tests/strands/multiagent/test_graph.py index 07037a447..4de51e9aa 100644 --- a/tests/strands/multiagent/test_graph.py +++ b/tests/strands/multiagent/test_graph.py @@ -17,9 +17,11 @@ def create_mock_agent(name, response_text="Default response", metrics=None, agen """Create a mock Agent with specified properties.""" agent = Mock(spec=Agent) agent.name = name + agent.agent_id = agent_id or f"{name}_id" agent.id = agent_id or f"{name}_id" agent._session_manager = None agent.hooks = HookRegistry() + agent.state = AgentState() if metrics is None: metrics = Mock( @@ -280,12 +282,14 @@ async def test_graph_execution_with_failures(mock_strands_tracer, mock_use_span) """Test graph execution error handling and failure propagation.""" failing_agent = Mock(spec=Agent) failing_agent.name = "failing_agent" + failing_agent.agent_id = "fail_node" failing_agent.id = "fail_node" failing_agent.__call__ = Mock(side_effect=Exception("Simulated failure")) # Add required attributes for validation failing_agent._session_manager = None failing_agent.hooks = HookRegistry() + failing_agent.state = AgentState() async def mock_invoke_failure(*args, **kwargs): raise Exception("Simulated failure") @@ -1524,9 +1528,11 @@ async def test_graph_streaming_with_failures(mock_strands_tracer, mock_use_span) # Create a failing agent failing_agent = Mock(spec=Agent) failing_agent.name = "failing_agent" + failing_agent.agent_id = "fail_node" failing_agent.id = "fail_node" failing_agent._session_manager = None failing_agent.hooks = HookRegistry() + failing_agent.state = AgentState() async def failing_stream(*args, **kwargs): yield {"agent_start": True} @@ -1697,9 +1703,11 @@ async def test_graph_parallel_with_failures(mock_strands_tracer, mock_use_span): # Create a failing agent failing_agent = Mock(spec=Agent) failing_agent.name = "failing_agent" + failing_agent.agent_id = "fail_node" failing_agent.id = "fail_node" failing_agent._session_manager = None failing_agent.hooks = HookRegistry() + failing_agent.state = AgentState() async def mock_invoke_failure(*args, **kwargs): await asyncio.sleep(0.05) # Small delay