Skip to content

Commit ffab1bf

Browse files
sonianuj287ylwu-amzn
authored andcommitted
fixed comments suggested changes
Signed-off-by: Anuj Soni <[email protected]>
1 parent 1f48d34 commit ffab1bf

File tree

3 files changed

+144
-47
lines changed

3 files changed

+144
-47
lines changed

plugin/src/test/java/org/opensearch/ml/rest/RestBedRockInferenceIT.java

Lines changed: 107 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,40 @@
55

66
package org.opensearch.ml.rest;
77

8+
import static org.mockito.ArgumentMatchers.any;
9+
import static org.mockito.Mockito.doAnswer;
10+
import static org.mockito.Mockito.mock;
11+
812
import java.io.IOException;
13+
import java.lang.reflect.Field;
914
import java.nio.file.Files;
1015
import java.nio.file.Path;
1116
import java.util.ArrayList;
17+
import java.util.Collections;
18+
import java.util.HashMap;
1219
import java.util.List;
1320
import java.util.Locale;
1421
import java.util.Map;
22+
import java.util.concurrent.CountDownLatch;
23+
import java.util.concurrent.TimeUnit;
24+
import java.util.concurrent.atomic.AtomicReference;
1525

1626
import org.junit.Before;
27+
import org.opensearch.core.action.ActionListener;
1728
import org.opensearch.ml.common.FunctionName;
1829
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
1930
import org.opensearch.ml.common.input.MLInput;
31+
import org.opensearch.ml.common.output.MLOutput;
32+
import org.opensearch.ml.common.output.model.MLResultDataType;
33+
import org.opensearch.ml.common.output.model.ModelTensor;
34+
import org.opensearch.ml.common.output.model.ModelTensorOutput;
35+
import org.opensearch.ml.common.output.model.ModelTensors;
2036
import org.opensearch.ml.common.utils.StringUtils;
37+
import org.opensearch.searchpipelines.questionanswering.generative.client.MachineLearningInternalClient;
38+
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionInput;
39+
import org.opensearch.searchpipelines.questionanswering.generative.llm.ChatCompletionOutput;
40+
import org.opensearch.searchpipelines.questionanswering.generative.llm.DefaultLlmImpl;
41+
import org.opensearch.searchpipelines.questionanswering.generative.llm.Llm;
2142

2243
import lombok.SneakyThrows;
2344
import lombok.extern.log4j.Log4j2;
@@ -82,17 +103,95 @@ public void test_bedrock_embedding_model() throws Exception {
82103
}
83104
}
84105

85-
public void testChatCompletionBedrockErrorResponseFormats() throws Exception {
86-
// Simulate Bedrock inference endpoint behavior
87-
// You can mock or create sample response maps for two formats
106+
public void testChatCompletionBedrockContentFormat() throws Exception {
107+
Map<String, Object> response = Map.of("content", List.of(Map.of("text", "Claude V3 response text")));
108+
109+
Map<String, Object> result = invokeBedrockInference(response);
110+
111+
assertTrue(result.containsKey("answers"));
112+
assertEquals("Claude V3 response text", ((List<?>) result.get("answers")).get(0));
113+
}
114+
115+
private static void injectMlClient(DefaultLlmImpl connector, Object mlClient) {
116+
try {
117+
Field field = null;
118+
// Try common field names. Adjust if the actual field is named differently.
119+
try {
120+
field = DefaultLlmImpl.class.getDeclaredField("mlClient");
121+
} catch (NoSuchFieldException e) {
122+
// fallback if different field name
123+
field = DefaultLlmImpl.class.getDeclaredField("client");
124+
}
125+
field.setAccessible(true);
126+
field.set(connector, mlClient);
127+
} catch (ReflectiveOperationException e) {
128+
throw new RuntimeException("Failed to inject mlClient into DefaultLlmImpl", e);
129+
}
130+
}
88131

89-
Map<String, Object> errorFormat1 = Map.of("error", Map.of("message", "Unsupported Claude response format"));
132+
private Map<String, Object> invokeBedrockInference(Map<String, Object> mockResponse) throws Exception {
133+
// Create DefaultLlmImpl and mock ML client
134+
DefaultLlmImpl connector = new DefaultLlmImpl("model_id", null); // Use getClient() from MLCommonsRestTestCase
135+
MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class);
136+
injectMlClient(connector, mlClient);
90137

91-
Map<String, Object> errorFormat2 = Map.of("error", "InvalidRequest");
138+
// Wrap mockResponse inside a ModelTensor -> ModelTensors -> ModelTensorOutput -> MLOutput
139+
ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, mockResponse);
140+
ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor))));
141+
// Do NOT depend on ActionFuture return path; instead drive the async listener directly.
92142

93-
// Use the same validation style but inverted for errors
94-
validateErrorOutput("Should detect error format 1 correctly", errorFormat1, "Unsupported Claude response format");
95-
validateErrorOutput("Should detect error format 2 correctly", errorFormat2, "InvalidRequest");
143+
// Make asynchronous predict(...) call invoke the ActionListener with our mlOutput
144+
doAnswer(invocation -> {
145+
@SuppressWarnings("unchecked")
146+
ActionListener<MLOutput> listener = (ActionListener<MLOutput>) invocation.getArguments()[2];
147+
// Simulate successful ML response
148+
listener.onResponse(mlOutput);
149+
return null;
150+
}).when(mlClient).predict(any(), any(), any());
151+
152+
// Prepare input (use BEDROCK provider so bedrock branch is taken)
153+
ChatCompletionInput input = new ChatCompletionInput(
154+
"bedrock/model",
155+
"question",
156+
Collections.emptyList(),
157+
Collections.emptyList(),
158+
0,
159+
"prompt",
160+
"instructions",
161+
Llm.ModelProvider.BEDROCK,
162+
null,
163+
null
164+
);
165+
166+
// Synchronously wait for callback result
167+
CountDownLatch latch = new CountDownLatch(1);
168+
AtomicReference<Map<String, Object>> resultRef = new AtomicReference<>();
169+
170+
connector.doChatCompletion(input, new ActionListener<>() {
171+
@Override
172+
public void onResponse(ChatCompletionOutput output) {
173+
Map<String, Object> map = new HashMap<>();
174+
map.put("answers", output.getAnswers());
175+
map.put("errors", output.getErrors());
176+
resultRef.set(map);
177+
latch.countDown();
178+
}
179+
180+
@Override
181+
public void onFailure(Exception e) {
182+
Map<String, Object> map = new HashMap<>();
183+
map.put("answers", Collections.emptyList());
184+
map.put("errors", List.of(e.getMessage()));
185+
resultRef.set(map);
186+
latch.countDown();
187+
}
188+
});
189+
190+
boolean completed = latch.await(5, TimeUnit.SECONDS);
191+
if (!completed) {
192+
throw new RuntimeException("Timed out waiting for doChatCompletion callback");
193+
}
194+
return resultRef.get();
96195
}
97196

98197
private void validateErrorOutput(String msg, Map<String, Object> output, String expectedError) {

search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImpl.java

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ public class DefaultLlmImpl implements Llm {
5252
private static final String CONNECTOR_OUTPUT_MESSAGE_ROLE = "role";
5353
private static final String CONNECTOR_OUTPUT_MESSAGE_CONTENT = "content";
5454
private static final String CONNECTOR_OUTPUT_ERROR = "error";
55-
private static final String CLAUDE_V2_COMPLETION_FIELD = "completion";
56-
private static final String CLAUDE_V3_CONTENT_FIELD = "content";
57-
private static final String CLAUDE_V3_TEXT_FIELD = "text";
55+
private static final String BEDROCK_COMPLETION_FIELD = "completion";
56+
private static final String BEDROCK_CONTENT_FIELD = "content";
57+
private static final String BEDROCK_TEXT_FIELD = "text";
5858

5959
private final String openSearchModelId;
6060

@@ -194,39 +194,37 @@ protected ChatCompletionOutput buildChatCompletionOutput(ModelProvider provider,
194194
answers = List.of(message.get(CONNECTOR_OUTPUT_MESSAGE_CONTENT));
195195
}
196196
} else if (provider == ModelProvider.BEDROCK) {
197-
// Handle both Claude V2 and V3 response formats
198-
if (dataAsMap.containsKey(CLAUDE_V2_COMPLETION_FIELD)) {
199-
// Old Claude V2 format
200-
answerField = CLAUDE_V2_COMPLETION_FIELD;
201-
fillAnswersOrErrors(dataAsMap, answers, errors, answerField, errorField, defaultErrorMessageField);
202-
} else if (dataAsMap.containsKey(CLAUDE_V3_CONTENT_FIELD)) {
203-
// New Claude V3 format
204-
Object contentObj = dataAsMap.get(CLAUDE_V3_CONTENT_FIELD);
205-
if (contentObj instanceof List) {
206-
List<?> contentList = (List<?>) contentObj;
207-
if (!contentList.isEmpty()) {
208-
Object first = contentList.get(0);
209-
if (first instanceof Map) {
210-
Map<?, ?> firstMap = (Map<?, ?>) first;
211-
Object text = firstMap.get(CLAUDE_V3_TEXT_FIELD);
212-
if (text != null) {
213-
answers.add(text.toString());
214-
} else {
215-
errors.add("Claude V3 response missing '" + CLAUDE_V3_TEXT_FIELD + "' field.");
216-
}
197+
// Handle Bedrock model responses (supports both legacy completion and newer content/text formats)
198+
199+
Object contentObj = dataAsMap.get(BEDROCK_CONTENT_FIELD);
200+
if (contentObj == null) {
201+
// Legacy completion-style format
202+
Object completion = dataAsMap.get(BEDROCK_COMPLETION_FIELD);
203+
if (completion != null) {
204+
answers.add(completion.toString());
205+
} else {
206+
errors.add("Unsupported Bedrock response format: " + dataAsMap.keySet());
207+
log.error("Unknown Bedrock response format: {}", dataAsMap);
208+
}
209+
} else {
210+
// Fail-fast checks for new content/text format
211+
if (!(contentObj instanceof List<?> contentList)) {
212+
errors.add("Unexpected type for '" + BEDROCK_CONTENT_FIELD + "' in Bedrock response.");
213+
} else if (contentList.isEmpty()) {
214+
errors.add("Empty content list in Bedrock response.");
215+
} else {
216+
Object first = contentList.get(0);
217+
if (!(first instanceof Map<?, ?> firstMap)) {
218+
errors.add("Unexpected content format in Bedrock response.");
219+
} else {
220+
Object text = firstMap.get(BEDROCK_TEXT_FIELD);
221+
if (text == null) {
222+
errors.add("Bedrock content response missing '" + BEDROCK_TEXT_FIELD + "' field.");
217223
} else {
218-
errors.add("Unexpected content format in Claude V3 response.");
224+
answers.add(text.toString());
219225
}
220-
} else {
221-
errors.add("Empty content list in Claude V3 response.");
222226
}
223-
} else {
224-
errors.add("Unexpected type for '" + CLAUDE_V3_CONTENT_FIELD + "' in Claude V3 response.");
225227
}
226-
} else {
227-
// Fallback error handling
228-
errors.add("Unsupported Claude response format: " + dataAsMap.keySet());
229-
log.error("Unknown Bedrock/Claude response format: {}", dataAsMap);
230228
}
231229
} else if (provider == ModelProvider.COHERE) {
232230
answerField = "text";

search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,14 @@ public void onFailure(Exception e) {
143143
assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet);
144144
}
145145

146-
public void testChatCompletionApiForBedrockClaudeV3() throws Exception {
146+
public void testChatCompletionApiForBedrockContentFormat() throws Exception {
147147
MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class);
148148
ArgumentCaptor<MLInput> captor = ArgumentCaptor.forClass(MLInput.class);
149149
DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client);
150150
connector.setMlClient(mlClient);
151151

152-
// Claude V3-style response
153-
Map<String, Object> textPart = Map.of("type", "text", "text", "Hello from Claude V3");
152+
// Bedrock content/text response (newer format)
153+
Map<String, Object> textPart = Map.of("type", "text", "text", "Hello from Bedrock");
154154
Map<String, Object> dataAsMap = Map.of("content", List.of(textPart));
155155

156156
ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, dataAsMap);
@@ -180,13 +180,13 @@ public void testChatCompletionApiForBedrockClaudeV3() throws Exception {
180180
connector.doChatCompletion(input, new ActionListener<>() {
181181
@Override
182182
public void onResponse(ChatCompletionOutput output) {
183-
// Verify that we parsed the Claude V3 response correctly
184-
assertEquals("Hello from Claude V3", output.getAnswers().get(0));
183+
// Verify that we parsed the Bedrock content response correctly
184+
assertEquals("Hello from Bedrock", output.getAnswers().get(0));
185185
}
186186

187187
@Override
188188
public void onFailure(Exception e) {
189-
fail("Claude V3 test failed: " + e.getMessage());
189+
fail("Bedrock test failed: " + e.getMessage());
190190
}
191191
});
192192

@@ -629,7 +629,7 @@ public void testChatCompletionBedrockThrowingError() throws Exception {
629629
DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client);
630630
connector.setMlClient(mlClient);
631631

632-
String errorMessage = "Unsupported Claude response format";
632+
String errorMessage = "Unsupported Bedrock response format";
633633
Map<String, String> messageMap = Map.of("message", errorMessage);
634634
Map<String, ?> dataAsMap = Map.of("error", messageMap);
635635
ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, dataAsMap);

0 commit comments

Comments
 (0)