From 8ed174c67559b31ceda5b2d1ed76a5069783482b Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Tue, 5 Nov 2024 18:16:20 -0600 Subject: [PATCH] [BUG]: Sql groupby and orderby with aliases and projections (#3177) Co-authored-by: Kev Wang --- src/daft-sql/src/error.rs | 1 - src/daft-sql/src/lib.rs | 26 +++ src/daft-sql/src/planner.rs | 372 ++++++++++++++++++++++++++++-------- tests/sql/test_sql.py | 38 +++- 4 files changed, 360 insertions(+), 77 deletions(-) diff --git a/src/daft-sql/src/error.rs b/src/daft-sql/src/error.rs index 7cfa8428aa..7f033d1b08 100644 --- a/src/daft-sql/src/error.rs +++ b/src/daft-sql/src/error.rs @@ -1,7 +1,6 @@ use common_error::DaftError; use snafu::Snafu; use sqlparser::{parser::ParserError, tokenizer::TokenizerError}; - #[derive(Debug, Snafu)] pub enum PlannerError { #[snafu(display("Tokenization error: {source}"))] diff --git a/src/daft-sql/src/lib.rs b/src/daft-sql/src/lib.rs index 954bacd997..af97b738c4 100644 --- a/src/daft-sql/src/lib.rs +++ b/src/daft-sql/src/lib.rs @@ -336,9 +336,35 @@ mod tests { let expected = LogicalPlanBuilder::new(tbl_1, None) .aggregate(vec![col("i32").max()], vec![])? + .select(vec![col("i32")])? .build(); assert_eq!(plan, expected); Ok(()) } + + #[rstest] + #[case::basic("select utf8 from tbl1 order by utf8")] + #[case::asc("select utf8 from tbl1 order by utf8 asc")] + #[case::desc("select utf8 from tbl1 order by utf8 desc")] + #[case::with_alias("select utf8 as a from tbl1 order by a")] + #[case::with_alias_in_projection_only("select utf8 as a from tbl1 order by utf8")] + #[case::with_groupby("select utf8, sum(i32) from tbl1 group by utf8 order by utf8")] + #[case::with_groupby_and_alias( + "select utf8 as a, sum(i32) from tbl1 group by utf8 order by utf8" + )] + #[case::with_groupby_and_alias_mixed("select utf8 as a from tbl1 group by a order by utf8")] + #[case::with_groupby_and_alias_mixed_2("select utf8 as a from tbl1 group by utf8 order by a")] + #[case::with_groupby_and_alias_mixed_asc( + "select utf8 as a from tbl1 group by utf8 order by a asc" + )] + fn test_compiles_orderby(mut planner: SQLPlanner, #[case] query: &str) -> SQLPlannerResult<()> { + let plan = planner.plan_sql(query); + if let Err(e) = plan { + panic!("query: {query}\nerror: {e:?}"); + } + assert!(plan.is_ok(), "query: {query}\nerror: {plan:?}"); + + Ok(()) + } } diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 76e30d5912..aeb63735ee 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -1,6 +1,9 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; -use common_error::DaftResult; +use common_error::{DaftError, DaftResult}; use daft_core::prelude::*; use daft_dsl::{ col, @@ -22,10 +25,8 @@ use sqlparser::{ }; use crate::{ - catalog::SQLCatalog, - column_not_found_err, - error::{PlannerError, SQLPlannerResult}, - invalid_operation_err, table_not_found_err, unsupported_sql_err, + catalog::SQLCatalog, column_not_found_err, error::*, invalid_operation_err, + table_not_found_err, unsupported_sql_err, }; /// A named logical plan @@ -221,6 +222,7 @@ impl SQLPlanner { // FROM/JOIN let from = selection.clone().from; let rel = self.plan_from(&from)?; + let schema = rel.schema(); self.current_relation = Some(rel); // WHERE @@ -246,32 +248,36 @@ impl SQLPlanner { .collect::>>()?; } } + let mut projections = Vec::with_capacity(selection.projection.len()); + let mut projection_fields = Vec::with_capacity(selection.projection.len()); + for expr in &selection.projection { + let exprs = self.select_item_to_expr(expr, &schema)?; - // split the selection into the groupby expressions and the rest - let (groupby_selection, to_select) = selection - .projection - .iter() - .map(|expr| self.select_item_to_expr(expr)) - .collect::>>()? - .into_iter() - .flatten() - .partition::, _>(|expr| { - groupby_exprs - .iter() - .any(|e| expr.input_mapping() == e.input_mapping()) - }); + let fields = exprs + .iter() + .map(|expr| expr.to_field(&schema).map_err(PlannerError::from)) + .collect::>>()?; - if !groupby_exprs.is_empty() { - let rel = self.relation_mut(); - rel.inner = rel.inner.aggregate(to_select, groupby_exprs.clone())?; - } else if !to_select.is_empty() { - let rel = self.relation_mut(); - let has_aggs = to_select.iter().any(has_agg); - if has_aggs { - rel.inner = rel.inner.aggregate(to_select, vec![])?; - } else { - rel.inner = rel.inner.select(to_select)?; - } + projections.extend(exprs); + + projection_fields.extend(fields); + } + + let projection_schema = Schema::new(projection_fields)?; + let has_orderby = query.order_by.is_some(); + let has_aggs = projections.iter().any(has_agg); + + if has_aggs { + self.plan_aggregate_query( + &projections, + &schema, + has_orderby, + groupby_exprs, + query, + &projection_schema, + )?; + } else { + self.plan_non_agg_query(projections, schema, has_orderby, query, projection_schema)?; } match &selection.distinct { @@ -283,71 +289,239 @@ impl SQLPlanner { None => {} } - if let Some(order_by) = &query.order_by { + if let Some(limit) = &query.limit { + let limit = self.plan_expr(limit)?; + if let Expr::Literal(LiteralValue::Int64(limit)) = limit.as_ref() { + let rel = self.relation_mut(); + rel.inner = rel.inner.limit(*limit, true)?; // TODO: Should this be eager or not? + } else { + invalid_operation_err!( + "LIMIT must be a constant integer, instead got: {limit}" + ); + } + } + + Ok(self.current_relation.clone().unwrap().inner) + } + + fn plan_non_agg_query( + &mut self, + projections: Vec>, + schema: Arc, + has_orderby: bool, + query: &Query, + projection_schema: Schema, + ) -> Result<(), PlannerError> { + // Final/selected cols + // if there is an orderby, and it references a column that is not part of the final projection (such as an alias) + // then we need to keep the original column in the projection, and remove it at the end + // ex: `SELECT a as b, c FROM t ORDER BY a` + // we need to keep a, and c in the first projection + // but only c in the final projection + // We only will apply 2 projections if there is an order by and the order by references a column that is not in the final projection + let mut final_projection = Vec::with_capacity(projections.len()); + let mut orderby_projection = Vec::with_capacity(projections.len()); + for p in &projections { + let fld = p.to_field(&schema); + + let fld = fld?; + let name = fld.name.clone(); + + // if there is an orderby, then the final projection will only contain the columns that are in the orderby + final_projection.push(if has_orderby { + col(name.as_ref()) + } else { + // otherwise we just do a normal projection + p.clone() + }); + } + + if has_orderby { + let order_by = query.order_by.as_ref().unwrap(); + if order_by.interpolate.is_some() { unsupported_sql_err!("ORDER BY [query] [INTERPOLATE]"); + }; + + let (orderby_exprs, orderby_desc) = + self.plan_order_by_exprs(order_by.exprs.as_slice())?; + + for expr in &orderby_exprs { + if let Err(DaftError::FieldNotFound(_)) = expr.to_field(&projection_schema) { + // this is likely an alias + orderby_projection.push(expr.clone()); + } } - // TODO: if ordering by a column not in the projection, this will fail. - let (exprs, descending) = self.plan_order_by_exprs(order_by.exprs.as_slice())?; + // if the orderby references a column that is not in the final projection + // then we need an additional projection + let needs_projection = !orderby_projection.is_empty(); + let rel = self.relation_mut(); - rel.inner = rel.inner.sort(exprs, descending)?; - } + if needs_projection { + let pre_orderby_projections = projections + .iter() + .cloned() + .chain(orderby_projection) + .collect::>() // dedup + .into_iter() + .collect::>(); + rel.inner = rel.inner.select(pre_orderby_projections)?; + } else { + rel.inner = rel.inner.select(projections)?; + } + + rel.inner = rel.inner.sort(orderby_exprs, orderby_desc)?; - // Properly apply or remove the groupby columns from the selection - // This needs to be done after the orderby - // otherwise, the orderby will not be able to reference the grouping columns - // - // ex: SELECT sum(a) as sum_a, max(a) as max_a, b as c FROM table GROUP BY b - // - // The groupby columns are [b] - // the evaluation of sum(a) and max(a) are already handled by the earlier aggregate, - // so our projection is [sum_a, max_a, (b as c)] - // leaving us to handle (b as c) - // - // we filter for the columns in the schema that are not in the groupby keys, - // [sum_a, max_a, b] -> [sum_a, max_a] - // - // Then we add the groupby columns back in with the correct expressions - // this gives us the final projection: [sum_a, max_a, (b as c)] - if !groupby_exprs.is_empty() { + if needs_projection { + rel.inner = rel.inner.select(final_projection)?; + } + } else { let rel = self.relation_mut(); - let schema = rel.inner.schema(); + rel.inner = rel.inner.select(projections)?; + } - let groupby_keys = groupby_exprs - .iter() - .map(|e| Ok(e.to_field(&schema)?.name)) - .collect::>>()?; + Ok(()) + } - let selection_colums = schema - .exclude(groupby_keys.as_ref())? + fn plan_aggregate_query( + &mut self, + projections: &Vec>, + schema: &Arc, + has_orderby: bool, + groupby_exprs: Vec>, + query: &Query, + projection_schema: &Schema, + ) -> Result<(), PlannerError> { + let mut final_projection = Vec::with_capacity(projections.len()); + let mut orderby_projection = Vec::with_capacity(projections.len()); + let mut aggs = Vec::with_capacity(projections.len()); + let mut orderby_exprs = None; + let mut orderby_desc = None; + for p in projections { + let fld = p.to_field(schema)?; + + let name = fld.name.clone(); + if has_agg(p) { + // this is an aggregate, so it is resolved during `.agg`. So we just push the column name + final_projection.push(col(name.as_ref())); + // add it to the aggs list + aggs.push(p.clone()); + } else { + // otherwise we just do a normal projection + final_projection.push(p.clone()); + } + } + let groupby_exprs = groupby_exprs + .into_iter() + .map(|e| { + // instead of trying to do an additional projection for the groupby column, we just map it back to the original (unaliased) column + // ex: SELECT a as b FROM t GROUP BY b + // in this case, we need to resolve b to a + if let Err(DaftError::FieldNotFound(_)) = e.to_field(schema) { + // this is likely an alias + unresolve_alias(e, &final_projection) + } else { + Ok(e) + } + }) + .collect::>>()?; + + if has_orderby { + let order_by = query.order_by.as_ref().unwrap(); + + if order_by.interpolate.is_some() { + unsupported_sql_err!("ORDER BY [query] [INTERPOLATE]"); + }; + + let (exprs, desc) = self.plan_order_by_exprs(order_by.exprs.as_slice())?; + + orderby_exprs = Some(exprs.clone()); + orderby_desc = Some(desc); + + for expr in &exprs { + // if the orderby references a column that is not in the final projection + // then we need an additional projection + if let Err(DaftError::FieldNotFound(_)) = expr.to_field(projection_schema) { + orderby_projection.push(expr.clone()); + } + } + } + + let rel = self.relation_mut(); + rel.inner = rel.inner.aggregate(aggs, groupby_exprs)?; + + let needs_projection = !orderby_projection.is_empty(); + if needs_projection { + let orderby_projection = rel + .schema() .names() .iter() .map(|n| col(n.as_str())) - .chain(groupby_selection) - .collect(); + .chain(orderby_projection) + .collect::>() // dedup + .into_iter() + .collect::>(); - rel.inner = rel.inner.select(selection_colums)?; + rel.inner = rel.inner.select(orderby_projection)?; } - if let Some(limit) = &query.limit { - let limit = self.plan_expr(limit)?; - if let Expr::Literal(LiteralValue::Int64(limit)) = limit.as_ref() { - let rel = self.relation_mut(); - rel.inner = rel.inner.limit(*limit, true)?; // TODO: Should this be eager or not? - } else { - invalid_operation_err!( - "LIMIT must be a constant integer, instead got: {limit}" - ); + // these are orderbys that are part of the final projection + let mut orderbys_after_projection = Vec::new(); + let mut orderbys_after_projection_desc = Vec::new(); + + // these are orderbys that are not part of the final projection + let mut orderbys_before_projection = Vec::new(); + let mut orderbys_before_projection_desc = Vec::new(); + + if let Some(orderby_exprs) = orderby_exprs { + // this needs to be done after the aggregation and any projections + // because the orderby may reference an alias, or an intermediate column that is not in the final projection + let schema = rel.schema(); + for (i, expr) in orderby_exprs.iter().enumerate() { + if let Err(DaftError::FieldNotFound(_)) = expr.to_field(&schema) { + orderbys_after_projection.push(expr.clone()); + let desc = orderby_desc.clone().map(|o| o[i]).unwrap(); + orderbys_after_projection_desc.push(desc); + } else { + let desc = orderby_desc.clone().map(|o| o[i]).unwrap(); + + orderbys_before_projection.push(expr.clone()); + orderbys_before_projection_desc.push(desc); + } } } - Ok(self.current_relation.clone().unwrap().inner) + let has_orderby_before_projection = !orderbys_before_projection.is_empty(); + let has_orderby_after_projection = !orderbys_after_projection.is_empty(); + + // PERF(cory): if there are order bys from both parts, can we combine them into a single sort instead of two? + // or can we optimize them into a single sort? + + // order bys that are not in the final projection + if has_orderby_before_projection { + rel.inner = rel + .inner + .sort(orderbys_before_projection, orderbys_before_projection_desc)?; + } + + rel.inner = rel.inner.select(final_projection)?; + + // order bys that are in the final projection + if has_orderby_after_projection { + rel.inner = rel + .inner + .sort(orderbys_after_projection, orderbys_after_projection_desc)?; + } + Ok(()) } fn plan_order_by_exprs( &self, expr: &[sqlparser::ast::OrderByExpr], ) -> SQLPlannerResult<(Vec, Vec)> { + if expr.is_empty() { + unsupported_sql_err!("ORDER BY []"); + } let mut exprs = Vec::with_capacity(expr.len()); let mut desc = Vec::with_capacity(expr.len()); for order_by_expr in expr { @@ -677,7 +851,11 @@ impl SQLPlanner { } } - fn select_item_to_expr(&self, item: &SelectItem) -> SQLPlannerResult> { + fn select_item_to_expr( + &self, + item: &SelectItem, + schema: &Schema, + ) -> SQLPlannerResult> { fn wildcard_exclude( schema: SchemaRef, exclusion: &ExcludeSelectItem, @@ -725,8 +903,14 @@ impl SQLPlanner { .collect::>() }) .map_err(std::convert::Into::into) - } else { + } else if schema.is_empty() { Ok(vec![col("*")]) + } else { + Ok(schema + .names() + .iter() + .map(|n| col(n.as_ref())) + .collect::>()) } } SelectItem::QualifiedWildcard(object_name, wildcard_opts) => { @@ -1594,7 +1778,8 @@ pub fn sql_expr>(s: S) -> SQLPlannerResult { .with_tokens(tokens); let expr = parser.parse_select_item()?; - let exprs = planner.select_item_to_expr(&expr)?; + let empty_schema = Schema::empty(); + let exprs = planner.select_item_to_expr(&expr, &empty_schema)?; if exprs.len() != 1 { invalid_operation_err!("expected a single expression, found {}", exprs.len()) } @@ -1615,3 +1800,40 @@ fn idents_to_str(idents: &[Ident]) -> String { .collect::>() .join(".") } + +/// unresolves an alias in a projection +/// Example: +/// ```sql +/// SELECT a as b, c FROM t group by b +/// ``` +/// in this case if you tried to unresolve the expr `b` using the projections [`a as b`, `c`] you would get `a` +/// +/// Since sql allows you to use the alias in the group by or the order by clause, we need to unresolve the alias to the original expression +/// ex: +/// All of the following are valid sql queries +/// `select a as b, c from t group by b` +/// `select a as b, c from t group by a` +/// `select a as b, c from t group by a order by a` +/// `select a as b, c from t group by a order by b` +/// `select a as b, c from t group by b order by a` +/// `select a as b, c from t group by b order by b` +/// +/// In all of the above cases, the group by and order by clauses are resolved to the original expression `a` +/// +/// This is needed for resolving group by and order by clauses +fn unresolve_alias(expr: ExprRef, projection: &[ExprRef]) -> SQLPlannerResult { + projection + .iter() + .find_map(|p| { + if let Expr::Alias(e, alias) = &p.as_ref() { + if expr.name() == alias.as_ref() { + Some(e.clone()) + } else { + None + } + } else { + None + } + }) + .ok_or_else(|| PlannerError::column_not_found(expr.name(), "projection")) +} diff --git a/tests/sql/test_sql.py b/tests/sql/test_sql.py index c550a1f5a4..2973580d05 100644 --- a/tests/sql/test_sql.py +++ b/tests/sql/test_sql.py @@ -106,7 +106,7 @@ def test_sql_global_agg(): df = daft.sql("SELECT max(n) max_n, sum(n) sum_n FROM test", catalog=catalog) assert df.collect().to_pydict() == {"max_n": [3], "sum_n": [6]} # If there is agg and non-agg, it should fail - with pytest.raises(Exception, match="Expected aggregation"): + with pytest.raises(Exception, match="Column not found"): daft.sql("SELECT n,max(n) max_n FROM test", catalog=catalog) @@ -224,6 +224,42 @@ def test_sql_distinct(): assert actual == expected +@pytest.mark.parametrize( + "query", + [ + "select utf8 from tbl1 order by utf8", + "select utf8 from tbl1 order by utf8 asc", + "select utf8 from tbl1 order by utf8 desc", + "select utf8 as a from tbl1 order by a", + "select utf8 as a from tbl1 order by utf8", + "select utf8 as a from tbl1 order by utf8 asc", + "select utf8 as a from tbl1 order by utf8 desc", + "select utf8 from tbl1 group by utf8 order by utf8", + "select utf8 as a from tbl1 group by utf8 order by utf8", + "select utf8 as a from tbl1 group by a order by utf8", + "select utf8 as a from tbl1 group by a order by a", + "select sum(i32), utf8 as a from tbl1 group by utf8 order by a", + "select sum(i32) as s, utf8 as a from tbl1 group by utf8 order by s", + ], +) +def test_compiles(query): + tbl1 = daft.from_pydict( + { + "utf8": ["group1", "group1", "group2", "group2"], + "i32": [1, 2, 3, 3], + } + ) + catalog = SQLCatalog({"tbl1": tbl1}) + try: + res = daft.sql(query, catalog=catalog) + data = res.collect().to_pydict() + assert data + + except Exception as e: + print(f"Error: {e}") + raise + + def test_sql_cte(): df = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6], "c": ["a", "b", "c"]}) actual = (