@@ -70,6 +70,7 @@ use core::ops::{Mul, MulAssign};
70
70
use elliptic_curve:: subtle:: { Choice , ConditionallySelectable , ConstantTimeEq } ;
71
71
72
72
/// Lookup table containing precomputed values `[p, 2p, 3p, ..., 8p]`
73
+ #[ derive( Copy , Clone , Default ) ]
73
74
struct LookupTable ( [ ProjectivePoint ; 8 ] ) ;
74
75
75
76
impl From < & ProjectivePoint > for LookupTable {
@@ -147,94 +148,218 @@ fn decompose_scalar(k: &Scalar) -> (Scalar, Scalar) {
147
148
( r1, r2)
148
149
}
149
150
150
- /// Returns `[a_0, ..., a_32]` such that `sum(a_j * 2^(j * 4)) == x`,
151
- /// and `-8 <= a_j <= 7`.
152
- /// Assumes `x < 2^128`.
153
- fn to_radix_16_half ( x : & Scalar ) -> [ i8 ; 33 ] {
154
- // `x` can have up to 256 bits, so we need an additional byte to store the carry.
155
- let mut output = [ 0i8 ; 33 ] ;
156
-
157
- // Step 1: change radix.
158
- // Convert from radix 256 (bytes) to radix 16 (nibbles)
159
- let bytes = x. to_bytes ( ) ;
160
- for i in 0 ..16 {
161
- output[ 2 * i] = ( bytes[ 31 - i] & 0xf ) as i8 ;
162
- output[ 2 * i + 1 ] = ( ( bytes[ 31 - i] >> 4 ) & 0xf ) as i8 ;
163
- }
151
+ // This needs to be an object to have Default implemented for it
152
+ // (required because it's used in static_map later)
153
+ // Otherwise we could just have a function returning an array.
154
+ #[ derive( Copy , Clone ) ]
155
+ struct Radix16Decomposition ( [ i8 ; 33 ] ) ;
156
+
157
+ impl Radix16Decomposition {
158
+ /// Returns an object containing a decomposition
159
+ /// `[a_0, ..., a_32]` such that `sum(a_j * 2^(j * 4)) == x`,
160
+ /// and `-8 <= a_j <= 7`.
161
+ /// Assumes `x < 2^128`.
162
+ fn new ( x : & Scalar ) -> Self {
163
+ debug_assert ! ( ( x >> 128 ) . is_zero( ) . unwrap_u8( ) == 1 ) ;
164
+
165
+ // The resulting decomposition can be negative, so, despite the limit on `x`,
166
+ // it can have up to 256 bits, and we need an additional byte to store the carry.
167
+ let mut output = [ 0i8 ; 33 ] ;
168
+
169
+ // Step 1: change radix.
170
+ // Convert from radix 256 (bytes) to radix 16 (nibbles)
171
+ let bytes = x. to_bytes ( ) ;
172
+ for i in 0 ..16 {
173
+ output[ 2 * i] = ( bytes[ 31 - i] & 0xf ) as i8 ;
174
+ output[ 2 * i + 1 ] = ( ( bytes[ 31 - i] >> 4 ) & 0xf ) as i8 ;
175
+ }
164
176
165
- debug_assert ! ( ( x >> 128 ) . is_zero( ) . unwrap_u8( ) == 1 ) ;
177
+ // Step 2: recenter coefficients from [0,16) to [-8,8)
178
+ for i in 0 ..32 {
179
+ let carry = ( output[ i] + 8 ) >> 4 ;
180
+ output[ i] -= carry << 4 ;
181
+ output[ i + 1 ] += carry;
182
+ }
166
183
167
- // Step 2: recenter coefficients from [0,16) to [-8,8)
168
- for i in 0 ..32 {
169
- let carry = ( output[ i] + 8 ) >> 4 ;
170
- output[ i] -= carry << 4 ;
171
- output[ i + 1 ] += carry;
184
+ Self ( output)
172
185
}
173
-
174
- output
175
186
}
176
187
177
- fn mul_windowed ( x : & ProjectivePoint , k : & Scalar ) -> ProjectivePoint {
178
- let ( r1, r2) = decompose_scalar ( k) ;
179
- let x_beta = x. endomorphism ( ) ;
188
+ impl Default for Radix16Decomposition {
189
+ fn default ( ) -> Self {
190
+ Self ( [ 0i8 ; 33 ] )
191
+ }
192
+ }
180
193
181
- let r1_sign = r1. is_high ( ) ;
182
- let r1_c = Scalar :: conditional_select ( & r1, & -r1, r1_sign) ;
183
- let r2_sign = r2. is_high ( ) ;
184
- let r2_c = Scalar :: conditional_select ( & r2, & -r2, r2_sign) ;
194
+ /// Maps an array `x` to an array using the predicate `f`.
195
+ /// We can't use the standard `map()` because as of Rust 1.51 we cannot collect into arrays.
196
+ /// Consequently, since we cannot have an uninitialized array (without `unsafe`),
197
+ /// a default value needs to be provided.
198
+ fn static_map < T : Copy , V : Copy , const N : usize > (
199
+ f : impl Fn ( T ) -> V ,
200
+ x : & [ T ; N ] ,
201
+ default : V ,
202
+ ) -> [ V ; N ] {
203
+ let mut res = [ default; N ] ;
204
+ for i in 0 ..N {
205
+ res[ i] = f ( x[ i] ) ;
206
+ }
207
+ res
208
+ }
185
209
186
- let table1 = LookupTable :: from ( & ProjectivePoint :: conditional_select ( x, & -x, r1_sign) ) ;
187
- let table2 = LookupTable :: from ( & ProjectivePoint :: conditional_select (
188
- & x_beta, & -x_beta, r2_sign,
189
- ) ) ;
210
+ /// Maps two arrays `x` and `y` into an array using a predicate `f` that takes two arguments.
211
+ fn static_zip_map < T : Copy , S : Copy , V : Copy , const N : usize > (
212
+ f : impl Fn ( T , S ) -> V ,
213
+ x : & [ T ; N ] ,
214
+ y : & [ S ; N ] ,
215
+ default : V ,
216
+ ) -> [ V ; N ] {
217
+ let mut res = [ default; N ] ;
218
+ for i in 0 ..N {
219
+ res[ i] = f ( x[ i] , y[ i] ) ;
220
+ }
221
+ res
222
+ }
190
223
191
- let digits1 = to_radix_16_half ( & r1_c) ;
192
- let digits2 = to_radix_16_half ( & r2_c) ;
224
+ /// Calculates a linear combination `sum(x[i] * k[i])`, `i = 0..N`
225
+ #[ inline( always) ]
226
+ fn lincomb_generic < const N : usize > ( xs : & [ ProjectivePoint ; N ] , ks : & [ Scalar ; N ] ) -> ProjectivePoint {
227
+ let rs = static_map (
228
+ |k| decompose_scalar ( & k) ,
229
+ ks,
230
+ ( Scalar :: default ( ) , Scalar :: default ( ) ) ,
231
+ ) ;
232
+ let r1s = static_map ( |( r1, _r2) | r1, & rs, Scalar :: default ( ) ) ;
233
+ let r2s = static_map ( |( _r1, r2) | r2, & rs, Scalar :: default ( ) ) ;
234
+
235
+ let xs_beta = static_map ( |x| x. endomorphism ( ) , xs, ProjectivePoint :: default ( ) ) ;
236
+
237
+ let r1_signs = static_map ( |r| r. is_high ( ) , & r1s, Choice :: from ( 0u8 ) ) ;
238
+ let r2_signs = static_map ( |r| r. is_high ( ) , & r2s, Choice :: from ( 0u8 ) ) ;
239
+
240
+ let r1s_c = static_zip_map (
241
+ |r, r_sign| Scalar :: conditional_select ( & r, & -r, r_sign) ,
242
+ & r1s,
243
+ & r1_signs,
244
+ Scalar :: default ( ) ,
245
+ ) ;
246
+ let r2s_c = static_zip_map (
247
+ |r, r_sign| Scalar :: conditional_select ( & r, & -r, r_sign) ,
248
+ & r2s,
249
+ & r2_signs,
250
+ Scalar :: default ( ) ,
251
+ ) ;
252
+
253
+ let tables1 = static_zip_map (
254
+ |x, r_sign| LookupTable :: from ( & ProjectivePoint :: conditional_select ( & x, & -x, r_sign) ) ,
255
+ & xs,
256
+ & r1_signs,
257
+ LookupTable :: default ( ) ,
258
+ ) ;
259
+ let tables2 = static_zip_map (
260
+ |x, r_sign| LookupTable :: from ( & ProjectivePoint :: conditional_select ( & x, & -x, r_sign) ) ,
261
+ & xs_beta,
262
+ & r2_signs,
263
+ LookupTable :: default ( ) ,
264
+ ) ;
265
+
266
+ let digits1 = static_map (
267
+ |r| Radix16Decomposition :: new ( & r) ,
268
+ & r1s_c,
269
+ Radix16Decomposition :: default ( ) ,
270
+ ) ;
271
+ let digits2 = static_map (
272
+ |r| Radix16Decomposition :: new ( & r) ,
273
+ & r2s_c,
274
+ Radix16Decomposition :: default ( ) ,
275
+ ) ;
276
+
277
+ let mut acc = ProjectivePoint :: identity ( ) ;
278
+ for component in 0 ..N {
279
+ acc += & tables1[ component] . select ( digits1[ component] . 0 [ 32 ] ) ;
280
+ acc += & tables2[ component] . select ( digits2[ component] . 0 [ 32 ] ) ;
281
+ }
193
282
194
- let mut acc = table1. select ( digits1[ 32 ] ) + table2. select ( digits2[ 32 ] ) ;
195
283
for i in ( 0 ..32 ) . rev ( ) {
196
284
for _j in 0 ..4 {
197
285
acc = acc. double ( ) ;
198
286
}
199
287
200
- acc += & table1. select ( digits1[ i] ) ;
201
- acc += & table2. select ( digits2[ i] ) ;
288
+ for component in 0 ..N {
289
+ acc += & tables1[ component] . select ( digits1[ component] . 0 [ i] ) ;
290
+ acc += & tables2[ component] . select ( digits2[ component] . 0 [ i] ) ;
291
+ }
202
292
}
203
293
acc
204
294
}
205
295
296
+ #[ inline( always) ]
297
+ fn mul ( x : & ProjectivePoint , k : & Scalar ) -> ProjectivePoint {
298
+ lincomb_generic ( & [ * x] , & [ * k] )
299
+ }
300
+
301
+ /// Calculates `x * k + y * l`.
302
+ pub fn lincomb (
303
+ x : & ProjectivePoint ,
304
+ k : & Scalar ,
305
+ y : & ProjectivePoint ,
306
+ l : & Scalar ,
307
+ ) -> ProjectivePoint {
308
+ lincomb_generic ( & [ * x, * y] , & [ * k, * l] )
309
+ }
310
+
206
311
impl Mul < Scalar > for ProjectivePoint {
207
312
type Output = ProjectivePoint ;
208
313
209
314
fn mul ( self , other : Scalar ) -> ProjectivePoint {
210
- mul_windowed ( & self , & other)
315
+ mul ( & self , & other)
211
316
}
212
317
}
213
318
214
319
impl Mul < & Scalar > for & ProjectivePoint {
215
320
type Output = ProjectivePoint ;
216
321
217
322
fn mul ( self , other : & Scalar ) -> ProjectivePoint {
218
- mul_windowed ( self , other)
323
+ mul ( self , other)
219
324
}
220
325
}
221
326
222
327
impl Mul < & Scalar > for ProjectivePoint {
223
328
type Output = ProjectivePoint ;
224
329
225
330
fn mul ( self , other : & Scalar ) -> ProjectivePoint {
226
- mul_windowed ( & self , other)
331
+ mul ( & self , other)
227
332
}
228
333
}
229
334
230
335
impl MulAssign < Scalar > for ProjectivePoint {
231
336
fn mul_assign ( & mut self , rhs : Scalar ) {
232
- * self = mul_windowed ( self , & rhs) ;
337
+ * self = mul ( self , & rhs) ;
233
338
}
234
339
}
235
340
236
341
impl MulAssign < & Scalar > for ProjectivePoint {
237
342
fn mul_assign ( & mut self , rhs : & Scalar ) {
238
- * self = mul_windowed ( self , rhs) ;
343
+ * self = mul ( self , rhs) ;
344
+ }
345
+ }
346
+
347
+ #[ cfg( test) ]
348
+ mod tests {
349
+ use super :: lincomb;
350
+ use crate :: arithmetic:: { ProjectivePoint , Scalar } ;
351
+ use elliptic_curve:: rand_core:: OsRng ;
352
+ use elliptic_curve:: { Field , Group } ;
353
+
354
+ #[ test]
355
+ fn test_lincomb ( ) {
356
+ let x = ProjectivePoint :: random ( & mut OsRng ) ;
357
+ let y = ProjectivePoint :: random ( & mut OsRng ) ;
358
+ let k = Scalar :: random ( & mut OsRng ) ;
359
+ let l = Scalar :: random ( & mut OsRng ) ;
360
+
361
+ let reference = & x * & k + & y * & l;
362
+ let test = lincomb ( & x, & k, & y, & l) ;
363
+ assert_eq ! ( reference, test) ;
239
364
}
240
365
}
0 commit comments