Skip to content

Commit 018516e

Browse files
committed
Add cancellation support to session::tokio functions
1 parent f2e444b commit 018516e

File tree

4 files changed

+88
-44
lines changed

4 files changed

+88
-44
lines changed

Cargo.lock

+32
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

manul/Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ serde-persistent-deserializer = { version = "0.3", optional = true }
2828
postcard = { version = "1", default-features = false, features = ["alloc"], optional = true }
2929
serde_json = { version = "1", default-features = false, features = ["alloc"], optional = true }
3030
tokio = { version = "1", default-features = false, features = ["sync", "rt", "macros", "time"], optional = true }
31+
tokio-util = { version = "0.7", default-features = false, optional = true }
3132

3233
[dev-dependencies]
3334
impls = "1"
@@ -44,7 +45,7 @@ tracing = { version = "0.1", default-features = false, features = ["std"] }
4445

4546
[features]
4647
dev = ["rand", "postcard", "serde_json", "tracing/std", "serde-persistent-deserializer"]
47-
tokio = ["dep:tokio"]
48+
tokio = ["dep:tokio", "tokio-util"]
4849

4950
[package.metadata.docs.rs]
5051
all-features = true

manul/src/dev/tokio.rs

+5-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use rand::Rng;
66
use rand_core::CryptoRngCore;
77
use signature::Keypair;
88
use tokio::sync::mpsc;
9+
use tokio_util::sync::CancellationToken;
910

1011
use super::run_sync::ExecutionResult;
1112
use crate::{
@@ -100,6 +101,7 @@ where
100101

101102
let dispatcher_task = message_dispatcher(rng.clone(), tx_map, dispatcher_rx);
102103
let dispatcher = tokio::spawn(dispatcher_task);
104+
let cancellation = CancellationToken::new();
103105

104106
let handles = rxs
105107
.into_iter()
@@ -110,12 +112,13 @@ where
110112

111113
let session = Session::<_, SP>::new(&mut rng, session_id.clone(), signer, entry_point)?;
112114
let id = session.verifier().clone();
115+
let cancellation = cancellation.clone();
113116

114117
let node_task = async move {
115118
if offload_processing {
116-
par_run_session(&mut rng, &tx, &mut rx, session).await
119+
par_run_session(&mut rng, &tx, &mut rx, cancellation, session).await
117120
} else {
118-
run_session(&mut rng, &tx, &mut rx, session).await
121+
run_session(&mut rng, &tx, &mut rx, cancellation, session).await
119122
}
120123
};
121124
Ok((id, tokio::spawn(node_task)))

manul/src/session/tokio.rs

+49-41
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use alloc::{format, sync::Arc, vec::Vec};
44

55
use rand_core::CryptoRngCore;
66
use tokio::{sync::mpsc, task::JoinHandle};
7+
use tokio_util::sync::CancellationToken;
78
use tracing::{debug, trace};
89

910
use super::{
@@ -50,6 +51,7 @@ pub async fn run_session<P, SP>(
5051
rng: &mut impl CryptoRngCore,
5152
tx: &mpsc::Sender<MessageOut<SP>>,
5253
rx: &mut mpsc::Receiver<MessageIn<SP>>,
54+
cancellation: CancellationToken,
5355
session: Session<P, SP>,
5456
) -> Result<SessionReport<P, SP>, LocalError>
5557
where
@@ -135,10 +137,14 @@ where
135137
}
136138

137139
debug!("{my_id}: Waiting for a message");
138-
let message_in = rx
139-
.recv()
140-
.await
141-
.ok_or_else(|| LocalError::new("Failed to receive a message"))?;
140+
let message_in = tokio::select! {
141+
message_in = rx.recv() => {
142+
message_in.ok_or_else(|| LocalError::new("The incoming message channel was closed unexpectedly"))?
143+
},
144+
_ = cancellation.cancelled() => {
145+
return session.terminate_due_to_errors(accum);
146+
}
147+
};
142148

143149
// Perform quick checks before proceeding with the verification.
144150
match session
@@ -184,6 +190,7 @@ pub async fn par_run_session<P, SP>(
184190
rng: &mut (impl 'static + Clone + CryptoRngCore + Send),
185191
tx: &mpsc::Sender<MessageOut<SP>>,
186192
rx: &mut mpsc::Receiver<MessageIn<SP>>,
193+
cancellation: CancellationToken,
187194
session: Session<P, SP>,
188195
) -> Result<SessionReport<P, SP>, LocalError>
189196
where
@@ -280,49 +287,50 @@ where
280287

281288
tokio::select! {
282289
processed = processed_rx.recv() => {
283-
if let Some(processed) = processed {
284-
session.add_processed_message(&mut accum, processed)?;
285-
}
290+
let processed = processed.ok_or_else(|| LocalError::new("The processed message channel was closed unexpectedly"))?;
291+
session.add_processed_message(&mut accum, processed)?;
286292
}
287293
outgoing = outgoing_rx.recv() => {
288-
if let Some((message_out, artifact)) = outgoing {
289-
let from = message_out.from.clone();
290-
let to = message_out.to.clone();
291-
tx.send(message_out)
292-
.await
293-
.map_err(|err| {
294-
LocalError::new(format!(
295-
"Failed to send a message from {from:?} to {to:?}: {err}",
296-
))
297-
})?;
298-
299-
session.add_artifact(&mut accum, artifact)?;
300-
}
294+
let (message_out, artifact) = outgoing.ok_or_else(|| LocalError::new("The outgoing message channel was closed unexpectedly"))?;
295+
296+
let from = message_out.from.clone();
297+
let to = message_out.to.clone();
298+
tx.send(message_out)
299+
.await
300+
.map_err(|err| {
301+
LocalError::new(format!(
302+
"Failed to send a message from {from:?} to {to:?}: {err}",
303+
))
304+
})?;
305+
306+
session.add_artifact(&mut accum, artifact)?;
301307
}
302308
message_in = rx.recv() => {
303-
if let Some(message_in) = message_in {
304-
match session
305-
.preprocess_message(&mut accum, &message_in.from, message_in.message)?
306-
.ok()
307-
{
308-
Some(preprocessed) => {
309-
let session = session.clone();
310-
let processed_tx = processed_tx.clone();
311-
let my_id = my_id.clone();
312-
let message_processing = tokio::task::spawn_blocking(move || {
313-
debug!("{my_id}: Applying a message from {:?}", message_in.from);
314-
let processed = session.process_message(preprocessed);
315-
processed_tx.blocking_send(processed).map_err(|_err| {
316-
LocalError::new("Failed to send a processed message")
317-
})
318-
});
319-
message_processing_tasks.push(message_processing);
320-
}
321-
None => {
322-
trace!("{my_id} Pre-processing complete. Current state: {accum:?}")
323-
}
309+
let message_in = message_in.ok_or_else(|| LocalError::new("The incoming message channel was closed unexpectedly"))?;
310+
match session
311+
.preprocess_message(&mut accum, &message_in.from, message_in.message)?
312+
.ok()
313+
{
314+
Some(preprocessed) => {
315+
let session = session.clone();
316+
let processed_tx = processed_tx.clone();
317+
let my_id = my_id.clone();
318+
let message_processing = tokio::task::spawn_blocking(move || {
319+
debug!("{my_id}: Applying a message from {:?}", message_in.from);
320+
let processed = session.process_message(preprocessed);
321+
processed_tx.blocking_send(processed).map_err(|_err| {
322+
LocalError::new("Failed to send a processed message")
323+
})
324+
});
325+
message_processing_tasks.push(message_processing);
326+
}
327+
None => {
328+
trace!("{my_id} Pre-processing complete. Current state: {accum:?}")
324329
}
325330
}
331+
},
332+
_ = cancellation.cancelled() => {
333+
break false;
326334
}
327335
}
328336
};

0 commit comments

Comments
 (0)