From 005e434194208e8b51c72d20ac2745b78c9ffe8d Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Fri, 5 Apr 2024 11:13:51 -0700 Subject: [PATCH] move iterator logic out and use macros --- .../src/array/fixed_size_list_array.rs | 24 +++ src/daft-core/src/array/list_array.rs | 22 +++ src/daft-core/src/array/ops/list.rs | 140 +++++------------- 3 files changed, 82 insertions(+), 104 deletions(-) diff --git a/src/daft-core/src/array/fixed_size_list_array.rs b/src/daft-core/src/array/fixed_size_list_array.rs index a8faba1a19..7d79d605ad 100644 --- a/src/daft-core/src/array/fixed_size_list_array.rs +++ b/src/daft-core/src/array/fixed_size_list_array.rs @@ -132,6 +132,30 @@ impl FixedSizeListArray { )) } + pub fn iter(&self) -> Box> + '_> { + let step = self.fixed_element_len(); + + if let Some(validity) = self.validity() { + Box::new((0..self.len()).map(move |i| { + if validity.get_bit(i) { + let start = i * step; + let end = (i + 1) * step; + + Some(self.flat_child.slice(start, end).unwrap()) + } else { + None + } + })) + } else { + Box::new((0..self.len()).map(move |i| { + let start = i * step; + let end = (i + 1) * step; + + Some(self.flat_child.slice(start, end).unwrap()) + })) + } + } + pub fn to_arrow(&self) -> Box { let arrow_dtype = self.data_type().to_arrow().unwrap(); Box::new(arrow2::array::FixedSizeListArray::new( diff --git a/src/daft-core/src/array/list_array.rs b/src/daft-core/src/array/list_array.rs index 766f6d5d6b..fb13283e0d 100644 --- a/src/daft-core/src/array/list_array.rs +++ b/src/daft-core/src/array/list_array.rs @@ -151,6 +151,28 @@ impl ListArray { )) } + pub fn iter(&self) -> Box> + '_> { + if let Some(validity) = self.validity() { + Box::new((0..self.len()).map(|i| { + if validity.get_bit(i) { + let start = *self.offsets().get(i).unwrap() as usize; + let end = *self.offsets().get(i + 1).unwrap() as usize; + + Some(self.flat_child.slice(start, end).unwrap()) + } else { + None + } + })) + } else { + Box::new(self.offsets().windows(2).map(|w| { + let start = w[0] as usize; + let end = w[1] as usize; + + Some(self.flat_child.slice(start, end).unwrap()) + })) + } + } + pub fn to_arrow(&self) -> Box { let arrow_dtype = self.data_type().to_arrow().unwrap(); Box::new(arrow2::array::ListArray::new( diff --git a/src/daft-core/src/array/ops/list.rs b/src/daft-core/src/array/ops/list.rs index e6bac8a497..1bfdc635c4 100644 --- a/src/daft-core/src/array/ops/list.rs +++ b/src/daft-core/src/array/ops/list.rs @@ -195,60 +195,6 @@ impl ListArray { } } } - - fn agg_helper(&self, op: T) -> DaftResult - where - T: Fn(&Series) -> DaftResult, - { - let aggs = if let Some(validity) = self.validity() { - let test_result = op(&Series::empty("", self.flat_child.data_type()))?; - - (0..self.len()) - .map(|i| { - if validity.get_bit(i) { - let start = *self.offsets().get(i).unwrap() as usize; - let end = *self.offsets().get(i + 1).unwrap() as usize; - - let slice = self.flat_child.slice(start, end)?; - op(&slice) - } else { - Ok(Series::full_null("", test_result.data_type(), 1)) - } - }) - .collect::>>()? - } else { - self.offsets() - .windows(2) - .map(|w| { - let start = w[0] as usize; - let end = w[1] as usize; - - let slice = self.flat_child.slice(start, end)?; - op(&slice) - }) - .collect::>>()? - }; - - let agg_refs: Vec<_> = aggs.iter().collect(); - - Ok(Series::concat(agg_refs.as_slice())?.rename(self.name())) - } - - pub fn sum(&self) -> DaftResult { - self.agg_helper(|s| s.sum(None)) - } - - pub fn mean(&self) -> DaftResult { - self.agg_helper(|s| s.mean(None)) - } - - pub fn min(&self) -> DaftResult { - self.agg_helper(|s| s.min(None)) - } - - pub fn max(&self) -> DaftResult { - self.agg_helper(|s| s.max(None)) - } } impl FixedSizeListArray { @@ -388,59 +334,45 @@ impl FixedSizeListArray { } } } +} - fn agg_helper(&self, op: T) -> DaftResult - where - T: Fn(&Series) -> DaftResult, - { - let step = self.fixed_element_len(); - - let aggs = if let Some(validity) = self.validity() { - let test_result = op(&Series::empty("", self.flat_child.data_type()))?; - - (0..self.len()) - .map(|i| { - if validity.get_bit(i) { - let start = i * step; - let end = (i + 1) * step; - - let slice = self.flat_child.slice(start, end)?; - op(&slice) - } else { - Ok(Series::full_null("", test_result.data_type(), 1)) - } - }) - .collect::>>()? - } else { - (0..self.len()) - .map(|i| { - let start = i * step; - let end = (i + 1) * step; - - let slice = self.flat_child.slice(start, end)?; - op(&slice) - }) - .collect::>>()? - }; - - let agg_refs: Vec<_> = aggs.iter().collect(); - - Series::concat(agg_refs.as_slice()).map(|s| s.rename(self.name())) - } +macro_rules! impl_aggs_list_array { + ($la:ident) => { + impl $la { + fn agg_helper(&self, op: T) -> DaftResult + where + T: Fn(&Series) -> DaftResult, + { + // Assumes `op`` returns a null Series given an empty Series + let aggs = self + .iter() + .map(|s| s.unwrap_or(Series::empty("", self.child_data_type()))) + .map(|s| op(&s)) + .collect::>>()?; + + let agg_refs: Vec<_> = aggs.iter().collect(); + + Series::concat(agg_refs.as_slice()).map(|s| s.rename(self.name())) + } - pub fn sum(&self) -> DaftResult { - self.agg_helper(|s| s.sum(None)) - } + pub fn sum(&self) -> DaftResult { + self.agg_helper(|s| s.sum(None)) + } - pub fn mean(&self) -> DaftResult { - self.agg_helper(|s| s.mean(None)) - } + pub fn mean(&self) -> DaftResult { + self.agg_helper(|s| s.mean(None)) + } - pub fn min(&self) -> DaftResult { - self.agg_helper(|s| s.min(None)) - } + pub fn min(&self) -> DaftResult { + self.agg_helper(|s| s.min(None)) + } - pub fn max(&self) -> DaftResult { - self.agg_helper(|s| s.max(None)) - } + pub fn max(&self) -> DaftResult { + self.agg_helper(|s| s.max(None)) + } + } + }; } + +impl_aggs_list_array!(ListArray); +impl_aggs_list_array!(FixedSizeListArray);