Skip to content

Commit

Permalink
Implement get_many_mut
Browse files Browse the repository at this point in the history
  • Loading branch information
NiklasJonsson committed Jul 14, 2022
1 parent 9a5b1fa commit d3f1968
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 0 deletions.
46 changes: 46 additions & 0 deletions src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,52 @@ where
}
}

pub fn get_many_mut<'a, 'b, Q: ?Sized, const N: usize>(
&'a mut self,
keys: [&'b Q; N],
) -> Option<[&'a mut V; N]>
where
Q: Hash + Equivalent<K>,
{
let indices = keys.map(|key| self.get_index_of(key));
if indices.iter().any(Option::is_none) {
return None;
}
let indices = indices.map(Option::unwrap);

// SAFETY: Can't allow duplicate indices as we would return several mutable refs to the same data
for i in 0..N {
let idx = indices[i];
if indices[i + 1..N].contains(&idx) {
return None;
}
}

// Replace with MaybeUninit::uninit_array when that is stable
// SAFETY: Creating MaybeUninit from uninit is always safe
#[allow(unsafe_code)]
let mut out: [std::mem::MaybeUninit<&'a mut V>; N] =
unsafe { std::mem::MaybeUninit::uninit().assume_init() };

let entries = self.as_entries_mut();
for (elem, idx) in out.iter_mut().zip(indices) {
let v: &mut V = &mut entries[idx].value;
// SAFETY: As we know that each index is unique, it is OK to discard the mutable
// borrow lifetime of v, we will never mutably borrow an element twice.
// The pointer is valid and aligned as we get it from MaybeUninit.
#[allow(unsafe_code)]
unsafe { std::ptr::write(elem.as_mut_ptr(), &mut *(v as *mut V)) };
}

// Can't transmute a const-sized array:
// https://github.com/rust-lang/rust/issues/61956
// This is the workaround.
// SAFETY: This is fine as the references all are from unique entries that we own and all of
// them have been properly initialized by the above loop.
#[allow(unsafe_code)]
Some(unsafe { std::mem::transmute_copy::<_, [&'a mut V; N]>(&out) })
}

/// Remove the key-value pair equivalent to `key` and return
/// its value.
///
Expand Down
61 changes: 61 additions & 0 deletions src/map/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,3 +418,64 @@ fn from_array() {

assert_eq!(map, expected)
}

#[test]
fn many_mut_empty() {
let mut map: IndexMap<u32, u32> = IndexMap::default();
assert!(map.get_many_mut([&0, &1, &2, &3]).is_none());
}

#[test]
fn many_mut_single_fail() {
let mut map: IndexMap<u32, u32> = IndexMap::default();
map.insert(1, 10);
assert!(map.get_many_mut([&0]).is_none());
}

#[test]
fn many_mut_single_success() {
let mut map: IndexMap<u32, u32> = IndexMap::default();
map.insert(1, 10);
assert_eq!(map.get_many_mut([&1]), Some([&mut 10]));
}

#[test]
fn many_mut_multi_success() {
let mut map: IndexMap<u32, u32> = IndexMap::default();
map.insert(1, 10);
map.insert(1123, 100);
map.insert(321, 20);
map.insert(1337, 30);
assert_eq!(map.get_many_mut([&1, &1123]), Some([&mut 10, &mut 100]));
assert_eq!(map.get_many_mut([&1, &1337]), Some([&mut 10, &mut 30]));
assert_eq!(
map.get_many_mut([&1337, &321, &1, &1123]),
Some([&mut 30, &mut 20, &mut 10, &mut 100])
);
}

#[test]
fn many_mut_multi_fail_missing() {
let mut map: IndexMap<u32, u32> = IndexMap::default();
map.insert(1, 10);
map.insert(1123, 100);
map.insert(321, 20);
map.insert(1337, 30);
assert_eq!(map.get_many_mut([&121, &1123]), None);
assert_eq!(map.get_many_mut([&1, &1337, &56]), None);
assert_eq!(map.get_many_mut([&1337, &123, &321, &1, &1123]), None);
}

#[test]
fn many_mut_multi_fail_duplicate() {
let mut map: IndexMap<u32, u32> = IndexMap::default();
map.insert(1, 10);
map.insert(1123, 100);
map.insert(321, 20);
map.insert(1337, 30);
assert_eq!(map.get_many_mut([&1, &1]), None);
assert_eq!(
map.get_many_mut([&1337, &123, &321, &1337, &1, &1123]),
None
);
}

0 comments on commit d3f1968

Please sign in to comment.