From 915467ba5230092ed191be51db8b89aa7219e024 Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Sat, 26 Oct 2024 23:07:06 +0800 Subject: [PATCH] [FEAT] Add floor division (#3064) Close #2418 --- daft/daft/__init__.pyi | 1 + daft/expressions/expressions.py | 10 ++ daft/series.py | 6 ++ src/daft-core/src/array/ops/arithmetic.rs | 71 +++++++++++--- src/daft-core/src/datatypes/infer_datatype.rs | 11 +++ src/daft-core/src/python/series.rs | 4 + src/daft-core/src/series/ops/arithmetic.rs | 51 +++++++++- src/daft-dsl/src/expr/mod.rs | 4 +- src/daft-sql/src/planner.rs | 1 + src/daft-table/src/lib.rs | 2 +- tests/series/test_arithmetic.py | 28 ++++++ tests/sql/test_binary_op_exprs.py | 94 +++++++++++++++++++ tests/table/numeric/test_numeric.py | 1 + 13 files changed, 265 insertions(+), 19 deletions(-) create mode 100644 tests/sql/test_binary_op_exprs.py diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 327b0de833..7edcae7158 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1339,6 +1339,7 @@ class PySeries: def __ne__(self, other: PySeries) -> PySeries: ... # type: ignore[override] def __rshift__(self, other: PySeries) -> PySeries: ... def __lshift__(self, other: PySeries) -> PySeries: ... + def __floordiv__(self, other: PySeries) -> PySeries: ... def take(self, idx: PySeries) -> PySeries: ... def slice(self, start: int, end: int) -> PySeries: ... def filter(self, mask: PySeries) -> PySeries: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 273a73a850..d1b52f6f95 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -486,6 +486,16 @@ def __invert__(self) -> Expression: expr = self._expr.__invert__() return Expression._from_pyexpr(expr) + def __floordiv__(self, other: Expression) -> Expression: + """Floor divides two numeric expressions (``e1 / e2``)""" + expr = Expression._to_expression(other) + return Expression._from_pyexpr(self._expr // expr._expr) + + def __rfloordiv__(self, other: object) -> Expression: + """Reverse floor divides two numeric expressions (``e2 / e1``)""" + expr = Expression._to_expression(other) + return Expression._from_pyexpr(expr._expr // self._expr) + def alias(self, name: builtins.str) -> Expression: """Gives the expression a new name, which is its column's name in the DataFrame schema and the name by which subsequent expressions can refer to the results of this expression. diff --git a/daft/series.py b/daft/series.py index fd85d33f13..97ac5aec9a 100644 --- a/daft/series.py +++ b/daft/series.py @@ -498,6 +498,12 @@ def __xor__(self, other: object) -> Series: assert self._series is not None and other._series is not None return Series._from_pyseries(self._series ^ other._series) + def __floordiv__(self, other: object) -> Series: + if not isinstance(other, Series): + raise TypeError(f"expected another Series but got {type(other)}") + assert self._series is not None and other._series is not None + return Series._from_pyseries(self._series // other._series) + def count(self, mode: CountMode = CountMode.Valid) -> Series: assert self._series is not None return Series._from_pyseries(self._series.count(mode)) diff --git a/src/daft-core/src/array/ops/arithmetic.rs b/src/daft-core/src/array/ops/arithmetic.rs index 365c178a28..c77f4722fa 100644 --- a/src/daft-core/src/array/ops/arithmetic.rs +++ b/src/daft-core/src/array/ops/arithmetic.rs @@ -6,7 +6,7 @@ use common_error::{DaftError, DaftResult}; use super::{as_arrow::AsArrow, full::FullNull}; use crate::{ array::{DataArray, FixedSizeListArray}, - datatypes::{DaftNumericType, DataType, Field, Float64Array, Int64Array, Utf8Array}, + datatypes::{DaftNumericType, DataType, Field, Utf8Array}, kernels::utf8::add_utf8_arrays, series::Series, }; @@ -108,20 +108,6 @@ where } } -impl Div for &Float64Array { - type Output = DaftResult; - fn div(self, rhs: Self) -> Self::Output { - arithmetic_helper(self, rhs, basic::div, |l, r| l / r) - } -} - -impl Div for &Int64Array { - type Output = DaftResult; - fn div(self, rhs: Self) -> Self::Output { - arithmetic_helper(self, rhs, basic::div, |l, r| l / r) - } -} - pub fn binary_with_nulls( lhs: &PrimitiveArray, rhs: &PrimitiveArray, @@ -195,6 +181,61 @@ where } } +fn div_with_nulls(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray +where + T: arrow2::types::NativeType + Div, +{ + binary_with_nulls(lhs, rhs, |a, b| a / b) +} + +impl Div for &DataArray +where + T: DaftNumericType, + T::Native: basic::NativeArithmetics, +{ + type Output = DaftResult>; + fn div(self, rhs: Self) -> Self::Output { + if rhs.data().null_count() == 0 { + arithmetic_helper(self, rhs, basic::div, |l, r| l / r) + } else { + match (self.len(), rhs.len()) { + (a, b) if a == b => Ok(DataArray::from(( + self.name(), + Box::new(div_with_nulls(self.as_arrow(), rhs.as_arrow())), + ))), + // broadcast right path + (_, 1) => { + let opt_rhs = rhs.get(0); + match opt_rhs { + None => Ok(DataArray::full_null( + self.name(), + self.data_type(), + self.len(), + )), + Some(rhs) => self.apply(|lhs| lhs / rhs), + } + } + (1, _) => { + let opt_lhs = self.get(0); + Ok(match opt_lhs { + None => DataArray::full_null(rhs.name(), rhs.data_type(), rhs.len()), + Some(lhs) => { + let values_iter = rhs.as_arrow().iter().map(|v| v.map(|v| lhs / *v)); + let arrow_array = unsafe { + PrimitiveArray::from_trusted_len_iter_unchecked(values_iter) + }; + DataArray::from((self.name(), Box::new(arrow_array))) + } + }) + } + (a, b) => Err(DaftError::ValueError(format!( + "Cannot apply operation on arrays of different lengths: {a} vs {b}" + ))), + } + } + } +} + fn fixed_sized_list_arithmetic_helper( lhs: &FixedSizeListArray, rhs: &FixedSizeListArray, diff --git a/src/daft-core/src/datatypes/infer_datatype.rs b/src/daft-core/src/datatypes/infer_datatype.rs index 020a36ceac..ac175b27af 100644 --- a/src/daft-core/src/datatypes/infer_datatype.rs +++ b/src/daft-core/src/datatypes/infer_datatype.rs @@ -114,6 +114,17 @@ impl<'a> InferDataType<'a> { // membership checks (is_in) use equality checks, so we can use the same logic as comparison ops. self.comparison_op(other) } + + pub fn floor_div(&self, other: &Self) -> DaftResult { + try_numeric_supertype(self.0, other.0).or(match (self.0, other.0) { + #[cfg(feature = "python")] + (DataType::Python, _) | (_, DataType::Python) => Ok(DataType::Python), + _ => Err(DaftError::TypeError(format!( + "Cannot perform floor divide on types: {}, {}", + self, other + ))), + }) + } } impl<'a> Add for InferDataType<'a> { diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index 87304b12d1..d173f18847 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -121,6 +121,10 @@ impl PySeries { Ok(self.series.shift_right(&other.series)?.into()) } + pub fn __floordiv__(&self, other: &Self) -> PyResult { + Ok(self.series.floor_div(&other.series)?.into()) + } + pub fn ceil(&self) -> PyResult { Ok(self.series.ceil()?.into()) } diff --git a/src/daft-core/src/series/ops/arithmetic.rs b/src/daft-core/src/series/ops/arithmetic.rs index b36d730b81..ff92df1023 100644 --- a/src/daft-core/src/series/ops/arithmetic.rs +++ b/src/daft-core/src/series/ops/arithmetic.rs @@ -9,7 +9,7 @@ use crate::{ array::prelude::*, datatypes::{InferDataType, Utf8Array}, series::{utils::cast::cast_downcast_op, IntoSeries, Series}, - with_match_numeric_daft_types, + with_match_integer_daft_types, with_match_numeric_daft_types, }; macro_rules! impl_arithmetic_ref_for_series { @@ -308,6 +308,29 @@ impl Rem for &Series { } } } + +impl Series { + pub fn floor_div(&self, rhs: &Self) -> DaftResult { + let output_type = InferDataType::from(self.data_type()) + .floor_div(&InferDataType::from(rhs.data_type()))?; + let lhs = self; + match &output_type { + #[cfg(feature = "python")] + DataType::Python => run_python_binary_operator_fn(lhs, rhs, "floordiv"), + output_type if output_type.is_integer() => { + with_match_integer_daft_types!(output_type, |$T| { + Ok(cast_downcast_op!(lhs, rhs, output_type, <$T as DaftDataType>::ArrayType, div)?.into_series()) + }) + } + output_type if output_type.is_numeric() => { + let div_floor = lhs.div(rhs)?.floor()?; + div_floor.cast(output_type) + } + _ => arithmetic_op_not_implemented!(self, "floor_div", rhs, output_type), + } + } +} + enum FixedSizeBinaryOp { Add, Sub, @@ -383,7 +406,7 @@ mod tests { use crate::{ array::ops::full::FullNull, - datatypes::{DataType, Float64Array, Int64Array, Utf8Array}, + datatypes::{DataType, Float32Array, Float64Array, Int32Array, Int64Array, Utf8Array}, series::IntoSeries, }; @@ -430,6 +453,30 @@ mod tests { Ok(()) } #[test] + fn floor_div_int_and_int() -> DaftResult<()> { + let a = Int32Array::from(("a", vec![1, 2, 3])); + let b = Int64Array::from(("b", vec![1, 2, 3])); + let c = a.into_series().floor_div(&(b.into_series())); + assert_eq!(*c?.data_type(), DataType::Int64); + Ok(()) + } + #[test] + fn floor_div_int_and_float() -> DaftResult<()> { + let a = Int64Array::from(("a", vec![1, 2, 3])); + let b = Float64Array::from(("b", vec![1., 2., 3.])); + let c = a.into_series().floor_div(&(b.into_series())); + assert_eq!(*c?.data_type(), DataType::Float64); + Ok(()) + } + #[test] + fn floor_div_float_and_float() -> DaftResult<()> { + let a = Float32Array::from(("b", vec![1., 2., 3.])); + let b = Float64Array::from(("b", vec![1., 2., 3.])); + let c = a.into_series().floor_div(&(b.into_series())); + assert_eq!(*c?.data_type(), DataType::Float64); + Ok(()) + } + #[test] fn rem_int_and_float() -> DaftResult<()> { let a = Int64Array::from(("a", vec![1, 2, 3])); let b = Float64Array::from(("b", vec![1., 2., 3.])); diff --git a/src/daft-dsl/src/expr/mod.rs b/src/daft-dsl/src/expr/mod.rs index 567a2d35d8..9926309f5b 100644 --- a/src/daft-dsl/src/expr/mod.rs +++ b/src/daft-dsl/src/expr/mod.rs @@ -866,7 +866,9 @@ impl Expr { Ok(Field::new(left_field.name.as_str(), result_type)) } Operator::FloorDivide => { - unimplemented!() + let result_type = (InferDataType::from(&left_field.dtype) + .floor_div(&InferDataType::from(&right_field.dtype)))?; + Ok(Field::new(left_field.name.as_str(), result_type)) } } } diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index d14bc258da..8eb2f40fb0 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -941,6 +941,7 @@ impl SQLPlanner { BinaryOperator::NotEq => Ok(Operator::NotEq), BinaryOperator::And => Ok(Operator::And), BinaryOperator::Or => Ok(Operator::Or), + BinaryOperator::DuckIntegerDivide => Ok(Operator::FloorDivide), other => unsupported_sql_err!("Unsupported operator: '{other}'"), } } diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index e93f4d7a77..701791aaf9 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -530,6 +530,7 @@ impl Table { Plus => lhs + rhs, Minus => lhs - rhs, TrueDivide => lhs / rhs, + FloorDivide => lhs.floor_div(&rhs), Multiply => lhs * rhs, Modulus => lhs % rhs, Lt => Ok(lhs.lt(&rhs)?.into_series()), @@ -543,7 +544,6 @@ impl Table { Xor => lhs.xor(&rhs), ShiftLeft => lhs.shift_left(&rhs), ShiftRight => lhs.shift_right(&rhs), - _ => panic!("{op:?} not supported"), } } Function { func, inputs } => { diff --git a/tests/series/test_arithmetic.py b/tests/series/test_arithmetic.py index 692991f843..fcf60b0b05 100644 --- a/tests/series/test_arithmetic.py +++ b/tests/series/test_arithmetic.py @@ -37,6 +37,10 @@ def test_arithmetic_numbers_array(l_dtype, r_dtype) -> None: assert div.name() == left.name() assert div.to_pylist() == [1.0, 0.5, 3.0, None, None, None] + floor_div = left // right + assert floor_div.name() == left.name() + assert floor_div.to_pylist() == [1, 0, 3, None, None, None] + # mod = (l % r) # assert mod.name() == l.name() # assert mod.to_pylist() == [0, 2, 0, None, None, None] @@ -67,6 +71,10 @@ def test_arithmetic_numbers_left_scalar(l_dtype, r_dtype) -> None: assert div.name() == left.name() assert div.to_pylist() == [1.0, 0.25, 1.0, 0.2, None, None] + floor_div = left // right + assert floor_div.name() == left.name() + assert floor_div.to_pylist() == [1, 0, 1, 0, None, None] + mod = left % right assert mod.name() == left.name() assert mod.to_pylist() == [0, 1, 0, 1, None, None] @@ -97,6 +105,10 @@ def test_arithmetic_numbers_right_scalar(l_dtype, r_dtype) -> None: assert div.name() == left.name() assert div.to_pylist() == [1.0, 2.0, 3.0, None, 5.0, None] + floor_div = left // right + assert floor_div.name() == left.name() + assert floor_div.to_pylist() == [1, 2, 3, None, 5, None] + mod = left % right assert mod.name() == left.name() assert mod.to_pylist() == [0, 0, 0, None, 0, None] @@ -127,6 +139,10 @@ def test_arithmetic_numbers_null_scalar(l_dtype, r_dtype) -> None: assert div.name() == left.name() assert div.to_pylist() == [None, None, None, None, None, None] + floor_div = left / right + assert floor_div.name() == left.name() + assert floor_div.to_pylist() == [None, None, None, None, None, None] + mod = left % right assert mod.name() == left.name() assert mod.to_pylist() == [None, None, None, None, None, None] @@ -207,6 +223,9 @@ def test_comparisons_bad_right_value() -> None: with pytest.raises(TypeError, match="another Series"): left / right + with pytest.raises(TypeError, match="another Series"): + left // right + with pytest.raises(TypeError, match="another Series"): left * right @@ -233,6 +252,9 @@ def test_arithmetic_numbers_array_mismatch_length() -> None: with pytest.raises(ValueError, match="different lengths"): left / right + with pytest.raises(ValueError, match="different lengths"): + left // right + with pytest.raises(ValueError, match="different lengths"): left % right @@ -263,6 +285,11 @@ def __mod__(self, other): other = 5 return 5 % other + def __floordiv__(self, other): + if isinstance(other, FakeFive): + other = 5 + return 5 // other + @pytest.mark.parametrize( ["op", "expected_datatype", "expected", "expected_self"], @@ -272,6 +299,7 @@ def __mod__(self, other): (operator.mul, DataType.int64(), [10, None, None], [25, 25, None]), (operator.truediv, DataType.float64(), [2.5, None, None], [1.0, 1.0, None]), (operator.mod, DataType.int64(), [1, None, None], [0, 0, None]), + (operator.floordiv, DataType.int64(), [2, None, None], [1.0, 1.0, None]), ], ) def test_arithmetic_pyobjects(op, expected_datatype, expected, expected_self) -> None: diff --git a/tests/sql/test_binary_op_exprs.py b/tests/sql/test_binary_op_exprs.py new file mode 100644 index 0000000000..cfc47efb44 --- /dev/null +++ b/tests/sql/test_binary_op_exprs.py @@ -0,0 +1,94 @@ +from typing import Callable + +import pytest + +import daft +from daft.sql import SQLCatalog + + +def _assert_sql_raise(sql: str, catalog: SQLCatalog, error_msg: str) -> None: + with pytest.raises(Exception) as excinfo: + daft.sql(sql, catalog=catalog).collect() + + assert error_msg in str(excinfo.value) + + +def _assert_df_op_raise(func: Callable, error_msg: str) -> None: + with pytest.raises(Exception) as excinfo: + func() + + assert error_msg in str(excinfo.value) + + +def test_div_floor(): + df = daft.from_pydict({"A": [1, 2, 3, 4], "B": [1.5, 2.5, 3.5, 4.5], "C": [4, 5, 6, 7]}) + + actual1 = daft.sql("SELECT (A // 1) AS div_floor FROM df").collect().to_pydict() + actual2 = df.select((daft.col("A") // 1).alias("div_floor")).collect().to_pydict() + expected = { + "div_floor": [1, 2, 3, 4], + } + assert actual1 == actual2 == expected + + actual1 = daft.sql("SELECT (A // B) AS div_floor FROM df").collect().to_pydict() + actual2 = df.select((daft.col("A") // daft.col("B")).alias("div_floor")).collect().to_pydict() + expected = { + "div_floor": [0, 0, 0, 0], + } + assert actual1 == actual2 == expected + + actual1 = daft.sql("SELECT (B // A) AS div_floor FROM df").collect().to_pydict() + actual2 = df.select((daft.col("B") // daft.col("A")).alias("div_floor")).collect().to_pydict() + expected = { + "div_floor": [1.0, 1.0, 1.0, 1.0], + } + assert actual1 == actual2 == expected + + actual1 = daft.sql("SELECT (C // A) AS div_floor FROM df").collect().to_pydict() + actual2 = df.select((daft.col("C") // daft.col("A")).alias("div_floor")).collect().to_pydict() + expected = { + "div_floor": [4, 2, 2, 1], + } + assert actual1 == actual2 == expected + + +def test_unsupported_div_floor(): + df = daft.from_pydict({"A": [1, 2, 3, 4], "B": [1.5, 2.5, 3.5, 4.5], "C": [True, False, True, True]}) + + catalog = SQLCatalog({"df": df}) + + _assert_sql_raise( + "SELECT A // C FROM df", catalog, "TypeError Cannot perform floor divide on types: Int64, Boolean" + ) + + _assert_sql_raise( + "SELECT C // A FROM df", catalog, "TypeError Cannot perform floor divide on types: Boolean, Int64" + ) + + _assert_sql_raise( + "SELECT B // C FROM df", catalog, "TypeError Cannot perform floor divide on types: Float64, Boolean" + ) + + _assert_sql_raise( + "SELECT B // C FROM df", catalog, "TypeError Cannot perform floor divide on types: Float64, Boolean" + ) + + _assert_df_op_raise( + lambda: df.select(daft.col("A") // daft.col("C")).collect(), + "TypeError Cannot perform floor divide on types: Int64, Boolean", + ) + + _assert_df_op_raise( + lambda: df.select(daft.col("C") // daft.col("A")).collect(), + "TypeError Cannot perform floor divide on types: Boolean, Int64", + ) + + _assert_df_op_raise( + lambda: df.select(daft.col("B") // daft.col("C")).collect(), + "TypeError Cannot perform floor divide on types: Float64, Boolean", + ) + + _assert_df_op_raise( + lambda: df.select(daft.col("C") // daft.col("B")).collect(), + "TypeError Cannot perform floor divide on types: Boolean, Float64", + ) diff --git a/tests/table/numeric/test_numeric.py b/tests/table/numeric/test_numeric.py index 28003f89b8..5459399538 100644 --- a/tests/table/numeric/test_numeric.py +++ b/tests/table/numeric/test_numeric.py @@ -17,6 +17,7 @@ ops.sub, ops.mul, ops.truediv, + ops.floordiv, ops.mod, ops.lt, ops.le,