diff --git a/src/daft-plan/src/logical_ops/project.rs b/src/daft-plan/src/logical_ops/project.rs index 8e69146ea9..bb9c754dbc 100644 --- a/src/daft-plan/src/logical_ops/project.rs +++ b/src/daft-plan/src/logical_ops/project.rs @@ -603,7 +603,8 @@ mod tests { Ok(()) } - /// Test that common leaf expressions are not factored out. + /// Test that common leaf expressions are not factored out + /// (since this would not save computation and only introduces another materialization) /// e.g. /// 3 as x, 3 as y, a as w, a as z /// -> diff --git a/src/daft-plan/src/optimization/rules/push_down_projection.rs b/src/daft-plan/src/optimization/rules/push_down_projection.rs index 78b55655e7..a4259e0c76 100644 --- a/src/daft-plan/src/optimization/rules/push_down_projection.rs +++ b/src/daft-plan/src/optimization/rules/push_down_projection.rs @@ -488,3 +488,63 @@ impl OptimizerRule for PushDownProjection { } } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use common_error::DaftResult; + use daft_core::{datatypes::Field, DataType}; + use daft_dsl::{binary_op, col, lit, Operator}; + + use crate::{ + logical_ops::Project, + optimization::{ + optimizer::{RuleBatch, RuleExecutionStrategy}, + rules::PushDownProjection, + Optimizer, + }, + test::dummy_scan_node, + JoinType, LogicalPlan, PartitionScheme, + }; + + /// Helper that creates an optimizer with the PushDownFilter rule registered, optimizes + /// the provided plan with said optimizer, and compares the optimized plan's repr with + /// the provided expected repr. + fn assert_optimized_plan_eq(plan: Arc, expected: &str) -> DaftResult<()> { + let optimizer = Optimizer::with_rule_batches( + vec![RuleBatch::new( + vec![Box::new(PushDownProjection::new())], + RuleExecutionStrategy::Once, + )], + Default::default(), + ); + let optimized_plan = optimizer + .optimize_with_rules( + optimizer.rule_batches[0].rules.as_slice(), + plan.clone(), + &optimizer.rule_batches[0].order, + )? + .unwrap() + .clone(); + assert_eq!(optimized_plan.repr_indent(), expected); + + Ok(()) + } + + /// Projection merging: Ensure factored projections do not get merged. + #[test] + fn test_merge_does_not_unfactor() -> DaftResult<()> { + let a2 = binary_op(Operator::Plus, &col("a"), &col("a")); + let a4 = binary_op(Operator::Plus, &a2, &a2); + let a8 = binary_op(Operator::Plus, &a4, &a4); + let expressions = vec![a8.alias("x")]; + let unoptimized = dummy_scan_node(vec![Field::new("a", DataType::Int64)]) + .project(expressions, Default::default())? + .build(); + + let expected = unoptimized.repr_indent(); + assert_optimized_plan_eq(unoptimized, expected.as_str())?; + Ok(()) + } +}