Skip to content

Commit

Permalink
Update test_sort.py
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Nov 25, 2024
1 parent e6833f7 commit 8e1b27a
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 6 deletions.
6 changes: 4 additions & 2 deletions src/daft-connect/src/translation/logical_plan/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ pub fn sort(sort: spark_connect::Sort) -> eyre::Result<LogicalPlanBuilder> {
// todo(correctness): is this correct?
let is_descending = match direction {
SortDirection::Unspecified => {
bail!("Unspecified sort direction is not yet supported")
// default to ascending order
false

Check warning on line 54 in src/daft-connect/src/translation/logical_plan/sort.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/logical_plan/sort.rs#L54

Added line #L54 was not covered by tests
}
SortDirection::Ascending => false,
SortDirection::Descending => true,
Expand All @@ -59,7 +60,8 @@ pub fn sort(sort: spark_connect::Sort) -> eyre::Result<LogicalPlanBuilder> {
// todo(correctness): is this correct?
let tentative_sort_nulls_first = match null_ordering {
NullOrdering::SortNullsUnspecified => {
bail!("Unspecified null ordering is not yet supported")
// default: match is_descending
is_descending

Check warning on line 64 in src/daft-connect/src/translation/logical_plan/sort.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/logical_plan/sort.rs#L64

Added line #L64 was not covered by tests
}
NullOrdering::SortNullsFirst => true,
NullOrdering::SortNullsLast => false,
Expand Down
37 changes: 33 additions & 4 deletions tests/connect/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,43 @@
from pyspark.sql.functions import col


def test_sort(spark_session):
def test_sort_desc(spark_session):
# Create DataFrame from range(10)
df = spark_session.range(10)

# Sort the DataFrame by 'id' column in descending order
df_sorted = df.sort(col("id").desc())

# Verify the DataFrame is sorted correctly
df_pandas = df.toPandas()
df_sorted_pandas = df_sorted.toPandas()
assert df_sorted_pandas["id"].equals(df_pandas["id"].sort_values(ascending=False).reset_index(drop=True)), "Data should be sorted in descending order"
actual = df_sorted.collect()

expected = list(range(9, -1, -1))
assert [row.id for row in actual] == expected


def test_sort_asc(spark_session):
# Create DataFrame from range(10)
df = spark_session.range(10)

# Sort the DataFrame by 'id' column in ascending order
df_sorted = df.sort(col("id").asc())

# Verify the DataFrame is sorted correctly
actual = df_sorted.collect()

expected = list(range(10))
assert [row.id for row in actual] == expected


def test_sort_default(spark_session):
# Create DataFrame from range(10)
df = spark_session.range(10)

# Default sort should be ascending
df_sorted = df.sort("id")

# Verify the DataFrame is sorted correctly
actual = df_sorted.collect()

expected = list(range(10))
assert [row.id for row in actual] == expected

0 comments on commit 8e1b27a

Please sign in to comment.