Skip to content

Commit c14d0c5

Browse files
committed
fix tests
Signed-off-by: Jiaping Zeng <[email protected]>
1 parent a00ed66 commit c14d0c5

File tree

3 files changed

+22
-84
lines changed

3 files changed

+22
-84
lines changed

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/AgentUtilsTest.java

Lines changed: 14 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1305,7 +1305,7 @@ public void testGetMcpToolSpecs_NoMcpJsonConfig() {
13051305
when(mlAgent.getParameters()).thenReturn(null);
13061306

13071307
ActionListener<List<MLToolSpec>> listener = mock(ActionListener.class);
1308-
AgentUtils.getMcpToolSpecs(mlAgent, client, sdkClient, encryptor, listener);
1308+
AgentUtils.getMcpToolSpecs(mlAgent, new HashMap<>(), client, sdkClient, encryptor, listener);
13091309

13101310
verify(listener).onResponse(Collections.emptyList());
13111311
}
@@ -1322,14 +1322,14 @@ public void testGetMcpToolSpecs_SingleConnectorSuccess() throws Exception {
13221322
// mock McpConnector, McpConnectorExecutor, agent, and listener
13231323
mockMcpConnector(connStatic);
13241324
McpConnectorExecutor exec = mock(McpConnectorExecutor.class);
1325-
when(exec.getMcpToolSpecs()).thenReturn(expected);
1325+
when(exec.getMcpToolSpecs(any())).thenReturn(expected);
13261326
loadStatic.when(() -> MLEngineClassLoader.initInstance(anyString(), any(), any())).thenReturn(exec);
13271327

13281328
MLAgent mlAgent = mockAgent("[{\"" + MCP_CONNECTOR_ID_FIELD + "\":\"c1\"}]", "tenant");
13291329
ActionListener<List<MLToolSpec>> listener = mock(ActionListener.class);
13301330

13311331
// run and verify
1332-
AgentUtils.getMcpToolSpecs(mlAgent, client, sdkClient, null, listener);
1332+
AgentUtils.getMcpToolSpecs(mlAgent, new HashMap<>(), client, sdkClient, null, listener);
13331333
verify(listener).onResponse(expected);
13341334
}
13351335
}
@@ -1348,7 +1348,7 @@ public void testGetMcpToolSpecs_ToolFilterApplied() throws Exception {
13481348
mockMcpConnector(connStatic);
13491349

13501350
McpConnectorExecutor exec = mock(McpConnectorExecutor.class);
1351-
when(exec.getMcpToolSpecs()).thenReturn(repo);
1351+
when(exec.getMcpToolSpecs(any())).thenReturn(repo);
13521352
loadStatic.when(() -> MLEngineClassLoader.initInstance(anyString(), any(), any())).thenReturn(exec);
13531353

13541354
String mcpJsonConfig = "[{\""
@@ -1361,7 +1361,7 @@ public void testGetMcpToolSpecs_ToolFilterApplied() throws Exception {
13611361
ActionListener<List<MLToolSpec>> listener = mock(ActionListener.class);
13621362

13631363
// run and verify
1364-
AgentUtils.getMcpToolSpecs(agent, client, sdkClient, null, listener);
1364+
AgentUtils.getMcpToolSpecs(agent, new HashMap<>(), client, sdkClient, null, listener);
13651365
verify(listener).onResponse(expected);
13661366
}
13671367
}
@@ -1384,7 +1384,7 @@ public void testGetMcpToolSpecs_MultipleConnectorsMerged() throws Exception {
13841384
mockMcpConnector(connStatic);
13851385

13861386
McpConnectorExecutor exec = mock(McpConnectorExecutor.class);
1387-
when(exec.getMcpToolSpecs()).thenReturn(aTools, bTools);
1387+
when(exec.getMcpToolSpecs(any())).thenReturn(aTools, bTools);
13881388
loadStatic.when(() -> MLEngineClassLoader.initInstance(anyString(), any(), any())).thenReturn(exec);
13891389

13901390
String mcpJsonConfig = "[{\"" + MCP_CONNECTOR_ID_FIELD + "\":\"A\"}," + "{\"" + MCP_CONNECTOR_ID_FIELD + "\":\"B\"}]";
@@ -1393,7 +1393,7 @@ public void testGetMcpToolSpecs_MultipleConnectorsMerged() throws Exception {
13931393
ActionListener<List<MLToolSpec>> listener = mock(ActionListener.class);
13941394

13951395
// run and verify
1396-
AgentUtils.getMcpToolSpecs(agent, client, sdkClient, null, listener);
1396+
AgentUtils.getMcpToolSpecs(agent, new HashMap<>(), client, sdkClient, null, listener);
13971397
verify(listener).onResponse(expected);
13981398
}
13991399
}
@@ -1407,7 +1407,7 @@ public void testGetMcpToolSpecs_NonMcpConnectorReturnsEmpty() throws Exception {
14071407
MLAgent agent = mockAgent("[{\"" + MCP_CONNECTOR_ID_FIELD + "\":\"c1\"}]", "tenant");
14081408

14091409
ActionListener<List<MLToolSpec>> listener = mock(ActionListener.class);
1410-
AgentUtils.getMcpToolSpecs(agent, client, sdkClient, null, listener);
1410+
AgentUtils.getMcpToolSpecs(agent, new HashMap<>(), client, sdkClient, null, listener);
14111411

14121412
verify(listener).onResponse(Collections.emptyList());
14131413
}
@@ -1916,14 +1916,14 @@ public void testGetMcpToolSpecs_McpStreamableHttpConnectorSuccess() throws Excep
19161916
// mock McpStreamableHttpConnector, McpStreamableHttpConnectorExecutor, agent, and listener
19171917
mockMcpStreamableHttpConnector(connStatic);
19181918
McpStreamableHttpConnectorExecutor exec = mock(McpStreamableHttpConnectorExecutor.class);
1919-
when(exec.getMcpToolSpecs()).thenReturn(expected);
1919+
when(exec.getMcpToolSpecs(any())).thenReturn(expected);
19201920
loadStatic.when(() -> MLEngineClassLoader.initInstance(anyString(), any(), any())).thenReturn(exec);
19211921

19221922
MLAgent mlAgent = mockAgent("[{\"" + MCP_CONNECTOR_ID_FIELD + "\":\"c1\"}]", "tenant");
19231923
ActionListener<List<MLToolSpec>> listener = mock(ActionListener.class);
19241924

19251925
// run and verify
1926-
AgentUtils.getMcpToolSpecs(mlAgent, client, sdkClient, null, listener);
1926+
AgentUtils.getMcpToolSpecs(mlAgent, new HashMap<>(), client, sdkClient, null, listener);
19271927
verify(listener).onResponse(expected);
19281928
}
19291929
}
@@ -1941,7 +1941,7 @@ public void testGetMcpToolSpecs_UnsupportedConnectorType() throws Exception {
19411941
MLAgent agent = mockAgent("[{\"" + MCP_CONNECTOR_ID_FIELD + "\":\"c1\"}]", "tenant");
19421942
ActionListener<List<MLToolSpec>> listener = mock(ActionListener.class);
19431943

1944-
AgentUtils.getMcpToolSpecs(agent, client, sdkClient, null, listener);
1944+
AgentUtils.getMcpToolSpecs(agent, new HashMap<>(), client, sdkClient, null, listener);
19451945
verify(listener).onResponse(Collections.emptyList());
19461946
}
19471947
}
@@ -1957,14 +1957,14 @@ public void testGetMcpToolSpecs_ExceptionInGetMcpToolSpecs() throws Exception {
19571957
// mock McpConnector, McpConnectorExecutor, agent, and listener
19581958
mockMcpConnector(connStatic);
19591959
McpConnectorExecutor exec = mock(McpConnectorExecutor.class);
1960-
when(exec.getMcpToolSpecs()).thenThrow(new RuntimeException("Test exception"));
1960+
when(exec.getMcpToolSpecs(any())).thenThrow(new RuntimeException("Test exception"));
19611961
loadStatic.when(() -> MLEngineClassLoader.initInstance(anyString(), any(), any())).thenReturn(exec);
19621962

19631963
MLAgent mlAgent = mockAgent("[{\"" + MCP_CONNECTOR_ID_FIELD + "\":\"c1\"}]", "tenant");
19641964
ActionListener<List<MLToolSpec>> listener = mock(ActionListener.class);
19651965

19661966
// run and verify
1967-
AgentUtils.getMcpToolSpecs(mlAgent, client, sdkClient, null, listener);
1967+
AgentUtils.getMcpToolSpecs(mlAgent, new HashMap<>(), client, sdkClient, null, listener);
19681968
verify(listener).onResponse(Collections.emptyList());
19691969
}
19701970
}
@@ -1989,7 +1989,7 @@ public void testGetMcpToolSpecs_ExceptionInGetConnector() throws Exception {
19891989
MLAgent mlAgent = mockAgent("[{\"" + MCP_CONNECTOR_ID_FIELD + "\":\"c1\"}]", "tenant");
19901990
ActionListener<List<MLToolSpec>> listener = mock(ActionListener.class);
19911991

1992-
AgentUtils.getMcpToolSpecs(mlAgent, client, sdkClient, null, listener);
1992+
AgentUtils.getMcpToolSpecs(mlAgent, new HashMap<>(), client, sdkClient, null, listener);
19931993
verify(listener).onResponse(Collections.emptyList());
19941994
}
19951995

@@ -2079,68 +2079,4 @@ public void testExtractRequestHeaders_WithEmptyJsonObject() {
20792079
assertEquals(0, result.size());
20802080
}
20812081

2082-
@Test
2083-
public void testMergeHeaders_BothNonEmpty() {
2084-
Map<String, String> staticHeaders = new HashMap<>();
2085-
staticHeaders.put("X-Static-Header", "static-value");
2086-
staticHeaders.put("Authorization", "Bearer static-token");
2087-
2088-
Map<String, String> requestHeaders = new HashMap<>();
2089-
requestHeaders.put("X-Request-Header", "request-value");
2090-
requestHeaders.put("Authorization", "Bearer request-token");
2091-
2092-
Map<String, String> result = AgentUtils.mergeHeaders(staticHeaders, requestHeaders);
2093-
2094-
assertEquals(3, result.size());
2095-
assertEquals("static-value", result.get("X-Static-Header"));
2096-
assertEquals("request-value", result.get("X-Request-Header"));
2097-
assertEquals("Bearer request-token", result.get("Authorization")); // Request headers override static
2098-
}
2099-
2100-
@Test
2101-
public void testMergeHeaders_OnlyStaticHeaders() {
2102-
Map<String, String> staticHeaders = new HashMap<>();
2103-
staticHeaders.put("X-Static-Header", "static-value");
2104-
2105-
Map<String, String> result = AgentUtils.mergeHeaders(staticHeaders, null);
2106-
2107-
assertEquals(1, result.size());
2108-
assertEquals("static-value", result.get("X-Static-Header"));
2109-
}
2110-
2111-
@Test
2112-
public void testMergeHeaders_OnlyRequestHeaders() {
2113-
Map<String, String> requestHeaders = new HashMap<>();
2114-
requestHeaders.put("X-Request-Header", "request-value");
2115-
2116-
Map<String, String> result = AgentUtils.mergeHeaders(null, requestHeaders);
2117-
2118-
assertEquals(1, result.size());
2119-
assertEquals("request-value", result.get("X-Request-Header"));
2120-
}
2121-
2122-
@Test
2123-
public void testMergeHeaders_BothNull() {
2124-
Map<String, String> result = AgentUtils.mergeHeaders(null, null);
2125-
2126-
assertEquals(0, result.size());
2127-
}
2128-
2129-
@Test
2130-
public void testMergeHeaders_BothEmpty() {
2131-
Map<String, String> result = AgentUtils.mergeHeaders(new HashMap<>(), new HashMap<>());
2132-
2133-
assertEquals(0, result.size());
2134-
}
2135-
2136-
@Test
2137-
public void testMergeHeaders_EmptyStaticNonEmptyRequest() {
2138-
Map<String, String> requestHeaders = new HashMap<>();
2139-
requestHeaders.put("X-Request-Header", "request-value");
2140-
2141-
Map<String, String> result = AgentUtils.mergeHeaders(new HashMap<>(), requestHeaders);
2142-
2143-
assertEquals(1, result.size());
2144-
assertEquals("request-value", result.get("X-Request-Header"));
2145-
}
21462082
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutorTest.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import static org.mockito.Mockito.verify;
1212
import static org.mockito.Mockito.when;
1313

14+
import java.util.HashMap;
1415
import java.util.List;
1516
import java.util.Map;
1617

@@ -67,7 +68,7 @@ public void getMcpToolSpecs_returnsExpectedSpecs() {
6768
try (MockedStatic<McpClient> mocked = mockStatic(McpClient.class)) {
6869
mocked.when(() -> McpClient.sync(any(McpClientTransport.class))).thenReturn(builder);
6970
McpConnectorExecutor exec = new McpConnectorExecutor(mockConnector);
70-
List<MLToolSpec> specs = exec.getMcpToolSpecs();
71+
List<MLToolSpec> specs = exec.getMcpToolSpecs(new HashMap<>());
7172

7273
Assert.assertEquals(1, specs.size());
7374
MLToolSpec spec = specs.get(0);
@@ -90,7 +91,7 @@ public void getMcpToolSpecs_throwsOnInitError() {
9091
mocked.when(() -> McpClient.sync(any(McpClientTransport.class))).thenReturn(builder);
9192
McpConnectorExecutor exec = new McpConnectorExecutor(mockConnector);
9293

93-
assertThrows(RuntimeException.class, () -> exec.getMcpToolSpecs());
94+
assertThrows(RuntimeException.class, () -> exec.getMcpToolSpecs(new HashMap<>()));
9495
}
9596
}
9697

@@ -101,7 +102,7 @@ public void getMcpToolSpecs_throwsOnListToolsError() {
101102
try (MockedStatic<McpClient> mocked = mockStatic(McpClient.class)) {
102103
mocked.when(() -> McpClient.sync(any(McpClientTransport.class))).thenReturn(builder);
103104
McpConnectorExecutor exec = new McpConnectorExecutor(mockConnector);
104-
assertThrows(RuntimeException.class, () -> exec.getMcpToolSpecs());
105+
assertThrows(RuntimeException.class, () -> exec.getMcpToolSpecs(new HashMap<>()));
105106
}
106107
}
107108

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/McpStreamableHttpConnectorExecutorTest.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import static org.mockito.Mockito.verify;
1212
import static org.mockito.Mockito.when;
1313

14+
import java.util.HashMap;
1415
import java.util.List;
1516
import java.util.Map;
1617

@@ -67,7 +68,7 @@ public void getMcpToolSpecs_returnsExpectedSpecs() {
6768
try (MockedStatic<McpClient> mocked = mockStatic(McpClient.class)) {
6869
mocked.when(() -> McpClient.sync(any(McpClientTransport.class))).thenReturn(builder);
6970
McpStreamableHttpConnectorExecutor exec = new McpStreamableHttpConnectorExecutor(mockConnector);
70-
List<MLToolSpec> specs = exec.getMcpToolSpecs();
71+
List<MLToolSpec> specs = exec.getMcpToolSpecs(new HashMap<>());
7172

7273
Assert.assertEquals(1, specs.size());
7374
MLToolSpec spec = specs.get(0);
@@ -90,7 +91,7 @@ public void getMcpToolSpecs_throwsOnInitError() {
9091
mocked.when(() -> McpClient.sync(any(McpClientTransport.class))).thenReturn(builder);
9192
McpStreamableHttpConnectorExecutor exec = new McpStreamableHttpConnectorExecutor(mockConnector);
9293

93-
assertThrows(RuntimeException.class, () -> exec.getMcpToolSpecs());
94+
assertThrows(RuntimeException.class, () -> exec.getMcpToolSpecs(new HashMap<>()));
9495
}
9596
}
9697

@@ -103,7 +104,7 @@ public void getMcpToolSpecs_throwsOnListToolsError() {
103104
mocked.when(() -> McpClient.sync(any(McpClientTransport.class))).thenReturn(builder);
104105
McpStreamableHttpConnectorExecutor exec = new McpStreamableHttpConnectorExecutor(mockConnector);
105106

106-
assertThrows(RuntimeException.class, () -> exec.getMcpToolSpecs());
107+
assertThrows(RuntimeException.class, () -> exec.getMcpToolSpecs(new HashMap<>()));
107108
}
108109
}
109110

0 commit comments

Comments
 (0)