diff --git a/src/map.rs b/src/map.rs index ce755514..761e2797 100644 --- a/src/map.rs +++ b/src/map.rs @@ -307,7 +307,8 @@ impl IndexMap { Drain::new(self.core.drain(range)) } - /// Creates an iterator which uses a closure to determine if an element should be removed. + /// Creates an iterator which uses a closure to determine if an element should be removed, + /// for all elements in the given range. /// /// If the closure returns true, the element is removed from the map and yielded. /// If the closure returns false, or panics, the element remains in the map and will not be @@ -316,6 +317,11 @@ impl IndexMap { /// Note that `extract_if` lets you mutate every value in the filter closure, regardless of /// whether you choose to keep or remove it. /// + /// The range may be any type that implements [`RangeBounds`], + /// including all of the `std::ops::Range*` types, or even a tuple pair of + /// `Bound` start and end values. To check the entire map, use `RangeFull` + /// like `map.extract_if(.., predicate)`. + /// /// If the returned `ExtractIf` is not exhausted, e.g. because it is dropped without iterating /// or the iteration short-circuits, then the remaining elements will be retained. /// Use [`retain`] with a negated predicate if you do not need the returned iterator. @@ -330,7 +336,7 @@ impl IndexMap { /// use indexmap::IndexMap; /// /// let mut map: IndexMap = (0..8).map(|x| (x, x)).collect(); - /// let extracted: IndexMap = map.extract_if(|k, _v| k % 2 == 0).collect(); + /// let extracted: IndexMap = map.extract_if(.., |k, _v| k % 2 == 0).collect(); /// /// let evens = extracted.keys().copied().collect::>(); /// let odds = map.keys().copied().collect::>(); @@ -338,11 +344,12 @@ impl IndexMap { /// assert_eq!(evens, vec![0, 2, 4, 6]); /// assert_eq!(odds, vec![1, 3, 5, 7]); /// ``` - pub fn extract_if(&mut self, pred: F) -> ExtractIf<'_, K, V, F> + pub fn extract_if(&mut self, range: R, pred: F) -> ExtractIf<'_, K, V, F> where + R: RangeBounds, F: FnMut(&K, &mut V) -> bool, { - ExtractIf::new(&mut self.core, pred) + ExtractIf::new(&mut self.core, range, pred) } /// Splits the collection into two at the given index. diff --git a/src/map/core/extract.rs b/src/map/core/extract.rs index db35d96d..ab7aa64d 100644 --- a/src/map/core/extract.rs +++ b/src/map/core/extract.rs @@ -1,34 +1,46 @@ #![allow(unsafe_code)] use super::{Bucket, IndexMapCore}; +use crate::util::simplify_range; + +use core::ops::RangeBounds; impl IndexMapCore { - pub(crate) fn extract(&mut self) -> ExtractCore<'_, K, V> { + pub(crate) fn extract(&mut self, range: R) -> ExtractCore<'_, K, V> + where + R: RangeBounds, + { + let range = simplify_range(range, self.entries.len()); + // SAFETY: We must have consistent lengths to start, so that's a hard assertion. - // Then the worst `set_len(0)` can do is leak items if `ExtractCore` doesn't drop. + // Then the worst `set_len` can do is leak items if `ExtractCore` doesn't drop. assert_eq!(self.entries.len(), self.indices.len()); unsafe { - self.entries.set_len(0); + self.entries.set_len(range.start); } ExtractCore { map: self, - current: 0, - new_len: 0, + new_len: range.start, + current: range.start, + end: range.end, } } } pub(crate) struct ExtractCore<'a, K, V> { map: &'a mut IndexMapCore, - current: usize, new_len: usize, + current: usize, + end: usize, } impl Drop for ExtractCore<'_, K, V> { fn drop(&mut self) { let old_len = self.map.indices.len(); let mut new_len = self.new_len; + debug_assert!(new_len <= self.current); + debug_assert!(self.current <= self.end); debug_assert!(self.current <= old_len); debug_assert!(old_len <= self.map.entries.capacity()); @@ -62,13 +74,12 @@ impl ExtractCore<'_, K, V> { where F: FnMut(&mut Bucket) -> bool, { - let old_len = self.map.indices.len(); - debug_assert!(old_len <= self.map.entries.capacity()); + debug_assert!(self.end <= self.map.entries.capacity()); let base = self.map.entries.as_mut_ptr(); - while self.current < old_len { + while self.current < self.end { // SAFETY: We're maintaining both indices within bounds of the original entries, so - // 0..new_len and current..old_len are always valid items for our Drop to keep. + // 0..new_len and current..indices.len() are always valid items for our Drop to keep. unsafe { let item = base.add(self.current); if pred(&mut *item) { @@ -91,6 +102,6 @@ impl ExtractCore<'_, K, V> { } pub(crate) fn remaining(&self) -> usize { - self.map.indices.len() - self.current + self.end - self.current } } diff --git a/src/map/iter.rs b/src/map/iter.rs index 41e61a18..2fc85503 100644 --- a/src/map/iter.rs +++ b/src/map/iter.rs @@ -777,21 +777,19 @@ where /// /// This `struct` is created by [`IndexMap::extract_if()`]. /// See its documentation for more. -pub struct ExtractIf<'a, K, V, F> -where - F: FnMut(&K, &mut V) -> bool, -{ +pub struct ExtractIf<'a, K, V, F> { inner: ExtractCore<'a, K, V>, pred: F, } -impl ExtractIf<'_, K, V, F> -where - F: FnMut(&K, &mut V) -> bool, -{ - pub(super) fn new(core: &mut IndexMapCore, pred: F) -> ExtractIf<'_, K, V, F> { +impl ExtractIf<'_, K, V, F> { + pub(super) fn new(core: &mut IndexMapCore, range: R, pred: F) -> ExtractIf<'_, K, V, F> + where + R: RangeBounds, + F: FnMut(&K, &mut V) -> bool, + { ExtractIf { - inner: core.extract(), + inner: core.extract(range), pred, } } @@ -817,10 +815,7 @@ where } } -impl<'a, K, V, F> fmt::Debug for ExtractIf<'a, K, V, F> -where - F: FnMut(&K, &mut V) -> bool, -{ +impl<'a, K, V, F> fmt::Debug for ExtractIf<'a, K, V, F> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ExtractIf").finish_non_exhaustive() } diff --git a/src/set.rs b/src/set.rs index 9314ff1e..f87ad0c3 100644 --- a/src/set.rs +++ b/src/set.rs @@ -257,12 +257,18 @@ impl IndexSet { Drain::new(self.map.core.drain(range)) } - /// Creates an iterator which uses a closure to determine if a value should be removed. + /// Creates an iterator which uses a closure to determine if a value should be removed, + /// for all values in the given range. /// /// If the closure returns true, then the value is removed and yielded. /// If the closure returns false, the value will remain in the list and will not be yielded /// by the iterator. /// + /// The range may be any type that implements [`RangeBounds`], + /// including all of the `std::ops::Range*` types, or even a tuple pair of + /// `Bound` start and end values. To check the entire set, use `RangeFull` + /// like `set.extract_if(.., predicate)`. + /// /// If the returned `ExtractIf` is not exhausted, e.g. because it is dropped without iterating /// or the iteration short-circuits, then the remaining elements will be retained. /// Use [`retain`] with a negated predicate if you do not need the returned iterator. @@ -277,7 +283,7 @@ impl IndexSet { /// use indexmap::IndexSet; /// /// let mut set: IndexSet = (0..8).collect(); - /// let extracted: IndexSet = set.extract_if(|v| v % 2 == 0).collect(); + /// let extracted: IndexSet = set.extract_if(.., |v| v % 2 == 0).collect(); /// /// let evens = extracted.into_iter().collect::>(); /// let odds = set.into_iter().collect::>(); @@ -285,11 +291,12 @@ impl IndexSet { /// assert_eq!(evens, vec![0, 2, 4, 6]); /// assert_eq!(odds, vec![1, 3, 5, 7]); /// ``` - pub fn extract_if(&mut self, pred: F) -> ExtractIf<'_, T, F> + pub fn extract_if(&mut self, range: R, pred: F) -> ExtractIf<'_, T, F> where + R: RangeBounds, F: FnMut(&T) -> bool, { - ExtractIf::new(&mut self.map.core, pred) + ExtractIf::new(&mut self.map.core, range, pred) } /// Splits the collection into two at the given index. diff --git a/src/set/iter.rs b/src/set/iter.rs index d34ff54e..17b700e8 100644 --- a/src/set/iter.rs +++ b/src/set/iter.rs @@ -632,21 +632,19 @@ impl fmt::Debug for UnitValue { /// /// This `struct` is created by [`IndexSet::extract_if()`]. /// See its documentation for more. -pub struct ExtractIf<'a, T, F> -where - F: FnMut(&T) -> bool, -{ +pub struct ExtractIf<'a, T, F> { inner: ExtractCore<'a, T, ()>, pred: F, } -impl ExtractIf<'_, T, F> -where - F: FnMut(&T) -> bool, -{ - pub(super) fn new(core: &mut IndexMapCore, pred: F) -> ExtractIf<'_, T, F> { +impl ExtractIf<'_, T, F> { + pub(super) fn new(core: &mut IndexMapCore, range: R, pred: F) -> ExtractIf<'_, T, F> + where + R: RangeBounds, + F: FnMut(&T) -> bool, + { ExtractIf { - inner: core.extract(), + inner: core.extract(range), pred, } } @@ -669,10 +667,7 @@ where } } -impl<'a, T, F> fmt::Debug for ExtractIf<'a, T, F> -where - F: FnMut(&T) -> bool, -{ +impl<'a, T, F> fmt::Debug for ExtractIf<'a, T, F> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("ExtractIf").finish_non_exhaustive() } diff --git a/tests/quick.rs b/tests/quick.rs index 9394a4a4..e9682115 100644 --- a/tests/quick.rs +++ b/tests/quick.rs @@ -200,7 +200,7 @@ quickcheck_limit! { let (odd, even): (Vec<_>, Vec<_>) = map.keys().copied().partition(|k| k % 2 == 1); let extracted: Vec<_> = map - .extract_if(|k, _| k % 2 == 1) + .extract_if(.., |k, _| k % 2 == 1) .map(|(k, _)| k) .collect(); @@ -222,7 +222,7 @@ quickcheck_limit! { }); let extracted: Vec<_> = map - .extract_if(|k, _| k % 2 == 1) + .extract_if(.., |k, _| k % 2 == 1) .map(|(k, _)| k) .take(limit) .collect();