Skip to content

Commit

Permalink
test_merge_does_not_unfactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiayue Charles Lin committed Sep 14, 2023
1 parent c5adbeb commit dc0aab3
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/daft-plan/src/logical_ops/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,8 @@ mod tests {
Ok(())
}

/// Test that common leaf expressions are not factored out.
/// Test that common leaf expressions are not factored out
/// (since this would not save computation and only introduces another materialization)
/// e.g.
/// 3 as x, 3 as y, a as w, a as z
/// ->
Expand Down
60 changes: 60 additions & 0 deletions src/daft-plan/src/optimization/rules/push_down_projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -488,3 +488,63 @@ impl OptimizerRule for PushDownProjection {
}
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use common_error::DaftResult;
use daft_core::{datatypes::Field, DataType};
use daft_dsl::{binary_op, col, lit, Operator};

use crate::{
logical_ops::Project,
optimization::{
optimizer::{RuleBatch, RuleExecutionStrategy},
rules::PushDownProjection,
Optimizer,
},
test::dummy_scan_node,
JoinType, LogicalPlan, PartitionScheme,
};

/// Helper that creates an optimizer with the PushDownFilter rule registered, optimizes
/// the provided plan with said optimizer, and compares the optimized plan's repr with
/// the provided expected repr.
fn assert_optimized_plan_eq(plan: Arc<LogicalPlan>, expected: &str) -> DaftResult<()> {
let optimizer = Optimizer::with_rule_batches(
vec![RuleBatch::new(
vec![Box::new(PushDownProjection::new())],
RuleExecutionStrategy::Once,
)],
Default::default(),
);
let optimized_plan = optimizer
.optimize_with_rules(
optimizer.rule_batches[0].rules.as_slice(),
plan.clone(),
&optimizer.rule_batches[0].order,
)?
.unwrap()
.clone();
assert_eq!(optimized_plan.repr_indent(), expected);

Ok(())
}

/// Projection merging: Ensure factored projections do not get merged.
#[test]
fn test_merge_does_not_unfactor() -> DaftResult<()> {
let a2 = binary_op(Operator::Plus, &col("a"), &col("a"));
let a4 = binary_op(Operator::Plus, &a2, &a2);
let a8 = binary_op(Operator::Plus, &a4, &a4);
let expressions = vec![a8.alias("x")];
let unoptimized = dummy_scan_node(vec![Field::new("a", DataType::Int64)])
.project(expressions, Default::default())?
.build();

let expected = unoptimized.repr_indent();
assert_optimized_plan_eq(unoptimized, expected.as_str())?;
Ok(())
}
}

0 comments on commit dc0aab3

Please sign in to comment.