Skip to content

Commit e1894f8

Browse files
authored
[ML] Move to the Cohere V2 API for new inference endpoints (#129884) (#129988)
(cherry picked from commit 3a1551e) # Conflicts: # muted-tests.yml # server/src/main/java/org/elasticsearch/TransportVersions.java # x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java
1 parent d7f2113 commit e1894f8

File tree

58 files changed

+2226
-1359
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+2226
-1359
lines changed

docs/changelog/129884.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 129884
2+
summary: Move to the Cohere V2 API for new inference endpoints
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

muted-tests.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,6 @@ tests:
353353
- class: org.elasticsearch.xpack.sql.qa.single_node.JdbcDocCsvSpecIT
354354
method: test {docs.testFilterToday}
355355
issue: https://github.com/elastic/elasticsearch/issues/121474
356-
- class: org.elasticsearch.xpack.application.CohereServiceUpgradeIT
357-
issue: https://github.com/elastic/elasticsearch/issues/121537
358356
- class: org.elasticsearch.xpack.test.rest.XPackRestIT
359357
method: test {p0=transform/*}
360358
issue: https://github.com/elastic/elasticsearch/issues/120816

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,8 @@ static TransportVersion def(int id) {
251251
public static final TransportVersion RANDOM_SAMPLER_QUERY_BUILDER_8_19 = def(8_841_0_56);
252252
public static final TransportVersion ML_INFERENCE_SAGEMAKER_ELASTIC_8_19 = def(8_841_0_57);
253253
public static final TransportVersion SPARSE_VECTOR_FIELD_PRUNING_OPTIONS_8_19 = def(8_841_0_58);
254+
public static final TransportVersion ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED_8_19 = def(8_841_0_59);
255+
public static final TransportVersion ML_INFERENCE_COHERE_API_VERSION_8_19 = def(8_841_0_60);
254256

255257
/*
256258
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java

Lines changed: 107 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99

1010
import com.carrotsearch.randomizedtesting.annotations.Name;
1111

12+
import org.elasticsearch.client.ResponseException;
1213
import org.elasticsearch.common.Strings;
1314
import org.elasticsearch.inference.TaskType;
15+
import org.elasticsearch.test.http.MockRequest;
1416
import org.elasticsearch.test.http.MockResponse;
1517
import org.elasticsearch.test.http.MockWebServer;
1618
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
@@ -24,6 +26,7 @@
2426

2527
import static org.hamcrest.Matchers.anEmptyMap;
2628
import static org.hamcrest.Matchers.anyOf;
29+
import static org.hamcrest.Matchers.containsString;
2730
import static org.hamcrest.Matchers.empty;
2831
import static org.hamcrest.Matchers.hasEntry;
2932
import static org.hamcrest.Matchers.hasSize;
@@ -35,11 +38,16 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase {
3538

3639
private static final String COHERE_EMBEDDINGS_ADDED = "8.13.0";
3740
private static final String COHERE_RERANK_ADDED = "8.14.0";
38-
private static final String BYTE_ALIAS_FOR_INT8_ADDED = "8.14.0";
41+
private static final String COHERE_V2_API_ADDED_TEST_FEATURE = "inference.cohere.v2";
3942

4043
private static MockWebServer cohereEmbeddingsServer;
4144
private static MockWebServer cohereRerankServer;
4245

46+
private enum ApiVersion {
47+
V1,
48+
V2
49+
}
50+
4351
public CohereServiceUpgradeIT(@Name("upgradedNodes") int upgradedNodes) {
4452
super(upgradedNodes);
4553
}
@@ -64,14 +72,15 @@ public void testCohereEmbeddings() throws IOException {
6472
var embeddingsSupported = getOldClusterTestVersion().onOrAfter(COHERE_EMBEDDINGS_ADDED);
6573
// `gte_v` indicates that the cluster version is Greater Than or Equal to MODELS_RENAMED_TO_ENDPOINTS
6674
String oldClusterEndpointIdentifier = oldClusterHasFeature("gte_v" + MODELS_RENAMED_TO_ENDPOINTS) ? "endpoints" : "models";
67-
assumeTrue("Cohere embedding service added in " + COHERE_EMBEDDINGS_ADDED, embeddingsSupported);
75+
ApiVersion oldClusterApiVersion = oldClusterHasFeature(COHERE_V2_API_ADDED_TEST_FEATURE) ? ApiVersion.V2 : ApiVersion.V1;
6876

6977
final String oldClusterIdInt8 = "old-cluster-embeddings-int8";
7078
final String oldClusterIdFloat = "old-cluster-embeddings-float";
7179

7280
var testTaskType = TaskType.TEXT_EMBEDDING;
7381

7482
if (isOldCluster()) {
83+
7584
// queue a response as PUT will call the service
7685
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
7786
put(oldClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), testTaskType);
@@ -129,13 +138,29 @@ public void testCohereEmbeddings() throws IOException {
129138

130139
// Inference on old cluster models
131140
assertEmbeddingInference(oldClusterIdInt8, CohereEmbeddingType.BYTE);
141+
assertVersionInPath(
142+
cohereEmbeddingsServer.requests().get(cohereEmbeddingsServer.requests().size() - 1),
143+
"embed",
144+
oldClusterApiVersion
145+
);
132146
assertEmbeddingInference(oldClusterIdFloat, CohereEmbeddingType.FLOAT);
147+
assertVersionInPath(
148+
cohereEmbeddingsServer.requests().get(cohereEmbeddingsServer.requests().size() - 1),
149+
"embed",
150+
oldClusterApiVersion
151+
);
133152

134153
{
135154
final String upgradedClusterIdByte = "upgraded-cluster-embeddings-byte";
136155

156+
// new endpoints use the V2 API
137157
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
138158
put(upgradedClusterIdByte, embeddingConfigByte(getUrl(cohereEmbeddingsServer)), testTaskType);
159+
assertVersionInPath(
160+
cohereEmbeddingsServer.requests().get(cohereEmbeddingsServer.requests().size() - 1),
161+
"embed",
162+
ApiVersion.V2
163+
);
139164

140165
configs = (List<Map<String, Object>>) get(testTaskType, upgradedClusterIdByte).get("endpoints");
141166
serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
@@ -147,34 +172,86 @@ public void testCohereEmbeddings() throws IOException {
147172
{
148173
final String upgradedClusterIdInt8 = "upgraded-cluster-embeddings-int8";
149174

175+
// new endpoints use the V2 API
150176
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
151177
put(upgradedClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), testTaskType);
178+
assertVersionInPath(
179+
cohereEmbeddingsServer.requests().get(cohereEmbeddingsServer.requests().size() - 1),
180+
"embed",
181+
ApiVersion.V2
182+
);
152183

153184
configs = (List<Map<String, Object>>) get(testTaskType, upgradedClusterIdInt8).get("endpoints");
154185
serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
155186
assertThat(serviceSettings, hasEntry("embedding_type", "byte")); // int8 rewritten to byte
156187

157188
assertEmbeddingInference(upgradedClusterIdInt8, CohereEmbeddingType.INT8);
189+
assertVersionInPath(
190+
cohereEmbeddingsServer.requests().get(cohereEmbeddingsServer.requests().size() - 1),
191+
"embed",
192+
ApiVersion.V2
193+
);
158194
delete(upgradedClusterIdInt8);
159195
}
160196
{
161197
final String upgradedClusterIdFloat = "upgraded-cluster-embeddings-float";
162198
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat()));
163199
put(upgradedClusterIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), testTaskType);
200+
assertVersionInPath(
201+
cohereEmbeddingsServer.requests().get(cohereEmbeddingsServer.requests().size() - 1),
202+
"embed",
203+
ApiVersion.V2
204+
);
164205

165206
configs = (List<Map<String, Object>>) get(testTaskType, upgradedClusterIdFloat).get("endpoints");
166207
serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
167208
assertThat(serviceSettings, hasEntry("embedding_type", "float"));
168209

169210
assertEmbeddingInference(upgradedClusterIdFloat, CohereEmbeddingType.FLOAT);
211+
assertVersionInPath(
212+
cohereEmbeddingsServer.requests().get(cohereEmbeddingsServer.requests().size() - 1),
213+
"embed",
214+
ApiVersion.V2
215+
);
170216
delete(upgradedClusterIdFloat);
171217
}
218+
{
219+
// new endpoints use the V2 API which require the model to be set
220+
final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id";
221+
var jsonBody = Strings.format("""
222+
{
223+
"service": "cohere",
224+
"service_settings": {
225+
"url": "%s",
226+
"api_key": "XXXX",
227+
"embedding_type": "int8"
228+
}
229+
}
230+
""", getUrl(cohereEmbeddingsServer));
231+
232+
var e = expectThrows(ResponseException.class, () -> put(upgradedClusterNoModel, jsonBody, testTaskType));
233+
assertThat(
234+
e.getMessage(),
235+
containsString("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API.")
236+
);
237+
}
172238

173239
delete(oldClusterIdFloat);
174240
delete(oldClusterIdInt8);
175241
}
176242
}
177243

244+
private void assertVersionInPath(MockRequest request, String endpoint, ApiVersion apiVersion) {
245+
switch (apiVersion) {
246+
case V2:
247+
assertEquals("/v2/" + endpoint, request.getUri().getPath());
248+
break;
249+
case V1:
250+
assertEquals("/v1/" + endpoint, request.getUri().getPath());
251+
break;
252+
}
253+
}
254+
178255
void assertEmbeddingInference(String inferenceId, CohereEmbeddingType type) throws IOException {
179256
switch (type) {
180257
case INT8:
@@ -195,6 +272,8 @@ public void testRerank() throws IOException {
195272
String old_cluster_endpoint_identifier = oldClusterHasFeature("gte_v" + MODELS_RENAMED_TO_ENDPOINTS) ? "endpoints" : "models";
196273
assumeTrue("Cohere rerank service added in " + COHERE_RERANK_ADDED, rerankSupported);
197274

275+
ApiVersion oldClusterApiVersion = oldClusterHasFeature(COHERE_V2_API_ADDED_TEST_FEATURE) ? ApiVersion.V2 : ApiVersion.V1;
276+
198277
final String oldClusterId = "old-cluster-rerank";
199278
final String upgradedClusterId = "upgraded-cluster-rerank";
200279

@@ -217,7 +296,6 @@ public void testRerank() throws IOException {
217296
assertThat(taskSettings, hasEntry("top_n", 3));
218297

219298
assertRerank(oldClusterId);
220-
221299
} else if (isUpgradedCluster()) {
222300
// check old cluster model
223301
var configs = (List<Map<String, Object>>) get(testTaskType, oldClusterId).get("endpoints");
@@ -228,6 +306,11 @@ public void testRerank() throws IOException {
228306
assertThat(taskSettings, hasEntry("top_n", 3));
229307

230308
assertRerank(oldClusterId);
309+
assertVersionInPath(
310+
cohereRerankServer.requests().get(cohereRerankServer.requests().size() - 1),
311+
"rerank",
312+
oldClusterApiVersion
313+
);
231314

232315
// New endpoint
233316
cohereRerankServer.enqueue(new MockResponse().setResponseCode(200).setBody(rerankResponse()));
@@ -236,6 +319,27 @@ public void testRerank() throws IOException {
236319
assertThat(configs, hasSize(1));
237320

238321
assertRerank(upgradedClusterId);
322+
assertVersionInPath(cohereRerankServer.requests().get(cohereRerankServer.requests().size() - 1), "rerank", ApiVersion.V2);
323+
324+
{
325+
// new endpoints use the V2 API which require the model_id to be set
326+
final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id";
327+
var jsonBody = Strings.format("""
328+
{
329+
"service": "cohere",
330+
"service_settings": {
331+
"url": "%s",
332+
"api_key": "XXXX"
333+
}
334+
}
335+
""", getUrl(cohereEmbeddingsServer));
336+
337+
var e = expectThrows(ResponseException.class, () -> put(upgradedClusterNoModel, jsonBody, testTaskType));
338+
assertThat(
339+
e.getMessage(),
340+
containsString("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API.")
341+
);
342+
}
239343

240344
delete(oldClusterId);
241345
delete(upgradedClusterId);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ public Set<NodeFeature> getFeatures() {
4646
"test_reranking_service.parse_text_as_score"
4747
);
4848
private static final NodeFeature SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER = new NodeFeature("semantic_text.match_all_highlighter");
49+
private static final NodeFeature COHERE_V2_API = new NodeFeature("inference.cohere.v2");
4950

5051
@Override
5152
public Set<NodeFeature> getTestFeatures() {
@@ -72,7 +73,8 @@ public Set<NodeFeature> getTestFeatures() {
7273
SEMANTIC_TEXT_SUPPORT_CHUNKING_CONFIG,
7374
SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER,
7475
SEMANTIC_TEXT_EXCLUDE_SUB_FIELDS_FROM_FIELD_CAPS,
75-
SEMANTIC_TEXT_INDEX_OPTIONS
76+
SEMANTIC_TEXT_INDEX_OPTIONS,
77+
COHERE_V2_API
7678
);
7779
}
7880
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,6 @@ public static QueryAndDocsInputs of(InferenceInputs inferenceInputs) {
2727
private final Boolean returnDocuments;
2828
private final Integer topN;
2929

30-
public QueryAndDocsInputs(String query, List<String> chunks) {
31-
this(query, chunks, null, null, false);
32-
}
33-
3430
public QueryAndDocsInputs(
3531
String query,
3632
List<String> chunks,
@@ -45,6 +41,10 @@ public QueryAndDocsInputs(
4541
this.topN = topN;
4642
}
4743

44+
public QueryAndDocsInputs(String query, List<String> chunks) {
45+
this(query, chunks, null, null, false);
46+
}
47+
4848
public String getQuery() {
4949
return query;
5050
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereAccount.java

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,35 @@
77

88
package org.elasticsearch.xpack.inference.services.cohere;
99

10-
import org.elasticsearch.common.CheckedSupplier;
10+
import org.apache.http.client.utils.URIBuilder;
11+
import org.elasticsearch.ElasticsearchStatusException;
12+
import org.elasticsearch.common.Strings;
1113
import org.elasticsearch.common.settings.SecureString;
14+
import org.elasticsearch.rest.RestStatus;
15+
import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils;
1216

1317
import java.net.URI;
1418
import java.net.URISyntaxException;
1519
import java.util.Objects;
1620

17-
import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri;
18-
19-
public record CohereAccount(URI uri, SecureString apiKey) {
20-
21-
public static CohereAccount of(CohereModel model, CheckedSupplier<URI, URISyntaxException> uriBuilder) {
22-
var uri = buildUri(model.uri(), "Cohere", uriBuilder);
23-
24-
return new CohereAccount(uri, model.apiKey());
21+
public record CohereAccount(URI baseUri, SecureString apiKey) {
22+
23+
public static CohereAccount of(CohereModel model) {
24+
try {
25+
var uri = model.baseUri() != null ? model.baseUri() : new URIBuilder().setScheme("https").setHost(CohereUtils.HOST).build();
26+
return new CohereAccount(uri, model.apiKey());
27+
} catch (URISyntaxException e) {
28+
// using bad request here so that potentially sensitive URL information does not get logged
29+
throw new ElasticsearchStatusException(
30+
Strings.format("Failed to construct %s URL", CohereService.NAME),
31+
RestStatus.BAD_REQUEST,
32+
e
33+
);
34+
}
2535
}
2636

2737
public CohereAccount {
28-
Objects.requireNonNull(uri);
38+
Objects.requireNonNull(baseUri);
2939
Objects.requireNonNull(apiKey);
3040
}
3141
}

0 commit comments

Comments
 (0)