From acb8203c6eabc2fbb4194a3c6431652b5105ded6 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Wed, 10 Apr 2024 21:09:56 -0700 Subject: [PATCH] [FEAT] fill_null expression (#2089) Closes #1904 --- daft/daft.pyi | 2 ++ daft/expressions/expressions.py | 27 +++++++++++++++++ daft/series.py | 6 ++++ src/daft-core/src/python/series.rs | 4 +++ src/daft-core/src/series/ops/null.rs | 5 +++ src/daft-dsl/src/expr.rs | 37 +++++++++++++++++++---- src/daft-dsl/src/optimization.rs | 1 + src/daft-dsl/src/python.rs | 4 +++ src/daft-dsl/src/treenode.rs | 5 +++ src/daft-plan/src/builder.rs | 2 +- src/daft-plan/src/logical_ops/project.rs | 16 ++++++++++ src/daft-plan/src/physical_ops/project.rs | 11 +++++++ src/daft-table/src/lib.rs | 4 +++ tests/expressions/typing/test_null.py | 15 ++++++++- tests/series/test_fill_null.py | 32 ++++++++++++++++++++ tests/table/test_fill_null.py | 37 +++++++++++++++++++++++ 16 files changed, 200 insertions(+), 8 deletions(-) create mode 100644 tests/series/test_fill_null.py create mode 100644 tests/table/test_fill_null.py diff --git a/daft/daft.pyi b/daft/daft.pyi index aa51d38dc7..91171522c2 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -917,6 +917,7 @@ class PyExpr: def __ne__(self, other: PyExpr) -> PyExpr: ... # type: ignore[override] def is_null(self) -> PyExpr: ... def not_null(self) -> PyExpr: ... + def fill_null(self, fill_value: PyExpr) -> PyExpr: ... def is_in(self, other: PyExpr) -> PyExpr: ... def name(self) -> str: ... def to_field(self, schema: PySchema) -> PyField: ... @@ -1068,6 +1069,7 @@ class PySeries: def if_else(self, other: PySeries, predicate: PySeries) -> PySeries: ... def is_null(self) -> PySeries: ... def not_null(self) -> PySeries: ... + def fill_null(self, fill_value: PySeries) -> PySeries: ... def murmur3_32(self) -> PySeries: ... def to_str_values(self) -> PySeries: ... def _debug_bincode_serialize(self) -> bytes: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 081fb453da..fa377b4a00 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -485,6 +485,33 @@ def not_null(self) -> Expression: expr = self._expr.not_null() return Expression._from_pyexpr(expr) + def fill_null(self, fill_value: Expression) -> Expression: + """Fills null values in the Expression with the provided fill_value + + Example: + >>> df = daft.from_pydict({"data": [1, None, 3]}) + >>> df = df.select(df["data"].fill_null(2)) + >>> df.collect() + ╭───────╮ + │ data │ + │ --- │ + │ Int64 │ + ╞═══════╡ + │ 1 │ + ├╌╌╌╌╌╌╌┤ + │ 2 │ + ├╌╌╌╌╌╌╌┤ + │ 3 │ + ╰───────╯ + + Returns: + Expression: Expression with null values filled with the provided fill_value + """ + + fill_value = Expression._to_expression(fill_value) + expr = self._expr.fill_null(fill_value._expr) + return Expression._from_pyexpr(expr) + def is_in(self, other: Any) -> Expression: """Checks if values in the Expression are in the provided list diff --git a/daft/series.py b/daft/series.py index 884980b6c1..b04b73f0d3 100644 --- a/daft/series.py +++ b/daft/series.py @@ -492,6 +492,12 @@ def not_null(self) -> Series: assert self._series is not None return Series._from_pyseries(self._series.not_null()) + def fill_null(self, fill_value: object) -> Series: + if not isinstance(fill_value, Series): + raise ValueError(f"expected another Series but got {type(fill_value)}") + assert self._series is not None and fill_value._series is not None + return Series._from_pyseries(self._series.fill_null(fill_value._series)) + def _to_str_values(self) -> Series: return Series._from_pyseries(self._series.to_str_values()) diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index c262f36612..495f160905 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -438,6 +438,10 @@ impl PySeries { Ok(self.series.not_null()?.into()) } + pub fn fill_null(&self, fill_value: &Self) -> PyResult { + Ok(self.series.fill_null(&fill_value.series)?.into()) + } + pub fn _debug_bincode_serialize(&self, py: Python) -> PyResult { let values = bincode::serialize(&self.series).unwrap(); Ok(PyBytes::new(py, &values).to_object(py)) diff --git a/src/daft-core/src/series/ops/null.rs b/src/daft-core/src/series/ops/null.rs index a1bbdb0419..685579f6c8 100644 --- a/src/daft-core/src/series/ops/null.rs +++ b/src/daft-core/src/series/ops/null.rs @@ -10,4 +10,9 @@ impl Series { pub fn not_null(&self) -> DaftResult { self.inner.not_null() } + + pub fn fill_null(&self, fill_value: &Series) -> DaftResult { + let predicate = self.not_null()?; + self.if_else(fill_value, &predicate) + } } diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index 1201956dea..181fe4195c 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -42,6 +42,7 @@ pub enum Expr { Not(ExprRef), IsNull(ExprRef), NotNull(ExprRef), + FillNull(ExprRef, ExprRef), IsIn(ExprRef, ExprRef), Literal(lit::LiteralValue), IfElse { @@ -279,6 +280,10 @@ impl Expr { Expr::NotNull(self.clone().into()) } + pub fn fill_null(&self, fill_value: &Self) -> Self { + Expr::FillNull(self.clone().into(), fill_value.clone().into()) + } + pub fn is_in(&self, items: &Self) -> Self { Expr::IsIn(self.clone().into(), items.clone().into()) } @@ -342,6 +347,11 @@ impl Expr { let child_id = expr.semantic_id(schema); FieldID::new(format!("{child_id}.not_null()")) } + FillNull(expr, fill_value) => { + let child_id = expr.semantic_id(schema); + let fill_value_id = fill_value.semantic_id(schema); + FieldID::new(format!("{child_id}.fill_null({fill_value_id})")) + } IsIn(expr, items) => { let child_id = expr.semantic_id(schema); let items_id = items.semantic_id(schema); @@ -400,6 +410,7 @@ impl Expr { } => { vec![predicate.clone(), if_true.clone(), if_false.clone()] } + FillNull(expr, fill_value) => vec![expr.clone(), fill_value.clone()], } } @@ -421,6 +432,16 @@ impl Expr { } IsNull(expr) => Ok(Field::new(expr.name()?, DataType::Boolean)), NotNull(expr) => Ok(Field::new(expr.name()?, DataType::Boolean)), + FillNull(expr, fill_value) => { + let expr_field = expr.to_field(schema)?; + let fill_value_field = fill_value.to_field(schema)?; + match try_get_supertype(&expr_field.dtype, &fill_value_field.dtype) { + Ok(supertype) => Ok(Field::new(expr_field.name.as_str(), supertype)), + Err(_) => Err(DaftError::TypeError(format!( + "Expected expr and fill_value arguments for fill_null to be castable to the same supertype, but received {expr_field} and {fill_value_field}", + ))) + } + } IsIn(left, right) => { let left_field = left.to_field(schema)?; let right_field = right.to_field(schema)?; @@ -510,6 +531,7 @@ impl Expr { Not(expr) => expr.name(), IsNull(expr) => expr.name(), NotNull(expr) => expr.name(), + FillNull(expr, ..) => expr.name(), IsIn(expr, ..) => expr.name(), Literal(..) => Ok("literal"), Function { func, inputs } => match func { @@ -598,12 +620,14 @@ impl Expr { write!(buffer, " END") } // TODO: Implement SQL translations for these expressions if possible - Expr::Agg(..) | Expr::Cast(..) | Expr::IsIn(..) | Expr::Function { .. } => { - Err(io::Error::new( - io::ErrorKind::Other, - "Unsupported expression for SQL translation", - )) - } + Expr::Agg(..) + | Expr::Cast(..) + | Expr::IsIn(..) + | Expr::Function { .. } + | Expr::FillNull(..) => Err(io::Error::new( + io::ErrorKind::Other, + "Unsupported expression for SQL translation", + )), } } @@ -638,6 +662,7 @@ impl Display for Expr { Not(expr) => write!(f, "not({expr})"), IsNull(expr) => write!(f, "is_null({expr})"), NotNull(expr) => write!(f, "not_null({expr})"), + FillNull(expr, fill_value) => write!(f, "fill_null({expr}, {fill_value})"), IsIn(expr, items) => write!(f, "{expr} in {items}"), Literal(val) => write!(f, "lit({val})"), Function { func, inputs } => function_display(f, func, inputs), diff --git a/src/daft-dsl/src/optimization.rs b/src/daft-dsl/src/optimization.rs index 465f2c1e84..2959a3b67f 100644 --- a/src/daft-dsl/src/optimization.rs +++ b/src/daft-dsl/src/optimization.rs @@ -30,6 +30,7 @@ pub fn requires_computation(e: &Expr) -> bool { | Expr::Not(..) | Expr::IsNull(..) | Expr::NotNull(..) + | Expr::FillNull(..) | Expr::IsIn { .. } | Expr::IfElse { .. } => true, } diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index bc68334b5b..ca22c7f5a9 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -331,6 +331,10 @@ impl PyExpr { Ok(self.expr.not_null().into()) } + pub fn fill_null(&self, fill_value: &Self) -> PyResult { + Ok(self.expr.fill_null(&fill_value.expr).into()) + } + pub fn is_in(&self, other: &Self) -> PyResult { Ok(self.expr.is_in(&other.expr).into()) } diff --git a/src/daft-dsl/src/treenode.rs b/src/daft-dsl/src/treenode.rs index 3f2d05453b..2ec7546a6f 100644 --- a/src/daft-dsl/src/treenode.rs +++ b/src/daft-dsl/src/treenode.rs @@ -29,6 +29,7 @@ impl TreeNode for Expr { } BinaryOp { op: _, left, right } => vec![left.as_ref(), right.as_ref()], IsIn(expr, items) => vec![expr.as_ref(), items.as_ref()], + FillNull(expr, fill_value) => vec![expr.as_ref(), fill_value.as_ref()], Column(_) | Literal(_) => vec![], Function { func: _, inputs } => inputs.iter().collect::>(), IfElse { @@ -83,6 +84,10 @@ impl TreeNode for Expr { Not(expr) => Not(transform(expr.as_ref().clone())?.into()), IsNull(expr) => IsNull(transform(expr.as_ref().clone())?.into()), NotNull(expr) => NotNull(transform(expr.as_ref().clone())?.into()), + FillNull(expr, fill_value) => FillNull( + transform(expr.as_ref().clone())?.into(), + transform(fill_value.as_ref().clone())?.into(), + ), IsIn(expr, items) => IsIn( transform(expr.as_ref().clone())?.into(), transform(items.as_ref().clone())?.into(), diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index 7f3cdf5cc0..f2c6339d3f 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -56,7 +56,7 @@ fn check_for_agg(expr: &Expr) -> bool { Alias(e, _) | Cast(e, _) | Not(e) | IsNull(e) | NotNull(e) => check_for_agg(e), BinaryOp { left, right, .. } => check_for_agg(left) || check_for_agg(right), Function { inputs, .. } => inputs.iter().any(check_for_agg), - IsIn(l, r) => check_for_agg(l) || check_for_agg(r), + IsIn(l, r) | FillNull(l, r) => check_for_agg(l) || check_for_agg(r), IfElse { if_true, if_false, diff --git a/src/daft-plan/src/logical_ops/project.rs b/src/daft-plan/src/logical_ops/project.rs index 124ff20f82..54777b1007 100644 --- a/src/daft-plan/src/logical_ops/project.rs +++ b/src/daft-plan/src/logical_ops/project.rs @@ -251,6 +251,22 @@ fn replace_column_with_semantic_id( |_| e, ) } + Expr::FillNull(child, fill_value) => { + let child = + replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema); + let fill_value = replace_column_with_semantic_id( + fill_value.clone(), + subexprs_to_replace, + schema, + ); + if child.is_no() && fill_value.is_no() { + Transformed::No(e) + } else { + Transformed::Yes( + Expr::FillNull(child.unwrap().clone(), fill_value.unwrap().clone()).into(), + ) + } + } Expr::IsIn(child, items) => { let child = replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema); diff --git a/src/daft-plan/src/physical_ops/project.rs b/src/daft-plan/src/physical_ops/project.rs index aefc850ce1..6f2b0401fe 100644 --- a/src/daft-plan/src/physical_ops/project.rs +++ b/src/daft-plan/src/physical_ops/project.rs @@ -173,6 +173,17 @@ impl Project { )?; Ok(Expr::NotNull(newchild.into())) } + Expr::FillNull(child, fill_value) => { + let newchild = Self::translate_clustering_spec_expr( + child.as_ref(), + old_colname_to_new_colname, + )?; + let newfill = Self::translate_clustering_spec_expr( + fill_value.as_ref(), + old_colname_to_new_colname, + )?; + Ok(Expr::FillNull(newchild.into(), newfill.into())) + } Expr::IsIn(child, items) => { let newchild = Self::translate_clustering_spec_expr( child.as_ref(), diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 6b6d8fabc1..4bd63735d3 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -347,6 +347,10 @@ impl Table { Not(child) => !(self.eval_expression(child)?), IsNull(child) => self.eval_expression(child)?.is_null(), NotNull(child) => self.eval_expression(child)?.not_null(), + FillNull(child, fill_value) => { + let fill_value = self.eval_expression(fill_value)?; + self.eval_expression(child)?.fill_null(&fill_value) + } IsIn(child, items) => self .eval_expression(child)? .is_in(&self.eval_expression(items)?), diff --git a/tests/expressions/typing/test_null.py b/tests/expressions/typing/test_null.py index 443f85996d..faadd935cf 100644 --- a/tests/expressions/typing/test_null.py +++ b/tests/expressions/typing/test_null.py @@ -1,7 +1,10 @@ from __future__ import annotations from daft.expressions import col -from tests.expressions.typing.conftest import assert_typing_resolve_vs_runtime_behavior +from tests.expressions.typing.conftest import ( + assert_typing_resolve_vs_runtime_behavior, + has_supertype, +) def test_is_null(unary_data_fixture): @@ -22,3 +25,13 @@ def test_not_null(unary_data_fixture): run_kernel=lambda: arg.not_null(), resolvable=True, ) + + +def test_fill_null(binary_data_fixture): + lhs, rhs = binary_data_fixture + assert_typing_resolve_vs_runtime_behavior( + data=binary_data_fixture, + expr=col(lhs.name()).fill_null(col(rhs.name())), + run_kernel=lambda: lhs.fill_null(rhs), + resolvable=has_supertype(lhs.datatype(), rhs.datatype()), + ) diff --git a/tests/series/test_fill_null.py b/tests/series/test_fill_null.py new file mode 100644 index 0000000000..11fcca3088 --- /dev/null +++ b/tests/series/test_fill_null.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import pyarrow as pa +import pytest + +from daft.series import Series + + +@pytest.mark.parametrize( + "input,fill_value,expected", + [ + # No broadcast + [[1, 2, None], [3, 3, 3], [1, 2, 3]], + # Broadcast input + [[None], [3, 3, 3], [3, 3, 3]], + # Broadcast fill_value + [[1, 2, None], [3], [1, 2, 3]], + # Empty + [[], [], []], + ], +) +def test_series_fill_null(input, fill_value, expected) -> None: + s = Series.from_arrow(pa.array(input, pa.int64())) + fill_value = Series.from_arrow(pa.array(fill_value, pa.int64())) + filled = s.fill_null(fill_value) + assert filled.to_pylist() == expected + + +def test_series_fill_null_bad_input() -> None: + s = Series.from_arrow(pa.array([1, 2, 3], pa.int64())) + with pytest.raises(ValueError, match="expected another Series but got"): + s.fill_null([1, 2, 3]) diff --git a/tests/table/test_fill_null.py b/tests/table/test_fill_null.py new file mode 100644 index 0000000000..a61d7601a4 --- /dev/null +++ b/tests/table/test_fill_null.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import datetime + +import pytest + +from daft.expressions.expressions import col +from daft.table.micropartition import MicroPartition + + +@pytest.mark.parametrize( + "input,fill_value,expected", + [ + pytest.param([None, None, None], "a", ["a", "a", "a"], id="NullColumn"), + pytest.param([True, False, None], False, [True, False, False], id="BoolColumn"), + pytest.param(["a", "b", None], "b", ["a", "b", "b"], id="StringColumn"), + pytest.param([b"a", None, b"c"], b"b", [b"a", b"b", b"c"], id="BinaryColumn"), + pytest.param([-1, None, 3], 0, [-1, 0, 3], id="IntColumn"), + pytest.param([-1.0, None, 3.0], 0.0, [-1.0, 0.0, 3.0], id="FloatColumn"), + pytest.param( + [datetime.date.today(), None, datetime.date(2023, 1, 1)], + datetime.date(2022, 1, 1), + [datetime.date.today(), datetime.date(2022, 1, 1), datetime.date(2023, 1, 1)], + ), + pytest.param( + [datetime.datetime(2022, 1, 1), None, datetime.datetime(2023, 1, 1)], + datetime.datetime(2022, 1, 1), + [datetime.datetime(2022, 1, 1), datetime.datetime(2022, 1, 1), datetime.datetime(2023, 1, 1)], + ), + ], +) +def test_table_expr_fill_null(input, fill_value, expected) -> None: + daft_table = MicroPartition.from_pydict({"input": input}) + daft_table = daft_table.eval_expression_list([col("input").fill_null(fill_value)]) + pydict = daft_table.to_pydict() + + assert pydict["input"] == expected