Skip to content

Commit ba0f832

Browse files
committed
Cancellation cleanup on un-attach
1 parent de32cab commit ba0f832

File tree

5 files changed

+45
-24
lines changed

5 files changed

+45
-24
lines changed

src/attach.rs

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,10 @@ impl Attached {
7070
fn drop(&mut self) {
7171
// Reset database to null if we did anything in `DbGuard::new`.
7272
if let Some(attached) = self.state {
73-
attached.database.set(None);
73+
if let Some(prev) = attached.database.replace(None) {
74+
// SAFETY: `prev` is a valid pointer to a database.
75+
unsafe { prev.as_ref().zalsa_local().uncancel() };
76+
}
7477
}
7578
}
7679
}
@@ -85,25 +88,50 @@ impl Attached {
8588
Db: ?Sized + Database,
8689
{
8790
struct DbGuard<'s> {
88-
state: &'s Attached,
91+
state: Option<&'s Attached>,
8992
prev: Option<NonNull<dyn Database>>,
9093
}
9194

9295
impl<'s> DbGuard<'s> {
9396
#[inline]
9497
fn new(attached: &'s Attached, db: &dyn Database) -> Self {
95-
let prev = attached.database.replace(Some(NonNull::from(db)));
96-
Self {
97-
state: attached,
98-
prev,
98+
let db = NonNull::from(db);
99+
match attached.database.replace(Some(db)) {
100+
Some(prev) => {
101+
if std::ptr::eq(db.as_ptr(), prev.as_ptr()) {
102+
Self {
103+
state: None,
104+
prev: None,
105+
}
106+
} else {
107+
Self {
108+
state: Some(attached),
109+
prev: Some(prev),
110+
}
111+
}
112+
}
113+
None => {
114+
// Otherwise, set the database.
115+
attached.database.set(Some(NonNull::from(db)));
116+
Self {
117+
state: Some(attached),
118+
prev: None,
119+
}
120+
}
99121
}
100122
}
101123
}
102124

103125
impl Drop for DbGuard<'_> {
104126
#[inline]
105127
fn drop(&mut self) {
106-
self.state.database.set(self.prev);
128+
// Reset database to null if we did anything in `DbGuard::new`.
129+
if let Some(attached) = self.state {
130+
if let Some(prev) = attached.database.replace(self.prev) {
131+
// SAFETY: `prev` is a valid pointer to a database.
132+
unsafe { prev.as_ref().zalsa_local().uncancel() };
133+
}
134+
}
107135
}
108136
}
109137

src/cancelled.rs

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
use std::fmt;
22
use std::panic::{self, UnwindSafe};
33

4-
use crate::CancellationToken;
5-
64
/// A panic payload indicating that execution of a salsa query was cancelled.
75
///
86
/// This can occur for a few reasons:
@@ -13,7 +11,7 @@ use crate::CancellationToken;
1311
#[non_exhaustive]
1412
pub enum Cancelled {
1513
/// The query was operating but the local database execution has been cancelled.
16-
Cancelled(UncancelGuard),
14+
Cancelled,
1715

1816
/// The query was operating on revision R, but there is a pending write to move to revision R+1.
1917
PendingWrite,
@@ -22,15 +20,6 @@ pub enum Cancelled {
2220
PropagatedPanic,
2321
}
2422

25-
#[derive(Debug)]
26-
pub struct UncancelGuard(pub(crate) CancellationToken);
27-
28-
impl Drop for UncancelGuard {
29-
fn drop(&mut self) {
30-
self.0.uncancel();
31-
}
32-
}
33-
3423
impl Cancelled {
3524
#[cold]
3625
pub(crate) fn throw(self) -> ! {
@@ -57,7 +46,7 @@ impl Cancelled {
5746
impl std::fmt::Display for Cancelled {
5847
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
5948
let why = match self {
60-
Cancelled::Cancelled(_) => "canellation request",
49+
Cancelled::Cancelled => "canellation request",
6150
Cancelled::PendingWrite => "pending write",
6251
Cancelled::PropagatedPanic => "propagated panic",
6352
};

src/zalsa_local.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ use crate::accumulator::{
1515
Accumulator,
1616
};
1717
use crate::active_query::{CompletedQuery, QueryStack};
18-
use crate::cancelled::UncancelGuard;
1918
use crate::cycle::{empty_cycle_heads, AtomicIterationCount, CycleHeads, IterationCount};
2019
use crate::durability::Durability;
2120
use crate::key::DatabaseKeyIndex;
@@ -432,6 +431,11 @@ impl ZalsaLocal {
432431
self.cancelled.clone()
433432
}
434433

434+
#[inline]
435+
pub(crate) fn uncancel(&self) {
436+
self.cancelled.uncancel();
437+
}
438+
435439
#[inline]
436440
pub fn is_cancelled(&self) -> bool {
437441
self.cancelled.0.load(Ordering::Relaxed)
@@ -444,7 +448,7 @@ impl ZalsaLocal {
444448

445449
#[cold]
446450
pub(crate) fn unwind_cancelled(&self) {
447-
Cancelled::Cancelled(UncancelGuard(self.cancellation_token())).throw();
451+
Cancelled::Cancelled.throw();
448452
}
449453
}
450454

tests/cancellation_token.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ fn cancellation_token() {
4444
a(&db, input)
4545
})
4646
});
47-
assert!(matches!(res, Err(Cancelled::Cancelled(_))), "{res:?}");
47+
assert!(matches!(res, Err(Cancelled::Cancelled)), "{res:?}");
4848
drop(res);
4949
db.assert_logs(expect![[r#"
5050
[

tests/parallel/cancellation_token_recomputes.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,6 @@ fn execute() {
3838
db_signaler.signal(3);
3939
let (r1, r2) = (t1.join(), t2.join());
4040
let r1 = *r1.unwrap_err().downcast::<salsa::Cancelled>().unwrap();
41-
assert!(matches!(r1, Cancelled::Cancelled(_)), "{r1:?}");
41+
assert!(matches!(r1, Cancelled::Cancelled), "{r1:?}");
4242
assert_eq!(r2.unwrap(), 1);
4343
}

0 commit comments

Comments
 (0)