From 67cd40cb604bb1bd9f75d5eacd017276a1429ec4 Mon Sep 17 00:00:00 2001 From: desmondcheongzx Date: Fri, 22 Nov 2024 18:26:58 -0800 Subject: [PATCH] Move materialized scan tasks into Source nodes --- src/common/scan-info/src/lib.rs | 30 +++++++++- src/daft-local-plan/src/translate.rs | 29 +++------ src/daft-logical-plan/src/logical_plan.rs | 12 ---- .../src/ops/materialized_scan_source.rs | 59 ------------------- src/daft-logical-plan/src/ops/mod.rs | 2 - src/daft-logical-plan/src/ops/source.rs | 48 ++++++++------- .../optimization/rules/enrich_with_stats.rs | 1 + .../optimization/rules/materialize_scans.rs | 15 ++--- .../optimization/rules/push_down_filter.rs | 2 +- .../src/optimization/rules/push_down_limit.rs | 7 ++- .../rules/push_down_projection.rs | 3 - .../src/physical_planner/translate.rs | 39 ++++-------- 12 files changed, 87 insertions(+), 160 deletions(-) delete mode 100644 src/daft-logical-plan/src/ops/materialized_scan_source.rs diff --git a/src/common/scan-info/src/lib.rs b/src/common/scan-info/src/lib.rs index ba3a201614..ac33fcbb62 100644 --- a/src/common/scan-info/src/lib.rs +++ b/src/common/scan-info/src/lib.rs @@ -21,9 +21,33 @@ pub use python::register_modules; pub use scan_operator::{ScanOperator, ScanOperatorRef}; pub use scan_task::{BoxScanTaskLikeIter, ScanTaskLike, ScanTaskLikeRef}; +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum ScanState { + Operator(ScanOperatorRef), + Tasks(Vec), +} + +impl ScanState { + pub fn multiline_display(&self) -> Vec { + match self { + Self::Operator(scan_op) => scan_op.0.multiline_display(), + Self::Tasks(scan_tasks) => { + vec![format!("Num Scan Tasks = {}", scan_tasks.len())] + } + } + } + + pub fn get_scan_op(&self) -> &ScanOperatorRef { + match self { + Self::Operator(scan_op) => scan_op, + Self::Tasks(_) => panic!("Tried to get scan op from materialized physical scan info"), + } + } +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct PhysicalScanInfo { - pub scan_op: ScanOperatorRef, + pub scan_state: ScanState, pub source_schema: SchemaRef, pub partitioning_keys: Vec, pub pushdowns: Pushdowns, @@ -38,7 +62,7 @@ impl PhysicalScanInfo { pushdowns: Pushdowns, ) -> Self { Self { - scan_op, + scan_state: ScanState::Operator(scan_op), source_schema, partitioning_keys, pushdowns, @@ -48,7 +72,7 @@ impl PhysicalScanInfo { #[must_use] pub fn with_pushdowns(&self, pushdowns: Pushdowns) -> Self { Self { - scan_op: self.scan_op.clone(), + scan_state: self.scan_state.clone(), source_schema: self.source_schema.clone(), partitioning_keys: self.partitioning_keys.clone(), pushdowns, diff --git a/src/daft-local-plan/src/translate.rs b/src/daft-local-plan/src/translate.rs index f6e51e0576..29e4081378 100644 --- a/src/daft-local-plan/src/translate.rs +++ b/src/daft-local-plan/src/translate.rs @@ -1,9 +1,8 @@ use common_error::{DaftError, DaftResult}; +use common_scan_info::ScanState; use daft_core::join::JoinStrategy; use daft_dsl::ExprRef; -use daft_logical_plan::{ - ops::MaterializedScanSource, JoinType, LogicalPlan, LogicalPlanRef, SourceInfo, -}; +use daft_logical_plan::{JoinType, LogicalPlan, LogicalPlanRef, SourceInfo}; use super::plan::{LocalPhysicalPlan, LocalPhysicalPlanRef}; @@ -17,7 +16,12 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { )), SourceInfo::Physical(info) => { // We should be able to pass the ScanOperator into the physical plan directly but we need to figure out the serialization story - let scan_tasks = info.scan_op.0.to_scan_tasks(info.pushdowns.clone(), None)?; + let scan_tasks = match &info.scan_state { + ScanState::Operator(scan_op) => { + scan_op.0.to_scan_tasks(info.pushdowns.clone(), None)? + } + ScanState::Tasks(scan_tasks) => scan_tasks.clone(), + }; if scan_tasks.is_empty() { Ok(LocalPhysicalPlan::empty_scan(source.output_schema.clone())) } else { @@ -34,23 +38,6 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { } } } - LogicalPlan::MaterializedScanSource(MaterializedScanSource { - scan_tasks, - pushdowns, - schema, - stats_state, - }) => { - if scan_tasks.is_empty() { - Ok(LocalPhysicalPlan::empty_scan(schema.clone())) - } else { - Ok(LocalPhysicalPlan::physical_scan( - scan_tasks.clone(), - pushdowns.clone(), - schema.clone(), - stats_state.clone(), - )) - } - } LogicalPlan::Filter(filter) => { let input = translate(&filter.input)?; Ok(LocalPhysicalPlan::filter( diff --git a/src/daft-logical-plan/src/logical_plan.rs b/src/daft-logical-plan/src/logical_plan.rs index f18bc7d1f9..cf013d8295 100644 --- a/src/daft-logical-plan/src/logical_plan.rs +++ b/src/daft-logical-plan/src/logical_plan.rs @@ -14,7 +14,6 @@ use crate::stats::StatsState; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum LogicalPlan { Source(Source), - MaterializedScanSource(MaterializedScanSource), Project(Project), ActorPoolProject(ActorPoolProject), Filter(Filter), @@ -45,7 +44,6 @@ impl LogicalPlan { pub fn schema(&self) -> SchemaRef { match self { Self::Source(Source { output_schema, .. }) => output_schema.clone(), - Self::MaterializedScanSource(MaterializedScanSource { schema, .. }) => schema.clone(), Self::Project(Project { projected_schema, .. }) => projected_schema.clone(), @@ -173,7 +171,6 @@ impl LogicalPlan { Self::Intersect(_) => vec![IndexSet::new(), IndexSet::new()], Self::Union(_) => vec![IndexSet::new(), IndexSet::new()], Self::Source(_) => todo!(), - Self::MaterializedScanSource(_) => todo!(), Self::Sink(_) => todo!(), } } @@ -181,7 +178,6 @@ impl LogicalPlan { pub fn name(&self) -> &'static str { match self { Self::Source(..) => "Source", - Self::MaterializedScanSource(..) => "MaterializedScanSource", Self::Project(..) => "Project", Self::ActorPoolProject(..) => "ActorPoolProject", Self::Filter(..) => "Filter", @@ -206,7 +202,6 @@ impl LogicalPlan { pub fn get_stats(&self) -> &StatsState { match self { Self::Source(Source { stats_state, .. }) - | Self::MaterializedScanSource(MaterializedScanSource { stats_state, .. }) | Self::Project(Project { stats_state, .. }) | Self::ActorPoolProject(ActorPoolProject { stats_state, .. }) | Self::Filter(Filter { stats_state, .. }) @@ -239,9 +234,6 @@ impl LogicalPlan { pub fn with_materialized_stats(self) -> Self { match self { Self::Source(plan) => Self::Source(plan.with_materialized_stats()), - Self::MaterializedScanSource(plan) => { - Self::MaterializedScanSource(plan.with_materialized_stats()) - } Self::Project(plan) => Self::Project(plan.with_materialized_stats()), Self::ActorPoolProject(plan) => Self::ActorPoolProject(plan.with_materialized_stats()), Self::Filter(plan) => Self::Filter(plan.with_materialized_stats()), @@ -272,7 +264,6 @@ impl LogicalPlan { pub fn multiline_display(&self) -> Vec { match self { Self::Source(source) => source.multiline_display(), - Self::MaterializedScanSource(plan) => plan.multiline_display(), Self::Project(projection) => projection.multiline_display(), Self::ActorPoolProject(projection) => projection.multiline_display(), Self::Filter(filter) => filter.multiline_display(), @@ -299,7 +290,6 @@ impl LogicalPlan { pub fn children(&self) -> Vec<&Self> { match self { Self::Source(..) => vec![], - Self::MaterializedScanSource(..) => vec![], Self::Project(Project { input, .. }) => vec![input], Self::ActorPoolProject(ActorPoolProject { input, .. }) => vec![input], Self::Filter(Filter { input, .. }) => vec![input], @@ -327,7 +317,6 @@ impl LogicalPlan { match children { [input] => match self { Self::Source(_) => panic!("Source nodes don't have children, with_new_children() should never be called for Source ops"), - Self::MaterializedScanSource(_) => panic!("MaterializedScanSource nodes don't have children, with_new_children() should never be called for MaterializedScanSource ops"), Self::Project(Project { projection, .. }) => Self::Project(Project::try_new( input.clone(), projection.clone(), ).unwrap()), @@ -460,7 +449,6 @@ macro_rules! impl_from_data_struct_for_logical_plan { } impl_from_data_struct_for_logical_plan!(Source); -impl_from_data_struct_for_logical_plan!(MaterializedScanSource); impl_from_data_struct_for_logical_plan!(Project); impl_from_data_struct_for_logical_plan!(Filter); impl_from_data_struct_for_logical_plan!(Limit); diff --git a/src/daft-logical-plan/src/ops/materialized_scan_source.rs b/src/daft-logical-plan/src/ops/materialized_scan_source.rs deleted file mode 100644 index 29e0267c9d..0000000000 --- a/src/daft-logical-plan/src/ops/materialized_scan_source.rs +++ /dev/null @@ -1,59 +0,0 @@ -use common_scan_info::{Pushdowns, ScanTaskLikeRef}; -use daft_schema::schema::SchemaRef; - -use crate::stats::{ApproxStats, PlanStats, StatsState}; - -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub struct MaterializedScanSource { - pub scan_tasks: Vec, - pub pushdowns: Pushdowns, - pub schema: SchemaRef, - pub stats_state: StatsState, -} - -impl MaterializedScanSource { - pub fn new(scan_tasks: Vec, pushdowns: Pushdowns, schema: SchemaRef) -> Self { - Self { - scan_tasks, - pushdowns, - schema, - stats_state: StatsState::NotMaterialized, - } - } - - pub(crate) fn with_materialized_stats(mut self) -> Self { - let mut approx_stats = ApproxStats::empty(); - for st in &self.scan_tasks { - approx_stats.lower_bound_rows += st.num_rows().unwrap_or(0); - let in_memory_size = st.estimate_in_memory_size_bytes(None); - approx_stats.lower_bound_bytes += in_memory_size.unwrap_or(0); - if let Some(st_ub) = st.upper_bound_rows() { - if let Some(ub) = approx_stats.upper_bound_rows { - approx_stats.upper_bound_rows = Some(ub + st_ub); - } else { - approx_stats.upper_bound_rows = st.upper_bound_rows(); - } - } - if let Some(st_ub) = in_memory_size { - if let Some(ub) = approx_stats.upper_bound_bytes { - approx_stats.upper_bound_bytes = Some(ub + st_ub); - } else { - approx_stats.upper_bound_bytes = in_memory_size; - } - } - } - self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats)); - self - } - - pub fn multiline_display(&self) -> Vec { - let mut res = vec![]; - res.push(format!("Num Scan Tasks = [{}]", self.scan_tasks.len())); - res.extend(self.pushdowns.multiline_display()); - res.push(format!("Output schema = {}", self.schema.short_string())); - if let StatsState::Materialized(stats) = &self.stats_state { - res.push(format!("Stats = {}", stats)); - } - res - } -} diff --git a/src/daft-logical-plan/src/ops/mod.rs b/src/daft-logical-plan/src/ops/mod.rs index ef32df0fcb..e70c5c98d8 100644 --- a/src/daft-logical-plan/src/ops/mod.rs +++ b/src/daft-logical-plan/src/ops/mod.rs @@ -6,7 +6,6 @@ mod explode; mod filter; mod join; mod limit; -mod materialized_scan_source; mod monotonically_increasing_id; mod pivot; mod project; @@ -26,7 +25,6 @@ pub use explode::Explode; pub use filter::Filter; pub use join::Join; pub use limit::Limit; -pub use materialized_scan_source::MaterializedScanSource; pub use monotonically_increasing_id::MonotonicallyIncreasingId; pub use pivot::Pivot; pub use project::Project; diff --git a/src/daft-logical-plan/src/ops/source.rs b/src/daft-logical-plan/src/ops/source.rs index b8de63ebf1..72974291ed 100644 --- a/src/daft-logical-plan/src/ops/source.rs +++ b/src/daft-logical-plan/src/ops/source.rs @@ -1,11 +1,10 @@ use std::sync::Arc; use common_daft_config::DaftExecutionConfig; -use common_scan_info::PhysicalScanInfo; +use common_scan_info::{PhysicalScanInfo, ScanState}; use daft_schema::schema::SchemaRef; use crate::{ - ops::MaterializedScanSource, source_info::{InMemoryInfo, PlaceHolderInfo, SourceInfo}, stats::{ApproxStats, PlanStats, StatsState}, }; @@ -30,26 +29,33 @@ impl Source { } } - pub(crate) fn with_materialized_scan_source( - &self, + pub(crate) fn build_materialized_scan_source( + mut self, execution_config: Option<&DaftExecutionConfig>, - ) -> MaterializedScanSource { - match &*self.source_info { - SourceInfo::Physical(PhysicalScanInfo { - scan_op, pushdowns, .. - }) => { - let scan_tasks = scan_op - .0 - .to_scan_tasks(pushdowns.clone(), execution_config) - .expect("Failed to get scan tasks from scan operator"); - MaterializedScanSource::new( - scan_tasks, - pushdowns.clone(), - self.output_schema.clone(), - ) + ) -> Self { + if let Some(scan_info) = Arc::get_mut(&mut self.source_info) { + match scan_info { + SourceInfo::Physical(physical_scan_info) => { + match &mut physical_scan_info.scan_state { + ScanState::Operator(scan_op) => { + let scan_tasks = scan_op + .0 + .to_scan_tasks( + physical_scan_info.pushdowns.clone(), + execution_config, + ) + .expect("Failed to get scan tasks from scan operator"); + physical_scan_info.scan_state = ScanState::Tasks(scan_tasks); + } + ScanState::Tasks(_) => { + panic!("Physical scan nodes are being materialized more than once"); + } + } + } + _ => panic!("Only unmaterialized physical scan nodes can be materialized"), } - _ => panic!("Only physical scan nodes can be materialized"), } + self } pub(crate) fn with_materialized_stats(mut self) -> Self { @@ -79,12 +85,12 @@ impl Source { match self.source_info.as_ref() { SourceInfo::Physical(PhysicalScanInfo { source_schema, - scan_op, + scan_state: scan_op, partitioning_keys, pushdowns, }) => { use itertools::Itertools; - res.extend(scan_op.0.multiline_display()); + res.extend(scan_op.multiline_display()); res.push(format!("File schema = {}", source_schema.short_string())); res.push(format!( diff --git a/src/daft-logical-plan/src/optimization/rules/enrich_with_stats.rs b/src/daft-logical-plan/src/optimization/rules/enrich_with_stats.rs index f582003589..149c141ff0 100644 --- a/src/daft-logical-plan/src/optimization/rules/enrich_with_stats.rs +++ b/src/daft-logical-plan/src/optimization/rules/enrich_with_stats.rs @@ -15,6 +15,7 @@ use super::OptimizerRule; use crate::LogicalPlan; // Add stats to all logical plan nodes in a bottom up fashion. +// All scan nodes MUST be materialized before stats are enriched. impl OptimizerRule for EnrichWithStats { fn try_optimize(&self, plan: Arc) -> DaftResult>> { plan.transform_up(|c| { diff --git a/src/daft-logical-plan/src/optimization/rules/materialize_scans.rs b/src/daft-logical-plan/src/optimization/rules/materialize_scans.rs index 94c89fbd58..45148c734d 100644 --- a/src/daft-logical-plan/src/optimization/rules/materialize_scans.rs +++ b/src/daft-logical-plan/src/optimization/rules/materialize_scans.rs @@ -20,26 +20,23 @@ use crate::{LogicalPlan, SourceInfo}; // Add stats to all logical plan nodes in a bottom up fashion. impl OptimizerRule for MaterializeScans { fn try_optimize(&self, plan: Arc) -> DaftResult>> { - plan.transform_up(|node| self.try_optimize_node(node)) + plan.transform_up(|node| self.try_optimize_node(Arc::unwrap_or_clone(node))) } } impl MaterializeScans { #[allow(clippy::only_used_in_recursion)] - fn try_optimize_node( - &self, - plan: Arc, - ) -> DaftResult>> { - match plan.as_ref() { + fn try_optimize_node(&self, plan: LogicalPlan) -> DaftResult>> { + match plan { LogicalPlan::Source(source) => match &*source.source_info { SourceInfo::Physical(_) => Ok(Transformed::yes( source - .with_materialized_scan_source(self.execution_config.as_deref()) + .build_materialized_scan_source(self.execution_config.as_deref()) .into(), )), - _ => Ok(Transformed::no(plan)), + _ => Ok(Transformed::no(Arc::new(LogicalPlan::Source(source)))), }, - _ => Ok(Transformed::no(plan)), + _ => Ok(Transformed::no(Arc::new(plan))), } } } diff --git a/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs b/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs index 468f46d664..9fa30ea8e5 100644 --- a/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs +++ b/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs @@ -113,7 +113,7 @@ impl PushDownFilter { needing_filter_op, } = rewrite_predicate_for_partitioning( &new_predicate, - external_info.scan_op.0.partitioning_keys(), + external_info.scan_state.get_scan_op().0.partitioning_keys(), )?; assert!( partition_only_filter.len() diff --git a/src/daft-logical-plan/src/optimization/rules/push_down_limit.rs b/src/daft-logical-plan/src/optimization/rules/push_down_limit.rs index 4edf53233e..e79879604d 100644 --- a/src/daft-logical-plan/src/optimization/rules/push_down_limit.rs +++ b/src/daft-logical-plan/src/optimization/rules/push_down_limit.rs @@ -75,7 +75,12 @@ impl PushDownLimit { SourceInfo::Physical(new_external_info).into(), )) .into(); - let out_plan = if external_info.scan_op.0.can_absorb_limit() { + let out_plan = if external_info + .scan_state + .get_scan_op() + .0 + .can_absorb_limit() + { new_source } else { plan.with_new_children(&[new_source]).into() diff --git a/src/daft-logical-plan/src/optimization/rules/push_down_projection.rs b/src/daft-logical-plan/src/optimization/rules/push_down_projection.rs index b388f88b3f..7c2391ccce 100644 --- a/src/daft-logical-plan/src/optimization/rules/push_down_projection.rs +++ b/src/daft-logical-plan/src/optimization/rules/push_down_projection.rs @@ -176,9 +176,6 @@ impl PushDownProjection { } } } - LogicalPlan::MaterializedScanSource(..) => { - panic!("Scan nodes should not be materialized before push down projection"); - } LogicalPlan::Project(upstream_projection) => { // Prune columns from the child projection that are not used in this projection. let required_columns = &plan.required_columns()[0]; diff --git a/src/daft-physical-plan/src/physical_planner/translate.rs b/src/daft-physical-plan/src/physical_planner/translate.rs index 16284223f4..59a14ef5a0 100644 --- a/src/daft-physical-plan/src/physical_planner/translate.rs +++ b/src/daft-physical-plan/src/physical_planner/translate.rs @@ -7,7 +7,7 @@ use std::{ use common_daft_config::DaftExecutionConfig; use common_error::{DaftError, DaftResult}; use common_file_formats::FileFormat; -use common_scan_info::PhysicalScanInfo; +use common_scan_info::{PhysicalScanInfo, ScanState}; use daft_core::prelude::*; use daft_dsl::{ col, functions::agg::merge_mean, is_partition_compatible, AggExpr, ApproxPercentileParams, @@ -19,7 +19,7 @@ use daft_logical_plan::{ ops::{ ActorPoolProject as LogicalActorPoolProject, Aggregate as LogicalAggregate, Distinct as LogicalDistinct, Explode as LogicalExplode, Filter as LogicalFilter, - Join as LogicalJoin, Limit as LogicalLimit, MaterializedScanSource, + Join as LogicalJoin, Limit as LogicalLimit, MonotonicallyIncreasingId as LogicalMonotonicallyIncreasingId, Pivot as LogicalPivot, Project as LogicalProject, Repartition as LogicalRepartition, Sample as LogicalSample, Sink as LogicalSink, Sort as LogicalSort, Source, Unpivot as LogicalUnpivot, @@ -42,11 +42,18 @@ pub(super) fn translate_single_logical_node( LogicalPlan::Source(Source { source_info, .. }) => match source_info.as_ref() { SourceInfo::Physical(PhysicalScanInfo { pushdowns, - scan_op, + scan_state, source_schema, .. }) => { - let scan_tasks = scan_op.0.to_scan_tasks(pushdowns.clone(), Some(cfg))?; + let scan_tasks = { + match scan_state { + ScanState::Operator(scan_op) => { + scan_op.0.to_scan_tasks(pushdowns.clone(), Some(cfg))? + } + ScanState::Tasks(scan_tasks) => scan_tasks.clone(), + } + }; if scan_tasks.is_empty() { let clustering_spec = @@ -86,30 +93,6 @@ pub(super) fn translate_single_logical_node( panic!("Placeholder {source_id} should not get to translation. This should have been optimized away"); } }, - LogicalPlan::MaterializedScanSource(MaterializedScanSource { - scan_tasks, schema, .. - }) => { - if scan_tasks.is_empty() { - let clustering_spec = - Arc::new(ClusteringSpec::Unknown(UnknownClusteringConfig::new(1))); - - Ok( - PhysicalPlan::EmptyScan(EmptyScan::new(schema.clone(), clustering_spec)) - .arced(), - ) - } else { - let clustering_spec = Arc::new(ClusteringSpec::Unknown( - UnknownClusteringConfig::new(scan_tasks.len()), - )); - Ok( - PhysicalPlan::TabularScan(TabularScan::new( - scan_tasks.clone(), - clustering_spec, - )) - .arced(), - ) - } - } LogicalPlan::Project(LogicalProject { projection, .. }) => { let input_physical = physical_children.pop().expect("requires 1 input"); Ok(