|
5 | 5 |
|
6 | 6 | package org.opensearch.ml.rest; |
7 | 7 |
|
| 8 | +import static org.mockito.ArgumentMatchers.any; |
| 9 | +import static org.mockito.Mockito.doAnswer; |
| 10 | +import static org.mockito.Mockito.mock; |
| 11 | + |
8 | 12 | import java.io.IOException; |
| 13 | +import java.lang.reflect.Field; |
9 | 14 | import java.nio.file.Files; |
10 | 15 | import java.nio.file.Path; |
11 | 16 | import java.util.ArrayList; |
| 17 | +import java.util.Collections; |
| 18 | +import java.util.HashMap; |
12 | 19 | import java.util.List; |
13 | 20 | import java.util.Locale; |
14 | 21 | import java.util.Map; |
| 22 | +import java.util.concurrent.CountDownLatch; |
| 23 | +import java.util.concurrent.TimeUnit; |
| 24 | +import java.util.concurrent.atomic.AtomicReference; |
15 | 25 |
|
16 | 26 | import org.junit.Before; |
| 27 | +import org.opensearch.core.action.ActionListener; |
17 | 28 | import org.opensearch.ml.common.FunctionName; |
18 | 29 | import org.opensearch.ml.common.dataset.TextDocsInputDataSet; |
19 | 30 | import org.opensearch.ml.common.input.MLInput; |
| 31 | +import org.opensearch.ml.common.output.MLOutput; |
| 32 | +import org.opensearch.ml.common.output.model.MLResultDataType; |
| 33 | +import org.opensearch.ml.common.output.model.ModelTensor; |
| 34 | +import org.opensearch.ml.common.output.model.ModelTensorOutput; |
| 35 | +import org.opensearch.ml.common.output.model.ModelTensors; |
20 | 36 | import org.opensearch.ml.common.utils.StringUtils; |
| 37 | +import org.opensearch.searchpipelines.questionanswering.generative.client.MachineLearningInternalClient; |
| 38 | +import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionInput; |
| 39 | +import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput; |
| 40 | +import org.opensearch.searchpipelines.questionanswering.generative.llm.DefaultLlmImpl; |
| 41 | +import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm; |
21 | 42 |
|
22 | 43 | import lombok.SneakyThrows; |
23 | 44 | import lombok.extern.log4j.Log4j2; |
@@ -82,17 +103,95 @@ public void test_bedrock_embedding_model() throws Exception { |
82 | 103 | } |
83 | 104 | } |
84 | 105 |
|
85 | | - public void testChatCompletionBedrockErrorResponseFormats() throws Exception { |
86 | | - // Simulate Bedrock inference endpoint behavior |
87 | | - // You can mock or create sample response maps for two formats |
| 106 | + public void testChatCompletionBedrockContentFormat() throws Exception { |
| 107 | + Map<String, Object> response = Map.of("content", List.of(Map.of("text", "Claude V3 response text"))); |
| 108 | + |
| 109 | + Map<String, Object> result = invokeBedrockInference(response); |
| 110 | + |
| 111 | + assertTrue(result.containsKey("answers")); |
| 112 | + assertEquals("Claude V3 response text", ((List<?>) result.get("answers")).get(0)); |
| 113 | + } |
| 114 | + |
| 115 | + private static void injectMlClient(DefaultLlmImpl connector, Object mlClient) { |
| 116 | + try { |
| 117 | + Field field = null; |
| 118 | + // Try common field names. Adjust if the actual field is named differently. |
| 119 | + try { |
| 120 | + field = DefaultLlmImpl.class.getDeclaredField("mlClient"); |
| 121 | + } catch (NoSuchFieldException e) { |
| 122 | + // fallback if different field name |
| 123 | + field = DefaultLlmImpl.class.getDeclaredField("client"); |
| 124 | + } |
| 125 | + field.setAccessible(true); |
| 126 | + field.set(connector, mlClient); |
| 127 | + } catch (ReflectiveOperationException e) { |
| 128 | + throw new RuntimeException("Failed to inject mlClient into DefaultLlmImpl", e); |
| 129 | + } |
| 130 | + } |
88 | 131 |
|
89 | | - Map<String, Object> errorFormat1 = Map.of("error", Map.of("message", "Unsupported Claude response format")); |
| 132 | + private Map<String, Object> invokeBedrockInference(Map<String, Object> mockResponse) throws Exception { |
| 133 | + // Create DefaultLlmImpl and mock ML client |
| 134 | + DefaultLlmImpl connector = new DefaultLlmImpl("model_id", null); // Use getClient() from MLCommonsRestTestCase |
| 135 | + MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class); |
| 136 | + injectMlClient(connector, mlClient); |
90 | 137 |
|
91 | | - Map<String, Object> errorFormat2 = Map.of("error", "InvalidRequest"); |
| 138 | + // Wrap mockResponse inside a ModelTensor -> ModelTensors -> ModelTensorOutput -> MLOutput |
| 139 | + ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, mockResponse); |
| 140 | + ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor)))); |
| 141 | + // Do NOT depend on ActionFuture return path; instead drive the async listener directly. |
92 | 142 |
|
93 | | - // Use the same validation style but inverted for errors |
94 | | - validateErrorOutput("Should detect error format 1 correctly", errorFormat1, "Unsupported Claude response format"); |
95 | | - validateErrorOutput("Should detect error format 2 correctly", errorFormat2, "InvalidRequest"); |
| 143 | + // Make asynchronous predict(...) call invoke the ActionListener with our mlOutput |
| 144 | + doAnswer(invocation -> { |
| 145 | + @SuppressWarnings("unchecked") |
| 146 | + ActionListener<MLOutput> listener = (ActionListener<MLOutput>) invocation.getArguments()[2]; |
| 147 | + // Simulate successful ML response |
| 148 | + listener.onResponse(mlOutput); |
| 149 | + return null; |
| 150 | + }).when(mlClient).predict(any(), any(), any()); |
| 151 | + |
| 152 | + // Prepare input (use BEDROCK provider so bedrock branch is taken) |
| 153 | + ChatCompletionInput input = new ChatCompletionInput( |
| 154 | + "bedrock/model", |
| 155 | + "question", |
| 156 | + Collections.emptyList(), |
| 157 | + Collections.emptyList(), |
| 158 | + 0, |
| 159 | + "prompt", |
| 160 | + "instructions", |
| 161 | + Llm.ModelProvider.BEDROCK, |
| 162 | + null, |
| 163 | + null |
| 164 | + ); |
| 165 | + |
| 166 | + // Synchronously wait for callback result |
| 167 | + CountDownLatch latch = new CountDownLatch(1); |
| 168 | + AtomicReference<Map<String, Object>> resultRef = new AtomicReference<>(); |
| 169 | + |
| 170 | + connector.doChatCompletion(input, new ActionListener<>() { |
| 171 | + @Override |
| 172 | + public void onResponse(ChatCompletionOutput output) { |
| 173 | + Map<String, Object> map = new HashMap<>(); |
| 174 | + map.put("answers", output.getAnswers()); |
| 175 | + map.put("errors", output.getErrors()); |
| 176 | + resultRef.set(map); |
| 177 | + latch.countDown(); |
| 178 | + } |
| 179 | + |
| 180 | + @Override |
| 181 | + public void onFailure(Exception e) { |
| 182 | + Map<String, Object> map = new HashMap<>(); |
| 183 | + map.put("answers", Collections.emptyList()); |
| 184 | + map.put("errors", List.of(e.getMessage())); |
| 185 | + resultRef.set(map); |
| 186 | + latch.countDown(); |
| 187 | + } |
| 188 | + }); |
| 189 | + |
| 190 | + boolean completed = latch.await(5, TimeUnit.SECONDS); |
| 191 | + if (!completed) { |
| 192 | + throw new RuntimeException("Timed out waiting for doChatCompletion callback"); |
| 193 | + } |
| 194 | + return resultRef.get(); |
96 | 195 | } |
97 | 196 |
|
98 | 197 | private void validateErrorOutput(String msg, Map<String, Object> output, String expectedError) { |
|
0 commit comments