Skip to content

Commit

Permalink
[FEAT]: sql HAVING (#3364)
Browse files Browse the repository at this point in the history
  • Loading branch information
universalmind303 authored Nov 20, 2024
1 parent bdfb8c6 commit b6695eb
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 4 deletions.
35 changes: 31 additions & 4 deletions src/daft-sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,13 +347,20 @@ impl<'a> SQLPlanner<'a> {
let has_aggs = projections.iter().any(has_agg) || !groupby_exprs.is_empty();

if has_aggs {
let having = selection
.having
.as_ref()
.map(|h| self.plan_expr(h))
.transpose()?;

self.plan_aggregate_query(
&projections,
&schema,
has_orderby,
groupby_exprs,
query,
&projection_schema,
having,
)?;
} else {
self.plan_non_agg_query(projections, schema, has_orderby, query, projection_schema)?;
Expand Down Expand Up @@ -464,6 +471,7 @@ impl<'a> SQLPlanner<'a> {
Ok(())
}

#[allow(clippy::too_many_arguments)]
fn plan_aggregate_query(
&mut self,
projections: &Vec<Arc<Expr>>,
Expand All @@ -472,6 +480,7 @@ impl<'a> SQLPlanner<'a> {
groupby_exprs: Vec<Arc<Expr>>,
query: &Query,
projection_schema: &Schema,
having: Option<Arc<Expr>>,
) -> Result<(), PlannerError> {
let mut final_projection = Vec::with_capacity(projections.len());
let mut aggs = Vec::with_capacity(projections.len());
Expand Down Expand Up @@ -500,6 +509,15 @@ impl<'a> SQLPlanner<'a> {
final_projection.push(p.clone());
}
}

if let Some(having) = &having {
if has_agg(having) {
let having = having.alias(having.semantic_id(schema).id);

aggs.push(having);
}
}

let groupby_exprs = groupby_exprs
.into_iter()
.map(|e| {
Expand Down Expand Up @@ -631,7 +649,7 @@ impl<'a> SQLPlanner<'a> {
}

let rel = self.relation_mut();
rel.inner = rel.inner.aggregate(aggs, groupby_exprs)?;
rel.inner = rel.inner.aggregate(aggs.clone(), groupby_exprs)?;

let has_orderby_before_projection = !orderbys_before_projection.is_empty();
let has_orderby_after_projection = !orderbys_after_projection.is_empty();
Expand All @@ -650,6 +668,16 @@ impl<'a> SQLPlanner<'a> {
)?;
}

if let Some(having) = having {
// if it's an agg, it's already resolved during .agg, so we just reference the column name
let having = if has_agg(&having) {
col(having.semantic_id(schema).id)
} else {
having
};
rel.inner = rel.inner.filter(having)?;
}

// apply the final projection
rel.inner = rel.inner.select(final_projection)?;

Expand All @@ -661,6 +689,7 @@ impl<'a> SQLPlanner<'a> {
orderbys_after_projection_nulls_first,
)?;
}

Ok(())
}

Expand Down Expand Up @@ -1999,9 +2028,7 @@ fn check_select_features(selection: &sqlparser::ast::Select) -> SQLPlannerResult
if !selection.sort_by.is_empty() {
unsupported_sql_err!("SORT BY");
}
if selection.having.is_some() {
unsupported_sql_err!("HAVING");
}

if !selection.named_window.is_empty() {
unsupported_sql_err!("WINDOW");
}
Expand Down
63 changes: 63 additions & 0 deletions tests/sql/test_aggs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import pytest

import daft
from daft import col
from daft.sql import SQLCatalog


def test_aggs_sql():
Expand Down Expand Up @@ -41,3 +44,63 @@ def test_aggs_sql():
)

assert actual == expected


@pytest.mark.parametrize(
"agg,cond,expected",
[
("sum(values)", "sum(values) > 10", {"values": [20.5, 29.5]}),
("sum(values)", "values > 10", {"values": [20.5, 29.5]}),
("sum(values) as sum_v", "sum(values) > 10", {"sum_v": [20.5, 29.5]}),
("sum(values) as sum_v", "sum_v > 10", {"sum_v": [20.5, 29.5]}),
("count(*) as cnt", "cnt > 2", {"cnt": [3, 5]}),
("count(*) as cnt", "count(*) > 2", {"cnt": [3, 5]}),
("count(*)", "count(*) > 2", {"count": [3, 5]}),
("count(*) as cnt", "sum(values) > 10", {"cnt": [3, 5]}),
("sum(values), count(*)", "id > 1", {"values": [10.0, 29.5], "count": [2, 5]}),
],
)
def test_having(agg, cond, expected):
df = daft.from_pydict(
{
"id": [1, 2, 3, 3, 3, 3, 2, 1, 3, 1],
"values": [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5],
}
)
catalog = SQLCatalog({"df": df})

actual = daft.sql(
f"""
SELECT
{agg},
from df
group by id
having {cond}
""",
catalog,
).to_pydict()

assert actual == expected


def test_having_non_grouped():
df = daft.from_pydict(
{
"id": [1, 2, 3, 3, 3, 3, 2, 1, 3, 1],
"values": [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 10.5],
"floats": [0.01, 0.011, 0.01047, 0.02, 0.019, 0.018, 0.017, 0.016, 0.015, 0.014],
}
)
catalog = SQLCatalog({"df": df})

actual = daft.sql(
"""
SELECT
count(*) ,
from df
having sum(values) > 40
""",
catalog,
).to_pydict()

assert actual == {"count": [10]}

0 comments on commit b6695eb

Please sign in to comment.