Skip to content

Commit

Permalink
test_merge_projections
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiayue Charles Lin committed Sep 14, 2023
1 parent dc0aab3 commit c202431
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 5 deletions.
22 changes: 17 additions & 5 deletions src/daft-plan/src/logical_ops/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ impl Project {
// (a maybe new input node, and a maybe new list of projection expressions).
let upstream_schema = input.schema();
let (projection, substitutions) = Self::factor_expressions(projection, &upstream_schema);

// If there are substitutions to factor out,
// create a child projection node to do the factoring.
let input = if substitutions.is_empty() {
Expand Down Expand Up @@ -317,10 +316,23 @@ impl Project {
let substituted_expressions = exprs
.iter()
.map(|e| {
replace_column_with_semantic_id(e.clone().into(), &subexprs_to_replace, schema)
.unwrap()
.as_ref()
.clone()
let new_expr = replace_column_with_semantic_id(
e.clone().into(),
&subexprs_to_replace,
schema,
)
.unwrap()
.as_ref()
.clone();
// The substitution can unintentionally change the expression's name
// (since the name depends on the first column referenced, which can be substituted away)
// so re-alias the original name here if it has changed.
let old_name = e.name().unwrap();
if new_expr.name().unwrap() != old_name {
Expr::Alias(new_expr.into(), old_name.into())
} else {
new_expr
}
})
.collect::<Vec<_>>();

Expand Down
33 changes: 33 additions & 0 deletions src/daft-plan/src/optimization/rules/push_down_projection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -547,4 +547,37 @@ mod tests {
assert_optimized_plan_eq(unoptimized, expected.as_str())?;
Ok(())
}

/// Projection merging: Ensure merging happens even when there is computation
/// in both the parent and the child.
#[test]
fn test_merge_projections() -> DaftResult<()> {
let unoptimized = dummy_scan_node(vec![
Field::new("a", DataType::Int64),
Field::new("b", DataType::Int64),
])
.project(
vec![
binary_op(Operator::Plus, &col("a"), &lit(1)),
binary_op(Operator::Plus, &col("b"), &lit(2)),
col("a").alias("c"),
],
Default::default(),
)?
.project(
vec![
binary_op(Operator::Plus, &col("a"), &lit(3)),
col("b"),
binary_op(Operator::Plus, &col("c"), &lit(4)),
],
Default::default(),
)?
.build();

let expected = "\
Project: [col(a) + lit(1)] + lit(3), col(b) + lit(2), col(a) + lit(4), Partition spec = PartitionSpec { scheme: Unknown, num_partitions: 1, by: None }\
\n Source: Json, File paths = [/foo], File schema = a (Int64), b (Int64), Format-specific config = Json(JsonSourceConfig), Storage config = Native(NativeStorageConfig { io_config: None }), Output schema = a (Int64), b (Int64)";
assert_optimized_plan_eq(unoptimized, expected)?;
Ok(())
}
}

0 comments on commit c202431

Please sign in to comment.