Skip to content

Commit 3e342e3

Browse files
committed
Fixed Publisher/Observable/Flowable.openSubscription in presence of selects;
support an optional `request` parameter in openSubscription to specify how many elements are requested from publisher in advance on subscription. Fixes #197
1 parent b5d4286 commit 3e342e3

File tree

6 files changed

+292
-58
lines changed

6 files changed

+292
-58
lines changed

core/kotlinx-coroutines-core/src/main/kotlin/kotlinx/coroutines/experimental/channels/AbstractChannel.kt

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ public abstract class AbstractSendChannel<E> : SendChannel<E> {
283283
* Retrieves first receiving waiter from the queue or returns closed token.
284284
* @suppress **This is unstable API and it is subject to change.**
285285
*/
286-
protected fun takeFirstReceiveOrPeekClosed(): ReceiveOrClosed<E>? =
286+
protected open fun takeFirstReceiveOrPeekClosed(): ReceiveOrClosed<E>? =
287287
queue.removeFirstIfIsInstanceOfOrPeekIf<ReceiveOrClosed<E>>({ it is Closed<*> })
288288

289289
// ------ registerSelectSend ------
@@ -520,7 +520,7 @@ public abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E>
520520
val result = if (isBufferAlwaysEmpty)
521521
queue.addLastIfPrev(receive, { it !is Send }) else
522522
queue.addLastIfPrevAndIf(receive, { it !is Send }, { isBufferEmpty })
523-
if (result) onEnqueuedReceive()
523+
if (result) onReceiveEnqueued()
524524
return result
525525
}
526526

@@ -640,7 +640,7 @@ public abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E>
640640
override fun finishOnSuccess(affected: LockFreeLinkedListNode, next: LockFreeLinkedListNode) {
641641
super.finishOnSuccess(affected, next)
642642
// notify the there is one more receiver
643-
onEnqueuedReceive()
643+
onReceiveEnqueued()
644644
// we can actually remove on select start, but this is also Ok (it'll get removed if discovered there)
645645
node.removeOnSelectCompletion()
646646
}
@@ -724,29 +724,36 @@ public abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E>
724724

725725
// ------ protected ------
726726

727+
override fun takeFirstReceiveOrPeekClosed(): ReceiveOrClosed<E>? =
728+
super.takeFirstReceiveOrPeekClosed().also {
729+
if (it != null && it !is Closed<*>) onReceiveDequeued()
730+
}
731+
727732
/**
728733
* Invoked when receiver is successfully enqueued to the queue of waiting receivers.
734+
* @suppress **This is unstable API and it is subject to change.**
729735
*/
730-
protected open fun onEnqueuedReceive() {}
736+
protected open fun onReceiveEnqueued() {}
731737

732738
/**
733-
* Invoked when enqueued receiver was successfully cancelled.
739+
* Invoked when enqueued receiver was successfully removed from the queue of waiting receivers.
740+
* @suppress **This is unstable API and it is subject to change.**
734741
*/
735-
protected open fun onCancelledReceive() {}
742+
protected open fun onReceiveDequeued() {}
736743

737744
// ------ private ------
738745

739746
private fun removeReceiveOnCancel(cont: CancellableContinuation<*>, receive: Receive<*>) {
740747
cont.invokeOnCompletion {
741748
if (cont.isCancelled && receive.remove())
742-
onCancelledReceive()
749+
onReceiveDequeued()
743750
}
744751
}
745752

746753
private class Itr<E>(val channel: AbstractChannel<E>) : ChannelIterator<E> {
747754
var result: Any? = POLL_FAILED // E | POLL_FAILED | Closed
748755

749-
suspend override fun hasNext(): Boolean {
756+
override suspend fun hasNext(): Boolean {
750757
// check for repeated hasNext
751758
if (result !== POLL_FAILED) return hasNextResult(result)
752759
// fast path -- try poll non-blocking
@@ -790,7 +797,7 @@ public abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E>
790797
}
791798

792799
@Suppress("UNCHECKED_CAST")
793-
suspend override fun next(): E {
800+
override suspend fun next(): E {
794801
val result = this.result
795802
if (result is Closed<*>) throw result.receiveException
796803
if (result !== POLL_FAILED) {
@@ -889,7 +896,7 @@ public abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E>
889896

890897
override fun dispose() { // invoked on select completion
891898
if (remove())
892-
onCancelledReceive() // notify cancellation of receive
899+
onReceiveDequeued() // notify cancellation of receive
893900
}
894901

895902
override fun toString(): String = "ReceiveSelect[$select,nullOnClose=$nullOnClose]"

reactive/kotlinx-coroutines-reactive/src/main/kotlin/kotlinx/coroutines/experimental/reactive/Channel.kt

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@ import org.reactivestreams.Subscription
2727
/**
2828
* Subscribes to this [Publisher] and returns a channel to receive elements emitted by it.
2929
* The resulting channel shall be [closed][SubscriptionReceiveChannel.close] to unsubscribe from this publisher.
30+
* @param request how many items to request from publisher in advance (optional, on-demand request by default).
3031
*/
31-
public fun <T> Publisher<T>.openSubscription(): SubscriptionReceiveChannel<T> {
32-
val channel = SubscriptionChannel<T>()
32+
@JvmOverloads // for binary compatibility
33+
public fun <T> Publisher<T>.openSubscription(request: Int = 0): SubscriptionReceiveChannel<T> {
34+
val channel = SubscriptionChannel<T>(request)
3335
subscribe(channel)
3436
return channel
3537
}
@@ -47,7 +49,6 @@ public fun <T> Publisher<T>.open(): SubscriptionReceiveChannel<T> = openSubscrip
4749
* This is a shortcut for `open().iterator()`. See [openSubscription] if you need an ability to manually
4850
* unsubscribe from the observable.
4951
*/
50-
5152
@Suppress("DeprecatedCallableAddReplaceWith")
5253
@Deprecated(message =
5354
"This iteration operator for `for (x in source) { ... }` loop is deprecated, " +
@@ -58,7 +59,7 @@ public operator fun <T> Publisher<T>.iterator() = openSubscription().iterator()
5859
/**
5960
* Subscribes to this [Publisher] and performs the specified action for each received element.
6061
*/
61-
public inline suspend fun <T> Publisher<T>.consumeEach(action: (T) -> Unit) {
62+
public suspend inline fun <T> Publisher<T>.consumeEach(action: (T) -> Unit) {
6263
openSubscription().use { channel ->
6364
for (x in channel) action(x)
6465
}
@@ -71,36 +72,39 @@ public inline suspend fun <T> Publisher<T>.consumeEach(action: (T) -> Unit) {
7172
public suspend fun <T> Publisher<T>.consumeEach(action: suspend (T) -> Unit) =
7273
consumeEach { action(it) }
7374

74-
private class SubscriptionChannel<T> : LinkedListChannel<T>(), SubscriptionReceiveChannel<T>, Subscriber<T> {
75+
private class SubscriptionChannel<T>(
76+
private val request: Int
77+
) : LinkedListChannel<T>(), SubscriptionReceiveChannel<T>, Subscriber<T> {
78+
init {
79+
require(request >= 0) { "Invalid request size: $request" }
80+
}
81+
7582
@Volatile
76-
@JvmField
77-
var subscription: Subscription? = null
83+
private var subscription: Subscription? = null
7884

79-
// request balance from cancelled receivers, balance is negative if we have receivers, but no subscription yet
80-
val _balance = atomic(0)
85+
// requested from subscription minus number of received minus number of enqueued receivers,
86+
// can be negative if we have receivers, but no subscription yet
87+
private val _requested = atomic(0)
8188

8289
// AbstractChannel overrides
83-
override fun onEnqueuedReceive() {
84-
_balance.loop { balance ->
90+
override fun onReceiveEnqueued() {
91+
_requested.loop { wasRequested ->
8592
val subscription = this.subscription
86-
if (subscription != null) {
87-
if (balance < 0) { // receivers came before we had subscription
88-
// try to fixup by making request
89-
if (!_balance.compareAndSet(balance, 0)) return@loop // continue looping
90-
subscription.request(-balance.toLong())
91-
return
92-
}
93-
if (balance == 0) { // normal story
94-
subscription.request(1)
95-
return
96-
}
93+
val needRequested = wasRequested - 1
94+
if (subscription != null && needRequested < 0) { // need to request more from subscription
95+
// try to fixup by making request
96+
if (wasRequested != request && !_requested.compareAndSet(wasRequested, request))
97+
return@loop // continue looping if failed
98+
subscription.request((request - needRequested).toLong())
99+
return
97100
}
98-
if (_balance.compareAndSet(balance, balance - 1)) return
101+
// just do book-keeping
102+
if (_requested.compareAndSet(wasRequested, needRequested)) return
99103
}
100104
}
101105

102-
override fun onCancelledReceive() {
103-
_balance.incrementAndGet()
106+
override fun onReceiveDequeued() {
107+
_requested.incrementAndGet()
104108
}
105109

106110
override fun afterClose(cause: Throwable?) {
@@ -110,22 +114,23 @@ private class SubscriptionChannel<T> : LinkedListChannel<T>(), SubscriptionRecei
110114
// Subscriber overrides
111115
override fun onSubscribe(s: Subscription) {
112116
subscription = s
113-
while (true) { // lock-free loop on balance
117+
while (true) { // lock-free loop on _requested
114118
if (isClosedForSend) {
115119
s.cancel()
116120
return
117121
}
118-
val balance = _balance.value
119-
if (balance >= 0) return // ok -- normal story
120-
// otherwise, receivers came before we had subscription
122+
val wasRequested = _requested.value
123+
if (wasRequested >= request) return // ok -- normal story
124+
// otherwise, receivers came before we had subscription or need to make initial request
121125
// try to fixup by making request
122-
if (!_balance.compareAndSet(balance, 0)) continue
123-
s.request(-balance.toLong())
126+
if (!_requested.compareAndSet(wasRequested, request)) continue
127+
s.request((request - wasRequested).toLong())
124128
return
125129
}
126130
}
127131

128132
override fun onNext(t: T) {
133+
_requested.decrementAndGet()
129134
offer(t)
130135
}
131136

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Copyright 2016-2017 JetBrains s.r.o.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package kotlinx.coroutines.experimental.reactive
18+
19+
import kotlinx.coroutines.experimental.*
20+
import kotlinx.coroutines.experimental.selects.*
21+
import org.junit.*
22+
import org.junit.Assert.*
23+
import org.junit.runner.*
24+
import org.junit.runners.*
25+
26+
@RunWith(Parameterized::class)
27+
class PublisherSubscriptionSelectTest(val request: Int) : TestBase() {
28+
companion object {
29+
@Parameterized.Parameters(name = "request = {0}")
30+
@JvmStatic
31+
fun params(): Collection<Array<Any>> = listOf(0, 1, 10).map { arrayOf<Any>(it) }
32+
}
33+
34+
@Test
35+
fun testSelect() = runTest {
36+
// source with n ints
37+
val n = 1000 * stressTestMultiplier
38+
val source = publish(coroutineContext) { repeat(n) { send(it) } }
39+
var a = 0
40+
var b = 0
41+
// open two subs
42+
source.openSubscription(request).use { channelA ->
43+
source.openSubscription(request).use { channelB ->
44+
loop@ while (true) {
45+
val done: Int = select {
46+
channelA.onReceiveOrNull {
47+
if (it != null) assertEquals(a++, it)
48+
if (it == null) 0 else 1
49+
}
50+
channelB.onReceiveOrNull {
51+
if (it != null) assertEquals(b++, it)
52+
if (it == null) 0 else 2
53+
}
54+
}
55+
when (done) {
56+
0 -> break@loop
57+
1 -> {
58+
val r = channelB.receiveOrNull()
59+
if (r != null) assertEquals(b++, r)
60+
}
61+
2 -> {
62+
val r = channelA.receiveOrNull()
63+
if (r != null) assertEquals(a++, r)
64+
}
65+
}
66+
}
67+
}
68+
}
69+
// should receive one of them fully
70+
assertTrue(a == n || b == n)
71+
}
72+
}

reactive/kotlinx-coroutines-rx1/src/main/kotlin/kotlinx/coroutines/experimental/rx1/RxChannel.kt

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@ import rx.Subscription
2727
/**
2828
* Subscribes to this [Observable] and returns a channel to receive elements emitted by it.
2929
* The resulting channel shall be [closed][SubscriptionReceiveChannel.close] to unsubscribe from this observable.
30+
* @param request how many items to request from publisher in advance (optional, on-demand request by default).
3031
*/
31-
public fun <T> Observable<T>.openSubscription(): SubscriptionReceiveChannel<T> {
32-
val channel = SubscriptionChannel<T>()
32+
@JvmOverloads // for binary compatibility
33+
public fun <T> Observable<T>.openSubscription(request: Int = 0): SubscriptionReceiveChannel<T> {
34+
val channel = SubscriptionChannel<T>(request)
3335
val subscription = subscribe(channel.subscriber)
3436
channel.subscription = subscription
3537
if (channel.isClosedForSend) subscription.unsubscribe()
@@ -59,7 +61,7 @@ public operator fun <T> Observable<T>.iterator() = openSubscription().iterator()
5961
/**
6062
* Subscribes to this [Observable] and performs the specified action for each received element.
6163
*/
62-
public inline suspend fun <T> Observable<T>.consumeEach(action: (T) -> Unit) {
64+
public suspend inline fun <T> Observable<T>.consumeEach(action: (T) -> Unit) {
6365
openSubscription().use { channel ->
6466
for (x in channel) action(x)
6567
}
@@ -72,45 +74,58 @@ public inline suspend fun <T> Observable<T>.consumeEach(action: (T) -> Unit) {
7274
public suspend fun <T> Observable<T>.consumeEach(action: suspend (T) -> Unit) =
7375
consumeEach { action(it) }
7476

75-
private class SubscriptionChannel<T> : LinkedListChannel<T>(), SubscriptionReceiveChannel<T> {
77+
private class SubscriptionChannel<T>(
78+
private val request: Int
79+
) : LinkedListChannel<T>(), SubscriptionReceiveChannel<T> {
80+
init {
81+
require(request >= 0) { "Invalid request size: $request" }
82+
}
83+
7684
@JvmField
77-
val subscriber: ChannelSubscriber = ChannelSubscriber()
85+
val subscriber: ChannelSubscriber = ChannelSubscriber(request)
7886

7987
@Volatile
8088
@JvmField
8189
var subscription: Subscription? = null
8290

83-
val _balance = atomic(0) // request balance from cancelled receivers
91+
// requested from subscription minus number of received minus number of enqueued receivers,
92+
private val _requested = atomic(request)
8493

8594
// AbstractChannel overrides
86-
override fun onEnqueuedReceive() {
87-
_balance.loop { balance ->
88-
if (balance == 0) {
89-
subscriber.requestOne()
95+
override fun onReceiveEnqueued() {
96+
_requested.loop { wasRequested ->
97+
val needRequested = wasRequested - 1
98+
if (needRequested < 0) { // need to request more from subscriber
99+
// try to fixup by making request
100+
if (wasRequested != request && !_requested.compareAndSet(wasRequested, request))
101+
return@loop // continue looping if failed
102+
subscriber.makeRequest((request - needRequested).toLong())
90103
return
91104
}
92-
if (_balance.compareAndSet(balance, balance - 1)) return
105+
// just do book-keeping
106+
if (_requested.compareAndSet(wasRequested, needRequested)) return
93107
}
94108
}
95109

96-
override fun onCancelledReceive() {
97-
_balance.incrementAndGet()
110+
override fun onReceiveDequeued() {
111+
_requested.incrementAndGet()
98112
}
99113

100114
override fun afterClose(cause: Throwable?) {
101115
subscription?.unsubscribe()
102116
}
103117

104-
inner class ChannelSubscriber: Subscriber<T>() {
105-
fun requestOne() {
106-
request(1)
118+
inner class ChannelSubscriber(private val request: Int): Subscriber<T>() {
119+
fun makeRequest(n: Long) {
120+
request(n)
107121
}
108122

109123
override fun onStart() {
110-
request(0) // init backpressure, but don't request anything yet
124+
request(request.toLong()) // init backpressure
111125
}
112126

113127
override fun onNext(t: T) {
128+
_requested.decrementAndGet()
114129
offer(t)
115130
}
116131

0 commit comments

Comments
 (0)