diff --git a/lib/multi_channel.ml b/lib/multi_channel.ml index 7bbe3e1..b6c7e24 100644 --- a/lib/multi_channel.ml +++ b/lib/multi_channel.ml @@ -139,7 +139,7 @@ let rec recv_poll_loop mchan dls cur_offset = let recv_poll_with_dls mchan dls = try - Ws_deque.pop (Array.unsafe_get mchan.channels dls.id) + Ws_deque.steal (Array.unsafe_get mchan.channels dls.id) with | Exit -> recv_poll_loop mchan dls 0 [@@inline] diff --git a/lib/task.ml b/lib/task.ml index 3599925..beb4f6b 100644 --- a/lib/task.ml +++ b/lib/task.ml @@ -24,7 +24,9 @@ type 'a promise_state = type 'a promise = 'a promise_state Atomic.t -type _ t += Wait : 'a promise * task_chan -> 'a t +type _ t += + | Wait : 'a promise * task_chan -> 'a t + | Yield : task_chan -> unit t let get_pool_data p = match Atomic.get p with @@ -61,6 +63,10 @@ let await pool promise = | Raised (e, bt) -> Printexc.raise_with_backtrace e bt | Pending _ -> perform (Wait (promise, pd.task_chan)) +let yield pool = + let pd = get_pool_data pool in + perform (Yield pd.task_chan) + let step (type a) (f : a -> unit) (v : a) : unit = try_with f v { effc = fun (type a) (e : a t) -> @@ -76,6 +82,7 @@ let step (type a) (f : a -> unit) (v : a) : unit = | Raised (e,bt) -> discontinue_with_backtrace k e bt in loop ()) + | Yield c -> Some (fun (k : (a, _) continuation) -> cont () (k, c)) | _ -> None } let rec worker task_chan = diff --git a/lib/task.mli b/lib/task.mli index 16baeac..c2bfa06 100644 --- a/lib/task.mli +++ b/lib/task.mli @@ -54,6 +54,12 @@ val await : pool -> 'a promise -> 'a Must be called with a call to {!run} in the dynamic scope to handle the internal algebraic effects for task synchronization. *) +val yield : pool -> unit +(** [yield p] suspends the current task momentarily, to be continued later. + This function should be called in place of {!Domain.cpu_relax ()} when the + current task is stuck and waiting on others tasks from the pool [p] to make + progress. *) + val parallel_for : ?chunk_size:int -> start:int -> finish:int -> body:(int -> unit) -> pool -> unit (** [parallel_for c s f b p] behaves similar to [for i=s to f do b i done], but diff --git a/test/dune b/test/dune index 346e8e1..b377c9f 100644 --- a/test/dune +++ b/test/dune @@ -110,3 +110,8 @@ (libraries domainslib) (modules off_by_one) (modes native)) + +(test + (name test_yield) + (libraries domainslib) + (modules test_yield)) diff --git a/test/test_yield.ml b/test/test_yield.ml new file mode 100644 index 0000000..a06b671 --- /dev/null +++ b/test/test_yield.ml @@ -0,0 +1,44 @@ +(* Test gets stuck if [Task.yield] is missing. *) + +module T = Domainslib.Task + +module Cell : sig + + type 'a t + val make : T.pool -> 'a t + val push : 'a t -> 'a -> unit + val pop : 'a t -> 'a + +end = struct + + type 'a t = { pool : T.pool ; cell : 'a option Atomic.t } + + let make pool = { pool ; cell = Atomic.make None } + + let rec push t x = + if not (Atomic.compare_and_set t.cell None (Some x)) + then begin + T.yield t.pool ; + push t x + end + + let rec pop t = + match Atomic.get t.cell with + | (Some x) as old when Atomic.compare_and_set t.cell old None -> x + | _ -> + T.yield t.pool ; (* try commenting *) + pop t +end + +let test pool () = + let t = Cell.make pool in + T.parallel_for pool ~start:1 ~finish:100 ~body:(fun i -> + let p = T.async pool (fun () -> Cell.push t i) in + let _ = Cell.pop t in + T.await pool p + ) + +let () = + let pool = T.setup_pool ~num_domains:6 () in + T.run pool (test pool) ; + T.teardown_pool pool