@@ -28,8 +28,6 @@ use rand_chacha::ChaCha20Rng;
28
28
/// every record in the stream. (One exception: when the reservoir's capacity is zero, it does not
29
29
/// retain any records, even the latest one.)
30
30
///
31
- /// **Note:** Preemption support is not yet implemented.
32
- ///
33
31
/// # Preemption
34
32
///
35
33
/// All records stored in this reservoir have a *step* and a *payload*. The step, a non-negative
@@ -93,6 +91,8 @@ pub struct StageReservoir<T, C = ChaCha20Rng> {
93
91
/// preempted from the stream is estimated linearly from the proportion of records preempted
94
92
/// from the reservoir.
95
93
///
94
+ /// It always holds that `self.len() <= self.seen`, even when `self.seen` is not exact.
95
+ ///
96
96
/// Exception: when `capacity == 0`, `seen` is always `0` as well. A reservoir with no capacity
97
97
/// is inert and has no need to track `seen`.
98
98
seen : usize ,
@@ -184,16 +184,22 @@ impl<T, C: ReservoirControl> StageReservoir<T, C> {
184
184
if self . capacity == 0 {
185
185
return ;
186
186
}
187
+ self . preempt ( step) ;
187
188
self . seen += 1 ;
188
- let dst = self . ctl . destination ( self . seen ) ;
189
-
190
- if dst >= self . capacity {
191
- // Didn't make the cut? Keep-last only.
192
- self . pop ( ) ;
193
- } else if self . len ( ) >= self . capacity {
194
- // No room? Evict the destination.
195
- // From `if`-guards, we know `dst < self.capacity <= self.len()`, so this is safe.
196
- self . remove ( dst) ;
189
+
190
+ // If we can hold every record that we've seen, we can add this record unconditionally.
191
+ // Otherwise, we need to roll a destination---even if there's available space, to avoid
192
+ // bias right after a preemption.
193
+ if self . seen > self . capacity {
194
+ let dst = self . ctl . destination ( self . seen ) ;
195
+ if dst >= self . capacity {
196
+ // Didn't make the cut? Keep-last only.
197
+ self . pop ( ) ;
198
+ } else if self . len ( ) >= self . capacity {
199
+ // No room? Evict the destination.
200
+ // From `if`-guards, we know `dst < self.capacity <= self.len()`, so this is safe.
201
+ self . remove ( dst) ;
202
+ }
197
203
}
198
204
// In any case, add to end.
199
205
self . staged_items . push ( ( step, v) ) ;
@@ -234,6 +240,62 @@ impl<T, C: ReservoirControl> StageReservoir<T, C> {
234
240
& self . staged_items [ ..]
235
241
}
236
242
243
+ /// Preempts any records whose step does not precede the given step.
244
+ fn preempt ( & mut self , step : Step ) {
245
+ let old_len = self . len ( ) ;
246
+ let staged_preempted = self
247
+ . staged_items
248
+ . iter ( )
249
+ . rev ( )
250
+ . take_while ( |( s, _) | * s >= step)
251
+ . count ( ) ;
252
+ if staged_preempted > 0 {
253
+ self . staged_items
254
+ . truncate ( self . staged_items . len ( ) - staged_preempted) ;
255
+ }
256
+ if self . staged_items . is_empty ( ) {
257
+ // Committed steps may have been preempted as well. Note: we can hit this case even if
258
+ // `staged_preempted == 0`, since `staged_items` may have been empty to begin with.
259
+ let committed_preempted = self
260
+ . committed_steps
261
+ . iter ( )
262
+ . rev ( )
263
+ . take_while ( |s| * * s >= step)
264
+ . count ( ) ;
265
+ if committed_preempted > 0 {
266
+ self . committed_steps
267
+ . truncate ( self . committed_steps . len ( ) - committed_preempted) ;
268
+ }
269
+ }
270
+ let new_len = self . len ( ) ;
271
+ if new_len == old_len {
272
+ return ; // No need to adjust `seen`.
273
+ }
274
+ // Update our estimate of `seen` assuming that the fraction of sampled-records preempted is
275
+ // the same as the fraction of seen-records preempted. Note: when preempting to or before
276
+ // the earliest-written step, `self.len()` will now be `0`, so we will reset `seen` to
277
+ // exactly `0`, as desired.
278
+ //
279
+ // This preserves that `self.len() <= self.seen` because:
280
+ //
281
+ // ```none
282
+ // old_seen >= old_len (by induction)
283
+ // old_seen * new_len >= old_len * new_len
284
+ // (old_seen * new_len) / old_len >= (old_len * new_len) / old_len
285
+ // (old_seen * new_len) / old_len >= new_len (since the above integer division is exact)
286
+ // new_seen >= new_len
287
+ // ```
288
+ let old_seen = self . seen ;
289
+ let new_seen = ( ( old_seen as u64 * new_len as u64 ) / old_len as u64 ) as usize ;
290
+ debug_assert ! (
291
+ new_seen >= new_len,
292
+ "old (len, seen) = {:?}, new (len, seen) = {:?}; wanted seen >= len" ,
293
+ ( old_len, old_seen) ,
294
+ ( new_len, new_seen) ,
295
+ ) ;
296
+ self . seen = new_seen;
297
+ }
298
+
237
299
/// Commits pending changes from this reservoir into a basin.
238
300
///
239
301
/// The basin should initially be empty and should be modified only by calls to
@@ -294,16 +356,20 @@ mod tests {
294
356
}
295
357
}
296
358
359
+ /// Extracts the steps from a basin. Convenient for tests.
360
+ fn steps < T > ( basin : & Basin < T > ) -> Vec < Step > {
361
+ basin. as_slice ( ) . iter ( ) . map ( |( s, _) | * s) . collect ( )
362
+ }
363
+
297
364
#[ test]
298
- fn test ( ) {
365
+ fn test_sampling_no_preemptions ( ) {
299
366
let mut rsv = StageReservoir :: with_control ( 7 , ScriptedControl :: new ( ) ) ;
300
367
let mut head = Basin :: new ( ) ;
301
368
fn mapper ( s : & str ) -> & str {
302
369
// leak, for test convenience
303
370
Box :: leak ( format ! ( ":{}:" , s) . into_boxed_str ( ) )
304
371
}
305
372
306
- rsv. ctl . extend ( vec ! [ 0 , 1 , 1 , 2 ] ) ;
307
373
rsv. offer ( Step ( 0 ) , "zero" ) ;
308
374
rsv. offer ( Step ( 1 ) , "one" ) ;
309
375
rsv. offer ( Step ( 2 ) , "two" ) ;
@@ -319,10 +385,10 @@ mod tests {
319
385
] ,
320
386
) ;
321
387
322
- rsv. ctl . extend ( vec ! [ 1 , 2 , 1 , 3 ] ) ;
323
388
rsv. offer ( Step ( 4 ) , "four" ) ;
324
389
rsv. offer ( Step ( 5 ) , "five" ) ;
325
390
rsv. offer ( Step ( 6 ) , "six" ) ;
391
+ rsv. ctl . extend ( vec ! [ 3 ] ) ;
326
392
rsv. offer ( Step ( 7 ) , "seven" ) ; // this one exceeds capacity, evicting index 3
327
393
rsv. commit_map ( & mut head, mapper) ;
328
394
assert_eq ! (
@@ -357,8 +423,10 @@ mod tests {
357
423
) ;
358
424
}
359
425
426
+ /// Tests some desired properties about sampling and preemption. This uses seeded RNG, but the
427
+ /// seed is arbitrary.
360
428
#[ test]
361
- fn test_random ( ) {
429
+ fn test_sampling_preemption ( ) {
362
430
// Seeded RNG, with tests for some invariants.
363
431
let mut rsv = StageReservoir :: new ( 10 ) ;
364
432
let mut head = Basin :: new ( ) ;
@@ -385,6 +453,73 @@ mod tests {
385
453
assert_eq ! ( head. as_slice( ) . len( ) , 10 ) ;
386
454
assert_eq ! ( head. as_slice( ) . last( ) , Some ( & ( Step ( i * i) , ( ) ) ) ) ;
387
455
}
456
+
457
+ // Seen 16 records, keeping 10. Preempt to invalidate records 9..=16, that the reservoir
458
+ // must have between 2 and 8 old records before the new one is added.
459
+ rsv. offer ( Step ( 70 ) , ( ) ) ; // 8 * 8 < 70 < 9 * 9
460
+ rsv. commit ( & mut head) ;
461
+ assert ! (
462
+ ( 2 ..=9 ) . contains( & head. as_slice( ) . len( ) ) ,
463
+ "want 2 <= {} <= 9: {:?}" ,
464
+ head. as_slice( ) . len( ) ,
465
+ head
466
+ ) ;
467
+ assert ! (
468
+ head. as_slice( ) . iter( ) . all( |( s, _) | * s <= Step ( 70 ) ) ,
469
+ "want all <= 70: {:?}" ,
470
+ head
471
+ ) ;
472
+ assert_eq ! ( head. as_slice( ) . last( ) , Some ( & ( Step ( 70 ) , ( ) ) ) ) ;
473
+
474
+ // One more sanity check: add another record. The "70" preemption may or may not be
475
+ // evicted, but this new record should be the last.
476
+ rsv. offer ( Step ( 71 ) , ( ) ) ;
477
+ rsv. commit ( & mut head) ;
478
+ assert_eq ! ( head. as_slice( ) . last( ) , Some ( & ( Step ( 71 ) , ( ) ) ) ) ;
479
+ }
480
+
481
+ /// Tests that a reservoir may reject a record (modulo keep-last) even if there is sufficient
482
+ /// capacity available, if it is estimated that the total number of records seen exceeds the
483
+ /// capacity.
484
+ ///
485
+ /// Without this, the reservoir's sampling is biased. Consider a reservoir of capacity 1000
486
+ /// into which 1 million records have been offered, at sequential steps. Suppose that we
487
+ /// preempt back to step 900_000, and suppose that this leaves the reservoir with 900/1000
488
+ /// slots filled. If we add the next record unconditionally, then that record has probability 1
489
+ /// of being sampled, whereas all the other records have probability 1/1000. This is
490
+ /// nonuniform.
491
+ ///
492
+ /// This test case uses smaller numbers (`n = 4`, `N = 16`), but exhibits the same idea.
493
+ #[ test]
494
+ fn test_unbiased_even_with_space_after_preemption ( ) {
495
+ let mut rsv = StageReservoir :: with_control ( 4 , ScriptedControl :: new ( ) ) ;
496
+ let mut head = Basin :: new ( ) ;
497
+
498
+ // Offer:
499
+ // - (0..4), filling capacity
500
+ // - (4..8), evicting to form [1, 3, 5, 7]
501
+ // - (8..16), evicting to form [3, 7, 11, 15]
502
+ ( 0 ..4 ) . for_each ( |i| rsv. offer ( Step ( i) , ( ) ) ) ;
503
+ rsv. ctl . extend ( vec ! [ 0 , 1 , 2 , 3 ] ) ;
504
+ ( 4 ..8 ) . for_each ( |i| rsv. offer ( Step ( i) , ( ) ) ) ;
505
+ rsv. ctl . extend ( vec ! [ 0 , 8 , 1 , 9 , 2 , 10 , 3 , 11 ] ) ;
506
+ ( 8 ..16 ) . for_each ( |i| rsv. offer ( Step ( i) , ( ) ) ) ;
507
+ rsv. commit ( & mut head) ;
508
+ assert_eq ! ( steps( & head) , vec![ Step ( 3 ) , Step ( 7 ) , Step ( 11 ) , Step ( 15 ) ] ) ;
509
+ assert_eq ! ( rsv. seen, 16 ) ;
510
+
511
+ // Offer step 4, preempting 12/16 of stream (estimated, and also exactly).
512
+ rsv. ctl . extend ( vec ! [ 3 ] ) ;
513
+ rsv. offer ( Step ( 4 ) , ( ) ) ;
514
+ assert_eq ! ( rsv. seen, 5 ) ; // had 16, preempted 12, offered 1
515
+ rsv. commit ( & mut head) ;
516
+ assert_eq ! ( steps( & head) , vec![ Step ( 3 ) , Step ( 4 ) ] ) ;
517
+
518
+ // Offer another record, evicting even though we have capacity.
519
+ rsv. ctl . extend ( vec ! [ 4 ] ) ;
520
+ rsv. offer ( Step ( 5 ) , ( ) ) ;
521
+ rsv. commit ( & mut head) ;
522
+ assert_eq ! ( steps( & head) , vec![ Step ( 3 ) , Step ( 5 ) ] ) ; // kept last only
388
523
}
389
524
390
525
#[ test]
@@ -413,7 +548,10 @@ mod tests {
413
548
fn test_empty ( ) {
414
549
let mut rsv = StageReservoir :: new ( 0 ) ;
415
550
let mut head = Basin :: new ( ) ;
416
- for i in 0 ..100 {
551
+ for mut i in 0 ..100 {
552
+ if i > 60 {
553
+ i -= 20 ; // "preemption", but not really, since no records
554
+ }
417
555
rsv. offer ( Step ( i) , ( ) ) ;
418
556
assert_eq ! ( rsv. staged_items( ) , & [ ] ) ;
419
557
if i % 5 == 0 {
@@ -422,4 +560,47 @@ mod tests {
422
560
}
423
561
}
424
562
}
563
+
564
+ /// Tests that when a reservoir is preempted back to its first-read record, we reset `seen` to
565
+ /// exactly zero, so that the next `capacity - 1` records may be read unconditionally. You can
566
+ /// imagine implementations of a reservoir whose `seen` estimation rounds in such a way that
567
+ /// this doesn't hold. That would be confusing for training jobs that are preempted before they
568
+ /// reach their first checkpoint.
569
+ #[ test]
570
+ fn test_preempt_to_start ( ) {
571
+ let ( rng_line, rng_file) = ( line ! ( ) + 1 , file ! ( ) ) ; // for error message below
572
+ let rng = ChaCha20Rng :: seed_from_u64 ( 0 ) ;
573
+ let mut rsv = StageReservoir :: with_control ( 10 , rng) ;
574
+ let mut head = Basin :: new ( ) ;
575
+ for i in 0 ..10_000 {
576
+ rsv. offer ( Step ( i) , ( ) ) ;
577
+ }
578
+ rsv. commit ( & mut head) ;
579
+ assert_eq ! ( head. as_slice( ) . len( ) , 10 ) ;
580
+ // If step 0 happened to be sampled, we can't test this usefully. Complain so that people
581
+ // fix the test.
582
+ assert_ne ! (
583
+ head. as_slice( ) . first( ) . expect( "head empty after commit" ) . 0 ,
584
+ Step ( 0 ) ,
585
+ "step 0 happened to stay in the reservoir, which is unlikely but \
586
+ possible (0.1% chance); if changing the RNG seed on {}:{} doesn't \
587
+ fix this, there may be a bug (committed steps: {:?})",
588
+ rng_file,
589
+ rng_line,
590
+ steps( & head) ,
591
+ ) ;
592
+
593
+ // Preempt back to step 0.
594
+ assert_eq ! ( rsv. seen, 10_000 ) ;
595
+ rsv. offer ( Step ( 0 ) , ( ) ) ;
596
+ assert_eq ! ( rsv. seen, 1 ) ; // just the newest record
597
+
598
+ // Offer more points: all should be accepted.
599
+ for i in 1 ..8 {
600
+ rsv. offer ( Step ( i) , ( ) ) ;
601
+ }
602
+ rsv. commit ( & mut head) ;
603
+ assert_eq ! ( steps( & head) , ( 0 ..8 ) . map( Step ) . collect:: <Vec <_>>( ) ) ;
604
+ assert_eq ! ( rsv. seen, 8 ) ;
605
+ }
425
606
}
0 commit comments