Skip to content
This repository was archived by the owner on Apr 28, 2025. It is now read-only.

Commit 4451a20

Browse files
committed
WIP f16 fma
1 parent 0f6b1bb commit 4451a20

File tree

15 files changed

+193
-21
lines changed

15 files changed

+193
-21
lines changed

crates/libm-macros/src/shared.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,13 @@ const ALL_OPERATIONS_NESTED: &[(FloatTy, Signature, Option<Signature>, &[&str])]
9292
None,
9393
&["copysignf128"],
9494
),
95+
(
96+
// `(f16, f16, f16) -> f16`
97+
FloatTy::F16,
98+
Signature { args: &[Ty::F16, Ty::F16, Ty::F16], returns: &[Ty::F16] },
99+
None,
100+
&["fmaf16"],
101+
),
95102
(
96103
// `(f32, f32, f32) -> f32`
97104
FloatTy::F32,

crates/libm-test/src/f8_impl.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ pub struct f8(u8);
2020
impl Float for f8 {
2121
type Int = u8;
2222
type SignedInt = i8;
23-
type ExpInt = i8;
2423

2524
const ZERO: Self = Self(0b0_0000_000);
2625
const NEG_ZERO: Self = Self(0b1_0000_000);
@@ -62,8 +61,8 @@ impl Float for f8 {
6261
self.0 & Self::SIGN_MASK != 0
6362
}
6463

65-
fn exp(self) -> Self::ExpInt {
66-
unimplemented!()
64+
fn exp(self) -> i32 {
65+
((self.to_bits() & Self::EXP_MASK) >> Self::SIG_BITS) as i32
6766
}
6867

6968
fn from_bits(a: Self::Int) -> Self {

crates/libm-test/src/mpfloat.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ libm_macros::for_each_function! {
147147
expm1 | expm1f => exp_m1,
148148
fabs | fabsf => abs,
149149
fdim | fdimf => positive_diff,
150-
fma | fmaf => mul_add,
150+
fma | fmaf | fmaf16 => mul_add,
151151
fmax | fmaxf => max,
152152
fmin | fminf => min,
153153
lgamma | lgammaf => ln_gamma,

crates/libm-test/src/precision.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,12 @@ fn bessel_prec_dropoff<F: Float>(
485485
None
486486
}
487487

488+
#[cfg(f16_enabled)]
489+
impl MaybeOverride<(f16, f16, f16)> for SpecialCase {}
488490
impl MaybeOverride<(f32, f32, f32)> for SpecialCase {}
489491
impl MaybeOverride<(f64, f64, f64)> for SpecialCase {}
492+
#[cfg(f128_enabled)]
493+
impl MaybeOverride<(f128, f128, f128)> for SpecialCase {}
494+
490495
impl MaybeOverride<(f32, i32)> for SpecialCase {}
491496
impl MaybeOverride<(f64, i32)> for SpecialCase {}

crates/libm-test/tests/multiprecision.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ libm_macros::for_each_function! {
122122
fdimf,
123123
fma,
124124
fmaf,
125+
fmaf16,
125126
fmax,
126127
fmaxf,
127128
fmin,

etc/function-definitions.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,12 @@
328328
],
329329
"type": "f32"
330330
},
331+
"fmaf16": {
332+
"sources": [
333+
"src/math/fmaf16.rs"
334+
],
335+
"type": "f16"
336+
},
331337
"fmax": {
332338
"sources": [
333339
"src/libm_helper.rs",

etc/function-list.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ floor
4747
floorf
4848
fma
4949
fmaf
50+
fmaf16
5051
fmax
5152
fmaxf
5253
fmin

src/math/fmaf.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ use super::fenv::{
4747
/// according to the rounding mode characterized by the value of FLT_ROUNDS.
4848
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
4949
pub fn fmaf(x: f32, y: f32, mut z: f32) -> f32 {
50+
if true {
51+
return super::generic::fma_big::<f32, f64>(x, y, z);
52+
}
53+
5054
let xy: f64;
5155
let mut result: f64;
5256
let mut ui: u64;

src/math/fmaf16.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
2+
pub fn fmaf16(x: f16, y: f16, z: f16) -> f16 {
3+
super::generic::fma_big::<f16, f32>(x, y, z)
4+
}

src/math/generic/fma.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
use super::super::fenv::{
2+
FE_INEXACT, FE_TONEAREST, FE_UNDERFLOW, feclearexcept, fegetround, feraiseexcept, fetestexcept,
3+
};
4+
use super::super::{CastFrom, CastInto, DFloat, Float, HFloat, IntTy, MinInt};
5+
6+
/// FMA implementation when a hardware-backed larger float type is available.
7+
pub fn fma_big<F, B>(x: F, y: F, z: F) -> F
8+
where
9+
F: Float + HFloat<D = B>,
10+
B: Float + DFloat<H = F>,
11+
// F: Float + CastInto<B>,
12+
// B: Float + CastInto<F> + CastFrom<F>,
13+
B::Int: CastInto<i32>,
14+
i32: CastFrom<i32>,
15+
{
16+
let one = IntTy::<B>::ONE;
17+
18+
let xy: B;
19+
let mut result: B;
20+
let mut ui: B::Int;
21+
let e: i32;
22+
23+
xy = x.widen() * y.widen();
24+
result = xy + z.widen();
25+
ui = result.to_bits();
26+
e = i32::cast_from(ui >> F::SIG_BITS) & F::EXP_MAX as i32;
27+
let zb: B = z.widen();
28+
29+
let prec_diff = B::SIG_BITS - F::SIG_BITS;
30+
let excess_prec = ui & ((one << prec_diff) - one);
31+
let x = one << (prec_diff - 1);
32+
33+
// Common case: the larger precision is fine
34+
if excess_prec != x
35+
|| e == i32::cast_from(F::EXP_MAX)
36+
|| (result - xy == zb && result - zb == xy)
37+
|| fegetround() != FE_TONEAREST
38+
{
39+
// TODO: feclearexcept
40+
41+
return result.narrow();
42+
}
43+
44+
let neg = ui & B::SIGN_MASK > IntTy::<B>::ZERO;
45+
let err = if neg == (zb > xy) { xy - result + zb } else { zb - result + xy };
46+
if neg == (err < B::ZERO) {
47+
ui += one;
48+
} else {
49+
ui -= one;
50+
}
51+
52+
B::from_bits(ui).narrow()
53+
}

0 commit comments

Comments
 (0)