Skip to content

[EIS] Dense Text Embedding task type integration #129847

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Jun 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
f054dca
Add working dense text embeddings integration with default endpoint. …
timgrein Jun 23, 2025
9ca7369
Merge branch 'main' into eis-text-embedding-task-type
timgrein Jun 23, 2025
6584dab
Fix merge conflicts, compilation errors and test failures
timgrein Jun 23, 2025
9d47176
Spotless apply
timgrein Jun 23, 2025
3e8c70a
Add ElasticInferenceServiceDenseTextEmbeddingsRequestTests
timgrein Jun 23, 2025
23e7595
Add ElasticInferenceServiceDenseTextEmbeddingsRequestEntityTests
timgrein Jun 23, 2025
5af7516
Add "-v1" to multilingual-embed
timgrein Jun 23, 2025
fddfd9d
Add ElasticInferenceServiceDenseTextEmbeddingsServiceSettingsTests.java
timgrein Jun 23, 2025
9b48dfb
Add dense text embedding test cases to ElasticInferenceServiceActionC…
timgrein Jun 23, 2025
dbdadbe
[CI] Auto commit changes from spotless
elasticsearchmachine Jun 23, 2025
e2f872e
Add ElasticInferenceServiceDenseTextEmbeddingsResponseEntityTests
timgrein Jun 23, 2025
485dd89
Merge remote-tracking branch 'origin/eis-text-embedding-task-type' in…
timgrein Jun 23, 2025
172070a
Merge branch 'main' into eis-text-embedding-task-type
timgrein Jun 23, 2025
6a35870
Fix compilation error after resolving merge conflict and spotlessAppl
timgrein Jun 23, 2025
a8b604b
Merge branch 'main' into eis-text-embedding-task-type
brendan-jugan-elastic Jun 23, 2025
3b486b7
remove dimensions_set_by_user
brendan-jugan-elastic Jun 23, 2025
6ffcc22
Merge branch 'main' into eis-text-embedding-task-type
brendan-jugan-elastic Jun 23, 2025
3489a09
[CI] Auto commit changes from spotless
elasticsearchmachine Jun 23, 2025
fb5dbc0
fix checkstyle
brendan-jugan-elastic Jun 23, 2025
1dcbcab
fix checkstyle
brendan-jugan-elastic Jun 23, 2025
dc6f320
[CI] Auto commit changes from spotless
elasticsearchmachine Jun 23, 2025
087d4e5
use ConstructingObjectParser for response parsing
brendan-jugan-elastic Jun 24, 2025
cd3e116
[CI] Auto commit changes from spotless
elasticsearchmachine Jun 24, 2025
aa24341
Merge branch 'main' into eis-text-embedding-task-type
timgrein Jun 24, 2025
7269c51
Some cleanup (removing unused vars etc.)
timgrein Jun 24, 2025
220e208
Fix integration test
timgrein Jun 24, 2025
27ca440
Do not set usage context, if it's null
timgrein Jun 24, 2025
b7d10b8
Pass through chunking settings and provide default for default endpoint
timgrein Jun 24, 2025
3164c6c
Merge branch 'main' into eis-text-embedding-task-type
timgrein Jun 24, 2025
fc11815
After merge conflict resolution clean-up
timgrein Jun 24, 2025
59f84a9
Merge branch 'main' into eis-text-embedding-task-type
timgrein Jun 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ static TransportVersion def(int id) {
public static final TransportVersion RANDOM_SAMPLER_QUERY_BUILDER_8_19 = def(8_841_0_56);
public static final TransportVersion ML_INFERENCE_SAGEMAKER_ELASTIC_8_19 = def(8_841_0_57);
public static final TransportVersion SPARSE_VECTOR_FIELD_PRUNING_OPTIONS_8_19 = def(8_841_0_58);
public static final TransportVersion ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED_8_19 = def(8_841_0_59);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
Expand Down Expand Up @@ -315,6 +316,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ML_INFERENCE_SAGEMAKER_ELASTIC = def(9_106_0_00);
public static final TransportVersion SPARSE_VECTOR_FIELD_PRUNING_OPTIONS = def(9_107_0_00);
public static final TransportVersion CLUSTER_STATE_PROJECTS_SETTINGS = def(9_108_0_00);
public static final TransportVersion ML_INFERENCE_ELASTIC_DENSE_TEXT_EMBEDDINGS_ADDED = def(9_109_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public void testGetDefaultEndpoints() throws IOException {
var allModels = getAllModels();
var chatCompletionModels = getModels("_all", TaskType.CHAT_COMPLETION);

assertThat(allModels, hasSize(6));
assertThat(allModels, hasSize(7));
assertThat(chatCompletionModels, hasSize(1));

for (var model : chatCompletionModels) {
Expand All @@ -42,6 +42,7 @@ public void testGetDefaultEndpoints() throws IOException {

assertInferenceIdTaskType(allModels, ".rainbow-sprinkles-elastic", TaskType.CHAT_COMPLETION);
assertInferenceIdTaskType(allModels, ".elser-v2-elastic", TaskType.SPARSE_EMBEDDING);
assertInferenceIdTaskType(allModels, ".multilingual-embed-v1-elastic", TaskType.TEXT_EMBEDDING);
assertInferenceIdTaskType(allModels, ".rerank-v1-elastic", TaskType.RERANK);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.assertStatusOkOrCreated;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.equalTo;

public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {

Expand Down Expand Up @@ -76,16 +77,21 @@ private Iterable<String> providers(List<Object> services) {
}

public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
List<Object> services = getServices(TaskType.TEXT_EMBEDDING);
assertThat(services.size(), equalTo(18));

assertThat(
providersFor(TaskType.TEXT_EMBEDDING),
containsInAnyOrder(
List.of(
"alibabacloud-ai-search",
"amazonbedrock",
"amazon_sagemaker",
"azureaistudio",
"azureopenai",
"cohere",
"custom",
"elastic",
"elasticsearch",
"googleaistudio",
"googlevertexai",
Expand All @@ -95,8 +101,7 @@ public void testGetServicesWithTextEmbeddingTaskType() throws IOException {
"openai",
"text_embedding_test_service",
"voyageai",
"watsonxai",
"amazon_sagemaker"
"watsonxai"
).toArray()
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ public void enqueueAuthorizeAllModelsResponse() {
"task_types": ["embed/text/sparse"]
},
{
"model_name": "multilingual-embed-v1",
"task_types": ["embed/text/dense"]
},
{
"model_name": "rerank-v1",
"task_types": ["rerank/text/text-similarity"]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.MinimalServiceSettings;
import org.elasticsearch.inference.Model;
Expand Down Expand Up @@ -43,6 +44,7 @@
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.mockito.Mockito.mock;

public class InferenceRevokeDefaultEndpointsIT extends ESSingleNodeTestCase {
Expand Down Expand Up @@ -190,13 +192,17 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
String responseJson = """
{
"models": [
{
"model_name": "elser-v2",
"task_types": ["embed/text/sparse"]
},
{
"model_name": "rainbow-sprinkles",
"task_types": ["chat"]
},
{
"model_name": "elser-v2",
"task_types": ["embed/text/sparse"]
"model_name": "multilingual-embed-v1",
"task_types": ["embed/text/dense"]
},
{
"model_name": "rerank-v1",
Expand All @@ -214,36 +220,48 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
assertThat(service.supportedStreamingTasks(), is(EnumSet.of(TaskType.CHAT_COMPLETION)));
assertThat(
service.defaultConfigIds(),
is(
List.of(
new InferenceService.DefaultConfigId(
".elser-v2-elastic",
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
service
),
new InferenceService.DefaultConfigId(
".rainbow-sprinkles-elastic",
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
service
containsInAnyOrder(
new InferenceService.DefaultConfigId(
".elser-v2-elastic",
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
service
),
new InferenceService.DefaultConfigId(
".rainbow-sprinkles-elastic",
MinimalServiceSettings.chatCompletion(ElasticInferenceService.NAME),
service
),
new InferenceService.DefaultConfigId(
".multilingual-embed-v1-elastic",
MinimalServiceSettings.textEmbedding(
ElasticInferenceService.NAME,
ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(),
DenseVectorFieldMapper.ElementType.FLOAT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a blocker, but can you explain why the MinimalServiceSettings differ from other task types?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's just about the different purposes models/tasks:

  • Dense vector embeddings can have different element types (typically float, but they can also be quantized to bit vectors or int vectors for example) , therefore we need to specify the ElementType. Some models also allow you to specify a target number of dimensions (f.e. when using Matryoshka embeddings, therefore we need to specify the number of dimensions. Also vector embeddings can be compared using different similarity measures, therefore we need to specify the similarity measure.
  • A reranking model simply returns an ordered list of ranked documents, so it doesn't make sense to specify dimensions, an element type or a similarity measure

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense! Thanks for the background :)

),
new InferenceService.DefaultConfigId(
".rerank-v1-elastic",
MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
service
)
service
),
new InferenceService.DefaultConfigId(
".rerank-v1-elastic",
MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
service
)
)
);
assertThat(
service.supportedTaskTypes(),
is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK))
is(EnumSet.of(TaskType.CHAT_COMPLETION, TaskType.SPARSE_EMBEDDING, TaskType.RERANK, TaskType.TEXT_EMBEDDING))
);

PlainActionFuture<List<Model>> listener = new PlainActionFuture<>();
service.defaultConfigs(listener);
assertThat(listener.actionGet(TIMEOUT).get(0).getConfigurations().getInferenceEntityId(), is(".elser-v2-elastic"));
assertThat(listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
assertThat(listener.actionGet(TIMEOUT).get(2).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic"));
assertThat(
listener.actionGet(TIMEOUT).get(1).getConfigurations().getInferenceEntityId(),
is(".multilingual-embed-v1-elastic")
);
assertThat(listener.actionGet(TIMEOUT).get(2).getConfigurations().getInferenceEntityId(), is(".rainbow-sprinkles-elastic"));
assertThat(listener.actionGet(TIMEOUT).get(3).getConfigurations().getInferenceEntityId(), is(".rerank-v1-elastic"));

var getModelListener = new PlainActionFuture<UnparsedModel>();
// persists the default endpoints
Expand All @@ -265,6 +283,10 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
{
"model_name": "rerank-v1",
"task_types": ["rerank/text/text-similarity"]
},
{
"model_name": "multilingual-embed-v1",
"task_types": ["embed/text/dense"]
}
]
}
Expand All @@ -278,22 +300,33 @@ public void testRemoves_DefaultChatCompletion_V1_WhenAuthorizationDoesNotReturnA
assertThat(service.supportedStreamingTasks(), is(EnumSet.noneOf(TaskType.class)));
assertThat(
service.defaultConfigIds(),
is(
List.of(
new InferenceService.DefaultConfigId(
".elser-v2-elastic",
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
service
containsInAnyOrder(
new InferenceService.DefaultConfigId(
".elser-v2-elastic",
MinimalServiceSettings.sparseEmbedding(ElasticInferenceService.NAME),
service
),
new InferenceService.DefaultConfigId(
".multilingual-embed-v1-elastic",
MinimalServiceSettings.textEmbedding(
ElasticInferenceService.NAME,
ElasticInferenceService.DENSE_TEXT_EMBEDDINGS_DIMENSIONS,
ElasticInferenceService.defaultDenseTextEmbeddingsSimilarity(),
DenseVectorFieldMapper.ElementType.FLOAT
),
new InferenceService.DefaultConfigId(
".rerank-v1-elastic",
MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
service
)
service
),
new InferenceService.DefaultConfigId(
".rerank-v1-elastic",
MinimalServiceSettings.rerank(ElasticInferenceService.NAME),
service
)
)
);
assertThat(service.supportedTaskTypes(), is(EnumSet.of(TaskType.SPARSE_EMBEDDING, TaskType.RERANK)));
assertThat(
service.supportedTaskTypes(),
is(EnumSet.of(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING, TaskType.RERANK))
);

var getModelListener = new PlainActionFuture<UnparsedModel>();
modelRegistry.getModel(".rainbow-sprinkles-elastic", getModelListener);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.response.elastic;

import org.elasticsearch.common.xcontent.XContentParserUtils;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.request.Request;

import java.io.IOException;
import java.util.List;

import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;

public class ElasticInferenceServiceDenseTextEmbeddingsResponseEntity {

/**
* Parses the Elastic Inference Service Dense Text Embeddings response.
*
* For a request like:
*
* <pre>
* <code>
* {
* "inputs": ["Embed this text", "Embed this text, too"]
* }
* </code>
* </pre>
*
* The response would look like:
*
* <pre>
* <code>
* {
* "data": [
* [
* 2.1259406,
* 1.7073475,
* 0.9020516
* ],
* (...)
* ],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I vaguely remembered Tim's thread on this a couple weeks ago, but should we revisit the response format? Looking at OpenAI, Alibaba, and Mixedbread as quick references, it looks like they return a list of objects. I don't have a strong preference, but just wanted to bring this up since we might be differing from others here and wanted to confirm that this is what we want.
Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Answered in the thread

* "meta": {
* "usage": {...}
* }
* }
* </code>
* </pre>
*/
public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
try (var p = XContentFactory.xContent(XContentType.JSON).createParser(XContentParserConfiguration.EMPTY, response.body())) {
return EmbeddingFloatResult.PARSER.apply(p, null).toTextEmbeddingFloatResults();
}
}

public record EmbeddingFloatResult(List<EmbeddingFloatResultEntry> embeddingResults) {
@SuppressWarnings("unchecked")
public static final ConstructingObjectParser<EmbeddingFloatResult, Void> PARSER = new ConstructingObjectParser<>(
EmbeddingFloatResult.class.getSimpleName(),
true,
args -> new EmbeddingFloatResult((List<EmbeddingFloatResultEntry>) args[0])
);

static {
// Custom field declaration to handle array of arrays format
PARSER.declareField(constructorArg(), (parser, context) -> {
return XContentParserUtils.parseList(parser, (p, index) -> {
List<Float> embedding = XContentParserUtils.parseList(p, (innerParser, innerIndex) -> innerParser.floatValue());
return EmbeddingFloatResultEntry.fromFloatArray(embedding);
});
}, new ParseField("data"), org.elasticsearch.xcontent.ObjectParser.ValueType.OBJECT_ARRAY);
}

public TextEmbeddingFloatResults toTextEmbeddingFloatResults() {
return new TextEmbeddingFloatResults(
embeddingResults.stream().map(entry -> TextEmbeddingFloatResults.Embedding.of(entry.embedding)).toList()
);
}
}

/**
* Represents a single embedding entry in the response.
* For the Elastic Inference Service, each entry is just an array of floats (no wrapper object).
* This is a simpler wrapper that just holds the float array.
*/
public record EmbeddingFloatResultEntry(List<Float> embedding) {
public static EmbeddingFloatResultEntry fromFloatArray(List<Float> floats) {
return new EmbeddingFloatResultEntry(floats);
}
}

private ElasticInferenceServiceDenseTextEmbeddingsResponseEntity() {}
}
Loading