From ce971a83070ffb4a46a6ca463a1402b4a5c6e7e2 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Wed, 10 Apr 2024 16:01:04 -0700 Subject: [PATCH] implement IntoIterator for list types --- .../src/array/fixed_size_list_array.rs | 65 ++++++++++++------- src/daft-core/src/array/list_array.rs | 61 ++++++++++------- src/daft-core/src/array/ops/list.rs | 4 +- src/daft-core/src/datatypes/agg_ops.rs | 31 +++++++++ src/daft-core/src/datatypes/binary_ops.rs | 28 -------- src/daft-core/src/datatypes/mod.rs | 4 +- 6 files changed, 117 insertions(+), 76 deletions(-) create mode 100644 src/daft-core/src/datatypes/agg_ops.rs diff --git a/src/daft-core/src/array/fixed_size_list_array.rs b/src/daft-core/src/array/fixed_size_list_array.rs index 7d79d605ad..2784f5746c 100644 --- a/src/daft-core/src/array/fixed_size_list_array.rs +++ b/src/daft-core/src/array/fixed_size_list_array.rs @@ -132,30 +132,6 @@ impl FixedSizeListArray { )) } - pub fn iter(&self) -> Box> + '_> { - let step = self.fixed_element_len(); - - if let Some(validity) = self.validity() { - Box::new((0..self.len()).map(move |i| { - if validity.get_bit(i) { - let start = i * step; - let end = (i + 1) * step; - - Some(self.flat_child.slice(start, end).unwrap()) - } else { - None - } - })) - } else { - Box::new((0..self.len()).map(move |i| { - let start = i * step; - let end = (i + 1) * step; - - Some(self.flat_child.slice(start, end).unwrap()) - })) - } - } - pub fn to_arrow(&self) -> Box { let arrow_dtype = self.data_type().to_arrow().unwrap(); Box::new(arrow2::array::FixedSizeListArray::new( @@ -190,6 +166,47 @@ impl FixedSizeListArray { } } +impl<'a> IntoIterator for &'a FixedSizeListArray { + type Item = Option; + + type IntoIter = FixedSizeListArrayIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + FixedSizeListArrayIter { + array: self, + idx: 0, + } + } +} + +pub struct FixedSizeListArrayIter<'a> { + array: &'a FixedSizeListArray, + idx: usize, +} + +impl Iterator for FixedSizeListArrayIter<'_> { + type Item = Option; + + fn next(&mut self) -> Option { + if self.idx < self.array.len() { + if let Some(validity) = self.array.validity() && !validity.get_bit(self.idx) { + self.idx += 1; + Some(None) + } else { + let step = self.array.fixed_element_len(); + + let start = self.idx * step; + let end = (self.idx + 1) * step; + + self.idx += 1; + Some(Some(self.array.flat_child.slice(start, end).unwrap())) + } + } else { + None + } + } +} + #[cfg(test)] mod tests { use common_error::DaftResult; diff --git a/src/daft-core/src/array/list_array.rs b/src/daft-core/src/array/list_array.rs index fb13283e0d..abe879b7bd 100644 --- a/src/daft-core/src/array/list_array.rs +++ b/src/daft-core/src/array/list_array.rs @@ -151,28 +151,6 @@ impl ListArray { )) } - pub fn iter(&self) -> Box> + '_> { - if let Some(validity) = self.validity() { - Box::new((0..self.len()).map(|i| { - if validity.get_bit(i) { - let start = *self.offsets().get(i).unwrap() as usize; - let end = *self.offsets().get(i + 1).unwrap() as usize; - - Some(self.flat_child.slice(start, end).unwrap()) - } else { - None - } - })) - } else { - Box::new(self.offsets().windows(2).map(|w| { - let start = w[0] as usize; - let end = w[1] as usize; - - Some(self.flat_child.slice(start, end).unwrap()) - })) - } - } - pub fn to_arrow(&self) -> Box { let arrow_dtype = self.data_type().to_arrow().unwrap(); Box::new(arrow2::array::ListArray::new( @@ -200,3 +178,42 @@ impl ListArray { )) } } + +impl<'a> IntoIterator for &'a ListArray { + type Item = Option; + + type IntoIter = ListArrayIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + ListArrayIter { + array: self, + idx: 0, + } + } +} + +pub struct ListArrayIter<'a> { + array: &'a ListArray, + idx: usize, +} + +impl Iterator for ListArrayIter<'_> { + type Item = Option; + + fn next(&mut self) -> Option { + if self.idx < self.array.len() { + if let Some(validity) = self.array.validity() && !validity.get_bit(self.idx) { + self.idx += 1; + Some(None) + } else { + let start = *self.array.offsets().get(self.idx).unwrap() as usize; + let end = *self.array.offsets().get(self.idx + 1).unwrap() as usize; + + self.idx += 1; + Some(Some(self.array.flat_child.slice(start, end).unwrap())) + } + } else { + None + } + } +} diff --git a/src/daft-core/src/array/ops/list.rs b/src/daft-core/src/array/ops/list.rs index 1bfdc635c4..bc93cc642a 100644 --- a/src/daft-core/src/array/ops/list.rs +++ b/src/daft-core/src/array/ops/list.rs @@ -343,9 +343,11 @@ macro_rules! impl_aggs_list_array { where T: Fn(&Series) -> DaftResult, { + // TODO(Kevin): Currently this requires full materialization of one Series for every list. We could avoid this by implementing either sorted aggregation or an array builder + // Assumes `op`` returns a null Series given an empty Series let aggs = self - .iter() + .into_iter() .map(|s| s.unwrap_or(Series::empty("", self.child_data_type()))) .map(|s| op(&s)) .collect::>>()?; diff --git a/src/daft-core/src/datatypes/agg_ops.rs b/src/daft-core/src/datatypes/agg_ops.rs new file mode 100644 index 0000000000..48a89968b6 --- /dev/null +++ b/src/daft-core/src/datatypes/agg_ops.rs @@ -0,0 +1,31 @@ +use common_error::{DaftError, DaftResult}; + +use super::DataType; + +/// Get the data type that the sum of a column of the given data type should be casted to. +pub fn try_sum_supertype(dtype: &DataType) -> DaftResult { + use DataType::*; + match dtype { + Int8 | Int16 | Int32 | Int64 => Ok(Int64), + UInt8 | UInt16 | UInt32 | UInt64 => Ok(UInt64), + Float32 => Ok(Float32), + Float64 => Ok(Float64), + other => Err(DaftError::TypeError(format!( + "Invalid argument to sum supertype: {}", + other + ))), + } +} + +/// Get the data type that the mean of a column of the given data type should be casted to. +pub fn try_mean_supertype(dtype: &DataType) -> DaftResult { + use DataType::*; + if dtype.is_numeric() { + Ok(Float64) + } else { + Err(DaftError::TypeError(format!( + "Invalid argument to mean supertype: {}", + dtype + ))) + } +} diff --git a/src/daft-core/src/datatypes/binary_ops.rs b/src/daft-core/src/datatypes/binary_ops.rs index a0c8f1bc91..6d08a6312c 100644 --- a/src/daft-core/src/datatypes/binary_ops.rs +++ b/src/daft-core/src/datatypes/binary_ops.rs @@ -307,31 +307,3 @@ pub fn try_numeric_supertype(l: &DataType, r: &DataType) -> DaftResult l, r ))) } - -/// Get the data type that the sum of a column of the given data type should be casted to. -pub fn try_sum_supertype(dtype: &DataType) -> DaftResult { - use DataType::*; - match dtype { - Int8 | Int16 | Int32 | Int64 => Ok(Int64), - UInt8 | UInt16 | UInt32 | UInt64 => Ok(UInt64), - Float32 => Ok(Float32), - Float64 => Ok(Float64), - other => Err(DaftError::TypeError(format!( - "Invalid argument to sum supertype: {}", - other - ))), - } -} - -/// Get the data type that the mean of a column of the given data type should be casted to. -pub fn try_mean_supertype(dtype: &DataType) -> DaftResult { - use DataType::*; - if dtype.is_numeric() { - Ok(Float64) - } else { - Err(DaftError::TypeError(format!( - "Invalid argument to mean supertype: {}", - dtype - ))) - } -} diff --git a/src/daft-core/src/datatypes/mod.rs b/src/daft-core/src/datatypes/mod.rs index 2ed24633a5..3b937fbbc5 100644 --- a/src/daft-core/src/datatypes/mod.rs +++ b/src/daft-core/src/datatypes/mod.rs @@ -1,3 +1,4 @@ +mod agg_ops; mod binary_ops; mod dtype; mod field; @@ -8,11 +9,12 @@ mod time_unit; pub use crate::array::{DataArray, FixedSizeListArray}; use crate::array::{ListArray, StructArray}; +pub use agg_ops::{try_mean_supertype, try_sum_supertype}; use arrow2::{ compute::comparison::Simd8, types::{simd::Simd, NativeType}, }; -pub use binary_ops::{try_mean_supertype, try_physical_supertype, try_sum_supertype}; +pub use binary_ops::try_physical_supertype; pub use dtype::DataType; pub use field::Field; pub use field::FieldID;