diff --git a/src/daft-logical-plan/src/builder.rs b/src/daft-logical-plan/src/builder.rs index ece988e08c..d82186c60e 100644 --- a/src/daft-logical-plan/src/builder.rs +++ b/src/daft-logical-plan/src/builder.rs @@ -464,7 +464,8 @@ impl LogicalPlanBuilder { } pub fn union(&self, other: &Self, is_all: bool) -> DaftResult { let logical_plan: LogicalPlan = - ops::Union::new(self.plan.clone(), other.plan.clone(), is_all).to_logical_plan()?; + ops::Union::try_new(self.plan.clone(), other.plan.clone(), is_all)? + .to_logical_plan()?; Ok(self.with_new_plan(logical_plan)) } diff --git a/src/daft-logical-plan/src/logical_plan.rs b/src/daft-logical-plan/src/logical_plan.rs index f1fc3bbd8c..71f6d5bc96 100644 --- a/src/daft-logical-plan/src/logical_plan.rs +++ b/src/daft-logical-plan/src/logical_plan.rs @@ -279,7 +279,7 @@ impl LogicalPlan { Self::Source(_) => panic!("Source nodes don't have children, with_new_children() should never be called for Source ops"), Self::Concat(_) => Self::Concat(Concat::try_new(input1.clone(), input2.clone()).unwrap()), Self::Intersect(inner) => Self::Intersect(Intersect::try_new(input1.clone(), input2.clone(), inner.is_all).unwrap()), - Self::Union(inner) => Self::Union(Union::new(input1.clone(), input2.clone(), inner.is_all)), + Self::Union(inner) => Self::Union(Union::try_new(input1.clone(), input2.clone(), inner.is_all).unwrap()), Self::Join(Join { left_on, right_on, null_equals_nulls, join_type, join_strategy, .. }) => Self::Join(Join::try_new( input1.clone(), input2.clone(), diff --git a/src/daft-logical-plan/src/ops/set_operations.rs b/src/daft-logical-plan/src/ops/set_operations.rs index 9a37cfebfa..017104f226 100644 --- a/src/daft-logical-plan/src/ops/set_operations.rs +++ b/src/daft-logical-plan/src/ops/set_operations.rs @@ -128,8 +128,23 @@ impl Union { /// > select * from t0 union select * from t1; /// ``` /// This is valid in Union, but not in Concat - pub(crate) fn new(lhs: Arc, rhs: Arc, is_all: bool) -> Self { - Self { lhs, rhs, is_all } + pub(crate) fn try_new( + lhs: Arc, + rhs: Arc, + is_all: bool, + ) -> logical_plan::Result { + if lhs.schema().len() != rhs.schema().len() { + return Err(DaftError::SchemaMismatch(format!( + "Both plans must have the same num of fields to union, \ + but got[lhs: {} v.s rhs: {}], lhs schema: {}, rhs schema: {}", + lhs.schema().len(), + rhs.schema().len(), + lhs.schema(), + rhs.schema() + ))) + .context(CreationSnafu); + } + Ok(Self { lhs, rhs, is_all }) } /// union could be represented as a concat + distinct diff --git a/src/daft-logical-plan/src/optimization/rules/push_down_projection.rs b/src/daft-logical-plan/src/optimization/rules/push_down_projection.rs index e2a237d0fd..5451106d2d 100644 --- a/src/daft-logical-plan/src/optimization/rules/push_down_projection.rs +++ b/src/daft-logical-plan/src/optimization/rules/push_down_projection.rs @@ -417,6 +417,11 @@ impl PushDownProjection { Ok(new_plan) } LogicalPlan::Union(union) => { + if !union.is_all { + // can not push down past a DISTINCT + return Ok(Transformed::no(plan)); + } + // Get required columns from projection and upstream. let combined_dependencies = plan .required_columns() diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 3aeb60768f..1dbf6512a2 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -236,7 +236,7 @@ impl SQLPlanner { return left.union(&right, true).map_err(|e| e.into()); } - (Union, SetQuantifier::None) => { + (Union, SetQuantifier::None | SetQuantifier::Distinct) => { let left = self.plan_query(&make_query(left))?; let right = self.plan_query(&make_query(right))?; return left.union(&right, false).map_err(|e| e.into()); @@ -247,7 +247,7 @@ impl SQLPlanner { let right = self.plan_query(&make_query(right))?; return left.intersect(&right, true).map_err(|e| e.into()); } - (Intersect, SetQuantifier::None) => { + (Intersect, SetQuantifier::None | SetQuantifier::Distinct) => { let left = self.plan_query(&make_query(left))?; let right = self.plan_query(&make_query(right))?; return left.intersect(&right, false).map_err(|e| e.into());