1
1
//! Computes some hashes of the PRF locally.
2
2
3
+ use std:: collections:: VecDeque ;
4
+
3
5
use crate :: { hmac:: hmac_sha256, sha256, state_to_bytes, PrfError } ;
4
6
use mpz_core:: bitvec:: BitVec ;
5
7
use mpz_hash:: sha256:: Sha256 ;
@@ -19,8 +21,8 @@ pub(crate) struct PrfFunction {
19
21
start_seed_label : Vec < u8 > ,
20
22
iterations : usize ,
21
23
state : PrfState ,
22
- a : Vec < PHash > ,
23
- p : Vec < PHash > ,
24
+ a : VecDeque < AHash > ,
25
+ p : VecDeque < PHash > ,
24
26
}
25
27
26
28
#[ derive( Debug ) ]
@@ -38,7 +40,7 @@ enum PrfState {
38
40
inner_partial : [ u32 ; 8 ] ,
39
41
a_output : DecodeFutureTyped < BitVec , [ u8 ; 32 ] > ,
40
42
} ,
41
- ComputeLastP ,
43
+ FinishLastP ,
42
44
Done ,
43
45
}
44
46
@@ -95,7 +97,7 @@ impl PrfFunction {
95
97
} ;
96
98
97
99
self . state = PrfState :: ComputeA {
98
- iter : 0 ,
100
+ iter : 1 ,
99
101
inner_partial,
100
102
msg : self . start_seed_label . clone ( ) ,
101
103
} ;
@@ -106,14 +108,13 @@ impl PrfFunction {
106
108
inner_partial,
107
109
msg,
108
110
} => {
109
- let a = & self . a [ * iter ] ;
111
+ let a = self . a . pop_front ( ) . expect ( "Prf AHash should be present" ) ;
110
112
assign_inner_local ( vm, a. inner_local , * inner_partial, msg) ?;
111
113
112
- let a_output = vm. decode ( a. output ) . map_err ( PrfError :: vm) ?;
113
114
self . state = PrfState :: ComputeP {
114
115
iter : * iter,
115
116
inner_partial : * inner_partial,
116
- a_output,
117
+ a_output : a . output ,
117
118
} ;
118
119
}
119
120
PrfState :: ComputeP {
@@ -124,15 +125,15 @@ impl PrfFunction {
124
125
let Some ( output) = a_output. try_recv ( ) . map_err ( PrfError :: vm) ? else {
125
126
return Ok ( ( ) ) ;
126
127
} ;
127
- let p = & self . p [ * iter ] ;
128
+ let p = self . p . pop_front ( ) . expect ( "Prf PHash should be present" ) ;
128
129
129
130
let mut msg = output. to_vec ( ) ;
130
131
msg. extend_from_slice ( & self . start_seed_label ) ;
131
132
132
133
assign_inner_local ( vm, p. inner_local , * inner_partial, & msg) ?;
133
134
134
135
if * iter == self . iterations {
135
- self . state = PrfState :: ComputeLastP ;
136
+ self . state = PrfState :: FinishLastP ;
136
137
} else {
137
138
self . state = PrfState :: ComputeA {
138
139
iter : * iter + 1 ,
@@ -141,7 +142,7 @@ impl PrfFunction {
141
142
}
142
143
} ;
143
144
}
144
- PrfState :: ComputeLastP => self . state = PrfState :: Done ,
145
+ PrfState :: FinishLastP => self . state = PrfState :: Done ,
145
146
_ => ( ) ,
146
147
}
147
148
@@ -178,22 +179,24 @@ impl PrfFunction {
178
179
let mut prf = Self {
179
180
label,
180
181
start_seed_label : vec ! [ ] ,
181
- // used for indexing, so we need to subtract one here
182
- iterations : iterations - 1 ,
182
+ iterations,
183
183
state : PrfState :: InnerPartial { inner_partial } ,
184
- a : vec ! [ ] ,
185
- p : vec ! [ ] ,
184
+ a : VecDeque :: new ( ) ,
185
+ p : VecDeque :: new ( ) ,
186
186
} ;
187
187
188
188
for _ in 0 ..iterations {
189
189
// setup A[i]
190
190
let inner_local: Array < U8 , 32 > = vm. alloc ( ) . map_err ( PrfError :: vm) ?;
191
191
let output = hmac_sha256 ( vm, outer_partial. clone ( ) , inner_local) ?;
192
- let p_hash = PHash {
192
+
193
+ let output = vm. decode ( output) . map_err ( PrfError :: vm) ?;
194
+ let a_hash = AHash {
193
195
inner_local,
194
196
output,
195
197
} ;
196
- prf. a . push ( p_hash) ;
198
+
199
+ prf. a . push_front ( a_hash) ;
197
200
198
201
// setup P[i]
199
202
let inner_local: Array < U8 , 32 > = vm. alloc ( ) . map_err ( PrfError :: vm) ?;
@@ -202,7 +205,7 @@ impl PrfFunction {
202
205
inner_local,
203
206
output,
204
207
} ;
205
- prf. p . push ( p_hash) ;
208
+ prf. p . push_front ( p_hash) ;
206
209
}
207
210
208
211
Ok ( prf)
@@ -225,6 +228,14 @@ fn assign_inner_local(
225
228
Ok ( ( ) )
226
229
}
227
230
231
+ /// Like PHash but stores the output as the decoding future because in the reduced Prf we need to
232
+ /// decode this output.
233
+ #[ derive( Debug ) ]
234
+ struct AHash {
235
+ inner_local : Array < U8 , 32 > ,
236
+ output : DecodeFutureTyped < BitVec , [ u8 ; 32 ] > ,
237
+ }
238
+
228
239
#[ derive( Debug , Clone , Copy ) ]
229
240
struct PHash {
230
241
inner_local : Array < U8 , 32 > ,
0 commit comments