Skip to content

Commit 3ce2dff

Browse files
committed
simplify state logic for reduced prf even more
1 parent 93f4a08 commit 3ce2dff

File tree

2 files changed

+116
-122
lines changed

2 files changed

+116
-122
lines changed

crates/components/hmac-sha256/src/prf/function.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ mod tests {
234234
.unwrap();
235235
}
236236

237+
assert_eq!(prf_out_leader.len(), 2);
237238
assert_eq!(prf_out_leader.len(), prf_out_follower.len());
238239

239240
let prf_result_leader: Vec<u8> = prf_out_leader

crates/components/hmac-sha256/src/prf/function/reduced.rs

Lines changed: 115 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,31 @@ pub(crate) struct PrfFunction {
1717
label: &'static [u8],
1818
// The start seed and the label, e.g. client_random + server_random + "master_secret".
1919
start_seed_label: Vec<u8>,
20-
// The current HMAC message needed for a[i]
21-
a_msg: Vec<u8>,
22-
inner_partial: InnerPartial,
20+
iterations: usize,
21+
state: PrfState,
2322
a: Vec<PHash>,
2423
p: Vec<PHash>,
2524
}
2625

26+
#[derive(Debug)]
27+
enum PrfState {
28+
InnerPartial {
29+
inner_partial: DecodeFutureTyped<BitVec, [u32; 8]>,
30+
},
31+
ComputeA {
32+
iter: usize,
33+
inner_partial: [u32; 8],
34+
msg: Vec<u8>,
35+
},
36+
ComputeP {
37+
iter: usize,
38+
inner_partial: [u32; 8],
39+
a_output: DecodeFutureTyped<BitVec, [u8; 32]>,
40+
},
41+
ComputeLastP,
42+
Done,
43+
}
44+
2745
impl PrfFunction {
2846
const MS_LABEL: &[u8] = b"master secret";
2947
const KEY_LABEL: &[u8] = b"key expansion";
@@ -63,51 +81,68 @@ impl PrfFunction {
6381
}
6482

6583
pub(crate) fn wants_flush(&mut self) -> bool {
66-
let last_p = self.p.last().expect("Prf should be allocated");
67-
68-
if let State::Done = last_p.state {
84+
if let PrfState::Done = self.state {
6985
return false;
7086
}
7187
true
7288
}
7389

7490
pub(crate) fn flush(&mut self, vm: &mut dyn Vm<Binary>) -> Result<(), PrfError> {
75-
let inner_partial = self.inner_partial.try_recv()?;
76-
let Some(inner_partial) = inner_partial else {
77-
return Ok(());
78-
};
79-
80-
for (a, p) in self.a.iter_mut().zip(self.p.iter_mut()) {
81-
match &mut a.state {
82-
State::Init { .. } => {
83-
a.assign_inner_local(vm, inner_partial, &self.a_msg)?;
84-
break;
85-
}
86-
State::Assigned { output } => {
87-
if let Some(output) = output.try_recv().map_err(PrfError::vm)? {
88-
let output = output.to_vec();
89-
a.state = State::Decoded {
90-
output: output.clone(),
91-
};
92-
self.a_msg = output;
93-
}
94-
}
95-
_ => (),
91+
match &mut self.state {
92+
PrfState::InnerPartial { inner_partial } => {
93+
let Some(inner_partial) = inner_partial.try_recv().map_err(PrfError::vm)? else {
94+
return Ok(());
95+
};
96+
97+
self.state = PrfState::ComputeA {
98+
iter: 0,
99+
inner_partial,
100+
msg: self.start_seed_label.clone(),
101+
};
102+
self.flush(vm)?;
96103
}
97-
98-
match &mut p.state {
99-
State::Init { .. } => {
100-
if let State::Decoded { output } = &a.state {
101-
let mut p_msg = output.to_vec();
102-
p_msg.extend_from_slice(&self.start_seed_label);
103-
p.assign_inner_local(vm, inner_partial, &p_msg)?;
104+
PrfState::ComputeA {
105+
iter,
106+
inner_partial,
107+
msg,
108+
} => {
109+
let a = &self.a[*iter];
110+
assign_inner_local(vm, a.inner_local, *inner_partial, msg)?;
111+
112+
let a_output = vm.decode(a.output).map_err(PrfError::vm)?;
113+
self.state = PrfState::ComputeP {
114+
iter: *iter,
115+
inner_partial: *inner_partial,
116+
a_output,
117+
};
118+
}
119+
PrfState::ComputeP {
120+
iter,
121+
inner_partial,
122+
a_output,
123+
} => {
124+
let Some(output) = a_output.try_recv().map_err(PrfError::vm)? else {
125+
return Ok(());
126+
};
127+
let p = &self.p[*iter];
128+
129+
let mut msg = output.to_vec();
130+
msg.extend_from_slice(&self.start_seed_label);
131+
132+
assign_inner_local(vm, p.inner_local, *inner_partial, &msg)?;
133+
134+
if *iter == self.iterations {
135+
self.state = PrfState::ComputeLastP;
136+
} else {
137+
self.state = PrfState::ComputeA {
138+
iter: *iter + 1,
139+
inner_partial: *inner_partial,
140+
msg: output.to_vec(),
104141
}
105-
}
106-
State::Assigned { .. } => {
107-
p.state = State::Done;
108-
}
109-
_ => (),
142+
};
110143
}
144+
PrfState::ComputeLastP => self.state = PrfState::Done,
145+
_ => (),
111146
}
112147

113148
Ok(())
@@ -117,8 +152,7 @@ impl PrfFunction {
117152
let mut start_seed_label = self.label.to_vec();
118153
start_seed_label.extend_from_slice(&seed);
119154

120-
self.start_seed_label = start_seed_label.clone();
121-
self.a_msg = start_seed_label;
155+
self.start_seed_label = start_seed_label;
122156
}
123157

124158
pub(crate) fn output(&self) -> Vec<Array<U8, 32>> {
@@ -132,6 +166,10 @@ impl PrfFunction {
132166
inner_partial: Sha256,
133167
len: usize,
134168
) -> Result<Self, PrfError> {
169+
assert!(len > 0, "cannot compute 0 bytes for prf");
170+
171+
let iterations = len / 32 + ((len % 32) != 0) as usize;
172+
135173
let (inner_partial, _) = inner_partial
136174
.state()
137175
.expect("state should be set for inner_partial");
@@ -140,100 +178,55 @@ impl PrfFunction {
140178
let mut prf = Self {
141179
label,
142180
start_seed_label: vec![],
143-
a_msg: vec![],
144-
inner_partial: InnerPartial::Decoding(inner_partial),
181+
// used for indexing, so we need to subtract one here
182+
iterations: iterations - 1,
183+
state: PrfState::InnerPartial { inner_partial },
145184
a: vec![],
146185
p: vec![],
147186
};
148187

149-
assert!(len > 0, "cannot compute 0 bytes for prf");
150-
151-
let iterations = len / 32 + ((len % 32) != 0) as usize;
152-
153188
for _ in 0..iterations {
154-
let a = PHash::alloc(vm, outer_partial.clone())?;
155-
prf.a.push(a);
156-
157-
let p = PHash::alloc(vm, outer_partial.clone())?;
158-
prf.p.push(p);
189+
// setup A[i]
190+
let inner_local: Array<U8, 32> = vm.alloc().map_err(PrfError::vm)?;
191+
let output = hmac_sha256(vm, outer_partial.clone(), inner_local)?;
192+
let p_hash = PHash {
193+
inner_local,
194+
output,
195+
};
196+
prf.a.push(p_hash);
197+
198+
// setup P[i]
199+
let inner_local: Array<U8, 32> = vm.alloc().map_err(PrfError::vm)?;
200+
let output = hmac_sha256(vm, outer_partial.clone(), inner_local)?;
201+
let p_hash = PHash {
202+
inner_local,
203+
output,
204+
};
205+
prf.p.push(p_hash);
159206
}
160207

161208
Ok(prf)
162209
}
163210
}
164211

165-
#[derive(Debug)]
166-
struct PHash {
167-
output: Array<U8, 32>,
168-
state: State,
169-
}
170-
171-
impl PHash {
172-
fn alloc(vm: &mut dyn Vm<Binary>, outer_partial: Sha256) -> Result<Self, PrfError> {
173-
let inner_local: Array<U8, 32> = vm.alloc().map_err(PrfError::vm)?;
174-
let output = hmac_sha256(vm, outer_partial, inner_local)?;
175-
176-
let p_hash = Self {
177-
state: State::Init { inner_local },
178-
output,
179-
};
180-
181-
Ok(p_hash)
182-
}
183-
184-
fn assign_inner_local(
185-
&mut self,
186-
vm: &mut dyn Vm<Binary>,
187-
inner_partial: [u32; 8],
188-
msg: &[u8],
189-
) -> Result<(), PrfError> {
190-
if let State::Init { inner_local, .. } = self.state {
191-
let inner_local_value = sha256(inner_partial, 64, msg);
192-
193-
vm.mark_public(inner_local).map_err(PrfError::vm)?;
194-
vm.assign(inner_local, state_to_bytes(inner_local_value))
195-
.map_err(PrfError::vm)?;
196-
vm.commit(inner_local).map_err(PrfError::vm)?;
197-
198-
let output = vm.decode(self.output).map_err(PrfError::vm)?;
199-
self.state = State::Assigned { output };
200-
}
201-
202-
Ok(())
203-
}
204-
}
212+
fn assign_inner_local(
213+
vm: &mut dyn Vm<Binary>,
214+
inner_local: Array<U8, 32>,
215+
inner_partial: [u32; 8],
216+
msg: &[u8],
217+
) -> Result<(), PrfError> {
218+
let inner_local_value = sha256(inner_partial, 64, msg);
205219

206-
#[derive(Debug)]
207-
enum State {
208-
Init {
209-
inner_local: Array<U8, 32>,
210-
},
211-
Assigned {
212-
output: DecodeFutureTyped<BitVec, [u8; 32]>,
213-
},
214-
Decoded {
215-
output: Vec<u8>,
216-
},
217-
Done,
218-
}
220+
vm.mark_public(inner_local).map_err(PrfError::vm)?;
221+
vm.assign(inner_local, state_to_bytes(inner_local_value))
222+
.map_err(PrfError::vm)?;
223+
vm.commit(inner_local).map_err(PrfError::vm)?;
219224

220-
#[derive(Debug)]
221-
enum InnerPartial {
222-
Decoding(DecodeFutureTyped<BitVec, [u32; 8]>),
223-
Finished([u32; 8]),
225+
Ok(())
224226
}
225227

226-
impl InnerPartial {
227-
pub(crate) fn try_recv(&mut self) -> Result<Option<[u32; 8]>, PrfError> {
228-
match self {
229-
InnerPartial::Decoding(value) => {
230-
let value = value.try_recv().map_err(PrfError::vm)?;
231-
if let Some(value) = value {
232-
*self = InnerPartial::Finished(value);
233-
}
234-
Ok(value)
235-
}
236-
InnerPartial::Finished(value) => Ok(Some(*value)),
237-
}
238-
}
228+
#[derive(Debug, Clone, Copy)]
229+
struct PHash {
230+
inner_local: Array<U8, 32>,
231+
output: Array<U8, 32>,
239232
}

0 commit comments

Comments
 (0)