Skip to content

Commit 2581c60

Browse files
committed
refactor: use sparse struct to parse
Signed-off-by: cutecutecat <[email protected]>
1 parent 2d6c196 commit 2581c60

File tree

5 files changed

+41
-56
lines changed

5 files changed

+41
-56
lines changed

src/datatype/text_svecf32.rs

Lines changed: 5 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,22 @@
11
use super::memory_svecf32::SVecf32Output;
22
use crate::datatype::memory_svecf32::SVecf32Input;
3-
use crate::datatype::typmod::Typmod;
43
use crate::error::*;
54
use base::scalar::*;
65
use base::vector::*;
7-
use num_traits::Zero;
86
use pgrx::pg_sys::Oid;
97
use std::ffi::{CStr, CString};
108

119
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
12-
fn _vectors_svecf32_in(input: &CStr, _oid: Oid, typmod: i32) -> SVecf32Output {
10+
fn _vectors_svecf32_in(input: &CStr, _oid: Oid, _typmod: i32) -> SVecf32Output {
1311
use crate::utils::parse::parse_pgvector_svector;
14-
let reserve = Typmod::parse_from_i32(typmod)
15-
.unwrap()
16-
.dims()
17-
.map(|x| x.get())
18-
.unwrap_or(0);
19-
let v = parse_pgvector_svector(input.to_bytes(), reserve as usize, |s| {
20-
s.parse::<F32>().ok()
21-
});
12+
let v = parse_pgvector_svector(input.to_bytes(), |s| s.parse::<F32>().ok());
2213
match v {
2314
Err(e) => {
2415
bad_literal(&e.to_string());
2516
}
26-
Ok(vector) => {
27-
check_value_dims_1048575(vector.len());
28-
let mut indexes = Vec::<u32>::new();
29-
let mut values = Vec::<F32>::new();
30-
for (i, &x) in vector.iter().enumerate() {
31-
if !x.is_zero() {
32-
indexes.push(i as u32);
33-
values.push(x);
34-
}
35-
}
36-
SVecf32Output::new(SVecf32Borrowed::new(vector.len() as u32, &indexes, &values))
17+
Ok((indexes, values, dims)) => {
18+
check_value_dims_1048575(dims);
19+
SVecf32Output::new(SVecf32Borrowed::new(dims as u32, &indexes, &values))
3720
}
3821
}
3922
}

src/sql/finalize.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -732,7 +732,7 @@ CREATE AGGREGATE avg(svector) (
732732
STYPE = svector_accumulate_state,
733733
COMBINEFUNC = _vectors_svector_combine,
734734
FINALFUNC = _vectors_svector_final,
735-
INITCOND = '(0, [0])',
735+
INITCOND = '(0, {}/1)',
736736
PARALLEL = SAFE
737737
);
738738

src/utils/parse.rs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,16 @@ where
8888
#[inline(always)]
8989
pub fn parse_pgvector_svector<T: Zero + Clone, F>(
9090
input: &[u8],
91-
reserve: usize,
9291
f: F,
93-
) -> Result<Vec<T>, ParseVectorError>
92+
) -> Result<(Vec<u32>, Vec<T>, usize), ParseVectorError>
9493
where
9594
F: Fn(&str) -> Option<T>,
9695
{
9796
use arrayvec::ArrayVec;
9897
if input.is_empty() {
9998
return Err(ParseVectorError::EmptyString {});
10099
}
100+
let mut dims: usize = 0;
101101
let left = 'a: {
102102
for position in 0..input.len() - 1 {
103103
match input[position] {
@@ -109,7 +109,6 @@ where
109109
return Err(ParseVectorError::BadParentheses { character: '{' });
110110
};
111111
let mut token: ArrayVec<u8, 48> = ArrayVec::new();
112-
let mut capacity = reserve;
113112
let right = 'a: {
114113
for position in (1..input.len()).rev() {
115114
match input[position] {
@@ -121,7 +120,7 @@ where
121120
b'/' => {
122121
token.reverse();
123122
let s = unsafe { std::str::from_utf8_unchecked(&token[..]) };
124-
capacity = s
123+
dims = s
125124
.parse::<usize>()
126125
.map_err(|_| ParseVectorError::BadParsing { position })?;
127126
}
@@ -135,8 +134,9 @@ where
135134
}
136135
return Err(ParseVectorError::BadParentheses { character: '}' });
137136
};
138-
let mut vector = vec![T::zero(); capacity];
139-
let mut index: usize = 0;
137+
let mut indexes = Vec::<u32>::new();
138+
let mut values = Vec::<T>::new();
139+
let mut index: u32 = 0;
140140
for position in left + 1..right {
141141
let c = input[position];
142142
match c {
@@ -153,7 +153,8 @@ where
153153
// Safety: all bytes in `token` are ascii characters
154154
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
155155
let num = f(s).ok_or(ParseVectorError::BadParsing { position })?;
156-
vector[index] = num;
156+
indexes.push(index);
157+
values.push(num);
157158
token.clear();
158159
} else {
159160
return Err(ParseVectorError::TooShortNumber { position });
@@ -164,7 +165,7 @@ where
164165
// Safety: all bytes in `token` are ascii characters
165166
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
166167
index = s
167-
.parse::<usize>()
168+
.parse::<u32>()
168169
.map_err(|_| ParseVectorError::BadParsing { position })?;
169170
token.clear();
170171
} else {
@@ -180,8 +181,9 @@ where
180181
// Safety: all bytes in `token` are ascii characters
181182
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
182183
let num = f(s).ok_or(ParseVectorError::BadParsing { position })?;
183-
vector[index] = num;
184+
indexes.push(index);
185+
values.push(num);
184186
token.clear();
185187
}
186-
Ok(vector)
188+
Ok((indexes, values, dims))
187189
}

tests/sqllogictest/sparse.slt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ DROP TABLE t;
4040
query I
4141
SELECT to_svector(5, '{1,2}', '{1,2}');
4242
----
43-
{2:1, 3:2}/5
43+
{1:1, 2:2}/5
4444

4545
query I
4646
SELECT to_svector(5, '{1,2}', '{1,1}') * to_svector(5, '{1,3}', '{2,2}');

tests/sqllogictest/svector_subscript.slt

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,83 +2,83 @@ statement ok
22
SET search_path TO pg_temp, vectors;
33

44
query I
5-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[3:6];
5+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[3:6];
66
----
7-
{1:3, 2:4, 3:5}/3
7+
{0:3, 1:4, 2:5}/3
88

99
query I
10-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:4];
10+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[:4];
1111
----
12-
{2:1, 3:2, 4:3}/4
12+
{1:1, 2:2, 3:3}/4
1313

1414
query I
15-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[5:];
15+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[5:];
1616
----
17-
{1:5, 2:6, 3:7}/3
17+
{0:5, 1:6, 2:7}/3
1818

1919
query I
20-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[1:8];
20+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[1:8];
2121
----
22-
{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/7
22+
{0:1, 1:2, 2:3, 3:4, 4:5, 5:6, 6:7}/7
2323

2424
statement error type svector does only support one subscript
25-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[3:3][1:1];
25+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[3:3][1:1];
2626

2727
statement error type svector does only support slice fetch
28-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[3];
28+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[3];
2929

3030
query I
31-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[5:4];
31+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[5:4];
3232
----
3333
NULL
3434

3535
query I
36-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[9:];
36+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[9:];
3737
----
3838
NULL
3939

4040
query I
41-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:0];
41+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[:0];
4242
----
4343
NULL
4444

4545
query I
46-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:-1];
46+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[:-1];
4747
----
4848
NULL
4949

5050
query I
51-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[NULL:NULL];
51+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[NULL:NULL];
5252
----
5353
NULL
5454

5555
query I
56-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[NULL:8];
56+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[NULL:8];
5757
----
5858
NULL
5959

6060
query I
61-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[1:NULL];
61+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[1:NULL];
6262
----
6363
NULL
6464

6565
query I
66-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[NULL:];
66+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[NULL:];
6767
----
6868
NULL
6969

7070
query I
71-
SELECT ('{2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 8:7}/8'::svector)[:NULL];
71+
SELECT ('{1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7}/8'::svector)[:NULL];
7272
----
7373
NULL
7474

7575
query I
76-
SELECT ('{3:2, 5:4, 8:7}/8'::svector)[3:7];
76+
SELECT ('{2:2, 4:4, 7:7}/8'::svector)[3:7];
7777
----
78-
{2:4}/4
78+
{1:4}/4
7979

8080
query I
81-
SELECT ('{3:2, 5:4, 8:7}/8'::svector)[5:7];
81+
SELECT ('{2:2, 4:4, 7:7}/8'::svector)[5:7];
8282
----
8383
{}/2
8484

0 commit comments

Comments
 (0)