Skip to content

Commit

Permalink
Hash scan tasks properly; remove is_python_scan; fix unwrap_or_defaul…
Browse files Browse the repository at this point in the history
…t for StatsState
  • Loading branch information
desmondcheongzx committed Nov 23, 2024
1 parent 50e9f74 commit d5c4167
Show file tree
Hide file tree
Showing 30 changed files with 193 additions and 68 deletions.
4 changes: 0 additions & 4 deletions src/common/scan-info/src/scan_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@ pub trait ScanOperator: Send + Sync + Debug {
pushdowns: Pushdowns,
config: Option<&DaftExecutionConfig>,
) -> DaftResult<Vec<ScanTaskLikeRef>>;

fn is_python_scan(&self) -> bool {
false
}
}

impl Display for dyn ScanOperator {
Expand Down
14 changes: 7 additions & 7 deletions src/common/scan-info/src/scan_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ use crate::Pushdowns;
pub trait ScanTaskLike: Debug + DisplayAs + Send + Sync {
fn as_any(&self) -> &dyn Any;
fn as_any_arc(self: Arc<Self>) -> Arc<dyn Any + Send + Sync>;
fn as_ptr(&self) -> *const ();
fn dyn_eq(&self, other: &dyn ScanTaskLike) -> bool;
fn dyn_hash(&self, state: &mut dyn Hasher);
#[must_use]
fn materialized_schema(&self) -> SchemaRef;
#[must_use]
Expand All @@ -41,12 +41,6 @@ pub trait ScanTaskLike: Debug + DisplayAs + Send + Sync {

pub type ScanTaskLikeRef = Arc<dyn ScanTaskLike>;

impl Hash for dyn ScanTaskLike {
fn hash<H: Hasher>(&self, state: &mut H) {
self.as_ptr().hash(state);
}
}

impl Eq for dyn ScanTaskLike + '_ {}

impl PartialEq for dyn ScanTaskLike + '_ {
Expand All @@ -55,4 +49,10 @@ impl PartialEq for dyn ScanTaskLike + '_ {
}
}

impl Hash for dyn ScanTaskLike + '_ {
fn hash<H: Hasher>(&self, state: &mut H) {
self.dyn_hash(state);
}
}

pub type BoxScanTaskLikeIter = Box<dyn Iterator<Item = DaftResult<Arc<dyn ScanTaskLike>>>>;
16 changes: 10 additions & 6 deletions src/common/scan-info/src/test/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use std::{any::Any, sync::Arc};
use std::{
any::Any,
hash::{Hash, Hasher},
sync::Arc,
};

use common_daft_config::DaftExecutionConfig;
use common_display::DisplayAs;
Expand All @@ -9,7 +13,7 @@ use serde::{Deserialize, Serialize};

use crate::{PartitionField, Pushdowns, ScanOperator, ScanTaskLike, ScanTaskLikeRef};

#[derive(Debug, Serialize, Deserialize, PartialEq)]
#[derive(Debug, Serialize, Deserialize, PartialEq, Hash)]
struct DummyScanTask {
pub schema: SchemaRef,
pub pushdowns: Pushdowns,
Expand All @@ -31,17 +35,17 @@ impl ScanTaskLike for DummyScanTask {
self
}

fn as_ptr(&self) -> *const () {
std::ptr::from_ref(self).cast::<()>()
}

fn dyn_eq(&self, other: &dyn ScanTaskLike) -> bool {
other
.as_any()
.downcast_ref::<Self>()
.map_or(false, |a| a == self)
}

fn dyn_hash(&self, mut state: &mut dyn Hasher) {
self.hash(&mut state);
}

fn materialized_schema(&self) -> SchemaRef {
self.schema.clone()
}
Expand Down
2 changes: 1 addition & 1 deletion src/daft-logical-plan/src/ops/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl Aggregate {
// TODO(desmond): We can use the schema here for better estimations. For now, use the old logic.
let input_stats = self.input.get_stats();
assert!(matches!(input_stats, StatsState::Materialized(..)));
let input_stats = input_stats.unwrap_or_default();
let input_stats = input_stats.clone().unwrap_or_default();
let est_bytes_per_row_lower = input_stats.approx_stats.lower_bound_bytes
/ (input_stats.approx_stats.lower_bound_rows.max(1));
let est_bytes_per_row_upper =
Expand Down
4 changes: 2 additions & 2 deletions src/daft-logical-plan/src/ops/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ impl Concat {
assert!(matches!(input_stats, StatsState::Materialized(..)));
let other_stats = self.other.get_stats();
assert!(matches!(other_stats, StatsState::Materialized(..)));
let input_stats = input_stats.unwrap_or_default();
let other_stats = other_stats.unwrap_or_default();
let input_stats = input_stats.clone().unwrap_or_default();
let other_stats = other_stats.clone().unwrap_or_default();
let approx_stats = &input_stats.approx_stats + &other_stats.approx_stats;
self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats));
self
Expand Down
2 changes: 1 addition & 1 deletion src/daft-logical-plan/src/ops/distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ impl Distinct {
// TODO(desmond): We can simply use NDVs here. For now, do a naive estimation.
let input_stats = self.input.get_stats();
assert!(matches!(input_stats, StatsState::Materialized(..)));
let input_stats = input_stats.unwrap_or_default();
let input_stats = input_stats.clone().unwrap_or_default();
let est_bytes_per_row_lower = input_stats.approx_stats.lower_bound_bytes
/ (input_stats.approx_stats.lower_bound_rows.max(1));
let approx_stats = ApproxStats {
Expand Down
2 changes: 1 addition & 1 deletion src/daft-logical-plan/src/ops/explode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl Explode {
pub(crate) fn with_materialized_stats(mut self) -> Self {
let input_stats = self.input.get_stats();
assert!(matches!(input_stats, StatsState::Materialized(..)));
let input_stats = input_stats.unwrap_or_default();
let input_stats = input_stats.clone().unwrap_or_default();
let approx_stats = ApproxStats {
lower_bound_rows: input_stats.approx_stats.lower_bound_rows,
upper_bound_rows: None,
Expand Down
2 changes: 1 addition & 1 deletion src/daft-logical-plan/src/ops/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl Filter {
// TODO(desmond): We can do better estimations here. For now, reuse the old logic.
let input_stats = self.input.get_stats();
assert!(matches!(input_stats, StatsState::Materialized(..)));
let input_stats = input_stats.unwrap_or_default();
let input_stats = input_stats.clone().unwrap_or_default();
let upper_bound_rows = input_stats.approx_stats.upper_bound_rows;
let upper_bound_bytes = input_stats.approx_stats.upper_bound_bytes;
let approx_stats = ApproxStats {
Expand Down
4 changes: 2 additions & 2 deletions src/daft-logical-plan/src/ops/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,8 @@ impl Join {
let right_stats = self.right.get_stats();
assert!(matches!(left_stats, StatsState::Materialized(..)));
assert!(matches!(right_stats, StatsState::Materialized(..)));
let left_stats = left_stats.unwrap_or_default();
let right_stats = right_stats.unwrap_or_default();
let left_stats = left_stats.clone().unwrap_or_default();
let right_stats = right_stats.clone().unwrap_or_default();
let approx_stats = ApproxStats {
lower_bound_rows: 0,
upper_bound_rows: left_stats
Expand Down
2 changes: 1 addition & 1 deletion src/daft-logical-plan/src/ops/limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl Limit {
pub(crate) fn with_materialized_stats(mut self) -> Self {
let input_stats = self.input.get_stats();
assert!(matches!(input_stats, StatsState::Materialized(..)));
let input_stats = input_stats.unwrap_or_default();
let input_stats = input_stats.clone().unwrap_or_default();
let limit = self.limit as usize;
let est_bytes_per_row_lower = input_stats.approx_stats.lower_bound_bytes
/ input_stats.approx_stats.lower_bound_rows.max(1);
Expand Down
2 changes: 1 addition & 1 deletion src/daft-logical-plan/src/ops/sample.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ impl Sample {
// TODO(desmond): We can do better estimations with the projection schema. For now, reuse the old logic.
let input_stats = self.input.get_stats();
assert!(matches!(input_stats, StatsState::Materialized(..)));
let input_stats = input_stats.unwrap_or_default();
let input_stats = input_stats.clone().unwrap_or_default();
let approx_stats = input_stats
.approx_stats
.apply(|v| ((v as f64) * self.fraction) as usize);
Expand Down
2 changes: 1 addition & 1 deletion src/daft-logical-plan/src/ops/unpivot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl Unpivot {
pub(crate) fn with_materialized_stats(mut self) -> Self {
let input_stats = self.input.get_stats();
assert!(matches!(input_stats, StatsState::Materialized(..)));
let input_stats = input_stats.unwrap_or_default();
let input_stats = input_stats.clone().unwrap_or_default();
let num_values = self.values.len();
let approx_stats = ApproxStats {
lower_bound_rows: input_stats.approx_stats.lower_bound_rows * num_values,
Expand Down
4 changes: 2 additions & 2 deletions src/daft-logical-plan/src/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ pub enum StatsState {
}

impl StatsState {
pub fn unwrap_or_default(&self) -> PlanStats {
pub fn unwrap_or_default(self) -> PlanStats {
match self {
Self::Materialized(stats) => stats.clone(),
Self::Materialized(stats) => stats,
Self::NotMaterialized => PlanStats::default(),
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/daft-parquet/src/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ pub fn build_row_ranges(
} else {
let mut rows_to_add = limit.unwrap_or(metadata.num_rows as i64);

for (i, rg) in &metadata.row_groups {
for (i, rg) in metadata.row_groups.iter() {
if (curr_row_index + rg.num_rows()) < row_start_offset {
curr_row_index += rg.num_rows();
continue;
Expand Down
4 changes: 2 additions & 2 deletions src/daft-parquet/src/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,8 @@ fn apply_field_ids_to_parquet_file_metadata(

let new_row_groups = file_metadata
.row_groups
.into_values()
.map(|rg| {
.iter()
.map(|(_, rg)| {
let new_columns = rg
.columns()
.iter()
Expand Down
22 changes: 14 additions & 8 deletions src/daft-scan/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
#![feature(if_let_guard)]
#![feature(let_chains)]
use std::{any::Any, borrow::Cow, fmt::Debug, sync::Arc};
use std::{
any::Any,
borrow::Cow,
fmt::Debug,
hash::{Hash, Hasher},
sync::Arc,
};

use common_display::DisplayAs;
use common_error::DaftError;
Expand Down Expand Up @@ -100,7 +106,7 @@ impl From<Error> for pyo3::PyErr {
}

/// Specification of a subset of a file to be read.
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
pub enum ChunkSpec {
/// Selection of Parquet row groups.
Parquet(Vec<i64>),
Expand All @@ -119,7 +125,7 @@ impl ChunkSpec {
}
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Hash)]
pub enum DataSource {
File {
path: String,
Expand Down Expand Up @@ -349,7 +355,7 @@ impl DisplayAs for DataSource {
}
}

#[derive(Debug, PartialEq, Serialize, Deserialize)]
#[derive(Debug, PartialEq, Serialize, Deserialize, Hash)]
pub struct ScanTask {
pub sources: Vec<DataSource>,

Expand Down Expand Up @@ -381,17 +387,17 @@ impl ScanTaskLike for ScanTask {
self
}

fn as_ptr(&self) -> *const () {
std::ptr::from_ref::<Self>(self).cast()
}

fn dyn_eq(&self, other: &dyn ScanTaskLike) -> bool {
other
.as_any()
.downcast_ref::<Self>()
.map_or(false, |a| a == self)
}

fn dyn_hash(&self, mut state: &mut dyn Hasher) {
self.hash(&mut state);
}

fn materialized_schema(&self) -> SchemaRef {
self.materialized_schema()
}
Expand Down
39 changes: 29 additions & 10 deletions src/daft-scan/src/python.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::hash::{Hash, Hasher};

use common_py_serde::{deserialize_py_object, serialize_py_object};
use pyo3::{prelude::*, types::PyTuple};
use serde::{Deserialize, Serialize};
Expand All @@ -15,27 +17,48 @@ struct PyObjectSerializableWrapper(

/// Python arguments to a Python function that produces Tables
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PythonTablesFactoryArgs(Vec<PyObjectSerializableWrapper>);
pub struct PythonTablesFactoryArgs {
args: Vec<PyObjectSerializableWrapper>,
hash: u64,
}

impl Hash for PythonTablesFactoryArgs {
fn hash<H: Hasher>(&self, state: &mut H) {
self.hash.hash(state);
}
}

impl PythonTablesFactoryArgs {
pub fn new(args: Vec<PyObject>) -> Self {
Self(args.into_iter().map(PyObjectSerializableWrapper).collect())
let mut hasher = std::collections::hash_map::DefaultHasher::new();
Python::with_gil(|py| {
for obj in &args {
obj.bind(py)
.hash()
.expect("Failed to hash PyObject")
.hash(&mut hasher);
}
});
Self {
args: args.into_iter().map(PyObjectSerializableWrapper).collect(),
hash: hasher.finish(),
}
}

#[must_use]
pub fn to_pytuple<'a>(&self, py: Python<'a>) -> Bound<'a, PyTuple> {
pyo3::types::PyTuple::new_bound(py, self.0.iter().map(|x| x.0.bind(py)))
pyo3::types::PyTuple::new_bound(py, self.args.iter().map(|x| x.0.bind(py)))
}
}

impl PartialEq for PythonTablesFactoryArgs {
fn eq(&self, other: &Self) -> bool {
if self.0.len() != other.0.len() {
if self.args.len() != other.args.len() {
return false;
}
self.0
self.args
.iter()
.zip(other.0.iter())
.zip(other.args.iter())
.all(|(s, o)| (s.0.as_ptr() as isize) == (o.0.as_ptr() as isize))
}
}
Expand Down Expand Up @@ -215,10 +238,6 @@ pub mod pylib {
}

impl ScanOperator for PythonScanOperatorBridge {
fn is_python_scan(&self) -> bool {
true
}

fn partitioning_keys(&self) -> &[PartitionField] {
&self.partitioning_keys
}
Expand Down
23 changes: 22 additions & 1 deletion src/daft-stats/src/column_stats/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ mod arithmetic;
mod comparison;
mod logical;

use std::string::FromUtf8Error;
use std::{
hash::{Hash, Hasher},
string::FromUtf8Error,
};

use daft_core::prelude::*;
use snafu::{ResultExt, Snafu};
Expand All @@ -14,6 +17,24 @@ pub enum ColumnRangeStatistics {
Loaded(Series, Series),
}

impl Hash for ColumnRangeStatistics {
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
Self::Missing => (),
Self::Loaded(l, u) => {
let lower_hashes = l
.hash(None)
.expect("Failed to hash lower column range statistics");
lower_hashes.into_iter().for_each(|h| h.hash(state));
let upper_hashes = u
.hash(None)
.expect("Failed to hash upper column range statistics");
upper_hashes.into_iter().for_each(|h| h.hash(state));
}
}
}
}

#[derive(PartialEq, Eq, Debug)]
pub enum TruthValue {
False,
Expand Down
Loading

0 comments on commit d5c4167

Please sign in to comment.