diff --git a/src/daft-sql/src/lib.rs b/src/daft-sql/src/lib.rs index a66fc26541..954bacd997 100644 --- a/src/daft-sql/src/lib.rs +++ b/src/daft-sql/src/lib.rs @@ -154,9 +154,10 @@ mod tests { #[case::orderby_multi("select * from tbl1 order by i32 desc, f32 asc")] #[case::whenthen("select case when i32 = 1 then 'a' else 'b' end from tbl1")] #[case::globalagg("select max(i32) from tbl1")] + #[case::cte("with cte as (select * from tbl1) select * from cte")] fn test_compiles(mut planner: SQLPlanner, #[case] query: &str) -> SQLPlannerResult<()> { let plan = planner.plan_sql(query); - assert!(plan.is_ok(), "query: {query}\nerror: {plan:?}"); + assert!(&plan.is_ok(), "query: {query}\nerror: {plan:?}"); Ok(()) } diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index d81796a4a6..b27a0060ce 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -12,9 +12,9 @@ use daft_plan::{LogicalPlanBuilder, LogicalPlanRef}; use sqlparser::{ ast::{ ArrayElemTypeDef, BinaryOperator, CastKind, DateTimeField, Distinct, ExactNumberInfo, - ExcludeSelectItem, GroupByExpr, Ident, Query, SelectItem, Statement, StructField, + ExcludeSelectItem, GroupByExpr, Ident, Query, SelectItem, SetExpr, Statement, StructField, Subscript, TableAlias, TableWithJoins, TimezoneInfo, UnaryOperator, Value, - WildcardAdditionalOptions, + WildcardAdditionalOptions, With, }, dialect::GenericDialect, parser::{Parser, ParserOptions}, @@ -66,14 +66,16 @@ pub struct SQLPlanner { catalog: SQLCatalog, current_relation: Option, table_map: HashMap, + cte_map: HashMap, } impl Default for SQLPlanner { fn default() -> Self { Self { catalog: SQLCatalog::new(), - current_relation: None, - table_map: HashMap::new(), + current_relation: Default::default(), + table_map: Default::default(), + cte_map: Default::default(), } } } @@ -82,8 +84,7 @@ impl SQLPlanner { pub fn new(context: SQLCatalog) -> Self { Self { catalog: context, - current_relation: None, - table_map: HashMap::new(), + ..Default::default() } } @@ -102,6 +103,69 @@ impl SQLPlanner { fn clear_context(&mut self) { self.current_relation = None; self.table_map.clear(); + self.cte_map.clear(); + } + + fn get_table_from_current_scope(&self, name: &str) -> Option { + let table = self.table_map.get(name).cloned(); + table + .or_else(|| self.cte_map.get(name).cloned()) + .or_else(|| { + self.catalog + .get_table(name) + .map(|table| Relation::new(table.into(), name.to_string())) + }) + } + + fn register_cte( + &mut self, + mut rel: Relation, + column_aliases: &[Ident], + ) -> SQLPlannerResult<()> { + if !column_aliases.is_empty() { + let schema = rel.schema(); + let columns = schema.names(); + if columns.len() != column_aliases.len() { + invalid_operation_err!( + "Column count mismatch: expected {} columns, found {}", + column_aliases.len(), + columns.len() + ); + } + + let projection = columns + .into_iter() + .zip(column_aliases) + .map(|(name, alias)| col(name).alias(ident_to_str(alias))) + .collect::>(); + + rel.inner = rel.inner.select(projection)?; + } + self.cte_map.insert(rel.get_name(), rel); + Ok(()) + } + + fn plan_ctes(&mut self, with: &With) -> SQLPlannerResult<()> { + if with.recursive { + unsupported_sql_err!("Recursive CTEs are not supported"); + } + + for cte in &with.cte_tables { + if cte.materialized.is_some() { + unsupported_sql_err!("MATERIALIZED is not supported"); + } + + if cte.from.is_some() { + invalid_operation_err!("FROM should only exist in recursive CTEs"); + } + + let name = ident_to_str(&cte.alias.name); + let plan = self.plan_query(&cte.query)?; + let rel = Relation::new(plan, name); + + self.register_cte(rel, cte.alias.columns.as_slice())?; + } + Ok(()) } pub fn plan_sql(&mut self, sql: &str) -> SQLPlannerResult { @@ -136,15 +200,24 @@ impl SQLPlanner { fn plan_query(&mut self, query: &Query) -> SQLPlannerResult { check_query_features(query)?; - let selection = query.body.as_select().ok_or_else(|| { - PlannerError::invalid_operation(format!( - "Only SELECT queries are supported, got: '{}'", - query.body - )) - })?; + let selection = match query.body.as_ref() { + SetExpr::Select(selection) => selection, + SetExpr::Query(_) => unsupported_sql_err!("Subqueries are not supported"), + SetExpr::SetOperation { .. } => { + unsupported_sql_err!("Set operations are not supported") + } + SetExpr::Values(..) => unsupported_sql_err!("VALUES are not supported"), + SetExpr::Insert(..) => unsupported_sql_err!("INSERT is not supported"), + SetExpr::Update(..) => unsupported_sql_err!("UPDATE is not supported"), + SetExpr::Table(..) => unsupported_sql_err!("TABLE is not supported"), + }; check_select_features(selection)?; + if let Some(with) = &query.with { + self.plan_ctes(with)?; + } + // FROM/JOIN let from = selection.clone().from; let rel = self.plan_from(&from)?; @@ -480,7 +553,7 @@ impl SQLPlanner { Ok(left_rel) } - fn plan_relation(&self, rel: &sqlparser::ast::TableFactor) -> SQLPlannerResult { + fn plan_relation(&mut self, rel: &sqlparser::ast::TableFactor) -> SQLPlannerResult { let (rel, alias) = match rel { sqlparser::ast::TableFactor::Table { name, @@ -498,12 +571,48 @@ impl SQLPlanner { .. } => { let table_name = name.to_string(); - let plan = self - .catalog - .get_table(&table_name) - .ok_or_else(|| PlannerError::table_not_found(table_name.clone()))?; - let plan_builder = LogicalPlanBuilder::new(plan, None); - (Relation::new(plan_builder, table_name), alias.clone()) + let Some(rel) = self.get_table_from_current_scope(&table_name) else { + table_not_found_err!(table_name) + }; + (rel, alias.clone()) + } + sqlparser::ast::TableFactor::Derived { + lateral, + subquery, + alias: Some(alias), + } => { + if *lateral { + unsupported_sql_err!("LATERAL"); + } + let subquery = self.plan_query(subquery)?; + let rel_name = ident_to_str(&alias.name); + let rel = Relation::new(subquery, rel_name); + + (rel, Some(alias.clone())) + } + sqlparser::ast::TableFactor::TableFunction { .. } => { + unsupported_sql_err!("Unsupported table factor: TableFunction") + } + sqlparser::ast::TableFactor::Function { .. } => { + unsupported_sql_err!("Unsupported table factor: Function") + } + sqlparser::ast::TableFactor::UNNEST { .. } => { + unsupported_sql_err!("Unsupported table factor: UNNEST") + } + sqlparser::ast::TableFactor::JsonTable { .. } => { + unsupported_sql_err!("Unsupported table factor: JsonTable") + } + sqlparser::ast::TableFactor::NestedJoin { .. } => { + unsupported_sql_err!("Unsupported table factor: NestedJoin") + } + sqlparser::ast::TableFactor::Pivot { .. } => { + unsupported_sql_err!("Unsupported table factor: Pivot") + } + sqlparser::ast::TableFactor::Unpivot { .. } => { + unsupported_sql_err!("Unsupported table factor: Unpivot") + } + sqlparser::ast::TableFactor::MatchRecognize { .. } => { + unsupported_sql_err!("Unsupported table factor: MatchRecognize") } _ => unsupported_sql_err!("Unsupported table factor"), }; @@ -520,8 +629,7 @@ impl SQLPlanner { let root = idents.next().unwrap(); let root = ident_to_str(root); - - let current_relation = match self.table_map.get(&root) { + let current_relation = match self.get_table_from_current_scope(&root) { Some(rel) => rel, None => { return Err(PlannerError::TableNotFound { @@ -626,7 +734,7 @@ impl SQLPlanner { let Some(rel) = self.relation_opt() else { table_not_found_err!(table_name); }; - let Some(table_rel) = self.table_map.get(&table_name) else { + let Some(table_rel) = self.get_table_from_current_scope(&table_name) else { table_not_found_err!(table_name); }; let right_schema = table_rel.inner.schema(); @@ -673,7 +781,7 @@ impl SQLPlanner { Value::Null => LiteralValue::Null, _ => { return Err(PlannerError::invalid_operation( - "Only string, number, boolean and null literals are supported", + "Only string, number, boolean and null literals are supported. Instead found: `{value}`", )) } }) @@ -683,7 +791,7 @@ impl SQLPlanner { if let sqlparser::ast::Expr::Value(v) = expr { self.value_to_lit(v) } else { - invalid_operation_err!("Only string, number, boolean and null literals are supported"); + invalid_operation_err!("Only string, number, boolean and null literals are supported. Instead found: `{expr}`"); } } pub(crate) fn plan_expr(&self, expr: &sqlparser::ast::Expr) -> SQLPlannerResult { @@ -1373,9 +1481,6 @@ impl SQLPlanner { /// /// This function examines various clauses and options in the provided [sqlparser::ast::Query] /// and returns an error if any unsupported features are encountered. fn check_query_features(query: &sqlparser::ast::Query) -> SQLPlannerResult<()> { - if let Some(with) = &query.with { - unsupported_sql_err!("WITH: {with}") - } if !query.limit_by.is_empty() { unsupported_sql_err!("LIMIT BY"); } diff --git a/tests/sql/test_sql.py b/tests/sql/test_sql.py index 6bcd716854..c550a1f5a4 100644 --- a/tests/sql/test_sql.py +++ b/tests/sql/test_sql.py @@ -4,6 +4,7 @@ import pytest import daft +from daft import col from daft.exceptions import DaftCoreException from daft.sql.sql import SQLCatalog from tests.assets import TPCH_QUERIES @@ -221,3 +222,63 @@ def test_sql_distinct(): actual = daft.sql("SELECT DISTINCT n FROM df").collect().to_pydict() expected = df.distinct().collect().to_pydict() assert actual == expected + + +def test_sql_cte(): + df = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6], "c": ["a", "b", "c"]}) + actual = ( + daft.sql(""" + WITH cte1 AS (select * FROM df) + SELECT * FROM cte1 + """) + .collect() + .to_pydict() + ) + + expected = df.collect().to_pydict() + + assert actual == expected + + +def test_sql_cte_column_aliases(): + df = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6], "c": ["a", "b", "c"]}) + actual = ( + daft.sql(""" + WITH cte1 (cte_a, cte_b, cte_c) AS (select * FROM df) + SELECT * FROM cte1 + """) + .collect() + .to_pydict() + ) + + expected = ( + df.select( + col("a").alias("cte_a"), + col("b").alias("cte_b"), + col("c").alias("cte_c"), + ) + .collect() + .to_pydict() + ) + + assert actual == expected + + +def test_sql_multiple_ctes(): + df1 = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6], "c": ["a", "b", "c"]}) + df2 = daft.from_pydict({"x": [1, 0, 3], "y": [True, None, False], "z": [1.0, 2.0, 3.0]}) + actual = ( + daft.sql(""" + WITH + cte1 AS (select * FROM df1), + cte2 AS (select x as a, y, z FROM df2) + SELECT * + FROM cte1 + JOIN cte2 USING (a) + """) + .collect() + .to_pydict() + ) + expected = df1.join(df2.select(col("x").alias("a"), "y", "z"), on="a").collect().to_pydict() + + assert actual == expected