Skip to content

Commit

Permalink
[BUG] fix type widening for rem
Browse files Browse the repository at this point in the history
  • Loading branch information
samster25 committed Oct 26, 2024
1 parent 5b450fb commit 80c53e7
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 2 deletions.
78 changes: 76 additions & 2 deletions src/daft-core/src/datatypes/infer_datatype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ impl<'a> Rem for InferDataType<'a> {
type Output = DaftResult<DataType>;

fn rem(self, other: Self) -> Self::Output {
try_numeric_supertype(self.0, other.0)
try_integer_widen_for_rem(self.0, other.0)
.or_else(|_| {
try_fixed_shape_numeric_datatype(self.0, other.0, |l, r| {
InferDataType::from(l) % InferDataType::from(r)
Expand All @@ -279,7 +279,7 @@ impl<'a> Rem for InferDataType<'a> {
#[cfg(feature = "python")]
(DataType::Python, _) | (_, DataType::Python) => Ok(DataType::Python),
_ => Err(DaftError::TypeError(format!(
"Cannot multiply types: {}, {}",
"Cannot modulus types: {}, {}",
self, other
))),
})
Expand Down Expand Up @@ -424,6 +424,80 @@ pub fn try_numeric_supertype(l: &DataType, r: &DataType) -> DaftResult<DataType>
)))
}

pub fn try_integer_widen_for_rem(l: &DataType, r: &DataType) -> DaftResult<DataType> {
// If given two integer data types,
// get the integer type that they should both be casted to
// for the purpose of performing widening.

fn inner(l: &DataType, r: &DataType) -> Option<DataType> {
match (l, r) {
(DataType::Int8, DataType::Int8) => Some(DataType::Int8),
(DataType::Int8, DataType::Int16) => Some(DataType::Int16),
(DataType::Int8, DataType::Int32) => Some(DataType::Int32),
(DataType::Int8, DataType::Int64) => Some(DataType::Int64),
(DataType::Int8, DataType::UInt8) => Some(DataType::UInt8),
(DataType::Int8, DataType::UInt16) => Some(DataType::UInt16),
(DataType::Int8, DataType::UInt32) => Some(DataType::UInt32),
(DataType::Int8, DataType::UInt64) => Some(DataType::UInt64),

(DataType::Int16, DataType::Int8) => Some(DataType::Int16),
(DataType::Int16, DataType::Int16) => Some(DataType::Int16),
(DataType::Int16, DataType::Int32) => Some(DataType::Int32),
(DataType::Int16, DataType::Int64) => Some(DataType::Int64),
(DataType::Int16, DataType::UInt8) => Some(DataType::UInt16),
(DataType::Int16, DataType::UInt16) => Some(DataType::UInt16),
(DataType::Int16, DataType::UInt32) => Some(DataType::UInt32),
(DataType::Int16, DataType::UInt64) => Some(DataType::UInt64),

(DataType::Int32, DataType::Int8) => Some(DataType::Int32),
(DataType::Int32, DataType::Int16) => Some(DataType::Int32),
(DataType::Int32, DataType::Int32) => Some(DataType::Int32),
(DataType::Int32, DataType::Int64) => Some(DataType::Int64),
(DataType::Int32, DataType::UInt8) => Some(DataType::UInt32),
(DataType::Int32, DataType::UInt16) => Some(DataType::UInt32),
(DataType::Int32, DataType::UInt32) => Some(DataType::UInt32),
(DataType::Int32, DataType::UInt64) => Some(DataType::UInt64),

(DataType::Int64, DataType::Int8) => Some(DataType::Int64),
(DataType::Int64, DataType::Int16) => Some(DataType::Int64),
(DataType::Int64, DataType::Int32) => Some(DataType::Int64),
(DataType::Int64, DataType::Int64) => Some(DataType::Int64),
(DataType::Int64, DataType::UInt8) => Some(DataType::UInt64),
(DataType::Int64, DataType::UInt16) => Some(DataType::UInt64),
(DataType::Int64, DataType::UInt32) => Some(DataType::UInt64),
(DataType::Int64, DataType::UInt64) => Some(DataType::UInt64),

(DataType::UInt8, DataType::UInt8) => Some(DataType::UInt8),
(DataType::UInt8, DataType::UInt16) => Some(DataType::UInt16),
(DataType::UInt8, DataType::UInt32) => Some(DataType::UInt32),
(DataType::UInt8, DataType::UInt64) => Some(DataType::UInt64),

(DataType::UInt16, DataType::UInt8) => Some(DataType::UInt16),
(DataType::UInt16, DataType::UInt16) => Some(DataType::UInt16),
(DataType::UInt16, DataType::UInt32) => Some(DataType::UInt32),
(DataType::UInt16, DataType::UInt64) => Some(DataType::UInt64),

(DataType::UInt32, DataType::UInt8) => Some(DataType::UInt32),
(DataType::UInt32, DataType::UInt16) => Some(DataType::UInt32),
(DataType::UInt32, DataType::UInt32) => Some(DataType::UInt32),
(DataType::UInt32, DataType::UInt64) => Some(DataType::UInt64),

(DataType::UInt64, DataType::UInt8) => Some(DataType::UInt64),
(DataType::UInt64, DataType::UInt16) => Some(DataType::UInt64),
(DataType::UInt64, DataType::UInt32) => Some(DataType::UInt64),
(DataType::UInt64, DataType::UInt64) => Some(DataType::UInt64),
_ => None,
}
}

inner(l, r)
.or_else(|| inner(r, l))
.ok_or(DaftError::TypeError(format!(
"Invalid arguments to integer widening: {}, {}",
l, r
)))
}

pub fn try_fixed_shape_numeric_datatype<F>(
l: &DataType,
r: &DataType,
Expand Down
1 change: 1 addition & 0 deletions src/daft-core/src/series/ops/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ impl Rem for &Series {
fn rem(self, rhs: Self) -> Self::Output {
let output_type =
InferDataType::from(self.data_type()).rem(InferDataType::from(rhs.data_type()))?;

let lhs = self;
match &output_type {
#[cfg(feature = "python")]
Expand Down

0 comments on commit 80c53e7

Please sign in to comment.