Skip to content

Commit 6cd9552

Browse files
committed
pass header to mcp connector
Signed-off-by: Jiaping Zeng <[email protected]>
1 parent 305cb8e commit 6cd9552

File tree

10 files changed

+366
-49
lines changed

10 files changed

+366
-49
lines changed

common/src/main/java/org/opensearch/ml/common/CommonValue.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ public class CommonValue {
119119
public static final String MCP_TOOL_DESCRIPTION_FIELD = "description";
120120
public static final String MCP_TOOL_INPUT_SCHEMA_FIELD = "inputSchema";
121121
public static final String MCP_SYNC_CLIENT = "mcp_sync_client";
122+
public static final String MCP_CONNECTOR = "mcp_connector";
123+
public static final String MCP_CONNECTOR_CONFIG = "mcp_connector_config";
124+
public static final String MCP_REQUEST_HEADERS = "mcp_request_headers";
122125
public static final String MCP_TOOLS_FIELD = "tools";
123126
public static final String MCP_CONNECTORS_FIELD = "mcp_connectors";
124127
public static final String MCP_CONNECTOR_ID_FIELD = "mcp_connector_id";

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java

Lines changed: 136 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import java.util.Collection;
4444
import java.util.Collections;
4545
import java.util.HashMap;
46+
import java.util.HashSet;
4647
import java.util.List;
4748
import java.util.Locale;
4849
import java.util.Map;
@@ -69,9 +70,11 @@
6970
import org.opensearch.core.xcontent.NamedXContentRegistry;
7071
import org.opensearch.core.xcontent.XContentParser;
7172
import org.opensearch.index.IndexNotFoundException;
73+
import org.opensearch.ml.common.CommonValue;
7274
import org.opensearch.ml.common.agent.MLAgent;
7375
import org.opensearch.ml.common.agent.MLToolSpec;
7476
import org.opensearch.ml.common.connector.Connector;
77+
import org.opensearch.ml.common.connector.ConnectorClientConfig;
7578
import org.opensearch.ml.common.connector.McpConnector;
7679
import org.opensearch.ml.common.connector.McpStreamableHttpConnector;
7780
import org.opensearch.ml.common.output.model.ModelTensor;
@@ -95,6 +98,7 @@
9598
import com.jayway.jsonpath.JsonPath;
9699
import com.jayway.jsonpath.PathNotFoundException;
97100

101+
import io.modelcontextprotocol.client.McpSyncClient;
98102
import lombok.extern.log4j.Log4j2;
99103

100104
@Log4j2
@@ -144,6 +148,56 @@ public class AgentUtils {
144148
public static final String DEFAULT_DATETIME_PREFIX = "Current date and time: ";
145149
private static final ZoneId UTC_ZONE = ZoneId.of("UTC");
146150

151+
private static void storeConnectorConfigInToolSpecs(
152+
List<MLToolSpec> toolSpecs,
153+
Connector connector,
154+
ConnectorClientConfig config,
155+
String connectorId
156+
) {
157+
for (MLToolSpec toolSpec : toolSpecs) {
158+
toolSpec.addRuntimeResource(CommonValue.MCP_CONNECTOR, connector);
159+
toolSpec.addRuntimeResource(CommonValue.MCP_CONNECTOR_CONFIG, config);
160+
toolSpec.addRuntimeResource(CommonValue.MCP_CONNECTOR_ID_FIELD, connectorId);
161+
}
162+
}
163+
164+
public static Map<String, String> extractRequestHeaders(Map<String, String> parameters) {
165+
if (parameters == null) {
166+
return Collections.emptyMap();
167+
}
168+
169+
String headersJson = parameters.get(CommonValue.MCP_REQUEST_HEADERS);
170+
if (headersJson == null || headersJson.trim().isEmpty()) {
171+
return Collections.emptyMap();
172+
}
173+
174+
try {
175+
Type mapType = new TypeToken<Map<String, String>>() {
176+
}.getType();
177+
Map<String, String> headers = gson.fromJson(headersJson, mapType);
178+
return headers != null ? headers : Collections.emptyMap();
179+
} catch (Exception e) {
180+
log.warn("Failed to parse request headers from JSON: {}", headersJson, e);
181+
return Collections.emptyMap();
182+
}
183+
}
184+
185+
public static Map<String, String> mergeHeaders(Map<String, String> staticHeaders, Map<String, String> requestHeaders) {
186+
Map<String, String> mergedHeaders = new HashMap<>();
187+
188+
// Add static headers first
189+
if (staticHeaders != null && !staticHeaders.isEmpty()) {
190+
mergedHeaders.putAll(staticHeaders);
191+
}
192+
193+
// Add request headers, overriding any static headers with the same key
194+
if (requestHeaders != null && !requestHeaders.isEmpty()) {
195+
mergedHeaders.putAll(requestHeaders);
196+
}
197+
198+
return mergedHeaders;
199+
}
200+
147201
public static String addExamplesToPrompt(Map<String, String> parameters, String prompt) {
148202
Map<String, String> examplesMap = new HashMap<>();
149203
if (parameters.containsKey(EXAMPLES)) {
@@ -762,13 +816,15 @@ private static void getMCPToolSpecsFromConnector(
762816
McpConnectorExecutor connectorExecutor = MLEngineClassLoader
763817
.initInstance(connector.getProtocol(), connector, Connector.class);
764818
mcpToolSpecs = connectorExecutor.getMcpToolSpecs();
819+
storeConnectorConfigInToolSpecs(mcpToolSpecs, connector, connectorExecutor.getConnectorClientConfig(), connectorId);
765820
toolListener.onResponse(mcpToolSpecs);
766821
return;
767822
}
768823
if (connector instanceof McpStreamableHttpConnector) {
769824
McpStreamableHttpConnectorExecutor connectorExecutor = MLEngineClassLoader
770825
.initInstance(connector.getProtocol(), connector, Connector.class);
771826
mcpToolSpecs = connectorExecutor.getMcpToolSpecs();
827+
storeConnectorConfigInToolSpecs(mcpToolSpecs, connector, connectorExecutor.getConnectorClientConfig(), connectorId);
772828
toolListener.onResponse(mcpToolSpecs);
773829
return;
774830
}
@@ -857,19 +913,52 @@ public static void createTools(
857913
if (toolSpecs == null) {
858914
return;
859915
}
916+
917+
Map<String, String> requestHeaders = extractRequestHeaders(params);
918+
Map<String, McpSyncClient> mcpClients = new HashMap<>();
919+
860920
for (MLToolSpec toolSpec : toolSpecs) {
861921
Map<String, String> toolParams = buildToolParameters(params, toolSpec, mlAgent.getTenantId());
862922
Tool tool = createTool(toolFactories, toolParams, toolSpec);
863923
tools.put(tool.getName(), tool);
924+
864925
if (toolSpec.getAttributes() != null) {
865926
if (tool.getAttributes() == null) {
866-
Map<String, Object> attributes = new HashMap<>();
867-
attributes.putAll(toolSpec.getAttributes());
868-
tool.setAttributes(attributes);
927+
tool.setAttributes(new HashMap<>(toolSpec.getAttributes()));
869928
} else {
870929
tool.getAttributes().putAll(toolSpec.getAttributes());
871930
}
872931
}
932+
Map<String, Object> runtimeResources = toolSpec.getRuntimeResources();
933+
if (runtimeResources != null) {
934+
if (tool.getAttributes() == null) {
935+
tool.setAttributes(new HashMap<>(runtimeResources));
936+
} else {
937+
tool.getAttributes().putAll(runtimeResources);
938+
}
939+
940+
Connector connector = (Connector) runtimeResources.get(CommonValue.MCP_CONNECTOR);
941+
ConnectorClientConfig config = (ConnectorClientConfig) runtimeResources.get(CommonValue.MCP_CONNECTOR_CONFIG);
942+
String connectorId = (String) runtimeResources.get(CommonValue.MCP_CONNECTOR_ID_FIELD);
943+
if (connector != null && config != null && connectorId != null) {
944+
McpSyncClient mcpClient = mcpClients.get(connectorId);
945+
if (mcpClient == null) {
946+
if (connector instanceof McpStreamableHttpConnector) {
947+
mcpClient = McpStreamableHttpTool
948+
.createMcpClient((McpStreamableHttpConnector) connector, config, requestHeaders);
949+
} else if (connector instanceof McpConnector) {
950+
mcpClient = McpSseTool.createMcpClient((McpConnector) connector, config, requestHeaders);
951+
}
952+
if (mcpClient != null) {
953+
mcpClients.put(connectorId, mcpClient);
954+
log.info("Created MCP client for connector: {}", connectorId);
955+
}
956+
}
957+
if (mcpClient != null) {
958+
tool.getAttributes().put(CommonValue.MCP_SYNC_CLIENT, mcpClient);
959+
}
960+
}
961+
}
873962
toolSpecMap.put(tool.getName(), toolSpec);
874963
}
875964
}
@@ -928,19 +1017,6 @@ public static Map<String, String> constructToolParams(
9281017
return toolParams;
9291018
}
9301019

931-
public static void cleanUpResource(Map<String, Tool> tools) {
932-
for (Map.Entry<String, Tool> entry : tools.entrySet()) {
933-
Tool tool = entry.getValue();
934-
if (tool instanceof McpSseTool) {
935-
// TODO: make this more general, avoid checking specific tool type
936-
((McpSseTool) tool).getMcpSyncClient().closeGracefully();
937-
} else if (tool instanceof McpStreamableHttpTool) {
938-
// TODO: make this more general, avoid checking specific tool type
939-
((McpStreamableHttpTool) tool).getMcpSyncClient().closeGracefully();
940-
}
941-
}
942-
}
943-
9441020
/**
9451021
* Generates a formatted current date and time string in UTC timezone.
9461022
*
@@ -1014,4 +1090,48 @@ public static Tool createTool(Map<String, Tool.Factory> toolFactories, Map<Strin
10141090

10151091
return tool;
10161092
}
1093+
1094+
public static void cleanupMcpClients(Map<String, Tool> tools) {
1095+
if (tools == null || tools.isEmpty()) {
1096+
return;
1097+
}
1098+
Set<McpSyncClient> closedClients = new HashSet<>();
1099+
for (Tool tool : tools.values()) {
1100+
if (tool.getAttributes() != null) {
1101+
Object clientObj = tool.getAttributes().get(CommonValue.MCP_SYNC_CLIENT);
1102+
if (clientObj instanceof McpSyncClient) {
1103+
McpSyncClient client = (McpSyncClient) clientObj;
1104+
if (!closedClients.contains(client)) {
1105+
try {
1106+
client.closeGracefully();
1107+
closedClients.add(client);
1108+
log.info("Successfully closed MCP client for tool: {}", tool.getName());
1109+
} catch (Exception e) {
1110+
log.warn("Failed to close MCP client gracefully for tool: {}. Error: {}", tool.getName(), e.getMessage());
1111+
}
1112+
}
1113+
}
1114+
}
1115+
}
1116+
if (!closedClients.isEmpty()) {
1117+
log.info("Cleaned up {} MCP client(s)", closedClients.size());
1118+
}
1119+
}
1120+
1121+
public static <T> ActionListener<T> wrapListenerWithMcpCleanup(ActionListener<T> delegate, Map<String, Tool> tools) {
1122+
return ActionListener.wrap(response -> {
1123+
try {
1124+
cleanupMcpClients(tools);
1125+
} finally {
1126+
delegate.onResponse(response);
1127+
}
1128+
}, exception -> {
1129+
try {
1130+
cleanupMcpClients(tools);
1131+
} finally {
1132+
delegate.onFailure(exception);
1133+
}
1134+
});
1135+
}
1136+
10171137
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_RESPONSE;
2323
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.TOOL_RESULT;
2424
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.VERBOSE;
25-
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.cleanUpResource;
2625
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.constructToolParams;
2726
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTools;
2827
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getCurrentDateTime;
@@ -267,7 +266,19 @@ private void runAgent(
267266
Map<String, Tool> tools = new HashMap<>();
268267
Map<String, MLToolSpec> toolSpecMap = new HashMap<>();
269268
createTools(toolFactories, params, allToolSpecs, tools, toolSpecMap, mlAgent);
270-
runReAct(mlAgent.getLlm(), tools, toolSpecMap, params, memory, sessionId, mlAgent.getTenantId(), listener, functionCalling);
269+
ActionListener<Object> wrappedListener = AgentUtils.wrapListenerWithMcpCleanup(listener, tools);
270+
271+
runReAct(
272+
mlAgent.getLlm(),
273+
tools,
274+
toolSpecMap,
275+
params,
276+
memory,
277+
sessionId,
278+
mlAgent.getTenantId(),
279+
wrappedListener,
280+
functionCalling
281+
);
271282
};
272283

273284
// Fetch MCP tools and handle both success and failure cases
@@ -369,7 +380,6 @@ private void runReAct(
369380
additionalInfo,
370381
finalAnswer
371382
);
372-
cleanUpResource(tools);
373383
return;
374384
}
375385

@@ -905,7 +915,6 @@ private void handleMaxIterationsReached(
905915
additionalInfo,
906916
incompleteResponse
907917
);
908-
cleanUpResource(tools);
909918
}
910919

911920
private void saveMessage(

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_BEDROCK_CONVERSE_DEEPSEEK_R1;
1616
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS;
1717
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_RESPONSE_FILTER;
18-
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.cleanUpResource;
1918
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.createTools;
2019
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getCurrentDateTime;
2120
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.getMcpToolSpecs;
@@ -620,7 +619,6 @@ void addToolsToPrompt(Map<String, Tool> tools, Map<String, String> allParams) {
620619
toolsPrompt.append("No other tools are available. Do not invent tools. Only use tools to create the plan.\n\n");
621620
allParams.put(DEFAULT_PROMPT_TOOLS_FIELD, toolsPrompt.toString());
622621
populatePrompt(allParams);
623-
cleanUpResource(tools);
624622
}
625623

626624
@VisibleForTesting

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutor.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
package org.opensearch.ml.engine.algorithms.remote;
77

88
import static org.opensearch.ml.common.CommonValue.MCP_DEFAULT_SSE_ENDPOINT;
9-
import static org.opensearch.ml.common.CommonValue.MCP_SYNC_CLIENT;
109
import static org.opensearch.ml.common.CommonValue.MCP_TOOLS_FIELD;
1110
import static org.opensearch.ml.common.CommonValue.MCP_TOOL_DESCRIPTION_FIELD;
1211
import static org.opensearch.ml.common.CommonValue.MCP_TOOL_INPUT_SCHEMA_FIELD;
@@ -124,7 +123,6 @@ public List<MLToolSpec> getMcpToolSpecs() {
124123
.description(description)
125124
.attributes(attributes)
126125
.build();
127-
mlToolSpec.addRuntimeResource(MCP_SYNC_CLIENT, client);
128126
mcpToolSpecs.add(mlToolSpec);
129127
}
130128

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/McpStreamableHttpConnectorExecutor.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import static org.opensearch.ml.common.CommonValue.ENDPOINT_FIELD;
99
import static org.opensearch.ml.common.CommonValue.MCP_DEFAULT_STREAMABLE_HTTP_ENDPOINT;
10-
import static org.opensearch.ml.common.CommonValue.MCP_SYNC_CLIENT;
1110
import static org.opensearch.ml.common.CommonValue.MCP_TOOLS_FIELD;
1211
import static org.opensearch.ml.common.CommonValue.MCP_TOOL_DESCRIPTION_FIELD;
1312
import static org.opensearch.ml.common.CommonValue.MCP_TOOL_INPUT_SCHEMA_FIELD;
@@ -131,7 +130,6 @@ public List<MLToolSpec> getMcpToolSpecs() {
131130
.description(description)
132131
.attributes(attributes)
133132
.build();
134-
mlToolSpec.addRuntimeResource(MCP_SYNC_CLIENT, client);
135133
mcpToolSpecs.add(mlToolSpec);
136134
}
137135

0 commit comments

Comments
 (0)