From 8f18d3673af6471c75770be146d24866b3f540a7 Mon Sep 17 00:00:00 2001 From: Seth Ford Date: Sat, 5 Apr 2025 09:46:48 -0400 Subject: [PATCH 1/3] Add compatibility layer for half types with rand 0.8.5 --- rust/moshi-core/src/compat/mod.rs | 3 + rust/moshi-core/src/compat/rand_half.rs | 188 ++++++++++++++++++++++++ rust/moshi-core/src/lib.rs | 3 + 3 files changed, 194 insertions(+) create mode 100644 rust/moshi-core/src/compat/mod.rs create mode 100644 rust/moshi-core/src/compat/rand_half.rs diff --git a/rust/moshi-core/src/compat/mod.rs b/rust/moshi-core/src/compat/mod.rs new file mode 100644 index 00000000..bd376c30 --- /dev/null +++ b/rust/moshi-core/src/compat/mod.rs @@ -0,0 +1,3 @@ +/// Module providing compatibility between dependencies +/// Specifically implements traits needed for half types to work with rand +pub mod rand_half; \ No newline at end of file diff --git a/rust/moshi-core/src/compat/rand_half.rs b/rust/moshi-core/src/compat/rand_half.rs new file mode 100644 index 00000000..f1838b9d --- /dev/null +++ b/rust/moshi-core/src/compat/rand_half.rs @@ -0,0 +1,188 @@ +//! Compatibility module for half types with rand +//! +//! This module implements the necessary traits for bf16 and f16 types +//! from the half crate to work with rand's uniform distribution. +//! +//! The implementation follows the same pattern as for f32 and f64 in rand. + +use half::{bf16, f16}; +use rand::distributions::uniform::{SampleUniform, SampleBorrow, UniformSampler}; +use rand::distributions::Uniform; +use std::ops::{Sub, Add}; + +/// Implementations for f16 +impl SampleUniform for f16 { + type Sampler = UniformF16; +} + +/// Uniform sampler for f16 +#[derive(Clone, Copy, Debug)] +pub struct UniformF16 { + low: f16, + range: f16, + // These are used by the distribution to ensure the range is covered + // properly and to enable distributions over ranges like 0..1. + scale: f16, + offset: f16, +} + +impl UniformSampler for UniformF16 { + type X = f16; + + fn new(low: B1, high: B2) -> Self + where + B1: SampleBorrow, + B2: SampleBorrow, + { + let low = *low.borrow(); + let high = *high.borrow(); + assert!(low < high, "Uniform::new called with low >= high"); + + let range = high - low; + + // Calculate offset and scale used to map from the half-open range + // [0, 1) to the target range + let scale = range; + let offset = low; + + Self { + low, + range, + scale, + offset, + } + } + + fn new_inclusive(low: B1, high: B2) -> Self + where + B1: SampleBorrow, + B2: SampleBorrow, + { + let low = *low.borrow(); + let high = *high.borrow(); + assert!(low <= high, "Uniform::new_inclusive called with low > high"); + + let range = high - low; + + let scale = range; + let offset = low; + + Self { + low, + range, + scale, + offset, + } + } + + fn sample(&self, rng: &mut R) -> Self::X { + // We use the same trick that rand uses internally for float sampling: + // Generate a value in the range [0, 1) and scale to our target range + let sampler = Uniform::new(0.0f32, 1.0f32); + let f = sampler.sample(rng); + + // Scale to target range + f16::from_f32(f) * self.scale + self.offset + } +} + +/// Implementations for bf16 +impl SampleUniform for bf16 { + type Sampler = UniformBF16; +} + +/// Uniform sampler for bf16 +#[derive(Clone, Copy, Debug)] +pub struct UniformBF16 { + low: bf16, + range: bf16, + scale: bf16, + offset: bf16, +} + +impl UniformSampler for UniformBF16 { + type X = bf16; + + fn new(low: B1, high: B2) -> Self + where + B1: SampleBorrow, + B2: SampleBorrow, + { + let low = *low.borrow(); + let high = *high.borrow(); + assert!(low < high, "Uniform::new called with low >= high"); + + let range = high - low; + + let scale = range; + let offset = low; + + Self { + low, + range, + scale, + offset, + } + } + + fn new_inclusive(low: B1, high: B2) -> Self + where + B1: SampleBorrow, + B2: SampleBorrow, + { + let low = *low.borrow(); + let high = *high.borrow(); + assert!(low <= high, "Uniform::new_inclusive called with low > high"); + + let range = high - low; + + let scale = range; + let offset = low; + + Self { + low, + range, + scale, + offset, + } + } + + fn sample(&self, rng: &mut R) -> Self::X { + let sampler = Uniform::new(0.0f32, 1.0f32); + let f = sampler.sample(rng); + + bf16::from_f32(f) * self.scale + self.offset + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::distributions::Distribution; + use rand::rngs::StdRng; + use rand::SeedableRng; + + #[test] + fn test_f16_uniform() { + let seed = [42u8; 32]; + let mut rng = StdRng::from_seed(seed); + + let dist = Uniform::new(f16::from_f32(0.0), f16::from_f32(1.0)); + let val = dist.sample(&mut rng); + + assert!(val >= f16::from_f32(0.0)); + assert!(val < f16::from_f32(1.0)); + } + + #[test] + fn test_bf16_uniform() { + let seed = [42u8; 32]; + let mut rng = StdRng::from_seed(seed); + + let dist = Uniform::new(bf16::from_f32(0.0), bf16::from_f32(1.0)); + let val = dist.sample(&mut rng); + + assert!(val >= bf16::from_f32(0.0)); + assert!(val < bf16::from_f32(1.0)); + } +} \ No newline at end of file diff --git a/rust/moshi-core/src/lib.rs b/rust/moshi-core/src/lib.rs index 73fc9592..cd9067b9 100644 --- a/rust/moshi-core/src/lib.rs +++ b/rust/moshi-core/src/lib.rs @@ -21,6 +21,9 @@ pub mod tts; pub mod tts_streaming; pub mod wav; +// Add compatibility module +pub mod compat; + #[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub enum NormType { RmsNorm, From 0af70785888d492750460cd8350912b80449a313 Mon Sep 17 00:00:00 2001 From: Seth Ford Date: Sat, 5 Apr 2025 09:49:07 -0400 Subject: [PATCH 2/3] Add half and rand as explicit dependencies in moshi-core --- rust/moshi-core/Cargo.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/rust/moshi-core/Cargo.toml b/rust/moshi-core/Cargo.toml index 13bae99a..fb006d5c 100644 --- a/rust/moshi-core/Cargo.toml +++ b/rust/moshi-core/Cargo.toml @@ -19,6 +19,10 @@ rayon = { workspace = true } serde = { workspace = true } tracing = { workspace = true } +# Additional dependencies for compatibility +rand = { workspace = true } +half = "2.5.0" + [features] default = [] cuda = ["candle/cuda", "candle-nn/cuda"] From eb44a0a12ca3f2cb1e1f89b0ab681860c8881ed5 Mon Sep 17 00:00:00 2001 From: Seth Ford Date: Sat, 5 Apr 2025 09:50:38 -0400 Subject: [PATCH 3/3] Use wrapper types to comply with Rust's orphan rule --- rust/moshi-core/src/compat/rand_half.rs | 173 +++++++++++++++++------- 1 file changed, 123 insertions(+), 50 deletions(-) diff --git a/rust/moshi-core/src/compat/rand_half.rs b/rust/moshi-core/src/compat/rand_half.rs index f1838b9d..b966a2ad 100644 --- a/rust/moshi-core/src/compat/rand_half.rs +++ b/rust/moshi-core/src/compat/rand_half.rs @@ -1,33 +1,62 @@ //! Compatibility module for half types with rand //! -//! This module implements the necessary traits for bf16 and f16 types -//! from the half crate to work with rand's uniform distribution. +//! This module provides wrapper types that implement the necessary traits +//! for bf16 and f16 types from the half crate to work with rand's uniform distribution. //! -//! The implementation follows the same pattern as for f32 and f64 in rand. +//! Since we can't directly implement foreign traits for foreign types due to Rust's orphan rules, +//! we use the newtype pattern with transparent wrappers. use half::{bf16, f16}; use rand::distributions::uniform::{SampleUniform, SampleBorrow, UniformSampler}; -use rand::distributions::Uniform; -use std::ops::{Sub, Add}; +use rand::distributions::{Distribution, Uniform}; +use rand::Rng; +use std::ops::{Deref, DerefMut}; -/// Implementations for f16 -impl SampleUniform for f16 { - type Sampler = UniformF16; +/// A wrapper around f16 that implements SampleUniform +#[derive(Clone, Copy, Debug, PartialEq)] +#[repr(transparent)] +pub struct F16Wrapper(pub f16); + +impl Deref for F16Wrapper { + type Target = f16; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for F16Wrapper { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } } -/// Uniform sampler for f16 +impl From for F16Wrapper { + fn from(v: f16) -> Self { + F16Wrapper(v) + } +} + +impl From for f16 { + fn from(v: F16Wrapper) -> Self { + v.0 + } +} + +impl SampleUniform for F16Wrapper { + type Sampler = UniformF16Wrapper; +} + +/// Uniform sampler for F16Wrapper #[derive(Clone, Copy, Debug)] -pub struct UniformF16 { - low: f16, - range: f16, - // These are used by the distribution to ensure the range is covered - // properly and to enable distributions over ranges like 0..1. - scale: f16, - offset: f16, +pub struct UniformF16Wrapper { + low: F16Wrapper, + range: F16Wrapper, + scale: F16Wrapper, + offset: F16Wrapper, } -impl UniformSampler for UniformF16 { - type X = f16; +impl UniformSampler for UniformF16Wrapper { + type X = F16Wrapper; fn new(low: B1, high: B2) -> Self where @@ -36,12 +65,10 @@ impl UniformSampler for UniformF16 { { let low = *low.borrow(); let high = *high.borrow(); - assert!(low < high, "Uniform::new called with low >= high"); + assert!(low.0 < high.0, "Uniform::new called with low >= high"); - let range = high - low; + let range = F16Wrapper(high.0 - low.0); - // Calculate offset and scale used to map from the half-open range - // [0, 1) to the target range let scale = range; let offset = low; @@ -60,9 +87,9 @@ impl UniformSampler for UniformF16 { { let low = *low.borrow(); let high = *high.borrow(); - assert!(low <= high, "Uniform::new_inclusive called with low > high"); + assert!(low.0 <= high.0, "Uniform::new_inclusive called with low > high"); - let range = high - low; + let range = F16Wrapper(high.0 - low.0); let scale = range; let offset = low; @@ -75,33 +102,62 @@ impl UniformSampler for UniformF16 { } } - fn sample(&self, rng: &mut R) -> Self::X { + fn sample(&self, rng: &mut R) -> Self::X { // We use the same trick that rand uses internally for float sampling: // Generate a value in the range [0, 1) and scale to our target range let sampler = Uniform::new(0.0f32, 1.0f32); - let f = sampler.sample(rng); + let f = Distribution::sample(&sampler, rng); // Scale to target range - f16::from_f32(f) * self.scale + self.offset + F16Wrapper(f16::from_f32(f) * self.scale.0 + self.offset.0) } } -/// Implementations for bf16 -impl SampleUniform for bf16 { - type Sampler = UniformBF16; +/// A wrapper around bf16 that implements SampleUniform +#[derive(Clone, Copy, Debug, PartialEq)] +#[repr(transparent)] +pub struct BF16Wrapper(pub bf16); + +impl Deref for BF16Wrapper { + type Target = bf16; + fn deref(&self) -> &Self::Target { + &self.0 + } } -/// Uniform sampler for bf16 +impl DerefMut for BF16Wrapper { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl From for BF16Wrapper { + fn from(v: bf16) -> Self { + BF16Wrapper(v) + } +} + +impl From for bf16 { + fn from(v: BF16Wrapper) -> Self { + v.0 + } +} + +impl SampleUniform for BF16Wrapper { + type Sampler = UniformBF16Wrapper; +} + +/// Uniform sampler for BF16Wrapper #[derive(Clone, Copy, Debug)] -pub struct UniformBF16 { - low: bf16, - range: bf16, - scale: bf16, - offset: bf16, +pub struct UniformBF16Wrapper { + low: BF16Wrapper, + range: BF16Wrapper, + scale: BF16Wrapper, + offset: BF16Wrapper, } -impl UniformSampler for UniformBF16 { - type X = bf16; +impl UniformSampler for UniformBF16Wrapper { + type X = BF16Wrapper; fn new(low: B1, high: B2) -> Self where @@ -110,9 +166,9 @@ impl UniformSampler for UniformBF16 { { let low = *low.borrow(); let high = *high.borrow(); - assert!(low < high, "Uniform::new called with low >= high"); + assert!(low.0 < high.0, "Uniform::new called with low >= high"); - let range = high - low; + let range = BF16Wrapper(high.0 - low.0); let scale = range; let offset = low; @@ -132,9 +188,9 @@ impl UniformSampler for UniformBF16 { { let low = *low.borrow(); let high = *high.borrow(); - assert!(low <= high, "Uniform::new_inclusive called with low > high"); + assert!(low.0 <= high.0, "Uniform::new_inclusive called with low > high"); - let range = high - low; + let range = BF16Wrapper(high.0 - low.0); let scale = range; let offset = low; @@ -147,18 +203,33 @@ impl UniformSampler for UniformBF16 { } } - fn sample(&self, rng: &mut R) -> Self::X { + fn sample(&self, rng: &mut R) -> Self::X { let sampler = Uniform::new(0.0f32, 1.0f32); - let f = sampler.sample(rng); + let f = Distribution::sample(&sampler, rng); - bf16::from_f32(f) * self.scale + self.offset + BF16Wrapper(bf16::from_f32(f) * self.scale.0 + self.offset.0) + } +} + +// Extension traits to make using with Uniform more ergonomic +pub trait UniformHalfExt { + fn uniform_f16(low: f16, high: f16) -> Uniform; + fn uniform_bf16(low: bf16, high: bf16) -> Uniform; +} + +impl UniformHalfExt for Uniform { + fn uniform_f16(low: f16, high: f16) -> Uniform { + Uniform::new(F16Wrapper(low), F16Wrapper(high)) + } + + fn uniform_bf16(low: bf16, high: bf16) -> Uniform { + Uniform::new(BF16Wrapper(low), BF16Wrapper(high)) } } #[cfg(test)] mod tests { use super::*; - use rand::distributions::Distribution; use rand::rngs::StdRng; use rand::SeedableRng; @@ -167,8 +238,9 @@ mod tests { let seed = [42u8; 32]; let mut rng = StdRng::from_seed(seed); - let dist = Uniform::new(f16::from_f32(0.0), f16::from_f32(1.0)); - let val = dist.sample(&mut rng); + let dist = Uniform::uniform_f16(f16::from_f32(0.0), f16::from_f32(1.0)); + let wrapper = dist.sample(&mut rng); + let val: f16 = wrapper.into(); assert!(val >= f16::from_f32(0.0)); assert!(val < f16::from_f32(1.0)); @@ -179,8 +251,9 @@ mod tests { let seed = [42u8; 32]; let mut rng = StdRng::from_seed(seed); - let dist = Uniform::new(bf16::from_f32(0.0), bf16::from_f32(1.0)); - let val = dist.sample(&mut rng); + let dist = Uniform::uniform_bf16(bf16::from_f32(0.0), bf16::from_f32(1.0)); + let wrapper = dist.sample(&mut rng); + let val: bf16 = wrapper.into(); assert!(val >= bf16::from_f32(0.0)); assert!(val < bf16::from_f32(1.0));