Skip to content

Commit bd95634

Browse files
Fix Concurrency: Use Async OpenSearch client to improve concurrency (#125)
* Fix Concurrency: Use Async OpenSearch client to improve concurrency Signed-off-by: rithin-pullela-aws <[email protected]> * Add changelog Signed-off-by: rithin-pullela-aws <[email protected]> --------- Signed-off-by: rithin-pullela-aws <[email protected]>
1 parent 1e1ff72 commit bd95634

File tree

17 files changed

+290
-225
lines changed

17 files changed

+290
-225
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
99
- Add header-based authentication + Code Clean up ([#117](https://github.com/opensearch-project/opensearch-mcp-server-py/pull/117))
1010

1111
### Fixed
12+
- Fix Concurrency: Use Async OpenSearch client to improve concurrency ([#125](https://github.com/opensearch-project/opensearch-mcp-server-py/pull/125))
1213

1314
### Removed
1415

src/mcp_server_opensearch/clusters_information.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def get_cluster(name: str) -> Optional[ClusterInfo]:
5252
return cluster_registry.get(name)
5353

5454

55-
def load_clusters_from_yaml(file_path: str) -> None:
55+
async def load_clusters_from_yaml(file_path: str) -> None:
5656
"""Load cluster configurations from a YAML file and populate the global registry.
5757
5858
Args:
@@ -117,7 +117,7 @@ def load_clusters_from_yaml(file_path: str) -> None:
117117
opensearch_header_auth=cluster_config.get('opensearch_header_auth', None),
118118
)
119119
# Check if possible to connect to the cluster
120-
is_connected, error_message = check_cluster_connection(cluster_info)
120+
is_connected, error_message = await check_cluster_connection(cluster_info)
121121
if not is_connected:
122122
result['errors'].append(
123123
f"Error connecting to cluster '{cluster_name}': {error_message}"
@@ -143,7 +143,7 @@ def load_clusters_from_yaml(file_path: str) -> None:
143143
raise yaml.YAMLError(f'Invalid YAML format in {file_path}: {str(e)}')
144144

145145

146-
def check_cluster_connection(cluster_info: ClusterInfo) -> tuple[bool, str]:
146+
async def check_cluster_connection(cluster_info: ClusterInfo) -> tuple[bool, str]:
147147
"""Check if the cluster is reachable by attempting to connect.
148148
149149
Args:
@@ -157,7 +157,7 @@ def check_cluster_connection(cluster_info: ClusterInfo) -> tuple[bool, str]:
157157
from opensearch.client import _initialize_client_multi_mode
158158

159159
client = _initialize_client_multi_mode(cluster_info)
160-
client.ping()
160+
await client.ping()
161161
return True, ''
162162
except Exception as e:
163163
return False, str(e)

src/mcp_server_opensearch/stdio_server.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ async def serve(
3434
server = Server('opensearch-mcp-server')
3535
# Load clusters from YAML file
3636
if mode == 'multi':
37-
load_clusters_from_yaml(config_file_path)
37+
await load_clusters_from_yaml(config_file_path)
3838

3939
# Call tool generator
4040
await generate_tools_from_openapi()
@@ -43,7 +43,9 @@ async def serve(
4343
TOOL_REGISTRY, config_file_path, cli_tool_overrides or {}
4444
)
4545
# Get enabled tools (tool filter)
46-
enabled_tools = get_tools(tool_registry=customized_registry, config_file_path=config_file_path)
46+
enabled_tools = await get_tools(
47+
tool_registry=customized_registry, config_file_path=config_file_path
48+
)
4749
logging.info(f'Enabled tools: {list(enabled_tools.keys())}')
4850

4951
@server.list_tools()

src/mcp_server_opensearch/streaming_server.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ async def create_mcp_server(
4141

4242
# Load clusters from YAML file
4343
if mode == 'multi':
44-
load_clusters_from_yaml(config_file_path)
44+
await load_clusters_from_yaml(config_file_path)
4545

4646
server = Server('opensearch-mcp-server')
4747
# Call tool generator
@@ -51,7 +51,9 @@ async def create_mcp_server(
5151
TOOL_REGISTRY, config_file_path, cli_tool_overrides or {}
5252
)
5353
# Get enabled tools (tool filter)
54-
enabled_tools = get_tools(tool_registry=customized_registry, config_file_path=config_file_path)
54+
enabled_tools = await get_tools(
55+
tool_registry=customized_registry, config_file_path=config_file_path
56+
)
5557
logging.info(f'Enabled tools: {list(enabled_tools.keys())}')
5658

5759
@server.list_tools()
@@ -141,6 +143,7 @@ def create_app(self) -> Starlette:
141143
Route('/health', endpoint=self.handle_health, methods=['GET']),
142144
Mount('/messages/', app=self.sse.handle_post_message),
143145
Mount('/mcp', app=self.handle_streamable_http),
146+
Mount('/mcp/', app=self.handle_streamable_http),
144147
],
145148
lifespan=self.lifespan,
146149
)

src/opensearch/client.py

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919

2020
from mcp_server_opensearch.clusters_information import ClusterInfo, get_cluster
2121
from mcp_server_opensearch.global_state import get_mode, get_profile
22-
from opensearchpy import OpenSearch, RequestsHttpConnection
23-
from requests_aws4auth import AWS4Auth
22+
from opensearchpy import AsyncOpenSearch, AsyncHttpConnection, AWSV4SignerAsyncAuth
2423
from tools.tool_params import baseToolArgs
24+
from botocore.credentials import Credentials
2525

2626
# Configure logging
2727
logger = logging.getLogger(__name__)
@@ -53,7 +53,7 @@ class ConfigurationError(OpenSearchClientError):
5353

5454

5555
# Public API Functions
56-
def initialize_client(args: baseToolArgs) -> OpenSearch:
56+
def initialize_client(args: baseToolArgs) -> AsyncOpenSearch:
5757
"""Initialize and return an OpenSearch client based on the current mode.
5858
5959
Behavior depends on the global mode:
@@ -100,7 +100,7 @@ def initialize_client(args: baseToolArgs) -> OpenSearch:
100100

101101

102102
# Private Implementation Functions
103-
def _initialize_client_single_mode() -> OpenSearch:
103+
def _initialize_client_single_mode() -> AsyncOpenSearch:
104104
"""Initialize OpenSearch client for single mode using environment variables.
105105
106106
Single mode uses environment variables for connection, with optional header-based auth:
@@ -200,7 +200,7 @@ def _initialize_client_single_mode() -> OpenSearch:
200200
raise ConfigurationError(f'Failed to initialize single mode client: {e}')
201201

202202

203-
def _initialize_client_multi_mode(cluster_info: ClusterInfo) -> OpenSearch:
203+
def _initialize_client_multi_mode(cluster_info: ClusterInfo) -> AsyncOpenSearch:
204204
"""Initialize OpenSearch client for multi mode using cluster configuration.
205205
206206
Multi mode uses cluster configuration from the provided ClusterInfo object.
@@ -303,7 +303,7 @@ def _create_opensearch_client(
303303
aws_access_key_id: Optional[str] = None,
304304
aws_secret_access_key: Optional[str] = None,
305305
aws_session_token: Optional[str] = None,
306-
) -> OpenSearch:
306+
) -> AsyncOpenSearch:
307307
"""Common function to create OpenSearch client with authentication.
308308
309309
This function handles the common authentication logic used by both
@@ -362,7 +362,7 @@ def _create_opensearch_client(
362362
'hosts': [opensearch_url],
363363
'use_ssl': (parsed_url.scheme == 'https'),
364364
'verify_certs': ssl_verify,
365-
'connection_class': RequestsHttpConnection,
365+
'connection_class': AsyncHttpConnection,
366366
'timeout': timeout,
367367
}
368368

@@ -379,7 +379,7 @@ def _create_opensearch_client(
379379
if opensearch_no_auth:
380380
logger.info('[NO AUTH] Attempting connection without authentication')
381381
try:
382-
return OpenSearch(**client_kwargs)
382+
return AsyncOpenSearch(**client_kwargs)
383383
except Exception as e:
384384
logger.error(f'[NO AUTH] Failed to connect without authentication: {e}')
385385
raise AuthenticationError(f'Failed to connect without authentication: {e}')
@@ -392,16 +392,16 @@ def _create_opensearch_client(
392392
raise AuthenticationError(
393393
'AWS region is required for header-based authentication'
394394
)
395-
396-
aws_auth = AWS4Auth(
397-
aws_access_key_id,
398-
aws_secret_access_key,
399-
aws_region.strip(),
400-
service_name,
401-
session_token=aws_session_token,
395+
credentials = Credentials(
396+
access_key=aws_access_key_id,
397+
secret_key=aws_secret_access_key,
398+
token=aws_session_token,
399+
)
400+
aws_auth = AWSV4SignerAsyncAuth(
401+
credentials=credentials, region=aws_region.strip(), service=service_name
402402
)
403403
client_kwargs['http_auth'] = aws_auth
404-
return OpenSearch(**client_kwargs)
404+
return AsyncOpenSearch(**client_kwargs)
405405
except Exception as e:
406406
logger.error(f'[HEADER AUTH] Failed to authenticate with header credentials: {e}')
407407
raise AuthenticationError(f'Failed to authenticate with header credentials: {e}')
@@ -419,15 +419,11 @@ def _create_opensearch_client(
419419
)
420420
credentials = assumed_role['Credentials']
421421

422-
aws_auth = AWS4Auth(
423-
credentials['AccessKeyId'],
424-
credentials['SecretAccessKey'],
425-
aws_region,
426-
service_name,
427-
session_token=credentials['SessionToken'],
422+
aws_auth = AWSV4SignerAsyncAuth(
423+
credentials=credentials, region=aws_region.strip(), service=service_name
428424
)
429425
client_kwargs['http_auth'] = aws_auth
430-
return OpenSearch(**client_kwargs)
426+
return AsyncOpenSearch(**client_kwargs)
431427
except Exception as e:
432428
logger.error(f'[IAM AUTH] Failed to assume IAM role {iam_arn}: {e}')
433429
raise AuthenticationError(f'Failed to assume IAM role {iam_arn}: {e}')
@@ -437,7 +433,7 @@ def _create_opensearch_client(
437433
logger.info(f'[BASIC AUTH] Using basic authentication for user: {opensearch_username}')
438434
try:
439435
client_kwargs['http_auth'] = (opensearch_username.strip(), opensearch_password)
440-
return OpenSearch(**client_kwargs)
436+
return AsyncOpenSearch(**client_kwargs)
441437
except Exception as e:
442438
logger.error(f'[BASIC AUTH] Failed to connect with basic authentication: {e}')
443439
raise AuthenticationError(f'Failed to connect with basic authentication: {e}')
@@ -454,13 +450,11 @@ def _create_opensearch_client(
454450
if not credentials:
455451
raise AuthenticationError('No AWS credentials found in session')
456452

457-
aws_auth = AWS4Auth(
458-
refreshable_credentials=credentials,
459-
service=service_name,
460-
region=aws_region,
453+
aws_auth = AWSV4SignerAsyncAuth(
454+
credentials=credentials, region=aws_region.strip(), service=service_name
461455
)
462456
client_kwargs['http_auth'] = aws_auth
463-
return OpenSearch(**client_kwargs)
457+
return AsyncOpenSearch(**client_kwargs)
464458
except Exception as e:
465459
logger.error(f'[AWS CREDS] Failed to authenticate with AWS credentials: {e}')
466460
raise AuthenticationError(f'Failed to authenticate with AWS credentials: {e}')

0 commit comments

Comments
 (0)