Skip to content

Commit

Permalink
[FEAT] dec128 math (#3143)
Browse files Browse the repository at this point in the history
* Removes Int128 Type
* Refactor Decimal128 to be backed by a DataArray rather than a
LogicalArray
* Implements math operations for Decimal
* Implements comparison operations for Decimal
  • Loading branch information
samster25 authored Oct 30, 2024
1 parent eaf4d03 commit c78fef4
Show file tree
Hide file tree
Showing 43 changed files with 810 additions and 382 deletions.
1 change: 0 additions & 1 deletion src/arrow2/src/compute/arithmetics/decimal/mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ use super::{adjusted_precision_scale, get_parameters, max_value, number_digits};
/// ```
pub fn mul(lhs: &PrimitiveArray<i128>, rhs: &PrimitiveArray<i128>) -> PrimitiveArray<i128> {
let (precision, scale) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap();

let scale = 10i128.pow(scale as u32);
let max = max_value(precision);

Expand Down
41 changes: 33 additions & 8 deletions src/daft-core/src/array/from_iter.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,44 @@
use arrow2::types::months_days_ns;
use std::sync::Arc;

use arrow2::{
array::{MutablePrimitiveArray, PrimitiveArray},
types::months_days_ns,
};

use super::DataArray;
use crate::{array::prelude::*, datatypes::prelude::*};
use crate::{
array::prelude::*,
datatypes::{prelude::*, DaftPrimitiveType},
};

impl<T> DataArray<T>
where
T: DaftNumericType,
T: DaftPrimitiveType,
{
pub fn from_iter(
name: &str,
pub fn from_iter<F: Into<Arc<Field>>>(
field: F,
iter: impl arrow2::trusted_len::TrustedLen<Item = Option<T::Native>>,
) -> Self {
let arrow_array =
Box::new(arrow2::array::PrimitiveArray::<T::Native>::from_trusted_len_iter(iter));
Self::new(Field::new(name, T::get_dtype()).into(), arrow_array).unwrap()
// this is a workaround to prevent overflow issues when dealing with i128 and decimal
// typical behavior would be the result array would always be Decimal(32, 32)
let field = field.into();
let mut array = MutablePrimitiveArray::<T::Native>::from(field.dtype.to_arrow().unwrap());
array.extend_trusted_len(iter);
let data_array: PrimitiveArray<_> = array.into();
Self::new(field, data_array.boxed()).unwrap()
}

pub fn from_values_iter<F: Into<Arc<Field>>>(
field: F,
iter: impl arrow2::trusted_len::TrustedLen<Item = T::Native>,
) -> Self {
// this is a workaround to prevent overflow issues when dealing with i128 and decimal
// typical behavior would be the result array would always be Decimal(32, 32)
let field = field.into();
let mut array = MutablePrimitiveArray::<T::Native>::from(field.dtype.to_arrow().unwrap());
array.extend_trusted_len_values(iter);
let data_array: PrimitiveArray<_> = array.into();
Self::new(field, data_array.boxed()).unwrap()
}
}

Expand Down
11 changes: 6 additions & 5 deletions src/daft-core/src/array/growable/arrow_growable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,6 @@ impl_arrow_backed_data_array_growable!(
Int64Type,
arrow2::array::growable::GrowablePrimitive<'a, i64>
);
impl_arrow_backed_data_array_growable!(
ArrowInt128Growable,
Int128Type,
arrow2::array::growable::GrowablePrimitive<'a, i128>
);
impl_arrow_backed_data_array_growable!(
ArrowUInt8Growable,
UInt8Type,
Expand Down Expand Up @@ -168,6 +163,12 @@ impl_arrow_backed_data_array_growable!(
arrow2::array::growable::GrowablePrimitive<'a, months_days_ns>
);

impl_arrow_backed_data_array_growable!(
ArrowDecimal128Growable,
Decimal128Type,
arrow2::array::growable::GrowablePrimitive<'a, i128>
);

/// ExtensionTypes are slightly different, because they have a dynamic inner type
pub struct ArrowExtensionGrowable<'a> {
name: String,
Expand Down
1 change: 0 additions & 1 deletion src/daft-core/src/array/growable/logical_growable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,5 @@ impl_logical_growable!(
FixedShapeSparseTensorType
);
impl_logical_growable!(LogicalImageGrowable, ImageType);
impl_logical_growable!(LogicalDecimal128Growable, Decimal128Type);
impl_logical_growable!(LogicalTensorGrowable, TensorType);
impl_logical_growable!(LogicalMapGrowable, MapType);
6 changes: 1 addition & 5 deletions src/daft-core/src/array/growable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ impl_growable_array!(Int8Array, arrow_growable::ArrowInt8Growable<'a>);
impl_growable_array!(Int16Array, arrow_growable::ArrowInt16Growable<'a>);
impl_growable_array!(Int32Array, arrow_growable::ArrowInt32Growable<'a>);
impl_growable_array!(Int64Array, arrow_growable::ArrowInt64Growable<'a>);
impl_growable_array!(Int128Array, arrow_growable::ArrowInt128Growable<'a>);
impl_growable_array!(Decimal128Array, arrow_growable::ArrowDecimal128Growable<'a>);
impl_growable_array!(UInt8Array, arrow_growable::ArrowUInt8Growable<'a>);
impl_growable_array!(UInt16Array, arrow_growable::ArrowUInt16Growable<'a>);
impl_growable_array!(UInt32Array, arrow_growable::ArrowUInt32Growable<'a>);
Expand Down Expand Up @@ -218,8 +218,4 @@ impl_growable_array!(
);
impl_growable_array!(ImageArray, logical_growable::LogicalImageGrowable<'a>);
impl_growable_array!(TensorArray, logical_growable::LogicalTensorGrowable<'a>);
impl_growable_array!(
Decimal128Array,
logical_growable::LogicalDecimal128Growable<'a>
);
impl_growable_array!(MapArray, logical_growable::LogicalMapGrowable<'a>);
31 changes: 15 additions & 16 deletions src/daft-core/src/array/ops/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,25 @@ use arrow2::array::PrimitiveArray;
use common_error::{DaftError, DaftResult};

use super::full::FullNull;
use crate::{array::DataArray, datatypes::DaftNumericType, utils::arrow::arrow_bitmap_and_helper};
use crate::{
array::DataArray,
datatypes::{DaftNumericType, DaftPrimitiveType},
utils::arrow::arrow_bitmap_and_helper,
};

impl<T> DataArray<T>
where
T: DaftNumericType,
T: DaftPrimitiveType,
{
// applies a native function to a numeric DataArray maintaining validity of the source array.
pub fn apply<F>(&self, func: F) -> DaftResult<Self>
where
F: Fn(T::Native) -> T::Native + Copy,
{
let arr: &PrimitiveArray<T::Native> = self.data().as_any().downcast_ref().unwrap();
let result_arr =
PrimitiveArray::from_trusted_len_values_iter(arr.values_iter().map(|v| func(*v)))
.with_validity(arr.validity().cloned());
let iter = arr.values_iter().map(|v| func(*v));

Ok(Self::from((self.name(), Box::new(result_arr))))
Self::from_values_iter(self.field.clone(), iter).with_validity(arr.validity().cloned())
}

// applies a native binary function to two DataArrays, maintaining validity.
Expand All @@ -40,11 +42,10 @@ where
rhs.data().as_any().downcast_ref().unwrap();

let validity = arrow_bitmap_and_helper(lhs_arr.validity(), rhs_arr.validity());
let result_arr = PrimitiveArray::from_trusted_len_values_iter(
zip(lhs_arr.values_iter(), rhs_arr.values_iter()).map(|(a, b)| func(*a, *b)),
)
.with_validity(validity);
Ok(Self::from((self.name(), Box::new(result_arr))))

let iter =
zip(lhs_arr.values_iter(), rhs_arr.values_iter()).map(|(a, b)| func(*a, *b));
Self::from_values_iter(self.field.clone(), iter).with_validity(validity)
}
(l_size, 1) => {
if let Some(value) = rhs.get(0) {
Expand All @@ -57,11 +58,9 @@ where
let rhs_arr: &PrimitiveArray<R::Native> =
rhs.data().as_any().downcast_ref().unwrap();
if let Some(value) = self.get(0) {
let result_arr = PrimitiveArray::from_trusted_len_values_iter(
rhs_arr.values_iter().map(|v| func(value, *v)),
)
.with_validity(rhs_arr.validity().cloned());
Ok(Self::from((self.name(), Box::new(result_arr))))
let iter = rhs_arr.values_iter().map(|v| func(value, *v));
Self::from_values_iter(self.field.clone(), iter)
.with_validity(rhs_arr.validity().cloned())
} else {
Ok(Self::full_null(self.name(), self.data_type(), r_size))
}
Expand Down
121 changes: 115 additions & 6 deletions src/daft-core/src/array/ops/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ use common_error::{DaftError, DaftResult};
use super::{as_arrow::AsArrow, full::FullNull};
use crate::{
array::{DataArray, FixedSizeListArray},
datatypes::{DaftNumericType, DataType, Field, Utf8Array},
datatypes::{DaftNumericType, DaftPrimitiveType, DataType, Field, Utf8Array},
kernels::utf8::add_utf8_arrays,
prelude::Decimal128Array,
series::Series,
};
// Permission is hereby granted, free of charge, to any person obtaining a copy
Expand Down Expand Up @@ -38,15 +39,16 @@ fn arithmetic_helper<T, Kernel, F>(
operation: F,
) -> DaftResult<DataArray<T>>
where
T: DaftNumericType,
Kernel: Fn(&PrimitiveArray<T::Native>, &PrimitiveArray<T::Native>) -> PrimitiveArray<T::Native>,
T: DaftPrimitiveType,
Kernel:
FnOnce(&PrimitiveArray<T::Native>, &PrimitiveArray<T::Native>) -> PrimitiveArray<T::Native>,
F: Fn(T::Native, T::Native) -> T::Native,
{
match (lhs.len(), rhs.len()) {
(a, b) if a == b => Ok(DataArray::from((
lhs.name(),
(a, b) if a == b => DataArray::new(
lhs.field.clone(),
Box::new(kernel(lhs.as_arrow(), rhs.as_arrow())),
))),
),
// broadcast right path
(_, 1) => {
let opt_rhs = rhs.get(0);
Expand Down Expand Up @@ -79,6 +81,50 @@ where
}
}

impl Add for &Decimal128Array {
type Output = DaftResult<Decimal128Array>;
fn add(self, rhs: Self) -> Self::Output {
assert_eq!(self.data_type(), rhs.data_type());
arithmetic_helper(
self,
rhs,
arrow2::compute::arithmetics::decimal::add,
|l, r| l + r,
)
}
}

impl Sub for &Decimal128Array {
type Output = DaftResult<Decimal128Array>;
fn sub(self, rhs: Self) -> Self::Output {
assert_eq!(self.data_type(), rhs.data_type());
arithmetic_helper(
self,
rhs,
arrow2::compute::arithmetics::decimal::sub,
|l, r| l - r,
)
}
}

impl Mul for &Decimal128Array {
type Output = DaftResult<Decimal128Array>;
fn mul(self, rhs: Self) -> Self::Output {
assert_eq!(self.data_type(), rhs.data_type());

let DataType::Decimal128(_, s) = self.data_type() else {
unreachable!("This should always be a Decimal128")
};
let scale = 10i128.pow(*s as u32);
arithmetic_helper(
self,
rhs,
arrow2::compute::arithmetics::decimal::mul,
|l, r| (l * r) / scale,
)
}
}

impl Add for &Utf8Array {
type Output = DaftResult<Utf8Array>;
fn add(self, rhs: Self) -> Self::Output {
Expand Down Expand Up @@ -236,6 +282,69 @@ where
}
}

impl Div for &Decimal128Array {
type Output = DaftResult<Decimal128Array>;
fn div(self, rhs: Self) -> Self::Output {
assert_eq!(self.data_type(), rhs.data_type());
let DataType::Decimal128(_, s) = self.data_type() else {
unreachable!("This should always be a Decimal128")
};
let scale = 10i128.pow(*s as u32);

if rhs.data().null_count() == 0 {
arithmetic_helper(
self,
rhs,
arrow2::compute::arithmetics::decimal::div,
|l, r| ((l * scale) / r),
)
} else {
match (self.len(), rhs.len()) {
(a, b) if a == b => {
let values = self
.as_arrow()
.iter()
.zip(rhs.as_arrow().iter())
.map(|(l, r)| match (l, r) {
(None, _) => None,
(_, None) => None,
(Some(l), Some(r)) => Some((l * scale) / r),
});
Ok(Decimal128Array::from_iter(self.field.clone(), values))
}
// broadcast right path
(_, 1) => {
let opt_rhs = rhs.get(0);
match opt_rhs {
None => Ok(DataArray::full_null(
self.name(),
self.data_type(),
self.len(),
)),
Some(rhs) => self.apply(|lhs| ((lhs * scale) / rhs)),
}
}
(1, _) => {
let opt_lhs = self.get(0);
Ok(match opt_lhs {
None => DataArray::full_null(rhs.name(), rhs.data_type(), rhs.len()),
Some(lhs) => {
let values_iter = rhs
.as_arrow()
.iter()
.map(|v| v.map(|v| ((lhs * scale) / *v)));
Decimal128Array::from_iter(self.field.clone(), values_iter)
}
})
}
(a, b) => Err(DaftError::ValueError(format!(
"Cannot apply operation on arrays of different lengths: {a} vs {b}"
))),
}
}
}
}

fn fixed_sized_list_arithmetic_helper<Kernel>(
lhs: &FixedSizeListArray,
rhs: &FixedSizeListArray,
Expand Down
9 changes: 4 additions & 5 deletions src/daft-core/src/array/ops/as_arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use crate::datatypes::PythonArray;
use crate::{
array::DataArray,
datatypes::{
logical::{DateArray, Decimal128Array, DurationArray, TimeArray, TimestampArray},
BinaryArray, BooleanArray, DaftNumericType, FixedSizeBinaryArray, IntervalArray, NullArray,
Utf8Array,
logical::{DateArray, DurationArray, TimeArray, TimestampArray},
BinaryArray, BooleanArray, DaftPrimitiveType, FixedSizeBinaryArray, IntervalArray,
NullArray, Utf8Array,
},
};

Expand All @@ -24,7 +24,7 @@ pub trait AsArrow {

impl<T> AsArrow for DataArray<T>
where
T: DaftNumericType,
T: DaftPrimitiveType,
{
type Output = array::PrimitiveArray<T::Native>;

Expand Down Expand Up @@ -66,7 +66,6 @@ impl_asarrow_dataarray!(IntervalArray, array::PrimitiveArray<months_days_ns>);
#[cfg(feature = "python")]
impl_asarrow_dataarray!(PythonArray, PseudoArrowArray<pyo3::PyObject>);

impl_asarrow_logicalarray!(Decimal128Array, array::PrimitiveArray<i128>);
impl_asarrow_logicalarray!(DateArray, array::PrimitiveArray<i32>);
impl_asarrow_logicalarray!(TimeArray, array::PrimitiveArray<i64>);
impl_asarrow_logicalarray!(DurationArray, array::PrimitiveArray<i64>);
Expand Down
Loading

0 comments on commit c78fef4

Please sign in to comment.