Skip to content

Commit

Permalink
add SQLPlanner.new_with_context
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinzwang committed Nov 19, 2024
1 parent 79d8b03 commit 17f815a
Showing 1 changed file with 25 additions and 24 deletions.
49 changes: 25 additions & 24 deletions src/daft-sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@ impl<'a> SQLPlanner<'a> {
}
}

fn new_with_context(&'a self) -> Self {
Self {
context: self.context.clone(),
..Default::default()
}
}

/// SAFETY: it is up to the caller to ensure that the relation is set before calling this method.
/// It's a programming error to call this method without setting the relation first.
/// Some methods such as `plan_expr` do not require the relation to be set.
Expand Down Expand Up @@ -172,7 +179,7 @@ impl<'a> SQLPlanner<'a> {
Ok(())
}

fn plan_ctes(&mut self, with: &With) -> SQLPlannerResult<()> {
fn plan_ctes(&self, with: &With) -> SQLPlannerResult<()> {
if with.recursive {
unsupported_sql_err!("Recursive CTEs are not supported");
}
Expand All @@ -187,7 +194,7 @@ impl<'a> SQLPlanner<'a> {
}

let name = ident_to_str(&cte.alias.name);
let plan = self.plan_query(&cte.query)?;
let plan = self.new_with_context().plan_query(&cte.query)?;
let rel = Relation::new(plan, name);

self.register_cte(rel, cte.alias.columns.as_slice())?;
Expand Down Expand Up @@ -255,33 +262,27 @@ impl<'a> SQLPlanner<'a> {
format_clause: None,
}
}
match (op, set_quantifier) {
(Union, SetQuantifier::All) => {
let left = self.plan_query(&make_query(left))?;
let right = self.plan_query(&make_query(right))?;
return left.union(&right, true).map_err(|e| e.into());
}

let left = self.new_with_context().plan_query(&make_query(left))?;
let right = self.new_with_context().plan_query(&make_query(right))?;

return match (op, set_quantifier) {
(Union, SetQuantifier::All) => left.union(&right, true).map_err(|e| e.into()),

(Union, SetQuantifier::None | SetQuantifier::Distinct) => {
let left = self.plan_query(&make_query(left))?;
let right = self.plan_query(&make_query(right))?;
return left.union(&right, false).map_err(|e| e.into());
left.union(&right, false).map_err(|e| e.into())
}

(Intersect, SetQuantifier::All) => {
let left = self.plan_query(&make_query(left))?;
let right = self.plan_query(&make_query(right))?;
return left.intersect(&right, true).map_err(|e| e.into());
left.intersect(&right, true).map_err(|e| e.into())
}
(Intersect, SetQuantifier::None | SetQuantifier::Distinct) => {
let left = self.plan_query(&make_query(left))?;
let right = self.plan_query(&make_query(right))?;
return left.intersect(&right, false).map_err(|e| e.into());
left.intersect(&right, false).map_err(|e| e.into())
}
(op, set_quantifier) => {
unsupported_sql_err!("{op} {set_quantifier} is not supported.")
}
}
};
}
SetExpr::Values(..) => unsupported_sql_err!("VALUES are not supported"),
SetExpr::Insert(..) => unsupported_sql_err!("INSERT is not supported"),
Expand Down Expand Up @@ -677,10 +678,10 @@ impl<'a> SQLPlanner<'a> {
let mut from_iter = from.iter();

let first = from_iter.next().unwrap();
let mut rel = self.plan_relation(&first.relation)?;
let mut rel = self.new_with_context().plan_relation(&first.relation)?;
self.table_map.insert(rel.get_name(), rel.clone());
for tbl in from_iter {
let right = self.plan_relation(&tbl.relation)?;
let right = self.new_with_context().plan_relation(&tbl.relation)?;
self.table_map.insert(right.get_name(), right.clone());
let right_join_prefix = Some(format!("{}.", right.get_name()));

Expand Down Expand Up @@ -797,15 +798,15 @@ impl<'a> SQLPlanner<'a> {
}

let relation = from.relation.clone();
let mut left_rel = self.plan_relation(&relation)?;
let mut left_rel = self.new_with_context().plan_relation(&relation)?;
self.table_map.insert(left_rel.get_name(), left_rel.clone());

for join in &from.joins {
use sqlparser::ast::{
JoinConstraint,
JoinOperator::{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter},
};
let right_rel = self.plan_relation(&join.relation)?;
let right_rel = self.new_with_context().plan_relation(&join.relation)?;
self.table_map
.insert(right_rel.get_name(), right_rel.clone());
let right_rel_name = right_rel.get_name();
Expand Down Expand Up @@ -855,7 +856,7 @@ impl<'a> SQLPlanner<'a> {
Ok(left_rel)
}

fn plan_relation(&mut self, rel: &sqlparser::ast::TableFactor) -> SQLPlannerResult<Relation> {
fn plan_relation(&self, rel: &sqlparser::ast::TableFactor) -> SQLPlannerResult<Relation> {
let (rel, alias) = match rel {
sqlparser::ast::TableFactor::Table {
name,
Expand Down Expand Up @@ -897,7 +898,7 @@ impl<'a> SQLPlanner<'a> {
if *lateral {
unsupported_sql_err!("LATERAL");
}
let subquery = self.plan_query(subquery)?;
let subquery = self.new_with_context().plan_query(subquery)?;
let rel_name = ident_to_str(&alias.name);
let rel = Relation::new(subquery, rel_name);

Expand Down

0 comments on commit 17f815a

Please sign in to comment.