Skip to content

Commit

Permalink
[FEAT] Sign expression implemtation (#2037)
Browse files Browse the repository at this point in the history
implementation for sign resolves #1917

---------

Co-authored-by: Colin <[email protected]>
  • Loading branch information
sherlockbeard and colin-ho authored Mar 25, 2024
1 parent fe53733 commit b43c3c8
Show file tree
Hide file tree
Showing 15 changed files with 184 additions and 0 deletions.
2 changes: 2 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,7 @@ class PyExpr:
def cast(self, dtype: PyDataType) -> PyExpr: ...
def ceil(self) -> PyExpr: ...
def floor(self) -> PyExpr: ...
def sign(self) -> PyExpr: ...
def if_else(self, if_true: PyExpr, if_false: PyExpr) -> PyExpr: ...
def count(self, mode: CountMode) -> PyExpr: ...
def sum(self) -> PyExpr: ...
Expand Down Expand Up @@ -1007,6 +1008,7 @@ class PySeries:
def cast(self, dtype: PyDataType) -> PySeries: ...
def ceil(self) -> PySeries: ...
def floor(self) -> PySeries: ...
def sign(self) -> PySeries: ...
@staticmethod
def concat(series: list[PySeries]) -> PySeries: ...
def __len__(self) -> int: ...
Expand Down
5 changes: 5 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,11 @@ def floor(self) -> Expression:
expr = self._expr.floor()
return Expression._from_pyexpr(expr)

def sign(self) -> Expression:
"""The sign of a numeric expression (``expr.sign()``)"""
expr = self._expr.sign()
return Expression._from_pyexpr(expr)

def count(self, mode: CountMode = CountMode.Valid) -> Expression:
"""Counts the number of values in the expression.
Expand Down
3 changes: 3 additions & 0 deletions daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,9 @@ def ceil(self) -> Series:
def floor(self) -> Series:
return Series._from_pyseries(self._series.floor())

def sign(self) -> Series:
return Series._from_pyseries(self._series.sign())

def __add__(self, other: object) -> Series:
if not isinstance(other, Series):
raise TypeError(f"expected another Series but got {type(other)}")
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_docs/expressions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ Numeric
Expression.__mod__
Expression.ceil
Expression.floor
Expression.sign

.. _api-comparison-expression:

Expand Down
1 change: 1 addition & 0 deletions src/daft-core/src/array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ mod null;
mod pairwise;
mod repr;
mod search_sorted;
mod sign;
mod sort;
mod struct_;
mod sum;
Expand Down
30 changes: 30 additions & 0 deletions src/daft-core/src/array/ops/sign.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use crate::{array::DataArray, datatypes::DaftNumericType};
use num_traits::Signed;
use num_traits::Unsigned;
use num_traits::{One, Zero};

use common_error::DaftResult;

impl<T: DaftNumericType> DataArray<T>
where
T::Native: Signed,
{
pub fn sign(&self) -> DaftResult<Self> {
self.apply(|v| v.signum())
}
}

impl<T: DaftNumericType> DataArray<T>
where
T::Native: Unsigned,
{
pub fn sign_unsigned(&self) -> DaftResult<Self> {
self.apply(|v| {
if v.is_zero() {
T::Native::zero()
} else {
T::Native::one()
}
})
}
}
4 changes: 4 additions & 0 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ impl PySeries {
Ok(self.series.floor()?.into())
}

pub fn sign(&self) -> PyResult<Self> {
Ok(self.series.sign()?.into())
}

pub fn take(&self, idx: &Self) -> PyResult<Self> {
Ok(self.series.take(&idx.series)?.into())
}
Expand Down
1 change: 1 addition & 0 deletions src/daft-core/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pub mod not;
pub mod null;
pub mod partitioning;
pub mod search_sorted;
pub mod sign;
pub mod sort;
pub mod struct_;
pub mod take;
Expand Down
27 changes: 27 additions & 0 deletions src/daft-core/src/series/ops/sign.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
use crate::datatypes::DataType;
use crate::series::Series;
use common_error::DaftError;
use common_error::DaftResult;

impl Series {
pub fn sign(&self) -> DaftResult<Series> {
use crate::series::array_impl::IntoSeries;
use DataType::*;
match self.data_type() {
UInt8 => Ok(self.u8().unwrap().sign_unsigned()?.into_series()),
UInt16 => Ok(self.u16().unwrap().sign_unsigned()?.into_series()),
UInt32 => Ok(self.u32().unwrap().sign_unsigned()?.into_series()),
UInt64 => Ok(self.u64().unwrap().sign_unsigned()?.into_series()),
Int8 => Ok(self.i8().unwrap().sign()?.into_series()),
Int16 => Ok(self.i16().unwrap().sign()?.into_series()),
Int32 => Ok(self.i32().unwrap().sign()?.into_series()),
Int64 => Ok(self.i64().unwrap().sign()?.into_series()),
Float32 => Ok(self.f32().unwrap().sign()?.into_series()),
Float64 => Ok(self.f64().unwrap().sign()?.into_series()),
dt => Err(DaftError::TypeError(format!(
"sign not implemented for {}",
dt
))),
}
}
}
11 changes: 11 additions & 0 deletions src/daft-dsl/src/functions/numeric/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
mod abs;
mod ceil;
mod floor;
mod sign;

use abs::AbsEvaluator;
use ceil::CeilEvaluator;
use floor::FloorEvaluator;
use sign::SignEvaluator;

use serde::{Deserialize, Serialize};

Expand All @@ -17,6 +19,7 @@ pub enum NumericExpr {
Abs,
Ceil,
Floor,
Sign,
}

impl NumericExpr {
Expand All @@ -27,6 +30,7 @@ impl NumericExpr {
Abs => &AbsEvaluator {},
Ceil => &CeilEvaluator {},
Floor => &FloorEvaluator {},
Sign => &SignEvaluator {},
}
}
}
Expand All @@ -51,3 +55,10 @@ pub fn floor(input: &Expr) -> Expr {
inputs: vec![input.clone()],
}
}

pub fn sign(input: &Expr) -> Expr {
Expr::Function {
func: super::FunctionExpr::Numeric(NumericExpr::Sign),
inputs: vec![input.clone()],
}
}
40 changes: 40 additions & 0 deletions src/daft-dsl/src/functions/numeric/sign.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use common_error::{DaftError, DaftResult};
use daft_core::{datatypes::Field, schema::Schema, series::Series};

use super::super::FunctionEvaluator;
use crate::Expr;

pub(super) struct SignEvaluator {}

impl FunctionEvaluator for SignEvaluator {
fn fn_name(&self) -> &'static str {
"sign"
}

fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult<Field> {
if inputs.len() != 1 {
return Err(DaftError::SchemaMismatch(format!(
"Expected 1 input arg, got {}",
inputs.len()
)));
}
let field = inputs.first().unwrap().to_field(schema)?;
if !field.dtype.is_numeric() {
return Err(DaftError::TypeError(format!(
"Expected input to sign to be numeric, got {}",
field.dtype
)));
}
Ok(field)
}

fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult<Series> {
if inputs.len() != 1 {
return Err(DaftError::SchemaMismatch(format!(
"Expected 1 input arg, got {}",
inputs.len()
)));
}
inputs.first().unwrap().sign()
}
}
5 changes: 5 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,11 @@ impl PyExpr {
Ok(floor(&self.expr).into())
}

pub fn sign(&self) -> PyResult<Self> {
use functions::numeric::sign;
Ok(sign(&self.expr).into())
}

pub fn if_else(&self, if_true: &Self, if_false: &Self) -> PyResult<Self> {
Ok(self.expr.if_else(&if_true.expr, &if_false.expr).into())
}
Expand Down
9 changes: 9 additions & 0 deletions tests/expressions/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,15 @@ def test_repr_functions_floor() -> None:
assert repr_out == repr(copied)


def test_repr_functions_sign() -> None:
a = col("a")
y = a.sign()
repr_out = repr(y)
assert repr_out == "sign(col(a))"
copied = copy.deepcopy(y)
assert repr_out == repr(copied)


def test_repr_functions_day() -> None:
a = col("a")
y = a.dt.day()
Expand Down
10 changes: 10 additions & 0 deletions tests/expressions/typing/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,13 @@ def test_floor(unary_data_fixture):
run_kernel=lambda: arg.floor(),
resolvable=is_numeric(arg.datatype()),
)


def test_sign(unary_data_fixture):
arg = unary_data_fixture
assert_typing_resolve_vs_runtime_behavior(
data=(unary_data_fixture,),
expr=col(arg.name()).sign(),
run_kernel=lambda: arg.sign(),
resolvable=is_numeric(arg.datatype()),
)
35 changes: 35 additions & 0 deletions tests/table/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,38 @@ def test_table_floor_bad_input() -> None:

with pytest.raises(ValueError, match="Expected input to floor to be numeric"):
table.eval_expression_list([col("a").floor()])


def test_table_numeric_sign() -> None:
table = MicroPartition.from_pydict(
{"a": [None, -1, -5, 0, 5, 2, None], "b": [-1.7, -1.5, -1.3, 0.3, 0.7, None, None]}
)
my_schema = pa.schema([pa.field("uint8", pa.uint8())])
table_Unsign = MicroPartition.from_arrow(pa.Table.from_arrays([pa.array([None, 0, 1, 2, 3])], schema=my_schema))

sign_table = table.eval_expression_list([col("a").sign(), col("b").sign()])
unsign_sign_table = table_Unsign.eval_expression_list([col("uint8").sign()])

def checkSign(val):
if val < 0:
return -1
if val > 0:
return 1
return 0

assert [checkSign(v) if v is not None else v for v in table.get_column("a").to_pylist()] == sign_table.get_column(
"a"
).to_pylist()
assert [checkSign(v) if v is not None else v for v in table.get_column("b").to_pylist()] == sign_table.get_column(
"b"
).to_pylist()
assert [
checkSign(v) if v is not None else v for v in table_Unsign.get_column("uint8").to_pylist()
] == unsign_sign_table.get_column("uint8").to_pylist()


def test_table_sign_bad_input() -> None:
table = MicroPartition.from_pydict({"a": ["a", "b", "c"]})

with pytest.raises(ValueError, match="Expected input to sign to be numeric"):
table.eval_expression_list([col("a").sign()])

0 comments on commit b43c3c8

Please sign in to comment.