From 9ac4ec4dc70600c6c7ca10c648cc668f6978375f Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Wed, 20 Nov 2024 01:23:05 -0800 Subject: [PATCH] [FEAT]: connect: `df.sort` --- .../src/translation/logical_plan.rs | 6 ++ .../src/translation/logical_plan/sort.rs | 83 +++++++++++++++++++ tests/connect/test_sort.py | 16 ++++ 3 files changed, 105 insertions(+) create mode 100644 src/daft-connect/src/translation/logical_plan/sort.rs create mode 100644 tests/connect/test_sort.py diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 439f5bd551..66d9a3e623 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -18,6 +18,8 @@ use futures::TryStreamExt; use spark_connect::{relation::RelType, Limit, Relation, ShowString}; use tracing::warn; +use crate::translation::logical_plan::sort::sort; + mod aggregate; mod drop; mod filter; @@ -25,6 +27,7 @@ mod local_relation; mod project; mod range; mod read; +mod sort; mod to_df; mod with_columns; @@ -113,6 +116,9 @@ impl SparkAnalyzer<'_> { RelType::Read(r) => read::read(r) .await .wrap_err("Failed to apply read to logical plan"), + RelType::Sort(s) => sort(*s) + .await + .wrap_err("Failed to apply sort to logical plan"), RelType::Drop(d) => self .drop(*d) .await diff --git a/src/daft-connect/src/translation/logical_plan/sort.rs b/src/daft-connect/src/translation/logical_plan/sort.rs new file mode 100644 index 0000000000..3884ca9984 --- /dev/null +++ b/src/daft-connect/src/translation/logical_plan/sort.rs @@ -0,0 +1,83 @@ +use eyre::{bail, WrapErr}; +use spark_connect::expression::{ + sort_order::{NullOrdering, SortDirection}, + SortOrder, +}; +use tracing::warn; + +use crate::translation::{to_daft_expr, to_logical_plan, Plan}; + +pub async fn sort(sort: spark_connect::Sort) -> eyre::Result { + let spark_connect::Sort { + input, + order, + is_global, + } = sort; + + if let Some(is_global) = is_global { + warn!("Ignoring is_global {is_global}; not yet implemented"); + } + + let Some(input) = input else { + bail!("Input is required"); + }; + + let mut plan = Box::pin(to_logical_plan(*input)).await?; + + let mut sort_by = Vec::new(); + let mut descending = Vec::new(); + let mut nulls_first = Vec::new(); + + for o in &order { + let SortOrder { + child, + direction, + null_ordering, + } = o; + + let Some(child) = child else { + bail!("Child is required"); + }; + + let child = to_daft_expr(child)?; + + let direction = SortDirection::try_from(*direction) + .wrap_err_with(|| format!("Invalid sort direction: {direction:?}"))?; + + let null_ordering = NullOrdering::try_from(*null_ordering) + .wrap_err_with(|| format!("Invalid null ordering: {null_ordering:?}"))?; + + // todo(correctness): is this correct? + let is_descending = match direction { + SortDirection::Unspecified => { + bail!("Unspecified sort direction is not yet supported") + } + SortDirection::Ascending => false, + SortDirection::Descending => true, + }; + + // todo(correctness): is this correct? + let tentative_sort_nulls_first = match null_ordering { + NullOrdering::SortNullsUnspecified => { + bail!("Unspecified null ordering is not yet supported") + } + NullOrdering::SortNullsFirst => true, + NullOrdering::SortNullsLast => false, + }; + + // https://github.com/Eventual-Inc/Daft/blob/7922d2d810ff92b00008d877aa9a6553bc0dedab/src/daft-core/src/utils/mod.rs#L10-L19 + let sort_nulls_first = is_descending; + + if sort_nulls_first != tentative_sort_nulls_first { + warn!("Ignoring nulls_first {sort_nulls_first}; not yet implemented"); + } + + sort_by.push(child); + descending.push(is_descending); + nulls_first.push(sort_nulls_first); + } + + plan.builder = plan.builder.sort(sort_by, descending, nulls_first)?; + + Ok(plan) +} diff --git a/tests/connect/test_sort.py b/tests/connect/test_sort.py new file mode 100644 index 0000000000..653510db18 --- /dev/null +++ b/tests/connect/test_sort.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +from pyspark.sql.functions import col + + +def test_sort(spark_session): + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Sort the DataFrame by 'id' column in descending order + df_sorted = df.sort(col("id").desc()) + + # Verify the DataFrame is sorted correctly + df_pandas = df.toPandas() + df_sorted_pandas = df_sorted.toPandas() + assert df_sorted_pandas["id"].equals(df_pandas["id"].sort_values(ascending=False).reset_index(drop=True)), "Data should be sorted in descending order"