diff --git a/Cargo.toml b/Cargo.toml index 4c26eef810..9202626a84 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,8 +8,7 @@ members = [ "crates/common", "crates/components/deap", "crates/components/cipher", - #"crates/components/hmac-sha256", - #"crates/components/hmac-sha256-circuits", + "crates/components/hmac-sha256", "crates/components/key-exchange", "crates/core", "crates/data-fixtures", @@ -57,8 +56,7 @@ tlsn-core = { path = "crates/core" } tlsn-data-fixtures = { path = "crates/data-fixtures" } tlsn-deap = { path = "crates/components/deap" } tlsn-formats = { path = "crates/formats" } -#tlsn-hmac-sha256 = { path = "crates/components/hmac-sha256" } -#tlsn-hmac-sha256-circuits = { path = "crates/components/hmac-sha256-circuits" } +tlsn-hmac-sha256 = { path = "crates/components/hmac-sha256" } tlsn-key-exchange = { path = "crates/components/key-exchange" } tlsn-mpc-tls = { path = "crates/mpc-tls" } tlsn-prover = { path = "crates/prover" } @@ -71,18 +69,19 @@ tlsn-tls-core = { path = "crates/tls/core" } tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" } tlsn-verifier = { path = "crates/verifier" } -mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" } -mpz-memory-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" } -mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" } -mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" } -mpz-vm-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" } -mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" } -mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" } -mpz-ole = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" } -mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" } -mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" } -mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" } -mpz-zk = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" } +mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } +mpz-memory-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } +mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } +mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } +mpz-vm-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } +mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } +mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } +mpz-ole = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } +mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } +mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } +mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } +mpz-zk = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } +mpz-hash = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } rangeset = { version = "0.2" } serio = { version = "0.2" } diff --git a/crates/benches/binary/Cargo.toml b/crates/benches/binary/Cargo.toml index 3d9cfc2611..8ba61c5c54 100644 --- a/crates/benches/binary/Cargo.toml +++ b/crates/benches/binary/Cargo.toml @@ -18,10 +18,10 @@ mpz-core = { workspace = true } mpz-garble = { workspace = true } mpz-ot = { workspace = true, features = ["ideal"] } tlsn-benches-library = { workspace = true } -tlsn-benches-browser-native = { workspace = true, optional = true} +tlsn-benches-browser-native = { workspace = true, optional = true } tlsn-common = { workspace = true } tlsn-core = { workspace = true } -#tlsn-hmac-sha256 = { workspace = true } +tlsn-hmac-sha256 = { workspace = true } tlsn-prover = { workspace = true } tlsn-server-fixture = { workspace = true } tlsn-server-fixture-certs = { workspace = true } @@ -30,7 +30,7 @@ tlsn-verifier = { workspace = true } anyhow = { workspace = true } async-trait = { workspace = true } -charming = {version = "0.3.1", features = ["ssr"]} +charming = { version = "0.3.1", features = ["ssr"] } csv = "1.3.0" dhat = { version = "0.3.3" } env_logger = { version = "0.6.0", default-features = false } @@ -46,7 +46,8 @@ tokio = { workspace = true, features = [ ] } tokio-util = { workspace = true } toml = "0.8.11" -tracing-subscriber = {workspace = true, features = ["env-filter"]} +tracing-subscriber = { workspace = true, features = ["env-filter"] } +rand = { workspace = true } [[bin]] name = "bench" diff --git a/crates/benches/binary/src/lib.rs b/crates/benches/binary/src/lib.rs index ab1c9dce2f..9225e78ca0 100644 --- a/crates/benches/binary/src/lib.rs +++ b/crates/benches/binary/src/lib.rs @@ -1,6 +1,5 @@ pub mod config; pub mod metrics; -mod preprocess; pub mod prover; pub mod prover_main; pub mod verifier_main; diff --git a/crates/benches/binary/src/preprocess.rs b/crates/benches/binary/src/preprocess.rs deleted file mode 100644 index 8e0e9e080f..0000000000 --- a/crates/benches/binary/src/preprocess.rs +++ /dev/null @@ -1,5 +0,0 @@ -use hmac_sha256::build_circuits; - -pub async fn preprocess_prf_circuits() { - build_circuits().await; -} diff --git a/crates/benches/binary/src/prover_main.rs b/crates/benches/binary/src/prover_main.rs index a72330d996..dc1e1d23e0 100644 --- a/crates/benches/binary/src/prover_main.rs +++ b/crates/benches/binary/src/prover_main.rs @@ -14,7 +14,6 @@ use std::{ use crate::{ config::{BenchInstance, Config}, metrics::Metrics, - preprocess::preprocess_prf_circuits, set_interface, PROVER_INTERFACE, }; use anyhow::Context; @@ -58,10 +57,6 @@ pub async fn prover_main(is_memory_profiling: bool) -> anyhow::Result<()> { .open("metrics.csv") .context("failed to open metrics file")?; - // Preprocess the PRF circuits as they are allocating a lot of memory, which - // don't need to be accounted for in the benchmarks. - preprocess_prf_circuits().await; - { let mut metric_wrt = WriterBuilder::new() // If file is not empty, assume that the CSV header is already present in the file. diff --git a/crates/benches/binary/src/verifier_main.rs b/crates/benches/binary/src/verifier_main.rs index d76752975c..f14c41f9ea 100644 --- a/crates/benches/binary/src/verifier_main.rs +++ b/crates/benches/binary/src/verifier_main.rs @@ -4,7 +4,6 @@ use crate::{ config::{BenchInstance, Config}, - preprocess::preprocess_prf_circuits, set_interface, VERIFIER_INTERFACE, }; use tls_core::verify::WebPkiVerifier; @@ -40,10 +39,6 @@ pub async fn verifier_main(is_memory_profiling: bool) -> anyhow::Result<()> { .await .context("failed to bind to port")?; - // Preprocess the PRF circuits as they are allocating a lot of memory, which - // don't need to be accounted for in the benchmarks. - preprocess_prf_circuits().await; - for bench in config.benches { for instance in bench.flatten() { if is_memory_profiling && !instance.memory_profile { diff --git a/crates/common/src/config.rs b/crates/common/src/config.rs index bf9dd0a25f..2c327bfd2f 100644 --- a/crates/common/src/config.rs +++ b/crates/common/src/config.rs @@ -41,6 +41,9 @@ pub struct ProtocolConfig { /// of the MPC-TLS connection. #[builder(default = "true")] defer_decryption_from_start: bool, + /// Network settings. + #[builder(default)] + network: NetworkSetting, /// Version that is being run by prover/verifier. #[builder(setter(skip), default = "VERSION.clone()")] version: Version, @@ -95,6 +98,11 @@ impl ProtocolConfig { pub fn defer_decryption_from_start(&self) -> bool { self.defer_decryption_from_start } + + /// Returns the network settings. + pub fn network(&self) -> NetworkSetting { + self.network + } } /// Protocol configuration validator used by checker (i.e. verifier) to perform @@ -216,6 +224,24 @@ impl ProtocolConfigValidator { } } +/// Settings for the network environment. +/// +/// Provides optimization options to adapt the protocol to different network +/// situations. +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub enum NetworkSetting { + /// Prefers a bandwidth-heavy protocol. + Bandwidth, + /// Prefers a latency-heavy protocol. + Latency, +} + +impl Default for NetworkSetting { + fn default() -> Self { + Self::Bandwidth + } +} + /// A ProtocolConfig error. #[derive(thiserror::Error, Debug)] pub struct ProtocolConfigError { diff --git a/crates/components/hmac-sha256-circuits/Cargo.toml b/crates/components/hmac-sha256-circuits/Cargo.toml deleted file mode 100644 index 55df9ebad5..0000000000 --- a/crates/components/hmac-sha256-circuits/Cargo.toml +++ /dev/null @@ -1,22 +0,0 @@ -[package] -name = "tlsn-hmac-sha256-circuits" -authors = ["TLSNotary Team"] -description = "The 2PC circuits for TLS HMAC-SHA256 PRF" -keywords = ["tls", "mpc", "2pc", "hmac", "sha256"] -categories = ["cryptography"] -license = "MIT OR Apache-2.0" -version = "0.1.0-alpha.11-pre" -edition = "2021" - -[lints] -workspace = true - -[lib] -name = "hmac_sha256_circuits" - -[dependencies] -mpz-circuits = { workspace = true } -tracing = { workspace = true } - -[dev-dependencies] -ring = { workspace = true } diff --git a/crates/components/hmac-sha256-circuits/src/hmac_sha256.rs b/crates/components/hmac-sha256-circuits/src/hmac_sha256.rs deleted file mode 100644 index 5f135cea68..0000000000 --- a/crates/components/hmac-sha256-circuits/src/hmac_sha256.rs +++ /dev/null @@ -1,159 +0,0 @@ -use std::cell::RefCell; - -use mpz_circuits::{ - circuits::{sha256, sha256_compress, sha256_compress_trace, sha256_trace}, - types::{U32, U8}, - BuilderState, Tracer, -}; - -static SHA256_INITIAL_STATE: [u32; 8] = [ - 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, -]; - -/// Returns the outer and inner states of HMAC-SHA256 with the provided key. -/// -/// Outer state is H(key ⊕ opad) -/// -/// Inner state is H(key ⊕ ipad) -/// -/// # Arguments -/// -/// * `builder_state` - Reference to builder state. -/// * `key` - N-byte key (must be <= 64 bytes). -pub fn hmac_sha256_partial_trace<'a>( - builder_state: &'a RefCell, - key: &[Tracer<'a, U8>], -) -> ([Tracer<'a, U32>; 8], [Tracer<'a, U32>; 8]) { - assert!(key.len() <= 64); - - let mut opad = [Tracer::new( - builder_state, - builder_state.borrow_mut().get_constant(0x5cu8), - ); 64]; - - let mut ipad = [Tracer::new( - builder_state, - builder_state.borrow_mut().get_constant(0x36u8), - ); 64]; - - key.iter().enumerate().for_each(|(i, k)| { - opad[i] = opad[i] ^ *k; - ipad[i] = ipad[i] ^ *k; - }); - - let sha256_initial_state: [_; 8] = SHA256_INITIAL_STATE - .map(|v| Tracer::new(builder_state, builder_state.borrow_mut().get_constant(v))); - - let outer_state = sha256_compress_trace(builder_state, sha256_initial_state, opad); - let inner_state = sha256_compress_trace(builder_state, sha256_initial_state, ipad); - - (outer_state, inner_state) -} - -/// Reference implementation of HMAC-SHA256 partial function. -/// -/// Returns the outer and inner states of HMAC-SHA256 with the provided key. -/// -/// Outer state is H(key ⊕ opad) -/// -/// Inner state is H(key ⊕ ipad) -/// -/// # Arguments -/// -/// * `key` - N-byte key (must be <= 64 bytes). -pub fn hmac_sha256_partial(key: &[u8]) -> ([u32; 8], [u32; 8]) { - assert!(key.len() <= 64); - - let mut opad = [0x5cu8; 64]; - let mut ipad = [0x36u8; 64]; - - key.iter().enumerate().for_each(|(i, k)| { - opad[i] ^= k; - ipad[i] ^= k; - }); - - let outer_state = sha256_compress(SHA256_INITIAL_STATE, opad); - let inner_state = sha256_compress(SHA256_INITIAL_STATE, ipad); - - (outer_state, inner_state) -} - -/// HMAC-SHA256 finalization function. -/// -/// Returns the HMAC-SHA256 digest of the provided message using existing outer -/// and inner states. -/// -/// # Arguments -/// -/// * `outer_state` - 256-bit outer state. -/// * `inner_state` - 256-bit inner state. -/// * `msg` - N-byte message. -pub fn hmac_sha256_finalize_trace<'a>( - builder_state: &'a RefCell, - outer_state: [Tracer<'a, U32>; 8], - inner_state: [Tracer<'a, U32>; 8], - msg: &[Tracer<'a, U8>], -) -> [Tracer<'a, U8>; 32] { - sha256_trace( - builder_state, - outer_state, - 64, - &sha256_trace(builder_state, inner_state, 64, msg), - ) -} - -/// Reference implementation of the HMAC-SHA256 finalization function. -/// -/// Returns the HMAC-SHA256 digest of the provided message using existing outer -/// and inner states. -/// -/// # Arguments -/// -/// * `outer_state` - 256-bit outer state. -/// * `inner_state` - 256-bit inner state. -/// * `msg` - N-byte message. -pub fn hmac_sha256_finalize(outer_state: [u32; 8], inner_state: [u32; 8], msg: &[u8]) -> [u8; 32] { - sha256(outer_state, 64, &sha256(inner_state, 64, msg)) -} - -#[cfg(test)] -mod tests { - use mpz_circuits::{test_circ, CircuitBuilder}; - - use super::*; - - #[test] - fn test_hmac_sha256_partial() { - let builder = CircuitBuilder::new(); - let key = builder.add_array_input::(); - let (outer_state, inner_state) = hmac_sha256_partial_trace(builder.state(), &key); - builder.add_output(outer_state); - builder.add_output(inner_state); - let circ = builder.build().unwrap(); - - let key = [69u8; 48]; - - test_circ!(circ, hmac_sha256_partial, fn(&key) -> ([u32; 8], [u32; 8])); - } - - #[test] - fn test_hmac_sha256_finalize() { - let builder = CircuitBuilder::new(); - let outer_state = builder.add_array_input::(); - let inner_state = builder.add_array_input::(); - let msg = builder.add_array_input::(); - let hash = hmac_sha256_finalize_trace(builder.state(), outer_state, inner_state, &msg); - builder.add_output(hash); - let circ = builder.build().unwrap(); - - let key = [69u8; 32]; - let (outer_state, inner_state) = hmac_sha256_partial(&key); - let msg = [42u8; 47]; - - test_circ!( - circ, - hmac_sha256_finalize, - fn(outer_state, inner_state, &msg) -> [u8; 32] - ); - } -} diff --git a/crates/components/hmac-sha256-circuits/src/lib.rs b/crates/components/hmac-sha256-circuits/src/lib.rs deleted file mode 100644 index 2892a5bcf0..0000000000 --- a/crates/components/hmac-sha256-circuits/src/lib.rs +++ /dev/null @@ -1,61 +0,0 @@ -//! HMAC-SHA256 circuits. - -#![deny(missing_docs, unreachable_pub, unused_must_use)] -#![deny(clippy::all)] -#![forbid(unsafe_code)] - -mod hmac_sha256; -mod prf; -mod session_keys; -mod verify_data; - -pub use hmac_sha256::{ - hmac_sha256_finalize, hmac_sha256_finalize_trace, hmac_sha256_partial, - hmac_sha256_partial_trace, -}; - -pub use prf::{prf, prf_trace}; -pub use session_keys::{session_keys, session_keys_trace}; -pub use verify_data::{verify_data, verify_data_trace}; - -use mpz_circuits::{Circuit, CircuitBuilder, Tracer}; -use std::sync::Arc; - -/// Builds session key derivation circuit. -#[tracing::instrument(level = "trace")] -pub fn build_session_keys() -> Arc { - let builder = CircuitBuilder::new(); - let pms = builder.add_array_input::(); - let client_random = builder.add_array_input::(); - let server_random = builder.add_array_input::(); - let (cwk, swk, civ, siv, outer_state, inner_state) = - session_keys_trace(builder.state(), pms, client_random, server_random); - builder.add_output(cwk); - builder.add_output(swk); - builder.add_output(civ); - builder.add_output(siv); - builder.add_output(outer_state); - builder.add_output(inner_state); - Arc::new(builder.build().expect("session keys should build")) -} - -/// Builds a verify data circuit. -#[tracing::instrument(level = "trace")] -pub fn build_verify_data(label: &[u8]) -> Arc { - let builder = CircuitBuilder::new(); - let outer_state = builder.add_array_input::(); - let inner_state = builder.add_array_input::(); - let handshake_hash = builder.add_array_input::(); - let vd = verify_data_trace( - builder.state(), - outer_state, - inner_state, - &label - .iter() - .map(|v| Tracer::new(builder.state(), builder.get_constant(*v).to_inner())) - .collect::>(), - handshake_hash, - ); - builder.add_output(vd); - Arc::new(builder.build().expect("verify data should build")) -} diff --git a/crates/components/hmac-sha256-circuits/src/prf.rs b/crates/components/hmac-sha256-circuits/src/prf.rs deleted file mode 100644 index 664a93393f..0000000000 --- a/crates/components/hmac-sha256-circuits/src/prf.rs +++ /dev/null @@ -1,227 +0,0 @@ -//! This module provides an implementation of the HMAC-SHA256 PRF defined in [RFC 5246](https://www.rfc-editor.org/rfc/rfc5246#section-5). - -use std::cell::RefCell; - -use mpz_circuits::{ - types::{U32, U8}, - BuilderState, Tracer, -}; - -use crate::hmac_sha256::{hmac_sha256_finalize, hmac_sha256_finalize_trace}; - -fn p_hash_trace<'a>( - builder_state: &'a RefCell, - outer_state: [Tracer<'a, U32>; 8], - inner_state: [Tracer<'a, U32>; 8], - seed: &[Tracer<'a, U8>], - iterations: usize, -) -> Vec> { - // A() is defined as: - // - // A(0) = seed - // A(i) = HMAC_hash(secret, A(i-1)) - let mut a_cache: Vec<_> = Vec::with_capacity(iterations + 1); - a_cache.push(seed.to_vec()); - - for i in 0..iterations { - let a_i = hmac_sha256_finalize_trace(builder_state, outer_state, inner_state, &a_cache[i]); - a_cache.push(a_i.to_vec()); - } - - // HMAC_hash(secret, A(i) + seed) - let mut output: Vec<_> = Vec::with_capacity(iterations * 32); - for i in 0..iterations { - let mut a_i_seed = a_cache[i + 1].clone(); - a_i_seed.extend_from_slice(seed); - - let hash = hmac_sha256_finalize_trace(builder_state, outer_state, inner_state, &a_i_seed); - output.extend_from_slice(&hash); - } - - output -} - -fn p_hash(outer_state: [u32; 8], inner_state: [u32; 8], seed: &[u8], iterations: usize) -> Vec { - // A() is defined as: - // - // A(0) = seed - // A(i) = HMAC_hash(secret, A(i-1)) - let mut a_cache: Vec<_> = Vec::with_capacity(iterations + 1); - a_cache.push(seed.to_vec()); - - for i in 0..iterations { - let a_i = hmac_sha256_finalize(outer_state, inner_state, &a_cache[i]); - a_cache.push(a_i.to_vec()); - } - - // HMAC_hash(secret, A(i) + seed) - let mut output: Vec<_> = Vec::with_capacity(iterations * 32); - for i in 0..iterations { - let mut a_i_seed = a_cache[i + 1].clone(); - a_i_seed.extend_from_slice(seed); - - let hash = hmac_sha256_finalize(outer_state, inner_state, &a_i_seed); - output.extend_from_slice(&hash); - } - - output -} - -/// Computes PRF(secret, label, seed). -/// -/// # Arguments -/// -/// * `builder_state` - Reference to builder state. -/// * `outer_state` - The outer state of HMAC-SHA256. -/// * `inner_state` - The inner state of HMAC-SHA256. -/// * `seed` - The seed to use. -/// * `label` - The label to use. -/// * `bytes` - The number of bytes to output. -pub fn prf_trace<'a>( - builder_state: &'a RefCell, - outer_state: [Tracer<'a, U32>; 8], - inner_state: [Tracer<'a, U32>; 8], - seed: &[Tracer<'a, U8>], - label: &[Tracer<'a, U8>], - bytes: usize, -) -> Vec> { - let iterations = bytes / 32 + (bytes % 32 != 0) as usize; - let mut label_seed = label.to_vec(); - label_seed.extend_from_slice(seed); - - let mut output = p_hash_trace( - builder_state, - outer_state, - inner_state, - &label_seed, - iterations, - ); - output.truncate(bytes); - - output -} - -/// Reference implementation of PRF(secret, label, seed). -/// -/// # Arguments -/// -/// * `outer_state` - The outer state of HMAC-SHA256. -/// * `inner_state` - The inner state of HMAC-SHA256. -/// * `seed` - The seed to use. -/// * `label` - The label to use. -/// * `bytes` - The number of bytes to output. -pub fn prf( - outer_state: [u32; 8], - inner_state: [u32; 8], - seed: &[u8], - label: &[u8], - bytes: usize, -) -> Vec { - let iterations = bytes / 32 + (bytes % 32 != 0) as usize; - let mut label_seed = label.to_vec(); - label_seed.extend_from_slice(seed); - - let mut output = p_hash(outer_state, inner_state, &label_seed, iterations); - output.truncate(bytes); - - output -} - -#[cfg(test)] -mod tests { - use mpz_circuits::{evaluate, CircuitBuilder}; - - use crate::hmac_sha256::hmac_sha256_partial; - - use super::*; - - #[test] - fn test_p_hash() { - let builder = CircuitBuilder::new(); - let outer_state = builder.add_array_input::(); - let inner_state = builder.add_array_input::(); - let seed = builder.add_array_input::(); - let output = p_hash_trace(builder.state(), outer_state, inner_state, &seed, 2); - builder.add_output(output); - let circ = builder.build().unwrap(); - - let outer_state = [0u32; 8]; - let inner_state = [1u32; 8]; - let seed = [42u8; 64]; - - let expected = p_hash(outer_state, inner_state, &seed, 2); - let actual = evaluate!(circ, fn(outer_state, inner_state, &seed) -> Vec).unwrap(); - - assert_eq!(actual, expected); - } - - #[test] - fn test_prf() { - let builder = CircuitBuilder::new(); - let outer_state = builder.add_array_input::(); - let inner_state = builder.add_array_input::(); - let seed = builder.add_array_input::(); - let label = builder.add_array_input::(); - let output = prf_trace(builder.state(), outer_state, inner_state, &seed, &label, 48); - builder.add_output(output); - let circ = builder.build().unwrap(); - - let master_secret = [0u8; 48]; - let seed = [43u8; 64]; - let label = b"master secret"; - - let (outer_state, inner_state) = hmac_sha256_partial(&master_secret); - - let expected = prf(outer_state, inner_state, &seed, label, 48); - let actual = - evaluate!(circ, fn(outer_state, inner_state, &seed, label) -> Vec).unwrap(); - - assert_eq!(actual, expected); - - let mut expected_ring = [0u8; 48]; - ring_prf::prf(&mut expected_ring, &master_secret, label, &seed); - - assert_eq!(actual, expected_ring); - } - - // Borrowed from Rustls for testing - // https://github.com/rustls/rustls/blob/main/rustls/src/tls12/prf.rs - mod ring_prf { - use ring::{hmac, hmac::HMAC_SHA256}; - - fn concat_sign(key: &hmac::Key, a: &[u8], b: &[u8]) -> hmac::Tag { - let mut ctx = hmac::Context::with_key(key); - ctx.update(a); - ctx.update(b); - ctx.sign() - } - - fn p(out: &mut [u8], secret: &[u8], seed: &[u8]) { - let hmac_key = hmac::Key::new(HMAC_SHA256, secret); - - // A(1) - let mut current_a = hmac::sign(&hmac_key, seed); - let chunk_size = HMAC_SHA256.digest_algorithm().output_len(); - for chunk in out.chunks_mut(chunk_size) { - // P_hash[i] = HMAC_hash(secret, A(i) + seed) - let p_term = concat_sign(&hmac_key, current_a.as_ref(), seed); - chunk.copy_from_slice(&p_term.as_ref()[..chunk.len()]); - - // A(i+1) = HMAC_hash(secret, A(i)) - current_a = hmac::sign(&hmac_key, current_a.as_ref()); - } - } - - fn concat(a: &[u8], b: &[u8]) -> Vec { - let mut ret = Vec::new(); - ret.extend_from_slice(a); - ret.extend_from_slice(b); - ret - } - - pub(crate) fn prf(out: &mut [u8], secret: &[u8], label: &[u8], seed: &[u8]) { - let joined_seed = concat(label, seed); - p(out, secret, &joined_seed); - } - } -} diff --git a/crates/components/hmac-sha256-circuits/src/session_keys.rs b/crates/components/hmac-sha256-circuits/src/session_keys.rs deleted file mode 100644 index 4ebe7f0841..0000000000 --- a/crates/components/hmac-sha256-circuits/src/session_keys.rs +++ /dev/null @@ -1,200 +0,0 @@ -use std::cell::RefCell; - -use mpz_circuits::{ - types::{U32, U8}, - BuilderState, Tracer, -}; - -use crate::{ - hmac_sha256::{hmac_sha256_partial, hmac_sha256_partial_trace}, - prf::{prf, prf_trace}, -}; - -/// Session Keys. -/// -/// Computes expanded p1 which consists of client_write_key + server_write_key. -/// Computes expanded p2 which consists of client_IV + server_IV. -/// -/// # Arguments -/// -/// * `builder_state` - Reference to builder state. -/// * `pms` - 32-byte premaster secret. -/// * `client_random` - 32-byte client random. -/// * `server_random` - 32-byte server random. -/// -/// # Returns -/// -/// * `client_write_key` - 16-byte client write key. -/// * `server_write_key` - 16-byte server write key. -/// * `client_IV` - 4-byte client IV. -/// * `server_IV` - 4-byte server IV. -/// * `outer_hash_state` - 256-bit master-secret outer HMAC state. -/// * `inner_hash_state` - 256-bit master-secret inner HMAC state. -#[allow(clippy::type_complexity)] -pub fn session_keys_trace<'a>( - builder_state: &'a RefCell, - pms: [Tracer<'a, U8>; 32], - client_random: [Tracer<'a, U8>; 32], - server_random: [Tracer<'a, U8>; 32], -) -> ( - [Tracer<'a, U8>; 16], - [Tracer<'a, U8>; 16], - [Tracer<'a, U8>; 4], - [Tracer<'a, U8>; 4], - [Tracer<'a, U32>; 8], - [Tracer<'a, U32>; 8], -) { - let (pms_outer_state, pms_inner_state) = hmac_sha256_partial_trace(builder_state, &pms); - - let master_secret = { - let seed = client_random - .iter() - .chain(&server_random) - .copied() - .collect::>(); - - let label = b"master secret" - .map(|v| Tracer::new(builder_state, builder_state.borrow_mut().get_constant(v))); - - prf_trace( - builder_state, - pms_outer_state, - pms_inner_state, - &seed, - &label, - 48, - ) - }; - - let (master_secret_outer_state, master_secret_inner_state) = - hmac_sha256_partial_trace(builder_state, &master_secret); - - let key_material = { - let seed = server_random - .iter() - .chain(&client_random) - .copied() - .collect::>(); - - let label = b"key expansion" - .map(|v| Tracer::new(builder_state, builder_state.borrow_mut().get_constant(v))); - - prf_trace( - builder_state, - master_secret_outer_state, - master_secret_inner_state, - &seed, - &label, - 40, - ) - }; - - let cwk = key_material[0..16].try_into().unwrap(); - let swk = key_material[16..32].try_into().unwrap(); - let civ = key_material[32..36].try_into().unwrap(); - let siv = key_material[36..40].try_into().unwrap(); - - ( - cwk, - swk, - civ, - siv, - master_secret_outer_state, - master_secret_inner_state, - ) -} - -/// Reference implementation of session keys derivation. -pub fn session_keys( - pms: [u8; 32], - client_random: [u8; 32], - server_random: [u8; 32], -) -> ([u8; 16], [u8; 16], [u8; 4], [u8; 4]) { - let (pms_outer_state, pms_inner_state) = hmac_sha256_partial(&pms); - - let master_secret = { - let seed = client_random - .iter() - .chain(&server_random) - .copied() - .collect::>(); - - let label = b"master secret"; - - prf(pms_outer_state, pms_inner_state, &seed, label, 48) - }; - - let (master_secret_outer_state, master_secret_inner_state) = - hmac_sha256_partial(&master_secret); - - let key_material = { - let seed = server_random - .iter() - .chain(&client_random) - .copied() - .collect::>(); - - let label = b"key expansion"; - - prf( - master_secret_outer_state, - master_secret_inner_state, - &seed, - label, - 40, - ) - }; - - let cwk = key_material[0..16].try_into().unwrap(); - let swk = key_material[16..32].try_into().unwrap(); - let civ = key_material[32..36].try_into().unwrap(); - let siv = key_material[36..40].try_into().unwrap(); - - (cwk, swk, civ, siv) -} - -#[cfg(test)] -mod tests { - use mpz_circuits::{evaluate, CircuitBuilder}; - - use super::*; - - #[test] - fn test_session_keys() { - let builder = CircuitBuilder::new(); - let pms = builder.add_array_input::(); - let client_random = builder.add_array_input::(); - let server_random = builder.add_array_input::(); - let (cwk, swk, civ, siv, outer_state, inner_state) = - session_keys_trace(builder.state(), pms, client_random, server_random); - builder.add_output(cwk); - builder.add_output(swk); - builder.add_output(civ); - builder.add_output(siv); - builder.add_output(outer_state); - builder.add_output(inner_state); - let circ = builder.build().unwrap(); - - let pms = [0u8; 32]; - let client_random = [42u8; 32]; - let server_random = [69u8; 32]; - - let (expected_cwk, expected_swk, expected_civ, expected_siv) = - session_keys(pms, client_random, server_random); - - let (cwk, swk, civ, siv, _, _) = evaluate!( - circ, - fn( - pms, - client_random, - server_random, - ) -> ([u8; 16], [u8; 16], [u8; 4], [u8; 4], [u32; 8], [u32; 8]) - ) - .unwrap(); - - assert_eq!(cwk, expected_cwk); - assert_eq!(swk, expected_swk); - assert_eq!(civ, expected_civ); - assert_eq!(siv, expected_siv); - } -} diff --git a/crates/components/hmac-sha256-circuits/src/verify_data.rs b/crates/components/hmac-sha256-circuits/src/verify_data.rs deleted file mode 100644 index 01f10b7469..0000000000 --- a/crates/components/hmac-sha256-circuits/src/verify_data.rs +++ /dev/null @@ -1,88 +0,0 @@ -use std::cell::RefCell; - -use mpz_circuits::{ - types::{U32, U8}, - BuilderState, Tracer, -}; - -use crate::prf::{prf, prf_trace}; - -/// Computes verify_data as specified in RFC 5246, Section 7.4.9. -/// -/// verify_data -/// PRF(master_secret, finished_label, -/// Hash(handshake_messages))[0..verify_data_length-1]; -/// -/// # Arguments -/// -/// * `builder_state` - The builder state. -/// * `outer_state` - The outer HMAC state of the master secret. -/// * `inner_state` - The inner HMAC state of the master secret. -/// * `label` - The label to use. -/// * `hs_hash` - The handshake hash. -pub fn verify_data_trace<'a>( - builder_state: &'a RefCell, - outer_state: [Tracer<'a, U32>; 8], - inner_state: [Tracer<'a, U32>; 8], - label: &[Tracer<'a, U8>], - hs_hash: [Tracer<'a, U8>; 32], -) -> [Tracer<'a, U8>; 12] { - let vd = prf_trace(builder_state, outer_state, inner_state, &hs_hash, label, 12); - - vd.try_into().expect("vd is 12 bytes") -} - -/// Reference implementation of verify_data as specified in RFC 5246, Section -/// 7.4.9. -/// -/// # Arguments -/// -/// * `outer_state` - The outer HMAC state of the master secret. -/// * `inner_state` - The inner HMAC state of the master secret. -/// * `label` - The label to use. -/// * `hs_hash` - The handshake hash. -pub fn verify_data( - outer_state: [u32; 8], - inner_state: [u32; 8], - label: &[u8], - hs_hash: [u8; 32], -) -> [u8; 12] { - let vd = prf(outer_state, inner_state, &hs_hash, label, 12); - - vd.try_into().expect("vd is 12 bytes") -} - -#[cfg(test)] -mod tests { - use super::*; - - use mpz_circuits::{evaluate, CircuitBuilder}; - - const CF_LABEL: &[u8; 15] = b"client finished"; - - #[test] - fn test_verify_data() { - let builder = CircuitBuilder::new(); - let outer_state = builder.add_array_input::(); - let inner_state = builder.add_array_input::(); - let label = builder.add_array_input::(); - let hs_hash = builder.add_array_input::(); - let vd = verify_data_trace(builder.state(), outer_state, inner_state, &label, hs_hash); - builder.add_output(vd); - let circ = builder.build().unwrap(); - - let outer_state = [0u32; 8]; - let inner_state = [1u32; 8]; - let hs_hash = [42u8; 32]; - - let expected = prf(outer_state, inner_state, &hs_hash, CF_LABEL, 12); - - let actual = evaluate!( - circ, - fn(outer_state, inner_state, CF_LABEL, hs_hash) -> [u8; 12] - ) - .unwrap(); - - assert_eq!(actual.to_vec(), expected); - } -} diff --git a/crates/components/hmac-sha256/Cargo.toml b/crates/components/hmac-sha256/Cargo.toml index 07ecc6b6c6..fa1fd4d7dd 100644 --- a/crates/components/hmac-sha256/Cargo.toml +++ b/crates/components/hmac-sha256/Cargo.toml @@ -14,22 +14,15 @@ workspace = true [lib] name = "hmac_sha256" -[features] -default = ["mock"] -rayon = ["mpz-common/rayon"] -mock = [] - [dependencies] -tlsn-hmac-sha256-circuits = { workspace = true } - mpz-vm-core = { workspace = true } +mpz-core = { workspace = true } mpz-circuits = { workspace = true } -mpz-common = { workspace = true, features = ["cpu"] } +mpz-hash = { workspace = true } -derive_builder = { workspace = true } thiserror = { workspace = true } tracing = { workspace = true } -futures = { workspace = true } +sha2 = { workspace = true } [dev-dependencies] mpz-ot = { workspace = true, features = ["ideal"] } @@ -39,7 +32,8 @@ mpz-common = { workspace = true, features = ["test-utils"] } criterion = { workspace = true, features = ["async_tokio"] } tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] } rand = { workspace = true } -rand06-compat = { workspace = true } +hex = { workspace = true } +ring = { workspace = true } [[bench]] name = "prf" diff --git a/crates/components/hmac-sha256/benches/prf.rs b/crates/components/hmac-sha256/benches/prf.rs index 9e494a93c2..926f503718 100644 --- a/crates/components/hmac-sha256/benches/prf.rs +++ b/crates/components/hmac-sha256/benches/prf.rs @@ -2,16 +2,16 @@ use criterion::{criterion_group, criterion_main, Criterion}; -use hmac_sha256::{MpcPrf, PrfConfig, Role}; +use hmac_sha256::{Mode, MpcPrf}; use mpz_common::context::test_mt_context; -use mpz_garble::protocol::semihonest::{Evaluator, Generator}; +use mpz_garble::protocol::semihonest::{Evaluator, Garbler}; use mpz_ot::ideal::cot::ideal_cot; use mpz_vm_core::{ memory::{binary::U8, correlated::Delta, Array}, prelude::*, + Execute, }; use rand::{rngs::StdRng, SeedableRng}; -use rand06_compat::Rand0_6CompatExt; #[allow(clippy::unit_arg)] fn criterion_benchmark(c: &mut Criterion) { @@ -36,10 +36,10 @@ async fn prf() { let mut leader_ctx = leader_exec.new_context().await.unwrap(); let mut follower_ctx = follower_exec.new_context().await.unwrap(); - let delta = Delta::random(&mut rng.compat_by_ref()); + let delta = Delta::random(&mut rng); let (ot_send, ot_recv) = ideal_cot(delta.into_inner()); - let mut leader_vm = Generator::new(ot_send, [0u8; 16], delta); + let mut leader_vm = Garbler::new(ot_send, [0u8; 16], delta); let mut follower_vm = Evaluator::new(ot_recv); let leader_pms: Array = leader_vm.alloc().unwrap(); @@ -52,23 +52,17 @@ async fn prf() { follower_vm.assign(follower_pms, pms).unwrap(); follower_vm.commit(follower_pms).unwrap(); - let mut leader = MpcPrf::new(PrfConfig::builder().role(Role::Leader).build().unwrap()); - let mut follower = MpcPrf::new(PrfConfig::builder().role(Role::Follower).build().unwrap()); + let mut leader = MpcPrf::new(Mode::default()); + let mut follower = MpcPrf::new(Mode::default()); let leader_output = leader.alloc(&mut leader_vm, leader_pms).unwrap(); let follower_output = follower.alloc(&mut follower_vm, follower_pms).unwrap(); - leader - .set_client_random(&mut leader_vm, Some(client_random)) - .unwrap(); - follower.set_client_random(&mut follower_vm, None).unwrap(); + leader.set_client_random(client_random).unwrap(); + follower.set_client_random(client_random).unwrap(); - leader - .set_server_random(&mut leader_vm, server_random) - .unwrap(); - follower - .set_server_random(&mut follower_vm, server_random) - .unwrap(); + leader.set_server_random(server_random).unwrap(); + follower.set_server_random(server_random).unwrap(); let _ = leader_vm .decode(leader_output.keys.client_write_key) @@ -88,44 +82,61 @@ async fn prf() { let _ = follower_vm.decode(follower_output.keys.client_iv).unwrap(); let _ = follower_vm.decode(follower_output.keys.server_iv).unwrap(); - futures::join!( - async { - leader_vm.flush(&mut leader_ctx).await.unwrap(); - leader_vm.execute(&mut leader_ctx).await.unwrap(); - leader_vm.flush(&mut leader_ctx).await.unwrap(); - }, - async { - follower_vm.flush(&mut follower_ctx).await.unwrap(); - follower_vm.execute(&mut follower_ctx).await.unwrap(); - follower_vm.flush(&mut follower_ctx).await.unwrap(); - } - ); + while leader.wants_flush() || follower.wants_flush() { + tokio::try_join!( + async { + leader.flush(&mut leader_vm).unwrap(); + leader_vm.execute_all(&mut leader_ctx).await + }, + async { + follower.flush(&mut follower_vm).unwrap(); + follower_vm.execute_all(&mut follower_ctx).await + } + ) + .unwrap(); + } let cf_hs_hash = [1u8; 32]; - let sf_hs_hash = [2u8; 32]; - - leader.set_cf_hash(&mut leader_vm, cf_hs_hash).unwrap(); - leader.set_sf_hash(&mut leader_vm, sf_hs_hash).unwrap(); - follower.set_cf_hash(&mut follower_vm, cf_hs_hash).unwrap(); - follower.set_sf_hash(&mut follower_vm, sf_hs_hash).unwrap(); + leader.set_cf_hash(cf_hs_hash).unwrap(); + follower.set_cf_hash(cf_hs_hash).unwrap(); + + while leader.wants_flush() || follower.wants_flush() { + tokio::try_join!( + async { + leader.flush(&mut leader_vm).unwrap(); + leader_vm.execute_all(&mut leader_ctx).await + }, + async { + follower.flush(&mut follower_vm).unwrap(); + follower_vm.execute_all(&mut follower_ctx).await + } + ) + .unwrap(); + } let _ = leader_vm.decode(leader_output.cf_vd).unwrap(); - let _ = leader_vm.decode(leader_output.sf_vd).unwrap(); - let _ = follower_vm.decode(follower_output.cf_vd).unwrap(); - let _ = follower_vm.decode(follower_output.sf_vd).unwrap(); - futures::join!( - async { - leader_vm.flush(&mut leader_ctx).await.unwrap(); - leader_vm.execute(&mut leader_ctx).await.unwrap(); - leader_vm.flush(&mut leader_ctx).await.unwrap(); - }, - async { - follower_vm.flush(&mut follower_ctx).await.unwrap(); - follower_vm.execute(&mut follower_ctx).await.unwrap(); - follower_vm.flush(&mut follower_ctx).await.unwrap(); - } - ); + let sf_hs_hash = [2u8; 32]; + + leader.set_sf_hash(sf_hs_hash).unwrap(); + follower.set_sf_hash(sf_hs_hash).unwrap(); + + while leader.wants_flush() || follower.wants_flush() { + tokio::try_join!( + async { + leader.flush(&mut leader_vm).unwrap(); + leader_vm.execute_all(&mut leader_ctx).await + }, + async { + follower.flush(&mut follower_vm).unwrap(); + follower_vm.execute_all(&mut follower_ctx).await + } + ) + .unwrap(); + } + + let _ = leader_vm.decode(leader_output.sf_vd).unwrap(); + let _ = follower_vm.decode(follower_output.sf_vd).unwrap(); } diff --git a/crates/components/hmac-sha256/src/config.rs b/crates/components/hmac-sha256/src/config.rs index c9e96c9cd4..85914b0f52 100644 --- a/crates/components/hmac-sha256/src/config.rs +++ b/crates/components/hmac-sha256/src/config.rs @@ -1,24 +1,16 @@ -use derive_builder::Builder; +//! PRF modes. -/// Role of this party in the PRF. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum Role { - /// The leader provides the private inputs to the PRF. - Leader, - /// The follower is blind to the inputs to the PRF. - Follower, +/// Modes for the PRF. +#[derive(Debug, Clone, Copy)] +pub enum Mode { + /// Computes some hashes locally. + Reduced, + /// Computes the whole PRF in MPC. + Normal, } -/// Configuration for the PRF. -#[derive(Debug, Builder)] -pub struct PrfConfig { - /// The role of this party in the PRF. - pub(crate) role: Role, -} - -impl PrfConfig { - /// Creates a new builder. - pub fn builder() -> PrfConfigBuilder { - PrfConfigBuilder::default() +impl Default for Mode { + fn default() -> Self { + Self::Reduced } } diff --git a/crates/components/hmac-sha256/src/error.rs b/crates/components/hmac-sha256/src/error.rs index d22f754947..a7cf225edf 100644 --- a/crates/components/hmac-sha256/src/error.rs +++ b/crates/components/hmac-sha256/src/error.rs @@ -1,6 +1,8 @@ use core::fmt; use std::error::Error; +use mpz_hash::sha256::Sha256Error; + /// A PRF error. #[derive(Debug, thiserror::Error)] pub struct PrfError { @@ -20,22 +22,21 @@ impl PrfError { } } - pub(crate) fn state(msg: impl Into) -> Self { - Self { - kind: ErrorKind::State, - source: Some(msg.into().into()), - } + pub(crate) fn vm>>(err: E) -> Self { + Self::new(ErrorKind::Vm, err) } - pub(crate) fn role(msg: impl Into) -> Self { + pub(crate) fn state(msg: impl Into) -> Self { Self { - kind: ErrorKind::Role, + kind: ErrorKind::State, source: Some(msg.into().into()), } } +} - pub(crate) fn vm>>(err: E) -> Self { - Self::new(ErrorKind::Vm, err) +impl From for PrfError { + fn from(value: Sha256Error) -> Self { + Self::new(ErrorKind::Hash, value) } } @@ -43,7 +44,7 @@ impl PrfError { pub(crate) enum ErrorKind { Vm, State, - Role, + Hash, } impl fmt::Display for PrfError { @@ -51,7 +52,7 @@ impl fmt::Display for PrfError { match self.kind { ErrorKind::Vm => write!(f, "vm error")?, ErrorKind::State => write!(f, "state error")?, - ErrorKind::Role => write!(f, "role error")?, + ErrorKind::Hash => write!(f, "hash error")?, } if let Some(ref source) = self.source { @@ -61,9 +62,3 @@ impl fmt::Display for PrfError { Ok(()) } } - -impl From for PrfError { - fn from(error: mpz_common::ContextError) -> Self { - Self::new(ErrorKind::Vm, error) - } -} diff --git a/crates/components/hmac-sha256/src/hmac.rs b/crates/components/hmac-sha256/src/hmac.rs new file mode 100644 index 0000000000..1103dc6c07 --- /dev/null +++ b/crates/components/hmac-sha256/src/hmac.rs @@ -0,0 +1,177 @@ +//! Computation of HMAC-SHA256. +//! +//! HMAC-SHA256 is defined as +//! +//! HMAC(m) = H((key' xor opad) || H((key' xor ipad) || m)) +//! +//! * H - SHA256 hash function +//! * key' - key padded with zero bytes to 64 bytes (we do not support longer +//! keys) +//! * opad - 64 bytes of 0x5c +//! * ipad - 64 bytes of 0x36 +//! * m - message +//! +//! This implementation computes HMAC-SHA256 using intermediate results +//! `outer_partial` and `inner_local`. Then HMAC(m) = H(outer_partial || +//! inner_local) +//! +//! * `outer_partial` - key' xor opad +//! * `inner_local` - H((key' xor ipad) || m) + +use mpz_hash::sha256::Sha256; +use mpz_vm_core::{ + memory::{ + binary::{Binary, U8}, + Array, + }, + Vm, +}; + +use crate::PrfError; + +pub(crate) const IPAD: [u8; 64] = [0x36; 64]; +pub(crate) const OPAD: [u8; 64] = [0x5c; 64]; + +/// Computes HMAC-SHA256 +/// +/// # Arguments +/// +/// * `vm` - The virtual machine. +/// * `outer_partial` - (key' xor opad) +/// * `inner_local` - H((key' xor ipad) || m) +pub(crate) fn hmac_sha256( + vm: &mut dyn Vm, + mut outer_partial: Sha256, + inner_local: Array, +) -> Result, PrfError> { + outer_partial.update(&inner_local); + outer_partial.compress(vm)?; + outer_partial.finalize(vm).map_err(PrfError::from) +} + +#[cfg(test)] +mod tests { + use crate::{ + hmac::hmac_sha256, + sha256, state_to_bytes, + test_utils::{compute_inner_local, compute_outer_partial, mock_vm}, + }; + use mpz_common::context::test_st_context; + use mpz_hash::sha256::Sha256; + use mpz_vm_core::{ + memory::{ + binary::{U32, U8}, + Array, MemoryExt, ViewExt, + }, + Execute, + }; + + #[test] + fn test_hmac_reference() { + let (inputs, references) = test_fixtures(); + + for (input, &reference) in inputs.iter().zip(references.iter()) { + let outer_partial = compute_outer_partial(input.0.clone()); + let inner_local = compute_inner_local(input.0.clone(), &input.1); + + let hmac = sha256(outer_partial, 64, &state_to_bytes(inner_local)); + + assert_eq!(state_to_bytes(hmac), reference); + } + } + + #[tokio::test] + async fn test_hmac_circuit() { + let (mut ctx_a, mut ctx_b) = test_st_context(8); + let (mut leader, mut follower) = mock_vm(); + + let (inputs, references) = test_fixtures(); + for (input, &reference) in inputs.iter().zip(references.iter()) { + let outer_partial = compute_outer_partial(input.0.clone()); + let inner_local = compute_inner_local(input.0.clone(), &input.1); + + let outer_partial_leader: Array = leader.alloc().unwrap(); + leader.mark_public(outer_partial_leader).unwrap(); + leader.assign(outer_partial_leader, outer_partial).unwrap(); + leader.commit(outer_partial_leader).unwrap(); + + let inner_local_leader: Array = leader.alloc().unwrap(); + leader.mark_public(inner_local_leader).unwrap(); + leader + .assign(inner_local_leader, state_to_bytes(inner_local)) + .unwrap(); + leader.commit(inner_local_leader).unwrap(); + + let hmac_leader = hmac_sha256( + &mut leader, + Sha256::new_from_state(outer_partial_leader, 1), + inner_local_leader, + ) + .unwrap(); + let hmac_leader = leader.decode(hmac_leader).unwrap(); + + let outer_partial_follower: Array = follower.alloc().unwrap(); + follower.mark_public(outer_partial_follower).unwrap(); + follower + .assign(outer_partial_follower, outer_partial) + .unwrap(); + follower.commit(outer_partial_follower).unwrap(); + + let inner_local_follower: Array = follower.alloc().unwrap(); + follower.mark_public(inner_local_follower).unwrap(); + follower + .assign(inner_local_follower, state_to_bytes(inner_local)) + .unwrap(); + follower.commit(inner_local_follower).unwrap(); + + let hmac_follower = hmac_sha256( + &mut follower, + Sha256::new_from_state(outer_partial_follower, 1), + inner_local_follower, + ) + .unwrap(); + let hmac_follower = follower.decode(hmac_follower).unwrap(); + + let (hmac_leader, hmac_follower) = tokio::try_join!( + async { + leader.execute_all(&mut ctx_a).await.unwrap(); + hmac_leader.await + }, + async { + follower.execute_all(&mut ctx_b).await.unwrap(); + hmac_follower.await + } + ) + .unwrap(); + + assert_eq!(hmac_leader, hmac_follower); + assert_eq!(hmac_leader, reference); + } + } + + #[allow(clippy::type_complexity)] + fn test_fixtures() -> (Vec<(Vec, Vec)>, Vec<[u8; 32]>) { + let test_vectors: Vec<(Vec, Vec)> = vec![ + ( + hex::decode("0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b").unwrap(), + hex::decode("4869205468657265").unwrap(), + ), + ( + hex::decode("4a656665").unwrap(), + hex::decode("7768617420646f2079612077616e7420666f72206e6f7468696e673f").unwrap(), + ), + ]; + let expected: Vec<[u8; 32]> = vec![ + hex::decode("b0344c61d8db38535ca8afceaf0bf12b881dc200c9833da726e9376c2e32cff7") + .unwrap() + .try_into() + .unwrap(), + hex::decode("5bdcc146bf60754e6a042426089575c75a003f089d2739839dec58b964ec3843") + .unwrap() + .try_into() + .unwrap(), + ]; + + (test_vectors, expected) + } +} diff --git a/crates/components/hmac-sha256/src/lib.rs b/crates/components/hmac-sha256/src/lib.rs index 2082113831..2e460bf43d 100644 --- a/crates/components/hmac-sha256/src/lib.rs +++ b/crates/components/hmac-sha256/src/lib.rs @@ -1,30 +1,24 @@ -//! This module contains the protocol for computing TLS SHA-256 HMAC PRF. +//! This crate contains the protocol for computing TLS 1.2 SHA-256 HMAC PRF. #![deny(missing_docs, unreachable_pub, unused_must_use)] #![deny(clippy::all)] #![forbid(unsafe_code)] +mod hmac; +#[cfg(test)] +mod test_utils; + mod config; -mod error; -mod prf; +pub use config::Mode; -pub use config::{PrfConfig, PrfConfigBuilder, PrfConfigBuilderError, Role}; +mod error; pub use error::PrfError; + +mod prf; pub use prf::MpcPrf; use mpz_vm_core::memory::{binary::U8, Array}; -pub(crate) static CF_LABEL: &[u8] = b"client finished"; -pub(crate) static SF_LABEL: &[u8] = b"server finished"; - -/// Builds the circuits for the PRF. -/// -/// This function can be used ahead of time to build the circuits for the PRF, -/// which at the moment is CPU and memory intensive. -pub async fn build_circuits() { - prf::Circuits::get().await; -} - /// PRF output. #[derive(Debug, Clone, Copy)] pub struct PrfOutput { @@ -49,176 +43,227 @@ pub struct SessionKeys { pub server_iv: Array, } +fn sha256(mut state: [u32; 8], pos: usize, msg: &[u8]) -> [u32; 8] { + use sha2::{ + compress256, + digest::{ + block_buffer::{BlockBuffer, Eager}, + generic_array::typenum::U64, + }, + }; + + let mut buffer = BlockBuffer::::default(); + buffer.digest_blocks(msg, |b| compress256(&mut state, b)); + buffer.digest_pad(0x80, &(((msg.len() + pos) * 8) as u64).to_be_bytes(), |b| { + compress256(&mut state, &[*b]) + }); + state +} + +fn state_to_bytes(input: [u32; 8]) -> [u8; 32] { + let mut output = [0_u8; 32]; + for (k, byte_chunk) in input.iter().enumerate() { + let byte_chunk = byte_chunk.to_be_bytes(); + output[4 * k..4 * (k + 1)].copy_from_slice(&byte_chunk); + } + output +} + #[cfg(test)] mod tests { + use crate::{ + test_utils::{mock_vm, prf_cf_vd, prf_keys, prf_ms, prf_sf_vd}, + Mode, MpcPrf, SessionKeys, + }; use mpz_common::context::test_st_context; - use mpz_garble::protocol::semihonest::{Evaluator, Generator}; - - use hmac_sha256_circuits::{hmac_sha256_partial, prf, session_keys}; - use mpz_ot::ideal::cot::ideal_cot; - use mpz_vm_core::{memory::correlated::Delta, prelude::*}; - use rand::{rngs::StdRng, SeedableRng}; - use rand06_compat::Rand0_6CompatExt; - - use super::*; - - fn compute_ms(pms: [u8; 32], client_random: [u8; 32], server_random: [u8; 32]) -> [u8; 48] { - let (outer_state, inner_state) = hmac_sha256_partial(&pms); - let seed = client_random - .iter() - .chain(&server_random) - .copied() - .collect::>(); - let ms = prf(outer_state, inner_state, &seed, b"master secret", 48); - ms.try_into().unwrap() - } + use mpz_vm_core::{ + memory::{binary::U8, Array, MemoryExt, ViewExt}, + Execute, + }; + use rand::{rngs::StdRng, Rng, SeedableRng}; - fn compute_vd(ms: [u8; 48], label: &[u8], hs_hash: [u8; 32]) -> [u8; 12] { - let (outer_state, inner_state) = hmac_sha256_partial(&ms); - let vd = prf(outer_state, inner_state, &hs_hash, label, 12); - vd.try_into().unwrap() + #[tokio::test] + async fn test_prf_reduced() { + let mode = Mode::Reduced; + test_prf(mode).await; } - #[ignore = "expensive"] #[tokio::test] - async fn test_prf() { - let mut rng = StdRng::seed_from_u64(0); - - let pms = [42u8; 32]; - let client_random = [69u8; 32]; - let server_random: [u8; 32] = [96u8; 32]; - let ms = compute_ms(pms, client_random, server_random); - - let (mut leader_ctx, mut follower_ctx) = test_st_context(128); - - let delta = Delta::random(&mut rng.compat_by_ref()); - let (ot_send, ot_recv) = ideal_cot(delta.into_inner()); - - let mut leader_vm = Generator::new(ot_send, [0u8; 16], delta); - let mut follower_vm = Evaluator::new(ot_recv); - - let leader_pms: Array = leader_vm.alloc().unwrap(); - leader_vm.mark_public(leader_pms).unwrap(); - leader_vm.assign(leader_pms, pms).unwrap(); - leader_vm.commit(leader_pms).unwrap(); - - let follower_pms: Array = follower_vm.alloc().unwrap(); - follower_vm.mark_public(follower_pms).unwrap(); - follower_vm.assign(follower_pms, pms).unwrap(); - follower_vm.commit(follower_pms).unwrap(); - - let mut leader = MpcPrf::new(PrfConfig::builder().role(Role::Leader).build().unwrap()); - let mut follower = MpcPrf::new(PrfConfig::builder().role(Role::Follower).build().unwrap()); - - let leader_output = leader.alloc(&mut leader_vm, leader_pms).unwrap(); - let follower_output = follower.alloc(&mut follower_vm, follower_pms).unwrap(); + async fn test_prf_normal() { + let mode = Mode::Normal; + test_prf(mode).await; + } - leader - .set_client_random(&mut leader_vm, Some(client_random)) + async fn test_prf(mode: Mode) { + let mut rng = StdRng::seed_from_u64(1); + // Test input + let pms: [u8; 32] = rng.random(); + let client_random: [u8; 32] = rng.random(); + let server_random: [u8; 32] = rng.random(); + + let cf_hs_hash: [u8; 32] = rng.random(); + let sf_hs_hash: [u8; 32] = rng.random(); + + // Expected output + let ms_expected = prf_ms(pms, client_random, server_random); + + let [cwk_expected, swk_expected, civ_expected, siv_expected] = + prf_keys(ms_expected, client_random, server_random); + + let cwk_expected: [u8; 16] = cwk_expected.try_into().unwrap(); + let swk_expected: [u8; 16] = swk_expected.try_into().unwrap(); + let civ_expected: [u8; 4] = civ_expected.try_into().unwrap(); + let siv_expected: [u8; 4] = siv_expected.try_into().unwrap(); + + let cf_vd_expected = prf_cf_vd(ms_expected, cf_hs_hash); + let sf_vd_expected = prf_sf_vd(ms_expected, sf_hs_hash); + + let cf_vd_expected: [u8; 12] = cf_vd_expected.try_into().unwrap(); + let sf_vd_expected: [u8; 12] = sf_vd_expected.try_into().unwrap(); + + // Set up vm and prf + let (mut ctx_a, mut ctx_b) = test_st_context(128); + let (mut leader, mut follower) = mock_vm(); + + let leader_pms: Array = leader.alloc().unwrap(); + leader.mark_public(leader_pms).unwrap(); + leader.assign(leader_pms, pms).unwrap(); + leader.commit(leader_pms).unwrap(); + + let follower_pms: Array = follower.alloc().unwrap(); + follower.mark_public(follower_pms).unwrap(); + follower.assign(follower_pms, pms).unwrap(); + follower.commit(follower_pms).unwrap(); + + let mut prf_leader = MpcPrf::new(mode); + let mut prf_follower = MpcPrf::new(mode); + + let leader_prf_out = prf_leader.alloc(&mut leader, leader_pms).unwrap(); + let follower_prf_out = prf_follower.alloc(&mut follower, follower_pms).unwrap(); + + // client_random and server_random + prf_leader.set_client_random(client_random).unwrap(); + prf_follower.set_client_random(client_random).unwrap(); + + prf_leader.set_server_random(server_random).unwrap(); + prf_follower.set_server_random(server_random).unwrap(); + + let SessionKeys { + client_write_key: cwk_leader, + server_write_key: swk_leader, + client_iv: civ_leader, + server_iv: siv_leader, + } = leader_prf_out.keys; + + let mut cwk_leader = leader.decode(cwk_leader).unwrap(); + let mut swk_leader = leader.decode(swk_leader).unwrap(); + let mut civ_leader = leader.decode(civ_leader).unwrap(); + let mut siv_leader = leader.decode(siv_leader).unwrap(); + + let SessionKeys { + client_write_key: cwk_follower, + server_write_key: swk_follower, + client_iv: civ_follower, + server_iv: siv_follower, + } = follower_prf_out.keys; + + let mut cwk_follower = follower.decode(cwk_follower).unwrap(); + let mut swk_follower = follower.decode(swk_follower).unwrap(); + let mut civ_follower = follower.decode(civ_follower).unwrap(); + let mut siv_follower = follower.decode(siv_follower).unwrap(); + + while prf_leader.wants_flush() || prf_follower.wants_flush() { + tokio::try_join!( + async { + prf_leader.flush(&mut leader).unwrap(); + leader.execute_all(&mut ctx_a).await + }, + async { + prf_follower.flush(&mut follower).unwrap(); + follower.execute_all(&mut ctx_b).await + } + ) .unwrap(); - follower.set_client_random(&mut follower_vm, None).unwrap(); - - leader - .set_server_random(&mut leader_vm, server_random) + } + + let cwk_leader = cwk_leader.try_recv().unwrap().unwrap(); + let swk_leader = swk_leader.try_recv().unwrap().unwrap(); + let civ_leader = civ_leader.try_recv().unwrap().unwrap(); + let siv_leader = siv_leader.try_recv().unwrap().unwrap(); + + let cwk_follower = cwk_follower.try_recv().unwrap().unwrap(); + let swk_follower = swk_follower.try_recv().unwrap().unwrap(); + let civ_follower = civ_follower.try_recv().unwrap().unwrap(); + let siv_follower = siv_follower.try_recv().unwrap().unwrap(); + + assert_eq!(cwk_leader, cwk_follower); + assert_eq!(swk_leader, swk_follower); + assert_eq!(civ_leader, civ_follower); + assert_eq!(siv_leader, siv_follower); + + assert_eq!(cwk_leader, cwk_expected); + assert_eq!(swk_leader, swk_expected); + assert_eq!(civ_leader, civ_expected); + assert_eq!(siv_leader, siv_expected); + + // client finished + prf_leader.set_cf_hash(cf_hs_hash).unwrap(); + prf_follower.set_cf_hash(cf_hs_hash).unwrap(); + + let cf_vd_leader = leader_prf_out.cf_vd; + let cf_vd_follower = follower_prf_out.cf_vd; + + let mut cf_vd_leader = leader.decode(cf_vd_leader).unwrap(); + let mut cf_vd_follower = follower.decode(cf_vd_follower).unwrap(); + + while prf_leader.wants_flush() || prf_follower.wants_flush() { + tokio::try_join!( + async { + prf_leader.flush(&mut leader).unwrap(); + leader.execute_all(&mut ctx_a).await + }, + async { + prf_follower.flush(&mut follower).unwrap(); + follower.execute_all(&mut ctx_b).await + } + ) .unwrap(); - follower - .set_server_random(&mut follower_vm, server_random) + } + + let cf_vd_leader = cf_vd_leader.try_recv().unwrap().unwrap(); + let cf_vd_follower = cf_vd_follower.try_recv().unwrap().unwrap(); + + assert_eq!(cf_vd_leader, cf_vd_follower); + assert_eq!(cf_vd_leader, cf_vd_expected); + + // server finished + prf_leader.set_sf_hash(sf_hs_hash).unwrap(); + prf_follower.set_sf_hash(sf_hs_hash).unwrap(); + + let sf_vd_leader = leader_prf_out.sf_vd; + let sf_vd_follower = follower_prf_out.sf_vd; + + let mut sf_vd_leader = leader.decode(sf_vd_leader).unwrap(); + let mut sf_vd_follower = follower.decode(sf_vd_follower).unwrap(); + + while prf_leader.wants_flush() || prf_follower.wants_flush() { + tokio::try_join!( + async { + prf_leader.flush(&mut leader).unwrap(); + leader.execute_all(&mut ctx_a).await + }, + async { + prf_follower.flush(&mut follower).unwrap(); + follower.execute_all(&mut ctx_b).await + } + ) .unwrap(); + } - let leader_cwk = leader_vm - .decode(leader_output.keys.client_write_key) - .unwrap(); - let leader_swk = leader_vm - .decode(leader_output.keys.server_write_key) - .unwrap(); - let leader_civ = leader_vm.decode(leader_output.keys.client_iv).unwrap(); - let leader_siv = leader_vm.decode(leader_output.keys.server_iv).unwrap(); + let sf_vd_leader = sf_vd_leader.try_recv().unwrap().unwrap(); + let sf_vd_follower = sf_vd_follower.try_recv().unwrap().unwrap(); - let follower_cwk = follower_vm - .decode(follower_output.keys.client_write_key) - .unwrap(); - let follower_swk = follower_vm - .decode(follower_output.keys.server_write_key) - .unwrap(); - let follower_civ = follower_vm.decode(follower_output.keys.client_iv).unwrap(); - let follower_siv = follower_vm.decode(follower_output.keys.server_iv).unwrap(); - - futures::join!( - async { - leader_vm.flush(&mut leader_ctx).await.unwrap(); - leader_vm.execute(&mut leader_ctx).await.unwrap(); - leader_vm.flush(&mut leader_ctx).await.unwrap(); - }, - async { - follower_vm.flush(&mut follower_ctx).await.unwrap(); - follower_vm.execute(&mut follower_ctx).await.unwrap(); - follower_vm.flush(&mut follower_ctx).await.unwrap(); - } - ); - - let leader_cwk = leader_cwk.await.unwrap(); - let leader_swk = leader_swk.await.unwrap(); - let leader_civ = leader_civ.await.unwrap(); - let leader_siv = leader_siv.await.unwrap(); - - let follower_cwk = follower_cwk.await.unwrap(); - let follower_swk = follower_swk.await.unwrap(); - let follower_civ = follower_civ.await.unwrap(); - let follower_siv = follower_siv.await.unwrap(); - - let (expected_cwk, expected_swk, expected_civ, expected_siv) = - session_keys(pms, client_random, server_random); - - assert_eq!(leader_cwk, expected_cwk); - assert_eq!(leader_swk, expected_swk); - assert_eq!(leader_civ, expected_civ); - assert_eq!(leader_siv, expected_siv); - - assert_eq!(follower_cwk, expected_cwk); - assert_eq!(follower_swk, expected_swk); - assert_eq!(follower_civ, expected_civ); - assert_eq!(follower_siv, expected_siv); - - let cf_hs_hash = [1u8; 32]; - let sf_hs_hash = [2u8; 32]; - - leader.set_cf_hash(&mut leader_vm, cf_hs_hash).unwrap(); - leader.set_sf_hash(&mut leader_vm, sf_hs_hash).unwrap(); - - follower.set_cf_hash(&mut follower_vm, cf_hs_hash).unwrap(); - follower.set_sf_hash(&mut follower_vm, sf_hs_hash).unwrap(); - - let leader_cf_vd = leader_vm.decode(leader_output.cf_vd).unwrap(); - let leader_sf_vd = leader_vm.decode(leader_output.sf_vd).unwrap(); - - let follower_cf_vd = follower_vm.decode(follower_output.cf_vd).unwrap(); - let follower_sf_vd = follower_vm.decode(follower_output.sf_vd).unwrap(); - - futures::join!( - async { - leader_vm.flush(&mut leader_ctx).await.unwrap(); - leader_vm.execute(&mut leader_ctx).await.unwrap(); - leader_vm.flush(&mut leader_ctx).await.unwrap(); - }, - async { - follower_vm.flush(&mut follower_ctx).await.unwrap(); - follower_vm.execute(&mut follower_ctx).await.unwrap(); - follower_vm.flush(&mut follower_ctx).await.unwrap(); - } - ); - - let leader_cf_vd = leader_cf_vd.await.unwrap(); - let leader_sf_vd = leader_sf_vd.await.unwrap(); - - let follower_cf_vd = follower_cf_vd.await.unwrap(); - let follower_sf_vd = follower_sf_vd.await.unwrap(); - - let expected_cf_vd = compute_vd(ms, b"client finished", cf_hs_hash); - let expected_sf_vd = compute_vd(ms, b"server finished", sf_hs_hash); - - assert_eq!(leader_cf_vd, expected_cf_vd); - assert_eq!(leader_sf_vd, expected_sf_vd); - assert_eq!(follower_cf_vd, expected_cf_vd); - assert_eq!(follower_sf_vd, expected_sf_vd); + assert_eq!(sf_vd_leader, sf_vd_follower); + assert_eq!(sf_vd_leader, sf_vd_expected); } } diff --git a/crates/components/hmac-sha256/src/prf.rs b/crates/components/hmac-sha256/src/prf.rs index 5081ce6639..f928f48bdb 100644 --- a/crates/components/hmac-sha256/src/prf.rs +++ b/crates/components/hmac-sha256/src/prf.rs @@ -1,98 +1,41 @@ -use std::{ - fmt::Debug, - sync::{Arc, OnceLock}, +use crate::{ + hmac::{IPAD, OPAD}, + Mode, PrfError, PrfOutput, }; - -use hmac_sha256_circuits::{build_session_keys, build_verify_data}; -use mpz_circuits::Circuit; -use mpz_common::cpu::CpuBackend; +use mpz_circuits::{circuits::xor, Circuit, CircuitBuilder}; +use mpz_hash::sha256::Sha256; use mpz_vm_core::{ memory::{ - binary::{Binary, U32, U8}, - Array, + binary::{Binary, U8}, + Array, MemoryExt, StaticSize, Vector, ViewExt, }, - prelude::*, - Call, Vm, + Call, CallableExt, Vm, }; +use std::{fmt::Debug, sync::Arc}; use tracing::instrument; -use crate::{PrfConfig, PrfError, PrfOutput, Role, SessionKeys, CF_LABEL, SF_LABEL}; - -pub(crate) struct Circuits { - session_keys: Arc, - client_vd: Arc, - server_vd: Arc, -} - -impl Circuits { - pub(crate) async fn get() -> &'static Self { - static CIRCUITS: OnceLock = OnceLock::new(); - if let Some(circuits) = CIRCUITS.get() { - return circuits; - } - - let (session_keys, client_vd, server_vd) = futures::join!( - CpuBackend::blocking(build_session_keys), - CpuBackend::blocking(|| build_verify_data(CF_LABEL)), - CpuBackend::blocking(|| build_verify_data(SF_LABEL)), - ); +mod state; +use state::State; - _ = CIRCUITS.set(Circuits { - session_keys, - client_vd, - server_vd, - }); - - CIRCUITS.get().unwrap() - } -} +mod function; +use function::Prf; +/// MPC PRF for computing TLS 1.2 HMAC-SHA256 PRF. #[derive(Debug)] -pub(crate) enum State { - Initialized, - SessionKeys { - client_random: Array, - server_random: Array, - cf_hash: Array, - sf_hash: Array, - }, - ClientFinished { - cf_hash: Array, - sf_hash: Array, - }, - ServerFinished { - sf_hash: Array, - }, - Complete, - Error, -} - -impl State { - fn take(&mut self) -> State { - std::mem::replace(self, State::Error) - } -} - -/// MPC PRF for computing TLS HMAC-SHA256 PRF. pub struct MpcPrf { - config: PrfConfig, + mode: Mode, state: State, } -impl Debug for MpcPrf { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("MpcPrf") - .field("config", &self.config) - .field("state", &self.state) - .finish() - } -} - impl MpcPrf { /// Creates a new instance of the PRF. - pub fn new(config: PrfConfig) -> MpcPrf { - MpcPrf { - config, + /// + /// # Arguments + /// + /// `mode` - The PRF mode. + pub fn new(mode: Mode) -> MpcPrf { + Self { + mode, state: State::Initialized, } } @@ -113,122 +56,58 @@ impl MpcPrf { return Err(PrfError::state("PRF not in initialized state")); }; - let circuits = futures::executor::block_on(Circuits::get()); - - let client_random = vm.alloc().map_err(PrfError::vm)?; - let server_random = vm.alloc().map_err(PrfError::vm)?; - - // The client random is kept private so that the handshake transcript - // hashes do not leak information about the server's identity. - match self.config.role { - Role::Leader => vm.mark_private(client_random), - Role::Follower => vm.mark_blind(client_random), - } - .map_err(PrfError::vm)?; - - vm.mark_public(server_random).map_err(PrfError::vm)?; - - #[allow(clippy::type_complexity)] - let ( - client_write_key, - server_write_key, - client_iv, - server_iv, - ms_outer_hash_state, - ms_inner_hash_state, - ): ( - Array, - Array, - Array, - Array, - Array, - Array, - ) = vm - .call( - Call::builder(circuits.session_keys.clone()) - .arg(pms) - .arg(client_random) - .arg(server_random) - .build() - .map_err(PrfError::vm)?, - ) - .map_err(PrfError::vm)?; - - let keys = SessionKeys { - client_write_key, - server_write_key, - client_iv, - server_iv, - }; - - let cf_hash = vm.alloc().map_err(PrfError::vm)?; - vm.mark_public(cf_hash).map_err(PrfError::vm)?; - - let cf_vd = vm - .call( - Call::builder(circuits.client_vd.clone()) - .arg(ms_outer_hash_state) - .arg(ms_inner_hash_state) - .arg(cf_hash) - .build() - .map_err(PrfError::vm)?, - ) - .map_err(PrfError::vm)?; - - let sf_hash = vm.alloc().map_err(PrfError::vm)?; - vm.mark_public(sf_hash).map_err(PrfError::vm)?; - - let sf_vd = vm - .call( - Call::builder(circuits.server_vd.clone()) - .arg(ms_outer_hash_state) - .arg(ms_inner_hash_state) - .arg(sf_hash) - .build() - .map_err(PrfError::vm)?, - ) - .map_err(PrfError::vm)?; + let mode = self.mode; + let pms: Vector = pms.into(); + + let outer_partial_pms = compute_partial(vm, pms, OPAD)?; + let inner_partial_pms = compute_partial(vm, pms, IPAD)?; + + let master_secret = + Prf::alloc_master_secret(mode, vm, outer_partial_pms, inner_partial_pms)?; + let ms = master_secret.output(); + let ms = merge_outputs(vm, ms, 48)?; + + let outer_partial_ms = compute_partial(vm, ms, OPAD)?; + let inner_partial_ms = compute_partial(vm, ms, IPAD)?; + + let key_expansion = + Prf::alloc_key_expansion(mode, vm, outer_partial_ms.clone(), inner_partial_ms.clone())?; + let client_finished = Prf::alloc_client_finished( + mode, + vm, + outer_partial_ms.clone(), + inner_partial_ms.clone(), + )?; + let server_finished = Prf::alloc_server_finished( + mode, + vm, + outer_partial_ms.clone(), + inner_partial_ms.clone(), + )?; self.state = State::SessionKeys { - client_random, - server_random, - cf_hash, - sf_hash, + client_random: None, + master_secret, + key_expansion, + client_finished, + server_finished, }; - Ok(PrfOutput { keys, cf_vd, sf_vd }) + self.state.prf_output(vm) } /// Sets the client random. /// - /// Only the leader can provide the client random. - /// /// # Arguments /// - /// * `vm` - Virtual machine. - /// * `client_random` - The client random. + /// * `random` - The client random. #[instrument(level = "debug", skip_all, err)] - pub fn set_client_random( - &mut self, - vm: &mut dyn Vm, - random: Option<[u8; 32]>, - ) -> Result<(), PrfError> { - let State::SessionKeys { client_random, .. } = &self.state else { + pub fn set_client_random(&mut self, random: [u8; 32]) -> Result<(), PrfError> { + let State::SessionKeys { client_random, .. } = &mut self.state else { return Err(PrfError::state("PRF not set up")); }; - if self.config.role == Role::Leader { - let Some(random) = random else { - return Err(PrfError::role("leader must provide client random")); - }; - - vm.assign(*client_random, random).map_err(PrfError::vm)?; - } else if random.is_some() { - return Err(PrfError::role("only leader can set client random")); - } - - vm.commit(*client_random).map_err(PrfError::vm)?; - + *client_random = Some(random); Ok(()) } @@ -236,28 +115,29 @@ impl MpcPrf { /// /// # Arguments /// - /// * `vm` - Virtual machine. - /// * `server_random` - The server random. + /// * `random` - The server random. #[instrument(level = "debug", skip_all, err)] - pub fn set_server_random( - &mut self, - vm: &mut dyn Vm, - random: [u8; 32], - ) -> Result<(), PrfError> { + pub fn set_server_random(&mut self, random: [u8; 32]) -> Result<(), PrfError> { let State::SessionKeys { - server_random, - cf_hash, - sf_hash, + client_random, + master_secret, + key_expansion, .. - } = self.state.take() + } = &mut self.state else { return Err(PrfError::state("PRF not set up")); }; - vm.assign(server_random, random).map_err(PrfError::vm)?; - vm.commit(server_random).map_err(PrfError::vm)?; + let client_random = client_random.expect("Client random should have been set by now"); + let server_random = random; + + let mut seed_ms = client_random.to_vec(); + seed_ms.extend_from_slice(&server_random); + master_secret.set_start_seed(seed_ms); - self.state = State::ClientFinished { cf_hash, sf_hash }; + let mut seed_ke = server_random.to_vec(); + seed_ke.extend_from_slice(&client_random); + key_expansion.set_start_seed(seed_ke); Ok(()) } @@ -266,22 +146,18 @@ impl MpcPrf { /// /// # Arguments /// - /// * `vm` - Virtual machine. /// * `handshake_hash` - The handshake transcript hash. #[instrument(level = "debug", skip_all, err)] - pub fn set_cf_hash( - &mut self, - vm: &mut dyn Vm, - handshake_hash: [u8; 32], - ) -> Result<(), PrfError> { - let State::ClientFinished { cf_hash, sf_hash } = self.state.take() else { + pub fn set_cf_hash(&mut self, handshake_hash: [u8; 32]) -> Result<(), PrfError> { + let State::ClientFinished { + client_finished, .. + } = &mut self.state + else { return Err(PrfError::state("PRF not in client finished state")); }; - vm.assign(cf_hash, handshake_hash).map_err(PrfError::vm)?; - vm.commit(cf_hash).map_err(PrfError::vm)?; - - self.state = State::ServerFinished { sf_hash }; + let seed_cf = handshake_hash.to_vec(); + client_finished.set_start_seed(seed_cf); Ok(()) } @@ -290,23 +166,242 @@ impl MpcPrf { /// /// # Arguments /// - /// * `vm` - Virtual machine. /// * `handshake_hash` - The handshake transcript hash. #[instrument(level = "debug", skip_all, err)] - pub fn set_sf_hash( - &mut self, - vm: &mut dyn Vm, - handshake_hash: [u8; 32], - ) -> Result<(), PrfError> { - let State::ServerFinished { sf_hash } = self.state.take() else { + pub fn set_sf_hash(&mut self, handshake_hash: [u8; 32]) -> Result<(), PrfError> { + let State::ServerFinished { server_finished } = &mut self.state else { return Err(PrfError::state("PRF not in server finished state")); }; - vm.assign(sf_hash, handshake_hash).map_err(PrfError::vm)?; - vm.commit(sf_hash).map_err(PrfError::vm)?; + let seed_sf = handshake_hash.to_vec(); + server_finished.set_start_seed(seed_sf); + + Ok(()) + } + + /// Returns if the PRF needs to be flushed. + pub fn wants_flush(&self) -> bool { + match &self.state { + State::Initialized => false, + State::SessionKeys { + master_secret, + key_expansion, + .. + } => master_secret.wants_flush() || key_expansion.wants_flush(), + State::ClientFinished { + client_finished, .. + } => client_finished.wants_flush(), + State::ServerFinished { server_finished } => server_finished.wants_flush(), + State::Complete => false, + State::Error => false, + } + } - self.state = State::Complete; + /// Flushes the PRF. + pub fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { + self.state = match self.state.take() { + State::SessionKeys { + client_random, + mut master_secret, + mut key_expansion, + client_finished, + server_finished, + } => { + master_secret.flush(vm)?; + key_expansion.flush(vm)?; + + if !master_secret.wants_flush() && !key_expansion.wants_flush() { + State::ClientFinished { + client_finished, + server_finished, + } + } else { + State::SessionKeys { + client_random, + master_secret, + key_expansion, + client_finished, + server_finished, + } + } + } + State::ClientFinished { + mut client_finished, + server_finished, + } => { + client_finished.flush(vm)?; + + if !client_finished.wants_flush() { + State::ServerFinished { server_finished } + } else { + State::ClientFinished { + client_finished, + server_finished, + } + } + } + State::ServerFinished { + mut server_finished, + } => { + server_finished.flush(vm)?; + + if !server_finished.wants_flush() { + State::Complete + } else { + State::ServerFinished { server_finished } + } + } + other => other, + }; Ok(()) } } + +/// Depending on the provided `mask` computes and returns `outer_partial` or +/// `inner_partial` for HMAC-SHA256. +/// +/// # Arguments +/// +/// * `vm` - Virtual machine. +/// * `key` - Key to pad and xor. +/// * `mask`- Mask used for padding. +fn compute_partial( + vm: &mut dyn Vm, + key: Vector, + mask: [u8; 64], +) -> Result { + let xor = Arc::new(xor(8 * 64)); + + let additional_len = 64 - key.len(); + let padding = vec![0_u8; additional_len]; + + let padding_ref: Vector = vm.alloc_vec(additional_len).map_err(PrfError::vm)?; + vm.mark_public(padding_ref).map_err(PrfError::vm)?; + vm.assign(padding_ref, padding).map_err(PrfError::vm)?; + vm.commit(padding_ref).map_err(PrfError::vm)?; + + let mask_ref: Array = vm.alloc().map_err(PrfError::vm)?; + vm.mark_public(mask_ref).map_err(PrfError::vm)?; + vm.assign(mask_ref, mask).map_err(PrfError::vm)?; + vm.commit(mask_ref).map_err(PrfError::vm)?; + + let xor = Call::builder(xor) + .arg(key) + .arg(padding_ref) + .arg(mask_ref) + .build() + .map_err(PrfError::vm)?; + let key_padded: Vector = vm.call(xor).map_err(PrfError::vm)?; + + let mut sha = Sha256::new_with_init(vm)?; + sha.update(&key_padded); + sha.compress(vm)?; + Ok(sha) +} + +fn merge_outputs( + vm: &mut dyn Vm, + inputs: Vec>, + output_bytes: usize, +) -> Result, PrfError> { + assert!(output_bytes <= 32 * inputs.len()); + + let bits = Array::::SIZE * inputs.len(); + let circ = gen_merge_circ(bits); + + let mut builder = Call::builder(circ); + for &input in inputs.iter() { + builder = builder.arg(input); + } + let call = builder.build().map_err(PrfError::vm)?; + + let mut output: Vector = vm.call(call).map_err(PrfError::vm)?; + output.truncate(output_bytes); + Ok(output) +} + +fn gen_merge_circ(size: usize) -> Arc { + let mut builder = CircuitBuilder::new(); + let inputs = (0..size).map(|_| builder.add_input()).collect::>(); + + for input in inputs.chunks_exact(8) { + for byte in input.chunks_exact(8) { + for &feed in byte.iter() { + let output = builder.add_id_gate(feed); + builder.add_output(output); + } + } + } + + Arc::new(builder.build().expect("merge circuit is valid")) +} + +#[cfg(test)] +mod tests { + use crate::{prf::merge_outputs, test_utils::mock_vm}; + use mpz_common::context::test_st_context; + use mpz_vm_core::{ + memory::{binary::U8, Array, MemoryExt, ViewExt}, + Execute, + }; + + #[tokio::test] + async fn test_merge_outputs() { + let (mut ctx_a, mut ctx_b) = test_st_context(8); + let (mut leader, mut follower) = mock_vm(); + + let input1: [u8; 32] = std::array::from_fn(|i| i as u8); + let input2: [u8; 32] = std::array::from_fn(|i| i as u8 + 32); + + let mut expected = input1.to_vec(); + expected.extend_from_slice(&input2); + expected.truncate(48); + + // leader + let input1_leader: Array = leader.alloc().unwrap(); + let input2_leader: Array = leader.alloc().unwrap(); + + leader.mark_public(input1_leader).unwrap(); + leader.mark_public(input2_leader).unwrap(); + + leader.assign(input1_leader, input1).unwrap(); + leader.assign(input2_leader, input2).unwrap(); + + leader.commit(input1_leader).unwrap(); + leader.commit(input2_leader).unwrap(); + + let merged_leader = + merge_outputs(&mut leader, vec![input1_leader, input2_leader], 48).unwrap(); + let mut merged_leader = leader.decode(merged_leader).unwrap(); + + // follower + let input1_follower: Array = follower.alloc().unwrap(); + let input2_follower: Array = follower.alloc().unwrap(); + + follower.mark_public(input1_follower).unwrap(); + follower.mark_public(input2_follower).unwrap(); + + follower.assign(input1_follower, input1).unwrap(); + follower.assign(input2_follower, input2).unwrap(); + + follower.commit(input1_follower).unwrap(); + follower.commit(input2_follower).unwrap(); + + let merged_follower = + merge_outputs(&mut follower, vec![input1_follower, input2_follower], 48).unwrap(); + let mut merged_follower = follower.decode(merged_follower).unwrap(); + + tokio::try_join!( + leader.execute_all(&mut ctx_a), + follower.execute_all(&mut ctx_b) + ) + .unwrap(); + + let merged_leader = merged_leader.try_recv().unwrap().unwrap(); + let merged_follower = merged_follower.try_recv().unwrap().unwrap(); + + assert_eq!(merged_leader, merged_follower); + assert_eq!(merged_leader, expected); + } +} diff --git a/crates/components/hmac-sha256/src/prf/function.rs b/crates/components/hmac-sha256/src/prf/function.rs new file mode 100644 index 0000000000..e1e932a02f --- /dev/null +++ b/crates/components/hmac-sha256/src/prf/function.rs @@ -0,0 +1,257 @@ +//! Provides [`Prf`], for computing the TLS 1.2 PRF. + +use crate::{Mode, PrfError}; +use mpz_hash::sha256::Sha256; +use mpz_vm_core::{ + memory::{ + binary::{Binary, U8}, + Array, + }, + Vm, +}; + +mod normal; +mod reduced; + +#[derive(Debug)] +pub(crate) enum Prf { + Reduced(reduced::PrfFunction), + Normal(normal::PrfFunction), +} + +impl Prf { + pub(crate) fn alloc_master_secret( + mode: Mode, + vm: &mut dyn Vm, + outer_partial: Sha256, + inner_partial: Sha256, + ) -> Result { + let prf = match mode { + Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_master_secret( + vm, + outer_partial, + inner_partial, + )?), + Mode::Normal => Self::Normal(normal::PrfFunction::alloc_master_secret( + vm, + outer_partial, + inner_partial, + )?), + }; + Ok(prf) + } + + pub(crate) fn alloc_key_expansion( + mode: Mode, + vm: &mut dyn Vm, + outer_partial: Sha256, + inner_partial: Sha256, + ) -> Result { + let prf = match mode { + Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_key_expansion( + vm, + outer_partial, + inner_partial, + )?), + Mode::Normal => Self::Normal(normal::PrfFunction::alloc_key_expansion( + vm, + outer_partial, + inner_partial, + )?), + }; + Ok(prf) + } + + pub(crate) fn alloc_client_finished( + config: Mode, + vm: &mut dyn Vm, + outer_partial: Sha256, + inner_partial: Sha256, + ) -> Result { + let prf = match config { + Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_client_finished( + vm, + outer_partial, + inner_partial, + )?), + Mode::Normal => Self::Normal(normal::PrfFunction::alloc_client_finished( + vm, + outer_partial, + inner_partial, + )?), + }; + Ok(prf) + } + + pub(crate) fn alloc_server_finished( + config: Mode, + vm: &mut dyn Vm, + outer_partial: Sha256, + inner_partial: Sha256, + ) -> Result { + let prf = match config { + Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_server_finished( + vm, + outer_partial, + inner_partial, + )?), + Mode::Normal => Self::Normal(normal::PrfFunction::alloc_server_finished( + vm, + outer_partial, + inner_partial, + )?), + }; + Ok(prf) + } + + pub(crate) fn wants_flush(&self) -> bool { + match self { + Prf::Reduced(prf) => prf.wants_flush(), + Prf::Normal(prf) => prf.wants_flush(), + } + } + + pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { + match self { + Prf::Reduced(prf) => prf.flush(vm), + Prf::Normal(prf) => prf.flush(vm), + } + } + + pub(crate) fn set_start_seed(&mut self, seed: Vec) { + match self { + Prf::Reduced(prf) => prf.set_start_seed(seed), + Prf::Normal(prf) => prf.set_start_seed(seed), + } + } + + pub(crate) fn output(&self) -> Vec> { + match self { + Prf::Reduced(prf) => prf.output(), + Prf::Normal(prf) => prf.output(), + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + prf::{compute_partial, function::Prf}, + test_utils::{mock_vm, phash}, + Mode, + }; + use mpz_common::context::test_st_context; + use mpz_vm_core::{ + memory::{binary::U8, Array, MemoryExt, ViewExt}, + Execute, + }; + use rand::{rngs::ThreadRng, Rng}; + + const IPAD: [u8; 64] = [0x36; 64]; + const OPAD: [u8; 64] = [0x5c; 64]; + + #[tokio::test] + async fn test_phash_reduced() { + let mode = Mode::Reduced; + test_phash(mode).await; + } + + #[tokio::test] + async fn test_phash_normal() { + let mode = Mode::Normal; + test_phash(mode).await; + } + + async fn test_phash(mode: Mode) { + let mut rng = ThreadRng::default(); + + let (mut ctx_a, mut ctx_b) = test_st_context(8); + let (mut leader, mut follower) = mock_vm(); + + let key: [u8; 32] = rng.random(); + let start_seed: Vec = vec![42; 64]; + + let mut label_seed = b"master secret".to_vec(); + label_seed.extend_from_slice(&start_seed); + let iterations = 2; + + let leader_key: Array = leader.alloc().unwrap(); + leader.mark_public(leader_key).unwrap(); + leader.assign(leader_key, key).unwrap(); + leader.commit(leader_key).unwrap(); + + let outer_partial_leader = compute_partial(&mut leader, leader_key.into(), OPAD).unwrap(); + let inner_partial_leader = compute_partial(&mut leader, leader_key.into(), IPAD).unwrap(); + + let mut prf_leader = Prf::alloc_master_secret( + mode, + &mut leader, + outer_partial_leader, + inner_partial_leader, + ) + .unwrap(); + prf_leader.set_start_seed(start_seed.clone()); + + let mut prf_out_leader = vec![]; + for p in prf_leader.output() { + let p_out = leader.decode(p).unwrap(); + prf_out_leader.push(p_out) + } + + let follower_key: Array = follower.alloc().unwrap(); + follower.mark_public(follower_key).unwrap(); + follower.assign(follower_key, key).unwrap(); + follower.commit(follower_key).unwrap(); + + let outer_partial_follower = + compute_partial(&mut follower, follower_key.into(), OPAD).unwrap(); + let inner_partial_follower = + compute_partial(&mut follower, follower_key.into(), IPAD).unwrap(); + + let mut prf_follower = Prf::alloc_master_secret( + mode, + &mut follower, + outer_partial_follower, + inner_partial_follower, + ) + .unwrap(); + prf_follower.set_start_seed(start_seed.clone()); + + let mut prf_out_follower = vec![]; + for p in prf_follower.output() { + let p_out = follower.decode(p).unwrap(); + prf_out_follower.push(p_out) + } + + while prf_leader.wants_flush() || prf_follower.wants_flush() { + tokio::try_join!( + async { + prf_leader.flush(&mut leader).unwrap(); + leader.execute_all(&mut ctx_a).await + }, + async { + prf_follower.flush(&mut follower).unwrap(); + follower.execute_all(&mut ctx_b).await + } + ) + .unwrap(); + } + + assert_eq!(prf_out_leader.len(), 2); + assert_eq!(prf_out_leader.len(), prf_out_follower.len()); + + let prf_result_leader: Vec = prf_out_leader + .iter_mut() + .flat_map(|p| p.try_recv().unwrap().unwrap()) + .collect(); + let prf_result_follower: Vec = prf_out_follower + .iter_mut() + .flat_map(|p| p.try_recv().unwrap().unwrap()) + .collect(); + + let expected = phash(key.to_vec(), &label_seed, iterations); + + assert_eq!(prf_result_leader, prf_result_follower); + assert_eq!(prf_result_leader, expected) + } +} diff --git a/crates/components/hmac-sha256/src/prf/function/normal.rs b/crates/components/hmac-sha256/src/prf/function/normal.rs new file mode 100644 index 0000000000..ec931f24ef --- /dev/null +++ b/crates/components/hmac-sha256/src/prf/function/normal.rs @@ -0,0 +1,174 @@ +//! Computes the whole PRF in MPC. + +use crate::{hmac::hmac_sha256, PrfError}; +use mpz_hash::sha256::Sha256; +use mpz_vm_core::{ + memory::{ + binary::{Binary, U8}, + Array, MemoryExt, Vector, ViewExt, + }, + Vm, +}; + +#[derive(Debug)] +pub(crate) struct PrfFunction { + // The label, e.g. "master secret". + label: &'static [u8], + state: State, + // The start seed and the label, e.g. client_random + server_random + "master_secret". + start_seed_label: Option>, + a: Vec, + p: Vec, +} + +impl PrfFunction { + const MS_LABEL: &[u8] = b"master secret"; + const KEY_LABEL: &[u8] = b"key expansion"; + const CF_LABEL: &[u8] = b"client finished"; + const SF_LABEL: &[u8] = b"server finished"; + + pub(crate) fn alloc_master_secret( + vm: &mut dyn Vm, + outer_partial: Sha256, + inner_partial: Sha256, + ) -> Result { + Self::alloc(vm, Self::MS_LABEL, outer_partial, inner_partial, 48, 64) + } + + pub(crate) fn alloc_key_expansion( + vm: &mut dyn Vm, + outer_partial: Sha256, + inner_partial: Sha256, + ) -> Result { + Self::alloc(vm, Self::KEY_LABEL, outer_partial, inner_partial, 40, 64) + } + + pub(crate) fn alloc_client_finished( + vm: &mut dyn Vm, + outer_partial: Sha256, + inner_partial: Sha256, + ) -> Result { + Self::alloc(vm, Self::CF_LABEL, outer_partial, inner_partial, 12, 32) + } + + pub(crate) fn alloc_server_finished( + vm: &mut dyn Vm, + outer_partial: Sha256, + inner_partial: Sha256, + ) -> Result { + Self::alloc(vm, Self::SF_LABEL, outer_partial, inner_partial, 12, 32) + } + + pub(crate) fn wants_flush(&self) -> bool { + let is_computing = match self.state { + State::Computing => true, + State::Finished => false, + }; + is_computing && self.start_seed_label.is_some() + } + + pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { + if let State::Computing = self.state { + let a = self.a.first().expect("prf should be allocated"); + let msg = *a.msg.first().expect("message for prf should be present"); + + let msg_value = self + .start_seed_label + .clone() + .expect("Start seed should have been set"); + + vm.assign(msg, msg_value).map_err(PrfError::vm)?; + vm.commit(msg).map_err(PrfError::vm)?; + + self.state = State::Finished; + } + Ok(()) + } + + pub(crate) fn set_start_seed(&mut self, seed: Vec) { + let mut start_seed_label = self.label.to_vec(); + start_seed_label.extend_from_slice(&seed); + + self.start_seed_label = Some(start_seed_label); + } + + pub(crate) fn output(&self) -> Vec> { + self.p.iter().map(|p| p.output).collect() + } + + fn alloc( + vm: &mut dyn Vm, + label: &'static [u8], + outer_partial: Sha256, + inner_partial: Sha256, + output_len: usize, + seed_len: usize, + ) -> Result { + let mut prf = Self { + label, + state: State::Computing, + start_seed_label: None, + a: vec![], + p: vec![], + }; + + assert!(output_len > 0, "cannot compute 0 bytes for prf"); + + let iterations = output_len.div_ceil(32); + + let msg_len_a = label.len() + seed_len; + let seed_label_ref: Vector = vm.alloc_vec(msg_len_a).map_err(PrfError::vm)?; + vm.mark_public(seed_label_ref).map_err(PrfError::vm)?; + + let mut msg_a = seed_label_ref; + for _ in 0..iterations { + let a = PHash::alloc(vm, outer_partial.clone(), inner_partial.clone(), &[msg_a])?; + msg_a = Vector::::from(a.output); + prf.a.push(a); + + let p = PHash::alloc( + vm, + outer_partial.clone(), + inner_partial.clone(), + &[msg_a, seed_label_ref], + )?; + prf.p.push(p); + } + + Ok(prf) + } +} + +#[derive(Debug, Clone, Copy)] +enum State { + Computing, + Finished, +} + +#[derive(Debug, Clone)] +struct PHash { + msg: Vec>, + output: Array, +} + +impl PHash { + fn alloc( + vm: &mut dyn Vm, + outer_partial: Sha256, + inner_partial: Sha256, + msg: &[Vector], + ) -> Result { + let mut inner_local = inner_partial; + + msg.iter().for_each(|m| inner_local.update(m)); + inner_local.compress(vm)?; + let inner_local = inner_local.finalize(vm)?; + + let output = hmac_sha256(vm, outer_partial, inner_local)?; + let p_hash = Self { + msg: msg.to_vec(), + output, + }; + Ok(p_hash) + } +} diff --git a/crates/components/hmac-sha256/src/prf/function/reduced.rs b/crates/components/hmac-sha256/src/prf/function/reduced.rs new file mode 100644 index 0000000000..2120a82937 --- /dev/null +++ b/crates/components/hmac-sha256/src/prf/function/reduced.rs @@ -0,0 +1,247 @@ +//! Computes some hashes of the PRF locally. + +use std::collections::VecDeque; + +use crate::{hmac::hmac_sha256, sha256, state_to_bytes, PrfError}; +use mpz_core::bitvec::BitVec; +use mpz_hash::sha256::Sha256; +use mpz_vm_core::{ + memory::{ + binary::{Binary, U8}, + Array, DecodeFutureTyped, MemoryExt, ViewExt, + }, + Vm, +}; + +#[derive(Debug)] +pub(crate) struct PrfFunction { + // The label, e.g. "master secret". + label: &'static [u8], + // The start seed and the label, e.g. client_random + server_random + "master_secret". + start_seed_label: Option>, + iterations: usize, + state: PrfState, + a: VecDeque, + p: VecDeque, +} + +#[derive(Debug)] +enum PrfState { + InnerPartial { + inner_partial: DecodeFutureTyped, + }, + ComputeA { + iter: usize, + inner_partial: [u32; 8], + msg: Vec, + }, + ComputeP { + iter: usize, + inner_partial: [u32; 8], + a_output: DecodeFutureTyped, + }, + FinishLastP, + Done, +} + +impl PrfFunction { + const MS_LABEL: &[u8] = b"master secret"; + const KEY_LABEL: &[u8] = b"key expansion"; + const CF_LABEL: &[u8] = b"client finished"; + const SF_LABEL: &[u8] = b"server finished"; + + pub(crate) fn alloc_master_secret( + vm: &mut dyn Vm, + outer_partial: Sha256, + inner_partial: Sha256, + ) -> Result { + Self::alloc(vm, Self::MS_LABEL, outer_partial, inner_partial, 48) + } + + pub(crate) fn alloc_key_expansion( + vm: &mut dyn Vm, + outer_partial: Sha256, + inner_partial: Sha256, + ) -> Result { + Self::alloc(vm, Self::KEY_LABEL, outer_partial, inner_partial, 40) + } + + pub(crate) fn alloc_client_finished( + vm: &mut dyn Vm, + outer_partial: Sha256, + inner_partial: Sha256, + ) -> Result { + Self::alloc(vm, Self::CF_LABEL, outer_partial, inner_partial, 12) + } + + pub(crate) fn alloc_server_finished( + vm: &mut dyn Vm, + outer_partial: Sha256, + inner_partial: Sha256, + ) -> Result { + Self::alloc(vm, Self::SF_LABEL, outer_partial, inner_partial, 12) + } + + pub(crate) fn wants_flush(&self) -> bool { + !matches!(self.state, PrfState::Done) && self.start_seed_label.is_some() + } + + pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { + match &mut self.state { + PrfState::InnerPartial { inner_partial } => { + let Some(inner_partial) = inner_partial.try_recv().map_err(PrfError::vm)? else { + return Ok(()); + }; + + self.state = PrfState::ComputeA { + iter: 1, + inner_partial, + msg: self + .start_seed_label + .clone() + .expect("Start seed should have been set"), + }; + self.flush(vm)?; + } + PrfState::ComputeA { + iter, + inner_partial, + msg, + } => { + let a = self.a.pop_front().expect("Prf AHash should be present"); + assign_inner_local(vm, a.inner_local, *inner_partial, msg)?; + + self.state = PrfState::ComputeP { + iter: *iter, + inner_partial: *inner_partial, + a_output: a.output, + }; + } + PrfState::ComputeP { + iter, + inner_partial, + a_output, + } => { + let Some(output) = a_output.try_recv().map_err(PrfError::vm)? else { + return Ok(()); + }; + let p = self.p.pop_front().expect("Prf PHash should be present"); + + let mut msg = output.to_vec(); + msg.extend_from_slice( + self.start_seed_label + .as_ref() + .expect("Start seed should have been set"), + ); + + assign_inner_local(vm, p.inner_local, *inner_partial, &msg)?; + + if *iter == self.iterations { + self.state = PrfState::FinishLastP; + } else { + self.state = PrfState::ComputeA { + iter: *iter + 1, + inner_partial: *inner_partial, + msg: output.to_vec(), + } + }; + } + PrfState::FinishLastP => self.state = PrfState::Done, + _ => (), + } + + Ok(()) + } + + pub(crate) fn set_start_seed(&mut self, seed: Vec) { + let mut start_seed_label = self.label.to_vec(); + start_seed_label.extend_from_slice(&seed); + + self.start_seed_label = Some(start_seed_label); + } + + pub(crate) fn output(&self) -> Vec> { + self.p.iter().map(|p| p.output).collect() + } + + fn alloc( + vm: &mut dyn Vm, + label: &'static [u8], + outer_partial: Sha256, + inner_partial: Sha256, + len: usize, + ) -> Result { + assert!(len > 0, "cannot compute 0 bytes for prf"); + + let iterations = len.div_ceil(32); + + let (inner_partial, _) = inner_partial + .state() + .expect("state should be set for inner_partial"); + let inner_partial = vm.decode(inner_partial).map_err(PrfError::vm)?; + + let mut prf = Self { + label, + start_seed_label: None, + iterations, + state: PrfState::InnerPartial { inner_partial }, + a: VecDeque::new(), + p: VecDeque::new(), + }; + + for _ in 0..iterations { + // setup A[i] + let inner_local: Array = vm.alloc().map_err(PrfError::vm)?; + let output = hmac_sha256(vm, outer_partial.clone(), inner_local)?; + + let output = vm.decode(output).map_err(PrfError::vm)?; + let a_hash = AHash { + inner_local, + output, + }; + + prf.a.push_front(a_hash); + + // setup P[i] + let inner_local: Array = vm.alloc().map_err(PrfError::vm)?; + let output = hmac_sha256(vm, outer_partial.clone(), inner_local)?; + let p_hash = PHash { + inner_local, + output, + }; + prf.p.push_front(p_hash); + } + + Ok(prf) + } +} + +fn assign_inner_local( + vm: &mut dyn Vm, + inner_local: Array, + inner_partial: [u32; 8], + msg: &[u8], +) -> Result<(), PrfError> { + let inner_local_value = sha256(inner_partial, 64, msg); + + vm.mark_public(inner_local).map_err(PrfError::vm)?; + vm.assign(inner_local, state_to_bytes(inner_local_value)) + .map_err(PrfError::vm)?; + vm.commit(inner_local).map_err(PrfError::vm)?; + + Ok(()) +} + +/// Like PHash but stores the output as the decoding future because in the +/// reduced Prf we need to decode this output. +#[derive(Debug)] +struct AHash { + inner_local: Array, + output: DecodeFutureTyped, +} + +#[derive(Debug, Clone, Copy)] +struct PHash { + inner_local: Array, + output: Array, +} diff --git a/crates/components/hmac-sha256/src/prf/state.rs b/crates/components/hmac-sha256/src/prf/state.rs new file mode 100644 index 0000000000..6299c6d07d --- /dev/null +++ b/crates/components/hmac-sha256/src/prf/state.rs @@ -0,0 +1,103 @@ +use crate::{ + prf::{function::Prf, merge_outputs}, + PrfError, PrfOutput, SessionKeys, +}; +use mpz_vm_core::{ + memory::{ + binary::{Binary, U8}, + Array, FromRaw, ToRaw, + }, + Vm, +}; + +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +pub(crate) enum State { + Initialized, + SessionKeys { + client_random: Option<[u8; 32]>, + master_secret: Prf, + key_expansion: Prf, + client_finished: Prf, + server_finished: Prf, + }, + ClientFinished { + client_finished: Prf, + server_finished: Prf, + }, + ServerFinished { + server_finished: Prf, + }, + Complete, + Error, +} + +impl State { + pub(crate) fn take(&mut self) -> State { + std::mem::replace(self, State::Error) + } + + pub(crate) fn prf_output(&self, vm: &mut dyn Vm) -> Result { + let State::SessionKeys { + key_expansion, + client_finished, + server_finished, + .. + } = self + else { + return Err(PrfError::state( + "Prf output can only be computed while in \"SessionKeys\" state", + )); + }; + + let keys = get_session_keys(key_expansion.output(), vm)?; + let cf_vd = get_client_finished_vd(client_finished.output(), vm)?; + let sf_vd = get_server_finished_vd(server_finished.output(), vm)?; + + let output = PrfOutput { keys, cf_vd, sf_vd }; + + Ok(output) + } +} + +fn get_session_keys( + output: Vec>, + vm: &mut dyn Vm, +) -> Result { + let mut keys = merge_outputs(vm, output, 40)?; + debug_assert!(keys.len() == 40, "session keys len should be 40"); + + let server_iv = Array::::try_from(keys.split_off(36)).unwrap(); + let client_iv = Array::::try_from(keys.split_off(32)).unwrap(); + let server_write_key = Array::::try_from(keys.split_off(16)).unwrap(); + let client_write_key = Array::::try_from(keys).unwrap(); + + let session_keys = SessionKeys { + client_write_key, + server_write_key, + client_iv, + server_iv, + }; + + Ok(session_keys) +} + +fn get_client_finished_vd( + output: Vec>, + vm: &mut dyn Vm, +) -> Result, PrfError> { + let cf_vd = merge_outputs(vm, output, 12)?; + let cf_vd = as FromRaw>::from_raw(cf_vd.to_raw()); + + Ok(cf_vd) +} + +fn get_server_finished_vd( + output: Vec>, + vm: &mut dyn Vm, +) -> Result, PrfError> { + let sf_vd = merge_outputs(vm, output, 12)?; + let sf_vd = as FromRaw>::from_raw(sf_vd.to_raw()); + + Ok(sf_vd) +} diff --git a/crates/components/hmac-sha256/src/test_utils.rs b/crates/components/hmac-sha256/src/test_utils.rs new file mode 100644 index 0000000000..65f340f178 --- /dev/null +++ b/crates/components/hmac-sha256/src/test_utils.rs @@ -0,0 +1,261 @@ +use crate::{sha256, state_to_bytes}; +use mpz_garble::protocol::semihonest::{Evaluator, Garbler}; +use mpz_ot::ideal::cot::{ideal_cot, IdealCOTReceiver, IdealCOTSender}; +use mpz_vm_core::memory::correlated::Delta; +use rand::{rngs::StdRng, Rng, SeedableRng}; + +pub(crate) const SHA256_IV: [u32; 8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, +]; + +pub(crate) fn mock_vm() -> (Garbler, Evaluator) { + let mut rng = StdRng::seed_from_u64(0); + let delta = Delta::random(&mut rng); + + let (cot_send, cot_recv) = ideal_cot(delta.into_inner()); + + let gen = Garbler::new(cot_send, [0u8; 16], delta); + let ev = Evaluator::new(cot_recv); + + (gen, ev) +} + +pub(crate) fn prf_ms(pms: [u8; 32], client_random: [u8; 32], server_random: [u8; 32]) -> [u8; 48] { + let mut label_start_seed = b"master secret".to_vec(); + label_start_seed.extend_from_slice(&client_random); + label_start_seed.extend_from_slice(&server_random); + + let ms = phash(pms.to_vec(), &label_start_seed, 2)[..48].to_vec(); + + ms.try_into().unwrap() +} + +pub(crate) fn prf_keys( + ms: [u8; 48], + client_random: [u8; 32], + server_random: [u8; 32], +) -> [Vec; 4] { + let mut label_start_seed = b"key expansion".to_vec(); + label_start_seed.extend_from_slice(&server_random); + label_start_seed.extend_from_slice(&client_random); + + let mut session_keys = phash(ms.to_vec(), &label_start_seed, 2)[..40].to_vec(); + + let server_iv = session_keys.split_off(36); + let client_iv = session_keys.split_off(32); + let server_write_key = session_keys.split_off(16); + let client_write_key = session_keys; + + [client_write_key, server_write_key, client_iv, server_iv] +} + +pub(crate) fn prf_cf_vd(ms: [u8; 48], hanshake_hash: [u8; 32]) -> Vec { + let mut label_start_seed = b"client finished".to_vec(); + label_start_seed.extend_from_slice(&hanshake_hash); + + phash(ms.to_vec(), &label_start_seed, 1)[..12].to_vec() +} + +pub(crate) fn prf_sf_vd(ms: [u8; 48], hanshake_hash: [u8; 32]) -> Vec { + let mut label_start_seed = b"server finished".to_vec(); + label_start_seed.extend_from_slice(&hanshake_hash); + + phash(ms.to_vec(), &label_start_seed, 1)[..12].to_vec() +} + +pub(crate) fn phash(key: Vec, seed: &[u8], iterations: usize) -> Vec { + // A() is defined as: + // + // A(0) = seed + // A(i) = HMAC_hash(secret, A(i-1)) + let mut a_cache: Vec<_> = Vec::with_capacity(iterations + 1); + a_cache.push(seed.to_vec()); + + for i in 0..iterations { + let a_i = hmac_sha256(key.clone(), &a_cache[i]); + a_cache.push(a_i.to_vec()); + } + + // HMAC_hash(secret, A(i) + seed) + let mut output: Vec<_> = Vec::with_capacity(iterations * 32); + for i in 0..iterations { + let mut a_i_seed = a_cache[i + 1].clone(); + a_i_seed.extend_from_slice(seed); + + let hash = hmac_sha256(key.clone(), &a_i_seed); + output.extend_from_slice(&hash); + } + + output +} + +pub(crate) fn hmac_sha256(key: Vec, msg: &[u8]) -> [u8; 32] { + let outer_partial = compute_outer_partial(key.clone()); + let inner_local = compute_inner_local(key, msg); + + let hmac = sha256(outer_partial, 64, &state_to_bytes(inner_local)); + state_to_bytes(hmac) +} + +pub(crate) fn compute_outer_partial(mut key: Vec) -> [u32; 8] { + assert!(key.len() <= 64); + + key.resize(64, 0_u8); + let key_padded: [u8; 64] = key + .into_iter() + .map(|b| b ^ 0x5c) + .collect::>() + .try_into() + .unwrap(); + + compress_256(SHA256_IV, &key_padded) +} + +pub(crate) fn compute_inner_local(mut key: Vec, msg: &[u8]) -> [u32; 8] { + assert!(key.len() <= 64); + + key.resize(64, 0_u8); + let key_padded: [u8; 64] = key + .into_iter() + .map(|b| b ^ 0x36) + .collect::>() + .try_into() + .unwrap(); + + let state = compress_256(SHA256_IV, &key_padded); + sha256(state, 64, msg) +} + +pub(crate) fn compress_256(mut state: [u32; 8], msg: &[u8]) -> [u32; 8] { + use sha2::{ + compress256, + digest::{ + block_buffer::{BlockBuffer, Eager}, + generic_array::typenum::U64, + }, + }; + + let mut buffer = BlockBuffer::::default(); + buffer.digest_blocks(msg, |b| compress256(&mut state, b)); + state +} + +// Borrowed from Rustls for testing +// https://github.com/rustls/rustls/blob/main/rustls/src/tls12/prf.rs +mod ring_prf { + use ring::{hmac, hmac::HMAC_SHA256}; + + fn concat_sign(key: &hmac::Key, a: &[u8], b: &[u8]) -> hmac::Tag { + let mut ctx = hmac::Context::with_key(key); + ctx.update(a); + ctx.update(b); + ctx.sign() + } + + fn p(out: &mut [u8], secret: &[u8], seed: &[u8]) { + let hmac_key = hmac::Key::new(HMAC_SHA256, secret); + + // A(1) + let mut current_a = hmac::sign(&hmac_key, seed); + let chunk_size = HMAC_SHA256.digest_algorithm().output_len(); + for chunk in out.chunks_mut(chunk_size) { + // P_hash[i] = HMAC_hash(secret, A(i) + seed) + let p_term = concat_sign(&hmac_key, current_a.as_ref(), seed); + chunk.copy_from_slice(&p_term.as_ref()[..chunk.len()]); + + // A(i+1) = HMAC_hash(secret, A(i)) + current_a = hmac::sign(&hmac_key, current_a.as_ref()); + } + } + + fn concat(a: &[u8], b: &[u8]) -> Vec { + let mut ret = Vec::new(); + ret.extend_from_slice(a); + ret.extend_from_slice(b); + ret + } + + pub(crate) fn prf(out: &mut [u8], secret: &[u8], label: &[u8], seed: &[u8]) { + let joined_seed = concat(label, seed); + p(out, secret, &joined_seed); + } +} + +#[test] +fn test_prf_reference_ms() { + use ring_prf::prf as prf_ref; + + let mut rng = StdRng::from_seed([1; 32]); + + let pms: [u8; 32] = rng.random(); + let label: &[u8] = b"master secret"; + let client_random: [u8; 32] = rng.random(); + let server_random: [u8; 32] = rng.random(); + let mut seed = Vec::from(client_random); + seed.extend_from_slice(&server_random); + + let ms = prf_ms(pms, client_random, server_random); + + let mut expected_ms: [u8; 48] = [0; 48]; + prf_ref(&mut expected_ms, &pms, label, &seed); + + assert_eq!(ms, expected_ms); +} + +#[test] +fn test_prf_reference_ke() { + use ring_prf::prf as prf_ref; + + let mut rng = StdRng::from_seed([2; 32]); + + let ms: [u8; 48] = rng.random(); + let label: &[u8] = b"key expansion"; + let client_random: [u8; 32] = rng.random(); + let server_random: [u8; 32] = rng.random(); + let mut seed = Vec::from(server_random); + seed.extend_from_slice(&client_random); + + let keys = prf_keys(ms, client_random, server_random); + let keys: Vec = keys.into_iter().flatten().collect(); + + let mut expected_keys: [u8; 40] = [0; 40]; + prf_ref(&mut expected_keys, &ms, label, &seed); + + assert_eq!(keys, expected_keys); +} + +#[test] +fn test_prf_reference_cf() { + use ring_prf::prf as prf_ref; + + let mut rng = StdRng::from_seed([3; 32]); + + let ms: [u8; 48] = rng.random(); + let label: &[u8] = b"client finished"; + let handshake_hash: [u8; 32] = rng.random(); + + let cf_vd = prf_cf_vd(ms, handshake_hash); + + let mut expected_cf_vd: [u8; 12] = [0; 12]; + prf_ref(&mut expected_cf_vd, &ms, label, &handshake_hash); + + assert_eq!(cf_vd, expected_cf_vd); +} + +#[test] +fn test_prf_reference_sf() { + use ring_prf::prf as prf_ref; + + let mut rng = StdRng::from_seed([4; 32]); + + let ms: [u8; 48] = rng.random(); + let label: &[u8] = b"server finished"; + let handshake_hash: [u8; 32] = rng.random(); + + let sf_vd = prf_sf_vd(ms, handshake_hash); + + let mut expected_sf_vd: [u8; 12] = [0; 12]; + prf_ref(&mut expected_sf_vd, &ms, label, &handshake_hash); + + assert_eq!(sf_vd, expected_sf_vd); +} diff --git a/crates/mpc-tls/Cargo.toml b/crates/mpc-tls/Cargo.toml index 6572967640..bc09300619 100644 --- a/crates/mpc-tls/Cargo.toml +++ b/crates/mpc-tls/Cargo.toml @@ -20,7 +20,7 @@ default = [] [dependencies] tlsn-cipher = { workspace = true } tlsn-common = { workspace = true } -#tlsn-hmac-sha256 = { workspace = true } +tlsn-hmac-sha256 = { workspace = true } tlsn-key-exchange = { workspace = true } tlsn-tls-backend = { workspace = true } tlsn-tls-core = { workspace = true, features = ["serde"] } diff --git a/crates/mpc-tls/src/config.rs b/crates/mpc-tls/src/config.rs index 24f32189d6..4e5e8b1dab 100644 --- a/crates/mpc-tls/src/config.rs +++ b/crates/mpc-tls/src/config.rs @@ -1,4 +1,5 @@ use derive_builder::Builder; +use hmac_sha256::Mode as PrfMode; /// Number of TLS protocol bytes that will be sent. const PROTOCOL_DATA_SENT: usize = 32; @@ -55,6 +56,9 @@ pub struct Config { /// Maximum number of received bytes. #[allow(unused)] pub(crate) max_recv: usize, + /// Configuration options for the PRF. + #[builder(setter(custom))] + pub(crate) prf: PrfMode, } impl Config { @@ -65,6 +69,12 @@ impl Config { } impl ConfigBuilder { + /// Optimizes the protocol for low bandwidth networks. + pub fn low_bandwidth(&mut self) -> &mut Self { + self.prf = Some(PrfMode::Reduced); + self + } + /// Builds the configuration. pub fn build(&self) -> Result { let defer_decryption = self.defer_decryption.unwrap_or(true); @@ -95,6 +105,8 @@ impl ConfigBuilder { .max_recv_records .unwrap_or_else(|| PROTOCOL_RECORD_COUNT_RECV + default_record_count(max_recv)); + let prf = self.prf.unwrap_or_default(); + Ok(Config { defer_decryption, max_sent_records, @@ -102,6 +114,7 @@ impl ConfigBuilder { max_recv_records, max_recv_online, max_recv, + prf, }) } } diff --git a/crates/mpc-tls/src/follower.rs b/crates/mpc-tls/src/follower.rs index dd56f65600..b3dcc0c0a0 100644 --- a/crates/mpc-tls/src/follower.rs +++ b/crates/mpc-tls/src/follower.rs @@ -3,7 +3,7 @@ use crate::{ record_layer::{aead::MpcAesGcm, RecordLayer}, Config, FollowerData, MpcTlsError, Role, SessionKeys, Vm, }; -use hmac_sha256::{MpcPrf, PrfConfig, PrfOutput}; +use hmac_sha256::{MpcPrf, PrfOutput}; use ke::KeyExchange; use key_exchange::{self as ke, MpcKeyExchange}; use mpz_common::{Context, Flush}; @@ -63,12 +63,7 @@ impl MpcTlsFollower { )), )) as Box; - let prf = MpcPrf::new( - PrfConfig::builder() - .role(hmac_sha256::Role::Follower) - .build() - .expect("PRF config is valid"), - ); + let prf = MpcPrf::new(config.prf); let encrypter = MpcAesGcm::new( ShareConversionReceiver::new(OLEReceiver::new(AnyReceiver::new( @@ -123,8 +118,6 @@ impl MpcTlsFollower { keys.server_iv, )?; - prf.set_client_random(vm, None)?; - let cf_vd = vm.decode(cf_vd).map_err(MpcTlsError::alloc)?; let sf_vd = vm.decode(sf_vd).map_err(MpcTlsError::alloc)?; @@ -230,6 +223,7 @@ impl MpcTlsFollower { return Err(MpcTlsError::state("must be in ready state to run")); }; + let mut client_random = None; let mut server_random = None; let mut server_key = None; let mut cf_vd = None; @@ -237,17 +231,20 @@ impl MpcTlsFollower { loop { let msg: Message = self.ctx.io_mut().expect_next().await?; match msg { + Message::SetClientRandom(random) => { + if client_random.is_some() { + return Err(MpcTlsError::hs("client random already set")); + } + + prf.set_client_random(random.random)?; + client_random = Some(random); + } Message::SetServerRandom(random) => { if server_random.is_some() { return Err(MpcTlsError::hs("server random already set")); } - let mut vm = vm - .try_lock() - .map_err(|_| MpcTlsError::other("VM lock is held"))?; - - prf.set_server_random(&mut (*vm), random.random)?; - + prf.set_server_random(random.random)?; server_random = Some(random); } Message::SetServerKey(key) => { @@ -274,9 +271,12 @@ impl MpcTlsFollower { ke.compute_shares(&mut self.ctx).await?; ke.assign(&mut (*vm))?; - vm.execute_all(&mut self.ctx) - .await - .map_err(MpcTlsError::hs)?; + while prf.wants_flush() { + prf.flush(&mut *vm)?; + vm.execute_all(&mut self.ctx) + .await + .map_err(MpcTlsError::hs)?; + } ke.finalize().await?; record_layer.setup(&mut self.ctx).await?; @@ -290,11 +290,14 @@ impl MpcTlsFollower { .try_lock() .map_err(|_| MpcTlsError::other("VM lock is held"))?; - prf.set_cf_hash(&mut (*vm), vd.handshake_hash)?; + prf.set_cf_hash(vd.handshake_hash)?; - vm.execute_all(&mut self.ctx) - .await - .map_err(MpcTlsError::hs)?; + while prf.wants_flush() { + prf.flush(&mut *vm)?; + vm.execute_all(&mut self.ctx) + .await + .map_err(MpcTlsError::hs)?; + } cf_vd = Some( cf_vd_fut @@ -312,11 +315,14 @@ impl MpcTlsFollower { .try_lock() .map_err(|_| MpcTlsError::other("VM lock is held"))?; - prf.set_sf_hash(&mut (*vm), vd.handshake_hash)?; + prf.set_sf_hash(vd.handshake_hash)?; - vm.execute_all(&mut self.ctx) - .await - .map_err(MpcTlsError::hs)?; + while prf.wants_flush() { + prf.flush(&mut *vm)?; + vm.execute_all(&mut self.ctx) + .await + .map_err(MpcTlsError::hs)?; + } sf_vd = Some( sf_vd_fut diff --git a/crates/mpc-tls/src/leader.rs b/crates/mpc-tls/src/leader.rs index d1cb1cff9b..5ea16be1e7 100644 --- a/crates/mpc-tls/src/leader.rs +++ b/crates/mpc-tls/src/leader.rs @@ -3,15 +3,15 @@ mod actor; use crate::{ error::MpcTlsError, msg::{ - ClientFinishedVd, Decrypt, Encrypt, Message, ServerFinishedVd, SetServerKey, - SetServerRandom, + ClientFinishedVd, Decrypt, Encrypt, Message, ServerFinishedVd, SetClientRandom, + SetServerKey, SetServerRandom, }, record_layer::{aead::MpcAesGcm, DecryptMode, EncryptMode, RecordLayer}, utils::opaque_into_parts, Config, LeaderOutput, Role, SessionKeys, Vm, }; use async_trait::async_trait; -use hmac_sha256::{MpcPrf, PrfConfig, PrfOutput}; +use hmac_sha256::{MpcPrf, PrfOutput}; use ke::KeyExchange; use key_exchange::{self as ke, MpcKeyExchange}; use ludi::Context as LudiContext; @@ -87,12 +87,7 @@ impl MpcTlsLeader { ))), )) as Box; - let prf = MpcPrf::new( - PrfConfig::builder() - .role(hmac_sha256::Role::Leader) - .build() - .expect("prf config is valid"), - ); + let prf = MpcPrf::new(config.prf); let encrypter = MpcAesGcm::new( ShareConversionSender::new(OLESender::new( @@ -157,8 +152,6 @@ impl MpcTlsLeader { keys.server_iv, )?; - prf.set_client_random(&mut (*vm_lock), Some(client_random.0))?; - let cf_vd = vm_lock.decode(cf_vd).map_err(MpcTlsError::alloc)?; let sf_vd = vm_lock.decode(sf_vd).map_err(MpcTlsError::alloc)?; @@ -194,7 +187,7 @@ impl MpcTlsLeader { vm, keys, mut ke, - prf, + mut prf, mut record_layer, cf_vd, sf_vd, @@ -238,6 +231,15 @@ impl MpcTlsLeader { .await .map_err(MpcTlsError::preprocess)??; + ctx.io_mut() + .send(Message::SetClientRandom(SetClientRandom { + random: client_random.0, + })) + .await + .map_err(MpcTlsError::from)?; + + prf.set_client_random(client_random.0)?; + self.state = State::Handshake { ctx, vm, @@ -408,7 +410,6 @@ impl Backend for MpcTlsLeader { async fn set_server_random(&mut self, random: Random) -> Result<(), BackendError> { let State::Handshake { ctx, - vm, prf, server_random, .. @@ -426,13 +427,7 @@ impl Backend for MpcTlsLeader { .await .map_err(MpcTlsError::from)?; - let mut vm = vm - .try_lock() - .map_err(|_| MpcTlsError::other("VM lock is held"))?; - - prf.set_server_random(&mut (*vm), random.0) - .map_err(MpcTlsError::hs)?; - + prf.set_server_random(random.0).map_err(MpcTlsError::hs)?; *server_random = Some(random); Ok(()) @@ -543,9 +538,12 @@ impl Backend for MpcTlsLeader { let mut vm = vm .try_lock() .map_err(|_| MpcTlsError::other("VM lock is held"))?; - prf.set_sf_hash(&mut (*vm), hash).map_err(MpcTlsError::hs)?; + prf.set_sf_hash(hash).map_err(MpcTlsError::hs)?; - vm.execute_all(ctx).await.map_err(MpcTlsError::hs)?; + while prf.wants_flush() { + prf.flush(&mut *vm).map_err(MpcTlsError::hs)?; + vm.execute_all(ctx).await.map_err(MpcTlsError::hs)?; + } let sf_vd = sf_vd .try_recv() @@ -586,9 +584,12 @@ impl Backend for MpcTlsLeader { let mut vm = vm .try_lock() .map_err(|_| MpcTlsError::hs("VM lock is held"))?; - prf.set_cf_hash(&mut (*vm), hash).map_err(MpcTlsError::hs)?; + prf.set_cf_hash(hash).map_err(MpcTlsError::hs)?; - vm.execute_all(ctx).await.map_err(MpcTlsError::hs)?; + while prf.wants_flush() { + prf.flush(&mut *vm).map_err(MpcTlsError::hs)?; + vm.execute_all(ctx).await.map_err(MpcTlsError::hs)?; + } let cf_vd = cf_vd .try_recv() @@ -605,7 +606,7 @@ impl Backend for MpcTlsLeader { vm, keys, mut ke, - prf, + mut prf, mut record_layer, cf_vd, sf_vd, @@ -650,10 +651,15 @@ impl Backend for MpcTlsLeader { .map_err(|_| MpcTlsError::other("VM lock is held"))?; ke.assign(&mut (*vm_lock)).map_err(MpcTlsError::hs)?; - vm_lock - .execute_all(&mut ctx) - .await - .map_err(MpcTlsError::hs)?; + + while prf.wants_flush() { + prf.flush(&mut *vm_lock).map_err(MpcTlsError::hs)?; + vm_lock + .execute_all(&mut ctx) + .await + .map_err(MpcTlsError::hs)?; + } + ke.finalize().await.map_err(MpcTlsError::hs)?; record_layer.setup(&mut ctx).await?; } diff --git a/crates/mpc-tls/src/msg.rs b/crates/mpc-tls/src/msg.rs index c5c86e4162..a150f0eb37 100644 --- a/crates/mpc-tls/src/msg.rs +++ b/crates/mpc-tls/src/msg.rs @@ -9,6 +9,7 @@ use crate::record_layer::{DecryptMode, EncryptMode}; /// MPC-TLS protocol message. #[derive(Debug, Clone, Serialize, Deserialize)] pub(crate) enum Message { + SetClientRandom(SetClientRandom), SetServerRandom(SetServerRandom), SetServerKey(SetServerKey), ClientFinishedVd(ClientFinishedVd), @@ -20,6 +21,11 @@ pub(crate) enum Message { CloseConnection, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct SetClientRandom { + pub(crate) random: [u8; 32], +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub(crate) struct SetServerRandom { pub(crate) random: [u8; 32], diff --git a/crates/prover/src/config.rs b/crates/prover/src/config.rs index 67cedd7ffe..c7f9a2a25e 100644 --- a/crates/prover/src/config.rs +++ b/crates/prover/src/config.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use mpc_tls::Config; -use tlsn_common::config::ProtocolConfig; +use tlsn_common::config::{NetworkSetting, ProtocolConfig}; use tlsn_core::{connection::ServerName, CryptoProvider}; /// Configuration for the prover @@ -55,6 +55,10 @@ impl ProverConfig { builder.max_recv_records(max_recv_records); } + if let NetworkSetting::Latency = self.protocol_config.network() { + builder.low_bandwidth(); + } + builder.build().unwrap() } } diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs index f422d866bc..e33f37648d 100644 --- a/crates/prover/src/lib.rs +++ b/crates/prover/src/lib.rs @@ -17,7 +17,6 @@ pub use future::ProverFuture; use mpz_common::Context; use mpz_core::Block; use mpz_garble_core::Delta; -use rand06_compat::Rand0_6CompatExt; use state::{Notarize, Prove}; use futures::{AsyncRead, AsyncWrite, TryFutureExt}; diff --git a/crates/verifier/src/config.rs b/crates/verifier/src/config.rs index 4328d5e9a1..8f2abc1942 100644 --- a/crates/verifier/src/config.rs +++ b/crates/verifier/src/config.rs @@ -4,7 +4,7 @@ use std::{ }; use mpc_tls::Config; -use tlsn_common::config::{ProtocolConfig, ProtocolConfigValidator}; +use tlsn_common::config::{NetworkSetting, ProtocolConfig, ProtocolConfigValidator}; use tlsn_core::CryptoProvider; /// Configuration for the [`Verifier`](crate::tls::Verifier). @@ -58,6 +58,10 @@ impl VerifierConfig { builder.max_recv_records(max_recv_records); } + if let NetworkSetting::Latency = protocol_config.network() { + builder.low_bandwidth(); + } + builder.build().unwrap() } } diff --git a/crates/verifier/src/lib.rs b/crates/verifier/src/lib.rs index 51f79f58a1..ae5d228159 100644 --- a/crates/verifier/src/lib.rs +++ b/crates/verifier/src/lib.rs @@ -20,7 +20,6 @@ use mpc_tls::{FollowerData, MpcTlsFollower}; use mpz_common::Context; use mpz_core::Block; use mpz_garble_core::Delta; -use rand06_compat::Rand0_6CompatExt; use serio::stream::IoStreamExt; use state::{Notarize, Verify}; use tls_core::msgs::enums::ContentType; diff --git a/crates/wasm/src/prover/config.rs b/crates/wasm/src/prover/config.rs index 72352b526f..99028510b4 100644 --- a/crates/wasm/src/prover/config.rs +++ b/crates/wasm/src/prover/config.rs @@ -1,5 +1,5 @@ use serde::Deserialize; -use tlsn_common::config::ProtocolConfig; +use tlsn_common::config::{NetworkSetting, ProtocolConfig}; use tsify_next::Tsify; #[derive(Debug, Tsify, Deserialize)] @@ -12,6 +12,7 @@ pub struct ProverConfig { pub defer_decryption_from_start: Option, pub max_sent_records: Option, pub max_recv_records: Option, + pub network: NetworkSetting, } impl From for tlsn_prover::ProverConfig { @@ -37,6 +38,7 @@ impl From for tlsn_prover::ProverConfig { builder.defer_decryption_from_start(value); } + builder.network(value.network); let protocol_config = builder.build().unwrap(); let mut builder = tlsn_prover::ProverConfig::builder();