-
Notifications
You must be signed in to change notification settings - Fork 174
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3394a66
commit f7e7817
Showing
3 changed files
with
102 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
use eyre::{bail, WrapErr}; | ||
use spark_connect::expression::{ | ||
sort_order::{NullOrdering, SortDirection}, | ||
SortOrder, | ||
}; | ||
use tracing::warn; | ||
|
||
use crate::translation::{logical_plan::LogicalPlanBuilder, to_daft_expr, to_logical_plan}; | ||
|
||
pub fn sort(sort: spark_connect::Sort) -> eyre::Result<LogicalPlanBuilder> { | ||
let spark_connect::Sort { | ||
input, | ||
order, | ||
is_global, | ||
} = sort; | ||
|
||
if let Some(is_global) = is_global { | ||
warn!("Ignoring is_global {is_global}; not yet implemented"); | ||
} | ||
|
||
let Some(input) = input else { | ||
bail!("Input is required"); | ||
}; | ||
|
||
let plan = to_logical_plan(*input)?; | ||
|
||
let mut sort_by = Vec::new(); | ||
let mut descending = Vec::new(); | ||
let mut nulls_first = Vec::new(); | ||
|
||
for o in &order { | ||
let SortOrder { | ||
child, | ||
direction, | ||
null_ordering, | ||
} = o; | ||
|
||
let Some(child) = child else { | ||
bail!("Child is required"); | ||
}; | ||
|
||
let child = to_daft_expr(child)?; | ||
|
||
let direction = SortDirection::try_from(*direction) | ||
.wrap_err_with(|| format!("Invalid sort direction: {direction:?}"))?; | ||
|
||
let null_ordering = NullOrdering::try_from(*null_ordering) | ||
.wrap_err_with(|| format!("Invalid null ordering: {null_ordering:?}"))?; | ||
|
||
// todo(correctness): is this correct? | ||
let is_descending = match direction { | ||
SortDirection::Unspecified => { | ||
bail!("Unspecified sort direction is not yet supported") | ||
} | ||
SortDirection::Ascending => false, | ||
SortDirection::Descending => true, | ||
}; | ||
|
||
// todo(correctness): is this correct? | ||
let tentative_sort_nulls_first = match null_ordering { | ||
NullOrdering::SortNullsUnspecified => { | ||
bail!("Unspecified null ordering is not yet supported") | ||
} | ||
NullOrdering::SortNullsFirst => true, | ||
NullOrdering::SortNullsLast => false, | ||
}; | ||
|
||
// https://github.com/Eventual-Inc/Daft/blob/7922d2d810ff92b00008d877aa9a6553bc0dedab/src/daft-core/src/utils/mod.rs#L10-L19 | ||
let sort_nulls_first = is_descending; | ||
|
||
if sort_nulls_first != tentative_sort_nulls_first { | ||
warn!("Ignoring nulls_first {sort_nulls_first}; not yet implemented"); | ||
} | ||
|
||
sort_by.push(child); | ||
descending.push(is_descending); | ||
nulls_first.push(sort_nulls_first); | ||
} | ||
|
||
let plan = plan.sort(sort_by, descending, nulls_first)?; | ||
|
||
Ok(plan) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from __future__ import annotations | ||
|
||
from pyspark.sql.functions import col | ||
|
||
|
||
def test_sort(spark_session): | ||
# Create DataFrame from range(10) | ||
df = spark_session.range(10) | ||
|
||
# Sort the DataFrame by 'id' column in descending order | ||
df_sorted = df.sort(col("id").desc()) | ||
|
||
# Verify the DataFrame is sorted correctly | ||
df_pandas = df.toPandas() | ||
df_sorted_pandas = df_sorted.toPandas() | ||
assert df_sorted_pandas["id"].equals(df_pandas["id"].sort_values(ascending=False).reset_index(drop=True)), "Data should be sorted in descending order" |