From 21cb2b574b9e257de92ab321a582213b96a22d24 Mon Sep 17 00:00:00 2001 From: Colin Date: Thu, 25 Jan 2024 12:30:33 -0800 Subject: [PATCH] [FEAT] is_in expression (#1811) Closes #993 The `is_in` expression checks whether the values of a series are contained in a given list of items, and produces a series of boolean values as the results of this membership test. Changes: - Added a Literal Series so that Series can be passed into the expression - Added `is_in` expression and kernel - Added tests --- Cargo.lock | 1 + daft/daft.pyi | 2 + daft/expressions/expressions.py | 25 ++- daft/series.py | 18 +- daft/table/table.py | 15 +- daft/utils.py | 11 ++ src/daft-core/Cargo.toml | 1 + src/daft-core/src/array/from_iter.rs | 16 +- src/daft-core/src/array/ops/is_in.rs | 93 ++++++++++ src/daft-core/src/array/ops/mod.rs | 6 + src/daft-core/src/datatypes/binary_ops.rs | 7 + src/daft-core/src/series/mod.rs | 11 +- src/daft-core/src/series/ops/is_in.rs | 53 ++++++ src/daft-core/src/series/ops/mod.rs | 34 ++++ src/daft-core/src/utils/display_table.rs | 15 ++ .../src/utils/hashable_float_wrapper.rs | 58 +++++- src/daft-dsl/src/expr.rs | 14 ++ src/daft-dsl/src/lib.rs | 1 + src/daft-dsl/src/lit.rs | 30 +++- src/daft-dsl/src/optimization.rs | 1 + src/daft-dsl/src/python.rs | 11 ++ src/daft-dsl/src/treenode.rs | 5 + src/daft-plan/src/logical_ops/project.rs | 13 ++ src/daft-plan/src/physical_ops/project.rs | 11 ++ src/daft-table/src/lib.rs | 3 + tests/expressions/test_expressions.py | 8 + tests/table/test_is_in.py | 169 ++++++++++++++++++ 27 files changed, 599 insertions(+), 33 deletions(-) create mode 100644 src/daft-core/src/array/ops/is_in.rs create mode 100644 src/daft-core/src/series/ops/is_in.rs create mode 100644 tests/table/test_is_in.py diff --git a/Cargo.lock b/Cargo.lock index e6a439995f..bc4c010136 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1153,6 +1153,7 @@ dependencies = [ "html-escape", "image", "indexmap 2.1.0", + "itertools", "lazy_static", "log", "mur3", diff --git a/daft/daft.pyi b/daft/daft.pyi index 080ae5694c..89348cc66f 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -839,6 +839,7 @@ class PyExpr: def __ne__(self, other: PyExpr) -> PyExpr: ... # type: ignore[override] def is_null(self) -> PyExpr: ... def not_null(self) -> PyExpr: ... + def is_in(self, other: PyExpr) -> PyExpr: ... def name(self) -> str: ... def to_field(self, schema: PySchema) -> PyField: ... def __repr__(self) -> str: ... @@ -879,6 +880,7 @@ def col(name: str) -> PyExpr: ... def lit(item: Any) -> PyExpr: ... def date_lit(item: int) -> PyExpr: ... def timestamp_lit(item: int, tu: PyTimeUnit, tz: str | None) -> PyExpr: ... +def series_lit(item: PySeries) -> PyExpr: ... def udf(func: Callable, expressions: list[PyExpr], return_dtype: PyDataType) -> PyExpr: ... class PySeries: diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 85d846b37d..4ad6f9ab22 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -3,7 +3,7 @@ import builtins import sys from datetime import date, datetime -from typing import TYPE_CHECKING, Callable, Iterable, Iterator, TypeVar, overload +from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, TypeVar, overload import pyarrow as pa @@ -13,11 +13,13 @@ from daft.daft import col as _col from daft.daft import date_lit as _date_lit from daft.daft import lit as _lit +from daft.daft import series_lit as _series_lit from daft.daft import timestamp_lit as _timestamp_lit from daft.daft import udf as _udf from daft.datatype import DataType, TimeUnit from daft.expressions.testing import expr_structurally_equal from daft.logical.schema import Field, Schema +from daft.series import Series, item_to_series if sys.version_info < (3, 8): from typing_extensions import Literal @@ -51,6 +53,8 @@ def lit(value: object) -> Expression: # pyo3 date (PyDate) is not available when running in abi3 mode, workaround epoch_time = value - date(1970, 1, 1) lit_value = _date_lit(epoch_time.days) + elif isinstance(value, Series): + lit_value = _series_lit(value._series) else: lit_value = _lit(value) return Expression._from_pyexpr(lit_value) @@ -392,6 +396,24 @@ def not_null(self) -> Expression: expr = self._expr.not_null() return Expression._from_pyexpr(expr) + def is_in(self, other: Any) -> Expression: + """Checks if values in the Expression are in the provided list + + Example: + >>> # [1, 2, 3] -> [True, False, True] + >>> col("x").is_in([1, 3]) + + Returns: + Expression: Boolean Expression indicating whether values are in the provided list + """ + + if not isinstance(other, Expression): + series = item_to_series("items", other) + other = Expression._to_expression(series) + + expr = self._expr.is_in(other._expr) + return Expression._from_pyexpr(expr) + def name(self) -> builtins.str: return self._expr.name() @@ -469,7 +491,6 @@ def download( Expression: a Binary expression which is the bytes contents of the URL, or None if an error occured during download """ if use_native_downloader: - raise_on_error = False if on_error == "raise": raise_on_error = True diff --git a/daft/series.py b/daft/series.py index ee091adc10..ffcaa8d74e 100644 --- a/daft/series.py +++ b/daft/series.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TypeVar +from typing import Any, TypeVar import pyarrow as pa @@ -515,6 +515,22 @@ def _debug_bincode_deserialize(cls, b: bytes) -> Series: return Series._from_pyseries(PySeries._debug_bincode_deserialize(b)) +def item_to_series(name: str, item: Any) -> Series: + if isinstance(item, list): + series = Series.from_pylist(item, name) + elif _NUMPY_AVAILABLE and isinstance(item, np.ndarray): + series = Series.from_numpy(item, name) + elif isinstance(item, Series): + series = item + elif isinstance(item, (pa.Array, pa.ChunkedArray)): + series = Series.from_arrow(item, name) + elif _PANDAS_AVAILABLE and isinstance(item, pd.Series): + series = Series.from_pandas(item, name) + else: + raise ValueError(f"Creating a Series from data of type {type(item)} not implemented") + return series + + SomeSeriesNamespace = TypeVar("SomeSeriesNamespace", bound="SeriesNamespace") diff --git a/daft/table/table.py b/daft/table/table.py index b7b09add90..6f7e0e972c 100644 --- a/daft/table/table.py +++ b/daft/table/table.py @@ -27,7 +27,7 @@ from daft.datatype import DataType, TimeUnit from daft.expressions import Expression, ExpressionsProjection from daft.logical.schema import Schema -from daft.series import Series +from daft.series import Series, item_to_series _NUMPY_AVAILABLE = True try: @@ -148,18 +148,7 @@ def from_pandas(pd_df: pd.DataFrame) -> Table: def from_pydict(data: dict) -> Table: series_dict = dict() for k, v in data.items(): - if isinstance(v, list): - series = Series.from_pylist(v, name=k) - elif _NUMPY_AVAILABLE and isinstance(v, np.ndarray): - series = Series.from_numpy(v, name=k) - elif isinstance(v, Series): - series = v - elif isinstance(v, (pa.Array, pa.ChunkedArray)): - series = Series.from_arrow(v, name=k) - elif _PANDAS_AVAILABLE and isinstance(v, pd.Series): - series = Series.from_pandas(v, name=k) - else: - raise ValueError(f"Creating a Series from data of type {type(v)} not implemented") + series = item_to_series(k, v) series_dict[k] = series._series return Table._from_pytable(_PyTable.from_pylist_series(series_dict)) diff --git a/daft/utils.py b/daft/utils.py index 796587aafc..a8efc9bf2e 100644 --- a/daft/utils.py +++ b/daft/utils.py @@ -87,6 +87,17 @@ def map_operator_arrow_semantics_bool( ] +def python_list_membership_check( + left_pylist: list, + right_pylist: list, +) -> list: + try: + right_pyset = set(right_pylist) + return [elem in right_pyset for elem in left_pylist] + except TypeError: + return [elem in right_pylist for elem in left_pylist] + + def map_operator_arrow_semantics( operator: Callable[[Any, Any], Any], left_pylist: list, diff --git a/src/daft-core/Cargo.toml b/src/daft-core/Cargo.toml index bae1484074..83be2e57ad 100644 --- a/src/daft-core/Cargo.toml +++ b/src/daft-core/Cargo.toml @@ -10,6 +10,7 @@ dyn-clone = "1.0.16" fnv = "1.0.7" html-escape = {workspace = true} indexmap = {workspace = true, features = ["serde"]} +itertools = {workspace = true} lazy_static = {workspace = true} log = {workspace = true} mur3 = "0.1.0" diff --git a/src/daft-core/src/array/from_iter.rs b/src/daft-core/src/array/from_iter.rs index 7fb3b5ad21..2440d8a0a4 100644 --- a/src/daft-core/src/array/from_iter.rs +++ b/src/daft-core/src/array/from_iter.rs @@ -1,4 +1,4 @@ -use crate::datatypes::{BinaryArray, DaftNumericType, Field, Utf8Array}; +use crate::datatypes::{BinaryArray, BooleanArray, DaftNumericType, Field, Utf8Array}; use super::DataArray; @@ -41,3 +41,17 @@ impl BinaryArray { .unwrap() } } + +impl BooleanArray { + pub fn from_iter( + name: &str, + iter: impl Iterator> + arrow2::trusted_len::TrustedLen, + ) -> Self { + let arrow_array = Box::new(arrow2::array::BooleanArray::from_trusted_len_iter(iter)); + DataArray::new( + Field::new(name, crate::DataType::Boolean).into(), + arrow_array, + ) + .unwrap() + } +} diff --git a/src/daft-core/src/array/ops/is_in.rs b/src/daft-core/src/array/ops/is_in.rs new file mode 100644 index 0000000000..725fad4ea2 --- /dev/null +++ b/src/daft-core/src/array/ops/is_in.rs @@ -0,0 +1,93 @@ +use crate::{ + array::DataArray, + datatypes::{ + BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, Float32Array, Float64Array, + NullArray, Utf8Array, + }, + DataType, +}; + +use super::as_arrow::AsArrow; +use super::{full::FullNull, DaftIsIn}; +use crate::utils::hashable_float_wrapper::FloatWrapper; +use common_error::DaftResult; +use std::collections::{BTreeSet, HashSet}; + +macro_rules! collect_to_set_and_check_membership { + ($self:expr, $rhs:expr) => {{ + let set = $rhs + .as_arrow() + .iter() + .filter_map(|item| item) + .collect::>(); + let result = $self + .as_arrow() + .iter() + .map(|option| option.and_then(|value| Some(set.contains(&value)))); + Ok(BooleanArray::from_iter($self.name(), result)) + }}; +} + +impl DaftIsIn<&DataArray> for DataArray +where + T: DaftIntegerType, + ::Native: Ord, + ::Native: std::hash::Hash, + ::Native: std::cmp::Eq, +{ + type Output = DaftResult; + + fn is_in(&self, rhs: &DataArray) -> Self::Output { + collect_to_set_and_check_membership!(self, rhs) + } +} + +macro_rules! impl_is_in_floating_array { + ($arr:ident, $T:ident) => { + impl DaftIsIn<&$arr> for $arr { + type Output = DaftResult; + + fn is_in(&self, rhs: &$arr) -> Self::Output { + let set = rhs + .as_arrow() + .iter() + .filter_map(|item| item.map(|value| FloatWrapper(*value))) + .collect::>>(); + let result = self.as_arrow().iter().map(|option| { + option.and_then(|value| Some(set.contains(&FloatWrapper(*value)))) + }); + Ok(BooleanArray::from_iter(self.name(), result)) + } + } + }; +} +impl_is_in_floating_array!(Float32Array, f32); +impl_is_in_floating_array!(Float64Array, f64); + +macro_rules! impl_is_in_non_numeric_array { + ($arr:ident) => { + impl DaftIsIn<&$arr> for $arr { + type Output = DaftResult; + + fn is_in(&self, rhs: &$arr) -> Self::Output { + collect_to_set_and_check_membership!(self, rhs) + } + } + }; +} +impl_is_in_non_numeric_array!(BooleanArray); +impl_is_in_non_numeric_array!(Utf8Array); +impl_is_in_non_numeric_array!(BinaryArray); + +impl DaftIsIn<&NullArray> for NullArray { + type Output = DaftResult; + + fn is_in(&self, _rhs: &NullArray) -> Self::Output { + // If self and rhs are null array then return a full null array + Ok(BooleanArray::full_null( + self.name(), + &DataType::Boolean, + self.len(), + )) + } +} diff --git a/src/daft-core/src/array/ops/mod.rs b/src/daft-core/src/array/ops/mod.rs index 3608629f93..ba69e727b6 100644 --- a/src/daft-core/src/array/ops/mod.rs +++ b/src/daft-core/src/array/ops/mod.rs @@ -21,6 +21,7 @@ pub(crate) mod groups; mod hash; mod if_else; pub(crate) mod image; +mod is_in; mod len; mod list; mod list_agg; @@ -78,6 +79,11 @@ pub trait DaftLogical { fn xor(&self, rhs: Rhs) -> Self::Output; } +pub trait DaftIsIn { + type Output; + fn is_in(&self, rhs: Rhs) -> Self::Output; +} + pub trait DaftIsNull { type Output; fn is_null(&self) -> Self::Output; diff --git a/src/daft-core/src/datatypes/binary_ops.rs b/src/daft-core/src/datatypes/binary_ops.rs index 7ef1818c30..f20f9fe662 100644 --- a/src/daft-core/src/datatypes/binary_ops.rs +++ b/src/daft-core/src/datatypes/binary_ops.rs @@ -64,6 +64,13 @@ impl DataType { )) }) } + pub fn membership_op( + &self, + other: &Self, + ) -> DaftResult<(DataType, Option, DataType)> { + // membership checks (is_in) use equality checks, so we can use the same logic as comparison ops. + self.comparison_op(other) + } } impl Add for &DataType { diff --git a/src/daft-core/src/series/mod.rs b/src/daft-core/src/series/mod.rs index da266eb42e..1722014e0c 100644 --- a/src/daft-core/src/series/mod.rs +++ b/src/daft-core/src/series/mod.rs @@ -10,7 +10,7 @@ use std::{ }; use crate::{ - array::ops::{from_arrow::FromArrow, full::FullNull}, + array::ops::{from_arrow::FromArrow, full::FullNull, DaftCompare}, datatypes::{DataType, Field, FieldRef}, utils::display_table::make_comfy_table, with_match_daft_types, @@ -26,6 +26,15 @@ pub struct Series { pub inner: Arc, } +impl PartialEq for Series { + fn eq(&self, other: &Self) -> bool { + match self.equal(other) { + Ok(arr) => arr.into_iter().all(|x| x.unwrap_or(false)), + Err(_) => false, + } + } +} + impl Series { pub fn to_arrow(&self) -> Box { self.inner.to_arrow() diff --git a/src/daft-core/src/series/ops/is_in.rs b/src/daft-core/src/series/ops/is_in.rs new file mode 100644 index 0000000000..06e2f64937 --- /dev/null +++ b/src/daft-core/src/series/ops/is_in.rs @@ -0,0 +1,53 @@ +use common_error::DaftResult; + +use crate::{ + array::ops::DaftIsIn, datatypes::BooleanArray, with_match_comparable_daft_types, DataType, + IntoSeries, Series, +}; + +#[cfg(feature = "python")] +use crate::series::ops::py_membership_op_utilfn; + +fn default(name: &str, size: usize) -> DaftResult { + Ok(BooleanArray::from((name, vec![false; size].as_slice())).into_series()) +} + +impl Series { + pub fn is_in(&self, items: &Self) -> DaftResult { + if items.is_empty() { + return default(self.name(), self.len()); + } + + let (output_type, intermediate, comp_type) = + match self.data_type().membership_op(items.data_type()) { + Ok(types) => types, + Err(_) => return default(self.name(), self.len()), + }; + + let (lhs, rhs) = if let Some(ref it) = intermediate { + (self.cast(it)?, items.cast(it)?) + } else { + (self.clone(), items.clone()) + }; + + if let DataType::Boolean = output_type { + match comp_type { + #[cfg(feature = "python")] + DataType::Python => Ok(py_membership_op_utilfn(self, items)? + .downcast::()? + .clone() + .into_series()), + _ => with_match_comparable_daft_types!(comp_type, |$T| { + let casted_lhs = lhs.cast(&comp_type)?; + let casted_rhs = rhs.cast(&comp_type)?; + let lhs = casted_lhs.downcast::<<$T as DaftDataType>::ArrayType>()?; + let rhs = casted_rhs.downcast::<<$T as DaftDataType>::ArrayType>()?; + + Ok(lhs.is_in(rhs)?.into_series()) + }), + } + } else { + unreachable!() + } + } +} diff --git a/src/daft-core/src/series/ops/mod.rs b/src/daft-core/src/series/ops/mod.rs index 7e3386dc00..e328be0cd3 100644 --- a/src/daft-core/src/series/ops/mod.rs +++ b/src/daft-core/src/series/ops/mod.rs @@ -18,6 +18,7 @@ pub mod groups; pub mod hash; pub mod if_else; pub mod image; +pub mod is_in; pub mod len; pub mod list; pub mod not; @@ -88,3 +89,36 @@ macro_rules! py_binary_op_utilfn { } #[cfg(feature = "python")] pub(super) use py_binary_op_utilfn; + +#[cfg(feature = "python")] +pub(super) fn py_membership_op_utilfn(lhs: &Series, rhs: &Series) -> DaftResult { + use crate::python::PySeries; + use crate::DataType; + use pyo3::prelude::*; + + let lhs_casted = lhs.cast(&DataType::Python)?; + let rhs_casted = rhs.cast(&DataType::Python)?; + + let left_pylist = PySeries::from(lhs_casted.clone()).to_pylist()?; + let right_pylist = PySeries::from(rhs_casted.clone()).to_pylist()?; + + let result_series: Series = Python::with_gil(|py| -> PyResult { + let result_pylist = PyModule::import(py, pyo3::intern!(py, "daft.utils"))? + .getattr(pyo3::intern!(py, "python_list_membership_check"))? + .call1((left_pylist, right_pylist))?; + + PyModule::import(py, pyo3::intern!(py, "daft.series"))? + .getattr(pyo3::intern!(py, "Series"))? + .getattr(pyo3::intern!(py, "from_pylist"))? + .call1(( + result_pylist, + lhs_casted.name(), + pyo3::intern!(py, "disallow"), + ))? + .getattr(pyo3::intern!(py, "_series"))? + .extract() + })? + .into(); + + Ok(result_series) +} diff --git a/src/daft-core/src/utils/display_table.rs b/src/daft-core/src/utils/display_table.rs index 3c59cbe1c8..f16dbd1344 100644 --- a/src/daft-core/src/utils/display_table.rs +++ b/src/daft-core/src/utils/display_table.rs @@ -3,6 +3,8 @@ use crate::{ Series, }; +use itertools::Itertools; + pub fn display_date32(val: i32) -> String { let epoch_date = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); let date = if val.is_positive() { @@ -35,6 +37,19 @@ pub fn display_timestamp(val: i64, unit: &TimeUnit, timezone: &Option) - ) } +pub fn display_series_literal(series: &Series) -> String { + if !series.is_empty() { + format!( + "[{}]", + (0..series.len()) + .map(|i| series.str_value(i).unwrap()) + .join(", ") + ) + } else { + "[]".to_string() + } +} + pub fn make_comfy_table>( fields: &[F], columns: Option<&[&Series]>, diff --git a/src/daft-core/src/utils/hashable_float_wrapper.rs b/src/daft-core/src/utils/hashable_float_wrapper.rs index cd1ddd716f..78280715f9 100644 --- a/src/daft-core/src/utils/hashable_float_wrapper.rs +++ b/src/daft-core/src/utils/hashable_float_wrapper.rs @@ -1,13 +1,53 @@ -use std::hash::{Hash, Hasher}; +use std::{ + cmp::Ordering, + hash::{Hash, Hasher}, +}; -// An f64 newtype wrapper that implements basic hashability. -pub struct FloatWrapper(pub f64); +// An float newtype wrapper that implements basic hashability. +pub struct FloatWrapper(pub T); -impl Hash for FloatWrapper { - fn hash(&self, state: &mut H) { - // This is a super basic hash function that could lead to e.g. different hashes for different - // NaN representations. Look to crates like https://docs.rs/ordered-float/latest/ordered_float/index.html - // for a more advanced Hash implementation, if we end up needing it. - state.write(&u64::from_ne_bytes(self.0.to_ne_bytes()).to_ne_bytes()) +macro_rules! impl_hash_for_float_wrapper { + ($T:ident, $UintEquivalent:ident) => { + impl Hash for FloatWrapper<$T> { + fn hash(&self, state: &mut H) { + // This is a super basic hash function that could lead to e.g. different hashes for different + // NaN representations. Look to crates like https://docs.rs/ordered-float/latest/ordered_float/index.html + // for a more advanced Hash implementation, if we end up needing it. + state.write(&$UintEquivalent::from_ne_bytes(self.0.to_ne_bytes()).to_ne_bytes()) + } + } + }; +} +impl_hash_for_float_wrapper!(f32, u32); +impl_hash_for_float_wrapper!(f64, u64); + +impl PartialEq for FloatWrapper { + fn eq(&self, other: &Self) -> bool { + self.0.eq(&other.0) + } +} +impl Eq for FloatWrapper {} +impl Eq for FloatWrapper {} + +impl PartialOrd for FloatWrapper { + fn partial_cmp(&self, other: &Self) -> Option { + self.0.partial_cmp(&other.0) } } +macro_rules! impl_ord_for_float_wrapper { + ($T:ident) => { + impl Ord for FloatWrapper<$T> { + fn cmp(&self, other: &Self) -> Ordering { + // This implementation of cmp considers NaNs to be equal to each other, and less than any other value. + match (self.0.is_nan(), other.0.is_nan()) { + (true, true) => Ordering::Equal, + (true, false) => Ordering::Less, + (false, true) => Ordering::Greater, + (false, false) => self.0.partial_cmp(&other.0).unwrap(), + } + } + } + }; +} +impl_ord_for_float_wrapper!(f32); +impl_ord_for_float_wrapper!(f64); diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index 0e6bcdc932..770af0a049 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -42,6 +42,7 @@ pub enum Expr { Not(ExprRef), IsNull(ExprRef), NotNull(ExprRef), + IsIn(ExprRef, ExprRef), Literal(lit::LiteralValue), IfElse { if_true: ExprRef, @@ -288,6 +289,10 @@ impl Expr { Expr::NotNull(self.clone().into()) } + pub fn is_in(&self, items: &Self) -> Self { + Expr::IsIn(self.clone().into(), items.clone().into()) + } + pub fn eq(&self, other: &Self) -> Self { binary_op(Operator::Eq, self, other) } @@ -347,6 +352,11 @@ impl Expr { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.not_null()")) } + IsIn(expr, items) => { + let child_id = expr.semantic_id(schema); + let items_id = items.semantic_id(schema); + FieldID::new(format!("{child_id}.is_in({items_id})")) + } Function { func, inputs } => { let inputs = inputs .iter() @@ -400,6 +410,7 @@ impl Expr { BinaryOp { left, right, .. } => { vec![left.clone(), right.clone()] } + IsIn(expr, items) => vec![expr.clone(), items.clone()], IfElse { if_true, if_false, @@ -428,6 +439,7 @@ impl Expr { } IsNull(expr) => Ok(Field::new(expr.name()?, DataType::Boolean)), NotNull(expr) => Ok(Field::new(expr.name()?, DataType::Boolean)), + IsIn(expr, ..) => Ok(Field::new(expr.name()?, DataType::Boolean)), Literal(value) => Ok(Field::new("literal", value.get_type())), Function { func, inputs } => func.to_field(inputs.as_slice(), schema, self), BinaryOp { op, left, right } => { @@ -510,6 +522,7 @@ impl Expr { Not(expr) => expr.name(), IsNull(expr) => expr.name(), NotNull(expr) => expr.name(), + IsIn(expr, ..) => expr.name(), Literal(..) => Ok("literal"), Function { func: _, inputs } => inputs.first().unwrap().name(), BinaryOp { @@ -563,6 +576,7 @@ impl Display for Expr { Not(expr) => write!(f, "not({expr})"), IsNull(expr) => write!(f, "is_null({expr})"), NotNull(expr) => write!(f, "not_null({expr})"), + IsIn(expr, items) => write!(f, "{expr} in {items}"), Literal(val) => write!(f, "lit({val})"), Function { func, inputs } => { write!(f, "{}(", func.fn_name())?; diff --git a/src/daft-dsl/src/lib.rs b/src/daft-dsl/src/lib.rs index 387acc53e9..78696a72e1 100644 --- a/src/daft-dsl/src/lib.rs +++ b/src/daft-dsl/src/lib.rs @@ -25,6 +25,7 @@ pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> { parent.add_wrapped(wrap_pyfunction!(python::lit))?; parent.add_wrapped(wrap_pyfunction!(python::date_lit))?; parent.add_wrapped(wrap_pyfunction!(python::timestamp_lit))?; + parent.add_wrapped(wrap_pyfunction!(python::series_lit))?; parent.add_wrapped(wrap_pyfunction!(python::udf))?; parent.add_wrapped(wrap_pyfunction!(python::eq))?; diff --git a/src/daft-dsl/src/lit.rs b/src/daft-dsl/src/lit.rs index fd231d48ac..ec8923cff9 100644 --- a/src/daft-dsl/src/lit.rs +++ b/src/daft-dsl/src/lit.rs @@ -1,9 +1,5 @@ -use std::{ - fmt::{Display, Formatter, Result}, - hash::{Hash, Hasher}, -}; - use crate::expr::Expr; + use daft_core::utils::hashable_float_wrapper::FloatWrapper; use daft_core::{array::ops::full::FullNull, datatypes::DataType}; use daft_core::{ @@ -12,9 +8,13 @@ use daft_core::{ TimeUnit, }, series::Series, - utils::display_table::{display_date32, display_timestamp}, + utils::display_table::{display_date32, display_series_literal, display_timestamp}, }; use serde::{Deserialize, Serialize}; +use std::{ + fmt::{Display, Formatter, Result}, + hash::{Hash, Hasher}, +}; #[cfg(feature = "python")] use crate::pyobject::DaftPyObject; @@ -57,6 +57,8 @@ pub enum LiteralValue { Date(i32), /// A 64-bit floating point number. Float64(f64), + /// A list + Series(Series), /// Python object. #[cfg(feature = "python")] Python(DaftPyObject), @@ -86,6 +88,13 @@ impl Hash for LiteralValue { } // Wrap float64 in hashable newtype. Float64(n) => FloatWrapper(*n).hash(state), + Series(series) => { + let hash_result = series.hash(None); + match hash_result { + Ok(hash) => hash.into_iter().for_each(|i| i.hash(state)), + Err(_) => panic!("Cannot hash series"), + } + } #[cfg(feature = "python")] Python(py_obj) => py_obj.hash(state), } @@ -108,6 +117,7 @@ impl Display for LiteralValue { Date(val) => write!(f, "{}", display_date32(*val)), Timestamp(val, tu, tz) => write!(f, "{}", display_timestamp(*val, tu, tz)), Float64(val) => write!(f, "{val:.1}"), + Series(series) => write!(f, "{}", display_series_literal(series)), #[cfg(feature = "python")] Python(pyobj) => write!(f, "PyObject({})", { use pyo3::prelude::*; @@ -137,6 +147,7 @@ impl LiteralValue { Date(_) => DataType::Date, Timestamp(_, tu, tz) => DataType::Timestamp(*tu, tz.clone()), Float64(_) => DataType::Float64, + Series(series) => series.data_type().clone(), #[cfg(feature = "python")] Python(_) => DataType::Python, } @@ -164,6 +175,7 @@ impl LiteralValue { TimestampArray::new(Field::new("literal", self.get_type()), physical).into_series() } Float64(val) => Float64Array::from(("literal", [*val].as_slice())).into_series(), + Series(series) => series.clone().rename("literal"), #[cfg(feature = "python")] Python(val) => PythonArray::from(("literal", vec![val.pyobject.clone()])).into_series(), }; @@ -204,6 +216,12 @@ impl<'a> Literal for &'a [u8] { } } +impl Literal for Series { + fn lit(self) -> Expr { + Expr::Literal(LiteralValue::Series(self)) + } +} + #[cfg(feature = "python")] impl Literal for pyo3::PyObject { fn lit(self) -> Expr { diff --git a/src/daft-dsl/src/optimization.rs b/src/daft-dsl/src/optimization.rs index cd7213bc1a..796289f5f7 100644 --- a/src/daft-dsl/src/optimization.rs +++ b/src/daft-dsl/src/optimization.rs @@ -30,6 +30,7 @@ pub fn requires_computation(e: &Expr) -> bool { | Expr::Not(..) | Expr::IsNull(..) | Expr::NotNull(..) + | Expr::IsIn { .. } | Expr::IfElse { .. } => true, } } diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index 98b90ef888..eff9bc2738 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -3,6 +3,7 @@ use std::collections::HashSet; use std::hash::{Hash, Hasher}; use daft_core::python::datatype::PyTimeUnit; +use daft_core::python::PySeries; use serde::{Deserialize, Serialize}; use crate::{functions, optimization, Expr, LiteralValue}; @@ -39,6 +40,12 @@ pub fn timestamp_lit(val: i64, tu: PyTimeUnit, tz: Option) -> PyResult

PyResult { + let expr = Expr::Literal(LiteralValue::Series(series.series)); + Ok(expr.into()) +} + #[pyfunction] pub fn lit(item: &PyAny) -> PyResult { if item.is_instance_of::() { @@ -239,6 +246,10 @@ impl PyExpr { Ok(self.expr.not_null().into()) } + pub fn is_in(&self, other: &Self) -> PyResult { + Ok(self.expr.is_in(&other.expr).into()) + } + pub fn name(&self) -> PyResult<&str> { Ok(self.expr.name()?) } diff --git a/src/daft-dsl/src/treenode.rs b/src/daft-dsl/src/treenode.rs index 6bfde1b2d3..8e86abb4f5 100644 --- a/src/daft-dsl/src/treenode.rs +++ b/src/daft-dsl/src/treenode.rs @@ -26,6 +26,7 @@ impl TreeNode for Expr { } } BinaryOp { op: _, left, right } => vec![left.as_ref(), right.as_ref()], + IsIn(expr, items) => vec![expr.as_ref(), items.as_ref()], Column(_) | Literal(_) => vec![], Function { func: _, inputs } => inputs.iter().collect::>(), IfElse { @@ -70,6 +71,10 @@ impl TreeNode for Expr { Not(expr) => Not(transform(expr.as_ref().clone())?.into()), IsNull(expr) => IsNull(transform(expr.as_ref().clone())?.into()), NotNull(expr) => NotNull(transform(expr.as_ref().clone())?.into()), + IsIn(expr, items) => IsIn( + transform(expr.as_ref().clone())?.into(), + transform(items.as_ref().clone())?.into(), + ), IfElse { if_true, if_false, diff --git a/src/daft-plan/src/logical_ops/project.rs b/src/daft-plan/src/logical_ops/project.rs index 9ae2890129..f05b3d79a0 100644 --- a/src/daft-plan/src/logical_ops/project.rs +++ b/src/daft-plan/src/logical_ops/project.rs @@ -254,6 +254,19 @@ fn replace_column_with_semantic_id( |_| e, ) } + Expr::IsIn(child, items) => { + let child = + replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema); + let items = + replace_column_with_semantic_id(items.clone(), subexprs_to_replace, schema); + if child.is_no() && items.is_no() { + Transformed::No(e) + } else { + Transformed::Yes( + Expr::IsIn(child.unwrap().clone(), items.unwrap().clone()).into(), + ) + } + } Expr::BinaryOp { op, left, right } => { let left = replace_column_with_semantic_id(left.clone(), subexprs_to_replace, schema); diff --git a/src/daft-plan/src/physical_ops/project.rs b/src/daft-plan/src/physical_ops/project.rs index f9c8617dcd..00dd5def53 100644 --- a/src/daft-plan/src/physical_ops/project.rs +++ b/src/daft-plan/src/physical_ops/project.rs @@ -162,6 +162,17 @@ impl Project { )?; Ok(Expr::NotNull(newchild.into())) } + Expr::IsIn(child, items) => { + let newchild = Self::translate_partition_spec_expr( + child.as_ref(), + old_colname_to_new_colname, + )?; + let newitems = Self::translate_partition_spec_expr( + items.as_ref(), + old_colname_to_new_colname, + )?; + Ok(Expr::IsIn(newchild.into(), newitems.into())) + } Expr::IfElse { if_true, if_false, diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 39f59c4eda..9cb4bb96a7 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -326,6 +326,9 @@ impl Table { Not(child) => !(self.eval_expression(child)?), IsNull(child) => self.eval_expression(child)?.is_null(), NotNull(child) => self.eval_expression(child)?.not_null(), + IsIn(child, items) => self + .eval_expression(child)? + .is_in(&self.eval_expression(items)?), BinaryOp { op, left, right } => { let lhs = self.eval_expression(left)?; let rhs = self.eval_expression(right)?; diff --git a/tests/expressions/test_expressions.py b/tests/expressions/test_expressions.py index 8dac113fc2..0884690963 100644 --- a/tests/expressions/test_expressions.py +++ b/tests/expressions/test_expressions.py @@ -9,6 +9,7 @@ from daft.datatype import DataType, TimeUnit from daft.expressions import col, lit from daft.expressions.testing import expr_structurally_equal +from daft.series import Series from daft.table import MicroPartition @@ -23,6 +24,7 @@ (b"a", DataType.binary()), (True, DataType.bool()), (None, DataType.null()), + (Series.from_pylist([1, 2, 3]), DataType.int64()), (date(2023, 1, 1), DataType.date()), (datetime(2023, 1, 1), DataType.timestamp(timeunit=TimeUnit.from_str("us"))), (datetime(2022, 1, 1, tzinfo=pytz.utc), DataType.timestamp(timeunit=TimeUnit.from_str("us"), timezone="UTC")), @@ -173,3 +175,9 @@ def test_datetime_lit_pre_epoch() -> None: d = lit(datetime(1950, 1, 1)) output = repr(d) assert output == "lit(1950-01-01T00:00:00.000000)" + + +def test_repr_series_lit() -> None: + s = lit(Series.from_pylist([1, 2, 3])) + output = repr(s) + assert output == "lit([1, 2, 3])" diff --git a/tests/table/test_is_in.py b/tests/table/test_is_in.py new file mode 100644 index 0000000000..a351937505 --- /dev/null +++ b/tests/table/test_is_in.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +import datetime + +import numpy as np +import pandas as pd +import pyarrow as pa +import pytest + +from daft import col +from daft.series import Series +from daft.table import MicroPartition + + +class CustomClass: + def __init__(self, x): + self.x = x + + def __eq__(self, other): + return self.x == other.x + + def __hash__(self): + return hash(self.x) + + +class CustomClassWithoutHash: + def __init__(self, x): + self.x = x + + def __eq__(self, other): + return self.x == other.x + + +@pytest.mark.parametrize( + "input,items,expected", + [ + pytest.param([None, None], [None], [None, None], id="NullColumn"), + pytest.param([True, False], [False], [False, True], id="BooleanColumn"), + pytest.param(["a", "b", "c", "d"], ["a", "b"], [True, True, False, False], id="StringColumn"), + pytest.param([b"a", b"b", b"c", b"d"], [b"a", b"b"], [True, True, False, False], id="BinaryColumn"), + pytest.param([-1, 2, 3, 4], [-1, 2], [True, True, False, False], id="IntColumn"), + pytest.param([-1.0, 2.0, 3.0, 4.0], [-1.0, 2.0], [True, True, False, False], id="FloatColumn"), + pytest.param( + [datetime.date.today(), datetime.date.today() - datetime.timedelta(days=1)], + [datetime.date.today()], + [True, False], + id="DateColumn", + ), + pytest.param( + [datetime.datetime(2022, 1, 1), datetime.datetime(2023, 1, 1)], + [datetime.datetime(2022, 1, 1)], + [True, False], + id="TimestampColumn", + ), + pytest.param([CustomClass(1), CustomClass(2)], [CustomClass(1)], [True, False], id="ObjectColumn"), + pytest.param( + [CustomClassWithoutHash(1), CustomClassWithoutHash(2)], + [CustomClassWithoutHash(1)], + [True, False], + id="ObjectWithoutHashColumn", + ), + ], +) +def test_table_expr_is_in_same_types(input, items, expected) -> None: + daft_table = MicroPartition.from_pydict({"input": input}) + daft_table = daft_table.eval_expression_list([col("input").is_in(items)]) + pydict = daft_table.to_pydict() + + assert pydict["input"] == expected + + +@pytest.mark.parametrize( + "input,items,expected", + [ + # Int + pytest.param([-1, 2, 3, 4], ["-1", "2"], [True, True, False, False], id="IntWithString"), + pytest.param([1, 2, 3, 4], [1.0, 2.0], [True, True, False, False], id="IntWithFloat"), + pytest.param([0, 1, 2, 3], [True], [False, True, False, False], id="IntWithBool"), + # Float + pytest.param([-1.0, 2.0, 3.0, 4.0], ["-1.0", "2.0"], [True, True, False, False], id="FloatWithString"), + pytest.param([1.0, 2.0, 3.0, 4.0], [1, 2], [True, True, False, False], id="FloatWithInt"), + pytest.param([0.0, 1.0, 2.0, 3.0], [True], [False, True, False, False], id="FloatWithBool"), + # String + pytest.param(["1", "2", "3", "4"], [1, 2], [True, True, False, False], id="StringWithInt"), + pytest.param(["1.0", "2.0", "3.0", "4.0"], [1.0, 2.0], [True, True, False, False], id="StringWithFloat"), + # Bool + pytest.param([True, False, None], [1, 0], [True, True, None], id="BoolWithInt"), + pytest.param([True, False, None], [1.0], [True, False, None], id="BoolWithFloat"), + # Date + pytest.param( + [datetime.date.today(), datetime.date.today() - datetime.timedelta(days=1)], + [datetime.datetime.today()], + [True, False], + id="DateWithTimestamp", + ), + # Timestamp + pytest.param( + [datetime.datetime(2022, 1, 1), datetime.datetime(2023, 1, 1)], + [datetime.date(2022, 1, 1)], + [True, False], + id="TimestampWithDate", + ), + ], +) +def test_table_expr_is_in_different_types_castable(input, items, expected) -> None: + daft_table = MicroPartition.from_pydict({"input": input}) + daft_table = daft_table.eval_expression_list([col("input").is_in(items)]) + pydict = daft_table.to_pydict() + + assert pydict["input"] == expected + + +@pytest.mark.parametrize( + "input,items,expected", + [ + pytest.param([None, None, None], [None], [None, None, None], id="NullColumn"), + pytest.param([True, False, None], [None], [False, False, None], id="BooleanColumn"), + pytest.param(["a", "b", None], [None], [False, False, None], id="StringColumn"), + pytest.param([b"a", b"b", None], [None], [False, False, None], id="BinaryColumn"), + pytest.param([1, 2, None], [None], [False, False, None], id="IntColumn"), + pytest.param([1.0, 2.0, None], [None], [False, False, None], id="FloatColumn"), + ], +) +def test_table_expr_is_in_items_is_none(input, items, expected) -> None: + daft_table = MicroPartition.from_pydict({"input": input}) + daft_table = daft_table.eval_expression_list([col("input").is_in(items)]) + pydict = daft_table.to_pydict() + + assert pydict["input"] == expected + + +@pytest.mark.parametrize( + "input,items,expected", + [ + pytest.param([1, 2, 3, 4], np.array([1, 2]), [True, True, False, False], id="NumpyArray"), + pytest.param([1, 2, 3, 4], pa.array([1, 2], type=pa.int8()), [True, True, False, False], id="PyArrowArray"), + pytest.param([1, 2, 3, 4], pd.Series([1, 2]), [True, True, False, False], id="PandasSeries"), + pytest.param([1, 2, 3, 4], Series.from_pylist([1, 2]), [True, True, False, False], id="DaftSeries"), + ], +) +def test_table_expr_is_in_different_input_types(input, items, expected) -> None: + daft_table = MicroPartition.from_pydict({"input": input}) + daft_table = daft_table.eval_expression_list([col("input").is_in(items)]) + pydict = daft_table.to_pydict() + + assert pydict["input"] == expected + + +def test_table_expr_is_in_with_another_df_column() -> None: + daft_table = MicroPartition.from_pydict({"input": [1, 2, 3, 4], "items": [3, 4, 5, 6]}) + daft_table = daft_table.eval_expression_list([col("input").is_in(col("items"))]) + pydict = daft_table.to_pydict() + + assert pydict["input"] == [False, False, True, True] + + +def test_table_expr_is_in_empty_items() -> None: + daft_table = MicroPartition.from_pydict({"input": [1, 2, 3, 4]}) + daft_table = daft_table.eval_expression_list([col("input").is_in([])]) + pydict = daft_table.to_pydict() + + assert pydict["input"] == [False, False, False, False] + + +def test_table_expr_is_in_items_invalid_input() -> None: + daft_table = MicroPartition.from_pydict({"input": [1, 2, 3, 4]}) + + with pytest.raises(ValueError, match="Creating a Series from data of type"): + daft_table.eval_expression_list([col("input").is_in(1)])