Skip to content

Commit d1842c6

Browse files
author
Sergey Mashkov
committed
IO: fix joining sequence, avoid infinite joinTo suspension
1 parent 869bde5 commit d1842c6

File tree

1 file changed

+53
-16
lines changed

1 file changed

+53
-16
lines changed

core/kotlinx-coroutines-io/src/main/kotlin/kotlinx/coroutines/experimental/io/ByteBufferChannel.kt

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,14 @@ internal class ByteBufferChannel(
265265
if (newState === ReadWriteBufferState.IdleEmpty) {
266266
toRelease?.let { releaseBuffer(it.initial) }
267267
resumeWriteOp()
268+
} else {
269+
if (newState is ReadWriteBufferState.IdleNonEmpty && newState.capacity.isEmpty()) {
270+
if (newState.capacity.tryLockForRelease() && State.compareAndSet(this, newState, ReadWriteBufferState.IdleEmpty)) {
271+
newState.capacity.resetForWrite()
272+
releaseBuffer(newState.initial)
273+
resumeWriteOp()
274+
}
275+
}
268276
}
269277
}
270278

@@ -288,7 +296,7 @@ internal class ByteBufferChannel(
288296
}
289297

290298
private fun tryCompleteJoining(joined: JoiningState): Boolean {
291-
if (!tryReleaseBuffer()) return false
299+
if (!tryReleaseBuffer(true)) return false
292300
ensureClosedJoined(joined)
293301

294302
resumeReadOp { IllegalStateException("Joining is in progress") }
@@ -300,7 +308,7 @@ internal class ByteBufferChannel(
300308
private fun tryTerminate(): Boolean {
301309
if (closed == null) return false
302310

303-
if (!tryReleaseBuffer()) return false
311+
if (!tryReleaseBuffer(false)) return false
304312

305313
joining?.let { ensureClosedJoined(it) }
306314

@@ -310,7 +318,7 @@ internal class ByteBufferChannel(
310318
return true
311319
}
312320

313-
private fun tryReleaseBuffer(): Boolean {
321+
private fun tryReleaseBuffer(forceTermination: Boolean): Boolean {
314322
var toRelease: ReadWriteBufferState.Initial? = null
315323

316324
updateState { state ->
@@ -329,6 +337,10 @@ internal class ByteBufferChannel(
329337
toRelease = state.initial
330338
ReadWriteBufferState.Terminated
331339
}
340+
forceTermination && state is ReadWriteBufferState.IdleNonEmpty && state.capacity.tryLockForRelease() -> {
341+
toRelease = state.initial
342+
ReadWriteBufferState.Terminated
343+
}
332344
else -> return false
333345
}
334346
}
@@ -464,15 +476,15 @@ internal class ByteBufferChannel(
464476
return consumed
465477
}
466478

467-
final suspend override fun readFully(dst: ByteArray, offset: Int, length: Int) {
479+
final override suspend fun readFully(dst: ByteArray, offset: Int, length: Int) {
468480
val consumed = readAsMuchAsPossible(dst, offset, length)
469481

470482
if (consumed < length) {
471483
return readFullySuspend(dst, offset + consumed, length - consumed)
472484
}
473485
}
474486

475-
final suspend override fun readFully(dst: ByteBuffer): Int {
487+
final override suspend fun readFully(dst: ByteBuffer): Int {
476488
val rc = readAsMuchAsPossible(dst)
477489
if (!dst.hasRemaining()) return rc
478490

@@ -490,7 +502,7 @@ internal class ByteBufferChannel(
490502
return copied
491503
}
492504

493-
private suspend tailrec fun readFullySuspend(dst: ByteArray, offset: Int, length: Int) {
505+
private tailrec suspend fun readFullySuspend(dst: ByteArray, offset: Int, length: Int) {
494506
if (!readSuspend(1)) throw ClosedReceiveChannelException("Unexpected EOF: expected $length more bytes")
495507

496508
val consumed = readAsMuchAsPossible(dst, offset, length)
@@ -560,7 +572,7 @@ internal class ByteBufferChannel(
560572
return readAvailable(dst)
561573
}
562574

563-
final suspend override fun readPacket(size: Int, headerSizeHint: Int): ByteReadPacket {
575+
final override suspend fun readPacket(size: Int, headerSizeHint: Int): ByteReadPacket {
564576
closed?.cause?.let { throw it }
565577

566578
if (size == 0) return ByteReadPacket.Empty
@@ -626,7 +638,7 @@ internal class ByteBufferChannel(
626638
}
627639
}
628640

629-
final suspend override fun readByte(): Byte {
641+
final override suspend fun readByte(): Byte {
630642
var b: Byte = 0
631643

632644
val rc = reading {
@@ -649,7 +661,7 @@ internal class ByteBufferChannel(
649661
return readByte()
650662
}
651663

652-
final suspend override fun readBoolean(): Boolean {
664+
final override suspend fun readBoolean(): Boolean {
653665
var b = false
654666

655667
val rc = reading {
@@ -672,7 +684,7 @@ internal class ByteBufferChannel(
672684
return readBoolean()
673685
}
674686

675-
final suspend override fun readShort(): Short {
687+
final override suspend fun readShort(): Short {
676688
var sh: Short = 0
677689

678690
val rc = reading {
@@ -1262,8 +1274,10 @@ internal class ByteBufferChannel(
12621274

12631275
flush()
12641276

1265-
if (src.availableForRead == 0 && !src.readSuspendImpl(1)) {
1266-
if (joined == null || src.tryCompleteJoining(joined)) break
1277+
if (src.availableForRead == 0) {
1278+
if (src.readSuspendImpl(1)) {
1279+
if (joined != null && src.tryCompleteJoining(joined)) break
1280+
} else if (joined == null || src.tryCompleteJoining(joined)) break
12671281
}
12681282

12691283
if (joining != null) {
@@ -1618,7 +1632,7 @@ internal class ByteBufferChannel(
16181632
read(min, block)
16191633
}
16201634

1621-
suspend override fun writePacket(packet: ByteReadPacket) {
1635+
override suspend fun writePacket(packet: ByteReadPacket) {
16221636
joining?.let { resolveDelegation(this, it)?.let { return it.writePacket(packet) } }
16231637

16241638
try {
@@ -1753,6 +1767,7 @@ internal class ByteBufferChannel(
17531767
} finally {
17541768
if (locked > 0) {
17551769
ringBufferCapacity.completeRead(locked)
1770+
locked = 0
17561771
}
17571772
}
17581773
}
@@ -2101,9 +2116,20 @@ internal class ByteBufferChannel(
21012116

21022117
private val readSuspendContinuationCache = MutableDelegateContinuation<Boolean>()
21032118

2119+
@Suppress("NOTHING_TO_INLINE")
2120+
private inline fun readSuspendPredicate(size: Int): Boolean {
2121+
val state = state
2122+
2123+
if (state.capacity.availableForRead >= size) return false
2124+
if (joining != null && writeOp != null &&
2125+
(state === ReadWriteBufferState.IdleEmpty || state is ReadWriteBufferState.IdleNonEmpty)) return false
2126+
2127+
return true
2128+
}
2129+
21042130
private fun suspensionForSize(size: Int, c: Continuation<Boolean>): Any {
21052131
do {
2106-
if (this.state.capacity.availableForRead >= size) {
2132+
if (!readSuspendPredicate(size)) {
21072133
c.resume(true)
21082134
break
21092135
}
@@ -2116,13 +2142,13 @@ internal class ByteBufferChannel(
21162142
}
21172143
return COROUTINE_SUSPENDED
21182144
}
2119-
} while (!setContinuation({ readOp }, ReadOp, c, { closed == null && state.capacity.availableForRead < size }))
2145+
} while (!setContinuation({ readOp }, ReadOp, c, { closed == null && readSuspendPredicate(size) }))
21202146

21212147
return COROUTINE_SUSPENDED
21222148
}
21232149

21242150
private suspend fun readSuspendImpl(size: Int): Boolean {
2125-
if (state.capacity.availableForRead >= size) return true
2151+
if (!readSuspendPredicate(size)) return true
21262152

21272153
return suspendCoroutineOrReturn { raw ->
21282154
val c = readSuspendContinuationCache
@@ -2131,6 +2157,9 @@ internal class ByteBufferChannel(
21312157
}
21322158
}
21332159

2160+
private fun shouldResumeReadOp() = joining != null &&
2161+
(state === ReadWriteBufferState.IdleEmpty || state is ReadWriteBufferState.IdleNonEmpty)
2162+
21342163
private fun writeSuspendPredicate(size: Int): Boolean {
21352164
val joined = joining
21362165
val state = state
@@ -2159,6 +2188,10 @@ internal class ByteBufferChannel(
21592188

21602189
flushImpl(1, minWriteSize = size)
21612190

2191+
if (shouldResumeReadOp()) {
2192+
resumeReadOp()
2193+
}
2194+
21622195
COROUTINE_SUSPENDED
21632196
}
21642197

@@ -2192,6 +2225,10 @@ internal class ByteBufferChannel(
21922225
} while (!setContinuation({ writeOp }, WriteOp, c, { writeSuspendPredicate(size) }))
21932226

21942227
flushImpl(1, minWriteSize = size)
2228+
2229+
if (shouldResumeReadOp()) {
2230+
resumeReadOp()
2231+
}
21952232
}
21962233
}
21972234

0 commit comments

Comments
 (0)