Skip to content

Commit 2e7b3ae

Browse files
committed
feat(mpz-common): Context::blocking
1 parent 5cb1aec commit 2e7b3ae

File tree

10 files changed

+413
-50
lines changed

10 files changed

+413
-50
lines changed

crates/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ futures-util = "0.3"
7474
tokio = "1.23"
7575
tokio-util = "0.7"
7676
scoped-futures = "0.1.3"
77+
pollster = "0.3"
7778

7879
# serialization
7980
ark-serialize = "0.4"

crates/mpz-common/Cargo.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ default = ["sync"]
88
sync = []
99
test-utils = ["uid-mux/test-utils"]
1010
ideal = []
11+
rayon = ["dep:rayon"]
12+
force-st = []
1113

1214
[dependencies]
1315
mpz-core.workspace = true
@@ -19,6 +21,9 @@ thiserror.workspace = true
1921
serio.workspace = true
2022
uid-mux.workspace = true
2123
serde = { workspace = true, features = ["derive"] }
24+
pollster.workspace = true
25+
rayon = { workspace = true, optional = true }
26+
cfg-if.workspace = true
2227

2328
[dev-dependencies]
2429
tokio = { workspace = true, features = [
@@ -29,3 +34,9 @@ tokio = { workspace = true, features = [
2934
tokio-util = { workspace = true, features = ["compat"] }
3035
uid-mux = { workspace = true, features = ["test-utils"] }
3136
tracing-subscriber = { workspace = true, features = ["fmt"] }
37+
criterion.workspace = true
38+
39+
[[bench]]
40+
name = "context"
41+
harness = false
42+
required-features = ["test-utils", "rayon"]

crates/mpz-common/benches/context.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
use criterion::{black_box, criterion_group, criterion_main, Criterion};
2+
use mpz_common::{
3+
executor::{test_mt_executor, test_st_executor},
4+
Context,
5+
};
6+
use pollster::block_on;
7+
use scoped_futures::ScopedFutureExt;
8+
9+
fn criterion_benchmark(c: &mut Criterion) {
10+
let mut group = c.benchmark_group("context");
11+
12+
// Measures the overhead of making a `Context::blocking` call, which
13+
// moves the context to a worker thread and back.
14+
group.bench_function("st/blocking", |b| {
15+
let (mut ctx, _) = test_st_executor(1024);
16+
b.iter(|| {
17+
block_on(async {
18+
ctx.blocking(|ctx| {
19+
async move {
20+
black_box(ctx.id());
21+
}
22+
.scope_boxed()
23+
})
24+
.await
25+
.unwrap();
26+
});
27+
})
28+
});
29+
30+
// Measures the overhead of making a `Context::blocking` call, which
31+
// moves the context to a worker thread and back.
32+
group.bench_function("mt/blocking", |b| {
33+
let (mut exec_a, _) = test_mt_executor(8);
34+
35+
let mut ctx = block_on(exec_a.new_thread()).unwrap();
36+
37+
b.iter(|| {
38+
block_on(async {
39+
ctx.blocking(|ctx| {
40+
async move {
41+
black_box(ctx.id());
42+
}
43+
.scope_boxed()
44+
})
45+
.await
46+
.unwrap();
47+
});
48+
})
49+
});
50+
}
51+
52+
criterion_group!(benches, criterion_benchmark);
53+
criterion_main!(benches);

crates/mpz-common/src/context.rs

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,26 @@ pub trait Context: Send + Sync {
5858
/// Returns a mutable reference to the thread's I/O channel.
5959
fn io_mut(&mut self) -> &mut Self::Io;
6060

61+
/// Executes a task that may block the thread.
62+
///
63+
/// If CPU multi-threading is available, the task is executed on a separate thread. Otherwise,
64+
/// the task is executed on the current thread and can block the executor.
65+
///
66+
/// # Deadlocks
67+
///
68+
/// This method may cause deadlocks if the task blocks and the executor can not make progress.
69+
/// Generally, one should *never* block across an await point. This method is intended for operations
70+
/// that are CPU-bound but require access to a thread context.
71+
///
72+
/// # Overhead
73+
///
74+
/// This method has an inherent overhead and should only be used for tasks that are CPU-bound. Otherwise,
75+
/// prefer using [`Context::queue`] or [`Context::join`] to execute tasks concurrently.
76+
async fn blocking<F, R>(&mut self, f: F) -> Result<R, ContextError>
77+
where
78+
F: for<'a> FnOnce(&'a mut Self) -> ScopedBoxFuture<'static, 'a, R> + Send + 'static,
79+
R: Send + 'static;
80+
6181
/// Forks the thread and executes the provided closures concurrently.
6282
///
6383
/// Implementations may not be able to fork, in which case the closures are executed
@@ -119,32 +139,34 @@ macro_rules! try_join {
119139

120140
#[cfg(test)]
121141
mod tests {
122-
use crate::executor::test_st_executor;
142+
use crate::{executor::test_st_executor, Context};
143+
use futures::executor::block_on;
123144

124145
#[test]
125146
fn test_join_macro() {
126147
let (mut ctx, _) = test_st_executor(1);
127148

128-
futures::executor::block_on(async {
129-
join!(ctx, async { println!("{:?}", ctx.id()) }, async {
130-
println!("{:?}", ctx.id())
131-
})
132-
.unwrap()
149+
let (id_0, id_1) = block_on(async {
150+
join!(ctx, async { ctx.id().clone() }, async { ctx.id().clone() }).unwrap()
133151
});
152+
153+
assert_eq!(&id_0, ctx.id());
154+
assert_eq!(&id_1, ctx.id());
134155
}
135156

136157
#[test]
137158
fn test_try_join_macro() {
138159
let (mut ctx, _) = test_st_executor(1);
139160

140-
futures::executor::block_on(async {
141-
try_join!(
142-
ctx,
143-
async { Ok::<_, ()>(println!("{:?}", ctx.id())) },
144-
async { Ok::<_, ()>(println!("{:?}", ctx.id())) }
145-
)
161+
let (id_0, id_1) = block_on(async {
162+
try_join!(ctx, async { Ok::<_, ()>(ctx.id().clone()) }, async {
163+
Ok::<_, ()>(ctx.id().clone())
164+
})
165+
.unwrap()
146166
.unwrap()
147-
.unwrap();
148167
});
168+
169+
assert_eq!(&id_0, ctx.id());
170+
assert_eq!(&id_1, ctx.id());
149171
}
150172
}

crates/mpz-common/src/cpu.rs

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
//! CPU backend shim.
2+
3+
use cfg_if::cfg_if;
4+
5+
cfg_if! {
6+
if #[cfg(feature = "force-st")] {
7+
pub use st::SingleThreadedBackend as CpuBackend;
8+
} else if #[cfg(feature = "rayon")] {
9+
pub use rayon_backend::RayonBackend as CpuBackend;
10+
} else {
11+
pub use st::SingleThreadedBackend as CpuBackend;
12+
}
13+
}
14+
15+
#[cfg(any(feature = "force-st", not(feature = "rayon")))]
16+
mod st {
17+
use futures::Future;
18+
19+
/// A single-threaded CPU backend.
20+
#[derive(Debug)]
21+
pub struct SingleThreadedBackend;
22+
23+
impl SingleThreadedBackend {
24+
/// Execute a future on the CPU backend.
25+
#[inline]
26+
pub fn blocking_async<F>(fut: F) -> impl Future<Output = F::Output> + Send
27+
where
28+
F: Future + Send + 'static,
29+
F::Output: Send,
30+
{
31+
fut
32+
}
33+
34+
/// Execute a closure on the CPU backend.
35+
#[inline]
36+
pub fn blocking<F, R>(f: F) -> impl Future<Output = R> + Send
37+
where
38+
F: FnOnce() -> R + Send + 'static,
39+
R: Send + 'static,
40+
{
41+
async move { f() }
42+
}
43+
}
44+
45+
#[cfg(test)]
46+
mod tests {
47+
use super::*;
48+
use pollster::block_on;
49+
50+
#[test]
51+
fn test_st_backend_blocking() {
52+
let output = block_on(SingleThreadedBackend::blocking(|| 42));
53+
assert_eq!(output, 42);
54+
}
55+
56+
#[test]
57+
fn test_st_backend_blocking_async() {
58+
let output = block_on(SingleThreadedBackend::blocking_async(async { 42 }));
59+
assert_eq!(output, 42);
60+
}
61+
}
62+
}
63+
64+
#[cfg(all(feature = "rayon", not(feature = "force-st")))]
65+
mod rayon_backend {
66+
use futures::{channel::oneshot, Future};
67+
use pollster::block_on;
68+
69+
/// A Rayon CPU backend.
70+
#[derive(Debug)]
71+
pub struct RayonBackend;
72+
73+
impl RayonBackend {
74+
/// Execute a future on the CPU backend.
75+
pub fn blocking_async<F>(fut: F) -> impl Future<Output = F::Output> + Send
76+
where
77+
F: Future + Send + 'static,
78+
F::Output: Send,
79+
{
80+
async move {
81+
let (sender, receiver) = oneshot::channel();
82+
rayon::spawn(move || {
83+
let output = block_on(fut);
84+
_ = sender.send(output);
85+
});
86+
receiver.await.expect("worker thread does not drop channel")
87+
}
88+
}
89+
90+
/// Execute a closure on the CPU backend.
91+
pub fn blocking<F, R>(f: F) -> impl Future<Output = R> + Send
92+
where
93+
F: FnOnce() -> R + Send + 'static,
94+
R: Send + 'static,
95+
{
96+
async move {
97+
let (sender, receiver) = oneshot::channel();
98+
rayon::spawn(move || {
99+
_ = sender.send(f());
100+
});
101+
receiver.await.expect("worker thread does not drop channel")
102+
}
103+
}
104+
}
105+
106+
#[cfg(test)]
107+
mod tests {
108+
use super::*;
109+
110+
#[test]
111+
fn test_rayon_backend_blocking() {
112+
let output = block_on(RayonBackend::blocking(|| 42));
113+
assert_eq!(output, 42);
114+
}
115+
116+
#[test]
117+
fn test_rayon_backend_blocking_async() {
118+
let output = block_on(RayonBackend::blocking_async(async { 42 }));
119+
assert_eq!(output, 42);
120+
}
121+
}
122+
}

crates/mpz-common/src/executor/dummy.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
use async_trait::async_trait;
2+
23
use scoped_futures::ScopedBoxFuture;
34
use serio::{Sink, Stream};
45

5-
use crate::{context::Context, ContextError, ThreadId};
6+
use crate::{context::Context, cpu::CpuBackend, ContextError, ThreadId};
67

78
/// A dummy executor.
89
#[derive(Debug, Default)]
@@ -74,6 +75,19 @@ impl Context for DummyExecutor {
7475
&mut self.io
7576
}
7677

78+
async fn blocking<F, R>(&mut self, f: F) -> Result<R, ContextError>
79+
where
80+
F: for<'a> FnOnce(&'a mut Self) -> ScopedBoxFuture<'static, 'a, R> + Send + 'static,
81+
R: Send + 'static,
82+
{
83+
let mut ctx = Self {
84+
id: self.id.clone(),
85+
io: DummyIo,
86+
};
87+
88+
Ok(CpuBackend::blocking_async(async move { f(&mut ctx).await }).await)
89+
}
90+
7791
async fn join<'a, A, B, RA, RB>(&'a mut self, a: A, b: B) -> Result<(RA, RB), ContextError>
7892
where
7993
A: for<'b> FnOnce(&'b mut Self) -> ScopedBoxFuture<'a, 'b, RA> + Send + 'a,

crates/mpz-common/src/executor/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ mod mt;
55
mod st;
66

77
pub use dummy::{DummyExecutor, DummyIo};
8-
pub use mt::MTExecutor;
8+
pub use mt::{MTContext, MTExecutor};
99
pub use st::STExecutor;
1010

1111
#[cfg(any(test, feature = "test-utils"))]

0 commit comments

Comments
 (0)