@@ -4,6 +4,7 @@ use alloc::{format, sync::Arc, vec::Vec};
4
4
5
5
use rand_core:: CryptoRngCore ;
6
6
use tokio:: { sync:: mpsc, task:: JoinHandle } ;
7
+ use tokio_util:: sync:: CancellationToken ;
7
8
use tracing:: { debug, trace} ;
8
9
9
10
use super :: {
@@ -50,6 +51,7 @@ pub async fn run_session<P, SP>(
50
51
rng : & mut impl CryptoRngCore ,
51
52
tx : & mpsc:: Sender < MessageOut < SP > > ,
52
53
rx : & mut mpsc:: Receiver < MessageIn < SP > > ,
54
+ cancellation : CancellationToken ,
53
55
session : Session < P , SP > ,
54
56
) -> Result < SessionReport < P , SP > , LocalError >
55
57
where
@@ -135,10 +137,14 @@ where
135
137
}
136
138
137
139
debug ! ( "{my_id}: Waiting for a message" ) ;
138
- let message_in = rx
139
- . recv ( )
140
- . await
141
- . ok_or_else ( || LocalError :: new ( "Failed to receive a message" ) ) ?;
140
+ let message_in = tokio:: select! {
141
+ message_in = rx. recv( ) => {
142
+ message_in. ok_or_else( || LocalError :: new( "The incoming message channel was closed unexpectedly" ) ) ?
143
+ } ,
144
+ _ = cancellation. cancelled( ) => {
145
+ return session. terminate_due_to_errors( accum) ;
146
+ }
147
+ } ;
142
148
143
149
// Perform quick checks before proceeding with the verification.
144
150
match session
@@ -184,6 +190,7 @@ pub async fn par_run_session<P, SP>(
184
190
rng : & mut ( impl ' static + Clone + CryptoRngCore + Send ) ,
185
191
tx : & mpsc:: Sender < MessageOut < SP > > ,
186
192
rx : & mut mpsc:: Receiver < MessageIn < SP > > ,
193
+ cancellation : CancellationToken ,
187
194
session : Session < P , SP > ,
188
195
) -> Result < SessionReport < P , SP > , LocalError >
189
196
where
@@ -280,49 +287,50 @@ where
280
287
281
288
tokio:: select! {
282
289
processed = processed_rx. recv( ) => {
283
- if let Some ( processed) = processed {
284
- session. add_processed_message( & mut accum, processed) ?;
285
- }
290
+ let processed = processed. ok_or_else( || LocalError :: new( "The processed message channel was closed unexpectedly" ) ) ?;
291
+ session. add_processed_message( & mut accum, processed) ?;
286
292
}
287
293
outgoing = outgoing_rx. recv( ) => {
288
- if let Some ( ( message_out, artifact) ) = outgoing {
289
- let from = message_out . from . clone ( ) ;
290
- let to = message_out. to . clone( ) ;
291
- tx . send ( message_out )
292
- . await
293
- . map_err ( |err| {
294
- LocalError :: new ( format! (
295
- "Failed to send a message from {from:?} to {to:?}: {err}" ,
296
- ) )
297
- } ) ? ;
298
-
299
- session . add_artifact ( & mut accum , artifact ) ? ;
300
- }
294
+ let ( message_out, artifact) = outgoing. ok_or_else ( || LocalError :: new ( "The outgoing message channel was closed unexpectedly" ) ) ? ;
295
+
296
+ let from = message_out. from . clone( ) ;
297
+ let to = message_out . to . clone ( ) ;
298
+ tx . send ( message_out )
299
+ . await
300
+ . map_err ( |err| {
301
+ LocalError :: new ( format! (
302
+ "Failed to send a message from {from:?} to {to:?}: {err}" ,
303
+ ) )
304
+ } ) ? ;
305
+
306
+ session . add_artifact ( & mut accum , artifact ) ? ;
301
307
}
302
308
message_in = rx. recv( ) => {
303
- if let Some ( message_in) = message_in {
304
- match session
305
- . preprocess_message( & mut accum, & message_in. from, message_in. message) ?
306
- . ok( )
307
- {
308
- Some ( preprocessed) => {
309
- let session = session. clone( ) ;
310
- let processed_tx = processed_tx. clone( ) ;
311
- let my_id = my_id. clone( ) ;
312
- let message_processing = tokio:: task:: spawn_blocking( move || {
313
- debug!( "{my_id}: Applying a message from {:?}" , message_in. from) ;
314
- let processed = session. process_message( preprocessed) ;
315
- processed_tx. blocking_send( processed) . map_err( |_err| {
316
- LocalError :: new( "Failed to send a processed message" )
317
- } )
318
- } ) ;
319
- message_processing_tasks. push( message_processing) ;
320
- }
321
- None => {
322
- trace!( "{my_id} Pre-processing complete. Current state: {accum:?}" )
323
- }
309
+ let message_in = message_in. ok_or_else( || LocalError :: new( "The incoming message channel was closed unexpectedly" ) ) ?;
310
+ match session
311
+ . preprocess_message( & mut accum, & message_in. from, message_in. message) ?
312
+ . ok( )
313
+ {
314
+ Some ( preprocessed) => {
315
+ let session = session. clone( ) ;
316
+ let processed_tx = processed_tx. clone( ) ;
317
+ let my_id = my_id. clone( ) ;
318
+ let message_processing = tokio:: task:: spawn_blocking( move || {
319
+ debug!( "{my_id}: Applying a message from {:?}" , message_in. from) ;
320
+ let processed = session. process_message( preprocessed) ;
321
+ processed_tx. blocking_send( processed) . map_err( |_err| {
322
+ LocalError :: new( "Failed to send a processed message" )
323
+ } )
324
+ } ) ;
325
+ message_processing_tasks. push( message_processing) ;
326
+ }
327
+ None => {
328
+ trace!( "{my_id} Pre-processing complete. Current state: {accum:?}" )
324
329
}
325
330
}
331
+ } ,
332
+ _ = cancellation. cancelled( ) => {
333
+ break false ;
326
334
}
327
335
}
328
336
} ;
0 commit comments