@@ -299,7 +299,7 @@ fn merge_outputs(
299299 assert ! ( output_bytes <= 32 * inputs. len( ) ) ;
300300
301301 let bits = Array :: < U32 , 8 > :: SIZE * inputs. len ( ) ;
302- let id_circ = identity_circuit ( bits) ;
302+ let id_circ = merge_to_big_endian ( 4 , bits) ;
303303
304304 let mut builder = Call :: builder ( id_circ) ;
305305 for & input in inputs. iter ( ) {
@@ -312,14 +312,89 @@ fn merge_outputs(
312312 Ok ( output)
313313}
314314
315- fn identity_circuit ( size : usize ) -> Arc < Circuit > {
315+ fn merge_to_big_endian ( element_byte_size : usize , size : usize ) -> Arc < Circuit > {
316+ assert ! ( ( size / 8 ) % element_byte_size == 0 ) ;
317+
316318 let mut builder = CircuitBuilder :: new ( ) ;
317319 let inputs = ( 0 ..size) . map ( |_| builder. add_input ( ) ) . collect :: < Vec < _ > > ( ) ;
318320
319- for input in inputs. into_iter ( ) {
320- let output = builder. add_id_gate ( input) ;
321- builder. add_output ( output) ;
321+ for input in inputs. chunks_exact ( element_byte_size * 8 ) {
322+ for byte in input. chunks_exact ( 8 ) . rev ( ) {
323+ for & feed in byte. iter ( ) {
324+ let output = builder. add_id_gate ( feed) ;
325+ builder. add_output ( output) ;
326+ }
327+ }
322328 }
323329
324330 Arc :: new ( builder. build ( ) . expect ( "identity circuit is valid" ) )
325331}
332+
333+ #[ cfg( test) ]
334+ mod tests {
335+ use crate :: { convert_to_bytes, prf:: merge_outputs, test_utils:: mock_vm} ;
336+ use mpz_common:: context:: test_st_context;
337+ use mpz_vm_core:: {
338+ memory:: { binary:: U32 , Array , MemoryExt , ViewExt } ,
339+ Execute ,
340+ } ;
341+
342+ #[ tokio:: test]
343+ async fn test_merge_outputs ( ) {
344+ let ( mut ctx_a, mut ctx_b) = test_st_context ( 8 ) ;
345+ let ( mut leader, mut follower) = mock_vm ( ) ;
346+
347+ let input1: [ u32 ; 8 ] = std:: array:: from_fn ( |i| i as u32 ) ;
348+ let input2: [ u32 ; 8 ] = std:: array:: from_fn ( |i| i as u32 + 8 ) ;
349+
350+ let mut expected = convert_to_bytes ( input1) . to_vec ( ) ;
351+ expected. extend_from_slice ( & convert_to_bytes ( input2) ) ;
352+ expected. truncate ( 48 ) ;
353+
354+ // leader
355+ let input1_leader: Array < U32 , 8 > = leader. alloc ( ) . unwrap ( ) ;
356+ let input2_leader: Array < U32 , 8 > = leader. alloc ( ) . unwrap ( ) ;
357+
358+ leader. mark_public ( input1_leader) . unwrap ( ) ;
359+ leader. mark_public ( input2_leader) . unwrap ( ) ;
360+
361+ leader. assign ( input1_leader, input1) . unwrap ( ) ;
362+ leader. assign ( input2_leader, input2) . unwrap ( ) ;
363+
364+ leader. commit ( input1_leader) . unwrap ( ) ;
365+ leader. commit ( input2_leader) . unwrap ( ) ;
366+
367+ let merged_leader =
368+ merge_outputs ( & mut leader, vec ! [ input1_leader, input2_leader] , 48 ) . unwrap ( ) ;
369+ let mut merged_leader = leader. decode ( merged_leader) . unwrap ( ) ;
370+
371+ // follower
372+ let input1_follower: Array < U32 , 8 > = follower. alloc ( ) . unwrap ( ) ;
373+ let input2_follower: Array < U32 , 8 > = follower. alloc ( ) . unwrap ( ) ;
374+
375+ follower. mark_public ( input1_follower) . unwrap ( ) ;
376+ follower. mark_public ( input2_follower) . unwrap ( ) ;
377+
378+ follower. assign ( input1_follower, input1) . unwrap ( ) ;
379+ follower. assign ( input2_follower, input2) . unwrap ( ) ;
380+
381+ follower. commit ( input1_follower) . unwrap ( ) ;
382+ follower. commit ( input2_follower) . unwrap ( ) ;
383+
384+ let merged_follower =
385+ merge_outputs ( & mut follower, vec ! [ input1_follower, input2_follower] , 48 ) . unwrap ( ) ;
386+ let mut merged_follower = follower. decode ( merged_follower) . unwrap ( ) ;
387+
388+ tokio:: try_join!(
389+ leader. execute_all( & mut ctx_a) ,
390+ follower. execute_all( & mut ctx_b)
391+ )
392+ . unwrap ( ) ;
393+
394+ let merged_leader = merged_leader. try_recv ( ) . unwrap ( ) . unwrap ( ) ;
395+ let merged_follower = merged_follower. try_recv ( ) . unwrap ( ) . unwrap ( ) ;
396+
397+ assert_eq ! ( merged_leader, merged_follower) ;
398+ assert_eq ! ( merged_leader, expected) ;
399+ }
400+ }
0 commit comments