diff --git a/docs/changelog/129990.yaml b/docs/changelog/129990.yaml
new file mode 100644
index 0000000000000..aa5ca924f2932
--- /dev/null
+++ b/docs/changelog/129990.yaml
@@ -0,0 +1,5 @@
+pr: 129990
+summary: Make forecast write load accurate when shard numbers change
+area: Allocation
+type: bug
+issues: []
diff --git a/server/src/main/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingService.java b/server/src/main/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingService.java
index 61eb2867790a2..334e0c7564f7c 100644
--- a/server/src/main/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingService.java
+++ b/server/src/main/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingService.java
@@ -362,7 +362,7 @@ public String toString() {
*
If the recommendation is to INCREASE/DECREASE shards the reported cooldown period will be TimeValue.ZERO.
* If the auto sharding service thinks the number of shards must be changed but it can't recommend a change due to the cooldown
* period not lapsing, the result will be of type {@link AutoShardingType#COOLDOWN_PREVENTED_INCREASE} or
- * {@link AutoShardingType#COOLDOWN_PREVENTED_INCREASE} with the remaining cooldown configured and the number of shards that should
+ * {@link AutoShardingType#COOLDOWN_PREVENTED_DECREASE} with the remaining cooldown configured and the number of shards that should
* be configured for the data stream once the remaining cooldown lapses as the target number of shards.
*
*
The NOT_APPLICABLE type result will report a cooldown period of TimeValue.MAX_VALUE.
diff --git a/x-pack/plugin/write-load-forecaster/src/main/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecaster.java b/x-pack/plugin/write-load-forecaster/src/main/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecaster.java
index b9b900836c6fc..db68694af4143 100644
--- a/x-pack/plugin/write-load-forecaster/src/main/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecaster.java
+++ b/x-pack/plugin/write-load-forecaster/src/main/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecaster.java
@@ -108,7 +108,12 @@ public ProjectMetadata.Builder withWriteLoadForecastForWriteIndex(String dataStr
}
final IndexMetadata writeIndex = metadata.getSafe(dataStream.getWriteIndex());
- metadata.put(IndexMetadata.builder(writeIndex).indexWriteLoadForecast(forecastIndexWriteLoad.getAsDouble()).build(), false);
+ metadata.put(
+ IndexMetadata.builder(writeIndex)
+ .indexWriteLoadForecast(forecastIndexWriteLoad.getAsDouble() / writeIndex.getNumberOfShards())
+ .build(),
+ false
+ );
return metadata;
}
@@ -129,11 +134,20 @@ private static void clearPreviousForecast(DataStream dataStream, ProjectMetadata
}
}
+ /**
+ * This calculates the weighted average total write-load for all recent indices.
+ *
+ * @param indicesWriteLoadWithinMaxAgeRange The indices considered "recent"
+ * @return The weighted average total write-load. To get the per-shard write load, this number must be divided by the number of shards
+ */
// Visible for testing
static OptionalDouble forecastIndexWriteLoad(List indicesWriteLoadWithinMaxAgeRange) {
- double totalWeightedWriteLoad = 0;
- long totalShardUptime = 0;
+ double allIndicesWriteLoad = 0;
+ long allIndicesUptime = 0;
for (IndexWriteLoad writeLoad : indicesWriteLoadWithinMaxAgeRange) {
+ double totalShardWriteLoad = 0;
+ long totalShardUptimeInMillis = 0;
+ long maxShardUptimeInMillis = 0;
for (int shardId = 0; shardId < writeLoad.numberOfShards(); shardId++) {
final OptionalDouble writeLoadForShard = writeLoad.getWriteLoadForShard(shardId);
final OptionalLong uptimeInMillisForShard = writeLoad.getUptimeInMillisForShard(shardId);
@@ -141,13 +155,27 @@ static OptionalDouble forecastIndexWriteLoad(List indicesWriteLo
assert uptimeInMillisForShard.isPresent();
double shardWriteLoad = writeLoadForShard.getAsDouble();
long shardUptimeInMillis = uptimeInMillisForShard.getAsLong();
- totalWeightedWriteLoad += shardWriteLoad * shardUptimeInMillis;
- totalShardUptime += shardUptimeInMillis;
+ totalShardWriteLoad += shardWriteLoad * shardUptimeInMillis;
+ totalShardUptimeInMillis += shardUptimeInMillis;
+ maxShardUptimeInMillis = Math.max(maxShardUptimeInMillis, shardUptimeInMillis);
}
}
+ double weightedAverageShardWriteLoad = totalShardWriteLoad / totalShardUptimeInMillis;
+ double totalIndexWriteLoad = weightedAverageShardWriteLoad * writeLoad.numberOfShards();
+ // We need to weight the contribution from each index somehow, but we only know
+ // the write-load from the final allocation of each shard at rollover time. It's
+ // possible the index is much older than any of those shards, but we don't have
+ // any write-load data beyond their lifetime.
+ // To avoid making assumptions about periods for which we have no data, we'll weight
+ // each index's contribution to the forecast by the maximum shard uptime observed in
+ // that index. It should be safe to extrapolate our weighted average out to the
+ // maximum uptime observed, based on the assumption that write-load is roughly
+ // evenly distributed across shards of a datastream index.
+ allIndicesWriteLoad += totalIndexWriteLoad * maxShardUptimeInMillis;
+ allIndicesUptime += maxShardUptimeInMillis;
}
- return totalShardUptime == 0 ? OptionalDouble.empty() : OptionalDouble.of(totalWeightedWriteLoad / totalShardUptime);
+ return allIndicesUptime == 0 ? OptionalDouble.empty() : OptionalDouble.of(allIndicesWriteLoad / allIndicesUptime);
}
@Override
diff --git a/x-pack/plugin/write-load-forecaster/src/test/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecasterTests.java b/x-pack/plugin/write-load-forecaster/src/test/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecasterTests.java
index d3db529f3cda7..3062e9a82b2ac 100644
--- a/x-pack/plugin/write-load-forecaster/src/test/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecasterTests.java
+++ b/x-pack/plugin/write-load-forecaster/src/test/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecasterTests.java
@@ -9,6 +9,7 @@
import org.apache.logging.log4j.Level;
import org.apache.logging.log4j.core.LogEvent;
+import org.apache.lucene.util.hnsw.IntToIntFunction;
import org.elasticsearch.cluster.metadata.DataStream;
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.IndexMetadataStats;
@@ -24,16 +25,19 @@
import org.elasticsearch.test.MockLog;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
+import org.hamcrest.Matcher;
import org.junit.After;
import org.junit.Before;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
import java.util.OptionalDouble;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
import static org.elasticsearch.xpack.writeloadforecaster.LicensedWriteLoadForecaster.forecastIndexWriteLoad;
import static org.hamcrest.Matchers.closeTo;
@@ -42,6 +46,7 @@
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.lessThan;
public class LicensedWriteLoadForecasterTests extends ESTestCase {
ThreadPool threadPool;
@@ -67,33 +72,15 @@ public void testWriteLoadForecastIsAddedToWriteIndex() {
writeLoadForecaster.refreshLicense();
- final ProjectMetadata.Builder metadataBuilder = ProjectMetadata.builder(randomProjectIdOrDefault());
final String dataStreamName = "logs-es";
final int numberOfBackingIndices = 10;
- final int numberOfShards = randomIntBetween(1, 5);
- final List backingIndices = new ArrayList<>();
- for (int i = 0; i < numberOfBackingIndices; i++) {
- final IndexMetadata indexMetadata = createIndexMetadata(
- DataStream.getDefaultBackingIndexName(dataStreamName, i),
- numberOfShards,
- randomIndexWriteLoad(numberOfShards),
- System.currentTimeMillis() - (maxIndexAge.millis() / 2)
- );
- backingIndices.add(indexMetadata.getIndex());
- metadataBuilder.put(indexMetadata, false);
- }
-
- final IndexMetadata writeIndexMetadata = createIndexMetadata(
- DataStream.getDefaultBackingIndexName(dataStreamName, numberOfBackingIndices),
- numberOfShards,
- null,
- System.currentTimeMillis()
+ final ProjectMetadata.Builder metadataBuilder = createMetadataBuilderWithDataStream(
+ dataStreamName,
+ numberOfBackingIndices,
+ randomIntBetween(1, 5),
+ maxIndexAge
);
- backingIndices.add(writeIndexMetadata.getIndex());
- metadataBuilder.put(writeIndexMetadata, false);
-
- final DataStream dataStream = createDataStream(dataStreamName, backingIndices);
- metadataBuilder.put(dataStream);
+ final DataStream dataStream = metadataBuilder.dataStream(dataStreamName);
final ProjectMetadata.Builder updatedMetadataBuilder = writeLoadForecaster.withWriteLoadForecastForWriteIndex(
dataStream.getName(),
@@ -253,7 +240,7 @@ public void testWriteLoadForecast() {
)
);
assertThat(writeLoadForecast.isPresent(), is(true));
- assertThat(writeLoadForecast.getAsDouble(), is(equalTo(14.4)));
+ assertThat(writeLoadForecast.getAsDouble(), is(equalTo(72.0)));
}
{
@@ -264,14 +251,14 @@ public void testWriteLoadForecast() {
.withShardWriteLoad(1, 24, 999, 999, 5)
.withShardWriteLoad(2, 24, 999, 999, 5)
.withShardWriteLoad(3, 24, 999, 999, 5)
- .withShardWriteLoad(4, 24, 999, 999, 4)
+ .withShardWriteLoad(4, 24, 999, 999, 5)
.build(),
// Since this shard uptime is really low, it doesn't add much to the avg
IndexWriteLoad.builder(1).withShardWriteLoad(0, 120, 999, 999, 1).build()
)
);
assertThat(writeLoadForecast.isPresent(), is(true));
- assertThat(writeLoadForecast.getAsDouble(), is(equalTo(15.36)));
+ assertThat(writeLoadForecast.getAsDouble(), is(closeTo(72.59, 0.01)));
}
{
@@ -283,7 +270,7 @@ public void testWriteLoadForecast() {
)
);
assertThat(writeLoadForecast.isPresent(), is(true));
- assertThat(writeLoadForecast.getAsDouble(), is(equalTo(12.0)));
+ assertThat(writeLoadForecast.getAsDouble(), is(equalTo(16.0)));
}
{
@@ -302,7 +289,7 @@ public void testWriteLoadForecast() {
)
);
assertThat(writeLoadForecast.isPresent(), is(true));
- assertThat(writeLoadForecast.getAsDouble(), is(closeTo(15.83, 0.01)));
+ assertThat(writeLoadForecast.getAsDouble(), is(closeTo(31.66, 0.01)));
}
}
@@ -404,4 +391,163 @@ public boolean innerMatch(LogEvent event) {
);
}, LicensedWriteLoadForecaster.class, collectingLoggingAssertion);
}
+
+ public void testShardIncreaseDoesNotIncreaseTotalLoad() {
+ testShardChangeDoesNotChangeTotalForecastLoad(ShardCountChange.INCREASE);
+ }
+
+ public void testShardDecreaseDoesNotDecreaseTotalLoad() {
+ testShardChangeDoesNotChangeTotalForecastLoad(ShardCountChange.DECREASE);
+ }
+
+ private void testShardChangeDoesNotChangeTotalForecastLoad(ShardCountChange shardCountChange) {
+ final TimeValue maxIndexAge = TimeValue.timeValueDays(7);
+ final AtomicBoolean hasValidLicense = new AtomicBoolean(true);
+ final AtomicInteger licenseCheckCount = new AtomicInteger();
+ final WriteLoadForecaster writeLoadForecaster = new LicensedWriteLoadForecaster(() -> {
+ licenseCheckCount.incrementAndGet();
+ return hasValidLicense.get();
+ }, threadPool, maxIndexAge);
+ writeLoadForecaster.refreshLicense();
+
+ final String dataStreamName = randomIdentifier();
+ final ProjectMetadata.Builder originalMetadata = writeLoadForecaster.withWriteLoadForecastForWriteIndex(
+ dataStreamName,
+ createMetadataBuilderWithDataStream(dataStreamName, randomIntBetween(5, 15), shardCountChange.originalShardCount(), maxIndexAge)
+ );
+
+ // Generate the same data stream, but with a different number of shards in the write index
+ final ProjectMetadata.Builder changedShardCountMetadata = writeLoadForecaster.withWriteLoadForecastForWriteIndex(
+ dataStreamName,
+ updateWriteIndexShardCount(dataStreamName, originalMetadata, shardCountChange)
+ );
+
+ IndexMetadata originalWriteIndexMetadata = originalMetadata.getSafe(originalMetadata.dataStream(dataStreamName).getWriteIndex());
+ IndexMetadata changedShardCountWriteIndexMetadata = changedShardCountMetadata.getSafe(
+ changedShardCountMetadata.dataStream(dataStreamName).getWriteIndex()
+ );
+
+ // The shard count changed
+ assertThat(
+ changedShardCountWriteIndexMetadata.getNumberOfShards(),
+ shardCountChange.expectedChangeFromOriginal(originalWriteIndexMetadata.getNumberOfShards())
+ );
+ // But the total write-load did not
+ assertThat(
+ changedShardCountWriteIndexMetadata.getNumberOfShards() * writeLoadForecaster.getForecastedWriteLoad(
+ changedShardCountWriteIndexMetadata
+ ).getAsDouble(),
+ closeTo(
+ originalWriteIndexMetadata.getNumberOfShards() * writeLoadForecaster.getForecastedWriteLoad(originalWriteIndexMetadata)
+ .getAsDouble(),
+ 0.01
+ )
+ );
+ }
+
+ public enum ShardCountChange implements IntToIntFunction {
+ INCREASE(1, 15) {
+ @Override
+ public int apply(int originalShardCount) {
+ return randomIntBetween(originalShardCount + 1, originalShardCount * 3);
+ }
+
+ public Matcher expectedChangeFromOriginal(int originalShardCount) {
+ return greaterThan(originalShardCount);
+ }
+ },
+ DECREASE(10, 30) {
+ @Override
+ public int apply(int originalShardCount) {
+ return randomIntBetween(1, originalShardCount - 1);
+ }
+
+ public Matcher expectedChangeFromOriginal(int originalShardCount) {
+ return lessThan(originalShardCount);
+ }
+ };
+
+ private final int originalMinimumShardCount;
+ private final int originalMaximumShardCount;
+
+ ShardCountChange(int originalMinimumShardCount, int originalMaximumShardCount) {
+ this.originalMinimumShardCount = originalMinimumShardCount;
+ this.originalMaximumShardCount = originalMaximumShardCount;
+ }
+
+ public int originalShardCount() {
+ return randomIntBetween(originalMinimumShardCount, originalMaximumShardCount);
+ }
+
+ abstract Matcher expectedChangeFromOriginal(int originalShardCount);
+ }
+
+ private ProjectMetadata.Builder updateWriteIndexShardCount(
+ String dataStreamName,
+ ProjectMetadata.Builder originalMetadata,
+ ShardCountChange shardCountChange
+ ) {
+ final ProjectMetadata.Builder updatedShardCountMetadata = ProjectMetadata.builder(originalMetadata.getId());
+
+ final DataStream originalDataStream = originalMetadata.dataStream(dataStreamName);
+ final Index existingWriteIndex = Objects.requireNonNull(originalDataStream.getWriteIndex());
+ final IndexMetadata originalWriteIndexMetadata = originalMetadata.getSafe(existingWriteIndex);
+
+ // Copy all non-write indices over unchanged
+ final List backingIndexMetadatas = originalDataStream.getIndices()
+ .stream()
+ .filter(index -> index != existingWriteIndex)
+ .map(originalMetadata::getSafe)
+ .collect(Collectors.toList());
+
+ // Create a new write index with an updated shard count
+ final IndexMetadata writeIndexMetadata = createIndexMetadata(
+ DataStream.getDefaultBackingIndexName(dataStreamName, backingIndexMetadatas.size()),
+ shardCountChange.apply(originalWriteIndexMetadata.getNumberOfShards()),
+ null,
+ System.currentTimeMillis()
+ );
+ backingIndexMetadatas.add(writeIndexMetadata);
+ backingIndexMetadatas.forEach(indexMetadata -> updatedShardCountMetadata.put(indexMetadata, false));
+
+ final DataStream dataStream = createDataStream(
+ dataStreamName,
+ backingIndexMetadatas.stream().map(IndexMetadata::getIndex).toList()
+ );
+ updatedShardCountMetadata.put(dataStream);
+ return updatedShardCountMetadata;
+ }
+
+ private ProjectMetadata.Builder createMetadataBuilderWithDataStream(
+ String dataStreamName,
+ int numberOfBackingIndices,
+ int numberOfShards,
+ TimeValue maxIndexAge
+ ) {
+ final ProjectMetadata.Builder metadataBuilder = ProjectMetadata.builder(randomProjectIdOrDefault());
+ final List backingIndices = new ArrayList<>();
+ for (int i = 0; i < numberOfBackingIndices; i++) {
+ final IndexMetadata indexMetadata = createIndexMetadata(
+ DataStream.getDefaultBackingIndexName(dataStreamName, i),
+ numberOfShards,
+ randomIndexWriteLoad(numberOfShards),
+ System.currentTimeMillis() - (maxIndexAge.millis() / 2)
+ );
+ backingIndices.add(indexMetadata.getIndex());
+ metadataBuilder.put(indexMetadata, false);
+ }
+
+ final IndexMetadata writeIndexMetadata = createIndexMetadata(
+ DataStream.getDefaultBackingIndexName(dataStreamName, numberOfBackingIndices),
+ numberOfShards,
+ null,
+ System.currentTimeMillis()
+ );
+ backingIndices.add(writeIndexMetadata.getIndex());
+ metadataBuilder.put(writeIndexMetadata, false);
+
+ final DataStream dataStream = createDataStream(dataStreamName, backingIndices);
+ metadataBuilder.put(dataStream);
+ return metadataBuilder;
+ }
}