Skip to content

Commit 860a1e5

Browse files
committed
fix: fix identity circuit endianness and add test
1 parent cd479f2 commit 860a1e5

File tree

1 file changed

+80
-5
lines changed
  • crates/components/hmac-sha256/src/prf

1 file changed

+80
-5
lines changed

crates/components/hmac-sha256/src/prf/mod.rs

+80-5
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ fn merge_outputs(
299299
assert!(output_bytes <= 32 * inputs.len());
300300

301301
let bits = Array::<U32, 8>::SIZE * inputs.len();
302-
let id_circ = identity_circuit(bits);
302+
let id_circ = merge_to_big_endian(4, bits);
303303

304304
let mut builder = Call::builder(id_circ);
305305
for &input in inputs.iter() {
@@ -312,14 +312,89 @@ fn merge_outputs(
312312
Ok(output)
313313
}
314314

315-
fn identity_circuit(size: usize) -> Arc<Circuit> {
315+
fn merge_to_big_endian(element_byte_size: usize, size: usize) -> Arc<Circuit> {
316+
assert!((size / 8) % element_byte_size == 0);
317+
316318
let mut builder = CircuitBuilder::new();
317319
let inputs = (0..size).map(|_| builder.add_input()).collect::<Vec<_>>();
318320

319-
for input in inputs.into_iter() {
320-
let output = builder.add_id_gate(input);
321-
builder.add_output(output);
321+
for input in inputs.chunks_exact(element_byte_size * 8) {
322+
for byte in input.chunks_exact(8).rev() {
323+
for &feed in byte.iter() {
324+
let output = builder.add_id_gate(feed);
325+
builder.add_output(output);
326+
}
327+
}
322328
}
323329

324330
Arc::new(builder.build().expect("identity circuit is valid"))
325331
}
332+
333+
#[cfg(test)]
334+
mod tests {
335+
use crate::{convert_to_bytes, prf::merge_outputs, test_utils::mock_vm};
336+
use mpz_common::context::test_st_context;
337+
use mpz_vm_core::{
338+
memory::{binary::U32, Array, MemoryExt, ViewExt},
339+
Execute,
340+
};
341+
342+
#[tokio::test]
343+
async fn test_merge_outputs() {
344+
let (mut ctx_a, mut ctx_b) = test_st_context(8);
345+
let (mut leader, mut follower) = mock_vm();
346+
347+
let input1: [u32; 8] = std::array::from_fn(|i| i as u32);
348+
let input2: [u32; 8] = std::array::from_fn(|i| i as u32 + 8);
349+
350+
let mut expected = convert_to_bytes(input1).to_vec();
351+
expected.extend_from_slice(&convert_to_bytes(input2));
352+
expected.truncate(48);
353+
354+
// leader
355+
let input1_leader: Array<U32, 8> = leader.alloc().unwrap();
356+
let input2_leader: Array<U32, 8> = leader.alloc().unwrap();
357+
358+
leader.mark_public(input1_leader).unwrap();
359+
leader.mark_public(input2_leader).unwrap();
360+
361+
leader.assign(input1_leader, input1).unwrap();
362+
leader.assign(input2_leader, input2).unwrap();
363+
364+
leader.commit(input1_leader).unwrap();
365+
leader.commit(input2_leader).unwrap();
366+
367+
let merged_leader =
368+
merge_outputs(&mut leader, vec![input1_leader, input2_leader], 48).unwrap();
369+
let mut merged_leader = leader.decode(merged_leader).unwrap();
370+
371+
// follower
372+
let input1_follower: Array<U32, 8> = follower.alloc().unwrap();
373+
let input2_follower: Array<U32, 8> = follower.alloc().unwrap();
374+
375+
follower.mark_public(input1_follower).unwrap();
376+
follower.mark_public(input2_follower).unwrap();
377+
378+
follower.assign(input1_follower, input1).unwrap();
379+
follower.assign(input2_follower, input2).unwrap();
380+
381+
follower.commit(input1_follower).unwrap();
382+
follower.commit(input2_follower).unwrap();
383+
384+
let merged_follower =
385+
merge_outputs(&mut follower, vec![input1_follower, input2_follower], 48).unwrap();
386+
let mut merged_follower = follower.decode(merged_follower).unwrap();
387+
388+
tokio::try_join!(
389+
leader.execute_all(&mut ctx_a),
390+
follower.execute_all(&mut ctx_b)
391+
)
392+
.unwrap();
393+
394+
let merged_leader = merged_leader.try_recv().unwrap().unwrap();
395+
let merged_follower = merged_follower.try_recv().unwrap().unwrap();
396+
397+
assert_eq!(merged_leader, merged_follower);
398+
assert_eq!(merged_leader, expected);
399+
}
400+
}

0 commit comments

Comments
 (0)