Skip to content

Commit

Permalink
[FEAT] connect: add groupBy
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Nov 20, 2024
1 parent 7f0c85d commit b7cdee2
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 1 deletion.
17 changes: 17 additions & 0 deletions src/daft-connect/src/translation/expr/unresolved_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,27 @@ pub fn unresolved_to_daft_expr(f: UnresolvedFunction) -> eyre::Result<daft_dsl::
"count" => handle_count(arguments).wrap_err("Failed to handle count function"),
"<" => handle_binary_op(arguments, daft_dsl::Operator::Lt)
.wrap_err("Failed to handle < function"),
"%" => handle_binary_op(arguments, daft_dsl::Operator::Modulus)
.wrap_err("Failed to handle % function"),
"sum" => handle_sum(arguments).wrap_err("Failed to handle sum function"),
n => bail!("Unresolved function {n} not yet supported"),
}
}

pub fn handle_sum(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl::ExprRef> {
let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() {
Ok(arguments) => arguments,
Err(arguments) => {
bail!("requires exactly one argument; got {arguments:?}");
}
};

let [arg] = arguments;
Ok(arg.sum())
}



pub fn handle_binary_op(
arguments: Vec<daft_dsl::ExprRef>,
op: daft_dsl::Operator,
Expand Down
8 changes: 7 additions & 1 deletion src/daft-connect/src/translation/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ use eyre::{bail, Context};
use spark_connect::{relation::RelType, Relation};
use tracing::warn;

use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range};
use crate::translation::logical_plan::{
aggregate::aggregate, project::project, range::range, with_columns::with_columns,
};

mod aggregate;
mod project;
mod range;
mod with_columns;

pub fn to_logical_plan(relation: Relation) -> eyre::Result<LogicalPlanBuilder> {
if let Some(common) = relation.common {
Expand All @@ -24,6 +27,9 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result<LogicalPlanBuilder> {
RelType::Aggregate(a) => {
aggregate(*a).wrap_err("Failed to apply aggregate to logical plan")
}
RelType::WithColumns(w) => {
with_columns(*w).wrap_err("Failed to apply with_columns to logical plan")
}
plan => bail!("Unsupported relation type: {plan:?}"),
}
}
3 changes: 3 additions & 0 deletions src/daft-connect/src/translation/logical_plan/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ pub fn aggregate(aggregate: spark_connect::Aggregate) -> eyre::Result<LogicalPla
bail!("GroupType must be specified; got Unspecified")
}
GroupType::Groupby => {
println!("grouping_expressions: {grouping_expressions:?}");
println!("aggregate_expressions: {aggregate_expressions:?}");

let plan = plan
.aggregate(grouping_expressions, aggregate_expressions)
.wrap_err("Failed to apply aggregate to logical plan")?;
Expand Down
32 changes: 32 additions & 0 deletions src/daft-connect/src/translation/logical_plan/with_columns.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use eyre::bail;
use spark_connect::{expression::ExprType, Expression};

use crate::translation::{to_daft_expr, to_logical_plan};

pub fn with_columns(
with_columns: spark_connect::WithColumns,
) -> eyre::Result<daft_logical_plan::LogicalPlanBuilder> {
let spark_connect::WithColumns { input, aliases } = with_columns;

let Some(input) = input else {
bail!("input is required");
};

let plan = to_logical_plan(*input)?;

let daft_exprs: Vec<_> = aliases
.into_iter()
.map(|alias| {
let expression = Expression {
common: None,
expr_type: Some(ExprType::Alias(Box::new(alias))),
};

to_daft_expr(expression)
})
.try_collect()?;

let plan = plan.with_columns(daft_exprs)?;

Ok(plan)
}
29 changes: 29 additions & 0 deletions tests/connect/test_group_by.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from __future__ import annotations

from pyspark.sql.functions import col, count


def test_group_by(spark_session):
# Create DataFrame from range(10)
df = spark_session.range(10)

# Add a column that will have repeated values for grouping
df = df.withColumn("group", col("id") % 3)

# Group by the new column and sum the ids in each group
df_grouped = df.groupBy("group").sum("id")

# Convert to pandas to verify the sums
df_grouped_pandas = df_grouped.toPandas()

# Verify we have 3 groups
assert len(df_grouped_pandas) == 3, "Should have 3 groups (0, 1, 2)"

# Verify the sums are correct
expected_sums = {
0: 9, # 0 + 3 + 6 + 9 = 18
1: 10, # 1 + 4 + 7 = 12
2: 11 # 2 + 5 + 8 = 15
}
for _, row in df_grouped_pandas.iterrows():
assert row["sum(id)"] == expected_sums[row["group"]], f"Sum for group {row['group']} is incorrect"

0 comments on commit b7cdee2

Please sign in to comment.