Skip to content

Commit acac26b

Browse files
committed
pass header to mcp connector
Signed-off-by: Jiaping Zeng <[email protected]>
1 parent 0ff1d76 commit acac26b

File tree

10 files changed

+286
-40
lines changed

10 files changed

+286
-40
lines changed

common/src/main/java/org/opensearch/ml/common/CommonValue.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ public class CommonValue {
119119
public static final String MCP_TOOL_DESCRIPTION_FIELD = "description";
120120
public static final String MCP_TOOL_INPUT_SCHEMA_FIELD = "inputSchema";
121121
public static final String MCP_SYNC_CLIENT = "mcp_sync_client";
122+
public static final String MCP_CONNECTOR = "mcp_connector";
123+
public static final String MCP_CONNECTOR_CONFIG = "mcp_connector_config";
124+
public static final String MCP_REQUEST_HEADERS = "mcp_request_headers";
122125
public static final String MCP_TOOLS_FIELD = "tools";
123126
public static final String MCP_CONNECTORS_FIELD = "mcp_connectors";
124127
public static final String MCP_CONNECTOR_ID_FIELD = "mcp_connector_id";

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java

Lines changed: 62 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
import org.opensearch.core.xcontent.NamedXContentRegistry;
7070
import org.opensearch.core.xcontent.XContentParser;
7171
import org.opensearch.index.IndexNotFoundException;
72+
import org.opensearch.ml.common.CommonValue;
7273
import org.opensearch.ml.common.agent.MLAgent;
7374
import org.opensearch.ml.common.agent.MLToolSpec;
7475
import org.opensearch.ml.common.connector.Connector;
@@ -144,6 +145,27 @@ public class AgentUtils {
144145
public static final String DEFAULT_DATETIME_PREFIX = "Current date and time: ";
145146
private static final ZoneId UTC_ZONE = ZoneId.of("UTC");
146147

148+
public static Map<String, String> extractRequestHeaders(Map<String, String> parameters) {
149+
if (parameters == null) {
150+
return Collections.emptyMap();
151+
}
152+
153+
String headersJson = parameters.get(CommonValue.MCP_REQUEST_HEADERS);
154+
if (headersJson == null || headersJson.trim().isEmpty()) {
155+
return Collections.emptyMap();
156+
}
157+
158+
try {
159+
Type mapType = new TypeToken<Map<String, String>>() {
160+
}.getType();
161+
Map<String, String> headers = gson.fromJson(headersJson, mapType);
162+
return headers != null ? headers : Collections.emptyMap();
163+
} catch (Exception e) {
164+
log.warn("Failed to parse request headers from JSON: {}", headersJson, e);
165+
return Collections.emptyMap();
166+
}
167+
}
168+
147169
public static String addExamplesToPrompt(Map<String, String> parameters, String prompt) {
148170
Map<String, String> examplesMap = new HashMap<>();
149171
if (parameters.containsKey(EXAMPLES)) {
@@ -679,6 +701,7 @@ public static List<MLToolSpec> getMlToolSpecs(MLAgent mlAgent, Map<String, Strin
679701

680702
public static void getMcpToolSpecs(
681703
MLAgent mlAgent,
704+
Map<String, String> params,
682705
Client client,
683706
SdkClient sdkClient,
684707
Encryptor encryptor,
@@ -697,6 +720,8 @@ public static void getMcpToolSpecs(
697720
}.getType();
698721
List<Map<String, Object>> mcpConnectorConfigs = gson.fromJson(mcpConnectorConfigJSON, listType);
699722

723+
Map<String, String> requestHeaders = extractRequestHeaders(params);
724+
700725
// Use AtomicInteger to track completion of all async operations
701726
AtomicInteger remainingConnectors = new AtomicInteger(mcpConnectorConfigs.size());
702727
List<MLToolSpec> finalToolSpecs = Collections.synchronizedList(new ArrayList<>());
@@ -706,43 +731,52 @@ public static void getMcpToolSpecs(
706731
String connectorId = (String) mcpConnectorConfig.get(MCP_CONNECTOR_ID_FIELD);
707732
List<String> toolFilters = (List<String>) mcpConnectorConfig.get(TOOL_FILTERS_FIELD);
708733

709-
getMCPToolSpecsFromConnector(connectorId, tenantId, sdkClient, client, encryptor, ActionListener.wrap(mcpToolspecs -> {
710-
List<MLToolSpec> filteredTools;
711-
if (toolFilters == null || toolFilters.isEmpty()) {
712-
filteredTools = mcpToolspecs;
713-
} else {
714-
filteredTools = new ArrayList<>();
715-
List<Pattern> compiledPatterns = toolFilters.stream().map(Pattern::compile).collect(Collectors.toList());
716-
717-
for (MLToolSpec toolSpec : mcpToolspecs) {
718-
for (Pattern pattern : compiledPatterns) {
719-
if (pattern.matcher(toolSpec.getName()).matches()) {
720-
filteredTools.add(toolSpec);
721-
break;
734+
getMCPToolSpecsFromConnector(
735+
connectorId,
736+
tenantId,
737+
requestHeaders,
738+
sdkClient,
739+
client,
740+
encryptor,
741+
ActionListener.wrap(mcpToolspecs -> {
742+
List<MLToolSpec> filteredTools;
743+
if (toolFilters == null || toolFilters.isEmpty()) {
744+
filteredTools = mcpToolspecs;
745+
} else {
746+
filteredTools = new ArrayList<>();
747+
List<Pattern> compiledPatterns = toolFilters.stream().map(Pattern::compile).collect(Collectors.toList());
748+
749+
for (MLToolSpec toolSpec : mcpToolspecs) {
750+
for (Pattern pattern : compiledPatterns) {
751+
if (pattern.matcher(toolSpec.getName()).matches()) {
752+
filteredTools.add(toolSpec);
753+
break;
754+
}
722755
}
723756
}
724757
}
725-
}
726758

727-
finalToolSpecs.addAll(filteredTools);
759+
finalToolSpecs.addAll(filteredTools);
728760

729-
// If this is the last connector, send the final response
730-
if (remainingConnectors.decrementAndGet() == 0) {
731-
finalListener.onResponse(finalToolSpecs);
732-
}
733-
}, e -> {
734-
log.error("Error processing connector: " + connectorId, e);
735-
// Even on error, we need to check if this is the last connector
736-
if (remainingConnectors.decrementAndGet() == 0) {
737-
finalListener.onResponse(finalToolSpecs);
738-
}
739-
}));
761+
// If this is the last connector, send the final response
762+
if (remainingConnectors.decrementAndGet() == 0) {
763+
finalListener.onResponse(finalToolSpecs);
764+
}
765+
}, e -> {
766+
log.error("Error processing connector: " + connectorId, e);
767+
// Even on error, we need to check if this is the last connector
768+
if (remainingConnectors.decrementAndGet() == 0) {
769+
finalListener.onResponse(finalToolSpecs);
770+
}
771+
})
772+
);
740773
}
741774
}
742775

743776
private static void getMCPToolSpecsFromConnector(
744777
String connectorId,
745778
String tenantId,
779+
Map<String, String> requestHeaders,
746780
SdkClient sdkClient,
747781
Client client,
748782
Encryptor encryptor,
@@ -761,14 +795,14 @@ private static void getMCPToolSpecsFromConnector(
761795
if (connector instanceof McpConnector) {
762796
McpConnectorExecutor connectorExecutor = MLEngineClassLoader
763797
.initInstance(connector.getProtocol(), connector, Connector.class);
764-
mcpToolSpecs = connectorExecutor.getMcpToolSpecs();
798+
mcpToolSpecs = connectorExecutor.getMcpToolSpecs(requestHeaders);
765799
toolListener.onResponse(mcpToolSpecs);
766800
return;
767801
}
768802
if (connector instanceof McpStreamableHttpConnector) {
769803
McpStreamableHttpConnectorExecutor connectorExecutor = MLEngineClassLoader
770804
.initInstance(connector.getProtocol(), connector, Connector.class);
771-
mcpToolSpecs = connectorExecutor.getMcpToolSpecs();
805+
mcpToolSpecs = connectorExecutor.getMcpToolSpecs(requestHeaders);
772806
toolListener.onResponse(mcpToolSpecs);
773807
return;
774808
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ private void runAgent(
271271
};
272272

273273
// Fetch MCP tools and handle both success and failure cases
274-
getMcpToolSpecs(mlAgent, client, sdkClient, encryptor, ActionListener.wrap(mcpTools -> {
274+
getMcpToolSpecs(mlAgent, params, client, sdkClient, encryptor, ActionListener.wrap(mcpTools -> {
275275
toolSpecs.addAll(mcpTools);
276276
processTools.accept(toolSpecs);
277277
}, e -> {

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ private void setToolsAndRunAgent(
340340
};
341341

342342
// Fetch MCP tools and handle both success and failure cases
343-
getMcpToolSpecs(mlAgent, client, sdkClient, encryptor, ActionListener.wrap(mcpTools -> {
343+
getMcpToolSpecs(mlAgent, allParams, client, sdkClient, encryptor, ActionListener.wrap(mcpTools -> {
344344
toolSpecs.addAll(mcpTools);
345345
processTools.accept(toolSpecs);
346346
}, e -> {

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutor.java

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ public McpConnectorExecutor(Connector connector) {
6464
this.connector = (McpConnector) connector;
6565
}
6666

67-
public List<MLToolSpec> getMcpToolSpecs() {
67+
public List<MLToolSpec> getMcpToolSpecs(Map<String, String> requestHeaders) {
6868
String mcpServerUrl = connector.getUrl();
6969
String sseEndpoint = connector.getParameters() != null && connector.getParameters().containsKey(SSE_ENDPOINT_FIELD)
7070
? connector.getParameters().get(SSE_ENDPOINT_FIELD)
@@ -74,11 +74,17 @@ public List<MLToolSpec> getMcpToolSpecs() {
7474
Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout());
7575
Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout());
7676

77+
Map<String, String> mergedHeaders = new HashMap<>();
78+
if (connector.getDecryptedHeaders() != null) {
79+
mergedHeaders.putAll(connector.getDecryptedHeaders());
80+
}
81+
if (requestHeaders != null) {
82+
mergedHeaders.putAll(requestHeaders);
83+
}
84+
7785
Consumer<HttpRequest.Builder> headerConfig = builder -> {
78-
if (connector.getDecryptedHeaders() != null) {
79-
for (Map.Entry<String, String> entry : connector.getDecryptedHeaders().entrySet()) {
80-
builder.header(entry.getKey(), entry.getValue());
81-
}
86+
for (Map.Entry<String, String> entry : mergedHeaders.entrySet()) {
87+
builder.header(entry.getKey(), entry.getValue());
8288
}
8389
};
8490

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/McpStreamableHttpConnectorExecutor.java

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ public McpStreamableHttpConnectorExecutor(Connector connector) {
6666
this.connector = (McpStreamableHttpConnector) connector;
6767
}
6868

69-
public List<MLToolSpec> getMcpToolSpecs() {
69+
public List<MLToolSpec> getMcpToolSpecs(Map<String, String> requestHeaders) {
7070
String mcpServerUrl = connector.getUrl();
7171
String endpoint = Optional
7272
.ofNullable(connector.getParameters())
@@ -77,11 +77,17 @@ public List<MLToolSpec> getMcpToolSpecs() {
7777
Duration connectionTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getConnectionTimeout());
7878
Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout());
7979

80+
Map<String, String> mergedHeaders = new HashMap<>();
81+
if (connector.getDecryptedHeaders() != null) {
82+
mergedHeaders.putAll(connector.getDecryptedHeaders());
83+
}
84+
if (requestHeaders != null) {
85+
mergedHeaders.putAll(requestHeaders);
86+
}
87+
8088
Consumer<HttpRequest.Builder> headerConfig = builder -> {
81-
if (connector.getDecryptedHeaders() != null) {
82-
for (Map.Entry<String, String> entry : connector.getDecryptedHeaders().entrySet()) {
83-
builder.header(entry.getKey(), entry.getValue());
84-
}
89+
for (Map.Entry<String, String> entry : mergedHeaders.entrySet()) {
90+
builder.header(entry.getKey(), entry.getValue());
8591
}
8692
};
8793

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2000,4 +2000,147 @@ private void mockMcpStreamableHttpConnector(MockedStatic<Connector> connectorSta
20002000
doNothing().when(mockConnector).decrypt(anyString(), any(), anyString());
20012001
connectorStatic.when(() -> Connector.createConnector(any(XContentParser.class))).thenReturn(mockConnector);
20022002
}
2003+
2004+
@Test
2005+
public void testExtractRequestHeaders_WithValidHeaders() {
2006+
Map<String, String> parameters = new HashMap<>();
2007+
parameters
2008+
.put(
2009+
org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS,
2010+
"{\"Authorization\":\"Bearer token123\",\"Content-Type\":\"application/json\"}"
2011+
);
2012+
2013+
Map<String, String> result = AgentUtils.extractRequestHeaders(parameters);
2014+
2015+
assertEquals(2, result.size());
2016+
assertEquals("Bearer token123", result.get("Authorization"));
2017+
assertEquals("application/json", result.get("Content-Type"));
2018+
}
2019+
2020+
@Test
2021+
public void testExtractRequestHeaders_WithNullParameters() {
2022+
Map<String, String> result = AgentUtils.extractRequestHeaders(null);
2023+
2024+
assertEquals(0, result.size());
2025+
assertEquals(Collections.emptyMap(), result);
2026+
}
2027+
2028+
@Test
2029+
public void testExtractRequestHeaders_WithEmptyHeadersJson() {
2030+
Map<String, String> parameters = new HashMap<>();
2031+
parameters.put(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS, "");
2032+
2033+
Map<String, String> result = AgentUtils.extractRequestHeaders(parameters);
2034+
2035+
assertEquals(0, result.size());
2036+
assertEquals(Collections.emptyMap(), result);
2037+
}
2038+
2039+
@Test
2040+
public void testExtractRequestHeaders_WithNullHeadersJson() {
2041+
Map<String, String> parameters = new HashMap<>();
2042+
parameters.put(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS, null);
2043+
2044+
Map<String, String> result = AgentUtils.extractRequestHeaders(parameters);
2045+
2046+
assertEquals(0, result.size());
2047+
assertEquals(Collections.emptyMap(), result);
2048+
}
2049+
2050+
@Test
2051+
public void testExtractRequestHeaders_WithWhitespaceHeadersJson() {
2052+
Map<String, String> parameters = new HashMap<>();
2053+
parameters.put(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS, " ");
2054+
2055+
Map<String, String> result = AgentUtils.extractRequestHeaders(parameters);
2056+
2057+
assertEquals(0, result.size());
2058+
assertEquals(Collections.emptyMap(), result);
2059+
}
2060+
2061+
@Test
2062+
public void testExtractRequestHeaders_WithInvalidJson() {
2063+
Map<String, String> parameters = new HashMap<>();
2064+
parameters.put(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS, "{invalid json}");
2065+
2066+
Map<String, String> result = AgentUtils.extractRequestHeaders(parameters);
2067+
2068+
assertEquals(0, result.size());
2069+
assertEquals(Collections.emptyMap(), result);
2070+
}
2071+
2072+
@Test
2073+
public void testExtractRequestHeaders_WithEmptyJsonObject() {
2074+
Map<String, String> parameters = new HashMap<>();
2075+
parameters.put(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS, "{}");
2076+
2077+
Map<String, String> result = AgentUtils.extractRequestHeaders(parameters);
2078+
2079+
assertEquals(0, result.size());
2080+
}
2081+
2082+
@Test
2083+
public void testMergeHeaders_BothNonEmpty() {
2084+
Map<String, String> staticHeaders = new HashMap<>();
2085+
staticHeaders.put("X-Static-Header", "static-value");
2086+
staticHeaders.put("Authorization", "Bearer static-token");
2087+
2088+
Map<String, String> requestHeaders = new HashMap<>();
2089+
requestHeaders.put("X-Request-Header", "request-value");
2090+
requestHeaders.put("Authorization", "Bearer request-token");
2091+
2092+
Map<String, String> result = AgentUtils.mergeHeaders(staticHeaders, requestHeaders);
2093+
2094+
assertEquals(3, result.size());
2095+
assertEquals("static-value", result.get("X-Static-Header"));
2096+
assertEquals("request-value", result.get("X-Request-Header"));
2097+
assertEquals("Bearer request-token", result.get("Authorization")); // Request headers override static
2098+
}
2099+
2100+
@Test
2101+
public void testMergeHeaders_OnlyStaticHeaders() {
2102+
Map<String, String> staticHeaders = new HashMap<>();
2103+
staticHeaders.put("X-Static-Header", "static-value");
2104+
2105+
Map<String, String> result = AgentUtils.mergeHeaders(staticHeaders, null);
2106+
2107+
assertEquals(1, result.size());
2108+
assertEquals("static-value", result.get("X-Static-Header"));
2109+
}
2110+
2111+
@Test
2112+
public void testMergeHeaders_OnlyRequestHeaders() {
2113+
Map<String, String> requestHeaders = new HashMap<>();
2114+
requestHeaders.put("X-Request-Header", "request-value");
2115+
2116+
Map<String, String> result = AgentUtils.mergeHeaders(null, requestHeaders);
2117+
2118+
assertEquals(1, result.size());
2119+
assertEquals("request-value", result.get("X-Request-Header"));
2120+
}
2121+
2122+
@Test
2123+
public void testMergeHeaders_BothNull() {
2124+
Map<String, String> result = AgentUtils.mergeHeaders(null, null);
2125+
2126+
assertEquals(0, result.size());
2127+
}
2128+
2129+
@Test
2130+
public void testMergeHeaders_BothEmpty() {
2131+
Map<String, String> result = AgentUtils.mergeHeaders(new HashMap<>(), new HashMap<>());
2132+
2133+
assertEquals(0, result.size());
2134+
}
2135+
2136+
@Test
2137+
public void testMergeHeaders_EmptyStaticNonEmptyRequest() {
2138+
Map<String, String> requestHeaders = new HashMap<>();
2139+
requestHeaders.put("X-Request-Header", "request-value");
2140+
2141+
Map<String, String> result = AgentUtils.mergeHeaders(new HashMap<>(), requestHeaders);
2142+
2143+
assertEquals(1, result.size());
2144+
assertEquals("request-value", result.get("X-Request-Header"));
2145+
}
20032146
}

0 commit comments

Comments
 (0)