diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java index 8439dfb6af3ca..42c1237cdcdb4 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java @@ -33,7 +33,7 @@ public void testGetDefaultEndpoints() throws IOException { var allModels = getAllModels(); var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION); - assertThat(allModels, hasSize(5)); + assertThat(allModels, hasSize(6)); assertThat(chatCompletionModels, hasSize(1)); for (var model : chatCompletionModels) { @@ -42,6 +42,7 @@ public void testGetDefaultEndpoints() throws IOException { assertInferenceIdTaskType(allModels, ".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION); assertInferenceIdTaskType(allModels, ".elser-v2-elastic", TaskType.SPARSE_EMBEDDING); + assertInferenceIdTaskType(allModels, ".rerank-v1-elastic", TaskType.RERANK); } private static void assertInferenceIdTaskType(List> models, String inferenceId, TaskType taskType) { diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java index ecf89dff104a0..ef1e1ce769e60 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java @@ -111,7 +111,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException { public void testGetServicesWithRerankTaskType() throws IOException { List services = getServices(TaskType.RERANK); - assertThat(services.size(), equalTo(9)); + assertThat(services.size(), equalTo(10)); var providers = providers(services); @@ -127,7 +127,8 @@ public void testGetServicesWithRerankTaskType() throws IOException { "jinaai", "test_reranking_service", "voyageai", - "hugging_face" + "hugging_face", + "elastic" ).toArray() ) ); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java index cf798d5c94364..4bdf9aa40b2c8 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java @@ -41,6 +41,10 @@ public void enqueueAuthorizeAllModelsResponse() { { "model_name": "elser-v2", "task_types": ["embed/text/sparse"] + }, + { + "model_name": "rerank-v1", + "task_types": ["rerank/text/text-similarity"] } ] } diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java index 86c1b549d9de5..3f6c66151cf4d 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java @@ -197,6 +197,10 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA { "model_name": "elser-v2", "task_types": ["embed/text/sparse"] + }, + { + "model_name": "rerank-v1", + "task_types": ["rerank/text/text-similarity"] } ] } @@ -221,16 +225,25 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA ".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), service + ), + new InferenceService.DefaultConfigId( + ".rerank-v1-elastic", + MinimalServiceSettings.rerank(ElasticInferenceService.NAME), + service ) ) ) ); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING))); + assertThat( + service.supportedTaskTypes(), + is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK)) + ); PlainActionFuture> listener = new PlainActionFuture<>(); service.defaultConfigs(listener); assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic")); assertThat(listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); + assertThat(listener.actionGet(TIMEOUT).get(2).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic")); var getModelListener = new PlainActionFuture(); // persists the default endpoints @@ -248,6 +261,10 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA { "model_name": "elser-v2", "task_types": ["embed/text/sparse"] + }, + { + "model_name": "rerank-v1", + "task_types": ["rerank/text/text-similarity"] } ] } @@ -267,11 +284,16 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA ".elser-v2-elastic", MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME), service + ), + new InferenceService.DefaultConfigId( + ".rerank-v1-elastic", + MinimalServiceSettings.rerank(ElasticInferenceService.NAME), + service ) ) ) ); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING))); + assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.RERANK))); var getModelListener = new PlainActionFuture(); modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index d3ef7c97489fc..8da1229a528ea 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -52,6 +52,7 @@ import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel; import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel; +import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel; import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; @@ -95,6 +96,10 @@ public class ElasticInferenceService extends SenderService { static final String DEFAULT_ELSER_MODEL_ID_V2 = "elser-v2"; static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId(DEFAULT_ELSER_MODEL_ID_V2); + // rerank-v1 + static final String DEFAULT_RERANK_MODEL_ID_V1 = "rerank-v1"; + static final String DEFAULT_RERANK_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_RERANK_MODEL_ID_V1); + /** * The task types that the {@link InferenceAction.Request} can accept. */ @@ -159,6 +164,19 @@ private static Map initDefaultEndpoints( elasticInferenceServiceComponents ), MinimalServiceSettings.sparseEmbedding(NAME) + ), + DEFAULT_RERANK_MODEL_ID_V1, + new DefaultModelConfig( + new ElasticInferenceServiceRerankModel( + DEFAULT_RERANK_ENDPOINT_ID_V1, + TaskType.RERANK, + NAME, + new ElasticInferenceServiceRerankServiceSettings(DEFAULT_RERANK_MODEL_ID_V1, null), + EmptyTaskSettings.INSTANCE, + EmptySecretSettings.INSTANCE, + elasticInferenceServiceComponents + ), + MinimalServiceSettings.rerank(NAME) ) ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModel.java index 7e592406a718a..38e71d74b1716 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModel.java @@ -87,7 +87,7 @@ public URI uri() { private URI createUri() throws ElasticsearchStatusException { try { // TODO, consider transforming the base URL into a URI for better error handling. - return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/rerank"); + return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/rerank/text/text-similarity"); } catch (URISyntaxException e) { throw new ElasticsearchStatusException( "Failed to create URI for service [" diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java index 63e60fa83f8b0..bb34ac202bd59 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java @@ -43,7 +43,9 @@ public class ElasticInferenceServiceAuthorizationResponseEntity implements Infer "embed/text/sparse", TaskType.SPARSE_EMBEDDING, "chat", - TaskType.CHAT_COMPLETION + TaskType.CHAT_COMPLETION, + "rerank/text/text-similarity", + TaskType.RERANK ); @SuppressWarnings("unchecked") diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index b3dab72c1410d..71a073c02e02b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -1294,6 +1294,10 @@ public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect() { "model_name": "elser-v2", "task_types": ["embed/text/sparse"] + }, + { + "model_name": "rerank-v1", + "task_types": ["rerank/text/text-similarity"] } ] } @@ -1319,18 +1323,25 @@ public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect() ".rainbow-sprinkles-elastic", MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME), service + ), + new InferenceService.DefaultConfigId( + ".rerank-v1-elastic", + MinimalServiceSettings.rerank(ElasticInferenceService.NAME), + service ) ) ) ); - assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING))); + assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK))); PlainActionFuture> listener = new PlainActionFuture<>(); service.defaultConfigs(listener); var models = listener.actionGet(TIMEOUT); - assertThat(models.size(), is(2)); + assertThat(models.size(), is(3)); assertThat(models.get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic")); assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic")); + assertThat(models.get(2).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic")); + } }