From 523827938e5494ec6c81193b0024284eee8afa20 Mon Sep 17 00:00:00 2001 From: Kev Wang Date: Wed, 11 Dec 2024 14:01:08 -0800 Subject: [PATCH] feat: support for basic subquery execution (#3536) Subquery execution now possible through rewrite rules that convert them into joins. This only covers subqueries that can be converted into equi joins and not general subqueries. However that already gets us to 21/22 TPC-H (although Q16 is still missing count distinct implementation on the SQL side). Also includes a drive-by fix for SQL substring. --- Note on alternative implementations: - Although Datafusion's optimizer rules for subqueries are more capable than these, they make use of non-equi joins yet do not allow for decorrelating all subqueries. Instead, once we have non-equi joins we should implement the [Unnesting Arbitrary Subqueries](https://cs.emis.de/LNI/Proceedings/Proceedings241/383.pdf) paper (which [duckdb also implements](https://duckdb.org/2023/05/26/correlated-subqueries-in-sql.html#performance)). --- todo: - [x] add more tests --- src/daft-dsl/src/expr/mod.rs | 34 +- src/daft-logical-plan/src/logical_plan.rs | 33 +- .../src/optimization/optimizer.rs | 8 +- .../optimization/rules/drop_repartition.rs | 13 +- .../rules/lift_project_from_agg.rs | 10 +- .../src/optimization/rules/mod.rs | 2 + .../optimization/rules/push_down_filter.rs | 15 +- .../src/optimization/rules/push_down_limit.rs | 15 +- .../rules/push_down_projection.rs | 11 +- .../rules/split_actor_pool_projects.rs | 22 +- .../src/optimization/rules/unnest_subquery.rs | 780 ++++++++++++++++++ .../src/optimization/test/mod.rs | 13 +- src/daft-sql/src/planner.rs | 3 + tests/benchmarks/test_local_tpch.py | 46 ++ tests/sql/test_utf8_exprs.py | 2 +- 15 files changed, 968 insertions(+), 39 deletions(-) create mode 100644 src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs diff --git a/src/daft-dsl/src/expr/mod.rs b/src/daft-dsl/src/expr/mod.rs index de41b34cfa..bccec09f32 100644 --- a/src/daft-dsl/src/expr/mod.rs +++ b/src/daft-dsl/src/expr/mod.rs @@ -2,6 +2,8 @@ mod tests; use std::{ + any::Any, + hash::{DefaultHasher, Hash, Hasher}, io::{self, Write}, sync::Arc, }; @@ -38,8 +40,11 @@ use crate::{ pub trait SubqueryPlan: std::fmt::Debug + std::fmt::Display + Send + Sync { fn as_any(&self) -> &dyn std::any::Any; + fn as_any_arc(self: Arc) -> Arc; fn name(&self) -> &'static str; fn schema(&self) -> SchemaRef; + fn dyn_eq(&self, other: &dyn SubqueryPlan) -> bool; + fn dyn_hash(&self, state: &mut dyn Hasher); } #[derive(Display, Debug, Clone)] @@ -60,6 +65,14 @@ impl Subquery { pub fn name(&self) -> &'static str { self.plan.name() } + + pub fn semantic_id(&self) -> FieldID { + let mut s = DefaultHasher::new(); + self.hash(&mut s); + let hash = s.finish(); + + FieldID::new(format!("subquery({}-{})", self.name(), hash)) + } } impl Serialize for Subquery { @@ -76,7 +89,7 @@ impl<'de> Deserialize<'de> for Subquery { impl PartialEq for Subquery { fn eq(&self, other: &Self) -> bool { - self.plan.name() == other.plan.name() && self.plan.schema() == other.plan.schema() + self.plan.dyn_eq(other.plan.as_ref()) } } @@ -84,8 +97,7 @@ impl Eq for Subquery {} impl std::hash::Hash for Subquery { fn hash(&self, state: &mut H) { - self.plan.name().hash(state); - self.plan.schema().hash(state); + self.plan.dyn_hash(state); } } @@ -177,7 +189,7 @@ pub struct OuterReferenceColumn { impl Display for OuterReferenceColumn { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "outer_col({}, {})", self.field.name, self.depth) + write!(f, "outer_col({}, depth={})", self.field.name, self.depth) } } @@ -744,10 +756,18 @@ impl Expr { // Agg: Separate path. Self::Agg(agg_expr) => agg_expr.semantic_id(schema), Self::ScalarFunction(sf) => scalar_function_semantic_id(sf, schema), + Self::Subquery(subquery) => subquery.semantic_id(), + Self::InSubquery(expr, subquery) => { + let child_id = expr.semantic_id(schema); + let subquery_id = subquery.semantic_id(); - Self::Subquery(..) | Self::InSubquery(..) | Self::Exists(..) => { - FieldID::new("__subquery__") - } // todo: better/unique id + FieldID::new(format!("({child_id} IN {subquery_id})")) + } + Self::Exists(subquery) => { + let subquery_id = subquery.semantic_id(); + + FieldID::new(format!("(EXISTS {subquery_id})")) + } Self::OuterReferenceColumn(c) => { let name = &c.field.name; let depth = c.depth; diff --git a/src/daft-logical-plan/src/logical_plan.rs b/src/daft-logical-plan/src/logical_plan.rs index fc2f065038..673f372a69 100644 --- a/src/daft-logical-plan/src/logical_plan.rs +++ b/src/daft-logical-plan/src/logical_plan.rs @@ -1,8 +1,13 @@ -use std::{num::NonZeroUsize, sync::Arc}; +use std::{ + any::Any, + hash::{Hash, Hasher}, + num::NonZeroUsize, + sync::Arc, +}; use common_display::ascii::AsciiTreeDisplay; use common_error::DaftError; -use daft_dsl::{optimization::get_required_columns, SubqueryPlan}; +use daft_dsl::{optimization::get_required_columns, Subquery, SubqueryPlan}; use daft_schema::schema::SchemaRef; use indexmap::IndexSet; use snafu::Snafu; @@ -396,6 +401,10 @@ impl SubqueryPlan for LogicalPlan { self } + fn as_any_arc(self: Arc) -> Arc { + self + } + fn name(&self) -> &'static str { Self::name(self) } @@ -403,6 +412,26 @@ impl SubqueryPlan for LogicalPlan { fn schema(&self) -> SchemaRef { Self::schema(self) } + + fn dyn_eq(&self, other: &dyn SubqueryPlan) -> bool { + other + .as_any() + .downcast_ref::() + .map_or(false, |other| self == other) + } + + fn dyn_hash(&self, mut state: &mut dyn Hasher) { + self.hash(&mut state); + } +} + +pub(crate) fn downcast_subquery(subquery: &Subquery) -> LogicalPlanRef { + subquery + .plan + .clone() + .as_any_arc() + .downcast::() + .expect("subquery plan should be a LogicalPlan") } #[derive(Debug, Snafu)] diff --git a/src/daft-logical-plan/src/optimization/optimizer.rs b/src/daft-logical-plan/src/optimization/optimizer.rs index 257d7de736..76f6251438 100644 --- a/src/daft-logical-plan/src/optimization/optimizer.rs +++ b/src/daft-logical-plan/src/optimization/optimizer.rs @@ -8,7 +8,7 @@ use super::{ rules::{ DropRepartition, EliminateCrossJoin, EnrichWithStats, LiftProjectFromAgg, MaterializeScans, OptimizerRule, PushDownFilter, PushDownLimit, PushDownProjection, SimplifyExpressionsRule, - SplitActorPoolProjects, + SplitActorPoolProjects, UnnestPredicateSubquery, UnnestScalarSubquery, }, }; use crate::LogicalPlan; @@ -93,10 +93,12 @@ impl Optimizer { // --- Rewrite rules --- RuleBatch::new( vec![ - Box::new(SplitActorPoolProjects::new()), Box::new(LiftProjectFromAgg::new()), + Box::new(UnnestScalarSubquery::new()), + Box::new(UnnestPredicateSubquery::new()), + Box::new(SplitActorPoolProjects::new()), ], - RuleExecutionStrategy::Once, + RuleExecutionStrategy::FixedPoint(None), ), // we want to simplify expressions first to make the rest of the rules easier RuleBatch::new( diff --git a/src/daft-logical-plan/src/optimization/rules/drop_repartition.rs b/src/daft-logical-plan/src/optimization/rules/drop_repartition.rs index 7a96a87fb3..5e868617d4 100644 --- a/src/daft-logical-plan/src/optimization/rules/drop_repartition.rs +++ b/src/daft-logical-plan/src/optimization/rules/drop_repartition.rs @@ -52,7 +52,9 @@ mod tests { use crate::{ optimization::{ - rules::drop_repartition::DropRepartition, test::assert_optimized_plan_with_rules_eq, + optimizer::{RuleBatch, RuleExecutionStrategy}, + rules::drop_repartition::DropRepartition, + test::assert_optimized_plan_with_rules_eq, }, test::{dummy_scan_node, dummy_scan_operator}, LogicalPlan, @@ -65,7 +67,14 @@ mod tests { plan: Arc, expected: Arc, ) -> DaftResult<()> { - assert_optimized_plan_with_rules_eq(plan, expected, vec![Box::new(DropRepartition::new())]) + assert_optimized_plan_with_rules_eq( + plan, + expected, + vec![RuleBatch::new( + vec![Box::new(DropRepartition::new())], + RuleExecutionStrategy::Once, + )], + ) } /// Tests that DropRepartition does drops the upstream Repartition in back-to-back Repartitions. diff --git a/src/daft-logical-plan/src/optimization/rules/lift_project_from_agg.rs b/src/daft-logical-plan/src/optimization/rules/lift_project_from_agg.rs index 872859356f..19c7169e03 100644 --- a/src/daft-logical-plan/src/optimization/rules/lift_project_from_agg.rs +++ b/src/daft-logical-plan/src/optimization/rules/lift_project_from_agg.rs @@ -115,7 +115,10 @@ mod tests { use super::LiftProjectFromAgg; use crate::{ - optimization::test::assert_optimized_plan_with_rules_eq, + optimization::{ + optimizer::{RuleBatch, RuleExecutionStrategy}, + test::assert_optimized_plan_with_rules_eq, + }, test::{dummy_scan_node, dummy_scan_operator}, LogicalPlan, }; @@ -127,7 +130,10 @@ mod tests { assert_optimized_plan_with_rules_eq( plan, expected, - vec![Box::new(LiftProjectFromAgg::new())], + vec![RuleBatch::new( + vec![Box::new(LiftProjectFromAgg::new())], + RuleExecutionStrategy::Once, + )], ) } diff --git a/src/daft-logical-plan/src/optimization/rules/mod.rs b/src/daft-logical-plan/src/optimization/rules/mod.rs index 88d528f9a1..f540a77cb0 100644 --- a/src/daft-logical-plan/src/optimization/rules/mod.rs +++ b/src/daft-logical-plan/src/optimization/rules/mod.rs @@ -10,6 +10,7 @@ mod reorder_joins; mod rule; mod simplify_expressions; mod split_actor_pool_projects; +mod unnest_subquery; pub use drop_repartition::DropRepartition; pub use eliminate_cross_join::EliminateCrossJoin; @@ -22,3 +23,4 @@ pub use push_down_projection::PushDownProjection; pub use rule::OptimizerRule; pub use simplify_expressions::SimplifyExpressionsRule; pub use split_actor_pool_projects::SplitActorPoolProjects; +pub use unnest_subquery::{UnnestPredicateSubquery, UnnestScalarSubquery}; diff --git a/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs b/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs index 9fa30ea8e5..2b77bd8e9a 100644 --- a/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs +++ b/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs @@ -359,7 +359,11 @@ mod tests { use rstest::rstest; use crate::{ - optimization::{rules::PushDownFilter, test::assert_optimized_plan_with_rules_eq}, + optimization::{ + optimizer::{RuleBatch, RuleExecutionStrategy}, + rules::PushDownFilter, + test::assert_optimized_plan_with_rules_eq, + }, test::{dummy_scan_node, dummy_scan_node_with_pushdowns, dummy_scan_operator}, LogicalPlan, }; @@ -371,7 +375,14 @@ mod tests { plan: Arc, expected: Arc, ) -> DaftResult<()> { - assert_optimized_plan_with_rules_eq(plan, expected, vec![Box::new(PushDownFilter::new())]) + assert_optimized_plan_with_rules_eq( + plan, + expected, + vec![RuleBatch::new( + vec![Box::new(PushDownFilter::new())], + RuleExecutionStrategy::Once, + )], + ) } /// Tests that we can't pushdown a filter into a ScanOperator that has a limit. diff --git a/src/daft-logical-plan/src/optimization/rules/push_down_limit.rs b/src/daft-logical-plan/src/optimization/rules/push_down_limit.rs index e79879604d..1a002ea572 100644 --- a/src/daft-logical-plan/src/optimization/rules/push_down_limit.rs +++ b/src/daft-logical-plan/src/optimization/rules/push_down_limit.rs @@ -137,7 +137,11 @@ mod tests { use rstest::rstest; use crate::{ - optimization::{rules::PushDownLimit, test::assert_optimized_plan_with_rules_eq}, + optimization::{ + optimizer::{RuleBatch, RuleExecutionStrategy}, + rules::PushDownLimit, + test::assert_optimized_plan_with_rules_eq, + }, test::{dummy_scan_node, dummy_scan_node_with_pushdowns, dummy_scan_operator}, LogicalPlan, LogicalPlanBuilder, }; @@ -149,7 +153,14 @@ mod tests { plan: Arc, expected: Arc, ) -> DaftResult<()> { - assert_optimized_plan_with_rules_eq(plan, expected, vec![Box::new(PushDownLimit::new())]) + assert_optimized_plan_with_rules_eq( + plan, + expected, + vec![RuleBatch::new( + vec![Box::new(PushDownLimit::new())], + RuleExecutionStrategy::Once, + )], + ) } /// Tests that Limit pushes into external Source. diff --git a/src/daft-logical-plan/src/optimization/rules/push_down_projection.rs b/src/daft-logical-plan/src/optimization/rules/push_down_projection.rs index 632f5e3bfe..bc0d153c3d 100644 --- a/src/daft-logical-plan/src/optimization/rules/push_down_projection.rs +++ b/src/daft-logical-plan/src/optimization/rules/push_down_projection.rs @@ -680,7 +680,11 @@ mod tests { }; use crate::{ - optimization::{rules::PushDownProjection, test::assert_optimized_plan_with_rules_eq}, + optimization::{ + optimizer::{RuleBatch, RuleExecutionStrategy}, + rules::PushDownProjection, + test::assert_optimized_plan_with_rules_eq, + }, test::{dummy_scan_node, dummy_scan_node_with_pushdowns, dummy_scan_operator}, LogicalPlan, }; @@ -695,7 +699,10 @@ mod tests { assert_optimized_plan_with_rules_eq( plan, expected, - vec![Box::new(PushDownProjection::new())], + vec![RuleBatch::new( + vec![Box::new(PushDownProjection::new())], + RuleExecutionStrategy::Once, + )], ) } diff --git a/src/daft-logical-plan/src/optimization/rules/split_actor_pool_projects.rs b/src/daft-logical-plan/src/optimization/rules/split_actor_pool_projects.rs index d2b717ef20..a7bb607019 100644 --- a/src/daft-logical-plan/src/optimization/rules/split_actor_pool_projects.rs +++ b/src/daft-logical-plan/src/optimization/rules/split_actor_pool_projects.rs @@ -516,7 +516,11 @@ mod tests { use super::SplitActorPoolProjects; use crate::{ ops::{ActorPoolProject, Project}, - optimization::{rules::PushDownProjection, test::assert_optimized_plan_with_rules_eq}, + optimization::{ + optimizer::{RuleBatch, RuleExecutionStrategy}, + rules::PushDownProjection, + test::assert_optimized_plan_with_rules_eq, + }, test::{dummy_scan_node, dummy_scan_operator}, LogicalPlan, }; @@ -531,7 +535,10 @@ mod tests { assert_optimized_plan_with_rules_eq( plan, expected, - vec![Box::new(SplitActorPoolProjects {})], + vec![RuleBatch::new( + vec![Box::new(SplitActorPoolProjects::new())], + RuleExecutionStrategy::Once, + )], ) } @@ -545,10 +552,13 @@ mod tests { assert_optimized_plan_with_rules_eq( plan, expected, - vec![ - Box::new(SplitActorPoolProjects {}), - Box::new(PushDownProjection::new()), - ], + vec![RuleBatch::new( + vec![ + Box::new(SplitActorPoolProjects::new()), + Box::new(PushDownProjection::new()), + ], + RuleExecutionStrategy::Once, + )], ) } diff --git a/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs b/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs new file mode 100644 index 0000000000..3413e8cc53 --- /dev/null +++ b/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs @@ -0,0 +1,780 @@ +use std::{collections::HashSet, sync::Arc}; + +use common_error::{DaftError, DaftResult}; +use common_treenode::{DynTreeNode, Transformed, TreeNode}; +use daft_core::{join::JoinType, prelude::SchemaRef}; +use daft_dsl::{ + col, + optimization::{conjuct, split_conjuction}, + Expr, ExprRef, Operator, Subquery, +}; +use itertools::multiunzip; +use uuid::Uuid; + +use super::OptimizerRule; +use crate::{ + logical_plan::downcast_subquery, + ops::{Aggregate, Filter, Join, Project}, + LogicalPlan, LogicalPlanRef, +}; + +/// Rewriter rule to convert scalar subqueries into joins. +/// +/// ## Examples +/// ### Example 1 - Uncorrelated subquery +/// Before: +/// ```sql +/// SELECT val +/// FROM tbl1 +/// WHERE key = (SELECT max(key) FROM tbl2) +/// ``` +/// After: +/// ```sql +/// SELECT val +/// FROM tbl1 +/// CROSS JOIN (SELECT max(key) FROM tbl2) AS subquery +/// WHERE key = subquery.key -- this can be then pushed into join in a future rule +/// ``` +/// +/// ### Example 2 - Correlated subquery +/// Before: +/// ```sql +/// SELECT val +/// FROM tbl1 +/// WHERE outer_key = +/// ( +/// SELECT max(outer_key) +/// FROM tbl2 +/// WHERE inner_key = tbl1.inner_key +/// ) +/// ``` +/// After: +/// ```sql +/// SELECT val +/// FROM tbl1 +/// LEFT JOIN +/// ( +/// SELECT inner_key, max(outer_key) +/// FROM tbl2 +/// GROUP BY inner_key +/// ) AS subquery +/// ON inner_key +/// WHERE outer_key = subquery.outer_key +/// ``` +#[derive(Debug)] +pub struct UnnestScalarSubquery {} + +impl UnnestScalarSubquery { + pub fn new() -> Self { + Self {} + } +} + +impl UnnestScalarSubquery { + fn unnest_subqueries( + input: LogicalPlanRef, + exprs: Vec<&ExprRef>, + ) -> DaftResult)>> { + let mut subqueries = HashSet::new(); + + let new_exprs = exprs + .into_iter() + .map(|expr| { + expr.clone() + .transform_down(|e| { + if let Expr::Subquery(subquery) = e.as_ref() { + subqueries.insert(subquery.clone()); + + Ok(Transformed::yes(col(subquery.semantic_id().id))) + } else { + Ok(Transformed::no(e)) + } + }) + .unwrap() + .data + }) + .collect(); + + if subqueries.is_empty() { + return Ok(Transformed::no((input, new_exprs))); + } + + let new_input = subqueries + .into_iter() + .try_fold(input, |curr_input, subquery| { + let subquery_alias = subquery.semantic_id().id; + let subquery_plan = downcast_subquery(&subquery); + + let subquery_col_names = subquery_plan.schema().names(); + let [output_col] = subquery_col_names.as_slice() else { + return Err(DaftError::ValueError(format!( + "Expected scalar subquery to have one output column, received: {}", + subquery_col_names.len() + ))); + }; + + // alias output column + let subquery_plan = Arc::new(LogicalPlan::Project(Project::try_new( + subquery_plan, + vec![col(output_col.as_str()).alias(subquery_alias)], + )?)); + + let (decorrelated_subquery, subquery_on, input_on) = + pull_up_correlated_cols(subquery_plan)?; + + if subquery_on.is_empty() { + // uncorrelated scalar subquery + Ok(Arc::new(LogicalPlan::Join(Join::try_new( + curr_input, + decorrelated_subquery, + vec![], + vec![], + None, + JoinType::Inner, + None, + None, + None, + false, + )?))) + } else { + // correlated scalar subquery + Ok(Arc::new(LogicalPlan::Join(Join::try_new( + curr_input, + decorrelated_subquery, + input_on, + subquery_on, + None, + JoinType::Left, + None, + None, + None, + false, + )?))) + } + })?; + + Ok(Transformed::yes((new_input, new_exprs))) + } +} + +impl OptimizerRule for UnnestScalarSubquery { + fn try_optimize(&self, plan: Arc) -> DaftResult>> { + plan.transform_down(|node| match node.as_ref() { + LogicalPlan::Filter(Filter { + input, predicate, .. + }) => { + let unnest_result = + Self::unnest_subqueries(input.clone(), split_conjuction(predicate))?; + + if !unnest_result.transformed { + return Ok(Transformed::no(node)); + } + + let (new_input, new_predicates) = unnest_result.data; + + let new_predicate = conjuct(new_predicates) + .expect("predicates are guaranteed to exist at this point, so 'conjunct' should never return 'None'"); + + let new_filter = Arc::new(LogicalPlan::Filter(Filter::try_new( + new_input, + new_predicate, + )?)); + + // preserve original schema + let new_plan = Arc::new(LogicalPlan::Project(Project::new_from_schema( + new_filter, + input.schema(), + )?)); + + Ok(Transformed::yes(new_plan)) + } + LogicalPlan::Project(Project { + input, projection, .. + }) => { + let unnest_result = + Self::unnest_subqueries(input.clone(), projection.iter().collect())?; + + if !unnest_result.transformed { + return Ok(Transformed::no(node)); + } + + let (new_input, new_projection) = unnest_result.data; + + // preserve original schema + let new_plan = Arc::new(LogicalPlan::Project(Project::try_new( + new_input, + new_projection, + )?)); + + Ok(Transformed::yes(new_plan)) + } + _ => Ok(Transformed::no(node)), + }) + } +} + +/// Rewriter rule to convert IN and EXISTS subqueries into joins. +/// +/// ## Examples +/// ### Example 1 - Uncorrelated `IN` Query +/// Before: +/// ```sql +/// SELECT val +/// FROM tbl1 +/// WHERE key IN (SELECT key FROM tbl2) +/// ``` +/// After: +/// ```sql +/// SELECT val +/// FROM tbl1 +/// SEMI JOIN (SELECT key FROM tbl2) AS subquery +/// ON key = subquery.key +/// ``` +/// +/// ### Example 2 - Correlated `NOT EXISTS` Query +/// Before: +/// ```sql +/// SELECT val +/// FROM tbl1 +/// WHERE NOT EXISTS +/// ( +/// SELECT * +/// FROM tbl2 +/// WHERE key = tbl1.key +/// ) +/// ``` +/// +/// After: +/// ```sql +/// SELECT val +/// FROM tbl1 +/// ANTI JOIN (SELECT * FROM tbl2) AS subquery +/// ON key = subquery.key +/// ``` +#[derive(Debug)] +pub struct UnnestPredicateSubquery {} + +impl UnnestPredicateSubquery { + pub fn new() -> Self { + Self {} + } +} + +#[derive(Eq, Hash, PartialEq)] +struct PredicateSubquery { + pub subquery: Subquery, + pub in_expr: Option, + pub join_type: JoinType, +} + +impl OptimizerRule for UnnestPredicateSubquery { + fn try_optimize(&self, plan: Arc) -> DaftResult>> { + plan.transform_down(|node| match node.as_ref() { + LogicalPlan::Filter(Filter { + input, predicate, .. + }) => { + let mut subqueries = HashSet::new(); + + let new_predicates = split_conjuction(predicate) + .into_iter() + .filter(|expr| { + match expr.as_ref() { + Expr::InSubquery(in_expr, subquery) => { + subqueries.insert(PredicateSubquery { subquery: subquery.clone(), in_expr: Some(in_expr.clone()), join_type: JoinType::Semi }); + false + } + Expr::Exists(subquery) => { + subqueries.insert(PredicateSubquery { subquery: subquery.clone(), in_expr: None, join_type: JoinType::Semi }); + false + } + Expr::Not(e) => { + match e.as_ref() { + Expr::InSubquery(in_expr, subquery) => { + subqueries.insert(PredicateSubquery { subquery: subquery.clone(), in_expr: Some(in_expr.clone()), join_type: JoinType::Anti }); + false + } + Expr::Exists(subquery) => { + subqueries.insert(PredicateSubquery { subquery: subquery.clone(), in_expr: None, join_type: JoinType::Anti }); + false + } + _ => true + } + } + _ => true + } + }) + .cloned() + .collect::>(); + + if subqueries.is_empty() { + return Ok(Transformed::no(node)); + } + + let new_input = subqueries.into_iter().try_fold(input.clone(), |curr_input, PredicateSubquery { subquery, in_expr, join_type }| { + let subquery_plan = downcast_subquery(&subquery); + let subquery_schema = subquery_plan.schema(); + + let (decorrelated_subquery, mut subquery_on, mut input_on) = + pull_up_correlated_cols(subquery_plan)?; + + if let Some(in_expr) = in_expr { + let subquery_col_names = subquery_schema.names(); + let [output_col] = subquery_col_names.as_slice() else { + return Err(DaftError::ValueError(format!("Expected IN subquery to have one output column, received: {}", subquery_col_names.len()))); + }; + + input_on.push(in_expr); + subquery_on.push(col(output_col.as_str())); + } + + if subquery_on.is_empty() { + return Err(DaftError::ValueError("Expected IN/EXISTS subquery to be correlated, found uncorrelated subquery.".to_string())); + } + + Ok(Arc::new(LogicalPlan::Join(Join::try_new( + curr_input, + decorrelated_subquery, + input_on, + subquery_on, + None, + join_type, + None, + None, + None, + false + )?))) + })?; + + let new_plan = if let Some(new_predicate) = conjuct(new_predicates) { + // add filter back if there are non-subquery predicates + Arc::new(LogicalPlan::Filter(Filter::try_new( + new_input, + new_predicate, + )?)) + } else { + new_input + }; + + Ok(Transformed::yes(new_plan)) + } + _ => Ok(Transformed::no(node)), + }) + } +} + +fn pull_up_correlated_cols( + plan: LogicalPlanRef, +) -> DaftResult<(LogicalPlanRef, Vec, Vec)> { + let (new_inputs, subquery_on, outer_on): (Vec<_>, Vec<_>, Vec<_>) = multiunzip( + plan.arc_children() + .into_iter() + .map(pull_up_correlated_cols) + .collect::>>()?, + ); + + let plan = if new_inputs.is_empty() { + plan + } else { + Arc::new(plan.with_new_children(&new_inputs)) + }; + + let mut subquery_on = subquery_on.into_iter().flatten().collect::>(); + let mut outer_on = outer_on.into_iter().flatten().collect::>(); + + match plan.as_ref() { + LogicalPlan::Filter(Filter { + input, predicate, .. + }) => { + let mut found_correlated_col = false; + + let preds = split_conjuction(predicate) + .into_iter() + .filter(|expr| { + if let Expr::BinaryOp { + op: Operator::Eq, + left, + right, + } = expr.as_ref() + { + match (left.as_ref(), right.as_ref()) { + ( + Expr::Column(subquery_col_name), + Expr::OuterReferenceColumn(outer_col), + ) + | ( + Expr::OuterReferenceColumn(outer_col), + Expr::Column(subquery_col_name), + ) => { + // remove correlated col from filter, use in join instead + subquery_on.push(col(subquery_col_name.clone())); + outer_on.push(col(outer_col.field.name.as_str())); + + found_correlated_col = true; + return false; + } + _ => {} + } + } + + true + }) + .cloned() + .collect::>(); + + // no new correlated cols found + if !found_correlated_col { + return Ok((plan.clone(), subquery_on, outer_on)); + } + + if let Some(new_predicate) = conjuct(preds) { + let new_plan = Arc::new(LogicalPlan::Filter(Filter::try_new( + input.clone(), + new_predicate, + )?)); + + Ok((new_plan, subquery_on, outer_on)) + } else { + // all predicates are correlated so filter can be completely removed + Ok((input.clone(), subquery_on, outer_on)) + } + } + LogicalPlan::Project(Project { + input, + projection, + projected_schema, + .. + }) => { + // ensure all columns that need to be pulled up are in the projection + + let (new_subquery_on, missing_exprs) = + get_missing_exprs(subquery_on, projection, projected_schema); + + if missing_exprs.is_empty() { + // project already contains all necessary columns + Ok((plan.clone(), new_subquery_on, outer_on)) + } else { + let new_projection = [projection.clone(), missing_exprs].concat(); + + let new_plan = Arc::new(LogicalPlan::Project(Project::try_new( + input.clone(), + new_projection, + )?)); + + Ok((new_plan, new_subquery_on, outer_on)) + } + } + LogicalPlan::Aggregate(Aggregate { + input, + aggregations, + groupby, + output_schema, + .. + }) => { + // put columns that need to be pulled up into the groupby + + let (new_subquery_on, missing_groupbys) = + get_missing_exprs(subquery_on, groupby, output_schema); + + if missing_groupbys.is_empty() { + // agg already contains all necessary columns + Ok((plan.clone(), new_subquery_on, outer_on)) + } else { + let new_groupby = [groupby.clone(), missing_groupbys].concat(); + + let new_plan = Arc::new(LogicalPlan::Aggregate(Aggregate::try_new( + input.clone(), + aggregations.clone(), + new_groupby, + )?)); + + Ok((new_plan, new_subquery_on, outer_on)) + } + } + + // ops that can trivially pull up correlated cols + LogicalPlan::Distinct(..) + | LogicalPlan::MonotonicallyIncreasingId(..) + | LogicalPlan::Repartition(..) + | LogicalPlan::Union(..) + | LogicalPlan::Intersect(..) + | LogicalPlan::Sort(..) => Ok((plan.clone(), subquery_on, outer_on)), + + // ops that cannot pull up correlated columns + LogicalPlan::ActorPoolProject(..) + | LogicalPlan::Limit(..) + | LogicalPlan::Sample(..) + | LogicalPlan::Source(..) + | LogicalPlan::Explode(..) + | LogicalPlan::Unpivot(..) + | LogicalPlan::Pivot(..) + | LogicalPlan::Concat(..) + | LogicalPlan::Join(..) + | LogicalPlan::Sink(..) => { + if subquery_on.is_empty() { + Ok((plan.clone(), vec![], vec![])) + } else { + Err(DaftError::NotImplemented(format!( + "Pulling up correlated columns not supported for: {}", + plan.name() + ))) + } + } + } +} + +fn get_missing_exprs( + subquery_on: Vec, + existing_exprs: &[ExprRef], + schema: &SchemaRef, +) -> (Vec, Vec) { + let mut new_subquery_on = Vec::new(); + let mut missing_exprs = Vec::new(); + + for expr in subquery_on { + if existing_exprs.contains(&expr) { + // column already exists in schema + new_subquery_on.push(expr); + } else if schema.has_field(expr.name()) { + // another expression takes pull up column name, we rename the pull up column. + let new_name = format!("{}-{}", expr.name(), Uuid::new_v4()); + + new_subquery_on.push(col(new_name.clone())); + missing_exprs.push(expr.alias(new_name)); + } else { + // missing from schema, can keep original name + + new_subquery_on.push(expr.clone()); + missing_exprs.push(expr); + } + } + + (new_subquery_on, missing_exprs) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use common_error::DaftResult; + use daft_core::join::JoinType; + use daft_dsl::{col, Expr, OuterReferenceColumn, Subquery}; + use daft_schema::{dtype::DataType, field::Field}; + + use super::{UnnestPredicateSubquery, UnnestScalarSubquery}; + use crate::{ + optimization::{ + optimizer::{RuleBatch, RuleExecutionStrategy}, + test::assert_optimized_plan_with_rules_eq, + }, + test::{dummy_scan_node, dummy_scan_operator}, + LogicalPlanRef, + }; + + fn assert_scalar_optimized_plan_eq( + plan: LogicalPlanRef, + expected: LogicalPlanRef, + ) -> DaftResult<()> { + assert_optimized_plan_with_rules_eq( + plan, + expected, + vec![RuleBatch::new( + vec![Box::new(UnnestScalarSubquery::new())], + RuleExecutionStrategy::Once, + )], + ) + } + + fn assert_predicate_optimized_plan_eq( + plan: LogicalPlanRef, + expected: LogicalPlanRef, + ) -> DaftResult<()> { + assert_optimized_plan_with_rules_eq( + plan, + expected, + vec![RuleBatch::new( + vec![Box::new(UnnestPredicateSubquery::new())], + RuleExecutionStrategy::Once, + )], + ) + } + + #[test] + fn uncorrelated_scalar_subquery() -> DaftResult<()> { + let tbl1 = dummy_scan_node(dummy_scan_operator(vec![ + Field::new("key", DataType::Int64), + Field::new("val", DataType::Int64), + ])); + + let tbl2 = dummy_scan_node(dummy_scan_operator(vec![Field::new( + "key", + DataType::Int64, + )])); + + let subquery = tbl2.aggregate(vec![col("key").max()], vec![])?; + let subquery_expr = Arc::new(Expr::Subquery(Subquery { + plan: subquery.build(), + })); + let subquery_alias = subquery_expr.semantic_id(&subquery.schema()).id; + + let plan = tbl1 + .filter(col("key").eq(subquery_expr))? + .select(vec![col("val")])? + .build(); + + let expected = tbl1 + .join( + subquery.select(vec![col("key").alias(subquery_alias.clone())])?, + vec![], + vec![], + JoinType::Inner, + None, + None, + None, + false, + )? + .filter(col("key").eq(col(subquery_alias)))? + .select(vec![col("key"), col("val")])? + .select(vec![col("val")])? + .build(); + + assert_scalar_optimized_plan_eq(plan, expected)?; + Ok(()) + } + + #[test] + fn correlated_scalar_subquery() -> DaftResult<()> { + let tbl1 = dummy_scan_node(dummy_scan_operator(vec![ + Field::new("outer_key", DataType::Int64), + Field::new("inner_key", DataType::Int64), + Field::new("val", DataType::Int64), + ])); + + let tbl2 = dummy_scan_node(dummy_scan_operator(vec![ + Field::new("outer_key", DataType::Int64), + Field::new("inner_key", DataType::Int64), + ])); + + let subquery = tbl2 + .filter(col("inner_key").eq(Arc::new(Expr::OuterReferenceColumn( + OuterReferenceColumn { + field: Field::new("inner_key", DataType::Int64), + depth: 1, + }, + ))))? + .aggregate(vec![col("outer_key").max()], vec![])?; + let subquery_expr = Arc::new(Expr::Subquery(Subquery { + plan: subquery.build(), + })); + let subquery_alias = subquery_expr.semantic_id(&subquery.schema()).id; + + let plan = tbl1 + .filter(col("outer_key").eq(subquery_expr))? + .select(vec![col("val")])? + .build(); + + let expected = tbl1 + .join( + tbl2.aggregate(vec![col("outer_key").max()], vec![col("inner_key")])? + .select(vec![ + col("outer_key").alias(subquery_alias.clone()), + col("inner_key"), + ])?, + vec![col("inner_key")], + vec![col("inner_key")], + JoinType::Left, + None, + None, + None, + false, + )? + .filter(col("outer_key").eq(col(subquery_alias)))? + .select(vec![col("outer_key"), col("inner_key"), col("val")])? + .select(vec![col("val")])? + .build(); + + assert_scalar_optimized_plan_eq(plan, expected)?; + Ok(()) + } + + #[test] + fn uncorrelated_predicate_subquery() -> DaftResult<()> { + let tbl1 = dummy_scan_node(dummy_scan_operator(vec![ + Field::new("key", DataType::Int64), + Field::new("val", DataType::Int64), + ])); + + let tbl2 = dummy_scan_node(dummy_scan_operator(vec![Field::new( + "key", + DataType::Int64, + )])); + + let plan = tbl1 + .filter(Arc::new(Expr::InSubquery( + col("key"), + Subquery { plan: tbl2.build() }, + )))? + .select(vec![col("val")])? + .build(); + + let expected = tbl1 + .join( + tbl2, + vec![col("key")], + vec![col("key")], + JoinType::Semi, + None, + None, + None, + false, + )? + .select(vec![col("val")])? + .build(); + + assert_predicate_optimized_plan_eq(plan, expected)?; + Ok(()) + } + + #[test] + fn correlated_predicate_subquery() -> DaftResult<()> { + let tbl1 = dummy_scan_node(dummy_scan_operator(vec![ + Field::new("key", DataType::Int64), + Field::new("val", DataType::Int64), + ])); + + let tbl2 = dummy_scan_node(dummy_scan_operator(vec![Field::new( + "key", + DataType::Int64, + )])); + + let subquery = tbl2 + .filter( + col("key").eq(Arc::new(Expr::OuterReferenceColumn(OuterReferenceColumn { + field: Field::new("key", DataType::Int64), + depth: 1, + }))), + )? + .build(); + + let plan = tbl1 + .filter(Arc::new(Expr::Exists(Subquery { plan: subquery })).not())? + .select(vec![col("val")])? + .build(); + + let expected = tbl1 + .join( + tbl2, + vec![col("key")], + vec![col("key")], + JoinType::Anti, + None, + None, + None, + false, + )? + .select(vec![col("val")])? + .build(); + + assert_predicate_optimized_plan_eq(plan, expected)?; + Ok(()) + } +} diff --git a/src/daft-logical-plan/src/optimization/test/mod.rs b/src/daft-logical-plan/src/optimization/test/mod.rs index e682d4cc07..7a8fe12121 100644 --- a/src/daft-logical-plan/src/optimization/test/mod.rs +++ b/src/daft-logical-plan/src/optimization/test/mod.rs @@ -2,12 +2,8 @@ use std::sync::Arc; use common_error::DaftResult; -use super::optimizer::OptimizerRuleInBatch; use crate::{ - optimization::{ - optimizer::{RuleBatch, RuleExecutionStrategy}, - Optimizer, - }, + optimization::{optimizer::RuleBatch, Optimizer}, LogicalPlan, }; @@ -17,12 +13,9 @@ use crate::{ pub fn assert_optimized_plan_with_rules_eq( plan: Arc, expected: Arc, - rules: Vec>, + rule_batches: Vec, ) -> DaftResult<()> { - let optimizer = Optimizer::with_rule_batches( - vec![RuleBatch::new(rules, RuleExecutionStrategy::Once)], - Default::default(), - ); + let optimizer = Optimizer::with_rule_batches(rule_batches, Default::default()); let optimized_plan = optimizer .optimize_with_rules(optimizer.rule_batches[0].rules.as_slice(), plan.clone())? .data; diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index a8e7a90c7b..e7a1fa381c 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -1458,6 +1458,9 @@ impl<'a> SQLPlanner<'a> { let start = self.plan_expr(substring_from)?; let length = self.plan_expr(substring_for)?; + // SQL substring is one indexed + let start = start.sub(lit(1)); + Ok(daft_functions::utf8::substr(expr, start, length)) } SQLExpr::Substring { special: false, .. } => { diff --git a/tests/benchmarks/test_local_tpch.py b/tests/benchmarks/test_local_tpch.py index f5d2b96d32..023781aec9 100644 --- a/tests/benchmarks/test_local_tpch.py +++ b/tests/benchmarks/test_local_tpch.py @@ -46,3 +46,49 @@ def f(): benchmark_group = f"q{q}-parts-{num_parts}" daft_pd_df = benchmark_with_memray(f, benchmark_group).to_pandas() check_answer(daft_pd_df, q, tmp_path) + + +@pytest.mark.skipif( + get_tests_daft_runner_name() not in {"py", "native"}, + reason="requires PyRunner to be in use", +) +@pytest.mark.benchmark(group="tpch") +@pytest.mark.parametrize("engine, q", itertools.product(ENGINES, TPCH_QUESTIONS)) +def test_tpch_sql(tmp_path, check_answer, get_df, benchmark_with_memray, engine, q): # noqa F811 + from daft.sql import SQLCatalog + + get_df, num_parts = get_df + + # TODO: remove this once SQL allows case-insensitive column names + def lowercase_column_names(df): + return df.select(*[daft.col(name).alias(name.lower()) for name in df.column_names]) + + table_names = [ + "part", + "supplier", + "partsupp", + "customer", + "orders", + "lineitem", + "nation", + "region", + ] + catalog = SQLCatalog({tbl: lowercase_column_names(get_df(tbl)) for tbl in table_names}) + + with open(f"benchmarking/tpch/queries/{q:02}.sql") as query_file: + query = query_file.read() + + def f(): + if engine == "native": + daft.context.set_runner_native() + elif engine == "python": + daft.context.set_runner_py() + else: + raise ValueError(f"{engine} unsupported") + + daft_df = daft.sql(query, catalog=catalog) + return daft_df.to_arrow() + + benchmark_group = f"q{q}-sql-parts-{num_parts}" + daft_pd_df = benchmark_with_memray(f, benchmark_group).to_pandas() + check_answer(daft_pd_df, q, tmp_path) diff --git a/tests/sql/test_utf8_exprs.py b/tests/sql/test_utf8_exprs.py index 8f45032c4e..c103db8503 100644 --- a/tests/sql/test_utf8_exprs.py +++ b/tests/sql/test_utf8_exprs.py @@ -52,7 +52,7 @@ def test_utf8_exprs(): repeat(a, 2) as repeat_a, a like 'a%' as like_a, a ilike 'a%' as ilike_a, - substring(a, 1, 3) as substring_a, + substring(a, 2, 3) as substring_a, count_matches(a, 'a') as count_matches_a_0, count_matches(a, 'a', case_sensitive := true) as count_matches_a_1, count_matches(a, 'a', case_sensitive := false, whole_words := false) as count_matches_a_2,