Skip to content

Commit dfedc10

Browse files
committed
fix: improve edge case detection in HashMap and BTreeMap codec decoding
1 parent 8b8b521 commit dfedc10

File tree

2 files changed

+241
-2
lines changed

2 files changed

+241
-2
lines changed

mls-rs-codec/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ pub enum Error {
5050
UnsupportedEnumDiscriminant,
5151
#[cfg_attr(feature = "std", error("Expected UTF-8 string"))]
5252
Utf8,
53+
#[cfg_attr(feature = "std", error("Invalid map content"))]
54+
InvalidMapContent,
5355
#[cfg_attr(feature = "std", error("mls codec error: {0}"))]
5456
Custom(u8),
5557
}

mls-rs-codec/src/map.rs

Lines changed: 239 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,13 @@ where
4343
let mut items = HashMap::new();
4444

4545
while !data.is_empty() {
46-
items.insert(K::mls_decode(data)?, V::mls_decode(data)?);
46+
let before = data.len();
47+
let key = K::mls_decode(data)?;
48+
let value = V::mls_decode(data)?;
49+
50+
if data.len() == before || items.insert(key, value).is_some() {
51+
return Err(crate::Error::InvalidMapContent);
52+
}
4753
}
4854

4955
Ok(items)
@@ -81,10 +87,241 @@ where
8187
let mut items = BTreeMap::new();
8288

8389
while !data.is_empty() {
84-
items.insert(K::mls_decode(data)?, V::mls_decode(data)?);
90+
let before = data.len();
91+
let key = K::mls_decode(data)?;
92+
let value = V::mls_decode(data)?;
93+
94+
if data.len() == before || items.insert(key, value).is_some() {
95+
return Err(crate::Error::InvalidMapContent);
96+
}
8597
}
8698

8799
Ok(items)
88100
})
89101
}
90102
}
103+
104+
#[cfg(test)]
105+
mod tests {
106+
use super::*;
107+
use crate::{MlsDecode, MlsEncode};
108+
use assert_matches::assert_matches;
109+
110+
#[cfg(feature = "std")]
111+
#[test]
112+
fn test_basic_hashmap_roundtrip() {
113+
let mut original = HashMap::new();
114+
original.insert(1u32, 100);
115+
original.insert(2u32, 200);
116+
original.insert(3u32, 300);
117+
original.insert(4u32, 100);
118+
119+
let mut encoded = Vec::new();
120+
original.mls_encode(&mut encoded).unwrap();
121+
122+
let mut slice = encoded.as_slice();
123+
let decoded = HashMap::<u32, u32>::mls_decode(&mut slice).unwrap();
124+
125+
assert_eq!(original, decoded);
126+
assert!(slice.is_empty());
127+
}
128+
129+
#[test]
130+
fn test_basic_btreemap_roundtrip() {
131+
let mut original = BTreeMap::new();
132+
original.insert(1u32, 100);
133+
original.insert(2u32, 200);
134+
original.insert(3u32, 300);
135+
original.insert(4u32, 100);
136+
137+
let mut encoded = Vec::new();
138+
original.mls_encode(&mut encoded).unwrap();
139+
140+
let mut slice = encoded.as_slice();
141+
let decoded = BTreeMap::<u32, u32>::mls_decode(&mut slice).unwrap();
142+
143+
assert_eq!(original, decoded);
144+
assert!(slice.is_empty());
145+
}
146+
147+
#[cfg(feature = "std")]
148+
#[test]
149+
fn test_empty_structure_in_hashmap() {
150+
let mut original: HashMap<u8, [u8; 0]> = HashMap::new();
151+
original.insert(1u8, []);
152+
original.insert(2u8, []);
153+
154+
let mut encoded = Vec::new();
155+
original.mls_encode(&mut encoded).unwrap();
156+
157+
let mut slice = encoded.as_slice();
158+
let decoded = HashMap::<u8, [u8; 0]>::mls_decode(&mut slice).unwrap();
159+
assert_eq!(original, decoded);
160+
assert!(slice.is_empty());
161+
}
162+
163+
#[cfg(feature = "std")]
164+
#[test]
165+
fn hashmap_zero_length_structure() {
166+
let res = HashMap::<[u8; 0], [u8; 0]>::mls_decode(&mut &[0x01, 0xff][..]);
167+
assert_matches!(res, Err(crate::Error::InvalidMapContent))
168+
}
169+
170+
#[cfg(feature = "std")]
171+
#[test]
172+
fn hashmap_will_not_allow_duplicate_keys() {
173+
let mut encoded = Vec::new();
174+
175+
vec![(1u8, 2u8), (3u8, 4u8), (1u8, 5u8)]
176+
.mls_encode(&mut encoded)
177+
.unwrap();
178+
179+
let res = HashMap::<u8, u8>::mls_decode(&mut &*encoded);
180+
assert_matches!(res, Err(crate::Error::InvalidMapContent))
181+
}
182+
183+
#[test]
184+
fn btree_map_will_not_allow_duplicate_keys() {
185+
let mut encoded = Vec::new();
186+
187+
vec![(1u8, 2u8), (3u8, 4u8), (1u8, 5u8)]
188+
.mls_encode(&mut encoded)
189+
.unwrap();
190+
191+
let res = BTreeMap::<u8, u8>::mls_decode(&mut &*encoded);
192+
assert_matches!(res, Err(crate::Error::InvalidMapContent))
193+
}
194+
195+
#[test]
196+
fn btree_map_zero_length_structure() {
197+
let res = BTreeMap::<[u8; 0], [u8; 0]>::mls_decode(&mut &[0x01, 0xff][..]);
198+
assert_matches!(res, Err(crate::Error::InvalidMapContent))
199+
}
200+
201+
#[cfg(feature = "std")]
202+
#[test]
203+
fn test_hashmap_encoding_order() {
204+
let mut hash = HashMap::new();
205+
hash.insert(3u32, "c".to_string());
206+
hash.insert(1u32, "a".to_string());
207+
hash.insert(2u32, "b".to_string());
208+
209+
let mut btree = BTreeMap::new();
210+
btree.insert(3u32, "c".to_string());
211+
btree.insert(1u32, "a".to_string());
212+
btree.insert(2u32, "b".to_string());
213+
214+
let mut hash_encoded = Vec::new();
215+
hash.mls_encode(&mut hash_encoded).unwrap();
216+
217+
let mut btree_encoded = Vec::new();
218+
btree.mls_encode(&mut btree_encoded).unwrap();
219+
220+
assert_eq!(hash_encoded, btree_encoded);
221+
}
222+
223+
#[cfg(feature = "std")]
224+
#[test]
225+
fn test_empty_hashmap() {
226+
let empty_hash: HashMap<u32, u32> = HashMap::new();
227+
let mut encoded = Vec::new();
228+
empty_hash.mls_encode(&mut encoded).unwrap();
229+
230+
let mut slice = encoded.as_slice();
231+
let decoded = HashMap::<u32, u32>::mls_decode(&mut slice).unwrap();
232+
assert!(decoded.is_empty());
233+
assert!(slice.is_empty());
234+
}
235+
236+
#[test]
237+
fn test_empty_btreemap() {
238+
let empty_btree: BTreeMap<u32, u32> = BTreeMap::new();
239+
let mut encoded = Vec::new();
240+
empty_btree.mls_encode(&mut encoded).unwrap();
241+
242+
let mut slice = encoded.as_slice();
243+
let decoded = BTreeMap::<u32, u32>::mls_decode(&mut slice).unwrap();
244+
assert!(decoded.is_empty());
245+
assert!(slice.is_empty());
246+
}
247+
248+
#[cfg(feature = "std")]
249+
#[test]
250+
fn test_large_hashmap() {
251+
let mut large_map = HashMap::new();
252+
for i in 0..1000u32 {
253+
large_map.insert(i, i * 2);
254+
}
255+
256+
let mut encoded = Vec::new();
257+
large_map.mls_encode(&mut encoded).unwrap();
258+
259+
let mut slice = encoded.as_slice();
260+
let decoded = HashMap::<u32, u32>::mls_decode(&mut slice).unwrap();
261+
262+
assert_eq!(large_map, decoded);
263+
assert!(slice.is_empty());
264+
}
265+
266+
#[test]
267+
fn test_large_btreemap() {
268+
let mut large_map = BTreeMap::new();
269+
for i in 0..1000u32 {
270+
large_map.insert(i, i * 2);
271+
}
272+
273+
let mut encoded = Vec::new();
274+
large_map.mls_encode(&mut encoded).unwrap();
275+
276+
let mut slice = encoded.as_slice();
277+
let decoded = BTreeMap::<u32, u32>::mls_decode(&mut slice).unwrap();
278+
279+
assert_eq!(large_map, decoded);
280+
assert!(slice.is_empty());
281+
}
282+
283+
#[test]
284+
fn test_invalid_btreemap_decode() {
285+
// Test with invalid data
286+
let invalid_data = vec![0xFF, 0xFF, 0xFF, 0xFF]; // Invalid length prefix
287+
let mut slice = invalid_data.as_slice();
288+
289+
let result = BTreeMap::<u32, u32>::mls_decode(&mut slice);
290+
assert!(result.is_err());
291+
292+
// Test with truncated data
293+
let mut valid_map = BTreeMap::new();
294+
valid_map.insert(1u32, 100u32);
295+
296+
let mut encoded = Vec::new();
297+
valid_map.mls_encode(&mut encoded).unwrap();
298+
encoded.truncate(encoded.len() - 1); // Remove last byte
299+
300+
let mut slice = encoded.as_slice();
301+
let result = BTreeMap::<u32, u32>::mls_decode(&mut slice);
302+
assert!(result.is_err());
303+
}
304+
305+
#[cfg(feature = "std")]
306+
#[test]
307+
fn test_invalid_hashmap_decode() {
308+
// Test with invalid data
309+
let invalid_data = vec![0xFF, 0xFF, 0xFF, 0xFF]; // Invalid length prefix
310+
let mut slice = invalid_data.as_slice();
311+
312+
let result = HashMap::<u32, u32>::mls_decode(&mut slice);
313+
assert!(result.is_err());
314+
315+
// Test with truncated data
316+
let mut valid_map = HashMap::new();
317+
valid_map.insert(1u32, 100u32);
318+
319+
let mut encoded = Vec::new();
320+
valid_map.mls_encode(&mut encoded).unwrap();
321+
encoded.truncate(encoded.len() - 1); // Remove last byte
322+
323+
let mut slice = encoded.as_slice();
324+
let result = HashMap::<u32, u32>::mls_decode(&mut slice);
325+
assert!(result.is_err());
326+
}
327+
}

0 commit comments

Comments
 (0)