Skip to content

Commit c95f1f1

Browse files
Fix tool filtering with rename tool (#79)
Signed-off-by: Nathalie Jonathan <[email protected]>
1 parent 9be2bd6 commit c95f1f1

File tree

3 files changed

+51
-23
lines changed

3 files changed

+51
-23
lines changed

src/tools/tool_filter.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,11 @@ def process_tool_filter(
6767
tool_registry: The tool registry to filter.
6868
"""
6969
try:
70-
# Create case-insensitive lookup
71-
tool_registry_lower = {k.lower(): k for k in tool_registry.keys()}
70+
# Create display name lookup
71+
display_name = {
72+
tool_info.get('display_name', '').lower(): k
73+
for k, tool_info in tool_registry.items()
74+
}
7275

7376
# Initialize collections
7477
category_to_tools = {}
@@ -120,7 +123,7 @@ def process_tool_filter(
120123
)
121124

122125
# Get current tool names after allow_write filtering
123-
current_tool_names = list(tool_registry.keys())
126+
current_tool_names = [tool['display_name'] for tool in tool_registry.values()]
124127
disabled_tools_from_regex = process_regex_patterns(
125128
disabled_tools_regex_list, current_tool_names
126129
)
@@ -130,17 +133,13 @@ def process_tool_filter(
130133
# Validate and collect all disabled tools
131134
all_disabled_tools = set()
132135
all_disabled_tools.update(
133-
validate_tools(disabled_tools_list, tool_registry_lower, 'disabled_tools')
136+
validate_tools(disabled_tools_list, display_name, 'disabled_tools')
134137
)
135138
all_disabled_tools.update(
136-
validate_tools(
137-
disabled_tools_from_categories, tool_registry_lower, 'disabled_categories'
138-
)
139+
validate_tools(disabled_tools_from_categories, display_name, 'disabled_categories')
139140
)
140141
all_disabled_tools.update(
141-
validate_tools(
142-
disabled_tools_from_regex, tool_registry_lower, 'disabled_tools_regex'
143-
)
142+
validate_tools(disabled_tools_from_regex, display_name, 'disabled_tools_regex')
144143
)
145144

146145
# Remove tools in the disabled list
@@ -150,8 +149,9 @@ def process_tool_filter(
150149

151150
# Log results
152151
source = filter_path if filter_path else 'environment variables'
152+
tool_display_names = [tool['display_name'] for tool in tool_registry.values()]
153153
logging.info(f'Applied tool filter from {source}')
154-
logging.info(f'Available tools after filtering: {list(tool_registry.keys())}')
154+
logging.info(f'Available tools after filtering: {tool_display_names}')
155155

156156
except Exception as e:
157157
logging.error(f'Error processing tool filter: {str(e)}')
@@ -201,9 +201,10 @@ def get_tools(tool_registry: dict, mode: str = 'single', config_file_path: str =
201201
**{k: v for k, v in env_config.items() if not config_file_path},
202202
)
203203

204-
for name, info in TOOL_REGISTRY.items():
204+
for name, info in tool_registry.items():
205205
# Create a copy to avoid modifying the original tool info
206206
tool_info = info.copy()
207+
tool_name = tool_info['display_name']
207208

208209
# If tool is not compatible with the current OpenSearch version, skip, don't enable
209210
if not is_tool_compatible(version, info):
@@ -218,6 +219,6 @@ def get_tools(tool_registry: dict, mode: str = 'single', config_file_path: str =
218219
schema['properties'].pop(field, None)
219220
tool_info['input_schema'] = schema
220221

221-
enabled[name] = tool_info
222+
enabled[tool_name] = tool_info
222223

223224
return enabled

src/tools/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,15 @@ def load_yaml_config(filter_path):
5151
return None
5252

5353

54-
def validate_tools(tool_list, registry_lookup, source_name):
54+
def validate_tools(tool_list, display_lookup, source_name):
5555
"""Validate tools against registry and return valid tools."""
5656
valid_tools = set()
5757
for tool in tool_list:
5858
tool_lower = tool.lower()
59-
if tool_lower in registry_lookup:
60-
valid_tools.add(tool_lower)
59+
# Check if it matches tool display name
60+
if tool_lower in display_lookup:
61+
actual_tool = display_lookup[tool_lower]
62+
valid_tools.add(actual_tool.lower())
6163
else:
6264
logging.warning(f"Ignoring invalid tool from '{source_name}': '{tool}'")
6365
return valid_tools

tests/tools/test_tool_filters.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# A dictionary for mocking TOOL_REGISTRY
1010
MOCK_TOOL_REGISTRY = {
1111
'ListIndexTool': {
12+
'display_name': 'ListIndexTool',
1213
'description': 'List indices',
1314
'input_schema': {'type': 'object', 'properties': {'param1': {'type': 'string'}}},
1415
'function': MagicMock(),
@@ -17,6 +18,7 @@
1718
'max_version': '3.0.0',
1819
},
1920
'SearchIndexTool': {
21+
'display_name': 'SearchIndexTool',
2022
'description': 'Search an index',
2123
'input_schema': {
2224
'type': 'object',
@@ -162,6 +164,7 @@ def test_get_tools_single_mode_handles_missing_properties(self, mock_patches):
162164
# Create tool with missing properties
163165
tool_without_properties = {
164166
'ListIndexTool': {
167+
'display_name': 'ListIndexTool',
165168
'description': 'List indices',
166169
'input_schema': {'type': 'object', 'title': 'ListIndexArgs'},
167170
'function': MagicMock(),
@@ -213,13 +216,15 @@ class TestProcessToolFilter:
213216
def setup_method(self):
214217
"""Set up a fresh copy of the tool registry for each test."""
215218
self.tool_registry = {
216-
'ListIndexTool': {'http_methods': 'GET'},
217-
'SearchIndexTool': {'http_methods': 'GET, POST'},
218-
'MsearchTool': {'http_methods': 'GET, POST'},
219-
'ExplainTool': {'http_methods': 'GET, POST'},
220-
'ClusterHealthTool': {'http_methods': 'GET'},
221-
'IndicesCreateTool': {'http_methods': 'PUT'},
222-
'IndicesStatsTool': {'http_methods': 'GET'},
219+
'ListIndexTool': {'display_name': 'ListIndexTool', 'http_methods': 'GET'},
220+
'SearchIndexTool': {'display_name': 'SearchIndexTool', 'http_methods': 'GET, POST'},
221+
'MsearchTool': {'display_name': 'MsearchTool', 'http_methods': 'GET, POST'},
222+
'ExplainTool': {'display_name': 'ExplainTool', 'http_methods': 'GET, POST'},
223+
'ClusterHealthTool': {'display_name': 'ClusterHealthTool', 'http_methods': 'GET'},
224+
'IndicesCreateTool': {'display_name': 'IndicesCreateTool', 'http_methods': 'PUT'},
225+
'IndicesStatsTool': {'display_name': 'IndicesStatsTool', 'http_methods': 'GET'},
226+
'CountTool': {'display_name': 'CustomCountTool', 'http_methods': 'GET'},
227+
'ListModelTool': {'display_name': 'ModelListTool', 'http_methods': 'GET'},
223228
}
224229
self.category_to_tools = {
225230
'critical': ['SearchIndexTool', 'ExplainTool'],
@@ -269,3 +274,23 @@ def test_process_tool_filter_env(self, caplog):
269274
assert 'MsearchTool' in self.tool_registry
270275
assert 'SearchIndexTool' not in self.tool_registry # In disabled_tools_regex
271276
assert 'ExplainTool' not in self.tool_registry # In disabled_tools
277+
278+
def test_process_tool_filter_rename_tool(self):
279+
"""Test processing tool filtering with tool renaming feature"""
280+
process_tool_filter(
281+
tool_registry=self.tool_registry,
282+
disabled_tools='CountTool',
283+
disabled_tools_regex='list.*',
284+
allow_write=True
285+
)
286+
assert 'CountTool' in self.tool_registry # Renamed to CustomCountTool
287+
assert 'ListModelTool' in self.tool_registry # Renamed to ModelListTool
288+
289+
process_tool_filter(
290+
tool_registry=self.tool_registry,
291+
disabled_tools='CustomCountTool',
292+
disabled_tools_regex='model.*',
293+
allow_write=True
294+
)
295+
assert 'CustomCountTool' not in self.tool_registry
296+
assert 'ModelListTool' not in self.tool_registry

0 commit comments

Comments
 (0)