Skip to content

Commit dc00192

Browse files
committed
Add BlockRng64 wrapper
1 parent 602ff6f commit dc00192

1 file changed

Lines changed: 161 additions & 0 deletions

File tree

rand_core/src/impls.rs

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,167 @@ impl<R: BlockRngCore + SeedableRng> SeedableRng for BlockRng<R> {
335335
}
336336
}
337337

338+
339+
340+
/// Wrapper around PRNGs that implement [`BlockRngCore`] to keep a results
341+
/// buffer and offer the methods from [`RngCore`].
342+
///
343+
/// This is similar to [`BlockRng`], but specialized for algorithms that operate
344+
/// on `u64` values.
345+
///
346+
/// [`BlockRngCore`]: ../BlockRngCore.t.html
347+
/// [`RngCore`]: ../RngCore.t.html
348+
/// [`BlockRng`]: struct.BlockRng.html
349+
#[derive(Clone)]
350+
pub struct BlockRng64<R: BlockRngCore + ?Sized> {
351+
pub results: R::Results,
352+
pub index: usize,
353+
pub half_used: bool, // true if only half of the previous result is used
354+
pub core: R,
355+
}
356+
357+
// Custom Debug implementation that does not expose the contents of `results`.
358+
impl<R: BlockRngCore + fmt::Debug> fmt::Debug for BlockRng64<R> {
359+
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
360+
fmt.debug_struct("BlockRng64")
361+
.field("core", &self.core)
362+
.field("result_len", &self.results.as_ref().len())
363+
.field("index", &self.index)
364+
.field("half_used", &self.half_used)
365+
.finish()
366+
}
367+
}
368+
369+
impl<R: BlockRngCore<Item=u64>> RngCore for BlockRng64<R>
370+
where <R as BlockRngCore>::Results: AsRef<[u64]>
371+
{
372+
#[inline(always)]
373+
fn next_u32(&mut self) -> u32 {
374+
let mut index = self.index * 2 - self.half_used as usize;
375+
if index >= self.results.as_ref().len() * 2 {
376+
self.core.generate(&mut self.results);
377+
self.index = 0;
378+
// `self.half_used` is by definition `false`
379+
self.half_used = false;
380+
index = 0;
381+
}
382+
383+
self.half_used = !self.half_used;
384+
self.index += self.half_used as usize;
385+
386+
// Index as if this is a u32 slice.
387+
unsafe {
388+
let results =
389+
&*(self.results.as_ref() as *const [u64] as *const [u32]);
390+
if cfg!(target_endian = "little") {
391+
*results.get_unchecked(index)
392+
} else {
393+
*results.get_unchecked(index ^ 1)
394+
}
395+
}
396+
}
397+
398+
#[inline(always)]
399+
fn next_u64(&mut self) -> u64 {
400+
if self.index >= self.results.as_ref().len() {
401+
self.core.generate(&mut self.results);
402+
self.index = 0;
403+
}
404+
405+
let value = self.results.as_ref()[self.index];
406+
self.index += 1;
407+
self.half_used = false;
408+
value
409+
}
410+
411+
// As an optimization we try to write directly into the output buffer.
412+
// This is only enabled for little-endian platforms where unaligned writes
413+
// are known to be safe and fast.
414+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
415+
fn fill_bytes(&mut self, dest: &mut [u8]) {
416+
let mut filled = 0;
417+
418+
// Continue filling from the current set of results
419+
if self.index < self.results.as_ref().len() {
420+
let (consumed_u64, filled_u8) =
421+
fill_via_u64_chunks(&self.results.as_ref()[self.index..],
422+
dest);
423+
424+
self.index += consumed_u64;
425+
filled += filled_u8;
426+
}
427+
428+
let len_remainder =
429+
(dest.len() - filled) % (self.results.as_ref().len() * 8);
430+
let end_direct = dest.len() - len_remainder;
431+
432+
while filled < end_direct {
433+
let dest_u64: &mut R::Results = unsafe {
434+
::core::mem::transmute(dest[filled..].as_mut_ptr())
435+
};
436+
self.core.generate(dest_u64);
437+
filled += self.results.as_ref().len() * 8;
438+
}
439+
self.index = self.results.as_ref().len();
440+
441+
if len_remainder > 0 {
442+
self.core.generate(&mut self.results);
443+
let (consumed_u64, _) =
444+
fill_via_u64_chunks(&mut self.results.as_ref(),
445+
&mut dest[filled..]);
446+
447+
self.index = consumed_u64;
448+
}
449+
}
450+
451+
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
452+
fn fill_bytes(&mut self, dest: &mut [u8]) {
453+
let mut read_len = 0;
454+
while read_len < dest.len() {
455+
if self.index as usize >= self.results.as_ref().len() {
456+
self.core.generate(&mut self.results);
457+
self.index = 0;
458+
self.half_used = false;
459+
}
460+
461+
let (consumed_u64, filled_u8) =
462+
fill_via_u64_chunks(&self.results.as_ref()[self.index as usize..],
463+
&mut dest[read_len..]);
464+
465+
self.index += consumed_u64;
466+
read_len += filled_u8;
467+
}
468+
}
469+
470+
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> {
471+
Ok(self.fill_bytes(dest))
472+
}
473+
}
474+
475+
impl<R: BlockRngCore + SeedableRng> SeedableRng for BlockRng64<R> {
476+
type Seed = R::Seed;
477+
478+
fn from_seed(seed: Self::Seed) -> Self {
479+
let results_empty = R::Results::default();
480+
Self {
481+
core: R::from_seed(seed),
482+
index: results_empty.as_ref().len(), // generate on first use
483+
half_used: false,
484+
results: results_empty,
485+
}
486+
}
487+
488+
fn from_rng<S: RngCore>(rng: S) -> Result<Self, Error> {
489+
let results_empty = R::Results::default();
490+
Ok(Self {
491+
core: R::from_rng(rng)?,
492+
index: results_empty.as_ref().len(), // generate on first use
493+
half_used: false,
494+
results: results_empty,
495+
})
496+
}
497+
}
498+
338499
impl<R: BlockRngCore + CryptoRng> CryptoRng for BlockRng<R> {}
339500

340501
// TODO: implement tests for the above

0 commit comments

Comments
 (0)