Skip to content

Commit cc5e826

Browse files
sayantnAmanieu
authored andcommitted
Implemented the missing AVX512BF16 intrinsics
1 parent a896f8d commit cc5e826

File tree

3 files changed

+245
-16
lines changed

3 files changed

+245
-16
lines changed

crates/core_arch/missing-x86.md

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -147,27 +147,12 @@
147147
</p></details>
148148

149149

150-
<details><summary>["AVX512_BF16", "AVX512F"]</summary><p>
151-
152-
* [ ] [`_mm512_cvtpbh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm512_cvtpbh_ps)
153-
* [ ] [`_mm512_mask_cvtpbh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm512_mask_cvtpbh_ps)
154-
* [ ] [`_mm512_maskz_cvtpbh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm512_maskz_cvtpbh_ps)
155-
* [ ] [`_mm_cvtsbh_ss`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtsbh_ss)
156-
</p></details>
157-
158-
159150
<details><summary>["AVX512_BF16", "AVX512VL"]</summary><p>
160151

161-
* [ ] [`_mm256_cvtpbh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_cvtpbh_ps)
162-
* [ ] [`_mm256_mask_cvtpbh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_mask_cvtpbh_ps)
163-
* [ ] [`_mm256_maskz_cvtpbh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_maskz_cvtpbh_ps)
164152
* [ ] [`_mm_cvtneps_pbh`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtneps_pbh)
165153
* [ ] [`_mm_cvtness_sbh`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtness_sbh)
166-
* [ ] [`_mm_cvtpbh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtpbh_ps)
167154
* [ ] [`_mm_mask_cvtneps_pbh`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_mask_cvtneps_pbh)
168-
* [ ] [`_mm_mask_cvtpbh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_mask_cvtpbh_ps)
169155
* [ ] [`_mm_maskz_cvtneps_pbh`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_maskz_cvtneps_pbh)
170-
* [ ] [`_mm_maskz_cvtpbh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_maskz_cvtpbh_ps)
171156
</p></details>
172157

173158

crates/core_arch/src/x86/avx512bf16.rs

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,131 @@ pub unsafe fn _mm512_maskz_dpbf16_ps(
365365
transmute(simd_select_bitmask(k, rst, zero))
366366
}
367367

368+
/// Converts packed BF16 (16-bit) floating-point elements in a to packed single-precision (32-bit)
369+
/// floating-point elements, and store the results in dst.
370+
///
371+
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm512_cvtpbh_ps)
372+
#[inline]
373+
#[target_feature(enable = "avx512bf16,avx512f")]
374+
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
375+
pub unsafe fn _mm512_cvtpbh_ps(a: __m256bh) -> __m512 {
376+
_mm512_castsi512_ps(_mm512_slli_epi32::<16>(_mm512_cvtepi16_epi32(transmute(a))))
377+
}
378+
379+
/// Converts packed BF16 (16-bit) floating-point elements in a to packed single-precision (32-bit)
380+
/// floating-point elements, and store the results in dst using writemask k (elements are copied
381+
/// from src when the corresponding mask bit is not set).
382+
///
383+
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm512_mask_cvtpbh_ps)
384+
#[inline]
385+
#[target_feature(enable = "avx512bf16,avx512f")]
386+
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
387+
pub unsafe fn _mm512_mask_cvtpbh_ps(src: __m512, k: __mmask16, a: __m256bh) -> __m512 {
388+
let cvt = _mm512_cvtpbh_ps(a);
389+
transmute(simd_select_bitmask(k, cvt.as_f32x16(), src.as_f32x16()))
390+
}
391+
392+
/// Converts packed BF16 (16-bit) floating-point elements in a to packed single-precision (32-bit)
393+
/// floating-point elements, and store the results in dst using zeromask k (elements are zeroed out
394+
/// when the corresponding mask bit is not set).
395+
///
396+
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm512_maskz_cvtpbh_ps)
397+
#[inline]
398+
#[target_feature(enable = "avx512bf16,avx512f")]
399+
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
400+
pub unsafe fn _mm512_maskz_cvtpbh_ps(k: __mmask16, a: __m256bh) -> __m512 {
401+
let cvt = _mm512_cvtpbh_ps(a);
402+
let zero = _mm512_setzero_ps();
403+
transmute(simd_select_bitmask(k, cvt.as_f32x16(), zero.as_f32x16()))
404+
}
405+
406+
/// Converts packed BF16 (16-bit) floating-point elements in a to packed single-precision (32-bit)
407+
/// floating-point elements, and store the results in dst.
408+
///
409+
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_cvtpbh_ps)
410+
#[inline]
411+
#[target_feature(enable = "avx512bf16,avx512vl")]
412+
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
413+
pub unsafe fn _mm256_cvtpbh_ps(a: __m128bh) -> __m256 {
414+
_mm256_castsi256_ps(_mm256_slli_epi32::<16>(_mm256_cvtepi16_epi32(transmute(a))))
415+
}
416+
417+
/// Converts packed BF16 (16-bit) floating-point elements in a to packed single-precision (32-bit)
418+
/// floating-point elements, and store the results in dst using writemask k (elements are copied
419+
/// from src when the corresponding mask bit is not set).
420+
///
421+
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_mask_cvtpbh_ps)
422+
#[inline]
423+
#[target_feature(enable = "avx512bf16,avx512vl")]
424+
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
425+
pub unsafe fn _mm256_mask_cvtpbh_ps(src: __m256, k: __mmask8, a: __m128bh) -> __m256 {
426+
let cvt = _mm256_cvtpbh_ps(a);
427+
transmute(simd_select_bitmask(k, cvt.as_f32x8(), src.as_f32x8()))
428+
}
429+
430+
/// Converts packed BF16 (16-bit) floating-point elements in a to packed single-precision (32-bit)
431+
/// floating-point elements, and store the results in dst using zeromask k (elements are zeroed out
432+
/// when the corresponding mask bit is not set).
433+
///
434+
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_maskz_cvtpbh_ps)
435+
#[inline]
436+
#[target_feature(enable = "avx512bf16,avx512vl")]
437+
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
438+
pub unsafe fn _mm256_maskz_cvtpbh_ps(k: __mmask8, a: __m128bh) -> __m256 {
439+
let cvt = _mm256_cvtpbh_ps(a);
440+
let zero = _mm256_setzero_ps();
441+
transmute(simd_select_bitmask(k, cvt.as_f32x8(), zero.as_f32x8()))
442+
}
443+
444+
/// Converts packed BF16 (16-bit) floating-point elements in a to single-precision (32-bit) floating-point
445+
/// elements, and store the results in dst.
446+
///
447+
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtpbh_ps)
448+
#[inline]
449+
#[target_feature(enable = "avx512bf16,avx512vl")]
450+
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
451+
pub unsafe fn _mm_cvtpbh_ps(a: __m128bh) -> __m128 {
452+
_mm_castsi128_ps(_mm_slli_epi32::<16>(_mm_cvtepi16_epi32(transmute(a))))
453+
}
454+
455+
/// Converts packed BF16 (16-bit) floating-point elements in a to single-precision (32-bit) floating-point
456+
/// elements, and store the results in dst using writemask k (elements are copied from src when the corresponding
457+
/// mask bit is not set).
458+
///
459+
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_mask_cvtpbh_ps)
460+
#[inline]
461+
#[target_feature(enable = "avx512bf16,avx512vl")]
462+
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
463+
pub unsafe fn _mm_mask_cvtpbh_ps(src: __m128, k: __mmask8, a: __m128bh) -> __m128 {
464+
let cvt = _mm_cvtpbh_ps(a);
465+
transmute(simd_select_bitmask(k, cvt.as_f32x4(), src.as_f32x4()))
466+
}
467+
468+
/// Converts packed BF16 (16-bit) floating-point elements in a to single-precision (32-bit) floating-point
469+
/// elements, and store the results in dst using zeromask k (elements are zeroed out when the corresponding
470+
/// mask bit is not set).
471+
///
472+
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_maskz_cvtpbh_ps)
473+
#[inline]
474+
#[target_feature(enable = "avx512bf16,avx512vl")]
475+
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
476+
pub unsafe fn _mm_maskz_cvtpbh_ps(k: __mmask8, a: __m128bh) -> __m128 {
477+
let cvt = _mm_cvtpbh_ps(a);
478+
let zero = _mm_setzero_ps();
479+
transmute(simd_select_bitmask(k, cvt.as_f32x4(), zero.as_f32x4()))
480+
}
481+
482+
/// Converts a single BF16 (16-bit) floating-point element in a to a single-precision (32-bit) floating-point
483+
/// element, and store the result in dst.
484+
///
485+
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtsbh_ss)
486+
#[inline]
487+
#[target_feature(enable = "avx512bf16,avx512f")]
488+
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
489+
pub unsafe fn _mm_cvtsbh_ss(a: u16) -> f32 {
490+
f32::from_bits((a as u32) << 16)
491+
}
492+
368493
#[cfg(test)]
369494
mod tests {
370495
use crate::{core_arch::x86::*, mem::transmute};
@@ -1592,4 +1717,123 @@ mod tests {
15921717
];
15931718
assert_eq!(result, expected_result);
15941719
}
1720+
1721+
const BF16_ONE: u16 = 0b0_01111111_0000000;
1722+
const BF16_TWO: u16 = 0b0_10000000_0000000;
1723+
const BF16_THREE: u16 = 0b0_10000000_1000000;
1724+
const BF16_FOUR: u16 = 0b0_10000001_0000000;
1725+
const BF16_FIVE: u16 = 0b0_10000001_0100000;
1726+
const BF16_SIX: u16 = 0b0_10000001_1000000;
1727+
const BF16_SEVEN: u16 = 0b0_10000001_1100000;
1728+
const BF16_EIGHT: u16 = 0b0_10000010_0000000;
1729+
1730+
#[simd_test(enable = "avx512bf16")]
1731+
unsafe fn test_mm512_cvtpbh_ps() {
1732+
let a = __m256bh(
1733+
BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT,
1734+
BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT,
1735+
);
1736+
let r = _mm512_cvtpbh_ps(a);
1737+
let e = _mm512_setr_ps(
1738+
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0,
1739+
);
1740+
assert_eq_m512(r, e);
1741+
}
1742+
1743+
#[simd_test(enable = "avx512bf16")]
1744+
unsafe fn test_mm512_mask_cvtpbh_ps() {
1745+
let a = __m256bh(
1746+
BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT,
1747+
BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT,
1748+
);
1749+
let src = _mm512_setr_ps(
1750+
9., 10., 11., 12., 13., 14., 15., 16., 9., 10., 11., 12., 13., 14., 15., 16.,
1751+
);
1752+
let k = 0b1010_1010_1010_1010;
1753+
let r = _mm512_mask_cvtpbh_ps(src, k, a);
1754+
let e = _mm512_setr_ps(
1755+
9., 2., 11., 4., 13., 6., 15., 8., 9., 2., 11., 4., 13., 6., 15., 8.,
1756+
);
1757+
assert_eq_m512(r, e);
1758+
}
1759+
1760+
#[simd_test(enable = "avx512bf16")]
1761+
unsafe fn test_mm512_maskz_cvtpbh_ps() {
1762+
let a = __m256bh(
1763+
BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT,
1764+
BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT,
1765+
);
1766+
let k = 0b1010_1010_1010_1010;
1767+
let r = _mm512_maskz_cvtpbh_ps(k, a);
1768+
let e = _mm512_setr_ps(
1769+
0., 2., 0., 4., 0., 6., 0., 8., 0., 2., 0., 4., 0., 6., 0., 8.,
1770+
);
1771+
assert_eq_m512(r, e);
1772+
}
1773+
1774+
#[simd_test(enable = "avx512bf16,avx512vl")]
1775+
unsafe fn test_mm256_cvtpbh_ps() {
1776+
let a = __m128bh(
1777+
BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT,
1778+
);
1779+
let r = _mm256_cvtpbh_ps(a);
1780+
let e = _mm256_setr_ps(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0);
1781+
assert_eq_m256(r, e);
1782+
}
1783+
1784+
#[simd_test(enable = "avx512bf16,avx512vl")]
1785+
unsafe fn test_mm256_mask_cvtpbh_ps() {
1786+
let a = __m128bh(
1787+
BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT,
1788+
);
1789+
let src = _mm256_setr_ps(9., 10., 11., 12., 13., 14., 15., 16.);
1790+
let k = 0b1010_1010;
1791+
let r = _mm256_mask_cvtpbh_ps(src, k, a);
1792+
let e = _mm256_setr_ps(9., 2., 11., 4., 13., 6., 15., 8.);
1793+
assert_eq_m256(r, e);
1794+
}
1795+
1796+
#[simd_test(enable = "avx512bf16,avx512vl")]
1797+
unsafe fn test_mm256_maskz_cvtpbh_ps() {
1798+
let a = __m128bh(
1799+
BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT,
1800+
);
1801+
let k = 0b1010_1010;
1802+
let r = _mm256_maskz_cvtpbh_ps(k, a);
1803+
let e = _mm256_setr_ps(0., 2., 0., 4., 0., 6., 0., 8.);
1804+
assert_eq_m256(r, e);
1805+
}
1806+
1807+
#[simd_test(enable = "avx512bf16,avx512vl")]
1808+
unsafe fn test_mm_cvtpbh_ps() {
1809+
let a = __m128bh(BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, 0, 0, 0, 0);
1810+
let r = _mm_cvtpbh_ps(a);
1811+
let e = _mm_setr_ps(1.0, 2.0, 3.0, 4.0);
1812+
assert_eq_m128(r, e);
1813+
}
1814+
1815+
#[simd_test(enable = "avx512bf16,avx512vl")]
1816+
unsafe fn test_mm_mask_cvtpbh_ps() {
1817+
let a = __m128bh(BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, 0, 0, 0, 0);
1818+
let src = _mm_setr_ps(9., 10., 11., 12.);
1819+
let k = 0b1010;
1820+
let r = _mm_mask_cvtpbh_ps(src, k, a);
1821+
let e = _mm_setr_ps(9., 2., 11., 4.);
1822+
assert_eq_m128(r, e);
1823+
}
1824+
1825+
#[simd_test(enable = "avx512bf16,avx512vl")]
1826+
unsafe fn test_mm_maskz_cvtpbh_ps() {
1827+
let a = __m128bh(BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, 0, 0, 0, 0);
1828+
let k = 0b1010;
1829+
let r = _mm_maskz_cvtpbh_ps(k, a);
1830+
let e = _mm_setr_ps(0., 2., 0., 4.);
1831+
assert_eq_m128(r, e);
1832+
}
1833+
1834+
#[simd_test(enable = "avx512bf16")]
1835+
unsafe fn test_mm_cvtsbh_ss() {
1836+
let r = _mm_cvtsbh_ss(BF16_ONE);
1837+
assert_eq!(r, 1.);
1838+
}
15951839
}

crates/stdarch-verify/tests/x86-intel.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ fn equate(
699699
(&Type::PrimSigned(32), "__int32" | "const int" | "int") => {}
700700
(&Type::PrimSigned(64), "__int64" | "long long") => {}
701701
(&Type::PrimUnsigned(8), "unsigned char") => {}
702-
(&Type::PrimUnsigned(16), "unsigned short") => {}
702+
(&Type::PrimUnsigned(16), "unsigned short" | "__bfloat16") => {}
703703
(
704704
&Type::PrimUnsigned(32),
705705
"unsigned __int32" | "unsigned int" | "unsigned long" | "const unsigned int",

0 commit comments

Comments
 (0)