diff --git a/src/daft-logical-plan/src/logical_plan.rs b/src/daft-logical-plan/src/logical_plan.rs index cf013d8295..fc2f065038 100644 --- a/src/daft-logical-plan/src/logical_plan.rs +++ b/src/daft-logical-plan/src/logical_plan.rs @@ -8,7 +8,7 @@ use indexmap::IndexSet; use snafu::Snafu; pub use crate::ops::*; -use crate::stats::StatsState; +use crate::stats::PlanStats; /// Logical plan for a Daft query. #[derive(Clone, Debug, PartialEq, Eq, Hash)] @@ -199,7 +199,7 @@ impl LogicalPlan { } } - pub fn get_stats(&self) -> &StatsState { + pub fn materialized_stats(&self) -> &PlanStats { match self { Self::Source(Source { stats_state, .. }) | Self::Project(Project { stats_state, .. }) @@ -218,7 +218,7 @@ impl LogicalPlan { | Self::Sink(Sink { stats_state, .. }) | Self::Sample(Sample { stats_state, .. }) | Self::MonotonicallyIncreasingId(MonotonicallyIncreasingId { stats_state, .. }) => { - stats_state + stats_state.materialized_stats() } Self::Intersect(_) => { panic!("Intersect nodes should be optimized away before stats are materialized") diff --git a/src/daft-logical-plan/src/ops/actor_pool_project.rs b/src/daft-logical-plan/src/ops/actor_pool_project.rs index 36f1388b88..63e28cd6c8 100644 --- a/src/daft-logical-plan/src/ops/actor_pool_project.rs +++ b/src/daft-logical-plan/src/ops/actor_pool_project.rs @@ -72,9 +72,8 @@ impl ActorPoolProject { pub(crate) fn with_materialized_stats(mut self) -> Self { // TODO(desmond): We can do better estimations with the projection schema. For now, reuse the old logic. - let input_stats = self.input.get_stats(); - assert!(matches!(input_stats, StatsState::Materialized(..))); - self.stats_state = input_stats.clone(); + let input_stats = self.input.materialized_stats(); + self.stats_state = StatsState::Materialized(input_stats.clone()); self } diff --git a/src/daft-logical-plan/src/ops/agg.rs b/src/daft-logical-plan/src/ops/agg.rs index 2513f33723..243b200543 100644 --- a/src/daft-logical-plan/src/ops/agg.rs +++ b/src/daft-logical-plan/src/ops/agg.rs @@ -63,9 +63,7 @@ impl Aggregate { pub(crate) fn with_materialized_stats(mut self) -> Self { // TODO(desmond): We can use the schema here for better estimations. For now, use the old logic. - let input_stats = self.input.get_stats(); - assert!(matches!(input_stats, StatsState::Materialized(..))); - let input_stats = input_stats.clone().unwrap_or_default(); + let input_stats = self.input.materialized_stats(); let est_bytes_per_row_lower = input_stats.approx_stats.lower_bound_bytes / (input_stats.approx_stats.lower_bound_rows.max(1)); let est_bytes_per_row_upper = diff --git a/src/daft-logical-plan/src/ops/concat.rs b/src/daft-logical-plan/src/ops/concat.rs index 724d41a682..601eaa135b 100644 --- a/src/daft-logical-plan/src/ops/concat.rs +++ b/src/daft-logical-plan/src/ops/concat.rs @@ -48,12 +48,8 @@ impl Concat { pub(crate) fn with_materialized_stats(mut self) -> Self { // TODO(desmond): We can do better estimations with the projection schema. For now, reuse the old logic. - let input_stats = self.input.get_stats(); - assert!(matches!(input_stats, StatsState::Materialized(..))); - let other_stats = self.other.get_stats(); - assert!(matches!(other_stats, StatsState::Materialized(..))); - let input_stats = input_stats.clone().unwrap_or_default(); - let other_stats = other_stats.clone().unwrap_or_default(); + let input_stats = self.input.materialized_stats(); + let other_stats = self.other.materialized_stats(); let approx_stats = &input_stats.approx_stats + &other_stats.approx_stats; self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats)); self diff --git a/src/daft-logical-plan/src/ops/distinct.rs b/src/daft-logical-plan/src/ops/distinct.rs index b9e9fd30ea..0ea1f8b480 100644 --- a/src/daft-logical-plan/src/ops/distinct.rs +++ b/src/daft-logical-plan/src/ops/distinct.rs @@ -22,9 +22,7 @@ impl Distinct { pub(crate) fn with_materialized_stats(mut self) -> Self { // TODO(desmond): We can simply use NDVs here. For now, do a naive estimation. - let input_stats = self.input.get_stats(); - assert!(matches!(input_stats, StatsState::Materialized(..))); - let input_stats = input_stats.clone().unwrap_or_default(); + let input_stats = self.input.materialized_stats(); let est_bytes_per_row_lower = input_stats.approx_stats.lower_bound_bytes / (input_stats.approx_stats.lower_bound_rows.max(1)); let approx_stats = ApproxStats { diff --git a/src/daft-logical-plan/src/ops/explode.rs b/src/daft-logical-plan/src/ops/explode.rs index 1de48c8588..e75c8bae84 100644 --- a/src/daft-logical-plan/src/ops/explode.rs +++ b/src/daft-logical-plan/src/ops/explode.rs @@ -66,9 +66,7 @@ impl Explode { } pub(crate) fn with_materialized_stats(mut self) -> Self { - let input_stats = self.input.get_stats(); - assert!(matches!(input_stats, StatsState::Materialized(..))); - let input_stats = input_stats.clone().unwrap_or_default(); + let input_stats = self.input.materialized_stats(); let approx_stats = ApproxStats { lower_bound_rows: input_stats.approx_stats.lower_bound_rows, upper_bound_rows: None, diff --git a/src/daft-logical-plan/src/ops/filter.rs b/src/daft-logical-plan/src/ops/filter.rs index 6b29a0bca4..1bbc9a5d8a 100644 --- a/src/daft-logical-plan/src/ops/filter.rs +++ b/src/daft-logical-plan/src/ops/filter.rs @@ -45,9 +45,7 @@ impl Filter { pub(crate) fn with_materialized_stats(mut self) -> Self { // Assume no row/column pruning in cardinality-affecting operations. // TODO(desmond): We can do better estimations here. For now, reuse the old logic. - let input_stats = self.input.get_stats(); - assert!(matches!(input_stats, StatsState::Materialized(..))); - let input_stats = input_stats.clone().unwrap_or_default(); + let input_stats = self.input.materialized_stats(); let upper_bound_rows = input_stats.approx_stats.upper_bound_rows; let upper_bound_bytes = input_stats.approx_stats.upper_bound_bytes; let approx_stats = ApproxStats { diff --git a/src/daft-logical-plan/src/ops/join.rs b/src/daft-logical-plan/src/ops/join.rs index 04b51d9754..0a547cdef4 100644 --- a/src/daft-logical-plan/src/ops/join.rs +++ b/src/daft-logical-plan/src/ops/join.rs @@ -314,12 +314,8 @@ impl Join { pub(crate) fn with_materialized_stats(mut self) -> Self { // Assume a Primary-key + Foreign-Key join which would yield the max of the two tables. // TODO(desmond): We can do better estimations here. For now, use the old logic. - let left_stats = self.left.get_stats(); - let right_stats = self.right.get_stats(); - assert!(matches!(left_stats, StatsState::Materialized(..))); - assert!(matches!(right_stats, StatsState::Materialized(..))); - let left_stats = left_stats.clone().unwrap_or_default(); - let right_stats = right_stats.clone().unwrap_or_default(); + let left_stats = self.left.materialized_stats(); + let right_stats = self.right.materialized_stats(); let approx_stats = ApproxStats { lower_bound_rows: 0, upper_bound_rows: left_stats diff --git a/src/daft-logical-plan/src/ops/limit.rs b/src/daft-logical-plan/src/ops/limit.rs index 46f718a2c1..784a86e123 100644 --- a/src/daft-logical-plan/src/ops/limit.rs +++ b/src/daft-logical-plan/src/ops/limit.rs @@ -28,9 +28,7 @@ impl Limit { } pub(crate) fn with_materialized_stats(mut self) -> Self { - let input_stats = self.input.get_stats(); - assert!(matches!(input_stats, StatsState::Materialized(..))); - let input_stats = input_stats.clone().unwrap_or_default(); + let input_stats = self.input.materialized_stats(); let limit = self.limit as usize; let est_bytes_per_row_lower = input_stats.approx_stats.lower_bound_bytes / input_stats.approx_stats.lower_bound_rows.max(1); diff --git a/src/daft-logical-plan/src/ops/monotonically_increasing_id.rs b/src/daft-logical-plan/src/ops/monotonically_increasing_id.rs index bc83c218de..5c89038d85 100644 --- a/src/daft-logical-plan/src/ops/monotonically_increasing_id.rs +++ b/src/daft-logical-plan/src/ops/monotonically_increasing_id.rs @@ -35,9 +35,8 @@ impl MonotonicallyIncreasingId { pub(crate) fn with_materialized_stats(mut self) -> Self { // TODO(desmond): We can do better estimations with the projection schema. For now, reuse the old logic. - let input_stats = self.input.get_stats(); - assert!(matches!(input_stats, StatsState::Materialized(..))); - self.stats_state = input_stats.clone(); + let input_stats = self.input.materialized_stats(); + self.stats_state = StatsState::Materialized(input_stats.clone()); self } diff --git a/src/daft-logical-plan/src/ops/pivot.rs b/src/daft-logical-plan/src/ops/pivot.rs index cbeb9cca7a..0cd904edb2 100644 --- a/src/daft-logical-plan/src/ops/pivot.rs +++ b/src/daft-logical-plan/src/ops/pivot.rs @@ -87,9 +87,8 @@ impl Pivot { pub(crate) fn with_materialized_stats(mut self) -> Self { // TODO(desmond): Pivoting does affect cardinality, but for now we keep the old logic. - let input_stats = self.input.get_stats(); - assert!(matches!(input_stats, StatsState::Materialized(..))); - self.stats_state = input_stats.clone(); + let input_stats = self.input.materialized_stats(); + self.stats_state = StatsState::Materialized(input_stats.clone()); self } diff --git a/src/daft-logical-plan/src/ops/project.rs b/src/daft-logical-plan/src/ops/project.rs index de760b3135..cffce4cfba 100644 --- a/src/daft-logical-plan/src/ops/project.rs +++ b/src/daft-logical-plan/src/ops/project.rs @@ -56,9 +56,8 @@ impl Project { pub(crate) fn with_materialized_stats(mut self) -> Self { // TODO(desmond): We can do better estimations with the projection schema. For now, reuse the old logic. - let input_stats = self.input.get_stats(); - assert!(matches!(input_stats, StatsState::Materialized(..))); - self.stats_state = input_stats.clone(); + let input_stats = self.input.materialized_stats(); + self.stats_state = StatsState::Materialized(input_stats.clone()); self } diff --git a/src/daft-logical-plan/src/ops/repartition.rs b/src/daft-logical-plan/src/ops/repartition.rs index 9deb07a0c5..67b0f4d71f 100644 --- a/src/daft-logical-plan/src/ops/repartition.rs +++ b/src/daft-logical-plan/src/ops/repartition.rs @@ -44,9 +44,8 @@ impl Repartition { pub(crate) fn with_materialized_stats(mut self) -> Self { // Repartitioning does not affect cardinality. - let input_stats = self.input.get_stats(); - assert!(matches!(input_stats, StatsState::Materialized(..))); - self.stats_state = input_stats.clone(); + let input_stats = self.input.materialized_stats(); + self.stats_state = StatsState::Materialized(input_stats.clone()); self } diff --git a/src/daft-logical-plan/src/ops/sample.rs b/src/daft-logical-plan/src/ops/sample.rs index 410d019dcf..0d7a1d2c44 100644 --- a/src/daft-logical-plan/src/ops/sample.rs +++ b/src/daft-logical-plan/src/ops/sample.rs @@ -54,9 +54,7 @@ impl Sample { pub(crate) fn with_materialized_stats(mut self) -> Self { // TODO(desmond): We can do better estimations with the projection schema. For now, reuse the old logic. - let input_stats = self.input.get_stats(); - assert!(matches!(input_stats, StatsState::Materialized(..))); - let input_stats = input_stats.clone().unwrap_or_default(); + let input_stats = self.input.materialized_stats(); let approx_stats = input_stats .approx_stats .apply(|v| ((v as f64) * self.fraction) as usize); diff --git a/src/daft-logical-plan/src/ops/sort.rs b/src/daft-logical-plan/src/ops/sort.rs index 4d7a80a58e..d8cebb74af 100644 --- a/src/daft-logical-plan/src/ops/sort.rs +++ b/src/daft-logical-plan/src/ops/sort.rs @@ -61,9 +61,8 @@ impl Sort { pub(crate) fn with_materialized_stats(mut self) -> Self { // Sorting does not affect cardinality. - let input_stats = self.input.get_stats(); - assert!(matches!(input_stats, StatsState::Materialized(..))); - self.stats_state = input_stats.clone(); + let input_stats = self.input.materialized_stats(); + self.stats_state = StatsState::Materialized(input_stats.clone()); self } diff --git a/src/daft-logical-plan/src/ops/unpivot.rs b/src/daft-logical-plan/src/ops/unpivot.rs index 309be10ec8..7bf0c25f41 100644 --- a/src/daft-logical-plan/src/ops/unpivot.rs +++ b/src/daft-logical-plan/src/ops/unpivot.rs @@ -98,9 +98,7 @@ impl Unpivot { } pub(crate) fn with_materialized_stats(mut self) -> Self { - let input_stats = self.input.get_stats(); - assert!(matches!(input_stats, StatsState::Materialized(..))); - let input_stats = input_stats.clone().unwrap_or_default(); + let input_stats = self.input.materialized_stats(); let num_values = self.values.len(); let approx_stats = ApproxStats { lower_bound_rows: input_stats.approx_stats.lower_bound_rows * num_values, diff --git a/src/daft-logical-plan/src/stats.rs b/src/daft-logical-plan/src/stats.rs index 7e898eebfa..0e8cec64d3 100644 --- a/src/daft-logical-plan/src/stats.rs +++ b/src/daft-logical-plan/src/stats.rs @@ -9,10 +9,10 @@ pub enum StatsState { } impl StatsState { - pub fn unwrap_or_default(self) -> PlanStats { + pub fn materialized_stats(&self) -> &PlanStats { match self { Self::Materialized(stats) => stats, - Self::NotMaterialized => PlanStats::default(), + Self::NotMaterialized => panic!("Tried to get unmaterialized stats"), } } }