From cddeae53c32fe1882503f7cdcb5baa22c59b4f9e Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 2 Jun 2025 20:43:50 +0000 Subject: [PATCH 1/3] Initial plan for issue From 0782632b2c6e1bd7f4581a5c9c47f59eb4ed9108 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 2 Jun 2025 20:50:22 +0000 Subject: [PATCH 2/3] feat: initial test case to demonstrate safe cancellation issue Co-authored-by: xerial <57538+xerial@users.noreply.github.com> --- .../airframe/rx/RxSafeCancellationTest.scala | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 airframe-rx/src/test/scala/wvlet/airframe/rx/RxSafeCancellationTest.scala diff --git a/airframe-rx/src/test/scala/wvlet/airframe/rx/RxSafeCancellationTest.scala b/airframe-rx/src/test/scala/wvlet/airframe/rx/RxSafeCancellationTest.scala new file mode 100644 index 000000000..a46548d4e --- /dev/null +++ b/airframe-rx/src/test/scala/wvlet/airframe/rx/RxSafeCancellationTest.scala @@ -0,0 +1,52 @@ +package wvlet.airframe.rx + +import wvlet.airspec.AirSpec + +class RxSafeCancellationTest extends AirSpec { + + test("demonstrate self-cancellation scenario") { + test("cancel within operator function should work safely") { + val results = scala.collection.mutable.ListBuffer[Int]() + val rx = Rx.sequence(1, 2, 3, 4, 5) + + var c = Cancelable.empty + c = rx.map { x => + results += x + if (x == 3) { + // User wants to cancel the stream at this point + c.cancel + } + x * 2 + }.run() + + // We should be able to process up to the point of cancellation + // and not process further elements + assert(results.size <= 3) + results.contains(1) shouldBe true + results.contains(2) shouldBe true + results.contains(3) shouldBe true + } + + test("request completion within operator should work") { + val results = scala.collection.mutable.ListBuffer[Int]() + val rx = Rx.sequence(1, 2, 3, 4, 5) + + // What we'd like: a way to request completion safely + var requestCompletion: () => Unit = () => {} + + val c = rx.map { x => + results += x + if (x == 3) { + // Request completion instead of cancelling + requestCompletion() + } + x * 2 + }.run() + + // After implementing the safe completion mechanism, + // we should process 1, 2, 3 but not 4, 5 + assert(results.size <= 3) + } + } + +} \ No newline at end of file From b20711023bff27d61ba9ee649e6aabc77e13c40d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 2 Jun 2025 20:57:45 +0000 Subject: [PATCH 3/3] feat: add mapWithCompletion and flatMapWithCompletion for safe cancellation Co-authored-by: xerial <57538+xerial@users.noreply.github.com> --- .../src/main/scala/wvlet/airframe/rx/Rx.scala | 28 +++- .../scala/wvlet/airframe/rx/RxRunner.scala | 51 ++++++++ .../airframe/rx/RxSafeCancellationTest.scala | 123 +++++++++++++----- 3 files changed, 165 insertions(+), 37 deletions(-) diff --git a/airframe-rx/src/main/scala/wvlet/airframe/rx/Rx.scala b/airframe-rx/src/main/scala/wvlet/airframe/rx/Rx.scala index a08ce540c..7ceba39fa 100644 --- a/airframe-rx/src/main/scala/wvlet/airframe/rx/Rx.scala +++ b/airframe-rx/src/main/scala/wvlet/airframe/rx/Rx.scala @@ -159,6 +159,16 @@ trait Rx[+A] extends RxOps[A] { */ def map[B](f: A => B): Rx[B] = MapOp(this, f) + /** + * Applies `f` to the input value and return the result. The function can signal completion by returning None. When + * None is returned, the stream will complete gracefully. + * @param f + * function that returns Some(result) to continue, or None to complete the stream + * @tparam B + * @return + */ + def mapWithCompletion[B](f: A => Option[B]): Rx[B] = MapWithCompletionOp(this, f) + /** * Applies `f` to the input value that produces another Rx stream. This method is an alias of flatMap(f) * @param f @@ -175,6 +185,16 @@ trait Rx[+A] extends RxOps[A] { */ def flatMap[B](f: A => RxOps[B]): Rx[B] = FlatMapOp(this, f) + /** + * Applies `f` to the input value that produces another Rx stream. The function can signal completion by returning + * None. When None is returned, the stream will complete gracefully. + * @param f + * function that returns Some(Rx) to continue, or None to complete the stream + * @tparam B + * @return + */ + def flatMapWithCompletion[B](f: A => Option[RxOps[B]]): Rx[B] = FlatMapWithCompletionOp(this, f) + /** * Applies the given filter and emit the value only when the filter condition matches * @param f @@ -880,9 +900,11 @@ object Rx extends LogSupport { override def parents: Seq[RxOps[_]] = Seq(input) } - case class MapOp[A, B](input: Rx[A], f: A => B) extends UnaryRx[A, B] - case class FlatMapOp[A, B](input: Rx[A], f: A => RxOps[B]) extends UnaryRx[A, B] - case class FilterOp[A](input: Rx[A], cond: A => Boolean) extends UnaryRx[A, A] + case class MapOp[A, B](input: Rx[A], f: A => B) extends UnaryRx[A, B] + case class MapWithCompletionOp[A, B](input: Rx[A], f: A => Option[B]) extends UnaryRx[A, B] + case class FlatMapOp[A, B](input: Rx[A], f: A => RxOps[B]) extends UnaryRx[A, B] + case class FlatMapWithCompletionOp[A, B](input: Rx[A], f: A => Option[RxOps[B]]) extends UnaryRx[A, B] + case class FilterOp[A](input: Rx[A], cond: A => Boolean) extends UnaryRx[A, A] case class ZipOp[A, B](a: RxOps[A], b: RxOps[B]) extends Rx[(A, B)] { override def parents: Seq[RxOps[_]] = Seq(a, b) } diff --git a/airframe-rx/src/main/scala/wvlet/airframe/rx/RxRunner.scala b/airframe-rx/src/main/scala/wvlet/airframe/rx/RxRunner.scala index 23e9a1102..d5c998bb0 100644 --- a/airframe-rx/src/main/scala/wvlet/airframe/rx/RxRunner.scala +++ b/airframe-rx/src/main/scala/wvlet/airframe/rx/RxRunner.scala @@ -116,6 +116,22 @@ class RxRunner( case other => effect(other) } + case MapWithCompletionOp(in, f) => + run(in) { + case OnNext(v) => + Try(f.asInstanceOf[Any => Option[A]](v)) match { + case Success(Some(x)) => + effect(OnNext(x)) + case Success(None) => + // The function requested completion + effect(OnCompletion) + RxResult.Stop + case Failure(e) => + effect(OnError(e)) + } + case other => + effect(other) + } case fm @ FlatMapOp(in, f) => // This var is a placeholder to remember the preceding Cancelable operator, which will be updated later var c1 = Cancelable.empty @@ -147,6 +163,41 @@ class RxRunner( Cancelable { () => c1.cancel; c2.cancel } + case fmwc @ FlatMapWithCompletionOp(in, f) => + // This var is a placeholder to remember the preceding Cancelable operator, which will be updated later + var c1 = Cancelable.empty + val c2 = run(fmwc.input) { + case OnNext(x) => + var toContinue: RxResult = RxResult.Continue + Try(fmwc.f.asInstanceOf[Function[Any, Option[RxOps[_]]]](x)) match { + case Success(Some(rxb)) => + // This code is necessary to properly cancel the effect if this operator is evaluated before + c1.cancel + c1 = run(rxb) { + case n @ OnNext(x) => + toContinue = effect(n) + toContinue + case OnCompletion => + // skip the end of the nested flatMap body stream + RxResult.Continue + case ev @ OnError(e) => + toContinue = effect(ev) + toContinue + } + toContinue + case Success(None) => + // The function requested completion + effect(OnCompletion) + RxResult.Stop + case Failure(e) => + effect(OnError(e)) + } + case other => + effect(other) + } + Cancelable { () => + c1.cancel; c2.cancel + } case FilterOp(in, cond) => run(in) { ev => ev match { diff --git a/airframe-rx/src/test/scala/wvlet/airframe/rx/RxSafeCancellationTest.scala b/airframe-rx/src/test/scala/wvlet/airframe/rx/RxSafeCancellationTest.scala index a46548d4e..4df1564ed 100644 --- a/airframe-rx/src/test/scala/wvlet/airframe/rx/RxSafeCancellationTest.scala +++ b/airframe-rx/src/test/scala/wvlet/airframe/rx/RxSafeCancellationTest.scala @@ -4,49 +4,104 @@ import wvlet.airspec.AirSpec class RxSafeCancellationTest extends AirSpec { - test("demonstrate self-cancellation scenario") { - test("cancel within operator function should work safely") { - val results = scala.collection.mutable.ListBuffer[Int]() - val rx = Rx.sequence(1, 2, 3, 4, 5) - - var c = Cancelable.empty - c = rx.map { x => + test("mapWithCompletion should allow safe completion signaling") { + val results = scala.collection.mutable.ListBuffer[Int]() + val rx = Rx.sequence(1, 2, 3, 4, 5) + + val c = rx + .mapWithCompletion { x => results += x if (x == 3) { - // User wants to cancel the stream at this point - c.cancel + // Request completion by returning None + None + } else { + Some(x * 2) + } + }.run() + + // Should process 1, 2, 3 but not 4, 5 + results.toList shouldBe List(1, 2, 3) + } + + test("mapWithCompletion should work with early completion") { + val results = scala.collection.mutable.ListBuffer[Int]() + + val c = Rx + .sequence(1, 2, 3, 4, 5).mapWithCompletion { x => + if (x == 2) { + None // Complete immediately when seeing 2 + } else { + results += x + Some(x * 10) } - x * 2 }.run() - // We should be able to process up to the point of cancellation - // and not process further elements - assert(results.size <= 3) - results.contains(1) shouldBe true - results.contains(2) shouldBe true - results.contains(3) shouldBe true + // Should only process 1, then complete when seeing 2 + results.toList shouldBe List(1) + } + + test("mapWithCompletion should handle empty sequence") { + val results = scala.collection.mutable.ListBuffer[Int]() + + val c = Rx + .empty[Int].mapWithCompletion { x => + results += x + Some(x * 2) + }.run() + + // Should have no results for empty sequence + results.toList shouldBe List() + } + + test("mapWithCompletion should propagate errors") { + val ex = new RuntimeException("test error") + + val thrown = intercept[RuntimeException] { + Rx.sequence(1, 2, 3).mapWithCompletion { x => + if (x == 2) { + throw ex + } + Some(x * 2) + }.run() } - - test("request completion within operator should work") { - val results = scala.collection.mutable.ListBuffer[Int]() - val rx = Rx.sequence(1, 2, 3, 4, 5) - - // What we'd like: a way to request completion safely - var requestCompletion: () => Unit = () => {} - - val c = rx.map { x => + + thrown shouldBe ex + } + + test("flatMapWithCompletion should allow safe completion signaling") { + val results = scala.collection.mutable.ListBuffer[Int]() + val rx = Rx.sequence(1, 2, 3, 4, 5) + + val c = rx + .flatMapWithCompletion { x => results += x if (x == 3) { - // Request completion instead of cancelling - requestCompletion() + // Request completion by returning None + None + } else { + Some(Rx.single(x * 10)) } - x * 2 }.run() - - // After implementing the safe completion mechanism, - // we should process 1, 2, 3 but not 4, 5 - assert(results.size <= 3) - } + + // Should process 1, 2, 3 but not 4, 5 + results.toList shouldBe List(1, 2, 3) + } + + test("flatMapWithCompletion should work with early completion") { + val results = scala.collection.mutable.ListBuffer[Int]() + + val c = Rx + .sequence(1, 2, 3, 4, 5).flatMapWithCompletion { x => + if (x == 2) { + None // Complete immediately when seeing 2 + } else { + results += x + Some(Rx.single(x * 100)) + } + }.run() + + // Should only process 1, then complete when seeing 2 + results.toList shouldBe List(1) } -} \ No newline at end of file +}