diff --git a/docs/changelog/129884.yaml b/docs/changelog/129884.yaml new file mode 100644 index 0000000000000..a3ae373f2dbd0 --- /dev/null +++ b/docs/changelog/129884.yaml @@ -0,0 +1,5 @@ +pr: 129884 +summary: Move to the Cohere V2 API for new inference endpoints +area: Machine Learning +type: enhancement +issues: [] diff --git a/muted-tests.yml b/muted-tests.yml index 9fea98b7f3575..064b0e311210b 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -353,8 +353,6 @@ tests: - class: org.elasticsearch.xpack.sql.qa.single_node.JdbcDocCsvSpecIT method: test {docs.testFilterToday} issue: https://github.com/elastic/elasticsearch/issues/121474 - - class: org.elasticsearch.xpack.application.CohereServiceUpgradeIT - issue: https://github.com/elastic/elasticsearch/issues/121537 - class: org.elasticsearch.xpack.test.rest.XPackRestIT method: test {p0=transform/*} issue: https://github.com/elastic/elasticsearch/issues/120816 diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 2d89c335af917..ce1016cd01c74 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -251,6 +251,8 @@ static TransportVersion def(int id) { public static final TransportVersion RANDOM_SAMPLER_QUERY_BUILDER_8_19 = def(8_841_0_56); public static final TransportVersion ML_INFERENCE_SAGEMAKER_ELASTIC_8_19 = def(8_841_0_57); public static final TransportVersion SPARSE_VECTOR_FIELD_PRUNING_OPTIONS_8_19 = def(8_841_0_58); + public static final TransportVersion ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED_8_19 = def(8_841_0_59); + public static final TransportVersion ML_INFERENCE_COHERE_API_VERSION_8_19 = def(8_841_0_60); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java b/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java index 0acbc148515bd..699e6baea12d2 100644 --- a/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java +++ b/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java @@ -9,8 +9,10 @@ import com.carrotsearch.randomizedtesting.annotations.Name; +import org.elasticsearch.client.ResponseException; import org.elasticsearch.common.Strings; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.http.MockRequest; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; @@ -24,6 +26,7 @@ import static org.hamcrest.Matchers.anEmptyMap; import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.hasEntry; import static org.hamcrest.Matchers.hasSize; @@ -35,11 +38,16 @@ public class CohereServiceUpgradeIT extends InferenceUpgradeTestCase { private static final String COHERE_EMBEDDINGS_ADDED = "8.13.0"; private static final String COHERE_RERANK_ADDED = "8.14.0"; - private static final String BYTE_ALIAS_FOR_INT8_ADDED = "8.14.0"; + private static final String COHERE_V2_API_ADDED_TEST_FEATURE = "inference.cohere.v2"; private static MockWebServer cohereEmbeddingsServer; private static MockWebServer cohereRerankServer; + private enum ApiVersion { + V1, + V2 + } + public CohereServiceUpgradeIT(@Name("upgradedNodes") int upgradedNodes) { super(upgradedNodes); } @@ -64,7 +72,7 @@ public void testCohereEmbeddings() throws IOException { var embeddingsSupported = getOldClusterTestVersion().onOrAfter(COHERE_EMBEDDINGS_ADDED); // `gte_v` indicates that the cluster version is Greater Than or Equal to MODELS_RENAMED_TO_ENDPOINTS String oldClusterEndpointIdentifier = oldClusterHasFeature("gte_v" + MODELS_RENAMED_TO_ENDPOINTS) ? "endpoints" : "models"; - assumeTrue("Cohere embedding service added in " + COHERE_EMBEDDINGS_ADDED, embeddingsSupported); + ApiVersion oldClusterApiVersion = oldClusterHasFeature(COHERE_V2_API_ADDED_TEST_FEATURE) ? ApiVersion.V2 : ApiVersion.V1; final String oldClusterIdInt8 = "old-cluster-embeddings-int8"; final String oldClusterIdFloat = "old-cluster-embeddings-float"; @@ -72,6 +80,7 @@ public void testCohereEmbeddings() throws IOException { var testTaskType = TaskType.TEXT_EMBEDDING; if (isOldCluster()) { + // queue a response as PUT will call the service cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte())); put(oldClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), testTaskType); @@ -129,13 +138,29 @@ public void testCohereEmbeddings() throws IOException { // Inference on old cluster models assertEmbeddingInference(oldClusterIdInt8, CohereEmbeddingType.BYTE); + assertVersionInPath( + cohereEmbeddingsServer.requests().get(cohereEmbeddingsServer.requests().size() - 1), + "embed", + oldClusterApiVersion + ); assertEmbeddingInference(oldClusterIdFloat, CohereEmbeddingType.FLOAT); + assertVersionInPath( + cohereEmbeddingsServer.requests().get(cohereEmbeddingsServer.requests().size() - 1), + "embed", + oldClusterApiVersion + ); { final String upgradedClusterIdByte = "upgraded-cluster-embeddings-byte"; + // new endpoints use the V2 API cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte())); put(upgradedClusterIdByte, embeddingConfigByte(getUrl(cohereEmbeddingsServer)), testTaskType); + assertVersionInPath( + cohereEmbeddingsServer.requests().get(cohereEmbeddingsServer.requests().size() - 1), + "embed", + ApiVersion.V2 + ); configs = (List>) get(testTaskType, upgradedClusterIdByte).get("endpoints"); serviceSettings = (Map) configs.get(0).get("service_settings"); @@ -147,34 +172,86 @@ public void testCohereEmbeddings() throws IOException { { final String upgradedClusterIdInt8 = "upgraded-cluster-embeddings-int8"; + // new endpoints use the V2 API cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte())); put(upgradedClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), testTaskType); + assertVersionInPath( + cohereEmbeddingsServer.requests().get(cohereEmbeddingsServer.requests().size() - 1), + "embed", + ApiVersion.V2 + ); configs = (List>) get(testTaskType, upgradedClusterIdInt8).get("endpoints"); serviceSettings = (Map) configs.get(0).get("service_settings"); assertThat(serviceSettings, hasEntry("embedding_type", "byte")); // int8 rewritten to byte assertEmbeddingInference(upgradedClusterIdInt8, CohereEmbeddingType.INT8); + assertVersionInPath( + cohereEmbeddingsServer.requests().get(cohereEmbeddingsServer.requests().size() - 1), + "embed", + ApiVersion.V2 + ); delete(upgradedClusterIdInt8); } { final String upgradedClusterIdFloat = "upgraded-cluster-embeddings-float"; cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat())); put(upgradedClusterIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), testTaskType); + assertVersionInPath( + cohereEmbeddingsServer.requests().get(cohereEmbeddingsServer.requests().size() - 1), + "embed", + ApiVersion.V2 + ); configs = (List>) get(testTaskType, upgradedClusterIdFloat).get("endpoints"); serviceSettings = (Map) configs.get(0).get("service_settings"); assertThat(serviceSettings, hasEntry("embedding_type", "float")); assertEmbeddingInference(upgradedClusterIdFloat, CohereEmbeddingType.FLOAT); + assertVersionInPath( + cohereEmbeddingsServer.requests().get(cohereEmbeddingsServer.requests().size() - 1), + "embed", + ApiVersion.V2 + ); delete(upgradedClusterIdFloat); } + { + // new endpoints use the V2 API which require the model to be set + final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id"; + var jsonBody = Strings.format(""" + { + "service": "cohere", + "service_settings": { + "url": "%s", + "api_key": "XXXX", + "embedding_type": "int8" + } + } + """, getUrl(cohereEmbeddingsServer)); + + var e = expectThrows(ResponseException.class, () -> put(upgradedClusterNoModel, jsonBody, testTaskType)); + assertThat( + e.getMessage(), + containsString("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API.") + ); + } delete(oldClusterIdFloat); delete(oldClusterIdInt8); } } + private void assertVersionInPath(MockRequest request, String endpoint, ApiVersion apiVersion) { + switch (apiVersion) { + case V2: + assertEquals("/v2/" + endpoint, request.getUri().getPath()); + break; + case V1: + assertEquals("/v1/" + endpoint, request.getUri().getPath()); + break; + } + } + void assertEmbeddingInference(String inferenceId, CohereEmbeddingType type) throws IOException { switch (type) { case INT8: @@ -195,6 +272,8 @@ public void testRerank() throws IOException { String old_cluster_endpoint_identifier = oldClusterHasFeature("gte_v" + MODELS_RENAMED_TO_ENDPOINTS) ? "endpoints" : "models"; assumeTrue("Cohere rerank service added in " + COHERE_RERANK_ADDED, rerankSupported); + ApiVersion oldClusterApiVersion = oldClusterHasFeature(COHERE_V2_API_ADDED_TEST_FEATURE) ? ApiVersion.V2 : ApiVersion.V1; + final String oldClusterId = "old-cluster-rerank"; final String upgradedClusterId = "upgraded-cluster-rerank"; @@ -217,7 +296,6 @@ public void testRerank() throws IOException { assertThat(taskSettings, hasEntry("top_n", 3)); assertRerank(oldClusterId); - } else if (isUpgradedCluster()) { // check old cluster model var configs = (List>) get(testTaskType, oldClusterId).get("endpoints"); @@ -228,6 +306,11 @@ public void testRerank() throws IOException { assertThat(taskSettings, hasEntry("top_n", 3)); assertRerank(oldClusterId); + assertVersionInPath( + cohereRerankServer.requests().get(cohereRerankServer.requests().size() - 1), + "rerank", + oldClusterApiVersion + ); // New endpoint cohereRerankServer.enqueue(new MockResponse().setResponseCode(200).setBody(rerankResponse())); @@ -236,6 +319,27 @@ public void testRerank() throws IOException { assertThat(configs, hasSize(1)); assertRerank(upgradedClusterId); + assertVersionInPath(cohereRerankServer.requests().get(cohereRerankServer.requests().size() - 1), "rerank", ApiVersion.V2); + + { + // new endpoints use the V2 API which require the model_id to be set + final String upgradedClusterNoModel = "upgraded-cluster-missing-model-id"; + var jsonBody = Strings.format(""" + { + "service": "cohere", + "service_settings": { + "url": "%s", + "api_key": "XXXX" + } + } + """, getUrl(cohereEmbeddingsServer)); + + var e = expectThrows(ResponseException.class, () -> put(upgradedClusterNoModel, jsonBody, testTaskType)); + assertThat( + e.getMessage(), + containsString("Validation Failed: 1: The [service_settings.model_id] field is required for the Cohere V2 API.") + ); + } delete(oldClusterId); delete(upgradedClusterId); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java index 2744aa36b5933..7fdebbb85f521 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java @@ -46,6 +46,7 @@ public Set getFeatures() { "test_reranking_service.parse_text_as_score" ); private static final NodeFeature SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER = new NodeFeature("semantic_text.match_all_highlighter"); + private static final NodeFeature COHERE_V2_API = new NodeFeature("inference.cohere.v2"); @Override public Set getTestFeatures() { @@ -72,7 +73,8 @@ public Set getTestFeatures() { SEMANTIC_TEXT_SUPPORT_CHUNKING_CONFIG, SEMANTIC_TEXT_MATCH_ALL_HIGHLIGHTER, SEMANTIC_TEXT_EXCLUDE_SUB_FIELDS_FROM_FIELD_CAPS, - SEMANTIC_TEXT_INDEX_OPTIONS + SEMANTIC_TEXT_INDEX_OPTIONS, + COHERE_V2_API ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java index a2526a2a293eb..850c96160dc44 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java @@ -27,10 +27,6 @@ public static QueryAndDocsInputs of(InferenceInputs inferenceInputs) { private final Boolean returnDocuments; private final Integer topN; - public QueryAndDocsInputs(String query, List chunks) { - this(query, chunks, null, null, false); - } - public QueryAndDocsInputs( String query, List chunks, @@ -45,6 +41,10 @@ public QueryAndDocsInputs( this.topN = topN; } + public QueryAndDocsInputs(String query, List chunks) { + this(query, chunks, null, null, false); + } + public String getQuery() { return query; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereAccount.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereAccount.java index 297e918cac307..869357ef8fb17 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereAccount.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereAccount.java @@ -7,25 +7,35 @@ package org.elasticsearch.xpack.inference.services.cohere; -import org.elasticsearch.common.CheckedSupplier; +import org.apache.http.client.utils.URIBuilder; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; import java.net.URI; import java.net.URISyntaxException; import java.util.Objects; -import static org.elasticsearch.xpack.inference.external.request.RequestUtils.buildUri; - -public record CohereAccount(URI uri, SecureString apiKey) { - - public static CohereAccount of(CohereModel model, CheckedSupplier uriBuilder) { - var uri = buildUri(model.uri(), "Cohere", uriBuilder); - - return new CohereAccount(uri, model.apiKey()); +public record CohereAccount(URI baseUri, SecureString apiKey) { + + public static CohereAccount of(CohereModel model) { + try { + var uri = model.baseUri() != null ? model.baseUri() : new URIBuilder().setScheme("https").setHost(CohereUtils.HOST).build(); + return new CohereAccount(uri, model.apiKey()); + } catch (URISyntaxException e) { + // using bad request here so that potentially sensitive URL information does not get logged + throw new ElasticsearchStatusException( + Strings.format("Failed to construct %s URL", CohereService.NAME), + RestStatus.BAD_REQUEST, + e + ); + } } public CohereAccount { - Objects.requireNonNull(uri); + Objects.requireNonNull(baseUri); Objects.requireNonNull(apiKey); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereCompletionRequestManager.java deleted file mode 100644 index 2c6b4beb80c5b..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereCompletionRequestManager.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; -import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; -import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; -import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; -import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; -import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel; -import org.elasticsearch.xpack.inference.services.cohere.request.completion.CohereCompletionRequest; -import org.elasticsearch.xpack.inference.services.cohere.response.CohereCompletionResponseEntity; - -import java.util.Objects; -import java.util.function.Supplier; - -public class CohereCompletionRequestManager extends CohereRequestManager { - - private static final Logger logger = LogManager.getLogger(CohereCompletionRequestManager.class); - - private static final ResponseHandler HANDLER = createCompletionHandler(); - - private static ResponseHandler createCompletionHandler() { - return new CohereResponseHandler("cohere completion", CohereCompletionResponseEntity::fromResponse, true); - } - - public static CohereCompletionRequestManager of(CohereCompletionModel model, ThreadPool threadPool) { - return new CohereCompletionRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); - } - - private final CohereCompletionModel model; - - private CohereCompletionRequestManager(CohereCompletionModel model, ThreadPool threadPool) { - super(threadPool, model); - this.model = Objects.requireNonNull(model); - } - - @Override - public void execute( - InferenceInputs inferenceInputs, - RequestSender requestSender, - Supplier hasRequestCompletedFunction, - ActionListener listener - ) { - var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class); - var inputs = chatCompletionInput.getInputs(); - var stream = chatCompletionInput.stream(); - CohereCompletionRequest request = new CohereCompletionRequest(inputs, model, stream); - - execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsRequestManager.java deleted file mode 100644 index e721c3e46cecf..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereEmbeddingsRequestManager.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; -import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; -import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; -import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; -import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; -import org.elasticsearch.xpack.inference.services.cohere.request.CohereEmbeddingsRequest; -import org.elasticsearch.xpack.inference.services.cohere.response.CohereEmbeddingsResponseEntity; - -import java.util.List; -import java.util.Objects; -import java.util.function.Supplier; - -public class CohereEmbeddingsRequestManager extends CohereRequestManager { - private static final Logger logger = LogManager.getLogger(CohereEmbeddingsRequestManager.class); - private static final ResponseHandler HANDLER = createEmbeddingsHandler(); - - private static ResponseHandler createEmbeddingsHandler() { - return new CohereResponseHandler("cohere text embedding", CohereEmbeddingsResponseEntity::fromResponse, false); - } - - public static CohereEmbeddingsRequestManager of(CohereEmbeddingsModel model, ThreadPool threadPool) { - return new CohereEmbeddingsRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); - } - - private final CohereEmbeddingsModel model; - - private CohereEmbeddingsRequestManager(CohereEmbeddingsModel model, ThreadPool threadPool) { - super(threadPool, model); - this.model = Objects.requireNonNull(model); - } - - @Override - public void execute( - InferenceInputs inferenceInputs, - RequestSender requestSender, - Supplier hasRequestCompletedFunction, - ActionListener listener - ) { - EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getStringInputs(); - InputType inputType = input.getInputType(); - - CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(docsInput, inputType, model); - - execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java index 5cb52bdb7f405..2457fb31c9c6e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereModel.java @@ -9,21 +9,23 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.cohere.action.CohereActionVisitor; import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import java.net.URI; import java.util.Map; import java.util.Objects; -public abstract class CohereModel extends Model { +public abstract class CohereModel extends RateLimitGroupingModel { + private final SecureString apiKey; private final CohereRateLimitServiceSettings rateLimitServiceSettings; @@ -63,5 +65,15 @@ public CohereRateLimitServiceSettings rateLimitServiceSettings() { public abstract ExecutableAction accept(CohereActionVisitor creator, Map taskSettings); - public abstract URI uri(); + public RateLimitSettings rateLimitSettings() { + return rateLimitServiceSettings.rateLimitSettings(); + } + + public int rateLimitGroupingHash() { + return apiKey().hashCode(); + } + + public URI baseUri() { + return rateLimitServiceSettings.uri(); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereRateLimitServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereRateLimitServiceSettings.java index 2607359c54c32..5b9fa3376a4f8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereRateLimitServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereRateLimitServiceSettings.java @@ -9,7 +9,12 @@ import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import java.net.URI; + public interface CohereRateLimitServiceSettings { RateLimitSettings rateLimitSettings(); + CohereServiceSettings.CohereApiVersion apiVersion(); + + URI uri(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereRerankRequestManager.java deleted file mode 100644 index 134aab77530e1..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereRerankRequestManager.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; -import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; -import org.elasticsearch.xpack.inference.external.http.sender.ExecutableInferenceRequest; -import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; -import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; -import org.elasticsearch.xpack.inference.services.cohere.request.CohereRerankRequest; -import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; -import org.elasticsearch.xpack.inference.services.cohere.response.CohereRankedResponseEntity; - -import java.util.Objects; -import java.util.function.Supplier; - -public class CohereRerankRequestManager extends CohereRequestManager { - private static final Logger logger = LogManager.getLogger(CohereRerankRequestManager.class); - private static final ResponseHandler HANDLER = createCohereResponseHandler(); - - private static ResponseHandler createCohereResponseHandler() { - return new CohereResponseHandler("cohere rerank", (request, response) -> CohereRankedResponseEntity.fromResponse(response), false); - } - - public static CohereRerankRequestManager of(CohereRerankModel model, ThreadPool threadPool) { - return new CohereRerankRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); - } - - private final CohereRerankModel model; - - private CohereRerankRequestManager(CohereRerankModel model, ThreadPool threadPool) { - super(threadPool, model); - this.model = model; - } - - @Override - public void execute( - InferenceInputs inferenceInputs, - RequestSender requestSender, - Supplier hasRequestCompletedFunction, - ActionListener listener - ) { - var rerankInput = QueryAndDocsInputs.of(inferenceInputs); - CohereRerankRequest request = new CohereRerankRequest( - rerankInput.getQuery(), - rerankInput.getChunks(), - rerankInput.getReturnDocuments(), - rerankInput.getTopN(), - model - ); - - execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index bf6a0bd03122b..c2f1221763165 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -166,24 +166,14 @@ private static CohereModel createModel( return switch (taskType) { case TEXT_EMBEDDING -> new CohereEmbeddingsModel( inferenceEntityId, - taskType, - NAME, serviceSettings, taskSettings, chunkingSettings, secretSettings, context ); - case RERANK -> new CohereRerankModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context); - case COMPLETION -> new CohereCompletionModel( - inferenceEntityId, - taskType, - NAME, - serviceSettings, - taskSettings, - secretSettings, - context - ); + case RERANK -> new CohereRerankModel(inferenceEntityId, serviceSettings, taskSettings, secretSettings, context); + case COMPLETION -> new CohereCompletionModel(inferenceEntityId, serviceSettings, secretSettings, context); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); }; } @@ -324,7 +314,8 @@ public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) { embeddingSize, serviceSettings.getCommonSettings().maxInputTokens(), serviceSettings.getCommonSettings().modelId(), - serviceSettings.getCommonSettings().rateLimitSettings() + serviceSettings.getCommonSettings().rateLimitSettings(), + serviceSettings.getCommonSettings().apiVersion() ), serviceSettings.getEmbeddingType() ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java index a1943b339a561..e31225d6eca8d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettings.java @@ -20,11 +20,14 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import java.io.IOException; import java.net.URI; +import java.util.EnumSet; +import java.util.Locale; import java.util.Map; import java.util.Objects; @@ -43,6 +46,18 @@ public class CohereServiceSettings extends FilteredXContentObject implements Ser public static final String NAME = "cohere_service_settings"; public static final String OLD_MODEL_ID_FIELD = "model"; public static final String MODEL_ID = "model_id"; + public static final String API_VERSION = "api_version"; + public static final String MODEL_REQUIRED_FOR_V2_API = "The [service_settings.model_id] field is required for the Cohere V2 API."; + + public enum CohereApiVersion { + V1, + V2; + + public static CohereApiVersion fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + } + private static final Logger logger = LogManager.getLogger(CohereServiceSettings.class); // Production key rate limits for all endpoints: https://docs.cohere.com/docs/going-live#production-key-specifications // 10K requests a minute @@ -72,11 +87,45 @@ public static CohereServiceSettings fromMap(Map map, Configurati logger.info("The cohere [service_settings.model] field is deprecated. Please use [service_settings.model_id] instead."); } + var resolvedModelId = modelId(oldModelId, modelId); + var apiVersion = apiVersionFromMap(map, context, validationException); + if (apiVersion == CohereApiVersion.V2) { + if (resolvedModelId == null) { + validationException.addValidationError(MODEL_REQUIRED_FOR_V2_API); + } + } + if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - return new CohereServiceSettings(uri, similarity, dims, maxInputTokens, modelId(oldModelId, modelId), rateLimitSettings); + return new CohereServiceSettings(uri, similarity, dims, maxInputTokens, resolvedModelId, rateLimitSettings, apiVersion); + } + + public static CohereApiVersion apiVersionFromMap( + Map map, + ConfigurationParseContext context, + ValidationException validationException + ) { + return switch (context) { + case REQUEST -> CohereApiVersion.V2; // new endpoints all use the V2 API. + case PERSISTENT -> { + var apiVersion = ServiceUtils.extractOptionalEnum( + map, + API_VERSION, + ModelConfigurations.SERVICE_SETTINGS, + CohereApiVersion::fromString, + EnumSet.allOf(CohereApiVersion.class), + validationException + ); + + if (apiVersion == null) { + yield CohereApiVersion.V1; // If the API version is not persisted then it must be V1 + } else { + yield apiVersion; + } + } + }; } private static String modelId(@Nullable String model, @Nullable String modelId) { @@ -89,6 +138,7 @@ private static String modelId(@Nullable String model, @Nullable String modelId) private final Integer maxInputTokens; private final String modelId; private final RateLimitSettings rateLimitSettings; + private final CohereApiVersion apiVersion; public CohereServiceSettings( @Nullable URI uri, @@ -96,7 +146,8 @@ public CohereServiceSettings( @Nullable Integer dimensions, @Nullable Integer maxInputTokens, @Nullable String modelId, - @Nullable RateLimitSettings rateLimitSettings + @Nullable RateLimitSettings rateLimitSettings, + CohereApiVersion apiVersion ) { this.uri = uri; this.similarity = similarity; @@ -104,6 +155,7 @@ public CohereServiceSettings( this.maxInputTokens = maxInputTokens; this.modelId = modelId; this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + this.apiVersion = apiVersion; } public CohereServiceSettings( @@ -112,9 +164,10 @@ public CohereServiceSettings( @Nullable Integer dimensions, @Nullable Integer maxInputTokens, @Nullable String modelId, - @Nullable RateLimitSettings rateLimitSettings + @Nullable RateLimitSettings rateLimitSettings, + CohereApiVersion apiVersion ) { - this(createOptionalUri(url), similarity, dimensions, maxInputTokens, modelId, rateLimitSettings); + this(createOptionalUri(url), similarity, dimensions, maxInputTokens, modelId, rateLimitSettings, apiVersion); } public CohereServiceSettings(StreamInput in) throws IOException { @@ -129,11 +182,16 @@ public CohereServiceSettings(StreamInput in) throws IOException { } else { rateLimitSettings = DEFAULT_RATE_LIMIT_SETTINGS; } + if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_API_VERSION_8_19)) { + this.apiVersion = in.readEnum(CohereServiceSettings.CohereApiVersion.class); + } else { + this.apiVersion = CohereServiceSettings.CohereApiVersion.V1; + } } // should only be used for testing, public because it's accessed outside of the package - public CohereServiceSettings() { - this((URI) null, null, null, null, null, null); + public CohereServiceSettings(CohereApiVersion apiVersion) { + this((URI) null, null, null, null, null, null, apiVersion); } @Override @@ -141,6 +199,11 @@ public RateLimitSettings rateLimitSettings() { return rateLimitSettings; } + @Override + public CohereApiVersion apiVersion() { + return apiVersion; + } + public URI uri() { return uri; } @@ -172,15 +235,14 @@ public String getWriteableName() { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - toXContentFragment(builder, params); - builder.endObject(); return builder; } public XContentBuilder toXContentFragment(XContentBuilder builder, Params params) throws IOException { - return toXContentFragmentOfExposedFields(builder, params); + toXContentFragmentOfExposedFields(builder, params); + return builder.field(API_VERSION, apiVersion); // API version is persisted but not exposed to the user } @Override @@ -222,6 +284,9 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) { rateLimitSettings.writeTo(out); } + if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_API_VERSION_8_19)) { + out.writeEnum(apiVersion); + } } @Override @@ -234,11 +299,12 @@ public boolean equals(Object o) { && Objects.equals(dimensions, that.dimensions) && Objects.equals(maxInputTokens, that.maxInputTokens) && Objects.equals(modelId, that.modelId) - && Objects.equals(rateLimitSettings, that.rateLimitSettings); + && Objects.equals(rateLimitSettings, that.rateLimitSettings) + && apiVersion == that.apiVersion; } @Override public int hashCode() { - return Objects.hash(uri, similarity, dimensions, maxInputTokens, modelId, rateLimitSettings); + return Objects.hash(uri, similarity, dimensions, maxInputTokens, modelId, rateLimitSettings, apiVersion); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java index 83fbc5a8ad6e9..777ddc348bda6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreator.java @@ -7,20 +7,35 @@ package org.elasticsearch.xpack.inference.services.cohere.action; +import org.elasticsearch.inference.InputType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; +import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; +import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.ServiceComponents; -import org.elasticsearch.xpack.inference.services.cohere.CohereCompletionRequestManager; -import org.elasticsearch.xpack.inference.services.cohere.CohereEmbeddingsRequestManager; -import org.elasticsearch.xpack.inference.services.cohere.CohereRerankRequestManager; +import org.elasticsearch.xpack.inference.services.cohere.CohereResponseHandler; import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.cohere.request.v1.CohereV1CompletionRequest; +import org.elasticsearch.xpack.inference.services.cohere.request.v1.CohereV1EmbeddingsRequest; +import org.elasticsearch.xpack.inference.services.cohere.request.v1.CohereV1RerankRequest; +import org.elasticsearch.xpack.inference.services.cohere.request.v2.CohereV2CompletionRequest; +import org.elasticsearch.xpack.inference.services.cohere.request.v2.CohereV2EmbeddingsRequest; +import org.elasticsearch.xpack.inference.services.cohere.request.v2.CohereV2RerankRequest; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; +import org.elasticsearch.xpack.inference.services.cohere.response.CohereCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.cohere.response.CohereEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.services.cohere.response.CohereRankedResponseEntity; import java.util.Map; import java.util.Objects; +import java.util.function.Function; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; @@ -28,12 +43,30 @@ * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the cohere model type. */ public class CohereActionCreator implements CohereActionVisitor { + + private static final ResponseHandler EMBEDDINGS_HANDLER = new CohereResponseHandler( + "cohere text embedding", + CohereEmbeddingsResponseEntity::fromResponse, + false + ); + + private static final ResponseHandler RERANK_HANDLER = new CohereResponseHandler( + "cohere rerank", + (request, response) -> CohereRankedResponseEntity.fromResponse(response), + false + ); + + private static final ResponseHandler COMPLETION_HANDLER = new CohereResponseHandler( + "cohere completion", + CohereCompletionResponseEntity::fromResponse, + true + ); + private static final String COMPLETION_ERROR_PREFIX = "Cohere completion"; private final Sender sender; private final ServiceComponents serviceComponents; public CohereActionCreator(Sender sender, ServiceComponents serviceComponents) { - // TODO Batching - accept a class that can handle batching this.sender = Objects.requireNonNull(sender); this.serviceComponents = Objects.requireNonNull(serviceComponents); } @@ -41,24 +74,80 @@ public CohereActionCreator(Sender sender, ServiceComponents serviceComponents) { @Override public ExecutableAction create(CohereEmbeddingsModel model, Map taskSettings) { var overriddenModel = CohereEmbeddingsModel.of(model, taskSettings); + + Function requestCreator = inferenceInputs -> { + var requestInputType = InputType.isSpecified(inferenceInputs.getInputType()) + ? inferenceInputs.getInputType() + : overriddenModel.getTaskSettings().getInputType(); + + return switch (overriddenModel.getServiceSettings().getCommonSettings().apiVersion()) { + case V1 -> new CohereV1EmbeddingsRequest(inferenceInputs.getStringInputs(), requestInputType, overriddenModel); + case V2 -> new CohereV2EmbeddingsRequest(inferenceInputs.getStringInputs(), requestInputType, overriddenModel); + }; + }; + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Cohere embeddings"); - // TODO - Batching pass the batching class on to the CohereEmbeddingsRequestManager - var requestCreator = CohereEmbeddingsRequestManager.of(overriddenModel, serviceComponents.threadPool()); - return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage); + var requestManager = new GenericRequestManager<>( + serviceComponents.threadPool(), + model, + EMBEDDINGS_HANDLER, + requestCreator, + EmbeddingsInput.class + ); + return new SenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage); } @Override public ExecutableAction create(CohereRerankModel model, Map taskSettings) { var overriddenModel = CohereRerankModel.of(model, taskSettings); - var requestCreator = CohereRerankRequestManager.of(overriddenModel, serviceComponents.threadPool()); + + Function requestCreator = inferenceInputs -> switch (overriddenModel.getServiceSettings() + .apiVersion()) { + case V1 -> new CohereV1RerankRequest( + inferenceInputs.getQuery(), + inferenceInputs.getChunks(), + inferenceInputs.getReturnDocuments(), + inferenceInputs.getTopN(), + overriddenModel + ); + case V2 -> new CohereV2RerankRequest( + inferenceInputs.getQuery(), + inferenceInputs.getChunks(), + inferenceInputs.getReturnDocuments(), + inferenceInputs.getTopN(), + overriddenModel + ); + }; + + var requestManager = new GenericRequestManager<>( + serviceComponents.threadPool(), + overriddenModel, + RERANK_HANDLER, + requestCreator, + QueryAndDocsInputs.class + ); + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Cohere rerank"); - return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage); + return new SenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage); } @Override public ExecutableAction create(CohereCompletionModel model, Map taskSettings) { // no overridden model as task settings are always empty for cohere completion model - var requestManager = CohereCompletionRequestManager.of(model, serviceComponents.threadPool()); + + Function requestCreator = completionInput -> switch (model.getServiceSettings().apiVersion()) { + case V1 -> new CohereV1CompletionRequest(completionInput.getInputs(), model, completionInput.stream()); + case V2 -> new CohereV2CompletionRequest(completionInput.getInputs(), model, completionInput.stream()); + }; + + var requestManager = new GenericRequestManager<>( + serviceComponents.threadPool(), + model, + COMPLETION_HANDLER, + requestCreator, + ChatCompletionInput.class + ); + var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(COMPLETION_ERROR_PREFIX); return new SingleInputSenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage, COMPLETION_ERROR_PREFIX); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModel.java index a31a6ae290fea..120964393fd6a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModel.java @@ -16,27 +16,22 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.cohere.CohereModel; +import org.elasticsearch.xpack.inference.services.cohere.CohereService; import org.elasticsearch.xpack.inference.services.cohere.action.CohereActionVisitor; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; -import java.net.URI; import java.util.Map; public class CohereCompletionModel extends CohereModel { public CohereCompletionModel( String modelId, - TaskType taskType, - String service, Map serviceSettings, - Map taskSettings, @Nullable Map secrets, ConfigurationParseContext context ) { this( modelId, - taskType, - service, CohereCompletionServiceSettings.fromMap(serviceSettings, context), EmptyTaskSettings.INSTANCE, DefaultSecretSettings.fromMap(secrets) @@ -46,14 +41,12 @@ public CohereCompletionModel( // should only be used for testing CohereCompletionModel( String modelId, - TaskType taskType, - String service, CohereCompletionServiceSettings serviceSettings, TaskSettings taskSettings, @Nullable DefaultSecretSettings secretSettings ) { super( - new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), + new ModelConfigurations(modelId, TaskType.COMPLETION, CohereService.NAME, serviceSettings, taskSettings), new ModelSecrets(secretSettings), secretSettings, serviceSettings @@ -79,9 +72,4 @@ public DefaultSecretSettings getSecretSettings() { public ExecutableAction accept(CohereActionVisitor visitor, Map taskSettings) { return visitor.create(this, taskSettings); } - - @Override - public URI uri() { - return getServiceSettings().uri(); - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettings.java index be241f3aaa7fc..9e46c91711924 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettings.java @@ -19,6 +19,7 @@ import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.cohere.CohereRateLimitServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.CohereService; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -32,6 +33,9 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createOptionalUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; +import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings.API_VERSION; +import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings.MODEL_REQUIRED_FOR_V2_API; +import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings.apiVersionFromMap; public class CohereCompletionServiceSettings extends FilteredXContentObject implements ServiceSettings, CohereRateLimitServiceSettings { @@ -54,34 +58,55 @@ public static CohereCompletionServiceSettings fromMap(Map map, C context ); String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + var apiVersion = apiVersionFromMap(map, context, validationException); + if (apiVersion == CohereServiceSettings.CohereApiVersion.V2) { + if (modelId == null) { + validationException.addValidationError(MODEL_REQUIRED_FOR_V2_API); + } + } if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - return new CohereCompletionServiceSettings(uri, modelId, rateLimitSettings); + return new CohereCompletionServiceSettings(uri, modelId, rateLimitSettings, apiVersion); } private final URI uri; - private final String modelId; - private final RateLimitSettings rateLimitSettings; - - public CohereCompletionServiceSettings(@Nullable URI uri, @Nullable String modelId, @Nullable RateLimitSettings rateLimitSettings) { + private final CohereServiceSettings.CohereApiVersion apiVersion; + + public CohereCompletionServiceSettings( + @Nullable URI uri, + @Nullable String modelId, + @Nullable RateLimitSettings rateLimitSettings, + CohereServiceSettings.CohereApiVersion apiVersion + ) { this.uri = uri; this.modelId = modelId; this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + this.apiVersion = apiVersion; } - public CohereCompletionServiceSettings(@Nullable String url, @Nullable String modelId, @Nullable RateLimitSettings rateLimitSettings) { - this(createOptionalUri(url), modelId, rateLimitSettings); + public CohereCompletionServiceSettings( + @Nullable String url, + @Nullable String modelId, + @Nullable RateLimitSettings rateLimitSettings, + CohereServiceSettings.CohereApiVersion apiVersion + ) { + this(createOptionalUri(url), modelId, rateLimitSettings, apiVersion); } public CohereCompletionServiceSettings(StreamInput in) throws IOException { uri = createOptionalUri(in.readOptionalString()); modelId = in.readOptionalString(); rateLimitSettings = new RateLimitSettings(in); + if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_API_VERSION_8_19)) { + this.apiVersion = in.readEnum(CohereServiceSettings.CohereApiVersion.class); + } else { + this.apiVersion = CohereServiceSettings.CohereApiVersion.V1; + } } @Override @@ -89,6 +114,11 @@ public RateLimitSettings rateLimitSettings() { return rateLimitSettings; } + @Override + public CohereServiceSettings.CohereApiVersion apiVersion() { + return apiVersion; + } + public URI uri() { return uri; } @@ -102,6 +132,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); toXContentFragmentOfExposedFields(builder, params); + builder.field(API_VERSION, apiVersion); // API version is persisted but not exposed to the user builder.endObject(); return builder; @@ -123,6 +154,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalString(uriToWrite); out.writeOptionalString(modelId); rateLimitSettings.writeTo(out); + if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_API_VERSION_8_19)) { + out.writeEnum(apiVersion); + } } @Override @@ -146,11 +180,12 @@ public boolean equals(Object object) { CohereCompletionServiceSettings that = (CohereCompletionServiceSettings) object; return Objects.equals(uri, that.uri) && Objects.equals(modelId, that.modelId) - && Objects.equals(rateLimitSettings, that.rateLimitSettings); + && Objects.equals(rateLimitSettings, that.rateLimitSettings) + && apiVersion == that.apiVersion; } @Override public int hashCode() { - return Objects.hash(uri, modelId, rateLimitSettings); + return Objects.hash(uri, modelId, rateLimitSettings, apiVersion); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java index 2edd365e66311..525674cc9b2ef 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModel.java @@ -15,6 +15,7 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.cohere.CohereModel; +import org.elasticsearch.xpack.inference.services.cohere.CohereService; import org.elasticsearch.xpack.inference.services.cohere.action.CohereActionVisitor; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -29,8 +30,6 @@ public static CohereEmbeddingsModel of(CohereEmbeddingsModel model, Map serviceSettings, Map taskSettings, ChunkingSettings chunkingSettings, @@ -39,8 +38,6 @@ public CohereEmbeddingsModel( ) { this( inferenceId, - taskType, - service, CohereEmbeddingsServiceSettings.fromMap(serviceSettings, context), CohereEmbeddingsTaskSettings.fromMap(taskSettings), chunkingSettings, @@ -51,15 +48,13 @@ public CohereEmbeddingsModel( // should only be used for testing CohereEmbeddingsModel( String modelId, - TaskType taskType, - String service, CohereEmbeddingsServiceSettings serviceSettings, CohereEmbeddingsTaskSettings taskSettings, ChunkingSettings chunkingSettings, @Nullable DefaultSecretSettings secretSettings ) { super( - new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings, chunkingSettings), + new ModelConfigurations(modelId, TaskType.TEXT_EMBEDDING, CohereService.NAME, serviceSettings, taskSettings, chunkingSettings), new ModelSecrets(secretSettings), secretSettings, serviceSettings.getCommonSettings() @@ -95,7 +90,7 @@ public ExecutableAction accept(CohereActionVisitor visitor, Map } @Override - public URI uri() { + public URI baseUri() { return getServiceSettings().getCommonSettings().uri(); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java index b25b9fc8fd351..11cd6c2bcd75d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettings.java @@ -9,6 +9,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -183,6 +184,11 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalEnum(CohereEmbeddingType.translateToVersion(embeddingType, out.getTransportVersion())); } + @Override + public String toString() { + return Strings.toString(this); + } + @Override public boolean equals(Object o) { if (this == o) return true; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereEmbeddingsRequest.java deleted file mode 100644 index 7ce218c3a8fe8..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereEmbeddingsRequest.java +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere.request; - -import org.apache.http.client.methods.HttpPost; -import org.apache.http.client.utils.URIBuilder; -import org.apache.http.entity.ByteArrayEntity; -import org.elasticsearch.common.Strings; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.xpack.inference.external.request.HttpRequest; -import org.elasticsearch.xpack.inference.external.request.Request; -import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; - -import java.net.URI; -import java.net.URISyntaxException; -import java.nio.charset.StandardCharsets; -import java.util.List; -import java.util.Objects; - -public class CohereEmbeddingsRequest extends CohereRequest { - - private final CohereAccount account; - private final List input; - private final InputType inputType; - private final CohereEmbeddingsTaskSettings taskSettings; - private final String model; - private final CohereEmbeddingType embeddingType; - private final String inferenceEntityId; - - public CohereEmbeddingsRequest(List input, InputType inputType, CohereEmbeddingsModel embeddingsModel) { - Objects.requireNonNull(embeddingsModel); - - account = CohereAccount.of(embeddingsModel, CohereEmbeddingsRequest::buildDefaultUri); - this.input = Objects.requireNonNull(input); - this.inputType = inputType; - taskSettings = embeddingsModel.getTaskSettings(); - model = embeddingsModel.getServiceSettings().getCommonSettings().modelId(); - embeddingType = embeddingsModel.getServiceSettings().getEmbeddingType(); - inferenceEntityId = embeddingsModel.getInferenceEntityId(); - } - - @Override - public HttpRequest createHttpRequest() { - HttpPost httpPost = new HttpPost(account.uri()); - - ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new CohereEmbeddingsRequestEntity(input, inputType, taskSettings, model, embeddingType)) - .getBytes(StandardCharsets.UTF_8) - ); - httpPost.setEntity(byteEntity); - - decorateWithAuthHeader(httpPost, account); - - return new HttpRequest(httpPost, getInferenceEntityId()); - } - - @Override - public String getInferenceEntityId() { - return inferenceEntityId; - } - - @Override - public URI getURI() { - return account.uri(); - } - - @Override - public Request truncate() { - return this; - } - - @Override - public boolean[] getTruncationInfo() { - return null; - } - - public static URI buildDefaultUri() throws URISyntaxException { - return new URIBuilder().setScheme("https") - .setHost(CohereUtils.HOST) - .setPathSegments(CohereUtils.VERSION_1, CohereUtils.EMBEDDINGS_PATH) - .build(); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereEmbeddingsRequestEntity.java deleted file mode 100644 index e4de77cd56edd..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereEmbeddingsRequestEntity.java +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere.request; - -import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields; -import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; - -import java.io.IOException; -import java.util.List; -import java.util.Objects; - -import static org.elasticsearch.inference.InputType.invalidInputTypeMessage; - -public record CohereEmbeddingsRequestEntity( - List input, - InputType inputType, - CohereEmbeddingsTaskSettings taskSettings, - @Nullable String model, - @Nullable CohereEmbeddingType embeddingType -) implements ToXContentObject { - - private static final String SEARCH_DOCUMENT = "search_document"; - private static final String SEARCH_QUERY = "search_query"; - private static final String CLUSTERING = "clustering"; - private static final String CLASSIFICATION = "classification"; - private static final String TEXTS_FIELD = "texts"; - public static final String INPUT_TYPE_FIELD = "input_type"; - static final String EMBEDDING_TYPES_FIELD = "embedding_types"; - - public CohereEmbeddingsRequestEntity { - Objects.requireNonNull(input); - Objects.requireNonNull(taskSettings); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(TEXTS_FIELD, input); - if (model != null) { - builder.field(CohereServiceSettings.OLD_MODEL_ID_FIELD, model); - } - - // prefer the root level inputType over task settings input type - if (InputType.isSpecified(inputType)) { - builder.field(INPUT_TYPE_FIELD, convertToString(inputType)); - } else if (InputType.isSpecified(taskSettings.getInputType())) { - builder.field(INPUT_TYPE_FIELD, convertToString(taskSettings.getInputType())); - } - - if (embeddingType != null) { - builder.field(EMBEDDING_TYPES_FIELD, List.of(embeddingType.toRequestString())); - } - - if (taskSettings.getTruncation() != null) { - builder.field(CohereServiceFields.TRUNCATE, taskSettings.getTruncation()); - } - - builder.endObject(); - return builder; - } - - // default for testing - public static String convertToString(InputType inputType) { - return switch (inputType) { - case INGEST, INTERNAL_INGEST -> SEARCH_DOCUMENT; - case SEARCH, INTERNAL_SEARCH -> SEARCH_QUERY; - case CLASSIFICATION -> CLASSIFICATION; - case CLUSTERING -> CLUSTERING; - default -> { - assert false : invalidInputTypeMessage(inputType); - yield null; - } - }; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRequest.java index fda1661d02472..ae351976545a9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRequest.java @@ -9,13 +9,28 @@ import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.utils.URIBuilder; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; +import org.elasticsearch.xpack.inference.services.cohere.CohereService; + +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Objects; import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; -public abstract class CohereRequest implements Request { +public abstract class CohereRequest implements Request, ToXContentObject { public static void decorateWithAuthHeader(HttpPost request, CohereAccount account) { request.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); @@ -23,4 +38,76 @@ public static void decorateWithAuthHeader(HttpPost request, CohereAccount accoun request.setHeader(CohereUtils.createRequestSourceHeader()); } + protected final CohereAccount account; + private final String inferenceEntityId; + private final String modelId; + private final boolean stream; + + protected CohereRequest(CohereAccount account, String inferenceEntityId, @Nullable String modelId, boolean stream) { + this.account = account; + this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId); + this.modelId = modelId; // model is optional in the v1 api + this.stream = stream; + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(getURI()); + + ByteArrayEntity byteEntity = new ByteArrayEntity(Strings.toString(this).getBytes(StandardCharsets.UTF_8)); + httpPost.setEntity(byteEntity); + + decorateWithAuthHeader(httpPost, account); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public String getInferenceEntityId() { + return inferenceEntityId; + } + + @Override + public boolean isStreaming() { + return stream; + } + + @Override + public URI getURI() { + return buildUri(account.baseUri()); + } + + /** + * Returns the URL path segments. + * @return List of segments that make up the path of the request. + */ + protected abstract List pathSegments(); + + private URI buildUri(URI baseUri) { + try { + return new URIBuilder(baseUri).setPathSegments(pathSegments()).build(); + } catch (URISyntaxException e) { + throw new ElasticsearchStatusException( + Strings.format("Failed to construct %s URL", CohereService.NAME), + RestStatus.BAD_REQUEST, + e + ); + } + } + + public String getModelId() { + return modelId; + } + + @Override + public Request truncate() { + // no truncation + return this; + } + + @Override + public boolean[] getTruncationInfo() { + // no truncation + return null; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRerankRequest.java deleted file mode 100644 index ed2a7ea97925e..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRerankRequest.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere.request; - -import org.apache.http.client.methods.HttpPost; -import org.apache.http.client.utils.URIBuilder; -import org.apache.http.entity.ByteArrayEntity; -import org.elasticsearch.common.Strings; -import org.elasticsearch.core.Nullable; -import org.elasticsearch.xpack.inference.external.request.HttpRequest; -import org.elasticsearch.xpack.inference.external.request.Request; -import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; -import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; -import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; - -import java.net.URI; -import java.net.URISyntaxException; -import java.nio.charset.StandardCharsets; -import java.util.List; -import java.util.Objects; - -public class CohereRerankRequest extends CohereRequest { - - private final CohereAccount account; - private final String query; - private final List input; - private final Boolean returnDocuments; - private final Integer topN; - private final CohereRerankTaskSettings taskSettings; - private final String model; - private final String inferenceEntityId; - - public CohereRerankRequest( - String query, - List input, - @Nullable Boolean returnDocuments, - @Nullable Integer topN, - CohereRerankModel model - ) { - Objects.requireNonNull(model); - - this.account = CohereAccount.of(model, CohereRerankRequest::buildDefaultUri); - this.input = Objects.requireNonNull(input); - this.query = Objects.requireNonNull(query); - this.returnDocuments = returnDocuments; - this.topN = topN; - taskSettings = model.getTaskSettings(); - this.model = model.getServiceSettings().modelId(); - inferenceEntityId = model.getInferenceEntityId(); - } - - @Override - public HttpRequest createHttpRequest() { - HttpPost httpPost = new HttpPost(account.uri()); - - ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new CohereRerankRequestEntity(query, input, returnDocuments, topN, taskSettings, model)) - .getBytes(StandardCharsets.UTF_8) - ); - httpPost.setEntity(byteEntity); - - decorateWithAuthHeader(httpPost, account); - - return new HttpRequest(httpPost, getInferenceEntityId()); - } - - @Override - public String getInferenceEntityId() { - return inferenceEntityId; - } - - @Override - public URI getURI() { - return account.uri(); - } - - @Override - public Request truncate() { - return this; // TODO? - } - - @Override - public boolean[] getTruncationInfo() { - return null; - } - - public static URI buildDefaultUri() throws URISyntaxException { - return new URIBuilder().setScheme("https") - .setHost(CohereUtils.HOST) - .setPathSegments(CohereUtils.VERSION_1, CohereUtils.RERANK_PATH) - .build(); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtils.java index 6eef2c67f5af0..f512444c6d6a4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtils.java @@ -9,19 +9,49 @@ import org.apache.http.Header; import org.apache.http.message.BasicHeader; +import org.elasticsearch.inference.InputType; + +import static org.elasticsearch.inference.InputType.invalidInputTypeMessage; public class CohereUtils { public static final String HOST = "api.cohere.ai"; public static final String VERSION_1 = "v1"; + public static final String VERSION_2 = "v2"; public static final String CHAT_PATH = "chat"; public static final String EMBEDDINGS_PATH = "embed"; public static final String RERANK_PATH = "rerank"; public static final String REQUEST_SOURCE_HEADER = "Request-Source"; public static final String ELASTIC_REQUEST_SOURCE = "unspecified:elasticsearch"; + public static final String CLUSTERING = "clustering"; + public static final String CLASSIFICATION = "classification"; + public static final String DOCUMENTS_FIELD = "documents"; + public static final String EMBEDDING_TYPES_FIELD = "embedding_types"; + public static final String INPUT_TYPE_FIELD = "input_type"; + public static final String MESSAGE_FIELD = "message"; + public static final String MODEL_FIELD = "model"; + public static final String QUERY_FIELD = "query"; + public static final String SEARCH_DOCUMENT = "search_document"; + public static final String SEARCH_QUERY = "search_query"; + public static final String TEXTS_FIELD = "texts"; + public static final String STREAM_FIELD = "stream"; + public static Header createRequestSourceHeader() { return new BasicHeader(REQUEST_SOURCE_HEADER, ELASTIC_REQUEST_SOURCE); } + public static String inputTypeToString(InputType inputType) { + return switch (inputType) { + case INGEST, INTERNAL_INGEST -> SEARCH_DOCUMENT; + case SEARCH, INTERNAL_SEARCH -> SEARCH_QUERY; + case CLASSIFICATION -> CLASSIFICATION; + case CLUSTERING -> CLUSTERING; + default -> { + assert false : invalidInputTypeMessage(inputType); + yield null; + } + }; + } + private CohereUtils() {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/completion/CohereCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/completion/CohereCompletionRequest.java deleted file mode 100644 index b477295afbc09..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/completion/CohereCompletionRequest.java +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere.request.completion; - -import org.apache.http.client.methods.HttpPost; -import org.apache.http.client.utils.URIBuilder; -import org.apache.http.entity.ByteArrayEntity; -import org.elasticsearch.common.Strings; -import org.elasticsearch.xpack.inference.external.request.HttpRequest; -import org.elasticsearch.xpack.inference.external.request.Request; -import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; -import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel; -import org.elasticsearch.xpack.inference.services.cohere.request.CohereRequest; -import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; - -import java.net.URI; -import java.net.URISyntaxException; -import java.nio.charset.StandardCharsets; -import java.util.List; -import java.util.Objects; - -public class CohereCompletionRequest extends CohereRequest { - private final CohereAccount account; - private final List input; - private final String modelId; - private final String inferenceEntityId; - private final boolean stream; - - public CohereCompletionRequest(List input, CohereCompletionModel model, boolean stream) { - Objects.requireNonNull(model); - - this.account = CohereAccount.of(model, CohereCompletionRequest::buildDefaultUri); - this.input = Objects.requireNonNull(input); - this.modelId = model.getServiceSettings().modelId(); - this.inferenceEntityId = model.getInferenceEntityId(); - this.stream = stream; - } - - @Override - public HttpRequest createHttpRequest() { - HttpPost httpPost = new HttpPost(account.uri()); - - ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString(new CohereCompletionRequestEntity(input, modelId, isStreaming())).getBytes(StandardCharsets.UTF_8) - ); - httpPost.setEntity(byteEntity); - - decorateWithAuthHeader(httpPost, account); - - return new HttpRequest(httpPost, getInferenceEntityId()); - } - - @Override - public String getInferenceEntityId() { - return inferenceEntityId; - } - - @Override - public boolean isStreaming() { - return stream; - } - - @Override - public URI getURI() { - return account.uri(); - } - - @Override - public Request truncate() { - // no truncation - return this; - } - - @Override - public boolean[] getTruncationInfo() { - // no truncation - return null; - } - - public static URI buildDefaultUri() throws URISyntaxException { - return new URIBuilder().setScheme("https") - .setHost(CohereUtils.HOST) - .setPathSegments(CohereUtils.VERSION_1, CohereUtils.CHAT_PATH) - .build(); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/completion/CohereCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/completion/CohereCompletionRequestEntity.java deleted file mode 100644 index 7ab8d6753a0c4..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/completion/CohereCompletionRequestEntity.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere.request.completion; - -import org.elasticsearch.core.Nullable; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContentBuilder; - -import java.io.IOException; -import java.util.List; -import java.util.Objects; - -public record CohereCompletionRequestEntity(List input, @Nullable String model, boolean stream) implements ToXContentObject { - - private static final String MESSAGE_FIELD = "message"; - private static final String MODEL = "model"; - private static final String STREAM = "stream"; - - public CohereCompletionRequestEntity { - Objects.requireNonNull(input); - Objects.requireNonNull(input.get(0)); - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - - // we only allow one input for completion, so always get the first one - builder.field(MESSAGE_FIELD, input.get(0)); - if (model != null) { - builder.field(MODEL, model); - } - - if (stream) { - builder.field(STREAM, true); - } - - builder.endObject(); - - return builder; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequest.java new file mode 100644 index 0000000000000..80d0fec9eb7c3 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequest.java @@ -0,0 +1,48 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.request.v1; + +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; +import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereRequest; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public class CohereV1CompletionRequest extends CohereRequest { + private final List input; + + public CohereV1CompletionRequest(List input, CohereCompletionModel model, boolean stream) { + super(CohereAccount.of(model), model.getInferenceEntityId(), model.getServiceSettings().modelId(), stream); + + this.input = Objects.requireNonNull(input); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + // we only allow one input for completion, so always get the first one + builder.field(CohereUtils.MESSAGE_FIELD, input.get(0)); + if (getModelId() != null) { + builder.field(CohereUtils.MODEL_FIELD, getModelId()); + } + if (isStreaming()) { + builder.field(CohereUtils.STREAM_FIELD, true); + } + builder.endObject(); + return builder; + } + + @Override + protected List pathSegments() { + return List.of(CohereUtils.VERSION_1, CohereUtils.CHAT_PATH); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1EmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1EmbeddingsRequest.java new file mode 100644 index 0000000000000..7c2d0b1fbf3f5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1EmbeddingsRequest.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.request.v1; + +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereRequest; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public class CohereV1EmbeddingsRequest extends CohereRequest { + + private final List input; + private final InputType inputType; + private final CohereEmbeddingsTaskSettings taskSettings; + private final CohereEmbeddingType embeddingType; + + public CohereV1EmbeddingsRequest(List input, InputType inputType, CohereEmbeddingsModel embeddingsModel) { + super( + CohereAccount.of(embeddingsModel), + embeddingsModel.getInferenceEntityId(), + embeddingsModel.getServiceSettings().getCommonSettings().modelId(), + false + ); + + this.input = Objects.requireNonNull(input); + this.inputType = inputType; + taskSettings = embeddingsModel.getTaskSettings(); + embeddingType = embeddingsModel.getServiceSettings().getEmbeddingType(); + } + + @Override + protected List pathSegments() { + return List.of(CohereUtils.VERSION_1, CohereUtils.EMBEDDINGS_PATH); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CohereUtils.TEXTS_FIELD, input); + if (getModelId() != null) { + builder.field(CohereServiceSettings.OLD_MODEL_ID_FIELD, getModelId()); + } + + // prefer the root level inputType over task settings input type + if (InputType.isSpecified(inputType)) { + builder.field(CohereUtils.INPUT_TYPE_FIELD, CohereUtils.inputTypeToString(inputType)); + } else if (InputType.isSpecified(taskSettings.getInputType())) { + builder.field(CohereUtils.INPUT_TYPE_FIELD, CohereUtils.inputTypeToString(taskSettings.getInputType())); + } + + if (embeddingType != null) { + builder.field(CohereUtils.EMBEDDING_TYPES_FIELD, List.of(embeddingType.toRequestString())); + } + + if (taskSettings.getTruncation() != null) { + builder.field(CohereServiceFields.TRUNCATE, taskSettings.getTruncation()); + } + + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRerankRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1RerankRequest.java similarity index 58% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRerankRequestEntity.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1RerankRequest.java index ddddb9fa314e5..70b34368eda61 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRerankRequestEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1RerankRequest.java @@ -5,54 +5,56 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.cohere.request; +package org.elasticsearch.xpack.inference.services.cohere.request.v1; import org.elasticsearch.core.Nullable; -import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereRequest; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; +import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; import java.io.IOException; import java.util.List; import java.util.Objects; -public record CohereRerankRequestEntity( - String model, - String query, - List documents, - @Nullable Boolean returnDocuments, - @Nullable Integer topN, - CohereRerankTaskSettings taskSettings -) implements ToXContentObject { +public class CohereV1RerankRequest extends CohereRequest { - private static final String DOCUMENTS_FIELD = "documents"; - private static final String QUERY_FIELD = "query"; - private static final String MODEL_FIELD = "model"; + private final String query; + private final List input; + private final Boolean returnDocuments; + private final Integer topN; + private final CohereRerankTaskSettings taskSettings; - public CohereRerankRequestEntity { - Objects.requireNonNull(query); - Objects.requireNonNull(documents); - Objects.requireNonNull(taskSettings); - } - - public CohereRerankRequestEntity( + public CohereV1RerankRequest( String query, List input, @Nullable Boolean returnDocuments, @Nullable Integer topN, - CohereRerankTaskSettings taskSettings, - String model + CohereRerankModel model ) { - this(model, query, input, returnDocuments, topN, taskSettings); + super(CohereAccount.of(model), model.getInferenceEntityId(), model.getServiceSettings().modelId(), false); + + this.input = Objects.requireNonNull(input); + this.query = Objects.requireNonNull(query); + this.returnDocuments = returnDocuments; + this.topN = topN; + taskSettings = model.getTaskSettings(); + } + + @Override + protected List pathSegments() { + return List.of(CohereUtils.VERSION_1, CohereUtils.RERANK_PATH); } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field(MODEL_FIELD, model); - builder.field(QUERY_FIELD, query); - builder.field(DOCUMENTS_FIELD, documents); + builder.field(CohereUtils.MODEL_FIELD, getModelId()); + builder.field(CohereUtils.QUERY_FIELD, query); + builder.field(CohereUtils.DOCUMENTS_FIELD, input); // prefer the root level return_documents over task settings if (returnDocuments != null) { @@ -75,5 +77,4 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.endObject(); return builder; } - } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequest.java new file mode 100644 index 0000000000000..6187f7b41862f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequest.java @@ -0,0 +1,44 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.request.v2; + +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; +import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModel; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereRequest; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public class CohereV2CompletionRequest extends CohereRequest { + private final List input; + + public CohereV2CompletionRequest(List input, CohereCompletionModel model, boolean stream) { + super(CohereAccount.of(model), model.getInferenceEntityId(), Objects.requireNonNull(model.getServiceSettings().modelId()), stream); + + this.input = Objects.requireNonNull(input); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + // we only allow one input for completion, so always get the first one + builder.field(CohereUtils.MESSAGE_FIELD, input.get(0)); + builder.field(CohereUtils.MODEL_FIELD, getModelId()); + builder.field(CohereUtils.STREAM_FIELD, isStreaming()); + builder.endObject(); + return builder; + } + + @Override + protected List pathSegments() { + return List.of(CohereUtils.VERSION_2, CohereUtils.CHAT_PATH); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequest.java new file mode 100644 index 0000000000000..6fb8eb5bec7b8 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequest.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.request.v2; + +import org.elasticsearch.inference.InputType; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereRequest; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +public class CohereV2EmbeddingsRequest extends CohereRequest { + + private final List input; + private final InputType inputType; + private final CohereEmbeddingsTaskSettings taskSettings; + private final CohereEmbeddingType embeddingType; + + public CohereV2EmbeddingsRequest(List input, InputType inputType, CohereEmbeddingsModel embeddingsModel) { + super( + CohereAccount.of(embeddingsModel), + embeddingsModel.getInferenceEntityId(), + Objects.requireNonNull(embeddingsModel.getServiceSettings().getCommonSettings().modelId()), + false + ); + + this.input = Objects.requireNonNull(input); + this.inputType = Optional.ofNullable(inputType).orElse(InputType.SEARCH); // inputType is required in v2 + taskSettings = embeddingsModel.getTaskSettings(); + embeddingType = embeddingsModel.getServiceSettings().getEmbeddingType(); + } + + @Override + protected List pathSegments() { + return List.of(CohereUtils.VERSION_2, CohereUtils.EMBEDDINGS_PATH); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(CohereUtils.TEXTS_FIELD, input); + builder.field(CohereUtils.MODEL_FIELD, getModelId()); + // prefer the root level inputType over task settings input type + if (InputType.isSpecified(inputType)) { + builder.field(CohereUtils.INPUT_TYPE_FIELD, CohereUtils.inputTypeToString(inputType)); + } else if (InputType.isSpecified(taskSettings.getInputType())) { + builder.field(CohereUtils.INPUT_TYPE_FIELD, CohereUtils.inputTypeToString(taskSettings.getInputType())); + } + builder.field(CohereUtils.EMBEDDING_TYPES_FIELD, List.of(embeddingType.toRequestString())); + if (taskSettings.getTruncation() != null) { + builder.field(CohereServiceFields.TRUNCATE, taskSettings.getTruncation()); + } + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2RerankRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2RerankRequest.java new file mode 100644 index 0000000000000..941e191bc1447 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2RerankRequest.java @@ -0,0 +1,84 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.request.v2; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.cohere.CohereAccount; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereRequest; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; +import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; +import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils.DOCUMENTS_FIELD; +import static org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils.MODEL_FIELD; +import static org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils.QUERY_FIELD; + +public class CohereV2RerankRequest extends CohereRequest { + + private final String query; + private final List input; + private final Boolean returnDocuments; + private final Integer topN; + private final CohereRerankTaskSettings taskSettings; + + public CohereV2RerankRequest( + String query, + List input, + @Nullable Boolean returnDocuments, + @Nullable Integer topN, + CohereRerankModel model + ) { + super(CohereAccount.of(model), model.getInferenceEntityId(), Objects.requireNonNull(model.getServiceSettings().modelId()), false); + + this.input = Objects.requireNonNull(input); + this.query = Objects.requireNonNull(query); + this.returnDocuments = returnDocuments; + this.topN = topN; + taskSettings = model.getTaskSettings(); + } + + @Override + protected List pathSegments() { + return List.of(CohereUtils.VERSION_2, CohereUtils.RERANK_PATH); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.field(MODEL_FIELD, getModelId()); + builder.field(QUERY_FIELD, query); + builder.field(DOCUMENTS_FIELD, input); + + // prefer the root level return_documents over task settings + if (returnDocuments != null) { + builder.field(CohereRerankTaskSettings.RETURN_DOCUMENTS, returnDocuments); + } else if (taskSettings.getDoesReturnDocuments() != null) { + builder.field(CohereRerankTaskSettings.RETURN_DOCUMENTS, taskSettings.getDoesReturnDocuments()); + } + + // prefer the root level top_n over task settings + if (topN != null) { + builder.field(CohereRerankTaskSettings.TOP_N_DOCS_ONLY, topN); + } else if (taskSettings.getTopNDocumentsOnly() != null) { + builder.field(CohereRerankTaskSettings.TOP_N_DOCS_ONLY, taskSettings.getTopNDocumentsOnly()); + } + + if (taskSettings.getMaxChunksPerDoc() != null) { + builder.field(CohereRerankTaskSettings.MAX_CHUNKS_PER_DOC, taskSettings.getMaxChunksPerDoc()); + } + + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankModel.java index ca853a2d28909..2244afc135582 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankModel.java @@ -14,10 +14,10 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.cohere.CohereModel; +import org.elasticsearch.xpack.inference.services.cohere.CohereService; import org.elasticsearch.xpack.inference.services.cohere.action.CohereActionVisitor; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; -import java.net.URI; import java.util.Map; public class CohereRerankModel extends CohereModel { @@ -28,8 +28,6 @@ public static CohereRerankModel of(CohereRerankModel model, Map public CohereRerankModel( String modelId, - TaskType taskType, - String service, Map serviceSettings, Map taskSettings, @Nullable Map secrets, @@ -37,25 +35,20 @@ public CohereRerankModel( ) { this( modelId, - taskType, - service, CohereRerankServiceSettings.fromMap(serviceSettings, context), CohereRerankTaskSettings.fromMap(taskSettings), DefaultSecretSettings.fromMap(secrets) ); } - // should only be used for testing - CohereRerankModel( + public CohereRerankModel( String modelId, - TaskType taskType, - String service, CohereRerankServiceSettings serviceSettings, CohereRerankTaskSettings taskSettings, @Nullable DefaultSecretSettings secretSettings ) { super( - new ModelConfigurations(modelId, taskType, service, serviceSettings, taskSettings), + new ModelConfigurations(modelId, TaskType.RERANK, CohereService.NAME, serviceSettings, taskSettings), new ModelSecrets(secretSettings), secretSettings, serviceSettings @@ -95,9 +88,4 @@ public DefaultSecretSettings getSecretSettings() { public ExecutableAction accept(CohereActionVisitor visitor, Map taskSettings) { return visitor.create(this, taskSettings); } - - @Override - public URI uri() { - return getServiceSettings().uri(); - } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettings.java index 78178466f9f3a..abc2f194be028 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettings.java @@ -7,8 +7,6 @@ package org.elasticsearch.xpack.inference.services.cohere.rerank; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.common.ValidationException; @@ -22,6 +20,7 @@ import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.cohere.CohereRateLimitServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.CohereService; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -39,13 +38,14 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; +import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings.API_VERSION; import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS; +import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings.MODEL_REQUIRED_FOR_V2_API; +import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings.apiVersionFromMap; public class CohereRerankServiceSettings extends FilteredXContentObject implements ServiceSettings, CohereRateLimitServiceSettings { public static final String NAME = "cohere_rerank_service_settings"; - private static final Logger logger = LogManager.getLogger(CohereRerankServiceSettings.class); - public static CohereRerankServiceSettings fromMap(Map map, ConfigurationParseContext context) { ValidationException validationException = new ValidationException(); @@ -66,27 +66,44 @@ public static CohereRerankServiceSettings fromMap(Map map, Confi context ); + var apiVersion = apiVersionFromMap(map, context, validationException); + if (apiVersion == CohereServiceSettings.CohereApiVersion.V2) { + if (modelId == null) { + validationException.addValidationError(MODEL_REQUIRED_FOR_V2_API); + } + } + if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - return new CohereRerankServiceSettings(uri, modelId, rateLimitSettings); + return new CohereRerankServiceSettings(uri, modelId, rateLimitSettings, apiVersion); } private final URI uri; - private final String modelId; - private final RateLimitSettings rateLimitSettings; - - public CohereRerankServiceSettings(@Nullable URI uri, @Nullable String modelId, @Nullable RateLimitSettings rateLimitSettings) { + private final CohereServiceSettings.CohereApiVersion apiVersion; + + public CohereRerankServiceSettings( + @Nullable URI uri, + @Nullable String modelId, + @Nullable RateLimitSettings rateLimitSettings, + CohereServiceSettings.CohereApiVersion apiVersion + ) { this.uri = uri; this.modelId = modelId; this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + this.apiVersion = apiVersion; } - public CohereRerankServiceSettings(@Nullable String url, @Nullable String modelId, @Nullable RateLimitSettings rateLimitSettings) { - this(createOptionalUri(url), modelId, rateLimitSettings); + public CohereRerankServiceSettings( + @Nullable String url, + @Nullable String modelId, + @Nullable RateLimitSettings rateLimitSettings, + CohereServiceSettings.CohereApiVersion apiVersion + ) { + this(createOptionalUri(url), modelId, rateLimitSettings, apiVersion); } public CohereRerankServiceSettings(StreamInput in) throws IOException { @@ -106,8 +123,15 @@ public CohereRerankServiceSettings(StreamInput in) throws IOException { } else { this.rateLimitSettings = DEFAULT_RATE_LIMIT_SETTINGS; } + + if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_API_VERSION_8_19)) { + this.apiVersion = in.readEnum(CohereServiceSettings.CohereApiVersion.class); + } else { + this.apiVersion = CohereServiceSettings.CohereApiVersion.V1; + } } + @Override public URI uri() { return uri; } @@ -122,6 +146,11 @@ public RateLimitSettings rateLimitSettings() { return rateLimitSettings; } + @Override + public CohereServiceSettings.CohereApiVersion apiVersion() { + return apiVersion; + } + @Override public String getWriteableName() { return NAME; @@ -132,6 +161,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); toXContentFragmentOfExposedFields(builder, params); + builder.field(API_VERSION, apiVersion); // API version is persisted but not exposed to the user builder.endObject(); return builder; @@ -175,6 +205,9 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_15_0)) { rateLimitSettings.writeTo(out); } + if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_COHERE_API_VERSION_8_19)) { + out.writeEnum(apiVersion); + } } @Override @@ -184,11 +217,12 @@ public boolean equals(Object object) { CohereRerankServiceSettings that = (CohereRerankServiceSettings) object; return Objects.equals(uri, that.uri) && Objects.equals(modelId, that.modelId) - && Objects.equals(rateLimitSettings, that.rateLimitSettings); + && Objects.equals(rateLimitSettings, that.rateLimitSettings) + && apiVersion == that.apiVersion; } @Override public int hashCode() { - return Objects.hash(uri, modelId, rateLimitSettings); + return Objects.hash(uri, modelId, rateLimitSettings, apiVersion); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java index dc0e2cc10501d..f8563aebe0764 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityTestPlugin.java @@ -34,6 +34,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.inference.services.cohere.CohereService; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; @@ -132,7 +133,7 @@ private void handleGetInferenceModelActionRequ request.getInferenceEntityId(), request.getTaskType(), CohereService.NAME, - new CohereRerankServiceSettings("uri", "model", null), + new CohereRerankServiceSettings("uri", "model", null, CohereServiceSettings.CohereApiVersion.V2), topN == null ? new EmptyTaskSettings() : new CohereRerankTaskSettings(topN, null, null) ) ) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java index f4dad7546c8a2..2d9033222419f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceSettingsTests.java @@ -61,7 +61,8 @@ private static CohereServiceSettings createRandom(String url) { dims, maxInputTokens, model, - RateLimitSettingsTests.createRandom() + RateLimitSettingsTests.createRandom(), + randomFrom(CohereServiceSettings.CohereApiVersion.values()) ); } @@ -91,7 +92,17 @@ public void testFromMap() { MatcherAssert.assertThat( serviceSettings, - is(new CohereServiceSettings(ServiceUtils.createUri(url), SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens, model, null)) + is( + new CohereServiceSettings( + ServiceUtils.createUri(url), + SimilarityMeasure.DOT_PRODUCT, + dims, + maxInputTokens, + model, + null, + CohereServiceSettings.CohereApiVersion.V2 + ) + ) ); } @@ -130,7 +141,8 @@ public void testFromMap_WithRateLimit() { dims, maxInputTokens, model, - new RateLimitSettings(3) + new RateLimitSettings(3), + CohereServiceSettings.CohereApiVersion.V2 ) ) ); @@ -154,7 +166,9 @@ public void testFromMap_WhenUsingModelId() { ServiceFields.MAX_INPUT_TOKENS, maxInputTokens, CohereServiceSettings.MODEL_ID, - model + model, + CohereServiceSettings.API_VERSION, + CohereServiceSettings.CohereApiVersion.V1.toString() ) ), ConfigurationParseContext.PERSISTENT @@ -162,10 +176,41 @@ public void testFromMap_WhenUsingModelId() { MatcherAssert.assertThat( serviceSettings, - is(new CohereServiceSettings(ServiceUtils.createUri(url), SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens, model, null)) + is( + new CohereServiceSettings( + ServiceUtils.createUri(url), + SimilarityMeasure.DOT_PRODUCT, + dims, + maxInputTokens, + model, + null, + CohereServiceSettings.CohereApiVersion.V1 + ) + ) ); } + public void testFromMap_MissingModelId() { + var e = expectThrows( + ValidationException.class, + () -> CohereServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.SIMILARITY, + SimilarityMeasure.DOT_PRODUCT.toString(), + ServiceFields.DIMENSIONS, + 1536, + ServiceFields.MAX_INPUT_TOKENS, + 512 + ) + ), + ConfigurationParseContext.REQUEST + ) + ); + + assertThat(e.validationErrors().get(0), containsString("The [service_settings.model_id] field is required for the Cohere V2 API.")); + } + public void testFromMap_PrefersModelId_OverModel() { var url = "https://www.abc.com"; var similarity = SimilarityMeasure.DOT_PRODUCT.toString(); @@ -194,7 +239,17 @@ public void testFromMap_PrefersModelId_OverModel() { MatcherAssert.assertThat( serviceSettings, - is(new CohereServiceSettings(ServiceUtils.createUri(url), SimilarityMeasure.DOT_PRODUCT, dims, maxInputTokens, model, null)) + is( + new CohereServiceSettings( + ServiceUtils.createUri(url), + SimilarityMeasure.DOT_PRODUCT, + dims, + maxInputTokens, + model, + null, + CohereServiceSettings.CohereApiVersion.V1 + ) + ) ); } @@ -255,14 +310,22 @@ public void testFromMap_InvalidSimilarity_ThrowsError() { } public void testXContent_WritesModelId() throws IOException { - var entity = new CohereServiceSettings((String) null, null, null, null, "modelId", new RateLimitSettings(1)); + var entity = new CohereServiceSettings( + (String) null, + null, + null, + null, + "modelId", + new RateLimitSettings(1), + CohereServiceSettings.CohereApiVersion.V2 + ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); entity.toXContent(builder, null); String xContentResult = Strings.toString(builder); assertThat(xContentResult, is(""" - {"model_id":"modelId","rate_limit":{"requests_per_minute":1}}""")); + {"model_id":"modelId","rate_limit":{"requests_per_minute":1},"api_version":"V2"}""")); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index f1045dc263547..ccd1c23fd3d58 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -263,11 +263,10 @@ private static ActionListener getModelListenerForException(Class excep public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createCohereService()) { - var config = getRequestConfigMap( - CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", null, null), - getTaskSettingsMapEmpty(), - getSecretSettingsMap("secret") - ); + var serviceSettings = CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, null, null); + serviceSettings.put(CohereServiceSettings.MODEL_ID, "foo"); + + var config = getRequestConfigMap(serviceSettings, getTaskSettingsMapEmpty(), getSecretSettingsMap("secret")); config.put("extra_key", "value"); var failureListener = getModelListenerForException( @@ -318,11 +317,9 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap var secretSettingsMap = getSecretSettingsMap("secret"); secretSettingsMap.put("extra_key", "value"); - var config = getRequestConfigMap( - CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap("url", null, null), - getTaskSettingsMapEmpty(), - secretSettingsMap - ); + var serviceSettings = CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, null, null); + serviceSettings.put(CohereServiceSettings.MODEL_ID, "foo"); + var config = getRequestConfigMap(serviceSettings, getTaskSettingsMapEmpty(), secretSettingsMap); var failureListener = getModelListenerForException( ElasticsearchStatusException.class, @@ -343,14 +340,12 @@ public void testParseRequestConfig_CreatesACohereEmbeddingsModelWithoutUrl() thr MatcherAssert.assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); }, (e) -> fail("Model parsing should have succeeded " + e.getMessage())); + var serviceSettings = CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, null, null); + serviceSettings.put(CohereServiceSettings.MODEL_ID, "foo"); service.parseRequestConfig( "id", TaskType.TEXT_EMBEDDING, - getRequestConfigMap( - CohereEmbeddingsServiceSettingsTests.getServiceSettingsMap(null, null, null), - getTaskSettingsMapEmpty(), - getSecretSettingsMap("secret") - ), + getRequestConfigMap(serviceSettings, getTaskSettingsMapEmpty(), getSecretSettingsMap("secret")), modelListener ); @@ -953,7 +948,7 @@ public void testInfer_UnauthorisedResponse() throws IOException { CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, 1024, 1024, - null, + "coheremodel", null ); PlainActionFuture listener = new PlainActionFuture<>(); @@ -1127,7 +1122,7 @@ public void testInfer_SetsInputTypeToIngestFromInferParameter_WhenModelSettingIs } } - public void testInfer_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest() throws IOException { + public void testInfer_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest_v1API() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { @@ -1166,7 +1161,8 @@ public void testInfer_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspec 1024, 1024, "model", - null + null, + CohereServiceSettings.CohereApiVersion.V1 ); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( @@ -1201,6 +1197,73 @@ public void testInfer_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspec } } + public void testInfer_DefaultsInputType_WhenNotPresentInTaskSettings_AndUnspecifiedIsPassedInRequest_v2API() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new CohereService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "id": "de37399c-5df6-47cb-bc57-e3c5680c977b", + "texts": [ + "hello" + ], + "embeddings": { + "float": [ + [ + 0.123, + -0.123 + ] + ] + }, + "meta": { + "api_version": { + "version": "1" + }, + "billed_units": { + "input_tokens": 1 + } + }, + "response_type": "embeddings_by_type" + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = CohereEmbeddingsModelTests.createModel( + getUrl(webServer), + "secret", + new CohereEmbeddingsTaskSettings(null, null), + 1024, + 1024, + "model", + null, + CohereServiceSettings.CohereApiVersion.V2 + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + null, + null, + List.of("abc"), + false, + new HashMap<>(), + InputType.UNSPECIFIED, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + listener.actionGet(TIMEOUT); + + MatcherAssert.assertThat(webServer.requests(), hasSize(1)); + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat( + requestMap, + is(Map.of("texts", List.of("abc"), "model", "model", "embedding_types", List.of("float"), "input_type", "search_query")) + ); + } + } + public void testChunkedInfer_BatchesCallsChunkingSettingsSet() throws IOException { var model = CohereEmbeddingsModelTests.createModel( getUrl(webServer), @@ -1315,7 +1378,7 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { var requestMap = entityAsMap(webServer.requests().get(0).getBody()); MatcherAssert.assertThat( requestMap, - is(Map.of("texts", List.of("a", "bb"), "model", "model", "embedding_types", List.of("float"))) + is(Map.of("texts", List.of("a", "bb"), "model", "model", "embedding_types", List.of("float"), "input_type", "search_query")) ); } } @@ -1413,7 +1476,7 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException { var requestMap = entityAsMap(webServer.requests().get(0).getBody()); MatcherAssert.assertThat( requestMap, - is(Map.of("texts", List.of("a", "bb"), "model", "model", "embedding_types", List.of("int8"))) + is(Map.of("texts", List.of("a", "bb"), "model", "model", "embedding_types", List.of("int8"), "input_type", "search_query")) ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java index b56a19c0af0f1..88d26d5d7eef1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java @@ -94,7 +94,7 @@ public void testCreate_CohereEmbeddingsModel() throws IOException { }, "meta": { "api_version": { - "version": "1" + "version": "2" }, "billed_units": { "input_tokens": 1 @@ -209,67 +209,7 @@ public void testCreate_CohereCompletionModel_WithModelSpecified() throws IOExcep assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), is("Bearer secret")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap, is(Map.of("message", "abc", "model", "model"))); - } - } - - public void testCreate_CohereCompletionModel_WithoutModelSpecified() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var sender = createSender(senderFactory)) { - sender.start(); - - String responseJson = """ - { - "response_id": "some id", - "text": "result", - "generation_id": "some id", - "chat_history": [ - { - "role": "USER", - "message": "input" - }, - { - "role": "CHATBOT", - "message": "result" - } - ], - "finish_reason": "COMPLETE", - "meta": { - "api_version": { - "version": "1" - }, - "billed_units": { - "input_tokens": 4, - "output_tokens": 191 - }, - "tokens": { - "input_tokens": 70, - "output_tokens": 191 - } - } - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = CohereCompletionModelTests.createModel(getUrl(webServer), "secret", null); - var actionCreator = new CohereActionCreator(sender, createWithEmptySettings(threadPool)); - var action = actionCreator.create(model, Map.of()); - - PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); - - var result = listener.actionGet(TIMEOUT); - - assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result")))); - assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), is(XContentType.JSON.mediaType())); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), is("Bearer secret")); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap, is(Map.of("message", "abc"))); + assertThat(requestMap, is(Map.of("message", "abc", "model", "model", "stream", false))); } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereCompletionActionTests.java index 016207a4835dd..78b8b7bdeaf3e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereCompletionActionTests.java @@ -24,13 +24,11 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; -import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.services.cohere.CohereCompletionRequestManager; import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModelTests; import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; import org.junit.After; @@ -44,9 +42,9 @@ import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; -import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -134,68 +132,7 @@ public void testExecute_ReturnsSuccessfulResponse_WithModelSpecified() throws IO ); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap, is(Map.of("message", "abc", "model", "model"))); - } - } - - public void testExecute_ReturnsSuccessfulResponse_WithoutModelSpecified() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var sender = HttpRequestSenderTests.createSender(senderFactory)) { - sender.start(); - - String responseJson = """ - { - "response_id": "some id", - "text": "result", - "generation_id": "some id", - "chat_history": [ - { - "role": "USER", - "message": "input" - }, - { - "role": "CHATBOT", - "message": "result" - } - ], - "finish_reason": "COMPLETE", - "meta": { - "api_version": { - "version": "1" - }, - "billed_units": { - "input_tokens": 4, - "output_tokens": 191 - }, - "tokens": { - "input_tokens": 70, - "output_tokens": 191 - } - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var action = createAction(getUrl(webServer), "secret", null, sender); - - PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); - - var result = listener.actionGet(TIMEOUT); - - assertThat(result.asMap(), is(buildExpectationCompletion(List.of("result")))); - assertThat(webServer.requests(), hasSize(1)); - assertNull(webServer.requests().get(0).getUri().getQuery()); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer secret")); - assertThat( - webServer.requests().get(0).getHeader(CohereUtils.REQUEST_SOURCE_HEADER), - equalTo(CohereUtils.ELASTIC_REQUEST_SOURCE) - ); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap, is(Map.of("message", "abc"))); + assertThat(requestMap, is(Map.of("message", "abc", "model", "model", "stream", false))); } } @@ -341,9 +278,8 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc } private ExecutableAction createAction(String url, String apiKey, @Nullable String modelName, Sender sender) { + var actionCreator = new CohereActionCreator(sender, createWithEmptySettings(threadPool)); var model = CohereCompletionModelTests.createModel(url, apiKey, modelName); - var requestManager = CohereCompletionRequestManager.of(model, threadPool); - var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Cohere completion"); - return new SingleInputSenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage, "Cohere completion"); + return actionCreator.create(model, Map.of()); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java index 13d5191577d4c..05d69bae4903e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java @@ -24,14 +24,12 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.InputTypeTests; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; -import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.HttpResult; import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; -import org.elasticsearch.xpack.inference.services.cohere.CohereEmbeddingsRequestManager; import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests; @@ -50,10 +48,9 @@ import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; -import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.services.cohere.request.CohereEmbeddingsRequestEntity.convertToString; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasSize; @@ -126,7 +123,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { ); PlainActionFuture listener = new PlainActionFuture<>(); - var inputType = InputTypeTests.randomWithNull(); + InputType inputType = InputTypeTests.randomWithNull(); action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -145,31 +142,25 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { ); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - if (inputType != null && inputType != InputType.UNSPECIFIED) { - var cohereInputType = convertToString(inputType); - MatcherAssert.assertThat( - requestMap, - is( - Map.of( - "texts", - List.of("abc"), - "model", - "model", - "input_type", - cohereInputType, - "embedding_types", - List.of("float"), - "truncate", - "start" - ) + var expectedInputType = InputType.isSpecified(inputType) ? inputType : InputType.SEARCH; + var cohereInputType = CohereUtils.inputTypeToString(expectedInputType); + MatcherAssert.assertThat( + requestMap, + is( + Map.of( + "texts", + List.of("abc"), + "model", + "model", + "input_type", + cohereInputType, + "embedding_types", + List.of("float"), + "truncate", + "start" ) - ); - } else { - MatcherAssert.assertThat( - requestMap, - is(Map.of("texts", List.of("abc"), "model", "model", "embedding_types", List.of("float"), "truncate", "start")) - ); - } + ) + ); } } @@ -354,10 +345,9 @@ private ExecutableAction createAction( @Nullable CohereEmbeddingType embeddingType, Sender sender ) { + var actionCreator = new CohereActionCreator(sender, createWithEmptySettings(threadPool)); var model = CohereEmbeddingsModelTests.createModel(url, apiKey, taskSettings, 1024, 1024, modelName, embeddingType); - var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Cohere embeddings"); - var requestCreator = CohereEmbeddingsRequestManager.of(model, threadPool); - return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage); + return actionCreator.create(model, null); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModelTests.java index b9fc7ee7b9952..6ae12b96741ed 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionModelTests.java @@ -10,9 +10,9 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.EmptyTaskSettings; -import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import java.util.HashMap; @@ -24,24 +24,20 @@ public class CohereCompletionModelTests extends ESTestCase { public void testCreateModel_AlwaysWithEmptyTaskSettings() { var model = new CohereCompletionModel( - "model", - TaskType.COMPLETION, - "service", - new HashMap<>(Map.of()), - new HashMap<>(Map.of("model", "overridden model")), + "inference_id", + new HashMap<>(Map.of("model_id", "cohere completion model")), null, ConfigurationParseContext.PERSISTENT ); assertThat(model.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertThat(model.getServiceSettings().modelId(), is("cohere completion model")); } public static CohereCompletionModel createModel(String url, String apiKey, @Nullable String model) { return new CohereCompletionModel( "id", - TaskType.COMPLETION, - "service", - new CohereCompletionServiceSettings(url, model, null), + new CohereCompletionServiceSettings(url, model, null, CohereServiceSettings.CohereApiVersion.V2), EmptyTaskSettings.INSTANCE, new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettingsTests.java index ed8bc90d32140..06ebdd158b92c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/completion/CohereCompletionServiceSettingsTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; @@ -27,7 +28,12 @@ public class CohereCompletionServiceSettingsTests extends AbstractWireSerializingTestCase { public static CohereCompletionServiceSettings createRandom() { - return new CohereCompletionServiceSettings(randomAlphaOfLength(8), randomAlphaOfLength(8), RateLimitSettingsTests.createRandom()); + return new CohereCompletionServiceSettings( + randomAlphaOfLength(8), + randomAlphaOfLength(8), + RateLimitSettingsTests.createRandom(), + randomFrom(CohereServiceSettings.CohereApiVersion.values()) + ); } public void testFromMap_WithRateLimitSettingsNull() { @@ -39,7 +45,7 @@ public void testFromMap_WithRateLimitSettingsNull() { ConfigurationParseContext.PERSISTENT ); - assertThat(serviceSettings, is(new CohereCompletionServiceSettings(url, model, null))); + assertThat(serviceSettings, is(new CohereCompletionServiceSettings(url, model, null, CohereServiceSettings.CohereApiVersion.V1))); } public void testFromMap_WithRateLimitSettings() { @@ -61,18 +67,33 @@ public void testFromMap_WithRateLimitSettings() { ConfigurationParseContext.PERSISTENT ); - assertThat(serviceSettings, is(new CohereCompletionServiceSettings(url, model, new RateLimitSettings(requestsPerMinute)))); + assertThat( + serviceSettings, + is( + new CohereCompletionServiceSettings( + url, + model, + new RateLimitSettings(requestsPerMinute), + CohereServiceSettings.CohereApiVersion.V1 + ) + ) + ); } public void testToXContent_WritesAllValues() throws IOException { - var serviceSettings = new CohereCompletionServiceSettings("url", "model", new RateLimitSettings(3)); + var serviceSettings = new CohereCompletionServiceSettings( + "url", + "model", + new RateLimitSettings(3), + CohereServiceSettings.CohereApiVersion.V1 + ); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); serviceSettings.toXContent(builder, null); String xContentResult = Strings.toString(builder); assertThat(xContentResult, is(""" - {"url":"url","model_id":"model","rate_limit":{"requests_per_minute":3}}""")); + {"url":"url","model_id":"model","rate_limit":{"requests_per_minute":3},"api_version":"V1"}""")); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java index b3e264fdf1ab7..fd380b8fd973d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsModelTests.java @@ -12,7 +12,6 @@ import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.SimilarityMeasure; -import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; @@ -121,10 +120,16 @@ public static CohereEmbeddingsModel createModel( ) { return new CohereEmbeddingsModel( "id", - TaskType.TEXT_EMBEDDING, - "service", new CohereEmbeddingsServiceSettings( - new CohereServiceSettings(url, SimilarityMeasure.DOT_PRODUCT, dimensions, tokenLimit, model, null), + new CohereServiceSettings( + url, + SimilarityMeasure.DOT_PRODUCT, + dimensions, + tokenLimit, + model, + null, + CohereServiceSettings.CohereApiVersion.V2 + ), Objects.requireNonNullElse(embeddingType, CohereEmbeddingType.FLOAT) ), taskSettings, @@ -139,15 +144,35 @@ public static CohereEmbeddingsModel createModel( CohereEmbeddingsTaskSettings taskSettings, @Nullable Integer tokenLimit, @Nullable Integer dimensions, - @Nullable String model, + String model, @Nullable CohereEmbeddingType embeddingType + ) { + return createModel( + url, + apiKey, + taskSettings, + tokenLimit, + dimensions, + model, + embeddingType, + CohereServiceSettings.CohereApiVersion.V2 + ); + } + + public static CohereEmbeddingsModel createModel( + String url, + String apiKey, + CohereEmbeddingsTaskSettings taskSettings, + @Nullable Integer tokenLimit, + @Nullable Integer dimensions, + String model, + @Nullable CohereEmbeddingType embeddingType, + CohereServiceSettings.CohereApiVersion apiVersion ) { return new CohereEmbeddingsModel( "id", - TaskType.TEXT_EMBEDDING, - "service", new CohereEmbeddingsServiceSettings( - new CohereServiceSettings(url, SimilarityMeasure.DOT_PRODUCT, dimensions, tokenLimit, model, null), + new CohereServiceSettings(url, SimilarityMeasure.DOT_PRODUCT, dimensions, tokenLimit, model, null, apiVersion), Objects.requireNonNullElse(embeddingType, CohereEmbeddingType.FLOAT) ), taskSettings, @@ -168,10 +193,16 @@ public static CohereEmbeddingsModel createModel( ) { return new CohereEmbeddingsModel( "id", - TaskType.TEXT_EMBEDDING, - "service", new CohereEmbeddingsServiceSettings( - new CohereServiceSettings(url, similarityMeasure, dimensions, tokenLimit, model, null), + new CohereServiceSettings( + url, + similarityMeasure, + dimensions, + tokenLimit, + model, + null, + CohereServiceSettings.CohereApiVersion.V2 + ), Objects.requireNonNullElse(embeddingType, CohereEmbeddingType.FLOAT) ), taskSettings, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java index 544676cfa7cc7..b033bfa0db6e5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsServiceSettingsTests.java @@ -81,7 +81,8 @@ public void testFromMap() { dims, maxInputTokens, model, - null + null, + CohereServiceSettings.CohereApiVersion.V1 ), CohereEmbeddingType.BYTE ) @@ -125,7 +126,8 @@ public void testFromMap_WithModelId() { dims, maxInputTokens, model, - null + null, + CohereServiceSettings.CohereApiVersion.V2 ), CohereEmbeddingType.INT8 ) @@ -155,7 +157,9 @@ public void testFromMap_PrefersModelId_OverModel() { CohereServiceSettings.MODEL_ID, model, CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, - CohereEmbeddingType.BYTE.toString() + CohereEmbeddingType.BYTE.toString(), + CohereServiceSettings.API_VERSION, + CohereServiceSettings.CohereApiVersion.V1.toString() ) ), ConfigurationParseContext.PERSISTENT @@ -171,7 +175,8 @@ public void testFromMap_PrefersModelId_OverModel() { dims, maxInputTokens, model, - null + null, + CohereServiceSettings.CohereApiVersion.V1 ), CohereEmbeddingType.BYTE ) @@ -188,7 +193,7 @@ public void testFromMap_EmptyEmbeddingType_ThrowsError() { var thrownException = expectThrows( ValidationException.class, () -> CohereEmbeddingsServiceSettings.fromMap( - new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, "")), + new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, "", CohereServiceSettings.MODEL_ID, "model")), ConfigurationParseContext.REQUEST ) ); @@ -208,7 +213,7 @@ public void testFromMap_InvalidEmbeddingType_ThrowsError_ForRequest() { var thrownException = expectThrows( ValidationException.class, () -> CohereEmbeddingsServiceSettings.fromMap( - new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, "abc")), + new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, "abc", CohereServiceSettings.MODEL_ID, "model")), ConfigurationParseContext.REQUEST ) ); @@ -265,7 +270,12 @@ public void testFromMap_ConvertsElementTypeByte_ToCohereEmbeddingTypeByte() { new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, DenseVectorFieldMapper.ElementType.BYTE.toString())), ConfigurationParseContext.PERSISTENT ), - is(new CohereEmbeddingsServiceSettings(new CohereServiceSettings(), CohereEmbeddingType.BYTE)) + is( + new CohereEmbeddingsServiceSettings( + new CohereServiceSettings(CohereServiceSettings.CohereApiVersion.V1), + CohereEmbeddingType.BYTE + ) + ) ); } @@ -275,7 +285,12 @@ public void testFromMap_ConvertsElementTypeFloat_ToCohereEmbeddingTypeFloat() { new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, DenseVectorFieldMapper.ElementType.FLOAT.toString())), ConfigurationParseContext.PERSISTENT ), - is(new CohereEmbeddingsServiceSettings(new CohereServiceSettings(), CohereEmbeddingType.FLOAT)) + is( + new CohereEmbeddingsServiceSettings( + new CohereServiceSettings(CohereServiceSettings.CohereApiVersion.V1), + CohereEmbeddingType.FLOAT + ) + ) ); } @@ -283,29 +298,58 @@ public void testFromMap_ConvertsInt8_ToCohereEmbeddingTypeInt8() { assertThat( CohereEmbeddingsServiceSettings.fromMap( new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, CohereEmbeddingType.INT8.toString())), - ConfigurationParseContext.REQUEST + ConfigurationParseContext.PERSISTENT ), - is(new CohereEmbeddingsServiceSettings(new CohereServiceSettings(), CohereEmbeddingType.INT8)) + is( + new CohereEmbeddingsServiceSettings( + new CohereServiceSettings(CohereServiceSettings.CohereApiVersion.V1), + CohereEmbeddingType.INT8 + ) + ) ); } public void testFromMap_ConvertsBit_ToCohereEmbeddingTypeBit() { assertThat( CohereEmbeddingsServiceSettings.fromMap( - new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, CohereEmbeddingType.BIT.toString())), + new HashMap<>( + Map.of( + CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, + CohereEmbeddingType.BIT.toString(), + CohereServiceSettings.MODEL_ID, + "model" + ) + ), ConfigurationParseContext.REQUEST ), - is(new CohereEmbeddingsServiceSettings(new CohereServiceSettings(), CohereEmbeddingType.BIT)) + is( + new CohereEmbeddingsServiceSettings( + new CohereServiceSettings((String) null, null, null, null, "model", null, CohereServiceSettings.CohereApiVersion.V2), + CohereEmbeddingType.BIT + ) + ) ); } public void testFromMap_PreservesEmbeddingTypeFloat() { assertThat( CohereEmbeddingsServiceSettings.fromMap( - new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, CohereEmbeddingType.FLOAT.toString())), + new HashMap<>( + Map.of( + CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, + CohereEmbeddingType.FLOAT.toString(), + CohereServiceSettings.MODEL_ID, + "model" + ) + ), ConfigurationParseContext.REQUEST ), - is(new CohereEmbeddingsServiceSettings(new CohereServiceSettings(), CohereEmbeddingType.FLOAT)) + is( + new CohereEmbeddingsServiceSettings( + new CohereServiceSettings((String) null, null, null, null, "model", null, CohereServiceSettings.CohereApiVersion.V2), + CohereEmbeddingType.FLOAT + ) + ) ); } @@ -315,7 +359,12 @@ public void testFromMap_PersistentReadsInt8() { new HashMap<>(Map.of(CohereEmbeddingsServiceSettings.EMBEDDING_TYPE, "int8")), ConfigurationParseContext.PERSISTENT ), - is(new CohereEmbeddingsServiceSettings(new CohereServiceSettings(), CohereEmbeddingType.INT8)) + is( + new CohereEmbeddingsServiceSettings( + new CohereServiceSettings(CohereServiceSettings.CohereApiVersion.V1), + CohereEmbeddingType.INT8 + ) + ) ); } @@ -331,7 +380,15 @@ public void testFromCohereOrDenseVectorEnumValues() { public void testToXContent_WritesAllValues() throws IOException { var serviceSettings = new CohereEmbeddingsServiceSettings( - new CohereServiceSettings("url", SimilarityMeasure.COSINE, 5, 10, "model_id", new RateLimitSettings(3)), + new CohereServiceSettings( + "url", + SimilarityMeasure.COSINE, + 5, + 10, + "model_id", + new RateLimitSettings(3), + CohereServiceSettings.CohereApiVersion.V2 + ), CohereEmbeddingType.INT8 ); @@ -340,7 +397,7 @@ public void testToXContent_WritesAllValues() throws IOException { String xContentResult = Strings.toString(builder); assertThat(xContentResult, is(""" {"url":"url","similarity":"cosine","dimensions":5,"max_input_tokens":10,"model_id":"model_id",""" + """ - "rate_limit":{"requests_per_minute":3},"embedding_type":"byte"}""")); + "rate_limit":{"requests_per_minute":3},"api_version":"V2","embedding_type":"byte"}""")); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestEntityTests.java deleted file mode 100644 index 39247b6e93e77..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestEntityTests.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere.request; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.services.cohere.request.completion.CohereCompletionRequestEntity; - -import java.io.IOException; -import java.util.List; - -import static org.hamcrest.CoreMatchers.is; - -public class CohereCompletionRequestEntityTests extends ESTestCase { - - public void testXContent_WritesAllFields() throws IOException { - var entity = new CohereCompletionRequestEntity(List.of("some input"), "model", false); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, is(""" - {"message":"some input","model":"model"}""")); - } - - public void testXContent_DoesNotWriteModelIfNotSpecified() throws IOException { - var entity = new CohereCompletionRequestEntity(List.of("some input"), null, false); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, is(""" - {"message":"some input"}""")); - } - - public void testXContent_ThrowsIfInputIsNull() { - expectThrows(NullPointerException.class, () -> new CohereCompletionRequestEntity(null, null, false)); - } - - public void testXContent_ThrowsIfMessageInInputIsNull() { - expectThrows(NullPointerException.class, () -> new CohereCompletionRequestEntity(List.of((String) null), null, false)); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestTests.java deleted file mode 100644 index 930480af50fb4..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereCompletionRequestTests.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere.request; - -import org.apache.http.HttpHeaders; -import org.apache.http.client.methods.HttpPost; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModelTests; -import org.elasticsearch.xpack.inference.services.cohere.request.completion.CohereCompletionRequest; - -import java.io.IOException; -import java.util.List; -import java.util.Map; - -import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; -import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.sameInstance; - -public class CohereCompletionRequestTests extends ESTestCase { - - public void testCreateRequest_UrlDefined() throws IOException { - var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", null), false); - - var httpRequest = request.createHttpRequest(); - assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); - - var httpPost = (HttpPost) httpRequest.httpRequestBase(); - - assertThat(httpPost.getURI().toString(), is("url")); - assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); - assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); - assertThat(httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), is(CohereUtils.ELASTIC_REQUEST_SOURCE)); - - var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, is(Map.of("message", "abc"))); - } - - public void testCreateRequest_ModelDefined() throws IOException { - var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", "model"), false); - - var httpRequest = request.createHttpRequest(); - assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); - - var httpPost = (HttpPost) httpRequest.httpRequestBase(); - - assertThat(httpPost.getURI().toString(), is("url")); - assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); - assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); - assertThat(httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), is(CohereUtils.ELASTIC_REQUEST_SOURCE)); - - var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, is(Map.of("message", "abc", "model", "model"))); - } - - public void testTruncate_ReturnsSameInstance() { - var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", "model"), false); - var truncatedRequest = request.truncate(); - - assertThat(truncatedRequest, sameInstance(request)); - } - - public void testTruncationInfo_ReturnsNull() { - var request = new CohereCompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel("url", "secret", "model"), false); - - assertNull(request.getTruncationInfo()); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereEmbeddingsRequestEntityTests.java deleted file mode 100644 index 30a01422f6f30..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereEmbeddingsRequestEntityTests.java +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere.request; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.inference.InputType; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; -import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; -import org.hamcrest.MatcherAssert; - -import java.io.IOException; -import java.util.List; - -import static org.hamcrest.CoreMatchers.is; - -public class CohereEmbeddingsRequestEntityTests extends ESTestCase { - public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { - var entity = new CohereEmbeddingsRequestEntity( - List.of("abc"), - InputType.INTERNAL_INGEST, - new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.START), - "model", - CohereEmbeddingType.FLOAT - ); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - MatcherAssert.assertThat(xContentResult, is(""" - {"texts":["abc"],"model":"model","input_type":"search_document","embedding_types":["float"],"truncate":"start"}""")); - } - - public void testXContent_TaskSettingsInputType_EmbeddingTypesInt8_TruncateNone() throws IOException { - var entity = new CohereEmbeddingsRequestEntity( - List.of("abc"), - null, - new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), - "model", - CohereEmbeddingType.INT8 - ); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - MatcherAssert.assertThat(xContentResult, is(""" - {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["int8"],"truncate":"none"}""")); - } - - public void testXContent_InternalInputType_EmbeddingTypesByte_TruncateNone() throws IOException { - var entity = new CohereEmbeddingsRequestEntity( - List.of("abc"), - InputType.INTERNAL_SEARCH, - new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE), - "model", - CohereEmbeddingType.BYTE - ); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - MatcherAssert.assertThat(xContentResult, is(""" - {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["int8"],"truncate":"none"}""")); - } - - public void testXContent_InputTypeSearch_EmbeddingTypesBinary_TruncateNone() throws IOException { - var entity = new CohereEmbeddingsRequestEntity( - List.of("abc"), - InputType.SEARCH, - new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), - "model", - CohereEmbeddingType.BINARY - ); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - MatcherAssert.assertThat(xContentResult, is(""" - {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["binary"],"truncate":"none"}""")); - } - - public void testXContent_InputTypeSearch_EmbeddingTypesBit_TruncateNone() throws IOException { - var entity = new CohereEmbeddingsRequestEntity( - List.of("abc"), - null, - new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), - "model", - CohereEmbeddingType.BIT - ); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - MatcherAssert.assertThat(xContentResult, is(""" - {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["binary"],"truncate":"none"}""")); - } - - public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException { - var entity = new CohereEmbeddingsRequestEntity(List.of("abc"), null, CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - MatcherAssert.assertThat(xContentResult, is(""" - {"texts":["abc"]}""")); - } - - public void testConvertToString_ThrowsAssertionFailure_WhenInputTypeIsUnspecified() { - var thrownException = expectThrows( - AssertionError.class, - () -> CohereEmbeddingsRequestEntity.convertToString(InputType.UNSPECIFIED) - ); - MatcherAssert.assertThat(thrownException.getMessage(), is("received invalid input type value [unspecified]")); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRequestTests.java index 604509afdbd7d..81106764474cf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRequestTests.java @@ -32,5 +32,4 @@ public void testDecorateWithAuthHeader() { assertThat(request.getFirstHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer abc")); assertThat(request.getFirstHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), is(CohereUtils.ELASTIC_REQUEST_SOURCE)); } - } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRerankRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRerankRequestEntityTests.java deleted file mode 100644 index 7c0fa143a56db..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereRerankRequestEntityTests.java +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.cohere.request; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; -import org.hamcrest.MatcherAssert; - -import java.io.IOException; -import java.util.List; - -import static org.hamcrest.CoreMatchers.is; - -public class CohereRerankRequestEntityTests extends ESTestCase { - public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { - var entity = new CohereRerankRequestEntity( - "query", - List.of("abc"), - Boolean.TRUE, - 22, - new CohereRerankTaskSettings(null, null, 3), - "model" - ); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - MatcherAssert.assertThat(xContentResult, is(""" - {"model":"model","query":"query","documents":["abc"],"return_documents":true,"top_n":22,"max_chunks_per_doc":3}""")); - } - - public void testXContent_WritesMinimalFields() throws IOException { - var entity = new CohereRerankRequestEntity( - "query", - List.of("abc"), - null, - null, - new CohereRerankTaskSettings(null, null, null), - "model" - ); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - MatcherAssert.assertThat(xContentResult, is(""" - {"model":"model","query":"query","documents":["abc"]}""")); - } - - public void testXContent_PrefersRootLevelReturnDocumentsAndTopN() throws IOException { - var entity = new CohereRerankRequestEntity( - "query", - List.of("abc"), - Boolean.FALSE, - 99, - new CohereRerankTaskSettings(33, Boolean.TRUE, null), - "model" - ); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - MatcherAssert.assertThat(xContentResult, is(""" - {"model":"model","query":"query","documents":["abc"],"return_documents":false,"top_n":99}""")); - } - - public void testXContent_UsesTaskSettingsIfNoRootOptionsDefined() throws IOException { - var entity = new CohereRerankRequestEntity( - "query", - List.of("abc"), - null, - null, - new CohereRerankTaskSettings(33, Boolean.TRUE, null), - "model" - ); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - MatcherAssert.assertThat(xContentResult, is(""" - {"model":"model","query":"query","documents":["abc"],"return_documents":true,"top_n":33}""")); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtilsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtilsTests.java index ad5c9c4e80330..ef2b29bbe9a2a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtilsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereUtilsTests.java @@ -7,7 +7,9 @@ package org.elasticsearch.xpack.inference.services.cohere.request; +import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; +import org.hamcrest.MatcherAssert; import static org.hamcrest.Matchers.is; @@ -20,4 +22,18 @@ public void testCreateRequestSourceHeader() { assertThat(requestSourceHeader.getValue(), is("unspecified:elasticsearch")); } + public void testInputTypeToString() { + assertThat(CohereUtils.inputTypeToString(InputType.INGEST), is("search_document")); + assertThat(CohereUtils.inputTypeToString(InputType.INTERNAL_INGEST), is("search_document")); + assertThat(CohereUtils.inputTypeToString(InputType.SEARCH), is("search_query")); + assertThat(CohereUtils.inputTypeToString(InputType.INTERNAL_SEARCH), is("search_query")); + assertThat(CohereUtils.inputTypeToString(InputType.CLASSIFICATION), is("classification")); + assertThat(CohereUtils.inputTypeToString(InputType.CLUSTERING), is("clustering")); + assertThat(InputType.values().length, is(7)); // includes unspecified. Fail if new values are added + } + + public void testInputTypeToString_ThrowsAssertionFailure_WhenInputTypeIsUnspecified() { + var thrownException = expectThrows(AssertionError.class, () -> CohereUtils.inputTypeToString(InputType.UNSPECIFIED)); + MatcherAssert.assertThat(thrownException.getMessage(), is("received invalid input type value [unspecified]")); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequestTests.java new file mode 100644 index 0000000000000..1f444ed3e8ce2 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1CompletionRequestTests.java @@ -0,0 +1,136 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.request.v1; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModelTests; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; +import org.hamcrest.CoreMatchers; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; + +public class CohereV1CompletionRequestTests extends ESTestCase { + + public void testCreateRequest_UrlDefined() throws IOException { + var request = new CohereV1CompletionRequest( + List.of("abc"), + CohereCompletionModelTests.createModel("http://localhost", "secret", null), + false + ); + + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getURI().toString(), is("http://localhost/v1/chat")); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + assertThat(httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), is(CohereUtils.ELASTIC_REQUEST_SOURCE)); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, is(Map.of("message", "abc"))); + } + + public void testCreateRequest_ModelDefined() throws IOException { + var request = new CohereV1CompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel(null, "secret", "model"), false); + + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getURI().toString(), is("https://api.cohere.ai/v1/chat")); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + assertThat(httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), is(CohereUtils.ELASTIC_REQUEST_SOURCE)); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, is(Map.of("message", "abc", "model", "model"))); + } + + public void testDefaultUrl() { + var request = new CohereV1CompletionRequest(List.of("abc"), CohereCompletionModelTests.createModel(null, "secret", null), false); + + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getURI().toString(), is("https://api.cohere.ai/v1/chat")); + } + + public void testTruncate_ReturnsSameInstance() { + var request = new CohereV1CompletionRequest( + List.of("abc"), + CohereCompletionModelTests.createModel("url", "secret", "model"), + false + ); + var truncatedRequest = request.truncate(); + + assertThat(truncatedRequest, sameInstance(request)); + } + + public void testTruncationInfo_ReturnsNull() { + var request = new CohereV1CompletionRequest( + List.of("abc"), + CohereCompletionModelTests.createModel("url", "secret", "model"), + false + ); + + assertNull(request.getTruncationInfo()); + } + + public void testXContent_WritesAllFields() throws IOException { + var request = new CohereV1CompletionRequest( + List.of("some input"), + CohereCompletionModelTests.createModel("url", "secret", "model"), + false + ); + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + request.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(""" + {"message":"some input","model":"model"}""")); + } + + public void testXContent_DoesNotWriteModelIfNotSpecified() throws IOException { + var request = new CohereV1CompletionRequest( + List.of("some input"), + CohereCompletionModelTests.createModel("url", "secret", null), + false + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + request.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(""" + {"message":"some input"}""")); + } + + public void testXContent_ThrowsIfInputIsNull() { + expectThrows( + NullPointerException.class, + () -> new CohereV1CompletionRequest(null, CohereCompletionModelTests.createModel("url", "secret", null), false) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1EmbeddingsRequestTests.java similarity index 68% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereEmbeddingsRequestTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1EmbeddingsRequestTests.java index 508c81bb940cd..18af39004b8eb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/CohereEmbeddingsRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1EmbeddingsRequestTests.java @@ -5,12 +5,15 @@ * 2.0. */ -package org.elasticsearch.xpack.inference.services.cohere.request; +package org.elasticsearch.xpack.inference.services.cohere.request.v1; import org.apache.http.HttpHeaders; import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.Strings; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.InputTypeTests; import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; @@ -18,6 +21,8 @@ import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; +import org.hamcrest.CoreMatchers; import org.hamcrest.MatcherAssert; import java.io.IOException; @@ -25,17 +30,16 @@ import java.util.Map; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; -import static org.elasticsearch.xpack.inference.services.cohere.request.CohereEmbeddingsRequestEntity.convertToString; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; -public class CohereEmbeddingsRequestTests extends ESTestCase { - public void testCreateRequest_UrlDefined() throws IOException { +public class CohereV1EmbeddingsRequestTests extends ESTestCase { + public void testCreateRequest() throws IOException { var inputType = InputTypeTests.randomWithNull(); var request = createRequest( List.of("abc"), inputType, - CohereEmbeddingsModelTests.createModel("url", "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, null, null) + CohereEmbeddingsModelTests.createModel(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, null, null) ); var httpRequest = request.createHttpRequest(); @@ -43,7 +47,7 @@ public void testCreateRequest_UrlDefined() throws IOException { var httpPost = (HttpPost) httpRequest.httpRequestBase(); - MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getURI().toString(), is("https://api.cohere.ai/v1/embed")); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); MatcherAssert.assertThat( @@ -63,7 +67,7 @@ public void testCreateRequest_AllOptionsDefined() throws IOException { List.of("abc"), inputType, CohereEmbeddingsModelTests.createModel( - "url", + "http://localhost:8080", "secret", new CohereEmbeddingsTaskSettings(InputType.INGEST, CohereTruncation.START), null, @@ -78,7 +82,7 @@ public void testCreateRequest_AllOptionsDefined() throws IOException { var httpPost = (HttpPost) httpRequest.httpRequestBase(); - MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getURI().toString(), is("http://localhost:8080/v1/embed")); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); MatcherAssert.assertThat( @@ -100,7 +104,7 @@ public void testCreateRequest_WithTaskSettingsInputType() throws IOException { List.of("abc"), null, CohereEmbeddingsModelTests.createModel( - "url", + null, "secret", new CohereEmbeddingsTaskSettings(inputType, CohereTruncation.END), null, @@ -115,7 +119,6 @@ public void testCreateRequest_WithTaskSettingsInputType() throws IOException { var httpPost = (HttpPost) httpRequest.httpRequestBase(); - MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); MatcherAssert.assertThat( @@ -137,7 +140,7 @@ public void testCreateRequest_RequestInputTypeTakesPrecedence() throws IOExcepti List.of("abc"), requestInputType, CohereEmbeddingsModelTests.createModel( - "url", + null, "secret", new CohereEmbeddingsTaskSettings(taskSettingInputType, CohereTruncation.END), null, @@ -152,7 +155,7 @@ public void testCreateRequest_RequestInputTypeTakesPrecedence() throws IOExcepti var httpPost = (HttpPost) httpRequest.httpRequestBase(); - MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getURI().toString(), is("https://api.cohere.ai/v1/embed")); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); MatcherAssert.assertThat( @@ -173,7 +176,7 @@ public void testCreateRequest_InputTypeSearch_EmbeddingTypeInt8_TruncateEnd() th List.of("abc"), inputType, CohereEmbeddingsModelTests.createModel( - "url", + null, "secret", new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.END), null, @@ -188,7 +191,7 @@ public void testCreateRequest_InputTypeSearch_EmbeddingTypeInt8_TruncateEnd() th var httpPost = (HttpPost) httpRequest.httpRequestBase(); - MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); + MatcherAssert.assertThat(httpPost.getURI().toString(), is("https://api.cohere.ai/v1/embed")); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); MatcherAssert.assertThat( @@ -210,7 +213,7 @@ public void testCreateRequest_InputTypeSearch_EmbeddingTypeBit_TruncateEnd() thr List.of("abc"), inputType, CohereEmbeddingsModelTests.createModel( - "url", + null, "secret", new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.END), null, @@ -225,7 +228,6 @@ public void testCreateRequest_InputTypeSearch_EmbeddingTypeBit_TruncateEnd() thr var httpPost = (HttpPost) httpRequest.httpRequestBase(); - MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); MatcherAssert.assertThat( @@ -247,7 +249,7 @@ public void testCreateRequest_TruncateNone() throws IOException { List.of("abc"), inputType, CohereEmbeddingsModelTests.createModel( - "url", + null, "secret", new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE), null, @@ -262,7 +264,6 @@ public void testCreateRequest_TruncateNone() throws IOException { var httpPost = (HttpPost) httpRequest.httpRequestBase(); - MatcherAssert.assertThat(httpPost.getURI().toString(), is("url")); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); MatcherAssert.assertThat( @@ -277,16 +278,123 @@ public void testCreateRequest_TruncateNone() throws IOException { validateInputType(requestMap, null, inputType); } - public static CohereEmbeddingsRequest createRequest(List input, InputType inputType, CohereEmbeddingsModel model) { - return new CohereEmbeddingsRequest(input, inputType, model); + public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { + var entity = createRequest( + "model", + List.of("abc"), + InputType.INTERNAL_INGEST, + new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.START), + CohereEmbeddingType.FLOAT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"texts":["abc"],"model":"model","input_type":"search_document","embedding_types":["float"],"truncate":"start"}""")); + } + + public void testXContent_TaskSettingsInputType_EmbeddingTypesInt8_TruncateNone() throws IOException { + var entity = createRequest( + "model", + List.of("abc"), + null, + new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), + CohereEmbeddingType.INT8 + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["int8"],"truncate":"none"}""")); + } + + public void testXContent_InternalInputType_EmbeddingTypesByte_TruncateNone() throws IOException { + var entity = createRequest( + "model", + List.of("abc"), + InputType.INTERNAL_SEARCH, + new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE), + CohereEmbeddingType.BYTE + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["int8"],"truncate":"none"}""")); + } + + public void testXContent_InputTypeSearch_EmbeddingTypesBinary_TruncateNone() throws IOException { + var entity = createRequest( + "model", + List.of("abc"), + InputType.SEARCH, + new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), + CohereEmbeddingType.BINARY + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["binary"],"truncate":"none"}""")); + } + + public void testXContent_InputTypeSearch_EmbeddingTypesBit_TruncateNone() throws IOException { + var entity = createRequest( + "model", + List.of("abc"), + null, + new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), + CohereEmbeddingType.BIT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"texts":["abc"],"model":"model","input_type":"search_query","embedding_types":["binary"],"truncate":"none"}""")); + } + + public void testXContent_WritesNoOptionalFields_WhenTheyAreNotDefined() throws IOException { + var entity = createRequest(null, List.of("abc"), null, CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"texts":["abc"],"embedding_types":["float"]}""")); + } + + public static CohereV1EmbeddingsRequest createRequest(List input, InputType inputType, CohereEmbeddingsModel model) { + return new CohereV1EmbeddingsRequest(input, inputType, model); + } + + public static CohereV1EmbeddingsRequest createRequest( + String modelId, + List input, + InputType inputType, + CohereEmbeddingsTaskSettings taskSettings, + CohereEmbeddingType embeddingType + ) { + var model = CohereEmbeddingsModelTests.createModel(null, "secret", taskSettings, null, null, modelId, embeddingType); + return new CohereV1EmbeddingsRequest(input, inputType, model); } private void validateInputType(Map requestMap, InputType taskSettingsInputType, InputType requestInputType) { if (InputType.isSpecified(requestInputType)) { - var convertedInputType = convertToString(requestInputType); + var convertedInputType = CohereUtils.inputTypeToString(requestInputType); assertThat(requestMap.get("input_type"), is(convertedInputType)); } else if (InputType.isSpecified(taskSettingsInputType)) { - var convertedInputType = convertToString(taskSettingsInputType); + var convertedInputType = CohereUtils.inputTypeToString(taskSettingsInputType); assertThat(requestMap.get("input_type"), is(convertedInputType)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1RerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1RerankRequestTests.java new file mode 100644 index 0000000000000..62ff4d599f6e6 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v1/CohereV1RerankRequestTests.java @@ -0,0 +1,134 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.request.v1; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; +import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; +import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.net.URI; +import java.util.List; + +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class CohereV1RerankRequestTests extends ESTestCase { + public void testRequest() { + var request = new CohereV1RerankRequest( + "query", + List.of("abc"), + Boolean.TRUE, + 22, + createModel("model", new CohereRerankTaskSettings(null, null, 3)) + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("https://api.cohere.ai/v1/rerank")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), + is(CohereUtils.ELASTIC_REQUEST_SOURCE) + ); + } + + public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { + var entity = new CohereV1RerankRequest( + "query", + List.of("abc"), + Boolean.TRUE, + 22, + createModel("model", new CohereRerankTaskSettings(null, null, 3)) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"model":"model","query":"query","documents":["abc"],"return_documents":true,"top_n":22,"max_chunks_per_doc":3}""")); + } + + public void testXContent_WritesMinimalFields() throws IOException { + var entity = new CohereV1RerankRequest( + "query", + List.of("abc"), + null, + null, + createModel("model", new CohereRerankTaskSettings(null, null, null)) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"model":"model","query":"query","documents":["abc"]}""")); + } + + public void testXContent_PrefersRootLevelReturnDocumentsAndTopN() throws IOException { + var entity = new CohereV1RerankRequest( + "query", + List.of("abc"), + Boolean.FALSE, + 99, + createModel("model", new CohereRerankTaskSettings(33, Boolean.TRUE, null)) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"model":"model","query":"query","documents":["abc"],"return_documents":false,"top_n":99}""")); + } + + public void testXContent_UsesTaskSettingsIfNoRootOptionsDefined() throws IOException { + var entity = new CohereV1RerankRequest( + "query", + List.of("abc"), + null, + null, + createModel("model", new CohereRerankTaskSettings(33, Boolean.TRUE, null)) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"model":"model","query":"query","documents":["abc"],"return_documents":true,"top_n":33}""")); + } + + private CohereRerankModel createModel(String modelId, CohereRerankTaskSettings taskSettings) { + return new CohereRerankModel( + "inference_id", + new CohereRerankServiceSettings((URI) null, modelId, null, CohereServiceSettings.CohereApiVersion.V2), + taskSettings, + new DefaultSecretSettings(new SecureString("secret".toCharArray())) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequestTests.java new file mode 100644 index 0000000000000..2fb51ca8ca457 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2CompletionRequestTests.java @@ -0,0 +1,93 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.request.v2; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.cohere.completion.CohereCompletionModelTests; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; +import org.hamcrest.CoreMatchers; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class CohereV2CompletionRequestTests extends ESTestCase { + + public void testCreateRequest() throws IOException { + var request = new CohereV2CompletionRequest( + List.of("abc"), + CohereCompletionModelTests.createModel(null, "secret", "required model id"), + false + ); + + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getURI().toString(), is("https://api.cohere.ai/v2/chat")); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + assertThat(httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), is(CohereUtils.ELASTIC_REQUEST_SOURCE)); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, is(Map.of("message", "abc", "model", "required model id", "stream", false))); + } + + public void testDefaultUrl() { + var request = new CohereV2CompletionRequest( + List.of("abc"), + CohereCompletionModelTests.createModel(null, "secret", "model id"), + false + ); + + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getURI().toString(), is("https://api.cohere.ai/v2/chat")); + } + + public void testOverriddenUrl() { + var request = new CohereV2CompletionRequest( + List.of("abc"), + CohereCompletionModelTests.createModel("http://localhost", "secret", "model id"), + false + ); + + var httpRequest = request.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getURI().toString(), is("http://localhost/v2/chat")); + } + + public void testXContents() throws IOException { + var request = new CohereV2CompletionRequest( + List.of("some input"), + CohereCompletionModelTests.createModel(null, "secret", "model"), + false + ); + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + request.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(""" + {"message":"some input","model":"model","stream":false}""")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequestTests.java new file mode 100644 index 0000000000000..a7e009d63a903 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2EmbeddingsRequestTests.java @@ -0,0 +1,317 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.request.v2; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.InputTypeTests; +import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class CohereV2EmbeddingsRequestTests extends ESTestCase { + public void testCreateRequest() throws IOException { + var inputType = InputTypeTests.randomWithoutUnspecified(); + var request = createRequest( + List.of("abc"), + inputType, + CohereEmbeddingsModelTests.createModel( + null, + "secret", + new CohereEmbeddingsTaskSettings(inputType, CohereTruncation.START), + null, + null, + "model id", + null + ) + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("https://api.cohere.ai/v2/embed")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), + is(CohereUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat(requestMap.get("texts"), is(List.of("abc"))); + MatcherAssert.assertThat(requestMap.get("embedding_types"), is(List.of("float"))); + MatcherAssert.assertThat(requestMap.get("model"), is("model id")); + MatcherAssert.assertThat(requestMap.get("truncate"), is("start")); + validateInputType(requestMap, inputType, inputType); + } + + public void testCreateRequest_WithTaskSettingsInputType() throws IOException { + var inputType = InputTypeTests.randomWithoutUnspecified(); + var request = createRequest( + List.of("abc"), + InputType.UNSPECIFIED, + CohereEmbeddingsModelTests.createModel( + "url", + "secret", + new CohereEmbeddingsTaskSettings(inputType, CohereTruncation.END), + null, + null, + "cohere model", + null + ) + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + validateInputType(requestMap, inputType, null); + } + + public void testCreateRequest_InputTypeSearch_EmbeddingTypeInt8_TruncateEnd() throws IOException { + var inputType = InputTypeTests.randomWithoutUnspecified(); + var request = createRequest( + List.of("abc"), + inputType, + CohereEmbeddingsModelTests.createModel( + "http://localhost", + "secret", + new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.END), + null, + null, + "model", + CohereEmbeddingType.INT8 + ) + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("http://localhost/v2/embed")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), + is(CohereUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat(requestMap.get("texts"), is(List.of("abc"))); + MatcherAssert.assertThat(requestMap.get("model"), is("model")); + MatcherAssert.assertThat(requestMap.get("embedding_types"), is(List.of("int8"))); + MatcherAssert.assertThat(requestMap.get("truncate"), is("end")); + validateInputType(requestMap, null, inputType); + } + + public void testCreateRequest_InputTypeSearch_EmbeddingTypeBit_TruncateEnd() throws IOException { + var inputType = InputTypeTests.randomWithoutUnspecified(); + var request = createRequest( + List.of("abc"), + inputType, + CohereEmbeddingsModelTests.createModel( + null, + "secret", + new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.END), + null, + null, + "model", + CohereEmbeddingType.BIT + ) + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), + is(CohereUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat(requestMap.get("texts"), is(List.of("abc"))); + MatcherAssert.assertThat(requestMap.get("model"), is("model")); + MatcherAssert.assertThat(requestMap.get("embedding_types"), is(List.of("binary"))); + MatcherAssert.assertThat(requestMap.get("truncate"), is("end")); + validateInputType(requestMap, null, inputType); + } + + public void testCreateRequest_TruncateNone() throws IOException { + var inputType = InputTypeTests.randomWithoutUnspecified(); + var request = createRequest( + List.of("abc"), + inputType, + CohereEmbeddingsModelTests.createModel( + null, + "secret", + new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE), + null, + null, + "cohere model", + null + ) + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("https://api.cohere.ai/v2/embed")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), + is(CohereUtils.ELASTIC_REQUEST_SOURCE) + ); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + MatcherAssert.assertThat(requestMap.get("texts"), is(List.of("abc"))); + MatcherAssert.assertThat(requestMap.get("embedding_types"), is(List.of("float"))); + MatcherAssert.assertThat(requestMap.get("truncate"), is("none")); + validateInputType(requestMap, null, inputType); + } + + public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { + var entity = createRequest( + "cohere model", + List.of("abc"), + InputType.INTERNAL_INGEST, + new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.START), + CohereEmbeddingType.FLOAT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"texts":["abc"],"model":"cohere model","input_type":"search_document","embedding_types":["float"],"truncate":"start"}""")); + } + + public void testXContent_TaskSettingsInputType_EmbeddingTypesInt8_TruncateNone() throws IOException { + var entity = createRequest( + "cohere model", + List.of("abc"), + InputType.INGEST, + new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), + CohereEmbeddingType.INT8 + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"texts":["abc"],"model":"cohere model","input_type":"search_document","embedding_types":["int8"],"truncate":"none"}""")); + } + + public void testXContent_InternalInputType_EmbeddingTypesByte_TruncateNone() throws IOException { + var entity = createRequest( + "cohere model", + List.of("abc"), + InputType.INTERNAL_SEARCH, + new CohereEmbeddingsTaskSettings(null, CohereTruncation.NONE), + CohereEmbeddingType.BYTE + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"texts":["abc"],"model":"cohere model","input_type":"search_query","embedding_types":["int8"],"truncate":"none"}""")); + } + + public void testXContent_InputTypeSearch_EmbeddingTypesBinary_TruncateNone() throws IOException { + var entity = createRequest( + "cohere model", + List.of("abc"), + InputType.SEARCH, + new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), + CohereEmbeddingType.BINARY + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"texts":["abc"],"model":"cohere model","input_type":"search_query","embedding_types":["binary"],"truncate":"none"}""")); + } + + public void testXContent_InputTypeSearch_EmbeddingTypesBit_TruncateNone() throws IOException { + var entity = createRequest( + "cohere model", + List.of("abc"), + InputType.SEARCH, + new CohereEmbeddingsTaskSettings(InputType.SEARCH, CohereTruncation.NONE), + CohereEmbeddingType.BIT + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"texts":["abc"],"model":"cohere model","input_type":"search_query","embedding_types":["binary"],"truncate":"none"}""")); + } + + public static CohereV2EmbeddingsRequest createRequest(List input, InputType inputType, CohereEmbeddingsModel model) { + return new CohereV2EmbeddingsRequest(input, inputType, model); + } + + public static CohereV2EmbeddingsRequest createRequest( + String modelId, + List input, + InputType inputType, + CohereEmbeddingsTaskSettings taskSettings, + CohereEmbeddingType embeddingType + ) { + var model = CohereEmbeddingsModelTests.createModel(null, "secret", taskSettings, null, null, modelId, embeddingType); + return new CohereV2EmbeddingsRequest(input, inputType, model); + } + + private void validateInputType(Map requestMap, InputType taskSettingsInputType, InputType requestInputType) { + if (InputType.isSpecified(requestInputType)) { + var convertedInputType = CohereUtils.inputTypeToString(requestInputType); + assertThat(requestMap.get("input_type"), is(convertedInputType)); + } else if (InputType.isSpecified(taskSettingsInputType)) { + var convertedInputType = CohereUtils.inputTypeToString(taskSettingsInputType); + assertThat(requestMap.get("input_type"), is(convertedInputType)); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2RerankRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2RerankRequestTests.java new file mode 100644 index 0000000000000..34cf019b6010a --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/request/v2/CohereV2RerankRequestTests.java @@ -0,0 +1,132 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.cohere.request.v2; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; +import org.elasticsearch.xpack.inference.services.cohere.request.CohereUtils; +import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; +import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankServiceSettings; +import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankTaskSettings; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.List; + +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class CohereV2RerankRequestTests extends ESTestCase { + public void testUrl() { + var request = new CohereV2RerankRequest( + "query", + List.of("abc"), + Boolean.TRUE, + 22, + createModel("model", null, new CohereRerankTaskSettings(null, null, 3)) + ); + + var httpRequest = request.createHttpRequest(); + MatcherAssert.assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + MatcherAssert.assertThat(httpPost.getURI().toString(), is("https://api.cohere.ai/v2/rerank")); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + MatcherAssert.assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); + MatcherAssert.assertThat( + httpPost.getLastHeader(CohereUtils.REQUEST_SOURCE_HEADER).getValue(), + is(CohereUtils.ELASTIC_REQUEST_SOURCE) + ); + } + + public void testXContent_WritesAllFields_WhenTheyAreDefined() throws IOException { + var entity = new CohereV2RerankRequest( + "query", + List.of("abc"), + Boolean.TRUE, + 22, + createModel("model", "uri", new CohereRerankTaskSettings(null, null, 3)) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"model":"model","query":"query","documents":["abc"],"return_documents":true,"top_n":22,"max_chunks_per_doc":3}""")); + } + + public void testXContent_WritesMinimalFields() throws IOException { + var entity = new CohereV2RerankRequest( + "query", + List.of("abc"), + null, + null, + createModel("model", "uri", new CohereRerankTaskSettings(null, null, null)) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"model":"model","query":"query","documents":["abc"]}""")); + } + + public void testXContent_PrefersRootLevelReturnDocumentsAndTopN() throws IOException { + var entity = new CohereV2RerankRequest( + "query", + List.of("abc"), + Boolean.FALSE, + 99, + createModel("model", "uri", new CohereRerankTaskSettings(33, Boolean.TRUE, null)) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"model":"model","query":"query","documents":["abc"],"return_documents":false,"top_n":99}""")); + } + + public void testXContent_UsesTaskSettingsIfNoRootOptionsDefined() throws IOException { + var entity = new CohereV2RerankRequest( + "query", + List.of("abc"), + null, + null, + createModel("model", "uri", new CohereRerankTaskSettings(33, Boolean.TRUE, null)) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + MatcherAssert.assertThat(xContentResult, CoreMatchers.is(""" + {"model":"model","query":"query","documents":["abc"],"return_documents":true,"top_n":33}""")); + } + + private CohereRerankModel createModel(String modelId, String uri, CohereRerankTaskSettings taskSettings) { + return new CohereRerankModel( + "inference_id", + new CohereRerankServiceSettings(uri, modelId, null, CohereServiceSettings.CohereApiVersion.V2), + taskSettings, + new DefaultSecretSettings(new SecureString("secret".toCharArray())) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettingsTests.java index e3401b74017f2..47e91590a22e1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/rerank/CohereRerankServiceSettingsTests.java @@ -36,7 +36,8 @@ public static CohereRerankServiceSettings createRandom(@Nullable RateLimitSettin return new CohereRerankServiceSettings( randomFrom(new String[] { null, Strings.format("http://%s.com", randomAlphaOfLength(8)) }), randomFrom(new String[] { null, randomAlphaOfLength(10) }), - rateLimitSettings + rateLimitSettings, + CohereServiceSettings.CohereApiVersion.V2 ); } @@ -44,7 +45,7 @@ public void testToXContent_WritesAllValues() throws IOException { var url = "http://www.abc.com"; var model = "model"; - var serviceSettings = new CohereRerankServiceSettings(url, model, null); + var serviceSettings = new CohereRerankServiceSettings(url, model, null, CohereServiceSettings.CohereApiVersion.V2); XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); serviceSettings.toXContent(builder, null); @@ -56,7 +57,8 @@ public void testToXContent_WritesAllValues() throws IOException { "model_id":"model", "rate_limit": { "requests_per_minute": 10000 - } + }, + "api_version": "V2" } """)); } @@ -80,7 +82,19 @@ protected CohereRerankServiceSettings mutateInstance(CohereRerankServiceSettings protected CohereRerankServiceSettings mutateInstanceForVersion(CohereRerankServiceSettings instance, TransportVersion version) { if (version.before(TransportVersions.V_8_15_0)) { // We always default to the same rate limit settings, if a node is on a version before rate limits were introduced - return new CohereRerankServiceSettings(instance.uri(), instance.modelId(), CohereServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS); + return new CohereRerankServiceSettings( + instance.uri(), + instance.modelId(), + CohereServiceSettings.DEFAULT_RATE_LIMIT_SETTINGS, + CohereServiceSettings.CohereApiVersion.V1 + ); + } else if (version.before(TransportVersions.ML_INFERENCE_COHERE_API_VERSION_8_19)) { + return new CohereRerankServiceSettings( + instance.uri(), + instance.modelId(), + instance.rateLimitSettings(), + CohereServiceSettings.CohereApiVersion.V1 + ); } return instance; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java index f5d700016bf81..b14cfcd14ec43 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java @@ -345,7 +345,11 @@ public void testSend_FailsFromInvalidResponseFormat_ForRerankAction() throws IOE var action = actionCreator.create(model); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new QueryAndDocsInputs("popular name", List.of("Luke")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new QueryAndDocsInputs("popular name", List.of("Luke"), null, null, false), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); assertThat(