diff --git a/mls-rs/src/tree_kem/parent_hash.rs b/mls-rs/src/tree_kem/parent_hash.rs index c81da833..94720a7c 100644 --- a/mls-rs/src/tree_kem/parent_hash.rs +++ b/mls-rs/src/tree_kem/parent_hash.rs @@ -6,6 +6,7 @@ use crate::client::MlsError; use crate::crypto::{CipherSuiteProvider, HpkePublicKey}; use crate::tree_kem::math as tree_math; use crate::tree_kem::node::{LeafIndex, Node, NodeIndex}; +use crate::tree_kem::tree_hash::TreeHash; use crate::tree_kem::TreeKemPublic; use alloc::vec::Vec; use core::{ @@ -198,77 +199,14 @@ impl TreeKemPublic { // For each leaf l, validate all non-blank nodes on the chain from l up the tree. for (leaf_index, _) in self.nodes.non_empty_leaves() { - let mut n = NodeIndex::from(leaf_index); - - while let Some(mut ps) = n.parent_sibling(&num_leaves) { - // Find the first non-blank ancestor p of n and p's co-path child s. - while self.nodes.is_blank(ps.parent)? { - // If we reached the root, we're done with this chain. - let Some(ps_parent) = ps.parent.parent_sibling(&num_leaves) else { - return Ok(()); - }; - - ps = ps_parent; - } - - // Check is n's parent_hash field matches the parent hash of p with co-path child s. - let p_parent = self.nodes.borrow_as_parent(ps.parent)?; - - let n_node = self - .nodes - .borrow_node(n)? - .as_ref() - .ok_or(MlsError::ExpectedNode)?; - - let calculated = ParentHash::new( - cipher_suite_provider, - &p_parent.public_key, - &p_parent.parent_hash, - &original_hashes[ps.sibling as usize], - ) - .await?; - - if n_node.get_parent_hash() == Some(calculated) { - // Check that "n is in the resolution of c, and the intersection of p's unmerged_leaves with the subtree - // under c is equal to the resolution of c with n removed". - let Some(cp) = ps.sibling.parent_sibling(&num_leaves) else { - return Err(MlsError::ParentHashMismatch); - }; - - let c = cp.sibling; - let c_resolution = self.nodes.get_resolution_index(c)?.into_iter(); - - #[cfg(feature = "std")] - let mut c_resolution = c_resolution.collect::>(); - #[cfg(not(feature = "std"))] - let mut c_resolution = c_resolution.collect::>(); - - let p_unmerged_in_c_subtree = self - .unmerged_in_subtree(ps.parent, c)? - .iter() - .copied() - .map(|x| *x * 2); - - #[cfg(feature = "std")] - let p_unmerged_in_c_subtree = p_unmerged_in_c_subtree.collect::>(); - #[cfg(not(feature = "std"))] - let p_unmerged_in_c_subtree = p_unmerged_in_c_subtree.collect::>(); - - if c_resolution.remove(&n) - && c_resolution == p_unmerged_in_c_subtree - && nodes_to_validate.remove(&ps.parent) - { - // If n's parent_hash field matches and p has not been validated yet, mark p as validated and continue. - n = ps.parent; - } else { - // If p is validated for the second time, the check fails ("all non-blank parent nodes are covered by exactly one such chain"). - return Err(MlsError::ParentHashMismatch); - } - } else { - // If n's parent_hash field doesn't match, we're done with this chain. - break; - } - } + self.validate_chain( + leaf_index, + num_leaves, + cipher_suite_provider, + &original_hashes, + &mut nodes_to_validate, + ) + .await?; } // The check passes iff all non-blank nodes are validated. @@ -278,6 +216,91 @@ impl TreeKemPublic { Err(MlsError::ParentHashMismatch) } } + + #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] + async fn validate_chain( + &self, + leaf_index: LeafIndex, + num_leaves: u32, + cipher_suite_provider: &P, + original_hashes: &[TreeHash], + #[cfg(feature = "std")] nodes_to_validate: &mut HashSet, + #[cfg(not(feature = "std"))] nodes_to_validate: &mut BTreeSet, + ) -> Result<(), MlsError> { + let mut n = NodeIndex::from(leaf_index); + + while let Some(mut ps) = n.parent_sibling(&num_leaves) { + // Find the first non-blank ancestor p of n and p's co-path child s. + while self.nodes.is_blank(ps.parent)? { + // If we reached the root, we're done with this chain. + let Some(ps_parent) = ps.parent.parent_sibling(&num_leaves) else { + return Ok(()); + }; + + ps = ps_parent; + } + + // Check is n's parent_hash field matches the parent hash of p with co-path child s. + let p_parent = self.nodes.borrow_as_parent(ps.parent)?; + + let n_node = self + .nodes + .borrow_node(n)? + .as_ref() + .ok_or(MlsError::ExpectedNode)?; + + let calculated = ParentHash::new( + cipher_suite_provider, + &p_parent.public_key, + &p_parent.parent_hash, + &original_hashes[ps.sibling as usize], + ) + .await?; + + if n_node.get_parent_hash() == Some(calculated) { + // Check that "n is in the resolution of c, and the intersection of p's unmerged_leaves with the subtree + // under c is equal to the resolution of c with n removed". + let Some(cp) = ps.sibling.parent_sibling(&num_leaves) else { + return Err(MlsError::ParentHashMismatch); + }; + + let c = cp.sibling; + let c_resolution = self.nodes.get_resolution_index(c)?.into_iter(); + + #[cfg(feature = "std")] + let mut c_resolution = c_resolution.collect::>(); + #[cfg(not(feature = "std"))] + let mut c_resolution = c_resolution.collect::>(); + + let p_unmerged_in_c_subtree = self + .unmerged_in_subtree(ps.parent, c)? + .iter() + .copied() + .map(|x| *x * 2); + + #[cfg(feature = "std")] + let p_unmerged_in_c_subtree = p_unmerged_in_c_subtree.collect::>(); + #[cfg(not(feature = "std"))] + let p_unmerged_in_c_subtree = p_unmerged_in_c_subtree.collect::>(); + + if c_resolution.remove(&n) + && c_resolution == p_unmerged_in_c_subtree + && nodes_to_validate.remove(&ps.parent) + { + // If n's parent_hash field matches and p has not been validated yet, mark p as validated and continue. + n = ps.parent; + } else { + // If p is validated for the second time, the check fails ("all non-blank parent nodes are covered by exactly one such chain"). + return Err(MlsError::ParentHashMismatch); + } + } else { + // If n's parent_hash field doesn't match, we're done with this chain. + break; + } + } + + Ok(()) + } } #[cfg(test)] @@ -379,6 +402,9 @@ mod tests { use crate::tree_kem::MlsError; use assert_matches::assert_matches; + #[cfg(feature = "rfc_compliant")] + use alloc::vec; + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_missing_parent_hash() { let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE); @@ -440,4 +466,27 @@ mod tests { assert_matches!(res, Err(MlsError::ParentHashMismatch)); } + + #[cfg(feature = "rfc_compliant")] + #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] + async fn test_premature_validation_termination() { + let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE); + let mut test_tree = TreeWithSigners::make_full_tree(8, &cs).await; + test_tree.remove_member(6); + + // Corrupt a parent hash that should be validated but might be skipped due to early return + test_tree + .tree + .nodes + .borrow_as_parent_mut(9) + .unwrap() + .parent_hash = ParentHash::from(vec![0xFF; 32]); + + let res = test_tree + .tree + .validate_parent_hashes(&test_cipher_suite_provider(TEST_CIPHER_SUITE)) + .await; + + assert_matches!(res, Err(MlsError::ParentHashMismatch)); + } }