Skip to content

Commit 0eb1158

Browse files
committed
feat(deap): address space mapping
1 parent 26ce60c commit 0eb1158

File tree

2 files changed

+268
-60
lines changed

2 files changed

+268
-60
lines changed

crates/components/deap/src/lib.rs

+157-60
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,9 @@
44
#![deny(clippy::all)]
55
#![forbid(unsafe_code)]
66

7-
use std::{
8-
mem,
9-
sync::{
10-
atomic::{AtomicBool, Ordering},
11-
Arc,
12-
},
13-
};
7+
mod map;
8+
9+
use std::{mem, sync::Arc};
1410

1511
use async_trait::async_trait;
1612
use mpz_common::Context;
@@ -38,11 +34,15 @@ pub struct Deap<Mpc, Zk> {
3834
role: Role,
3935
mpc: Arc<Mutex<Mpc>>,
4036
zk: Arc<Mutex<Zk>>,
41-
/// Private inputs of the follower.
42-
follower_inputs: RangeSet<usize>,
37+
/// Mapping between the memories of the MPC and ZK VMs.
38+
memory_map: map::MemoryMap,
39+
/// Ranges of the follower's private inputs in the MPC VM.
40+
follower_input_ranges: RangeSet<usize>,
41+
/// Private inputs of the follower in the MPC VM.
42+
follower_inputs: Vec<Slice>,
43+
/// Outputs of the follower from the ZK VM. The references
44+
/// correspond to the MPC VM.
4345
outputs: Vec<(Slice, DecodeFuture<BitVec>)>,
44-
/// Whether the memories of the two VMs are potentially desynchronized.
45-
desync: AtomicBool,
4646
}
4747

4848
impl<Mpc, Zk> Deap<Mpc, Zk> {
@@ -52,9 +52,10 @@ impl<Mpc, Zk> Deap<Mpc, Zk> {
5252
role,
5353
mpc: Arc::new(Mutex::new(mpc)),
5454
zk: Arc::new(Mutex::new(zk)),
55-
follower_inputs: RangeSet::default(),
55+
memory_map: map::MemoryMap::default(),
56+
follower_input_ranges: RangeSet::default(),
57+
follower_inputs: Vec::default(),
5658
outputs: Vec::default(),
57-
desync: AtomicBool::new(false),
5859
}
5960
}
6061

@@ -68,34 +69,28 @@ impl<Mpc, Zk> Deap<Mpc, Zk> {
6869

6970
/// Returns a mutable reference to the ZK VM.
7071
///
71-
/// # Note
72-
///
73-
/// After calling this method, allocations will no longer be allowed in the
74-
/// DEAP VM as the memory will potentially be desynchronized.
75-
///
7672
/// # Panics
7773
///
7874
/// Panics if the mutex is locked by another thread.
7975
pub fn zk(&self) -> MutexGuard<'_, Zk> {
80-
self.desync.store(true, Ordering::Relaxed);
8176
self.zk.try_lock().unwrap()
8277
}
8378

8479
/// Returns an owned mutex guard to the ZK VM.
8580
///
86-
/// # Note
87-
///
88-
/// After calling this method, allocations will no longer be allowed in the
89-
/// DEAP VM as the memory will potentially be desynchronized.
90-
///
9181
/// # Panics
9282
///
9383
/// Panics if the mutex is locked by another thread.
9484
pub fn zk_owned(&self) -> OwnedMutexGuard<Zk> {
95-
self.desync.store(true, Ordering::Relaxed);
9685
self.zk.clone().try_lock_owned().unwrap()
9786
}
9887

88+
/// Translates a slice from the MPC VM address space to the ZK VM address
89+
/// space.
90+
pub fn translate_slice(&self, slice: Slice) -> Result<Slice, VmError> {
91+
self.memory_map.try_get(slice)
92+
}
93+
9994
#[cfg(test)]
10095
fn mpc(&self) -> MutexGuard<'_, Mpc> {
10196
self.mpc.try_lock().unwrap()
@@ -124,18 +119,15 @@ where
124119
// MACs.
125120
let input_futs = self
126121
.follower_inputs
127-
.iter_ranges()
128-
.map(|input| mpc.decode_raw(Slice::from_range_unchecked(input)))
122+
.iter()
123+
.map(|&input| mpc.decode_raw(input))
129124
.collect::<Result<Vec<_>, _>>()?;
130125

131126
mpc.execute_all(ctx).await?;
132127

133128
// Assign inputs to the ZK VM.
134-
for (mut decode, input) in input_futs
135-
.into_iter()
136-
.zip(self.follower_inputs.iter_ranges())
137-
{
138-
let input = Slice::from_range_unchecked(input);
129+
for (mut decode, &input) in input_futs.into_iter().zip(&self.follower_inputs) {
130+
let input = self.memory_map.try_get(input)?;
139131

140132
// Follower has already assigned the inputs.
141133
if let Role::Leader = self.role {
@@ -185,30 +177,35 @@ where
185177
type Error = VmError;
186178

187179
fn alloc_raw(&mut self, size: usize) -> Result<Slice, VmError> {
188-
if self.desync.load(Ordering::Relaxed) {
189-
return Err(VmError::memory(
190-
"DEAP VM memories are potentially desynchronized",
191-
));
192-
}
180+
let mpc_slice = self.mpc.try_lock().unwrap().alloc_raw(size)?;
181+
let zk_slice = self.zk.try_lock().unwrap().alloc_raw(size)?;
193182

194-
self.zk.try_lock().unwrap().alloc_raw(size)?;
195-
self.mpc.try_lock().unwrap().alloc_raw(size)
183+
self.memory_map.insert(mpc_slice, zk_slice);
184+
185+
Ok(mpc_slice)
196186
}
197187

198188
fn assign_raw(&mut self, slice: Slice, data: BitVec) -> Result<(), VmError> {
199-
self.zk
189+
self.mpc
200190
.try_lock()
201191
.unwrap()
202192
.assign_raw(slice, data.clone())?;
203-
self.mpc.try_lock().unwrap().assign_raw(slice, data)
193+
194+
self.zk
195+
.try_lock()
196+
.unwrap()
197+
.assign_raw(self.memory_map.try_get(slice)?, data)
204198
}
205199

206200
fn commit_raw(&mut self, slice: Slice) -> Result<(), VmError> {
207201
// Follower's private inputs are not committed in the ZK VM until finalization.
208-
let input_minus_follower = slice.to_range().difference(&self.follower_inputs);
202+
let input_minus_follower = slice.to_range().difference(&self.follower_input_ranges);
209203
let mut zk = self.zk.try_lock().unwrap();
210204
for input in input_minus_follower.iter_ranges() {
211-
zk.commit_raw(Slice::from_range_unchecked(input))?;
205+
zk.commit_raw(
206+
self.memory_map
207+
.try_get(Slice::from_range_unchecked(input))?,
208+
)?;
212209
}
213210

214211
self.mpc.try_lock().unwrap().commit_raw(slice)
@@ -219,7 +216,11 @@ where
219216
}
220217

221218
fn decode_raw(&mut self, slice: Slice) -> Result<DecodeFuture<BitVec>, VmError> {
222-
let fut = self.zk.try_lock().unwrap().decode_raw(slice)?;
219+
let fut = self
220+
.zk
221+
.try_lock()
222+
.unwrap()
223+
.decode_raw(self.memory_map.try_get(slice)?)?;
223224
self.outputs.push((slice, fut));
224225

225226
self.mpc.try_lock().unwrap().decode_raw(slice)
@@ -234,23 +235,27 @@ where
234235
type Error = VmError;
235236

236237
fn mark_public_raw(&mut self, slice: Slice) -> Result<(), VmError> {
237-
self.zk.try_lock().unwrap().mark_public_raw(slice)?;
238-
self.mpc.try_lock().unwrap().mark_public_raw(slice)
238+
self.mpc.try_lock().unwrap().mark_public_raw(slice)?;
239+
self.zk
240+
.try_lock()
241+
.unwrap()
242+
.mark_public_raw(self.memory_map.try_get(slice)?)
239243
}
240244

241245
fn mark_private_raw(&mut self, slice: Slice) -> Result<(), VmError> {
242246
let mut zk = self.zk.try_lock().unwrap();
243247
let mut mpc = self.mpc.try_lock().unwrap();
244248
match self.role {
245249
Role::Leader => {
246-
zk.mark_private_raw(slice)?;
247250
mpc.mark_private_raw(slice)?;
251+
zk.mark_private_raw(self.memory_map.try_get(slice)?)?;
248252
}
249253
Role::Follower => {
250-
// Follower's private inputs will become public during finalization.
251-
zk.mark_public_raw(slice)?;
252254
mpc.mark_private_raw(slice)?;
253-
self.follower_inputs.union_mut(&slice.to_range());
255+
// Follower's private inputs will become public during finalization.
256+
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
257+
self.follower_input_ranges.union_mut(&slice.to_range());
258+
self.follower_inputs.push(slice);
254259
}
255260
}
256261

@@ -262,14 +267,15 @@ where
262267
let mut mpc = self.mpc.try_lock().unwrap();
263268
match self.role {
264269
Role::Leader => {
265-
// Follower's private inputs will become public during finalization.
266-
zk.mark_public_raw(slice)?;
267270
mpc.mark_blind_raw(slice)?;
268-
self.follower_inputs.union_mut(&slice.to_range());
271+
// Follower's private inputs will become public during finalization.
272+
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
273+
self.follower_input_ranges.union_mut(&slice.to_range());
274+
self.follower_inputs.push(slice);
269275
}
270276
Role::Follower => {
271-
zk.mark_blind_raw(slice)?;
272277
mpc.mark_blind_raw(slice)?;
278+
zk.mark_blind_raw(self.memory_map.try_get(slice)?)?;
273279
}
274280
}
275281

@@ -283,14 +289,21 @@ where
283289
Zk: Vm<Binary>,
284290
{
285291
fn call_raw(&mut self, call: Call) -> Result<Slice, VmError> {
286-
if self.desync.load(Ordering::Relaxed) {
287-
return Err(VmError::memory(
288-
"DEAP VM memories are potentially desynchronized",
289-
));
292+
let (circ, inputs) = call.clone().into_parts();
293+
let mut builder = Call::builder(circ);
294+
295+
for input in inputs {
296+
builder = builder.arg(self.memory_map.try_get(input)?);
290297
}
291298

292-
self.zk.try_lock().unwrap().call_raw(call.clone())?;
293-
self.mpc.try_lock().unwrap().call_raw(call)
299+
let zk_call = builder.build().expect("call should be valid");
300+
301+
let output = self.mpc.try_lock().unwrap().call_raw(call)?;
302+
let zk_output = self.zk.try_lock().unwrap().call_raw(zk_call)?;
303+
304+
self.memory_map.insert(output, zk_output);
305+
306+
Ok(output)
294307
}
295308
}
296309

@@ -451,6 +464,90 @@ mod tests {
451464
assert_eq!(ct_leader, ct_follower);
452465
}
453466

467+
#[tokio::test]
468+
async fn test_deap_desync_memory() {
469+
let mut rng = StdRng::seed_from_u64(0);
470+
let delta_mpc = Delta::random(&mut rng);
471+
let delta_zk = Delta::random(&mut rng);
472+
473+
let (mut ctx_a, mut ctx_b) = test_st_context(8);
474+
let (rcot_send, rcot_recv) = ideal_rcot(Block::ZERO, delta_zk.into_inner());
475+
let (cot_send, cot_recv) = ideal_cot(delta_mpc.into_inner());
476+
477+
let gb = Garbler::new(cot_send, [0u8; 16], delta_mpc);
478+
let ev = Evaluator::new(cot_recv);
479+
let prover = Prover::new(rcot_recv);
480+
let verifier = Verifier::new(delta_zk, rcot_send);
481+
482+
let mut leader = Deap::new(Role::Leader, gb, prover);
483+
let mut follower = Deap::new(Role::Follower, ev, verifier);
484+
485+
// Desynchronize the memories.
486+
let _ = leader.zk().alloc_raw(1).unwrap();
487+
let _ = follower.zk().alloc_raw(1).unwrap();
488+
489+
let (ct_leader, ct_follower) = futures::join!(
490+
async {
491+
let key: Array<U8, 16> = leader.alloc().unwrap();
492+
let msg: Array<U8, 16> = leader.alloc().unwrap();
493+
494+
leader.mark_private(key).unwrap();
495+
leader.mark_blind(msg).unwrap();
496+
leader.assign(key, [42u8; 16]).unwrap();
497+
leader.commit(key).unwrap();
498+
leader.commit(msg).unwrap();
499+
500+
let ct: Array<U8, 16> = leader
501+
.call(
502+
Call::builder(AES128.clone())
503+
.arg(key)
504+
.arg(msg)
505+
.build()
506+
.unwrap(),
507+
)
508+
.unwrap();
509+
let ct = leader.decode(ct).unwrap();
510+
511+
leader.flush(&mut ctx_a).await.unwrap();
512+
leader.execute(&mut ctx_a).await.unwrap();
513+
leader.flush(&mut ctx_a).await.unwrap();
514+
leader.finalize(&mut ctx_a).await.unwrap();
515+
516+
ct.await.unwrap()
517+
},
518+
async {
519+
let key: Array<U8, 16> = follower.alloc().unwrap();
520+
let msg: Array<U8, 16> = follower.alloc().unwrap();
521+
522+
follower.mark_blind(key).unwrap();
523+
follower.mark_private(msg).unwrap();
524+
follower.assign(msg, [69u8; 16]).unwrap();
525+
follower.commit(key).unwrap();
526+
follower.commit(msg).unwrap();
527+
528+
let ct: Array<U8, 16> = follower
529+
.call(
530+
Call::builder(AES128.clone())
531+
.arg(key)
532+
.arg(msg)
533+
.build()
534+
.unwrap(),
535+
)
536+
.unwrap();
537+
let ct = follower.decode(ct).unwrap();
538+
539+
follower.flush(&mut ctx_b).await.unwrap();
540+
follower.execute(&mut ctx_b).await.unwrap();
541+
follower.flush(&mut ctx_b).await.unwrap();
542+
follower.finalize(&mut ctx_b).await.unwrap();
543+
544+
ct.await.unwrap()
545+
}
546+
);
547+
548+
assert_eq!(ct_leader, ct_follower);
549+
}
550+
454551
// Tests that the leader can not use different inputs in each VM without
455552
// detection by the follower.
456553
#[tokio::test]

0 commit comments

Comments
 (0)