Skip to content

Commit cef717c

Browse files
add default inference endpoint for Elastic Inference Service rerank (#129681)
* add Elastic Inference Service rerank default inference endpoint * [CI] Auto commit changes from spotless * fix integ tests * update mock Elastic Inference Service authorization response * fix rerank service test --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent 648c5ad commit cef717c

File tree

8 files changed

+68
-9
lines changed

8 files changed

+68
-9
lines changed

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public void testGetDefaultEndpoints() throws IOException {
3333
var allModels = getAllModels();
3434
var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION);
3535

36-
assertThat(allModels, hasSize(5));
36+
assertThat(allModels, hasSize(6));
3737
assertThat(chatCompletionModels, hasSize(1));
3838

3939
for (var model : chatCompletionModels) {
@@ -42,6 +42,7 @@ public void testGetDefaultEndpoints() throws IOException {
4242

4343
assertInferenceIdTaskType(allModels, ".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION);
4444
assertInferenceIdTaskType(allModels, ".elser-v2-elastic", TaskType.SPARSE_EMBEDDING);
45+
assertInferenceIdTaskType(allModels, ".rerank-v1-elastic", TaskType.RERANK);
4546
}
4647

4748
private static void assertInferenceIdTaskType(List<Map<String, Object>> models, String inferenceId, TaskType taskType) {

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
111111

112112
public void testGetServicesWithRerankTaskType() throws IOException {
113113
List<Object> services = getServices(TaskType.RERANK);
114-
assertThat(services.size(), equalTo(9));
114+
assertThat(services.size(), equalTo(10));
115115

116116
var providers = providers(services);
117117

@@ -127,7 +127,8 @@ public void testGetServicesWithRerankTaskType() throws IOException {
127127
"jinaai",
128128
"test_reranking_service",
129129
"voyageai",
130-
"hugging_face"
130+
"hugging_face",
131+
"elastic"
131132
).toArray()
132133
)
133134
);

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ public void enqueueAuthorizeAllModelsResponse() {
4141
{
4242
"model_name": "elser-v2",
4343
"task_types": ["embed/text/sparse"]
44+
},
45+
{
46+
"model_name": "rerank-v1",
47+
"task_types": ["rerank/text/text-similarity"]
4448
}
4549
]
4650
}

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,10 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
197197
{
198198
"model_name": "elser-v2",
199199
"task_types": ["embed/text/sparse"]
200+
},
201+
{
202+
"model_name": "rerank-v1",
203+
"task_types": ["rerank/text/text-similarity"]
200204
}
201205
]
202206
}
@@ -221,16 +225,25 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
221225
".rainbow-sprinkles-elastic",
222226
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
223227
service
228+
),
229+
new InferenceService.DefaultConfigId(
230+
".rerank-v1-elastic",
231+
MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
232+
service
224233
)
225234
)
226235
)
227236
);
228-
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING)));
237+
assertThat(
238+
service.supportedTaskTypes(),
239+
is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK))
240+
);
229241

230242
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
231243
service.defaultConfigs(listener);
232244
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
233245
assertThat(listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
246+
assertThat(listener.actionGet(TIMEOUT).get(2).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic"));
234247

235248
var getModelListener = new PlainActionFuture<UnparsedModel>();
236249
// persists the default endpoints
@@ -248,6 +261,10 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
248261
{
249262
"model_name": "elser-v2",
250263
"task_types": ["embed/text/sparse"]
264+
},
265+
{
266+
"model_name": "rerank-v1",
267+
"task_types": ["rerank/text/text-similarity"]
251268
}
252269
]
253270
}
@@ -267,11 +284,16 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
267284
".elser-v2-elastic",
268285
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
269286
service
287+
),
288+
new InferenceService.DefaultConfigId(
289+
".rerank-v1-elastic",
290+
MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
291+
service
270292
)
271293
)
272294
)
273295
);
274-
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING)));
296+
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.RERANK)));
275297

276298
var getModelListener = new PlainActionFuture<UnparsedModel>();
277299
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionModel;
5353
import org.elasticsearch.xpack.inference.services.elastic.completion.ElasticInferenceServiceCompletionServiceSettings;
5454
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankModel;
55+
import org.elasticsearch.xpack.inference.services.elastic.rerank.ElasticInferenceServiceRerankServiceSettings;
5556
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsModel;
5657
import org.elasticsearch.xpack.inference.services.elastic.sparseembeddings.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
5758
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
@@ -95,6 +96,10 @@ public class ElasticInferenceService extends SenderService {
9596
static final String DEFAULT_ELSER_MODEL_ID_V2 = "elser-v2";
9697
static final String DEFAULT_ELSER_ENDPOINT_ID_V2 = defaultEndpointId(DEFAULT_ELSER_MODEL_ID_V2);
9798

99+
// rerank-v1
100+
static final String DEFAULT_RERANK_MODEL_ID_V1 = "rerank-v1";
101+
static final String DEFAULT_RERANK_ENDPOINT_ID_V1 = defaultEndpointId(DEFAULT_RERANK_MODEL_ID_V1);
102+
98103
/**
99104
* The task types that the {@link InferenceAction.Request} can accept.
100105
*/
@@ -159,6 +164,19 @@ private static Map<String, DefaultModelConfig> initDefaultEndpoints(
159164
elasticInferenceServiceComponents
160165
),
161166
MinimalServiceSettings.sparseEmbedding(NAME)
167+
),
168+
DEFAULT_RERANK_MODEL_ID_V1,
169+
new DefaultModelConfig(
170+
new ElasticInferenceServiceRerankModel(
171+
DEFAULT_RERANK_ENDPOINT_ID_V1,
172+
TaskType.RERANK,
173+
NAME,
174+
new ElasticInferenceServiceRerankServiceSettings(DEFAULT_RERANK_MODEL_ID_V1, null),
175+
EmptyTaskSettings.INSTANCE,
176+
EmptySecretSettings.INSTANCE,
177+
elasticInferenceServiceComponents
178+
),
179+
MinimalServiceSettings.rerank(NAME)
162180
)
163181
);
164182
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/rerank/ElasticInferenceServiceRerankModel.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ public URI uri() {
8787
private URI createUri() throws ElasticsearchStatusException {
8888
try {
8989
// TODO, consider transforming the base URL into a URI for better error handling.
90-
return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/rerank");
90+
return new URI(elasticInferenceServiceComponents().elasticInferenceServiceUrl() + "/api/v1/rerank/text/text-similarity");
9191
} catch (URISyntaxException e) {
9292
throw new ElasticsearchStatusException(
9393
"Failed to create URI for service ["

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/response/ElasticInferenceServiceAuthorizationResponseEntity.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ public class ElasticInferenceServiceAuthorizationResponseEntity implements Infer
4343
"embed/text/sparse",
4444
TaskType.SPARSE_EMBEDDING,
4545
"chat",
46-
TaskType.CHAT_COMPLETION
46+
TaskType.CHAT_COMPLETION,
47+
"rerank/text/text-similarity",
48+
TaskType.RERANK
4749
);
4850

4951
@SuppressWarnings("unchecked")

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1294,6 +1294,10 @@ public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect()
12941294
{
12951295
"model_name": "elser-v2",
12961296
"task_types": ["embed/text/sparse"]
1297+
},
1298+
{
1299+
"model_name": "rerank-v1",
1300+
"task_types": ["rerank/text/text-similarity"]
12971301
}
12981302
]
12991303
}
@@ -1319,18 +1323,25 @@ public void testDefaultConfigs_Returns_DefaultEndpoints_WhenTaskTypeIsCorrect()
13191323
".rainbow-sprinkles-elastic",
13201324
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
13211325
service
1326+
),
1327+
new InferenceService.DefaultConfigId(
1328+
".rerank-v1-elastic",
1329+
MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
1330+
service
13221331
)
13231332
)
13241333
)
13251334
);
1326-
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING)));
1335+
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK)));
13271336

13281337
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
13291338
service.defaultConfigs(listener);
13301339
var models = listener.actionGet(TIMEOUT);
1331-
assertThat(models.size(), is(2));
1340+
assertThat(models.size(), is(3));
13321341
assertThat(models.get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
13331342
assertThat(models.get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
1343+
assertThat(models.get(2).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic"));
1344+
13341345
}
13351346
}
13361347

0 commit comments

Comments
 (0)