@@ -670,6 +670,171 @@ public void onFailure(Exception e) {
670670 assertTrue (mlInput .getInputDataset () instanceof RemoteInferenceInputDataSet );
671671 }
672672
673+ public void testChatCompletionBedrockV3ValidResponse () throws Exception {
674+ MachineLearningInternalClient mlClient = mock (MachineLearningInternalClient .class );
675+ ArgumentCaptor <MLInput > captor = ArgumentCaptor .forClass (MLInput .class );
676+ DefaultLlmImpl connector = new DefaultLlmImpl ("model_id" , client );
677+ connector .setMlClient (mlClient );
678+
679+ // Simulate valid Claude V3 response
680+ Map <String , Object > innerMap = Map .of ("text" , "Hello from Claude V3" );
681+ Map <String , Object > dataAsMap = Map .of ("content" , List .of (innerMap ));
682+ ModelTensor tensor = new ModelTensor ("tensor" , new Number [0 ], new long [0 ], MLResultDataType .STRING , null , null , dataAsMap );
683+ ModelTensorOutput mlOutput = new ModelTensorOutput (List .of (new ModelTensors (List .of (tensor ))));
684+ ActionFuture <MLOutput > future = mock (ActionFuture .class );
685+ when (future .actionGet (anyLong ())).thenReturn (mlOutput );
686+ when (mlClient .predict (any (), any ())).thenReturn (future );
687+
688+ ChatCompletionInput input = new ChatCompletionInput (
689+ "model" ,
690+ "question" ,
691+ Collections .emptyList (),
692+ Collections .emptyList (),
693+ 0 ,
694+ "prompt" ,
695+ "instructions" ,
696+ Llm .ModelProvider .BEDROCK ,
697+ null ,
698+ null
699+ );
700+
701+ doAnswer (invocation -> {
702+ ((ActionListener <MLOutput >) invocation .getArguments ()[2 ]).onResponse (mlOutput );
703+ return null ;
704+ }).when (mlClient ).predict (any (), any (), any ());
705+
706+ connector .doChatCompletion (input , new ActionListener <>() {
707+ @ Override
708+ public void onResponse (ChatCompletionOutput output ) {
709+ assertFalse (output .isErrorOccurred ());
710+ assertEquals ("Hello from Claude V3" , output .getAnswers ().get (0 ));
711+ }
712+
713+ @ Override
714+ public void onFailure (Exception e ) {
715+ fail ("Should not fail" );
716+ }
717+ });
718+ }
719+
720+ public void testChatCompletionBedrockV3MissingTextField () throws Exception {
721+ MachineLearningInternalClient mlClient = mock (MachineLearningInternalClient .class );
722+ DefaultLlmImpl connector = new DefaultLlmImpl ("model_id" , client );
723+ connector .setMlClient (mlClient );
724+
725+ Map <String , Object > innerMap = Map .of ("wrong_key" , "oops" );
726+ Map <String , Object > dataAsMap = Map .of ("content" , List .of (innerMap ));
727+ ModelTensor tensor = new ModelTensor ("tensor" , new Number [0 ], new long [0 ], MLResultDataType .STRING , null , null , dataAsMap );
728+ ModelTensorOutput mlOutput = new ModelTensorOutput (List .of (new ModelTensors (List .of (tensor ))));
729+
730+ doAnswer (invocation -> {
731+ ((ActionListener <MLOutput >) invocation .getArguments ()[2 ]).onResponse (mlOutput );
732+ return null ;
733+ }).when (mlClient ).predict (any (), any (), any ());
734+
735+ ChatCompletionInput input = new ChatCompletionInput (
736+ "model" ,
737+ "question" ,
738+ Collections .emptyList (),
739+ Collections .emptyList (),
740+ 0 ,
741+ "prompt" ,
742+ "instructions" ,
743+ Llm .ModelProvider .BEDROCK ,
744+ null ,
745+ null
746+ );
747+
748+ connector .doChatCompletion (input , new ActionListener <>() {
749+ @ Override
750+ public void onResponse (ChatCompletionOutput output ) {
751+ assertTrue (output .isErrorOccurred ());
752+ assertTrue (output .getErrors ().get (0 ).contains ("missing 'text'" ));
753+ }
754+
755+ @ Override
756+ public void onFailure (Exception e ) {}
757+ });
758+ }
759+
760+ public void testChatCompletionBedrockV3EmptyContentList () throws Exception {
761+ MachineLearningInternalClient mlClient = mock (MachineLearningInternalClient .class );
762+ DefaultLlmImpl connector = new DefaultLlmImpl ("model_id" , client );
763+ connector .setMlClient (mlClient );
764+
765+ Map <String , Object > dataAsMap = Map .of ("content" , List .of ());
766+ ModelTensor tensor = new ModelTensor ("tensor" , new Number [0 ], new long [0 ], MLResultDataType .STRING , null , null , dataAsMap );
767+ ModelTensorOutput mlOutput = new ModelTensorOutput (List .of (new ModelTensors (List .of (tensor ))));
768+
769+ doAnswer (invocation -> {
770+ ((ActionListener <MLOutput >) invocation .getArguments ()[2 ]).onResponse (mlOutput );
771+ return null ;
772+ }).when (mlClient ).predict (any (), any (), any ());
773+
774+ ChatCompletionInput input = new ChatCompletionInput (
775+ "model" ,
776+ "question" ,
777+ Collections .emptyList (),
778+ Collections .emptyList (),
779+ 0 ,
780+ "prompt" ,
781+ "instructions" ,
782+ Llm .ModelProvider .BEDROCK ,
783+ null ,
784+ null
785+ );
786+
787+ connector .doChatCompletion (input , new ActionListener <>() {
788+ @ Override
789+ public void onResponse (ChatCompletionOutput output ) {
790+ assertTrue (output .isErrorOccurred ());
791+ assertTrue (output .getErrors ().get (0 ).contains ("Empty content list" ));
792+ }
793+
794+ @ Override
795+ public void onFailure (Exception e ) {}
796+ });
797+ }
798+
799+ public void testChatCompletionBedrockV3UnexpectedType () throws Exception {
800+ MachineLearningInternalClient mlClient = mock (MachineLearningInternalClient .class );
801+ DefaultLlmImpl connector = new DefaultLlmImpl ("model_id" , client );
802+ connector .setMlClient (mlClient );
803+
804+ Map <String , Object > dataAsMap = Map .of ("content" , "not a list" );
805+ ModelTensor tensor = new ModelTensor ("tensor" , new Number [0 ], new long [0 ], MLResultDataType .STRING , null , null , dataAsMap );
806+ ModelTensorOutput mlOutput = new ModelTensorOutput (List .of (new ModelTensors (List .of (tensor ))));
807+
808+ doAnswer (invocation -> {
809+ ((ActionListener <MLOutput >) invocation .getArguments ()[2 ]).onResponse (mlOutput );
810+ return null ;
811+ }).when (mlClient ).predict (any (), any (), any ());
812+
813+ ChatCompletionInput input = new ChatCompletionInput (
814+ "model" ,
815+ "question" ,
816+ Collections .emptyList (),
817+ Collections .emptyList (),
818+ 0 ,
819+ "prompt" ,
820+ "instructions" ,
821+ Llm .ModelProvider .BEDROCK ,
822+ null ,
823+ null
824+ );
825+
826+ connector .doChatCompletion (input , new ActionListener <>() {
827+ @ Override
828+ public void onResponse (ChatCompletionOutput output ) {
829+ assertTrue (output .isErrorOccurred ());
830+ assertTrue (output .getErrors ().get (0 ).contains ("Unexpected type" ));
831+ }
832+
833+ @ Override
834+ public void onFailure (Exception e ) {}
835+ });
836+ }
837+
673838 public void testIllegalArgument1 () {
674839 exceptionRule .expect (IllegalArgumentException .class );
675840 exceptionRule
0 commit comments