Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT]: Sql common table expressions (CTE's) #3137

Merged
merged 8 commits into from
Oct 30, 2024
3 changes: 2 additions & 1 deletion src/daft-sql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,10 @@ mod tests {
#[case::orderby_multi("select * from tbl1 order by i32 desc, f32 asc")]
#[case::whenthen("select case when i32 = 1 then 'a' else 'b' end from tbl1")]
#[case::globalagg("select max(i32) from tbl1")]
#[case::cte("with cte as (select * from tbl1) select * from cte")]
fn test_compiles(mut planner: SQLPlanner, #[case] query: &str) -> SQLPlannerResult<()> {
let plan = planner.plan_sql(query);
assert!(plan.is_ok(), "query: {query}\nerror: {plan:?}");
assert!(&plan.is_ok(), "query: {query}\nerror: {plan:?}");

Ok(())
}
Expand Down
159 changes: 132 additions & 27 deletions src/daft-sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ use daft_plan::{LogicalPlanBuilder, LogicalPlanRef};
use sqlparser::{
ast::{
ArrayElemTypeDef, BinaryOperator, CastKind, DateTimeField, Distinct, ExactNumberInfo,
ExcludeSelectItem, GroupByExpr, Ident, Query, SelectItem, Statement, StructField,
ExcludeSelectItem, GroupByExpr, Ident, Query, SelectItem, SetExpr, Statement, StructField,
Subscript, TableAlias, TableWithJoins, TimezoneInfo, UnaryOperator, Value,
WildcardAdditionalOptions,
WildcardAdditionalOptions, With,
},
dialect::GenericDialect,
parser::{Parser, ParserOptions},
Expand Down Expand Up @@ -66,14 +66,16 @@ pub struct SQLPlanner {
catalog: SQLCatalog,
current_relation: Option<Relation>,
table_map: HashMap<String, Relation>,
cte_map: HashMap<String, Relation>,
}

impl Default for SQLPlanner {
fn default() -> Self {
Self {
catalog: SQLCatalog::new(),
current_relation: None,
table_map: HashMap::new(),
current_relation: Default::default(),
table_map: Default::default(),
cte_map: Default::default(),
}
}
}
Expand All @@ -82,8 +84,7 @@ impl SQLPlanner {
pub fn new(context: SQLCatalog) -> Self {
Self {
catalog: context,
current_relation: None,
table_map: HashMap::new(),
..Default::default()
}
}

Expand All @@ -102,6 +103,69 @@ impl SQLPlanner {
fn clear_context(&mut self) {
self.current_relation = None;
self.table_map.clear();
self.cte_map.clear();
}

fn get_table_from_current_scope(&self, name: &str) -> Option<Relation> {
let table = self.table_map.get(name).cloned();
table
.or_else(|| self.cte_map.get(name).cloned())
.or_else(|| {
self.catalog
.get_table(name)
.map(|table| Relation::new(table.into(), name.to_string()))
})
}

fn register_cte(
&mut self,
mut rel: Relation,
column_aliases: &[Ident],
) -> SQLPlannerResult<()> {
if !column_aliases.is_empty() {
let schema = rel.schema();
let columns = schema.names();
if columns.len() != column_aliases.len() {
invalid_operation_err!(
"Column count mismatch: expected {} columns, found {}",
column_aliases.len(),
columns.len()
);
}

let projection = columns
.into_iter()
.zip(column_aliases)
.map(|(name, alias)| col(name).alias(ident_to_str(alias)))
.collect::<Vec<_>>();

rel.inner = rel.inner.select(projection)?;
}
self.cte_map.insert(rel.get_name(), rel);
Ok(())
}

fn plan_ctes(&mut self, with: &With) -> SQLPlannerResult<()> {
if with.recursive {
unsupported_sql_err!("Recursive CTEs are not supported");
}

for cte in &with.cte_tables {
if cte.materialized.is_some() {
unsupported_sql_err!("MATERIALIZED is not supported");
}

if cte.from.is_some() {
invalid_operation_err!("FROM should only exist in recursive CTEs");
}

let name = ident_to_str(&cte.alias.name);
let plan = self.plan_query(&cte.query)?;
let rel = Relation::new(plan, name);

self.register_cte(rel, cte.alias.columns.as_slice())?;
}
Ok(())
}

pub fn plan_sql(&mut self, sql: &str) -> SQLPlannerResult<LogicalPlanRef> {
Expand Down Expand Up @@ -136,15 +200,24 @@ impl SQLPlanner {
fn plan_query(&mut self, query: &Query) -> SQLPlannerResult<LogicalPlanBuilder> {
check_query_features(query)?;

let selection = query.body.as_select().ok_or_else(|| {
PlannerError::invalid_operation(format!(
"Only SELECT queries are supported, got: '{}'",
query.body
))
})?;
let selection = match query.body.as_ref() {
SetExpr::Select(selection) => selection,
SetExpr::Query(_) => unsupported_sql_err!("Subqueries are not supported"),
SetExpr::SetOperation { .. } => {
unsupported_sql_err!("Set operations are not supported")
}
SetExpr::Values(..) => unsupported_sql_err!("VALUES are not supported"),
SetExpr::Insert(..) => unsupported_sql_err!("INSERT is not supported"),
SetExpr::Update(..) => unsupported_sql_err!("UPDATE is not supported"),
SetExpr::Table(..) => unsupported_sql_err!("TABLE is not supported"),
};
Comment on lines +203 to +213
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a small drive-by to provide better error messages here.


check_select_features(selection)?;

if let Some(with) = &query.with {
self.plan_ctes(with)?;
}

// FROM/JOIN
let from = selection.clone().from;
let rel = self.plan_from(&from)?;
Expand Down Expand Up @@ -480,7 +553,7 @@ impl SQLPlanner {
Ok(left_rel)
}

fn plan_relation(&self, rel: &sqlparser::ast::TableFactor) -> SQLPlannerResult<Relation> {
fn plan_relation(&mut self, rel: &sqlparser::ast::TableFactor) -> SQLPlannerResult<Relation> {
let (rel, alias) = match rel {
sqlparser::ast::TableFactor::Table {
name,
Expand All @@ -498,12 +571,48 @@ impl SQLPlanner {
..
} => {
let table_name = name.to_string();
let plan = self
.catalog
.get_table(&table_name)
.ok_or_else(|| PlannerError::table_not_found(table_name.clone()))?;
let plan_builder = LogicalPlanBuilder::new(plan, None);
(Relation::new(plan_builder, table_name), alias.clone())
let Some(rel) = self.get_table_from_current_scope(&table_name) else {
table_not_found_err!(table_name)
};
(rel, alias.clone())
}
sqlparser::ast::TableFactor::Derived {
lateral,
subquery,
alias: Some(alias),
} => {
if *lateral {
unsupported_sql_err!("LATERAL");
}
let subquery = self.plan_query(subquery)?;
let rel_name = ident_to_str(&alias.name);
let rel = Relation::new(subquery, rel_name);

(rel, Some(alias.clone()))
}
sqlparser::ast::TableFactor::TableFunction { .. } => {
unsupported_sql_err!("Unsupported table factor: TableFunction")
}
sqlparser::ast::TableFactor::Function { .. } => {
unsupported_sql_err!("Unsupported table factor: Function")
}
sqlparser::ast::TableFactor::UNNEST { .. } => {
unsupported_sql_err!("Unsupported table factor: UNNEST")
}
sqlparser::ast::TableFactor::JsonTable { .. } => {
unsupported_sql_err!("Unsupported table factor: JsonTable")
}
sqlparser::ast::TableFactor::NestedJoin { .. } => {
unsupported_sql_err!("Unsupported table factor: NestedJoin")
}
sqlparser::ast::TableFactor::Pivot { .. } => {
unsupported_sql_err!("Unsupported table factor: Pivot")
}
sqlparser::ast::TableFactor::Unpivot { .. } => {
unsupported_sql_err!("Unsupported table factor: Unpivot")
}
sqlparser::ast::TableFactor::MatchRecognize { .. } => {
unsupported_sql_err!("Unsupported table factor: MatchRecognize")
Comment on lines +593 to +615
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another small driveby for better error messages

}
_ => unsupported_sql_err!("Unsupported table factor"),
};
Expand All @@ -520,8 +629,7 @@ impl SQLPlanner {

let root = idents.next().unwrap();
let root = ident_to_str(root);

let current_relation = match self.table_map.get(&root) {
let current_relation = match self.get_table_from_current_scope(&root) {
Some(rel) => rel,
None => {
return Err(PlannerError::TableNotFound {
Expand Down Expand Up @@ -626,7 +734,7 @@ impl SQLPlanner {
let Some(rel) = self.relation_opt() else {
table_not_found_err!(table_name);
};
let Some(table_rel) = self.table_map.get(&table_name) else {
let Some(table_rel) = self.get_table_from_current_scope(&table_name) else {
table_not_found_err!(table_name);
};
let right_schema = table_rel.inner.schema();
Expand Down Expand Up @@ -673,7 +781,7 @@ impl SQLPlanner {
Value::Null => LiteralValue::Null,
_ => {
return Err(PlannerError::invalid_operation(
"Only string, number, boolean and null literals are supported",
"Only string, number, boolean and null literals are supported. Instead found: `{value}`",
))
}
})
Expand All @@ -683,7 +791,7 @@ impl SQLPlanner {
if let sqlparser::ast::Expr::Value(v) = expr {
self.value_to_lit(v)
} else {
invalid_operation_err!("Only string, number, boolean and null literals are supported");
invalid_operation_err!("Only string, number, boolean and null literals are supported. Instead found: `{expr}`");
}
}
pub(crate) fn plan_expr(&self, expr: &sqlparser::ast::Expr) -> SQLPlannerResult<ExprRef> {
Expand Down Expand Up @@ -1373,9 +1481,6 @@ impl SQLPlanner {
/// /// This function examines various clauses and options in the provided [sqlparser::ast::Query]
/// and returns an error if any unsupported features are encountered.
fn check_query_features(query: &sqlparser::ast::Query) -> SQLPlannerResult<()> {
if let Some(with) = &query.with {
unsupported_sql_err!("WITH: {with}")
}
if !query.limit_by.is_empty() {
unsupported_sql_err!("LIMIT BY");
}
Expand Down
61 changes: 61 additions & 0 deletions tests/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest

import daft
from daft import col
from daft.exceptions import DaftCoreException
from daft.sql.sql import SQLCatalog
from tests.assets import TPCH_QUERIES
Expand Down Expand Up @@ -221,3 +222,63 @@ def test_sql_distinct():
actual = daft.sql("SELECT DISTINCT n FROM df").collect().to_pydict()
expected = df.distinct().collect().to_pydict()
assert actual == expected


def test_sql_cte():
df = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6], "c": ["a", "b", "c"]})
actual = (
daft.sql("""
WITH cte1 AS (select * FROM df)
SELECT * FROM cte1
""")
.collect()
.to_pydict()
)

expected = df.collect().to_pydict()

assert actual == expected


def test_sql_cte_column_aliases():
df = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6], "c": ["a", "b", "c"]})
actual = (
daft.sql("""
WITH cte1 (cte_a, cte_b, cte_c) AS (select * FROM df)
SELECT * FROM cte1
""")
.collect()
.to_pydict()
)

expected = (
df.select(
col("a").alias("cte_a"),
col("b").alias("cte_b"),
col("c").alias("cte_c"),
)
.collect()
.to_pydict()
)

assert actual == expected


def test_sql_multiple_ctes():
df1 = daft.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6], "c": ["a", "b", "c"]})
df2 = daft.from_pydict({"x": [1, 0, 3], "y": [True, None, False], "z": [1.0, 2.0, 3.0]})
actual = (
daft.sql("""
WITH
cte1 AS (select * FROM df1),
cte2 AS (select x as a, y, z FROM df2)
SELECT *
FROM cte1
JOIN cte2 USING (a)
""")
.collect()
.to_pydict()
)
expected = df1.join(df2.select(col("x").alias("a"), "y", "z"), on="a").collect().to_pydict()

assert actual == expected
Loading