Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 120 additions & 71 deletions mls-rs/src/tree_kem/parent_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::client::MlsError;
use crate::crypto::{CipherSuiteProvider, HpkePublicKey};
use crate::tree_kem::math as tree_math;
use crate::tree_kem::node::{LeafIndex, Node, NodeIndex};
use crate::tree_kem::tree_hash::TreeHash;
use crate::tree_kem::TreeKemPublic;
use alloc::vec::Vec;
use core::{
Expand Down Expand Up @@ -198,77 +199,14 @@ impl TreeKemPublic {

// For each leaf l, validate all non-blank nodes on the chain from l up the tree.
for (leaf_index, _) in self.nodes.non_empty_leaves() {
let mut n = NodeIndex::from(leaf_index);

while let Some(mut ps) = n.parent_sibling(&num_leaves) {
// Find the first non-blank ancestor p of n and p's co-path child s.
while self.nodes.is_blank(ps.parent)? {
// If we reached the root, we're done with this chain.
let Some(ps_parent) = ps.parent.parent_sibling(&num_leaves) else {
return Ok(());
};

ps = ps_parent;
}

// Check is n's parent_hash field matches the parent hash of p with co-path child s.
let p_parent = self.nodes.borrow_as_parent(ps.parent)?;

let n_node = self
.nodes
.borrow_node(n)?
.as_ref()
.ok_or(MlsError::ExpectedNode)?;

let calculated = ParentHash::new(
cipher_suite_provider,
&p_parent.public_key,
&p_parent.parent_hash,
&original_hashes[ps.sibling as usize],
)
.await?;

if n_node.get_parent_hash() == Some(calculated) {
// Check that "n is in the resolution of c, and the intersection of p's unmerged_leaves with the subtree
// under c is equal to the resolution of c with n removed".
let Some(cp) = ps.sibling.parent_sibling(&num_leaves) else {
return Err(MlsError::ParentHashMismatch);
};

let c = cp.sibling;
let c_resolution = self.nodes.get_resolution_index(c)?.into_iter();

#[cfg(feature = "std")]
let mut c_resolution = c_resolution.collect::<HashSet<_>>();
#[cfg(not(feature = "std"))]
let mut c_resolution = c_resolution.collect::<BTreeSet<_>>();

let p_unmerged_in_c_subtree = self
.unmerged_in_subtree(ps.parent, c)?
.iter()
.copied()
.map(|x| *x * 2);

#[cfg(feature = "std")]
let p_unmerged_in_c_subtree = p_unmerged_in_c_subtree.collect::<HashSet<_>>();
#[cfg(not(feature = "std"))]
let p_unmerged_in_c_subtree = p_unmerged_in_c_subtree.collect::<BTreeSet<_>>();

if c_resolution.remove(&n)
&& c_resolution == p_unmerged_in_c_subtree
&& nodes_to_validate.remove(&ps.parent)
{
// If n's parent_hash field matches and p has not been validated yet, mark p as validated and continue.
n = ps.parent;
} else {
// If p is validated for the second time, the check fails ("all non-blank parent nodes are covered by exactly one such chain").
return Err(MlsError::ParentHashMismatch);
}
} else {
// If n's parent_hash field doesn't match, we're done with this chain.
break;
}
}
self.validate_chain(
leaf_index,
num_leaves,
cipher_suite_provider,
&original_hashes,
&mut nodes_to_validate,
)
.await?;
}

// The check passes iff all non-blank nodes are validated.
Expand All @@ -278,6 +216,91 @@ impl TreeKemPublic {
Err(MlsError::ParentHashMismatch)
}
}

#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
async fn validate_chain<P: CipherSuiteProvider>(
&self,
leaf_index: LeafIndex,
num_leaves: u32,
cipher_suite_provider: &P,
original_hashes: &[TreeHash],
#[cfg(feature = "std")] nodes_to_validate: &mut HashSet<u32>,
#[cfg(not(feature = "std"))] nodes_to_validate: &mut BTreeSet<u32>,
) -> Result<(), MlsError> {
let mut n = NodeIndex::from(leaf_index);

while let Some(mut ps) = n.parent_sibling(&num_leaves) {
// Find the first non-blank ancestor p of n and p's co-path child s.
while self.nodes.is_blank(ps.parent)? {
// If we reached the root, we're done with this chain.
let Some(ps_parent) = ps.parent.parent_sibling(&num_leaves) else {
return Ok(());
};

ps = ps_parent;
}

// Check is n's parent_hash field matches the parent hash of p with co-path child s.
let p_parent = self.nodes.borrow_as_parent(ps.parent)?;

let n_node = self
.nodes
.borrow_node(n)?
.as_ref()
.ok_or(MlsError::ExpectedNode)?;

let calculated = ParentHash::new(
cipher_suite_provider,
&p_parent.public_key,
&p_parent.parent_hash,
&original_hashes[ps.sibling as usize],
)
.await?;

if n_node.get_parent_hash() == Some(calculated) {
// Check that "n is in the resolution of c, and the intersection of p's unmerged_leaves with the subtree
// under c is equal to the resolution of c with n removed".
let Some(cp) = ps.sibling.parent_sibling(&num_leaves) else {
return Err(MlsError::ParentHashMismatch);
};

let c = cp.sibling;
let c_resolution = self.nodes.get_resolution_index(c)?.into_iter();

#[cfg(feature = "std")]
let mut c_resolution = c_resolution.collect::<HashSet<_>>();
#[cfg(not(feature = "std"))]
let mut c_resolution = c_resolution.collect::<BTreeSet<_>>();

let p_unmerged_in_c_subtree = self
.unmerged_in_subtree(ps.parent, c)?
.iter()
.copied()
.map(|x| *x * 2);

#[cfg(feature = "std")]
let p_unmerged_in_c_subtree = p_unmerged_in_c_subtree.collect::<HashSet<_>>();
#[cfg(not(feature = "std"))]
let p_unmerged_in_c_subtree = p_unmerged_in_c_subtree.collect::<BTreeSet<_>>();

if c_resolution.remove(&n)
&& c_resolution == p_unmerged_in_c_subtree
&& nodes_to_validate.remove(&ps.parent)
{
// If n's parent_hash field matches and p has not been validated yet, mark p as validated and continue.
n = ps.parent;
} else {
// If p is validated for the second time, the check fails ("all non-blank parent nodes are covered by exactly one such chain").
return Err(MlsError::ParentHashMismatch);
}
} else {
// If n's parent_hash field doesn't match, we're done with this chain.
break;
}
}

Ok(())
}
}

#[cfg(test)]
Expand Down Expand Up @@ -379,6 +402,9 @@ mod tests {
use crate::tree_kem::MlsError;
use assert_matches::assert_matches;

#[cfg(feature = "rfc_compliant")]
use alloc::vec;

#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn test_missing_parent_hash() {
let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
Expand Down Expand Up @@ -440,4 +466,27 @@ mod tests {

assert_matches!(res, Err(MlsError::ParentHashMismatch));
}

#[cfg(feature = "rfc_compliant")]
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
async fn test_premature_validation_termination() {
let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
let mut test_tree = TreeWithSigners::make_full_tree(8, &cs).await;
test_tree.remove_member(6);

// Corrupt a parent hash that should be validated but might be skipped due to early return
test_tree
.tree
.nodes
.borrow_as_parent_mut(9)
.unwrap()
.parent_hash = ParentHash::from(vec![0xFF; 32]);

let res = test_tree
.tree
.validate_parent_hashes(&test_cipher_suite_provider(TEST_CIPHER_SUITE))
.await;

assert_matches!(res, Err(MlsError::ParentHashMismatch));
}
}
Loading