Skip to content

Commit

Permalink
[FEAT] connect: add repartition support
Browse files Browse the repository at this point in the history
- [ ] the test is not great but idk how to do it better since rdd does
  not work with spark connect (I think)
- [ ] do we want to support non-shuffle repartitioning?
  • Loading branch information
andrewgazelka committed Nov 21, 2024
1 parent e4dadeb commit e1aa22e
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 1 deletion.
6 changes: 5 additions & 1 deletion src/daft-connect/src/translation/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ use spark_connect::{relation::RelType, Relation};
use tracing::warn;

use crate::translation::logical_plan::{
aggregate::aggregate, project::project, range::range, set_op::set_op,
aggregate::aggregate, project::project, range::range, repartition::repartition, set_op::set_op,
with_columns::with_columns,
};

mod aggregate;
mod project;
mod range;
mod repartition;
mod set_op;
mod with_columns;

Expand All @@ -32,6 +33,9 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result<LogicalPlanBuilder> {
RelType::WithColumns(w) => {
with_columns(*w).wrap_err("Failed to apply with_columns to logical plan")
}
RelType::Repartition(r) => {
repartition(*r).wrap_err("Failed to apply repartition to logical plan")
}
RelType::SetOp(s) => set_op(*s).wrap_err("Failed to apply set_op to logical plan"),
plan => bail!("Unsupported relation type: {plan:?}"),
}
Expand Down
41 changes: 41 additions & 0 deletions src/daft-connect/src/translation/logical_plan/repartition.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use eyre::{bail, ensure, WrapErr};

use crate::translation::to_logical_plan;

pub fn repartition(
repartition: spark_connect::Repartition,
) -> eyre::Result<daft_logical_plan::LogicalPlanBuilder> {
let spark_connect::Repartition {
input,
num_partitions,
shuffle,
} = repartition;

let Some(input) = input else {
bail!("Input is required");
};

let num_partitions = usize::try_from(num_partitions).map_err(|_| {
eyre::eyre!("Num partitions must be a positive integer, got {num_partitions}")
})?;

ensure!(
num_partitions > 0,
"Num partitions must be greater than 0, got {num_partitions}"
);

let plan = to_logical_plan(*input)?;

// let's make true is default
let shuffle = shuffle.unwrap_or(true);

if !shuffle {
bail!("Repartitioning without shuffling is not yet supported");
}

let plan = plan
.random_shuffle(Some(num_partitions))
.wrap_err("Failed to apply random_shuffle to logical plan")?;

Ok(plan)
}
13 changes: 13 additions & 0 deletions tests/connect/test_repartition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from __future__ import annotations

def test_repartition(spark_session):
# Create a simple DataFrame
df = spark_session.range(10)

# Test repartitioning to 2 partitions
repartitioned = df.repartition(2)

# Verify data is preserved after repartitioning
original_data = sorted(df.collect())
repartitioned_data = sorted(repartitioned.collect())
assert repartitioned_data == original_data, "Data should be preserved after repartitioning"

0 comments on commit e1aa22e

Please sign in to comment.