Skip to content

Commit

Permalink
[FEAT]: connect: df.sort
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Dec 19, 2024
1 parent c30f6a8 commit 9ac4ec4
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/daft-connect/src/translation/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@ use futures::TryStreamExt;
use spark_connect::{relation::RelType, Limit, Relation, ShowString};
use tracing::warn;

use crate::translation::logical_plan::sort::sort;

mod aggregate;
mod drop;
mod filter;
mod local_relation;
mod project;
mod range;
mod read;
mod sort;
mod to_df;
mod with_columns;

Expand Down Expand Up @@ -113,6 +116,9 @@ impl SparkAnalyzer<'_> {
RelType::Read(r) => read::read(r)
.await
.wrap_err("Failed to apply read to logical plan"),
RelType::Sort(s) => sort(*s)
.await
.wrap_err("Failed to apply sort to logical plan"),
RelType::Drop(d) => self
.drop(*d)
.await
Expand Down
83 changes: 83 additions & 0 deletions src/daft-connect/src/translation/logical_plan/sort.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
use eyre::{bail, WrapErr};
use spark_connect::expression::{
sort_order::{NullOrdering, SortDirection},
SortOrder,
};
use tracing::warn;

use crate::translation::{to_daft_expr, to_logical_plan, Plan};

pub async fn sort(sort: spark_connect::Sort) -> eyre::Result<Plan> {
let spark_connect::Sort {
input,
order,
is_global,
} = sort;

if let Some(is_global) = is_global {
warn!("Ignoring is_global {is_global}; not yet implemented");
}

let Some(input) = input else {
bail!("Input is required");
};

let mut plan = Box::pin(to_logical_plan(*input)).await?;

let mut sort_by = Vec::new();
let mut descending = Vec::new();
let mut nulls_first = Vec::new();

for o in &order {
let SortOrder {
child,
direction,
null_ordering,
} = o;

let Some(child) = child else {
bail!("Child is required");
};

let child = to_daft_expr(child)?;

let direction = SortDirection::try_from(*direction)
.wrap_err_with(|| format!("Invalid sort direction: {direction:?}"))?;

let null_ordering = NullOrdering::try_from(*null_ordering)
.wrap_err_with(|| format!("Invalid null ordering: {null_ordering:?}"))?;

// todo(correctness): is this correct?
let is_descending = match direction {
SortDirection::Unspecified => {
bail!("Unspecified sort direction is not yet supported")
}
SortDirection::Ascending => false,
SortDirection::Descending => true,
};

// todo(correctness): is this correct?
let tentative_sort_nulls_first = match null_ordering {
NullOrdering::SortNullsUnspecified => {
bail!("Unspecified null ordering is not yet supported")
}
NullOrdering::SortNullsFirst => true,
NullOrdering::SortNullsLast => false,
};

// https://github.com/Eventual-Inc/Daft/blob/7922d2d810ff92b00008d877aa9a6553bc0dedab/src/daft-core/src/utils/mod.rs#L10-L19
let sort_nulls_first = is_descending;

if sort_nulls_first != tentative_sort_nulls_first {
warn!("Ignoring nulls_first {sort_nulls_first}; not yet implemented");
}

sort_by.push(child);
descending.push(is_descending);
nulls_first.push(sort_nulls_first);
}

plan.builder = plan.builder.sort(sort_by, descending, nulls_first)?;

Ok(plan)
}
16 changes: 16 additions & 0 deletions tests/connect/test_sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from __future__ import annotations

from pyspark.sql.functions import col


def test_sort(spark_session):
# Create DataFrame from range(10)
df = spark_session.range(10)

# Sort the DataFrame by 'id' column in descending order
df_sorted = df.sort(col("id").desc())

# Verify the DataFrame is sorted correctly
df_pandas = df.toPandas()
df_sorted_pandas = df_sorted.toPandas()
assert df_sorted_pandas["id"].equals(df_pandas["id"].sort_values(ascending=False).reset_index(drop=True)), "Data should be sorted in descending order"

0 comments on commit 9ac4ec4

Please sign in to comment.