Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] connect: add modulus operator and withColumns support #3351

Merged
merged 1 commit into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/daft-connect/src/translation/expr/unresolved_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,27 @@
.wrap_err("Failed to handle <= function"),
">=" => handle_binary_op(arguments, daft_dsl::Operator::GtEq)
.wrap_err("Failed to handle >= function"),
"%" => handle_binary_op(arguments, daft_dsl::Operator::Modulus)
.wrap_err("Failed to handle % function"),
"sum" => handle_sum(arguments).wrap_err("Failed to handle sum function"),
"isnotnull" => handle_isnotnull(arguments).wrap_err("Failed to handle isnotnull function"),
"isnull" => handle_isnull(arguments).wrap_err("Failed to handle isnull function"),
n => bail!("Unresolved function {n} not yet supported"),
}
}

pub fn handle_sum(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl::ExprRef> {
let arguments: [daft_dsl::ExprRef; 1] = match arguments.try_into() {
Ok(arguments) => arguments,
Err(arguments) => {
bail!("requires exactly one argument; got {arguments:?}");

Check warning on line 48 in src/daft-connect/src/translation/expr/unresolved_function.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/expr/unresolved_function.rs#L47-L48

Added lines #L47 - L48 were not covered by tests
}
};

let [arg] = arguments;
Ok(arg.sum())
}

pub fn handle_binary_op(
arguments: Vec<daft_dsl::ExprRef>,
op: daft_dsl::Operator,
Expand Down
6 changes: 5 additions & 1 deletion src/daft-connect/src/translation/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@ use tracing::warn;

use crate::translation::logical_plan::{
aggregate::aggregate, local_relation::local_relation, project::project, range::range,
to_df::to_df,
to_df::to_df, with_columns::with_columns,
};

mod aggregate;
mod local_relation;
mod project;
mod range;
mod to_df;
mod with_columns;

#[derive(Constructor)]
pub struct Plan {
Expand Down Expand Up @@ -49,6 +50,9 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result<Plan> {
RelType::Aggregate(a) => {
aggregate(*a).wrap_err("Failed to apply aggregate to logical plan")
}
RelType::WithColumns(w) => {
with_columns(*w).wrap_err("Failed to apply with_columns to logical plan")
}
RelType::ToDf(t) => to_df(*t).wrap_err("Failed to apply to_df to logical plan"),
RelType::LocalRelation(l) => {
local_relation(l).wrap_err("Failed to apply local_relation to logical plan")
Expand Down
30 changes: 30 additions & 0 deletions src/daft-connect/src/translation/logical_plan/with_columns.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
use eyre::bail;
use spark_connect::{expression::ExprType, Expression};

use crate::translation::{to_daft_expr, to_logical_plan, Plan};

pub fn with_columns(with_columns: spark_connect::WithColumns) -> eyre::Result<Plan> {
let spark_connect::WithColumns { input, aliases } = with_columns;

let Some(input) = input else {
bail!("input is required");

Check warning on line 10 in src/daft-connect/src/translation/logical_plan/with_columns.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/logical_plan/with_columns.rs#L10

Added line #L10 was not covered by tests
};

let mut plan = to_logical_plan(*input)?;

let daft_exprs: Vec<_> = aliases
.into_iter()
.map(|alias| {
let expression = Expression {
common: None,
expr_type: Some(ExprType::Alias(Box::new(alias))),
};

to_daft_expr(&expression)
})
.try_collect()?;

plan.builder = plan.builder.with_columns(daft_exprs)?;

Ok(plan)
}
andrewgazelka marked this conversation as resolved.
Show resolved Hide resolved
33 changes: 33 additions & 0 deletions tests/connect/test_group_by.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from __future__ import annotations

from pyspark.sql.functions import col


def test_group_by(spark_session):
# Create DataFrame from range(10)
df = spark_session.range(10)

# Add a column that will have repeated values for grouping
df = df.withColumn("group", col("id") % 3)

# Group by the new column and sum the ids in each group
df_grouped = df.groupBy("group").sum("id")

# Convert to pandas to verify the sums
df_grouped_pandas = df_grouped.toPandas()

# Sort by group to ensure consistent order for comparison
df_grouped_pandas = df_grouped_pandas.sort_values("group").reset_index(drop=True)

# Verify the expected sums for each group
# group id
# 0 2 15
# 1 1 12
# 2 0 18
expected = {
"group": [0, 1, 2],
"id": [18, 12, 15], # todo(correctness): should this be "id" for value here?
}

assert df_grouped_pandas["group"].tolist() == expected["group"]
assert df_grouped_pandas["id"].tolist() == expected["id"]
Loading