Skip to content

Commit

Permalink
add all functions to implement
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinzwang committed Feb 21, 2024
1 parent 9c66a5e commit b493f28
Show file tree
Hide file tree
Showing 17 changed files with 124 additions and 1 deletion.
1 change: 1 addition & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,7 @@ class PyExpr:
def mean(self) -> PyExpr: ...
def min(self) -> PyExpr: ...
def max(self) -> PyExpr: ...
def any_value(self) -> PyExpr: ...
def agg_list(self) -> PyExpr: ...
def agg_concat(self) -> PyExpr: ...
def explode(self) -> PyExpr: ...
Expand Down
11 changes: 11 additions & 0 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1546,6 +1546,17 @@ def max(self, *cols: ColumnInputType) -> "DataFrame":

return self.df._agg([(c, "max") for c in cols], group_by=self.group_by)

def any_value(self, *cols: ColumnInputType) -> "DataFrame":
"""Returns a single non-deterministic value from each group in this GroupedDataFrame
Args:
*cols (Union[str, Expression]): columns to get
Returns:
DataFrame: DataFrame with any values.
"""
return self.df._agg([(c, "any_value") for c in cols], group_by=self.group_by)

Check warning on line 1558 in daft/dataframe/dataframe.py

View check run for this annotation

Codecov / codecov/patch

daft/dataframe/dataframe.py#L1558

Added line #L1558 was not covered by tests

def count(self) -> "DataFrame":
"""Performs grouped count on this GroupedDataFrame.
Expand Down
4 changes: 4 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,10 @@ def _max(self) -> Expression:
expr = self._expr.max()
return Expression._from_pyexpr(expr)

def _any_value(self) -> Expression:
expr = self._expr.any_value()
return Expression._from_pyexpr(expr)

Check warning on line 345 in daft/expressions/expressions.py

View check run for this annotation

Codecov / codecov/patch

daft/expressions/expressions.py#L344-L345

Added lines #L344 - L345 were not covered by tests

def _agg_list(self) -> Expression:
expr = self._expr.agg_list()
return Expression._from_pyexpr(expr)
Expand Down
2 changes: 2 additions & 0 deletions daft/logical/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ def agg(
exprs.append(expr._max())
elif op == "mean":
exprs.append(expr._mean())
elif op == "any_value":
exprs.append(expr._any_value())

Check warning on line 175 in daft/logical/builder.py

View check run for this annotation

Codecov / codecov/patch

daft/logical/builder.py#L175

Added line #L175 was not covered by tests
elif op == "list":
exprs.append(expr._agg_list())
elif op == "concat":
Expand Down
43 changes: 43 additions & 0 deletions src/daft-core/src/array/ops/any_value.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
use common_error::DaftResult;

use crate::{
array::{DataArray, FixedSizeListArray, ListArray, StructArray},
datatypes::DaftPhysicalType,
};

use super::DaftAnyValueAggable;

impl<T> DaftAnyValueAggable for DataArray<T>
where
T: DaftPhysicalType,
{
type Output = DaftResult<DataArray<T>>;

fn any_value(&self) -> Self::Output {
todo!()
}

fn grouped_any_value(&self, _groups: &super::GroupIndices) -> Self::Output {
todo!()
}
}

macro_rules! impl_daft_any_value_nested_array {
($arr:ident) => {
impl DaftAnyValueAggable for $arr {
type Output = DaftResult<$arr>;

fn any_value(&self) -> Self::Output {
todo!()
}

fn grouped_any_value(&self, _groups: &super::GroupIndices) -> Self::Output {
todo!()
}
}
};
}

impl_daft_any_value_nested_array!(FixedSizeListArray);
impl_daft_any_value_nested_array!(ListArray);
impl_daft_any_value_nested_array!(StructArray);
7 changes: 7 additions & 0 deletions src/daft-core/src/array/ops/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mod abs;
mod any_value;
mod apply;
mod arange;
mod arithmetic;
Expand Down Expand Up @@ -134,6 +135,12 @@ pub trait DaftCompareAggable {
fn grouped_max(&self, groups: &GroupIndices) -> Self::Output;
}

pub trait DaftAnyValueAggable {
type Output;
fn any_value(&self) -> Self::Output;
fn grouped_any_value(&self, groups: &GroupIndices) -> Self::Output;
}

pub trait DaftListAggable {
type Output;
fn list(&self) -> Self::Output;
Expand Down
10 changes: 10 additions & 0 deletions src/daft-core/src/series/array_impl/data_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,16 @@ macro_rules! impl_series_like_for_data_array {
}
}

fn any_value(&self, groups: Option<&GroupIndices>) -> DaftResult<Series> {
use crate::array::ops::DaftAnyValueAggable;
match groups {
Some(groups) => {
Ok(DaftAnyValueAggable::grouped_any_value(&self.0, groups)?.into_series())
}
None => Ok(DaftAnyValueAggable::any_value(&self.0)?.into_series()),
}
}

fn agg_list(&self, groups: Option<&GroupIndices>) -> DaftResult<Series> {
match groups {
Some(groups) => Ok(self.0.grouped_list(groups)?.into_series()),
Expand Down
10 changes: 10 additions & 0 deletions src/daft-core/src/series/array_impl/logical_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,16 @@ macro_rules! impl_series_like_for_logical_array {
};
Ok($da::new(self.0.field.clone(), data_array).into_series())
}
fn any_value(&self, groups: Option<&GroupIndices>) -> DaftResult<Series> {
use crate::array::ops::DaftAnyValueAggable;
let data_array = match groups {
Some(groups) => {
DaftAnyValueAggable::grouped_any_value(&self.0.physical, groups)?
}
None => DaftAnyValueAggable::any_value(&self.0.physical)?,
};
Ok($da::new(self.0.field.clone(), data_array).into_series())
}
fn agg_list(&self, groups: Option<&GroupIndices>) -> DaftResult<Series> {
use crate::array::{ops::DaftListAggable, ListArray};
let data_array = match groups {
Expand Down
4 changes: 4 additions & 0 deletions src/daft-core/src/series/array_impl/nested_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ macro_rules! impl_series_like_for_nested_arrays {
)))
}

fn any_value(&self, _groups: Option<&GroupIndices>) -> DaftResult<Series> {
todo!();
}

fn agg_list(&self, groups: Option<&GroupIndices>) -> DaftResult<Series> {
use crate::array::ops::DaftListAggable;

Expand Down
4 changes: 4 additions & 0 deletions src/daft-core/src/series/ops/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ impl Series {
self.inner.max(groups)
}

pub fn any_value(&self, groups: Option<&GroupIndices>) -> DaftResult<Series> {
self.inner.any_value(groups)
}

pub fn agg_list(&self, groups: Option<&GroupIndices>) -> DaftResult<Series> {
self.inner.agg_list(groups)
}
Expand Down
1 change: 1 addition & 0 deletions src/daft-core/src/series/series_like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub trait SeriesLike: Send + Sync + Any + std::fmt::Debug {
fn validity(&self) -> Option<&arrow2::bitmap::Bitmap>;
fn min(&self, groups: Option<&GroupIndices>) -> DaftResult<Series>;
fn max(&self, groups: Option<&GroupIndices>) -> DaftResult<Series>;
fn any_value(&self, groups: Option<&GroupIndices>) -> DaftResult<Series>;
fn agg_list(&self, groups: Option<&GroupIndices>) -> DaftResult<Series>;
fn broadcast(&self, num: usize) -> DaftResult<Series>;
fn cast(&self, datatype: &DataType) -> DaftResult<Series>;
Expand Down
14 changes: 13 additions & 1 deletion src/daft-dsl/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ pub enum AggExpr {
Mean(ExprRef),
Min(ExprRef),
Max(ExprRef),
AnyValue(ExprRef),
List(ExprRef),
Concat(ExprRef),
MapGroups {
Expand Down Expand Up @@ -87,6 +88,7 @@ impl AggExpr {
| Mean(expr)
| Min(expr)
| Max(expr)
| AnyValue(expr)
| List(expr)
| Concat(expr) => expr.name(),
MapGroups { func: _, inputs } => inputs.first().unwrap().name(),
Expand Down Expand Up @@ -116,6 +118,10 @@ impl AggExpr {
let child_id = expr.semantic_id(schema);
FieldID::new(format!("{child_id}.local_max()"))
}
AnyValue(expr) => {
let child_id = expr.semantic_id(schema);
FieldID::new(format!("{child_id}.local_any_value()"))
}
List(expr) => {
let child_id = expr.semantic_id(schema);
FieldID::new(format!("{child_id}.local_list()"))
Expand All @@ -136,6 +142,7 @@ impl AggExpr {
| Mean(expr)
| Min(expr)
| Max(expr)
| AnyValue(expr)
| List(expr)
| Concat(expr) => vec![expr.clone()],
MapGroups { func: _, inputs } => inputs.iter().map(|e| e.clone().into()).collect(),
Expand Down Expand Up @@ -196,7 +203,7 @@ impl AggExpr {
},
))
}
Min(expr) | Max(expr) => {
Min(expr) | Max(expr) | AnyValue(expr) => {
let field = expr.to_field(schema)?;
Ok(Field::new(field.name.as_str(), field.dtype))
}
Expand Down Expand Up @@ -279,6 +286,10 @@ impl Expr {
Expr::Agg(AggExpr::Max(self.clone().into()))
}

pub fn any_value(&self) -> Self {
Expr::Agg(AggExpr::AnyValue(self.clone().into()))
}

pub fn agg_list(&self) -> Self {
Expr::Agg(AggExpr::List(self.clone().into()))
}
Expand Down Expand Up @@ -604,6 +615,7 @@ impl Display for AggExpr {
Mean(expr) => write!(f, "mean({expr})"),
Min(expr) => write!(f, "min({expr})"),
Max(expr) => write!(f, "max({expr})"),
AnyValue(expr) => write!(f, "any_value({expr})"),
List(expr) => write!(f, "list({expr})"),
Concat(expr) => write!(f, "list({expr})"),
MapGroups { func, inputs } => function_display(f, func, inputs),
Expand Down
4 changes: 4 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ impl PyExpr {
Ok(self.expr.max().into())
}

pub fn any_value(&self) -> PyResult<Self> {
Ok(self.expr.any_value().into())
}

pub fn agg_list(&self) -> PyResult<Self> {
Ok(self.expr.agg_list().into())
}
Expand Down
2 changes: 2 additions & 0 deletions src/daft-dsl/src/treenode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ impl TreeNode for Expr {
| Mean(expr)
| Min(expr)
| Max(expr)
| AnyValue(expr)
| List(expr)
| Concat(expr) => vec![expr.as_ref()],
MapGroups { func: _, inputs } => inputs.iter().collect::<Vec<_>>(),
Expand Down Expand Up @@ -65,6 +66,7 @@ impl TreeNode for Expr {
Mean(expr) => transform(expr.as_ref().clone())?.mean(),
Min(expr) => transform(expr.as_ref().clone())?.min(),
Max(expr) => transform(expr.as_ref().clone())?.max(),
AnyValue(expr) => transform(expr.as_ref().clone())?.any_value(),
List(expr) => transform(expr.as_ref().clone())?.agg_list(),
Concat(expr) => transform(expr.as_ref().clone())?.agg_concat(),
MapGroups { func, inputs } => Expr::Agg(MapGroups {
Expand Down
4 changes: 4 additions & 0 deletions src/daft-plan/src/logical_ops/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,10 @@ fn replace_column_with_semantic_id_aggexpr(
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
.map_yes_no(AggExpr::Max, |_| e.clone())
}
AggExpr::AnyValue(ref child) => {
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
.map_yes_no(AggExpr::AnyValue, |_| e.clone())
}
AggExpr::List(ref child) => {
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
.map_yes_no(AggExpr::List, |_| e.clone())
Expand Down
3 changes: 3 additions & 0 deletions src/daft-plan/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,9 @@ pub fn plan(logical_plan: &LogicalPlan, cfg: Arc<DaftExecutionConfig>) -> DaftRe
.into()));
final_exprs.push(Column(max_of_max_id.clone()).alias(output_name));
}
AnyValue(_e) => {
todo!()
}
List(e) => {
let list_id = agg_expr.semantic_id(&schema).id;
let concat_of_list_id = Concat(Column(list_id.clone()).into())
Expand Down
1 change: 1 addition & 0 deletions src/daft-table/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ impl Table {
Mean(expr) => Series::mean(&self.eval_expression(expr)?, groups),
Min(expr) => Series::min(&self.eval_expression(expr)?, groups),
Max(expr) => Series::max(&self.eval_expression(expr)?, groups),
AnyValue(expr) => Series::any_value(&self.eval_expression(expr)?, groups),
List(expr) => Series::agg_list(&self.eval_expression(expr)?, groups),
Concat(expr) => Series::agg_concat(&self.eval_expression(expr)?, groups),
MapGroups { .. } => Err(DaftError::ValueError(
Expand Down

0 comments on commit b493f28

Please sign in to comment.