@@ -85,35 +85,54 @@ where
85
85
Ok ( vector)
86
86
}
87
87
88
- #[ derive( PartialEq , Debug ) ]
88
+ #[ derive( PartialEq , Debug , Clone ) ]
89
89
enum ParseState {
90
90
Start ,
91
91
LeftBracket ,
92
92
Index ,
93
+ Colon ,
93
94
Value ,
94
- Splitter ,
95
95
Comma ,
96
- Length ,
96
+ RightBracket ,
97
+ Splitter ,
98
+ Dims ,
97
99
}
98
100
99
101
#[ inline( always) ]
100
- pub fn svector_filter_nonzero < T : Zero + Clone + PartialEq > (
102
+ pub fn svector_sorted < T : Zero + Clone + PartialEq > (
101
103
indexes : & [ u32 ] ,
102
104
values : & [ T ] ,
103
105
) -> ( Vec < u32 > , Vec < T > ) {
104
- let non_zero_indexes: Vec < u32 > = indexes
105
- . iter ( )
106
- . enumerate ( )
107
- . filter ( |( i, _) | values. get ( * i) . unwrap ( ) != & T :: zero ( ) )
108
- . map ( |( _, x) | * x)
109
- . collect ( ) ;
110
- let non_zero_values: Vec < T > = indexes
111
- . iter ( )
112
- . enumerate ( )
113
- . filter ( |( i, _) | values. get ( * i) . unwrap ( ) != & T :: zero ( ) )
114
- . map ( |( i, _) | values. get ( i) . unwrap ( ) . clone ( ) )
115
- . collect ( ) ;
116
- ( non_zero_indexes, non_zero_values)
106
+ let mut indices = ( 0 ..indexes. len ( ) ) . collect :: < Vec < _ > > ( ) ;
107
+ indices. sort_by_key ( |& i| & indexes[ i] ) ;
108
+
109
+ let mut sorted_indexes: Vec < u32 > = Vec :: with_capacity ( indexes. len ( ) ) ;
110
+ let mut sorted_values: Vec < T > = Vec :: with_capacity ( indexes. len ( ) ) ;
111
+ for i in indices {
112
+ sorted_indexes. push ( * indexes. get ( i) . unwrap ( ) ) ;
113
+ sorted_values. push ( values. get ( i) . unwrap ( ) . clone ( ) ) ;
114
+ }
115
+ ( sorted_indexes, sorted_values)
116
+ }
117
+
118
+ #[ inline( always) ]
119
+ pub fn svector_filter_nonzero < T : Zero + Clone + PartialEq > (
120
+ indexes : & mut Vec < u32 > ,
121
+ values : & mut Vec < T > ,
122
+ ) {
123
+ // Index must be sorted!
124
+ let mut i = 0 ;
125
+ let mut j = 0 ;
126
+ while j < values. len ( ) {
127
+ if !values[ j] . is_zero ( ) {
128
+ indexes[ i] = indexes[ j] ;
129
+ values[ i] = values[ j] . clone ( ) ;
130
+ i += 1 ;
131
+ }
132
+ j += 1 ;
133
+ }
134
+ indexes. truncate ( i) ;
135
+ values. truncate ( i) ;
117
136
}
118
137
119
138
#[ inline( always) ]
@@ -133,110 +152,82 @@ where
133
152
let mut values = Vec :: < T > :: new ( ) ;
134
153
135
154
let mut state = ParseState :: Start ;
136
- for ( position, char) in input. iter ( ) . enumerate ( ) {
137
- let c = * char;
138
- match ( & state, c) {
139
- ( _, b' ' ) => { }
140
- ( ParseState :: Start , b'{' ) => {
141
- state = ParseState :: LeftBracket ;
142
- }
155
+ for ( position, c) in input. iter ( ) . copied ( ) . enumerate ( ) {
156
+ state = match ( & state, c) {
157
+ ( _, b' ' ) => state,
158
+ ( ParseState :: Start , b'{' ) => ParseState :: LeftBracket ,
143
159
(
144
160
ParseState :: LeftBracket | ParseState :: Index | ParseState :: Comma ,
145
161
b'0' ..=b'9' | b'a' ..=b'z' | b'A' ..=b'Z' | b'.' | b'+' | b'-' ,
146
162
) => {
147
- if token. is_empty ( ) {
148
- token. push ( b'$' ) ;
149
- }
150
163
if token. try_push ( c) . is_err ( ) {
151
164
return Err ( ParseVectorError :: TooLongNumber { position } ) ;
152
165
}
153
- state = ParseState :: Index ;
166
+ ParseState :: Index
154
167
}
155
- ( ParseState :: LeftBracket | ParseState :: Comma , b'}' ) => {
156
- state = ParseState :: Splitter ;
168
+ ( ParseState :: Colon , b'0' ..=b'9' | b'a' ..=b'z' | b'A' ..=b'Z' | b'.' | b'+' | b'-' ) => {
169
+ if token. try_push ( c) . is_err ( ) {
170
+ return Err ( ParseVectorError :: TooLongNumber { position } ) ;
171
+ }
172
+ ParseState :: Value
157
173
}
174
+ ( ParseState :: LeftBracket | ParseState :: Comma , b'}' ) => ParseState :: RightBracket ,
158
175
( ParseState :: Index , b':' ) => {
159
- if token. is_empty ( ) {
160
- return Err ( ParseVectorError :: TooShortNumber { position } ) ;
161
- }
162
- let s = unsafe { std:: str:: from_utf8_unchecked ( & token[ 1 ..] ) } ;
176
+ let s = unsafe { std:: str:: from_utf8_unchecked ( & token[ ..] ) } ;
163
177
let index = s
164
178
. parse :: < u32 > ( )
165
179
. map_err ( |_| ParseVectorError :: BadParsing { position } ) ?;
166
180
indexes. push ( index) ;
167
181
token. clear ( ) ;
168
- state = ParseState :: Value ;
182
+ ParseState :: Colon
169
183
}
170
184
( ParseState :: Value , b'0' ..=b'9' | b'a' ..=b'z' | b'A' ..=b'Z' | b'.' | b'+' | b'-' ) => {
171
- if token. is_empty ( ) {
172
- token. push ( b'$' ) ;
173
- }
174
185
if token. try_push ( c) . is_err ( ) {
175
186
return Err ( ParseVectorError :: TooLongNumber { position } ) ;
176
187
}
188
+ ParseState :: Value
177
189
}
178
190
( ParseState :: Value , b',' ) => {
179
- if token. is_empty ( ) {
180
- return Err ( ParseVectorError :: TooShortNumber { position } ) ;
181
- }
182
- let s = unsafe { std:: str:: from_utf8_unchecked ( & token[ 1 ..] ) } ;
191
+ let s = unsafe { std:: str:: from_utf8_unchecked ( & token[ ..] ) } ;
183
192
let num = f ( s) . ok_or ( ParseVectorError :: BadParsing { position } ) ?;
184
193
values. push ( num) ;
185
194
token. clear ( ) ;
186
- state = ParseState :: Comma ;
195
+ ParseState :: Comma
187
196
}
188
197
( ParseState :: Value , b'}' ) => {
189
198
if token. is_empty ( ) {
190
199
return Err ( ParseVectorError :: TooShortNumber { position } ) ;
191
200
}
192
- let s = unsafe { std:: str:: from_utf8_unchecked ( & token[ 1 ..] ) } ;
201
+ let s = unsafe { std:: str:: from_utf8_unchecked ( & token[ ..] ) } ;
193
202
let num = f ( s) . ok_or ( ParseVectorError :: BadParsing { position } ) ?;
194
203
values. push ( num) ;
195
204
token. clear ( ) ;
196
- state = ParseState :: Splitter ;
197
- }
198
- ( ParseState :: Splitter , b'/' ) => {
199
- state = ParseState :: Length ;
205
+ ParseState :: RightBracket
200
206
}
201
- ( ParseState :: Length , b'0' ..=b'9' ) => {
202
- if token. is_empty ( ) {
203
- token. push ( b'$' ) ;
204
- }
207
+ ( ParseState :: RightBracket , b'/' ) => ParseState :: Splitter ,
208
+ ( ParseState :: Dims | ParseState :: Splitter , b'0' ..=b'9' ) => {
205
209
if token. try_push ( c) . is_err ( ) {
206
210
return Err ( ParseVectorError :: TooLongNumber { position } ) ;
207
211
}
212
+ ParseState :: Dims
208
213
}
209
214
( _, _) => {
210
215
return Err ( ParseVectorError :: BadCharacter { position } ) ;
211
216
}
212
217
}
213
218
}
214
- if state != ParseState :: Length {
219
+ if state != ParseState :: Dims {
215
220
return Err ( ParseVectorError :: BadParsing {
216
221
position : input. len ( ) ,
217
222
} ) ;
218
223
}
219
- if token. is_empty ( ) {
220
- return Err ( ParseVectorError :: TooShortNumber {
221
- position : input. len ( ) ,
222
- } ) ;
223
- }
224
- let s = unsafe { std:: str:: from_utf8_unchecked ( & token[ 1 ..] ) } ;
224
+ let s = unsafe { std:: str:: from_utf8_unchecked ( & token[ ..] ) } ;
225
225
let dims = s
226
226
. parse :: < usize > ( )
227
227
. map_err ( |_| ParseVectorError :: BadParsing {
228
228
position : input. len ( ) ,
229
229
} ) ?;
230
-
231
- let mut indices = ( 0 ..indexes. len ( ) ) . collect :: < Vec < _ > > ( ) ;
232
- indices. sort_by_key ( |& i| & indexes[ i] ) ;
233
- let sorted_values: Vec < T > = indices
234
- . iter ( )
235
- . map ( |i| values. get ( * i) . unwrap ( ) . clone ( ) )
236
- . collect ( ) ;
237
- indexes. sort ( ) ;
238
-
239
- Ok ( ( indexes, sorted_values, dims) )
230
+ Ok ( ( indexes, values, dims) )
240
231
}
241
232
242
233
#[ cfg( test) ]
@@ -266,8 +257,8 @@ mod tests {
266
257
(
267
258
"{3:3, 2:2, 1:1, 0:0}/4" ,
268
259
(
269
- vec![ 0 , 1 , 2 , 3 ] ,
270
- vec![ F32 ( 0 .0) , F32 ( 1 .0) , F32 ( 2 .0) , F32 ( 3 .0) ] ,
260
+ vec![ 3 , 2 , 1 , 0 ] ,
261
+ vec![ F32 ( 3 .0) , F32 ( 2 .0) , F32 ( 1 .0) , F32 ( 0 .0) ] ,
271
262
4 ,
272
263
) ,
273
264
) ,
@@ -294,16 +285,13 @@ mod tests {
294
285
"{0:1, 1:2, 2:3" ,
295
286
ParseVectorError :: BadParsing { position: 14 } ,
296
287
) ,
297
- (
298
- "{0:1, 1:2}/" ,
299
- ParseVectorError :: TooShortNumber { position: 11 } ,
300
- ) ,
288
+ ( "{0:1, 1:2}/" , ParseVectorError :: BadParsing { position: 11 } ) ,
301
289
( "{0}/5" , ParseVectorError :: BadCharacter { position: 2 } ) ,
302
- ( "{0:}/5" , ParseVectorError :: TooShortNumber { position: 3 } ) ,
290
+ ( "{0:}/5" , ParseVectorError :: BadCharacter { position: 3 } ) ,
303
291
( "{:0}/5" , ParseVectorError :: BadCharacter { position: 1 } ) ,
304
292
(
305
293
"{0:, 1:2}/5" ,
306
- ParseVectorError :: TooShortNumber { position: 3 } ,
294
+ ParseVectorError :: BadCharacter { position: 3 } ,
307
295
) ,
308
296
( "{0:1, 1}/5" , ParseVectorError :: BadCharacter { position: 7 } ) ,
309
297
( "/2" , ParseVectorError :: BadCharacter { position: 0 } ) ,
@@ -347,23 +335,33 @@ mod tests {
347
335
) ,
348
336
(
349
337
"{2:0, 1:0}/2" ,
350
- ( vec![ 1 , 2 ] , vec![ F32 ( 0.0 ) , F32 ( 0.0 ) ] , 2 ) ,
338
+ ( vec![ 2 , 1 ] , vec![ F32 ( 0.0 ) , F32 ( 0.0 ) ] , 2 ) ,
351
339
( vec![ ] , vec![ ] ) ,
352
340
) ,
353
341
(
354
342
"{2:0, 1:0, }/2" ,
355
- ( vec![ 1 , 2 ] , vec![ F32 ( 0.0 ) , F32 ( 0.0 ) ] , 2 ) ,
343
+ ( vec![ 2 , 1 ] , vec![ F32 ( 0.0 ) , F32 ( 0.0 ) ] , 2 ) ,
356
344
( vec![ ] , vec![ ] ) ,
357
345
) ,
346
+ (
347
+ "{3:2, 2:1, 1:0, 0:-1}/4" ,
348
+ (
349
+ vec![ 3 , 2 , 1 , 0 ] ,
350
+ vec![ F32 ( 2.0 ) , F32 ( 1.0 ) , F32 ( 0.0 ) , F32 ( -1.0 ) ] ,
351
+ 4 ,
352
+ ) ,
353
+ ( vec![ 0 , 2 , 3 ] , vec![ F32 ( -1.0 ) , F32 ( 1.0 ) , F32 ( 2.0 ) ] ) ,
354
+ ) ,
358
355
] ;
359
356
for ( e, parsed, filtered) in exprs {
360
357
let ret = parse_pgvector_svector ( e. as_bytes ( ) , |s| s. parse :: < F32 > ( ) . ok ( ) ) ;
361
358
assert ! ( ret. is_ok( ) , "at expr {:?}: {:?}" , e, ret) ;
362
359
assert_eq ! ( ret. unwrap( ) , parsed, "parsed at expr {:?}" , e) ;
363
360
364
361
let ( indexes, values, _) = parsed;
365
- let nonzero = svector_filter_nonzero ( & indexes, & values) ;
366
- assert_eq ! ( nonzero, filtered, "filtered at expr {:?}" , e) ;
362
+ let ( mut indexes, mut values) = svector_sorted ( & indexes, & values) ;
363
+ svector_filter_nonzero ( & mut indexes, & mut values) ;
364
+ assert_eq ! ( ( indexes, values) , filtered, "filtered at expr {:?}" , e) ;
367
365
}
368
366
}
369
367
}
0 commit comments