Skip to content

Commit 6ac107c

Browse files
committed
Improve documentation for Treiber stack.
1 parent b26099a commit 6ac107c

File tree

2 files changed

+68
-24
lines changed

2 files changed

+68
-24
lines changed

src/pool/treiber/cas.rs

Lines changed: 64 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,16 @@ where
6868
failure,
6969
)
7070
.map(drop)
71-
.map_err(NonNullPtr::from_inner)
71+
.map_err(|value| {
72+
// SAFETY: `value` cam from a `NonNullPtr::into_inner` call.
73+
unsafe { NonNullPtr::from_inner(value) }
74+
})
7275
}
7376

7477
#[inline]
7578
fn load(&self, order: Ordering) -> Option<NonNullPtr<N>> {
76-
InnerNonZero::new(self.inner.load(order)).map(|inner| NonNullPtr {
77-
inner,
79+
Some(NonNullPtr {
80+
inner: InnerNonZero::new(self.inner.load(order))?,
7881
_marker: PhantomData,
7982
})
8083
}
@@ -115,33 +118,41 @@ where
115118
}
116119

117120
#[inline]
118-
pub fn from_static_mut_ref(ref_: &'static mut N) -> NonNullPtr<N> {
119-
let non_null = NonNull::from(ref_);
120-
Self::from_non_null(non_null)
121+
pub fn from_static_mut_ref(reference: &'static mut N) -> NonNullPtr<N> {
122+
// SAFETY: `reference` is a static mutable reference, i.e. a valid pointer.
123+
unsafe { Self::new_unchecked(initial_tag(), NonNull::from(reference)) }
121124
}
122125

123-
fn from_non_null(ptr: NonNull<N>) -> Self {
124-
let address = ptr.as_ptr() as Address;
125-
let tag = initial_tag().get();
126-
127-
let value = (Inner::from(tag) << Address::BITS) | Inner::from(address);
126+
/// # Safety
127+
///
128+
/// - `ptr` must be a valid pointer.
129+
#[inline]
130+
unsafe fn new_unchecked(tag: Tag, ptr: NonNull<N>) -> Self {
131+
let value =
132+
(Inner::from(tag.get()) << Address::BITS) | Inner::from(ptr.as_ptr() as Address);
128133

129134
Self {
135+
// SAFETY: `value` is constructed from a `Tag` which is non-zero and half the
136+
// size of the `InnerNonZero` type, and a `NonNull<N>` pointer.
130137
inner: unsafe { InnerNonZero::new_unchecked(value) },
131138
_marker: PhantomData,
132139
}
133140
}
134141

142+
/// # Safety
143+
///
144+
/// - `value` must come from a `Self::into_inner` call.
135145
#[inline]
136-
fn from_inner(value: Inner) -> Option<Self> {
137-
InnerNonZero::new(value).map(|inner| Self {
138-
inner,
146+
unsafe fn from_inner(value: Inner) -> Option<Self> {
147+
Some(Self {
148+
inner: InnerNonZero::new(value)?,
139149
_marker: PhantomData,
140150
})
141151
}
142152

143153
#[inline]
144154
fn non_null(&self) -> NonNull<N> {
155+
// SAFETY: `Self` can only be constructed using a `NonNull<N>`.
145156
unsafe { NonNull::new_unchecked(self.as_ptr()) }
146157
}
147158

@@ -152,17 +163,15 @@ where
152163

153164
#[inline]
154165
fn tag(&self) -> Tag {
166+
// SAFETY: `self.inner` was constructed from a non-zero `Tag`.
155167
unsafe { Tag::new_unchecked((self.inner.get() >> Address::BITS) as Address) }
156168
}
157169

158-
fn increase_tag(&mut self) {
159-
let address = self.as_ptr() as Address;
160-
161-
let new_tag = self.tag().checked_add(1).unwrap_or_else(initial_tag).get();
162-
163-
let value = (Inner::from(new_tag) << Address::BITS) | Inner::from(address);
170+
fn increment_tag(&mut self) {
171+
let new_tag = self.tag().checked_add(1).unwrap_or_else(initial_tag);
164172

165-
self.inner = unsafe { InnerNonZero::new_unchecked(value) };
173+
// SAFETY: `self.non_null()` is a valid pointer.
174+
*self = unsafe { Self::new_unchecked(new_tag, self.non_null()) };
166175
}
167176
}
168177

@@ -210,7 +219,40 @@ where
210219
.compare_and_exchange_weak(Some(top), next, Ordering::Release, Ordering::Relaxed)
211220
.is_ok()
212221
{
213-
top.increase_tag();
222+
// Prevent the ABA problem (https://en.wikipedia.org/wiki/Treiber_stack#Correctness).
223+
//
224+
// Without this, the following would be possible:
225+
//
226+
// | Thread 1 | Thread 2 | Stack |
227+
// |-------------------------------|-------------------------|------------------------------|
228+
// | push((1, 1)) | | (1, 1) |
229+
// | push((1, 2)) | | (1, 2) -> (1, 1) |
230+
// | p = try_pop()::load // (1, 2) | | (1, 2) -> (1, 1) |
231+
// | | p = try_pop() // (1, 2) | (1, 1) |
232+
// | | push((1, 3)) | (1, 3) -> (1, 1) |
233+
// | | push(p) | (1, 2) -> (1, 3) -> (1, 1) |
234+
// | try_pop()::cas(p, p.next) | | (1, 1) |
235+
//
236+
// As can be seen, the `cas` operation succeeds, wrongly removing pointer `3` from the stack.
237+
//
238+
// By incrementing the tag before returning the pointer, it cannot be pushed again with the,
239+
// same tag, preventing the `try_pop()::cas(p, p.next)` operation from succeeding.
240+
//
241+
// With this fix, `try_pop()` in thread 2 returns `(2, 2)` and the comparison between
242+
// `(1, 2)` and `(2, 2)` fails, restarting the loop and correctly removing the new top:
243+
//
244+
// | Thread 1 | Thread 2 | Stack |
245+
// |-------------------------------|-------------------------|------------------------------|
246+
// | push((1, 1)) | | (1, 1) |
247+
// | push((1, 2)) | | (1, 2) -> (1, 1) |
248+
// | p = try_pop()::load // (1, 2) | | (1, 2) -> (1, 1) |
249+
// | | p = try_pop() // (2, 2) | (1, 1) |
250+
// | | push((1, 3)) | (1, 3) -> (1, 1) |
251+
// | | push(p) | (2, 2) -> (1, 3) -> (1, 1) |
252+
// | try_pop()::cas(p, p.next) | | (2, 2) -> (1, 3) -> (1, 1) |
253+
// | p = try_pop()::load // (2, 2) | | (2, 2) -> (1, 3) -> (1, 1) |
254+
// | try_pop()::cas(p, p.next) | | (1, 3) -> (1, 1) |
255+
top.increment_tag();
214256

215257
return Some(top);
216258
}

src/pool/treiber/llsc.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,8 @@ mod arch {
125125
}
126126

127127
/// # Safety
128-
/// - `addr` must be a valid pointer
128+
///
129+
/// - `addr` must be a valid pointer.
129130
#[inline(always)]
130131
pub unsafe fn load_link(addr: *const usize) -> usize {
131132
let value;
@@ -134,7 +135,8 @@ mod arch {
134135
}
135136

136137
/// # Safety
137-
/// - `addr` must be a valid pointer
138+
///
139+
/// - `addr` must be a valid pointer.
138140
#[inline(always)]
139141
pub unsafe fn store_conditional(value: usize, addr: *mut usize) -> Result<(), ()> {
140142
let outcome: usize;

0 commit comments

Comments
 (0)