Skip to content

Commit f97346e

Browse files
committed
Introduce a CancellationToken for cancelling specific computations
1 parent e8ddb4d commit f97346e

File tree

15 files changed

+243
-33
lines changed

15 files changed

+243
-33
lines changed

src/cancelled.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use std::fmt;
22
use std::panic::{self, UnwindSafe};
33

4+
use crate::CancellationToken;
5+
46
/// A panic payload indicating that execution of a salsa query was cancelled.
57
///
68
/// This can occur for a few reasons:
@@ -10,15 +12,25 @@ use std::panic::{self, UnwindSafe};
1012
#[derive(Debug)]
1113
#[non_exhaustive]
1214
pub enum Cancelled {
15+
/// The query was operating but the local database execution has been cancelled.
16+
Cancelled(UncancelGuard),
17+
1318
/// The query was operating on revision R, but there is a pending write to move to revision R+1.
14-
#[non_exhaustive]
1519
PendingWrite,
1620

1721
/// The query was blocked on another thread, and that thread panicked.
18-
#[non_exhaustive]
1922
PropagatedPanic,
2023
}
2124

25+
#[derive(Debug)]
26+
pub struct UncancelGuard(pub(crate) CancellationToken);
27+
28+
impl Drop for UncancelGuard {
29+
fn drop(&mut self) {
30+
self.0.uncancel();
31+
}
32+
}
33+
2234
impl Cancelled {
2335
#[cold]
2436
pub(crate) fn throw(self) -> ! {
@@ -45,6 +57,7 @@ impl Cancelled {
4557
impl std::fmt::Display for Cancelled {
4658
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4759
let why = match self {
60+
Cancelled::Cancelled(_) => "canellation request",
4861
Cancelled::PendingWrite => "pending write",
4962
Cancelled::PropagatedPanic => "propagated panic",
5063
};

src/database.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::ptr::NonNull;
33

44
use crate::views::DatabaseDownCaster;
55
use crate::zalsa::{IngredientIndex, ZalsaDatabase};
6+
use crate::zalsa_local::CancellationToken;
67
use crate::{Durability, Revision};
78

89
#[derive(Copy, Clone)]
@@ -59,14 +60,19 @@ pub trait Database: Send + ZalsaDatabase + AsDynDatabase {
5960
zalsa_mut.runtime_mut().report_tracked_write(durability);
6061
}
6162

62-
/// This method triggers cancellation.
63+
/// This method cancels all outstanding computations.
6364
/// If you invoke it while a snapshot exists, it
6465
/// will block until that snapshot is dropped -- if that snapshot
6566
/// is owned by the current thread, this could trigger deadlock.
6667
fn trigger_cancellation(&mut self) {
6768
let _ = self.zalsa_mut();
6869
}
6970

71+
/// Retrives a [`CancellationToken`] for the current database.
72+
fn cancellation_token(&self) -> CancellationToken {
73+
self.zalsa_local().cancellation_token()
74+
}
75+
7076
/// Reports that the query depends on some state unknown to salsa.
7177
///
7278
/// Queries which report untracked reads will be re-executed in the next

src/function/fetch.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,10 @@ where
105105
) -> Option<&'db Memo<'db, C>> {
106106
let database_key_index = self.database_key_index(id);
107107
// Try to claim this query: if someone else has claimed it already, go back and start again.
108-
let claim_guard = match self.sync_table.try_claim(zalsa, id, Reentrancy::Allow) {
108+
let claim_guard = match self
109+
.sync_table
110+
.try_claim(zalsa, zalsa_local, id, Reentrancy::Allow)
111+
{
109112
ClaimResult::Claimed(guard) => guard,
110113
ClaimResult::Running(blocked_on) => {
111114
blocked_on.block_on(zalsa);

src/function/maybe_changed_after.rs

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -141,23 +141,24 @@ where
141141
) -> Option<VerifyResult> {
142142
let database_key_index = self.database_key_index(key_index);
143143

144-
let claim_guard = match self
145-
.sync_table
146-
.try_claim(zalsa, key_index, Reentrancy::Deny)
147-
{
148-
ClaimResult::Claimed(guard) => guard,
149-
ClaimResult::Running(blocked_on) => {
150-
blocked_on.block_on(zalsa);
151-
return None;
152-
}
153-
ClaimResult::Cycle { .. } => {
154-
return Some(self.maybe_changed_after_cold_cycle(
155-
zalsa_local,
156-
database_key_index,
157-
cycle_heads,
158-
))
159-
}
160-
};
144+
let claim_guard =
145+
match self
146+
.sync_table
147+
.try_claim(zalsa, zalsa_local, key_index, Reentrancy::Deny)
148+
{
149+
ClaimResult::Claimed(guard) => guard,
150+
ClaimResult::Running(blocked_on) => {
151+
blocked_on.block_on(zalsa);
152+
return None;
153+
}
154+
ClaimResult::Cycle { .. } => {
155+
return Some(self.maybe_changed_after_cold_cycle(
156+
zalsa_local,
157+
database_key_index,
158+
cycle_heads,
159+
))
160+
}
161+
};
161162
// Load the current memo, if any.
162163
let Some(old_memo) = self.get_memo_from_table_for(zalsa, key_index, memo_ingredient_index)
163164
else {

src/function/sync.rs

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use rustc_hash::FxHashMap;
22
use std::collections::hash_map::OccupiedEntry;
33

44
use crate::key::DatabaseKeyIndex;
5+
use crate::plumbing::ZalsaLocal;
56
use crate::runtime::{
67
BlockOnTransferredOwner, BlockResult, BlockTransferredResult, Running, WaitResult,
78
};
@@ -68,6 +69,7 @@ impl SyncTable {
6869
pub(crate) fn try_claim<'me>(
6970
&'me self,
7071
zalsa: &'me Zalsa,
72+
zalsa_local: &'me ZalsaLocal,
7173
key_index: Id,
7274
reentrant: Reentrancy,
7375
) -> ClaimResult<'me> {
@@ -77,7 +79,12 @@ impl SyncTable {
7779
let id = match occupied_entry.get().id {
7880
SyncOwner::Thread(id) => id,
7981
SyncOwner::Transferred => {
80-
return match self.try_claim_transferred(zalsa, occupied_entry, reentrant) {
82+
return match self.try_claim_transferred(
83+
zalsa,
84+
zalsa_local,
85+
occupied_entry,
86+
reentrant,
87+
) {
8188
Ok(claimed) => claimed,
8289
Err(other_thread) => match other_thread.block(write) {
8390
BlockResult::Cycle => ClaimResult::Cycle { inner: false },
@@ -115,6 +122,7 @@ impl SyncTable {
115122
ClaimResult::Claimed(ClaimGuard {
116123
key_index,
117124
zalsa,
125+
zalsa_local,
118126
sync_table: self,
119127
mode: ReleaseMode::Default,
120128
})
@@ -172,6 +180,7 @@ impl SyncTable {
172180
fn try_claim_transferred<'me>(
173181
&'me self,
174182
zalsa: &'me Zalsa,
183+
zalsa_local: &'me ZalsaLocal,
175184
mut entry: OccupiedEntry<Id, SyncState>,
176185
reentrant: Reentrancy,
177186
) -> Result<ClaimResult<'me>, Box<BlockOnTransferredOwner<'me>>> {
@@ -195,6 +204,7 @@ impl SyncTable {
195204
Ok(ClaimResult::Claimed(ClaimGuard {
196205
key_index,
197206
zalsa,
207+
zalsa_local,
198208
sync_table: self,
199209
mode: ReleaseMode::SelfOnly,
200210
}))
@@ -214,6 +224,7 @@ impl SyncTable {
214224
Ok(ClaimResult::Claimed(ClaimGuard {
215225
key_index,
216226
zalsa,
227+
zalsa_local,
217228
sync_table: self,
218229
mode: ReleaseMode::Default,
219230
}))
@@ -295,6 +306,7 @@ pub(crate) struct ClaimGuard<'me> {
295306
zalsa: &'me Zalsa,
296307
sync_table: &'me SyncTable,
297308
mode: ReleaseMode,
309+
zalsa_local: &'me ZalsaLocal,
298310
}
299311

300312
impl<'me> ClaimGuard<'me> {
@@ -319,10 +331,21 @@ impl<'me> ClaimGuard<'me> {
319331
"Release claim on {:?} due to panic",
320332
self.database_key_index()
321333
);
322-
323334
self.release(state, WaitResult::Panicked);
324335
}
325336

337+
#[cold]
338+
#[inline(never)]
339+
fn release_cancelled(&self) {
340+
let mut syncs = self.sync_table.syncs.lock();
341+
let state = syncs.remove(&self.key_index).expect("key claimed twice?");
342+
tracing::debug!(
343+
"Release claim on {:?} due to cancellation",
344+
self.database_key_index()
345+
);
346+
self.release(state, WaitResult::Canceled);
347+
}
348+
326349
#[inline(always)]
327350
fn release(&self, state: SyncState, wait_result: WaitResult) {
328351
let SyncState {
@@ -446,7 +469,11 @@ impl<'me> ClaimGuard<'me> {
446469
impl Drop for ClaimGuard<'_> {
447470
fn drop(&mut self) {
448471
if thread::panicking() {
449-
self.release_panicking();
472+
if self.zalsa_local.is_cancelled() {
473+
self.release_cancelled();
474+
} else {
475+
self.release_panicking();
476+
}
450477
return;
451478
}
452479

src/interned.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,7 @@ pub struct StructEntry<'db, C>
850850
where
851851
C: Configuration,
852852
{
853+
#[allow(dead_code)]
853854
value: &'db Value<C>,
854855
key: DatabaseKeyIndex,
855856
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ pub use self::runtime::Runtime;
6262
pub use self::storage::{Storage, StorageHandle};
6363
pub use self::update::Update;
6464
pub use self::zalsa::IngredientIndex;
65+
pub use self::zalsa_local::CancellationToken;
6566
pub use crate::attach::{attach, attach_allow_change, with_attached_database};
6667

6768
pub mod prelude {

src/runtime.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ pub struct Runtime {
4444
pub(super) enum WaitResult {
4545
Completed,
4646
Panicked,
47+
Canceled,
4748
}
4849

4950
#[derive(Debug)]
@@ -121,7 +122,7 @@ struct BlockedOnInner<'me> {
121122

122123
impl Running<'_> {
123124
/// Blocks on the other thread to complete the computation.
124-
pub(crate) fn block_on(self, zalsa: &Zalsa) {
125+
pub(crate) fn block_on(self, zalsa: &Zalsa) -> bool {
125126
let BlockedOnInner {
126127
dg,
127128
query_mutex_guard,
@@ -151,7 +152,8 @@ impl Running<'_> {
151152
// by the other thread and responded to appropriately.
152153
Cancelled::PropagatedPanic.throw()
153154
}
154-
WaitResult::Completed => {}
155+
WaitResult::Canceled => true,
156+
WaitResult::Completed => false,
155157
}
156158
}
157159
}

src/storage.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,15 @@ impl<Db: Database> Storage<Db> {
124124
.record_unfilled_pages(self.handle.zalsa_impl.table());
125125
let Self {
126126
handle,
127-
zalsa_local: _,
128-
} = &self;
127+
zalsa_local,
128+
} = &mut self;
129129
// Avoid rust's annoying destructure prevention rules for `Drop` types
130130
// SAFETY: We forget `Self` afterwards to discard the original copy, and the destructure
131131
// above makes sure we won't forget to take into account newly added fields.
132132
let handle = unsafe { std::ptr::read(handle) };
133+
// SAFETY: We forget `Self` afterwards to discard the original copy, and the destructure
134+
// above makes sure we won't forget to take into account newly added fields.
135+
unsafe { std::ptr::drop_in_place(zalsa_local) };
133136
std::mem::forget::<Self>(self);
134137
handle
135138
}

src/zalsa.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,11 @@ impl Zalsa {
301301
#[inline]
302302
pub(crate) fn unwind_if_revision_cancelled(&self, zalsa_local: &ZalsaLocal) {
303303
self.event(&|| crate::Event::new(crate::EventKind::WillCheckCancellation));
304+
if zalsa_local.is_cancelled() {
305+
zalsa_local.unwind_cancelled();
306+
}
304307
if self.runtime().load_cancellation_flag() {
305-
zalsa_local.unwind_cancelled(self.current_revision());
308+
zalsa_local.unwind_pending_write();
306309
}
307310
}
308311

0 commit comments

Comments
 (0)