diff --git a/src/daft-connect/src/translation/logical_plan/sort.rs b/src/daft-connect/src/translation/logical_plan/sort.rs index 15baccbcd8..695c4d5756 100644 --- a/src/daft-connect/src/translation/logical_plan/sort.rs +++ b/src/daft-connect/src/translation/logical_plan/sort.rs @@ -50,7 +50,8 @@ pub fn sort(sort: spark_connect::Sort) -> eyre::Result { // todo(correctness): is this correct? let is_descending = match direction { SortDirection::Unspecified => { - bail!("Unspecified sort direction is not yet supported") + // default to ascending order + false } SortDirection::Ascending => false, SortDirection::Descending => true, @@ -59,7 +60,8 @@ pub fn sort(sort: spark_connect::Sort) -> eyre::Result { // todo(correctness): is this correct? let tentative_sort_nulls_first = match null_ordering { NullOrdering::SortNullsUnspecified => { - bail!("Unspecified null ordering is not yet supported") + // default: match is_descending + is_descending } NullOrdering::SortNullsFirst => true, NullOrdering::SortNullsLast => false, diff --git a/tests/connect/test_sort.py b/tests/connect/test_sort.py index 653510db18..43ed0a85cd 100644 --- a/tests/connect/test_sort.py +++ b/tests/connect/test_sort.py @@ -3,7 +3,7 @@ from pyspark.sql.functions import col -def test_sort(spark_session): +def test_sort_desc(spark_session): # Create DataFrame from range(10) df = spark_session.range(10) @@ -11,6 +11,35 @@ def test_sort(spark_session): df_sorted = df.sort(col("id").desc()) # Verify the DataFrame is sorted correctly - df_pandas = df.toPandas() - df_sorted_pandas = df_sorted.toPandas() - assert df_sorted_pandas["id"].equals(df_pandas["id"].sort_values(ascending=False).reset_index(drop=True)), "Data should be sorted in descending order" + actual = df_sorted.collect() + + expected = list(range(9, -1, -1)) + assert [row.id for row in actual] == expected + + +def test_sort_asc(spark_session): + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Sort the DataFrame by 'id' column in ascending order + df_sorted = df.sort(col("id").asc()) + + # Verify the DataFrame is sorted correctly + actual = df_sorted.collect() + + expected = list(range(10)) + assert [row.id for row in actual] == expected + + +def test_sort_default(spark_session): + # Create DataFrame from range(10) + df = spark_session.range(10) + + # Default sort should be ascending + df_sorted = df.sort("id") + + # Verify the DataFrame is sorted correctly + actual = df_sorted.collect() + + expected = list(range(10)) + assert [row.id for row in actual] == expected