Skip to content

Commit d0d23bb

Browse files
Address comments
Signed-off-by: Nathalie Jonathan <[email protected]>
1 parent d416d95 commit d0d23bb

File tree

16 files changed

+123
-308
lines changed

16 files changed

+123
-308
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/tool/MLToolExecutor.java

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.opensearch.common.settings.Settings;
1515
import org.opensearch.core.action.ActionListener;
1616
import org.opensearch.core.xcontent.NamedXContentRegistry;
17+
import org.opensearch.ingest.ConfigurationUtils;
1718
import org.opensearch.ml.common.FunctionName;
1819
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
1920
import org.opensearch.ml.common.input.Input;
@@ -84,14 +85,8 @@ public void execute(Input input, ActionListener<Output> listener) {
8485

8586
try {
8687
Map<String, String> mutableParams = new HashMap<>(parameters);
87-
Tool tool = toolFactory.create(mutableParams);
88-
89-
// Validate original parameter types
90-
Map<String, Object> originalParameters = toolMLInput.getOriginalParameters();
91-
if (originalParameters != null && !tool.validateParameterTypes(originalParameters)) {
92-
listener.onFailure(new IllegalArgumentException("Invalid parameter types for tool: " + toolName));
93-
return;
94-
}
88+
Map<String, Object> originalParams = toolMLInput.getOriginalParameters();
89+
Tool tool = toolFactory.create(originalParams != null ? originalParams : new HashMap<>(mutableParams));
9590

9691
if (!tool.validate(mutableParams)) {
9792
listener.onFailure(new IllegalArgumentException("Invalid parameters for tool: " + toolName));
@@ -136,4 +131,50 @@ private void processOutput(Object output, List<ModelTensor> modelTensors) {
136131
modelTensors.add(ModelTensor.builder().name("response").result(result).build());
137132
}
138133
}
134+
135+
private void validateParameterTypes(String toolName, Map<String, String> parameters, Tool.Factory toolFactory) {
136+
try {
137+
for (Map.Entry<String, String> entry : parameters.entrySet()) {
138+
String paramName = entry.getKey();
139+
String paramValue = entry.getValue();
140+
141+
if (!isValidParameterValue(paramName, paramValue)) {
142+
throw new IllegalArgumentException(
143+
"Invalid parameter value for '" + paramName + "' in tool '" + toolName + "': " + paramValue
144+
);
145+
}
146+
}
147+
} catch (Exception e) {
148+
log.error("Could not validate parameter types for tool: " + toolName, e);
149+
if (e instanceof IllegalArgumentException) {
150+
throw e;
151+
}
152+
}
153+
}
154+
155+
private boolean isValidParameterValue(String paramName, String paramValue) {
156+
if (paramValue == null || paramValue.trim().isEmpty()) {
157+
return false;
158+
}
159+
160+
try {
161+
Map<String, Object> testConfig = new HashMap<>();
162+
testConfig.put(paramName, paramValue);
163+
164+
if (paramValue.matches("^-?\\d+$")) {
165+
ConfigurationUtils.readIntProperty("tool", "test", testConfig, paramName, 0);
166+
} else if ("true".equalsIgnoreCase(paramValue) || "false".equalsIgnoreCase(paramValue)) {
167+
ConfigurationUtils.readBooleanProperty("tool", "test", testConfig, paramName, false);
168+
} else if (paramValue.matches("^-?\\d*\\.\\d+$")) {
169+
ConfigurationUtils.readDoubleProperty("tool", "test", testConfig, paramName);
170+
} else if (paramValue.startsWith("[") && paramValue.endsWith("]")) {
171+
ConfigurationUtils.readOptionalList("tool", "test", testConfig, paramName);
172+
} else if (paramValue.startsWith("{") && paramValue.endsWith("}")) {
173+
ConfigurationUtils.readOptionalMap("tool", "test", testConfig, paramName);
174+
}
175+
return true;
176+
} catch (Exception e) {
177+
return false;
178+
}
179+
}
139180
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import org.opensearch.action.ActionRequest;
1313
import org.opensearch.core.action.ActionListener;
14+
import org.opensearch.ingest.ConfigurationUtils;
1415
import org.opensearch.ml.common.FunctionName;
1516
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
1617
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
@@ -36,6 +37,9 @@
3637
@ToolAnnotation(AgentTool.TYPE)
3738
public class AgentTool implements Tool {
3839
public static final String TYPE = "AgentTool";
40+
public static final int DEFAULT_MAX_QUESTION_LENGTH = 10000;
41+
public static final String MAX_QUESTION_LENGTH_FIELD = "max_question_length";
42+
private int maxQuestionLength = DEFAULT_MAX_QUESTION_LENGTH;
3943
private final Client client;
4044

4145
@Setter
@@ -123,21 +127,8 @@ public boolean validate(Map<String, String> parameters) {
123127

124128
// Validate question length
125129
String question = parameters.get("question");
126-
if (question != null && question.length() > 10000) {
127-
throw new IllegalArgumentException("question length cannot exceed 10000 characters");
128-
}
129-
130-
return true;
131-
}
132-
133-
@Override
134-
public boolean validateParameterTypes(Map<String, Object> parameters) {
135-
// Validate question must be String
136-
Object questionObj = parameters.get("question");
137-
if (questionObj != null && !(questionObj instanceof String)) {
138-
throw new IllegalArgumentException(
139-
String.format("question must be a String type, but got %s", questionObj.getClass().getSimpleName())
140-
);
130+
if (question != null && question.length() > maxQuestionLength) {
131+
throw new IllegalArgumentException("question length cannot exceed " + maxQuestionLength + " characters");
141132
}
142133
return true;
143134
}
@@ -166,8 +157,11 @@ public void init(Client client) {
166157

167158
@Override
168159
public AgentTool create(Map<String, Object> params) {
160+
ConfigurationUtils.readStringProperty(TYPE, null, params, "question");
169161
AgentTool agentTool = new AgentTool(client, (String) params.get("agent_id"));
170162
agentTool.setOutputParser(ToolParser.createFromToolParams(params));
163+
agentTool.maxQuestionLength = ConfigurationUtils
164+
.readIntProperty(TYPE, null, params, MAX_QUESTION_LENGTH_FIELD, DEFAULT_MAX_QUESTION_LENGTH);
171165
return agentTool;
172166
}
173167

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ConnectorTool.java

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.apache.commons.lang3.StringUtils;
1111
import org.opensearch.action.ActionRequest;
1212
import org.opensearch.core.action.ActionListener;
13+
import org.opensearch.ingest.ConfigurationUtils;
1314
import org.opensearch.ml.common.FunctionName;
1415
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
1516
import org.opensearch.ml.common.input.MLInput;
@@ -100,18 +101,6 @@ public boolean validate(Map<String, String> parameters) {
100101
return parameters != null && !parameters.isEmpty();
101102
}
102103

103-
@Override
104-
public boolean validateParameterTypes(Map<String, Object> parameters) {
105-
// Validate response_filter must be String
106-
Object responseFilterObj = parameters.get("response_filter");
107-
if (responseFilterObj != null && !(responseFilterObj instanceof String)) {
108-
throw new IllegalArgumentException(
109-
String.format("response_filter must be a String type, but got %s", responseFilterObj.getClass().getSimpleName())
110-
);
111-
}
112-
return true;
113-
}
114-
115104
public static class Factory implements Tool.Factory<ConnectorTool> {
116105
public static final String TYPE = "ConnectorTool";
117106
public static final String DEFAULT_DESCRIPTION = "Invokes external service. Required: 'connector_id'. Returns: service response.";
@@ -137,6 +126,7 @@ public void init(Client client) {
137126

138127
@Override
139128
public ConnectorTool create(Map<String, Object> params) {
129+
ConfigurationUtils.readOptionalStringProperty(TYPE, null, params, "response_filter");
140130
ConnectorTool connectorTool = new ConnectorTool(client, (String) params.get(CONNECTOR_ID));
141131
connectorTool.setOutputParser(ToolParser.createFromToolParams(params));
142132
return connectorTool;

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/IndexInsightTool.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,6 @@ public boolean validate(Map<String, String> parameters) {
113113
return true;
114114
}
115115

116-
@Override
117-
public boolean validateParameterTypes(Map<String, Object> parameters) {
118-
return true;
119-
}
120-
121116
public static class Factory implements Tool.Factory<IndexInsightTool> {
122117
private Client client;
123118

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/IndexMappingTool.java

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.opensearch.common.settings.Settings;
2525
import org.opensearch.common.unit.TimeValue;
2626
import org.opensearch.core.action.ActionListener;
27+
import org.opensearch.ingest.ConfigurationUtils;
2728
import org.opensearch.ml.common.spi.tools.Parser;
2829
import org.opensearch.ml.common.spi.tools.Tool;
2930
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
@@ -55,6 +56,9 @@ public class IndexMappingTool implements Tool {
5556
+ "\"required\":[\"index\"],"
5657
+ "\"additionalProperties\":false}";
5758
public static final Map<String, Object> DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA, STRICT_FIELD, true);
59+
public static final int DEFAULT_MAX_QUESTION_LENGTH = 10000;
60+
public static final String MAX_QUESTION_LENGTH_FIELD = "max_question_length";
61+
private int maxQuestionLength = DEFAULT_MAX_QUESTION_LENGTH;
5862

5963
@Setter
6064
@Getter
@@ -181,41 +185,13 @@ public boolean validate(Map<String, String> parameters) {
181185

182186
// Validate question length
183187
String question = parameters.get("question");
184-
if (question != null && question.length() > 10000) {
185-
throw new IllegalArgumentException("question length cannot exceed 10000 characters");
188+
if (question != null && question.length() > maxQuestionLength) {
189+
throw new IllegalArgumentException("question length cannot exceed " + maxQuestionLength + " characters");
186190
}
187191

188192
return true;
189193
}
190194

191-
@Override
192-
public boolean validateParameterTypes(Map<String, Object> parameters) {
193-
// Validate question must be String
194-
Object questionObj = parameters.get("question");
195-
if (questionObj != null && !(questionObj instanceof String)) {
196-
throw new IllegalArgumentException(
197-
String.format("question must be a String type, but got %s", questionObj.getClass().getSimpleName())
198-
);
199-
}
200-
201-
// Validate index must be ArrayList
202-
Object indexObj = parameters.get("index");
203-
if (indexObj != null && !(indexObj instanceof ArrayList)) {
204-
throw new IllegalArgumentException(
205-
String.format("index must be an Array type, but got %s", indexObj.getClass().getSimpleName())
206-
);
207-
}
208-
209-
// Validate local must be Boolean
210-
Object localObj = parameters.get("local");
211-
if (localObj != null && !(localObj instanceof Boolean)) {
212-
throw new IllegalArgumentException(
213-
String.format("local must be a Boolean type, but got %s", localObj.getClass().getSimpleName())
214-
);
215-
}
216-
return true;
217-
}
218-
219195
/**
220196
* Factory for the {@link IndexMappingTool}
221197
*/
@@ -250,8 +226,14 @@ public void init(Client client) {
250226

251227
@Override
252228
public IndexMappingTool create(Map<String, Object> params) {
229+
ConfigurationUtils.readStringProperty(TYPE, null, params, "question");
230+
ConfigurationUtils.readOptionalList(TYPE, null, params, "index");
231+
ConfigurationUtils.readBooleanProperty(TYPE, null, params, "local", false);
232+
253233
IndexMappingTool indexMappingTool = new IndexMappingTool(client);
254234
indexMappingTool.setOutputParser(ToolParser.createFromToolParams(params));
235+
indexMappingTool.maxQuestionLength = ConfigurationUtils
236+
.readIntProperty(TYPE, null, params, MAX_QUESTION_LENGTH_FIELD, DEFAULT_MAX_QUESTION_LENGTH);
255237
return indexMappingTool;
256238
}
257239

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ListIndexTool.java

Lines changed: 13 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
import org.opensearch.core.action.ActionListener;
5555
import org.opensearch.core.action.ActionResponse;
5656
import org.opensearch.index.IndexSettings;
57+
import org.opensearch.ingest.ConfigurationUtils;
5758
import org.opensearch.ml.common.spi.tools.Parser;
5859
import org.opensearch.ml.common.spi.tools.Tool;
5960
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
@@ -86,6 +87,9 @@ public class ListIndexTool implements Tool {
8687
+ "for example: [\\\"index1\\\", \\\"index2\\\"], use empty array [] to list all indices in the cluster\"}},"
8788
+ "\"additionalProperties\":false}";
8889
public static final Map<String, Object> DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA, STRICT_FIELD, false);
90+
public static final int DEFAULT_MAX_QUESTION_LENGTH = 10000;
91+
public static final String MAX_QUESTION_LENGTH_FIELD = "max_question_length";
92+
private int maxQuestionLength = DEFAULT_MAX_QUESTION_LENGTH;
8993

9094
@Setter
9195
@Getter
@@ -421,45 +425,8 @@ public boolean validate(Map<String, String> parameters) {
421425

422426
// Validate question length
423427
String question = parameters.get("question");
424-
if (question != null && question.length() > 10000) {
425-
throw new IllegalArgumentException("question length cannot exceed 10000 characters");
426-
}
427-
428-
return true;
429-
}
430-
431-
@Override
432-
public boolean validateParameterTypes(Map<String, Object> parameters) {
433-
// Validate question must be String
434-
Object questionObj = parameters.get("question");
435-
if (questionObj != null && !(questionObj instanceof String)) {
436-
throw new IllegalArgumentException(
437-
String.format("question must be a String type, but got %s", questionObj.getClass().getSimpleName())
438-
);
439-
}
440-
441-
// Validate indices must be ArrayList
442-
Object indicesObj = parameters.get("indices");
443-
if (indicesObj != null && !(indicesObj instanceof ArrayList)) {
444-
throw new IllegalArgumentException(
445-
String.format("indices must be an Array type, but got %s", indicesObj.getClass().getSimpleName())
446-
);
447-
}
448-
449-
// Validate local must be Boolean
450-
Object localObj = parameters.get("local");
451-
if (localObj != null && !(localObj instanceof Boolean)) {
452-
throw new IllegalArgumentException(
453-
String.format("local must be a Boolean type, but got %s", localObj.getClass().getSimpleName())
454-
);
455-
}
456-
457-
// Validate page_size must be Integer
458-
Object pageSizeObj = parameters.get("page_size");
459-
if (pageSizeObj != null && !(pageSizeObj instanceof Integer)) {
460-
throw new IllegalArgumentException(
461-
String.format("page_size must be an Integer type, but got %s", pageSizeObj.getClass().getSimpleName())
462-
);
428+
if (question != null && question.length() > maxQuestionLength) {
429+
throw new IllegalArgumentException("question length cannot exceed " + maxQuestionLength + " characters");
463430
}
464431
return true;
465432
}
@@ -501,8 +468,15 @@ public void init(Client client, ClusterService clusterService) {
501468

502469
@Override
503470
public ListIndexTool create(Map<String, Object> params) {
471+
ConfigurationUtils.readStringProperty(TYPE, null, params, "question");
472+
ConfigurationUtils.readOptionalList(TYPE, null, params, "indices");
473+
ConfigurationUtils.readBooleanProperty(TYPE, null, params, "local", false);
474+
ConfigurationUtils.readIntProperty(TYPE, null, params, "page_size", 100);
475+
504476
ListIndexTool tool = new ListIndexTool(client, clusterService);
505477
tool.setOutputParser(ToolParser.createFromToolParams(params));
478+
tool.maxQuestionLength = ConfigurationUtils
479+
.readIntProperty(TYPE, null, params, MAX_QUESTION_LENGTH_FIELD, DEFAULT_MAX_QUESTION_LENGTH);
506480
return tool;
507481
}
508482

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import org.opensearch.action.ActionRequest;
1414
import org.opensearch.core.action.ActionListener;
15+
import org.opensearch.ingest.ConfigurationUtils;
1516
import org.opensearch.ml.common.FunctionName;
1617
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
1718
import org.opensearch.ml.common.input.MLInput;
@@ -148,26 +149,6 @@ public boolean validate(Map<String, String> parameters) {
148149
return parameters != null && !parameters.isEmpty();
149150
}
150151

151-
@Override
152-
public boolean validateParameterTypes(Map<String, Object> parameters) {
153-
// Validate prompt must be String
154-
Object promptObj = parameters.get("prompt");
155-
if (promptObj != null && !(promptObj instanceof String)) {
156-
throw new IllegalArgumentException(
157-
String.format("prompt must be a String type, but got %s", promptObj.getClass().getSimpleName())
158-
);
159-
}
160-
161-
// Validate response_field must be String
162-
Object responseFieldObj = parameters.get(RESPONSE_FIELD);
163-
if (responseFieldObj != null && !(responseFieldObj instanceof String)) {
164-
throw new IllegalArgumentException(
165-
String.format("%s must be a String type, but got %s", RESPONSE_FIELD, responseFieldObj.getClass().getSimpleName())
166-
);
167-
}
168-
return true;
169-
}
170-
171152
public static class Factory implements WithModelTool.Factory<MLModelTool> {
172153
private Client client;
173154

@@ -192,6 +173,8 @@ public void init(Client client) {
192173

193174
@Override
194175
public MLModelTool create(Map<String, Object> map) {
176+
ConfigurationUtils.readOptionalStringProperty(TYPE, null, map, "prompt");
177+
ConfigurationUtils.readOptionalStringProperty(TYPE, null, map, "response_field");
195178
String modelId = (String) map.get(MODEL_ID_FIELD);
196179
String responseField = (String) map.getOrDefault(RESPONSE_FIELD, DEFAULT_RESPONSE_FIELD);
197180

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/McpSseTool.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,6 @@ public boolean validate(Map<String, String> parameters) {
9999
return true;
100100
}
101101

102-
@Override
103-
public boolean validateParameterTypes(Map<String, Object> parameters) {
104-
return true;
105-
}
106-
107102
public static class Factory implements WithModelTool.Factory<McpSseTool> {
108103
private static Factory INSTANCE;
109104

0 commit comments

Comments
 (0)