Skip to content

Commit 47b7a7d

Browse files
committed
Introduce a CancellationToken for cancelling specific computations
1 parent ef9f932 commit 47b7a7d

File tree

8 files changed

+165
-8
lines changed

8 files changed

+165
-8
lines changed

src/cancelled.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,16 @@ use std::panic::{self, UnwindSafe};
77
/// *
88
/// *
99
/// *
10-
#[derive(Debug)]
10+
#[derive(Debug, PartialEq, Eq)]
1111
#[non_exhaustive]
1212
pub enum Cancelled {
13+
/// The query was operating but the local database execution has been cancelled.
14+
Cancelled,
15+
1316
/// The query was operating on revision R, but there is a pending write to move to revision R+1.
14-
#[non_exhaustive]
1517
PendingWrite,
1618

1719
/// The query was blocked on another thread, and that thread panicked.
18-
#[non_exhaustive]
1920
PropagatedPanic,
2021
}
2122

@@ -45,6 +46,7 @@ impl Cancelled {
4546
impl std::fmt::Display for Cancelled {
4647
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4748
let why = match self {
49+
Cancelled::Cancelled => "canellation request",
4850
Cancelled::PendingWrite => "pending write",
4951
Cancelled::PropagatedPanic => "propagated panic",
5052
};

src/database.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::ptr::NonNull;
33

44
use crate::views::DatabaseDownCaster;
55
use crate::zalsa::{IngredientIndex, ZalsaDatabase};
6+
use crate::zalsa_local::CancellationToken;
67
use crate::{Durability, Revision};
78

89
#[derive(Copy, Clone)]
@@ -59,14 +60,19 @@ pub trait Database: Send + ZalsaDatabase + AsDynDatabase {
5960
zalsa_mut.runtime_mut().report_tracked_write(durability);
6061
}
6162

62-
/// This method triggers cancellation.
63+
/// This method cancels all outstanding computations.
6364
/// If you invoke it while a snapshot exists, it
6465
/// will block until that snapshot is dropped -- if that snapshot
6566
/// is owned by the current thread, this could trigger deadlock.
6667
fn trigger_cancellation(&mut self) {
6768
let _ = self.zalsa_mut();
6869
}
6970

71+
/// Retrives a [`CancellationToken`] for the current database.
72+
fn cancellation_token(&self) -> CancellationToken {
73+
self.zalsa_local().cancellation_token()
74+
}
75+
7076
/// Reports that the query depends on some state unknown to salsa.
7177
///
7278
/// Queries which report untracked reads will be re-executed in the next

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ pub use self::runtime::Runtime;
6767
pub use self::storage::{Storage, StorageHandle};
6868
pub use self::update::Update;
6969
pub use self::zalsa::IngredientIndex;
70+
pub use self::zalsa_local::CancellationToken;
7071
pub use crate::attach::{attach, attach_allow_change, with_attached_database};
7172

7273
pub mod prelude {

src/zalsa.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,11 @@ impl Zalsa {
301301
#[inline]
302302
pub(crate) fn unwind_if_revision_cancelled(&self, zalsa_local: &ZalsaLocal) {
303303
self.event(&|| crate::Event::new(crate::EventKind::WillCheckCancellation));
304+
if zalsa_local.take_cancellation() {
305+
zalsa_local.unwind_cancelled();
306+
}
304307
if self.runtime().load_cancellation_flag() {
305-
zalsa_local.unwind_cancelled(self.current_revision());
308+
zalsa_local.unwind_pending_write();
306309
}
307310
}
308311

src/zalsa_local.rs

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ use std::fmt;
33
use std::fmt::Formatter;
44
use std::panic::UnwindSafe;
55
use std::ptr::{self, NonNull};
6+
use std::sync::atomic::Ordering;
7+
use std::sync::Arc;
68

79
use rustc_hash::FxHashMap;
810
use thin_vec::ThinVec;
@@ -39,13 +41,32 @@ pub struct ZalsaLocal {
3941
/// Stores the most recent page for a given ingredient.
4042
/// This is thread-local to avoid contention.
4143
most_recent_pages: UnsafeCell<FxHashMap<IngredientIndex, PageIndex>>,
44+
45+
cancelled: CancellationToken,
46+
}
47+
48+
/// A cancellation token that can be used to cancel a query computation for a specific local `Database`.
49+
#[derive(Default, Clone, Debug)]
50+
pub struct CancellationToken(Arc<AtomicBool>);
51+
52+
impl CancellationToken {
53+
/// Inform the database to cancel the current query computation.
54+
pub fn cancel(&self) {
55+
self.0.store(true, Ordering::Relaxed);
56+
}
57+
58+
/// Check if the query computation has been requested to be cancelled.
59+
pub fn is_cancelled(&self) -> bool {
60+
self.0.load(Ordering::Relaxed)
61+
}
4262
}
4363

4464
impl ZalsaLocal {
4565
pub(crate) fn new() -> Self {
4666
ZalsaLocal {
4767
query_stack: RefCell::new(QueryStack::default()),
4868
most_recent_pages: UnsafeCell::new(FxHashMap::default()),
69+
cancelled: CancellationToken::default(),
4970
}
5071
}
5172

@@ -401,12 +422,24 @@ impl ZalsaLocal {
401422
}
402423
}
403424

425+
pub(crate) fn cancellation_token(&self) -> CancellationToken {
426+
self.cancelled.clone()
427+
}
428+
429+
#[inline]
430+
pub fn take_cancellation(&self) -> bool {
431+
self.cancelled.0.swap(false, Ordering::Relaxed)
432+
}
433+
404434
#[cold]
405-
pub(crate) fn unwind_cancelled(&self, current_revision: Revision) {
406-
// Why is this reporting an untracked read? We do not store the query revisions on unwind do we?
407-
self.report_untracked_read(current_revision);
435+
pub(crate) fn unwind_pending_write(&self) {
408436
Cancelled::PendingWrite.throw();
409437
}
438+
439+
#[cold]
440+
pub(crate) fn unwind_cancelled(&self) {
441+
Cancelled::Cancelled.throw();
442+
}
410443
}
411444

412445
// Okay to implement as `ZalsaLocal`` is !Sync

tests/cancellation_token.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
//! Test that `DeriveWithDb` is correctly derived.
2+
3+
mod common;
4+
5+
use std::{sync::Barrier, thread};
6+
7+
use expect_test::expect;
8+
use salsa::{Cancelled, Database};
9+
10+
use crate::common::LogDatabase;
11+
12+
#[salsa::input(debug)]
13+
struct MyInput {
14+
field: u32,
15+
}
16+
17+
#[salsa::tracked]
18+
fn a(db: &dyn Database, input: MyInput) -> u32 {
19+
BARRIER.wait();
20+
BARRIER2.wait();
21+
b(db, input)
22+
}
23+
#[salsa::tracked]
24+
fn b(db: &dyn Database, input: MyInput) -> u32 {
25+
input.field(db)
26+
}
27+
28+
static BARRIER: Barrier = Barrier::new(2);
29+
static BARRIER2: Barrier = Barrier::new(2);
30+
31+
#[test]
32+
fn cancellation_token() {
33+
let db = common::EventLoggerDatabase::default();
34+
let token = db.cancellation_token();
35+
let input = MyInput::new(&db, 22);
36+
let res = Cancelled::catch(|| {
37+
thread::scope(|s| {
38+
s.spawn(|| {
39+
BARRIER.wait();
40+
token.cancel();
41+
BARRIER2.wait();
42+
});
43+
a(&db, input)
44+
})
45+
});
46+
assert_eq!(res, Err(Cancelled::Cancelled));
47+
db.assert_logs(expect![[r#"
48+
[
49+
"WillCheckCancellation",
50+
"WillExecute { database_key: a(Id(0)) }",
51+
"WillCheckCancellation",
52+
]"#]]);
53+
thread::spawn(|| {
54+
BARRIER.wait();
55+
BARRIER2.wait();
56+
});
57+
a(&db, input);
58+
db.assert_logs(expect![[r#"
59+
[
60+
"WillCheckCancellation",
61+
"WillExecute { database_key: a(Id(0)) }",
62+
"WillCheckCancellation",
63+
"WillExecute { database_key: b(Id(0)) }",
64+
]"#]]);
65+
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// Shuttle doesn't like panics inside of its runtime.
2+
#![cfg(not(feature = "shuttle"))]
3+
4+
//! Test for cancellation when another query is blocked on the cancelled thread.
5+
use salsa::{Cancelled, Database};
6+
7+
use crate::setup::{Knobs, KnobsDatabase};
8+
9+
#[salsa::tracked]
10+
fn query_a(db: &dyn KnobsDatabase) -> u32 {
11+
query_b(db)
12+
}
13+
14+
#[salsa::tracked]
15+
fn query_b(db: &dyn KnobsDatabase) -> u32 {
16+
db.signal(1);
17+
db.wait_for(3);
18+
query_c(db)
19+
}
20+
21+
#[salsa::tracked]
22+
fn query_c(_db: &dyn KnobsDatabase) -> u32 {
23+
1
24+
}
25+
26+
#[test]
27+
fn execute() {
28+
let db = Knobs::default();
29+
let db2 = db.clone();
30+
let db_signaler = db.clone();
31+
let token = db.cancellation_token();
32+
33+
let t1 = std::thread::spawn(move || query_a(&db));
34+
db_signaler.wait_for(1);
35+
db2.signal_on_will_block(2);
36+
let t2 = std::thread::spawn(move || query_a(&db2));
37+
db_signaler.wait_for(2);
38+
token.cancel();
39+
db_signaler.signal(3);
40+
let (r1, r2) = (t1.join(), t2.join());
41+
assert_eq!(
42+
*r1.unwrap_err().downcast::<salsa::Cancelled>().unwrap(),
43+
Cancelled::Cancelled
44+
);
45+
assert_eq!(r2.unwrap(), 1);
46+
}

tests/parallel/main.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
mod setup;
44
mod signal;
55

6+
mod cancellation_token_recomputes;
67
mod cycle_a_t1_b_t2;
78
mod cycle_a_t1_b_t2_fallback;
89
mod cycle_ab_peeping_c;

0 commit comments

Comments
 (0)