9
9
10
10
import com .carrotsearch .randomizedtesting .annotations .Name ;
11
11
12
+ import org .elasticsearch .client .ResponseException ;
12
13
import org .elasticsearch .common .Strings ;
13
14
import org .elasticsearch .inference .TaskType ;
15
+ import org .elasticsearch .test .http .MockRequest ;
14
16
import org .elasticsearch .test .http .MockResponse ;
15
17
import org .elasticsearch .test .http .MockWebServer ;
16
18
import org .elasticsearch .xpack .inference .services .cohere .embeddings .CohereEmbeddingType ;
24
26
25
27
import static org .hamcrest .Matchers .anEmptyMap ;
26
28
import static org .hamcrest .Matchers .anyOf ;
29
+ import static org .hamcrest .Matchers .containsString ;
27
30
import static org .hamcrest .Matchers .empty ;
28
31
import static org .hamcrest .Matchers .hasEntry ;
29
32
import static org .hamcrest .Matchers .hasSize ;
@@ -35,11 +38,16 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase {
35
38
36
39
private static final String COHERE_EMBEDDINGS_ADDED = "8.13.0" ;
37
40
private static final String COHERE_RERANK_ADDED = "8.14.0" ;
38
- private static final String BYTE_ALIAS_FOR_INT8_ADDED = "8.14.0 " ;
41
+ private static final String COHERE_V2_API_ADDED_TEST_FEATURE = "inference.cohere.v2 " ;
39
42
40
43
private static MockWebServer cohereEmbeddingsServer ;
41
44
private static MockWebServer cohereRerankServer ;
42
45
46
+ private enum ApiVersion {
47
+ V1 ,
48
+ V2
49
+ }
50
+
43
51
public CohereServiceUpgradeIT (@ Name ("upgradedNodes" ) int upgradedNodes ) {
44
52
super (upgradedNodes );
45
53
}
@@ -64,14 +72,15 @@ public void testCohereEmbeddings() throws IOException {
64
72
var embeddingsSupported = getOldClusterTestVersion ().onOrAfter (COHERE_EMBEDDINGS_ADDED );
65
73
// `gte_v` indicates that the cluster version is Greater Than or Equal to MODELS_RENAMED_TO_ENDPOINTS
66
74
String oldClusterEndpointIdentifier = oldClusterHasFeature ("gte_v" + MODELS_RENAMED_TO_ENDPOINTS ) ? "endpoints" : "models" ;
67
- assumeTrue ( "Cohere embedding service added in " + COHERE_EMBEDDINGS_ADDED , embeddingsSupported ) ;
75
+ ApiVersion oldClusterApiVersion = oldClusterHasFeature ( COHERE_V2_API_ADDED_TEST_FEATURE ) ? ApiVersion . V2 : ApiVersion . V1 ;
68
76
69
77
final String oldClusterIdInt8 = "old-cluster-embeddings-int8" ;
70
78
final String oldClusterIdFloat = "old-cluster-embeddings-float" ;
71
79
72
80
var testTaskType = TaskType .TEXT_EMBEDDING ;
73
81
74
82
if (isOldCluster ()) {
83
+
75
84
// queue a response as PUT will call the service
76
85
cohereEmbeddingsServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (embeddingResponseByte ()));
77
86
put (oldClusterIdInt8 , embeddingConfigInt8 (getUrl (cohereEmbeddingsServer )), testTaskType );
@@ -129,13 +138,29 @@ public void testCohereEmbeddings() throws IOException {
129
138
130
139
// Inference on old cluster models
131
140
assertEmbeddingInference (oldClusterIdInt8 , CohereEmbeddingType .BYTE );
141
+ assertVersionInPath (
142
+ cohereEmbeddingsServer .requests ().get (cohereEmbeddingsServer .requests ().size () - 1 ),
143
+ "embed" ,
144
+ oldClusterApiVersion
145
+ );
132
146
assertEmbeddingInference (oldClusterIdFloat , CohereEmbeddingType .FLOAT );
147
+ assertVersionInPath (
148
+ cohereEmbeddingsServer .requests ().get (cohereEmbeddingsServer .requests ().size () - 1 ),
149
+ "embed" ,
150
+ oldClusterApiVersion
151
+ );
133
152
134
153
{
135
154
final String upgradedClusterIdByte = "upgraded-cluster-embeddings-byte" ;
136
155
156
+ // new endpoints use the V2 API
137
157
cohereEmbeddingsServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (embeddingResponseByte ()));
138
158
put (upgradedClusterIdByte , embeddingConfigByte (getUrl (cohereEmbeddingsServer )), testTaskType );
159
+ assertVersionInPath (
160
+ cohereEmbeddingsServer .requests ().get (cohereEmbeddingsServer .requests ().size () - 1 ),
161
+ "embed" ,
162
+ ApiVersion .V2
163
+ );
139
164
140
165
configs = (List <Map <String , Object >>) get (testTaskType , upgradedClusterIdByte ).get ("endpoints" );
141
166
serviceSettings = (Map <String , Object >) configs .get (0 ).get ("service_settings" );
@@ -147,34 +172,86 @@ public void testCohereEmbeddings() throws IOException {
147
172
{
148
173
final String upgradedClusterIdInt8 = "upgraded-cluster-embeddings-int8" ;
149
174
175
+ // new endpoints use the V2 API
150
176
cohereEmbeddingsServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (embeddingResponseByte ()));
151
177
put (upgradedClusterIdInt8 , embeddingConfigInt8 (getUrl (cohereEmbeddingsServer )), testTaskType );
178
+ assertVersionInPath (
179
+ cohereEmbeddingsServer .requests ().get (cohereEmbeddingsServer .requests ().size () - 1 ),
180
+ "embed" ,
181
+ ApiVersion .V2
182
+ );
152
183
153
184
configs = (List <Map <String , Object >>) get (testTaskType , upgradedClusterIdInt8 ).get ("endpoints" );
154
185
serviceSettings = (Map <String , Object >) configs .get (0 ).get ("service_settings" );
155
186
assertThat (serviceSettings , hasEntry ("embedding_type" , "byte" )); // int8 rewritten to byte
156
187
157
188
assertEmbeddingInference (upgradedClusterIdInt8 , CohereEmbeddingType .INT8 );
189
+ assertVersionInPath (
190
+ cohereEmbeddingsServer .requests ().get (cohereEmbeddingsServer .requests ().size () - 1 ),
191
+ "embed" ,
192
+ ApiVersion .V2
193
+ );
158
194
delete (upgradedClusterIdInt8 );
159
195
}
160
196
{
161
197
final String upgradedClusterIdFloat = "upgraded-cluster-embeddings-float" ;
162
198
cohereEmbeddingsServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (embeddingResponseFloat ()));
163
199
put (upgradedClusterIdFloat , embeddingConfigFloat (getUrl (cohereEmbeddingsServer )), testTaskType );
200
+ assertVersionInPath (
201
+ cohereEmbeddingsServer .requests ().get (cohereEmbeddingsServer .requests ().size () - 1 ),
202
+ "embed" ,
203
+ ApiVersion .V2
204
+ );
164
205
165
206
configs = (List <Map <String , Object >>) get (testTaskType , upgradedClusterIdFloat ).get ("endpoints" );
166
207
serviceSettings = (Map <String , Object >) configs .get (0 ).get ("service_settings" );
167
208
assertThat (serviceSettings , hasEntry ("embedding_type" , "float" ));
168
209
169
210
assertEmbeddingInference (upgradedClusterIdFloat , CohereEmbeddingType .FLOAT );
211
+ assertVersionInPath (
212
+ cohereEmbeddingsServer .requests ().get (cohereEmbeddingsServer .requests ().size () - 1 ),
213
+ "embed" ,
214
+ ApiVersion .V2
215
+ );
170
216
delete (upgradedClusterIdFloat );
171
217
}
218
+ {
219
+ // new endpoints use the V2 API which require the model to be set
220
+ final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id" ;
221
+ var jsonBody = Strings .format ("""
222
+ {
223
+ "service": "cohere",
224
+ "service_settings": {
225
+ "url": "%s",
226
+ "api_key": "XXXX",
227
+ "embedding_type": "int8"
228
+ }
229
+ }
230
+ """ , getUrl (cohereEmbeddingsServer ));
231
+
232
+ var e = expectThrows (ResponseException .class , () -> put (upgradedClusterNoModel , jsonBody , testTaskType ));
233
+ assertThat (
234
+ e .getMessage (),
235
+ containsString ("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API." )
236
+ );
237
+ }
172
238
173
239
delete (oldClusterIdFloat );
174
240
delete (oldClusterIdInt8 );
175
241
}
176
242
}
177
243
244
+ private void assertVersionInPath (MockRequest request , String endpoint , ApiVersion apiVersion ) {
245
+ switch (apiVersion ) {
246
+ case V2 :
247
+ assertEquals ("/v2/" + endpoint , request .getUri ().getPath ());
248
+ break ;
249
+ case V1 :
250
+ assertEquals ("/v1/" + endpoint , request .getUri ().getPath ());
251
+ break ;
252
+ }
253
+ }
254
+
178
255
void assertEmbeddingInference (String inferenceId , CohereEmbeddingType type ) throws IOException {
179
256
switch (type ) {
180
257
case INT8 :
@@ -195,6 +272,8 @@ public void testRerank() throws IOException {
195
272
String old_cluster_endpoint_identifier = oldClusterHasFeature ("gte_v" + MODELS_RENAMED_TO_ENDPOINTS ) ? "endpoints" : "models" ;
196
273
assumeTrue ("Cohere rerank service added in " + COHERE_RERANK_ADDED , rerankSupported );
197
274
275
+ ApiVersion oldClusterApiVersion = oldClusterHasFeature (COHERE_V2_API_ADDED_TEST_FEATURE ) ? ApiVersion .V2 : ApiVersion .V1 ;
276
+
198
277
final String oldClusterId = "old-cluster-rerank" ;
199
278
final String upgradedClusterId = "upgraded-cluster-rerank" ;
200
279
@@ -217,7 +296,6 @@ public void testRerank() throws IOException {
217
296
assertThat (taskSettings , hasEntry ("top_n" , 3 ));
218
297
219
298
assertRerank (oldClusterId );
220
-
221
299
} else if (isUpgradedCluster ()) {
222
300
// check old cluster model
223
301
var configs = (List <Map <String , Object >>) get (testTaskType , oldClusterId ).get ("endpoints" );
@@ -228,6 +306,11 @@ public void testRerank() throws IOException {
228
306
assertThat (taskSettings , hasEntry ("top_n" , 3 ));
229
307
230
308
assertRerank (oldClusterId );
309
+ assertVersionInPath (
310
+ cohereRerankServer .requests ().get (cohereRerankServer .requests ().size () - 1 ),
311
+ "rerank" ,
312
+ oldClusterApiVersion
313
+ );
231
314
232
315
// New endpoint
233
316
cohereRerankServer .enqueue (new MockResponse ().setResponseCode (200 ).setBody (rerankResponse ()));
@@ -236,6 +319,27 @@ public void testRerank() throws IOException {
236
319
assertThat (configs , hasSize (1 ));
237
320
238
321
assertRerank (upgradedClusterId );
322
+ assertVersionInPath (cohereRerankServer .requests ().get (cohereRerankServer .requests ().size () - 1 ), "rerank" , ApiVersion .V2 );
323
+
324
+ {
325
+ // new endpoints use the V2 API which require the model_id to be set
326
+ final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id" ;
327
+ var jsonBody = Strings .format ("""
328
+ {
329
+ "service": "cohere",
330
+ "service_settings": {
331
+ "url": "%s",
332
+ "api_key": "XXXX"
333
+ }
334
+ }
335
+ """ , getUrl (cohereEmbeddingsServer ));
336
+
337
+ var e = expectThrows (ResponseException .class , () -> put (upgradedClusterNoModel , jsonBody , testTaskType ));
338
+ assertThat (
339
+ e .getMessage (),
340
+ containsString ("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API." )
341
+ );
342
+ }
239
343
240
344
delete (oldClusterId );
241
345
delete (upgradedClusterId );
0 commit comments