From dfedc10c8a05ffcfe51ac168ace6acada85a3ea7 Mon Sep 17 00:00:00 2001 From: Tom Leavy Date: Tue, 30 Sep 2025 16:10:22 -0400 Subject: [PATCH 1/3] fix: improve edge case detection in HashMap and BTreeMap codec decoding --- mls-rs-codec/src/lib.rs | 2 + mls-rs-codec/src/map.rs | 241 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 241 insertions(+), 2 deletions(-) diff --git a/mls-rs-codec/src/lib.rs b/mls-rs-codec/src/lib.rs index c6d754b80..0973a12d8 100644 --- a/mls-rs-codec/src/lib.rs +++ b/mls-rs-codec/src/lib.rs @@ -50,6 +50,8 @@ pub enum Error { UnsupportedEnumDiscriminant, #[cfg_attr(feature = "std", error("Expected UTF-8 string"))] Utf8, + #[cfg_attr(feature = "std", error("Invalid map content"))] + InvalidMapContent, #[cfg_attr(feature = "std", error("mls codec error: {0}"))] Custom(u8), } diff --git a/mls-rs-codec/src/map.rs b/mls-rs-codec/src/map.rs index 9c7f3bec6..a5250d5b7 100644 --- a/mls-rs-codec/src/map.rs +++ b/mls-rs-codec/src/map.rs @@ -43,7 +43,13 @@ where let mut items = HashMap::new(); while !data.is_empty() { - items.insert(K::mls_decode(data)?, V::mls_decode(data)?); + let before = data.len(); + let key = K::mls_decode(data)?; + let value = V::mls_decode(data)?; + + if data.len() == before || items.insert(key, value).is_some() { + return Err(crate::Error::InvalidMapContent); + } } Ok(items) @@ -81,10 +87,241 @@ where let mut items = BTreeMap::new(); while !data.is_empty() { - items.insert(K::mls_decode(data)?, V::mls_decode(data)?); + let before = data.len(); + let key = K::mls_decode(data)?; + let value = V::mls_decode(data)?; + + if data.len() == before || items.insert(key, value).is_some() { + return Err(crate::Error::InvalidMapContent); + } } Ok(items) }) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{MlsDecode, MlsEncode}; + use assert_matches::assert_matches; + + #[cfg(feature = "std")] + #[test] + fn test_basic_hashmap_roundtrip() { + let mut original = HashMap::new(); + original.insert(1u32, 100); + original.insert(2u32, 200); + original.insert(3u32, 300); + original.insert(4u32, 100); + + let mut encoded = Vec::new(); + original.mls_encode(&mut encoded).unwrap(); + + let mut slice = encoded.as_slice(); + let decoded = HashMap::::mls_decode(&mut slice).unwrap(); + + assert_eq!(original, decoded); + assert!(slice.is_empty()); + } + + #[test] + fn test_basic_btreemap_roundtrip() { + let mut original = BTreeMap::new(); + original.insert(1u32, 100); + original.insert(2u32, 200); + original.insert(3u32, 300); + original.insert(4u32, 100); + + let mut encoded = Vec::new(); + original.mls_encode(&mut encoded).unwrap(); + + let mut slice = encoded.as_slice(); + let decoded = BTreeMap::::mls_decode(&mut slice).unwrap(); + + assert_eq!(original, decoded); + assert!(slice.is_empty()); + } + + #[cfg(feature = "std")] + #[test] + fn test_empty_structure_in_hashmap() { + let mut original: HashMap = HashMap::new(); + original.insert(1u8, []); + original.insert(2u8, []); + + let mut encoded = Vec::new(); + original.mls_encode(&mut encoded).unwrap(); + + let mut slice = encoded.as_slice(); + let decoded = HashMap::::mls_decode(&mut slice).unwrap(); + assert_eq!(original, decoded); + assert!(slice.is_empty()); + } + + #[cfg(feature = "std")] + #[test] + fn hashmap_zero_length_structure() { + let res = HashMap::<[u8; 0], [u8; 0]>::mls_decode(&mut &[0x01, 0xff][..]); + assert_matches!(res, Err(crate::Error::InvalidMapContent)) + } + + #[cfg(feature = "std")] + #[test] + fn hashmap_will_not_allow_duplicate_keys() { + let mut encoded = Vec::new(); + + vec![(1u8, 2u8), (3u8, 4u8), (1u8, 5u8)] + .mls_encode(&mut encoded) + .unwrap(); + + let res = HashMap::::mls_decode(&mut &*encoded); + assert_matches!(res, Err(crate::Error::InvalidMapContent)) + } + + #[test] + fn btree_map_will_not_allow_duplicate_keys() { + let mut encoded = Vec::new(); + + vec![(1u8, 2u8), (3u8, 4u8), (1u8, 5u8)] + .mls_encode(&mut encoded) + .unwrap(); + + let res = BTreeMap::::mls_decode(&mut &*encoded); + assert_matches!(res, Err(crate::Error::InvalidMapContent)) + } + + #[test] + fn btree_map_zero_length_structure() { + let res = BTreeMap::<[u8; 0], [u8; 0]>::mls_decode(&mut &[0x01, 0xff][..]); + assert_matches!(res, Err(crate::Error::InvalidMapContent)) + } + + #[cfg(feature = "std")] + #[test] + fn test_hashmap_encoding_order() { + let mut hash = HashMap::new(); + hash.insert(3u32, "c".to_string()); + hash.insert(1u32, "a".to_string()); + hash.insert(2u32, "b".to_string()); + + let mut btree = BTreeMap::new(); + btree.insert(3u32, "c".to_string()); + btree.insert(1u32, "a".to_string()); + btree.insert(2u32, "b".to_string()); + + let mut hash_encoded = Vec::new(); + hash.mls_encode(&mut hash_encoded).unwrap(); + + let mut btree_encoded = Vec::new(); + btree.mls_encode(&mut btree_encoded).unwrap(); + + assert_eq!(hash_encoded, btree_encoded); + } + + #[cfg(feature = "std")] + #[test] + fn test_empty_hashmap() { + let empty_hash: HashMap = HashMap::new(); + let mut encoded = Vec::new(); + empty_hash.mls_encode(&mut encoded).unwrap(); + + let mut slice = encoded.as_slice(); + let decoded = HashMap::::mls_decode(&mut slice).unwrap(); + assert!(decoded.is_empty()); + assert!(slice.is_empty()); + } + + #[test] + fn test_empty_btreemap() { + let empty_btree: BTreeMap = BTreeMap::new(); + let mut encoded = Vec::new(); + empty_btree.mls_encode(&mut encoded).unwrap(); + + let mut slice = encoded.as_slice(); + let decoded = BTreeMap::::mls_decode(&mut slice).unwrap(); + assert!(decoded.is_empty()); + assert!(slice.is_empty()); + } + + #[cfg(feature = "std")] + #[test] + fn test_large_hashmap() { + let mut large_map = HashMap::new(); + for i in 0..1000u32 { + large_map.insert(i, i * 2); + } + + let mut encoded = Vec::new(); + large_map.mls_encode(&mut encoded).unwrap(); + + let mut slice = encoded.as_slice(); + let decoded = HashMap::::mls_decode(&mut slice).unwrap(); + + assert_eq!(large_map, decoded); + assert!(slice.is_empty()); + } + + #[test] + fn test_large_btreemap() { + let mut large_map = BTreeMap::new(); + for i in 0..1000u32 { + large_map.insert(i, i * 2); + } + + let mut encoded = Vec::new(); + large_map.mls_encode(&mut encoded).unwrap(); + + let mut slice = encoded.as_slice(); + let decoded = BTreeMap::::mls_decode(&mut slice).unwrap(); + + assert_eq!(large_map, decoded); + assert!(slice.is_empty()); + } + + #[test] + fn test_invalid_btreemap_decode() { + // Test with invalid data + let invalid_data = vec![0xFF, 0xFF, 0xFF, 0xFF]; // Invalid length prefix + let mut slice = invalid_data.as_slice(); + + let result = BTreeMap::::mls_decode(&mut slice); + assert!(result.is_err()); + + // Test with truncated data + let mut valid_map = BTreeMap::new(); + valid_map.insert(1u32, 100u32); + + let mut encoded = Vec::new(); + valid_map.mls_encode(&mut encoded).unwrap(); + encoded.truncate(encoded.len() - 1); // Remove last byte + + let mut slice = encoded.as_slice(); + let result = BTreeMap::::mls_decode(&mut slice); + assert!(result.is_err()); + } + + #[cfg(feature = "std")] + #[test] + fn test_invalid_hashmap_decode() { + // Test with invalid data + let invalid_data = vec![0xFF, 0xFF, 0xFF, 0xFF]; // Invalid length prefix + let mut slice = invalid_data.as_slice(); + + let result = HashMap::::mls_decode(&mut slice); + assert!(result.is_err()); + + // Test with truncated data + let mut valid_map = HashMap::new(); + valid_map.insert(1u32, 100u32); + + let mut encoded = Vec::new(); + valid_map.mls_encode(&mut encoded).unwrap(); + encoded.truncate(encoded.len() - 1); // Remove last byte + + let mut slice = encoded.as_slice(); + let result = HashMap::::mls_decode(&mut slice); + assert!(result.is_err()); + } +} From 9ed3c15a6c7b22ad5133abfc70d9ee5abd2a374c Mon Sep 17 00:00:00 2001 From: Tom Leavy Date: Wed, 1 Oct 2025 13:21:38 -0400 Subject: [PATCH 2/3] apply this fix to Vec as well --- mls-rs-codec/src/lib.rs | 4 ++-- mls-rs-codec/src/map.rs | 12 ++++++------ mls-rs-codec/src/vec.rs | 15 ++++++++++++++- 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/mls-rs-codec/src/lib.rs b/mls-rs-codec/src/lib.rs index 0973a12d8..72e66c231 100644 --- a/mls-rs-codec/src/lib.rs +++ b/mls-rs-codec/src/lib.rs @@ -50,8 +50,8 @@ pub enum Error { UnsupportedEnumDiscriminant, #[cfg_attr(feature = "std", error("Expected UTF-8 string"))] Utf8, - #[cfg_attr(feature = "std", error("Invalid map content"))] - InvalidMapContent, + #[cfg_attr(feature = "std", error("Invalid content"))] + InvalidContent, #[cfg_attr(feature = "std", error("mls codec error: {0}"))] Custom(u8), } diff --git a/mls-rs-codec/src/map.rs b/mls-rs-codec/src/map.rs index a5250d5b7..c71244ff4 100644 --- a/mls-rs-codec/src/map.rs +++ b/mls-rs-codec/src/map.rs @@ -48,7 +48,7 @@ where let value = V::mls_decode(data)?; if data.len() == before || items.insert(key, value).is_some() { - return Err(crate::Error::InvalidMapContent); + return Err(crate::Error::InvalidContent); } } @@ -92,7 +92,7 @@ where let value = V::mls_decode(data)?; if data.len() == before || items.insert(key, value).is_some() { - return Err(crate::Error::InvalidMapContent); + return Err(crate::Error::InvalidContent); } } @@ -164,7 +164,7 @@ mod tests { #[test] fn hashmap_zero_length_structure() { let res = HashMap::<[u8; 0], [u8; 0]>::mls_decode(&mut &[0x01, 0xff][..]); - assert_matches!(res, Err(crate::Error::InvalidMapContent)) + assert_matches!(res, Err(crate::Error::InvalidContent)) } #[cfg(feature = "std")] @@ -177,7 +177,7 @@ mod tests { .unwrap(); let res = HashMap::::mls_decode(&mut &*encoded); - assert_matches!(res, Err(crate::Error::InvalidMapContent)) + assert_matches!(res, Err(crate::Error::InvalidContent)) } #[test] @@ -189,13 +189,13 @@ mod tests { .unwrap(); let res = BTreeMap::::mls_decode(&mut &*encoded); - assert_matches!(res, Err(crate::Error::InvalidMapContent)) + assert_matches!(res, Err(crate::Error::InvalidContent)) } #[test] fn btree_map_zero_length_structure() { let res = BTreeMap::<[u8; 0], [u8; 0]>::mls_decode(&mut &[0x01, 0xff][..]); - assert_matches!(res, Err(crate::Error::InvalidMapContent)) + assert_matches!(res, Err(crate::Error::InvalidContent)) } #[cfg(feature = "std")] diff --git a/mls-rs-codec/src/vec.rs b/mls-rs-codec/src/vec.rs index 3d7f1a8a4..05a47a363 100644 --- a/mls-rs-codec/src/vec.rs +++ b/mls-rs-codec/src/vec.rs @@ -53,7 +53,14 @@ where let mut items = Vec::new(); while !data.is_empty() { - items.push(T::mls_decode(data)?); + let before = data.len(); + let value = T::mls_decode(data)?; + + if data.len() == before { + return Err(crate::Error::InvalidContent); + } + + items.push(value); } Ok(items) @@ -97,4 +104,10 @@ mod tests { Err(Error::UnexpectedEOF) ); } + + #[test] + fn vec_zero_length_structure() { + let res = Vec::<[u8; 0]>::mls_decode(&mut &[0x01, 0xff][..]); + assert_matches!(res, Err(crate::Error::InvalidContent)) + } } From 15e03632493cc75ad90ce10928f22514425eebb3 Mon Sep 17 00:00:00 2001 From: Tom Leavy Date: Mon, 13 Oct 2025 16:38:02 -0400 Subject: [PATCH 3/3] fix generic-array version --- mls-rs-crypto-rustcrypto/Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/mls-rs-crypto-rustcrypto/Cargo.toml b/mls-rs-crypto-rustcrypto/Cargo.toml index 4d2c814a6..cb2fe324f 100644 --- a/mls-rs-crypto-rustcrypto/Cargo.toml +++ b/mls-rs-crypto-rustcrypto/Cargo.toml @@ -43,6 +43,7 @@ rand_core = { version = "0.6", default-features = false, features = ["alloc"] } # AEAD aes-gcm = { version = "0.10", features = ["zeroize"] } chacha20poly1305 = { version = "0.10", default-features = false, features = ["alloc", "getrandom"] } +generic-array = { version = "=0.14.7", default-features = false } aead = { version = "0.5", default-features = false, features = ["alloc", "getrandom"] } # Hash