Skip to content

Commit

Permalink
[FEAT] connect: add groupBy
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Dec 5, 2024
1 parent 86523a0 commit 3d691bb
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 1 deletion.
15 changes: 15 additions & 0 deletions src/daft-connect/src/translation/expr/unresolved_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,27 @@ pub fn unresolved_to_daft_expr(f: &UnresolvedFunction) -> eyre::Result<daft_dsl:
.wrap_err("Failed to handle <= function"),
">=" => handle_binary_op(arguments, daft_dsl::Operator::GtEq)
.wrap_err("Failed to handle >= function"),
"%" => handle_binary_op(arguments, daft_dsl::Operator::Modulus)
.wrap_err("Failed to handle % function"),
"sum" => handle_sum(arguments).wrap_err("Failed to handle sum function"),
"isnotnull" => handle_isnotnull(arguments).wrap_err("Failed to handle isnotnull function"),
"isnull" => handle_isnull(arguments).wrap_err("Failed to handle isnull function"),
n => bail!("Unresolved function {n} not yet supported"),
}
}

pub fn handle_sum(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl::ExprRef> {
let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() {
Ok(arguments) => arguments,
Err(arguments) => {
bail!("requires exactly one argument; got {arguments:?}");

Check warning on line 48 in src/daft-connect/src/translation/expr/unresolved_function.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/expr/unresolved_function.rs#L47-L48

Added lines #L47 - L48 were not covered by tests
}
};

let [arg] = arguments;
Ok(arg.sum())
}

pub fn handle_binary_op(
arguments: Vec<daft_dsl::ExprRef>,
op: daft_dsl::Operator,
Expand Down
6 changes: 5 additions & 1 deletion src/daft-connect/src/translation/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@ use tracing::warn;

use crate::translation::logical_plan::{
aggregate::aggregate, local_relation::local_relation, project::project, range::range,
to_df::to_df,
to_df::to_df, with_columns::with_columns,
};

mod aggregate;
mod local_relation;
mod project;
mod range;
mod to_df;
mod with_columns;

#[derive(Constructor)]
pub struct Plan {
Expand Down Expand Up @@ -49,6 +50,9 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result<Plan> {
RelType::Aggregate(a) => {
aggregate(*a).wrap_err("Failed to apply aggregate to logical plan")
}
RelType::WithColumns(w) => {
with_columns(*w).wrap_err("Failed to apply with_columns to logical plan")
}
RelType::ToDf(t) => to_df(*t).wrap_err("Failed to apply to_df to logical plan"),
RelType::LocalRelation(l) => {
local_relation(l).wrap_err("Failed to apply local_relation to logical plan")
Expand Down
30 changes: 30 additions & 0 deletions src/daft-connect/src/translation/logical_plan/with_columns.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use eyre::bail;
use spark_connect::{expression::ExprType, Expression};

use crate::translation::{to_daft_expr, to_logical_plan, Plan};

pub fn with_columns(with_columns: spark_connect::WithColumns) -> eyre::Result<Plan> {
let spark_connect::WithColumns { input, aliases } = with_columns;

let Some(input) = input else {
bail!("input is required");

Check warning on line 10 in src/daft-connect/src/translation/logical_plan/with_columns.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/logical_plan/with_columns.rs#L10

Added line #L10 was not covered by tests
};

let mut plan = to_logical_plan(*input)?;

let daft_exprs: Vec<_> = aliases
.into_iter()
.map(|alias| {
let expression = Expression {
common: None,
expr_type: Some(ExprType::Alias(Box::new(alias))),
};

to_daft_expr(&expression)
})
.try_collect()?;

plan.builder = plan.builder.with_columns(daft_exprs)?;

Ok(plan)
}
33 changes: 33 additions & 0 deletions tests/connect/test_group_by.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from __future__ import annotations

from pyspark.sql.functions import col


def test_group_by(spark_session):
# Create DataFrame from range(10)
df = spark_session.range(10)

# Add a column that will have repeated values for grouping
df = df.withColumn("group", col("id") % 3)

# Group by the new column and sum the ids in each group
df_grouped = df.groupBy("group").sum("id")

# Convert to pandas to verify the sums
df_grouped_pandas = df_grouped.toPandas()

# Sort by group to ensure consistent order for comparison
df_grouped_pandas = df_grouped_pandas.sort_values("group").reset_index(drop=True)

# Verify the expected sums for each group
# group id
# 0 2 15
# 1 1 12
# 2 0 18
expected = {
"group": [0, 1, 2],
"id": [18, 12, 15], # todo(correctness): should this be "id" for value here?
}

assert df_grouped_pandas["group"].tolist() == expected["group"]
assert df_grouped_pandas["id"].tolist() == expected["id"]

0 comments on commit 3d691bb

Please sign in to comment.