From e46c2b6ceb792d5e37099c98e3360a5917514142 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Wed, 30 Apr 2025 11:56:46 -0700 Subject: [PATCH] Add cancellation support to `session::tokio` functions --- CHANGELOG.md | 10 +++++ Cargo.lock | 34 +++++++++++++- manul/Cargo.toml | 5 ++- manul/src/dev/tokio.rs | 7 ++- manul/src/session/tokio.rs | 90 +++++++++++++++++++++----------------- 5 files changed, 100 insertions(+), 46 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a1b4d9..88d80fb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,16 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.3.0] - in development + +### Changed + +- `session::tokio::run_session()` and `par_run_session()` take an additional `cancellation` argument to support external loop cancellation. ([#100]) + + +[#100]: https://github.com/entropyxyz/manul/pull/100 + + ## [0.2.1] - 2025-05-05 ### Added diff --git a/Cargo.lock b/Cargo.lock index 42e3aed..ee9846f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -144,6 +144,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +[[package]] +name = "bytes" +version = "1.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" + [[package]] name = "cast" version = "0.3.0" @@ -393,6 +399,18 @@ dependencies = [ "typeid", ] +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-sink" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" + [[package]] name = "generic-array" version = "0.14.7" @@ -551,7 +569,7 @@ checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "manul" -version = "0.2.1" +version = "0.3.0-dev" dependencies = [ "criterion", "derive-where", @@ -570,6 +588,7 @@ dependencies = [ "signature", "tinyvec", "tokio", + "tokio-util", "tracing", ] @@ -1067,6 +1086,19 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-util" +version = "0.7.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "tracing" version = "0.1.40" diff --git a/manul/Cargo.toml b/manul/Cargo.toml index 314cb76..213faf7 100644 --- a/manul/Cargo.toml +++ b/manul/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "manul" -version = "0.2.1" +version = "0.3.0-dev" edition = "2021" rust-version = "1.81" authors = ['Entropy Cryptography '] @@ -28,6 +28,7 @@ serde-persistent-deserializer = { version = "0.3", optional = true } postcard = { version = "1", default-features = false, features = ["alloc"], optional = true } serde_json = { version = "1", default-features = false, features = ["alloc"], optional = true } tokio = { version = "1", default-features = false, features = ["sync", "rt", "macros", "time"], optional = true } +tokio-util = { version = "0.7", default-features = false, optional = true } [dev-dependencies] impls = "1" @@ -44,7 +45,7 @@ tracing = { version = "0.1", default-features = false, features = ["std"] } [features] dev = ["rand", "postcard", "serde_json", "tracing/std", "serde-persistent-deserializer"] -tokio = ["dep:tokio"] +tokio = ["dep:tokio", "tokio-util"] [package.metadata.docs.rs] all-features = true diff --git a/manul/src/dev/tokio.rs b/manul/src/dev/tokio.rs index 718b2b6..e96da80 100644 --- a/manul/src/dev/tokio.rs +++ b/manul/src/dev/tokio.rs @@ -6,6 +6,7 @@ use rand::Rng; use rand_core::CryptoRngCore; use signature::Keypair; use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; use super::run_sync::ExecutionResult; use crate::{ @@ -100,6 +101,7 @@ where let dispatcher_task = message_dispatcher(rng.clone(), tx_map, dispatcher_rx); let dispatcher = tokio::spawn(dispatcher_task); + let cancellation = CancellationToken::new(); let handles = rxs .into_iter() @@ -110,12 +112,13 @@ where let session = Session::<_, SP>::new(&mut rng, session_id.clone(), signer, entry_point)?; let id = session.verifier().clone(); + let cancellation = cancellation.clone(); let node_task = async move { if offload_processing { - par_run_session(&mut rng, &tx, &mut rx, session).await + par_run_session(&mut rng, &tx, &mut rx, cancellation, session).await } else { - run_session(&mut rng, &tx, &mut rx, session).await + run_session(&mut rng, &tx, &mut rx, cancellation, session).await } }; Ok((id, tokio::spawn(node_task))) diff --git a/manul/src/session/tokio.rs b/manul/src/session/tokio.rs index 748ac8a..18211aa 100644 --- a/manul/src/session/tokio.rs +++ b/manul/src/session/tokio.rs @@ -4,6 +4,7 @@ use alloc::{format, sync::Arc, vec::Vec}; use rand_core::CryptoRngCore; use tokio::{sync::mpsc, task::JoinHandle}; +use tokio_util::sync::CancellationToken; use tracing::{debug, trace}; use super::{ @@ -50,6 +51,7 @@ pub async fn run_session( rng: &mut impl CryptoRngCore, tx: &mpsc::Sender>, rx: &mut mpsc::Receiver>, + cancellation: CancellationToken, session: Session, ) -> Result, LocalError> where @@ -135,10 +137,14 @@ where } debug!("{my_id}: Waiting for a message"); - let message_in = rx - .recv() - .await - .ok_or_else(|| LocalError::new("Failed to receive a message"))?; + let message_in = tokio::select! { + message_in = rx.recv() => { + message_in.ok_or_else(|| LocalError::new("The incoming message channel was closed unexpectedly"))? + }, + _ = cancellation.cancelled() => { + return session.terminate_due_to_errors(accum); + } + }; // Perform quick checks before proceeding with the verification. match session @@ -184,6 +190,7 @@ pub async fn par_run_session( rng: &mut (impl 'static + Clone + CryptoRngCore + Send), tx: &mpsc::Sender>, rx: &mut mpsc::Receiver>, + cancellation: CancellationToken, session: Session, ) -> Result, LocalError> where @@ -280,49 +287,50 @@ where tokio::select! { processed = processed_rx.recv() => { - if let Some(processed) = processed { - session.add_processed_message(&mut accum, processed)?; - } + let processed = processed.ok_or_else(|| LocalError::new("The processed message channel was closed unexpectedly"))?; + session.add_processed_message(&mut accum, processed)?; } outgoing = outgoing_rx.recv() => { - if let Some((message_out, artifact)) = outgoing { - let from = message_out.from.clone(); - let to = message_out.to.clone(); - tx.send(message_out) - .await - .map_err(|err| { - LocalError::new(format!( - "Failed to send a message from {from:?} to {to:?}: {err}", - )) - })?; - - session.add_artifact(&mut accum, artifact)?; - } + let (message_out, artifact) = outgoing.ok_or_else(|| LocalError::new("The outgoing message channel was closed unexpectedly"))?; + + let from = message_out.from.clone(); + let to = message_out.to.clone(); + tx.send(message_out) + .await + .map_err(|err| { + LocalError::new(format!( + "Failed to send a message from {from:?} to {to:?}: {err}", + )) + })?; + + session.add_artifact(&mut accum, artifact)?; } message_in = rx.recv() => { - if let Some(message_in) = message_in { - match session - .preprocess_message(&mut accum, &message_in.from, message_in.message)? - .ok() - { - Some(preprocessed) => { - let session = session.clone(); - let processed_tx = processed_tx.clone(); - let my_id = my_id.clone(); - let message_processing = tokio::task::spawn_blocking(move || { - debug!("{my_id}: Applying a message from {:?}", message_in.from); - let processed = session.process_message(preprocessed); - processed_tx.blocking_send(processed).map_err(|_err| { - LocalError::new("Failed to send a processed message") - }) - }); - message_processing_tasks.push(message_processing); - } - None => { - trace!("{my_id} Pre-processing complete. Current state: {accum:?}") - } + let message_in = message_in.ok_or_else(|| LocalError::new("The incoming message channel was closed unexpectedly"))?; + match session + .preprocess_message(&mut accum, &message_in.from, message_in.message)? + .ok() + { + Some(preprocessed) => { + let session = session.clone(); + let processed_tx = processed_tx.clone(); + let my_id = my_id.clone(); + let message_processing = tokio::task::spawn_blocking(move || { + debug!("{my_id}: Applying a message from {:?}", message_in.from); + let processed = session.process_message(preprocessed); + processed_tx.blocking_send(processed).map_err(|_err| { + LocalError::new("Failed to send a processed message") + }) + }); + message_processing_tasks.push(message_processing); + } + None => { + trace!("{my_id} Pre-processing complete. Current state: {accum:?}") } } + }, + _ = cancellation.cancelled() => { + break false; } } };