Skip to content

Commit a47164e

Browse files
author
Marta Mularczyk
committed
fix: Verify that unmerged leaves are sorted
1 parent 738e250 commit a47164e

File tree

2 files changed

+32
-6
lines changed

2 files changed

+32
-6
lines changed

mls-rs/src/tree_kem/mod.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,11 +327,15 @@ impl TreeKemPublic {
327327

328328
fn update_unmerged(&mut self, index: LeafIndex) -> Result<(), MlsError> {
329329
// For a given leaf index, find parent nodes and add the leaf to the unmerged leaf
330-
self.nodes.direct_copath(index).into_iter().for_each(|i| {
330+
for i in self.nodes.direct_copath(index) {
331331
if let Ok(p) = self.nodes.borrow_as_parent_mut(i.path) {
332-
p.unmerged_leaves.push(index)
332+
// Unmerged leaves MUST be sorted and some of our mechanisms rely on this.
333+
match p.unmerged_leaves.binary_search(&index) {
334+
Ok(_) => return Err(MlsError::ParentHashMismatch),
335+
Err(to_insert) => p.unmerged_leaves.insert(to_insert, index),
336+
}
333337
}
334-
});
338+
}
335339

336340
Ok(())
337341
}

mls-rs/src/tree_kem/tree_validator.rs

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,17 +117,28 @@ impl<'a, C: IdentityProvider, CSP: CipherSuiteProvider> TreeValidator<'a, C, CSP
117117
}
118118

119119
fn validate_unmerged(tree: &TreeKemPublic) -> Result<(), MlsError> {
120+
// The entries in the unmerged_leaves vector MUST be sorted in increasing order.
121+
tree.nodes
122+
.iter()
123+
.flatten()
124+
.all(|n| match n {
125+
Node::Leaf(_) => true,
126+
Node::Parent(p) => p.unmerged_leaves.is_sorted(),
127+
})
128+
.then_some(())
129+
.ok_or(MlsError::ParentHashMismatch)?;
130+
120131
let unmerged_sets = tree.nodes.iter().map(|n| {
121132
#[cfg(feature = "std")]
122133
if let Some(Node::Parent(p)) = n {
123-
HashSet::from_iter(p.unmerged_leaves.iter().cloned())
134+
HashSet::from_iter(p.unmerged_leaves.iter())
124135
} else {
125136
HashSet::new()
126137
}
127138

128139
#[cfg(not(feature = "std"))]
129140
if let Some(Node::Parent(p)) = n {
130-
p.unmerged_leaves.clone()
141+
&p.unmerged_leaves
131142
} else {
132143
vec![]
133144
}
@@ -152,7 +163,7 @@ fn validate_unmerged(tree: &TreeKemPublic) -> Result<(), MlsError> {
152163
let parent_node = tree.nodes.borrow_as_parent(ps.parent)?;
153164

154165
if parent_node.unmerged_leaves.contains(&index) {
155-
unmerged_sets[ps.parent as usize].retain(|i| i != &index);
166+
unmerged_sets[ps.parent as usize].retain(|i| **i != index);
156167

157168
n = ps.parent;
158169
} else {
@@ -366,4 +377,15 @@ mod tests {
366377
Err(MlsError::UnmergedLeavesMismatch)
367378
);
368379
}
380+
381+
#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
382+
async fn verify_unmerged_leaves_sorted() {
383+
let mut tree = get_test_tree_fig_12(TEST_CIPHER_SUITE).await;
384+
385+
// Set unsorted unmerged leaves
386+
tree.nodes.borrow_as_parent_mut(3).unwrap().unmerged_leaves =
387+
vec![LeafIndex::unchecked(3), LeafIndex::unchecked(1)];
388+
389+
assert_matches!(validate_unmerged(&tree), Err(MlsError::ParentHashMismatch));
390+
}
369391
}

0 commit comments

Comments
 (0)