diff --git a/Cargo.lock b/Cargo.lock index fd22dcfa10..0771b3414b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1843,6 +1843,16 @@ dependencies = [ "memchr", ] +[[package]] +name = "ctor" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a2785755761f3ddc1492979ce1e48d2c00d09311c39e4466429188f3dd6501" +dependencies = [ + "quote", + "syn 2.0.87", +] + [[package]] name = "daft" version = "0.3.0-dev0" @@ -2398,6 +2408,7 @@ dependencies = [ "common-py-serde", "common-runtime", "common-scan-info", + "ctor", "daft-core", "daft-csv", "daft-decoding", diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index f971a078ad..77528da220 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -184,9 +184,10 @@ def explain( return None def num_partitions(self) -> int: - daft_execution_config = get_context().daft_execution_config # We need to run the optimizer since that could change the number of partitions - return self.__builder.optimize().to_physical_plan_scheduler(daft_execution_config).num_partitions() + return ( + self.__builder.optimize().to_physical_plan_scheduler(get_context().daft_execution_config).num_partitions() + ) @DataframePublicAPI def schema(self) -> Schema: diff --git a/src/common/daft-config/src/lib.rs b/src/common/daft-config/src/lib.rs index bbd53b6024..ac8600936d 100644 --- a/src/common/daft-config/src/lib.rs +++ b/src/common/daft-config/src/lib.rs @@ -37,7 +37,7 @@ impl DaftPlanningConfig { /// 3. Task generation from physical plan /// 4. Task scheduling /// 5. Task local execution -#[derive(Clone, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct DaftExecutionConfig { pub scan_tasks_min_size_bytes: usize, pub scan_tasks_max_size_bytes: usize, diff --git a/src/common/scan-info/src/lib.rs b/src/common/scan-info/src/lib.rs index ba3a201614..44727c61df 100644 --- a/src/common/scan-info/src/lib.rs +++ b/src/common/scan-info/src/lib.rs @@ -10,7 +10,7 @@ mod scan_operator; mod scan_task; pub mod test; -use std::{fmt::Debug, hash::Hash}; +use std::{fmt::Debug, hash::Hash, sync::Arc}; use daft_schema::schema::SchemaRef; pub use expr_rewriter::{rewrite_predicate_for_partitioning, PredicateGroups}; @@ -19,11 +19,35 @@ pub use pushdowns::Pushdowns; #[cfg(feature = "python")] pub use python::register_modules; pub use scan_operator::{ScanOperator, ScanOperatorRef}; -pub use scan_task::{BoxScanTaskLikeIter, ScanTaskLike, ScanTaskLikeRef}; +pub use scan_task::{ScanTaskLike, ScanTaskLikeRef, SPLIT_AND_MERGE_PASS}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum ScanState { + Operator(ScanOperatorRef), + Tasks(Arc>), +} + +impl ScanState { + pub fn multiline_display(&self) -> Vec { + match self { + Self::Operator(scan_op) => scan_op.0.multiline_display(), + Self::Tasks(scan_tasks) => { + vec![format!("Num Scan Tasks = {}", scan_tasks.len())] + } + } + } + + pub fn get_scan_op(&self) -> &ScanOperatorRef { + match self { + Self::Operator(scan_op) => scan_op, + Self::Tasks(_) => panic!("Tried to get scan op from materialized physical scan info"), + } + } +} #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct PhysicalScanInfo { - pub scan_op: ScanOperatorRef, + pub scan_state: ScanState, pub source_schema: SchemaRef, pub partitioning_keys: Vec, pub pushdowns: Pushdowns, @@ -38,7 +62,7 @@ impl PhysicalScanInfo { pushdowns: Pushdowns, ) -> Self { Self { - scan_op, + scan_state: ScanState::Operator(scan_op), source_schema, partitioning_keys, pushdowns, @@ -48,7 +72,7 @@ impl PhysicalScanInfo { #[must_use] pub fn with_pushdowns(&self, pushdowns: Pushdowns) -> Self { Self { - scan_op: self.scan_op.clone(), + scan_state: self.scan_state.clone(), source_schema: self.source_schema.clone(), partitioning_keys: self.partitioning_keys.clone(), pushdowns, diff --git a/src/common/scan-info/src/scan_operator.rs b/src/common/scan-info/src/scan_operator.rs index b965f62bb6..10c798fe5f 100644 --- a/src/common/scan-info/src/scan_operator.rs +++ b/src/common/scan-info/src/scan_operator.rs @@ -4,7 +4,6 @@ use std::{ sync::Arc, }; -use common_daft_config::DaftExecutionConfig; use common_error::DaftResult; use daft_schema::schema::SchemaRef; @@ -33,11 +32,7 @@ pub trait ScanOperator: Send + Sync + Debug { /// If cfg provided, `to_scan_tasks` should apply the appropriate transformations /// (merging, splitting) to the outputted scan tasks - fn to_scan_tasks( - &self, - pushdowns: Pushdowns, - config: Option<&DaftExecutionConfig>, - ) -> DaftResult>; + fn to_scan_tasks(&self, pushdowns: Pushdowns) -> DaftResult>; } impl Display for dyn ScanOperator { diff --git a/src/common/scan-info/src/scan_task.rs b/src/common/scan-info/src/scan_task.rs index 3d6e5f466b..886fe42891 100644 --- a/src/common/scan-info/src/scan_task.rs +++ b/src/common/scan-info/src/scan_task.rs @@ -1,4 +1,9 @@ -use std::{any::Any, fmt::Debug, sync::Arc}; +use std::{ + any::Any, + fmt::Debug, + hash::{Hash, Hasher}, + sync::{Arc, OnceLock}, +}; use common_daft_config::DaftExecutionConfig; use common_display::DisplayAs; @@ -13,6 +18,7 @@ pub trait ScanTaskLike: Debug + DisplayAs + Send + Sync { fn as_any(&self) -> &dyn Any; fn as_any_arc(self: Arc) -> Arc; fn dyn_eq(&self, other: &dyn ScanTaskLike) -> bool; + fn dyn_hash(&self, state: &mut dyn Hasher); #[must_use] fn materialized_schema(&self) -> SchemaRef; #[must_use] @@ -35,10 +41,27 @@ pub trait ScanTaskLike: Debug + DisplayAs + Send + Sync { pub type ScanTaskLikeRef = Arc; +impl Eq for dyn ScanTaskLike + '_ {} + impl PartialEq for dyn ScanTaskLike + '_ { fn eq(&self, other: &Self) -> bool { self.dyn_eq(other) } } -pub type BoxScanTaskLikeIter = Box>>>; +impl Hash for dyn ScanTaskLike + '_ { + fn hash(&self, state: &mut H) { + self.dyn_hash(state); + } +} + +// Forward declare splitting and merging pass so that scan tasks can be split and merged +// with common/scan-info without importing daft-scan. +pub type SplitAndMergePass = dyn Fn( + Arc>, + &Pushdowns, + &DaftExecutionConfig, + ) -> DaftResult>> + + Sync + + Send; +pub static SPLIT_AND_MERGE_PASS: OnceLock<&SplitAndMergePass> = OnceLock::new(); diff --git a/src/common/scan-info/src/test/mod.rs b/src/common/scan-info/src/test/mod.rs index c5b14039c1..2fd717db1b 100644 --- a/src/common/scan-info/src/test/mod.rs +++ b/src/common/scan-info/src/test/mod.rs @@ -1,4 +1,8 @@ -use std::{any::Any, sync::Arc}; +use std::{ + any::Any, + hash::{Hash, Hasher}, + sync::Arc, +}; use common_daft_config::DaftExecutionConfig; use common_display::DisplayAs; @@ -9,7 +13,7 @@ use serde::{Deserialize, Serialize}; use crate::{PartitionField, Pushdowns, ScanOperator, ScanTaskLike, ScanTaskLikeRef}; -#[derive(Debug, Serialize, Deserialize, PartialEq)] +#[derive(Debug, Serialize, Deserialize, PartialEq, Hash)] struct DummyScanTask { pub schema: SchemaRef, pub pushdowns: Pushdowns, @@ -38,6 +42,10 @@ impl ScanTaskLike for DummyScanTask { .map_or(false, |a| a == self) } + fn dyn_hash(&self, mut state: &mut dyn Hasher) { + self.hash(&mut state); + } + fn materialized_schema(&self) -> SchemaRef { self.schema.clone() } @@ -121,11 +129,7 @@ impl ScanOperator for DummyScanOperator { vec!["DummyScanOperator".to_string()] } - fn to_scan_tasks( - &self, - pushdowns: Pushdowns, - _: Option<&DaftExecutionConfig>, - ) -> DaftResult> { + fn to_scan_tasks(&self, pushdowns: Pushdowns) -> DaftResult> { let scan_task = Arc::new(DummyScanTask { schema: self.schema.clone(), pushdowns, diff --git a/src/daft-catalog/src/lib.rs b/src/daft-catalog/src/lib.rs index 8492e6ae37..73f75864c8 100644 --- a/src/daft-catalog/src/lib.rs +++ b/src/daft-catalog/src/lib.rs @@ -168,21 +168,21 @@ mod tests { ]) .unwrap(), ); - LogicalPlan::Source(Source { - output_schema: schema.clone(), - source_info: Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { + LogicalPlan::Source(Source::new( + schema.clone(), + Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { source_schema: schema, clustering_spec: Arc::new(ClusteringSpec::unknown()), source_id: 0, })), - }) + )) .arced() } #[test] fn test_register_and_unregister_named_table() { let mut catalog = DaftMetaCatalog::new_from_env(); - let plan = LogicalPlanBuilder::new(mock_plan(), None); + let plan = LogicalPlanBuilder::from(mock_plan()); // Register a table assert!(catalog @@ -198,7 +198,7 @@ mod tests { #[test] fn test_read_registered_table() { let mut catalog = DaftMetaCatalog::new_from_env(); - let plan = LogicalPlanBuilder::new(mock_plan(), None); + let plan = LogicalPlanBuilder::from(mock_plan()); catalog.register_named_table("test_table", plan).unwrap(); diff --git a/src/daft-connect/src/op/execute/root.rs b/src/daft-connect/src/op/execute/root.rs index 1e1fac147b..4f765243c8 100644 --- a/src/daft-connect/src/op/execute/root.rs +++ b/src/daft-connect/src/op/execute/root.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, future::ready}; +use std::{collections::HashMap, future::ready, sync::Arc}; use common_daft_config::DaftExecutionConfig; use daft_local_execution::NativeExecutor; @@ -33,10 +33,10 @@ impl Session { let execution_fut = async { let plan = translation::to_logical_plan(command)?; let optimized_plan = plan.optimize()?; - let cfg = DaftExecutionConfig::default(); + let cfg = Arc::new(DaftExecutionConfig::default()); let native_executor = NativeExecutor::from_logical_plan_builder(&optimized_plan)?; let mut result_stream = native_executor - .run(HashMap::new(), cfg.into(), None)? + .run(HashMap::new(), cfg, None)? .into_stream(); while let Some(result) = result_stream.next().await { diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index 6f804c150f..c931614ff3 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -15,7 +15,7 @@ use daft_local_plan::{ Limit, LocalPhysicalPlan, MonotonicallyIncreasingId, PhysicalWrite, Pivot, Project, Sample, Sort, UnGroupedAggregate, Unpivot, }; -use daft_logical_plan::JoinType; +use daft_logical_plan::{stats::StatsState, JoinType}; use daft_micropartition::MicroPartition; use daft_physical_plan::{extract_agg_expr, populate_aggregation_stages}; use daft_scan::ScanTaskRef; @@ -319,18 +319,54 @@ pub fn physical_plan_to_pipeline( null_equals_null, join_type, schema, + .. }) => { let left_schema = left.schema(); let right_schema = right.schema(); - // Determine the build and probe sides based on the join type - // Currently it is a naive determination, in the future we should leverage the cardinality of the tables - // to determine the build and probe sides + // To determine whether to use the left or right side of a join for building a probe table, we consider: + // 1. Cardinality of the sides. Probe tables should be built on the smaller side. + // 2. Join type. Different join types have different requirements for which side can build the probe table. + let left_stats_state = left.get_stats_state(); + let right_stats_state = right.get_stats_state(); + let build_on_left = match (left_stats_state, right_stats_state) { + (StatsState::Materialized(left_stats), StatsState::Materialized(right_stats)) => { + left_stats.approx_stats.upper_bound_bytes + <= right_stats.approx_stats.upper_bound_bytes + } + // If stats are only available on the right side of the join, and the upper bound bytes on the + // right are under the broadcast join size threshold, we build on the right instead of the left. + (StatsState::NotMaterialized, StatsState::Materialized(right_stats)) => right_stats + .approx_stats + .upper_bound_bytes + .map_or(true, |size| size > cfg.broadcast_join_size_bytes_threshold), + // If stats are not available, we fall back and build on the left by default. + _ => true, + }; + + // TODO(desmond): We might potentially want to flip the probe table side for + // left/right outer joins if one side is significantly larger. Needs to be tuned. + // + // In greater detail, consider a right outer join where the left side is several orders + // of magnitude larger than the right. An extreme example might have 1B rows on the left, + // and 10 rows on the right. + // + // Typically we would build the probe table on the left, then stream rows from the right + // to match against the probe table. But in this case we would have a giant intermediate + // probe table. + // + // An alternative 2-pass algorithm would be to: + // 1. Build the probe table on the right, but add a second data structure to keep track of + // which rows on the right have been matched. + // 2. Stream rows on the left until all rows have been seen. + // 3. Finally, emit all unmatched rows from the right. let build_on_left = match join_type { - JoinType::Inner => true, - JoinType::Right => true, - JoinType::Outer => true, + JoinType::Inner => build_on_left, + JoinType::Outer => build_on_left, + // For left outer joins, we build on right so we can stream the left side. JoinType::Left => false, + // For right outer joins, we build on left so we can stream the right side. + JoinType::Right => true, JoinType::Anti | JoinType::Semi => false, }; let (build_on, probe_on, build_child, probe_child) = match build_on_left { @@ -421,6 +457,7 @@ pub fn physical_plan_to_pipeline( left_schema, right_schema, *join_type, + build_on_left, common_join_keys, schema, probe_state_bridge, diff --git a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs index a8ca50130f..4af93075a0 100644 --- a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs +++ b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs @@ -126,6 +126,7 @@ struct OuterHashJoinParams { right_non_join_columns: Vec, right_non_join_schema: SchemaRef, join_type: JoinType, + build_on_left: bool, } pub(crate) struct OuterHashJoinProbeSink { @@ -134,16 +135,23 @@ pub(crate) struct OuterHashJoinProbeSink { probe_state_bridge: ProbeStateBridgeRef, } +#[allow(clippy::too_many_arguments)] impl OuterHashJoinProbeSink { pub(crate) fn new( probe_on: Vec, left_schema: &SchemaRef, right_schema: &SchemaRef, join_type: JoinType, + build_on_left: bool, common_join_keys: IndexSet, output_schema: &SchemaRef, probe_state_bridge: ProbeStateBridgeRef, ) -> Self { + // For outer joins, we need to swap the left and right schemas if we are building on the right. + let (left_schema, right_schema) = match (join_type, build_on_left) { + (JoinType::Outer, false) => (right_schema, left_schema), + _ => (left_schema, right_schema), + }; let left_non_join_columns = left_schema .fields .keys() @@ -168,6 +176,7 @@ impl OuterHashJoinProbeSink { right_non_join_columns, right_non_join_schema, join_type, + build_on_left, }), output_schema: output_schema.clone(), probe_state_bridge, @@ -243,6 +252,7 @@ impl OuterHashJoinProbeSink { ))) } + #[allow(clippy::too_many_arguments)] fn probe_outer( input: &Arc, probe_state: &ProbeState, @@ -251,6 +261,7 @@ impl OuterHashJoinProbeSink { common_join_keys: &[String], left_non_join_columns: &[String], right_non_join_columns: &[String], + build_on_left: bool, ) -> DaftResult> { let probe_table = probe_state.get_probeable().clone(); let tables = probe_state.get_tables().clone(); @@ -297,6 +308,12 @@ impl OuterHashJoinProbeSink { let join_table = probe_side_table.get_columns(common_join_keys)?; let left = build_side_table.get_columns(left_non_join_columns)?; let right = probe_side_table.get_columns(right_non_join_columns)?; + // If we built the probe table on the right, flip the order of union. + let (left, right) = if build_on_left { + (left, right) + } else { + (right, left) + }; let final_table = join_table.union(&left)?.union(&right)?; Ok(Arc::new(MicroPartition::new_loaded( final_table.schema.clone(), @@ -310,6 +327,7 @@ impl OuterHashJoinProbeSink { common_join_keys: &[String], left_non_join_columns: &[String], right_non_join_schema: &SchemaRef, + build_on_left: bool, ) -> DaftResult>> { let mut states_iter = states.iter_mut(); let first_state = states_iter @@ -372,6 +390,12 @@ impl OuterHashJoinProbeSink { .collect::>(); Table::new_unchecked(right_non_join_schema.clone(), columns, left.len()) }; + // If we built the probe table on the right, flip the order of union. + let (left, right) = if build_on_left { + (left, right) + } else { + (right, left) + }; let final_table = join_table.union(&left)?.union(&right)?; Ok(Some(Arc::new(MicroPartition::new_loaded( final_table.schema.clone(), @@ -426,6 +450,7 @@ impl StreamingSink for OuterHashJoinProbeSink { ¶ms.common_join_keys, ¶ms.left_non_join_columns, ¶ms.right_non_join_columns, + params.build_on_left, ) } _ => unreachable!( @@ -462,6 +487,7 @@ impl StreamingSink for OuterHashJoinProbeSink { ¶ms.common_join_keys, ¶ms.left_non_join_columns, ¶ms.right_non_join_schema, + params.build_on_left, ) .await }) diff --git a/src/daft-local-plan/src/plan.rs b/src/daft-local-plan/src/plan.rs index 26f796b1e4..7d541421a3 100644 --- a/src/daft-local-plan/src/plan.rs +++ b/src/daft-local-plan/src/plan.rs @@ -4,7 +4,10 @@ use common_resource_request::ResourceRequest; use common_scan_info::{Pushdowns, ScanTaskLikeRef}; use daft_core::prelude::*; use daft_dsl::{AggExpr, ExprRef}; -use daft_logical_plan::{InMemoryInfo, OutputFileInfo}; +use daft_logical_plan::{ + stats::{PlanStats, StatsState}, + InMemoryInfo, OutputFileInfo, +}; pub type LocalPhysicalPlanRef = Arc; #[derive(Debug, strum::IntoStaticStr)] @@ -56,24 +59,54 @@ impl LocalPhysicalPlan { self.into() } - pub(crate) fn in_memory_scan(in_memory_info: InMemoryInfo) -> LocalPhysicalPlanRef { + pub fn get_stats_state(&self) -> &StatsState { + match self { + Self::InMemoryScan(InMemoryScan { stats_state, .. }) + | Self::PhysicalScan(PhysicalScan { stats_state, .. }) + | Self::EmptyScan(EmptyScan { stats_state, .. }) + | Self::Project(Project { stats_state, .. }) + | Self::ActorPoolProject(ActorPoolProject { stats_state, .. }) + | Self::Filter(Filter { stats_state, .. }) + | Self::Limit(Limit { stats_state, .. }) + | Self::Explode(Explode { stats_state, .. }) + | Self::Unpivot(Unpivot { stats_state, .. }) + | Self::Sort(Sort { stats_state, .. }) + | Self::Sample(Sample { stats_state, .. }) + | Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId { stats_state, .. }) + | Self::UnGroupedAggregate(UnGroupedAggregate { stats_state, .. }) + | Self::HashAggregate(HashAggregate { stats_state, .. }) + | Self::Pivot(Pivot { stats_state, .. }) + | Self::Concat(Concat { stats_state, .. }) + | Self::HashJoin(HashJoin { stats_state, .. }) + | Self::PhysicalWrite(PhysicalWrite { stats_state, .. }) => stats_state, + #[cfg(feature = "python")] + Self::CatalogWrite(CatalogWrite { stats_state, .. }) + | Self::LanceWrite(LanceWrite { stats_state, .. }) => stats_state, + } + } + + pub(crate) fn in_memory_scan( + in_memory_info: InMemoryInfo, + stats_state: StatsState, + ) -> LocalPhysicalPlanRef { Self::InMemoryScan(InMemoryScan { info: in_memory_info, - plan_stats: PlanStats {}, + stats_state, }) .arced() } pub(crate) fn physical_scan( - scan_tasks: Vec, + scan_tasks: Arc>, pushdowns: Pushdowns, schema: SchemaRef, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { Self::PhysicalScan(PhysicalScan { scan_tasks, pushdowns, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -81,29 +114,37 @@ impl LocalPhysicalPlan { pub(crate) fn empty_scan(schema: SchemaRef) -> LocalPhysicalPlanRef { Self::EmptyScan(EmptyScan { schema, - plan_stats: PlanStats {}, + stats_state: StatsState::Materialized(PlanStats::empty().into()), }) .arced() } - pub(crate) fn filter(input: LocalPhysicalPlanRef, predicate: ExprRef) -> LocalPhysicalPlanRef { + pub(crate) fn filter( + input: LocalPhysicalPlanRef, + predicate: ExprRef, + stats_state: StatsState, + ) -> LocalPhysicalPlanRef { let schema = input.schema().clone(); Self::Filter(Filter { input, predicate, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } - pub(crate) fn limit(input: LocalPhysicalPlanRef, num_rows: i64) -> LocalPhysicalPlanRef { + pub(crate) fn limit( + input: LocalPhysicalPlanRef, + num_rows: i64, + stats_state: StatsState, + ) -> LocalPhysicalPlanRef { let schema = input.schema().clone(); Self::Limit(Limit { input, num_rows, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -112,12 +153,13 @@ impl LocalPhysicalPlan { input: LocalPhysicalPlanRef, to_explode: Vec, schema: SchemaRef, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { Self::Explode(Explode { input, to_explode, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -126,12 +168,13 @@ impl LocalPhysicalPlan { input: LocalPhysicalPlanRef, projection: Vec, schema: SchemaRef, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { Self::Project(Project { input, projection, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -140,12 +183,13 @@ impl LocalPhysicalPlan { input: LocalPhysicalPlanRef, projection: Vec, schema: SchemaRef, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { Self::ActorPoolProject(ActorPoolProject { input, projection, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -154,12 +198,13 @@ impl LocalPhysicalPlan { input: LocalPhysicalPlanRef, aggregations: Vec, schema: SchemaRef, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { Self::UnGroupedAggregate(UnGroupedAggregate { input, aggregations, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -169,13 +214,14 @@ impl LocalPhysicalPlan { aggregations: Vec, group_by: Vec, schema: SchemaRef, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { Self::HashAggregate(HashAggregate { input, aggregations, group_by, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -187,6 +233,7 @@ impl LocalPhysicalPlan { variable_name: String, value_name: String, schema: SchemaRef, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { Self::Unpivot(Unpivot { input, @@ -195,11 +242,12 @@ impl LocalPhysicalPlan { variable_name, value_name, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } + #[allow(clippy::too_many_arguments)] pub(crate) fn pivot( input: LocalPhysicalPlanRef, group_by: Vec, @@ -208,6 +256,7 @@ impl LocalPhysicalPlan { aggregation: AggExpr, names: Vec, schema: SchemaRef, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { Self::Pivot(Pivot { input, @@ -217,7 +266,7 @@ impl LocalPhysicalPlan { aggregation, names, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -227,15 +276,16 @@ impl LocalPhysicalPlan { sort_by: Vec, descending: Vec, nulls_first: Vec, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { let schema = input.schema().clone(); Self::Sort(Sort { input, sort_by, - nulls_first, descending, + nulls_first, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -245,6 +295,7 @@ impl LocalPhysicalPlan { fraction: f64, with_replacement: bool, seed: Option, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { let schema = input.schema().clone(); Self::Sample(Sample { @@ -253,7 +304,7 @@ impl LocalPhysicalPlan { with_replacement, seed, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -262,16 +313,18 @@ impl LocalPhysicalPlan { input: LocalPhysicalPlanRef, column_name: String, schema: SchemaRef, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId { input, column_name, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } + #[allow(clippy::too_many_arguments)] pub(crate) fn hash_join( left: LocalPhysicalPlanRef, right: LocalPhysicalPlanRef, @@ -280,6 +333,7 @@ impl LocalPhysicalPlan { null_equals_null: Option>, join_type: JoinType, schema: SchemaRef, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { Self::HashJoin(HashJoin { left, @@ -289,6 +343,7 @@ impl LocalPhysicalPlan { null_equals_null, join_type, schema, + stats_state, }) .arced() } @@ -296,13 +351,14 @@ impl LocalPhysicalPlan { pub(crate) fn concat( input: LocalPhysicalPlanRef, other: LocalPhysicalPlanRef, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { let schema = input.schema().clone(); Self::Concat(Concat { input, other, schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -312,13 +368,14 @@ impl LocalPhysicalPlan { data_schema: SchemaRef, file_schema: SchemaRef, file_info: OutputFileInfo, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { Self::PhysicalWrite(PhysicalWrite { input, data_schema, file_schema, file_info, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -329,13 +386,14 @@ impl LocalPhysicalPlan { catalog_type: daft_logical_plan::CatalogType, data_schema: SchemaRef, file_schema: SchemaRef, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { Self::CatalogWrite(CatalogWrite { input, catalog_type, data_schema, file_schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -346,13 +404,14 @@ impl LocalPhysicalPlan { lance_info: daft_logical_plan::LanceCatalogInfo, data_schema: SchemaRef, file_schema: SchemaRef, + stats_state: StatsState, ) -> LocalPhysicalPlanRef { Self::LanceWrite(LanceWrite { input, lance_info, data_schema, file_schema, - plan_stats: PlanStats {}, + stats_state, }) .arced() } @@ -388,21 +447,21 @@ impl LocalPhysicalPlan { #[derive(Debug)] pub struct InMemoryScan { pub info: InMemoryInfo, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] pub struct PhysicalScan { - pub scan_tasks: Vec, + pub scan_tasks: Arc>, pub pushdowns: Pushdowns, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] pub struct EmptyScan { pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -410,7 +469,7 @@ pub struct Project { pub input: LocalPhysicalPlanRef, pub projection: Vec, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -418,7 +477,7 @@ pub struct ActorPoolProject { pub input: LocalPhysicalPlanRef, pub projection: Vec, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -426,7 +485,7 @@ pub struct Filter { pub input: LocalPhysicalPlanRef, pub predicate: ExprRef, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -434,7 +493,7 @@ pub struct Limit { pub input: LocalPhysicalPlanRef, pub num_rows: i64, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -442,7 +501,7 @@ pub struct Explode { pub input: LocalPhysicalPlanRef, pub to_explode: Vec, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -452,7 +511,7 @@ pub struct Sort { pub descending: Vec, pub nulls_first: Vec, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -462,7 +521,7 @@ pub struct Sample { pub with_replacement: bool, pub seed: Option, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -470,7 +529,7 @@ pub struct MonotonicallyIncreasingId { pub input: LocalPhysicalPlanRef, pub column_name: String, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -478,7 +537,7 @@ pub struct UnGroupedAggregate { pub input: LocalPhysicalPlanRef, pub aggregations: Vec, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -487,7 +546,7 @@ pub struct HashAggregate { pub aggregations: Vec, pub group_by: Vec, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -498,7 +557,7 @@ pub struct Unpivot { pub variable_name: String, pub value_name: String, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -510,7 +569,7 @@ pub struct Pivot { pub aggregation: AggExpr, pub names: Vec, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -522,6 +581,7 @@ pub struct HashJoin { pub null_equals_null: Option>, pub join_type: JoinType, pub schema: SchemaRef, + pub stats_state: StatsState, } #[derive(Debug)] @@ -529,7 +589,7 @@ pub struct Concat { pub input: LocalPhysicalPlanRef, pub other: LocalPhysicalPlanRef, pub schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[derive(Debug)] @@ -538,7 +598,7 @@ pub struct PhysicalWrite { pub data_schema: SchemaRef, pub file_schema: SchemaRef, pub file_info: OutputFileInfo, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[cfg(feature = "python")] @@ -548,7 +608,7 @@ pub struct CatalogWrite { pub catalog_type: daft_logical_plan::CatalogType, pub data_schema: SchemaRef, pub file_schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } #[cfg(feature = "python")] @@ -558,8 +618,5 @@ pub struct LanceWrite { pub lance_info: daft_logical_plan::LanceCatalogInfo, pub data_schema: SchemaRef, pub file_schema: SchemaRef, - pub plan_stats: PlanStats, + pub stats_state: StatsState, } - -#[derive(Debug)] -pub struct PlanStats {} diff --git a/src/daft-local-plan/src/translate.rs b/src/daft-local-plan/src/translate.rs index b8f214dd4f..aac5046e46 100644 --- a/src/daft-local-plan/src/translate.rs +++ b/src/daft-local-plan/src/translate.rs @@ -1,4 +1,7 @@ +use std::sync::Arc; + use common_error::{DaftError, DaftResult}; +use common_scan_info::ScanState; use daft_core::join::JoinStrategy; use daft_dsl::ExprRef; use daft_logical_plan::{JoinType, LogicalPlan, LogicalPlanRef, SourceInfo}; @@ -9,10 +12,18 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { match plan.as_ref() { LogicalPlan::Source(source) => { match source.source_info.as_ref() { - SourceInfo::InMemory(info) => Ok(LocalPhysicalPlan::in_memory_scan(info.clone())), + SourceInfo::InMemory(info) => Ok(LocalPhysicalPlan::in_memory_scan( + info.clone(), + source.stats_state.clone(), + )), SourceInfo::Physical(info) => { // We should be able to pass the ScanOperator into the physical plan directly but we need to figure out the serialization story - let scan_tasks = info.scan_op.0.to_scan_tasks(info.pushdowns.clone(), None)?; + let scan_tasks = match &info.scan_state { + ScanState::Operator(scan_op) => { + Arc::new(scan_op.0.to_scan_tasks(info.pushdowns.clone())?) + } + ScanState::Tasks(scan_tasks) => scan_tasks.clone(), + }; if scan_tasks.is_empty() { Ok(LocalPhysicalPlan::empty_scan(source.output_schema.clone())) } else { @@ -20,6 +31,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { scan_tasks, info.pushdowns.clone(), source.output_schema.clone(), + source.stats_state.clone(), )) } } @@ -30,11 +42,19 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { } LogicalPlan::Filter(filter) => { let input = translate(&filter.input)?; - Ok(LocalPhysicalPlan::filter(input, filter.predicate.clone())) + Ok(LocalPhysicalPlan::filter( + input, + filter.predicate.clone(), + filter.stats_state.clone(), + )) } LogicalPlan::Limit(limit) => { let input = translate(&limit.input)?; - Ok(LocalPhysicalPlan::limit(input, limit.limit)) + Ok(LocalPhysicalPlan::limit( + input, + limit.limit, + limit.stats_state.clone(), + )) } LogicalPlan::Project(project) => { let input = translate(&project.input)?; @@ -42,6 +62,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { input, project.projection.clone(), project.projected_schema.clone(), + project.stats_state.clone(), )) } LogicalPlan::ActorPoolProject(actor_pool_project) => { @@ -50,6 +71,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { input, actor_pool_project.projection.clone(), actor_pool_project.projected_schema.clone(), + actor_pool_project.stats_state.clone(), )) } LogicalPlan::Sample(sample) => { @@ -59,6 +81,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { sample.fraction, sample.with_replacement, sample.seed, + sample.stats_state.clone(), )) } LogicalPlan::Aggregate(aggregate) => { @@ -68,6 +91,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { input, aggregate.aggregations.clone(), aggregate.output_schema.clone(), + aggregate.stats_state.clone(), )) } else { Ok(LocalPhysicalPlan::hash_aggregate( @@ -75,6 +99,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { aggregate.aggregations.clone(), aggregate.groupby.clone(), aggregate.output_schema.clone(), + aggregate.stats_state.clone(), )) } } @@ -87,6 +112,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { unpivot.variable_name.clone(), unpivot.value_name.clone(), unpivot.output_schema.clone(), + unpivot.stats_state.clone(), )) } LogicalPlan::Pivot(pivot) => { @@ -99,6 +125,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { pivot.aggregation.clone(), pivot.names.clone(), pivot.output_schema.clone(), + pivot.stats_state.clone(), )) } LogicalPlan::Sort(sort) => { @@ -108,6 +135,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { sort.sort_by.clone(), sort.descending.clone(), sort.nulls_first.clone(), + sort.stats_state.clone(), )) } LogicalPlan::Join(join) => { @@ -134,6 +162,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { join.null_equals_nulls.clone(), join.join_type, join.output_schema.clone(), + join.stats_state.clone(), )) } LogicalPlan::Distinct(distinct) => { @@ -150,12 +179,17 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { vec![], col_exprs, schema, + distinct.stats_state.clone(), )) } LogicalPlan::Concat(concat) => { let input = translate(&concat.input)?; let other = translate(&concat.other)?; - Ok(LocalPhysicalPlan::concat(input, other)) + Ok(LocalPhysicalPlan::concat( + input, + other, + concat.stats_state.clone(), + )) } LogicalPlan::Repartition(repartition) => { log::warn!("Repartition Not supported for Local Executor!; This will be a No-Op"); @@ -167,6 +201,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { input, monotonically_increasing_id.column_name.clone(), monotonically_increasing_id.schema.clone(), + monotonically_increasing_id.stats_state.clone(), )) } LogicalPlan::Sink(sink) => { @@ -179,6 +214,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { data_schema, sink.schema.clone(), info.clone(), + sink.stats_state.clone(), )), #[cfg(feature = "python")] SinkInfo::CatalogInfo(info) => match &info.catalog { @@ -189,6 +225,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { info.catalog.clone(), data_schema, sink.schema.clone(), + sink.stats_state.clone(), )) } daft_logical_plan::CatalogType::Lance(info) => { @@ -197,6 +234,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { info.clone(), data_schema, sink.schema.clone(), + sink.stats_state.clone(), )) } }, @@ -208,6 +246,7 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { input, explode.to_explode.clone(), explode.exploded_schema.clone(), + explode.stats_state.clone(), )) } _ => todo!("{} not yet implemented", plan.name()), diff --git a/src/daft-logical-plan/src/builder.rs b/src/daft-logical-plan/src/builder.rs index c945c80203..f40a55ed4f 100644 --- a/src/daft-logical-plan/src/builder.rs +++ b/src/daft-logical-plan/src/builder.rs @@ -134,7 +134,7 @@ impl LogicalPlanBuilder { )); let logical_plan: LogicalPlan = ops::Source::new(schema, source_info.into()).into(); - Ok(Self::new(logical_plan.into(), None)) + Ok(Self::from(Arc::new(logical_plan))) } pub fn table_scan( @@ -186,7 +186,7 @@ impl LogicalPlanBuilder { schema_with_generated_fields }; let logical_plan: LogicalPlan = ops::Source::new(output_schema, source_info.into()).into(); - Ok(Self::new(logical_plan.into(), None)) + Ok(Self::from(Arc::new(logical_plan))) } pub fn select(&self, to_select: Vec) -> DaftResult { diff --git a/src/daft-logical-plan/src/display.rs b/src/daft-logical-plan/src/display.rs index be83f5237b..84db90d273 100644 --- a/src/daft-logical-plan/src/display.rs +++ b/src/daft-logical-plan/src/display.rs @@ -51,14 +51,14 @@ mod test { ]) .unwrap(), ); - LogicalPlan::Source(Source { - output_schema: schema.clone(), - source_info: Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { + LogicalPlan::Source(Source::new( + schema.clone(), + Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { source_schema: schema, clustering_spec: Arc::new(ClusteringSpec::unknown()), source_id: 0, })), - }) + )) .arced() } @@ -71,25 +71,25 @@ mod test { ]) .unwrap(), ); - LogicalPlan::Source(Source { - output_schema: schema.clone(), - source_info: Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { + LogicalPlan::Source(Source::new( + schema.clone(), + Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { source_schema: schema, clustering_spec: Arc::new(ClusteringSpec::unknown()), source_id: 0, })), - }) + )) .arced() } #[test] // create a random, complex plan and check if it can be displayed as expected fn test_mermaid_display() -> DaftResult<()> { - let subplan = LogicalPlanBuilder::new(plan_1(), None) + let subplan = LogicalPlanBuilder::from(plan_1()) .filter(col("id").eq(lit(1)))? .build(); - let subplan2 = LogicalPlanBuilder::new(plan_2(), None) + let subplan2 = LogicalPlanBuilder::from(plan_2()) .filter( startswith(col("last_name"), lit("S")).and(endswith(col("last_name"), lit("n"))), )? @@ -99,7 +99,7 @@ mod test { .sort(vec![col("last_name")], vec![false], vec![false])? .build(); - let plan = LogicalPlanBuilder::new(subplan, None) + let plan = LogicalPlanBuilder::from(subplan) .join( subplan2, vec![col("id")], @@ -159,11 +159,11 @@ Project1 --> Limit0 #[test] // create a random, complex plan and check if it can be displayed as expected fn test_mermaid_display_simple() -> DaftResult<()> { - let subplan = LogicalPlanBuilder::new(plan_1(), None) + let subplan = LogicalPlanBuilder::from(plan_1()) .filter(col("id").eq(lit(1)))? .build(); - let subplan2 = LogicalPlanBuilder::new(plan_2(), None) + let subplan2 = LogicalPlanBuilder::from(plan_2()) .filter( startswith(col("last_name"), lit("S")).and(endswith(col("last_name"), lit("n"))), )? @@ -173,7 +173,7 @@ Project1 --> Limit0 .sort(vec![col("last_name")], vec![false], vec![false])? .build(); - let plan = LogicalPlanBuilder::new(subplan, None) + let plan = LogicalPlanBuilder::from(subplan) .join_with_null_safe_equal( subplan2, vec![col("id")], diff --git a/src/daft-logical-plan/src/lib.rs b/src/daft-logical-plan/src/lib.rs index 88f21f8797..5296d99a23 100644 --- a/src/daft-logical-plan/src/lib.rs +++ b/src/daft-logical-plan/src/lib.rs @@ -10,6 +10,7 @@ pub mod optimization; pub mod partitioning; pub mod sink_info; pub mod source_info; +pub mod stats; #[cfg(test)] mod test; mod treenode; diff --git a/src/daft-logical-plan/src/logical_plan.rs b/src/daft-logical-plan/src/logical_plan.rs index 01c6b510c8..fc2f065038 100644 --- a/src/daft-logical-plan/src/logical_plan.rs +++ b/src/daft-logical-plan/src/logical_plan.rs @@ -8,6 +8,7 @@ use indexmap::IndexSet; use snafu::Snafu; pub use crate::ops::*; +use crate::stats::PlanStats; /// Logical plan for a Daft query. #[derive(Clone, Debug, PartialEq, Eq, Hash)] @@ -39,6 +40,7 @@ impl LogicalPlan { pub fn arced(self) -> Arc { Arc::new(self) } + pub fn schema(&self) -> SchemaRef { match self { Self::Source(Source { output_schema, .. }) => output_schema.clone(), @@ -197,29 +199,91 @@ impl LogicalPlan { } } + pub fn materialized_stats(&self) -> &PlanStats { + match self { + Self::Source(Source { stats_state, .. }) + | Self::Project(Project { stats_state, .. }) + | Self::ActorPoolProject(ActorPoolProject { stats_state, .. }) + | Self::Filter(Filter { stats_state, .. }) + | Self::Limit(Limit { stats_state, .. }) + | Self::Explode(Explode { stats_state, .. }) + | Self::Unpivot(Unpivot { stats_state, .. }) + | Self::Sort(Sort { stats_state, .. }) + | Self::Repartition(Repartition { stats_state, .. }) + | Self::Distinct(Distinct { stats_state, .. }) + | Self::Aggregate(Aggregate { stats_state, .. }) + | Self::Pivot(Pivot { stats_state, .. }) + | Self::Concat(Concat { stats_state, .. }) + | Self::Join(Join { stats_state, .. }) + | Self::Sink(Sink { stats_state, .. }) + | Self::Sample(Sample { stats_state, .. }) + | Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId { stats_state, .. }) => { + stats_state.materialized_stats() + } + Self::Intersect(_) => { + panic!("Intersect nodes should be optimized away before stats are materialized") + } + Self::Union(_) => { + panic!("Union nodes should be optimized away before stats are materialized") + } + } + } + + // Materializes stats over logical plans. If stats are already materialized, this function recomputes stats, which might be + // useful if stats become stale during query planning. + pub fn with_materialized_stats(self) -> Self { + match self { + Self::Source(plan) => Self::Source(plan.with_materialized_stats()), + Self::Project(plan) => Self::Project(plan.with_materialized_stats()), + Self::ActorPoolProject(plan) => Self::ActorPoolProject(plan.with_materialized_stats()), + Self::Filter(plan) => Self::Filter(plan.with_materialized_stats()), + Self::Limit(plan) => Self::Limit(plan.with_materialized_stats()), + Self::Explode(plan) => Self::Explode(plan.with_materialized_stats()), + Self::Unpivot(plan) => Self::Unpivot(plan.with_materialized_stats()), + Self::Sort(plan) => Self::Sort(plan.with_materialized_stats()), + Self::Repartition(plan) => Self::Repartition(plan.with_materialized_stats()), + Self::Distinct(plan) => Self::Distinct(plan.with_materialized_stats()), + Self::Aggregate(plan) => Self::Aggregate(plan.with_materialized_stats()), + Self::Pivot(plan) => Self::Pivot(plan.with_materialized_stats()), + Self::Concat(plan) => Self::Concat(plan.with_materialized_stats()), + Self::Intersect(_) => { + panic!("Intersect should be optimized away before stats are derived") + } + Self::Union(_) => { + panic!("Union should be optimized away before stats are derived") + } + Self::Join(plan) => Self::Join(plan.with_materialized_stats()), + Self::Sink(plan) => Self::Sink(plan.with_materialized_stats()), + Self::Sample(plan) => Self::Sample(plan.with_materialized_stats()), + Self::MonotonicallyIncreasingId(plan) => { + Self::MonotonicallyIncreasingId(plan.with_materialized_stats()) + } + } + } + pub fn multiline_display(&self) -> Vec { match self { Self::Source(source) => source.multiline_display(), Self::Project(projection) => projection.multiline_display(), Self::ActorPoolProject(projection) => projection.multiline_display(), - Self::Filter(Filter { predicate, .. }) => vec![format!("Filter: {predicate}")], - Self::Limit(Limit { limit, .. }) => vec![format!("Limit: {limit}")], + Self::Filter(filter) => filter.multiline_display(), + Self::Limit(limit) => limit.multiline_display(), Self::Explode(explode) => explode.multiline_display(), Self::Unpivot(unpivot) => unpivot.multiline_display(), Self::Sort(sort) => sort.multiline_display(), Self::Repartition(repartition) => repartition.multiline_display(), - Self::Distinct(_) => vec!["Distinct".to_string()], + Self::Distinct(distinct) => distinct.multiline_display(), Self::Aggregate(aggregate) => aggregate.multiline_display(), Self::Pivot(pivot) => pivot.multiline_display(), - Self::Concat(_) => vec!["Concat".to_string()], + Self::Concat(concat) => concat.multiline_display(), Self::Intersect(inner) => inner.multiline_display(), Self::Union(inner) => inner.multiline_display(), Self::Join(join) => join.multiline_display(), Self::Sink(sink) => sink.multiline_display(), - Self::Sample(sample) => { - vec![format!("Sample: {fraction}", fraction = sample.fraction)] + Self::Sample(sample) => sample.multiline_display(), + Self::MonotonicallyIncreasingId(monotonically_increasing_id) => { + monotonically_increasing_id.multiline_display() } - Self::MonotonicallyIncreasingId(_) => vec!["MonotonicallyIncreasingId".to_string()], } } @@ -237,7 +301,7 @@ impl LogicalPlan { Self::Distinct(Distinct { input, .. }) => vec![input], Self::Aggregate(Aggregate { input, .. }) => vec![input], Self::Pivot(Pivot { input, .. }) => vec![input], - Self::Concat(Concat { input, other }) => vec![input, other], + Self::Concat(Concat { input, other, .. }) => vec![input, other], Self::Join(Join { left, right, .. }) => vec![left, right], Self::Sink(Sink { input, .. }) => vec![input], Self::Intersect(Intersect { lhs, rhs, .. }) => vec![lhs, rhs], @@ -267,7 +331,8 @@ impl LogicalPlan { Self::Pivot(Pivot { group_by, pivot_column, value_column, aggregation, names, ..}) => Self::Pivot(Pivot::try_new(input.clone(), group_by.clone(), pivot_column.clone(), value_column.clone(), aggregation.into(), names.clone()).unwrap()), Self::Sink(Sink { sink_info, .. }) => Self::Sink(Sink::try_new(input.clone(), sink_info.clone()).unwrap()), Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId {column_name, .. }) => Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId::new(input.clone(), Some(column_name))), - Self::Unpivot(Unpivot {ids, values, variable_name, value_name, output_schema, ..}) => Self::Unpivot(Unpivot { input: input.clone(), ids: ids.clone(), values: values.clone(), variable_name: variable_name.clone(), value_name: value_name.clone(), output_schema: output_schema.clone() }), + Self::Unpivot(Unpivot {ids, values, variable_name, value_name, output_schema, ..}) => + Self::Unpivot(Unpivot::new(input.clone(), ids.clone(), values.clone(), variable_name.clone(), value_name.clone(), output_schema.clone())), Self::Sample(Sample {fraction, with_replacement, seed, ..}) => Self::Sample(Sample::new(input.clone(), *fraction, *with_replacement, *seed)), Self::Concat(_) => panic!("Concat ops should never have only one input, but got one"), Self::Intersect(_) => panic!("Intersect ops should never have only one input, but got one"), @@ -377,7 +442,7 @@ macro_rules! impl_from_data_struct_for_logical_plan { impl From<$name> for Arc { fn from(data: $name) -> Self { - Arc::new(LogicalPlan::$name(data)) + Self::new(LogicalPlan::$name(data)) } } }; diff --git a/src/daft-logical-plan/src/ops/actor_pool_project.rs b/src/daft-logical-plan/src/ops/actor_pool_project.rs index fa1c8bb970..78ec2a681f 100644 --- a/src/daft-logical-plan/src/ops/actor_pool_project.rs +++ b/src/daft-logical-plan/src/ops/actor_pool_project.rs @@ -16,6 +16,7 @@ use snafu::ResultExt; use crate::{ logical_plan::{CreationSnafu, Error, Result}, + stats::StatsState, LogicalPlan, }; @@ -25,6 +26,7 @@ pub struct ActorPoolProject { pub input: Arc, pub projection: Vec, pub projected_schema: SchemaRef, + pub stats_state: StatsState, } impl ActorPoolProject { @@ -64,9 +66,17 @@ impl ActorPoolProject { input, projection, projected_schema, + stats_state: StatsState::NotMaterialized, }) } + pub(crate) fn with_materialized_stats(mut self) -> Self { + // TODO(desmond): We can do better estimations with the projection schema. For now, reuse the old logic. + let input_stats = self.input.materialized_stats(); + self.stats_state = StatsState::Materialized(input_stats.clone().into()); + self + } + pub fn resource_request(&self) -> Option { get_resource_request(self.projection.as_slice()) } @@ -115,6 +125,9 @@ impl ActorPoolProject { multiline_display.join(", ") )); } + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } res } } diff --git a/src/daft-logical-plan/src/ops/agg.rs b/src/daft-logical-plan/src/ops/agg.rs index be82d0d010..5b99338b1c 100644 --- a/src/daft-logical-plan/src/ops/agg.rs +++ b/src/daft-logical-plan/src/ops/agg.rs @@ -7,6 +7,7 @@ use snafu::ResultExt; use crate::{ logical_plan::{self, CreationSnafu}, + stats::{ApproxStats, PlanStats, StatsState}, LogicalPlan, }; @@ -26,6 +27,7 @@ pub struct Aggregate { pub groupby: Vec, pub output_schema: SchemaRef, + pub stats_state: StatsState, } impl Aggregate { @@ -55,9 +57,46 @@ impl Aggregate { aggregations, groupby, output_schema, + stats_state: StatsState::NotMaterialized, }) } + pub(crate) fn with_materialized_stats(mut self) -> Self { + // TODO(desmond): We can use the schema here for better estimations. For now, use the old logic. + let input_stats = self.input.materialized_stats(); + let est_bytes_per_row_lower = input_stats.approx_stats.lower_bound_bytes + / (input_stats.approx_stats.lower_bound_rows.max(1)); + let est_bytes_per_row_upper = + input_stats + .approx_stats + .upper_bound_bytes + .and_then(|bytes| { + input_stats + .approx_stats + .upper_bound_rows + .map(|rows| bytes / rows.max(1)) + }); + let approx_stats = if self.groupby.is_empty() { + ApproxStats { + lower_bound_rows: input_stats.approx_stats.lower_bound_rows.min(1), + upper_bound_rows: Some(1), + lower_bound_bytes: input_stats.approx_stats.lower_bound_bytes.min(1) + * est_bytes_per_row_lower, + upper_bound_bytes: est_bytes_per_row_upper, + } + } else { + ApproxStats { + lower_bound_rows: input_stats.approx_stats.lower_bound_rows.min(1), + upper_bound_rows: input_stats.approx_stats.upper_bound_rows, + lower_bound_bytes: input_stats.approx_stats.lower_bound_bytes.min(1) + * est_bytes_per_row_lower, + upper_bound_bytes: input_stats.approx_stats.upper_bound_bytes, + } + }; + self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into()); + self + } + pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!( @@ -74,6 +113,9 @@ impl Aggregate { "Output schema = {}", self.output_schema.short_string() )); + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } res } } diff --git a/src/daft-logical-plan/src/ops/concat.rs b/src/daft-logical-plan/src/ops/concat.rs index 39541e39de..fb18441c4c 100644 --- a/src/daft-logical-plan/src/ops/concat.rs +++ b/src/daft-logical-plan/src/ops/concat.rs @@ -3,16 +3,29 @@ use std::sync::Arc; use common_error::DaftError; use snafu::ResultExt; -use crate::{logical_plan, logical_plan::CreationSnafu, LogicalPlan}; +use crate::{ + logical_plan::{self, CreationSnafu}, + stats::{PlanStats, StatsState}, + LogicalPlan, +}; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Concat { // Upstream nodes. pub input: Arc, pub other: Arc, + pub stats_state: StatsState, } impl Concat { + pub(crate) fn new(input: Arc, other: Arc) -> Self { + Self { + input, + other, + stats_state: StatsState::NotMaterialized, + } + } + pub(crate) fn try_new( input: Arc, other: Arc, @@ -26,6 +39,27 @@ impl Concat { ))) .context(CreationSnafu); } - Ok(Self { input, other }) + Ok(Self { + input, + other, + stats_state: StatsState::NotMaterialized, + }) + } + + pub(crate) fn with_materialized_stats(mut self) -> Self { + // TODO(desmond): We can do better estimations with the projection schema. For now, reuse the old logic. + let input_stats = self.input.materialized_stats(); + let other_stats = self.other.materialized_stats(); + let approx_stats = &input_stats.approx_stats + &other_stats.approx_stats; + self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into()); + self + } + + pub fn multiline_display(&self) -> Vec { + let mut res = vec![format!("Concat")]; + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } + res } } diff --git a/src/daft-logical-plan/src/ops/distinct.rs b/src/daft-logical-plan/src/ops/distinct.rs index 11fe254e75..899dab940b 100644 --- a/src/daft-logical-plan/src/ops/distinct.rs +++ b/src/daft-logical-plan/src/ops/distinct.rs @@ -1,15 +1,46 @@ use std::sync::Arc; -use crate::LogicalPlan; +use crate::{ + stats::{ApproxStats, PlanStats, StatsState}, + LogicalPlan, +}; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Distinct { // Upstream node. pub input: Arc, + pub stats_state: StatsState, } impl Distinct { pub(crate) fn new(input: Arc) -> Self { - Self { input } + Self { + input, + stats_state: StatsState::NotMaterialized, + } + } + + pub(crate) fn with_materialized_stats(mut self) -> Self { + // TODO(desmond): We can simply use NDVs here. For now, do a naive estimation. + let input_stats = self.input.materialized_stats(); + let est_bytes_per_row_lower = input_stats.approx_stats.lower_bound_bytes + / (input_stats.approx_stats.lower_bound_rows.max(1)); + let approx_stats = ApproxStats { + lower_bound_rows: input_stats.approx_stats.lower_bound_rows.min(1), + upper_bound_rows: input_stats.approx_stats.upper_bound_rows, + lower_bound_bytes: input_stats.approx_stats.lower_bound_bytes.min(1) + * est_bytes_per_row_lower, + upper_bound_bytes: input_stats.approx_stats.upper_bound_bytes, + }; + self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into()); + self + } + + pub fn multiline_display(&self) -> Vec { + let mut res = vec![format!("Distinct")]; + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } + res } } diff --git a/src/daft-logical-plan/src/ops/explode.rs b/src/daft-logical-plan/src/ops/explode.rs index daa7ca99b0..00624102f4 100644 --- a/src/daft-logical-plan/src/ops/explode.rs +++ b/src/daft-logical-plan/src/ops/explode.rs @@ -7,6 +7,7 @@ use snafu::ResultExt; use crate::{ logical_plan::{self, CreationSnafu}, + stats::{ApproxStats, PlanStats, StatsState}, LogicalPlan, }; @@ -17,6 +18,7 @@ pub struct Explode { // Expressions to explode. e.g. col("a") pub to_explode: Vec, pub exploded_schema: SchemaRef, + pub stats_state: StatsState, } impl Explode { @@ -59,9 +61,22 @@ impl Explode { input, to_explode, exploded_schema, + stats_state: StatsState::NotMaterialized, }) } + pub(crate) fn with_materialized_stats(mut self) -> Self { + let input_stats = self.input.materialized_stats(); + let approx_stats = ApproxStats { + lower_bound_rows: input_stats.approx_stats.lower_bound_rows, + upper_bound_rows: None, + lower_bound_bytes: input_stats.approx_stats.lower_bound_bytes, + upper_bound_bytes: None, + }; + self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into()); + self + } + pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!( @@ -69,6 +84,9 @@ impl Explode { self.to_explode.iter().map(|e| e.to_string()).join(", ") )); res.push(format!("Schema = {}", self.exploded_schema.short_string())); + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } res } } diff --git a/src/daft-logical-plan/src/ops/filter.rs b/src/daft-logical-plan/src/ops/filter.rs index 0f12bf9a49..62bb34a46a 100644 --- a/src/daft-logical-plan/src/ops/filter.rs +++ b/src/daft-logical-plan/src/ops/filter.rs @@ -7,6 +7,7 @@ use snafu::ResultExt; use crate::{ logical_plan::{CreationSnafu, Result}, + stats::{ApproxStats, PlanStats, StatsState}, LogicalPlan, }; @@ -16,6 +17,7 @@ pub struct Filter { pub input: Arc, // The Boolean expression to filter on. pub predicate: ExprRef, + pub stats_state: StatsState, } impl Filter { @@ -33,6 +35,34 @@ impl Filter { ))) .context(CreationSnafu); } - Ok(Self { input, predicate }) + Ok(Self { + input, + predicate, + stats_state: StatsState::NotMaterialized, + }) + } + + pub(crate) fn with_materialized_stats(mut self) -> Self { + // Assume no row/column pruning in cardinality-affecting operations. + // TODO(desmond): We can do better estimations here. For now, reuse the old logic. + let input_stats = self.input.materialized_stats(); + let upper_bound_rows = input_stats.approx_stats.upper_bound_rows; + let upper_bound_bytes = input_stats.approx_stats.upper_bound_bytes; + let approx_stats = ApproxStats { + lower_bound_rows: 0, + upper_bound_rows, + lower_bound_bytes: 0, + upper_bound_bytes, + }; + self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into()); + self + } + + pub fn multiline_display(&self) -> Vec { + let mut res = vec![format!("Filter: {}", self.predicate)]; + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } + res } } diff --git a/src/daft-logical-plan/src/ops/join.rs b/src/daft-logical-plan/src/ops/join.rs index db787cf3a2..5484a5c701 100644 --- a/src/daft-logical-plan/src/ops/join.rs +++ b/src/daft-logical-plan/src/ops/join.rs @@ -18,6 +18,7 @@ use uuid::Uuid; use crate::{ logical_plan::{self, CreationSnafu}, ops::Project, + stats::{ApproxStats, PlanStats, StatsState}, LogicalPlan, }; @@ -33,6 +34,7 @@ pub struct Join { pub join_type: JoinType, pub join_strategy: Option, pub output_schema: SchemaRef, + pub stats_state: StatsState, } impl std::hash::Hash for Join { @@ -49,6 +51,30 @@ impl std::hash::Hash for Join { } impl Join { + #[allow(clippy::too_many_arguments)] + pub(crate) fn new( + left: Arc, + right: Arc, + left_on: Vec, + right_on: Vec, + null_equals_nulls: Option>, + join_type: JoinType, + join_strategy: Option, + output_schema: SchemaRef, + ) -> Self { + Self { + left, + right, + left_on, + right_on, + null_equals_nulls, + join_type, + join_strategy, + output_schema, + stats_state: StatsState::NotMaterialized, + } + } + #[allow(clippy::too_many_arguments)] pub(crate) fn try_new( left: Arc, @@ -129,6 +155,7 @@ impl Join { join_type, join_strategy, output_schema, + stats_state: StatsState::NotMaterialized, }) } else { let common_join_keys: HashSet<_> = @@ -224,6 +251,7 @@ impl Join { join_type, join_strategy, output_schema, + stats_state: StatsState::NotMaterialized, }) } } @@ -283,6 +311,27 @@ impl Join { .unzip() } + pub(crate) fn with_materialized_stats(mut self) -> Self { + // Assume a Primary-key + Foreign-Key join which would yield the max of the two tables. + // TODO(desmond): We can do better estimations here. For now, use the old logic. + let left_stats = self.left.materialized_stats(); + let right_stats = self.right.materialized_stats(); + let approx_stats = ApproxStats { + lower_bound_rows: 0, + upper_bound_rows: left_stats + .approx_stats + .upper_bound_rows + .and_then(|l| right_stats.approx_stats.upper_bound_rows.map(|r| l.max(r))), + lower_bound_bytes: 0, + upper_bound_bytes: left_stats + .approx_stats + .upper_bound_bytes + .and_then(|l| right_stats.approx_stats.upper_bound_bytes.map(|r| l.max(r))), + }; + self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into()); + self + } + pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!("Join: Type = {}", self.join_type)); @@ -320,6 +369,9 @@ impl Join { "Output schema = {}", self.output_schema.short_string() )); + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } res } } diff --git a/src/daft-logical-plan/src/ops/limit.rs b/src/daft-logical-plan/src/ops/limit.rs index 4d91ee4a84..fdb2ecab7c 100644 --- a/src/daft-logical-plan/src/ops/limit.rs +++ b/src/daft-logical-plan/src/ops/limit.rs @@ -1,6 +1,9 @@ use std::sync::Arc; -use crate::LogicalPlan; +use crate::{ + stats::{ApproxStats, PlanStats, StatsState}, + LogicalPlan, +}; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Limit { @@ -11,6 +14,7 @@ pub struct Limit { // Whether to send tasks in waves (maximize throughput) or // eagerly one-at-a-time (maximize time-to-first-result) pub eager: bool, + pub stats_state: StatsState, } impl Limit { @@ -19,6 +23,46 @@ impl Limit { input, limit, eager, + stats_state: StatsState::NotMaterialized, } } + + pub(crate) fn with_materialized_stats(mut self) -> Self { + let input_stats = self.input.materialized_stats(); + let limit = self.limit as usize; + let est_bytes_per_row_lower = input_stats.approx_stats.lower_bound_bytes + / input_stats.approx_stats.lower_bound_rows.max(1); + let est_bytes_per_row_upper = + input_stats + .approx_stats + .upper_bound_bytes + .and_then(|bytes| { + input_stats + .approx_stats + .upper_bound_rows + .map(|rows| bytes / rows.max(1)) + }); + let new_lower_rows = input_stats.approx_stats.lower_bound_rows.min(limit); + let new_upper_rows = input_stats + .approx_stats + .upper_bound_rows + .map(|ub| ub.min(limit)) + .unwrap_or(limit); + let approx_stats = ApproxStats { + lower_bound_rows: new_lower_rows, + upper_bound_rows: Some(new_upper_rows), + lower_bound_bytes: new_lower_rows * est_bytes_per_row_lower, + upper_bound_bytes: est_bytes_per_row_upper.map(|x| x * new_upper_rows), + }; + self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into()); + self + } + + pub fn multiline_display(&self) -> Vec { + let mut res = vec![format!("Limit: {}", self.limit)]; + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } + res + } } diff --git a/src/daft-logical-plan/src/ops/monotonically_increasing_id.rs b/src/daft-logical-plan/src/ops/monotonically_increasing_id.rs index 1e798f4ff3..9be863a686 100644 --- a/src/daft-logical-plan/src/ops/monotonically_increasing_id.rs +++ b/src/daft-logical-plan/src/ops/monotonically_increasing_id.rs @@ -2,13 +2,14 @@ use std::sync::Arc; use daft_core::prelude::*; -use crate::LogicalPlan; +use crate::{stats::StatsState, LogicalPlan}; #[derive(Hash, Eq, PartialEq, Debug, Clone)] pub struct MonotonicallyIncreasingId { pub input: Arc, pub schema: Arc, pub column_name: String, + pub stats_state: StatsState, } impl MonotonicallyIncreasingId { @@ -28,6 +29,22 @@ impl MonotonicallyIncreasingId { input, schema: Arc::new(schema_with_id), column_name: column_name.to_string(), + stats_state: StatsState::NotMaterialized, } } + + pub(crate) fn with_materialized_stats(mut self) -> Self { + // TODO(desmond): We can do better estimations with the projection schema. For now, reuse the old logic. + let input_stats = self.input.materialized_stats(); + self.stats_state = StatsState::Materialized(input_stats.clone().into()); + self + } + + pub fn multiline_display(&self) -> Vec { + let mut res = vec![format!("MonotonicallyIncreasingId")]; + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } + res + } } diff --git a/src/daft-logical-plan/src/ops/pivot.rs b/src/daft-logical-plan/src/ops/pivot.rs index da204fdb34..57ee3bb1c5 100644 --- a/src/daft-logical-plan/src/ops/pivot.rs +++ b/src/daft-logical-plan/src/ops/pivot.rs @@ -9,6 +9,7 @@ use snafu::ResultExt; use crate::{ logical_plan::{self, CreationSnafu}, + stats::StatsState, LogicalPlan, }; @@ -21,6 +22,7 @@ pub struct Pivot { pub aggregation: AggExpr, pub names: Vec, pub output_schema: SchemaRef, + pub stats_state: StatsState, } impl Pivot { @@ -78,9 +80,17 @@ impl Pivot { aggregation: agg_expr.clone(), names, output_schema, + stats_state: StatsState::NotMaterialized, }) } + pub(crate) fn with_materialized_stats(mut self) -> Self { + // TODO(desmond): Pivoting does affect cardinality, but for now we keep the old logic. + let input_stats = self.input.materialized_stats(); + self.stats_state = StatsState::Materialized(input_stats.clone().into()); + self + } + pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push("Pivot:".to_string()); @@ -95,6 +105,9 @@ impl Pivot { "Output schema = {}", self.output_schema.short_string() )); + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } res } } diff --git a/src/daft-logical-plan/src/ops/project.rs b/src/daft-logical-plan/src/ops/project.rs index 40f0565102..dc165d5c5c 100644 --- a/src/daft-logical-plan/src/ops/project.rs +++ b/src/daft-logical-plan/src/ops/project.rs @@ -9,6 +9,7 @@ use snafu::ResultExt; use crate::{ logical_plan::{CreationSnafu, Result}, + stats::StatsState, LogicalPlan, }; @@ -18,6 +19,7 @@ pub struct Project { pub input: Arc, pub projection: Vec, pub projected_schema: SchemaRef, + pub stats_state: StatsState, } impl Project { @@ -38,8 +40,10 @@ impl Project { input: factored_input, projection: factored_projection, projected_schema, + stats_state: StatsState::NotMaterialized, }) } + /// Create a new Projection using the specified output schema pub(crate) fn new_from_schema(input: Arc, schema: SchemaRef) -> Result { let expr: Vec = schema @@ -50,11 +54,22 @@ impl Project { Self::try_new(input, expr) } + pub(crate) fn with_materialized_stats(mut self) -> Self { + // TODO(desmond): We can do better estimations with the projection schema. For now, reuse the old logic. + let input_stats = self.input.materialized_stats(); + self.stats_state = StatsState::Materialized(input_stats.clone().into()); + self + } + pub fn multiline_display(&self) -> Vec { - vec![format!( + let mut res = vec![format!( "Project: {}", self.projection.iter().map(|e| e.to_string()).join(", ") - )] + )]; + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } + res } fn try_factor_subexpressions( diff --git a/src/daft-logical-plan/src/ops/repartition.rs b/src/daft-logical-plan/src/ops/repartition.rs index 1dce616d62..ac12970c49 100644 --- a/src/daft-logical-plan/src/ops/repartition.rs +++ b/src/daft-logical-plan/src/ops/repartition.rs @@ -5,6 +5,7 @@ use daft_dsl::ExprResolver; use crate::{ partitioning::{HashRepartitionConfig, RepartitionSpec}, + stats::StatsState, LogicalPlan, }; @@ -13,6 +14,7 @@ pub struct Repartition { // Upstream node. pub input: Arc, pub repartition_spec: RepartitionSpec, + pub stats_state: StatsState, } impl Repartition { @@ -36,9 +38,17 @@ impl Repartition { Ok(Self { input, repartition_spec, + stats_state: StatsState::NotMaterialized, }) } + pub(crate) fn with_materialized_stats(mut self) -> Self { + // Repartitioning does not affect cardinality. + let input_stats = self.input.materialized_stats(); + self.stats_state = StatsState::Materialized(input_stats.clone().into()); + self + } + pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!( @@ -46,6 +56,9 @@ impl Repartition { self.repartition_spec.var_name(), )); res.extend(self.repartition_spec.multiline_display()); + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } res } } diff --git a/src/daft-logical-plan/src/ops/sample.rs b/src/daft-logical-plan/src/ops/sample.rs index 9d96594666..d11cc5d1d6 100644 --- a/src/daft-logical-plan/src/ops/sample.rs +++ b/src/daft-logical-plan/src/ops/sample.rs @@ -3,7 +3,10 @@ use std::{ sync::Arc, }; -use crate::LogicalPlan; +use crate::{ + stats::{PlanStats, StatsState}, + LogicalPlan, +}; #[derive(Clone, Debug, PartialEq)] pub struct Sample { @@ -12,6 +15,7 @@ pub struct Sample { pub fraction: f64, pub with_replacement: bool, pub seed: Option, + pub stats_state: StatsState, } impl Eq for Sample {} @@ -44,14 +48,28 @@ impl Sample { fraction, with_replacement, seed, + stats_state: StatsState::NotMaterialized, } } + pub(crate) fn with_materialized_stats(mut self) -> Self { + // TODO(desmond): We can do better estimations with the projection schema. For now, reuse the old logic. + let input_stats = self.input.materialized_stats(); + let approx_stats = input_stats + .approx_stats + .apply(|v| ((v as f64) * self.fraction) as usize); + self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into()); + self + } + pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!("Sample: {}", self.fraction)); res.push(format!("With replacement = {}", self.with_replacement)); res.push(format!("Seed = {:?}", self.seed)); + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } res } } diff --git a/src/daft-logical-plan/src/ops/set_operations.rs b/src/daft-logical-plan/src/ops/set_operations.rs index 017104f226..65ceb807b8 100644 --- a/src/daft-logical-plan/src/ops/set_operations.rs +++ b/src/daft-logical-plan/src/ops/set_operations.rs @@ -183,10 +183,7 @@ impl Union { (self.lhs.clone(), self.rhs.clone()) }; // we don't want to use `try_new` as we have already checked the schema - let concat = LogicalPlan::Concat(Concat { - input: lhs, - other: rhs, - }); + let concat = LogicalPlan::Concat(Concat::new(lhs, rhs)); if self.is_all { Ok(concat) } else { diff --git a/src/daft-logical-plan/src/ops/sink.rs b/src/daft-logical-plan/src/ops/sink.rs index 1370ab91f4..e5eb9f3f2e 100644 --- a/src/daft-logical-plan/src/ops/sink.rs +++ b/src/daft-logical-plan/src/ops/sink.rs @@ -6,7 +6,11 @@ use daft_dsl::ExprResolver; #[cfg(feature = "python")] use crate::sink_info::CatalogType; -use crate::{sink_info::SinkInfo, LogicalPlan, OutputFileInfo}; +use crate::{ + sink_info::SinkInfo, + stats::{PlanStats, StatsState}, + LogicalPlan, OutputFileInfo, +}; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Sink { @@ -15,6 +19,7 @@ pub struct Sink { pub schema: SchemaRef, /// Information about the sink data location. pub sink_info: Arc, + pub stats_state: StatsState, } impl Sink { @@ -82,9 +87,17 @@ impl Sink { input, schema, sink_info, + stats_state: StatsState::NotMaterialized, }) } + pub(crate) fn with_materialized_stats(mut self) -> Self { + // Post-write DataFrame will contain paths to files that were written. + // TODO(desmond): Estimate output size via root directory and estimates for # of partitions given partitioning column. + self.stats_state = StatsState::Materialized(PlanStats::empty().into()); + self + } + pub fn multiline_display(&self) -> Vec { let mut res = vec![]; @@ -110,6 +123,9 @@ impl Sink { }, } res.push(format!("Output schema = {}", self.schema.short_string())); + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } res } } diff --git a/src/daft-logical-plan/src/ops/sort.rs b/src/daft-logical-plan/src/ops/sort.rs index 85cd8c2a64..9c2cd046fd 100644 --- a/src/daft-logical-plan/src/ops/sort.rs +++ b/src/daft-logical-plan/src/ops/sort.rs @@ -6,7 +6,7 @@ use daft_dsl::{ExprRef, ExprResolver}; use itertools::Itertools; use snafu::ResultExt; -use crate::{logical_plan, logical_plan::CreationSnafu, LogicalPlan}; +use crate::{logical_plan, logical_plan::CreationSnafu, stats::StatsState, LogicalPlan}; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Sort { @@ -15,6 +15,7 @@ pub struct Sort { pub sort_by: Vec, pub descending: Vec, pub nulls_first: Vec, + pub stats_state: StatsState, } impl Sort { @@ -54,9 +55,17 @@ impl Sort { sort_by, descending, nulls_first, + stats_state: StatsState::NotMaterialized, }) } + pub(crate) fn with_materialized_stats(mut self) -> Self { + // Sorting does not affect cardinality. + let input_stats = self.input.materialized_stats(); + self.stats_state = StatsState::Materialized(input_stats.clone().into()); + self + } + pub fn multiline_display(&self) -> Vec { let mut res = vec![]; // Must have at least one expression to sort by. @@ -76,6 +85,9 @@ impl Sort { }) .join(", "); res.push(format!("Sort: Sort by = {}", pairs)); + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } res } } diff --git a/src/daft-logical-plan/src/ops/source.rs b/src/daft-logical-plan/src/ops/source.rs index 111575f411..4044e08c72 100644 --- a/src/daft-logical-plan/src/ops/source.rs +++ b/src/daft-logical-plan/src/ops/source.rs @@ -1,9 +1,13 @@ use std::sync::Arc; -use common_scan_info::PhysicalScanInfo; +use common_error::DaftResult; +use common_scan_info::{PhysicalScanInfo, ScanState}; use daft_schema::schema::SchemaRef; -use crate::source_info::{InMemoryInfo, PlaceHolderInfo, SourceInfo}; +use crate::{ + source_info::{InMemoryInfo, PlaceHolderInfo, SourceInfo}, + stats::{ApproxStats, PlanStats, StatsState}, +}; #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Source { @@ -13,6 +17,7 @@ pub struct Source { /// Information about the source data location. pub source_info: Arc, + pub stats_state: StatsState, } impl Source { @@ -20,21 +25,92 @@ impl Source { Self { output_schema, source_info, + stats_state: StatsState::NotMaterialized, } } + // Helper method that converts the ScanOperatorRef inside a Source node's PhysicalScanInfo into scan tasks. + // Should only be called if a Source node's source info contains PhysicalScanInfo. The PhysicalScanInfo + // should also hold a ScanState::Operator and not a ScanState::Tasks (which would indicate that we're + // materializing this physical scan node multiple times). + pub(crate) fn build_materialized_scan_source(mut self) -> DaftResult { + let new_physical_scan_info = match Arc::unwrap_or_clone(self.source_info) { + SourceInfo::Physical(mut physical_scan_info) => { + let scan_tasks = match &physical_scan_info.scan_state { + ScanState::Operator(scan_op) => scan_op + .0 + .to_scan_tasks(physical_scan_info.pushdowns.clone())?, + ScanState::Tasks(_) => { + panic!("Physical scan nodes are being materialized more than once"); + } + }; + physical_scan_info.scan_state = ScanState::Tasks(Arc::new(scan_tasks)); + physical_scan_info + } + _ => panic!("Only unmaterialized physical scan nodes can be materialized"), + }; + self.source_info = Arc::new(SourceInfo::Physical(new_physical_scan_info)); + Ok(self) + } + + pub(crate) fn with_materialized_stats(mut self) -> Self { + let approx_stats = match &*self.source_info { + SourceInfo::InMemory(InMemoryInfo { + size_bytes, + num_rows, + .. + }) => ApproxStats { + lower_bound_rows: *num_rows, + upper_bound_rows: Some(*num_rows), + lower_bound_bytes: *size_bytes, + upper_bound_bytes: Some(*size_bytes), + }, + SourceInfo::Physical(physical_scan_info) => match &physical_scan_info.scan_state { + ScanState::Operator(_) => { + panic!("Scan nodes should be materialized before stats are materialized") + } + ScanState::Tasks(scan_tasks) => { + let mut approx_stats = ApproxStats::empty(); + for st in scan_tasks.iter() { + approx_stats.lower_bound_rows += st.num_rows().unwrap_or(0); + let in_memory_size = st.estimate_in_memory_size_bytes(None); + approx_stats.lower_bound_bytes += in_memory_size.unwrap_or(0); + if let Some(st_ub) = st.upper_bound_rows() { + if let Some(ub) = approx_stats.upper_bound_rows { + approx_stats.upper_bound_rows = Some(ub + st_ub); + } else { + approx_stats.upper_bound_rows = st.upper_bound_rows(); + } + } + if let Some(st_ub) = in_memory_size { + if let Some(ub) = approx_stats.upper_bound_bytes { + approx_stats.upper_bound_bytes = Some(ub + st_ub); + } else { + approx_stats.upper_bound_bytes = in_memory_size; + } + } + } + approx_stats + } + }, + SourceInfo::PlaceHolder(_) => ApproxStats::empty(), + }; + self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into()); + self + } + pub fn multiline_display(&self) -> Vec { let mut res = vec![]; match self.source_info.as_ref() { SourceInfo::Physical(PhysicalScanInfo { source_schema, - scan_op, + scan_state, partitioning_keys, pushdowns, }) => { use itertools::Itertools; - res.extend(scan_op.0.multiline_display()); + res.extend(scan_state.multiline_display()); res.push(format!("File schema = {}", source_schema.short_string())); res.push(format!( @@ -61,6 +137,9 @@ impl Source { "Output schema = {}", self.output_schema.short_string() )); + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } res } } diff --git a/src/daft-logical-plan/src/ops/unpivot.rs b/src/daft-logical-plan/src/ops/unpivot.rs index cec9cd1c00..46a7071bf5 100644 --- a/src/daft-logical-plan/src/ops/unpivot.rs +++ b/src/daft-logical-plan/src/ops/unpivot.rs @@ -8,6 +8,7 @@ use snafu::ResultExt; use crate::{ logical_plan::{self, CreationSnafu}, + stats::{ApproxStats, PlanStats, StatsState}, LogicalPlan, }; @@ -19,9 +20,30 @@ pub struct Unpivot { pub variable_name: String, pub value_name: String, pub output_schema: SchemaRef, + pub stats_state: StatsState, } impl Unpivot { + pub(crate) fn new( + input: Arc, + ids: Vec, + values: Vec, + variable_name: String, + value_name: String, + output_schema: SchemaRef, + ) -> Self { + Self { + input, + ids, + values, + variable_name, + value_name, + output_schema, + stats_state: StatsState::NotMaterialized, + } + } + + // Similar to new, except that `try_new` is not given the output schema and instead extracts it. pub(crate) fn try_new( input: Arc, ids: Vec, @@ -71,9 +93,26 @@ impl Unpivot { variable_name: variable_name.to_string(), value_name: value_name.to_string(), output_schema, + stats_state: StatsState::NotMaterialized, }) } + pub(crate) fn with_materialized_stats(mut self) -> Self { + let input_stats = self.input.materialized_stats(); + let num_values = self.values.len(); + let approx_stats = ApproxStats { + lower_bound_rows: input_stats.approx_stats.lower_bound_rows * num_values, + upper_bound_rows: input_stats + .approx_stats + .upper_bound_rows + .map(|v| v * num_values), + lower_bound_bytes: input_stats.approx_stats.lower_bound_bytes, + upper_bound_bytes: input_stats.approx_stats.upper_bound_bytes, + }; + self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into()); + self + } + pub fn multiline_display(&self) -> Vec { let mut res = vec![]; res.push(format!( @@ -85,6 +124,9 @@ impl Unpivot { self.ids.iter().map(|e| e.to_string()).join(", ") )); res.push(format!("Schema = {}", self.output_schema.short_string())); + if let StatsState::Materialized(stats) = &self.stats_state { + res.push(format!("Stats = {}", stats)); + } res } } diff --git a/src/daft-logical-plan/src/optimization/optimizer.rs b/src/daft-logical-plan/src/optimization/optimizer.rs index 084018522a..61a3ff314e 100644 --- a/src/daft-logical-plan/src/optimization/optimizer.rs +++ b/src/daft-logical-plan/src/optimization/optimizer.rs @@ -6,8 +6,8 @@ use common_treenode::Transformed; use super::{ logical_plan_tracker::LogicalPlanTracker, rules::{ - DropRepartition, EliminateCrossJoin, LiftProjectFromAgg, OptimizerRule, PushDownFilter, - PushDownLimit, PushDownProjection, SplitActorPoolProjects, + DropRepartition, EliminateCrossJoin, EnrichWithStats, LiftProjectFromAgg, MaterializeScans, + OptimizerRule, PushDownFilter, PushDownLimit, PushDownProjection, SplitActorPoolProjects, }, }; use crate::LogicalPlan; @@ -136,6 +136,18 @@ impl Optimizer { RuleExecutionStrategy::FixedPoint(Some(3)), )); + // --- Materialize scan nodes --- + rule_batches.push(RuleBatch::new( + vec![Box::new(MaterializeScans::new())], + RuleExecutionStrategy::Once, + )); + + // --- Enrich logical plan with stats --- + rule_batches.push(RuleBatch::new( + vec![Box::new(EnrichWithStats::new())], + RuleExecutionStrategy::Once, + )); + Self::with_rule_batches(rule_batches, config) } diff --git a/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs b/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs index c8e888fecf..e9e3a2e524 100644 --- a/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs +++ b/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs @@ -52,7 +52,9 @@ impl OptimizerRule for EliminateCrossJoin { if !can_flatten_join_inputs(filter.input.as_ref()) { return Ok(Transformed::no(Arc::new(LogicalPlan::Filter(filter)))); } - let Filter { input, predicate } = filter; + let Filter { + input, predicate, .. + } = filter; flatten_join_inputs( Arc::unwrap_or_clone(input), &mut possible_join_keys, @@ -306,16 +308,16 @@ fn find_inner_join( .non_distinct_union(right_input.schema().as_ref()); let (left_keys, right_keys) = join_keys.iter().cloned().unzip(); - return Ok(LogicalPlan::Join(Join { - left: left_input, - right: right_input, - left_on: left_keys, - right_on: right_keys, - null_equals_nulls: None, - join_type: JoinType::Inner, - join_strategy: None, - output_schema: Arc::new(join_schema), - }) + return Ok(LogicalPlan::Join(Join::new( + left_input, + right_input, + left_keys, + right_keys, + None, + JoinType::Inner, + None, + Arc::new(join_schema), + )) .arced()); } } @@ -327,16 +329,16 @@ fn find_inner_join( .schema() .non_distinct_union(right.schema().as_ref()); - Ok(LogicalPlan::Join(Join { - left: left_input, + Ok(LogicalPlan::Join(Join::new( + left_input, right, - left_on: vec![], - right_on: vec![], - null_equals_nulls: None, - join_type: JoinType::Inner, - join_strategy: None, - output_schema: Arc::new(join_schema), - }) + vec![], + vec![], + None, + JoinType::Inner, + None, + Arc::new(join_schema), + )) .arced()) } @@ -449,14 +451,14 @@ mod tests { ]) .unwrap(), ); - LogicalPlan::Source(Source { - output_schema: schema.clone(), - source_info: Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { + LogicalPlan::Source(Source::new( + schema.clone(), + Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { source_schema: schema, clustering_spec: Arc::new(ClusteringSpec::unknown()), source_id: 0, })), - }) + )) .arced() } @@ -470,14 +472,14 @@ mod tests { ]) .unwrap(), ); - LogicalPlan::Source(Source { - output_schema: schema.clone(), - source_info: Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { + LogicalPlan::Source(Source::new( + schema.clone(), + Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { source_schema: schema, clustering_spec: Arc::new(ClusteringSpec::unknown()), source_id: 0, })), - }) + )) .arced() } diff --git a/src/daft-logical-plan/src/optimization/rules/enrich_with_stats.rs b/src/daft-logical-plan/src/optimization/rules/enrich_with_stats.rs new file mode 100644 index 0000000000..f6bc338ab6 --- /dev/null +++ b/src/daft-logical-plan/src/optimization/rules/enrich_with_stats.rs @@ -0,0 +1,27 @@ +#[derive(Default, Debug)] +pub struct EnrichWithStats {} + +impl EnrichWithStats { + pub fn new() -> Self { + Self {} + } +} +use std::sync::Arc; + +use common_error::DaftResult; +use common_treenode::{Transformed, TreeNode}; + +use super::OptimizerRule; +use crate::LogicalPlan; + +// Add stats to all logical plan nodes in a bottom up fashion. +// All scan nodes MUST be materialized before stats are enriched. +impl OptimizerRule for EnrichWithStats { + fn try_optimize(&self, plan: Arc) -> DaftResult>> { + plan.transform_up(|node: Arc| { + Ok(Transformed::yes( + Arc::unwrap_or_clone(node).with_materialized_stats().into(), + )) + }) + } +} diff --git a/src/daft-logical-plan/src/optimization/rules/materialize_scans.rs b/src/daft-logical-plan/src/optimization/rules/materialize_scans.rs new file mode 100644 index 0000000000..4b9f9707ed --- /dev/null +++ b/src/daft-logical-plan/src/optimization/rules/materialize_scans.rs @@ -0,0 +1,47 @@ +#[derive(Default, Debug)] +pub struct MaterializeScans {} + +impl MaterializeScans { + pub fn new() -> Self { + Self {} + } +} +use std::sync::Arc; + +use common_error::DaftResult; +use common_treenode::{Transformed, TreeNode}; + +use super::OptimizerRule; +use crate::{LogicalPlan, SourceInfo}; + +// Materialize scan tasks from scan operators for all physical scans. +impl OptimizerRule for MaterializeScans { + fn try_optimize(&self, plan: Arc) -> DaftResult>> { + plan.transform_up(|node| self.try_optimize_node(node)) + } +} + +impl MaterializeScans { + #[allow(clippy::only_used_in_recursion)] + fn try_optimize_node( + &self, + plan: Arc, + ) -> DaftResult>> { + match &*plan { + LogicalPlan::Source(source) => match &*source.source_info { + SourceInfo::Physical(_) => { + let source_plan = Arc::unwrap_or_clone(plan); + if let LogicalPlan::Source(source) = source_plan { + Ok(Transformed::yes( + source.build_materialized_scan_source()?.into(), + )) + } else { + unreachable!("This logical plan was already matched as a Source node") + } + } + _ => Ok(Transformed::no(plan)), + }, + _ => Ok(Transformed::no(plan)), + } + } +} diff --git a/src/daft-logical-plan/src/optimization/rules/mod.rs b/src/daft-logical-plan/src/optimization/rules/mod.rs index 78ce51533a..75e0f36c88 100644 --- a/src/daft-logical-plan/src/optimization/rules/mod.rs +++ b/src/daft-logical-plan/src/optimization/rules/mod.rs @@ -1,6 +1,8 @@ mod drop_repartition; mod eliminate_cross_join; +mod enrich_with_stats; mod lift_project_from_agg; +mod materialize_scans; mod push_down_filter; mod push_down_limit; mod push_down_projection; @@ -9,7 +11,9 @@ mod split_actor_pool_projects; pub use drop_repartition::DropRepartition; pub use eliminate_cross_join::EliminateCrossJoin; +pub use enrich_with_stats::EnrichWithStats; pub use lift_project_from_agg::LiftProjectFromAgg; +pub use materialize_scans::MaterializeScans; pub use push_down_filter::PushDownFilter; pub use push_down_limit::PushDownLimit; pub use push_down_projection::PushDownProjection; diff --git a/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs b/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs index 2cf2ea14ad..9fa30ea8e5 100644 --- a/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs +++ b/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs @@ -113,7 +113,7 @@ impl PushDownFilter { needing_filter_op, } = rewrite_predicate_for_partitioning( &new_predicate, - external_info.scan_op.0.partitioning_keys(), + external_info.scan_state.get_scan_op().0.partitioning_keys(), )?; assert!( partition_only_filter.len() @@ -239,7 +239,7 @@ impl PushDownFilter { .into(); child_plan.with_new_children(&[new_filter]).into() } - LogicalPlan::Concat(Concat { input, other }) => { + LogicalPlan::Concat(Concat { input, other, .. }) => { // Push filter into each side of the concat. let new_input: LogicalPlan = Filter::try_new(input.clone(), filter.predicate.clone())?.into(); diff --git a/src/daft-logical-plan/src/optimization/rules/push_down_limit.rs b/src/daft-logical-plan/src/optimization/rules/push_down_limit.rs index b32ab85fbb..e79879604d 100644 --- a/src/daft-logical-plan/src/optimization/rules/push_down_limit.rs +++ b/src/daft-logical-plan/src/optimization/rules/push_down_limit.rs @@ -37,6 +37,7 @@ impl PushDownLimit { input, limit, eager, + .. }) => { let limit = *limit as usize; match input.as_ref() { @@ -74,7 +75,12 @@ impl PushDownLimit { SourceInfo::Physical(new_external_info).into(), )) .into(); - let out_plan = if external_info.scan_op.0.can_absorb_limit() { + let out_plan = if external_info + .scan_state + .get_scan_op() + .0 + .can_absorb_limit() + { new_source } else { plan.with_new_children(&[new_source]).into() @@ -93,6 +99,7 @@ impl PushDownLimit { input, limit: child_limit, eager: child_eagar, + .. }) => { let new_limit = limit.min(*child_limit as usize); let new_eager = eager | child_eagar; diff --git a/src/daft-logical-plan/src/source_info/mod.rs b/src/daft-logical-plan/src/source_info/mod.rs index 07ae94b841..11122464a0 100644 --- a/src/daft-logical-plan/src/source_info/mod.rs +++ b/src/daft-logical-plan/src/source_info/mod.rs @@ -16,7 +16,7 @@ use { use crate::partitioning::ClusteringSpecRef; -#[derive(Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum SourceInfo { InMemory(InMemoryInfo), Physical(PhysicalScanInfo), @@ -78,7 +78,7 @@ impl Hash for InMemoryInfo { static PLACEHOLDER_ID_COUNTER: AtomicUsize = AtomicUsize::new(0); -#[derive(Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct PlaceHolderInfo { pub source_schema: SchemaRef, pub clustering_spec: ClusteringSpecRef, diff --git a/src/daft-logical-plan/src/stats.rs b/src/daft-logical-plan/src/stats.rs new file mode 100644 index 0000000000..22c3f85198 --- /dev/null +++ b/src/daft-logical-plan/src/stats.rs @@ -0,0 +1,146 @@ +use std::{fmt::Display, hash::Hash, ops::Deref}; + +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq, Hash)] +pub enum StatsState { + Materialized(AlwaysSame), + NotMaterialized, +} + +impl StatsState { + pub fn materialized_stats(&self) -> &PlanStats { + match self { + Self::Materialized(stats) => stats, + Self::NotMaterialized => panic!("Tried to get unmaterialized stats"), + } + } +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct PlanStats { + // Currently we're only putting cardinality stats in the plan stats. + // In the future we want to start including column stats, including min, max, NDVs, etc. + pub approx_stats: ApproxStats, +} + +impl PlanStats { + pub fn new(approx_stats: ApproxStats) -> Self { + Self { approx_stats } + } + + pub fn empty() -> Self { + Self { + approx_stats: ApproxStats::empty(), + } + } +} + +impl Default for PlanStats { + fn default() -> Self { + Self::empty() + } +} + +impl Display for PlanStats { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{{ Lower bound rows = {}, Upper bound rows = {}, Lower bound bytes = {}, Upper bound bytes = {} }}", + self.approx_stats.lower_bound_rows, + self.approx_stats.upper_bound_rows.map_or("None".to_string(), |v| v.to_string()), + self.approx_stats.lower_bound_bytes, + self.approx_stats.upper_bound_bytes.map_or("None".to_string(), |v| v.to_string()), + ) + } +} + +// We implement PartialEq, Eq, and Hash for AlwaysSame, then add PlanStats to LogicalPlans wrapped by AlwaysSame. +// This allows all PlanStats to be considered equal, so that logical/physical plans that are enriched with +// stats can easily implement PartialEq, Eq, and Hash in a way that ignores PlanStats when considering equality. + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct AlwaysSame(T); + +impl Deref for AlwaysSame { + type Target = T; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Hash for AlwaysSame { + #[inline] + fn hash(&self, _state: &mut H) { + // Add nothing to hash state since all AlwaysSame should hash the same. + } +} + +impl Eq for AlwaysSame {} + +impl PartialEq for AlwaysSame { + #[inline] + fn eq(&self, _other: &Self) -> bool { + true + } +} + +impl From for AlwaysSame { + #[inline] + fn from(value: T) -> Self { + Self(value) + } +} + +impl Display for AlwaysSame { + #[inline] + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +#[derive(Clone, Debug, PartialEq, Eq, Deserialize, Serialize)] +pub struct ApproxStats { + pub lower_bound_rows: usize, + pub upper_bound_rows: Option, + pub lower_bound_bytes: usize, + pub upper_bound_bytes: Option, +} + +impl ApproxStats { + pub fn empty() -> Self { + Self { + lower_bound_rows: 0, + upper_bound_rows: None, + lower_bound_bytes: 0, + upper_bound_bytes: None, + } + } + pub fn apply usize>(&self, f: F) -> Self { + Self { + lower_bound_rows: f(self.lower_bound_rows), + upper_bound_rows: self.upper_bound_rows.map(&f), + lower_bound_bytes: f(self.lower_bound_rows), + upper_bound_bytes: self.upper_bound_bytes.map(&f), + } + } +} + +use std::ops::Add; +impl Add for &ApproxStats { + type Output = ApproxStats; + fn add(self, rhs: Self) -> Self::Output { + ApproxStats { + lower_bound_rows: self.lower_bound_rows + rhs.lower_bound_rows, + upper_bound_rows: self + .upper_bound_rows + .and_then(|l_ub| rhs.upper_bound_rows.map(|v| v + l_ub)), + lower_bound_bytes: self.lower_bound_bytes + rhs.lower_bound_bytes, + upper_bound_bytes: self + .upper_bound_bytes + .and_then(|l_ub| rhs.upper_bound_bytes.map(|v| v + l_ub)), + } + } +} diff --git a/src/daft-parquet/src/metadata.rs b/src/daft-parquet/src/metadata.rs index 32c1090ddd..aab541c04f 100644 --- a/src/daft-parquet/src/metadata.rs +++ b/src/daft-parquet/src/metadata.rs @@ -176,8 +176,8 @@ fn apply_field_ids_to_parquet_file_metadata( let new_row_groups = file_metadata .row_groups - .into_values() - .map(|rg| { + .iter() + .map(|(_, rg)| { let new_columns = rg .columns() .iter() diff --git a/src/daft-physical-plan/src/ops/scan.rs b/src/daft-physical-plan/src/ops/scan.rs index 6c8472beda..6333b05100 100644 --- a/src/daft-physical-plan/src/ops/scan.rs +++ b/src/daft-physical-plan/src/ops/scan.rs @@ -6,15 +6,15 @@ use common_scan_info::ScanTaskLikeRef; use daft_logical_plan::partitioning::ClusteringSpec; use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct TabularScan { - pub scan_tasks: Vec, + pub scan_tasks: Arc>, pub clustering_spec: Arc, } impl TabularScan { pub(crate) fn new( - scan_tasks: Vec, + scan_tasks: Arc>, clustering_spec: Arc, ) -> Self { Self { @@ -100,7 +100,7 @@ Clustering spec = {{ {clustering_spec} }} let mut s = base_display(self); writeln!(s, "Scan Tasks: [").unwrap(); - for st in &self.scan_tasks { + for st in self.scan_tasks.iter() { writeln!(s, "{}", st.as_ref().display_as(DisplayLevel::Verbose)).unwrap(); } s diff --git a/src/daft-physical-plan/src/physical_planner/planner.rs b/src/daft-physical-plan/src/physical_planner/planner.rs index f5c146cd2e..23ad09376d 100644 --- a/src/daft-physical-plan/src/physical_planner/planner.rs +++ b/src/daft-physical-plan/src/physical_planner/planner.rs @@ -215,6 +215,7 @@ impl TreeNodeRewriter for ReplacePlaceholdersWithMaterializedResult { LogicalPlan::Source(Source { output_schema: _, source_info, + .. }) => match source_info.as_ref() { SourceInfo::PlaceHolder(phi) => { assert!(self.mat_results.is_some()); @@ -226,10 +227,10 @@ impl TreeNodeRewriter for ReplacePlaceholdersWithMaterializedResult { mat_results.in_memory_info.clustering_spec = Some(phi.clustering_spec.clone()); mat_results.in_memory_info.source_schema = phi.source_schema.clone(); - let new_source_node = LogicalPlan::Source(Source { - output_schema: mat_results.in_memory_info.source_schema.clone(), - source_info: SourceInfo::InMemory(mat_results.in_memory_info).into(), - }) + let new_source_node = LogicalPlan::Source(Source::new( + mat_results.in_memory_info.source_schema.clone(), + SourceInfo::InMemory(mat_results.in_memory_info).into(), + )) .arced(); Ok(Transformed::new( new_source_node, diff --git a/src/daft-physical-plan/src/physical_planner/translate.rs b/src/daft-physical-plan/src/physical_planner/translate.rs index 9bffaef97b..ec4e4a1985 100644 --- a/src/daft-physical-plan/src/physical_planner/translate.rs +++ b/src/daft-physical-plan/src/physical_planner/translate.rs @@ -7,7 +7,7 @@ use std::{ use common_daft_config::DaftExecutionConfig; use common_error::{DaftError, DaftResult}; use common_file_formats::FileFormat; -use common_scan_info::PhysicalScanInfo; +use common_scan_info::{PhysicalScanInfo, ScanState, SPLIT_AND_MERGE_PASS}; use daft_core::prelude::*; use daft_dsl::{ col, functions::agg::merge_mean, is_partition_compatible, AggExpr, ApproxPercentileParams, @@ -38,15 +38,22 @@ pub(super) fn translate_single_logical_node( physical_children: &mut Vec, cfg: &DaftExecutionConfig, ) -> DaftResult { - match logical_plan { + let physical_plan = match logical_plan { LogicalPlan::Source(Source { source_info, .. }) => match source_info.as_ref() { SourceInfo::Physical(PhysicalScanInfo { pushdowns, - scan_op, + scan_state, source_schema, .. }) => { - let scan_tasks = scan_op.0.to_scan_tasks(pushdowns.clone(), Some(cfg))?; + let scan_tasks = { + match scan_state { + ScanState::Operator(scan_op) => { + Arc::new(scan_op.0.to_scan_tasks(pushdowns.clone())?) + } + ScanState::Tasks(scan_tasks) => scan_tasks.clone(), + } + }; if scan_tasks.is_empty() { let clustering_spec = @@ -58,6 +65,14 @@ pub(super) fn translate_single_logical_node( )) .arced()) } else { + // Perform scan task splitting and merging. + let scan_tasks = if let Some(split_and_merge_pass) = SPLIT_AND_MERGE_PASS.get() + { + split_and_merge_pass(scan_tasks, pushdowns, cfg)? + } else { + scan_tasks + }; + let clustering_spec = Arc::new(ClusteringSpec::Unknown( UnknownClusteringConfig::new(scan_tasks.len()), )); @@ -205,7 +220,7 @@ pub(super) fn translate_single_logical_node( }; Ok(repartitioned_plan.arced()) } - LogicalPlan::Distinct(LogicalDistinct { input }) => { + LogicalPlan::Distinct(LogicalDistinct { input, .. }) => { let input_physical = physical_children.pop().expect("requires 1 input"); let col_exprs = input .schema() @@ -756,7 +771,12 @@ pub(super) fn translate_single_logical_node( LogicalPlan::Union(_) => Err(DaftError::InternalError( "Union should already be optimized away".to_string(), )), - } + }?; + // TODO(desmond): We can't perform this check for now because ScanTasks currently provide + // different size estimations depending on when the approximation is computed. Once we fix + // this, we can add back in the assertion here. + // debug_assert!(logical_plan.get_stats().approx_stats == physical_plan.approximate_stats()); + Ok(physical_plan) } pub fn extract_agg_expr(expr: &ExprRef) -> DaftResult { diff --git a/src/daft-physical-plan/src/plan.rs b/src/daft-physical-plan/src/plan.rs index 740905b6e8..acf456f14d 100644 --- a/src/daft-physical-plan/src/plan.rs +++ b/src/daft-physical-plan/src/plan.rs @@ -1,8 +1,11 @@ -use std::{cmp::max, collections::HashSet, ops::Add, sync::Arc}; +use std::{cmp::max, collections::HashSet, sync::Arc}; use common_display::ascii::AsciiTreeDisplay; -use daft_logical_plan::partitioning::{ - ClusteringSpec, HashClusteringConfig, RangeClusteringConfig, UnknownClusteringConfig, +use daft_logical_plan::{ + partitioning::{ + ClusteringSpec, HashClusteringConfig, RangeClusteringConfig, UnknownClusteringConfig, + }, + stats::ApproxStats, }; use serde::{Deserialize, Serialize}; @@ -43,48 +46,6 @@ pub enum PhysicalPlan { LanceWrite(LanceWrite), } -pub struct ApproxStats { - pub lower_bound_rows: usize, - pub upper_bound_rows: Option, - pub lower_bound_bytes: usize, - pub upper_bound_bytes: Option, -} - -impl ApproxStats { - fn empty() -> Self { - Self { - lower_bound_rows: 0, - upper_bound_rows: None, - lower_bound_bytes: 0, - upper_bound_bytes: None, - } - } - fn apply usize>(&self, f: F) -> Self { - Self { - lower_bound_rows: f(self.lower_bound_rows), - upper_bound_rows: self.upper_bound_rows.map(&f), - lower_bound_bytes: f(self.lower_bound_rows), - upper_bound_bytes: self.upper_bound_bytes.map(&f), - } - } -} - -impl Add for &ApproxStats { - type Output = ApproxStats; - fn add(self, rhs: Self) -> Self::Output { - ApproxStats { - lower_bound_rows: self.lower_bound_rows + rhs.lower_bound_rows, - upper_bound_rows: self - .upper_bound_rows - .and_then(|l_ub| rhs.upper_bound_rows.map(|v| v + l_ub)), - lower_bound_bytes: self.lower_bound_bytes + rhs.lower_bound_bytes, - upper_bound_bytes: self - .upper_bound_bytes - .and_then(|l_ub| rhs.upper_bound_bytes.map(|v| v + l_ub)), - } - } -} - impl PhysicalPlan { pub fn arced(self) -> PhysicalPlanRef { Arc::new(self) @@ -229,16 +190,24 @@ impl PhysicalPlan { }, Self::TabularScan(TabularScan { scan_tasks, .. }) => { let mut stats = ApproxStats::empty(); - for st in scan_tasks { + for st in scan_tasks.iter() { stats.lower_bound_rows += st.num_rows().unwrap_or(0); let in_memory_size = st.estimate_in_memory_size_bytes(None); stats.lower_bound_bytes += in_memory_size.unwrap_or(0); - stats.upper_bound_rows = stats - .upper_bound_rows - .and_then(|st_ub| st.upper_bound_rows().map(|ub| st_ub + ub)); - stats.upper_bound_bytes = stats - .upper_bound_bytes - .and_then(|st_ub| in_memory_size.map(|ub| st_ub + ub)); + if let Some(st_ub) = st.upper_bound_rows() { + if let Some(ub) = stats.upper_bound_rows { + stats.upper_bound_rows = Some(ub + st_ub); + } else { + stats.upper_bound_rows = st.upper_bound_rows(); + } + } + if let Some(st_ub) = in_memory_size { + if let Some(ub) = stats.upper_bound_bytes { + stats.upper_bound_bytes = Some(ub + st_ub); + } else { + stats.upper_bound_bytes = in_memory_size; + } + } } stats } diff --git a/src/daft-scan/Cargo.toml b/src/daft-scan/Cargo.toml index 12eeb71f5b..49344c53ab 100644 --- a/src/daft-scan/Cargo.toml +++ b/src/daft-scan/Cargo.toml @@ -8,6 +8,7 @@ common-io-config = {path = "../common/io-config", default-features = false} common-py-serde = {path = "../common/py-serde", default-features = false} common-runtime = {path = "../common/runtime", default-features = false} common-scan-info = {path = "../common/scan-info", default-features = false} +ctor = "0.2.9" daft-core = {path = "../daft-core", default-features = false} daft-csv = {path = "../daft-csv", default-features = false} daft-decoding = {path = "../daft-decoding", default-features = false} diff --git a/src/daft-scan/src/anonymous.rs b/src/daft-scan/src/anonymous.rs index 2931996809..17f8c6574a 100644 --- a/src/daft-scan/src/anonymous.rs +++ b/src/daft-scan/src/anonymous.rs @@ -1,16 +1,11 @@ use std::sync::Arc; -use common_daft_config::DaftExecutionConfig; use common_error::DaftResult; use common_file_formats::{FileFormatConfig, ParquetSourceConfig}; use common_scan_info::{PartitionField, Pushdowns, ScanOperator, ScanTaskLike, ScanTaskLikeRef}; use daft_schema::schema::SchemaRef; -use crate::{ - scan_task_iters::{merge_by_sizes, split_by_row_groups, BoxScanTaskIter}, - storage_config::StorageConfig, - ChunkSpec, DataSource, ScanTask, -}; +use crate::{storage_config::StorageConfig, ChunkSpec, DataSource, ScanTask}; #[derive(Debug)] pub struct AnonymousScanOperator { files: Vec, @@ -74,11 +69,7 @@ impl ScanOperator for AnonymousScanOperator { lines } - fn to_scan_tasks( - &self, - pushdowns: Pushdowns, - cfg: Option<&DaftExecutionConfig>, - ) -> DaftResult> { + fn to_scan_tasks(&self, pushdowns: Pushdowns) -> DaftResult> { let files = self.files.clone(); let file_format_config = self.file_format_config.clone(); let schema = self.schema.clone(); @@ -95,10 +86,12 @@ impl ScanOperator for AnonymousScanOperator { }; // Create one ScanTask per file. - let mut scan_tasks: BoxScanTaskIter = - Box::new(files.into_iter().zip(row_groups).map(|(f, rg)| { + Ok(files + .into_iter() + .zip(row_groups) + .map(|(f, rg)| { let chunk_spec = rg.map(ChunkSpec::Parquet); - Ok(ScanTask::new( + Arc::new(ScanTask::new( vec![DataSource::File { path: f, chunk_spec, @@ -114,23 +107,9 @@ impl ScanOperator for AnonymousScanOperator { storage_config.clone(), pushdowns.clone(), None, - ) - .into()) - })); - - if let Some(cfg) = cfg { - scan_tasks = split_by_row_groups( - scan_tasks, - cfg.parquet_split_row_groups_max_files, - cfg.scan_tasks_min_size_bytes, - cfg.scan_tasks_max_size_bytes, - ); - - scan_tasks = merge_by_sizes(scan_tasks, &pushdowns, cfg); - } - - scan_tasks - .map(|st| st.map(|task| task as Arc)) - .collect() + )) + }) + .map(|st| st as Arc) + .collect()) } } diff --git a/src/daft-scan/src/glob.rs b/src/daft-scan/src/glob.rs index 899e0ebc89..2f8d0f071f 100644 --- a/src/daft-scan/src/glob.rs +++ b/src/daft-scan/src/glob.rs @@ -1,6 +1,5 @@ use std::{sync::Arc, vec}; -use common_daft_config::DaftExecutionConfig; use common_error::{DaftError, DaftResult}; use common_file_formats::{CsvSourceConfig, FileFormat, FileFormatConfig, ParquetSourceConfig}; use common_runtime::RuntimeRef; @@ -21,7 +20,6 @@ use snafu::Snafu; use crate::{ hive::{hive_partitions_to_fields, hive_partitions_to_series, parse_hive_partitioning}, - scan_task_iters::{merge_by_sizes, split_by_row_groups, BoxScanTaskIter}, storage_config::StorageConfig, ChunkSpec, DataSource, ScanTask, }; @@ -355,11 +353,7 @@ impl ScanOperator for GlobScanOperator { lines } - fn to_scan_tasks( - &self, - pushdowns: Pushdowns, - cfg: Option<&DaftExecutionConfig>, - ) -> DaftResult> { + fn to_scan_tasks(&self, pushdowns: Pushdowns) -> DaftResult> { let (io_runtime, io_client) = self.storage_config.get_io_client_and_runtime()?; let io_stats = IOStatsContext::new(format!( "GlobScanOperator::to_scan_tasks for {:#?}", @@ -397,89 +391,79 @@ impl ScanOperator for GlobScanOperator { .collect(); let partition_schema = Schema::new(partition_fields)?; // Create one ScanTask per file. - let mut scan_tasks: BoxScanTaskIter = Box::new(files.enumerate().filter_map(|(idx, f)| { - let scan_task_result = (|| { - let FileMetadata { - filepath: path, - size: size_bytes, - .. - } = f?; - // Create partition values from hive partitions, if any. - let mut partition_values = if hive_partitioning { - let hive_partitions = parse_hive_partitioning(&path)?; - hive_partitions_to_series(&hive_partitions, &partition_schema)? - } else { - vec![] - }; - // Extend partition values based on whether a file_path_column is set (this column is inherently a partition). - if let Some(fp_col) = &file_path_column { - let trimmed = path.trim_start_matches("file://"); - let file_paths_column_series = - Utf8Array::from_iter(fp_col, std::iter::once(Some(trimmed))).into_series(); - partition_values.push(file_paths_column_series); - } - let (partition_spec, generated_fields) = if !partition_values.is_empty() { - let partition_values_table = Table::from_nonempty_columns(partition_values)?; - // If there are partition values, evaluate them against partition filters, if any. - if let Some(partition_filters) = &pushdowns.partition_filters { - let filter_result = - partition_values_table.filter(&[partition_filters.clone()])?; - if filter_result.is_empty() { - // Skip the current file since it does not satisfy the partition filters. - return Ok(None); - } + files + .enumerate() + .filter_map(|(idx, f)| { + let scan_task_result = (|| { + let FileMetadata { + filepath: path, + size: size_bytes, + .. + } = f?; + // Create partition values from hive partitions, if any. + let mut partition_values = if hive_partitioning { + let hive_partitions = parse_hive_partitioning(&path)?; + hive_partitions_to_series(&hive_partitions, &partition_schema)? + } else { + vec![] + }; + // Extend partition values based on whether a file_path_column is set (this column is inherently a partition). + if let Some(fp_col) = &file_path_column { + let trimmed = path.trim_start_matches("file://"); + let file_paths_column_series = + Utf8Array::from_iter(fp_col, std::iter::once(Some(trimmed))) + .into_series(); + partition_values.push(file_paths_column_series); } - let generated_fields = partition_values_table.schema.clone(); - let partition_spec = PartitionSpec { - keys: partition_values_table, + let (partition_spec, generated_fields) = if !partition_values.is_empty() { + let partition_values_table = + Table::from_nonempty_columns(partition_values)?; + // If there are partition values, evaluate them against partition filters, if any. + if let Some(partition_filters) = &pushdowns.partition_filters { + let filter_result = + partition_values_table.filter(&[partition_filters.clone()])?; + if filter_result.is_empty() { + // Skip the current file since it does not satisfy the partition filters. + return Ok(None); + } + } + let generated_fields = partition_values_table.schema.clone(); + let partition_spec = PartitionSpec { + keys: partition_values_table, + }; + (Some(partition_spec), Some(generated_fields)) + } else { + (None, None) }; - (Some(partition_spec), Some(generated_fields)) - } else { - (None, None) - }; - let row_group = row_groups - .as_ref() - .and_then(|rgs| rgs.get(idx).cloned()) - .flatten(); - let chunk_spec = row_group.map(ChunkSpec::Parquet); - Ok(Some(ScanTask::new( - vec![DataSource::File { - path, - chunk_spec, - size_bytes, - iceberg_delete_files: None, - metadata: None, - partition_spec, - statistics: None, - parquet_metadata: None, - }], - file_format_config.clone(), - schema.clone(), - storage_config.clone(), - pushdowns.clone(), - generated_fields, - ))) - })(); - match scan_task_result { - Ok(Some(scan_task)) => Some(Ok(scan_task.into())), - Ok(None) => None, - Err(e) => Some(Err(e)), - } - })); - - if let Some(cfg) = cfg { - scan_tasks = split_by_row_groups( - scan_tasks, - cfg.parquet_split_row_groups_max_files, - cfg.scan_tasks_min_size_bytes, - cfg.scan_tasks_max_size_bytes, - ); - - scan_tasks = merge_by_sizes(scan_tasks, &pushdowns, cfg); - } - - scan_tasks - .map(|st| st.map(|task| task as Arc)) + let row_group = row_groups + .as_ref() + .and_then(|rgs| rgs.get(idx).cloned()) + .flatten(); + let chunk_spec = row_group.map(ChunkSpec::Parquet); + Ok(Some(ScanTask::new( + vec![DataSource::File { + path, + chunk_spec, + size_bytes, + iceberg_delete_files: None, + metadata: None, + partition_spec, + statistics: None, + parquet_metadata: None, + }], + file_format_config.clone(), + schema.clone(), + storage_config.clone(), + pushdowns.clone(), + generated_fields, + ))) + })(); + match scan_task_result { + Ok(Some(scan_task)) => Some(Ok(Arc::new(scan_task) as Arc)), + Ok(None) => None, + Err(e) => Some(Err(e)), + } + }) .collect() } } diff --git a/src/daft-scan/src/lib.rs b/src/daft-scan/src/lib.rs index 788c0f3b60..efae6d250c 100644 --- a/src/daft-scan/src/lib.rs +++ b/src/daft-scan/src/lib.rs @@ -1,6 +1,12 @@ #![feature(if_let_guard)] #![feature(let_chains)] -use std::{any::Any, borrow::Cow, fmt::Debug, sync::Arc}; +use std::{ + any::Any, + borrow::Cow, + fmt::Debug, + hash::{Hash, Hasher}, + sync::Arc, +}; use common_display::DisplayAs; use common_error::DaftError; @@ -100,7 +106,7 @@ impl From for pyo3::PyErr { } /// Specification of a subset of a file to be read. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] pub enum ChunkSpec { /// Selection of Parquet row groups. Parquet(Vec), @@ -149,6 +155,63 @@ pub enum DataSource { }, } +impl Hash for DataSource { + fn hash(&self, state: &mut H) { + // Hash everything except for cached parquet metadata. + match self { + Self::File { + path, + chunk_spec, + size_bytes, + iceberg_delete_files, + metadata, + partition_spec, + statistics, + .. + } => { + path.hash(state); + if let Some(chunk_spec) = chunk_spec { + chunk_spec.hash(state); + } + size_bytes.hash(state); + iceberg_delete_files.hash(state); + metadata.hash(state); + partition_spec.hash(state); + statistics.hash(state); + } + Self::Database { + path, + size_bytes, + metadata, + statistics, + } => { + path.hash(state); + size_bytes.hash(state); + metadata.hash(state); + statistics.hash(state); + } + #[cfg(feature = "python")] + Self::PythonFactoryFunction { + module, + func_name, + func_args, + size_bytes, + metadata, + statistics, + partition_spec, + } => { + module.hash(state); + func_name.hash(state); + func_args.hash(state); + size_bytes.hash(state); + metadata.hash(state); + statistics.hash(state); + partition_spec.hash(state); + } + } + } +} + impl DataSource { #[must_use] pub fn get_path(&self) -> &str { @@ -349,7 +412,7 @@ impl DisplayAs for DataSource { } } -#[derive(Debug, PartialEq, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Serialize, Deserialize, Hash)] pub struct ScanTask { pub sources: Vec, @@ -388,6 +451,10 @@ impl ScanTaskLike for ScanTask { .map_or(false, |a| a == self) } + fn dyn_hash(&self, mut state: &mut dyn Hasher) { + self.hash(&mut state); + } + fn materialized_schema(&self) -> SchemaRef { self.materialized_schema() } diff --git a/src/daft-scan/src/python.rs b/src/daft-scan/src/python.rs index dd30b4541c..4c9da39372 100644 --- a/src/daft-scan/src/python.rs +++ b/src/daft-scan/src/python.rs @@ -1,3 +1,5 @@ +use std::hash::{Hash, Hasher}; + use common_py_serde::{deserialize_py_object, serialize_py_object}; use pyo3::{prelude::*, types::PyTuple}; use serde::{Deserialize, Serialize}; @@ -15,27 +17,48 @@ struct PyObjectSerializableWrapper( /// Python arguments to a Python function that produces Tables #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PythonTablesFactoryArgs(Vec); +pub struct PythonTablesFactoryArgs { + args: Vec, + hash: u64, +} + +impl Hash for PythonTablesFactoryArgs { + fn hash(&self, state: &mut H) { + self.hash.hash(state); + } +} impl PythonTablesFactoryArgs { pub fn new(args: Vec) -> Self { - Self(args.into_iter().map(PyObjectSerializableWrapper).collect()) + let mut hasher = std::collections::hash_map::DefaultHasher::new(); + Python::with_gil(|py| { + for obj in &args { + // Only hash hashable PyObjects. + if let Ok(hash) = obj.bind(py).hash() { + hash.hash(&mut hasher); + } + } + }); + Self { + args: args.into_iter().map(PyObjectSerializableWrapper).collect(), + hash: hasher.finish(), + } } #[must_use] pub fn to_pytuple<'a>(&self, py: Python<'a>) -> Bound<'a, PyTuple> { - pyo3::types::PyTuple::new_bound(py, self.0.iter().map(|x| x.0.bind(py))) + pyo3::types::PyTuple::new_bound(py, self.args.iter().map(|x| x.0.bind(py))) } } impl PartialEq for PythonTablesFactoryArgs { fn eq(&self, other: &Self) -> bool { - if self.0.len() != other.0.len() { + if self.args.len() != other.args.len() { return false; } - self.0 + self.args .iter() - .zip(other.0.iter()) + .zip(other.args.iter()) .all(|(s, o)| (s.0.as_ptr() as isize) == (o.0.as_ptr() as isize)) } } @@ -43,7 +66,7 @@ impl PartialEq for PythonTablesFactoryArgs { pub mod pylib { use std::sync::Arc; - use common_daft_config::{DaftExecutionConfig, PyDaftExecutionConfig}; + use common_daft_config::PyDaftExecutionConfig; use common_error::DaftResult; use common_file_formats::{python::PyFileFormatConfig, FileFormatConfig}; use common_py_serde::impl_bincode_py_state_serialization; @@ -66,7 +89,6 @@ pub mod pylib { use crate::{ anonymous::AnonymousScanOperator, glob::GlobScanOperator, - scan_task_iters::{merge_by_sizes, split_by_row_groups, BoxScanTaskIter}, storage_config::{PyStorageConfig, PythonStorageConfig}, DataSource, ScanTask, }; @@ -248,11 +270,7 @@ pub mod pylib { lines } - fn to_scan_tasks( - &self, - pushdowns: Pushdowns, - cfg: Option<&DaftExecutionConfig>, - ) -> DaftResult> { + fn to_scan_tasks(&self, pushdowns: Pushdowns) -> DaftResult> { let scan_tasks = Python::with_gil(|py| { let pypd = PyPushdowns(pushdowns.clone().into()).into_py(py); let pyiter = @@ -269,20 +287,8 @@ pub mod pylib { ) })?; - let mut scan_tasks: BoxScanTaskIter = Box::new(scan_tasks.into_iter()); - - if let Some(cfg) = cfg { - scan_tasks = split_by_row_groups( - scan_tasks, - cfg.parquet_split_row_groups_max_files, - cfg.scan_tasks_min_size_bytes, - cfg.scan_tasks_max_size_bytes, - ); - - scan_tasks = merge_by_sizes(scan_tasks, &pushdowns, cfg); - } - scan_tasks + .into_iter() .map(|st| st.map(|task| task as Arc)) .collect() } diff --git a/src/daft-scan/src/scan_task_iters.rs b/src/daft-scan/src/scan_task_iters.rs index c1c92b2483..3ee2a18ccd 100644 --- a/src/daft-scan/src/scan_task_iters.rs +++ b/src/daft-scan/src/scan_task_iters.rs @@ -1,8 +1,9 @@ use std::sync::Arc; use common_daft_config::DaftExecutionConfig; -use common_error::DaftResult; +use common_error::{DaftError, DaftResult}; use common_file_formats::{FileFormatConfig, ParquetSourceConfig}; +use common_scan_info::{ScanTaskLike, ScanTaskLikeRef, SPLIT_AND_MERGE_PASS}; use daft_io::IOStatsContext; use daft_parquet::read::read_parquet_metadata; use parquet2::metadata::RowGroupList; @@ -308,3 +309,43 @@ pub(crate) fn split_by_row_groups( ) } } + +fn split_and_merge_pass( + scan_tasks: Arc>, + pushdowns: &Pushdowns, + cfg: &DaftExecutionConfig, +) -> DaftResult>> { + // Perform scan task splitting and merging if there are only ScanTasks (i.e. no DummyScanTasks). + if scan_tasks + .iter() + .all(|st| st.as_any().downcast_ref::().is_some()) + { + // TODO(desmond): Here we downcast Arc to Arc. ScanTask and DummyScanTask (test only) are + // the only non-test implementer of ScanTaskLike. It might be possible to avoid the downcast by implementing merging + // at the trait level, but today that requires shifting around a non-trivial amount of code to avoid circular dependencies. + let iter: BoxScanTaskIter = Box::new(scan_tasks.as_ref().iter().map(|st| { + st.clone() + .as_any_arc() + .downcast::() + .map_err(|e| DaftError::TypeError(format!("Expected Arc, found {:?}", e))) + })); + let split_tasks = split_by_row_groups( + iter, + cfg.parquet_split_row_groups_max_files, + cfg.scan_tasks_min_size_bytes, + cfg.scan_tasks_max_size_bytes, + ); + let merged_tasks = merge_by_sizes(split_tasks, pushdowns, cfg); + let scan_tasks: Vec> = merged_tasks + .map(|st| st.map(|task| task as Arc)) + .collect::>>()?; + Ok(Arc::new(scan_tasks)) + } else { + Ok(scan_tasks) + } +} + +#[ctor::ctor] +fn set_pass() { + let _ = SPLIT_AND_MERGE_PASS.set(&split_and_merge_pass); +} diff --git a/src/daft-sql/src/lib.rs b/src/daft-sql/src/lib.rs index c304cad020..485dcf0aeb 100644 --- a/src/daft-sql/src/lib.rs +++ b/src/daft-sql/src/lib.rs @@ -55,14 +55,14 @@ mod tests { ]) .unwrap(), ); - LogicalPlan::Source(Source { - output_schema: schema.clone(), - source_info: Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { + LogicalPlan::Source(Source::new( + schema.clone(), + Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { source_schema: schema, clustering_spec: Arc::new(ClusteringSpec::unknown()), source_id: 0, })), - }) + )) .arced() } @@ -76,14 +76,14 @@ mod tests { ]) .unwrap(), ); - LogicalPlan::Source(Source { - output_schema: schema.clone(), - source_info: Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { + LogicalPlan::Source(Source::new( + schema.clone(), + Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { source_schema: schema, clustering_spec: Arc::new(ClusteringSpec::unknown()), source_id: 0, })), - }) + )) .arced() } @@ -97,14 +97,14 @@ mod tests { ]) .unwrap(), ); - LogicalPlan::Source(Source { - output_schema: schema.clone(), - source_info: Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { + LogicalPlan::Source(Source::new( + schema.clone(), + Arc::new(SourceInfo::PlaceHolder(PlaceHolderInfo { source_schema: schema, clustering_spec: Arc::new(ClusteringSpec::unknown()), source_id: 0, })), - }) + )) .arced() } @@ -191,7 +191,7 @@ mod tests { let sql = "select test as a from tbl1"; let plan = planner.plan_sql(sql).unwrap(); - let expected = LogicalPlanBuilder::new(tbl_1, None) + let expected = LogicalPlanBuilder::from(tbl_1) .select(vec![col("test").alias("a")]) .unwrap() .build(); @@ -203,7 +203,7 @@ mod tests { let sql = "select test as a from tbl1 where test = 'a'"; let plan = planner.plan_sql(sql)?; - let expected = LogicalPlanBuilder::new(tbl_1, None) + let expected = LogicalPlanBuilder::from(tbl_1) .filter(col("test").eq(lit("a")))? .select(vec![col("test").alias("a")])? .build(); @@ -216,7 +216,7 @@ mod tests { let sql = "select test as a from tbl1 limit 10"; let plan = planner.plan_sql(sql)?; - let expected = LogicalPlanBuilder::new(tbl_1, None) + let expected = LogicalPlanBuilder::from(tbl_1) .select(vec![col("test").alias("a")])? .limit(10, true)? .build(); @@ -230,7 +230,7 @@ mod tests { let sql = "select utf8 from tbl1 order by utf8 desc"; let plan = planner.plan_sql(sql)?; - let expected = LogicalPlanBuilder::new(tbl_1, None) + let expected = LogicalPlanBuilder::from(tbl_1) .select(vec![col("utf8")])? .sort(vec![col("utf8")], vec![true], vec![true])? .build(); @@ -241,7 +241,7 @@ mod tests { #[rstest] fn test_cast(mut planner: SQLPlanner, tbl_1: LogicalPlanRef) -> SQLPlannerResult<()> { - let builder = LogicalPlanBuilder::new(tbl_1, None); + let builder = LogicalPlanBuilder::from(tbl_1); let cases = vec![ ( "select bool::text from tbl1", @@ -285,7 +285,7 @@ mod tests { if null_equals_null { "<=>" } else { "=" } ); let plan = planner.plan_sql(&sql)?; - let expected = LogicalPlanBuilder::new(tbl_2, None) + let expected = LogicalPlanBuilder::from(tbl_2) .join_with_null_safe_equal( tbl_3, vec![col("id")], @@ -312,7 +312,7 @@ mod tests { let sql = "select * from tbl2 join tbl3 on tbl2.id = tbl3.id and tbl2.val > 0"; let plan = planner.plan_sql(&sql)?; - let expected = LogicalPlanBuilder::new(tbl_2, None) + let expected = LogicalPlanBuilder::from(tbl_2) .filter(col("val").gt(lit(0 as i64)))? .join_with_null_safe_equal( tbl_3, @@ -394,7 +394,7 @@ mod tests { let sql = "select max(i32) from tbl1"; let plan = planner.plan_sql(sql)?; - let expected = LogicalPlanBuilder::new(tbl_1, None) + let expected = LogicalPlanBuilder::from(tbl_1) .aggregate(vec![col("i32").max()], vec![])? .select(vec![col("i32")])? .build(); @@ -469,7 +469,7 @@ mod tests { field: Field::new("i32", DataType::Int32), depth: 1, })); - let subquery = LogicalPlanBuilder::new(tbl_2, None) + let subquery = LogicalPlanBuilder::from(tbl_2) .filter(col("id").eq(outer_col))? .aggregate(vec![col("id").max()], vec![])? .select(vec![col("id")])? @@ -477,7 +477,7 @@ mod tests { let subquery = Arc::new(Expr::Subquery(Subquery { plan: subquery })); - let expected = LogicalPlanBuilder::new(tbl_1, None) + let expected = LogicalPlanBuilder::from(tbl_1) .filter(col("i64").gt(subquery))? .select(vec![col("utf8")])? .build(); diff --git a/src/daft-stats/src/column_stats/mod.rs b/src/daft-stats/src/column_stats/mod.rs index 491c63ec40..f2e733815c 100644 --- a/src/daft-stats/src/column_stats/mod.rs +++ b/src/daft-stats/src/column_stats/mod.rs @@ -2,7 +2,10 @@ mod arithmetic; mod comparison; mod logical; -use std::string::FromUtf8Error; +use std::{ + hash::{Hash, Hasher}, + string::FromUtf8Error, +}; use daft_core::prelude::*; use snafu::{ResultExt, Snafu}; @@ -14,6 +17,24 @@ pub enum ColumnRangeStatistics { Loaded(Series, Series), } +impl Hash for ColumnRangeStatistics { + fn hash(&self, state: &mut H) { + match self { + Self::Missing => (), + Self::Loaded(l, u) => { + let lower_hashes = l + .hash(None) + .expect("Failed to hash lower column range statistics"); + lower_hashes.into_iter().for_each(|h| h.hash(state)); + let upper_hashes = u + .hash(None) + .expect("Failed to hash upper column range statistics"); + upper_hashes.into_iter().for_each(|h| h.hash(state)); + } + } + } +} + #[derive(PartialEq, Eq, Debug)] pub enum TruthValue { False, diff --git a/src/daft-stats/src/partition_spec.rs b/src/daft-stats/src/partition_spec.rs index 24834bf116..79fe4c21b4 100644 --- a/src/daft-stats/src/partition_spec.rs +++ b/src/daft-stats/src/partition_spec.rs @@ -1,4 +1,7 @@ -use std::collections::HashMap; +use std::{ + collections::HashMap, + hash::{Hash, Hasher}, +}; use daft_core::array::ops::{DaftCompare, DaftLogical}; use daft_dsl::{ExprRef, Literal}; @@ -71,3 +74,15 @@ impl PartialEq for PartitionSpec { } impl Eq for PartitionSpec {} + +// Manually implement Hash to ensure consistency with `PartialEq`. +impl Hash for PartitionSpec { + fn hash(&self, state: &mut H) { + self.keys.schema.hash(state); + + for column in &self.keys { + let column_hashes = column.hash(None).expect("Failed to hash column"); + column_hashes.into_iter().for_each(|h| h.hash(state)); + } + } +} diff --git a/src/daft-stats/src/table_metadata.rs b/src/daft-stats/src/table_metadata.rs index bcd76e96c4..4db512bf57 100644 --- a/src/daft-stats/src/table_metadata.rs +++ b/src/daft-stats/src/table_metadata.rs @@ -1,6 +1,6 @@ use serde::{Deserialize, Serialize}; -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] pub struct TableMetadata { pub length: usize, } diff --git a/src/daft-stats/src/table_stats.rs b/src/daft-stats/src/table_stats.rs index e0d91d24c6..5f6f32a5a8 100644 --- a/src/daft-stats/src/table_stats.rs +++ b/src/daft-stats/src/table_stats.rs @@ -1,6 +1,7 @@ use std::{ collections::HashMap, fmt::Display, + hash::{Hash, Hasher}, ops::{BitAnd, BitOr, Not}, }; @@ -17,6 +18,15 @@ pub struct TableStatistics { pub columns: IndexMap, } +impl Hash for TableStatistics { + fn hash(&self, state: &mut H) { + for (key, value) in &self.columns { + key.hash(state); + value.hash(state); + } + } +} + impl TableStatistics { pub fn from_stats_table(table: &Table) -> DaftResult { // Assumed format is each column having 2 rows: diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index d48cb36d33..359758e802 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -6,6 +6,7 @@ use core::slice; use std::{ collections::{HashMap, HashSet}, fmt::{Display, Formatter, Result}, + hash::{Hash, Hasher}, }; use arrow2::array::Array; @@ -47,6 +48,17 @@ pub struct Table { num_rows: usize, } +impl Hash for Table { + fn hash(&self, state: &mut H) { + self.schema.hash(state); + for col in &self.columns { + let hashes = col.hash(None).expect("Failed to hash column"); + hashes.into_iter().for_each(|h| h.hash(state)); + } + self.num_rows.hash(state); + } +} + #[inline] fn _validate_schema(schema: &Schema, columns: &[Series]) -> DaftResult<()> { if schema.fields.len() != columns.len() {