4
4
#![ deny( clippy:: all) ]
5
5
#![ forbid( unsafe_code) ]
6
6
7
- use std:: {
8
- mem,
9
- sync:: {
10
- atomic:: { AtomicBool , Ordering } ,
11
- Arc ,
12
- } ,
13
- } ;
7
+ mod map;
8
+
9
+ use std:: { mem, sync:: Arc } ;
14
10
15
11
use async_trait:: async_trait;
16
12
use mpz_common:: Context ;
@@ -38,11 +34,13 @@ pub struct Deap<Mpc, Zk> {
38
34
role : Role ,
39
35
mpc : Arc < Mutex < Mpc > > ,
40
36
zk : Arc < Mutex < Zk > > ,
37
+ /// Mapping between the memories of the MPC and ZK VMs.
38
+ memory_map : map:: MemoryMap ,
39
+ /// Ranges of the follower's private inputs.
40
+ follower_input_ranges : RangeSet < usize > ,
41
41
/// Private inputs of the follower.
42
- follower_inputs : RangeSet < usize > ,
42
+ follower_inputs : Vec < Slice > ,
43
43
outputs : Vec < ( Slice , DecodeFuture < BitVec > ) > ,
44
- /// Whether the memories of the two VMs are potentially desynchronized.
45
- desync : AtomicBool ,
46
44
}
47
45
48
46
impl < Mpc , Zk > Deap < Mpc , Zk > {
@@ -52,9 +50,10 @@ impl<Mpc, Zk> Deap<Mpc, Zk> {
52
50
role,
53
51
mpc : Arc :: new ( Mutex :: new ( mpc) ) ,
54
52
zk : Arc :: new ( Mutex :: new ( zk) ) ,
55
- follower_inputs : RangeSet :: default ( ) ,
53
+ memory_map : map:: MemoryMap :: default ( ) ,
54
+ follower_input_ranges : RangeSet :: default ( ) ,
55
+ follower_inputs : Vec :: default ( ) ,
56
56
outputs : Vec :: default ( ) ,
57
- desync : AtomicBool :: new ( false ) ,
58
57
}
59
58
}
60
59
@@ -68,34 +67,28 @@ impl<Mpc, Zk> Deap<Mpc, Zk> {
68
67
69
68
/// Returns a mutable reference to the ZK VM.
70
69
///
71
- /// # Note
72
- ///
73
- /// After calling this method, allocations will no longer be allowed in the
74
- /// DEAP VM as the memory will potentially be desynchronized.
75
- ///
76
70
/// # Panics
77
71
///
78
72
/// Panics if the mutex is locked by another thread.
79
73
pub fn zk ( & self ) -> MutexGuard < ' _ , Zk > {
80
- self . desync . store ( true , Ordering :: Relaxed ) ;
81
74
self . zk . try_lock ( ) . unwrap ( )
82
75
}
83
76
84
77
/// Returns an owned mutex guard to the ZK VM.
85
78
///
86
- /// # Note
87
- ///
88
- /// After calling this method, allocations will no longer be allowed in the
89
- /// DEAP VM as the memory will potentially be desynchronized.
90
- ///
91
79
/// # Panics
92
80
///
93
81
/// Panics if the mutex is locked by another thread.
94
82
pub fn zk_owned ( & self ) -> OwnedMutexGuard < Zk > {
95
- self . desync . store ( true , Ordering :: Relaxed ) ;
96
83
self . zk . clone ( ) . try_lock_owned ( ) . unwrap ( )
97
84
}
98
85
86
+ /// Translates a slice from the MPC VM address space to the ZK VM address
87
+ /// space.
88
+ pub fn translate_slice ( & self , slice : Slice ) -> Result < Slice , VmError > {
89
+ self . memory_map . try_get ( slice)
90
+ }
91
+
99
92
#[ cfg( test) ]
100
93
fn mpc ( & self ) -> MutexGuard < ' _ , Mpc > {
101
94
self . mpc . try_lock ( ) . unwrap ( )
@@ -124,18 +117,15 @@ where
124
117
// MACs.
125
118
let input_futs = self
126
119
. follower_inputs
127
- . iter_ranges ( )
128
- . map ( |input| mpc. decode_raw ( Slice :: from_range_unchecked ( input) ) )
120
+ . iter ( )
121
+ . map ( |& input| mpc. decode_raw ( input) )
129
122
. collect :: < Result < Vec < _ > , _ > > ( ) ?;
130
123
131
124
mpc. execute_all ( ctx) . await ?;
132
125
133
126
// Assign inputs to the ZK VM.
134
- for ( mut decode, input) in input_futs
135
- . into_iter ( )
136
- . zip ( self . follower_inputs . iter_ranges ( ) )
137
- {
138
- let input = Slice :: from_range_unchecked ( input) ;
127
+ for ( mut decode, & input) in input_futs. into_iter ( ) . zip ( & self . follower_inputs ) {
128
+ let input = self . memory_map . try_get ( input) ?;
139
129
140
130
// Follower has already assigned the inputs.
141
131
if let Role :: Leader = self . role {
@@ -185,30 +175,35 @@ where
185
175
type Error = VmError ;
186
176
187
177
fn alloc_raw ( & mut self , size : usize ) -> Result < Slice , VmError > {
188
- if self . desync . load ( Ordering :: Relaxed ) {
189
- return Err ( VmError :: memory (
190
- "DEAP VM memories are potentially desynchronized" ,
191
- ) ) ;
192
- }
178
+ let mpc_slice = self . mpc . try_lock ( ) . unwrap ( ) . alloc_raw ( size) ?;
179
+ let zk_slice = self . zk . try_lock ( ) . unwrap ( ) . alloc_raw ( size) ?;
193
180
194
- self . zk . try_lock ( ) . unwrap ( ) . alloc_raw ( size) ?;
195
- self . mpc . try_lock ( ) . unwrap ( ) . alloc_raw ( size)
181
+ self . memory_map . insert ( mpc_slice, zk_slice) ;
182
+
183
+ Ok ( mpc_slice)
196
184
}
197
185
198
186
fn assign_raw ( & mut self , slice : Slice , data : BitVec ) -> Result < ( ) , VmError > {
199
- self . zk
187
+ self . mpc
200
188
. try_lock ( )
201
189
. unwrap ( )
202
190
. assign_raw ( slice, data. clone ( ) ) ?;
203
- self . mpc . try_lock ( ) . unwrap ( ) . assign_raw ( slice, data)
191
+
192
+ self . zk
193
+ . try_lock ( )
194
+ . unwrap ( )
195
+ . assign_raw ( self . memory_map . try_get ( slice) ?, data)
204
196
}
205
197
206
198
fn commit_raw ( & mut self , slice : Slice ) -> Result < ( ) , VmError > {
207
199
// Follower's private inputs are not committed in the ZK VM until finalization.
208
- let input_minus_follower = slice. to_range ( ) . difference ( & self . follower_inputs ) ;
200
+ let input_minus_follower = slice. to_range ( ) . difference ( & self . follower_input_ranges ) ;
209
201
let mut zk = self . zk . try_lock ( ) . unwrap ( ) ;
210
202
for input in input_minus_follower. iter_ranges ( ) {
211
- zk. commit_raw ( Slice :: from_range_unchecked ( input) ) ?;
203
+ zk. commit_raw (
204
+ self . memory_map
205
+ . try_get ( Slice :: from_range_unchecked ( input) ) ?,
206
+ ) ?;
212
207
}
213
208
214
209
self . mpc . try_lock ( ) . unwrap ( ) . commit_raw ( slice)
@@ -219,7 +214,11 @@ where
219
214
}
220
215
221
216
fn decode_raw ( & mut self , slice : Slice ) -> Result < DecodeFuture < BitVec > , VmError > {
222
- let fut = self . zk . try_lock ( ) . unwrap ( ) . decode_raw ( slice) ?;
217
+ let fut = self
218
+ . zk
219
+ . try_lock ( )
220
+ . unwrap ( )
221
+ . decode_raw ( self . memory_map . try_get ( slice) ?) ?;
223
222
self . outputs . push ( ( slice, fut) ) ;
224
223
225
224
self . mpc . try_lock ( ) . unwrap ( ) . decode_raw ( slice)
@@ -234,23 +233,27 @@ where
234
233
type Error = VmError ;
235
234
236
235
fn mark_public_raw ( & mut self , slice : Slice ) -> Result < ( ) , VmError > {
237
- self . zk . try_lock ( ) . unwrap ( ) . mark_public_raw ( slice) ?;
238
- self . mpc . try_lock ( ) . unwrap ( ) . mark_public_raw ( slice)
236
+ self . mpc . try_lock ( ) . unwrap ( ) . mark_public_raw ( slice) ?;
237
+ self . zk
238
+ . try_lock ( )
239
+ . unwrap ( )
240
+ . mark_public_raw ( self . memory_map . try_get ( slice) ?)
239
241
}
240
242
241
243
fn mark_private_raw ( & mut self , slice : Slice ) -> Result < ( ) , VmError > {
242
244
let mut zk = self . zk . try_lock ( ) . unwrap ( ) ;
243
245
let mut mpc = self . mpc . try_lock ( ) . unwrap ( ) ;
244
246
match self . role {
245
247
Role :: Leader => {
246
- zk. mark_private_raw ( slice) ?;
247
248
mpc. mark_private_raw ( slice) ?;
249
+ zk. mark_private_raw ( self . memory_map . try_get ( slice) ?) ?;
248
250
}
249
251
Role :: Follower => {
250
- // Follower's private inputs will become public during finalization.
251
- zk. mark_public_raw ( slice) ?;
252
252
mpc. mark_private_raw ( slice) ?;
253
- self . follower_inputs . union_mut ( & slice. to_range ( ) ) ;
253
+ // Follower's private inputs will become public during finalization.
254
+ zk. mark_public_raw ( self . memory_map . try_get ( slice) ?) ?;
255
+ self . follower_input_ranges . union_mut ( & slice. to_range ( ) ) ;
256
+ self . follower_inputs . push ( slice) ;
254
257
}
255
258
}
256
259
@@ -262,14 +265,15 @@ where
262
265
let mut mpc = self . mpc . try_lock ( ) . unwrap ( ) ;
263
266
match self . role {
264
267
Role :: Leader => {
265
- // Follower's private inputs will become public during finalization.
266
- zk. mark_public_raw ( slice) ?;
267
268
mpc. mark_blind_raw ( slice) ?;
268
- self . follower_inputs . union_mut ( & slice. to_range ( ) ) ;
269
+ // Follower's private inputs will become public during finalization.
270
+ zk. mark_public_raw ( self . memory_map . try_get ( slice) ?) ?;
271
+ self . follower_input_ranges . union_mut ( & slice. to_range ( ) ) ;
272
+ self . follower_inputs . push ( slice) ;
269
273
}
270
274
Role :: Follower => {
271
- zk. mark_blind_raw ( slice) ?;
272
275
mpc. mark_blind_raw ( slice) ?;
276
+ zk. mark_blind_raw ( self . memory_map . try_get ( slice) ?) ?;
273
277
}
274
278
}
275
279
@@ -283,14 +287,21 @@ where
283
287
Zk : Vm < Binary > ,
284
288
{
285
289
fn call_raw ( & mut self , call : Call ) -> Result < Slice , VmError > {
286
- if self . desync . load ( Ordering :: Relaxed ) {
287
- return Err ( VmError :: memory (
288
- "DEAP VM memories are potentially desynchronized" ,
289
- ) ) ;
290
+ let ( circ, inputs) = call. clone ( ) . into_parts ( ) ;
291
+ let mut builder = Call :: builder ( circ) ;
292
+
293
+ for input in inputs {
294
+ builder = builder. arg ( self . memory_map . try_get ( input) ?) ;
290
295
}
291
296
292
- self . zk . try_lock ( ) . unwrap ( ) . call_raw ( call. clone ( ) ) ?;
293
- self . mpc . try_lock ( ) . unwrap ( ) . call_raw ( call)
297
+ let zk_call = builder. build ( ) . expect ( "call should be valid" ) ;
298
+
299
+ let output = self . mpc . try_lock ( ) . unwrap ( ) . call_raw ( call) ?;
300
+ let zk_output = self . zk . try_lock ( ) . unwrap ( ) . call_raw ( zk_call) ?;
301
+
302
+ self . memory_map . insert ( output, zk_output) ;
303
+
304
+ Ok ( output)
294
305
}
295
306
}
296
307
@@ -451,6 +462,90 @@ mod tests {
451
462
assert_eq ! ( ct_leader, ct_follower) ;
452
463
}
453
464
465
+ #[ tokio:: test]
466
+ async fn test_deap_desync_memory ( ) {
467
+ let mut rng = StdRng :: seed_from_u64 ( 0 ) ;
468
+ let delta_mpc = Delta :: random ( & mut rng) ;
469
+ let delta_zk = Delta :: random ( & mut rng) ;
470
+
471
+ let ( mut ctx_a, mut ctx_b) = test_st_context ( 8 ) ;
472
+ let ( rcot_send, rcot_recv) = ideal_rcot ( Block :: ZERO , delta_zk. into_inner ( ) ) ;
473
+ let ( cot_send, cot_recv) = ideal_cot ( delta_mpc. into_inner ( ) ) ;
474
+
475
+ let gb = Garbler :: new ( cot_send, [ 0u8 ; 16 ] , delta_mpc) ;
476
+ let ev = Evaluator :: new ( cot_recv) ;
477
+ let prover = Prover :: new ( rcot_recv) ;
478
+ let verifier = Verifier :: new ( delta_zk, rcot_send) ;
479
+
480
+ let mut leader = Deap :: new ( Role :: Leader , gb, prover) ;
481
+ let mut follower = Deap :: new ( Role :: Follower , ev, verifier) ;
482
+
483
+ // Desynchronize the memories.
484
+ let _ = leader. zk ( ) . alloc_raw ( 1 ) . unwrap ( ) ;
485
+ let _ = follower. zk ( ) . alloc_raw ( 1 ) . unwrap ( ) ;
486
+
487
+ let ( ct_leader, ct_follower) = futures:: join!(
488
+ async {
489
+ let key: Array <U8 , 16 > = leader. alloc( ) . unwrap( ) ;
490
+ let msg: Array <U8 , 16 > = leader. alloc( ) . unwrap( ) ;
491
+
492
+ leader. mark_private( key) . unwrap( ) ;
493
+ leader. mark_blind( msg) . unwrap( ) ;
494
+ leader. assign( key, [ 42u8 ; 16 ] ) . unwrap( ) ;
495
+ leader. commit( key) . unwrap( ) ;
496
+ leader. commit( msg) . unwrap( ) ;
497
+
498
+ let ct: Array <U8 , 16 > = leader
499
+ . call(
500
+ Call :: builder( AES128 . clone( ) )
501
+ . arg( key)
502
+ . arg( msg)
503
+ . build( )
504
+ . unwrap( ) ,
505
+ )
506
+ . unwrap( ) ;
507
+ let ct = leader. decode( ct) . unwrap( ) ;
508
+
509
+ leader. flush( & mut ctx_a) . await . unwrap( ) ;
510
+ leader. execute( & mut ctx_a) . await . unwrap( ) ;
511
+ leader. flush( & mut ctx_a) . await . unwrap( ) ;
512
+ leader. finalize( & mut ctx_a) . await . unwrap( ) ;
513
+
514
+ ct. await . unwrap( )
515
+ } ,
516
+ async {
517
+ let key: Array <U8 , 16 > = follower. alloc( ) . unwrap( ) ;
518
+ let msg: Array <U8 , 16 > = follower. alloc( ) . unwrap( ) ;
519
+
520
+ follower. mark_blind( key) . unwrap( ) ;
521
+ follower. mark_private( msg) . unwrap( ) ;
522
+ follower. assign( msg, [ 69u8 ; 16 ] ) . unwrap( ) ;
523
+ follower. commit( key) . unwrap( ) ;
524
+ follower. commit( msg) . unwrap( ) ;
525
+
526
+ let ct: Array <U8 , 16 > = follower
527
+ . call(
528
+ Call :: builder( AES128 . clone( ) )
529
+ . arg( key)
530
+ . arg( msg)
531
+ . build( )
532
+ . unwrap( ) ,
533
+ )
534
+ . unwrap( ) ;
535
+ let ct = follower. decode( ct) . unwrap( ) ;
536
+
537
+ follower. flush( & mut ctx_b) . await . unwrap( ) ;
538
+ follower. execute( & mut ctx_b) . await . unwrap( ) ;
539
+ follower. flush( & mut ctx_b) . await . unwrap( ) ;
540
+ follower. finalize( & mut ctx_b) . await . unwrap( ) ;
541
+
542
+ ct. await . unwrap( )
543
+ }
544
+ ) ;
545
+
546
+ assert_eq ! ( ct_leader, ct_follower) ;
547
+ }
548
+
454
549
// Tests that the leader can not use different inputs in each VM without
455
550
// detection by the follower.
456
551
#[ tokio:: test]
0 commit comments