Skip to content

Commit 0855419

Browse files
committed
fix: add rerun in build script
Signed-off-by: usamoi <[email protected]>
1 parent fe764af commit 0855419

File tree

6 files changed

+51
-65
lines changed

6 files changed

+51
-65
lines changed

crates/c/build.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
fn main() {
2+
println!("rerun-if-changed:src/c.h");
3+
println!("rerun-if-changed:src/c.c");
24
cc::Build::new()
35
.compiler("/usr/bin/clang-16")
46
.file("./src/c.c")

crates/c/src/c.c

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
#include <math.h>
44

55
__attribute__((target("avx512fp16,bmi2"))) extern float
6-
vectors_f16_cosine_axv512(_Float16 const *restrict a,
7-
_Float16 const *restrict b, size_t n) {
6+
v_f16_cosine_axv512(_Float16 const *restrict a, _Float16 const *restrict b,
7+
size_t n) {
88
__m512h xy = _mm512_set1_ph(0);
99
__m512h xx = _mm512_set1_ph(0);
1010
__m512h yy = _mm512_set1_ph(0);
@@ -30,8 +30,8 @@ vectors_f16_cosine_axv512(_Float16 const *restrict a,
3030
}
3131

3232
__attribute__((target("avx512fp16,bmi2"))) extern float
33-
vectors_f16_dot_axv512(_Float16 const *restrict a, _Float16 const *restrict b,
34-
size_t n) {
33+
v_f16_dot_axv512(_Float16 const *restrict a, _Float16 const *restrict b,
34+
size_t n) {
3535
__m512h xy = _mm512_set1_ph(0);
3636

3737
while (n >= 32) {
@@ -50,8 +50,8 @@ vectors_f16_dot_axv512(_Float16 const *restrict a, _Float16 const *restrict b,
5050
}
5151

5252
__attribute__((target("avx512fp16,bmi2"))) extern float
53-
vectors_f16_distance_squared_l2_axv512(_Float16 const *restrict a,
54-
_Float16 const *restrict b, size_t n) {
53+
v_f16_sl2_axv512(_Float16 const *restrict a, _Float16 const *restrict b,
54+
size_t n) {
5555
__m512h dd = _mm512_set1_ph(0);
5656

5757
while (n >= 32) {
@@ -73,8 +73,8 @@ vectors_f16_distance_squared_l2_axv512(_Float16 const *restrict a,
7373
}
7474

7575
__attribute__((target("avx2"))) extern float
76-
vectors_f16_cosine_axv2(_Float16 const *restrict a, _Float16 const *restrict b,
77-
size_t n) {
76+
v_f16_cosine_axv2(_Float16 const *restrict a, _Float16 const *restrict b,
77+
size_t n) {
7878
float xy = 0;
7979
float xx = 0;
8080
float yy = 0;
@@ -90,8 +90,8 @@ vectors_f16_cosine_axv2(_Float16 const *restrict a, _Float16 const *restrict b,
9090
}
9191

9292
__attribute__((target("avx2"))) extern float
93-
vectors_f16_dot_axv2(_Float16 const *restrict a, _Float16 const *restrict b,
94-
size_t n) {
93+
v_f16_dot_axv2(_Float16 const *restrict a, _Float16 const *restrict b,
94+
size_t n) {
9595
float xy = 0;
9696
#pragma clang loop vectorize_width(8)
9797
for (size_t i = 0; i < n; i++) {
@@ -103,8 +103,8 @@ vectors_f16_dot_axv2(_Float16 const *restrict a, _Float16 const *restrict b,
103103
}
104104

105105
__attribute__((target("avx2"))) extern float
106-
vectors_f16_distance_squared_l2_axv2(_Float16 const *restrict a,
107-
_Float16 const *restrict b, size_t n) {
106+
v_f16_sl2_axv2(_Float16 const *restrict a, _Float16 const *restrict b,
107+
size_t n) {
108108
float dd = 0;
109109
#pragma clang loop vectorize_width(8)
110110
for (size_t i = 0; i < n; i++) {

crates/c/src/c.h

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,9 @@
11
#include <stddef.h>
22
#include <stdint.h>
33

4-
extern float vectors_f16_cosine_axv512(_Float16 const *, _Float16 const *,
5-
size_t n);
6-
extern float vectors_f16_dot_axv512(_Float16 const *, _Float16 const *,
7-
size_t n);
8-
extern float vectors_f16_distance_squared_l2_axv512(_Float16 const *,
9-
_Float16 const *, size_t n);
10-
11-
extern float vectors_f16_cosine_axv2(_Float16 const *, _Float16 const *,
12-
size_t n);
13-
extern float vectors_f16_dot_axv2(_Float16 const *, _Float16 const *, size_t n);
14-
extern float vectors_f16_distance_squared_l2_axv2(_Float16 const *,
15-
_Float16 const *, size_t n);
4+
extern float v_f16_cosine_axv512(_Float16 const *, _Float16 const *, size_t n);
5+
extern float v_f16_dot_axv512(_Float16 const *, _Float16 const *, size_t n);
6+
extern float v_f16_sl2_axv512(_Float16 const *, _Float16 const *, size_t n);
7+
extern float v_f16_cosine_axv2(_Float16 const *, _Float16 const *, size_t n);
8+
extern float v_f16_dot_axv2(_Float16 const *, _Float16 const *, size_t n);
9+
extern float v_f16_sl2_axv2(_Float16 const *, _Float16 const *, size_t n);

crates/c/src/c.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#[link(name = "pgvectorsc", kind = "static")]
22
extern "C" {
3-
pub fn vectors_f16_cosine_axv512(a: *const u16, b: *const u16, n: usize) -> f32;
4-
pub fn vectors_f16_dot_axv512(a: *const u16, b: *const u16, n: usize) -> f32;
5-
pub fn vectors_f16_distance_squared_l2_axv512(a: *const u16, b: *const u16, n: usize) -> f32;
6-
pub fn vectors_f16_cosine_axv2(a: *const u16, b: *const u16, n: usize) -> f32;
7-
pub fn vectors_f16_dot_axv2(a: *const u16, b: *const u16, n: usize) -> f32;
8-
pub fn vectors_f16_distance_squared_l2_axv2(a: *const u16, b: *const u16, n: usize) -> f32;
3+
pub fn v_f16_cosine_axv512(a: *const u16, b: *const u16, n: usize) -> f32;
4+
pub fn v_f16_dot_axv512(a: *const u16, b: *const u16, n: usize) -> f32;
5+
pub fn v_f16_sl2_axv512(a: *const u16, b: *const u16, n: usize) -> f32;
6+
pub fn v_f16_cosine_axv2(a: *const u16, b: *const u16, n: usize) -> f32;
7+
pub fn v_f16_dot_axv2(a: *const u16, b: *const u16, n: usize) -> f32;
8+
pub fn v_f16_sl2_axv2(a: *const u16, b: *const u16, n: usize) -> f32;
99
}

crates/service/src/prelude/global/f16.rs

Lines changed: 21 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,18 @@ pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 {
1717
xy / (x2 * y2).sqrt()
1818
}
1919
if super::avx512fp16::detect() {
20+
assert!(lhs.len() == rhs.len());
21+
let n = lhs.len();
2022
unsafe {
21-
assert!(lhs.len() == rhs.len());
22-
let n = lhs.len();
23-
return c::vectors_f16_cosine_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n)
23+
return c::v_f16_cosine_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n)
2424
.into();
2525
}
2626
}
2727
if super::avx2::detect() {
28+
assert!(lhs.len() == rhs.len());
29+
let n = lhs.len();
2830
unsafe {
29-
assert!(lhs.len() == rhs.len());
30-
let n = lhs.len();
31-
return c::vectors_f16_cosine_axv2(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
31+
return c::v_f16_cosine_axv2(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
3232
}
3333
}
3434
cosine(lhs, rhs)
@@ -47,26 +47,26 @@ pub fn dot(lhs: &[F16], rhs: &[F16]) -> F32 {
4747
xy
4848
}
4949
if super::avx512fp16::detect() {
50+
assert!(lhs.len() == rhs.len());
51+
let n = lhs.len();
5052
unsafe {
51-
assert!(lhs.len() == rhs.len());
52-
let n = lhs.len();
53-
return c::vectors_f16_dot_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
53+
return c::v_f16_dot_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
5454
}
5555
}
5656
if super::avx2::detect() {
57+
assert!(lhs.len() == rhs.len());
58+
let n = lhs.len();
5759
unsafe {
58-
assert!(lhs.len() == rhs.len());
59-
let n = lhs.len();
60-
return c::vectors_f16_dot_axv2(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
60+
return c::v_f16_dot_axv2(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
6161
}
6262
}
6363
cosine(lhs, rhs)
6464
}
6565

66-
pub fn distance_squared_l2(lhs: &[F16], rhs: &[F16]) -> F32 {
66+
pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 {
6767
#[inline(always)]
6868
#[multiversion::multiversion(targets = "simd")]
69-
pub fn distance_squared_l2(lhs: &[F16], rhs: &[F16]) -> F32 {
69+
pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 {
7070
assert!(lhs.len() == rhs.len());
7171
let n = lhs.len();
7272
let mut d2 = F32::zero();
@@ -77,28 +77,18 @@ pub fn distance_squared_l2(lhs: &[F16], rhs: &[F16]) -> F32 {
7777
d2
7878
}
7979
if super::avx512fp16::detect() {
80+
assert!(lhs.len() == rhs.len());
81+
let n = lhs.len();
8082
unsafe {
81-
assert!(lhs.len() == rhs.len());
82-
let n = lhs.len();
83-
return c::vectors_f16_distance_squared_l2_axv512(
84-
lhs.as_ptr().cast(),
85-
rhs.as_ptr().cast(),
86-
n,
87-
)
88-
.into();
83+
return c::v_f16_sl2_axv512(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
8984
}
9085
}
9186
if super::avx2::detect() {
87+
assert!(lhs.len() == rhs.len());
88+
let n = lhs.len();
9289
unsafe {
93-
assert!(lhs.len() == rhs.len());
94-
let n = lhs.len();
95-
return c::vectors_f16_distance_squared_l2_axv2(
96-
lhs.as_ptr().cast(),
97-
rhs.as_ptr().cast(),
98-
n,
99-
)
100-
.into();
90+
return c::v_f16_sl2_axv2(lhs.as_ptr().cast(), rhs.as_ptr().cast(), n).into();
10191
}
10292
}
103-
distance_squared_l2(lhs, rhs)
93+
sl2(lhs, rhs)
10494
}

crates/service/src/prelude/global/f16_l2.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@ impl G for F16L2 {
1414
type L2 = F16L2;
1515

1616
fn distance(lhs: &[F16], rhs: &[F16]) -> F32 {
17-
super::f16::distance_squared_l2(lhs, rhs)
17+
super::f16::sl2(lhs, rhs)
1818
}
1919

2020
fn elkan_k_means_normalize(_: &mut [F16]) {}
2121

2222
fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 {
23-
super::f16::distance_squared_l2(lhs, rhs).sqrt()
23+
super::f16::sl2(lhs, rhs).sqrt()
2424
}
2525

2626
#[multiversion::multiversion(targets = "simd")]
@@ -72,7 +72,7 @@ impl G for F16L2 {
7272
let lhs = &lhs[(i * ratio) as usize..][..k as usize];
7373
let rhsp = rhs[i as usize] as usize * dims as usize;
7474
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
75-
result += super::f16::distance_squared_l2(lhs, rhs);
75+
result += super::f16::sl2(lhs, rhs);
7676
}
7777
result
7878
}
@@ -93,7 +93,7 @@ impl G for F16L2 {
9393
let lhs = &centroids[lhsp..][(i * ratio) as usize..][..k as usize];
9494
let rhsp = rhs[i as usize] as usize * dims as usize;
9595
let rhs = &centroids[rhsp..][(i * ratio) as usize..][..k as usize];
96-
result += super::f16::distance_squared_l2(lhs, rhs);
96+
result += super::f16::sl2(lhs, rhs);
9797
}
9898
result
9999
}

0 commit comments

Comments
 (0)