Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add basic list aggregations #2032

Merged
merged 11 commits into from
Apr 11, 2024
Merged
8 changes: 6 additions & 2 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -943,8 +943,12 @@ class PyExpr:
def image_resize(self, w: int, h: int) -> PyExpr: ...
def image_crop(self, bbox: PyExpr) -> PyExpr: ...
def list_join(self, delimiter: PyExpr) -> PyExpr: ...
def list_lengths(self) -> PyExpr: ...
def list_count(self, mode: CountMode) -> PyExpr: ...
def list_get(self, idx: PyExpr, default: PyExpr) -> PyExpr: ...
def list_sum(self) -> PyExpr: ...
def list_mean(self) -> PyExpr: ...
def list_min(self) -> PyExpr: ...
def list_max(self) -> PyExpr: ...
def struct_get(self, name: str) -> PyExpr: ...
def url_download(
self, max_connections: int, raise_error_on_failure: bool, multi_thread: bool, config: IOConfig
Expand Down Expand Up @@ -1037,7 +1041,7 @@ class PySeries:
def partitioning_years(self) -> PySeries: ...
def partitioning_iceberg_bucket(self, n: int) -> PySeries: ...
def partitioning_iceberg_truncate(self, w: int) -> PySeries: ...
def list_lengths(self) -> PySeries: ...
def list_count(self, mode: CountMode) -> PySeries: ...
def list_get(self, idx: PySeries, default: PySeries) -> PySeries: ...
def image_decode(self) -> PySeries: ...
def image_encode(self, image_format: ImageFormat) -> PySeries: ...
Expand Down
45 changes: 44 additions & 1 deletion daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,13 +868,24 @@ def join(self, delimiter: str | Expression) -> Expression:
delimiter_expr = Expression._to_expression(delimiter)
return Expression._from_pyexpr(self._expr.list_join(delimiter_expr._expr))

def count(self, mode: CountMode = CountMode.Valid) -> Expression:
"""Counts the number of elements in each list

Args:
mode: The mode to use for counting. Defaults to CountMode.Valid

Returns:
Expression: a UInt64 expression which is the length of each list
"""
return Expression._from_pyexpr(self._expr.list_count(mode))

def lengths(self) -> Expression:
"""Gets the length of each list

Returns:
Expression: a UInt64 expression which is the length of each list
"""
return Expression._from_pyexpr(self._expr.list_lengths())
return Expression._from_pyexpr(self._expr.list_count(CountMode.All))

def get(self, idx: int | Expression, default: object = None) -> Expression:
"""Gets the element at an index in each list
Expand All @@ -890,6 +901,38 @@ def get(self, idx: int | Expression, default: object = None) -> Expression:
default_expr = lit(default)
return Expression._from_pyexpr(self._expr.list_get(idx_expr._expr, default_expr._expr))

def sum(self) -> Expression:
"""Sums each list. Empty lists and lists with all nulls yield null.

Returns:
Expression: an expression with the type of the list values
"""
return Expression._from_pyexpr(self._expr.list_sum())

def mean(self) -> Expression:
"""Calculates the mean of each list. If no non-null values in a list, the result is null.

Returns:
Expression: a Float64 expression with the type of the list values
"""
return Expression._from_pyexpr(self._expr.list_mean())

def min(self) -> Expression:
"""Calculates the minimum of each list. If no non-null values in a list, the result is null.

Returns:
Expression: a Float64 expression with the type of the list values
"""
return Expression._from_pyexpr(self._expr.list_min())

def max(self) -> Expression:
"""Calculates the maximum of each list. If no non-null values in a list, the result is null.

Returns:
Expression: a Float64 expression with the type of the list values
"""
return Expression._from_pyexpr(self._expr.list_max())


class ExpressionStructNamespace(ExpressionNamespace):
def get(self, name: str) -> Expression:
Expand Down
2 changes: 1 addition & 1 deletion daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ def iceberg_truncate(self, w: int) -> Series:

class SeriesListNamespace(SeriesNamespace):
def lengths(self) -> Series:
return Series._from_pyseries(self._series.list_lengths())
return Series._from_pyseries(self._series.list_count(CountMode.All))

def get(self, idx: Series, default: Series) -> Series:
return Series._from_pyseries(self._series.list_get(idx._series, default._series))
Expand Down
24 changes: 24 additions & 0 deletions src/daft-core/src/array/fixed_size_list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,30 @@ impl FixedSizeListArray {
))
}

pub fn iter(&self) -> Box<dyn Iterator<Item = Option<Series>> + '_> {
kevinzwang marked this conversation as resolved.
Show resolved Hide resolved
let step = self.fixed_element_len();

if let Some(validity) = self.validity() {
Box::new((0..self.len()).map(move |i| {
if validity.get_bit(i) {
let start = i * step;
let end = (i + 1) * step;

Some(self.flat_child.slice(start, end).unwrap())
} else {
None
}
}))
} else {
Box::new((0..self.len()).map(move |i| {
let start = i * step;
let end = (i + 1) * step;

Some(self.flat_child.slice(start, end).unwrap())
}))
}
}

pub fn to_arrow(&self) -> Box<dyn arrow2::array::Array> {
let arrow_dtype = self.data_type().to_arrow().unwrap();
Box::new(arrow2::array::FixedSizeListArray::new(
Expand Down
22 changes: 22 additions & 0 deletions src/daft-core/src/array/list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,28 @@ impl ListArray {
))
}

pub fn iter(&self) -> Box<dyn Iterator<Item = Option<Series>> + '_> {
kevinzwang marked this conversation as resolved.
Show resolved Hide resolved
if let Some(validity) = self.validity() {
Box::new((0..self.len()).map(|i| {
if validity.get_bit(i) {
let start = *self.offsets().get(i).unwrap() as usize;
let end = *self.offsets().get(i + 1).unwrap() as usize;

Some(self.flat_child.slice(start, end).unwrap())
} else {
None
}
}))
} else {
Box::new(self.offsets().windows(2).map(|w| {
let start = w[0] as usize;
let end = w[1] as usize;

Some(self.flat_child.slice(start, end).unwrap())
}))
}
}

pub fn to_arrow(&self) -> Box<dyn arrow2::array::Array> {
let arrow_dtype = self.data_type().to_arrow().unwrap();
Box::new(arrow2::array::ListArray::new(
Expand Down
2 changes: 1 addition & 1 deletion src/daft-core/src/array/ops/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ fn grouped_count_arrow_bitmap(
.iter()
.map(|g| {
g.iter()
.map(|i| validity.get_bit(!*i as usize) as u64)
.map(|i| !validity.get_bit(*i as usize) as u64)
.sum()
})
.collect(),
Expand Down
129 changes: 101 additions & 28 deletions src/daft-core/src/array/ops/list.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use std::iter::repeat;

use crate::array::{
growable::{make_growable, Growable},
FixedSizeListArray, ListArray,
use crate::datatypes::{Int64Array, Utf8Array};
use crate::{
array::{
growable::{make_growable, Growable},
FixedSizeListArray, ListArray,
},
datatypes::UInt64Array,
};
use crate::datatypes::{Int64Array, UInt64Array, Utf8Array};
use crate::DataType;
use crate::{CountMode, DataType};

use crate::series::Series;

Expand Down Expand Up @@ -42,11 +45,34 @@ fn join_arrow_list_of_utf8s(
}

impl ListArray {
pub fn lengths(&self) -> DaftResult<UInt64Array> {
let lengths = self.offsets().lengths().map(|l| Some(l as u64));
pub fn count(&self, mode: CountMode) -> DaftResult<UInt64Array> {
let counts = match (mode, self.flat_child.validity()) {
(CountMode::All, _) | (CountMode::Valid, None) => {
self.offsets().lengths().map(|l| l as u64).collect()
}
(CountMode::Valid, Some(validity)) => self
.offsets()
kevinzwang marked this conversation as resolved.
Show resolved Hide resolved
.windows(2)
.map(|w| {
(w[0]..w[1])
.map(|i| validity.get_bit(i as usize) as u64)
.sum()
})
.collect(),
(CountMode::Null, None) => repeat(0).take(self.offsets().len() - 1).collect(),
(CountMode::Null, Some(validity)) => self
.offsets()
kevinzwang marked this conversation as resolved.
Show resolved Hide resolved
.windows(2)
.map(|w| {
(w[0]..w[1])
.map(|i| !validity.get_bit(i as usize) as u64)
.sum()
})
.collect(),
};

let array = Box::new(
arrow2::array::PrimitiveArray::from_iter(lengths)
.with_validity(self.validity().cloned()),
arrow2::array::PrimitiveArray::from_vec(counts).with_validity(self.validity().cloned()),
);
Ok(UInt64Array::from((self.name(), array)))
}
Expand Down Expand Up @@ -172,27 +198,33 @@ impl ListArray {
}

impl FixedSizeListArray {
pub fn lengths(&self) -> DaftResult<UInt64Array> {
pub fn count(&self, mode: CountMode) -> DaftResult<UInt64Array> {
let size = self.fixed_element_len();
match self.validity() {
None => Ok(UInt64Array::from((
self.name(),
repeat(size as u64)
.take(self.len())
.collect::<Vec<_>>()
.as_slice(),
))),
Some(validity) => {
let arrow_arr = arrow2::array::UInt64Array::from_iter(validity.iter().map(|v| {
if v {
Some(size as u64)
} else {
None
}
}));
Ok(UInt64Array::from((self.name(), Box::new(arrow_arr))))
let counts = match (mode, self.flat_child.validity()) {
(CountMode::All, _) | (CountMode::Valid, None) => {
repeat(size as u64).take(self.len()).collect()
}
}
(CountMode::Valid, Some(validity)) => (0..self.len())
.map(|i| {
(0..size)
.map(|j| validity.get_bit(i * size + j) as u64)
.sum()
})
.collect(),
(CountMode::Null, None) => repeat(0).take(self.len()).collect(),
(CountMode::Null, Some(validity)) => (0..self.len())
.map(|i| {
(0..size)
.map(|j| !validity.get_bit(i * size + j) as u64)
.sum()
})
.collect(),
};

let array = Box::new(
arrow2::array::PrimitiveArray::from_vec(counts).with_validity(self.validity().cloned()),
);
Ok(UInt64Array::from((self.name(), array)))
}

pub fn explode(&self) -> DaftResult<Series> {
Expand Down Expand Up @@ -303,3 +335,44 @@ impl FixedSizeListArray {
}
}
}

macro_rules! impl_aggs_list_array {
($la:ident) => {
impl $la {
fn agg_helper<T>(&self, op: T) -> DaftResult<Series>
where
T: Fn(&Series) -> DaftResult<Series>,
{
// Assumes `op`` returns a null Series given an empty Series
let aggs = self
.iter()
.map(|s| s.unwrap_or(Series::empty("", self.child_data_type())))
.map(|s| op(&s))
.collect::<DaftResult<Vec<_>>>()?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably want to have a concat that takes in an iterator of Series for this. materializing all the Series then concating it is going to be slow.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tabled until we create an ArrayBuilder so that we don't have to fully materialize the Series objects to concat


let agg_refs: Vec<_> = aggs.iter().collect();

Series::concat(agg_refs.as_slice()).map(|s| s.rename(self.name()))
}

pub fn sum(&self) -> DaftResult<Series> {
self.agg_helper(|s| s.sum(None))
}

pub fn mean(&self) -> DaftResult<Series> {
self.agg_helper(|s| s.mean(None))
}

pub fn min(&self) -> DaftResult<Series> {
self.agg_helper(|s| s.min(None))
}

pub fn max(&self) -> DaftResult<Series> {
self.agg_helper(|s| s.max(None))
}
}
};
}

impl_aggs_list_array!(ListArray);
impl_aggs_list_array!(FixedSizeListArray);
28 changes: 28 additions & 0 deletions src/daft-core/src/datatypes/binary_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,3 +307,31 @@ pub fn try_numeric_supertype(l: &DataType, r: &DataType) -> DaftResult<DataType>
l, r
)))
}

/// Get the data type that the sum of a column of the given data type should be casted to.
pub fn try_sum_supertype(dtype: &DataType) -> DaftResult<DataType> {
kevinzwang marked this conversation as resolved.
Show resolved Hide resolved
use DataType::*;
match dtype {
Int8 | Int16 | Int32 | Int64 => Ok(Int64),
UInt8 | UInt16 | UInt32 | UInt64 => Ok(UInt64),
Float32 => Ok(Float32),
Float64 => Ok(Float64),
other => Err(DaftError::TypeError(format!(
"Invalid argument to sum supertype: {}",
other
))),
}
}

/// Get the data type that the mean of a column of the given data type should be casted to.
pub fn try_mean_supertype(dtype: &DataType) -> DaftResult<DataType> {
use DataType::*;
if dtype.is_numeric() {
Ok(Float64)
} else {
Err(DaftError::TypeError(format!(
"Invalid argument to mean supertype: {}",
dtype
)))
}
}
2 changes: 1 addition & 1 deletion src/daft-core/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use arrow2::{
compute::comparison::Simd8,
types::{simd::Simd, NativeType},
};
pub use binary_ops::try_physical_supertype;
pub use binary_ops::{try_mean_supertype, try_physical_supertype, try_sum_supertype};
pub use dtype::DataType;
pub use field::Field;
pub use field::FieldID;
Expand Down
4 changes: 2 additions & 2 deletions src/daft-core/src/python/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,8 @@ impl PySeries {
Ok(self.series.murmur3_32()?.into_series().into())
}

pub fn list_lengths(&self) -> PyResult<Self> {
Ok(self.series.list_lengths()?.into_series().into())
pub fn list_count(&self, mode: CountMode) -> PyResult<Self> {
Ok(self.series.list_count(mode)?.into_series().into())
}

pub fn list_get(&self, idx: &Self, default: &Self) -> PyResult<Self> {
Expand Down
Loading
Loading