Skip to content

Commit

Permalink
Add ceil function
Browse files Browse the repository at this point in the history
  • Loading branch information
NormallyGaussian committed Feb 11, 2024
1 parent 9b6cb94 commit 1d1076e
Show file tree
Hide file tree
Showing 15 changed files with 155 additions and 0 deletions.
2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,7 @@ class PyExpr:
def _is_column(self) -> bool: ...
def alias(self, name: str) -> PyExpr: ...
def cast(self, dtype: PyDataType) -> PyExpr: ...
def ceil(self) -> PyExpr: ...
def if_else(self, if_true: PyExpr, if_false: PyExpr) -> PyExpr: ...
def count(self, mode: CountMode) -> PyExpr: ...
def sum(self) -> PyExpr: ...
Expand Down Expand Up @@ -940,6 +941,7 @@ class PySeries:
def _max(self) -> PySeries: ...
def _agg_list(self) -> PySeries: ...
def cast(self, dtype: PyDataType) -> PySeries: ...
def ceil(self) -> PySeries: ...
@staticmethod
def concat(series: list[PySeries]) -> PySeries: ...
def __len__(self) -> int: ...
Expand Down
5 changes: 5 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,11 @@ def cast(self, dtype: DataType) -> Expression:
expr = self._expr.cast(dtype._dtype)
return Expression._from_pyexpr(expr)

def ceil(self) -> Expression:
"""The ceiling of a numeric expression (``expr.ceil()``)"""
expr = self._expr.ceil()
return Expression._from_pyexpr(expr)

def _count(self, mode: CountMode = CountMode.Valid) -> Expression:
expr = self._expr.count(mode)
return Expression._from_pyexpr(expr)
Expand Down
3 changes: 3 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,9 @@ def size_bytes(self) -> int:
def __abs__(self) -> Series:
return Series._from_pyseries(abs(self._series))

def ceil(self) -> Series:
return Series._from_pyseries(self._series.ceil())

def __add__(self, other: object) -> Series:
if not isinstance(other, Series):
raise TypeError(f"expected another Series but got {type(other)}")
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_docs/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Numeric
Expression.__mul__
Expression.__truediv__
Expression.__mod__
Expression.ceil

.. _api-comparison-expression:

Expand Down
18 changes: 18 additions & 0 deletions src/daft-core/src/array/ops/ceil.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
use num_traits::Float;

use crate::{
array::DataArray,
datatypes::{DaftFloatType, DaftNumericType},
};

use common_error::DaftResult;

impl<T: DaftFloatType> DataArray<T>
where
T: DaftNumericType,
T::Native: Float,
{
pub fn ceil(&self) -> DaftResult<Self> {
self.apply(|v| v.ceil())
}
}
1 change: 1 addition & 0 deletions src/daft-core/src/array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub mod arrow2;
pub mod as_arrow;
pub(crate) mod broadcast;
pub(crate) mod cast;
mod ceil;
mod compare_agg;
mod comparison;
mod concat;
Expand Down
4 changes: 4 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ impl PySeries {
Ok(self.series.xor(&other.series)?.into_series().into())
}

pub fn ceil(&self) -> PyResult<Self> {
Ok(self.series.ceil()?.into())
}

pub fn take(&self, idx: &Self) -> PyResult<Self> {
Ok(self.series.take(&idx.series)?.into())
}
Expand Down
20 changes: 20 additions & 0 deletions src/daft-core/src/series/ops/ceil.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use crate::datatypes::DataType;
use crate::series::Series;
use common_error::DaftError;
use common_error::DaftResult;
impl Series {
pub fn ceil(&self) -> DaftResult<Series> {
use crate::series::array_impl::IntoSeries;

use DataType::*;
match self.data_type() {
Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 => Ok(self.clone()),
Float32 => Ok(self.f32().unwrap().ceil()?.into_series()),
Float64 => Ok(self.f64().unwrap().ceil()?.into_series()),
dt => Err(DaftError::TypeError(format!(
"ceil not implemented for {}",
dt
))),
}
}
}
1 change: 1 addition & 0 deletions src/daft-core/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub mod agg;
pub mod arithmetic;
pub mod broadcast;
pub mod cast;
pub mod ceil;
pub mod comparison;
pub mod concat;
pub mod date;
Expand Down
41 changes: 41 additions & 0 deletions src/daft-dsl/src/functions/numeric/ceil.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use common_error::{DaftError, DaftResult};
use daft_core::{datatypes::Field, schema::Schema, series::Series};

use crate::Expr;

use super::super::FunctionEvaluator;

pub(super) struct CeilEvaluator {}

impl FunctionEvaluator for CeilEvaluator {
fn fn_name(&self) -> &'static str {
"ceil"
}

fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult<Field> {
if inputs.len() != 1 {
return Err(DaftError::SchemaMismatch(format!(
"Expected 1 input arg, got {}",
inputs.len()
)));
}
let field = inputs.first().unwrap().to_field(schema)?;
if !field.dtype.is_numeric() {
return Err(DaftError::TypeError(format!(
"Expected input to ceil to be numeric, got {}",
field.dtype
)));
}
Ok(field)
}

fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult<Series> {
if inputs.len() != 1 {
return Err(DaftError::ValueError(format!(
"Expected 1 input arg, got {}",
inputs.len()
)));
}
inputs.first().unwrap().ceil()
}
}
12 changes: 12 additions & 0 deletions src/daft-dsl/src/functions/numeric/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
mod abs;
mod ceil;

use abs::AbsEvaluator;
use ceil::CeilEvaluator;

use serde::{Deserialize, Serialize};

use crate::Expr;
Expand All @@ -10,6 +13,7 @@ use super::FunctionEvaluator;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum NumericExpr {
Abs,
Ceil,
}

impl NumericExpr {
Expand All @@ -18,6 +22,7 @@ impl NumericExpr {
use NumericExpr::*;
match self {
Abs => &AbsEvaluator {},
Ceil => &CeilEvaluator {},
}
}
}
Expand All @@ -28,3 +33,10 @@ pub fn abs(input: &Expr) -> Expr {
inputs: vec![input.clone()],
}
}

pub fn ceil(input: &Expr) -> Expr {
Expr::Function {
func: super::FunctionExpr::Numeric(NumericExpr::Ceil),
inputs: vec![input.clone()],
}
}
5 changes: 5 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@ impl PyExpr {
Ok(self.expr.cast(&dtype.into()).into())
}

pub fn ceil(&self) -> PyResult<Self> {
use functions::numeric::ceil;
Ok(ceil(&self.expr).into())
}

pub fn if_else(&self, if_true: &Self, if_false: &Self) -> PyResult<Self> {
Ok(self.expr.if_else(&if_true.expr, &if_false.expr).into())
}
Expand Down
9 changes: 9 additions & 0 deletions tests/expressions/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,15 @@ def test_repr_functions_abs() -> None:
assert repr_out == repr(copied)


def test_repr_functions_ceil() -> None:
a = col("a")
y = a.ceil()
repr_out = repr(y)
assert repr_out == "ceil(col(a))"
copied = copy.deepcopy(y)
assert repr_out == repr(copied)


def test_repr_functions_day() -> None:
a = col("a")
y = a.dt.day()
Expand Down
10 changes: 10 additions & 0 deletions tests/expressions/typing/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,13 @@ def test_abs(unary_data_fixture):
run_kernel=lambda: abs(arg),
resolvable=is_numeric(arg.datatype()),
)


def test_ceil(unary_data_fixture):
arg = unary_data_fixture
assert_typing_resolve_vs_runtime_behavior(
data=(unary_data_fixture,),
expr=col(arg.name()).ceil(),
run_kernel=lambda: arg.ceil(),
resolvable=is_numeric(arg.datatype()),
)
23 changes: 23 additions & 0 deletions tests/table/test_eval.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import itertools
import math
import operator as ops

import pyarrow as pa
Expand Down Expand Up @@ -157,3 +158,25 @@ def test_table_abs_bad_input() -> None:

with pytest.raises(ValueError, match="Expected input to abs to be numeric"):
table.eval_expression_list([abs(col("a"))])


def test_table_numeric_ceil() -> None:
table = MicroPartition.from_pydict(
{"a": [None, -1.0, -0.5, 0, 0.5, 2, None], "b": [-1.7, -1.5, -1.3, 0.3, 0.7, None, None]}
)

ceil_table = table.eval_expression_list([col("a").ceil(), col("b").ceil()])

assert [math.ceil(v) if v is not None else v for v in table.get_column("a").to_pylist()] == ceil_table.get_column(
"a"
).to_pylist()
assert [math.ceil(v) if v is not None else v for v in table.get_column("b").to_pylist()] == ceil_table.get_column(
"b"
).to_pylist()


def test_table_ceil_bad_input() -> None:
table = MicroPartition.from_pydict({"a": ["a", "b", "c"]})

with pytest.raises(ValueError, match="Expected input to ceil to be numeric"):
table.eval_expression_list([col("a").ceil()])

0 comments on commit 1d1076e

Please sign in to comment.