From dd9fefecdef1000ca9ca5e23a43c3ab98fb64e7f Mon Sep 17 00:00:00 2001
From: Matthew de Detrich <mdedetrich@gmail.com>
Date: Mon, 9 Sep 2024 09:15:21 +0200
Subject: [PATCH] Add overridden duration timeout to StreamTestKit

---
 .../testkit/javadsl/StreamTestKit.scala       | 20 +++++++++-
 .../testkit/scaladsl/StreamTestKit.scala      | 31 ++++++++++++--
 .../stream/testkit/StreamConfiguration.scala  | 40 +++++++++++++++++++
 .../pekko/stream/testkit/StreamSpec.scala     |  8 +++-
 4 files changed, 92 insertions(+), 7 deletions(-)
 create mode 100644 stream-testkit/src/test/scala/org/apache/pekko/stream/testkit/StreamConfiguration.scala

diff --git a/stream-testkit/src/main/scala/org/apache/pekko/stream/testkit/javadsl/StreamTestKit.scala b/stream-testkit/src/main/scala/org/apache/pekko/stream/testkit/javadsl/StreamTestKit.scala
index 5f42371bf38..306082a7573 100644
--- a/stream-testkit/src/main/scala/org/apache/pekko/stream/testkit/javadsl/StreamTestKit.scala
+++ b/stream-testkit/src/main/scala/org/apache/pekko/stream/testkit/javadsl/StreamTestKit.scala
@@ -19,6 +19,10 @@ import pekko.stream.{ Materializer, SystemMaterializer }
 import pekko.stream.impl.PhasedFusingActorMaterializer
 import pekko.stream.testkit.scaladsl
 
+import java.time.Duration
+import java.util.concurrent.TimeUnit
+import scala.concurrent.duration.FiniteDuration
+
 object StreamTestKit {
 
   /**
@@ -29,7 +33,21 @@ object StreamTestKit {
   def assertAllStagesStopped(mat: Materializer): Unit =
     mat match {
       case impl: PhasedFusingActorMaterializer =>
-        scaladsl.StreamTestKit.assertNoChildren(impl.system, impl.supervisor)
+        scaladsl.StreamTestKit.assertNoChildren(impl.system, impl.supervisor, None)
+      case _ =>
+    }
+
+  /**
+   * Assert that there are no stages running under a given materializer.
+   * Usually this assertion is run after a test-case to check that all of the
+   * stages have terminated successfully with an overridden duration that ignores
+   * `stream.testkit.all-stages-stopped-timeout`.
+   */
+  def assertAllStagesStopped(mat: Materializer, overrideTimeout: Duration): Unit =
+    mat match {
+      case impl: PhasedFusingActorMaterializer =>
+        scaladsl.StreamTestKit.assertNoChildren(impl.system, impl.supervisor,
+          Some(FiniteDuration(overrideTimeout.toMillis, TimeUnit.MILLISECONDS)))
       case _ =>
     }
 
diff --git a/stream-testkit/src/main/scala/org/apache/pekko/stream/testkit/scaladsl/StreamTestKit.scala b/stream-testkit/src/main/scala/org/apache/pekko/stream/testkit/scaladsl/StreamTestKit.scala
index 85d5d5623b8..3c8b3eef5fd 100644
--- a/stream-testkit/src/main/scala/org/apache/pekko/stream/testkit/scaladsl/StreamTestKit.scala
+++ b/stream-testkit/src/main/scala/org/apache/pekko/stream/testkit/scaladsl/StreamTestKit.scala
@@ -35,12 +35,30 @@ object StreamTestKit {
    * This assertion is useful to check that all of the stages have
    * terminated successfully.
    */
+  def assertAllStagesStopped[T](block: => T, overrideTimeout: FiniteDuration)(implicit materializer: Materializer): T =
+    materializer match {
+      case impl: PhasedFusingActorMaterializer =>
+        stopAllChildren(impl.system, impl.supervisor)
+        val result = block
+        assertNoChildren(impl.system, impl.supervisor, Some(overrideTimeout))
+        result
+      case _ => block
+    }
+
+  /**
+   * Asserts that after the given code block is ran, no stages are left over
+   * that were created by the given materializer with an overridden duration
+   * that ignores `stream.testkit.all-stages-stopped-timeout`.
+   *
+   * This assertion is useful to check that all of the stages have
+   * terminated successfully.
+   */
   def assertAllStagesStopped[T](block: => T)(implicit materializer: Materializer): T =
     materializer match {
       case impl: PhasedFusingActorMaterializer =>
         stopAllChildren(impl.system, impl.supervisor)
         val result = block
-        assertNoChildren(impl.system, impl.supervisor)
+        assertNoChildren(impl.system, impl.supervisor, None)
         result
       case _ => block
     }
@@ -53,10 +71,15 @@ object StreamTestKit {
   }
 
   /** INTERNAL API */
-  @InternalApi private[testkit] def assertNoChildren(sys: ActorSystem, supervisor: ActorRef): Unit = {
+  @InternalApi private[testkit] def assertNoChildren(sys: ActorSystem, supervisor: ActorRef,
+      overrideTimeout: Option[FiniteDuration]): Unit = {
     val probe = TestProbe()(sys)
-    val c = sys.settings.config.getConfig("pekko.stream.testkit")
-    val timeout = c.getDuration("all-stages-stopped-timeout", MILLISECONDS).millis
+
+    val timeout = overrideTimeout.getOrElse {
+      val c = sys.settings.config.getConfig("pekko.stream.testkit")
+      c.getDuration("all-stages-stopped-timeout", MILLISECONDS).millis
+    }
+
     probe.within(timeout) {
       try probe.awaitAssert {
           supervisor.tell(StreamSupervisor.GetChildren, probe.ref)
diff --git a/stream-testkit/src/test/scala/org/apache/pekko/stream/testkit/StreamConfiguration.scala b/stream-testkit/src/test/scala/org/apache/pekko/stream/testkit/StreamConfiguration.scala
new file mode 100644
index 00000000000..bd4cf74e7c8
--- /dev/null
+++ b/stream-testkit/src/test/scala/org/apache/pekko/stream/testkit/StreamConfiguration.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.pekko.stream.testkit
+
+import org.apache.pekko.testkit.TestKitBase
+import org.scalatest.time.{ Millis, Span }
+
+import java.util.concurrent.TimeUnit
+
+trait StreamConfiguration extends TestKitBase {
+  final case class StreamConfig(allStagesStoppedTimeout: Span = Span({
+          val c = system.settings.config.getConfig("pekko.stream.testkit")
+          c.getDuration("all-stages-stopped-timeout", TimeUnit.MILLISECONDS)
+        }, Millis))
+
+  private val defaultStreamConfig = StreamConfig()
+
+  /**
+   * The default `StreamConfig` which is derived from the Actor System's `pekko.stream.testkit.all-stages-stopped-timeout`
+   * configuration value. If you want to provide a different StreamConfig for specific tests without having to re-specify
+   * `pekko.stream.testkit.all-stages-stopped-timeout` then you can override this value.
+   */
+  implicit def streamConfig: StreamConfig = defaultStreamConfig
+
+}
diff --git a/stream-testkit/src/test/scala/org/apache/pekko/stream/testkit/StreamSpec.scala b/stream-testkit/src/test/scala/org/apache/pekko/stream/testkit/StreamSpec.scala
index 827a3c4918c..2fc883026aa 100644
--- a/stream-testkit/src/test/scala/org/apache/pekko/stream/testkit/StreamSpec.scala
+++ b/stream-testkit/src/test/scala/org/apache/pekko/stream/testkit/StreamSpec.scala
@@ -30,7 +30,10 @@ import org.scalatest.Failed
 
 import com.typesafe.config.{ Config, ConfigFactory }
 
-abstract class StreamSpec(_system: ActorSystem) extends PekkoSpec(_system) {
+import java.util.concurrent.TimeUnit
+
+abstract class StreamSpec(_system: ActorSystem) extends PekkoSpec(_system) with StreamConfiguration {
+
   def this(config: Config) =
     this(
       ActorSystem(
@@ -73,7 +76,8 @@ abstract class StreamSpec(_system: ActorSystem) extends PekkoSpec(_system) {
           case impl: PhasedFusingActorMaterializer =>
             stopAllChildren(impl.system, impl.supervisor)
             val result = test.apply()
-            assertNoChildren(impl.system, impl.supervisor)
+            assertNoChildren(impl.system, impl.supervisor,
+              Some(FiniteDuration(streamConfig.allStagesStoppedTimeout.millisPart, TimeUnit.MILLISECONDS)))
             result
           case _ => other
         }