diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java index 736cdc5a53..c617951b4a 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/SearchIndexTool.java @@ -62,6 +62,7 @@ public class SearchIndexTool implements Tool { public static final String INPUT_FIELD = "input"; public static final String INDEX_FIELD = "index"; public static final String QUERY_FIELD = "query"; + public static final String SIZE_FIELD = "size"; public static final String INPUT_SCHEMA_FIELD = "input_schema"; public static final String STRICT_FIELD = "strict"; @@ -131,10 +132,13 @@ public boolean validate(Map parameters) { return true; } - private SearchRequest getSearchRequest(String index, String query) throws IOException { + private SearchRequest getSearchRequest(String index, String query, Integer size) throws IOException { SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, query); searchSourceBuilder.parseXContent(queryParser); + if (size != null && size > 0) { + searchSourceBuilder.size(size); + } return new SearchRequest().source(searchSourceBuilder).indices(index); } @@ -176,6 +180,7 @@ public void run(Map originalParameters, ActionListener li String input = parameters.get(INPUT_FIELD); String index = null; String query = null; + Integer size = null; boolean returnFullResponse = Boolean.parseBoolean(parameters.getOrDefault(RETURN_RAW_RESPONSE, "false")); if (!StringUtils.isEmpty(input)) { try { @@ -188,6 +193,10 @@ public void run(Map originalParameters, ActionListener li Object queryObject = PLAIN_NUMBER_GSON.fromJson(queryElement, Object.class); query = PLAIN_NUMBER_GSON.toJson(queryObject); } + + if (jsonObject.has(SIZE_FIELD)) { + size = jsonObject.get(SIZE_FIELD).getAsInt(); + } } } catch (JsonSyntaxException e) { log.error("Invalid JSON input: {}", input, e); @@ -202,6 +211,14 @@ public void run(Map originalParameters, ActionListener li query = parameters.get(QUERY_FIELD); } + if (StringUtils.isNotEmpty(parameters.get(SIZE_FIELD))) { + try { + size = Math.min(size == null ? 100 : size, Integer.parseInt(parameters.get(SIZE_FIELD))); + } catch (NumberFormatException e) { + log.warn("Invalid size parameter: {}", parameters.get(SIZE_FIELD)); + } + } + if (StringUtils.isEmpty(index) || StringUtils.isEmpty(query)) { listener .onFailure( @@ -212,7 +229,7 @@ public void run(Map originalParameters, ActionListener li return; } - SearchRequest searchRequest = getSearchRequest(index, query); + SearchRequest searchRequest = getSearchRequest(index, query, size); ActionListener actionListener = ActionListener.wrap(r -> { SearchHit[] hits = r.getHits().getHits(); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/SearchIndexToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/SearchIndexToolTests.java index 161c43a92f..97023d28fd 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/SearchIndexToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/SearchIndexToolTests.java @@ -520,4 +520,44 @@ public void testRun_withRangeQuery_triggersPlainDoubleGson() { assertArrayEquals(new String[] { "test-index" }, cap.getValue().indices()); } + + @Test + public void testRunWithSizeInInput() { + String inputString = "{\"index\": \"test-index\", \"query\": {\"query\": {\"match_all\": {}}}, \"size\": 5}"; + Map parameters = Map.of("input", inputString); + mockedSearchIndexTool.run(parameters, null); + + ArgumentCaptor cap = ArgumentCaptor.forClass(SearchRequest.class); + verify(client, times(1)).search(cap.capture(), any()); + assertEquals(5, cap.getValue().source().size()); + } + + @Test + public void testRunWithSizeAsParameter() { + Map parameters = Map.of( + "index", "test-index", + "query", "{\"query\": {\"match_all\": {}}}", + "size", "3" + ); + mockedSearchIndexTool.run(parameters, null); + + ArgumentCaptor cap = ArgumentCaptor.forClass(SearchRequest.class); + verify(client, times(1)).search(cap.capture(), any()); + assertEquals(3, cap.getValue().source().size()); + } + + @Test + public void testRunWithInvalidSizeParameter() { + Map parameters = Map.of( + "index", "test-index", + "query", "{\"query\": {\"match_all\": {}}}", + "size", "invalid" + ); + mockedSearchIndexTool.run(parameters, null); + + ArgumentCaptor cap = ArgumentCaptor.forClass(SearchRequest.class); + verify(client, times(1)).search(cap.capture(), any()); + // Size should not be set when invalid + assertEquals(-1, cap.getValue().source().size()); + } }