Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions src/attach.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ impl Attached {
fn drop(&mut self) {
// Reset database to null if we did anything in `DbGuard::new`.
if let Some(attached) = self.state {
attached.database.set(None);
if let Some(prev) = attached.database.replace(None) {
// SAFETY: `prev` is a valid pointer to a database.
unsafe { prev.as_ref().zalsa_local().uncancel() };
}
}
}
}
Expand All @@ -85,25 +88,50 @@ impl Attached {
Db: ?Sized + Database,
{
struct DbGuard<'s> {
state: &'s Attached,
state: Option<&'s Attached>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reason to make this optional? It's not evident to me how it's related to the change (but it probably it is, just not evident to me)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to track when the database actually changes which this helps in doing, since when it changes thats a signal for it having exited its outermost scope allowing us to reset the cancellation state.

prev: Option<NonNull<dyn Database>>,
}

impl<'s> DbGuard<'s> {
#[inline]
fn new(attached: &'s Attached, db: &dyn Database) -> Self {
let prev = attached.database.replace(Some(NonNull::from(db)));
Self {
state: attached,
prev,
let db = NonNull::from(db);
match attached.database.replace(Some(db)) {
Some(prev) => {
if std::ptr::eq(db.as_ptr(), prev.as_ptr()) {
Self {
state: None,
prev: None,
}
} else {
Self {
state: Some(attached),
prev: Some(prev),
}
}
}
None => {
// Otherwise, set the database.
attached.database.set(Some(db));
Self {
state: Some(attached),
prev: None,
}
}
}
}
}

impl Drop for DbGuard<'_> {
#[inline]
fn drop(&mut self) {
self.state.database.set(self.prev);
// Reset database to null if we did anything in `DbGuard::new`.
if let Some(attached) = self.state {
if let Some(prev) = attached.database.replace(self.prev) {
// SAFETY: `prev` is a valid pointer to a database.
unsafe { prev.as_ref().zalsa_local().uncancel() };
}
}
}
}

Expand Down
6 changes: 4 additions & 2 deletions src/cancelled.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ use std::panic::{self, UnwindSafe};
#[derive(Debug)]
#[non_exhaustive]
pub enum Cancelled {
/// The query was operating but the local database execution has been cancelled.
Local,

/// The query was operating on revision R, but there is a pending write to move to revision R+1.
#[non_exhaustive]
PendingWrite,

/// The query was blocked on another thread, and that thread panicked.
#[non_exhaustive]
PropagatedPanic,
}

Expand Down Expand Up @@ -45,6 +46,7 @@ impl Cancelled {
impl std::fmt::Display for Cancelled {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let why = match self {
Cancelled::Local => "local canellation request",
Cancelled::PendingWrite => "pending write",
Cancelled::PropagatedPanic => "propagated panic",
};
Expand Down
8 changes: 7 additions & 1 deletion src/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::ptr::NonNull;

use crate::views::DatabaseDownCaster;
use crate::zalsa::{IngredientIndex, ZalsaDatabase};
use crate::zalsa_local::CancellationToken;
use crate::{Durability, Revision};

#[derive(Copy, Clone)]
Expand Down Expand Up @@ -59,14 +60,19 @@ pub trait Database: Send + ZalsaDatabase + AsDynDatabase {
zalsa_mut.runtime_mut().report_tracked_write(durability);
}

/// This method triggers cancellation.
/// This method cancels all outstanding computations.
/// If you invoke it while a snapshot exists, it
/// will block until that snapshot is dropped -- if that snapshot
/// is owned by the current thread, this could trigger deadlock.
fn trigger_cancellation(&mut self) {
let _ = self.zalsa_mut();
}

/// Retrieves a [`CancellationToken`] for the current database handle.
fn cancellation_token(&self) -> CancellationToken {
self.zalsa_local().cancellation_token()
}

/// Reports that the query depends on some state unknown to salsa.
///
/// Queries which report untracked reads will be re-executed in the next
Expand Down
9 changes: 7 additions & 2 deletions src/function/fetch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,15 @@ where
) -> Option<&'db Memo<'db, C>> {
let database_key_index = self.database_key_index(id);
// Try to claim this query: if someone else has claimed it already, go back and start again.
let claim_guard = match self.sync_table.try_claim(zalsa, id, Reentrancy::Allow) {
let claim_guard = match self
.sync_table
.try_claim(zalsa, zalsa_local, id, Reentrancy::Allow)
{
ClaimResult::Claimed(guard) => guard,
ClaimResult::Running(blocked_on) => {
blocked_on.block_on(zalsa);
if !blocked_on.block_on(zalsa) {
return None;
}

if C::CYCLE_STRATEGY == CycleRecoveryStrategy::FallbackImmediate {
let memo = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index);
Expand Down
35 changes: 18 additions & 17 deletions src/function/maybe_changed_after.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,23 +141,24 @@ where
) -> Option<VerifyResult> {
let database_key_index = self.database_key_index(key_index);

let claim_guard = match self
.sync_table
.try_claim(zalsa, key_index, Reentrancy::Deny)
{
ClaimResult::Claimed(guard) => guard,
ClaimResult::Running(blocked_on) => {
blocked_on.block_on(zalsa);
return None;
}
ClaimResult::Cycle { .. } => {
return Some(self.maybe_changed_after_cold_cycle(
zalsa_local,
database_key_index,
cycle_heads,
))
}
};
let claim_guard =
match self
.sync_table
.try_claim(zalsa, zalsa_local, key_index, Reentrancy::Deny)
{
ClaimResult::Claimed(guard) => guard,
ClaimResult::Running(blocked_on) => {
_ = blocked_on.block_on(zalsa);
return None;
}
ClaimResult::Cycle { .. } => {
return Some(self.maybe_changed_after_cold_cycle(
zalsa_local,
database_key_index,
cycle_heads,
))
}
};
// Load the current memo, if any.
let Some(old_memo) = self.get_memo_from_table_for(zalsa, key_index, memo_ingredient_index)
else {
Expand Down
7 changes: 5 additions & 2 deletions src/function/memo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,10 @@ impl<'db, C: Configuration> Memo<'db, C> {
}
TryClaimHeadsResult::Running(running) => {
all_cycles = false;
running.block_on(zalsa);
if !running.block_on(zalsa) {
// FIXME: Handle cancellation properly?
crate::Cancelled::PropagatedPanic.throw();
}
}
}
}
Expand Down Expand Up @@ -507,7 +510,7 @@ mod _memory_usage {
use std::any::TypeId;
use std::num::NonZeroUsize;

// Memo's are stored a lot, make sure their size is doesn't randomly increase.
// Memo's are stored a lot, make sure their size doesn't randomly increase.
const _: [(); std::mem::size_of::<super::Memo<DummyConfiguration>>()] =
[(); std::mem::size_of::<[usize; 6]>()];

Expand Down
37 changes: 32 additions & 5 deletions src/function/sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use rustc_hash::FxHashMap;
use std::collections::hash_map::OccupiedEntry;

use crate::key::DatabaseKeyIndex;
use crate::plumbing::ZalsaLocal;
use crate::runtime::{
BlockOnTransferredOwner, BlockResult, BlockTransferredResult, Running, WaitResult,
};
Expand All @@ -21,6 +22,8 @@ pub(crate) struct SyncTable {
}

pub(crate) enum ClaimResult<'a, Guard = ClaimGuard<'a>> {
/// Successfully claimed the query.
Claimed(Guard),
/// Can't claim the query because it is running on an other thread.
Running(Running<'a>),
/// Claiming the query results in a cycle.
Expand All @@ -30,8 +33,6 @@ pub(crate) enum ClaimResult<'a, Guard = ClaimGuard<'a>> {
/// [`SyncTable::try_claim`] with [`Reentrant::Allow`].
inner: bool,
},
/// Successfully claimed the query.
Claimed(Guard),
}

pub(crate) struct SyncState {
Expand Down Expand Up @@ -68,6 +69,7 @@ impl SyncTable {
pub(crate) fn try_claim<'me>(
&'me self,
zalsa: &'me Zalsa,
zalsa_local: &'me ZalsaLocal,
key_index: Id,
reentrant: Reentrancy,
) -> ClaimResult<'me> {
Expand All @@ -77,7 +79,12 @@ impl SyncTable {
let id = match occupied_entry.get().id {
SyncOwner::Thread(id) => id,
SyncOwner::Transferred => {
return match self.try_claim_transferred(zalsa, occupied_entry, reentrant) {
return match self.try_claim_transferred(
zalsa,
zalsa_local,
occupied_entry,
reentrant,
) {
Ok(claimed) => claimed,
Err(other_thread) => match other_thread.block(write) {
BlockResult::Cycle => ClaimResult::Cycle { inner: false },
Expand Down Expand Up @@ -115,6 +122,7 @@ impl SyncTable {
ClaimResult::Claimed(ClaimGuard {
key_index,
zalsa,
zalsa_local,
sync_table: self,
mode: ReleaseMode::Default,
})
Expand Down Expand Up @@ -172,6 +180,7 @@ impl SyncTable {
fn try_claim_transferred<'me>(
&'me self,
zalsa: &'me Zalsa,
zalsa_local: &'me ZalsaLocal,
mut entry: OccupiedEntry<Id, SyncState>,
reentrant: Reentrancy,
) -> Result<ClaimResult<'me>, Box<BlockOnTransferredOwner<'me>>> {
Expand All @@ -195,6 +204,7 @@ impl SyncTable {
Ok(ClaimResult::Claimed(ClaimGuard {
key_index,
zalsa,
zalsa_local,
sync_table: self,
mode: ReleaseMode::SelfOnly,
}))
Expand All @@ -214,6 +224,7 @@ impl SyncTable {
Ok(ClaimResult::Claimed(ClaimGuard {
key_index,
zalsa,
zalsa_local,
sync_table: self,
mode: ReleaseMode::Default,
}))
Expand Down Expand Up @@ -295,6 +306,7 @@ pub(crate) struct ClaimGuard<'me> {
zalsa: &'me Zalsa,
sync_table: &'me SyncTable,
mode: ReleaseMode,
zalsa_local: &'me ZalsaLocal,
}

impl<'me> ClaimGuard<'me> {
Expand All @@ -319,10 +331,21 @@ impl<'me> ClaimGuard<'me> {
"Release claim on {:?} due to panic",
self.database_key_index()
);

self.release(state, WaitResult::Panicked);
}

#[cold]
#[inline(never)]
fn release_cancelled(&self) {
let mut syncs = self.sync_table.syncs.lock();
let state = syncs.remove(&self.key_index).expect("key claimed twice?");
tracing::debug!(
"Release claim on {:?} due to cancellation",
self.database_key_index()
);
self.release(state, WaitResult::Cancelled);
}

#[inline(always)]
fn release(&self, state: SyncState, wait_result: WaitResult) {
let SyncState {
Expand Down Expand Up @@ -446,7 +469,11 @@ impl<'me> ClaimGuard<'me> {
impl Drop for ClaimGuard<'_> {
fn drop(&mut self) {
if thread::panicking() {
self.release_panicking();
if self.zalsa_local.is_cancelled() {
self.release_cancelled();
} else {
self.release_panicking();
}
return;
}

Expand Down
1 change: 1 addition & 0 deletions src/interned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,7 @@ pub struct StructEntry<'db, C>
where
C: Configuration,
{
#[allow(dead_code)]
value: &'db Value<C>,
key: DatabaseKeyIndex,
}
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ pub use self::runtime::Runtime;
pub use self::storage::{Storage, StorageHandle};
pub use self::update::Update;
pub use self::zalsa::IngredientIndex;
pub use self::zalsa_local::CancellationToken;
pub use crate::attach::{attach, attach_allow_change, with_attached_database};

pub mod prelude {
Expand Down
Loading