Skip to content

Commit

Permalink
[BUG] Enable sorting bool columns (#2529)
Browse files Browse the repository at this point in the history
Co-authored-by: Colin Ho <[email protected]>
  • Loading branch information
colin-ho and Colin Ho authored Jul 18, 2024
1 parent 46d6717 commit 63518de
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 9 deletions.
82 changes: 81 additions & 1 deletion src/daft-core/src/kernels/search_sorted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{cmp::Ordering, iter::zip};
use arrow2::{
array::{
ord::{build_compare, DynComparator},
Array, BinaryArray, FixedSizeBinaryArray, PrimitiveArray, Utf8Array,
Array, BinaryArray, BooleanArray, FixedSizeBinaryArray, PrimitiveArray, Utf8Array,
},
datatypes::{DataType, PhysicalType},
error::{Error, Result},
Expand Down Expand Up @@ -142,6 +142,81 @@ fn search_sorted_utf_array<O: Offset>(
PrimitiveArray::<u64>::new(DataType::UInt64, results.into(), None)
}

fn search_sorted_boolean_array(
sorted_array: &BooleanArray,
keys: &BooleanArray,
input_reversed: bool,
) -> PrimitiveArray<u64> {
let array_size = sorted_array.len();
let mut left = 0_usize;
let mut right = array_size;

// For boolean arrays, we know there can only be three possible values: true, false, and null.s
// We can pre-compute the results for these three values and then reuse them to compute the results for the keys.
let pre_computed_keys = &[Some(true), Some(false), None];
let mut pre_computed_results: [u64; 3] = [0, 0, 0];
let mut last_key = pre_computed_keys.iter().next().unwrap();
for (i, key_val) in pre_computed_keys.iter().enumerate() {
let is_last_key_lt = match (last_key, key_val) {
(None, None) => false,
(None, Some(_)) => input_reversed,
(Some(last_key), Some(key_val)) => {
if !input_reversed {
last_key.lt(key_val)
} else {
last_key.gt(key_val)
}
}
(Some(_), None) => !input_reversed,
};
if is_last_key_lt {
right = array_size;
} else {
left = 0;
right = if right < array_size {
right + 1
} else {
array_size
};
}
while left < right {
let mid_idx = left + ((right - left) >> 1);
let mid_val = unsafe { sorted_array.value_unchecked(mid_idx) };
let is_key_val_lt = match (key_val, sorted_array.is_valid(mid_idx)) {
(None, false) => false,
(None, true) => input_reversed,
(Some(key_val), true) => {
if !input_reversed {
key_val.lt(&mid_val)
} else {
mid_val.lt(key_val)
}
}
(Some(_), false) => !input_reversed,
};

if is_key_val_lt {
right = mid_idx;
} else {
left = mid_idx + 1;
}
}
pre_computed_results[i] = left.try_into().unwrap();
last_key = key_val;
}

let results = keys
.iter()
.map(|key_val| match key_val {
Some(true) => pre_computed_results[0],
Some(false) => pre_computed_results[1],
None => pre_computed_results[2],
})
.collect::<Vec<_>>();

PrimitiveArray::<u64>::new(DataType::UInt64, results.into(), None)
}

fn search_sorted_binary_array<O: Offset>(
sorted_array: &BinaryArray<O>,
keys: &BinaryArray<O>,
Expand Down Expand Up @@ -535,6 +610,11 @@ pub fn search_sorted(
keys.as_any().downcast_ref().unwrap(),
input_reversed,
),
Boolean => search_sorted_boolean_array(
sorted_array.as_any().downcast_ref().unwrap(),
keys.as_any().downcast_ref().unwrap(),
input_reversed,
),
t => {
return Err(Error::NotYetImplemented(format!(
"search_sorted not implemented for type {t:?}"
Expand Down
2 changes: 1 addition & 1 deletion src/daft-plan/src/logical_ops/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ impl Sort {
for (field, expr) in sort_by_resolved_schema.fields.values().zip(sort_by.iter()) {
// Disallow sorting by null, binary, and boolean columns.
// TODO(Clark): This is a port of an existing constraint, we should look at relaxing this.
if let dt @ (DataType::Null | DataType::Binary | DataType::Boolean) = &field.dtype {
if let dt @ (DataType::Null | DataType::Binary) = &field.dtype {
return Err(DaftError::ValueError(format!(
"Cannot sort on expression {expr} with type: {dt}",
)))
Expand Down
42 changes: 35 additions & 7 deletions tests/dataframe/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,6 @@
###


def test_disallowed_sort_bool(make_df):
df = make_df({"A": [True, False]})

with pytest.raises((ExpressionTypeError, ValueError)):
df.sort("A")


def test_disallowed_sort_null(make_df):
df = make_df({"A": [None, None]})

Expand Down Expand Up @@ -117,6 +110,41 @@ def test_single_string_col_sort(make_df, desc: bool, n_partitions: int):
assert sorted_data["A"] == expected


@pytest.mark.parametrize("desc", [True, False])
@pytest.mark.parametrize("n_partitions", [1, 3, 4])
def test_single_bool_col_sort(make_df, desc: bool, n_partitions: int):
df = make_df({"A": [True, None, False, True, False]}, repartition=n_partitions)
df = df.sort("A", desc=desc)
sorted_data = df.to_pydict()

expected = [False, False, True, True, None]
if desc:
expected = list(reversed(expected))

assert sorted_data["A"] == expected


@pytest.mark.parametrize("n_partitions", [1, 3, 4])
def test_multi_bool_col_sort(make_df, n_partitions: int):
df = make_df(
{
"A": [True, False, None, False, True],
"B": [None, True, False, True, None],
},
repartition=n_partitions,
)
df = df.sort(["A", "B"], desc=[True, False])
sorted_data = df.to_pydict()

expected = {
"A": [None, True, True, False, False],
"B": [False, None, None, True, True],
}

assert sorted_data["A"] == expected["A"]
assert sorted_data["B"] == expected["B"]


###
# Null tests
###
Expand Down

0 comments on commit 63518de

Please sign in to comment.