Skip to content

Commit be0a14b

Browse files
NiklasJonssoncuviper
authored andcommitted
Address review feedback
(cherry picked from commit 4e1d8cef470b4d96380ebbb8bae8994db1d79f51)
1 parent c9fd0af commit be0a14b

File tree

4 files changed

+205
-66
lines changed

4 files changed

+205
-66
lines changed

src/lib.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,33 @@ impl core::fmt::Display for TryReserveError {
266266
#[cfg(feature = "std")]
267267
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
268268
impl std::error::Error for TryReserveError {}
269+
270+
// NOTE: This is copied from the slice module in the std lib.
271+
/// The error type returned by [`get_disjoint_indices_mut`][`IndexMap::get_disjoint_indices_mut`].
272+
///
273+
/// It indicates one of two possible errors:
274+
/// - An index is out-of-bounds.
275+
/// - The same index appeared multiple times in the array
276+
/// (or different but overlapping indices when ranges are provided).
277+
#[derive(Debug, Clone, PartialEq, Eq)]
278+
pub enum GetDisjointMutError {
279+
/// An index provided was out-of-bounds for the slice.
280+
IndexOutOfBounds,
281+
/// Two indices provided were overlapping.
282+
OverlappingIndices,
283+
}
284+
285+
impl core::fmt::Display for GetDisjointMutError {
286+
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
287+
let msg = match self {
288+
GetDisjointMutError::IndexOutOfBounds => "an index is out of bounds",
289+
GetDisjointMutError::OverlappingIndices => "there were overlapping indices",
290+
};
291+
292+
core::fmt::Display::fmt(msg, f)
293+
}
294+
}
295+
296+
#[cfg(feature = "std")]
297+
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
298+
impl std::error::Error for GetDisjointMutError {}

src/map.rs

Lines changed: 18 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ use std::collections::hash_map::RandomState;
4141

4242
use self::core::RingMapCore;
4343
use crate::util::third;
44-
use crate::{Bucket, Entries, Equivalent, HashValue, TryReserveError};
44+
use crate::{Bucket, Entries, Equivalent, GetDisjointMutError, HashValue, TryReserveError};
4545

4646
/// A hash table where the iteration order of the key-value pairs is independent
4747
/// of the hash values of the keys.
@@ -825,35 +825,31 @@ where
825825
}
826826
}
827827

828-
/// Return the values for `N` keys. If any key is missing a value, or there
829-
/// are duplicate keys, `None` is returned.
828+
/// Return the values for `N` keys. If any key is duplicated, this function will panic.
830829
///
831830
/// # Examples
832831
///
833832
/// ```
834833
/// let mut map = ringmap::RingMap::from([(1, 'a'), (3, 'b'), (2, 'c')]);
835-
/// assert_eq!(map.get_disjoint_mut([&2, &1]), Some([&mut 'c', &mut 'a']));
834+
/// assert_eq!(map.get_disjoint_mut([&2, &1]), [Some(&mut 'c'), Some(&mut 'a')]);
836835
/// ```
837-
pub fn get_disjoint_mut<Q, const N: usize>(&mut self, keys: [&Q; N]) -> Option<[&mut V; N]>
836+
#[allow(unsafe_code)]
837+
pub fn get_disjoint_mut<Q, const N: usize>(&mut self, keys: [&Q; N]) -> [Option<&mut V>; N]
838838
where
839839
Q: Hash + Equivalent<K> + ?Sized,
840840
{
841-
let len = self.len();
842841
let indices = keys.map(|key| self.get_index_of(key));
843-
844-
// Handle out-of-bounds indices with panic as this is an internal error in get_index_of.
845-
for idx in indices {
846-
let idx = idx?;
847-
debug_assert!(
848-
idx < len,
849-
"Index is out of range! Got '{}' but length is '{}'",
850-
idx,
851-
len
852-
);
842+
match self.as_mut_slice().get_disjoint_opt_mut(indices) {
843+
Err(GetDisjointMutError::IndexOutOfBounds) => {
844+
unreachable!(
845+
"Internal error: indices should never be OOB as we got them from get_index_of"
846+
);
847+
}
848+
Err(GetDisjointMutError::OverlappingIndices) => {
849+
panic!("duplicate keys found");
850+
}
851+
Ok(key_values) => key_values.map(|kv_opt| kv_opt.map(|kv| kv.1)),
853852
}
854-
let indices = indices.map(Option::unwrap);
855-
let entries = self.get_disjoint_indices_mut(indices)?;
856-
Some(entries.map(|(_key, value)| value))
857853
}
858854

859855
/// Remove the key-value pair equivalent to `key` and return its value.
@@ -1321,38 +1317,17 @@ impl<K, V, S> RingMap<K, V, S> {
13211317
///
13221318
/// Valid indices are *0 <= index < self.len()* and each index needs to be unique.
13231319
///
1324-
/// Computes in **O(1)** time.
1325-
///
13261320
/// # Examples
13271321
///
13281322
/// ```
13291323
/// let mut map = ringmap::RingMap::from([(1, 'a'), (3, 'b'), (2, 'c')]);
1330-
/// assert_eq!(map.get_disjoint_indices_mut([2, 0]), Some([(&2, &mut 'c'), (&1, &mut 'a')]));
1324+
/// assert_eq!(map.get_disjoint_indices_mut([2, 0]), Ok([(&2, &mut 'c'), (&1, &mut 'a')]));
13311325
/// ```
13321326
pub fn get_disjoint_indices_mut<const N: usize>(
13331327
&mut self,
13341328
indices: [usize; N],
1335-
) -> Option<[(&K, &mut V); N]> {
1336-
// SAFETY: Can't allow duplicate indices as we would return several mutable refs to the same data.
1337-
let len = self.len();
1338-
for i in 0..N {
1339-
let idx = indices[i];
1340-
if idx >= len || indices[i + 1..N].contains(&idx) {
1341-
return None;
1342-
}
1343-
}
1344-
1345-
let entries_ptr = self.as_entries_mut().as_mut_ptr();
1346-
let out = indices.map(|i| {
1347-
// SAFETY: The base pointer is valid as it comes from a slice and the deref is always
1348-
// in-bounds as we've already checked the indices above.
1349-
#[allow(unsafe_code)]
1350-
unsafe {
1351-
(*(entries_ptr.add(i))).ref_mut()
1352-
}
1353-
});
1354-
1355-
Some(out)
1329+
) -> Result<[(&K, &mut V); N], GetDisjointMutError> {
1330+
self.as_mut_slice().get_disjoint_mut(indices)
13561331
}
13571332

13581333
/// Get the first key-value pair

src/map/slice.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use super::{Bucket, IntoIter, IntoKeys, IntoValues, Iter, IterMut, Keys, Values, ValuesMut};
22
use crate::util::{slice_eq, try_simplify_range};
3+
use crate::GetDisjointMutError;
34

45
use alloc::boxed::Box;
56
use alloc::collections::VecDeque;
@@ -264,6 +265,52 @@ impl<K, V> Slice<K, V> {
264265
self.entries
265266
.partition_point(move |a| pred(&a.key, &a.value))
266267
}
268+
269+
/// Get an array of `N` key-value pairs by `N` indices
270+
///
271+
/// Valid indices are *0 <= index < self.len()* and each index needs to be unique.
272+
pub fn get_disjoint_mut<const N: usize>(
273+
&mut self,
274+
indices: [usize; N],
275+
) -> Result<[(&K, &mut V); N], GetDisjointMutError> {
276+
let indices = indices.map(Some);
277+
let key_values = self.get_disjoint_opt_mut(indices)?;
278+
Ok(key_values.map(Option::unwrap))
279+
}
280+
281+
#[allow(unsafe_code)]
282+
pub(crate) fn get_disjoint_opt_mut<const N: usize>(
283+
&mut self,
284+
indices: [Option<usize>; N],
285+
) -> Result<[Option<(&K, &mut V)>; N], GetDisjointMutError> {
286+
// SAFETY: Can't allow duplicate indices as we would return several mutable refs to the same data.
287+
let len = self.len();
288+
for i in 0..N {
289+
let Some(idx) = indices[i] else {
290+
continue;
291+
};
292+
if idx >= len {
293+
return Err(GetDisjointMutError::IndexOutOfBounds);
294+
} else if indices[i + 1..N].contains(&Some(idx)) {
295+
return Err(GetDisjointMutError::OverlappingIndices);
296+
}
297+
}
298+
299+
let entries_ptr = self.entries.as_mut_ptr();
300+
let out = indices.map(|idx_opt| {
301+
match idx_opt {
302+
Some(idx) => {
303+
// SAFETY: The base pointer is valid as it comes from a slice and the reference is always
304+
// in-bounds & unique as we've already checked the indices above.
305+
let kv = unsafe { (*(entries_ptr.add(idx))).ref_mut() };
306+
Some(kv)
307+
}
308+
None => None,
309+
}
310+
});
311+
312+
Ok(out)
313+
}
267314
}
268315

269316
impl<'a, K, V> IntoIterator for &'a Slice<K, V> {

src/map/tests.rs

Lines changed: 110 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -927,28 +927,31 @@ move_index_oob!(test_move_index_out_of_bounds_max_0, usize::MAX, 0);
927927
#[test]
928928
fn disjoint_mut_empty_map() {
929929
let mut map: RingMap<u32, u32> = RingMap::default();
930-
assert!(map.get_disjoint_mut([&0, &1, &2, &3]).is_none());
930+
assert_eq!(
931+
map.get_disjoint_mut([&0, &1, &2, &3]),
932+
[None, None, None, None]
933+
);
931934
}
932935

933936
#[test]
934937
fn disjoint_mut_empty_param() {
935938
let mut map: RingMap<u32, u32> = RingMap::default();
936939
map.insert(1, 10);
937-
assert!(map.get_disjoint_mut([] as [&u32; 0]).is_some());
940+
assert_eq!(map.get_disjoint_mut([] as [&u32; 0]), []);
938941
}
939942

940943
#[test]
941944
fn disjoint_mut_single_fail() {
942945
let mut map: RingMap<u32, u32> = RingMap::default();
943946
map.insert(1, 10);
944-
assert!(map.get_disjoint_mut([&0]).is_none());
947+
assert_eq!(map.get_disjoint_mut([&0]), [None]);
945948
}
946949

947950
#[test]
948951
fn disjoint_mut_single_success() {
949952
let mut map: RingMap<u32, u32> = RingMap::default();
950953
map.insert(1, 10);
951-
assert_eq!(map.get_disjoint_mut([&1]), Some([&mut 10]));
954+
assert_eq!(map.get_disjoint_mut([&1]), [Some(&mut 10)]);
952955
}
953956

954957
#[test]
@@ -958,11 +961,22 @@ fn disjoint_mut_multi_success() {
958961
map.insert(2, 200);
959962
map.insert(3, 300);
960963
map.insert(4, 400);
961-
assert_eq!(map.get_disjoint_mut([&1, &2]), Some([&mut 100, &mut 200]));
962-
assert_eq!(map.get_disjoint_mut([&1, &3]), Some([&mut 100, &mut 300]));
964+
assert_eq!(
965+
map.get_disjoint_mut([&1, &2]),
966+
[Some(&mut 100), Some(&mut 200)]
967+
);
968+
assert_eq!(
969+
map.get_disjoint_mut([&1, &3]),
970+
[Some(&mut 100), Some(&mut 300)]
971+
);
963972
assert_eq!(
964973
map.get_disjoint_mut([&3, &1, &4, &2]),
965-
Some([&mut 300, &mut 100, &mut 400, &mut 200])
974+
[
975+
Some(&mut 300),
976+
Some(&mut 100),
977+
Some(&mut 400),
978+
Some(&mut 200)
979+
]
966980
);
967981
}
968982

@@ -973,44 +987,117 @@ fn disjoint_mut_multi_success_unsized_key() {
973987
map.insert("2", 200);
974988
map.insert("3", 300);
975989
map.insert("4", 400);
976-
assert_eq!(map.get_disjoint_mut(["1", "2"]), Some([&mut 100, &mut 200]));
977-
assert_eq!(map.get_disjoint_mut(["1", "3"]), Some([&mut 100, &mut 300]));
990+
991+
assert_eq!(
992+
map.get_disjoint_mut(["1", "2"]),
993+
[Some(&mut 100), Some(&mut 200)]
994+
);
995+
assert_eq!(
996+
map.get_disjoint_mut(["1", "3"]),
997+
[Some(&mut 100), Some(&mut 300)]
998+
);
978999
assert_eq!(
9791000
map.get_disjoint_mut(["3", "1", "4", "2"]),
980-
Some([&mut 300, &mut 100, &mut 400, &mut 200])
1001+
[
1002+
Some(&mut 300),
1003+
Some(&mut 100),
1004+
Some(&mut 400),
1005+
Some(&mut 200)
1006+
]
1007+
);
1008+
}
1009+
1010+
#[test]
1011+
fn disjoint_mut_multi_success_borrow_key() {
1012+
let mut map: RingMap<String, u32> = RingMap::default();
1013+
map.insert("1".into(), 100);
1014+
map.insert("2".into(), 200);
1015+
map.insert("3".into(), 300);
1016+
map.insert("4".into(), 400);
1017+
1018+
assert_eq!(
1019+
map.get_disjoint_mut(["1", "2"]),
1020+
[Some(&mut 100), Some(&mut 200)]
1021+
);
1022+
assert_eq!(
1023+
map.get_disjoint_mut(["1", "3"]),
1024+
[Some(&mut 100), Some(&mut 300)]
1025+
);
1026+
assert_eq!(
1027+
map.get_disjoint_mut(["3", "1", "4", "2"]),
1028+
[
1029+
Some(&mut 300),
1030+
Some(&mut 100),
1031+
Some(&mut 400),
1032+
Some(&mut 200)
1033+
]
9811034
);
9821035
}
9831036

9841037
#[test]
9851038
fn disjoint_mut_multi_fail_missing() {
1039+
let mut map: RingMap<u32, u32> = RingMap::default();
1040+
map.insert(1, 100);
1041+
map.insert(2, 200);
1042+
map.insert(3, 300);
1043+
map.insert(4, 400);
1044+
1045+
assert_eq!(map.get_disjoint_mut([&1, &5]), [Some(&mut 100), None]);
1046+
assert_eq!(map.get_disjoint_mut([&5, &6]), [None, None]);
1047+
assert_eq!(
1048+
map.get_disjoint_mut([&1, &5, &4]),
1049+
[Some(&mut 100), None, Some(&mut 400)]
1050+
);
1051+
}
1052+
1053+
#[test]
1054+
#[should_panic]
1055+
fn disjoint_mut_multi_fail_duplicate_panic() {
1056+
let mut map: RingMap<u32, u32> = RingMap::default();
1057+
map.insert(1, 100);
1058+
map.get_disjoint_mut([&1, &2, &1]);
1059+
}
1060+
1061+
#[test]
1062+
fn disjoint_indices_mut_fail_oob() {
1063+
let mut map: RingMap<u32, u32> = RingMap::default();
1064+
map.insert(1, 10);
1065+
map.insert(321, 20);
1066+
assert_eq!(
1067+
map.get_disjoint_indices_mut([1, 3]),
1068+
Err(crate::GetDisjointMutError::IndexOutOfBounds)
1069+
);
1070+
}
1071+
1072+
#[test]
1073+
fn disjoint_indices_mut_empty() {
9861074
let mut map: RingMap<u32, u32> = RingMap::default();
9871075
map.insert(1, 10);
988-
map.insert(1123, 100);
9891076
map.insert(321, 20);
990-
map.insert(1337, 30);
991-
assert_eq!(map.get_disjoint_mut([&121, &1123]), None);
992-
assert_eq!(map.get_disjoint_mut([&1, &1337, &56]), None);
993-
assert_eq!(map.get_disjoint_mut([&1337, &123, &321, &1, &1123]), None);
1077+
assert_eq!(map.get_disjoint_indices_mut([]), Ok([]));
9941078
}
9951079

9961080
#[test]
997-
fn disjoint_mut_multi_fail_duplicate() {
1081+
fn disjoint_indices_mut_success() {
9981082
let mut map: RingMap<u32, u32> = RingMap::default();
9991083
map.insert(1, 10);
1000-
map.insert(1123, 100);
10011084
map.insert(321, 20);
1002-
map.insert(1337, 30);
1003-
assert_eq!(map.get_disjoint_mut([&1, &1]), None);
1085+
assert_eq!(map.get_disjoint_indices_mut([0]), Ok([(&1, &mut 10)]));
1086+
1087+
assert_eq!(map.get_disjoint_indices_mut([1]), Ok([(&321, &mut 20)]));
10041088
assert_eq!(
1005-
map.get_disjoint_mut([&1337, &123, &321, &1337, &1, &1123]),
1006-
None
1089+
map.get_disjoint_indices_mut([0, 1]),
1090+
Ok([(&1, &mut 10), (&321, &mut 20)])
10071091
);
10081092
}
10091093

10101094
#[test]
1011-
fn many_index_mut_fail_oob() {
1095+
fn disjoint_indices_mut_fail_duplicate() {
10121096
let mut map: RingMap<u32, u32> = RingMap::default();
10131097
map.insert(1, 10);
10141098
map.insert(321, 20);
1015-
assert_eq!(map.get_disjoint_indices_mut([1, 3]), None);
1099+
assert_eq!(
1100+
map.get_disjoint_indices_mut([1, 2, 1]),
1101+
Err(crate::GetDisjointMutError::OverlappingIndices)
1102+
);
10161103
}

0 commit comments

Comments
 (0)