Skip to content

Add cancellation support to session::tokio functions #100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).


## [0.3.0] - in development

### Changed

- `session::tokio::run_session()` and `par_run_session()` take an additional `cancellation` argument to support external loop cancellation. ([#100])


[#100]: https://github.com/entropyxyz/manul/pull/100


## [0.2.1] - 2025-05-05

### Added
Expand Down
34 changes: 33 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions manul/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "manul"
version = "0.2.1"
version = "0.3.0-dev"
edition = "2021"
rust-version = "1.81"
authors = ['Entropy Cryptography <[email protected]>']
Expand Down Expand Up @@ -28,6 +28,7 @@ serde-persistent-deserializer = { version = "0.3", optional = true }
postcard = { version = "1", default-features = false, features = ["alloc"], optional = true }
serde_json = { version = "1", default-features = false, features = ["alloc"], optional = true }
tokio = { version = "1", default-features = false, features = ["sync", "rt", "macros", "time"], optional = true }
tokio-util = { version = "0.7", default-features = false, optional = true }

[dev-dependencies]
impls = "1"
Expand All @@ -44,7 +45,7 @@ tracing = { version = "0.1", default-features = false, features = ["std"] }

[features]
dev = ["rand", "postcard", "serde_json", "tracing/std", "serde-persistent-deserializer"]
tokio = ["dep:tokio"]
tokio = ["dep:tokio", "tokio-util"]

[package.metadata.docs.rs]
all-features = true
Expand Down
7 changes: 5 additions & 2 deletions manul/src/dev/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use rand::Rng;
use rand_core::CryptoRngCore;
use signature::Keypair;
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;

use super::run_sync::ExecutionResult;
use crate::{
Expand Down Expand Up @@ -100,6 +101,7 @@ where

let dispatcher_task = message_dispatcher(rng.clone(), tx_map, dispatcher_rx);
let dispatcher = tokio::spawn(dispatcher_task);
let cancellation = CancellationToken::new();

let handles = rxs
.into_iter()
Expand All @@ -110,12 +112,13 @@ where

let session = Session::<_, SP>::new(&mut rng, session_id.clone(), signer, entry_point)?;
let id = session.verifier().clone();
let cancellation = cancellation.clone();

let node_task = async move {
if offload_processing {
par_run_session(&mut rng, &tx, &mut rx, session).await
par_run_session(&mut rng, &tx, &mut rx, cancellation, session).await
} else {
run_session(&mut rng, &tx, &mut rx, session).await
run_session(&mut rng, &tx, &mut rx, cancellation, session).await
}
};
Ok((id, tokio::spawn(node_task)))
Expand Down
90 changes: 49 additions & 41 deletions manul/src/session/tokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use alloc::{format, sync::Arc, vec::Vec};

use rand_core::CryptoRngCore;
use tokio::{sync::mpsc, task::JoinHandle};
use tokio_util::sync::CancellationToken;
use tracing::{debug, trace};

use super::{
Expand Down Expand Up @@ -50,6 +51,7 @@ pub async fn run_session<P, SP>(
rng: &mut impl CryptoRngCore,
tx: &mpsc::Sender<MessageOut<SP>>,
rx: &mut mpsc::Receiver<MessageIn<SP>>,
cancellation: CancellationToken,
session: Session<P, SP>,
) -> Result<SessionReport<P, SP>, LocalError>
where
Expand Down Expand Up @@ -135,10 +137,14 @@ where
}

debug!("{my_id}: Waiting for a message");
let message_in = rx
.recv()
.await
.ok_or_else(|| LocalError::new("Failed to receive a message"))?;
let message_in = tokio::select! {
message_in = rx.recv() => {
message_in.ok_or_else(|| LocalError::new("The incoming message channel was closed unexpectedly"))?
},
_ = cancellation.cancelled() => {
return session.terminate_due_to_errors(accum);
}
};

// Perform quick checks before proceeding with the verification.
match session
Expand Down Expand Up @@ -184,6 +190,7 @@ pub async fn par_run_session<P, SP>(
rng: &mut (impl 'static + Clone + CryptoRngCore + Send),
tx: &mpsc::Sender<MessageOut<SP>>,
rx: &mut mpsc::Receiver<MessageIn<SP>>,
cancellation: CancellationToken,
session: Session<P, SP>,
) -> Result<SessionReport<P, SP>, LocalError>
where
Expand Down Expand Up @@ -280,49 +287,50 @@ where

tokio::select! {
processed = processed_rx.recv() => {
if let Some(processed) = processed {
session.add_processed_message(&mut accum, processed)?;
}
let processed = processed.ok_or_else(|| LocalError::new("The processed message channel was closed unexpectedly"))?;
session.add_processed_message(&mut accum, processed)?;
}
outgoing = outgoing_rx.recv() => {
if let Some((message_out, artifact)) = outgoing {
let from = message_out.from.clone();
let to = message_out.to.clone();
tx.send(message_out)
.await
.map_err(|err| {
LocalError::new(format!(
"Failed to send a message from {from:?} to {to:?}: {err}",
))
})?;

session.add_artifact(&mut accum, artifact)?;
}
let (message_out, artifact) = outgoing.ok_or_else(|| LocalError::new("The outgoing message channel was closed unexpectedly"))?;

let from = message_out.from.clone();
let to = message_out.to.clone();
tx.send(message_out)
.await
.map_err(|err| {
LocalError::new(format!(
"Failed to send a message from {from:?} to {to:?}: {err}",
))
})?;

session.add_artifact(&mut accum, artifact)?;
}
message_in = rx.recv() => {
if let Some(message_in) = message_in {
match session
.preprocess_message(&mut accum, &message_in.from, message_in.message)?
.ok()
{
Some(preprocessed) => {
let session = session.clone();
let processed_tx = processed_tx.clone();
let my_id = my_id.clone();
let message_processing = tokio::task::spawn_blocking(move || {
debug!("{my_id}: Applying a message from {:?}", message_in.from);
let processed = session.process_message(preprocessed);
processed_tx.blocking_send(processed).map_err(|_err| {
LocalError::new("Failed to send a processed message")
})
});
message_processing_tasks.push(message_processing);
}
None => {
trace!("{my_id} Pre-processing complete. Current state: {accum:?}")
}
let message_in = message_in.ok_or_else(|| LocalError::new("The incoming message channel was closed unexpectedly"))?;
match session
.preprocess_message(&mut accum, &message_in.from, message_in.message)?
.ok()
{
Some(preprocessed) => {
let session = session.clone();
let processed_tx = processed_tx.clone();
let my_id = my_id.clone();
let message_processing = tokio::task::spawn_blocking(move || {
debug!("{my_id}: Applying a message from {:?}", message_in.from);
let processed = session.process_message(preprocessed);
processed_tx.blocking_send(processed).map_err(|_err| {
LocalError::new("Failed to send a processed message")
})
});
message_processing_tasks.push(message_processing);
}
None => {
trace!("{my_id} Pre-processing complete. Current state: {accum:?}")
}
}
},
_ = cancellation.cancelled() => {
break false;
}
}
};
Expand Down
Loading