Skip to content

Commit fcf7610

Browse files
xingyaowwopenhands-agentsimonrosenberg
authored
Refactor ToolDefinition architecture to use subclass pattern for all tools (#971)
Co-authored-by: openhands <[email protected]> Co-authored-by: simonrosenberg <[email protected]>
1 parent 9bcaf5c commit fcf7610

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+739
-657
lines changed

examples/01_standalone_sdk/02_custom_tools.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@
2626
)
2727
from openhands.tools.execute_bash import (
2828
BashExecutor,
29+
BashTool,
2930
ExecuteBashAction,
30-
execute_bash_tool,
3131
)
3232
from openhands.tools.file_editor import FileEditorTool
3333

@@ -115,6 +115,42 @@ def __call__(self, action: GrepAction, conversation=None) -> GrepObservation: #
115115
* When you are doing an open ended search that may require multiple rounds of globbing and grepping, use the Agent tool instead
116116
""" # noqa: E501
117117

118+
119+
# --- Tool Definition ---
120+
121+
122+
class GrepTool(ToolDefinition[GrepAction, GrepObservation]):
123+
"""A custom grep tool that searches file contents using regular expressions."""
124+
125+
@classmethod
126+
def create(
127+
cls, conv_state, bash_executor: BashExecutor | None = None
128+
) -> Sequence[ToolDefinition]:
129+
"""Create GrepTool instance with a GrepExecutor.
130+
131+
Args:
132+
conv_state: Conversation state to get working directory from.
133+
bash_executor: Optional bash executor to reuse. If not provided,
134+
a new one will be created.
135+
136+
Returns:
137+
A sequence containing a single GrepTool instance.
138+
"""
139+
if bash_executor is None:
140+
bash_executor = BashExecutor(working_dir=conv_state.workspace.working_dir)
141+
grep_executor = GrepExecutor(bash_executor)
142+
143+
return [
144+
cls(
145+
name="grep",
146+
description=_GREP_DESCRIPTION,
147+
action_type=GrepAction,
148+
observation_type=GrepObservation,
149+
executor=grep_executor,
150+
)
151+
]
152+
153+
118154
# Configure LLM
119155
api_key = os.getenv("LLM_API_KEY")
120156
assert api_key is not None, "LLM_API_KEY environment variable is not set."
@@ -135,16 +171,11 @@ def _make_bash_and_grep_tools(conv_state) -> list[ToolDefinition]:
135171
"""Create execute_bash and custom grep tools sharing one executor."""
136172

137173
bash_executor = BashExecutor(working_dir=conv_state.workspace.working_dir)
138-
bash_tool = execute_bash_tool.set_executor(executor=bash_executor)
139-
140-
grep_executor = GrepExecutor(bash_executor)
141-
grep_tool = ToolDefinition(
142-
name="grep",
143-
description=_GREP_DESCRIPTION,
144-
action_type=GrepAction,
145-
observation_type=GrepObservation,
146-
executor=grep_executor,
147-
)
174+
# bash_tool = execute_bash_tool.set_executor(executor=bash_executor)
175+
bash_tool = BashTool.create(conv_state, executor=bash_executor)[0]
176+
177+
# Use the GrepTool.create() method with shared bash_executor
178+
grep_tool = GrepTool.create(conv_state, bash_executor=bash_executor)[0]
148179

149180
return [bash_tool, grep_tool]
150181

openhands-sdk/openhands/sdk/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
Action,
3838
Observation,
3939
Tool,
40-
ToolBase,
4140
ToolDefinition,
4241
list_registered_tools,
4342
register_tool,
@@ -67,7 +66,6 @@
6766
"RedactedThinkingBlock",
6867
"Tool",
6968
"ToolDefinition",
70-
"ToolBase",
7169
"AgentBase",
7270
"Agent",
7371
"Action",

openhands-sdk/openhands/sdk/agent/agent.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
from openhands.sdk.security.llm_analyzer import LLMSecurityAnalyzer
3838
from openhands.sdk.tool import (
3939
Action,
40-
FinishTool,
4140
Observation,
4241
)
4342
from openhands.sdk.tool.builtins import FinishAction, ThinkAction
@@ -431,6 +430,6 @@ def _execute_action_event(
431430
on_event(obs_event)
432431

433432
# Set conversation state
434-
if tool.name == FinishTool.name:
433+
if tool.name == "finish":
435434
state.agent_status = AgentExecutionStatus.FINISHED
436435
return obs_event

openhands-sdk/openhands/sdk/agent/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,9 @@ def _initialize(self, state: "ConversationState"):
223223
)
224224

225225
# Always include built-in tools; not subject to filtering
226-
tools.extend(BUILT_IN_TOOLS)
226+
# Instantiate built-in tools using their .create() method
227+
for tool_class in BUILT_IN_TOOLS:
228+
tools.extend(tool_class.create(state))
227229

228230
# Check tool types
229231
for tool in tools:

openhands-sdk/openhands/sdk/llm/llm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727

2828
if TYPE_CHECKING: # type hints only, avoid runtime import cycle
29-
from openhands.sdk.tool.tool import ToolBase
29+
from openhands.sdk.tool.tool import ToolDefinition
3030

3131
from openhands.sdk.utils.pydantic_diff import pretty_pydantic_diff
3232

@@ -425,7 +425,7 @@ def restore_metrics(self, metrics: Metrics) -> None:
425425
def completion(
426426
self,
427427
messages: list[Message],
428-
tools: Sequence[ToolBase] | None = None,
428+
tools: Sequence[ToolDefinition] | None = None,
429429
_return_metrics: bool = False,
430430
add_security_risk_prediction: bool = False,
431431
**kwargs,
@@ -562,7 +562,7 @@ def _one_attempt(**retry_kwargs) -> ModelResponse:
562562
def responses(
563563
self,
564564
messages: list[Message],
565-
tools: Sequence[ToolBase] | None = None,
565+
tools: Sequence[ToolDefinition] | None = None,
566566
include: list[str] | None = None,
567567
store: bool | None = None,
568568
_return_metrics: bool = False,

openhands-sdk/openhands/sdk/llm/router/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from openhands.sdk.llm.llm_response import LLMResponse
1212
from openhands.sdk.llm.message import Message
1313
from openhands.sdk.logger import get_logger
14-
from openhands.sdk.tool.tool import ToolBase
14+
from openhands.sdk.tool.tool import ToolDefinition
1515

1616

1717
logger = get_logger(__name__)
@@ -49,7 +49,7 @@ def validate_llms_not_empty(cls, v):
4949
def completion(
5050
self,
5151
messages: list[Message],
52-
tools: Sequence[ToolBase] | None = None,
52+
tools: Sequence[ToolDefinition] | None = None,
5353
return_metrics: bool = False,
5454
add_security_risk_prediction: bool = False,
5555
**kwargs,

openhands-sdk/openhands/sdk/mcp/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from openhands.sdk.logger import get_logger
1010
from openhands.sdk.mcp import MCPClient, MCPToolDefinition
11-
from openhands.sdk.tool.tool import ToolBase
11+
from openhands.sdk.tool.tool import ToolDefinition
1212

1313

1414
logger = get_logger(__name__)
@@ -30,9 +30,9 @@ async def log_handler(message: LogMessage):
3030
logger.log(level, msg, extra=extra)
3131

3232

33-
async def _list_tools(client: MCPClient) -> list[ToolBase]:
33+
async def _list_tools(client: MCPClient) -> list[ToolDefinition]:
3434
"""List tools from an MCP client."""
35-
tools: list[ToolBase] = []
35+
tools: list[ToolDefinition] = []
3636

3737
async with client:
3838
assert client.is_connected(), "MCP client is not connected."

openhands-sdk/openhands/sdk/tool/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from openhands.sdk.tool.tool import (
1515
ExecutableTool,
1616
ToolAnnotations,
17-
ToolBase,
1817
ToolDefinition,
1918
ToolExecutor,
2019
)
@@ -23,7 +22,6 @@
2322
__all__ = [
2423
"Tool",
2524
"ToolDefinition",
26-
"ToolBase",
2725
"ToolAnnotations",
2826
"ToolExecutor",
2927
"ExecutableTool",

openhands-sdk/openhands/sdk/tool/builtins/finish.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Sequence
2-
from typing import TYPE_CHECKING
2+
from typing import TYPE_CHECKING, Self
33

44
from pydantic import Field
55
from rich.text import Text
@@ -16,6 +16,7 @@
1616

1717
if TYPE_CHECKING:
1818
from openhands.sdk.conversation.base import BaseConversation
19+
from openhands.sdk.conversation.state import ConversationState
1920

2021

2122
class FinishAction(Action):
@@ -67,17 +68,42 @@ def __call__(
6768
return FinishObservation(message=action.message)
6869

6970

70-
FinishTool = ToolDefinition(
71-
name="finish",
72-
action_type=FinishAction,
73-
observation_type=FinishObservation,
74-
description=TOOL_DESCRIPTION,
75-
executor=FinishExecutor(),
76-
annotations=ToolAnnotations(
77-
title="finish",
78-
readOnlyHint=True,
79-
destructiveHint=False,
80-
idempotentHint=True,
81-
openWorldHint=False,
82-
),
83-
)
71+
class FinishTool(ToolDefinition[FinishAction, FinishObservation]):
72+
"""Tool for signaling the completion of a task or conversation."""
73+
74+
@classmethod
75+
def create(
76+
cls,
77+
conv_state: "ConversationState | None" = None, # noqa: ARG003
78+
**params,
79+
) -> Sequence[Self]:
80+
"""Create FinishTool instance.
81+
82+
Args:
83+
conv_state: Optional conversation state (not used by FinishTool).
84+
**params: Additional parameters (none supported).
85+
86+
Returns:
87+
A sequence containing a single FinishTool instance.
88+
89+
Raises:
90+
ValueError: If any parameters are provided.
91+
"""
92+
if params:
93+
raise ValueError("FinishTool doesn't accept parameters")
94+
return [
95+
cls(
96+
name="finish",
97+
action_type=FinishAction,
98+
observation_type=FinishObservation,
99+
description=TOOL_DESCRIPTION,
100+
executor=FinishExecutor(),
101+
annotations=ToolAnnotations(
102+
title="finish",
103+
readOnlyHint=True,
104+
destructiveHint=False,
105+
idempotentHint=True,
106+
openWorldHint=False,
107+
),
108+
)
109+
]

openhands-sdk/openhands/sdk/tool/builtins/think.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Sequence
2-
from typing import TYPE_CHECKING
2+
from typing import TYPE_CHECKING, Self
33

44
from pydantic import Field
55
from rich.text import Text
@@ -16,6 +16,7 @@
1616

1717
if TYPE_CHECKING:
1818
from openhands.sdk.conversation.base import BaseConversation
19+
from openhands.sdk.conversation.state import ConversationState
1920

2021

2122
class ThinkAction(Action):
@@ -83,16 +84,41 @@ def __call__(
8384
return ThinkObservation()
8485

8586

86-
ThinkTool = ToolDefinition(
87-
name="think",
88-
description=THINK_DESCRIPTION,
89-
action_type=ThinkAction,
90-
observation_type=ThinkObservation,
91-
executor=ThinkExecutor(),
92-
annotations=ToolAnnotations(
93-
readOnlyHint=True,
94-
destructiveHint=False,
95-
idempotentHint=True,
96-
openWorldHint=False,
97-
),
98-
)
87+
class ThinkTool(ToolDefinition[ThinkAction, ThinkObservation]):
88+
"""Tool for logging thoughts without making changes."""
89+
90+
@classmethod
91+
def create(
92+
cls,
93+
conv_state: "ConversationState | None" = None, # noqa: ARG003
94+
**params,
95+
) -> Sequence[Self]:
96+
"""Create ThinkTool instance.
97+
98+
Args:
99+
conv_state: Optional conversation state (not used by ThinkTool).
100+
**params: Additional parameters (none supported).
101+
102+
Returns:
103+
A sequence containing a single ThinkTool instance.
104+
105+
Raises:
106+
ValueError: If any parameters are provided.
107+
"""
108+
if params:
109+
raise ValueError("ThinkTool doesn't accept parameters")
110+
return [
111+
cls(
112+
name="think",
113+
description=THINK_DESCRIPTION,
114+
action_type=ThinkAction,
115+
observation_type=ThinkObservation,
116+
executor=ThinkExecutor(),
117+
annotations=ToolAnnotations(
118+
readOnlyHint=True,
119+
destructiveHint=False,
120+
idempotentHint=True,
121+
openWorldHint=False,
122+
),
123+
)
124+
]

0 commit comments

Comments
 (0)