diff --git a/daft/daft.pyi b/daft/daft.pyi index 8055a69940..eb30a512c3 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -871,6 +871,7 @@ class PyExpr: def alias(self, name: str) -> PyExpr: ... def cast(self, dtype: PyDataType) -> PyExpr: ... def ceil(self) -> PyExpr: ... + def floor(self) -> PyExpr: ... def if_else(self, if_true: PyExpr, if_false: PyExpr) -> PyExpr: ... def count(self, mode: CountMode) -> PyExpr: ... def sum(self) -> PyExpr: ... @@ -986,6 +987,7 @@ class PySeries: def _agg_list(self) -> PySeries: ... def cast(self, dtype: PyDataType) -> PySeries: ... def ceil(self) -> PySeries: ... + def floor(self) -> PySeries: ... @staticmethod def concat(series: list[PySeries]) -> PySeries: ... def __len__(self) -> int: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 80b4b13c21..449ea2bb23 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -327,6 +327,11 @@ def ceil(self) -> Expression: expr = self._expr.ceil() return Expression._from_pyexpr(expr) + def floor(self) -> Expression: + """The floor of a numeric expression (``expr.floor()``)""" + expr = self._expr.floor() + return Expression._from_pyexpr(expr) + def _count(self, mode: CountMode = CountMode.Valid) -> Expression: expr = self._expr.count(mode) return Expression._from_pyexpr(expr) diff --git a/daft/series.py b/daft/series.py index 95b9874de7..dac7aca5ef 100644 --- a/daft/series.py +++ b/daft/series.py @@ -352,6 +352,9 @@ def __abs__(self) -> Series: def ceil(self) -> Series: return Series._from_pyseries(self._series.ceil()) + def floor(self) -> Series: + return Series._from_pyseries(self._series.floor()) + def __add__(self, other: object) -> Series: if not isinstance(other, Series): raise TypeError(f"expected another Series but got {type(other)}") diff --git a/docs/source/api_docs/expressions.rst b/docs/source/api_docs/expressions.rst index 3623edece9..f5b9031936 100644 --- a/docs/source/api_docs/expressions.rst +++ b/docs/source/api_docs/expressions.rst @@ -47,6 +47,7 @@ Numeric Expression.__truediv__ Expression.__mod__ Expression.ceil + Expression.floor .. _api-comparison-expression: diff --git a/src/daft-core/src/array/ops/floor.rs b/src/daft-core/src/array/ops/floor.rs new file mode 100644 index 0000000000..047cbbf98d --- /dev/null +++ b/src/daft-core/src/array/ops/floor.rs @@ -0,0 +1,18 @@ +use num_traits::Float; + +use crate::{ + array::DataArray, + datatypes::{DaftFloatType, DaftNumericType}, +}; + +use common_error::DaftResult; + +impl DataArray +where + T: DaftNumericType, + T::Native: Float, +{ + pub fn floor(&self) -> DaftResult { + self.apply(|v| v.floor()) + } +} diff --git a/src/daft-core/src/array/ops/mod.rs b/src/daft-core/src/array/ops/mod.rs index d9d66cfec6..111bb8d6e6 100644 --- a/src/daft-core/src/array/ops/mod.rs +++ b/src/daft-core/src/array/ops/mod.rs @@ -15,6 +15,7 @@ mod count; mod date; mod filter; mod float; +mod floor; pub mod from_arrow; pub mod full; mod get; diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index 6773725b5e..10fd48a820 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -112,6 +112,10 @@ impl PySeries { Ok(self.series.ceil()?.into()) } + pub fn floor(&self) -> PyResult { + Ok(self.series.floor()?.into()) + } + pub fn take(&self, idx: &Self) -> PyResult { Ok(self.series.take(&idx.series)?.into()) } diff --git a/src/daft-core/src/series/ops/floor.rs b/src/daft-core/src/series/ops/floor.rs new file mode 100644 index 0000000000..59256fef5f --- /dev/null +++ b/src/daft-core/src/series/ops/floor.rs @@ -0,0 +1,20 @@ +use crate::datatypes::DataType; +use crate::series::Series; +use common_error::DaftError; +use common_error::DaftResult; + +impl Series { + pub fn floor(&self) -> DaftResult { + use crate::series::array_impl::IntoSeries; + use DataType::*; + match self.data_type() { + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => Ok(self.clone()), + Float32 => Ok(self.f32().unwrap().floor()?.into_series()), + Float64 => Ok(self.f64().unwrap().floor()?.into_series()), + dt => Err(DaftError::TypeError(format!( + "floor not implemented for {}", + dt + ))), + } + } +} diff --git a/src/daft-core/src/series/ops/mod.rs b/src/daft-core/src/series/ops/mod.rs index eb0f75b245..865a0c9edc 100644 --- a/src/daft-core/src/series/ops/mod.rs +++ b/src/daft-core/src/series/ops/mod.rs @@ -15,6 +15,7 @@ pub mod date; pub mod downcast; pub mod filter; pub mod float; +pub mod floor; pub mod groups; pub mod hash; pub mod if_else; diff --git a/src/daft-dsl/src/functions/numeric/floor.rs b/src/daft-dsl/src/functions/numeric/floor.rs new file mode 100644 index 0000000000..4275f4dfce --- /dev/null +++ b/src/daft-dsl/src/functions/numeric/floor.rs @@ -0,0 +1,40 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{datatypes::Field, schema::Schema, series::Series}; + +use super::super::FunctionEvaluator; +use crate::Expr; + +pub(super) struct FloorEvaluator {} + +impl FunctionEvaluator for FloorEvaluator { + fn fn_name(&self) -> &'static str { + "floor" + } + + fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { + if inputs.len() != 1 { + return Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))); + } + let field = inputs.first().unwrap().to_field(schema)?; + if !field.dtype.is_numeric() { + return Err(DaftError::TypeError(format!( + "Expected input to floor to be numeric, got {}", + field.dtype + ))); + } + Ok(field) + } + + fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult { + if inputs.len() != 1 { + return Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))); + } + inputs.first().unwrap().floor() + } +} diff --git a/src/daft-dsl/src/functions/numeric/mod.rs b/src/daft-dsl/src/functions/numeric/mod.rs index 7e54d83c68..7b2eb8447b 100644 --- a/src/daft-dsl/src/functions/numeric/mod.rs +++ b/src/daft-dsl/src/functions/numeric/mod.rs @@ -1,8 +1,10 @@ mod abs; mod ceil; +mod floor; use abs::AbsEvaluator; use ceil::CeilEvaluator; +use floor::FloorEvaluator; use serde::{Deserialize, Serialize}; @@ -14,6 +16,7 @@ use super::FunctionEvaluator; pub enum NumericExpr { Abs, Ceil, + Floor, } impl NumericExpr { @@ -23,6 +26,7 @@ impl NumericExpr { match self { Abs => &AbsEvaluator {}, Ceil => &CeilEvaluator {}, + Floor => &FloorEvaluator {}, } } } @@ -40,3 +44,10 @@ pub fn ceil(input: &Expr) -> Expr { inputs: vec![input.clone()], } } + +pub fn floor(input: &Expr) -> Expr { + Expr::Function { + func: super::FunctionExpr::Numeric(NumericExpr::Floor), + inputs: vec![input.clone()], + } +} diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index 7f45c804e4..0628b0b47c 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -155,6 +155,11 @@ impl PyExpr { Ok(ceil(&self.expr).into()) } + pub fn floor(&self) -> PyResult { + use functions::numeric::floor; + Ok(floor(&self.expr).into()) + } + pub fn if_else(&self, if_true: &Self, if_false: &Self) -> PyResult { Ok(self.expr.if_else(&if_true.expr, &if_false.expr).into()) } diff --git a/tests/expressions/test_expressions.py b/tests/expressions/test_expressions.py index 92c80df665..809138ebb5 100644 --- a/tests/expressions/test_expressions.py +++ b/tests/expressions/test_expressions.py @@ -96,6 +96,15 @@ def test_repr_functions_ceil() -> None: assert repr_out == repr(copied) +def test_repr_functions_floor() -> None: + a = col("a") + y = a.floor() + repr_out = repr(y) + assert repr_out == "floor(col(a))" + copied = copy.deepcopy(y) + assert repr_out == repr(copied) + + def test_repr_functions_day() -> None: a = col("a") y = a.dt.day() diff --git a/tests/expressions/typing/test_arithmetic.py b/tests/expressions/typing/test_arithmetic.py index 4d47a069a3..c973265c87 100644 --- a/tests/expressions/typing/test_arithmetic.py +++ b/tests/expressions/typing/test_arithmetic.py @@ -82,3 +82,13 @@ def test_ceil(unary_data_fixture): run_kernel=lambda: arg.ceil(), resolvable=is_numeric(arg.datatype()), ) + + +def test_floor(unary_data_fixture): + arg = unary_data_fixture + assert_typing_resolve_vs_runtime_behavior( + data=(unary_data_fixture,), + expr=col(arg.name()).floor(), + run_kernel=lambda: arg.floor(), + resolvable=is_numeric(arg.datatype()), + ) diff --git a/tests/table/test_eval.py b/tests/table/test_eval.py index b87b3bd6e1..68c8ff0b4e 100644 --- a/tests/table/test_eval.py +++ b/tests/table/test_eval.py @@ -180,3 +180,25 @@ def test_table_ceil_bad_input() -> None: with pytest.raises(ValueError, match="Expected input to ceil to be numeric"): table.eval_expression_list([col("a").ceil()]) + + +def test_table_numeric_floor() -> None: + table = MicroPartition.from_pydict( + {"a": [None, -1.0, -0.5, 0.0, 0.5, 2, None], "b": [-1.7, -1.5, -1.3, 0.3, 0.7, None, None]} + ) + + floor_table = table.eval_expression_list([col("a").floor(), col("b").floor()]) + + assert [math.floor(v) if v is not None else v for v in table.get_column("a").to_pylist()] == floor_table.get_column( + "a" + ).to_pylist() + assert [math.floor(v) if v is not None else v for v in table.get_column("b").to_pylist()] == floor_table.get_column( + "b" + ).to_pylist() + + +def test_table_floor_bad_input() -> None: + table = MicroPartition.from_pydict({"a": ["a", "b", "c"]}) + + with pytest.raises(ValueError, match="Expected input to floor to be numeric"): + table.eval_expression_list([col("a").floor()])