Skip to content

Commit 94312ef

Browse files
committed
fix: by comments
Signed-off-by: cutecutecat <[email protected]>
1 parent 8a8561e commit 94312ef

File tree

3 files changed

+89
-90
lines changed

3 files changed

+89
-90
lines changed

src/datatype/text_svecf32.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,21 @@ use std::fmt::Write;
99

1010
#[pgrx::pg_extern(immutable, strict, parallel_safe)]
1111
fn _vectors_svecf32_in(input: &CStr, _oid: Oid, _typmod: i32) -> SVecf32Output {
12-
use crate::utils::parse::{parse_pgvector_svector, svector_filter_nonzero};
12+
use crate::utils::parse::{parse_pgvector_svector, svector_filter_nonzero, svector_sorted};
1313
let v = parse_pgvector_svector(input.to_bytes(), |s| s.parse::<F32>().ok());
1414
match v {
1515
Err(e) => {
1616
bad_literal(&e.to_string());
1717
}
1818
Ok((indexes, values, dims)) => {
19+
let (mut sorted_indexes, mut sorted_values) = svector_sorted(&indexes, &values);
1920
check_value_dims_1048575(dims);
20-
check_index_in_bound(&indexes, dims);
21-
let (non_zero_indexes, non_zero_values) = svector_filter_nonzero(&indexes, &values);
21+
check_index_in_bound(&sorted_indexes, dims);
22+
svector_filter_nonzero(&mut sorted_indexes, &mut sorted_values);
2223
SVecf32Output::new(SVecf32Borrowed::new(
2324
dims as u32,
24-
&non_zero_indexes,
25-
&non_zero_values,
25+
&sorted_indexes,
26+
&sorted_values,
2627
))
2728
}
2829
}

src/error.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,15 @@ ADVICE: Check if dimensions of the vector are among 1 and 1_048_575."
7070
}
7171

7272
pub fn check_index_in_bound(indexes: &[u32], dims: usize) -> NonZeroU32 {
73-
let mut last: u32 = 0;
74-
for (i, index) in indexes.iter().enumerate() {
75-
if i > 0 && last == *index {
73+
let mut last: Option<u32> = None;
74+
for index in indexes {
75+
if last == Some(*index) {
7676
error!("Indexes need to be unique, but there are more than one same index {index}")
7777
}
7878
if *index >= dims as u32 {
7979
error!("Index out of bounds: the dim is {dims} but the index is {index}");
8080
}
81-
last = *index;
81+
last = Some(*index);
8282
}
8383
NonZeroU32::new(dims as u32).unwrap()
8484
}

src/utils/parse.rs

Lines changed: 79 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -85,35 +85,54 @@ where
8585
Ok(vector)
8686
}
8787

88-
#[derive(PartialEq, Debug)]
88+
#[derive(PartialEq, Debug, Clone)]
8989
enum ParseState {
9090
Start,
9191
LeftBracket,
9292
Index,
93+
Colon,
9394
Value,
94-
Splitter,
9595
Comma,
96-
Length,
96+
RightBracket,
97+
Splitter,
98+
Dims,
9799
}
98100

99101
#[inline(always)]
100-
pub fn svector_filter_nonzero<T: Zero + Clone + PartialEq>(
102+
pub fn svector_sorted<T: Zero + Clone + PartialEq>(
101103
indexes: &[u32],
102104
values: &[T],
103105
) -> (Vec<u32>, Vec<T>) {
104-
let non_zero_indexes: Vec<u32> = indexes
105-
.iter()
106-
.enumerate()
107-
.filter(|(i, _)| values.get(*i).unwrap() != &T::zero())
108-
.map(|(_, x)| *x)
109-
.collect();
110-
let non_zero_values: Vec<T> = indexes
111-
.iter()
112-
.enumerate()
113-
.filter(|(i, _)| values.get(*i).unwrap() != &T::zero())
114-
.map(|(i, _)| values.get(i).unwrap().clone())
115-
.collect();
116-
(non_zero_indexes, non_zero_values)
106+
let mut indices = (0..indexes.len()).collect::<Vec<_>>();
107+
indices.sort_by_key(|&i| &indexes[i]);
108+
109+
let mut sorted_indexes: Vec<u32> = Vec::with_capacity(indexes.len());
110+
let mut sorted_values: Vec<T> = Vec::with_capacity(indexes.len());
111+
for i in indices {
112+
sorted_indexes.push(*indexes.get(i).unwrap());
113+
sorted_values.push(values.get(i).unwrap().clone());
114+
}
115+
(sorted_indexes, sorted_values)
116+
}
117+
118+
#[inline(always)]
119+
pub fn svector_filter_nonzero<T: Zero + Clone + PartialEq>(
120+
indexes: &mut Vec<u32>,
121+
values: &mut Vec<T>,
122+
) {
123+
// Index must be sorted!
124+
let mut i = 0;
125+
let mut j = 0;
126+
while j < values.len() {
127+
if !values[j].is_zero() {
128+
indexes[i] = indexes[j];
129+
values[i] = values[j].clone();
130+
i += 1;
131+
}
132+
j += 1;
133+
}
134+
indexes.truncate(i);
135+
values.truncate(i);
117136
}
118137

119138
#[inline(always)]
@@ -133,110 +152,82 @@ where
133152
let mut values = Vec::<T>::new();
134153

135154
let mut state = ParseState::Start;
136-
for (position, char) in input.iter().enumerate() {
137-
let c = *char;
138-
match (&state, c) {
139-
(_, b' ') => {}
140-
(ParseState::Start, b'{') => {
141-
state = ParseState::LeftBracket;
142-
}
155+
for (position, c) in input.iter().copied().enumerate() {
156+
state = match (&state, c) {
157+
(_, b' ') => state,
158+
(ParseState::Start, b'{') => ParseState::LeftBracket,
143159
(
144160
ParseState::LeftBracket | ParseState::Index | ParseState::Comma,
145161
b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-',
146162
) => {
147-
if token.is_empty() {
148-
token.push(b'$');
149-
}
150163
if token.try_push(c).is_err() {
151164
return Err(ParseVectorError::TooLongNumber { position });
152165
}
153-
state = ParseState::Index;
166+
ParseState::Index
154167
}
155-
(ParseState::LeftBracket | ParseState::Comma, b'}') => {
156-
state = ParseState::Splitter;
168+
(ParseState::Colon, b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-') => {
169+
if token.try_push(c).is_err() {
170+
return Err(ParseVectorError::TooLongNumber { position });
171+
}
172+
ParseState::Value
157173
}
174+
(ParseState::LeftBracket | ParseState::Comma, b'}') => ParseState::RightBracket,
158175
(ParseState::Index, b':') => {
159-
if token.is_empty() {
160-
return Err(ParseVectorError::TooShortNumber { position });
161-
}
162-
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
176+
let s = unsafe { std::str::from_utf8_unchecked(&token[..]) };
163177
let index = s
164178
.parse::<u32>()
165179
.map_err(|_| ParseVectorError::BadParsing { position })?;
166180
indexes.push(index);
167181
token.clear();
168-
state = ParseState::Value;
182+
ParseState::Colon
169183
}
170184
(ParseState::Value, b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-') => {
171-
if token.is_empty() {
172-
token.push(b'$');
173-
}
174185
if token.try_push(c).is_err() {
175186
return Err(ParseVectorError::TooLongNumber { position });
176187
}
188+
ParseState::Value
177189
}
178190
(ParseState::Value, b',') => {
179-
if token.is_empty() {
180-
return Err(ParseVectorError::TooShortNumber { position });
181-
}
182-
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
191+
let s = unsafe { std::str::from_utf8_unchecked(&token[..]) };
183192
let num = f(s).ok_or(ParseVectorError::BadParsing { position })?;
184193
values.push(num);
185194
token.clear();
186-
state = ParseState::Comma;
195+
ParseState::Comma
187196
}
188197
(ParseState::Value, b'}') => {
189198
if token.is_empty() {
190199
return Err(ParseVectorError::TooShortNumber { position });
191200
}
192-
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
201+
let s = unsafe { std::str::from_utf8_unchecked(&token[..]) };
193202
let num = f(s).ok_or(ParseVectorError::BadParsing { position })?;
194203
values.push(num);
195204
token.clear();
196-
state = ParseState::Splitter;
197-
}
198-
(ParseState::Splitter, b'/') => {
199-
state = ParseState::Length;
205+
ParseState::RightBracket
200206
}
201-
(ParseState::Length, b'0'..=b'9') => {
202-
if token.is_empty() {
203-
token.push(b'$');
204-
}
207+
(ParseState::RightBracket, b'/') => ParseState::Splitter,
208+
(ParseState::Dims | ParseState::Splitter, b'0'..=b'9') => {
205209
if token.try_push(c).is_err() {
206210
return Err(ParseVectorError::TooLongNumber { position });
207211
}
212+
ParseState::Dims
208213
}
209214
(_, _) => {
210215
return Err(ParseVectorError::BadCharacter { position });
211216
}
212217
}
213218
}
214-
if state != ParseState::Length {
219+
if state != ParseState::Dims {
215220
return Err(ParseVectorError::BadParsing {
216221
position: input.len(),
217222
});
218223
}
219-
if token.is_empty() {
220-
return Err(ParseVectorError::TooShortNumber {
221-
position: input.len(),
222-
});
223-
}
224-
let s = unsafe { std::str::from_utf8_unchecked(&token[1..]) };
224+
let s = unsafe { std::str::from_utf8_unchecked(&token[..]) };
225225
let dims = s
226226
.parse::<usize>()
227227
.map_err(|_| ParseVectorError::BadParsing {
228228
position: input.len(),
229229
})?;
230-
231-
let mut indices = (0..indexes.len()).collect::<Vec<_>>();
232-
indices.sort_by_key(|&i| &indexes[i]);
233-
let sorted_values: Vec<T> = indices
234-
.iter()
235-
.map(|i| values.get(*i).unwrap().clone())
236-
.collect();
237-
indexes.sort();
238-
239-
Ok((indexes, sorted_values, dims))
230+
Ok((indexes, values, dims))
240231
}
241232

242233
#[cfg(test)]
@@ -266,8 +257,8 @@ mod tests {
266257
(
267258
"{3:3, 2:2, 1:1, 0:0}/4",
268259
(
269-
vec![0, 1, 2, 3],
270-
vec![F32(0.0), F32(1.0), F32(2.0), F32(3.0)],
260+
vec![3, 2, 1, 0],
261+
vec![F32(3.0), F32(2.0), F32(1.0), F32(0.0)],
271262
4,
272263
),
273264
),
@@ -294,16 +285,13 @@ mod tests {
294285
"{0:1, 1:2, 2:3",
295286
ParseVectorError::BadParsing { position: 14 },
296287
),
297-
(
298-
"{0:1, 1:2}/",
299-
ParseVectorError::TooShortNumber { position: 11 },
300-
),
288+
("{0:1, 1:2}/", ParseVectorError::BadParsing { position: 11 }),
301289
("{0}/5", ParseVectorError::BadCharacter { position: 2 }),
302-
("{0:}/5", ParseVectorError::TooShortNumber { position: 3 }),
290+
("{0:}/5", ParseVectorError::BadCharacter { position: 3 }),
303291
("{:0}/5", ParseVectorError::BadCharacter { position: 1 }),
304292
(
305293
"{0:, 1:2}/5",
306-
ParseVectorError::TooShortNumber { position: 3 },
294+
ParseVectorError::BadCharacter { position: 3 },
307295
),
308296
("{0:1, 1}/5", ParseVectorError::BadCharacter { position: 7 }),
309297
("/2", ParseVectorError::BadCharacter { position: 0 }),
@@ -347,23 +335,33 @@ mod tests {
347335
),
348336
(
349337
"{2:0, 1:0}/2",
350-
(vec![1, 2], vec![F32(0.0), F32(0.0)], 2),
338+
(vec![2, 1], vec![F32(0.0), F32(0.0)], 2),
351339
(vec![], vec![]),
352340
),
353341
(
354342
"{2:0, 1:0, }/2",
355-
(vec![1, 2], vec![F32(0.0), F32(0.0)], 2),
343+
(vec![2, 1], vec![F32(0.0), F32(0.0)], 2),
356344
(vec![], vec![]),
357345
),
346+
(
347+
"{3:2, 2:1, 1:0, 0:-1}/4",
348+
(
349+
vec![3, 2, 1, 0],
350+
vec![F32(2.0), F32(1.0), F32(0.0), F32(-1.0)],
351+
4,
352+
),
353+
(vec![0, 2, 3], vec![F32(-1.0), F32(1.0), F32(2.0)]),
354+
),
358355
];
359356
for (e, parsed, filtered) in exprs {
360357
let ret = parse_pgvector_svector(e.as_bytes(), |s| s.parse::<F32>().ok());
361358
assert!(ret.is_ok(), "at expr {:?}: {:?}", e, ret);
362359
assert_eq!(ret.unwrap(), parsed, "parsed at expr {:?}", e);
363360

364361
let (indexes, values, _) = parsed;
365-
let nonzero = svector_filter_nonzero(&indexes, &values);
366-
assert_eq!(nonzero, filtered, "filtered at expr {:?}", e);
362+
let (mut indexes, mut values) = svector_sorted(&indexes, &values);
363+
svector_filter_nonzero(&mut indexes, &mut values);
364+
assert_eq!((indexes, values), filtered, "filtered at expr {:?}", e);
367365
}
368366
}
369367
}

0 commit comments

Comments
 (0)