Skip to content

Commit 8821c44

Browse files
committed
add/fix tests
Signed-off-by: Jiaping Zeng <[email protected]>
1 parent 6cd9552 commit 8821c44

File tree

1 file changed

+335
-0
lines changed

1 file changed

+335
-0
lines changed

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java

Lines changed: 335 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import static org.mockito.ArgumentMatchers.anyString;
1212
import static org.mockito.ArgumentMatchers.argThat;
1313
import static org.mockito.Mockito.doNothing;
14+
import static org.mockito.Mockito.doThrow;
1415
import static org.mockito.Mockito.mock;
1516
import static org.mockito.Mockito.verify;
1617
import static org.mockito.Mockito.when;
@@ -2000,4 +2001,338 @@ private void mockMcpStreamableHttpConnector(MockedStatic<Connector> connectorSta
20002001
doNothing().when(mockConnector).decrypt(anyString(), any(), anyString());
20012002
connectorStatic.when(() -> Connector.createConnector(any(XContentParser.class))).thenReturn(mockConnector);
20022003
}
2004+
2005+
@Test
2006+
public void testExtractRequestHeaders_WithValidHeaders() {
2007+
Map<String, String> parameters = new HashMap<>();
2008+
parameters.put(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS, "{\"Authorization\":\"Bearer token123\",\"Content-Type\":\"application/json\"}");
2009+
2010+
Map<String, String> result = AgentUtils.extractRequestHeaders(parameters);
2011+
2012+
assertEquals(2, result.size());
2013+
assertEquals("Bearer token123", result.get("Authorization"));
2014+
assertEquals("application/json", result.get("Content-Type"));
2015+
}
2016+
2017+
@Test
2018+
public void testExtractRequestHeaders_WithNullParameters() {
2019+
Map<String, String> result = AgentUtils.extractRequestHeaders(null);
2020+
2021+
assertEquals(0, result.size());
2022+
assertEquals(Collections.emptyMap(), result);
2023+
}
2024+
2025+
@Test
2026+
public void testExtractRequestHeaders_WithEmptyHeadersJson() {
2027+
Map<String, String> parameters = new HashMap<>();
2028+
parameters.put(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS, "");
2029+
2030+
Map<String, String> result = AgentUtils.extractRequestHeaders(parameters);
2031+
2032+
assertEquals(0, result.size());
2033+
assertEquals(Collections.emptyMap(), result);
2034+
}
2035+
2036+
@Test
2037+
public void testExtractRequestHeaders_WithNullHeadersJson() {
2038+
Map<String, String> parameters = new HashMap<>();
2039+
parameters.put(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS, null);
2040+
2041+
Map<String, String> result = AgentUtils.extractRequestHeaders(parameters);
2042+
2043+
assertEquals(0, result.size());
2044+
assertEquals(Collections.emptyMap(), result);
2045+
}
2046+
2047+
@Test
2048+
public void testExtractRequestHeaders_WithWhitespaceHeadersJson() {
2049+
Map<String, String> parameters = new HashMap<>();
2050+
parameters.put(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS, " ");
2051+
2052+
Map<String, String> result = AgentUtils.extractRequestHeaders(parameters);
2053+
2054+
assertEquals(0, result.size());
2055+
assertEquals(Collections.emptyMap(), result);
2056+
}
2057+
2058+
@Test
2059+
public void testExtractRequestHeaders_WithInvalidJson() {
2060+
Map<String, String> parameters = new HashMap<>();
2061+
parameters.put(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS, "{invalid json}");
2062+
2063+
Map<String, String> result = AgentUtils.extractRequestHeaders(parameters);
2064+
2065+
assertEquals(0, result.size());
2066+
assertEquals(Collections.emptyMap(), result);
2067+
}
2068+
2069+
@Test
2070+
public void testExtractRequestHeaders_WithEmptyJsonObject() {
2071+
Map<String, String> parameters = new HashMap<>();
2072+
parameters.put(org.opensearch.ml.common.CommonValue.MCP_REQUEST_HEADERS, "{}");
2073+
2074+
Map<String, String> result = AgentUtils.extractRequestHeaders(parameters);
2075+
2076+
assertEquals(0, result.size());
2077+
}
2078+
2079+
@Test
2080+
public void testMergeHeaders_BothNonEmpty() {
2081+
Map<String, String> staticHeaders = new HashMap<>();
2082+
staticHeaders.put("X-Static-Header", "static-value");
2083+
staticHeaders.put("Authorization", "Bearer static-token");
2084+
2085+
Map<String, String> requestHeaders = new HashMap<>();
2086+
requestHeaders.put("X-Request-Header", "request-value");
2087+
requestHeaders.put("Authorization", "Bearer request-token");
2088+
2089+
Map<String, String> result = AgentUtils.mergeHeaders(staticHeaders, requestHeaders);
2090+
2091+
assertEquals(3, result.size());
2092+
assertEquals("static-value", result.get("X-Static-Header"));
2093+
assertEquals("request-value", result.get("X-Request-Header"));
2094+
assertEquals("Bearer request-token", result.get("Authorization")); // Request headers override static
2095+
}
2096+
2097+
@Test
2098+
public void testMergeHeaders_OnlyStaticHeaders() {
2099+
Map<String, String> staticHeaders = new HashMap<>();
2100+
staticHeaders.put("X-Static-Header", "static-value");
2101+
2102+
Map<String, String> result = AgentUtils.mergeHeaders(staticHeaders, null);
2103+
2104+
assertEquals(1, result.size());
2105+
assertEquals("static-value", result.get("X-Static-Header"));
2106+
}
2107+
2108+
@Test
2109+
public void testMergeHeaders_OnlyRequestHeaders() {
2110+
Map<String, String> requestHeaders = new HashMap<>();
2111+
requestHeaders.put("X-Request-Header", "request-value");
2112+
2113+
Map<String, String> result = AgentUtils.mergeHeaders(null, requestHeaders);
2114+
2115+
assertEquals(1, result.size());
2116+
assertEquals("request-value", result.get("X-Request-Header"));
2117+
}
2118+
2119+
@Test
2120+
public void testMergeHeaders_BothNull() {
2121+
Map<String, String> result = AgentUtils.mergeHeaders(null, null);
2122+
2123+
assertEquals(0, result.size());
2124+
}
2125+
2126+
@Test
2127+
public void testMergeHeaders_BothEmpty() {
2128+
Map<String, String> result = AgentUtils.mergeHeaders(new HashMap<>(), new HashMap<>());
2129+
2130+
assertEquals(0, result.size());
2131+
}
2132+
2133+
@Test
2134+
public void testMergeHeaders_EmptyStaticNonEmptyRequest() {
2135+
Map<String, String> requestHeaders = new HashMap<>();
2136+
requestHeaders.put("X-Request-Header", "request-value");
2137+
2138+
Map<String, String> result = AgentUtils.mergeHeaders(new HashMap<>(), requestHeaders);
2139+
2140+
assertEquals(1, result.size());
2141+
assertEquals("request-value", result.get("X-Request-Header"));
2142+
}
2143+
2144+
@Test
2145+
public void testCleanupMcpClients_WithNullTools() {
2146+
// Should not throw exception
2147+
AgentUtils.cleanupMcpClients(null);
2148+
}
2149+
2150+
@Test
2151+
public void testCleanupMcpClients_WithEmptyTools() {
2152+
Map<String, Tool> tools = new HashMap<>();
2153+
2154+
// Should not throw exception
2155+
AgentUtils.cleanupMcpClients(tools);
2156+
}
2157+
2158+
@Test
2159+
public void testCleanupMcpClients_WithToolsWithoutMcpClient() {
2160+
Map<String, Tool> tools = new HashMap<>();
2161+
Tool tool = mock(Tool.class);
2162+
when(tool.getAttributes()).thenReturn(new HashMap<>());
2163+
tools.put("tool1", tool);
2164+
2165+
// Should not throw exception
2166+
AgentUtils.cleanupMcpClients(tools);
2167+
}
2168+
2169+
@Test
2170+
public void testCleanupMcpClients_WithToolsWithNullAttributes() {
2171+
Map<String, Tool> tools = new HashMap<>();
2172+
Tool tool = mock(Tool.class);
2173+
when(tool.getAttributes()).thenReturn(null);
2174+
tools.put("tool1", tool);
2175+
2176+
// Should not throw exception
2177+
AgentUtils.cleanupMcpClients(tools);
2178+
}
2179+
2180+
@Test
2181+
public void testCleanupMcpClients_WithMcpClient() {
2182+
Map<String, Tool> tools = new HashMap<>();
2183+
Tool tool = mock(Tool.class);
2184+
io.modelcontextprotocol.client.McpSyncClient mcpClient = mock(io.modelcontextprotocol.client.McpSyncClient.class);
2185+
2186+
Map<String, Object> attributes = new HashMap<>();
2187+
attributes.put(org.opensearch.ml.common.CommonValue.MCP_SYNC_CLIENT, mcpClient);
2188+
when(tool.getAttributes()).thenReturn(attributes);
2189+
when(tool.getName()).thenReturn("test-tool");
2190+
tools.put("tool1", tool);
2191+
2192+
AgentUtils.cleanupMcpClients(tools);
2193+
2194+
verify(mcpClient).closeGracefully();
2195+
}
2196+
2197+
@Test
2198+
public void testCleanupMcpClients_WithMultipleToolsSameMcpClient() {
2199+
Map<String, Tool> tools = new HashMap<>();
2200+
io.modelcontextprotocol.client.McpSyncClient mcpClient = mock(io.modelcontextprotocol.client.McpSyncClient.class);
2201+
2202+
// Create two tools sharing the same MCP client
2203+
Tool tool1 = mock(Tool.class);
2204+
Map<String, Object> attributes1 = new HashMap<>();
2205+
attributes1.put(org.opensearch.ml.common.CommonValue.MCP_SYNC_CLIENT, mcpClient);
2206+
when(tool1.getAttributes()).thenReturn(attributes1);
2207+
when(tool1.getName()).thenReturn("test-tool-1");
2208+
2209+
Tool tool2 = mock(Tool.class);
2210+
Map<String, Object> attributes2 = new HashMap<>();
2211+
attributes2.put(org.opensearch.ml.common.CommonValue.MCP_SYNC_CLIENT, mcpClient);
2212+
when(tool2.getAttributes()).thenReturn(attributes2);
2213+
when(tool2.getName()).thenReturn("test-tool-2");
2214+
2215+
tools.put("tool1", tool1);
2216+
tools.put("tool2", tool2);
2217+
2218+
AgentUtils.cleanupMcpClients(tools);
2219+
2220+
// Should only close once even though two tools share the same client
2221+
verify(mcpClient).closeGracefully();
2222+
}
2223+
2224+
@Test
2225+
public void testCleanupMcpClients_WithMultipleToolsDifferentMcpClients() {
2226+
Map<String, Tool> tools = new HashMap<>();
2227+
io.modelcontextprotocol.client.McpSyncClient mcpClient1 = mock(io.modelcontextprotocol.client.McpSyncClient.class);
2228+
io.modelcontextprotocol.client.McpSyncClient mcpClient2 = mock(io.modelcontextprotocol.client.McpSyncClient.class);
2229+
2230+
Tool tool1 = mock(Tool.class);
2231+
Map<String, Object> attributes1 = new HashMap<>();
2232+
attributes1.put(org.opensearch.ml.common.CommonValue.MCP_SYNC_CLIENT, mcpClient1);
2233+
when(tool1.getAttributes()).thenReturn(attributes1);
2234+
when(tool1.getName()).thenReturn("test-tool-1");
2235+
2236+
Tool tool2 = mock(Tool.class);
2237+
Map<String, Object> attributes2 = new HashMap<>();
2238+
attributes2.put(org.opensearch.ml.common.CommonValue.MCP_SYNC_CLIENT, mcpClient2);
2239+
when(tool2.getAttributes()).thenReturn(attributes2);
2240+
when(tool2.getName()).thenReturn("test-tool-2");
2241+
2242+
tools.put("tool1", tool1);
2243+
tools.put("tool2", tool2);
2244+
2245+
AgentUtils.cleanupMcpClients(tools);
2246+
2247+
verify(mcpClient1).closeGracefully();
2248+
verify(mcpClient2).closeGracefully();
2249+
}
2250+
2251+
@Test
2252+
public void testCleanupMcpClients_WithExceptionDuringClose() {
2253+
Map<String, Tool> tools = new HashMap<>();
2254+
Tool tool = mock(Tool.class);
2255+
io.modelcontextprotocol.client.McpSyncClient mcpClient = mock(io.modelcontextprotocol.client.McpSyncClient.class);
2256+
2257+
Map<String, Object> attributes = new HashMap<>();
2258+
attributes.put(org.opensearch.ml.common.CommonValue.MCP_SYNC_CLIENT, mcpClient);
2259+
when(tool.getAttributes()).thenReturn(attributes);
2260+
when(tool.getName()).thenReturn("test-tool");
2261+
tools.put("tool1", tool);
2262+
2263+
// Make closeGracefully throw an exception
2264+
doThrow(new RuntimeException("Close failed")).when(mcpClient).closeGracefully();
2265+
2266+
// Should not throw exception, just log warning
2267+
AgentUtils.cleanupMcpClients(tools);
2268+
2269+
verify(mcpClient).closeGracefully();
2270+
}
2271+
2272+
@Test
2273+
public void testWrapListenerWithMcpCleanup_OnResponse() {
2274+
Map<String, Tool> tools = new HashMap<>();
2275+
Tool tool = mock(Tool.class);
2276+
io.modelcontextprotocol.client.McpSyncClient mcpClient = mock(io.modelcontextprotocol.client.McpSyncClient.class);
2277+
2278+
Map<String, Object> attributes = new HashMap<>();
2279+
attributes.put(org.opensearch.ml.common.CommonValue.MCP_SYNC_CLIENT, mcpClient);
2280+
when(tool.getAttributes()).thenReturn(attributes);
2281+
when(tool.getName()).thenReturn("test-tool");
2282+
tools.put("tool1", tool);
2283+
2284+
ActionListener<String> delegate = mock(ActionListener.class);
2285+
ActionListener<String> wrapped = AgentUtils.wrapListenerWithMcpCleanup(delegate, tools);
2286+
2287+
wrapped.onResponse("test-response");
2288+
2289+
verify(mcpClient).closeGracefully();
2290+
verify(delegate).onResponse("test-response");
2291+
}
2292+
2293+
@Test
2294+
public void testWrapListenerWithMcpCleanup_OnFailure() {
2295+
Map<String, Tool> tools = new HashMap<>();
2296+
Tool tool = mock(Tool.class);
2297+
io.modelcontextprotocol.client.McpSyncClient mcpClient = mock(io.modelcontextprotocol.client.McpSyncClient.class);
2298+
2299+
Map<String, Object> attributes = new HashMap<>();
2300+
attributes.put(org.opensearch.ml.common.CommonValue.MCP_SYNC_CLIENT, mcpClient);
2301+
when(tool.getAttributes()).thenReturn(attributes);
2302+
when(tool.getName()).thenReturn("test-tool");
2303+
tools.put("tool1", tool);
2304+
2305+
ActionListener<String> delegate = mock(ActionListener.class);
2306+
ActionListener<String> wrapped = AgentUtils.wrapListenerWithMcpCleanup(delegate, tools);
2307+
2308+
Exception testException = new RuntimeException("test exception");
2309+
wrapped.onFailure(testException);
2310+
2311+
verify(mcpClient).closeGracefully();
2312+
verify(delegate).onFailure(testException);
2313+
}
2314+
2315+
@Test
2316+
public void testWrapListenerWithMcpCleanup_CleanupFailureDoesNotPreventDelegateCall() {
2317+
Map<String, Tool> tools = new HashMap<>();
2318+
Tool tool = mock(Tool.class);
2319+
io.modelcontextprotocol.client.McpSyncClient mcpClient = mock(io.modelcontextprotocol.client.McpSyncClient.class);
2320+
2321+
Map<String, Object> attributes = new HashMap<>();
2322+
attributes.put(org.opensearch.ml.common.CommonValue.MCP_SYNC_CLIENT, mcpClient);
2323+
when(tool.getAttributes()).thenReturn(attributes);
2324+
when(tool.getName()).thenReturn("test-tool");
2325+
tools.put("tool1", tool);
2326+
2327+
// Make cleanup fail
2328+
doThrow(new RuntimeException("Cleanup failed")).when(mcpClient).closeGracefully();
2329+
2330+
ActionListener<String> delegate = mock(ActionListener.class);
2331+
ActionListener<String> wrapped = AgentUtils.wrapListenerWithMcpCleanup(delegate, tools);
2332+
2333+
wrapped.onResponse("test-response");
2334+
2335+
// Delegate should still be called even if cleanup fails
2336+
verify(delegate).onResponse("test-response");
2337+
}
20032338
}

0 commit comments

Comments
 (0)