Skip to content

Commit

Permalink
[FEAT] fill_nan and not_nan expressions (#2313)
Browse files Browse the repository at this point in the history
Adds expressions for `fill_nan` and `not_nan`

Todo: 
- add expression for `fill_na`, which is a convenience method for doing
fill_null and fill_nan together, see:
#571
  • Loading branch information
colin-ho authored Jun 14, 2024
1 parent fa6a413 commit b861f4c
Show file tree
Hide file tree
Showing 17 changed files with 360 additions and 3 deletions.
4 changes: 4 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,8 @@ class PyExpr:
def __reduce__(self) -> tuple: ...
def is_nan(self) -> PyExpr: ...
def is_inf(self) -> PyExpr: ...
def not_nan(self) -> PyExpr: ...
def fill_nan(self, fill_value: PyExpr) -> PyExpr: ...
def dt_date(self) -> PyExpr: ...
def dt_day(self) -> PyExpr: ...
def dt_hour(self) -> PyExpr: ...
Expand Down Expand Up @@ -1209,6 +1211,8 @@ class PySeries:
def utf8_substr(self, start: PySeries, length: PySeries | None = None) -> PySeries: ...
def is_nan(self) -> PySeries: ...
def is_inf(self) -> PySeries: ...
def not_nan(self) -> PySeries: ...
def fill_nan(self, fill_value: PySeries) -> PySeries: ...
def dt_date(self) -> PySeries: ...
def dt_day(self) -> PySeries: ...
def dt_hour(self) -> PySeries: ...
Expand Down
42 changes: 42 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,48 @@ def is_inf(self) -> Expression:
"""
return Expression._from_pyexpr(self._expr.is_inf())

def not_nan(self) -> Expression:
"""Checks if values are not NaN (a special float value indicating not-a-number)
.. NOTE::
Nulls will be propagated! I.e. this operation will return a null for null values.
Example:
>>> # [1., None, NaN] -> [True, None, False]
>>> col("x").not_nan()
Returns:
Expression: Boolean Expression indicating whether values are not invalid.
"""
return Expression._from_pyexpr(self._expr.not_nan())

def fill_nan(self, fill_value: Expression) -> Expression:
"""Fills NaN values in the Expression with the provided fill_value
Example:
>>> df = daft.from_pydict({"data": [1.1, float("nan"), 3.3]})
>>> df = df.with_column("filled", df["data"].float.fill_nan(2.2))
>>> df.show()
╭─────────┬─────────╮
│ data ┆ filled │
│ --- ┆ --- │
│ Float64 ┆ Float64 │
╞═════════╪═════════╡
│ 1.1 ┆ 1.1 │
├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤
│ NaN ┆ 2.2 │
├╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌┤
│ 3.3 ┆ 3.3 │
╰─────────┴─────────╯
Returns:
Expression: Expression with Nan values filled with the provided fill_value
"""

fill_value = Expression._to_expression(fill_value)
expr = self._expr.fill_nan(fill_value._expr)
return Expression._from_pyexpr(expr)


class ExpressionDatetimeNamespace(ExpressionNamespace):
def date(self) -> Expression:
Expand Down
9 changes: 9 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,15 @@ def is_nan(self) -> Series:
def is_inf(self) -> Series:
return Series._from_pyseries(self._series.is_inf())

def not_nan(self) -> Series:
return Series._from_pyseries(self._series.not_nan())

def fill_nan(self, fill_value: Series) -> Series:
if not isinstance(fill_value, Series):
raise ValueError(f"expected another Series but got {type(fill_value)}")
assert self._series is not None and fill_value._series is not None
return Series._from_pyseries(self._series.fill_nan(fill_value._series))


class SeriesStringNamespace(SeriesNamespace):
def endswith(self, suffix: Series) -> Series:
Expand Down
3 changes: 3 additions & 0 deletions docs/source/api_docs/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ Generic
Expression.if_else
Expression.is_null
Expression.not_null
Expression.fill_null
Expression.apply

.. _api-numeric-expression-operations:
Expand Down Expand Up @@ -160,6 +161,8 @@ The following methods are available under the ``expr.float`` attribute.

Expression.float.is_inf
Expression.float.is_nan
Expression.float.not_nan
Expression.float.fill_nan

.. _api-expressions-temporal:

Expand Down
32 changes: 31 additions & 1 deletion src/daft-core/src/array/ops/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use common_error::DaftResult;
use num_traits::Float;

use super::DaftIsInf;
use super::DaftIsNan;
use super::{DaftIsNan, DaftNotNan};

use super::as_arrow::AsArrow;

Expand Down Expand Up @@ -68,3 +68,33 @@ impl DaftIsInf for DataArray<NullType> {
)))
}
}

impl<T> DaftNotNan for DataArray<T>
where
T: DaftFloatType,
<T as DaftNumericType>::Native: Float,
{
type Output = DaftResult<DataArray<BooleanType>>;

fn not_nan(&self) -> Self::Output {
let arrow_array = self.as_arrow();
let result_arrow_array = arrow2::array::BooleanArray::from_trusted_len_values_iter(
arrow_array.values_iter().map(|v| !v.is_nan()),
)
.with_validity(arrow_array.validity().cloned());
Ok(BooleanArray::from((self.name(), result_arrow_array)))
}
}

impl DaftNotNan for DataArray<NullType> {
type Output = DaftResult<DataArray<BooleanType>>;

fn not_nan(&self) -> Self::Output {
// Entire array is null; since we don't consider nulls to be NaNs, return an all null (invalid) boolean array.
Ok(BooleanArray::from((
self.name(),
arrow2::array::BooleanArray::from_slice(vec![false; self.len()])
.with_validity(Some(arrow2::bitmap::Bitmap::from(vec![false; self.len()]))),
)))
}
}
5 changes: 5 additions & 0 deletions src/daft-core/src/array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ pub trait DaftIsInf {
fn is_inf(&self) -> Self::Output;
}

pub trait DaftNotNan {
type Output;
fn not_nan(&self) -> Self::Output;
}

pub type VecIndices = Vec<u64>;
pub type GroupIndices = Vec<VecIndices>;
pub type GroupIndicesPair = (VecIndices, GroupIndices);
Expand Down
8 changes: 8 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,14 @@ impl PySeries {
Ok(self.series.is_inf()?.into())
}

pub fn not_nan(&self) -> PyResult<Self> {
Ok(self.series.not_nan()?.into())
}

pub fn fill_nan(&self, fill_value: &Self) -> PyResult<Self> {
Ok(self.series.fill_nan(&fill_value.series)?.into())
}

pub fn dt_date(&self) -> PyResult<Self> {
Ok(self.series.dt_date()?.into())
}
Expand Down
12 changes: 12 additions & 0 deletions src/daft-core/src/series/ops/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,16 @@ impl Series {
Ok(DaftIsInf::is_inf(self.downcast::<<$T as DaftDataType>::ArrayType>()?)?.into_series())
})
}

pub fn not_nan(&self) -> DaftResult<Series> {
use crate::array::ops::DaftNotNan;
with_match_float_and_null_daft_types!(self.data_type(), |$T| {
Ok(DaftNotNan::not_nan(self.downcast::<<$T as DaftDataType>::ArrayType>()?)?.into_series())
})
}

pub fn fill_nan(&self, fill_value: &Self) -> DaftResult<Self> {
let predicate = self.not_nan()?;
self.if_else(fill_value, &predicate)
}
}
48 changes: 48 additions & 0 deletions src/daft-dsl/src/functions/float/fill_nan.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
use daft_core::{
datatypes::Field, schema::Schema, series::Series, utils::supertype::try_get_supertype,
};

use crate::ExprRef;

use crate::functions::FunctionExpr;
use common_error::{DaftError, DaftResult};

use super::super::FunctionEvaluator;

pub(super) struct FillNanEvaluator {}

impl FunctionEvaluator for FillNanEvaluator {
fn fn_name(&self) -> &'static str {
"fill_nan"
}

fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult<Field> {
match inputs {
[data, fill_value] => match (data.to_field(schema), fill_value.to_field(schema)) {
(Ok(data_field), Ok(fill_value_field)) => {
match (&data_field.dtype.is_floating(), &fill_value_field.dtype.is_floating(), try_get_supertype(&data_field.dtype, &fill_value_field.dtype)) {
(true, true, Ok(dtype)) => Ok(Field::new(data_field.name, dtype)),
_ => Err(DaftError::TypeError(format!(
"Expects input to fill_nan to be float, but received {data_field} and {fill_value_field}",
))),
}
}
(Err(e), _) | (_, Err(e)) => Err(e),
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 2 input args, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult<Series> {
match inputs {
[data, fill_value] => data.fill_nan(fill_value),
_ => Err(DaftError::ValueError(format!(
"Expected 2 input args, got {}",
inputs.len()
))),
}
}
}
24 changes: 24 additions & 0 deletions src/daft-dsl/src/functions/float/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
mod fill_nan;
mod is_inf;
mod is_nan;
mod not_nan;

use fill_nan::FillNanEvaluator;
use is_inf::IsInfEvaluator;
use is_nan::IsNanEvaluator;
use not_nan::NotNanEvaluator;
use serde::{Deserialize, Serialize};

use crate::{Expr, ExprRef};
Expand All @@ -13,6 +17,8 @@ use super::FunctionEvaluator;
pub enum FloatExpr {
IsNan,
IsInf,
NotNan,
FillNan,
}

impl FloatExpr {
Expand All @@ -22,6 +28,8 @@ impl FloatExpr {
match self {
IsNan => &IsNanEvaluator {},
IsInf => &IsInfEvaluator {},
NotNan => &NotNanEvaluator {},
FillNan => &FillNanEvaluator {},
}
}
}
Expand All @@ -41,3 +49,19 @@ pub fn is_inf(data: ExprRef) -> ExprRef {
}
.into()
}

pub fn not_nan(data: ExprRef) -> ExprRef {
Expr::Function {
func: super::FunctionExpr::Float(FloatExpr::NotNan),
inputs: vec![data],
}
.into()
}

pub fn fill_nan(data: ExprRef, fill_value: ExprRef) -> ExprRef {
Expr::Function {
func: super::FunctionExpr::Float(FloatExpr::FillNan),
inputs: vec![data, fill_value],
}
.into()
}
51 changes: 51 additions & 0 deletions src/daft-dsl/src/functions/float/not_nan.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use daft_core::{
datatypes::{DataType, Field},
schema::Schema,
series::Series,
};

use crate::ExprRef;

use crate::functions::FunctionExpr;
use common_error::{DaftError, DaftResult};

use super::super::FunctionEvaluator;

pub(super) struct NotNanEvaluator {}

impl FunctionEvaluator for NotNanEvaluator {
fn fn_name(&self) -> &'static str {
"not_nan"
}

fn to_field(&self, inputs: &[ExprRef], schema: &Schema, _: &FunctionExpr) -> DaftResult<Field> {
match inputs {
[data] => match data.to_field(schema) {
Ok(data_field) => match &data_field.dtype {
// DataType::Float16 |
DataType::Float32 | DataType::Float64 => {
Ok(Field::new(data_field.name, DataType::Boolean))
}
_ => Err(DaftError::TypeError(format!(
"Expects input to is_nan to be float, but received {data_field}",
))),
},
Err(e) => Err(e),
},
_ => Err(DaftError::SchemaMismatch(format!(
"Expected 1 input args, got {}",
inputs.len()
))),
}
}

fn evaluate(&self, inputs: &[Series], _: &FunctionExpr) -> DaftResult<Series> {
match inputs {
[data] => data.not_nan(),
_ => Err(DaftError::ValueError(format!(
"Expected 1 input args, got {}",
inputs.len()
))),
}
}
}
10 changes: 10 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,16 @@ impl PyExpr {
Ok(is_inf(self.into()).into())
}

pub fn not_nan(&self) -> PyResult<Self> {
use functions::float::not_nan;
Ok(not_nan(self.into()).into())
}

pub fn fill_nan(&self, fill_value: &Self) -> PyResult<Self> {
use functions::float::fill_nan;
Ok(fill_nan(self.into(), fill_value.expr.clone()).into())
}

pub fn dt_date(&self) -> PyResult<Self> {
use functions::temporal::date;
Ok(date(self.into()).into())
Expand Down
7 changes: 7 additions & 0 deletions tests/expressions/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,13 @@ def test_float_is_inf() -> None:
assert output == "is_inf(col(a))"


def test_float_not_nan() -> None:
a = col("a")
c = a.float.not_nan()
output = repr(c)
assert output == "not_nan(col(a))"


def test_date_lit_post_epoch() -> None:
d = lit(date(2022, 1, 1))
output = repr(d)
Expand Down
20 changes: 20 additions & 0 deletions tests/expressions/typing/test_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,23 @@ def test_float_is_inf(unary_data_fixture):
run_kernel=unary_data_fixture.float.is_inf,
resolvable=unary_data_fixture.datatype() in (DataType.float32(), DataType.float64()),
)


def test_float_not_nan(unary_data_fixture):
assert_typing_resolve_vs_runtime_behavior(
data=[unary_data_fixture],
expr=col(unary_data_fixture.name()).float.not_nan(),
run_kernel=unary_data_fixture.float.not_nan,
resolvable=unary_data_fixture.datatype() in (DataType.float32(), DataType.float64()),
)


def test_fill_nan(binary_data_fixture):
lhs, rhs = binary_data_fixture
assert_typing_resolve_vs_runtime_behavior(
data=binary_data_fixture,
expr=col(lhs.name()).float.fill_nan(rhs),
run_kernel=lambda: lhs.float.fill_nan(rhs),
resolvable=lhs.datatype() in (DataType.float32(), DataType.float64())
and rhs.datatype() in (DataType.float32(), DataType.float64()),
)
Loading

0 comments on commit b861f4c

Please sign in to comment.