Skip to content

feat(deap): address space mapping #809

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: feat/prf-local
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,19 @@ tlsn-tls-core = { path = "crates/tls/core" }
tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" }
tlsn-verifier = { path = "crates/verifier" }

mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" }
mpz-memory-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" }
mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" }
mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" }
mpz-vm-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" }
mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" }
mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" }
mpz-ole = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" }
mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" }
mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" }
mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" }
mpz-zk = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" }
mpz-hash = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" }
mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "75928f7" }
mpz-memory-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "75928f7" }
mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "75928f7" }
mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "75928f7" }
mpz-vm-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "75928f7" }
mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "75928f7" }
mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "75928f7" }
mpz-ole = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "75928f7" }
mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "75928f7" }
mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "75928f7" }
mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "75928f7" }
mpz-zk = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "75928f7" }
mpz-hash = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "75928f7" }

rangeset = { version = "0.2" }
serio = { version = "0.2" }
Expand Down
217 changes: 157 additions & 60 deletions crates/components/deap/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,9 @@
#![deny(clippy::all)]
#![forbid(unsafe_code)]

use std::{
mem,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
mod map;

use std::{mem, sync::Arc};

use async_trait::async_trait;
use mpz_common::Context;
Expand Down Expand Up @@ -38,11 +34,15 @@ pub struct Deap<Mpc, Zk> {
role: Role,
mpc: Arc<Mutex<Mpc>>,
zk: Arc<Mutex<Zk>>,
/// Private inputs of the follower.
follower_inputs: RangeSet<usize>,
/// Mapping between the memories of the MPC and ZK VMs.
memory_map: map::MemoryMap,
/// Ranges of the follower's private inputs in the MPC VM.
follower_input_ranges: RangeSet<usize>,
/// Private inputs of the follower in the MPC VM.
follower_inputs: Vec<Slice>,
/// Outputs of the follower from the ZK VM. The references
/// correspond to the MPC VM.
outputs: Vec<(Slice, DecodeFuture<BitVec>)>,
/// Whether the memories of the two VMs are potentially desynchronized.
desync: AtomicBool,
}

impl<Mpc, Zk> Deap<Mpc, Zk> {
Expand All @@ -52,9 +52,10 @@ impl<Mpc, Zk> Deap<Mpc, Zk> {
role,
mpc: Arc::new(Mutex::new(mpc)),
zk: Arc::new(Mutex::new(zk)),
follower_inputs: RangeSet::default(),
memory_map: map::MemoryMap::default(),
follower_input_ranges: RangeSet::default(),
follower_inputs: Vec::default(),
outputs: Vec::default(),
desync: AtomicBool::new(false),
}
}

Expand All @@ -68,34 +69,28 @@ impl<Mpc, Zk> Deap<Mpc, Zk> {

/// Returns a mutable reference to the ZK VM.
///
/// # Note
///
/// After calling this method, allocations will no longer be allowed in the
/// DEAP VM as the memory will potentially be desynchronized.
///
/// # Panics
///
/// Panics if the mutex is locked by another thread.
pub fn zk(&self) -> MutexGuard<'_, Zk> {
self.desync.store(true, Ordering::Relaxed);
self.zk.try_lock().unwrap()
}

/// Returns an owned mutex guard to the ZK VM.
///
/// # Note
///
/// After calling this method, allocations will no longer be allowed in the
/// DEAP VM as the memory will potentially be desynchronized.
///
/// # Panics
///
/// Panics if the mutex is locked by another thread.
pub fn zk_owned(&self) -> OwnedMutexGuard<Zk> {
self.desync.store(true, Ordering::Relaxed);
self.zk.clone().try_lock_owned().unwrap()
}

/// Translates a slice from the MPC VM address space to the ZK VM address
/// space.
pub fn translate_slice(&self, slice: Slice) -> Result<Slice, VmError> {
self.memory_map.try_get(slice)
}

#[cfg(test)]
fn mpc(&self) -> MutexGuard<'_, Mpc> {
self.mpc.try_lock().unwrap()
Expand Down Expand Up @@ -124,18 +119,15 @@ where
// MACs.
let input_futs = self
.follower_inputs
.iter_ranges()
.map(|input| mpc.decode_raw(Slice::from_range_unchecked(input)))
.iter()
.map(|&input| mpc.decode_raw(input))
.collect::<Result<Vec<_>, _>>()?;

mpc.execute_all(ctx).await?;

// Assign inputs to the ZK VM.
for (mut decode, input) in input_futs
.into_iter()
.zip(self.follower_inputs.iter_ranges())
{
let input = Slice::from_range_unchecked(input);
for (mut decode, &input) in input_futs.into_iter().zip(&self.follower_inputs) {
let input = self.memory_map.try_get(input)?;

// Follower has already assigned the inputs.
if let Role::Leader = self.role {
Expand Down Expand Up @@ -185,30 +177,35 @@ where
type Error = VmError;

fn alloc_raw(&mut self, size: usize) -> Result<Slice, VmError> {
if self.desync.load(Ordering::Relaxed) {
return Err(VmError::memory(
"DEAP VM memories are potentially desynchronized",
));
}
let mpc_slice = self.mpc.try_lock().unwrap().alloc_raw(size)?;
let zk_slice = self.zk.try_lock().unwrap().alloc_raw(size)?;

self.zk.try_lock().unwrap().alloc_raw(size)?;
self.mpc.try_lock().unwrap().alloc_raw(size)
self.memory_map.insert(mpc_slice, zk_slice);

Ok(mpc_slice)
}

fn assign_raw(&mut self, slice: Slice, data: BitVec) -> Result<(), VmError> {
self.zk
self.mpc
.try_lock()
.unwrap()
.assign_raw(slice, data.clone())?;
self.mpc.try_lock().unwrap().assign_raw(slice, data)

self.zk
.try_lock()
.unwrap()
.assign_raw(self.memory_map.try_get(slice)?, data)
}

fn commit_raw(&mut self, slice: Slice) -> Result<(), VmError> {
// Follower's private inputs are not committed in the ZK VM until finalization.
let input_minus_follower = slice.to_range().difference(&self.follower_inputs);
let input_minus_follower = slice.to_range().difference(&self.follower_input_ranges);
let mut zk = self.zk.try_lock().unwrap();
for input in input_minus_follower.iter_ranges() {
zk.commit_raw(Slice::from_range_unchecked(input))?;
zk.commit_raw(
self.memory_map
.try_get(Slice::from_range_unchecked(input))?,
)?;
}

self.mpc.try_lock().unwrap().commit_raw(slice)
Expand All @@ -219,7 +216,11 @@ where
}

fn decode_raw(&mut self, slice: Slice) -> Result<DecodeFuture<BitVec>, VmError> {
let fut = self.zk.try_lock().unwrap().decode_raw(slice)?;
let fut = self
.zk
.try_lock()
.unwrap()
.decode_raw(self.memory_map.try_get(slice)?)?;
self.outputs.push((slice, fut));

self.mpc.try_lock().unwrap().decode_raw(slice)
Expand All @@ -234,23 +235,27 @@ where
type Error = VmError;

fn mark_public_raw(&mut self, slice: Slice) -> Result<(), VmError> {
self.zk.try_lock().unwrap().mark_public_raw(slice)?;
self.mpc.try_lock().unwrap().mark_public_raw(slice)
self.mpc.try_lock().unwrap().mark_public_raw(slice)?;
self.zk
.try_lock()
.unwrap()
.mark_public_raw(self.memory_map.try_get(slice)?)
}

fn mark_private_raw(&mut self, slice: Slice) -> Result<(), VmError> {
let mut zk = self.zk.try_lock().unwrap();
let mut mpc = self.mpc.try_lock().unwrap();
match self.role {
Role::Leader => {
zk.mark_private_raw(slice)?;
mpc.mark_private_raw(slice)?;
zk.mark_private_raw(self.memory_map.try_get(slice)?)?;
}
Role::Follower => {
// Follower's private inputs will become public during finalization.
zk.mark_public_raw(slice)?;
mpc.mark_private_raw(slice)?;
self.follower_inputs.union_mut(&slice.to_range());
// Follower's private inputs will become public during finalization.
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
self.follower_input_ranges.union_mut(&slice.to_range());
self.follower_inputs.push(slice);
}
}

Expand All @@ -262,14 +267,15 @@ where
let mut mpc = self.mpc.try_lock().unwrap();
match self.role {
Role::Leader => {
// Follower's private inputs will become public during finalization.
zk.mark_public_raw(slice)?;
mpc.mark_blind_raw(slice)?;
self.follower_inputs.union_mut(&slice.to_range());
// Follower's private inputs will become public during finalization.
zk.mark_public_raw(self.memory_map.try_get(slice)?)?;
self.follower_input_ranges.union_mut(&slice.to_range());
self.follower_inputs.push(slice);
}
Role::Follower => {
zk.mark_blind_raw(slice)?;
mpc.mark_blind_raw(slice)?;
zk.mark_blind_raw(self.memory_map.try_get(slice)?)?;
}
}

Expand All @@ -283,14 +289,21 @@ where
Zk: Vm<Binary>,
{
fn call_raw(&mut self, call: Call) -> Result<Slice, VmError> {
if self.desync.load(Ordering::Relaxed) {
return Err(VmError::memory(
"DEAP VM memories are potentially desynchronized",
));
let (circ, inputs) = call.clone().into_parts();
let mut builder = Call::builder(circ);

for input in inputs {
builder = builder.arg(self.memory_map.try_get(input)?);
}

self.zk.try_lock().unwrap().call_raw(call.clone())?;
self.mpc.try_lock().unwrap().call_raw(call)
let zk_call = builder.build().expect("call should be valid");

let output = self.mpc.try_lock().unwrap().call_raw(call)?;
let zk_output = self.zk.try_lock().unwrap().call_raw(zk_call)?;

self.memory_map.insert(output, zk_output);

Ok(output)
}
}

Expand Down Expand Up @@ -451,6 +464,90 @@ mod tests {
assert_eq!(ct_leader, ct_follower);
}

#[tokio::test]
async fn test_deap_desync_memory() {
let mut rng = StdRng::seed_from_u64(0);
let delta_mpc = Delta::random(&mut rng);
let delta_zk = Delta::random(&mut rng);

let (mut ctx_a, mut ctx_b) = test_st_context(8);
let (rcot_send, rcot_recv) = ideal_rcot(Block::ZERO, delta_zk.into_inner());
let (cot_send, cot_recv) = ideal_cot(delta_mpc.into_inner());

let gb = Garbler::new(cot_send, [0u8; 16], delta_mpc);
let ev = Evaluator::new(cot_recv);
let prover = Prover::new(rcot_recv);
let verifier = Verifier::new(delta_zk, rcot_send);

let mut leader = Deap::new(Role::Leader, gb, prover);
let mut follower = Deap::new(Role::Follower, ev, verifier);

// Desynchronize the memories.
let _ = leader.zk().alloc_raw(1).unwrap();
let _ = follower.zk().alloc_raw(1).unwrap();

let (ct_leader, ct_follower) = futures::join!(
async {
let key: Array<U8, 16> = leader.alloc().unwrap();
let msg: Array<U8, 16> = leader.alloc().unwrap();

leader.mark_private(key).unwrap();
leader.mark_blind(msg).unwrap();
leader.assign(key, [42u8; 16]).unwrap();
leader.commit(key).unwrap();
leader.commit(msg).unwrap();

let ct: Array<U8, 16> = leader
.call(
Call::builder(AES128.clone())
.arg(key)
.arg(msg)
.build()
.unwrap(),
)
.unwrap();
let ct = leader.decode(ct).unwrap();

leader.flush(&mut ctx_a).await.unwrap();
leader.execute(&mut ctx_a).await.unwrap();
leader.flush(&mut ctx_a).await.unwrap();
leader.finalize(&mut ctx_a).await.unwrap();

ct.await.unwrap()
},
async {
let key: Array<U8, 16> = follower.alloc().unwrap();
let msg: Array<U8, 16> = follower.alloc().unwrap();

follower.mark_blind(key).unwrap();
follower.mark_private(msg).unwrap();
follower.assign(msg, [69u8; 16]).unwrap();
follower.commit(key).unwrap();
follower.commit(msg).unwrap();

let ct: Array<U8, 16> = follower
.call(
Call::builder(AES128.clone())
.arg(key)
.arg(msg)
.build()
.unwrap(),
)
.unwrap();
let ct = follower.decode(ct).unwrap();

follower.flush(&mut ctx_b).await.unwrap();
follower.execute(&mut ctx_b).await.unwrap();
follower.flush(&mut ctx_b).await.unwrap();
follower.finalize(&mut ctx_b).await.unwrap();

ct.await.unwrap()
}
);

assert_eq!(ct_leader, ct_follower);
}

// Tests that the leader can not use different inputs in each VM without
// detection by the follower.
#[tokio::test]
Expand Down
Loading
Loading