From 05e3e3fbbefdaa1780484e58b2c002e4f6b9ec6b Mon Sep 17 00:00:00 2001 From: Matt H Date: Sun, 11 Feb 2024 18:32:07 -0500 Subject: [PATCH] [FEAT] Add ceil function (#1867) Adding the `ceil` function to match https://ibis-project.org/reference/expression-numeric#ibis.expr.types.numeric.NumericValue.ceil. Added tests with example usage. This is my first PR here. Happy to take on more of the functions in #1037 after getting through this one. --- daft/daft.pyi | 2 + daft/expressions/expressions.py | 5 +++ daft/series.py | 3 ++ docs/source/api_docs/expressions.rst | 1 + src/daft-core/src/array/ops/ceil.rs | 18 +++++++++ src/daft-core/src/array/ops/mod.rs | 1 + src/daft-core/src/python/series.rs | 4 ++ src/daft-core/src/series/ops/ceil.rs | 20 ++++++++++ src/daft-core/src/series/ops/mod.rs | 1 + src/daft-dsl/src/functions/numeric/ceil.rs | 41 +++++++++++++++++++++ src/daft-dsl/src/functions/numeric/mod.rs | 12 ++++++ src/daft-dsl/src/python.rs | 5 +++ tests/expressions/test_expressions.py | 9 +++++ tests/expressions/typing/test_arithmetic.py | 10 +++++ tests/table/test_eval.py | 23 ++++++++++++ 15 files changed, 155 insertions(+) create mode 100644 src/daft-core/src/array/ops/ceil.rs create mode 100644 src/daft-core/src/series/ops/ceil.rs create mode 100644 src/daft-dsl/src/functions/numeric/ceil.rs diff --git a/daft/daft.pyi b/daft/daft.pyi index 6d2b51ebfe..9361906b84 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -832,6 +832,7 @@ class PyExpr: def _is_column(self) -> bool: ... def alias(self, name: str) -> PyExpr: ... def cast(self, dtype: PyDataType) -> PyExpr: ... + def ceil(self) -> PyExpr: ... def if_else(self, if_true: PyExpr, if_false: PyExpr) -> PyExpr: ... def count(self, mode: CountMode) -> PyExpr: ... def sum(self) -> PyExpr: ... @@ -940,6 +941,7 @@ class PySeries: def _max(self) -> PySeries: ... def _agg_list(self) -> PySeries: ... def cast(self, dtype: PyDataType) -> PySeries: ... + def ceil(self) -> PySeries: ... @staticmethod def concat(series: list[PySeries]) -> PySeries: ... def __len__(self) -> int: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 11d1277fa0..5ee05c4cd4 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -315,6 +315,11 @@ def cast(self, dtype: DataType) -> Expression: expr = self._expr.cast(dtype._dtype) return Expression._from_pyexpr(expr) + def ceil(self) -> Expression: + """The ceiling of a numeric expression (``expr.ceil()``)""" + expr = self._expr.ceil() + return Expression._from_pyexpr(expr) + def _count(self, mode: CountMode = CountMode.Valid) -> Expression: expr = self._expr.count(mode) return Expression._from_pyexpr(expr) diff --git a/daft/series.py b/daft/series.py index ffcaa8d74e..1f4782b4f0 100644 --- a/daft/series.py +++ b/daft/series.py @@ -349,6 +349,9 @@ def size_bytes(self) -> int: def __abs__(self) -> Series: return Series._from_pyseries(abs(self._series)) + def ceil(self) -> Series: + return Series._from_pyseries(self._series.ceil()) + def __add__(self, other: object) -> Series: if not isinstance(other, Series): raise TypeError(f"expected another Series but got {type(other)}") diff --git a/docs/source/api_docs/expressions.rst b/docs/source/api_docs/expressions.rst index e586a8c2b9..e977c849d8 100644 --- a/docs/source/api_docs/expressions.rst +++ b/docs/source/api_docs/expressions.rst @@ -46,6 +46,7 @@ Numeric Expression.__mul__ Expression.__truediv__ Expression.__mod__ + Expression.ceil .. _api-comparison-expression: diff --git a/src/daft-core/src/array/ops/ceil.rs b/src/daft-core/src/array/ops/ceil.rs new file mode 100644 index 0000000000..3d677fff65 --- /dev/null +++ b/src/daft-core/src/array/ops/ceil.rs @@ -0,0 +1,18 @@ +use num_traits::Float; + +use crate::{ + array::DataArray, + datatypes::{DaftFloatType, DaftNumericType}, +}; + +use common_error::DaftResult; + +impl DataArray +where + T: DaftNumericType, + T::Native: Float, +{ + pub fn ceil(&self) -> DaftResult { + self.apply(|v| v.ceil()) + } +} diff --git a/src/daft-core/src/array/ops/mod.rs b/src/daft-core/src/array/ops/mod.rs index ba69e727b6..d9d66cfec6 100644 --- a/src/daft-core/src/array/ops/mod.rs +++ b/src/daft-core/src/array/ops/mod.rs @@ -6,6 +6,7 @@ pub mod arrow2; pub mod as_arrow; pub(crate) mod broadcast; pub(crate) mod cast; +mod ceil; mod compare_agg; mod comparison; mod concat; diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index d21def6a8f..e89b206399 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -108,6 +108,10 @@ impl PySeries { Ok(self.series.xor(&other.series)?.into_series().into()) } + pub fn ceil(&self) -> PyResult { + Ok(self.series.ceil()?.into()) + } + pub fn take(&self, idx: &Self) -> PyResult { Ok(self.series.take(&idx.series)?.into()) } diff --git a/src/daft-core/src/series/ops/ceil.rs b/src/daft-core/src/series/ops/ceil.rs new file mode 100644 index 0000000000..3a01f045a6 --- /dev/null +++ b/src/daft-core/src/series/ops/ceil.rs @@ -0,0 +1,20 @@ +use crate::datatypes::DataType; +use crate::series::Series; +use common_error::DaftError; +use common_error::DaftResult; +impl Series { + pub fn ceil(&self) -> DaftResult { + use crate::series::array_impl::IntoSeries; + + use DataType::*; + match self.data_type() { + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => Ok(self.clone()), + Float32 => Ok(self.f32().unwrap().ceil()?.into_series()), + Float64 => Ok(self.f64().unwrap().ceil()?.into_series()), + dt => Err(DaftError::TypeError(format!( + "ceil not implemented for {}", + dt + ))), + } + } +} diff --git a/src/daft-core/src/series/ops/mod.rs b/src/daft-core/src/series/ops/mod.rs index e328be0cd3..eb0f75b245 100644 --- a/src/daft-core/src/series/ops/mod.rs +++ b/src/daft-core/src/series/ops/mod.rs @@ -8,6 +8,7 @@ pub mod agg; pub mod arithmetic; pub mod broadcast; pub mod cast; +pub mod ceil; pub mod comparison; pub mod concat; pub mod date; diff --git a/src/daft-dsl/src/functions/numeric/ceil.rs b/src/daft-dsl/src/functions/numeric/ceil.rs new file mode 100644 index 0000000000..9eec7d09f3 --- /dev/null +++ b/src/daft-dsl/src/functions/numeric/ceil.rs @@ -0,0 +1,41 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{datatypes::Field, schema::Schema, series::Series}; + +use crate::Expr; + +use super::super::FunctionEvaluator; + +pub(super) struct CeilEvaluator {} + +impl FunctionEvaluator for CeilEvaluator { + fn fn_name(&self) -> &'static str { + "ceil" + } + + fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { + if inputs.len() != 1 { + return Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))); + } + let field = inputs.first().unwrap().to_field(schema)?; + if !field.dtype.is_numeric() { + return Err(DaftError::TypeError(format!( + "Expected input to ceil to be numeric, got {}", + field.dtype + ))); + } + Ok(field) + } + + fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult { + if inputs.len() != 1 { + return Err(DaftError::ValueError(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))); + } + inputs.first().unwrap().ceil() + } +} diff --git a/src/daft-dsl/src/functions/numeric/mod.rs b/src/daft-dsl/src/functions/numeric/mod.rs index e8706192aa..7e54d83c68 100644 --- a/src/daft-dsl/src/functions/numeric/mod.rs +++ b/src/daft-dsl/src/functions/numeric/mod.rs @@ -1,6 +1,9 @@ mod abs; +mod ceil; use abs::AbsEvaluator; +use ceil::CeilEvaluator; + use serde::{Deserialize, Serialize}; use crate::Expr; @@ -10,6 +13,7 @@ use super::FunctionEvaluator; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] pub enum NumericExpr { Abs, + Ceil, } impl NumericExpr { @@ -18,6 +22,7 @@ impl NumericExpr { use NumericExpr::*; match self { Abs => &AbsEvaluator {}, + Ceil => &CeilEvaluator {}, } } } @@ -28,3 +33,10 @@ pub fn abs(input: &Expr) -> Expr { inputs: vec![input.clone()], } } + +pub fn ceil(input: &Expr) -> Expr { + Expr::Function { + func: super::FunctionExpr::Numeric(NumericExpr::Ceil), + inputs: vec![input.clone()], + } +} diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index eff9bc2738..f135b96bb4 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -144,6 +144,11 @@ impl PyExpr { Ok(self.expr.cast(&dtype.into()).into()) } + pub fn ceil(&self) -> PyResult { + use functions::numeric::ceil; + Ok(ceil(&self.expr).into()) + } + pub fn if_else(&self, if_true: &Self, if_false: &Self) -> PyResult { Ok(self.expr.if_else(&if_true.expr, &if_false.expr).into()) } diff --git a/tests/expressions/test_expressions.py b/tests/expressions/test_expressions.py index 0884690963..84102ef625 100644 --- a/tests/expressions/test_expressions.py +++ b/tests/expressions/test_expressions.py @@ -86,6 +86,15 @@ def test_repr_functions_abs() -> None: assert repr_out == repr(copied) +def test_repr_functions_ceil() -> None: + a = col("a") + y = a.ceil() + repr_out = repr(y) + assert repr_out == "ceil(col(a))" + copied = copy.deepcopy(y) + assert repr_out == repr(copied) + + def test_repr_functions_day() -> None: a = col("a") y = a.dt.day() diff --git a/tests/expressions/typing/test_arithmetic.py b/tests/expressions/typing/test_arithmetic.py index 827aa85d7a..4d47a069a3 100644 --- a/tests/expressions/typing/test_arithmetic.py +++ b/tests/expressions/typing/test_arithmetic.py @@ -72,3 +72,13 @@ def test_abs(unary_data_fixture): run_kernel=lambda: abs(arg), resolvable=is_numeric(arg.datatype()), ) + + +def test_ceil(unary_data_fixture): + arg = unary_data_fixture + assert_typing_resolve_vs_runtime_behavior( + data=(unary_data_fixture,), + expr=col(arg.name()).ceil(), + run_kernel=lambda: arg.ceil(), + resolvable=is_numeric(arg.datatype()), + ) diff --git a/tests/table/test_eval.py b/tests/table/test_eval.py index 5f6d3948b8..b87b3bd6e1 100644 --- a/tests/table/test_eval.py +++ b/tests/table/test_eval.py @@ -1,6 +1,7 @@ from __future__ import annotations import itertools +import math import operator as ops import pyarrow as pa @@ -157,3 +158,25 @@ def test_table_abs_bad_input() -> None: with pytest.raises(ValueError, match="Expected input to abs to be numeric"): table.eval_expression_list([abs(col("a"))]) + + +def test_table_numeric_ceil() -> None: + table = MicroPartition.from_pydict( + {"a": [None, -1.0, -0.5, 0, 0.5, 2, None], "b": [-1.7, -1.5, -1.3, 0.3, 0.7, None, None]} + ) + + ceil_table = table.eval_expression_list([col("a").ceil(), col("b").ceil()]) + + assert [math.ceil(v) if v is not None else v for v in table.get_column("a").to_pylist()] == ceil_table.get_column( + "a" + ).to_pylist() + assert [math.ceil(v) if v is not None else v for v in table.get_column("b").to_pylist()] == ceil_table.get_column( + "b" + ).to_pylist() + + +def test_table_ceil_bad_input() -> None: + table = MicroPartition.from_pydict({"a": ["a", "b", "c"]}) + + with pytest.raises(ValueError, match="Expected input to ceil to be numeric"): + table.eval_expression_list([col("a").ceil()])