diff --git a/src/daft-connect/src/translation/expr/unresolved_function.rs b/src/daft-connect/src/translation/expr/unresolved_function.rs index 7e3c42552b..90db495831 100644 --- a/src/daft-connect/src/translation/expr/unresolved_function.rs +++ b/src/daft-connect/src/translation/expr/unresolved_function.rs @@ -26,10 +26,27 @@ pub fn unresolved_to_daft_expr(f: UnresolvedFunction) -> eyre::Result handle_count(arguments).wrap_err("Failed to handle count function"), "<" => handle_binary_op(arguments, daft_dsl::Operator::Lt) .wrap_err("Failed to handle < function"), + "%" => handle_binary_op(arguments, daft_dsl::Operator::Modulus) + .wrap_err("Failed to handle % function"), + "sum" => handle_sum(arguments).wrap_err("Failed to handle sum function"), n => bail!("Unresolved function {n} not yet supported"), } } +pub fn handle_sum(arguments: Vec) -> eyre::Result { + let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() { + Ok(arguments) => arguments, + Err(arguments) => { + bail!("requires exactly one argument; got {arguments:?}"); + } + }; + + let [arg] = arguments; + Ok(arg.sum()) +} + + + pub fn handle_binary_op( arguments: Vec, op: daft_dsl::Operator, diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 947e0cd0d3..1e78eecdbf 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -3,11 +3,14 @@ use eyre::{bail, Context}; use spark_connect::{relation::RelType, Relation}; use tracing::warn; -use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range}; +use crate::translation::logical_plan::{ + aggregate::aggregate, project::project, range::range, with_columns::with_columns, +}; mod aggregate; mod project; mod range; +mod with_columns; pub fn to_logical_plan(relation: Relation) -> eyre::Result { if let Some(common) = relation.common { @@ -24,6 +27,9 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result { RelType::Aggregate(a) => { aggregate(*a).wrap_err("Failed to apply aggregate to logical plan") } + RelType::WithColumns(w) => { + with_columns(*w).wrap_err("Failed to apply with_columns to logical plan") + } plan => bail!("Unsupported relation type: {plan:?}"), } } diff --git a/src/daft-connect/src/translation/logical_plan/with_columns.rs b/src/daft-connect/src/translation/logical_plan/with_columns.rs new file mode 100644 index 0000000000..2a96822f93 --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/with_columns.rs @@ -0,0 +1,32 @@ +use eyre::bail; +use spark_connect::{expression::ExprType, Expression}; + +use crate::translation::{to_daft_expr, to_logical_plan}; + +pub fn with_columns( + with_columns: spark_connect::WithColumns, +) -> eyre::Result { + let spark_connect::WithColumns { input, aliases } = with_columns; + + let Some(input) = input else { + bail!("input is required"); + }; + + let plan = to_logical_plan(*input)?; + + let daft_exprs: Vec<_> = aliases + .into_iter() + .map(|alias| { + let expression = Expression { + common: None, + expr_type: Some(ExprType::Alias(Box::new(alias))), + }; + + to_daft_expr(expression) + }) + .try_collect()?; + + let plan = plan.with_columns(daft_exprs)?; + + Ok(plan) +} diff --git a/tests/connect/test_group_by.py b/tests/connect/test_group_by.py new file mode 100644 index 0000000000..734b3e8f21 --- /dev/null +++ b/tests/connect/test_group_by.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from pyspark.sql.functions import col, count + + +def test_group_by(spark_session): + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Add a column that will have repeated values for grouping + df = df.withColumn("group", col("id") % 3) + + # Group by the new column and sum the ids in each group + df_grouped = df.groupBy("group").sum("id") + + # Convert to pandas to verify the sums + df_grouped_pandas = df_grouped.toPandas() + + # Verify we have 3 groups + assert len(df_grouped_pandas) == 3, "Should have 3 groups (0, 1, 2)" + + # Verify the sums are correct + expected_sums = { + 0: 9, # 0 + 3 + 6 + 9 = 18 + 1: 10, # 1 + 4 + 7 = 12 + 2: 11 # 2 + 5 + 8 = 15 + } + for _, row in df_grouped_pandas.iterrows(): + assert row["sum(id)"] == expected_sums[row["group"]], f"Sum for group {row['group']} is incorrect"