diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 7edcae7158..eda8f03156 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1001,6 +1001,7 @@ class PyDataType: def python() -> PyDataType: ... def to_arrow(self, cast_tensor_type_for_ray: builtins.bool | None = None) -> pa.DataType: ... def is_numeric(self) -> builtins.bool: ... + def is_integer(self) -> builtins.bool: ... def is_image(self) -> builtins.bool: ... def is_fixed_shape_image(self) -> builtins.bool: ... def is_list(self) -> builtins.bool: ... diff --git a/daft/datatype.py b/daft/datatype.py index 3039121e96..5244231bd4 100644 --- a/daft/datatype.py +++ b/daft/datatype.py @@ -501,6 +501,9 @@ def _is_fixed_shape_image_type(self) -> builtins.bool: def _is_numeric_type(self) -> builtins.bool: return self._dtype.is_numeric() + def _is_integer(self) -> builtins.bool: + return self._dtype.is_integer() + def _is_list(self) -> builtins.bool: return self._dtype.is_list() diff --git a/src/daft-core/src/datatypes/infer_datatype.rs b/src/daft-core/src/datatypes/infer_datatype.rs index ac175b27af..51b28a665b 100644 --- a/src/daft-core/src/datatypes/infer_datatype.rs +++ b/src/daft-core/src/datatypes/infer_datatype.rs @@ -280,7 +280,7 @@ impl<'a> Rem for InferDataType<'a> { type Output = DaftResult; fn rem(self, other: Self) -> Self::Output { - try_numeric_supertype(self.0, other.0) + try_integer_widen_for_rem(self.0, other.0) .or_else(|_| { try_fixed_shape_numeric_datatype(self.0, other.0, |l, r| { InferDataType::from(l) % InferDataType::from(r) @@ -290,7 +290,7 @@ impl<'a> Rem for InferDataType<'a> { #[cfg(feature = "python")] (DataType::Python, _) | (_, DataType::Python) => Ok(DataType::Python), _ => Err(DaftError::TypeError(format!( - "Cannot multiply types: {}, {}", + "Cannot modulus types: {}, {}", self, other ))), }) @@ -435,6 +435,87 @@ pub fn try_numeric_supertype(l: &DataType, r: &DataType) -> DaftResult ))) } +pub fn try_integer_widen_for_rem(l: &DataType, r: &DataType) -> DaftResult { + // If given two integer data types, + // get the integer type that they should both be casted to + // for the purpose of performing widening. + + fn inner(l: &DataType, r: &DataType) -> Option { + match (l, r) { + (DataType::Float64, other) | (other, DataType::Float64) if other.is_numeric() => { + Some(DataType::Float64) + } + (DataType::Float32, other) | (other, DataType::Float32) if other.is_numeric() => { + Some(DataType::Float32) + } + + (DataType::Int8, DataType::Int8) => Some(DataType::Int8), + (DataType::Int8, DataType::Int16) => Some(DataType::Int16), + (DataType::Int8, DataType::Int32) => Some(DataType::Int32), + (DataType::Int8, DataType::Int64) => Some(DataType::Int64), + (DataType::Int8, DataType::UInt8) => Some(DataType::UInt8), + (DataType::Int8, DataType::UInt16) => Some(DataType::UInt16), + (DataType::Int8, DataType::UInt32) => Some(DataType::UInt32), + (DataType::Int8, DataType::UInt64) => Some(DataType::UInt64), + + (DataType::Int16, DataType::Int8) => Some(DataType::Int16), + (DataType::Int16, DataType::Int16) => Some(DataType::Int16), + (DataType::Int16, DataType::Int32) => Some(DataType::Int32), + (DataType::Int16, DataType::Int64) => Some(DataType::Int64), + (DataType::Int16, DataType::UInt8) => Some(DataType::UInt16), + (DataType::Int16, DataType::UInt16) => Some(DataType::UInt16), + (DataType::Int16, DataType::UInt32) => Some(DataType::UInt32), + (DataType::Int16, DataType::UInt64) => Some(DataType::UInt64), + + (DataType::Int32, DataType::Int8) => Some(DataType::Int32), + (DataType::Int32, DataType::Int16) => Some(DataType::Int32), + (DataType::Int32, DataType::Int32) => Some(DataType::Int32), + (DataType::Int32, DataType::Int64) => Some(DataType::Int64), + (DataType::Int32, DataType::UInt8) => Some(DataType::UInt32), + (DataType::Int32, DataType::UInt16) => Some(DataType::UInt32), + (DataType::Int32, DataType::UInt32) => Some(DataType::UInt32), + (DataType::Int32, DataType::UInt64) => Some(DataType::UInt64), + + (DataType::Int64, DataType::Int8) => Some(DataType::Int64), + (DataType::Int64, DataType::Int16) => Some(DataType::Int64), + (DataType::Int64, DataType::Int32) => Some(DataType::Int64), + (DataType::Int64, DataType::Int64) => Some(DataType::Int64), + (DataType::Int64, DataType::UInt8) => Some(DataType::UInt64), + (DataType::Int64, DataType::UInt16) => Some(DataType::UInt64), + (DataType::Int64, DataType::UInt32) => Some(DataType::UInt64), + (DataType::Int64, DataType::UInt64) => Some(DataType::UInt64), + + (DataType::UInt8, DataType::UInt8) => Some(DataType::UInt8), + (DataType::UInt8, DataType::UInt16) => Some(DataType::UInt16), + (DataType::UInt8, DataType::UInt32) => Some(DataType::UInt32), + (DataType::UInt8, DataType::UInt64) => Some(DataType::UInt64), + + (DataType::UInt16, DataType::UInt8) => Some(DataType::UInt16), + (DataType::UInt16, DataType::UInt16) => Some(DataType::UInt16), + (DataType::UInt16, DataType::UInt32) => Some(DataType::UInt32), + (DataType::UInt16, DataType::UInt64) => Some(DataType::UInt64), + + (DataType::UInt32, DataType::UInt8) => Some(DataType::UInt32), + (DataType::UInt32, DataType::UInt16) => Some(DataType::UInt32), + (DataType::UInt32, DataType::UInt32) => Some(DataType::UInt32), + (DataType::UInt32, DataType::UInt64) => Some(DataType::UInt64), + + (DataType::UInt64, DataType::UInt8) => Some(DataType::UInt64), + (DataType::UInt64, DataType::UInt16) => Some(DataType::UInt64), + (DataType::UInt64, DataType::UInt32) => Some(DataType::UInt64), + (DataType::UInt64, DataType::UInt64) => Some(DataType::UInt64), + _ => None, + } + } + + inner(l, r) + .or_else(|| inner(r, l)) + .ok_or(DaftError::TypeError(format!( + "Invalid arguments to integer widening: {}, {}", + l, r + ))) +} + pub fn try_fixed_shape_numeric_datatype( l: &DataType, r: &DataType, diff --git a/src/daft-core/src/series/ops/arithmetic.rs b/src/daft-core/src/series/ops/arithmetic.rs index ff92df1023..baf5087fb6 100644 --- a/src/daft-core/src/series/ops/arithmetic.rs +++ b/src/daft-core/src/series/ops/arithmetic.rs @@ -292,6 +292,7 @@ impl Rem for &Series { fn rem(self, rhs: Self) -> Self::Output { let output_type = InferDataType::from(self.data_type()).rem(InferDataType::from(rhs.data_type()))?; + let lhs = self; match &output_type { #[cfg(feature = "python")] diff --git a/src/daft-schema/src/python/datatype.rs b/src/daft-schema/src/python/datatype.rs index dcd772a516..2aad609ad7 100644 --- a/src/daft-schema/src/python/datatype.rs +++ b/src/daft-schema/src/python/datatype.rs @@ -349,6 +349,10 @@ impl PyDataType { Ok(self.dtype.is_numeric()) } + pub fn is_integer(&self) -> PyResult { + Ok(self.dtype.is_integer()) + } + pub fn is_image(&self) -> PyResult { Ok(self.dtype.is_image()) } diff --git a/tests/series/test_arithmetic.py b/tests/series/test_arithmetic.py index fcf60b0b05..eb369f232c 100644 --- a/tests/series/test_arithmetic.py +++ b/tests/series/test_arithmetic.py @@ -37,14 +37,14 @@ def test_arithmetic_numbers_array(l_dtype, r_dtype) -> None: assert div.name() == left.name() assert div.to_pylist() == [1.0, 0.5, 3.0, None, None, None] + mod = left % right + assert mod.name() == left.name() + assert mod.to_pylist() == [0, 2, 0, None, None, None] + floor_div = left // right assert floor_div.name() == left.name() assert floor_div.to_pylist() == [1, 0, 3, None, None, None] - # mod = (l % r) - # assert mod.name() == l.name() - # assert mod.to_pylist() == [0, 2, 0, None, None, None] - @pytest.mark.parametrize("l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_float_types, repeat=2)) def test_arithmetic_numbers_left_scalar(l_dtype, r_dtype) -> None: @@ -309,3 +309,17 @@ def test_arithmetic_pyobjects(op, expected_datatype, expected, expected_self) -> assert op(fake_fives, values).datatype() == expected_datatype assert op(fake_fives, values).to_pylist() == expected assert op(fake_fives, fake_fives).to_pylist() == expected_self + + +@pytest.mark.parametrize("l_dtype, r_dtype", itertools.product(arrow_int_types, repeat=2)) +def test_mod_series(l_dtype, r_dtype) -> None: + l_arrow = pa.array([1, 2, 3, None, 5, None]) + r_arrow = pa.array([1, 4, 1, 5, None, None]) + + left = Series.from_arrow(l_arrow.cast(l_dtype), name="left") + right = Series.from_arrow(r_arrow.cast(r_dtype), name="right") + + mod = left % right + assert mod.name() == left.name() + assert mod.datatype()._is_integer() + assert mod.to_pylist() == [0, 2, 0, None, None, None]