Skip to content

Commit fe8f300

Browse files
sayantnAmanieu
authored andcommitted
Implemented some missing functions
These cannot be linked with LLVM because of the lack of `bfloat16` and `i1` types in Rust. So, inline asm was the only way
1 parent cc5e826 commit fe8f300

File tree

3 files changed

+176
-13
lines changed

3 files changed

+176
-13
lines changed

crates/core_arch/missing-x86.md

-12
Original file line numberDiff line numberDiff line change
@@ -147,15 +147,6 @@
147147
</p></details>
148148

149149

150-
<details><summary>["AVX512_BF16", "AVX512VL"]</summary><p>
151-
152-
* [ ] [`_mm_cvtneps_pbh`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtneps_pbh)
153-
* [ ] [`_mm_cvtness_sbh`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtness_sbh)
154-
* [ ] [`_mm_mask_cvtneps_pbh`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_mask_cvtneps_pbh)
155-
* [ ] [`_mm_maskz_cvtneps_pbh`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_maskz_cvtneps_pbh)
156-
</p></details>
157-
158-
159150
<details><summary>["AVX512_FP16"]</summary><p>
160151

161152
* [ ] [`_mm256_castpd_ph`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_castpd_ph)
@@ -1125,12 +1116,9 @@
11251116
* [ ] [`_mm256_bcstnesh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_bcstnesh_ps)
11261117
* [ ] [`_mm256_cvtneeph_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_cvtneeph_ps)
11271118
* [ ] [`_mm256_cvtneoph_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_cvtneoph_ps)
1128-
* [ ] [`_mm256_cvtneps_avx_pbh`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_cvtneps_avx_pbh)
11291119
* [ ] [`_mm_bcstnesh_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_bcstnesh_ps)
11301120
* [ ] [`_mm_cvtneeph_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtneeph_ps)
11311121
* [ ] [`_mm_cvtneoph_ps`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtneoph_ps)
1132-
* [ ] [`_mm_cvtneps_avx_pbh`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtneps_avx_pbh)
1133-
* [ ] [`_mm_cvtneps_pbh`](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtneps_pbh)
11341122
</p></details>
11351123

11361124

crates/core_arch/src/x86/avx512bf16.rs

+111-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
//!
33
//! [AVX512BF16 intrinsics]: https://software.intel.com/sites/landingpage/IntrinsicsGuide/#expand=1769&avx512techs=AVX512_BF16
44
5+
use crate::arch::asm;
56
use crate::core_arch::{simd::*, x86::*};
67
use crate::intrinsics::simd::*;
78

@@ -490,9 +491,85 @@ pub unsafe fn _mm_cvtsbh_ss(a: u16) -> f32 {
490491
f32::from_bits((a as u32) << 16)
491492
}
492493

494+
/// Converts packed single-precision (32-bit) floating-point elements in a to packed BF16 (16-bit)
495+
/// floating-point elements, and store the results in dst.
496+
///
497+
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtneps_pbh)
498+
#[inline]
499+
#[target_feature(enable = "avx512bf16,avx512vl,sse")]
500+
#[cfg_attr(test, assert_instr("vcvtneps2bf16"))]
501+
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
502+
pub unsafe fn _mm_cvtneps_pbh(a: __m128) -> __m128bh {
503+
let mut dst: __m128bh;
504+
asm!(
505+
"vcvtneps2bf16 {dst}, {src}",
506+
dst = lateout(xmm_reg) dst,
507+
src = in(xmm_reg) a,
508+
options(pure, nomem, nostack, preserves_flags)
509+
);
510+
dst
511+
}
512+
513+
/// Converts packed single-precision (32-bit) floating-point elements in a to packed BF16 (16-bit)
514+
/// floating-point elements, and store the results in dst using writemask k (elements are copied
515+
/// from src when the corresponding mask bit is not set).
516+
///
517+
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_mask_cvtneps_pbh)
518+
#[inline]
519+
#[target_feature(enable = "avx512bf16,avx512vl,sse,avx512f")]
520+
#[cfg_attr(test, assert_instr("vcvtneps2bf16"))]
521+
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
522+
pub unsafe fn _mm_mask_cvtneps_pbh(src: __m128bh, k: __mmask8, a: __m128) -> __m128bh {
523+
let mut dst = src;
524+
asm!(
525+
"vcvtneps2bf16 {dst}{{{k}}},{src}",
526+
dst = inlateout(xmm_reg) dst,
527+
src = in(xmm_reg) a,
528+
k = in(kreg) k,
529+
options(pure, nomem, nostack, preserves_flags)
530+
);
531+
dst
532+
}
533+
534+
/// Converts packed single-precision (32-bit) floating-point elements in a to packed BF16 (16-bit)
535+
/// floating-point elements, and store the results in dst using zeromask k (elements are zeroed out
536+
/// when the corresponding mask bit is not set).
537+
///
538+
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_maskz_cvtneps_pbh)
539+
#[inline]
540+
#[target_feature(enable = "avx512bf16,avx512vl,sse,avx512f")]
541+
#[cfg_attr(test, assert_instr("vcvtneps2bf16"))]
542+
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
543+
pub unsafe fn _mm_maskz_cvtneps_pbh(k: __mmask8, a: __m128) -> __m128bh {
544+
let mut dst: __m128bh;
545+
asm!(
546+
"vcvtneps2bf16 {dst}{{{k}}}{{z}},{src}",
547+
dst = lateout(xmm_reg) dst,
548+
src = in(xmm_reg) a,
549+
k = in(kreg) k,
550+
options(pure, nomem, nostack, preserves_flags)
551+
);
552+
dst
553+
}
554+
555+
/// Converts a single-precision (32-bit) floating-point element in a to a BF16 (16-bit) floating-point
556+
/// element, and store the result in dst.
557+
///
558+
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtness_sbh)
559+
#[inline]
560+
#[target_feature(enable = "avx512bf16,avx512vl")]
561+
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
562+
pub unsafe fn _mm_cvtness_sbh(a: f32) -> u16 {
563+
simd_extract!(_mm_cvtneps_pbh(_mm_set_ss(a)), 0)
564+
}
565+
493566
#[cfg(test)]
494567
mod tests {
495-
use crate::{core_arch::x86::*, mem::transmute};
568+
use crate::core_arch::simd::u16x4;
569+
use crate::{
570+
core_arch::x86::*,
571+
mem::{transmute, transmute_copy},
572+
};
496573
use stdarch_test::simd_test;
497574

498575
#[simd_test(enable = "avx512bf16,avx512vl")]
@@ -1836,4 +1913,37 @@ mod tests {
18361913
let r = _mm_cvtsbh_ss(BF16_ONE);
18371914
assert_eq!(r, 1.);
18381915
}
1916+
1917+
#[simd_test(enable = "avx512bf16,avx512vl")]
1918+
unsafe fn test_mm_cvtneps_pbh() {
1919+
let a = _mm_setr_ps(1.0, 2.0, 3.0, 4.0);
1920+
let r: u16x4 = transmute_copy(&_mm_cvtneps_pbh(a));
1921+
let e = u16x4::new(BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR);
1922+
assert_eq!(r, e);
1923+
}
1924+
1925+
#[simd_test(enable = "avx512bf16,avx512vl")]
1926+
unsafe fn test_mm_mask_cvtneps_pbh() {
1927+
let a = _mm_setr_ps(1.0, 2.0, 3.0, 4.0);
1928+
let src = __m128bh(5, 6, 7, 8, !0, !0, !0, !0);
1929+
let k = 0b1010;
1930+
let r: u16x4 = transmute_copy(&_mm_mask_cvtneps_pbh(src, k, a));
1931+
let e = u16x4::new(5, BF16_TWO, 7, BF16_FOUR);
1932+
assert_eq!(r, e);
1933+
}
1934+
1935+
#[simd_test(enable = "avx512bf16,avx512vl")]
1936+
unsafe fn test_mm_maskz_cvtneps_pbh() {
1937+
let a = _mm_setr_ps(1.0, 2.0, 3.0, 4.0);
1938+
let k = 0b1010;
1939+
let r: u16x4 = transmute_copy(&_mm_maskz_cvtneps_pbh(k, a));
1940+
let e = u16x4::new(0, BF16_TWO, 0, BF16_FOUR);
1941+
assert_eq!(r, e);
1942+
}
1943+
1944+
#[simd_test(enable = "avx512bf16,avx512vl")]
1945+
unsafe fn test_mm_cvtness_sbh() {
1946+
let r = _mm_cvtness_sbh(1.);
1947+
assert_eq!(r, BF16_ONE);
1948+
}
18391949
}

crates/core_arch/src/x86/avxneconvert.rs

+65
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use crate::arch::asm;
12
use crate::core_arch::{simd::*, x86::*};
23

34
#[cfg(test)]
@@ -95,6 +96,50 @@ pub unsafe fn _mm256_cvtneobf16_ps(a: *const __m256bh) -> __m256 {
9596
transmute(cvtneobf162ps_256(a))
9697
}
9798

99+
/// Convert packed single precision (32-bit) floating-point elements in a to packed BF16 (16-bit) floating-point
100+
/// elements, and store the results in dst.
101+
///
102+
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm_cvtneps_avx_bf16)
103+
#[inline]
104+
#[target_feature(enable = "avxneconvert,sse")]
105+
#[cfg_attr(
106+
all(test, any(target_os = "linux", target_env = "msvc")),
107+
assert_instr(vcvtneps2bf16)
108+
)]
109+
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
110+
pub unsafe fn _mm_cvtneps_avx_pbh(a: __m128) -> __m128bh {
111+
let mut dst: __m128bh;
112+
asm!(
113+
"{{vex}}vcvtneps2bf16 {dst},{src}",
114+
dst = lateout(xmm_reg) dst,
115+
src = in(xmm_reg) a,
116+
options(pure, nomem, nostack, preserves_flags)
117+
);
118+
dst
119+
}
120+
121+
/// Convert packed single precision (32-bit) floating-point elements in a to packed BF16 (16-bit) floating-point
122+
/// elements, and store the results in dst.
123+
///
124+
/// [Intel's documentation](https://software.intel.com/sites/landingpage/IntrinsicsGuide/#text=_mm256_cvtneps_avx_bf16)
125+
#[inline]
126+
#[target_feature(enable = "avxneconvert,sse,avx")]
127+
#[cfg_attr(
128+
all(test, any(target_os = "linux", target_env = "msvc")),
129+
assert_instr(vcvtneps2bf16)
130+
)]
131+
#[unstable(feature = "stdarch_x86_avx512", issue = "111137")]
132+
pub unsafe fn _mm256_cvtneps_avx_pbh(a: __m256) -> __m128bh {
133+
let mut dst: __m128bh;
134+
asm!(
135+
"{{vex}}vcvtneps2bf16 {dst},{src}",
136+
dst = lateout(xmm_reg) dst,
137+
src = in(ymm_reg) a,
138+
options(pure, nomem, nostack, preserves_flags)
139+
);
140+
dst
141+
}
142+
98143
#[allow(improper_ctypes)]
99144
extern "C" {
100145
#[link_name = "llvm.x86.vbcstnebf162ps128"]
@@ -115,7 +160,9 @@ extern "C" {
115160

116161
#[cfg(test)]
117162
mod tests {
163+
use crate::core_arch::simd::{u16x4, u16x8};
118164
use crate::core_arch::x86::*;
165+
use crate::mem::transmute_copy;
119166
use std::ptr::addr_of;
120167
use stdarch_test::simd_test;
121168

@@ -185,4 +232,22 @@ mod tests {
185232
let e = _mm256_setr_ps(2., 4., 6., 8., 2., 4., 6., 8.);
186233
assert_eq_m256(r, e);
187234
}
235+
236+
#[simd_test(enable = "avxneconvert")]
237+
unsafe fn test_mm_cvtneps_avx_pbh() {
238+
let a = _mm_setr_ps(1., 2., 3., 4.);
239+
let r: u16x4 = transmute_copy(&_mm_cvtneps_avx_pbh(a));
240+
let e = u16x4::new(BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR);
241+
assert_eq!(r, e);
242+
}
243+
244+
#[simd_test(enable = "avxneconvert")]
245+
unsafe fn test_mm256_cvtneps_avx_pbh() {
246+
let a = _mm256_setr_ps(1., 2., 3., 4., 5., 6., 7., 8.);
247+
let r: u16x8 = transmute(_mm256_cvtneps_avx_pbh(a));
248+
let e = u16x8::new(
249+
BF16_ONE, BF16_TWO, BF16_THREE, BF16_FOUR, BF16_FIVE, BF16_SIX, BF16_SEVEN, BF16_EIGHT,
250+
);
251+
assert_eq!(r, e);
252+
}
188253
}

0 commit comments

Comments
 (0)