From bfb72025ece563d6285e4a27117dbd2f2c2469a3 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Thu, 21 Mar 2024 16:26:01 -0700 Subject: [PATCH 01/11] add list sum --- daft/daft.pyi | 1 + daft/expressions/expressions.py | 8 ++ src/daft-core/src/array/ops/list.rs | 110 ++++++++++++++++++++++++- src/daft-core/src/datatypes/field.rs | 33 ++++++++ src/daft-core/src/series/ops/list.rs | 11 +++ src/daft-dsl/src/expr.rs | 49 +---------- src/daft-dsl/src/functions/list/mod.rs | 11 +++ src/daft-dsl/src/functions/list/sum.rs | 34 ++++++++ src/daft-dsl/src/python.rs | 5 ++ 9 files changed, 213 insertions(+), 49 deletions(-) create mode 100644 src/daft-dsl/src/functions/list/sum.rs diff --git a/daft/daft.pyi b/daft/daft.pyi index dfaea18054..9d16d46429 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -945,6 +945,7 @@ class PyExpr: def list_join(self, delimiter: PyExpr) -> PyExpr: ... def list_lengths(self) -> PyExpr: ... def list_get(self, idx: PyExpr, default: PyExpr) -> PyExpr: ... + def list_sum(self) -> PyExpr: ... def struct_get(self, name: str) -> PyExpr: ... def url_download( self, max_connections: int, raise_error_on_failure: bool, multi_thread: bool, config: IOConfig diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 2489af0d22..b9db87866c 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -890,6 +890,14 @@ def get(self, idx: int | Expression, default: object = None) -> Expression: default_expr = lit(default) return Expression._from_pyexpr(self._expr.list_get(idx_expr._expr, default_expr._expr)) + def sum(self) -> Expression: + """Sums each list. Empty lists and lists with all nulls yield null. + + Returns: + Expression: an expression with the type of the list values + """ + return Expression._from_pyexpr(self._expr.list_sum()) + class ExpressionStructNamespace(ExpressionNamespace): def get(self, name: str) -> Expression: diff --git a/src/daft-core/src/array/ops/list.rs b/src/daft-core/src/array/ops/list.rs index 20cdd6cead..2972a117bb 100644 --- a/src/daft-core/src/array/ops/list.rs +++ b/src/daft-core/src/array/ops/list.rs @@ -1,17 +1,20 @@ use std::iter::repeat; +use std::ops::Add; +use crate::array::DataArray; use crate::array::{ growable::{make_growable, Growable}, FixedSizeListArray, ListArray, }; -use crate::datatypes::{Int64Array, UInt64Array, Utf8Array}; +use crate::datatypes::{DaftNumericType, Int64Array, UInt64Array, Utf8Array}; use crate::DataType; use crate::series::Series; use arrow2; -use common_error::DaftResult; +use arrow2::compute::aggregate::Sum; +use common_error::{DaftError, DaftResult}; use super::as_arrow::AsArrow; @@ -169,6 +172,57 @@ impl ListArray { } } } + + fn sum_data_array(&self, arr: &DataArray) -> DaftResult + where + ::Simd: + Add::Simd> + Sum, + { + let sums = arrow2::types::IndexRange::new(0, self.len() as u64) + .map(|i| i as usize) + .map(|i| { + if let Some(validity) = self.validity() && !validity.get_bit(i) { + return None; + } + + let start = *self.offsets().get(i).unwrap() as usize; + let end = *self.offsets().get(i + 1).unwrap() as usize; + + let slice = arr.slice(start, end).unwrap(); + let slice_arr = slice.as_arrow(); + + arrow2::compute::aggregate::sum_primitive(slice_arr) + }); + + let array = arrow2::array::PrimitiveArray::from_trusted_len_iter(sums).boxed(); + + Series::try_from((self.name(), array)) + } + + pub fn sum(&self) -> DaftResult { + use crate::datatypes::DataType::*; + + match self.flat_child.data_type() { + Int8 | Int16 | Int32 | Int64 => { + let casted = self.flat_child.cast(&Int64)?; + let arr = casted.i64()?; + + self.sum_data_array(arr) + } + UInt8 | UInt16 | UInt32 | UInt64 => { + let casted = self.flat_child.cast(&UInt64)?; + let arr = casted.u64()?; + + self.sum_data_array(arr) + } + Float32 => self.sum_data_array(self.flat_child.f32()?), + Float64 => self.sum_data_array(self.flat_child.f64()?), + other => Err(DaftError::TypeError(format!( + "Sum not implemented for {}", + other + ))), + } + } } impl FixedSizeListArray { @@ -302,4 +356,56 @@ impl FixedSizeListArray { } } } + + fn sum_data_array(&self, arr: &DataArray) -> DaftResult + where + ::Simd: + Add::Simd> + Sum, + { + let step = self.fixed_element_len(); + let sums = arrow2::types::IndexRange::new(0, self.len() as u64) + .map(|i| i as usize) + .map(|i| { + if let Some(validity) = self.validity() && !validity.get_bit(i) { + return None; + } + + let start = i * step; + let end = (i + 1) + step; + + let slice = arr.slice(start, end).unwrap(); + let slice_arr = slice.as_arrow(); + + arrow2::compute::aggregate::sum_primitive(slice_arr) + }); + + let array = arrow2::array::PrimitiveArray::from_trusted_len_iter(sums).boxed(); + + Series::try_from((self.name(), array)) + } + + pub fn sum(&self) -> DaftResult { + use crate::datatypes::DataType::*; + + match self.flat_child.data_type() { + Int8 | Int16 | Int32 | Int64 => { + let casted = self.flat_child.cast(&Int64)?; + let arr = casted.i64()?; + + self.sum_data_array(arr) + } + UInt8 | UInt16 | UInt32 | UInt64 => { + let casted = self.flat_child.cast(&UInt64)?; + let arr = casted.u64()?; + + self.sum_data_array(arr) + } + Float32 => self.sum_data_array(self.flat_child.f32()?), + Float64 => self.sum_data_array(self.flat_child.f64()?), + other => Err(DaftError::TypeError(format!( + "Sum not implemented for {}", + other + ))), + } + } } diff --git a/src/daft-core/src/datatypes/field.rs b/src/daft-core/src/datatypes/field.rs index 9362c7c012..33f523f1ba 100644 --- a/src/daft-core/src/datatypes/field.rs +++ b/src/daft-core/src/datatypes/field.rs @@ -128,6 +128,39 @@ impl Field { ))), } } + + pub fn to_sum(&self) -> DaftResult { + Ok(Field::new( + self.name.as_str(), + match &self.dtype { + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + DataType::Int64 + } + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { + DataType::UInt64 + } + DataType::Float32 => DataType::Float32, + DataType::Float64 => DataType::Float64, + other => { + return Err(DaftError::TypeError(format!( + "Expected input to sum() to be numeric but received dtype {} for column \"{}\"", + other, self.name, + ))) + } + }, + )) + } + + pub fn to_mean(&self) -> DaftResult { + if self.dtype.is_numeric() { + Ok(Field::new(self.name.as_str(), DataType::Float64)) + } else { + Err(DaftError::TypeError(format!( + "Numeric mean is not implemented for column \"{}\" of type {}", + self.name, self.dtype, + ))) + } + } } impl From<&ArrowField> for Field { diff --git a/src/daft-core/src/series/ops/list.rs b/src/daft-core/src/series/ops/list.rs index c48905fc7f..14eceed574 100644 --- a/src/daft-core/src/series/ops/list.rs +++ b/src/daft-core/src/series/ops/list.rs @@ -67,4 +67,15 @@ impl Series { ))), } } + + pub fn list_sum(&self) -> DaftResult { + match self.data_type() { + DataType::List(_) => self.list()?.sum(), + DataType::FixedSizeList(..) => self.fixed_size_list()?.sum(), + dt => Err(DaftError::TypeError(format!( + "Sum not implemented for {}", + dt + ))), + } + } } diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index 3b4e8716df..a60ff226a2 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -159,53 +159,8 @@ impl AggExpr { let field = expr.to_field(schema)?; Ok(Field::new(field.name.as_str(), DataType::UInt64)) } - Sum(expr) => { - let field = expr.to_field(schema)?; - Ok(Field::new( - field.name.as_str(), - match &field.dtype { - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { - DataType::Int64 - } - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 => DataType::UInt64, - DataType::Float32 => DataType::Float32, - DataType::Float64 => DataType::Float64, - other => { - return Err(DaftError::TypeError(format!( - "Expected input to sum() to be numeric but received dtype {} for column \"{}\"", - other, field.name, - ))) - } - }, - )) - } - Mean(expr) => { - let field = expr.to_field(schema)?; - Ok(Field::new( - field.name.as_str(), - match &field.dtype { - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float32 - | DataType::Float64 => DataType::Float64, - other => { - return Err(DaftError::TypeError(format!( - "Numeric mean is not implemented for column \"{}\" of type {}", - field.name, other, - ))) - } - }, - )) - } + Sum(expr) => expr.to_field(schema)?.to_sum(), + Mean(expr) => expr.to_field(schema)?.to_mean(), Min(expr) | Max(expr) | AnyValue(expr, _) => { let field = expr.to_field(schema)?; Ok(Field::new(field.name.as_str(), field.dtype)) diff --git a/src/daft-dsl/src/functions/list/mod.rs b/src/daft-dsl/src/functions/list/mod.rs index 8279dbdb2e..14dc2089c9 100644 --- a/src/daft-dsl/src/functions/list/mod.rs +++ b/src/daft-dsl/src/functions/list/mod.rs @@ -2,12 +2,14 @@ mod explode; mod get; mod join; mod lengths; +mod sum; use explode::ExplodeEvaluator; use get::GetEvaluator; use join::JoinEvaluator; use lengths::LengthsEvaluator; use serde::{Deserialize, Serialize}; +use sum::SumEvaluator; use crate::Expr; @@ -19,6 +21,7 @@ pub enum ListExpr { Join, Lengths, Get, + Sum, } impl ListExpr { @@ -30,6 +33,7 @@ impl ListExpr { Join => &JoinEvaluator {}, Lengths => &LengthsEvaluator {}, Get => &GetEvaluator {}, + Sum => &SumEvaluator {}, } } } @@ -61,3 +65,10 @@ pub fn get(input: &Expr, idx: &Expr, default: &Expr) -> Expr { inputs: vec![input.clone(), idx.clone(), default.clone()], } } + +pub fn sum(input: &Expr) -> Expr { + Expr::Function { + func: super::FunctionExpr::List(ListExpr::Sum), + inputs: vec![input.clone()], + } +} diff --git a/src/daft-dsl/src/functions/list/sum.rs b/src/daft-dsl/src/functions/list/sum.rs new file mode 100644 index 0000000000..484aa65a75 --- /dev/null +++ b/src/daft-dsl/src/functions/list/sum.rs @@ -0,0 +1,34 @@ +use crate::Expr; +use daft_core::{datatypes::Field, schema::Schema, series::Series}; + +use common_error::{DaftError, DaftResult}; + +use super::super::FunctionEvaluator; + +pub(super) struct SumEvaluator {} + +impl FunctionEvaluator for SumEvaluator { + fn fn_name(&self) -> &'static str { + "sum" + } + + fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { + match inputs { + [input] => Ok(input.to_field(schema)?.to_exploded_field()?.to_sum()?), + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult { + match inputs { + [input] => Ok(input.list_sum()?), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } +} diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index 9d7f09f498..c0d836199b 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -482,6 +482,11 @@ impl PyExpr { Ok(get(&self.expr, &idx.expr, &default.expr).into()) } + pub fn list_sum(&self) -> PyResult { + use crate::functions::list::sum; + Ok(sum(&self.expr).into()) + } + pub fn struct_get(&self, name: &str) -> PyResult { use crate::functions::struct_::get; Ok(get(&self.expr, name).into()) From 83f6275b92b2cec8d0e2b7af8cd5b8ea94af05f6 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Fri, 22 Mar 2024 12:20:10 -0700 Subject: [PATCH 02/11] add list count and mean --- daft/daft.pyi | 5 +- daft/expressions/expressions.py | 21 ++- daft/series.py | 2 +- src/daft-core/src/array/ops/count.rs | 2 +- src/daft-core/src/array/ops/list.rs | 151 +++++++++++++++++---- src/daft-core/src/python/series.rs | 4 +- src/daft-core/src/series/ops/list.rs | 24 +++- src/daft-dsl/src/functions/list/count.rs | 64 +++++++++ src/daft-dsl/src/functions/list/lengths.rs | 50 ------- src/daft-dsl/src/functions/list/mean.rs | 34 +++++ src/daft-dsl/src/functions/list/mod.rs | 24 +++- src/daft-dsl/src/python.rs | 11 +- src/daft-table/src/ops/explode.rs | 8 +- 13 files changed, 300 insertions(+), 100 deletions(-) create mode 100644 src/daft-dsl/src/functions/list/count.rs delete mode 100644 src/daft-dsl/src/functions/list/lengths.rs create mode 100644 src/daft-dsl/src/functions/list/mean.rs diff --git a/daft/daft.pyi b/daft/daft.pyi index 9d16d46429..e8b7d05fd0 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -943,9 +943,10 @@ class PyExpr: def image_resize(self, w: int, h: int) -> PyExpr: ... def image_crop(self, bbox: PyExpr) -> PyExpr: ... def list_join(self, delimiter: PyExpr) -> PyExpr: ... - def list_lengths(self) -> PyExpr: ... + def list_count(self, mode: CountMode) -> PyExpr: ... def list_get(self, idx: PyExpr, default: PyExpr) -> PyExpr: ... def list_sum(self) -> PyExpr: ... + def list_mean(self) -> PyExpr: ... def struct_get(self, name: str) -> PyExpr: ... def url_download( self, max_connections: int, raise_error_on_failure: bool, multi_thread: bool, config: IOConfig @@ -1038,7 +1039,7 @@ class PySeries: def partitioning_years(self) -> PySeries: ... def partitioning_iceberg_bucket(self, n: int) -> PySeries: ... def partitioning_iceberg_truncate(self, w: int) -> PySeries: ... - def list_lengths(self) -> PySeries: ... + def list_count(self, mode: CountMode) -> PySeries: ... def list_get(self, idx: PySeries, default: PySeries) -> PySeries: ... def image_decode(self) -> PySeries: ... def image_encode(self, image_format: ImageFormat) -> PySeries: ... diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index b9db87866c..afd2e63178 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -868,13 +868,24 @@ def join(self, delimiter: str | Expression) -> Expression: delimiter_expr = Expression._to_expression(delimiter) return Expression._from_pyexpr(self._expr.list_join(delimiter_expr._expr)) + def count(self, mode: CountMode = CountMode.Valid) -> Expression: + """Counts the number of elements in each list + + Args: + mode: The mode to use for counting. Defaults to CountMode.Valid + + Returns: + Expression: a UInt64 expression which is the length of each list + """ + return Expression._from_pyexpr(self._expr.list_count(mode)) + def lengths(self) -> Expression: """Gets the length of each list Returns: Expression: a UInt64 expression which is the length of each list """ - return Expression._from_pyexpr(self._expr.list_lengths()) + return Expression._from_pyexpr(self._expr.list_count(CountMode.All)) def get(self, idx: int | Expression, default: object = None) -> Expression: """Gets the element at an index in each list @@ -898,6 +909,14 @@ def sum(self) -> Expression: """ return Expression._from_pyexpr(self._expr.list_sum()) + def mean(self) -> Expression: + """Calculates the mean of each list. If no non-null values in a list, the result is null. + + Returns: + Expression: a Float64 expression with the type of the list values + """ + return Expression._from_pyexpr(self._expr.list_mean()) + class ExpressionStructNamespace(ExpressionNamespace): def get(self, name: str) -> Expression: diff --git a/daft/series.py b/daft/series.py index 0c8aa4b7ae..abcff7b6ce 100644 --- a/daft/series.py +++ b/daft/series.py @@ -666,7 +666,7 @@ def iceberg_truncate(self, w: int) -> Series: class SeriesListNamespace(SeriesNamespace): def lengths(self) -> Series: - return Series._from_pyseries(self._series.list_lengths()) + return Series._from_pyseries(self._series.list_count(CountMode.All)) def get(self, idx: Series, default: Series) -> Series: return Series._from_pyseries(self._series.list_get(idx._series, default._series)) diff --git a/src/daft-core/src/array/ops/count.rs b/src/daft-core/src/array/ops/count.rs index 04f9fdad1d..e71ff09e71 100644 --- a/src/daft-core/src/array/ops/count.rs +++ b/src/daft-core/src/array/ops/count.rs @@ -32,7 +32,7 @@ fn grouped_count_arrow_bitmap( .iter() .map(|g| { g.iter() - .map(|i| validity.get_bit(!*i as usize) as u64) + .map(|i| !validity.get_bit(*i as usize) as u64) .sum() }) .collect(), diff --git a/src/daft-core/src/array/ops/list.rs b/src/daft-core/src/array/ops/list.rs index 2972a117bb..ee97d26db2 100644 --- a/src/daft-core/src/array/ops/list.rs +++ b/src/daft-core/src/array/ops/list.rs @@ -6,8 +6,8 @@ use crate::array::{ growable::{make_growable, Growable}, FixedSizeListArray, ListArray, }; -use crate::datatypes::{DaftNumericType, Int64Array, UInt64Array, Utf8Array}; -use crate::DataType; +use crate::datatypes::{DaftNumericType, Float64Array, Int64Array, UInt64Array, Utf8Array}; +use crate::{CountMode, DataType}; use crate::series::Series; @@ -45,11 +45,39 @@ fn join_arrow_list_of_utf8s( } impl ListArray { - pub fn lengths(&self) -> DaftResult { - let lengths = self.offsets().lengths().map(|l| Some(l as u64)); + pub fn count(&self, mode: CountMode) -> DaftResult { + let counts = match mode { + CountMode::All => self.offsets().lengths().map(|l| l as u64).collect(), + CountMode::Valid => self + .offsets() + .windows(2) + .map(|w| { + if let Some(validity) = self.flat_child.validity() { + (w[0]..w[1]) + .map(|i| validity.get_bit(i as usize) as u64) + .sum() + } else { + (w[1] - w[0]) as u64 + } + }) + .collect(), + CountMode::Null => self + .offsets() + .windows(2) + .map(|w| { + if let Some(validity) = self.flat_child.validity() { + (w[0]..w[1]) + .map(|i| !validity.get_bit(i as usize) as u64) + .sum() + } else { + (w[1] - w[0]) as u64 + } + }) + .collect(), + }; + let array = Box::new( - arrow2::array::PrimitiveArray::from_iter(lengths) - .with_validity(self.validity().cloned()), + arrow2::array::PrimitiveArray::from_vec(counts).with_validity(self.validity().cloned()), ); Ok(UInt64Array::from((self.name(), array))) } @@ -223,30 +251,73 @@ impl ListArray { ))), } } + + pub fn mean(&self) -> DaftResult { + use crate::datatypes::DataType::*; + + match self.flat_child.data_type() { + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 | Float32 | Float64 => { + let counts = self.count(CountMode::Valid)?; + let sum_series = self.sum()?.cast(&Float64)?; + let sums = sum_series.f64()?; + + let means = counts + .into_iter() + .zip(sums) + .map(|(count, sum)| match (count, sum) { + (Some(count), Some(sum)) if *count != 0 => Some(sum / (*count as f64)), + _ => None, + }); + + let arr = Box::new(arrow2::array::PrimitiveArray::from_trusted_len_iter(means)); + + Ok(Float64Array::from((self.name(), arr))) + } + other => Err(DaftError::TypeError(format!( + "Mean not implemented for {}", + other + ))), + } + } } impl FixedSizeListArray { - pub fn lengths(&self) -> DaftResult { + pub fn count(&self, mode: CountMode) -> DaftResult { let size = self.fixed_element_len(); - match self.validity() { - None => Ok(UInt64Array::from(( - self.name(), - repeat(size as u64) - .take(self.len()) - .collect::>() - .as_slice(), - ))), - Some(validity) => { - let arrow_arr = arrow2::array::UInt64Array::from_iter(validity.iter().map(|v| { - if v { - Some(size as u64) - } else { - None - } - })); - Ok(UInt64Array::from((self.name(), Box::new(arrow_arr)))) + let counts = match mode { + CountMode::All => repeat(size as u64).take(self.len()).collect(), + CountMode::Valid => { + if let Some(validity) = self.flat_child.validity() { + (0..self.len()) + .map(|i| { + (0..size) + .map(|j| validity.get_bit(i * size + j) as u64) + .sum() + }) + .collect() + } else { + repeat(size as u64).take(self.len()).collect() + } } - } + CountMode::Null => { + if let Some(validity) = self.flat_child.validity() { + (0..self.len()) + .map(|i| { + (0..size) + .map(|j| !validity.get_bit(i * size + j) as u64) + .sum() + }) + .collect() + } else { + repeat(0).take(self.len()).collect() + } + } + }; + + let array = Box::new( + arrow2::array::PrimitiveArray::from_vec(counts).with_validity(self.validity().cloned()), + ); + Ok(UInt64Array::from((self.name(), array))) } pub fn explode(&self) -> DaftResult { @@ -371,7 +442,7 @@ impl FixedSizeListArray { } let start = i * step; - let end = (i + 1) + step; + let end = (i + 1) * step; let slice = arr.slice(start, end).unwrap(); let slice_arr = slice.as_arrow(); @@ -408,4 +479,32 @@ impl FixedSizeListArray { ))), } } + + pub fn mean(&self) -> DaftResult { + use crate::datatypes::DataType::*; + + match self.flat_child.data_type() { + Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 | Float32 | Float64 => { + let counts = self.count(CountMode::Valid)?; + let sum_series = self.sum()?.cast(&Float64)?; + let sums = sum_series.f64()?; + + let means = counts + .into_iter() + .zip(sums) + .map(|(count, sum)| match (count, sum) { + (Some(count), Some(sum)) if *count != 0 => Some(sum / (*count as f64)), + _ => None, + }); + + let arr = Box::new(arrow2::array::PrimitiveArray::from_trusted_len_iter(means)); + + Ok(Float64Array::from((self.name(), arr))) + } + other => Err(DaftError::TypeError(format!( + "Mean not implemented for {}", + other + ))), + } + } } diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index 2b8177e001..e5ab0950d6 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -352,8 +352,8 @@ impl PySeries { Ok(self.series.murmur3_32()?.into_series().into()) } - pub fn list_lengths(&self) -> PyResult { - Ok(self.series.list_lengths()?.into_series().into()) + pub fn list_count(&self, mode: CountMode) -> PyResult { + Ok(self.series.list_count(mode)?.into_series().into()) } pub fn list_get(&self, idx: &Self, default: &Self) -> PyResult { diff --git a/src/daft-core/src/series/ops/list.rs b/src/daft-core/src/series/ops/list.rs index 14eceed574..51e26d72ca 100644 --- a/src/daft-core/src/series/ops/list.rs +++ b/src/daft-core/src/series/ops/list.rs @@ -1,5 +1,6 @@ -use crate::datatypes::{DataType, UInt64Array, Utf8Array}; +use crate::datatypes::{DataType, Float64Array, UInt64Array, Utf8Array}; use crate::series::Series; +use crate::CountMode; use common_error::DaftError; use common_error::DaftResult; @@ -17,13 +18,13 @@ impl Series { } } - pub fn list_lengths(&self) -> DaftResult { + pub fn list_count(&self, mode: CountMode) -> DaftResult { use DataType::*; match self.data_type() { - List(_) => self.list()?.lengths(), - FixedSizeList(..) => self.fixed_size_list()?.lengths(), - Embedding(..) | FixedShapeImage(..) => self.as_physical()?.list_lengths(), + List(_) => self.list()?.count(mode), + FixedSizeList(..) => self.fixed_size_list()?.count(mode), + Embedding(..) | FixedShapeImage(..) => self.as_physical()?.list_count(mode), Image(..) => { let struct_array = self.as_physical()?; let data_array = struct_array.struct_()?.children[0].list().unwrap(); @@ -37,7 +38,7 @@ impl Series { Ok(UInt64Array::from((self.name(), array))) } dt => Err(DaftError::TypeError(format!( - "lengths not implemented for {}", + "Count not implemented for {}", dt ))), } @@ -78,4 +79,15 @@ impl Series { ))), } } + + pub fn list_mean(&self) -> DaftResult { + match self.data_type() { + DataType::List(_) => self.list()?.mean(), + DataType::FixedSizeList(..) => self.fixed_size_list()?.mean(), + dt => Err(DaftError::TypeError(format!( + "Mean not implemented for {}", + dt + ))), + } + } } diff --git a/src/daft-dsl/src/functions/list/count.rs b/src/daft-dsl/src/functions/list/count.rs new file mode 100644 index 0000000000..d6f72dac26 --- /dev/null +++ b/src/daft-dsl/src/functions/list/count.rs @@ -0,0 +1,64 @@ +use crate::{functions::FunctionExpr, Expr}; +use daft_core::{ + datatypes::{DataType, Field}, + schema::Schema, + series::{IntoSeries, Series}, +}; + +use common_error::{DaftError, DaftResult}; + +use super::{super::FunctionEvaluator, ListExpr}; + +pub(super) struct CountEvaluator {} + +impl FunctionEvaluator for CountEvaluator { + fn fn_name(&self) -> &'static str { + "count" + } + + fn to_field(&self, inputs: &[Expr], schema: &Schema, expr: &Expr) -> DaftResult { + match inputs { + [input] => { + let input_field = input.to_field(schema)?; + + match input_field.dtype { + DataType::List(_) | DataType::FixedSizeList(_, _) => match expr { + Expr::Function { + func: FunctionExpr::List(ListExpr::Count(_)), + inputs: _, + } => Ok(Field::new(input.name()?, DataType::UInt64)), + _ => panic!("Expected List Count Expr, got {expr}"), + }, + _ => Err(DaftError::TypeError(format!( + "Expected input to be a list type, received: {}", + input_field.dtype + ))), + } + } + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series], expr: &Expr) -> DaftResult { + match inputs { + [input] => { + let mode = match expr { + Expr::Function { + func: FunctionExpr::List(ListExpr::Count(mode)), + inputs: _, + } => mode, + _ => panic!("Expected List Count Expr, got {expr}"), + }; + + Ok(input.list_count(*mode)?.into_series()) + } + _ => Err(DaftError::ValueError(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } +} diff --git a/src/daft-dsl/src/functions/list/lengths.rs b/src/daft-dsl/src/functions/list/lengths.rs deleted file mode 100644 index b7858999bc..0000000000 --- a/src/daft-dsl/src/functions/list/lengths.rs +++ /dev/null @@ -1,50 +0,0 @@ -use crate::Expr; -use daft_core::{ - datatypes::{DataType, Field}, - schema::Schema, - series::{IntoSeries, Series}, -}; - -use common_error::{DaftError, DaftResult}; - -use super::super::FunctionEvaluator; - -pub(super) struct LengthsEvaluator {} - -impl FunctionEvaluator for LengthsEvaluator { - fn fn_name(&self) -> &'static str { - "lengths" - } - - fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { - match inputs { - [input] => { - let input_field = input.to_field(schema)?; - - match input_field.dtype { - DataType::List(_) | DataType::FixedSizeList(_, _) => { - Ok(Field::new(input.name()?, DataType::UInt64)) - } - _ => Err(DaftError::TypeError(format!( - "Expected input to be a list type, received: {}", - input_field.dtype - ))), - } - } - _ => Err(DaftError::SchemaMismatch(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } - - fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult { - match inputs { - [input] => Ok(input.list_lengths()?.into_series()), - _ => Err(DaftError::ValueError(format!( - "Expected 1 input arg, got {}", - inputs.len() - ))), - } - } -} diff --git a/src/daft-dsl/src/functions/list/mean.rs b/src/daft-dsl/src/functions/list/mean.rs new file mode 100644 index 0000000000..aeac9bc388 --- /dev/null +++ b/src/daft-dsl/src/functions/list/mean.rs @@ -0,0 +1,34 @@ +use crate::Expr; +use daft_core::{datatypes::Field, schema::Schema, series::Series, IntoSeries}; + +use common_error::{DaftError, DaftResult}; + +use super::super::FunctionEvaluator; + +pub(super) struct MeanEvaluator {} + +impl FunctionEvaluator for MeanEvaluator { + fn fn_name(&self) -> &'static str { + "mean" + } + + fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { + match inputs { + [input] => Ok(input.to_field(schema)?.to_exploded_field()?.to_mean()?), + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult { + match inputs { + [input] => Ok(input.list_mean()?.into_series()), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } +} diff --git a/src/daft-dsl/src/functions/list/mod.rs b/src/daft-dsl/src/functions/list/mod.rs index 14dc2089c9..c78eaccec2 100644 --- a/src/daft-dsl/src/functions/list/mod.rs +++ b/src/daft-dsl/src/functions/list/mod.rs @@ -1,13 +1,16 @@ +mod count; mod explode; mod get; mod join; -mod lengths; +mod mean; mod sum; +use count::CountEvaluator; +use daft_core::CountMode; use explode::ExplodeEvaluator; use get::GetEvaluator; use join::JoinEvaluator; -use lengths::LengthsEvaluator; +use mean::MeanEvaluator; use serde::{Deserialize, Serialize}; use sum::SumEvaluator; @@ -19,9 +22,10 @@ use super::FunctionEvaluator; pub enum ListExpr { Explode, Join, - Lengths, + Count(CountMode), Get, Sum, + Mean, } impl ListExpr { @@ -31,9 +35,10 @@ impl ListExpr { match self { Explode => &ExplodeEvaluator {}, Join => &JoinEvaluator {}, - Lengths => &LengthsEvaluator {}, + Count(_) => &CountEvaluator {}, Get => &GetEvaluator {}, Sum => &SumEvaluator {}, + Mean => &MeanEvaluator {}, } } } @@ -52,9 +57,9 @@ pub fn join(input: &Expr, delimiter: &Expr) -> Expr { } } -pub fn lengths(input: &Expr) -> Expr { +pub fn count(input: &Expr, mode: CountMode) -> Expr { Expr::Function { - func: super::FunctionExpr::List(ListExpr::Lengths), + func: super::FunctionExpr::List(ListExpr::Count(mode)), inputs: vec![input.clone()], } } @@ -72,3 +77,10 @@ pub fn sum(input: &Expr) -> Expr { inputs: vec![input.clone()], } } + +pub fn mean(input: &Expr) -> Expr { + Expr::Function { + func: super::FunctionExpr::List(ListExpr::Mean), + inputs: vec![input.clone()], + } +} diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index c0d836199b..5e7eef93d9 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -472,9 +472,9 @@ impl PyExpr { Ok(join(&self.expr, &delimiter.expr).into()) } - pub fn list_lengths(&self) -> PyResult { - use crate::functions::list::lengths; - Ok(lengths(&self.expr).into()) + pub fn list_count(&self, mode: CountMode) -> PyResult { + use crate::functions::list::count; + Ok(count(&self.expr, mode).into()) } pub fn list_get(&self, idx: &Self, default: &Self) -> PyResult { @@ -487,6 +487,11 @@ impl PyExpr { Ok(sum(&self.expr).into()) } + pub fn list_mean(&self) -> PyResult { + use crate::functions::list::mean; + Ok(mean(&self.expr).into()) + } + pub fn struct_get(&self, name: &str) -> PyResult { use crate::functions::struct_::get; Ok(get(&self.expr, name).into()) diff --git a/src/daft-table/src/ops/explode.rs b/src/daft-table/src/ops/explode.rs index 45855805ed..bec57b8dfb 100644 --- a/src/daft-table/src/ops/explode.rs +++ b/src/daft-table/src/ops/explode.rs @@ -1,5 +1,6 @@ use common_error::{DaftError, DaftResult}; use daft_core::series::IntoSeries; +use daft_core::CountMode; use daft_core::{ array::ops::as_arrow::AsArrow, datatypes::{DataType, UInt64Array}, @@ -60,11 +61,14 @@ impl Table { } } } - let first_len = evaluated_columns.first().unwrap().list_lengths()?; + let first_len = evaluated_columns + .first() + .unwrap() + .list_count(CountMode::All)?; if evaluated_columns .iter() .skip(1) - .any(|c| c.list_lengths().unwrap().ne(&first_len)) + .any(|c| c.list_count(CountMode::All).unwrap().ne(&first_len)) { return Err(DaftError::ValueError( "In multicolumn explode, list length did not match".to_string(), From f5ed4f953432cf1667e58979aa55d3a57a61c44c Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Fri, 22 Mar 2024 12:53:37 -0700 Subject: [PATCH 03/11] add list min and max --- daft/daft.pyi | 2 + daft/expressions/expressions.py | 16 ++++ src/daft-core/src/array/ops/list.rs | 125 ++++++++++++++++++++++--- src/daft-core/src/series/ops/list.rs | 22 +++++ src/daft-dsl/src/functions/list/max.rs | 45 +++++++++ src/daft-dsl/src/functions/list/min.rs | 45 +++++++++ src/daft-dsl/src/functions/list/mod.rs | 22 +++++ src/daft-dsl/src/python.rs | 10 ++ 8 files changed, 275 insertions(+), 12 deletions(-) create mode 100644 src/daft-dsl/src/functions/list/max.rs create mode 100644 src/daft-dsl/src/functions/list/min.rs diff --git a/daft/daft.pyi b/daft/daft.pyi index e8b7d05fd0..fb3b832436 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -947,6 +947,8 @@ class PyExpr: def list_get(self, idx: PyExpr, default: PyExpr) -> PyExpr: ... def list_sum(self) -> PyExpr: ... def list_mean(self) -> PyExpr: ... + def list_min(self) -> PyExpr: ... + def list_max(self) -> PyExpr: ... def struct_get(self, name: str) -> PyExpr: ... def url_download( self, max_connections: int, raise_error_on_failure: bool, multi_thread: bool, config: IOConfig diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index afd2e63178..f6a6aba4bf 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -917,6 +917,22 @@ def mean(self) -> Expression: """ return Expression._from_pyexpr(self._expr.list_mean()) + def min(self) -> Expression: + """Calculates the minimum of each list. If no non-null values in a list, the result is null. + + Returns: + Expression: a Float64 expression with the type of the list values + """ + return Expression._from_pyexpr(self._expr.list_min()) + + def max(self) -> Expression: + """Calculates the maximum of each list. If no non-null values in a list, the result is null. + + Returns: + Expression: a Float64 expression with the type of the list values + """ + return Expression._from_pyexpr(self._expr.list_max()) + class ExpressionStructNamespace(ExpressionNamespace): def get(self, name: str) -> Expression: diff --git a/src/daft-core/src/array/ops/list.rs b/src/daft-core/src/array/ops/list.rs index ee97d26db2..bb3a40e26f 100644 --- a/src/daft-core/src/array/ops/list.rs +++ b/src/daft-core/src/array/ops/list.rs @@ -13,11 +13,14 @@ use crate::series::Series; use arrow2; +use arrow2::array::PrimitiveArray; use arrow2::compute::aggregate::Sum; use common_error::{DaftError, DaftResult}; use super::as_arrow::AsArrow; +type Arrow2AggregateFn<'a, T> = &'a dyn Fn(&PrimitiveArray) -> Option; + fn join_arrow_list_of_utf8s( list_element: Option<&dyn arrow2::array::Array>, delimiter_str: &str, @@ -201,7 +204,11 @@ impl ListArray { } } - fn sum_data_array(&self, arr: &DataArray) -> DaftResult + fn agg_data_array( + &self, + arr: &DataArray, + op: Arrow2AggregateFn, + ) -> DaftResult where ::Simd: Add::Simd> + Sum, @@ -219,7 +226,7 @@ impl ListArray { let slice = arr.slice(start, end).unwrap(); let slice_arr = slice.as_arrow(); - arrow2::compute::aggregate::sum_primitive(slice_arr) + op(slice_arr) }); let array = arrow2::array::PrimitiveArray::from_trusted_len_iter(sums).boxed(); @@ -229,22 +236,23 @@ impl ListArray { pub fn sum(&self) -> DaftResult { use crate::datatypes::DataType::*; + use arrow2::compute::aggregate::sum_primitive; match self.flat_child.data_type() { Int8 | Int16 | Int32 | Int64 => { let casted = self.flat_child.cast(&Int64)?; let arr = casted.i64()?; - self.sum_data_array(arr) + self.agg_data_array(arr, &sum_primitive) } UInt8 | UInt16 | UInt32 | UInt64 => { let casted = self.flat_child.cast(&UInt64)?; let arr = casted.u64()?; - self.sum_data_array(arr) + self.agg_data_array(arr, &sum_primitive) } - Float32 => self.sum_data_array(self.flat_child.f32()?), - Float64 => self.sum_data_array(self.flat_child.f64()?), + Float32 => self.agg_data_array(self.flat_child.f32()?, &sum_primitive), + Float64 => self.agg_data_array(self.flat_child.f64()?, &sum_primitive), other => Err(DaftError::TypeError(format!( "Sum not implemented for {}", other @@ -279,6 +287,50 @@ impl ListArray { ))), } } + + pub fn min(&self) -> DaftResult { + use crate::datatypes::DataType::*; + use arrow2::compute::aggregate::min_primitive; + + match self.flat_child.data_type() { + Int8 => self.agg_data_array(self.flat_child.i8()?, &min_primitive), + Int16 => self.agg_data_array(self.flat_child.i16()?, &min_primitive), + Int32 => self.agg_data_array(self.flat_child.i32()?, &min_primitive), + Int64 => self.agg_data_array(self.flat_child.i64()?, &min_primitive), + UInt8 => self.agg_data_array(self.flat_child.u8()?, &min_primitive), + UInt16 => self.agg_data_array(self.flat_child.u16()?, &min_primitive), + UInt32 => self.agg_data_array(self.flat_child.u32()?, &min_primitive), + UInt64 => self.agg_data_array(self.flat_child.u64()?, &min_primitive), + Float32 => self.agg_data_array(self.flat_child.f32()?, &min_primitive), + Float64 => self.agg_data_array(self.flat_child.f64()?, &min_primitive), + other => Err(DaftError::TypeError(format!( + "Min not implemented for {}", + other + ))), + } + } + + pub fn max(&self) -> DaftResult { + use crate::datatypes::DataType::*; + use arrow2::compute::aggregate::max_primitive; + + match self.flat_child.data_type() { + Int8 => self.agg_data_array(self.flat_child.i8()?, &max_primitive), + Int16 => self.agg_data_array(self.flat_child.i16()?, &max_primitive), + Int32 => self.agg_data_array(self.flat_child.i32()?, &max_primitive), + Int64 => self.agg_data_array(self.flat_child.i64()?, &max_primitive), + UInt8 => self.agg_data_array(self.flat_child.u8()?, &max_primitive), + UInt16 => self.agg_data_array(self.flat_child.u16()?, &max_primitive), + UInt32 => self.agg_data_array(self.flat_child.u32()?, &max_primitive), + UInt64 => self.agg_data_array(self.flat_child.u64()?, &max_primitive), + Float32 => self.agg_data_array(self.flat_child.f32()?, &max_primitive), + Float64 => self.agg_data_array(self.flat_child.f64()?, &max_primitive), + other => Err(DaftError::TypeError(format!( + "Max not implemented for {}", + other + ))), + } + } } impl FixedSizeListArray { @@ -428,7 +480,11 @@ impl FixedSizeListArray { } } - fn sum_data_array(&self, arr: &DataArray) -> DaftResult + fn agg_data_array( + &self, + arr: &DataArray, + op: Arrow2AggregateFn, + ) -> DaftResult where ::Simd: Add::Simd> + Sum, @@ -447,7 +503,7 @@ impl FixedSizeListArray { let slice = arr.slice(start, end).unwrap(); let slice_arr = slice.as_arrow(); - arrow2::compute::aggregate::sum_primitive(slice_arr) + op(slice_arr) }); let array = arrow2::array::PrimitiveArray::from_trusted_len_iter(sums).boxed(); @@ -457,22 +513,23 @@ impl FixedSizeListArray { pub fn sum(&self) -> DaftResult { use crate::datatypes::DataType::*; + use arrow2::compute::aggregate::sum_primitive; match self.flat_child.data_type() { Int8 | Int16 | Int32 | Int64 => { let casted = self.flat_child.cast(&Int64)?; let arr = casted.i64()?; - self.sum_data_array(arr) + self.agg_data_array(arr, &sum_primitive) } UInt8 | UInt16 | UInt32 | UInt64 => { let casted = self.flat_child.cast(&UInt64)?; let arr = casted.u64()?; - self.sum_data_array(arr) + self.agg_data_array(arr, &sum_primitive) } - Float32 => self.sum_data_array(self.flat_child.f32()?), - Float64 => self.sum_data_array(self.flat_child.f64()?), + Float32 => self.agg_data_array(self.flat_child.f32()?, &sum_primitive), + Float64 => self.agg_data_array(self.flat_child.f64()?, &sum_primitive), other => Err(DaftError::TypeError(format!( "Sum not implemented for {}", other @@ -507,4 +564,48 @@ impl FixedSizeListArray { ))), } } + + pub fn min(&self) -> DaftResult { + use crate::datatypes::DataType::*; + use arrow2::compute::aggregate::min_primitive; + + match self.flat_child.data_type() { + Int8 => self.agg_data_array(self.flat_child.i8()?, &min_primitive), + Int16 => self.agg_data_array(self.flat_child.i16()?, &min_primitive), + Int32 => self.agg_data_array(self.flat_child.i32()?, &min_primitive), + Int64 => self.agg_data_array(self.flat_child.i64()?, &min_primitive), + UInt8 => self.agg_data_array(self.flat_child.u8()?, &min_primitive), + UInt16 => self.agg_data_array(self.flat_child.u16()?, &min_primitive), + UInt32 => self.agg_data_array(self.flat_child.u32()?, &min_primitive), + UInt64 => self.agg_data_array(self.flat_child.u64()?, &min_primitive), + Float32 => self.agg_data_array(self.flat_child.f32()?, &min_primitive), + Float64 => self.agg_data_array(self.flat_child.f64()?, &min_primitive), + other => Err(DaftError::TypeError(format!( + "Min not implemented for {}", + other + ))), + } + } + + pub fn max(&self) -> DaftResult { + use crate::datatypes::DataType::*; + use arrow2::compute::aggregate::max_primitive; + + match self.flat_child.data_type() { + Int8 => self.agg_data_array(self.flat_child.i8()?, &max_primitive), + Int16 => self.agg_data_array(self.flat_child.i16()?, &max_primitive), + Int32 => self.agg_data_array(self.flat_child.i32()?, &max_primitive), + Int64 => self.agg_data_array(self.flat_child.i64()?, &max_primitive), + UInt8 => self.agg_data_array(self.flat_child.u8()?, &max_primitive), + UInt16 => self.agg_data_array(self.flat_child.u16()?, &max_primitive), + UInt32 => self.agg_data_array(self.flat_child.u32()?, &max_primitive), + UInt64 => self.agg_data_array(self.flat_child.u64()?, &max_primitive), + Float32 => self.agg_data_array(self.flat_child.f32()?, &max_primitive), + Float64 => self.agg_data_array(self.flat_child.f64()?, &max_primitive), + other => Err(DaftError::TypeError(format!( + "Max not implemented for {}", + other + ))), + } + } } diff --git a/src/daft-core/src/series/ops/list.rs b/src/daft-core/src/series/ops/list.rs index 51e26d72ca..a4cf81bb74 100644 --- a/src/daft-core/src/series/ops/list.rs +++ b/src/daft-core/src/series/ops/list.rs @@ -90,4 +90,26 @@ impl Series { ))), } } + + pub fn list_min(&self) -> DaftResult { + match self.data_type() { + DataType::List(_) => self.list()?.min(), + DataType::FixedSizeList(..) => self.fixed_size_list()?.min(), + dt => Err(DaftError::TypeError(format!( + "Min not implemented for {}", + dt + ))), + } + } + + pub fn list_max(&self) -> DaftResult { + match self.data_type() { + DataType::List(_) => self.list()?.max(), + DataType::FixedSizeList(..) => self.fixed_size_list()?.max(), + dt => Err(DaftError::TypeError(format!( + "Max not implemented for {}", + dt + ))), + } + } } diff --git a/src/daft-dsl/src/functions/list/max.rs b/src/daft-dsl/src/functions/list/max.rs new file mode 100644 index 0000000000..34d788f4aa --- /dev/null +++ b/src/daft-dsl/src/functions/list/max.rs @@ -0,0 +1,45 @@ +use crate::Expr; +use daft_core::{datatypes::Field, schema::Schema, series::Series}; + +use common_error::{DaftError, DaftResult}; + +use super::super::FunctionEvaluator; + +pub(super) struct MaxEvaluator {} + +impl FunctionEvaluator for MaxEvaluator { + fn fn_name(&self) -> &'static str { + "max" + } + + fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { + match inputs { + [input] => { + let field = input.to_field(schema)?.to_exploded_field()?; + + if field.dtype.is_numeric() { + Ok(field) + } else { + Err(DaftError::TypeError(format!( + "Expected input to be numeric, got {}", + field.dtype + ))) + } + } + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult { + match inputs { + [input] => Ok(input.list_max()?), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } +} diff --git a/src/daft-dsl/src/functions/list/min.rs b/src/daft-dsl/src/functions/list/min.rs new file mode 100644 index 0000000000..a2e1988e3e --- /dev/null +++ b/src/daft-dsl/src/functions/list/min.rs @@ -0,0 +1,45 @@ +use crate::Expr; +use daft_core::{datatypes::Field, schema::Schema, series::Series}; + +use common_error::{DaftError, DaftResult}; + +use super::super::FunctionEvaluator; + +pub(super) struct MinEvaluator {} + +impl FunctionEvaluator for MinEvaluator { + fn fn_name(&self) -> &'static str { + "min" + } + + fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { + match inputs { + [input] => { + let field = input.to_field(schema)?.to_exploded_field()?; + + if field.dtype.is_numeric() { + Ok(field) + } else { + Err(DaftError::TypeError(format!( + "Expected input to be numeric, got {}", + field.dtype + ))) + } + } + _ => Err(DaftError::SchemaMismatch(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } + + fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult { + match inputs { + [input] => Ok(input.list_min()?), + _ => Err(DaftError::ValueError(format!( + "Expected 1 input arg, got {}", + inputs.len() + ))), + } + } +} diff --git a/src/daft-dsl/src/functions/list/mod.rs b/src/daft-dsl/src/functions/list/mod.rs index c78eaccec2..f673406f9d 100644 --- a/src/daft-dsl/src/functions/list/mod.rs +++ b/src/daft-dsl/src/functions/list/mod.rs @@ -2,7 +2,9 @@ mod count; mod explode; mod get; mod join; +mod max; mod mean; +mod min; mod sum; use count::CountEvaluator; @@ -10,7 +12,9 @@ use daft_core::CountMode; use explode::ExplodeEvaluator; use get::GetEvaluator; use join::JoinEvaluator; +use max::MaxEvaluator; use mean::MeanEvaluator; +use min::MinEvaluator; use serde::{Deserialize, Serialize}; use sum::SumEvaluator; @@ -26,6 +30,8 @@ pub enum ListExpr { Get, Sum, Mean, + Min, + Max, } impl ListExpr { @@ -39,6 +45,8 @@ impl ListExpr { Get => &GetEvaluator {}, Sum => &SumEvaluator {}, Mean => &MeanEvaluator {}, + Min => &MinEvaluator {}, + Max => &MaxEvaluator {}, } } } @@ -84,3 +92,17 @@ pub fn mean(input: &Expr) -> Expr { inputs: vec![input.clone()], } } + +pub fn min(input: &Expr) -> Expr { + Expr::Function { + func: super::FunctionExpr::List(ListExpr::Min), + inputs: vec![input.clone()], + } +} + +pub fn max(input: &Expr) -> Expr { + Expr::Function { + func: super::FunctionExpr::List(ListExpr::Max), + inputs: vec![input.clone()], + } +} diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index 5e7eef93d9..1302005dee 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -492,6 +492,16 @@ impl PyExpr { Ok(mean(&self.expr).into()) } + pub fn list_min(&self) -> PyResult { + use crate::functions::list::min; + Ok(min(&self.expr).into()) + } + + pub fn list_max(&self) -> PyResult { + use crate::functions::list::max; + Ok(max(&self.expr).into()) + } + pub fn struct_get(&self, name: &str) -> PyResult { use crate::functions::struct_::get; Ok(get(&self.expr, name).into()) From 9cfbf7076a59909256d04fdf3bce68ad2aeceada Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Fri, 22 Mar 2024 13:22:11 -0700 Subject: [PATCH 04/11] add list agg tests --- tests/table/list/test_list_count_lengths.py | 52 +++++++++++++++++++++ tests/table/list/test_list_lengths.py | 10 ---- tests/table/list/test_list_numeric_aggs.py | 35 ++++++++++++++ 3 files changed, 87 insertions(+), 10 deletions(-) create mode 100644 tests/table/list/test_list_count_lengths.py delete mode 100644 tests/table/list/test_list_lengths.py create mode 100644 tests/table/list/test_list_numeric_aggs.py diff --git a/tests/table/list/test_list_count_lengths.py b/tests/table/list/test_list_count_lengths.py new file mode 100644 index 0000000000..321219d6da --- /dev/null +++ b/tests/table/list/test_list_count_lengths.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import pytest + +from daft.daft import CountMode +from daft.datatype import DataType +from daft.expressions import col +from daft.table import MicroPartition + + +@pytest.fixture +def table(): + return MicroPartition.from_pydict({"col": [None, [], ["a"], [None], ["a", "a"], ["a", None], ["a", None, "a"]]}) + + +@pytest.fixture +def fixed_table(): + table = MicroPartition.from_pydict({"col": [["a", "a"], ["a", "a"], ["a", None], [None, None], None]}) + fixed_dtype = DataType.fixed_size_list(DataType.string(), 2) + return table.eval_expression_list([col("col").cast(fixed_dtype)]) + + +def test_list_lengths(table): + result = table.eval_expression_list([col("col").list.lengths()]) + assert result.to_pydict() == {"col": [None, 0, 1, 1, 2, 2, 3]} + + +def test_fixed_list_lengths(fixed_table): + result = fixed_table.eval_expression_list([col("col").list.lengths()]) + assert result.to_pydict() == {"col": [2, 2, 2, 2, None]} + + +def test_list_count(table): + result = table.eval_expression_list([col("col").list.count(CountMode.All)]) + assert result.to_pydict() == {"col": [None, 0, 1, 1, 2, 2, 3]} + + result = table.eval_expression_list([col("col").list.count(CountMode.Valid)]) + assert result.to_pydict() == {"col": [None, 0, 1, 0, 2, 1, 2]} + + result = table.eval_expression_list([col("col").list.count(CountMode.Null)]) + assert result.to_pydict() == {"col": [None, 0, 0, 1, 0, 1, 1]} + + +def test_fixed_list_count(fixed_table): + result = fixed_table.eval_expression_list([col("col").list.count(CountMode.All)]) + assert result.to_pydict() == {"col": [2, 2, 2, 2, None]} + + result = fixed_table.eval_expression_list([col("col").list.count(CountMode.Valid)]) + assert result.to_pydict() == {"col": [2, 2, 1, 0, None]} + + result = fixed_table.eval_expression_list([col("col").list.count(CountMode.Null)]) + assert result.to_pydict() == {"col": [0, 0, 1, 2, None]} diff --git a/tests/table/list/test_list_lengths.py b/tests/table/list/test_list_lengths.py deleted file mode 100644 index a3520908dc..0000000000 --- a/tests/table/list/test_list_lengths.py +++ /dev/null @@ -1,10 +0,0 @@ -from __future__ import annotations - -from daft.expressions import col -from daft.table import MicroPartition - - -def test_list_lengths(): - table = MicroPartition.from_pydict({"col": [None, [], ["a"], [None], ["a", "a"], ["a", None], ["a", None, "a"]]}) - result = table.eval_expression_list([col("col").list.lengths()]) - assert result.to_pydict() == {"col": [None, 0, 1, 1, 2, 2, 3]} diff --git a/tests/table/list/test_list_numeric_aggs.py b/tests/table/list/test_list_numeric_aggs.py new file mode 100644 index 0000000000..c77d2fe3dc --- /dev/null +++ b/tests/table/list/test_list_numeric_aggs.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import pytest + +from daft.datatype import DataType +from daft.expressions import col +from daft.table import MicroPartition + +table = MicroPartition.from_pydict({"a": [[1, 2], [3, 4], [5, None], [None, None], None]}) +fixed_dtype = DataType.fixed_size_list(DataType.int64(), 2) +fixed_table = table.eval_expression_list([col("a").cast(fixed_dtype)]) + + +@pytest.mark.parametrize("table", [table, fixed_table]) +def test_list_sum(table): + result = table.eval_expression_list([col("a").list.sum()]) + assert result.to_pydict() == {"a": [3, 7, 5, None, None]} + + +@pytest.mark.parametrize("table", [table, fixed_table]) +def test_list_mean(table): + result = table.eval_expression_list([col("a").list.mean()]) + assert result.to_pydict() == {"a": [1.5, 3.5, 5, None, None]} + + +@pytest.mark.parametrize("table", [table, fixed_table]) +def test_list_min(table): + result = table.eval_expression_list([col("a").list.min()]) + assert result.to_pydict() == {"a": [1, 3, 5, None, None]} + + +@pytest.mark.parametrize("table", [table, fixed_table]) +def test_list_max(table): + result = table.eval_expression_list([col("a").list.max()]) + assert result.to_pydict() == {"a": [2, 4, 5, None, None]} From bfe1e6530d7877cab1b55d865990474d0b54d739 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Wed, 3 Apr 2024 17:16:43 -0700 Subject: [PATCH 05/11] revise code based on review --- src/daft-core/src/array/ops/list.rs | 310 +++++++++++----------- src/daft-core/src/datatypes/binary_ops.rs | 28 ++ src/daft-core/src/datatypes/field.rs | 33 --- src/daft-core/src/datatypes/mod.rs | 2 +- src/daft-dsl/src/expr.rs | 19 +- src/daft-dsl/src/functions/list/mean.rs | 15 +- src/daft-dsl/src/functions/list/sum.rs | 14 +- 7 files changed, 221 insertions(+), 200 deletions(-) diff --git a/src/daft-core/src/array/ops/list.rs b/src/daft-core/src/array/ops/list.rs index bb3a40e26f..b706f1dc54 100644 --- a/src/daft-core/src/array/ops/list.rs +++ b/src/daft-core/src/array/ops/list.rs @@ -19,8 +19,6 @@ use common_error::{DaftError, DaftResult}; use super::as_arrow::AsArrow; -type Arrow2AggregateFn<'a, T> = &'a dyn Fn(&PrimitiveArray) -> Option; - fn join_arrow_list_of_utf8s( list_element: Option<&dyn arrow2::array::Array>, delimiter_str: &str, @@ -47,34 +45,82 @@ fn join_arrow_list_of_utf8s( }) } +trait ListChildAggable { + fn agg_data_array(&self, arr: &DataArray, op: U) -> DaftResult + where + T: DaftNumericType, + ::Simd: + Add::Simd> + Sum, + U: Fn(&PrimitiveArray) -> Option; + + fn _min(&self, flat_child: &Series) -> DaftResult { + use crate::datatypes::DataType::*; + use arrow2::compute::aggregate::min_primitive; + + match flat_child.data_type() { + Int8 => self.agg_data_array(flat_child.i8()?, &min_primitive), + Int16 => self.agg_data_array(flat_child.i16()?, &min_primitive), + Int32 => self.agg_data_array(flat_child.i32()?, &min_primitive), + Int64 => self.agg_data_array(flat_child.i64()?, &min_primitive), + UInt8 => self.agg_data_array(flat_child.u8()?, &min_primitive), + UInt16 => self.agg_data_array(flat_child.u16()?, &min_primitive), + UInt32 => self.agg_data_array(flat_child.u32()?, &min_primitive), + UInt64 => self.agg_data_array(flat_child.u64()?, &min_primitive), + Float32 => self.agg_data_array(flat_child.f32()?, &min_primitive), + Float64 => self.agg_data_array(flat_child.f64()?, &min_primitive), + other => Err(DaftError::TypeError(format!( + "Min not implemented for {}", + other + ))), + } + } + + fn _max(&self, flat_child: &Series) -> DaftResult { + use crate::datatypes::DataType::*; + use arrow2::compute::aggregate::max_primitive; + + match flat_child.data_type() { + Int8 => self.agg_data_array(flat_child.i8()?, &max_primitive), + Int16 => self.agg_data_array(flat_child.i16()?, &max_primitive), + Int32 => self.agg_data_array(flat_child.i32()?, &max_primitive), + Int64 => self.agg_data_array(flat_child.i64()?, &max_primitive), + UInt8 => self.agg_data_array(flat_child.u8()?, &max_primitive), + UInt16 => self.agg_data_array(flat_child.u16()?, &max_primitive), + UInt32 => self.agg_data_array(flat_child.u32()?, &max_primitive), + UInt64 => self.agg_data_array(flat_child.u64()?, &max_primitive), + Float32 => self.agg_data_array(flat_child.f32()?, &max_primitive), + Float64 => self.agg_data_array(flat_child.f64()?, &max_primitive), + other => Err(DaftError::TypeError(format!( + "Max not implemented for {}", + other + ))), + } + } +} + impl ListArray { pub fn count(&self, mode: CountMode) -> DaftResult { - let counts = match mode { - CountMode::All => self.offsets().lengths().map(|l| l as u64).collect(), - CountMode::Valid => self + let counts = match (mode, self.flat_child.validity()) { + (CountMode::All, _) | (CountMode::Valid, None) => { + self.offsets().lengths().map(|l| l as u64).collect() + } + (CountMode::Valid, Some(validity)) => self .offsets() .windows(2) .map(|w| { - if let Some(validity) = self.flat_child.validity() { - (w[0]..w[1]) - .map(|i| validity.get_bit(i as usize) as u64) - .sum() - } else { - (w[1] - w[0]) as u64 - } + (w[0]..w[1]) + .map(|i| validity.get_bit(i as usize) as u64) + .sum() }) .collect(), - CountMode::Null => self + (CountMode::Null, None) => repeat(0).take(self.offsets().len() - 1).collect(), + (CountMode::Null, Some(validity)) => self .offsets() .windows(2) .map(|w| { - if let Some(validity) = self.flat_child.validity() { - (w[0]..w[1]) - .map(|i| !validity.get_bit(i as usize) as u64) - .sum() - } else { - (w[1] - w[0]) as u64 - } + (w[0]..w[1]) + .map(|i| !validity.get_bit(i as usize) as u64) + .sum() }) .collect(), }; @@ -204,36 +250,6 @@ impl ListArray { } } - fn agg_data_array( - &self, - arr: &DataArray, - op: Arrow2AggregateFn, - ) -> DaftResult - where - ::Simd: - Add::Simd> + Sum, - { - let sums = arrow2::types::IndexRange::new(0, self.len() as u64) - .map(|i| i as usize) - .map(|i| { - if let Some(validity) = self.validity() && !validity.get_bit(i) { - return None; - } - - let start = *self.offsets().get(i).unwrap() as usize; - let end = *self.offsets().get(i + 1).unwrap() as usize; - - let slice = arr.slice(start, end).unwrap(); - let slice_arr = slice.as_arrow(); - - op(slice_arr) - }); - - let array = arrow2::array::PrimitiveArray::from_trusted_len_iter(sums).boxed(); - - Series::try_from((self.name(), array)) - } - pub fn sum(&self) -> DaftResult { use crate::datatypes::DataType::*; use arrow2::compute::aggregate::sum_primitive; @@ -289,81 +305,66 @@ impl ListArray { } pub fn min(&self) -> DaftResult { - use crate::datatypes::DataType::*; - use arrow2::compute::aggregate::min_primitive; - - match self.flat_child.data_type() { - Int8 => self.agg_data_array(self.flat_child.i8()?, &min_primitive), - Int16 => self.agg_data_array(self.flat_child.i16()?, &min_primitive), - Int32 => self.agg_data_array(self.flat_child.i32()?, &min_primitive), - Int64 => self.agg_data_array(self.flat_child.i64()?, &min_primitive), - UInt8 => self.agg_data_array(self.flat_child.u8()?, &min_primitive), - UInt16 => self.agg_data_array(self.flat_child.u16()?, &min_primitive), - UInt32 => self.agg_data_array(self.flat_child.u32()?, &min_primitive), - UInt64 => self.agg_data_array(self.flat_child.u64()?, &min_primitive), - Float32 => self.agg_data_array(self.flat_child.f32()?, &min_primitive), - Float64 => self.agg_data_array(self.flat_child.f64()?, &min_primitive), - other => Err(DaftError::TypeError(format!( - "Min not implemented for {}", - other - ))), - } + self._min(&self.flat_child) } pub fn max(&self) -> DaftResult { - use crate::datatypes::DataType::*; - use arrow2::compute::aggregate::max_primitive; + self._max(&self.flat_child) + } +} - match self.flat_child.data_type() { - Int8 => self.agg_data_array(self.flat_child.i8()?, &max_primitive), - Int16 => self.agg_data_array(self.flat_child.i16()?, &max_primitive), - Int32 => self.agg_data_array(self.flat_child.i32()?, &max_primitive), - Int64 => self.agg_data_array(self.flat_child.i64()?, &max_primitive), - UInt8 => self.agg_data_array(self.flat_child.u8()?, &max_primitive), - UInt16 => self.agg_data_array(self.flat_child.u16()?, &max_primitive), - UInt32 => self.agg_data_array(self.flat_child.u32()?, &max_primitive), - UInt64 => self.agg_data_array(self.flat_child.u64()?, &max_primitive), - Float32 => self.agg_data_array(self.flat_child.f32()?, &max_primitive), - Float64 => self.agg_data_array(self.flat_child.f64()?, &max_primitive), - other => Err(DaftError::TypeError(format!( - "Max not implemented for {}", - other - ))), - } +impl ListChildAggable for ListArray { + fn agg_data_array(&self, arr: &DataArray, op: U) -> DaftResult + where + T: DaftNumericType, + ::Simd: + Add::Simd> + Sum, + U: Fn(&PrimitiveArray) -> Option, + { + let aggs = arrow2::types::IndexRange::new(0, self.len() as u64) + .map(|i| i as usize) + .map(|i| { + if let Some(validity) = self.validity() && !validity.get_bit(i) { + return None; + } + + let start = *self.offsets().get(i).unwrap() as usize; + let end = *self.offsets().get(i + 1).unwrap() as usize; + + let slice = arr.slice(start, end).unwrap(); + let slice_arr = slice.as_arrow(); + + op(slice_arr) + }); + + let array = arrow2::array::PrimitiveArray::from_trusted_len_iter(aggs).boxed(); + + Series::try_from((self.name(), array)) } } impl FixedSizeListArray { pub fn count(&self, mode: CountMode) -> DaftResult { let size = self.fixed_element_len(); - let counts = match mode { - CountMode::All => repeat(size as u64).take(self.len()).collect(), - CountMode::Valid => { - if let Some(validity) = self.flat_child.validity() { - (0..self.len()) - .map(|i| { - (0..size) - .map(|j| validity.get_bit(i * size + j) as u64) - .sum() - }) - .collect() - } else { - repeat(size as u64).take(self.len()).collect() - } - } - CountMode::Null => { - if let Some(validity) = self.flat_child.validity() { - (0..self.len()) - .map(|i| { - (0..size) - .map(|j| !validity.get_bit(i * size + j) as u64) - .sum() - }) - .collect() - } else { - repeat(0).take(self.len()).collect() - } + let counts = match (mode, self.flat_child.validity()) { + (CountMode::All, _) | (CountMode::Valid, None) => { + repeat(size as u64).take(self.len()).collect() } + (CountMode::Valid, Some(validity)) => (0..self.len()) + .map(|i| { + (0..size) + .map(|j| validity.get_bit(i * size + j) as u64) + .sum() + }) + .collect(), + (CountMode::Null, None) => repeat(0).take(self.len()).collect(), + (CountMode::Null, Some(validity)) => (0..self.len()) + .map(|i| { + (0..size) + .map(|j| !validity.get_bit(i * size + j) as u64) + .sum() + }) + .collect(), }; let array = Box::new( @@ -480,17 +481,15 @@ impl FixedSizeListArray { } } - fn agg_data_array( - &self, - arr: &DataArray, - op: Arrow2AggregateFn, - ) -> DaftResult + fn agg_data_array(&self, arr: &DataArray, op: U) -> DaftResult where + T: DaftNumericType, ::Simd: Add::Simd> + Sum, + U: Fn(&PrimitiveArray) -> Option, { let step = self.fixed_element_len(); - let sums = arrow2::types::IndexRange::new(0, self.len() as u64) + let aggs = arrow2::types::IndexRange::new(0, self.len() as u64) .map(|i| i as usize) .map(|i| { if let Some(validity) = self.validity() && !validity.get_bit(i) { @@ -506,7 +505,7 @@ impl FixedSizeListArray { op(slice_arr) }); - let array = arrow2::array::PrimitiveArray::from_trusted_len_iter(sums).boxed(); + let array = arrow2::array::PrimitiveArray::from_trusted_len_iter(aggs).boxed(); Series::try_from((self.name(), array)) } @@ -566,46 +565,41 @@ impl FixedSizeListArray { } pub fn min(&self) -> DaftResult { - use crate::datatypes::DataType::*; - use arrow2::compute::aggregate::min_primitive; - - match self.flat_child.data_type() { - Int8 => self.agg_data_array(self.flat_child.i8()?, &min_primitive), - Int16 => self.agg_data_array(self.flat_child.i16()?, &min_primitive), - Int32 => self.agg_data_array(self.flat_child.i32()?, &min_primitive), - Int64 => self.agg_data_array(self.flat_child.i64()?, &min_primitive), - UInt8 => self.agg_data_array(self.flat_child.u8()?, &min_primitive), - UInt16 => self.agg_data_array(self.flat_child.u16()?, &min_primitive), - UInt32 => self.agg_data_array(self.flat_child.u32()?, &min_primitive), - UInt64 => self.agg_data_array(self.flat_child.u64()?, &min_primitive), - Float32 => self.agg_data_array(self.flat_child.f32()?, &min_primitive), - Float64 => self.agg_data_array(self.flat_child.f64()?, &min_primitive), - other => Err(DaftError::TypeError(format!( - "Min not implemented for {}", - other - ))), - } + self._min(&self.flat_child) } pub fn max(&self) -> DaftResult { - use crate::datatypes::DataType::*; - use arrow2::compute::aggregate::max_primitive; + self._max(&self.flat_child) + } +} - match self.flat_child.data_type() { - Int8 => self.agg_data_array(self.flat_child.i8()?, &max_primitive), - Int16 => self.agg_data_array(self.flat_child.i16()?, &max_primitive), - Int32 => self.agg_data_array(self.flat_child.i32()?, &max_primitive), - Int64 => self.agg_data_array(self.flat_child.i64()?, &max_primitive), - UInt8 => self.agg_data_array(self.flat_child.u8()?, &max_primitive), - UInt16 => self.agg_data_array(self.flat_child.u16()?, &max_primitive), - UInt32 => self.agg_data_array(self.flat_child.u32()?, &max_primitive), - UInt64 => self.agg_data_array(self.flat_child.u64()?, &max_primitive), - Float32 => self.agg_data_array(self.flat_child.f32()?, &max_primitive), - Float64 => self.agg_data_array(self.flat_child.f64()?, &max_primitive), - other => Err(DaftError::TypeError(format!( - "Max not implemented for {}", - other - ))), - } +impl ListChildAggable for FixedSizeListArray { + fn agg_data_array(&self, arr: &DataArray, op: U) -> DaftResult + where + T: DaftNumericType, + ::Simd: + Add::Simd> + Sum, + U: Fn(&PrimitiveArray) -> Option, + { + let step = self.fixed_element_len(); + let aggs = arrow2::types::IndexRange::new(0, self.len() as u64) + .map(|i| i as usize) + .map(|i| { + if let Some(validity) = self.validity() && !validity.get_bit(i) { + return None; + } + + let start = i * step; + let end = (i + 1) * step; + + let slice = arr.slice(start, end).unwrap(); + let slice_arr = slice.as_arrow(); + + op(slice_arr) + }); + + let array = arrow2::array::PrimitiveArray::from_trusted_len_iter(aggs).boxed(); + + Series::try_from((self.name(), array)) } } diff --git a/src/daft-core/src/datatypes/binary_ops.rs b/src/daft-core/src/datatypes/binary_ops.rs index 6d08a6312c..a0c8f1bc91 100644 --- a/src/daft-core/src/datatypes/binary_ops.rs +++ b/src/daft-core/src/datatypes/binary_ops.rs @@ -307,3 +307,31 @@ pub fn try_numeric_supertype(l: &DataType, r: &DataType) -> DaftResult l, r ))) } + +/// Get the data type that the sum of a column of the given data type should be casted to. +pub fn try_sum_supertype(dtype: &DataType) -> DaftResult { + use DataType::*; + match dtype { + Int8 | Int16 | Int32 | Int64 => Ok(Int64), + UInt8 | UInt16 | UInt32 | UInt64 => Ok(UInt64), + Float32 => Ok(Float32), + Float64 => Ok(Float64), + other => Err(DaftError::TypeError(format!( + "Invalid argument to sum supertype: {}", + other + ))), + } +} + +/// Get the data type that the mean of a column of the given data type should be casted to. +pub fn try_mean_supertype(dtype: &DataType) -> DaftResult { + use DataType::*; + if dtype.is_numeric() { + Ok(Float64) + } else { + Err(DaftError::TypeError(format!( + "Invalid argument to mean supertype: {}", + dtype + ))) + } +} diff --git a/src/daft-core/src/datatypes/field.rs b/src/daft-core/src/datatypes/field.rs index 33f523f1ba..9362c7c012 100644 --- a/src/daft-core/src/datatypes/field.rs +++ b/src/daft-core/src/datatypes/field.rs @@ -128,39 +128,6 @@ impl Field { ))), } } - - pub fn to_sum(&self) -> DaftResult { - Ok(Field::new( - self.name.as_str(), - match &self.dtype { - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { - DataType::Int64 - } - DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { - DataType::UInt64 - } - DataType::Float32 => DataType::Float32, - DataType::Float64 => DataType::Float64, - other => { - return Err(DaftError::TypeError(format!( - "Expected input to sum() to be numeric but received dtype {} for column \"{}\"", - other, self.name, - ))) - } - }, - )) - } - - pub fn to_mean(&self) -> DaftResult { - if self.dtype.is_numeric() { - Ok(Field::new(self.name.as_str(), DataType::Float64)) - } else { - Err(DaftError::TypeError(format!( - "Numeric mean is not implemented for column \"{}\" of type {}", - self.name, self.dtype, - ))) - } - } } impl From<&ArrowField> for Field { diff --git a/src/daft-core/src/datatypes/mod.rs b/src/daft-core/src/datatypes/mod.rs index 68d9c55464..2ed24633a5 100644 --- a/src/daft-core/src/datatypes/mod.rs +++ b/src/daft-core/src/datatypes/mod.rs @@ -12,7 +12,7 @@ use arrow2::{ compute::comparison::Simd8, types::{simd::Simd, NativeType}, }; -pub use binary_ops::try_physical_supertype; +pub use binary_ops::{try_mean_supertype, try_physical_supertype, try_sum_supertype}; pub use dtype::DataType; pub use field::Field; pub use field::FieldID; diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index a60ff226a2..aefa6b3c22 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -1,7 +1,6 @@ use daft_core::{ count_mode::CountMode, - datatypes::DataType, - datatypes::{Field, FieldID}, + datatypes::{try_sum_supertype, DataType, Field, FieldID}, schema::Schema, utils::supertype::try_get_supertype, }; @@ -159,8 +158,20 @@ impl AggExpr { let field = expr.to_field(schema)?; Ok(Field::new(field.name.as_str(), DataType::UInt64)) } - Sum(expr) => expr.to_field(schema)?.to_sum(), - Mean(expr) => expr.to_field(schema)?.to_mean(), + Sum(expr) => { + let field = expr.to_field(schema)?; + Ok(Field::new( + field.name.as_str(), + try_sum_supertype(&field.dtype)?, + )) + } + Mean(expr) => { + let field = expr.to_field(schema)?; + Ok(Field::new( + field.name.as_str(), + try_sum_supertype(&field.dtype)?, + )) + } Min(expr) | Max(expr) | AnyValue(expr, _) => { let field = expr.to_field(schema)?; Ok(Field::new(field.name.as_str(), field.dtype)) diff --git a/src/daft-dsl/src/functions/list/mean.rs b/src/daft-dsl/src/functions/list/mean.rs index aeac9bc388..47d9a0db71 100644 --- a/src/daft-dsl/src/functions/list/mean.rs +++ b/src/daft-dsl/src/functions/list/mean.rs @@ -1,5 +1,10 @@ use crate::Expr; -use daft_core::{datatypes::Field, schema::Schema, series::Series, IntoSeries}; +use daft_core::{ + datatypes::{try_mean_supertype, Field}, + schema::Schema, + series::Series, + IntoSeries, +}; use common_error::{DaftError, DaftResult}; @@ -14,7 +19,13 @@ impl FunctionEvaluator for MeanEvaluator { fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { match inputs { - [input] => Ok(input.to_field(schema)?.to_exploded_field()?.to_mean()?), + [input] => { + let field = input.to_field(schema)?; + Ok(Field::new( + field.name.as_str(), + try_mean_supertype(&field.dtype)?, + )) + } _ => Err(DaftError::SchemaMismatch(format!( "Expected 1 input arg, got {}", inputs.len() diff --git a/src/daft-dsl/src/functions/list/sum.rs b/src/daft-dsl/src/functions/list/sum.rs index 484aa65a75..061f5b5469 100644 --- a/src/daft-dsl/src/functions/list/sum.rs +++ b/src/daft-dsl/src/functions/list/sum.rs @@ -1,5 +1,9 @@ use crate::Expr; -use daft_core::{datatypes::Field, schema::Schema, series::Series}; +use daft_core::{ + datatypes::{try_sum_supertype, Field}, + schema::Schema, + series::Series, +}; use common_error::{DaftError, DaftResult}; @@ -14,7 +18,13 @@ impl FunctionEvaluator for SumEvaluator { fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { match inputs { - [input] => Ok(input.to_field(schema)?.to_exploded_field()?.to_sum()?), + [input] => { + let field = input.to_field(schema)?; + Ok(Field::new( + field.name.as_str(), + try_sum_supertype(&field.dtype)?, + )) + } _ => Err(DaftError::SchemaMismatch(format!( "Expected 1 input arg, got {}", inputs.len() From 93b23526e6b051435ea95a6d17c84dedb29c8096 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Thu, 4 Apr 2024 10:37:26 -0700 Subject: [PATCH 06/11] fix list agg field resolution --- src/daft-dsl/src/functions/list/mean.rs | 6 +++--- src/daft-dsl/src/functions/list/sum.rs | 7 ++++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/daft-dsl/src/functions/list/mean.rs b/src/daft-dsl/src/functions/list/mean.rs index 47d9a0db71..dd569461ef 100644 --- a/src/daft-dsl/src/functions/list/mean.rs +++ b/src/daft-dsl/src/functions/list/mean.rs @@ -20,10 +20,10 @@ impl FunctionEvaluator for MeanEvaluator { fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { match inputs { [input] => { - let field = input.to_field(schema)?; + let inner_field = input.to_field(schema)?.to_exploded_field()?; Ok(Field::new( - field.name.as_str(), - try_mean_supertype(&field.dtype)?, + inner_field.name.as_str(), + try_mean_supertype(&inner_field.dtype)?, )) } _ => Err(DaftError::SchemaMismatch(format!( diff --git a/src/daft-dsl/src/functions/list/sum.rs b/src/daft-dsl/src/functions/list/sum.rs index 061f5b5469..88ec0a56cc 100644 --- a/src/daft-dsl/src/functions/list/sum.rs +++ b/src/daft-dsl/src/functions/list/sum.rs @@ -19,10 +19,11 @@ impl FunctionEvaluator for SumEvaluator { fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { match inputs { [input] => { - let field = input.to_field(schema)?; + let inner_field = input.to_field(schema)?.to_exploded_field()?; + Ok(Field::new( - field.name.as_str(), - try_sum_supertype(&field.dtype)?, + inner_field.name.as_str(), + try_sum_supertype(&inner_field.dtype)?, )) } _ => Err(DaftError::SchemaMismatch(format!( From 25538a582d76cfc4f29e5ba010e1da138ac0acf0 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Thu, 4 Apr 2024 15:49:17 -0700 Subject: [PATCH 07/11] change aggregations to use sliced series aggs --- src/daft-core/src/array/ops/list.rs | 381 +++++------------------ src/daft-core/src/python/series.rs | 2 +- src/daft-core/src/series/ops/list.rs | 10 +- src/daft-dsl/src/functions/list/count.rs | 4 +- src/daft-dsl/src/functions/list/mean.rs | 3 +- src/daft-table/src/ops/explode.rs | 2 +- 6 files changed, 93 insertions(+), 309 deletions(-) diff --git a/src/daft-core/src/array/ops/list.rs b/src/daft-core/src/array/ops/list.rs index b706f1dc54..5604465a64 100644 --- a/src/daft-core/src/array/ops/list.rs +++ b/src/daft-core/src/array/ops/list.rs @@ -1,21 +1,17 @@ use std::iter::repeat; -use std::ops::Add; -use crate::array::DataArray; use crate::array::{ growable::{make_growable, Growable}, FixedSizeListArray, ListArray, }; -use crate::datatypes::{DaftNumericType, Float64Array, Int64Array, UInt64Array, Utf8Array}; +use crate::datatypes::{Int64Array, Utf8Array}; use crate::{CountMode, DataType}; use crate::series::Series; use arrow2; -use arrow2::array::PrimitiveArray; -use arrow2::compute::aggregate::Sum; -use common_error::{DaftError, DaftResult}; +use common_error::DaftResult; use super::as_arrow::AsArrow; @@ -45,92 +41,7 @@ fn join_arrow_list_of_utf8s( }) } -trait ListChildAggable { - fn agg_data_array(&self, arr: &DataArray, op: U) -> DaftResult - where - T: DaftNumericType, - ::Simd: - Add::Simd> + Sum, - U: Fn(&PrimitiveArray) -> Option; - - fn _min(&self, flat_child: &Series) -> DaftResult { - use crate::datatypes::DataType::*; - use arrow2::compute::aggregate::min_primitive; - - match flat_child.data_type() { - Int8 => self.agg_data_array(flat_child.i8()?, &min_primitive), - Int16 => self.agg_data_array(flat_child.i16()?, &min_primitive), - Int32 => self.agg_data_array(flat_child.i32()?, &min_primitive), - Int64 => self.agg_data_array(flat_child.i64()?, &min_primitive), - UInt8 => self.agg_data_array(flat_child.u8()?, &min_primitive), - UInt16 => self.agg_data_array(flat_child.u16()?, &min_primitive), - UInt32 => self.agg_data_array(flat_child.u32()?, &min_primitive), - UInt64 => self.agg_data_array(flat_child.u64()?, &min_primitive), - Float32 => self.agg_data_array(flat_child.f32()?, &min_primitive), - Float64 => self.agg_data_array(flat_child.f64()?, &min_primitive), - other => Err(DaftError::TypeError(format!( - "Min not implemented for {}", - other - ))), - } - } - - fn _max(&self, flat_child: &Series) -> DaftResult { - use crate::datatypes::DataType::*; - use arrow2::compute::aggregate::max_primitive; - - match flat_child.data_type() { - Int8 => self.agg_data_array(flat_child.i8()?, &max_primitive), - Int16 => self.agg_data_array(flat_child.i16()?, &max_primitive), - Int32 => self.agg_data_array(flat_child.i32()?, &max_primitive), - Int64 => self.agg_data_array(flat_child.i64()?, &max_primitive), - UInt8 => self.agg_data_array(flat_child.u8()?, &max_primitive), - UInt16 => self.agg_data_array(flat_child.u16()?, &max_primitive), - UInt32 => self.agg_data_array(flat_child.u32()?, &max_primitive), - UInt64 => self.agg_data_array(flat_child.u64()?, &max_primitive), - Float32 => self.agg_data_array(flat_child.f32()?, &max_primitive), - Float64 => self.agg_data_array(flat_child.f64()?, &max_primitive), - other => Err(DaftError::TypeError(format!( - "Max not implemented for {}", - other - ))), - } - } -} - impl ListArray { - pub fn count(&self, mode: CountMode) -> DaftResult { - let counts = match (mode, self.flat_child.validity()) { - (CountMode::All, _) | (CountMode::Valid, None) => { - self.offsets().lengths().map(|l| l as u64).collect() - } - (CountMode::Valid, Some(validity)) => self - .offsets() - .windows(2) - .map(|w| { - (w[0]..w[1]) - .map(|i| validity.get_bit(i as usize) as u64) - .sum() - }) - .collect(), - (CountMode::Null, None) => repeat(0).take(self.offsets().len() - 1).collect(), - (CountMode::Null, Some(validity)) => self - .offsets() - .windows(2) - .map(|w| { - (w[0]..w[1]) - .map(|i| !validity.get_bit(i as usize) as u64) - .sum() - }) - .collect(), - }; - - let array = Box::new( - arrow2::array::PrimitiveArray::from_vec(counts).with_validity(self.validity().cloned()), - ); - Ok(UInt64Array::from((self.name(), array))) - } - pub fn explode(&self) -> DaftResult { let offsets = self.offsets(); @@ -250,129 +161,66 @@ impl ListArray { } } - pub fn sum(&self) -> DaftResult { - use crate::datatypes::DataType::*; - use arrow2::compute::aggregate::sum_primitive; - - match self.flat_child.data_type() { - Int8 | Int16 | Int32 | Int64 => { - let casted = self.flat_child.cast(&Int64)?; - let arr = casted.i64()?; + fn agg_helper(&self, op: T) -> DaftResult + where + T: Fn(&Series) -> DaftResult, + { + let aggs = if let Some(validity) = self.validity() { + let test_result = op(&Series::empty("", self.flat_child.data_type()))?; - self.agg_data_array(arr, &sum_primitive) - } - UInt8 | UInt16 | UInt32 | UInt64 => { - let casted = self.flat_child.cast(&UInt64)?; - let arr = casted.u64()?; + (0..self.len()) + .map(|i| { + if validity.get_bit(i) { + let start = *self.offsets().get(i).unwrap() as usize; + let end = *self.offsets().get(i + 1).unwrap() as usize; + + let slice = self.flat_child.slice(start, end)?; + op(&slice) + } else { + Ok(Series::full_null("", test_result.data_type(), 1)) + } + }) + .collect::>>()? + } else { + self.offsets() + .windows(2) + .map(|w| { + let start = w[0] as usize; + let end = w[1] as usize; - self.agg_data_array(arr, &sum_primitive) - } - Float32 => self.agg_data_array(self.flat_child.f32()?, &sum_primitive), - Float64 => self.agg_data_array(self.flat_child.f64()?, &sum_primitive), - other => Err(DaftError::TypeError(format!( - "Sum not implemented for {}", - other - ))), - } - } + let slice = self.flat_child.slice(start, end)?; + op(&slice) + }) + .collect::>>()? + }; - pub fn mean(&self) -> DaftResult { - use crate::datatypes::DataType::*; + let agg_refs: Vec<_> = aggs.iter().collect(); - match self.flat_child.data_type() { - Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 | Float32 | Float64 => { - let counts = self.count(CountMode::Valid)?; - let sum_series = self.sum()?.cast(&Float64)?; - let sums = sum_series.f64()?; + Ok(Series::concat(agg_refs.as_slice())?.rename(self.name())) + } - let means = counts - .into_iter() - .zip(sums) - .map(|(count, sum)| match (count, sum) { - (Some(count), Some(sum)) if *count != 0 => Some(sum / (*count as f64)), - _ => None, - }); + pub fn count(&self, mode: CountMode) -> DaftResult { + self.agg_helper(|s| s.count(None, mode)) + } - let arr = Box::new(arrow2::array::PrimitiveArray::from_trusted_len_iter(means)); + pub fn sum(&self) -> DaftResult { + self.agg_helper(|s| s.sum(None)) + } - Ok(Float64Array::from((self.name(), arr))) - } - other => Err(DaftError::TypeError(format!( - "Mean not implemented for {}", - other - ))), - } + pub fn mean(&self) -> DaftResult { + self.agg_helper(|s| s.mean(None)) } pub fn min(&self) -> DaftResult { - self._min(&self.flat_child) + self.agg_helper(|s| s.min(None)) } pub fn max(&self) -> DaftResult { - self._max(&self.flat_child) - } -} - -impl ListChildAggable for ListArray { - fn agg_data_array(&self, arr: &DataArray, op: U) -> DaftResult - where - T: DaftNumericType, - ::Simd: - Add::Simd> + Sum, - U: Fn(&PrimitiveArray) -> Option, - { - let aggs = arrow2::types::IndexRange::new(0, self.len() as u64) - .map(|i| i as usize) - .map(|i| { - if let Some(validity) = self.validity() && !validity.get_bit(i) { - return None; - } - - let start = *self.offsets().get(i).unwrap() as usize; - let end = *self.offsets().get(i + 1).unwrap() as usize; - - let slice = arr.slice(start, end).unwrap(); - let slice_arr = slice.as_arrow(); - - op(slice_arr) - }); - - let array = arrow2::array::PrimitiveArray::from_trusted_len_iter(aggs).boxed(); - - Series::try_from((self.name(), array)) + self.agg_helper(|s| s.max(None)) } } impl FixedSizeListArray { - pub fn count(&self, mode: CountMode) -> DaftResult { - let size = self.fixed_element_len(); - let counts = match (mode, self.flat_child.validity()) { - (CountMode::All, _) | (CountMode::Valid, None) => { - repeat(size as u64).take(self.len()).collect() - } - (CountMode::Valid, Some(validity)) => (0..self.len()) - .map(|i| { - (0..size) - .map(|j| validity.get_bit(i * size + j) as u64) - .sum() - }) - .collect(), - (CountMode::Null, None) => repeat(0).take(self.len()).collect(), - (CountMode::Null, Some(validity)) => (0..self.len()) - .map(|i| { - (0..size) - .map(|j| !validity.get_bit(i * size + j) as u64) - .sum() - }) - .collect(), - }; - - let array = Box::new( - arrow2::array::PrimitiveArray::from_vec(counts).with_validity(self.validity().cloned()), - ); - Ok(UInt64Array::from((self.name(), array))) - } - pub fn explode(&self) -> DaftResult { let list_size = self.fixed_element_len(); let total_capacity = if list_size == 0 { @@ -481,125 +329,62 @@ impl FixedSizeListArray { } } - fn agg_data_array(&self, arr: &DataArray, op: U) -> DaftResult + fn agg_helper(&self, op: T) -> DaftResult where - T: DaftNumericType, - ::Simd: - Add::Simd> + Sum, - U: Fn(&PrimitiveArray) -> Option, + T: Fn(&Series) -> DaftResult, { let step = self.fixed_element_len(); - let aggs = arrow2::types::IndexRange::new(0, self.len() as u64) - .map(|i| i as usize) - .map(|i| { - if let Some(validity) = self.validity() && !validity.get_bit(i) { - return None; - } - let start = i * step; - let end = (i + 1) * step; + let aggs = if let Some(validity) = self.validity() { + let test_result = op(&Series::empty("", self.flat_child.data_type()))?; - let slice = arr.slice(start, end).unwrap(); - let slice_arr = slice.as_arrow(); + (0..self.len()) + .map(|i| { + if validity.get_bit(i) { + let start = i * step; + let end = (i + 1) * step; + + let slice = self.flat_child.slice(start, end)?; + op(&slice) + } else { + Ok(Series::full_null("", test_result.data_type(), 1)) + } + }) + .collect::>>()? + } else { + (0..self.len()) + .map(|i| { + let start = i * step; + let end = (i + 1) * step; - op(slice_arr) - }); + let slice = self.flat_child.slice(start, end)?; + op(&slice) + }) + .collect::>>()? + }; - let array = arrow2::array::PrimitiveArray::from_trusted_len_iter(aggs).boxed(); + let agg_refs: Vec<_> = aggs.iter().collect(); - Series::try_from((self.name(), array)) + Series::concat(agg_refs.as_slice()).map(|s| s.rename(self.name())) } - pub fn sum(&self) -> DaftResult { - use crate::datatypes::DataType::*; - use arrow2::compute::aggregate::sum_primitive; - - match self.flat_child.data_type() { - Int8 | Int16 | Int32 | Int64 => { - let casted = self.flat_child.cast(&Int64)?; - let arr = casted.i64()?; - - self.agg_data_array(arr, &sum_primitive) - } - UInt8 | UInt16 | UInt32 | UInt64 => { - let casted = self.flat_child.cast(&UInt64)?; - let arr = casted.u64()?; - - self.agg_data_array(arr, &sum_primitive) - } - Float32 => self.agg_data_array(self.flat_child.f32()?, &sum_primitive), - Float64 => self.agg_data_array(self.flat_child.f64()?, &sum_primitive), - other => Err(DaftError::TypeError(format!( - "Sum not implemented for {}", - other - ))), - } + pub fn count(&self, mode: CountMode) -> DaftResult { + self.agg_helper(|s| s.count(None, mode)) } - pub fn mean(&self) -> DaftResult { - use crate::datatypes::DataType::*; - - match self.flat_child.data_type() { - Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64 | Float32 | Float64 => { - let counts = self.count(CountMode::Valid)?; - let sum_series = self.sum()?.cast(&Float64)?; - let sums = sum_series.f64()?; - - let means = counts - .into_iter() - .zip(sums) - .map(|(count, sum)| match (count, sum) { - (Some(count), Some(sum)) if *count != 0 => Some(sum / (*count as f64)), - _ => None, - }); - - let arr = Box::new(arrow2::array::PrimitiveArray::from_trusted_len_iter(means)); + pub fn sum(&self) -> DaftResult { + self.agg_helper(|s| s.sum(None)) + } - Ok(Float64Array::from((self.name(), arr))) - } - other => Err(DaftError::TypeError(format!( - "Mean not implemented for {}", - other - ))), - } + pub fn mean(&self) -> DaftResult { + self.agg_helper(|s| s.mean(None)) } pub fn min(&self) -> DaftResult { - self._min(&self.flat_child) + self.agg_helper(|s| s.min(None)) } pub fn max(&self) -> DaftResult { - self._max(&self.flat_child) - } -} - -impl ListChildAggable for FixedSizeListArray { - fn agg_data_array(&self, arr: &DataArray, op: U) -> DaftResult - where - T: DaftNumericType, - ::Simd: - Add::Simd> + Sum, - U: Fn(&PrimitiveArray) -> Option, - { - let step = self.fixed_element_len(); - let aggs = arrow2::types::IndexRange::new(0, self.len() as u64) - .map(|i| i as usize) - .map(|i| { - if let Some(validity) = self.validity() && !validity.get_bit(i) { - return None; - } - - let start = i * step; - let end = (i + 1) * step; - - let slice = arr.slice(start, end).unwrap(); - let slice_arr = slice.as_arrow(); - - op(slice_arr) - }); - - let array = arrow2::array::PrimitiveArray::from_trusted_len_iter(aggs).boxed(); - - Series::try_from((self.name(), array)) + self.agg_helper(|s| s.max(None)) } } diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index e5ab0950d6..90234b2eb1 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -353,7 +353,7 @@ impl PySeries { } pub fn list_count(&self, mode: CountMode) -> PyResult { - Ok(self.series.list_count(mode)?.into_series().into()) + Ok(self.series.list_count(mode)?.into()) } pub fn list_get(&self, idx: &Self, default: &Self) -> PyResult { diff --git a/src/daft-core/src/series/ops/list.rs b/src/daft-core/src/series/ops/list.rs index a4cf81bb74..4f0d5fecd7 100644 --- a/src/daft-core/src/series/ops/list.rs +++ b/src/daft-core/src/series/ops/list.rs @@ -1,6 +1,6 @@ -use crate::datatypes::{DataType, Float64Array, UInt64Array, Utf8Array}; +use crate::datatypes::{DataType, UInt64Array, Utf8Array}; use crate::series::Series; -use crate::CountMode; +use crate::{CountMode, IntoSeries}; use common_error::DaftError; use common_error::DaftResult; @@ -18,7 +18,7 @@ impl Series { } } - pub fn list_count(&self, mode: CountMode) -> DaftResult { + pub fn list_count(&self, mode: CountMode) -> DaftResult { use DataType::*; match self.data_type() { @@ -35,7 +35,7 @@ impl Series { ) .with_validity(data_array.validity().cloned()), ); - Ok(UInt64Array::from((self.name(), array))) + Ok(UInt64Array::from((self.name(), array)).into_series()) } dt => Err(DaftError::TypeError(format!( "Count not implemented for {}", @@ -80,7 +80,7 @@ impl Series { } } - pub fn list_mean(&self) -> DaftResult { + pub fn list_mean(&self) -> DaftResult { match self.data_type() { DataType::List(_) => self.list()?.mean(), DataType::FixedSizeList(..) => self.fixed_size_list()?.mean(), diff --git a/src/daft-dsl/src/functions/list/count.rs b/src/daft-dsl/src/functions/list/count.rs index d6f72dac26..fc8c17e435 100644 --- a/src/daft-dsl/src/functions/list/count.rs +++ b/src/daft-dsl/src/functions/list/count.rs @@ -2,7 +2,7 @@ use crate::{functions::FunctionExpr, Expr}; use daft_core::{ datatypes::{DataType, Field}, schema::Schema, - series::{IntoSeries, Series}, + series::Series, }; use common_error::{DaftError, DaftResult}; @@ -53,7 +53,7 @@ impl FunctionEvaluator for CountEvaluator { _ => panic!("Expected List Count Expr, got {expr}"), }; - Ok(input.list_count(*mode)?.into_series()) + Ok(input.list_count(*mode)?) } _ => Err(DaftError::ValueError(format!( "Expected 1 input arg, got {}", diff --git a/src/daft-dsl/src/functions/list/mean.rs b/src/daft-dsl/src/functions/list/mean.rs index dd569461ef..c7409093b3 100644 --- a/src/daft-dsl/src/functions/list/mean.rs +++ b/src/daft-dsl/src/functions/list/mean.rs @@ -3,7 +3,6 @@ use daft_core::{ datatypes::{try_mean_supertype, Field}, schema::Schema, series::Series, - IntoSeries, }; use common_error::{DaftError, DaftResult}; @@ -35,7 +34,7 @@ impl FunctionEvaluator for MeanEvaluator { fn evaluate(&self, inputs: &[Series], _: &Expr) -> DaftResult { match inputs { - [input] => Ok(input.list_mean()?.into_series()), + [input] => Ok(input.list_mean()?), _ => Err(DaftError::ValueError(format!( "Expected 1 input arg, got {}", inputs.len() diff --git a/src/daft-table/src/ops/explode.rs b/src/daft-table/src/ops/explode.rs index bec57b8dfb..003da9d33b 100644 --- a/src/daft-table/src/ops/explode.rs +++ b/src/daft-table/src/ops/explode.rs @@ -80,7 +80,7 @@ impl Table { .collect::>>()?; let capacity_expected = exploded_columns.first().unwrap().len(); - let take_idx = lengths_to_indices(&first_len, capacity_expected)?.into_series(); + let take_idx = lengths_to_indices(first_len.u64()?, capacity_expected)?.into_series(); let mut new_series = self.columns.clone(); From edc30857982ae637227e2f88a6376e63918f0d93 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Thu, 4 Apr 2024 16:09:54 -0700 Subject: [PATCH 08/11] fix mean type --- src/daft-dsl/src/expr.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index aefa6b3c22..b4afa087ff 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -1,6 +1,6 @@ use daft_core::{ count_mode::CountMode, - datatypes::{try_sum_supertype, DataType, Field, FieldID}, + datatypes::{try_mean_supertype, try_sum_supertype, DataType, Field, FieldID}, schema::Schema, utils::supertype::try_get_supertype, }; @@ -169,7 +169,7 @@ impl AggExpr { let field = expr.to_field(schema)?; Ok(Field::new( field.name.as_str(), - try_sum_supertype(&field.dtype)?, + try_mean_supertype(&field.dtype)?, )) } Min(expr) | Max(expr) | AnyValue(expr, _) => { From c40a90c48d8aa6965da23f457c9f0caf6ee3c2ea Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Thu, 4 Apr 2024 16:36:23 -0700 Subject: [PATCH 09/11] add custom list count --- src/daft-core/src/array/ops/list.rs | 80 ++++++++++++++++++++---- src/daft-core/src/python/series.rs | 2 +- src/daft-core/src/series/ops/list.rs | 6 +- src/daft-dsl/src/functions/list/count.rs | 3 +- src/daft-table/src/ops/explode.rs | 2 +- 5 files changed, 75 insertions(+), 18 deletions(-) diff --git a/src/daft-core/src/array/ops/list.rs b/src/daft-core/src/array/ops/list.rs index 5604465a64..e6bac8a497 100644 --- a/src/daft-core/src/array/ops/list.rs +++ b/src/daft-core/src/array/ops/list.rs @@ -1,10 +1,13 @@ use std::iter::repeat; -use crate::array::{ - growable::{make_growable, Growable}, - FixedSizeListArray, ListArray, -}; use crate::datatypes::{Int64Array, Utf8Array}; +use crate::{ + array::{ + growable::{make_growable, Growable}, + FixedSizeListArray, ListArray, + }, + datatypes::UInt64Array, +}; use crate::{CountMode, DataType}; use crate::series::Series; @@ -42,6 +45,38 @@ fn join_arrow_list_of_utf8s( } impl ListArray { + pub fn count(&self, mode: CountMode) -> DaftResult { + let counts = match (mode, self.flat_child.validity()) { + (CountMode::All, _) | (CountMode::Valid, None) => { + self.offsets().lengths().map(|l| l as u64).collect() + } + (CountMode::Valid, Some(validity)) => self + .offsets() + .windows(2) + .map(|w| { + (w[0]..w[1]) + .map(|i| validity.get_bit(i as usize) as u64) + .sum() + }) + .collect(), + (CountMode::Null, None) => repeat(0).take(self.offsets().len() - 1).collect(), + (CountMode::Null, Some(validity)) => self + .offsets() + .windows(2) + .map(|w| { + (w[0]..w[1]) + .map(|i| !validity.get_bit(i as usize) as u64) + .sum() + }) + .collect(), + }; + + let array = Box::new( + arrow2::array::PrimitiveArray::from_vec(counts).with_validity(self.validity().cloned()), + ); + Ok(UInt64Array::from((self.name(), array))) + } + pub fn explode(&self) -> DaftResult { let offsets = self.offsets(); @@ -199,10 +234,6 @@ impl ListArray { Ok(Series::concat(agg_refs.as_slice())?.rename(self.name())) } - pub fn count(&self, mode: CountMode) -> DaftResult { - self.agg_helper(|s| s.count(None, mode)) - } - pub fn sum(&self) -> DaftResult { self.agg_helper(|s| s.sum(None)) } @@ -221,6 +252,35 @@ impl ListArray { } impl FixedSizeListArray { + pub fn count(&self, mode: CountMode) -> DaftResult { + let size = self.fixed_element_len(); + let counts = match (mode, self.flat_child.validity()) { + (CountMode::All, _) | (CountMode::Valid, None) => { + repeat(size as u64).take(self.len()).collect() + } + (CountMode::Valid, Some(validity)) => (0..self.len()) + .map(|i| { + (0..size) + .map(|j| validity.get_bit(i * size + j) as u64) + .sum() + }) + .collect(), + (CountMode::Null, None) => repeat(0).take(self.len()).collect(), + (CountMode::Null, Some(validity)) => (0..self.len()) + .map(|i| { + (0..size) + .map(|j| !validity.get_bit(i * size + j) as u64) + .sum() + }) + .collect(), + }; + + let array = Box::new( + arrow2::array::PrimitiveArray::from_vec(counts).with_validity(self.validity().cloned()), + ); + Ok(UInt64Array::from((self.name(), array))) + } + pub fn explode(&self) -> DaftResult { let list_size = self.fixed_element_len(); let total_capacity = if list_size == 0 { @@ -368,10 +428,6 @@ impl FixedSizeListArray { Series::concat(agg_refs.as_slice()).map(|s| s.rename(self.name())) } - pub fn count(&self, mode: CountMode) -> DaftResult { - self.agg_helper(|s| s.count(None, mode)) - } - pub fn sum(&self) -> DaftResult { self.agg_helper(|s| s.sum(None)) } diff --git a/src/daft-core/src/python/series.rs b/src/daft-core/src/python/series.rs index 90234b2eb1..e5ab0950d6 100644 --- a/src/daft-core/src/python/series.rs +++ b/src/daft-core/src/python/series.rs @@ -353,7 +353,7 @@ impl PySeries { } pub fn list_count(&self, mode: CountMode) -> PyResult { - Ok(self.series.list_count(mode)?.into()) + Ok(self.series.list_count(mode)?.into_series().into()) } pub fn list_get(&self, idx: &Self, default: &Self) -> PyResult { diff --git a/src/daft-core/src/series/ops/list.rs b/src/daft-core/src/series/ops/list.rs index 4f0d5fecd7..961e160d6c 100644 --- a/src/daft-core/src/series/ops/list.rs +++ b/src/daft-core/src/series/ops/list.rs @@ -1,6 +1,6 @@ use crate::datatypes::{DataType, UInt64Array, Utf8Array}; use crate::series::Series; -use crate::{CountMode, IntoSeries}; +use crate::CountMode; use common_error::DaftError; use common_error::DaftResult; @@ -18,7 +18,7 @@ impl Series { } } - pub fn list_count(&self, mode: CountMode) -> DaftResult { + pub fn list_count(&self, mode: CountMode) -> DaftResult { use DataType::*; match self.data_type() { @@ -35,7 +35,7 @@ impl Series { ) .with_validity(data_array.validity().cloned()), ); - Ok(UInt64Array::from((self.name(), array)).into_series()) + Ok(UInt64Array::from((self.name(), array))) } dt => Err(DaftError::TypeError(format!( "Count not implemented for {}", diff --git a/src/daft-dsl/src/functions/list/count.rs b/src/daft-dsl/src/functions/list/count.rs index fc8c17e435..10c818bb24 100644 --- a/src/daft-dsl/src/functions/list/count.rs +++ b/src/daft-dsl/src/functions/list/count.rs @@ -3,6 +3,7 @@ use daft_core::{ datatypes::{DataType, Field}, schema::Schema, series::Series, + IntoSeries, }; use common_error::{DaftError, DaftResult}; @@ -53,7 +54,7 @@ impl FunctionEvaluator for CountEvaluator { _ => panic!("Expected List Count Expr, got {expr}"), }; - Ok(input.list_count(*mode)?) + Ok(input.list_count(*mode)?.into_series()) } _ => Err(DaftError::ValueError(format!( "Expected 1 input arg, got {}", diff --git a/src/daft-table/src/ops/explode.rs b/src/daft-table/src/ops/explode.rs index 003da9d33b..bec57b8dfb 100644 --- a/src/daft-table/src/ops/explode.rs +++ b/src/daft-table/src/ops/explode.rs @@ -80,7 +80,7 @@ impl Table { .collect::>>()?; let capacity_expected = exploded_columns.first().unwrap().len(); - let take_idx = lengths_to_indices(first_len.u64()?, capacity_expected)?.into_series(); + let take_idx = lengths_to_indices(&first_len, capacity_expected)?.into_series(); let mut new_series = self.columns.clone(); From 005e434194208e8b51c72d20ac2745b78c9ffe8d Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Fri, 5 Apr 2024 11:13:51 -0700 Subject: [PATCH 10/11] move iterator logic out and use macros --- .../src/array/fixed_size_list_array.rs | 24 +++ src/daft-core/src/array/list_array.rs | 22 +++ src/daft-core/src/array/ops/list.rs | 140 +++++------------- 3 files changed, 82 insertions(+), 104 deletions(-) diff --git a/src/daft-core/src/array/fixed_size_list_array.rs b/src/daft-core/src/array/fixed_size_list_array.rs index a8faba1a19..7d79d605ad 100644 --- a/src/daft-core/src/array/fixed_size_list_array.rs +++ b/src/daft-core/src/array/fixed_size_list_array.rs @@ -132,6 +132,30 @@ impl FixedSizeListArray { )) } + pub fn iter(&self) -> Box> + '_> { + let step = self.fixed_element_len(); + + if let Some(validity) = self.validity() { + Box::new((0..self.len()).map(move |i| { + if validity.get_bit(i) { + let start = i * step; + let end = (i + 1) * step; + + Some(self.flat_child.slice(start, end).unwrap()) + } else { + None + } + })) + } else { + Box::new((0..self.len()).map(move |i| { + let start = i * step; + let end = (i + 1) * step; + + Some(self.flat_child.slice(start, end).unwrap()) + })) + } + } + pub fn to_arrow(&self) -> Box { let arrow_dtype = self.data_type().to_arrow().unwrap(); Box::new(arrow2::array::FixedSizeListArray::new( diff --git a/src/daft-core/src/array/list_array.rs b/src/daft-core/src/array/list_array.rs index 766f6d5d6b..fb13283e0d 100644 --- a/src/daft-core/src/array/list_array.rs +++ b/src/daft-core/src/array/list_array.rs @@ -151,6 +151,28 @@ impl ListArray { )) } + pub fn iter(&self) -> Box> + '_> { + if let Some(validity) = self.validity() { + Box::new((0..self.len()).map(|i| { + if validity.get_bit(i) { + let start = *self.offsets().get(i).unwrap() as usize; + let end = *self.offsets().get(i + 1).unwrap() as usize; + + Some(self.flat_child.slice(start, end).unwrap()) + } else { + None + } + })) + } else { + Box::new(self.offsets().windows(2).map(|w| { + let start = w[0] as usize; + let end = w[1] as usize; + + Some(self.flat_child.slice(start, end).unwrap()) + })) + } + } + pub fn to_arrow(&self) -> Box { let arrow_dtype = self.data_type().to_arrow().unwrap(); Box::new(arrow2::array::ListArray::new( diff --git a/src/daft-core/src/array/ops/list.rs b/src/daft-core/src/array/ops/list.rs index e6bac8a497..1bfdc635c4 100644 --- a/src/daft-core/src/array/ops/list.rs +++ b/src/daft-core/src/array/ops/list.rs @@ -195,60 +195,6 @@ impl ListArray { } } } - - fn agg_helper(&self, op: T) -> DaftResult - where - T: Fn(&Series) -> DaftResult, - { - let aggs = if let Some(validity) = self.validity() { - let test_result = op(&Series::empty("", self.flat_child.data_type()))?; - - (0..self.len()) - .map(|i| { - if validity.get_bit(i) { - let start = *self.offsets().get(i).unwrap() as usize; - let end = *self.offsets().get(i + 1).unwrap() as usize; - - let slice = self.flat_child.slice(start, end)?; - op(&slice) - } else { - Ok(Series::full_null("", test_result.data_type(), 1)) - } - }) - .collect::>>()? - } else { - self.offsets() - .windows(2) - .map(|w| { - let start = w[0] as usize; - let end = w[1] as usize; - - let slice = self.flat_child.slice(start, end)?; - op(&slice) - }) - .collect::>>()? - }; - - let agg_refs: Vec<_> = aggs.iter().collect(); - - Ok(Series::concat(agg_refs.as_slice())?.rename(self.name())) - } - - pub fn sum(&self) -> DaftResult { - self.agg_helper(|s| s.sum(None)) - } - - pub fn mean(&self) -> DaftResult { - self.agg_helper(|s| s.mean(None)) - } - - pub fn min(&self) -> DaftResult { - self.agg_helper(|s| s.min(None)) - } - - pub fn max(&self) -> DaftResult { - self.agg_helper(|s| s.max(None)) - } } impl FixedSizeListArray { @@ -388,59 +334,45 @@ impl FixedSizeListArray { } } } +} - fn agg_helper(&self, op: T) -> DaftResult - where - T: Fn(&Series) -> DaftResult, - { - let step = self.fixed_element_len(); - - let aggs = if let Some(validity) = self.validity() { - let test_result = op(&Series::empty("", self.flat_child.data_type()))?; - - (0..self.len()) - .map(|i| { - if validity.get_bit(i) { - let start = i * step; - let end = (i + 1) * step; - - let slice = self.flat_child.slice(start, end)?; - op(&slice) - } else { - Ok(Series::full_null("", test_result.data_type(), 1)) - } - }) - .collect::>>()? - } else { - (0..self.len()) - .map(|i| { - let start = i * step; - let end = (i + 1) * step; - - let slice = self.flat_child.slice(start, end)?; - op(&slice) - }) - .collect::>>()? - }; - - let agg_refs: Vec<_> = aggs.iter().collect(); - - Series::concat(agg_refs.as_slice()).map(|s| s.rename(self.name())) - } +macro_rules! impl_aggs_list_array { + ($la:ident) => { + impl $la { + fn agg_helper(&self, op: T) -> DaftResult + where + T: Fn(&Series) -> DaftResult, + { + // Assumes `op`` returns a null Series given an empty Series + let aggs = self + .iter() + .map(|s| s.unwrap_or(Series::empty("", self.child_data_type()))) + .map(|s| op(&s)) + .collect::>>()?; + + let agg_refs: Vec<_> = aggs.iter().collect(); + + Series::concat(agg_refs.as_slice()).map(|s| s.rename(self.name())) + } - pub fn sum(&self) -> DaftResult { - self.agg_helper(|s| s.sum(None)) - } + pub fn sum(&self) -> DaftResult { + self.agg_helper(|s| s.sum(None)) + } - pub fn mean(&self) -> DaftResult { - self.agg_helper(|s| s.mean(None)) - } + pub fn mean(&self) -> DaftResult { + self.agg_helper(|s| s.mean(None)) + } - pub fn min(&self) -> DaftResult { - self.agg_helper(|s| s.min(None)) - } + pub fn min(&self) -> DaftResult { + self.agg_helper(|s| s.min(None)) + } - pub fn max(&self) -> DaftResult { - self.agg_helper(|s| s.max(None)) - } + pub fn max(&self) -> DaftResult { + self.agg_helper(|s| s.max(None)) + } + } + }; } + +impl_aggs_list_array!(ListArray); +impl_aggs_list_array!(FixedSizeListArray); From ce971a83070ffb4a46a6ca463a1402b4a5c6e7e2 Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Wed, 10 Apr 2024 16:01:04 -0700 Subject: [PATCH 11/11] implement IntoIterator for list types --- .../src/array/fixed_size_list_array.rs | 65 ++++++++++++------- src/daft-core/src/array/list_array.rs | 61 ++++++++++------- src/daft-core/src/array/ops/list.rs | 4 +- src/daft-core/src/datatypes/agg_ops.rs | 31 +++++++++ src/daft-core/src/datatypes/binary_ops.rs | 28 -------- src/daft-core/src/datatypes/mod.rs | 4 +- 6 files changed, 117 insertions(+), 76 deletions(-) create mode 100644 src/daft-core/src/datatypes/agg_ops.rs diff --git a/src/daft-core/src/array/fixed_size_list_array.rs b/src/daft-core/src/array/fixed_size_list_array.rs index 7d79d605ad..2784f5746c 100644 --- a/src/daft-core/src/array/fixed_size_list_array.rs +++ b/src/daft-core/src/array/fixed_size_list_array.rs @@ -132,30 +132,6 @@ impl FixedSizeListArray { )) } - pub fn iter(&self) -> Box> + '_> { - let step = self.fixed_element_len(); - - if let Some(validity) = self.validity() { - Box::new((0..self.len()).map(move |i| { - if validity.get_bit(i) { - let start = i * step; - let end = (i + 1) * step; - - Some(self.flat_child.slice(start, end).unwrap()) - } else { - None - } - })) - } else { - Box::new((0..self.len()).map(move |i| { - let start = i * step; - let end = (i + 1) * step; - - Some(self.flat_child.slice(start, end).unwrap()) - })) - } - } - pub fn to_arrow(&self) -> Box { let arrow_dtype = self.data_type().to_arrow().unwrap(); Box::new(arrow2::array::FixedSizeListArray::new( @@ -190,6 +166,47 @@ impl FixedSizeListArray { } } +impl<'a> IntoIterator for &'a FixedSizeListArray { + type Item = Option; + + type IntoIter = FixedSizeListArrayIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + FixedSizeListArrayIter { + array: self, + idx: 0, + } + } +} + +pub struct FixedSizeListArrayIter<'a> { + array: &'a FixedSizeListArray, + idx: usize, +} + +impl Iterator for FixedSizeListArrayIter<'_> { + type Item = Option; + + fn next(&mut self) -> Option { + if self.idx < self.array.len() { + if let Some(validity) = self.array.validity() && !validity.get_bit(self.idx) { + self.idx += 1; + Some(None) + } else { + let step = self.array.fixed_element_len(); + + let start = self.idx * step; + let end = (self.idx + 1) * step; + + self.idx += 1; + Some(Some(self.array.flat_child.slice(start, end).unwrap())) + } + } else { + None + } + } +} + #[cfg(test)] mod tests { use common_error::DaftResult; diff --git a/src/daft-core/src/array/list_array.rs b/src/daft-core/src/array/list_array.rs index fb13283e0d..abe879b7bd 100644 --- a/src/daft-core/src/array/list_array.rs +++ b/src/daft-core/src/array/list_array.rs @@ -151,28 +151,6 @@ impl ListArray { )) } - pub fn iter(&self) -> Box> + '_> { - if let Some(validity) = self.validity() { - Box::new((0..self.len()).map(|i| { - if validity.get_bit(i) { - let start = *self.offsets().get(i).unwrap() as usize; - let end = *self.offsets().get(i + 1).unwrap() as usize; - - Some(self.flat_child.slice(start, end).unwrap()) - } else { - None - } - })) - } else { - Box::new(self.offsets().windows(2).map(|w| { - let start = w[0] as usize; - let end = w[1] as usize; - - Some(self.flat_child.slice(start, end).unwrap()) - })) - } - } - pub fn to_arrow(&self) -> Box { let arrow_dtype = self.data_type().to_arrow().unwrap(); Box::new(arrow2::array::ListArray::new( @@ -200,3 +178,42 @@ impl ListArray { )) } } + +impl<'a> IntoIterator for &'a ListArray { + type Item = Option; + + type IntoIter = ListArrayIter<'a>; + + fn into_iter(self) -> Self::IntoIter { + ListArrayIter { + array: self, + idx: 0, + } + } +} + +pub struct ListArrayIter<'a> { + array: &'a ListArray, + idx: usize, +} + +impl Iterator for ListArrayIter<'_> { + type Item = Option; + + fn next(&mut self) -> Option { + if self.idx < self.array.len() { + if let Some(validity) = self.array.validity() && !validity.get_bit(self.idx) { + self.idx += 1; + Some(None) + } else { + let start = *self.array.offsets().get(self.idx).unwrap() as usize; + let end = *self.array.offsets().get(self.idx + 1).unwrap() as usize; + + self.idx += 1; + Some(Some(self.array.flat_child.slice(start, end).unwrap())) + } + } else { + None + } + } +} diff --git a/src/daft-core/src/array/ops/list.rs b/src/daft-core/src/array/ops/list.rs index 1bfdc635c4..bc93cc642a 100644 --- a/src/daft-core/src/array/ops/list.rs +++ b/src/daft-core/src/array/ops/list.rs @@ -343,9 +343,11 @@ macro_rules! impl_aggs_list_array { where T: Fn(&Series) -> DaftResult, { + // TODO(Kevin): Currently this requires full materialization of one Series for every list. We could avoid this by implementing either sorted aggregation or an array builder + // Assumes `op`` returns a null Series given an empty Series let aggs = self - .iter() + .into_iter() .map(|s| s.unwrap_or(Series::empty("", self.child_data_type()))) .map(|s| op(&s)) .collect::>>()?; diff --git a/src/daft-core/src/datatypes/agg_ops.rs b/src/daft-core/src/datatypes/agg_ops.rs new file mode 100644 index 0000000000..48a89968b6 --- /dev/null +++ b/src/daft-core/src/datatypes/agg_ops.rs @@ -0,0 +1,31 @@ +use common_error::{DaftError, DaftResult}; + +use super::DataType; + +/// Get the data type that the sum of a column of the given data type should be casted to. +pub fn try_sum_supertype(dtype: &DataType) -> DaftResult { + use DataType::*; + match dtype { + Int8 | Int16 | Int32 | Int64 => Ok(Int64), + UInt8 | UInt16 | UInt32 | UInt64 => Ok(UInt64), + Float32 => Ok(Float32), + Float64 => Ok(Float64), + other => Err(DaftError::TypeError(format!( + "Invalid argument to sum supertype: {}", + other + ))), + } +} + +/// Get the data type that the mean of a column of the given data type should be casted to. +pub fn try_mean_supertype(dtype: &DataType) -> DaftResult { + use DataType::*; + if dtype.is_numeric() { + Ok(Float64) + } else { + Err(DaftError::TypeError(format!( + "Invalid argument to mean supertype: {}", + dtype + ))) + } +} diff --git a/src/daft-core/src/datatypes/binary_ops.rs b/src/daft-core/src/datatypes/binary_ops.rs index a0c8f1bc91..6d08a6312c 100644 --- a/src/daft-core/src/datatypes/binary_ops.rs +++ b/src/daft-core/src/datatypes/binary_ops.rs @@ -307,31 +307,3 @@ pub fn try_numeric_supertype(l: &DataType, r: &DataType) -> DaftResult l, r ))) } - -/// Get the data type that the sum of a column of the given data type should be casted to. -pub fn try_sum_supertype(dtype: &DataType) -> DaftResult { - use DataType::*; - match dtype { - Int8 | Int16 | Int32 | Int64 => Ok(Int64), - UInt8 | UInt16 | UInt32 | UInt64 => Ok(UInt64), - Float32 => Ok(Float32), - Float64 => Ok(Float64), - other => Err(DaftError::TypeError(format!( - "Invalid argument to sum supertype: {}", - other - ))), - } -} - -/// Get the data type that the mean of a column of the given data type should be casted to. -pub fn try_mean_supertype(dtype: &DataType) -> DaftResult { - use DataType::*; - if dtype.is_numeric() { - Ok(Float64) - } else { - Err(DaftError::TypeError(format!( - "Invalid argument to mean supertype: {}", - dtype - ))) - } -} diff --git a/src/daft-core/src/datatypes/mod.rs b/src/daft-core/src/datatypes/mod.rs index 2ed24633a5..3b937fbbc5 100644 --- a/src/daft-core/src/datatypes/mod.rs +++ b/src/daft-core/src/datatypes/mod.rs @@ -1,3 +1,4 @@ +mod agg_ops; mod binary_ops; mod dtype; mod field; @@ -8,11 +9,12 @@ mod time_unit; pub use crate::array::{DataArray, FixedSizeListArray}; use crate::array::{ListArray, StructArray}; +pub use agg_ops::{try_mean_supertype, try_sum_supertype}; use arrow2::{ compute::comparison::Simd8, types::{simd::Simd, NativeType}, }; -pub use binary_ops::{try_mean_supertype, try_physical_supertype, try_sum_supertype}; +pub use binary_ops::try_physical_supertype; pub use dtype::DataType; pub use field::Field; pub use field::FieldID;