From d5497886c0f58a2b9238eafcf36fccb7025e3d24 Mon Sep 17 00:00:00 2001 From: Nikita Date: Fri, 24 Oct 2025 15:19:33 -0700 Subject: [PATCH] KAFKA-19831: Improved error handling in DefaultStateUpdater. --- .../StateUpdaterFailureIntegrationTest.java | 136 ++++++++++++++++++ .../internals/DefaultStateUpdater.java | 83 +++++------ .../processor/internals/TaskAndAction.java | 7 +- .../processor/internals/TaskManager.java | 19 ++- .../internals/DefaultStateUpdaterTest.java | 134 +++++++++++++++-- .../processor/internals/TaskManagerTest.java | 39 ++++- 6 files changed, 347 insertions(+), 71 deletions(-) create mode 100644 streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/StateUpdaterFailureIntegrationTest.java diff --git a/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/StateUpdaterFailureIntegrationTest.java b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/StateUpdaterFailureIntegrationTest.java new file mode 100644 index 0000000000000..ffc21ab29525a --- /dev/null +++ b/streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/StateUpdaterFailureIntegrationTest.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.streams.integration; + +import org.apache.kafka.clients.consumer.ConsumerConfig; +import org.apache.kafka.common.serialization.Serdes; +import org.apache.kafka.common.utils.MockTime; +import org.apache.kafka.streams.KafkaStreams; +import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.TopologyWrapper; +import org.apache.kafka.streams.errors.ProcessorStateException; +import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster; +import org.apache.kafka.streams.processor.StateStore; +import org.apache.kafka.streams.processor.StateStoreContext; +import org.apache.kafka.streams.state.KeyValueStore; +import org.apache.kafka.streams.state.StoreBuilder; +import org.apache.kafka.streams.state.internals.AbstractStoreBuilder; +import org.apache.kafka.test.MockApiProcessorSupplier; +import org.apache.kafka.test.MockKeyValueStore; +import org.apache.kafka.test.TestUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInfo; + +import java.io.IOException; +import java.time.Duration; +import java.util.Properties; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static org.apache.kafka.streams.utils.TestUtils.safeUniqueTestName; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class StateUpdaterFailureIntegrationTest { + + private static final int NUM_BROKERS = 1; + protected static final String INPUT_TOPIC_NAME = "input-topic"; + private static final int NUM_PARTITIONS = 6; + + private final EmbeddedKafkaCluster cluster = new EmbeddedKafkaCluster(NUM_BROKERS); + + private Properties streamsConfiguration; + private final MockTime mockTime = cluster.time; + private KafkaStreams streams; + + @BeforeEach + public void before(final TestInfo testInfo) throws InterruptedException, IOException { + cluster.start(); + cluster.createTopic(INPUT_TOPIC_NAME, NUM_PARTITIONS, 1); + streamsConfiguration = new Properties(); + final String safeTestName = safeUniqueTestName(testInfo); + streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, "app-" + safeTestName); + streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, cluster.bootstrapServers()); + streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest"); + streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, TestUtils.tempDirectory().getPath()); + streamsConfiguration.put(StreamsConfig.STATESTORE_CACHE_MAX_BYTES_CONFIG, 0); + streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 100L); + streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, Serdes.IntegerSerde.class); + streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, Serdes.StringSerde.class); + streamsConfiguration.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 2); + + } + + @AfterEach + public void after() { + cluster.stop(); + if (streams != null) { + streams.close(Duration.ofSeconds(30)); + } + } + + @Test + public void correctlyHandleFlushErrorsDuringRebalance() throws InterruptedException { + final AtomicInteger numberOfStoreInits = new AtomicInteger(); + final AtomicReference currentState = new AtomicReference<>(); + + final StoreBuilder> storeBuilder = new AbstractStoreBuilder<>("testStateStore", Serdes.Integer(), Serdes.ByteArray(), new MockTime()) { + + @Override + public KeyValueStore build() { + return new MockKeyValueStore(name, false) { + + @Override + public void init(final StateStoreContext stateStoreContext, final StateStore root) { + super.init(stateStoreContext, root); + numberOfStoreInits.incrementAndGet(); + } + + @Override + public void flush() { + if (numberOfStoreInits.get() == NUM_PARTITIONS * 1.5) { + try { + TestUtils.waitForCondition(() -> currentState.get() == KafkaStreams.State.PENDING_SHUTDOWN, "Streams never reached PENDING_SHUTDOWN state"); + } catch (final InterruptedException e) { + throw new RuntimeException(e); + } + throw new ProcessorStateException("flush"); + } + } + }; + } + }; + + final TopologyWrapper topology = new TopologyWrapper(); + topology.addSource("ingest", INPUT_TOPIC_NAME); + topology.addProcessor("my-processor", new MockApiProcessorSupplier<>(), "ingest"); + topology.addStateStore(storeBuilder, "my-processor"); + + streams = new KafkaStreams(topology, streamsConfiguration); + streams.setStateListener((newState, oldState) -> currentState.set(newState)); + streams.start(); + + TestUtils.waitForCondition(() -> currentState.get() == KafkaStreams.State.RUNNING, "Streams never reached RUNNING state"); + + streams.removeStreamThread(); + + TestUtils.waitForCondition(() -> numberOfStoreInits.get() == NUM_PARTITIONS * 1.5, "Streams never reinitialized the store enough times"); + + assertTrue(streams.close(Duration.ofSeconds(60))); + } +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java index a3a44f6f02d31..fccb670c04caf 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdater.java @@ -208,11 +208,7 @@ private void performActionsOnTasks() { addTask(taskAndAction.task()); break; case REMOVE: - if (taskAndAction.futureForRemove() == null) { - removeTask(taskAndAction.taskId()); - } else { - removeTask(taskAndAction.taskId(), taskAndAction.futureForRemove()); - } + removeTask(taskAndAction.taskId(), taskAndAction.futureForRemove()); break; default: throw new IllegalStateException("Unknown action type " + action); @@ -349,23 +345,26 @@ private void handleTaskCorruptedException(final TaskCorruptedException taskCorru // TODO: we can let the exception encode the actual corrupted changelog partitions and only // mark those instead of marking all changelogs private void removeCheckpointForCorruptedTask(final Task task) { - task.markChangelogAsCorrupted(task.changelogPartitions()); + try { + task.markChangelogAsCorrupted(task.changelogPartitions()); - // we need to enforce a checkpoint that removes the corrupted partitions - measureCheckpointLatency(() -> task.maybeCheckpoint(true)); + // we need to enforce a checkpoint that removes the corrupted partitions + measureCheckpointLatency(() -> task.maybeCheckpoint(true)); + } catch (final StreamsException e) { + log.warn("Checkpoint failed for corrupted task {}", task.id(), e); + } } private void handleStreamsException(final StreamsException streamsException) { log.info("Encountered streams exception: ", streamsException); if (streamsException.taskId().isPresent()) { - handleStreamsExceptionWithTask(streamsException); + handleStreamsExceptionWithTask(streamsException, streamsException.taskId().get()); } else { handleStreamsExceptionWithoutTask(streamsException); } } - private void handleStreamsExceptionWithTask(final StreamsException streamsException) { - final TaskId failedTaskId = streamsException.taskId().get(); + private void handleStreamsExceptionWithTask(final StreamsException streamsException, final TaskId failedTaskId) { if (updatingTasks.containsKey(failedTaskId)) { addToExceptionsAndFailedTasksThenRemoveFromUpdatingTasks( new ExceptionAndTask(streamsException, updatingTasks.get(failedTaskId)) @@ -518,7 +517,7 @@ private void removeTask(final TaskId taskId, final CompletableFuture task.maybeCheckpoint(true)); - final Collection changelogPartitions = task.changelogPartitions(); - changelogReader.unregister(changelogPartitions); - removedTasks.add(task); + pausedTasks.put(taskId, task); updatingTasks.remove(taskId); if (task.isActive()) { transitToUpdateStandbysIfOnlyStandbysLeft(); } log.info((task.isActive() ? "Active" : "Standby") - + " task " + task.id() + " was removed from the updating tasks and added to the removed tasks."); - } else if (pausedTasks.containsKey(taskId)) { - task = pausedTasks.get(taskId); - final Collection changelogPartitions = task.changelogPartitions(); - changelogReader.unregister(changelogPartitions); - removedTasks.add(task); - pausedTasks.remove(taskId); - log.info((task.isActive() ? "Active" : "Standby") - + " task " + task.id() + " was removed from the paused tasks and added to the removed tasks."); - } else { - log.info("Task " + taskId + " was not removed since it is not updating or paused."); - } - } + + " task " + task.id() + " was paused from the updating tasks and added to the paused tasks."); - private void pauseTask(final Task task) { - final TaskId taskId = task.id(); - // do not need to unregister changelog partitions for paused tasks - measureCheckpointLatency(() -> task.maybeCheckpoint(true)); - pausedTasks.put(taskId, task); - updatingTasks.remove(taskId); - if (task.isActive()) { - transitToUpdateStandbysIfOnlyStandbysLeft(); + } catch (final StreamsException streamsException) { + handleStreamsExceptionWithTask(streamsException, taskId); } - log.info((task.isActive() ? "Active" : "Standby") - + " task " + task.id() + " was paused from the updating tasks and added to the paused tasks."); } private void resumeTask(final Task task) { @@ -671,11 +648,15 @@ private void maybeCompleteRestoration(final StreamTask task, final Set restoredChangelogs) { final Collection changelogPartitions = task.changelogPartitions(); if (restoredChangelogs.containsAll(changelogPartitions)) { - measureCheckpointLatency(() -> task.maybeCheckpoint(true)); - changelogReader.unregister(changelogPartitions); - addToRestoredTasks(task); - log.info("Stateful active task " + task.id() + " completed restoration"); - transitToUpdateStandbysIfOnlyStandbysLeft(); + try { + measureCheckpointLatency(() -> task.maybeCheckpoint(true)); + changelogReader.unregister(changelogPartitions); + addToRestoredTasks(task); + log.info("Stateful active task " + task.id() + " completed restoration"); + transitToUpdateStandbysIfOnlyStandbysLeft(); + } catch (final StreamsException streamsException) { + handleStreamsExceptionWithTask(streamsException, task.id()); + } } } @@ -707,8 +688,12 @@ private void maybeCheckpointTasks(final long now) { measureCheckpointLatency(() -> { for (final Task task : updatingTasks.values()) { - // do not enforce checkpointing during restoration if its position has not advanced much - task.maybeCheckpoint(false); + try { + // do not enforce checkpointing during restoration if its position has not advanced much + task.maybeCheckpoint(false); + } catch (final StreamsException streamsException) { + handleStreamsExceptionWithTask(streamsException, task.id()); + } } }); diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskAndAction.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskAndAction.java index b9c07151cfa01..ec6c6830bbdf1 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskAndAction.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskAndAction.java @@ -55,11 +55,6 @@ public static TaskAndAction createRemoveTask(final TaskId taskId, return new TaskAndAction(null, taskId, Action.REMOVE, future); } - public static TaskAndAction createRemoveTask(final TaskId taskId) { - Objects.requireNonNull(taskId, "Task ID of task to remove is null!"); - return new TaskAndAction(null, taskId, Action.REMOVE, null); - } - public Task task() { if (action != Action.ADD) { throw new IllegalStateException("Action type " + action + " cannot have a task!"); @@ -84,4 +79,4 @@ public CompletableFuture futureForRemove() { public Action action() { return action; } -} \ No newline at end of file +} diff --git a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java index 67d009b037f78..20d97ae9f4621 100644 --- a/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java +++ b/streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java @@ -65,6 +65,7 @@ import java.util.TreeSet; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -772,7 +773,7 @@ private StateUpdater.RemovedTaskResult waitForFuture(final TaskId taskId, final CompletableFuture future) { final StateUpdater.RemovedTaskResult removedTaskResult; try { - removedTaskResult = future.get(); + removedTaskResult = future.get(1, TimeUnit.MINUTES); if (removedTaskResult == null) { throw new IllegalStateException("Task " + taskId + " was not found in the state updater. " + BUG_ERROR_MESSAGE); @@ -787,6 +788,10 @@ private StateUpdater.RemovedTaskResult waitForFuture(final TaskId taskId, Thread.currentThread().interrupt(); log.error(INTERRUPTED_ERROR_MESSAGE, shouldNotHappen); throw new IllegalStateException(INTERRUPTED_ERROR_MESSAGE, shouldNotHappen); + } catch (final java.util.concurrent.TimeoutException timeoutException) { + log.warn("The state updater wasn't able to remove task {} in time. The state updater thread may be dead. " + + BUG_ERROR_MESSAGE, taskId, timeoutException); + return null; } } @@ -1567,6 +1572,12 @@ void shutdown(final boolean clean) { private void shutdownStateUpdater() { if (stateUpdater != null) { + // If there are failed tasks handling them first + for (final StateUpdater.ExceptionAndTask exceptionAndTask : stateUpdater.drainExceptionsAndFailedTasks()) { + final Task failedTask = exceptionAndTask.task(); + closeTaskDirty(failedTask, false); + } + final Map> futures = new LinkedHashMap<>(); for (final Task task : stateUpdater.tasks()) { final CompletableFuture future = stateUpdater.remove(task.id()); @@ -1583,10 +1594,16 @@ private void shutdownStateUpdater() { for (final Task task : tasksToCloseDirty) { closeTaskDirty(task, false); } + // Handling all failures that occurred during the remove process for (final StateUpdater.ExceptionAndTask exceptionAndTask : stateUpdater.drainExceptionsAndFailedTasks()) { final Task failedTask = exceptionAndTask.task(); closeTaskDirty(failedTask, false); } + + // If there is anything left unhandled due to timeouts, handling now + for (final Task task : stateUpdater.tasks()) { + closeTaskDirty(task, false); + } } } diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java index b6d41966257a5..c11dcf5189f25 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/DefaultStateUpdaterTest.java @@ -23,6 +23,7 @@ import org.apache.kafka.common.utils.MockTime; import org.apache.kafka.common.utils.Time; import org.apache.kafka.streams.StreamsConfig; +import org.apache.kafka.streams.errors.ProcessorStateException; import org.apache.kafka.streams.errors.StreamsException; import org.apache.kafka.streams.errors.TaskCorruptedException; import org.apache.kafka.streams.processor.TaskId; @@ -73,6 +74,7 @@ import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyMap; import static org.mockito.Mockito.atLeast; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.inOrder; @@ -1717,6 +1719,114 @@ public void shouldRemoveMetricsWithoutInterference() { } } + @Test + public void shouldNotFailTheThreadIfMaybeCheckpointFails() throws Exception { + final StreamTask activeTask1 = statefulTask(TASK_0_0, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final StreamTask activeTask2 = statefulTask(TASK_0_1, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final StreamTask failedStatefulTask = statefulTask(TASK_0_2, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final ProcessorStateException processorStateException = new ProcessorStateException("flush"); + doThrow(processorStateException).when(failedStatefulTask).maybeCheckpoint(anyBoolean()); + + stateUpdater.add(failedStatefulTask); + stateUpdater.add(activeTask1); + stateUpdater.start(); + verifyExceptionsAndFailedTasks(new ExceptionAndTask(processorStateException, failedStatefulTask)); + verifyUpdatingTasks(activeTask1); + + stateUpdater.add(activeTask2); + verifyUpdatingTasks(activeTask1, activeTask2); + } + + @Test + public void shouldNotFailTheThreadIfMaybeCheckpointFailsForCorruptedTask() throws Exception { + final StreamTask activeTask1 = statefulTask(TASK_0_0, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final StreamTask activeTask2 = statefulTask(TASK_0_1, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final StreamTask failedStatefulTask = statefulTask(TASK_0_2, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final ProcessorStateException processorStateException = new ProcessorStateException("flush"); + doThrow(processorStateException).when(failedStatefulTask).maybeCheckpoint(anyBoolean()); + + final TaskCorruptedException taskCorruptedException = new TaskCorruptedException(Set.of(TASK_0_2)); + when(changelogReader.restore(Map.of( + TASK_0_0, activeTask1, + TASK_0_2, failedStatefulTask)) + ).thenThrow(taskCorruptedException); + + stateUpdater.add(failedStatefulTask); + stateUpdater.add(activeTask1); + stateUpdater.start(); + verifyExceptionsAndFailedTasks(new ExceptionAndTask(taskCorruptedException, failedStatefulTask)); + verifyUpdatingTasks(activeTask1); + + stateUpdater.add(activeTask2); + verifyUpdatingTasks(activeTask1, activeTask2); + } + + @Test + public void shouldNotFailTheThreadIfMaybeCheckpointFailsDuringTaskRemoval() throws Exception { + final StreamTask activeTask1 = statefulTask(TASK_0_0, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final StreamTask activeTask2 = statefulTask(TASK_0_1, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final StreamTask failedStatefulTask = statefulTask(TASK_0_2, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final ProcessorStateException processorStateException = new ProcessorStateException("flush"); + final AtomicBoolean throwException = new AtomicBoolean(false); + doAnswer(invocation -> { + if (throwException.get()) { + throw processorStateException; + } + return null; + }).when(failedStatefulTask).maybeCheckpoint(anyBoolean()); + when(changelogReader.allChangelogsCompleted()).thenReturn(true); + + stateUpdater.add(failedStatefulTask); + stateUpdater.add(activeTask1); + stateUpdater.start(); + verifyUpdatingTasks(failedStatefulTask, activeTask1); + + throwException.set(true); + final ExecutionException exception = assertThrows(ExecutionException.class, () -> stateUpdater.remove(TASK_0_2).get()); + assertEquals(processorStateException, exception.getCause()); + + stateUpdater.add(activeTask2); + verifyUpdatingTasks(activeTask1, activeTask2); + } + + @Test + public void shouldNotFailTheThreadIfMaybeCheckpointFailsDuringTaskPause() throws Exception { + final StreamTask activeTask1 = statefulTask(TASK_0_0, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final StreamTask activeTask2 = statefulTask(TASK_0_1, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final StreamTask failedStatefulTask = statefulTask(TASK_0_2, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final ProcessorStateException processorStateException = new ProcessorStateException("flush"); + doThrow(processorStateException).when(failedStatefulTask).maybeCheckpoint(anyBoolean()); + when(topologyMetadata.isPaused(null)).thenReturn(false).thenReturn(false).thenReturn(true); + + stateUpdater.add(failedStatefulTask); + stateUpdater.add(activeTask1); + stateUpdater.start(); + verifyExceptionsAndFailedTasks(new ExceptionAndTask(processorStateException, failedStatefulTask)); + verifyPausedTasks(activeTask1); + + stateUpdater.add(activeTask2); + verifyPausedTasks(activeTask1, activeTask2); + } + + @Test + public void shouldNotFailTheThreadIfMaybeCheckpointFailsDuringTaskRestore() throws Exception { + final StreamTask activeTask1 = statefulTask(TASK_0_0, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final StreamTask activeTask2 = statefulTask(TASK_0_1, Set.of(TOPIC_PARTITION_A_0)).inState(State.RESTORING).build(); + final StreamTask failedStatefulTask = statefulTask(TASK_0_2, Set.of(TOPIC_PARTITION_B_0)).inState(State.RESTORING).build(); + final ProcessorStateException processorStateException = new ProcessorStateException("flush"); + doThrow(processorStateException).when(failedStatefulTask).maybeCheckpoint(anyBoolean()); + when(changelogReader.completedChangelogs()).thenReturn(Set.of(TOPIC_PARTITION_B_0)); + + stateUpdater.add(failedStatefulTask); + stateUpdater.add(activeTask1); + stateUpdater.start(); + verifyExceptionsAndFailedTasks(new ExceptionAndTask(processorStateException, failedStatefulTask)); + verifyUpdatingTasks(activeTask1); + + stateUpdater.add(activeTask2); + verifyUpdatingTasks(activeTask1, activeTask2); + } + private static List getMetricNames(final String threadId) { final Map tagMap = Map.of("thread-id", threadId); return List.of( @@ -1779,7 +1889,8 @@ private void verifyRestoredActiveTasks(final StreamTask... tasks) throws Excepti && restoredTasks.size() == expectedRestoredTasks.size(); }, VERIFICATION_TIMEOUT, - "Did not get all restored active task within the given timeout!" + () -> "Did not get all restored active task within the given timeout! Expected: " + + expectedRestoredTasks + ", actual: " + restoredTasks ); } } @@ -1794,7 +1905,8 @@ private void verifyDrainingRestoredActiveTasks(final StreamTask... tasks) throws && restoredTasks.size() == expectedRestoredTasks.size(); }, VERIFICATION_TIMEOUT, - "Did not get all restored active task within the given timeout!" + () -> "Did not get all restored active task within the given timeout! Expected: " + + expectedRestoredTasks + ", actual: " + restoredTasks ); assertTrue(stateUpdater.drainRestoredActiveTasks(Duration.ZERO).isEmpty()); } @@ -1816,7 +1928,8 @@ private void verifyUpdatingTasks(final Task... tasks) throws Exception { && updatingTasks.size() == expectedUpdatingTasks.size(); }, VERIFICATION_TIMEOUT, - "Did not get all updating task within the given timeout!" + () -> "Did not get all updating task within the given timeout! Expected: " + + expectedUpdatingTasks + ", actual: " + updatingTasks ); } } @@ -1831,7 +1944,8 @@ private void verifyUpdatingStandbyTasks(final StandbyTask... tasks) throws Excep && standbyTasks.size() == expectedStandbyTasks.size(); }, VERIFICATION_TIMEOUT, - "Did not see all standby task within the given timeout!" + () -> "Did not see all standby task within the given timeout! Expected: " + + expectedStandbyTasks + ", actual: " + standbyTasks ); } @@ -1860,7 +1974,8 @@ private void verifyPausedTasks(final Task... tasks) throws Exception { && pausedTasks.size() == expectedPausedTasks.size(); }, VERIFICATION_TIMEOUT, - "Did not get all paused task within the given timeout!" + () -> "Did not get all paused task within the given timeout! Expected: " + + expectedPausedTasks + ", actual: " + pausedTasks ); } } @@ -1875,7 +1990,8 @@ private void verifyExceptionsAndFailedTasks(final ExceptionAndTask... exceptions && failedTasks.size() == expectedExceptionAndTasks.size(); }, VERIFICATION_TIMEOUT, - "Did not get all exceptions and failed tasks within the given timeout!" + () -> "Did not get all exceptions and failed tasks within the given timeout! Expected: " + + expectedExceptionAndTasks + ", actual: " + failedTasks ); } @@ -1893,7 +2009,8 @@ private void verifyFailedTasks(final Class clazz, fi && failedTasks.size() == expectedFailedTasks.size(); }, VERIFICATION_TIMEOUT, - "Did not get all exceptions and failed tasks within the given timeout!" + () -> "Did not get all exceptions and failed tasks within the given timeout! Expected: " + + expectedFailedTasks + ", actual: " + failedTasks ); } @@ -1911,7 +2028,8 @@ private void verifyDrainingExceptionsAndFailedTasks(final ExceptionAndTask... ex && failedTasks.size() == expectedExceptionAndTasks.size(); }, VERIFICATION_TIMEOUT, - "Did not get all exceptions and failed tasks within the given timeout!" + () -> "Did not get all exceptions and failed tasks within the given timeout! Expected: " + + expectedExceptionAndTasks + ", actual: " + failedTasks ); assertFalse(stateUpdater.hasExceptionsAndFailedTasks()); assertTrue(stateUpdater.drainExceptionsAndFailedTasks().isEmpty()); diff --git a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java index 63cbc441f8a82..bdd244c8d62e9 100644 --- a/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java +++ b/streams/src/test/java/org/apache/kafka/streams/processor/internals/TaskManagerTest.java @@ -127,6 +127,7 @@ import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; @@ -3289,6 +3290,28 @@ public Set changelogPartitions() { verify(activeTaskCreator).close(); } + @SuppressWarnings("unchecked") + @Test + public void shouldCloseTasksIfStateUpdaterTimesOutOnRemove() throws Exception { + final TaskManager taskManager = setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, null, false); + final Map> assignment = mkMap( + mkEntry(taskId00, taskId00Partitions) + ); + final Task task00 = spy(new StateMachineTask(taskId00, taskId00Partitions, true, stateManager)); + + when(activeTaskCreator.createTasks(any(), eq(assignment))).thenReturn(singletonList(task00)); + taskManager.handleAssignment(assignment, emptyMap()); + + when(stateUpdater.tasks()).thenReturn(singleton(task00)); + final CompletableFuture future = mock(CompletableFuture.class); + when(stateUpdater.remove(eq(taskId00))).thenReturn(future); + when(future.get(anyLong(), any())).thenThrow(new java.util.concurrent.TimeoutException()); + + taskManager.shutdown(true); + + verify(task00).closeDirty(); + } + @Test public void shouldOnlyCommitRevokedStandbyTaskAndPropagatePrepareCommitException() { setUpTaskManagerWithoutStateUpdater(ProcessingMode.EXACTLY_ONCE_V2, null, false); @@ -3456,7 +3479,8 @@ public void shouldShutDownStateUpdaterAndCloseFailedTasksDirty() { .thenReturn(Arrays.asList( new ExceptionAndTask(new RuntimeException(), failedStatefulTask), new ExceptionAndTask(new RuntimeException(), failedStandbyTask)) - ); + ) + .thenReturn(Collections.emptyList()); final TaskManager taskManager = setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks); taskManager.shutdown(true); @@ -3500,8 +3524,8 @@ public void shouldShutDownStateUpdaterAndCloseDirtyTasksFailedDuringRemoval() { removedFailedStatefulTask, removedFailedStandbyTask, removedFailedStatefulTaskDuringRemoval, - removedFailedStandbyTaskDuringRemoval - )); + removedFailedStandbyTaskDuringRemoval) + ).thenReturn(Collections.emptySet()); final CompletableFuture futureForRemovedStatefulTask = new CompletableFuture<>(); final CompletableFuture futureForRemovedStandbyTask = new CompletableFuture<>(); final CompletableFuture futureForRemovedFailedStatefulTask = new CompletableFuture<>(); @@ -3516,10 +3540,11 @@ public void shouldShutDownStateUpdaterAndCloseDirtyTasksFailedDuringRemoval() { .thenReturn(futureForRemovedFailedStatefulTaskDuringRemoval); when(stateUpdater.remove(removedFailedStandbyTaskDuringRemoval.id())) .thenReturn(futureForRemovedFailedStandbyTaskDuringRemoval); - when(stateUpdater.drainExceptionsAndFailedTasks()).thenReturn(Arrays.asList( - new ExceptionAndTask(new StreamsException("KABOOM!"), removedFailedStatefulTaskDuringRemoval), - new ExceptionAndTask(new StreamsException("KABOOM!"), removedFailedStandbyTaskDuringRemoval) - )); + when(stateUpdater.drainExceptionsAndFailedTasks()) + .thenReturn(Arrays.asList( + new ExceptionAndTask(new StreamsException("KABOOM!"), removedFailedStatefulTaskDuringRemoval), + new ExceptionAndTask(new StreamsException("KABOOM!"), removedFailedStandbyTaskDuringRemoval)) + ).thenReturn(Collections.emptyList()); final TaskManager taskManager = setUpTaskManagerWithStateUpdater(ProcessingMode.AT_LEAST_ONCE, tasks); futureForRemovedStatefulTask.complete(new StateUpdater.RemovedTaskResult(removedStatefulTask)); futureForRemovedStandbyTask.complete(new StateUpdater.RemovedTaskResult(removedStandbyTask));