-
Notifications
You must be signed in to change notification settings - Fork 174
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[FEAT] connect: add repartition support
- [ ] the test is not great but idk how to do it better since rdd does not work with spark connect (I think) - [ ] do we want to support non-shuffle repartitioning?
- Loading branch information
1 parent
14d1d50
commit 581eb36
Showing
3 changed files
with
59 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
41 changes: 41 additions & 0 deletions
41
src/daft-connect/src/translation/logical_plan/repartition.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
use eyre::{bail, ensure, WrapErr}; | ||
|
||
use crate::translation::to_logical_plan; | ||
|
||
pub fn repartition( | ||
repartition: spark_connect::Repartition, | ||
) -> eyre::Result<daft_logical_plan::LogicalPlanBuilder> { | ||
let spark_connect::Repartition { | ||
input, | ||
num_partitions, | ||
shuffle, | ||
} = repartition; | ||
|
||
let Some(input) = input else { | ||
bail!("Input is required"); | ||
}; | ||
|
||
let num_partitions = usize::try_from(num_partitions).map_err(|_| { | ||
eyre::eyre!("Num partitions must be a positive integer, got {num_partitions}") | ||
})?; | ||
|
||
ensure!( | ||
num_partitions > 0, | ||
"Num partitions must be greater than 0, got {num_partitions}" | ||
); | ||
|
||
let plan = to_logical_plan(*input)?; | ||
|
||
// let's make true is default | ||
let shuffle = shuffle.unwrap_or(true); | ||
|
||
if !shuffle { | ||
bail!("Repartitioning without shuffling is not yet supported"); | ||
} | ||
|
||
let plan = plan | ||
.random_shuffle(Some(num_partitions)) | ||
.wrap_err("Failed to apply random_shuffle to logical plan")?; | ||
|
||
Ok(plan) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from __future__ import annotations | ||
|
||
def test_repartition(spark_session): | ||
# Create a simple DataFrame | ||
df = spark_session.range(10) | ||
|
||
# Test repartitioning to 2 partitions | ||
repartitioned = df.repartition(2) | ||
|
||
# Verify data is preserved after repartitioning | ||
original_data = sorted(df.collect()) | ||
repartitioned_data = sorted(repartitioned.collect()) | ||
assert repartitioned_data == original_data, "Data should be preserved after repartitioning" |