Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mls-rs-codec/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ pub enum Error {
UnsupportedEnumDiscriminant,
#[cfg_attr(feature = "std", error("Expected UTF-8 string"))]
Utf8,
#[cfg_attr(feature = "std", error("Invalid content"))]
InvalidContent,
#[cfg_attr(feature = "std", error("mls codec error: {0}"))]
Custom(u8),
}
Expand Down
241 changes: 239 additions & 2 deletions mls-rs-codec/src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,13 @@ where
let mut items = HashMap::new();

while !data.is_empty() {
items.insert(K::mls_decode(data)?, V::mls_decode(data)?);
let before = data.len();
let key = K::mls_decode(data)?;
let value = V::mls_decode(data)?;

if data.len() == before || items.insert(key, value).is_some() {
return Err(crate::Error::InvalidContent);
}
}

Ok(items)
Expand Down Expand Up @@ -81,10 +87,241 @@ where
let mut items = BTreeMap::new();

while !data.is_empty() {
items.insert(K::mls_decode(data)?, V::mls_decode(data)?);
let before = data.len();
let key = K::mls_decode(data)?;
let value = V::mls_decode(data)?;

if data.len() == before || items.insert(key, value).is_some() {
return Err(crate::Error::InvalidContent);
}
}

Ok(items)
})
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{MlsDecode, MlsEncode};
use assert_matches::assert_matches;

#[cfg(feature = "std")]
#[test]
fn test_basic_hashmap_roundtrip() {
let mut original = HashMap::new();
original.insert(1u32, 100);
original.insert(2u32, 200);
original.insert(3u32, 300);
original.insert(4u32, 100);

let mut encoded = Vec::new();
original.mls_encode(&mut encoded).unwrap();

let mut slice = encoded.as_slice();
let decoded = HashMap::<u32, u32>::mls_decode(&mut slice).unwrap();

assert_eq!(original, decoded);
assert!(slice.is_empty());
}

#[test]
fn test_basic_btreemap_roundtrip() {
let mut original = BTreeMap::new();
original.insert(1u32, 100);
original.insert(2u32, 200);
original.insert(3u32, 300);
original.insert(4u32, 100);

let mut encoded = Vec::new();
original.mls_encode(&mut encoded).unwrap();

let mut slice = encoded.as_slice();
let decoded = BTreeMap::<u32, u32>::mls_decode(&mut slice).unwrap();

assert_eq!(original, decoded);
assert!(slice.is_empty());
}

#[cfg(feature = "std")]
#[test]
fn test_empty_structure_in_hashmap() {
let mut original: HashMap<u8, [u8; 0]> = HashMap::new();
original.insert(1u8, []);
original.insert(2u8, []);

let mut encoded = Vec::new();
original.mls_encode(&mut encoded).unwrap();

let mut slice = encoded.as_slice();
let decoded = HashMap::<u8, [u8; 0]>::mls_decode(&mut slice).unwrap();
assert_eq!(original, decoded);
assert!(slice.is_empty());
}

#[cfg(feature = "std")]
#[test]
fn hashmap_zero_length_structure() {
let res = HashMap::<[u8; 0], [u8; 0]>::mls_decode(&mut &[0x01, 0xff][..]);
assert_matches!(res, Err(crate::Error::InvalidContent))
}

#[cfg(feature = "std")]
#[test]
fn hashmap_will_not_allow_duplicate_keys() {
let mut encoded = Vec::new();

vec![(1u8, 2u8), (3u8, 4u8), (1u8, 5u8)]
.mls_encode(&mut encoded)
.unwrap();

let res = HashMap::<u8, u8>::mls_decode(&mut &*encoded);
assert_matches!(res, Err(crate::Error::InvalidContent))
}

#[test]
fn btree_map_will_not_allow_duplicate_keys() {
let mut encoded = Vec::new();

vec![(1u8, 2u8), (3u8, 4u8), (1u8, 5u8)]
.mls_encode(&mut encoded)
.unwrap();

let res = BTreeMap::<u8, u8>::mls_decode(&mut &*encoded);
assert_matches!(res, Err(crate::Error::InvalidContent))
}

#[test]
fn btree_map_zero_length_structure() {
let res = BTreeMap::<[u8; 0], [u8; 0]>::mls_decode(&mut &[0x01, 0xff][..]);
assert_matches!(res, Err(crate::Error::InvalidContent))
}

#[cfg(feature = "std")]
#[test]
fn test_hashmap_encoding_order() {
let mut hash = HashMap::new();
hash.insert(3u32, "c".to_string());
hash.insert(1u32, "a".to_string());
hash.insert(2u32, "b".to_string());

let mut btree = BTreeMap::new();
btree.insert(3u32, "c".to_string());
btree.insert(1u32, "a".to_string());
btree.insert(2u32, "b".to_string());

let mut hash_encoded = Vec::new();
hash.mls_encode(&mut hash_encoded).unwrap();

let mut btree_encoded = Vec::new();
btree.mls_encode(&mut btree_encoded).unwrap();

assert_eq!(hash_encoded, btree_encoded);
}

#[cfg(feature = "std")]
#[test]
fn test_empty_hashmap() {
let empty_hash: HashMap<u32, u32> = HashMap::new();
let mut encoded = Vec::new();
empty_hash.mls_encode(&mut encoded).unwrap();

let mut slice = encoded.as_slice();
let decoded = HashMap::<u32, u32>::mls_decode(&mut slice).unwrap();
assert!(decoded.is_empty());
assert!(slice.is_empty());
}

#[test]
fn test_empty_btreemap() {
let empty_btree: BTreeMap<u32, u32> = BTreeMap::new();
let mut encoded = Vec::new();
empty_btree.mls_encode(&mut encoded).unwrap();

let mut slice = encoded.as_slice();
let decoded = BTreeMap::<u32, u32>::mls_decode(&mut slice).unwrap();
assert!(decoded.is_empty());
assert!(slice.is_empty());
}

#[cfg(feature = "std")]
#[test]
fn test_large_hashmap() {
let mut large_map = HashMap::new();
for i in 0..1000u32 {
large_map.insert(i, i * 2);
}

let mut encoded = Vec::new();
large_map.mls_encode(&mut encoded).unwrap();

let mut slice = encoded.as_slice();
let decoded = HashMap::<u32, u32>::mls_decode(&mut slice).unwrap();

assert_eq!(large_map, decoded);
assert!(slice.is_empty());
}

#[test]
fn test_large_btreemap() {
let mut large_map = BTreeMap::new();
for i in 0..1000u32 {
large_map.insert(i, i * 2);
}

let mut encoded = Vec::new();
large_map.mls_encode(&mut encoded).unwrap();

let mut slice = encoded.as_slice();
let decoded = BTreeMap::<u32, u32>::mls_decode(&mut slice).unwrap();

assert_eq!(large_map, decoded);
assert!(slice.is_empty());
}

#[test]
fn test_invalid_btreemap_decode() {
// Test with invalid data
let invalid_data = vec![0xFF, 0xFF, 0xFF, 0xFF]; // Invalid length prefix
let mut slice = invalid_data.as_slice();

let result = BTreeMap::<u32, u32>::mls_decode(&mut slice);
assert!(result.is_err());

// Test with truncated data
let mut valid_map = BTreeMap::new();
valid_map.insert(1u32, 100u32);

let mut encoded = Vec::new();
valid_map.mls_encode(&mut encoded).unwrap();
encoded.truncate(encoded.len() - 1); // Remove last byte

let mut slice = encoded.as_slice();
let result = BTreeMap::<u32, u32>::mls_decode(&mut slice);
assert!(result.is_err());
}

#[cfg(feature = "std")]
#[test]
fn test_invalid_hashmap_decode() {
// Test with invalid data
let invalid_data = vec![0xFF, 0xFF, 0xFF, 0xFF]; // Invalid length prefix
let mut slice = invalid_data.as_slice();

let result = HashMap::<u32, u32>::mls_decode(&mut slice);
assert!(result.is_err());

// Test with truncated data
let mut valid_map = HashMap::new();
valid_map.insert(1u32, 100u32);

let mut encoded = Vec::new();
valid_map.mls_encode(&mut encoded).unwrap();
encoded.truncate(encoded.len() - 1); // Remove last byte

let mut slice = encoded.as_slice();
let result = HashMap::<u32, u32>::mls_decode(&mut slice);
assert!(result.is_err());
}
}
15 changes: 14 additions & 1 deletion mls-rs-codec/src/vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,14 @@ where
let mut items = Vec::new();

while !data.is_empty() {
items.push(T::mls_decode(data)?);
let before = data.len();
let value = T::mls_decode(data)?;

if data.len() == before {
return Err(crate::Error::InvalidContent);
}

items.push(value);
}

Ok(items)
Expand Down Expand Up @@ -97,4 +104,10 @@ mod tests {
Err(Error::UnexpectedEOF)
);
}

#[test]
fn vec_zero_length_structure() {
let res = Vec::<[u8; 0]>::mls_decode(&mut &[0x01, 0xff][..]);
assert_matches!(res, Err(crate::Error::InvalidContent))
}
}
1 change: 1 addition & 0 deletions mls-rs-crypto-rustcrypto/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ rand_core = { version = "0.6", default-features = false, features = ["alloc"] }
# AEAD
aes-gcm = { version = "0.10", features = ["zeroize"] }
chacha20poly1305 = { version = "0.10", default-features = false, features = ["alloc", "getrandom"] }
generic-array = { version = "=0.14.7", default-features = false }
aead = { version = "0.5", default-features = false, features = ["alloc", "getrandom"] }

# Hash
Expand Down
Loading