@@ -15,6 +15,7 @@ use core::fmt;
15
15
16
16
// Note that this whole module is only imported if feature="alloc" is enabled.
17
17
use alloc:: vec:: Vec ;
18
+ use core:: fmt:: Debug ;
18
19
19
20
#[ cfg( feature = "serde1" ) ]
20
21
use serde:: { Deserialize , Serialize } ;
@@ -243,6 +244,124 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
243
244
}
244
245
}
245
246
247
+ /// A lazy-loading iterator over the weights of a `WeightedIndex` distribution.
248
+ /// This is returned by [`WeightedIndex::weights`].
249
+ pub struct WeightedIndexIter < ' a , X : SampleUniform + PartialOrd > {
250
+ weighted_index : & ' a WeightedIndex < X > ,
251
+ index : usize ,
252
+ }
253
+
254
+ impl < ' a , X > Debug for WeightedIndexIter < ' a , X >
255
+ where
256
+ X : SampleUniform + PartialOrd + Debug ,
257
+ X :: Sampler : Debug ,
258
+ {
259
+ fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
260
+ f. debug_struct ( "WeightedIndexIter" )
261
+ . field ( "weighted_index" , & self . weighted_index )
262
+ . field ( "index" , & self . index )
263
+ . finish ( )
264
+ }
265
+ }
266
+
267
+ impl < ' a , X > Clone for WeightedIndexIter < ' a , X >
268
+ where
269
+ X : SampleUniform + PartialOrd ,
270
+ {
271
+ fn clone ( & self ) -> Self {
272
+ WeightedIndexIter {
273
+ weighted_index : self . weighted_index ,
274
+ index : self . index ,
275
+ }
276
+ }
277
+ }
278
+
279
+ impl < ' a , X > Iterator for WeightedIndexIter < ' a , X >
280
+ where
281
+ X : for < ' b > :: core:: ops:: SubAssign < & ' b X >
282
+ + SampleUniform
283
+ + PartialOrd
284
+ + Clone ,
285
+ {
286
+ type Item = X ;
287
+
288
+ fn next ( & mut self ) -> Option < Self :: Item > {
289
+ match self . weighted_index . weight ( self . index ) {
290
+ None => None ,
291
+ Some ( weight) => {
292
+ self . index += 1 ;
293
+ Some ( weight)
294
+ }
295
+ }
296
+ }
297
+ }
298
+
299
+ impl < X : SampleUniform + PartialOrd + Clone > WeightedIndex < X > {
300
+ /// Returns the weight at the given index, if it exists.
301
+ ///
302
+ /// If the index is out of bounds, this will return `None`.
303
+ ///
304
+ /// # Example
305
+ ///
306
+ /// ```
307
+ /// use rand::distributions::WeightedIndex;
308
+ ///
309
+ /// let weights = [0, 1, 2];
310
+ /// let dist = WeightedIndex::new(&weights).unwrap();
311
+ /// assert_eq!(dist.weight(0), Some(0));
312
+ /// assert_eq!(dist.weight(1), Some(1));
313
+ /// assert_eq!(dist.weight(2), Some(2));
314
+ /// assert_eq!(dist.weight(3), None);
315
+ /// ```
316
+ pub fn weight ( & self , index : usize ) -> Option < X >
317
+ where
318
+ X : for < ' a > :: core:: ops:: SubAssign < & ' a X >
319
+ {
320
+ let mut weight = if index < self . cumulative_weights . len ( ) {
321
+ self . cumulative_weights [ index] . clone ( )
322
+ } else if index == self . cumulative_weights . len ( ) {
323
+ self . total_weight . clone ( )
324
+ } else {
325
+ return None ;
326
+ } ;
327
+ if index > 0 {
328
+ weight -= & self . cumulative_weights [ index - 1 ] ;
329
+ }
330
+ Some ( weight)
331
+ }
332
+
333
+ /// Returns a lazy-loading iterator containing the current weights of this distribution.
334
+ ///
335
+ /// If this distribution has not been updated since its creation, this will return the
336
+ /// same weights as were passed to `new`.
337
+ ///
338
+ /// # Example
339
+ ///
340
+ /// ```
341
+ /// use rand::distributions::WeightedIndex;
342
+ ///
343
+ /// let weights = [1, 2, 3];
344
+ /// let mut dist = WeightedIndex::new(&weights).unwrap();
345
+ /// assert_eq!(dist.weights().collect::<Vec<_>>(), vec![1, 2, 3]);
346
+ /// dist.update_weights(&[(0, &2)]).unwrap();
347
+ /// assert_eq!(dist.weights().collect::<Vec<_>>(), vec![2, 2, 3]);
348
+ /// ```
349
+ pub fn weights ( & self ) -> WeightedIndexIter < ' _ , X >
350
+ where
351
+ X : for < ' a > :: core:: ops:: SubAssign < & ' a X >
352
+ {
353
+ WeightedIndexIter {
354
+ weighted_index : self ,
355
+ index : 0 ,
356
+ }
357
+ }
358
+
359
+ /// Returns the sum of all weights in this distribution.
360
+ pub fn total_weight ( & self ) -> X {
361
+ self . total_weight . clone ( )
362
+ }
363
+ }
364
+
246
365
impl < X > Distribution < usize > for WeightedIndex < X >
247
366
where
248
367
X : SampleUniform + PartialOrd ,
@@ -458,6 +577,75 @@ mod test {
458
577
}
459
578
}
460
579
580
+ #[ test]
581
+ fn test_update_weights_errors ( ) {
582
+ let data = [
583
+ (
584
+ & [ 1i32 , 0 , 0 ] [ ..] ,
585
+ & [ ( 0 , & 0 ) ] [ ..] ,
586
+ WeightError :: InsufficientNonZero ,
587
+ ) ,
588
+ (
589
+ & [ 10 , 10 , 10 , 10 ] [ ..] ,
590
+ & [ ( 1 , & -11 ) ] [ ..] ,
591
+ WeightError :: InvalidWeight , // A weight is negative
592
+ ) ,
593
+ (
594
+ & [ 1 , 2 , 3 , 4 , 5 ] [ ..] ,
595
+ & [ ( 1 , & 5 ) , ( 0 , & 5 ) ] [ ..] , // Wrong order
596
+ WeightError :: InvalidInput ,
597
+ ) ,
598
+ (
599
+ & [ 1 ] [ ..] ,
600
+ & [ ( 1 , & 1 ) ] [ ..] , // Index too large
601
+ WeightError :: InvalidInput ,
602
+ ) ,
603
+ ] ;
604
+
605
+ for ( weights, update, err) in data. iter ( ) {
606
+ let total_weight = weights. iter ( ) . sum :: < i32 > ( ) ;
607
+ let mut distr = WeightedIndex :: new ( weights. to_vec ( ) ) . unwrap ( ) ;
608
+ assert_eq ! ( distr. total_weight, total_weight) ;
609
+ match distr. update_weights ( update) {
610
+ Ok ( _) => panic ! ( "Expected update_weights to fail, but it succeeded" ) ,
611
+ Err ( e) => assert_eq ! ( e, * err) ,
612
+ }
613
+ }
614
+ }
615
+
616
+ #[ test]
617
+ fn test_weight_at ( ) {
618
+ let data = [
619
+ & [ 1 ] [ ..] ,
620
+ & [ 10 , 2 , 3 , 4 ] [ ..] ,
621
+ & [ 1 , 2 , 3 , 0 , 5 , 6 , 7 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ] [ ..] ,
622
+ & [ u32:: MAX ] [ ..] ,
623
+ ] ;
624
+
625
+ for weights in data. iter ( ) {
626
+ let distr = WeightedIndex :: new ( weights. to_vec ( ) ) . unwrap ( ) ;
627
+ for ( i, weight) in weights. iter ( ) . enumerate ( ) {
628
+ assert_eq ! ( distr. weight( i) , Some ( * weight) ) ;
629
+ }
630
+ assert_eq ! ( distr. weight( weights. len( ) ) , None ) ;
631
+ }
632
+ }
633
+
634
+ #[ test]
635
+ fn test_weights ( ) {
636
+ let data = [
637
+ & [ 1 ] [ ..] ,
638
+ & [ 10 , 2 , 3 , 4 ] [ ..] ,
639
+ & [ 1 , 2 , 3 , 0 , 5 , 6 , 7 , 1 , 2 , 3 , 4 , 5 , 6 , 7 ] [ ..] ,
640
+ & [ u32:: MAX ] [ ..] ,
641
+ ] ;
642
+
643
+ for weights in data. iter ( ) {
644
+ let distr = WeightedIndex :: new ( weights. to_vec ( ) ) . unwrap ( ) ;
645
+ assert_eq ! ( distr. weights( ) . collect:: <Vec <_>>( ) , weights. to_vec( ) ) ;
646
+ }
647
+ }
648
+
461
649
#[ test]
462
650
fn value_stability ( ) {
463
651
fn test_samples < X : Weight + SampleUniform + PartialOrd , I > (
0 commit comments