Skip to content

Commit

Permalink
move iterator logic out and use macros
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinzwang committed Apr 5, 2024
1 parent c40a90c commit 005e434
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 104 deletions.
24 changes: 24 additions & 0 deletions src/daft-core/src/array/fixed_size_list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,30 @@ impl FixedSizeListArray {
))
}

pub fn iter(&self) -> Box<dyn Iterator<Item = Option<Series>> + '_> {
let step = self.fixed_element_len();

if let Some(validity) = self.validity() {
Box::new((0..self.len()).map(move |i| {
if validity.get_bit(i) {
let start = i * step;
let end = (i + 1) * step;

Some(self.flat_child.slice(start, end).unwrap())
} else {
None
}
}))
} else {
Box::new((0..self.len()).map(move |i| {
let start = i * step;
let end = (i + 1) * step;

Some(self.flat_child.slice(start, end).unwrap())
}))
}
}

pub fn to_arrow(&self) -> Box<dyn arrow2::array::Array> {
let arrow_dtype = self.data_type().to_arrow().unwrap();
Box::new(arrow2::array::FixedSizeListArray::new(
Expand Down
22 changes: 22 additions & 0 deletions src/daft-core/src/array/list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,28 @@ impl ListArray {
))
}

pub fn iter(&self) -> Box<dyn Iterator<Item = Option<Series>> + '_> {
if let Some(validity) = self.validity() {
Box::new((0..self.len()).map(|i| {
if validity.get_bit(i) {
let start = *self.offsets().get(i).unwrap() as usize;
let end = *self.offsets().get(i + 1).unwrap() as usize;

Some(self.flat_child.slice(start, end).unwrap())
} else {
None
}
}))
} else {
Box::new(self.offsets().windows(2).map(|w| {
let start = w[0] as usize;
let end = w[1] as usize;

Some(self.flat_child.slice(start, end).unwrap())
}))
}
}

pub fn to_arrow(&self) -> Box<dyn arrow2::array::Array> {
let arrow_dtype = self.data_type().to_arrow().unwrap();
Box::new(arrow2::array::ListArray::new(
Expand Down
140 changes: 36 additions & 104 deletions src/daft-core/src/array/ops/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,60 +195,6 @@ impl ListArray {
}
}
}

fn agg_helper<T>(&self, op: T) -> DaftResult<Series>
where
T: Fn(&Series) -> DaftResult<Series>,
{
let aggs = if let Some(validity) = self.validity() {
let test_result = op(&Series::empty("", self.flat_child.data_type()))?;

(0..self.len())
.map(|i| {
if validity.get_bit(i) {
let start = *self.offsets().get(i).unwrap() as usize;
let end = *self.offsets().get(i + 1).unwrap() as usize;

let slice = self.flat_child.slice(start, end)?;
op(&slice)
} else {
Ok(Series::full_null("", test_result.data_type(), 1))
}
})
.collect::<DaftResult<Vec<_>>>()?
} else {
self.offsets()
.windows(2)
.map(|w| {
let start = w[0] as usize;
let end = w[1] as usize;

let slice = self.flat_child.slice(start, end)?;
op(&slice)
})
.collect::<DaftResult<Vec<_>>>()?
};

let agg_refs: Vec<_> = aggs.iter().collect();

Ok(Series::concat(agg_refs.as_slice())?.rename(self.name()))
}

pub fn sum(&self) -> DaftResult<Series> {
self.agg_helper(|s| s.sum(None))
}

pub fn mean(&self) -> DaftResult<Series> {
self.agg_helper(|s| s.mean(None))
}

pub fn min(&self) -> DaftResult<Series> {
self.agg_helper(|s| s.min(None))
}

pub fn max(&self) -> DaftResult<Series> {
self.agg_helper(|s| s.max(None))
}
}

impl FixedSizeListArray {
Expand Down Expand Up @@ -388,59 +334,45 @@ impl FixedSizeListArray {
}
}
}
}

fn agg_helper<T>(&self, op: T) -> DaftResult<Series>
where
T: Fn(&Series) -> DaftResult<Series>,
{
let step = self.fixed_element_len();

let aggs = if let Some(validity) = self.validity() {
let test_result = op(&Series::empty("", self.flat_child.data_type()))?;

(0..self.len())
.map(|i| {
if validity.get_bit(i) {
let start = i * step;
let end = (i + 1) * step;

let slice = self.flat_child.slice(start, end)?;
op(&slice)
} else {
Ok(Series::full_null("", test_result.data_type(), 1))
}
})
.collect::<DaftResult<Vec<_>>>()?
} else {
(0..self.len())
.map(|i| {
let start = i * step;
let end = (i + 1) * step;

let slice = self.flat_child.slice(start, end)?;
op(&slice)
})
.collect::<DaftResult<Vec<_>>>()?
};

let agg_refs: Vec<_> = aggs.iter().collect();

Series::concat(agg_refs.as_slice()).map(|s| s.rename(self.name()))
}
macro_rules! impl_aggs_list_array {
($la:ident) => {
impl $la {
fn agg_helper<T>(&self, op: T) -> DaftResult<Series>
where
T: Fn(&Series) -> DaftResult<Series>,
{
// Assumes `op`` returns a null Series given an empty Series
let aggs = self
.iter()
.map(|s| s.unwrap_or(Series::empty("", self.child_data_type())))
.map(|s| op(&s))
.collect::<DaftResult<Vec<_>>>()?;

let agg_refs: Vec<_> = aggs.iter().collect();

Series::concat(agg_refs.as_slice()).map(|s| s.rename(self.name()))
}

pub fn sum(&self) -> DaftResult<Series> {
self.agg_helper(|s| s.sum(None))
}
pub fn sum(&self) -> DaftResult<Series> {
self.agg_helper(|s| s.sum(None))
}

pub fn mean(&self) -> DaftResult<Series> {
self.agg_helper(|s| s.mean(None))
}
pub fn mean(&self) -> DaftResult<Series> {
self.agg_helper(|s| s.mean(None))
}

pub fn min(&self) -> DaftResult<Series> {
self.agg_helper(|s| s.min(None))
}
pub fn min(&self) -> DaftResult<Series> {
self.agg_helper(|s| s.min(None))
}

pub fn max(&self) -> DaftResult<Series> {
self.agg_helper(|s| s.max(None))
}
pub fn max(&self) -> DaftResult<Series> {
self.agg_helper(|s| s.max(None))
}
}
};
}

impl_aggs_list_array!(ListArray);
impl_aggs_list_array!(FixedSizeListArray);

0 comments on commit 005e434

Please sign in to comment.