diff --git a/src/daft-sql/src/lib.rs b/src/daft-sql/src/lib.rs index 75bb2b07c0..c304cad020 100644 --- a/src/daft-sql/src/lib.rs +++ b/src/daft-sql/src/lib.rs @@ -72,6 +72,7 @@ mod tests { Schema::new(vec![ Field::new("text", DataType::Utf8), Field::new("id", DataType::Int32), + Field::new("val", DataType::Int32), ]) .unwrap(), ); @@ -138,6 +139,7 @@ mod tests { #[case::slice("select list_utf8[0:2] from tbl1")] #[case::join("select * from tbl2 join tbl3 on tbl2.id = tbl3.id")] #[case::null_safe_join("select * from tbl2 left join tbl3 on tbl2.id <=> tbl3.id")] + #[case::join_with_filter("select * from tbl2 join tbl3 on tbl2.id = tbl3.id and tbl2.val > 0")] #[case::from("select tbl2.text from tbl2")] #[case::using("select tbl2.text from tbl2 join tbl3 using (id)")] #[case( @@ -301,6 +303,34 @@ mod tests { Ok(()) } + #[rstest] + fn test_join_with_filter( + mut planner: SQLPlanner, + tbl_2: LogicalPlanRef, + tbl_3: LogicalPlanRef, + ) -> SQLPlannerResult<()> { + let sql = "select * from tbl2 join tbl3 on tbl2.id = tbl3.id and tbl2.val > 0"; + let plan = planner.plan_sql(&sql)?; + + let expected = LogicalPlanBuilder::new(tbl_2, None) + .filter(col("val").gt(lit(0 as i64)))? + .join_with_null_safe_equal( + tbl_3, + vec![col("id")], + vec![col("id")], + Some(vec![false]), + JoinType::Inner, + None, + None, + Some("tbl3."), + true, + )? + .select(vec![col("*")])? + .build(); + assert_eq!(plan, expected); + Ok(()) + } + #[rstest] #[case::abs("select abs(i32) as abs from tbl1")] #[case::ceil("select ceil(i32) as ceil from tbl1")] diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index f6334c5632..a8e7a90c7b 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -1,5 +1,4 @@ use std::{ - borrow::Cow, cell::{Ref, RefCell, RefMut}, collections::{HashMap, HashSet}, rc::Rc, @@ -11,8 +10,9 @@ use daft_core::prelude::*; use daft_dsl::{ col, common_treenode::{Transformed, TreeNode}, - has_agg, lit, literals_to_series, null_lit, AggExpr, Expr, ExprRef, LiteralValue, Operator, - OuterReferenceColumn, Subquery, + has_agg, lit, literals_to_series, null_lit, + optimization::conjuct, + AggExpr, Expr, ExprRef, LiteralValue, Operator, OuterReferenceColumn, Subquery, }; use daft_functions::{ numeric::{ceil::ceil, floor::floor}, @@ -297,9 +297,8 @@ impl<'a> SQLPlanner<'a> { // FROM/JOIN let from = selection.clone().from; - let rel = self.plan_from(&from)?; - let schema = rel.schema(); - self.current_relation = Some(rel); + self.plan_from(&from)?; + let schema = self.relation_opt().unwrap().schema(); // SELECT let mut projections = Vec::with_capacity(selection.projection.len()); @@ -756,15 +755,17 @@ impl<'a> SQLPlanner<'a> { Ok((exprs, desc, nulls_first)) } - fn plan_from(&mut self, from: &[TableWithJoins]) -> SQLPlannerResult { + /// Plans the FROM clause of a query and populates self.current_relation and self.table_map + /// Should only be called once per query. + fn plan_from(&mut self, from: &[TableWithJoins]) -> SQLPlannerResult<()> { if from.len() > 1 { let mut from_iter = from.iter(); let first = from_iter.next().unwrap(); - let mut rel = self.new_with_context().plan_relation(&first.relation)?; + let mut rel = self.plan_relation(&first.relation)?; self.table_map.insert(rel.get_name(), rel.clone()); for tbl in from_iter { - let right = self.new_with_context().plan_relation(&tbl.relation)?; + let right = self.plan_relation(&tbl.relation)?; self.table_map.insert(right.get_name(), right.clone()); let right_join_prefix = Some(format!("{}.", right.get_name())); @@ -772,129 +773,145 @@ impl<'a> SQLPlanner<'a> { rel.inner .cross_join(right.inner, None, right_join_prefix.as_deref())?; } - return Ok(rel); + self.current_relation = Some(rel); + return Ok(()); } let from = from.iter().next().unwrap(); - fn collect_idents( - left: &[Ident], - right: &[Ident], - left_rel: &Relation, - right_rel: &Relation, - ) -> SQLPlannerResult<(Vec, Vec)> { - let (left, right) = match (left, right) { - // both are fully qualified: `join on a.x = b.y` - ([tbl_a, Ident{value: col_a, ..}], [tbl_b, Ident{value: col_b, ..}]) => { - if left_rel.get_name() == tbl_b.value && right_rel.get_name() == tbl_a.value { - (col_b.clone(), col_a.clone()) - } else { - (col_a.clone(), col_b.clone()) - } + macro_rules! return_non_ident_errors { + ($e:expr) => { + if !matches!( + $e, + PlannerError::ColumnNotFound { .. } | PlannerError::TableNotFound { .. } + ) { + return Err($e); } - // only one is fully qualified: `join on x = b.y` - ([Ident{value: col_a, ..}], [tbl_b, Ident{value: col_b, ..}]) => { - if tbl_b.value == right_rel.get_name() { - (col_a.clone(), col_b.clone()) - } else if tbl_b.value == left_rel.get_name() { - (col_b.clone(), col_a.clone()) - } else { - unsupported_sql_err!("Could not determine which table the identifiers belong to") - } + }; + } + + #[allow(clippy::too_many_arguments)] + fn process_join_on( + sql_expr: &sqlparser::ast::Expr, + left_planner: &SQLPlanner, + right_planner: &SQLPlanner, + left_on: &mut Vec, + right_on: &mut Vec, + null_eq_nulls: &mut Vec, + left_filters: &mut Vec, + right_filters: &mut Vec, + ) -> SQLPlannerResult<()> { + // check if join expression is actually a filter on one of the tables + match ( + left_planner.plan_expr(sql_expr), + right_planner.plan_expr(sql_expr), + ) { + (Ok(_), Ok(_)) => { + return Err(PlannerError::invalid_operation(format!( + "Ambiguous reference to column name in join: {}", + sql_expr + ))); } - // only one is fully qualified: `join on a.x = y` - ([tbl_a, Ident{value: col_a, ..}], [Ident{value: col_b, ..}]) => { - // find out which one the qualified identifier belongs to - // we assume the other identifier belongs to the other table - if tbl_a.value == left_rel.get_name() { - (col_a.clone(), col_b.clone()) - } else if tbl_a.value == right_rel.get_name() { - (col_b.clone(), col_a.clone()) - } else { - unsupported_sql_err!("Could not determine which table the identifiers belong to") - } + (Ok(expr), _) => { + left_filters.push(expr); + return Ok(()); } - // neither are fully qualified: `join on x = y` - ([left], [right]) => { - let left = ident_to_str(left); - let right = ident_to_str(right); - - // we don't know which table the identifiers belong to, so we need to check both - let left_schema = left_rel.schema(); - let right_schema = right_rel.schema(); - - // if the left side is in the left schema, then we assume the right side is in the right schema - if left_schema.get_field(&left).is_ok() { - (left, right) - // if the right side is in the left schema, then we assume the left side is in the right schema - } else if right_schema.get_field(&left).is_ok() { - (right, left) - } else { - unsupported_sql_err!("JOIN clauses must reference columns in the joined tables; found `{}`", left); - } + (_, Ok(expr)) => { + right_filters.push(expr); + return Ok(()); + } + (Err(left_err), Err(right_err)) => { + return_non_ident_errors!(left_err); + return_non_ident_errors!(right_err); + } + } + match sql_expr { + // join key + sqlparser::ast::Expr::BinaryOp { + left, + right, + op: op @ BinaryOperator::Eq, } - _ => unsupported_sql_err!( - "collect_compound_identifiers: Expected left.len() == 2 && right.len() == 2, but found left.len() == {:?}, right.len() == {:?}", - left.len(), - right.len() - ), - }; - Ok((vec![col(left)], vec![col(right)])) - } + | sqlparser::ast::Expr::BinaryOp { + left, + right, + op: op @ BinaryOperator::Spaceship, + } => { + let null_equals_null = *op == BinaryOperator::Spaceship; - fn process_join_on( - expression: &sqlparser::ast::Expr, - left_rel: &Relation, - right_rel: &Relation, - ) -> SQLPlannerResult<(Vec, Vec, Vec)> { - if let sqlparser::ast::Expr::BinaryOp { left, op, right } = expression { - match *op { - BinaryOperator::Eq | BinaryOperator::Spaceship => { - let null_equals_null = *op == BinaryOperator::Spaceship; - - let left = get_idents_vec(left)?; - let right = get_idents_vec(right)?; - - collect_idents(&left, &right, left_rel, right_rel) - .map(|(left, right)| (left, right, vec![null_equals_null])) - } - BinaryOperator::And => { - let (mut left_i, mut right_i, mut null_equals_nulls_i) = - process_join_on(left, left_rel, right_rel)?; - let (mut left_j, mut right_j, mut null_equals_nulls_j) = - process_join_on(right, left_rel, right_rel)?; - left_i.append(&mut left_j); - right_i.append(&mut right_j); - null_equals_nulls_i.append(&mut null_equals_nulls_j); - Ok((left_i, right_i, null_equals_nulls_i)) - } - _ => { - unsupported_sql_err!("JOIN clauses support '=' constraints combined with 'AND'; found op = '{}'", op); + let mut last_error = None; + + for (left, right) in [(left, right), (right, left)] { + let left_expr = left_planner.plan_expr(left); + let right_expr = right_planner.plan_expr(right); + + if let Ok(left_expr) = &left_expr && let Ok(right_expr) = &right_expr { + left_on.push(left_expr.clone()); + right_on.push(right_expr.clone()); + null_eq_nulls.push(null_equals_null); + + return Ok(()) + } + + for expr_result in [left_expr, right_expr] { + if let Err(e) = expr_result { + return_non_ident_errors!(e); + + last_error = Some(e); + } + } } + + Err(last_error.unwrap()) } - } else if let sqlparser::ast::Expr::Nested(expr) = expression { - process_join_on(expr, left_rel, right_rel) - } else { - unsupported_sql_err!("JOIN clauses support '=' constraints combined with 'AND'; found expression = {:?}", expression); + // multiple expressions + sqlparser::ast::Expr::BinaryOp { + left, + right, + op: BinaryOperator::And, + } => { + process_join_on(left, left_planner, right_planner, left_on, right_on, null_eq_nulls, left_filters, right_filters)?; + process_join_on(right, left_planner, right_planner, left_on, right_on, null_eq_nulls, left_filters, right_filters)?; + + Ok(()) + } + // nested expression + sqlparser::ast::Expr::Nested(expr) => process_join_on( + expr, + left_planner, + right_planner, + left_on, + right_on, + null_eq_nulls, + left_filters, + right_filters, + ), + _ => unsupported_sql_err!("JOIN clauses support '=' constraints and filter predicates combined with 'AND'; found expression = {:?}", sql_expr) } } let relation = from.relation.clone(); - let mut left_rel = self.new_with_context().plan_relation(&relation)?; - self.table_map.insert(left_rel.get_name(), left_rel.clone()); + let left_rel = self.plan_relation(&relation)?; + self.current_relation = Some(left_rel.clone()); + self.table_map.insert(left_rel.get_name(), left_rel); for join in &from.joins { use sqlparser::ast::{ JoinConstraint, JoinOperator::{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, RightOuter}, }; - let right_rel = self.new_with_context().plan_relation(&join.relation)?; - self.table_map - .insert(right_rel.get_name(), right_rel.clone()); + let right_rel = self.plan_relation(&join.relation)?; let right_rel_name = right_rel.get_name(); let right_join_prefix = Some(format!("{right_rel_name}.")); + // construct a planner with the right table to use for expr planning + let mut right_planner = self.new_with_context(); + right_planner.current_relation = Some(right_rel.clone()); + right_planner + .table_map + .insert(right_rel.get_name(), right_rel.clone()); + let (join_type, constraint) = match &join.join_operator { Inner(constraint) => (JoinType::Inner, constraint), LeftOuter(constraint) => (JoinType::Left, constraint), @@ -906,37 +923,66 @@ impl<'a> SQLPlanner<'a> { _ => unsupported_sql_err!("Unsupported join type: {:?}", join.join_operator), }; - let (left_on, right_on, null_eq_null, keep_join_keys) = match &constraint { + let mut left_on = Vec::new(); + let mut right_on = Vec::new(); + let mut left_filters = Vec::new(); + let mut right_filters = Vec::new(); + + let (keep_join_keys, null_eq_nulls) = match &constraint { JoinConstraint::On(expr) => { - let (left_on, right_on, null_equals_nulls) = - process_join_on(expr, &left_rel, &right_rel)?; - (left_on, right_on, Some(null_equals_nulls), true) + let mut null_eq_nulls = Vec::new(); + + process_join_on( + expr, + self, + &right_planner, + &mut left_on, + &mut right_on, + &mut null_eq_nulls, + &mut left_filters, + &mut right_filters, + )?; + + (true, Some(null_eq_nulls)) } JoinConstraint::Using(idents) => { - let on = idents + left_on = idents .iter() .map(|i| col(i.value.clone())) .collect::>(); - (on.clone(), on, None, false) + right_on.clone_from(&left_on); + + (false, None) } JoinConstraint::Natural => unsupported_sql_err!("NATURAL JOIN not supported"), JoinConstraint::None => unsupported_sql_err!("JOIN without ON/USING not supported"), }; - left_rel.inner = left_rel.inner.join_with_null_safe_equal( - right_rel.inner, + let mut left_plan = self.current_relation.as_ref().unwrap().inner.clone(); + if let Some(left_predicate) = conjuct(left_filters) { + left_plan = left_plan.filter(left_predicate)?; + } + + let mut right_plan = right_rel.inner.clone(); + if let Some(right_predicate) = conjuct(right_filters) { + right_plan = right_plan.filter(right_predicate)?; + } + + self.relation_mut().inner = left_plan.join_with_null_safe_equal( + right_plan, left_on, right_on, - null_eq_null, + null_eq_nulls, join_type, None, None, right_join_prefix.as_deref(), keep_join_keys, )?; + self.table_map.insert(right_rel_name, right_rel); } - Ok(left_rel) + Ok(()) } fn plan_relation(&self, rel: &sqlparser::ast::TableFactor) -> SQLPlannerResult { @@ -2105,14 +2151,6 @@ fn idents_to_str(idents: &[Ident]) -> String { .join(".") } -fn get_idents_vec(expr: &sqlparser::ast::Expr) -> SQLPlannerResult>> { - match expr { - sqlparser::ast::Expr::Identifier(ident) => Ok(Cow::Owned(vec![ident.clone()])), - sqlparser::ast::Expr::CompoundIdentifier(idents) => Ok(Cow::Borrowed(idents)), - _ => invalid_operation_err!("expected an identifier"), - } -} - /// unresolves an alias in a projection /// Example: /// ```sql