Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add ceil function #1867

Merged
merged 1 commit into from
Feb 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,7 @@ class PyExpr:
def _is_column(self) -> bool: ...
def alias(self, name: str) -> PyExpr: ...
def cast(self, dtype: PyDataType) -> PyExpr: ...
def ceil(self) -> PyExpr: ...
def if_else(self, if_true: PyExpr, if_false: PyExpr) -> PyExpr: ...
def count(self, mode: CountMode) -> PyExpr: ...
def sum(self) -> PyExpr: ...
Expand Down Expand Up @@ -940,6 +941,7 @@ class PySeries:
def _max(self) -> PySeries: ...
def _agg_list(self) -> PySeries: ...
def cast(self, dtype: PyDataType) -> PySeries: ...
def ceil(self) -> PySeries: ...
@staticmethod
def concat(series: list[PySeries]) -> PySeries: ...
def __len__(self) -> int: ...
Expand Down
5 changes: 5 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,11 @@ def cast(self, dtype: DataType) -> Expression:
expr = self._expr.cast(dtype._dtype)
return Expression._from_pyexpr(expr)

def ceil(self) -> Expression:
"""The ceiling of a numeric expression (``expr.ceil()``)"""
expr = self._expr.ceil()
return Expression._from_pyexpr(expr)

def _count(self, mode: CountMode = CountMode.Valid) -> Expression:
expr = self._expr.count(mode)
return Expression._from_pyexpr(expr)
Expand Down
3 changes: 3 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,9 @@ def size_bytes(self) -> int:
def __abs__(self) -> Series:
return Series._from_pyseries(abs(self._series))

def ceil(self) -> Series:
return Series._from_pyseries(self._series.ceil())

def __add__(self, other: object) -> Series:
if not isinstance(other, Series):
raise TypeError(f"expected another Series but got {type(other)}")
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_docs/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Numeric
Expression.__mul__
Expression.__truediv__
Expression.__mod__
Expression.ceil

.. _api-comparison-expression:

Expand Down
18 changes: 18 additions & 0 deletions src/daft-core/src/array/ops/ceil.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
use num_traits::Float;

use crate::{
array::DataArray,
datatypes::{DaftFloatType, DaftNumericType},
};

use common_error::DaftResult;

impl<T: DaftFloatType> DataArray<T>
where
T: DaftNumericType,
T::Native: Float,
{
pub fn ceil(&self) -> DaftResult<Self> {
self.apply(|v| v.ceil())
}
}
1 change: 1 addition & 0 deletions src/daft-core/src/array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub mod arrow2;
pub mod as_arrow;
pub(crate) mod broadcast;
pub(crate) mod cast;
mod ceil;
mod compare_agg;
mod comparison;
mod concat;
Expand Down
4 changes: 4 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ impl PySeries {
Ok(self.series.xor(&other.series)?.into_series().into())
}

pub fn ceil(&self) -> PyResult<Self> {
Ok(self.series.ceil()?.into())
}

pub fn take(&self, idx: &Self) -> PyResult<Self> {
Ok(self.series.take(&idx.series)?.into())
}
Expand Down
20 changes: 20 additions & 0 deletions src/daft-core/src/series/ops/ceil.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use crate::datatypes::DataType;
use crate::series::Series;
use common_error::DaftError;
use common_error::DaftResult;
impl Series {
pub fn ceil(&self) -> DaftResult<Series> {
use crate::series::array_impl::IntoSeries;

use DataType::*;
match self.data_type() {
Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => Ok(self.clone()),
Float32 => Ok(self.f32().unwrap().ceil()?.into_series()),
Float64 => Ok(self.f64().unwrap().ceil()?.into_series()),
dt => Err(DaftError::TypeError(format!(
"ceil not implemented for {}",
dt
))),
}
}
}
1 change: 1 addition & 0 deletions src/daft-core/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub mod agg;
pub mod arithmetic;
pub mod broadcast;
pub mod cast;
pub mod ceil;
pub mod comparison;
pub mod concat;
pub mod date;
Expand Down
41 changes: 41 additions & 0 deletions src/daft-dsl/src/functions/numeric/ceil.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use common_error::{DaftError, DaftResult};
use daft_core::{datatypes::Field, schema::Schema, series::Series};

use crate::Expr;

use super::super::FunctionEvaluator;

pub(super) struct CeilEvaluator {}

impl FunctionEvaluator for CeilEvaluator {
fn fn_name(&self) -> &'static str {
"ceil"
}

fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult<Field> {
if inputs.len() != 1 {
return Err(DaftError::SchemaMismatch(format!(
"Expected 1 input arg, got {}",
inputs.len()
)));
}
let field = inputs.first().unwrap().to_field(schema)?;
if !field.dtype.is_numeric() {
return Err(DaftError::TypeError(format!(
"Expected input to ceil to be numeric, got {}",
field.dtype
)));
}
Ok(field)
}

fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult<Series> {
if inputs.len() != 1 {
return Err(DaftError::ValueError(format!(
"Expected 1 input arg, got {}",
inputs.len()
)));
}
inputs.first().unwrap().ceil()
}
}
12 changes: 12 additions & 0 deletions src/daft-dsl/src/functions/numeric/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
mod abs;
mod ceil;

use abs::AbsEvaluator;
use ceil::CeilEvaluator;

use serde::{Deserialize, Serialize};

use crate::Expr;
Expand All @@ -10,6 +13,7 @@ use super::FunctionEvaluator;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum NumericExpr {
Abs,
Ceil,
}

impl NumericExpr {
Expand All @@ -18,6 +22,7 @@ impl NumericExpr {
use NumericExpr::*;
match self {
Abs => &AbsEvaluator {},
Ceil => &CeilEvaluator {},
}
}
}
Expand All @@ -28,3 +33,10 @@ pub fn abs(input: &Expr) -> Expr {
inputs: vec![input.clone()],
}
}

pub fn ceil(input: &Expr) -> Expr {
Expr::Function {
func: super::FunctionExpr::Numeric(NumericExpr::Ceil),
inputs: vec![input.clone()],
}
}
5 changes: 5 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@ impl PyExpr {
Ok(self.expr.cast(&dtype.into()).into())
}

pub fn ceil(&self) -> PyResult<Self> {
use functions::numeric::ceil;
Ok(ceil(&self.expr).into())
}

pub fn if_else(&self, if_true: &Self, if_false: &Self) -> PyResult<Self> {
Ok(self.expr.if_else(&if_true.expr, &if_false.expr).into())
}
Expand Down
9 changes: 9 additions & 0 deletions tests/expressions/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,15 @@ def test_repr_functions_abs() -> None:
assert repr_out == repr(copied)


def test_repr_functions_ceil() -> None:
a = col("a")
y = a.ceil()
repr_out = repr(y)
assert repr_out == "ceil(col(a))"
copied = copy.deepcopy(y)
assert repr_out == repr(copied)


def test_repr_functions_day() -> None:
a = col("a")
y = a.dt.day()
Expand Down
10 changes: 10 additions & 0 deletions tests/expressions/typing/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,13 @@ def test_abs(unary_data_fixture):
run_kernel=lambda: abs(arg),
resolvable=is_numeric(arg.datatype()),
)


def test_ceil(unary_data_fixture):
arg = unary_data_fixture
assert_typing_resolve_vs_runtime_behavior(
data=(unary_data_fixture,),
expr=col(arg.name()).ceil(),
run_kernel=lambda: arg.ceil(),
resolvable=is_numeric(arg.datatype()),
)
23 changes: 23 additions & 0 deletions tests/table/test_eval.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import itertools
import math
import operator as ops

import pyarrow as pa
Expand Down Expand Up @@ -157,3 +158,25 @@ def test_table_abs_bad_input() -> None:

with pytest.raises(ValueError, match="Expected input to abs to be numeric"):
table.eval_expression_list([abs(col("a"))])


def test_table_numeric_ceil() -> None:
table = MicroPartition.from_pydict(
{"a": [None, -1.0, -0.5, 0, 0.5, 2, None], "b": [-1.7, -1.5, -1.3, 0.3, 0.7, None, None]}
)

ceil_table = table.eval_expression_list([col("a").ceil(), col("b").ceil()])

assert [math.ceil(v) if v is not None else v for v in table.get_column("a").to_pylist()] == ceil_table.get_column(
"a"
).to_pylist()
assert [math.ceil(v) if v is not None else v for v in table.get_column("b").to_pylist()] == ceil_table.get_column(
"b"
).to_pylist()


def test_table_ceil_bad_input() -> None:
table = MicroPartition.from_pydict({"a": ["a", "b", "c"]})

with pytest.raises(ValueError, match="Expected input to ceil to be numeric"):
table.eval_expression_list([col("a").ceil()])
Loading