Skip to content

Commit

Permalink
Fix LogicalPlan::..._with_subqueries methods (#13589)
Browse files Browse the repository at this point in the history
* Fix `LogicalPlan::transform_..._with_subqueries` methods

* add subquery tests
  • Loading branch information
peter-toth authored Nov 30, 2024
1 parent e2b335c commit 3ab67d8
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 20 deletions.
146 changes: 144 additions & 2 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3494,9 +3494,13 @@ mod tests {
use super::*;
use crate::builder::LogicalTableSource;
use crate::logical_plan::table_scan;
use crate::{col, exists, in_subquery, lit, placeholder, GroupingSet};
use crate::{
col, exists, in_subquery, lit, placeholder, scalar_subquery, GroupingSet,
};

use datafusion_common::tree_node::{TransformedResult, TreeNodeVisitor};
use datafusion_common::tree_node::{
TransformedResult, TreeNodeRewriter, TreeNodeVisitor,
};
use datafusion_common::{not_impl_err, Constraint, ScalarValue};

use crate::test::function_stub::count;
Expand Down Expand Up @@ -4157,4 +4161,142 @@ digraph {
.unwrap();
assert_eq!(limit, new_limit);
}

#[test]
fn test_with_subqueries_jump() {
// The test plan contains a `Project` node above a `Filter` node, and the
// `Project` node contains a subquery plan with a `Filter` root node, so returning
// `TreeNodeRecursion::Jump` on `Project` should cause not visiting any of the
// `Filter`s.
let subquery_schema =
Schema::new(vec![Field::new("sub_id", DataType::Int32, false)]);

let subquery_plan =
table_scan(TableReference::none(), &subquery_schema, Some(vec![0]))
.unwrap()
.filter(col("sub_id").eq(lit(0)))
.unwrap()
.build()
.unwrap();

let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]);

let plan = table_scan(TableReference::none(), &schema, Some(vec![0]))
.unwrap()
.filter(col("id").eq(lit(0)))
.unwrap()
.project(vec![col("id"), scalar_subquery(Arc::new(subquery_plan))])
.unwrap()
.build()
.unwrap();

let mut filter_found = false;
plan.apply_with_subqueries(|plan| {
match plan {
LogicalPlan::Projection(..) => return Ok(TreeNodeRecursion::Jump),
LogicalPlan::Filter(..) => filter_found = true,
_ => {}
}
Ok(TreeNodeRecursion::Continue)
})
.unwrap();
assert!(!filter_found);

struct ProjectJumpVisitor {
filter_found: bool,
}

impl ProjectJumpVisitor {
fn new() -> Self {
Self {
filter_found: false,
}
}
}

impl<'n> TreeNodeVisitor<'n> for ProjectJumpVisitor {
type Node = LogicalPlan;

fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
match node {
LogicalPlan::Projection(..) => return Ok(TreeNodeRecursion::Jump),
LogicalPlan::Filter(..) => self.filter_found = true,
_ => {}
}
Ok(TreeNodeRecursion::Continue)
}
}

let mut visitor = ProjectJumpVisitor::new();
plan.visit_with_subqueries(&mut visitor).unwrap();
assert!(!visitor.filter_found);

let mut filter_found = false;
plan.clone()
.transform_down_with_subqueries(|plan| {
match plan {
LogicalPlan::Projection(..) => {
return Ok(Transformed::new(plan, false, TreeNodeRecursion::Jump))
}
LogicalPlan::Filter(..) => filter_found = true,
_ => {}
}
Ok(Transformed::no(plan))
})
.unwrap();
assert!(!filter_found);

let mut filter_found = false;
plan.clone()
.transform_down_up_with_subqueries(
|plan| {
match plan {
LogicalPlan::Projection(..) => {
return Ok(Transformed::new(
plan,
false,
TreeNodeRecursion::Jump,
))
}
LogicalPlan::Filter(..) => filter_found = true,
_ => {}
}
Ok(Transformed::no(plan))
},
|plan| Ok(Transformed::no(plan)),
)
.unwrap();
assert!(!filter_found);

struct ProjectJumpRewriter {
filter_found: bool,
}

impl ProjectJumpRewriter {
fn new() -> Self {
Self {
filter_found: false,
}
}
}

impl TreeNodeRewriter for ProjectJumpRewriter {
type Node = LogicalPlan;

fn f_down(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
match node {
LogicalPlan::Projection(..) => {
return Ok(Transformed::new(node, false, TreeNodeRecursion::Jump))
}
LogicalPlan::Filter(..) => self.filter_found = true,
_ => {}
}
Ok(Transformed::no(node))
}
}

let mut rewriter = ProjectJumpRewriter::new();
plan.rewrite_with_subqueries(&mut rewriter).unwrap();
assert!(!rewriter.filter_found);
}
}
38 changes: 20 additions & 18 deletions datafusion/expr/src/logical_plan/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,10 @@ fn rewrite_extension_inputs<F: FnMut(LogicalPlan) -> Result<Transformed<LogicalP
macro_rules! handle_transform_recursion {
($F_DOWN:expr, $F_CHILD:expr, $F_UP:expr) => {{
$F_DOWN?
.transform_children(|n| n.map_subqueries($F_CHILD))?
.transform_sibling(|n| n.map_children($F_CHILD))?
.transform_children(|n| {
n.map_subqueries($F_CHILD)?
.transform_sibling(|n| n.map_children($F_CHILD))
})?
.transform_parent($F_UP)
}};
}
Expand Down Expand Up @@ -675,9 +677,11 @@ impl LogicalPlan {
visitor
.f_down(self)?
.visit_children(|| {
self.apply_subqueries(|c| c.visit_with_subqueries(visitor))
self.apply_subqueries(|c| c.visit_with_subqueries(visitor))?
.visit_sibling(|| {
self.apply_children(|c| c.visit_with_subqueries(visitor))
})
})?
.visit_sibling(|| self.apply_children(|c| c.visit_with_subqueries(visitor)))?
.visit_parent(|| visitor.f_up(self))
}

Expand Down Expand Up @@ -710,13 +714,12 @@ impl LogicalPlan {
node: &LogicalPlan,
f: &mut F,
) -> Result<TreeNodeRecursion> {
f(node)?
.visit_children(|| {
node.apply_subqueries(|c| apply_with_subqueries_impl(c, f))
})?
.visit_sibling(|| {
node.apply_children(|c| apply_with_subqueries_impl(c, f))
})
f(node)?.visit_children(|| {
node.apply_subqueries(|c| apply_with_subqueries_impl(c, f))?
.visit_sibling(|| {
node.apply_children(|c| apply_with_subqueries_impl(c, f))
})
})
}

apply_with_subqueries_impl(self, &mut f)
Expand Down Expand Up @@ -746,13 +749,12 @@ impl LogicalPlan {
node: LogicalPlan,
f: &mut F,
) -> Result<Transformed<LogicalPlan>> {
f(node)?
.transform_children(|n| {
n.map_subqueries(|c| transform_down_with_subqueries_impl(c, f))
})?
.transform_sibling(|n| {
n.map_children(|c| transform_down_with_subqueries_impl(c, f))
})
f(node)?.transform_children(|n| {
n.map_subqueries(|c| transform_down_with_subqueries_impl(c, f))?
.transform_sibling(|n| {
n.map_children(|c| transform_down_with_subqueries_impl(c, f))
})
})
}

transform_down_with_subqueries_impl(self, &mut f)
Expand Down

0 comments on commit 3ab67d8

Please sign in to comment.