Skip to content

Commit 7b37c15

Browse files
Add methods weight, weights, and total_weight to weighted_index.rs (#1420)
1 parent 0518975 commit 7b37c15

File tree

2 files changed

+189
-0
lines changed

2 files changed

+189
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ A [separate changelog is kept for rand_core](rand_core/CHANGELOG.md).
99
You may also find the [Upgrade Guide](https://rust-random.github.io/book/update.html) useful.
1010

1111
## [Unreleased]
12+
- Add `rand::distributions::WeightedIndex::{weight, weights, total_weight}` (#1420)
1213
- Bump the MSRV to 1.61.0
1314

1415
## [0.9.0-alpha.1] - 2024-03-18

src/distributions/weighted_index.rs

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ use core::fmt;
1515

1616
// Note that this whole module is only imported if feature="alloc" is enabled.
1717
use alloc::vec::Vec;
18+
use core::fmt::Debug;
1819

1920
#[cfg(feature = "serde1")]
2021
use serde::{Deserialize, Serialize};
@@ -243,6 +244,124 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
243244
}
244245
}
245246

247+
/// A lazy-loading iterator over the weights of a `WeightedIndex` distribution.
248+
/// This is returned by [`WeightedIndex::weights`].
249+
pub struct WeightedIndexIter<'a, X: SampleUniform + PartialOrd> {
250+
weighted_index: &'a WeightedIndex<X>,
251+
index: usize,
252+
}
253+
254+
impl<'a, X> Debug for WeightedIndexIter<'a, X>
255+
where
256+
X: SampleUniform + PartialOrd + Debug,
257+
X::Sampler: Debug,
258+
{
259+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
260+
f.debug_struct("WeightedIndexIter")
261+
.field("weighted_index", &self.weighted_index)
262+
.field("index", &self.index)
263+
.finish()
264+
}
265+
}
266+
267+
impl<'a, X> Clone for WeightedIndexIter<'a, X>
268+
where
269+
X: SampleUniform + PartialOrd,
270+
{
271+
fn clone(&self) -> Self {
272+
WeightedIndexIter {
273+
weighted_index: self.weighted_index,
274+
index: self.index,
275+
}
276+
}
277+
}
278+
279+
impl<'a, X> Iterator for WeightedIndexIter<'a, X>
280+
where
281+
X: for<'b> ::core::ops::SubAssign<&'b X>
282+
+ SampleUniform
283+
+ PartialOrd
284+
+ Clone,
285+
{
286+
type Item = X;
287+
288+
fn next(&mut self) -> Option<Self::Item> {
289+
match self.weighted_index.weight(self.index) {
290+
None => None,
291+
Some(weight) => {
292+
self.index += 1;
293+
Some(weight)
294+
}
295+
}
296+
}
297+
}
298+
299+
impl<X: SampleUniform + PartialOrd + Clone> WeightedIndex<X> {
300+
/// Returns the weight at the given index, if it exists.
301+
///
302+
/// If the index is out of bounds, this will return `None`.
303+
///
304+
/// # Example
305+
///
306+
/// ```
307+
/// use rand::distributions::WeightedIndex;
308+
///
309+
/// let weights = [0, 1, 2];
310+
/// let dist = WeightedIndex::new(&weights).unwrap();
311+
/// assert_eq!(dist.weight(0), Some(0));
312+
/// assert_eq!(dist.weight(1), Some(1));
313+
/// assert_eq!(dist.weight(2), Some(2));
314+
/// assert_eq!(dist.weight(3), None);
315+
/// ```
316+
pub fn weight(&self, index: usize) -> Option<X>
317+
where
318+
X: for<'a> ::core::ops::SubAssign<&'a X>
319+
{
320+
let mut weight = if index < self.cumulative_weights.len() {
321+
self.cumulative_weights[index].clone()
322+
} else if index == self.cumulative_weights.len() {
323+
self.total_weight.clone()
324+
} else {
325+
return None;
326+
};
327+
if index > 0 {
328+
weight -= &self.cumulative_weights[index - 1];
329+
}
330+
Some(weight)
331+
}
332+
333+
/// Returns a lazy-loading iterator containing the current weights of this distribution.
334+
///
335+
/// If this distribution has not been updated since its creation, this will return the
336+
/// same weights as were passed to `new`.
337+
///
338+
/// # Example
339+
///
340+
/// ```
341+
/// use rand::distributions::WeightedIndex;
342+
///
343+
/// let weights = [1, 2, 3];
344+
/// let mut dist = WeightedIndex::new(&weights).unwrap();
345+
/// assert_eq!(dist.weights().collect::<Vec<_>>(), vec![1, 2, 3]);
346+
/// dist.update_weights(&[(0, &2)]).unwrap();
347+
/// assert_eq!(dist.weights().collect::<Vec<_>>(), vec![2, 2, 3]);
348+
/// ```
349+
pub fn weights(&self) -> WeightedIndexIter<'_, X>
350+
where
351+
X: for<'a> ::core::ops::SubAssign<&'a X>
352+
{
353+
WeightedIndexIter {
354+
weighted_index: self,
355+
index: 0,
356+
}
357+
}
358+
359+
/// Returns the sum of all weights in this distribution.
360+
pub fn total_weight(&self) -> X {
361+
self.total_weight.clone()
362+
}
363+
}
364+
246365
impl<X> Distribution<usize> for WeightedIndex<X>
247366
where
248367
X: SampleUniform + PartialOrd,
@@ -458,6 +577,75 @@ mod test {
458577
}
459578
}
460579

580+
#[test]
581+
fn test_update_weights_errors() {
582+
let data = [
583+
(
584+
&[1i32, 0, 0][..],
585+
&[(0, &0)][..],
586+
WeightError::InsufficientNonZero,
587+
),
588+
(
589+
&[10, 10, 10, 10][..],
590+
&[(1, &-11)][..],
591+
WeightError::InvalidWeight, // A weight is negative
592+
),
593+
(
594+
&[1, 2, 3, 4, 5][..],
595+
&[(1, &5), (0, &5)][..], // Wrong order
596+
WeightError::InvalidInput,
597+
),
598+
(
599+
&[1][..],
600+
&[(1, &1)][..], // Index too large
601+
WeightError::InvalidInput,
602+
),
603+
];
604+
605+
for (weights, update, err) in data.iter() {
606+
let total_weight = weights.iter().sum::<i32>();
607+
let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
608+
assert_eq!(distr.total_weight, total_weight);
609+
match distr.update_weights(update) {
610+
Ok(_) => panic!("Expected update_weights to fail, but it succeeded"),
611+
Err(e) => assert_eq!(e, *err),
612+
}
613+
}
614+
}
615+
616+
#[test]
617+
fn test_weight_at() {
618+
let data = [
619+
&[1][..],
620+
&[10, 2, 3, 4][..],
621+
&[1, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
622+
&[u32::MAX][..],
623+
];
624+
625+
for weights in data.iter() {
626+
let distr = WeightedIndex::new(weights.to_vec()).unwrap();
627+
for (i, weight) in weights.iter().enumerate() {
628+
assert_eq!(distr.weight(i), Some(*weight));
629+
}
630+
assert_eq!(distr.weight(weights.len()), None);
631+
}
632+
}
633+
634+
#[test]
635+
fn test_weights() {
636+
let data = [
637+
&[1][..],
638+
&[10, 2, 3, 4][..],
639+
&[1, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
640+
&[u32::MAX][..],
641+
];
642+
643+
for weights in data.iter() {
644+
let distr = WeightedIndex::new(weights.to_vec()).unwrap();
645+
assert_eq!(distr.weights().collect::<Vec<_>>(), weights.to_vec());
646+
}
647+
}
648+
461649
#[test]
462650
fn value_stability() {
463651
fn test_samples<X: Weight + SampleUniform + PartialOrd, I>(

0 commit comments

Comments
 (0)