@@ -17,13 +17,31 @@ pub(crate) struct PrfFunction {
17
17
label : & ' static [ u8 ] ,
18
18
// The start seed and the label, e.g. client_random + server_random + "master_secret".
19
19
start_seed_label : Vec < u8 > ,
20
- // The current HMAC message needed for a[i]
21
- a_msg : Vec < u8 > ,
22
- inner_partial : InnerPartial ,
20
+ iterations : usize ,
21
+ state : PrfState ,
23
22
a : Vec < PHash > ,
24
23
p : Vec < PHash > ,
25
24
}
26
25
26
+ #[ derive( Debug ) ]
27
+ enum PrfState {
28
+ InnerPartial {
29
+ inner_partial : DecodeFutureTyped < BitVec , [ u32 ; 8 ] > ,
30
+ } ,
31
+ ComputeA {
32
+ iter : usize ,
33
+ inner_partial : [ u32 ; 8 ] ,
34
+ msg : Vec < u8 > ,
35
+ } ,
36
+ ComputeP {
37
+ iter : usize ,
38
+ inner_partial : [ u32 ; 8 ] ,
39
+ a_output : DecodeFutureTyped < BitVec , [ u8 ; 32 ] > ,
40
+ } ,
41
+ ComputeLastP ,
42
+ Done ,
43
+ }
44
+
27
45
impl PrfFunction {
28
46
const MS_LABEL : & [ u8 ] = b"master secret" ;
29
47
const KEY_LABEL : & [ u8 ] = b"key expansion" ;
@@ -63,51 +81,68 @@ impl PrfFunction {
63
81
}
64
82
65
83
pub ( crate ) fn wants_flush ( & mut self ) -> bool {
66
- let last_p = self . p . last ( ) . expect ( "Prf should be allocated" ) ;
67
-
68
- if let State :: Done = last_p. state {
84
+ if let PrfState :: Done = self . state {
69
85
return false ;
70
86
}
71
87
true
72
88
}
73
89
74
90
pub ( crate ) fn flush ( & mut self , vm : & mut dyn Vm < Binary > ) -> Result < ( ) , PrfError > {
75
- let inner_partial = self . inner_partial . try_recv ( ) ?;
76
- let Some ( inner_partial) = inner_partial else {
77
- return Ok ( ( ) ) ;
78
- } ;
79
-
80
- for ( a, p) in self . a . iter_mut ( ) . zip ( self . p . iter_mut ( ) ) {
81
- match & mut a. state {
82
- State :: Init { .. } => {
83
- a. assign_inner_local ( vm, inner_partial, & self . a_msg ) ?;
84
- break ;
85
- }
86
- State :: Assigned { output } => {
87
- if let Some ( output) = output. try_recv ( ) . map_err ( PrfError :: vm) ? {
88
- let output = output. to_vec ( ) ;
89
- a. state = State :: Decoded {
90
- output : output. clone ( ) ,
91
- } ;
92
- self . a_msg = output;
93
- }
94
- }
95
- _ => ( ) ,
91
+ match & mut self . state {
92
+ PrfState :: InnerPartial { inner_partial } => {
93
+ let Some ( inner_partial) = inner_partial. try_recv ( ) . map_err ( PrfError :: vm) ? else {
94
+ return Ok ( ( ) ) ;
95
+ } ;
96
+
97
+ self . state = PrfState :: ComputeA {
98
+ iter : 0 ,
99
+ inner_partial,
100
+ msg : self . start_seed_label . clone ( ) ,
101
+ } ;
102
+ self . flush ( vm) ?;
96
103
}
97
-
98
- match & mut p. state {
99
- State :: Init { .. } => {
100
- if let State :: Decoded { output } = & a. state {
101
- let mut p_msg = output. to_vec ( ) ;
102
- p_msg. extend_from_slice ( & self . start_seed_label ) ;
103
- p. assign_inner_local ( vm, inner_partial, & p_msg) ?;
104
+ PrfState :: ComputeA {
105
+ iter,
106
+ inner_partial,
107
+ msg,
108
+ } => {
109
+ let a = & self . a [ * iter] ;
110
+ assign_inner_local ( vm, a. inner_local , * inner_partial, msg) ?;
111
+
112
+ let a_output = vm. decode ( a. output ) . map_err ( PrfError :: vm) ?;
113
+ self . state = PrfState :: ComputeP {
114
+ iter : * iter,
115
+ inner_partial : * inner_partial,
116
+ a_output,
117
+ } ;
118
+ }
119
+ PrfState :: ComputeP {
120
+ iter,
121
+ inner_partial,
122
+ a_output,
123
+ } => {
124
+ let Some ( output) = a_output. try_recv ( ) . map_err ( PrfError :: vm) ? else {
125
+ return Ok ( ( ) ) ;
126
+ } ;
127
+ let p = & self . p [ * iter] ;
128
+
129
+ let mut msg = output. to_vec ( ) ;
130
+ msg. extend_from_slice ( & self . start_seed_label ) ;
131
+
132
+ assign_inner_local ( vm, p. inner_local , * inner_partial, & msg) ?;
133
+
134
+ if * iter == self . iterations {
135
+ self . state = PrfState :: ComputeLastP ;
136
+ } else {
137
+ self . state = PrfState :: ComputeA {
138
+ iter : * iter + 1 ,
139
+ inner_partial : * inner_partial,
140
+ msg : output. to_vec ( ) ,
104
141
}
105
- }
106
- State :: Assigned { .. } => {
107
- p. state = State :: Done ;
108
- }
109
- _ => ( ) ,
142
+ } ;
110
143
}
144
+ PrfState :: ComputeLastP => self . state = PrfState :: Done ,
145
+ _ => ( ) ,
111
146
}
112
147
113
148
Ok ( ( ) )
@@ -117,8 +152,7 @@ impl PrfFunction {
117
152
let mut start_seed_label = self . label . to_vec ( ) ;
118
153
start_seed_label. extend_from_slice ( & seed) ;
119
154
120
- self . start_seed_label = start_seed_label. clone ( ) ;
121
- self . a_msg = start_seed_label;
155
+ self . start_seed_label = start_seed_label;
122
156
}
123
157
124
158
pub ( crate ) fn output ( & self ) -> Vec < Array < U8 , 32 > > {
@@ -132,6 +166,10 @@ impl PrfFunction {
132
166
inner_partial : Sha256 ,
133
167
len : usize ,
134
168
) -> Result < Self , PrfError > {
169
+ assert ! ( len > 0 , "cannot compute 0 bytes for prf" ) ;
170
+
171
+ let iterations = len / 32 + ( ( len % 32 ) != 0 ) as usize ;
172
+
135
173
let ( inner_partial, _) = inner_partial
136
174
. state ( )
137
175
. expect ( "state should be set for inner_partial" ) ;
@@ -140,100 +178,55 @@ impl PrfFunction {
140
178
let mut prf = Self {
141
179
label,
142
180
start_seed_label : vec ! [ ] ,
143
- a_msg : vec ! [ ] ,
144
- inner_partial : InnerPartial :: Decoding ( inner_partial) ,
181
+ // used for indexing, so we need to subtract one here
182
+ iterations : iterations - 1 ,
183
+ state : PrfState :: InnerPartial { inner_partial } ,
145
184
a : vec ! [ ] ,
146
185
p : vec ! [ ] ,
147
186
} ;
148
187
149
- assert ! ( len > 0 , "cannot compute 0 bytes for prf" ) ;
150
-
151
- let iterations = len / 32 + ( ( len % 32 ) != 0 ) as usize ;
152
-
153
188
for _ in 0 ..iterations {
154
- let a = PHash :: alloc ( vm, outer_partial. clone ( ) ) ?;
155
- prf. a . push ( a) ;
156
-
157
- let p = PHash :: alloc ( vm, outer_partial. clone ( ) ) ?;
158
- prf. p . push ( p) ;
189
+ // setup A[i]
190
+ let inner_local: Array < U8 , 32 > = vm. alloc ( ) . map_err ( PrfError :: vm) ?;
191
+ let output = hmac_sha256 ( vm, outer_partial. clone ( ) , inner_local) ?;
192
+ let p_hash = PHash {
193
+ inner_local,
194
+ output,
195
+ } ;
196
+ prf. a . push ( p_hash) ;
197
+
198
+ // setup P[i]
199
+ let inner_local: Array < U8 , 32 > = vm. alloc ( ) . map_err ( PrfError :: vm) ?;
200
+ let output = hmac_sha256 ( vm, outer_partial. clone ( ) , inner_local) ?;
201
+ let p_hash = PHash {
202
+ inner_local,
203
+ output,
204
+ } ;
205
+ prf. p . push ( p_hash) ;
159
206
}
160
207
161
208
Ok ( prf)
162
209
}
163
210
}
164
211
165
- #[ derive( Debug ) ]
166
- struct PHash {
167
- output : Array < U8 , 32 > ,
168
- state : State ,
169
- }
170
-
171
- impl PHash {
172
- fn alloc ( vm : & mut dyn Vm < Binary > , outer_partial : Sha256 ) -> Result < Self , PrfError > {
173
- let inner_local: Array < U8 , 32 > = vm. alloc ( ) . map_err ( PrfError :: vm) ?;
174
- let output = hmac_sha256 ( vm, outer_partial, inner_local) ?;
175
-
176
- let p_hash = Self {
177
- state : State :: Init { inner_local } ,
178
- output,
179
- } ;
180
-
181
- Ok ( p_hash)
182
- }
183
-
184
- fn assign_inner_local (
185
- & mut self ,
186
- vm : & mut dyn Vm < Binary > ,
187
- inner_partial : [ u32 ; 8 ] ,
188
- msg : & [ u8 ] ,
189
- ) -> Result < ( ) , PrfError > {
190
- if let State :: Init { inner_local, .. } = self . state {
191
- let inner_local_value = sha256 ( inner_partial, 64 , msg) ;
192
-
193
- vm. mark_public ( inner_local) . map_err ( PrfError :: vm) ?;
194
- vm. assign ( inner_local, state_to_bytes ( inner_local_value) )
195
- . map_err ( PrfError :: vm) ?;
196
- vm. commit ( inner_local) . map_err ( PrfError :: vm) ?;
197
-
198
- let output = vm. decode ( self . output ) . map_err ( PrfError :: vm) ?;
199
- self . state = State :: Assigned { output } ;
200
- }
201
-
202
- Ok ( ( ) )
203
- }
204
- }
212
+ fn assign_inner_local (
213
+ vm : & mut dyn Vm < Binary > ,
214
+ inner_local : Array < U8 , 32 > ,
215
+ inner_partial : [ u32 ; 8 ] ,
216
+ msg : & [ u8 ] ,
217
+ ) -> Result < ( ) , PrfError > {
218
+ let inner_local_value = sha256 ( inner_partial, 64 , msg) ;
205
219
206
- #[ derive( Debug ) ]
207
- enum State {
208
- Init {
209
- inner_local : Array < U8 , 32 > ,
210
- } ,
211
- Assigned {
212
- output : DecodeFutureTyped < BitVec , [ u8 ; 32 ] > ,
213
- } ,
214
- Decoded {
215
- output : Vec < u8 > ,
216
- } ,
217
- Done ,
218
- }
220
+ vm. mark_public ( inner_local) . map_err ( PrfError :: vm) ?;
221
+ vm. assign ( inner_local, state_to_bytes ( inner_local_value) )
222
+ . map_err ( PrfError :: vm) ?;
223
+ vm. commit ( inner_local) . map_err ( PrfError :: vm) ?;
219
224
220
- #[ derive( Debug ) ]
221
- enum InnerPartial {
222
- Decoding ( DecodeFutureTyped < BitVec , [ u32 ; 8 ] > ) ,
223
- Finished ( [ u32 ; 8 ] ) ,
225
+ Ok ( ( ) )
224
226
}
225
227
226
- impl InnerPartial {
227
- pub ( crate ) fn try_recv ( & mut self ) -> Result < Option < [ u32 ; 8 ] > , PrfError > {
228
- match self {
229
- InnerPartial :: Decoding ( value) => {
230
- let value = value. try_recv ( ) . map_err ( PrfError :: vm) ?;
231
- if let Some ( value) = value {
232
- * self = InnerPartial :: Finished ( value) ;
233
- }
234
- Ok ( value)
235
- }
236
- InnerPartial :: Finished ( value) => Ok ( Some ( * value) ) ,
237
- }
238
- }
228
+ #[ derive( Debug , Clone , Copy ) ]
229
+ struct PHash {
230
+ inner_local : Array < U8 , 32 > ,
231
+ output : Array < U8 , 32 > ,
239
232
}
0 commit comments