diff --git a/akka-projection-grpc/src/main/scala/akka/projection/grpc/internal/FilterStage.scala b/akka-projection-grpc/src/main/scala/akka/projection/grpc/internal/FilterStage.scala index 8814c2120..b1db1b8ef 100644 --- a/akka-projection-grpc/src/main/scala/akka/projection/grpc/internal/FilterStage.scala +++ b/akka-projection-grpc/src/main/scala/akka/projection/grpc/internal/FilterStage.scala @@ -4,11 +4,12 @@ package akka.projection.grpc.internal +import akka.Done + import scala.util.Failure import scala.util.Success import scala.util.Try import scala.util.matching.Regex - import akka.NotUsed import akka.actor.typed.scaladsl.LoggerOps import akka.annotation.InternalApi @@ -28,6 +29,7 @@ import akka.stream.Attributes import akka.stream.BidiShape import akka.stream.Inlet import akka.stream.Outlet +import akka.stream.scaladsl.Keep import akka.stream.scaladsl.Sink import akka.stream.scaladsl.SinkQueueWithCancel import akka.stream.scaladsl.Source @@ -37,6 +39,8 @@ import akka.stream.stage.InHandler import akka.stream.stage.OutHandler import org.slf4j.LoggerFactory +import scala.concurrent.Future + /** * INTERNAL API */ @@ -145,7 +149,8 @@ import org.slf4j.LoggerFactory private case class ReplaySession( fromSeqNr: Long, filterAfterSeqNr: Long, - queue: SinkQueueWithCancel[EventEnvelope[Any]]) + queue: SinkQueueWithCancel[EventEnvelope[Any]], + replayStreamCompletion: Future[Done]) } @@ -356,11 +361,14 @@ import org.slf4j.LoggerFactory replayInProgress(pid).copy(filterAfterSeqNr = replayPersistenceId.filterAfterSeqNr)) } else if (replayInProgress.size < replayParallelism) { log.debugN("Stream [{}]: Starting replay of persistenceId [{}], from seqNr [{}]", logPrefix, pid, fromSeqNr) - val queue = + val (replayCompletion, queue) = currentEventsByPersistenceId(pid, fromSeqNr) - .runWith(Sink.queue())(materializer) - replayInProgress = - replayInProgress.updated(pid, ReplaySession(fromSeqNr, replayPersistenceId.filterAfterSeqNr, queue)) + .watchTermination()((_, done) => done) + .toMat(Sink.queue())(Keep.both) + .run()(materializer) + replayInProgress = replayInProgress.updated( + pid, + ReplaySession(fromSeqNr, replayPersistenceId.filterAfterSeqNr, queue, replayCompletion)) tryPullReplay(pid) } else { log.debugN("Stream [{}]: Queueing replay of persistenceId [{}], from seqNr [{}]", logPrefix, pid, fromSeqNr) @@ -457,12 +465,36 @@ import org.slf4j.LoggerFactory } }) - setHandler(outEnv, new OutHandler { - override def onPull(): Unit = { - log.trace("Stream [{}]: onPull outEnv", logPrefix) - pullInEnvOrReplay() - } - }) + setHandler( + outEnv, + new OutHandler { + override def onPull(): Unit = { + log.trace("Stream [{}]: onPull outEnv", logPrefix) + pullInEnvOrReplay() + } + + override def onDownstreamFinish(cause: Throwable): Unit = { + val runningSessions = replayInProgress.values.filterNot(_.replayStreamCompletion.isCompleted) + if (runningSessions.nonEmpty) { + // to avoid abrupt stage termination error logging, + // defer acting on cancel until any replays have completely cancelled + setKeepGoing(true) + val replayCompletedCallback = getAsyncCallback[Try[Done]] { _ => + val stillRunning = replayInProgress.values.filterNot(_.replayStreamCompletion.isCompleted) + if (stillRunning.isEmpty) { + super.onDownstreamFinish(cause) + } + }.invoke _ + runningSessions.foreach { runningSession => + runningSession.queue.cancel() + runningSession.replayStreamCompletion.onComplete(replayCompletedCallback)(ExecutionContexts.parasitic) + } + } else { + super.onDownstreamFinish(cause) + } + } + }) + } }