Skip to content

Commit

Permalink
Add Hll aggregate expression
Browse files Browse the repository at this point in the history
  • Loading branch information
Raunak Bhagat committed Aug 23, 2024
1 parent ab6d1a5 commit 7360b54
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 20 deletions.
10 changes: 10 additions & 0 deletions src/daft-core/src/series/ops/hll.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
use common_error::DaftResult;

use crate::{IntoSeries, Series};

impl Series {
pub fn hll(&self) -> DaftResult<Self> {
let series = self.hash(None)?.into_series();
Ok(series)
}
}
1 change: 1 addition & 0 deletions src/daft-core/src/series/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub mod float;
pub mod floor;
pub mod groups;
pub mod hash;
pub mod hll;
pub mod if_else;
pub mod image;
pub mod is_in;
Expand Down
17 changes: 15 additions & 2 deletions src/daft-dsl/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ pub enum AggExpr {
func: FunctionExpr,
inputs: Vec<ExprRef>,
},
Hll(ExprRef),
}

pub fn col<S: Into<Arc<str>>>(name: S) -> ExprRef {
Expand All @@ -109,7 +110,8 @@ impl AggExpr {
| Max(expr)
| AnyValue(expr, _)
| List(expr)
| Concat(expr) => expr.name(),
| Concat(expr)
| Hll(expr) => expr.name(),
MapGroups { func: _, inputs } => inputs.first().unwrap().name(),
}
}
Expand Down Expand Up @@ -171,6 +173,10 @@ impl AggExpr {
FieldID::new(format!("{child_id}.local_concat()"))
}
MapGroups { func, inputs } => function_semantic_id(func, inputs, schema),
Hll(expr) => {
let child_id = expr.semantic_id(schema);
FieldID::new(format!("{child_id}.hll()"))
}
}
}

Expand All @@ -187,7 +193,8 @@ impl AggExpr {
| Max(expr)
| AnyValue(expr, _)
| List(expr)
| Concat(expr) => vec![expr.clone()],
| Concat(expr)
| Hll(expr) => vec![expr.clone()],
MapGroups { func: _, inputs } => inputs.clone(),
}
}
Expand Down Expand Up @@ -224,6 +231,7 @@ impl AggExpr {
}),
ApproxSketch(_) => ApproxSketch(children[0].clone()),
MergeSketch(_) => MergeSketch(children[0].clone()),
Hll(_) => Hll(children[0].clone()),
}
}

Expand Down Expand Up @@ -328,6 +336,10 @@ impl AggExpr {
}
}
MapGroups { func, inputs } => func.to_field(inputs.as_slice(), schema, func),
Hll(expr) => {
let field = expr.to_field(schema)?;
Ok(Field::new(field.name, DataType::UInt64))
}
}
}

Expand Down Expand Up @@ -1028,6 +1040,7 @@ impl Display for AggExpr {
List(expr) => write!(f, "list({expr})"),
Concat(expr) => write!(f, "list({expr})"),
MapGroups { func, inputs } => function_display(f, func, inputs),
Hll(expr) => write!(f, "hll({expr})"),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/daft-dsl/src/resolve_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ fn extract_agg_expr(expr: &Expr) -> DaftResult<AggExpr> {
.map(|input| input.alias(name.clone()))
.collect(),
},
Hll(e) => Hll(Alias(e, name.clone()).into()),
}
}),
// TODO(Kevin): Support a mix of aggregation and non-aggregation expressions
Expand Down
4 changes: 4 additions & 0 deletions src/daft-plan/src/logical_ops/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,10 @@ fn replace_column_with_semantic_id_aggexpr(
})
}
}
AggExpr::Hll(ref child) => {
replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema)
.map_yes_no(AggExpr::Hll, |_| e.clone())
}
}
}

Expand Down
12 changes: 12 additions & 0 deletions src/daft-plan/src/physical_planner/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,18 @@ pub fn populate_aggregation_stages(
.alias(output_name),
);
}
Hll(expr) => {
let a = agg_expr.semantic_id(schema).id;
let b = agg_expr.semantic_id(schema).id; // Change to HllMerge
first_stage_aggs
.entry(a.clone())
.or_insert_with(|| Hll(expr.alias(&*a)));
second_stage_aggs.entry(b.clone()).or_insert_with(|| {
// Hll Merge
todo!()
});
todo!()
}
}
}
(first_stage_aggs, second_stage_aggs, final_exprs)
Expand Down
37 changes: 19 additions & 18 deletions src/daft-table/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,32 +431,33 @@ impl Table {
agg_expr: &AggExpr,
groups: Option<&GroupIndices>,
) -> DaftResult<Series> {
use daft_dsl::AggExpr::*;
match agg_expr {
Count(expr, mode) => Series::count(&self.eval_expression(expr)?, groups, *mode),
Sum(expr) => Series::sum(&self.eval_expression(expr)?, groups),
ApproxSketch(expr) => Series::approx_sketch(&self.eval_expression(expr)?, groups),
ApproxPercentile(ApproxPercentileParams {
child: expr,
percentiles,
&AggExpr::Count(ref expr, mode) => self.eval_expression(expr)?.count(groups, mode),
AggExpr::Sum(expr) => self.eval_expression(expr)?.sum(groups),
AggExpr::ApproxSketch(expr) => self.eval_expression(expr)?.approx_sketch(groups),
&AggExpr::ApproxPercentile(ApproxPercentileParams {
child: ref expr,
ref percentiles,
force_list_output,
}) => {
let percentiles = percentiles.iter().map(|p| p.0).collect::<Vec<f64>>();
Series::approx_sketch(&self.eval_expression(expr)?, groups)?
.sketch_percentile(&percentiles, *force_list_output)
self.eval_expression(expr)?
.approx_sketch(groups)?
.sketch_percentile(&percentiles, force_list_output)
}
MergeSketch(expr) => Series::merge_sketch(&self.eval_expression(expr)?, groups),
Mean(expr) => Series::mean(&self.eval_expression(expr)?, groups),
Min(expr) => Series::min(&self.eval_expression(expr)?, groups),
Max(expr) => Series::max(&self.eval_expression(expr)?, groups),
AnyValue(expr, ignore_nulls) => {
Series::any_value(&self.eval_expression(expr)?, groups, *ignore_nulls)
AggExpr::MergeSketch(expr) => self.eval_expression(expr)?.merge_sketch(groups),
AggExpr::Mean(expr) => self.eval_expression(expr)?.mean(groups),
AggExpr::Min(expr) => self.eval_expression(expr)?.min(groups),
AggExpr::Max(expr) => self.eval_expression(expr)?.max(groups),
&AggExpr::AnyValue(ref expr, ignore_nulls) => {
self.eval_expression(expr)?.any_value(groups, ignore_nulls)
}
List(expr) => Series::agg_list(&self.eval_expression(expr)?, groups),
Concat(expr) => Series::agg_concat(&self.eval_expression(expr)?, groups),
MapGroups { .. } => Err(DaftError::ValueError(
AggExpr::List(expr) => self.eval_expression(expr)?.agg_list(groups),
AggExpr::Concat(expr) => self.eval_expression(expr)?.agg_concat(groups),
AggExpr::MapGroups { .. } => Err(DaftError::ValueError(
"MapGroups not supported via aggregation, use map_groups instead".to_string(),
)),
AggExpr::Hll(expr) => self.eval_expression(expr)?.hll(),
}
}

Expand Down

0 comments on commit 7360b54

Please sign in to comment.