Skip to content

Commit

Permalink
[FEAT] Filter predicates in SQL join
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinzwang committed Nov 20, 2024
1 parent b6695eb commit fb83763
Showing 1 changed file with 162 additions and 125 deletions.
287 changes: 162 additions & 125 deletions src/daft-sql/src/planner.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::{
borrow::Cow,
cell::{Ref, RefCell, RefMut},
collections::{HashMap, HashSet},
rc::Rc,
Expand All @@ -11,8 +10,9 @@ use daft_core::prelude::*;
use daft_dsl::{
col,
common_treenode::{Transformed, TreeNode},
has_agg, lit, literals_to_series, null_lit, AggExpr, Expr, ExprRef, LiteralValue, Operator,
OuterReferenceColumn, Subquery,
has_agg, lit, literals_to_series, null_lit,
optimization::conjuct,
AggExpr, Expr, ExprRef, LiteralValue, Operator, OuterReferenceColumn, Subquery,
};
use daft_functions::{
numeric::{ceil::ceil, floor::floor},
Expand Down Expand Up @@ -297,9 +297,8 @@ impl<'a> SQLPlanner<'a> {

// FROM/JOIN
let from = selection.clone().from;
let rel = self.plan_from(&from)?;
let schema = rel.schema();
self.current_relation = Some(rel);
self.plan_from(&from)?;
let schema = self.relation_opt().unwrap().schema();

// SELECT
let mut projections = Vec::with_capacity(selection.projection.len());
Expand Down Expand Up @@ -756,145 +755,163 @@ impl<'a> SQLPlanner<'a> {
Ok((exprs, desc, nulls_first))
}

fn plan_from(&mut self, from: &[TableWithJoins]) -> SQLPlannerResult<Relation> {
/// Plans the FROM clause of a query and populates self.current_relation and self.table_map
/// Should only be called once per query.
fn plan_from(&mut self, from: &[TableWithJoins]) -> SQLPlannerResult<()> {
if from.len() > 1 {
let mut from_iter = from.iter();

let first = from_iter.next().unwrap();
let mut rel = self.new_with_context().plan_relation(&first.relation)?;
let mut rel = self.plan_relation(&first.relation)?;
self.table_map.insert(rel.get_name(), rel.clone());
for tbl in from_iter {
let right = self.new_with_context().plan_relation(&tbl.relation)?;
let right = self.plan_relation(&tbl.relation)?;
self.table_map.insert(right.get_name(), right.clone());
let right_join_prefix = Some(format!("{}.", right.get_name()));

rel.inner =
rel.inner
.cross_join(right.inner, None, right_join_prefix.as_deref())?;
}
return Ok(rel);
self.current_relation = Some(rel);
return Ok(());
}

let from = from.iter().next().unwrap();

fn collect_idents(
left: &[Ident],
right: &[Ident],
left_rel: &Relation,
right_rel: &Relation,
) -> SQLPlannerResult<(Vec<ExprRef>, Vec<ExprRef>)> {
let (left, right) = match (left, right) {
// both are fully qualified: `join on a.x = b.y`
([tbl_a, Ident{value: col_a, ..}], [tbl_b, Ident{value: col_b, ..}]) => {
if left_rel.get_name() == tbl_b.value && right_rel.get_name() == tbl_a.value {
(col_b.clone(), col_a.clone())
} else {
(col_a.clone(), col_b.clone())
}
macro_rules! return_non_ident_errors {
($e:expr) => {
if !matches!(
$e,
PlannerError::ColumnNotFound { .. } | PlannerError::TableNotFound { .. }
) {
return Err($e);
}
// only one is fully qualified: `join on x = b.y`
([Ident{value: col_a, ..}], [tbl_b, Ident{value: col_b, ..}]) => {
if tbl_b.value == right_rel.get_name() {
(col_a.clone(), col_b.clone())
} else if tbl_b.value == left_rel.get_name() {
(col_b.clone(), col_a.clone())
} else {
unsupported_sql_err!("Could not determine which table the identifiers belong to")
}
};
}

#[allow(clippy::too_many_arguments)]
fn process_join_on(
sql_expr: &sqlparser::ast::Expr,
left_planner: &SQLPlanner,
right_planner: &SQLPlanner,
left_on: &mut Vec<ExprRef>,
right_on: &mut Vec<ExprRef>,
null_eq_nulls: &mut Vec<bool>,
left_filters: &mut Vec<ExprRef>,
right_filters: &mut Vec<ExprRef>,
) -> SQLPlannerResult<()> {
// check if join expression is actually a filter on one of the tables
match (
left_planner.plan_expr(sql_expr),
right_planner.plan_expr(sql_expr),
) {
(Ok(_), Ok(_)) => {
return Err(PlannerError::invalid_operation(format!(
"Ambiguous reference to column name in join: {}",
sql_expr
)));
}
// only one is fully qualified: `join on a.x = y`
([tbl_a, Ident{value: col_a, ..}], [Ident{value: col_b, ..}]) => {
// find out which one the qualified identifier belongs to
// we assume the other identifier belongs to the other table
if tbl_a.value == left_rel.get_name() {
(col_a.clone(), col_b.clone())
} else if tbl_a.value == right_rel.get_name() {
(col_b.clone(), col_a.clone())
} else {
unsupported_sql_err!("Could not determine which table the identifiers belong to")
}
(Ok(expr), _) => {
left_filters.push(expr);
return Ok(());
}
// neither are fully qualified: `join on x = y`
([left], [right]) => {
let left = ident_to_str(left);
let right = ident_to_str(right);

// we don't know which table the identifiers belong to, so we need to check both
let left_schema = left_rel.schema();
let right_schema = right_rel.schema();

// if the left side is in the left schema, then we assume the right side is in the right schema
if left_schema.get_field(&left).is_ok() {
(left, right)
// if the right side is in the left schema, then we assume the left side is in the right schema
} else if right_schema.get_field(&left).is_ok() {
(right, left)
} else {
unsupported_sql_err!("JOIN clauses must reference columns in the joined tables; found `{}`", left);
}
(_, Ok(expr)) => {
right_filters.push(expr);
return Ok(());
}
(Err(left_err), Err(right_err)) => {
return_non_ident_errors!(left_err);
return_non_ident_errors!(right_err);
}
}

match sql_expr {
// join key
sqlparser::ast::Expr::BinaryOp {
left,
right,
op: op @ BinaryOperator::Eq,
}
_ => unsupported_sql_err!(
"collect_compound_identifiers: Expected left.len() == 2 && right.len() == 2, but found left.len() == {:?}, right.len() == {:?}",
left.len(),
right.len()
),
};
Ok((vec![col(left)], vec![col(right)]))
}
| sqlparser::ast::Expr::BinaryOp {
left,
right,
op: op @ BinaryOperator::Spaceship,
} => {
let null_equals_null = *op == BinaryOperator::Spaceship;

fn process_join_on(
expression: &sqlparser::ast::Expr,
left_rel: &Relation,
right_rel: &Relation,
) -> SQLPlannerResult<(Vec<ExprRef>, Vec<ExprRef>, Vec<bool>)> {
if let sqlparser::ast::Expr::BinaryOp { left, op, right } = expression {
match *op {
BinaryOperator::Eq | BinaryOperator::Spaceship => {
let null_equals_null = *op == BinaryOperator::Spaceship;

let left = get_idents_vec(left)?;
let right = get_idents_vec(right)?;

collect_idents(&left, &right, left_rel, right_rel)
.map(|(left, right)| (left, right, vec![null_equals_null]))
}
BinaryOperator::And => {
let (mut left_i, mut right_i, mut null_equals_nulls_i) =
process_join_on(left, left_rel, right_rel)?;
let (mut left_j, mut right_j, mut null_equals_nulls_j) =
process_join_on(right, left_rel, right_rel)?;
left_i.append(&mut left_j);
right_i.append(&mut right_j);
null_equals_nulls_i.append(&mut null_equals_nulls_j);
Ok((left_i, right_i, null_equals_nulls_i))
}
_ => {
unsupported_sql_err!("JOIN clauses support '=' constraints combined with 'AND'; found op = '{}'", op);
let mut last_error = None;

for (left, right) in [(left, right), (right, left)] {
let left_expr = left_planner.plan_expr(left);
let right_expr = right_planner.plan_expr(right);

if let Ok(left_expr) = &left_expr && let Ok(right_expr) = &right_expr {
left_on.push(left_expr.clone());
right_on.push(right_expr.clone());
null_eq_nulls.push(null_equals_null);

return Ok(())
}

for expr_result in [left_expr, right_expr] {
if let Err(e) = expr_result {
return_non_ident_errors!(e);

last_error = Some(e);
}
}
}

Err(last_error.unwrap())
}
} else if let sqlparser::ast::Expr::Nested(expr) = expression {
process_join_on(expr, left_rel, right_rel)
} else {
unsupported_sql_err!("JOIN clauses support '=' constraints combined with 'AND'; found expression = {:?}", expression);
// multiple expressions
sqlparser::ast::Expr::BinaryOp {
left,
right,
op: BinaryOperator::And,
} => {
process_join_on(left, left_planner, right_planner, left_on, right_on, null_eq_nulls, left_filters, right_filters)?;
process_join_on(right, left_planner, right_planner, left_on, right_on, null_eq_nulls, left_filters, right_filters)?;

Ok(())
}
// nested expression
sqlparser::ast::Expr::Nested(expr) => process_join_on(
expr,
left_planner,
right_planner,
left_on,
right_on,
null_eq_nulls,
left_filters,
right_filters,
),
_ => unsupported_sql_err!("JOIN clauses support '=' constraints and filter predicates combined with 'AND'; found expression = {:?}", sql_expr)
}
}

let relation = from.relation.clone();
let mut left_rel = self.new_with_context().plan_relation(&relation)?;
self.table_map.insert(left_rel.get_name(), left_rel.clone());
let left_rel = self.plan_relation(&relation)?;
self.current_relation = Some(left_rel.clone());
self.table_map.insert(left_rel.get_name(), left_rel);

for join in &from.joins {
use sqlparser::ast::{
JoinConstraint,
JoinOperator::{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter},
};
let right_rel = self.new_with_context().plan_relation(&join.relation)?;
self.table_map
.insert(right_rel.get_name(), right_rel.clone());
let right_rel = self.plan_relation(&join.relation)?;
let right_rel_name = right_rel.get_name();
let right_join_prefix = Some(format!("{right_rel_name}."));

// construct a planner with the right table to use for expr planning
let mut right_planner = self.new_with_context();
right_planner.current_relation = Some(right_rel.clone());
right_planner
.table_map
.insert(right_rel.get_name(), right_rel.clone());

let (join_type, constraint) = match &join.join_operator {
Inner(constraint) => (JoinType::Inner, constraint),
LeftOuter(constraint) => (JoinType::Left, constraint),
Expand All @@ -906,37 +923,65 @@ impl<'a> SQLPlanner<'a> {
_ => unsupported_sql_err!("Unsupported join type: {:?}", join.join_operator),
};

let (left_on, right_on, null_eq_null, keep_join_keys) = match &constraint {
let mut left_on = Vec::new();
let mut right_on = Vec::new();
let mut null_eq_null = Vec::new();
let mut left_filters = Vec::new();
let mut right_filters = Vec::new();

let keep_join_keys = match &constraint {
JoinConstraint::On(expr) => {
let (left_on, right_on, null_equals_nulls) =
process_join_on(expr, &left_rel, &right_rel)?;
(left_on, right_on, Some(null_equals_nulls), true)
process_join_on(
expr,
self,
&right_planner,
&mut left_on,
&mut right_on,
&mut null_eq_null,
&mut left_filters,
&mut right_filters,
)?;

true
}
JoinConstraint::Using(idents) => {
let on = idents
left_on = idents
.iter()
.map(|i| col(i.value.clone()))
.collect::<Vec<_>>();
(on.clone(), on, None, false)
right_on.clone_from(&left_on);

false
}
JoinConstraint::Natural => unsupported_sql_err!("NATURAL JOIN not supported"),
JoinConstraint::None => unsupported_sql_err!("JOIN without ON/USING not supported"),
};

left_rel.inner = left_rel.inner.join_with_null_safe_equal(
right_rel.inner,
let mut left_plan = self.current_relation.as_ref().unwrap().inner.clone();
if let Some(left_predicate) = conjuct(left_filters) {
left_plan = left_plan.filter(left_predicate)?;
}

let mut right_plan = right_rel.inner.clone();
if let Some(right_predicate) = conjuct(right_filters) {
right_plan = right_plan.filter(right_predicate)?;
}

self.relation_mut().inner = left_plan.join_with_null_safe_equal(
right_plan,
left_on,
right_on,
null_eq_null,
Some(null_eq_null),
join_type,
None,
None,
right_join_prefix.as_deref(),
keep_join_keys,
)?;
self.table_map.insert(right_rel_name, right_rel);
}

Ok(left_rel)
Ok(())
}

fn plan_relation(&self, rel: &sqlparser::ast::TableFactor) -> SQLPlannerResult<Relation> {
Expand Down Expand Up @@ -2105,14 +2150,6 @@ fn idents_to_str(idents: &[Ident]) -> String {
.join(".")
}

fn get_idents_vec(expr: &sqlparser::ast::Expr) -> SQLPlannerResult<Cow<Vec<Ident>>> {
match expr {
sqlparser::ast::Expr::Identifier(ident) => Ok(Cow::Owned(vec![ident.clone()])),
sqlparser::ast::Expr::CompoundIdentifier(idents) => Ok(Cow::Borrowed(idents)),
_ => invalid_operation_err!("expected an identifier"),
}
}

/// unresolves an alias in a projection
/// Example:
/// ```sql
Expand Down

0 comments on commit fb83763

Please sign in to comment.