diff --git a/src/daft-sql/src/modules/aggs.rs b/src/daft-sql/src/modules/aggs.rs index e9b524874c..f2b70d68bf 100644 --- a/src/daft-sql/src/modules/aggs.rs +++ b/src/daft-sql/src/modules/aggs.rs @@ -1,7 +1,7 @@ use std::sync::Arc; -use daft_dsl::{AggExpr, Expr, ExprRef, LiteralValue}; -use sqlparser::ast::FunctionArg; +use daft_dsl::{col, AggExpr, Expr, ExprRef, LiteralValue}; +use sqlparser::ast::{FunctionArg, FunctionArgExpr}; use crate::{ ensure, @@ -34,20 +34,50 @@ impl SQLModule for SQLModuleAggs { impl SQLFunction for AggExpr { fn to_expr(&self, inputs: &[FunctionArg], planner: &SQLPlanner) -> SQLPlannerResult { - let inputs = self.args_to_expr_unnamed(inputs, planner)?; - to_expr(self, inputs.as_slice()) + // COUNT(*) needs a bit of extra handling, so we process that outside of `to_expr` + if let AggExpr::Count(_, _) = self { + handle_count(inputs, planner) + } else { + let inputs = self.args_to_expr_unnamed(inputs, planner)?; + to_expr(self, inputs.as_slice()) + } } } +fn handle_count(inputs: &[FunctionArg], planner: &SQLPlanner) -> SQLPlannerResult { + Ok(match inputs { + [FunctionArg::Unnamed(FunctionArgExpr::Wildcard)] => match planner.relation_opt() { + Some(rel) => { + let schema = rel.schema(); + col(schema.fields[0].name.clone()) + .count(daft_core::count_mode::CountMode::All) + .alias("count") + } + None => unsupported_sql_err!("Wildcard is not supported in this context"), + }, + [FunctionArg::Unnamed(FunctionArgExpr::QualifiedWildcard(name))] => { + match planner.relation_opt() { + Some(rel) if name.to_string() == rel.name => { + let schema = rel.schema(); + col(schema.fields[0].name.clone()) + .count(daft_core::count_mode::CountMode::All) + .alias("count") + } + _ => unsupported_sql_err!("Wildcard is not supported in this context"), + } + } + [expr] => { + // SQL default COUNT ignores nulls + let input = planner.plan_function_arg(expr)?; + input.count(daft_core::count_mode::CountMode::Valid) + } + _ => unsupported_sql_err!("COUNT takes exactly one argument"), + }) +} + pub(crate) fn to_expr(expr: &AggExpr, args: &[ExprRef]) -> SQLPlannerResult { match expr { - AggExpr::Count(_, _) => { - // SQL default COUNT ignores nulls. - ensure!(args.len() == 1, "count takes exactly one argument"); - Ok(args[0] - .clone() - .count(daft_core::count_mode::CountMode::Valid)) - } + AggExpr::Count(_, _) => unreachable!("count should be handled by by this point"), AggExpr::Sum(_) => { ensure!(args.len() == 1, "sum takes exactly one argument"); Ok(args[0].clone().sum()) diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index b956b2c999..08abedf903 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -31,14 +31,17 @@ use sqlparser::{ /// This is used to keep track of the table name associated with a logical plan while planning a SQL query #[derive(Debug, Clone)] pub(crate) struct Relation { - inner: LogicalPlanBuilder, - name: String, + pub(crate) inner: LogicalPlanBuilder, + pub(crate) name: String, } impl Relation { pub fn new(inner: LogicalPlanBuilder, name: String) -> Self { Relation { inner, name } } + pub(crate) fn schema(&self) -> SchemaRef { + self.inner.schema() + } } pub struct SQLPlanner { @@ -70,6 +73,10 @@ impl SQLPlanner { self.current_relation.as_mut().expect("relation not set") } + pub(crate) fn relation_opt(&self) -> Option<&Relation> { + self.current_relation.as_ref() + } + pub fn plan_sql(&mut self, sql: &str) -> SQLPlannerResult { let tokens = Tokenizer::new(&GenericDialect {}, sql).tokenize()?; diff --git a/tests/sql/test_sql.py b/tests/sql/test_sql.py index 2792771e78..662e4d8b14 100644 --- a/tests/sql/test_sql.py +++ b/tests/sql/test_sql.py @@ -113,3 +113,21 @@ def test_sql_groupby_agg(): catalog = SQLCatalog({"test": df}) df = daft.sql("SELECT sum(v) FROM test GROUP BY n ORDER BY n", catalog=catalog) assert df.collect().to_pydict() == {"n": [1, 2], "v": [3, 7]} + + +def test_sql_count_star(): + df = daft.from_pydict( + { + "a": ["a", "b", None, "c"], + "b": [4, 3, 2, None], + } + ) + catalog = SQLCatalog({"df": df}) + df2 = daft.sql("SELECT count(*) FROM df", catalog) + actual = df2.collect().to_pydict() + expected = df.count().collect().to_pydict() + assert actual == expected + df2 = daft.sql("SELECT count(b) FROM df", catalog) + actual = df2.collect().to_pydict() + expected = df.agg(daft.col("b").count()).collect().to_pydict() + assert actual == expected