Skip to content

Commit c59972e

Browse files
committed
Add BlockRng64 wrapper
1 parent 9230307 commit c59972e

File tree

1 file changed

+162
-0
lines changed

1 file changed

+162
-0
lines changed

rand-core/src/impls.rs

+162
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,168 @@ impl<R: BlockRngCore + SeedableRng> SeedableRng for BlockRng<R> {
333333
}
334334
}
335335

336+
337+
338+
/// Wrapper around PRNGs that implement [`BlockRngCore`] to keep a results
339+
/// buffer and offer the methods from [`RngCore`].
340+
///
341+
/// This is similar to [`BlockRng`], but specialized for algorithms that operate
342+
/// on `u64` values (rare). Hopefully specialization will one day allow
343+
/// `BlockRng` to support both, so that this wrapper can be removed.
344+
///
345+
/// [`BlockRngCore`]: ../BlockRngCore.t.html
346+
/// [`RngCore`]: ../RngCore.t.html
347+
/// [`BlockRng`]: .BlockRng.s.html
348+
#[derive(Clone)]
349+
pub struct BlockRng64<R: BlockRngCore + ?Sized> {
350+
pub results: R::Results,
351+
pub index: usize,
352+
pub half_used: bool, // true if only half of the previous result is used
353+
pub core: R,
354+
}
355+
356+
// Custom Debug implementation that does not expose the contents of `results`.
357+
impl<R: BlockRngCore + fmt::Debug> fmt::Debug for BlockRng64<R> {
358+
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
359+
fmt.debug_struct("BlockRng64")
360+
.field("core", &self.core)
361+
.field("result_len", &self.results.as_ref().len())
362+
.field("index", &self.index)
363+
.field("half_used", &self.half_used)
364+
.finish()
365+
}
366+
}
367+
368+
impl<R: BlockRngCore<Item=u64>> RngCore for BlockRng64<R>
369+
where <R as BlockRngCore>::Results: AsRef<[u64]>
370+
{
371+
#[inline(always)]
372+
fn next_u32(&mut self) -> u32 {
373+
let mut index = self.index * 2 - self.half_used as usize;
374+
if index >= self.results.as_ref().len() * 2 {
375+
self.core.generate(&mut self.results);
376+
self.index = 0;
377+
// `self.half_used` is by definition `false`
378+
self.half_used = false;
379+
index = 0;
380+
}
381+
382+
self.half_used = !self.half_used;
383+
self.index += self.half_used as usize;
384+
385+
// Index as if this is a u32 slice.
386+
unsafe {
387+
let results =
388+
&*(self.results.as_ref() as *const [u64] as *const [u32]);
389+
if cfg!(target_endian = "little") {
390+
*results.get_unchecked(index)
391+
} else {
392+
*results.get_unchecked(index ^ 1)
393+
}
394+
}
395+
}
396+
397+
#[inline(always)]
398+
fn next_u64(&mut self) -> u64 {
399+
if self.index >= self.results.as_ref().len() {
400+
self.core.generate(&mut self.results);
401+
self.index = 0;
402+
}
403+
404+
let value = self.results.as_ref()[self.index];
405+
self.index += 1;
406+
self.half_used = false;
407+
value
408+
}
409+
410+
// As an optimization we try to write directly into the output buffer.
411+
// This is only enabled for little-endian platforms where unaligned writes
412+
// are known to be safe and fast.
413+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
414+
fn fill_bytes(&mut self, dest: &mut [u8]) {
415+
let mut filled = 0;
416+
417+
// Continue filling from the current set of results
418+
if self.index < self.results.as_ref().len() {
419+
let (consumed_u64, filled_u8) =
420+
fill_via_u64_chunks(&self.results.as_ref()[self.index..],
421+
dest);
422+
423+
self.index += consumed_u64;
424+
filled += filled_u8;
425+
}
426+
427+
let len_remainder =
428+
(dest.len() - filled) % (self.results.as_ref().len() * 8);
429+
let end_direct = dest.len() - len_remainder;
430+
431+
while filled < end_direct {
432+
let dest_u64: &mut R::Results = unsafe {
433+
::core::mem::transmute(dest[filled..].as_mut_ptr())
434+
};
435+
self.core.generate(dest_u64);
436+
filled += self.results.as_ref().len() * 8;
437+
}
438+
self.index = self.results.as_ref().len();
439+
440+
if len_remainder > 0 {
441+
self.core.generate(&mut self.results);
442+
let (consumed_u64, _) =
443+
fill_via_u64_chunks(&mut self.results.as_ref(),
444+
&mut dest[filled..]);
445+
446+
self.index = consumed_u64;
447+
}
448+
}
449+
450+
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
451+
fn fill_bytes(&mut self, dest: &mut [u8]) {
452+
let mut read_len = 0;
453+
while read_len < dest.len() {
454+
if self.index as usize >= self.results.as_ref().len() {
455+
self.core.generate(&mut self.results);
456+
self.index = 0;
457+
self.half_used = false;
458+
}
459+
460+
let (consumed_u64, filled_u8) =
461+
fill_via_u64_chunks(&self.results.as_ref()[self.index as usize..],
462+
&mut dest[read_len..]);
463+
464+
self.index += consumed_u64;
465+
read_len += filled_u8;
466+
}
467+
}
468+
469+
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> {
470+
Ok(self.fill_bytes(dest))
471+
}
472+
}
473+
474+
impl<R: BlockRngCore + SeedableRng> SeedableRng for BlockRng64<R> {
475+
type Seed = R::Seed;
476+
477+
fn from_seed(seed: Self::Seed) -> Self {
478+
let results_empty = R::Results::default();
479+
Self {
480+
core: R::from_seed(seed),
481+
index: results_empty.as_ref().len(), // generate on first use
482+
half_used: false,
483+
results: results_empty,
484+
}
485+
}
486+
487+
fn from_rng<RNG: RngCore>(rng: &mut RNG) -> Result<Self, Error> {
488+
let results_empty = R::Results::default();
489+
Ok(Self {
490+
core: R::from_rng(rng)?,
491+
index: results_empty.as_ref().len(), // generate on first use
492+
half_used: false,
493+
results: results_empty,
494+
})
495+
}
496+
}
497+
336498
impl<R: BlockRngCore + CryptoRng> CryptoRng for BlockRng<R> {}
337499

338500
// TODO: implement tests for the above

0 commit comments

Comments
 (0)