@@ -111,6 +111,52 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase {
111111 private static final String BEDROCK_ANTHROPIC_CLAUDE_3_5_SONNET = "anthropic.claude-3-5-sonnet-20240620-v1:0" ;
112112 private static final String BEDROCK_ANTHROPIC_CLAUDE_3_SONNET = "anthropic.claude-3-sonnet-20240229-v1:0" ;
113113
114+ private static final String BEDROCK_CONNECTOR_BLUEPRINT_INVOKE = "{\n "
115+ + " \" name\" : \" Bedrock Connector: claude 3.5\" ,\n "
116+ + " \" description\" : \" The connector to bedrock claude 3.5 model\" ,\n "
117+ + " \" version\" : 1,\n "
118+ + " \" protocol\" : \" aws_sigv4\" ,\n "
119+ + " \" parameters\" : {\n "
120+ + " \" region\" : \" "
121+ + GITHUB_CI_AWS_REGION
122+ + "\" ,\n "
123+ + " \" service_name\" : \" bedrock\" ,\n "
124+ + " \" model\" : \" "
125+ + "anthropic.claude-3-5-sonnet-20240620-v1:0"
126+ + "\" ,\n "
127+ + " \" system_prompt\" : \" You are a helpful assistant.\" ,\n "
128+ + "\" response_filter\" : \" $.content[0].text\" "
129+ + " },\n "
130+ + " \" credential\" : {\n "
131+ + " \" access_key\" : \" "
132+ + AWS_ACCESS_KEY_ID
133+ + "\" ,\n "
134+ + " \" secret_key\" : \" "
135+ + AWS_SECRET_ACCESS_KEY
136+ + "\" ,\n "
137+ + " \" session_token\" : \" "
138+ + AWS_SESSION_TOKEN
139+ + "\" \n "
140+ + " },\n "
141+ + " \" actions\" : [\n "
142+ + " {\n "
143+ + " \" action_type\" : \" "
144+ + "predict"
145+ + "\" ,\n "
146+ + " \" method\" : \" POST\" ,\n "
147+ + " \" headers\" : {\n "
148+ + " \" content-type\" : \" application/json\" \n "
149+ + " },\n "
150+ + " \" url\" : \" https://bedrock-runtime."
151+ + GITHUB_CI_AWS_REGION
152+ + ".amazonaws.com/model/"
153+ + "anthropic.claude-3-5-sonnet-20240620-v1:0"
154+ + "/invoke\" ,\n "
155+ + " \" request_body\" : \" {\\ \" messages\\ \" :[{\\ \" role\\ \" : \\ \" user\\ \" , \\ \" content\\ \" :[ {\\ \" type\\ \" : \\ \" text\\ \" , \\ \" text\\ \" :\\ \" ${parameters.inputs}\\ \" }]}], \\ \" max_tokens\\ \" :300, \\ \" temperature\\ \" :0.5, \\ \" anthropic_version\\ \" :\\ \" bedrock-2023-05-31\\ \" }\" \n "
156+ + " }\n "
157+ + " ]\n "
158+ + "}" ;
159+
114160 private static final String BEDROCK_CONNECTOR_BLUEPRINT1 = "{\n "
115161 + " \" name\" : \" Bedrock Connector: claude2\" ,\n "
116162 + " \" description\" : \" The connector to bedrock claude2 model\" ,\n "
@@ -181,7 +227,7 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase {
181227 + " ]\n "
182228 + "}" ;
183229
184- private static final String BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT2 = "{\n "
230+ static final String BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT2 = "{\n "
185231 + " \" name\" : \" Bedrock Connector: claude 3.5\" ,\n "
186232 + " \" description\" : \" The connector to bedrock claude 3.5 model\" ,\n "
187233 + " \" version\" : 1,\n "
@@ -268,8 +314,8 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase {
268314 + "}" ;
269315
270316 private static final String BEDROCK_CONNECTOR_BLUEPRINT = AWS_SESSION_TOKEN == null
271- ? BEDROCK_CONNECTOR_BLUEPRINT2
272- : BEDROCK_CONNECTOR_BLUEPRINT1 ;
317+ ? BEDROCK_CONNECTOR_BLUEPRINT_INVOKE
318+ : BEDROCK_CONNECTOR_BLUEPRINT_INVOKE ;
273319
274320 private static final String BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT = AWS_SESSION_TOKEN == null
275321 ? BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT2
@@ -425,6 +471,26 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase {
425471 + " }\n "
426472 + "}" ;
427473
474+ private static final String BM25_SEARCH_REQUEST_WITH_CONVO_WITH_LLM_RESPONSE_TEMPLATE = "{\n "
475+ + " \" _source\" : [\" %s\" ],\n "
476+ + " \" query\" : {\n "
477+ + " \" match\" : {\" %s\" : \" %s\" }\n "
478+ + " },\n "
479+ + " \" ext\" : {\n "
480+ + " \" generative_qa_parameters\" : {\n "
481+ + " \" llm_model\" : \" %s\" ,\n "
482+ + " \" llm_question\" : \" %s\" ,\n "
483+ + " \" memory_id\" : \" %s\" ,\n "
484+ + " \" system_prompt\" : \" %s\" ,\n "
485+ + " \" user_instructions\" : \" %s\" ,\n "
486+ + " \" context_size\" : %d,\n "
487+ + " \" message_size\" : %d,\n "
488+ + " \" timeout\" : %d,\n "
489+ + " \" llm_response_field\" : \" %s\" \n "
490+ + " }\n "
491+ + " }\n "
492+ + "}" ;
493+
428494 private static final String BM25_SEARCH_REQUEST_WITH_CONVO_AND_IMAGE_TEMPLATE = "{\n "
429495 + " \" _source\" : [\" %s\" ],\n "
430496 + " \" query\" : {\n "
@@ -705,6 +771,7 @@ public void testBM25WithBedrock() throws Exception {
705771 requestParameters .contextSize = 5 ;
706772 requestParameters .interactionSize = 5 ;
707773 requestParameters .timeout = 60 ;
774+ requestParameters .llmResponseField = "response" ;
708775 Response response2 = performSearch (INDEX_NAME , "pipeline_test" , 5 , requestParameters );
709776 assertEquals (200 , response2 .getStatusLine ().getStatusCode ());
710777
@@ -1068,6 +1135,7 @@ public void testBM25WithBedrockWithConversation() throws Exception {
10681135 requestParameters .interactionSize = 5 ;
10691136 requestParameters .timeout = 60 ;
10701137 requestParameters .conversationId = conversationId ;
1138+ requestParameters .llmResponseField = "response" ;
10711139 Response response2 = performSearch (INDEX_NAME , "pipeline_test" , 5 , requestParameters );
10721140 assertEquals (200 , response2 .getStatusLine ().getStatusCode ());
10731141
@@ -1240,7 +1308,7 @@ private Response performSearch(String indexName, String pipeline, int size, Sear
12401308 throws Exception {
12411309
12421310 // TODO build these templates dynamically
1243- String httpEntity = requestParameters .llmResponseField != null
1311+ String httpEntity = requestParameters .llmResponseField != null && requestParameters . conversationId == null
12441312 ? String
12451313 .format (
12461314 Locale .ROOT ,
@@ -1351,10 +1419,27 @@ private Response performSearch(String indexName, String pipeline, int size, Sear
13511419 requestParameters .interactionSize ,
13521420 requestParameters .timeout
13531421 )
1422+ : (requestParameters .llmResponseField == null )
1423+ ? String
1424+ .format (
1425+ Locale .ROOT ,
1426+ BM25_SEARCH_REQUEST_WITH_CONVO_TEMPLATE ,
1427+ requestParameters .source ,
1428+ requestParameters .source ,
1429+ requestParameters .match ,
1430+ requestParameters .llmModel ,
1431+ requestParameters .llmQuestion ,
1432+ requestParameters .conversationId ,
1433+ requestParameters .systemPrompt ,
1434+ requestParameters .userInstructions ,
1435+ requestParameters .contextSize ,
1436+ requestParameters .interactionSize ,
1437+ requestParameters .timeout
1438+ )
13541439 : String
13551440 .format (
13561441 Locale .ROOT ,
1357- BM25_SEARCH_REQUEST_WITH_CONVO_TEMPLATE ,
1442+ BM25_SEARCH_REQUEST_WITH_CONVO_WITH_LLM_RESPONSE_TEMPLATE ,
13581443 requestParameters .source ,
13591444 requestParameters .source ,
13601445 requestParameters .match ,
@@ -1365,7 +1450,8 @@ private Response performSearch(String indexName, String pipeline, int size, Sear
13651450 requestParameters .userInstructions ,
13661451 requestParameters .contextSize ,
13671452 requestParameters .interactionSize ,
1368- requestParameters .timeout
1453+ requestParameters .timeout ,
1454+ requestParameters .llmResponseField
13691455 );
13701456 return makeRequest (
13711457 client (),
0 commit comments