diff --git a/src/daft-core/src/series/ops/hll.rs b/src/daft-core/src/series/ops/hll.rs new file mode 100644 index 0000000000..462b813627 --- /dev/null +++ b/src/daft-core/src/series/ops/hll.rs @@ -0,0 +1,10 @@ +use common_error::DaftResult; + +use crate::{IntoSeries, Series}; + +impl Series { + pub fn hll(&self) -> DaftResult { + let series = self.hash(None)?.into_series(); + Ok(series) + } +} diff --git a/src/daft-core/src/series/ops/mod.rs b/src/daft-core/src/series/ops/mod.rs index ea09bee99e..b414ccbee1 100644 --- a/src/daft-core/src/series/ops/mod.rs +++ b/src/daft-core/src/series/ops/mod.rs @@ -20,6 +20,7 @@ pub mod float; pub mod floor; pub mod groups; pub mod hash; +pub mod hll; pub mod if_else; pub mod image; pub mod is_in; diff --git a/src/daft-dsl/src/expr.rs b/src/daft-dsl/src/expr.rs index 3bf1631888..6b6df07fa1 100644 --- a/src/daft-dsl/src/expr.rs +++ b/src/daft-dsl/src/expr.rs @@ -85,6 +85,7 @@ pub enum AggExpr { func: FunctionExpr, inputs: Vec, }, + Hll(ExprRef), } pub fn col>>(name: S) -> ExprRef { @@ -109,7 +110,8 @@ impl AggExpr { | Max(expr) | AnyValue(expr, _) | List(expr) - | Concat(expr) => expr.name(), + | Concat(expr) + | Hll(expr) => expr.name(), MapGroups { func: _, inputs } => inputs.first().unwrap().name(), } } @@ -171,6 +173,10 @@ impl AggExpr { FieldID::new(format!("{child_id}.local_concat()")) } MapGroups { func, inputs } => function_semantic_id(func, inputs, schema), + Hll(expr) => { + let child_id = expr.semantic_id(schema); + FieldID::new(format!("{child_id}.hll()")) + } } } @@ -187,7 +193,8 @@ impl AggExpr { | Max(expr) | AnyValue(expr, _) | List(expr) - | Concat(expr) => vec![expr.clone()], + | Concat(expr) + | Hll(expr) => vec![expr.clone()], MapGroups { func: _, inputs } => inputs.clone(), } } @@ -224,6 +231,7 @@ impl AggExpr { }), ApproxSketch(_) => ApproxSketch(children[0].clone()), MergeSketch(_) => MergeSketch(children[0].clone()), + Hll(_) => Hll(children[0].clone()), } } @@ -328,6 +336,10 @@ impl AggExpr { } } MapGroups { func, inputs } => func.to_field(inputs.as_slice(), schema, func), + Hll(expr) => { + let field = expr.to_field(schema)?; + Ok(Field::new(field.name, DataType::UInt64)) + } } } @@ -1028,6 +1040,7 @@ impl Display for AggExpr { List(expr) => write!(f, "list({expr})"), Concat(expr) => write!(f, "list({expr})"), MapGroups { func, inputs } => function_display(f, func, inputs), + Hll(expr) => write!(f, "hll({expr})"), } } } diff --git a/src/daft-dsl/src/resolve_expr.rs b/src/daft-dsl/src/resolve_expr.rs index 152abbbd4a..e7412bb6f9 100644 --- a/src/daft-dsl/src/resolve_expr.rs +++ b/src/daft-dsl/src/resolve_expr.rs @@ -246,6 +246,7 @@ fn extract_agg_expr(expr: &Expr) -> DaftResult { .map(|input| input.alias(name.clone())) .collect(), }, + Hll(e) => Hll(Alias(e, name.clone()).into()), } }), // TODO(Kevin): Support a mix of aggregation and non-aggregation expressions diff --git a/src/daft-plan/src/logical_ops/project.rs b/src/daft-plan/src/logical_ops/project.rs index 8d346c85b1..827a9be204 100644 --- a/src/daft-plan/src/logical_ops/project.rs +++ b/src/daft-plan/src/logical_ops/project.rs @@ -448,6 +448,10 @@ fn replace_column_with_semantic_id_aggexpr( }) } } + AggExpr::Hll(ref child) => { + replace_column_with_semantic_id(child.clone(), subexprs_to_replace, schema) + .map_yes_no(AggExpr::Hll, |_| e.clone()) + } } } diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index 8ef727eedf..e2e1ab34b2 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -923,6 +923,18 @@ pub fn populate_aggregation_stages( .alias(output_name), ); } + Hll(expr) => { + let a = agg_expr.semantic_id(schema).id; + let b = agg_expr.semantic_id(schema).id; // Change to HllMerge + first_stage_aggs + .entry(a.clone()) + .or_insert_with(|| Hll(expr.alias(&*a))); + second_stage_aggs.entry(b.clone()).or_insert_with(|| { + // Hll Merge + todo!() + }); + todo!() + } } } (first_stage_aggs, second_stage_aggs, final_exprs) diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 4b820d8e7a..49b0cca76b 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -431,32 +431,33 @@ impl Table { agg_expr: &AggExpr, groups: Option<&GroupIndices>, ) -> DaftResult { - use daft_dsl::AggExpr::*; match agg_expr { - Count(expr, mode) => Series::count(&self.eval_expression(expr)?, groups, *mode), - Sum(expr) => Series::sum(&self.eval_expression(expr)?, groups), - ApproxSketch(expr) => Series::approx_sketch(&self.eval_expression(expr)?, groups), - ApproxPercentile(ApproxPercentileParams { - child: expr, - percentiles, + &AggExpr::Count(ref expr, mode) => self.eval_expression(expr)?.count(groups, mode), + AggExpr::Sum(expr) => self.eval_expression(expr)?.sum(groups), + AggExpr::ApproxSketch(expr) => self.eval_expression(expr)?.approx_sketch(groups), + &AggExpr::ApproxPercentile(ApproxPercentileParams { + child: ref expr, + ref percentiles, force_list_output, }) => { let percentiles = percentiles.iter().map(|p| p.0).collect::>(); - Series::approx_sketch(&self.eval_expression(expr)?, groups)? - .sketch_percentile(&percentiles, *force_list_output) + self.eval_expression(expr)? + .approx_sketch(groups)? + .sketch_percentile(&percentiles, force_list_output) } - MergeSketch(expr) => Series::merge_sketch(&self.eval_expression(expr)?, groups), - Mean(expr) => Series::mean(&self.eval_expression(expr)?, groups), - Min(expr) => Series::min(&self.eval_expression(expr)?, groups), - Max(expr) => Series::max(&self.eval_expression(expr)?, groups), - AnyValue(expr, ignore_nulls) => { - Series::any_value(&self.eval_expression(expr)?, groups, *ignore_nulls) + AggExpr::MergeSketch(expr) => self.eval_expression(expr)?.merge_sketch(groups), + AggExpr::Mean(expr) => self.eval_expression(expr)?.mean(groups), + AggExpr::Min(expr) => self.eval_expression(expr)?.min(groups), + AggExpr::Max(expr) => self.eval_expression(expr)?.max(groups), + &AggExpr::AnyValue(ref expr, ignore_nulls) => { + self.eval_expression(expr)?.any_value(groups, ignore_nulls) } - List(expr) => Series::agg_list(&self.eval_expression(expr)?, groups), - Concat(expr) => Series::agg_concat(&self.eval_expression(expr)?, groups), - MapGroups { .. } => Err(DaftError::ValueError( + AggExpr::List(expr) => self.eval_expression(expr)?.agg_list(groups), + AggExpr::Concat(expr) => self.eval_expression(expr)?.agg_concat(groups), + AggExpr::MapGroups { .. } => Err(DaftError::ValueError( "MapGroups not supported via aggregation, use map_groups instead".to_string(), )), + AggExpr::Hll(expr) => self.eval_expression(expr)?.hll(), } }