Skip to content

Commit

Permalink
[FEAT] connect: add groupBy
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Nov 20, 2024
1 parent b399d76 commit 128d640
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 1 deletion.
17 changes: 17 additions & 0 deletions src/daft-connect/src/translation/expr/unresolved_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,27 @@ pub fn unresolved_to_daft_expr(f: &UnresolvedFunction) -> eyre::Result<daft_dsl:
"count" => handle_count(arguments).wrap_err("Failed to handle count function"),
"<" => handle_binary_op(arguments, daft_dsl::Operator::Lt)
.wrap_err("Failed to handle < function"),
"%" => handle_binary_op(arguments, daft_dsl::Operator::Modulus)
.wrap_err("Failed to handle % function"),
"sum" => handle_sum(arguments).wrap_err("Failed to handle sum function"),
n => bail!("Unresolved function {n} not yet supported"),
}
}

pub fn handle_sum(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl::ExprRef> {
let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() {
Ok(arguments) => arguments,
Err(arguments) => {
bail!("requires exactly one argument; got {arguments:?}");
}
};

let [arg] = arguments;
Ok(arg.sum())
}



pub fn handle_binary_op(
arguments: Vec<daft_dsl::ExprRef>,
op: daft_dsl::Operator,
Expand Down
8 changes: 7 additions & 1 deletion src/daft-connect/src/translation/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ use eyre::{bail, Context};
use spark_connect::{relation::RelType, Relation};
use tracing::warn;

use crate::translation::logical_plan::{aggregate::aggregate, project::project, range::range};
use crate::translation::logical_plan::{
aggregate::aggregate, project::project, range::range, with_columns::with_columns,
};

mod aggregate;
mod project;
mod range;
mod with_columns;

pub fn to_logical_plan(relation: Relation) -> eyre::Result<LogicalPlanBuilder> {
if let Some(common) = relation.common {
Expand All @@ -24,6 +27,9 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result<LogicalPlanBuilder> {
RelType::Aggregate(a) => {
aggregate(*a).wrap_err("Failed to apply aggregate to logical plan")
}
RelType::WithColumns(w) => {
with_columns(*w).wrap_err("Failed to apply with_columns to logical plan")
}
plan => bail!("Unsupported relation type: {plan:?}"),
}
}
32 changes: 32 additions & 0 deletions src/daft-connect/src/translation/logical_plan/with_columns.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use eyre::bail;
use spark_connect::{expression::ExprType, Expression};

use crate::translation::{to_daft_expr, to_logical_plan};

pub fn with_columns(
with_columns: spark_connect::WithColumns,
) -> eyre::Result<daft_logical_plan::LogicalPlanBuilder> {
let spark_connect::WithColumns { input, aliases } = with_columns;

let Some(input) = input else {
bail!("input is required");
};

let plan = to_logical_plan(*input)?;

let daft_exprs: Vec<_> = aliases
.into_iter()
.map(|alias| {
let expression = Expression {
common: None,
expr_type: Some(ExprType::Alias(Box::new(alias))),
};

to_daft_expr(expression)
})
.try_collect()?;

let plan = plan.with_columns(daft_exprs)?;

Ok(plan)
}
35 changes: 35 additions & 0 deletions tests/connect/test_group_by.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

from pyspark.sql.functions import col


def test_group_by(spark_session):
# Create DataFrame from range(10)
df = spark_session.range(10)

# Add a column that will have repeated values for grouping
df = df.withColumn("group", col("id") % 3)

# Group by the new column and sum the ids in each group
df_grouped = df.groupBy("group").sum("id")

# Convert to pandas to verify the sums
df_grouped_pandas = df_grouped.toPandas()

print(df_grouped_pandas)

# Sort by group to ensure consistent order for comparison
df_grouped_pandas = df_grouped_pandas.sort_values("group").reset_index(drop=True)

# Verify the expected sums for each group
# group id
# 0 2 15
# 1 1 12
# 2 0 18
expected = {
"group": [0, 1, 2],
"id": [18, 12, 15], # todo(correctness): should this be "id" for value here?
}

assert df_grouped_pandas["group"].tolist() == expected["group"]
assert df_grouped_pandas["id"].tolist() == expected["id"]

0 comments on commit 128d640

Please sign in to comment.