From 1deeaf9db21a251df3ff607606cdafa3d1999418 Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Mon, 28 Oct 2024 12:34:02 -0500 Subject: [PATCH] [FEAT]: sql cross join (#3110) still todo: - [x] add tests ## Notes for reviewers: This does not actually implement a physical cross join, but just implements the logical cross join as well as cross join to inner join optimization `eliminate_cross_join.rs` This treats an inner join with no join conditions as cross join. (inspired by a recent [change in datafusion](https://github.com/apache/datafusion/pull/12985)). If the cross join can not be optimized away, an error will be raised when attempting to execute the plan. --- src/common/error/src/error.rs | 8 + src/daft-physical-plan/src/translate.rs | 16 +- src/daft-plan/src/builder.rs | 24 + src/daft-plan/src/logical_ops/project.rs | 9 + .../src/logical_optimization/join_key_set.rs | 156 ++++ src/daft-plan/src/logical_optimization/mod.rs | 1 + .../src/logical_optimization/optimizer.rs | 5 +- .../rules/eliminate_cross_join.rs | 729 ++++++++++++++++++ .../src/logical_optimization/rules/mod.rs | 2 + .../src/physical_planner/translate.rs | 7 +- src/daft-sql/src/planner.rs | 39 +- 11 files changed, 984 insertions(+), 12 deletions(-) create mode 100644 src/daft-plan/src/logical_optimization/join_key_set.rs create mode 100644 src/daft-plan/src/logical_optimization/rules/eliminate_cross_join.rs diff --git a/src/common/error/src/error.rs b/src/common/error/src/error.rs index 31cb71b7ba..7574780721 100644 --- a/src/common/error/src/error.rs +++ b/src/common/error/src/error.rs @@ -46,6 +46,14 @@ pub enum DaftError { FmtError(#[from] std::fmt::Error), #[error("DaftError::RegexError {0}")] RegexError(#[from] regex::Error), + #[error("Not Yet Implemented: {0}")] + NotImplemented(String), +} + +impl DaftError { + pub fn not_implemented(msg: T) -> Self { + Self::NotImplemented(msg.to_string()) + } } impl From for DaftError { diff --git a/src/daft-physical-plan/src/translate.rs b/src/daft-physical-plan/src/translate.rs index 7dcb0f552b..8c851e7c39 100644 --- a/src/daft-physical-plan/src/translate.rs +++ b/src/daft-physical-plan/src/translate.rs @@ -1,7 +1,7 @@ -use common_error::DaftResult; +use common_error::{DaftError, DaftResult}; use daft_core::{join::JoinStrategy, prelude::Schema}; use daft_dsl::ExprRef; -use daft_plan::{LogicalPlan, LogicalPlanRef, SourceInfo}; +use daft_plan::{JoinType, LogicalPlan, LogicalPlanRef, SourceInfo}; use crate::local_plan::{LocalPhysicalPlan, LocalPhysicalPlanRef}; @@ -119,8 +119,18 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { )) } LogicalPlan::Join(join) => { + if join.left_on.is_empty() + && join.right_on.is_empty() + && join.join_type == JoinType::Inner + { + return Err(DaftError::not_implemented( + "Joins without join conditions (cross join) are not supported yet", + )); + } if join.join_strategy.is_some_and(|x| x != JoinStrategy::Hash) { - todo!("Only hash join is supported for now") + return Err(DaftError::not_implemented( + "Only hash join is supported for now", + )); } let left = translate(&join.left)?; let right = translate(&join.right)?; diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index 64f4faf5c5..a8386c25c7 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -84,6 +84,13 @@ impl From<&LogicalPlanBuilder> for LogicalPlanRef { value.plan.clone() } } + +impl From for LogicalPlanBuilder { + fn from(plan: LogicalPlanRef) -> Self { + Self::new(plan, None) + } +} + pub trait IntoGlobPath { fn into_glob_path(self) -> Vec; } @@ -468,6 +475,23 @@ impl LogicalPlanBuilder { Ok(self.with_new_plan(logical_plan)) } + pub fn cross_join>( + &self, + right: Right, + join_suffix: Option<&str>, + join_prefix: Option<&str>, + ) -> DaftResult { + self.join( + right, + vec![], + vec![], + JoinType::Inner, + None, + join_suffix, + join_prefix, + ) + } + pub fn concat(&self, other: &Self) -> DaftResult { let logical_plan: LogicalPlan = logical_ops::Concat::try_new(self.plan.clone(), other.plan.clone())?.into(); diff --git a/src/daft-plan/src/logical_ops/project.rs b/src/daft-plan/src/logical_ops/project.rs index 78de22bea6..d51a7b51a8 100644 --- a/src/daft-plan/src/logical_ops/project.rs +++ b/src/daft-plan/src/logical_ops/project.rs @@ -37,6 +37,15 @@ impl Project { projected_schema, }) } + /// Create a new Projection using the specified output schema + pub(crate) fn new_from_schema(input: Arc, schema: SchemaRef) -> Result { + let expr: Vec = schema + .names() + .into_iter() + .map(|n| Arc::new(Expr::Column(Arc::from(n)))) + .collect(); + Self::try_new(input, expr) + } pub fn multiline_display(&self) -> Vec { vec![format!( diff --git a/src/daft-plan/src/logical_optimization/join_key_set.rs b/src/daft-plan/src/logical_optimization/join_key_set.rs new file mode 100644 index 0000000000..a0fd79fbd5 --- /dev/null +++ b/src/daft-plan/src/logical_optimization/join_key_set.rs @@ -0,0 +1,156 @@ +// Borrowed from DataFusion project: datafusion/optimizer/src/join_key_set.rs + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [JoinKeySet] for tracking the set of join keys in a plan. + +use std::sync::Arc; + +use daft_dsl::{Expr, ExprRef}; +use indexmap::{Equivalent, IndexSet}; + +/// Tracks a set of equality Join keys +/// +/// A join key is an expression that is used to join two tables via an equality +/// predicate such as `a.x = b.y` +/// +/// This struct models `a.x + 5 = b.y AND a.z = b.z` as two join keys +/// 1. `(a.x + 5, b.y)` +/// 2. `(a.z, b.z)` +/// +/// # Important properties: +/// +/// 1. Retains insert order +/// 2. Can quickly look up if a pair of expressions are in the set. +#[derive(Debug)] +pub struct JoinKeySet { + inner: IndexSet<(ExprRef, ExprRef)>, +} + +impl JoinKeySet { + /// Create a new empty set + pub fn new() -> Self { + Self { + inner: IndexSet::new(), + } + } + + /// Return true if the set contains a join pair + /// where left = right or right = left + pub fn contains(&self, left: &Expr, right: &Expr) -> bool { + self.inner.contains(&ExprPair::new(left, right)) + || self.inner.contains(&ExprPair::new(right, left)) + } + + /// Insert the join key `(left = right)` into the set if join pair `(right = + /// left)` is not already in the set + /// + /// returns true if the pair was inserted + pub fn insert(&mut self, left: &Expr, right: &Expr) -> bool { + if self.contains(left, right) { + false + } else { + self.inner + .insert((left.clone().arced(), right.clone().arced())); + true + } + } + + /// Same as [`Self::insert`] but avoids cloning expression if they + /// are owned + pub fn insert_owned(&mut self, left: Expr, right: Expr) -> bool { + if self.contains(&left, &right) { + false + } else { + self.inner.insert((Arc::new(left), Arc::new(right))); + true + } + } + + /// Inserts potentially many join keys into the set, copying only when necessary + /// + /// returns true if any of the pairs were inserted + pub fn insert_all<'a>( + &mut self, + iter: impl IntoIterator, + ) -> bool { + let mut inserted = false; + for (left, right) in iter { + inserted |= self.insert(left, right); + } + inserted + } + + /// Same as [`Self::insert_all`] but avoids cloning expressions if they are + /// already owned + /// + /// returns true if any of the pairs were inserted + pub fn insert_all_owned(&mut self, iter: impl IntoIterator) -> bool { + let mut inserted = false; + for (left, right) in iter { + inserted |= self.insert_owned(Arc::unwrap_or_clone(left), Arc::unwrap_or_clone(right)); + } + inserted + } + + /// Inserts any join keys that are common to both `s1` and `s2` into self + pub fn insert_intersection(&mut self, s1: &Self, s2: &Self) { + // note can't use inner.intersection as we need to consider both (l, r) + // and (r, l) in equality + for (left, right) in &s1.inner { + if s2.contains(left.as_ref(), right.as_ref()) { + self.insert(left.as_ref(), right.as_ref()); + } + } + } + + /// returns true if this set is empty + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + /// Return the length of this set + #[cfg(test)] + pub fn len(&self) -> usize { + self.inner.len() + } + + /// Return an iterator over the join keys in this set + pub fn iter(&self) -> impl Iterator { + self.inner.iter().map(|(l, r)| (l, r)) + } +} + +/// Custom comparison operation to avoid copying owned values +/// +/// This behaves like a `(Expr, Expr)` tuple for hashing and comparison, but +/// avoids copying the values simply to comparing them. +#[derive(Debug, Eq, PartialEq, Hash)] +struct ExprPair<'a>(&'a Expr, &'a Expr); + +impl<'a> ExprPair<'a> { + fn new(left: &'a Expr, right: &'a Expr) -> Self { + Self(left, right) + } +} + +impl<'a> Equivalent<(ExprRef, ExprRef)> for ExprPair<'a> { + fn equivalent(&self, other: &(ExprRef, ExprRef)) -> bool { + self.0 == other.0.as_ref() && self.1 == other.1.as_ref() + } +} diff --git a/src/daft-plan/src/logical_optimization/mod.rs b/src/daft-plan/src/logical_optimization/mod.rs index 37bd2306cd..37806b54b3 100644 --- a/src/daft-plan/src/logical_optimization/mod.rs +++ b/src/daft-plan/src/logical_optimization/mod.rs @@ -1,3 +1,4 @@ +pub(crate) mod join_key_set; mod logical_plan_tracker; mod optimizer; mod rules; diff --git a/src/daft-plan/src/logical_optimization/optimizer.rs b/src/daft-plan/src/logical_optimization/optimizer.rs index a53d5980da..63e6259836 100644 --- a/src/daft-plan/src/logical_optimization/optimizer.rs +++ b/src/daft-plan/src/logical_optimization/optimizer.rs @@ -6,8 +6,8 @@ use common_treenode::Transformed; use super::{ logical_plan_tracker::LogicalPlanTracker, rules::{ - DropRepartition, OptimizerRule, PushDownFilter, PushDownLimit, PushDownProjection, - SplitActorPoolProjects, + DropRepartition, EliminateCrossJoin, OptimizerRule, PushDownFilter, PushDownLimit, + PushDownProjection, SplitActorPoolProjects, }, }; use crate::LogicalPlan; @@ -112,6 +112,7 @@ impl Optimizer { Box::new(DropRepartition::new()), Box::new(PushDownFilter::new()), Box::new(PushDownProjection::new()), + Box::new(EliminateCrossJoin::new()), ], // Use a fixed-point policy for the pushdown rules: PushDownProjection can produce a Filter node // at the current node, which would require another batch application in order to have a chance to push diff --git a/src/daft-plan/src/logical_optimization/rules/eliminate_cross_join.rs b/src/daft-plan/src/logical_optimization/rules/eliminate_cross_join.rs new file mode 100644 index 0000000000..92d6e30cb1 --- /dev/null +++ b/src/daft-plan/src/logical_optimization/rules/eliminate_cross_join.rs @@ -0,0 +1,729 @@ +/// Heavily inspired by DataFusion's EliminateCrossJoin rule: https://github.com/apache/datafusion/blob/b978cf8236436038a106ed94fb0d7eaa6ba99962/datafusion/optimizer/src/eliminate_cross_join.rs +use std::sync::Arc; + +use common_error::DaftResult; +use common_treenode::{Transformed, TreeNode}; +use daft_core::{ + join::JoinType, + prelude::{Schema, SchemaRef, TimeUnit}, +}; +use daft_dsl::{optimization::get_required_columns, Expr, ExprRef, Operator}; +use daft_schema::dtype::DataType; + +use super::OptimizerRule; +use crate::{ + logical_ops::{Filter, Join, Project}, + logical_optimization::join_key_set::JoinKeySet, + LogicalPlan, LogicalPlanRef, +}; + +#[derive(Default, Debug)] +pub struct EliminateCrossJoin {} + +impl EliminateCrossJoin { + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for EliminateCrossJoin { + fn try_optimize(&self, plan: Arc) -> DaftResult>> { + let schema = plan.schema(); + let mut possible_join_keys = JoinKeySet::new(); + let mut all_inputs: Vec> = vec![]; + let plan = Arc::unwrap_or_clone(plan); + + let parent_predicate = if let LogicalPlan::Filter(filter) = plan { + // if input isn't a join that can potentially be rewritten + // avoid unwrapping the input + let rewriteable = matches!( + filter.input.as_ref(), + LogicalPlan::Join(Join { + join_type: JoinType::Inner, + join_strategy: None, + .. + }) + ); + if !rewriteable { + return rewrite_children(self, Arc::new(LogicalPlan::Filter(filter))); + } + if !can_flatten_join_inputs(filter.input.as_ref()) { + return Ok(Transformed::no(Arc::new(LogicalPlan::Filter(filter)))); + } + let Filter { input, predicate } = filter; + flatten_join_inputs( + Arc::unwrap_or_clone(input), + &mut possible_join_keys, + &mut all_inputs, + )?; + extract_possible_join_keys(&predicate, &mut possible_join_keys); + Some(predicate) + } else if matches!( + plan, + LogicalPlan::Join(Join { + join_type: JoinType::Inner, + join_strategy: None, + .. + }) + ) { + if !can_flatten_join_inputs(&plan) { + return Ok(Transformed::no(plan.arced())); + } + flatten_join_inputs(plan, &mut possible_join_keys, &mut all_inputs)?; + None + } else { + // recursively try to rewrite children + return rewrite_children(self, Arc::new(plan)); + }; + // Join keys are handled locally: + let mut all_join_keys = JoinKeySet::new(); + let mut left = all_inputs.remove(0); + while !all_inputs.is_empty() { + left = find_inner_join( + left, + &mut all_inputs, + &possible_join_keys, + &mut all_join_keys, + )?; + } + left = rewrite_children(self, left)?.data; + if schema != left.schema() { + let project = Project::new_from_schema(left, schema)?; + + left = Arc::new(LogicalPlan::Project(project)); + } + let Some(predicate) = parent_predicate else { + return Ok(Transformed::yes(left)); + }; + + // If there are no join keys then do nothing: + if all_join_keys.is_empty() { + let f = Filter::try_new(left, predicate)?; + + Ok(Transformed::yes(Arc::new(LogicalPlan::Filter(f)))) + } else { + // Remove join expressions from filter: + match remove_join_expressions(predicate, &all_join_keys) { + Some(filter_expr) => { + let f = Filter::try_new(left, Arc::new(filter_expr))?; + + Ok(Transformed::yes(Arc::new(LogicalPlan::Filter(f)))) + } + _ => Ok(Transformed::yes(left)), + } + } + } +} + +fn rewrite_children( + optimizer: &impl OptimizerRule, + plan: Arc, +) -> DaftResult>> { + plan.map_children(|input| optimizer.try_optimize(input)) +} + +fn flatten_join_inputs( + plan: LogicalPlan, + possible_join_keys: &mut JoinKeySet, + all_inputs: &mut Vec, +) -> DaftResult<()> { + if let LogicalPlan::Join( + join @ Join { + join_type: JoinType::Inner, + join_strategy: None, + .. + }, + ) = plan + { + let keys = join.left_on.into_iter().zip(join.right_on); + possible_join_keys.insert_all_owned(keys); + flatten_join_inputs( + Arc::unwrap_or_clone(join.left), + possible_join_keys, + all_inputs, + )?; + flatten_join_inputs( + Arc::unwrap_or_clone(join.right), + possible_join_keys, + all_inputs, + )?; + } else { + all_inputs.push(Arc::new(plan)); + } + + Ok(()) +} + +/// Returns true if the plan is a Join or Cross join could be flattened with +/// `flatten_join_inputs` +/// +/// Must stay in sync with `flatten_join_inputs` +fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool { + // can only flatten inner / cross joins + match plan { + LogicalPlan::Join(join) if join.join_type == JoinType::Inner => {} + _ => return false, + }; + + for child in plan.children() { + if matches!( + child, + LogicalPlan::Join(Join { + join_strategy: None, + join_type: JoinType::Inner, + .. + }) + ) && !can_flatten_join_inputs(child) + { + return false; + } + } + true +} + +/// Extract join keys from a WHERE clause +fn extract_possible_join_keys(expr: &Expr, join_keys: &mut JoinKeySet) { + if let Expr::BinaryOp { left, op, right } = expr { + match op { + Operator::Eq => { + // insert handles ensuring we don't add the same Join keys multiple times + join_keys.insert(left, right); + } + Operator::And => { + extract_possible_join_keys(left, join_keys); + extract_possible_join_keys(right, join_keys); + } + // Fix for join predicates from inside of OR expr also pulled up properly. + Operator::Or => { + let mut left_join_keys = JoinKeySet::new(); + let mut right_join_keys = JoinKeySet::new(); + + extract_possible_join_keys(left, &mut left_join_keys); + extract_possible_join_keys(right, &mut right_join_keys); + + join_keys.insert_intersection(&left_join_keys, &right_join_keys); + } + _ => (), + }; + } +} + +/// Remove join expressions from a filter expression +/// +/// # Returns +/// * `Some()` when there are few remaining predicates in filter_expr +/// * `None` otherwise +fn remove_join_expressions(expr: ExprRef, join_keys: &JoinKeySet) -> Option { + match Arc::unwrap_or_clone(expr) { + Expr::BinaryOp { + left, + op: Operator::Eq, + right, + } if join_keys.contains(&left, &right) => { + // was a join key, so remove it + None + } + // Fix for join predicates from inside of OR expr also pulled up properly. + Expr::BinaryOp { left, op, right } if op == Operator::And => { + let l = remove_join_expressions(left, join_keys); + let r = remove_join_expressions(right, join_keys); + match (l, r) { + (Some(ll), Some(rr)) => Some(Expr::BinaryOp { + left: Arc::new(ll), + op, + right: Arc::new(rr), + }), + (Some(ll), _) => Some(ll), + (_, Some(rr)) => Some(rr), + _ => None, + } + } + Expr::BinaryOp { left, op, right } if op == Operator::Or => { + let l = remove_join_expressions(left, join_keys); + let r = remove_join_expressions(right, join_keys); + match (l, r) { + (Some(ll), Some(rr)) => Some(Expr::BinaryOp { + left: Arc::new(ll), + op, + right: Arc::new(rr), + }), + // When either `left` or `right` is empty, it means they are `true` + // so OR'ing anything with them will also be true + _ => None, + } + } + other => Some(other), + } +} + +/// Finds the next to join with the left input plan, +/// +/// Finds the next `right` from `rights` that can be joined with `left_input` +/// plan based on the join keys in `possible_join_keys`. +/// +/// If such a matching `right` is found: +/// 1. Adds the matching join keys to `all_join_keys`. +/// 2. Returns `left_input JOIN right ON (all join keys)`. +/// +/// If no matching `right` is found: +/// 1. Removes the first plan from `rights` +/// 2. Returns `left_input CROSS JOIN right`. +fn find_inner_join( + left_input: LogicalPlanRef, + rights: &mut Vec, + possible_join_keys: &JoinKeySet, + all_join_keys: &mut JoinKeySet, +) -> DaftResult { + for (i, right_input) in rights.iter().enumerate() { + let mut join_keys = vec![]; + + for (l, r) in possible_join_keys.iter() { + let key_pair = find_valid_equijoin_key_pair( + l.clone(), + r.clone(), + left_input.schema(), + right_input.schema(), + )?; + + // Save join keys + if let Some((valid_l, valid_r)) = key_pair { + if can_hash(&valid_l.get_type(left_input.schema().as_ref())?) { + join_keys.push((valid_l, valid_r)); + } + } + } + + // Found one or more matching join keys + if !join_keys.is_empty() { + all_join_keys.insert_all(join_keys.iter()); + let right_input = rights.remove(i); + let join_schema = left_input + .schema() + .non_distinct_union(right_input.schema().as_ref()); + + let (left_keys, right_keys) = join_keys.iter().cloned().unzip(); + return Ok(LogicalPlan::Join(Join { + left: left_input, + right: right_input, + left_on: left_keys, + right_on: right_keys, + join_type: JoinType::Inner, + join_strategy: None, + output_schema: Arc::new(join_schema), + }) + .arced()); + } + } + + // no matching right plan had any join keys, cross join with the first right + // plan + let right = rights.remove(0); + let join_schema = left_input + .schema() + .non_distinct_union(right.schema().as_ref()); + + Ok(LogicalPlan::Join(Join { + left: left_input, + right, + left_on: vec![], + right_on: vec![], + join_type: JoinType::Inner, + join_strategy: None, + output_schema: Arc::new(join_schema), + }) + .arced()) +} + +/// Check whether all columns are from the schema. +pub fn check_all_columns_from_schema(columns: &[String], schema: &Schema) -> DaftResult { + for col in columns { + let exist = schema.get_index(col).is_ok(); + + if !exist { + return Ok(false); + } + } + + Ok(true) +} + +/// Give two sides of the equijoin predicate, return a valid join key pair. +/// If there is no valid join key pair, return None. +/// +/// A valid join means: +/// 1. All referenced column of the left side is from the left schema, and +/// all referenced column of the right side is from the right schema. +/// 2. Or opposite. All referenced column of the left side is from the right schema, +/// and the right side is from the left schema. +/// +pub fn find_valid_equijoin_key_pair( + left_key: ExprRef, + right_key: ExprRef, + left_schema: SchemaRef, + right_schema: SchemaRef, +) -> DaftResult> { + let left_using_columns = get_required_columns(&left_key); + let right_using_columns = get_required_columns(&right_key); + + // Conditions like a = 10, will be added to non-equijoin. + if left_using_columns.is_empty() || right_using_columns.is_empty() { + return Ok(None); + } + + if check_all_columns_from_schema(&left_using_columns, &left_schema)? + && check_all_columns_from_schema(&right_using_columns, &right_schema)? + { + return Ok(Some((left_key, right_key))); + } else if check_all_columns_from_schema(&right_using_columns, &left_schema)? + && check_all_columns_from_schema(&left_using_columns, &right_schema)? + { + return Ok(Some((right_key, left_key))); + } + + Ok(None) +} + +/// Can this data type be used in hash join equal conditions?? +/// Data types here come from function 'equal_rows', if more data types are supported +/// in equal_rows(hash join), add those data types here to generate join logical plan. +pub fn can_hash(data_type: &DataType) -> bool { + match data_type { + DataType::Null => true, + DataType::Boolean => true, + DataType::Int8 => true, + DataType::Int16 => true, + DataType::Int32 => true, + DataType::Int64 => true, + DataType::UInt8 => true, + DataType::UInt16 => true, + DataType::UInt32 => true, + DataType::UInt64 => true, + DataType::Float32 => true, + DataType::Float64 => true, + DataType::Timestamp(time_unit, _) => match time_unit { + TimeUnit::Seconds => true, + TimeUnit::Milliseconds => true, + TimeUnit::Microseconds => true, + TimeUnit::Nanoseconds => true, + }, + DataType::Utf8 => true, + + DataType::Decimal128(_, _) => true, + DataType::Date => true, + + DataType::FixedSizeBinary(_) => true, + + DataType::List(_) => true, + + DataType::FixedSizeList(_, _) => true, + DataType::Struct(fields) => fields.iter().all(|f| can_hash(&f.dtype)), + _ => false, + } +} +#[cfg(test)] +mod tests { + use common_display::mermaid::{MermaidDisplay, MermaidDisplayOptions}; + use daft_dsl::{col, lit}; + use daft_schema::field::Field; + use rstest::*; + + use super::*; + use crate::{ + logical_plan::Source, source_info::PlaceHolderInfo, ClusteringSpec, LogicalPlan, + LogicalPlanBuilder, LogicalPlanRef, SourceInfo, + }; + + #[fixture] + fn t1() -> LogicalPlanRef { + let schema = Arc::new( + Schema::new(vec![ + Field::new("a", DataType::UInt32), + Field::new("b", DataType::UInt32), + Field::new("c", DataType::UInt32), + ]) + .unwrap(), + ); + LogicalPlan::Source(Source { + output_schema: schema.clone(), + source_info: Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { + source_schema: schema, + clustering_spec: Arc::new(ClusteringSpec::unknown()), + source_id: 0, + })), + }) + .arced() + } + + #[fixture] + fn t2() -> LogicalPlanRef { + let schema = Arc::new( + Schema::new(vec![ + Field::new("a", DataType::UInt32), + Field::new("b", DataType::UInt32), + Field::new("c", DataType::UInt32), + ]) + .unwrap(), + ); + LogicalPlan::Source(Source { + output_schema: schema.clone(), + source_info: Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { + source_schema: schema, + clustering_spec: Arc::new(ClusteringSpec::unknown()), + source_id: 0, + })), + }) + .arced() + } + + fn assert_optimized_plan_eq(plan: LogicalPlanRef, expected: LogicalPlanRef) { + let starting_schema = plan.schema(); + + let rule = EliminateCrossJoin::new(); + let transformed_plan = rule.try_optimize(plan).unwrap(); + assert!(transformed_plan.transformed, "failed to optimize plan"); + let actual = transformed_plan.data; + + if actual != expected { + println!( + "expected:\n{}\nactual:\n{}", + expected.repr_mermaid(MermaidDisplayOptions::default()), + actual.repr_mermaid(MermaidDisplayOptions::default()) + ); + } + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + assert_eq!(starting_schema, actual.schema()) + } + + #[rstest] + fn eliminate_cross_with_simple_and(t1: LogicalPlanRef, t2: LogicalPlanRef) -> DaftResult<()> { + // could eliminate to inner join since filter has Join predicates + let plan = LogicalPlanBuilder::from(t1.clone()) + .cross_join(t2.clone(), None, None)? + .filter(col("a").eq(col("right.a")).and(col("b").eq(col("right.b"))))? + .build(); + + let expected = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).select(vec![ + col("a").alias("right.a"), + col("b").alias("right.b"), + col("c").alias("right.c"), + ])?, + vec![col("a"), col("b")], + vec![col("right.a"), col("right.b")], + JoinType::Inner, + None, + None, + None, + )? + .build(); + + assert_optimized_plan_eq(plan, expected); + + Ok(()) + } + + #[rstest] + fn eliminate_cross_with_simple_or(t1: LogicalPlanRef, t2: LogicalPlanRef) -> DaftResult<()> { + // could not eliminate to inner join since filter OR expression and there is no common + // Join predicates in left and right of OR expr. + let plan = LogicalPlanBuilder::from(t1.clone()) + .cross_join(t2.clone(), None, None)? + .filter(col("a").eq(col("right.a")).or(col("right.b").eq(col("a"))))? + .build(); + + let expected = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).select(vec![ + col("a").alias("right.a"), + col("b").alias("right.b"), + col("c").alias("right.c"), + ])?, + vec![], + vec![], + JoinType::Inner, + None, + None, + None, + )? + .filter(col("a").eq(col("right.a")).or(col("right.b").eq(col("a"))))? + .build(); + + assert_optimized_plan_eq(plan, expected); + + Ok(()) + } + + #[rstest] + fn eliminate_cross_with_and(t1: LogicalPlanRef, t2: LogicalPlanRef) -> DaftResult<()> { + let expr1 = col("a").eq(col("right.a")); + let expr2 = col("right.c").lt(lit(20u32)); + let expr3 = col("a").eq(col("right.a")); + let expr4 = col("right.c").eq(lit(10u32)); + // could eliminate to inner join + let plan = LogicalPlanBuilder::from(t1.clone()) + .cross_join(t2.clone(), None, None)? + .filter(expr1.and(expr2.clone()).and(expr3).and(expr4.clone()))? + .build(); + + let expected = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).select(vec![ + col("a").alias("right.a"), + col("b").alias("right.b"), + col("c").alias("right.c"), + ])?, + vec![col("a")], + vec![col("right.a")], + JoinType::Inner, + None, + None, + None, + )? + .filter(expr2.and(expr4))? + .build(); + + assert_optimized_plan_eq(plan, expected); + + Ok(()) + } + + #[rstest] + fn eliminate_cross_with_or(t1: LogicalPlanRef, t2: LogicalPlanRef) -> DaftResult<()> { + // could eliminate to inner join since Or predicates have common Join predicates + let expr1 = col("a").eq(col("right.a")); + let expr2 = col("right.c").lt(lit(15u32)); + let expr3 = col("a").eq(col("right.a")); + let expr4 = col("right.c").eq(lit(688u32)); + let plan = LogicalPlanBuilder::from(t1.clone()) + .cross_join(t2.clone(), None, None)? + .filter(expr1.and(expr2.clone()).or(expr3.and(expr4.clone())))? + .build(); + + let expected = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).select(vec![ + col("a").alias("right.a"), + col("b").alias("right.b"), + col("c").alias("right.c"), + ])?, + vec![col("a")], + vec![col("right.a")], + JoinType::Inner, + None, + None, + None, + )? + .filter(expr2.or(expr4))? + .build(); + + assert_optimized_plan_eq(plan, expected); + + Ok(()) + } + + #[rstest] + fn eliminate_cross_join_multi_tables( + t1: LogicalPlanRef, + t2: LogicalPlanRef, + #[from(t1)] t3: LogicalPlanRef, + #[from(t1)] t4: LogicalPlanRef, + ) -> DaftResult<()> { + // could eliminate to inner join + let plan1 = LogicalPlanBuilder::from(t1.clone()) + .cross_join(t2.clone(), None, Some("t2."))? + .filter( + col("a") + .eq(col("t2.a")) + .and(col("t2.c").lt(lit(15u32))) + .or(col("a").eq(col("t2.a")).and(col("t2.c").eq(lit(688u32)))), + )? + .build(); + + let plan2 = LogicalPlanBuilder::from(t3.clone()) + .cross_join(t4.clone(), None, Some("t4."))? + .filter( + (col("a") + .eq(col("t4.a")) + .and(col("t4.c").lt(lit(15u32))) + .or(col("a").eq(col("t4.a")).and(col("c").eq(lit(688u32))))) + .or(col("a").eq(col("t4.a")).and(col("b").eq(col("t4.b")))), + )? + .build(); + + let plan = LogicalPlanBuilder::from(plan1.clone()) + .cross_join(plan2.clone(), None, Some("t3."))? + .filter( + col("t3.a") + .eq(col("a")) + .and(col("t4.c").lt(lit(15u32))) + .or(col("t3.a").eq(col("a")).and(col("t4.c").eq(lit(688u32)))), + )? + .build(); + let plan_1 = LogicalPlanBuilder::from(t1) + .join( + LogicalPlanBuilder::from(t2).select(vec![ + col("a").alias("t2.a"), + col("b").alias("t2.b"), + col("c").alias("t2.c"), + ])?, + vec![col("a")], + vec![col("t2.a")], + JoinType::Inner, + None, + None, + None, + )? + .filter(col("t2.c").lt(lit(15u32)).or(col("t2.c").eq(lit(688u32))))? + .build(); + + let plan_2 = LogicalPlanBuilder::from(t3) + .join( + LogicalPlanBuilder::from(t4).select(vec![ + col("a").alias("t4.a"), + col("b").alias("t4.b"), + col("c").alias("t4.c"), + ])?, + vec![col("a")], + vec![col("t4.a")], + JoinType::Inner, + None, + None, + None, + )? + .filter( + col("t4.c") + .lt(lit(15u32)) + .or(col("c").eq(lit(688u32))) + .or(col("b").eq(col("t4.b"))), + )? + .select(vec![ + col("a").alias("t3.a"), + col("b").alias("t3.b"), + col("c").alias("t3.c"), + col("t4.a"), + col("t4.b"), + col("t4.c"), + ])? + .build(); + let expected = LogicalPlanBuilder::from(plan_1) + .join( + plan_2, + vec![col("a")], + vec![col("t3.a")], + JoinType::Inner, + None, + None, + None, + )? + .filter(col("t4.c").lt(lit(15u32)).or(col("t4.c").eq(lit(688u32))))? + .build(); + + assert_optimized_plan_eq(plan, expected); + + Ok(()) + } +} diff --git a/src/daft-plan/src/logical_optimization/rules/mod.rs b/src/daft-plan/src/logical_optimization/rules/mod.rs index ac8579123a..06f6382ea8 100644 --- a/src/daft-plan/src/logical_optimization/rules/mod.rs +++ b/src/daft-plan/src/logical_optimization/rules/mod.rs @@ -1,4 +1,5 @@ mod drop_repartition; +mod eliminate_cross_join; mod push_down_filter; mod push_down_limit; mod push_down_projection; @@ -6,6 +7,7 @@ mod rule; mod split_actor_pool_projects; pub use drop_repartition::DropRepartition; +pub use eliminate_cross_join::EliminateCrossJoin; pub use push_down_filter::PushDownFilter; pub use push_down_limit::PushDownLimit; pub use push_down_projection::PushDownProjection; diff --git a/src/daft-plan/src/physical_planner/translate.rs b/src/daft-plan/src/physical_planner/translate.rs index 46f2b3a610..85a14532c2 100644 --- a/src/daft-plan/src/physical_planner/translate.rs +++ b/src/daft-plan/src/physical_planner/translate.rs @@ -5,7 +5,7 @@ use std::{ }; use common_daft_config::DaftExecutionConfig; -use common_error::DaftResult; +use common_error::{DaftError, DaftResult}; use common_file_formats::FileFormat; use daft_core::prelude::*; use daft_dsl::{ @@ -428,6 +428,11 @@ pub(super) fn translate_single_logical_node( join_strategy, .. }) => { + if left_on.is_empty() && right_on.is_empty() && join_type == &JoinType::Inner { + return Err(DaftError::not_implemented( + "Joins without join conditions (cross join) are not supported yet", + )); + } let mut right_physical = physical_children.pop().expect("requires 1 inputs"); let mut left_physical = physical_children.pop().expect("requires 2 inputs"); diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 8eb2f40fb0..2f47b5008b 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -26,6 +26,7 @@ use crate::{ error::{PlannerError, SQLPlannerResult}, invalid_operation_err, table_not_found_err, unsupported_sql_err, }; + /// A named logical plan /// This is used to keep track of the table name associated with a logical plan while planning a SQL query #[derive(Debug, Clone)] @@ -145,11 +146,8 @@ impl SQLPlanner { // FROM/JOIN let from = selection.clone().from; - if from.len() != 1 { - unsupported_sql_err!("Only exactly one table is supported"); - } - - self.current_relation = Some(self.plan_from(&from[0])?); + let rel = self.plan_from(&from)?; + self.current_relation = Some(rel); // WHERE if let Some(selection) = &selection.selection { @@ -293,7 +291,34 @@ impl SQLPlanner { Ok((exprs, desc)) } - fn plan_from(&mut self, from: &TableWithJoins) -> SQLPlannerResult { + fn plan_from(&mut self, from: &[TableWithJoins]) -> SQLPlannerResult { + if from.len() > 1 { + // todo!("cross join") + let mut from_iter = from.iter(); + + let first = from_iter.next().unwrap(); + let mut rel = self.plan_relation(&first.relation)?; + self.table_map.insert(rel.get_name(), rel.clone()); + for tbl in from_iter { + let right = self.plan_relation(&tbl.relation)?; + self.table_map.insert(right.get_name(), right.clone()); + let right_join_prefix = Some(format!("{}.", right.get_name())); + + rel.inner = rel.inner.join( + right.inner, + vec![], + vec![], + JoinType::Inner, + None, + None, + right_join_prefix.as_deref(), + )?; + } + return Ok(rel); + } + + let from = from.iter().next().unwrap(); + fn collect_compound_identifiers( left: &[Ident], right: &[Ident], @@ -494,6 +519,7 @@ impl SQLPlanner { let root = idents.next().unwrap(); let root = ident_to_str(root); + let current_relation = match self.table_map.get(&root) { Some(rel) => rel, None => { @@ -518,6 +544,7 @@ impl SQLPlanner { // If duplicate columns are present in the schema, it adds the table name as a prefix. (df.column_name) // So we first check if the prefixed column name is present in the schema. let current_schema = self.relation_opt().unwrap().inner.schema(); + let f = current_schema.get_field(&ident_str).ok(); if let Some(field) = f { Ok(vec![col(field.name.clone())])