From c466dd54498982ca6bfc29dd19f712bd62261735 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Tue, 19 Nov 2024 23:31:49 -0800 Subject: [PATCH] [FEAT] connect: add groupBy --- .../translation/expr/unresolved_function.rs | 17 +++++++++ .../src/translation/logical_plan.rs | 8 ++++- .../translation/logical_plan/with_columns.rs | 32 +++++++++++++++++ tests/connect/test_group_by.py | 35 +++++++++++++++++++ 4 files changed, 91 insertions(+), 1 deletion(-) create mode 100644 src/daft-connect/src/translation/logical_plan/with_columns.rs create mode 100644 tests/connect/test_group_by.py diff --git a/src/daft-connect/src/translation/expr/unresolved_function.rs b/src/daft-connect/src/translation/expr/unresolved_function.rs index ef8a2a568d..6af019d43f 100644 --- a/src/daft-connect/src/translation/expr/unresolved_function.rs +++ b/src/daft-connect/src/translation/expr/unresolved_function.rs @@ -32,10 +32,27 @@ pub fn unresolved_to_daft_expr(f: &UnresolvedFunction) -> eyre::Result=" => handle_binary_op(arguments, daft_dsl::Operator::GtEq) .wrap_err("Failed to handle >= function"), + "%" => handle_binary_op(arguments, daft_dsl::Operator::Modulus) + .wrap_err("Failed to handle % function"), + "sum" => handle_sum(arguments).wrap_err("Failed to handle sum function"), n => bail!("Unresolved function {n} not yet supported"), } } +pub fn handle_sum(arguments: Vec) -> eyre::Result { + let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() { + Ok(arguments) => arguments, + Err(arguments) => { + bail!("requires exactly one argument; got {arguments:?}"); + } + }; + + let [arg] = arguments; + Ok(arg.sum()) +} + + + pub fn handle_binary_op( arguments: Vec, op: daft_dsl::Operator, diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 947e0cd0d3..1e78eecdbf 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -3,11 +3,14 @@ use eyre::{bail, Context}; use spark_connect::{relation::RelType, Relation}; use tracing::warn; -use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range}; +use crate::translation::logical_plan::{ + aggregate::aggregate, project::project, range::range, with_columns::with_columns, +}; mod aggregate; mod project; mod range; +mod with_columns; pub fn to_logical_plan(relation: Relation) -> eyre::Result { if let Some(common) = relation.common { @@ -24,6 +27,9 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result { RelType::Aggregate(a) => { aggregate(*a).wrap_err("Failed to apply aggregate to logical plan") } + RelType::WithColumns(w) => { + with_columns(*w).wrap_err("Failed to apply with_columns to logical plan") + } plan => bail!("Unsupported relation type: {plan:?}"), } } diff --git a/src/daft-connect/src/translation/logical_plan/with_columns.rs b/src/daft-connect/src/translation/logical_plan/with_columns.rs new file mode 100644 index 0000000000..2ab9424a72 --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/with_columns.rs @@ -0,0 +1,32 @@ +use eyre::bail; +use spark_connect::{expression::ExprType, Expression}; + +use crate::translation::{to_daft_expr, to_logical_plan}; + +pub fn with_columns( + with_columns: spark_connect::WithColumns, +) -> eyre::Result { + let spark_connect::WithColumns { input, aliases } = with_columns; + + let Some(input) = input else { + bail!("input is required"); + }; + + let plan = to_logical_plan(*input)?; + + let daft_exprs: Vec<_> = aliases + .into_iter() + .map(|alias| { + let expression = Expression { + common: None, + expr_type: Some(ExprType::Alias(Box::new(alias))), + }; + + to_daft_expr(&expression) + }) + .try_collect()?; + + let plan = plan.with_columns(daft_exprs)?; + + Ok(plan) +} diff --git a/tests/connect/test_group_by.py b/tests/connect/test_group_by.py new file mode 100644 index 0000000000..40efbb20c6 --- /dev/null +++ b/tests/connect/test_group_by.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from pyspark.sql.functions import col + + +def test_group_by(spark_session): + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Add a column that will have repeated values for grouping + df = df.withColumn("group", col("id") % 3) + + # Group by the new column and sum the ids in each group + df_grouped = df.groupBy("group").sum("id") + + # Convert to pandas to verify the sums + df_grouped_pandas = df_grouped.toPandas() + + print(df_grouped_pandas) + + # Sort by group to ensure consistent order for comparison + df_grouped_pandas = df_grouped_pandas.sort_values("group").reset_index(drop=True) + + # Verify the expected sums for each group + # group id + # 0 2 15 + # 1 1 12 + # 2 0 18 + expected = { + "group": [0, 1, 2], + "id": [18, 12, 15], # todo(correctness): should this be "id" for value here? + } + + assert df_grouped_pandas["group"].tolist() == expected["group"] + assert df_grouped_pandas["id"].tolist() == expected["id"]