Skip to content

Track bytes used by in-memory postings #129969

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions server/src/main/java/org/elasticsearch/common/lucene/Lucene.java
Original file line number Diff line number Diff line change
Expand Up @@ -739,15 +739,26 @@ public static Version parseVersionLenient(String toParse, Version defaultValue)
* If no SegmentReader can be extracted an {@link IllegalStateException} is thrown.
*/
public static SegmentReader segmentReader(LeafReader reader) {
SegmentReader segmentReader = tryUnwrapSegmentReader(reader);
if (segmentReader == null) {
throw new IllegalStateException("Can not extract segment reader from given index reader [" + reader + "]");
}
return segmentReader;
}

/**
* Tries to extract a segment reader from the given index reader. Unlike {@link #segmentReader(LeafReader)} this method returns
* null if no SegmentReader can be unwrapped instead of throwing an exception.
*/
public static SegmentReader tryUnwrapSegmentReader(LeafReader reader) {
if (reader instanceof SegmentReader) {
return (SegmentReader) reader;
} else if (reader instanceof final FilterLeafReader fReader) {
return segmentReader(FilterLeafReader.unwrap(fReader));
return tryUnwrapSegmentReader(FilterLeafReader.unwrap(fReader));
} else if (reader instanceof final FilterCodecReader fReader) {
return segmentReader(FilterCodecReader.unwrap(fReader));
return tryUnwrapSegmentReader(FilterCodecReader.unwrap(fReader));
}
// hard fail - we can't get a SegmentReader
throw new IllegalStateException("Can not extract segment reader from given index reader [" + reader + "]");
return null;
}

@SuppressForbidden(reason = "Version#parseLeniently() used in a central place")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.index.codec;

import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.FieldsConsumer;
import org.apache.lucene.codecs.FieldsProducer;
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.NormsProducer;
import org.apache.lucene.codecs.PostingsFormat;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.Fields;
import org.apache.lucene.index.FilterLeafReader;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.internal.hppc.IntIntHashMap;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.common.util.FeatureFlag;

import java.io.IOException;
import java.util.function.IntConsumer;

/**
* A codec that tracks the length of the min and max written terms. Used to improve memory usage estimates in serverless, since
* {@link org.apache.lucene.codecs.lucene90.blocktree.FieldReader} keeps an in-memory reference to the min and max term.
*/
public class TrackingPostingsInMemoryBytesCodec extends FilterCodec {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add class level javadocs explain the purpose of this class?

public static final FeatureFlag TRACK_POSTINGS_IN_MEMORY_BYTES = new FeatureFlag("track_postings_in_memory_bytes");

public static final String IN_MEMORY_POSTINGS_BYTES_KEY = "es.postings.in_memory_bytes";

public TrackingPostingsInMemoryBytesCodec(Codec delegate) {
super(delegate.getName(), delegate);
}

@Override
public PostingsFormat postingsFormat() {
PostingsFormat format = super.postingsFormat();

return new PostingsFormat(format.getName()) {
@Override
public FieldsConsumer fieldsConsumer(SegmentWriteState state) throws IOException {
FieldsConsumer consumer = format.fieldsConsumer(state);
return new TrackingLengthFieldsConsumer(state, consumer);
}

@Override
public FieldsProducer fieldsProducer(SegmentReadState state) throws IOException {
return format.fieldsProducer(state);
}
};
}

static final class TrackingLengthFieldsConsumer extends FieldsConsumer {
final SegmentWriteState state;
final FieldsConsumer in;
final IntIntHashMap termsBytesPerField;

TrackingLengthFieldsConsumer(SegmentWriteState state, FieldsConsumer in) {
this.state = state;
this.in = in;
this.termsBytesPerField = new IntIntHashMap(state.fieldInfos.size());
}

@Override
public void write(Fields fields, NormsProducer norms) throws IOException {
in.write(new TrackingLengthFields(fields, termsBytesPerField, state.fieldInfos), norms);
long totalBytes = 0;
for (int bytes : termsBytesPerField.values) {
totalBytes += bytes;
}
state.segmentInfo.putAttribute(IN_MEMORY_POSTINGS_BYTES_KEY, Long.toString(totalBytes));
}

@Override
public void close() throws IOException {
in.close();
}
}

static final class TrackingLengthFields extends FilterLeafReader.FilterFields {
final IntIntHashMap termsBytesPerField;
final FieldInfos fieldInfos;

TrackingLengthFields(Fields in, IntIntHashMap termsBytesPerField, FieldInfos fieldInfos) {
super(in);
this.termsBytesPerField = termsBytesPerField;
this.fieldInfos = fieldInfos;
}

@Override
public Terms terms(String field) throws IOException {
Terms terms = super.terms(field);
if (terms == null) {
return null;
}
int fieldNum = fieldInfos.fieldInfo(field).number;
return new TrackingLengthTerms(
terms,
bytes -> termsBytesPerField.put(fieldNum, Math.max(termsBytesPerField.getOrDefault(fieldNum, 0), bytes))
);
}
}

static final class TrackingLengthTerms extends FilterLeafReader.FilterTerms {
final IntConsumer onFinish;

TrackingLengthTerms(Terms in, IntConsumer onFinish) {
super(in);
this.onFinish = onFinish;
}

@Override
public TermsEnum iterator() throws IOException {
return new TrackingLengthTermsEnum(super.iterator(), onFinish);
}
}

static final class TrackingLengthTermsEnum extends FilterLeafReader.FilterTermsEnum {
int maxTermLength = 0;
int minTermLength = 0;
int termCount = 0;
final IntConsumer onFinish;

TrackingLengthTermsEnum(TermsEnum in, IntConsumer onFinish) {
super(in);
this.onFinish = onFinish;
}

@Override
public BytesRef next() throws IOException {
final BytesRef term = super.next();
if (term != null) {
if (termCount == 0) {
minTermLength = term.length;
}
maxTermLength = term.length;
termCount++;
} else {
if (termCount == 1) {
// If the minTerm and maxTerm are the same, only one instance is kept on the heap.
assert minTermLength == maxTermLength;
onFinish.accept(maxTermLength);
} else {
onFinish.accept(maxTermLength + minTermLength);
}
}
return term;
}
}
}
15 changes: 14 additions & 1 deletion server/src/main/java/org/elasticsearch/index/engine/Engine.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.index.VersionType;
import org.elasticsearch.index.codec.FieldInfosWithUsages;
import org.elasticsearch.index.codec.TrackingPostingsInMemoryBytesCodec;
import org.elasticsearch.index.codec.vectors.reflect.OffHeapByteSizeUtils;
import org.elasticsearch.index.mapper.DocumentParser;
import org.elasticsearch.index.mapper.LuceneDocument;
Expand Down Expand Up @@ -275,6 +276,7 @@ protected static ShardFieldStats shardFieldStats(List<LeafReaderContext> leaves)
int numSegments = 0;
int totalFields = 0;
long usages = 0;
long totalPostingBytes = 0;
for (LeafReaderContext leaf : leaves) {
numSegments++;
var fieldInfos = leaf.reader().getFieldInfos();
Expand All @@ -286,8 +288,19 @@ protected static ShardFieldStats shardFieldStats(List<LeafReaderContext> leaves)
} else {
usages = -1;
}
if (TrackingPostingsInMemoryBytesCodec.TRACK_POSTINGS_IN_MEMORY_BYTES.isEnabled()) {
SegmentReader segmentReader = Lucene.tryUnwrapSegmentReader(leaf.reader());
if (segmentReader != null) {
String postingBytes = segmentReader.getSegmentInfo().info.getAttribute(
TrackingPostingsInMemoryBytesCodec.IN_MEMORY_POSTINGS_BYTES_KEY
);
if (postingBytes != null) {
totalPostingBytes += Long.parseLong(postingBytes);
}
}
}
}
return new ShardFieldStats(numSegments, totalFields, usages);
return new ShardFieldStats(numSegments, totalFields, usages, totalPostingBytes);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
package org.elasticsearch.index.engine;

import org.apache.logging.log4j.Logger;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexCommit;
Expand Down Expand Up @@ -79,6 +80,7 @@
import org.elasticsearch.index.IndexVersions;
import org.elasticsearch.index.VersionType;
import org.elasticsearch.index.cache.query.TrivialQueryCachingPolicy;
import org.elasticsearch.index.codec.TrackingPostingsInMemoryBytesCodec;
import org.elasticsearch.index.mapper.DocumentParser;
import org.elasticsearch.index.mapper.IdFieldMapper;
import org.elasticsearch.index.mapper.LuceneDocument;
Expand Down Expand Up @@ -2778,7 +2780,13 @@ private IndexWriterConfig getIndexWriterConfig() {
iwc.setMaxFullFlushMergeWaitMillis(-1);
iwc.setSimilarity(engineConfig.getSimilarity());
iwc.setRAMBufferSizeMB(engineConfig.getIndexingBufferSize().getMbFrac());
iwc.setCodec(engineConfig.getCodec());

Codec codec = engineConfig.getCodec();
if (TrackingPostingsInMemoryBytesCodec.TRACK_POSTINGS_IN_MEMORY_BYTES.isEnabled()) {
codec = new TrackingPostingsInMemoryBytesCodec(codec);
}
iwc.setCodec(codec);

boolean useCompoundFile = engineConfig.getUseCompoundFile();
iwc.setUseCompoundFile(useCompoundFile);
if (useCompoundFile == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
* @param totalFields the total number of fields across the segments
* @param fieldUsages the number of usages for segment-level fields (e.g., doc_values, postings, norms, points)
* -1 if unavailable
* @param postingsInMemoryBytes the total bytes in memory used for postings across all fields
*/
public record ShardFieldStats(int numSegments, int totalFields, long fieldUsages) {
public record ShardFieldStats(int numSegments, int totalFields, long fieldUsages, long postingsInMemoryBytes) {

}
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
import org.elasticsearch.index.IndexSettings;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.index.codec.CodecService;
import org.elasticsearch.index.codec.TrackingPostingsInMemoryBytesCodec;
import org.elasticsearch.index.engine.CommitStats;
import org.elasticsearch.index.engine.DocIdSeqNoAndSource;
import org.elasticsearch.index.engine.Engine;
Expand Down Expand Up @@ -1882,8 +1883,12 @@ public void testShardFieldStats() throws IOException {
assertThat(stats.numSegments(), equalTo(0));
assertThat(stats.totalFields(), equalTo(0));
assertThat(stats.fieldUsages(), equalTo(0L));
assertThat(stats.postingsInMemoryBytes(), equalTo(0L));

boolean postingsBytesTrackingEnabled = TrackingPostingsInMemoryBytesCodec.TRACK_POSTINGS_IN_MEMORY_BYTES.isEnabled();

// index some documents
int numDocs = between(1, 10);
int numDocs = between(2, 10);
for (int i = 0; i < numDocs; i++) {
indexDoc(shard, "_doc", "first_" + i, """
{
Expand All @@ -1901,6 +1906,9 @@ public void testShardFieldStats() throws IOException {
// _id(term), _source(0), _version(dv), _primary_term(dv), _seq_no(point,dv), f1(postings,norms),
// f1.keyword(term,dv), f2(postings,norms), f2.keyword(term,dv),
assertThat(stats.fieldUsages(), equalTo(13L));
// _id: (5,8), f1: 3, f1.keyword: 3, f2: 3, f2.keyword: 3
// 5 + 8 + 3 + 3 + 3 + 3 = 25
assertThat(stats.postingsInMemoryBytes(), equalTo(postingsBytesTrackingEnabled ? 25L : 0L));
// don't re-compute on refresh without change
if (randomBoolean()) {
shard.refresh("test");
Expand All @@ -1919,11 +1927,18 @@ public void testShardFieldStats() throws IOException {
assertThat(shard.getShardFieldStats(), sameInstance(stats));
// index more docs
numDocs = between(1, 10);
indexDoc(shard, "_doc", "first_0", """
{
"f1": "lorem",
"f2": "bar",
"f3": "sit amet"
}
""");
for (int i = 0; i < numDocs; i++) {
indexDoc(shard, "_doc", "first_" + i, """
indexDoc(shard, "_doc", "first_" + i + 1, """
{
"f1": "foo",
"f2": "bar",
"f2": "ipsum",
"f3": "foobar"
}
""");
Expand All @@ -1948,13 +1963,20 @@ public void testShardFieldStats() throws IOException {
assertThat(stats.totalFields(), equalTo(21));
// first segment: 13, second segment: 13 + f3(postings,norms) + f3.keyword(term,dv), and __soft_deletes to previous segment
assertThat(stats.fieldUsages(), equalTo(31L));
// segment 1: 25 (see above)
// segment 2: _id: (5,6), f1: (3,5), f1.keyword: (3,5), f2: (3,5), f2.keyword: (3,5), f3: (4,3), f3.keyword: (6,8)
// (5+6) + (3+5) + (3+5) + (3+5) + (3+5) + (4+3) + (6+8) = 64
// 25 + 64 = 89
assertThat(stats.postingsInMemoryBytes(), equalTo(postingsBytesTrackingEnabled ? 89L : 0L));
shard.forceMerge(new ForceMergeRequest().maxNumSegments(1).flush(true));
stats = shard.getShardFieldStats();
assertThat(stats.numSegments(), equalTo(1));
assertThat(stats.totalFields(), equalTo(12));
// _id(term), _source(0), _version(dv), _primary_term(dv), _seq_no(point,dv), f1(postings,norms),
// f1.keyword(term,dv), f2(postings,norms), f2.keyword(term,dv), f3(postings,norms), f3.keyword(term,dv), __soft_deletes
assertThat(stats.fieldUsages(), equalTo(18L));
// _id: (5,8), f1: (3,5), f1.keyword: (3,5), f2: (3,5), f2.keyword: (3,5), f3: (4,3), f3.keyword: (6,8)
assertThat(stats.postingsInMemoryBytes(), equalTo(postingsBytesTrackingEnabled ? 66L : 0L));
closeShards(shard);
}

Expand Down