diff --git a/src/daft-algebra/src/boolean.rs b/src/daft-algebra/src/boolean.rs index 38f659e00c..30b4f53392 100644 --- a/src/daft-algebra/src/boolean.rs +++ b/src/daft-algebra/src/boolean.rs @@ -1,4 +1,6 @@ -use common_treenode::{TreeNode, TreeNodeRecursion}; +use std::sync::Arc; + +use common_treenode::{Transformed, TreeNode, TreeNodeRecursion}; use daft_dsl::{Expr, ExprRef, Operator}; pub fn split_conjunction(expr: &ExprRef) -> Vec { @@ -22,3 +24,179 @@ pub fn split_conjunction(expr: &ExprRef) -> Vec { pub fn combine_conjunction>(exprs: T) -> Option { exprs.into_iter().reduce(|acc, e| acc.and(e)) } + +/// Converts a boolean expression to conjunctive normal form (AND of ORs) +pub fn to_cnf(expr: ExprRef) -> ExprRef { + let dnf_form = to_dnf(expr.not()).not(); + + apply_de_morgans(dnf_form).data +} + +/// Converts a boolean expression to disjunctive normal form (OR of ANDs) +pub fn to_dnf(expr: ExprRef) -> ExprRef { + let dm_expr = apply_de_morgans(expr).data; + + // apply distributive property recursively + dm_expr + .transform_down(|e| { + Ok(match e.as_ref() { + Expr::BinaryOp { + op: Operator::And, + left, + right, + } => { + if let Expr::BinaryOp { + op: Operator::Or, + left: right_left, + right: right_right, + } = right.as_ref() + { + // (x & (y | z)) -> ((x & y) | (x & z)) + Transformed::yes(Arc::new(Expr::BinaryOp { + op: Operator::Or, + left: Arc::new(Expr::BinaryOp { + op: Operator::And, + left: left.clone(), + right: right_left.clone(), + }), + right: Arc::new(Expr::BinaryOp { + op: Operator::And, + left: left.clone(), + right: right_right.clone(), + }), + })) + } else if let Expr::BinaryOp { + op: Operator::Or, + left: left_left, + right: left_right, + } = left.as_ref() + { + // ((x | y) & z) -> ((x & z) | (y & z)) + Transformed::yes(Arc::new(Expr::BinaryOp { + op: Operator::Or, + left: Arc::new(Expr::BinaryOp { + op: Operator::And, + left: left_left.clone(), + right: right.clone(), + }), + right: Arc::new(Expr::BinaryOp { + op: Operator::And, + left: left_right.clone(), + right: right.clone(), + }), + })) + } else { + Transformed::no(e) + } + } + _ => Transformed::no(e), + }) + }) + .unwrap() + .data +} + +/// Transform boolean expression by applying De Morgan's law + eliminate double negations recursively +fn apply_de_morgans(expr: ExprRef) -> Transformed { + expr.transform_down(|e| { + Ok(match e.as_ref() { + Expr::Not(ne) => match ne.as_ref() { + // !x -> x + Expr::Not(nne) => Transformed::yes(nne.clone()), + // !(x & y) -> ((!x) | (!y)) + Expr::BinaryOp { + op: Operator::And, + left, + right, + } => Transformed::yes(Arc::new(Expr::BinaryOp { + op: Operator::Or, + left: left.clone().not(), + right: right.clone().not(), + })), + // !(x | y) -> ((!x) & (!y)) + Expr::BinaryOp { + op: Operator::Or, + left, + right, + } => Transformed::yes(Arc::new(Expr::BinaryOp { + op: Operator::And, + left: left.clone().not(), + right: right.clone().not(), + })), + _ => Transformed::no(e), + }, + _ => Transformed::no(e), + }) + }) + .unwrap() +} + +#[cfg(test)] +mod tests { + use daft_dsl::col; + + use crate::boolean::{to_cnf, to_dnf}; + + #[test] + fn dnf_simple() { + // a & (b | c) -> (a & b) | (a & c) + let expr = col("a").and(col("b").or(col("c"))); + let expected = col("a").and(col("b")).or(col("a").and(col("c"))); + + assert_eq!(expected, to_dnf(expr)); + } + + #[test] + fn cnf_simple() { + // a | (b & c) -> (a | b) & (a | c) + let expr = col("a").or(col("b").and(col("c"))); + let expected = col("a").or(col("b")).and(col("a").or(col("c"))); + + assert_eq!(expected, to_cnf(expr)); + } + + #[test] + fn dnf_neg() { + // !(a & ((!b) | c)) -> (!a) | (b & (!c)) + let expr = col("a").and(col("b").not().or(col("c"))).not(); + let expected = col("a").not().or(col("b").and(col("c").not())); + + assert_eq!(expected, to_dnf(expr)); + } + + #[test] + fn cnf_neg() { + // !(a | ((!b) & c)) -> (!a) & (b | (!c)) + let expr = col("a").or(col("b").not().and(col("c"))).not(); + let expected = col("a").not().and(col("b").or(col("c").not())); + + assert_eq!(expected, to_cnf(expr)); + } + + #[test] + fn dnf_nested() { + // a & b & ((c & d) | (e & f)) -> (a & b & c & d) | (a & b & e & f) + let expr = col("a") + .and(col("b")) + .and((col("c").and(col("d"))).or(col("e").and(col("f")))); + let expected = (col("a").and(col("b")).and(col("c").and(col("d")))) + .or(col("a").and(col("b")).and(col("e").and(col("f")))); + + assert_eq!(expected, to_dnf(expr)); + } + + #[test] + fn cnf_nested() { + // a & b & ((c & d) | (e & f)) -> a & b & (c | e) & (c | f) & (d | e) & (d | f) + let expr = col("a") + .and(col("b")) + .and((col("c").and(col("d"))).or(col("e").and(col("f")))); + let expected = col("a").and(col("b")).and( + (col("c").or(col("e"))) + .and(col("d").or(col("e"))) + .and(col("c").or(col("f")).and(col("d").or(col("f")))), + ); + + assert_eq!(expected, to_cnf(expr)); + } +} diff --git a/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs b/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs index 6e5be33c40..442d8120dc 100644 --- a/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs +++ b/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs @@ -6,7 +6,7 @@ use std::{ use common_error::DaftResult; use common_scan_info::{rewrite_predicate_for_partitioning, PredicateGroups}; use common_treenode::{DynTreeNode, Transformed, TreeNode}; -use daft_algebra::boolean::{combine_conjunction, split_conjunction}; +use daft_algebra::boolean::{combine_conjunction, split_conjunction, to_cnf}; use daft_core::join::JoinType; use daft_dsl::{ col, @@ -273,7 +273,8 @@ impl PushDownFilter { let left_cols = HashSet::<_>::from_iter(child_join.left.schema().names()); let right_cols = HashSet::<_>::from_iter(child_join.right.schema().names()); - for predicate in split_conjunction(&filter.predicate) { + // TODO: simplify predicates, since they may be expanded with `to_cnf` + for predicate in split_conjunction(&to_cnf(filter.predicate.clone())) { let pred_cols = HashSet::<_>::from_iter(get_required_columns(&predicate)); match ( @@ -963,4 +964,60 @@ mod tests { assert_optimized_plan_eq(plan, expected)?; Ok(()) } + + /// Tests that a complex predicate can be separated so that it can be pushed down into one side of the join. + /// Modeled after TPC-H Q7 + #[rstest] + fn filter_commutes_with_join_complex() -> DaftResult<()> { + let left_scan_op = dummy_scan_operator(vec![Field::new("a", DataType::Utf8)]); + let right_scan_op = dummy_scan_operator(vec![Field::new("b", DataType::Utf8)]); + + let plan = dummy_scan_node(left_scan_op.clone()) + .join( + dummy_scan_node(right_scan_op.clone()), + vec![], + vec![], + JoinType::Inner, + None, + None, + None, + false, + )? + .filter( + (col("a").eq(lit("FRANCE")).and(col("b").eq(lit("GERMANY")))) + .or(col("a").eq(lit("GERMANY")).and(col("b").eq(lit("FRANCE")))), + )? + .build(); + + let expected = dummy_scan_node_with_pushdowns( + left_scan_op, + Pushdowns::default().with_filters(Some( + col("a").eq(lit("FRANCE")).or(col("a").eq(lit("GERMANY"))), + )), + ) + .join( + dummy_scan_node_with_pushdowns( + right_scan_op, + Pushdowns::default().with_filters(Some( + col("b").eq(lit("GERMANY")).or(col("b").eq(lit("FRANCE"))), + )), + ), + vec![], + vec![], + JoinType::Inner, + None, + None, + None, + false, + )? + .filter( + (col("b").eq(lit("GERMANY")).or(col("a").eq(lit("GERMANY")))) + .and(col("a").eq(lit("FRANCE")).or(col("b").eq(lit("FRANCE")))), + )? + .build(); + + assert_optimized_plan_eq(plan, expected)?; + + Ok(()) + } }