diff --git a/mls-rs-codec/src/lib.rs b/mls-rs-codec/src/lib.rs index c6d754b8..72e66c23 100644 --- a/mls-rs-codec/src/lib.rs +++ b/mls-rs-codec/src/lib.rs @@ -50,6 +50,8 @@ pub enum Error { UnsupportedEnumDiscriminant, #[cfg_attr(feature = "std", error("Expected UTF-8 string"))] Utf8, + #[cfg_attr(feature = "std", error("Invalid content"))] + InvalidContent, #[cfg_attr(feature = "std", error("mls codec error: {0}"))] Custom(u8), } diff --git a/mls-rs-codec/src/map.rs b/mls-rs-codec/src/map.rs index 9c7f3bec..c71244ff 100644 --- a/mls-rs-codec/src/map.rs +++ b/mls-rs-codec/src/map.rs @@ -43,7 +43,13 @@ where let mut items = HashMap::new(); while !data.is_empty() { - items.insert(K::mls_decode(data)?, V::mls_decode(data)?); + let before = data.len(); + let key = K::mls_decode(data)?; + let value = V::mls_decode(data)?; + + if data.len() == before || items.insert(key, value).is_some() { + return Err(crate::Error::InvalidContent); + } } Ok(items) @@ -81,10 +87,241 @@ where let mut items = BTreeMap::new(); while !data.is_empty() { - items.insert(K::mls_decode(data)?, V::mls_decode(data)?); + let before = data.len(); + let key = K::mls_decode(data)?; + let value = V::mls_decode(data)?; + + if data.len() == before || items.insert(key, value).is_some() { + return Err(crate::Error::InvalidContent); + } } Ok(items) }) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{MlsDecode, MlsEncode}; + use assert_matches::assert_matches; + + #[cfg(feature = "std")] + #[test] + fn test_basic_hashmap_roundtrip() { + let mut original = HashMap::new(); + original.insert(1u32, 100); + original.insert(2u32, 200); + original.insert(3u32, 300); + original.insert(4u32, 100); + + let mut encoded = Vec::new(); + original.mls_encode(&mut encoded).unwrap(); + + let mut slice = encoded.as_slice(); + let decoded = HashMap::::mls_decode(&mut slice).unwrap(); + + assert_eq!(original, decoded); + assert!(slice.is_empty()); + } + + #[test] + fn test_basic_btreemap_roundtrip() { + let mut original = BTreeMap::new(); + original.insert(1u32, 100); + original.insert(2u32, 200); + original.insert(3u32, 300); + original.insert(4u32, 100); + + let mut encoded = Vec::new(); + original.mls_encode(&mut encoded).unwrap(); + + let mut slice = encoded.as_slice(); + let decoded = BTreeMap::::mls_decode(&mut slice).unwrap(); + + assert_eq!(original, decoded); + assert!(slice.is_empty()); + } + + #[cfg(feature = "std")] + #[test] + fn test_empty_structure_in_hashmap() { + let mut original: HashMap = HashMap::new(); + original.insert(1u8, []); + original.insert(2u8, []); + + let mut encoded = Vec::new(); + original.mls_encode(&mut encoded).unwrap(); + + let mut slice = encoded.as_slice(); + let decoded = HashMap::::mls_decode(&mut slice).unwrap(); + assert_eq!(original, decoded); + assert!(slice.is_empty()); + } + + #[cfg(feature = "std")] + #[test] + fn hashmap_zero_length_structure() { + let res = HashMap::<[u8; 0], [u8; 0]>::mls_decode(&mut &[0x01, 0xff][..]); + assert_matches!(res, Err(crate::Error::InvalidContent)) + } + + #[cfg(feature = "std")] + #[test] + fn hashmap_will_not_allow_duplicate_keys() { + let mut encoded = Vec::new(); + + vec![(1u8, 2u8), (3u8, 4u8), (1u8, 5u8)] + .mls_encode(&mut encoded) + .unwrap(); + + let res = HashMap::::mls_decode(&mut &*encoded); + assert_matches!(res, Err(crate::Error::InvalidContent)) + } + + #[test] + fn btree_map_will_not_allow_duplicate_keys() { + let mut encoded = Vec::new(); + + vec![(1u8, 2u8), (3u8, 4u8), (1u8, 5u8)] + .mls_encode(&mut encoded) + .unwrap(); + + let res = BTreeMap::::mls_decode(&mut &*encoded); + assert_matches!(res, Err(crate::Error::InvalidContent)) + } + + #[test] + fn btree_map_zero_length_structure() { + let res = BTreeMap::<[u8; 0], [u8; 0]>::mls_decode(&mut &[0x01, 0xff][..]); + assert_matches!(res, Err(crate::Error::InvalidContent)) + } + + #[cfg(feature = "std")] + #[test] + fn test_hashmap_encoding_order() { + let mut hash = HashMap::new(); + hash.insert(3u32, "c".to_string()); + hash.insert(1u32, "a".to_string()); + hash.insert(2u32, "b".to_string()); + + let mut btree = BTreeMap::new(); + btree.insert(3u32, "c".to_string()); + btree.insert(1u32, "a".to_string()); + btree.insert(2u32, "b".to_string()); + + let mut hash_encoded = Vec::new(); + hash.mls_encode(&mut hash_encoded).unwrap(); + + let mut btree_encoded = Vec::new(); + btree.mls_encode(&mut btree_encoded).unwrap(); + + assert_eq!(hash_encoded, btree_encoded); + } + + #[cfg(feature = "std")] + #[test] + fn test_empty_hashmap() { + let empty_hash: HashMap = HashMap::new(); + let mut encoded = Vec::new(); + empty_hash.mls_encode(&mut encoded).unwrap(); + + let mut slice = encoded.as_slice(); + let decoded = HashMap::::mls_decode(&mut slice).unwrap(); + assert!(decoded.is_empty()); + assert!(slice.is_empty()); + } + + #[test] + fn test_empty_btreemap() { + let empty_btree: BTreeMap = BTreeMap::new(); + let mut encoded = Vec::new(); + empty_btree.mls_encode(&mut encoded).unwrap(); + + let mut slice = encoded.as_slice(); + let decoded = BTreeMap::::mls_decode(&mut slice).unwrap(); + assert!(decoded.is_empty()); + assert!(slice.is_empty()); + } + + #[cfg(feature = "std")] + #[test] + fn test_large_hashmap() { + let mut large_map = HashMap::new(); + for i in 0..1000u32 { + large_map.insert(i, i * 2); + } + + let mut encoded = Vec::new(); + large_map.mls_encode(&mut encoded).unwrap(); + + let mut slice = encoded.as_slice(); + let decoded = HashMap::::mls_decode(&mut slice).unwrap(); + + assert_eq!(large_map, decoded); + assert!(slice.is_empty()); + } + + #[test] + fn test_large_btreemap() { + let mut large_map = BTreeMap::new(); + for i in 0..1000u32 { + large_map.insert(i, i * 2); + } + + let mut encoded = Vec::new(); + large_map.mls_encode(&mut encoded).unwrap(); + + let mut slice = encoded.as_slice(); + let decoded = BTreeMap::::mls_decode(&mut slice).unwrap(); + + assert_eq!(large_map, decoded); + assert!(slice.is_empty()); + } + + #[test] + fn test_invalid_btreemap_decode() { + // Test with invalid data + let invalid_data = vec![0xFF, 0xFF, 0xFF, 0xFF]; // Invalid length prefix + let mut slice = invalid_data.as_slice(); + + let result = BTreeMap::::mls_decode(&mut slice); + assert!(result.is_err()); + + // Test with truncated data + let mut valid_map = BTreeMap::new(); + valid_map.insert(1u32, 100u32); + + let mut encoded = Vec::new(); + valid_map.mls_encode(&mut encoded).unwrap(); + encoded.truncate(encoded.len() - 1); // Remove last byte + + let mut slice = encoded.as_slice(); + let result = BTreeMap::::mls_decode(&mut slice); + assert!(result.is_err()); + } + + #[cfg(feature = "std")] + #[test] + fn test_invalid_hashmap_decode() { + // Test with invalid data + let invalid_data = vec![0xFF, 0xFF, 0xFF, 0xFF]; // Invalid length prefix + let mut slice = invalid_data.as_slice(); + + let result = HashMap::::mls_decode(&mut slice); + assert!(result.is_err()); + + // Test with truncated data + let mut valid_map = HashMap::new(); + valid_map.insert(1u32, 100u32); + + let mut encoded = Vec::new(); + valid_map.mls_encode(&mut encoded).unwrap(); + encoded.truncate(encoded.len() - 1); // Remove last byte + + let mut slice = encoded.as_slice(); + let result = HashMap::::mls_decode(&mut slice); + assert!(result.is_err()); + } +} diff --git a/mls-rs-codec/src/vec.rs b/mls-rs-codec/src/vec.rs index 3d7f1a8a..05a47a36 100644 --- a/mls-rs-codec/src/vec.rs +++ b/mls-rs-codec/src/vec.rs @@ -53,7 +53,14 @@ where let mut items = Vec::new(); while !data.is_empty() { - items.push(T::mls_decode(data)?); + let before = data.len(); + let value = T::mls_decode(data)?; + + if data.len() == before { + return Err(crate::Error::InvalidContent); + } + + items.push(value); } Ok(items) @@ -97,4 +104,10 @@ mod tests { Err(Error::UnexpectedEOF) ); } + + #[test] + fn vec_zero_length_structure() { + let res = Vec::<[u8; 0]>::mls_decode(&mut &[0x01, 0xff][..]); + assert_matches!(res, Err(crate::Error::InvalidContent)) + } } diff --git a/mls-rs-crypto-rustcrypto/Cargo.toml b/mls-rs-crypto-rustcrypto/Cargo.toml index 4d2c814a..cb2fe324 100644 --- a/mls-rs-crypto-rustcrypto/Cargo.toml +++ b/mls-rs-crypto-rustcrypto/Cargo.toml @@ -43,6 +43,7 @@ rand_core = { version = "0.6", default-features = false, features = ["alloc"] } # AEAD aes-gcm = { version = "0.10", features = ["zeroize"] } chacha20poly1305 = { version = "0.10", default-features = false, features = ["alloc", "getrandom"] } +generic-array = { version = "=0.14.7", default-features = false } aead = { version = "0.5", default-features = false, features = ["alloc", "getrandom"] } # Hash