diff --git a/README.md b/README.md index 73337b0..623ff4c 100644 --- a/README.md +++ b/README.md @@ -77,6 +77,8 @@ The following tools are available but disabled by default. To enable them, see t - `opensearch_url` (optional): The OpenSearch cluster URL to connect to - `index` (required): The name of the index to search in - `query` (required): The search query in OpenSearch Query DSL format + - `size` (optional): Maximum number of hits to return (default: 10, max: 100). Limits response size to prevent token overflow + - `from` (optional): Starting offset for pagination (default: 0). Use with size for pagination - **GetShardsTool** - `opensearch_url` (optional): The OpenSearch cluster URL to connect to @@ -115,6 +117,7 @@ The following tools are available but disabled by default. To enable them, see t - `opensearch_url` (optional): The OpenSearch cluster URL to connect to - `index` (optional): Limit the information returned to the specified indices. If not provided, returns segments for all indices + - `limit` (optional): Maximum number of segments to return (default: 1000). Limits response size to prevent token overflow - **CatNodesTool** diff --git a/src/opensearch/helper.py b/src/opensearch/helper.py index 2687bee..c1baa48 100644 --- a/src/opensearch/helper.py +++ b/src/opensearch/helper.py @@ -45,10 +45,28 @@ def get_index_mapping(args: GetIndexMappingArgs) -> json: def search_index(args: SearchIndexArgs) -> json: + """Search an index with pagination support. + + Args: + args: SearchIndexArgs containing index, query, and optional pagination params + + Returns: + json: Search results from OpenSearch + """ from .client import initialize_client client = initialize_client(args) - response = client.search(index=args.index, body=args.query) + + # Ensure query is a dict for merging + query_body = args.query if isinstance(args.query, dict) else {} + + # Apply pagination parameters (override any user-provided values) + # Cap size at maximum of 100 to prevent token overflow + effective_size = min(args.size, 100) if args.size else 10 + query_body['size'] = effective_size + query_body['from'] = args.from_ if args.from_ is not None else 0 + + response = client.search(index=args.index, body=query_body) return response @@ -62,21 +80,26 @@ def get_shards(args: GetShardsArgs) -> json: def get_segments(args: GetSegmentsArgs) -> json: """Get information about Lucene segments in indices. - + Args: - args: GetSegmentsArgs containing optional index filter - + args: GetSegmentsArgs containing optional index filter and limit + Returns: json: Segment information for the specified indices or all indices """ from .client import initialize_client - + client = initialize_client(args) - + # If index is provided, filter by that index index_param = args.index if args.index else None - + response = client.cat.segments(index=index_param, format='json') + + # Apply limit to prevent token overflow + if args.limit and isinstance(response, list): + return response[:args.limit] + return response diff --git a/src/tools/tool_params.py b/src/tools/tool_params.py index 14418ec..5c3b063 100644 --- a/src/tools/tool_params.py +++ b/src/tools/tool_params.py @@ -31,6 +31,21 @@ class GetIndexMappingArgs(baseToolArgs): class SearchIndexArgs(baseToolArgs): index: str = Field(description='The name of the index to search in') query: Any = Field(description='The search query in OpenSearch query DSL format') + size: Optional[int] = Field( + default=10, + description='Maximum number of hits to return (default: 10, max: 100). Limits response size to prevent token overflow. Values exceeding 100 will be capped at 100.', + ge=1, + ) + from_: Optional[int] = Field( + default=0, + description='Starting offset for pagination (default: 0). Use with size for pagination.', + alias='from', + ge=0, + serialization_alias='from', + ) + + class Config: + populate_by_name = True class GetShardsArgs(baseToolArgs): @@ -65,12 +80,17 @@ class Config: class GetSegmentsArgs(baseToolArgs): """Arguments for the GetSegmentsTool.""" - + index: Optional[str] = Field( - default=None, + default=None, description='Limit the information returned to the specified indices. If not provided, returns segments for all indices.' ) - + limit: Optional[int] = Field( + default=1000, + description='Maximum number of segments to return (default: 1000). Limits response size to prevent token overflow.', + ge=1, + ) + class Config: json_schema_extra = { "examples": [ diff --git a/src/tools/tools.py b/src/tools/tools.py index 9b71327..829cfc0 100644 --- a/src/tools/tools.py +++ b/src/tools/tools.py @@ -499,7 +499,7 @@ async def get_long_running_tasks_tool(args: GetLongRunningTasksArgs) -> list[dic }, 'SearchIndexTool': { 'display_name': 'SearchIndexTool', - 'description': 'Searches an index using a query written in query domain-specific language (DSL) in OpenSearch', + 'description': 'Searches an index using a query written in query domain-specific language (DSL) in OpenSearch. Supports pagination with size (default: 10, max: 100) and from parameters to limit response size and prevent token overflow.', 'input_schema': SearchIndexArgs.model_json_schema(), 'function': search_index_tool, 'args_model': SearchIndexArgs, @@ -524,7 +524,7 @@ async def get_long_running_tasks_tool(args: GetLongRunningTasksArgs) -> list[dic }, 'GetSegmentsTool': { 'display_name': 'GetSegmentsTool', - 'description': 'Gets information about Lucene segments in indices, including memory usage, document counts, and segment sizes. Can be filtered by specific indices.', + 'description': 'Gets information about Lucene segments in indices, including memory usage, document counts, and segment sizes. Can be filtered by specific indices. Supports limit parameter (default: 1000) to prevent token overflow.', 'input_schema': GetSegmentsArgs.model_json_schema(), 'function': get_segments_tool, 'args_model': GetSegmentsArgs, diff --git a/tests/tools/test_tools.py b/tests/tools/test_tools.py index 621e9b2..0115e91 100644 --- a/tests/tools/test_tools.py +++ b/tests/tools/test_tools.py @@ -297,7 +297,10 @@ async def test_search_index_tool(self): assert result[0]['type'] == 'text' assert 'Search results from test-index' in result[0]['text'] assert json.loads(result[0]['text'].split('\n', 1)[1]) == mock_results - self.mock_client.search.assert_called_once_with(index='test-index', body={'match_all': {}}) + # Pagination params should be added with defaults + self.mock_client.search.assert_called_once_with( + index='test-index', body={'match_all': {}, 'size': 10, 'from': 0} + ) @pytest.mark.asyncio async def test_search_index_tool_error(self): @@ -311,7 +314,133 @@ async def test_search_index_tool_error(self): assert len(result) == 1 assert result[0]['type'] == 'text' assert 'Error searching index: Test error' in result[0]['text'] - self.mock_client.search.assert_called_once_with(index='test-index', body={'match_all': {}}) + # Pagination params should be added with defaults even on error path + self.mock_client.search.assert_called_once_with( + index='test-index', body={'match_all': {}, 'size': 10, 'from': 0} + ) + + @pytest.mark.asyncio + async def test_search_index_tool_with_default_size(self): + """Test search_index_tool applies default size of 10.""" + # Setup + mock_results = {'hits': {'total': {'value': 100}, 'hits': []}} + self.mock_client.search.return_value = mock_results + # Execute + args = self.SearchIndexArgs(index='test-index', query={'match_all': {}}) + result = await self._search_index_tool(args) + # Assert + assert len(result) == 1 + assert result[0]['type'] == 'text' + # Should call search with size=10 injected into body + expected_body = {'match_all': {}, 'size': 10, 'from': 0} + self.mock_client.search.assert_called_once_with(index='test-index', body=expected_body) + + @pytest.mark.asyncio + async def test_search_index_tool_with_custom_size(self): + """Test search_index_tool with custom size parameter.""" + # Setup + mock_results = {'hits': {'total': {'value': 50}, 'hits': []}} + self.mock_client.search.return_value = mock_results + # Execute + args = self.SearchIndexArgs(index='test-index', query={'match_all': {}}, size=25) + result = await self._search_index_tool(args) + # Assert + assert len(result) == 1 + assert result[0]['type'] == 'text' + # Should call search with size=25 injected into body + expected_body = {'match_all': {}, 'size': 25, 'from': 0} + self.mock_client.search.assert_called_once_with(index='test-index', body=expected_body) + + @pytest.mark.asyncio + async def test_search_index_tool_with_from_parameter(self): + """Test search_index_tool with from_ (offset) parameter.""" + # Setup + mock_results = {'hits': {'total': {'value': 100}, 'hits': []}} + self.mock_client.search.return_value = mock_results + # Execute + args = self.SearchIndexArgs(index='test-index', query={'match_all': {}}, from_=20) + result = await self._search_index_tool(args) + # Assert + assert len(result) == 1 + assert result[0]['type'] == 'text' + # Should call search with from=20 injected into body + expected_body = {'match_all': {}, 'size': 10, 'from': 20} + self.mock_client.search.assert_called_once_with(index='test-index', body=expected_body) + + @pytest.mark.asyncio + async def test_search_index_tool_with_size_and_from(self): + """Test search_index_tool with both size and from_ parameters.""" + # Setup + mock_results = {'hits': {'total': {'value': 100}, 'hits': []}} + self.mock_client.search.return_value = mock_results + # Execute + args = self.SearchIndexArgs( + index='test-index', query={'match_all': {}}, size=50, from_=30 + ) + result = await self._search_index_tool(args) + # Assert + assert len(result) == 1 + assert result[0]['type'] == 'text' + # Should call search with size=50 and from=30 injected into body + expected_body = {'match_all': {}, 'size': 50, 'from': 30} + self.mock_client.search.assert_called_once_with(index='test-index', body=expected_body) + + @pytest.mark.asyncio + async def test_search_index_tool_enforces_max_size(self): + """Test search_index_tool enforces maximum size of 100.""" + # Setup + mock_results = {'hits': {'total': {'value': 1000}, 'hits': []}} + self.mock_client.search.return_value = mock_results + # Execute - request 200 but should be capped at 100 + args = self.SearchIndexArgs(index='test-index', query={'match_all': {}}, size=200) + result = await self._search_index_tool(args) + # Assert + assert len(result) == 1 + assert result[0]['type'] == 'text' + # Should cap size at 100 + expected_body = {'match_all': {}, 'size': 100, 'from': 0} + self.mock_client.search.assert_called_once_with(index='test-index', body=expected_body) + + @pytest.mark.asyncio + async def test_search_index_tool_preserves_query_body_structure(self): + """Test search_index_tool preserves existing query body structure.""" + # Setup + mock_results = {'hits': {'total': {'value': 10}, 'hits': []}} + self.mock_client.search.return_value = mock_results + # Execute with complex query that already has some params + complex_query = { + 'query': {'match': {'field': 'value'}}, + 'sort': [{'timestamp': 'desc'}], + '_source': ['field1', 'field2'], + } + args = self.SearchIndexArgs(index='test-index', query=complex_query, size=20) + result = await self._search_index_tool(args) + # Assert + assert len(result) == 1 + # Should merge pagination params with existing query structure + expected_body = { + 'query': {'match': {'field': 'value'}}, + 'sort': [{'timestamp': 'desc'}], + '_source': ['field1', 'field2'], + 'size': 20, + 'from': 0, + } + self.mock_client.search.assert_called_once_with(index='test-index', body=expected_body) + + @pytest.mark.asyncio + async def test_search_index_tool_overrides_user_size_if_exceeds_max(self): + """Test search_index_tool overrides user-provided size in query body if it exceeds max.""" + # Setup + mock_results = {'hits': {'total': {'value': 1000}, 'hits': []}} + self.mock_client.search.return_value = mock_results + # Execute - query body has size=500, should be overridden by max + query_with_size = {'match_all': {}, 'size': 500} + args = self.SearchIndexArgs(index='test-index', query=query_with_size, size=50) + result = await self._search_index_tool(args) + # Assert + # Parameter size=50 should override query body size=500 + expected_body = {'match_all': {}, 'size': 50, 'from': 0} + self.mock_client.search.assert_called_once_with(index='test-index', body=expected_body) @pytest.mark.asyncio async def test_get_shards_tool(self): @@ -539,17 +668,110 @@ async def test_get_segments_tool_error(self): """Test get_segments_tool exception handling.""" # Setup self.mock_client.cat.segments.side_effect = Exception('Test error') - + # Execute args = self.GetSegmentsArgs() result = await self._get_segments_tool(args) - + # Assert assert len(result) == 1 assert result[0]['type'] == 'text' assert 'Error getting segment information: Test error' in result[0]['text'] self.mock_client.cat.segments.assert_called_once_with(index=None, format='json') - + + @pytest.mark.asyncio + async def test_get_segments_tool_with_limit(self): + """Test get_segments_tool with limit parameter to prevent token overflow.""" + # Setup - create many segments that would exceed token limit + mock_segments = [] + for i in range(100): + mock_segments.append({ + 'index': f'index-{i}', + 'shard': str(i % 5), + 'prirep': 'p', + 'segment': f's{i}', + 'generation': str(i), + 'docs.count': '100', + 'docs.deleted': '5', + 'size': '1mb', + 'memory.bookkeeping': '500b', + 'memory.vectors': '0b', + 'memory.docvalues': '200b', + 'memory.terms': '300b', + 'version': '8.0.0' + }) + self.mock_client.cat.segments.return_value = mock_segments + + # Execute with limit=50 + args = self.GetSegmentsArgs(limit=50) + result = await self._get_segments_tool(args) + + # Assert - should only return first 50 segments + assert len(result) == 1 + assert result[0]['type'] == 'text' + # Check that we have limited results + assert 'index-0' in result[0]['text'] + assert 'index-49' in result[0]['text'] + assert 'index-50' not in result[0]['text'] # Should be truncated + assert 'index-99' not in result[0]['text'] + + @pytest.mark.asyncio + async def test_get_segments_tool_default_limit(self): + """Test get_segments_tool applies default limit of 1000.""" + # Setup - create 1500 segments + mock_segments = [{'index': f'idx-{i}', 'shard': '0', 'prirep': 'p', 'segment': f's{i}'} for i in range(1500)] + self.mock_client.cat.segments.return_value = mock_segments + + # Execute without limit parameter + args = self.GetSegmentsArgs() + result = await self._get_segments_tool(args) + + # Assert - should cap at 1000 + assert len(result) == 1 + # Should have segments 0-999 but not 1000+ + assert 'idx-0' in result[0]['text'] + assert 'idx-999' in result[0]['text'] + assert 'idx-1000' not in result[0]['text'] + assert 'idx-1499' not in result[0]['text'] + + @pytest.mark.asyncio + async def test_get_segments_tool_with_limit_and_index(self): + """Test get_segments_tool with both limit and index parameters.""" + # Setup + mock_segments = [] + for i in range(200): + mock_segments.append({ + 'index': 'test-index', + 'shard': str(i % 10), + 'prirep': 'p' if i % 2 == 0 else 'r', + 'segment': f's{i}', + 'generation': str(i), + 'docs.count': '100', + 'docs.deleted': '5', + 'size': '1mb', + 'memory.bookkeeping': '500b', + 'memory.vectors': '0b', + 'memory.docvalues': '200b', + 'memory.terms': '300b', + 'version': '8.0.0' + }) + self.mock_client.cat.segments.return_value = mock_segments + + # Execute with both index and limit + args = self.GetSegmentsArgs(index='test-index', limit=100) + result = await self._get_segments_tool(args) + + # Assert + assert len(result) == 1 + assert result[0]['type'] == 'text' + assert 'Segment information for index: test-index' in result[0]['text'] + # Should have limited segments + assert 's0' in result[0]['text'] + assert 's99' in result[0]['text'] + assert 's100' not in result[0]['text'] + assert 's199' not in result[0]['text'] + self.mock_client.cat.segments.assert_called_once_with(index='test-index', format='json') + @pytest.mark.asyncio async def test_cat_nodes_tool(self): """Test cat_nodes_tool successful."""