Skip to content

Commit

Permalink
add any_value implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinzwang committed Feb 23, 2024
1 parent b493f28 commit 611d1d6
Show file tree
Hide file tree
Showing 15 changed files with 168 additions and 52 deletions.
2 changes: 1 addition & 1 deletion daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,7 @@ class PyExpr:
def mean(self) -> PyExpr: ...
def min(self) -> PyExpr: ...
def max(self) -> PyExpr: ...
def any_value(self) -> PyExpr: ...
def any_value(self, ignore_nulls: bool) -> PyExpr: ...
def agg_list(self) -> PyExpr: ...
def agg_concat(self) -> PyExpr: ...
def explode(self) -> PyExpr: ...
Expand Down
4 changes: 2 additions & 2 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,8 @@ def _max(self) -> Expression:
expr = self._expr.max()
return Expression._from_pyexpr(expr)

def _any_value(self) -> Expression:
expr = self._expr.any_value()
def _any_value(self, ignore_nulls=False) -> Expression:
expr = self._expr.any_value(ignore_nulls)
return Expression._from_pyexpr(expr)

Check warning on line 345 in daft/expressions/expressions.py

View check run for this annotation

Codecov / codecov/patch

daft/expressions/expressions.py#L344-L345

Added lines #L344 - L345 were not covered by tests

def _agg_list(self) -> Expression:
Expand Down
88 changes: 73 additions & 15 deletions src/daft-core/src/array/ops/any_value.rs
Original file line number Diff line number Diff line change
@@ -1,43 +1,101 @@
use arrow2::{array::PrimitiveArray, bitmap::Bitmap};
use common_error::DaftResult;

use crate::{
array::{DataArray, FixedSizeListArray, ListArray, StructArray},
datatypes::DaftPhysicalType,
datatypes::*,
};

use super::DaftAnyValueAggable;
use super::{full::FullNull, DaftAnyValueAggable, GroupIndices};

fn get_any_grouped_idx(
groups: &GroupIndices,
ignore_nulls: bool,
validity: Option<&Bitmap>,
) -> UInt64Array {
let group_indices = if ignore_nulls && let Some(validity) = validity {
Box::new(PrimitiveArray::from_trusted_len_iter(groups.iter().map(|g| {
for i in g {
if validity.get_bit(*i as usize) {
return Some(*i)
}
}
None
})))
} else {
Box::new(PrimitiveArray::from_trusted_len_iter(groups.iter().map(|g| g.first().cloned())))
};

DataArray::from(("", group_indices))
}

impl<T> DaftAnyValueAggable for DataArray<T>
where
T: DaftPhysicalType,
T: DaftNumericType,
{
type Output = DaftResult<DataArray<T>>;

fn any_value(&self) -> Self::Output {
todo!()
fn any_value(&self, ignore_nulls: bool) -> Self::Output {
if ignore_nulls && let Some(validity) = self.validity() {
for i in 0..self.len() {
if validity.get_bit(i) {
return self.slice(i, i+1);
}
}

Ok(DataArray::full_null(
self.name(),
self.data_type(),
1
))
} else {
self.slice(0, 1)
}
}

fn grouped_any_value(&self, _groups: &super::GroupIndices) -> Self::Output {
todo!()
fn grouped_any_value(&self, groups: &GroupIndices, ignore_nulls: bool) -> Self::Output {
self.take(&get_any_grouped_idx(groups, ignore_nulls, self.validity()))
}
}

macro_rules! impl_daft_any_value_nested_array {
macro_rules! impl_daft_any_value {
($arr:ident) => {
impl DaftAnyValueAggable for $arr {
type Output = DaftResult<$arr>;

fn any_value(&self) -> Self::Output {
todo!()
fn any_value(&self, ignore_nulls: bool) -> Self::Output {
if ignore_nulls && let Some(validity) = self.validity() {
for i in 0..self.len() {
if validity.get_bit(i) {
return self.slice(i, i+1);
}
}

Ok($arr::full_null(
self.name(),
self.data_type(),
1
))
} else {
self.slice(0, 1)
}
}

fn grouped_any_value(&self, _groups: &super::GroupIndices) -> Self::Output {
todo!()
fn grouped_any_value(&self, groups: &GroupIndices, ignore_nulls: bool) -> Self::Output {
self.take(&get_any_grouped_idx(groups, ignore_nulls, self.validity()))
}
}
};
}

impl_daft_any_value_nested_array!(FixedSizeListArray);
impl_daft_any_value_nested_array!(ListArray);
impl_daft_any_value_nested_array!(StructArray);
impl_daft_any_value!(Utf8Array);
impl_daft_any_value!(BooleanArray);
impl_daft_any_value!(BinaryArray);
impl_daft_any_value!(NullArray);
impl_daft_any_value!(ExtensionArray);
impl_daft_any_value!(FixedSizeListArray);
impl_daft_any_value!(ListArray);
impl_daft_any_value!(StructArray);

#[cfg(feature = "python")]
impl_daft_any_value!(PythonArray);
4 changes: 2 additions & 2 deletions src/daft-core/src/array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ pub trait DaftCompareAggable {

pub trait DaftAnyValueAggable {
type Output;
fn any_value(&self) -> Self::Output;
fn grouped_any_value(&self, groups: &GroupIndices) -> Self::Output;
fn any_value(&self, ignore_nulls: bool) -> Self::Output;
fn grouped_any_value(&self, groups: &GroupIndices, ignore_nulls: bool) -> Self::Output;
}

pub trait DaftListAggable {
Expand Down
15 changes: 12 additions & 3 deletions src/daft-core/src/series/array_impl/data_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,22 @@ macro_rules! impl_series_like_for_data_array {
}
}

fn any_value(&self, groups: Option<&GroupIndices>) -> DaftResult<Series> {
fn any_value(
&self,
groups: Option<&GroupIndices>,
ignore_nulls: bool,
) -> DaftResult<Series> {
use crate::array::ops::DaftAnyValueAggable;
match groups {
Some(groups) => {
Ok(DaftAnyValueAggable::grouped_any_value(&self.0, groups)?.into_series())
Ok(
DaftAnyValueAggable::grouped_any_value(&self.0, groups, ignore_nulls)?
.into_series(),
)
}
None => {
Ok(DaftAnyValueAggable::any_value(&self.0, ignore_nulls)?.into_series())
}
None => Ok(DaftAnyValueAggable::any_value(&self.0)?.into_series()),
}
}

Expand Down
16 changes: 11 additions & 5 deletions src/daft-core/src/series/array_impl/logical_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,19 @@ macro_rules! impl_series_like_for_logical_array {
};
Ok($da::new(self.0.field.clone(), data_array).into_series())
}
fn any_value(&self, groups: Option<&GroupIndices>) -> DaftResult<Series> {
fn any_value(
&self,
groups: Option<&GroupIndices>,
ignore_nulls: bool,
) -> DaftResult<Series> {
use crate::array::ops::DaftAnyValueAggable;
let data_array = match groups {
Some(groups) => {
DaftAnyValueAggable::grouped_any_value(&self.0.physical, groups)?
}
None => DaftAnyValueAggable::any_value(&self.0.physical)?,
Some(groups) => DaftAnyValueAggable::grouped_any_value(
&self.0.physical,
groups,
ignore_nulls,
)?,
None => DaftAnyValueAggable::any_value(&self.0.physical, ignore_nulls)?,
};
Ok($da::new(self.0.field.clone(), data_array).into_series())
}
Expand Down
19 changes: 17 additions & 2 deletions src/daft-core/src/series/array_impl/nested_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,23 @@ macro_rules! impl_series_like_for_nested_arrays {
)))
}

fn any_value(&self, _groups: Option<&GroupIndices>) -> DaftResult<Series> {
todo!();
fn any_value(
&self,
groups: Option<&GroupIndices>,
ignore_nulls: bool,
) -> DaftResult<Series> {
use crate::array::ops::DaftAnyValueAggable;
match groups {
Some(groups) => {
Ok(
DaftAnyValueAggable::grouped_any_value(&self.0, groups, ignore_nulls)?
.into_series(),
)
}
None => {
Ok(DaftAnyValueAggable::any_value(&self.0, ignore_nulls)?.into_series())
}
}
}

fn agg_list(&self, groups: Option<&GroupIndices>) -> DaftResult<Series> {
Expand Down
8 changes: 6 additions & 2 deletions src/daft-core/src/series/ops/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,12 @@ impl Series {
self.inner.max(groups)
}

pub fn any_value(&self, groups: Option<&GroupIndices>) -> DaftResult<Series> {
self.inner.any_value(groups)
pub fn any_value(
&self,
groups: Option<&GroupIndices>,
ignore_nulls: bool,
) -> DaftResult<Series> {
self.inner.any_value(groups, ignore_nulls)
}

pub fn agg_list(&self, groups: Option<&GroupIndices>) -> DaftResult<Series> {
Expand Down
2 changes: 1 addition & 1 deletion src/daft-core/src/series/series_like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub trait SeriesLike: Send + Sync + Any + std::fmt::Debug {
fn validity(&self) -> Option<&arrow2::bitmap::Bitmap>;
fn min(&self, groups: Option<&GroupIndices>) -> DaftResult<Series>;
fn max(&self, groups: Option<&GroupIndices>) -> DaftResult<Series>;
fn any_value(&self, groups: Option<&GroupIndices>) -> DaftResult<Series>;
fn any_value(&self, groups: Option<&GroupIndices>, ignore_nulls: bool) -> DaftResult<Series>;
fn agg_list(&self, groups: Option<&GroupIndices>) -> DaftResult<Series>;
fn broadcast(&self, num: usize) -> DaftResult<Series>;
fn cast(&self, datatype: &DataType) -> DaftResult<Series>;
Expand Down
22 changes: 13 additions & 9 deletions src/daft-dsl/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ pub enum AggExpr {
Mean(ExprRef),
Min(ExprRef),
Max(ExprRef),
AnyValue(ExprRef),
AnyValue(ExprRef, bool),
List(ExprRef),
Concat(ExprRef),
MapGroups {
Expand Down Expand Up @@ -88,7 +88,7 @@ impl AggExpr {
| Mean(expr)
| Min(expr)
| Max(expr)
| AnyValue(expr)
| AnyValue(expr, _)
| List(expr)
| Concat(expr) => expr.name(),
MapGroups { func: _, inputs } => inputs.first().unwrap().name(),
Expand Down Expand Up @@ -118,9 +118,11 @@ impl AggExpr {
let child_id = expr.semantic_id(schema);
FieldID::new(format!("{child_id}.local_max()"))
}
AnyValue(expr) => {
AnyValue(expr, ignore_nulls) => {
let child_id = expr.semantic_id(schema);
FieldID::new(format!("{child_id}.local_any_value()"))
FieldID::new(format!(
"{child_id}.local_any_value(ignore_nulls={ignore_nulls})"
))
}
List(expr) => {
let child_id = expr.semantic_id(schema);
Expand All @@ -142,7 +144,7 @@ impl AggExpr {
| Mean(expr)
| Min(expr)
| Max(expr)
| AnyValue(expr)
| AnyValue(expr, _)
| List(expr)
| Concat(expr) => vec![expr.clone()],
MapGroups { func: _, inputs } => inputs.iter().map(|e| e.clone().into()).collect(),
Expand Down Expand Up @@ -203,7 +205,7 @@ impl AggExpr {
},
))
}
Min(expr) | Max(expr) | AnyValue(expr) => {
Min(expr) | Max(expr) | AnyValue(expr, _) => {
let field = expr.to_field(schema)?;
Ok(Field::new(field.name.as_str(), field.dtype))
}
Expand Down Expand Up @@ -286,8 +288,8 @@ impl Expr {
Expr::Agg(AggExpr::Max(self.clone().into()))
}

pub fn any_value(&self) -> Self {
Expr::Agg(AggExpr::AnyValue(self.clone().into()))
pub fn any_value(&self, ignore_nulls: bool) -> Self {
Expr::Agg(AggExpr::AnyValue(self.clone().into(), ignore_nulls))
}

pub fn agg_list(&self) -> Self {
Expand Down Expand Up @@ -615,7 +617,9 @@ impl Display for AggExpr {
Mean(expr) => write!(f, "mean({expr})"),
Min(expr) => write!(f, "min({expr})"),
Max(expr) => write!(f, "max({expr})"),
AnyValue(expr) => write!(f, "any_value({expr})"),
AnyValue(expr, ignore_nulls) => {
write!(f, "any_value({expr}, ignore_nulls={ignore_nulls})")
}
List(expr) => write!(f, "list({expr})"),
Concat(expr) => write!(f, "list({expr})"),
MapGroups { func, inputs } => function_display(f, func, inputs),
Expand Down
4 changes: 2 additions & 2 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ impl PyExpr {
Ok(self.expr.max().into())
}

pub fn any_value(&self) -> PyResult<Self> {
Ok(self.expr.any_value().into())
pub fn any_value(&self, ignore_nulls: bool) -> PyResult<Self> {
Ok(self.expr.any_value(ignore_nulls).into())
}

pub fn agg_list(&self) -> PyResult<Self> {
Expand Down
6 changes: 4 additions & 2 deletions src/daft-dsl/src/treenode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ impl TreeNode for Expr {
| Mean(expr)
| Min(expr)
| Max(expr)
| AnyValue(expr)
| AnyValue(expr, _)
| List(expr)
| Concat(expr) => vec![expr.as_ref()],
MapGroups { func: _, inputs } => inputs.iter().collect::<Vec<_>>(),
Expand Down Expand Up @@ -66,7 +66,9 @@ impl TreeNode for Expr {
Mean(expr) => transform(expr.as_ref().clone())?.mean(),
Min(expr) => transform(expr.as_ref().clone())?.min(),
Max(expr) => transform(expr.as_ref().clone())?.max(),
AnyValue(expr) => transform(expr.as_ref().clone())?.any_value(),
AnyValue(expr, ignore_nulls) => {
transform(expr.as_ref().clone())?.any_value(ignore_nulls)
}
List(expr) => transform(expr.as_ref().clone())?.agg_list(),
Concat(expr) => transform(expr.as_ref().clone())?.agg_concat(),
MapGroups { func, inputs } => Expr::Agg(MapGroups {
Expand Down
8 changes: 5 additions & 3 deletions src/daft-plan/src/logical_ops/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,9 +370,11 @@ fn replace_column_with_semantic_id_aggexpr(
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
.map_yes_no(AggExpr::Max, |_| e.clone())
}
AggExpr::AnyValue(ref child) => {
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
.map_yes_no(AggExpr::AnyValue, |_| e.clone())
AggExpr::AnyValue(ref child, ignore_nulls) => {
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema).map_yes_no(
|transformed_child| AggExpr::AnyValue(transformed_child, ignore_nulls),
|_| e.clone(),
)
}
AggExpr::List(ref child) => {
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
Expand Down
18 changes: 16 additions & 2 deletions src/daft-plan/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,22 @@ pub fn plan(logical_plan: &LogicalPlan, cfg: Arc<DaftExecutionConfig>) -> DaftRe
.into()));
final_exprs.push(Column(max_of_max_id.clone()).alias(output_name));
}
AnyValue(_e) => {
todo!()
AnyValue(e, ignore_nulls) => {
let any_id = agg_expr.semantic_id(&schema).id;
let any_of_any_id =
AnyValue(Column(any_id.clone()).into(), *ignore_nulls)
.semantic_id(&schema)
.id;
first_stage_aggs.entry(any_id.clone()).or_insert(AnyValue(
e.alias(any_id.clone()).clone().into(),
*ignore_nulls,
));
second_stage_aggs
.entry(any_of_any_id.clone())
.or_insert(AnyValue(
Column(any_id.clone()).alias(any_of_any_id.clone()).into(),
*ignore_nulls,
));
}
List(e) => {
let list_id = agg_expr.semantic_id(&schema).id;
Expand Down
Loading

0 comments on commit 611d1d6

Please sign in to comment.