diff --git a/docs/changelog/129990.yaml b/docs/changelog/129990.yaml new file mode 100644 index 0000000000000..aa5ca924f2932 --- /dev/null +++ b/docs/changelog/129990.yaml @@ -0,0 +1,5 @@ +pr: 129990 +summary: Make forecast write load accurate when shard numbers change +area: Allocation +type: bug +issues: [] diff --git a/server/src/main/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingService.java b/server/src/main/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingService.java index 61eb2867790a2..334e0c7564f7c 100644 --- a/server/src/main/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingService.java +++ b/server/src/main/java/org/elasticsearch/action/datastreams/autosharding/DataStreamAutoShardingService.java @@ -362,7 +362,7 @@ public String toString() { *

If the recommendation is to INCREASE/DECREASE shards the reported cooldown period will be TimeValue.ZERO. * If the auto sharding service thinks the number of shards must be changed but it can't recommend a change due to the cooldown * period not lapsing, the result will be of type {@link AutoShardingType#COOLDOWN_PREVENTED_INCREASE} or - * {@link AutoShardingType#COOLDOWN_PREVENTED_INCREASE} with the remaining cooldown configured and the number of shards that should + * {@link AutoShardingType#COOLDOWN_PREVENTED_DECREASE} with the remaining cooldown configured and the number of shards that should * be configured for the data stream once the remaining cooldown lapses as the target number of shards. * *

The NOT_APPLICABLE type result will report a cooldown period of TimeValue.MAX_VALUE. diff --git a/x-pack/plugin/write-load-forecaster/src/main/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecaster.java b/x-pack/plugin/write-load-forecaster/src/main/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecaster.java index b9b900836c6fc..db68694af4143 100644 --- a/x-pack/plugin/write-load-forecaster/src/main/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecaster.java +++ b/x-pack/plugin/write-load-forecaster/src/main/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecaster.java @@ -108,7 +108,12 @@ public ProjectMetadata.Builder withWriteLoadForecastForWriteIndex(String dataStr } final IndexMetadata writeIndex = metadata.getSafe(dataStream.getWriteIndex()); - metadata.put(IndexMetadata.builder(writeIndex).indexWriteLoadForecast(forecastIndexWriteLoad.getAsDouble()).build(), false); + metadata.put( + IndexMetadata.builder(writeIndex) + .indexWriteLoadForecast(forecastIndexWriteLoad.getAsDouble() / writeIndex.getNumberOfShards()) + .build(), + false + ); return metadata; } @@ -129,11 +134,20 @@ private static void clearPreviousForecast(DataStream dataStream, ProjectMetadata } } + /** + * This calculates the weighted average total write-load for all recent indices. + * + * @param indicesWriteLoadWithinMaxAgeRange The indices considered "recent" + * @return The weighted average total write-load. To get the per-shard write load, this number must be divided by the number of shards + */ // Visible for testing static OptionalDouble forecastIndexWriteLoad(List indicesWriteLoadWithinMaxAgeRange) { - double totalWeightedWriteLoad = 0; - long totalShardUptime = 0; + double allIndicesWriteLoad = 0; + long allIndicesUptime = 0; for (IndexWriteLoad writeLoad : indicesWriteLoadWithinMaxAgeRange) { + double totalShardWriteLoad = 0; + long totalShardUptimeInMillis = 0; + long maxShardUptimeInMillis = 0; for (int shardId = 0; shardId < writeLoad.numberOfShards(); shardId++) { final OptionalDouble writeLoadForShard = writeLoad.getWriteLoadForShard(shardId); final OptionalLong uptimeInMillisForShard = writeLoad.getUptimeInMillisForShard(shardId); @@ -141,13 +155,27 @@ static OptionalDouble forecastIndexWriteLoad(List indicesWriteLo assert uptimeInMillisForShard.isPresent(); double shardWriteLoad = writeLoadForShard.getAsDouble(); long shardUptimeInMillis = uptimeInMillisForShard.getAsLong(); - totalWeightedWriteLoad += shardWriteLoad * shardUptimeInMillis; - totalShardUptime += shardUptimeInMillis; + totalShardWriteLoad += shardWriteLoad * shardUptimeInMillis; + totalShardUptimeInMillis += shardUptimeInMillis; + maxShardUptimeInMillis = Math.max(maxShardUptimeInMillis, shardUptimeInMillis); } } + double weightedAverageShardWriteLoad = totalShardWriteLoad / totalShardUptimeInMillis; + double totalIndexWriteLoad = weightedAverageShardWriteLoad * writeLoad.numberOfShards(); + // We need to weight the contribution from each index somehow, but we only know + // the write-load from the final allocation of each shard at rollover time. It's + // possible the index is much older than any of those shards, but we don't have + // any write-load data beyond their lifetime. + // To avoid making assumptions about periods for which we have no data, we'll weight + // each index's contribution to the forecast by the maximum shard uptime observed in + // that index. It should be safe to extrapolate our weighted average out to the + // maximum uptime observed, based on the assumption that write-load is roughly + // evenly distributed across shards of a datastream index. + allIndicesWriteLoad += totalIndexWriteLoad * maxShardUptimeInMillis; + allIndicesUptime += maxShardUptimeInMillis; } - return totalShardUptime == 0 ? OptionalDouble.empty() : OptionalDouble.of(totalWeightedWriteLoad / totalShardUptime); + return allIndicesUptime == 0 ? OptionalDouble.empty() : OptionalDouble.of(allIndicesWriteLoad / allIndicesUptime); } @Override diff --git a/x-pack/plugin/write-load-forecaster/src/test/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecasterTests.java b/x-pack/plugin/write-load-forecaster/src/test/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecasterTests.java index d3db529f3cda7..3062e9a82b2ac 100644 --- a/x-pack/plugin/write-load-forecaster/src/test/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecasterTests.java +++ b/x-pack/plugin/write-load-forecaster/src/test/java/org/elasticsearch/xpack/writeloadforecaster/LicensedWriteLoadForecasterTests.java @@ -9,6 +9,7 @@ import org.apache.logging.log4j.Level; import org.apache.logging.log4j.core.LogEvent; +import org.apache.lucene.util.hnsw.IntToIntFunction; import org.elasticsearch.cluster.metadata.DataStream; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.IndexMetadataStats; @@ -24,16 +25,19 @@ import org.elasticsearch.test.MockLog; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; +import org.hamcrest.Matcher; import org.junit.After; import org.junit.Before; import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.OptionalDouble; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; import static org.elasticsearch.xpack.writeloadforecaster.LicensedWriteLoadForecaster.forecastIndexWriteLoad; import static org.hamcrest.Matchers.closeTo; @@ -42,6 +46,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThan; public class LicensedWriteLoadForecasterTests extends ESTestCase { ThreadPool threadPool; @@ -67,33 +72,15 @@ public void testWriteLoadForecastIsAddedToWriteIndex() { writeLoadForecaster.refreshLicense(); - final ProjectMetadata.Builder metadataBuilder = ProjectMetadata.builder(randomProjectIdOrDefault()); final String dataStreamName = "logs-es"; final int numberOfBackingIndices = 10; - final int numberOfShards = randomIntBetween(1, 5); - final List backingIndices = new ArrayList<>(); - for (int i = 0; i < numberOfBackingIndices; i++) { - final IndexMetadata indexMetadata = createIndexMetadata( - DataStream.getDefaultBackingIndexName(dataStreamName, i), - numberOfShards, - randomIndexWriteLoad(numberOfShards), - System.currentTimeMillis() - (maxIndexAge.millis() / 2) - ); - backingIndices.add(indexMetadata.getIndex()); - metadataBuilder.put(indexMetadata, false); - } - - final IndexMetadata writeIndexMetadata = createIndexMetadata( - DataStream.getDefaultBackingIndexName(dataStreamName, numberOfBackingIndices), - numberOfShards, - null, - System.currentTimeMillis() + final ProjectMetadata.Builder metadataBuilder = createMetadataBuilderWithDataStream( + dataStreamName, + numberOfBackingIndices, + randomIntBetween(1, 5), + maxIndexAge ); - backingIndices.add(writeIndexMetadata.getIndex()); - metadataBuilder.put(writeIndexMetadata, false); - - final DataStream dataStream = createDataStream(dataStreamName, backingIndices); - metadataBuilder.put(dataStream); + final DataStream dataStream = metadataBuilder.dataStream(dataStreamName); final ProjectMetadata.Builder updatedMetadataBuilder = writeLoadForecaster.withWriteLoadForecastForWriteIndex( dataStream.getName(), @@ -253,7 +240,7 @@ public void testWriteLoadForecast() { ) ); assertThat(writeLoadForecast.isPresent(), is(true)); - assertThat(writeLoadForecast.getAsDouble(), is(equalTo(14.4))); + assertThat(writeLoadForecast.getAsDouble(), is(equalTo(72.0))); } { @@ -264,14 +251,14 @@ public void testWriteLoadForecast() { .withShardWriteLoad(1, 24, 999, 999, 5) .withShardWriteLoad(2, 24, 999, 999, 5) .withShardWriteLoad(3, 24, 999, 999, 5) - .withShardWriteLoad(4, 24, 999, 999, 4) + .withShardWriteLoad(4, 24, 999, 999, 5) .build(), // Since this shard uptime is really low, it doesn't add much to the avg IndexWriteLoad.builder(1).withShardWriteLoad(0, 120, 999, 999, 1).build() ) ); assertThat(writeLoadForecast.isPresent(), is(true)); - assertThat(writeLoadForecast.getAsDouble(), is(equalTo(15.36))); + assertThat(writeLoadForecast.getAsDouble(), is(closeTo(72.59, 0.01))); } { @@ -283,7 +270,7 @@ public void testWriteLoadForecast() { ) ); assertThat(writeLoadForecast.isPresent(), is(true)); - assertThat(writeLoadForecast.getAsDouble(), is(equalTo(12.0))); + assertThat(writeLoadForecast.getAsDouble(), is(equalTo(16.0))); } { @@ -302,7 +289,7 @@ public void testWriteLoadForecast() { ) ); assertThat(writeLoadForecast.isPresent(), is(true)); - assertThat(writeLoadForecast.getAsDouble(), is(closeTo(15.83, 0.01))); + assertThat(writeLoadForecast.getAsDouble(), is(closeTo(31.66, 0.01))); } } @@ -404,4 +391,163 @@ public boolean innerMatch(LogEvent event) { ); }, LicensedWriteLoadForecaster.class, collectingLoggingAssertion); } + + public void testShardIncreaseDoesNotIncreaseTotalLoad() { + testShardChangeDoesNotChangeTotalForecastLoad(ShardCountChange.INCREASE); + } + + public void testShardDecreaseDoesNotDecreaseTotalLoad() { + testShardChangeDoesNotChangeTotalForecastLoad(ShardCountChange.DECREASE); + } + + private void testShardChangeDoesNotChangeTotalForecastLoad(ShardCountChange shardCountChange) { + final TimeValue maxIndexAge = TimeValue.timeValueDays(7); + final AtomicBoolean hasValidLicense = new AtomicBoolean(true); + final AtomicInteger licenseCheckCount = new AtomicInteger(); + final WriteLoadForecaster writeLoadForecaster = new LicensedWriteLoadForecaster(() -> { + licenseCheckCount.incrementAndGet(); + return hasValidLicense.get(); + }, threadPool, maxIndexAge); + writeLoadForecaster.refreshLicense(); + + final String dataStreamName = randomIdentifier(); + final ProjectMetadata.Builder originalMetadata = writeLoadForecaster.withWriteLoadForecastForWriteIndex( + dataStreamName, + createMetadataBuilderWithDataStream(dataStreamName, randomIntBetween(5, 15), shardCountChange.originalShardCount(), maxIndexAge) + ); + + // Generate the same data stream, but with a different number of shards in the write index + final ProjectMetadata.Builder changedShardCountMetadata = writeLoadForecaster.withWriteLoadForecastForWriteIndex( + dataStreamName, + updateWriteIndexShardCount(dataStreamName, originalMetadata, shardCountChange) + ); + + IndexMetadata originalWriteIndexMetadata = originalMetadata.getSafe(originalMetadata.dataStream(dataStreamName).getWriteIndex()); + IndexMetadata changedShardCountWriteIndexMetadata = changedShardCountMetadata.getSafe( + changedShardCountMetadata.dataStream(dataStreamName).getWriteIndex() + ); + + // The shard count changed + assertThat( + changedShardCountWriteIndexMetadata.getNumberOfShards(), + shardCountChange.expectedChangeFromOriginal(originalWriteIndexMetadata.getNumberOfShards()) + ); + // But the total write-load did not + assertThat( + changedShardCountWriteIndexMetadata.getNumberOfShards() * writeLoadForecaster.getForecastedWriteLoad( + changedShardCountWriteIndexMetadata + ).getAsDouble(), + closeTo( + originalWriteIndexMetadata.getNumberOfShards() * writeLoadForecaster.getForecastedWriteLoad(originalWriteIndexMetadata) + .getAsDouble(), + 0.01 + ) + ); + } + + public enum ShardCountChange implements IntToIntFunction { + INCREASE(1, 15) { + @Override + public int apply(int originalShardCount) { + return randomIntBetween(originalShardCount + 1, originalShardCount * 3); + } + + public Matcher expectedChangeFromOriginal(int originalShardCount) { + return greaterThan(originalShardCount); + } + }, + DECREASE(10, 30) { + @Override + public int apply(int originalShardCount) { + return randomIntBetween(1, originalShardCount - 1); + } + + public Matcher expectedChangeFromOriginal(int originalShardCount) { + return lessThan(originalShardCount); + } + }; + + private final int originalMinimumShardCount; + private final int originalMaximumShardCount; + + ShardCountChange(int originalMinimumShardCount, int originalMaximumShardCount) { + this.originalMinimumShardCount = originalMinimumShardCount; + this.originalMaximumShardCount = originalMaximumShardCount; + } + + public int originalShardCount() { + return randomIntBetween(originalMinimumShardCount, originalMaximumShardCount); + } + + abstract Matcher expectedChangeFromOriginal(int originalShardCount); + } + + private ProjectMetadata.Builder updateWriteIndexShardCount( + String dataStreamName, + ProjectMetadata.Builder originalMetadata, + ShardCountChange shardCountChange + ) { + final ProjectMetadata.Builder updatedShardCountMetadata = ProjectMetadata.builder(originalMetadata.getId()); + + final DataStream originalDataStream = originalMetadata.dataStream(dataStreamName); + final Index existingWriteIndex = Objects.requireNonNull(originalDataStream.getWriteIndex()); + final IndexMetadata originalWriteIndexMetadata = originalMetadata.getSafe(existingWriteIndex); + + // Copy all non-write indices over unchanged + final List backingIndexMetadatas = originalDataStream.getIndices() + .stream() + .filter(index -> index != existingWriteIndex) + .map(originalMetadata::getSafe) + .collect(Collectors.toList()); + + // Create a new write index with an updated shard count + final IndexMetadata writeIndexMetadata = createIndexMetadata( + DataStream.getDefaultBackingIndexName(dataStreamName, backingIndexMetadatas.size()), + shardCountChange.apply(originalWriteIndexMetadata.getNumberOfShards()), + null, + System.currentTimeMillis() + ); + backingIndexMetadatas.add(writeIndexMetadata); + backingIndexMetadatas.forEach(indexMetadata -> updatedShardCountMetadata.put(indexMetadata, false)); + + final DataStream dataStream = createDataStream( + dataStreamName, + backingIndexMetadatas.stream().map(IndexMetadata::getIndex).toList() + ); + updatedShardCountMetadata.put(dataStream); + return updatedShardCountMetadata; + } + + private ProjectMetadata.Builder createMetadataBuilderWithDataStream( + String dataStreamName, + int numberOfBackingIndices, + int numberOfShards, + TimeValue maxIndexAge + ) { + final ProjectMetadata.Builder metadataBuilder = ProjectMetadata.builder(randomProjectIdOrDefault()); + final List backingIndices = new ArrayList<>(); + for (int i = 0; i < numberOfBackingIndices; i++) { + final IndexMetadata indexMetadata = createIndexMetadata( + DataStream.getDefaultBackingIndexName(dataStreamName, i), + numberOfShards, + randomIndexWriteLoad(numberOfShards), + System.currentTimeMillis() - (maxIndexAge.millis() / 2) + ); + backingIndices.add(indexMetadata.getIndex()); + metadataBuilder.put(indexMetadata, false); + } + + final IndexMetadata writeIndexMetadata = createIndexMetadata( + DataStream.getDefaultBackingIndexName(dataStreamName, numberOfBackingIndices), + numberOfShards, + null, + System.currentTimeMillis() + ); + backingIndices.add(writeIndexMetadata.getIndex()); + metadataBuilder.put(writeIndexMetadata, false); + + final DataStream dataStream = createDataStream(dataStreamName, backingIndices); + metadataBuilder.put(dataStream); + return metadataBuilder; + } }