Skip to content

Commit 4d01f7e

Browse files
author
Sergey Mashkov
committed
IO: make reusable cancellable continuation implement DispatchedTask
1 parent 883e57e commit 4d01f7e

File tree

4 files changed

+207
-34
lines changed

4 files changed

+207
-34
lines changed

core/kotlinx-coroutines-io/src/main/kotlin/kotlinx/coroutines/experimental/io/ByteBufferChannel.kt

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ internal class ByteBufferChannel(
107107
}
108108

109109
if (cause != null) attachedJob?.cancel(cause)
110-
readSuspendContinuationCache.close()
111-
writeSuspendContinuationCache.close()
110+
// readSuspendContinuationCache.close()
111+
// writeSuspendContinuationCache.close()
112112

113113
return true
114114
}
@@ -155,8 +155,8 @@ internal class ByteBufferChannel(
155155
}
156156

157157
private fun setupStateForWrite(): ByteBuffer? {
158-
if (writeOp != null) {
159-
throw IllegalStateException("Write operation is already in progress")
158+
writeOp?.let { existing ->
159+
throw IllegalStateException("Write operation is already in progress: $existing")
160160
}
161161

162162
var _allocated: ReadWriteBufferState.Initial? = null
@@ -183,7 +183,6 @@ internal class ByteBufferChannel(
183183
}
184184
}
185185

186-
// joining?.let { restoreStateAfterWrite(); return null }
187186
if (closed != null) {
188187
restoreStateAfterWrite()
189188
tryTerminate()
@@ -348,12 +347,16 @@ internal class ByteBufferChannel(
348347
val current = joining?.let { resolveDelegation(this, it) } ?: this
349348
val buffer = current.setupStateForWrite() ?: return
350349
val capacity = current.state.capacity
350+
val before = current.totalBytesWritten
351351

352352
try {
353353
current.closed?.let { throw it.sendException }
354354
block(current, buffer, capacity)
355355
} finally {
356356
if (capacity.isFull() || current.autoFlush) current.flush()
357+
if (current !== this) {
358+
totalBytesWritten += current.totalBytesWritten - before
359+
}
357360
current.restoreStateAfterWrite()
358361
current.tryTerminate()
359362
}
@@ -692,7 +695,7 @@ internal class ByteBufferChannel(
692695
return readShort()
693696
}
694697

695-
final suspend override fun readInt(): Int {
698+
final override suspend fun readInt(): Int {
696699
var i = 0
697700

698701
val rc = reading {
@@ -716,7 +719,7 @@ internal class ByteBufferChannel(
716719
return readInt()
717720
}
718721

719-
final suspend override fun readLong(): Long {
722+
final override suspend fun readLong(): Long {
720723
var i = 0L
721724

722725
val rc = reading {
@@ -740,7 +743,7 @@ internal class ByteBufferChannel(
740743
return readLong()
741744
}
742745

743-
final suspend override fun readDouble(): Double {
746+
final override suspend fun readDouble(): Double {
744747
var d = 0.0
745748

746749
val rc = reading {
@@ -764,7 +767,7 @@ internal class ByteBufferChannel(
764767
return readDouble()
765768
}
766769

767-
final suspend override fun readFloat(): Float {
770+
final override suspend fun readFloat(): Float {
768771
var f = 0.0f
769772

770773
val rc = reading {
@@ -1182,7 +1185,6 @@ internal class ByteBufferChannel(
11821185
}
11831186

11841187
internal suspend fun copyDirect(src: ByteBufferChannel, limit: Long, joined: JoiningState?): Long {
1185-
// println("enter copyDirect")
11861188
if (limit == 0L) return 0L
11871189
if (src.isClosedForRead) {
11881190
if (joined != null) {
@@ -1234,13 +1236,11 @@ internal class ByteBufferChannel(
12341236
true
12351237
}
12361238

1237-
// println("rc = $rc, partSize = $partSize")
12381239
if (rc) {
12391240
dstBuffer.bytesWritten(state, partSize)
12401241
copied += partSize
12411242

12421243
if (avWBefore - partSize == 0 || autoFlush) {
1243-
// println("flush")
12441244
flush()
12451245
}
12461246
} else {
@@ -1257,30 +1257,23 @@ internal class ByteBufferChannel(
12571257
}
12581258
}
12591259

1260-
if (joining != null) break // TODO think of joining chain
12611260
if (copied >= limit) break
12621261

1263-
// println("readSuspend?")
12641262
flush()
12651263

12661264
if (src.availableForRead == 0 && !src.readSuspendImpl(1)) {
1267-
// println("readSuspend failed")
12681265
if (joined == null || src.tryCompleteJoining(joined)) break
12691266
}
1270-
// println("next loop")
1267+
1268+
if (joining != null) {
1269+
yield()
1270+
}
12711271
}
12721272

12731273
if (autoFlush) {
1274-
// println("final flush")
12751274
flush()
12761275
}
12771276

1278-
if (joined == null) {
1279-
joining?.let { thisJoined ->
1280-
return copied + thisJoined.delegatedTo.copyDirect(src, limit - copied, null)
1281-
}
1282-
}
1283-
12841277
return copied
12851278
} catch (t: Throwable) {
12861279
close(t)
@@ -1293,7 +1286,17 @@ internal class ByteBufferChannel(
12931286
this.joining = null
12941287

12951288
if (joined.delegateClose) {
1296-
joined.delegatedTo.close(closed.cause)
1289+
// writing state could be if we are inside of copyDirect loop
1290+
// so in this case we shouldn't close channel
1291+
// otherwise few bytes could be lost
1292+
// it will be closed later in copyDirect's finalization
1293+
// so we only do flush
1294+
val writing = joined.delegatedTo.state.let { it is ReadWriteBufferState.Writing || it is ReadWriteBufferState.ReadingWriting }
1295+
if (closed.cause != null || !writing) {
1296+
joined.delegatedTo.close(closed.cause)
1297+
} else {
1298+
joined.delegatedTo.flush()
1299+
}
12971300
} else {
12981301
joined.delegatedTo.flush()
12991302
}

core/kotlinx-coroutines-io/src/main/kotlin/kotlinx/coroutines/experimental/io/internal/MutableDelegateContinuation.kt

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,20 @@ import kotlin.coroutines.experimental.intrinsics.*
1111
* - [T] should be neither [Throwable] nor [Continuation]
1212
* - value shouldn't be null
1313
*/
14-
internal class MutableDelegateContinuation<T : Any> : Continuation<T> {
14+
internal class MutableDelegateContinuation<T : Any> : Continuation<T>, DispatchedTask<T> {
15+
private var _delegate: Continuation<T>? = null
1516
private val state = atomic<Any?>(null)
1617
private val handler = atomic<JobRelation?>(null)
1718

19+
override val delegate: Continuation<T>
20+
get() = _delegate!!
21+
22+
override fun takeState(): Any? {
23+
val value = state.getAndSet(null)
24+
_delegate = null
25+
return value
26+
}
27+
1828
fun swap(actual: Continuation<T>): Any {
1929
loop@while (true) {
2030
val before = state.value
@@ -72,10 +82,11 @@ internal class MutableDelegateContinuation<T : Any> : Continuation<T> {
7282
return
7383
}
7484
is Continuation<*> -> {
75-
if (!state.compareAndSet(before, null)) continue@loop
85+
if (!state.compareAndSet(before, value)) continue@loop
7686
@Suppress("UNCHECKED_CAST")
7787
val cont = before as Continuation<T>
78-
return cont.resume(value)
88+
_delegate = cont
89+
return dispatch(1)
7990
}
8091
else -> return
8192
}
@@ -92,10 +103,11 @@ internal class MutableDelegateContinuation<T : Any> : Continuation<T> {
92103
return
93104
}
94105
is Continuation<*> -> {
95-
if (!state.compareAndSet(before, null)) continue@loop
106+
if (!state.compareAndSet(before, CompletedExceptionally(exception))) continue@loop
96107
@Suppress("UNCHECKED_CAST")
97108
val cont = before as Continuation<T>
98-
return cont.resumeWithException(exception)
109+
_delegate = cont
110+
return dispatch(1)
99111
}
100112
else -> return
101113
}

0 commit comments

Comments
 (0)