Skip to content

Commit

Permalink
fix(connect): count returns the number of rows in the table
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewgazelka committed Dec 19, 2024
1 parent c30f6a8 commit 041a1ee
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
11 changes: 10 additions & 1 deletion src/daft-connect/src/translation/expr/unresolved_function.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use daft_core::count_mode::CountMode;
use daft_schema::dtype::DataType;
use eyre::{bail, Context};
use spark_connect::expression::UnresolvedFunction;

Expand Down Expand Up @@ -79,7 +80,15 @@ pub fn handle_count(arguments: Vec<daft_dsl::ExprRef>) -> eyre::Result<daft_dsl:

let [arg] = arguments;

let count = arg.count(CountMode::All);
let count = if arg == daft_dsl::lit(1_i32) {
// Count(Literal(1)) is handled differently by Daft. Generally speaking though, what
// this means is we are counting the number of rows in the table.
daft_dsl::col("*")
.count(CountMode::All)
.cast(&DataType::Int64)
} else {
arg.count(CountMode::All).cast(&DataType::Int64)

Check warning on line 90 in src/daft-connect/src/translation/expr/unresolved_function.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/translation/expr/unresolved_function.rs#L90

Added line #L90 was not covered by tests
};

Ok(count)
}
Expand Down
13 changes: 13 additions & 0 deletions tests/connect/test_count.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from __future__ import annotations


def test_count(spark_session):
# Create a range using Spark
# For example, creating a range from 0 to 9
spark_range = spark_session.range(10) # Creates DataFrame with numbers 0 to 9

# Convert to Pandas DataFrame
count = spark_range.count()

# Verify the DataFrame has expected values
assert count == 10, "DataFrame should have 10 rows"

0 comments on commit 041a1ee

Please sign in to comment.