Skip to content

Commit a4113cd

Browse files
Veykrilibraheemdev
andauthored
Simplify WaitGroup implementation (#958)
* Simplify `WaitGroup` implementation * Slightly cheaper `get_mut` Co-authored-by: Ibraheem Ahmed <[email protected]> --------- Co-authored-by: Ibraheem Ahmed <[email protected]>
1 parent 9cfe41c commit a4113cd

File tree

4 files changed

+31
-25
lines changed

4 files changed

+31
-25
lines changed

src/storage.rs

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ pub struct StorageHandle<Db> {
2525

2626
impl<Db> Clone for StorageHandle<Db> {
2727
fn clone(&self) -> Self {
28-
*self.coordinate.clones.lock() += 1;
29-
3028
Self {
3129
zalsa_impl: self.zalsa_impl.clone(),
3230
coordinate: CoordinateDrop(Arc::clone(&self.coordinate)),
@@ -53,7 +51,7 @@ impl<Db: Database> StorageHandle<Db> {
5351
Self {
5452
zalsa_impl: Arc::new(Zalsa::new::<Db>(event_callback, jars)),
5553
coordinate: CoordinateDrop(Arc::new(Coordinate {
56-
clones: Mutex::new(1),
54+
coordinate_lock: Mutex::default(),
5755
cvar: Default::default(),
5856
})),
5957
phantom: PhantomData,
@@ -95,17 +93,6 @@ impl<Db> Drop for Storage<Db> {
9593
}
9694
}
9795

98-
struct Coordinate {
99-
/// Counter of the number of clones of actor. Begins at 1.
100-
/// Incremented when cloned, decremented when dropped.
101-
clones: Mutex<usize>,
102-
cvar: Condvar,
103-
}
104-
105-
// We cannot panic while holding a lock to `clones: Mutex<usize>` and therefore we cannot enter an
106-
// inconsistent state.
107-
impl RefUnwindSafe for Coordinate {}
108-
10996
impl<Db: Database> Default for Storage<Db> {
11097
fn default() -> Self {
11198
Self::new(None)
@@ -168,12 +155,15 @@ impl<Db: Database> Storage<Db> {
168155
.zalsa_impl
169156
.event(&|| Event::new(EventKind::DidSetCancellationFlag));
170157

171-
let mut clones = self.handle.coordinate.clones.lock();
172-
while *clones != 1 {
173-
clones = self.handle.coordinate.cvar.wait(clones);
174-
}
175-
// The ref count on the `Arc` should now be 1
176-
let zalsa = Arc::get_mut(&mut self.handle.zalsa_impl).unwrap();
158+
let mut coordinate_lock = self.handle.coordinate.coordinate_lock.lock();
159+
let zalsa = loop {
160+
if Arc::strong_count(&self.handle.zalsa_impl) == 1 {
161+
// SAFETY: The strong count is 1, and we never create any weak pointers,
162+
// so we have a unique reference.
163+
break unsafe { &mut *(Arc::as_ptr(&self.handle.zalsa_impl).cast_mut()) };
164+
}
165+
coordinate_lock = self.handle.coordinate.cvar.wait(coordinate_lock);
166+
};
177167
// cancellation is done, so reset the flag
178168
zalsa.runtime_mut().reset_cancellation_flag();
179169
zalsa
@@ -260,6 +250,16 @@ impl<Db: Database> Clone for Storage<Db> {
260250
}
261251
}
262252

253+
/// A simplified `WaitGroup`, this is used together with `Arc<Zalsa>` as the actual counter
254+
struct Coordinate {
255+
coordinate_lock: Mutex<()>,
256+
cvar: Condvar,
257+
}
258+
259+
// We cannot panic while holding a lock to `clones: Mutex<usize>` and therefore we cannot enter an
260+
// inconsistent state.
261+
impl RefUnwindSafe for Coordinate {}
262+
263263
struct CoordinateDrop(Arc<Coordinate>);
264264

265265
impl std::ops::Deref for CoordinateDrop {
@@ -272,7 +272,6 @@ impl std::ops::Deref for CoordinateDrop {
272272

273273
impl Drop for CoordinateDrop {
274274
fn drop(&mut self) {
275-
*self.0.clones.lock() -= 1;
276275
self.0.cvar.notify_all();
277276
}
278277
}

src/table.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,10 @@ impl Table {
252252
}
253253

254254
let allocated_idx = self.push_page::<T>(ingredient, memo_types.clone());
255-
assert_eq!(allocated_idx, page_idx);
255+
assert_eq!(
256+
allocated_idx, page_idx,
257+
"allocated index does not match requested index"
258+
);
256259
}
257260
};
258261
}

src/views.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,11 @@ impl Views {
108108
&self,
109109
func: fn(NonNull<Concrete>) -> NonNull<DbView>,
110110
) -> &DatabaseDownCaster<DbView> {
111-
assert_eq!(self.source_type_id, TypeId::of::<Concrete>());
111+
assert_eq!(
112+
self.source_type_id,
113+
TypeId::of::<Concrete>(),
114+
"mismatched source type"
115+
);
112116
let target_type_id = TypeId::of::<DbView>();
113117
if let Some((_, caster)) = self
114118
.view_casters

src/zalsa_local.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1173,7 +1173,7 @@ impl ActiveQueryGuard<'_> {
11731173
unsafe {
11741174
self.local_state.with_query_stack_unchecked_mut(|stack| {
11751175
#[cfg(debug_assertions)]
1176-
assert_eq!(stack.len(), self.push_len);
1176+
assert_eq!(stack.len(), self.push_len, "mismatched push and pop");
11771177
let frame = stack.last_mut().unwrap();
11781178
frame.tracked_struct_ids_mut().seed(tracked_struct_ids);
11791179
})
@@ -1195,7 +1195,7 @@ impl ActiveQueryGuard<'_> {
11951195
unsafe {
11961196
self.local_state.with_query_stack_unchecked_mut(|stack| {
11971197
#[cfg(debug_assertions)]
1198-
assert_eq!(stack.len(), self.push_len);
1198+
assert_eq!(stack.len(), self.push_len, "mismatched push and pop");
11991199
let frame = stack.last_mut().unwrap();
12001200
frame.seed_iteration(durability, changed_at, edges, untracked_read, tracked_ids);
12011201
})

0 commit comments

Comments
 (0)