Skip to content

Commit 5794c83

Browse files
Merge pull request #401 from rust-lang/std_float_improvements
Test StdFloat
2 parents eea6f77 + 278eb28 commit 5794c83

File tree

4 files changed

+174
-116
lines changed

4 files changed

+174
-116
lines changed

Cargo.lock

+3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/std_float/Cargo.toml

+7
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@ edition = "2021"
88
[dependencies]
99
core_simd = { path = "../core_simd", default-features = false }
1010

11+
[dev-dependencies.test_helpers]
12+
path = "../test_helpers"
13+
14+
[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
15+
wasm-bindgen = "0.2"
16+
wasm-bindgen-test = "0.3"
17+
1118
[features]
1219
default = ["as_crate"]
1320
as_crate = []

crates/std_float/src/lib.rs

+90-116
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#![cfg_attr(feature = "as_crate", no_std)] // We are std!
21
#![cfg_attr(
32
feature = "as_crate",
43
feature(core_intrinsics),
@@ -44,7 +43,7 @@ use crate::sealed::Sealed;
4443
/// For now this trait is available to permit experimentation with SIMD float
4544
/// operations that may lack hardware support, such as `mul_add`.
4645
pub trait StdFloat: Sealed + Sized {
47-
/// Fused multiply-add. Computes `(self * a) + b` with only one rounding error,
46+
/// Elementwise fused multiply-add. Computes `(self * a) + b` with only one rounding error,
4847
/// yielding a more accurate result than an unfused multiply-add.
4948
///
5049
/// Using `mul_add` *may* be more performant than an unfused multiply-add if the target
@@ -57,78 +56,65 @@ pub trait StdFloat: Sealed + Sized {
5756
unsafe { intrinsics::simd_fma(self, a, b) }
5857
}
5958

60-
/// Produces a vector where every lane has the square root value
61-
/// of the equivalently-indexed lane in `self`
59+
/// Produces a vector where every element has the square root value
60+
/// of the equivalently-indexed element in `self`
6261
#[inline]
6362
#[must_use = "method returns a new vector and does not mutate the original value"]
6463
fn sqrt(self) -> Self {
6564
unsafe { intrinsics::simd_fsqrt(self) }
6665
}
6766

68-
/// Produces a vector where every lane has the sine of the value
69-
/// in the equivalently-indexed lane in `self`.
70-
#[inline]
67+
/// Produces a vector where every element has the sine of the value
68+
/// in the equivalently-indexed element in `self`.
7169
#[must_use = "method returns a new vector and does not mutate the original value"]
72-
fn sin(self) -> Self {
73-
unsafe { intrinsics::simd_fsin(self) }
74-
}
70+
fn sin(self) -> Self;
7571

76-
/// Produces a vector where every lane has the cosine of the value
77-
/// in the equivalently-indexed lane in `self`.
78-
#[inline]
72+
/// Produces a vector where every element has the cosine of the value
73+
/// in the equivalently-indexed element in `self`.
7974
#[must_use = "method returns a new vector and does not mutate the original value"]
80-
fn cos(self) -> Self {
81-
unsafe { intrinsics::simd_fcos(self) }
82-
}
75+
fn cos(self) -> Self;
8376

84-
/// Produces a vector where every lane has the exponential (base e) of the value
85-
/// in the equivalently-indexed lane in `self`.
86-
#[inline]
77+
/// Produces a vector where every element has the exponential (base e) of the value
78+
/// in the equivalently-indexed element in `self`.
8779
#[must_use = "method returns a new vector and does not mutate the original value"]
88-
fn exp(self) -> Self {
89-
unsafe { intrinsics::simd_fexp(self) }
90-
}
80+
fn exp(self) -> Self;
9181

92-
/// Produces a vector where every lane has the exponential (base 2) of the value
93-
/// in the equivalently-indexed lane in `self`.
94-
#[inline]
82+
/// Produces a vector where every element has the exponential (base 2) of the value
83+
/// in the equivalently-indexed element in `self`.
9584
#[must_use = "method returns a new vector and does not mutate the original value"]
96-
fn exp2(self) -> Self {
97-
unsafe { intrinsics::simd_fexp2(self) }
98-
}
85+
fn exp2(self) -> Self;
9986

100-
/// Produces a vector where every lane has the natural logarithm of the value
101-
/// in the equivalently-indexed lane in `self`.
102-
#[inline]
87+
/// Produces a vector where every element has the natural logarithm of the value
88+
/// in the equivalently-indexed element in `self`.
10389
#[must_use = "method returns a new vector and does not mutate the original value"]
104-
fn log(self) -> Self {
105-
unsafe { intrinsics::simd_flog(self) }
106-
}
90+
fn ln(self) -> Self;
10791

108-
/// Produces a vector where every lane has the base-2 logarithm of the value
109-
/// in the equivalently-indexed lane in `self`.
92+
/// Produces a vector where every element has the logarithm with respect to an arbitrary
93+
/// in the equivalently-indexed elements in `self` and `base`.
11094
#[inline]
11195
#[must_use = "method returns a new vector and does not mutate the original value"]
112-
fn log2(self) -> Self {
113-
unsafe { intrinsics::simd_flog2(self) }
96+
fn log(self, base: Self) -> Self {
97+
unsafe { intrinsics::simd_div(self.ln(), base.ln()) }
11498
}
11599

116-
/// Produces a vector where every lane has the base-10 logarithm of the value
117-
/// in the equivalently-indexed lane in `self`.
118-
#[inline]
100+
/// Produces a vector where every element has the base-2 logarithm of the value
101+
/// in the equivalently-indexed element in `self`.
119102
#[must_use = "method returns a new vector and does not mutate the original value"]
120-
fn log10(self) -> Self {
121-
unsafe { intrinsics::simd_flog10(self) }
122-
}
103+
fn log2(self) -> Self;
104+
105+
/// Produces a vector where every element has the base-10 logarithm of the value
106+
/// in the equivalently-indexed element in `self`.
107+
#[must_use = "method returns a new vector and does not mutate the original value"]
108+
fn log10(self) -> Self;
123109

124-
/// Returns the smallest integer greater than or equal to each lane.
110+
/// Returns the smallest integer greater than or equal to each element.
125111
#[must_use = "method returns a new vector and does not mutate the original value"]
126112
#[inline]
127113
fn ceil(self) -> Self {
128114
unsafe { intrinsics::simd_ceil(self) }
129115
}
130116

131-
/// Returns the largest integer value less than or equal to each lane.
117+
/// Returns the largest integer value less than or equal to each element.
132118
#[must_use = "method returns a new vector and does not mutate the original value"]
133119
#[inline]
134120
fn floor(self) -> Self {
@@ -157,77 +143,65 @@ pub trait StdFloat: Sealed + Sized {
157143
impl<const N: usize> Sealed for Simd<f32, N> where LaneCount<N>: SupportedLaneCount {}
158144
impl<const N: usize> Sealed for Simd<f64, N> where LaneCount<N>: SupportedLaneCount {}
159145

160-
// We can safely just use all the defaults.
161-
impl<const N: usize> StdFloat for Simd<f32, N>
162-
where
163-
LaneCount<N>: SupportedLaneCount,
164-
{
165-
/// Returns the floating point's fractional value, with its integer part removed.
166-
#[must_use = "method returns a new vector and does not mutate the original value"]
167-
#[inline]
168-
fn fract(self) -> Self {
169-
self - self.trunc()
146+
macro_rules! impl_float {
147+
{
148+
$($fn:ident: $intrinsic:ident,)*
149+
} => {
150+
impl<const N: usize> StdFloat for Simd<f32, N>
151+
where
152+
LaneCount<N>: SupportedLaneCount,
153+
{
154+
#[inline]
155+
fn fract(self) -> Self {
156+
self - self.trunc()
157+
}
158+
159+
$(
160+
#[inline]
161+
fn $fn(self) -> Self {
162+
unsafe { intrinsics::$intrinsic(self) }
163+
}
164+
)*
165+
}
166+
167+
impl<const N: usize> StdFloat for Simd<f64, N>
168+
where
169+
LaneCount<N>: SupportedLaneCount,
170+
{
171+
#[inline]
172+
fn fract(self) -> Self {
173+
self - self.trunc()
174+
}
175+
176+
$(
177+
#[inline]
178+
fn $fn(self) -> Self {
179+
// https://github.com/llvm/llvm-project/issues/83729
180+
#[cfg(target_arch = "aarch64")]
181+
{
182+
let mut ln = Self::splat(0f64);
183+
for i in 0..N {
184+
ln[i] = self[i].$fn()
185+
}
186+
ln
187+
}
188+
189+
#[cfg(not(target_arch = "aarch64"))]
190+
{
191+
unsafe { intrinsics::$intrinsic(self) }
192+
}
193+
}
194+
)*
195+
}
170196
}
171197
}
172198

173-
impl<const N: usize> StdFloat for Simd<f64, N>
174-
where
175-
LaneCount<N>: SupportedLaneCount,
176-
{
177-
/// Returns the floating point's fractional value, with its integer part removed.
178-
#[must_use = "method returns a new vector and does not mutate the original value"]
179-
#[inline]
180-
fn fract(self) -> Self {
181-
self - self.trunc()
182-
}
183-
}
184-
185-
#[cfg(test)]
186-
mod tests_simd_floats {
187-
use super::*;
188-
use simd::prelude::*;
189-
190-
#[test]
191-
fn everything_works_f32() {
192-
let x = f32x4::from_array([0.1, 0.5, 0.6, -1.5]);
193-
194-
let x2 = x + x;
195-
let _xc = x.ceil();
196-
let _xf = x.floor();
197-
let _xr = x.round();
198-
let _xt = x.trunc();
199-
let _xfma = x.mul_add(x, x);
200-
let _xsqrt = x.sqrt();
201-
let _abs_mul = x2.abs() * x2;
202-
203-
let _fexp = x.exp();
204-
let _fexp2 = x.exp2();
205-
let _flog = x.log();
206-
let _flog2 = x.log2();
207-
let _flog10 = x.log10();
208-
let _fsin = x.sin();
209-
let _fcos = x.cos();
210-
}
211-
212-
#[test]
213-
fn everything_works_f64() {
214-
let x = f64x4::from_array([0.1, 0.5, 0.6, -1.5]);
215-
216-
let x2 = x + x;
217-
let _xc = x.ceil();
218-
let _xf = x.floor();
219-
let _xr = x.round();
220-
let _xt = x.trunc();
221-
let _xfma = x.mul_add(x, x);
222-
let _xsqrt = x.sqrt();
223-
let _abs_mul = x2.abs() * x2;
224-
225-
let _fexp = x.exp();
226-
let _fexp2 = x.exp2();
227-
let _flog = x.log();
228-
let _flog2 = x.log2();
229-
let _flog10 = x.log10();
230-
let _fsin = x.sin();
231-
let _fcos = x.cos();
232-
}
199+
impl_float! {
200+
sin: simd_fsin,
201+
cos: simd_fcos,
202+
exp: simd_fexp,
203+
exp2: simd_fexp2,
204+
ln: simd_flog,
205+
log2: simd_flog2,
206+
log10: simd_flog10,
233207
}

crates/std_float/tests/float.rs

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#![feature(portable_simd)]
2+
3+
macro_rules! unary_test {
4+
{ $scalar:tt, $($func:tt),+ } => {
5+
test_helpers::test_lanes! {
6+
$(
7+
fn $func<const LANES: usize>() {
8+
test_helpers::test_unary_elementwise(
9+
&core_simd::simd::Simd::<$scalar, LANES>::$func,
10+
&$scalar::$func,
11+
&|_| true,
12+
)
13+
}
14+
)*
15+
}
16+
}
17+
}
18+
19+
macro_rules! binary_test {
20+
{ $scalar:tt, $($func:tt),+ } => {
21+
test_helpers::test_lanes! {
22+
$(
23+
fn $func<const LANES: usize>() {
24+
test_helpers::test_binary_elementwise(
25+
&core_simd::simd::Simd::<$scalar, LANES>::$func,
26+
&$scalar::$func,
27+
&|_, _| true,
28+
)
29+
}
30+
)*
31+
}
32+
}
33+
}
34+
35+
macro_rules! ternary_test {
36+
{ $scalar:tt, $($func:tt),+ } => {
37+
test_helpers::test_lanes! {
38+
$(
39+
fn $func<const LANES: usize>() {
40+
test_helpers::test_ternary_elementwise(
41+
&core_simd::simd::Simd::<$scalar, LANES>::$func,
42+
&$scalar::$func,
43+
&|_, _, _| true,
44+
)
45+
}
46+
)*
47+
}
48+
}
49+
}
50+
51+
macro_rules! impl_tests {
52+
{ $scalar:tt } => {
53+
mod $scalar {
54+
use std_float::StdFloat;
55+
56+
unary_test! { $scalar, sqrt, sin, cos, exp, exp2, ln, log2, log10, ceil, floor, round, trunc }
57+
binary_test! { $scalar, log }
58+
ternary_test! { $scalar, mul_add }
59+
60+
test_helpers::test_lanes! {
61+
fn fract<const LANES: usize>() {
62+
test_helpers::test_unary_elementwise_flush_subnormals(
63+
&core_simd::simd::Simd::<$scalar, LANES>::fract,
64+
&$scalar::fract,
65+
&|_| true,
66+
)
67+
}
68+
}
69+
}
70+
}
71+
}
72+
73+
impl_tests! { f32 }
74+
impl_tests! { f64 }

0 commit comments

Comments
 (0)