diff --git a/docs/changelog/129320.yaml b/docs/changelog/129320.yaml new file mode 100644 index 0000000000000..ccd2ebcaa4379 --- /dev/null +++ b/docs/changelog/129320.yaml @@ -0,0 +1,5 @@ +pr: 129320 +summary: Refine indexing pressure accounting in semantic bulk inference filter +area: Relevance +type: enhancement +issues: [] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 082ece347208a..10a31cee64dac 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -25,7 +25,9 @@ import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.cluster.metadata.ProjectMetadata; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.bytes.CompositeBytesReference; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.concurrent.AtomicArray; @@ -52,6 +54,7 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.xcontent.XContent; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; @@ -469,17 +472,17 @@ private void recordRequestCountMetrics(Model model, int incrementBy, Throwable t * Adds all inference requests associated with their respective inference IDs to the given {@code requestsMap} * for the specified {@code item}. * - * @param item The bulk request item to process. - * @param itemIndex The position of the item within the original bulk request. + * @param item The bulk request item to process. + * @param itemIndex The position of the item within the original bulk request. * @param requestsMap A map storing inference requests, where each key is an inference ID, * and the value is a list of associated {@link FieldInferenceRequest} objects. * @return The total content length of all newly added requests, or {@code 0} if no requests were added. */ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map> requestsMap) { boolean isUpdateRequest = false; - final IndexRequestWithIndexingPressure indexRequest; + final IndexRequest indexRequest; if (item.request() instanceof IndexRequest ir) { - indexRequest = new IndexRequestWithIndexingPressure(ir); + indexRequest = ir; } else if (item.request() instanceof UpdateRequest updateRequest) { isUpdateRequest = true; if (updateRequest.script() != null) { @@ -493,13 +496,13 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map< ); return 0; } - indexRequest = new IndexRequestWithIndexingPressure(updateRequest.doc()); + indexRequest = updateRequest.doc(); } else { // ignore delete request return 0; } - final Map docMap = indexRequest.getIndexRequest().sourceAsMap(); + final Map docMap = indexRequest.sourceAsMap(); long inputLength = 0; for (var entry : fieldInferenceMap.values()) { String field = entry.getName(); @@ -535,10 +538,6 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map< * This ensures that the field is treated as intentionally cleared, * preventing any unintended carryover of prior inference results. */ - if (incrementIndexingPressure(indexRequest, itemIndex) == false) { - return inputLength; - } - var slot = ensureResponseAccumulatorSlot(itemIndex); slot.addOrUpdateResponse( new FieldInferenceResponse(field, sourceField, null, order++, 0, null, EMPTY_CHUNKED_INFERENCE) @@ -578,10 +577,6 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map< List requests = requestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); int offsetAdjustment = 0; for (String v : values) { - if (incrementIndexingPressure(indexRequest, itemIndex) == false) { - return inputLength; - } - if (v.isBlank()) { slot.addOrUpdateResponse( new FieldInferenceResponse(field, sourceField, v, order++, 0, null, EMPTY_CHUNKED_INFERENCE) @@ -604,50 +599,6 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map< return inputLength; } - private static class IndexRequestWithIndexingPressure { - private final IndexRequest indexRequest; - private boolean indexingPressureIncremented; - - private IndexRequestWithIndexingPressure(IndexRequest indexRequest) { - this.indexRequest = indexRequest; - this.indexingPressureIncremented = false; - } - - private IndexRequest getIndexRequest() { - return indexRequest; - } - - private boolean isIndexingPressureIncremented() { - return indexingPressureIncremented; - } - - private void setIndexingPressureIncremented() { - this.indexingPressureIncremented = true; - } - } - - private boolean incrementIndexingPressure(IndexRequestWithIndexingPressure indexRequest, int itemIndex) { - boolean success = true; - if (indexRequest.isIndexingPressureIncremented() == false) { - try { - // Track operation count as one operation per document source update - coordinatingIndexingPressure.increment(1, indexRequest.getIndexRequest().source().ramBytesUsed()); - indexRequest.setIndexingPressureIncremented(); - } catch (EsRejectedExecutionException e) { - addInferenceResponseFailure( - itemIndex, - new InferenceException( - "Insufficient memory available to update source on document [" + indexRequest.getIndexRequest().id() + "]", - e - ) - ); - success = false; - } - } - - return success; - } - private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) { FieldInferenceResponseAccumulator acc = inferenceResults.get(id); if (acc == null) { @@ -723,70 +674,152 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons ); inferenceFieldsMap.put(fieldName, result); } - - BytesReference originalSource = indexRequest.source(); if (useLegacyFormat) { var newDocMap = indexRequest.sourceAsMap(); for (var entry : inferenceFieldsMap.entrySet()) { SemanticTextUtils.insertValue(entry.getKey(), newDocMap, entry.getValue()); } - indexRequest.source(newDocMap, indexRequest.getContentType()); + XContentBuilder builder = XContentFactory.contentBuilder(indexRequest.getContentType()); + builder.map(newDocMap); + var newSource = BytesReference.bytes(builder); + if (incrementIndexingPressure(item, indexRequest, newSource.length())) { + indexRequest.source(newSource, indexRequest.getContentType()); + } + } else { + updateSourceWithInferenceFields(item, indexRequest, inferenceFieldsMap); + } + } + + /** + * Updates the {@link IndexRequest}'s source to include additional inference fields. + *

+ * If the original source uses an array-backed {@link BytesReference}, this method attempts an in-place update, + * reusing the existing array where possible and appending additional bytes only if needed. + *

+ * If the original source is not array-backed, the entire source is replaced with the new source that includes + * the inference fields. In this case, the full size of the new source is accounted for in indexing pressure. + *

+ * Note: We do not subtract the indexing pressure of the original source since its bytes may be pooled and not + * reclaimable by the garbage collector during the request lifecycle. + * + * @param item The {@link BulkItemRequest} being processed. + * @param indexRequest The {@link IndexRequest} whose source will be updated. + * @param inferenceFieldsMap A map of additional fields to append to the source. + * @throws IOException if building the new source fails. + */ + private void updateSourceWithInferenceFields( + BulkItemRequest item, + IndexRequest indexRequest, + Map inferenceFieldsMap + ) throws IOException { + var originalSource = indexRequest.source(); + final BytesReference newSource; + + // Build a new source by appending the inference fields to the existing source. + try (XContentBuilder builder = XContentBuilder.builder(indexRequest.getContentType().xContent())) { + appendSourceAndInferenceMetadata(builder, originalSource, indexRequest.getContentType(), inferenceFieldsMap); + newSource = BytesReference.bytes(builder); + } + + // Calculate the additional size to account for in indexing pressure. + final int additionalSize = originalSource.hasArray() ? newSource.length() - originalSource.length() : newSource.length(); + + // If we exceed the indexing pressure limit, do not proceed with the update. + if (incrementIndexingPressure(item, indexRequest, additionalSize) == false) { + return; + } + + // Apply the updated source to the index request. + if (originalSource.hasArray()) { + // If the original source is backed by an array, perform in-place update: + // - Copy as much of the new source as fits into the original array. + System.arraycopy( + newSource.array(), + newSource.arrayOffset(), + originalSource.array(), + originalSource.arrayOffset(), + originalSource.length() + ); + + int remainingSize = newSource.length() - originalSource.length(); + if (remainingSize > 0) { + // If there are additional bytes, append them as a new BytesArray segment. + byte[] remainingBytes = new byte[remainingSize]; + System.arraycopy( + newSource.array(), + newSource.arrayOffset() + originalSource.length(), + remainingBytes, + 0, + remainingSize + ); + indexRequest.source( + CompositeBytesReference.of(originalSource, new BytesArray(remainingBytes)), + indexRequest.getContentType() + ); + } else { + // No additional bytes; just adjust the slice length. + indexRequest.source(originalSource.slice(0, newSource.length())); + } } else { - try (XContentBuilder builder = XContentBuilder.builder(indexRequest.getContentType().xContent())) { - appendSourceAndInferenceMetadata(builder, indexRequest.source(), indexRequest.getContentType(), inferenceFieldsMap); - indexRequest.source(builder); + // If the original source is not array-backed, replace it entirely. + indexRequest.source(newSource, indexRequest.getContentType()); + } + } + + /** + * Appends the original source and the new inference metadata field directly to the provided + * {@link XContentBuilder}, avoiding the need to materialize the original source as a {@link Map}. + */ + private void appendSourceAndInferenceMetadata( + XContentBuilder builder, + BytesReference source, + XContentType xContentType, + Map inferenceFieldsMap + ) throws IOException { + builder.startObject(); + + // append the original source + try ( + XContentParser parser = XContentHelper.createParserNotCompressed(XContentParserConfiguration.EMPTY, source, xContentType) + ) { + // skip start object + parser.nextToken(); + while (parser.nextToken() != XContentParser.Token.END_OBJECT) { + builder.copyCurrentStructure(parser); } } - long modifiedSourceSize = indexRequest.source().ramBytesUsed(); - // Add the indexing pressure from the source modifications. - // Don't increment operation count because we count one source update as one operation, and we already accounted for those - // in addFieldInferenceRequests. + // add the inference metadata field + builder.field(InferenceMetadataFieldsMapper.NAME); + try (XContentParser parser = XContentHelper.mapToXContentParser(XContentParserConfiguration.EMPTY, inferenceFieldsMap)) { + builder.copyCurrentStructure(parser); + } + + builder.endObject(); + } + + private boolean incrementIndexingPressure(BulkItemRequest item, IndexRequest indexRequest, int inc) { try { - coordinatingIndexingPressure.increment(0, modifiedSourceSize - originalSource.ramBytesUsed()); + if (inc > 0) { + coordinatingIndexingPressure.increment(1, inc); + } + return true; } catch (EsRejectedExecutionException e) { - indexRequest.source(originalSource, indexRequest.getContentType()); + inferenceStats.bulkRejection().incrementBy(1); item.abort( item.index(), new InferenceException( - "Insufficient memory available to insert inference results into document [" + indexRequest.id() + "]", + "Unable to insert inference results into document [" + + indexRequest.id() + + "] due to memory pressure. Please retry the bulk request with fewer documents or smaller document sizes.", e ) ); + return false; } } } - /** - * Appends the original source and the new inference metadata field directly to the provided - * {@link XContentBuilder}, avoiding the need to materialize the original source as a {@link Map}. - */ - private static void appendSourceAndInferenceMetadata( - XContentBuilder builder, - BytesReference source, - XContentType xContentType, - Map inferenceFieldsMap - ) throws IOException { - builder.startObject(); - - // append the original source - try (XContentParser parser = XContentHelper.createParserNotCompressed(XContentParserConfiguration.EMPTY, source, xContentType)) { - // skip start object - parser.nextToken(); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - builder.copyCurrentStructure(parser); - } - } - - // add the inference metadata field - builder.field(InferenceMetadataFieldsMapper.NAME); - try (XContentParser parser = XContentHelper.mapToXContentParser(XContentParserConfiguration.EMPTY, inferenceFieldsMap)) { - builder.copyCurrentStructure(parser); - } - - builder.endObject(); - } - static IndexRequest getIndexRequestOrNull(DocWriteRequest docWriteRequest) { if (docWriteRequest instanceof IndexRequest indexRequest) { return indexRequest; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java index 17c91b81233fb..215d12e27f833 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/telemetry/InferenceStats.java @@ -20,11 +20,12 @@ import java.util.Map; import java.util.Objects; -public record InferenceStats(LongCounter requestCount, LongHistogram inferenceDuration) { +public record InferenceStats(LongCounter requestCount, LongHistogram inferenceDuration, LongCounter bulkRejection) { public InferenceStats { Objects.requireNonNull(requestCount); Objects.requireNonNull(inferenceDuration); + Objects.requireNonNull(bulkRejection); } public static InferenceStats create(MeterRegistry meterRegistry) { @@ -38,6 +39,11 @@ public static InferenceStats create(MeterRegistry meterRegistry) { "es.inference.requests.time", "Inference API request counts for a particular service, task type, model ID", "ms" + ), + meterRegistry.registerLongCounter( + "es.inference.bulk.rejection.total", + "Count of bulk request rejections for semantic text processing due to insufficient available memory", + "operations" ) ); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java index 70499c7987965..501f239be3032 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java @@ -88,7 +88,7 @@ public void setUp() throws Exception { licenseState = mock(); modelRegistry = mock(); serviceRegistry = mock(); - inferenceStats = new InferenceStats(mock(), mock()); + inferenceStats = new InferenceStats(mock(), mock(), mock()); streamingTaskManager = mock(); action = createAction( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index a7cb0234aee59..4539fae5a5bc8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -131,9 +131,7 @@ public ShardBulkInferenceActionFilterTests(boolean useLegacyFormat) { @ParametersFactory public static Iterable parameters() throws Exception { - List lst = new ArrayList<>(); - lst.add(new Object[] { true }); - return lst; + return List.of(new Boolean[] { true }, new Boolean[] { false }); } @Before @@ -148,7 +146,7 @@ public void tearDownThreadPool() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testFilterNoop() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = new InferenceStats(mock(), mock(), mock()); ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(), @@ -181,7 +179,7 @@ public void testFilterNoop() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testLicenseInvalidForInference() throws InterruptedException { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = new InferenceStats(mock(), mock(), mock()); StaticModel model = StaticModel.createRandomInstance(); ShardBulkInferenceActionFilter filter = createFilter( threadPool, @@ -227,7 +225,7 @@ public void testLicenseInvalidForInference() throws InterruptedException { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testInferenceNotFound() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = new InferenceStats(mock(), mock(), mock()); StaticModel model = StaticModel.createRandomInstance(); ShardBulkInferenceActionFilter filter = createFilter( threadPool, @@ -275,7 +273,7 @@ public void testInferenceNotFound() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testItemFailures() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = new InferenceStats(mock(), mock(), mock()); StaticModel model = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); ShardBulkInferenceActionFilter filter = createFilter( threadPool, @@ -364,7 +362,7 @@ public void testItemFailures() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testExplicitNull() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = new InferenceStats(mock(), mock(), mock()); StaticModel model = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom"))); model.putResult("I am a success", randomChunkedInferenceEmbedding(model, List.of("I am a success"))); @@ -440,7 +438,7 @@ public void testExplicitNull() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testHandleEmptyInput() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = new InferenceStats(mock(), mock(), mock()); StaticModel model = StaticModel.createRandomInstance(); ShardBulkInferenceActionFilter filter = createFilter( threadPool, @@ -495,7 +493,7 @@ public void testHandleEmptyInput() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testManyRandomDocs() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = new InferenceStats(mock(), mock(), mock()); Map inferenceModelMap = new HashMap<>(); int numModels = randomIntBetween(1, 3); for (int i = 0; i < numModels; i++) { @@ -559,7 +557,7 @@ public void testManyRandomDocs() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testIndexingPressure() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = new InferenceStats(mock(), mock(), mock()); final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(Settings.EMPTY); final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); final StaticModel denseModel = StaticModel.createRandomInstance(TaskType.TEXT_EMBEDDING); @@ -616,20 +614,10 @@ public void testIndexingPressure() throws Exception { IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating(); assertThat(coordinatingIndexingPressure, notNullValue()); - verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc0Source)); - verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source)); - verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc2Source)); - verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc3Source)); - verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc4Source)); - verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc0UpdateSource)); - if (useLegacyFormat == false) { - verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1UpdateSource)); - } - - verify(coordinatingIndexingPressure, times(useLegacyFormat ? 6 : 7)).increment(eq(0), longThat(l -> l > 0)); + verify(coordinatingIndexingPressure, times(useLegacyFormat ? 6 : 7)).increment(eq(1), longThat(l -> l > 0)); // Verify that the only times that increment is called are the times verified above - verify(coordinatingIndexingPressure, times(useLegacyFormat ? 12 : 14)).increment(anyInt(), anyLong()); + verify(coordinatingIndexingPressure, times(useLegacyFormat ? 6 : 7)).increment(anyInt(), anyLong()); // Verify that the coordinating indexing pressure is maintained through downstream action filters verify(coordinatingIndexingPressure, never()).close(); @@ -675,86 +663,6 @@ public void testIndexingPressure() throws Exception { verify(coordinatingIndexingPressure).close(); } - @SuppressWarnings("unchecked") - public void testIndexingPressureTripsOnInferenceRequestGeneration() throws Exception { - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); - final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure( - Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), "1b").build() - ); - final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); - final ShardBulkInferenceActionFilter filter = createFilter( - threadPool, - Map.of(sparseModel.getInferenceEntityId(), sparseModel), - indexingPressure, - useLegacyFormat, - true, - inferenceStats - ); - - XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bar"); - - CountDownLatch chainExecuted = new CountDownLatch(1); - ActionFilterChain actionFilterChain = (task, action, request, listener) -> { - try { - assertNull(request.getInferenceFieldMap()); - assertThat(request.items().length, equalTo(3)); - - assertNull(request.items()[0].getPrimaryResponse()); - assertNull(request.items()[2].getPrimaryResponse()); - - BulkItemRequest doc1Request = request.items()[1]; - BulkItemResponse doc1Response = doc1Request.getPrimaryResponse(); - assertNotNull(doc1Response); - assertTrue(doc1Response.isFailed()); - BulkItemResponse.Failure doc1Failure = doc1Response.getFailure(); - assertThat( - doc1Failure.getCause().getMessage(), - containsString("Insufficient memory available to update source on document [doc_1]") - ); - assertThat(doc1Failure.getCause().getCause(), instanceOf(EsRejectedExecutionException.class)); - assertThat(doc1Failure.getStatus(), is(RestStatus.TOO_MANY_REQUESTS)); - - IndexRequest doc1IndexRequest = getIndexRequestOrNull(doc1Request.request()); - assertThat(doc1IndexRequest, notNullValue()); - assertThat(doc1IndexRequest.source(), equalTo(BytesReference.bytes(doc1Source))); - - IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating(); - assertThat(coordinatingIndexingPressure, notNullValue()); - verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source)); - verify(coordinatingIndexingPressure, times(1)).increment(anyInt(), anyLong()); - - // Verify that the coordinating indexing pressure is maintained through downstream action filters - verify(coordinatingIndexingPressure, never()).close(); - - // Call the listener once the request is successfully processed, like is done in the production code path - listener.onResponse(null); - } finally { - chainExecuted.countDown(); - } - }; - ActionListener actionListener = (ActionListener) mock(ActionListener.class); - Task task = mock(Task.class); - - Map inferenceFieldMap = Map.of( - "sparse_field", - new InferenceFieldMetadata("sparse_field", sparseModel.getInferenceEntityId(), new String[] { "sparse_field" }, null) - ); - - BulkItemRequest[] items = new BulkItemRequest[3]; - items[0] = new BulkItemRequest(0, new IndexRequest("index").id("doc_0").source("non_inference_field", "foo")); - items[1] = new BulkItemRequest(1, new IndexRequest("index").id("doc_1").source(doc1Source)); - items[2] = new BulkItemRequest(2, new IndexRequest("index").id("doc_2").source("non_inference_field", "baz")); - - BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); - request.setInferenceFieldMap(inferenceFieldMap); - filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); - awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); - - IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating(); - assertThat(coordinatingIndexingPressure, notNullValue()); - verify(coordinatingIndexingPressure).close(); - } - @SuppressWarnings("unchecked") public void testIndexingPressureTripsOnInferenceResponseHandling() throws Exception { final XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bar"); @@ -762,7 +670,7 @@ public void testIndexingPressureTripsOnInferenceResponseHandling() throws Except Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), (bytesUsed(doc1Source) + 1) + "b").build() ); - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = new InferenceStats(mock(), mock(), mock()); final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); sparseModel.putResult("bar", randomChunkedInferenceEmbedding(sparseModel, List.of("bar"))); @@ -791,7 +699,7 @@ public void testIndexingPressureTripsOnInferenceResponseHandling() throws Except BulkItemResponse.Failure doc1Failure = doc1Response.getFailure(); assertThat( doc1Failure.getCause().getMessage(), - containsString("Insufficient memory available to insert inference results into document [doc_1]") + containsString("Unable to insert inference results into document [doc_1] due to memory pressure.") ); assertThat(doc1Failure.getCause().getCause(), instanceOf(EsRejectedExecutionException.class)); assertThat(doc1Failure.getStatus(), is(RestStatus.TOO_MANY_REQUESTS)); @@ -802,9 +710,8 @@ public void testIndexingPressureTripsOnInferenceResponseHandling() throws Except IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating(); assertThat(coordinatingIndexingPressure, notNullValue()); - verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source)); - verify(coordinatingIndexingPressure).increment(eq(0), longThat(l -> l > 0)); - verify(coordinatingIndexingPressure, times(2)).increment(anyInt(), anyLong()); + verify(coordinatingIndexingPressure).increment(eq(1), longThat(l -> l > 0)); + verify(coordinatingIndexingPressure, times(1)).increment(anyInt(), anyLong()); // Verify that the coordinating indexing pressure is maintained through downstream action filters verify(coordinatingIndexingPressure, never()).close(); @@ -875,7 +782,7 @@ public void testIndexingPressurePartialFailure() throws Exception { .build() ); - final InferenceStats inferenceStats = new InferenceStats(mock(), mock()); + final InferenceStats inferenceStats = new InferenceStats(mock(), mock(), mock()); final ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(sparseModel.getInferenceEntityId(), sparseModel), @@ -900,10 +807,7 @@ public void testIndexingPressurePartialFailure() throws Exception { assertNotNull(doc2Response); assertTrue(doc2Response.isFailed()); BulkItemResponse.Failure doc2Failure = doc2Response.getFailure(); - assertThat( - doc2Failure.getCause().getMessage(), - containsString("Insufficient memory available to insert inference results into document [doc_2]") - ); + assertThat(doc2Failure.getCause().getMessage(), containsString("Unable to insert inference results into document [doc_2]")); assertThat(doc2Failure.getCause().getCause(), instanceOf(EsRejectedExecutionException.class)); assertThat(doc2Failure.getStatus(), is(RestStatus.TOO_MANY_REQUESTS)); @@ -913,10 +817,8 @@ public void testIndexingPressurePartialFailure() throws Exception { IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating(); assertThat(coordinatingIndexingPressure, notNullValue()); - verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source)); - verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc2Source)); - verify(coordinatingIndexingPressure, times(2)).increment(eq(0), longThat(l -> l > 0)); - verify(coordinatingIndexingPressure, times(4)).increment(anyInt(), anyLong()); + verify(coordinatingIndexingPressure, times(2)).increment(eq(1), longThat(l -> l > 0)); + verify(coordinatingIndexingPressure, times(2)).increment(anyInt(), anyLong()); // Verify that the coordinating indexing pressure is maintained through downstream action filters verify(coordinatingIndexingPressure, never()).close(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/InferenceStatsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/InferenceStatsTests.java index f3800f91d9a54..352ced4a22127 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/InferenceStatsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/telemetry/InferenceStatsTests.java @@ -37,7 +37,7 @@ public class InferenceStatsTests extends ESTestCase { public void testRecordWithModel() { var longCounter = mock(LongCounter.class); - var stats = new InferenceStats(longCounter, mock()); + var stats = new InferenceStats(longCounter, mock(), mock()); stats.requestCount().incrementBy(1, modelAttributes(model("service", TaskType.ANY, "modelId"))); @@ -49,7 +49,7 @@ public void testRecordWithModel() { public void testRecordWithoutModel() { var longCounter = mock(LongCounter.class); - var stats = new InferenceStats(longCounter, mock()); + var stats = new InferenceStats(longCounter, mock(), mock()); stats.requestCount().incrementBy(1, modelAttributes(model("service", TaskType.ANY, null))); @@ -63,7 +63,7 @@ public void testCreation() { public void testRecordDurationWithoutError() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); Map metricAttributes = new HashMap<>(); metricAttributes.putAll(modelAttributes(model("service", TaskType.ANY, "modelId"))); @@ -88,7 +88,7 @@ public void testRecordDurationWithoutError() { public void testRecordDurationWithElasticsearchStatusException() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); var statusCode = RestStatus.BAD_REQUEST; var exception = new ElasticsearchStatusException("hello", statusCode); var expectedError = String.valueOf(statusCode.getStatus()); @@ -116,7 +116,7 @@ public void testRecordDurationWithElasticsearchStatusException() { public void testRecordDurationWithOtherException() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); var exception = new IllegalStateException("ahh"); var expectedError = exception.getClass().getSimpleName(); @@ -138,7 +138,7 @@ public void testRecordDurationWithOtherException() { public void testRecordDurationWithUnparsedModelAndElasticsearchStatusException() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); var statusCode = RestStatus.BAD_REQUEST; var exception = new ElasticsearchStatusException("hello", statusCode); var expectedError = String.valueOf(statusCode.getStatus()); @@ -163,7 +163,7 @@ public void testRecordDurationWithUnparsedModelAndElasticsearchStatusException() public void testRecordDurationWithUnparsedModelAndOtherException() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); var exception = new IllegalStateException("ahh"); var expectedError = exception.getClass().getSimpleName(); @@ -187,7 +187,7 @@ public void testRecordDurationWithUnparsedModelAndOtherException() { public void testRecordDurationWithUnknownModelAndElasticsearchStatusException() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); var statusCode = RestStatus.BAD_REQUEST; var exception = new ElasticsearchStatusException("hello", statusCode); var expectedError = String.valueOf(statusCode.getStatus()); @@ -206,7 +206,7 @@ public void testRecordDurationWithUnknownModelAndElasticsearchStatusException() public void testRecordDurationWithUnknownModelAndOtherException() { var expectedLong = randomLong(); var histogramCounter = mock(LongHistogram.class); - var stats = new InferenceStats(mock(), histogramCounter); + var stats = new InferenceStats(mock(), histogramCounter, mock()); var exception = new IllegalStateException("ahh"); var expectedError = exception.getClass().getSimpleName();