From a5e17044c26727c19f405c4083e77f45a6e995c3 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Fri, 17 Oct 2025 19:01:29 +0200 Subject: [PATCH 1/2] Introduce a `CancellationToken` for cancelling specific computations --- src/attach.rs | 42 ++++++++++-- src/cancelled.rs | 6 +- src/database.rs | 8 ++- src/function/fetch.rs | 49 ++++++++------ src/function/maybe_changed_after.rs | 33 ++++----- src/function/memo.rs | 5 +- src/function/sync.rs | 37 ++++++++-- src/interned.rs | 1 + src/lib.rs | 1 + src/runtime.rs | 24 ++++--- src/storage.rs | 7 +- src/zalsa.rs | 5 +- src/zalsa_local.rs | 49 +++++++++++++- tests/cancellation_token.rs | 67 +++++++++++++++++++ tests/interned-revisions.rs | 2 +- .../parallel/cancellation_token_recomputes.rs | 43 ++++++++++++ tests/parallel/main.rs | 1 + 17 files changed, 312 insertions(+), 68 deletions(-) create mode 100644 tests/cancellation_token.rs create mode 100644 tests/parallel/cancellation_token_recomputes.rs diff --git a/src/attach.rs b/src/attach.rs index 973da8959..3a6edc93f 100644 --- a/src/attach.rs +++ b/src/attach.rs @@ -70,7 +70,10 @@ impl Attached { fn drop(&mut self) { // Reset database to null if we did anything in `DbGuard::new`. if let Some(attached) = self.state { - attached.database.set(None); + if let Some(prev) = attached.database.replace(None) { + // SAFETY: `prev` is a valid pointer to a database. + unsafe { prev.as_ref().zalsa_local().uncancel() }; + } } } } @@ -85,17 +88,36 @@ impl Attached { Db: ?Sized + Database, { struct DbGuard<'s> { - state: &'s Attached, + state: Option<&'s Attached>, prev: Option>, } impl<'s> DbGuard<'s> { #[inline] fn new(attached: &'s Attached, db: &dyn Database) -> Self { - let prev = attached.database.replace(Some(NonNull::from(db))); - Self { - state: attached, - prev, + let db = NonNull::from(db); + match attached.database.replace(Some(db)) { + Some(prev) => { + if std::ptr::eq(db.as_ptr(), prev.as_ptr()) { + Self { + state: None, + prev: None, + } + } else { + Self { + state: Some(attached), + prev: Some(prev), + } + } + } + None => { + // Otherwise, set the database. + attached.database.set(Some(db)); + Self { + state: Some(attached), + prev: None, + } + } } } } @@ -103,7 +125,13 @@ impl Attached { impl Drop for DbGuard<'_> { #[inline] fn drop(&mut self) { - self.state.database.set(self.prev); + // Reset database to null if we did anything in `DbGuard::new`. + if let Some(attached) = self.state { + if let Some(prev) = attached.database.replace(self.prev) { + // SAFETY: `prev` is a valid pointer to a database. + unsafe { prev.as_ref().zalsa_local().uncancel() }; + } + } } } diff --git a/src/cancelled.rs b/src/cancelled.rs index 3c31bae5a..1fa0edc59 100644 --- a/src/cancelled.rs +++ b/src/cancelled.rs @@ -10,12 +10,13 @@ use std::panic::{self, UnwindSafe}; #[derive(Debug)] #[non_exhaustive] pub enum Cancelled { + /// The query was operating but the local database execution has been cancelled. + Cancelled, + /// The query was operating on revision R, but there is a pending write to move to revision R+1. - #[non_exhaustive] PendingWrite, /// The query was blocked on another thread, and that thread panicked. - #[non_exhaustive] PropagatedPanic, } @@ -45,6 +46,7 @@ impl Cancelled { impl std::fmt::Display for Cancelled { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let why = match self { + Cancelled::Cancelled => "canellation request", Cancelled::PendingWrite => "pending write", Cancelled::PropagatedPanic => "propagated panic", }; diff --git a/src/database.rs b/src/database.rs index 0df83b03b..9cb70f917 100644 --- a/src/database.rs +++ b/src/database.rs @@ -3,6 +3,7 @@ use std::ptr::NonNull; use crate::views::DatabaseDownCaster; use crate::zalsa::{IngredientIndex, ZalsaDatabase}; +use crate::zalsa_local::CancellationToken; use crate::{Durability, Revision}; #[derive(Copy, Clone)] @@ -59,7 +60,7 @@ pub trait Database: Send + ZalsaDatabase + AsDynDatabase { zalsa_mut.runtime_mut().report_tracked_write(durability); } - /// This method triggers cancellation. + /// This method cancels all outstanding computations. /// If you invoke it while a snapshot exists, it /// will block until that snapshot is dropped -- if that snapshot /// is owned by the current thread, this could trigger deadlock. @@ -67,6 +68,11 @@ pub trait Database: Send + ZalsaDatabase + AsDynDatabase { let _ = self.zalsa_mut(); } + /// Retrives a [`CancellationToken`] for the current database. + fn cancellation_token(&self) -> CancellationToken { + self.zalsa_local().cancellation_token() + } + /// Reports that the query depends on some state unknown to salsa. /// /// Queries which report untracked reads will be re-executed in the next diff --git a/src/function/fetch.rs b/src/function/fetch.rs index f1c58eda1..9adc01a07 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -105,32 +105,39 @@ where ) -> Option<&'db Memo<'db, C>> { let database_key_index = self.database_key_index(id); // Try to claim this query: if someone else has claimed it already, go back and start again. - let claim_guard = match self.sync_table.try_claim(zalsa, id, Reentrancy::Allow) { - ClaimResult::Claimed(guard) => guard, - ClaimResult::Running(blocked_on) => { - blocked_on.block_on(zalsa); + let claim_guard = loop { + match self + .sync_table + .try_claim(zalsa, zalsa_local, id, Reentrancy::Allow) + { + ClaimResult::Claimed(guard) => break guard, + ClaimResult::Running(blocked_on) => { + if !blocked_on.block_on(zalsa) { + continue; + } - if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate { - let memo = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); + if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate { + let memo = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); - if let Some(memo) = memo { - if memo.value.is_some() { - memo.block_on_heads(zalsa); + if let Some(memo) = memo { + if memo.value.is_some() { + memo.block_on_heads(zalsa); + } } } - } - return None; - } - ClaimResult::Cycle { .. } => { - return Some(self.fetch_cold_cycle( - zalsa, - zalsa_local, - db, - id, - database_key_index, - memo_ingredient_index, - )); + return None; + } + ClaimResult::Cycle { .. } => { + return Some(self.fetch_cold_cycle( + zalsa, + zalsa_local, + db, + id, + database_key_index, + memo_ingredient_index, + )); + } } }; diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 4198631b9..a92ea1876 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -141,21 +141,24 @@ where ) -> Option { let database_key_index = self.database_key_index(key_index); - let claim_guard = match self - .sync_table - .try_claim(zalsa, key_index, Reentrancy::Deny) - { - ClaimResult::Claimed(guard) => guard, - ClaimResult::Running(blocked_on) => { - blocked_on.block_on(zalsa); - return None; - } - ClaimResult::Cycle { .. } => { - return Some(self.maybe_changed_after_cold_cycle( - zalsa_local, - database_key_index, - cycle_heads, - )) + let claim_guard = loop { + match self + .sync_table + .try_claim(zalsa, zalsa_local, key_index, Reentrancy::Deny) + { + ClaimResult::Claimed(guard) => break guard, + ClaimResult::Running(blocked_on) => { + if blocked_on.block_on(zalsa) { + return None; + } + } + ClaimResult::Cycle { .. } => { + return Some(self.maybe_changed_after_cold_cycle( + zalsa_local, + database_key_index, + cycle_heads, + )) + } } }; // Load the current memo, if any. diff --git a/src/function/memo.rs b/src/function/memo.rs index fd830ced3..4c60c44fe 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -180,7 +180,10 @@ impl<'db, C: Configuration> Memo<'db, C> { } TryClaimHeadsResult::Running(running) => { all_cycles = false; - running.block_on(zalsa); + if !running.block_on(zalsa) { + // FIXME: Handle cancellation properly? + crate::Cancelled::PropagatedPanic.throw(); + } } } } diff --git a/src/function/sync.rs b/src/function/sync.rs index c9a74a307..edc881ffa 100644 --- a/src/function/sync.rs +++ b/src/function/sync.rs @@ -2,6 +2,7 @@ use rustc_hash::FxHashMap; use std::collections::hash_map::OccupiedEntry; use crate::key::DatabaseKeyIndex; +use crate::plumbing::ZalsaLocal; use crate::runtime::{ BlockOnTransferredOwner, BlockResult, BlockTransferredResult, Running, WaitResult, }; @@ -21,6 +22,8 @@ pub(crate) struct SyncTable { } pub(crate) enum ClaimResult<'a, Guard = ClaimGuard<'a>> { + /// Successfully claimed the query. + Claimed(Guard), /// Can't claim the query because it is running on an other thread. Running(Running<'a>), /// Claiming the query results in a cycle. @@ -30,8 +33,6 @@ pub(crate) enum ClaimResult<'a, Guard = ClaimGuard<'a>> { /// [`SyncTable::try_claim`] with [`Reentrant::Allow`]. inner: bool, }, - /// Successfully claimed the query. - Claimed(Guard), } pub(crate) struct SyncState { @@ -68,6 +69,7 @@ impl SyncTable { pub(crate) fn try_claim<'me>( &'me self, zalsa: &'me Zalsa, + zalsa_local: &'me ZalsaLocal, key_index: Id, reentrant: Reentrancy, ) -> ClaimResult<'me> { @@ -77,7 +79,12 @@ impl SyncTable { let id = match occupied_entry.get().id { SyncOwner::Thread(id) => id, SyncOwner::Transferred => { - return match self.try_claim_transferred(zalsa, occupied_entry, reentrant) { + return match self.try_claim_transferred( + zalsa, + zalsa_local, + occupied_entry, + reentrant, + ) { Ok(claimed) => claimed, Err(other_thread) => match other_thread.block(write) { BlockResult::Cycle => ClaimResult::Cycle { inner: false }, @@ -115,6 +122,7 @@ impl SyncTable { ClaimResult::Claimed(ClaimGuard { key_index, zalsa, + zalsa_local, sync_table: self, mode: ReleaseMode::Default, }) @@ -172,6 +180,7 @@ impl SyncTable { fn try_claim_transferred<'me>( &'me self, zalsa: &'me Zalsa, + zalsa_local: &'me ZalsaLocal, mut entry: OccupiedEntry, reentrant: Reentrancy, ) -> Result, Box>> { @@ -195,6 +204,7 @@ impl SyncTable { Ok(ClaimResult::Claimed(ClaimGuard { key_index, zalsa, + zalsa_local, sync_table: self, mode: ReleaseMode::SelfOnly, })) @@ -214,6 +224,7 @@ impl SyncTable { Ok(ClaimResult::Claimed(ClaimGuard { key_index, zalsa, + zalsa_local, sync_table: self, mode: ReleaseMode::Default, })) @@ -295,6 +306,7 @@ pub(crate) struct ClaimGuard<'me> { zalsa: &'me Zalsa, sync_table: &'me SyncTable, mode: ReleaseMode, + zalsa_local: &'me ZalsaLocal, } impl<'me> ClaimGuard<'me> { @@ -319,10 +331,21 @@ impl<'me> ClaimGuard<'me> { "Release claim on {:?} due to panic", self.database_key_index() ); - self.release(state, WaitResult::Panicked); } + #[cold] + #[inline(never)] + fn release_cancelled(&self) { + let mut syncs = self.sync_table.syncs.lock(); + let state = syncs.remove(&self.key_index).expect("key claimed twice?"); + tracing::debug!( + "Release claim on {:?} due to cancellation", + self.database_key_index() + ); + self.release(state, WaitResult::Cancelled); + } + #[inline(always)] fn release(&self, state: SyncState, wait_result: WaitResult) { let SyncState { @@ -446,7 +469,11 @@ impl<'me> ClaimGuard<'me> { impl Drop for ClaimGuard<'_> { fn drop(&mut self) { if thread::panicking() { - self.release_panicking(); + if self.zalsa_local.is_cancelled() { + self.release_cancelled(); + } else { + self.release_panicking(); + } return; } diff --git a/src/interned.rs b/src/interned.rs index 544a8d0ee..afd15f71e 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -850,6 +850,7 @@ pub struct StructEntry<'db, C> where C: Configuration, { + #[allow(dead_code)] value: &'db Value, key: DatabaseKeyIndex, } diff --git a/src/lib.rs b/src/lib.rs index 8c50c9052..b4bcf5d28 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -62,6 +62,7 @@ pub use self::runtime::Runtime; pub use self::storage::{Storage, StorageHandle}; pub use self::update::Update; pub use self::zalsa::IngredientIndex; +pub use self::zalsa_local::CancellationToken; pub use crate::attach::{attach, attach_allow_change, with_attached_database}; pub mod prelude { diff --git a/src/runtime.rs b/src/runtime.rs index 48caf53ec..a7feb2287 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -13,11 +13,11 @@ mod dependency_graph; #[cfg_attr(feature = "persistence", derive(serde::Serialize, serde::Deserialize))] pub struct Runtime { - /// Set to true when the current revision has been canceled. + /// Set to true when the current revision has been cancelled. /// This is done when we an input is being changed. The flag /// is set back to false once the input has been changed. #[cfg_attr(feature = "persistence", serde(skip))] - revision_canceled: AtomicBool, + revision_cancelled: AtomicBool, /// Stores the "last change" revision for values of each duration. /// This vector is always of length at least 1 (for Durability 0) @@ -44,6 +44,7 @@ pub struct Runtime { pub(super) enum WaitResult { Completed, Panicked, + Cancelled, } #[derive(Debug)] @@ -121,7 +122,11 @@ struct BlockedOnInner<'me> { impl Running<'_> { /// Blocks on the other thread to complete the computation. - pub(crate) fn block_on(self, zalsa: &Zalsa) { + /// + /// Returns `true` if the computation was successful, and `false` if the other thread was cancelled. + #[must_use] + #[cold] + pub(crate) fn block_on(self, zalsa: &Zalsa) -> bool { let BlockedOnInner { dg, query_mutex_guard, @@ -151,7 +156,8 @@ impl Running<'_> { // by the other thread and responded to appropriately. Cancelled::PropagatedPanic.throw() } - WaitResult::Completed => {} + WaitResult::Cancelled => false, + WaitResult::Completed => true, } } } @@ -183,7 +189,7 @@ impl Default for Runtime { fn default() -> Self { Runtime { revisions: [Revision::start(); Durability::LEN], - revision_canceled: Default::default(), + revision_cancelled: Default::default(), dependency_graph: Default::default(), table: Default::default(), } @@ -194,7 +200,7 @@ impl std::fmt::Debug for Runtime { fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fmt.debug_struct("Runtime") .field("revisions", &self.revisions) - .field("revision_canceled", &self.revision_canceled) + .field("revision_cancelled", &self.revision_cancelled) .field("dependency_graph", &self.dependency_graph) .finish() } @@ -227,16 +233,16 @@ impl Runtime { } pub(crate) fn load_cancellation_flag(&self) -> bool { - self.revision_canceled.load(Ordering::Acquire) + self.revision_cancelled.load(Ordering::Acquire) } pub(crate) fn set_cancellation_flag(&self) { crate::tracing::trace!("set_cancellation_flag"); - self.revision_canceled.store(true, Ordering::Release); + self.revision_cancelled.store(true, Ordering::Release); } pub(crate) fn reset_cancellation_flag(&mut self) { - *self.revision_canceled.get_mut() = false; + *self.revision_cancelled.get_mut() = false; } /// Returns the [`Table`] used to store the value of salsa structs diff --git a/src/storage.rs b/src/storage.rs index 443b53221..bf3fccd92 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -124,12 +124,15 @@ impl Storage { .record_unfilled_pages(self.handle.zalsa_impl.table()); let Self { handle, - zalsa_local: _, - } = &self; + zalsa_local, + } = &mut self; // Avoid rust's annoying destructure prevention rules for `Drop` types // SAFETY: We forget `Self` afterwards to discard the original copy, and the destructure // above makes sure we won't forget to take into account newly added fields. let handle = unsafe { std::ptr::read(handle) }; + // SAFETY: We forget `Self` afterwards to discard the original copy, and the destructure + // above makes sure we won't forget to take into account newly added fields. + unsafe { std::ptr::drop_in_place(zalsa_local) }; std::mem::forget::(self); handle } diff --git a/src/zalsa.rs b/src/zalsa.rs index ee3c68ce0..0448c4ee7 100644 --- a/src/zalsa.rs +++ b/src/zalsa.rs @@ -301,8 +301,11 @@ impl Zalsa { #[inline] pub(crate) fn unwind_if_revision_cancelled(&self, zalsa_local: &ZalsaLocal) { self.event(&|| crate::Event::new(crate::EventKind::WillCheckCancellation)); + if zalsa_local.is_cancelled() { + zalsa_local.unwind_cancelled(); + } if self.runtime().load_cancellation_flag() { - zalsa_local.unwind_cancelled(self.current_revision()); + zalsa_local.unwind_pending_write(); } } diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 7b0399178..2c6031308 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -3,6 +3,8 @@ use std::fmt; use std::fmt::Formatter; use std::panic::UnwindSafe; use std::ptr::{self, NonNull}; +use std::sync::atomic::Ordering; +use std::sync::Arc; use rustc_hash::FxHashMap; use thin_vec::ThinVec; @@ -39,6 +41,28 @@ pub struct ZalsaLocal { /// Stores the most recent page for a given ingredient. /// This is thread-local to avoid contention. most_recent_pages: UnsafeCell>, + + cancelled: CancellationToken, +} + +/// A cancellation token that can be used to cancel a query computation for a specific local `Database`. +#[derive(Default, Clone, Debug)] +pub struct CancellationToken(Arc); + +impl CancellationToken { + /// Inform the database to cancel the current query computation. + pub fn cancel(&self) { + self.0.store(true, Ordering::Relaxed); + } + + /// Check if the query computation has been requested to be cancelled. + pub fn is_cancelled(&self) -> bool { + self.0.load(Ordering::Relaxed) + } + + pub(crate) fn uncancel(&self) { + self.0.store(false, Ordering::Relaxed); + } } impl ZalsaLocal { @@ -46,6 +70,7 @@ impl ZalsaLocal { ZalsaLocal { query_stack: RefCell::new(QueryStack::default()), most_recent_pages: UnsafeCell::new(FxHashMap::default()), + cancelled: CancellationToken::default(), } } @@ -401,12 +426,30 @@ impl ZalsaLocal { } } + #[inline] + pub(crate) fn cancellation_token(&self) -> CancellationToken { + self.cancelled.clone() + } + + #[inline] + pub(crate) fn uncancel(&self) { + self.cancelled.uncancel(); + } + + #[inline] + pub fn is_cancelled(&self) -> bool { + self.cancelled.0.load(Ordering::Relaxed) + } + #[cold] - pub(crate) fn unwind_cancelled(&self, current_revision: Revision) { - // Why is this reporting an untracked read? We do not store the query revisions on unwind do we? - self.report_untracked_read(current_revision); + pub(crate) fn unwind_pending_write(&self) { Cancelled::PendingWrite.throw(); } + + #[cold] + pub(crate) fn unwind_cancelled(&self) { + Cancelled::Cancelled.throw(); + } } // Okay to implement as `ZalsaLocal`` is !Sync diff --git a/tests/cancellation_token.rs b/tests/cancellation_token.rs new file mode 100644 index 000000000..9aec792b6 --- /dev/null +++ b/tests/cancellation_token.rs @@ -0,0 +1,67 @@ +#![cfg(feature = "inventory")] +//! Test that `DeriveWithDb` is correctly derived. + +mod common; + +use std::{sync::Barrier, thread}; + +use expect_test::expect; +use salsa::{Cancelled, Database}; + +use crate::common::LogDatabase; + +#[salsa::input(debug)] +struct MyInput { + field: u32, +} + +#[salsa::tracked] +fn a(db: &dyn Database, input: MyInput) -> u32 { + BARRIER.wait(); + BARRIER2.wait(); + b(db, input) +} +#[salsa::tracked] +fn b(db: &dyn Database, input: MyInput) -> u32 { + input.field(db) +} + +static BARRIER: Barrier = Barrier::new(2); +static BARRIER2: Barrier = Barrier::new(2); + +#[test] +fn cancellation_token() { + let db = common::EventLoggerDatabase::default(); + let token = db.cancellation_token(); + let input = MyInput::new(&db, 22); + let res = Cancelled::catch(|| { + thread::scope(|s| { + s.spawn(|| { + BARRIER.wait(); + token.cancel(); + BARRIER2.wait(); + }); + a(&db, input) + }) + }); + assert!(matches!(res, Err(Cancelled::Cancelled)), "{res:?}"); + drop(res); + db.assert_logs(expect![[r#" + [ + "WillCheckCancellation", + "WillExecute { database_key: a(Id(0)) }", + "WillCheckCancellation", + ]"#]]); + thread::spawn(|| { + BARRIER.wait(); + BARRIER2.wait(); + }); + a(&db, input); + db.assert_logs(expect![[r#" + [ + "WillCheckCancellation", + "WillExecute { database_key: a(Id(0)) }", + "WillCheckCancellation", + "WillExecute { database_key: b(Id(0)) }", + ]"#]]); +} diff --git a/tests/interned-revisions.rs b/tests/interned-revisions.rs index bef1db61c..41f762895 100644 --- a/tests/interned-revisions.rs +++ b/tests/interned-revisions.rs @@ -156,7 +156,7 @@ fn test_immortal() { // Modify the input to bump the revision and intern a new value. // // No values should ever be reused with `durability = usize::MAX`. - for i in 1..100 { + for i in 1..if cfg!(miri) { 50 } else { 1000 } { input.set_field1(&mut db).to(i); let result = function(&db, input); assert_eq!(result.field1(&db).0, i); diff --git a/tests/parallel/cancellation_token_recomputes.rs b/tests/parallel/cancellation_token_recomputes.rs new file mode 100644 index 000000000..b7963afc5 --- /dev/null +++ b/tests/parallel/cancellation_token_recomputes.rs @@ -0,0 +1,43 @@ +// Shuttle doesn't like panics inside of its runtime. +#![cfg(not(feature = "shuttle"))] + +//! Test for cancellation when another query is blocked on the cancelled thread. +use salsa::{Cancelled, Database}; + +use crate::setup::{Knobs, KnobsDatabase}; + +#[salsa::tracked] +fn query_a(db: &dyn KnobsDatabase) -> u32 { + query_b(db) +} + +#[salsa::tracked] +fn query_b(db: &dyn KnobsDatabase) -> u32 { + db.signal(1); + db.wait_for(3); + query_c(db) +} + +#[salsa::tracked] +fn query_c(_db: &dyn KnobsDatabase) -> u32 { + 1 +} +#[test] +fn execute() { + let db = Knobs::default(); + let db2 = db.clone(); + let db_signaler = db.clone(); + let token = db.cancellation_token(); + + let t1 = std::thread::spawn(move || query_a(&db)); + db_signaler.wait_for(1); + db2.signal_on_will_block(2); + let t2 = std::thread::spawn(move || query_a(&db2)); + db_signaler.wait_for(2); + token.cancel(); + db_signaler.signal(3); + let (r1, r2) = (t1.join(), t2.join()); + let r1 = *r1.unwrap_err().downcast::().unwrap(); + assert!(matches!(r1, Cancelled::Cancelled), "{r1:?}"); + assert_eq!(r2.unwrap(), 1); +} diff --git a/tests/parallel/main.rs b/tests/parallel/main.rs index 1062d4899..41a6f7453 100644 --- a/tests/parallel/main.rs +++ b/tests/parallel/main.rs @@ -3,6 +3,7 @@ mod setup; mod signal; +mod cancellation_token_recomputes; mod cycle_a_t1_b_t2; mod cycle_a_t1_b_t2_fallback; mod cycle_ab_peeping_c; From 7ab946b30974a7b9a450d591d76baa3fb75d3ce1 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Thu, 30 Oct 2025 08:38:10 +0100 Subject: [PATCH 2/2] Address reviews --- src/cancelled.rs | 4 +- src/database.rs | 2 +- src/function/fetch.rs | 54 +++++++++---------- src/function/maybe_changed_after.rs | 12 ++--- src/function/memo.rs | 2 +- src/runtime.rs | 7 ++- src/zalsa_local.rs | 2 +- tests/cancellation_token.rs | 2 +- .../parallel/cancellation_token_recomputes.rs | 2 +- 9 files changed, 43 insertions(+), 44 deletions(-) diff --git a/src/cancelled.rs b/src/cancelled.rs index 1fa0edc59..5fe69e7d1 100644 --- a/src/cancelled.rs +++ b/src/cancelled.rs @@ -11,7 +11,7 @@ use std::panic::{self, UnwindSafe}; #[non_exhaustive] pub enum Cancelled { /// The query was operating but the local database execution has been cancelled. - Cancelled, + Local, /// The query was operating on revision R, but there is a pending write to move to revision R+1. PendingWrite, @@ -46,7 +46,7 @@ impl Cancelled { impl std::fmt::Display for Cancelled { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let why = match self { - Cancelled::Cancelled => "canellation request", + Cancelled::Local => "local canellation request", Cancelled::PendingWrite => "pending write", Cancelled::PropagatedPanic => "propagated panic", }; diff --git a/src/database.rs b/src/database.rs index 9cb70f917..0831fd5bf 100644 --- a/src/database.rs +++ b/src/database.rs @@ -68,7 +68,7 @@ pub trait Database: Send + ZalsaDatabase + AsDynDatabase { let _ = self.zalsa_mut(); } - /// Retrives a [`CancellationToken`] for the current database. + /// Retrieves a [`CancellationToken`] for the current database handle. fn cancellation_token(&self) -> CancellationToken { self.zalsa_local().cancellation_token() } diff --git a/src/function/fetch.rs b/src/function/fetch.rs index 9adc01a07..7548e3953 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -105,39 +105,37 @@ where ) -> Option<&'db Memo<'db, C>> { let database_key_index = self.database_key_index(id); // Try to claim this query: if someone else has claimed it already, go back and start again. - let claim_guard = loop { - match self - .sync_table - .try_claim(zalsa, zalsa_local, id, Reentrancy::Allow) - { - ClaimResult::Claimed(guard) => break guard, - ClaimResult::Running(blocked_on) => { - if !blocked_on.block_on(zalsa) { - continue; - } + let claim_guard = match self + .sync_table + .try_claim(zalsa, zalsa_local, id, Reentrancy::Allow) + { + ClaimResult::Claimed(guard) => guard, + ClaimResult::Running(blocked_on) => { + if !blocked_on.block_on(zalsa) { + return None; + } - if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate { - let memo = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); + if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate { + let memo = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index); - if let Some(memo) = memo { - if memo.value.is_some() { - memo.block_on_heads(zalsa); - } + if let Some(memo) = memo { + if memo.value.is_some() { + memo.block_on_heads(zalsa); } } - - return None; - } - ClaimResult::Cycle { .. } => { - return Some(self.fetch_cold_cycle( - zalsa, - zalsa_local, - db, - id, - database_key_index, - memo_ingredient_index, - )); } + + return None; + } + ClaimResult::Cycle { .. } => { + return Some(self.fetch_cold_cycle( + zalsa, + zalsa_local, + db, + id, + database_key_index, + memo_ingredient_index, + )); } }; diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index a92ea1876..ee84d10e6 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -141,16 +141,15 @@ where ) -> Option { let database_key_index = self.database_key_index(key_index); - let claim_guard = loop { + let claim_guard = match self .sync_table .try_claim(zalsa, zalsa_local, key_index, Reentrancy::Deny) { - ClaimResult::Claimed(guard) => break guard, + ClaimResult::Claimed(guard) => guard, ClaimResult::Running(blocked_on) => { - if blocked_on.block_on(zalsa) { - return None; - } + _ = blocked_on.block_on(zalsa); + return None; } ClaimResult::Cycle { .. } => { return Some(self.maybe_changed_after_cold_cycle( @@ -159,8 +158,7 @@ where cycle_heads, )) } - } - }; + }; // Load the current memo, if any. let Some(old_memo) = self.get_memo_from_table_for(zalsa, key_index, memo_ingredient_index) else { diff --git a/src/function/memo.rs b/src/function/memo.rs index 4c60c44fe..fc03f8b29 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -510,7 +510,7 @@ mod _memory_usage { use std::any::TypeId; use std::num::NonZeroUsize; - // Memo's are stored a lot, make sure their size is doesn't randomly increase. + // Memo's are stored a lot, make sure their size doesn't randomly increase. const _: [(); std::mem::size_of::>()] = [(); std::mem::size_of::<[usize; 6]>()]; diff --git a/src/runtime.rs b/src/runtime.rs index a7feb2287..5b36bf205 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -123,9 +123,12 @@ struct BlockedOnInner<'me> { impl Running<'_> { /// Blocks on the other thread to complete the computation. /// - /// Returns `true` if the computation was successful, and `false` if the other thread was cancelled. + /// Returns `true` if the computation was successful, and `false` if the other thread was locally cancelled. + /// + /// # Panics + /// + /// If the other thread panics, this function will panic as well. #[must_use] - #[cold] pub(crate) fn block_on(self, zalsa: &Zalsa) -> bool { let BlockedOnInner { dg, diff --git a/src/zalsa_local.rs b/src/zalsa_local.rs index 2c6031308..e6ba1860b 100644 --- a/src/zalsa_local.rs +++ b/src/zalsa_local.rs @@ -448,7 +448,7 @@ impl ZalsaLocal { #[cold] pub(crate) fn unwind_cancelled(&self) { - Cancelled::Cancelled.throw(); + Cancelled::Local.throw(); } } diff --git a/tests/cancellation_token.rs b/tests/cancellation_token.rs index 9aec792b6..f6a14930a 100644 --- a/tests/cancellation_token.rs +++ b/tests/cancellation_token.rs @@ -44,7 +44,7 @@ fn cancellation_token() { a(&db, input) }) }); - assert!(matches!(res, Err(Cancelled::Cancelled)), "{res:?}"); + assert!(matches!(res, Err(Cancelled::Local)), "{res:?}"); drop(res); db.assert_logs(expect![[r#" [ diff --git a/tests/parallel/cancellation_token_recomputes.rs b/tests/parallel/cancellation_token_recomputes.rs index b7963afc5..0bbb67ef0 100644 --- a/tests/parallel/cancellation_token_recomputes.rs +++ b/tests/parallel/cancellation_token_recomputes.rs @@ -38,6 +38,6 @@ fn execute() { db_signaler.signal(3); let (r1, r2) = (t1.join(), t2.join()); let r1 = *r1.unwrap_err().downcast::().unwrap(); - assert!(matches!(r1, Cancelled::Cancelled), "{r1:?}"); + assert!(matches!(r1, Cancelled::Local), "{r1:?}"); assert_eq!(r2.unwrap(), 1); }