Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Add basic list aggregations #2032

Merged
merged 11 commits into from
Apr 11, 2024
Merged
8 changes: 6 additions & 2 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -943,8 +943,12 @@ class PyExpr:
def image_resize(self, w: int, h: int) -> PyExpr: ...
def image_crop(self, bbox: PyExpr) -> PyExpr: ...
def list_join(self, delimiter: PyExpr) -> PyExpr: ...
def list_lengths(self) -> PyExpr: ...
def list_count(self, mode: CountMode) -> PyExpr: ...
def list_get(self, idx: PyExpr, default: PyExpr) -> PyExpr: ...
def list_sum(self) -> PyExpr: ...
def list_mean(self) -> PyExpr: ...
def list_min(self) -> PyExpr: ...
def list_max(self) -> PyExpr: ...
def struct_get(self, name: str) -> PyExpr: ...
def url_download(
self, max_connections: int, raise_error_on_failure: bool, multi_thread: bool, config: IOConfig
Expand Down Expand Up @@ -1037,7 +1041,7 @@ class PySeries:
def partitioning_years(self) -> PySeries: ...
def partitioning_iceberg_bucket(self, n: int) -> PySeries: ...
def partitioning_iceberg_truncate(self, w: int) -> PySeries: ...
def list_lengths(self) -> PySeries: ...
def list_count(self, mode: CountMode) -> PySeries: ...
def list_get(self, idx: PySeries, default: PySeries) -> PySeries: ...
def image_decode(self) -> PySeries: ...
def image_encode(self, image_format: ImageFormat) -> PySeries: ...
Expand Down
45 changes: 44 additions & 1 deletion daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,13 +868,24 @@ def join(self, delimiter: str | Expression) -> Expression:
delimiter_expr = Expression._to_expression(delimiter)
return Expression._from_pyexpr(self._expr.list_join(delimiter_expr._expr))

def count(self, mode: CountMode = CountMode.Valid) -> Expression:
"""Counts the number of elements in each list

Args:
mode: The mode to use for counting. Defaults to CountMode.Valid

Returns:
Expression: a UInt64 expression which is the length of each list
"""
return Expression._from_pyexpr(self._expr.list_count(mode))

def lengths(self) -> Expression:
"""Gets the length of each list

Returns:
Expression: a UInt64 expression which is the length of each list
"""
return Expression._from_pyexpr(self._expr.list_lengths())
return Expression._from_pyexpr(self._expr.list_count(CountMode.All))

def get(self, idx: int | Expression, default: object = None) -> Expression:
"""Gets the element at an index in each list
Expand All @@ -890,6 +901,38 @@ def get(self, idx: int | Expression, default: object = None) -> Expression:
default_expr = lit(default)
return Expression._from_pyexpr(self._expr.list_get(idx_expr._expr, default_expr._expr))

def sum(self) -> Expression:
"""Sums each list. Empty lists and lists with all nulls yield null.

Returns:
Expression: an expression with the type of the list values
"""
return Expression._from_pyexpr(self._expr.list_sum())

def mean(self) -> Expression:
"""Calculates the mean of each list. If no non-null values in a list, the result is null.

Returns:
Expression: a Float64 expression with the type of the list values
"""
return Expression._from_pyexpr(self._expr.list_mean())

def min(self) -> Expression:
"""Calculates the minimum of each list. If no non-null values in a list, the result is null.

Returns:
Expression: a Float64 expression with the type of the list values
"""
return Expression._from_pyexpr(self._expr.list_min())

def max(self) -> Expression:
"""Calculates the maximum of each list. If no non-null values in a list, the result is null.

Returns:
Expression: a Float64 expression with the type of the list values
"""
return Expression._from_pyexpr(self._expr.list_max())


class ExpressionStructNamespace(ExpressionNamespace):
def get(self, name: str) -> Expression:
Expand Down
2 changes: 1 addition & 1 deletion daft/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ def iceberg_truncate(self, w: int) -> Series:

class SeriesListNamespace(SeriesNamespace):
def lengths(self) -> Series:
return Series._from_pyseries(self._series.list_lengths())
return Series._from_pyseries(self._series.list_count(CountMode.All))

def get(self, idx: Series, default: Series) -> Series:
return Series._from_pyseries(self._series.list_get(idx._series, default._series))
Expand Down
2 changes: 1 addition & 1 deletion src/daft-core/src/array/ops/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ fn grouped_count_arrow_bitmap(
.iter()
.map(|g| {
g.iter()
.map(|i| validity.get_bit(!*i as usize) as u64)
.map(|i| !validity.get_bit(*i as usize) as u64)
.sum()
})
.collect(),
Expand Down
Loading
Loading