diff --git a/src/arrow2/src/compute/arithmetics/decimal/mul.rs b/src/arrow2/src/compute/arithmetics/decimal/mul.rs index f050952e95..b092834422 100644 --- a/src/arrow2/src/compute/arithmetics/decimal/mul.rs +++ b/src/arrow2/src/compute/arithmetics/decimal/mul.rs @@ -36,7 +36,6 @@ use super::{adjusted_precision_scale, get_parameters, max_value, number_digits}; /// ``` pub fn mul(lhs: &PrimitiveArray, rhs: &PrimitiveArray) -> PrimitiveArray { let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap(); - let scale = 10i128.pow(scale as u32); let max = max_value(precision); diff --git a/src/daft-core/src/array/from_iter.rs b/src/daft-core/src/array/from_iter.rs index ee1bb90b13..0867598c79 100644 --- a/src/daft-core/src/array/from_iter.rs +++ b/src/daft-core/src/array/from_iter.rs @@ -1,19 +1,44 @@ -use arrow2::types::months_days_ns; +use std::sync::Arc; + +use arrow2::{ + array::{MutablePrimitiveArray, PrimitiveArray}, + types::months_days_ns, +}; use super::DataArray; -use crate::{array::prelude::*, datatypes::prelude::*}; +use crate::{ + array::prelude::*, + datatypes::{prelude::*, DaftPrimitiveType}, +}; impl DataArray where - T: DaftNumericType, + T: DaftPrimitiveType, { - pub fn from_iter( - name: &str, + pub fn from_iter>>( + field: F, iter: impl arrow2::trusted_len::TrustedLen>, ) -> Self { - let arrow_array = - Box::new(arrow2::array::PrimitiveArray::::from_trusted_len_iter(iter)); - Self::new(Field::new(name, T::get_dtype()).into(), arrow_array).unwrap() + // this is a workaround to prevent overflow issues when dealing with i128 and decimal + // typical behavior would be the result array would always be Decimal(32, 32) + let field = field.into(); + let mut array = MutablePrimitiveArray::::from(field.dtype.to_arrow().unwrap()); + array.extend_trusted_len(iter); + let data_array: PrimitiveArray<_> = array.into(); + Self::new(field, data_array.boxed()).unwrap() + } + + pub fn from_values_iter>>( + field: F, + iter: impl arrow2::trusted_len::TrustedLen, + ) -> Self { + // this is a workaround to prevent overflow issues when dealing with i128 and decimal + // typical behavior would be the result array would always be Decimal(32, 32) + let field = field.into(); + let mut array = MutablePrimitiveArray::::from(field.dtype.to_arrow().unwrap()); + array.extend_trusted_len_values(iter); + let data_array: PrimitiveArray<_> = array.into(); + Self::new(field, data_array.boxed()).unwrap() } } diff --git a/src/daft-core/src/array/growable/arrow_growable.rs b/src/daft-core/src/array/growable/arrow_growable.rs index 4524c0143d..480b3f85ba 100644 --- a/src/daft-core/src/array/growable/arrow_growable.rs +++ b/src/daft-core/src/array/growable/arrow_growable.rs @@ -112,11 +112,6 @@ impl_arrow_backed_data_array_growable!( Int64Type, arrow2::array::growable::GrowablePrimitive<'a, i64> ); -impl_arrow_backed_data_array_growable!( - ArrowInt128Growable, - Int128Type, - arrow2::array::growable::GrowablePrimitive<'a, i128> -); impl_arrow_backed_data_array_growable!( ArrowUInt8Growable, UInt8Type, @@ -168,6 +163,12 @@ impl_arrow_backed_data_array_growable!( arrow2::array::growable::GrowablePrimitive<'a, months_days_ns> ); +impl_arrow_backed_data_array_growable!( + ArrowDecimal128Growable, + Decimal128Type, + arrow2::array::growable::GrowablePrimitive<'a, i128> +); + /// ExtensionTypes are slightly different, because they have a dynamic inner type pub struct ArrowExtensionGrowable<'a> { name: String, diff --git a/src/daft-core/src/array/growable/logical_growable.rs b/src/daft-core/src/array/growable/logical_growable.rs index aaab91dca4..92f7abac66 100644 --- a/src/daft-core/src/array/growable/logical_growable.rs +++ b/src/daft-core/src/array/growable/logical_growable.rs @@ -83,6 +83,5 @@ impl_logical_growable!( FixedShapeSparseTensorType ); impl_logical_growable!(LogicalImageGrowable, ImageType); -impl_logical_growable!(LogicalDecimal128Growable, Decimal128Type); impl_logical_growable!(LogicalTensorGrowable, TensorType); impl_logical_growable!(LogicalMapGrowable, MapType); diff --git a/src/daft-core/src/array/growable/mod.rs b/src/daft-core/src/array/growable/mod.rs index 8513820b17..6f6e044e01 100644 --- a/src/daft-core/src/array/growable/mod.rs +++ b/src/daft-core/src/array/growable/mod.rs @@ -164,7 +164,7 @@ impl_growable_array!(Int8Array, arrow_growable::ArrowInt8Growable<'a>); impl_growable_array!(Int16Array, arrow_growable::ArrowInt16Growable<'a>); impl_growable_array!(Int32Array, arrow_growable::ArrowInt32Growable<'a>); impl_growable_array!(Int64Array, arrow_growable::ArrowInt64Growable<'a>); -impl_growable_array!(Int128Array, arrow_growable::ArrowInt128Growable<'a>); +impl_growable_array!(Decimal128Array, arrow_growable::ArrowDecimal128Growable<'a>); impl_growable_array!(UInt8Array, arrow_growable::ArrowUInt8Growable<'a>); impl_growable_array!(UInt16Array, arrow_growable::ArrowUInt16Growable<'a>); impl_growable_array!(UInt32Array, arrow_growable::ArrowUInt32Growable<'a>); @@ -218,8 +218,4 @@ impl_growable_array!( ); impl_growable_array!(ImageArray, logical_growable::LogicalImageGrowable<'a>); impl_growable_array!(TensorArray, logical_growable::LogicalTensorGrowable<'a>); -impl_growable_array!( - Decimal128Array, - logical_growable::LogicalDecimal128Growable<'a> -); impl_growable_array!(MapArray, logical_growable::LogicalMapGrowable<'a>); diff --git a/src/daft-core/src/array/ops/apply.rs b/src/daft-core/src/array/ops/apply.rs index f5388bfbc6..904ab0a057 100644 --- a/src/daft-core/src/array/ops/apply.rs +++ b/src/daft-core/src/array/ops/apply.rs @@ -4,11 +4,15 @@ use arrow2::array::PrimitiveArray; use common_error::{DaftError, DaftResult}; use super::full::FullNull; -use crate::{array::DataArray, datatypes::DaftNumericType, utils::arrow::arrow_bitmap_and_helper}; +use crate::{ + array::DataArray, + datatypes::{DaftNumericType, DaftPrimitiveType}, + utils::arrow::arrow_bitmap_and_helper, +}; impl DataArray where - T: DaftNumericType, + T: DaftPrimitiveType, { // applies a native function to a numeric DataArray maintaining validity of the source array. pub fn apply(&self, func: F) -> DaftResult @@ -16,11 +20,9 @@ where F: Fn(T::Native) -> T::Native + Copy, { let arr: &PrimitiveArray = self.data().as_any().downcast_ref().unwrap(); - let result_arr = - PrimitiveArray::from_trusted_len_values_iter(arr.values_iter().map(|v| func(*v))) - .with_validity(arr.validity().cloned()); + let iter = arr.values_iter().map(|v| func(*v)); - Ok(Self::from((self.name(), Box::new(result_arr)))) + Self::from_values_iter(self.field.clone(), iter).with_validity(arr.validity().cloned()) } // applies a native binary function to two DataArrays, maintaining validity. @@ -40,11 +42,10 @@ where rhs.data().as_any().downcast_ref().unwrap(); let validity = arrow_bitmap_and_helper(lhs_arr.validity(), rhs_arr.validity()); - let result_arr = PrimitiveArray::from_trusted_len_values_iter( - zip(lhs_arr.values_iter(), rhs_arr.values_iter()).map(|(a, b)| func(*a, *b)), - ) - .with_validity(validity); - Ok(Self::from((self.name(), Box::new(result_arr)))) + + let iter = + zip(lhs_arr.values_iter(), rhs_arr.values_iter()).map(|(a, b)| func(*a, *b)); + Self::from_values_iter(self.field.clone(), iter).with_validity(validity) } (l_size, 1) => { if let Some(value) = rhs.get(0) { @@ -57,11 +58,9 @@ where let rhs_arr: &PrimitiveArray = rhs.data().as_any().downcast_ref().unwrap(); if let Some(value) = self.get(0) { - let result_arr = PrimitiveArray::from_trusted_len_values_iter( - rhs_arr.values_iter().map(|v| func(value, *v)), - ) - .with_validity(rhs_arr.validity().cloned()); - Ok(Self::from((self.name(), Box::new(result_arr)))) + let iter = rhs_arr.values_iter().map(|v| func(value, *v)); + Self::from_values_iter(self.field.clone(), iter) + .with_validity(rhs_arr.validity().cloned()) } else { Ok(Self::full_null(self.name(), self.data_type(), r_size)) } diff --git a/src/daft-core/src/array/ops/arithmetic.rs b/src/daft-core/src/array/ops/arithmetic.rs index c77f4722fa..388759c63f 100644 --- a/src/daft-core/src/array/ops/arithmetic.rs +++ b/src/daft-core/src/array/ops/arithmetic.rs @@ -6,8 +6,9 @@ use common_error::{DaftError, DaftResult}; use super::{as_arrow::AsArrow, full::FullNull}; use crate::{ array::{DataArray, FixedSizeListArray}, - datatypes::{DaftNumericType, DataType, Field, Utf8Array}, + datatypes::{DaftNumericType, DaftPrimitiveType, DataType, Field, Utf8Array}, kernels::utf8::add_utf8_arrays, + prelude::Decimal128Array, series::Series, }; // Permission is hereby granted, free of charge, to any person obtaining a copy @@ -38,15 +39,16 @@ fn arithmetic_helper( operation: F, ) -> DaftResult> where - T: DaftNumericType, - Kernel: Fn(&PrimitiveArray, &PrimitiveArray) -> PrimitiveArray, + T: DaftPrimitiveType, + Kernel: + FnOnce(&PrimitiveArray, &PrimitiveArray) -> PrimitiveArray, F: Fn(T::Native, T::Native) -> T::Native, { match (lhs.len(), rhs.len()) { - (a, b) if a == b => Ok(DataArray::from(( - lhs.name(), + (a, b) if a == b => DataArray::new( + lhs.field.clone(), Box::new(kernel(lhs.as_arrow(), rhs.as_arrow())), - ))), + ), // broadcast right path (_, 1) => { let opt_rhs = rhs.get(0); @@ -79,6 +81,50 @@ where } } +impl Add for &Decimal128Array { + type Output = DaftResult; + fn add(self, rhs: Self) -> Self::Output { + assert_eq!(self.data_type(), rhs.data_type()); + arithmetic_helper( + self, + rhs, + arrow2::compute::arithmetics::decimal::add, + |l, r| l + r, + ) + } +} + +impl Sub for &Decimal128Array { + type Output = DaftResult; + fn sub(self, rhs: Self) -> Self::Output { + assert_eq!(self.data_type(), rhs.data_type()); + arithmetic_helper( + self, + rhs, + arrow2::compute::arithmetics::decimal::sub, + |l, r| l - r, + ) + } +} + +impl Mul for &Decimal128Array { + type Output = DaftResult; + fn mul(self, rhs: Self) -> Self::Output { + assert_eq!(self.data_type(), rhs.data_type()); + + let DataType::Decimal128(_, s) = self.data_type() else { + unreachable!("This should always be a Decimal128") + }; + let scale = 10i128.pow(*s as u32); + arithmetic_helper( + self, + rhs, + arrow2::compute::arithmetics::decimal::mul, + |l, r| (l * r) / scale, + ) + } +} + impl Add for &Utf8Array { type Output = DaftResult; fn add(self, rhs: Self) -> Self::Output { @@ -236,6 +282,69 @@ where } } +impl Div for &Decimal128Array { + type Output = DaftResult; + fn div(self, rhs: Self) -> Self::Output { + assert_eq!(self.data_type(), rhs.data_type()); + let DataType::Decimal128(_, s) = self.data_type() else { + unreachable!("This should always be a Decimal128") + }; + let scale = 10i128.pow(*s as u32); + + if rhs.data().null_count() == 0 { + arithmetic_helper( + self, + rhs, + arrow2::compute::arithmetics::decimal::div, + |l, r| ((l * scale) / r), + ) + } else { + match (self.len(), rhs.len()) { + (a, b) if a == b => { + let values = self + .as_arrow() + .iter() + .zip(rhs.as_arrow().iter()) + .map(|(l, r)| match (l, r) { + (None, _) => None, + (_, None) => None, + (Some(l), Some(r)) => Some((l * scale) / r), + }); + Ok(Decimal128Array::from_iter(self.field.clone(), values)) + } + // broadcast right path + (_, 1) => { + let opt_rhs = rhs.get(0); + match opt_rhs { + None => Ok(DataArray::full_null( + self.name(), + self.data_type(), + self.len(), + )), + Some(rhs) => self.apply(|lhs| ((lhs * scale) / rhs)), + } + } + (1, _) => { + let opt_lhs = self.get(0); + Ok(match opt_lhs { + None => DataArray::full_null(rhs.name(), rhs.data_type(), rhs.len()), + Some(lhs) => { + let values_iter = rhs + .as_arrow() + .iter() + .map(|v| v.map(|v| ((lhs * scale) / *v))); + Decimal128Array::from_iter(self.field.clone(), values_iter) + } + }) + } + (a, b) => Err(DaftError::ValueError(format!( + "Cannot apply operation on arrays of different lengths: {a} vs {b}" + ))), + } + } + } +} + fn fixed_sized_list_arithmetic_helper( lhs: &FixedSizeListArray, rhs: &FixedSizeListArray, diff --git a/src/daft-core/src/array/ops/as_arrow.rs b/src/daft-core/src/array/ops/as_arrow.rs index 1672a95160..8964df3640 100644 --- a/src/daft-core/src/array/ops/as_arrow.rs +++ b/src/daft-core/src/array/ops/as_arrow.rs @@ -7,9 +7,9 @@ use crate::datatypes::PythonArray; use crate::{ array::DataArray, datatypes::{ - logical::{DateArray, Decimal128Array, DurationArray, TimeArray, TimestampArray}, - BinaryArray, BooleanArray, DaftNumericType, FixedSizeBinaryArray, IntervalArray, NullArray, - Utf8Array, + logical::{DateArray, DurationArray, TimeArray, TimestampArray}, + BinaryArray, BooleanArray, DaftPrimitiveType, FixedSizeBinaryArray, IntervalArray, + NullArray, Utf8Array, }, }; @@ -24,7 +24,7 @@ pub trait AsArrow { impl AsArrow for DataArray where - T: DaftNumericType, + T: DaftPrimitiveType, { type Output = array::PrimitiveArray; @@ -66,7 +66,6 @@ impl_asarrow_dataarray!(IntervalArray, array::PrimitiveArray); #[cfg(feature = "python")] impl_asarrow_dataarray!(PythonArray, PseudoArrowArray); -impl_asarrow_logicalarray!(Decimal128Array, array::PrimitiveArray); impl_asarrow_logicalarray!(DateArray, array::PrimitiveArray); impl_asarrow_logicalarray!(TimeArray, array::PrimitiveArray); impl_asarrow_logicalarray!(DurationArray, array::PrimitiveArray); diff --git a/src/daft-core/src/array/ops/cast.rs b/src/daft-core/src/array/ops/cast.rs index d495990d45..20981a23ca 100644 --- a/src/daft-core/src/array/ops/cast.rs +++ b/src/daft-core/src/array/ops/cast.rs @@ -37,7 +37,7 @@ use crate::{ }, datatypes::{ logical::{ - DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, + DateArray, DurationArray, EmbeddingArray, FixedShapeImageArray, FixedShapeSparseTensorArray, FixedShapeTensorArray, ImageArray, LogicalArray, MapArray, SparseTensorArray, TensorArray, TimeArray, TimestampArray, }, @@ -412,40 +412,40 @@ impl DurationArray { } } -impl Decimal128Array { - pub fn cast(&self, dtype: &DataType) -> DaftResult { - match dtype { - #[cfg(feature = "python")] - DataType::Python => cast_logical_to_python_array(self, dtype), - DataType::Int128 => Ok(self.physical.clone().into_series()), - dtype if dtype.is_numeric() => self.physical.cast(dtype), - DataType::Decimal128(_, _) => { - // Use the arrow2 Decimal128 casting logic. - let target_arrow_type = dtype.to_arrow()?; - let arrow_decimal_array = self - .as_arrow() - .clone() - .to(self.data_type().to_arrow()?) - .to_boxed(); - let casted_arrow_array = cast( - arrow_decimal_array.as_ref(), - &target_arrow_type, - CastOptions { - wrapped: true, - partial: false, - }, - )?; - - let new_field = Arc::new(Field::new(self.name(), dtype.clone())); - Series::from_arrow(new_field, casted_arrow_array) - } - _ => Err(DaftError::TypeError(format!( - "Cannot cast Decimal128 to {}", - dtype - ))), - } - } -} +// impl Decimal128Array { +// pub fn cast(&self, dtype: &DataType) -> DaftResult { +// match dtype { +// #[cfg(feature = "python")] +// DataType::Python => cast_logical_to_python_array(self, dtype), +// DataType::Int128 => Ok(self.physical.clone().into_series()), +// dtype if dtype.is_numeric() => self.physical.cast(dtype), +// DataType::Decimal128(_, _) => { +// // Use the arrow2 Decimal128 casting logic. +// let target_arrow_type = dtype.to_arrow()?; +// let arrow_decimal_array = self +// .as_arrow() +// .clone() +// .to(self.data_type().to_arrow()?) +// .to_boxed(); +// let casted_arrow_array = cast( +// arrow_decimal_array.as_ref(), +// &target_arrow_type, +// CastOptions { +// wrapped: true, +// partial: false, +// }, +// )?; + +// let new_field = Arc::new(Field::new(self.name(), dtype.clone())); +// Series::from_arrow(new_field, casted_arrow_array) +// } +// _ => Err(DaftError::TypeError(format!( +// "Cannot cast Decimal128 to {}", +// dtype +// ))), +// } +// } +// } #[cfg(feature = "python")] macro_rules! pycast_then_arrowcast { diff --git a/src/daft-core/src/array/ops/compare_agg.rs b/src/daft-core/src/array/ops/compare_agg.rs index 6fc4036b30..f6134078d5 100644 --- a/src/daft-core/src/array/ops/compare_agg.rs +++ b/src/daft-core/src/array/ops/compare_agg.rs @@ -1,4 +1,4 @@ -use arrow2::array::{Array, PrimitiveArray}; +use arrow2::array::Array; use common_error::DaftResult; use super::{full::FullNull, DaftCompareAggable, GroupIndices}; @@ -8,15 +8,15 @@ use crate::{ }; fn grouped_cmp_native( - data_array: &DataArray, + array: &DataArray, mut op: F, groups: &GroupIndices, ) -> DaftResult> where - T: DaftNumericType, + T: DaftPrimitiveType, F: Fn(T::Native, T::Native) -> T::Native, { - let arrow_array = data_array.as_arrow(); + let arrow_array = array.as_arrow(); let cmp_per_group = if arrow_array.null_count() > 0 { let cmp_values_iter = groups.iter().map(|g| { let reduced_val = g @@ -36,9 +36,10 @@ where }); reduced_val.unwrap_or_default() }); - Box::new(PrimitiveArray::from_trusted_len_iter(cmp_values_iter)) + DataArray::::from_iter(array.field.clone(), cmp_values_iter) } else { - Box::new(PrimitiveArray::from_trusted_len_values_iter( + DataArray::::from_values_iter( + array.field.clone(), groups.iter().map(|g| { g.iter() .map(|i| { @@ -48,19 +49,16 @@ where .reduce(&mut op) .unwrap() }), - )) + ) }; - Ok(DataArray::from(( - data_array.field.name.as_ref(), - cmp_per_group, - ))) + Ok(cmp_per_group) } use super::as_arrow::AsArrow; impl DaftCompareAggable for DataArray where - T: DaftNumericType, + T: DaftPrimitiveType, T::Native: PartialOrd, ::Simd: arrow2::compute::aggregate::SimdOrd, { @@ -70,18 +68,14 @@ where let primitive_arr = self.as_arrow(); let result = arrow2::compute::aggregate::min_primitive(primitive_arr); - let arrow_array = Box::new(arrow2::array::PrimitiveArray::from([result])); - - Self::new(self.field.clone(), arrow_array) + Ok(Self::from_iter(self.field.clone(), std::iter::once(result))) } fn max(&self) -> Self::Output { let primitive_arr = self.as_arrow(); let result = arrow2::compute::aggregate::max_primitive(primitive_arr); - let arrow_array = Box::new(arrow2::array::PrimitiveArray::from([result])); - - Self::new(self.field.clone(), arrow_array) + Ok(Self::from_iter(self.field.clone(), std::iter::once(result))) } fn grouped_min(&self, groups: &GroupIndices) -> Self::Output { grouped_cmp_native( diff --git a/src/daft-core/src/array/ops/comparison.rs b/src/daft-core/src/array/ops/comparison.rs index 8c941f8b2a..0a48557cdf 100644 --- a/src/daft-core/src/array/ops/comparison.rs +++ b/src/daft-core/src/array/ops/comparison.rs @@ -8,7 +8,7 @@ use super::{as_arrow::AsArrow, from_arrow::FromArrow, full::FullNull, DaftCompar use crate::{ array::DataArray, datatypes::{ - BinaryArray, BooleanArray, DaftArrowBackedType, DaftNumericType, DataType, Field, + BinaryArray, BooleanArray, DaftArrowBackedType, DaftPrimitiveType, DataType, Field, FixedSizeBinaryArray, NullArray, Utf8Array, }, utils::arrow::arrow_bitmap_and_helper, @@ -25,7 +25,7 @@ where impl DaftCompare<&Self> for DataArray where - T: DaftNumericType, + T: DaftPrimitiveType, { type Output = DaftResult; @@ -272,7 +272,7 @@ where impl DataArray where - T: DaftNumericType, + T: DaftPrimitiveType, { fn compare_to_scalar( &self, @@ -295,7 +295,7 @@ where impl DaftCompare for DataArray where - T: DaftNumericType, + T: DaftPrimitiveType, Scalar: ToPrimitive, { type Output = BooleanArray; diff --git a/src/daft-core/src/array/ops/get.rs b/src/daft-core/src/array/ops/get.rs index 524792ae03..b90062b8af 100644 --- a/src/daft-core/src/array/ops/get.rs +++ b/src/daft-core/src/array/ops/get.rs @@ -5,10 +5,9 @@ use crate::{ array::{DataArray, FixedSizeListArray, ListArray}, datatypes::{ logical::{ - DateArray, Decimal128Array, DurationArray, LogicalArrayImpl, MapArray, TimeArray, - TimestampArray, + DateArray, DurationArray, LogicalArrayImpl, MapArray, TimeArray, TimestampArray, }, - BinaryArray, BooleanArray, DaftLogicalType, DaftNumericType, ExtensionArray, + BinaryArray, BooleanArray, DaftLogicalType, DaftPrimitiveType, ExtensionArray, FixedSizeBinaryArray, IntervalArray, NullArray, Utf8Array, }, series::Series, @@ -16,7 +15,7 @@ use crate::{ impl DataArray where - T: DaftNumericType, + T: DaftPrimitiveType, { #[inline] pub fn get(&self, idx: usize) -> Option { @@ -72,7 +71,6 @@ impl_array_arrow_get!(Utf8Array, &str); impl_array_arrow_get!(BooleanArray, bool); impl_array_arrow_get!(BinaryArray, &[u8]); impl_array_arrow_get!(FixedSizeBinaryArray, &[u8]); -impl_array_arrow_get!(Decimal128Array, i128); impl_array_arrow_get!(DateArray, i32); impl_array_arrow_get!(TimeArray, i64); impl_array_arrow_get!(DurationArray, i64); diff --git a/src/daft-core/src/array/ops/groups.rs b/src/daft-core/src/array/ops/groups.rs index 6f053040c3..c184d54015 100644 --- a/src/daft-core/src/array/ops/groups.rs +++ b/src/daft-core/src/array/ops/groups.rs @@ -14,6 +14,7 @@ use crate::{ BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, FixedSizeBinaryArray, Float32Array, Float64Array, NullArray, Utf8Array, }, + prelude::Decimal128Array, }; /// Given a list of values, return a `(Vec, Vec>)`. @@ -84,6 +85,17 @@ where } } +impl IntoGroups for Decimal128Array { + fn make_groups(&self) -> DaftResult { + let array = self.as_arrow(); + if array.null_count() > 0 { + make_groups(array.iter()) + } else { + make_groups(array.values_iter()) + } + } +} + impl IntoGroups for Float32Array { fn make_groups(&self) -> DaftResult { let array = self.as_arrow(); diff --git a/src/daft-core/src/array/ops/hash.rs b/src/daft-core/src/array/ops/hash.rs index 278f70ce79..512598e788 100644 --- a/src/daft-core/src/array/ops/hash.rs +++ b/src/daft-core/src/array/ops/hash.rs @@ -1,15 +1,18 @@ +use std::sync::Arc; + use arrow2::types::Index; use common_error::{DaftError, DaftResult}; +use daft_schema::{dtype::DataType, field::Field}; use xxhash_rust::xxh3::{xxh3_64, xxh3_64_with_seed}; use super::as_arrow::AsArrow; use crate::{ array::{DataArray, FixedSizeListArray, ListArray, StructArray}, datatypes::{ - logical::{DateArray, Decimal128Array, TimeArray, TimestampArray}, - BinaryArray, BooleanArray, DaftNumericType, FixedSizeBinaryArray, Int16Array, Int32Array, - Int64Array, Int8Array, NullArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, - Utf8Array, + logical::{DateArray, TimeArray, TimestampArray}, + BinaryArray, BooleanArray, DaftPrimitiveType, Decimal128Array, FixedSizeBinaryArray, + Int16Array, Int32Array, Int64Array, Int8Array, NullArray, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, Utf8Array, }, kernels, series::Series, @@ -18,7 +21,7 @@ use crate::{ impl DataArray where - T: DaftNumericType, + T: DaftPrimitiveType, { pub fn hash(&self, seed: Option<&UInt64Array>) -> DaftResult { let as_arrowed = self.as_arrow(); @@ -106,14 +109,14 @@ fn hash_list( if let Some(seed_arr) = seed { let combined_validity = arrow_bitmap_and_helper(validity, seed.unwrap().validity()); UInt64Array::from_iter( - name, + Arc::new(Field::new(name, DataType::UInt64)), u64::range(0, offsets.len() - 1).unwrap().map(|i| { let start = offsets[i as usize] as usize; let end = offsets[i as usize + 1] as usize; // apply the current seed across this row let cur_seed_opt = seed_arr.get(i as usize); let flat_seed = UInt64Array::from_iter( - "seed", + Arc::new(Field::new("seed", DataType::UInt64)), std::iter::repeat(cur_seed_opt).take(end - start), ); let hashed_child = flat_child @@ -146,7 +149,7 @@ fn hash_list( const OFFSET: usize = (u64::BITS as usize) / 8; // how many bytes per u64 let combined_validity = validity.cloned(); UInt64Array::from_iter( - name, + Arc::new(Field::new(name, DataType::UInt64)), u64::range(0, offsets.len() - 1).unwrap().map(|i| { let start = (offsets[i as usize] as usize) * OFFSET; let end = (offsets[i as usize + 1] as usize) * OFFSET; @@ -319,7 +322,11 @@ impl TimestampArray { impl Decimal128Array { pub fn murmur3_32(&self) -> DaftResult { - let arr = self.physical.as_arrow(); + let arr = self + .data() + .as_any() + .downcast_ref::>() + .expect("this should be a decimal array"); let hashes = arr.into_iter().map(|d| { d.map(|d| { let twos_compliment = u128::from_ne_bytes(d.to_ne_bytes()); diff --git a/src/daft-core/src/array/ops/is_in.rs b/src/daft-core/src/array/ops/is_in.rs index 24e78e8f29..9e942d482f 100644 --- a/src/daft-core/src/array/ops/is_in.rs +++ b/src/daft-core/src/array/ops/is_in.rs @@ -71,6 +71,8 @@ macro_rules! impl_is_in_non_numeric_array { } }; } + +impl_is_in_non_numeric_array!(Decimal128Array); impl_is_in_non_numeric_array!(BooleanArray); impl_is_in_non_numeric_array!(Utf8Array); impl_is_in_non_numeric_array!(BinaryArray); diff --git a/src/daft-core/src/array/ops/repr.rs b/src/daft-core/src/array/ops/repr.rs index 0d40150cf9..7eba5c0ba1 100644 --- a/src/daft-core/src/array/ops/repr.rs +++ b/src/daft-core/src/array/ops/repr.rs @@ -5,12 +5,12 @@ use crate::{ array::{DataArray, FixedSizeListArray, ListArray, StructArray}, datatypes::{ logical::{ - DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, + DateArray, DurationArray, EmbeddingArray, FixedShapeImageArray, FixedShapeSparseTensorArray, FixedShapeTensorArray, ImageArray, MapArray, SparseTensorArray, TensorArray, TimeArray, TimestampArray, }, - BinaryArray, BooleanArray, DaftNumericType, DataType, ExtensionArray, FixedSizeBinaryArray, - IntervalArray, IntervalValue, NullArray, UInt64Array, Utf8Array, + BinaryArray, BooleanArray, DaftNumericType, DataType, Decimal128Array, ExtensionArray, + FixedSizeBinaryArray, IntervalArray, IntervalValue, NullArray, UInt64Array, Utf8Array, }, series::Series, utils::display::{ diff --git a/src/daft-core/src/array/ops/sort.rs b/src/daft-core/src/array/ops/sort.rs index 1831313e98..fa3ccad594 100644 --- a/src/daft-core/src/array/ops/sort.rs +++ b/src/daft-core/src/array/ops/sort.rs @@ -11,12 +11,13 @@ use crate::{ array::{DataArray, FixedSizeListArray, ListArray, StructArray}, datatypes::{ logical::{ - DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, + DateArray, DurationArray, EmbeddingArray, FixedShapeImageArray, FixedShapeSparseTensorArray, FixedShapeTensorArray, ImageArray, MapArray, SparseTensorArray, TensorArray, TimeArray, TimestampArray, }, - BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, ExtensionArray, - FixedSizeBinaryArray, Float32Array, Float64Array, IntervalArray, NullArray, Utf8Array, + BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, Decimal128Array, + ExtensionArray, FixedSizeBinaryArray, Float32Array, Float64Array, IntervalArray, NullArray, + Utf8Array, }, kernels::search_sorted::{build_compare_with_nulls, cmp_float}, series::Series, @@ -336,6 +337,98 @@ impl Float64Array { } } +impl Decimal128Array { + pub fn argsort(&self, descending: bool) -> DaftResult> + where + I: DaftIntegerType, + ::Native: arrow2::types::Index, + { + let arrow_array = self.as_arrow(); + + let result = + crate::array::ops::arrow2::sort::primitive::indices::indices_sorted_unstable_by::< + I::Native, + i128, + _, + >(arrow_array, ord::total_cmp, descending); + + Ok(DataArray::::from((self.name(), Box::new(result)))) + } + + pub fn argsort_multikey( + &self, + others: &[Series], + descending: &[bool], + ) -> DaftResult> + where + I: DaftIntegerType, + ::Native: arrow2::types::Index, + { + let arrow_array = self.as_arrow(); + let first_desc = *descending.first().unwrap(); + + let others_cmp = build_multi_array_compare(others, &descending[1..])?; + + let values = arrow_array.values().as_slice(); + + let result = if first_desc { + multi_column_idx_sort( + arrow_array.validity(), + |a: &I::Native, b: &I::Native| { + let a = a.to_usize(); + let b = b.to_usize(); + let l = unsafe { values.get_unchecked(a) }; + let r = unsafe { values.get_unchecked(b) }; + match ord::total_cmp(r, l) { + std::cmp::Ordering::Equal => others_cmp(a, b), + v => v, + } + }, + &others_cmp, + arrow_array.len(), + first_desc, + ) + } else { + multi_column_idx_sort( + arrow_array.validity(), + |a: &I::Native, b: &I::Native| { + let a = a.to_usize(); + let b = b.to_usize(); + let l = unsafe { values.get_unchecked(a) }; + let r = unsafe { values.get_unchecked(b) }; + match ord::total_cmp(l, r) { + std::cmp::Ordering::Equal => others_cmp(a, b), + v => v, + } + }, + &others_cmp, + arrow_array.len(), + first_desc, + ) + }; + + Ok(DataArray::::from((self.name(), Box::new(result)))) + } + + pub fn sort(&self, descending: bool) -> DaftResult { + let options = arrow2::compute::sort::SortOptions { + descending, + nulls_first: descending, + }; + + let arrow_array = self.as_arrow(); + + let result = crate::array::ops::arrow2::sort::primitive::sort::sort_by::( + arrow_array, + ord::total_cmp, + &options, + None, + ); + + Self::new(self.field.clone(), Box::new(result)) + } +} + impl NullArray { pub fn argsort(&self, _descending: bool) -> DaftResult> where @@ -626,13 +719,6 @@ impl PythonArray { } } -impl Decimal128Array { - pub fn sort(&self, descending: bool) -> DaftResult { - let new_array = self.physical.sort(descending)?; - Ok(Self::new(self.field.clone(), new_array)) - } -} - impl DateArray { pub fn sort(&self, descending: bool) -> DaftResult { let new_array = self.physical.sort(descending)?; diff --git a/src/daft-core/src/array/ops/sum.rs b/src/daft-core/src/array/ops/sum.rs index 90c4e512a6..c4d7e3157b 100644 --- a/src/daft-core/src/array/ops/sum.rs +++ b/src/daft-core/src/array/ops/sum.rs @@ -11,16 +11,18 @@ macro_rules! impl_daft_numeric_agg { fn sum(&self) -> Self::Output { let primitive_arr = self.as_arrow(); let sum_value = arrow2::compute::aggregate::sum_primitive(primitive_arr); - let arrow_array = Box::new(arrow2::array::PrimitiveArray::from([sum_value])); - DataArray::new(self.field.clone(), arrow_array) + Ok(DataArray::<$T>::from_iter( + self.field.clone(), + std::iter::once(sum_value), + )) } fn grouped_sum(&self, groups: &GroupIndices) -> Self::Output { - use arrow2::array::PrimitiveArray; let arrow_array = self.as_arrow(); let sum_per_group = if arrow_array.null_count() > 0 { - Box::new(PrimitiveArray::from_trusted_len_iter(groups.iter().map( - |g| { + DataArray::<$T>::from_iter( + self.field.clone(), + groups.iter().map(|g| { g.iter().fold(None, |acc, index| { let idx = *index as usize; match (acc, arrow_array.is_null(idx)) { @@ -29,27 +31,28 @@ macro_rules! impl_daft_numeric_agg { (Some(acc), false) => Some(acc + arrow_array.value(idx)), } }) - }, - ))) + }), + ) } else { - Box::new(PrimitiveArray::from_trusted_len_values_iter( + DataArray::<$T>::from_values_iter( + self.field.clone(), groups.iter().map(|g| { g.iter().fold(0 as $AggType, |acc, index| { let idx = *index as usize; acc + unsafe { arrow_array.value_unchecked(idx) } }) }), - )) + ) }; - Ok(DataArray::from((self.field.name.as_ref(), sum_per_group))) + Ok(sum_per_group) } } }; } impl_daft_numeric_agg!(Int64Type, i64); -impl_daft_numeric_agg!(Int128Type, i128); impl_daft_numeric_agg!(UInt64Type, u64); impl_daft_numeric_agg!(Float32Type, f32); impl_daft_numeric_agg!(Float64Type, f64); +impl_daft_numeric_agg!(Decimal128Type, i128); diff --git a/src/daft-core/src/array/ops/take.rs b/src/daft-core/src/array/ops/take.rs index ae3bac984f..0cceec49a0 100644 --- a/src/daft-core/src/array/ops/take.rs +++ b/src/daft-core/src/array/ops/take.rs @@ -61,8 +61,8 @@ impl_dataarray_take!(BinaryArray); impl_dataarray_take!(NullArray); impl_dataarray_take!(ExtensionArray); impl_dataarray_take!(IntervalArray); +impl_dataarray_take!(Decimal128Array); -impl_logicalarray_take!(Decimal128Array); impl_logicalarray_take!(DateArray); impl_logicalarray_take!(TimeArray); impl_logicalarray_take!(DurationArray); diff --git a/src/daft-core/src/array/ops/truncate.rs b/src/daft-core/src/array/ops/truncate.rs index c939cd89ac..ba290d559e 100644 --- a/src/daft-core/src/array/ops/truncate.rs +++ b/src/daft-core/src/array/ops/truncate.rs @@ -7,8 +7,8 @@ use super::as_arrow::AsArrow; use crate::{ array::DataArray, datatypes::{ - logical::Decimal128Array, DaftNumericType, Int16Type, Int32Type, Int64Type, Int8Type, - UInt16Type, UInt32Type, UInt64Type, UInt8Type, Utf8Array, + DaftNumericType, Decimal128Array, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, Utf8Array, }, prelude::BinaryArray, }; @@ -53,11 +53,7 @@ impl Decimal128Array { i - remainder }) }); - let array = Box::new(arrow2::array::PrimitiveArray::from_iter(trun_value)); - Ok(Decimal128Array::new( - self.field.clone(), - DataArray::from((self.name(), array)), - )) + Ok(Self::from_iter(self.field.clone(), trun_value)) } } diff --git a/src/daft-core/src/array/ops/utf8.rs b/src/daft-core/src/array/ops/utf8.rs index f67ee6977b..5b2be1969e 100644 --- a/src/daft-core/src/array/ops/utf8.rs +++ b/src/daft-core/src/array/ops/utf8.rs @@ -1,6 +1,7 @@ use std::{ borrow::Cow, iter::{self, Repeat, Take}, + sync::Arc, }; use aho_corasick::{AhoCorasickBuilder, MatchKind}; @@ -1368,8 +1369,11 @@ impl Utf8Array { ) -> DaftResult { if patterns.null_count() == patterns.len() { // no matches - return UInt64Array::from_iter(self.name(), iter::repeat(Some(0)).take(self.len())) - .with_validity(self.validity().cloned()); + return UInt64Array::from_iter( + Arc::new(Field::new(self.name(), DataType::UInt64)), + iter::repeat(Some(0)).take(self.len()), + ) + .with_validity(self.validity().cloned()); } let patterns = patterns.as_arrow().iter().flatten(); @@ -1400,7 +1404,10 @@ impl Utf8Array { } }) }); - Ok(UInt64Array::from_iter(self.name(), iter)) + Ok(UInt64Array::from_iter( + Arc::new(Field::new(self.name(), DataType::UInt64)), + iter, + )) } fn unary_broadcasted_op(&self, operation: ScalarKernel) -> DaftResult diff --git a/src/daft-core/src/array/prelude.rs b/src/daft-core/src/array/prelude.rs index 8e3b6f4031..1a082e6ea4 100644 --- a/src/daft-core/src/array/prelude.rs +++ b/src/daft-core/src/array/prelude.rs @@ -1,17 +1,17 @@ pub use super::{DataArray, FixedSizeListArray, ListArray, StructArray}; // Import logical array types pub use crate::datatypes::logical::{ - DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, - FixedShapeSparseTensorArray, FixedShapeTensorArray, ImageArray, LogicalArray, MapArray, - SparseTensorArray, TensorArray, TimeArray, TimestampArray, + DateArray, DurationArray, EmbeddingArray, FixedShapeImageArray, FixedShapeSparseTensorArray, + FixedShapeTensorArray, ImageArray, LogicalArray, MapArray, SparseTensorArray, TensorArray, + TimeArray, TimestampArray, }; pub use crate::{ array::ops::{ as_arrow::AsArrow, from_arrow::FromArrow, full::FullNull, DaftCompare, DaftLogical, }, datatypes::{ - BinaryArray, BooleanArray, ExtensionArray, FixedSizeBinaryArray, Float32Array, - Float64Array, Int128Array, Int16Array, Int32Array, Int64Array, Int8Array, IntervalArray, + BinaryArray, BooleanArray, Decimal128Array, ExtensionArray, FixedSizeBinaryArray, + Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, IntervalArray, NullArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, Utf8Array, }, }; diff --git a/src/daft-core/src/array/serdes.rs b/src/daft-core/src/array/serdes.rs index 4b57fc107b..6a8dc64df3 100644 --- a/src/daft-core/src/array/serdes.rs +++ b/src/daft-core/src/array/serdes.rs @@ -7,7 +7,7 @@ use super::{ops::as_arrow::AsArrow, DataArray, FixedSizeListArray, ListArray, St use crate::datatypes::PythonArray; use crate::{ datatypes::{ - logical::LogicalArray, BinaryArray, BooleanArray, DaftLogicalType, DaftNumericType, + logical::LogicalArray, BinaryArray, BooleanArray, DaftLogicalType, DaftPrimitiveType, DataType, ExtensionArray, FixedSizeBinaryArray, Int64Array, IntervalArray, NullArray, Utf8Array, }, @@ -51,7 +51,7 @@ where } } -impl serde::Serialize for DataArray { +impl serde::Serialize for DataArray { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, diff --git a/src/daft-core/src/datatypes/infer_datatype.rs b/src/daft-core/src/datatypes/infer_datatype.rs index 51b28a665b..0c1bfdf581 100644 --- a/src/daft-core/src/datatypes/infer_datatype.rs +++ b/src/daft-core/src/datatypes/infer_datatype.rs @@ -1,4 +1,5 @@ use std::{ + f64::consts::LOG10_2, fmt::Display, ops::{Add, Div, Mul, Rem, Shl, Shr, Sub}, }; @@ -80,9 +81,6 @@ impl<'a> InferDataType<'a> { left, other ))) } - (s, o) if s.is_physical() && o.is_physical() => { - Ok((DataType::Boolean, None, try_physical_supertype(s, o)?)) - } (DataType::Timestamp(..), DataType::Timestamp(..)) => { let intermediate_type = try_get_supertype(left, other)?; let pt = intermediate_type.to_physical(); @@ -94,6 +92,45 @@ impl<'a> InferDataType<'a> { let pt = intermediate_type.to_physical(); Ok((DataType::Boolean, Some(intermediate_type), pt)) } + + (DataType::Decimal128(..), other) if other.is_integer() => { + self.comparison_op(&InferDataType::from(&integer_to_decimal128(other)?)) + } + (left, DataType::Decimal128(..)) if left.is_integer() => { + InferDataType::from(&integer_to_decimal128(left)?) + .comparison_op(&InferDataType::from(*other)) + } + (DataType::Decimal128(..), DataType::Float32 | DataType::Float64) + | (DataType::Float32 | DataType::Float64, DataType::Decimal128(..)) => Ok(( + DataType::Boolean, + Some(DataType::Float64), + DataType::Float64, + )), + (DataType::Decimal128(p1, s1), DataType::Decimal128(p2, s2)) => { + let s_max = *std::cmp::max(s1, s2); + let p_prime = std::cmp::max(p1 - s1, p2 - s2) + s_max; + + let d_type = if !(1..=34).contains(&p_prime) { + Err(DaftError::TypeError( + format!("Cannot infer supertypes for comparison on types: {}, {} result precision: {p_prime} exceed bounds of [1, 34]", self, other) + )) + } else if s_max > 34 { + Err(DaftError::TypeError( + format!("Cannot infer supertypes for comparison on types: {}, {} result scale: {s_max} exceed bounds of [0, 34]", self, other) + )) + } else if s_max > p_prime { + Err(DaftError::TypeError( + format!("Cannot infer supertypes for comparison on types: {}, {} result scale: {s_max} exceed precision {p_prime}", self, other) + )) + } else { + Ok(DataType::Decimal128(p_prime, s_max)) + }?; + + Ok((DataType::Boolean, Some(d_type.clone()), d_type)) + } + (s, o) if s.is_physical() && o.is_physical() => { + Ok((DataType::Boolean, None, try_physical_supertype(s, o)?)) + } _ => Err(DaftError::TypeError(format!( "Cannot perform comparison on types: {}, {}", left, other @@ -191,6 +228,31 @@ impl<'a> Add for InferDataType<'a> { // ---- Boolean + other ---- (DataType::Boolean, other) | (other, DataType::Boolean) if other.is_numeric() => Ok(other.clone()), + + (DataType::Decimal128(..), other) if other.is_integer() => self.add(InferDataType::from(&integer_to_decimal128(other)?)), + (left, DataType::Decimal128(..)) if left.is_integer() => InferDataType::from(&integer_to_decimal128(left)?).add(other), + (DataType::Decimal128(..), DataType::Float32 | DataType::Float64 ) | (DataType::Float32 | DataType::Float64, DataType::Decimal128(..)) => Ok(DataType::Float64), + (DataType::Decimal128(p1, s1), DataType::Decimal128(p2, s2)) => { + let s_max = *std::cmp::max(s1, s2); + let p_prime = std::cmp::max(p1 - s1, p2 - s2) + s_max + 1; + + + if !(1..=34).contains(&p_prime) { + Err(DaftError::TypeError( + format!("Cannot infer supertypes for addition on types: {}, {} result precision: {p_prime} exceed bounds of [1, 34]", self, other) + )) + } else if s_max > 34 { + Err(DaftError::TypeError( + format!("Cannot infer supertypes for addition on types: {}, {} result scale: {s_max} exceed bounds of [0, 34]", self, other) + )) + } else if s_max > p_prime { + Err(DaftError::TypeError( + format!("Cannot infer supertypes for addition on types: {}, {} result scale: {s_max} exceed precision {p_prime}", self, other) + )) + } else { + Ok(DataType::Decimal128(p_prime, s_max)) + } + } _ => Err(DaftError::TypeError( format!("Cannot infer supertypes for addition on types: {}, {}", self, other) )) @@ -225,6 +287,28 @@ impl<'a> Sub for InferDataType<'a> { (du_self @ &DataType::Duration(..), du_other @ &DataType::Duration(..)) => Err(DaftError::TypeError( format!("Cannot subtract due to differing precision: {}, {}. Please explicitly cast to the precision you wish to add in.", du_self, du_other) )), + (DataType::Decimal128(..), other) if other.is_integer() => self.sub(InferDataType::from(&integer_to_decimal128(other)?)), + (left, DataType::Decimal128(..)) if left.is_integer() => InferDataType::from(&integer_to_decimal128(left)?).sub(other), + (DataType::Decimal128(..), DataType::Float32 | DataType::Float64 ) | (DataType::Float32 | DataType::Float64, DataType::Decimal128(..)) => Ok(DataType::Float64), + (DataType::Decimal128(p1, s1), DataType::Decimal128(p2, s2)) => { + let s_max = *std::cmp::max(s1, s2); + let p_prime = std::cmp::max(p1 - s1, p2 - s2) + s_max + 1; + if !(1..=34).contains(&p_prime) { + Err(DaftError::TypeError( + format!("Cannot infer supertypes for subtraction on types: {}, {} result precision: {p_prime} exceed bounds of [1, 34]", self, other) + )) + } else if s_max > 34 { + Err(DaftError::TypeError( + format!("Cannot infer supertypes for subtraction on types: {}, {} result scale: {s_max} exceed bounds of [0, 34]", self, other) + )) + } else if s_max > p_prime { + Err(DaftError::TypeError( + format!("Cannot infer supertypes for subtraction on types: {}, {} result scale: {s_max} exceed precision {p_prime}", self, other) + )) + } else { + Ok(DataType::Decimal128(p_prime, s_max)) + } + } (DataType::Interval, dtype) | (dtype, DataType::Interval) if dtype.is_temporal() => Ok(dtype.clone()), _ => Err(DaftError::TypeError( format!("Cannot subtract types: {}, {}", self, other) @@ -238,19 +322,44 @@ impl<'a> Div for InferDataType<'a> { type Output = DaftResult; fn div(self, other: Self) -> Self::Output { - match (&self.0, &other.0) { - #[cfg(feature = "python")] - (DataType::Python, _) | (_, DataType::Python) => Ok(DataType::Python), - (s, o) if s.is_numeric() && o.is_numeric() => Ok(DataType::Float64), - _ => Err(DaftError::TypeError(format!( - "Cannot divide types: {}, {}", - self, other - ))), - } - .or_else(|_| { - try_fixed_shape_numeric_datatype(self.0, other.0, |l, r| { - InferDataType::from(l) / InferDataType::from(r) - }) + try_fixed_shape_numeric_datatype(self.0, other.0, |l, r| { + InferDataType::from(l) / InferDataType::from(r) + }).or_else(|_| { + match (&self.0, &other.0) { + #[cfg(feature = "python")] + (DataType::Python, _) | (_, DataType::Python) => Ok(DataType::Python), + (DataType::Decimal128(..), right) if right.is_integer() => self.div(InferDataType::from(&integer_to_decimal128(right)?)), + (left, DataType::Decimal128(..)) if left.is_integer() => InferDataType::from(&integer_to_decimal128(left)?).div(other), + (DataType::Decimal128(..), DataType::Float32 | DataType::Float64 ) | (DataType::Float32 | DataType::Float64, DataType::Decimal128(..)) => Ok(DataType::Float64), + (DataType::Decimal128(p1, s1), DataType::Decimal128(_, s2)) => { + let s1 = *s1 as i64; + let s2 = *s2 as i64; + let p1 = *p1 as i64; + let s_prime = s1 - s2 + std::cmp::max(6, p1+s2+1); + let p_prime = p1 - s1 + s_prime; + if !(1..=34).contains(&p_prime) { + Err(DaftError::TypeError( + format!("Cannot infer supertypes for divide on types: {}, {} result precision: {p_prime} exceed bounds of [1, 34]. scale: {s_prime}", self, other) + )) + } else if !(0..=34).contains(&s_prime){ + Err(DaftError::TypeError( + format!("Cannot infer supertypes for divide on types: {}, {} result scale: {s_prime} exceed bounds of [0, 34]. precision: {p_prime}", self, other) + )) + } else if s_prime > p_prime { + Err(DaftError::TypeError( + format!("Cannot infer supertypes for divide on types: {}, {} result scale: {s_prime} exceed precision {p_prime}", self, other) + )) + } else { + Ok(DataType::Decimal128(p_prime as usize, s_prime as usize)) + } + } + (s, o) if s.is_numeric() && o.is_numeric() => Ok(DataType::Float64), + (l, r) => Err(DaftError::TypeError(format!( + "Cannot divide types: {}, {}", + l, r + ))), + } + }) } } @@ -268,6 +377,29 @@ impl<'a> Mul for InferDataType<'a> { .or(match (self.0, other.0) { #[cfg(feature = "python")] (DataType::Python, _) | (_, DataType::Python) => Ok(DataType::Python), + (DataType::Decimal128(..), other) if other.is_integer() => self.mul(InferDataType::from(&integer_to_decimal128(other)?)), + (left, DataType::Decimal128(..)) if left.is_integer() => InferDataType::from(&integer_to_decimal128(left)?).mul(other), + (DataType::Decimal128(..), DataType::Float32) | (DataType::Float32, DataType::Decimal128(..)) => Ok(DataType::Float32), + (DataType::Decimal128(..), DataType::Float64) | (DataType::Float64, DataType::Decimal128(..)) => Ok(DataType::Float64), + (DataType::Decimal128(p1, s1), DataType::Decimal128(p2, s2)) => { + let s_prime = s1 + s2; + let p_prime = p1 + p2; + if !(1..=34).contains(&p_prime) { + Err(DaftError::TypeError( + format!("Cannot infer supertypes for multiply on types: {}, {} result precision: {p_prime} exceed bounds of [1, 34]", self, other) + )) + } else if s_prime > 34 { + Err(DaftError::TypeError( + format!("Cannot infer supertypes for multiply on types: {}, {} result scale: {s_prime} exceed bounds of [0, 34]", self, other) + )) + } else if s_prime > p_prime { + Err(DaftError::TypeError( + format!("Cannot infer supertypes for multiply on types: {}, {} result scale: {s_prime} exceed precision {p_prime}", self, other) + )) + } else { + Ok(DataType::Decimal128(p_prime, s_prime)) + } + } _ => Err(DaftError::TypeError(format!( "Cannot multiply types: {}, {}", self, other @@ -290,7 +422,7 @@ impl<'a> Rem for InferDataType<'a> { #[cfg(feature = "python")] (DataType::Python, _) | (_, DataType::Python) => Ok(DataType::Python), _ => Err(DaftError::TypeError(format!( - "Cannot modulus types: {}, {}", + "Cannot modulo types: {}, {}", self, other ))), }) @@ -325,6 +457,24 @@ impl<'a> Shr for InferDataType<'a> { } } +pub fn integer_to_decimal128(dtype: &DataType) -> DaftResult { + let constant = LOG10_2; + + let num_bits = match dtype { + DataType::Int8 | DataType::UInt8 => Ok(8), + DataType::Int16 | DataType::UInt16 => Ok(16), + DataType::Int32 | DataType::UInt32 => Ok(32), + DataType::Int64 | DataType::UInt64 => Ok(64), + _ => Err(DaftError::TypeError(format!( + "We can't infer the number of digits for a decimal from a non integer: {}", + dtype + ))), + }?; + let num_digits = ((num_bits as f64) * constant).ceil() as usize; + + Ok(DataType::Decimal128(num_digits, 0)) +} + pub fn try_physical_supertype(l: &DataType, r: &DataType) -> DaftResult { // Given two physical data types, // get the physical data type that they can both be casted to. diff --git a/src/daft-core/src/datatypes/logical.rs b/src/daft-core/src/datatypes/logical.rs index 9704b3b76f..e16ef05336 100644 --- a/src/daft-core/src/datatypes/logical.rs +++ b/src/daft-core/src/datatypes/logical.rs @@ -3,7 +3,7 @@ use std::{marker::PhantomData, sync::Arc}; use common_error::DaftResult; use super::{ - DaftArrayType, DaftDataType, DataArray, DataType, Decimal128Type, DurationType, EmbeddingType, + DaftArrayType, DaftDataType, DataArray, DataType, DurationType, EmbeddingType, FixedShapeImageType, FixedShapeSparseTensorType, FixedShapeTensorType, FixedSizeListArray, ImageType, MapType, SparseTensorType, TensorType, TimeType, TimestampType, }; @@ -164,7 +164,7 @@ impl MapArray { pub type LogicalArray = LogicalArrayImpl::PhysicalType as DaftDataType>::ArrayType>; -pub type Decimal128Array = LogicalArray; +// pub type Decimal128Array = LogicalArray; pub type DateArray = LogicalArray; pub type TimeArray = LogicalArray; pub type DurationArray = LogicalArray; diff --git a/src/daft-core/src/datatypes/matching.rs b/src/daft-core/src/datatypes/matching.rs index 792896afc5..ce6e662b37 100644 --- a/src/daft-core/src/datatypes/matching.rs +++ b/src/daft-core/src/datatypes/matching.rs @@ -24,7 +24,6 @@ macro_rules! with_match_daft_types {( Float32 => __with_ty__! { Float32Type }, Float64 => __with_ty__! { Float64Type }, Image(..) => __with_ty__! { ImageType }, - Int128 => __with_ty__! { Int128Type }, Int16 => __with_ty__! { Int16Type }, Int32 => __with_ty__! { Int32Type }, Int64 => __with_ty__! { Int64Type }, @@ -68,7 +67,7 @@ macro_rules! with_match_physical_daft_types {( Int16 => __with_ty__! { Int16Type }, Int32 => __with_ty__! { Int32Type }, Int64 => __with_ty__! { Int64Type }, - Int128 => __with_ty__! { Int128Type }, + Decimal128(_, _) => __with_ty__! { Decimal128Type }, UInt8 => __with_ty__! { UInt8Type }, UInt16 => __with_ty__! { UInt16Type }, UInt32 => __with_ty__! { UInt32Type }, @@ -114,6 +113,8 @@ macro_rules! with_match_arrow_daft_types {( // Float16 => __with_ty__! { Float16Type }, Float32 => __with_ty__! { Float32Type }, Float64 => __with_ty__! { Float64Type }, + Decimal128(..) => __with_ty__! { Decimal128Type }, + // Date => __with_ty__! { DateType }, // Timestamp(_, _) => __with_ty__! { TimestampType }, List(_) => __with_ty__! { ListType }, @@ -139,7 +140,7 @@ macro_rules! with_match_comparable_daft_types {( Int16 => __with_ty__! { Int16Type }, Int32 => __with_ty__! { Int32Type }, Int64 => __with_ty__! { Int64Type }, - Int128 => __with_ty__! { Int128Type }, + Decimal128(..) => __with_ty__! { Decimal128Type }, UInt8 => __with_ty__! { UInt8Type }, UInt16 => __with_ty__! { UInt16Type }, UInt32 => __with_ty__! { UInt32Type }, @@ -170,7 +171,7 @@ macro_rules! with_match_hashable_daft_types {( Int16 => __with_ty__! { Int16Type }, Int32 => __with_ty__! { Int32Type }, Int64 => __with_ty__! { Int64Type }, - Int128 => __with_ty__! { Int128Type }, + Decimal128(..) => __with_ty__! { Decimal128Type }, UInt8 => __with_ty__! { UInt8Type }, UInt16 => __with_ty__! { UInt16Type }, UInt32 => __with_ty__! { UInt32Type }, diff --git a/src/daft-core/src/datatypes/mod.rs b/src/daft-core/src/datatypes/mod.rs index 41cb82e2a8..2fd8f17c1c 100644 --- a/src/daft-core/src/datatypes/mod.rs +++ b/src/daft-core/src/datatypes/mod.rs @@ -182,11 +182,14 @@ impl_daft_arrow_datatype!(Int8Type, Int8); impl_daft_arrow_datatype!(Int16Type, Int16); impl_daft_arrow_datatype!(Int32Type, Int32); impl_daft_arrow_datatype!(Int64Type, Int64); -impl_daft_arrow_datatype!(Int128Type, Int128); impl_daft_arrow_datatype!(UInt8Type, UInt8); impl_daft_arrow_datatype!(UInt16Type, UInt16); impl_daft_arrow_datatype!(UInt32Type, UInt32); impl_daft_arrow_datatype!(UInt64Type, UInt64); + +// This Type isn't actually used but has to be kept around to ensure that i128 is recognized as a primitive +impl_daft_arrow_datatype!(Int128Type, Unknown); + impl_daft_arrow_datatype!( IntervalType, Interval, @@ -225,12 +228,12 @@ impl_daft_arrow_datatype!(BinaryType, Binary); impl_daft_arrow_datatype!(FixedSizeBinaryType, Unknown); impl_daft_arrow_datatype!(Utf8Type, Utf8); impl_daft_arrow_datatype!(ExtensionType, Unknown); +impl_daft_arrow_datatype!(Decimal128Type, Unknown); impl_nested_datatype!(FixedSizeListType, FixedSizeListArray); impl_nested_datatype!(StructType, StructArray); impl_nested_datatype!(ListType, ListArray); -impl_daft_logical_data_array_datatype!(Decimal128Type, Unknown, Int128Type); impl_daft_logical_data_array_datatype!(TimestampType, Unknown, Int64Type); impl_daft_logical_data_array_datatype!(DateType, Date, Int32Type); impl_daft_logical_data_array_datatype!(TimeType, Unknown, Int64Type); @@ -334,9 +337,11 @@ impl DaftNumericType for Int32Type { impl DaftNumericType for Int64Type { type Native = i64; } + impl DaftNumericType for Int128Type { type Native = i128; } + impl DaftNumericType for Float32Type { type Native = f32; } @@ -350,6 +355,18 @@ where { } +pub trait DaftPrimitiveType: Send + Sync + DaftArrowBackedType + 'static { + type Native: NumericNative; +} + +impl DaftPrimitiveType for T { + type Native = T::Native; +} + +impl DaftPrimitiveType for Decimal128Type { + type Native = i128; +} + impl DaftIntegerType for UInt8Type {} impl DaftIntegerType for UInt16Type {} impl DaftIntegerType for UInt32Type {} @@ -358,7 +375,6 @@ impl DaftIntegerType for Int8Type {} impl DaftIntegerType for Int16Type {} impl DaftIntegerType for Int32Type {} impl DaftIntegerType for Int64Type {} -impl DaftIntegerType for Int128Type {} pub trait DaftFloatType: DaftNumericType where @@ -380,7 +396,6 @@ pub type Int8Array = DataArray; pub type Int16Array = DataArray; pub type Int32Array = DataArray; pub type Int64Array = DataArray; -pub type Int128Array = DataArray; pub type UInt8Array = DataArray; pub type UInt16Array = DataArray; pub type UInt32Array = DataArray; @@ -392,6 +407,7 @@ pub type FixedSizeBinaryArray = DataArray; pub type Utf8Array = DataArray; pub type ExtensionArray = DataArray; pub type IntervalArray = DataArray; +pub type Decimal128Array = DataArray; #[cfg(feature = "python")] pub type PythonArray = DataArray; diff --git a/src/daft-core/src/datatypes/prelude.rs b/src/daft-core/src/datatypes/prelude.rs index 27aaa044fa..9ae8ed5762 100644 --- a/src/daft-core/src/datatypes/prelude.rs +++ b/src/daft-core/src/datatypes/prelude.rs @@ -14,8 +14,8 @@ pub use daft_schema::{ pub use super::PythonArray; pub use super::{ BinaryType, BooleanType, DaftArrayType, ExtensionType, FixedSizeBinaryType, FixedSizeListType, - Float32Type, Float64Type, Int128Type, Int16Type, Int32Type, Int64Type, Int8Type, NullType, - UInt16Type, UInt32Type, UInt64Type, UInt8Type, Utf8Type, + Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, NullType, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, Utf8Type, }; // Import trait definitions pub use super::{ diff --git a/src/daft-core/src/series/array_impl/data_array.rs b/src/daft-core/src/series/array_impl/data_array.rs index 6ed822ef44..1efa60f3a7 100644 --- a/src/daft-core/src/series/array_impl/data_array.rs +++ b/src/daft-core/src/series/array_impl/data_array.rs @@ -171,7 +171,6 @@ impl_series_like_for_data_array!(Int8Array); impl_series_like_for_data_array!(Int16Array); impl_series_like_for_data_array!(Int32Array); impl_series_like_for_data_array!(Int64Array); -impl_series_like_for_data_array!(Int128Array); impl_series_like_for_data_array!(UInt8Array); impl_series_like_for_data_array!(UInt16Array); impl_series_like_for_data_array!(UInt32Array); @@ -181,5 +180,6 @@ impl_series_like_for_data_array!(Float64Array); impl_series_like_for_data_array!(Utf8Array); impl_series_like_for_data_array!(ExtensionArray); impl_series_like_for_data_array!(IntervalArray); +impl_series_like_for_data_array!(Decimal128Array); #[cfg(feature = "python")] impl_series_like_for_data_array!(PythonArray); diff --git a/src/daft-core/src/series/array_impl/logical_array.rs b/src/daft-core/src/series/array_impl/logical_array.rs index ec7d5c5ea3..85316f0ec4 100644 --- a/src/daft-core/src/series/array_impl/logical_array.rs +++ b/src/daft-core/src/series/array_impl/logical_array.rs @@ -169,7 +169,6 @@ macro_rules! impl_series_like_for_logical_array { }; } -impl_series_like_for_logical_array!(Decimal128Array); impl_series_like_for_logical_array!(DateArray); impl_series_like_for_logical_array!(TimeArray); impl_series_like_for_logical_array!(DurationArray); diff --git a/src/daft-core/src/series/ops/agg.rs b/src/daft-core/src/series/ops/agg.rs index b3bfee765c..1c7ef7e759 100644 --- a/src/daft-core/src/series/ops/agg.rs +++ b/src/daft-core/src/series/ops/agg.rs @@ -1,6 +1,5 @@ use arrow2::array::PrimitiveArray; use common_error::{DaftError, DaftResult}; -use logical::Decimal128Array; use crate::{ array::{ @@ -67,27 +66,21 @@ impl Series { .into_series()), None => Ok(DaftSumAggable::sum(&self.downcast::()?)?.into_series()), }, - DataType::Decimal128(_, _) => match groups { - Some(groups) => Ok(Decimal128Array::new( - Field { - dtype: try_sum_supertype(self.data_type())?, - ..self.field().clone() - }, - DaftSumAggable::grouped_sum( - &self.as_physical()?.downcast::()?, + DataType::Decimal128(_, _) => { + let casted = self.cast(&try_sum_supertype(self.data_type())?)?; + + match groups { + Some(groups) => Ok(DaftSumAggable::grouped_sum( + &casted.downcast::()?, groups, - )?, - ) - .into_series()), - None => Ok(Decimal128Array::new( - Field { - dtype: try_sum_supertype(self.data_type())?, - ..self.field().clone() - }, - DaftSumAggable::sum(&self.as_physical()?.downcast::()?)?, - ) - .into_series()), - }, + )? + .into_series()), + None => { + Ok(DaftSumAggable::sum(&casted.downcast::()?)? + .into_series()) + } + } + } other => Err(DaftError::TypeError(format!( "Numeric sum is not implemented for type {}", other diff --git a/src/daft-core/src/series/ops/arithmetic.rs b/src/daft-core/src/series/ops/arithmetic.rs index baf5087fb6..7b5d789603 100644 --- a/src/daft-core/src/series/ops/arithmetic.rs +++ b/src/daft-core/src/series/ops/arithmetic.rs @@ -62,6 +62,12 @@ impl Add for &Series { }) } // ---------------- + // Decimal Types + // ---------------- + DataType::Decimal128(..) => { + Ok(cast_downcast_op!(lhs, rhs, &output_type, Decimal128Array, add)?.into_series()) + } + // ---------------- // FixedSizeLists of numeric types (fsl, embedding, tensor, etc.) // ---------------- output_type if output_type.is_fixed_size_numeric() => { @@ -160,6 +166,12 @@ impl Sub for &Series { fixed_size_binary_op(lhs, rhs, output_type, FixedSizeBinaryOp::Sub) } // ---------------- + // Decimal Types + // ---------------- + DataType::Decimal128(..) => { + Ok(cast_downcast_op!(lhs, rhs, &output_type, Decimal128Array, sub)?.into_series()) + } + // ---------------- // Temporal types // ---------------- output_type @@ -245,6 +257,12 @@ impl Mul for &Series { }) } // ---------------- + // Decimal Types + // ---------------- + DataType::Decimal128(..) => { + Ok(cast_downcast_op!(lhs, rhs, &output_type, Decimal128Array, mul)?.into_series()) + } + // ---------------- // FixedSizeLists of numeric types (fsl, embedding, tensor, etc.) // ---------------- output_type if output_type.is_fixed_size_numeric() => { @@ -277,6 +295,12 @@ impl Div for &Series { ) } // ---------------- + // Decimal Types + // ---------------- + DataType::Decimal128(..) => { + Ok(cast_downcast_op!(lhs, rhs, &output_type, Decimal128Array, div)?.into_series()) + } + // ---------------- // FixedSizeLists of numeric types (fsl, embedding, tensor, etc.) // ---------------- output_type if output_type.is_fixed_size_numeric() => { diff --git a/src/daft-core/src/series/ops/downcast.rs b/src/daft-core/src/series/ops/downcast.rs index 5cb3cdd18c..af19767deb 100644 --- a/src/daft-core/src/series/ops/downcast.rs +++ b/src/daft-core/src/series/ops/downcast.rs @@ -8,7 +8,7 @@ use self::logical::{DurationArray, ImageArray, MapArray}; use crate::{ array::{ListArray, StructArray}, datatypes::{ - logical::{DateArray, Decimal128Array, FixedShapeImageArray, TimeArray, TimestampArray}, + logical::{DateArray, FixedShapeImageArray, TimeArray, TimestampArray}, *, }, series::{array_impl::ArrayWrapper, Series}, diff --git a/src/daft-core/src/series/serdes.rs b/src/daft-core/src/series/serdes.rs index 843950cc23..4160636b7e 100644 --- a/src/daft-core/src/series/serdes.rs +++ b/src/daft-core/src/series/serdes.rs @@ -10,7 +10,7 @@ use crate::{ }, datatypes::{ logical::{ - DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, + DateArray, DurationArray, EmbeddingArray, FixedShapeImageArray, FixedShapeSparseTensorArray, FixedShapeTensorArray, ImageArray, MapArray, SparseTensorArray, TensorArray, TimeArray, TimestampArray, }, @@ -85,57 +85,52 @@ impl<'d> serde::Deserialize<'d> for Series { )) .into_series()), DataType::Int8 => Ok(Int8Array::from_iter( - field.name.as_str(), + field, map.next_value::>>()?.into_iter(), ) .into_series()), DataType::Int16 => Ok(Int16Array::from_iter( - field.name.as_str(), + field, map.next_value::>>()?.into_iter(), ) .into_series()), DataType::Int32 => Ok(Int32Array::from_iter( - field.name.as_str(), + field, map.next_value::>>()?.into_iter(), ) .into_series()), DataType::Int64 => Ok(Int64Array::from_iter( - field.name.as_str(), + field, map.next_value::>>()?.into_iter(), ) .into_series()), - DataType::Int128 => Ok(Int128Array::from_iter( - field.name.as_str(), - map.next_value::>>()?.into_iter(), - ) - .into_series()), DataType::UInt8 => Ok(UInt8Array::from_iter( - field.name.as_str(), + field, map.next_value::>>()?.into_iter(), ) .into_series()), DataType::UInt16 => Ok(UInt16Array::from_iter( - field.name.as_str(), + field, map.next_value::>>()?.into_iter(), ) .into_series()), DataType::UInt32 => Ok(UInt32Array::from_iter( - field.name.as_str(), + field, map.next_value::>>()?.into_iter(), ) .into_series()), DataType::UInt64 => Ok(UInt64Array::from_iter( - field.name.as_str(), + field, map.next_value::>>()?.into_iter(), ) .into_series()), DataType::Float32 => Ok(Float32Array::from_iter( - field.name.as_str(), + field, map.next_value::>>()?.into_iter(), ) .into_series()), DataType::Float64 => Ok(Float64Array::from_iter( - field.name.as_str(), + field, map.next_value::>>()?.into_iter(), ) .into_series()), @@ -219,15 +214,11 @@ impl<'d> serde::Deserialize<'d> for Series { let validity = validity.map(|v| v.bool().unwrap().as_bitmap().clone()); Ok(FixedSizeListArray::new(field, flat_child, validity).into_series()) } - DataType::Decimal128(..) => { - type PType = <::PhysicalType as DaftDataType>::ArrayType; - let physical = map.next_value::()?; - Ok(Decimal128Array::new( - field, - physical.downcast::().unwrap().clone(), - ) - .into_series()) - } + DataType::Decimal128(..) => Ok(Decimal128Array::from_iter( + Arc::new(field.clone()), + map.next_value::>>()?.into_iter(), + ) + .into_series()), DataType::Timestamp(..) => { type PType = <::PhysicalType as DaftDataType>::ArrayType; let physical = map.next_value::()?; diff --git a/src/daft-dsl/src/lit.rs b/src/daft-dsl/src/lit.rs index 47b4090be5..8f0fba1fec 100644 --- a/src/daft-dsl/src/lit.rs +++ b/src/daft-dsl/src/lit.rs @@ -238,9 +238,10 @@ impl LiteralValue { ) .into_series(), Self::Float64(val) => Float64Array::from(("literal", [*val].as_slice())).into_series(), - Self::Decimal(val, ..) => { - let physical = Int128Array::from(("literal", [*val].as_slice())); - Decimal128Array::new(Field::new("literal", self.get_type()), physical).into_series() + Self::Decimal(val, p, s) => { + let dtype = DataType::Decimal128(*p as usize, *s as usize); + let field = Field::new("literal", dtype); + Decimal128Array::from_values_iter(field, std::iter::once(*val)).into_series() } Self::Series(series) => series.clone().rename("literal"), #[cfg(feature = "python")] @@ -495,19 +496,19 @@ pub fn literals_to_series(values: &[LiteralValue]) -> DaftResult { } DataType::Int32 => { let data = values.iter().map(|lit| unwrap_unchecked!(lit, Int32)); - Int32Array::from_iter("literal", data).into_series() + Int32Array::from_iter(Field::new("literal", DataType::Int32), data).into_series() } DataType::UInt32 => { let data = values.iter().map(|lit| unwrap_unchecked!(lit, UInt32)); - UInt32Array::from_iter("literal", data).into_series() + UInt32Array::from_iter(Field::new("literal", DataType::UInt32), data).into_series() } DataType::Int64 => { let data = values.iter().map(|lit| unwrap_unchecked!(lit, Int64)); - Int64Array::from_iter("literal", data).into_series() + Int64Array::from_iter(Field::new("literal", DataType::Int64), data).into_series() } DataType::UInt64 => { let data = values.iter().map(|lit| unwrap_unchecked!(lit, UInt64)); - UInt64Array::from_iter("literal", data).into_series() + UInt64Array::from_iter(Field::new("literal", DataType::UInt64), data).into_series() } DataType::Interval => { let data = values.iter().map(|lit| match lit { @@ -519,29 +520,29 @@ pub fn literals_to_series(values: &[LiteralValue]) -> DaftResult { } dtype @ DataType::Timestamp(_, _) => { let data = values.iter().map(|lit| unwrap_unchecked!(lit, Timestamp)); - let physical = Int64Array::from_iter("literal", data); + let physical = Int64Array::from_iter(Field::new("literal", DataType::Int64), data); TimestampArray::new(Field::new("literal", dtype), physical).into_series() } dtype @ DataType::Date => { let data = values.iter().map(|lit| unwrap_unchecked!(lit, Date)); - let physical = Int32Array::from_iter("literal", data); + let physical = Int32Array::from_iter(Field::new("literal", DataType::Int32), data); DateArray::new(Field::new("literal", dtype), physical).into_series() } dtype @ DataType::Time(_) => { let data = values.iter().map(|lit| unwrap_unchecked!(lit, Time)); - let physical = Int64Array::from_iter("literal", data); + let physical = Int64Array::from_iter(Field::new("literal", DataType::Int64), data); + TimeArray::new(Field::new("literal", dtype), physical).into_series() } DataType::Float64 => { let data = values.iter().map(|lit| unwrap_unchecked!(lit, Float64)); - Float64Array::from_iter("literal", data).into_series() + Float64Array::from_iter(Field::new("literal", dtype), data).into_series() } dtype @ DataType::Decimal128 { .. } => { let data = values.iter().map(|lit| unwrap_unchecked!(lit, Decimal)); - let physical = Int128Array::from_iter("literal", data); - Decimal128Array::new(Field::new("literal", dtype), physical).into_series() + Decimal128Array::from_iter(Field::new("literal", dtype), data).into_series() } _ => { return Err(DaftError::ValueError(format!( @@ -580,7 +581,10 @@ mod test { LiteralValue::UInt64(3), ]; let expected = vec![None, Some(2), Some(3)]; - let expected = UInt64Array::from_iter("literal", expected.into_iter()); + let expected = UInt64Array::from_iter( + Field::new("literal", DataType::UInt64), + expected.into_iter(), + ); let expected = expected.into_series(); let actual = super::literals_to_series(&values).unwrap(); // Series.eq returns false for nulls diff --git a/src/daft-functions/src/distance/cosine.rs b/src/daft-functions/src/distance/cosine.rs index 11b3d1eef2..b56cb418f2 100644 --- a/src/daft-functions/src/distance/cosine.rs +++ b/src/daft-functions/src/distance/cosine.rs @@ -117,7 +117,10 @@ impl ScalarUDF for CosineDistanceFunction { } }?; - let output = Float64Array::from_iter(source_name, res.into_iter()); + let output = Float64Array::from_iter( + Field::new(source_name, DataType::Float64), + res.into_iter(), + ); Ok(output.into_series()) } diff --git a/src/daft-functions/src/hash.rs b/src/daft-functions/src/hash.rs index f7ab7a7a30..f517ce74b6 100644 --- a/src/daft-functions/src/hash.rs +++ b/src/daft-functions/src/hash.rs @@ -30,7 +30,7 @@ impl ScalarUDF for HashFunction { let seed = seed.u64().unwrap(); let seed = seed.get(0).unwrap(); let seed = UInt64Array::from_iter( - "seed", + Field::new("seed", DataType::UInt64), std::iter::repeat(Some(seed)).take(input.len()), ); input diff --git a/src/daft-parquet/src/statistics/column_range.rs b/src/daft-parquet/src/statistics/column_range.rs index ac62627ea4..a604c8875f 100644 --- a/src/daft-parquet/src/statistics/column_range.rs +++ b/src/daft-parquet/src/statistics/column_range.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use arrow2::array::PrimitiveArray; use daft_core::prelude::*; use daft_stats::ColumnRangeStatistics; @@ -127,17 +129,13 @@ fn make_decimal_column_range_statistics( } let l = convert_i128(lower, lower.len()); let u = convert_i128(upper, upper.len()); - let lower = Int128Array::from(("lower", [l].as_slice())); - let upper = Int128Array::from(("upper", [u].as_slice())); + let daft_type = daft_core::datatypes::DataType::Decimal128(p, s); + let lower_field = Arc::new(daft_core::datatypes::Field::new("lower", daft_type.clone())); + let upper_field = Arc::new(daft_core::datatypes::Field::new("upper", daft_type)); - let lower = Decimal128Array::new( - daft_core::datatypes::Field::new("lower", daft_type.clone()), - lower, - ) - .into_series(); - let upper = Decimal128Array::new(daft_core::datatypes::Field::new("upper", daft_type), upper) - .into_series(); + let lower = Decimal128Array::from_iter(lower_field, std::iter::once(Some(l))).into_series(); + let upper = Decimal128Array::from_iter(upper_field, std::iter::once(Some(u))).into_series(); Ok(ColumnRangeStatistics::new(Some(lower), Some(upper))?) } diff --git a/src/daft-schema/src/dtype.rs b/src/daft-schema/src/dtype.rs index f9650e3396..77e8fc173e 100644 --- a/src/daft-schema/src/dtype.rs +++ b/src/daft-schema/src/dtype.rs @@ -30,9 +30,6 @@ pub enum DataType { /// An [`i64`] Int64, - /// An [`i128`] - Int128, - /// An [`u8`] UInt8, @@ -212,10 +209,6 @@ impl DataType { Self::Int16 => Ok(ArrowType::Int16), Self::Int32 => Ok(ArrowType::Int32), Self::Int64 => Ok(ArrowType::Int64), - // Must maintain same default mapping as Arrow2, otherwise this will throw errors in - // DataArray::new() which makes strong assumptions about the arrow/Daft types - // https://github.com/jorgecarleitao/arrow2/blob/b0734542c2fef5d2d0c7b6ffce5d094de371168a/src/datatypes/mod.rs#L493 - Self::Int128 => Ok(ArrowType::Decimal(32, 32)), Self::UInt8 => Ok(ArrowType::UInt8), Self::UInt16 => Ok(ArrowType::UInt16), Self::UInt32 => Ok(ArrowType::UInt32), @@ -311,7 +304,6 @@ impl DataType { pub fn to_physical(&self) -> Self { use DataType::*; match self { - Decimal128(..) => Int128, Date => Int32, Duration(_) | Timestamp(..) | Time(_) => Int64, @@ -374,7 +366,6 @@ impl DataType { | Self::Int16 | Self::Int32 | Self::Int64 - | Self::Int128 | Self::UInt8 | Self::UInt16 | Self::UInt32 @@ -427,7 +418,6 @@ impl DataType { | Self::Int16 | Self::Int32 | Self::Int64 - | Self::Int128 | Self::UInt8 | Self::UInt16 | Self::UInt32 @@ -565,7 +555,7 @@ impl DataType { Self::Int16 => Some(2.), Self::Int32 => Some(4.), Self::Int64 => Some(8.), - Self::Int128 => Some(16.), + Self::Decimal128(..) => Some(16.), Self::UInt8 => Some(1.), Self::UInt16 => Some(2.), Self::UInt32 => Some(4.), @@ -596,8 +586,7 @@ impl DataType { pub fn is_logical(&self) -> bool { matches!( self, - Self::Decimal128(..) - | Self::Date + Self::Date | Self::Time(..) | Self::Timestamp(..) | Self::Duration(..) diff --git a/src/daft-stats/src/column_stats/mod.rs b/src/daft-stats/src/column_stats/mod.rs index eae64bbebd..491c63ec40 100644 --- a/src/daft-stats/src/column_stats/mod.rs +++ b/src/daft-stats/src/column_stats/mod.rs @@ -60,7 +60,7 @@ impl ColumnRangeStatistics { DataType::Null | // Numeric types - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 | DataType::Int128 | + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 | DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 | DataType::Float32 | DataType::Float64 | DataType::Decimal128(..) | DataType::Boolean | diff --git a/src/daft-table/src/repr_html.rs b/src/daft-table/src/repr_html.rs index 20e20e322f..a5cf9a822e 100644 --- a/src/daft-table/src/repr_html.rs +++ b/src/daft-table/src/repr_html.rs @@ -30,10 +30,6 @@ pub fn html_value(s: &Series, idx: usize) -> String { let arr = s.i64().unwrap(); arr.html_value(idx) } - DataType::Int128 => { - let arr = s.i128().unwrap(); - arr.html_value(idx) - } DataType::UInt8 => { let arr = s.u8().unwrap(); arr.html_value(idx) diff --git a/tests/series/test_arithmetic.py b/tests/series/test_arithmetic.py index eb369f232c..0199041bc9 100644 --- a/tests/series/test_arithmetic.py +++ b/tests/series/test_arithmetic.py @@ -9,17 +9,31 @@ from daft import DataType, Series arrow_int_types = [pa.int8(), pa.uint8(), pa.int16(), pa.uint16(), pa.int32(), pa.uint32(), pa.int64(), pa.uint64()] -arrow_string_types = [pa.string(), pa.large_string()] +arrow_decimal_types = [pa.decimal128(4, 0), pa.decimal128(5, 1)] arrow_float_types = [pa.float32(), pa.float64()] +arrow_number_types = arrow_int_types + arrow_decimal_types + arrow_float_types +arrow_string_types = [pa.string(), pa.large_string()] + + +def arrow_number_combinations(): + for left in arrow_number_types: + for right in arrow_number_types: + # we can't perform all ops on decimal and 64 bit ints + if pa.types.is_decimal(left) and (pa.types.is_int64(right) or pa.types.is_uint64(right)): + continue + if pa.types.is_decimal(right) and (pa.types.is_int64(left) or pa.types.is_uint64(left)): + continue + + yield (left, right) -@pytest.mark.parametrize("l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_float_types, repeat=2)) +@pytest.mark.parametrize("l_dtype, r_dtype", arrow_number_combinations()) def test_arithmetic_numbers_array(l_dtype, r_dtype) -> None: - l_arrow = pa.array([1, 2, 3, None, 5, None]) - r_arrow = pa.array([1, 4, 1, 5, None, None]) + l_arrow = pa.array([1, 2, 3, None, 5, None], type=l_dtype) + r_arrow = pa.array([1, 4, 1, 5, None, None], type=r_dtype) - left = Series.from_arrow(l_arrow.cast(l_dtype), name="left") - right = Series.from_arrow(r_arrow.cast(r_dtype), name="right") + left = Series.from_arrow(l_arrow, name="left") + right = Series.from_arrow(r_arrow, name="right") add = left + right assert add.name() == left.name() @@ -35,24 +49,25 @@ def test_arithmetic_numbers_array(l_dtype, r_dtype) -> None: div = left / right assert div.name() == left.name() - assert div.to_pylist() == [1.0, 0.5, 3.0, None, None, None] + assert div.cast(DataType.float64()).to_pylist() == [1.0, 0.5, 3.0, None, None, None] - mod = left % right - assert mod.name() == left.name() - assert mod.to_pylist() == [0, 2, 0, None, None, None] + if not pa.types.is_decimal(l_dtype) and not pa.types.is_decimal(r_dtype): + floor_div = left // right + assert floor_div.name() == left.name() + assert floor_div.to_pylist() == [1, 0, 3, None, None, None] - floor_div = left // right - assert floor_div.name() == left.name() - assert floor_div.to_pylist() == [1, 0, 3, None, None, None] + mod = left % right + assert mod.name() == left.name() + assert mod.to_pylist() == [0, 2, 0, None, None, None] -@pytest.mark.parametrize("l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_float_types, repeat=2)) +@pytest.mark.parametrize("l_dtype, r_dtype", arrow_number_combinations()) def test_arithmetic_numbers_left_scalar(l_dtype, r_dtype) -> None: - l_arrow = pa.array([1]) - r_arrow = pa.array([1, 4, 1, 5, None, None]) + l_arrow = pa.array([1], type=l_dtype) + r_arrow = pa.array([1, 4, 1, 5, None, None], type=r_dtype) - left = Series.from_arrow(l_arrow.cast(l_dtype), name="left") - right = Series.from_arrow(r_arrow.cast(r_dtype), name="right") + left = Series.from_arrow(l_arrow, name="left") + right = Series.from_arrow(r_arrow, name="right") add = left + right assert add.name() == left.name() @@ -69,24 +84,25 @@ def test_arithmetic_numbers_left_scalar(l_dtype, r_dtype) -> None: div = left / right assert div.name() == left.name() - assert div.to_pylist() == [1.0, 0.25, 1.0, 0.2, None, None] + assert div.cast(DataType.float64()).to_pylist() == [1.0, 0.25, 1.0, 0.2, None, None] - floor_div = left // right - assert floor_div.name() == left.name() - assert floor_div.to_pylist() == [1, 0, 1, 0, None, None] + if not pa.types.is_decimal(l_dtype) and not pa.types.is_decimal(r_dtype): + floor_div = left // right + assert floor_div.name() == left.name() + assert floor_div.to_pylist() == [1, 0, 1, 0, None, None] - mod = left % right - assert mod.name() == left.name() - assert mod.to_pylist() == [0, 1, 0, 1, None, None] + mod = left % right + assert mod.name() == left.name() + assert mod.to_pylist() == [0, 1, 0, 1, None, None] -@pytest.mark.parametrize("l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_float_types, repeat=2)) +@pytest.mark.parametrize("l_dtype, r_dtype", arrow_number_combinations()) def test_arithmetic_numbers_right_scalar(l_dtype, r_dtype) -> None: - l_arrow = pa.array([1, 2, 3, None, 5, None]) - r_arrow = pa.array([1]) + l_arrow = pa.array([1, 2, 3, None, 5, None], type=l_dtype) + r_arrow = pa.array([1], type=r_dtype) - left = Series.from_arrow(l_arrow.cast(l_dtype), name="left") - right = Series.from_arrow(r_arrow.cast(r_dtype), name="right") + left = Series.from_arrow(l_arrow, name="left") + right = Series.from_arrow(r_arrow, name="right") add = left + right assert add.name() == left.name() @@ -103,23 +119,24 @@ def test_arithmetic_numbers_right_scalar(l_dtype, r_dtype) -> None: div = left / right assert div.name() == left.name() - assert div.to_pylist() == [1.0, 2.0, 3.0, None, 5.0, None] + assert div.cast(DataType.float64()).to_pylist() == [1.0, 2.0, 3.0, None, 5.0, None] - floor_div = left // right - assert floor_div.name() == left.name() - assert floor_div.to_pylist() == [1, 2, 3, None, 5, None] + if not pa.types.is_decimal(l_dtype) and not pa.types.is_decimal(r_dtype): + floor_div = left // right + assert floor_div.name() == left.name() + assert floor_div.to_pylist() == [1, 2, 3, None, 5, None] - mod = left % right - assert mod.name() == left.name() - assert mod.to_pylist() == [0, 0, 0, None, 0, None] + mod = left % right + assert mod.name() == left.name() + assert mod.to_pylist() == [0, 0, 0, None, 0, None] -@pytest.mark.parametrize("l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_float_types, repeat=2)) +@pytest.mark.parametrize("l_dtype, r_dtype", arrow_number_combinations()) def test_arithmetic_numbers_null_scalar(l_dtype, r_dtype) -> None: - l_arrow = pa.array([1, 2, 3, None, 5, None]) + l_arrow = pa.array([1, 2, 3, None, 5, None], type=l_dtype) r_arrow = pa.array([None], type=r_dtype) - left = Series.from_arrow(l_arrow.cast(l_dtype), name="left") + left = Series.from_arrow(l_arrow, name="left") right = Series.from_arrow(r_arrow, name="right") add = left + right @@ -139,13 +156,14 @@ def test_arithmetic_numbers_null_scalar(l_dtype, r_dtype) -> None: assert div.name() == left.name() assert div.to_pylist() == [None, None, None, None, None, None] - floor_div = left / right - assert floor_div.name() == left.name() - assert floor_div.to_pylist() == [None, None, None, None, None, None] + if not pa.types.is_decimal(l_dtype) and not pa.types.is_decimal(r_dtype): + floor_div = left / right + assert floor_div.name() == left.name() + assert floor_div.to_pylist() == [None, None, None, None, None, None] - mod = left % right - assert mod.name() == left.name() - assert mod.to_pylist() == [None, None, None, None, None, None] + mod = left % right + assert mod.name() == left.name() + assert mod.to_pylist() == [None, None, None, None, None, None] @pytest.mark.parametrize( diff --git a/tests/series/test_comparisons.py b/tests/series/test_comparisons.py index bfbfd2b236..37a52330c6 100644 --- a/tests/series/test_comparisons.py +++ b/tests/series/test_comparisons.py @@ -11,24 +11,32 @@ from daft import DataType, Series arrow_int_types = [pa.int8(), pa.uint8(), pa.int16(), pa.uint16(), pa.int32(), pa.uint32(), pa.int64(), pa.uint64()] +arrow_decimal_types = [pa.decimal128(20, 5), pa.decimal128(15, 9)] arrow_string_types = [pa.string(), pa.large_string()] arrow_float_types = [pa.float32(), pa.float64()] arrow_binary_types = [pa.binary(), pa.large_binary()] +arrow_number_types = arrow_int_types + arrow_decimal_types + arrow_float_types -VALID_INT_STRING_COMPARISONS = list(itertools.product(arrow_int_types, repeat=2)) + list( +VALID_INT_STRING_COMPARISONS = list(itertools.product(arrow_int_types + arrow_decimal_types, repeat=2)) + list( itertools.product(arrow_string_types, repeat=2) ) +def make_array(data: list, type=None) -> pa.array: + if type is not None and (pa.types.is_string(type) or pa.types.is_large_string(type)): + data = [str(x) if x is not None else None for x in data] + return pa.array(data, type=type) + + @pytest.mark.parametrize("l_dtype, r_dtype", VALID_INT_STRING_COMPARISONS) def test_comparisons_int_and_str(l_dtype, r_dtype) -> None: - l_arrow = pa.array([1, 2, 3, None, 5, None]) - r_arrow = pa.array([1, 3, 1, 5, None, None]) + l_arrow = make_array([1, 2, 3, None, 5, None], type=l_dtype) + r_arrow = make_array([1, 3, 1, 5, None, None], type=r_dtype) # eq, lt, gt, None, None, None - left = Series.from_arrow(l_arrow.cast(l_dtype)) - right = Series.from_arrow(r_arrow.cast(r_dtype)) + left = Series.from_arrow(l_arrow) + right = Series.from_arrow(r_arrow) lt = (left < right).to_pylist() assert lt == [False, True, False, None, None, None] @@ -50,12 +58,12 @@ def test_comparisons_int_and_str(l_dtype, r_dtype) -> None: @pytest.mark.parametrize("l_dtype, r_dtype", VALID_INT_STRING_COMPARISONS) def test_comparisons_int_and_str_left_scalar(l_dtype, r_dtype) -> None: - l_arrow = pa.array([2]) - r_arrow = pa.array([1, 2, 3, None]) + l_arrow = make_array([2], type=l_dtype) + r_arrow = make_array([1, 2, 3, None], type=r_dtype) # gt, eq, lt - left = Series.from_arrow(l_arrow.cast(l_dtype)) - right = Series.from_arrow(r_arrow.cast(r_dtype)) + left = Series.from_arrow(l_arrow) + right = Series.from_arrow(r_arrow) lt = (left < right).to_pylist() assert lt == [False, False, True, None] @@ -78,12 +86,12 @@ def test_comparisons_int_and_str_left_scalar(l_dtype, r_dtype) -> None: @pytest.mark.parametrize("l_dtype, r_dtype", VALID_INT_STRING_COMPARISONS) def test_comparisons_int_and_str_right_scalar(l_dtype, r_dtype) -> None: - l_arrow = pa.array([1, 2, 3, None, 5, None]) - r_arrow = pa.array([2]) + l_arrow = make_array([1, 2, 3, None, 5, None], type=l_dtype) + r_arrow = make_array([2], type=r_dtype) # lt, eq, gt, None, gt, None - left = Series.from_arrow(l_arrow.cast(l_dtype)) - right = Series.from_arrow(r_arrow.cast(r_dtype)) + left = Series.from_arrow(l_arrow) + right = Series.from_arrow(r_arrow) lt = (left < right).to_pylist() assert lt == [True, False, False, None, False, None] @@ -105,11 +113,11 @@ def test_comparisons_int_and_str_right_scalar(l_dtype, r_dtype) -> None: @pytest.mark.parametrize("l_dtype, r_dtype", VALID_INT_STRING_COMPARISONS) def test_comparisons_int_and_str_right_null_scalar(l_dtype, r_dtype) -> None: - l_arrow = pa.array([1, 2, 3, None, 5, None]) - r_arrow = pa.array([None], type=r_dtype) + l_arrow = make_array([1, 2, 3, None, 5, None], type=l_dtype) + r_arrow = make_array([None], type=r_dtype) # lt, eq, gt, None, gt, None - left = Series.from_arrow(l_arrow.cast(l_dtype)) + left = Series.from_arrow(l_arrow) right = Series.from_arrow(r_arrow) lt = (left < right).to_pylist() assert lt == [None, None, None, None, None, None] @@ -130,14 +138,16 @@ def test_comparisons_int_and_str_right_null_scalar(l_dtype, r_dtype) -> None: assert gt == [None, None, None, None, None, None] -@pytest.mark.parametrize("l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_float_types, repeat=2)) +@pytest.mark.parametrize( + "l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_float_types + arrow_decimal_types, repeat=2) +) def test_comparisons_int_and_float(l_dtype, r_dtype) -> None: - l_arrow = pa.array([1, 2, 3, None, 5, None]) - r_arrow = pa.array([1, 3, 1, 5, None, None]) + l_arrow = make_array([1, 2, 3, None, 5, None], type=l_dtype) + r_arrow = make_array([1, 3, 1, 5, None, None], type=r_dtype) # eq, lt, gt, None, None, None - left = Series.from_arrow(l_arrow.cast(l_dtype)) - right = Series.from_arrow(r_arrow.cast(r_dtype)) + left = Series.from_arrow(l_arrow) + right = Series.from_arrow(r_arrow) lt = (left < right).to_pylist() assert lt == [False, True, False, None, None, None] @@ -157,14 +167,16 @@ def test_comparisons_int_and_float(l_dtype, r_dtype) -> None: assert gt == [False, False, True, None, None, None] -@pytest.mark.parametrize("l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_float_types, repeat=2)) +@pytest.mark.parametrize( + "l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_float_types + arrow_decimal_types, repeat=2) +) def test_comparisons_int_and_float_right_scalar(l_dtype, r_dtype) -> None: - l_arrow = pa.array([1, 2, 3, None, 5, None]) - r_arrow = pa.array([2]) + l_arrow = make_array([1, 2, 3, None, 5, None], type=l_dtype) + r_arrow = make_array([2], type=r_dtype) # lt, eq, gt, None, gt, None - left = Series.from_arrow(l_arrow.cast(l_dtype)) - right = Series.from_arrow(r_arrow.cast(r_dtype)) + left = Series.from_arrow(l_arrow) + right = Series.from_arrow(r_arrow) lt = (left < right).to_pylist() assert lt == [True, False, False, None, False, None] @@ -184,13 +196,15 @@ def test_comparisons_int_and_float_right_scalar(l_dtype, r_dtype) -> None: assert gt == [False, False, True, None, True, None] -@pytest.mark.parametrize("l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_float_types, repeat=2)) +@pytest.mark.parametrize( + "l_dtype, r_dtype", itertools.product(arrow_int_types + arrow_float_types + arrow_decimal_types, repeat=2) +) def test_comparisons_int_and_float_right_null_scalar(l_dtype, r_dtype) -> None: - l_arrow = pa.array([1, 2, 3, None, 5, None]) - r_arrow = pa.array([None], type=r_dtype) + l_arrow = make_array([1, 2, 3, None, 5, None], type=l_dtype) + r_arrow = make_array([None], type=r_dtype) # lt, eq, gt, None, gt, None - left = Series.from_arrow(l_arrow.cast(l_dtype)) + left = Series.from_arrow(l_arrow) right = Series.from_arrow(r_arrow) lt = (left < right).to_pylist() assert lt == [None, None, None, None, None, None] @@ -212,8 +226,8 @@ def test_comparisons_int_and_float_right_null_scalar(l_dtype, r_dtype) -> None: def test_comparisons_boolean_array() -> None: - l_arrow = pa.array([False, False, None, True, None]) - r_arrow = pa.array([True, False, True, None, None]) + l_arrow = make_array([False, False, None, True, None]) + r_arrow = make_array([True, False, True, None, None]) # lt, eq, lt, None left = Series.from_arrow(l_arrow) @@ -248,8 +262,8 @@ def test_comparisons_boolean_array() -> None: def test_comparisons_boolean_array_right_scalar() -> None: - l_arrow = pa.array([False, True, None]) - r_arrow = pa.array([True]) + l_arrow = make_array([False, True, None]) + r_arrow = make_array([True]) left = Series.from_arrow(l_arrow) right = Series.from_arrow(r_arrow) @@ -281,7 +295,7 @@ def test_comparisons_boolean_array_right_scalar() -> None: _xor = (left ^ right).to_pylist() assert _xor == [True, False, None] - r_arrow = pa.array([False]) + r_arrow = make_array([False]) right = Series.from_arrow(r_arrow) lt = (left < right).to_pylist() @@ -311,7 +325,7 @@ def test_comparisons_boolean_array_right_scalar() -> None: _xor = (left ^ right).to_pylist() assert _xor == [False, True, None] - r_arrow = pa.array([None], type=pa.bool_()) + r_arrow = make_array([None], type=pa.bool_()) right = Series.from_arrow(r_arrow) lt = (left < right).to_pylist() @@ -343,8 +357,8 @@ def test_comparisons_boolean_array_right_scalar() -> None: def test_comparisons_boolean_array_left_scalar() -> None: - l_arrow = pa.array([True]) - r_arrow = pa.array([False, True, None]) + l_arrow = make_array([True]) + r_arrow = make_array([False, True, None]) # lt, eq, lt, None left = Series.from_arrow(l_arrow) @@ -379,7 +393,7 @@ def test_comparisons_boolean_array_left_scalar() -> None: def test_comparisons_bad_right_value() -> None: - l_arrow = pa.array([1, 2, 3, None, 5, None]) + l_arrow = make_array([1, 2, 3, None, 5, None]) left = Series.from_arrow(l_arrow) right = [1, 2, 3, None, 5, None] @@ -413,8 +427,8 @@ def test_comparisons_bad_right_value() -> None: def test_boolean_array_mismatch_length() -> None: - l_arrow = pa.array([False, True, None, None]) - r_arrow = pa.array([False, True, False, True, None]) + l_arrow = make_array([False, True, None, None]) + r_arrow = make_array([False, True, False, True, None]) left = Series.from_arrow(l_arrow) right = Series.from_arrow(r_arrow) @@ -448,8 +462,8 @@ def test_boolean_array_mismatch_length() -> None: def test_logical_ops_with_non_boolean() -> None: - l_arrow = pa.array([False, True, None, None]) - r_arrow = pa.array([1, 2, 3, 4]) + l_arrow = make_array([False, True, None, None]) + r_arrow = make_array([1, 2, 3, 4]) left = Series.from_arrow(l_arrow) right = Series.from_arrow(r_arrow) @@ -502,8 +516,8 @@ def test_comparisons_dates() -> None: @pytest.mark.parametrize("l_dtype, r_dtype", itertools.product(arrow_binary_types, repeat=2)) def test_comparisons_binary(l_dtype, r_dtype) -> None: - l_arrow = pa.array([b"1", b"22", b"333", None, b"55555", None]) - r_arrow = pa.array([b"1", b"333", b"1", b"55555", None, None]) + l_arrow = make_array([b"1", b"22", b"333", None, b"55555", None]) + r_arrow = make_array([b"1", b"333", b"1", b"55555", None, None]) # eq, lt, gt, None, None, None left = Series.from_arrow(l_arrow.cast(l_dtype)) @@ -529,8 +543,8 @@ def test_comparisons_binary(l_dtype, r_dtype) -> None: @pytest.mark.parametrize("l_dtype, r_dtype", itertools.product(arrow_binary_types, repeat=2)) def test_comparisons_binary_left_scalar(l_dtype, r_dtype) -> None: - l_arrow = pa.array([b"22"]) - r_arrow = pa.array([b"1", b"22", b"333", None]) + l_arrow = make_array([b"22"]) + r_arrow = make_array([b"1", b"22", b"333", None]) # gt, eq, lt left = Series.from_arrow(l_arrow.cast(l_dtype)) @@ -557,8 +571,8 @@ def test_comparisons_binary_left_scalar(l_dtype, r_dtype) -> None: @pytest.mark.parametrize("l_dtype, r_dtype", itertools.product(arrow_binary_types, repeat=2)) def test_comparisons_binary_right_scalar(l_dtype, r_dtype) -> None: - l_arrow = pa.array([b"1", b"22", b"333", None, b"55555", None]) - r_arrow = pa.array([b"22"]) + l_arrow = make_array([b"1", b"22", b"333", None, b"55555", None]) + r_arrow = make_array([b"22"]) # lt, eq, gt, None, gt, None left = Series.from_arrow(l_arrow.cast(l_dtype)) @@ -583,8 +597,8 @@ def test_comparisons_binary_right_scalar(l_dtype, r_dtype) -> None: def test_comparisons_fixed_size_binary() -> None: - l_arrow = pa.array([b"11111", b"22222", b"33333", None, b"12345", None], type=pa.binary(5)) - r_arrow = pa.array([b"11111", b"33333", b"11111", b"12345", None, None], type=pa.binary(5)) + l_arrow = make_array([b"11111", b"22222", b"33333", None, b"12345", None], type=pa.binary(5)) + r_arrow = make_array([b"11111", b"33333", b"11111", b"12345", None, None], type=pa.binary(5)) # eq, lt, gt, None, None, None left = Series.from_arrow(l_arrow) @@ -609,8 +623,8 @@ def test_comparisons_fixed_size_binary() -> None: def test_comparisons_fixed_size_binary_left_scalar() -> None: - l_arrow = pa.array([b"222"], type=pa.binary(3)) - r_arrow = pa.array([b"111", b"222", b"333", None], type=pa.binary(3)) + l_arrow = make_array([b"222"], type=pa.binary(3)) + r_arrow = make_array([b"111", b"222", b"333", None], type=pa.binary(3)) # gt, eq, lt left = Series.from_arrow(l_arrow) @@ -636,8 +650,8 @@ def test_comparisons_fixed_size_binary_left_scalar() -> None: def test_comparisons_fixed_size_binary_right_scalar() -> None: - l_arrow = pa.array([b"111", b"222", b"333", None, b"555", None], type=pa.binary(3)) - r_arrow = pa.array([b"222"], type=pa.binary(3)) + l_arrow = make_array([b"111", b"222", b"333", None, b"555", None], type=pa.binary(3)) + r_arrow = make_array([b"222"], type=pa.binary(3)) # lt, eq, gt, None, gt, None left = Series.from_arrow(l_arrow)