Skip to content

Commit 35171dc

Browse files
authored
rust: support reservoir preemption (#4299)
Summary: This patch extends the bipartite reservoir sampling data structure with support for preemption. A *preemption* occurs when an record is enqueued with a step that is not higher than the last-seen step. In a preemption, all records whose step is not lower than the new record’s step are evicted. Test Plan: Unit tests expanded. wchargin-branch: rust-reservoir-preemption
1 parent 80e33c1 commit 35171dc

File tree

1 file changed

+197
-16
lines changed

1 file changed

+197
-16
lines changed

tensorboard/data/server/reservoir.rs

Lines changed: 197 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ use rand_chacha::ChaCha20Rng;
2828
/// every record in the stream. (One exception: when the reservoir's capacity is zero, it does not
2929
/// retain any records, even the latest one.)
3030
///
31-
/// **Note:** Preemption support is not yet implemented.
32-
///
3331
/// # Preemption
3432
///
3533
/// All records stored in this reservoir have a *step* and a *payload*. The step, a non-negative
@@ -93,6 +91,8 @@ pub struct StageReservoir<T, C = ChaCha20Rng> {
9391
/// preempted from the stream is estimated linearly from the proportion of records preempted
9492
/// from the reservoir.
9593
///
94+
/// It always holds that `self.len() <= self.seen`, even when `self.seen` is not exact.
95+
///
9696
/// Exception: when `capacity == 0`, `seen` is always `0` as well. A reservoir with no capacity
9797
/// is inert and has no need to track `seen`.
9898
seen: usize,
@@ -184,16 +184,22 @@ impl<T, C: ReservoirControl> StageReservoir<T, C> {
184184
if self.capacity == 0 {
185185
return;
186186
}
187+
self.preempt(step);
187188
self.seen += 1;
188-
let dst = self.ctl.destination(self.seen);
189-
190-
if dst >= self.capacity {
191-
// Didn't make the cut? Keep-last only.
192-
self.pop();
193-
} else if self.len() >= self.capacity {
194-
// No room? Evict the destination.
195-
// From `if`-guards, we know `dst < self.capacity <= self.len()`, so this is safe.
196-
self.remove(dst);
189+
190+
// If we can hold every record that we've seen, we can add this record unconditionally.
191+
// Otherwise, we need to roll a destination---even if there's available space, to avoid
192+
// bias right after a preemption.
193+
if self.seen > self.capacity {
194+
let dst = self.ctl.destination(self.seen);
195+
if dst >= self.capacity {
196+
// Didn't make the cut? Keep-last only.
197+
self.pop();
198+
} else if self.len() >= self.capacity {
199+
// No room? Evict the destination.
200+
// From `if`-guards, we know `dst < self.capacity <= self.len()`, so this is safe.
201+
self.remove(dst);
202+
}
197203
}
198204
// In any case, add to end.
199205
self.staged_items.push((step, v));
@@ -234,6 +240,62 @@ impl<T, C: ReservoirControl> StageReservoir<T, C> {
234240
&self.staged_items[..]
235241
}
236242

243+
/// Preempts any records whose step does not precede the given step.
244+
fn preempt(&mut self, step: Step) {
245+
let old_len = self.len();
246+
let staged_preempted = self
247+
.staged_items
248+
.iter()
249+
.rev()
250+
.take_while(|(s, _)| *s >= step)
251+
.count();
252+
if staged_preempted > 0 {
253+
self.staged_items
254+
.truncate(self.staged_items.len() - staged_preempted);
255+
}
256+
if self.staged_items.is_empty() {
257+
// Committed steps may have been preempted as well. Note: we can hit this case even if
258+
// `staged_preempted == 0`, since `staged_items` may have been empty to begin with.
259+
let committed_preempted = self
260+
.committed_steps
261+
.iter()
262+
.rev()
263+
.take_while(|s| **s >= step)
264+
.count();
265+
if committed_preempted > 0 {
266+
self.committed_steps
267+
.truncate(self.committed_steps.len() - committed_preempted);
268+
}
269+
}
270+
let new_len = self.len();
271+
if new_len == old_len {
272+
return; // No need to adjust `seen`.
273+
}
274+
// Update our estimate of `seen` assuming that the fraction of sampled-records preempted is
275+
// the same as the fraction of seen-records preempted. Note: when preempting to or before
276+
// the earliest-written step, `self.len()` will now be `0`, so we will reset `seen` to
277+
// exactly `0`, as desired.
278+
//
279+
// This preserves that `self.len() <= self.seen` because:
280+
//
281+
// ```none
282+
// old_seen >= old_len (by induction)
283+
// old_seen * new_len >= old_len * new_len
284+
// (old_seen * new_len) / old_len >= (old_len * new_len) / old_len
285+
// (old_seen * new_len) / old_len >= new_len (since the above integer division is exact)
286+
// new_seen >= new_len
287+
// ```
288+
let old_seen = self.seen;
289+
let new_seen = ((old_seen as u64 * new_len as u64) / old_len as u64) as usize;
290+
debug_assert!(
291+
new_seen >= new_len,
292+
"old (len, seen) = {:?}, new (len, seen) = {:?}; wanted seen >= len",
293+
(old_len, old_seen),
294+
(new_len, new_seen),
295+
);
296+
self.seen = new_seen;
297+
}
298+
237299
/// Commits pending changes from this reservoir into a basin.
238300
///
239301
/// The basin should initially be empty and should be modified only by calls to
@@ -294,16 +356,20 @@ mod tests {
294356
}
295357
}
296358

359+
/// Extracts the steps from a basin. Convenient for tests.
360+
fn steps<T>(basin: &Basin<T>) -> Vec<Step> {
361+
basin.as_slice().iter().map(|(s, _)| *s).collect()
362+
}
363+
297364
#[test]
298-
fn test() {
365+
fn test_sampling_no_preemptions() {
299366
let mut rsv = StageReservoir::with_control(7, ScriptedControl::new());
300367
let mut head = Basin::new();
301368
fn mapper(s: &str) -> &str {
302369
// leak, for test convenience
303370
Box::leak(format!(":{}:", s).into_boxed_str())
304371
}
305372

306-
rsv.ctl.extend(vec![0, 1, 1, 2]);
307373
rsv.offer(Step(0), "zero");
308374
rsv.offer(Step(1), "one");
309375
rsv.offer(Step(2), "two");
@@ -319,10 +385,10 @@ mod tests {
319385
],
320386
);
321387

322-
rsv.ctl.extend(vec![1, 2, 1, 3]);
323388
rsv.offer(Step(4), "four");
324389
rsv.offer(Step(5), "five");
325390
rsv.offer(Step(6), "six");
391+
rsv.ctl.extend(vec![3]);
326392
rsv.offer(Step(7), "seven"); // this one exceeds capacity, evicting index 3
327393
rsv.commit_map(&mut head, mapper);
328394
assert_eq!(
@@ -357,8 +423,10 @@ mod tests {
357423
);
358424
}
359425

426+
/// Tests some desired properties about sampling and preemption. This uses seeded RNG, but the
427+
/// seed is arbitrary.
360428
#[test]
361-
fn test_random() {
429+
fn test_sampling_preemption() {
362430
// Seeded RNG, with tests for some invariants.
363431
let mut rsv = StageReservoir::new(10);
364432
let mut head = Basin::new();
@@ -385,6 +453,73 @@ mod tests {
385453
assert_eq!(head.as_slice().len(), 10);
386454
assert_eq!(head.as_slice().last(), Some(&(Step(i * i), ())));
387455
}
456+
457+
// Seen 16 records, keeping 10. Preempt to invalidate records 9..=16, that the reservoir
458+
// must have between 2 and 8 old records before the new one is added.
459+
rsv.offer(Step(70), ()); // 8 * 8 < 70 < 9 * 9
460+
rsv.commit(&mut head);
461+
assert!(
462+
(2..=9).contains(&head.as_slice().len()),
463+
"want 2 <= {} <= 9: {:?}",
464+
head.as_slice().len(),
465+
head
466+
);
467+
assert!(
468+
head.as_slice().iter().all(|(s, _)| *s <= Step(70)),
469+
"want all <= 70: {:?}",
470+
head
471+
);
472+
assert_eq!(head.as_slice().last(), Some(&(Step(70), ())));
473+
474+
// One more sanity check: add another record. The "70" preemption may or may not be
475+
// evicted, but this new record should be the last.
476+
rsv.offer(Step(71), ());
477+
rsv.commit(&mut head);
478+
assert_eq!(head.as_slice().last(), Some(&(Step(71), ())));
479+
}
480+
481+
/// Tests that a reservoir may reject a record (modulo keep-last) even if there is sufficient
482+
/// capacity available, if it is estimated that the total number of records seen exceeds the
483+
/// capacity.
484+
///
485+
/// Without this, the reservoir's sampling is biased. Consider a reservoir of capacity 1000
486+
/// into which 1 million records have been offered, at sequential steps. Suppose that we
487+
/// preempt back to step 900_000, and suppose that this leaves the reservoir with 900/1000
488+
/// slots filled. If we add the next record unconditionally, then that record has probability 1
489+
/// of being sampled, whereas all the other records have probability 1/1000. This is
490+
/// nonuniform.
491+
///
492+
/// This test case uses smaller numbers (`n = 4`, `N = 16`), but exhibits the same idea.
493+
#[test]
494+
fn test_unbiased_even_with_space_after_preemption() {
495+
let mut rsv = StageReservoir::with_control(4, ScriptedControl::new());
496+
let mut head = Basin::new();
497+
498+
// Offer:
499+
// - (0..4), filling capacity
500+
// - (4..8), evicting to form [1, 3, 5, 7]
501+
// - (8..16), evicting to form [3, 7, 11, 15]
502+
(0..4).for_each(|i| rsv.offer(Step(i), ()));
503+
rsv.ctl.extend(vec![0, 1, 2, 3]);
504+
(4..8).for_each(|i| rsv.offer(Step(i), ()));
505+
rsv.ctl.extend(vec![0, 8, 1, 9, 2, 10, 3, 11]);
506+
(8..16).for_each(|i| rsv.offer(Step(i), ()));
507+
rsv.commit(&mut head);
508+
assert_eq!(steps(&head), vec![Step(3), Step(7), Step(11), Step(15)]);
509+
assert_eq!(rsv.seen, 16);
510+
511+
// Offer step 4, preempting 12/16 of stream (estimated, and also exactly).
512+
rsv.ctl.extend(vec![3]);
513+
rsv.offer(Step(4), ());
514+
assert_eq!(rsv.seen, 5); // had 16, preempted 12, offered 1
515+
rsv.commit(&mut head);
516+
assert_eq!(steps(&head), vec![Step(3), Step(4)]);
517+
518+
// Offer another record, evicting even though we have capacity.
519+
rsv.ctl.extend(vec![4]);
520+
rsv.offer(Step(5), ());
521+
rsv.commit(&mut head);
522+
assert_eq!(steps(&head), vec![Step(3), Step(5)]); // kept last only
388523
}
389524

390525
#[test]
@@ -413,7 +548,10 @@ mod tests {
413548
fn test_empty() {
414549
let mut rsv = StageReservoir::new(0);
415550
let mut head = Basin::new();
416-
for i in 0..100 {
551+
for mut i in 0..100 {
552+
if i > 60 {
553+
i -= 20; // "preemption", but not really, since no records
554+
}
417555
rsv.offer(Step(i), ());
418556
assert_eq!(rsv.staged_items(), &[]);
419557
if i % 5 == 0 {
@@ -422,4 +560,47 @@ mod tests {
422560
}
423561
}
424562
}
563+
564+
/// Tests that when a reservoir is preempted back to its first-read record, we reset `seen` to
565+
/// exactly zero, so that the next `capacity - 1` records may be read unconditionally. You can
566+
/// imagine implementations of a reservoir whose `seen` estimation rounds in such a way that
567+
/// this doesn't hold. That would be confusing for training jobs that are preempted before they
568+
/// reach their first checkpoint.
569+
#[test]
570+
fn test_preempt_to_start() {
571+
let (rng_line, rng_file) = (line!() + 1, file!()); // for error message below
572+
let rng = ChaCha20Rng::seed_from_u64(0);
573+
let mut rsv = StageReservoir::with_control(10, rng);
574+
let mut head = Basin::new();
575+
for i in 0..10_000 {
576+
rsv.offer(Step(i), ());
577+
}
578+
rsv.commit(&mut head);
579+
assert_eq!(head.as_slice().len(), 10);
580+
// If step 0 happened to be sampled, we can't test this usefully. Complain so that people
581+
// fix the test.
582+
assert_ne!(
583+
head.as_slice().first().expect("head empty after commit").0,
584+
Step(0),
585+
"step 0 happened to stay in the reservoir, which is unlikely but \
586+
possible (0.1% chance); if changing the RNG seed on {}:{} doesn't \
587+
fix this, there may be a bug (committed steps: {:?})",
588+
rng_file,
589+
rng_line,
590+
steps(&head),
591+
);
592+
593+
// Preempt back to step 0.
594+
assert_eq!(rsv.seen, 10_000);
595+
rsv.offer(Step(0), ());
596+
assert_eq!(rsv.seen, 1); // just the newest record
597+
598+
// Offer more points: all should be accepted.
599+
for i in 1..8 {
600+
rsv.offer(Step(i), ());
601+
}
602+
rsv.commit(&mut head);
603+
assert_eq!(steps(&head), (0..8).map(Step).collect::<Vec<_>>());
604+
assert_eq!(rsv.seen, 8);
605+
}
425606
}

0 commit comments

Comments
 (0)