Skip to content

Commit

Permalink
perf(optimizer): convert filter predicate to CNF to push through join (
Browse files Browse the repository at this point in the history
…#3623)

By converting the predicate to Conjunctive Normal Form, we are able to
split the predicate into expressions that are isolated to the schemas of
an individual side of a join. Then, we can split by conjunctions and
push those new expressions through the joins.

For example, consider the schemas:
```
left = [a: Utf8]
right = [b: Utf8]
```

If our plan was
```
Filter
|  by: (col(a) = "FRANCE" & col("b") = "GERMANY") | (col(a) = "GERMANY" & col("b") = "FRANCE")
|
Join
| \
|  Scan(right)
Scan(left)
```

That filter predicate in CNF is
```
(col(a) = "FRANCE" | col(a) = "GERMANY") 
  & (col(a) = "FRANCE" | col(b) = "FRANCE") 
  & (col(b) = "GERMANY" | col(a) = "GERMANY") 
  & (col(b) = "GERMANY" | col(b) = "FRANCE")
```
The first and last expressions in this conjunction only have columns
from a single side of the join, so can be pushed down, resulting in this
optimized plan:

```
Filter
|  by: (col(a) = "FRANCE" | col(b) = "FRANCE") & (col(b) = "GERMANY" | col(a) = "GERMANY") 
|
Join
| \
|  Scan(right)
|    pushdowns: Filter { (col(b) = "GERMANY" | col(b) = "FRANCE") }
Scan(left)
|. pushdowns: Filter { (col(a) = "FRANCE" | col(a) = "GERMANY") }
```
  • Loading branch information
kevinzwang authored Dec 19, 2024
1 parent 07f6b2c commit 28d3fa7
Show file tree
Hide file tree
Showing 2 changed files with 238 additions and 3 deletions.
180 changes: 179 additions & 1 deletion src/daft-algebra/src/boolean.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use common_treenode::{TreeNode, TreeNodeRecursion};
use std::sync::Arc;

use common_treenode::{Transformed, TreeNode, TreeNodeRecursion};
use daft_dsl::{Expr, ExprRef, Operator};

pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
Expand All @@ -22,3 +24,179 @@ pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
pub fn combine_conjunction<T: IntoIterator<Item = ExprRef>>(exprs: T) -> Option<ExprRef> {
exprs.into_iter().reduce(|acc, e| acc.and(e))
}

/// Converts a boolean expression to conjunctive normal form (AND of ORs)
pub fn to_cnf(expr: ExprRef) -> ExprRef {
let dnf_form = to_dnf(expr.not()).not();

apply_de_morgans(dnf_form).data
}

/// Converts a boolean expression to disjunctive normal form (OR of ANDs)
pub fn to_dnf(expr: ExprRef) -> ExprRef {
let dm_expr = apply_de_morgans(expr).data;

// apply distributive property recursively
dm_expr
.transform_down(|e| {
Ok(match e.as_ref() {
Expr::BinaryOp {
op: Operator::And,
left,
right,
} => {
if let Expr::BinaryOp {
op: Operator::Or,
left: right_left,
right: right_right,
} = right.as_ref()
{
// (x & (y | z)) -> ((x & y) | (x & z))
Transformed::yes(Arc::new(Expr::BinaryOp {
op: Operator::Or,
left: Arc::new(Expr::BinaryOp {
op: Operator::And,
left: left.clone(),
right: right_left.clone(),
}),
right: Arc::new(Expr::BinaryOp {
op: Operator::And,
left: left.clone(),
right: right_right.clone(),
}),
}))
} else if let Expr::BinaryOp {
op: Operator::Or,
left: left_left,
right: left_right,
} = left.as_ref()
{
// ((x | y) & z) -> ((x & z) | (y & z))
Transformed::yes(Arc::new(Expr::BinaryOp {
op: Operator::Or,
left: Arc::new(Expr::BinaryOp {
op: Operator::And,
left: left_left.clone(),
right: right.clone(),
}),
right: Arc::new(Expr::BinaryOp {
op: Operator::And,
left: left_right.clone(),
right: right.clone(),
}),
}))
} else {
Transformed::no(e)
}
}
_ => Transformed::no(e),
})
})
.unwrap()
.data
}

/// Transform boolean expression by applying De Morgan's law + eliminate double negations recursively
fn apply_de_morgans(expr: ExprRef) -> Transformed<ExprRef> {
expr.transform_down(|e| {
Ok(match e.as_ref() {
Expr::Not(ne) => match ne.as_ref() {
// !x -> x
Expr::Not(nne) => Transformed::yes(nne.clone()),
// !(x & y) -> ((!x) | (!y))
Expr::BinaryOp {
op: Operator::And,
left,
right,
} => Transformed::yes(Arc::new(Expr::BinaryOp {
op: Operator::Or,
left: left.clone().not(),
right: right.clone().not(),
})),
// !(x | y) -> ((!x) & (!y))
Expr::BinaryOp {
op: Operator::Or,
left,
right,
} => Transformed::yes(Arc::new(Expr::BinaryOp {
op: Operator::And,
left: left.clone().not(),
right: right.clone().not(),
})),
_ => Transformed::no(e),
},
_ => Transformed::no(e),
})
})
.unwrap()
}

#[cfg(test)]
mod tests {
use daft_dsl::col;

use crate::boolean::{to_cnf, to_dnf};

#[test]
fn dnf_simple() {
// a & (b | c) -> (a & b) | (a & c)
let expr = col("a").and(col("b").or(col("c")));
let expected = col("a").and(col("b")).or(col("a").and(col("c")));

assert_eq!(expected, to_dnf(expr));
}

#[test]
fn cnf_simple() {
// a | (b & c) -> (a | b) & (a | c)
let expr = col("a").or(col("b").and(col("c")));
let expected = col("a").or(col("b")).and(col("a").or(col("c")));

assert_eq!(expected, to_cnf(expr));
}

#[test]
fn dnf_neg() {
// !(a & ((!b) | c)) -> (!a) | (b & (!c))
let expr = col("a").and(col("b").not().or(col("c"))).not();
let expected = col("a").not().or(col("b").and(col("c").not()));

assert_eq!(expected, to_dnf(expr));
}

#[test]
fn cnf_neg() {
// !(a | ((!b) & c)) -> (!a) & (b | (!c))
let expr = col("a").or(col("b").not().and(col("c"))).not();
let expected = col("a").not().and(col("b").or(col("c").not()));

assert_eq!(expected, to_cnf(expr));
}

#[test]
fn dnf_nested() {
// a & b & ((c & d) | (e & f)) -> (a & b & c & d) | (a & b & e & f)
let expr = col("a")
.and(col("b"))
.and((col("c").and(col("d"))).or(col("e").and(col("f"))));
let expected = (col("a").and(col("b")).and(col("c").and(col("d"))))
.or(col("a").and(col("b")).and(col("e").and(col("f"))));

assert_eq!(expected, to_dnf(expr));
}

#[test]
fn cnf_nested() {
// a & b & ((c & d) | (e & f)) -> a & b & (c | e) & (c | f) & (d | e) & (d | f)
let expr = col("a")
.and(col("b"))
.and((col("c").and(col("d"))).or(col("e").and(col("f"))));
let expected = col("a").and(col("b")).and(
(col("c").or(col("e")))
.and(col("d").or(col("e")))
.and(col("c").or(col("f")).and(col("d").or(col("f")))),
);

assert_eq!(expected, to_cnf(expr));
}
}
61 changes: 59 additions & 2 deletions src/daft-logical-plan/src/optimization/rules/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{
use common_error::DaftResult;
use common_scan_info::{rewrite_predicate_for_partitioning, PredicateGroups};
use common_treenode::{DynTreeNode, Transformed, TreeNode};
use daft_algebra::boolean::{combine_conjunction, split_conjunction};
use daft_algebra::boolean::{combine_conjunction, split_conjunction, to_cnf};
use daft_core::join::JoinType;
use daft_dsl::{
col,
Expand Down Expand Up @@ -273,7 +273,8 @@ impl PushDownFilter {
let left_cols = HashSet::<_>::from_iter(child_join.left.schema().names());
let right_cols = HashSet::<_>::from_iter(child_join.right.schema().names());

for predicate in split_conjunction(&filter.predicate) {
// TODO: simplify predicates, since they may be expanded with `to_cnf`
for predicate in split_conjunction(&to_cnf(filter.predicate.clone())) {
let pred_cols = HashSet::<_>::from_iter(get_required_columns(&predicate));

match (
Expand Down Expand Up @@ -963,4 +964,60 @@ mod tests {
assert_optimized_plan_eq(plan, expected)?;
Ok(())
}

/// Tests that a complex predicate can be separated so that it can be pushed down into one side of the join.
/// Modeled after TPC-H Q7
#[rstest]
fn filter_commutes_with_join_complex() -> DaftResult<()> {
let left_scan_op = dummy_scan_operator(vec![Field::new("a", DataType::Utf8)]);
let right_scan_op = dummy_scan_operator(vec![Field::new("b", DataType::Utf8)]);

let plan = dummy_scan_node(left_scan_op.clone())
.join(
dummy_scan_node(right_scan_op.clone()),
vec![],
vec![],
JoinType::Inner,
None,
None,
None,
false,
)?
.filter(
(col("a").eq(lit("FRANCE")).and(col("b").eq(lit("GERMANY"))))
.or(col("a").eq(lit("GERMANY")).and(col("b").eq(lit("FRANCE")))),
)?
.build();

let expected = dummy_scan_node_with_pushdowns(
left_scan_op,
Pushdowns::default().with_filters(Some(
col("a").eq(lit("FRANCE")).or(col("a").eq(lit("GERMANY"))),
)),
)
.join(
dummy_scan_node_with_pushdowns(
right_scan_op,
Pushdowns::default().with_filters(Some(
col("b").eq(lit("GERMANY")).or(col("b").eq(lit("FRANCE"))),
)),
),
vec![],
vec![],
JoinType::Inner,
None,
None,
None,
false,
)?
.filter(
(col("b").eq(lit("GERMANY")).or(col("a").eq(lit("GERMANY"))))
.and(col("a").eq(lit("FRANCE")).or(col("b").eq(lit("FRANCE")))),
)?
.build();

assert_optimized_plan_eq(plan, expected)?;

Ok(())
}
}

0 comments on commit 28d3fa7

Please sign in to comment.