Skip to content

Commit

Permalink
Use common binary_search_by() method
Browse files Browse the repository at this point in the history
  • Loading branch information
codetheweb committed Oct 24, 2024
1 parent db61125 commit 7a6c762
Showing 1 changed file with 124 additions and 93 deletions.
217 changes: 124 additions & 93 deletions rust/blockstore/src/arrow/block/types.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::cmp::Ordering::{Equal, Greater, Less};
use std::cmp::Ordering;
use std::collections::HashMap;
use std::io::SeekFrom;
use std::ops::{Bound, RangeBounds};
Expand Down Expand Up @@ -119,15 +119,35 @@ impl Block {
delta
}

/// Binary search the block to find the last index of the specified prefix.
/// Returns None if prefix does not exist in the block.
/// Binary searches this slice with a comparator function.
///
/// Partly based on `std::slice::binary_search_by`: https://doc.rust-lang.org/src/core/slice/mod.rs.html#2770
/// The comparator function should return an order code that indicates
/// whether its argument is `Less`, `Equal` or `Greater` the desired
/// target.
/// If the slice is not sorted or if the comparator function does not
/// implement an order consistent with the sort order of the underlying
/// slice, the returned result is unspecified and meaningless.
///
/// If the value is found then [`Result::Ok`] is returned, containing the
/// index of the matching element. If there are multiple matches, then any
/// one of the matches could be returned. The index is chosen
/// deterministically, but is subject to change in future versions of Rust.
/// If the value is not found then [`Result::Err`] is returned, containing
/// the index where a matching element could be inserted while maintaining
/// sorted order.
///
/// Based on std::slice::binary_search_by with minimal modifications (https://doc.rust-lang.org/src/core/slice/mod.rs.html#2770).
#[inline]
fn binary_search_last_index(&self, prefix: &str) -> Option<usize> {
fn binary_search_by<'me, K: ArrowReadableKey<'me>, F>(
&'me self,
mut f: F,
) -> Result<usize, usize>
where
F: FnMut((&'me str, K)) -> Ordering,
{
let mut size = self.len();
if size == 0 {
return None;
return Err(0);
}

let prefix_array = self
Expand All @@ -136,59 +156,66 @@ impl Block {
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let mut base = self.len() - 1;
let mut base = 0;

// This loop intentionally doesn't have an early exit if the comparison
// returns Equal. We want the number of loop iterations to depend *only*
// on the size of the input slice so that the CPU can reliably predict
// the loop count.
while size > 1 {
let half = size / 2;
let mid = base - half;
let mid = base + half;

// SAFETY: the call is made safe by the following inconstants:
// - `mid < size`: by definition
// - `mid >= 0`: `mid = size - 1 - size / 2 - size / 4 ...`
let cmp = prefix_array.value(mid).cmp(prefix);

base = if cmp == Greater { mid } else { base };
// - `mid >= 0`: by definition
// - `mid < size`: `mid = size / 2 + size / 4 + size / 8 ...`
let prefix = prefix_array.value(mid);
let key = K::get(self.data.column(1), mid);
let cmp = f((prefix, key));

base = if cmp == Ordering::Greater { base } else { mid };

// This is imprecise in the case where `size` is odd and the
// comparison returns Greater: the mid element still gets included
// by `size` even though it's known to be larger than the element
// being searched for.
//
// This is fine though: we gain more performance by keeping the
// loop iteration count invariant (and thus predictable) than we
// lose from considering one additional element.
size -= half;
}

// SAFETY: `base` is always in [0, size) because `base < size` by init.
// `base` should be the last index where the element matches the target prefix,
// or 0 if the first element is already larger than the target prefix.
match prefix_array.value(base).cmp(prefix) {
Less => None,
Equal => Some(base),
Greater => {
if base > 0 && prefix_array.value(base - 1) == prefix {
Some(base - 1)
} else {
None
}
}
// SAFETY: base is always in [0, size) because base <= mid.
let prefix = prefix_array.value(base);
let key = K::get(self.data.column(1), base);
let cmp = f((prefix, key));
if cmp == Ordering::Equal {
Ok(base)
} else {
let result = base + (cmp == Ordering::Less) as usize;
Err(result)
}
}

/// Binary search the blockfile to find the partition point of the specified prefix and key.
///
/// `(prefix, key)` serves as the search key, and it is sorted in ascending order.
/// The partition predicate is defined by: `|x| x < (prefix, key)`.
/// The partition point is the first index where the partition precidate evaluates to `false`
/// The code is a result of inlining this predicate in [`std::slice::partition_point`].
/// If the key is unspecified (i.e. `None`), we find the first index of the prefix.
///
/// Partly based on `std::slice::binary_search_by`: https://doc.rust-lang.org/src/core/slice/mod.rs.html#2770
/// Returns the largest index where `prefixes[index] == prefix` or None if the provided prefix does not exist in the block.
#[inline]
fn binary_search_index<'me, K: ArrowReadableKey<'me>>(
fn find_largest_index_of_prefix<'me, K: ArrowReadableKey<'me>>(
&'me self,
prefix: &str,
key: Option<&K>,
) -> usize {
let mut size = self.len();
if size == 0 {
return 0;
) -> Option<usize> {
// By design, will never find an exact match (comparator never evaluates to Equal). This finds the index of the first element that is greater than the prefix. If no element is greater, it returns the length of the array.
let result = self
.binary_search_by::<K, _>(|(p, _)| match p.cmp(prefix) {
Ordering::Less => Ordering::Less,
Ordering::Equal => Ordering::Less,
Ordering::Greater => Ordering::Greater,
})
.expect_err("Never returns Ok because the comparator never evaluates to Equal.");

if result == 0 {
// The first element is greater than the target prefix, so the target prefix does not exist in the block.
return None;
}

let prefix_array = self
Expand All @@ -197,49 +224,48 @@ impl Block {
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let mut base = 0;

// This loop intentionally doesn't have an early exit if the comparison
// returns Equal. We want the number of loop iterations to depend *only*
// on the size of the input slice so that the CPU can reliably predict
// the loop count.
while size > 1 {
let half = size / 2;
let mid = base + half;

// SAFETY: the call is made safe by the following inconstants:
// - `mid >= 0`: by definition
// - `mid < size`: `mid = size / 2 + size / 4 + size / 8 ...`
let mut cmp = prefix_array.value(mid).cmp(prefix);

// Continue to compare the key if prefix matches
if let (Equal, Some(k)) = (cmp, key) {
cmp = K::get(self.data.column(1), mid)
// Key type do not have total order because of floating point values
// But in our case NaN should not be allowed so we should always have total order
.partial_cmp(k)
.expect("Array values should be comparable.");
}

base = if cmp == Less { mid } else { base };
size -= half;
// `result` is the first index where the prefix is larger than the input (or the length of the array) so we want one element before this.
match prefix_array.value(result - 1).cmp(prefix) {
// We're at the end of the array, so the prefix does not exist in the block (all values are less than the prefix)
Ordering::Less => None,
// The prefix exists
Ordering::Equal => Some(result - 1),
// This is impossible
Ordering::Greater => None,
}
}

// SAFETY: `base` is always in [0, size) because `base <= mid`.
// `base` should be the last index where the element is smaller than the target,
// or 0 if the first element is already larger than the target.
match prefix_array.value(base).cmp(prefix) {
Less => base + 1,
Equal => match key {
// Key type do not have total order because of floating point values
// But in our case NaN should not be allowed so we should always have total order
Some(k) => match K::get(self.data.column(1), base).partial_cmp(k) {
Some(Less) => base + 1,
_ => base,
},
None => base,
},
Greater => base,
/// Finds the partition point of the prefix and key.
/// Returns the index of the first element that matches the target prefix and key. If no element matches, returns the index at which the target prefix and key could be inserted to maintain sorted order.
#[inline]
fn get_key_prefix_partition_point<'me, K: ArrowReadableKey<'me>>(
&'me self,
prefix: &str,
key: Option<&K>,
) -> usize {
// By design, will never find an exact match (comparator never evaluates to Equal). This finds the index of the first element that matches the target prefix and key. If no element matches, it returns the index at which the target prefix and key could be inserted to maintain sorted order.
if let Some(key) = key {
self.binary_search_by::<K, _>(|(p, k)| {
match p.cmp(prefix).then_with(|| {
k.partial_cmp(key)
// The key type does not have a total order because of floating point values.
// But in our case NaN is not allowed, so we should always have total order.
.expect("Array values should be comparable.")
}) {
Ordering::Less => Ordering::Less,
Ordering::Equal => Ordering::Greater,
Ordering::Greater => Ordering::Greater,
}
})
.expect_err("Never returns Ok because the comparator never evaluates to Equal.")
} else {
self.binary_search_by::<K, _>(|(p, _)| match p.cmp(prefix) {
Ordering::Less => Ordering::Less,
Ordering::Equal => Ordering::Greater,
Ordering::Greater => Ordering::Greater,
})
.expect_err("Never returns Ok because the comparator never evaluates to Equal.")
}
}

Expand All @@ -262,7 +288,7 @@ impl Block {
prefix_array.value(index).cmp(prefix),
K::get(self.data.column(1), index).partial_cmp(key),
),
(Equal, Some(Equal))
(Ordering::Equal, Some(Ordering::Equal))
)
}

Expand All @@ -278,11 +304,16 @@ impl Block {
prefix: &str,
key: K,
) -> Option<V> {
let index = self.binary_search_index(prefix, Some(&key));
if self.match_prefix_key_at_index(prefix, &key, index) {
Some(V::get(self.data.column(2), index))
} else {
None
match self.binary_search_by::<K, _>(|(p, k)| {
p.cmp(prefix).then_with(|| {
k.partial_cmp(&key)
// The key type does not have a total order because of floating point values.
// But in our case NaN is not allowed, so we should always have total order.
.expect("Array values should be comparable.")
})
}) {
Ok(index) => Some(V::get(self.data.column(2), index)),
Err(_) => None,
}
}

Expand All @@ -308,16 +339,16 @@ impl Block {
{
let start_index = match prefix_range.start_bound() {
Bound::Included(prefix) => match key_range.start_bound() {
Bound::Included(key) => self.binary_search_index(prefix, Some(key)),
Bound::Included(key) => self.get_key_prefix_partition_point(prefix, Some(key)),
Bound::Excluded(key) => {
let index = self.binary_search_index(prefix, Some(key));
let index = self.get_key_prefix_partition_point(prefix, Some(key));
if self.match_prefix_key_at_index(prefix, key, index) {
index + 1
} else {
index
}
}
Bound::Unbounded => self.binary_search_index::<K>(prefix, None),
Bound::Unbounded => self.get_key_prefix_partition_point::<K>(prefix, None),
},
Bound::Excluded(_) => {
unimplemented!("Excluded prefix range is not currently supported")
Expand All @@ -328,15 +359,15 @@ impl Block {
let end_index = match prefix_range.end_bound() {
Bound::Included(prefix) => match key_range.end_bound() {
Bound::Included(key) => {
let index = self.binary_search_index(prefix, Some(key));
let index = self.get_key_prefix_partition_point(prefix, Some(key));
if self.match_prefix_key_at_index(prefix, key, index) {
index + 1
} else {
index
}
}
Bound::Excluded(key) => self.binary_search_index(prefix, Some(key)),
Bound::Unbounded => match self.binary_search_last_index(prefix) {
Bound::Excluded(key) => self.get_key_prefix_partition_point(prefix, Some(key)),
Bound::Unbounded => match self.find_largest_index_of_prefix::<K>(prefix) {
Some(last_index_of_prefix) => last_index_of_prefix + 1, // (add 1 because end_index is exclusive below)
None => start_index, // prefix does not exist in the block so we shouldn't return anything
},
Expand Down

0 comments on commit 7a6c762

Please sign in to comment.