diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 53a0cfc923..a7d97b0d15 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -4,13 +4,14 @@ use spark_connect::{relation::RelType, Limit, Relation}; use tracing::warn; use crate::translation::logical_plan::{ - aggregate::aggregate, project::project, range::range, set_op::set_op, + aggregate::aggregate, project::project, range::range, repartition::repartition, set_op::set_op, with_columns::with_columns, }; mod aggregate; mod project; mod range; +mod repartition; mod set_op; mod with_columns; @@ -33,6 +34,9 @@ pub fn to_logical_plan(relation: Relation) -> eyre::Result { RelType::WithColumns(w) => { with_columns(*w).wrap_err("Failed to apply with_columns to logical plan") } + RelType::Repartition(r) => { + repartition(*r).wrap_err("Failed to apply repartition to logical plan") + } RelType::SetOp(s) => set_op(*s).wrap_err("Failed to apply set_op to logical plan"), plan => bail!("Unsupported relation type: {plan:?}"), } diff --git a/src/daft-connect/src/translation/logical_plan/repartition.rs b/src/daft-connect/src/translation/logical_plan/repartition.rs new file mode 100644 index 0000000000..76ff0175fc --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/repartition.rs @@ -0,0 +1,41 @@ +use eyre::{bail, ensure, WrapErr}; + +use crate::translation::to_logical_plan; + +pub fn repartition( + repartition: spark_connect::Repartition, +) -> eyre::Result { + let spark_connect::Repartition { + input, + num_partitions, + shuffle, + } = repartition; + + let Some(input) = input else { + bail!("Input is required"); + }; + + let num_partitions = usize::try_from(num_partitions).map_err(|_| { + eyre::eyre!("Num partitions must be a positive integer, got {num_partitions}") + })?; + + ensure!( + num_partitions > 0, + "Num partitions must be greater than 0, got {num_partitions}" + ); + + let plan = to_logical_plan(*input)?; + + // let's make true is default + let shuffle = shuffle.unwrap_or(true); + + if !shuffle { + bail!("Repartitioning without shuffling is not yet supported"); + } + + let plan = plan + .random_shuffle(Some(num_partitions)) + .wrap_err("Failed to apply random_shuffle to logical plan")?; + + Ok(plan) +} diff --git a/tests/connect/test_repartition.py b/tests/connect/test_repartition.py new file mode 100644 index 0000000000..7d7c8e25f6 --- /dev/null +++ b/tests/connect/test_repartition.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +def test_repartition(spark_session): + # Create a simple DataFrame + df = spark_session.range(10) + + # Test repartitioning to 2 partitions + repartitioned = df.repartition(2) + + # Verify data is preserved after repartitioning + original_data = sorted(df.collect()) + repartitioned_data = sorted(repartitioned.collect()) + assert repartitioned_data == original_data, "Data should be preserved after repartitioning" \ No newline at end of file