Skip to content

Commit

Permalink
implement IntoIterator for list types
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinzwang committed Apr 10, 2024
1 parent 005e434 commit ce971a8
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 76 deletions.
65 changes: 41 additions & 24 deletions src/daft-core/src/array/fixed_size_list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,30 +132,6 @@ impl FixedSizeListArray {
))
}

pub fn iter(&self) -> Box<dyn Iterator<Item = Option<Series>> + '_> {
let step = self.fixed_element_len();

if let Some(validity) = self.validity() {
Box::new((0..self.len()).map(move |i| {
if validity.get_bit(i) {
let start = i * step;
let end = (i + 1) * step;

Some(self.flat_child.slice(start, end).unwrap())
} else {
None
}
}))
} else {
Box::new((0..self.len()).map(move |i| {
let start = i * step;
let end = (i + 1) * step;

Some(self.flat_child.slice(start, end).unwrap())
}))
}
}

pub fn to_arrow(&self) -> Box<dyn arrow2::array::Array> {
let arrow_dtype = self.data_type().to_arrow().unwrap();
Box::new(arrow2::array::FixedSizeListArray::new(
Expand Down Expand Up @@ -190,6 +166,47 @@ impl FixedSizeListArray {
}
}

impl<'a> IntoIterator for &'a FixedSizeListArray {
type Item = Option<Series>;

type IntoIter = FixedSizeListArrayIter<'a>;

fn into_iter(self) -> Self::IntoIter {
FixedSizeListArrayIter {
array: self,
idx: 0,
}
}
}

pub struct FixedSizeListArrayIter<'a> {
array: &'a FixedSizeListArray,
idx: usize,
}

impl Iterator for FixedSizeListArrayIter<'_> {
type Item = Option<Series>;

fn next(&mut self) -> Option<Self::Item> {
if self.idx < self.array.len() {
if let Some(validity) = self.array.validity() && !validity.get_bit(self.idx) {
self.idx += 1;
Some(None)
} else {
let step = self.array.fixed_element_len();

let start = self.idx * step;
let end = (self.idx + 1) * step;

self.idx += 1;
Some(Some(self.array.flat_child.slice(start, end).unwrap()))
}
} else {
None
}
}
}

#[cfg(test)]
mod tests {
use common_error::DaftResult;
Expand Down
61 changes: 39 additions & 22 deletions src/daft-core/src/array/list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,28 +151,6 @@ impl ListArray {
))
}

pub fn iter(&self) -> Box<dyn Iterator<Item = Option<Series>> + '_> {
if let Some(validity) = self.validity() {
Box::new((0..self.len()).map(|i| {
if validity.get_bit(i) {
let start = *self.offsets().get(i).unwrap() as usize;
let end = *self.offsets().get(i + 1).unwrap() as usize;

Some(self.flat_child.slice(start, end).unwrap())
} else {
None
}
}))
} else {
Box::new(self.offsets().windows(2).map(|w| {
let start = w[0] as usize;
let end = w[1] as usize;

Some(self.flat_child.slice(start, end).unwrap())
}))
}
}

pub fn to_arrow(&self) -> Box<dyn arrow2::array::Array> {
let arrow_dtype = self.data_type().to_arrow().unwrap();
Box::new(arrow2::array::ListArray::new(
Expand Down Expand Up @@ -200,3 +178,42 @@ impl ListArray {
))
}
}

impl<'a> IntoIterator for &'a ListArray {
type Item = Option<Series>;

type IntoIter = ListArrayIter<'a>;

fn into_iter(self) -> Self::IntoIter {
ListArrayIter {
array: self,
idx: 0,
}
}
}

pub struct ListArrayIter<'a> {
array: &'a ListArray,
idx: usize,
}

impl Iterator for ListArrayIter<'_> {
type Item = Option<Series>;

fn next(&mut self) -> Option<Self::Item> {
if self.idx < self.array.len() {
if let Some(validity) = self.array.validity() && !validity.get_bit(self.idx) {
self.idx += 1;
Some(None)
} else {
let start = *self.array.offsets().get(self.idx).unwrap() as usize;
let end = *self.array.offsets().get(self.idx + 1).unwrap() as usize;

self.idx += 1;
Some(Some(self.array.flat_child.slice(start, end).unwrap()))
}
} else {
None
}
}
}
4 changes: 3 additions & 1 deletion src/daft-core/src/array/ops/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,11 @@ macro_rules! impl_aggs_list_array {
where
T: Fn(&Series) -> DaftResult<Series>,
{
// TODO(Kevin): Currently this requires full materialization of one Series for every list. We could avoid this by implementing either sorted aggregation or an array builder

// Assumes `op`` returns a null Series given an empty Series
let aggs = self
.iter()
.into_iter()
.map(|s| s.unwrap_or(Series::empty("", self.child_data_type())))
.map(|s| op(&s))
.collect::<DaftResult<Vec<_>>>()?;
Expand Down
31 changes: 31 additions & 0 deletions src/daft-core/src/datatypes/agg_ops.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use common_error::{DaftError, DaftResult};

use super::DataType;

/// Get the data type that the sum of a column of the given data type should be casted to.
pub fn try_sum_supertype(dtype: &DataType) -> DaftResult<DataType> {
use DataType::*;
match dtype {
Int8 | Int16 | Int32 | Int64 => Ok(Int64),
UInt8 | UInt16 | UInt32 | UInt64 => Ok(UInt64),
Float32 => Ok(Float32),
Float64 => Ok(Float64),
other => Err(DaftError::TypeError(format!(
"Invalid argument to sum supertype: {}",
other
))),
}
}

/// Get the data type that the mean of a column of the given data type should be casted to.
pub fn try_mean_supertype(dtype: &DataType) -> DaftResult<DataType> {
use DataType::*;
if dtype.is_numeric() {
Ok(Float64)
} else {
Err(DaftError::TypeError(format!(
"Invalid argument to mean supertype: {}",
dtype
)))
}
}
28 changes: 0 additions & 28 deletions src/daft-core/src/datatypes/binary_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,31 +307,3 @@ pub fn try_numeric_supertype(l: &DataType, r: &DataType) -> DaftResult<DataType>
l, r
)))
}

/// Get the data type that the sum of a column of the given data type should be casted to.
pub fn try_sum_supertype(dtype: &DataType) -> DaftResult<DataType> {
use DataType::*;
match dtype {
Int8 | Int16 | Int32 | Int64 => Ok(Int64),
UInt8 | UInt16 | UInt32 | UInt64 => Ok(UInt64),
Float32 => Ok(Float32),
Float64 => Ok(Float64),
other => Err(DaftError::TypeError(format!(
"Invalid argument to sum supertype: {}",
other
))),
}
}

/// Get the data type that the mean of a column of the given data type should be casted to.
pub fn try_mean_supertype(dtype: &DataType) -> DaftResult<DataType> {
use DataType::*;
if dtype.is_numeric() {
Ok(Float64)
} else {
Err(DaftError::TypeError(format!(
"Invalid argument to mean supertype: {}",
dtype
)))
}
}
4 changes: 3 additions & 1 deletion src/daft-core/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod agg_ops;
mod binary_ops;
mod dtype;
mod field;
Expand All @@ -8,11 +9,12 @@ mod time_unit;

pub use crate::array::{DataArray, FixedSizeListArray};
use crate::array::{ListArray, StructArray};
pub use agg_ops::{try_mean_supertype, try_sum_supertype};
use arrow2::{
compute::comparison::Simd8,
types::{simd::Simd, NativeType},
};
pub use binary_ops::{try_mean_supertype, try_physical_supertype, try_sum_supertype};
pub use binary_ops::try_physical_supertype;
pub use dtype::DataType;
pub use field::Field;
pub use field::FieldID;
Expand Down

0 comments on commit ce971a8

Please sign in to comment.