diff --git a/common/src/main/java/org/opensearch/ml/common/input/execute/tool/ToolMLInput.java b/common/src/main/java/org/opensearch/ml/common/input/execute/tool/ToolMLInput.java index 5acf865ff9..df04db90a3 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/execute/tool/ToolMLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/execute/tool/ToolMLInput.java @@ -33,6 +33,9 @@ public class ToolMLInput extends MLInput { @Setter private String toolName; + @Getter + private Map originalParameters; + public ToolMLInput(StreamInput in) throws IOException { super(in); this.toolName = in.readString(); @@ -66,7 +69,9 @@ public ToolMLInput(XContentParser parser, FunctionName functionName) throws IOEx toolName = parser.text(); break; case PARAMETERS_FIELD: - Map parameters = StringUtils.getParameterMap(parser.map()); + Map rawParams = parser.map(); + originalParameters = rawParams; + Map parameters = StringUtils.getParameterMap(rawParams); inputDataset = new RemoteInferenceInputDataSet(parameters); break; default: diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index 7de547127a..ef4cfadd5e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -996,7 +996,10 @@ public static Tool createTool(Map toolFactories, Map toolParams = new HashMap<>(); - toolParams.putAll(executeParams); + // Parse JSON strings back to original type since we need to validate each parameter type when creating tool + for (Map.Entry entry : executeParams.entrySet()) { + toolParams.put(entry.getKey(), parseValue(entry.getValue())); + } Map runtimeResources = toolSpec.getRuntimeResources(); if (runtimeResources != null) { toolParams.putAll(runtimeResources); @@ -1014,4 +1017,32 @@ public static Tool createTool(Map toolFactories, Map listener) { try { Map mutableParams = new HashMap<>(parameters); - Tool tool = toolFactory.create(mutableParams); + Map originalParams = toolMLInput.getOriginalParameters(); + Tool tool = toolFactory.create(originalParams); + if (!tool.validate(mutableParams)) { listener.onFailure(new IllegalArgumentException("Invalid parameters for tool: " + toolName)); return; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java index 39f10530b8..856b835805 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/AgentTool.java @@ -11,6 +11,7 @@ import org.opensearch.action.ActionRequest; import org.opensearch.core.action.ActionListener; +import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.execute.agent.AgentMLInput; @@ -36,6 +37,9 @@ @ToolAnnotation(AgentTool.TYPE) public class AgentTool implements Tool { public static final String TYPE = "AgentTool"; + public static final int DEFAULT_MAX_QUESTION_LENGTH = 10000; + public static final String MAX_QUESTION_LENGTH_FIELD = "max_question_length"; + private int maxQuestionLength = DEFAULT_MAX_QUESTION_LENGTH; private final Client client; @Setter @@ -117,6 +121,15 @@ public void setName(String s) { @Override public boolean validate(Map parameters) { + if (parameters == null || parameters.isEmpty()) { + return false; + } + + // Validate question length + String question = parameters.get("question"); + if (question != null && question.length() > maxQuestionLength) { + throw new IllegalArgumentException("question length cannot exceed " + maxQuestionLength + " characters"); + } return true; } @@ -144,8 +157,11 @@ public void init(Client client) { @Override public AgentTool create(Map params) { + ConfigurationUtils.readStringProperty(TYPE, null, params, "question"); AgentTool agentTool = new AgentTool(client, (String) params.get("agent_id")); agentTool.setOutputParser(ToolParser.createFromToolParams(params)); + agentTool.maxQuestionLength = ConfigurationUtils + .readIntProperty(TYPE, null, params, MAX_QUESTION_LENGTH_FIELD, DEFAULT_MAX_QUESTION_LENGTH); return agentTool; } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ConnectorTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ConnectorTool.java index 7f165de28c..b090ee8a27 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ConnectorTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ConnectorTool.java @@ -10,6 +10,7 @@ import org.apache.commons.lang3.StringUtils; import org.opensearch.action.ActionRequest; import org.opensearch.core.action.ActionListener; +import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; @@ -125,6 +126,7 @@ public void init(Client client) { @Override public ConnectorTool create(Map params) { + ConfigurationUtils.readOptionalStringProperty(TYPE, null, params, "response_filter"); ConnectorTool connectorTool = new ConnectorTool(client, (String) params.get(CONNECTOR_ID)); connectorTool.setOutputParser(ToolParser.createFromToolParams(params)); return connectorTool; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/IndexMappingTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/IndexMappingTool.java index 6b359dfa19..bc97686b62 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/IndexMappingTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/IndexMappingTool.java @@ -24,6 +24,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.core.action.ActionListener; +import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.ml.common.spi.tools.Parser; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; @@ -55,6 +56,9 @@ public class IndexMappingTool implements Tool { + "\"required\":[\"index\"]," + "\"additionalProperties\":false}"; public static final Map DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA, STRICT_FIELD, true); + public static final int DEFAULT_MAX_QUESTION_LENGTH = 10000; + public static final String MAX_QUESTION_LENGTH_FIELD = "max_question_length"; + private int maxQuestionLength = DEFAULT_MAX_QUESTION_LENGTH; @Setter @Getter @@ -175,7 +179,17 @@ public String getType() { @Override public boolean validate(Map parameters) { - return parameters != null && !parameters.isEmpty() && parameters.containsKey("index"); + if (parameters == null || parameters.isEmpty() || !parameters.containsKey("index")) { + return false; + } + + // Validate question length + String question = parameters.get("question"); + if (question != null && question.length() > maxQuestionLength) { + throw new IllegalArgumentException("question length cannot exceed " + maxQuestionLength + " characters"); + } + + return true; } /** @@ -212,8 +226,14 @@ public void init(Client client) { @Override public IndexMappingTool create(Map params) { + ConfigurationUtils.readOptionalStringProperty(TYPE, null, params, "question"); + ConfigurationUtils.readOptionalList(TYPE, null, params, "index"); + ConfigurationUtils.readBooleanProperty(TYPE, null, params, "local", false); + IndexMappingTool indexMappingTool = new IndexMappingTool(client); indexMappingTool.setOutputParser(ToolParser.createFromToolParams(params)); + indexMappingTool.maxQuestionLength = ConfigurationUtils + .readIntProperty(TYPE, null, params, MAX_QUESTION_LENGTH_FIELD, DEFAULT_MAX_QUESTION_LENGTH); return indexMappingTool; } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ListIndexTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ListIndexTool.java index b2de15ac2e..c547428a8a 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ListIndexTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ListIndexTool.java @@ -54,6 +54,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ActionResponse; import org.opensearch.index.IndexSettings; +import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.ml.common.spi.tools.Parser; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; @@ -86,6 +87,9 @@ public class ListIndexTool implements Tool { + "for example: [\\\"index1\\\", \\\"index2\\\"], use empty array [] to list all indices in the cluster\"}}," + "\"additionalProperties\":false}"; public static final Map DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA, STRICT_FIELD, false); + public static final int DEFAULT_MAX_QUESTION_LENGTH = 10000; + public static final String MAX_QUESTION_LENGTH_FIELD = "max_question_length"; + private int maxQuestionLength = DEFAULT_MAX_QUESTION_LENGTH; @Setter @Getter @@ -415,7 +419,16 @@ public void onFailure(final Exception e) { @Override public boolean validate(Map parameters) { - return parameters != null && !parameters.isEmpty(); + if (parameters == null || parameters.isEmpty()) { + return false; + } + + // Validate question length + String question = parameters.get("question"); + if (question != null && question.length() > maxQuestionLength) { + throw new IllegalArgumentException("question length cannot exceed " + maxQuestionLength + " characters"); + } + return true; } /** @@ -455,8 +468,15 @@ public void init(Client client, ClusterService clusterService) { @Override public ListIndexTool create(Map params) { + ConfigurationUtils.readOptionalStringProperty(TYPE, null, params, "question"); + ConfigurationUtils.readOptionalList(TYPE, null, params, "indices"); + ConfigurationUtils.readBooleanProperty(TYPE, null, params, "local", false); + ConfigurationUtils.readIntProperty(TYPE, null, params, "page_size", 100); + ListIndexTool tool = new ListIndexTool(client, clusterService); tool.setOutputParser(ToolParser.createFromToolParams(params)); + tool.maxQuestionLength = ConfigurationUtils + .readIntProperty(TYPE, null, params, MAX_QUESTION_LENGTH_FIELD, DEFAULT_MAX_QUESTION_LENGTH); return tool; } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java index cbd3c3e21b..5e66060ec4 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java @@ -12,6 +12,7 @@ import org.opensearch.action.ActionRequest; import org.opensearch.core.action.ActionListener; +import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; @@ -172,6 +173,8 @@ public void init(Client client) { @Override public MLModelTool create(Map map) { + ConfigurationUtils.readOptionalStringProperty(TYPE, null, map, "prompt"); + ConfigurationUtils.readOptionalStringProperty(TYPE, null, map, "response_field"); String modelId = (String) map.get(MODEL_ID_FIELD); String responseField = (String) map.getOrDefault(RESPONSE_FIELD, DEFAULT_RESPONSE_FIELD); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningTool.java index 26bc77a053..d3bf1b6691 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningTool.java @@ -36,6 +36,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.ml.common.spi.tools.Parser; import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.common.spi.tools.WithModelTool; @@ -121,6 +122,9 @@ public class QueryPlanningTool implements WithModelTool { + "}"; public static final Map DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA, STRICT_FIELD, false); + public static final int DEFAULT_MAX_QUESTION_LENGTH = 10000; + public static final String MAX_QUESTION_LENGTH_FIELD = "max_question_length"; + private int maxQuestionLength = DEFAULT_MAX_QUESTION_LENGTH; @Getter @Setter @@ -394,6 +398,12 @@ public boolean validate(Map parameters) { || !parameters.containsKey(INDEX_NAME_FIELD)) { return false; } + + // Validate question length + String question = parameters.get(QUESTION_FIELD); + if (question != null && question.length() > maxQuestionLength) { + throw new IllegalArgumentException("question length cannot exceed " + maxQuestionLength + " characters"); + } return true; } @@ -420,6 +430,14 @@ public void init(Client client) { @Override public QueryPlanningTool create(Map params) { + ConfigurationUtils.readStringProperty(TYPE, null, params, QUESTION_FIELD); + ConfigurationUtils.readStringProperty(TYPE, null, params, INDEX_NAME_FIELD); + ConfigurationUtils.readOptionalStringProperty(TYPE, null, params, GENERATION_TYPE_FIELD); + ConfigurationUtils.readOptionalStringProperty(TYPE, null, params, QUERY_PLANNER_SYSTEM_PROMPT_FIELD); + ConfigurationUtils.readOptionalStringProperty(TYPE, null, params, QUERY_PLANNER_USER_PROMPT_FIELD); + ConfigurationUtils.readOptionalStringProperty(TYPE, null, params, "embedding_model_id"); + ConfigurationUtils.readOptionalStringProperty(TYPE, null, params, "response_filter"); + ConfigurationUtils.readOptionalList(TYPE, null, params, SEARCH_TEMPLATES_FIELD); MLModelTool queryGenerationTool = MLModelTool.Factory.getInstance().create(params); @@ -455,6 +473,8 @@ public QueryPlanningTool create(Map params) { // Create parser with default extract_json processor + any custom processors queryPlanningTool.setOutputParser(createParserWithDefaultExtractJson(params)); + queryPlanningTool.maxQuestionLength = ConfigurationUtils + .readIntProperty(TYPE, null, params, MAX_QUESTION_LENGTH_FIELD, DEFAULT_MAX_QUESTION_LENGTH); return queryPlanningTool; } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ReadFromScratchPadTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ReadFromScratchPadTool.java index 70ca8a1072..780e8c9404 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ReadFromScratchPadTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/ReadFromScratchPadTool.java @@ -13,6 +13,7 @@ import java.util.Map; import org.opensearch.core.action.ActionListener; +import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.common.utils.StringUtils; @@ -137,6 +138,7 @@ public void init() {} @Override public ReadFromScratchPadTool create(Map params) { + ConfigurationUtils.readOptionalStringProperty(TYPE, null, params, PERSISTENT_NOTES_KEY); return new ReadFromScratchPadTool(); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java index 736cdc5a53..8fda79e9f5 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java @@ -27,6 +27,7 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; @@ -296,6 +297,7 @@ public void init(Client client, NamedXContentRegistry xContentRegistry) { @Override public SearchIndexTool create(Map params) { + ConfigurationUtils.readStringProperty(TYPE, null, params, INPUT_FIELD); SearchIndexTool tool = new SearchIndexTool(client, xContentRegistry); // Enhance the output parser with processors if configured tool.setOutputParser(ToolParser.createFromToolParams(params)); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VisualizationsTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VisualizationsTool.java index e2ba02c99d..15e02aa6ff 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VisualizationsTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/VisualizationsTool.java @@ -21,6 +21,7 @@ import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.common.utils.ToolUtils; @@ -43,6 +44,9 @@ public class VisualizationsTool implements Tool { public static final String SAVED_OBJECT_TYPE = "visualization"; public static final String STRICT_FIELD = "strict"; + public static final String INPUT_FIELD = "input"; + public static final String INDEX_FIELD = "index"; + public static final String SIZE_FIELD = "size"; /** * default number of visualizations returned @@ -55,6 +59,10 @@ public class VisualizationsTool implements Tool { + "\"required\":[\"input\"]," + "\"additionalProperties\":false}"; public static final Map DEFAULT_ATTRIBUTES = Map.of(TOOL_INPUT_SCHEMA_FIELD, DEFAULT_INPUT_SCHEMA, STRICT_FIELD, false); + public static final int DEFAULT_MAX_INPUT_LENGTH = 10000; + public static final String MAX_INPUT_LENGTH_FIELD = "max_input_length"; + private int maxInputLength = DEFAULT_MAX_INPUT_LENGTH; + @Setter @Getter private String description = DEFAULT_DESCRIPTION; @@ -93,7 +101,7 @@ public void run(Map originalParameters, ActionListener li Map parameters = ToolUtils.extractInputParameters(originalParameters, attributes); BoolQueryBuilder boolQueryBuilder = QueryBuilders.boolQuery(); boolQueryBuilder.must().add(QueryBuilders.termQuery("type", SAVED_OBJECT_TYPE)); - boolQueryBuilder.must().add(QueryBuilders.matchQuery(SAVED_OBJECT_TYPE + ".title", parameters.get("input"))); + boolQueryBuilder.must().add(QueryBuilders.matchQuery(SAVED_OBJECT_TYPE + ".title", parameters.get(INPUT_FIELD))); SearchSourceBuilder searchSourceBuilder = SearchSourceBuilder.searchSource().query(boolQueryBuilder); searchSourceBuilder.from(0).size(size); @@ -146,10 +154,19 @@ String trimIdPrefix(String id) { @Override public boolean validate(Map parameters) { - return parameters != null - && !parameters.isEmpty() - && !Strings.isNullOrEmpty(parameters.get("input")) - && parameters.containsKey("input"); + if (parameters == null + || parameters.isEmpty() + || Strings.isNullOrEmpty(parameters.get(INPUT_FIELD)) + || !parameters.containsKey(INPUT_FIELD)) { + return false; + } + + // Validate input length + String input = parameters.get(INPUT_FIELD); + if (input != null && input.length() > maxInputLength) { + throw new IllegalArgumentException("input length cannot exceed " + maxInputLength + " characters"); + } + return true; } public static class Factory implements Tool.Factory { @@ -176,15 +193,20 @@ public void init(Client client) { @Override public VisualizationsTool create(Map params) { - String index = params.get("index") == null ? ".kibana" : (String) params.get("index"); - String sizeStr = params.get("size") == null ? "3" : (String) params.get("size"); + ConfigurationUtils.readStringProperty(TYPE, null, params, INPUT_FIELD); + ConfigurationUtils.readOptionalStringProperty(TYPE, null, params, INDEX_FIELD); + ConfigurationUtils.readIntProperty(TYPE, null, params, SIZE_FIELD, 3); + String index = params.get(INDEX_FIELD) == null ? ".kibana" : (String) params.get(INDEX_FIELD); + String sizeStr = params.get(SIZE_FIELD) == null ? "3" : (String) params.get(SIZE_FIELD); int size; try { size = Integer.parseInt(sizeStr); } catch (NumberFormatException ignored) { size = DEFAULT_SIZE; } - return VisualizationsTool.builder().client(client).index(index).size(size).build(); + VisualizationsTool tool = VisualizationsTool.builder().client(client).index(index).size(size).build(); + tool.maxInputLength = ConfigurationUtils.readIntProperty(TYPE, null, params, MAX_INPUT_LENGTH_FIELD, DEFAULT_MAX_INPUT_LENGTH); + return tool; } @Override diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/WriteToScratchPadTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/WriteToScratchPadTool.java index ba15fbb8b3..cd82283e25 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/WriteToScratchPadTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/WriteToScratchPadTool.java @@ -13,6 +13,7 @@ import java.util.Map; import org.opensearch.core.action.ActionListener; +import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; import org.opensearch.ml.common.utils.StringUtils; @@ -143,6 +144,8 @@ public void init() {} @Override public WriteToScratchPadTool create(Map params) { + ConfigurationUtils.readStringProperty(TYPE, null, params, NOTES_KEY); + ConfigurationUtils.readBooleanProperty(TYPE, null, params, RETURN_HISTORY_KEY, false); return new WriteToScratchPadTool(); }