Skip to content

Commit

Permalink
feat: Add index_of() function to Series and Expr (#19894)
Browse files Browse the repository at this point in the history
Co-authored-by: Itamar Turner-Trauring <[email protected]>
Co-authored-by: coastalwhite <[email protected]>
  • Loading branch information
3 people authored Jan 7, 2025
1 parent 791dee5 commit 785bb1e
Show file tree
Hide file tree
Showing 20 changed files with 619 additions and 0 deletions.
4 changes: 4 additions & 0 deletions crates/polars-core/src/datatypes/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,10 @@ impl DataType {
return Some(true);
}

if self.is_null() {
return Some(true);
}

use DataType as D;
Some(match (self, to) {
#[cfg(feature = "dtype-categorical")]
Expand Down
3 changes: 3 additions & 0 deletions crates/polars-lazy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ string_pad = ["polars-plan/string_pad"]
string_reverse = ["polars-plan/string_reverse"]
string_to_integer = ["polars-plan/string_to_integer"]
arg_where = ["polars-plan/arg_where"]
index_of = ["polars-plan/index_of"]
search_sorted = ["polars-plan/search_sorted"]
merge_sorted = ["polars-plan/merge_sorted", "polars-stream?/merge_sorted"]
meta = ["polars-plan/meta"]
Expand Down Expand Up @@ -314,6 +315,7 @@ test_all = [
"row_hash",
"string_pad",
"string_to_integer",
"index_of",
"search_sorted",
"top_k",
"pivot",
Expand Down Expand Up @@ -360,6 +362,7 @@ features = [
"fused",
"futures",
"hist",
"index_of",
"interpolate",
"interpolate_by",
"ipc",
Expand Down
1 change: 1 addition & 0 deletions crates/polars-ops/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ rolling_window = ["polars-core/rolling_window"]
rolling_window_by = ["polars-core/rolling_window_by"]
moment = []
mode = []
index_of = []
search_sorted = []
merge_sorted = []
top_k = []
Expand Down
121 changes: 121 additions & 0 deletions crates/polars-ops/src/series/ops/index_of.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
use arrow::array::{BinaryArray, PrimitiveArray};
use polars_core::downcast_as_macro_arg_physical;
use polars_core::prelude::*;
use polars_utils::total_ord::TotalEq;
use row_encode::encode_rows_unordered;

/// Find the index of the value, or ``None`` if it can't be found.
fn index_of_value<'a, DT, AR>(ca: &'a ChunkedArray<DT>, value: AR::ValueT<'a>) -> Option<usize>
where
DT: PolarsDataType,
AR: StaticArray,
AR::ValueT<'a>: TotalEq,
{
let req_value = &value;
let mut index = 0;
for chunk in ca.chunks() {
let chunk = chunk.as_any().downcast_ref::<AR>().unwrap();
if chunk.validity().is_some() {
for maybe_value in chunk.iter() {
if maybe_value.map(|v| v.tot_eq(req_value)) == Some(true) {
return Some(index);
} else {
index += 1;
}
}
} else {
// A lack of a validity bitmap means there are no nulls, so we
// can simplify our logic and use a faster code path:
for value in chunk.values_iter() {
if value.tot_eq(req_value) {
return Some(index);
} else {
index += 1;
}
}
}
}
None
}

fn index_of_numeric_value<T>(ca: &ChunkedArray<T>, value: T::Native) -> Option<usize>
where
T: PolarsNumericType,
{
index_of_value::<_, PrimitiveArray<T::Native>>(ca, value)
}

/// Try casting the value to the correct type, then call
/// index_of_numeric_value().
macro_rules! try_index_of_numeric_ca {
($ca:expr, $value:expr) => {{
let ca = $ca;
let value = $value;
// extract() returns None if casting failed, so consider an extract()
// failure as not finding the value. Nulls should have been handled
// earlier.
let value = value.value().extract().unwrap();
index_of_numeric_value(ca, value)
}};
}

/// Find the index of a given value (the first and only entry in `value_series`)
/// within the series.
pub fn index_of(series: &Series, needle: Scalar) -> PolarsResult<Option<usize>> {
polars_ensure!(
series.dtype() == needle.dtype(),
InvalidOperation: "Cannot perform index_of with mismatching datatypes: {:?} and {:?}",
series.dtype(),
needle.dtype(),
);

// Series is null:
if series.dtype().is_null() {
if needle.is_null() {
return Ok((series.len() > 0).then_some(0));
} else {
return Ok(None);
}
}

// Series is not null, and the value is null:
if needle.is_null() {
let mut index = 0;
for chunk in series.chunks() {
let length = chunk.len();
if let Some(bitmap) = chunk.validity() {
let leading_ones = bitmap.leading_ones();
if leading_ones < length {
return Ok(Some(index + leading_ones));
}
} else {
index += length;
}
}
return Ok(None);
}

if series.dtype().is_primitive_numeric() {
return Ok(downcast_as_macro_arg_physical!(
series,
try_index_of_numeric_ca,
needle
));
}

if series.dtype().is_categorical() {
// See https://github.com/pola-rs/polars/issues/20318
polars_bail!(InvalidOperation: "index_of() on Categoricals is not supported");
}

// For non-numeric dtypes, we convert to row-encoding, which essentially has
// us searching the physical representation of the data as a series of
// bytes.
let value_as_column = Column::new_scalar(PlSmallStr::EMPTY, needle, 1);
let value_as_row_encoded_ca = encode_rows_unordered(&[value_as_column])?;
let value = value_as_row_encoded_ca
.first()
.expect("Shouldn't have nulls in a row-encoded result");
let ca = encode_rows_unordered(&[series.clone().into()])?;
Ok(index_of_value::<_, BinaryArray<i64>>(&ca, value))
}
4 changes: 4 additions & 0 deletions crates/polars-ops/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ mod floor_divide;
mod fused;
mod horizontal;
mod index;
#[cfg(feature = "index_of")]
mod index_of;
mod int_range;
#[cfg(any(feature = "interpolate_by", feature = "interpolate"))]
mod interpolation;
Expand Down Expand Up @@ -84,6 +86,8 @@ pub use floor_divide::*;
pub use fused::*;
pub use horizontal::*;
pub use index::*;
#[cfg(feature = "index_of")]
pub use index_of::*;
pub use int_range::*;
#[cfg(feature = "interpolate")]
pub use interpolation::interpolate::*;
Expand Down
1 change: 1 addition & 0 deletions crates/polars-ops/src/series/ops/search_sorted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub fn search_sorted(
let original_dtype = s.dtype();

if s.dtype().is_categorical() {
// See https://github.com/pola-rs/polars/issues/20171
polars_bail!(InvalidOperation: "'search_sorted' is not supported on dtype: {}", s.dtype())
}

Expand Down
2 changes: 2 additions & 0 deletions crates/polars-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ string_pad = ["polars-ops/string_pad"]
string_reverse = ["polars-ops/string_reverse"]
string_to_integer = ["polars-ops/string_to_integer"]
arg_where = []
index_of = ["polars-ops/index_of"]
search_sorted = ["polars-ops/search_sorted"]
merge_sorted = ["polars-ops/merge_sorted"]
meta = []
Expand Down Expand Up @@ -263,6 +264,7 @@ features = [
"find_many",
"string_encoding",
"ipc",
"index_of",
"search_sorted",
"unique_counts",
"dtype-u8",
Expand Down
61 changes: 61 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/index_of.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
use polars_ops::series::index_of as index_of_op;

use super::*;

/// Given two columns, find the index of a value (the second column) within the
/// first column. Will use binary search if possible, as an optimization.
pub(super) fn index_of(s: &mut [Column]) -> PolarsResult<Column> {
let series = if let Column::Scalar(ref sc) = s[0] {
// We only care about the first value:
&sc.as_single_value_series()
} else {
s[0].as_materialized_series()
};

let needle_s = &s[1];
polars_ensure!(
needle_s.len() == 1,
InvalidOperation: "needle of `index_of` can only contain a single value, found {} values",
needle_s.len()
);
let needle = Scalar::new(
needle_s.dtype().clone(),
needle_s.get(0).unwrap().into_static(),
);

let is_sorted_flag = series.is_sorted_flag();
let result = match is_sorted_flag {
// If the Series is sorted, we can use an optimized binary search to
// find the value.
IsSorted::Ascending | IsSorted::Descending
if !needle.is_null() &&
// search_sorted() doesn't support decimals at the moment.
!series.dtype().is_decimal() =>
{
search_sorted(
series,
needle_s.as_materialized_series(),
SearchSortedSide::Left,
IsSorted::Descending == is_sorted_flag,
)?
.get(0)
.and_then(|idx| {
// search_sorted() gives an index even if it's not an exact
// match! So we want to make sure it actually found the value.
if series.get(idx as usize).ok()? == needle.as_any_value() {
Some(idx as usize)
} else {
None
}
})
},
_ => index_of_op(series, needle)?,
};

let av = match result {
None => AnyValue::Null,
Some(idx) => AnyValue::from(idx as IdxSize),
};
let scalar = Scalar::new(IDX_DTYPE, av);
Ok(Column::new_scalar(series.name().clone(), scalar, 1))
}
12 changes: 12 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ mod ewm_by;
mod fill_null;
#[cfg(feature = "fused")]
mod fused;
#[cfg(feature = "index_of")]
mod index_of;
mod list;
#[cfg(feature = "log")]
mod log;
Expand Down Expand Up @@ -154,6 +156,8 @@ pub enum FunctionExpr {
Hash(u64, u64, u64, u64),
#[cfg(feature = "arg_where")]
ArgWhere,
#[cfg(feature = "index_of")]
IndexOf,
#[cfg(feature = "search_sorted")]
SearchSorted(SearchSortedSide),
#[cfg(feature = "range")]
Expand Down Expand Up @@ -395,6 +399,8 @@ impl Hash for FunctionExpr {
#[cfg(feature = "business")]
Business(f) => f.hash(state),
Pow(f) => f.hash(state),
#[cfg(feature = "index_of")]
IndexOf => {},
#[cfg(feature = "search_sorted")]
SearchSorted(f) => f.hash(state),
#[cfg(feature = "random")]
Expand Down Expand Up @@ -640,6 +646,8 @@ impl Display for FunctionExpr {
Hash(_, _, _, _) => "hash",
#[cfg(feature = "arg_where")]
ArgWhere => "arg_where",
#[cfg(feature = "index_of")]
IndexOf => "index_of",
#[cfg(feature = "search_sorted")]
SearchSorted(_) => "search_sorted",
#[cfg(feature = "range")]
Expand Down Expand Up @@ -929,6 +937,10 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn ColumnsUdf>> {
ArgWhere => {
wrap!(arg_where::arg_where)
},
#[cfg(feature = "index_of")]
IndexOf => {
map_as_slice!(index_of::index_of)
},
#[cfg(feature = "search_sorted")]
SearchSorted(side) => {
map_as_slice!(search_sorted::search_sorted_impl, side)
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ impl FunctionExpr {
Hash(..) => mapper.with_dtype(DataType::UInt64),
#[cfg(feature = "arg_where")]
ArgWhere => mapper.with_dtype(IDX_DTYPE),
#[cfg(feature = "index_of")]
IndexOf => mapper.with_dtype(IDX_DTYPE),
#[cfg(feature = "search_sorted")]
SearchSorted(_) => mapper.with_dtype(IDX_DTYPE),
#[cfg(feature = "range")]
Expand Down
16 changes: 16 additions & 0 deletions crates/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,22 @@ impl Expr {
)
}

#[cfg(feature = "index_of")]
/// Find the index of a value.
pub fn index_of<E: Into<Expr>>(self, element: E) -> Expr {
let element = element.into();
Expr::Function {
input: vec![self, element],
function: FunctionExpr::IndexOf,
options: FunctionOptions {
flags: FunctionFlags::default() | FunctionFlags::RETURNS_SCALAR,
fmt_str: "index_of",
cast_options: Some(CastingRules::FirstArgLossless),
..Default::default()
},
}
}

#[cfg(feature = "search_sorted")]
/// Find indices where elements should be inserted to maintain order.
pub fn search_sorted<E: Into<Expr>>(self, element: E, side: SearchSortedSide) -> Expr {
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ repeat_by = ["polars/repeat_by"]

streaming = ["polars/streaming"]
meta = ["polars/meta"]
index_of = ["polars/index_of"]
search_sorted = ["polars/search_sorted"]
decompress = ["polars/decompress-fast"]
regex = ["polars/regex"]
Expand Down Expand Up @@ -211,6 +212,7 @@ operations = [
"asof_join",
"cross_join",
"pct_change",
"index_of",
"search_sorted",
"merge_sorted",
"top_k",
Expand Down
6 changes: 6 additions & 0 deletions crates/polars-python/src/expr/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,13 +318,19 @@ impl PyExpr {
self.inner.clone().arg_min().into()
}

#[cfg(feature = "index_of")]
fn index_of(&self, element: Self) -> Self {
self.inner.clone().index_of(element.inner).into()
}

#[cfg(feature = "search_sorted")]
fn search_sorted(&self, element: Self, side: Wrap<SearchSortedSide>) -> Self {
self.inner
.clone()
.search_sorted(element.inner, side.0)
.into()
}

fn gather(&self, idx: Self) -> Self {
self.inner.clone().gather(idx.inner).into()
}
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-python/src/lazyframe/visitor/expr_nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,8 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
("hash", seed, seed_1, seed_2, seed_3).into_py_any(py)
},
FunctionExpr::ArgWhere => ("argwhere",).into_py_any(py),
#[cfg(feature = "index_of")]
FunctionExpr::IndexOf => ("index_of",).into_py_any(py),
#[cfg(feature = "search_sorted")]
FunctionExpr::SearchSorted(side) => (
"search_sorted",
Expand Down
Loading

0 comments on commit 785bb1e

Please sign in to comment.