Skip to content

Commit 3b51dd5

Browse files
authored
[EIS] Dense Text Embedding task type integration (#129847)
1 parent 0e23624 commit 3b51dd5

File tree

26 files changed

+1860
-266
lines changed

26 files changed

+1860
-266
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ static TransportVersion def(int id) {
206206
public static final TransportVersion RANDOM_SAMPLER_QUERY_BUILDER_8_19 = def(8_841_0_56);
207207
public static final TransportVersion ML_INFERENCE_SAGEMAKER_ELASTIC_8_19 = def(8_841_0_57);
208208
public static final TransportVersion SPARSE_VECTOR_FIELD_PRUNING_OPTIONS_8_19 = def(8_841_0_58);
209+
public static final TransportVersion ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED_8_19 = def(8_841_0_59);
209210
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
210211
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
211212
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
@@ -318,6 +319,7 @@ static TransportVersion def(int id) {
318319
public static final TransportVersion ML_INFERENCE_SAGEMAKER_ELASTIC = def(9_106_0_00);
319320
public static final TransportVersion SPARSE_VECTOR_FIELD_PRUNING_OPTIONS = def(9_107_0_00);
320321
public static final TransportVersion CLUSTER_STATE_PROJECTS_SETTINGS = def(9_108_0_00);
322+
public static final TransportVersion ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED = def(9_109_00_0);
321323

322324
/*
323325
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetModelsWithElasticInferenceServiceIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public void testGetDefaultEndpoints() throws IOException {
3333
var allModels = getAllModels();
3434
var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION);
3535

36-
assertThat(allModels, hasSize(6));
36+
assertThat(allModels, hasSize(7));
3737
assertThat(chatCompletionModels, hasSize(1));
3838

3939
for (var model : chatCompletionModels) {
@@ -42,6 +42,7 @@ public void testGetDefaultEndpoints() throws IOException {
4242

4343
assertInferenceIdTaskType(allModels, ".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION);
4444
assertInferenceIdTaskType(allModels, ".elser-v2-elastic", TaskType.SPARSE_EMBEDDING);
45+
assertInferenceIdTaskType(allModels, ".multilingual-embed-v1-elastic", TaskType.TEXT_EMBEDDING);
4546
assertInferenceIdTaskType(allModels, ".rerank-v1-elastic", TaskType.RERANK);
4647
}
4748

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceGetServicesIT.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated;
2222
import static org.hamcrest.Matchers.containsInAnyOrder;
23+
import static org.hamcrest.Matchers.equalTo;
2324

2425
public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
2526

@@ -76,16 +77,21 @@ private Iterable<String> providers(List<Object> services) {
7677
}
7778

7879
public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
80+
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
81+
assertThat(services.size(), equalTo(18));
82+
7983
assertThat(
8084
providersFor(TaskType.TEXT_EMBEDDING),
8185
containsInAnyOrder(
8286
List.of(
8387
"alibabacloud-ai-search",
8488
"amazonbedrock",
89+
"amazon_sagemaker",
8590
"azureaistudio",
8691
"azureopenai",
8792
"cohere",
8893
"custom",
94+
"elastic",
8995
"elasticsearch",
9096
"googleaistudio",
9197
"googlevertexai",
@@ -95,8 +101,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
95101
"openai",
96102
"text_embedding_test_service",
97103
"voyageai",
98-
"watsonxai",
99-
"amazon_sagemaker"
104+
"watsonxai"
100105
).toArray()
101106
)
102107
);

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/MockElasticInferenceServiceAuthorizationServer.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ public void enqueueAuthorizeAllModelsResponse() {
4343
"task_types": ["embed/text/sparse"]
4444
},
4545
{
46+
"model_name": "multilingual-embed-v1",
47+
"task_types": ["embed/text/dense"]
48+
},
49+
{
4650
"model_name": "rerank-v1",
4751
"task_types": ["rerank/text/text-similarity"]
4852
}

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/InferenceRevokeDefaultEndpointsIT.java

Lines changed: 66 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import org.elasticsearch.action.support.PlainActionFuture;
1212
import org.elasticsearch.common.settings.Settings;
1313
import org.elasticsearch.core.TimeValue;
14+
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
1415
import org.elasticsearch.inference.InferenceService;
1516
import org.elasticsearch.inference.MinimalServiceSettings;
1617
import org.elasticsearch.inference.Model;
@@ -43,6 +44,7 @@
4344
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
4445
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
4546
import static org.hamcrest.CoreMatchers.is;
47+
import static org.hamcrest.Matchers.containsInAnyOrder;
4648
import static org.mockito.Mockito.mock;
4749

4850
public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
@@ -190,13 +192,17 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
190192
String responseJson = """
191193
{
192194
"models": [
195+
{
196+
"model_name": "elser-v2",
197+
"task_types": ["embed/text/sparse"]
198+
},
193199
{
194200
"model_name": "rainbow-sprinkles",
195201
"task_types": ["chat"]
196202
},
197203
{
198-
"model_name": "elser-v2",
199-
"task_types": ["embed/text/sparse"]
204+
"model_name": "multilingual-embed-v1",
205+
"task_types": ["embed/text/dense"]
200206
},
201207
{
202208
"model_name": "rerank-v1",
@@ -214,36 +220,48 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
214220
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
215221
assertThat(
216222
service.defaultConfigIds(),
217-
is(
218-
List.of(
219-
new InferenceService.DefaultConfigId(
220-
".elser-v2-elastic",
221-
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
222-
service
223-
),
224-
new InferenceService.DefaultConfigId(
225-
".rainbow-sprinkles-elastic",
226-
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
227-
service
223+
containsInAnyOrder(
224+
new InferenceService.DefaultConfigId(
225+
".elser-v2-elastic",
226+
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
227+
service
228+
),
229+
new InferenceService.DefaultConfigId(
230+
".rainbow-sprinkles-elastic",
231+
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
232+
service
233+
),
234+
new InferenceService.DefaultConfigId(
235+
".multilingual-embed-v1-elastic",
236+
MinimalServiceSettings.textEmbedding(
237+
ElasticInferenceService.NAME,
238+
ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
239+
ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(),
240+
DenseVectorFieldMapper.ElementType.FLOAT
228241
),
229-
new InferenceService.DefaultConfigId(
230-
".rerank-v1-elastic",
231-
MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
232-
service
233-
)
242+
service
243+
),
244+
new InferenceService.DefaultConfigId(
245+
".rerank-v1-elastic",
246+
MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
247+
service
234248
)
235249
)
236250
);
237251
assertThat(
238252
service.supportedTaskTypes(),
239-
is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK))
253+
is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK, TaskType.TEXT_EMBEDDING))
240254
);
241255

242256
PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
243257
service.defaultConfigs(listener);
244258
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
245-
assertThat(listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
246-
assertThat(listener.actionGet(TIMEOUT).get(2).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic"));
259+
assertThat(
260+
listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(),
261+
is(".multilingual-embed-v1-elastic")
262+
);
263+
assertThat(listener.actionGet(TIMEOUT).get(2).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
264+
assertThat(listener.actionGet(TIMEOUT).get(3).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic"));
247265

248266
var getModelListener = new PlainActionFuture<UnparsedModel>();
249267
// persists the default endpoints
@@ -265,6 +283,10 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
265283
{
266284
"model_name": "rerank-v1",
267285
"task_types": ["rerank/text/text-similarity"]
286+
},
287+
{
288+
"model_name": "multilingual-embed-v1",
289+
"task_types": ["embed/text/dense"]
268290
}
269291
]
270292
}
@@ -278,22 +300,33 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
278300
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
279301
assertThat(
280302
service.defaultConfigIds(),
281-
is(
282-
List.of(
283-
new InferenceService.DefaultConfigId(
284-
".elser-v2-elastic",
285-
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
286-
service
303+
containsInAnyOrder(
304+
new InferenceService.DefaultConfigId(
305+
".elser-v2-elastic",
306+
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
307+
service
308+
),
309+
new InferenceService.DefaultConfigId(
310+
".multilingual-embed-v1-elastic",
311+
MinimalServiceSettings.textEmbedding(
312+
ElasticInferenceService.NAME,
313+
ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
314+
ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(),
315+
DenseVectorFieldMapper.ElementType.FLOAT
287316
),
288-
new InferenceService.DefaultConfigId(
289-
".rerank-v1-elastic",
290-
MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
291-
service
292-
)
317+
service
318+
),
319+
new InferenceService.DefaultConfigId(
320+
".rerank-v1-elastic",
321+
MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
322+
service
293323
)
294324
)
295325
);
296-
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.RERANK)));
326+
assertThat(
327+
service.supportedTaskTypes(),
328+
is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK))
329+
);
297330

298331
var getModelListener = new PlainActionFuture<UnparsedModel>();
299332
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.response.elastic;
9+
10+
import org.elasticsearch.common.xcontent.XContentParserUtils;
11+
import org.elasticsearch.xcontent.ConstructingObjectParser;
12+
import org.elasticsearch.xcontent.ParseField;
13+
import org.elasticsearch.xcontent.XContentFactory;
14+
import org.elasticsearch.xcontent.XContentParserConfiguration;
15+
import org.elasticsearch.xcontent.XContentType;
16+
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
17+
import org.elasticsearch.xpack.inference.external.http.HttpResult;
18+
import org.elasticsearch.xpack.inference.external.request.Request;
19+
20+
import java.io.IOException;
21+
import java.util.List;
22+
23+
import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
24+
25+
public class ElasticInferenceServiceDenseTextEmbeddingsResponseEntity {
26+
27+
/**
28+
* Parses the Elastic Inference Service Dense Text Embeddings response.
29+
*
30+
* For a request like:
31+
*
32+
* <pre>
33+
* <code>
34+
* {
35+
* "inputs": ["Embed this text", "Embed this text, too"]
36+
* }
37+
* </code>
38+
* </pre>
39+
*
40+
* The response would look like:
41+
*
42+
* <pre>
43+
* <code>
44+
* {
45+
* "data": [
46+
* [
47+
* 2.1259406,
48+
* 1.7073475,
49+
* 0.9020516
50+
* ],
51+
* (...)
52+
* ],
53+
* "meta": {
54+
* "usage": {...}
55+
* }
56+
* }
57+
* </code>
58+
* </pre>
59+
*/
60+
public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
61+
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) {
62+
return EmbeddingFloatResult.PARSER.apply(p, null).toTextEmbeddingFloatResults();
63+
}
64+
}
65+
66+
public record EmbeddingFloatResult(List<EmbeddingFloatResultEntry> embeddingResults) {
67+
@SuppressWarnings("unchecked")
68+
public static final ConstructingObjectParser<EmbeddingFloatResult, Void> PARSER = new ConstructingObjectParser<>(
69+
EmbeddingFloatResult.class.getSimpleName(),
70+
true,
71+
args -> new EmbeddingFloatResult((List<EmbeddingFloatResultEntry>) args[0])
72+
);
73+
74+
static {
75+
// Custom field declaration to handle array of arrays format
76+
PARSER.declareField(constructorArg(), (parser, context) -> {
77+
return XContentParserUtils.parseList(parser, (p, index) -> {
78+
List<Float> embedding = XContentParserUtils.parseList(p, (innerParser, innerIndex) -> innerParser.floatValue());
79+
return EmbeddingFloatResultEntry.fromFloatArray(embedding);
80+
});
81+
}, new ParseField("data"), org.elasticsearch.xcontent.ObjectParser.ValueType.OBJECT_ARRAY);
82+
}
83+
84+
public TextEmbeddingFloatResults toTextEmbeddingFloatResults() {
85+
return new TextEmbeddingFloatResults(
86+
embeddingResults.stream().map(entry -> TextEmbeddingFloatResults.Embedding.of(entry.embedding)).toList()
87+
);
88+
}
89+
}
90+
91+
/**
92+
* Represents a single embedding entry in the response.
93+
* For the Elastic Inference Service, each entry is just an array of floats (no wrapper object).
94+
* This is a simpler wrapper that just holds the float array.
95+
*/
96+
public record EmbeddingFloatResultEntry(List<Float> embedding) {
97+
public static EmbeddingFloatResultEntry fromFloatArray(List<Float> floats) {
98+
return new EmbeddingFloatResultEntry(floats);
99+
}
100+
}
101+
102+
private ElasticInferenceServiceDenseTextEmbeddingsResponseEntity() {}
103+
}

0 commit comments

Comments
 (0)