diff --git a/Cargo.lock b/Cargo.lock index db21af1ce5..13b8a10b91 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7409,6 +7409,7 @@ version = "0.1.0-alpha.13-pre" dependencies = [ "criterion", "hex", + "hmac", "mpz-circuits", "mpz-common", "mpz-core", @@ -7418,6 +7419,7 @@ dependencies = [ "mpz-vm-core", "rand 0.9.2", "ring 0.17.14", + "rstest", "sha2", "thiserror 1.0.69", "tokio", diff --git a/Cargo.toml b/Cargo.toml index 73e23cdb1c..33fa2ef91b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,19 +66,19 @@ tlsn-harness-runner = { path = "crates/harness/runner" } tlsn-wasm = { path = "crates/wasm" } tlsn = { path = "crates/tlsn" } -mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "3d90b6c" } -mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "3d90b6c" } -mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "3d90b6c" } -mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "3d90b6c" } -mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "3d90b6c" } -mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "3d90b6c" } -mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "3d90b6c" } -mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "3d90b6c" } -mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "3d90b6c" } -mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "3d90b6c" } -mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "3d90b6c" } -mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "3d90b6c" } -mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "3d90b6c" } +mpz-circuits = { git = "https://github.com/privacy-ethereum/mpz", rev = "8a57d98" } +mpz-memory-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "8a57d98" } +mpz-common = { git = "https://github.com/privacy-ethereum/mpz", rev = "8a57d98" } +mpz-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "8a57d98" } +mpz-vm-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "8a57d98" } +mpz-garble = { git = "https://github.com/privacy-ethereum/mpz", rev = "8a57d98" } +mpz-garble-core = { git = "https://github.com/privacy-ethereum/mpz", rev = "8a57d98" } +mpz-ole = { git = "https://github.com/privacy-ethereum/mpz", rev = "8a57d98" } +mpz-ot = { git = "https://github.com/privacy-ethereum/mpz", rev = "8a57d98" } +mpz-share-conversion = { git = "https://github.com/privacy-ethereum/mpz", rev = "8a57d98" } +mpz-fields = { git = "https://github.com/privacy-ethereum/mpz", rev = "8a57d98" } +mpz-zk = { git = "https://github.com/privacy-ethereum/mpz", rev = "8a57d98" } +mpz-hash = { git = "https://github.com/privacy-ethereum/mpz", rev = "8a57d98" } rangeset = { version = "0.2" } serio = { version = "0.2" } diff --git a/crates/components/hmac-sha256/Cargo.toml b/crates/components/hmac-sha256/Cargo.toml index 1213e442be..d90f748192 100644 --- a/crates/components/hmac-sha256/Cargo.toml +++ b/crates/components/hmac-sha256/Cargo.toml @@ -20,9 +20,10 @@ mpz-core = { workspace = true } mpz-circuits = { workspace = true } mpz-hash = { workspace = true } +rand = { workspace = true } +sha2 = { workspace = true } thiserror = { workspace = true } tracing = { workspace = true } -sha2 = { workspace = true } [dev-dependencies] mpz-ot = { workspace = true, features = ["ideal"] } @@ -30,11 +31,17 @@ mpz-garble = { workspace = true } mpz-common = { workspace = true, features = ["test-utils"] } criterion = { workspace = true, features = ["async_tokio"] } -tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] } -rand = { workspace = true } hex = { workspace = true } +hmac = { workspace = true } ring = { workspace = true } +rstest = { workspace = true } +sha2 = { workspace = true } +tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread"] } [[bench]] -name = "prf" +name = "tls12" harness = false + +[[bench]] +name = "tls13" +harness = false \ No newline at end of file diff --git a/crates/components/hmac-sha256/benches/prf.rs b/crates/components/hmac-sha256/benches/tls12.rs similarity index 91% rename from crates/components/hmac-sha256/benches/prf.rs rename to crates/components/hmac-sha256/benches/tls12.rs index cc966741e4..02f74a7c0a 100644 --- a/crates/components/hmac-sha256/benches/prf.rs +++ b/crates/components/hmac-sha256/benches/tls12.rs @@ -2,7 +2,7 @@ use criterion::{criterion_group, criterion_main, Criterion}; -use hmac_sha256::{Mode, MpcPrf}; +use hmac_sha256::{Mode, Tls12Prf}; use mpz_common::context::test_mt_context; use mpz_garble::protocol::semihonest::{Evaluator, Garbler}; use mpz_ot::ideal::cot::ideal_cot; @@ -15,20 +15,22 @@ use rand::{rngs::StdRng, SeedableRng}; #[allow(clippy::unit_arg)] fn criterion_benchmark(c: &mut Criterion) { - let mut group = c.benchmark_group("prf"); + let mut group = c.benchmark_group("tls12"); group.sample_size(10); let rt = tokio::runtime::Runtime::new().unwrap(); - group.bench_function("prf_normal", |b| b.to_async(&rt).iter(|| prf(Mode::Normal))); - group.bench_function("prf_reduced", |b| { - b.to_async(&rt).iter(|| prf(Mode::Reduced)) + group.bench_function("tls12_normal", |b| { + b.to_async(&rt).iter(|| tls12(Mode::Normal)) + }); + group.bench_function("tls12_reduced", |b| { + b.to_async(&rt).iter(|| tls12(Mode::Reduced)) }); } criterion_group!(benches, criterion_benchmark); criterion_main!(benches); -async fn prf(mode: Mode) { +async fn tls12(mode: Mode) { let mut rng = StdRng::seed_from_u64(0); let pms = [42u8; 32]; @@ -55,8 +57,8 @@ async fn prf(mode: Mode) { follower_vm.assign(follower_pms, pms).unwrap(); follower_vm.commit(follower_pms).unwrap(); - let mut leader = MpcPrf::new(mode); - let mut follower = MpcPrf::new(mode); + let mut leader = Tls12Prf::new(mode); + let mut follower = Tls12Prf::new(mode); let leader_output = leader.alloc(&mut leader_vm, leader_pms).unwrap(); let follower_output = follower.alloc(&mut follower_vm, follower_pms).unwrap(); diff --git a/crates/components/hmac-sha256/benches/tls13.rs b/crates/components/hmac-sha256/benches/tls13.rs new file mode 100644 index 0000000000..b4d948d11d --- /dev/null +++ b/crates/components/hmac-sha256/benches/tls13.rs @@ -0,0 +1,139 @@ +#![allow(clippy::let_underscore_future)] + +use criterion::{criterion_group, criterion_main, Criterion}; + +use hmac_sha256::{Mode, Role, Tls13KeySched}; +use mpz_common::context::test_mt_context; +use mpz_garble::protocol::semihonest::{Evaluator, Garbler}; +use mpz_ot::ideal::cot::ideal_cot; +use mpz_vm_core::{ + memory::{ + binary::{Binary, U8}, + correlated::Delta, + Array, + }, + prelude::*, + Execute, Vm, +}; +use rand::{rngs::StdRng, SeedableRng}; + +#[allow(clippy::unit_arg)] +fn criterion_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("tls13"); + group.sample_size(10); + let rt = tokio::runtime::Runtime::new().unwrap(); + + group.bench_function("tls13_normal", |b| { + b.to_async(&rt).iter(|| tls13(Mode::Normal)) + }); + group.bench_function("tls13_reduced", |b| { + b.to_async(&rt).iter(|| tls13(Mode::Reduced)) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); + +async fn tls13(mode: Mode) { + let mut rng = StdRng::seed_from_u64(0); + + let pms = [42u8; 32]; + + let (mut leader_exec, mut follower_exec) = test_mt_context(8); + let mut leader_ctx = leader_exec.new_context().await.unwrap(); + let mut follower_ctx = follower_exec.new_context().await.unwrap(); + + let delta = Delta::random(&mut rng); + let (ot_send, ot_recv) = ideal_cot(delta.into_inner()); + + let mut leader_vm = Garbler::new(ot_send, [0u8; 16], delta); + let mut follower_vm = Evaluator::new(ot_recv); + + fn setup_ks( + vm: &mut (dyn Vm + Send), + pms: [u8; 32], + mode: Mode, + role: Role, + ) -> Tls13KeySched { + let secret: Array = vm.alloc().unwrap(); + vm.mark_public(secret).unwrap(); + vm.assign(secret, pms).unwrap(); + vm.commit(secret).unwrap(); + + let mut ks = Tls13KeySched::new(mode, role); + ks.alloc(vm, secret).unwrap(); + ks + } + + let mut leader_ks = setup_ks(&mut leader_vm, pms, mode, Role::Leader); + let mut follower_ks = setup_ks(&mut follower_vm, pms, mode, Role::Follower); + + while leader_ks.wants_flush() || follower_ks.wants_flush() { + tokio::try_join!( + async { + leader_ks.flush(&mut leader_vm).unwrap(); + leader_vm.execute_all(&mut leader_ctx).await + }, + async { + follower_ks.flush(&mut follower_vm).unwrap(); + follower_vm.execute_all(&mut follower_ctx).await + } + ) + .unwrap(); + } + + let hello_hash = [1u8; 32]; + + leader_ks.set_hello_hash(hello_hash).unwrap(); + follower_ks.set_hello_hash(hello_hash).unwrap(); + + while leader_ks.wants_flush() || follower_ks.wants_flush() { + tokio::try_join!( + async { + leader_ks.flush(&mut leader_vm).unwrap(); + leader_vm.execute_all(&mut leader_ctx).await + }, + async { + follower_ks.flush(&mut follower_vm).unwrap(); + follower_vm.execute_all(&mut follower_ctx).await + } + ) + .unwrap(); + } + + leader_ks.continue_to_app_keys().unwrap(); + follower_ks.continue_to_app_keys().unwrap(); + + while leader_ks.wants_flush() || follower_ks.wants_flush() { + tokio::try_join!( + async { + leader_ks.flush(&mut leader_vm).unwrap(); + leader_vm.execute_all(&mut leader_ctx).await + }, + async { + follower_ks.flush(&mut follower_vm).unwrap(); + follower_vm.execute_all(&mut follower_ctx).await + } + ) + .unwrap(); + } + + let handshake_hash = [2u8; 32]; + + leader_ks.set_handshake_hash(handshake_hash).unwrap(); + follower_ks.set_handshake_hash(handshake_hash).unwrap(); + + while leader_ks.wants_flush() || follower_ks.wants_flush() { + tokio::try_join!( + async { + leader_ks.flush(&mut leader_vm).unwrap(); + leader_vm.execute_all(&mut leader_ctx).await + }, + async { + follower_ks.flush(&mut follower_vm).unwrap(); + follower_vm.execute_all(&mut follower_ctx).await + } + ) + .unwrap(); + } +} diff --git a/crates/components/hmac-sha256/src/config.rs b/crates/components/hmac-sha256/src/config.rs index 90834ac646..6627248957 100644 --- a/crates/components/hmac-sha256/src/config.rs +++ b/crates/components/hmac-sha256/src/config.rs @@ -1,10 +1,10 @@ -//! PRF modes. +//! Modes of operation. -/// Modes for the PRF. -#[derive(Debug, Clone, Copy)] +/// Modes for the TLS 1.2 PRF and the TLS 1.3 key schedule. +#[derive(Debug, Clone, Copy, PartialEq)] pub enum Mode { /// Computes some hashes locally. Reduced, - /// Computes the whole PRF in MPC. + /// Computes the whole function in MPC. Normal, } diff --git a/crates/components/hmac-sha256/src/error.rs b/crates/components/hmac-sha256/src/error.rs index 6ed809f1bb..89b93e3075 100644 --- a/crates/components/hmac-sha256/src/error.rs +++ b/crates/components/hmac-sha256/src/error.rs @@ -3,15 +3,15 @@ use std::error::Error; use mpz_hash::sha256::Sha256Error; -/// A PRF error. +/// An error type used by the functionalities of this crate. #[derive(Debug, thiserror::Error)] -pub struct PrfError { +pub struct FError { kind: ErrorKind, #[source] source: Option>, } -impl PrfError { +impl FError { pub(crate) fn new(kind: ErrorKind, source: E) -> Self where E: Into>, @@ -34,7 +34,7 @@ impl PrfError { } } -impl From for PrfError { +impl From for FError { fn from(value: Sha256Error) -> Self { Self::new(ErrorKind::Hash, value) } @@ -47,7 +47,7 @@ pub(crate) enum ErrorKind { Hash, } -impl fmt::Display for PrfError { +impl fmt::Display for FError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self.kind { ErrorKind::Vm => write!(f, "vm error")?, diff --git a/crates/components/hmac-sha256/src/hmac.rs b/crates/components/hmac-sha256/src/hmac.rs index a9cd1abcd1..a185627d9d 100644 --- a/crates/components/hmac-sha256/src/hmac.rs +++ b/crates/components/hmac-sha256/src/hmac.rs @@ -2,7 +2,7 @@ //! //! HMAC-SHA256 is defined as //! -//! HMAC(m) = H((key' xor opad) || H((key' xor ipad) || m)) +//! HMAC(key, m) = H((key' xor opad) || H((key' xor ipad) || m)) //! //! * H - SHA256 hash function //! * key' - key padded with zero bytes to 64 bytes (we do not support longer @@ -11,167 +11,307 @@ //! * ipad - 64 bytes of 0x36 //! * m - message //! -//! This implementation computes HMAC-SHA256 using intermediate results -//! `outer_partial` and `inner_local`. Then HMAC(m) = H(outer_partial || -//! inner_local) +//! We describe HMAC in terms of the SHA-256 compression function +//! C(IV, m), where `IV` is the hash state, `m` is the input block, +//! and the output is the updated state. //! -//! * `outer_partial` - key' xor opad -//! * `inner_local` - H((key' xor ipad) || m) +//! HMAC(m) = C( C(IV, key' xor opad), C( C(IV, key' xor ipad), m) ) +//! +//! Throughout this crate we use the following terminology for +//! intermediate states: +//! +//! * `outer_partial` — C(IV, key' ⊕ opad) +//! * `inner_partial` — C(IV, key' ⊕ ipad) +//! * `inner_local` — C(inner_partial, m) +//! +//! The final value is then computed as: +//! +//! HMAC(m) = C(outer_partial, inner_local) + +use std::sync::Arc; +use crate::{ + hmac::{normal::HmacNormal, reduced::HmacReduced}, + sha256, state_to_bytes, Mode, +}; +use mpz_circuits::circuits::xor; use mpz_hash::sha256::Sha256; use mpz_vm_core::{ memory::{ binary::{Binary, U8}, - Array, + Array, MemoryExt, Vector, ViewExt, }, - Vm, + Call, CallableExt, Vm, }; -use crate::PrfError; +use crate::FError; + +pub(crate) mod clear; +pub(crate) mod normal; +pub(crate) mod reduced; +/// Inner padding of HMAC. pub(crate) const IPAD: [u8; 64] = [0x36; 64]; +/// Outer padding of HMAC. pub(crate) const OPAD: [u8; 64] = [0x5c; 64]; +/// Initial IV of SHA256. +pub(crate) const SHA256_IV: [u32; 8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, +]; + +/// Functionality for HMAC computation with a private key and a public message. +#[derive(Debug)] +#[allow(dead_code)] +pub(crate) enum Hmac { + Reduced(reduced::HmacReduced), + Normal(normal::HmacNormal), +} -/// Computes HMAC-SHA256 +impl Hmac { + /// Allocates a new HMAC with the given `key`. + pub(crate) fn alloc( + vm: &mut dyn Vm, + key: Vector, + mode: Mode, + ) -> Result { + match mode { + Mode::Reduced => Ok(Hmac::Reduced(HmacReduced::alloc(vm, key)?)), + Mode::Normal => Ok(Hmac::Normal(HmacNormal::alloc(vm, key)?)), + } + } + + /// Whether this functionality needs to be flushed. + #[allow(dead_code)] + pub(crate) fn wants_flush(&self) -> bool { + match self { + Hmac::Reduced(hmac) => hmac.wants_flush(), + Hmac::Normal(hmac) => hmac.wants_flush(), + } + } + + /// Flushes the functionality. + #[allow(dead_code)] + pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), FError> { + match self { + Hmac::Reduced(hmac) => hmac.flush(vm), + Hmac::Normal(hmac) => hmac.flush(), + } + } + + /// Returns HMAC output. + #[allow(dead_code)] + pub(crate) fn output(&self) -> Result, FError> { + match self { + Hmac::Reduced(hmac) => Ok(hmac.output()), + Hmac::Normal(hmac) => hmac.output(), + } + } + + /// Creates a new allocated instance of HMAC from another instance. + pub(crate) fn from_other(vm: &mut dyn Vm, other: &Self) -> Result { + match other { + Hmac::Reduced(hmac) => Ok(Hmac::Reduced(HmacReduced::from_other(vm, hmac)?)), + Hmac::Normal(hmac) => Ok(Hmac::Normal(HmacNormal::from_other(hmac)?)), + } + } +} + +/// Computes HMAC-SHA256. /// /// # Arguments /// /// * `vm` - The virtual machine. -/// * `outer_partial` - (key' xor opad) -/// * `inner_local` - H((key' xor ipad) || m) +/// * `outer_partial` - outer_partial. +/// * `inner_local` - inner_local. pub(crate) fn hmac_sha256( vm: &mut dyn Vm, mut outer_partial: Sha256, inner_local: Array, -) -> Result, PrfError> { +) -> Result, FError> { outer_partial.update(&inner_local.into()); outer_partial.compress(vm)?; - outer_partial.finalize(vm).map_err(PrfError::from) + outer_partial.finalize(vm).map_err(FError::from) +} + +/// Depending on the provided `mask` computes and returns outer_partial or +/// inner_partial for HMAC-SHA256. +/// +/// # Arguments +/// +/// * `vm` - Virtual machine. +/// * `key` - Key to pad and xor. +/// * `mask`- Mask used for padding. +fn compute_partial( + vm: &mut dyn Vm, + key: Vector, + mask: [u8; 64], +) -> Result { + let xor = Arc::new(xor(8 * 64)); + + let additional_len = 64 - key.len(); + let padding = vec![0_u8; additional_len]; + + let padding_ref: Vector = vm.alloc_vec(additional_len).map_err(FError::vm)?; + vm.mark_public(padding_ref).map_err(FError::vm)?; + vm.assign(padding_ref, padding).map_err(FError::vm)?; + vm.commit(padding_ref).map_err(FError::vm)?; + + let mask_ref: Array = vm.alloc().map_err(FError::vm)?; + vm.mark_public(mask_ref).map_err(FError::vm)?; + vm.assign(mask_ref, mask).map_err(FError::vm)?; + vm.commit(mask_ref).map_err(FError::vm)?; + + let xor = Call::builder(xor) + .arg(key) + .arg(padding_ref) + .arg(mask_ref) + .build() + .map_err(FError::vm)?; + let key_padded: Vector = vm.call(xor).map_err(FError::vm)?; + + let mut sha = Sha256::new_with_init(vm)?; + sha.update(&key_padded); + sha.compress(vm)?; + Ok(sha) +} + +/// Computes and assigns inner_local. +/// +/// # Arguments +/// +/// * `vm` - Virtual machine. +/// * `inner_local` - VM reference to assign to. +/// * `inner_partial` - inner_partial. +/// * `msg` - Message to be compressed. +pub(crate) fn assign_inner_local( + vm: &mut dyn Vm, + inner_local: Array, + inner_partial: [u32; 8], + msg: &[u8], +) -> Result<(), FError> { + let inner_local_value = sha256(inner_partial, 64, msg); + + vm.assign(inner_local, state_to_bytes(inner_local_value)) + .map_err(FError::vm)?; + vm.commit(inner_local).map_err(FError::vm)?; + + Ok(()) } #[cfg(test)] mod tests { - use crate::{ - hmac::hmac_sha256, - sha256, state_to_bytes, - test_utils::{compute_inner_local, compute_outer_partial, mock_vm}, - }; + use super::*; + use crate::test_utils::mock_vm; + use hmac::{Hmac as HmacReference, Mac}; use mpz_common::context::test_st_context; - use mpz_hash::sha256::Sha256; use mpz_vm_core::{ - memory::{ - binary::{U32, U8}, - Array, MemoryExt, ViewExt, - }, + memory::{MemoryExt, ViewExt}, Execute, }; + use rand::{rngs::StdRng, Rng, SeedableRng}; + use rstest::*; + use sha2::Sha256; - #[test] - fn test_hmac_reference() { - let (inputs, references) = test_fixtures(); + #[rstest] + #[case::normal(Mode::Normal)] + #[case::reduced(Mode::Reduced)] + #[tokio::test] + async fn test_hmac(#[case] mode: Mode) { + let mut rng = StdRng::from_seed([2; 32]); - for (input, &reference) in inputs.iter().zip(references.iter()) { - let outer_partial = compute_outer_partial(input.0.clone()); - let inner_local = compute_inner_local(input.0.clone(), &input.1); + for _ in 0..10 { + let key: [u8; 32] = rng.random(); + let msg: [u8; 32] = rng.random(); - let hmac = sha256(outer_partial, 64, &state_to_bytes(inner_local)); + let (mut ctx_a, mut ctx_b) = test_st_context(8); + let (mut leader, mut follower) = mock_vm(); - assert_eq!(state_to_bytes(hmac), reference); - } - } + let vm = &mut leader; + let key_ref = vm.alloc_vec(32).unwrap(); + vm.mark_public(key_ref).unwrap(); + vm.assign(key_ref, key.to_vec()).unwrap(); + vm.commit(key_ref).unwrap(); + let mut hmac_leader = Hmac::alloc(vm, key_ref, mode).unwrap(); - #[tokio::test] - async fn test_hmac_circuit() { - let (mut ctx_a, mut ctx_b) = test_st_context(8); - let (mut leader, mut follower) = mock_vm(); - - let (inputs, references) = test_fixtures(); - for (input, &reference) in inputs.iter().zip(references.iter()) { - let outer_partial = compute_outer_partial(input.0.clone()); - let inner_local = compute_inner_local(input.0.clone(), &input.1); - - let outer_partial_leader: Array = leader.alloc().unwrap(); - leader.mark_public(outer_partial_leader).unwrap(); - leader.assign(outer_partial_leader, outer_partial).unwrap(); - leader.commit(outer_partial_leader).unwrap(); - - let inner_local_leader: Array = leader.alloc().unwrap(); - leader.mark_public(inner_local_leader).unwrap(); - leader - .assign(inner_local_leader, state_to_bytes(inner_local)) - .unwrap(); - leader.commit(inner_local_leader).unwrap(); - - let hmac_leader = hmac_sha256( - &mut leader, - Sha256::new_from_state(outer_partial_leader, 1), - inner_local_leader, - ) - .unwrap(); - let hmac_leader = leader.decode(hmac_leader).unwrap(); - - let outer_partial_follower: Array = follower.alloc().unwrap(); - follower.mark_public(outer_partial_follower).unwrap(); - follower - .assign(outer_partial_follower, outer_partial) - .unwrap(); - follower.commit(outer_partial_follower).unwrap(); - - let inner_local_follower: Array = follower.alloc().unwrap(); - follower.mark_public(inner_local_follower).unwrap(); - follower - .assign(inner_local_follower, state_to_bytes(inner_local)) - .unwrap(); - follower.commit(inner_local_follower).unwrap(); - - let hmac_follower = hmac_sha256( - &mut follower, - Sha256::new_from_state(outer_partial_follower, 1), - inner_local_follower, - ) - .unwrap(); - let hmac_follower = follower.decode(hmac_follower).unwrap(); + if mode == Mode::Reduced { + if let Hmac::Reduced(ref mut hmac) = hmac_leader { + hmac.set_msg(&msg).unwrap(); + }; + } else if let Hmac::Normal(ref mut hmac) = hmac_leader { + let msg_ref = vm.alloc_vec(msg.len()).unwrap(); + vm.mark_public(msg_ref).unwrap(); + vm.assign(msg_ref, msg.to_vec()).unwrap(); + vm.commit(msg_ref).unwrap(); + hmac.set_msg(vm, &[msg_ref]).unwrap(); + } + let leader_out = hmac_leader.output().unwrap(); + let mut leader_out = vm.decode(leader_out).unwrap(); - let (hmac_leader, hmac_follower) = tokio::try_join!( + let vm = &mut follower; + let key_ref = vm.alloc_vec(32).unwrap(); + vm.mark_public(key_ref).unwrap(); + vm.assign(key_ref, key.to_vec()).unwrap(); + vm.commit(key_ref).unwrap(); + let mut hmac_follower = Hmac::alloc(vm, key_ref, mode).unwrap(); + + if mode == Mode::Reduced { + if let Hmac::Reduced(ref mut hmac) = hmac_follower { + hmac.set_msg(&msg).unwrap(); + }; + } else if let Hmac::Normal(ref mut hmac) = hmac_follower { + let msg_ref = vm.alloc_vec(msg.len()).unwrap(); + vm.mark_public(msg_ref).unwrap(); + vm.assign(msg_ref, msg.to_vec()).unwrap(); + vm.commit(msg_ref).unwrap(); + hmac.set_msg(vm, &[msg_ref]).unwrap(); + } + let follower_out = hmac_follower.output().unwrap(); + let mut follower_out = vm.decode(follower_out).unwrap(); + + tokio::try_join!( async { + assert!(hmac_leader.wants_flush()); + hmac_leader.flush(&mut leader).unwrap(); leader.execute_all(&mut ctx_a).await.unwrap(); - hmac_leader.await + + // In reduced mode two flushes are required. + if mode == Mode::Reduced { + assert!(hmac_leader.wants_flush()); + hmac_leader.flush(&mut leader).unwrap(); + leader.execute_all(&mut ctx_a).await.unwrap(); + } + + assert!(!hmac_leader.wants_flush()); + + Ok::<(), Box>(()) }, async { + assert!(hmac_follower.wants_flush()); + hmac_follower.flush(&mut follower).unwrap(); follower.execute_all(&mut ctx_b).await.unwrap(); - hmac_follower.await + + // On reduced mode two flushes are required. + if mode == Mode::Reduced { + assert!(hmac_follower.wants_flush()); + hmac_follower.flush(&mut follower).unwrap(); + follower.execute_all(&mut ctx_b).await.unwrap(); + } + + assert!(!hmac_follower.wants_flush()); + + Ok::<(), Box>(()) } ) .unwrap(); - assert_eq!(hmac_leader, hmac_follower); - assert_eq!(hmac_leader, reference); - } - } + let leader_out = leader_out.try_recv().unwrap().unwrap(); + let follower_out = follower_out.try_recv().unwrap().unwrap(); - #[allow(clippy::type_complexity)] - fn test_fixtures() -> (Vec<(Vec, Vec)>, Vec<[u8; 32]>) { - let test_vectors: Vec<(Vec, Vec)> = vec![ - ( - hex::decode("0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b").unwrap(), - hex::decode("4869205468657265").unwrap(), - ), - ( - hex::decode("4a656665").unwrap(), - hex::decode("7768617420646f2079612077616e7420666f72206e6f7468696e673f").unwrap(), - ), - ]; - let expected: Vec<[u8; 32]> = vec![ - hex::decode("b0344c61d8db38535ca8afceaf0bf12b881dc200c9833da726e9376c2e32cff7") - .unwrap() - .try_into() - .unwrap(), - hex::decode("5bdcc146bf60754e6a042426089575c75a003f089d2739839dec58b964ec3843") - .unwrap() - .try_into() - .unwrap(), - ]; - - (test_vectors, expected) + let mut hmac_ref = HmacReference::::new_from_slice(&key).unwrap(); + hmac_ref.update(&msg); + + assert_eq!(leader_out, follower_out); + assert_eq!(leader_out, *hmac_ref.finalize().into_bytes()); + } } } diff --git a/crates/components/hmac-sha256/src/hmac/clear.rs b/crates/components/hmac-sha256/src/hmac/clear.rs new file mode 100644 index 0000000000..5da0ed7c9c --- /dev/null +++ b/crates/components/hmac-sha256/src/hmac/clear.rs @@ -0,0 +1,72 @@ +//! Computation of HMAC-SHA256 on cleartext values. +use crate::{ + compress_256, + hmac::{IPAD, OPAD, SHA256_IV}, + sha256, state_to_bytes, +}; + +/// Depending on the provided `mask` computes and returns outer_partial or +/// inner_partial for HMAC-SHA256. +fn compute_partial(key: &[u8], mask: &[u8; 64]) -> [u32; 8] { + assert!(key.len() <= 64); + let mut key = key.to_vec(); + + key.resize(64, 0_u8); + let key_padded: [u8; 64] = key + .into_iter() + .zip(mask) + .map(|(b, mask)| b ^ mask) + .collect::>() + .try_into() + .expect("output length is 64 bytes"); + + compress_256(SHA256_IV, &key_padded) +} + +/// Computes and returns inner_partial for HMAC-SHA256. +pub(crate) fn compute_inner_partial(key: &[u8]) -> [u32; 8] { + compute_partial(key, &IPAD) +} + +/// Computes and returns outer_partial for HMAC-SHA256. +pub(crate) fn compute_outer_partial(key: &[u8]) -> [u32; 8] { + compute_partial(key, &OPAD) +} + +/// Computes and returns inner_local for HMAC-SHA256. +fn compute_inner_local(key: &[u8], msg: &[u8]) -> [u32; 8] { + sha256(compute_inner_partial(key), 64, msg) +} + +/// Computes and returns the HMAC-SHA256 output. +pub(crate) fn hmac_sha256(key: &[u8], msg: &[u8]) -> [u8; 32] { + let outer_partial = compute_outer_partial(key); + let inner_local = compute_inner_local(key, msg); + + let hmac = sha256(outer_partial, 64, &state_to_bytes(inner_local)); + state_to_bytes(hmac) +} + +#[cfg(test)] +mod tests { + use super::*; + use hmac::{Hmac, Mac}; + use rand::{rngs::StdRng, Rng, SeedableRng}; + use sha2::Sha256; + + #[test] + fn test_hmac_sha256() { + let mut rng = StdRng::from_seed([1; 32]); + + for _ in 0..10 { + let key: [u8; 32] = rng.random(); + let msg: [u8; 32] = rng.random(); + + let mut mac = + Hmac::::new_from_slice(&key).expect("HMAC can take key of any size"); + mac.update(&msg); + + assert_eq!(hmac_sha256(&key, &msg), *mac.finalize().into_bytes()) + } + } +} diff --git a/crates/components/hmac-sha256/src/hmac/normal.rs b/crates/components/hmac-sha256/src/hmac/normal.rs new file mode 100644 index 0000000000..15b6f33913 --- /dev/null +++ b/crates/components/hmac-sha256/src/hmac/normal.rs @@ -0,0 +1,141 @@ +use mpz_hash::sha256::Sha256; +use mpz_vm_core::{ + memory::{ + binary::{Binary, U32, U8}, + Array, MemoryExt, Vector, ViewExt, + }, + Vm, +}; + +use crate::{ + hmac::{compute_partial, hmac_sha256, IPAD, OPAD}, + FError, +}; + +/// Functionality for HMAC computation with a private key and a public message. +/// +/// Used in conjunction with [crate::Mode::Normal]. +#[derive(Debug, Clone)] +pub(crate) struct HmacNormal { + inner_partial: Sha256, + outer_partial: Sha256, + output: Option>, + state: State, +} + +impl HmacNormal { + /// Allocates a new HMAC with the given `key`. + pub(crate) fn alloc(vm: &mut dyn Vm, key: Vector) -> Result { + Ok(Self { + inner_partial: compute_partial(vm, key, IPAD)?, + outer_partial: compute_partial(vm, key, OPAD)?, + output: None, + state: State::WantsMsg, + }) + } + + /// Allocates a new HMAC with the given `inner_partial` and + /// `outer_partial`. + pub(crate) fn alloc_with_state( + vm: &mut dyn Vm, + inner_partial: [u32; 8], + outer_partial: [u32; 8], + ) -> Result { + let inner_p: Array = vm.alloc().map_err(FError::vm)?; + vm.mark_public(inner_p).map_err(FError::vm)?; + vm.assign(inner_p, inner_partial).map_err(FError::vm)?; + vm.commit(inner_p).map_err(FError::vm)?; + let inner = Sha256::new_from_state(inner_p, 1); + + let outer_p: Array = vm.alloc().map_err(FError::vm)?; + vm.mark_public(outer_p).map_err(FError::vm)?; + vm.assign(outer_p, outer_partial).map_err(FError::vm)?; + vm.commit(outer_p).map_err(FError::vm)?; + let outer = Sha256::new_from_state(outer_p, 1); + + Ok(Self { + inner_partial: inner, + outer_partial: outer, + output: None, + state: State::WantsMsg, + }) + } + + /// Whether this functionality needs to be flushed. + pub(crate) fn wants_flush(&self) -> bool { + matches!(self.state, State::MsgSet) + } + + /// Flushes the functionality. + pub(crate) fn flush(&mut self) -> Result<(), FError> { + if let State::MsgSet = self.state { + self.state = State::Complete; + } + + Ok(()) + } + + /// Sets an HMAC message `msg`. + /// + /// The message is a slice of vectors which will be concatenated. + pub(crate) fn set_msg( + &mut self, + vm: &mut dyn Vm, + msg: &[Vector], + ) -> Result<(), FError> { + match self.state { + State::WantsMsg => (), + _ => return Err(FError::state("must be in WantsMsg state to set message")), + } + + msg.iter().for_each(|m| self.inner_partial.update(m)); + self.inner_partial.compress(vm).map_err(FError::vm)?; + let inner_local = self.inner_partial.finalize(vm).map_err(FError::vm)?; + let out = hmac_sha256(vm, self.outer_partial.clone(), inner_local)?; + + self.output = Some(out); + self.state = State::MsgSet; + + Ok(()) + } + + /// Returns HMAC output. + pub(crate) fn output(&self) -> Result, FError> { + match self.state { + State::MsgSet | State::Complete => Ok(self + .output + .expect("output is available when message is set")), + _ => Err(FError::state( + "must be in MsgSet or Complete state to return output", + )), + } + } + + /// Creates a new allocated instance of HMAC from another instance. + pub(crate) fn from_other(other: &Self) -> Result { + match other.state { + State::WantsMsg => Ok(Self { + inner_partial: other.inner_partial.clone(), + outer_partial: other.outer_partial.clone(), + output: None, + state: State::WantsMsg, + }), + + _ => Err(FError::state("other must be in WantsMsg state")), + } + } + + /// Whether this functionality is complete. + pub(crate) fn is_complete(&self) -> bool { + matches!(self.state, State::Complete) + } +} + +/// State of [HmacNormal]. +#[derive(Debug, Clone)] +pub(crate) enum State { + WantsMsg, + /// The state after the message has been set. + MsgSet, + Complete, +} diff --git a/crates/components/hmac-sha256/src/hmac/reduced.rs b/crates/components/hmac-sha256/src/hmac/reduced.rs new file mode 100644 index 0000000000..31c21d92d3 --- /dev/null +++ b/crates/components/hmac-sha256/src/hmac/reduced.rs @@ -0,0 +1,163 @@ +use crate::hmac::{assign_inner_local, compute_partial, hmac_sha256, IPAD, OPAD}; +use mpz_hash::sha256::Sha256; +use mpz_vm_core::{ + memory::{ + binary::{Binary, U32, U8}, + Array, MemoryExt, Vector, ViewExt, + }, + Vm, +}; + +use crate::FError; + +/// Functionality for HMAC computation with a private key and a public message. +/// +/// Used in conjunction with [crate::Mode::Reduced]. +#[derive(Debug)] +pub(crate) struct HmacReduced { + outer_partial: Sha256, + inner_local: Array, + inner_partial: Array, + msg: Option>, + output: Array, + state: State, +} + +impl HmacReduced { + /// Allocates a new HMAC with the given `key`. + pub(crate) fn alloc(vm: &mut dyn Vm, key: Vector) -> Result { + let inner_partial = compute_partial(vm, key, IPAD)?; + let outer_partial = compute_partial(vm, key, OPAD)?; + + let (inner_partial, _) = inner_partial + .state() + .expect("state should be set for inner_partial"); + // Decode as soon as the value is computed. + std::mem::drop(vm.decode(inner_partial).map_err(FError::vm)?); + + let inner_local: Array = vm.alloc().map_err(FError::vm)?; + vm.mark_public(inner_local).map_err(FError::vm)?; + let out = hmac_sha256(vm, outer_partial.clone(), inner_local)?; + + Ok(Self { + outer_partial, + inner_local, + inner_partial, + msg: None, + output: out, + state: State::WantsInnerPartial, + }) + } + + /// Whether this functionality needs to be flushed. + pub(crate) fn wants_flush(&self) -> bool { + match self.state { + State::WantsInnerPartial => true, + State::WantsMsg { .. } => self.msg.is_some(), + _ => false, + } + } + + /// Flushes the functionality. + pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), FError> { + let state = self.state.take(); + + match state { + State::WantsInnerPartial => { + let mut inner_partial = vm.decode(self.inner_partial).map_err(FError::vm)?; + let Some(inner_partial) = inner_partial.try_recv().map_err(FError::vm)? else { + self.state = State::WantsInnerPartial; + return Ok(()); + }; + + self.state = State::WantsMsg { inner_partial }; + // Recurse. + self.flush(vm)?; + } + State::WantsMsg { inner_partial } => { + // output is Some after msg was set + if self.msg.is_some() { + assign_inner_local( + vm, + self.inner_local, + inner_partial, + &self.msg.clone().unwrap(), + )?; + + self.state = State::Complete; + } else { + self.state = State::WantsMsg { inner_partial }; + } + } + _ => self.state = state, + } + + Ok(()) + } + + /// Sets the HMAC message. + pub(crate) fn set_msg(&mut self, msg: &[u8]) -> Result<(), FError> { + match self.msg { + None => self.msg = Some(msg.to_vec()), + Some(_) => return Err(FError::state("message has already been set")), + } + + Ok(()) + } + + /// Whether the HMAC message has been set. + pub(crate) fn is_msg_set(&mut self) -> bool { + self.msg.is_some() + } + + /// Returns the HMAC output. + pub(crate) fn output(&self) -> Array { + self.output + } + + /// Creates a new allocated instance of HMAC from another instance. + pub(crate) fn from_other(vm: &mut dyn Vm, other: &Self) -> Result { + match other.state { + State::WantsInnerPartial => { + let inner_local: Array = vm.alloc().map_err(FError::vm)?; + vm.mark_public(inner_local).map_err(FError::vm)?; + + let out = hmac_sha256(vm, other.outer_partial.clone(), inner_local)?; + + Ok(Self { + outer_partial: other.outer_partial.clone(), + inner_local, + inner_partial: other.inner_partial, + msg: None, + output: out, + state: State::WantsInnerPartial, + }) + } + _ => Err(FError::state("other must be in WantsInnerPartial state")), + } + } + + /// Whether this functionality is complete. + pub(crate) fn is_complete(&self) -> bool { + matches!(self.state, State::Complete) + } +} + +/// State of [HmacReduced]. +#[derive(Debug, Clone)] +pub(crate) enum State { + /// Wants the decoded inner_partial plaintext. + WantsInnerPartial, + /// Wants the message to be set. + WantsMsg { + inner_partial: [u32; 8], + }, + Complete, + Error, +} + +impl State { + pub(crate) fn take(&mut self) -> State { + std::mem::replace(self, State::Error) + } +} diff --git a/crates/components/hmac-sha256/src/kdf/expand.rs b/crates/components/hmac-sha256/src/kdf/expand.rs new file mode 100644 index 0000000000..a4e55150ca --- /dev/null +++ b/crates/components/hmac-sha256/src/kdf/expand.rs @@ -0,0 +1,292 @@ +//! `HKDF-Expand-Label` function as defined in TLS 1.3. + +use mpz_vm_core::{ + memory::{ + binary::{Binary, U8}, + Vector, + }, + Vm, +}; + +use crate::{ + hmac::{clear, Hmac}, + kdf::expand::label::make_hkdf_label, + FError, Mode, +}; + +pub(crate) mod label; +pub(crate) mod normal; +pub(crate) mod reduced; + +/// A zero_length HKDF-Expand-Label context. +pub(crate) const EMPTY_CTX: [u8; 0] = []; + +/// Functionality for computing `HKDF-Expand-Label` with a private secret +/// and public label and context. +#[derive(Debug)] +pub(crate) enum HkdfExpand { + Reduced(reduced::HkdfExpand), + Normal(normal::HkdfExpand), +} + +impl HkdfExpand { + /// Allocates a new HKDF-Expand-Label with the `hmac` + /// instantiated with the secret. + pub(crate) fn alloc( + mode: Mode, + vm: &mut dyn Vm, + // Partial hash states of the secret. + hmac: Hmac, + // Human-readable label. + label: &'static [u8], + // Context. + ctx: Option<&[u8]>, + // Context length. + ctx_len: usize, + // Output length. + out_len: usize, + ) -> Result { + let prf = match mode { + Mode::Reduced => { + if let Hmac::Reduced(hmac) = hmac { + let mut hkdf = reduced::HkdfExpand::alloc(hmac, label, out_len)?; + if let Some(ctx) = ctx { + hkdf.set_ctx(ctx)?; + } + Self::Reduced(hkdf) + } else { + unreachable!("modes always match"); + } + } + Mode::Normal => { + if let Hmac::Normal(hmac) = hmac { + let mut hkdf = normal::HkdfExpand::alloc(vm, hmac, label, ctx_len, out_len)?; + if let Some(ctx) = ctx { + hkdf.set_ctx(ctx)?; + } + Self::Normal(hkdf) + } else { + unreachable!("modes always match"); + } + } + }; + Ok(prf) + } + + /// Whether this functionality needs to be flushed. + pub(crate) fn wants_flush(&self) -> bool { + match self { + HkdfExpand::Reduced(hkdf) => hkdf.wants_flush(), + HkdfExpand::Normal(hkdf) => hkdf.wants_flush(), + } + } + + /// Flushes the functionality. + pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), FError> { + match self { + HkdfExpand::Reduced(hkdf) => hkdf.flush(vm), + HkdfExpand::Normal(hkdf) => hkdf.flush(vm), + } + } + + /// Sets the HKDF-Expand-Label context. + pub(crate) fn set_ctx(&mut self, ctx: &[u8]) -> Result<(), FError> { + match self { + HkdfExpand::Reduced(hkdf) => hkdf.set_ctx(ctx), + HkdfExpand::Normal(hkdf) => hkdf.set_ctx(ctx), + } + } + + /// Whether the context has been set. + pub(crate) fn is_ctx_set(&self) -> bool { + match self { + HkdfExpand::Reduced(hkdf) => hkdf.is_ctx_set(), + HkdfExpand::Normal(hkdf) => hkdf.is_ctx_set(), + } + } + + /// Returns the HKDF-Expand-Label output. + pub(crate) fn output(&self) -> Vector { + match self { + HkdfExpand::Reduced(hkdf) => hkdf.output(), + HkdfExpand::Normal(hkdf) => hkdf.output(), + } + } + + /// Whether this functionality is complete. + pub(crate) fn is_complete(&self) -> bool { + match self { + HkdfExpand::Reduced(hkdf) => hkdf.is_complete(), + HkdfExpand::Normal(hkdf) => hkdf.is_complete(), + } + } +} + +/// Computes `HKDF-Expand-Label` as defined in TLS 1.3. +pub(crate) fn hkdf_expand_label(key: &[u8], label: &[u8], ctx: &[u8], len: usize) -> Vec { + hkdf_expand(key, &make_hkdf_label(label, ctx, len), len) +} + +/// Computes `HKDF-Expand` as defined in https://datatracker.ietf.org/doc/html/rfc5869 +fn hkdf_expand(prk: &[u8], info: &[u8], len: usize) -> Vec { + assert!(len <= 32, "output length larger than 32 is not supported"); + let mut info = info.to_vec(); + info.push(0x01); + clear::hmac_sha256(prk, &info)[..len].to_vec() +} + +#[cfg(test)] +mod tests { + use crate::{ + hmac::{normal::HmacNormal, Hmac}, + kdf::expand::{hkdf_expand_label, HkdfExpand}, + test_utils::mock_vm, + Mode, + }; + use mpz_common::context::test_st_context; + use mpz_vm_core::{ + memory::{binary::Binary, MemoryExt, ViewExt}, + Execute, Vm, + }; + use rstest::*; + + #[rstest] + #[case::normal(Mode::Normal)] + #[case::reduced(Mode::Reduced)] + #[tokio::test] + async fn test_hkdf_expand(#[case] mode: Mode) { + for fixture in test_fixtures() { + let (label, prk, ctx, output) = fixture; + + let (mut ctx_a, mut ctx_b) = test_st_context(8); + let (mut leader, mut follower) = mock_vm(); + + fn setup_hkdf( + vm: &mut (dyn Vm + Send), + prk: [u8; 32], + label: &'static [u8], + ctx: Option<&[u8]>, + ctx_len: usize, + out_len: usize, + mode: Mode, + ) -> HkdfExpand { + let secret = vm.alloc_vec(32).unwrap(); + vm.mark_public(secret).unwrap(); + vm.assign(secret, prk.to_vec()).unwrap(); + vm.commit(secret).unwrap(); + + let hmac = if mode == Mode::Normal { + Hmac::Normal(HmacNormal::alloc(vm, secret).unwrap()) + } else { + use crate::hmac::reduced::HmacReduced; + + Hmac::Reduced(HmacReduced::alloc(vm, secret).unwrap()) + }; + + HkdfExpand::alloc(mode, vm, hmac, label, ctx, ctx_len, out_len).unwrap() + } + + let mut hkdf_leader = setup_hkdf( + &mut leader, + prk.clone().try_into().unwrap(), + label, + Some(&ctx), + ctx.len(), + output.len(), + mode, + ); + + let mut hkdf_follower = setup_hkdf( + &mut follower, + prk.clone().try_into().unwrap(), + label, + Some(&ctx), + ctx.len(), + output.len(), + mode, + ); + + let out_leader = hkdf_leader.output(); + let mut leader_decode_fut = leader.decode(out_leader).unwrap(); + + let out_follower = hkdf_follower.output(); + let mut follower_decode_fut = follower.decode(out_follower).unwrap(); + + tokio::try_join!( + async { + leader.execute_all(&mut ctx_a).await.unwrap(); + assert!(hkdf_leader.wants_flush()); + hkdf_leader.flush(&mut leader).unwrap(); + assert!(!hkdf_leader.wants_flush()); + leader.execute_all(&mut ctx_a).await.unwrap(); + + Ok::<(), Box>(()) + }, + async { + follower.execute_all(&mut ctx_b).await.unwrap(); + assert!(hkdf_follower.wants_flush()); + hkdf_follower.flush(&mut follower).unwrap(); + assert!(!hkdf_follower.wants_flush()); + follower.execute_all(&mut ctx_b).await.unwrap(); + + Ok::<(), Box>(()) + } + ) + .unwrap(); + + let out_leader = leader_decode_fut.try_recv().unwrap().unwrap(); + let out_follower = follower_decode_fut.try_recv().unwrap().unwrap(); + assert_eq!(out_leader, out_follower); + assert_eq!(out_leader, output); + } + } + + #[test] + fn test_hkdf_expand_label() { + for fixture in test_fixtures() { + let (label, prk, ctx, output) = fixture; + let out = hkdf_expand_label(&prk, label, &ctx, output.len()); + assert_eq!(out, output); + } + } + + // Reference values from https://datatracker.ietf.org/doc/html/draft-ietf-tls-tls13-vectors-06 + #[allow(clippy::type_complexity)] + fn test_fixtures() -> Vec<(&'static [u8], Vec, Vec, Vec)> { + vec![( + // LABEL + b"c hs traffic", + // PRK + from_hex_str("5b 4f 96 5d f0 3c 68 2c 46 e6 ee 86 c3 11 63 66 15 a1 d2 bb b2 43 45 c2 52 05 95 3c 87 9e 8d 06").to_vec(), + // CTX + from_hex_str("c6 c9 18 ad 2f 41 99 d5 59 8e af 01 16 cb 7a 5c 2c 14 cb 54 78 12 18 88 8d b7 03 0d d5 0d 5e 6d").to_vec(), + // OUTPUT + from_hex_str("e2 e2 32 07 bd 93 fb 7f e4 fc 2e 29 7a fe ab 16 0e 52 2b 5a b7 5d 64 a8 6e 75 bc ac 3f 3e 51 03").to_vec(), + ), + ( + // LABEL + b"s hs traffic", + // PRK + from_hex_str("5b 4f 96 5d f0 3c 68 2c 46 e6 ee 86 c3 11 63 66 15 a1 d2 bb b2 43 45 c2 52 05 95 3c 87 9e 8d 06").to_vec(), + // CTX + from_hex_str("c6 c9 18 ad 2f 41 99 d5 59 8e af 01 16 cb 7a 5c 2c 14 cb 54 78 12 18 88 8d b7 03 0d d5 0d 5e 6d").to_vec(), + // OUTPUT + from_hex_str("3b 7a 83 9c 23 9e f2 bf 0b 73 05 a0 e0 c4 e5 a8 c6 c6 93 30 a7 53 b3 08 f5 e3 a8 3a a2 ef 69 79").to_vec(), + ), + ( + // LABEL + b"c ap traffic", + // PRK + from_hex_str("5c 79 d1 69 42 4e 26 2b 56 32 03 62 7b e4 eb 51 03 3f 58 8c 43 c9 ce 03 73 37 2d bc bc 01 85 a7").to_vec(), + // CTX + from_hex_str("f8 c1 9e 8c 77 c0 38 79 bb c8 eb 6d 56 e0 0d d5 d8 6e f5 59 27 ee fc 08 e1 b0 02 b6 ec e0 5d bf").to_vec(), + // OUTPUT + from_hex_str("e2 f0 db 6a 82 e8 82 80 fc 26 f7 3c 89 85 4e e8 61 5e 25 df 28 b2 20 79 62 fa 78 22 26 b2 36 26").to_vec(), + ) + ] + } + + fn from_hex_str(s: &str) -> Vec { + hex::decode(s.split_whitespace().collect::()).unwrap() + } +} diff --git a/crates/components/hmac-sha256/src/kdf/expand/label.rs b/crates/components/hmac-sha256/src/kdf/expand/label.rs new file mode 100644 index 0000000000..144d6c0f57 --- /dev/null +++ b/crates/components/hmac-sha256/src/kdf/expand/label.rs @@ -0,0 +1,267 @@ +//! Computation of HkdfLabel as specified in TLS 1.3. + +use crate::FError; + +use mpz_vm_core::{ + memory::{ + binary::{Binary, U8}, + MemoryExt, Vector, ViewExt, + }, + Vm, +}; + +/// Functionality for HkdfLabel computation. +#[derive(Debug)] +pub(crate) struct HkdfLabel { + /// Cleartext label. + label: HkdfLabelClear, + // VM reference for the HKDF label. + output: Vector, + // Label context. + ctx: Option>, + state: State, +} + +impl HkdfLabel { + /// Allocates a new HkdfLabel. + pub(crate) fn alloc( + vm: &mut dyn Vm, + label: &'static [u8], + ctx_len: usize, + out_len: usize, + ) -> Result { + let label_ref = vm + .alloc_vec::(hkdf_label_length(label.len(), ctx_len)) + .map_err(FError::vm)?; + vm.mark_public(label_ref).map_err(FError::vm)?; + + Ok(Self { + label: HkdfLabelClear::new(label, out_len), + output: label_ref, + ctx: None, + state: State::WantsContext, + }) + } + + /// Whether this functionality needs to be flushed. + pub(crate) fn wants_flush(&self) -> bool { + match self.state { + State::WantsContext => self.is_ctx_set(), + _ => false, + } + } + + /// Flushes the functionality. + pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), FError> { + if let State::WantsContext = &mut self.state { + if let Some(ctx) = &self.ctx { + self.label.set_ctx(ctx)?; + + vm.assign(self.output, self.label.output()?) + .map_err(FError::vm)?; + vm.commit(self.output).map_err(FError::vm)?; + + self.state = State::Complete; + } + } + + Ok(()) + } + + /// Sets label context. + pub(crate) fn set_ctx(&mut self, ctx: &[u8]) -> Result<(), FError> { + if self.is_ctx_set() { + return Err(FError::state("context has already been set")); + } + + self.ctx = Some(ctx.to_vec()); + + Ok(()) + } + + /// Returns the HkdfLabel output. + pub(crate) fn output(&self) -> Vector { + self.output + } + + /// Whether this functionality is complete. + pub(crate) fn is_complete(&self) -> bool { + matches!(self.state, State::Complete) + } + + /// Returns whether context has been set. + fn is_ctx_set(&self) -> bool { + self.ctx.is_some() + } +} + +#[derive(Debug)] +enum State { + /// Wants the context to be set. + WantsContext, + Complete, +} + +/// Functionality for HkdfLabel computation on cleartext values. +#[derive(Debug)] +pub(crate) struct HkdfLabelClear { + /// Human-readable label. + label: &'static [u8], + /// Context. + ctx: Option>, + /// Output length. + out_len: usize, +} + +impl HkdfLabelClear { + /// Creates a new label. + pub(crate) fn new(label: &'static [u8], out_len: usize) -> Self { + Self { + label, + ctx: None, + out_len, + } + } + + /// Sets label context. + pub(crate) fn set_ctx(&mut self, ctx: &[u8]) -> Result<(), FError> { + if self.ctx.is_some() { + return Err(FError::state("context has already been set")); + } + + self.ctx = Some(ctx.to_vec()); + Ok(()) + } + + /// Returns the byte representation of the label. + pub(crate) fn output(&self) -> Result, FError> { + match &self.ctx { + Some(ctx) => Ok(make_hkdf_label(self.label, ctx, self.out_len)), + _ => Err(FError::state("context was not set")), + } + } +} + +/// Returns the byte representation of an HKDF label. +pub(crate) fn make_hkdf_label(label: &[u8], ctx: &[u8], out_len: usize) -> Vec { + assert!( + out_len <= 256, + "output length larger than 256 not supported" + ); + + const LABEL_PREFIX: &[u8] = b"tls13 "; + + let mut hkdf_label = Vec::new(); + let output_len = u16::to_be_bytes(out_len as u16); + let label_len = u8::to_be_bytes((LABEL_PREFIX.len() + label.len()) as u8); + let context_len = u8::to_be_bytes(ctx.len() as u8); + + hkdf_label.extend_from_slice(&output_len); + hkdf_label.extend_from_slice(&label_len); + hkdf_label.extend_from_slice(LABEL_PREFIX); + hkdf_label.extend_from_slice(label); + hkdf_label.extend_from_slice(&context_len); + hkdf_label.extend_from_slice(ctx); + hkdf_label +} + +/// Returns the length of an HKDF label. +fn hkdf_label_length(label_len: usize, ctx_len: usize) -> usize { + // 2 : output length as u16 + // 1 : label length as u8 + // 6 : length of "tls13 " + // 1 : context length as u8 + // see `make_hkdf_label` + 2 + 1 + 6 + label_len + 1 + ctx_len +} + +#[cfg(test)] +mod tests { + use crate::kdf::expand::label::make_hkdf_label; + + #[test] + fn test_make_hkdf_label() { + for fixture in test_fixtures() { + let (label, ctx, hkdf_label, out_len) = fixture; + assert_eq!(make_hkdf_label(label, &ctx, out_len), hkdf_label); + } + } + + // Test vectors from https://datatracker.ietf.org/doc/html/draft-ietf-tls-tls13-vectors-06 + // (in that ref, `hash` is the context, `info` is the hkdf label). + #[allow(clippy::type_complexity)] + fn test_fixtures() -> Vec<(&'static [u8], Vec, Vec, usize)> { + vec![ + ( + b"derived", + from_hex_str("e3 b0 c4 42 98 fc 1c 14 9a fb f4 c8 99 6f b9 24 27 ae 41 e4 64 9b 93 4c a4 95 99 1b 78 52 b8 55"), + from_hex_str("00 20 0d 74 6c 73 31 33 20 64 65 72 69 76 65 64 20 e3 b0 c4 42 98 fc 1c 14 9a fb f4 c8 99 6f b9 24 27 ae 41 e4 64 9b 93 4c a4 95 99 1b 78 52 b8 55"), + 32, + ), + ( + b"c hs traffic", + from_hex_str("c6 c9 18 ad 2f 41 99 d5 59 8e af 01 16 cb 7a 5c 2c 14 cb 54 78 12 18 88 8d b7 03 0d d5 0d 5e 6d"), + from_hex_str("00 20 12 74 6c 73 31 33 20 63 20 68 73 20 74 72 61 66 66 69 63 20 c6 c9 18 ad 2f 41 99 d5 59 8e af 01 16 cb 7a 5c 2c 14 cb 54 78 12 18 88 8d b7 03 0d d5 0d 5e 6d"), + 32, + ), + ( + b"s hs traffic", + from_hex_str("c6 c9 18 ad 2f 41 99 d5 59 8e af 01 16 cb 7a 5c 2c 14 cb 54 78 12 18 88 8d b7 03 0d d5 0d 5e 6d"), + from_hex_str("00 20 12 74 6c 73 31 33 20 73 20 68 73 20 74 72 61 66 66 69 63 20 c6 c9 18 ad 2f 41 99 d5 59 8e af 01 16 cb 7a 5c 2c 14 cb 54 78 12 18 88 8d b7 03 0d d5 0d 5e 6d"), + 32, + ), + ( + b"key", + from_hex_str(""), + from_hex_str("00 10 09 74 6c 73 31 33 20 6b 65 79 00"), + 16, + ), + ( + b"iv", + from_hex_str(""), + from_hex_str("00 0c 08 74 6c 73 31 33 20 69 76 00"), + 12, + ), + ( + b"finished", + from_hex_str(""), + from_hex_str("00 20 0e 74 6c 73 31 33 20 66 69 6e 69 73 68 65 64 00"), + 32, + ), + ( + b"c ap traffic", + from_hex_str("f8 c1 9e 8c 77 c0 38 79 bb c8 eb 6d 56 e0 0d d5 d8 6e f5 59 27 ee fc 08 e1 b0 02 b6 ec e0 5d bf"), + from_hex_str("00 20 12 74 6c 73 31 33 20 63 20 61 70 20 74 72 61 66 66 69 63 20 f8 c1 9e 8c 77 c0 38 79 bb c8 eb 6d 56 e0 0d d5 d8 6e f5 59 27 ee fc 08 e1 b0 02 b6 ec e0 5d bf"), + 32, + ), + ( + b"s ap traffic", + from_hex_str("f8 c1 9e 8c 77 c0 38 79 bb c8 eb 6d 56 e0 0d d5 d8 6e f5 59 27 ee fc 08 e1 b0 02 b6 ec e0 5d bf"), + from_hex_str("00 20 12 74 6c 73 31 33 20 73 20 61 70 20 74 72 61 66 66 69 63 20 f8 c1 9e 8c 77 c0 38 79 bb c8 eb 6d 56 e0 0d d5 d8 6e f5 59 27 ee fc 08 e1 b0 02 b6 ec e0 5d bf"), + 32, + ), + ( + b"exp master", + from_hex_str("f8 c1 9e 8c 77 c0 38 79 bb c8 eb 6d 56 e0 0d d5 d8 6e f5 59 27 ee fc 08 e1 b0 02 b6 ec e0 5d bf"), + from_hex_str("00 20 10 74 6c 73 31 33 20 65 78 70 20 6d 61 73 74 65 72 20 f8 c1 9e 8c 77 c0 38 79 bb c8 eb 6d 56 e0 0d d5 d8 6e f5 59 27 ee fc 08 e1 b0 02 b6 ec e0 5d bf"), + 32, + ), + ( + b"res master", + from_hex_str("50 2f 86 b9 57 9e c0 53 d3 28 24 e2 78 0e f6 5c c4 37 a3 56 43 45 35 6b df 79 13 ec 3b 87 96 14"), + from_hex_str("00 20 10 74 6c 73 31 33 20 72 65 73 20 6d 61 73 74 65 72 20 50 2f 86 b9 57 9e c0 53 d3 28 24 e2 78 0e f6 5c c4 37 a3 56 43 45 35 6b df 79 13 ec 3b 87 96 14"), + 32, + ), + ( + b"resumption", + from_hex_str("00 00"), + from_hex_str("00 20 10 74 6c 73 31 33 20 72 65 73 75 6d 70 74 69 6f 6e 02 00 00"), + 32, + ), + ] + } + + fn from_hex_str(s: &str) -> Vec { + hex::decode(s.split_whitespace().collect::()).unwrap() + } +} diff --git a/crates/components/hmac-sha256/src/kdf/expand/normal.rs b/crates/components/hmac-sha256/src/kdf/expand/normal.rs new file mode 100644 index 0000000000..b3d5876418 --- /dev/null +++ b/crates/components/hmac-sha256/src/kdf/expand/normal.rs @@ -0,0 +1,134 @@ +use crate::{ + hmac::normal::HmacNormal, kdf::expand::label::HkdfLabel, tls12::merge_vectors, FError, +}; + +use mpz_vm_core::{ + memory::{ + binary::{Binary, U8}, + MemoryExt, Vector, ViewExt, + }, + Vm, +}; + +#[derive(Debug)] +enum State { + /// Wants the context to be set. + WantsContext, + /// Context has been set. + ContextSet, + Complete, +} + +/// Functionality for computing `HKDF-Expand-Label` with a private secret +/// and public label and context. +#[derive(Debug)] +pub(crate) struct HkdfExpand { + label: HkdfLabel, + state: State, + ctx: Option>, + output: Vector, +} + +impl HkdfExpand { + /// Allocates a new HKDF-Expand-Label with the `hmac` + /// instantiated with the secret. + pub(crate) fn alloc( + vm: &mut dyn Vm, + mut hmac: HmacNormal, + // Human-readable label. + label: &'static [u8], + // Context length. + ctx_len: usize, + // Output length. + out_len: usize, + ) -> Result { + assert!( + out_len <= 32, + "output length larger than 32 is not supported" + ); + + let hkdf_label = HkdfLabel::alloc(vm, label, ctx_len, out_len)?; + let info = hkdf_label.output(); + + // HKDF-Expand requires 0x01 to be concatenated. + // see line: T(1) = HMAC-Hash(PRK, T(0) | info | 0x01) in + // https://datatracker.ietf.org/doc/html/rfc5869 + let constant = vm.alloc_vec::(1).map_err(FError::vm)?; + vm.mark_public(constant).map_err(FError::vm)?; + vm.assign(constant, vec![0x01]).map_err(FError::vm)?; + vm.commit(constant).map_err(FError::vm)?; + + let msg = merge_vectors(vm, vec![info, constant], info.len() + constant.len())?; + + hmac.set_msg(vm, &[msg])?; + + let mut output: Vector = hmac.output()?.into(); + output.truncate(out_len); + + Ok(Self { + output, + label: hkdf_label, + ctx: None, + state: State::WantsContext, + }) + } + + /// Whether this functionality needs to be flushed. + pub(crate) fn wants_flush(&self) -> bool { + let state_wants_flush = match self.state { + State::WantsContext => self.is_ctx_set(), + _ => false, + }; + state_wants_flush || self.label.wants_flush() + } + + /// Flushes the functionality. + pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), FError> { + self.label.flush(vm)?; + + match &mut self.state { + State::WantsContext => { + if let Some(ctx) = &self.ctx { + self.label.set_ctx(ctx)?; + self.label.flush(vm)?; + self.state = State::ContextSet; + // Recurse. + self.flush(vm)?; + } + } + State::ContextSet => { + if self.label.is_complete() { + self.state = State::Complete; + } + } + _ => (), + } + + Ok(()) + } + + /// Sets the HKDF-Expand-Label context. + pub(crate) fn set_ctx(&mut self, ctx: &[u8]) -> Result<(), FError> { + if self.is_ctx_set() { + return Err(FError::state("context has already been set")); + } + + self.ctx = Some(ctx.to_vec()); + Ok(()) + } + + /// Returns the HKDF-Expand-Label output. + pub(crate) fn output(&self) -> Vector { + self.output + } + + /// Whether this functionality is complete. + pub(crate) fn is_complete(&self) -> bool { + matches!(self.state, State::Complete) + } + + /// Whether the context has been set. + pub(crate) fn is_ctx_set(&self) -> bool { + self.ctx.is_some() + } +} diff --git a/crates/components/hmac-sha256/src/kdf/expand/reduced.rs b/crates/components/hmac-sha256/src/kdf/expand/reduced.rs new file mode 100644 index 0000000000..0b49508662 --- /dev/null +++ b/crates/components/hmac-sha256/src/kdf/expand/reduced.rs @@ -0,0 +1,125 @@ +use mpz_vm_core::{ + memory::{ + binary::{Binary, U8}, + Vector, + }, + Vm, +}; + +use crate::{hmac::reduced::HmacReduced, kdf::expand::label::HkdfLabelClear, FError}; + +/// Functionality for computing `HKDF-Expand-Label` with a private secret +/// and public label and context. +#[derive(Debug)] +pub(crate) struct HkdfExpand { + label: HkdfLabelClear, + hmac: HmacReduced, + ctx: Option>, + output: Vector, + state: State, +} + +impl HkdfExpand { + /// Allocates a new HKDF-Expand-Label with the `hmac` + /// instantiated with the secret. + pub(crate) fn alloc( + hmac: HmacReduced, + // Human-readable label. + label: &'static [u8], + // Output length. + out_len: usize, + ) -> Result { + assert!( + out_len <= 32, + "output length larger than 32 is not supported" + ); + + let hkdf_label = HkdfLabelClear::new(label, out_len); + + let mut output: Vector = hmac.output().into(); + output.truncate(out_len); + + Ok(Self { + label: hkdf_label, + hmac, + ctx: None, + output, + state: State::WantsContext, + }) + } + + /// Whether this functionality needs to be flushed. + pub(crate) fn wants_flush(&self) -> bool { + let state_wants_flush = match self.state { + State::WantsContext => self.is_ctx_set(), + _ => false, + }; + + state_wants_flush || self.hmac.wants_flush() + } + + /// Flushes the functionality. + pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), FError> { + self.hmac.flush(vm)?; + + match self.state { + State::WantsContext => { + if let Some(ctx) = &self.ctx { + // HKDF-Expand requires 0x01 to be concatenated. + // see line: T(1) = HMAC-Hash(PRK, T(0) | info | 0x01) in + // https://datatracker.ietf.org/doc/html/rfc5869 + self.label.set_ctx(ctx)?; + let mut label = self.label.output()?; + label.push(0x01); + + self.hmac.set_msg(&label)?; + self.hmac.flush(vm)?; + + self.state = State::ContextSet; + // Recurse. + self.flush(vm)?; + } + } + State::ContextSet => { + if self.hmac.is_complete() { + self.state = State::Complete; + } + } + _ => (), + } + + Ok(()) + } + + /// Sets the HKDF-Expand-Label context. + pub(crate) fn set_ctx(&mut self, ctx: &[u8]) -> Result<(), FError> { + if self.is_ctx_set() { + return Err(FError::state("context has already been set")); + } + + self.ctx = Some(ctx.to_vec()); + Ok(()) + } + + /// Returns the HKDF-Expand-Label output. + pub(crate) fn output(&self) -> Vector { + self.output + } + + /// Whether the context has been set. + pub(crate) fn is_ctx_set(&self) -> bool { + self.ctx.is_some() + } + + /// Whether this functionality is complete. + pub(crate) fn is_complete(&self) -> bool { + matches!(self.state, State::Complete) + } +} + +#[derive(Debug)] +enum State { + WantsContext, + ContextSet, + Complete, +} diff --git a/crates/components/hmac-sha256/src/kdf/extract.rs b/crates/components/hmac-sha256/src/kdf/extract.rs new file mode 100644 index 0000000000..946990a451 --- /dev/null +++ b/crates/components/hmac-sha256/src/kdf/extract.rs @@ -0,0 +1,334 @@ +//! `HKDF-Extract` function as defined in https://datatracker.ietf.org/doc/html/rfc5869 + +use crate::{ + hmac::{normal::HmacNormal, Hmac}, + FError, Mode, +}; +use mpz_vm_core::{ + memory::{ + binary::{Binary, U8}, + Array, Vector, + }, + Vm, +}; + +pub(crate) mod normal; +pub(crate) mod reduced; + +/// Functionality for computing `HKDF-Extract` with private salt and public +/// IKM. +#[derive(Debug)] +pub(crate) enum HkdfExtract { + Reduced(reduced::HkdfExtract), + Normal(normal::HkdfExtract), +} + +impl HkdfExtract { + /// Allocates a new HKDF-Extract with the given `ikm` and `hmac` + /// instantiated with the salt. + pub(crate) fn alloc( + mode: Mode, + vm: &mut dyn Vm, + ikm: [u8; 32], + hmac: Hmac, + ) -> Result { + let prf = match mode { + Mode::Reduced => { + if let Hmac::Reduced(hmac) = hmac { + Self::Reduced(reduced::HkdfExtract::alloc(ikm, hmac)?) + } else { + unreachable!("modes always match"); + } + } + Mode::Normal => { + if let Hmac::Normal(hmac) = hmac { + Self::Normal(normal::HkdfExtract::alloc(vm, ikm, hmac)?) + } else { + unreachable!("modes always match"); + } + } + }; + Ok(prf) + } + + /// Whether this functionality needs to be flushed. + pub(crate) fn wants_flush(&self) -> bool { + match self { + HkdfExtract::Reduced(hkdf) => hkdf.wants_flush(), + HkdfExtract::Normal(hkdf) => hkdf.wants_flush(), + } + } + + /// Flushes the functionality. + pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), FError> { + match self { + HkdfExtract::Reduced(hkdf) => hkdf.flush(vm), + HkdfExtract::Normal(hkdf) => hkdf.flush(), + } + } + + /// Returns HKDF-Extract output. + pub(crate) fn output(&self) -> Vector { + match self { + HkdfExtract::Reduced(hkdf) => hkdf.output(), + HkdfExtract::Normal(hkdf) => hkdf.output(), + } + } + + /// Whether this functionality is complete. + pub(crate) fn is_complete(&self) -> bool { + match self { + HkdfExtract::Reduced(hkdf) => hkdf.is_complete(), + HkdfExtract::Normal(hkdf) => hkdf.is_complete(), + } + } +} + +/// Functionality for computing `HKDF-Extract` with private IKM and public +/// salt. +#[derive(Debug)] +pub(crate) struct HkdfExtractPrivIkm { + output: Vector, + state: State, +} + +impl HkdfExtractPrivIkm { + /// Allocates a new HKDF-Extract with the given `ikm` and `hmac` + /// instantiated with the salt. + pub(crate) fn alloc( + vm: &mut dyn Vm, + ikm: Array, + mut hmac: HmacNormal, + ) -> Result { + hmac.set_msg(vm, &[ikm.into()])?; + + Ok(Self { + output: hmac.output()?.into(), + state: State::Setup, + }) + } + + /// Whether this functionality needs to be flushed. + pub(crate) fn wants_flush(&self) -> bool { + matches!(self.state, State::Setup) + } + + /// Flushes the functionality. + pub(crate) fn flush(&mut self) { + if let State::Setup = self.state { + self.state = State::Complete; + } + } + + /// Returns HKDF-Extract output. + pub(crate) fn output(&self) -> Vector { + self.output + } + + pub(crate) fn is_complete(&self) -> bool { + matches!(self.state, State::Complete) + } +} + +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +pub(crate) enum State { + Setup, + Complete, +} + +#[cfg(test)] +mod tests { + use crate::{ + hmac::{clear, normal::HmacNormal, Hmac}, + kdf::extract::{HkdfExtract, HkdfExtractPrivIkm}, + test_utils::mock_vm, + Mode, + }; + use mpz_common::context::test_st_context; + use mpz_vm_core::{ + memory::{binary::U8, Array, MemoryExt, ViewExt}, + Execute, + }; + use rstest::*; + + #[tokio::test] + async fn test_hkdf_extract_priv_ikm() { + for fixture in test_fixtures() { + let (salt, ikm, secret) = fixture; + + let (mut ctx_a, mut ctx_b) = test_st_context(8); + let (mut leader, mut follower) = mock_vm(); + + let ikm: [u8; 32] = ikm.try_into().unwrap(); + + let inner_state = clear::compute_inner_partial(&salt); + let outer_state = clear::compute_outer_partial(&salt); + + // ------------------ LEADER + + let vm = &mut leader; + + let ikm_ref: Array = vm.alloc().unwrap(); + vm.mark_public(ikm_ref).unwrap(); + vm.assign(ikm_ref, ikm).unwrap(); + vm.commit(ikm_ref).unwrap(); + + let hmac = HmacNormal::alloc_with_state(vm, inner_state, outer_state).unwrap(); + + let mut hkdf_leader = HkdfExtractPrivIkm::alloc(vm, ikm_ref, hmac).unwrap(); + let out_leader = hkdf_leader.output(); + let mut leader_decode_fut = vm.decode(out_leader).unwrap(); + + // ------------------ FOLLOWER + + let vm = &mut follower; + + let ikm_ref: Array = vm.alloc().unwrap(); + vm.mark_public(ikm_ref).unwrap(); + vm.assign(ikm_ref, ikm).unwrap(); + vm.commit(ikm_ref).unwrap(); + + let hmac = HmacNormal::alloc_with_state(vm, inner_state, outer_state).unwrap(); + + let mut hkdf_follower = HkdfExtractPrivIkm::alloc(vm, ikm_ref, hmac).unwrap(); + let out_follower = hkdf_follower.output(); + let mut follower_decode_fut = vm.decode(out_follower).unwrap(); + + tokio::try_join!( + async { + leader.execute_all(&mut ctx_a).await.unwrap(); + assert!(hkdf_leader.wants_flush()); + hkdf_leader.flush(); + assert!(!hkdf_leader.wants_flush()); + + Ok::<(), Box>(()) + }, + async { + follower.execute_all(&mut ctx_b).await.unwrap(); + assert!(hkdf_follower.wants_flush()); + hkdf_follower.flush(); + assert!(!hkdf_follower.wants_flush()); + + Ok::<(), Box>(()) + } + ) + .unwrap(); + + let leader_out = leader_decode_fut.try_recv().unwrap().unwrap(); + let follower_out = follower_decode_fut.try_recv().unwrap().unwrap(); + assert_eq!(leader_out, follower_out); + assert_eq!(leader_out, secret); + } + } + + #[rstest] + #[case::normal(Mode::Normal)] + #[case::reduced(Mode::Reduced)] + #[tokio::test] + async fn test_hkdf_extract(#[case] mode: Mode) { + for fixture in test_fixtures() { + let (salt, ikm, secret) = fixture; + + let (mut ctx_a, mut ctx_b) = test_st_context(8); + let (mut leader, mut follower) = mock_vm(); + + let salt: [u8; 32] = salt.try_into().unwrap(); + + // ------------------ LEADER + + let vm = &mut leader; + + let salt_ref = vm.alloc_vec(32).unwrap(); + vm.mark_public(salt_ref).unwrap(); + vm.assign(salt_ref, salt.to_vec()).unwrap(); + vm.commit(salt_ref).unwrap(); + + let hmac = Hmac::alloc(vm, salt_ref, mode).unwrap(); + + let mut hkdf_leader = + HkdfExtract::alloc(mode, vm, ikm.clone().try_into().unwrap(), hmac).unwrap(); + let out_leader = hkdf_leader.output(); + let mut leader_decode_fut = leader.decode(out_leader).unwrap(); + + // ------------------ FOLLOWER + + let vm = &mut follower; + + let salt_ref = vm.alloc_vec(32).unwrap(); + vm.mark_public(salt_ref).unwrap(); + vm.assign(salt_ref, salt.to_vec()).unwrap(); + vm.commit(salt_ref).unwrap(); + + let hmac = Hmac::alloc(vm, salt_ref, mode).unwrap(); + + let mut hkdf_follower = + HkdfExtract::alloc(mode, vm, ikm.try_into().unwrap(), hmac).unwrap(); + let out_follower = hkdf_follower.output(); + let mut follower_decode_fut = follower.decode(out_follower).unwrap(); + + tokio::try_join!( + async { + leader.execute_all(&mut ctx_a).await.unwrap(); + assert!(hkdf_leader.wants_flush()); + hkdf_leader.flush(&mut leader).unwrap(); + assert!(!hkdf_leader.wants_flush()); + leader.execute_all(&mut ctx_a).await.unwrap(); + + Ok::<(), Box>(()) + }, + async { + follower.execute_all(&mut ctx_b).await.unwrap(); + assert!(hkdf_follower.wants_flush()); + hkdf_follower.flush(&mut follower).unwrap(); + assert!(!hkdf_follower.wants_flush()); + follower.execute_all(&mut ctx_b).await.unwrap(); + + Ok::<(), Box>(()) + } + ) + .unwrap(); + + let out_leader = leader_decode_fut.try_recv().unwrap().unwrap(); + let out_follower = follower_decode_fut.try_recv().unwrap().unwrap(); + assert_eq!(out_leader, out_follower); + assert_eq!(out_leader, secret); + } + } + + // Reference values from https://datatracker.ietf.org/doc/html/draft-ietf-tls-tls13-vectors-06 + fn test_fixtures() -> Vec<(Vec, Vec, Vec)> { + vec![( + // SALT + from_hex_str::<32>("6f 26 15 a1 08 c7 02 c5 67 8f 54 fc 9d ba b6 97 16 c0 76 18 9c 48 25 0c eb ea c3 57 6c 36 11 ba").to_vec(), + // IKM + from_hex_str::<32>("81 51 d1 46 4c 1b 55 53 36 23 b9 c2 24 6a 6a 0e 6e 7e 18 50 63 e1 4a fd af f0 b6 e1 c6 1a 86 42").to_vec(), + // SECRET + from_hex_str::<32>("5b 4f 96 5d f0 3c 68 2c 46 e6 ee 86 c3 11 63 66 15 a1 d2 bb b2 43 45 c2 52 05 95 3c 87 9e 8d 06").to_vec(), + ), + ( + // SALT + from_hex_str::<32>("c8 61 57 19 e2 40 37 47 b6 10 76 2c 72 b8 f4 da 5c 60 99 57 65 d4 04 a9 d0 06 b9 b0 72 7b a5 83").to_vec(), + // IKM + from_hex_str::<32>("00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00").to_vec(), + // SECRET + from_hex_str::<32>("5c 79 d1 69 42 4e 26 2b 56 32 03 62 7b e4 eb 51 03 3f 58 8c 43 c9 ce 03 73 37 2d bc bc 01 85 a7").to_vec(), + ), + ( + // SALT + from_hex_str::<32>("9e fc 79 87 0b 08 c4 c6 51 20 52 50 af 9b 83 04 79 11 b7 83 d5 d7 67 8d 7c cc e7 18 18 9e a2 ec").to_vec(), + // IKM + from_hex_str::<32>("b0 66 a1 5b c1 aa ee f8 79 0e 0b 02 e6 2f 82 dc 44 64 46 e3 7d 6d 61 22 b0 d3 b9 94 ef 11 dd 3c").to_vec(), + // SECRET + from_hex_str::<32>("ea d8 b8 c5 9a 15 df 29 d7 9f a4 ac 31 d5 f7 c9 0e 2e 5c 87 d9 ea fe d1 fe 69 16 cf 2f 29 37 34").to_vec(), + ) + ] + } + + fn from_hex_str(s: &str) -> [u8; N] { + let bytes: Vec = hex::decode(s.split_whitespace().collect::()).unwrap(); + bytes + .try_into() + .expect("Hex string length does not match array size") + } +} diff --git a/crates/components/hmac-sha256/src/kdf/extract/normal.rs b/crates/components/hmac-sha256/src/kdf/extract/normal.rs new file mode 100644 index 0000000000..c6a80bb356 --- /dev/null +++ b/crates/components/hmac-sha256/src/kdf/extract/normal.rs @@ -0,0 +1,76 @@ +use mpz_vm_core::{ + memory::{ + binary::{Binary, U8}, + Array, MemoryExt, Vector, ViewExt, + }, + Vm, +}; + +use crate::{hmac::normal::HmacNormal, FError}; + +/// Functionality for HKDF-Extract computation with private salt and public +/// IKM. +#[derive(Debug)] +pub(crate) struct HkdfExtract { + hmac: HmacNormal, + output: Vector, + state: State, +} + +impl HkdfExtract { + /// Allocates a new HKDF-Extract with the given `ikm` and `hmac` + /// instantiated with the salt. + pub(crate) fn alloc( + vm: &mut dyn Vm, + ikm: [u8; 32], + mut hmac: HmacNormal, + ) -> Result { + let msg: Array = vm.alloc().map_err(FError::vm)?; + vm.mark_public(msg).map_err(FError::vm)?; + vm.assign(msg, ikm).map_err(FError::vm)?; + vm.commit(msg).map_err(FError::vm)?; + + hmac.set_msg(vm, &[msg.into()])?; + + Ok(Self { + output: hmac.output()?.into(), + hmac, + state: State::Setup {}, + }) + } + + /// Whether this functionality needs to be flushed. + pub(crate) fn wants_flush(&self) -> bool { + matches!(self.state, State::Setup) || self.hmac.wants_flush() + } + + /// Flushes the functionality. + pub(crate) fn flush(&mut self) -> Result<(), FError> { + self.hmac.flush()?; + + if let State::Setup = &mut self.state { + if self.hmac.is_complete() { + self.state = State::Complete; + } + } + + Ok(()) + } + + /// Returns HKDF-Extract output. + pub(crate) fn output(&self) -> Vector { + self.output + } + + /// Whether this functionality is complete. + pub(crate) fn is_complete(&self) -> bool { + matches!(self.state, State::Complete) + } +} + +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +pub(crate) enum State { + Setup, + Complete, +} diff --git a/crates/components/hmac-sha256/src/kdf/extract/reduced.rs b/crates/components/hmac-sha256/src/kdf/extract/reduced.rs new file mode 100644 index 0000000000..2008f41120 --- /dev/null +++ b/crates/components/hmac-sha256/src/kdf/extract/reduced.rs @@ -0,0 +1,67 @@ +use crate::{hmac::reduced::HmacReduced, FError}; + +use mpz_vm_core::{ + memory::{ + binary::{Binary, U8}, + Vector, + }, + Vm, +}; + +/// Functionality for HKDF-Extract computation with private salt and public +/// IKM. +#[derive(Debug)] +pub(crate) struct HkdfExtract { + hmac: HmacReduced, + output: Vector, + state: State, +} + +impl HkdfExtract { + /// Allocates a new HKDF-Extract with the given `ikm` and `hmac` + /// instantiated with the salt. + pub(crate) fn alloc(ikm: [u8; 32], mut hmac: HmacReduced) -> Result { + hmac.set_msg(&ikm)?; + + Ok(Self { + output: hmac.output().into(), + hmac, + state: State::Setup, + }) + } + + /// Whether this functionality needs to be flushed. + pub(crate) fn wants_flush(&self) -> bool { + matches!(self.state, State::Setup) || self.hmac.wants_flush() + } + + /// Flushes the functionality. + pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), FError> { + self.hmac.flush(vm)?; + + if let State::Setup = &mut self.state { + if self.hmac.is_complete() { + self.state = State::Complete; + } + } + + Ok(()) + } + + /// Returns HKDF-Extract output. + pub(crate) fn output(&self) -> Vector { + self.output + } + + /// Whether this functionality is complete. + pub(crate) fn is_complete(&self) -> bool { + matches!(self.state, State::Complete) + } +} + +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +pub(crate) enum State { + Setup, + Complete, +} diff --git a/crates/components/hmac-sha256/src/kdf/mod.rs b/crates/components/hmac-sha256/src/kdf/mod.rs new file mode 100644 index 0000000000..807cbf93cf --- /dev/null +++ b/crates/components/hmac-sha256/src/kdf/mod.rs @@ -0,0 +1,2 @@ +pub(crate) mod expand; +pub(crate) mod extract; diff --git a/crates/components/hmac-sha256/src/lib.rs b/crates/components/hmac-sha256/src/lib.rs index 2e460bf43d..9856743e47 100644 --- a/crates/components/hmac-sha256/src/lib.rs +++ b/crates/components/hmac-sha256/src/lib.rs @@ -1,4 +1,5 @@ -//! This crate contains the protocol for computing TLS 1.2 SHA-256 HMAC PRF. +//! MPC protocols for computing HMAC-SHA-256-based PRF for TLS 1.2 and key +//! schedule for TLS 1.3. #![deny(missing_docs, unreachable_pub, unused_must_use)] #![deny(clippy::all)] @@ -12,36 +13,15 @@ mod config; pub use config::Mode; mod error; -pub use error::PrfError; +pub use error::FError; +mod kdf; mod prf; -pub use prf::MpcPrf; +mod tls12; +mod tls13; -use mpz_vm_core::memory::{binary::U8, Array}; - -/// PRF output. -#[derive(Debug, Clone, Copy)] -pub struct PrfOutput { - /// TLS session keys. - pub keys: SessionKeys, - /// Client finished verify data. - pub cf_vd: Array, - /// Server finished verify data. - pub sf_vd: Array, -} - -/// Session keys computed by the PRF. -#[derive(Debug, Clone, Copy)] -pub struct SessionKeys { - /// Client write key. - pub client_write_key: Array, - /// Server write key. - pub server_write_key: Array, - /// Client IV. - pub client_iv: Array, - /// Server IV. - pub server_iv: Array, -} +pub use tls12::{PrfOutput, SessionKeys, Tls12Prf}; +pub use tls13::{ApplicationKeys, HandshakeKeys, Role, Tls13KeySched}; fn sha256(mut state: [u32; 8], pos: usize, msg: &[u8]) -> [u32; 8] { use sha2::{ @@ -60,6 +40,20 @@ fn sha256(mut state: [u32; 8], pos: usize, msg: &[u8]) -> [u32; 8] { state } +pub(crate) fn compress_256(mut state: [u32; 8], msg: &[u8]) -> [u32; 8] { + use sha2::{ + compress256, + digest::{ + block_buffer::{BlockBuffer, Eager}, + generic_array::typenum::U64, + }, + }; + + let mut buffer = BlockBuffer::::default(); + buffer.digest_blocks(msg, |b| compress256(&mut state, b)); + state +} + fn state_to_bytes(input: [u32; 8]) -> [u8; 32] { let mut output = [0_u8; 32]; for (k, byte_chunk) in input.iter().enumerate() { @@ -68,202 +62,3 @@ fn state_to_bytes(input: [u32; 8]) -> [u8; 32] { } output } - -#[cfg(test)] -mod tests { - use crate::{ - test_utils::{mock_vm, prf_cf_vd, prf_keys, prf_ms, prf_sf_vd}, - Mode, MpcPrf, SessionKeys, - }; - use mpz_common::context::test_st_context; - use mpz_vm_core::{ - memory::{binary::U8, Array, MemoryExt, ViewExt}, - Execute, - }; - use rand::{rngs::StdRng, Rng, SeedableRng}; - - #[tokio::test] - async fn test_prf_reduced() { - let mode = Mode::Reduced; - test_prf(mode).await; - } - - #[tokio::test] - async fn test_prf_normal() { - let mode = Mode::Normal; - test_prf(mode).await; - } - - async fn test_prf(mode: Mode) { - let mut rng = StdRng::seed_from_u64(1); - // Test input - let pms: [u8; 32] = rng.random(); - let client_random: [u8; 32] = rng.random(); - let server_random: [u8; 32] = rng.random(); - - let cf_hs_hash: [u8; 32] = rng.random(); - let sf_hs_hash: [u8; 32] = rng.random(); - - // Expected output - let ms_expected = prf_ms(pms, client_random, server_random); - - let [cwk_expected, swk_expected, civ_expected, siv_expected] = - prf_keys(ms_expected, client_random, server_random); - - let cwk_expected: [u8; 16] = cwk_expected.try_into().unwrap(); - let swk_expected: [u8; 16] = swk_expected.try_into().unwrap(); - let civ_expected: [u8; 4] = civ_expected.try_into().unwrap(); - let siv_expected: [u8; 4] = siv_expected.try_into().unwrap(); - - let cf_vd_expected = prf_cf_vd(ms_expected, cf_hs_hash); - let sf_vd_expected = prf_sf_vd(ms_expected, sf_hs_hash); - - let cf_vd_expected: [u8; 12] = cf_vd_expected.try_into().unwrap(); - let sf_vd_expected: [u8; 12] = sf_vd_expected.try_into().unwrap(); - - // Set up vm and prf - let (mut ctx_a, mut ctx_b) = test_st_context(128); - let (mut leader, mut follower) = mock_vm(); - - let leader_pms: Array = leader.alloc().unwrap(); - leader.mark_public(leader_pms).unwrap(); - leader.assign(leader_pms, pms).unwrap(); - leader.commit(leader_pms).unwrap(); - - let follower_pms: Array = follower.alloc().unwrap(); - follower.mark_public(follower_pms).unwrap(); - follower.assign(follower_pms, pms).unwrap(); - follower.commit(follower_pms).unwrap(); - - let mut prf_leader = MpcPrf::new(mode); - let mut prf_follower = MpcPrf::new(mode); - - let leader_prf_out = prf_leader.alloc(&mut leader, leader_pms).unwrap(); - let follower_prf_out = prf_follower.alloc(&mut follower, follower_pms).unwrap(); - - // client_random and server_random - prf_leader.set_client_random(client_random).unwrap(); - prf_follower.set_client_random(client_random).unwrap(); - - prf_leader.set_server_random(server_random).unwrap(); - prf_follower.set_server_random(server_random).unwrap(); - - let SessionKeys { - client_write_key: cwk_leader, - server_write_key: swk_leader, - client_iv: civ_leader, - server_iv: siv_leader, - } = leader_prf_out.keys; - - let mut cwk_leader = leader.decode(cwk_leader).unwrap(); - let mut swk_leader = leader.decode(swk_leader).unwrap(); - let mut civ_leader = leader.decode(civ_leader).unwrap(); - let mut siv_leader = leader.decode(siv_leader).unwrap(); - - let SessionKeys { - client_write_key: cwk_follower, - server_write_key: swk_follower, - client_iv: civ_follower, - server_iv: siv_follower, - } = follower_prf_out.keys; - - let mut cwk_follower = follower.decode(cwk_follower).unwrap(); - let mut swk_follower = follower.decode(swk_follower).unwrap(); - let mut civ_follower = follower.decode(civ_follower).unwrap(); - let mut siv_follower = follower.decode(siv_follower).unwrap(); - - while prf_leader.wants_flush() || prf_follower.wants_flush() { - tokio::try_join!( - async { - prf_leader.flush(&mut leader).unwrap(); - leader.execute_all(&mut ctx_a).await - }, - async { - prf_follower.flush(&mut follower).unwrap(); - follower.execute_all(&mut ctx_b).await - } - ) - .unwrap(); - } - - let cwk_leader = cwk_leader.try_recv().unwrap().unwrap(); - let swk_leader = swk_leader.try_recv().unwrap().unwrap(); - let civ_leader = civ_leader.try_recv().unwrap().unwrap(); - let siv_leader = siv_leader.try_recv().unwrap().unwrap(); - - let cwk_follower = cwk_follower.try_recv().unwrap().unwrap(); - let swk_follower = swk_follower.try_recv().unwrap().unwrap(); - let civ_follower = civ_follower.try_recv().unwrap().unwrap(); - let siv_follower = siv_follower.try_recv().unwrap().unwrap(); - - assert_eq!(cwk_leader, cwk_follower); - assert_eq!(swk_leader, swk_follower); - assert_eq!(civ_leader, civ_follower); - assert_eq!(siv_leader, siv_follower); - - assert_eq!(cwk_leader, cwk_expected); - assert_eq!(swk_leader, swk_expected); - assert_eq!(civ_leader, civ_expected); - assert_eq!(siv_leader, siv_expected); - - // client finished - prf_leader.set_cf_hash(cf_hs_hash).unwrap(); - prf_follower.set_cf_hash(cf_hs_hash).unwrap(); - - let cf_vd_leader = leader_prf_out.cf_vd; - let cf_vd_follower = follower_prf_out.cf_vd; - - let mut cf_vd_leader = leader.decode(cf_vd_leader).unwrap(); - let mut cf_vd_follower = follower.decode(cf_vd_follower).unwrap(); - - while prf_leader.wants_flush() || prf_follower.wants_flush() { - tokio::try_join!( - async { - prf_leader.flush(&mut leader).unwrap(); - leader.execute_all(&mut ctx_a).await - }, - async { - prf_follower.flush(&mut follower).unwrap(); - follower.execute_all(&mut ctx_b).await - } - ) - .unwrap(); - } - - let cf_vd_leader = cf_vd_leader.try_recv().unwrap().unwrap(); - let cf_vd_follower = cf_vd_follower.try_recv().unwrap().unwrap(); - - assert_eq!(cf_vd_leader, cf_vd_follower); - assert_eq!(cf_vd_leader, cf_vd_expected); - - // server finished - prf_leader.set_sf_hash(sf_hs_hash).unwrap(); - prf_follower.set_sf_hash(sf_hs_hash).unwrap(); - - let sf_vd_leader = leader_prf_out.sf_vd; - let sf_vd_follower = follower_prf_out.sf_vd; - - let mut sf_vd_leader = leader.decode(sf_vd_leader).unwrap(); - let mut sf_vd_follower = follower.decode(sf_vd_follower).unwrap(); - - while prf_leader.wants_flush() || prf_follower.wants_flush() { - tokio::try_join!( - async { - prf_leader.flush(&mut leader).unwrap(); - leader.execute_all(&mut ctx_a).await - }, - async { - prf_follower.flush(&mut follower).unwrap(); - follower.execute_all(&mut ctx_b).await - } - ) - .unwrap(); - } - - let sf_vd_leader = sf_vd_leader.try_recv().unwrap().unwrap(); - let sf_vd_follower = sf_vd_follower.try_recv().unwrap().unwrap(); - - assert_eq!(sf_vd_leader, sf_vd_follower); - assert_eq!(sf_vd_leader, sf_vd_expected); - } -} diff --git a/crates/components/hmac-sha256/src/prf.rs b/crates/components/hmac-sha256/src/prf.rs deleted file mode 100644 index f928f48bdb..0000000000 --- a/crates/components/hmac-sha256/src/prf.rs +++ /dev/null @@ -1,407 +0,0 @@ -use crate::{ - hmac::{IPAD, OPAD}, - Mode, PrfError, PrfOutput, -}; -use mpz_circuits::{circuits::xor, Circuit, CircuitBuilder}; -use mpz_hash::sha256::Sha256; -use mpz_vm_core::{ - memory::{ - binary::{Binary, U8}, - Array, MemoryExt, StaticSize, Vector, ViewExt, - }, - Call, CallableExt, Vm, -}; -use std::{fmt::Debug, sync::Arc}; -use tracing::instrument; - -mod state; -use state::State; - -mod function; -use function::Prf; - -/// MPC PRF for computing TLS 1.2 HMAC-SHA256 PRF. -#[derive(Debug)] -pub struct MpcPrf { - mode: Mode, - state: State, -} - -impl MpcPrf { - /// Creates a new instance of the PRF. - /// - /// # Arguments - /// - /// `mode` - The PRF mode. - pub fn new(mode: Mode) -> MpcPrf { - Self { - mode, - state: State::Initialized, - } - } - - /// Allocates resources for the PRF. - /// - /// # Arguments - /// - /// * `vm` - Virtual machine. - /// * `pms` - The pre-master secret. - #[instrument(level = "debug", skip_all, err)] - pub fn alloc( - &mut self, - vm: &mut dyn Vm, - pms: Array, - ) -> Result { - let State::Initialized = self.state.take() else { - return Err(PrfError::state("PRF not in initialized state")); - }; - - let mode = self.mode; - let pms: Vector = pms.into(); - - let outer_partial_pms = compute_partial(vm, pms, OPAD)?; - let inner_partial_pms = compute_partial(vm, pms, IPAD)?; - - let master_secret = - Prf::alloc_master_secret(mode, vm, outer_partial_pms, inner_partial_pms)?; - let ms = master_secret.output(); - let ms = merge_outputs(vm, ms, 48)?; - - let outer_partial_ms = compute_partial(vm, ms, OPAD)?; - let inner_partial_ms = compute_partial(vm, ms, IPAD)?; - - let key_expansion = - Prf::alloc_key_expansion(mode, vm, outer_partial_ms.clone(), inner_partial_ms.clone())?; - let client_finished = Prf::alloc_client_finished( - mode, - vm, - outer_partial_ms.clone(), - inner_partial_ms.clone(), - )?; - let server_finished = Prf::alloc_server_finished( - mode, - vm, - outer_partial_ms.clone(), - inner_partial_ms.clone(), - )?; - - self.state = State::SessionKeys { - client_random: None, - master_secret, - key_expansion, - client_finished, - server_finished, - }; - - self.state.prf_output(vm) - } - - /// Sets the client random. - /// - /// # Arguments - /// - /// * `random` - The client random. - #[instrument(level = "debug", skip_all, err)] - pub fn set_client_random(&mut self, random: [u8; 32]) -> Result<(), PrfError> { - let State::SessionKeys { client_random, .. } = &mut self.state else { - return Err(PrfError::state("PRF not set up")); - }; - - *client_random = Some(random); - Ok(()) - } - - /// Sets the server random. - /// - /// # Arguments - /// - /// * `random` - The server random. - #[instrument(level = "debug", skip_all, err)] - pub fn set_server_random(&mut self, random: [u8; 32]) -> Result<(), PrfError> { - let State::SessionKeys { - client_random, - master_secret, - key_expansion, - .. - } = &mut self.state - else { - return Err(PrfError::state("PRF not set up")); - }; - - let client_random = client_random.expect("Client random should have been set by now"); - let server_random = random; - - let mut seed_ms = client_random.to_vec(); - seed_ms.extend_from_slice(&server_random); - master_secret.set_start_seed(seed_ms); - - let mut seed_ke = server_random.to_vec(); - seed_ke.extend_from_slice(&client_random); - key_expansion.set_start_seed(seed_ke); - - Ok(()) - } - - /// Sets the client finished handshake hash. - /// - /// # Arguments - /// - /// * `handshake_hash` - The handshake transcript hash. - #[instrument(level = "debug", skip_all, err)] - pub fn set_cf_hash(&mut self, handshake_hash: [u8; 32]) -> Result<(), PrfError> { - let State::ClientFinished { - client_finished, .. - } = &mut self.state - else { - return Err(PrfError::state("PRF not in client finished state")); - }; - - let seed_cf = handshake_hash.to_vec(); - client_finished.set_start_seed(seed_cf); - - Ok(()) - } - - /// Sets the server finished handshake hash. - /// - /// # Arguments - /// - /// * `handshake_hash` - The handshake transcript hash. - #[instrument(level = "debug", skip_all, err)] - pub fn set_sf_hash(&mut self, handshake_hash: [u8; 32]) -> Result<(), PrfError> { - let State::ServerFinished { server_finished } = &mut self.state else { - return Err(PrfError::state("PRF not in server finished state")); - }; - - let seed_sf = handshake_hash.to_vec(); - server_finished.set_start_seed(seed_sf); - - Ok(()) - } - - /// Returns if the PRF needs to be flushed. - pub fn wants_flush(&self) -> bool { - match &self.state { - State::Initialized => false, - State::SessionKeys { - master_secret, - key_expansion, - .. - } => master_secret.wants_flush() || key_expansion.wants_flush(), - State::ClientFinished { - client_finished, .. - } => client_finished.wants_flush(), - State::ServerFinished { server_finished } => server_finished.wants_flush(), - State::Complete => false, - State::Error => false, - } - } - - /// Flushes the PRF. - pub fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { - self.state = match self.state.take() { - State::SessionKeys { - client_random, - mut master_secret, - mut key_expansion, - client_finished, - server_finished, - } => { - master_secret.flush(vm)?; - key_expansion.flush(vm)?; - - if !master_secret.wants_flush() && !key_expansion.wants_flush() { - State::ClientFinished { - client_finished, - server_finished, - } - } else { - State::SessionKeys { - client_random, - master_secret, - key_expansion, - client_finished, - server_finished, - } - } - } - State::ClientFinished { - mut client_finished, - server_finished, - } => { - client_finished.flush(vm)?; - - if !client_finished.wants_flush() { - State::ServerFinished { server_finished } - } else { - State::ClientFinished { - client_finished, - server_finished, - } - } - } - State::ServerFinished { - mut server_finished, - } => { - server_finished.flush(vm)?; - - if !server_finished.wants_flush() { - State::Complete - } else { - State::ServerFinished { server_finished } - } - } - other => other, - }; - - Ok(()) - } -} - -/// Depending on the provided `mask` computes and returns `outer_partial` or -/// `inner_partial` for HMAC-SHA256. -/// -/// # Arguments -/// -/// * `vm` - Virtual machine. -/// * `key` - Key to pad and xor. -/// * `mask`- Mask used for padding. -fn compute_partial( - vm: &mut dyn Vm, - key: Vector, - mask: [u8; 64], -) -> Result { - let xor = Arc::new(xor(8 * 64)); - - let additional_len = 64 - key.len(); - let padding = vec![0_u8; additional_len]; - - let padding_ref: Vector = vm.alloc_vec(additional_len).map_err(PrfError::vm)?; - vm.mark_public(padding_ref).map_err(PrfError::vm)?; - vm.assign(padding_ref, padding).map_err(PrfError::vm)?; - vm.commit(padding_ref).map_err(PrfError::vm)?; - - let mask_ref: Array = vm.alloc().map_err(PrfError::vm)?; - vm.mark_public(mask_ref).map_err(PrfError::vm)?; - vm.assign(mask_ref, mask).map_err(PrfError::vm)?; - vm.commit(mask_ref).map_err(PrfError::vm)?; - - let xor = Call::builder(xor) - .arg(key) - .arg(padding_ref) - .arg(mask_ref) - .build() - .map_err(PrfError::vm)?; - let key_padded: Vector = vm.call(xor).map_err(PrfError::vm)?; - - let mut sha = Sha256::new_with_init(vm)?; - sha.update(&key_padded); - sha.compress(vm)?; - Ok(sha) -} - -fn merge_outputs( - vm: &mut dyn Vm, - inputs: Vec>, - output_bytes: usize, -) -> Result, PrfError> { - assert!(output_bytes <= 32 * inputs.len()); - - let bits = Array::::SIZE * inputs.len(); - let circ = gen_merge_circ(bits); - - let mut builder = Call::builder(circ); - for &input in inputs.iter() { - builder = builder.arg(input); - } - let call = builder.build().map_err(PrfError::vm)?; - - let mut output: Vector = vm.call(call).map_err(PrfError::vm)?; - output.truncate(output_bytes); - Ok(output) -} - -fn gen_merge_circ(size: usize) -> Arc { - let mut builder = CircuitBuilder::new(); - let inputs = (0..size).map(|_| builder.add_input()).collect::>(); - - for input in inputs.chunks_exact(8) { - for byte in input.chunks_exact(8) { - for &feed in byte.iter() { - let output = builder.add_id_gate(feed); - builder.add_output(output); - } - } - } - - Arc::new(builder.build().expect("merge circuit is valid")) -} - -#[cfg(test)] -mod tests { - use crate::{prf::merge_outputs, test_utils::mock_vm}; - use mpz_common::context::test_st_context; - use mpz_vm_core::{ - memory::{binary::U8, Array, MemoryExt, ViewExt}, - Execute, - }; - - #[tokio::test] - async fn test_merge_outputs() { - let (mut ctx_a, mut ctx_b) = test_st_context(8); - let (mut leader, mut follower) = mock_vm(); - - let input1: [u8; 32] = std::array::from_fn(|i| i as u8); - let input2: [u8; 32] = std::array::from_fn(|i| i as u8 + 32); - - let mut expected = input1.to_vec(); - expected.extend_from_slice(&input2); - expected.truncate(48); - - // leader - let input1_leader: Array = leader.alloc().unwrap(); - let input2_leader: Array = leader.alloc().unwrap(); - - leader.mark_public(input1_leader).unwrap(); - leader.mark_public(input2_leader).unwrap(); - - leader.assign(input1_leader, input1).unwrap(); - leader.assign(input2_leader, input2).unwrap(); - - leader.commit(input1_leader).unwrap(); - leader.commit(input2_leader).unwrap(); - - let merged_leader = - merge_outputs(&mut leader, vec![input1_leader, input2_leader], 48).unwrap(); - let mut merged_leader = leader.decode(merged_leader).unwrap(); - - // follower - let input1_follower: Array = follower.alloc().unwrap(); - let input2_follower: Array = follower.alloc().unwrap(); - - follower.mark_public(input1_follower).unwrap(); - follower.mark_public(input2_follower).unwrap(); - - follower.assign(input1_follower, input1).unwrap(); - follower.assign(input2_follower, input2).unwrap(); - - follower.commit(input1_follower).unwrap(); - follower.commit(input2_follower).unwrap(); - - let merged_follower = - merge_outputs(&mut follower, vec![input1_follower, input2_follower], 48).unwrap(); - let mut merged_follower = follower.decode(merged_follower).unwrap(); - - tokio::try_join!( - leader.execute_all(&mut ctx_a), - follower.execute_all(&mut ctx_b) - ) - .unwrap(); - - let merged_leader = merged_leader.try_recv().unwrap().unwrap(); - let merged_follower = merged_follower.try_recv().unwrap().unwrap(); - - assert_eq!(merged_leader, merged_follower); - assert_eq!(merged_leader, expected); - } -} diff --git a/crates/components/hmac-sha256/src/prf/function.rs b/crates/components/hmac-sha256/src/prf/function.rs index e1e932a02f..5c90fefd47 100644 --- a/crates/components/hmac-sha256/src/prf/function.rs +++ b/crates/components/hmac-sha256/src/prf/function.rs @@ -1,11 +1,10 @@ //! Provides [`Prf`], for computing the TLS 1.2 PRF. -use crate::{Mode, PrfError}; -use mpz_hash::sha256::Sha256; +use crate::{hmac::Hmac, FError, Mode}; use mpz_vm_core::{ memory::{ binary::{Binary, U8}, - Array, + Vector, }, Vm, }; @@ -20,90 +19,107 @@ pub(crate) enum Prf { } impl Prf { + /// Allocates master secret. pub(crate) fn alloc_master_secret( mode: Mode, vm: &mut dyn Vm, - outer_partial: Sha256, - inner_partial: Sha256, - ) -> Result { + hmac: Hmac, + ) -> Result { let prf = match mode { - Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_master_secret( - vm, - outer_partial, - inner_partial, - )?), - Mode::Normal => Self::Normal(normal::PrfFunction::alloc_master_secret( - vm, - outer_partial, - inner_partial, - )?), + Mode::Reduced => { + if let Hmac::Reduced(hmac) = hmac { + Self::Reduced(reduced::PrfFunction::alloc_master_secret(vm, hmac)?) + } else { + unreachable!("modes always match"); + } + } + Mode::Normal => { + if let Hmac::Normal(hmac) = hmac { + Self::Normal(normal::PrfFunction::alloc_master_secret(vm, hmac)?) + } else { + unreachable!("modes always match"); + } + } }; Ok(prf) } + /// Allocates key expansion. pub(crate) fn alloc_key_expansion( mode: Mode, vm: &mut dyn Vm, - outer_partial: Sha256, - inner_partial: Sha256, - ) -> Result { + hmac: Hmac, + ) -> Result { let prf = match mode { - Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_key_expansion( - vm, - outer_partial, - inner_partial, - )?), - Mode::Normal => Self::Normal(normal::PrfFunction::alloc_key_expansion( - vm, - outer_partial, - inner_partial, - )?), + Mode::Reduced => { + if let Hmac::Reduced(hmac) = hmac { + Self::Reduced(reduced::PrfFunction::alloc_key_expansion(vm, hmac)?) + } else { + unreachable!("modes always match"); + } + } + Mode::Normal => { + if let Hmac::Normal(hmac) = hmac { + Self::Normal(normal::PrfFunction::alloc_key_expansion(vm, hmac)?) + } else { + unreachable!("modes always match"); + } + } }; Ok(prf) } + /// Allocates client finished. pub(crate) fn alloc_client_finished( config: Mode, vm: &mut dyn Vm, - outer_partial: Sha256, - inner_partial: Sha256, - ) -> Result { + hmac: Hmac, + ) -> Result { let prf = match config { - Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_client_finished( - vm, - outer_partial, - inner_partial, - )?), - Mode::Normal => Self::Normal(normal::PrfFunction::alloc_client_finished( - vm, - outer_partial, - inner_partial, - )?), + Mode::Reduced => { + if let Hmac::Reduced(hmac) = hmac { + Self::Reduced(reduced::PrfFunction::alloc_client_finished(vm, hmac)?) + } else { + unreachable!("modes always match"); + } + } + Mode::Normal => { + if let Hmac::Normal(hmac) = hmac { + Self::Normal(normal::PrfFunction::alloc_client_finished(vm, hmac)?) + } else { + unreachable!("modes always match"); + } + } }; Ok(prf) } + /// Allocates server finished. pub(crate) fn alloc_server_finished( config: Mode, vm: &mut dyn Vm, - outer_partial: Sha256, - inner_partial: Sha256, - ) -> Result { + hmac: Hmac, + ) -> Result { let prf = match config { - Mode::Reduced => Self::Reduced(reduced::PrfFunction::alloc_server_finished( - vm, - outer_partial, - inner_partial, - )?), - Mode::Normal => Self::Normal(normal::PrfFunction::alloc_server_finished( - vm, - outer_partial, - inner_partial, - )?), + Mode::Reduced => { + if let Hmac::Reduced(hmac) = hmac { + Self::Reduced(reduced::PrfFunction::alloc_server_finished(vm, hmac)?) + } else { + unreachable!("modes always match"); + } + } + Mode::Normal => { + if let Hmac::Normal(hmac) = hmac { + Self::Normal(normal::PrfFunction::alloc_server_finished(vm, hmac)?) + } else { + unreachable!("modes always match"); + } + } }; Ok(prf) } + /// Whether this functionality needs to be flushed. pub(crate) fn wants_flush(&self) -> bool { match self { Prf::Reduced(prf) => prf.wants_flush(), @@ -111,13 +127,15 @@ impl Prf { } } - pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { + /// Flushes the functionality. + pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), FError> { match self { Prf::Reduced(prf) => prf.flush(vm), Prf::Normal(prf) => prf.flush(vm), } } + /// Sets the seed. pub(crate) fn set_start_seed(&mut self, seed: Vec) { match self { Prf::Reduced(prf) => prf.set_start_seed(seed), @@ -125,18 +143,28 @@ impl Prf { } } - pub(crate) fn output(&self) -> Vec> { + /// Returns the PRF output. + pub(crate) fn output(&self) -> Vector { match self { Prf::Reduced(prf) => prf.output(), Prf::Normal(prf) => prf.output(), } } + + /// Whether this functionality is complete. + pub(crate) fn is_complete(&self) -> bool { + match self { + Prf::Reduced(prf) => prf.is_complete(), + Prf::Normal(prf) => prf.is_complete(), + } + } } #[cfg(test)] mod tests { use crate::{ - prf::{compute_partial, function::Prf}, + hmac::Hmac, + prf::function::Prf, test_utils::{mock_vm, phash}, Mode, }; @@ -146,23 +174,13 @@ mod tests { Execute, }; use rand::{rngs::ThreadRng, Rng}; + use rstest::*; - const IPAD: [u8; 64] = [0x36; 64]; - const OPAD: [u8; 64] = [0x5c; 64]; - - #[tokio::test] - async fn test_phash_reduced() { - let mode = Mode::Reduced; - test_phash(mode).await; - } - + #[rstest] + #[case::normal(Mode::Normal)] + #[case::reduced(Mode::Reduced)] #[tokio::test] - async fn test_phash_normal() { - let mode = Mode::Normal; - test_phash(mode).await; - } - - async fn test_phash(mode: Mode) { + async fn test_phash(#[case] mode: Mode) { let mut rng = ThreadRng::default(); let (mut ctx_a, mut ctx_b) = test_st_context(8); @@ -170,6 +188,7 @@ mod tests { let key: [u8; 32] = rng.random(); let start_seed: Vec = vec![42; 64]; + let output_len = 48; let mut label_seed = b"master secret".to_vec(); label_seed.extend_from_slice(&start_seed); @@ -180,48 +199,25 @@ mod tests { leader.assign(leader_key, key).unwrap(); leader.commit(leader_key).unwrap(); - let outer_partial_leader = compute_partial(&mut leader, leader_key.into(), OPAD).unwrap(); - let inner_partial_leader = compute_partial(&mut leader, leader_key.into(), IPAD).unwrap(); + let leader_hmac = Hmac::alloc(&mut leader, leader_key.into(), mode).unwrap(); - let mut prf_leader = Prf::alloc_master_secret( - mode, - &mut leader, - outer_partial_leader, - inner_partial_leader, - ) - .unwrap(); + let mut prf_leader = Prf::alloc_master_secret(mode, &mut leader, leader_hmac).unwrap(); prf_leader.set_start_seed(start_seed.clone()); - let mut prf_out_leader = vec![]; - for p in prf_leader.output() { - let p_out = leader.decode(p).unwrap(); - prf_out_leader.push(p_out) - } + let mut prf_out_leader = leader.decode(prf_leader.output()).unwrap(); let follower_key: Array = follower.alloc().unwrap(); follower.mark_public(follower_key).unwrap(); follower.assign(follower_key, key).unwrap(); follower.commit(follower_key).unwrap(); - let outer_partial_follower = - compute_partial(&mut follower, follower_key.into(), OPAD).unwrap(); - let inner_partial_follower = - compute_partial(&mut follower, follower_key.into(), IPAD).unwrap(); + let follower_hmac = Hmac::alloc(&mut follower, follower_key.into(), mode).unwrap(); - let mut prf_follower = Prf::alloc_master_secret( - mode, - &mut follower, - outer_partial_follower, - inner_partial_follower, - ) - .unwrap(); + let mut prf_follower = + Prf::alloc_master_secret(mode, &mut follower, follower_hmac).unwrap(); prf_follower.set_start_seed(start_seed.clone()); - let mut prf_out_follower = vec![]; - for p in prf_follower.output() { - let p_out = follower.decode(p).unwrap(); - prf_out_follower.push(p_out) - } + let mut prf_out_follower = follower.decode(prf_follower.output()).unwrap(); while prf_leader.wants_flush() || prf_follower.wants_flush() { tokio::try_join!( @@ -237,19 +233,10 @@ mod tests { .unwrap(); } - assert_eq!(prf_out_leader.len(), 2); - assert_eq!(prf_out_leader.len(), prf_out_follower.len()); - - let prf_result_leader: Vec = prf_out_leader - .iter_mut() - .flat_map(|p| p.try_recv().unwrap().unwrap()) - .collect(); - let prf_result_follower: Vec = prf_out_follower - .iter_mut() - .flat_map(|p| p.try_recv().unwrap().unwrap()) - .collect(); + let prf_result_leader: Vec = prf_out_leader.try_recv().unwrap().unwrap(); + let prf_result_follower: Vec = prf_out_follower.try_recv().unwrap().unwrap(); - let expected = phash(key.to_vec(), &label_seed, iterations); + let expected = &phash(key.to_vec(), &label_seed, iterations)[..output_len]; assert_eq!(prf_result_leader, prf_result_follower); assert_eq!(prf_result_leader, expected) diff --git a/crates/components/hmac-sha256/src/prf/function/normal.rs b/crates/components/hmac-sha256/src/prf/function/normal.rs index ec931f24ef..9b811e5184 100644 --- a/crates/components/hmac-sha256/src/prf/function/normal.rs +++ b/crates/components/hmac-sha256/src/prf/function/normal.rs @@ -1,24 +1,30 @@ -//! Computes the whole PRF in MPC. +//! TLS 1.2 PRF function. -use crate::{hmac::hmac_sha256, PrfError}; -use mpz_hash::sha256::Sha256; +use crate::{hmac::normal::HmacNormal, tls12::merge_vectors, FError}; use mpz_vm_core::{ memory::{ binary::{Binary, U8}, - Array, MemoryExt, Vector, ViewExt, + MemoryExt, Vector, ViewExt, }, Vm, }; #[derive(Debug)] pub(crate) struct PrfFunction { - // The label, e.g. "master secret". + // The human-readable label, e.g. "master secret". label: &'static [u8], state: State, - // The start seed and the label, e.g. client_random + server_random + "master_secret". + /// The start seed and the label, e.g. client_random + server_random + + /// "master_secret". start_seed_label: Option>, - a: Vec, - p: Vec, + seed_label_ref: Vector, + /// A_Hash functionalities for each iteration instantiated with the PRF + /// secret. + a_hash: Vec, + /// P_Hash functionalities for each iteration instantiated with the PRF + /// secret. + p_hash: Vec, + output: Vector, } impl PrfFunction { @@ -27,64 +33,128 @@ impl PrfFunction { const CF_LABEL: &[u8] = b"client finished"; const SF_LABEL: &[u8] = b"server finished"; + /// Allocates master secret. pub(crate) fn alloc_master_secret( vm: &mut dyn Vm, - outer_partial: Sha256, - inner_partial: Sha256, - ) -> Result { - Self::alloc(vm, Self::MS_LABEL, outer_partial, inner_partial, 48, 64) + hmac: HmacNormal, + ) -> Result { + Self::alloc(vm, Self::MS_LABEL, hmac, 48, 64) } + /// Allocates key expansion. pub(crate) fn alloc_key_expansion( vm: &mut dyn Vm, - outer_partial: Sha256, - inner_partial: Sha256, - ) -> Result { - Self::alloc(vm, Self::KEY_LABEL, outer_partial, inner_partial, 40, 64) + hmac: HmacNormal, + ) -> Result { + Self::alloc(vm, Self::KEY_LABEL, hmac, 40, 64) } + /// Allocates client finished. pub(crate) fn alloc_client_finished( vm: &mut dyn Vm, - outer_partial: Sha256, - inner_partial: Sha256, - ) -> Result { - Self::alloc(vm, Self::CF_LABEL, outer_partial, inner_partial, 12, 32) + hmac: HmacNormal, + ) -> Result { + Self::alloc(vm, Self::CF_LABEL, hmac, 12, 32) } + /// Allocates server finished. pub(crate) fn alloc_server_finished( vm: &mut dyn Vm, - outer_partial: Sha256, - inner_partial: Sha256, - ) -> Result { - Self::alloc(vm, Self::SF_LABEL, outer_partial, inner_partial, 12, 32) + hmac: HmacNormal, + ) -> Result { + Self::alloc(vm, Self::SF_LABEL, hmac, 12, 32) } - pub(crate) fn wants_flush(&self) -> bool { - let is_computing = match self.state { - State::Computing => true, - State::Finished => false, - }; - is_computing && self.start_seed_label.is_some() - } + /// Allocates a new PRF with the given `hmac` instantiated with the PRF + /// secret. + fn alloc( + vm: &mut dyn Vm, + label: &'static [u8], + hmac: HmacNormal, + output_len: usize, + seed_len: usize, + ) -> Result { + assert!(output_len > 0, "cannot compute 0 bytes for prf"); + + let iterations = output_len.div_ceil(32); + + let msg_len_a = label.len() + seed_len; + let seed_label_ref: Vector = vm.alloc_vec(msg_len_a).map_err(FError::vm)?; + vm.mark_public(seed_label_ref).map_err(FError::vm)?; - pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { - if let State::Computing = self.state { - let a = self.a.first().expect("prf should be allocated"); - let msg = *a.msg.first().expect("message for prf should be present"); + let mut msg_a = seed_label_ref; + + let mut p_out: Vec> = Vec::with_capacity(iterations); + let mut a_hash = Vec::with_capacity(iterations); + let mut p_hash = Vec::with_capacity(iterations); - let msg_value = self - .start_seed_label - .clone() - .expect("Start seed should have been set"); + for _ in 0..iterations { + let mut a = HmacNormal::from_other(&hmac)?; + a.set_msg(vm, &[msg_a])?; + let a_out: Vector = a.output()?.into(); + msg_a = a_out; + a_hash.push(a); + + let mut p = HmacNormal::from_other(&hmac)?; + p.set_msg(vm, &[a_out, seed_label_ref])?; + p_out.push(p.output()?.into()); + p_hash.push(p); + } - vm.assign(msg, msg_value).map_err(PrfError::vm)?; - vm.commit(msg).map_err(PrfError::vm)?; + Ok(Self { + label, + state: State::WantsSeed, + start_seed_label: None, + seed_label_ref, + a_hash, + p_hash, + output: merge_vectors(vm, p_out, output_len)?, + }) + } - self.state = State::Finished; + /// Whether this functionality needs to be flushed. + pub(crate) fn wants_flush(&self) -> bool { + let state_wants_flush = match self.state { + State::WantsSeed => self.start_seed_label.is_some(), + _ => false, + }; + state_wants_flush + || self.a_hash.iter().any(|h| h.wants_flush()) + || self.p_hash.iter().any(|h| h.wants_flush()) + } + + /// Flushes the functionality. + pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), FError> { + // Flush every HMAC functionality. + self.a_hash.iter_mut().try_for_each(|h| h.flush())?; + self.p_hash.iter_mut().try_for_each(|h| h.flush())?; + + match self.state { + State::WantsSeed => { + if let Some(seed) = &self.start_seed_label { + vm.assign(self.seed_label_ref, seed.clone()) + .map_err(FError::vm)?; + vm.commit(self.seed_label_ref).map_err(FError::vm)?; + + self.state = State::SeedSet; + // Recurse. + self.flush(vm)?; + } + } + State::SeedSet => { + // We are complete when all HMACs are complete. + if self.a_hash.iter().all(|h| h.is_complete()) + && self.p_hash.iter().all(|h| h.is_complete()) + { + self.state = State::Complete; + } + } + _ => (), } Ok(()) } + /// Sets the seed. pub(crate) fn set_start_seed(&mut self, seed: Vec) { let mut start_seed_label = self.label.to_vec(); start_seed_label.extend_from_slice(&seed); @@ -92,83 +162,20 @@ impl PrfFunction { self.start_seed_label = Some(start_seed_label); } - pub(crate) fn output(&self) -> Vec> { - self.p.iter().map(|p| p.output).collect() + /// Returns the PRF output. + pub(crate) fn output(&self) -> Vector { + self.output } - fn alloc( - vm: &mut dyn Vm, - label: &'static [u8], - outer_partial: Sha256, - inner_partial: Sha256, - output_len: usize, - seed_len: usize, - ) -> Result { - let mut prf = Self { - label, - state: State::Computing, - start_seed_label: None, - a: vec![], - p: vec![], - }; - - assert!(output_len > 0, "cannot compute 0 bytes for prf"); - - let iterations = output_len.div_ceil(32); - - let msg_len_a = label.len() + seed_len; - let seed_label_ref: Vector = vm.alloc_vec(msg_len_a).map_err(PrfError::vm)?; - vm.mark_public(seed_label_ref).map_err(PrfError::vm)?; - - let mut msg_a = seed_label_ref; - for _ in 0..iterations { - let a = PHash::alloc(vm, outer_partial.clone(), inner_partial.clone(), &[msg_a])?; - msg_a = Vector::::from(a.output); - prf.a.push(a); - - let p = PHash::alloc( - vm, - outer_partial.clone(), - inner_partial.clone(), - &[msg_a, seed_label_ref], - )?; - prf.p.push(p); - } - - Ok(prf) + /// Whether this functionality is complete. + pub(crate) fn is_complete(&self) -> bool { + matches!(self.state, State::Complete) } } #[derive(Debug, Clone, Copy)] enum State { - Computing, - Finished, -} - -#[derive(Debug, Clone)] -struct PHash { - msg: Vec>, - output: Array, -} - -impl PHash { - fn alloc( - vm: &mut dyn Vm, - outer_partial: Sha256, - inner_partial: Sha256, - msg: &[Vector], - ) -> Result { - let mut inner_local = inner_partial; - - msg.iter().for_each(|m| inner_local.update(m)); - inner_local.compress(vm)?; - let inner_local = inner_local.finalize(vm)?; - - let output = hmac_sha256(vm, outer_partial, inner_local)?; - let p_hash = Self { - msg: msg.to_vec(), - output, - }; - Ok(p_hash) - } + WantsSeed, + SeedSet, + Complete, } diff --git a/crates/components/hmac-sha256/src/prf/function/reduced.rs b/crates/components/hmac-sha256/src/prf/function/reduced.rs index 403d4c5295..119a492ea9 100644 --- a/crates/components/hmac-sha256/src/prf/function/reduced.rs +++ b/crates/components/hmac-sha256/src/prf/function/reduced.rs @@ -1,46 +1,30 @@ -//! Computes some hashes of the PRF locally. +//! TLS 1.2 PRF function. use std::collections::VecDeque; -use crate::{hmac::hmac_sha256, sha256, state_to_bytes, PrfError}; -use mpz_core::bitvec::BitVec; -use mpz_hash::sha256::Sha256; +use crate::{hmac::reduced::HmacReduced, tls12::merge_vectors, FError}; use mpz_vm_core::{ memory::{ binary::{Binary, U8}, - Array, DecodeFutureTyped, MemoryExt, ViewExt, + MemoryExt, Vector, }, Vm, }; #[derive(Debug)] pub(crate) struct PrfFunction { - // The label, e.g. "master secret". + // The human-readable label, e.g. "master secret". label: &'static [u8], // The start seed and the label, e.g. client_random + server_random + "master_secret". start_seed_label: Option>, - iterations: usize, - state: PrfState, - a: VecDeque, - p: VecDeque, -} - -#[derive(Debug)] -enum PrfState { - InnerPartial { - inner_partial: DecodeFutureTyped, - }, - ComputeA { - iter: usize, - inner_partial: [u32; 8], - msg: Vec, - }, - ComputeP { - iter: usize, - inner_partial: [u32; 8], - a_output: DecodeFutureTyped, - }, - Done, + state: State, + /// A_Hash functionalities for each iteration instantiated with the PRF + /// secret. + a_hash: VecDeque, + /// P_Hash functionalities for each iteration instantiated with the PRF + /// secret. + p_hash: VecDeque, + output: Vector, } impl PrfFunction { @@ -49,103 +33,213 @@ impl PrfFunction { const CF_LABEL: &[u8] = b"client finished"; const SF_LABEL: &[u8] = b"server finished"; + /// Allocates master secret. pub(crate) fn alloc_master_secret( vm: &mut dyn Vm, - outer_partial: Sha256, - inner_partial: Sha256, - ) -> Result { - Self::alloc(vm, Self::MS_LABEL, outer_partial, inner_partial, 48) + hmac: HmacReduced, + ) -> Result { + Self::alloc(vm, Self::MS_LABEL, hmac, 48) } + /// Allocates key expansion. pub(crate) fn alloc_key_expansion( vm: &mut dyn Vm, - outer_partial: Sha256, - inner_partial: Sha256, - ) -> Result { - Self::alloc(vm, Self::KEY_LABEL, outer_partial, inner_partial, 40) + hmac: HmacReduced, + ) -> Result { + Self::alloc(vm, Self::KEY_LABEL, hmac, 40) } + /// Allocates client finished. pub(crate) fn alloc_client_finished( vm: &mut dyn Vm, - outer_partial: Sha256, - inner_partial: Sha256, - ) -> Result { - Self::alloc(vm, Self::CF_LABEL, outer_partial, inner_partial, 12) + hmac: HmacReduced, + ) -> Result { + Self::alloc(vm, Self::CF_LABEL, hmac, 12) } + /// Allocates server finished. pub(crate) fn alloc_server_finished( vm: &mut dyn Vm, - outer_partial: Sha256, - inner_partial: Sha256, - ) -> Result { - Self::alloc(vm, Self::SF_LABEL, outer_partial, inner_partial, 12) + hmac: HmacReduced, + ) -> Result { + Self::alloc(vm, Self::SF_LABEL, hmac, 12) } - pub(crate) fn wants_flush(&self) -> bool { - !matches!(self.state, PrfState::Done) && self.start_seed_label.is_some() + /// Allocates a new PRF with the given `hmac` instantiated with the PRF + /// secret. + fn alloc( + vm: &mut dyn Vm, + label: &'static [u8], + hmac: HmacReduced, + output_len: usize, + ) -> Result { + assert!(output_len > 0, "cannot compute 0 bytes for prf"); + + let iterations = output_len.div_ceil(32); + let mut a_hash = VecDeque::with_capacity(iterations); + let mut p_hash = VecDeque::with_capacity(iterations); + + // Create the required amount of HMAC instances. + let mut hmacs = vec![hmac]; + for _ in 0..iterations * 2 - 1 { + hmacs.push(HmacReduced::from_other(vm, &hmacs[0])?); + } + + let mut p_out: Vec> = Vec::with_capacity(iterations); + for _ in 0..iterations { + let a = hmacs.pop().expect("enough instances"); + let p = hmacs.pop().expect("enough instances"); + // Decode output as soon as it becomes available. + std::mem::drop(vm.decode(a.output()).map_err(FError::vm)?); + p_out.push(p.output().into()); + + a_hash.push_back(a); + p_hash.push_back(p); + } + + Ok(Self { + label, + start_seed_label: None, + state: State::WantsSeed, + a_hash, + p_hash, + output: merge_vectors(vm, p_out, output_len)?, + }) } - pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), PrfError> { - match &mut self.state { - PrfState::InnerPartial { inner_partial } => { - let Some(inner_partial) = inner_partial.try_recv().map_err(PrfError::vm)? else { - return Ok(()); - }; + /// Whether this functionality needs to be flushed. + pub(crate) fn wants_flush(&self) -> bool { + let state_wants_flush = match self.state { + State::WantsSeed => self.start_seed_label.is_some(), + State::ComputeFirstCycle { .. } => true, + State::ComputeCycle { .. } => true, + State::ComputeLastCycle { .. } => true, + _ => false, + }; - self.state = PrfState::ComputeA { - iter: 1, - inner_partial, - msg: self - .start_seed_label - .clone() - .expect("Start seed should have been set"), - }; - self.flush(vm)?; + state_wants_flush + || self.a_hash.iter().any(|h| h.wants_flush()) + || self.p_hash.iter().any(|h| h.wants_flush()) + } + + /// Flushes the functionality. + pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), FError> { + // Flush every HMAC functionality. + self.a_hash.iter_mut().try_for_each(|h| h.flush(vm))?; + self.p_hash.iter_mut().try_for_each(|h| h.flush(vm))?; + + match &self.state { + State::WantsSeed => { + if let Some(seed) = &self.start_seed_label { + self.state = State::ComputeFirstCycle { msg: seed.to_vec() }; + // recurse. + self.flush(vm)?; + } } - PrfState::ComputeA { - iter, - inner_partial, - msg, - } => { - let a = self.a.pop_front().expect("Prf AHash should be present"); - assign_inner_local(vm, a.inner_local, *inner_partial, msg)?; - - self.state = PrfState::ComputeP { - iter: *iter, - inner_partial: *inner_partial, - a_output: a.output, + State::ComputeFirstCycle { msg } => { + let mut a = self.a_hash.pop_front().expect("not empty"); + + if !a.is_msg_set() { + a.set_msg(msg)?; + a.flush(vm)?; + } + + let out = if a.is_complete() { + let mut a_out = vm.decode(a.output()).map_err(FError::vm)?; + a_out.try_recv().map_err(FError::vm)? + } else { + None }; + + match out { + Some(out) => { + self.state = State::ComputeCycle { msg: out.to_vec() }; + // Recurse to the next cycle. + self.flush(vm)?; + } + None => { + // Prepare to process this cycle again after VM executes. + self.a_hash.push_front(a); + self.state = State::ComputeFirstCycle { msg: msg.to_vec() }; + } + } } - PrfState::ComputeP { - iter, - inner_partial, - a_output, - } => { - let Some(output) = a_output.try_recv().map_err(PrfError::vm)? else { + State::ComputeCycle { msg } => { + if self.p_hash.len() == 1 { + // Recurse to the last cycle. + self.state = State::ComputeLastCycle { msg: msg.to_vec() }; + self.flush(vm)?; return Ok(()); - }; - let p = self.p.pop_front().expect("Prf PHash should be present"); + } + + let mut a = self.a_hash.pop_front().expect("not empty"); + let mut p = self.p_hash.pop_front().expect("not empty"); + + if !a.is_msg_set() { + a.set_msg(msg)?; + a.flush(vm)?; + } - let mut msg = output.to_vec(); - msg.extend_from_slice( - self.start_seed_label - .as_ref() - .expect("Start seed should have been set"), - ); + if !p.is_msg_set() { + let mut p_msg = msg.clone(); + p_msg.extend_from_slice( + self.start_seed_label + .as_ref() + .expect("Start seed should have been set"), + ); + p.set_msg(&p_msg)?; + p.flush(vm)?; + } - assign_inner_local(vm, p.inner_local, *inner_partial, &msg)?; + if !p.is_complete() { + // Prepare to process this cycle again after VM executes. + self.a_hash.push_front(a); + self.p_hash.push_front(p); + self.state = State::ComputeCycle { msg: msg.to_vec() }; + return Ok(()); + } - if *iter == self.iterations { - self.state = PrfState::Done; + let out = if a.is_complete() { + let mut a_out = vm.decode(a.output()).map_err(FError::vm)?; + a_out.try_recv().map_err(FError::vm)? } else { - self.state = PrfState::ComputeA { - iter: *iter + 1, - inner_partial: *inner_partial, - msg: output.to_vec(), - }; - // We recurse, so that this PHash and the next AHash could - // be computed in a single VM execute call. - self.flush(vm)?; + None + }; + + match out { + Some(out) => { + // Recurse to the next cycle. + self.state = State::ComputeCycle { msg: out.to_vec() }; + self.flush(vm)?; + } + None => { + // Prepare to process this cycle again after VM executes. + self.a_hash.push_front(a); + self.p_hash.push_front(p); + self.state = State::ComputeCycle { msg: msg.to_vec() }; + } + } + } + State::ComputeLastCycle { msg } => { + let mut p = self.p_hash.pop_front().expect("not empty"); + + if !p.is_msg_set() { + let mut p_msg = msg.clone(); + p_msg.extend_from_slice( + self.start_seed_label + .as_ref() + .expect("Start seed should have been set"), + ); + p.set_msg(&p_msg)?; + p.flush(vm)?; + } + + if !p.is_complete() { + // Prepare to process this cycle again after VM executes. + self.p_hash.push_front(p); + self.state = State::ComputeLastCycle { msg: msg.to_vec() }; + } else { + self.state = State::Complete; } } _ => (), @@ -154,6 +248,7 @@ impl PrfFunction { Ok(()) } + /// Sets the seed. pub(crate) fn set_start_seed(&mut self, seed: Vec) { let mut start_seed_label = self.label.to_vec(); start_seed_label.extend_from_slice(&seed); @@ -161,88 +256,33 @@ impl PrfFunction { self.start_seed_label = Some(start_seed_label); } - pub(crate) fn output(&self) -> Vec> { - self.p.iter().map(|p| p.output).collect() + /// Returns the PRF output. + pub(crate) fn output(&self) -> Vector { + self.output } - fn alloc( - vm: &mut dyn Vm, - label: &'static [u8], - outer_partial: Sha256, - inner_partial: Sha256, - len: usize, - ) -> Result { - assert!(len > 0, "cannot compute 0 bytes for prf"); - - let iterations = len.div_ceil(32); - - let (inner_partial, _) = inner_partial - .state() - .expect("state should be set for inner_partial"); - let inner_partial = vm.decode(inner_partial).map_err(PrfError::vm)?; - - let mut prf = Self { - label, - start_seed_label: None, - iterations, - state: PrfState::InnerPartial { inner_partial }, - a: VecDeque::new(), - p: VecDeque::new(), - }; - - for _ in 0..iterations { - // setup A[i] - let inner_local: Array = vm.alloc().map_err(PrfError::vm)?; - let output = hmac_sha256(vm, outer_partial.clone(), inner_local)?; - - let output = vm.decode(output).map_err(PrfError::vm)?; - let a_hash = AHash { - inner_local, - output, - }; - - prf.a.push_front(a_hash); - - // setup P[i] - let inner_local: Array = vm.alloc().map_err(PrfError::vm)?; - let output = hmac_sha256(vm, outer_partial.clone(), inner_local)?; - let p_hash = PHash { - inner_local, - output, - }; - prf.p.push_front(p_hash); - } - - Ok(prf) + /// Whether this functionality is complete. + pub(crate) fn is_complete(&self) -> bool { + matches!(self.state, State::Complete) } } -fn assign_inner_local( - vm: &mut dyn Vm, - inner_local: Array, - inner_partial: [u32; 8], - msg: &[u8], -) -> Result<(), PrfError> { - let inner_local_value = sha256(inner_partial, 64, msg); - - vm.mark_public(inner_local).map_err(PrfError::vm)?; - vm.assign(inner_local, state_to_bytes(inner_local_value)) - .map_err(PrfError::vm)?; - vm.commit(inner_local).map_err(PrfError::vm)?; - - Ok(()) -} - -/// Like PHash but stores the output as the decoding future because in the -/// reduced Prf we need to decode this output. -#[derive(Debug)] -struct AHash { - inner_local: Array, - output: DecodeFutureTyped, -} - -#[derive(Debug, Clone, Copy)] -struct PHash { - inner_local: Array, - output: Array, +#[derive(Debug, PartialEq)] +enum State { + WantsSeed, + /// To minimize the amount of VM execute calls, the PRF iterations are + /// divided into cycles. + /// Starting with iteration count i == 1, each cycle computes a tuple + /// (A_Hash(i), P_Hash(i-1)). Thus, during the first cycle, only A_Hash(1) + /// is computed and during the last cycle only P_Hash(i) is computed. + ComputeFirstCycle { + msg: Vec, + }, + ComputeCycle { + msg: Vec, + }, + ComputeLastCycle { + msg: Vec, + }, + Complete, } diff --git a/crates/components/hmac-sha256/src/prf/mod.rs b/crates/components/hmac-sha256/src/prf/mod.rs new file mode 100644 index 0000000000..92151817f8 --- /dev/null +++ b/crates/components/hmac-sha256/src/prf/mod.rs @@ -0,0 +1,2 @@ +pub(crate) mod function; +pub(crate) use function::Prf; diff --git a/crates/components/hmac-sha256/src/prf/state.rs b/crates/components/hmac-sha256/src/prf/state.rs deleted file mode 100644 index 6299c6d07d..0000000000 --- a/crates/components/hmac-sha256/src/prf/state.rs +++ /dev/null @@ -1,103 +0,0 @@ -use crate::{ - prf::{function::Prf, merge_outputs}, - PrfError, PrfOutput, SessionKeys, -}; -use mpz_vm_core::{ - memory::{ - binary::{Binary, U8}, - Array, FromRaw, ToRaw, - }, - Vm, -}; - -#[allow(clippy::large_enum_variant)] -#[derive(Debug)] -pub(crate) enum State { - Initialized, - SessionKeys { - client_random: Option<[u8; 32]>, - master_secret: Prf, - key_expansion: Prf, - client_finished: Prf, - server_finished: Prf, - }, - ClientFinished { - client_finished: Prf, - server_finished: Prf, - }, - ServerFinished { - server_finished: Prf, - }, - Complete, - Error, -} - -impl State { - pub(crate) fn take(&mut self) -> State { - std::mem::replace(self, State::Error) - } - - pub(crate) fn prf_output(&self, vm: &mut dyn Vm) -> Result { - let State::SessionKeys { - key_expansion, - client_finished, - server_finished, - .. - } = self - else { - return Err(PrfError::state( - "Prf output can only be computed while in \"SessionKeys\" state", - )); - }; - - let keys = get_session_keys(key_expansion.output(), vm)?; - let cf_vd = get_client_finished_vd(client_finished.output(), vm)?; - let sf_vd = get_server_finished_vd(server_finished.output(), vm)?; - - let output = PrfOutput { keys, cf_vd, sf_vd }; - - Ok(output) - } -} - -fn get_session_keys( - output: Vec>, - vm: &mut dyn Vm, -) -> Result { - let mut keys = merge_outputs(vm, output, 40)?; - debug_assert!(keys.len() == 40, "session keys len should be 40"); - - let server_iv = Array::::try_from(keys.split_off(36)).unwrap(); - let client_iv = Array::::try_from(keys.split_off(32)).unwrap(); - let server_write_key = Array::::try_from(keys.split_off(16)).unwrap(); - let client_write_key = Array::::try_from(keys).unwrap(); - - let session_keys = SessionKeys { - client_write_key, - server_write_key, - client_iv, - server_iv, - }; - - Ok(session_keys) -} - -fn get_client_finished_vd( - output: Vec>, - vm: &mut dyn Vm, -) -> Result, PrfError> { - let cf_vd = merge_outputs(vm, output, 12)?; - let cf_vd = as FromRaw>::from_raw(cf_vd.to_raw()); - - Ok(cf_vd) -} - -fn get_server_finished_vd( - output: Vec>, - vm: &mut dyn Vm, -) -> Result, PrfError> { - let sf_vd = merge_outputs(vm, output, 12)?; - let sf_vd = as FromRaw>::from_raw(sf_vd.to_raw()); - - Ok(sf_vd) -} diff --git a/crates/components/hmac-sha256/src/test_utils.rs b/crates/components/hmac-sha256/src/test_utils.rs index 65f340f178..edf81405dc 100644 --- a/crates/components/hmac-sha256/src/test_utils.rs +++ b/crates/components/hmac-sha256/src/test_utils.rs @@ -1,13 +1,9 @@ -use crate::{sha256, state_to_bytes}; +use crate::hmac::clear; use mpz_garble::protocol::semihonest::{Evaluator, Garbler}; use mpz_ot::ideal::cot::{ideal_cot, IdealCOTReceiver, IdealCOTSender}; use mpz_vm_core::memory::correlated::Delta; use rand::{rngs::StdRng, Rng, SeedableRng}; -pub(crate) const SHA256_IV: [u32; 8] = [ - 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, -]; - pub(crate) fn mock_vm() -> (Garbler, Evaluator) { let mut rng = StdRng::seed_from_u64(0); let delta = Delta::random(&mut rng); @@ -72,7 +68,7 @@ pub(crate) fn phash(key: Vec, seed: &[u8], iterations: usize) -> Vec { a_cache.push(seed.to_vec()); for i in 0..iterations { - let a_i = hmac_sha256(key.clone(), &a_cache[i]); + let a_i = clear::hmac_sha256(&key, &a_cache[i]); a_cache.push(a_i.to_vec()); } @@ -82,64 +78,13 @@ pub(crate) fn phash(key: Vec, seed: &[u8], iterations: usize) -> Vec { let mut a_i_seed = a_cache[i + 1].clone(); a_i_seed.extend_from_slice(seed); - let hash = hmac_sha256(key.clone(), &a_i_seed); + let hash = clear::hmac_sha256(&key, &a_i_seed); output.extend_from_slice(&hash); } output } -pub(crate) fn hmac_sha256(key: Vec, msg: &[u8]) -> [u8; 32] { - let outer_partial = compute_outer_partial(key.clone()); - let inner_local = compute_inner_local(key, msg); - - let hmac = sha256(outer_partial, 64, &state_to_bytes(inner_local)); - state_to_bytes(hmac) -} - -pub(crate) fn compute_outer_partial(mut key: Vec) -> [u32; 8] { - assert!(key.len() <= 64); - - key.resize(64, 0_u8); - let key_padded: [u8; 64] = key - .into_iter() - .map(|b| b ^ 0x5c) - .collect::>() - .try_into() - .unwrap(); - - compress_256(SHA256_IV, &key_padded) -} - -pub(crate) fn compute_inner_local(mut key: Vec, msg: &[u8]) -> [u32; 8] { - assert!(key.len() <= 64); - - key.resize(64, 0_u8); - let key_padded: [u8; 64] = key - .into_iter() - .map(|b| b ^ 0x36) - .collect::>() - .try_into() - .unwrap(); - - let state = compress_256(SHA256_IV, &key_padded); - sha256(state, 64, msg) -} - -pub(crate) fn compress_256(mut state: [u32; 8], msg: &[u8]) -> [u32; 8] { - use sha2::{ - compress256, - digest::{ - block_buffer::{BlockBuffer, Eager}, - generic_array::typenum::U64, - }, - }; - - let mut buffer = BlockBuffer::::default(); - buffer.digest_blocks(msg, |b| compress256(&mut state, b)); - state -} - // Borrowed from Rustls for testing // https://github.com/rustls/rustls/blob/main/rustls/src/tls12/prf.rs mod ring_prf { @@ -259,3 +204,21 @@ fn test_prf_reference_sf() { assert_eq!(sf_vd, expected_sf_vd); } + +#[test] +fn test_key_schedule_reference_sf() { + use ring_prf::prf as prf_ref; + + let mut rng = StdRng::from_seed([4; 32]); + + let ms: [u8; 48] = rng.random(); + let label: &[u8] = b"server finished"; + let handshake_hash: [u8; 32] = rng.random(); + + let sf_vd = prf_sf_vd(ms, handshake_hash); + + let mut expected_sf_vd: [u8; 12] = [0; 12]; + prf_ref(&mut expected_sf_vd, &ms, label, &handshake_hash); + + assert_eq!(sf_vd, expected_sf_vd); +} diff --git a/crates/components/hmac-sha256/src/tls12.rs b/crates/components/hmac-sha256/src/tls12.rs new file mode 100644 index 0000000000..dab8a25460 --- /dev/null +++ b/crates/components/hmac-sha256/src/tls12.rs @@ -0,0 +1,556 @@ +//! Functionality for computing HMAC-SHA-256-based TLS 1.2 PRF. + +use std::{fmt::Debug, sync::Arc}; + +use mpz_circuits::{Circuit, CircuitBuilder}; +use mpz_vm_core::{ + memory::{ + binary::{Binary, U8}, + Array, StaticSize, Vector, + }, + Call, CallableExt, Vm, +}; +use tracing::instrument; + +use crate::{hmac::Hmac, prf::Prf, tls12::state::State, FError, Mode}; + +mod state; + +/// Functionality for computing HMAC-SHA-256-based TLS 1.2 PRF. +#[derive(Debug)] +pub struct Tls12Prf { + mode: Mode, + state: State, +} + +impl Tls12Prf { + /// Creates a new instance of the PRF. + /// + /// # Arguments + /// + /// `mode` - The PRF mode. + pub fn new(mode: Mode) -> Tls12Prf { + Self { + mode, + state: State::Initialized, + } + } + + /// Allocates resources for the PRF. + /// + /// # Arguments + /// + /// * `vm` - Virtual machine. + /// * `pms` - The pre-master secret. + #[instrument(level = "debug", skip_all, err)] + pub fn alloc( + &mut self, + vm: &mut dyn Vm, + pms: Array, + ) -> Result { + let State::Initialized = self.state.take() else { + return Err(FError::state("PRF not in initialized state")); + }; + + let mode = self.mode; + + let hmac_pms = Hmac::alloc(vm, pms.into(), mode)?; + + let master_secret = Prf::alloc_master_secret(mode, vm, hmac_pms)?; + + let hmac_ms1: Hmac = Hmac::alloc(vm, master_secret.output(), mode)?; + let hmac_ms2 = Hmac::from_other(vm, &hmac_ms1)?; + let hmac_ms3 = Hmac::from_other(vm, &hmac_ms1)?; + + let key_expansion = Prf::alloc_key_expansion(mode, vm, hmac_ms1)?; + + let client_finished = Prf::alloc_client_finished(mode, vm, hmac_ms2)?; + + let server_finished = Prf::alloc_server_finished(mode, vm, hmac_ms3)?; + + self.state = State::SessionKeys { + client_random: None, + master_secret, + key_expansion, + client_finished, + server_finished, + }; + + self.state.prf_output() + } + + /// Sets the client random. + /// + /// # Arguments + /// + /// * `random` - The client random. + #[instrument(level = "debug", skip_all, err)] + pub fn set_client_random(&mut self, random: [u8; 32]) -> Result<(), FError> { + let State::SessionKeys { client_random, .. } = &mut self.state else { + return Err(FError::state("PRF not set up")); + }; + + *client_random = Some(random); + Ok(()) + } + + /// Sets the server random. + /// + /// # Arguments + /// + /// * `random` - The server random. + #[instrument(level = "debug", skip_all, err)] + pub fn set_server_random(&mut self, random: [u8; 32]) -> Result<(), FError> { + let State::SessionKeys { + client_random, + master_secret, + key_expansion, + .. + } = &mut self.state + else { + return Err(FError::state("PRF not set up")); + }; + + let client_random = client_random.expect("Client random should have been set by now"); + let server_random = random; + + let mut seed_ms = client_random.to_vec(); + seed_ms.extend_from_slice(&server_random); + master_secret.set_start_seed(seed_ms); + + let mut seed_ke = server_random.to_vec(); + seed_ke.extend_from_slice(&client_random); + key_expansion.set_start_seed(seed_ke); + + Ok(()) + } + + /// Sets the client finished handshake hash. + /// + /// # Arguments + /// + /// * `handshake_hash` - The handshake transcript hash. + #[instrument(level = "debug", skip_all, err)] + pub fn set_cf_hash(&mut self, handshake_hash: [u8; 32]) -> Result<(), FError> { + let State::ClientFinished { + client_finished, .. + } = &mut self.state + else { + return Err(FError::state("PRF not in client finished state")); + }; + + let seed_cf = handshake_hash.to_vec(); + client_finished.set_start_seed(seed_cf); + + Ok(()) + } + + /// Sets the server finished handshake hash. + /// + /// # Arguments + /// + /// * `handshake_hash` - The handshake transcript hash. + #[instrument(level = "debug", skip_all, err)] + pub fn set_sf_hash(&mut self, handshake_hash: [u8; 32]) -> Result<(), FError> { + let State::ServerFinished { server_finished } = &mut self.state else { + return Err(FError::state("PRF not in server finished state")); + }; + + let seed_sf = handshake_hash.to_vec(); + server_finished.set_start_seed(seed_sf); + + Ok(()) + } + + /// Returns if the PRF needs to be flushed. + pub fn wants_flush(&self) -> bool { + match &self.state { + State::SessionKeys { + master_secret, + key_expansion, + .. + } => master_secret.wants_flush() || key_expansion.wants_flush(), + State::ClientFinished { + client_finished, .. + } => client_finished.wants_flush(), + State::ServerFinished { server_finished } => server_finished.wants_flush(), + _ => false, + } + } + + /// Flushes the PRF. + pub fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), FError> { + self.state = match self.state.take() { + State::SessionKeys { + client_random, + mut master_secret, + mut key_expansion, + client_finished, + server_finished, + } => { + master_secret.flush(vm)?; + key_expansion.flush(vm)?; + + if master_secret.is_complete() && key_expansion.is_complete() { + State::ClientFinished { + client_finished, + server_finished, + } + } else { + State::SessionKeys { + client_random, + master_secret, + key_expansion, + client_finished, + server_finished, + } + } + } + State::ClientFinished { + mut client_finished, + server_finished, + } => { + client_finished.flush(vm)?; + + if client_finished.is_complete() { + State::ServerFinished { server_finished } + } else { + State::ClientFinished { + client_finished, + server_finished, + } + } + } + State::ServerFinished { + mut server_finished, + } => { + server_finished.flush(vm)?; + + if server_finished.is_complete() { + State::Complete + } else { + State::ServerFinished { server_finished } + } + } + other => other, + }; + + Ok(()) + } +} + +/// PRF output. +#[derive(Debug, Clone, Copy)] +pub struct PrfOutput { + /// TLS session keys. + pub keys: SessionKeys, + /// Client finished verify data. + pub cf_vd: Array, + /// Server finished verify data. + pub sf_vd: Array, +} + +/// Session keys computed by the PRF. +#[derive(Debug, Clone, Copy)] +pub struct SessionKeys { + /// Client write key. + pub client_write_key: Array, + /// Server write key. + pub server_write_key: Array, + /// Client IV. + pub client_iv: Array, + /// Server IV. + pub server_iv: Array, +} + +/// Merges vectors, returning the merged vector truncated to the `output_bytes` +/// length. +pub(crate) fn merge_vectors( + vm: &mut dyn Vm, + inputs: Vec>, + output_bytes: usize, +) -> Result, FError> { + let len = inputs.iter().map(|inp| inp.len()).sum(); + assert!(output_bytes <= len); + + let bits = len * U8::SIZE; + let circ = gen_merge_circ(bits); + + let mut builder = Call::builder(circ); + for &input in inputs.iter() { + builder = builder.arg(input); + } + let call = builder.build().map_err(FError::vm)?; + + let mut output: Vector = vm.call(call).map_err(FError::vm)?; + output.truncate(output_bytes); + Ok(output) +} + +fn gen_merge_circ(size: usize) -> Arc { + let mut builder = CircuitBuilder::new(); + let inputs = (0..size).map(|_| builder.add_input()).collect::>(); + + for input in inputs.chunks_exact(8) { + for byte in input.chunks_exact(8) { + for &feed in byte.iter() { + let output = builder.add_id_gate(feed); + builder.add_output(output); + } + } + } + + Arc::new(builder.build().expect("merge circuit is valid")) +} + +#[cfg(test)] +mod tests { + use crate::{ + test_utils::{mock_vm, prf_cf_vd, prf_keys, prf_ms, prf_sf_vd}, + tls12::merge_vectors, + Mode, SessionKeys, Tls12Prf, + }; + use mpz_common::context::test_st_context; + use mpz_vm_core::{ + memory::{binary::U8, Array, MemoryExt, Vector, ViewExt}, + Execute, + }; + use rand::{rngs::StdRng, Rng, SeedableRng}; + use rstest::*; + + #[rstest] + #[case::normal(Mode::Normal)] + #[case::reduced(Mode::Reduced)] + #[tokio::test] + async fn test_tls12prf(#[case] mode: Mode) { + let mut rng = StdRng::seed_from_u64(1); + // Test input. + let pms: [u8; 32] = rng.random(); + let client_random: [u8; 32] = rng.random(); + let server_random: [u8; 32] = rng.random(); + + let cf_hs_hash: [u8; 32] = rng.random(); + let sf_hs_hash: [u8; 32] = rng.random(); + + // Expected output. + let ms_expected = prf_ms(pms, client_random, server_random); + + let [cwk_expected, swk_expected, civ_expected, siv_expected] = + prf_keys(ms_expected, client_random, server_random); + + let cwk_expected: [u8; 16] = cwk_expected.try_into().unwrap(); + let swk_expected: [u8; 16] = swk_expected.try_into().unwrap(); + let civ_expected: [u8; 4] = civ_expected.try_into().unwrap(); + let siv_expected: [u8; 4] = siv_expected.try_into().unwrap(); + + let cf_vd_expected = prf_cf_vd(ms_expected, cf_hs_hash); + let sf_vd_expected = prf_sf_vd(ms_expected, sf_hs_hash); + + let cf_vd_expected: [u8; 12] = cf_vd_expected.try_into().unwrap(); + let sf_vd_expected: [u8; 12] = sf_vd_expected.try_into().unwrap(); + + // Set up vm and prf. + let (mut ctx_a, mut ctx_b) = test_st_context(128); + let (mut leader, mut follower) = mock_vm(); + + let leader_pms: Array = leader.alloc().unwrap(); + leader.mark_public(leader_pms).unwrap(); + leader.assign(leader_pms, pms).unwrap(); + leader.commit(leader_pms).unwrap(); + + let follower_pms: Array = follower.alloc().unwrap(); + follower.mark_public(follower_pms).unwrap(); + follower.assign(follower_pms, pms).unwrap(); + follower.commit(follower_pms).unwrap(); + + let mut prf_leader = Tls12Prf::new(mode); + let mut prf_follower = Tls12Prf::new(mode); + + let leader_prf_out = prf_leader.alloc(&mut leader, leader_pms).unwrap(); + let follower_prf_out = prf_follower.alloc(&mut follower, follower_pms).unwrap(); + + // client_random and server_random. + prf_leader.set_client_random(client_random).unwrap(); + prf_follower.set_client_random(client_random).unwrap(); + + prf_leader.set_server_random(server_random).unwrap(); + prf_follower.set_server_random(server_random).unwrap(); + + let SessionKeys { + client_write_key: cwk_leader, + server_write_key: swk_leader, + client_iv: civ_leader, + server_iv: siv_leader, + } = leader_prf_out.keys; + + let mut cwk_leader = leader.decode(cwk_leader).unwrap(); + let mut swk_leader = leader.decode(swk_leader).unwrap(); + let mut civ_leader = leader.decode(civ_leader).unwrap(); + let mut siv_leader = leader.decode(siv_leader).unwrap(); + + let SessionKeys { + client_write_key: cwk_follower, + server_write_key: swk_follower, + client_iv: civ_follower, + server_iv: siv_follower, + } = follower_prf_out.keys; + + let mut cwk_follower = follower.decode(cwk_follower).unwrap(); + let mut swk_follower = follower.decode(swk_follower).unwrap(); + let mut civ_follower = follower.decode(civ_follower).unwrap(); + let mut siv_follower = follower.decode(siv_follower).unwrap(); + + while prf_leader.wants_flush() || prf_follower.wants_flush() { + tokio::try_join!( + async { + prf_leader.flush(&mut leader).unwrap(); + leader.execute_all(&mut ctx_a).await + }, + async { + prf_follower.flush(&mut follower).unwrap(); + follower.execute_all(&mut ctx_b).await + } + ) + .unwrap(); + } + + let cwk_leader = cwk_leader.try_recv().unwrap().unwrap(); + let swk_leader = swk_leader.try_recv().unwrap().unwrap(); + let civ_leader = civ_leader.try_recv().unwrap().unwrap(); + let siv_leader = siv_leader.try_recv().unwrap().unwrap(); + + let cwk_follower = cwk_follower.try_recv().unwrap().unwrap(); + let swk_follower = swk_follower.try_recv().unwrap().unwrap(); + let civ_follower = civ_follower.try_recv().unwrap().unwrap(); + let siv_follower = siv_follower.try_recv().unwrap().unwrap(); + + assert_eq!(cwk_leader, cwk_follower); + assert_eq!(swk_leader, swk_follower); + assert_eq!(civ_leader, civ_follower); + assert_eq!(siv_leader, siv_follower); + + assert_eq!(cwk_leader, cwk_expected); + assert_eq!(swk_leader, swk_expected); + assert_eq!(civ_leader, civ_expected); + assert_eq!(siv_leader, siv_expected); + + // client finished. + prf_leader.set_cf_hash(cf_hs_hash).unwrap(); + prf_follower.set_cf_hash(cf_hs_hash).unwrap(); + + let cf_vd_leader = leader_prf_out.cf_vd; + let cf_vd_follower = follower_prf_out.cf_vd; + + let mut cf_vd_leader = leader.decode(cf_vd_leader).unwrap(); + let mut cf_vd_follower = follower.decode(cf_vd_follower).unwrap(); + + while prf_leader.wants_flush() || prf_follower.wants_flush() { + tokio::try_join!( + async { + prf_leader.flush(&mut leader).unwrap(); + leader.execute_all(&mut ctx_a).await + }, + async { + prf_follower.flush(&mut follower).unwrap(); + follower.execute_all(&mut ctx_b).await + } + ) + .unwrap(); + } + + let cf_vd_leader = cf_vd_leader.try_recv().unwrap().unwrap(); + let cf_vd_follower = cf_vd_follower.try_recv().unwrap().unwrap(); + + assert_eq!(cf_vd_leader, cf_vd_follower); + assert_eq!(cf_vd_leader, cf_vd_expected); + + // server finished. + prf_leader.set_sf_hash(sf_hs_hash).unwrap(); + prf_follower.set_sf_hash(sf_hs_hash).unwrap(); + + let sf_vd_leader = leader_prf_out.sf_vd; + let sf_vd_follower = follower_prf_out.sf_vd; + + let mut sf_vd_leader = leader.decode(sf_vd_leader).unwrap(); + let mut sf_vd_follower = follower.decode(sf_vd_follower).unwrap(); + + while prf_leader.wants_flush() || prf_follower.wants_flush() { + tokio::try_join!( + async { + prf_leader.flush(&mut leader).unwrap(); + leader.execute_all(&mut ctx_a).await + }, + async { + prf_follower.flush(&mut follower).unwrap(); + follower.execute_all(&mut ctx_b).await + } + ) + .unwrap(); + } + + let sf_vd_leader = sf_vd_leader.try_recv().unwrap().unwrap(); + let sf_vd_follower = sf_vd_follower.try_recv().unwrap().unwrap(); + + assert_eq!(sf_vd_leader, sf_vd_follower); + assert_eq!(sf_vd_leader, sf_vd_expected); + } + + #[tokio::test] + async fn test_merge_outputs() { + let (mut ctx_a, mut ctx_b) = test_st_context(8); + let (mut leader, mut follower) = mock_vm(); + + let input1: [u8; 32] = std::array::from_fn(|i| i as u8); + let input2: [u8; 32] = std::array::from_fn(|i| i as u8 + 32); + + let mut expected = input1.to_vec(); + expected.extend_from_slice(&input2); + expected.truncate(48); + + // leader + let input1_leader: Vector = leader.alloc_vec(32).unwrap(); + let input2_leader: Vector = leader.alloc_vec(32).unwrap(); + + leader.mark_public(input1_leader).unwrap(); + leader.mark_public(input2_leader).unwrap(); + + leader.assign(input1_leader, input1.to_vec()).unwrap(); + leader.assign(input2_leader, input2.to_vec()).unwrap(); + + leader.commit(input1_leader).unwrap(); + leader.commit(input2_leader).unwrap(); + + let merged_leader = + merge_vectors(&mut leader, vec![input1_leader, input2_leader], 48).unwrap(); + let mut merged_leader = leader.decode(merged_leader).unwrap(); + + // follower + let input1_follower: Vector = follower.alloc_vec(32).unwrap(); + let input2_follower: Vector = follower.alloc_vec(32).unwrap(); + + follower.mark_public(input1_follower).unwrap(); + follower.mark_public(input2_follower).unwrap(); + + follower.assign(input1_follower, input1.to_vec()).unwrap(); + follower.assign(input2_follower, input2.to_vec()).unwrap(); + + follower.commit(input1_follower).unwrap(); + follower.commit(input2_follower).unwrap(); + + let merged_follower = + merge_vectors(&mut follower, vec![input1_follower, input2_follower], 48).unwrap(); + let mut merged_follower = follower.decode(merged_follower).unwrap(); + + tokio::try_join!( + leader.execute_all(&mut ctx_a), + follower.execute_all(&mut ctx_b) + ) + .unwrap(); + + let merged_leader = merged_leader.try_recv().unwrap().unwrap(); + let merged_follower = merged_follower.try_recv().unwrap().unwrap(); + + assert_eq!(merged_leader, merged_follower); + assert_eq!(merged_leader, expected); + } +} diff --git a/crates/components/hmac-sha256/src/tls12/state.rs b/crates/components/hmac-sha256/src/tls12/state.rs new file mode 100644 index 0000000000..3c216300ce --- /dev/null +++ b/crates/components/hmac-sha256/src/tls12/state.rs @@ -0,0 +1,79 @@ +use crate::{prf::Prf, FError, PrfOutput, SessionKeys}; +use mpz_vm_core::memory::{binary::U8, Array}; + +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +pub(crate) enum State { + Initialized, + SessionKeys { + client_random: Option<[u8; 32]>, + master_secret: Prf, + key_expansion: Prf, + client_finished: Prf, + server_finished: Prf, + }, + ClientFinished { + client_finished: Prf, + server_finished: Prf, + }, + ServerFinished { + server_finished: Prf, + }, + Complete, + Error, +} + +impl State { + pub(crate) fn take(&mut self) -> State { + std::mem::replace(self, State::Error) + } + + pub(crate) fn prf_output(&self) -> Result { + let State::SessionKeys { + key_expansion, + client_finished, + server_finished, + .. + } = self + else { + return Err(FError::state( + "Prf output can only be computed while in \"SessionKeys\" state", + )); + }; + + let keys = get_session_keys( + key_expansion + .output() + .try_into() + .expect("session keys are 40 bytes"), + ); + + let output = PrfOutput { + keys, + cf_vd: client_finished + .output() + .try_into() + .expect("client finished is 12 bytes"), + sf_vd: server_finished + .output() + .try_into() + .expect("server finished is 12 bytes"), + }; + + Ok(output) + } +} + +fn get_session_keys(keys: Array) -> SessionKeys { + let client_write_key = keys.get::<16>(0).expect("within bounds"); + let server_write_key = keys.get::<16>(16).expect("within bounds"); + let client_iv = keys.get::<4>(32).expect("within bounds"); + let server_iv = keys.get::<4>(36).expect("within bounds"); + + SessionKeys { + client_write_key, + server_write_key, + client_iv, + server_iv, + } +} diff --git a/crates/components/hmac-sha256/src/tls13.rs b/crates/components/hmac-sha256/src/tls13.rs new file mode 100644 index 0000000000..7bac110b44 --- /dev/null +++ b/crates/components/hmac-sha256/src/tls13.rs @@ -0,0 +1,605 @@ +//! Functionality for computing HMAC-SHA256-based TLS 1.3 key schedule. + +use std::mem; + +use mpz_vm_core::{ + memory::{ + binary::{Binary, U8}, + Array, MemoryExt, + }, + OneTimePad, Vm, +}; +use rand::RngCore; + +use crate::{ + hmac::Hmac, + kdf::{expand::hkdf_expand_label, extract::HkdfExtract}, + tls13::{application::ApplicationSecrets, handshake::HandshakeSecrets}, + FError, Mode, +}; + +mod application; +mod handshake; + +/// Functionality role. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum Role { + /// Leader. + /// + /// The leader learns handshake secrets and locally finishes the handshake. + Leader, + /// Follower. + Follower, +} + +/// Functionality for computing HMAC-SHA-256-based TLS 1.3 key schedule. +pub struct Tls13KeySched { + mode: Mode, + role: Role, + // Allocated master secret. + master_secret: Option, + // Allocated application secrets. + application: Option, + state: State, +} + +impl Tls13KeySched { + /// Creates a new functionality. + pub fn new(mode: Mode, role: Role) -> Tls13KeySched { + Self { + mode, + role, + application: None, + master_secret: None, + state: State::Initialized, + } + } + + /// Allocates the functionality with the given pre-master secret. + pub fn alloc(&mut self, vm: &mut dyn Vm, pms: Array) -> Result<(), FError> { + let State::Initialized = self.state.take() else { + return Err(FError::state("not in initialized state")); + }; + + let mut hs_secrets = HandshakeSecrets::new(self.mode); + let (cs, ss, derived_secret) = hs_secrets.alloc(vm, pms)?; + + let (masked_cs, cs_otp, masked_ss, ss_otp) = match self.role { + Role::Leader => { + let mut cs_otp = [0u8; 32]; + let mut ss_otp = [0u8; 32]; + rand::rng().fill_bytes(&mut cs_otp); + rand::rng().fill_bytes(&mut ss_otp); + let masked_cs = vm.mask_private(cs, cs_otp).map_err(FError::vm)?; + let masked_ss = vm.mask_private(ss, ss_otp).map_err(FError::vm)?; + (masked_cs, Some(cs_otp), masked_ss, Some(ss_otp)) + } + Role::Follower => { + let masked_cs = vm.mask_blind(cs).map_err(FError::vm)?; + let masked_ss = vm.mask_blind(ss).map_err(FError::vm)?; + (masked_cs, None, masked_ss, None) + } + }; + + // Decode as soon as values are known. + std::mem::drop(vm.decode(masked_cs).map_err(FError::vm)?); + std::mem::drop(vm.decode(masked_ss).map_err(FError::vm)?); + + let hmac_derived = Hmac::alloc(vm, derived_secret, self.mode)?; + let master_secret = HkdfExtract::alloc(self.mode, vm, [0u8; 32], hmac_derived)?; + + let mut aps = ApplicationSecrets::new(self.mode); + aps.alloc(vm, master_secret.output())?; + + self.master_secret = Some(master_secret); + self.application = Some(aps); + self.state = State::Handshake { + secrets: hs_secrets, + masked_cs, + masked_ss, + cs_otp, + ss_otp, + }; + + Ok(()) + } + + /// Whether this functionality needs to be flushed. + pub fn wants_flush(&self) -> bool { + match &self.state { + State::Handshake { secrets, .. } => secrets.wants_flush(), + State::WantsDecodedKeys { .. } => true, + State::MasterSecret(ms) => ms.wants_flush(), + State::Application(app) => app.wants_flush(), + _ => false, + } + } + + /// Flushes the functionality. + pub fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), FError> { + match &mut self.state { + State::Handshake { secrets, .. } => { + secrets.flush(vm)?; + + if secrets.is_complete() { + match self.state.take() { + State::Handshake { + masked_cs, + masked_ss, + cs_otp, + ss_otp, + .. + } => { + self.state = State::WantsDecodedKeys { + masked_cs, + masked_ss, + cs_otp, + ss_otp, + }; + // Recurse. + self.flush(vm)?; + return Ok(()); + } + _ => unreachable!(), + } + } + } + State::WantsDecodedKeys { + masked_cs, + masked_ss, + cs_otp, + ss_otp, + } => { + let mut masked_cs = vm.decode(*masked_cs).map_err(FError::vm)?; + let Some(masked_cs) = masked_cs.try_recv().map_err(FError::vm)? else { + return Ok(()); + }; + let mut masked_ss = vm.decode(*masked_ss).map_err(FError::vm)?; + let Some(masked_ss) = masked_ss.try_recv().map_err(FError::vm)? else { + return Ok(()); + }; + + let (ckey, civ, skey, siv) = if self.role == Role::Leader { + let cs_otp = cs_otp.expect("leader knows cs otp"); + let ss_otp = ss_otp.expect("leader knows ss otp"); + + let mut cs = masked_cs; + let mut ss = masked_ss; + + cs.iter_mut().zip(cs_otp).for_each(|(cs, otp)| { + *cs ^= otp; + }); + + ss.iter_mut().zip(ss_otp).for_each(|(ss, otp)| { + *ss ^= otp; + }); + let ckey: [u8; 16] = hkdf_expand_label(&cs, b"key", &[], 16) + .try_into() + .expect("output is 16 bytes"); + let civ: [u8; 12] = hkdf_expand_label(&cs, b"iv", &[], 12) + .try_into() + .expect("output is 12 bytes"); + let skey: [u8; 16] = hkdf_expand_label(&ss, b"key", &[], 16) + .try_into() + .expect("output is 16 bytes"); + let siv: [u8; 12] = hkdf_expand_label(&ss, b"iv", &[], 12) + .try_into() + .expect("output is 12 bytes"); + + (Some(ckey), Some(civ), Some(skey), Some(siv)) + } else { + (None, None, None, None) + }; + + self.state = State::KeysDecoded { + ckey, + civ, + skey, + siv, + } + } + State::MasterSecret(ms) => { + ms.flush(vm)?; + + if ms.is_complete() { + self.state = State::WantsHandshakeHash; + } + } + State::Application(app) => { + app.flush(vm)?; + + if app.is_complete() { + self.state = State::Complete(app.keys()?); + } + } + _ => (), + } + + Ok(()) + } + + /// Sets the hash of the ClientHello message. + pub fn set_hello_hash(&mut self, hello_hash: [u8; 32]) -> Result<(), FError> { + match &mut self.state { + State::Handshake { secrets, .. } => { + secrets.set_hello_hash(hello_hash)?; + Ok(()) + } + _ => Err(FError::state("not in Handshake state")), + } + } + + /// Returns handshake keys. + pub fn handshake_keys(&mut self) -> Result { + if self.role != Role::Leader { + return Err(FError::state("only leader can access handshake keys")); + } + match self.state { + State::KeysDecoded { + ckey, + civ, + skey, + siv, + } => Ok(HandshakeKeys { + client_write_key: ckey.expect("leader knows key"), + client_iv: civ.expect("leader knows key"), + server_write_key: skey.expect("leader knows key"), + server_iv: siv.expect("leader knows key"), + }), + _ => Err(FError::state("not in HandshakeComplete state")), + } + } + + /// Continues the key schedule to derive application keys. + /// + /// Used after the handshake keys are computed and before the handshake + /// hash is set. + pub fn continue_to_app_keys(&mut self) -> Result<(), FError> { + match self.state { + State::KeysDecoded { .. } => { + let ms = mem::take(&mut self.master_secret).expect("master secret is set"); + self.state = State::MasterSecret(ms); + Ok(()) + } + _ => Err(FError::state("not in KeysDecoded state")), + } + } + + /// Sets the handshake hash. + pub fn set_handshake_hash(&mut self, handshake_hash: [u8; 32]) -> Result<(), FError> { + match &mut self.state { + State::WantsHandshakeHash => { + let mut app = + mem::take(&mut self.application).expect("application secrets are set"); + app.set_handshake_hash(handshake_hash)?; + self.state = State::Application(app); + Ok(()) + } + _ => Err(FError::state("not in WantsHandshakeHash state")), + } + } + + /// Returns VM references to the application keys. + pub fn application_keys(&mut self) -> Result { + match self.state { + State::Complete(keys) => Ok(keys), + _ => Err(FError::state("not in Complete state")), + } + } +} + +#[derive(Debug)] +#[allow(clippy::large_enum_variant)] +pub(crate) enum State { + Initialized, + /// The state in which some of the handshake secrets are computed in MPC. + Handshake { + secrets: HandshakeSecrets, + masked_cs: Array, + masked_ss: Array, + cs_otp: Option<[u8; 32]>, + ss_otp: Option<[u8; 32]>, + }, + /// The state after all handshake-related MPC operations were completed + /// and the keys need to be decoded. + WantsDecodedKeys { + masked_cs: Array, + masked_ss: Array, + cs_otp: Option<[u8; 32]>, + ss_otp: Option<[u8; 32]>, + }, + /// The state after the handshake keys were decoded and made known to the + /// leader. + KeysDecoded { + ckey: Option<[u8; 16]>, + civ: Option<[u8; 12]>, + skey: Option<[u8; 16]>, + siv: Option<[u8; 12]>, + }, + /// The state in which the master secret is computed. + /// + /// Computing master secret before handshake hash is set can potentially + /// improve overall performance. + MasterSecret(HkdfExtract), + /// The state in which the master secret has been computed and the + /// handshake hash is expected to be set. + WantsHandshakeHash, + /// The state in which the application secrets are derived. + Application(ApplicationSecrets), + Complete(ApplicationKeys), + Error, +} + +impl State { + pub(crate) fn take(&mut self) -> State { + std::mem::replace(self, State::Error) + } +} + +/// Handshake keys computed by the key schedule. +#[derive(Debug, Clone, Copy)] +pub struct HandshakeKeys { + /// Client write key. + pub client_write_key: [u8; 16], + /// Server write key. + pub server_write_key: [u8; 16], + /// Client IV. + pub client_iv: [u8; 12], + /// Server IV. + pub server_iv: [u8; 12], +} + +/// Application keys computed by the key schedule. +#[derive(Debug, Clone, Copy)] +pub struct ApplicationKeys { + /// Client write key. + pub client_write_key: Array, + /// Server write key. + pub server_write_key: Array, + /// Client IV. + pub client_iv: Array, + /// Server IV. + pub server_iv: Array, +} + +#[cfg(test)] +mod tests { + use crate::{ + test_utils::mock_vm, + tls13::{Role, Tls13KeySched}, + ApplicationKeys, HandshakeKeys, Mode, + }; + use mpz_common::{context::test_st_context, Context}; + use mpz_vm_core::{ + memory::{ + binary::{Binary, U8}, + Array, MemoryExt, ViewExt, + }, + Vm, + }; + use rstest::*; + + #[rstest] + #[case::normal(Mode::Normal)] + #[case::reduced(Mode::Reduced)] + #[tokio::test] + async fn test_tls13_key_sched(#[case] mode: Mode) { + let ( + pms, + hello_hash, + handshake_hash, + ckey_hs, + civ_hs, + skey_hs, + siv_hs, + ckey_app, + civ_app, + skey_app, + siv_app, + ) = test_fixtures(); + + let (mut ctx_a, mut ctx_b) = test_st_context(8); + let (mut leader, mut follower) = mock_vm(); + + // PMS is a private output from previous MPC computations not known + // to either party. For simplicity, it is marked public in this test. + let pms: [u8; 32] = pms.try_into().unwrap(); + + fn setup_ks( + vm: &mut (dyn Vm + Send), + pms: [u8; 32], + mode: Mode, + role: Role, + ) -> Tls13KeySched { + let secret: Array = vm.alloc().unwrap(); + vm.mark_public(secret).unwrap(); + vm.assign(secret, pms).unwrap(); + vm.commit(secret).unwrap(); + + let mut ks = Tls13KeySched::new(mode, role); + ks.alloc(vm, secret).unwrap(); + ks + } + + let mut leader_ks = setup_ks(&mut leader, pms, mode, Role::Leader); + let mut follower_ks = setup_ks(&mut follower, pms, mode, Role::Follower); + + async fn run_ks( + vm: &mut (dyn Vm + Send), + ks: &mut Tls13KeySched, + ctx: &mut Context, + role: Role, + mode: Mode, + hello_hash: Vec, + handshake_hash: Vec, + ) -> Result< + ( + Option, + ([u8; 16], [u8; 12], [u8; 16], [u8; 12]), + ), + Box, + > { + let res = async move { + vm.execute_all(ctx).await.unwrap(); + + flush_execute(ks, vm, ctx, false).await; + + ks.set_hello_hash(hello_hash.try_into().unwrap()).unwrap(); + + // One extra flush to process decoded handshake secrets. + flush_execute(ks, vm, ctx, true).await; + + let hs_keys = if role == Role::Leader { + Some(ks.handshake_keys().unwrap()) + } else { + None + }; + + ks.continue_to_app_keys().unwrap(); + + flush_execute(ks, vm, ctx, false).await; + + ks.set_handshake_hash(handshake_hash.try_into().unwrap()) + .unwrap(); + + if mode == Mode::Reduced { + // One extra flush to process decoded inner_partial. + flush_execute(ks, vm, ctx, true).await; + } else { + flush_execute(ks, vm, ctx, false).await; + } + + let ApplicationKeys { + client_write_key, + client_iv, + server_write_key, + server_iv, + } = ks.application_keys().unwrap(); + let mut ckey_fut = vm.decode(client_write_key).unwrap(); + let mut civ_fut = vm.decode(client_iv).unwrap(); + let mut skey_fut = vm.decode(server_write_key).unwrap(); + let mut siv_fut = vm.decode(server_iv).unwrap(); + vm.execute_all(ctx).await.unwrap(); + let ckey = ckey_fut.try_recv().unwrap().unwrap(); + let civ = civ_fut.try_recv().unwrap().unwrap(); + let skey = skey_fut.try_recv().unwrap().unwrap(); + let siv = siv_fut.try_recv().unwrap().unwrap(); + + (hs_keys, (ckey, civ, skey, siv)) + } + .await; + + Ok(res) + } + + let (out_leader, out_follower) = tokio::try_join!( + run_ks( + &mut leader, + &mut leader_ks, + &mut ctx_a, + Role::Leader, + mode, + hello_hash.clone(), + handshake_hash.clone() + ), + run_ks( + &mut follower, + &mut follower_ks, + &mut ctx_b, + Role::Follower, + mode, + hello_hash, + handshake_hash + ) + ) + .unwrap(); + + let hs_keys_leader = out_leader.0.unwrap(); + + assert_eq!( + ( + hs_keys_leader.client_write_key.to_vec(), + hs_keys_leader.client_iv.to_vec(), + hs_keys_leader.server_write_key.to_vec(), + hs_keys_leader.server_iv.to_vec() + ), + (ckey_hs, civ_hs, skey_hs, siv_hs) + ); + + let app_keys_leader = out_leader.1; + let app_keys_follower = out_follower.1; + assert_eq!(app_keys_leader, app_keys_follower); + + assert_eq!( + app_keys_leader, + ( + ckey_app.try_into().unwrap(), + civ_app.try_into().unwrap(), + skey_app.try_into().unwrap(), + siv_app.try_into().unwrap() + ) + ); + } + + async fn flush_execute( + ks: &mut Tls13KeySched, + vm: &mut (dyn Vm + Send), + ctx: &mut Context, + // Whether after executing the VM, one extra flush is required. + extra_flush: bool, + ) { + assert!(ks.wants_flush()); + ks.flush(vm).unwrap(); + vm.execute_all(ctx).await.unwrap(); + if extra_flush { + assert!(ks.wants_flush()); + ks.flush(vm).unwrap(); + } + assert!(!ks.wants_flush()) + } + + // Reference values from https://datatracker.ietf.org/doc/html/draft-ietf-tls-tls13-vectors-06 + #[allow(clippy::type_complexity)] + fn test_fixtures() -> ( + Vec, + Vec, + Vec, + Vec, + Vec, + Vec, + Vec, + Vec, + Vec, + Vec, + Vec, + ) { + ( + // PMS + from_hex_str("81 51 d1 46 4c 1b 55 53 36 23 b9 c2 24 6a 6a 0e 6e 7e 18 50 63 e1 4a fd af f0 b6 e1 c6 1a 86 42"), + // HELLO_HASH + from_hex_str("c6 c9 18 ad 2f 41 99 d5 59 8e af 01 16 cb 7a 5c 2c 14 cb 54 78 12 18 88 8d b7 03 0d d5 0d 5e 6d"), + // HANDSHAKE_HASH + from_hex_str("f8 c1 9e 8c 77 c0 38 79 bb c8 eb 6d 56 e0 0d d5 d8 6e f5 59 27 ee fc 08 e1 b0 02 b6 ec e0 5d bf"), + // CKEY_HS + from_hex_str("26 79 a4 3e 1d 76 78 40 34 ea 17 97 d5 ad 26 49"), + // CIV_HS + from_hex_str("54 82 40 52 90 dd 0d 2f 81 c0 d9 42"), + // SKEY_HS + from_hex_str("c6 6c b1 ae c5 19 df 44 c9 1e 10 99 55 11 ac 8b"), + // SIV_HS + from_hex_str("f7 f6 88 4c 49 81 71 6c 2d 0d 29 a4"), + // CKEY_APP + from_hex_str("88 b9 6a d6 86 c8 4b e5 5a ce 18 a5 9c ce 5c 87"), + // CIV_APP + from_hex_str("b9 9d c5 8c d5 ff 5a b0 82 fd ad 19"), + // SKEY_APP + from_hex_str("a6 88 eb b5 ac 82 6d 6f 42 d4 5c 0c c4 4b 9b 7d"), + // SIV_APP + from_hex_str("c1 ca d4 42 5a 43 8b 5d e7 14 83 0a"), + ) + } + + fn from_hex_str(s: &str) -> Vec { + hex::decode(s.split_whitespace().collect::()).unwrap() + } +} diff --git a/crates/components/hmac-sha256/src/tls13/application.rs b/crates/components/hmac-sha256/src/tls13/application.rs new file mode 100644 index 0000000000..a8f091d3f8 --- /dev/null +++ b/crates/components/hmac-sha256/src/tls13/application.rs @@ -0,0 +1,233 @@ +use crate::{ + hmac::Hmac, + kdf::expand::{HkdfExpand, EMPTY_CTX}, + ApplicationKeys, FError, Mode, +}; + +use mpz_vm_core::{ + memory::{ + binary::{Binary, U8}, + Vector, + }, + Vm, +}; + +/// Functionality for computing application secrets of TLS 1.3 key schedule. +#[derive(Debug)] +pub(crate) struct ApplicationSecrets { + mode: Mode, + state: State, + client_secret: Option, + server_secret: Option, + client_application_key: Option, + client_application_iv: Option, + server_application_key: Option, + server_application_iv: Option, +} + +impl ApplicationSecrets { + /// Creates a new functionality. + pub(crate) fn new(mode: Mode) -> ApplicationSecrets { + Self { + mode, + state: State::Initialized, + client_secret: None, + server_secret: None, + client_application_key: None, + client_application_iv: None, + server_application_key: None, + server_application_iv: None, + } + } + + /// Allocates the functionality with the given `master_secret`. + pub(crate) fn alloc( + &mut self, + vm: &mut dyn Vm, + master_secret: Vector, + ) -> Result<(), FError> { + let State::Initialized = self.state.take() else { + return Err(FError::state("not in Initialized state")); + }; + + let mode = self.mode; + + let hmac_ms1 = Hmac::alloc(vm, master_secret, mode)?; + let hmac_ms2 = Hmac::from_other(vm, &hmac_ms1)?; + + let client_secret = HkdfExpand::alloc(mode, vm, hmac_ms1, b"c ap traffic", None, 32, 32)?; + + let server_secret = HkdfExpand::alloc(mode, vm, hmac_ms2, b"s ap traffic", None, 32, 32)?; + + let hmac_cs1 = Hmac::alloc(vm, client_secret.output(), mode)?; + let hmac_cs2 = Hmac::from_other(vm, &hmac_cs1)?; + + let hmac_ss1 = Hmac::alloc(vm, server_secret.output(), mode)?; + let hmac_ss2 = Hmac::from_other(vm, &hmac_ss1)?; + + let client_application_key = + HkdfExpand::alloc(mode, vm, hmac_cs1, b"key", Some(&EMPTY_CTX), 0, 16)?; + + let client_application_iv = + HkdfExpand::alloc(mode, vm, hmac_cs2, b"iv", Some(&EMPTY_CTX), 0, 12)?; + + let server_application_key = + HkdfExpand::alloc(mode, vm, hmac_ss1, b"key", Some(&EMPTY_CTX), 0, 16)?; + + let server_application_iv = + HkdfExpand::alloc(mode, vm, hmac_ss2, b"iv", Some(&EMPTY_CTX), 0, 12)?; + + self.state = State::WantsHandshakeHash; + self.client_secret = Some(client_secret); + self.server_secret = Some(server_secret); + self.client_application_key = Some(client_application_key); + self.client_application_iv = Some(client_application_iv); + self.server_application_key = Some(server_application_key); + self.server_application_iv = Some(server_application_iv); + + Ok(()) + } + + /// Whether this functionality needs to be flushed. + pub(crate) fn wants_flush(&self) -> bool { + let client_secret = self.client_secret.as_ref().expect("functionality was set"); + let server_secret = self.server_secret.as_ref().expect("functionality was set"); + let client_application_key = self + .client_application_key + .as_ref() + .expect("functionality was set"); + let client_application_iv = self + .client_application_iv + .as_ref() + .expect("functionality was set"); + let server_application_key = self + .server_application_key + .as_ref() + .expect("functionality was set"); + let server_application_iv = self + .server_application_iv + .as_ref() + .expect("functionality was set"); + + let state_wants_flush = matches!(&self.state, State::HandshakeHashSet(..)); + + state_wants_flush + || client_secret.wants_flush() + || server_secret.wants_flush() + || client_application_key.wants_flush() + || client_application_iv.wants_flush() + || server_application_key.wants_flush() + || server_application_iv.wants_flush() + } + + /// Flushes the functionality. + pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), FError> { + let client_secret = self.client_secret.as_mut().expect("functionality was set"); + let server_secret = self.server_secret.as_mut().expect("functionality was set"); + let client_application_key = self + .client_application_key + .as_mut() + .expect("functionality was set"); + let client_application_iv = self + .client_application_iv + .as_mut() + .expect("functionality was set"); + let server_application_key = self + .server_application_key + .as_mut() + .expect("functionality was set"); + let server_application_iv = self + .server_application_iv + .as_mut() + .expect("functionality was set"); + + client_secret.flush(vm)?; + server_secret.flush(vm)?; + client_application_key.flush(vm)?; + client_application_iv.flush(vm)?; + server_application_key.flush(vm)?; + server_application_iv.flush(vm)?; + + if let State::HandshakeHashSet(hash) = &self.state { + if !client_secret.is_ctx_set() { + client_secret.set_ctx(hash)?; + client_secret.flush(vm)?; + } + if !server_secret.is_ctx_set() { + server_secret.set_ctx(hash)?; + server_secret.flush(vm)?; + } + + if client_application_iv.is_complete() + && client_application_key.is_complete() + && client_secret.is_complete() + && server_application_iv.is_complete() + && server_application_key.is_complete() + && server_secret.is_complete() + { + self.state = State::Complete(ApplicationKeys { + client_write_key: client_application_key + .output() + .try_into() + .expect("key length is 16 bytes"), + client_iv: client_application_iv + .output() + .try_into() + .expect("iv length is 12 bytes"), + server_write_key: server_application_key + .output() + .try_into() + .expect("key length is 16 bytes"), + server_iv: server_application_iv + .output() + .try_into() + .expect("iv length is 12 bytes"), + }); + } + } + + Ok(()) + } + + /// Sets the handshake hash. + pub(crate) fn set_handshake_hash(&mut self, handshake_hash: [u8; 32]) -> Result<(), FError> { + match &mut self.state { + State::WantsHandshakeHash => { + self.state = State::HandshakeHashSet(handshake_hash); + Ok(()) + } + _ => Err(FError::state("not in WantsHandshakeHash state")), + } + } + + /// Returns the application keys. + pub(crate) fn keys(&mut self) -> Result { + match self.state { + State::Complete(keys) => Ok(keys), + _ => Err(FError::state("not in Complete state")), + } + } + + /// Whether this functionality is complete. + pub(crate) fn is_complete(&self) -> bool { + matches!(self.state, State::Complete { .. }) + } +} + +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +pub(crate) enum State { + Initialized, + /// Wants handshake hash to be set. + WantsHandshakeHash, + /// Handshake hash has been set. + HandshakeHashSet([u8; 32]), + Complete(ApplicationKeys), + Error, +} + +impl State { + pub(crate) fn take(&mut self) -> State { + std::mem::replace(self, State::Error) + } +} diff --git a/crates/components/hmac-sha256/src/tls13/handshake.rs b/crates/components/hmac-sha256/src/tls13/handshake.rs new file mode 100644 index 0000000000..86ec044431 --- /dev/null +++ b/crates/components/hmac-sha256/src/tls13/handshake.rs @@ -0,0 +1,206 @@ +use crate::{ + hmac::{normal::HmacNormal, Hmac}, + kdf::{expand::HkdfExpand, extract::HkdfExtractPrivIkm}, + FError, Mode, +}; + +use mpz_vm_core::{ + memory::{ + binary::{Binary, U8}, + Array, Vector, + }, + Vm, +}; + +// INNER_PARTIAL and OUTER_PARTIAL were computed using the code below: +// +// // A deterministic derived secret for handshake for SHA-256 ciphersuites. +// // see https://datatracker.ietf.org/doc/html/draft-ietf-tls-tls13-vectors-06 +// let derived_secret: Vec = vec![ +// 0x6f, 0x26, 0x15, 0xa1, 0x08, 0xc7, 0x02, 0xc5, 0x67, 0x8f, 0x54, +// 0xfc, 0x9d, 0xba, 0xb6, 0x97, 0x16, 0xc0, 0x76, 0x18, 0x9c, 0x48, +// 0x25, 0x0c, 0xeb, 0xea, 0xc3, 0x57, 0x6c, 0x36, 0x11, 0xba]; +// +// let inner_partial = clear::compute_inner_partial(derived_secret.clone()); +// let outer_partial = clear::compute_outer_partial(derived_secret); + +/// A deterministic inner partial hash state of the derived secret for +/// handshake for SHA-256 ciphersuites. +const INNER_PARTIAL: [u32; 8] = [ + 2335507740, 2200227439, 3546272834, 83913483, 301355998, 2266431524, 1402092146, 439257589, +]; + +/// A deterministic inner partial hash state of the derived secret for +/// handshake for SHA-256 ciphersuites. +const OUTER_PARTIAL: [u32; 8] = [ + 582556975, 2818161237, 3127925320, 2797531207, 4122647441, 3290806166, 3682628262, 2419579842, +]; + +/// The digest of SHA256(""). +const EMPTY_HASH: [u8; 32] = [ + 0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, 0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24, + 0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55, +]; + +/// Functionality for computing handshake secrets of TLS 1.3 key schedule. +#[derive(Debug)] +pub(crate) struct HandshakeSecrets { + mode: Mode, + state: State, + handshake_secret: Option, + client_secret: Option, + server_secret: Option, + derived_secret: Option, +} + +impl HandshakeSecrets { + /// Creates a new functionality. + pub(crate) fn new(mode: Mode) -> HandshakeSecrets { + Self { + mode, + state: State::Initialized, + handshake_secret: None, + client_secret: None, + server_secret: None, + derived_secret: None, + } + } + + /// Allocates the functionality with the given pre-master secret. + /// + /// Returns client_handshake_traffic_secret, + /// server_handshake_traffic_secret, and derived_secret for master_secret. + #[allow(clippy::type_complexity)] + pub(crate) fn alloc( + &mut self, + vm: &mut dyn Vm, + pms: Array, + ) -> Result<(Array, Array, Vector), FError> { + let State::Initialized = self.state.take() else { + return Err(FError::state("not in Initialized state")); + }; + + let mode = self.mode; + let hmac = HmacNormal::alloc_with_state(vm, INNER_PARTIAL, OUTER_PARTIAL)?; + + let handshake_secret = HkdfExtractPrivIkm::alloc(vm, pms, hmac)?; + + let hmac_hs1 = Hmac::alloc(vm, handshake_secret.output(), mode)?; + let hmac_hs2 = Hmac::from_other(vm, &hmac_hs1)?; + let hmac_hs3 = Hmac::from_other(vm, &hmac_hs1)?; + + let client_secret = HkdfExpand::alloc(mode, vm, hmac_hs1, b"c hs traffic", None, 32, 32)?; + + let server_secret = HkdfExpand::alloc(mode, vm, hmac_hs2, b"s hs traffic", None, 32, 32)?; + + // Optimization: by computing now the derived_secret for + // master_secret in parallel with cs and ss, we save communication + // rounds when we are in the reduced mode. + let derived_secret = + HkdfExpand::alloc(mode, vm, hmac_hs3, b"derived", Some(&EMPTY_HASH), 32, 32)?; + + let cs_out: Array = client_secret + .output() + .try_into() + .expect("client secret is 32 bytes"); + let ss_out = server_secret + .output() + .try_into() + .expect("server secret is 32 bytes"); + + let derived_output = derived_secret.output(); + + self.handshake_secret = Some(handshake_secret); + self.client_secret = Some(client_secret); + self.server_secret = Some(server_secret); + self.derived_secret = Some(derived_secret); + self.state = State::WantsHelloHash; + + Ok((cs_out, ss_out, derived_output)) + } + + /// Whether this functionality needs to be flushed. + pub(crate) fn wants_flush(&self) -> bool { + let client_secret = self.client_secret.as_ref().expect("functionality was set"); + let server_secret = self.server_secret.as_ref().expect("functionality was set"); + let derived_secret = self.derived_secret.as_ref().expect("functionality was set"); + let handshake_secret = self + .handshake_secret + .as_ref() + .expect("functionality was set"); + + let state_wants_flush = matches!(&self.state, State::HelloHashSet(..)); + + state_wants_flush + || client_secret.wants_flush() + || server_secret.wants_flush() + || derived_secret.wants_flush() + || handshake_secret.wants_flush() + } + + /// Flushes the functionality. + pub(crate) fn flush(&mut self, vm: &mut dyn Vm) -> Result<(), FError> { + let client_secret = self.client_secret.as_mut().expect("functionality was set"); + let server_secret = self.server_secret.as_mut().expect("functionality was set"); + let derived_secret = self.derived_secret.as_mut().expect("functionality was set"); + let handshake_secret = self + .handshake_secret + .as_mut() + .expect("functionality was set"); + + client_secret.flush(vm)?; + derived_secret.flush(vm)?; + handshake_secret.flush(); + server_secret.flush(vm)?; + + if let State::HelloHashSet(hash) = &mut self.state { + client_secret.set_ctx(hash)?; + client_secret.flush(vm)?; + + server_secret.set_ctx(hash)?; + server_secret.flush(vm)?; + + if handshake_secret.is_complete() + && client_secret.is_complete() + && server_secret.is_complete() + && derived_secret.is_complete() + { + self.state = State::Complete; + } + } + + Ok(()) + } + + /// Sets the hash of the ClientHello message. + pub(crate) fn set_hello_hash(&mut self, hello_hash: [u8; 32]) -> Result<(), FError> { + match &mut self.state { + State::WantsHelloHash => { + self.state = State::HelloHashSet(hello_hash); + Ok(()) + } + _ => Err(FError::state("not in WantsHelloHash state")), + } + } + + /// Whether this functionality is complete. + pub(crate) fn is_complete(&self) -> bool { + matches!(self.state, State::Complete) + } +} + +#[allow(clippy::large_enum_variant)] +#[derive(Debug)] +pub(crate) enum State { + Initialized, + WantsHelloHash, + HelloHashSet([u8; 32]), + Complete, + Error, +} + +impl State { + pub(crate) fn take(&mut self) -> State { + std::mem::replace(self, State::Error) + } +} diff --git a/crates/mpc-tls/src/error.rs b/crates/mpc-tls/src/error.rs index 71b1bd2f03..5ce3f81129 100644 --- a/crates/mpc-tls/src/error.rs +++ b/crates/mpc-tls/src/error.rs @@ -1,4 +1,4 @@ -use hmac_sha256::PrfError; +use hmac_sha256::FError; use key_exchange::KeyExchangeError; use tls_backend::BackendError; @@ -106,8 +106,8 @@ impl From for MpcTlsError { } } -impl From for MpcTlsError { - fn from(value: PrfError) -> Self { +impl From for MpcTlsError { + fn from(value: FError) -> Self { MpcTlsError::hs(value) } } diff --git a/crates/mpc-tls/src/follower.rs b/crates/mpc-tls/src/follower.rs index 241820013f..46b6995c84 100644 --- a/crates/mpc-tls/src/follower.rs +++ b/crates/mpc-tls/src/follower.rs @@ -3,7 +3,7 @@ use crate::{ record_layer::{aead::MpcAesGcm, RecordLayer}, Config, MpcTlsError, Role, SessionKeys, Vm, }; -use hmac_sha256::{MpcPrf, PrfOutput}; +use hmac_sha256::{PrfOutput, Tls12Prf}; use ke::KeyExchange; use key_exchange::{self as ke, MpcKeyExchange}; use mpz_common::{Context, Flush}; @@ -64,7 +64,7 @@ impl MpcTlsFollower { )), )) as Box; - let prf = MpcPrf::new(config.prf); + let prf = Tls12Prf::new(config.prf); let encrypter = MpcAesGcm::new( ShareConversionReceiver::new(OLEReceiver::new(AnyReceiver::new( @@ -436,13 +436,13 @@ enum State { Init { vm: Vm, ke: Box, - prf: MpcPrf, + prf: Tls12Prf, record_layer: RecordLayer, }, Setup { vm: Vm, ke: Box, - prf: MpcPrf, + prf: Tls12Prf, record_layer: RecordLayer, cf_vd: DecodeFutureTyped, sf_vd: DecodeFutureTyped, @@ -450,7 +450,7 @@ enum State { Ready { vm: Vm, ke: Box, - prf: MpcPrf, + prf: Tls12Prf, record_layer: RecordLayer, cf_vd: DecodeFutureTyped, sf_vd: DecodeFutureTyped, diff --git a/crates/mpc-tls/src/leader.rs b/crates/mpc-tls/src/leader.rs index 926fda9936..2528c5a711 100644 --- a/crates/mpc-tls/src/leader.rs +++ b/crates/mpc-tls/src/leader.rs @@ -11,7 +11,7 @@ use crate::{ Config, Role, SessionKeys, Vm, }; use async_trait::async_trait; -use hmac_sha256::{MpcPrf, PrfOutput}; +use hmac_sha256::{PrfOutput, Tls12Prf}; use ke::KeyExchange; use key_exchange::{self as ke, MpcKeyExchange}; use ludi::Context as LudiContext; @@ -92,7 +92,7 @@ impl MpcTlsLeader { ))), )) as Box; - let prf = MpcPrf::new(config.prf); + let prf = Tls12Prf::new(config.prf); let encrypter = MpcAesGcm::new( ShareConversionSender::new(OLESender::new( @@ -1039,14 +1039,14 @@ enum State { ctx: Context, vm: Vm, ke: Box, - prf: MpcPrf, + prf: Tls12Prf, record_layer: RecordLayer, }, Setup { ctx: Context, vm: Vm, ke: Box, - prf: MpcPrf, + prf: Tls12Prf, record_layer: RecordLayer, cf_vd_fut: DecodeFutureTyped, sf_vd_fut: DecodeFutureTyped, @@ -1056,7 +1056,7 @@ enum State { ctx: Context, vm: Vm, ke: Box, - prf: MpcPrf, + prf: Tls12Prf, record_layer: RecordLayer, cf_vd_fut: DecodeFutureTyped, sf_vd_fut: DecodeFutureTyped, @@ -1073,7 +1073,7 @@ enum State { ctx: Context, vm: Vm, _ke: Box, - prf: MpcPrf, + prf: Tls12Prf, record_layer: RecordLayer, cf_vd_fut: DecodeFutureTyped, sf_vd_fut: DecodeFutureTyped,