diff --git a/daft/daft.pyi b/daft/daft.pyi index 2dae790022..a9f4d38721 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -883,6 +883,7 @@ class PyExpr: def cast(self, dtype: PyDataType) -> PyExpr: ... def ceil(self) -> PyExpr: ... def floor(self) -> PyExpr: ... + def sign(self) -> PyExpr: ... def if_else(self, if_true: PyExpr, if_false: PyExpr) -> PyExpr: ... def count(self, mode: CountMode) -> PyExpr: ... def sum(self) -> PyExpr: ... @@ -1007,6 +1008,7 @@ class PySeries: def cast(self, dtype: PyDataType) -> PySeries: ... def ceil(self) -> PySeries: ... def floor(self) -> PySeries: ... + def sign(self) -> PySeries: ... @staticmethod def concat(series: list[PySeries]) -> PySeries: ... def __len__(self) -> int: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index a395346e3e..8b46f90285 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -342,6 +342,11 @@ def floor(self) -> Expression: expr = self._expr.floor() return Expression._from_pyexpr(expr) + def sign(self) -> Expression: + """The sign of a numeric expression (``expr.sign()``)""" + expr = self._expr.sign() + return Expression._from_pyexpr(expr) + def count(self, mode: CountMode = CountMode.Valid) -> Expression: """Counts the number of values in the expression. diff --git a/daft/series.py b/daft/series.py index 53644e0802..73670307b1 100644 --- a/daft/series.py +++ b/daft/series.py @@ -361,6 +361,9 @@ def ceil(self) -> Series: def floor(self) -> Series: return Series._from_pyseries(self._series.floor()) + def sign(self) -> Series: + return Series._from_pyseries(self._series.sign()) + def __add__(self, other: object) -> Series: if not isinstance(other, Series): raise TypeError(f"expected another Series but got {type(other)}") diff --git a/docs/source/api_docs/expressions.rst b/docs/source/api_docs/expressions.rst index 65ece7f10b..7fdb8185a7 100644 --- a/docs/source/api_docs/expressions.rst +++ b/docs/source/api_docs/expressions.rst @@ -48,6 +48,7 @@ Numeric Expression.__mod__ Expression.ceil Expression.floor + Expression.sign .. _api-comparison-expression: diff --git a/src/daft-core/src/array/ops/mod.rs b/src/daft-core/src/array/ops/mod.rs index c3a9ee1385..3cf1e17d91 100644 --- a/src/daft-core/src/array/ops/mod.rs +++ b/src/daft-core/src/array/ops/mod.rs @@ -33,6 +33,7 @@ mod null; mod pairwise; mod repr; mod search_sorted; +mod sign; mod sort; mod struct_; mod sum; diff --git a/src/daft-core/src/array/ops/sign.rs b/src/daft-core/src/array/ops/sign.rs new file mode 100644 index 0000000000..c3be34bb92 --- /dev/null +++ b/src/daft-core/src/array/ops/sign.rs @@ -0,0 +1,30 @@ +use crate::{array::DataArray, datatypes::DaftNumericType}; +use num_traits::Signed; +use num_traits::Unsigned; +use num_traits::{One, Zero}; + +use common_error::DaftResult; + +impl DataArray +where + T::Native: Signed, +{ + pub fn sign(&self) -> DaftResult { + self.apply(|v| v.signum()) + } +} + +impl DataArray +where + T::Native: Unsigned, +{ + pub fn sign_unsigned(&self) -> DaftResult { + self.apply(|v| { + if v.is_zero() { + T::Native::zero() + } else { + T::Native::one() + } + }) + } +} diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index 7af7507bfe..9b264aadc6 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -116,6 +116,10 @@ impl PySeries { Ok(self.series.floor()?.into()) } + pub fn sign(&self) -> PyResult { + Ok(self.series.sign()?.into()) + } + pub fn take(&self, idx: &Self) -> PyResult { Ok(self.series.take(&idx.series)?.into()) } diff --git a/src/daft-core/src/series/ops/mod.rs b/src/daft-core/src/series/ops/mod.rs index 4b34aef5f8..09ff8c13ea 100644 --- a/src/daft-core/src/series/ops/mod.rs +++ b/src/daft-core/src/series/ops/mod.rs @@ -28,6 +28,7 @@ pub mod not; pub mod null; pub mod partitioning; pub mod search_sorted; +pub mod sign; pub mod sort; pub mod struct_; pub mod take; diff --git a/src/daft-core/src/series/ops/sign.rs b/src/daft-core/src/series/ops/sign.rs new file mode 100644 index 0000000000..e12de0726b --- /dev/null +++ b/src/daft-core/src/series/ops/sign.rs @@ -0,0 +1,27 @@ +use crate::datatypes::DataType; +use crate::series::Series; +use common_error::DaftError; +use common_error::DaftResult; + +impl Series { + pub fn sign(&self) -> DaftResult { + use crate::series::array_impl::IntoSeries; + use DataType::*; + match self.data_type() { + UInt8 => Ok(self.u8().unwrap().sign_unsigned()?.into_series()), + UInt16 => Ok(self.u16().unwrap().sign_unsigned()?.into_series()), + UInt32 => Ok(self.u32().unwrap().sign_unsigned()?.into_series()), + UInt64 => Ok(self.u64().unwrap().sign_unsigned()?.into_series()), + Int8 => Ok(self.i8().unwrap().sign()?.into_series()), + Int16 => Ok(self.i16().unwrap().sign()?.into_series()), + Int32 => Ok(self.i32().unwrap().sign()?.into_series()), + Int64 => Ok(self.i64().unwrap().sign()?.into_series()), + Float32 => Ok(self.f32().unwrap().sign()?.into_series()), + Float64 => Ok(self.f64().unwrap().sign()?.into_series()), + dt => Err(DaftError::TypeError(format!( + "sign not implemented for {}", + dt + ))), + } + } +} diff --git a/src/daft-dsl/src/functions/numeric/mod.rs b/src/daft-dsl/src/functions/numeric/mod.rs index 7b2eb8447b..91888ed59a 100644 --- a/src/daft-dsl/src/functions/numeric/mod.rs +++ b/src/daft-dsl/src/functions/numeric/mod.rs @@ -1,10 +1,12 @@ mod abs; mod ceil; mod floor; +mod sign; use abs::AbsEvaluator; use ceil::CeilEvaluator; use floor::FloorEvaluator; +use sign::SignEvaluator; use serde::{Deserialize, Serialize}; @@ -17,6 +19,7 @@ pub enum NumericExpr { Abs, Ceil, Floor, + Sign, } impl NumericExpr { @@ -27,6 +30,7 @@ impl NumericExpr { Abs => &AbsEvaluator {}, Ceil => &CeilEvaluator {}, Floor => &FloorEvaluator {}, + Sign => &SignEvaluator {}, } } } @@ -51,3 +55,10 @@ pub fn floor(input: &Expr) -> Expr { inputs: vec![input.clone()], } } + +pub fn sign(input: &Expr) -> Expr { + Expr::Function { + func: super::FunctionExpr::Numeric(NumericExpr::Sign), + inputs: vec![input.clone()], + } +} diff --git a/src/daft-dsl/src/functions/numeric/sign.rs b/src/daft-dsl/src/functions/numeric/sign.rs new file mode 100644 index 0000000000..e801b300a3 --- /dev/null +++ b/src/daft-dsl/src/functions/numeric/sign.rs @@ -0,0 +1,40 @@ +use common_error::{DaftError, DaftResult}; +use daft_core::{datatypes::Field, schema::Schema, series::Series}; + +use super::super::FunctionEvaluator; +use crate::Expr; + +pub(super) struct SignEvaluator {} + +impl FunctionEvaluator for SignEvaluator { + fn fn_name(&self) -> &'static str { + "sign" + } + + fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { + if inputs.len() != 1 { + return Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))); + } + let field = inputs.first().unwrap().to_field(schema)?; + if !field.dtype.is_numeric() { + return Err(DaftError::TypeError(format!( + "Expected input to sign to be numeric, got {}", + field.dtype + ))); + } + Ok(field) + } + + fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult { + if inputs.len() != 1 { + return Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))); + } + inputs.first().unwrap().sign() + } +} diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index 508dc2c471..9cbd3ac418 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -210,6 +210,11 @@ impl PyExpr { Ok(floor(&self.expr).into()) } + pub fn sign(&self) -> PyResult { + use functions::numeric::sign; + Ok(sign(&self.expr).into()) + } + pub fn if_else(&self, if_true: &Self, if_false: &Self) -> PyResult { Ok(self.expr.if_else(&if_true.expr, &if_false.expr).into()) } diff --git a/tests/expressions/test_expressions.py b/tests/expressions/test_expressions.py index 61e7f498c1..811f9ae41b 100644 --- a/tests/expressions/test_expressions.py +++ b/tests/expressions/test_expressions.py @@ -105,6 +105,15 @@ def test_repr_functions_floor() -> None: assert repr_out == repr(copied) +def test_repr_functions_sign() -> None: + a = col("a") + y = a.sign() + repr_out = repr(y) + assert repr_out == "sign(col(a))" + copied = copy.deepcopy(y) + assert repr_out == repr(copied) + + def test_repr_functions_day() -> None: a = col("a") y = a.dt.day() diff --git a/tests/expressions/typing/test_arithmetic.py b/tests/expressions/typing/test_arithmetic.py index c973265c87..c4b1c07f39 100644 --- a/tests/expressions/typing/test_arithmetic.py +++ b/tests/expressions/typing/test_arithmetic.py @@ -92,3 +92,13 @@ def test_floor(unary_data_fixture): run_kernel=lambda: arg.floor(), resolvable=is_numeric(arg.datatype()), ) + + +def test_sign(unary_data_fixture): + arg = unary_data_fixture + assert_typing_resolve_vs_runtime_behavior( + data=(unary_data_fixture,), + expr=col(arg.name()).sign(), + run_kernel=lambda: arg.sign(), + resolvable=is_numeric(arg.datatype()), + ) diff --git a/tests/table/test_eval.py b/tests/table/test_eval.py index 68c8ff0b4e..0a4b37e335 100644 --- a/tests/table/test_eval.py +++ b/tests/table/test_eval.py @@ -202,3 +202,38 @@ def test_table_floor_bad_input() -> None: with pytest.raises(ValueError, match="Expected input to floor to be numeric"): table.eval_expression_list([col("a").floor()]) + + +def test_table_numeric_sign() -> None: + table = MicroPartition.from_pydict( + {"a": [None, -1, -5, 0, 5, 2, None], "b": [-1.7, -1.5, -1.3, 0.3, 0.7, None, None]} + ) + my_schema = pa.schema([pa.field("uint8", pa.uint8())]) + table_Unsign = MicroPartition.from_arrow(pa.Table.from_arrays([pa.array([None, 0, 1, 2, 3])], schema=my_schema)) + + sign_table = table.eval_expression_list([col("a").sign(), col("b").sign()]) + unsign_sign_table = table_Unsign.eval_expression_list([col("uint8").sign()]) + + def checkSign(val): + if val < 0: + return -1 + if val > 0: + return 1 + return 0 + + assert [checkSign(v) if v is not None else v for v in table.get_column("a").to_pylist()] == sign_table.get_column( + "a" + ).to_pylist() + assert [checkSign(v) if v is not None else v for v in table.get_column("b").to_pylist()] == sign_table.get_column( + "b" + ).to_pylist() + assert [ + checkSign(v) if v is not None else v for v in table_Unsign.get_column("uint8").to_pylist() + ] == unsign_sign_table.get_column("uint8").to_pylist() + + +def test_table_sign_bad_input() -> None: + table = MicroPartition.from_pydict({"a": ["a", "b", "c"]}) + + with pytest.raises(ValueError, match="Expected input to sign to be numeric"): + table.eval_expression_list([col("a").sign()])