From 80c53e7a8e2aa3fd6c968e8c814d08d3d0f27107 Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Fri, 25 Oct 2024 20:55:27 -0700 Subject: [PATCH] [BUG] fix type widening for rem --- src/daft-core/src/datatypes/infer_datatype.rs | 78 ++++++++++++++++++- src/daft-core/src/series/ops/arithmetic.rs | 1 + 2 files changed, 77 insertions(+), 2 deletions(-) diff --git a/src/daft-core/src/datatypes/infer_datatype.rs b/src/daft-core/src/datatypes/infer_datatype.rs index 020a36ceac..c5867d2d17 100644 --- a/src/daft-core/src/datatypes/infer_datatype.rs +++ b/src/daft-core/src/datatypes/infer_datatype.rs @@ -269,7 +269,7 @@ impl<'a> Rem for InferDataType<'a> { type Output = DaftResult; fn rem(self, other: Self) -> Self::Output { - try_numeric_supertype(self.0, other.0) + try_integer_widen_for_rem(self.0, other.0) .or_else(|_| { try_fixed_shape_numeric_datatype(self.0, other.0, |l, r| { InferDataType::from(l) % InferDataType::from(r) @@ -279,7 +279,7 @@ impl<'a> Rem for InferDataType<'a> { #[cfg(feature = "python")] (DataType::Python, _) | (_, DataType::Python) => Ok(DataType::Python), _ => Err(DaftError::TypeError(format!( - "Cannot multiply types: {}, {}", + "Cannot modulus types: {}, {}", self, other ))), }) @@ -424,6 +424,80 @@ pub fn try_numeric_supertype(l: &DataType, r: &DataType) -> DaftResult ))) } +pub fn try_integer_widen_for_rem(l: &DataType, r: &DataType) -> DaftResult { + // If given two integer data types, + // get the integer type that they should both be casted to + // for the purpose of performing widening. + + fn inner(l: &DataType, r: &DataType) -> Option { + match (l, r) { + (DataType::Int8, DataType::Int8) => Some(DataType::Int8), + (DataType::Int8, DataType::Int16) => Some(DataType::Int16), + (DataType::Int8, DataType::Int32) => Some(DataType::Int32), + (DataType::Int8, DataType::Int64) => Some(DataType::Int64), + (DataType::Int8, DataType::UInt8) => Some(DataType::UInt8), + (DataType::Int8, DataType::UInt16) => Some(DataType::UInt16), + (DataType::Int8, DataType::UInt32) => Some(DataType::UInt32), + (DataType::Int8, DataType::UInt64) => Some(DataType::UInt64), + + (DataType::Int16, DataType::Int8) => Some(DataType::Int16), + (DataType::Int16, DataType::Int16) => Some(DataType::Int16), + (DataType::Int16, DataType::Int32) => Some(DataType::Int32), + (DataType::Int16, DataType::Int64) => Some(DataType::Int64), + (DataType::Int16, DataType::UInt8) => Some(DataType::UInt16), + (DataType::Int16, DataType::UInt16) => Some(DataType::UInt16), + (DataType::Int16, DataType::UInt32) => Some(DataType::UInt32), + (DataType::Int16, DataType::UInt64) => Some(DataType::UInt64), + + (DataType::Int32, DataType::Int8) => Some(DataType::Int32), + (DataType::Int32, DataType::Int16) => Some(DataType::Int32), + (DataType::Int32, DataType::Int32) => Some(DataType::Int32), + (DataType::Int32, DataType::Int64) => Some(DataType::Int64), + (DataType::Int32, DataType::UInt8) => Some(DataType::UInt32), + (DataType::Int32, DataType::UInt16) => Some(DataType::UInt32), + (DataType::Int32, DataType::UInt32) => Some(DataType::UInt32), + (DataType::Int32, DataType::UInt64) => Some(DataType::UInt64), + + (DataType::Int64, DataType::Int8) => Some(DataType::Int64), + (DataType::Int64, DataType::Int16) => Some(DataType::Int64), + (DataType::Int64, DataType::Int32) => Some(DataType::Int64), + (DataType::Int64, DataType::Int64) => Some(DataType::Int64), + (DataType::Int64, DataType::UInt8) => Some(DataType::UInt64), + (DataType::Int64, DataType::UInt16) => Some(DataType::UInt64), + (DataType::Int64, DataType::UInt32) => Some(DataType::UInt64), + (DataType::Int64, DataType::UInt64) => Some(DataType::UInt64), + + (DataType::UInt8, DataType::UInt8) => Some(DataType::UInt8), + (DataType::UInt8, DataType::UInt16) => Some(DataType::UInt16), + (DataType::UInt8, DataType::UInt32) => Some(DataType::UInt32), + (DataType::UInt8, DataType::UInt64) => Some(DataType::UInt64), + + (DataType::UInt16, DataType::UInt8) => Some(DataType::UInt16), + (DataType::UInt16, DataType::UInt16) => Some(DataType::UInt16), + (DataType::UInt16, DataType::UInt32) => Some(DataType::UInt32), + (DataType::UInt16, DataType::UInt64) => Some(DataType::UInt64), + + (DataType::UInt32, DataType::UInt8) => Some(DataType::UInt32), + (DataType::UInt32, DataType::UInt16) => Some(DataType::UInt32), + (DataType::UInt32, DataType::UInt32) => Some(DataType::UInt32), + (DataType::UInt32, DataType::UInt64) => Some(DataType::UInt64), + + (DataType::UInt64, DataType::UInt8) => Some(DataType::UInt64), + (DataType::UInt64, DataType::UInt16) => Some(DataType::UInt64), + (DataType::UInt64, DataType::UInt32) => Some(DataType::UInt64), + (DataType::UInt64, DataType::UInt64) => Some(DataType::UInt64), + _ => None, + } + } + + inner(l, r) + .or_else(|| inner(r, l)) + .ok_or(DaftError::TypeError(format!( + "Invalid arguments to integer widening: {}, {}", + l, r + ))) +} + pub fn try_fixed_shape_numeric_datatype( l: &DataType, r: &DataType, diff --git a/src/daft-core/src/series/ops/arithmetic.rs b/src/daft-core/src/series/ops/arithmetic.rs index b36d730b81..812435b381 100644 --- a/src/daft-core/src/series/ops/arithmetic.rs +++ b/src/daft-core/src/series/ops/arithmetic.rs @@ -292,6 +292,7 @@ impl Rem for &Series { fn rem(self, rhs: Self) -> Self::Output { let output_type = InferDataType::from(self.data_type()).rem(InferDataType::from(rhs.data_type()))?; + let lhs = self; match &output_type { #[cfg(feature = "python")]