Skip to content

Commit d047851

Browse files
committed
feat(deap): address space mapping
1 parent b78bd31 commit d047851

File tree

2 files changed

+268
-60
lines changed

2 files changed

+268
-60
lines changed

crates/components/deap/src/lib.rs

+157-60
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,9 @@
44
#![deny(clippy::all)]
55
#![forbid(unsafe_code)]
66

7-
use std::{
8-
mem,
9-
sync::{
10-
atomic::{AtomicBool, Ordering},
11-
Arc,
12-
},
13-
};
7+
mod map;
8+
9+
use std::{mem, sync::Arc};
1410

1511
use async_trait::async_trait;
1612
use mpz_common::Context;
@@ -38,11 +34,15 @@ pub struct Deap<Mpc, Zk> {
3834
role: Role,
3935
mpc: Arc<Mutex<Mpc>>,
4036
zk: Arc<Mutex<Zk>>,
41-
/// Private inputs of the follower.
42-
follower_inputs: RangeSet<usize>,
37+
/// Mapping between the memories of the MPC and ZK VMs.
38+
memory_map: map::MemoryMap,
39+
/// Ranges of the follower's private inputs in the MPC VM.
40+
follower_input_ranges: RangeSet<usize>,
41+
/// Private inputs of the follower in the MPC VM.
42+
follower_inputs: Vec<Slice>,
43+
/// Outputs of the follower from the ZK VM. The references
44+
/// correspond to the MPC VM.
4345
outputs: Vec<(Slice, DecodeFuture<BitVec>)>,
44-
/// Whether the memories of the two VMs are potentially desynchronized.
45-
desync: AtomicBool,
4646
}
4747

4848
impl<Mpc, Zk> Deap<Mpc, Zk> {
@@ -52,9 +52,10 @@ impl<Mpc, Zk> Deap<Mpc, Zk> {
5252
role,
5353
mpc: Arc::new(Mutex::new(mpc)),
5454
zk: Arc::new(Mutex::new(zk)),
55-
follower_inputs: RangeSet::default(),
55+
memory_map: map::MemoryMap::default(),
56+
follower_input_ranges: RangeSet::default(),
57+
follower_inputs: Vec::default(),
5658
outputs: Vec::default(),
57-
desync: AtomicBool::new(false),
5859
}
5960
}
6061

@@ -68,34 +69,28 @@ impl<Mpc, Zk> Deap<Mpc, Zk> {
6869

6970
/// Returns a mutable reference to the ZK VM.
7071
///
71-
/// # Note
72-
///
73-
/// After calling this method, allocations will no longer be allowed in the
74-
/// DEAP VM as the memory will potentially be desynchronized.
75-
///
7672
/// # Panics
7773
///
7874
/// Panics if the mutex is locked by another thread.
7975
pub fn zk(&self) -> MutexGuard<'_, Zk> {
80-
self.desync.store(true, Ordering::Relaxed);
8176
self.zk.try_lock().unwrap()
8277
}
8378

8479
/// Returns an owned mutex guard to the ZK VM.
8580
///
86-
/// # Note
87-
///
88-
/// After calling this method, allocations will no longer be allowed in the
89-
/// DEAP VM as the memory will potentially be desynchronized.
90-
///
9181
/// # Panics
9282
///
9383
/// Panics if the mutex is locked by another thread.
9484
pub fn zk_owned(&self) -> OwnedMutexGuard<Zk> {
95-
self.desync.store(true, Ordering::Relaxed);
9685
self.zk.clone().try_lock_owned().unwrap()
9786
}
9887

88+
/// Translates a slice from the MPC VM address space to the ZK VM address
89+
/// space.
90+
pub fn translate_slice(&self, slice: Slice) -> Result<Slice, VmError> {
91+
self.memory_map.try_get(slice)
92+
}
93+
9994
#[cfg(test)]
10095
fn mpc(&self) -> MutexGuard<'_, Mpc> {
10196
self.mpc.try_lock().unwrap()
@@ -124,18 +119,15 @@ where
124119
// MACs.
125120
let input_futs = self
126121
.follower_inputs
127-
.iter_ranges()
128-
.map(|input| mpc.decode_raw(Slice::from_range_unchecked(input)))
122+
.iter()
123+
.map(|&input| mpc.decode_raw(input))
129124
.collect::<Result<Vec<_>, _>>()?;
130125

131126
mpc.execute_all(ctx).await?;
132127

133128
// Assign inputs to the ZK VM.
134-
for (mut decode, input) in input_futs
135-
.into_iter()
136-
.zip(self.follower_inputs.iter_ranges())
137-
{
138-
let input = Slice::from_range_unchecked(input);
129+
for (mut decode, &input) in input_futs.into_iter().zip(&self.follower_inputs) {
130+
let input = self.memory_map.try_get(input)?;
139131

140132
// Follower has already assigned the inputs.
141133
if let Role::Leader = self.role {
@@ -189,26 +181,28 @@ where
189181
}
190182

191183
fn alloc_raw(&mut self, size: usize) -> Result<Slice, VmError> {
192-
if self.desync.load(Ordering::Relaxed) {
193-
return Err(VmError::memory(
194-
"DEAP VM memories are potentially desynchronized",
195-
));
196-
}
184+
let mpc_slice = self.mpc.try_lock().unwrap().alloc_raw(size)?;
185+
let zk_slice = self.zk.try_lock().unwrap().alloc_raw(size)?;
197186

198-
self.zk.try_lock().unwrap().alloc_raw(size)?;
199-
self.mpc.try_lock().unwrap().alloc_raw(size)
187+
self.memory_map.insert(mpc_slice, zk_slice);
188+
189+
Ok(mpc_slice)
200190
}
201191

202192
fn is_assigned_raw(&self, slice: Slice) -> bool {
203193
self.mpc.try_lock().unwrap().is_assigned_raw(slice)
204194
}
205195

206196
fn assign_raw(&mut self, slice: Slice, data: BitVec) -> Result<(), VmError> {
207-
self.zk
197+
self.mpc
208198
.try_lock()
209199
.unwrap()
210200
.assign_raw(slice, data.clone())?;
211-
self.mpc.try_lock().unwrap().assign_raw(slice, data)
201+
202+
self.zk
203+
.try_lock()
204+
.unwrap()
205+
.assign_raw(self.memory_map.try_get(slice)?, data)
212206
}
213207

214208
fn is_committed_raw(&self, slice: Slice) -> bool {
@@ -217,10 +211,13 @@ where
217211

218212
fn commit_raw(&mut self, slice: Slice) -> Result<(), VmError> {
219213
// Follower's private inputs are not committed in the ZK VM until finalization.
220-
let input_minus_follower = slice.to_range().difference(&self.follower_inputs);
214+
let input_minus_follower = slice.to_range().difference(&self.follower_input_ranges);
221215
let mut zk = self.zk.try_lock().unwrap();
222216
for input in input_minus_follower.iter_ranges() {
223-
zk.commit_raw(Slice::from_range_unchecked(input))?;
217+
zk.commit_raw(
218+
self.memory_map
219+
.try_get(Slice::from_range_unchecked(input))?,
220+
)?;
224221
}
225222

226223
self.mpc.try_lock().unwrap().commit_raw(slice)
@@ -231,7 +228,11 @@ where
231228
}
232229

233230
fn decode_raw(&mut self, slice: Slice) -> Result<DecodeFuture<BitVec>, VmError> {
234-
let fut = self.zk.try_lock().unwrap().decode_raw(slice)?;
231+
let fut = self
232+
.zk
233+
.try_lock()
234+
.unwrap()
235+
.decode_raw(self.memory_map.try_get(slice)?)?;
235236
self.outputs.push((slice, fut));
236237

237238
self.mpc.try_lock().unwrap().decode_raw(slice)
@@ -246,23 +247,27 @@ where
246247
type Error = VmError;
247248

248249
fn mark_public_raw(&mut self, slice: Slice) -> Result<(), VmError> {
249-
self.zk.try_lock().unwrap().mark_public_raw(slice)?;
250-
self.mpc.try_lock().unwrap().mark_public_raw(slice)
250+
self.mpc.try_lock().unwrap().mark_public_raw(slice)?;
251+
self.zk
252+
.try_lock()
253+
.unwrap()
254+
.mark_public_raw(self.memory_map.try_get(slice)?)
251255
}
252256

253257
fn mark_private_raw(&mut self, slice: Slice) -> Result<(), VmError> {
254258
let mut zk = self.zk.try_lock().unwrap();
255259
let mut mpc = self.mpc.try_lock().unwrap();
256260
match self.role {
257261
Role::Leader => {
258-
zk.mark_private_raw(slice)?;
259262
mpc.mark_private_raw(slice)?;
263+
zk.mark_private_raw(self.memory_map.try_get(slice)?)?;
260264
}
261265
Role::Follower => {
262-
// Follower's private inputs will become public during finalization.
263-
zk.mark_public_raw(slice)?;
264266
mpc.mark_private_raw(slice)?;
265-
self.follower_inputs.union_mut(&slice.to_range());
267+
// Follower's private inputs will become public during finalization.
268+
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
269+
self.follower_input_ranges.union_mut(&slice.to_range());
270+
self.follower_inputs.push(slice);
266271
}
267272
}
268273

@@ -274,14 +279,15 @@ where
274279
let mut mpc = self.mpc.try_lock().unwrap();
275280
match self.role {
276281
Role::Leader => {
277-
// Follower's private inputs will become public during finalization.
278-
zk.mark_public_raw(slice)?;
279282
mpc.mark_blind_raw(slice)?;
280-
self.follower_inputs.union_mut(&slice.to_range());
283+
// Follower's private inputs will become public during finalization.
284+
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
285+
self.follower_input_ranges.union_mut(&slice.to_range());
286+
self.follower_inputs.push(slice);
281287
}
282288
Role::Follower => {
283-
zk.mark_blind_raw(slice)?;
284289
mpc.mark_blind_raw(slice)?;
290+
zk.mark_blind_raw(self.memory_map.try_get(slice)?)?;
285291
}
286292
}
287293

@@ -295,14 +301,21 @@ where
295301
Zk: Vm<Binary>,
296302
{
297303
fn call_raw(&mut self, call: Call) -> Result<Slice, VmError> {
298-
if self.desync.load(Ordering::Relaxed) {
299-
return Err(VmError::memory(
300-
"DEAP VM memories are potentially desynchronized",
301-
));
304+
let (circ, inputs) = call.clone().into_parts();
305+
let mut builder = Call::builder(circ);
306+
307+
for input in inputs {
308+
builder = builder.arg(self.memory_map.try_get(input)?);
302309
}
303310

304-
self.zk.try_lock().unwrap().call_raw(call.clone())?;
305-
self.mpc.try_lock().unwrap().call_raw(call)
311+
let zk_call = builder.build().expect("call should be valid");
312+
313+
let output = self.mpc.try_lock().unwrap().call_raw(call)?;
314+
let zk_output = self.zk.try_lock().unwrap().call_raw(zk_call)?;
315+
316+
self.memory_map.insert(output, zk_output);
317+
318+
Ok(output)
306319
}
307320
}
308321

@@ -463,6 +476,90 @@ mod tests {
463476
assert_eq!(ct_leader, ct_follower);
464477
}
465478

479+
#[tokio::test]
480+
async fn test_deap_desync_memory() {
481+
let mut rng = StdRng::seed_from_u64(0);
482+
let delta_mpc = Delta::random(&mut rng);
483+
let delta_zk = Delta::random(&mut rng);
484+
485+
let (mut ctx_a, mut ctx_b) = test_st_context(8);
486+
let (rcot_send, rcot_recv) = ideal_rcot(Block::ZERO, delta_zk.into_inner());
487+
let (cot_send, cot_recv) = ideal_cot(delta_mpc.into_inner());
488+
489+
let gb = Garbler::new(cot_send, [0u8; 16], delta_mpc);
490+
let ev = Evaluator::new(cot_recv);
491+
let prover = Prover::new(rcot_recv);
492+
let verifier = Verifier::new(delta_zk, rcot_send);
493+
494+
let mut leader = Deap::new(Role::Leader, gb, prover);
495+
let mut follower = Deap::new(Role::Follower, ev, verifier);
496+
497+
// Desynchronize the memories.
498+
let _ = leader.zk().alloc_raw(1).unwrap();
499+
let _ = follower.zk().alloc_raw(1).unwrap();
500+
501+
let (ct_leader, ct_follower) = futures::join!(
502+
async {
503+
let key: Array<U8, 16> = leader.alloc().unwrap();
504+
let msg: Array<U8, 16> = leader.alloc().unwrap();
505+
506+
leader.mark_private(key).unwrap();
507+
leader.mark_blind(msg).unwrap();
508+
leader.assign(key, [42u8; 16]).unwrap();
509+
leader.commit(key).unwrap();
510+
leader.commit(msg).unwrap();
511+
512+
let ct: Array<U8, 16> = leader
513+
.call(
514+
Call::builder(AES128.clone())
515+
.arg(key)
516+
.arg(msg)
517+
.build()
518+
.unwrap(),
519+
)
520+
.unwrap();
521+
let ct = leader.decode(ct).unwrap();
522+
523+
leader.flush(&mut ctx_a).await.unwrap();
524+
leader.execute(&mut ctx_a).await.unwrap();
525+
leader.flush(&mut ctx_a).await.unwrap();
526+
leader.finalize(&mut ctx_a).await.unwrap();
527+
528+
ct.await.unwrap()
529+
},
530+
async {
531+
let key: Array<U8, 16> = follower.alloc().unwrap();
532+
let msg: Array<U8, 16> = follower.alloc().unwrap();
533+
534+
follower.mark_blind(key).unwrap();
535+
follower.mark_private(msg).unwrap();
536+
follower.assign(msg, [69u8; 16]).unwrap();
537+
follower.commit(key).unwrap();
538+
follower.commit(msg).unwrap();
539+
540+
let ct: Array<U8, 16> = follower
541+
.call(
542+
Call::builder(AES128.clone())
543+
.arg(key)
544+
.arg(msg)
545+
.build()
546+
.unwrap(),
547+
)
548+
.unwrap();
549+
let ct = follower.decode(ct).unwrap();
550+
551+
follower.flush(&mut ctx_b).await.unwrap();
552+
follower.execute(&mut ctx_b).await.unwrap();
553+
follower.flush(&mut ctx_b).await.unwrap();
554+
follower.finalize(&mut ctx_b).await.unwrap();
555+
556+
ct.await.unwrap()
557+
}
558+
);
559+
560+
assert_eq!(ct_leader, ct_follower);
561+
}
562+
466563
// Tests that the leader can not use different inputs in each VM without
467564
// detection by the follower.
468565
#[tokio::test]

0 commit comments

Comments
 (0)