From a8876f0b522e96b680fbc1244be82516d3586d5d Mon Sep 17 00:00:00 2001 From: Conor Kennedy <32619800+Vince7778@users.noreply.github.com> Date: Wed, 19 Jun 2024 15:51:53 -0700 Subject: [PATCH] [FEAT] Fixed Size Binary Type v2 (#2403) Supersedes #2266. Additionally implements `binary` to `fixed_size_binary` cast if all values have the correct length. Original description: > Implements a fixed size binary type and array. > > Todo: > > * Optimize the kernels to leverage fixed size lengths --------- Co-authored-by: Colin Ho --- daft/daft.pyi | 2 + daft/datatype.py | 15 +- src/arrow2/src/array/fixed_size_binary/mod.rs | 3 +- src/arrow2/src/compute/cast/binary_to.rs | 62 +++- src/arrow2/src/compute/cast/mod.rs | 7 +- src/daft-core/src/array/from.rs | 15 +- src/daft-core/src/array/from_iter.rs | 19 +- .../src/array/growable/arrow_growable.rs | 9 +- src/daft-core/src/array/growable/mod.rs | 10 +- src/daft-core/src/array/ops/as_arrow.rs | 3 +- src/daft-core/src/array/ops/cast.rs | 3 + src/daft-core/src/array/ops/compare_agg.rs | 107 +++++++ src/daft-core/src/array/ops/comparison.rs | 293 +++++++++++++++++- src/daft-core/src/array/ops/get.rs | 5 +- src/daft-core/src/array/ops/groups.rs | 15 +- src/daft-core/src/array/ops/hash.rs | 32 +- src/daft-core/src/array/ops/is_in.rs | 5 +- src/daft-core/src/array/ops/repr.rs | 18 +- src/daft-core/src/array/ops/sort.rs | 28 +- src/daft-core/src/array/ops/take.rs | 36 ++- src/daft-core/src/array/serdes.rs | 14 +- src/daft-core/src/datatypes/binary_ops.rs | 8 +- src/daft-core/src/datatypes/dtype.rs | 9 +- src/daft-core/src/datatypes/matching.rs | 4 + src/daft-core/src/datatypes/mod.rs | 2 + src/daft-core/src/kernels/hashing.rs | 23 +- src/daft-core/src/kernels/search_sorted.rs | 72 ++++- src/daft-core/src/python/datatype.rs | 11 + .../src/series/array_impl/binary_ops.rs | 3 +- .../src/series/array_impl/data_array.rs | 2 + src/daft-core/src/series/ops/downcast.rs | 4 + src/daft-core/src/series/ops/hash.rs | 1 + src/daft-core/src/series/serdes.rs | 6 + src/daft-core/src/utils/arrow.rs | 4 +- src/daft-core/src/utils/supertype.rs | 2 +- src/daft-stats/src/column_stats/mod.rs | 2 +- tests/expressions/typing/conftest.py | 6 +- tests/series/test_cast.py | 18 ++ tests/series/test_comparisons.py | 79 +++++ tests/series/test_filter.py | 5 +- tests/series/test_hash.py | 31 ++ tests/series/test_if_else.py | 23 +- tests/series/test_series.py | 15 +- tests/series/test_size_bytes.py | 18 ++ tests/series/test_take.py | 5 +- tests/table/test_from_py.py | 7 + tests/table/test_sorting.py | 6 +- tests/table/test_table_aggs.py | 82 +++-- tests/test_datatypes.py | 6 +- 49 files changed, 1053 insertions(+), 102 deletions(-) diff --git a/daft/daft.pyi b/daft/daft.pyi index 50913cd567..e396e6f7c4 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -928,6 +928,8 @@ class PyDataType: @staticmethod def binary() -> PyDataType: ... @staticmethod + def fixed_size_binary(size: int) -> PyDataType: ... + @staticmethod def string() -> PyDataType: ... @staticmethod def decimal128(precision: int, size: int) -> PyDataType: ... diff --git a/daft/datatype.py b/daft/datatype.py index 4310b23f08..4077a7b0e3 100644 --- a/daft/datatype.py +++ b/daft/datatype.py @@ -178,6 +178,13 @@ def binary(cls) -> DataType: """Create a Binary DataType: A string of bytes""" return cls._from_pydatatype(PyDataType.binary()) + @classmethod + def fixed_size_binary(cls, size: int) -> DataType: + """Create a FixedSizeBinary DataType: A fixed-size string of bytes""" + if not isinstance(size, int) or size <= 0: + raise ValueError("The size for a fixed-size binary must be a positive integer, but got: ", size) + return cls._from_pydatatype(PyDataType.fixed_size_binary(size)) + @classmethod def null(cls) -> DataType: """Creates the Null DataType: Always the ``Null`` value""" @@ -364,12 +371,10 @@ def from_arrow_type(cls, arrow_type: pa.lib.DataType) -> DataType: return cls.float64() elif pa.types.is_string(arrow_type) or pa.types.is_large_string(arrow_type): return cls.string() - elif ( - pa.types.is_binary(arrow_type) - or pa.types.is_large_binary(arrow_type) - or pa.types.is_fixed_size_binary(arrow_type) - ): + elif pa.types.is_binary(arrow_type) or pa.types.is_large_binary(arrow_type): return cls.binary() + elif pa.types.is_fixed_size_binary(arrow_type): + return cls.fixed_size_binary(arrow_type.byte_width) elif pa.types.is_boolean(arrow_type): return cls.bool() elif pa.types.is_null(arrow_type): diff --git a/src/arrow2/src/array/fixed_size_binary/mod.rs b/src/arrow2/src/array/fixed_size_binary/mod.rs index 973f0f031b..88f0e2975a 100644 --- a/src/arrow2/src/array/fixed_size_binary/mod.rs +++ b/src/arrow2/src/array/fixed_size_binary/mod.rs @@ -47,7 +47,8 @@ impl FixedSizeBinaryArray { .map_or(false, |validity| validity.len() != len) { return Err(Error::oos( - "validity mask length must be equal to the number of values divided by size", + format!("validity mask length (got {}) must be equal to the number of values ({}) divided by size ({})", + validity.unwrap().len(), values.len(), size), )); } diff --git a/src/arrow2/src/compute/cast/binary_to.rs b/src/arrow2/src/compute/cast/binary_to.rs index 9a74c1f521..7faaf3ec8f 100644 --- a/src/arrow2/src/compute/cast/binary_to.rs +++ b/src/arrow2/src/compute/cast/binary_to.rs @@ -1,4 +1,4 @@ -use crate::error::Result; +use crate::error::{Error, Result}; use crate::offset::{Offset, Offsets}; use crate::{array::*, datatypes::DataType, types::NativeType}; @@ -155,6 +155,66 @@ pub fn fixed_size_binary_binary( ) } +pub fn binary_to_fixed_size_binary( + from: &BinaryArray, + size: usize, +) -> Result> { + if let Some(validity) = from.validity() { + // Ensure all valid elements have the right size + for (value, valid) in from.values_iter().zip(validity) { + if valid && value.len() != size { + return Err(Error::InvalidArgumentError( + format!( + "element has invalid length ({}, expected {})", + value.len(), + size + ) + .to_string(), + )); + } + } + + // Copy values to new buffer, accounting for validity + let mut values: Vec = Vec::new(); + let offsets = from.offsets().buffer().iter(); + let from_values = from.values(); + for (off, valid) in offsets.zip(validity) { + if valid { + let start = off.to_usize(); + let end = start + size; + values.extend(&from_values[start..end]); + } else { + values.extend(std::iter::repeat(0u8).take(size)); + } + } + Ok(Box::new(FixedSizeBinaryArray::try_new( + DataType::FixedSizeBinary(size), + values.into(), + from.validity().cloned(), + )?)) + } else { + // Ensure all elements have the right size + for value in from.values_iter() { + if value.len() != size { + return Err(Error::InvalidArgumentError( + format!( + "element has invalid length ({}, expected {})", + value.len(), + size + ) + .to_string(), + )); + } + } + + Ok(Box::new(FixedSizeBinaryArray::try_new( + DataType::FixedSizeBinary(size), + from.values().clone(), + from.validity().cloned(), + )?)) + } +} + /// Conversion of binary pub fn binary_to_list(from: &BinaryArray, to_data_type: DataType) -> ListArray { let values = from.values().clone(); diff --git a/src/arrow2/src/compute/cast/mod.rs b/src/arrow2/src/compute/cast/mod.rs index 96002ceb70..c5a389b32d 100644 --- a/src/arrow2/src/compute/cast/mod.rs +++ b/src/arrow2/src/compute/cast/mod.rs @@ -156,12 +156,13 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { } (Binary, to_type) => { - is_numeric(to_type) || matches!(to_type, LargeBinary | Utf8 | LargeUtf8) + is_numeric(to_type) + || matches!(to_type, LargeBinary | FixedSizeBinary(_) | Utf8 | LargeUtf8) } (LargeBinary, to_type) => { is_numeric(to_type) || match to_type { - Binary | LargeUtf8 => true, + Binary | FixedSizeBinary(_) | LargeUtf8 => true, LargeList(field) => matches!(field.data_type, UInt8), _ => false, } @@ -772,6 +773,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu Int64 => binary_to_primitive_dyn::(array, to_type, options), Float32 => binary_to_primitive_dyn::(array, to_type, options), Float64 => binary_to_primitive_dyn::(array, to_type, options), + FixedSizeBinary(size) => binary_to_fixed_size_binary::(array.as_any().downcast_ref().unwrap(), *size), LargeBinary => Ok(Box::new(binary_to_large_binary( array.as_any().downcast_ref().unwrap(), to_type.clone(), @@ -800,6 +802,7 @@ pub fn cast(array: &dyn Array, to_type: &DataType, options: CastOptions) -> Resu binary_large_to_binary(array.as_any().downcast_ref().unwrap(), to_type.clone()) .map(|x| x.boxed()) } + FixedSizeBinary(size) => binary_to_fixed_size_binary::(array.as_any().downcast_ref().unwrap(), *size), LargeUtf8 => { binary_to_utf8::(array.as_any().downcast_ref().unwrap(), to_type.clone()) .map(|x| x.boxed()) diff --git a/src/daft-core/src/array/from.rs b/src/daft-core/src/array/from.rs index e3ac0b5d5d..a4ddc49166 100644 --- a/src/daft-core/src/array/from.rs +++ b/src/daft-core/src/array/from.rs @@ -1,8 +1,8 @@ use std::sync::Arc; use crate::datatypes::{ - BinaryArray, BooleanArray, DaftNumericType, DaftPhysicalType, DataType, Field, NullArray, - Utf8Array, Utf8Type, + BinaryArray, BooleanArray, DaftNumericType, DaftPhysicalType, DataType, Field, + FixedSizeBinaryArray, NullArray, Utf8Array, Utf8Type, }; use crate::array::DataArray; @@ -38,6 +38,17 @@ impl From<(&str, Box>)> for BinaryArray { } } +impl From<(&str, Box)> for FixedSizeBinaryArray { + fn from(item: (&str, Box)) -> Self { + let (name, array) = item; + DataArray::new( + Field::new(name, DataType::FixedSizeBinary(array.size())).into(), + array, + ) + .unwrap() + } +} + impl From<(&str, &[T::Native])> for DataArray where T: DaftNumericType, diff --git a/src/daft-core/src/array/from_iter.rs b/src/daft-core/src/array/from_iter.rs index 19c5e292fd..2c300dc48b 100644 --- a/src/daft-core/src/array/from_iter.rs +++ b/src/daft-core/src/array/from_iter.rs @@ -1,4 +1,6 @@ -use crate::datatypes::{BinaryArray, BooleanArray, DaftNumericType, Field, Utf8Array}; +use crate::datatypes::{ + BinaryArray, BooleanArray, DaftNumericType, Field, FixedSizeBinaryArray, Utf8Array, +}; use super::DataArray; @@ -42,6 +44,21 @@ impl BinaryArray { } } +impl FixedSizeBinaryArray { + pub fn from_iter>( + name: &str, + iter: impl arrow2::trusted_len::TrustedLen>, + size: usize, + ) -> Self { + let arrow_array = Box::new(arrow2::array::FixedSizeBinaryArray::from_iter(iter, size)); + DataArray::new( + Field::new(name, crate::DataType::FixedSizeBinary(size)).into(), + arrow_array, + ) + .unwrap() + } +} + impl BooleanArray { pub fn from_iter( name: &str, diff --git a/src/daft-core/src/array/growable/arrow_growable.rs b/src/daft-core/src/array/growable/arrow_growable.rs index a964960058..cac34a81a3 100644 --- a/src/daft-core/src/array/growable/arrow_growable.rs +++ b/src/daft-core/src/array/growable/arrow_growable.rs @@ -9,8 +9,8 @@ use crate::{ }, datatypes::{ BinaryType, BooleanType, DaftArrowBackedType, DaftDataType, ExtensionArray, Field, - Float32Type, Float64Type, Int128Type, Int16Type, Int32Type, Int64Type, Int8Type, NullType, - UInt16Type, UInt32Type, UInt64Type, UInt8Type, Utf8Type, + FixedSizeBinaryType, Float32Type, Float64Type, Int128Type, Int16Type, Int32Type, Int64Type, + Int8Type, NullType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, Utf8Type, }, DataType, IntoSeries, Series, }; @@ -159,6 +159,11 @@ impl_arrow_backed_data_array_growable!( BinaryType, arrow2::array::growable::GrowableBinary<'a, i64> ); +impl_arrow_backed_data_array_growable!( + ArrowFixedSizeBinaryGrowable, + FixedSizeBinaryType, + arrow2::array::growable::GrowableFixedSizeBinary<'a> +); impl_arrow_backed_data_array_growable!( ArrowUtf8Growable, Utf8Type, diff --git a/src/daft-core/src/array/growable/mod.rs b/src/daft-core/src/array/growable/mod.rs index 9c7a93d515..812f9d4f8c 100644 --- a/src/daft-core/src/array/growable/mod.rs +++ b/src/daft-core/src/array/growable/mod.rs @@ -7,9 +7,9 @@ use crate::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, FixedShapeTensorArray, ImageArray, MapArray, TensorArray, TimeArray, TimestampArray, }, - BinaryArray, BooleanArray, ExtensionArray, Float32Array, Float64Array, Int128Array, - Int16Array, Int32Array, Int64Array, Int8Array, NullArray, UInt16Array, UInt32Array, - UInt64Array, UInt8Array, Utf8Array, + BinaryArray, BooleanArray, ExtensionArray, FixedSizeBinaryArray, Float32Array, + Float64Array, Int128Array, Int16Array, Int32Array, Int64Array, Int8Array, NullArray, + UInt16Array, UInt32Array, UInt64Array, UInt8Array, Utf8Array, }, with_match_daft_types, DataType, Series, }; @@ -179,6 +179,10 @@ impl_growable_array!(UInt64Array, arrow_growable::ArrowUInt64Growable<'a>); impl_growable_array!(Float32Array, arrow_growable::ArrowFloat32Growable<'a>); impl_growable_array!(Float64Array, arrow_growable::ArrowFloat64Growable<'a>); impl_growable_array!(BinaryArray, arrow_growable::ArrowBinaryGrowable<'a>); +impl_growable_array!( + FixedSizeBinaryArray, + arrow_growable::ArrowFixedSizeBinaryGrowable<'a> +); impl_growable_array!(Utf8Array, arrow_growable::ArrowUtf8Growable<'a>); impl_growable_array!(ExtensionArray, arrow_growable::ArrowExtensionGrowable<'a>); impl_growable_array!( diff --git a/src/daft-core/src/array/ops/as_arrow.rs b/src/daft-core/src/array/ops/as_arrow.rs index fe7c216f5b..07302c8161 100644 --- a/src/daft-core/src/array/ops/as_arrow.rs +++ b/src/daft-core/src/array/ops/as_arrow.rs @@ -4,7 +4,7 @@ use crate::{ array::DataArray, datatypes::{ logical::{DateArray, Decimal128Array, DurationArray, TimeArray, TimestampArray}, - BinaryArray, BooleanArray, DaftNumericType, NullArray, Utf8Array, + BinaryArray, BooleanArray, DaftNumericType, FixedSizeBinaryArray, NullArray, Utf8Array, }, }; @@ -58,6 +58,7 @@ impl_asarrow_dataarray!(NullArray, array::NullArray); impl_asarrow_dataarray!(Utf8Array, array::Utf8Array); impl_asarrow_dataarray!(BooleanArray, array::BooleanArray); impl_asarrow_dataarray!(BinaryArray, array::BinaryArray); +impl_asarrow_dataarray!(FixedSizeBinaryArray, array::FixedSizeBinaryArray); #[cfg(feature = "python")] impl_asarrow_dataarray!(PythonArray, PseudoArrowArray); diff --git a/src/daft-core/src/array/ops/cast.rs b/src/daft-core/src/array/ops/cast.rs index b92e521238..8e0919d39a 100644 --- a/src/daft-core/src/array/ops/cast.rs +++ b/src/daft-core/src/array/ops/cast.rs @@ -1028,6 +1028,9 @@ impl PythonArray { } DataType::Boolean => pycast_then_arrowcast!(self, DataType::Boolean, "bool"), DataType::Binary => pycast_then_arrowcast!(self, DataType::Binary, "bytes"), + DataType::FixedSizeBinary(size) => { + pycast_then_arrowcast!(self, DataType::FixedSizeBinary(*size), "fixed_size_bytes") + } DataType::Utf8 => pycast_then_arrowcast!(self, DataType::Utf8, "str"), dt @ DataType::UInt8 | dt @ DataType::UInt16 diff --git a/src/daft-core/src/array/ops/compare_agg.rs b/src/daft-core/src/array/ops/compare_agg.rs index 287260f8c1..f358e91b7d 100644 --- a/src/daft-core/src/array/ops/compare_agg.rs +++ b/src/daft-core/src/array/ops/compare_agg.rs @@ -266,6 +266,113 @@ impl DaftCompareAggable for DataArray { } } +fn cmp_fixed_size_binary<'a, F>( + data_array: &'a FixedSizeBinaryArray, + op: F, +) -> DaftResult +where + F: Fn(&'a [u8], &'a [u8]) -> &'a [u8], +{ + let arrow_array = data_array.as_arrow(); + if arrow_array.null_count() == arrow_array.len() { + Ok(FixedSizeBinaryArray::full_null( + data_array.name(), + &DataType::FixedSizeBinary(arrow_array.size()), + 1, + )) + } else if arrow_array.validity().is_some() { + let res = arrow_array + .iter() + .reduce(|v1, v2| match (v1, v2) { + (None, v2) => v2, + (v1, None) => v1, + (Some(v1), Some(v2)) => Some(op(v1, v2)), + }) + .unwrap_or(None); + Ok(FixedSizeBinaryArray::from_iter( + data_array.name(), + std::iter::once(res), + arrow_array.size(), + )) + } else { + let res = arrow_array.values_iter().reduce(|v1, v2| op(v1, v2)); + Ok(FixedSizeBinaryArray::from_iter( + data_array.name(), + std::iter::once(res), + arrow_array.size(), + )) + } +} + +fn grouped_cmp_fixed_size_binary<'a, F>( + data_array: &'a FixedSizeBinaryArray, + op: F, + groups: &GroupIndices, +) -> DaftResult +where + F: Fn(&'a [u8], &'a [u8]) -> &'a [u8], +{ + let arrow_array = data_array.as_arrow(); + let cmp_per_group = if arrow_array.null_count() > 0 { + let cmp_values_iter = groups.iter().map(|g| { + let reduced_val = g + .iter() + .map(|i| { + let idx = *i as usize; + match arrow_array.is_null(idx) { + false => Some(unsafe { arrow_array.value_unchecked(idx) }), + true => None, + } + }) + .reduce(|l, r| match (l, r) { + (None, None) => None, + (None, Some(r)) => Some(r), + (Some(l), None) => Some(l), + (Some(l), Some(r)) => Some(op(l, r)), + }); + reduced_val.unwrap_or_default() + }); + Box::new(arrow2::array::FixedSizeBinaryArray::from_iter( + cmp_values_iter, + arrow_array.size(), + )) + } else { + Box::new(arrow2::array::FixedSizeBinaryArray::from_iter( + groups.iter().map(|g| { + g.iter() + .map(|i| { + let idx = *i as usize; + unsafe { arrow_array.value_unchecked(idx) } + }) + .reduce(|l, r| op(l, r)) + }), + arrow_array.size(), + )) + }; + Ok(DataArray::from(( + data_array.field.name.as_ref(), + cmp_per_group, + ))) +} + +impl DaftCompareAggable for DataArray { + type Output = DaftResult>; + fn min(&self) -> Self::Output { + cmp_fixed_size_binary(self, |l, r| l.min(r)) + } + fn max(&self) -> Self::Output { + cmp_fixed_size_binary(self, |l, r| l.max(r)) + } + + fn grouped_min(&self, groups: &GroupIndices) -> Self::Output { + grouped_cmp_fixed_size_binary(self, |l, r| l.min(r), groups) + } + + fn grouped_max(&self, groups: &GroupIndices) -> Self::Output { + grouped_cmp_fixed_size_binary(self, |l, r| l.max(r), groups) + } +} + fn grouped_cmp_bool( data_array: &BooleanArray, val_to_find: bool, diff --git a/src/daft-core/src/array/ops/comparison.rs b/src/daft-core/src/array/ops/comparison.rs index 2d45551fad..711975a341 100644 --- a/src/daft-core/src/array/ops/comparison.rs +++ b/src/daft-core/src/array/ops/comparison.rs @@ -3,8 +3,8 @@ use num_traits::{NumCast, ToPrimitive}; use crate::{ array::DataArray, datatypes::{ - BinaryArray, BooleanArray, DaftArrowBackedType, DaftNumericType, DataType, NullArray, - Utf8Array, + BinaryArray, BooleanArray, DaftArrowBackedType, DaftNumericType, DataType, Field, + FixedSizeBinaryArray, NullArray, Utf8Array, }, utils::arrow::arrow_bitmap_and_helper, }; @@ -13,7 +13,7 @@ use common_error::{DaftError, DaftResult}; use std::ops::Not; -use super::{full::FullNull, DaftCompare, DaftLogical}; +use super::{from_arrow::FromArrow, full::FullNull, DaftCompare, DaftLogical}; use super::as_arrow::AsArrow; use arrow2::{compute::comparison, scalar::PrimitiveScalar}; @@ -1462,6 +1462,293 @@ impl DaftCompare<&[u8]> for BinaryArray { } } +fn compare_fixed_size_binary( + lhs: &FixedSizeBinaryArray, + rhs: &FixedSizeBinaryArray, + op: F, +) -> DaftResult +where + F: Fn(&[u8], &[u8]) -> bool, +{ + let lhs_arrow = lhs.as_arrow(); + let rhs_arrow = rhs.as_arrow(); + let validity = match (lhs_arrow.validity(), rhs_arrow.validity()) { + (Some(lhs), None) => Some(lhs.clone()), + (None, Some(rhs)) => Some(rhs.clone()), + (None, None) => None, + (Some(lhs), Some(rhs)) => Some(lhs & rhs), + }; + + let values = lhs_arrow + .values_iter() + .zip(rhs_arrow.values_iter()) + .map(|(lhs, rhs)| op(lhs, rhs)); + let values = arrow2::bitmap::Bitmap::from_trusted_len_iter(values); + + BooleanArray::from_arrow( + Field::new(lhs.name(), DataType::Boolean).into(), + Box::new(arrow2::array::BooleanArray::new( + arrow2::datatypes::DataType::Boolean, + values, + validity, + )), + ) +} + +fn cmp_fixed_size_binary_scalar( + lhs: &FixedSizeBinaryArray, + rhs: &[u8], + op: F, +) -> DaftResult +where + F: Fn(&[u8], &[u8]) -> bool, +{ + let lhs_arrow = lhs.as_arrow(); + let validity = lhs_arrow.validity().cloned(); + + let values = lhs_arrow.values_iter().map(|lhs| op(lhs, rhs)); + let values = arrow2::bitmap::Bitmap::from_trusted_len_iter(values); + + BooleanArray::from_arrow( + Field::new(lhs.name(), DataType::Boolean).into(), + Box::new(arrow2::array::BooleanArray::new( + arrow2::datatypes::DataType::Boolean, + values, + validity, + )), + ) +} + +impl DaftCompare<&FixedSizeBinaryArray> for FixedSizeBinaryArray { + type Output = DaftResult; + + fn equal(&self, rhs: &FixedSizeBinaryArray) -> Self::Output { + match (self.len(), rhs.len()) { + (x, y) if x == y => compare_fixed_size_binary(self, rhs, |lhs, rhs| lhs == rhs), + (l_size, 1) => { + if let Some(value) = rhs.get(0) { + self.equal(value) + } else { + Ok(BooleanArray::full_null( + self.name(), + &DataType::Boolean, + l_size, + )) + } + } + (1, r_size) => { + if let Some(value) = self.get(0) { + rhs.equal(value).map(|v| v.rename(self.name())) + } else { + Ok(BooleanArray::full_null( + self.name(), + &DataType::Boolean, + r_size, + )) + } + } + (l, r) => Err(DaftError::ValueError(format!( + "trying to compare different length arrays: {}: {l} vs {}: {r}", + self.name(), + rhs.name() + ))), + } + } + + fn not_equal(&self, rhs: &FixedSizeBinaryArray) -> Self::Output { + match (self.len(), rhs.len()) { + (x, y) if x == y => compare_fixed_size_binary(self, rhs, |lhs, rhs| lhs != rhs), + (l_size, 1) => { + if let Some(value) = rhs.get(0) { + self.not_equal(value) + } else { + Ok(BooleanArray::full_null( + self.name(), + &DataType::Boolean, + l_size, + )) + } + } + (1, r_size) => { + if let Some(value) = self.get(0) { + rhs.not_equal(value) + } else { + Ok(BooleanArray::full_null( + self.name(), + &DataType::Boolean, + r_size, + )) + } + } + (l, r) => Err(DaftError::ValueError(format!( + "trying to compare different length arrays: {}: {l} vs {}: {r}", + self.name(), + rhs.name() + ))), + } + } + + fn lt(&self, rhs: &FixedSizeBinaryArray) -> Self::Output { + match (self.len(), rhs.len()) { + (x, y) if x == y => compare_fixed_size_binary(self, rhs, |lhs, rhs| lhs < rhs), + (l_size, 1) => { + if let Some(value) = rhs.get(0) { + self.lt(value) + } else { + Ok(BooleanArray::full_null( + self.name(), + &DataType::Boolean, + l_size, + )) + } + } + (1, r_size) => { + if let Some(value) = self.get(0) { + rhs.gt(value) + } else { + Ok(BooleanArray::full_null( + self.name(), + &DataType::Boolean, + r_size, + )) + } + } + (l, r) => Err(DaftError::ValueError(format!( + "trying to compare different length arrays: {}: {l} vs {}: {r}", + self.name(), + rhs.name() + ))), + } + } + + fn lte(&self, rhs: &FixedSizeBinaryArray) -> Self::Output { + match (self.len(), rhs.len()) { + (x, y) if x == y => compare_fixed_size_binary(self, rhs, |lhs, rhs| lhs <= rhs), + (l_size, 1) => { + if let Some(value) = rhs.get(0) { + self.lte(value) + } else { + Ok(BooleanArray::full_null( + self.name(), + &DataType::Boolean, + l_size, + )) + } + } + (1, r_size) => { + if let Some(value) = self.get(0) { + rhs.gte(value) + } else { + Ok(BooleanArray::full_null( + self.name(), + &DataType::Boolean, + r_size, + )) + } + } + (l, r) => Err(DaftError::ValueError(format!( + "trying to compare different length arrays: {}: {l} vs {}: {r}", + self.name(), + rhs.name() + ))), + } + } + + fn gt(&self, rhs: &FixedSizeBinaryArray) -> Self::Output { + match (self.len(), rhs.len()) { + (x, y) if x == y => compare_fixed_size_binary(self, rhs, |lhs, rhs| lhs > rhs), + (l_size, 1) => { + if let Some(value) = rhs.get(0) { + self.gt(value) + } else { + Ok(BooleanArray::full_null( + self.name(), + &DataType::Boolean, + l_size, + )) + } + } + (1, r_size) => { + if let Some(value) = self.get(0) { + rhs.lt(value) + } else { + Ok(BooleanArray::full_null( + self.name(), + &DataType::Boolean, + r_size, + )) + } + } + (l, r) => Err(DaftError::ValueError(format!( + "trying to compare different length arrays: {}: {l} vs {}: {r}", + self.name(), + rhs.name() + ))), + } + } + + fn gte(&self, rhs: &FixedSizeBinaryArray) -> Self::Output { + match (self.len(), rhs.len()) { + (x, y) if x == y => compare_fixed_size_binary(self, rhs, |lhs, rhs| lhs >= rhs), + (l_size, 1) => { + if let Some(value) = rhs.get(0) { + self.gte(value) + } else { + Ok(BooleanArray::full_null( + self.name(), + &DataType::Boolean, + l_size, + )) + } + } + (1, r_size) => { + if let Some(value) = self.get(0) { + rhs.lte(value) + } else { + Ok(BooleanArray::full_null( + self.name(), + &DataType::Boolean, + r_size, + )) + } + } + (l, r) => Err(DaftError::ValueError(format!( + "trying to compare different length arrays: {}: {l} vs {}: {r}", + self.name(), + rhs.name() + ))), + } + } +} + +impl DaftCompare<&[u8]> for FixedSizeBinaryArray { + type Output = DaftResult; + + fn equal(&self, rhs: &[u8]) -> Self::Output { + cmp_fixed_size_binary_scalar(self, rhs, |lhs, rhs| lhs == rhs) + } + + fn not_equal(&self, rhs: &[u8]) -> Self::Output { + cmp_fixed_size_binary_scalar(self, rhs, |lhs, rhs| lhs != rhs) + } + + fn lt(&self, rhs: &[u8]) -> Self::Output { + cmp_fixed_size_binary_scalar(self, rhs, |lhs, rhs| lhs < rhs) + } + + fn lte(&self, rhs: &[u8]) -> Self::Output { + cmp_fixed_size_binary_scalar(self, rhs, |lhs, rhs| lhs <= rhs) + } + + fn gt(&self, rhs: &[u8]) -> Self::Output { + cmp_fixed_size_binary_scalar(self, rhs, |lhs, rhs| lhs > rhs) + } + + fn gte(&self, rhs: &[u8]) -> Self::Output { + cmp_fixed_size_binary_scalar(self, rhs, |lhs, rhs| lhs >= rhs) + } +} + #[cfg(test)] mod tests { use crate::{array::ops::DaftCompare, datatypes::Int64Array}; diff --git a/src/daft-core/src/array/ops/get.rs b/src/daft-core/src/array/ops/get.rs index 5776d19b31..1868669d0e 100644 --- a/src/daft-core/src/array/ops/get.rs +++ b/src/daft-core/src/array/ops/get.rs @@ -5,8 +5,8 @@ use crate::{ DateArray, Decimal128Array, DurationArray, LogicalArrayImpl, MapArray, TimeArray, TimestampArray, }, - BinaryArray, BooleanArray, DaftLogicalType, DaftNumericType, ExtensionArray, NullArray, - Utf8Array, + BinaryArray, BooleanArray, DaftLogicalType, DaftNumericType, ExtensionArray, + FixedSizeBinaryArray, NullArray, Utf8Array, }, Series, }; @@ -67,6 +67,7 @@ impl LogicalArrayImpl { impl_array_arrow_get!(Utf8Array, &str); impl_array_arrow_get!(BooleanArray, bool); impl_array_arrow_get!(BinaryArray, &[u8]); +impl_array_arrow_get!(FixedSizeBinaryArray, &[u8]); impl_array_arrow_get!(Decimal128Array, i128); impl_array_arrow_get!(DateArray, i32); impl_array_arrow_get!(TimeArray, i64); diff --git a/src/daft-core/src/array/ops/groups.rs b/src/daft-core/src/array/ops/groups.rs index 940f847d43..dfa7adf450 100644 --- a/src/daft-core/src/array/ops/groups.rs +++ b/src/daft-core/src/array/ops/groups.rs @@ -6,8 +6,8 @@ use fnv::FnvHashMap; use crate::{ array::DataArray, datatypes::{ - BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, Float32Array, Float64Array, - NullArray, Utf8Array, + BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, FixedSizeBinaryArray, + Float32Array, Float64Array, NullArray, Utf8Array, }, }; use common_error::DaftResult; @@ -138,6 +138,17 @@ impl IntoGroups for BinaryArray { } } +impl IntoGroups for FixedSizeBinaryArray { + fn make_groups(&self) -> DaftResult { + let array = self.as_arrow(); + if array.null_count() > 0 { + make_groups(array.iter()) + } else { + make_groups(array.values_iter()) + } + } +} + impl IntoGroups for BooleanArray { fn make_groups(&self) -> DaftResult { let array = self.as_arrow(); diff --git a/src/daft-core/src/array/ops/hash.rs b/src/daft-core/src/array/ops/hash.rs index f1a84529f3..9a9833f143 100644 --- a/src/daft-core/src/array/ops/hash.rs +++ b/src/daft-core/src/array/ops/hash.rs @@ -2,8 +2,9 @@ use crate::{ array::DataArray, datatypes::{ logical::{DateArray, Decimal128Array, TimeArray, TimestampArray}, - BinaryArray, BooleanArray, DaftNumericType, Int16Array, Int32Array, Int64Array, Int8Array, - NullArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, Utf8Array, + BinaryArray, BooleanArray, DaftNumericType, FixedSizeBinaryArray, Int16Array, Int32Array, + Int64Array, Int8Array, NullArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + Utf8Array, }, kernels, }; @@ -51,6 +52,18 @@ impl BinaryArray { } } +impl FixedSizeBinaryArray { + pub fn hash(&self, seed: Option<&UInt64Array>) -> DaftResult { + let as_arrowed = self.as_arrow(); + + let seed = seed.map(|v| v.as_arrow()); + + let result = kernels::hashing::hash(as_arrowed, seed)?; + + Ok(DataArray::from((self.name(), Box::new(result)))) + } +} + impl BooleanArray { pub fn hash(&self, seed: Option<&UInt64Array>) -> DaftResult { let as_arrowed = self.as_arrow(); @@ -148,6 +161,21 @@ impl BinaryArray { } } +impl FixedSizeBinaryArray { + pub fn murmur3_32(&self) -> DaftResult { + let as_arrowed = self.as_arrow(); + let has_nulls = as_arrowed + .validity() + .map(|v| v.unset_bits() > 0) + .unwrap_or(false); + if has_nulls { + murmur3_32_hash_from_iter_with_nulls(self.name(), as_arrowed.into_iter()) + } else { + murmur3_32_hash_from_iter_no_nulls(self.name(), as_arrowed.values_iter()) + } + } +} + impl DateArray { pub fn murmur3_32(&self) -> DaftResult { self.physical.murmur3_32() diff --git a/src/daft-core/src/array/ops/is_in.rs b/src/daft-core/src/array/ops/is_in.rs index 725fad4ea2..6a5c9bfec9 100644 --- a/src/daft-core/src/array/ops/is_in.rs +++ b/src/daft-core/src/array/ops/is_in.rs @@ -1,8 +1,8 @@ use crate::{ array::DataArray, datatypes::{ - BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, Float32Array, Float64Array, - NullArray, Utf8Array, + BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, FixedSizeBinaryArray, + Float32Array, Float64Array, NullArray, Utf8Array, }, DataType, }; @@ -78,6 +78,7 @@ macro_rules! impl_is_in_non_numeric_array { impl_is_in_non_numeric_array!(BooleanArray); impl_is_in_non_numeric_array!(Utf8Array); impl_is_in_non_numeric_array!(BinaryArray); +impl_is_in_non_numeric_array!(FixedSizeBinaryArray); impl DaftIsIn<&NullArray> for NullArray { type Output = DaftResult; diff --git a/src/daft-core/src/array/ops/repr.rs b/src/daft-core/src/array/ops/repr.rs index 4c447ad3a5..b2d6643e65 100644 --- a/src/daft-core/src/array/ops/repr.rs +++ b/src/daft-core/src/array/ops/repr.rs @@ -7,8 +7,8 @@ use crate::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, FixedShapeTensorArray, ImageArray, MapArray, TensorArray, TimeArray, TimestampArray, }, - BinaryArray, BooleanArray, DaftNumericType, ExtensionArray, ImageFormat, NullArray, - UInt64Array, Utf8Array, + BinaryArray, BooleanArray, DaftNumericType, ExtensionArray, FixedSizeBinaryArray, + ImageFormat, NullArray, UInt64Array, Utf8Array, }, utils::display_table::{display_date32, display_decimal128, display_time64, display_timestamp}, with_match_daft_types, DataType, Series, @@ -123,6 +123,19 @@ impl BinaryArray { } } } + +impl FixedSizeBinaryArray { + pub fn str_value(&self, idx: usize) -> DaftResult { + let val = self.get(idx); + match val { + None => Ok("None".to_string()), + Some(v) => { + const LEN_TO_TRUNC: usize = 40; + pretty_print_bytes(v, LEN_TO_TRUNC) + } + } + } +} #[cfg(feature = "python")] impl crate::datatypes::PythonArray { pub fn str_value(&self, idx: usize) -> DaftResult { @@ -346,6 +359,7 @@ impl_array_html_value!(Utf8Array); impl_array_html_value!(BooleanArray); impl_array_html_value!(NullArray); impl_array_html_value!(BinaryArray); +impl_array_html_value!(FixedSizeBinaryArray); impl_array_html_value!(ListArray); impl_array_html_value!(FixedSizeListArray); impl_array_html_value!(MapArray); diff --git a/src/daft-core/src/array/ops/sort.rs b/src/daft-core/src/array/ops/sort.rs index ae3bcf012e..1c85e48daf 100644 --- a/src/daft-core/src/array/ops/sort.rs +++ b/src/daft-core/src/array/ops/sort.rs @@ -5,8 +5,8 @@ use crate::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, FixedShapeTensorArray, ImageArray, MapArray, TensorArray, TimeArray, TimestampArray, }, - BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, ExtensionArray, Float32Array, - Float64Array, NullArray, Utf8Array, + BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, ExtensionArray, + FixedSizeBinaryArray, Float32Array, Float64Array, NullArray, Utf8Array, }, kernels::search_sorted::{build_compare_with_nulls, cmp_float}, series::Series, @@ -562,6 +562,30 @@ macro_rules! impl_binary_like_sort { impl_binary_like_sort!(BinaryArray); impl_binary_like_sort!(Utf8Array); +impl FixedSizeBinaryArray { + pub fn argsort(&self, _descending: bool) -> DaftResult> + where + I: DaftIntegerType, + ::Native: arrow2::types::Index, + { + todo!("impl argsort for FixedSizeBinaryArray") + } + pub fn argsort_multikey( + &self, + _others: &[Series], + _descending: &[bool], + ) -> DaftResult> + where + I: DaftIntegerType, + ::Native: arrow2::types::Index, + { + todo!("impl argsort_multikey for FixedSizeBinaryArray") + } + pub fn sort(&self, _descending: bool) -> DaftResult { + todo!("impl sort for FixedSizeBinaryArray") + } +} + impl FixedSizeListArray { pub fn sort(&self, _descending: bool) -> DaftResult { todo!("impl sort for FixedSizeListArray") diff --git a/src/daft-core/src/array/ops/take.rs b/src/daft-core/src/array/ops/take.rs index 880fb2afeb..b2bde15143 100644 --- a/src/daft-core/src/array/ops/take.rs +++ b/src/daft-core/src/array/ops/take.rs @@ -8,8 +8,8 @@ use crate::{ DateArray, Decimal128Array, DurationArray, EmbeddingArray, FixedShapeImageArray, FixedShapeTensorArray, ImageArray, MapArray, TensorArray, TimeArray, TimestampArray, }, - BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, ExtensionArray, NullArray, - Utf8Array, + BinaryArray, BooleanArray, DaftIntegerType, DaftNumericType, ExtensionArray, + FixedSizeBinaryArray, NullArray, Utf8Array, }, DataType, }; @@ -80,6 +80,38 @@ impl_logicalarray_take!(TensorArray); impl_logicalarray_take!(FixedShapeTensorArray); impl_logicalarray_take!(MapArray); +impl FixedSizeBinaryArray { + pub fn take(&self, idx: &DataArray) -> DaftResult + where + I: DaftIntegerType, + ::Native: arrow2::types::Index, + { + let mut growable = FixedSizeBinaryArray::make_growable( + self.name(), + self.data_type(), + vec![self], + idx.data().null_count() > 0, + idx.len(), + ); + + for i in idx { + match i { + None => { + growable.add_nulls(1); + } + Some(i) => { + growable.extend(0, i.to_usize(), 1); + } + } + } + + Ok(growable + .build()? + .downcast::()? + .clone()) + } +} + #[cfg(feature = "python")] impl crate::datatypes::PythonArray { pub fn take(&self, idx: &DataArray) -> DaftResult diff --git a/src/daft-core/src/array/serdes.rs b/src/daft-core/src/array/serdes.rs index 549b6f1685..ba33a796f7 100644 --- a/src/daft-core/src/array/serdes.rs +++ b/src/daft-core/src/array/serdes.rs @@ -5,7 +5,7 @@ use serde::ser::SerializeMap; use crate::{ datatypes::{ logical::LogicalArray, BinaryArray, BooleanArray, DaftLogicalType, DaftNumericType, - ExtensionArray, Int64Array, NullArray, Utf8Array, + ExtensionArray, FixedSizeBinaryArray, Int64Array, NullArray, Utf8Array, }, DataType, IntoSeries, Series, }; @@ -100,6 +100,18 @@ impl serde::Serialize for BinaryArray { } } +impl serde::Serialize for FixedSizeBinaryArray { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut s = serializer.serialize_map(Some(2))?; + s.serialize_entry("field", self.field())?; + s.serialize_entry("values", &IterSer::new(self.as_arrow().iter()))?; + s.end() + } +} + impl serde::Serialize for NullArray { fn serialize(&self, serializer: S) -> Result where diff --git a/src/daft-core/src/datatypes/binary_ops.rs b/src/daft-core/src/datatypes/binary_ops.rs index 6c1a7f1fe3..cf7e9d85ef 100644 --- a/src/daft-core/src/datatypes/binary_ops.rs +++ b/src/daft-core/src/datatypes/binary_ops.rs @@ -103,7 +103,7 @@ impl Add for &DataType { (Null, other) | (other, Null) => { match other { // Condition is for backwards compatibility. TODO: remove - Binary | Date => Err(DaftError::TypeError( + Binary | FixedSizeBinary(..) | Date => Err(DaftError::TypeError( format!("Cannot add types: {}, {}", self, other) )), other if other.is_physical() => Ok(other.clone()), @@ -115,7 +115,7 @@ impl Add for &DataType { (Utf8, other) | (other, Utf8) => { match other { // Date condition is for backwards compatibility. TODO: remove - Binary | Date => Err(DaftError::TypeError( + Binary | FixedSizeBinary(..) | Date => Err(DaftError::TypeError( format!("Cannot add types: {}, {}", self, other) )), other if other.is_physical() => Ok(Utf8), @@ -234,7 +234,9 @@ pub fn try_physical_supertype(l: &DataType, r: &DataType) -> DaftResult Ok(other.clone()), #[cfg(feature = "python")] (Python, _) | (_, Python) => Ok(Python), - (Utf8, o) | (o, Utf8) if o.is_physical() && !matches!(o, Binary) => Ok(Utf8), + (Utf8, o) | (o, Utf8) if o.is_physical() && !matches!(o, Binary | FixedSizeBinary(..)) => { + Ok(Utf8) + } _ => Err(DaftError::TypeError(format!( "Invalid arguments to try_physical_supertype: {}, {}", l, r diff --git a/src/daft-core/src/datatypes/dtype.rs b/src/daft-core/src/datatypes/dtype.rs index 04e28a5c9c..c5094fee57 100644 --- a/src/daft-core/src/datatypes/dtype.rs +++ b/src/daft-core/src/datatypes/dtype.rs @@ -68,6 +68,8 @@ pub enum DataType { Duration(TimeUnit), /// Opaque binary data of variable length whose offsets are represented as [`i64`]. Binary, + /// Opaque binary data of fixed size. Enum parameter specifies the number of bytes per value. + FixedSizeBinary(usize), /// A variable-length UTF-8 encoded string whose offsets are represented as [`i64`]. Utf8, /// A list of some logical data type with a fixed number of elements. @@ -146,6 +148,7 @@ impl DataType { DataType::Time(unit) => Ok(ArrowType::Time64(unit.to_arrow())), DataType::Duration(unit) => Ok(ArrowType::Duration(unit.to_arrow())), DataType::Binary => Ok(ArrowType::LargeBinary), + DataType::FixedSizeBinary(size) => Ok(ArrowType::FixedSizeBinary(*size)), DataType::Utf8 => Ok(ArrowType::LargeUtf8), DataType::FixedSizeList(child_dtype, size) => Ok(ArrowType::FixedSizeList( Box::new(arrow2::datatypes::Field::new( @@ -370,6 +373,7 @@ impl DataType { DataType::Float64 => Some(8.), DataType::Utf8 => Some(VARIABLE_TYPE_SIZE), DataType::Binary => Some(VARIABLE_TYPE_SIZE), + DataType::FixedSizeBinary(size) => Some(size as f64), DataType::FixedSizeList(dtype, len) => { dtype.estimate_size_bytes().map(|b| b * (len as f64)) } @@ -458,9 +462,8 @@ impl From<&ArrowType> for DataType { DataType::Time(timeunit.into()) } ArrowType::Duration(timeunit) => DataType::Duration(timeunit.into()), - ArrowType::Binary | ArrowType::LargeBinary | ArrowType::FixedSizeBinary(_) => { - DataType::Binary - } + ArrowType::FixedSizeBinary(size) => DataType::FixedSizeBinary(*size), + ArrowType::Binary | ArrowType::LargeBinary => DataType::Binary, ArrowType::Utf8 | ArrowType::LargeUtf8 => DataType::Utf8, ArrowType::Decimal(precision, scale) => DataType::Decimal128(*precision, *scale), ArrowType::List(field) | ArrowType::LargeList(field) => { diff --git a/src/daft-core/src/datatypes/matching.rs b/src/daft-core/src/datatypes/matching.rs index ab1817528d..e3c72528f8 100644 --- a/src/daft-core/src/datatypes/matching.rs +++ b/src/daft-core/src/datatypes/matching.rs @@ -26,6 +26,7 @@ macro_rules! with_match_daft_types {( Time(_) => __with_ty__! { TimeType }, Duration(_) => __with_ty__! { DurationType }, Binary => __with_ty__! { BinaryType }, + FixedSizeBinary(_) => __with_ty__! { FixedSizeBinaryType }, Utf8 => __with_ty__! { Utf8Type }, FixedSizeList(_, _) => __with_ty__! { FixedSizeListType }, List(_) => __with_ty__! { ListType }, @@ -74,6 +75,7 @@ macro_rules! with_match_physical_daft_types {( Float32 => __with_ty__! { Float32Type }, Float64 => __with_ty__! { Float64Type }, Binary => __with_ty__! { BinaryType }, + FixedSizeBinary(_) => __with_ty__! { FixedSizeBinaryType }, Utf8 => __with_ty__! { Utf8Type }, FixedSizeList(_, _) => __with_ty__! { FixedSizeListType }, List(_) => __with_ty__! { ListType }, @@ -97,6 +99,7 @@ macro_rules! with_match_arrow_daft_types {( Null => __with_ty__! { NullType }, Boolean => __with_ty__! { BooleanType }, Binary => __with_ty__! { BinaryType }, + FixedSizeBinary(_) => __with_ty__! { FixedSizeBinaryType }, Int8 => __with_ty__! { Int8Type }, Int16 => __with_ty__! { Int16Type }, Int32 => __with_ty__! { Int32Type }, @@ -143,6 +146,7 @@ macro_rules! with_match_comparable_daft_types {( Float64 => __with_ty__! { Float64Type }, Utf8 => __with_ty__! { Utf8Type }, Binary => __with_ty__! { BinaryType }, + FixedSizeBinary(_) => __with_ty__! { FixedSizeBinaryType }, _ => panic!("{:?} not implemented", $key_type) } })} diff --git a/src/daft-core/src/datatypes/mod.rs b/src/daft-core/src/datatypes/mod.rs index c33279c9ab..d3c1c17ef6 100644 --- a/src/daft-core/src/datatypes/mod.rs +++ b/src/daft-core/src/datatypes/mod.rs @@ -189,6 +189,7 @@ impl_daft_arrow_datatype!(UInt64Type, UInt64); impl_daft_arrow_datatype!(Float32Type, Float32); impl_daft_arrow_datatype!(Float64Type, Float64); impl_daft_arrow_datatype!(BinaryType, Binary); +impl_daft_arrow_datatype!(FixedSizeBinaryType, Unknown); impl_daft_arrow_datatype!(Utf8Type, Utf8); impl_daft_arrow_datatype!(ExtensionType, Unknown); @@ -352,6 +353,7 @@ pub type UInt64Array = DataArray; pub type Float32Array = DataArray; pub type Float64Array = DataArray; pub type BinaryArray = DataArray; +pub type FixedSizeBinaryArray = DataArray; pub type Utf8Array = DataArray; pub type ExtensionArray = DataArray; diff --git a/src/daft-core/src/kernels/hashing.rs b/src/daft-core/src/kernels/hashing.rs index 388c178927..c7151c10ce 100644 --- a/src/daft-core/src/kernels/hashing.rs +++ b/src/daft-core/src/kernels/hashing.rs @@ -1,6 +1,8 @@ use arrow2::{ - array::Array, - array::{BinaryArray, BooleanArray, NullArray, PrimitiveArray, Utf8Array}, + array::{ + Array, BinaryArray, BooleanArray, FixedSizeBinaryArray, NullArray, PrimitiveArray, + Utf8Array, + }, datatypes::{DataType, PhysicalType}, error::{Error, Result}, types::{NativeType, Offset}, @@ -93,6 +95,22 @@ fn hash_binary( PrimitiveArray::::new(DataType::UInt64, hashes.into(), None) } +fn hash_fixed_size_binary( + array: &FixedSizeBinaryArray, + seed: Option<&PrimitiveArray>, +) -> PrimitiveArray { + let hashes = if let Some(seed) = seed { + array + .values_iter() + .zip(seed.values_iter()) + .map(|(v, s)| xxh3_64_with_seed(v, *s)) + .collect::>() + } else { + array.values_iter().map(xxh3_64).collect::>() + }; + PrimitiveArray::::new(DataType::UInt64, hashes.into(), None) +} + fn hash_utf8( array: &Utf8Array, seed: Option<&PrimitiveArray>, @@ -166,6 +184,7 @@ pub fn hash(array: &dyn Array, seed: Option<&PrimitiveArray>) -> Result hash_binary::(array.as_any().downcast_ref().unwrap(), seed), LargeBinary => hash_binary::(array.as_any().downcast_ref().unwrap(), seed), + FixedSizeBinary => hash_fixed_size_binary(array.as_any().downcast_ref().unwrap(), seed), Utf8 => hash_utf8::(array.as_any().downcast_ref().unwrap(), seed), LargeUtf8 => hash_utf8::(array.as_any().downcast_ref().unwrap(), seed), t => { diff --git a/src/daft-core/src/kernels/search_sorted.rs b/src/daft-core/src/kernels/search_sorted.rs index d46d23f8a7..53fef1b712 100644 --- a/src/daft-core/src/kernels/search_sorted.rs +++ b/src/daft-core/src/kernels/search_sorted.rs @@ -1,12 +1,10 @@ use std::{cmp::Ordering, iter::zip}; use arrow2::{ - array::Array, array::{ ord::{build_compare, DynComparator}, - BinaryArray, + Array, BinaryArray, FixedSizeBinaryArray, PrimitiveArray, Utf8Array, }, - array::{PrimitiveArray, Utf8Array}, datatypes::{DataType, PhysicalType}, error::{Error, Result}, types::{NativeType, Offset}, @@ -207,6 +205,69 @@ fn search_sorted_binary_array( PrimitiveArray::::new(DataType::UInt64, results.into(), None) } +fn search_sorted_fixed_size_binary_array( + sorted_array: &FixedSizeBinaryArray, + keys: &FixedSizeBinaryArray, + input_reversed: bool, +) -> PrimitiveArray { + let array_size = sorted_array.len(); + let mut left = 0_usize; + let mut right = array_size; + + let mut results: Vec = Vec::with_capacity(array_size); + let mut last_key = keys.iter().next().unwrap_or(None); + for key_val in keys.iter() { + let is_last_key_lt = match (last_key, key_val) { + (None, None) => false, + (None, Some(_)) => input_reversed, + (Some(last_key), Some(key_val)) => { + if !input_reversed { + last_key.lt(key_val) + } else { + last_key.gt(key_val) + } + } + (Some(_), None) => !input_reversed, + }; + if is_last_key_lt { + right = array_size; + } else { + left = 0; + right = if right < array_size { + right + 1 + } else { + array_size + }; + } + while left < right { + let mid_idx = left + ((right - left) >> 1); + let mid_val = unsafe { sorted_array.value_unchecked(mid_idx) }; + let is_key_val_lt = match (key_val, sorted_array.is_valid(mid_idx)) { + (None, false) => false, + (None, true) => input_reversed, + (Some(key_val), true) => { + if !input_reversed { + key_val.lt(mid_val) + } else { + mid_val.lt(key_val) + } + } + (Some(_), false) => !input_reversed, + }; + + if is_key_val_lt { + right = mid_idx; + } else { + left = mid_idx + 1; + } + } + results.push(left.try_into().unwrap()); + last_key = key_val; + } + + PrimitiveArray::::new(DataType::UInt64, results.into(), None) +} + macro_rules! with_match_searching_primitive_type {( $key_type:expr, | $_:tt $T:ident | $($body:tt)* ) => ({ @@ -469,6 +530,11 @@ pub fn search_sorted( keys.as_any().downcast_ref().unwrap(), input_reversed, ), + FixedSizeBinary => search_sorted_fixed_size_binary_array( + sorted_array.as_any().downcast_ref().unwrap(), + keys.as_any().downcast_ref().unwrap(), + input_reversed, + ), t => { return Err(Error::NotYetImplemented(format!( "search_sorted not implemented for type {t:?}" diff --git a/src/daft-core/src/python/datatype.rs b/src/daft-core/src/python/datatype.rs index 598f4b1625..398a4fdf9e 100644 --- a/src/daft-core/src/python/datatype.rs +++ b/src/daft-core/src/python/datatype.rs @@ -145,6 +145,17 @@ impl PyDataType { Ok(DataType::Binary.into()) } + #[staticmethod] + pub fn fixed_size_binary(size: i64) -> PyResult { + if size <= 0 { + return Err(PyValueError::new_err(format!( + "The size for fixed-size binary types must be a positive integer, but got: {}", + size + ))); + } + Ok(DataType::FixedSizeBinary(usize::try_from(size)?).into()) + } + #[staticmethod] pub fn string() -> PyResult { Ok(DataType::Utf8.into()) diff --git a/src/daft-core/src/series/array_impl/binary_ops.rs b/src/daft-core/src/series/array_impl/binary_ops.rs index c2a6be7967..79d4afe39e 100644 --- a/src/daft-core/src/series/array_impl/binary_ops.rs +++ b/src/daft-core/src/series/array_impl/binary_ops.rs @@ -9,7 +9,7 @@ use crate::{ }, datatypes::{ logical::{Decimal128Array, MapArray}, - Int128Array, + FixedSizeBinaryArray, Int128Array, }, series::series_like::SeriesLike, with_match_comparable_daft_types, with_match_numeric_daft_types, DataType, @@ -214,6 +214,7 @@ impl SeriesBinaryOps for ArrayWrapper {} impl SeriesBinaryOps for ArrayWrapper {} impl SeriesBinaryOps for ArrayWrapper {} impl SeriesBinaryOps for ArrayWrapper {} +impl SeriesBinaryOps for ArrayWrapper {} impl SeriesBinaryOps for ArrayWrapper {} impl SeriesBinaryOps for ArrayWrapper {} impl SeriesBinaryOps for ArrayWrapper {} diff --git a/src/daft-core/src/series/array_impl/data_array.rs b/src/daft-core/src/series/array_impl/data_array.rs index 00c8bd63ad..66a1009e04 100644 --- a/src/daft-core/src/series/array_impl/data_array.rs +++ b/src/daft-core/src/series/array_impl/data_array.rs @@ -7,6 +7,7 @@ use crate::array::ops::GroupIndices; use crate::array::DataArray; use crate::datatypes::DaftArrowBackedType; +use crate::datatypes::FixedSizeBinaryArray; #[cfg(feature = "python")] use crate::datatypes::PythonArray; use crate::series::array_impl::binary_ops::SeriesBinaryOps; @@ -220,6 +221,7 @@ macro_rules! impl_series_like_for_data_array { impl_series_like_for_data_array!(NullArray); impl_series_like_for_data_array!(BooleanArray); impl_series_like_for_data_array!(BinaryArray); +impl_series_like_for_data_array!(FixedSizeBinaryArray); impl_series_like_for_data_array!(Int8Array); impl_series_like_for_data_array!(Int16Array); impl_series_like_for_data_array!(Int32Array); diff --git a/src/daft-core/src/series/ops/downcast.rs b/src/daft-core/src/series/ops/downcast.rs index 584f875a11..8c509113c5 100644 --- a/src/daft-core/src/series/ops/downcast.rs +++ b/src/daft-core/src/series/ops/downcast.rs @@ -83,6 +83,10 @@ impl Series { self.downcast() } + pub fn fixed_size_binary(&self) -> DaftResult<&FixedSizeBinaryArray> { + self.downcast() + } + pub fn utf8(&self) -> DaftResult<&Utf8Array> { self.downcast() } diff --git a/src/daft-core/src/series/ops/hash.rs b/src/daft-core/src/series/ops/hash.rs index f5196eeae4..cd831b0f9f 100644 --- a/src/daft-core/src/series/ops/hash.rs +++ b/src/daft-core/src/series/ops/hash.rs @@ -27,6 +27,7 @@ impl Series { UInt64 => self.u64()?.murmur3_32(), Utf8 => self.utf8()?.murmur3_32(), Binary => self.binary()?.murmur3_32(), + FixedSizeBinary(_) => self.fixed_size_binary()?.murmur3_32(), Date => self.date()?.murmur3_32(), Time(..) => self.time()?.murmur3_32(), Timestamp(..) => self.timestamp()?.murmur3_32(), diff --git a/src/daft-core/src/series/serdes.rs b/src/daft-core/src/series/serdes.rs index 69a0ca33ec..94629e2a00 100644 --- a/src/daft-core/src/series/serdes.rs +++ b/src/daft-core/src/series/serdes.rs @@ -146,6 +146,12 @@ impl<'d> serde::Deserialize<'d> for Series { map.next_value::>>>()?.into_iter(), ) .into_series()), + FixedSizeBinary(size) => Ok(FixedSizeBinaryArray::from_iter( + field.name.as_str(), + map.next_value::>>>()?.into_iter(), + *size, + ) + .into_series()), Extension(..) => { let physical = map.next_value::()?; let physical = physical.to_arrow(); diff --git a/src/daft-core/src/utils/arrow.rs b/src/daft-core/src/utils/arrow.rs index a18b387994..a3690a837a 100644 --- a/src/daft-core/src/utils/arrow.rs +++ b/src/daft-core/src/utils/arrow.rs @@ -15,9 +15,7 @@ fn coerce_to_daft_compatible_type( ) -> Option { match dtype { arrow2::datatypes::DataType::Utf8 => Some(arrow2::datatypes::DataType::LargeUtf8), - arrow2::datatypes::DataType::Binary | arrow2::datatypes::DataType::FixedSizeBinary(_) => { - Some(arrow2::datatypes::DataType::LargeBinary) - } + arrow2::datatypes::DataType::Binary => Some(arrow2::datatypes::DataType::LargeBinary), arrow2::datatypes::DataType::List(field) => { let new_field = match coerce_to_daft_compatible_type(field.data_type()) { Some(new_inner_dtype) => Box::new( diff --git a/src/daft-core/src/utils/supertype.rs b/src/daft-core/src/utils/supertype.rs index d2335894e2..fcf5f07010 100644 --- a/src/daft-core/src/utils/supertype.rs +++ b/src/daft-core/src/utils/supertype.rs @@ -194,7 +194,7 @@ pub fn get_supertype(l: &DataType, r: &DataType) -> Option { // } // every known type can be casted to a string except binary - (dt, Utf8) if dt.ne(&Binary) => Some(Utf8), + (dt, Utf8) if !matches!(&dt, &Binary | &FixedSizeBinary(_)) => Some(Utf8), (dt, Null) => Some(dt.clone()), // Drop Null Type diff --git a/src/daft-stats/src/column_stats/mod.rs b/src/daft-stats/src/column_stats/mod.rs index b6a0eeb99f..c263d628cd 100644 --- a/src/daft-stats/src/column_stats/mod.rs +++ b/src/daft-stats/src/column_stats/mod.rs @@ -68,7 +68,7 @@ impl ColumnRangeStatistics { DataType::Float32 | DataType::Float64 | DataType::Decimal128(..) | DataType::Boolean | // String types - DataType::Utf8 | DataType::Binary | + DataType::Utf8 | DataType::Binary | DataType::FixedSizeBinary(..) | // Temporal types DataType::Date | DataType::Time(..) | DataType::Timestamp(..) | DataType::Duration(..) => true, diff --git a/tests/expressions/typing/conftest.py b/tests/expressions/typing/conftest.py index 686a559da2..8c72b78364 100644 --- a/tests/expressions/typing/conftest.py +++ b/tests/expressions/typing/conftest.py @@ -28,6 +28,7 @@ (DataType.bool(), pa.array([True, False, None], type=pa.bool_())), (DataType.null(), pa.array([None, None, None], type=pa.null())), (DataType.binary(), pa.array([b"1", b"2", None], type=pa.binary())), + (DataType.fixed_size_binary(1), pa.array([b"1", b"2", None], type=pa.binary(1))), ] ALL_DATATYPES_BINARY_PAIRS = list(itertools.product(ALL_DTYPES, repeat=2)) @@ -182,6 +183,7 @@ def is_comparable(dt: DataType): or dt == DataType.string() or dt == DataType.null() or dt == DataType.binary() + or dt == DataType.fixed_size_binary(1) or dt._is_temporal_type() ) @@ -210,7 +212,9 @@ def has_supertype(dt1: DataType, dt2: DataType) -> bool: for x, y in ((dt1, dt2), (dt2, dt1)): # --- Common types across hierarchies --- either_null = x == DataType.null() - either_string_and_other_not_binary = x == DataType.string() and y != DataType.binary() + either_string_and_other_not_binary = x == DataType.string() and not ( + y == DataType.binary() or y == DataType.fixed_size_binary(1) + ) # --- Within type hierarchies --- both_numeric = (is_numeric(x) and is_numeric(y)) or ((x == DataType.bool()) and is_numeric(y)) diff --git a/tests/series/test_cast.py b/tests/series/test_cast.py index ccec910ea8..fae2026678 100644 --- a/tests/series/test_cast.py +++ b/tests/series/test_cast.py @@ -9,6 +9,7 @@ import pytest from daft.datatype import DataType, ImageMode, TimeUnit +from daft.exceptions import DaftCoreException from daft.series import Series from tests.series import ARROW_FLOAT_TYPES, ARROW_INT_TYPES, ARROW_STRING_TYPES @@ -813,3 +814,20 @@ def test_cast_timestamp_to_time_unsupported_timeunit(timeunit): input = Series.from_pylist([datetime(2022, 1, 6, 12, 34, 56, 78)]) with pytest.raises(ValueError): input.cast(DataType.time(timeunit)) + + +def test_cast_binary_to_fixed_size_binary(): + data = [b"abc", b"def", None, b"bcd", None] + + input = Series.from_pylist(data) + assert input.datatype() == DataType.binary() + casted = input.cast(DataType.fixed_size_binary(3)) + assert casted.to_pylist() == [b"abc", b"def", None, b"bcd", None] + + +def test_cast_binary_to_fixed_size_binary_fails_with_variable_lengths(): + data = [b"abc", b"def", None, b"bcd", None, b"long"] + + input = Series.from_pylist(data) + with pytest.raises(DaftCoreException): + input.cast(DataType.fixed_size_binary(3)) diff --git a/tests/series/test_comparisons.py b/tests/series/test_comparisons.py index 0b1f0356fa..21591db608 100644 --- a/tests/series/test_comparisons.py +++ b/tests/series/test_comparisons.py @@ -582,6 +582,85 @@ def test_comparisons_binary_right_scalar(l_dtype, r_dtype) -> None: assert gt == [False, False, True, None, True, None] +def test_comparisons_fixed_size_binary() -> None: + l_arrow = pa.array([b"11111", b"22222", b"33333", None, b"12345", None], type=pa.binary(5)) + r_arrow = pa.array([b"11111", b"33333", b"11111", b"12345", None, None], type=pa.binary(5)) + # eq, lt, gt, None, None, None + + left = Series.from_arrow(l_arrow) + right = Series.from_arrow(r_arrow) + lt = (left < right).to_pylist() + assert lt == [False, True, False, None, None, None] + + le = (left <= right).to_pylist() + assert le == [True, True, False, None, None, None] + + eq = (left == right).to_pylist() + assert eq == [True, False, False, None, None, None] + + neq = (left != right).to_pylist() + assert neq == [False, True, True, None, None, None] + + ge = (left >= right).to_pylist() + assert ge == [True, False, True, None, None, None] + + gt = (left > right).to_pylist() + assert gt == [False, False, True, None, None, None] + + +def test_comparisons_fixed_size_binary_left_scalar() -> None: + l_arrow = pa.array([b"222"], type=pa.binary(3)) + r_arrow = pa.array([b"111", b"222", b"333", None], type=pa.binary(3)) + # gt, eq, lt + + left = Series.from_arrow(l_arrow) + right = Series.from_arrow(r_arrow) + + lt = (left < right).to_pylist() + assert lt == [False, False, True, None] + + le = (left <= right).to_pylist() + assert le == [False, True, True, None] + + eq = (left == right).to_pylist() + assert eq == [False, True, False, None] + + neq = (left != right).to_pylist() + assert neq == [True, False, True, None] + + ge = (left >= right).to_pylist() + assert ge == [True, True, False, None] + + gt = (left > right).to_pylist() + assert gt == [True, False, False, None] + + +def test_comparisons_fixed_size_binary_right_scalar() -> None: + l_arrow = pa.array([b"111", b"222", b"333", None, b"555", None], type=pa.binary(3)) + r_arrow = pa.array([b"222"], type=pa.binary(3)) + # lt, eq, gt, None, gt, None + + left = Series.from_arrow(l_arrow) + right = Series.from_arrow(r_arrow) + lt = (left < right).to_pylist() + assert lt == [True, False, False, None, False, None] + + le = (left <= right).to_pylist() + assert le == [True, True, False, None, False, None] + + eq = (left == right).to_pylist() + assert eq == [False, True, False, None, False, None] + + neq = (left != right).to_pylist() + assert neq == [True, False, True, None, True, None] + + ge = (left >= right).to_pylist() + assert ge == [False, True, True, None, True, None] + + gt = (left > right).to_pylist() + assert gt == [False, False, True, None, True, None] + + class CustomZero: def __eq__(self, other): if isinstance(other, CustomZero): diff --git a/tests/series/test_filter.py b/tests/series/test_filter.py index 01d83ea253..be187b4bbc 100644 --- a/tests/series/test_filter.py +++ b/tests/series/test_filter.py @@ -54,8 +54,9 @@ def test_series_filter_on_bool() -> None: assert result.to_pylist() == expected -def test_series_filter_on_binary() -> None: - s = Series.from_pylist([b"Y", b"N", None, b"Y", None, b"N"]) +@pytest.mark.parametrize("type", [pa.binary(), pa.binary(1)]) +def test_series_filter_on_binary(type) -> None: + s = Series.from_arrow(pa.array([b"Y", b"N", None, b"Y", None, b"N"], type=type)) pymask = [False, True, True, None, False, False] mask = Series.from_pylist(pymask) diff --git a/tests/series/test_hash.py b/tests/series/test_hash.py index 54f780866f..33b1d7821b 100644 --- a/tests/series/test_hash.py +++ b/tests/series/test_hash.py @@ -87,6 +87,27 @@ def test_hash_binary_array_with_reference(): assert hashed_again.to_pylist() == expected +def test_hash_fixed_size_binary_array_with_reference(): + import pyarrow as pa + + arr = Series.from_arrow(pa.array([b"foo", b"bar", None], type=pa.binary(3))) + expected = [ + xxhash.xxh3_64_intdigest(b"foo"), + xxhash.xxh3_64_intdigest(b"bar"), + xxhash.xxh3_64_intdigest(b"\x00\x00\x00"), + ] + hashed = arr.hash() + assert hashed.to_pylist() == expected + + hashed_again = arr.hash(hashed) + expected = [ + xxhash.xxh3_64_intdigest(b"foo", expected[0]), + xxhash.xxh3_64_intdigest(b"bar", expected[1]), + xxhash.xxh3_64_intdigest(b"\x00\x00\x00", expected[2]), + ] + assert hashed_again.to_pylist() == expected + + def test_hash_null_array_with_reference(): arr = Series.from_pylist([None, None, None]) expected = [xxhash.xxh3_64_intdigest(b""), xxhash.xxh3_64_intdigest(b""), xxhash.xxh3_64_intdigest(b"")] @@ -169,6 +190,16 @@ def test_murmur3_32_hash_bytes(): assert hashes.to_pylist() == [java_answer, None] +def test_murmur3_32_hash_fixed_sized_bytes(): + import pyarrow as pa + + arr = Series.from_arrow(pa.array([b"\x00\x01\x02\x03", None], type=pa.binary(4))) + assert arr.datatype() == DataType.fixed_size_binary(4) + hashes = arr.murmur3_32() + java_answer = -188683207 + assert hashes.to_pylist() == [java_answer, None] + + def test_murmur3_32_hash_date(): arr = Series.from_pylist([date(2017, 11, 16), None]) assert arr.datatype() == DataType.date() diff --git a/tests/series/test_if_else.py b/tests/series/test_if_else.py index 482fc3a841..090d232538 100644 --- a/tests/series/test_if_else.py +++ b/tests/series/test_if_else.py @@ -125,17 +125,22 @@ def test_series_if_else_bool( "if_false_length", [1, 3], ) +@pytest.mark.parametrize( + "type, expected_type", [(pa.binary(), DataType.binary()), (pa.binary(1), DataType.fixed_size_binary(1))] +) def test_series_if_else_binary( if_true_value, if_false_value, if_true_length, if_false_length, + type, + expected_type, ) -> None: - if_true_series = Series.from_arrow(pa.array([if_true_value] * if_true_length, type=pa.binary())) - if_false_series = Series.from_arrow(pa.array([if_false_value] * if_false_length, type=pa.binary())) + if_true_series = Series.from_arrow(pa.array([if_true_value] * if_true_length, type=type)) + if_false_series = Series.from_arrow(pa.array([if_false_value] * if_false_length, type=type)) predicate_series = Series.from_arrow(pa.array([True, False, None])) result = predicate_series.if_else(if_true_series, if_false_series) - assert result.datatype() == DataType.binary() + assert result.datatype() == expected_type assert result.to_pylist() == [if_true_value, if_false_value, None] @@ -553,12 +558,16 @@ def test_series_if_else_predicate_broadcast_bools(predicate_value, expected_resu ["predicate_value", "expected_results"], [(True, [b"Y", b"Y", b"Y"]), (False, [b"N", b"N", b"N"]), (None, [None, None, None])], ) -def test_series_if_else_predicate_broadcast_binary(predicate_value, expected_results) -> None: - if_true_series = Series.from_arrow(pa.array([b"Y", b"Y", b"Y"], type=pa.binary())) - if_false_series = Series.from_arrow(pa.array([b"N", b"N", b"N"], type=pa.binary())) +@pytest.mark.parametrize( + "type, result_type", + [(pa.binary(), DataType.binary()), (pa.binary(1), DataType.fixed_size_binary(1))], +) +def test_series_if_else_predicate_broadcast_binary(predicate_value, expected_results, type, result_type) -> None: + if_true_series = Series.from_arrow(pa.array([b"Y", b"Y", b"Y"], type=type)) + if_false_series = Series.from_arrow(pa.array([b"N", b"N", b"N"], type=type)) predicate_series = Series.from_arrow(pa.array([predicate_value], type=pa.bool_())) result = predicate_series.if_else(if_true_series, if_false_series) - assert result.datatype() == DataType.binary() + assert result.datatype() == result_type assert result.to_pylist() == expected_results diff --git a/tests/series/test_series.py b/tests/series/test_series.py index fe31324588..cdbb15f465 100644 --- a/tests/series/test_series.py +++ b/tests/series/test_series.py @@ -112,8 +112,9 @@ def test_series_pylist_round_trip_null() -> None: assert words["None"] == 2 -def test_series_pylist_round_trip_binary() -> None: - data = pa.array([b"a", b"b", b"c", None, b"d", None]) +@pytest.mark.parametrize("type", [pa.binary(), pa.binary(1)]) +def test_series_pylist_round_trip_binary(type) -> None: + data = pa.array([b"a", b"b", b"c", None, b"d", None], type=type) s = Series.from_arrow(data) @@ -142,6 +143,16 @@ def test_series_bincode_serdes(dtype) -> None: assert s.to_pylist() == copied_s.to_pylist() +def test_series_bincode_serdes_fixed_size_binary() -> None: + s = Series.from_arrow(pa.array([b"a", b"b", b"c", None, b"d", None], type=pa.binary(1))) + serialized = s._debug_bincode_serialize() + copied_s = Series._debug_bincode_deserialize(serialized) + + assert s.name() == copied_s.name() + assert s.datatype() == copied_s.datatype() + assert s.to_pylist() == copied_s.to_pylist() + + @pytest.mark.parametrize( "data", [ diff --git a/tests/series/test_size_bytes.py b/tests/series/test_size_bytes.py index a0e69d83a4..5e90c8db68 100644 --- a/tests/series/test_size_bytes.py +++ b/tests/series/test_size_bytes.py @@ -123,6 +123,24 @@ def test_series_binary_size_bytes(size, with_nulls) -> None: assert s.size_bytes() == get_total_buffer_size(data) +@pytest.mark.parametrize("size", [1, 2, 8, 9]) +@pytest.mark.parametrize("with_nulls", [True, False]) +def test_series_fixed_size_binary_size_bytes(size, with_nulls) -> None: + import random + + pydata = ["".join([str(random.randint(0, 9)) for _ in range(size)]).encode() for _ in range(size)] + + if with_nulls and size > 0: + data = pa.array(pydata[:-1] + [None], pa.binary(size)) + else: + data = pa.array(pydata, pa.binary(size)) + + s = Series.from_arrow(data) + + assert s.datatype() == DataType.fixed_size_binary(size) + assert s.size_bytes() == get_total_buffer_size(data) + + @pytest.mark.parametrize("dtype, size", itertools.product(ARROW_INT_TYPES + ARROW_FLOAT_TYPES, [0, 1, 2, 8, 9, 16])) @pytest.mark.parametrize("with_nulls", [True, False]) def test_series_list_size_bytes(dtype, size, with_nulls) -> None: diff --git a/tests/series/test_take.py b/tests/series/test_take.py index ea68ff515e..6139e7fcea 100644 --- a/tests/series/test_take.py +++ b/tests/series/test_take.py @@ -62,8 +62,9 @@ def time_maker(h, m, s, us): assert taken.to_pylist() == times[::-1] -def test_series_binary_take() -> None: - data = pa.array([b"1", b"2", b"3", None, b"5", None]) +@pytest.mark.parametrize("type", [pa.binary(), pa.binary(1)]) +def test_series_binary_take(type) -> None: + data = pa.array([b"1", b"2", b"3", None, b"5", None], type=type) s = Series.from_arrow(data) pyidx = [2, 0, None, 5] diff --git a/tests/table/test_from_py.py b/tests/table/test_from_py.py index bc8d2be888..95ef19b884 100644 --- a/tests/table/test_from_py.py +++ b/tests/table/test_from_py.py @@ -92,6 +92,7 @@ "float64": pa.array(PYTHON_TYPE_ARRAYS["float"], pa.float64()), "string": pa.array(PYTHON_TYPE_ARRAYS["str"], pa.string()), "binary": pa.array(PYTHON_TYPE_ARRAYS["binary"], pa.binary()), + "fixed_size_binary": pa.array(PYTHON_TYPE_ARRAYS["binary"], pa.binary(3)), "boolean": pa.array(PYTHON_TYPE_ARRAYS["bool"], pa.bool_()), "date32": pa.array(PYTHON_TYPE_ARRAYS["date"], pa.date32()), "date64": pa.array(PYTHON_TYPE_ARRAYS["date"], pa.date64()), @@ -148,6 +149,7 @@ "float64": pa.float64(), "string": pa.large_string(), "binary": pa.large_binary(), + "fixed_size_binary": pa.binary(3), "boolean": pa.bool_(), "date32": pa.date32(), "date64": pa.timestamp("ms"), @@ -354,6 +356,7 @@ def test_from_pydict_arrow_deeply_nested() -> None: (pa.array([1, 2, None, 4], type=pa.int64()), pa.int64()), (pa.array(["a", "b", None, "d"], type=pa.string()), pa.large_string()), (pa.array([b"a", b"b", None, b"d"], type=pa.binary()), pa.large_binary()), + (pa.array([b"a", b"b", None, b"d"], type=pa.binary(1)), pa.binary(1)), (pa.array([[1, 2], [3], None, [None, 4]], pa.list_(pa.int64())), pa.large_list(pa.int64())), (pa.array([[1, 2], [3, 4], None, [None, 6]], pa.list_(pa.int64(), 2)), pa.list_(pa.int64(), 2)), ( @@ -384,6 +387,7 @@ def test_from_pydict_arrow_with_nulls_roundtrip(data, out_dtype, chunked) -> Non (pa.array([1, 2, 3, 4], type=pa.int64()), pa.int64()), (pa.array(["a", "b", "c", "d"], type=pa.string()), pa.large_string()), (pa.array([b"a", b"b", b"c", b"d"], type=pa.binary()), pa.large_binary()), + (pa.array([b"a", b"b", b"c", b"d"], type=pa.binary(1)), pa.binary(1)), (pa.array([[1, 2], [3], [4, 5, 6], [None, 7]], pa.list_(pa.int64())), pa.large_list(pa.int64())), (pa.array([[1, 2], [3, None], [4, 5], [None, 6]], pa.list_(pa.int64(), 2)), pa.list_(pa.int64(), 2)), ( @@ -398,6 +402,7 @@ def test_from_pydict_arrow_with_nulls_roundtrip(data, out_dtype, chunked) -> Non (pa.array([1, 2, None, 4], type=pa.int64()), pa.int64()), (pa.array(["a", "b", None, "d"], type=pa.string()), pa.large_string()), (pa.array([b"a", b"b", None, b"d"], type=pa.binary()), pa.large_binary()), + (pa.array([b"a", b"b", None, b"d"], type=pa.binary(1)), pa.binary(1)), (pa.array([[1, 2], [3], None, [None, 4]], pa.list_(pa.int64())), pa.large_list(pa.int64())), (pa.array([[1, 2], [3, 4], None, [None, 6]], pa.list_(pa.int64(), 2)), pa.list_(pa.int64(), 2)), ( @@ -438,6 +443,7 @@ def test_from_pydict_series() -> None: (pa.array([1, 2, 3, 4], type=pa.int64()), pa.int64()), (pa.array(["a", "b", "c", "d"], type=pa.string()), pa.large_string()), (pa.array([b"a", b"b", b"c", b"d"], type=pa.binary()), pa.large_binary()), + (pa.array([b"a", b"b", b"c", b"d"], type=pa.binary(1)), pa.binary(1)), (pa.array([[1, 2], [3], [4, 5, 6], [None, 7]], pa.list_(pa.int64())), pa.large_list(pa.int64())), (pa.array([[1, 2], [3, None], [4, 5], [None, 6]], pa.list_(pa.int64(), 2)), pa.list_(pa.int64(), 2)), ( @@ -452,6 +458,7 @@ def test_from_pydict_series() -> None: (pa.array([1, 2, None, 4], type=pa.int64()), pa.int64()), (pa.array(["a", "b", None, "d"], type=pa.string()), pa.large_string()), (pa.array([b"a", b"b", None, b"d"], type=pa.binary()), pa.large_binary()), + (pa.array([b"a", b"b", None, b"d"], type=pa.binary(1)), pa.binary(1)), (pa.array([[1, 2], [3], None, [None, 4]], pa.list_(pa.int64())), pa.large_list(pa.int64())), (pa.array([[1, 2], [3, 4], None, [None, 6]], pa.list_(pa.int64(), 2)), pa.list_(pa.int64(), 2)), ( diff --git a/tests/table/test_sorting.py b/tests/table/test_sorting.py index 9ee928294a..0f9e635685 100644 --- a/tests/table/test_sorting.py +++ b/tests/table/test_sorting.py @@ -7,7 +7,6 @@ import pytest from daft import col -from daft.datatype import DataType from daft.logical.schema import Schema from daft.series import Series from daft.table import MicroPartition @@ -186,15 +185,14 @@ def test_table_multiple_col_sorting(sort_dtype, value_dtype, data) -> None: ) def test_table_multiple_col_sorting_binary(data) -> None: a, b, a_desc, b_desc, expected = data - a = [x.to_bytes(1, "little") if x is not None else None for x in a] - b = [x.to_bytes(1, "little") if x is not None else None for x in b] + a = pa.array([x.to_bytes(1, "little") if x is not None else None for x in a], type=pa.binary()) + b = pa.array([x.to_bytes(1, "little") if x is not None else None for x in b], type=pa.binary()) pa_table = pa.Table.from_pydict({"a": a, "b": b}) argsort_order = Series.from_pylist(expected) daft_table = MicroPartition.from_arrow(pa_table) - daft_table = daft_table.eval_expression_list([col("a").cast(DataType.binary()), col("b").cast(DataType.binary())]) assert len(daft_table) == 5 assert daft_table.column_names() == ["a", "b"] diff --git a/tests/table/test_table_aggs.py b/tests/table/test_table_aggs.py index 2f24122e93..84545ff954 100644 --- a/tests/table/test_table_aggs.py +++ b/tests/table/test_table_aggs.py @@ -184,6 +184,32 @@ def test_table_minmax_bool(case) -> None: assert res == expected +test_table_minmax_binary_cases = [ + ([], {"min": [None], "max": [None]}), + ([None], {"min": [None], "max": [None]}), + ([None, None, None], {"min": [None], "max": [None]}), + ([b"1"], {"min": [b"1"], "max": [b"1"]}), + ([None, b"1"], {"min": [b"1"], "max": [b"1"]}), + ([b"a", b"b", b"c", b"a"], {"min": [b"a"], "max": [b"c"]}), +] + + +@pytest.mark.parametrize("case", test_table_minmax_binary_cases, ids=[f"{_}" for _ in test_table_minmax_binary_cases]) +@pytest.mark.parametrize("type", [pa.binary(), pa.binary(1)]) +def test_table_minmax_binary(case, type) -> None: + input, expected = case + daft_table = MicroPartition.from_arrow(pa.table({"input": pa.array(input, type=type)})) + daft_table = daft_table.eval_expression_list( + [ + col("input").alias("min").min(), + col("input").alias("max").max(), + ] + ) + + res = daft_table.to_pydict() + assert res == expected + + test_table_sum_mean_cases = [ ([], {"sum": [None], "mean": [None]}), ([None], {"sum": [None], "mean": [None]}), @@ -495,18 +521,26 @@ def test_groupby_numeric_string_bool_no_nulls(dtype) -> None: ) -def test_groupby_binary_bool_some_nulls() -> None: +@pytest.mark.parametrize("type", [pa.binary(), pa.binary(1)]) +@pytest.mark.parametrize( + "agg, expected", + [ + (col("cookies").max(), [b"2", b"4"]), + (col("cookies").min(), [b"1", b"3"]), + ], +) +def test_groupby_binary_bool_some_nulls(type, agg, expected) -> None: daft_table = MicroPartition.from_pydict( { - "group": Series.from_pylist([b"1", b"1", None]), - "cookies": [2, 2, 3], + "group": Series.from_arrow(pa.array([b"1", b"1", None, None], type=type)), + "cookies": Series.from_arrow(pa.array([b"1", b"2", b"3", b"4"], type=type)), } ) - result_table = daft_table.agg([col("cookies").sum()], group_by=[col("group")]) + result_table = daft_table.agg([agg], group_by=[col("group")]) expected_table = MicroPartition.from_pydict( { "group": Series.from_pylist([b"1", None]), - "cookies": [4, 3], + "cookies": expected, } ) @@ -515,38 +549,26 @@ def test_groupby_binary_bool_some_nulls() -> None: ) -def test_groupby_binary_no_nulls() -> None: - daft_table = MicroPartition.from_pydict( - { - "group": Series.from_pylist([b"1", b"0", b"1", b"0"]), - "cookies": [1, 2, 2, 3], - } - ) - result_table = daft_table.agg([col("cookies").sum()], group_by=[col("group")]) - expected_table = MicroPartition.from_pydict( - { - "group": Series.from_pylist([b"0", b"1"]), - "cookies": [5, 3], - } - ) - - assert set(utils.freeze(utils.pydict_to_rows(result_table.to_pydict()))) == set( - utils.freeze(utils.pydict_to_rows(expected_table.to_pydict())) - ) - - -def test_groupby_binary_no_nulls_max() -> None: +@pytest.mark.parametrize("type", [pa.binary(), pa.binary(1)]) +@pytest.mark.parametrize( + "agg, expected", + [ + (col("cookies").max(), [b"4", b"3"]), + (col("cookies").min(), [b"2", b"1"]), + ], +) +def test_groupby_binary_no_nulls(type, agg, expected) -> None: daft_table = MicroPartition.from_pydict( { - "group": Series.from_pylist([b"1", b"0", b"1", b"0"]), - "cookies": [b"1", b"2", b"2", b"3"], + "group": Series.from_arrow(pa.array([b"1", b"0", b"1", b"0"], type=type)), + "cookies": Series.from_arrow(pa.array([b"1", b"2", b"3", b"4"], type=type)), } ) - result_table = daft_table.agg([col("cookies").max()], group_by=[col("group")]) + result_table = daft_table.agg([agg], group_by=[col("group")]) expected_table = MicroPartition.from_pydict( { "group": Series.from_pylist([b"0", b"1"]), - "cookies": [b"3", "2"], + "cookies": expected, } ) diff --git a/tests/test_datatypes.py b/tests/test_datatypes.py index 1b4c723b4b..5df20df90f 100644 --- a/tests/test_datatypes.py +++ b/tests/test_datatypes.py @@ -19,7 +19,11 @@ daft_numeric_types = daft_int_types + [DataType.float32(), DataType.float64()] daft_string_types = [DataType.string()] -daft_nonnull_types = daft_numeric_types + daft_string_types + [DataType.bool(), DataType.binary(), DataType.date()] +daft_nonnull_types = ( + daft_numeric_types + + daft_string_types + + [DataType.bool(), DataType.binary(), DataType.fixed_size_binary(1), DataType.date()] +) @pytest.mark.parametrize("dtype", daft_nonnull_types)