diff --git a/src/daft-connect/src/translation/logical_plan/sort.rs b/src/daft-connect/src/translation/logical_plan/sort.rs index 15baccbcd8..695c4d5756 100644 --- a/src/daft-connect/src/translation/logical_plan/sort.rs +++ b/src/daft-connect/src/translation/logical_plan/sort.rs @@ -50,7 +50,8 @@ pub fn sort(sort: spark_connect::Sort) -> eyre::Result { // todo(correctness): is this correct? let is_descending = match direction { SortDirection::Unspecified => { - bail!("Unspecified sort direction is not yet supported") + // default to ascending order + false } SortDirection::Ascending => false, SortDirection::Descending => true, @@ -59,7 +60,8 @@ pub fn sort(sort: spark_connect::Sort) -> eyre::Result { // todo(correctness): is this correct? let tentative_sort_nulls_first = match null_ordering { NullOrdering::SortNullsUnspecified => { - bail!("Unspecified null ordering is not yet supported") + // default: match is_descending + is_descending } NullOrdering::SortNullsFirst => true, NullOrdering::SortNullsLast => false, diff --git a/tests/connect/test_sort.py b/tests/connect/test_sort.py index 653510db18..112a1fc658 100644 --- a/tests/connect/test_sort.py +++ b/tests/connect/test_sort.py @@ -3,14 +3,33 @@ from pyspark.sql.functions import col -def test_sort(spark_session): - # Create DataFrame from range(10) - df = spark_session.range(10) +def test_sort_multiple_columns(spark_session): + # Create DataFrame with two columns using range + df = spark_session.range(4).select( + (col("id") % 2).alias("num"), + (col("id") % 2).cast("string").alias("letter") + ) - # Sort the DataFrame by 'id' column in descending order - df_sorted = df.sort(col("id").desc()) + # Sort by multiple columns + df_sorted = df.sort(col("num").asc(), col("letter").desc()) # Verify the DataFrame is sorted correctly - df_pandas = df.toPandas() - df_sorted_pandas = df_sorted.toPandas() - assert df_sorted_pandas["id"].equals(df_pandas["id"].sort_values(ascending=False).reset_index(drop=True)), "Data should be sorted in descending order" + actual = df_sorted.collect() + expected = [(0, "0"), (0, "0"), (1, "1"), (1, "1")] + assert [(row.num, row.letter) for row in actual] == expected + + +def test_sort_mixed_order(spark_session): + # Create DataFrame with two columns using range + df = spark_session.range(4).select( + (col("id") % 2).alias("num"), + (col("id") % 2).cast("string").alias("letter") + ) + + # Sort with mixed ascending/descending order + df_sorted = df.sort(col("num").desc(), col("letter").asc()) + + # Verify the DataFrame is sorted correctly + actual = df_sorted.collect() + expected = [(1, "1"), (1, "1"), (0, "0"), (0, "0")] + assert [(row.num, row.letter) for row in actual] == expected