Skip to content

RWMutex/Semaphore Redux #1182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,26 @@ public final class kotlinx/coroutines/sync/MutexKt {
public static synthetic fun withLock$default (Lkotlinx/coroutines/sync/Mutex;Ljava/lang/Object;Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
}

public abstract interface class kotlinx/coroutines/sync/ReadWriteMutex {
public abstract fun getRead ()Lkotlinx/coroutines/sync/Mutex;
public abstract fun getWrite ()Lkotlinx/coroutines/sync/Mutex;
}

public final class kotlinx/coroutines/sync/ReadWriteMutexKt {
public static final fun ReadWriteMutex ()Lkotlinx/coroutines/sync/ReadWriteMutex;
}

public abstract interface class kotlinx/coroutines/sync/Semaphore {
public abstract fun acquire (Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public abstract fun getPermits ()I
public abstract fun release ()V
public abstract fun tryAcquire ()Z
}

public final class kotlinx/coroutines/sync/SemaphoreKt {
public static final fun Semaphore (I)Lkotlinx/coroutines/sync/Semaphore;
}

public final class kotlinx/coroutines/test/TestCoroutineContext : kotlin/coroutines/CoroutineContext {
public fun <init> ()V
public fun <init> (Ljava/lang/String;)V
Expand Down
1 change: 1 addition & 0 deletions knit/src/Knit.kt
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ fun makeSectionRef(name: String): String = name
.replace(",", "")
.replace("(", "")
.replace(")", "")
.replace("`", "")
.toLowerCase()

class Include(val regex: Regex, val lines: MutableList<String> = arrayListOf())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ public expect open class LockFreeLinkedListNode() {
public val nextNode: LockFreeLinkedListNode
public val prevNode: LockFreeLinkedListNode
public fun addLast(node: LockFreeLinkedListNode)
public fun <T : LockFreeLinkedListNode> describeAddLast(node: T): AddLastDesc<T>
public fun addOneIfEmpty(node: LockFreeLinkedListNode): Boolean
public inline fun addLastIf(node: LockFreeLinkedListNode, crossinline condition: () -> Boolean): Boolean
public inline fun addLastIfPrev(
Expand Down
267 changes: 267 additions & 0 deletions kotlinx-coroutines-core/common/src/sync/ReadWriteMutex.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
package kotlinx.coroutines.sync

import kotlinx.atomicfu.atomic
import kotlinx.atomicfu.loop
import kotlinx.coroutines.internal.*
import kotlinx.coroutines.selects.SelectClause2
import kotlinx.coroutines.sync.Mutex
import kotlin.coroutines.Continuation
import kotlin.coroutines.intrinsics.COROUTINE_SUSPENDED
import kotlin.coroutines.intrinsics.suspendCoroutineUninterceptedOrReturn
import kotlin.coroutines.resume

public interface ReadWriteMutex {
val read: Mutex
val write: Mutex
}

/**
* Read-write lock for coroutines. This implementation is fair, non-reentrant,
* and non-suspending when there are no writers.
*/
@Suppress("FunctionName")
public fun ReadWriteMutex(): ReadWriteMutex = ReadWriteMutexImpl()

internal class ReadWriteMutexImpl : ReadWriteMutex {

// count -1: write lock held
// count 0: no locks held
// count x > 0: x read locks held
// waiters: number of coroutines waiting for a lock
private class State(val count: Int, val waiters: Int, val owner: Any?)
private val _state = atomic<Any>(State(0, 0, null)) // S | OpDescriptor

private class Waiter(val cont: Continuation<Unit>, val isWriter: Boolean, val owner: Any?) : LockFreeLinkedListNode()
val queue = LockFreeLinkedListHead() // queue of waiters

// create an op to atomically enqueue a waiter and increment the waiter count in the state
private fun createAddWaiterOp(state: State, cont: Continuation<Unit>, isWriter: Boolean, owner: Any?) : OpDescriptor {
val waiter = Waiter(cont, isWriter, owner)
val addLastDesc = queue.describeAddLast(waiter)
return object : AtomicOp<Any?>() {
override fun prepare(affected: Any?): Any? {
return addLastDesc.prepare(this)
}

override fun complete(affected: Any?, failure: Any?) {
addLastDesc.complete(this, failure)
_state.compareAndSet(this, State(state.count, state.waiters + 1, state.owner))
}
}
}

override val read = object : Mutex {
override val isLocked: Boolean get() {
_state.loop { state ->
when (state) {
is OpDescriptor -> state.perform(this) // help
is State -> return state.count > 0
else -> error("unexpected state $state")
}
}
}

override fun tryLock(owner: Any?): Boolean {
require(owner == null) { "owners not supported for read mutex" }
_state.loop { state ->
when (state) {
is OpDescriptor -> state.perform(this) // help
is State -> {
if (state.count >= 0 && state.waiters == 0) {
if (_state.compareAndSet(state, State(state.count + 1, 0, null))) {
return true
}
} else {
return false
}
}
else -> error("unexpected state $state")
}
}
}

override suspend fun lock(owner: Any?) {
require(owner == null) { "owners not supported for read mutex" }
if (tryLock()) return
else return suspendCoroutineUninterceptedOrReturn { lockCont(it) }
}

private fun lockCont(cont: Continuation<Unit>): Any {
_state.loop { state ->
when (state) {
is OpDescriptor -> state.perform(this) // help
is State -> {
if (state.count >= 0 && state.waiters == 0) {
if (_state.compareAndSet(state, State(state.count + 1, 0, null))) {
return Unit
}
} else {
val op = createAddWaiterOp(state, cont, false, null)
if (_state.compareAndSet(state, op)) {
op.perform(this)
return COROUTINE_SUSPENDED
}
}
}
else -> error("unexpected state $state")
}
}
}

override fun unlock(owner: Any?) {
require(owner == null) { "owners not supported for read mutex" }
_state.loop { state ->
when (state) {
is OpDescriptor -> state.perform(this) // help
is State -> {
check(state.count > 0) { "read mutex is not locked" }
if (state.count > 1 || state.waiters == 0) {
if (_state.compareAndSet(state, State(state.count - 1, state.waiters, null))) {
return
}
} else {
// this seems to be the easiest way to peek the queue
val waiter = (queue.removeFirstIfIsInstanceOfOrPeekIf<Waiter> { true })!!
if (_state.compareAndSet(state, State(-1, state.waiters - 1, waiter.owner))) {
val writer = queue.removeFirstOrNull()!! as Waiter
writer.cont.resume(Unit)
return
}
}
}
else -> error("unexpected state $state")
}
}
}

override val onLock: SelectClause2<Any?, Mutex>
get() = TODO("support for select is not yet implemented")

override fun holdsLock(owner: Any): Boolean {
throw UnsupportedOperationException("owners not supported for read mutex")
}
}

override val write = object : Mutex {
override val isLocked: Boolean get() {
_state.loop { state ->
when (state) {
is OpDescriptor -> state.perform(this) // help
is State -> return state.count < 0
else -> error("unexpected state $state")
}
}
}

override fun tryLock(owner: Any?): Boolean {
_state.loop { state ->
when (state) {
is OpDescriptor -> state.perform(this) // help
is State -> {
if (state.count == 0 && state.waiters == 0) {
if (_state.compareAndSet(state, State(-1, 0, owner))) {
return true
}
} else {
check(state.owner !== owner) { "already locked by $owner" }
return false
}
}
else -> error("unexpected state $state")
}
}
}

override suspend fun lock(owner: Any?) {
if (tryLock(owner)) return
else return suspendCoroutineUninterceptedOrReturn { lockCont(owner, it) }
}

private fun lockCont(owner: Any?, cont: Continuation<Unit>): Any {
_state.loop { state ->
when (state) {
is OpDescriptor -> state.perform(this) // help
is State -> {
if (state.count == 0 && state.waiters == 0) {
if (_state.compareAndSet(state, State(-1, 0, owner))) {
return Unit
}
} else {
check(state.owner !== owner) { "already locked by $owner" }
val op = createAddWaiterOp(state, cont, true, owner)
if (_state.compareAndSet(state, op)) {
op.perform(this)
return COROUTINE_SUSPENDED
}
}
}
else -> error("unexpected state $state")
}
}
}

override fun unlock(owner: Any?) {
_state.loop { state ->
when (state) {
is OpDescriptor -> state.perform(this) // help
is State -> {
check(state.count == -1) { "write mutex is not locked" }
check(state.owner === owner) { "write mutex is locked by ${state.owner} but $owner is unlocking it" }
if (state.waiters == 0) {
if (_state.compareAndSet(state, State(0, 0, null))) {
return
}
} else {
// this seems to be the easiest way to peek the queue
val waiter = (queue.removeFirstIfIsInstanceOfOrPeekIf<Waiter> { true })!!

if (waiter.isWriter) {
if (_state.compareAndSet(state, State(-1, state.waiters - 1, waiter.owner))) {
val writer = queue.removeFirstOrNull()!! as Waiter
writer.cont.resume(Unit)
return
}
} else {
var readers = consecutiveWaitingReaderCount()
if (_state.compareAndSet(state, State(readers, state.waiters - readers, null))) {
while (readers-- > 0) {
val reader = queue.removeFirstOrNull()!! as Waiter
reader.cont.resume(Unit)
}
return
}
}
}
}
else -> error("unexpected state $state")
}
}
}

private fun consecutiveWaitingReaderCount(): Int {
var readers = 0
var done = false
queue.forEach<Waiter> { waiter ->
if (waiter.isWriter) {
done = true // TODO: terminate iteration early somehow
} else if (!done) {
readers++
}
}
return readers
}

override val onLock: SelectClause2<Any?, Mutex>
get() = TODO("support for select is not yet implemented")

override fun holdsLock(owner: Any): Boolean {
_state.loop { state ->
when (state) {
is OpDescriptor -> state.perform(this) // help
is State -> return state.owner === owner
else -> error("unexpected $state")
}
}
}
}
}
Loading