diff --git a/src/daft-plan/src/logical_ops/project.rs b/src/daft-plan/src/logical_ops/project.rs index bb9c754dbc..7b7218e569 100644 --- a/src/daft-plan/src/logical_ops/project.rs +++ b/src/daft-plan/src/logical_ops/project.rs @@ -226,7 +226,6 @@ impl Project { // (a maybe new input node, and a maybe new list of projection expressions). let upstream_schema = input.schema(); let (projection, substitutions) = Self::factor_expressions(projection, &upstream_schema); - // If there are substitutions to factor out, // create a child projection node to do the factoring. let input = if substitutions.is_empty() { @@ -317,10 +316,23 @@ impl Project { let substituted_expressions = exprs .iter() .map(|e| { - replace_column_with_semantic_id(e.clone().into(), &subexprs_to_replace, schema) - .unwrap() - .as_ref() - .clone() + let new_expr = replace_column_with_semantic_id( + e.clone().into(), + &subexprs_to_replace, + schema, + ) + .unwrap() + .as_ref() + .clone(); + // The substitution can unintentionally change the expression's name + // (since the name depends on the first column referenced, which can be substituted away) + // so re-alias the original name here if it has changed. + let old_name = e.name().unwrap(); + if new_expr.name().unwrap() != old_name { + Expr::Alias(new_expr.into(), old_name.into()) + } else { + new_expr + } }) .collect::>(); diff --git a/src/daft-plan/src/optimization/rules/push_down_projection.rs b/src/daft-plan/src/optimization/rules/push_down_projection.rs index a4259e0c76..dfef1fef25 100644 --- a/src/daft-plan/src/optimization/rules/push_down_projection.rs +++ b/src/daft-plan/src/optimization/rules/push_down_projection.rs @@ -547,4 +547,37 @@ mod tests { assert_optimized_plan_eq(unoptimized, expected.as_str())?; Ok(()) } + + /// Projection merging: Ensure merging happens even when there is computation + /// in both the parent and the child. + #[test] + fn test_merge_projections() -> DaftResult<()> { + let unoptimized = dummy_scan_node(vec![ + Field::new("a", DataType::Int64), + Field::new("b", DataType::Int64), + ]) + .project( + vec![ + binary_op(Operator::Plus, &col("a"), &lit(1)), + binary_op(Operator::Plus, &col("b"), &lit(2)), + col("a").alias("c"), + ], + Default::default(), + )? + .project( + vec![ + binary_op(Operator::Plus, &col("a"), &lit(3)), + col("b"), + binary_op(Operator::Plus, &col("c"), &lit(4)), + ], + Default::default(), + )? + .build(); + + let expected = "\ + Project: [col(a) + lit(1)] + lit(3), col(b) + lit(2), col(a) + lit(4), Partition spec = PartitionSpec { scheme: Unknown, num_partitions: 1, by: None }\ + \n Source: Json, File paths = [/foo], File schema = a (Int64), b (Int64), Format-specific config = Json(JsonSourceConfig), Storage config = Native(NativeStorageConfig { io_config: None }), Output schema = a (Int64), b (Int64)"; + assert_optimized_plan_eq(unoptimized, expected)?; + Ok(()) + } }