Skip to content

Commit ca9e119

Browse files
authored
Add IndexedRandom::choose_multiple_array, index::sample_array (#1453)
* New private module rand::seq::iterator * New private module rand::seq::slice * Add index::sample_array and IndexedRandom::choose_multiple_array
1 parent ef75e56 commit ca9e119

File tree

5 files changed

+1484
-1389
lines changed

5 files changed

+1484
-1389
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ You may also find the [Upgrade Guide](https://rust-random.github.io/book/update.
1010

1111
## [Unreleased]
1212
- Add `rand::distributions::WeightedIndex::{weight, weights, total_weight}` (#1420)
13+
- Add `IndexedRandom::choose_multiple_array`, `index::sample_array` (#1453)
1314
- Bump the MSRV to 1.61.0
1415
- Rename `Rng::gen` to `Rng::random` to avoid conflict with the new `gen` keyword in Rust 2024 (#1435)
1516
- Move all benchmarks to new `benches` crate (#1439)

src/seq/index.rs

+53-19
Original file line numberDiff line numberDiff line change
@@ -7,35 +7,29 @@
77
// except according to those terms.
88

99
//! Low-level API for sampling indices
10-
use core::{cmp::Ordering, hash::Hash, ops::AddAssign};
11-
12-
#[cfg(feature = "alloc")]
13-
use core::slice;
14-
1510
#[cfg(feature = "alloc")]
1611
use alloc::vec::{self, Vec};
12+
use core::slice;
13+
use core::{hash::Hash, ops::AddAssign};
1714
// BTreeMap is not as fast in tests, but better than nothing.
18-
#[cfg(all(feature = "alloc", not(feature = "std")))]
19-
use alloc::collections::BTreeSet;
20-
#[cfg(feature = "std")]
21-
use std::collections::HashSet;
22-
2315
#[cfg(feature = "std")]
2416
use super::WeightError;
25-
17+
use crate::distributions::uniform::SampleUniform;
2618
#[cfg(feature = "alloc")]
27-
use crate::{
28-
distributions::{uniform::SampleUniform, Distribution, Uniform},
29-
Rng,
30-
};
31-
19+
use crate::distributions::{Distribution, Uniform};
20+
use crate::Rng;
21+
#[cfg(all(feature = "alloc", not(feature = "std")))]
22+
use alloc::collections::BTreeSet;
3223
#[cfg(feature = "serde1")]
3324
use serde::{Deserialize, Serialize};
25+
#[cfg(feature = "std")]
26+
use std::collections::HashSet;
3427

3528
/// A vector of indices.
3629
///
3730
/// Multiple internal representations are possible.
3831
#[derive(Clone, Debug)]
32+
#[cfg(feature = "alloc")]
3933
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
4034
pub enum IndexVec {
4135
#[doc(hidden)]
@@ -44,6 +38,7 @@ pub enum IndexVec {
4438
USize(Vec<usize>),
4539
}
4640

41+
#[cfg(feature = "alloc")]
4742
impl IndexVec {
4843
/// Returns the number of indices
4944
#[inline]
@@ -94,6 +89,7 @@ impl IndexVec {
9489
}
9590
}
9691

92+
#[cfg(feature = "alloc")]
9793
impl IntoIterator for IndexVec {
9894
type IntoIter = IndexVecIntoIter;
9995
type Item = usize;
@@ -108,6 +104,7 @@ impl IntoIterator for IndexVec {
108104
}
109105
}
110106

107+
#[cfg(feature = "alloc")]
111108
impl PartialEq for IndexVec {
112109
fn eq(&self, other: &IndexVec) -> bool {
113110
use self::IndexVec::*;
@@ -124,13 +121,15 @@ impl PartialEq for IndexVec {
124121
}
125122
}
126123

124+
#[cfg(feature = "alloc")]
127125
impl From<Vec<u32>> for IndexVec {
128126
#[inline]
129127
fn from(v: Vec<u32>) -> Self {
130128
IndexVec::U32(v)
131129
}
132130
}
133131

132+
#[cfg(feature = "alloc")]
134133
impl From<Vec<usize>> for IndexVec {
135134
#[inline]
136135
fn from(v: Vec<usize>) -> Self {
@@ -171,6 +170,7 @@ impl<'a> Iterator for IndexVecIter<'a> {
171170
impl<'a> ExactSizeIterator for IndexVecIter<'a> {}
172171

173172
/// Return type of `IndexVec::into_iter`.
173+
#[cfg(feature = "alloc")]
174174
#[derive(Clone, Debug)]
175175
pub enum IndexVecIntoIter {
176176
#[doc(hidden)]
@@ -179,6 +179,7 @@ pub enum IndexVecIntoIter {
179179
USize(vec::IntoIter<usize>),
180180
}
181181

182+
#[cfg(feature = "alloc")]
182183
impl Iterator for IndexVecIntoIter {
183184
type Item = usize;
184185

@@ -201,6 +202,7 @@ impl Iterator for IndexVecIntoIter {
201202
}
202203
}
203204

205+
#[cfg(feature = "alloc")]
204206
impl ExactSizeIterator for IndexVecIntoIter {}
205207

206208
/// Randomly sample exactly `amount` distinct indices from `0..length`, and
@@ -225,6 +227,7 @@ impl ExactSizeIterator for IndexVecIntoIter {}
225227
/// to adapt the internal `sample_floyd` implementation.
226228
///
227229
/// Panics if `amount > length`.
230+
#[cfg(feature = "alloc")]
228231
#[track_caller]
229232
pub fn sample<R>(rng: &mut R, length: usize, amount: usize) -> IndexVec
230233
where
@@ -267,6 +270,33 @@ where
267270
}
268271
}
269272

273+
/// Randomly sample exactly `N` distinct indices from `0..len`, and
274+
/// return them in random order (fully shuffled).
275+
///
276+
/// This is implemented via Floyd's algorithm. Time complexity is `O(N^2)`
277+
/// and memory complexity is `O(N)`.
278+
///
279+
/// Returns `None` if (and only if) `N > len`.
280+
pub fn sample_array<R, const N: usize>(rng: &mut R, len: usize) -> Option<[usize; N]>
281+
where
282+
R: Rng + ?Sized,
283+
{
284+
if N > len {
285+
return None;
286+
}
287+
288+
// Floyd's algorithm
289+
let mut indices = [0; N];
290+
for (i, j) in (len - N..len).enumerate() {
291+
let t = rng.gen_range(0..=j);
292+
if let Some(pos) = indices[0..i].iter().position(|&x| x == t) {
293+
indices[pos] = j;
294+
}
295+
indices[i] = t;
296+
}
297+
Some(indices)
298+
}
299+
270300
/// Randomly sample exactly `amount` distinct indices from `0..length`, and
271301
/// return them in an arbitrary order (there is no guarantee of shuffling or
272302
/// ordering). The weights are to be provided by the input function `weights`,
@@ -329,6 +359,8 @@ where
329359
N: UInt,
330360
IndexVec: From<Vec<N>>,
331361
{
362+
use std::cmp::Ordering;
363+
332364
if amount == N::zero() {
333365
return Ok(IndexVec::U32(Vec::new()));
334366
}
@@ -399,6 +431,7 @@ where
399431
/// The output values are fully shuffled. (Overhead is under 50%.)
400432
///
401433
/// This implementation uses `O(amount)` memory and `O(amount^2)` time.
434+
#[cfg(feature = "alloc")]
402435
fn sample_floyd<R>(rng: &mut R, length: u32, amount: u32) -> IndexVec
403436
where
404437
R: Rng + ?Sized,
@@ -430,6 +463,7 @@ where
430463
/// performance in all cases).
431464
///
432465
/// Set-up is `O(length)` time and memory and shuffling is `O(amount)` time.
466+
#[cfg(feature = "alloc")]
433467
fn sample_inplace<R>(rng: &mut R, length: u32, amount: u32) -> IndexVec
434468
where
435469
R: Rng + ?Sized,
@@ -495,6 +529,7 @@ impl UInt for usize {
495529
///
496530
/// This function is generic over X primarily so that results are value-stable
497531
/// over 32-bit and 64-bit platforms.
532+
#[cfg(feature = "alloc")]
498533
fn sample_rejection<X: UInt, R>(rng: &mut R, length: X, amount: X) -> IndexVec
499534
where
500535
R: Rng + ?Sized,
@@ -519,9 +554,11 @@ where
519554
IndexVec::from(indices)
520555
}
521556

557+
#[cfg(feature = "alloc")]
522558
#[cfg(test)]
523559
mod test {
524560
use super::*;
561+
use alloc::vec;
525562

526563
#[test]
527564
#[cfg(feature = "serde1")]
@@ -542,9 +579,6 @@ mod test {
542579
}
543580
}
544581

545-
#[cfg(feature = "alloc")]
546-
use alloc::vec;
547-
548582
#[test]
549583
fn test_sample_boundaries() {
550584
let mut r = crate::test::rng(404);

0 commit comments

Comments
 (0)