Skip to content

Commit d2edc9b

Browse files
committed
Introduce a CancellationToken for cancelling specific computations
1 parent e8ddb4d commit d2edc9b

17 files changed

+310
-66
lines changed

src/attach.rs

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,10 @@ impl Attached {
7070
fn drop(&mut self) {
7171
// Reset database to null if we did anything in `DbGuard::new`.
7272
if let Some(attached) = self.state {
73-
attached.database.set(None);
73+
if let Some(prev) = attached.database.replace(None) {
74+
// SAFETY: `prev` is a valid pointer to a database.
75+
unsafe { prev.as_ref().zalsa_local().uncancel() };
76+
}
7477
}
7578
}
7679
}
@@ -85,25 +88,50 @@ impl Attached {
8588
Db: ?Sized + Database,
8689
{
8790
struct DbGuard<'s> {
88-
state: &'s Attached,
91+
state: Option<&'s Attached>,
8992
prev: Option<NonNull<dyn Database>>,
9093
}
9194

9295
impl<'s> DbGuard<'s> {
9396
#[inline]
9497
fn new(attached: &'s Attached, db: &dyn Database) -> Self {
95-
let prev = attached.database.replace(Some(NonNull::from(db)));
96-
Self {
97-
state: attached,
98-
prev,
98+
let db = NonNull::from(db);
99+
match attached.database.replace(Some(db)) {
100+
Some(prev) => {
101+
if std::ptr::eq(db.as_ptr(), prev.as_ptr()) {
102+
Self {
103+
state: None,
104+
prev: None,
105+
}
106+
} else {
107+
Self {
108+
state: Some(attached),
109+
prev: Some(prev),
110+
}
111+
}
112+
}
113+
None => {
114+
// Otherwise, set the database.
115+
attached.database.set(Some(db));
116+
Self {
117+
state: Some(attached),
118+
prev: None,
119+
}
120+
}
99121
}
100122
}
101123
}
102124

103125
impl Drop for DbGuard<'_> {
104126
#[inline]
105127
fn drop(&mut self) {
106-
self.state.database.set(self.prev);
128+
// Reset database to null if we did anything in `DbGuard::new`.
129+
if let Some(attached) = self.state {
130+
if let Some(prev) = attached.database.replace(self.prev) {
131+
// SAFETY: `prev` is a valid pointer to a database.
132+
unsafe { prev.as_ref().zalsa_local().uncancel() };
133+
}
134+
}
107135
}
108136
}
109137

src/cancelled.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@ use std::panic::{self, UnwindSafe};
1010
#[derive(Debug)]
1111
#[non_exhaustive]
1212
pub enum Cancelled {
13+
/// The query was operating but the local database execution has been cancelled.
14+
Cancelled,
15+
1316
/// The query was operating on revision R, but there is a pending write to move to revision R+1.
14-
#[non_exhaustive]
1517
PendingWrite,
1618

1719
/// The query was blocked on another thread, and that thread panicked.
18-
#[non_exhaustive]
1920
PropagatedPanic,
2021
}
2122

@@ -45,6 +46,7 @@ impl Cancelled {
4546
impl std::fmt::Display for Cancelled {
4647
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
4748
let why = match self {
49+
Cancelled::Cancelled => "canellation request",
4850
Cancelled::PendingWrite => "pending write",
4951
Cancelled::PropagatedPanic => "propagated panic",
5052
};

src/database.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use std::ptr::NonNull;
33

44
use crate::views::DatabaseDownCaster;
55
use crate::zalsa::{IngredientIndex, ZalsaDatabase};
6+
use crate::zalsa_local::CancellationToken;
67
use crate::{Durability, Revision};
78

89
#[derive(Copy, Clone)]
@@ -59,14 +60,19 @@ pub trait Database: Send + ZalsaDatabase + AsDynDatabase {
5960
zalsa_mut.runtime_mut().report_tracked_write(durability);
6061
}
6162

62-
/// This method triggers cancellation.
63+
/// This method cancels all outstanding computations.
6364
/// If you invoke it while a snapshot exists, it
6465
/// will block until that snapshot is dropped -- if that snapshot
6566
/// is owned by the current thread, this could trigger deadlock.
6667
fn trigger_cancellation(&mut self) {
6768
let _ = self.zalsa_mut();
6869
}
6970

71+
/// Retrives a [`CancellationToken`] for the current database.
72+
fn cancellation_token(&self) -> CancellationToken {
73+
self.zalsa_local().cancellation_token()
74+
}
75+
7076
/// Reports that the query depends on some state unknown to salsa.
7177
///
7278
/// Queries which report untracked reads will be re-executed in the next

src/function/fetch.rs

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -105,32 +105,39 @@ where
105105
) -> Option<&'db Memo<'db, C>> {
106106
let database_key_index = self.database_key_index(id);
107107
// Try to claim this query: if someone else has claimed it already, go back and start again.
108-
let claim_guard = match self.sync_table.try_claim(zalsa, id, Reentrancy::Allow) {
109-
ClaimResult::Claimed(guard) => guard,
110-
ClaimResult::Running(blocked_on) => {
111-
blocked_on.block_on(zalsa);
108+
let claim_guard = loop {
109+
match self
110+
.sync_table
111+
.try_claim(zalsa, zalsa_local, id, Reentrancy::Allow)
112+
{
113+
ClaimResult::Claimed(guard) => break guard,
114+
ClaimResult::Running(blocked_on) => {
115+
if !blocked_on.block_on(zalsa) {
116+
continue;
117+
}
112118

113-
if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate {
114-
let memo = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index);
119+
if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate {
120+
let memo = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index);
115121

116-
if let Some(memo) = memo {
117-
if memo.value.is_some() {
118-
memo.block_on_heads(zalsa);
122+
if let Some(memo) = memo {
123+
if memo.value.is_some() {
124+
memo.block_on_heads(zalsa);
125+
}
119126
}
120127
}
121-
}
122128

123-
return None;
124-
}
125-
ClaimResult::Cycle { .. } => {
126-
return Some(self.fetch_cold_cycle(
127-
zalsa,
128-
zalsa_local,
129-
db,
130-
id,
131-
database_key_index,
132-
memo_ingredient_index,
133-
));
129+
return None;
130+
}
131+
ClaimResult::Cycle { .. } => {
132+
return Some(self.fetch_cold_cycle(
133+
zalsa,
134+
zalsa_local,
135+
db,
136+
id,
137+
database_key_index,
138+
memo_ingredient_index,
139+
));
140+
}
134141
}
135142
};
136143

src/function/maybe_changed_after.rs

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -141,21 +141,24 @@ where
141141
) -> Option<VerifyResult> {
142142
let database_key_index = self.database_key_index(key_index);
143143

144-
let claim_guard = match self
145-
.sync_table
146-
.try_claim(zalsa, key_index, Reentrancy::Deny)
147-
{
148-
ClaimResult::Claimed(guard) => guard,
149-
ClaimResult::Running(blocked_on) => {
150-
blocked_on.block_on(zalsa);
151-
return None;
152-
}
153-
ClaimResult::Cycle { .. } => {
154-
return Some(self.maybe_changed_after_cold_cycle(
155-
zalsa_local,
156-
database_key_index,
157-
cycle_heads,
158-
))
144+
let claim_guard = loop {
145+
match self
146+
.sync_table
147+
.try_claim(zalsa, zalsa_local, key_index, Reentrancy::Deny)
148+
{
149+
ClaimResult::Claimed(guard) => break guard,
150+
ClaimResult::Running(blocked_on) => {
151+
if blocked_on.block_on(zalsa) {
152+
return None;
153+
}
154+
}
155+
ClaimResult::Cycle { .. } => {
156+
return Some(self.maybe_changed_after_cold_cycle(
157+
zalsa_local,
158+
database_key_index,
159+
cycle_heads,
160+
))
161+
}
159162
}
160163
};
161164
// Load the current memo, if any.

src/function/memo.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,10 @@ impl<'db, C: Configuration> Memo<'db, C> {
180180
}
181181
TryClaimHeadsResult::Running(running) => {
182182
all_cycles = false;
183-
running.block_on(zalsa);
183+
if !running.block_on(zalsa) {
184+
// FIXME: Handle cancellation properly?
185+
crate::Cancelled::PropagatedPanic.throw();
186+
}
184187
}
185188
}
186189
}

src/function/sync.rs

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use rustc_hash::FxHashMap;
22
use std::collections::hash_map::OccupiedEntry;
33

44
use crate::key::DatabaseKeyIndex;
5+
use crate::plumbing::ZalsaLocal;
56
use crate::runtime::{
67
BlockOnTransferredOwner, BlockResult, BlockTransferredResult, Running, WaitResult,
78
};
@@ -68,6 +69,7 @@ impl SyncTable {
6869
pub(crate) fn try_claim<'me>(
6970
&'me self,
7071
zalsa: &'me Zalsa,
72+
zalsa_local: &'me ZalsaLocal,
7173
key_index: Id,
7274
reentrant: Reentrancy,
7375
) -> ClaimResult<'me> {
@@ -77,7 +79,12 @@ impl SyncTable {
7779
let id = match occupied_entry.get().id {
7880
SyncOwner::Thread(id) => id,
7981
SyncOwner::Transferred => {
80-
return match self.try_claim_transferred(zalsa, occupied_entry, reentrant) {
82+
return match self.try_claim_transferred(
83+
zalsa,
84+
zalsa_local,
85+
occupied_entry,
86+
reentrant,
87+
) {
8188
Ok(claimed) => claimed,
8289
Err(other_thread) => match other_thread.block(write) {
8390
BlockResult::Cycle => ClaimResult::Cycle { inner: false },
@@ -115,6 +122,7 @@ impl SyncTable {
115122
ClaimResult::Claimed(ClaimGuard {
116123
key_index,
117124
zalsa,
125+
zalsa_local,
118126
sync_table: self,
119127
mode: ReleaseMode::Default,
120128
})
@@ -172,6 +180,7 @@ impl SyncTable {
172180
fn try_claim_transferred<'me>(
173181
&'me self,
174182
zalsa: &'me Zalsa,
183+
zalsa_local: &'me ZalsaLocal,
175184
mut entry: OccupiedEntry<Id, SyncState>,
176185
reentrant: Reentrancy,
177186
) -> Result<ClaimResult<'me>, Box<BlockOnTransferredOwner<'me>>> {
@@ -195,6 +204,7 @@ impl SyncTable {
195204
Ok(ClaimResult::Claimed(ClaimGuard {
196205
key_index,
197206
zalsa,
207+
zalsa_local,
198208
sync_table: self,
199209
mode: ReleaseMode::SelfOnly,
200210
}))
@@ -214,6 +224,7 @@ impl SyncTable {
214224
Ok(ClaimResult::Claimed(ClaimGuard {
215225
key_index,
216226
zalsa,
227+
zalsa_local,
217228
sync_table: self,
218229
mode: ReleaseMode::Default,
219230
}))
@@ -295,6 +306,7 @@ pub(crate) struct ClaimGuard<'me> {
295306
zalsa: &'me Zalsa,
296307
sync_table: &'me SyncTable,
297308
mode: ReleaseMode,
309+
zalsa_local: &'me ZalsaLocal,
298310
}
299311

300312
impl<'me> ClaimGuard<'me> {
@@ -319,10 +331,21 @@ impl<'me> ClaimGuard<'me> {
319331
"Release claim on {:?} due to panic",
320332
self.database_key_index()
321333
);
322-
323334
self.release(state, WaitResult::Panicked);
324335
}
325336

337+
#[cold]
338+
#[inline(never)]
339+
fn release_cancelled(&self) {
340+
let mut syncs = self.sync_table.syncs.lock();
341+
let state = syncs.remove(&self.key_index).expect("key claimed twice?");
342+
tracing::debug!(
343+
"Release claim on {:?} due to cancellation",
344+
self.database_key_index()
345+
);
346+
self.release(state, WaitResult::Cancelled);
347+
}
348+
326349
#[inline(always)]
327350
fn release(&self, state: SyncState, wait_result: WaitResult) {
328351
let SyncState {
@@ -446,7 +469,11 @@ impl<'me> ClaimGuard<'me> {
446469
impl Drop for ClaimGuard<'_> {
447470
fn drop(&mut self) {
448471
if thread::panicking() {
449-
self.release_panicking();
472+
if self.zalsa_local.is_cancelled() {
473+
self.release_cancelled();
474+
} else {
475+
self.release_panicking();
476+
}
450477
return;
451478
}
452479

src/interned.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,7 @@ pub struct StructEntry<'db, C>
850850
where
851851
C: Configuration,
852852
{
853+
#[allow(dead_code)]
853854
value: &'db Value<C>,
854855
key: DatabaseKeyIndex,
855856
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ pub use self::runtime::Runtime;
6262
pub use self::storage::{Storage, StorageHandle};
6363
pub use self::update::Update;
6464
pub use self::zalsa::IngredientIndex;
65+
pub use self::zalsa_local::CancellationToken;
6566
pub use crate::attach::{attach, attach_allow_change, with_attached_database};
6667

6768
pub mod prelude {

0 commit comments

Comments
 (0)