Skip to content

Commit

Permalink
Add range support to extract_if
Browse files Browse the repository at this point in the history
  • Loading branch information
cuviper committed Dec 8, 2024
1 parent 5bd12c4 commit ee03a3e
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 49 deletions.
15 changes: 11 additions & 4 deletions src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,8 @@ impl<K, V, S> IndexMap<K, V, S> {
Drain::new(self.core.drain(range))
}

/// Creates an iterator which uses a closure to determine if an element should be removed.
/// Creates an iterator which uses a closure to determine if an element should be removed,
/// for all elements in the given range.
///
/// If the closure returns true, the element is removed from the map and yielded.
/// If the closure returns false, or panics, the element remains in the map and will not be
Expand All @@ -316,6 +317,11 @@ impl<K, V, S> IndexMap<K, V, S> {
/// Note that `extract_if` lets you mutate every value in the filter closure, regardless of
/// whether you choose to keep or remove it.
///
/// The range may be any type that implements [`RangeBounds<usize>`],
/// including all of the `std::ops::Range*` types, or even a tuple pair of
/// `Bound` start and end values. To check the entire map, use `RangeFull`
/// like `map.extract_if(.., predicate)`.
///
/// If the returned `ExtractIf` is not exhausted, e.g. because it is dropped without iterating
/// or the iteration short-circuits, then the remaining elements will be retained.
/// Use [`retain`] with a negated predicate if you do not need the returned iterator.
Expand All @@ -330,19 +336,20 @@ impl<K, V, S> IndexMap<K, V, S> {
/// use indexmap::IndexMap;
///
/// let mut map: IndexMap<i32, i32> = (0..8).map(|x| (x, x)).collect();
/// let extracted: IndexMap<i32, i32> = map.extract_if(|k, _v| k % 2 == 0).collect();
/// let extracted: IndexMap<i32, i32> = map.extract_if(.., |k, _v| k % 2 == 0).collect();
///
/// let evens = extracted.keys().copied().collect::<Vec<_>>();
/// let odds = map.keys().copied().collect::<Vec<_>>();
///
/// assert_eq!(evens, vec![0, 2, 4, 6]);
/// assert_eq!(odds, vec![1, 3, 5, 7]);
/// ```
pub fn extract_if<F>(&mut self, pred: F) -> ExtractIf<'_, K, V, F>
pub fn extract_if<R, F>(&mut self, range: R, pred: F) -> ExtractIf<'_, K, V, F>
where
R: RangeBounds<usize>,
F: FnMut(&K, &mut V) -> bool,
{
ExtractIf::new(&mut self.core, pred)
ExtractIf::new(&mut self.core, range, pred)
}

/// Splits the collection into two at the given index.
Expand Down
33 changes: 22 additions & 11 deletions src/map/core/extract.rs
Original file line number Diff line number Diff line change
@@ -1,34 +1,46 @@
#![allow(unsafe_code)]

use super::{Bucket, IndexMapCore};
use crate::util::simplify_range;

use core::ops::RangeBounds;

impl<K, V> IndexMapCore<K, V> {
pub(crate) fn extract(&mut self) -> ExtractCore<'_, K, V> {
pub(crate) fn extract<R>(&mut self, range: R) -> ExtractCore<'_, K, V>
where
R: RangeBounds<usize>,
{
let range = simplify_range(range, self.entries.len());

// SAFETY: We must have consistent lengths to start, so that's a hard assertion.
// Then the worst `set_len(0)` can do is leak items if `ExtractCore` doesn't drop.
// Then the worst `set_len` can do is leak items if `ExtractCore` doesn't drop.
assert_eq!(self.entries.len(), self.indices.len());
unsafe {
self.entries.set_len(0);
self.entries.set_len(range.start);
}
ExtractCore {
map: self,
current: 0,
new_len: 0,
new_len: range.start,
current: range.start,
end: range.end,
}
}
}

pub(crate) struct ExtractCore<'a, K, V> {
map: &'a mut IndexMapCore<K, V>,
current: usize,
new_len: usize,
current: usize,
end: usize,
}

impl<K, V> Drop for ExtractCore<'_, K, V> {
fn drop(&mut self) {
let old_len = self.map.indices.len();
let mut new_len = self.new_len;

debug_assert!(new_len <= self.current);
debug_assert!(self.current <= self.end);
debug_assert!(self.current <= old_len);
debug_assert!(old_len <= self.map.entries.capacity());

Expand Down Expand Up @@ -62,13 +74,12 @@ impl<K, V> ExtractCore<'_, K, V> {
where
F: FnMut(&mut Bucket<K, V>) -> bool,
{
let old_len = self.map.indices.len();
debug_assert!(old_len <= self.map.entries.capacity());
debug_assert!(self.end <= self.map.entries.capacity());

let base = self.map.entries.as_mut_ptr();
while self.current < old_len {
while self.current < self.end {
// SAFETY: We're maintaining both indices within bounds of the original entries, so
// 0..new_len and current..old_len are always valid items for our Drop to keep.
// 0..new_len and current..indices.len() are always valid items for our Drop to keep.
unsafe {
let item = base.add(self.current);
if pred(&mut *item) {
Expand All @@ -91,6 +102,6 @@ impl<K, V> ExtractCore<'_, K, V> {
}

pub(crate) fn remaining(&self) -> usize {
self.map.indices.len() - self.current
self.end - self.current
}
}
23 changes: 9 additions & 14 deletions src/map/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -777,21 +777,19 @@ where
///
/// This `struct` is created by [`IndexMap::extract_if()`].
/// See its documentation for more.
pub struct ExtractIf<'a, K, V, F>
where
F: FnMut(&K, &mut V) -> bool,
{
pub struct ExtractIf<'a, K, V, F> {
inner: ExtractCore<'a, K, V>,
pred: F,
}

impl<K, V, F> ExtractIf<'_, K, V, F>
where
F: FnMut(&K, &mut V) -> bool,
{
pub(super) fn new(core: &mut IndexMapCore<K, V>, pred: F) -> ExtractIf<'_, K, V, F> {
impl<K, V, F> ExtractIf<'_, K, V, F> {
pub(super) fn new<R>(core: &mut IndexMapCore<K, V>, range: R, pred: F) -> ExtractIf<'_, K, V, F>
where
R: RangeBounds<usize>,
F: FnMut(&K, &mut V) -> bool,
{
ExtractIf {
inner: core.extract(),
inner: core.extract(range),
pred,
}
}
Expand All @@ -817,10 +815,7 @@ where
}
}

impl<'a, K, V, F> fmt::Debug for ExtractIf<'a, K, V, F>
where
F: FnMut(&K, &mut V) -> bool,
{
impl<'a, K, V, F> fmt::Debug for ExtractIf<'a, K, V, F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ExtractIf").finish_non_exhaustive()
}
Expand Down
15 changes: 11 additions & 4 deletions src/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,12 +257,18 @@ impl<T, S> IndexSet<T, S> {
Drain::new(self.map.core.drain(range))
}

/// Creates an iterator which uses a closure to determine if a value should be removed.
/// Creates an iterator which uses a closure to determine if a value should be removed,
/// for all values in the given range.
///
/// If the closure returns true, then the value is removed and yielded.
/// If the closure returns false, the value will remain in the list and will not be yielded
/// by the iterator.
///
/// The range may be any type that implements [`RangeBounds<usize>`],
/// including all of the `std::ops::Range*` types, or even a tuple pair of
/// `Bound` start and end values. To check the entire set, use `RangeFull`
/// like `set.extract_if(.., predicate)`.
///
/// If the returned `ExtractIf` is not exhausted, e.g. because it is dropped without iterating
/// or the iteration short-circuits, then the remaining elements will be retained.
/// Use [`retain`] with a negated predicate if you do not need the returned iterator.
Expand All @@ -277,19 +283,20 @@ impl<T, S> IndexSet<T, S> {
/// use indexmap::IndexSet;
///
/// let mut set: IndexSet<i32> = (0..8).collect();
/// let extracted: IndexSet<i32> = set.extract_if(|v| v % 2 == 0).collect();
/// let extracted: IndexSet<i32> = set.extract_if(.., |v| v % 2 == 0).collect();
///
/// let evens = extracted.into_iter().collect::<Vec<_>>();
/// let odds = set.into_iter().collect::<Vec<_>>();
///
/// assert_eq!(evens, vec![0, 2, 4, 6]);
/// assert_eq!(odds, vec![1, 3, 5, 7]);
/// ```
pub fn extract_if<F>(&mut self, pred: F) -> ExtractIf<'_, T, F>
pub fn extract_if<R, F>(&mut self, range: R, pred: F) -> ExtractIf<'_, T, F>
where
R: RangeBounds<usize>,
F: FnMut(&T) -> bool,
{
ExtractIf::new(&mut self.map.core, pred)
ExtractIf::new(&mut self.map.core, range, pred)
}

/// Splits the collection into two at the given index.
Expand Down
23 changes: 9 additions & 14 deletions src/set/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -632,21 +632,19 @@ impl<I: fmt::Debug> fmt::Debug for UnitValue<I> {
///
/// This `struct` is created by [`IndexSet::extract_if()`].
/// See its documentation for more.
pub struct ExtractIf<'a, T, F>
where
F: FnMut(&T) -> bool,
{
pub struct ExtractIf<'a, T, F> {
inner: ExtractCore<'a, T, ()>,
pred: F,
}

impl<T, F> ExtractIf<'_, T, F>
where
F: FnMut(&T) -> bool,
{
pub(super) fn new(core: &mut IndexMapCore<T, ()>, pred: F) -> ExtractIf<'_, T, F> {
impl<T, F> ExtractIf<'_, T, F> {
pub(super) fn new<R>(core: &mut IndexMapCore<T, ()>, range: R, pred: F) -> ExtractIf<'_, T, F>
where
R: RangeBounds<usize>,
F: FnMut(&T) -> bool,
{
ExtractIf {
inner: core.extract(),
inner: core.extract(range),
pred,
}
}
Expand All @@ -669,10 +667,7 @@ where
}
}

impl<'a, T, F> fmt::Debug for ExtractIf<'a, T, F>
where
F: FnMut(&T) -> bool,
{
impl<'a, T, F> fmt::Debug for ExtractIf<'a, T, F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ExtractIf").finish_non_exhaustive()
}
Expand Down
4 changes: 2 additions & 2 deletions tests/quick.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ quickcheck_limit! {
let (odd, even): (Vec<_>, Vec<_>) = map.keys().copied().partition(|k| k % 2 == 1);

let extracted: Vec<_> = map
.extract_if(|k, _| k % 2 == 1)
.extract_if(.., |k, _| k % 2 == 1)
.map(|(k, _)| k)
.collect();

Expand All @@ -222,7 +222,7 @@ quickcheck_limit! {
});

let extracted: Vec<_> = map
.extract_if(|k, _| k % 2 == 1)
.extract_if(.., |k, _| k % 2 == 1)
.map(|(k, _)| k)
.take(limit)
.collect();
Expand Down

0 comments on commit ee03a3e

Please sign in to comment.