Skip to content

[8.19] [ML] Move to the Cohere V2 API for new inference endpoints (#129884) #129988

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/129884.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 129884
summary: Move to the Cohere V2 API for new inference endpoints
area: Machine Learning
type: enhancement
issues: []
2 changes: 0 additions & 2 deletions muted-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,6 @@ tests:
- class: org.elasticsearch.xpack.sql.qa.single_node.JdbcDocCsvSpecIT
method: test {docs.testFilterToday}
issue: https://github.com/elastic/elasticsearch/issues/121474
- class: org.elasticsearch.xpack.application.CohereServiceUpgradeIT
issue: https://github.com/elastic/elasticsearch/issues/121537
- class: org.elasticsearch.xpack.test.rest.XPackRestIT
method: test {p0=transform/*}
issue: https://github.com/elastic/elasticsearch/issues/120816
Expand Down
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ static TransportVersion def(int id) {
public static final TransportVersion RANDOM_SAMPLER_QUERY_BUILDER_8_19 = def(8_841_0_56);
public static final TransportVersion ML_INFERENCE_SAGEMAKER_ELASTIC_8_19 = def(8_841_0_57);
public static final TransportVersion SPARSE_VECTOR_FIELD_PRUNING_OPTIONS_8_19 = def(8_841_0_58);
public static final TransportVersion ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED_8_19 = def(8_841_0_59);
public static final TransportVersion ML_INFERENCE_COHERE_API_VERSION_8_19 = def(8_841_0_60);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@

import com.carrotsearch.randomizedtesting.annotations.Name;

import org.elasticsearch.client.ResponseException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.http.MockRequest;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType;
Expand All @@ -24,6 +26,7 @@

import static org.hamcrest.Matchers.anEmptyMap;
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.hasSize;
Expand All @@ -35,11 +38,16 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase {

private static final String COHERE_EMBEDDINGS_ADDED = "8.13.0";
private static final String COHERE_RERANK_ADDED = "8.14.0";
private static final String BYTE_ALIAS_FOR_INT8_ADDED = "8.14.0";
private static final String COHERE_V2_API_ADDED_TEST_FEATURE = "inference.cohere.v2";

private static MockWebServer cohereEmbeddingsServer;
private static MockWebServer cohereRerankServer;

private enum ApiVersion {
V1,
V2
}

public CohereServiceUpgradeIT(@Name("upgradedNodes") int upgradedNodes) {
super(upgradedNodes);
}
Expand All @@ -64,14 +72,15 @@ public void testCohereEmbeddings() throws IOException {
var embeddingsSupported = getOldClusterTestVersion().onOrAfter(COHERE_EMBEDDINGS_ADDED);
// `gte_v` indicates that the cluster version is Greater Than or Equal to MODELS_RENAMED_TO_ENDPOINTS
String oldClusterEndpointIdentifier = oldClusterHasFeature("gte_v" + MODELS_RENAMED_TO_ENDPOINTS) ? "endpoints" : "models";
assumeTrue("Cohere embedding service added in " + COHERE_EMBEDDINGS_ADDED, embeddingsSupported);
ApiVersion oldClusterApiVersion = oldClusterHasFeature(COHERE_V2_API_ADDED_TEST_FEATURE) ? ApiVersion.V2 : ApiVersion.V1;

final String oldClusterIdInt8 = "old-cluster-embeddings-int8";
final String oldClusterIdFloat = "old-cluster-embeddings-float";

var testTaskType = TaskType.TEXT_EMBEDDING;

if (isOldCluster()) {

// queue a response as PUT will call the service
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
put(oldClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), testTaskType);
Expand Down Expand Up @@ -129,13 +138,29 @@ public void testCohereEmbeddings() throws IOException {

// Inference on old cluster models
assertEmbeddingInference(oldClusterIdInt8, CohereEmbeddingType.BYTE);
assertVersionInPath(
cohereEmbeddingsServer.requests().get(cohereEmbeddingsServer.requests().size() - 1),
"embed",
oldClusterApiVersion
);
assertEmbeddingInference(oldClusterIdFloat, CohereEmbeddingType.FLOAT);
assertVersionInPath(
cohereEmbeddingsServer.requests().get(cohereEmbeddingsServer.requests().size() - 1),
"embed",
oldClusterApiVersion
);

{
final String upgradedClusterIdByte = "upgraded-cluster-embeddings-byte";

// new endpoints use the V2 API
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
put(upgradedClusterIdByte, embeddingConfigByte(getUrl(cohereEmbeddingsServer)), testTaskType);
assertVersionInPath(
cohereEmbeddingsServer.requests().get(cohereEmbeddingsServer.requests().size() - 1),
"embed",
ApiVersion.V2
);

configs = (List<Map<String, Object>>) get(testTaskType, upgradedClusterIdByte).get("endpoints");
serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
Expand All @@ -147,34 +172,86 @@ public void testCohereEmbeddings() throws IOException {
{
final String upgradedClusterIdInt8 = "upgraded-cluster-embeddings-int8";

// new endpoints use the V2 API
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte()));
put(upgradedClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), testTaskType);
assertVersionInPath(
cohereEmbeddingsServer.requests().get(cohereEmbeddingsServer.requests().size() - 1),
"embed",
ApiVersion.V2
);

configs = (List<Map<String, Object>>) get(testTaskType, upgradedClusterIdInt8).get("endpoints");
serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
assertThat(serviceSettings, hasEntry("embedding_type", "byte")); // int8 rewritten to byte

assertEmbeddingInference(upgradedClusterIdInt8, CohereEmbeddingType.INT8);
assertVersionInPath(
cohereEmbeddingsServer.requests().get(cohereEmbeddingsServer.requests().size() - 1),
"embed",
ApiVersion.V2
);
delete(upgradedClusterIdInt8);
}
{
final String upgradedClusterIdFloat = "upgraded-cluster-embeddings-float";
cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat()));
put(upgradedClusterIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), testTaskType);
assertVersionInPath(
cohereEmbeddingsServer.requests().get(cohereEmbeddingsServer.requests().size() - 1),
"embed",
ApiVersion.V2
);

configs = (List<Map<String, Object>>) get(testTaskType, upgradedClusterIdFloat).get("endpoints");
serviceSettings = (Map<String, Object>) configs.get(0).get("service_settings");
assertThat(serviceSettings, hasEntry("embedding_type", "float"));

assertEmbeddingInference(upgradedClusterIdFloat, CohereEmbeddingType.FLOAT);
assertVersionInPath(
cohereEmbeddingsServer.requests().get(cohereEmbeddingsServer.requests().size() - 1),
"embed",
ApiVersion.V2
);
delete(upgradedClusterIdFloat);
}
{
// new endpoints use the V2 API which require the model to be set
final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id";
var jsonBody = Strings.format("""
{
"service": "cohere",
"service_settings": {
"url": "%s",
"api_key": "XXXX",
"embedding_type": "int8"
}
}
""", getUrl(cohereEmbeddingsServer));

var e = expectThrows(ResponseException.class, () -> put(upgradedClusterNoModel, jsonBody, testTaskType));
assertThat(
e.getMessage(),
containsString("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API.")
);
}

delete(oldClusterIdFloat);
delete(oldClusterIdInt8);
}
}

private void assertVersionInPath(MockRequest request, String endpoint, ApiVersion apiVersion) {
switch (apiVersion) {
case V2:
assertEquals("/v2/" + endpoint, request.getUri().getPath());
break;
case V1:
assertEquals("/v1/" + endpoint, request.getUri().getPath());
break;
}
}

void assertEmbeddingInference(String inferenceId, CohereEmbeddingType type) throws IOException {
switch (type) {
case INT8:
Expand All @@ -195,6 +272,8 @@ public void testRerank() throws IOException {
String old_cluster_endpoint_identifier = oldClusterHasFeature("gte_v" + MODELS_RENAMED_TO_ENDPOINTS) ? "endpoints" : "models";
assumeTrue("Cohere rerank service added in " + COHERE_RERANK_ADDED, rerankSupported);

ApiVersion oldClusterApiVersion = oldClusterHasFeature(COHERE_V2_API_ADDED_TEST_FEATURE) ? ApiVersion.V2 : ApiVersion.V1;

final String oldClusterId = "old-cluster-rerank";
final String upgradedClusterId = "upgraded-cluster-rerank";

Expand All @@ -217,7 +296,6 @@ public void testRerank() throws IOException {
assertThat(taskSettings, hasEntry("top_n", 3));

assertRerank(oldClusterId);

} else if (isUpgradedCluster()) {
// check old cluster model
var configs = (List<Map<String, Object>>) get(testTaskType, oldClusterId).get("endpoints");
Expand All @@ -228,6 +306,11 @@ public void testRerank() throws IOException {
assertThat(taskSettings, hasEntry("top_n", 3));

assertRerank(oldClusterId);
assertVersionInPath(
cohereRerankServer.requests().get(cohereRerankServer.requests().size() - 1),
"rerank",
oldClusterApiVersion
);

// New endpoint
cohereRerankServer.enqueue(new MockResponse().setResponseCode(200).setBody(rerankResponse()));
Expand All @@ -236,6 +319,27 @@ public void testRerank() throws IOException {
assertThat(configs, hasSize(1));

assertRerank(upgradedClusterId);
assertVersionInPath(cohereRerankServer.requests().get(cohereRerankServer.requests().size() - 1), "rerank", ApiVersion.V2);

{
// new endpoints use the V2 API which require the model_id to be set
final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id";
var jsonBody = Strings.format("""
{
"service": "cohere",
"service_settings": {
"url": "%s",
"api_key": "XXXX"
}
}
""", getUrl(cohereEmbeddingsServer));

var e = expectThrows(ResponseException.class, () -> put(upgradedClusterNoModel, jsonBody, testTaskType));
assertThat(
e.getMessage(),
containsString("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API.")
);
}

delete(oldClusterId);
delete(upgradedClusterId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public Set<NodeFeature> getFeatures() {
"test_reranking_service.parse_text_as_score"
);
private static final NodeFeature SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER = new NodeFeature("semantic_text.match_all_highlighter");
private static final NodeFeature COHERE_V2_API = new NodeFeature("inference.cohere.v2");

@Override
public Set<NodeFeature> getTestFeatures() {
Expand All @@ -72,7 +73,8 @@ public Set<NodeFeature> getTestFeatures() {
SEMANTIC_TEXT_SUPPORT_CHUNKING_CONFIG,
SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER,
SEMANTIC_TEXT_EXCLUDE_SUB_FIELDS_FROM_FIELD_CAPS,
SEMANTIC_TEXT_INDEX_OPTIONS
SEMANTIC_TEXT_INDEX_OPTIONS,
COHERE_V2_API
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ public static QueryAndDocsInputs of(InferenceInputs inferenceInputs) {
private final Boolean returnDocuments;
private final Integer topN;

public QueryAndDocsInputs(String query, List<String> chunks) {
this(query, chunks, null, null, false);
}

public QueryAndDocsInputs(
String query,
List<String> chunks,
Expand All @@ -45,6 +41,10 @@ public QueryAndDocsInputs(
this.topN = topN;
}

public QueryAndDocsInputs(String query, List<String> chunks) {
this(query, chunks, null, null, false);
}

public String getQuery() {
return query;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,35 @@

package org.elasticsearch.xpack.inference.services.cohere;

import org.elasticsearch.common.CheckedSupplier;
import org.apache.http.client.utils.URIBuilder;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils;

import java.net.URI;
import java.net.URISyntaxException;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri;

public record CohereAccount(URI uri, SecureString apiKey) {

public static CohereAccount of(CohereModel model, CheckedSupplier<URI, URISyntaxException> uriBuilder) {
var uri = buildUri(model.uri(), "Cohere", uriBuilder);

return new CohereAccount(uri, model.apiKey());
public record CohereAccount(URI baseUri, SecureString apiKey) {

public static CohereAccount of(CohereModel model) {
try {
var uri = model.baseUri() != null ? model.baseUri() : new URIBuilder().setScheme("https").setHost(CohereUtils.HOST).build();
return new CohereAccount(uri, model.apiKey());
} catch (URISyntaxException e) {
// using bad request here so that potentially sensitive URL information does not get logged
throw new ElasticsearchStatusException(
Strings.format("Failed to construct %s URL", CohereService.NAME),
RestStatus.BAD_REQUEST,
e
);
}
}

public CohereAccount {
Objects.requireNonNull(uri);
Objects.requireNonNull(baseUri);
Objects.requireNonNull(apiKey);
}
}
Loading