|
11 | 11 | import static org.mockito.ArgumentMatchers.anyString; |
12 | 12 | import static org.mockito.ArgumentMatchers.argThat; |
13 | 13 | import static org.mockito.Mockito.doNothing; |
| 14 | +import static org.mockito.Mockito.doThrow; |
14 | 15 | import static org.mockito.Mockito.mock; |
15 | 16 | import static org.mockito.Mockito.verify; |
16 | 17 | import static org.mockito.Mockito.when; |
@@ -2000,4 +2001,338 @@ private void mockMcpStreamableHttpConnector(MockedStatic<Connector> connectorSta |
2000 | 2001 | doNothing().when(mockConnector).decrypt(anyString(), any(), anyString()); |
2001 | 2002 | connectorStatic.when(() -> Connector.createConnector(any(XContentParser.class))).thenReturn(mockConnector); |
2002 | 2003 | } |
| 2004 | + |
| 2005 | + @Test |
| 2006 | + public void testExtractRequestHeaders_WithValidHeaders() { |
| 2007 | + Map<String, String> parameters = new HashMap<>(); |
| 2008 | + parameters.put(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS, "{\"Authorization\":\"Bearer token123\",\"Content-Type\":\"application/json\"}"); |
| 2009 | + |
| 2010 | + Map<String, String> result = AgentUtils.extractRequestHeaders(parameters); |
| 2011 | + |
| 2012 | + assertEquals(2, result.size()); |
| 2013 | + assertEquals("Bearer token123", result.get("Authorization")); |
| 2014 | + assertEquals("application/json", result.get("Content-Type")); |
| 2015 | + } |
| 2016 | + |
| 2017 | + @Test |
| 2018 | + public void testExtractRequestHeaders_WithNullParameters() { |
| 2019 | + Map<String, String> result = AgentUtils.extractRequestHeaders(null); |
| 2020 | + |
| 2021 | + assertEquals(0, result.size()); |
| 2022 | + assertEquals(Collections.emptyMap(), result); |
| 2023 | + } |
| 2024 | + |
| 2025 | + @Test |
| 2026 | + public void testExtractRequestHeaders_WithEmptyHeadersJson() { |
| 2027 | + Map<String, String> parameters = new HashMap<>(); |
| 2028 | + parameters.put(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS, ""); |
| 2029 | + |
| 2030 | + Map<String, String> result = AgentUtils.extractRequestHeaders(parameters); |
| 2031 | + |
| 2032 | + assertEquals(0, result.size()); |
| 2033 | + assertEquals(Collections.emptyMap(), result); |
| 2034 | + } |
| 2035 | + |
| 2036 | + @Test |
| 2037 | + public void testExtractRequestHeaders_WithNullHeadersJson() { |
| 2038 | + Map<String, String> parameters = new HashMap<>(); |
| 2039 | + parameters.put(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS, null); |
| 2040 | + |
| 2041 | + Map<String, String> result = AgentUtils.extractRequestHeaders(parameters); |
| 2042 | + |
| 2043 | + assertEquals(0, result.size()); |
| 2044 | + assertEquals(Collections.emptyMap(), result); |
| 2045 | + } |
| 2046 | + |
| 2047 | + @Test |
| 2048 | + public void testExtractRequestHeaders_WithWhitespaceHeadersJson() { |
| 2049 | + Map<String, String> parameters = new HashMap<>(); |
| 2050 | + parameters.put(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS, " "); |
| 2051 | + |
| 2052 | + Map<String, String> result = AgentUtils.extractRequestHeaders(parameters); |
| 2053 | + |
| 2054 | + assertEquals(0, result.size()); |
| 2055 | + assertEquals(Collections.emptyMap(), result); |
| 2056 | + } |
| 2057 | + |
| 2058 | + @Test |
| 2059 | + public void testExtractRequestHeaders_WithInvalidJson() { |
| 2060 | + Map<String, String> parameters = new HashMap<>(); |
| 2061 | + parameters.put(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS, "{invalid json}"); |
| 2062 | + |
| 2063 | + Map<String, String> result = AgentUtils.extractRequestHeaders(parameters); |
| 2064 | + |
| 2065 | + assertEquals(0, result.size()); |
| 2066 | + assertEquals(Collections.emptyMap(), result); |
| 2067 | + } |
| 2068 | + |
| 2069 | + @Test |
| 2070 | + public void testExtractRequestHeaders_WithEmptyJsonObject() { |
| 2071 | + Map<String, String> parameters = new HashMap<>(); |
| 2072 | + parameters.put(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS, "{}"); |
| 2073 | + |
| 2074 | + Map<String, String> result = AgentUtils.extractRequestHeaders(parameters); |
| 2075 | + |
| 2076 | + assertEquals(0, result.size()); |
| 2077 | + } |
| 2078 | + |
| 2079 | + @Test |
| 2080 | + public void testMergeHeaders_BothNonEmpty() { |
| 2081 | + Map<String, String> staticHeaders = new HashMap<>(); |
| 2082 | + staticHeaders.put("X-Static-Header", "static-value"); |
| 2083 | + staticHeaders.put("Authorization", "Bearer static-token"); |
| 2084 | + |
| 2085 | + Map<String, String> requestHeaders = new HashMap<>(); |
| 2086 | + requestHeaders.put("X-Request-Header", "request-value"); |
| 2087 | + requestHeaders.put("Authorization", "Bearer request-token"); |
| 2088 | + |
| 2089 | + Map<String, String> result = AgentUtils.mergeHeaders(staticHeaders, requestHeaders); |
| 2090 | + |
| 2091 | + assertEquals(3, result.size()); |
| 2092 | + assertEquals("static-value", result.get("X-Static-Header")); |
| 2093 | + assertEquals("request-value", result.get("X-Request-Header")); |
| 2094 | + assertEquals("Bearer request-token", result.get("Authorization")); // Request headers override static |
| 2095 | + } |
| 2096 | + |
| 2097 | + @Test |
| 2098 | + public void testMergeHeaders_OnlyStaticHeaders() { |
| 2099 | + Map<String, String> staticHeaders = new HashMap<>(); |
| 2100 | + staticHeaders.put("X-Static-Header", "static-value"); |
| 2101 | + |
| 2102 | + Map<String, String> result = AgentUtils.mergeHeaders(staticHeaders, null); |
| 2103 | + |
| 2104 | + assertEquals(1, result.size()); |
| 2105 | + assertEquals("static-value", result.get("X-Static-Header")); |
| 2106 | + } |
| 2107 | + |
| 2108 | + @Test |
| 2109 | + public void testMergeHeaders_OnlyRequestHeaders() { |
| 2110 | + Map<String, String> requestHeaders = new HashMap<>(); |
| 2111 | + requestHeaders.put("X-Request-Header", "request-value"); |
| 2112 | + |
| 2113 | + Map<String, String> result = AgentUtils.mergeHeaders(null, requestHeaders); |
| 2114 | + |
| 2115 | + assertEquals(1, result.size()); |
| 2116 | + assertEquals("request-value", result.get("X-Request-Header")); |
| 2117 | + } |
| 2118 | + |
| 2119 | + @Test |
| 2120 | + public void testMergeHeaders_BothNull() { |
| 2121 | + Map<String, String> result = AgentUtils.mergeHeaders(null, null); |
| 2122 | + |
| 2123 | + assertEquals(0, result.size()); |
| 2124 | + } |
| 2125 | + |
| 2126 | + @Test |
| 2127 | + public void testMergeHeaders_BothEmpty() { |
| 2128 | + Map<String, String> result = AgentUtils.mergeHeaders(new HashMap<>(), new HashMap<>()); |
| 2129 | + |
| 2130 | + assertEquals(0, result.size()); |
| 2131 | + } |
| 2132 | + |
| 2133 | + @Test |
| 2134 | + public void testMergeHeaders_EmptyStaticNonEmptyRequest() { |
| 2135 | + Map<String, String> requestHeaders = new HashMap<>(); |
| 2136 | + requestHeaders.put("X-Request-Header", "request-value"); |
| 2137 | + |
| 2138 | + Map<String, String> result = AgentUtils.mergeHeaders(new HashMap<>(), requestHeaders); |
| 2139 | + |
| 2140 | + assertEquals(1, result.size()); |
| 2141 | + assertEquals("request-value", result.get("X-Request-Header")); |
| 2142 | + } |
| 2143 | + |
| 2144 | + @Test |
| 2145 | + public void testCleanupMcpClients_WithNullTools() { |
| 2146 | + // Should not throw exception |
| 2147 | + AgentUtils.cleanupMcpClients(null); |
| 2148 | + } |
| 2149 | + |
| 2150 | + @Test |
| 2151 | + public void testCleanupMcpClients_WithEmptyTools() { |
| 2152 | + Map<String, Tool> tools = new HashMap<>(); |
| 2153 | + |
| 2154 | + // Should not throw exception |
| 2155 | + AgentUtils.cleanupMcpClients(tools); |
| 2156 | + } |
| 2157 | + |
| 2158 | + @Test |
| 2159 | + public void testCleanupMcpClients_WithToolsWithoutMcpClient() { |
| 2160 | + Map<String, Tool> tools = new HashMap<>(); |
| 2161 | + Tool tool = mock(Tool.class); |
| 2162 | + when(tool.getAttributes()).thenReturn(new HashMap<>()); |
| 2163 | + tools.put("tool1", tool); |
| 2164 | + |
| 2165 | + // Should not throw exception |
| 2166 | + AgentUtils.cleanupMcpClients(tools); |
| 2167 | + } |
| 2168 | + |
| 2169 | + @Test |
| 2170 | + public void testCleanupMcpClients_WithToolsWithNullAttributes() { |
| 2171 | + Map<String, Tool> tools = new HashMap<>(); |
| 2172 | + Tool tool = mock(Tool.class); |
| 2173 | + when(tool.getAttributes()).thenReturn(null); |
| 2174 | + tools.put("tool1", tool); |
| 2175 | + |
| 2176 | + // Should not throw exception |
| 2177 | + AgentUtils.cleanupMcpClients(tools); |
| 2178 | + } |
| 2179 | + |
| 2180 | + @Test |
| 2181 | + public void testCleanupMcpClients_WithMcpClient() { |
| 2182 | + Map<String, Tool> tools = new HashMap<>(); |
| 2183 | + Tool tool = mock(Tool.class); |
| 2184 | + io.modelcontextprotocol.client.McpSyncClient mcpClient = mock(io.modelcontextprotocol.client.McpSyncClient.class); |
| 2185 | + |
| 2186 | + Map<String, Object> attributes = new HashMap<>(); |
| 2187 | + attributes.put(org.opensearch.ml.common.CommonValue.MCP_SYNC_CLIENT, mcpClient); |
| 2188 | + when(tool.getAttributes()).thenReturn(attributes); |
| 2189 | + when(tool.getName()).thenReturn("test-tool"); |
| 2190 | + tools.put("tool1", tool); |
| 2191 | + |
| 2192 | + AgentUtils.cleanupMcpClients(tools); |
| 2193 | + |
| 2194 | + verify(mcpClient).closeGracefully(); |
| 2195 | + } |
| 2196 | + |
| 2197 | + @Test |
| 2198 | + public void testCleanupMcpClients_WithMultipleToolsSameMcpClient() { |
| 2199 | + Map<String, Tool> tools = new HashMap<>(); |
| 2200 | + io.modelcontextprotocol.client.McpSyncClient mcpClient = mock(io.modelcontextprotocol.client.McpSyncClient.class); |
| 2201 | + |
| 2202 | + // Create two tools sharing the same MCP client |
| 2203 | + Tool tool1 = mock(Tool.class); |
| 2204 | + Map<String, Object> attributes1 = new HashMap<>(); |
| 2205 | + attributes1.put(org.opensearch.ml.common.CommonValue.MCP_SYNC_CLIENT, mcpClient); |
| 2206 | + when(tool1.getAttributes()).thenReturn(attributes1); |
| 2207 | + when(tool1.getName()).thenReturn("test-tool-1"); |
| 2208 | + |
| 2209 | + Tool tool2 = mock(Tool.class); |
| 2210 | + Map<String, Object> attributes2 = new HashMap<>(); |
| 2211 | + attributes2.put(org.opensearch.ml.common.CommonValue.MCP_SYNC_CLIENT, mcpClient); |
| 2212 | + when(tool2.getAttributes()).thenReturn(attributes2); |
| 2213 | + when(tool2.getName()).thenReturn("test-tool-2"); |
| 2214 | + |
| 2215 | + tools.put("tool1", tool1); |
| 2216 | + tools.put("tool2", tool2); |
| 2217 | + |
| 2218 | + AgentUtils.cleanupMcpClients(tools); |
| 2219 | + |
| 2220 | + // Should only close once even though two tools share the same client |
| 2221 | + verify(mcpClient).closeGracefully(); |
| 2222 | + } |
| 2223 | + |
| 2224 | + @Test |
| 2225 | + public void testCleanupMcpClients_WithMultipleToolsDifferentMcpClients() { |
| 2226 | + Map<String, Tool> tools = new HashMap<>(); |
| 2227 | + io.modelcontextprotocol.client.McpSyncClient mcpClient1 = mock(io.modelcontextprotocol.client.McpSyncClient.class); |
| 2228 | + io.modelcontextprotocol.client.McpSyncClient mcpClient2 = mock(io.modelcontextprotocol.client.McpSyncClient.class); |
| 2229 | + |
| 2230 | + Tool tool1 = mock(Tool.class); |
| 2231 | + Map<String, Object> attributes1 = new HashMap<>(); |
| 2232 | + attributes1.put(org.opensearch.ml.common.CommonValue.MCP_SYNC_CLIENT, mcpClient1); |
| 2233 | + when(tool1.getAttributes()).thenReturn(attributes1); |
| 2234 | + when(tool1.getName()).thenReturn("test-tool-1"); |
| 2235 | + |
| 2236 | + Tool tool2 = mock(Tool.class); |
| 2237 | + Map<String, Object> attributes2 = new HashMap<>(); |
| 2238 | + attributes2.put(org.opensearch.ml.common.CommonValue.MCP_SYNC_CLIENT, mcpClient2); |
| 2239 | + when(tool2.getAttributes()).thenReturn(attributes2); |
| 2240 | + when(tool2.getName()).thenReturn("test-tool-2"); |
| 2241 | + |
| 2242 | + tools.put("tool1", tool1); |
| 2243 | + tools.put("tool2", tool2); |
| 2244 | + |
| 2245 | + AgentUtils.cleanupMcpClients(tools); |
| 2246 | + |
| 2247 | + verify(mcpClient1).closeGracefully(); |
| 2248 | + verify(mcpClient2).closeGracefully(); |
| 2249 | + } |
| 2250 | + |
| 2251 | + @Test |
| 2252 | + public void testCleanupMcpClients_WithExceptionDuringClose() { |
| 2253 | + Map<String, Tool> tools = new HashMap<>(); |
| 2254 | + Tool tool = mock(Tool.class); |
| 2255 | + io.modelcontextprotocol.client.McpSyncClient mcpClient = mock(io.modelcontextprotocol.client.McpSyncClient.class); |
| 2256 | + |
| 2257 | + Map<String, Object> attributes = new HashMap<>(); |
| 2258 | + attributes.put(org.opensearch.ml.common.CommonValue.MCP_SYNC_CLIENT, mcpClient); |
| 2259 | + when(tool.getAttributes()).thenReturn(attributes); |
| 2260 | + when(tool.getName()).thenReturn("test-tool"); |
| 2261 | + tools.put("tool1", tool); |
| 2262 | + |
| 2263 | + // Make closeGracefully throw an exception |
| 2264 | + doThrow(new RuntimeException("Close failed")).when(mcpClient).closeGracefully(); |
| 2265 | + |
| 2266 | + // Should not throw exception, just log warning |
| 2267 | + AgentUtils.cleanupMcpClients(tools); |
| 2268 | + |
| 2269 | + verify(mcpClient).closeGracefully(); |
| 2270 | + } |
| 2271 | + |
| 2272 | + @Test |
| 2273 | + public void testWrapListenerWithMcpCleanup_OnResponse() { |
| 2274 | + Map<String, Tool> tools = new HashMap<>(); |
| 2275 | + Tool tool = mock(Tool.class); |
| 2276 | + io.modelcontextprotocol.client.McpSyncClient mcpClient = mock(io.modelcontextprotocol.client.McpSyncClient.class); |
| 2277 | + |
| 2278 | + Map<String, Object> attributes = new HashMap<>(); |
| 2279 | + attributes.put(org.opensearch.ml.common.CommonValue.MCP_SYNC_CLIENT, mcpClient); |
| 2280 | + when(tool.getAttributes()).thenReturn(attributes); |
| 2281 | + when(tool.getName()).thenReturn("test-tool"); |
| 2282 | + tools.put("tool1", tool); |
| 2283 | + |
| 2284 | + ActionListener<String> delegate = mock(ActionListener.class); |
| 2285 | + ActionListener<String> wrapped = AgentUtils.wrapListenerWithMcpCleanup(delegate, tools); |
| 2286 | + |
| 2287 | + wrapped.onResponse("test-response"); |
| 2288 | + |
| 2289 | + verify(mcpClient).closeGracefully(); |
| 2290 | + verify(delegate).onResponse("test-response"); |
| 2291 | + } |
| 2292 | + |
| 2293 | + @Test |
| 2294 | + public void testWrapListenerWithMcpCleanup_OnFailure() { |
| 2295 | + Map<String, Tool> tools = new HashMap<>(); |
| 2296 | + Tool tool = mock(Tool.class); |
| 2297 | + io.modelcontextprotocol.client.McpSyncClient mcpClient = mock(io.modelcontextprotocol.client.McpSyncClient.class); |
| 2298 | + |
| 2299 | + Map<String, Object> attributes = new HashMap<>(); |
| 2300 | + attributes.put(org.opensearch.ml.common.CommonValue.MCP_SYNC_CLIENT, mcpClient); |
| 2301 | + when(tool.getAttributes()).thenReturn(attributes); |
| 2302 | + when(tool.getName()).thenReturn("test-tool"); |
| 2303 | + tools.put("tool1", tool); |
| 2304 | + |
| 2305 | + ActionListener<String> delegate = mock(ActionListener.class); |
| 2306 | + ActionListener<String> wrapped = AgentUtils.wrapListenerWithMcpCleanup(delegate, tools); |
| 2307 | + |
| 2308 | + Exception testException = new RuntimeException("test exception"); |
| 2309 | + wrapped.onFailure(testException); |
| 2310 | + |
| 2311 | + verify(mcpClient).closeGracefully(); |
| 2312 | + verify(delegate).onFailure(testException); |
| 2313 | + } |
| 2314 | + |
| 2315 | + @Test |
| 2316 | + public void testWrapListenerWithMcpCleanup_CleanupFailureDoesNotPreventDelegateCall() { |
| 2317 | + Map<String, Tool> tools = new HashMap<>(); |
| 2318 | + Tool tool = mock(Tool.class); |
| 2319 | + io.modelcontextprotocol.client.McpSyncClient mcpClient = mock(io.modelcontextprotocol.client.McpSyncClient.class); |
| 2320 | + |
| 2321 | + Map<String, Object> attributes = new HashMap<>(); |
| 2322 | + attributes.put(org.opensearch.ml.common.CommonValue.MCP_SYNC_CLIENT, mcpClient); |
| 2323 | + when(tool.getAttributes()).thenReturn(attributes); |
| 2324 | + when(tool.getName()).thenReturn("test-tool"); |
| 2325 | + tools.put("tool1", tool); |
| 2326 | + |
| 2327 | + // Make cleanup fail |
| 2328 | + doThrow(new RuntimeException("Cleanup failed")).when(mcpClient).closeGracefully(); |
| 2329 | + |
| 2330 | + ActionListener<String> delegate = mock(ActionListener.class); |
| 2331 | + ActionListener<String> wrapped = AgentUtils.wrapListenerWithMcpCleanup(delegate, tools); |
| 2332 | + |
| 2333 | + wrapped.onResponse("test-response"); |
| 2334 | + |
| 2335 | + // Delegate should still be called even if cleanup fails |
| 2336 | + verify(delegate).onResponse("test-response"); |
| 2337 | + } |
2003 | 2338 | } |
0 commit comments