Skip to content

Commit e05cfb8

Browse files
committed
feat: hand-writing avx2
Signed-off-by: usamoi <[email protected]>
1 parent c27617b commit e05cfb8

File tree

7 files changed

+102
-5
lines changed

7 files changed

+102
-5
lines changed

crates/c/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
*.s
2+
*.o

crates/c/src/c.c

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ vectors_f16_cosine_axv512(_Float16 const *restrict a,
2525
xx = _mm512_fmadd_ph(x, x, xx);
2626
yy = _mm512_fmadd_ph(y, y, yy);
2727
}
28-
return (float)(_mm512_reduce_add_ps(xy) /
29-
sqrt(_mm512_reduce_add_ps(xx) * _mm512_reduce_add_ps(yy)));
28+
return (float)(_mm512_reduce_add_ph(xy) /
29+
sqrt(_mm512_reduce_add_ph(xx) * _mm512_reduce_add_ph(yy)));
3030
}
3131

3232
__attribute__((target("avx512fp16,bmi2"))) extern float
@@ -71,3 +71,47 @@ vectors_f16_distance_squared_l2_axv512(_Float16 const *restrict a,
7171

7272
return (float)_mm512_reduce_add_ph(dd);
7373
}
74+
75+
__attribute__((target("avx2"))) extern float
76+
vectors_f16_cosine_axv2(_Float16 const *restrict a, _Float16 const *restrict b,
77+
size_t n) {
78+
float xy = 0;
79+
float xx = 0;
80+
float yy = 0;
81+
#pragma clang loop vectorize_width(8)
82+
for (size_t i = 0; i < n; i++) {
83+
float x = a[i];
84+
float y = b[i];
85+
xy += x * y;
86+
xx += x * x;
87+
yy += y * y;
88+
}
89+
return xy / sqrt(xx * yy);
90+
}
91+
92+
__attribute__((target("avx2"))) extern float
93+
vectors_f16_dot_axv2(_Float16 const *restrict a, _Float16 const *restrict b,
94+
size_t n) {
95+
float xy = 0;
96+
#pragma clang loop vectorize_width(8)
97+
for (size_t i = 0; i < n; i++) {
98+
float x = a[i];
99+
float y = b[i];
100+
xy += x * y;
101+
}
102+
return xy;
103+
}
104+
105+
__attribute__((target("avx2"))) extern float
106+
vectors_f16_distance_squared_l2_axv2(_Float16 const *restrict a,
107+
_Float16 const *restrict b, size_t n) {
108+
float dd = 0;
109+
#pragma clang loop vectorize_width(8)
110+
for (size_t i = 0; i < n; i++) {
111+
float x = a[i];
112+
float y = b[i];
113+
float d = x - y;
114+
dd += d * d;
115+
}
116+
return dd;
117+
}

crates/c/src/c.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,9 @@ extern float vectors_f16_dot_axv512(_Float16 const *, _Float16 const *,
77
size_t n);
88
extern float vectors_f16_distance_squared_l2_axv512(_Float16 const *,
99
_Float16 const *, size_t n);
10+
11+
extern float vectors_f16_cosine_axv2(_Float16 const *, _Float16 const *,
12+
size_t n);
13+
extern float vectors_f16_dot_axv2(_Float16 const *, _Float16 const *, size_t n);
14+
extern float vectors_f16_distance_squared_l2_axv2(_Float16 const *,
15+
_Float16 const *, size_t n);

crates/c/src/c.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,7 @@ extern "C" {
33
pub fn vectors_f16_cosine_axv512(a: *const u16, b: *const u16, n: usize) -> f32;
44
pub fn vectors_f16_dot_axv512(a: *const u16, b: *const u16, n: usize) -> f32;
55
pub fn vectors_f16_distance_squared_l2_axv512(a: *const u16, b: *const u16, n: usize) -> f32;
6+
pub fn vectors_f16_cosine_axv2(a: *const u16, b: *const u16, n: usize) -> f32;
7+
pub fn vectors_f16_dot_axv2(a: *const u16, b: *const u16, n: usize) -> f32;
8+
pub fn vectors_f16_distance_squared_l2_axv2(a: *const u16, b: *const u16, n: usize) -> f32;
69
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#[cfg(not(target_arch = "x86_64"))]
2+
pub fn detect() -> bool {
3+
false
4+
}
5+
6+
#[cfg(target_arch = "x86_64")]
7+
pub fn detect() -> bool {
8+
std_detect::is_x86_feature_detected!("avx2")
9+
}

crates/service/src/prelude/global/f16.rs

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,15 @@ pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 {
2020
unsafe {
2121
assert!(lhs.len() == rhs.len());
2222
let n = lhs.len();
23-
c::vectors_f16_cosine_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n);
23+
return c::vectors_f16_cosine_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n)
24+
.into();
25+
}
26+
}
27+
if super::avx2::detect() {
28+
unsafe {
29+
assert!(lhs.len() == rhs.len());
30+
let n = lhs.len();
31+
return c::vectors_f16_cosine_axv2(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
2432
}
2533
}
2634
cosine(lhs, rhs)
@@ -42,7 +50,14 @@ pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 {
4250
unsafe {
4351
assert!(lhs.len() == rhs.len());
4452
let n = lhs.len();
45-
c::vectors_f16_dot_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n);
53+
return c::vectors_f16_dot_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
54+
}
55+
}
56+
if super::avx2::detect() {
57+
unsafe {
58+
assert!(lhs.len() == rhs.len());
59+
let n = lhs.len();
60+
return c::vectors_f16_dot_axv2(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
4661
}
4762
}
4863
cosine(lhs, rhs)
@@ -65,7 +80,24 @@ pub fn distance_squared_l2(lhs: &[F16], rhs: &[F16]) -> F32 {
6580
unsafe {
6681
assert!(lhs.len() == rhs.len());
6782
let n = lhs.len();
68-
c::vectors_f16_distance_squared_l2_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n);
83+
return c::vectors_f16_distance_squared_l2_axv512(
84+
lhs.as_ptr().cast(),
85+
rhs.as_ptr().cast(),
86+
n,
87+
)
88+
.into();
89+
}
90+
}
91+
if super::avx2::detect() {
92+
unsafe {
93+
assert!(lhs.len() == rhs.len());
94+
let n = lhs.len();
95+
return c::vectors_f16_distance_squared_l2_axv2(
96+
lhs.as_ptr().cast(),
97+
rhs.as_ptr().cast(),
98+
n,
99+
)
100+
.into();
69101
}
70102
}
71103
distance_squared_l2(lhs, rhs)

crates/service/src/prelude/global/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
mod avx2;
12
mod avx512fp16;
23
mod f16;
34
mod f16_cos;

0 commit comments

Comments
 (0)