From b861f4c782d3dda88e776ddbefcd27906ea81a5d Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 14 Jun 2024 10:45:16 -0700 Subject: [PATCH] [FEAT] fill_nan and not_nan expressions (#2313) Adds expressions for `fill_nan` and `not_nan` Todo: - add expression for `fill_na`, which is a convenience method for doing fill_null and fill_nan together, see: https://github.com/Eventual-Inc/Daft/issues/571 --- daft/daft.pyi | 4 ++ daft/expressions/expressions.py | 42 +++++++++++++++ daft/series.py | 9 ++++ docs/source/api_docs/expressions.rst | 3 ++ src/daft-core/src/array/ops/float.rs | 32 +++++++++++- src/daft-core/src/array/ops/mod.rs | 5 ++ src/daft-core/src/python/series.rs | 8 +++ src/daft-core/src/series/ops/float.rs | 12 +++++ src/daft-dsl/src/functions/float/fill_nan.rs | 48 +++++++++++++++++ src/daft-dsl/src/functions/float/mod.rs | 24 +++++++++ src/daft-dsl/src/functions/float/not_nan.rs | 51 ++++++++++++++++++ src/daft-dsl/src/python.rs | 10 ++++ tests/expressions/test_expressions.py | 7 +++ tests/expressions/typing/test_float.py | 20 +++++++ tests/series/test_float.py | 52 +++++++++++++++++++ .../table/{test_fill_null.py => test_fill.py} | 29 ++++++++++- tests/table/test_filter.py | 7 +++ 17 files changed, 360 insertions(+), 3 deletions(-) create mode 100644 src/daft-dsl/src/functions/float/fill_nan.rs create mode 100644 src/daft-dsl/src/functions/float/not_nan.rs rename tests/table/{test_fill_null.py => test_fill.py} (60%) diff --git a/daft/daft.pyi b/daft/daft.pyi index 80e01e0b34..ba8ba7213d 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -1043,6 +1043,8 @@ class PyExpr: def __reduce__(self) -> tuple: ... def is_nan(self) -> PyExpr: ... def is_inf(self) -> PyExpr: ... + def not_nan(self) -> PyExpr: ... + def fill_nan(self, fill_value: PyExpr) -> PyExpr: ... def dt_date(self) -> PyExpr: ... def dt_day(self) -> PyExpr: ... def dt_hour(self) -> PyExpr: ... @@ -1209,6 +1211,8 @@ class PySeries: def utf8_substr(self, start: PySeries, length: PySeries | None = None) -> PySeries: ... def is_nan(self) -> PySeries: ... def is_inf(self) -> PySeries: ... + def not_nan(self) -> PySeries: ... + def fill_nan(self, fill_value: PySeries) -> PySeries: ... def dt_date(self) -> PySeries: ... def dt_day(self) -> PySeries: ... def dt_hour(self) -> PySeries: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 957b80692d..fc039b29cf 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -835,6 +835,48 @@ def is_inf(self) -> Expression: """ return Expression._from_pyexpr(self._expr.is_inf()) + def not_nan(self) -> Expression: + """Checks if values are not NaN (a special float value indicating not-a-number) + + .. NOTE:: + Nulls will be propagated! I.e. this operation will return a null for null values. + + Example: + >>> # [1., None, NaN] -> [True, None, False] + >>> col("x").not_nan() + + Returns: + Expression: Boolean Expression indicating whether values are not invalid. + """ + return Expression._from_pyexpr(self._expr.not_nan()) + + def fill_nan(self, fill_value: Expression) -> Expression: + """Fills NaN values in the Expression with the provided fill_value + + Example: + >>> df = daft.from_pydict({"data": [1.1, float("nan"), 3.3]}) + >>> df = df.with_column("filled", df["data"].float.fill_nan(2.2)) + >>> df.show() + ╭─────────┬─────────╮ + │ data ┆ filled │ + │ --- ┆ --- │ + │ Float64 ┆ Float64 │ + ╞═════════╪═════════╡ + │ 1.1 ┆ 1.1 │ + ├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤ + │ NaN ┆ 2.2 │ + ├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤ + │ 3.3 ┆ 3.3 │ + ╰─────────┴─────────╯ + + Returns: + Expression: Expression with Nan values filled with the provided fill_value + """ + + fill_value = Expression._to_expression(fill_value) + expr = self._expr.fill_nan(fill_value._expr) + return Expression._from_pyexpr(expr) + class ExpressionDatetimeNamespace(ExpressionNamespace): def date(self) -> Expression: diff --git a/daft/series.py b/daft/series.py index e36d64edfe..bf1c0105c8 100644 --- a/daft/series.py +++ b/daft/series.py @@ -638,6 +638,15 @@ def is_nan(self) -> Series: def is_inf(self) -> Series: return Series._from_pyseries(self._series.is_inf()) + def not_nan(self) -> Series: + return Series._from_pyseries(self._series.not_nan()) + + def fill_nan(self, fill_value: Series) -> Series: + if not isinstance(fill_value, Series): + raise ValueError(f"expected another Series but got {type(fill_value)}") + assert self._series is not None and fill_value._series is not None + return Series._from_pyseries(self._series.fill_nan(fill_value._series)) + class SeriesStringNamespace(SeriesNamespace): def endswith(self, suffix: Series) -> Series: diff --git a/docs/source/api_docs/expressions.rst b/docs/source/api_docs/expressions.rst index bff11275a4..a8c06cf6f3 100644 --- a/docs/source/api_docs/expressions.rst +++ b/docs/source/api_docs/expressions.rst @@ -29,6 +29,7 @@ Generic Expression.if_else Expression.is_null Expression.not_null + Expression.fill_null Expression.apply .. _api-numeric-expression-operations: @@ -160,6 +161,8 @@ The following methods are available under the ``expr.float`` attribute. Expression.float.is_inf Expression.float.is_nan + Expression.float.not_nan + Expression.float.fill_nan .. _api-expressions-temporal: diff --git a/src/daft-core/src/array/ops/float.rs b/src/daft-core/src/array/ops/float.rs index 7fddedd555..afd0838563 100644 --- a/src/daft-core/src/array/ops/float.rs +++ b/src/daft-core/src/array/ops/float.rs @@ -6,7 +6,7 @@ use common_error::DaftResult; use num_traits::Float; use super::DaftIsInf; -use super::DaftIsNan; +use super::{DaftIsNan, DaftNotNan}; use super::as_arrow::AsArrow; @@ -68,3 +68,33 @@ impl DaftIsInf for DataArray { ))) } } + +impl DaftNotNan for DataArray +where + T: DaftFloatType, + ::Native: Float, +{ + type Output = DaftResult>; + + fn not_nan(&self) -> Self::Output { + let arrow_array = self.as_arrow(); + let result_arrow_array = arrow2::array::BooleanArray::from_trusted_len_values_iter( + arrow_array.values_iter().map(|v| !v.is_nan()), + ) + .with_validity(arrow_array.validity().cloned()); + Ok(BooleanArray::from((self.name(), result_arrow_array))) + } +} + +impl DaftNotNan for DataArray { + type Output = DaftResult>; + + fn not_nan(&self) -> Self::Output { + // Entire array is null; since we don't consider nulls to be NaNs, return an all null (invalid) boolean array. + Ok(BooleanArray::from(( + self.name(), + arrow2::array::BooleanArray::from_slice(vec![false; self.len()]) + .with_validity(Some(arrow2::bitmap::Bitmap::from(vec![false; self.len()]))), + ))) + } +} diff --git a/src/daft-core/src/array/ops/mod.rs b/src/daft-core/src/array/ops/mod.rs index 49bfdec6cb..ad1e98f68c 100644 --- a/src/daft-core/src/array/ops/mod.rs +++ b/src/daft-core/src/array/ops/mod.rs @@ -125,6 +125,11 @@ pub trait DaftIsInf { fn is_inf(&self) -> Self::Output; } +pub trait DaftNotNan { + type Output; + fn not_nan(&self) -> Self::Output; +} + pub type VecIndices = Vec; pub type GroupIndices = Vec; pub type GroupIndicesPair = (VecIndices, GroupIndices); diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index 446e9822cd..b305a635aa 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -465,6 +465,14 @@ impl PySeries { Ok(self.series.is_inf()?.into()) } + pub fn not_nan(&self) -> PyResult { + Ok(self.series.not_nan()?.into()) + } + + pub fn fill_nan(&self, fill_value: &Self) -> PyResult { + Ok(self.series.fill_nan(&fill_value.series)?.into()) + } + pub fn dt_date(&self) -> PyResult { Ok(self.series.dt_date()?.into()) } diff --git a/src/daft-core/src/series/ops/float.rs b/src/daft-core/src/series/ops/float.rs index 890b024fd1..05af1c7a45 100644 --- a/src/daft-core/src/series/ops/float.rs +++ b/src/daft-core/src/series/ops/float.rs @@ -18,4 +18,16 @@ impl Series { Ok(DaftIsInf::is_inf(self.downcast::<<$T as DaftDataType>::ArrayType>()?)?.into_series()) }) } + + pub fn not_nan(&self) -> DaftResult { + use crate::array::ops::DaftNotNan; + with_match_float_and_null_daft_types!(self.data_type(), |$T| { + Ok(DaftNotNan::not_nan(self.downcast::<<$T as DaftDataType>::ArrayType>()?)?.into_series()) + }) + } + + pub fn fill_nan(&self, fill_value: &Self) -> DaftResult { + let predicate = self.not_nan()?; + self.if_else(fill_value, &predicate) + } } diff --git a/src/daft-dsl/src/functions/float/fill_nan.rs b/src/daft-dsl/src/functions/float/fill_nan.rs new file mode 100644 index 0000000000..9c6c44a93e --- /dev/null +++ b/src/daft-dsl/src/functions/float/fill_nan.rs @@ -0,0 +1,48 @@ +use daft_core::{ + datatypes::Field, schema::Schema, series::Series, utils::supertype::try_get_supertype, +}; + +use crate::ExprRef; + +use crate::functions::FunctionExpr; +use common_error::{DaftError, DaftResult}; + +use super::super::FunctionEvaluator; + +pub(super) struct FillNanEvaluator {} + +impl FunctionEvaluator for FillNanEvaluator { + fn fn_name(&self) -> &'static str { + "fill_nan" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + match inputs { + [data, fill_value] => match (data.to_field(schema), fill_value.to_field(schema)) { + (Ok(data_field), Ok(fill_value_field)) => { + match (&data_field.dtype.is_floating(), &fill_value_field.dtype.is_floating(), try_get_supertype(&data_field.dtype, &fill_value_field.dtype)) { + (true, true, Ok(dtype)) => Ok(Field::new(data_field.name, dtype)), + _ => Err(DaftError::TypeError(format!( + "Expects input to fill_nan to be float, but received {data_field} and {fill_value_field}", + ))), + } + } + (Err(e), _) | (_, Err(e)) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + match inputs { + [data, fill_value] => data.fill_nan(fill_value), + _ => Err(DaftError::ValueError(format!( + "Expected 2 input args, got {}", + inputs.len() + ))), + } + } +} diff --git a/src/daft-dsl/src/functions/float/mod.rs b/src/daft-dsl/src/functions/float/mod.rs index d5c9218b94..5ef47f004c 100644 --- a/src/daft-dsl/src/functions/float/mod.rs +++ b/src/daft-dsl/src/functions/float/mod.rs @@ -1,8 +1,12 @@ +mod fill_nan; mod is_inf; mod is_nan; +mod not_nan; +use fill_nan::FillNanEvaluator; use is_inf::IsInfEvaluator; use is_nan::IsNanEvaluator; +use not_nan::NotNanEvaluator; use serde::{Deserialize, Serialize}; use crate::{Expr, ExprRef}; @@ -13,6 +17,8 @@ use super::FunctionEvaluator; pub enum FloatExpr { IsNan, IsInf, + NotNan, + FillNan, } impl FloatExpr { @@ -22,6 +28,8 @@ impl FloatExpr { match self { IsNan => &IsNanEvaluator {}, IsInf => &IsInfEvaluator {}, + NotNan => &NotNanEvaluator {}, + FillNan => &FillNanEvaluator {}, } } } @@ -41,3 +49,19 @@ pub fn is_inf(data: ExprRef) -> ExprRef { } .into() } + +pub fn not_nan(data: ExprRef) -> ExprRef { + Expr::Function { + func: super::FunctionExpr::Float(FloatExpr::NotNan), + inputs: vec![data], + } + .into() +} + +pub fn fill_nan(data: ExprRef, fill_value: ExprRef) -> ExprRef { + Expr::Function { + func: super::FunctionExpr::Float(FloatExpr::FillNan), + inputs: vec![data, fill_value], + } + .into() +} diff --git a/src/daft-dsl/src/functions/float/not_nan.rs b/src/daft-dsl/src/functions/float/not_nan.rs new file mode 100644 index 0000000000..45d7e643a9 --- /dev/null +++ b/src/daft-dsl/src/functions/float/not_nan.rs @@ -0,0 +1,51 @@ +use daft_core::{ + datatypes::{DataType, Field}, + schema::Schema, + series::Series, +}; + +use crate::ExprRef; + +use crate::functions::FunctionExpr; +use common_error::{DaftError, DaftResult}; + +use super::super::FunctionEvaluator; + +pub(super) struct NotNanEvaluator {} + +impl FunctionEvaluator for NotNanEvaluator { + fn fn_name(&self) -> &'static str { + "not_nan" + } + + fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult { + match inputs { + [data] => match data.to_field(schema) { + Ok(data_field) => match &data_field.dtype { + // DataType::Float16 | + DataType::Float32 | DataType::Float64 => { + Ok(Field::new(data_field.name, DataType::Boolean)) + } + _ => Err(DaftError::TypeError(format!( + "Expects input to is_nan to be float, but received {data_field}", + ))), + }, + Err(e) => Err(e), + }, + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult { + match inputs { + [data] => data.not_nan(), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input args, got {}", + inputs.len() + ))), + } + } +} diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index 335168a253..29cb0b0952 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -480,6 +480,16 @@ impl PyExpr { Ok(is_inf(self.into()).into()) } + pub fn not_nan(&self) -> PyResult { + use functions::float::not_nan; + Ok(not_nan(self.into()).into()) + } + + pub fn fill_nan(&self, fill_value: &Self) -> PyResult { + use functions::float::fill_nan; + Ok(fill_nan(self.into(), fill_value.expr.clone()).into()) + } + pub fn dt_date(&self) -> PyResult { use functions::temporal::date; Ok(date(self.into()).into()) diff --git a/tests/expressions/test_expressions.py b/tests/expressions/test_expressions.py index 8b3bba34b8..589491b957 100644 --- a/tests/expressions/test_expressions.py +++ b/tests/expressions/test_expressions.py @@ -271,6 +271,13 @@ def test_float_is_inf() -> None: assert output == "is_inf(col(a))" +def test_float_not_nan() -> None: + a = col("a") + c = a.float.not_nan() + output = repr(c) + assert output == "not_nan(col(a))" + + def test_date_lit_post_epoch() -> None: d = lit(date(2022, 1, 1)) output = repr(d) diff --git a/tests/expressions/typing/test_float.py b/tests/expressions/typing/test_float.py index 8b641ae097..88d576764c 100644 --- a/tests/expressions/typing/test_float.py +++ b/tests/expressions/typing/test_float.py @@ -21,3 +21,23 @@ def test_float_is_inf(unary_data_fixture): run_kernel=unary_data_fixture.float.is_inf, resolvable=unary_data_fixture.datatype() in (DataType.float32(), DataType.float64()), ) + + +def test_float_not_nan(unary_data_fixture): + assert_typing_resolve_vs_runtime_behavior( + data=[unary_data_fixture], + expr=col(unary_data_fixture.name()).float.not_nan(), + run_kernel=unary_data_fixture.float.not_nan, + resolvable=unary_data_fixture.datatype() in (DataType.float32(), DataType.float64()), + ) + + +def test_fill_nan(binary_data_fixture): + lhs, rhs = binary_data_fixture + assert_typing_resolve_vs_runtime_behavior( + data=binary_data_fixture, + expr=col(lhs.name()).float.fill_nan(rhs), + run_kernel=lambda: lhs.float.fill_nan(rhs), + resolvable=lhs.datatype() in (DataType.float32(), DataType.float64()) + and rhs.datatype() in (DataType.float32(), DataType.float64()), + ) diff --git a/tests/series/test_float.py b/tests/series/test_float.py index 32296f0905..fdc0696c5d 100644 --- a/tests/series/test_float.py +++ b/tests/series/test_float.py @@ -52,3 +52,55 @@ def test_float_is_inf_all_null() -> None: s = Series.from_arrow(pa.array([None, None, None])) result = s.float.is_inf() assert result.to_pylist() == [None, None, None] + + +def test_float_not_nan() -> None: + s = Series.from_arrow(pa.array([1.0, np.nan, 3.0, float("nan")])) + result = s.float.not_nan() + assert result.to_pylist() == [True, False, True, False] + + +def test_float_not_nan_with_nulls() -> None: + s = Series.from_arrow(pa.array([1.0, None, np.nan, 3.0, None, float("nan")])) + result = s.float.not_nan() + assert result.to_pylist() == [True, None, False, True, None, False] + + +def test_float_not_nan_empty() -> None: + s = Series.from_arrow(pa.array([], type=pa.float64())) + result = s.float.not_nan() + assert result.to_pylist() == [] + + +def test_float_not_nan_all_null() -> None: + s = Series.from_arrow(pa.array([None, None, None])) + result = s.float.not_nan() + assert result.to_pylist() == [None, None, None] + + +def test_float_fill_nan() -> None: + s = Series.from_arrow(pa.array([1.0, np.nan, 3.0, float("nan")])) + fill = Series.from_arrow(pa.array([2.0])) + result = s.float.fill_nan(fill) + assert result.to_pylist() == [1.0, 2.0, 3.0, 2.0] + + +def test_float_fill_nan_with_nulls() -> None: + s = Series.from_arrow(pa.array([1.0, None, np.nan, 3.0, None, float("nan")])) + fill = Series.from_arrow(pa.array([2.0])) + result = s.float.fill_nan(fill) + assert result.to_pylist() == [1.0, None, 2.0, 3.0, None, 2.0] + + +def test_float_fill_nan_empty() -> None: + s = Series.from_arrow(pa.array([], type=pa.float64())) + fill = Series.from_arrow(pa.array([2.0])) + result = s.float.fill_nan(fill) + assert result.to_pylist() == [] + + +def test_float_fill_nan_all_null() -> None: + s = Series.from_arrow(pa.array([None, None, None])) + fill = Series.from_arrow(pa.array([2.0])) + result = s.float.fill_nan(fill) + assert result.to_pylist() == [None, None, None] diff --git a/tests/table/test_fill_null.py b/tests/table/test_fill.py similarity index 60% rename from tests/table/test_fill_null.py rename to tests/table/test_fill.py index a61d7601a4..fc873fb193 100644 --- a/tests/table/test_fill_null.py +++ b/tests/table/test_fill.py @@ -4,6 +4,7 @@ import pytest +from daft.datatype import DataType from daft.expressions.expressions import col from daft.table.micropartition import MicroPartition @@ -20,12 +21,20 @@ pytest.param( [datetime.date.today(), None, datetime.date(2023, 1, 1)], datetime.date(2022, 1, 1), - [datetime.date.today(), datetime.date(2022, 1, 1), datetime.date(2023, 1, 1)], + [ + datetime.date.today(), + datetime.date(2022, 1, 1), + datetime.date(2023, 1, 1), + ], ), pytest.param( [datetime.datetime(2022, 1, 1), None, datetime.datetime(2023, 1, 1)], datetime.datetime(2022, 1, 1), - [datetime.datetime(2022, 1, 1), datetime.datetime(2022, 1, 1), datetime.datetime(2023, 1, 1)], + [ + datetime.datetime(2022, 1, 1), + datetime.datetime(2022, 1, 1), + datetime.datetime(2023, 1, 1), + ], ), ], ) @@ -35,3 +44,19 @@ def test_table_expr_fill_null(input, fill_value, expected) -> None: pydict = daft_table.to_pydict() assert pydict["input"] == expected + + +@pytest.mark.parametrize( + "float_dtype", + [DataType.float32(), DataType.float64()], +) +def test_table_expr_fill_nan(float_dtype) -> None: + input = [1.0, None, 3.0, float("nan")] + fill_value = 2.0 + expected = [1.0, None, 3.0, 2.0] + + daft_table = MicroPartition.from_pydict({"input": input}) + daft_table = daft_table.eval_expression_list([col("input").cast(float_dtype).float.fill_nan(fill_value)]) + pydict = daft_table.to_pydict() + + assert pydict["input"] == expected diff --git a/tests/table/test_filter.py b/tests/table/test_filter.py index 37dc6b2ef5..81a345d752 100644 --- a/tests/table/test_filter.py +++ b/tests/table/test_filter.py @@ -219,6 +219,13 @@ def test_table_float_is_inf() -> None: assert result_table.to_pydict() == {"a": [True, False, None, True]} +def test_table_float_not_nan() -> None: + table = MicroPartition.from_pydict({"a": [1.0, np.nan, 3.0, None, float("nan")]}) + result_table = table.eval_expression_list([col("a").float.not_nan()]) + # Note that null entries are _not_ treated as float NaNs. + assert result_table.to_pydict() == {"a": [True, False, True, None, False]} + + def test_table_if_else() -> None: table = MicroPartition.from_arrow( pa.Table.from_pydict({"ones": [1, 1, 1], "zeros": [0, 0, 0], "pred": [True, False, None]})