From 96c578bd5e6b5f803152ee9949097995b1f340ce Mon Sep 17 00:00:00 2001 From: th4s Date: Wed, 23 Apr 2025 07:49:37 -0700 Subject: [PATCH 01/36] feat(prf): reduced MPC variant --- Cargo.toml | 6 +- crates/benches/binary/Cargo.toml | 9 +- crates/benches/binary/src/lib.rs | 1 - crates/benches/binary/src/preprocess.rs | 5 - crates/benches/binary/src/prover_main.rs | 5 - crates/benches/binary/src/verifier_main.rs | 5 - .../hmac-sha256-circuits/Cargo.toml | 22 - .../hmac-sha256-circuits/src/hmac_sha256.rs | 159 ------ .../hmac-sha256-circuits/src/lib.rs | 61 --- .../hmac-sha256-circuits/src/prf.rs | 227 --------- .../hmac-sha256-circuits/src/session_keys.rs | 200 -------- .../hmac-sha256-circuits/src/verify_data.rs | 88 ---- crates/components/hmac-sha256/Cargo.toml | 15 +- crates/components/hmac-sha256/benches/prf.rs | 99 ++-- crates/components/hmac-sha256/src/config.rs | 30 +- crates/components/hmac-sha256/src/error.rs | 15 - crates/components/hmac-sha256/src/hmac.rs | 188 +++++++ crates/components/hmac-sha256/src/lib.rs | 360 ++++++++------ crates/components/hmac-sha256/src/prf.rs | 312 ------------ .../hmac-sha256/src/prf/function/local.rs | 294 +++++++++++ .../hmac-sha256/src/prf/function/mod.rs | 247 ++++++++++ .../hmac-sha256/src/prf/function/mpc.rs | 205 ++++++++ crates/components/hmac-sha256/src/prf/mod.rs | 466 ++++++++++++++++++ .../components/hmac-sha256/src/prf/state.rs | 15 + crates/components/hmac-sha256/src/sha256.rs | 381 ++++++++++++++ .../components/hmac-sha256/src/test_utils.rs | 261 ++++++++++ crates/mpc-tls/Cargo.toml | 2 +- crates/mpc-tls/src/config.rs | 4 + crates/mpc-tls/src/follower.rs | 79 ++- crates/mpc-tls/src/leader.rs | 85 ++-- crates/mpc-tls/src/msg.rs | 6 + crates/mpc-tls/tests/test.rs | 2 +- crates/prover/Cargo.toml | 1 + crates/prover/src/config.rs | 7 +- crates/prover/src/lib.rs | 3 +- crates/verifier/Cargo.toml | 1 + crates/verifier/src/config.rs | 7 +- crates/verifier/src/lib.rs | 1 - 38 files changed, 2467 insertions(+), 1407 deletions(-) delete mode 100644 crates/benches/binary/src/preprocess.rs delete mode 100644 crates/components/hmac-sha256-circuits/Cargo.toml delete mode 100644 crates/components/hmac-sha256-circuits/src/hmac_sha256.rs delete mode 100644 crates/components/hmac-sha256-circuits/src/lib.rs delete mode 100644 crates/components/hmac-sha256-circuits/src/prf.rs delete mode 100644 crates/components/hmac-sha256-circuits/src/session_keys.rs delete mode 100644 crates/components/hmac-sha256-circuits/src/verify_data.rs create mode 100644 crates/components/hmac-sha256/src/hmac.rs delete mode 100644 crates/components/hmac-sha256/src/prf.rs create mode 100644 crates/components/hmac-sha256/src/prf/function/local.rs create mode 100644 crates/components/hmac-sha256/src/prf/function/mod.rs create mode 100644 crates/components/hmac-sha256/src/prf/function/mpc.rs create mode 100644 crates/components/hmac-sha256/src/prf/mod.rs create mode 100644 crates/components/hmac-sha256/src/prf/state.rs create mode 100644 crates/components/hmac-sha256/src/sha256.rs create mode 100644 crates/components/hmac-sha256/src/test_utils.rs diff --git a/Cargo.toml b/Cargo.toml index 4c26eef810..1a1d200454 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,8 +8,7 @@ members = [ "crates/common", "crates/components/deap", "crates/components/cipher", - #"crates/components/hmac-sha256", - #"crates/components/hmac-sha256-circuits", + "crates/components/hmac-sha256", "crates/components/key-exchange", "crates/core", "crates/data-fixtures", @@ -57,8 +56,7 @@ tlsn-core = { path = "crates/core" } tlsn-data-fixtures = { path = "crates/data-fixtures" } tlsn-deap = { path = "crates/components/deap" } tlsn-formats = { path = "crates/formats" } -#tlsn-hmac-sha256 = { path = "crates/components/hmac-sha256" } -#tlsn-hmac-sha256-circuits = { path = "crates/components/hmac-sha256-circuits" } +tlsn-hmac-sha256 = { path = "crates/components/hmac-sha256" } tlsn-key-exchange = { path = "crates/components/key-exchange" } tlsn-mpc-tls = { path = "crates/mpc-tls" } tlsn-prover = { path = "crates/prover" } diff --git a/crates/benches/binary/Cargo.toml b/crates/benches/binary/Cargo.toml index 3d9cfc2611..8ba61c5c54 100644 --- a/crates/benches/binary/Cargo.toml +++ b/crates/benches/binary/Cargo.toml @@ -18,10 +18,10 @@ mpz-core = { workspace = true } mpz-garble = { workspace = true } mpz-ot = { workspace = true, features = ["ideal"] } tlsn-benches-library = { workspace = true } -tlsn-benches-browser-native = { workspace = true, optional = true} +tlsn-benches-browser-native = { workspace = true, optional = true } tlsn-common = { workspace = true } tlsn-core = { workspace = true } -#tlsn-hmac-sha256 = { workspace = true } +tlsn-hmac-sha256 = { workspace = true } tlsn-prover = { workspace = true } tlsn-server-fixture = { workspace = true } tlsn-server-fixture-certs = { workspace = true } @@ -30,7 +30,7 @@ tlsn-verifier = { workspace = true } anyhow = { workspace = true } async-trait = { workspace = true } -charming = {version = "0.3.1", features = ["ssr"]} +charming = { version = "0.3.1", features = ["ssr"] } csv = "1.3.0" dhat = { version = "0.3.3" } env_logger = { version = "0.6.0", default-features = false } @@ -46,7 +46,8 @@ tokio = { workspace = true, features = [ ] } tokio-util = { workspace = true } toml = "0.8.11" -tracing-subscriber = {workspace = true, features = ["env-filter"]} +tracing-subscriber = { workspace = true, features = ["env-filter"] } +rand = { workspace = true } [[bin]] name = "bench" diff --git a/crates/benches/binary/src/lib.rs b/crates/benches/binary/src/lib.rs index ab1c9dce2f..9225e78ca0 100644 --- a/crates/benches/binary/src/lib.rs +++ b/crates/benches/binary/src/lib.rs @@ -1,6 +1,5 @@ pub mod config; pub mod metrics; -mod preprocess; pub mod prover; pub mod prover_main; pub mod verifier_main; diff --git a/crates/benches/binary/src/preprocess.rs b/crates/benches/binary/src/preprocess.rs deleted file mode 100644 index 8e0e9e080f..0000000000 --- a/crates/benches/binary/src/preprocess.rs +++ /dev/null @@ -1,5 +0,0 @@ -use hmac_sha256::build_circuits; - -pub async fn preprocess_prf_circuits() { - build_circuits().await; -} diff --git a/crates/benches/binary/src/prover_main.rs b/crates/benches/binary/src/prover_main.rs index a72330d996..dc1e1d23e0 100644 --- a/crates/benches/binary/src/prover_main.rs +++ b/crates/benches/binary/src/prover_main.rs @@ -14,7 +14,6 @@ use std::{ use crate::{ config::{BenchInstance, Config}, metrics::Metrics, - preprocess::preprocess_prf_circuits, set_interface, PROVER_INTERFACE, }; use anyhow::Context; @@ -58,10 +57,6 @@ pub async fn prover_main(is_memory_profiling: bool) -> anyhow::Result<()> { .open("metrics.csv") .context("failed to open metrics file")?; - // Preprocess the PRF circuits as they are allocating a lot of memory, which - // don't need to be accounted for in the benchmarks. - preprocess_prf_circuits().await; - { let mut metric_wrt = WriterBuilder::new() // If file is not empty, assume that the CSV header is already present in the file. diff --git a/crates/benches/binary/src/verifier_main.rs b/crates/benches/binary/src/verifier_main.rs index d76752975c..f14c41f9ea 100644 --- a/crates/benches/binary/src/verifier_main.rs +++ b/crates/benches/binary/src/verifier_main.rs @@ -4,7 +4,6 @@ use crate::{ config::{BenchInstance, Config}, - preprocess::preprocess_prf_circuits, set_interface, VERIFIER_INTERFACE, }; use tls_core::verify::WebPkiVerifier; @@ -40,10 +39,6 @@ pub async fn verifier_main(is_memory_profiling: bool) -> anyhow::Result<()> { .await .context("failed to bind to port")?; - // Preprocess the PRF circuits as they are allocating a lot of memory, which - // don't need to be accounted for in the benchmarks. - preprocess_prf_circuits().await; - for bench in config.benches { for instance in bench.flatten() { if is_memory_profiling && !instance.memory_profile { diff --git a/crates/components/hmac-sha256-circuits/Cargo.toml b/crates/components/hmac-sha256-circuits/Cargo.toml deleted file mode 100644 index 55df9ebad5..0000000000 --- a/crates/components/hmac-sha256-circuits/Cargo.toml +++ /dev/null @@ -1,22 +0,0 @@ -[package] -name = "tlsn-hmac-sha256-circuits" -authors = ["TLSNotary Team"] -description = "The 2PC circuits for TLS HMAC-SHA256 PRF" -keywords = ["tls", "mpc", "2pc", "hmac", "sha256"] -categories = ["cryptography"] -license = "MIT OR Apache-2.0" -version = "0.1.0-alpha.11-pre" -edition = "2021" - -[lints] -workspace = true - -[lib] -name = "hmac_sha256_circuits" - -[dependencies] -mpz-circuits = { workspace = true } -tracing = { workspace = true } - -[dev-dependencies] -ring = { workspace = true } diff --git a/crates/components/hmac-sha256-circuits/src/hmac_sha256.rs b/crates/components/hmac-sha256-circuits/src/hmac_sha256.rs deleted file mode 100644 index 5f135cea68..0000000000 --- a/crates/components/hmac-sha256-circuits/src/hmac_sha256.rs +++ /dev/null @@ -1,159 +0,0 @@ -use std::cell::RefCell; - -use mpz_circuits::{ - circuits::{sha256, sha256_compress, sha256_compress_trace, sha256_trace}, - types::{U32, U8}, - BuilderState, Tracer, -}; - -static SHA256_INITIAL_STATE: [u32; 8] = [ - 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, -]; - -/// Returns the outer and inner states of HMAC-SHA256 with the provided key. -/// -/// Outer state is H(key ⊕ opad) -/// -/// Inner state is H(key ⊕ ipad) -/// -/// # Arguments -/// -/// * `builder_state` - Reference to builder state. -/// * `key` - N-byte key (must be <= 64 bytes). -pub fn hmac_sha256_partial_trace<'a>( - builder_state: &'a RefCell, - key: &[Tracer<'a, U8>], -) -> ([Tracer<'a, U32>; 8], [Tracer<'a, U32>; 8]) { - assert!(key.len() <= 64); - - let mut opad = [Tracer::new( - builder_state, - builder_state.borrow_mut().get_constant(0x5cu8), - ); 64]; - - let mut ipad = [Tracer::new( - builder_state, - builder_state.borrow_mut().get_constant(0x36u8), - ); 64]; - - key.iter().enumerate().for_each(|(i, k)| { - opad[i] = opad[i] ^ *k; - ipad[i] = ipad[i] ^ *k; - }); - - let sha256_initial_state: [_; 8] = SHA256_INITIAL_STATE - .map(|v| Tracer::new(builder_state, builder_state.borrow_mut().get_constant(v))); - - let outer_state = sha256_compress_trace(builder_state, sha256_initial_state, opad); - let inner_state = sha256_compress_trace(builder_state, sha256_initial_state, ipad); - - (outer_state, inner_state) -} - -/// Reference implementation of HMAC-SHA256 partial function. -/// -/// Returns the outer and inner states of HMAC-SHA256 with the provided key. -/// -/// Outer state is H(key ⊕ opad) -/// -/// Inner state is H(key ⊕ ipad) -/// -/// # Arguments -/// -/// * `key` - N-byte key (must be <= 64 bytes). -pub fn hmac_sha256_partial(key: &[u8]) -> ([u32; 8], [u32; 8]) { - assert!(key.len() <= 64); - - let mut opad = [0x5cu8; 64]; - let mut ipad = [0x36u8; 64]; - - key.iter().enumerate().for_each(|(i, k)| { - opad[i] ^= k; - ipad[i] ^= k; - }); - - let outer_state = sha256_compress(SHA256_INITIAL_STATE, opad); - let inner_state = sha256_compress(SHA256_INITIAL_STATE, ipad); - - (outer_state, inner_state) -} - -/// HMAC-SHA256 finalization function. -/// -/// Returns the HMAC-SHA256 digest of the provided message using existing outer -/// and inner states. -/// -/// # Arguments -/// -/// * `outer_state` - 256-bit outer state. -/// * `inner_state` - 256-bit inner state. -/// * `msg` - N-byte message. -pub fn hmac_sha256_finalize_trace<'a>( - builder_state: &'a RefCell, - outer_state: [Tracer<'a, U32>; 8], - inner_state: [Tracer<'a, U32>; 8], - msg: &[Tracer<'a, U8>], -) -> [Tracer<'a, U8>; 32] { - sha256_trace( - builder_state, - outer_state, - 64, - &sha256_trace(builder_state, inner_state, 64, msg), - ) -} - -/// Reference implementation of the HMAC-SHA256 finalization function. -/// -/// Returns the HMAC-SHA256 digest of the provided message using existing outer -/// and inner states. -/// -/// # Arguments -/// -/// * `outer_state` - 256-bit outer state. -/// * `inner_state` - 256-bit inner state. -/// * `msg` - N-byte message. -pub fn hmac_sha256_finalize(outer_state: [u32; 8], inner_state: [u32; 8], msg: &[u8]) -> [u8; 32] { - sha256(outer_state, 64, &sha256(inner_state, 64, msg)) -} - -#[cfg(test)] -mod tests { - use mpz_circuits::{test_circ, CircuitBuilder}; - - use super::*; - - #[test] - fn test_hmac_sha256_partial() { - let builder = CircuitBuilder::new(); - let key = builder.add_array_input::(); - let (outer_state, inner_state) = hmac_sha256_partial_trace(builder.state(), &key); - builder.add_output(outer_state); - builder.add_output(inner_state); - let circ = builder.build().unwrap(); - - let key = [69u8; 48]; - - test_circ!(circ, hmac_sha256_partial, fn(&key) -> ([u32; 8], [u32; 8])); - } - - #[test] - fn test_hmac_sha256_finalize() { - let builder = CircuitBuilder::new(); - let outer_state = builder.add_array_input::(); - let inner_state = builder.add_array_input::(); - let msg = builder.add_array_input::(); - let hash = hmac_sha256_finalize_trace(builder.state(), outer_state, inner_state, &msg); - builder.add_output(hash); - let circ = builder.build().unwrap(); - - let key = [69u8; 32]; - let (outer_state, inner_state) = hmac_sha256_partial(&key); - let msg = [42u8; 47]; - - test_circ!( - circ, - hmac_sha256_finalize, - fn(outer_state, inner_state, &msg) -> [u8; 32] - ); - } -} diff --git a/crates/components/hmac-sha256-circuits/src/lib.rs b/crates/components/hmac-sha256-circuits/src/lib.rs deleted file mode 100644 index 2892a5bcf0..0000000000 --- a/crates/components/hmac-sha256-circuits/src/lib.rs +++ /dev/null @@ -1,61 +0,0 @@ -//! HMAC-SHA256 circuits. - -#![deny(missing_docs, unreachable_pub, unused_must_use)] -#![deny(clippy::all)] -#![forbid(unsafe_code)] - -mod hmac_sha256; -mod prf; -mod session_keys; -mod verify_data; - -pub use hmac_sha256::{ - hmac_sha256_finalize, hmac_sha256_finalize_trace, hmac_sha256_partial, - hmac_sha256_partial_trace, -}; - -pub use prf::{prf, prf_trace}; -pub use session_keys::{session_keys, session_keys_trace}; -pub use verify_data::{verify_data, verify_data_trace}; - -use mpz_circuits::{Circuit, CircuitBuilder, Tracer}; -use std::sync::Arc; - -/// Builds session key derivation circuit. -#[tracing::instrument(level = "trace")] -pub fn build_session_keys() -> Arc { - let builder = CircuitBuilder::new(); - let pms = builder.add_array_input::(); - let client_random = builder.add_array_input::(); - let server_random = builder.add_array_input::(); - let (cwk, swk, civ, siv, outer_state, inner_state) = - session_keys_trace(builder.state(), pms, client_random, server_random); - builder.add_output(cwk); - builder.add_output(swk); - builder.add_output(civ); - builder.add_output(siv); - builder.add_output(outer_state); - builder.add_output(inner_state); - Arc::new(builder.build().expect("session keys should build")) -} - -/// Builds a verify data circuit. -#[tracing::instrument(level = "trace")] -pub fn build_verify_data(label: &[u8]) -> Arc { - let builder = CircuitBuilder::new(); - let outer_state = builder.add_array_input::(); - let inner_state = builder.add_array_input::(); - let handshake_hash = builder.add_array_input::(); - let vd = verify_data_trace( - builder.state(), - outer_state, - inner_state, - &label - .iter() - .map(|v| Tracer::new(builder.state(), builder.get_constant(*v).to_inner())) - .collect::>(), - handshake_hash, - ); - builder.add_output(vd); - Arc::new(builder.build().expect("verify data should build")) -} diff --git a/crates/components/hmac-sha256-circuits/src/prf.rs b/crates/components/hmac-sha256-circuits/src/prf.rs deleted file mode 100644 index 664a93393f..0000000000 --- a/crates/components/hmac-sha256-circuits/src/prf.rs +++ /dev/null @@ -1,227 +0,0 @@ -//! This module provides an implementation of the HMAC-SHA256 PRF defined in [RFC 5246](https://www.rfc-editor.org/rfc/rfc5246#section-5). - -use std::cell::RefCell; - -use mpz_circuits::{ - types::{U32, U8}, - BuilderState, Tracer, -}; - -use crate::hmac_sha256::{hmac_sha256_finalize, hmac_sha256_finalize_trace}; - -fn p_hash_trace<'a>( - builder_state: &'a RefCell, - outer_state: [Tracer<'a, U32>; 8], - inner_state: [Tracer<'a, U32>; 8], - seed: &[Tracer<'a, U8>], - iterations: usize, -) -> Vec> { - // A() is defined as: - // - // A(0) = seed - // A(i) = HMAC_hash(secret, A(i-1)) - let mut a_cache: Vec<_> = Vec::with_capacity(iterations + 1); - a_cache.push(seed.to_vec()); - - for i in 0..iterations { - let a_i = hmac_sha256_finalize_trace(builder_state, outer_state, inner_state, &a_cache[i]); - a_cache.push(a_i.to_vec()); - } - - // HMAC_hash(secret, A(i) + seed) - let mut output: Vec<_> = Vec::with_capacity(iterations * 32); - for i in 0..iterations { - let mut a_i_seed = a_cache[i + 1].clone(); - a_i_seed.extend_from_slice(seed); - - let hash = hmac_sha256_finalize_trace(builder_state, outer_state, inner_state, &a_i_seed); - output.extend_from_slice(&hash); - } - - output -} - -fn p_hash(outer_state: [u32; 8], inner_state: [u32; 8], seed: &[u8], iterations: usize) -> Vec { - // A() is defined as: - // - // A(0) = seed - // A(i) = HMAC_hash(secret, A(i-1)) - let mut a_cache: Vec<_> = Vec::with_capacity(iterations + 1); - a_cache.push(seed.to_vec()); - - for i in 0..iterations { - let a_i = hmac_sha256_finalize(outer_state, inner_state, &a_cache[i]); - a_cache.push(a_i.to_vec()); - } - - // HMAC_hash(secret, A(i) + seed) - let mut output: Vec<_> = Vec::with_capacity(iterations * 32); - for i in 0..iterations { - let mut a_i_seed = a_cache[i + 1].clone(); - a_i_seed.extend_from_slice(seed); - - let hash = hmac_sha256_finalize(outer_state, inner_state, &a_i_seed); - output.extend_from_slice(&hash); - } - - output -} - -/// Computes PRF(secret, label, seed). -/// -/// # Arguments -/// -/// * `builder_state` - Reference to builder state. -/// * `outer_state` - The outer state of HMAC-SHA256. -/// * `inner_state` - The inner state of HMAC-SHA256. -/// * `seed` - The seed to use. -/// * `label` - The label to use. -/// * `bytes` - The number of bytes to output. -pub fn prf_trace<'a>( - builder_state: &'a RefCell, - outer_state: [Tracer<'a, U32>; 8], - inner_state: [Tracer<'a, U32>; 8], - seed: &[Tracer<'a, U8>], - label: &[Tracer<'a, U8>], - bytes: usize, -) -> Vec> { - let iterations = bytes / 32 + (bytes % 32 != 0) as usize; - let mut label_seed = label.to_vec(); - label_seed.extend_from_slice(seed); - - let mut output = p_hash_trace( - builder_state, - outer_state, - inner_state, - &label_seed, - iterations, - ); - output.truncate(bytes); - - output -} - -/// Reference implementation of PRF(secret, label, seed). -/// -/// # Arguments -/// -/// * `outer_state` - The outer state of HMAC-SHA256. -/// * `inner_state` - The inner state of HMAC-SHA256. -/// * `seed` - The seed to use. -/// * `label` - The label to use. -/// * `bytes` - The number of bytes to output. -pub fn prf( - outer_state: [u32; 8], - inner_state: [u32; 8], - seed: &[u8], - label: &[u8], - bytes: usize, -) -> Vec { - let iterations = bytes / 32 + (bytes % 32 != 0) as usize; - let mut label_seed = label.to_vec(); - label_seed.extend_from_slice(seed); - - let mut output = p_hash(outer_state, inner_state, &label_seed, iterations); - output.truncate(bytes); - - output -} - -#[cfg(test)] -mod tests { - use mpz_circuits::{evaluate, CircuitBuilder}; - - use crate::hmac_sha256::hmac_sha256_partial; - - use super::*; - - #[test] - fn test_p_hash() { - let builder = CircuitBuilder::new(); - let outer_state = builder.add_array_input::(); - let inner_state = builder.add_array_input::(); - let seed = builder.add_array_input::(); - let output = p_hash_trace(builder.state(), outer_state, inner_state, &seed, 2); - builder.add_output(output); - let circ = builder.build().unwrap(); - - let outer_state = [0u32; 8]; - let inner_state = [1u32; 8]; - let seed = [42u8; 64]; - - let expected = p_hash(outer_state, inner_state, &seed, 2); - let actual = evaluate!(circ, fn(outer_state, inner_state, &seed) -> Vec).unwrap(); - - assert_eq!(actual, expected); - } - - #[test] - fn test_prf() { - let builder = CircuitBuilder::new(); - let outer_state = builder.add_array_input::(); - let inner_state = builder.add_array_input::(); - let seed = builder.add_array_input::(); - let label = builder.add_array_input::(); - let output = prf_trace(builder.state(), outer_state, inner_state, &seed, &label, 48); - builder.add_output(output); - let circ = builder.build().unwrap(); - - let master_secret = [0u8; 48]; - let seed = [43u8; 64]; - let label = b"master secret"; - - let (outer_state, inner_state) = hmac_sha256_partial(&master_secret); - - let expected = prf(outer_state, inner_state, &seed, label, 48); - let actual = - evaluate!(circ, fn(outer_state, inner_state, &seed, label) -> Vec).unwrap(); - - assert_eq!(actual, expected); - - let mut expected_ring = [0u8; 48]; - ring_prf::prf(&mut expected_ring, &master_secret, label, &seed); - - assert_eq!(actual, expected_ring); - } - - // Borrowed from Rustls for testing - // https://github.com/rustls/rustls/blob/main/rustls/src/tls12/prf.rs - mod ring_prf { - use ring::{hmac, hmac::HMAC_SHA256}; - - fn concat_sign(key: &hmac::Key, a: &[u8], b: &[u8]) -> hmac::Tag { - let mut ctx = hmac::Context::with_key(key); - ctx.update(a); - ctx.update(b); - ctx.sign() - } - - fn p(out: &mut [u8], secret: &[u8], seed: &[u8]) { - let hmac_key = hmac::Key::new(HMAC_SHA256, secret); - - // A(1) - let mut current_a = hmac::sign(&hmac_key, seed); - let chunk_size = HMAC_SHA256.digest_algorithm().output_len(); - for chunk in out.chunks_mut(chunk_size) { - // P_hash[i] = HMAC_hash(secret, A(i) + seed) - let p_term = concat_sign(&hmac_key, current_a.as_ref(), seed); - chunk.copy_from_slice(&p_term.as_ref()[..chunk.len()]); - - // A(i+1) = HMAC_hash(secret, A(i)) - current_a = hmac::sign(&hmac_key, current_a.as_ref()); - } - } - - fn concat(a: &[u8], b: &[u8]) -> Vec { - let mut ret = Vec::new(); - ret.extend_from_slice(a); - ret.extend_from_slice(b); - ret - } - - pub(crate) fn prf(out: &mut [u8], secret: &[u8], label: &[u8], seed: &[u8]) { - let joined_seed = concat(label, seed); - p(out, secret, &joined_seed); - } - } -} diff --git a/crates/components/hmac-sha256-circuits/src/session_keys.rs b/crates/components/hmac-sha256-circuits/src/session_keys.rs deleted file mode 100644 index 4ebe7f0841..0000000000 --- a/crates/components/hmac-sha256-circuits/src/session_keys.rs +++ /dev/null @@ -1,200 +0,0 @@ -use std::cell::RefCell; - -use mpz_circuits::{ - types::{U32, U8}, - BuilderState, Tracer, -}; - -use crate::{ - hmac_sha256::{hmac_sha256_partial, hmac_sha256_partial_trace}, - prf::{prf, prf_trace}, -}; - -/// Session Keys. -/// -/// Computes expanded p1 which consists of client_write_key + server_write_key. -/// Computes expanded p2 which consists of client_IV + server_IV. -/// -/// # Arguments -/// -/// * `builder_state` - Reference to builder state. -/// * `pms` - 32-byte premaster secret. -/// * `client_random` - 32-byte client random. -/// * `server_random` - 32-byte server random. -/// -/// # Returns -/// -/// * `client_write_key` - 16-byte client write key. -/// * `server_write_key` - 16-byte server write key. -/// * `client_IV` - 4-byte client IV. -/// * `server_IV` - 4-byte server IV. -/// * `outer_hash_state` - 256-bit master-secret outer HMAC state. -/// * `inner_hash_state` - 256-bit master-secret inner HMAC state. -#[allow(clippy::type_complexity)] -pub fn session_keys_trace<'a>( - builder_state: &'a RefCell, - pms: [Tracer<'a, U8>; 32], - client_random: [Tracer<'a, U8>; 32], - server_random: [Tracer<'a, U8>; 32], -) -> ( - [Tracer<'a, U8>; 16], - [Tracer<'a, U8>; 16], - [Tracer<'a, U8>; 4], - [Tracer<'a, U8>; 4], - [Tracer<'a, U32>; 8], - [Tracer<'a, U32>; 8], -) { - let (pms_outer_state, pms_inner_state) = hmac_sha256_partial_trace(builder_state, &pms); - - let master_secret = { - let seed = client_random - .iter() - .chain(&server_random) - .copied() - .collect::>(); - - let label = b"master secret" - .map(|v| Tracer::new(builder_state, builder_state.borrow_mut().get_constant(v))); - - prf_trace( - builder_state, - pms_outer_state, - pms_inner_state, - &seed, - &label, - 48, - ) - }; - - let (master_secret_outer_state, master_secret_inner_state) = - hmac_sha256_partial_trace(builder_state, &master_secret); - - let key_material = { - let seed = server_random - .iter() - .chain(&client_random) - .copied() - .collect::>(); - - let label = b"key expansion" - .map(|v| Tracer::new(builder_state, builder_state.borrow_mut().get_constant(v))); - - prf_trace( - builder_state, - master_secret_outer_state, - master_secret_inner_state, - &seed, - &label, - 40, - ) - }; - - let cwk = key_material[0..16].try_into().unwrap(); - let swk = key_material[16..32].try_into().unwrap(); - let civ = key_material[32..36].try_into().unwrap(); - let siv = key_material[36..40].try_into().unwrap(); - - ( - cwk, - swk, - civ, - siv, - master_secret_outer_state, - master_secret_inner_state, - ) -} - -/// Reference implementation of session keys derivation. -pub fn session_keys( - pms: [u8; 32], - client_random: [u8; 32], - server_random: [u8; 32], -) -> ([u8; 16], [u8; 16], [u8; 4], [u8; 4]) { - let (pms_outer_state, pms_inner_state) = hmac_sha256_partial(&pms); - - let master_secret = { - let seed = client_random - .iter() - .chain(&server_random) - .copied() - .collect::>(); - - let label = b"master secret"; - - prf(pms_outer_state, pms_inner_state, &seed, label, 48) - }; - - let (master_secret_outer_state, master_secret_inner_state) = - hmac_sha256_partial(&master_secret); - - let key_material = { - let seed = server_random - .iter() - .chain(&client_random) - .copied() - .collect::>(); - - let label = b"key expansion"; - - prf( - master_secret_outer_state, - master_secret_inner_state, - &seed, - label, - 40, - ) - }; - - let cwk = key_material[0..16].try_into().unwrap(); - let swk = key_material[16..32].try_into().unwrap(); - let civ = key_material[32..36].try_into().unwrap(); - let siv = key_material[36..40].try_into().unwrap(); - - (cwk, swk, civ, siv) -} - -#[cfg(test)] -mod tests { - use mpz_circuits::{evaluate, CircuitBuilder}; - - use super::*; - - #[test] - fn test_session_keys() { - let builder = CircuitBuilder::new(); - let pms = builder.add_array_input::(); - let client_random = builder.add_array_input::(); - let server_random = builder.add_array_input::(); - let (cwk, swk, civ, siv, outer_state, inner_state) = - session_keys_trace(builder.state(), pms, client_random, server_random); - builder.add_output(cwk); - builder.add_output(swk); - builder.add_output(civ); - builder.add_output(siv); - builder.add_output(outer_state); - builder.add_output(inner_state); - let circ = builder.build().unwrap(); - - let pms = [0u8; 32]; - let client_random = [42u8; 32]; - let server_random = [69u8; 32]; - - let (expected_cwk, expected_swk, expected_civ, expected_siv) = - session_keys(pms, client_random, server_random); - - let (cwk, swk, civ, siv, _, _) = evaluate!( - circ, - fn( - pms, - client_random, - server_random, - ) -> ([u8; 16], [u8; 16], [u8; 4], [u8; 4], [u32; 8], [u32; 8]) - ) - .unwrap(); - - assert_eq!(cwk, expected_cwk); - assert_eq!(swk, expected_swk); - assert_eq!(civ, expected_civ); - assert_eq!(siv, expected_siv); - } -} diff --git a/crates/components/hmac-sha256-circuits/src/verify_data.rs b/crates/components/hmac-sha256-circuits/src/verify_data.rs deleted file mode 100644 index 01f10b7469..0000000000 --- a/crates/components/hmac-sha256-circuits/src/verify_data.rs +++ /dev/null @@ -1,88 +0,0 @@ -use std::cell::RefCell; - -use mpz_circuits::{ - types::{U32, U8}, - BuilderState, Tracer, -}; - -use crate::prf::{prf, prf_trace}; - -/// Computes verify_data as specified in RFC 5246, Section 7.4.9. -/// -/// verify_data -/// PRF(master_secret, finished_label, -/// Hash(handshake_messages))[0..verify_data_length-1]; -/// -/// # Arguments -/// -/// * `builder_state` - The builder state. -/// * `outer_state` - The outer HMAC state of the master secret. -/// * `inner_state` - The inner HMAC state of the master secret. -/// * `label` - The label to use. -/// * `hs_hash` - The handshake hash. -pub fn verify_data_trace<'a>( - builder_state: &'a RefCell, - outer_state: [Tracer<'a, U32>; 8], - inner_state: [Tracer<'a, U32>; 8], - label: &[Tracer<'a, U8>], - hs_hash: [Tracer<'a, U8>; 32], -) -> [Tracer<'a, U8>; 12] { - let vd = prf_trace(builder_state, outer_state, inner_state, &hs_hash, label, 12); - - vd.try_into().expect("vd is 12 bytes") -} - -/// Reference implementation of verify_data as specified in RFC 5246, Section -/// 7.4.9. -/// -/// # Arguments -/// -/// * `outer_state` - The outer HMAC state of the master secret. -/// * `inner_state` - The inner HMAC state of the master secret. -/// * `label` - The label to use. -/// * `hs_hash` - The handshake hash. -pub fn verify_data( - outer_state: [u32; 8], - inner_state: [u32; 8], - label: &[u8], - hs_hash: [u8; 32], -) -> [u8; 12] { - let vd = prf(outer_state, inner_state, &hs_hash, label, 12); - - vd.try_into().expect("vd is 12 bytes") -} - -#[cfg(test)] -mod tests { - use super::*; - - use mpz_circuits::{evaluate, CircuitBuilder}; - - const CF_LABEL: &[u8; 15] = b"client finished"; - - #[test] - fn test_verify_data() { - let builder = CircuitBuilder::new(); - let outer_state = builder.add_array_input::(); - let inner_state = builder.add_array_input::(); - let label = builder.add_array_input::(); - let hs_hash = builder.add_array_input::(); - let vd = verify_data_trace(builder.state(), outer_state, inner_state, &label, hs_hash); - builder.add_output(vd); - let circ = builder.build().unwrap(); - - let outer_state = [0u32; 8]; - let inner_state = [1u32; 8]; - let hs_hash = [42u8; 32]; - - let expected = prf(outer_state, inner_state, &hs_hash, CF_LABEL, 12); - - let actual = evaluate!( - circ, - fn(outer_state, inner_state, CF_LABEL, hs_hash) -> [u8; 12] - ) - .unwrap(); - - assert_eq!(actual.to_vec(), expected); - } -} diff --git a/crates/components/hmac-sha256/Cargo.toml b/crates/components/hmac-sha256/Cargo.toml index 07ecc6b6c6..5b78a71df7 100644 --- a/crates/components/hmac-sha256/Cargo.toml +++ b/crates/components/hmac-sha256/Cargo.toml @@ -14,22 +14,14 @@ workspace = true [lib] name = "hmac_sha256" -[features] -default = ["mock"] -rayon = ["mpz-common/rayon"] -mock = [] - [dependencies] -tlsn-hmac-sha256-circuits = { workspace = true } - mpz-vm-core = { workspace = true } +mpz-core = { workspace = true } mpz-circuits = { workspace = true } -mpz-common = { workspace = true, features = ["cpu"] } -derive_builder = { workspace = true } thiserror = { workspace = true } tracing = { workspace = true } -futures = { workspace = true } +sha2 = { workspace = true } [dev-dependencies] mpz-ot = { workspace = true, features = ["ideal"] } @@ -39,7 +31,8 @@ mpz-common = { workspace = true, features = ["test-utils"] } criterion = { workspace = true, features = ["async_tokio"] } tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] } rand = { workspace = true } -rand06-compat = { workspace = true } +hex = { workspace = true } +ring = { workspace = true } [[bench]] name = "prf" diff --git a/crates/components/hmac-sha256/benches/prf.rs b/crates/components/hmac-sha256/benches/prf.rs index 9e494a93c2..b6db1e4ef4 100644 --- a/crates/components/hmac-sha256/benches/prf.rs +++ b/crates/components/hmac-sha256/benches/prf.rs @@ -2,16 +2,16 @@ use criterion::{criterion_group, criterion_main, Criterion}; -use hmac_sha256::{MpcPrf, PrfConfig, Role}; +use hmac_sha256::{Config, MpcPrf}; use mpz_common::context::test_mt_context; -use mpz_garble::protocol::semihonest::{Evaluator, Generator}; +use mpz_garble::protocol::semihonest::{Evaluator, Garbler}; use mpz_ot::ideal::cot::ideal_cot; use mpz_vm_core::{ memory::{binary::U8, correlated::Delta, Array}, prelude::*, + Execute, }; use rand::{rngs::StdRng, SeedableRng}; -use rand06_compat::Rand0_6CompatExt; #[allow(clippy::unit_arg)] fn criterion_benchmark(c: &mut Criterion) { @@ -36,10 +36,10 @@ async fn prf() { let mut leader_ctx = leader_exec.new_context().await.unwrap(); let mut follower_ctx = follower_exec.new_context().await.unwrap(); - let delta = Delta::random(&mut rng.compat_by_ref()); + let delta = Delta::random(&mut rng); let (ot_send, ot_recv) = ideal_cot(delta.into_inner()); - let mut leader_vm = Generator::new(ot_send, [0u8; 16], delta); + let mut leader_vm = Garbler::new(ot_send, [0u8; 16], delta); let mut follower_vm = Evaluator::new(ot_recv); let leader_pms: Array = leader_vm.alloc().unwrap(); @@ -52,23 +52,17 @@ async fn prf() { follower_vm.assign(follower_pms, pms).unwrap(); follower_vm.commit(follower_pms).unwrap(); - let mut leader = MpcPrf::new(PrfConfig::builder().role(Role::Leader).build().unwrap()); - let mut follower = MpcPrf::new(PrfConfig::builder().role(Role::Follower).build().unwrap()); + let mut leader = MpcPrf::new(Config::default()); + let mut follower = MpcPrf::new(Config::default()); let leader_output = leader.alloc(&mut leader_vm, leader_pms).unwrap(); let follower_output = follower.alloc(&mut follower_vm, follower_pms).unwrap(); - leader - .set_client_random(&mut leader_vm, Some(client_random)) - .unwrap(); - follower.set_client_random(&mut follower_vm, None).unwrap(); + leader.set_client_random(client_random).unwrap(); + follower.set_client_random(client_random).unwrap(); - leader - .set_server_random(&mut leader_vm, server_random) - .unwrap(); - follower - .set_server_random(&mut follower_vm, server_random) - .unwrap(); + leader.set_server_random(server_random).unwrap(); + follower.set_server_random(server_random).unwrap(); let _ = leader_vm .decode(leader_output.keys.client_write_key) @@ -88,27 +82,29 @@ async fn prf() { let _ = follower_vm.decode(follower_output.keys.client_iv).unwrap(); let _ = follower_vm.decode(follower_output.keys.server_iv).unwrap(); - futures::join!( - async { - leader_vm.flush(&mut leader_ctx).await.unwrap(); - leader_vm.execute(&mut leader_ctx).await.unwrap(); - leader_vm.flush(&mut leader_ctx).await.unwrap(); - }, - async { - follower_vm.flush(&mut follower_ctx).await.unwrap(); - follower_vm.execute(&mut follower_ctx).await.unwrap(); - follower_vm.flush(&mut follower_ctx).await.unwrap(); + loop { + let leader_finished = leader.drive_key_expansion(&mut leader_vm).unwrap(); + let follower_finished = follower.drive_key_expansion(&mut follower_vm).unwrap(); + + tokio::try_join!( + leader_vm.execute_all(&mut leader_ctx), + follower_vm.execute_all(&mut follower_ctx) + ) + .unwrap(); + + if leader_finished && follower_finished { + break; } - ); + } let cf_hs_hash = [1u8; 32]; let sf_hs_hash = [2u8; 32]; - leader.set_cf_hash(&mut leader_vm, cf_hs_hash).unwrap(); - leader.set_sf_hash(&mut leader_vm, sf_hs_hash).unwrap(); + leader.set_cf_hash(cf_hs_hash).unwrap(); + leader.set_sf_hash(sf_hs_hash).unwrap(); - follower.set_cf_hash(&mut follower_vm, cf_hs_hash).unwrap(); - follower.set_sf_hash(&mut follower_vm, sf_hs_hash).unwrap(); + follower.set_cf_hash(cf_hs_hash).unwrap(); + follower.set_sf_hash(sf_hs_hash).unwrap(); let _ = leader_vm.decode(leader_output.cf_vd).unwrap(); let _ = leader_vm.decode(leader_output.sf_vd).unwrap(); @@ -116,16 +112,33 @@ async fn prf() { let _ = follower_vm.decode(follower_output.cf_vd).unwrap(); let _ = follower_vm.decode(follower_output.sf_vd).unwrap(); - futures::join!( - async { - leader_vm.flush(&mut leader_ctx).await.unwrap(); - leader_vm.execute(&mut leader_ctx).await.unwrap(); - leader_vm.flush(&mut leader_ctx).await.unwrap(); - }, - async { - follower_vm.flush(&mut follower_ctx).await.unwrap(); - follower_vm.execute(&mut follower_ctx).await.unwrap(); - follower_vm.flush(&mut follower_ctx).await.unwrap(); + loop { + let leader_finished = leader.drive_client_finished(&mut leader_vm).unwrap(); + let follower_finished = follower.drive_client_finished(&mut follower_vm).unwrap(); + + tokio::try_join!( + leader_vm.execute_all(&mut leader_ctx), + follower_vm.execute_all(&mut follower_ctx) + ) + .unwrap(); + + if leader_finished && follower_finished { + break; + } + } + + loop { + let leader_finished = leader.drive_server_finished(&mut leader_vm).unwrap(); + let follower_finished = follower.drive_server_finished(&mut follower_vm).unwrap(); + + tokio::try_join!( + leader_vm.execute_all(&mut leader_ctx), + follower_vm.execute_all(&mut follower_ctx) + ) + .unwrap(); + + if leader_finished && follower_finished { + break; } - ); + } } diff --git a/crates/components/hmac-sha256/src/config.rs b/crates/components/hmac-sha256/src/config.rs index c9e96c9cd4..a0ba2df48f 100644 --- a/crates/components/hmac-sha256/src/config.rs +++ b/crates/components/hmac-sha256/src/config.rs @@ -1,24 +1,16 @@ -use derive_builder::Builder; +//! PRF Config. -/// Role of this party in the PRF. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum Role { - /// The leader provides the private inputs to the PRF. - Leader, - /// The follower is blind to the inputs to the PRF. - Follower, +/// Configuration option for the PRF. +#[derive(Debug, Clone, Copy)] +pub enum Config { + /// Computes some hashes locally. + Local, + /// Computes the whole PRF in MPC. + Mpc, } -/// Configuration for the PRF. -#[derive(Debug, Builder)] -pub struct PrfConfig { - /// The role of this party in the PRF. - pub(crate) role: Role, -} - -impl PrfConfig { - /// Creates a new builder. - pub fn builder() -> PrfConfigBuilder { - PrfConfigBuilder::default() +impl Default for Config { + fn default() -> Self { + Self::Mpc } } diff --git a/crates/components/hmac-sha256/src/error.rs b/crates/components/hmac-sha256/src/error.rs index d22f754947..bad3acf5f6 100644 --- a/crates/components/hmac-sha256/src/error.rs +++ b/crates/components/hmac-sha256/src/error.rs @@ -27,13 +27,6 @@ impl PrfError { } } - pub(crate) fn role(msg: impl Into) -> Self { - Self { - kind: ErrorKind::Role, - source: Some(msg.into().into()), - } - } - pub(crate) fn vm>>(err: E) -> Self { Self::new(ErrorKind::Vm, err) } @@ -43,7 +36,6 @@ impl PrfError { pub(crate) enum ErrorKind { Vm, State, - Role, } impl fmt::Display for PrfError { @@ -51,7 +43,6 @@ impl fmt::Display for PrfError { match self.kind { ErrorKind::Vm => write!(f, "vm error")?, ErrorKind::State => write!(f, "state error")?, - ErrorKind::Role => write!(f, "role error")?, } if let Some(ref source) = self.source { @@ -61,9 +52,3 @@ impl fmt::Display for PrfError { Ok(()) } } - -impl From for PrfError { - fn from(error: mpz_common::ContextError) -> Self { - Self::new(ErrorKind::Vm, error) - } -} diff --git a/crates/components/hmac-sha256/src/hmac.rs b/crates/components/hmac-sha256/src/hmac.rs new file mode 100644 index 0000000000..67cc6cf801 --- /dev/null +++ b/crates/components/hmac-sha256/src/hmac.rs @@ -0,0 +1,188 @@ +//! Computation of HMAC-SHA256. +//! +//! HMAC-SHA256 is defined as +//! +//! HMAC(m) = H((key' xor opad) || H((key' xor ipad) || m)) +//! +//! * H - SHA256 hash function +//! * key' - key padded with zero bytes to 64 bytes (we do not support longer +//! keys) +//! * opad - 64 bytes of 0x5c +//! * ipad - 64 bytes of 0x36 +//! * m - message +//! +//! This implementation computes HMAC-SHA256 using intermediate results +//! `outer_partial` and `inner_local`. Then HMAC(m) = H(outer_partial || +//! inner_local) +//! +//! * `outer_partial` - key' xor opad +//! * `inner_local` - H((key' xor ipad) || m) + +use crate::{sha256::Sha256, PrfError}; +use mpz_vm_core::{ + memory::{ + binary::{Binary, U32, U8}, + Array, + }, + Vm, +}; + +/// Computes HMAC-SHA256. +#[derive(Debug)] +pub(crate) struct HmacSha256 { + outer_partial: Array, + inner_local: Array, +} + +impl HmacSha256 { + /// Creates a new instance. + /// + /// # Arguments + /// + /// * `outer_partial` - (key' xor opad) + /// * `inner_local` - H((key' xor ipad) || m) + pub(crate) fn new(outer_partial: Array, inner_local: Array) -> Self { + Self { + outer_partial, + inner_local, + } + } + + /// Adds the circuit to the [`Vm`] and returns the output. + /// + /// # Arguments + /// + /// * `vm` - The virtual machine. + pub(crate) fn alloc(self, vm: &mut dyn Vm) -> Result, PrfError> { + let inner_local = self.inner_local.into(); + + let mut outer = Sha256::default(); + outer + .set_state(self.outer_partial, 64) + .update(inner_local) + .add_padding(vm)?; + + outer.alloc(vm) + } +} + +#[cfg(test)] +mod tests { + use crate::{ + convert_to_bytes, + hmac::HmacSha256, + sha256::sha256, + test_utils::{compute_inner_local, compute_outer_partial, mock_vm}, + }; + use mpz_common::context::test_st_context; + use mpz_vm_core::{ + memory::{ + binary::{U32, U8}, + Array, MemoryExt, ViewExt, + }, + Execute, + }; + + #[test] + fn test_hmac_reference() { + let (inputs, references) = test_fixtures(); + + for (input, &reference) in inputs.iter().zip(references.iter()) { + let outer_partial = compute_outer_partial(input.0.clone()); + let inner_local = compute_inner_local(input.0.clone(), &input.1); + + let hmac = sha256(outer_partial, 64, &convert_to_bytes(inner_local)); + + assert_eq!(convert_to_bytes(hmac), reference); + } + } + + #[tokio::test] + async fn test_hmac_circuit() { + let (mut ctx_a, mut ctx_b) = test_st_context(8); + let (mut leader, mut follower) = mock_vm(); + + let (inputs, references) = test_fixtures(); + for (input, &reference) in inputs.iter().zip(references.iter()) { + let outer_partial = compute_outer_partial(input.0.clone()); + let inner_local = compute_inner_local(input.0.clone(), &input.1); + + let outer_partial_leader: Array = leader.alloc().unwrap(); + leader.mark_public(outer_partial_leader).unwrap(); + leader.assign(outer_partial_leader, outer_partial).unwrap(); + leader.commit(outer_partial_leader).unwrap(); + + let inner_local_leader: Array = leader.alloc().unwrap(); + leader.mark_public(inner_local_leader).unwrap(); + leader + .assign(inner_local_leader, convert_to_bytes(inner_local)) + .unwrap(); + leader.commit(inner_local_leader).unwrap(); + + let hmac_leader = HmacSha256::new(outer_partial_leader, inner_local_leader) + .alloc(&mut leader) + .unwrap(); + let hmac_leader = leader.decode(hmac_leader).unwrap(); + + let outer_partial_follower: Array = follower.alloc().unwrap(); + follower.mark_public(outer_partial_follower).unwrap(); + follower + .assign(outer_partial_follower, outer_partial) + .unwrap(); + follower.commit(outer_partial_follower).unwrap(); + + let inner_local_follower: Array = follower.alloc().unwrap(); + follower.mark_public(inner_local_follower).unwrap(); + follower + .assign(inner_local_follower, convert_to_bytes(inner_local)) + .unwrap(); + follower.commit(inner_local_follower).unwrap(); + + let hmac_follower = HmacSha256::new(outer_partial_follower, inner_local_follower) + .alloc(&mut follower) + .unwrap(); + let hmac_follower = follower.decode(hmac_follower).unwrap(); + + let (hmac_leader, hmac_follower) = tokio::try_join!( + async { + leader.execute_all(&mut ctx_a).await.unwrap(); + hmac_leader.await + }, + async { + follower.execute_all(&mut ctx_b).await.unwrap(); + hmac_follower.await + } + ) + .unwrap(); + + assert_eq!(hmac_leader, hmac_follower); + assert_eq!(convert_to_bytes(hmac_leader), reference); + } + } + + #[allow(clippy::type_complexity)] + fn test_fixtures() -> (Vec<(Vec, Vec)>, Vec<[u8; 32]>) { + let test_vectors: Vec<(Vec, Vec)> = vec![ + ( + hex::decode("0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b").unwrap(), + hex::decode("4869205468657265").unwrap(), + ), + ( + hex::decode("4a656665").unwrap(), + hex::decode("7768617420646f2079612077616e7420666f72206e6f7468696e673f").unwrap(), + ), + ]; + let expected: Vec<[u8; 32]> = vec![ + hex::decode("b0344c61d8db38535ca8afceaf0bf12b881dc200c9833da726e9376c2e32cff7") + .unwrap() + .try_into() + .unwrap(), + hex::decode("5bdcc146bf60754e6a042426089575c75a003f089d2739839dec58b964ec3843") + .unwrap() + .try_into() + .unwrap(), + ]; + + (test_vectors, expected) + } +} diff --git a/crates/components/hmac-sha256/src/lib.rs b/crates/components/hmac-sha256/src/lib.rs index 2082113831..d08aef5f2c 100644 --- a/crates/components/hmac-sha256/src/lib.rs +++ b/crates/components/hmac-sha256/src/lib.rs @@ -1,30 +1,25 @@ -//! This module contains the protocol for computing TLS SHA-256 HMAC PRF. +//! This crate contains the protocol for computing TLS 1.2 SHA-256 HMAC PRF. #![deny(missing_docs, unreachable_pub, unused_must_use)] #![deny(clippy::all)] #![forbid(unsafe_code)] +mod hmac; +mod sha256; +#[cfg(test)] +mod test_utils; + mod config; -mod error; -mod prf; +pub use config::Config; -pub use config::{PrfConfig, PrfConfigBuilder, PrfConfigBuilderError, Role}; +mod error; pub use error::PrfError; + +mod prf; pub use prf::MpcPrf; use mpz_vm_core::memory::{binary::U8, Array}; -pub(crate) static CF_LABEL: &[u8] = b"client finished"; -pub(crate) static SF_LABEL: &[u8] = b"server finished"; - -/// Builds the circuits for the PRF. -/// -/// This function can be used ahead of time to build the circuits for the PRF, -/// which at the moment is CPU and memory intensive. -pub async fn build_circuits() { - prf::Circuits::get().await; -} - /// PRF output. #[derive(Debug, Clone, Copy)] pub struct PrfOutput { @@ -49,176 +44,213 @@ pub struct SessionKeys { pub server_iv: Array, } +fn convert_to_bytes(input: [u32; 8]) -> [u8; 32] { + let mut output = [0_u8; 32]; + for (k, byte_chunk) in input.iter().enumerate() { + let byte_chunk = byte_chunk.to_be_bytes(); + output[4 * k..4 * (k + 1)].copy_from_slice(&byte_chunk); + } + output +} + #[cfg(test)] mod tests { + use crate::{ + test_utils::{mock_vm, prf_cf_vd, prf_keys, prf_ms, prf_sf_vd}, + Config, MpcPrf, SessionKeys, + }; use mpz_common::context::test_st_context; - use mpz_garble::protocol::semihonest::{Evaluator, Generator}; - - use hmac_sha256_circuits::{hmac_sha256_partial, prf, session_keys}; - use mpz_ot::ideal::cot::ideal_cot; - use mpz_vm_core::{memory::correlated::Delta, prelude::*}; - use rand::{rngs::StdRng, SeedableRng}; - use rand06_compat::Rand0_6CompatExt; - - use super::*; - - fn compute_ms(pms: [u8; 32], client_random: [u8; 32], server_random: [u8; 32]) -> [u8; 48] { - let (outer_state, inner_state) = hmac_sha256_partial(&pms); - let seed = client_random - .iter() - .chain(&server_random) - .copied() - .collect::>(); - let ms = prf(outer_state, inner_state, &seed, b"master secret", 48); - ms.try_into().unwrap() - } + use mpz_vm_core::{ + memory::{binary::U8, Array, MemoryExt, ViewExt}, + Execute, + }; + use rand::{rngs::StdRng, Rng, SeedableRng}; - fn compute_vd(ms: [u8; 48], label: &[u8], hs_hash: [u8; 32]) -> [u8; 12] { - let (outer_state, inner_state) = hmac_sha256_partial(&ms); - let vd = prf(outer_state, inner_state, &hs_hash, label, 12); - vd.try_into().unwrap() + #[tokio::test] + async fn test_prf_local() { + let config = Config::Local; + test_prf(config).await; } - #[ignore = "expensive"] #[tokio::test] - async fn test_prf() { - let mut rng = StdRng::seed_from_u64(0); + async fn test_prf_mpc() { + let config = Config::Mpc; + test_prf(config).await; + } - let pms = [42u8; 32]; - let client_random = [69u8; 32]; - let server_random: [u8; 32] = [96u8; 32]; - let ms = compute_ms(pms, client_random, server_random); + async fn test_prf(config: Config) { + let mut rng = StdRng::seed_from_u64(1); + // Test input + let pms: [u8; 32] = rng.random(); + let client_random: [u8; 32] = rng.random(); + let server_random: [u8; 32] = rng.random(); + + let cf_hs_hash: [u8; 32] = rng.random(); + let sf_hs_hash: [u8; 32] = rng.random(); + + // Expected output + let ms_expected = prf_ms(pms, client_random, server_random); + + let [cwk_expected, swk_expected, civ_expected, siv_expected] = + prf_keys(ms_expected, client_random, server_random); + + let cwk_expected: [u8; 16] = cwk_expected.try_into().unwrap(); + let swk_expected: [u8; 16] = swk_expected.try_into().unwrap(); + let civ_expected: [u8; 4] = civ_expected.try_into().unwrap(); + let siv_expected: [u8; 4] = siv_expected.try_into().unwrap(); + + let cf_vd_expected = prf_cf_vd(ms_expected, cf_hs_hash); + let sf_vd_expected = prf_sf_vd(ms_expected, sf_hs_hash); + + let cf_vd_expected: [u8; 12] = cf_vd_expected.try_into().unwrap(); + let sf_vd_expected: [u8; 12] = sf_vd_expected.try_into().unwrap(); + + // Set up vm and prf + let (mut ctx_a, mut ctx_b) = test_st_context(128); + let (mut leader, mut follower) = mock_vm(); + + let leader_pms: Array = leader.alloc().unwrap(); + leader.mark_public(leader_pms).unwrap(); + leader.assign(leader_pms, pms).unwrap(); + leader.commit(leader_pms).unwrap(); + + let follower_pms: Array = follower.alloc().unwrap(); + follower.mark_public(follower_pms).unwrap(); + follower.assign(follower_pms, pms).unwrap(); + follower.commit(follower_pms).unwrap(); + + let mut leader_prf = MpcPrf::new(config); + let mut follower_prf = MpcPrf::new(config); + + let leader_prf_out = leader_prf.alloc(&mut leader, leader_pms).unwrap(); + let follower_prf_out = follower_prf.alloc(&mut follower, follower_pms).unwrap(); + + // client_random and server_random + leader_prf.set_client_random(client_random).unwrap(); + follower_prf.set_client_random(client_random).unwrap(); + + leader_prf.set_server_random(server_random).unwrap(); + follower_prf.set_server_random(server_random).unwrap(); + + let SessionKeys { + client_write_key: cwk_leader, + server_write_key: swk_leader, + client_iv: civ_leader, + server_iv: siv_leader, + } = leader_prf_out.keys; + + let mut cwk_leader = leader.decode(cwk_leader).unwrap(); + let mut swk_leader = leader.decode(swk_leader).unwrap(); + let mut civ_leader = leader.decode(civ_leader).unwrap(); + let mut siv_leader = leader.decode(siv_leader).unwrap(); + + let SessionKeys { + client_write_key: cwk_follower, + server_write_key: swk_follower, + client_iv: civ_follower, + server_iv: siv_follower, + } = follower_prf_out.keys; + + let mut cwk_follower = follower.decode(cwk_follower).unwrap(); + let mut swk_follower = follower.decode(swk_follower).unwrap(); + let mut civ_follower = follower.decode(civ_follower).unwrap(); + let mut siv_follower = follower.decode(siv_follower).unwrap(); + + loop { + let leader_finished = leader_prf.drive_key_expansion(&mut leader).unwrap(); + let follower_finished = follower_prf.drive_key_expansion(&mut follower).unwrap(); + + tokio::try_join!( + leader.execute_all(&mut ctx_a), + follower.execute_all(&mut ctx_b) + ) + .unwrap(); - let (mut leader_ctx, mut follower_ctx) = test_st_context(128); + if leader_finished && follower_finished { + break; + } + } + + let cwk_leader = cwk_leader.try_recv().unwrap().unwrap(); + let swk_leader = swk_leader.try_recv().unwrap().unwrap(); + let civ_leader = civ_leader.try_recv().unwrap().unwrap(); + let siv_leader = siv_leader.try_recv().unwrap().unwrap(); + + let cwk_follower = cwk_follower.try_recv().unwrap().unwrap(); + let swk_follower = swk_follower.try_recv().unwrap().unwrap(); + let civ_follower = civ_follower.try_recv().unwrap().unwrap(); + let siv_follower = siv_follower.try_recv().unwrap().unwrap(); + + assert_eq!(cwk_leader, cwk_follower); + assert_eq!(swk_leader, swk_follower); + assert_eq!(civ_leader, civ_follower); + assert_eq!(siv_leader, siv_follower); + + assert_eq!(cwk_leader, cwk_expected); + assert_eq!(swk_leader, swk_expected); + assert_eq!(civ_leader, civ_expected); + assert_eq!(siv_leader, siv_expected); + + // client finished + leader_prf.set_cf_hash(cf_hs_hash).unwrap(); + follower_prf.set_cf_hash(cf_hs_hash).unwrap(); + + let cf_vd_leader = leader_prf_out.cf_vd; + let cf_vd_follower = follower_prf_out.cf_vd; + + let mut cf_vd_leader = leader.decode(cf_vd_leader).unwrap(); + let mut cf_vd_follower = follower.decode(cf_vd_follower).unwrap(); + + loop { + let leader_finished = leader_prf.drive_client_finished(&mut leader).unwrap(); + let follower_finished = follower_prf.drive_client_finished(&mut follower).unwrap(); + + tokio::try_join!( + leader.execute_all(&mut ctx_a), + follower.execute_all(&mut ctx_b) + ) + .unwrap(); - let delta = Delta::random(&mut rng.compat_by_ref()); - let (ot_send, ot_recv) = ideal_cot(delta.into_inner()); + if leader_finished && follower_finished { + break; + } + } - let mut leader_vm = Generator::new(ot_send, [0u8; 16], delta); - let mut follower_vm = Evaluator::new(ot_recv); + let cf_vd_leader = cf_vd_leader.try_recv().unwrap().unwrap(); + let cf_vd_follower = cf_vd_follower.try_recv().unwrap().unwrap(); - let leader_pms: Array = leader_vm.alloc().unwrap(); - leader_vm.mark_public(leader_pms).unwrap(); - leader_vm.assign(leader_pms, pms).unwrap(); - leader_vm.commit(leader_pms).unwrap(); + assert_eq!(cf_vd_leader, cf_vd_follower); + assert_eq!(cf_vd_leader, cf_vd_expected); - let follower_pms: Array = follower_vm.alloc().unwrap(); - follower_vm.mark_public(follower_pms).unwrap(); - follower_vm.assign(follower_pms, pms).unwrap(); - follower_vm.commit(follower_pms).unwrap(); + // server finished + leader_prf.set_sf_hash(sf_hs_hash).unwrap(); + follower_prf.set_sf_hash(sf_hs_hash).unwrap(); - let mut leader = MpcPrf::new(PrfConfig::builder().role(Role::Leader).build().unwrap()); - let mut follower = MpcPrf::new(PrfConfig::builder().role(Role::Follower).build().unwrap()); + let sf_vd_leader = leader_prf_out.sf_vd; + let sf_vd_follower = follower_prf_out.sf_vd; - let leader_output = leader.alloc(&mut leader_vm, leader_pms).unwrap(); - let follower_output = follower.alloc(&mut follower_vm, follower_pms).unwrap(); + let mut sf_vd_leader = leader.decode(sf_vd_leader).unwrap(); + let mut sf_vd_follower = follower.decode(sf_vd_follower).unwrap(); - leader - .set_client_random(&mut leader_vm, Some(client_random)) - .unwrap(); - follower.set_client_random(&mut follower_vm, None).unwrap(); + loop { + let leader_finished = leader_prf.drive_server_finished(&mut leader).unwrap(); + let follower_finished = follower_prf.drive_server_finished(&mut follower).unwrap(); - leader - .set_server_random(&mut leader_vm, server_random) - .unwrap(); - follower - .set_server_random(&mut follower_vm, server_random) + tokio::try_join!( + leader.execute_all(&mut ctx_a), + follower.execute_all(&mut ctx_b) + ) .unwrap(); - let leader_cwk = leader_vm - .decode(leader_output.keys.client_write_key) - .unwrap(); - let leader_swk = leader_vm - .decode(leader_output.keys.server_write_key) - .unwrap(); - let leader_civ = leader_vm.decode(leader_output.keys.client_iv).unwrap(); - let leader_siv = leader_vm.decode(leader_output.keys.server_iv).unwrap(); - - let follower_cwk = follower_vm - .decode(follower_output.keys.client_write_key) - .unwrap(); - let follower_swk = follower_vm - .decode(follower_output.keys.server_write_key) - .unwrap(); - let follower_civ = follower_vm.decode(follower_output.keys.client_iv).unwrap(); - let follower_siv = follower_vm.decode(follower_output.keys.server_iv).unwrap(); - - futures::join!( - async { - leader_vm.flush(&mut leader_ctx).await.unwrap(); - leader_vm.execute(&mut leader_ctx).await.unwrap(); - leader_vm.flush(&mut leader_ctx).await.unwrap(); - }, - async { - follower_vm.flush(&mut follower_ctx).await.unwrap(); - follower_vm.execute(&mut follower_ctx).await.unwrap(); - follower_vm.flush(&mut follower_ctx).await.unwrap(); + if leader_finished && follower_finished { + break; } - ); - - let leader_cwk = leader_cwk.await.unwrap(); - let leader_swk = leader_swk.await.unwrap(); - let leader_civ = leader_civ.await.unwrap(); - let leader_siv = leader_siv.await.unwrap(); - - let follower_cwk = follower_cwk.await.unwrap(); - let follower_swk = follower_swk.await.unwrap(); - let follower_civ = follower_civ.await.unwrap(); - let follower_siv = follower_siv.await.unwrap(); - - let (expected_cwk, expected_swk, expected_civ, expected_siv) = - session_keys(pms, client_random, server_random); - - assert_eq!(leader_cwk, expected_cwk); - assert_eq!(leader_swk, expected_swk); - assert_eq!(leader_civ, expected_civ); - assert_eq!(leader_siv, expected_siv); - - assert_eq!(follower_cwk, expected_cwk); - assert_eq!(follower_swk, expected_swk); - assert_eq!(follower_civ, expected_civ); - assert_eq!(follower_siv, expected_siv); - - let cf_hs_hash = [1u8; 32]; - let sf_hs_hash = [2u8; 32]; - - leader.set_cf_hash(&mut leader_vm, cf_hs_hash).unwrap(); - leader.set_sf_hash(&mut leader_vm, sf_hs_hash).unwrap(); - - follower.set_cf_hash(&mut follower_vm, cf_hs_hash).unwrap(); - follower.set_sf_hash(&mut follower_vm, sf_hs_hash).unwrap(); - - let leader_cf_vd = leader_vm.decode(leader_output.cf_vd).unwrap(); - let leader_sf_vd = leader_vm.decode(leader_output.sf_vd).unwrap(); - - let follower_cf_vd = follower_vm.decode(follower_output.cf_vd).unwrap(); - let follower_sf_vd = follower_vm.decode(follower_output.sf_vd).unwrap(); - - futures::join!( - async { - leader_vm.flush(&mut leader_ctx).await.unwrap(); - leader_vm.execute(&mut leader_ctx).await.unwrap(); - leader_vm.flush(&mut leader_ctx).await.unwrap(); - }, - async { - follower_vm.flush(&mut follower_ctx).await.unwrap(); - follower_vm.execute(&mut follower_ctx).await.unwrap(); - follower_vm.flush(&mut follower_ctx).await.unwrap(); - } - ); - - let leader_cf_vd = leader_cf_vd.await.unwrap(); - let leader_sf_vd = leader_sf_vd.await.unwrap(); - - let follower_cf_vd = follower_cf_vd.await.unwrap(); - let follower_sf_vd = follower_sf_vd.await.unwrap(); + } - let expected_cf_vd = compute_vd(ms, b"client finished", cf_hs_hash); - let expected_sf_vd = compute_vd(ms, b"server finished", sf_hs_hash); + let sf_vd_leader = sf_vd_leader.try_recv().unwrap().unwrap(); + let sf_vd_follower = sf_vd_follower.try_recv().unwrap().unwrap(); - assert_eq!(leader_cf_vd, expected_cf_vd); - assert_eq!(leader_sf_vd, expected_sf_vd); - assert_eq!(follower_cf_vd, expected_cf_vd); - assert_eq!(follower_sf_vd, expected_sf_vd); + assert_eq!(sf_vd_leader, sf_vd_follower); + assert_eq!(sf_vd_leader, sf_vd_expected); } } diff --git a/crates/components/hmac-sha256/src/prf.rs b/crates/components/hmac-sha256/src/prf.rs deleted file mode 100644 index 5081ce6639..0000000000 --- a/crates/components/hmac-sha256/src/prf.rs +++ /dev/null @@ -1,312 +0,0 @@ -use std::{ - fmt::Debug, - sync::{Arc, OnceLock}, -}; - -use hmac_sha256_circuits::{build_session_keys, build_verify_data}; -use mpz_circuits::Circuit; -use mpz_common::cpu::CpuBackend; -use mpz_vm_core::{ - memory::{ - binary::{Binary, U32, U8}, - Array, - }, - prelude::*, - Call, Vm, -}; -use tracing::instrument; - -use crate::{PrfConfig, PrfError, PrfOutput, Role, SessionKeys, CF_LABEL, SF_LABEL}; - -pub(crate) struct Circuits { - session_keys: Arc, - client_vd: Arc, - server_vd: Arc, -} - -impl Circuits { - pub(crate) async fn get() -> &'static Self { - static CIRCUITS: OnceLock = OnceLock::new(); - if let Some(circuits) = CIRCUITS.get() { - return circuits; - } - - let (session_keys, client_vd, server_vd) = futures::join!( - CpuBackend::blocking(build_session_keys), - CpuBackend::blocking(|| build_verify_data(CF_LABEL)), - CpuBackend::blocking(|| build_verify_data(SF_LABEL)), - ); - - _ = CIRCUITS.set(Circuits { - session_keys, - client_vd, - server_vd, - }); - - CIRCUITS.get().unwrap() - } -} - -#[derive(Debug)] -pub(crate) enum State { - Initialized, - SessionKeys { - client_random: Array, - server_random: Array, - cf_hash: Array, - sf_hash: Array, - }, - ClientFinished { - cf_hash: Array, - sf_hash: Array, - }, - ServerFinished { - sf_hash: Array, - }, - Complete, - Error, -} - -impl State { - fn take(&mut self) -> State { - std::mem::replace(self, State::Error) - } -} - -/// MPC PRF for computing TLS HMAC-SHA256 PRF. -pub struct MpcPrf { - config: PrfConfig, - state: State, -} - -impl Debug for MpcPrf { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("MpcPrf") - .field("config", &self.config) - .field("state", &self.state) - .finish() - } -} - -impl MpcPrf { - /// Creates a new instance of the PRF. - pub fn new(config: PrfConfig) -> MpcPrf { - MpcPrf { - config, - state: State::Initialized, - } - } - - /// Allocates resources for the PRF. - /// - /// # Arguments - /// - /// * `vm` - Virtual machine. - /// * `pms` - The pre-master secret. - #[instrument(level = "debug", skip_all, err)] - pub fn alloc( - &mut self, - vm: &mut dyn Vm, - pms: Array, - ) -> Result { - let State::Initialized = self.state.take() else { - return Err(PrfError::state("PRF not in initialized state")); - }; - - let circuits = futures::executor::block_on(Circuits::get()); - - let client_random = vm.alloc().map_err(PrfError::vm)?; - let server_random = vm.alloc().map_err(PrfError::vm)?; - - // The client random is kept private so that the handshake transcript - // hashes do not leak information about the server's identity. - match self.config.role { - Role::Leader => vm.mark_private(client_random), - Role::Follower => vm.mark_blind(client_random), - } - .map_err(PrfError::vm)?; - - vm.mark_public(server_random).map_err(PrfError::vm)?; - - #[allow(clippy::type_complexity)] - let ( - client_write_key, - server_write_key, - client_iv, - server_iv, - ms_outer_hash_state, - ms_inner_hash_state, - ): ( - Array, - Array, - Array, - Array, - Array, - Array, - ) = vm - .call( - Call::builder(circuits.session_keys.clone()) - .arg(pms) - .arg(client_random) - .arg(server_random) - .build() - .map_err(PrfError::vm)?, - ) - .map_err(PrfError::vm)?; - - let keys = SessionKeys { - client_write_key, - server_write_key, - client_iv, - server_iv, - }; - - let cf_hash = vm.alloc().map_err(PrfError::vm)?; - vm.mark_public(cf_hash).map_err(PrfError::vm)?; - - let cf_vd = vm - .call( - Call::builder(circuits.client_vd.clone()) - .arg(ms_outer_hash_state) - .arg(ms_inner_hash_state) - .arg(cf_hash) - .build() - .map_err(PrfError::vm)?, - ) - .map_err(PrfError::vm)?; - - let sf_hash = vm.alloc().map_err(PrfError::vm)?; - vm.mark_public(sf_hash).map_err(PrfError::vm)?; - - let sf_vd = vm - .call( - Call::builder(circuits.server_vd.clone()) - .arg(ms_outer_hash_state) - .arg(ms_inner_hash_state) - .arg(sf_hash) - .build() - .map_err(PrfError::vm)?, - ) - .map_err(PrfError::vm)?; - - self.state = State::SessionKeys { - client_random, - server_random, - cf_hash, - sf_hash, - }; - - Ok(PrfOutput { keys, cf_vd, sf_vd }) - } - - /// Sets the client random. - /// - /// Only the leader can provide the client random. - /// - /// # Arguments - /// - /// * `vm` - Virtual machine. - /// * `client_random` - The client random. - #[instrument(level = "debug", skip_all, err)] - pub fn set_client_random( - &mut self, - vm: &mut dyn Vm, - random: Option<[u8; 32]>, - ) -> Result<(), PrfError> { - let State::SessionKeys { client_random, .. } = &self.state else { - return Err(PrfError::state("PRF not set up")); - }; - - if self.config.role == Role::Leader { - let Some(random) = random else { - return Err(PrfError::role("leader must provide client random")); - }; - - vm.assign(*client_random, random).map_err(PrfError::vm)?; - } else if random.is_some() { - return Err(PrfError::role("only leader can set client random")); - } - - vm.commit(*client_random).map_err(PrfError::vm)?; - - Ok(()) - } - - /// Sets the server random. - /// - /// # Arguments - /// - /// * `vm` - Virtual machine. - /// * `server_random` - The server random. - #[instrument(level = "debug", skip_all, err)] - pub fn set_server_random( - &mut self, - vm: &mut dyn Vm, - random: [u8; 32], - ) -> Result<(), PrfError> { - let State::SessionKeys { - server_random, - cf_hash, - sf_hash, - .. - } = self.state.take() - else { - return Err(PrfError::state("PRF not set up")); - }; - - vm.assign(server_random, random).map_err(PrfError::vm)?; - vm.commit(server_random).map_err(PrfError::vm)?; - - self.state = State::ClientFinished { cf_hash, sf_hash }; - - Ok(()) - } - - /// Sets the client finished handshake hash. - /// - /// # Arguments - /// - /// * `vm` - Virtual machine. - /// * `handshake_hash` - The handshake transcript hash. - #[instrument(level = "debug", skip_all, err)] - pub fn set_cf_hash( - &mut self, - vm: &mut dyn Vm, - handshake_hash: [u8; 32], - ) -> Result<(), PrfError> { - let State::ClientFinished { cf_hash, sf_hash } = self.state.take() else { - return Err(PrfError::state("PRF not in client finished state")); - }; - - vm.assign(cf_hash, handshake_hash).map_err(PrfError::vm)?; - vm.commit(cf_hash).map_err(PrfError::vm)?; - - self.state = State::ServerFinished { sf_hash }; - - Ok(()) - } - - /// Sets the server finished handshake hash. - /// - /// # Arguments - /// - /// * `vm` - Virtual machine. - /// * `handshake_hash` - The handshake transcript hash. - #[instrument(level = "debug", skip_all, err)] - pub fn set_sf_hash( - &mut self, - vm: &mut dyn Vm, - handshake_hash: [u8; 32], - ) -> Result<(), PrfError> { - let State::ServerFinished { sf_hash } = self.state.take() else { - return Err(PrfError::state("PRF not in server finished state")); - }; - - vm.assign(sf_hash, handshake_hash).map_err(PrfError::vm)?; - vm.commit(sf_hash).map_err(PrfError::vm)?; - - self.state = State::Complete; - - Ok(()) - } -} diff --git a/crates/components/hmac-sha256/src/prf/function/local.rs b/crates/components/hmac-sha256/src/prf/function/local.rs new file mode 100644 index 0000000000..63e31f1782 --- /dev/null +++ b/crates/components/hmac-sha256/src/prf/function/local.rs @@ -0,0 +1,294 @@ +//! Computes some hashes of the PRF locally. + +use crate::{convert_to_bytes, hmac::HmacSha256, sha256::sha256, PrfError}; +use mpz_vm_core::{ + memory::{ + binary::{Binary, U32, U8}, + Array, DecodeFutureTyped, MemoryExt, MemoryType, Repr, ViewExt, + }, + Vm, +}; + +#[derive(Debug)] +pub(crate) struct PrfFunction { + label: &'static [u8], + start_seed_label: Option>, + a: Vec, + p: Vec, +} + +impl PrfFunction { + const MS_LABEL: &[u8] = b"master secret"; + const KEY_LABEL: &[u8] = b"key expansion"; + const CF_LABEL: &[u8] = b"client finished"; + const SF_LABEL: &[u8] = b"server finished"; + + pub(crate) fn alloc_master_secret( + vm: &mut dyn Vm, + outer_partial: Array, + inner_partial: Array, + ) -> Result { + Self::alloc(vm, Self::MS_LABEL, outer_partial, inner_partial, 48) + } + + pub(crate) fn alloc_key_expansion( + vm: &mut dyn Vm, + outer_partial: Array, + inner_partial: Array, + ) -> Result { + Self::alloc(vm, Self::KEY_LABEL, outer_partial, inner_partial, 40) + } + + pub(crate) fn alloc_client_finished( + vm: &mut dyn Vm, + outer_partial: Array, + inner_partial: Array, + ) -> Result { + Self::alloc(vm, Self::CF_LABEL, outer_partial, inner_partial, 12) + } + + pub(crate) fn alloc_server_finished( + vm: &mut dyn Vm, + outer_partial: Array, + inner_partial: Array, + ) -> Result { + Self::alloc(vm, Self::SF_LABEL, outer_partial, inner_partial, 12) + } + + pub(crate) fn make_progress(&mut self, vm: &mut dyn Vm) -> Result { + let a_assigned = self.is_a_assigned(); + let mut p_assigned = self.is_p_assigned(); + + if !a_assigned { + self.poll_a(vm)?; + } + + if !p_assigned { + self.poll_p(vm)?; + p_assigned = self.is_p_assigned(); + } + + Ok(p_assigned) + } + + pub(crate) fn set_start_seed(&mut self, seed: Vec) { + let mut start_seed_label = self.label.to_vec(); + start_seed_label.extend_from_slice(&seed); + + self.start_seed_label = Some(start_seed_label); + } + + pub(crate) fn output(&self) -> Vec> { + self.p.iter().map(|p| p.output.value()).collect() + } + + fn poll_a(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { + let Some(mut message) = self.start_seed_label.clone() else { + return Err(PrfError::state("Starting seed not set for PRF")); + }; + + for a in self.a.iter_mut() { + if let Some(output) = a.output.poll(vm)? { + message = convert_to_bytes(output).to_vec(); + continue; + }; + + let Some(inner_partial) = a.inner_partial.poll(vm)? else { + break; + }; + + a.assign_inner_local(vm, inner_partial, &message)?; + } + + Ok(()) + } + + fn poll_p(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { + let Some(ref start_seed) = self.start_seed_label else { + return Err(PrfError::state("Starting seed not set for PRF")); + }; + + for (i, p) in self.p.iter_mut().enumerate() { + if p.inner_local.1 { + continue; + } + + let Some(message) = self.a[i].output.poll(vm)? else { + break; + }; + + let mut message = convert_to_bytes(message).to_vec(); + message.extend_from_slice(start_seed); + + let Some(inner_partial) = p.inner_partial.poll(vm)? else { + break; + }; + + p.assign_inner_local(vm, inner_partial, &message)?; + } + + Ok(()) + } + + fn alloc( + vm: &mut dyn Vm, + label: &'static [u8], + outer_partial: Array, + inner_partial: Array, + len: usize, + ) -> Result { + let mut prf = Self { + label, + start_seed_label: None, + a: vec![], + p: vec![], + }; + + assert!(len > 0, "cannot compute 0 bytes for prf"); + + let iterations = len / 32 + ((len % 32) != 0) as usize; + + for _ in 0..iterations { + let a = PHash::alloc(vm, outer_partial, inner_partial)?; + prf.a.push(a); + + let p = PHash::alloc(vm, outer_partial, inner_partial)?; + prf.p.push(p); + } + + Ok(prf) + } + + fn is_p_assigned(&self) -> bool { + self.p + .last() + .expect("prf should be allocated") + .inner_local + .1 + } + + fn is_a_assigned(&self) -> bool { + self.a + .last() + .expect("prf should be allocated") + .inner_local + .1 + } +} + +#[derive(Debug)] +struct PHash { + pub(crate) inner_partial: DecodeOperation>, + // the bool tracks if assignment has already happened + pub(crate) inner_local: (Array, bool), + pub(crate) output: DecodeOperation>, +} + +impl PHash { + fn alloc( + vm: &mut dyn Vm, + outer_partial: Array, + inner_partial: Array, + ) -> Result { + let inner_local = vm.alloc().map_err(PrfError::vm)?; + let hmac = HmacSha256::new(outer_partial, inner_local); + + let output = hmac.alloc(vm).map_err(PrfError::vm)?; + + let p_hash = Self { + inner_partial: DecodeOperation::new(inner_partial), + inner_local: (inner_local, false), + output: DecodeOperation::new(output), + }; + + Ok(p_hash) + } + + fn assign_inner_local( + &mut self, + vm: &mut dyn Vm, + inner_partial: [u32; 8], + msg: &[u8], + ) -> Result<(), PrfError> { + if !self.inner_local.1 { + let inner_local_ref: Array = self.inner_local.0; + let inner_local = sha256(inner_partial, 64, msg); + + vm.mark_public(inner_local_ref).map_err(PrfError::vm)?; + vm.assign(inner_local_ref, convert_to_bytes(inner_local)) + .map_err(PrfError::vm)?; + vm.commit(inner_local_ref).map_err(PrfError::vm)?; + + self.inner_local.1 = true + } + + Ok(()) + } +} + +#[derive(Debug)] +struct DecodeOperation +where + T: Repr, +{ + value: T, + progress: DecodeProgress, +} + +impl DecodeOperation +where + T: Repr + Copy, +{ + pub(crate) fn new(value: T) -> Self { + Self { + value, + progress: DecodeProgress::Alloc, + } + } + + pub(crate) fn value(&self) -> T { + self.value + } + + pub(crate) fn poll(&mut self, vm: &mut dyn Vm) -> Result, PrfError> { + self.progress.poll(vm, self.value) + } +} + +#[derive(Debug)] +enum DecodeProgress +where + T: Repr, +{ + Alloc, + Decoded(DecodeFutureTyped<::Raw, T::Clear>), + Finished(T::Clear), +} + +impl DecodeProgress +where + T: Repr + Copy, +{ + pub(crate) fn poll( + &mut self, + vm: &mut dyn Vm, + value: T, + ) -> Result, PrfError> { + match self { + DecodeProgress::Alloc => { + let value = vm.decode(value).map_err(PrfError::vm)?; + *self = DecodeProgress::Decoded(value); + Ok(None) + } + DecodeProgress::Decoded(value) => { + if let Some(value) = value.try_recv().map_err(PrfError::vm)? { + *self = DecodeProgress::Finished(value); + Ok(Some(value)) + } else { + Ok(None) + } + } + DecodeProgress::Finished(value) => Ok(Some(*value)), + } + } +} diff --git a/crates/components/hmac-sha256/src/prf/function/mod.rs b/crates/components/hmac-sha256/src/prf/function/mod.rs new file mode 100644 index 0000000000..7c1b622ebc --- /dev/null +++ b/crates/components/hmac-sha256/src/prf/function/mod.rs @@ -0,0 +1,247 @@ +//! Provides [`Prf`], for computing the TLS 1.2 PRF. + +use crate::{Config, PrfError}; +use mpz_vm_core::{ + memory::{ + binary::{Binary, U32}, + Array, + }, + Vm, +}; + +mod local; +mod mpc; + +#[derive(Debug)] +pub(crate) enum Prf { + Local(local::PrfFunction), + Mpc(mpc::PrfFunction), +} + +impl Prf { + pub(crate) fn alloc_master_secret( + config: Config, + vm: &mut dyn Vm, + outer_partial: Array, + inner_partial: Array, + ) -> Result { + let prf = match config { + Config::Local => Self::Local(local::PrfFunction::alloc_master_secret( + vm, + outer_partial, + inner_partial, + )?), + Config::Mpc => Self::Mpc(mpc::PrfFunction::alloc_master_secret( + vm, + outer_partial, + inner_partial, + )?), + }; + Ok(prf) + } + + pub(crate) fn alloc_key_expansion( + config: Config, + vm: &mut dyn Vm, + outer_partial: Array, + inner_partial: Array, + ) -> Result { + let prf = match config { + Config::Local => Self::Local(local::PrfFunction::alloc_key_expansion( + vm, + outer_partial, + inner_partial, + )?), + Config::Mpc => Self::Mpc(mpc::PrfFunction::alloc_key_expansion( + vm, + outer_partial, + inner_partial, + )?), + }; + Ok(prf) + } + + pub(crate) fn alloc_client_finished( + config: Config, + vm: &mut dyn Vm, + outer_partial: Array, + inner_partial: Array, + ) -> Result { + let prf = match config { + Config::Local => Self::Local(local::PrfFunction::alloc_client_finished( + vm, + outer_partial, + inner_partial, + )?), + Config::Mpc => Self::Mpc(mpc::PrfFunction::alloc_client_finished( + vm, + outer_partial, + inner_partial, + )?), + }; + Ok(prf) + } + + pub(crate) fn alloc_server_finished( + config: Config, + vm: &mut dyn Vm, + outer_partial: Array, + inner_partial: Array, + ) -> Result { + let prf = match config { + Config::Local => Self::Local(local::PrfFunction::alloc_server_finished( + vm, + outer_partial, + inner_partial, + )?), + Config::Mpc => Self::Mpc(mpc::PrfFunction::alloc_server_finished( + vm, + outer_partial, + inner_partial, + )?), + }; + Ok(prf) + } + + pub(crate) fn make_progress(&mut self, vm: &mut dyn Vm) -> Result { + match self { + Prf::Local(prf) => prf.make_progress(vm), + Prf::Mpc(prf) => prf.make_progress(vm), + } + } + + pub(crate) fn set_start_seed(&mut self, seed: Vec) { + match self { + Prf::Local(prf) => prf.set_start_seed(seed), + Prf::Mpc(prf) => prf.set_start_seed(seed), + } + } + + pub(crate) fn output(&self) -> Vec> { + match self { + Prf::Local(prf) => prf.output(), + Prf::Mpc(prf) => prf.output(), + } + } +} + +#[cfg(test)] +mod tests { + use crate::{ + convert_to_bytes, + prf::{compute_partial, function::Prf}, + test_utils::{mock_vm, phash}, + Config, + }; + use mpz_common::context::test_st_context; + use mpz_vm_core::{ + memory::{binary::U8, Array, MemoryExt, ViewExt}, + Execute, + }; + + const IPAD: [u8; 64] = [0x36; 64]; + const OPAD: [u8; 64] = [0x5c; 64]; + + #[tokio::test] + async fn test_phash_local() { + let config = Config::Local; + test_phash(config).await; + } + + #[tokio::test] + async fn test_phash_mpc() { + let config = Config::Local; + test_phash(config).await; + } + + async fn test_phash(config: Config) { + let (mut ctx_a, mut ctx_b) = test_st_context(8); + let (mut leader, mut follower) = mock_vm(); + + let key: [u8; 32] = std::array::from_fn(|i| i as u8); + let start_seed: Vec = vec![42; 64]; + + let mut label_seed = b"master secret".to_vec(); + label_seed.extend_from_slice(&start_seed); + let iterations = 2; + + let leader_key: Array = leader.alloc().unwrap(); + leader.mark_public(leader_key).unwrap(); + leader.assign(leader_key, key).unwrap(); + leader.commit(leader_key).unwrap(); + + let outer_partial_leader = compute_partial(&mut leader, leader_key.into(), OPAD).unwrap(); + let inner_partial_leader = compute_partial(&mut leader, leader_key.into(), IPAD).unwrap(); + + let mut prf_leader = Prf::alloc_master_secret( + config, + &mut leader, + outer_partial_leader, + inner_partial_leader, + ) + .unwrap(); + prf_leader.set_start_seed(start_seed.clone()); + + let mut prf_out_leader = vec![]; + for p in prf_leader.output() { + let p_out = leader.decode(p).unwrap(); + prf_out_leader.push(p_out) + } + + let follower_key: Array = follower.alloc().unwrap(); + follower.mark_public(follower_key).unwrap(); + follower.assign(follower_key, key).unwrap(); + follower.commit(follower_key).unwrap(); + + let outer_partial_follower = + compute_partial(&mut follower, follower_key.into(), OPAD).unwrap(); + let inner_partial_follower = + compute_partial(&mut follower, follower_key.into(), IPAD).unwrap(); + + let mut prf_follower = Prf::alloc_master_secret( + config, + &mut follower, + outer_partial_follower, + inner_partial_follower, + ) + .unwrap(); + prf_follower.set_start_seed(start_seed.clone()); + + let mut prf_out_follower = vec![]; + for p in prf_follower.output() { + let p_out = follower.decode(p).unwrap(); + prf_out_follower.push(p_out) + } + + loop { + let leader_finished = prf_leader.make_progress(&mut leader).unwrap(); + let follower_finished = prf_follower.make_progress(&mut follower).unwrap(); + + tokio::try_join!( + leader.execute_all(&mut ctx_a), + follower.execute_all(&mut ctx_b) + ) + .unwrap(); + + if leader_finished && follower_finished { + break; + } + } + + assert_eq!(prf_out_leader.len(), prf_out_follower.len()); + + let prf_result_leader: Vec = prf_out_leader + .iter_mut() + .flat_map(|p| convert_to_bytes(p.try_recv().unwrap().unwrap())) + .collect(); + let prf_result_follower: Vec = prf_out_follower + .iter_mut() + .flat_map(|p| convert_to_bytes(p.try_recv().unwrap().unwrap())) + .collect(); + + let expected = phash(key.to_vec(), &label_seed, iterations); + + assert_eq!(prf_result_leader, prf_result_follower); + assert_eq!(prf_result_leader, expected) + } +} diff --git a/crates/components/hmac-sha256/src/prf/function/mpc.rs b/crates/components/hmac-sha256/src/prf/function/mpc.rs new file mode 100644 index 0000000000..9b1c04b7bb --- /dev/null +++ b/crates/components/hmac-sha256/src/prf/function/mpc.rs @@ -0,0 +1,205 @@ +//! Computes the whole PRF in MPC. + +use crate::{hmac::HmacSha256, sha256::Sha256, PrfError}; +use mpz_circuits::CircuitBuilder; +use mpz_vm_core::{ + memory::{ + binary::{Binary, U32, U8}, + Array, MemoryExt, Vector, ViewExt, + }, + Call, CallableExt, Vm, +}; +use std::sync::Arc; + +#[derive(Debug)] +pub(crate) struct PrfFunction { + label: &'static [u8], + start_seed_label: Option>, + a: Vec, + p: Vec, + assigned: bool, +} + +impl PrfFunction { + const MS_LABEL: &[u8] = b"master secret"; + const KEY_LABEL: &[u8] = b"key expansion"; + const CF_LABEL: &[u8] = b"client finished"; + const SF_LABEL: &[u8] = b"server finished"; + + pub(crate) fn alloc_master_secret( + vm: &mut dyn Vm, + outer_partial: Array, + inner_partial: Array, + ) -> Result { + Self::alloc(vm, Self::MS_LABEL, outer_partial, inner_partial, 48, 64) + } + + pub(crate) fn alloc_key_expansion( + vm: &mut dyn Vm, + outer_partial: Array, + inner_partial: Array, + ) -> Result { + Self::alloc(vm, Self::KEY_LABEL, outer_partial, inner_partial, 40, 64) + } + + pub(crate) fn alloc_client_finished( + vm: &mut dyn Vm, + outer_partial: Array, + inner_partial: Array, + ) -> Result { + Self::alloc(vm, Self::CF_LABEL, outer_partial, inner_partial, 12, 32) + } + + pub(crate) fn alloc_server_finished( + vm: &mut dyn Vm, + outer_partial: Array, + inner_partial: Array, + ) -> Result { + Self::alloc(vm, Self::SF_LABEL, outer_partial, inner_partial, 12, 32) + } + + pub(crate) fn make_progress(&mut self, vm: &mut dyn Vm) -> Result { + if !self.assigned { + let a = self.a.first_mut().expect("prf should be allocated"); + let msg = a.msg; + + let msg_value = self + .start_seed_label + .clone() + .expect("seed should be assigned by now"); + + vm.mark_public(msg).map_err(PrfError::vm)?; + vm.assign(msg, msg_value).map_err(PrfError::vm)?; + vm.commit(msg).map_err(PrfError::vm)?; + self.assigned = true; + } + + Ok(self.assigned) + } + + pub(crate) fn set_start_seed(&mut self, seed: Vec) { + let mut start_seed_label = self.label.to_vec(); + start_seed_label.extend_from_slice(&seed); + + self.start_seed_label = Some(start_seed_label); + } + + pub(crate) fn output(&self) -> Vec> { + self.p.iter().map(|p| p.output).collect() + } + + fn alloc( + vm: &mut dyn Vm, + label: &'static [u8], + outer_partial: Array, + inner_partial: Array, + output_len: usize, + seed_len: usize, + ) -> Result { + let mut prf = Self { + label, + start_seed_label: None, + a: vec![], + p: vec![], + assigned: false, + }; + + assert!(output_len > 0, "cannot compute 0 bytes for prf"); + + let iterations = output_len / 32 + ((output_len % 32) != 0) as usize; + + let msg_len_a = label.len() + seed_len; + let seed_label_ref: Vector = vm.alloc_vec(msg_len_a).map_err(PrfError::vm)?; + let mut msg_a = seed_label_ref; + + for _ in 0..iterations { + let a = PHash::alloc(vm, outer_partial, inner_partial, msg_a)?; + msg_a = convert_array(vm, a.output)?.into(); + prf.a.push(a); + + let msg_p = merge_vecs(vm, vec![msg_a, seed_label_ref])?; + let p = PHash::alloc(vm, outer_partial, inner_partial, msg_p)?; + prf.p.push(p); + } + + Ok(prf) + } +} + +#[derive(Debug, Clone)] +struct PHash { + pub(crate) msg: Vector, + pub(crate) output: Array, +} + +impl PHash { + fn alloc( + vm: &mut dyn Vm, + outer_partial: Array, + inner_partial: Array, + msg: Vector, + ) -> Result { + let mut sha = Sha256::default(); + sha.set_state(inner_partial, 64) + .update(msg) + .add_padding(vm)?; + + let inner_local = sha.alloc(vm)?; + let inner_local = convert_array(vm, inner_local)?; + + let hmac = HmacSha256::new(outer_partial, inner_local); + let output = hmac.alloc(vm).map_err(PrfError::vm)?; + + let p_hash = Self { msg, output }; + Ok(p_hash) + } +} + +fn convert_array(vm: &mut dyn Vm, input: Array) -> Result, PrfError> { + let circ = { + let mut builder = CircuitBuilder::new(); + let inputs = (0..32 * 8).map(|_| builder.add_input()).collect::>(); + + for input in inputs.chunks_exact(4 * 8) { + for byte in input.chunks_exact(8).rev() { + for &feed in byte.iter() { + let output = builder.add_id_gate(feed); + builder.add_output(output); + } + } + } + + Arc::new(builder.build().expect("conversion circuit is valid")) + }; + + let mut builder = Call::builder(circ); + builder = builder.arg(input); + let call = builder.build().map_err(PrfError::vm)?; + + vm.call(call).map_err(PrfError::vm) +} + +fn merge_vecs(vm: &mut dyn Vm, inputs: Vec>) -> Result, PrfError> { + let len: usize = inputs.iter().map(|inp| inp.len()).sum(); + let circ = { + let mut builder = CircuitBuilder::new(); + + let feeds = (0..len * 8) + .map(|_| builder.add_input()) + .collect::>(); + for feed in feeds { + let output = builder.add_id_gate(feed); + builder.add_output(output); + } + + Arc::new(builder.build().expect("merge circuit is valid")) + }; + + let mut builder = Call::builder(circ); + for input in inputs { + builder = builder.arg(input); + } + let call = builder.build().map_err(PrfError::vm)?; + + vm.call(call).map_err(PrfError::vm) +} diff --git a/crates/components/hmac-sha256/src/prf/mod.rs b/crates/components/hmac-sha256/src/prf/mod.rs new file mode 100644 index 0000000000..248e5f2830 --- /dev/null +++ b/crates/components/hmac-sha256/src/prf/mod.rs @@ -0,0 +1,466 @@ +use crate::{sha256::Sha256, Config, PrfError, PrfOutput, SessionKeys}; +use mpz_circuits::{circuits::xor, Circuit, CircuitBuilder}; +use mpz_vm_core::{ + memory::{ + binary::{Binary, U32, U8}, + Array, FromRaw, MemoryExt, StaticSize, ToRaw, Vector, ViewExt, + }, + Call, CallableExt, Vm, +}; +use std::{fmt::Debug, sync::Arc}; +use tracing::instrument; + +mod state; +use state::State; + +pub(crate) mod function; +use function::Prf; + +/// MPC PRF for computing TLS 1.2 HMAC-SHA256 PRF. +#[derive(Debug)] +pub struct MpcPrf { + config: Config, + state: State, + circuits: Option, +} + +impl MpcPrf { + /// Creates a new instance of the PRF. + /// + /// # Arguments + /// + /// `config` - The PRF config. + pub fn new(config: Config) -> MpcPrf { + Self { + config, + state: State::Initialized, + circuits: None, + } + } + + /// Allocates resources for the PRF. + /// + /// # Arguments + /// + /// * `vm` - Virtual machine. + /// * `pms` - The pre-master secret. + #[instrument(level = "debug", skip_all, err)] + pub fn alloc( + &mut self, + vm: &mut dyn Vm, + pms: Array, + ) -> Result { + let State::Initialized = self.state.take() else { + return Err(PrfError::state("PRF not in initialized state")); + }; + + let circuits = Circuits::alloc(self.config, vm, pms.into())?; + + let keys = circuits.get_session_keys(vm)?; + let cf_vd = circuits.get_client_finished_vd(vm)?; + let sf_vd = circuits.get_server_finished_vd(vm)?; + + let prf_output = PrfOutput { keys, cf_vd, sf_vd }; + + self.circuits = Some(circuits); + self.state = State::SessionKeys { + client_random: None, + }; + + Ok(prf_output) + } + + /// Sets the client random. + /// + /// # Arguments + /// + /// * `random` - The client random. + #[instrument(level = "debug", skip_all, err)] + pub fn set_client_random(&mut self, random: [u8; 32]) -> Result<(), PrfError> { + let State::SessionKeys { client_random } = &mut self.state else { + return Err(PrfError::state("PRF not set up")); + }; + + *client_random = Some(random); + Ok(()) + } + + /// Sets the server random. + /// + /// # Arguments + /// + /// * `random` - The server random. + #[instrument(level = "debug", skip_all, err)] + pub fn set_server_random(&mut self, random: [u8; 32]) -> Result<(), PrfError> { + let State::SessionKeys { client_random } = self.state.take() else { + return Err(PrfError::state("PRF not set up")); + }; + + let Some(ref mut circuits) = self.circuits else { + return Err(PrfError::state("Circuits should have been set for PRF")); + }; + + let client_random = client_random.expect("Client random should have been set by now"); + let server_random = random; + + let mut seed_ms = client_random.to_vec(); + seed_ms.extend_from_slice(&server_random); + circuits.master_secret.set_start_seed(seed_ms); + + let mut seed_ke = server_random.to_vec(); + seed_ke.extend_from_slice(&client_random); + circuits.key_expansion.set_start_seed(seed_ke); + + self.state = State::ClientFinished; + Ok(()) + } + + /// Sets the client finished handshake hash. + /// + /// # Arguments + /// + /// * `handshake_hash` - The handshake transcript hash. + #[instrument(level = "debug", skip_all, err)] + pub fn set_cf_hash(&mut self, handshake_hash: [u8; 32]) -> Result<(), PrfError> { + let State::ClientFinished = self.state.take() else { + return Err(PrfError::state("PRF not in client finished state")); + }; + + let Some(ref mut circuits) = self.circuits else { + return Err(PrfError::state("Circuits should have been set for PRF")); + }; + + let seed_cf = handshake_hash.to_vec(); + circuits.client_finished.set_start_seed(seed_cf); + + self.state = State::ServerFinished; + Ok(()) + } + + /// Sets the server finished handshake hash. + /// + /// # Arguments + /// + /// * `handshake_hash` - The handshake transcript hash. + #[instrument(level = "debug", skip_all, err)] + pub fn set_sf_hash(&mut self, handshake_hash: [u8; 32]) -> Result<(), PrfError> { + let State::ServerFinished = self.state.take() else { + return Err(PrfError::state("PRF not in server finished state")); + }; + + let Some(ref mut circuits) = self.circuits else { + return Err(PrfError::state("Circuits should have been set for PRF")); + }; + + let seed_sf = handshake_hash.to_vec(); + circuits.server_finished.set_start_seed(seed_sf); + + self.state = State::Complete; + Ok(()) + } + + /// Drives the computation of the session keys. + /// + /// Returns if all inputs have been assigned for the computation of the + /// final output. + /// + /// # Arguments + /// + /// * `vm` - Virtual machine. + #[instrument(level = "debug", skip_all, err)] + pub fn drive_key_expansion(&mut self, vm: &mut dyn Vm) -> Result { + let Some(ref mut circuits) = self.circuits else { + return Err(PrfError::state("Circuits should have been set for PRF")); + }; + + circuits.drive_key_expansion(vm) + } + + /// Drives the computation of the client_finished verify_data. + /// + /// Returns if all inputs have been assigned for the computation of the + /// final output. + /// + /// # Arguments + /// + /// * `vm` - Virtual machine. + #[instrument(level = "debug", skip_all, err)] + pub fn drive_client_finished(&mut self, vm: &mut dyn Vm) -> Result { + let Some(ref mut circuits) = self.circuits else { + return Err(PrfError::state("Circuits should have been set for PRF")); + }; + + circuits.drive_client_finished(vm) + } + + /// Drives the computation of the server_finished verify_data. + /// + /// Returns if all inputs have been assigned for the computation of the + /// final output. + /// + /// # Arguments + /// + /// * `vm` - Virtual machine. + #[instrument(level = "debug", skip_all, err)] + pub fn drive_server_finished(&mut self, vm: &mut dyn Vm) -> Result { + let Some(ref mut circuits) = self.circuits else { + return Err(PrfError::state("Circuits should have been set for PRF")); + }; + + circuits.drive_server_finished(vm) + } +} + +/// Contains the respective [`PrfFunction`]s. +#[derive(Debug)] +struct Circuits { + pub(crate) master_secret: Prf, + pub(crate) key_expansion: Prf, + pub(crate) client_finished: Prf, + pub(crate) server_finished: Prf, +} + +impl Circuits { + const IPAD: [u8; 64] = [0x36; 64]; + const OPAD: [u8; 64] = [0x5c; 64]; + + fn alloc(config: Config, vm: &mut dyn Vm, pms: Vector) -> Result { + let outer_partial_pms = compute_partial(vm, pms, Self::OPAD)?; + let inner_partial_pms = compute_partial(vm, pms, Self::IPAD)?; + + let master_secret = + Prf::alloc_master_secret(config, vm, outer_partial_pms, inner_partial_pms)?; + let ms = master_secret.output(); + let ms = merge_outputs(vm, ms, 48)?; + + let outer_partial_ms = compute_partial(vm, ms, Self::OPAD)?; + let inner_partial_ms = compute_partial(vm, ms, Self::IPAD)?; + + let circuits = Self { + master_secret, + key_expansion: Prf::alloc_key_expansion( + config, + vm, + outer_partial_ms, + inner_partial_ms, + )?, + client_finished: Prf::alloc_client_finished( + config, + vm, + outer_partial_ms, + inner_partial_ms, + )?, + server_finished: Prf::alloc_server_finished( + config, + vm, + outer_partial_ms, + inner_partial_ms, + )?, + }; + Ok(circuits) + } + + fn get_session_keys(&self, vm: &mut dyn Vm) -> Result { + let keys = self.key_expansion.output(); + let mut keys = merge_outputs(vm, keys, 40)?; + + let server_iv = as FromRaw>::from_raw(keys.split_off(36).to_raw()); + let client_iv = as FromRaw>::from_raw(keys.split_off(32).to_raw()); + let server_write_key = + as FromRaw>::from_raw(keys.split_off(16).to_raw()); + let client_write_key = as FromRaw>::from_raw(keys.to_raw()); + + let session_keys = SessionKeys { + client_write_key, + server_write_key, + client_iv, + server_iv, + }; + + Ok(session_keys) + } + + fn get_client_finished_vd(&self, vm: &mut dyn Vm) -> Result, PrfError> { + let client_finished = &self.client_finished; + let cf_vd = client_finished.output(); + + let cf_vd = merge_outputs(vm, cf_vd, 12)?; + let cf_vd = as FromRaw>::from_raw(cf_vd.to_raw()); + + Ok(cf_vd) + } + + fn get_server_finished_vd(&self, vm: &mut dyn Vm) -> Result, PrfError> { + let server_finished = &self.server_finished; + let sf_vd = server_finished.output(); + + let sf_vd = merge_outputs(vm, sf_vd, 12)?; + let sf_vd = as FromRaw>::from_raw(sf_vd.to_raw()); + + Ok(sf_vd) + } + + fn drive_key_expansion(&mut self, vm: &mut dyn Vm) -> Result { + let ms_finished = self.master_secret.make_progress(vm)?; + let ke_finished = self.key_expansion.make_progress(vm)?; + + Ok(ms_finished && ke_finished) + } + + fn drive_client_finished(&mut self, vm: &mut dyn Vm) -> Result { + self.client_finished.make_progress(vm) + } + + fn drive_server_finished(&mut self, vm: &mut dyn Vm) -> Result { + self.server_finished.make_progress(vm) + } +} + +/// Depending on the provided `mask` computes and returns `outer_partial` or +/// `inner_partial` for HMAC-SHA256. +/// +/// # Arguments +/// +/// * `vm` - Virtual machine. +/// * `key` - Key to pad and xor. +/// * `mask`- Mask used for padding. +fn compute_partial( + vm: &mut dyn Vm, + key: Vector, + mask: [u8; 64], +) -> Result, PrfError> { + let xor = Arc::new(xor(8 * 64)); + + let additional_len = 64 - key.len(); + let padding = vec![0_u8; additional_len]; + + let padding_ref: Vector = vm.alloc_vec(additional_len).map_err(PrfError::vm)?; + vm.mark_public(padding_ref).map_err(PrfError::vm)?; + vm.assign(padding_ref, padding).map_err(PrfError::vm)?; + vm.commit(padding_ref).map_err(PrfError::vm)?; + + let mask_ref: Array = vm.alloc().map_err(PrfError::vm)?; + vm.mark_public(mask_ref).map_err(PrfError::vm)?; + vm.assign(mask_ref, mask).map_err(PrfError::vm)?; + vm.commit(mask_ref).map_err(PrfError::vm)?; + + let xor = Call::builder(xor) + .arg(key) + .arg(padding_ref) + .arg(mask_ref) + .build() + .map_err(PrfError::vm)?; + let key_padded = vm.call(xor).map_err(PrfError::vm)?; + + let mut sha = Sha256::default(); + sha.update(key_padded); + sha.alloc(vm) +} + +fn merge_outputs( + vm: &mut dyn Vm, + inputs: Vec>, + output_bytes: usize, +) -> Result, PrfError> { + assert!(output_bytes <= 32 * inputs.len()); + + let bits = Array::::SIZE * inputs.len(); + let msb0_circ = gen_merge_circ(4, bits); + + let mut builder = Call::builder(msb0_circ); + for &input in inputs.iter() { + builder = builder.arg(input); + } + let call = builder.build().map_err(PrfError::vm)?; + + let mut output: Vector = vm.call(call).map_err(PrfError::vm)?; + output.truncate(output_bytes); + Ok(output) +} + +fn gen_merge_circ(element_byte_size: usize, size: usize) -> Arc { + assert!((size / 8) % element_byte_size == 0); + + let mut builder = CircuitBuilder::new(); + let inputs = (0..size).map(|_| builder.add_input()).collect::>(); + + for input in inputs.chunks_exact(element_byte_size * 8) { + for byte in input.chunks_exact(8).rev() { + for &feed in byte.iter() { + let output = builder.add_id_gate(feed); + builder.add_output(output); + } + } + } + + Arc::new(builder.build().expect("merge circuit is valid")) +} + +#[cfg(test)] +mod tests { + use crate::{convert_to_bytes, prf::merge_outputs, test_utils::mock_vm}; + use mpz_common::context::test_st_context; + use mpz_vm_core::{ + memory::{binary::U32, Array, MemoryExt, ViewExt}, + Execute, + }; + + #[tokio::test] + async fn test_merge_outputs() { + let (mut ctx_a, mut ctx_b) = test_st_context(8); + let (mut leader, mut follower) = mock_vm(); + + let input1: [u32; 8] = std::array::from_fn(|i| i as u32); + let input2: [u32; 8] = std::array::from_fn(|i| i as u32 + 8); + + let mut expected = convert_to_bytes(input1).to_vec(); + expected.extend_from_slice(&convert_to_bytes(input2)); + expected.truncate(48); + + // leader + let input1_leader: Array = leader.alloc().unwrap(); + let input2_leader: Array = leader.alloc().unwrap(); + + leader.mark_public(input1_leader).unwrap(); + leader.mark_public(input2_leader).unwrap(); + + leader.assign(input1_leader, input1).unwrap(); + leader.assign(input2_leader, input2).unwrap(); + + leader.commit(input1_leader).unwrap(); + leader.commit(input2_leader).unwrap(); + + let merged_leader = + merge_outputs(&mut leader, vec![input1_leader, input2_leader], 48).unwrap(); + let mut merged_leader = leader.decode(merged_leader).unwrap(); + + // follower + let input1_follower: Array = follower.alloc().unwrap(); + let input2_follower: Array = follower.alloc().unwrap(); + + follower.mark_public(input1_follower).unwrap(); + follower.mark_public(input2_follower).unwrap(); + + follower.assign(input1_follower, input1).unwrap(); + follower.assign(input2_follower, input2).unwrap(); + + follower.commit(input1_follower).unwrap(); + follower.commit(input2_follower).unwrap(); + + let merged_follower = + merge_outputs(&mut follower, vec![input1_follower, input2_follower], 48).unwrap(); + let mut merged_follower = follower.decode(merged_follower).unwrap(); + + tokio::try_join!( + leader.execute_all(&mut ctx_a), + follower.execute_all(&mut ctx_b) + ) + .unwrap(); + + let merged_leader = merged_leader.try_recv().unwrap().unwrap(); + let merged_follower = merged_follower.try_recv().unwrap().unwrap(); + + assert_eq!(merged_leader, merged_follower); + assert_eq!(merged_leader, expected); + } +} diff --git a/crates/components/hmac-sha256/src/prf/state.rs b/crates/components/hmac-sha256/src/prf/state.rs new file mode 100644 index 0000000000..43cdca42d1 --- /dev/null +++ b/crates/components/hmac-sha256/src/prf/state.rs @@ -0,0 +1,15 @@ +#[derive(Debug)] +pub(crate) enum State { + Initialized, + SessionKeys { client_random: Option<[u8; 32]> }, + ClientFinished, + ServerFinished, + Complete, + Error, +} + +impl State { + pub(crate) fn take(&mut self) -> State { + std::mem::replace(self, State::Error) + } +} diff --git a/crates/components/hmac-sha256/src/sha256.rs b/crates/components/hmac-sha256/src/sha256.rs new file mode 100644 index 0000000000..752a5f5961 --- /dev/null +++ b/crates/components/hmac-sha256/src/sha256.rs @@ -0,0 +1,381 @@ +//! Computation of SHA256. + +use crate::PrfError; +use mpz_circuits::circuits::SHA256_COMPRESS; +use mpz_vm_core::{ + memory::{ + binary::{Binary, U32, U8}, + Array, MemoryExt, Vector, ViewExt, + }, + Call, CallableExt, Vm, +}; + +/// Computes SHA256. +#[derive(Debug, Default)] +pub(crate) struct Sha256 { + state: Option>, + chunks: Vec>, + processed: usize, +} + +impl Sha256 { + /// The default initialization vector. + const IV: [u32; 8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, + 0x5be0cd19, + ]; + + /// Sets the state. + /// + /// # Arguments + /// + /// * `state` - The starting state for the SHA256 compression function. + /// * `processed` - The number of already processed bytes corresponding to + /// `state`. + pub(crate) fn set_state(&mut self, state: Array, processed: usize) -> &mut Self { + self.state = Some(state); + self.processed = processed; + self + } + + /// Feeds data into the hash function. + /// + /// # Arguments + /// + /// * `data` - The data to hash. + pub(crate) fn update(&mut self, data: Vector) -> &mut Self { + self.chunks.push(data); + self + } + + /// Computes the padding for SHA256. + /// + /// Padding is computed depending on [`Self::state`] and + /// [`Self::processed`]. + /// + /// # Arguments + /// + /// * `vm` - The virtual machine. + pub(crate) fn add_padding(&mut self, vm: &mut dyn Vm) -> Result<&mut Self, PrfError> { + let msg_len: usize = self.chunks.iter().map(|b| b.len()).sum(); + let pos = self.processed; + + let bit_len = msg_len * 8; + let processed_bit_len = (bit_len + (pos * 8)) as u64; + + // minimum length of padded message in bytes + let min_padded_len = msg_len + 9; + // number of 64-byte blocks rounded up + let block_count = (min_padded_len / 64) + (min_padded_len % 64 != 0) as usize; + // message is padded to a multiple of 64 bytes + let padded_len = block_count * 64; + // number of bytes to pad + let pad_len = padded_len - msg_len; + + // append a single '1' bit + // append K '0' bits, where K is the minimum number >= 0 such that (L + 1 + K + + // 64) is a multiple of 512 append L as a 64-bit big-endian integer, making + // the total post-processed length a multiple of 512 bits such that the bits + // in the message are: 1 , (the number of bits will be a multiple of 512) + let mut padding = Vec::new(); + padding.push(128_u8); + padding.extend((0..pad_len - 9).map(|_| 0_u8)); + padding.extend(processed_bit_len.to_be_bytes()); + + let padding_ref: Vector = vm.alloc_vec(padding.len()).map_err(PrfError::vm)?; + + vm.mark_public(padding_ref).map_err(PrfError::vm)?; + vm.assign(padding_ref, padding).map_err(PrfError::vm)?; + vm.commit(padding_ref).map_err(PrfError::vm)?; + + self.chunks.push(padding_ref); + Ok(self) + } + + /// Adds the [`Call`] to the [`Vm`], and returns the output. + /// + /// # Arguments + /// + /// * `vm` - The virtual machine. + pub(crate) fn alloc(self, vm: &mut dyn Vm) -> Result, PrfError> { + let mut state = if let Some(state) = self.state { + state + } else { + Self::assign_iv(vm)? + }; + + // SHA256 compression function takes 64 byte blocks as inputs but our blocks in + // `self.chunks` can have arbitrary size to simplify the api. So we need to + // repartition them to 64 byte blocks and feed those into the + // compression function. + let mut remainder = None; + let mut block: Vec> = vec![]; + let mut chunk_iter = self.chunks.iter().copied(); + + loop { + if let Some(remainder) = remainder.take() { + block.push(remainder); + } + let Some(mut chunk) = chunk_iter.next() else { + break; + }; + + let len_before: usize = block.iter().map(|b| b.len()).sum(); + let len_after = len_before + chunk.len(); + + if len_after <= 64 { + block.push(chunk); + } else { + let excess_len = len_after - 64; + remainder = Some(chunk.split_off(chunk.len() - excess_len)); + + block.push(chunk); + state = Self::compute_state(vm, state, &block)?; + block.clear(); + } + } + + Self::compute_state(vm, state, &block) + } + + fn assign_iv(vm: &mut dyn Vm) -> Result, PrfError> { + let iv: Array = vm.alloc().map_err(PrfError::vm)?; + + vm.mark_public(iv).map_err(PrfError::vm)?; + vm.assign(iv, Self::IV).map_err(PrfError::vm)?; + vm.commit(iv).map_err(PrfError::vm)?; + + Ok(iv) + } + + fn compute_state( + vm: &mut dyn Vm, + state: Array, + data: &[Vector], + ) -> Result, PrfError> { + let mut compress = Call::builder(SHA256_COMPRESS.clone()); + + for &block in data { + compress = compress.arg(block); + } + + let compress = compress.arg(state).build().map_err(PrfError::vm)?; + vm.call(compress).map_err(PrfError::vm) + } +} + +/// Reference SHA256 implementation. +/// +/// # Arguments +/// +/// * `state` - The SHA256 state. +/// * `pos` - The number of bytes processed in the current state. +/// * `msg` - The message to hash. +pub(crate) fn sha256(mut state: [u32; 8], pos: usize, msg: &[u8]) -> [u32; 8] { + use sha2::{ + compress256, + digest::{ + block_buffer::{BlockBuffer, Eager}, + generic_array::typenum::U64, + }, + }; + + let mut buffer = BlockBuffer::::default(); + buffer.digest_blocks(msg, |b| compress256(&mut state, b)); + buffer.digest_pad(0x80, &(((msg.len() + pos) * 8) as u64).to_be_bytes(), |b| { + compress256(&mut state, &[*b]) + }); + state +} + +#[cfg(test)] +mod tests { + use crate::{ + convert_to_bytes, + sha256::{sha256, Sha256}, + test_utils::{compress_256, mock_vm}, + }; + use mpz_common::context::test_st_context; + use mpz_vm_core::{ + memory::{ + binary::{U32, U8}, + Array, MemoryExt, Vector, ViewExt, + }, + Execute, + }; + + #[tokio::test] + async fn test_sha256_circuit() { + let (mut ctx_a, mut ctx_b) = test_st_context(8); + let (mut leader, mut follower) = mock_vm(); + + let (inputs, references) = test_fixtures(); + for (input, &reference) in inputs.iter().zip(references.iter()) { + let input_leader: Vector = leader.alloc_vec(input.len()).unwrap(); + leader.mark_public(input_leader).unwrap(); + leader.assign(input_leader, input.clone()).unwrap(); + leader.commit(input_leader).unwrap(); + + let mut sha_leader = Sha256::default(); + sha_leader + .update(input_leader) + .add_padding(&mut leader) + .unwrap(); + let sha_out_leader = sha_leader.alloc(&mut leader).unwrap(); + let sha_out_leader = leader.decode(sha_out_leader).unwrap(); + + let input_follower: Vector = follower.alloc_vec(input.len()).unwrap(); + follower.mark_public(input_follower).unwrap(); + follower.assign(input_follower, input.clone()).unwrap(); + follower.commit(input_follower).unwrap(); + + let mut sha_follower = Sha256::default(); + sha_follower + .update(input_follower) + .add_padding(&mut follower) + .unwrap(); + let sha_out_follower = sha_follower.alloc(&mut follower).unwrap(); + let sha_out_follower = follower.decode(sha_out_follower).unwrap(); + + let (sha_out_leader, sha_out_follower) = tokio::try_join!( + async { + leader.execute_all(&mut ctx_a).await.unwrap(); + sha_out_leader.await + }, + async { + follower.execute_all(&mut ctx_b).await.unwrap(); + sha_out_follower.await + } + ) + .unwrap(); + + assert_eq!(sha_out_leader, sha_out_follower); + assert_eq!(convert_to_bytes(sha_out_leader), reference); + } + } + + #[tokio::test] + async fn test_sha256_circuit_set_state() { + let (mut ctx_a, mut ctx_b) = test_st_context(8); + let (mut leader, mut follower) = mock_vm(); + + let (inputs, references) = test_fixtures(); + + // only take 3rd example because we need minimum 64 bits. + let input = &inputs[2]; + let reference = references[2]; + + // This has to be 64 bytes, because the sha256 compression function operates on + // 64 byte blocks. + let skip = 64; + + let state = compress_256(Sha256::IV, &input[..skip]); + let test = input[skip..].to_vec(); + + let input_leader: Vector = leader.alloc_vec(test.len()).unwrap(); + leader.mark_public(input_leader).unwrap(); + leader.assign(input_leader, test.clone()).unwrap(); + leader.commit(input_leader).unwrap(); + + let state_leader: Array = leader.alloc().unwrap(); + leader.mark_public(state_leader).unwrap(); + leader.assign(state_leader, state).unwrap(); + leader.commit(state_leader).unwrap(); + + let mut sha_leader = Sha256::default(); + sha_leader + .set_state(state_leader, skip) + .update(input_leader) + .add_padding(&mut leader) + .unwrap(); + let sha_out_leader = sha_leader.alloc(&mut leader).unwrap(); + let sha_out_leader = leader.decode(sha_out_leader).unwrap(); + + let input_follower: Vector = follower.alloc_vec(test.len()).unwrap(); + follower.mark_public(input_follower).unwrap(); + follower.assign(input_follower, test).unwrap(); + follower.commit(input_follower).unwrap(); + + let state_follower: Array = follower.alloc().unwrap(); + follower.mark_public(state_follower).unwrap(); + follower.assign(state_follower, state).unwrap(); + follower.commit(state_follower).unwrap(); + + let mut sha_follower = Sha256::default(); + sha_follower + .set_state(state_follower, skip) + .update(input_follower) + .add_padding(&mut follower) + .unwrap(); + let sha_out_follower = sha_follower.alloc(&mut follower).unwrap(); + let sha_out_follower = follower.decode(sha_out_follower).unwrap(); + + let (sha_out_leader, sha_out_follower) = tokio::try_join!( + async { + leader.execute_all(&mut ctx_a).await.unwrap(); + sha_out_leader.await + }, + async { + follower.execute_all(&mut ctx_b).await.unwrap(); + sha_out_follower.await + } + ) + .unwrap(); + + assert_eq!(sha_out_leader, sha_out_follower); + assert_eq!(convert_to_bytes(sha_out_leader), reference); + } + + #[test] + fn test_sha256_reference() { + let (inputs, references) = test_fixtures(); + for (input, &reference) in inputs.iter().zip(references.iter()) { + let sha = sha256(Sha256::IV, 0, input); + assert_eq!(convert_to_bytes(sha), reference); + } + } + + #[test] + fn test_sha256_reference_set_state() { + let (inputs, references) = test_fixtures(); + + // only take 3rd example because we need minimum 64 bits. + let input = &inputs[2]; + let reference = references[2]; + + // This has to be 64 bytes, because the sha256 compression function operates on + // 64 byte blocks. + let skip = 64; + + let state = compress_256(Sha256::IV, &input[..skip]); + let test = input[skip..].to_vec(); + + let sha = sha256(state, skip, &test); + assert_eq!(convert_to_bytes(sha), reference); + } + + fn test_fixtures() -> (Vec>, Vec<[u8; 32]>) { + let test_vectors: Vec> = vec![ + b"abc".to_vec(), + b"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq".to_vec(), + b"abcdefghbcdefghicdefghijdefghijkefghijklfghijklmghijklmnhijklmnoijklmnopjklmnopqklmnopqrlmnopqrsmnopqrstnopqrstu".to_vec() + ]; + let expected: Vec<[u8; 32]> = vec![ + hex::decode("ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad") + .unwrap() + .try_into() + .unwrap(), + hex::decode("248d6a61d20638b8e5c026930c3e6039a33ce45964ff2167f6ecedd419db06c1") + .unwrap() + .try_into() + .unwrap(), + hex::decode("cf5b16a778af8380036ce59e7b0492370b249b11e8f07a51afac45037afee9d1") + .unwrap() + .try_into() + .unwrap(), + ]; + + (test_vectors, expected) + } +} diff --git a/crates/components/hmac-sha256/src/test_utils.rs b/crates/components/hmac-sha256/src/test_utils.rs new file mode 100644 index 0000000000..d3da0e5833 --- /dev/null +++ b/crates/components/hmac-sha256/src/test_utils.rs @@ -0,0 +1,261 @@ +use crate::{convert_to_bytes, sha256::sha256}; +use mpz_garble::protocol::semihonest::{Evaluator, Garbler}; +use mpz_ot::ideal::cot::{ideal_cot, IdealCOTReceiver, IdealCOTSender}; +use mpz_vm_core::memory::correlated::Delta; +use rand::{rngs::StdRng, Rng, SeedableRng}; + +pub(crate) const SHA256_IV: [u32; 8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, +]; + +pub(crate) fn mock_vm() -> (Garbler, Evaluator) { + let mut rng = StdRng::seed_from_u64(0); + let delta = Delta::random(&mut rng); + + let (cot_send, cot_recv) = ideal_cot(delta.into_inner()); + + let gen = Garbler::new(cot_send, [0u8; 16], delta); + let ev = Evaluator::new(cot_recv); + + (gen, ev) +} + +pub(crate) fn prf_ms(pms: [u8; 32], client_random: [u8; 32], server_random: [u8; 32]) -> [u8; 48] { + let mut label_start_seed = b"master secret".to_vec(); + label_start_seed.extend_from_slice(&client_random); + label_start_seed.extend_from_slice(&server_random); + + let ms = phash(pms.to_vec(), &label_start_seed, 2)[..48].to_vec(); + + ms.try_into().unwrap() +} + +pub(crate) fn prf_keys( + ms: [u8; 48], + client_random: [u8; 32], + server_random: [u8; 32], +) -> [Vec; 4] { + let mut label_start_seed = b"key expansion".to_vec(); + label_start_seed.extend_from_slice(&server_random); + label_start_seed.extend_from_slice(&client_random); + + let mut session_keys = phash(ms.to_vec(), &label_start_seed, 2)[..40].to_vec(); + + let server_iv = session_keys.split_off(36); + let client_iv = session_keys.split_off(32); + let server_write_key = session_keys.split_off(16); + let client_write_key = session_keys; + + [client_write_key, server_write_key, client_iv, server_iv] +} + +pub(crate) fn prf_cf_vd(ms: [u8; 48], hanshake_hash: [u8; 32]) -> Vec { + let mut label_start_seed = b"client finished".to_vec(); + label_start_seed.extend_from_slice(&hanshake_hash); + + phash(ms.to_vec(), &label_start_seed, 1)[..12].to_vec() +} + +pub(crate) fn prf_sf_vd(ms: [u8; 48], hanshake_hash: [u8; 32]) -> Vec { + let mut label_start_seed = b"server finished".to_vec(); + label_start_seed.extend_from_slice(&hanshake_hash); + + phash(ms.to_vec(), &label_start_seed, 1)[..12].to_vec() +} + +pub(crate) fn phash(key: Vec, seed: &[u8], iterations: usize) -> Vec { + // A() is defined as: + // + // A(0) = seed + // A(i) = HMAC_hash(secret, A(i-1)) + let mut a_cache: Vec<_> = Vec::with_capacity(iterations + 1); + a_cache.push(seed.to_vec()); + + for i in 0..iterations { + let a_i = hmac_sha256(key.clone(), &a_cache[i]); + a_cache.push(a_i.to_vec()); + } + + // HMAC_hash(secret, A(i) + seed) + let mut output: Vec<_> = Vec::with_capacity(iterations * 32); + for i in 0..iterations { + let mut a_i_seed = a_cache[i + 1].clone(); + a_i_seed.extend_from_slice(seed); + + let hash = hmac_sha256(key.clone(), &a_i_seed); + output.extend_from_slice(&hash); + } + + output +} + +pub(crate) fn hmac_sha256(key: Vec, msg: &[u8]) -> [u8; 32] { + let outer_partial = compute_outer_partial(key.clone()); + let inner_local = compute_inner_local(key, msg); + + let hmac = sha256(outer_partial, 64, &convert_to_bytes(inner_local)); + convert_to_bytes(hmac) +} + +pub(crate) fn compute_outer_partial(mut key: Vec) -> [u32; 8] { + assert!(key.len() <= 64); + + key.resize(64, 0_u8); + let key_padded: [u8; 64] = key + .into_iter() + .map(|b| b ^ 0x5c) + .collect::>() + .try_into() + .unwrap(); + + compress_256(SHA256_IV, &key_padded) +} + +pub(crate) fn compute_inner_local(mut key: Vec, msg: &[u8]) -> [u32; 8] { + assert!(key.len() <= 64); + + key.resize(64, 0_u8); + let key_padded: [u8; 64] = key + .into_iter() + .map(|b| b ^ 0x36) + .collect::>() + .try_into() + .unwrap(); + + let state = compress_256(SHA256_IV, &key_padded); + sha256(state, 64, msg) +} + +pub(crate) fn compress_256(mut state: [u32; 8], msg: &[u8]) -> [u32; 8] { + use sha2::{ + compress256, + digest::{ + block_buffer::{BlockBuffer, Eager}, + generic_array::typenum::U64, + }, + }; + + let mut buffer = BlockBuffer::::default(); + buffer.digest_blocks(msg, |b| compress256(&mut state, b)); + state +} + +// Borrowed from Rustls for testing +// https://github.com/rustls/rustls/blob/main/rustls/src/tls12/prf.rs +mod ring_prf { + use ring::{hmac, hmac::HMAC_SHA256}; + + fn concat_sign(key: &hmac::Key, a: &[u8], b: &[u8]) -> hmac::Tag { + let mut ctx = hmac::Context::with_key(key); + ctx.update(a); + ctx.update(b); + ctx.sign() + } + + fn p(out: &mut [u8], secret: &[u8], seed: &[u8]) { + let hmac_key = hmac::Key::new(HMAC_SHA256, secret); + + // A(1) + let mut current_a = hmac::sign(&hmac_key, seed); + let chunk_size = HMAC_SHA256.digest_algorithm().output_len(); + for chunk in out.chunks_mut(chunk_size) { + // P_hash[i] = HMAC_hash(secret, A(i) + seed) + let p_term = concat_sign(&hmac_key, current_a.as_ref(), seed); + chunk.copy_from_slice(&p_term.as_ref()[..chunk.len()]); + + // A(i+1) = HMAC_hash(secret, A(i)) + current_a = hmac::sign(&hmac_key, current_a.as_ref()); + } + } + + fn concat(a: &[u8], b: &[u8]) -> Vec { + let mut ret = Vec::new(); + ret.extend_from_slice(a); + ret.extend_from_slice(b); + ret + } + + pub(crate) fn prf(out: &mut [u8], secret: &[u8], label: &[u8], seed: &[u8]) { + let joined_seed = concat(label, seed); + p(out, secret, &joined_seed); + } +} + +#[test] +fn test_prf_reference_ms() { + use ring_prf::prf as prf_ref; + + let mut rng = StdRng::from_seed([1; 32]); + + let pms: [u8; 32] = rng.random(); + let label: &[u8] = b"master secret"; + let client_random: [u8; 32] = rng.random(); + let server_random: [u8; 32] = rng.random(); + let mut seed = Vec::from(client_random); + seed.extend_from_slice(&server_random); + + let ms = prf_ms(pms, client_random, server_random); + + let mut expected_ms: [u8; 48] = [0; 48]; + prf_ref(&mut expected_ms, &pms, label, &seed); + + assert_eq!(ms, expected_ms); +} + +#[test] +fn test_prf_reference_ke() { + use ring_prf::prf as prf_ref; + + let mut rng = StdRng::from_seed([2; 32]); + + let ms: [u8; 48] = rng.random(); + let label: &[u8] = b"key expansion"; + let client_random: [u8; 32] = rng.random(); + let server_random: [u8; 32] = rng.random(); + let mut seed = Vec::from(server_random); + seed.extend_from_slice(&client_random); + + let keys = prf_keys(ms, client_random, server_random); + let keys: Vec = keys.into_iter().flatten().collect(); + + let mut expected_keys: [u8; 40] = [0; 40]; + prf_ref(&mut expected_keys, &ms, label, &seed); + + assert_eq!(keys, expected_keys); +} + +#[test] +fn test_prf_reference_cf() { + use ring_prf::prf as prf_ref; + + let mut rng = StdRng::from_seed([3; 32]); + + let ms: [u8; 48] = rng.random(); + let label: &[u8] = b"client finished"; + let handshake_hash: [u8; 32] = rng.random(); + + let cf_vd = prf_cf_vd(ms, handshake_hash); + + let mut expected_cf_vd: [u8; 12] = [0; 12]; + prf_ref(&mut expected_cf_vd, &ms, label, &handshake_hash); + + assert_eq!(cf_vd, expected_cf_vd); +} + +#[test] +fn test_prf_reference_sf() { + use ring_prf::prf as prf_ref; + + let mut rng = StdRng::from_seed([4; 32]); + + let ms: [u8; 48] = rng.random(); + let label: &[u8] = b"server finished"; + let handshake_hash: [u8; 32] = rng.random(); + + let sf_vd = prf_sf_vd(ms, handshake_hash); + + let mut expected_sf_vd: [u8; 12] = [0; 12]; + prf_ref(&mut expected_sf_vd, &ms, label, &handshake_hash); + + assert_eq!(sf_vd, expected_sf_vd); +} diff --git a/crates/mpc-tls/Cargo.toml b/crates/mpc-tls/Cargo.toml index 6572967640..bc09300619 100644 --- a/crates/mpc-tls/Cargo.toml +++ b/crates/mpc-tls/Cargo.toml @@ -20,7 +20,7 @@ default = [] [dependencies] tlsn-cipher = { workspace = true } tlsn-common = { workspace = true } -#tlsn-hmac-sha256 = { workspace = true } +tlsn-hmac-sha256 = { workspace = true } tlsn-key-exchange = { workspace = true } tlsn-tls-backend = { workspace = true } tlsn-tls-core = { workspace = true, features = ["serde"] } diff --git a/crates/mpc-tls/src/config.rs b/crates/mpc-tls/src/config.rs index 24f32189d6..20d33abc87 100644 --- a/crates/mpc-tls/src/config.rs +++ b/crates/mpc-tls/src/config.rs @@ -1,4 +1,5 @@ use derive_builder::Builder; +use hmac_sha256::Config as PrfConfig; /// Number of TLS protocol bytes that will be sent. const PROTOCOL_DATA_SENT: usize = 32; @@ -55,6 +56,8 @@ pub struct Config { /// Maximum number of received bytes. #[allow(unused)] pub(crate) max_recv: usize, + /// Configuration options for the PRF. + pub(crate) prf: PrfConfig, } impl Config { @@ -102,6 +105,7 @@ impl ConfigBuilder { max_recv_records, max_recv_online, max_recv, + prf: self.prf.unwrap_or_default(), }) } } diff --git a/crates/mpc-tls/src/follower.rs b/crates/mpc-tls/src/follower.rs index dd56f65600..f932f376e0 100644 --- a/crates/mpc-tls/src/follower.rs +++ b/crates/mpc-tls/src/follower.rs @@ -3,7 +3,7 @@ use crate::{ record_layer::{aead::MpcAesGcm, RecordLayer}, Config, FollowerData, MpcTlsError, Role, SessionKeys, Vm, }; -use hmac_sha256::{MpcPrf, PrfConfig, PrfOutput}; +use hmac_sha256::{MpcPrf, PrfOutput}; use ke::KeyExchange; use key_exchange::{self as ke, MpcKeyExchange}; use mpz_common::{Context, Flush}; @@ -63,12 +63,7 @@ impl MpcTlsFollower { )), )) as Box; - let prf = MpcPrf::new( - PrfConfig::builder() - .role(hmac_sha256::Role::Follower) - .build() - .expect("PRF config is valid"), - ); + let prf = MpcPrf::new(config.prf); let encrypter = MpcAesGcm::new( ShareConversionReceiver::new(OLEReceiver::new(AnyReceiver::new( @@ -123,8 +118,6 @@ impl MpcTlsFollower { keys.server_iv, )?; - prf.set_client_random(vm, None)?; - let cf_vd = vm.decode(cf_vd).map_err(MpcTlsError::alloc)?; let sf_vd = vm.decode(sf_vd).map_err(MpcTlsError::alloc)?; @@ -230,6 +223,7 @@ impl MpcTlsFollower { return Err(MpcTlsError::state("must be in ready state to run")); }; + let mut client_random = None; let mut server_random = None; let mut server_key = None; let mut cf_vd = None; @@ -237,17 +231,20 @@ impl MpcTlsFollower { loop { let msg: Message = self.ctx.io_mut().expect_next().await?; match msg { + Message::SetClientRandom(random) => { + if client_random.is_some() { + return Err(MpcTlsError::hs("client random already set")); + } + + prf.set_client_random(random.random)?; + client_random = Some(random); + } Message::SetServerRandom(random) => { if server_random.is_some() { return Err(MpcTlsError::hs("server random already set")); } - let mut vm = vm - .try_lock() - .map_err(|_| MpcTlsError::other("VM lock is held"))?; - - prf.set_server_random(&mut (*vm), random.random)?; - + prf.set_server_random(random.random)?; server_random = Some(random); } Message::SetServerKey(key) => { @@ -274,9 +271,19 @@ impl MpcTlsFollower { ke.compute_shares(&mut self.ctx).await?; ke.assign(&mut (*vm))?; - vm.execute_all(&mut self.ctx) - .await - .map_err(MpcTlsError::hs)?; + loop { + let assigned = prf + .drive_key_expansion(&mut (*vm)) + .map_err(MpcTlsError::hs)?; + + vm.execute_all(&mut self.ctx) + .await + .map_err(MpcTlsError::hs)?; + + if assigned { + break; + } + } ke.finalize().await?; record_layer.setup(&mut self.ctx).await?; @@ -290,11 +297,21 @@ impl MpcTlsFollower { .try_lock() .map_err(|_| MpcTlsError::other("VM lock is held"))?; - prf.set_cf_hash(&mut (*vm), vd.handshake_hash)?; + prf.set_cf_hash(vd.handshake_hash)?; - vm.execute_all(&mut self.ctx) - .await - .map_err(MpcTlsError::hs)?; + loop { + let assigned = prf + .drive_client_finished(&mut (*vm)) + .map_err(MpcTlsError::hs)?; + + vm.execute_all(&mut self.ctx) + .await + .map_err(MpcTlsError::hs)?; + + if assigned { + break; + } + } cf_vd = Some( cf_vd_fut @@ -312,11 +329,21 @@ impl MpcTlsFollower { .try_lock() .map_err(|_| MpcTlsError::other("VM lock is held"))?; - prf.set_sf_hash(&mut (*vm), vd.handshake_hash)?; + prf.set_sf_hash(vd.handshake_hash)?; + + loop { + let assigned = prf + .drive_server_finished(&mut (*vm)) + .map_err(MpcTlsError::hs)?; - vm.execute_all(&mut self.ctx) - .await - .map_err(MpcTlsError::hs)?; + vm.execute_all(&mut self.ctx) + .await + .map_err(MpcTlsError::hs)?; + + if assigned { + break; + } + } sf_vd = Some( sf_vd_fut diff --git a/crates/mpc-tls/src/leader.rs b/crates/mpc-tls/src/leader.rs index d1cb1cff9b..72ddf180aa 100644 --- a/crates/mpc-tls/src/leader.rs +++ b/crates/mpc-tls/src/leader.rs @@ -3,15 +3,15 @@ mod actor; use crate::{ error::MpcTlsError, msg::{ - ClientFinishedVd, Decrypt, Encrypt, Message, ServerFinishedVd, SetServerKey, - SetServerRandom, + ClientFinishedVd, Decrypt, Encrypt, Message, ServerFinishedVd, SetClientRandom, + SetServerKey, SetServerRandom, }, record_layer::{aead::MpcAesGcm, DecryptMode, EncryptMode, RecordLayer}, utils::opaque_into_parts, Config, LeaderOutput, Role, SessionKeys, Vm, }; use async_trait::async_trait; -use hmac_sha256::{MpcPrf, PrfConfig, PrfOutput}; +use hmac_sha256::{MpcPrf, PrfOutput}; use ke::KeyExchange; use key_exchange::{self as ke, MpcKeyExchange}; use ludi::Context as LudiContext; @@ -87,12 +87,7 @@ impl MpcTlsLeader { ))), )) as Box; - let prf = MpcPrf::new( - PrfConfig::builder() - .role(hmac_sha256::Role::Leader) - .build() - .expect("prf config is valid"), - ); + let prf = MpcPrf::new(config.prf); let encrypter = MpcAesGcm::new( ShareConversionSender::new(OLESender::new( @@ -128,9 +123,9 @@ impl MpcTlsLeader { } /// Allocates resources for the connection. - pub fn alloc(&mut self) -> Result { + pub async fn alloc(&mut self) -> Result { let State::Init { - ctx, + mut ctx, vm, mut ke, mut prf, @@ -157,7 +152,14 @@ impl MpcTlsLeader { keys.server_iv, )?; - prf.set_client_random(&mut (*vm_lock), Some(client_random.0))?; + ctx.io_mut() + .send(Message::SetClientRandom(SetClientRandom { + random: client_random.0, + })) + .await + .map_err(MpcTlsError::from)?; + + prf.set_client_random(client_random.0)?; let cf_vd = vm_lock.decode(cf_vd).map_err(MpcTlsError::alloc)?; let sf_vd = vm_lock.decode(sf_vd).map_err(MpcTlsError::alloc)?; @@ -408,7 +410,6 @@ impl Backend for MpcTlsLeader { async fn set_server_random(&mut self, random: Random) -> Result<(), BackendError> { let State::Handshake { ctx, - vm, prf, server_random, .. @@ -426,13 +427,7 @@ impl Backend for MpcTlsLeader { .await .map_err(MpcTlsError::from)?; - let mut vm = vm - .try_lock() - .map_err(|_| MpcTlsError::other("VM lock is held"))?; - - prf.set_server_random(&mut (*vm), random.0) - .map_err(MpcTlsError::hs)?; - + prf.set_server_random(random.0).map_err(MpcTlsError::hs)?; *server_random = Some(random); Ok(()) @@ -543,9 +538,19 @@ impl Backend for MpcTlsLeader { let mut vm = vm .try_lock() .map_err(|_| MpcTlsError::other("VM lock is held"))?; - prf.set_sf_hash(&mut (*vm), hash).map_err(MpcTlsError::hs)?; + prf.set_sf_hash(hash).map_err(MpcTlsError::hs)?; + + loop { + let assigned = prf + .drive_server_finished(&mut (*vm)) + .map_err(MpcTlsError::hs)?; - vm.execute_all(ctx).await.map_err(MpcTlsError::hs)?; + vm.execute_all(ctx).await.map_err(MpcTlsError::hs)?; + + if assigned { + break; + } + } let sf_vd = sf_vd .try_recv() @@ -586,9 +591,19 @@ impl Backend for MpcTlsLeader { let mut vm = vm .try_lock() .map_err(|_| MpcTlsError::hs("VM lock is held"))?; - prf.set_cf_hash(&mut (*vm), hash).map_err(MpcTlsError::hs)?; + prf.set_cf_hash(hash).map_err(MpcTlsError::hs)?; + + loop { + let assigned = prf + .drive_client_finished(&mut (*vm)) + .map_err(MpcTlsError::hs)?; - vm.execute_all(ctx).await.map_err(MpcTlsError::hs)?; + vm.execute_all(ctx).await.map_err(MpcTlsError::hs)?; + + if assigned { + break; + } + } let cf_vd = cf_vd .try_recv() @@ -605,7 +620,7 @@ impl Backend for MpcTlsLeader { vm, keys, mut ke, - prf, + mut prf, mut record_layer, cf_vd, sf_vd, @@ -650,10 +665,22 @@ impl Backend for MpcTlsLeader { .map_err(|_| MpcTlsError::other("VM lock is held"))?; ke.assign(&mut (*vm_lock)).map_err(MpcTlsError::hs)?; - vm_lock - .execute_all(&mut ctx) - .await - .map_err(MpcTlsError::hs)?; + + loop { + let assigned = prf + .drive_key_expansion(&mut (*vm_lock)) + .map_err(MpcTlsError::hs)?; + + vm_lock + .execute_all(&mut ctx) + .await + .map_err(MpcTlsError::hs)?; + + if assigned { + break; + } + } + ke.finalize().await.map_err(MpcTlsError::hs)?; record_layer.setup(&mut ctx).await?; } diff --git a/crates/mpc-tls/src/msg.rs b/crates/mpc-tls/src/msg.rs index c5c86e4162..a150f0eb37 100644 --- a/crates/mpc-tls/src/msg.rs +++ b/crates/mpc-tls/src/msg.rs @@ -9,6 +9,7 @@ use crate::record_layer::{DecryptMode, EncryptMode}; /// MPC-TLS protocol message. #[derive(Debug, Clone, Serialize, Deserialize)] pub(crate) enum Message { + SetClientRandom(SetClientRandom), SetServerRandom(SetServerRandom), SetServerKey(SetServerKey), ClientFinishedVd(ClientFinishedVd), @@ -20,6 +21,11 @@ pub(crate) enum Message { CloseConnection, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct SetClientRandom { + pub(crate) random: [u8; 32], +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub(crate) struct SetServerRandom { pub(crate) random: [u8; 32], diff --git a/crates/mpc-tls/tests/test.rs b/crates/mpc-tls/tests/test.rs index 89ac59f1b6..227e48198c 100644 --- a/crates/mpc-tls/tests/test.rs +++ b/crates/mpc-tls/tests/test.rs @@ -41,7 +41,7 @@ async fn mpc_tls_test() { } async fn leader_task(mut leader: MpcTlsLeader) { - leader.alloc().unwrap(); + leader.alloc().await.unwrap(); leader.preprocess().await.unwrap(); let (leader_ctrl, leader_fut) = leader.run(); diff --git a/crates/prover/Cargo.toml b/crates/prover/Cargo.toml index 1c86593819..c22fa1e157 100644 --- a/crates/prover/Cargo.toml +++ b/crates/prover/Cargo.toml @@ -24,6 +24,7 @@ tlsn-tls-client = { workspace = true } tlsn-tls-client-async = { workspace = true } tlsn-tls-core = { workspace = true } tlsn-mpc-tls = { workspace = true } +tlsn-hmac-sha256 = { workspace = true } serio = { workspace = true, features = ["compat"] } uid-mux = { workspace = true, features = ["serio"] } diff --git a/crates/prover/src/config.rs b/crates/prover/src/config.rs index 67cedd7ffe..e5ee8469bd 100644 --- a/crates/prover/src/config.rs +++ b/crates/prover/src/config.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use hmac_sha256::Config as PrfConfig; use mpc_tls::Config; use tlsn_common::config::ProtocolConfig; use tlsn_core::{connection::ServerName, CryptoProvider}; @@ -15,6 +16,9 @@ pub struct ProverConfig { /// Cryptography provider. #[builder(default, setter(into))] crypto_provider: Arc, + /// Configuration options for the PRF. + #[builder(default)] + prf: PrfConfig, } impl ProverConfig { @@ -45,7 +49,8 @@ impl ProverConfig { .defer_decryption(self.protocol_config.defer_decryption_from_start()) .max_sent(self.protocol_config.max_sent_data()) .max_recv_online(self.protocol_config.max_recv_data_online()) - .max_recv(self.protocol_config.max_recv_data()); + .max_recv(self.protocol_config.max_recv_data()) + .prf(self.prf); if let Some(max_sent_records) = self.protocol_config.max_sent_records() { builder.max_sent_records(max_sent_records); diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs index f422d866bc..b1f8201859 100644 --- a/crates/prover/src/lib.rs +++ b/crates/prover/src/lib.rs @@ -17,7 +17,6 @@ pub use future::ProverFuture; use mpz_common::Context; use mpz_core::Block; use mpz_garble_core::Delta; -use rand06_compat::Rand0_6CompatExt; use state::{Notarize, Prove}; use futures::{AsyncRead, AsyncWrite, TryFutureExt}; @@ -104,7 +103,7 @@ impl Prover { let (vm, mut mpc_tls) = build_mpc_tls(&self.config, ctx); // Allocate resources for MPC-TLS in VM. - let keys = mpc_tls.alloc()?; + let keys = mpc_tls.alloc().await?; // Allocate for committing to plaintext. let mut zk_aes = ZkAesCtr::new(Role::Prover); zk_aes.set_key(keys.server_write_key, keys.server_write_iv); diff --git a/crates/verifier/Cargo.toml b/crates/verifier/Cargo.toml index 2ef0574577..e6d5138983 100644 --- a/crates/verifier/Cargo.toml +++ b/crates/verifier/Cargo.toml @@ -22,6 +22,7 @@ tlsn-core = { workspace = true } tlsn-deap = { workspace = true } tlsn-mpc-tls = { workspace = true } tlsn-tls-core = { workspace = true } +tlsn-hmac-sha256 = { workspace = true } serio = { workspace = true, features = ["compat"] } uid-mux = { workspace = true, features = ["serio"] } diff --git a/crates/verifier/src/config.rs b/crates/verifier/src/config.rs index 4328d5e9a1..edf4cf6e2d 100644 --- a/crates/verifier/src/config.rs +++ b/crates/verifier/src/config.rs @@ -3,6 +3,7 @@ use std::{ sync::Arc, }; +use hmac_sha256::Config as PrfConfig; use mpc_tls::Config; use tlsn_common::config::{ProtocolConfig, ProtocolConfigValidator}; use tlsn_core::CryptoProvider; @@ -16,6 +17,9 @@ pub struct VerifierConfig { /// Cryptography provider. #[builder(default, setter(into))] crypto_provider: Arc, + /// Configuration options for the PRF. + #[builder(default)] + prf: PrfConfig, } impl Debug for VerifierConfig { @@ -48,7 +52,8 @@ impl VerifierConfig { builder .max_sent(protocol_config.max_sent_data()) .max_recv_online(protocol_config.max_recv_data_online()) - .max_recv(protocol_config.max_recv_data()); + .max_recv(protocol_config.max_recv_data()) + .prf(self.prf); if let Some(max_sent_records) = protocol_config.max_sent_records() { builder.max_sent_records(max_sent_records); diff --git a/crates/verifier/src/lib.rs b/crates/verifier/src/lib.rs index 51f79f58a1..ae5d228159 100644 --- a/crates/verifier/src/lib.rs +++ b/crates/verifier/src/lib.rs @@ -20,7 +20,6 @@ use mpc_tls::{FollowerData, MpcTlsFollower}; use mpz_common::Context; use mpz_core::Block; use mpz_garble_core::Delta; -use rand06_compat::Rand0_6CompatExt; use serio::stream::IoStreamExt; use state::{Notarize, Verify}; use tls_core::msgs::enums::ContentType; From 08dcb698b4dd00066bb4b50a7e85d34275a753e4 Mon Sep 17 00:00:00 2001 From: th4s Date: Thu, 24 Apr 2025 13:08:46 +0200 Subject: [PATCH 02/36] move sending `client_random` from `alloc` to `preprocess` --- crates/mpc-tls/src/leader.rs | 24 ++++++++++++------------ crates/mpc-tls/tests/test.rs | 2 +- crates/prover/src/lib.rs | 2 +- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/crates/mpc-tls/src/leader.rs b/crates/mpc-tls/src/leader.rs index 72ddf180aa..8fcfbcf65d 100644 --- a/crates/mpc-tls/src/leader.rs +++ b/crates/mpc-tls/src/leader.rs @@ -123,9 +123,9 @@ impl MpcTlsLeader { } /// Allocates resources for the connection. - pub async fn alloc(&mut self) -> Result { + pub fn alloc(&mut self) -> Result { let State::Init { - mut ctx, + ctx, vm, mut ke, mut prf, @@ -152,15 +152,6 @@ impl MpcTlsLeader { keys.server_iv, )?; - ctx.io_mut() - .send(Message::SetClientRandom(SetClientRandom { - random: client_random.0, - })) - .await - .map_err(MpcTlsError::from)?; - - prf.set_client_random(client_random.0)?; - let cf_vd = vm_lock.decode(cf_vd).map_err(MpcTlsError::alloc)?; let sf_vd = vm_lock.decode(sf_vd).map_err(MpcTlsError::alloc)?; @@ -196,7 +187,7 @@ impl MpcTlsLeader { vm, keys, mut ke, - prf, + mut prf, mut record_layer, cf_vd, sf_vd, @@ -240,6 +231,15 @@ impl MpcTlsLeader { .await .map_err(MpcTlsError::preprocess)??; + ctx.io_mut() + .send(Message::SetClientRandom(SetClientRandom { + random: client_random.0, + })) + .await + .map_err(MpcTlsError::from)?; + + prf.set_client_random(client_random.0)?; + self.state = State::Handshake { ctx, vm, diff --git a/crates/mpc-tls/tests/test.rs b/crates/mpc-tls/tests/test.rs index 227e48198c..89ac59f1b6 100644 --- a/crates/mpc-tls/tests/test.rs +++ b/crates/mpc-tls/tests/test.rs @@ -41,7 +41,7 @@ async fn mpc_tls_test() { } async fn leader_task(mut leader: MpcTlsLeader) { - leader.alloc().await.unwrap(); + leader.alloc().unwrap(); leader.preprocess().await.unwrap(); let (leader_ctrl, leader_fut) = leader.run(); diff --git a/crates/prover/src/lib.rs b/crates/prover/src/lib.rs index b1f8201859..e33f37648d 100644 --- a/crates/prover/src/lib.rs +++ b/crates/prover/src/lib.rs @@ -103,7 +103,7 @@ impl Prover { let (vm, mut mpc_tls) = build_mpc_tls(&self.config, ctx); // Allocate resources for MPC-TLS in VM. - let keys = mpc_tls.alloc().await?; + let keys = mpc_tls.alloc()?; // Allocate for committing to plaintext. let mut zk_aes = ZkAesCtr::new(Role::Prover); zk_aes.set_key(keys.server_write_key, keys.server_write_iv); From 417621d27712f96891faf153545f71f73daac7d6 Mon Sep 17 00:00:00 2001 From: th4s Date: Thu, 24 Apr 2025 13:18:01 +0200 Subject: [PATCH 03/36] rename `Config` -> `Mode` and rename variants --- crates/components/hmac-sha256/benches/prf.rs | 6 +-- crates/components/hmac-sha256/src/config.rs | 14 +++---- crates/components/hmac-sha256/src/lib.rs | 10 ++--- .../hmac-sha256/src/prf/function/mod.rs | 42 +++++++++---------- .../src/prf/function/{mpc.rs => normal.rs} | 0 .../src/prf/function/{local.rs => reduced.rs} | 0 crates/components/hmac-sha256/src/prf/mod.rs | 8 ++-- crates/mpc-tls/src/config.rs | 2 +- crates/prover/src/config.rs | 2 +- crates/verifier/src/config.rs | 2 +- 10 files changed, 43 insertions(+), 43 deletions(-) rename crates/components/hmac-sha256/src/prf/function/{mpc.rs => normal.rs} (100%) rename crates/components/hmac-sha256/src/prf/function/{local.rs => reduced.rs} (100%) diff --git a/crates/components/hmac-sha256/benches/prf.rs b/crates/components/hmac-sha256/benches/prf.rs index b6db1e4ef4..92b48cbb4a 100644 --- a/crates/components/hmac-sha256/benches/prf.rs +++ b/crates/components/hmac-sha256/benches/prf.rs @@ -2,7 +2,7 @@ use criterion::{criterion_group, criterion_main, Criterion}; -use hmac_sha256::{Config, MpcPrf}; +use hmac_sha256::{Mode, MpcPrf}; use mpz_common::context::test_mt_context; use mpz_garble::protocol::semihonest::{Evaluator, Garbler}; use mpz_ot::ideal::cot::ideal_cot; @@ -52,8 +52,8 @@ async fn prf() { follower_vm.assign(follower_pms, pms).unwrap(); follower_vm.commit(follower_pms).unwrap(); - let mut leader = MpcPrf::new(Config::default()); - let mut follower = MpcPrf::new(Config::default()); + let mut leader = MpcPrf::new(Mode::default()); + let mut follower = MpcPrf::new(Mode::default()); let leader_output = leader.alloc(&mut leader_vm, leader_pms).unwrap(); let follower_output = follower.alloc(&mut follower_vm, follower_pms).unwrap(); diff --git a/crates/components/hmac-sha256/src/config.rs b/crates/components/hmac-sha256/src/config.rs index a0ba2df48f..9a9284c711 100644 --- a/crates/components/hmac-sha256/src/config.rs +++ b/crates/components/hmac-sha256/src/config.rs @@ -1,16 +1,16 @@ -//! PRF Config. +//! PRF modes. -/// Configuration option for the PRF. +/// Modes for the PRF. #[derive(Debug, Clone, Copy)] -pub enum Config { +pub enum Mode { /// Computes some hashes locally. - Local, + Reduced, /// Computes the whole PRF in MPC. - Mpc, + Normal, } -impl Default for Config { +impl Default for Mode { fn default() -> Self { - Self::Mpc + Self::Normal } } diff --git a/crates/components/hmac-sha256/src/lib.rs b/crates/components/hmac-sha256/src/lib.rs index d08aef5f2c..759d373495 100644 --- a/crates/components/hmac-sha256/src/lib.rs +++ b/crates/components/hmac-sha256/src/lib.rs @@ -10,7 +10,7 @@ mod sha256; mod test_utils; mod config; -pub use config::Config; +pub use config::Mode; mod error; pub use error::PrfError; @@ -57,7 +57,7 @@ fn convert_to_bytes(input: [u32; 8]) -> [u8; 32] { mod tests { use crate::{ test_utils::{mock_vm, prf_cf_vd, prf_keys, prf_ms, prf_sf_vd}, - Config, MpcPrf, SessionKeys, + Mode, MpcPrf, SessionKeys, }; use mpz_common::context::test_st_context; use mpz_vm_core::{ @@ -68,17 +68,17 @@ mod tests { #[tokio::test] async fn test_prf_local() { - let config = Config::Local; + let config = Mode::Reduced; test_prf(config).await; } #[tokio::test] async fn test_prf_mpc() { - let config = Config::Mpc; + let config = Mode::Normal; test_prf(config).await; } - async fn test_prf(config: Config) { + async fn test_prf(config: Mode) { let mut rng = StdRng::seed_from_u64(1); // Test input let pms: [u8; 32] = rng.random(); diff --git a/crates/components/hmac-sha256/src/prf/function/mod.rs b/crates/components/hmac-sha256/src/prf/function/mod.rs index 7c1b622ebc..5b5537204a 100644 --- a/crates/components/hmac-sha256/src/prf/function/mod.rs +++ b/crates/components/hmac-sha256/src/prf/function/mod.rs @@ -1,6 +1,6 @@ //! Provides [`Prf`], for computing the TLS 1.2 PRF. -use crate::{Config, PrfError}; +use crate::{Mode, PrfError}; use mpz_vm_core::{ memory::{ binary::{Binary, U32}, @@ -9,29 +9,29 @@ use mpz_vm_core::{ Vm, }; -mod local; -mod mpc; +mod normal; +mod reduced; #[derive(Debug)] pub(crate) enum Prf { - Local(local::PrfFunction), - Mpc(mpc::PrfFunction), + Local(reduced::PrfFunction), + Mpc(normal::PrfFunction), } impl Prf { pub(crate) fn alloc_master_secret( - config: Config, + config: Mode, vm: &mut dyn Vm, outer_partial: Array, inner_partial: Array, ) -> Result { let prf = match config { - Config::Local => Self::Local(local::PrfFunction::alloc_master_secret( + Mode::Reduced => Self::Local(reduced::PrfFunction::alloc_master_secret( vm, outer_partial, inner_partial, )?), - Config::Mpc => Self::Mpc(mpc::PrfFunction::alloc_master_secret( + Mode::Normal => Self::Mpc(normal::PrfFunction::alloc_master_secret( vm, outer_partial, inner_partial, @@ -41,18 +41,18 @@ impl Prf { } pub(crate) fn alloc_key_expansion( - config: Config, + config: Mode, vm: &mut dyn Vm, outer_partial: Array, inner_partial: Array, ) -> Result { let prf = match config { - Config::Local => Self::Local(local::PrfFunction::alloc_key_expansion( + Mode::Reduced => Self::Local(reduced::PrfFunction::alloc_key_expansion( vm, outer_partial, inner_partial, )?), - Config::Mpc => Self::Mpc(mpc::PrfFunction::alloc_key_expansion( + Mode::Normal => Self::Mpc(normal::PrfFunction::alloc_key_expansion( vm, outer_partial, inner_partial, @@ -62,18 +62,18 @@ impl Prf { } pub(crate) fn alloc_client_finished( - config: Config, + config: Mode, vm: &mut dyn Vm, outer_partial: Array, inner_partial: Array, ) -> Result { let prf = match config { - Config::Local => Self::Local(local::PrfFunction::alloc_client_finished( + Mode::Reduced => Self::Local(reduced::PrfFunction::alloc_client_finished( vm, outer_partial, inner_partial, )?), - Config::Mpc => Self::Mpc(mpc::PrfFunction::alloc_client_finished( + Mode::Normal => Self::Mpc(normal::PrfFunction::alloc_client_finished( vm, outer_partial, inner_partial, @@ -83,18 +83,18 @@ impl Prf { } pub(crate) fn alloc_server_finished( - config: Config, + config: Mode, vm: &mut dyn Vm, outer_partial: Array, inner_partial: Array, ) -> Result { let prf = match config { - Config::Local => Self::Local(local::PrfFunction::alloc_server_finished( + Mode::Reduced => Self::Local(reduced::PrfFunction::alloc_server_finished( vm, outer_partial, inner_partial, )?), - Config::Mpc => Self::Mpc(mpc::PrfFunction::alloc_server_finished( + Mode::Normal => Self::Mpc(normal::PrfFunction::alloc_server_finished( vm, outer_partial, inner_partial, @@ -131,7 +131,7 @@ mod tests { convert_to_bytes, prf::{compute_partial, function::Prf}, test_utils::{mock_vm, phash}, - Config, + Mode, }; use mpz_common::context::test_st_context; use mpz_vm_core::{ @@ -144,17 +144,17 @@ mod tests { #[tokio::test] async fn test_phash_local() { - let config = Config::Local; + let config = Mode::Reduced; test_phash(config).await; } #[tokio::test] async fn test_phash_mpc() { - let config = Config::Local; + let config = Mode::Reduced; test_phash(config).await; } - async fn test_phash(config: Config) { + async fn test_phash(config: Mode) { let (mut ctx_a, mut ctx_b) = test_st_context(8); let (mut leader, mut follower) = mock_vm(); diff --git a/crates/components/hmac-sha256/src/prf/function/mpc.rs b/crates/components/hmac-sha256/src/prf/function/normal.rs similarity index 100% rename from crates/components/hmac-sha256/src/prf/function/mpc.rs rename to crates/components/hmac-sha256/src/prf/function/normal.rs diff --git a/crates/components/hmac-sha256/src/prf/function/local.rs b/crates/components/hmac-sha256/src/prf/function/reduced.rs similarity index 100% rename from crates/components/hmac-sha256/src/prf/function/local.rs rename to crates/components/hmac-sha256/src/prf/function/reduced.rs diff --git a/crates/components/hmac-sha256/src/prf/mod.rs b/crates/components/hmac-sha256/src/prf/mod.rs index 248e5f2830..eff57ea2b9 100644 --- a/crates/components/hmac-sha256/src/prf/mod.rs +++ b/crates/components/hmac-sha256/src/prf/mod.rs @@ -1,4 +1,4 @@ -use crate::{sha256::Sha256, Config, PrfError, PrfOutput, SessionKeys}; +use crate::{sha256::Sha256, Mode, PrfError, PrfOutput, SessionKeys}; use mpz_circuits::{circuits::xor, Circuit, CircuitBuilder}; use mpz_vm_core::{ memory::{ @@ -19,7 +19,7 @@ use function::Prf; /// MPC PRF for computing TLS 1.2 HMAC-SHA256 PRF. #[derive(Debug)] pub struct MpcPrf { - config: Config, + config: Mode, state: State, circuits: Option, } @@ -30,7 +30,7 @@ impl MpcPrf { /// # Arguments /// /// `config` - The PRF config. - pub fn new(config: Config) -> MpcPrf { + pub fn new(config: Mode) -> MpcPrf { Self { config, state: State::Initialized, @@ -224,7 +224,7 @@ impl Circuits { const IPAD: [u8; 64] = [0x36; 64]; const OPAD: [u8; 64] = [0x5c; 64]; - fn alloc(config: Config, vm: &mut dyn Vm, pms: Vector) -> Result { + fn alloc(config: Mode, vm: &mut dyn Vm, pms: Vector) -> Result { let outer_partial_pms = compute_partial(vm, pms, Self::OPAD)?; let inner_partial_pms = compute_partial(vm, pms, Self::IPAD)?; diff --git a/crates/mpc-tls/src/config.rs b/crates/mpc-tls/src/config.rs index 20d33abc87..089c935b12 100644 --- a/crates/mpc-tls/src/config.rs +++ b/crates/mpc-tls/src/config.rs @@ -1,5 +1,5 @@ use derive_builder::Builder; -use hmac_sha256::Config as PrfConfig; +use hmac_sha256::Mode as PrfConfig; /// Number of TLS protocol bytes that will be sent. const PROTOCOL_DATA_SENT: usize = 32; diff --git a/crates/prover/src/config.rs b/crates/prover/src/config.rs index e5ee8469bd..7e1b6bdbd3 100644 --- a/crates/prover/src/config.rs +++ b/crates/prover/src/config.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use hmac_sha256::Config as PrfConfig; +use hmac_sha256::Mode as PrfConfig; use mpc_tls::Config; use tlsn_common::config::ProtocolConfig; use tlsn_core::{connection::ServerName, CryptoProvider}; diff --git a/crates/verifier/src/config.rs b/crates/verifier/src/config.rs index edf4cf6e2d..84733ebe35 100644 --- a/crates/verifier/src/config.rs +++ b/crates/verifier/src/config.rs @@ -3,7 +3,7 @@ use std::{ sync::Arc, }; -use hmac_sha256::Config as PrfConfig; +use hmac_sha256::Mode as PrfConfig; use mpc_tls::Config; use tlsn_common::config::{ProtocolConfig, ProtocolConfigValidator}; use tlsn_core::CryptoProvider; From f2bcbca4dc9e0218739531ddd162ee1e18bdad91 Mon Sep 17 00:00:00 2001 From: th4s Date: Thu, 24 Apr 2025 15:40:03 +0200 Subject: [PATCH 04/36] add feedback for handling of prf config --- crates/common/src/config.rs | 17 +++++++++++++++++ crates/mpc-tls/src/config.rs | 15 ++++++++++++--- crates/prover/Cargo.toml | 1 - crates/prover/src/config.rs | 14 ++++++++------ crates/verifier/Cargo.toml | 1 - crates/verifier/src/config.rs | 14 ++++++++------ crates/wasm/src/prover/config.rs | 9 ++++++++- 7 files changed, 53 insertions(+), 18 deletions(-) diff --git a/crates/common/src/config.rs b/crates/common/src/config.rs index bf9dd0a25f..5a22332aa1 100644 --- a/crates/common/src/config.rs +++ b/crates/common/src/config.rs @@ -216,6 +216,23 @@ impl ProtocolConfigValidator { } } +/// Settings for the network environment. +/// +/// Provides optimization options to adapt the protocol to different network situations. +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub enum NetworkSetting { + /// Prefers a bandwidth-heavy protocol. + Bandwidth, + /// Prefers a latency-heavy protocol. + Latency, +} + +impl Default for NetworkSetting { + fn default() -> Self { + Self::Bandwidth + } +} + /// A ProtocolConfig error. #[derive(thiserror::Error, Debug)] pub struct ProtocolConfigError { diff --git a/crates/mpc-tls/src/config.rs b/crates/mpc-tls/src/config.rs index 089c935b12..d7a9c8946e 100644 --- a/crates/mpc-tls/src/config.rs +++ b/crates/mpc-tls/src/config.rs @@ -1,5 +1,5 @@ use derive_builder::Builder; -use hmac_sha256::Mode as PrfConfig; +use hmac_sha256::Mode as PrfMode; /// Number of TLS protocol bytes that will be sent. const PROTOCOL_DATA_SENT: usize = 32; @@ -57,7 +57,8 @@ pub struct Config { #[allow(unused)] pub(crate) max_recv: usize, /// Configuration options for the PRF. - pub(crate) prf: PrfConfig, + #[builder(setter(custom))] + pub(crate) prf: PrfMode, } impl Config { @@ -68,6 +69,12 @@ impl Config { } impl ConfigBuilder { + /// Optimizes the protocol for low bandwidth networks. + pub fn low_bandwidth(&mut self) -> &mut Self { + self.prf = Some(PrfMode::Reduced); + self + } + /// Builds the configuration. pub fn build(&self) -> Result { let defer_decryption = self.defer_decryption.unwrap_or(true); @@ -98,6 +105,8 @@ impl ConfigBuilder { .max_recv_records .unwrap_or_else(|| PROTOCOL_RECORD_COUNT_RECV + default_record_count(max_recv)); + let prf = self.prf.unwrap_or(PrfMode::Normal); + Ok(Config { defer_decryption, max_sent_records, @@ -105,7 +114,7 @@ impl ConfigBuilder { max_recv_records, max_recv_online, max_recv, - prf: self.prf.unwrap_or_default(), + prf, }) } } diff --git a/crates/prover/Cargo.toml b/crates/prover/Cargo.toml index c22fa1e157..1c86593819 100644 --- a/crates/prover/Cargo.toml +++ b/crates/prover/Cargo.toml @@ -24,7 +24,6 @@ tlsn-tls-client = { workspace = true } tlsn-tls-client-async = { workspace = true } tlsn-tls-core = { workspace = true } tlsn-mpc-tls = { workspace = true } -tlsn-hmac-sha256 = { workspace = true } serio = { workspace = true, features = ["compat"] } uid-mux = { workspace = true, features = ["serio"] } diff --git a/crates/prover/src/config.rs b/crates/prover/src/config.rs index 7e1b6bdbd3..02bf0e71a2 100644 --- a/crates/prover/src/config.rs +++ b/crates/prover/src/config.rs @@ -1,8 +1,7 @@ use std::sync::Arc; -use hmac_sha256::Mode as PrfConfig; use mpc_tls::Config; -use tlsn_common::config::ProtocolConfig; +use tlsn_common::config::{NetworkSetting, ProtocolConfig}; use tlsn_core::{connection::ServerName, CryptoProvider}; /// Configuration for the prover @@ -16,9 +15,9 @@ pub struct ProverConfig { /// Cryptography provider. #[builder(default, setter(into))] crypto_provider: Arc, - /// Configuration options for the PRF. + /// Network settings. #[builder(default)] - prf: PrfConfig, + network: NetworkSetting, } impl ProverConfig { @@ -49,8 +48,7 @@ impl ProverConfig { .defer_decryption(self.protocol_config.defer_decryption_from_start()) .max_sent(self.protocol_config.max_sent_data()) .max_recv_online(self.protocol_config.max_recv_data_online()) - .max_recv(self.protocol_config.max_recv_data()) - .prf(self.prf); + .max_recv(self.protocol_config.max_recv_data()); if let Some(max_sent_records) = self.protocol_config.max_sent_records() { builder.max_sent_records(max_sent_records); @@ -60,6 +58,10 @@ impl ProverConfig { builder.max_recv_records(max_recv_records); } + if let NetworkSetting::Latency = self.network { + builder.low_bandwidth(); + } + builder.build().unwrap() } } diff --git a/crates/verifier/Cargo.toml b/crates/verifier/Cargo.toml index e6d5138983..2ef0574577 100644 --- a/crates/verifier/Cargo.toml +++ b/crates/verifier/Cargo.toml @@ -22,7 +22,6 @@ tlsn-core = { workspace = true } tlsn-deap = { workspace = true } tlsn-mpc-tls = { workspace = true } tlsn-tls-core = { workspace = true } -tlsn-hmac-sha256 = { workspace = true } serio = { workspace = true, features = ["compat"] } uid-mux = { workspace = true, features = ["serio"] } diff --git a/crates/verifier/src/config.rs b/crates/verifier/src/config.rs index 84733ebe35..85f4a8c548 100644 --- a/crates/verifier/src/config.rs +++ b/crates/verifier/src/config.rs @@ -3,9 +3,8 @@ use std::{ sync::Arc, }; -use hmac_sha256::Mode as PrfConfig; use mpc_tls::Config; -use tlsn_common::config::{ProtocolConfig, ProtocolConfigValidator}; +use tlsn_common::config::{NetworkSetting, ProtocolConfig, ProtocolConfigValidator}; use tlsn_core::CryptoProvider; /// Configuration for the [`Verifier`](crate::tls::Verifier). @@ -17,9 +16,9 @@ pub struct VerifierConfig { /// Cryptography provider. #[builder(default, setter(into))] crypto_provider: Arc, - /// Configuration options for the PRF. + /// Network settings. #[builder(default)] - prf: PrfConfig, + network: NetworkSetting, } impl Debug for VerifierConfig { @@ -52,8 +51,7 @@ impl VerifierConfig { builder .max_sent(protocol_config.max_sent_data()) .max_recv_online(protocol_config.max_recv_data_online()) - .max_recv(protocol_config.max_recv_data()) - .prf(self.prf); + .max_recv(protocol_config.max_recv_data()); if let Some(max_sent_records) = protocol_config.max_sent_records() { builder.max_sent_records(max_sent_records); @@ -63,6 +61,10 @@ impl VerifierConfig { builder.max_recv_records(max_recv_records); } + if let NetworkSetting::Latency = self.network { + builder.low_bandwidth(); + } + builder.build().unwrap() } } diff --git a/crates/wasm/src/prover/config.rs b/crates/wasm/src/prover/config.rs index 72352b526f..e23ee9fab7 100644 --- a/crates/wasm/src/prover/config.rs +++ b/crates/wasm/src/prover/config.rs @@ -1,5 +1,5 @@ use serde::Deserialize; -use tlsn_common::config::ProtocolConfig; +use tlsn_common::config::{NetworkSetting, ProtocolConfig}; use tsify_next::Tsify; #[derive(Debug, Tsify, Deserialize)] @@ -12,6 +12,7 @@ pub struct ProverConfig { pub defer_decryption_from_start: Option, pub max_sent_records: Option, pub max_recv_records: Option, + pub network: NetworkSetting, } impl From for tlsn_prover::ProverConfig { @@ -44,6 +45,12 @@ impl From for tlsn_prover::ProverConfig { .server_name(value.server_name.as_ref()) .protocol_config(protocol_config); + if let Some(value) = value.defer_decryption_from_start { + builder.defer_decryption_from_start(value); + } + + builder.network(value.network); + builder.build().unwrap() } } From cb0916bc206556aae041eda7e3dcf296c5b87bd5 Mon Sep 17 00:00:00 2001 From: th4s Date: Thu, 24 Apr 2025 15:41:43 +0200 Subject: [PATCH 05/36] fix formatting to nightly --- crates/common/src/config.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/common/src/config.rs b/crates/common/src/config.rs index 5a22332aa1..a0cc14b6c4 100644 --- a/crates/common/src/config.rs +++ b/crates/common/src/config.rs @@ -218,7 +218,8 @@ impl ProtocolConfigValidator { /// Settings for the network environment. /// -/// Provides optimization options to adapt the protocol to different network situations. +/// Provides optimization options to adapt the protocol to different network +/// situations. #[derive(Debug, Clone, Copy, Serialize, Deserialize)] pub enum NetworkSetting { /// Prefers a bandwidth-heavy protocol. From b8c99e65f1c9c9d073b20258c913ffc61524dd15 Mon Sep 17 00:00:00 2001 From: th4s Date: Fri, 25 Apr 2025 12:15:33 +0200 Subject: [PATCH 06/36] simplify `MpcPrf` --- crates/components/hmac-sha256/src/hmac.rs | 3 + .../hmac-sha256/src/prf/function/mod.rs | 8 +- crates/components/hmac-sha256/src/prf/mod.rs | 240 ++++-------------- .../components/hmac-sha256/src/prf/state.rs | 93 ++++++- 4 files changed, 145 insertions(+), 199 deletions(-) diff --git a/crates/components/hmac-sha256/src/hmac.rs b/crates/components/hmac-sha256/src/hmac.rs index 67cc6cf801..beb48c9aae 100644 --- a/crates/components/hmac-sha256/src/hmac.rs +++ b/crates/components/hmac-sha256/src/hmac.rs @@ -35,6 +35,9 @@ pub(crate) struct HmacSha256 { } impl HmacSha256 { + pub(crate) const IPAD: [u8; 64] = [0x36; 64]; + pub(crate) const OPAD: [u8; 64] = [0x5c; 64]; + /// Creates a new instance. /// /// # Arguments diff --git a/crates/components/hmac-sha256/src/prf/function/mod.rs b/crates/components/hmac-sha256/src/prf/function/mod.rs index 5b5537204a..6f1c246ada 100644 --- a/crates/components/hmac-sha256/src/prf/function/mod.rs +++ b/crates/components/hmac-sha256/src/prf/function/mod.rs @@ -20,12 +20,12 @@ pub(crate) enum Prf { impl Prf { pub(crate) fn alloc_master_secret( - config: Mode, + mode: Mode, vm: &mut dyn Vm, outer_partial: Array, inner_partial: Array, ) -> Result { - let prf = match config { + let prf = match mode { Mode::Reduced => Self::Local(reduced::PrfFunction::alloc_master_secret( vm, outer_partial, @@ -41,12 +41,12 @@ impl Prf { } pub(crate) fn alloc_key_expansion( - config: Mode, + mode: Mode, vm: &mut dyn Vm, outer_partial: Array, inner_partial: Array, ) -> Result { - let prf = match config { + let prf = match mode { Mode::Reduced => Self::Local(reduced::PrfFunction::alloc_key_expansion( vm, outer_partial, diff --git a/crates/components/hmac-sha256/src/prf/mod.rs b/crates/components/hmac-sha256/src/prf/mod.rs index eff57ea2b9..4b2d3133e8 100644 --- a/crates/components/hmac-sha256/src/prf/mod.rs +++ b/crates/components/hmac-sha256/src/prf/mod.rs @@ -1,9 +1,9 @@ -use crate::{sha256::Sha256, Mode, PrfError, PrfOutput, SessionKeys}; +use crate::{hmac::HmacSha256, sha256::Sha256, Mode, PrfError, PrfOutput}; use mpz_circuits::{circuits::xor, Circuit, CircuitBuilder}; use mpz_vm_core::{ memory::{ binary::{Binary, U32, U8}, - Array, FromRaw, MemoryExt, StaticSize, ToRaw, Vector, ViewExt, + Array, MemoryExt, StaticSize, Vector, ViewExt, }, Call, CallableExt, Vm, }; @@ -19,9 +19,8 @@ use function::Prf; /// MPC PRF for computing TLS 1.2 HMAC-SHA256 PRF. #[derive(Debug)] pub struct MpcPrf { - config: Mode, + mode: Mode, state: State, - circuits: Option, } impl MpcPrf { @@ -29,12 +28,11 @@ impl MpcPrf { /// /// # Arguments /// - /// `config` - The PRF config. - pub fn new(config: Mode) -> MpcPrf { + /// `mode` - The PRF config. + pub fn new(mode: Mode) -> MpcPrf { Self { - config, + mode, state: State::Initialized, - circuits: None, } } @@ -54,20 +52,35 @@ impl MpcPrf { return Err(PrfError::state("PRF not in initialized state")); }; - let circuits = Circuits::alloc(self.config, vm, pms.into())?; + let mode = self.mode; + let pms: Vector = pms.into(); - let keys = circuits.get_session_keys(vm)?; - let cf_vd = circuits.get_client_finished_vd(vm)?; - let sf_vd = circuits.get_server_finished_vd(vm)?; + let outer_partial_pms = compute_partial(vm, pms, HmacSha256::OPAD)?; + let inner_partial_pms = compute_partial(vm, pms, HmacSha256::IPAD)?; - let prf_output = PrfOutput { keys, cf_vd, sf_vd }; + let master_secret = + Prf::alloc_master_secret(mode, vm, outer_partial_pms, inner_partial_pms)?; + let ms = master_secret.output(); + let ms = merge_outputs(vm, ms, 48)?; + + let outer_partial_ms = compute_partial(vm, ms, HmacSha256::OPAD)?; + let inner_partial_ms = compute_partial(vm, ms, HmacSha256::IPAD)?; + + let key_expansion = Prf::alloc_key_expansion(mode, vm, outer_partial_ms, inner_partial_ms)?; + let client_finished = + Prf::alloc_client_finished(mode, vm, outer_partial_ms, inner_partial_ms)?; + let server_finished = + Prf::alloc_server_finished(mode, vm, outer_partial_ms, inner_partial_ms)?; - self.circuits = Some(circuits); self.state = State::SessionKeys { client_random: None, + master_secret, + key_expansion, + client_finished, + server_finished, }; - Ok(prf_output) + self.state.prf_output(vm) } /// Sets the client random. @@ -77,7 +90,7 @@ impl MpcPrf { /// * `random` - The client random. #[instrument(level = "debug", skip_all, err)] pub fn set_client_random(&mut self, random: [u8; 32]) -> Result<(), PrfError> { - let State::SessionKeys { client_random } = &mut self.state else { + let State::SessionKeys { client_random, .. } = &mut self.state else { return Err(PrfError::state("PRF not set up")); }; @@ -92,26 +105,27 @@ impl MpcPrf { /// * `random` - The server random. #[instrument(level = "debug", skip_all, err)] pub fn set_server_random(&mut self, random: [u8; 32]) -> Result<(), PrfError> { - let State::SessionKeys { client_random } = self.state.take() else { + let State::SessionKeys { + client_random, + master_secret, + key_expansion, + .. + } = &mut self.state + else { return Err(PrfError::state("PRF not set up")); }; - let Some(ref mut circuits) = self.circuits else { - return Err(PrfError::state("Circuits should have been set for PRF")); - }; - let client_random = client_random.expect("Client random should have been set by now"); let server_random = random; let mut seed_ms = client_random.to_vec(); seed_ms.extend_from_slice(&server_random); - circuits.master_secret.set_start_seed(seed_ms); + master_secret.set_start_seed(seed_ms); let mut seed_ke = server_random.to_vec(); seed_ke.extend_from_slice(&client_random); - circuits.key_expansion.set_start_seed(seed_ke); + key_expansion.set_start_seed(seed_ke); - self.state = State::ClientFinished; Ok(()) } @@ -122,18 +136,16 @@ impl MpcPrf { /// * `handshake_hash` - The handshake transcript hash. #[instrument(level = "debug", skip_all, err)] pub fn set_cf_hash(&mut self, handshake_hash: [u8; 32]) -> Result<(), PrfError> { - let State::ClientFinished = self.state.take() else { + let State::ClientFinished { + client_finished, .. + } = &mut self.state + else { return Err(PrfError::state("PRF not in client finished state")); }; - let Some(ref mut circuits) = self.circuits else { - return Err(PrfError::state("Circuits should have been set for PRF")); - }; - let seed_cf = handshake_hash.to_vec(); - circuits.client_finished.set_start_seed(seed_cf); + client_finished.set_start_seed(seed_cf); - self.state = State::ServerFinished; Ok(()) } @@ -144,175 +156,19 @@ impl MpcPrf { /// * `handshake_hash` - The handshake transcript hash. #[instrument(level = "debug", skip_all, err)] pub fn set_sf_hash(&mut self, handshake_hash: [u8; 32]) -> Result<(), PrfError> { - let State::ServerFinished = self.state.take() else { + let State::ServerFinished { server_finished } = &mut self.state else { return Err(PrfError::state("PRF not in server finished state")); }; - let Some(ref mut circuits) = self.circuits else { - return Err(PrfError::state("Circuits should have been set for PRF")); - }; - let seed_sf = handshake_hash.to_vec(); - circuits.server_finished.set_start_seed(seed_sf); + server_finished.set_start_seed(seed_sf); - self.state = State::Complete; Ok(()) } - /// Drives the computation of the session keys. - /// - /// Returns if all inputs have been assigned for the computation of the - /// final output. - /// - /// # Arguments - /// - /// * `vm` - Virtual machine. - #[instrument(level = "debug", skip_all, err)] - pub fn drive_key_expansion(&mut self, vm: &mut dyn Vm) -> Result { - let Some(ref mut circuits) = self.circuits else { - return Err(PrfError::state("Circuits should have been set for PRF")); - }; - - circuits.drive_key_expansion(vm) - } - - /// Drives the computation of the client_finished verify_data. - /// - /// Returns if all inputs have been assigned for the computation of the - /// final output. - /// - /// # Arguments - /// - /// * `vm` - Virtual machine. - #[instrument(level = "debug", skip_all, err)] - pub fn drive_client_finished(&mut self, vm: &mut dyn Vm) -> Result { - let Some(ref mut circuits) = self.circuits else { - return Err(PrfError::state("Circuits should have been set for PRF")); - }; - - circuits.drive_client_finished(vm) - } - - /// Drives the computation of the server_finished verify_data. - /// - /// Returns if all inputs have been assigned for the computation of the - /// final output. - /// - /// # Arguments - /// - /// * `vm` - Virtual machine. - #[instrument(level = "debug", skip_all, err)] - pub fn drive_server_finished(&mut self, vm: &mut dyn Vm) -> Result { - let Some(ref mut circuits) = self.circuits else { - return Err(PrfError::state("Circuits should have been set for PRF")); - }; - - circuits.drive_server_finished(vm) - } -} - -/// Contains the respective [`PrfFunction`]s. -#[derive(Debug)] -struct Circuits { - pub(crate) master_secret: Prf, - pub(crate) key_expansion: Prf, - pub(crate) client_finished: Prf, - pub(crate) server_finished: Prf, -} - -impl Circuits { - const IPAD: [u8; 64] = [0x36; 64]; - const OPAD: [u8; 64] = [0x5c; 64]; - - fn alloc(config: Mode, vm: &mut dyn Vm, pms: Vector) -> Result { - let outer_partial_pms = compute_partial(vm, pms, Self::OPAD)?; - let inner_partial_pms = compute_partial(vm, pms, Self::IPAD)?; - - let master_secret = - Prf::alloc_master_secret(config, vm, outer_partial_pms, inner_partial_pms)?; - let ms = master_secret.output(); - let ms = merge_outputs(vm, ms, 48)?; - - let outer_partial_ms = compute_partial(vm, ms, Self::OPAD)?; - let inner_partial_ms = compute_partial(vm, ms, Self::IPAD)?; - - let circuits = Self { - master_secret, - key_expansion: Prf::alloc_key_expansion( - config, - vm, - outer_partial_ms, - inner_partial_ms, - )?, - client_finished: Prf::alloc_client_finished( - config, - vm, - outer_partial_ms, - inner_partial_ms, - )?, - server_finished: Prf::alloc_server_finished( - config, - vm, - outer_partial_ms, - inner_partial_ms, - )?, - }; - Ok(circuits) - } - - fn get_session_keys(&self, vm: &mut dyn Vm) -> Result { - let keys = self.key_expansion.output(); - let mut keys = merge_outputs(vm, keys, 40)?; - - let server_iv = as FromRaw>::from_raw(keys.split_off(36).to_raw()); - let client_iv = as FromRaw>::from_raw(keys.split_off(32).to_raw()); - let server_write_key = - as FromRaw>::from_raw(keys.split_off(16).to_raw()); - let client_write_key = as FromRaw>::from_raw(keys.to_raw()); - - let session_keys = SessionKeys { - client_write_key, - server_write_key, - client_iv, - server_iv, - }; - - Ok(session_keys) - } - - fn get_client_finished_vd(&self, vm: &mut dyn Vm) -> Result, PrfError> { - let client_finished = &self.client_finished; - let cf_vd = client_finished.output(); - - let cf_vd = merge_outputs(vm, cf_vd, 12)?; - let cf_vd = as FromRaw>::from_raw(cf_vd.to_raw()); - - Ok(cf_vd) - } - - fn get_server_finished_vd(&self, vm: &mut dyn Vm) -> Result, PrfError> { - let server_finished = &self.server_finished; - let sf_vd = server_finished.output(); - - let sf_vd = merge_outputs(vm, sf_vd, 12)?; - let sf_vd = as FromRaw>::from_raw(sf_vd.to_raw()); - - Ok(sf_vd) - } - - fn drive_key_expansion(&mut self, vm: &mut dyn Vm) -> Result { - let ms_finished = self.master_secret.make_progress(vm)?; - let ke_finished = self.key_expansion.make_progress(vm)?; - - Ok(ms_finished && ke_finished) - } - - fn drive_client_finished(&mut self, vm: &mut dyn Vm) -> Result { - self.client_finished.make_progress(vm) - } - - fn drive_server_finished(&mut self, vm: &mut dyn Vm) -> Result { - self.server_finished.make_progress(vm) + /// Returns if the PRF needs to be flushed. + pub fn wants_flush(&self) -> bool { + todo!() } } diff --git a/crates/components/hmac-sha256/src/prf/state.rs b/crates/components/hmac-sha256/src/prf/state.rs index 43cdca42d1..67dabae631 100644 --- a/crates/components/hmac-sha256/src/prf/state.rs +++ b/crates/components/hmac-sha256/src/prf/state.rs @@ -1,9 +1,32 @@ +use crate::{ + prf::{function::Prf, merge_outputs}, + PrfError, PrfOutput, SessionKeys, +}; +use mpz_vm_core::{ + memory::{ + binary::{Binary, U32, U8}, + Array, FromRaw, ToRaw, + }, + Vm, +}; + #[derive(Debug)] pub(crate) enum State { Initialized, - SessionKeys { client_random: Option<[u8; 32]> }, - ClientFinished, - ServerFinished, + SessionKeys { + client_random: Option<[u8; 32]>, + master_secret: Prf, + key_expansion: Prf, + client_finished: Prf, + server_finished: Prf, + }, + ClientFinished { + client_finished: Prf, + server_finished: Prf, + }, + ServerFinished { + server_finished: Prf, + }, Complete, Error, } @@ -12,4 +35,68 @@ impl State { pub(crate) fn take(&mut self) -> State { std::mem::replace(self, State::Error) } + + pub(crate) fn prf_output(&self, vm: &mut dyn Vm) -> Result { + let State::SessionKeys { + key_expansion, + client_finished, + server_finished, + .. + } = self + else { + return Err(PrfError::state( + "Prf output can only be computed while in \"SessionKeys\" state", + )); + }; + + let keys = get_session_keys(key_expansion.output(), vm)?; + let cf_vd = get_client_finished_vd(client_finished.output(), vm)?; + let sf_vd = get_server_finished_vd(server_finished.output(), vm)?; + + let output = PrfOutput { keys, cf_vd, sf_vd }; + + Ok(output) + } +} + +fn get_session_keys( + output: Vec>, + vm: &mut dyn Vm, +) -> Result { + let mut keys = merge_outputs(vm, output, 40)?; + + let server_iv = as FromRaw>::from_raw(keys.split_off(36).to_raw()); + let client_iv = as FromRaw>::from_raw(keys.split_off(32).to_raw()); + let server_write_key = + as FromRaw>::from_raw(keys.split_off(16).to_raw()); + let client_write_key = as FromRaw>::from_raw(keys.to_raw()); + + let session_keys = SessionKeys { + client_write_key, + server_write_key, + client_iv, + server_iv, + }; + + Ok(session_keys) +} + +fn get_client_finished_vd( + output: Vec>, + vm: &mut dyn Vm, +) -> Result, PrfError> { + let cf_vd = merge_outputs(vm, output, 12)?; + let cf_vd = as FromRaw>::from_raw(cf_vd.to_raw()); + + Ok(cf_vd) +} + +fn get_server_finished_vd( + output: Vec>, + vm: &mut dyn Vm, +) -> Result, PrfError> { + let sf_vd = merge_outputs(vm, output, 12)?; + let sf_vd = as FromRaw>::from_raw(sf_vd.to_raw()); + + Ok(sf_vd) } From ff362319faacbd7356a1ee112fdeacde45c65e6e Mon Sep 17 00:00:00 2001 From: th4s Date: Fri, 25 Apr 2025 13:39:17 +0200 Subject: [PATCH 07/36] improve external flush handling --- .../hmac-sha256/src/prf/function/mod.rs | 7 ++++ .../hmac-sha256/src/prf/function/normal.rs | 4 ++ .../hmac-sha256/src/prf/function/reduced.rs | 4 ++ crates/components/hmac-sha256/src/prf/mod.rs | 37 ++++++++++++++++++- .../components/hmac-sha256/src/prf/state.rs | 1 + crates/mpc-tls/src/follower.rs | 30 ++------------- crates/mpc-tls/src/leader.rs | 30 ++------------- 7 files changed, 57 insertions(+), 56 deletions(-) diff --git a/crates/components/hmac-sha256/src/prf/function/mod.rs b/crates/components/hmac-sha256/src/prf/function/mod.rs index 6f1c246ada..ecaa511d15 100644 --- a/crates/components/hmac-sha256/src/prf/function/mod.rs +++ b/crates/components/hmac-sha256/src/prf/function/mod.rs @@ -103,6 +103,13 @@ impl Prf { Ok(prf) } + pub(crate) fn wants_flush(&self) -> bool { + match self { + Prf::Local(prf) => prf.wants_flush(), + Prf::Mpc(prf) => prf.wants_flush(), + } + } + pub(crate) fn make_progress(&mut self, vm: &mut dyn Vm) -> Result { match self { Prf::Local(prf) => prf.make_progress(vm), diff --git a/crates/components/hmac-sha256/src/prf/function/normal.rs b/crates/components/hmac-sha256/src/prf/function/normal.rs index 9b1c04b7bb..f8ba561999 100644 --- a/crates/components/hmac-sha256/src/prf/function/normal.rs +++ b/crates/components/hmac-sha256/src/prf/function/normal.rs @@ -58,6 +58,10 @@ impl PrfFunction { Self::alloc(vm, Self::SF_LABEL, outer_partial, inner_partial, 12, 32) } + pub(crate) fn wants_flush(&self) -> bool { + todo!() + } + pub(crate) fn make_progress(&mut self, vm: &mut dyn Vm) -> Result { if !self.assigned { let a = self.a.first_mut().expect("prf should be allocated"); diff --git a/crates/components/hmac-sha256/src/prf/function/reduced.rs b/crates/components/hmac-sha256/src/prf/function/reduced.rs index 63e31f1782..0431e263f0 100644 --- a/crates/components/hmac-sha256/src/prf/function/reduced.rs +++ b/crates/components/hmac-sha256/src/prf/function/reduced.rs @@ -55,6 +55,10 @@ impl PrfFunction { Self::alloc(vm, Self::SF_LABEL, outer_partial, inner_partial, 12) } + pub(crate) fn wants_flush(&self) -> bool { + todo!() + } + pub(crate) fn make_progress(&mut self, vm: &mut dyn Vm) -> Result { let a_assigned = self.is_a_assigned(); let mut p_assigned = self.is_p_assigned(); diff --git a/crates/components/hmac-sha256/src/prf/mod.rs b/crates/components/hmac-sha256/src/prf/mod.rs index 4b2d3133e8..975d9a84f1 100644 --- a/crates/components/hmac-sha256/src/prf/mod.rs +++ b/crates/components/hmac-sha256/src/prf/mod.rs @@ -167,8 +167,41 @@ impl MpcPrf { } /// Returns if the PRF needs to be flushed. - pub fn wants_flush(&self) -> bool { - todo!() + /// + /// Also drives forward the inner state. + pub fn wants_flush(&mut self) -> bool { + let wants_flush = match &self.state { + State::Initialized => false, + State::SessionKeys { + master_secret, + key_expansion, + .. + } => master_secret.wants_flush() || key_expansion.wants_flush(), + State::ClientFinished { + client_finished, .. + } => client_finished.wants_flush(), + State::ServerFinished { server_finished } => server_finished.wants_flush(), + State::Complete => false, + State::Error => false, + }; + + self.state = match self.state.take() { + State::SessionKeys { + client_finished, + server_finished, + .. + } if !wants_flush => State::ClientFinished { + client_finished, + server_finished, + }, + State::ClientFinished { + server_finished, .. + } if !wants_flush => State::ServerFinished { server_finished }, + State::ServerFinished { .. } if !wants_flush => State::Complete, + other => other, + }; + + wants_flush } } diff --git a/crates/components/hmac-sha256/src/prf/state.rs b/crates/components/hmac-sha256/src/prf/state.rs index 67dabae631..52eaec68de 100644 --- a/crates/components/hmac-sha256/src/prf/state.rs +++ b/crates/components/hmac-sha256/src/prf/state.rs @@ -10,6 +10,7 @@ use mpz_vm_core::{ Vm, }; +#[allow(clippy::large_enum_variant)] #[derive(Debug)] pub(crate) enum State { Initialized, diff --git a/crates/mpc-tls/src/follower.rs b/crates/mpc-tls/src/follower.rs index f932f376e0..dbd618d3a9 100644 --- a/crates/mpc-tls/src/follower.rs +++ b/crates/mpc-tls/src/follower.rs @@ -271,18 +271,10 @@ impl MpcTlsFollower { ke.compute_shares(&mut self.ctx).await?; ke.assign(&mut (*vm))?; - loop { - let assigned = prf - .drive_key_expansion(&mut (*vm)) - .map_err(MpcTlsError::hs)?; - + while prf.wants_flush() { vm.execute_all(&mut self.ctx) .await .map_err(MpcTlsError::hs)?; - - if assigned { - break; - } } ke.finalize().await?; @@ -299,18 +291,10 @@ impl MpcTlsFollower { prf.set_cf_hash(vd.handshake_hash)?; - loop { - let assigned = prf - .drive_client_finished(&mut (*vm)) - .map_err(MpcTlsError::hs)?; - + while prf.wants_flush() { vm.execute_all(&mut self.ctx) .await .map_err(MpcTlsError::hs)?; - - if assigned { - break; - } } cf_vd = Some( @@ -331,18 +315,10 @@ impl MpcTlsFollower { prf.set_sf_hash(vd.handshake_hash)?; - loop { - let assigned = prf - .drive_server_finished(&mut (*vm)) - .map_err(MpcTlsError::hs)?; - + while prf.wants_flush() { vm.execute_all(&mut self.ctx) .await .map_err(MpcTlsError::hs)?; - - if assigned { - break; - } } sf_vd = Some( diff --git a/crates/mpc-tls/src/leader.rs b/crates/mpc-tls/src/leader.rs index 8fcfbcf65d..07c86d24e7 100644 --- a/crates/mpc-tls/src/leader.rs +++ b/crates/mpc-tls/src/leader.rs @@ -540,16 +540,8 @@ impl Backend for MpcTlsLeader { .map_err(|_| MpcTlsError::other("VM lock is held"))?; prf.set_sf_hash(hash).map_err(MpcTlsError::hs)?; - loop { - let assigned = prf - .drive_server_finished(&mut (*vm)) - .map_err(MpcTlsError::hs)?; - + while prf.wants_flush() { vm.execute_all(ctx).await.map_err(MpcTlsError::hs)?; - - if assigned { - break; - } } let sf_vd = sf_vd @@ -593,16 +585,8 @@ impl Backend for MpcTlsLeader { .map_err(|_| MpcTlsError::hs("VM lock is held"))?; prf.set_cf_hash(hash).map_err(MpcTlsError::hs)?; - loop { - let assigned = prf - .drive_client_finished(&mut (*vm)) - .map_err(MpcTlsError::hs)?; - + while prf.wants_flush() { vm.execute_all(ctx).await.map_err(MpcTlsError::hs)?; - - if assigned { - break; - } } let cf_vd = cf_vd @@ -666,19 +650,11 @@ impl Backend for MpcTlsLeader { ke.assign(&mut (*vm_lock)).map_err(MpcTlsError::hs)?; - loop { - let assigned = prf - .drive_key_expansion(&mut (*vm_lock)) - .map_err(MpcTlsError::hs)?; - + while prf.wants_flush() { vm_lock .execute_all(&mut ctx) .await .map_err(MpcTlsError::hs)?; - - if assigned { - break; - } } ke.finalize().await.map_err(MpcTlsError::hs)?; From 30446bdb02edeb3defa4389bc8bf76f817f526c7 Mon Sep 17 00:00:00 2001 From: th4s Date: Fri, 25 Apr 2025 14:28:48 +0200 Subject: [PATCH 08/36] improve control flow --- .../hmac-sha256/src/prf/function/mod.rs | 7 +++ .../hmac-sha256/src/prf/function/normal.rs | 4 ++ .../hmac-sha256/src/prf/function/reduced.rs | 4 ++ crates/components/hmac-sha256/src/prf/mod.rs | 57 +++++++++++++------ crates/mpc-tls/src/follower.rs | 3 + crates/mpc-tls/src/leader.rs | 3 + 6 files changed, 62 insertions(+), 16 deletions(-) diff --git a/crates/components/hmac-sha256/src/prf/function/mod.rs b/crates/components/hmac-sha256/src/prf/function/mod.rs index ecaa511d15..4c071222c8 100644 --- a/crates/components/hmac-sha256/src/prf/function/mod.rs +++ b/crates/components/hmac-sha256/src/prf/function/mod.rs @@ -110,6 +110,13 @@ impl Prf { } } + pub(crate) fn flush(&mut self) -> Result<(), PrfError> { + match self { + Prf::Local(prf) => prf.flush(), + Prf::Mpc(prf) => prf.flush(), + } + } + pub(crate) fn make_progress(&mut self, vm: &mut dyn Vm) -> Result { match self { Prf::Local(prf) => prf.make_progress(vm), diff --git a/crates/components/hmac-sha256/src/prf/function/normal.rs b/crates/components/hmac-sha256/src/prf/function/normal.rs index f8ba561999..ac2c5ae6a1 100644 --- a/crates/components/hmac-sha256/src/prf/function/normal.rs +++ b/crates/components/hmac-sha256/src/prf/function/normal.rs @@ -62,6 +62,10 @@ impl PrfFunction { todo!() } + pub(crate) fn flush(&mut self) -> Result<(), PrfError> { + todo!() + } + pub(crate) fn make_progress(&mut self, vm: &mut dyn Vm) -> Result { if !self.assigned { let a = self.a.first_mut().expect("prf should be allocated"); diff --git a/crates/components/hmac-sha256/src/prf/function/reduced.rs b/crates/components/hmac-sha256/src/prf/function/reduced.rs index 0431e263f0..a44a7cd5fc 100644 --- a/crates/components/hmac-sha256/src/prf/function/reduced.rs +++ b/crates/components/hmac-sha256/src/prf/function/reduced.rs @@ -59,6 +59,10 @@ impl PrfFunction { todo!() } + pub(crate) fn flush(&mut self) -> Result<(), PrfError> { + todo!() + } + pub(crate) fn make_progress(&mut self, vm: &mut dyn Vm) -> Result { let a_assigned = self.is_a_assigned(); let mut p_assigned = self.is_p_assigned(); diff --git a/crates/components/hmac-sha256/src/prf/mod.rs b/crates/components/hmac-sha256/src/prf/mod.rs index 975d9a84f1..e9439a5bf5 100644 --- a/crates/components/hmac-sha256/src/prf/mod.rs +++ b/crates/components/hmac-sha256/src/prf/mod.rs @@ -166,9 +166,7 @@ impl MpcPrf { Ok(()) } - /// Returns if the PRF needs to be flushed. - /// - /// Also drives forward the inner state. + /// Returns if the PRF needs to be flushed and drives the PRF. pub fn wants_flush(&mut self) -> bool { let wants_flush = match &self.state { State::Initialized => false, @@ -185,23 +183,50 @@ impl MpcPrf { State::Error => false, }; - self.state = match self.state.take() { + if !wants_flush { + self.state = match self.state.take() { + State::SessionKeys { + client_finished, + server_finished, + .. + } => State::ClientFinished { + client_finished, + server_finished, + }, + State::ClientFinished { + server_finished, .. + } => State::ServerFinished { server_finished }, + State::ServerFinished { .. } => State::Complete, + other => other, + }; + } + + wants_flush + } + + /// Flushes the PRF. + pub fn flush(&mut self) -> Result<(), PrfError> { + match &mut self.state { State::SessionKeys { - client_finished, - server_finished, + master_secret, + key_expansion, .. - } if !wants_flush => State::ClientFinished { - client_finished, - server_finished, - }, + } => { + master_secret.flush()?; + key_expansion.flush()?; + } State::ClientFinished { - server_finished, .. - } if !wants_flush => State::ServerFinished { server_finished }, - State::ServerFinished { .. } if !wants_flush => State::Complete, - other => other, - }; + client_finished, .. + } => { + client_finished.flush()?; + } + State::ServerFinished { server_finished } => { + server_finished.flush()?; + } + _ => (), + } - wants_flush + Ok(()) } } diff --git a/crates/mpc-tls/src/follower.rs b/crates/mpc-tls/src/follower.rs index dbd618d3a9..7024022334 100644 --- a/crates/mpc-tls/src/follower.rs +++ b/crates/mpc-tls/src/follower.rs @@ -275,6 +275,7 @@ impl MpcTlsFollower { vm.execute_all(&mut self.ctx) .await .map_err(MpcTlsError::hs)?; + prf.flush()?; } ke.finalize().await?; @@ -295,6 +296,7 @@ impl MpcTlsFollower { vm.execute_all(&mut self.ctx) .await .map_err(MpcTlsError::hs)?; + prf.flush()?; } cf_vd = Some( @@ -319,6 +321,7 @@ impl MpcTlsFollower { vm.execute_all(&mut self.ctx) .await .map_err(MpcTlsError::hs)?; + prf.flush()?; } sf_vd = Some( diff --git a/crates/mpc-tls/src/leader.rs b/crates/mpc-tls/src/leader.rs index 07c86d24e7..08ef8a45da 100644 --- a/crates/mpc-tls/src/leader.rs +++ b/crates/mpc-tls/src/leader.rs @@ -542,6 +542,7 @@ impl Backend for MpcTlsLeader { while prf.wants_flush() { vm.execute_all(ctx).await.map_err(MpcTlsError::hs)?; + prf.flush().map_err(MpcTlsError::hs)?; } let sf_vd = sf_vd @@ -587,6 +588,7 @@ impl Backend for MpcTlsLeader { while prf.wants_flush() { vm.execute_all(ctx).await.map_err(MpcTlsError::hs)?; + prf.flush().map_err(MpcTlsError::hs)?; } let cf_vd = cf_vd @@ -655,6 +657,7 @@ impl Backend for MpcTlsLeader { .execute_all(&mut ctx) .await .map_err(MpcTlsError::hs)?; + prf.flush().map_err(MpcTlsError::hs)?; } ke.finalize().await.map_err(MpcTlsError::hs)?; From 21c77de92c68b252d47e0038616116b009aeab3c Mon Sep 17 00:00:00 2001 From: th4s Date: Fri, 25 Apr 2025 15:03:37 +0200 Subject: [PATCH 09/36] improved inner control flow for normal prf version --- crates/components/hmac-sha256/src/lib.rs | 85 +++++++++---------- .../hmac-sha256/src/prf/function/mod.rs | 70 +++++++-------- .../hmac-sha256/src/prf/function/normal.rs | 29 ++++--- .../hmac-sha256/src/prf/function/reduced.rs | 2 +- crates/components/hmac-sha256/src/prf/mod.rs | 10 +-- crates/mpc-tls/src/follower.rs | 6 +- crates/mpc-tls/src/leader.rs | 6 +- 7 files changed, 101 insertions(+), 107 deletions(-) diff --git a/crates/components/hmac-sha256/src/lib.rs b/crates/components/hmac-sha256/src/lib.rs index 759d373495..105bd5f60f 100644 --- a/crates/components/hmac-sha256/src/lib.rs +++ b/crates/components/hmac-sha256/src/lib.rs @@ -67,13 +67,13 @@ mod tests { use rand::{rngs::StdRng, Rng, SeedableRng}; #[tokio::test] - async fn test_prf_local() { + async fn test_prf_reduced() { let config = Mode::Reduced; test_prf(config).await; } #[tokio::test] - async fn test_prf_mpc() { + async fn test_prf_normal() { let config = Mode::Normal; test_prf(config).await; } @@ -119,18 +119,18 @@ mod tests { follower.assign(follower_pms, pms).unwrap(); follower.commit(follower_pms).unwrap(); - let mut leader_prf = MpcPrf::new(config); - let mut follower_prf = MpcPrf::new(config); + let mut prf_leader = MpcPrf::new(config); + let mut prf_follower = MpcPrf::new(config); - let leader_prf_out = leader_prf.alloc(&mut leader, leader_pms).unwrap(); - let follower_prf_out = follower_prf.alloc(&mut follower, follower_pms).unwrap(); + let leader_prf_out = prf_leader.alloc(&mut leader, leader_pms).unwrap(); + let follower_prf_out = prf_follower.alloc(&mut follower, follower_pms).unwrap(); // client_random and server_random - leader_prf.set_client_random(client_random).unwrap(); - follower_prf.set_client_random(client_random).unwrap(); + prf_leader.set_client_random(client_random).unwrap(); + prf_follower.set_client_random(client_random).unwrap(); - leader_prf.set_server_random(server_random).unwrap(); - follower_prf.set_server_random(server_random).unwrap(); + prf_leader.set_server_random(server_random).unwrap(); + prf_follower.set_server_random(server_random).unwrap(); let SessionKeys { client_write_key: cwk_leader, @@ -156,19 +156,18 @@ mod tests { let mut civ_follower = follower.decode(civ_follower).unwrap(); let mut siv_follower = follower.decode(siv_follower).unwrap(); - loop { - let leader_finished = leader_prf.drive_key_expansion(&mut leader).unwrap(); - let follower_finished = follower_prf.drive_key_expansion(&mut follower).unwrap(); - + while prf_leader.wants_flush() || prf_follower.wants_flush() { tokio::try_join!( - leader.execute_all(&mut ctx_a), - follower.execute_all(&mut ctx_b) + async { + prf_leader.flush(&mut leader).unwrap(); + leader.execute_all(&mut ctx_a).await + }, + async { + prf_follower.flush(&mut follower).unwrap(); + follower.execute_all(&mut ctx_b).await + } ) .unwrap(); - - if leader_finished && follower_finished { - break; - } } let cwk_leader = cwk_leader.try_recv().unwrap().unwrap(); @@ -192,8 +191,8 @@ mod tests { assert_eq!(siv_leader, siv_expected); // client finished - leader_prf.set_cf_hash(cf_hs_hash).unwrap(); - follower_prf.set_cf_hash(cf_hs_hash).unwrap(); + prf_leader.set_cf_hash(cf_hs_hash).unwrap(); + prf_follower.set_cf_hash(cf_hs_hash).unwrap(); let cf_vd_leader = leader_prf_out.cf_vd; let cf_vd_follower = follower_prf_out.cf_vd; @@ -201,19 +200,18 @@ mod tests { let mut cf_vd_leader = leader.decode(cf_vd_leader).unwrap(); let mut cf_vd_follower = follower.decode(cf_vd_follower).unwrap(); - loop { - let leader_finished = leader_prf.drive_client_finished(&mut leader).unwrap(); - let follower_finished = follower_prf.drive_client_finished(&mut follower).unwrap(); - + while prf_leader.wants_flush() || prf_follower.wants_flush() { tokio::try_join!( - leader.execute_all(&mut ctx_a), - follower.execute_all(&mut ctx_b) + async { + prf_leader.flush(&mut leader).unwrap(); + leader.execute_all(&mut ctx_a).await + }, + async { + prf_follower.flush(&mut follower).unwrap(); + follower.execute_all(&mut ctx_b).await + } ) .unwrap(); - - if leader_finished && follower_finished { - break; - } } let cf_vd_leader = cf_vd_leader.try_recv().unwrap().unwrap(); @@ -223,8 +221,8 @@ mod tests { assert_eq!(cf_vd_leader, cf_vd_expected); // server finished - leader_prf.set_sf_hash(sf_hs_hash).unwrap(); - follower_prf.set_sf_hash(sf_hs_hash).unwrap(); + prf_leader.set_sf_hash(sf_hs_hash).unwrap(); + prf_follower.set_sf_hash(sf_hs_hash).unwrap(); let sf_vd_leader = leader_prf_out.sf_vd; let sf_vd_follower = follower_prf_out.sf_vd; @@ -232,19 +230,18 @@ mod tests { let mut sf_vd_leader = leader.decode(sf_vd_leader).unwrap(); let mut sf_vd_follower = follower.decode(sf_vd_follower).unwrap(); - loop { - let leader_finished = leader_prf.drive_server_finished(&mut leader).unwrap(); - let follower_finished = follower_prf.drive_server_finished(&mut follower).unwrap(); - + while prf_leader.wants_flush() || prf_follower.wants_flush() { tokio::try_join!( - leader.execute_all(&mut ctx_a), - follower.execute_all(&mut ctx_b) + async { + prf_leader.flush(&mut leader).unwrap(); + leader.execute_all(&mut ctx_a).await + }, + async { + prf_follower.flush(&mut follower).unwrap(); + follower.execute_all(&mut ctx_b).await + } ) .unwrap(); - - if leader_finished && follower_finished { - break; - } } let sf_vd_leader = sf_vd_leader.try_recv().unwrap().unwrap(); diff --git a/crates/components/hmac-sha256/src/prf/function/mod.rs b/crates/components/hmac-sha256/src/prf/function/mod.rs index 4c071222c8..49d72b7c1e 100644 --- a/crates/components/hmac-sha256/src/prf/function/mod.rs +++ b/crates/components/hmac-sha256/src/prf/function/mod.rs @@ -14,8 +14,8 @@ mod reduced; #[derive(Debug)] pub(crate) enum Prf { - Local(reduced::PrfFunction), - Mpc(normal::PrfFunction), + Reduced(reduced::PrfFunction), + Normal(normal::PrfFunction), } impl Prf { @@ -26,12 +26,12 @@ impl Prf { inner_partial: Array, ) -> Result { let prf = match mode { - Mode::Reduced => Self::Local(reduced::PrfFunction::alloc_master_secret( + Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_master_secret( vm, outer_partial, inner_partial, )?), - Mode::Normal => Self::Mpc(normal::PrfFunction::alloc_master_secret( + Mode::Normal => Self::Normal(normal::PrfFunction::alloc_master_secret( vm, outer_partial, inner_partial, @@ -47,12 +47,12 @@ impl Prf { inner_partial: Array, ) -> Result { let prf = match mode { - Mode::Reduced => Self::Local(reduced::PrfFunction::alloc_key_expansion( + Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_key_expansion( vm, outer_partial, inner_partial, )?), - Mode::Normal => Self::Mpc(normal::PrfFunction::alloc_key_expansion( + Mode::Normal => Self::Normal(normal::PrfFunction::alloc_key_expansion( vm, outer_partial, inner_partial, @@ -68,12 +68,12 @@ impl Prf { inner_partial: Array, ) -> Result { let prf = match config { - Mode::Reduced => Self::Local(reduced::PrfFunction::alloc_client_finished( + Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_client_finished( vm, outer_partial, inner_partial, )?), - Mode::Normal => Self::Mpc(normal::PrfFunction::alloc_client_finished( + Mode::Normal => Self::Normal(normal::PrfFunction::alloc_client_finished( vm, outer_partial, inner_partial, @@ -89,12 +89,12 @@ impl Prf { inner_partial: Array, ) -> Result { let prf = match config { - Mode::Reduced => Self::Local(reduced::PrfFunction::alloc_server_finished( + Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_server_finished( vm, outer_partial, inner_partial, )?), - Mode::Normal => Self::Mpc(normal::PrfFunction::alloc_server_finished( + Mode::Normal => Self::Normal(normal::PrfFunction::alloc_server_finished( vm, outer_partial, inner_partial, @@ -105,36 +105,29 @@ impl Prf { pub(crate) fn wants_flush(&self) -> bool { match self { - Prf::Local(prf) => prf.wants_flush(), - Prf::Mpc(prf) => prf.wants_flush(), + Prf::Reduced(prf) => prf.wants_flush(), + Prf::Normal(prf) => prf.wants_flush(), } } - pub(crate) fn flush(&mut self) -> Result<(), PrfError> { + pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { match self { - Prf::Local(prf) => prf.flush(), - Prf::Mpc(prf) => prf.flush(), - } - } - - pub(crate) fn make_progress(&mut self, vm: &mut dyn Vm) -> Result { - match self { - Prf::Local(prf) => prf.make_progress(vm), - Prf::Mpc(prf) => prf.make_progress(vm), + Prf::Reduced(prf) => prf.flush(vm), + Prf::Normal(prf) => prf.flush(vm), } } pub(crate) fn set_start_seed(&mut self, seed: Vec) { match self { - Prf::Local(prf) => prf.set_start_seed(seed), - Prf::Mpc(prf) => prf.set_start_seed(seed), + Prf::Reduced(prf) => prf.set_start_seed(seed), + Prf::Normal(prf) => prf.set_start_seed(seed), } } pub(crate) fn output(&self) -> Vec> { match self { - Prf::Local(prf) => prf.output(), - Prf::Mpc(prf) => prf.output(), + Prf::Reduced(prf) => prf.output(), + Prf::Normal(prf) => prf.output(), } } } @@ -157,14 +150,14 @@ mod tests { const OPAD: [u8; 64] = [0x5c; 64]; #[tokio::test] - async fn test_phash_local() { + async fn test_phash_reduced() { let config = Mode::Reduced; test_phash(config).await; } #[tokio::test] - async fn test_phash_mpc() { - let config = Mode::Reduced; + async fn test_phash_normal() { + let config = Mode::Normal; test_phash(config).await; } @@ -227,19 +220,18 @@ mod tests { prf_out_follower.push(p_out) } - loop { - let leader_finished = prf_leader.make_progress(&mut leader).unwrap(); - let follower_finished = prf_follower.make_progress(&mut follower).unwrap(); - + while prf_leader.wants_flush() || prf_follower.wants_flush() { tokio::try_join!( - leader.execute_all(&mut ctx_a), - follower.execute_all(&mut ctx_b) + async { + prf_leader.flush(&mut leader).unwrap(); + leader.execute_all(&mut ctx_a).await + }, + async { + prf_follower.flush(&mut follower).unwrap(); + follower.execute_all(&mut ctx_b).await + } ) .unwrap(); - - if leader_finished && follower_finished { - break; - } } assert_eq!(prf_out_leader.len(), prf_out_follower.len()); diff --git a/crates/components/hmac-sha256/src/prf/function/normal.rs b/crates/components/hmac-sha256/src/prf/function/normal.rs index ac2c5ae6a1..dcf517c5a0 100644 --- a/crates/components/hmac-sha256/src/prf/function/normal.rs +++ b/crates/components/hmac-sha256/src/prf/function/normal.rs @@ -14,10 +14,10 @@ use std::sync::Arc; #[derive(Debug)] pub(crate) struct PrfFunction { label: &'static [u8], + state: State, start_seed_label: Option>, a: Vec, p: Vec, - assigned: bool, } impl PrfFunction { @@ -59,15 +59,14 @@ impl PrfFunction { } pub(crate) fn wants_flush(&self) -> bool { - todo!() - } - - pub(crate) fn flush(&mut self) -> Result<(), PrfError> { - todo!() + match self.state { + State::Computing => true, + State::Finished => false, + } } - pub(crate) fn make_progress(&mut self, vm: &mut dyn Vm) -> Result { - if !self.assigned { + pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { + if let State::Computing = self.state { let a = self.a.first_mut().expect("prf should be allocated"); let msg = a.msg; @@ -79,10 +78,10 @@ impl PrfFunction { vm.mark_public(msg).map_err(PrfError::vm)?; vm.assign(msg, msg_value).map_err(PrfError::vm)?; vm.commit(msg).map_err(PrfError::vm)?; - self.assigned = true; - } - Ok(self.assigned) + self.state = State::Finished; + } + Ok(()) } pub(crate) fn set_start_seed(&mut self, seed: Vec) { @@ -106,10 +105,10 @@ impl PrfFunction { ) -> Result { let mut prf = Self { label, + state: State::Computing, start_seed_label: None, a: vec![], p: vec![], - assigned: false, }; assert!(output_len > 0, "cannot compute 0 bytes for prf"); @@ -134,6 +133,12 @@ impl PrfFunction { } } +#[derive(Debug, Clone, Copy)] +enum State { + Computing, + Finished, +} + #[derive(Debug, Clone)] struct PHash { pub(crate) msg: Vector, diff --git a/crates/components/hmac-sha256/src/prf/function/reduced.rs b/crates/components/hmac-sha256/src/prf/function/reduced.rs index a44a7cd5fc..426d40aca8 100644 --- a/crates/components/hmac-sha256/src/prf/function/reduced.rs +++ b/crates/components/hmac-sha256/src/prf/function/reduced.rs @@ -59,7 +59,7 @@ impl PrfFunction { todo!() } - pub(crate) fn flush(&mut self) -> Result<(), PrfError> { + pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { todo!() } diff --git a/crates/components/hmac-sha256/src/prf/mod.rs b/crates/components/hmac-sha256/src/prf/mod.rs index e9439a5bf5..0b7ad31cad 100644 --- a/crates/components/hmac-sha256/src/prf/mod.rs +++ b/crates/components/hmac-sha256/src/prf/mod.rs @@ -205,23 +205,23 @@ impl MpcPrf { } /// Flushes the PRF. - pub fn flush(&mut self) -> Result<(), PrfError> { + pub fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { match &mut self.state { State::SessionKeys { master_secret, key_expansion, .. } => { - master_secret.flush()?; - key_expansion.flush()?; + master_secret.flush(vm)?; + key_expansion.flush(vm)?; } State::ClientFinished { client_finished, .. } => { - client_finished.flush()?; + client_finished.flush(vm)?; } State::ServerFinished { server_finished } => { - server_finished.flush()?; + server_finished.flush(vm)?; } _ => (), } diff --git a/crates/mpc-tls/src/follower.rs b/crates/mpc-tls/src/follower.rs index 7024022334..b3dcc0c0a0 100644 --- a/crates/mpc-tls/src/follower.rs +++ b/crates/mpc-tls/src/follower.rs @@ -272,10 +272,10 @@ impl MpcTlsFollower { ke.assign(&mut (*vm))?; while prf.wants_flush() { + prf.flush(&mut *vm)?; vm.execute_all(&mut self.ctx) .await .map_err(MpcTlsError::hs)?; - prf.flush()?; } ke.finalize().await?; @@ -293,10 +293,10 @@ impl MpcTlsFollower { prf.set_cf_hash(vd.handshake_hash)?; while prf.wants_flush() { + prf.flush(&mut *vm)?; vm.execute_all(&mut self.ctx) .await .map_err(MpcTlsError::hs)?; - prf.flush()?; } cf_vd = Some( @@ -318,10 +318,10 @@ impl MpcTlsFollower { prf.set_sf_hash(vd.handshake_hash)?; while prf.wants_flush() { + prf.flush(&mut *vm)?; vm.execute_all(&mut self.ctx) .await .map_err(MpcTlsError::hs)?; - prf.flush()?; } sf_vd = Some( diff --git a/crates/mpc-tls/src/leader.rs b/crates/mpc-tls/src/leader.rs index 08ef8a45da..5ea16be1e7 100644 --- a/crates/mpc-tls/src/leader.rs +++ b/crates/mpc-tls/src/leader.rs @@ -541,8 +541,8 @@ impl Backend for MpcTlsLeader { prf.set_sf_hash(hash).map_err(MpcTlsError::hs)?; while prf.wants_flush() { + prf.flush(&mut *vm).map_err(MpcTlsError::hs)?; vm.execute_all(ctx).await.map_err(MpcTlsError::hs)?; - prf.flush().map_err(MpcTlsError::hs)?; } let sf_vd = sf_vd @@ -587,8 +587,8 @@ impl Backend for MpcTlsLeader { prf.set_cf_hash(hash).map_err(MpcTlsError::hs)?; while prf.wants_flush() { + prf.flush(&mut *vm).map_err(MpcTlsError::hs)?; vm.execute_all(ctx).await.map_err(MpcTlsError::hs)?; - prf.flush().map_err(MpcTlsError::hs)?; } let cf_vd = cf_vd @@ -653,11 +653,11 @@ impl Backend for MpcTlsLeader { ke.assign(&mut (*vm_lock)).map_err(MpcTlsError::hs)?; while prf.wants_flush() { + prf.flush(&mut *vm_lock).map_err(MpcTlsError::hs)?; vm_lock .execute_all(&mut ctx) .await .map_err(MpcTlsError::hs)?; - prf.flush().map_err(MpcTlsError::hs)?; } ke.finalize().await.map_err(MpcTlsError::hs)?; From 21b4a5465486323048812d7edc8d22dc24be4aa0 Mon Sep 17 00:00:00 2001 From: th4s Date: Fri, 25 Apr 2025 15:05:04 +0200 Subject: [PATCH 10/36] rename leftover `config` -> `mode` --- crates/components/hmac-sha256/src/lib.rs | 14 +++++++------- .../components/hmac-sha256/src/prf/function/mod.rs | 14 +++++++------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/crates/components/hmac-sha256/src/lib.rs b/crates/components/hmac-sha256/src/lib.rs index 105bd5f60f..050162f499 100644 --- a/crates/components/hmac-sha256/src/lib.rs +++ b/crates/components/hmac-sha256/src/lib.rs @@ -68,17 +68,17 @@ mod tests { #[tokio::test] async fn test_prf_reduced() { - let config = Mode::Reduced; - test_prf(config).await; + let mode = Mode::Reduced; + test_prf(mode).await; } #[tokio::test] async fn test_prf_normal() { - let config = Mode::Normal; - test_prf(config).await; + let mode = Mode::Normal; + test_prf(mode).await; } - async fn test_prf(config: Mode) { + async fn test_prf(mode: Mode) { let mut rng = StdRng::seed_from_u64(1); // Test input let pms: [u8; 32] = rng.random(); @@ -119,8 +119,8 @@ mod tests { follower.assign(follower_pms, pms).unwrap(); follower.commit(follower_pms).unwrap(); - let mut prf_leader = MpcPrf::new(config); - let mut prf_follower = MpcPrf::new(config); + let mut prf_leader = MpcPrf::new(mode); + let mut prf_follower = MpcPrf::new(mode); let leader_prf_out = prf_leader.alloc(&mut leader, leader_pms).unwrap(); let follower_prf_out = prf_follower.alloc(&mut follower, follower_pms).unwrap(); diff --git a/crates/components/hmac-sha256/src/prf/function/mod.rs b/crates/components/hmac-sha256/src/prf/function/mod.rs index 49d72b7c1e..049053de87 100644 --- a/crates/components/hmac-sha256/src/prf/function/mod.rs +++ b/crates/components/hmac-sha256/src/prf/function/mod.rs @@ -151,17 +151,17 @@ mod tests { #[tokio::test] async fn test_phash_reduced() { - let config = Mode::Reduced; - test_phash(config).await; + let mode = Mode::Reduced; + test_phash(mode).await; } #[tokio::test] async fn test_phash_normal() { - let config = Mode::Normal; - test_phash(config).await; + let mode = Mode::Normal; + test_phash(mode).await; } - async fn test_phash(config: Mode) { + async fn test_phash(mode: Mode) { let (mut ctx_a, mut ctx_b) = test_st_context(8); let (mut leader, mut follower) = mock_vm(); @@ -181,7 +181,7 @@ mod tests { let inner_partial_leader = compute_partial(&mut leader, leader_key.into(), IPAD).unwrap(); let mut prf_leader = Prf::alloc_master_secret( - config, + mode, &mut leader, outer_partial_leader, inner_partial_leader, @@ -206,7 +206,7 @@ mod tests { compute_partial(&mut follower, follower_key.into(), IPAD).unwrap(); let mut prf_follower = Prf::alloc_master_secret( - config, + mode, &mut follower, outer_partial_follower, inner_partial_follower, From e6fb9b633fa33a079428aa48b98daf9645db4b71 Mon Sep 17 00:00:00 2001 From: th4s Date: Fri, 25 Apr 2025 16:42:02 +0200 Subject: [PATCH 11/36] remove unnecessary pub(crate) --- .../hmac-sha256/src/prf/function/normal.rs | 4 +- .../hmac-sha256/src/prf/function/reduced.rs | 162 +++++------------- 2 files changed, 48 insertions(+), 118 deletions(-) diff --git a/crates/components/hmac-sha256/src/prf/function/normal.rs b/crates/components/hmac-sha256/src/prf/function/normal.rs index dcf517c5a0..3e8aa24fdf 100644 --- a/crates/components/hmac-sha256/src/prf/function/normal.rs +++ b/crates/components/hmac-sha256/src/prf/function/normal.rs @@ -141,8 +141,8 @@ enum State { #[derive(Debug, Clone)] struct PHash { - pub(crate) msg: Vector, - pub(crate) output: Array, + msg: Vector, + output: Array, } impl PHash { diff --git a/crates/components/hmac-sha256/src/prf/function/reduced.rs b/crates/components/hmac-sha256/src/prf/function/reduced.rs index 426d40aca8..83298ca668 100644 --- a/crates/components/hmac-sha256/src/prf/function/reduced.rs +++ b/crates/components/hmac-sha256/src/prf/function/reduced.rs @@ -1,10 +1,11 @@ //! Computes some hashes of the PRF locally. use crate::{convert_to_bytes, hmac::HmacSha256, sha256::sha256, PrfError}; +use mpz_core::bitvec::BitVec; use mpz_vm_core::{ memory::{ binary::{Binary, U32, U8}, - Array, DecodeFutureTyped, MemoryExt, MemoryType, Repr, ViewExt, + Array, DecodeFutureTyped, MemoryExt, ViewExt, }, Vm, }; @@ -12,6 +13,7 @@ use mpz_vm_core::{ #[derive(Debug)] pub(crate) struct PrfFunction { label: &'static [u8], + state: State, start_seed_label: Option>, a: Vec, p: Vec, @@ -55,28 +57,24 @@ impl PrfFunction { Self::alloc(vm, Self::SF_LABEL, outer_partial, inner_partial, 12) } - pub(crate) fn wants_flush(&self) -> bool { - todo!() - } - - pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { - todo!() - } + pub(crate) fn wants_flush(&mut self) -> bool { + let wants_flush = if let State::Finished = self.state { + false + } else { + true + }; - pub(crate) fn make_progress(&mut self, vm: &mut dyn Vm) -> Result { - let a_assigned = self.is_a_assigned(); - let mut p_assigned = self.is_p_assigned(); + // Drive state - if !a_assigned { - self.poll_a(vm)?; - } + wants_flush + } - if !p_assigned { - self.poll_p(vm)?; - p_assigned = self.is_p_assigned(); + pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { + if let State::Computing = self.state { + todo!() } - Ok(p_assigned) + Ok(()) } pub(crate) fn set_start_seed(&mut self, seed: Vec) { @@ -87,7 +85,7 @@ impl PrfFunction { } pub(crate) fn output(&self) -> Vec> { - self.p.iter().map(|p| p.output.value()).collect() + self.p.iter().map(|p| p.output).collect() } fn poll_a(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { @@ -96,7 +94,7 @@ impl PrfFunction { }; for a in self.a.iter_mut() { - if let Some(output) = a.output.poll(vm)? { + if let Some(output) = a.output_decoded.try_recv().map_err(PrfError::vm)? { message = convert_to_bytes(output).to_vec(); continue; }; @@ -146,6 +144,7 @@ impl PrfFunction { len: usize, ) -> Result { let mut prf = Self { + state: State::Computing, label, start_seed_label: None, a: vec![], @@ -166,30 +165,21 @@ impl PrfFunction { Ok(prf) } +} - fn is_p_assigned(&self) -> bool { - self.p - .last() - .expect("prf should be allocated") - .inner_local - .1 - } - - fn is_a_assigned(&self) -> bool { - self.a - .last() - .expect("prf should be allocated") - .inner_local - .1 - } +#[derive(Debug, Clone, Copy)] +enum State { + Computing, + Finished, } #[derive(Debug)] struct PHash { - pub(crate) inner_partial: DecodeOperation>, - // the bool tracks if assignment has already happened - pub(crate) inner_local: (Array, bool), - pub(crate) output: DecodeOperation>, + state: InnerState, + inner_partial: Array, + inner_local: Array, + output: Array, + output_decoded: DecodeFutureTyped, } impl PHash { @@ -202,11 +192,14 @@ impl PHash { let hmac = HmacSha256::new(outer_partial, inner_local); let output = hmac.alloc(vm).map_err(PrfError::vm)?; + let output_decoded = vm.decode(output).map_err(PrfError::vm)?; let p_hash = Self { - inner_partial: DecodeOperation::new(inner_partial), - inner_local: (inner_local, false), - output: DecodeOperation::new(output), + state: InnerState::Init, + inner_partial, + inner_local, + output, + output_decoded, }; Ok(p_hash) @@ -218,85 +211,22 @@ impl PHash { inner_partial: [u32; 8], msg: &[u8], ) -> Result<(), PrfError> { - if !self.inner_local.1 { - let inner_local_ref: Array = self.inner_local.0; - let inner_local = sha256(inner_partial, 64, msg); + let inner_local_ref: Array = self.inner_local; + let inner_local = sha256(inner_partial, 64, msg); - vm.mark_public(inner_local_ref).map_err(PrfError::vm)?; - vm.assign(inner_local_ref, convert_to_bytes(inner_local)) - .map_err(PrfError::vm)?; - vm.commit(inner_local_ref).map_err(PrfError::vm)?; + vm.mark_public(inner_local_ref).map_err(PrfError::vm)?; + vm.assign(inner_local_ref, convert_to_bytes(inner_local)) + .map_err(PrfError::vm)?; + vm.commit(inner_local_ref).map_err(PrfError::vm)?; - self.inner_local.1 = true - } + self.state = InnerState::Assigned; Ok(()) } } -#[derive(Debug)] -struct DecodeOperation -where - T: Repr, -{ - value: T, - progress: DecodeProgress, -} - -impl DecodeOperation -where - T: Repr + Copy, -{ - pub(crate) fn new(value: T) -> Self { - Self { - value, - progress: DecodeProgress::Alloc, - } - } - - pub(crate) fn value(&self) -> T { - self.value - } - - pub(crate) fn poll(&mut self, vm: &mut dyn Vm) -> Result, PrfError> { - self.progress.poll(vm, self.value) - } -} - -#[derive(Debug)] -enum DecodeProgress -where - T: Repr, -{ - Alloc, - Decoded(DecodeFutureTyped<::Raw, T::Clear>), - Finished(T::Clear), -} - -impl DecodeProgress -where - T: Repr + Copy, -{ - pub(crate) fn poll( - &mut self, - vm: &mut dyn Vm, - value: T, - ) -> Result, PrfError> { - match self { - DecodeProgress::Alloc => { - let value = vm.decode(value).map_err(PrfError::vm)?; - *self = DecodeProgress::Decoded(value); - Ok(None) - } - DecodeProgress::Decoded(value) => { - if let Some(value) = value.try_recv().map_err(PrfError::vm)? { - *self = DecodeProgress::Finished(value); - Ok(Some(value)) - } else { - Ok(None) - } - } - DecodeProgress::Finished(value) => Ok(Some(*value)), - } - } +#[derive(Debug, Clone, Copy)] +enum InnerState { + Init, + Assigned, } From 6cae7f5c02eb357747a89f2d326721dac179d167 Mon Sep 17 00:00:00 2001 From: th4s Date: Mon, 28 Apr 2025 15:49:56 +0200 Subject: [PATCH 12/36] rewrite state flow for reduced prf --- .../hmac-sha256/src/prf/function/mod.rs | 2 +- .../hmac-sha256/src/prf/function/reduced.rs | 168 +++++++++--------- crates/components/hmac-sha256/src/prf/mod.rs | 2 +- 3 files changed, 84 insertions(+), 88 deletions(-) diff --git a/crates/components/hmac-sha256/src/prf/function/mod.rs b/crates/components/hmac-sha256/src/prf/function/mod.rs index 049053de87..72de5cb1e4 100644 --- a/crates/components/hmac-sha256/src/prf/function/mod.rs +++ b/crates/components/hmac-sha256/src/prf/function/mod.rs @@ -103,7 +103,7 @@ impl Prf { Ok(prf) } - pub(crate) fn wants_flush(&self) -> bool { + pub(crate) fn wants_flush(&mut self) -> bool { match self { Prf::Reduced(prf) => prf.wants_flush(), Prf::Normal(prf) => prf.wants_flush(), diff --git a/crates/components/hmac-sha256/src/prf/function/reduced.rs b/crates/components/hmac-sha256/src/prf/function/reduced.rs index 83298ca668..5fb1ae7b11 100644 --- a/crates/components/hmac-sha256/src/prf/function/reduced.rs +++ b/crates/components/hmac-sha256/src/prf/function/reduced.rs @@ -13,7 +13,6 @@ use mpz_vm_core::{ #[derive(Debug)] pub(crate) struct PrfFunction { label: &'static [u8], - state: State, start_seed_label: Option>, a: Vec, p: Vec, @@ -58,20 +57,65 @@ impl PrfFunction { } pub(crate) fn wants_flush(&mut self) -> bool { - let wants_flush = if let State::Finished = self.state { - false - } else { - true - }; - - // Drive state + let last_p = self.p.last().expect("Prf should be allocated"); - wants_flush + if let State::Finished { .. } = last_p.state { + return false; + } + true } pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { - if let State::Computing = self.state { - todo!() + let mut message = self.start_seed_label.clone(); + + for (a, p) in self.a.iter_mut().zip(self.p.iter_mut()) { + match &mut a.state { + State::Init { inner_partial, .. } => { + if let (Some(msg), Some(inner_partial)) = ( + message.as_ref(), + inner_partial.try_recv().map_err(PrfError::vm)?, + ) { + a.assign_inner_local(vm, inner_partial, msg)?; + message = None; + } + } + State::Assigned { output } => { + if let Some(output) = output.try_recv().map_err(PrfError::vm)? { + let output = convert_to_bytes(output).to_vec(); + a.state = State::Finished { + output: output.clone(), + }; + message = Some(output); + } + } + State::Finished { output } => { + message = Some(output.clone()); + } + } + + match &mut p.state { + State::Init { inner_partial, .. } => { + if let (State::Finished { output }, Some(inner_partial)) = + (&a.state, inner_partial.try_recv().map_err(PrfError::vm)?) + { + let mut msg = output.to_vec(); + msg.extend_from_slice( + self.start_seed_label + .as_ref() + .expect("Start seed for PRF should be set"), + ); + + p.assign_inner_local(vm, inner_partial, &msg)?; + } + } + State::Assigned { output } => { + if let Some(output) = output.try_recv().map_err(PrfError::vm)? { + let output = convert_to_bytes(output).to_vec(); + a.state = State::Finished { output }; + } + } + _ => (), + } } Ok(()) @@ -88,54 +132,6 @@ impl PrfFunction { self.p.iter().map(|p| p.output).collect() } - fn poll_a(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { - let Some(mut message) = self.start_seed_label.clone() else { - return Err(PrfError::state("Starting seed not set for PRF")); - }; - - for a in self.a.iter_mut() { - if let Some(output) = a.output_decoded.try_recv().map_err(PrfError::vm)? { - message = convert_to_bytes(output).to_vec(); - continue; - }; - - let Some(inner_partial) = a.inner_partial.poll(vm)? else { - break; - }; - - a.assign_inner_local(vm, inner_partial, &message)?; - } - - Ok(()) - } - - fn poll_p(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { - let Some(ref start_seed) = self.start_seed_label else { - return Err(PrfError::state("Starting seed not set for PRF")); - }; - - for (i, p) in self.p.iter_mut().enumerate() { - if p.inner_local.1 { - continue; - } - - let Some(message) = self.a[i].output.poll(vm)? else { - break; - }; - - let mut message = convert_to_bytes(message).to_vec(); - message.extend_from_slice(start_seed); - - let Some(inner_partial) = p.inner_partial.poll(vm)? else { - break; - }; - - p.assign_inner_local(vm, inner_partial, &message)?; - } - - Ok(()) - } - fn alloc( vm: &mut dyn Vm, label: &'static [u8], @@ -144,7 +140,6 @@ impl PrfFunction { len: usize, ) -> Result { let mut prf = Self { - state: State::Computing, label, start_seed_label: None, a: vec![], @@ -167,19 +162,10 @@ impl PrfFunction { } } -#[derive(Debug, Clone, Copy)] -enum State { - Computing, - Finished, -} - #[derive(Debug)] struct PHash { - state: InnerState, - inner_partial: Array, - inner_local: Array, output: Array, - output_decoded: DecodeFutureTyped, + state: State, } impl PHash { @@ -192,14 +178,14 @@ impl PHash { let hmac = HmacSha256::new(outer_partial, inner_local); let output = hmac.alloc(vm).map_err(PrfError::vm)?; - let output_decoded = vm.decode(output).map_err(PrfError::vm)?; + let inner_partial = vm.decode(inner_partial).map_err(PrfError::vm)?; let p_hash = Self { - state: InnerState::Init, - inner_partial, - inner_local, + state: State::Init { + inner_partial, + inner_local, + }, output, - output_decoded, }; Ok(p_hash) @@ -211,22 +197,32 @@ impl PHash { inner_partial: [u32; 8], msg: &[u8], ) -> Result<(), PrfError> { - let inner_local_ref: Array = self.inner_local; - let inner_local = sha256(inner_partial, 64, msg); + if let State::Init { inner_local, .. } = self.state { + let inner_local_value = sha256(inner_partial, 64, msg); - vm.mark_public(inner_local_ref).map_err(PrfError::vm)?; - vm.assign(inner_local_ref, convert_to_bytes(inner_local)) - .map_err(PrfError::vm)?; - vm.commit(inner_local_ref).map_err(PrfError::vm)?; + vm.mark_public(inner_local).map_err(PrfError::vm)?; + vm.assign(inner_local, convert_to_bytes(inner_local_value)) + .map_err(PrfError::vm)?; + vm.commit(inner_local).map_err(PrfError::vm)?; - self.state = InnerState::Assigned; + let output = vm.decode(self.output).map_err(PrfError::vm)?; + self.state = State::Assigned { output }; + } Ok(()) } } -#[derive(Debug, Clone, Copy)] -enum InnerState { - Init, - Assigned, +#[derive(Debug)] +enum State { + Init { + inner_partial: DecodeFutureTyped, + inner_local: Array, + }, + Assigned { + output: DecodeFutureTyped, + }, + Finished { + output: Vec, + }, } diff --git a/crates/components/hmac-sha256/src/prf/mod.rs b/crates/components/hmac-sha256/src/prf/mod.rs index 0b7ad31cad..ce0b93823c 100644 --- a/crates/components/hmac-sha256/src/prf/mod.rs +++ b/crates/components/hmac-sha256/src/prf/mod.rs @@ -168,7 +168,7 @@ impl MpcPrf { /// Returns if the PRF needs to be flushed and drives the PRF. pub fn wants_flush(&mut self) -> bool { - let wants_flush = match &self.state { + let wants_flush = match &mut self.state { State::Initialized => false, State::SessionKeys { master_secret, From fec1e124e8b691b4ccfdfe539c002009220b0d85 Mon Sep 17 00:00:00 2001 From: th4s Date: Mon, 28 Apr 2025 21:16:02 +0200 Subject: [PATCH 13/36] improve state transition for reduced prf --- .../hmac-sha256/src/prf/function/normal.rs | 15 ++- .../hmac-sha256/src/prf/function/reduced.rs | 95 +++++++++++-------- 2 files changed, 60 insertions(+), 50 deletions(-) diff --git a/crates/components/hmac-sha256/src/prf/function/normal.rs b/crates/components/hmac-sha256/src/prf/function/normal.rs index 3e8aa24fdf..b37e65699e 100644 --- a/crates/components/hmac-sha256/src/prf/function/normal.rs +++ b/crates/components/hmac-sha256/src/prf/function/normal.rs @@ -13,9 +13,11 @@ use std::sync::Arc; #[derive(Debug)] pub(crate) struct PrfFunction { + // The label, e.g. "master secret". label: &'static [u8], state: State, - start_seed_label: Option>, + // The start seed and the label, e.g. client_random + server_random + "master_secret". + start_seed_label: Vec, a: Vec, p: Vec, } @@ -67,13 +69,10 @@ impl PrfFunction { pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { if let State::Computing = self.state { - let a = self.a.first_mut().expect("prf should be allocated"); + let a = self.a.first().expect("prf should be allocated"); let msg = a.msg; - let msg_value = self - .start_seed_label - .clone() - .expect("seed should be assigned by now"); + let msg_value = self.start_seed_label.clone(); vm.mark_public(msg).map_err(PrfError::vm)?; vm.assign(msg, msg_value).map_err(PrfError::vm)?; @@ -88,7 +87,7 @@ impl PrfFunction { let mut start_seed_label = self.label.to_vec(); start_seed_label.extend_from_slice(&seed); - self.start_seed_label = Some(start_seed_label); + self.start_seed_label = start_seed_label; } pub(crate) fn output(&self) -> Vec> { @@ -106,7 +105,7 @@ impl PrfFunction { let mut prf = Self { label, state: State::Computing, - start_seed_label: None, + start_seed_label: vec![], a: vec![], p: vec![], }; diff --git a/crates/components/hmac-sha256/src/prf/function/reduced.rs b/crates/components/hmac-sha256/src/prf/function/reduced.rs index 5fb1ae7b11..3eb67bfee7 100644 --- a/crates/components/hmac-sha256/src/prf/function/reduced.rs +++ b/crates/components/hmac-sha256/src/prf/function/reduced.rs @@ -12,8 +12,13 @@ use mpz_vm_core::{ #[derive(Debug)] pub(crate) struct PrfFunction { + // The label, e.g. "master secret". label: &'static [u8], - start_seed_label: Option>, + // The start seed and the label, e.g. client_random + server_random + "master_secret". + start_seed_label: Vec, + // The current HMAC message needed for a[i] + a_msg: Vec, + inner_partial: InnerPartial, a: Vec, p: Vec, } @@ -66,18 +71,16 @@ impl PrfFunction { } pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { - let mut message = self.start_seed_label.clone(); + let inner_partial = self.inner_partial.try_recv()?; + let Some(inner_partial) = inner_partial else { + return Ok(()); + }; for (a, p) in self.a.iter_mut().zip(self.p.iter_mut()) { match &mut a.state { - State::Init { inner_partial, .. } => { - if let (Some(msg), Some(inner_partial)) = ( - message.as_ref(), - inner_partial.try_recv().map_err(PrfError::vm)?, - ) { - a.assign_inner_local(vm, inner_partial, msg)?; - message = None; - } + State::Init { .. } => { + a.assign_inner_local(vm, inner_partial, &self.a_msg)?; + break; } State::Assigned { output } => { if let Some(output) = output.try_recv().map_err(PrfError::vm)? { @@ -85,33 +88,24 @@ impl PrfFunction { a.state = State::Finished { output: output.clone(), }; - message = Some(output); + self.a_msg = output; } } - State::Finished { output } => { - message = Some(output.clone()); - } + _ => (), } match &mut p.state { - State::Init { inner_partial, .. } => { - if let (State::Finished { output }, Some(inner_partial)) = - (&a.state, inner_partial.try_recv().map_err(PrfError::vm)?) - { - let mut msg = output.to_vec(); - msg.extend_from_slice( - self.start_seed_label - .as_ref() - .expect("Start seed for PRF should be set"), - ); - - p.assign_inner_local(vm, inner_partial, &msg)?; + State::Init { .. } => { + if let State::Finished { output } = &a.state { + let mut p_msg = output.to_vec(); + p_msg.extend_from_slice(&self.start_seed_label); + p.assign_inner_local(vm, inner_partial, &p_msg)?; } } State::Assigned { output } => { if let Some(output) = output.try_recv().map_err(PrfError::vm)? { let output = convert_to_bytes(output).to_vec(); - a.state = State::Finished { output }; + p.state = State::Finished { output }; } } _ => (), @@ -125,7 +119,8 @@ impl PrfFunction { let mut start_seed_label = self.label.to_vec(); start_seed_label.extend_from_slice(&seed); - self.start_seed_label = Some(start_seed_label); + self.start_seed_label = start_seed_label.clone(); + self.a_msg = start_seed_label; } pub(crate) fn output(&self) -> Vec> { @@ -139,9 +134,13 @@ impl PrfFunction { inner_partial: Array, len: usize, ) -> Result { + let inner_partial = vm.decode(inner_partial).map_err(PrfError::vm)?; + let mut prf = Self { label, - start_seed_label: None, + start_seed_label: vec![], + a_msg: vec![], + inner_partial: InnerPartial::Decoding(inner_partial), a: vec![], p: vec![], }; @@ -151,10 +150,10 @@ impl PrfFunction { let iterations = len / 32 + ((len % 32) != 0) as usize; for _ in 0..iterations { - let a = PHash::alloc(vm, outer_partial, inner_partial)?; + let a = PHash::alloc(vm, outer_partial)?; prf.a.push(a); - let p = PHash::alloc(vm, outer_partial, inner_partial)?; + let p = PHash::alloc(vm, outer_partial)?; prf.p.push(p); } @@ -169,22 +168,14 @@ struct PHash { } impl PHash { - fn alloc( - vm: &mut dyn Vm, - outer_partial: Array, - inner_partial: Array, - ) -> Result { + fn alloc(vm: &mut dyn Vm, outer_partial: Array) -> Result { let inner_local = vm.alloc().map_err(PrfError::vm)?; let hmac = HmacSha256::new(outer_partial, inner_local); let output = hmac.alloc(vm).map_err(PrfError::vm)?; - let inner_partial = vm.decode(inner_partial).map_err(PrfError::vm)?; let p_hash = Self { - state: State::Init { - inner_partial, - inner_local, - }, + state: State::Init { inner_local }, output, }; @@ -216,7 +207,6 @@ impl PHash { #[derive(Debug)] enum State { Init { - inner_partial: DecodeFutureTyped, inner_local: Array, }, Assigned { @@ -226,3 +216,24 @@ enum State { output: Vec, }, } + +#[derive(Debug)] +enum InnerPartial { + Decoding(DecodeFutureTyped), + Finished([u32; 8]), +} + +impl InnerPartial { + pub(crate) fn try_recv(&mut self) -> Result, PrfError> { + match self { + InnerPartial::Decoding(value) => { + let value = value.try_recv().map_err(PrfError::vm)?; + if let Some(value) = value { + *self = InnerPartial::Finished(value); + } + Ok(value) + } + InnerPartial::Finished(value) => Ok(Some(*value)), + } + } +} From 32eff72a00db569a373e632608d50bd6a8619283 Mon Sep 17 00:00:00 2001 From: th4s Date: Mon, 28 Apr 2025 21:39:38 +0200 Subject: [PATCH 14/36] repair prf bench --- crates/components/hmac-sha256/benches/prf.rs | 76 ++++++++++---------- 1 file changed, 37 insertions(+), 39 deletions(-) diff --git a/crates/components/hmac-sha256/benches/prf.rs b/crates/components/hmac-sha256/benches/prf.rs index 92b48cbb4a..926f503718 100644 --- a/crates/components/hmac-sha256/benches/prf.rs +++ b/crates/components/hmac-sha256/benches/prf.rs @@ -82,63 +82,61 @@ async fn prf() { let _ = follower_vm.decode(follower_output.keys.client_iv).unwrap(); let _ = follower_vm.decode(follower_output.keys.server_iv).unwrap(); - loop { - let leader_finished = leader.drive_key_expansion(&mut leader_vm).unwrap(); - let follower_finished = follower.drive_key_expansion(&mut follower_vm).unwrap(); - + while leader.wants_flush() || follower.wants_flush() { tokio::try_join!( - leader_vm.execute_all(&mut leader_ctx), - follower_vm.execute_all(&mut follower_ctx) + async { + leader.flush(&mut leader_vm).unwrap(); + leader_vm.execute_all(&mut leader_ctx).await + }, + async { + follower.flush(&mut follower_vm).unwrap(); + follower_vm.execute_all(&mut follower_ctx).await + } ) .unwrap(); - - if leader_finished && follower_finished { - break; - } } let cf_hs_hash = [1u8; 32]; - let sf_hs_hash = [2u8; 32]; leader.set_cf_hash(cf_hs_hash).unwrap(); - leader.set_sf_hash(sf_hs_hash).unwrap(); - follower.set_cf_hash(cf_hs_hash).unwrap(); - follower.set_sf_hash(sf_hs_hash).unwrap(); - - let _ = leader_vm.decode(leader_output.cf_vd).unwrap(); - let _ = leader_vm.decode(leader_output.sf_vd).unwrap(); - - let _ = follower_vm.decode(follower_output.cf_vd).unwrap(); - let _ = follower_vm.decode(follower_output.sf_vd).unwrap(); - - loop { - let leader_finished = leader.drive_client_finished(&mut leader_vm).unwrap(); - let follower_finished = follower.drive_client_finished(&mut follower_vm).unwrap(); + while leader.wants_flush() || follower.wants_flush() { tokio::try_join!( - leader_vm.execute_all(&mut leader_ctx), - follower_vm.execute_all(&mut follower_ctx) + async { + leader.flush(&mut leader_vm).unwrap(); + leader_vm.execute_all(&mut leader_ctx).await + }, + async { + follower.flush(&mut follower_vm).unwrap(); + follower_vm.execute_all(&mut follower_ctx).await + } ) .unwrap(); - - if leader_finished && follower_finished { - break; - } } - loop { - let leader_finished = leader.drive_server_finished(&mut leader_vm).unwrap(); - let follower_finished = follower.drive_server_finished(&mut follower_vm).unwrap(); + let _ = leader_vm.decode(leader_output.cf_vd).unwrap(); + let _ = follower_vm.decode(follower_output.cf_vd).unwrap(); + + let sf_hs_hash = [2u8; 32]; + + leader.set_sf_hash(sf_hs_hash).unwrap(); + follower.set_sf_hash(sf_hs_hash).unwrap(); + while leader.wants_flush() || follower.wants_flush() { tokio::try_join!( - leader_vm.execute_all(&mut leader_ctx), - follower_vm.execute_all(&mut follower_ctx) + async { + leader.flush(&mut leader_vm).unwrap(); + leader_vm.execute_all(&mut leader_ctx).await + }, + async { + follower.flush(&mut follower_vm).unwrap(); + follower_vm.execute_all(&mut follower_ctx).await + } ) .unwrap(); - - if leader_finished && follower_finished { - break; - } } + + let _ = leader_vm.decode(leader_output.sf_vd).unwrap(); + let _ = follower_vm.decode(follower_output.sf_vd).unwrap(); } From 07f82368a200224e14cbebc1984af8447be11f07 Mon Sep 17 00:00:00 2001 From: th4s Date: Tue, 29 Apr 2025 13:49:06 +0200 Subject: [PATCH 15/36] WIP: Adapting to new `Sha256` from mpz --- Cargo.toml | 25 +- crates/components/hmac-sha256/Cargo.toml | 1 + crates/components/hmac-sha256/src/error.rs | 14 +- crates/components/hmac-sha256/src/hmac.rs | 56 +-- crates/components/hmac-sha256/src/lib.rs | 29 +- .../hmac-sha256/src/prf/function/mod.rs | 26 +- .../hmac-sha256/src/prf/function/normal.rs | 48 ++- .../hmac-sha256/src/prf/function/reduced.rs | 55 +-- crates/components/hmac-sha256/src/prf/mod.rs | 64 +-- .../components/hmac-sha256/src/prf/state.rs | 8 +- crates/components/hmac-sha256/src/sha256.rs | 381 ------------------ .../components/hmac-sha256/src/test_utils.rs | 6 +- 12 files changed, 192 insertions(+), 521 deletions(-) delete mode 100644 crates/components/hmac-sha256/src/sha256.rs diff --git a/Cargo.toml b/Cargo.toml index 1a1d200454..2801b71f6a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -69,18 +69,19 @@ tlsn-tls-core = { path = "crates/tls/core" } tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" } tlsn-verifier = { path = "crates/verifier" } -mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" } -mpz-memory-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" } -mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" } -mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" } -mpz-vm-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" } -mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" } -mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" } -mpz-ole = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" } -mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" } -mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" } -mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" } -mpz-zk = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "alpha.3" } +mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "dev" } +mpz-memory-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "dev" } +mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "dev" } +mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "dev" } +mpz-vm-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "dev" } +mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "dev" } +mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "dev" } +mpz-ole = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "dev" } +mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "dev" } +mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "dev" } +mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "dev" } +mpz-zk = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "dev" } +mpz-hash = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "dev" } rangeset = { version = "0.2" } serio = { version = "0.2" } diff --git a/crates/components/hmac-sha256/Cargo.toml b/crates/components/hmac-sha256/Cargo.toml index 5b78a71df7..fa1fd4d7dd 100644 --- a/crates/components/hmac-sha256/Cargo.toml +++ b/crates/components/hmac-sha256/Cargo.toml @@ -18,6 +18,7 @@ name = "hmac_sha256" mpz-vm-core = { workspace = true } mpz-core = { workspace = true } mpz-circuits = { workspace = true } +mpz-hash = { workspace = true } thiserror = { workspace = true } tracing = { workspace = true } diff --git a/crates/components/hmac-sha256/src/error.rs b/crates/components/hmac-sha256/src/error.rs index bad3acf5f6..a7cf225edf 100644 --- a/crates/components/hmac-sha256/src/error.rs +++ b/crates/components/hmac-sha256/src/error.rs @@ -1,6 +1,8 @@ use core::fmt; use std::error::Error; +use mpz_hash::sha256::Sha256Error; + /// A PRF error. #[derive(Debug, thiserror::Error)] pub struct PrfError { @@ -20,15 +22,21 @@ impl PrfError { } } + pub(crate) fn vm>>(err: E) -> Self { + Self::new(ErrorKind::Vm, err) + } + pub(crate) fn state(msg: impl Into) -> Self { Self { kind: ErrorKind::State, source: Some(msg.into().into()), } } +} - pub(crate) fn vm>>(err: E) -> Self { - Self::new(ErrorKind::Vm, err) +impl From for PrfError { + fn from(value: Sha256Error) -> Self { + Self::new(ErrorKind::Hash, value) } } @@ -36,6 +44,7 @@ impl PrfError { pub(crate) enum ErrorKind { Vm, State, + Hash, } impl fmt::Display for PrfError { @@ -43,6 +52,7 @@ impl fmt::Display for PrfError { match self.kind { ErrorKind::Vm => write!(f, "vm error")?, ErrorKind::State => write!(f, "state error")?, + ErrorKind::Hash => write!(f, "hash error")?, } if let Some(ref source) = self.source { diff --git a/crates/components/hmac-sha256/src/hmac.rs b/crates/components/hmac-sha256/src/hmac.rs index beb48c9aae..a407045238 100644 --- a/crates/components/hmac-sha256/src/hmac.rs +++ b/crates/components/hmac-sha256/src/hmac.rs @@ -18,19 +18,21 @@ //! * `outer_partial` - key' xor opad //! * `inner_local` - H((key' xor ipad) || m) -use crate::{sha256::Sha256, PrfError}; +use mpz_hash::sha256::Sha256; use mpz_vm_core::{ memory::{ - binary::{Binary, U32, U8}, + binary::{Binary, U8}, Array, }, Vm, }; +use crate::PrfError; + /// Computes HMAC-SHA256. #[derive(Debug)] pub(crate) struct HmacSha256 { - outer_partial: Array, + outer_partial: Sha256, inner_local: Array, } @@ -44,7 +46,7 @@ impl HmacSha256 { /// /// * `outer_partial` - (key' xor opad) /// * `inner_local` - H((key' xor ipad) || m) - pub(crate) fn new(outer_partial: Array, inner_local: Array) -> Self { + pub(crate) fn new(outer_partial: Sha256, inner_local: Array) -> Self { Self { outer_partial, inner_local, @@ -56,28 +58,22 @@ impl HmacSha256 { /// # Arguments /// /// * `vm` - The virtual machine. - pub(crate) fn alloc(self, vm: &mut dyn Vm) -> Result, PrfError> { - let inner_local = self.inner_local.into(); - - let mut outer = Sha256::default(); - outer - .set_state(self.outer_partial, 64) - .update(inner_local) - .add_padding(vm)?; - - outer.alloc(vm) + pub(crate) fn alloc(mut self, vm: &mut dyn Vm) -> Result, PrfError> { + self.outer_partial.update(&self.inner_local); + self.outer_partial.compress(vm)?; + self.outer_partial.finalize(vm).map_err(PrfError::from) } } #[cfg(test)] mod tests { use crate::{ - convert_to_bytes, hmac::HmacSha256, - sha256::sha256, + sha256, state_to_bytes, test_utils::{compute_inner_local, compute_outer_partial, mock_vm}, }; use mpz_common::context::test_st_context; + use mpz_hash::sha256::Sha256; use mpz_vm_core::{ memory::{ binary::{U32, U8}, @@ -94,9 +90,9 @@ mod tests { let outer_partial = compute_outer_partial(input.0.clone()); let inner_local = compute_inner_local(input.0.clone(), &input.1); - let hmac = sha256(outer_partial, 64, &convert_to_bytes(inner_local)); + let hmac = sha256(outer_partial, 64, &state_to_bytes(inner_local)); - assert_eq!(convert_to_bytes(hmac), reference); + assert_eq!(state_to_bytes(hmac), reference); } } @@ -118,13 +114,16 @@ mod tests { let inner_local_leader: Array = leader.alloc().unwrap(); leader.mark_public(inner_local_leader).unwrap(); leader - .assign(inner_local_leader, convert_to_bytes(inner_local)) + .assign(inner_local_leader, state_to_bytes(inner_local)) .unwrap(); leader.commit(inner_local_leader).unwrap(); - let hmac_leader = HmacSha256::new(outer_partial_leader, inner_local_leader) - .alloc(&mut leader) - .unwrap(); + let hmac_leader = HmacSha256::new( + Sha256::new_from_state(outer_partial_leader, 64), + inner_local_leader, + ) + .alloc(&mut leader) + .unwrap(); let hmac_leader = leader.decode(hmac_leader).unwrap(); let outer_partial_follower: Array = follower.alloc().unwrap(); @@ -137,13 +136,16 @@ mod tests { let inner_local_follower: Array = follower.alloc().unwrap(); follower.mark_public(inner_local_follower).unwrap(); follower - .assign(inner_local_follower, convert_to_bytes(inner_local)) + .assign(inner_local_follower, state_to_bytes(inner_local)) .unwrap(); follower.commit(inner_local_follower).unwrap(); - let hmac_follower = HmacSha256::new(outer_partial_follower, inner_local_follower) - .alloc(&mut follower) - .unwrap(); + let hmac_follower = HmacSha256::new( + Sha256::new_from_state(outer_partial_follower, 64), + inner_local_follower, + ) + .alloc(&mut follower) + .unwrap(); let hmac_follower = follower.decode(hmac_follower).unwrap(); let (hmac_leader, hmac_follower) = tokio::try_join!( @@ -159,7 +161,7 @@ mod tests { .unwrap(); assert_eq!(hmac_leader, hmac_follower); - assert_eq!(convert_to_bytes(hmac_leader), reference); + assert_eq!(hmac_leader, reference); } } diff --git a/crates/components/hmac-sha256/src/lib.rs b/crates/components/hmac-sha256/src/lib.rs index 050162f499..7f76543b38 100644 --- a/crates/components/hmac-sha256/src/lib.rs +++ b/crates/components/hmac-sha256/src/lib.rs @@ -5,7 +5,6 @@ #![forbid(unsafe_code)] mod hmac; -mod sha256; #[cfg(test)] mod test_utils; @@ -44,7 +43,24 @@ pub struct SessionKeys { pub server_iv: Array, } -fn convert_to_bytes(input: [u32; 8]) -> [u8; 32] { +fn sha256(mut state: [u32; 8], pos: usize, msg: &[u8]) -> [u32; 8] { + use sha2::{ + compress256, + digest::{ + block_buffer::{BlockBuffer, Eager}, + generic_array::typenum::U64, + }, + }; + + let mut buffer = BlockBuffer::::default(); + buffer.digest_blocks(msg, |b| compress256(&mut state, b)); + buffer.digest_pad(0x80, &(((msg.len() + pos) * 8) as u64).to_be_bytes(), |b| { + compress256(&mut state, &[*b]) + }); + state +} + +fn state_to_bytes(input: [u32; 8]) -> [u8; 32] { let mut output = [0_u8; 32]; for (k, byte_chunk) in input.iter().enumerate() { let byte_chunk = byte_chunk.to_be_bytes(); @@ -53,6 +69,15 @@ fn convert_to_bytes(input: [u32; 8]) -> [u8; 32] { output } +fn bytes_to_state(byte_input: [u8; 32]) -> [u32; 8] { + let mut output = [0_u32; 8]; + for (k, bytes) in byte_input.chunks_exact(4).rev().enumerate() { + let value = u32::from_le_bytes(bytes.try_into().unwrap()); + output[k] = value; + } + output +} + #[cfg(test)] mod tests { use crate::{ diff --git a/crates/components/hmac-sha256/src/prf/function/mod.rs b/crates/components/hmac-sha256/src/prf/function/mod.rs index 72de5cb1e4..9738b2a52c 100644 --- a/crates/components/hmac-sha256/src/prf/function/mod.rs +++ b/crates/components/hmac-sha256/src/prf/function/mod.rs @@ -1,9 +1,10 @@ //! Provides [`Prf`], for computing the TLS 1.2 PRF. use crate::{Mode, PrfError}; +use mpz_hash::sha256::Sha256; use mpz_vm_core::{ memory::{ - binary::{Binary, U32}, + binary::{Binary, U8}, Array, }, Vm, @@ -22,8 +23,8 @@ impl Prf { pub(crate) fn alloc_master_secret( mode: Mode, vm: &mut dyn Vm, - outer_partial: Array, - inner_partial: Array, + outer_partial: Sha256, + inner_partial: Sha256, ) -> Result { let prf = match mode { Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_master_secret( @@ -43,8 +44,8 @@ impl Prf { pub(crate) fn alloc_key_expansion( mode: Mode, vm: &mut dyn Vm, - outer_partial: Array, - inner_partial: Array, + outer_partial: Sha256, + inner_partial: Sha256, ) -> Result { let prf = match mode { Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_key_expansion( @@ -64,8 +65,8 @@ impl Prf { pub(crate) fn alloc_client_finished( config: Mode, vm: &mut dyn Vm, - outer_partial: Array, - inner_partial: Array, + outer_partial: Sha256, + inner_partial: Sha256, ) -> Result { let prf = match config { Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_client_finished( @@ -85,8 +86,8 @@ impl Prf { pub(crate) fn alloc_server_finished( config: Mode, vm: &mut dyn Vm, - outer_partial: Array, - inner_partial: Array, + outer_partial: Sha256, + inner_partial: Sha256, ) -> Result { let prf = match config { Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_server_finished( @@ -124,7 +125,7 @@ impl Prf { } } - pub(crate) fn output(&self) -> Vec> { + pub(crate) fn output(&self) -> Vec> { match self { Prf::Reduced(prf) => prf.output(), Prf::Normal(prf) => prf.output(), @@ -135,7 +136,6 @@ impl Prf { #[cfg(test)] mod tests { use crate::{ - convert_to_bytes, prf::{compute_partial, function::Prf}, test_utils::{mock_vm, phash}, Mode, @@ -238,11 +238,11 @@ mod tests { let prf_result_leader: Vec = prf_out_leader .iter_mut() - .flat_map(|p| convert_to_bytes(p.try_recv().unwrap().unwrap())) + .flat_map(|p| p.try_recv().unwrap().unwrap()) .collect(); let prf_result_follower: Vec = prf_out_follower .iter_mut() - .flat_map(|p| convert_to_bytes(p.try_recv().unwrap().unwrap())) + .flat_map(|p| p.try_recv().unwrap().unwrap()) .collect(); let expected = phash(key.to_vec(), &label_seed, iterations); diff --git a/crates/components/hmac-sha256/src/prf/function/normal.rs b/crates/components/hmac-sha256/src/prf/function/normal.rs index b37e65699e..6706977ae0 100644 --- a/crates/components/hmac-sha256/src/prf/function/normal.rs +++ b/crates/components/hmac-sha256/src/prf/function/normal.rs @@ -1,7 +1,8 @@ //! Computes the whole PRF in MPC. -use crate::{hmac::HmacSha256, sha256::Sha256, PrfError}; +use crate::{hmac::HmacSha256, PrfError}; use mpz_circuits::CircuitBuilder; +use mpz_hash::sha256::Sha256; use mpz_vm_core::{ memory::{ binary::{Binary, U32, U8}, @@ -30,32 +31,32 @@ impl PrfFunction { pub(crate) fn alloc_master_secret( vm: &mut dyn Vm, - outer_partial: Array, - inner_partial: Array, + outer_partial: Sha256, + inner_partial: Sha256, ) -> Result { Self::alloc(vm, Self::MS_LABEL, outer_partial, inner_partial, 48, 64) } pub(crate) fn alloc_key_expansion( vm: &mut dyn Vm, - outer_partial: Array, - inner_partial: Array, + outer_partial: Sha256, + inner_partial: Sha256, ) -> Result { Self::alloc(vm, Self::KEY_LABEL, outer_partial, inner_partial, 40, 64) } pub(crate) fn alloc_client_finished( vm: &mut dyn Vm, - outer_partial: Array, - inner_partial: Array, + outer_partial: Sha256, + inner_partial: Sha256, ) -> Result { Self::alloc(vm, Self::CF_LABEL, outer_partial, inner_partial, 12, 32) } pub(crate) fn alloc_server_finished( vm: &mut dyn Vm, - outer_partial: Array, - inner_partial: Array, + outer_partial: Sha256, + inner_partial: Sha256, ) -> Result { Self::alloc(vm, Self::SF_LABEL, outer_partial, inner_partial, 12, 32) } @@ -90,15 +91,15 @@ impl PrfFunction { self.start_seed_label = start_seed_label; } - pub(crate) fn output(&self) -> Vec> { + pub(crate) fn output(&self) -> Vec> { self.p.iter().map(|p| p.output).collect() } fn alloc( vm: &mut dyn Vm, label: &'static [u8], - outer_partial: Array, - inner_partial: Array, + outer_partial: Sha256, + inner_partial: Sha256, output_len: usize, seed_len: usize, ) -> Result { @@ -119,12 +120,12 @@ impl PrfFunction { let mut msg_a = seed_label_ref; for _ in 0..iterations { - let a = PHash::alloc(vm, outer_partial, inner_partial, msg_a)?; - msg_a = convert_array(vm, a.output)?.into(); + let a = PHash::alloc(vm, outer_partial.clone(), inner_partial.clone(), msg_a)?; + msg_a = Vector::::from(a.output); prf.a.push(a); let msg_p = merge_vecs(vm, vec![msg_a, seed_label_ref])?; - let p = PHash::alloc(vm, outer_partial, inner_partial, msg_p)?; + let p = PHash::alloc(vm, outer_partial.clone(), inner_partial.clone(), msg_p)?; prf.p.push(p); } @@ -141,23 +142,20 @@ enum State { #[derive(Debug, Clone)] struct PHash { msg: Vector, - output: Array, + output: Array, } impl PHash { fn alloc( vm: &mut dyn Vm, - outer_partial: Array, - inner_partial: Array, + outer_partial: Sha256, + inner_partial: Sha256, msg: Vector, ) -> Result { - let mut sha = Sha256::default(); - sha.set_state(inner_partial, 64) - .update(msg) - .add_padding(vm)?; - - let inner_local = sha.alloc(vm)?; - let inner_local = convert_array(vm, inner_local)?; + let mut inner_local = inner_partial; + inner_local.update(&msg); + inner_local.compress(vm)?; + let inner_local = inner_local.finalize(vm)?; let hmac = HmacSha256::new(outer_partial, inner_local); let output = hmac.alloc(vm).map_err(PrfError::vm)?; diff --git a/crates/components/hmac-sha256/src/prf/function/reduced.rs b/crates/components/hmac-sha256/src/prf/function/reduced.rs index 3eb67bfee7..3e4a0baa2a 100644 --- a/crates/components/hmac-sha256/src/prf/function/reduced.rs +++ b/crates/components/hmac-sha256/src/prf/function/reduced.rs @@ -1,10 +1,11 @@ //! Computes some hashes of the PRF locally. -use crate::{convert_to_bytes, hmac::HmacSha256, sha256::sha256, PrfError}; +use crate::{bytes_to_state, hmac::HmacSha256, sha256, state_to_bytes, PrfError}; use mpz_core::bitvec::BitVec; +use mpz_hash::sha256::Sha256; use mpz_vm_core::{ memory::{ - binary::{Binary, U32, U8}, + binary::{Binary, U8}, Array, DecodeFutureTyped, MemoryExt, ViewExt, }, Vm, @@ -31,32 +32,32 @@ impl PrfFunction { pub(crate) fn alloc_master_secret( vm: &mut dyn Vm, - outer_partial: Array, - inner_partial: Array, + outer_partial: Sha256, + inner_partial: Sha256, ) -> Result { Self::alloc(vm, Self::MS_LABEL, outer_partial, inner_partial, 48) } pub(crate) fn alloc_key_expansion( vm: &mut dyn Vm, - outer_partial: Array, - inner_partial: Array, + outer_partial: Sha256, + inner_partial: Sha256, ) -> Result { Self::alloc(vm, Self::KEY_LABEL, outer_partial, inner_partial, 40) } pub(crate) fn alloc_client_finished( vm: &mut dyn Vm, - outer_partial: Array, - inner_partial: Array, + outer_partial: Sha256, + inner_partial: Sha256, ) -> Result { Self::alloc(vm, Self::CF_LABEL, outer_partial, inner_partial, 12) } pub(crate) fn alloc_server_finished( vm: &mut dyn Vm, - outer_partial: Array, - inner_partial: Array, + outer_partial: Sha256, + inner_partial: Sha256, ) -> Result { Self::alloc(vm, Self::SF_LABEL, outer_partial, inner_partial, 12) } @@ -84,7 +85,7 @@ impl PrfFunction { } State::Assigned { output } => { if let Some(output) = output.try_recv().map_err(PrfError::vm)? { - let output = convert_to_bytes(output).to_vec(); + let output = output.to_vec(); a.state = State::Finished { output: output.clone(), }; @@ -104,7 +105,7 @@ impl PrfFunction { } State::Assigned { output } => { if let Some(output) = output.try_recv().map_err(PrfError::vm)? { - let output = convert_to_bytes(output).to_vec(); + let output = output.to_vec(); p.state = State::Finished { output }; } } @@ -123,17 +124,18 @@ impl PrfFunction { self.a_msg = start_seed_label; } - pub(crate) fn output(&self) -> Vec> { + pub(crate) fn output(&self) -> Vec> { self.p.iter().map(|p| p.output).collect() } fn alloc( vm: &mut dyn Vm, label: &'static [u8], - outer_partial: Array, - inner_partial: Array, + outer_partial: Sha256, + inner_partial: Sha256, len: usize, ) -> Result { + let inner_partial = inner_partial.finalize(vm)?; let inner_partial = vm.decode(inner_partial).map_err(PrfError::vm)?; let mut prf = Self { @@ -150,10 +152,10 @@ impl PrfFunction { let iterations = len / 32 + ((len % 32) != 0) as usize; for _ in 0..iterations { - let a = PHash::alloc(vm, outer_partial)?; + let a = PHash::alloc(vm, outer_partial.clone())?; prf.a.push(a); - let p = PHash::alloc(vm, outer_partial)?; + let p = PHash::alloc(vm, outer_partial.clone())?; prf.p.push(p); } @@ -163,13 +165,13 @@ impl PrfFunction { #[derive(Debug)] struct PHash { - output: Array, + output: Array, state: State, } impl PHash { - fn alloc(vm: &mut dyn Vm, outer_partial: Array) -> Result { - let inner_local = vm.alloc().map_err(PrfError::vm)?; + fn alloc(vm: &mut dyn Vm, outer_partial: Sha256) -> Result { + let inner_local: Array = vm.alloc().map_err(PrfError::vm)?; let hmac = HmacSha256::new(outer_partial, inner_local); let output = hmac.alloc(vm).map_err(PrfError::vm)?; @@ -185,14 +187,15 @@ impl PHash { fn assign_inner_local( &mut self, vm: &mut dyn Vm, - inner_partial: [u32; 8], + inner_partial: [u8; 32], msg: &[u8], ) -> Result<(), PrfError> { if let State::Init { inner_local, .. } = self.state { + let inner_partial = bytes_to_state(inner_partial); let inner_local_value = sha256(inner_partial, 64, msg); vm.mark_public(inner_local).map_err(PrfError::vm)?; - vm.assign(inner_local, convert_to_bytes(inner_local_value)) + vm.assign(inner_local, state_to_bytes(inner_local_value)) .map_err(PrfError::vm)?; vm.commit(inner_local).map_err(PrfError::vm)?; @@ -210,7 +213,7 @@ enum State { inner_local: Array, }, Assigned { - output: DecodeFutureTyped, + output: DecodeFutureTyped, }, Finished { output: Vec, @@ -219,12 +222,12 @@ enum State { #[derive(Debug)] enum InnerPartial { - Decoding(DecodeFutureTyped), - Finished([u32; 8]), + Decoding(DecodeFutureTyped), + Finished([u8; 32]), } impl InnerPartial { - pub(crate) fn try_recv(&mut self) -> Result, PrfError> { + pub(crate) fn try_recv(&mut self) -> Result, PrfError> { match self { InnerPartial::Decoding(value) => { let value = value.try_recv().map_err(PrfError::vm)?; diff --git a/crates/components/hmac-sha256/src/prf/mod.rs b/crates/components/hmac-sha256/src/prf/mod.rs index ce0b93823c..8b4009b96c 100644 --- a/crates/components/hmac-sha256/src/prf/mod.rs +++ b/crates/components/hmac-sha256/src/prf/mod.rs @@ -1,8 +1,9 @@ -use crate::{hmac::HmacSha256, sha256::Sha256, Mode, PrfError, PrfOutput}; +use crate::{hmac::HmacSha256, Mode, PrfError, PrfOutput}; use mpz_circuits::{circuits::xor, Circuit, CircuitBuilder}; +use mpz_hash::sha256::Sha256; use mpz_vm_core::{ memory::{ - binary::{Binary, U32, U8}, + binary::{Binary, U8}, Array, MemoryExt, StaticSize, Vector, ViewExt, }, Call, CallableExt, Vm, @@ -66,11 +67,20 @@ impl MpcPrf { let outer_partial_ms = compute_partial(vm, ms, HmacSha256::OPAD)?; let inner_partial_ms = compute_partial(vm, ms, HmacSha256::IPAD)?; - let key_expansion = Prf::alloc_key_expansion(mode, vm, outer_partial_ms, inner_partial_ms)?; - let client_finished = - Prf::alloc_client_finished(mode, vm, outer_partial_ms, inner_partial_ms)?; - let server_finished = - Prf::alloc_server_finished(mode, vm, outer_partial_ms, inner_partial_ms)?; + let key_expansion = + Prf::alloc_key_expansion(mode, vm, outer_partial_ms.clone(), inner_partial_ms.clone())?; + let client_finished = Prf::alloc_client_finished( + mode, + vm, + outer_partial_ms.clone(), + inner_partial_ms.clone(), + )?; + let server_finished = Prf::alloc_server_finished( + mode, + vm, + outer_partial_ms.clone(), + inner_partial_ms.clone(), + )?; self.state = State::SessionKeys { client_random: None, @@ -242,7 +252,7 @@ fn compute_partial( vm: &mut dyn Vm, key: Vector, mask: [u8; 64], -) -> Result, PrfError> { +) -> Result { let xor = Arc::new(xor(8 * 64)); let additional_len = 64 - key.len(); @@ -264,24 +274,25 @@ fn compute_partial( .arg(mask_ref) .build() .map_err(PrfError::vm)?; - let key_padded = vm.call(xor).map_err(PrfError::vm)?; + let key_padded: Vector = vm.call(xor).map_err(PrfError::vm)?; - let mut sha = Sha256::default(); - sha.update(key_padded); - sha.alloc(vm) + let mut sha = Sha256::new_with_init(vm)?; + sha.update(&key_padded); + sha.compress(vm)?; + Ok(sha) } fn merge_outputs( vm: &mut dyn Vm, - inputs: Vec>, + inputs: Vec>, output_bytes: usize, ) -> Result, PrfError> { assert!(output_bytes <= 32 * inputs.len()); - let bits = Array::::SIZE * inputs.len(); - let msb0_circ = gen_merge_circ(4, bits); + let bits = Array::::SIZE * inputs.len(); + let circ = gen_merge_circ(1, bits); - let mut builder = Call::builder(msb0_circ); + let mut builder = Call::builder(circ); for &input in inputs.iter() { builder = builder.arg(input); } @@ -299,6 +310,7 @@ fn gen_merge_circ(element_byte_size: usize, size: usize) -> Arc { let inputs = (0..size).map(|_| builder.add_input()).collect::>(); for input in inputs.chunks_exact(element_byte_size * 8) { + // TODO: .rev() removed here, correct? for byte in input.chunks_exact(8).rev() { for &feed in byte.iter() { let output = builder.add_id_gate(feed); @@ -312,10 +324,10 @@ fn gen_merge_circ(element_byte_size: usize, size: usize) -> Arc { #[cfg(test)] mod tests { - use crate::{convert_to_bytes, prf::merge_outputs, test_utils::mock_vm}; + use crate::{prf::merge_outputs, test_utils::mock_vm}; use mpz_common::context::test_st_context; use mpz_vm_core::{ - memory::{binary::U32, Array, MemoryExt, ViewExt}, + memory::{binary::U8, Array, MemoryExt, ViewExt}, Execute, }; @@ -324,16 +336,16 @@ mod tests { let (mut ctx_a, mut ctx_b) = test_st_context(8); let (mut leader, mut follower) = mock_vm(); - let input1: [u32; 8] = std::array::from_fn(|i| i as u32); - let input2: [u32; 8] = std::array::from_fn(|i| i as u32 + 8); + let input1: [u8; 32] = std::array::from_fn(|i| i as u8); + let input2: [u8; 32] = std::array::from_fn(|i| i as u8 + 32); - let mut expected = convert_to_bytes(input1).to_vec(); - expected.extend_from_slice(&convert_to_bytes(input2)); + let mut expected = input1.to_vec(); + expected.extend_from_slice(&input2); expected.truncate(48); // leader - let input1_leader: Array = leader.alloc().unwrap(); - let input2_leader: Array = leader.alloc().unwrap(); + let input1_leader: Array = leader.alloc().unwrap(); + let input2_leader: Array = leader.alloc().unwrap(); leader.mark_public(input1_leader).unwrap(); leader.mark_public(input2_leader).unwrap(); @@ -349,8 +361,8 @@ mod tests { let mut merged_leader = leader.decode(merged_leader).unwrap(); // follower - let input1_follower: Array = follower.alloc().unwrap(); - let input2_follower: Array = follower.alloc().unwrap(); + let input1_follower: Array = follower.alloc().unwrap(); + let input2_follower: Array = follower.alloc().unwrap(); follower.mark_public(input1_follower).unwrap(); follower.mark_public(input2_follower).unwrap(); diff --git a/crates/components/hmac-sha256/src/prf/state.rs b/crates/components/hmac-sha256/src/prf/state.rs index 52eaec68de..406de8efc3 100644 --- a/crates/components/hmac-sha256/src/prf/state.rs +++ b/crates/components/hmac-sha256/src/prf/state.rs @@ -4,7 +4,7 @@ use crate::{ }; use mpz_vm_core::{ memory::{ - binary::{Binary, U32, U8}, + binary::{Binary, U8}, Array, FromRaw, ToRaw, }, Vm, @@ -61,7 +61,7 @@ impl State { } fn get_session_keys( - output: Vec>, + output: Vec>, vm: &mut dyn Vm, ) -> Result { let mut keys = merge_outputs(vm, output, 40)?; @@ -83,7 +83,7 @@ fn get_session_keys( } fn get_client_finished_vd( - output: Vec>, + output: Vec>, vm: &mut dyn Vm, ) -> Result, PrfError> { let cf_vd = merge_outputs(vm, output, 12)?; @@ -93,7 +93,7 @@ fn get_client_finished_vd( } fn get_server_finished_vd( - output: Vec>, + output: Vec>, vm: &mut dyn Vm, ) -> Result, PrfError> { let sf_vd = merge_outputs(vm, output, 12)?; diff --git a/crates/components/hmac-sha256/src/sha256.rs b/crates/components/hmac-sha256/src/sha256.rs deleted file mode 100644 index 752a5f5961..0000000000 --- a/crates/components/hmac-sha256/src/sha256.rs +++ /dev/null @@ -1,381 +0,0 @@ -//! Computation of SHA256. - -use crate::PrfError; -use mpz_circuits::circuits::SHA256_COMPRESS; -use mpz_vm_core::{ - memory::{ - binary::{Binary, U32, U8}, - Array, MemoryExt, Vector, ViewExt, - }, - Call, CallableExt, Vm, -}; - -/// Computes SHA256. -#[derive(Debug, Default)] -pub(crate) struct Sha256 { - state: Option>, - chunks: Vec>, - processed: usize, -} - -impl Sha256 { - /// The default initialization vector. - const IV: [u32; 8] = [ - 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, - 0x5be0cd19, - ]; - - /// Sets the state. - /// - /// # Arguments - /// - /// * `state` - The starting state for the SHA256 compression function. - /// * `processed` - The number of already processed bytes corresponding to - /// `state`. - pub(crate) fn set_state(&mut self, state: Array, processed: usize) -> &mut Self { - self.state = Some(state); - self.processed = processed; - self - } - - /// Feeds data into the hash function. - /// - /// # Arguments - /// - /// * `data` - The data to hash. - pub(crate) fn update(&mut self, data: Vector) -> &mut Self { - self.chunks.push(data); - self - } - - /// Computes the padding for SHA256. - /// - /// Padding is computed depending on [`Self::state`] and - /// [`Self::processed`]. - /// - /// # Arguments - /// - /// * `vm` - The virtual machine. - pub(crate) fn add_padding(&mut self, vm: &mut dyn Vm) -> Result<&mut Self, PrfError> { - let msg_len: usize = self.chunks.iter().map(|b| b.len()).sum(); - let pos = self.processed; - - let bit_len = msg_len * 8; - let processed_bit_len = (bit_len + (pos * 8)) as u64; - - // minimum length of padded message in bytes - let min_padded_len = msg_len + 9; - // number of 64-byte blocks rounded up - let block_count = (min_padded_len / 64) + (min_padded_len % 64 != 0) as usize; - // message is padded to a multiple of 64 bytes - let padded_len = block_count * 64; - // number of bytes to pad - let pad_len = padded_len - msg_len; - - // append a single '1' bit - // append K '0' bits, where K is the minimum number >= 0 such that (L + 1 + K + - // 64) is a multiple of 512 append L as a 64-bit big-endian integer, making - // the total post-processed length a multiple of 512 bits such that the bits - // in the message are: 1 , (the number of bits will be a multiple of 512) - let mut padding = Vec::new(); - padding.push(128_u8); - padding.extend((0..pad_len - 9).map(|_| 0_u8)); - padding.extend(processed_bit_len.to_be_bytes()); - - let padding_ref: Vector = vm.alloc_vec(padding.len()).map_err(PrfError::vm)?; - - vm.mark_public(padding_ref).map_err(PrfError::vm)?; - vm.assign(padding_ref, padding).map_err(PrfError::vm)?; - vm.commit(padding_ref).map_err(PrfError::vm)?; - - self.chunks.push(padding_ref); - Ok(self) - } - - /// Adds the [`Call`] to the [`Vm`], and returns the output. - /// - /// # Arguments - /// - /// * `vm` - The virtual machine. - pub(crate) fn alloc(self, vm: &mut dyn Vm) -> Result, PrfError> { - let mut state = if let Some(state) = self.state { - state - } else { - Self::assign_iv(vm)? - }; - - // SHA256 compression function takes 64 byte blocks as inputs but our blocks in - // `self.chunks` can have arbitrary size to simplify the api. So we need to - // repartition them to 64 byte blocks and feed those into the - // compression function. - let mut remainder = None; - let mut block: Vec> = vec![]; - let mut chunk_iter = self.chunks.iter().copied(); - - loop { - if let Some(remainder) = remainder.take() { - block.push(remainder); - } - let Some(mut chunk) = chunk_iter.next() else { - break; - }; - - let len_before: usize = block.iter().map(|b| b.len()).sum(); - let len_after = len_before + chunk.len(); - - if len_after <= 64 { - block.push(chunk); - } else { - let excess_len = len_after - 64; - remainder = Some(chunk.split_off(chunk.len() - excess_len)); - - block.push(chunk); - state = Self::compute_state(vm, state, &block)?; - block.clear(); - } - } - - Self::compute_state(vm, state, &block) - } - - fn assign_iv(vm: &mut dyn Vm) -> Result, PrfError> { - let iv: Array = vm.alloc().map_err(PrfError::vm)?; - - vm.mark_public(iv).map_err(PrfError::vm)?; - vm.assign(iv, Self::IV).map_err(PrfError::vm)?; - vm.commit(iv).map_err(PrfError::vm)?; - - Ok(iv) - } - - fn compute_state( - vm: &mut dyn Vm, - state: Array, - data: &[Vector], - ) -> Result, PrfError> { - let mut compress = Call::builder(SHA256_COMPRESS.clone()); - - for &block in data { - compress = compress.arg(block); - } - - let compress = compress.arg(state).build().map_err(PrfError::vm)?; - vm.call(compress).map_err(PrfError::vm) - } -} - -/// Reference SHA256 implementation. -/// -/// # Arguments -/// -/// * `state` - The SHA256 state. -/// * `pos` - The number of bytes processed in the current state. -/// * `msg` - The message to hash. -pub(crate) fn sha256(mut state: [u32; 8], pos: usize, msg: &[u8]) -> [u32; 8] { - use sha2::{ - compress256, - digest::{ - block_buffer::{BlockBuffer, Eager}, - generic_array::typenum::U64, - }, - }; - - let mut buffer = BlockBuffer::::default(); - buffer.digest_blocks(msg, |b| compress256(&mut state, b)); - buffer.digest_pad(0x80, &(((msg.len() + pos) * 8) as u64).to_be_bytes(), |b| { - compress256(&mut state, &[*b]) - }); - state -} - -#[cfg(test)] -mod tests { - use crate::{ - convert_to_bytes, - sha256::{sha256, Sha256}, - test_utils::{compress_256, mock_vm}, - }; - use mpz_common::context::test_st_context; - use mpz_vm_core::{ - memory::{ - binary::{U32, U8}, - Array, MemoryExt, Vector, ViewExt, - }, - Execute, - }; - - #[tokio::test] - async fn test_sha256_circuit() { - let (mut ctx_a, mut ctx_b) = test_st_context(8); - let (mut leader, mut follower) = mock_vm(); - - let (inputs, references) = test_fixtures(); - for (input, &reference) in inputs.iter().zip(references.iter()) { - let input_leader: Vector = leader.alloc_vec(input.len()).unwrap(); - leader.mark_public(input_leader).unwrap(); - leader.assign(input_leader, input.clone()).unwrap(); - leader.commit(input_leader).unwrap(); - - let mut sha_leader = Sha256::default(); - sha_leader - .update(input_leader) - .add_padding(&mut leader) - .unwrap(); - let sha_out_leader = sha_leader.alloc(&mut leader).unwrap(); - let sha_out_leader = leader.decode(sha_out_leader).unwrap(); - - let input_follower: Vector = follower.alloc_vec(input.len()).unwrap(); - follower.mark_public(input_follower).unwrap(); - follower.assign(input_follower, input.clone()).unwrap(); - follower.commit(input_follower).unwrap(); - - let mut sha_follower = Sha256::default(); - sha_follower - .update(input_follower) - .add_padding(&mut follower) - .unwrap(); - let sha_out_follower = sha_follower.alloc(&mut follower).unwrap(); - let sha_out_follower = follower.decode(sha_out_follower).unwrap(); - - let (sha_out_leader, sha_out_follower) = tokio::try_join!( - async { - leader.execute_all(&mut ctx_a).await.unwrap(); - sha_out_leader.await - }, - async { - follower.execute_all(&mut ctx_b).await.unwrap(); - sha_out_follower.await - } - ) - .unwrap(); - - assert_eq!(sha_out_leader, sha_out_follower); - assert_eq!(convert_to_bytes(sha_out_leader), reference); - } - } - - #[tokio::test] - async fn test_sha256_circuit_set_state() { - let (mut ctx_a, mut ctx_b) = test_st_context(8); - let (mut leader, mut follower) = mock_vm(); - - let (inputs, references) = test_fixtures(); - - // only take 3rd example because we need minimum 64 bits. - let input = &inputs[2]; - let reference = references[2]; - - // This has to be 64 bytes, because the sha256 compression function operates on - // 64 byte blocks. - let skip = 64; - - let state = compress_256(Sha256::IV, &input[..skip]); - let test = input[skip..].to_vec(); - - let input_leader: Vector = leader.alloc_vec(test.len()).unwrap(); - leader.mark_public(input_leader).unwrap(); - leader.assign(input_leader, test.clone()).unwrap(); - leader.commit(input_leader).unwrap(); - - let state_leader: Array = leader.alloc().unwrap(); - leader.mark_public(state_leader).unwrap(); - leader.assign(state_leader, state).unwrap(); - leader.commit(state_leader).unwrap(); - - let mut sha_leader = Sha256::default(); - sha_leader - .set_state(state_leader, skip) - .update(input_leader) - .add_padding(&mut leader) - .unwrap(); - let sha_out_leader = sha_leader.alloc(&mut leader).unwrap(); - let sha_out_leader = leader.decode(sha_out_leader).unwrap(); - - let input_follower: Vector = follower.alloc_vec(test.len()).unwrap(); - follower.mark_public(input_follower).unwrap(); - follower.assign(input_follower, test).unwrap(); - follower.commit(input_follower).unwrap(); - - let state_follower: Array = follower.alloc().unwrap(); - follower.mark_public(state_follower).unwrap(); - follower.assign(state_follower, state).unwrap(); - follower.commit(state_follower).unwrap(); - - let mut sha_follower = Sha256::default(); - sha_follower - .set_state(state_follower, skip) - .update(input_follower) - .add_padding(&mut follower) - .unwrap(); - let sha_out_follower = sha_follower.alloc(&mut follower).unwrap(); - let sha_out_follower = follower.decode(sha_out_follower).unwrap(); - - let (sha_out_leader, sha_out_follower) = tokio::try_join!( - async { - leader.execute_all(&mut ctx_a).await.unwrap(); - sha_out_leader.await - }, - async { - follower.execute_all(&mut ctx_b).await.unwrap(); - sha_out_follower.await - } - ) - .unwrap(); - - assert_eq!(sha_out_leader, sha_out_follower); - assert_eq!(convert_to_bytes(sha_out_leader), reference); - } - - #[test] - fn test_sha256_reference() { - let (inputs, references) = test_fixtures(); - for (input, &reference) in inputs.iter().zip(references.iter()) { - let sha = sha256(Sha256::IV, 0, input); - assert_eq!(convert_to_bytes(sha), reference); - } - } - - #[test] - fn test_sha256_reference_set_state() { - let (inputs, references) = test_fixtures(); - - // only take 3rd example because we need minimum 64 bits. - let input = &inputs[2]; - let reference = references[2]; - - // This has to be 64 bytes, because the sha256 compression function operates on - // 64 byte blocks. - let skip = 64; - - let state = compress_256(Sha256::IV, &input[..skip]); - let test = input[skip..].to_vec(); - - let sha = sha256(state, skip, &test); - assert_eq!(convert_to_bytes(sha), reference); - } - - fn test_fixtures() -> (Vec>, Vec<[u8; 32]>) { - let test_vectors: Vec> = vec![ - b"abc".to_vec(), - b"abcdbcdecdefdefgefghfghighijhijkijkljklmklmnlmnomnopnopq".to_vec(), - b"abcdefghbcdefghicdefghijdefghijkefghijklfghijklmghijklmnhijklmnoijklmnopjklmnopqklmnopqrlmnopqrsmnopqrstnopqrstu".to_vec() - ]; - let expected: Vec<[u8; 32]> = vec![ - hex::decode("ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad") - .unwrap() - .try_into() - .unwrap(), - hex::decode("248d6a61d20638b8e5c026930c3e6039a33ce45964ff2167f6ecedd419db06c1") - .unwrap() - .try_into() - .unwrap(), - hex::decode("cf5b16a778af8380036ce59e7b0492370b249b11e8f07a51afac45037afee9d1") - .unwrap() - .try_into() - .unwrap(), - ]; - - (test_vectors, expected) - } -} diff --git a/crates/components/hmac-sha256/src/test_utils.rs b/crates/components/hmac-sha256/src/test_utils.rs index d3da0e5833..65f340f178 100644 --- a/crates/components/hmac-sha256/src/test_utils.rs +++ b/crates/components/hmac-sha256/src/test_utils.rs @@ -1,4 +1,4 @@ -use crate::{convert_to_bytes, sha256::sha256}; +use crate::{sha256, state_to_bytes}; use mpz_garble::protocol::semihonest::{Evaluator, Garbler}; use mpz_ot::ideal::cot::{ideal_cot, IdealCOTReceiver, IdealCOTSender}; use mpz_vm_core::memory::correlated::Delta; @@ -93,8 +93,8 @@ pub(crate) fn hmac_sha256(key: Vec, msg: &[u8]) -> [u8; 32] { let outer_partial = compute_outer_partial(key.clone()); let inner_local = compute_inner_local(key, msg); - let hmac = sha256(outer_partial, 64, &convert_to_bytes(inner_local)); - convert_to_bytes(hmac) + let hmac = sha256(outer_partial, 64, &state_to_bytes(inner_local)); + state_to_bytes(hmac) } pub(crate) fn compute_outer_partial(mut key: Vec) -> [u32; 8] { From 5283b1867c4170cb7174cb486e68410aa4086ef6 Mon Sep 17 00:00:00 2001 From: th4s Date: Tue, 29 Apr 2025 16:17:46 +0200 Subject: [PATCH 16/36] repair failing test --- crates/components/hmac-sha256/src/hmac.rs | 4 +-- .../hmac-sha256/src/prf/function/normal.rs | 26 +------------------ 2 files changed, 3 insertions(+), 27 deletions(-) diff --git a/crates/components/hmac-sha256/src/hmac.rs b/crates/components/hmac-sha256/src/hmac.rs index a407045238..64a6de6946 100644 --- a/crates/components/hmac-sha256/src/hmac.rs +++ b/crates/components/hmac-sha256/src/hmac.rs @@ -119,7 +119,7 @@ mod tests { leader.commit(inner_local_leader).unwrap(); let hmac_leader = HmacSha256::new( - Sha256::new_from_state(outer_partial_leader, 64), + Sha256::new_from_state(outer_partial_leader, 1), inner_local_leader, ) .alloc(&mut leader) @@ -141,7 +141,7 @@ mod tests { follower.commit(inner_local_follower).unwrap(); let hmac_follower = HmacSha256::new( - Sha256::new_from_state(outer_partial_follower, 64), + Sha256::new_from_state(outer_partial_follower, 1), inner_local_follower, ) .alloc(&mut follower) diff --git a/crates/components/hmac-sha256/src/prf/function/normal.rs b/crates/components/hmac-sha256/src/prf/function/normal.rs index 6706977ae0..1b3b151a33 100644 --- a/crates/components/hmac-sha256/src/prf/function/normal.rs +++ b/crates/components/hmac-sha256/src/prf/function/normal.rs @@ -5,7 +5,7 @@ use mpz_circuits::CircuitBuilder; use mpz_hash::sha256::Sha256; use mpz_vm_core::{ memory::{ - binary::{Binary, U32, U8}, + binary::{Binary, U8}, Array, MemoryExt, Vector, ViewExt, }, Call, CallableExt, Vm, @@ -165,30 +165,6 @@ impl PHash { } } -fn convert_array(vm: &mut dyn Vm, input: Array) -> Result, PrfError> { - let circ = { - let mut builder = CircuitBuilder::new(); - let inputs = (0..32 * 8).map(|_| builder.add_input()).collect::>(); - - for input in inputs.chunks_exact(4 * 8) { - for byte in input.chunks_exact(8).rev() { - for &feed in byte.iter() { - let output = builder.add_id_gate(feed); - builder.add_output(output); - } - } - } - - Arc::new(builder.build().expect("conversion circuit is valid")) - }; - - let mut builder = Call::builder(circ); - builder = builder.arg(input); - let call = builder.build().map_err(PrfError::vm)?; - - vm.call(call).map_err(PrfError::vm) -} - fn merge_vecs(vm: &mut dyn Vm, inputs: Vec>) -> Result, PrfError> { let len: usize = inputs.iter().map(|inp| inp.len()).sum(); let circ = { From 487af260004f290d328e13f38a7ecba2a17df321 Mon Sep 17 00:00:00 2001 From: th4s Date: Tue, 29 Apr 2025 17:05:22 +0200 Subject: [PATCH 17/36] fixed all tests --- crates/components/hmac-sha256/src/lib.rs | 9 --------- .../hmac-sha256/src/prf/function/reduced.rs | 15 ++++++++------- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/crates/components/hmac-sha256/src/lib.rs b/crates/components/hmac-sha256/src/lib.rs index 7f76543b38..2e460bf43d 100644 --- a/crates/components/hmac-sha256/src/lib.rs +++ b/crates/components/hmac-sha256/src/lib.rs @@ -69,15 +69,6 @@ fn state_to_bytes(input: [u32; 8]) -> [u8; 32] { output } -fn bytes_to_state(byte_input: [u8; 32]) -> [u32; 8] { - let mut output = [0_u32; 8]; - for (k, bytes) in byte_input.chunks_exact(4).rev().enumerate() { - let value = u32::from_le_bytes(bytes.try_into().unwrap()); - output[k] = value; - } - output -} - #[cfg(test)] mod tests { use crate::{ diff --git a/crates/components/hmac-sha256/src/prf/function/reduced.rs b/crates/components/hmac-sha256/src/prf/function/reduced.rs index 3e4a0baa2a..097c4f581b 100644 --- a/crates/components/hmac-sha256/src/prf/function/reduced.rs +++ b/crates/components/hmac-sha256/src/prf/function/reduced.rs @@ -1,6 +1,6 @@ //! Computes some hashes of the PRF locally. -use crate::{bytes_to_state, hmac::HmacSha256, sha256, state_to_bytes, PrfError}; +use crate::{hmac::HmacSha256, sha256, state_to_bytes, PrfError}; use mpz_core::bitvec::BitVec; use mpz_hash::sha256::Sha256; use mpz_vm_core::{ @@ -135,7 +135,9 @@ impl PrfFunction { inner_partial: Sha256, len: usize, ) -> Result { - let inner_partial = inner_partial.finalize(vm)?; + let (inner_partial, _) = inner_partial + .state() + .expect("state should be set for inner_partial"); let inner_partial = vm.decode(inner_partial).map_err(PrfError::vm)?; let mut prf = Self { @@ -187,11 +189,10 @@ impl PHash { fn assign_inner_local( &mut self, vm: &mut dyn Vm, - inner_partial: [u8; 32], + inner_partial: [u32; 8], msg: &[u8], ) -> Result<(), PrfError> { if let State::Init { inner_local, .. } = self.state { - let inner_partial = bytes_to_state(inner_partial); let inner_local_value = sha256(inner_partial, 64, msg); vm.mark_public(inner_local).map_err(PrfError::vm)?; @@ -222,12 +223,12 @@ enum State { #[derive(Debug)] enum InnerPartial { - Decoding(DecodeFutureTyped), - Finished([u8; 32]), + Decoding(DecodeFutureTyped), + Finished([u32; 8]), } impl InnerPartial { - pub(crate) fn try_recv(&mut self) -> Result, PrfError> { + pub(crate) fn try_recv(&mut self) -> Result, PrfError> { match self { InnerPartial::Decoding(value) => { let value = value.try_recv().map_err(PrfError::vm)?; From f51ed3352cefde2269fd5562593acfa9b54b50b8 Mon Sep 17 00:00:00 2001 From: th4s Date: Tue, 29 Apr 2025 17:41:20 +0200 Subject: [PATCH 18/36] remove output decoding for p --- .../hmac-sha256/src/prf/function/reduced.rs | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/crates/components/hmac-sha256/src/prf/function/reduced.rs b/crates/components/hmac-sha256/src/prf/function/reduced.rs index 097c4f581b..2cf5d3b4b3 100644 --- a/crates/components/hmac-sha256/src/prf/function/reduced.rs +++ b/crates/components/hmac-sha256/src/prf/function/reduced.rs @@ -65,7 +65,7 @@ impl PrfFunction { pub(crate) fn wants_flush(&mut self) -> bool { let last_p = self.p.last().expect("Prf should be allocated"); - if let State::Finished { .. } = last_p.state { + if let State::Done = last_p.state { return false; } true @@ -86,7 +86,7 @@ impl PrfFunction { State::Assigned { output } => { if let Some(output) = output.try_recv().map_err(PrfError::vm)? { let output = output.to_vec(); - a.state = State::Finished { + a.state = State::Decoded { output: output.clone(), }; self.a_msg = output; @@ -97,17 +97,14 @@ impl PrfFunction { match &mut p.state { State::Init { .. } => { - if let State::Finished { output } = &a.state { + if let State::Decoded { output } = &a.state { let mut p_msg = output.to_vec(); p_msg.extend_from_slice(&self.start_seed_label); p.assign_inner_local(vm, inner_partial, &p_msg)?; } } - State::Assigned { output } => { - if let Some(output) = output.try_recv().map_err(PrfError::vm)? { - let output = output.to_vec(); - p.state = State::Finished { output }; - } + State::Assigned { .. } => { + p.state = State::Done; } _ => (), } @@ -216,9 +213,10 @@ enum State { Assigned { output: DecodeFutureTyped, }, - Finished { + Decoded { output: Vec, }, + Done, } #[derive(Debug)] From 08c0bc9c5e9fd6c6092bc486200e8cd898fa02b8 Mon Sep 17 00:00:00 2001 From: th4s Date: Fri, 2 May 2025 12:00:43 +0200 Subject: [PATCH 19/36] do not use mod.rs file hierarchy --- crates/components/hmac-sha256/src/{prf/mod.rs => prf.rs} | 0 .../hmac-sha256/src/prf/{function/mod.rs => function.rs} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename crates/components/hmac-sha256/src/{prf/mod.rs => prf.rs} (100%) rename crates/components/hmac-sha256/src/prf/{function/mod.rs => function.rs} (100%) diff --git a/crates/components/hmac-sha256/src/prf/mod.rs b/crates/components/hmac-sha256/src/prf.rs similarity index 100% rename from crates/components/hmac-sha256/src/prf/mod.rs rename to crates/components/hmac-sha256/src/prf.rs diff --git a/crates/components/hmac-sha256/src/prf/function/mod.rs b/crates/components/hmac-sha256/src/prf/function.rs similarity index 100% rename from crates/components/hmac-sha256/src/prf/function/mod.rs rename to crates/components/hmac-sha256/src/prf/function.rs From 8a9735cb10808445ae7a31fb2bd77f59a4f3adbd Mon Sep 17 00:00:00 2001 From: th4s Date: Fri, 2 May 2025 12:29:44 +0200 Subject: [PATCH 20/36] remove pub(crate) from function --- crates/components/hmac-sha256/src/prf.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/components/hmac-sha256/src/prf.rs b/crates/components/hmac-sha256/src/prf.rs index 8b4009b96c..ad8381c473 100644 --- a/crates/components/hmac-sha256/src/prf.rs +++ b/crates/components/hmac-sha256/src/prf.rs @@ -14,7 +14,7 @@ use tracing::instrument; mod state; use state::State; -pub(crate) mod function; +mod function; use function::Prf; /// MPC PRF for computing TLS 1.2 HMAC-SHA256 PRF. From 0d4f169b616f157438a954ab4a601bbcd71d012f Mon Sep 17 00:00:00 2001 From: th4s Date: Fri, 2 May 2025 12:51:50 +0200 Subject: [PATCH 21/36] improve config handling --- crates/common/src/config.rs | 8 ++++++++ crates/mpc-tls/src/config.rs | 2 +- crates/prover/src/config.rs | 5 +---- crates/verifier/src/config.rs | 5 +---- crates/wasm/src/prover/config.rs | 7 +------ 5 files changed, 12 insertions(+), 15 deletions(-) diff --git a/crates/common/src/config.rs b/crates/common/src/config.rs index a0cc14b6c4..2c327bfd2f 100644 --- a/crates/common/src/config.rs +++ b/crates/common/src/config.rs @@ -41,6 +41,9 @@ pub struct ProtocolConfig { /// of the MPC-TLS connection. #[builder(default = "true")] defer_decryption_from_start: bool, + /// Network settings. + #[builder(default)] + network: NetworkSetting, /// Version that is being run by prover/verifier. #[builder(setter(skip), default = "VERSION.clone()")] version: Version, @@ -95,6 +98,11 @@ impl ProtocolConfig { pub fn defer_decryption_from_start(&self) -> bool { self.defer_decryption_from_start } + + /// Returns the network settings. + pub fn network(&self) -> NetworkSetting { + self.network + } } /// Protocol configuration validator used by checker (i.e. verifier) to perform diff --git a/crates/mpc-tls/src/config.rs b/crates/mpc-tls/src/config.rs index d7a9c8946e..4e5e8b1dab 100644 --- a/crates/mpc-tls/src/config.rs +++ b/crates/mpc-tls/src/config.rs @@ -105,7 +105,7 @@ impl ConfigBuilder { .max_recv_records .unwrap_or_else(|| PROTOCOL_RECORD_COUNT_RECV + default_record_count(max_recv)); - let prf = self.prf.unwrap_or(PrfMode::Normal); + let prf = self.prf.unwrap_or_default(); Ok(Config { defer_decryption, diff --git a/crates/prover/src/config.rs b/crates/prover/src/config.rs index 02bf0e71a2..c7f9a2a25e 100644 --- a/crates/prover/src/config.rs +++ b/crates/prover/src/config.rs @@ -15,9 +15,6 @@ pub struct ProverConfig { /// Cryptography provider. #[builder(default, setter(into))] crypto_provider: Arc, - /// Network settings. - #[builder(default)] - network: NetworkSetting, } impl ProverConfig { @@ -58,7 +55,7 @@ impl ProverConfig { builder.max_recv_records(max_recv_records); } - if let NetworkSetting::Latency = self.network { + if let NetworkSetting::Latency = self.protocol_config.network() { builder.low_bandwidth(); } diff --git a/crates/verifier/src/config.rs b/crates/verifier/src/config.rs index 85f4a8c548..8f2abc1942 100644 --- a/crates/verifier/src/config.rs +++ b/crates/verifier/src/config.rs @@ -16,9 +16,6 @@ pub struct VerifierConfig { /// Cryptography provider. #[builder(default, setter(into))] crypto_provider: Arc, - /// Network settings. - #[builder(default)] - network: NetworkSetting, } impl Debug for VerifierConfig { @@ -61,7 +58,7 @@ impl VerifierConfig { builder.max_recv_records(max_recv_records); } - if let NetworkSetting::Latency = self.network { + if let NetworkSetting::Latency = protocol_config.network() { builder.low_bandwidth(); } diff --git a/crates/wasm/src/prover/config.rs b/crates/wasm/src/prover/config.rs index e23ee9fab7..e4e5aed76f 100644 --- a/crates/wasm/src/prover/config.rs +++ b/crates/wasm/src/prover/config.rs @@ -34,10 +34,7 @@ impl From for tlsn_prover::ProverConfig { builder.max_recv_records(value); } - if let Some(value) = value.defer_decryption_from_start { - builder.defer_decryption_from_start(value); - } - + builder.network(value.network); let protocol_config = builder.build().unwrap(); let mut builder = tlsn_prover::ProverConfig::builder(); @@ -49,8 +46,6 @@ impl From for tlsn_prover::ProverConfig { builder.defer_decryption_from_start(value); } - builder.network(value.network); - builder.build().unwrap() } } From c5d017ea66f59ef87f0e3efc793cdd57485022d2 Mon Sep 17 00:00:00 2001 From: th4s Date: Fri, 2 May 2025 13:00:14 +0200 Subject: [PATCH 22/36] use `Array::try_from` --- crates/components/hmac-sha256/src/prf/state.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/crates/components/hmac-sha256/src/prf/state.rs b/crates/components/hmac-sha256/src/prf/state.rs index 406de8efc3..6299c6d07d 100644 --- a/crates/components/hmac-sha256/src/prf/state.rs +++ b/crates/components/hmac-sha256/src/prf/state.rs @@ -65,12 +65,12 @@ fn get_session_keys( vm: &mut dyn Vm, ) -> Result { let mut keys = merge_outputs(vm, output, 40)?; + debug_assert!(keys.len() == 40, "session keys len should be 40"); - let server_iv = as FromRaw>::from_raw(keys.split_off(36).to_raw()); - let client_iv = as FromRaw>::from_raw(keys.split_off(32).to_raw()); - let server_write_key = - as FromRaw>::from_raw(keys.split_off(16).to_raw()); - let client_write_key = as FromRaw>::from_raw(keys.to_raw()); + let server_iv = Array::::try_from(keys.split_off(36)).unwrap(); + let client_iv = Array::::try_from(keys.split_off(32)).unwrap(); + let server_write_key = Array::::try_from(keys.split_off(16)).unwrap(); + let client_write_key = Array::::try_from(keys).unwrap(); let session_keys = SessionKeys { client_write_key, From b443d7adc58e549a5723b95ba1f0159d842637fc Mon Sep 17 00:00:00 2001 From: th4s Date: Fri, 2 May 2025 13:19:26 +0200 Subject: [PATCH 23/36] simplify hmac to function --- crates/components/hmac-sha256/src/hmac.rs | 60 +++++++------------ crates/components/hmac-sha256/src/prf.rs | 13 ++-- .../hmac-sha256/src/prf/function/normal.rs | 6 +- .../hmac-sha256/src/prf/function/reduced.rs | 6 +- 4 files changed, 34 insertions(+), 51 deletions(-) diff --git a/crates/components/hmac-sha256/src/hmac.rs b/crates/components/hmac-sha256/src/hmac.rs index 64a6de6946..1103dc6c07 100644 --- a/crates/components/hmac-sha256/src/hmac.rs +++ b/crates/components/hmac-sha256/src/hmac.rs @@ -29,46 +29,30 @@ use mpz_vm_core::{ use crate::PrfError; -/// Computes HMAC-SHA256. -#[derive(Debug)] -pub(crate) struct HmacSha256 { - outer_partial: Sha256, +pub(crate) const IPAD: [u8; 64] = [0x36; 64]; +pub(crate) const OPAD: [u8; 64] = [0x5c; 64]; + +/// Computes HMAC-SHA256 +/// +/// # Arguments +/// +/// * `vm` - The virtual machine. +/// * `outer_partial` - (key' xor opad) +/// * `inner_local` - H((key' xor ipad) || m) +pub(crate) fn hmac_sha256( + vm: &mut dyn Vm, + mut outer_partial: Sha256, inner_local: Array, -} - -impl HmacSha256 { - pub(crate) const IPAD: [u8; 64] = [0x36; 64]; - pub(crate) const OPAD: [u8; 64] = [0x5c; 64]; - - /// Creates a new instance. - /// - /// # Arguments - /// - /// * `outer_partial` - (key' xor opad) - /// * `inner_local` - H((key' xor ipad) || m) - pub(crate) fn new(outer_partial: Sha256, inner_local: Array) -> Self { - Self { - outer_partial, - inner_local, - } - } - - /// Adds the circuit to the [`Vm`] and returns the output. - /// - /// # Arguments - /// - /// * `vm` - The virtual machine. - pub(crate) fn alloc(mut self, vm: &mut dyn Vm) -> Result, PrfError> { - self.outer_partial.update(&self.inner_local); - self.outer_partial.compress(vm)?; - self.outer_partial.finalize(vm).map_err(PrfError::from) - } +) -> Result, PrfError> { + outer_partial.update(&inner_local); + outer_partial.compress(vm)?; + outer_partial.finalize(vm).map_err(PrfError::from) } #[cfg(test)] mod tests { use crate::{ - hmac::HmacSha256, + hmac::hmac_sha256, sha256, state_to_bytes, test_utils::{compute_inner_local, compute_outer_partial, mock_vm}, }; @@ -118,11 +102,11 @@ mod tests { .unwrap(); leader.commit(inner_local_leader).unwrap(); - let hmac_leader = HmacSha256::new( + let hmac_leader = hmac_sha256( + &mut leader, Sha256::new_from_state(outer_partial_leader, 1), inner_local_leader, ) - .alloc(&mut leader) .unwrap(); let hmac_leader = leader.decode(hmac_leader).unwrap(); @@ -140,11 +124,11 @@ mod tests { .unwrap(); follower.commit(inner_local_follower).unwrap(); - let hmac_follower = HmacSha256::new( + let hmac_follower = hmac_sha256( + &mut follower, Sha256::new_from_state(outer_partial_follower, 1), inner_local_follower, ) - .alloc(&mut follower) .unwrap(); let hmac_follower = follower.decode(hmac_follower).unwrap(); diff --git a/crates/components/hmac-sha256/src/prf.rs b/crates/components/hmac-sha256/src/prf.rs index ad8381c473..bada648a27 100644 --- a/crates/components/hmac-sha256/src/prf.rs +++ b/crates/components/hmac-sha256/src/prf.rs @@ -1,4 +1,7 @@ -use crate::{hmac::HmacSha256, Mode, PrfError, PrfOutput}; +use crate::{ + hmac::{IPAD, OPAD}, + Mode, PrfError, PrfOutput, +}; use mpz_circuits::{circuits::xor, Circuit, CircuitBuilder}; use mpz_hash::sha256::Sha256; use mpz_vm_core::{ @@ -56,16 +59,16 @@ impl MpcPrf { let mode = self.mode; let pms: Vector = pms.into(); - let outer_partial_pms = compute_partial(vm, pms, HmacSha256::OPAD)?; - let inner_partial_pms = compute_partial(vm, pms, HmacSha256::IPAD)?; + let outer_partial_pms = compute_partial(vm, pms, OPAD)?; + let inner_partial_pms = compute_partial(vm, pms, IPAD)?; let master_secret = Prf::alloc_master_secret(mode, vm, outer_partial_pms, inner_partial_pms)?; let ms = master_secret.output(); let ms = merge_outputs(vm, ms, 48)?; - let outer_partial_ms = compute_partial(vm, ms, HmacSha256::OPAD)?; - let inner_partial_ms = compute_partial(vm, ms, HmacSha256::IPAD)?; + let outer_partial_ms = compute_partial(vm, ms, OPAD)?; + let inner_partial_ms = compute_partial(vm, ms, IPAD)?; let key_expansion = Prf::alloc_key_expansion(mode, vm, outer_partial_ms.clone(), inner_partial_ms.clone())?; diff --git a/crates/components/hmac-sha256/src/prf/function/normal.rs b/crates/components/hmac-sha256/src/prf/function/normal.rs index 1b3b151a33..d775f782a7 100644 --- a/crates/components/hmac-sha256/src/prf/function/normal.rs +++ b/crates/components/hmac-sha256/src/prf/function/normal.rs @@ -1,6 +1,6 @@ //! Computes the whole PRF in MPC. -use crate::{hmac::HmacSha256, PrfError}; +use crate::{hmac::hmac_sha256, PrfError}; use mpz_circuits::CircuitBuilder; use mpz_hash::sha256::Sha256; use mpz_vm_core::{ @@ -157,9 +157,7 @@ impl PHash { inner_local.compress(vm)?; let inner_local = inner_local.finalize(vm)?; - let hmac = HmacSha256::new(outer_partial, inner_local); - let output = hmac.alloc(vm).map_err(PrfError::vm)?; - + let output = hmac_sha256(vm, outer_partial, inner_local)?; let p_hash = Self { msg, output }; Ok(p_hash) } diff --git a/crates/components/hmac-sha256/src/prf/function/reduced.rs b/crates/components/hmac-sha256/src/prf/function/reduced.rs index 2cf5d3b4b3..fc4ee8ec1a 100644 --- a/crates/components/hmac-sha256/src/prf/function/reduced.rs +++ b/crates/components/hmac-sha256/src/prf/function/reduced.rs @@ -1,6 +1,6 @@ //! Computes some hashes of the PRF locally. -use crate::{hmac::HmacSha256, sha256, state_to_bytes, PrfError}; +use crate::{hmac::hmac_sha256, sha256, state_to_bytes, PrfError}; use mpz_core::bitvec::BitVec; use mpz_hash::sha256::Sha256; use mpz_vm_core::{ @@ -171,9 +171,7 @@ struct PHash { impl PHash { fn alloc(vm: &mut dyn Vm, outer_partial: Sha256) -> Result { let inner_local: Array = vm.alloc().map_err(PrfError::vm)?; - let hmac = HmacSha256::new(outer_partial, inner_local); - - let output = hmac.alloc(vm).map_err(PrfError::vm)?; + let output = hmac_sha256(vm, outer_partial, inner_local)?; let p_hash = Self { state: State::Init { inner_local }, From b769c93016985f7b9d45ba95c047c4b34c34869b Mon Sep 17 00:00:00 2001 From: th4s Date: Fri, 2 May 2025 15:30:28 +0200 Subject: [PATCH 24/36] remove `merge_vecs` --- .../hmac-sha256/src/prf/function/normal.rs | 53 ++++++------------- 1 file changed, 17 insertions(+), 36 deletions(-) diff --git a/crates/components/hmac-sha256/src/prf/function/normal.rs b/crates/components/hmac-sha256/src/prf/function/normal.rs index d775f782a7..5f97b7f625 100644 --- a/crates/components/hmac-sha256/src/prf/function/normal.rs +++ b/crates/components/hmac-sha256/src/prf/function/normal.rs @@ -1,16 +1,14 @@ //! Computes the whole PRF in MPC. use crate::{hmac::hmac_sha256, PrfError}; -use mpz_circuits::CircuitBuilder; use mpz_hash::sha256::Sha256; use mpz_vm_core::{ memory::{ binary::{Binary, U8}, Array, MemoryExt, Vector, ViewExt, }, - Call, CallableExt, Vm, + Vm, }; -use std::sync::Arc; #[derive(Debug)] pub(crate) struct PrfFunction { @@ -71,7 +69,7 @@ impl PrfFunction { pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { if let State::Computing = self.state { let a = self.a.first().expect("prf should be allocated"); - let msg = a.msg; + let msg = *a.msg.first().expect("message for prf should be present"); let msg_value = self.start_seed_label.clone(); @@ -120,12 +118,16 @@ impl PrfFunction { let mut msg_a = seed_label_ref; for _ in 0..iterations { - let a = PHash::alloc(vm, outer_partial.clone(), inner_partial.clone(), msg_a)?; + let a = PHash::alloc(vm, outer_partial.clone(), inner_partial.clone(), &[msg_a])?; msg_a = Vector::::from(a.output); prf.a.push(a); - let msg_p = merge_vecs(vm, vec![msg_a, seed_label_ref])?; - let p = PHash::alloc(vm, outer_partial.clone(), inner_partial.clone(), msg_p)?; + let p = PHash::alloc( + vm, + outer_partial.clone(), + inner_partial.clone(), + &[msg_a, seed_label_ref], + )?; prf.p.push(p); } @@ -141,7 +143,7 @@ enum State { #[derive(Debug, Clone)] struct PHash { - msg: Vector, + msg: Vec>, output: Array, } @@ -150,40 +152,19 @@ impl PHash { vm: &mut dyn Vm, outer_partial: Sha256, inner_partial: Sha256, - msg: Vector, + msg: &[Vector], ) -> Result { let mut inner_local = inner_partial; - inner_local.update(&msg); + + msg.iter().for_each(|m| inner_local.update(m)); inner_local.compress(vm)?; let inner_local = inner_local.finalize(vm)?; let output = hmac_sha256(vm, outer_partial, inner_local)?; - let p_hash = Self { msg, output }; + let p_hash = Self { + msg: msg.to_vec(), + output, + }; Ok(p_hash) } } - -fn merge_vecs(vm: &mut dyn Vm, inputs: Vec>) -> Result, PrfError> { - let len: usize = inputs.iter().map(|inp| inp.len()).sum(); - let circ = { - let mut builder = CircuitBuilder::new(); - - let feeds = (0..len * 8) - .map(|_| builder.add_input()) - .collect::>(); - for feed in feeds { - let output = builder.add_id_gate(feed); - builder.add_output(output); - } - - Arc::new(builder.build().expect("merge circuit is valid")) - }; - - let mut builder = Call::builder(circ); - for input in inputs { - builder = builder.arg(input); - } - let call = builder.build().map_err(PrfError::vm)?; - - vm.call(call).map_err(PrfError::vm) -} From b1251ff538957d275be6ccaf76d234b7ac10f83b Mon Sep 17 00:00:00 2001 From: th4s Date: Fri, 2 May 2025 15:36:16 +0200 Subject: [PATCH 25/36] move `mark_public` to allocation --- crates/components/hmac-sha256/src/prf/function/normal.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/components/hmac-sha256/src/prf/function/normal.rs b/crates/components/hmac-sha256/src/prf/function/normal.rs index 5f97b7f625..734ad1b2d5 100644 --- a/crates/components/hmac-sha256/src/prf/function/normal.rs +++ b/crates/components/hmac-sha256/src/prf/function/normal.rs @@ -73,7 +73,6 @@ impl PrfFunction { let msg_value = self.start_seed_label.clone(); - vm.mark_public(msg).map_err(PrfError::vm)?; vm.assign(msg, msg_value).map_err(PrfError::vm)?; vm.commit(msg).map_err(PrfError::vm)?; @@ -115,8 +114,9 @@ impl PrfFunction { let msg_len_a = label.len() + seed_len; let seed_label_ref: Vector = vm.alloc_vec(msg_len_a).map_err(PrfError::vm)?; - let mut msg_a = seed_label_ref; + vm.mark_public(seed_label_ref).map_err(PrfError::vm)?; + let mut msg_a = seed_label_ref; for _ in 0..iterations { let a = PHash::alloc(vm, outer_partial.clone(), inner_partial.clone(), &[msg_a])?; msg_a = Vector::::from(a.output); From 2d1595364ebac638d14836d1b5c47c065c9839f6 Mon Sep 17 00:00:00 2001 From: th4s Date: Fri, 2 May 2025 18:40:12 +0200 Subject: [PATCH 26/36] minor fixes --- crates/components/hmac-sha256/src/prf.rs | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/crates/components/hmac-sha256/src/prf.rs b/crates/components/hmac-sha256/src/prf.rs index bada648a27..5ee96c6820 100644 --- a/crates/components/hmac-sha256/src/prf.rs +++ b/crates/components/hmac-sha256/src/prf.rs @@ -32,7 +32,7 @@ impl MpcPrf { /// /// # Arguments /// - /// `mode` - The PRF config. + /// `mode` - The PRF mode. pub fn new(mode: Mode) -> MpcPrf { Self { mode, @@ -293,7 +293,7 @@ fn merge_outputs( assert!(output_bytes <= 32 * inputs.len()); let bits = Array::::SIZE * inputs.len(); - let circ = gen_merge_circ(1, bits); + let circ = gen_merge_circ(bits); let mut builder = Call::builder(circ); for &input in inputs.iter() { @@ -306,15 +306,12 @@ fn merge_outputs( Ok(output) } -fn gen_merge_circ(element_byte_size: usize, size: usize) -> Arc { - assert!((size / 8) % element_byte_size == 0); - +fn gen_merge_circ(size: usize) -> Arc { let mut builder = CircuitBuilder::new(); let inputs = (0..size).map(|_| builder.add_input()).collect::>(); - for input in inputs.chunks_exact(element_byte_size * 8) { - // TODO: .rev() removed here, correct? - for byte in input.chunks_exact(8).rev() { + for input in inputs.chunks_exact(8) { + for byte in input.chunks_exact(8) { for &feed in byte.iter() { let output = builder.add_id_gate(feed); builder.add_output(output); From 27f8543abbe9a51f31428906ecaace60192a561d Mon Sep 17 00:00:00 2001 From: th4s Date: Fri, 2 May 2025 22:11:00 +0200 Subject: [PATCH 27/36] simplify state logic for reduced prf even more --- .../hmac-sha256/src/prf/function.rs | 1 + .../hmac-sha256/src/prf/function/reduced.rs | 237 +++++++++--------- 2 files changed, 116 insertions(+), 122 deletions(-) diff --git a/crates/components/hmac-sha256/src/prf/function.rs b/crates/components/hmac-sha256/src/prf/function.rs index 9738b2a52c..9830e60159 100644 --- a/crates/components/hmac-sha256/src/prf/function.rs +++ b/crates/components/hmac-sha256/src/prf/function.rs @@ -234,6 +234,7 @@ mod tests { .unwrap(); } + assert_eq!(prf_out_leader.len(), 2); assert_eq!(prf_out_leader.len(), prf_out_follower.len()); let prf_result_leader: Vec = prf_out_leader diff --git a/crates/components/hmac-sha256/src/prf/function/reduced.rs b/crates/components/hmac-sha256/src/prf/function/reduced.rs index fc4ee8ec1a..b56da46a05 100644 --- a/crates/components/hmac-sha256/src/prf/function/reduced.rs +++ b/crates/components/hmac-sha256/src/prf/function/reduced.rs @@ -17,13 +17,31 @@ pub(crate) struct PrfFunction { label: &'static [u8], // The start seed and the label, e.g. client_random + server_random + "master_secret". start_seed_label: Vec, - // The current HMAC message needed for a[i] - a_msg: Vec, - inner_partial: InnerPartial, + iterations: usize, + state: PrfState, a: Vec, p: Vec, } +#[derive(Debug)] +enum PrfState { + InnerPartial { + inner_partial: DecodeFutureTyped, + }, + ComputeA { + iter: usize, + inner_partial: [u32; 8], + msg: Vec, + }, + ComputeP { + iter: usize, + inner_partial: [u32; 8], + a_output: DecodeFutureTyped, + }, + ComputeLastP, + Done, +} + impl PrfFunction { const MS_LABEL: &[u8] = b"master secret"; const KEY_LABEL: &[u8] = b"key expansion"; @@ -63,51 +81,68 @@ impl PrfFunction { } pub(crate) fn wants_flush(&mut self) -> bool { - let last_p = self.p.last().expect("Prf should be allocated"); - - if let State::Done = last_p.state { + if let PrfState::Done = self.state { return false; } true } pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { - let inner_partial = self.inner_partial.try_recv()?; - let Some(inner_partial) = inner_partial else { - return Ok(()); - }; - - for (a, p) in self.a.iter_mut().zip(self.p.iter_mut()) { - match &mut a.state { - State::Init { .. } => { - a.assign_inner_local(vm, inner_partial, &self.a_msg)?; - break; - } - State::Assigned { output } => { - if let Some(output) = output.try_recv().map_err(PrfError::vm)? { - let output = output.to_vec(); - a.state = State::Decoded { - output: output.clone(), - }; - self.a_msg = output; - } - } - _ => (), + match &mut self.state { + PrfState::InnerPartial { inner_partial } => { + let Some(inner_partial) = inner_partial.try_recv().map_err(PrfError::vm)? else { + return Ok(()); + }; + + self.state = PrfState::ComputeA { + iter: 0, + inner_partial, + msg: self.start_seed_label.clone(), + }; + self.flush(vm)?; } - - match &mut p.state { - State::Init { .. } => { - if let State::Decoded { output } = &a.state { - let mut p_msg = output.to_vec(); - p_msg.extend_from_slice(&self.start_seed_label); - p.assign_inner_local(vm, inner_partial, &p_msg)?; + PrfState::ComputeA { + iter, + inner_partial, + msg, + } => { + let a = &self.a[*iter]; + assign_inner_local(vm, a.inner_local, *inner_partial, msg)?; + + let a_output = vm.decode(a.output).map_err(PrfError::vm)?; + self.state = PrfState::ComputeP { + iter: *iter, + inner_partial: *inner_partial, + a_output, + }; + } + PrfState::ComputeP { + iter, + inner_partial, + a_output, + } => { + let Some(output) = a_output.try_recv().map_err(PrfError::vm)? else { + return Ok(()); + }; + let p = &self.p[*iter]; + + let mut msg = output.to_vec(); + msg.extend_from_slice(&self.start_seed_label); + + assign_inner_local(vm, p.inner_local, *inner_partial, &msg)?; + + if *iter == self.iterations { + self.state = PrfState::ComputeLastP; + } else { + self.state = PrfState::ComputeA { + iter: *iter + 1, + inner_partial: *inner_partial, + msg: output.to_vec(), } - } - State::Assigned { .. } => { - p.state = State::Done; - } - _ => (), + }; } + PrfState::ComputeLastP => self.state = PrfState::Done, + _ => (), } Ok(()) @@ -117,8 +152,7 @@ impl PrfFunction { let mut start_seed_label = self.label.to_vec(); start_seed_label.extend_from_slice(&seed); - self.start_seed_label = start_seed_label.clone(); - self.a_msg = start_seed_label; + self.start_seed_label = start_seed_label; } pub(crate) fn output(&self) -> Vec> { @@ -132,6 +166,10 @@ impl PrfFunction { inner_partial: Sha256, len: usize, ) -> Result { + assert!(len > 0, "cannot compute 0 bytes for prf"); + + let iterations = len / 32 + ((len % 32) != 0) as usize; + let (inner_partial, _) = inner_partial .state() .expect("state should be set for inner_partial"); @@ -140,100 +178,55 @@ impl PrfFunction { let mut prf = Self { label, start_seed_label: vec![], - a_msg: vec![], - inner_partial: InnerPartial::Decoding(inner_partial), + // used for indexing, so we need to subtract one here + iterations: iterations - 1, + state: PrfState::InnerPartial { inner_partial }, a: vec![], p: vec![], }; - assert!(len > 0, "cannot compute 0 bytes for prf"); - - let iterations = len / 32 + ((len % 32) != 0) as usize; - for _ in 0..iterations { - let a = PHash::alloc(vm, outer_partial.clone())?; - prf.a.push(a); - - let p = PHash::alloc(vm, outer_partial.clone())?; - prf.p.push(p); + // setup A[i] + let inner_local: Array = vm.alloc().map_err(PrfError::vm)?; + let output = hmac_sha256(vm, outer_partial.clone(), inner_local)?; + let p_hash = PHash { + inner_local, + output, + }; + prf.a.push(p_hash); + + // setup P[i] + let inner_local: Array = vm.alloc().map_err(PrfError::vm)?; + let output = hmac_sha256(vm, outer_partial.clone(), inner_local)?; + let p_hash = PHash { + inner_local, + output, + }; + prf.p.push(p_hash); } Ok(prf) } } -#[derive(Debug)] -struct PHash { - output: Array, - state: State, -} - -impl PHash { - fn alloc(vm: &mut dyn Vm, outer_partial: Sha256) -> Result { - let inner_local: Array = vm.alloc().map_err(PrfError::vm)?; - let output = hmac_sha256(vm, outer_partial, inner_local)?; - - let p_hash = Self { - state: State::Init { inner_local }, - output, - }; - - Ok(p_hash) - } - - fn assign_inner_local( - &mut self, - vm: &mut dyn Vm, - inner_partial: [u32; 8], - msg: &[u8], - ) -> Result<(), PrfError> { - if let State::Init { inner_local, .. } = self.state { - let inner_local_value = sha256(inner_partial, 64, msg); - - vm.mark_public(inner_local).map_err(PrfError::vm)?; - vm.assign(inner_local, state_to_bytes(inner_local_value)) - .map_err(PrfError::vm)?; - vm.commit(inner_local).map_err(PrfError::vm)?; - - let output = vm.decode(self.output).map_err(PrfError::vm)?; - self.state = State::Assigned { output }; - } - - Ok(()) - } -} +fn assign_inner_local( + vm: &mut dyn Vm, + inner_local: Array, + inner_partial: [u32; 8], + msg: &[u8], +) -> Result<(), PrfError> { + let inner_local_value = sha256(inner_partial, 64, msg); -#[derive(Debug)] -enum State { - Init { - inner_local: Array, - }, - Assigned { - output: DecodeFutureTyped, - }, - Decoded { - output: Vec, - }, - Done, -} + vm.mark_public(inner_local).map_err(PrfError::vm)?; + vm.assign(inner_local, state_to_bytes(inner_local_value)) + .map_err(PrfError::vm)?; + vm.commit(inner_local).map_err(PrfError::vm)?; -#[derive(Debug)] -enum InnerPartial { - Decoding(DecodeFutureTyped), - Finished([u32; 8]), + Ok(()) } -impl InnerPartial { - pub(crate) fn try_recv(&mut self) -> Result, PrfError> { - match self { - InnerPartial::Decoding(value) => { - let value = value.try_recv().map_err(PrfError::vm)?; - if let Some(value) = value { - *self = InnerPartial::Finished(value); - } - Ok(value) - } - InnerPartial::Finished(value) => Ok(Some(*value)), - } - } +#[derive(Debug, Clone, Copy)] +struct PHash { + inner_local: Array, + output: Array, } From f2d7b76442a8c6d13b78c2bfe2d3f240e5aa3a96 Mon Sep 17 00:00:00 2001 From: th4s Date: Tue, 6 May 2025 15:58:36 +0200 Subject: [PATCH 28/36] simplify reduced prf even more --- .../hmac-sha256/src/prf/function/reduced.rs | 45 ++++++++++++------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/crates/components/hmac-sha256/src/prf/function/reduced.rs b/crates/components/hmac-sha256/src/prf/function/reduced.rs index b56da46a05..4c56d2db32 100644 --- a/crates/components/hmac-sha256/src/prf/function/reduced.rs +++ b/crates/components/hmac-sha256/src/prf/function/reduced.rs @@ -1,5 +1,7 @@ //! Computes some hashes of the PRF locally. +use std::collections::VecDeque; + use crate::{hmac::hmac_sha256, sha256, state_to_bytes, PrfError}; use mpz_core::bitvec::BitVec; use mpz_hash::sha256::Sha256; @@ -19,8 +21,8 @@ pub(crate) struct PrfFunction { start_seed_label: Vec, iterations: usize, state: PrfState, - a: Vec, - p: Vec, + a: VecDeque, + p: VecDeque, } #[derive(Debug)] @@ -38,7 +40,7 @@ enum PrfState { inner_partial: [u32; 8], a_output: DecodeFutureTyped, }, - ComputeLastP, + FinishLastP, Done, } @@ -95,7 +97,7 @@ impl PrfFunction { }; self.state = PrfState::ComputeA { - iter: 0, + iter: 1, inner_partial, msg: self.start_seed_label.clone(), }; @@ -106,14 +108,13 @@ impl PrfFunction { inner_partial, msg, } => { - let a = &self.a[*iter]; + let a = self.a.pop_front().expect("Prf AHash should be present"); assign_inner_local(vm, a.inner_local, *inner_partial, msg)?; - let a_output = vm.decode(a.output).map_err(PrfError::vm)?; self.state = PrfState::ComputeP { iter: *iter, inner_partial: *inner_partial, - a_output, + a_output: a.output, }; } PrfState::ComputeP { @@ -124,7 +125,7 @@ impl PrfFunction { let Some(output) = a_output.try_recv().map_err(PrfError::vm)? else { return Ok(()); }; - let p = &self.p[*iter]; + let p = self.p.pop_front().expect("Prf PHash should be present"); let mut msg = output.to_vec(); msg.extend_from_slice(&self.start_seed_label); @@ -132,7 +133,7 @@ impl PrfFunction { assign_inner_local(vm, p.inner_local, *inner_partial, &msg)?; if *iter == self.iterations { - self.state = PrfState::ComputeLastP; + self.state = PrfState::FinishLastP; } else { self.state = PrfState::ComputeA { iter: *iter + 1, @@ -141,7 +142,7 @@ impl PrfFunction { } }; } - PrfState::ComputeLastP => self.state = PrfState::Done, + PrfState::FinishLastP => self.state = PrfState::Done, _ => (), } @@ -178,22 +179,24 @@ impl PrfFunction { let mut prf = Self { label, start_seed_label: vec![], - // used for indexing, so we need to subtract one here - iterations: iterations - 1, + iterations, state: PrfState::InnerPartial { inner_partial }, - a: vec![], - p: vec![], + a: VecDeque::new(), + p: VecDeque::new(), }; for _ in 0..iterations { // setup A[i] let inner_local: Array = vm.alloc().map_err(PrfError::vm)?; let output = hmac_sha256(vm, outer_partial.clone(), inner_local)?; - let p_hash = PHash { + + let output = vm.decode(output).map_err(PrfError::vm)?; + let a_hash = AHash { inner_local, output, }; - prf.a.push(p_hash); + + prf.a.push_front(a_hash); // setup P[i] let inner_local: Array = vm.alloc().map_err(PrfError::vm)?; @@ -202,7 +205,7 @@ impl PrfFunction { inner_local, output, }; - prf.p.push(p_hash); + prf.p.push_front(p_hash); } Ok(prf) @@ -225,6 +228,14 @@ fn assign_inner_local( Ok(()) } +/// Like PHash but stores the output as the decoding future because in the reduced Prf we need to +/// decode this output. +#[derive(Debug)] +struct AHash { + inner_local: Array, + output: DecodeFutureTyped, +} + #[derive(Debug, Clone, Copy)] struct PHash { inner_local: Array, From b2b00d541345ae887bd6d9a6a85997add2a98260 Mon Sep 17 00:00:00 2001 From: th4s Date: Tue, 6 May 2025 15:58:58 +0200 Subject: [PATCH 29/36] set reduced prf as default --- crates/components/hmac-sha256/src/config.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/components/hmac-sha256/src/config.rs b/crates/components/hmac-sha256/src/config.rs index 9a9284c711..85914b0f52 100644 --- a/crates/components/hmac-sha256/src/config.rs +++ b/crates/components/hmac-sha256/src/config.rs @@ -11,6 +11,6 @@ pub enum Mode { impl Default for Mode { fn default() -> Self { - Self::Normal + Self::Reduced } } From f99d4260fd5dd81a239fa51622622a60f194cae1 Mon Sep 17 00:00:00 2001 From: th4s Date: Wed, 7 May 2025 13:26:26 +0200 Subject: [PATCH 30/36] temporarily fix commit for mpz --- Cargo.toml | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2801b71f6a..9202626a84 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -69,19 +69,19 @@ tlsn-tls-core = { path = "crates/tls/core" } tlsn-utils = { git = "https://github.com/tlsnotary/tlsn-utils", rev = "6168663" } tlsn-verifier = { path = "crates/verifier" } -mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "dev" } -mpz-memory-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "dev" } -mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "dev" } -mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "dev" } -mpz-vm-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "dev" } -mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "dev" } -mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "dev" } -mpz-ole = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "dev" } -mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "dev" } -mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "dev" } -mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "dev" } -mpz-zk = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "dev" } -mpz-hash = { git = "https://github.com/privacy-scaling-explorations/mpz", branch = "dev" } +mpz-circuits = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } +mpz-memory-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } +mpz-common = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } +mpz-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } +mpz-vm-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } +mpz-garble = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } +mpz-garble-core = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } +mpz-ole = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } +mpz-ot = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } +mpz-share-conversion = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } +mpz-fields = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } +mpz-zk = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } +mpz-hash = { git = "https://github.com/privacy-scaling-explorations/mpz", rev = "39f64de" } rangeset = { version = "0.2" } serio = { version = "0.2" } From 67fa2a0a3618d030b47992c34e437af855709d7e Mon Sep 17 00:00:00 2001 From: th4s Date: Wed, 7 May 2025 13:55:15 +0200 Subject: [PATCH 31/36] add part of feedback --- crates/components/hmac-sha256/src/prf/function.rs | 5 ++++- .../components/hmac-sha256/src/prf/function/normal.rs | 2 +- .../hmac-sha256/src/prf/function/reduced.rs | 11 ++++------- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/crates/components/hmac-sha256/src/prf/function.rs b/crates/components/hmac-sha256/src/prf/function.rs index 9830e60159..7b830afe6f 100644 --- a/crates/components/hmac-sha256/src/prf/function.rs +++ b/crates/components/hmac-sha256/src/prf/function.rs @@ -145,6 +145,7 @@ mod tests { memory::{binary::U8, Array, MemoryExt, ViewExt}, Execute, }; + use rand::{rngs::ThreadRng, Rng}; const IPAD: [u8; 64] = [0x36; 64]; const OPAD: [u8; 64] = [0x5c; 64]; @@ -162,10 +163,12 @@ mod tests { } async fn test_phash(mode: Mode) { + let mut rng = ThreadRng::default(); + let (mut ctx_a, mut ctx_b) = test_st_context(8); let (mut leader, mut follower) = mock_vm(); - let key: [u8; 32] = std::array::from_fn(|i| i as u8); + let key: [u8; 32] = rng.random(); let start_seed: Vec = vec![42; 64]; let mut label_seed = b"master secret".to_vec(); diff --git a/crates/components/hmac-sha256/src/prf/function/normal.rs b/crates/components/hmac-sha256/src/prf/function/normal.rs index 734ad1b2d5..bbb9ed40f2 100644 --- a/crates/components/hmac-sha256/src/prf/function/normal.rs +++ b/crates/components/hmac-sha256/src/prf/function/normal.rs @@ -110,7 +110,7 @@ impl PrfFunction { assert!(output_len > 0, "cannot compute 0 bytes for prf"); - let iterations = output_len / 32 + ((output_len % 32) != 0) as usize; + let iterations = output_len.div_ceil(32); let msg_len_a = label.len() + seed_len; let seed_label_ref: Vector = vm.alloc_vec(msg_len_a).map_err(PrfError::vm)?; diff --git a/crates/components/hmac-sha256/src/prf/function/reduced.rs b/crates/components/hmac-sha256/src/prf/function/reduced.rs index 4c56d2db32..d92e245352 100644 --- a/crates/components/hmac-sha256/src/prf/function/reduced.rs +++ b/crates/components/hmac-sha256/src/prf/function/reduced.rs @@ -83,10 +83,7 @@ impl PrfFunction { } pub(crate) fn wants_flush(&mut self) -> bool { - if let PrfState::Done = self.state { - return false; - } - true + !matches!(self.state, PrfState::Done) } pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { @@ -169,7 +166,7 @@ impl PrfFunction { ) -> Result { assert!(len > 0, "cannot compute 0 bytes for prf"); - let iterations = len / 32 + ((len % 32) != 0) as usize; + let iterations = len.div_ceil(32); let (inner_partial, _) = inner_partial .state() @@ -228,8 +225,8 @@ fn assign_inner_local( Ok(()) } -/// Like PHash but stores the output as the decoding future because in the reduced Prf we need to -/// decode this output. +/// Like PHash but stores the output as the decoding future because in the +/// reduced Prf we need to decode this output. #[derive(Debug)] struct AHash { inner_local: Array, From 6f4884a7dea0a6d63467bdba1b915c7b70c8744b Mon Sep 17 00:00:00 2001 From: th4s Date: Wed, 7 May 2025 17:18:16 +0200 Subject: [PATCH 32/36] simplify state transition --- crates/components/hmac-sha256/src/prf.rs | 42 +++++++++---------- .../hmac-sha256/src/prf/function.rs | 2 +- .../hmac-sha256/src/prf/function/normal.rs | 16 ++++--- .../hmac-sha256/src/prf/function/reduced.rs | 21 ++++++---- 4 files changed, 45 insertions(+), 36 deletions(-) diff --git a/crates/components/hmac-sha256/src/prf.rs b/crates/components/hmac-sha256/src/prf.rs index 5ee96c6820..fd967a5398 100644 --- a/crates/components/hmac-sha256/src/prf.rs +++ b/crates/components/hmac-sha256/src/prf.rs @@ -180,8 +180,8 @@ impl MpcPrf { } /// Returns if the PRF needs to be flushed and drives the PRF. - pub fn wants_flush(&mut self) -> bool { - let wants_flush = match &mut self.state { + pub fn wants_flush(&self) -> bool { + match &self.state { State::Initialized => false, State::SessionKeys { master_secret, @@ -194,27 +194,7 @@ impl MpcPrf { State::ServerFinished { server_finished } => server_finished.wants_flush(), State::Complete => false, State::Error => false, - }; - - if !wants_flush { - self.state = match self.state.take() { - State::SessionKeys { - client_finished, - server_finished, - .. - } => State::ClientFinished { - client_finished, - server_finished, - }, - State::ClientFinished { - server_finished, .. - } => State::ServerFinished { server_finished }, - State::ServerFinished { .. } => State::Complete, - other => other, - }; } - - wants_flush } /// Flushes the PRF. @@ -239,6 +219,24 @@ impl MpcPrf { _ => (), } + if !self.wants_flush() { + self.state = match self.state.take() { + State::SessionKeys { + client_finished, + server_finished, + .. + } => State::ClientFinished { + client_finished, + server_finished, + }, + State::ClientFinished { + server_finished, .. + } => State::ServerFinished { server_finished }, + State::ServerFinished { .. } => State::Complete, + other => other, + }; + } + Ok(()) } } diff --git a/crates/components/hmac-sha256/src/prf/function.rs b/crates/components/hmac-sha256/src/prf/function.rs index 7b830afe6f..e1e932a02f 100644 --- a/crates/components/hmac-sha256/src/prf/function.rs +++ b/crates/components/hmac-sha256/src/prf/function.rs @@ -104,7 +104,7 @@ impl Prf { Ok(prf) } - pub(crate) fn wants_flush(&mut self) -> bool { + pub(crate) fn wants_flush(&self) -> bool { match self { Prf::Reduced(prf) => prf.wants_flush(), Prf::Normal(prf) => prf.wants_flush(), diff --git a/crates/components/hmac-sha256/src/prf/function/normal.rs b/crates/components/hmac-sha256/src/prf/function/normal.rs index bbb9ed40f2..ec931f24ef 100644 --- a/crates/components/hmac-sha256/src/prf/function/normal.rs +++ b/crates/components/hmac-sha256/src/prf/function/normal.rs @@ -16,7 +16,7 @@ pub(crate) struct PrfFunction { label: &'static [u8], state: State, // The start seed and the label, e.g. client_random + server_random + "master_secret". - start_seed_label: Vec, + start_seed_label: Option>, a: Vec, p: Vec, } @@ -60,10 +60,11 @@ impl PrfFunction { } pub(crate) fn wants_flush(&self) -> bool { - match self.state { + let is_computing = match self.state { State::Computing => true, State::Finished => false, - } + }; + is_computing && self.start_seed_label.is_some() } pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { @@ -71,7 +72,10 @@ impl PrfFunction { let a = self.a.first().expect("prf should be allocated"); let msg = *a.msg.first().expect("message for prf should be present"); - let msg_value = self.start_seed_label.clone(); + let msg_value = self + .start_seed_label + .clone() + .expect("Start seed should have been set"); vm.assign(msg, msg_value).map_err(PrfError::vm)?; vm.commit(msg).map_err(PrfError::vm)?; @@ -85,7 +89,7 @@ impl PrfFunction { let mut start_seed_label = self.label.to_vec(); start_seed_label.extend_from_slice(&seed); - self.start_seed_label = start_seed_label; + self.start_seed_label = Some(start_seed_label); } pub(crate) fn output(&self) -> Vec> { @@ -103,7 +107,7 @@ impl PrfFunction { let mut prf = Self { label, state: State::Computing, - start_seed_label: vec![], + start_seed_label: None, a: vec![], p: vec![], }; diff --git a/crates/components/hmac-sha256/src/prf/function/reduced.rs b/crates/components/hmac-sha256/src/prf/function/reduced.rs index d92e245352..2120a82937 100644 --- a/crates/components/hmac-sha256/src/prf/function/reduced.rs +++ b/crates/components/hmac-sha256/src/prf/function/reduced.rs @@ -18,7 +18,7 @@ pub(crate) struct PrfFunction { // The label, e.g. "master secret". label: &'static [u8], // The start seed and the label, e.g. client_random + server_random + "master_secret". - start_seed_label: Vec, + start_seed_label: Option>, iterations: usize, state: PrfState, a: VecDeque, @@ -82,8 +82,8 @@ impl PrfFunction { Self::alloc(vm, Self::SF_LABEL, outer_partial, inner_partial, 12) } - pub(crate) fn wants_flush(&mut self) -> bool { - !matches!(self.state, PrfState::Done) + pub(crate) fn wants_flush(&self) -> bool { + !matches!(self.state, PrfState::Done) && self.start_seed_label.is_some() } pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { @@ -96,7 +96,10 @@ impl PrfFunction { self.state = PrfState::ComputeA { iter: 1, inner_partial, - msg: self.start_seed_label.clone(), + msg: self + .start_seed_label + .clone() + .expect("Start seed should have been set"), }; self.flush(vm)?; } @@ -125,7 +128,11 @@ impl PrfFunction { let p = self.p.pop_front().expect("Prf PHash should be present"); let mut msg = output.to_vec(); - msg.extend_from_slice(&self.start_seed_label); + msg.extend_from_slice( + self.start_seed_label + .as_ref() + .expect("Start seed should have been set"), + ); assign_inner_local(vm, p.inner_local, *inner_partial, &msg)?; @@ -150,7 +157,7 @@ impl PrfFunction { let mut start_seed_label = self.label.to_vec(); start_seed_label.extend_from_slice(&seed); - self.start_seed_label = start_seed_label; + self.start_seed_label = Some(start_seed_label); } pub(crate) fn output(&self) -> Vec> { @@ -175,7 +182,7 @@ impl PrfFunction { let mut prf = Self { label, - start_seed_label: vec![], + start_seed_label: None, iterations, state: PrfState::InnerPartial { inner_partial }, a: VecDeque::new(), From 7a60ef08ed8afa4161d2d4fe0d21a992558bb92c Mon Sep 17 00:00:00 2001 From: th4s Date: Wed, 7 May 2025 17:27:42 +0200 Subject: [PATCH 33/36] adapt comment --- crates/components/hmac-sha256/src/prf.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/components/hmac-sha256/src/prf.rs b/crates/components/hmac-sha256/src/prf.rs index fd967a5398..949242fba7 100644 --- a/crates/components/hmac-sha256/src/prf.rs +++ b/crates/components/hmac-sha256/src/prf.rs @@ -179,7 +179,7 @@ impl MpcPrf { Ok(()) } - /// Returns if the PRF needs to be flushed and drives the PRF. + /// Returns if the PRF needs to be flushed. pub fn wants_flush(&self) -> bool { match &self.state { State::Initialized => false, From cf15673a1127b505a31021a49b3989b1cf966d10 Mon Sep 17 00:00:00 2001 From: th4s Date: Wed, 7 May 2025 17:50:45 +0200 Subject: [PATCH 34/36] improve state transition in flush --- crates/components/hmac-sha256/src/prf.rs | 70 +++++++++++++++--------- 1 file changed, 44 insertions(+), 26 deletions(-) diff --git a/crates/components/hmac-sha256/src/prf.rs b/crates/components/hmac-sha256/src/prf.rs index 949242fba7..eb2eea7a9f 100644 --- a/crates/components/hmac-sha256/src/prf.rs +++ b/crates/components/hmac-sha256/src/prf.rs @@ -199,44 +199,62 @@ impl MpcPrf { /// Flushes the PRF. pub fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { - match &mut self.state { + let state = match self.state.take() { State::SessionKeys { - master_secret, - key_expansion, - .. + client_random, + mut master_secret, + mut key_expansion, + client_finished, + server_finished, } => { master_secret.flush(vm)?; key_expansion.flush(vm)?; + + if !master_secret.wants_flush() && !key_expansion.wants_flush() { + State::ClientFinished { + client_finished, + server_finished, + } + } else { + State::SessionKeys { + client_random, + master_secret, + key_expansion, + client_finished, + server_finished, + } + } } State::ClientFinished { - client_finished, .. + mut client_finished, + server_finished, } => { client_finished.flush(vm)?; + + if !client_finished.wants_flush() { + State::ServerFinished { server_finished } + } else { + State::ClientFinished { + client_finished, + server_finished, + } + } } - State::ServerFinished { server_finished } => { + State::ServerFinished { + mut server_finished, + } => { server_finished.flush(vm)?; - } - _ => (), - } - if !self.wants_flush() { - self.state = match self.state.take() { - State::SessionKeys { - client_finished, - server_finished, - .. - } => State::ClientFinished { - client_finished, - server_finished, - }, - State::ClientFinished { - server_finished, .. - } => State::ServerFinished { server_finished }, - State::ServerFinished { .. } => State::Complete, - other => other, - }; - } + if !server_finished.wants_flush() { + State::Complete + } else { + State::ServerFinished { server_finished } + } + } + other => other, + }; + self.state = state; Ok(()) } } From 6b8ac8740923fc28d3de4d4b0b63d3f9aae9d748 Mon Sep 17 00:00:00 2001 From: th4s Date: Wed, 7 May 2025 17:55:50 +0200 Subject: [PATCH 35/36] simplify flush --- crates/components/hmac-sha256/src/prf.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/crates/components/hmac-sha256/src/prf.rs b/crates/components/hmac-sha256/src/prf.rs index eb2eea7a9f..f928f48bdb 100644 --- a/crates/components/hmac-sha256/src/prf.rs +++ b/crates/components/hmac-sha256/src/prf.rs @@ -199,7 +199,7 @@ impl MpcPrf { /// Flushes the PRF. pub fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { - let state = match self.state.take() { + self.state = match self.state.take() { State::SessionKeys { client_random, mut master_secret, @@ -254,7 +254,6 @@ impl MpcPrf { other => other, }; - self.state = state; Ok(()) } } From c732b24cb254a7288c5bd4ca11fa88e66b35caf7 Mon Sep 17 00:00:00 2001 From: sinu <65924192+sinui0@users.noreply.github.com> Date: Tue, 13 May 2025 09:15:44 -0700 Subject: [PATCH 36/36] fix wasm prover config --- crates/wasm/src/prover/config.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/wasm/src/prover/config.rs b/crates/wasm/src/prover/config.rs index e4e5aed76f..99028510b4 100644 --- a/crates/wasm/src/prover/config.rs +++ b/crates/wasm/src/prover/config.rs @@ -34,6 +34,10 @@ impl From for tlsn_prover::ProverConfig { builder.max_recv_records(value); } + if let Some(value) = value.defer_decryption_from_start { + builder.defer_decryption_from_start(value); + } + builder.network(value.network); let protocol_config = builder.build().unwrap(); @@ -42,10 +46,6 @@ impl From for tlsn_prover::ProverConfig { .server_name(value.server_name.as_ref()) .protocol_config(protocol_config); - if let Some(value) = value.defer_decryption_from_start { - builder.defer_decryption_from_start(value); - } - builder.build().unwrap() } }