Skip to content

Commit 1f48d34

Browse files
sonianuj287ylwu-amzn
authored andcommitted
Add complete Bedrock Claude V3 test coverage for DefaultLlmImpl
Signed-off-by: Anuj Soni <[email protected]>
1 parent 91c75f5 commit 1f48d34

File tree

1 file changed

+165
-0
lines changed

1 file changed

+165
-0
lines changed

search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/llm/DefaultLlmImplTests.java

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,171 @@ public void onFailure(Exception e) {
670670
assertTrue(mlInput.getInputDataset() instanceof RemoteInferenceInputDataSet);
671671
}
672672

673+
public void testChatCompletionBedrockV3ValidResponse() throws Exception {
674+
MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class);
675+
ArgumentCaptor<MLInput> captor = ArgumentCaptor.forClass(MLInput.class);
676+
DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client);
677+
connector.setMlClient(mlClient);
678+
679+
// Simulate valid Claude V3 response
680+
Map<String, Object> innerMap = Map.of("text", "Hello from Claude V3");
681+
Map<String, Object> dataAsMap = Map.of("content", List.of(innerMap));
682+
ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, dataAsMap);
683+
ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor))));
684+
ActionFuture<MLOutput> future = mock(ActionFuture.class);
685+
when(future.actionGet(anyLong())).thenReturn(mlOutput);
686+
when(mlClient.predict(any(), any())).thenReturn(future);
687+
688+
ChatCompletionInput input = new ChatCompletionInput(
689+
"model",
690+
"question",
691+
Collections.emptyList(),
692+
Collections.emptyList(),
693+
0,
694+
"prompt",
695+
"instructions",
696+
Llm.ModelProvider.BEDROCK,
697+
null,
698+
null
699+
);
700+
701+
doAnswer(invocation -> {
702+
((ActionListener<MLOutput>) invocation.getArguments()[2]).onResponse(mlOutput);
703+
return null;
704+
}).when(mlClient).predict(any(), any(), any());
705+
706+
connector.doChatCompletion(input, new ActionListener<>() {
707+
@Override
708+
public void onResponse(ChatCompletionOutput output) {
709+
assertFalse(output.isErrorOccurred());
710+
assertEquals("Hello from Claude V3", output.getAnswers().get(0));
711+
}
712+
713+
@Override
714+
public void onFailure(Exception e) {
715+
fail("Should not fail");
716+
}
717+
});
718+
}
719+
720+
public void testChatCompletionBedrockV3MissingTextField() throws Exception {
721+
MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class);
722+
DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client);
723+
connector.setMlClient(mlClient);
724+
725+
Map<String, Object> innerMap = Map.of("wrong_key", "oops");
726+
Map<String, Object> dataAsMap = Map.of("content", List.of(innerMap));
727+
ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, dataAsMap);
728+
ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor))));
729+
730+
doAnswer(invocation -> {
731+
((ActionListener<MLOutput>) invocation.getArguments()[2]).onResponse(mlOutput);
732+
return null;
733+
}).when(mlClient).predict(any(), any(), any());
734+
735+
ChatCompletionInput input = new ChatCompletionInput(
736+
"model",
737+
"question",
738+
Collections.emptyList(),
739+
Collections.emptyList(),
740+
0,
741+
"prompt",
742+
"instructions",
743+
Llm.ModelProvider.BEDROCK,
744+
null,
745+
null
746+
);
747+
748+
connector.doChatCompletion(input, new ActionListener<>() {
749+
@Override
750+
public void onResponse(ChatCompletionOutput output) {
751+
assertTrue(output.isErrorOccurred());
752+
assertTrue(output.getErrors().get(0).contains("missing 'text'"));
753+
}
754+
755+
@Override
756+
public void onFailure(Exception e) {}
757+
});
758+
}
759+
760+
public void testChatCompletionBedrockV3EmptyContentList() throws Exception {
761+
MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class);
762+
DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client);
763+
connector.setMlClient(mlClient);
764+
765+
Map<String, Object> dataAsMap = Map.of("content", List.of());
766+
ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, dataAsMap);
767+
ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor))));
768+
769+
doAnswer(invocation -> {
770+
((ActionListener<MLOutput>) invocation.getArguments()[2]).onResponse(mlOutput);
771+
return null;
772+
}).when(mlClient).predict(any(), any(), any());
773+
774+
ChatCompletionInput input = new ChatCompletionInput(
775+
"model",
776+
"question",
777+
Collections.emptyList(),
778+
Collections.emptyList(),
779+
0,
780+
"prompt",
781+
"instructions",
782+
Llm.ModelProvider.BEDROCK,
783+
null,
784+
null
785+
);
786+
787+
connector.doChatCompletion(input, new ActionListener<>() {
788+
@Override
789+
public void onResponse(ChatCompletionOutput output) {
790+
assertTrue(output.isErrorOccurred());
791+
assertTrue(output.getErrors().get(0).contains("Empty content list"));
792+
}
793+
794+
@Override
795+
public void onFailure(Exception e) {}
796+
});
797+
}
798+
799+
public void testChatCompletionBedrockV3UnexpectedType() throws Exception {
800+
MachineLearningInternalClient mlClient = mock(MachineLearningInternalClient.class);
801+
DefaultLlmImpl connector = new DefaultLlmImpl("model_id", client);
802+
connector.setMlClient(mlClient);
803+
804+
Map<String, Object> dataAsMap = Map.of("content", "not a list");
805+
ModelTensor tensor = new ModelTensor("tensor", new Number[0], new long[0], MLResultDataType.STRING, null, null, dataAsMap);
806+
ModelTensorOutput mlOutput = new ModelTensorOutput(List.of(new ModelTensors(List.of(tensor))));
807+
808+
doAnswer(invocation -> {
809+
((ActionListener<MLOutput>) invocation.getArguments()[2]).onResponse(mlOutput);
810+
return null;
811+
}).when(mlClient).predict(any(), any(), any());
812+
813+
ChatCompletionInput input = new ChatCompletionInput(
814+
"model",
815+
"question",
816+
Collections.emptyList(),
817+
Collections.emptyList(),
818+
0,
819+
"prompt",
820+
"instructions",
821+
Llm.ModelProvider.BEDROCK,
822+
null,
823+
null
824+
);
825+
826+
connector.doChatCompletion(input, new ActionListener<>() {
827+
@Override
828+
public void onResponse(ChatCompletionOutput output) {
829+
assertTrue(output.isErrorOccurred());
830+
assertTrue(output.getErrors().get(0).contains("Unexpected type"));
831+
}
832+
833+
@Override
834+
public void onFailure(Exception e) {}
835+
});
836+
}
837+
673838
public void testIllegalArgument1() {
674839
exceptionRule.expect(IllegalArgumentException.class);
675840
exceptionRule

0 commit comments

Comments
 (0)