diff --git a/Cargo.toml b/Cargo.toml index 9202626a8..c0d4491ed 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -69,19 +69,19 @@ tlsn-tls-core = { path = "crates/tls/core" } tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" } tlsn-verifier = { path = "crates/verifier" } -mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } -mpz-memory-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } -mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } -mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } -mpz-vm-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } -mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } -mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } -mpz-ole = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } -mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } -mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } -mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } -mpz-zk = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } -mpz-hash = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } +mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "75928f7" } +mpz-memory-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "75928f7" } +mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "75928f7" } +mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "75928f7" } +mpz-vm-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "75928f7" } +mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "75928f7" } +mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "75928f7" } +mpz-ole = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "75928f7" } +mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "75928f7" } +mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "75928f7" } +mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "75928f7" } +mpz-zk = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "75928f7" } +mpz-hash = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "75928f7" } rangeset = { version = "0.2" } serio = { version = "0.2" } diff --git a/crates/components/deap/src/lib.rs b/crates/components/deap/src/lib.rs index 9d3faa95d..38c669ab9 100644 --- a/crates/components/deap/src/lib.rs +++ b/crates/components/deap/src/lib.rs @@ -4,13 +4,9 @@ #![deny(clippy::all)] #![forbid(unsafe_code)] -use std::{ - mem, - sync::{ - atomic::{AtomicBool, Ordering}, - Arc, - }, -}; +mod map; + +use std::{mem, sync::Arc}; use async_trait::async_trait; use mpz_common::Context; @@ -38,11 +34,15 @@ pub struct Deap { role: Role, mpc: Arc>, zk: Arc>, - /// Private inputs of the follower. - follower_inputs: RangeSet, + /// Mapping between the memories of the MPC and ZK VMs. + memory_map: map::MemoryMap, + /// Ranges of the follower's private inputs in the MPC VM. + follower_input_ranges: RangeSet, + /// Private inputs of the follower in the MPC VM. + follower_inputs: Vec, + /// Outputs of the follower from the ZK VM. The references + /// correspond to the MPC VM. outputs: Vec<(Slice, DecodeFuture)>, - /// Whether the memories of the two VMs are potentially desynchronized. - desync: AtomicBool, } impl Deap { @@ -52,9 +52,10 @@ impl Deap { role, mpc: Arc::new(Mutex::new(mpc)), zk: Arc::new(Mutex::new(zk)), - follower_inputs: RangeSet::default(), + memory_map: map::MemoryMap::default(), + follower_input_ranges: RangeSet::default(), + follower_inputs: Vec::default(), outputs: Vec::default(), - desync: AtomicBool::new(false), } } @@ -68,34 +69,28 @@ impl Deap { /// Returns a mutable reference to the ZK VM. /// - /// # Note - /// - /// After calling this method, allocations will no longer be allowed in the - /// DEAP VM as the memory will potentially be desynchronized. - /// /// # Panics /// /// Panics if the mutex is locked by another thread. pub fn zk(&self) -> MutexGuard<'_, Zk> { - self.desync.store(true, Ordering::Relaxed); self.zk.try_lock().unwrap() } /// Returns an owned mutex guard to the ZK VM. /// - /// # Note - /// - /// After calling this method, allocations will no longer be allowed in the - /// DEAP VM as the memory will potentially be desynchronized. - /// /// # Panics /// /// Panics if the mutex is locked by another thread. pub fn zk_owned(&self) -> OwnedMutexGuard { - self.desync.store(true, Ordering::Relaxed); self.zk.clone().try_lock_owned().unwrap() } + /// Translates a slice from the MPC VM address space to the ZK VM address + /// space. + pub fn translate_slice(&self, slice: Slice) -> Result { + self.memory_map.try_get(slice) + } + #[cfg(test)] fn mpc(&self) -> MutexGuard<'_, Mpc> { self.mpc.try_lock().unwrap() @@ -124,18 +119,15 @@ where // MACs. let input_futs = self .follower_inputs - .iter_ranges() - .map(|input| mpc.decode_raw(Slice::from_range_unchecked(input))) + .iter() + .map(|&input| mpc.decode_raw(input)) .collect::, _>>()?; mpc.execute_all(ctx).await?; // Assign inputs to the ZK VM. - for (mut decode, input) in input_futs - .into_iter() - .zip(self.follower_inputs.iter_ranges()) - { - let input = Slice::from_range_unchecked(input); + for (mut decode, &input) in input_futs.into_iter().zip(&self.follower_inputs) { + let input = self.memory_map.try_get(input)?; // Follower has already assigned the inputs. if let Role::Leader = self.role { @@ -185,30 +177,35 @@ where type Error = VmError; fn alloc_raw(&mut self, size: usize) -> Result { - if self.desync.load(Ordering::Relaxed) { - return Err(VmError::memory( - "DEAP VM memories are potentially desynchronized", - )); - } + let mpc_slice = self.mpc.try_lock().unwrap().alloc_raw(size)?; + let zk_slice = self.zk.try_lock().unwrap().alloc_raw(size)?; - self.zk.try_lock().unwrap().alloc_raw(size)?; - self.mpc.try_lock().unwrap().alloc_raw(size) + self.memory_map.insert(mpc_slice, zk_slice); + + Ok(mpc_slice) } fn assign_raw(&mut self, slice: Slice, data: BitVec) -> Result<(), VmError> { - self.zk + self.mpc .try_lock() .unwrap() .assign_raw(slice, data.clone())?; - self.mpc.try_lock().unwrap().assign_raw(slice, data) + + self.zk + .try_lock() + .unwrap() + .assign_raw(self.memory_map.try_get(slice)?, data) } fn commit_raw(&mut self, slice: Slice) -> Result<(), VmError> { // Follower's private inputs are not committed in the ZK VM until finalization. - let input_minus_follower = slice.to_range().difference(&self.follower_inputs); + let input_minus_follower = slice.to_range().difference(&self.follower_input_ranges); let mut zk = self.zk.try_lock().unwrap(); for input in input_minus_follower.iter_ranges() { - zk.commit_raw(Slice::from_range_unchecked(input))?; + zk.commit_raw( + self.memory_map + .try_get(Slice::from_range_unchecked(input))?, + )?; } self.mpc.try_lock().unwrap().commit_raw(slice) @@ -219,7 +216,11 @@ where } fn decode_raw(&mut self, slice: Slice) -> Result, VmError> { - let fut = self.zk.try_lock().unwrap().decode_raw(slice)?; + let fut = self + .zk + .try_lock() + .unwrap() + .decode_raw(self.memory_map.try_get(slice)?)?; self.outputs.push((slice, fut)); self.mpc.try_lock().unwrap().decode_raw(slice) @@ -234,8 +235,11 @@ where type Error = VmError; fn mark_public_raw(&mut self, slice: Slice) -> Result<(), VmError> { - self.zk.try_lock().unwrap().mark_public_raw(slice)?; - self.mpc.try_lock().unwrap().mark_public_raw(slice) + self.mpc.try_lock().unwrap().mark_public_raw(slice)?; + self.zk + .try_lock() + .unwrap() + .mark_public_raw(self.memory_map.try_get(slice)?) } fn mark_private_raw(&mut self, slice: Slice) -> Result<(), VmError> { @@ -243,14 +247,15 @@ where let mut mpc = self.mpc.try_lock().unwrap(); match self.role { Role::Leader => { - zk.mark_private_raw(slice)?; mpc.mark_private_raw(slice)?; + zk.mark_private_raw(self.memory_map.try_get(slice)?)?; } Role::Follower => { - // Follower's private inputs will become public during finalization. - zk.mark_public_raw(slice)?; mpc.mark_private_raw(slice)?; - self.follower_inputs.union_mut(&slice.to_range()); + // Follower's private inputs will become public during finalization. + zk.mark_public_raw(self.memory_map.try_get(slice)?)?; + self.follower_input_ranges.union_mut(&slice.to_range()); + self.follower_inputs.push(slice); } } @@ -262,14 +267,15 @@ where let mut mpc = self.mpc.try_lock().unwrap(); match self.role { Role::Leader => { - // Follower's private inputs will become public during finalization. - zk.mark_public_raw(slice)?; mpc.mark_blind_raw(slice)?; - self.follower_inputs.union_mut(&slice.to_range()); + // Follower's private inputs will become public during finalization. + zk.mark_public_raw(self.memory_map.try_get(slice)?)?; + self.follower_input_ranges.union_mut(&slice.to_range()); + self.follower_inputs.push(slice); } Role::Follower => { - zk.mark_blind_raw(slice)?; mpc.mark_blind_raw(slice)?; + zk.mark_blind_raw(self.memory_map.try_get(slice)?)?; } } @@ -283,14 +289,21 @@ where Zk: Vm, { fn call_raw(&mut self, call: Call) -> Result { - if self.desync.load(Ordering::Relaxed) { - return Err(VmError::memory( - "DEAP VM memories are potentially desynchronized", - )); + let (circ, inputs) = call.clone().into_parts(); + let mut builder = Call::builder(circ); + + for input in inputs { + builder = builder.arg(self.memory_map.try_get(input)?); } - self.zk.try_lock().unwrap().call_raw(call.clone())?; - self.mpc.try_lock().unwrap().call_raw(call) + let zk_call = builder.build().expect("call should be valid"); + + let output = self.mpc.try_lock().unwrap().call_raw(call)?; + let zk_output = self.zk.try_lock().unwrap().call_raw(zk_call)?; + + self.memory_map.insert(output, zk_output); + + Ok(output) } } @@ -451,6 +464,90 @@ mod tests { assert_eq!(ct_leader, ct_follower); } + #[tokio::test] + async fn test_deap_desync_memory() { + let mut rng = StdRng::seed_from_u64(0); + let delta_mpc = Delta::random(&mut rng); + let delta_zk = Delta::random(&mut rng); + + let (mut ctx_a, mut ctx_b) = test_st_context(8); + let (rcot_send, rcot_recv) = ideal_rcot(Block::ZERO, delta_zk.into_inner()); + let (cot_send, cot_recv) = ideal_cot(delta_mpc.into_inner()); + + let gb = Garbler::new(cot_send, [0u8; 16], delta_mpc); + let ev = Evaluator::new(cot_recv); + let prover = Prover::new(rcot_recv); + let verifier = Verifier::new(delta_zk, rcot_send); + + let mut leader = Deap::new(Role::Leader, gb, prover); + let mut follower = Deap::new(Role::Follower, ev, verifier); + + // Desynchronize the memories. + let _ = leader.zk().alloc_raw(1).unwrap(); + let _ = follower.zk().alloc_raw(1).unwrap(); + + let (ct_leader, ct_follower) = futures::join!( + async { + let key: Array = leader.alloc().unwrap(); + let msg: Array = leader.alloc().unwrap(); + + leader.mark_private(key).unwrap(); + leader.mark_blind(msg).unwrap(); + leader.assign(key, [42u8; 16]).unwrap(); + leader.commit(key).unwrap(); + leader.commit(msg).unwrap(); + + let ct: Array = leader + .call( + Call::builder(AES128.clone()) + .arg(key) + .arg(msg) + .build() + .unwrap(), + ) + .unwrap(); + let ct = leader.decode(ct).unwrap(); + + leader.flush(&mut ctx_a).await.unwrap(); + leader.execute(&mut ctx_a).await.unwrap(); + leader.flush(&mut ctx_a).await.unwrap(); + leader.finalize(&mut ctx_a).await.unwrap(); + + ct.await.unwrap() + }, + async { + let key: Array = follower.alloc().unwrap(); + let msg: Array = follower.alloc().unwrap(); + + follower.mark_blind(key).unwrap(); + follower.mark_private(msg).unwrap(); + follower.assign(msg, [69u8; 16]).unwrap(); + follower.commit(key).unwrap(); + follower.commit(msg).unwrap(); + + let ct: Array = follower + .call( + Call::builder(AES128.clone()) + .arg(key) + .arg(msg) + .build() + .unwrap(), + ) + .unwrap(); + let ct = follower.decode(ct).unwrap(); + + follower.flush(&mut ctx_b).await.unwrap(); + follower.execute(&mut ctx_b).await.unwrap(); + follower.flush(&mut ctx_b).await.unwrap(); + follower.finalize(&mut ctx_b).await.unwrap(); + + ct.await.unwrap() + } + ); + + assert_eq!(ct_leader, ct_follower); + } + // Tests that the leader can not use different inputs in each VM without // detection by the follower. #[tokio::test] diff --git a/crates/components/deap/src/map.rs b/crates/components/deap/src/map.rs new file mode 100644 index 000000000..7d9807bc7 --- /dev/null +++ b/crates/components/deap/src/map.rs @@ -0,0 +1,111 @@ +use std::ops::Range; + +use mpz_vm_core::{memory::Slice, VmError}; +use rangeset::Subset; + +/// A mapping between the memories of the MPC and ZK VMs. +#[derive(Debug, Default)] +pub(crate) struct MemoryMap { + mpc: Vec>, + zk: Vec>, +} + +impl MemoryMap { + /// Inserts a new allocation into the map. + /// + /// # Panics + /// + /// - If the slices are not inserted in the order they are allocated. + /// - If the slices are not the same length. + pub(crate) fn insert(&mut self, mpc: Slice, zk: Slice) { + let mpc = mpc.to_range(); + let zk = zk.to_range(); + + assert_eq!(mpc.len(), zk.len(), "slices must be the same length"); + + if let Some(last) = self.mpc.last() { + if last.end > mpc.start { + panic!("slices must be provided in ascending order"); + } + } + + self.mpc.push(mpc); + self.zk.push(zk); + } + + /// Returns the corresponding allocation in the ZK VM. + pub(crate) fn try_get(&self, mpc: Slice) -> Result { + let mpc_range = mpc.to_range(); + let pos = match self + .mpc + .binary_search_by_key(&mpc_range.start, |range| range.start) + { + Ok(pos) => pos, + Err(0) => return Err(VmError::memory(format!("invalid memory slice: {mpc}"))), + Err(pos) => pos - 1, + }; + + let candidate = &self.mpc[pos]; + if mpc_range.is_subset(candidate) { + let offset = mpc_range.start - candidate.start; + let start = self.zk[pos].start + offset; + let slice = Slice::from_range_unchecked(start..start + mpc_range.len()); + + Ok(slice) + } else { + Err(VmError::memory(format!("invalid memory slice: {mpc}"))) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_map() { + let mut map = MemoryMap::default(); + map.insert( + Slice::from_range_unchecked(0..10), + Slice::from_range_unchecked(10..20), + ); + + // Range is fully contained. + assert_eq!( + map.try_get(Slice::from_range_unchecked(0..10)).unwrap(), + Slice::from_range_unchecked(10..20) + ); + // Range is subset. + assert_eq!( + map.try_get(Slice::from_range_unchecked(1..9)).unwrap(), + Slice::from_range_unchecked(11..19) + ); + // Range is not subset. + assert!(map.try_get(Slice::from_range_unchecked(0..11)).is_err()); + + // Insert another range. + map.insert( + Slice::from_range_unchecked(20..30), + Slice::from_range_unchecked(30..40), + ); + assert_eq!( + map.try_get(Slice::from_range_unchecked(20..30)).unwrap(), + Slice::from_range_unchecked(30..40) + ); + assert_eq!( + map.try_get(Slice::from_range_unchecked(21..29)).unwrap(), + Slice::from_range_unchecked(31..39) + ); + assert!(map.try_get(Slice::from_range_unchecked(19..21)).is_err()); + } + + #[test] + #[should_panic] + fn test_map_length_mismatch() { + let mut map = MemoryMap::default(); + map.insert( + Slice::from_range_unchecked(5..10), + Slice::from_range_unchecked(20..30), + ); + } +} diff --git a/crates/mpc-tls/tests/test.rs b/crates/mpc-tls/tests/test.rs index 89ac59f1b..57062c216 100644 --- a/crates/mpc-tls/tests/test.rs +++ b/crates/mpc-tls/tests/test.rs @@ -129,42 +129,34 @@ fn build_pair(config: Config) -> (MpcTlsLeader, MpcTlsFollower) { let (rcot_send_a, rcot_recv_b) = ideal_rcot(Block::random(&mut rng), delta_a.into_inner()); let (rcot_send_b, rcot_recv_a) = ideal_rcot(Block::random(&mut rng), delta_b.into_inner()); - let mut rcot_send_a = SharedRCOTSender::new(4, rcot_send_a); - let mut rcot_send_b = SharedRCOTSender::new(1, rcot_send_b); - let mut rcot_recv_a = SharedRCOTReceiver::new(1, rcot_recv_a); - let mut rcot_recv_b = SharedRCOTReceiver::new(4, rcot_recv_b); + let rcot_send_a = SharedRCOTSender::new(rcot_send_a); + let rcot_send_b = SharedRCOTSender::new(rcot_send_b); + let rcot_recv_a = SharedRCOTReceiver::new(rcot_recv_a); + let rcot_recv_b = SharedRCOTReceiver::new(rcot_recv_b); let mpc_a = Arc::new(Mutex::new(Garbler::new( - DerandCOTSender::new(rcot_send_a.next().unwrap()), + DerandCOTSender::new(rcot_send_a.clone()), rand::rng().random(), delta_a, ))); let mpc_b = Arc::new(Mutex::new(Evaluator::new(DerandCOTReceiver::new( - rcot_recv_b.next().unwrap(), + rcot_recv_b.clone(), )))); let leader = MpcTlsLeader::new( config.clone(), ctx_a, mpc_a, - ( - rcot_send_a.next().unwrap(), - rcot_send_a.next().unwrap(), - rcot_send_a.next().unwrap(), - ), - rcot_recv_a.next().unwrap(), + (rcot_send_a.clone(), rcot_send_a.clone(), rcot_send_a), + rcot_recv_a, ); let follower = MpcTlsFollower::new( config, ctx_b, mpc_b, - rcot_send_b.next().unwrap(), - ( - rcot_recv_b.next().unwrap(), - rcot_recv_b.next().unwrap(), - rcot_recv_b.next().unwrap(), - ), + rcot_send_b, + (rcot_recv_b.clone(), rcot_recv_b.clone(), rcot_recv_b), ); (leader, follower) diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs index e33f37648..f7804e969 100644 --- a/crates/prover/src/lib.rs +++ b/crates/prover/src/lib.rs @@ -364,16 +364,16 @@ fn build_mpc_tls(config: &ProverConfig, ctx: Context) -> (Arc (Arc