|
43 | 43 | import java.util.Collection; |
44 | 44 | import java.util.Collections; |
45 | 45 | import java.util.HashMap; |
| 46 | +import java.util.HashSet; |
46 | 47 | import java.util.List; |
47 | 48 | import java.util.Locale; |
48 | 49 | import java.util.Map; |
|
69 | 70 | import org.opensearch.core.xcontent.NamedXContentRegistry; |
70 | 71 | import org.opensearch.core.xcontent.XContentParser; |
71 | 72 | import org.opensearch.index.IndexNotFoundException; |
| 73 | +import org.opensearch.ml.common.CommonValue; |
72 | 74 | import org.opensearch.ml.common.agent.MLAgent; |
73 | 75 | import org.opensearch.ml.common.agent.MLToolSpec; |
74 | 76 | import org.opensearch.ml.common.connector.Connector; |
| 77 | +import org.opensearch.ml.common.connector.ConnectorClientConfig; |
75 | 78 | import org.opensearch.ml.common.connector.McpConnector; |
76 | 79 | import org.opensearch.ml.common.connector.McpStreamableHttpConnector; |
77 | 80 | import org.opensearch.ml.common.output.model.ModelTensor; |
|
95 | 98 | import com.jayway.jsonpath.JsonPath; |
96 | 99 | import com.jayway.jsonpath.PathNotFoundException; |
97 | 100 |
|
| 101 | +import io.modelcontextprotocol.client.McpSyncClient; |
98 | 102 | import lombok.extern.log4j.Log4j2; |
99 | 103 |
|
100 | 104 | @Log4j2 |
@@ -144,6 +148,56 @@ public class AgentUtils { |
144 | 148 | public static final String DEFAULT_DATETIME_PREFIX = "Current date and time: "; |
145 | 149 | private static final ZoneId UTC_ZONE = ZoneId.of("UTC"); |
146 | 150 |
|
| 151 | + private static void storeConnectorConfigInToolSpecs( |
| 152 | + List<MLToolSpec> toolSpecs, |
| 153 | + Connector connector, |
| 154 | + ConnectorClientConfig config, |
| 155 | + String connectorId |
| 156 | + ) { |
| 157 | + for (MLToolSpec toolSpec : toolSpecs) { |
| 158 | + toolSpec.addRuntimeResource(CommonValue.MCP_CONNECTOR, connector); |
| 159 | + toolSpec.addRuntimeResource(CommonValue.MCP_CONNECTOR_CONFIG, config); |
| 160 | + toolSpec.addRuntimeResource(CommonValue.MCP_CONNECTOR_ID_FIELD, connectorId); |
| 161 | + } |
| 162 | + } |
| 163 | + |
| 164 | + public static Map<String, String> extractRequestHeaders(Map<String, String> parameters) { |
| 165 | + if (parameters == null) { |
| 166 | + return Collections.emptyMap(); |
| 167 | + } |
| 168 | + |
| 169 | + String headersJson = parameters.get(CommonValue.MCP_REQUEST_HEADERS); |
| 170 | + if (headersJson == null || headersJson.trim().isEmpty()) { |
| 171 | + return Collections.emptyMap(); |
| 172 | + } |
| 173 | + |
| 174 | + try { |
| 175 | + Type mapType = new TypeToken<Map<String, String>>() { |
| 176 | + }.getType(); |
| 177 | + Map<String, String> headers = gson.fromJson(headersJson, mapType); |
| 178 | + return headers != null ? headers : Collections.emptyMap(); |
| 179 | + } catch (Exception e) { |
| 180 | + log.warn("Failed to parse request headers from JSON: {}", headersJson, e); |
| 181 | + return Collections.emptyMap(); |
| 182 | + } |
| 183 | + } |
| 184 | + |
| 185 | + public static Map<String, String> mergeHeaders(Map<String, String> staticHeaders, Map<String, String> requestHeaders) { |
| 186 | + Map<String, String> mergedHeaders = new HashMap<>(); |
| 187 | + |
| 188 | + // Add static headers first |
| 189 | + if (staticHeaders != null && !staticHeaders.isEmpty()) { |
| 190 | + mergedHeaders.putAll(staticHeaders); |
| 191 | + } |
| 192 | + |
| 193 | + // Add request headers, overriding any static headers with the same key |
| 194 | + if (requestHeaders != null && !requestHeaders.isEmpty()) { |
| 195 | + mergedHeaders.putAll(requestHeaders); |
| 196 | + } |
| 197 | + |
| 198 | + return mergedHeaders; |
| 199 | + } |
| 200 | + |
147 | 201 | public static String addExamplesToPrompt(Map<String, String> parameters, String prompt) { |
148 | 202 | Map<String, String> examplesMap = new HashMap<>(); |
149 | 203 | if (parameters.containsKey(EXAMPLES)) { |
@@ -762,13 +816,15 @@ private static void getMCPToolSpecsFromConnector( |
762 | 816 | McpConnectorExecutor connectorExecutor = MLEngineClassLoader |
763 | 817 | .initInstance(connector.getProtocol(), connector, Connector.class); |
764 | 818 | mcpToolSpecs = connectorExecutor.getMcpToolSpecs(); |
| 819 | + storeConnectorConfigInToolSpecs(mcpToolSpecs, connector, connectorExecutor.getConnectorClientConfig(), connectorId); |
765 | 820 | toolListener.onResponse(mcpToolSpecs); |
766 | 821 | return; |
767 | 822 | } |
768 | 823 | if (connector instanceof McpStreamableHttpConnector) { |
769 | 824 | McpStreamableHttpConnectorExecutor connectorExecutor = MLEngineClassLoader |
770 | 825 | .initInstance(connector.getProtocol(), connector, Connector.class); |
771 | 826 | mcpToolSpecs = connectorExecutor.getMcpToolSpecs(); |
| 827 | + storeConnectorConfigInToolSpecs(mcpToolSpecs, connector, connectorExecutor.getConnectorClientConfig(), connectorId); |
772 | 828 | toolListener.onResponse(mcpToolSpecs); |
773 | 829 | return; |
774 | 830 | } |
@@ -857,19 +913,52 @@ public static void createTools( |
857 | 913 | if (toolSpecs == null) { |
858 | 914 | return; |
859 | 915 | } |
| 916 | + |
| 917 | + Map<String, String> requestHeaders = extractRequestHeaders(params); |
| 918 | + Map<String, McpSyncClient> mcpClients = new HashMap<>(); |
| 919 | + |
860 | 920 | for (MLToolSpec toolSpec : toolSpecs) { |
861 | 921 | Map<String, String> toolParams = buildToolParameters(params, toolSpec, mlAgent.getTenantId()); |
862 | 922 | Tool tool = createTool(toolFactories, toolParams, toolSpec); |
863 | 923 | tools.put(tool.getName(), tool); |
| 924 | + |
864 | 925 | if (toolSpec.getAttributes() != null) { |
865 | 926 | if (tool.getAttributes() == null) { |
866 | | - Map<String, Object> attributes = new HashMap<>(); |
867 | | - attributes.putAll(toolSpec.getAttributes()); |
868 | | - tool.setAttributes(attributes); |
| 927 | + tool.setAttributes(new HashMap<>(toolSpec.getAttributes())); |
869 | 928 | } else { |
870 | 929 | tool.getAttributes().putAll(toolSpec.getAttributes()); |
871 | 930 | } |
872 | 931 | } |
| 932 | + Map<String, Object> runtimeResources = toolSpec.getRuntimeResources(); |
| 933 | + if (runtimeResources != null) { |
| 934 | + if (tool.getAttributes() == null) { |
| 935 | + tool.setAttributes(new HashMap<>(runtimeResources)); |
| 936 | + } else { |
| 937 | + tool.getAttributes().putAll(runtimeResources); |
| 938 | + } |
| 939 | + |
| 940 | + Connector connector = (Connector) runtimeResources.get(CommonValue.MCP_CONNECTOR); |
| 941 | + ConnectorClientConfig config = (ConnectorClientConfig) runtimeResources.get(CommonValue.MCP_CONNECTOR_CONFIG); |
| 942 | + String connectorId = (String) runtimeResources.get(CommonValue.MCP_CONNECTOR_ID_FIELD); |
| 943 | + if (connector != null && config != null && connectorId != null) { |
| 944 | + McpSyncClient mcpClient = mcpClients.get(connectorId); |
| 945 | + if (mcpClient == null) { |
| 946 | + if (connector instanceof McpStreamableHttpConnector) { |
| 947 | + mcpClient = McpStreamableHttpTool |
| 948 | + .createMcpClient((McpStreamableHttpConnector) connector, config, requestHeaders); |
| 949 | + } else if (connector instanceof McpConnector) { |
| 950 | + mcpClient = McpSseTool.createMcpClient((McpConnector) connector, config, requestHeaders); |
| 951 | + } |
| 952 | + if (mcpClient != null) { |
| 953 | + mcpClients.put(connectorId, mcpClient); |
| 954 | + log.info("Created MCP client for connector: {}", connectorId); |
| 955 | + } |
| 956 | + } |
| 957 | + if (mcpClient != null) { |
| 958 | + tool.getAttributes().put(CommonValue.MCP_SYNC_CLIENT, mcpClient); |
| 959 | + } |
| 960 | + } |
| 961 | + } |
873 | 962 | toolSpecMap.put(tool.getName(), toolSpec); |
874 | 963 | } |
875 | 964 | } |
@@ -928,19 +1017,6 @@ public static Map<String, String> constructToolParams( |
928 | 1017 | return toolParams; |
929 | 1018 | } |
930 | 1019 |
|
931 | | - public static void cleanUpResource(Map<String, Tool> tools) { |
932 | | - for (Map.Entry<String, Tool> entry : tools.entrySet()) { |
933 | | - Tool tool = entry.getValue(); |
934 | | - if (tool instanceof McpSseTool) { |
935 | | - // TODO: make this more general, avoid checking specific tool type |
936 | | - ((McpSseTool) tool).getMcpSyncClient().closeGracefully(); |
937 | | - } else if (tool instanceof McpStreamableHttpTool) { |
938 | | - // TODO: make this more general, avoid checking specific tool type |
939 | | - ((McpStreamableHttpTool) tool).getMcpSyncClient().closeGracefully(); |
940 | | - } |
941 | | - } |
942 | | - } |
943 | | - |
944 | 1020 | /** |
945 | 1021 | * Generates a formatted current date and time string in UTC timezone. |
946 | 1022 | * |
@@ -1014,4 +1090,48 @@ public static Tool createTool(Map<String, Tool.Factory> toolFactories, Map<Strin |
1014 | 1090 |
|
1015 | 1091 | return tool; |
1016 | 1092 | } |
| 1093 | + |
| 1094 | + public static void cleanupMcpClients(Map<String, Tool> tools) { |
| 1095 | + if (tools == null || tools.isEmpty()) { |
| 1096 | + return; |
| 1097 | + } |
| 1098 | + Set<McpSyncClient> closedClients = new HashSet<>(); |
| 1099 | + for (Tool tool : tools.values()) { |
| 1100 | + if (tool.getAttributes() != null) { |
| 1101 | + Object clientObj = tool.getAttributes().get(CommonValue.MCP_SYNC_CLIENT); |
| 1102 | + if (clientObj instanceof McpSyncClient) { |
| 1103 | + McpSyncClient client = (McpSyncClient) clientObj; |
| 1104 | + if (!closedClients.contains(client)) { |
| 1105 | + try { |
| 1106 | + client.closeGracefully(); |
| 1107 | + closedClients.add(client); |
| 1108 | + log.info("Successfully closed MCP client for tool: {}", tool.getName()); |
| 1109 | + } catch (Exception e) { |
| 1110 | + log.warn("Failed to close MCP client gracefully for tool: {}. Error: {}", tool.getName(), e.getMessage()); |
| 1111 | + } |
| 1112 | + } |
| 1113 | + } |
| 1114 | + } |
| 1115 | + } |
| 1116 | + if (!closedClients.isEmpty()) { |
| 1117 | + log.info("Cleaned up {} MCP client(s)", closedClients.size()); |
| 1118 | + } |
| 1119 | + } |
| 1120 | + |
| 1121 | + public static <T> ActionListener<T> wrapListenerWithMcpCleanup(ActionListener<T> delegate, Map<String, Tool> tools) { |
| 1122 | + return ActionListener.wrap(response -> { |
| 1123 | + try { |
| 1124 | + cleanupMcpClients(tools); |
| 1125 | + } finally { |
| 1126 | + delegate.onResponse(response); |
| 1127 | + } |
| 1128 | + }, exception -> { |
| 1129 | + try { |
| 1130 | + cleanupMcpClients(tools); |
| 1131 | + } finally { |
| 1132 | + delegate.onFailure(exception); |
| 1133 | + } |
| 1134 | + }); |
| 1135 | + } |
| 1136 | + |
1017 | 1137 | } |
0 commit comments