From 7a6c76212cc71e23e76c83270892c9b94ecdc570 Mon Sep 17 00:00:00 2001 From: Max Isom Date: Thu, 24 Oct 2024 15:13:36 -0700 Subject: [PATCH] Use common binary_search_by() method --- rust/blockstore/src/arrow/block/types.rs | 217 +++++++++++++---------- 1 file changed, 124 insertions(+), 93 deletions(-) diff --git a/rust/blockstore/src/arrow/block/types.rs b/rust/blockstore/src/arrow/block/types.rs index 6e4871469130..eef18065efe4 100644 --- a/rust/blockstore/src/arrow/block/types.rs +++ b/rust/blockstore/src/arrow/block/types.rs @@ -1,4 +1,4 @@ -use std::cmp::Ordering::{Equal, Greater, Less}; +use std::cmp::Ordering; use std::collections::HashMap; use std::io::SeekFrom; use std::ops::{Bound, RangeBounds}; @@ -119,15 +119,35 @@ impl Block { delta } - /// Binary search the block to find the last index of the specified prefix. - /// Returns None if prefix does not exist in the block. + /// Binary searches this slice with a comparator function. /// - /// Partly based on `std::slice::binary_search_by`: https://doc.rust-lang.org/src/core/slice/mod.rs.html#2770 + /// The comparator function should return an order code that indicates + /// whether its argument is `Less`, `Equal` or `Greater` the desired + /// target. + /// If the slice is not sorted or if the comparator function does not + /// implement an order consistent with the sort order of the underlying + /// slice, the returned result is unspecified and meaningless. + /// + /// If the value is found then [`Result::Ok`] is returned, containing the + /// index of the matching element. If there are multiple matches, then any + /// one of the matches could be returned. The index is chosen + /// deterministically, but is subject to change in future versions of Rust. + /// If the value is not found then [`Result::Err`] is returned, containing + /// the index where a matching element could be inserted while maintaining + /// sorted order. + /// + /// Based on std::slice::binary_search_by with minimal modifications (https://doc.rust-lang.org/src/core/slice/mod.rs.html#2770). #[inline] - fn binary_search_last_index(&self, prefix: &str) -> Option { + fn binary_search_by<'me, K: ArrowReadableKey<'me>, F>( + &'me self, + mut f: F, + ) -> Result + where + F: FnMut((&'me str, K)) -> Ordering, + { let mut size = self.len(); if size == 0 { - return None; + return Err(0); } let prefix_array = self @@ -136,7 +156,7 @@ impl Block { .as_any() .downcast_ref::() .unwrap(); - let mut base = self.len() - 1; + let mut base = 0; // This loop intentionally doesn't have an early exit if the comparison // returns Equal. We want the number of loop iterations to depend *only* @@ -144,51 +164,58 @@ impl Block { // the loop count. while size > 1 { let half = size / 2; - let mid = base - half; + let mid = base + half; // SAFETY: the call is made safe by the following inconstants: - // - `mid < size`: by definition - // - `mid >= 0`: `mid = size - 1 - size / 2 - size / 4 ...` - let cmp = prefix_array.value(mid).cmp(prefix); - - base = if cmp == Greater { mid } else { base }; + // - `mid >= 0`: by definition + // - `mid < size`: `mid = size / 2 + size / 4 + size / 8 ...` + let prefix = prefix_array.value(mid); + let key = K::get(self.data.column(1), mid); + let cmp = f((prefix, key)); + + base = if cmp == Ordering::Greater { base } else { mid }; + + // This is imprecise in the case where `size` is odd and the + // comparison returns Greater: the mid element still gets included + // by `size` even though it's known to be larger than the element + // being searched for. + // + // This is fine though: we gain more performance by keeping the + // loop iteration count invariant (and thus predictable) than we + // lose from considering one additional element. size -= half; } - // SAFETY: `base` is always in [0, size) because `base < size` by init. - // `base` should be the last index where the element matches the target prefix, - // or 0 if the first element is already larger than the target prefix. - match prefix_array.value(base).cmp(prefix) { - Less => None, - Equal => Some(base), - Greater => { - if base > 0 && prefix_array.value(base - 1) == prefix { - Some(base - 1) - } else { - None - } - } + // SAFETY: base is always in [0, size) because base <= mid. + let prefix = prefix_array.value(base); + let key = K::get(self.data.column(1), base); + let cmp = f((prefix, key)); + if cmp == Ordering::Equal { + Ok(base) + } else { + let result = base + (cmp == Ordering::Less) as usize; + Err(result) } } - /// Binary search the blockfile to find the partition point of the specified prefix and key. - /// - /// `(prefix, key)` serves as the search key, and it is sorted in ascending order. - /// The partition predicate is defined by: `|x| x < (prefix, key)`. - /// The partition point is the first index where the partition precidate evaluates to `false` - /// The code is a result of inlining this predicate in [`std::slice::partition_point`]. - /// If the key is unspecified (i.e. `None`), we find the first index of the prefix. - /// - /// Partly based on `std::slice::binary_search_by`: https://doc.rust-lang.org/src/core/slice/mod.rs.html#2770 + /// Returns the largest index where `prefixes[index] == prefix` or None if the provided prefix does not exist in the block. #[inline] - fn binary_search_index<'me, K: ArrowReadableKey<'me>>( + fn find_largest_index_of_prefix<'me, K: ArrowReadableKey<'me>>( &'me self, prefix: &str, - key: Option<&K>, - ) -> usize { - let mut size = self.len(); - if size == 0 { - return 0; + ) -> Option { + // By design, will never find an exact match (comparator never evaluates to Equal). This finds the index of the first element that is greater than the prefix. If no element is greater, it returns the length of the array. + let result = self + .binary_search_by::(|(p, _)| match p.cmp(prefix) { + Ordering::Less => Ordering::Less, + Ordering::Equal => Ordering::Less, + Ordering::Greater => Ordering::Greater, + }) + .expect_err("Never returns Ok because the comparator never evaluates to Equal."); + + if result == 0 { + // The first element is greater than the target prefix, so the target prefix does not exist in the block. + return None; } let prefix_array = self @@ -197,49 +224,48 @@ impl Block { .as_any() .downcast_ref::() .unwrap(); - let mut base = 0; - // This loop intentionally doesn't have an early exit if the comparison - // returns Equal. We want the number of loop iterations to depend *only* - // on the size of the input slice so that the CPU can reliably predict - // the loop count. - while size > 1 { - let half = size / 2; - let mid = base + half; - - // SAFETY: the call is made safe by the following inconstants: - // - `mid >= 0`: by definition - // - `mid < size`: `mid = size / 2 + size / 4 + size / 8 ...` - let mut cmp = prefix_array.value(mid).cmp(prefix); - - // Continue to compare the key if prefix matches - if let (Equal, Some(k)) = (cmp, key) { - cmp = K::get(self.data.column(1), mid) - // Key type do not have total order because of floating point values - // But in our case NaN should not be allowed so we should always have total order - .partial_cmp(k) - .expect("Array values should be comparable."); - } - - base = if cmp == Less { mid } else { base }; - size -= half; + // `result` is the first index where the prefix is larger than the input (or the length of the array) so we want one element before this. + match prefix_array.value(result - 1).cmp(prefix) { + // We're at the end of the array, so the prefix does not exist in the block (all values are less than the prefix) + Ordering::Less => None, + // The prefix exists + Ordering::Equal => Some(result - 1), + // This is impossible + Ordering::Greater => None, } + } - // SAFETY: `base` is always in [0, size) because `base <= mid`. - // `base` should be the last index where the element is smaller than the target, - // or 0 if the first element is already larger than the target. - match prefix_array.value(base).cmp(prefix) { - Less => base + 1, - Equal => match key { - // Key type do not have total order because of floating point values - // But in our case NaN should not be allowed so we should always have total order - Some(k) => match K::get(self.data.column(1), base).partial_cmp(k) { - Some(Less) => base + 1, - _ => base, - }, - None => base, - }, - Greater => base, + /// Finds the partition point of the prefix and key. + /// Returns the index of the first element that matches the target prefix and key. If no element matches, returns the index at which the target prefix and key could be inserted to maintain sorted order. + #[inline] + fn get_key_prefix_partition_point<'me, K: ArrowReadableKey<'me>>( + &'me self, + prefix: &str, + key: Option<&K>, + ) -> usize { + // By design, will never find an exact match (comparator never evaluates to Equal). This finds the index of the first element that matches the target prefix and key. If no element matches, it returns the index at which the target prefix and key could be inserted to maintain sorted order. + if let Some(key) = key { + self.binary_search_by::(|(p, k)| { + match p.cmp(prefix).then_with(|| { + k.partial_cmp(key) + // The key type does not have a total order because of floating point values. + // But in our case NaN is not allowed, so we should always have total order. + .expect("Array values should be comparable.") + }) { + Ordering::Less => Ordering::Less, + Ordering::Equal => Ordering::Greater, + Ordering::Greater => Ordering::Greater, + } + }) + .expect_err("Never returns Ok because the comparator never evaluates to Equal.") + } else { + self.binary_search_by::(|(p, _)| match p.cmp(prefix) { + Ordering::Less => Ordering::Less, + Ordering::Equal => Ordering::Greater, + Ordering::Greater => Ordering::Greater, + }) + .expect_err("Never returns Ok because the comparator never evaluates to Equal.") } } @@ -262,7 +288,7 @@ impl Block { prefix_array.value(index).cmp(prefix), K::get(self.data.column(1), index).partial_cmp(key), ), - (Equal, Some(Equal)) + (Ordering::Equal, Some(Ordering::Equal)) ) } @@ -278,11 +304,16 @@ impl Block { prefix: &str, key: K, ) -> Option { - let index = self.binary_search_index(prefix, Some(&key)); - if self.match_prefix_key_at_index(prefix, &key, index) { - Some(V::get(self.data.column(2), index)) - } else { - None + match self.binary_search_by::(|(p, k)| { + p.cmp(prefix).then_with(|| { + k.partial_cmp(&key) + // The key type does not have a total order because of floating point values. + // But in our case NaN is not allowed, so we should always have total order. + .expect("Array values should be comparable.") + }) + }) { + Ok(index) => Some(V::get(self.data.column(2), index)), + Err(_) => None, } } @@ -308,16 +339,16 @@ impl Block { { let start_index = match prefix_range.start_bound() { Bound::Included(prefix) => match key_range.start_bound() { - Bound::Included(key) => self.binary_search_index(prefix, Some(key)), + Bound::Included(key) => self.get_key_prefix_partition_point(prefix, Some(key)), Bound::Excluded(key) => { - let index = self.binary_search_index(prefix, Some(key)); + let index = self.get_key_prefix_partition_point(prefix, Some(key)); if self.match_prefix_key_at_index(prefix, key, index) { index + 1 } else { index } } - Bound::Unbounded => self.binary_search_index::(prefix, None), + Bound::Unbounded => self.get_key_prefix_partition_point::(prefix, None), }, Bound::Excluded(_) => { unimplemented!("Excluded prefix range is not currently supported") @@ -328,15 +359,15 @@ impl Block { let end_index = match prefix_range.end_bound() { Bound::Included(prefix) => match key_range.end_bound() { Bound::Included(key) => { - let index = self.binary_search_index(prefix, Some(key)); + let index = self.get_key_prefix_partition_point(prefix, Some(key)); if self.match_prefix_key_at_index(prefix, key, index) { index + 1 } else { index } } - Bound::Excluded(key) => self.binary_search_index(prefix, Some(key)), - Bound::Unbounded => match self.binary_search_last_index(prefix) { + Bound::Excluded(key) => self.get_key_prefix_partition_point(prefix, Some(key)), + Bound::Unbounded => match self.find_largest_index_of_prefix::(prefix) { Some(last_index_of_prefix) => last_index_of_prefix + 1, // (add 1 because end_index is exclusive below) None => start_index, // prefix does not exist in the block so we shouldn't return anything },