Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(optimizer): convert filter predicate to CNF to push through join #3623

Merged
merged 3 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 161 additions & 1 deletion src/daft-algebra/src/boolean.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use common_treenode::{TreeNode, TreeNodeRecursion};
use std::sync::Arc;

use common_treenode::{Transformed, TreeNode, TreeNodeRecursion};
use daft_dsl::{Expr, ExprRef, Operator};

pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
Expand All @@ -22,3 +24,161 @@ pub fn split_conjunction(expr: &ExprRef) -> Vec<ExprRef> {
pub fn combine_conjunction<T: IntoIterator<Item = ExprRef>>(exprs: T) -> Option<ExprRef> {
exprs.into_iter().reduce(|acc, e| acc.and(e))
}

/// Converts a boolean expression to conjunctive normal form (AND of ORs)
pub fn to_cnf(expr: ExprRef) -> ExprRef {
let dnf_form = to_dnf(expr.not()).not();

andrewgazelka marked this conversation as resolved.
Show resolved Hide resolved
apply_de_morgans(dnf_form).data
}

/// Converts a boolean expression to disjunctive normal form (OR of ANDs)
pub fn to_dnf(expr: ExprRef) -> ExprRef {
let dm_expr = apply_de_morgans(expr).data;

// apply distributive property recursively
dm_expr
.transform_down(|e| {
Ok(match e.as_ref() {
Expr::BinaryOp {
op: Operator::And,
left,
right,
} => {
if let Expr::BinaryOp {
op: Operator::Or,
left: right_left,
right: right_right,
} = right.as_ref()
{
// (x & (y | z)) -> ((x & y) | (x & z))
Transformed::yes(Arc::new(Expr::BinaryOp {
op: Operator::Or,
left: Arc::new(Expr::BinaryOp {
op: Operator::And,
left: left.clone(),
right: right_left.clone(),
}),
right: Arc::new(Expr::BinaryOp {
op: Operator::And,
left: left.clone(),
right: right_right.clone(),
}),
}))
} else if let Expr::BinaryOp {
op: Operator::Or,
left: left_left,
right: left_right,
} = left.as_ref()
{
// ((x | y) & z) -> ((x & z) | (y & z))
Transformed::yes(Arc::new(Expr::BinaryOp {
op: Operator::Or,
left: Arc::new(Expr::BinaryOp {
op: Operator::And,
left: left_left.clone(),
right: right.clone(),
}),
right: Arc::new(Expr::BinaryOp {
op: Operator::And,
left: left_right.clone(),
right: right.clone(),
}),
}))
} else {
Transformed::no(e)
}
}
_ => Transformed::no(e),
})
})
.unwrap()
.data
}

/// Transform boolean expression by applying De Morgan's law + eliminate double negations recursively
fn apply_de_morgans(expr: ExprRef) -> Transformed<ExprRef> {
expr.transform_down(|e| {
Ok(match e.as_ref() {
Expr::Not(ne) => match ne.as_ref() {
// !x -> x
Expr::Not(nne) => Transformed::yes(nne.clone()),
// !(x & y) -> ((!x) | (!y))
Expr::BinaryOp {
op: Operator::And,
left,
right,
} => Transformed::yes(Arc::new(Expr::BinaryOp {
op: Operator::Or,
left: left.clone().not(),
right: right.clone().not(),
})),
// !(x | y) -> ((!x) & (!y))
Expr::BinaryOp {
op: Operator::Or,
left,
right,
} => Transformed::yes(Arc::new(Expr::BinaryOp {
op: Operator::And,
left: left.clone().not(),
right: right.clone().not(),
})),
_ => Transformed::no(e),
},
_ => Transformed::no(e),
})
})
.unwrap()
}

#[cfg(test)]
mod tests {
use daft_dsl::col;

use crate::boolean::{to_cnf, to_dnf};

#[test]
fn dnf_simple() {
// a & (b | c) -> (a & b) | (a & c)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have more tests with negations?

let expr = col("a").and(col("b").or(col("c")));
let expected = col("a").and(col("b")).or(col("a").and(col("c")));

assert_eq!(expected, to_dnf(expr));
}

#[test]
fn cnf_simple() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in future I suppose could do some rust proptest stuff here to verify they are correct but not needed for this issue imo

// a | (b & c) -> (a | b) & (a | c)
let expr = col("a").or(col("b").and(col("c")));
let expected = col("a").or(col("b")).and(col("a").or(col("c")));

assert_eq!(expected, to_cnf(expr));
}

#[test]
fn dnf_nested() {
// a & b & ((c & d) | (e & f)) -> (a & b & c & d) | (a & b & e & f)
let expr = col("a")
.and(col("b"))
.and((col("c").and(col("d"))).or(col("e").and(col("f"))));
let expected = (col("a").and(col("b")).and(col("c").and(col("d"))))
.or(col("a").and(col("b")).and(col("e").and(col("f"))));

assert_eq!(expected, to_dnf(expr));
}

#[test]
fn cnf_nested() {
// a & b & ((c & d) | (e & f)) -> a & b & (c | e) & (c | f) & (d | e) & (d | f)
let expr = col("a")
.and(col("b"))
.and((col("c").and(col("d"))).or(col("e").and(col("f"))));
let expected = col("a").and(col("b")).and(
(col("c").or(col("e")))
.and(col("d").or(col("e")))
.and(col("c").or(col("f")).and(col("d").or(col("f")))),
);

assert_eq!(expected, to_cnf(expr));
}
}
61 changes: 59 additions & 2 deletions src/daft-logical-plan/src/optimization/rules/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{
use common_error::DaftResult;
use common_scan_info::{rewrite_predicate_for_partitioning, PredicateGroups};
use common_treenode::{DynTreeNode, Transformed, TreeNode};
use daft_algebra::boolean::{combine_conjunction, split_conjunction};
use daft_algebra::boolean::{combine_conjunction, split_conjunction, to_cnf};
use daft_core::join::JoinType;
use daft_dsl::{
col,
Expand Down Expand Up @@ -273,7 +273,8 @@ impl PushDownFilter {
let left_cols = HashSet::<_>::from_iter(child_join.left.schema().names());
let right_cols = HashSet::<_>::from_iter(child_join.right.schema().names());

for predicate in split_conjunction(&filter.predicate) {
// TODO: simplify predicates, since they may be expanded with `to_cnf`
for predicate in split_conjunction(&to_cnf(filter.predicate.clone())) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this only applicable to predicates? I'm wondering if this should be part of the SimplifyExprs optimizer pass?

Copy link
Member Author

@kevinzwang kevinzwang Dec 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be a general simplification yes. Just put the comment here because this part specifically expands the predicate

let pred_cols = HashSet::<_>::from_iter(get_required_columns(&predicate));

match (
Expand Down Expand Up @@ -963,4 +964,60 @@ mod tests {
assert_optimized_plan_eq(plan, expected)?;
Ok(())
}

/// Tests that a complex predicate can be separated so that it can be pushed down into one side of the join.
/// Modeled after TPC-H Q7
#[rstest]
fn filter_commutes_with_join_complex() -> DaftResult<()> {
let left_scan_op = dummy_scan_operator(vec![Field::new("a", DataType::Utf8)]);
let right_scan_op = dummy_scan_operator(vec![Field::new("b", DataType::Utf8)]);

let plan = dummy_scan_node(left_scan_op.clone())
.join(
dummy_scan_node(right_scan_op.clone()),
vec![],
vec![],
JoinType::Inner,
None,
None,
None,
false,
)?
.filter(
(col("a").eq(lit("FRANCE")).and(col("b").eq(lit("GERMANY"))))
.or(col("a").eq(lit("GERMANY")).and(col("b").eq(lit("FRANCE")))),
)?
.build();

let expected = dummy_scan_node_with_pushdowns(
left_scan_op,
Pushdowns::default().with_filters(Some(
col("a").eq(lit("FRANCE")).or(col("a").eq(lit("GERMANY"))),
)),
)
.join(
dummy_scan_node_with_pushdowns(
right_scan_op,
Pushdowns::default().with_filters(Some(
col("b").eq(lit("GERMANY")).or(col("b").eq(lit("FRANCE"))),
)),
),
vec![],
vec![],
JoinType::Inner,
None,
None,
None,
false,
)?
.filter(
(col("b").eq(lit("GERMANY")).or(col("a").eq(lit("GERMANY"))))
.and(col("a").eq(lit("FRANCE")).or(col("b").eq(lit("FRANCE")))),
)?
.build();

assert_optimized_plan_eq(plan, expected)?;

Ok(())
}
}
Loading