From 94bb3704045f6e4383b909dbf319fec42957910c Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Thu, 21 Dec 2023 16:01:03 -0800 Subject: [PATCH] [PERF] Iceberg Partition Pruning (#1688) * Implements Partition Transforms which map source fields to partition fields * Implements Partition Filtering when creating scan tasks * Implements Predicate to Partition Filter rewriting * Allow Iceberg Scan to leverage partition filtering * Implements EmptyScan which kicks in whenever we have no files to scan * Fixes bug with incorrect length when we have a predicate and a limit in a ScanTask * Fixes bug with `df.num_partitions()` where we didn't optimize the logical plan before computing the number of partitions --- Cargo.lock | 1 + daft/daft.pyi | 26 ++- daft/dataframe/dataframe.py | 3 +- daft/execution/physical_plan.py | 1 - daft/execution/rust_physical_plan_shim.py | 32 ++- daft/iceberg/iceberg_scan.py | 58 +++-- daft/io/scan.py | 11 +- src/daft-core/src/series/ops/partitioning.rs | 29 ++- .../src/functions/partitioning/evaluators.rs | 22 +- .../src/functions/partitioning/mod.rs | 16 +- src/daft-dsl/src/lib.rs | 1 + src/daft-dsl/src/optimization.rs | 29 +++ src/daft-dsl/src/python.rs | 20 -- src/daft-micropartition/src/micropartition.rs | 1 + src/daft-plan/Cargo.toml | 1 + src/daft-plan/src/logical_ops/source.rs | 6 +- src/daft-plan/src/optimization/rules/mod.rs | 1 - .../optimization/rules/push_down_filter.rs | 21 +- src/daft-plan/src/optimization/rules/utils.rs | 28 --- src/daft-plan/src/physical_ops/empty_scan.rs | 19 ++ src/daft-plan/src/physical_ops/mod.rs | 2 + src/daft-plan/src/physical_plan.rs | 19 ++ src/daft-plan/src/planner.rs | 34 ++- src/daft-scan/src/expr_rewriter.rs | 157 +++++++++++++ src/daft-scan/src/lib.rs | 107 ++++++--- src/daft-scan/src/python.rs | 126 +++++++--- .../iceberg/test_partition_pruning.py | 219 ++++++++++++++++++ tests/integration/iceberg/test_table_load.py | 3 +- tests/series/test_partitioning.py | 16 +- 29 files changed, 821 insertions(+), 188 deletions(-) delete mode 100644 src/daft-plan/src/optimization/rules/utils.rs create mode 100644 src/daft-plan/src/physical_ops/empty_scan.rs create mode 100644 src/daft-scan/src/expr_rewriter.rs create mode 100644 tests/integration/iceberg/test_partition_pruning.py diff --git a/Cargo.lock b/Cargo.lock index 1cb5cafe1b..e567ba59bf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1387,6 +1387,7 @@ dependencies = [ "daft-scan", "daft-table", "indexmap 2.1.0", + "itertools", "log", "pyo3", "pyo3-log", diff --git a/daft/daft.pyi b/daft/daft.pyi index 2f3156d6fe..e52274f087 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -552,7 +552,8 @@ class ScanTask: storage_config: StorageConfig, size_bytes: int | None, pushdowns: Pushdowns | None, - ) -> ScanTask: + partition_values: PyTable | None, + ) -> ScanTask | None: """ Create a Catalog Scan Task """ @@ -585,17 +586,36 @@ class PartitionField: Partitioning Field of a Scan Source such as Hive or Iceberg """ + field: PyField + def __init__( - self, field: PyField, source_field: PyField | None = None, transform: PyExpr | None = None + self, field: PyField, source_field: PyField | None = None, transform: PartitionTransform | None = None ) -> None: ... +class PartitionTransform: + """ + Partitioning Transform from a Data Catalog source field to a Partitioning Columns + """ + + @staticmethod + def identity() -> PartitionTransform: ... + @staticmethod + def year() -> PartitionTransform: ... + @staticmethod + def month() -> PartitionTransform: ... + @staticmethod + def day() -> PartitionTransform: ... + @staticmethod + def hour() -> PartitionTransform: ... + class Pushdowns: """ Pushdowns from the query optimizer that can optimize scanning data sources. """ columns: list[str] | None - filters: list[str] | None + filters: PyExpr | None + partition_filters: PyExpr | None limit: int | None def read_parquet( diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index 1e191b6494..070f7194e0 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -132,7 +132,8 @@ def explain(self, show_optimized: bool = False, simple=False) -> None: def num_partitions(self) -> int: daft_execution_config = get_context().daft_execution_config - return self.__builder.to_physical_plan_scheduler(daft_execution_config).num_partitions() + # We need to run the optimizer since that could change the number of partitions + return self.__builder.optimize().to_physical_plan_scheduler(daft_execution_config).num_partitions() @DataframePublicAPI def schema(self) -> Schema: diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index 40868673ec..a95a8d8b3c 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -453,7 +453,6 @@ def global_limit( # since we will never take more than the remaining limit anyway. child_plan = local_limit(child_plan=child_plan, limit=remaining_rows) started = False - while True: # Check if any inputs finished executing. # Apply and deduct the rolling global limit. diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index bb1016088a..5d772bf49e 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -42,6 +42,17 @@ def scan_with_tasks( yield scan_step +def empty_scan( + schema: Schema, +) -> physical_plan.InProgressPhysicalPlan[PartitionT]: + """yield a plan to create an empty Partition""" + scan_step = execution_step.PartitionTaskBuilder[PartitionT](inputs=[], partial_metadatas=None,).add_instruction( + instruction=EmptyScan(schema=schema), + resource_request=ResourceRequest(memory_bytes=0), + ) + yield scan_step + + @dataclass(frozen=True) class ScanWithTask(execution_step.SingleOutputInstruction): scan_task: ScanTask @@ -51,7 +62,8 @@ def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]: def _scan(self, inputs: list[MicroPartition]) -> list[MicroPartition]: assert len(inputs) == 0 - return [MicroPartition._from_scan_task(self.scan_task)] + table = MicroPartition._from_scan_task(self.scan_task) + return [table] def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]: assert len(input_metadatas) == 0 @@ -64,6 +76,24 @@ def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) ] +@dataclass(frozen=True) +class EmptyScan(execution_step.SingleOutputInstruction): + schema: Schema + + def run(self, inputs: list[MicroPartition]) -> list[MicroPartition]: + return [MicroPartition.empty(self.schema)] + + def run_partial_metadata(self, input_metadatas: list[PartialPartitionMetadata]) -> list[PartialPartitionMetadata]: + assert len(input_metadatas) == 0 + + return [ + PartialPartitionMetadata( + num_rows=0, + size_bytes=0, + ) + ] + + def tabular_scan( schema: PySchema, columns_to_read: list[str] | None, diff --git a/daft/iceberg/iceberg_scan.py b/daft/iceberg/iceberg_scan.py index e17a66238d..a0d3b12416 100644 --- a/daft/iceberg/iceberg_scan.py +++ b/daft/iceberg/iceberg_scan.py @@ -1,26 +1,32 @@ from __future__ import annotations +import logging import warnings from collections.abc import Iterator +import pyarrow as pa from pyiceberg.io.pyarrow import schema_to_pyarrow from pyiceberg.partitioning import PartitionField as IcebergPartitionField from pyiceberg.partitioning import PartitionSpec as IcebergPartitionSpec from pyiceberg.schema import Schema as IcebergSchema from pyiceberg.table import Table +from pyiceberg.typedef import Record +import daft from daft.daft import ( FileFormatConfig, ParquetSourceConfig, + PartitionTransform, Pushdowns, ScanTask, StorageConfig, ) from daft.datatype import DataType -from daft.expressions.expressions import col from daft.io.scan import PartitionField, ScanOperator, make_partition_field from daft.logical.schema import Field, Schema +logger = logging.getLogger(__name__) + def _iceberg_partition_field_to_daft_partition_field( iceberg_schema: IcebergSchema, pfield: IcebergPartitionField @@ -37,7 +43,6 @@ def _iceberg_partition_field_to_daft_partition_field( arrow_result_type = schema_to_pyarrow(iceberg_result_type) daft_result_type = DataType.from_arrow_type(arrow_result_type) result_field = Field.create(name, daft_result_type) - from pyiceberg.transforms import ( DayTransform, HourTransform, @@ -46,25 +51,20 @@ def _iceberg_partition_field_to_daft_partition_field( YearTransform, ) - expr = None + tfm = None if isinstance(transform, IdentityTransform): - expr = col(source_name) - if source_name != name: - expr = expr.alias(name) + tfm = PartitionTransform.identity() elif isinstance(transform, YearTransform): - expr = col(source_name).dt.year().alias(name) + tfm = PartitionTransform.year() elif isinstance(transform, MonthTransform): - expr = col(source_name).dt.month().alias(name) + tfm = PartitionTransform.month() elif isinstance(transform, DayTransform): - expr = col(source_name).dt.day().alias(name) + tfm = PartitionTransform.day() elif isinstance(transform, HourTransform): - warnings.warn( - "HourTransform not implemented, Please make a comment: https://github.com/Eventual-Inc/Daft/issues/1606" - ) + tfm = PartitionTransform.hour() else: warnings.warn(f"{transform} not implemented, Please make an issue!") - - return make_partition_field(result_field, daft_field, transform=expr) + return make_partition_field(result_field, daft_field, transform=tfm) def iceberg_partition_spec_to_fields(iceberg_schema: IcebergSchema, spec: IcebergPartitionSpec) -> list[PartitionField]: @@ -83,15 +83,37 @@ def __init__(self, iceberg_table: Table, storage_config: StorageConfig) -> None: def schema(self) -> Schema: return self._schema + def display_name(self) -> str: + return f"IcebergScanOperator({'.'.join(self._table.name())})" + def partitioning_keys(self) -> list[PartitionField]: return self._partition_keys + def _iceberg_record_to_partition_spec(self, record: Record) -> daft.table.Table | None: + arrays = dict() + assert len(record._position_to_field_name) == len(self._partition_keys) + for name, value, pfield in zip(record._position_to_field_name, record.record_fields(), self._partition_keys): + field = Field._from_pyfield(pfield.field) + field_name = field.name + field_dtype = field.dtype + arrow_type = field_dtype.to_arrow_dtype() + assert name == field_name + arrays[name] = daft.Series.from_arrow(pa.array([value], type=arrow_type), name=name).cast(field_dtype) + if len(arrays) > 0: + return daft.table.Table.from_pydict(arrays) + else: + return None + def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: limit = pushdowns.limit iceberg_tasks = self._table.scan(limit=limit).plan_files() - limit_files = limit is not None and pushdowns.filters is None + limit_files = limit is not None and pushdowns.filters is None and pushdowns.partition_filters is None + if len(self.partitioning_keys()) > 0 and pushdowns.partition_filters is None: + logging.warn( + f"{self.display_name()} has Partitioning Keys: {self.partitioning_keys()} but no partition filter was specified. This will result in a full table scan." + ) scan_tasks = [] if limit is not None: @@ -114,9 +136,8 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: if len(task.delete_files) > 0: raise NotImplementedError(f"Iceberg Merge-on-Read currently not supported, please make an issue!") - # TODO: Thread in PartitionSpec to each ScanTask: P1 # TODO: Thread in Statistics to each ScanTask: P2 - + pspec = self._iceberg_record_to_partition_spec(file.partition) st = ScanTask.catalog_scan_task( file=path, file_format=file_format_config, @@ -125,7 +146,10 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: storage_config=self._storage_config, size_bytes=file.file_size_in_bytes, pushdowns=pushdowns, + partition_values=pspec._table if pspec is not None else None, ) + if st is None: + continue rows_left -= record_count scan_tasks.append(st) return iter(scan_tasks) diff --git a/daft/io/scan.py b/daft/io/scan.py index 1e110ee3d0..3f4e17dfe1 100644 --- a/daft/io/scan.py +++ b/daft/io/scan.py @@ -3,18 +3,17 @@ import abc from collections.abc import Iterator -from daft.daft import PartitionField, Pushdowns, ScanTask -from daft.expressions.expressions import Expression +from daft.daft import PartitionField, PartitionTransform, Pushdowns, ScanTask from daft.logical.schema import Field, Schema def make_partition_field( - field: Field, source_field: Field | None = None, transform: Expression | None = None + field: Field, source_field: Field | None = None, transform: PartitionTransform | None = None ) -> PartitionField: return PartitionField( field._field, source_field._field if source_field is not None else None, - transform._expr if transform is not None else None, + transform, ) @@ -23,6 +22,10 @@ class ScanOperator(abc.ABC): def schema(self) -> Schema: raise NotImplementedError() + @abc.abstractmethod + def display_name(self) -> str: + return self.__class__.__name__ + @abc.abstractmethod def partitioning_keys(self) -> list[PartitionField]: raise NotImplementedError() diff --git a/src/daft-core/src/series/ops/partitioning.rs b/src/daft-core/src/series/ops/partitioning.rs index 7469e1783e..f11160a617 100644 --- a/src/daft-core/src/series/ops/partitioning.rs +++ b/src/daft-core/src/series/ops/partitioning.rs @@ -1,17 +1,14 @@ use crate::datatypes::logical::TimestampArray; use crate::datatypes::{Int32Array, Int64Array, TimeUnit}; use crate::series::array_impl::IntoSeries; -use crate::{ - datatypes::{logical::DateArray, DataType}, - series::Series, -}; +use crate::{datatypes::DataType, series::Series}; use common_error::{DaftError, DaftResult}; impl Series { pub fn partitioning_years(&self) -> DaftResult { let epoch_year = Int32Array::from(("1970", vec![1970])).into_series(); - match self.data_type() { + let value = match self.data_type() { DataType::Date | DataType::Timestamp(_, None) => { let years_since_ce = self.dt_year()?; &years_since_ce - &epoch_year @@ -25,13 +22,14 @@ impl Series { "Can only run partitioning_years() operation on temporal types, got {}", self.data_type() ))), - } + }?; + value.cast(&DataType::Int32) } pub fn partitioning_months(&self) -> DaftResult { let months_in_year = Int32Array::from(("months", vec![12])).into_series(); let month_of_epoch = Int32Array::from(("months", vec![1])).into_series(); - match self.data_type() { + let value = match self.data_type() { DataType::Date | DataType::Timestamp(_, None) => { let years_since_1970 = self.partitioning_years()?; let months_of_this_year = self.dt_month()?; @@ -47,24 +45,22 @@ impl Series { "Can only run partitioning_years() operation on temporal types, got {}", self.data_type() ))), - } + }?; + value.cast(&DataType::Int32) } pub fn partitioning_days(&self) -> DaftResult { match self.data_type() { - DataType::Date => { - let downcasted = self.downcast::()?; - downcasted.cast(&DataType::Int32) - } + DataType::Date => Ok(self.clone()), DataType::Timestamp(_, None) => { let ts_array = self.downcast::()?; - ts_array.date()?.cast(&DataType::Int32) + Ok(ts_array.date()?.into_series()) } DataType::Timestamp(tu, Some(_)) => { let array = self.cast(&DataType::Timestamp(*tu, None))?; let ts_array = array.downcast::()?; - ts_array.date()?.cast(&DataType::Int32) + Ok(ts_array.date()?.into_series()) } _ => Err(DaftError::ComputeError(format!( @@ -75,7 +71,7 @@ impl Series { } pub fn partitioning_hours(&self) -> DaftResult { - match self.data_type() { + let value = match self.data_type() { DataType::Timestamp(unit, _) => { let ts_array = self.downcast::()?; let physical = &ts_array.physical; @@ -93,6 +89,7 @@ impl Series { "Can only run partitioning_hours() operation on timestamp types, got {}", self.data_type() ))), - } + }?; + value.cast(&DataType::Int32) } } diff --git a/src/daft-dsl/src/functions/partitioning/evaluators.rs b/src/daft-dsl/src/functions/partitioning/evaluators.rs index 5ef92b5758..bbd759df28 100644 --- a/src/daft-dsl/src/functions/partitioning/evaluators.rs +++ b/src/daft-dsl/src/functions/partitioning/evaluators.rs @@ -11,22 +11,23 @@ use common_error::{DaftError, DaftResult}; use super::super::FunctionEvaluator; macro_rules! impl_func_evaluator_for_partitioning { - ($name:ident, $op:ident, $kernel:ident) => { + ($name:ident, $op:ident, $kernel:ident, $result_type:ident) => { pub(super) struct $name {} impl FunctionEvaluator for $name { fn fn_name(&self) -> &'static str { - "$op" + stringify!($op) } fn to_field(&self, inputs: &[Expr], schema: &Schema, _: &Expr) -> DaftResult { match inputs { [input] => match input.to_field(schema) { Ok(field) if field.dtype.is_temporal() => { - Ok(Field::new(field.name, DataType::Int32)) + Ok(Field::new(field.name, $result_type)) } Ok(field) => Err(DaftError::TypeError(format!( - "Expected input to $op to be temporal, got {}", + "Expected input to {} to be temporal, got {}", + stringify!($op), field.dtype ))), Err(e) => Err(e), @@ -42,7 +43,8 @@ macro_rules! impl_func_evaluator_for_partitioning { match inputs { [input] => input.$kernel(), _ => Err(DaftError::ValueError(format!( - "Expected 1 input arg, got {}", + "Expected 1 input arg for {}, got {}", + stringify!($op), inputs.len() ))), } @@ -50,8 +52,8 @@ macro_rules! impl_func_evaluator_for_partitioning { } }; } - -impl_func_evaluator_for_partitioning!(YearsEvaluator, years, partitioning_years); -impl_func_evaluator_for_partitioning!(MonthsEvaluator, months, partitioning_months); -impl_func_evaluator_for_partitioning!(DaysEvaluator, days, partitioning_days); -impl_func_evaluator_for_partitioning!(HoursEvaluator, hours, partitioning_hours); +use DataType::{Date, Int32}; +impl_func_evaluator_for_partitioning!(YearsEvaluator, years, partitioning_years, Int32); +impl_func_evaluator_for_partitioning!(MonthsEvaluator, months, partitioning_months, Int32); +impl_func_evaluator_for_partitioning!(DaysEvaluator, days, partitioning_days, Date); +impl_func_evaluator_for_partitioning!(HoursEvaluator, hours, partitioning_hours, Int32); diff --git a/src/daft-dsl/src/functions/partitioning/mod.rs b/src/daft-dsl/src/functions/partitioning/mod.rs index f3aebb6fcd..ab36ad4c79 100644 --- a/src/daft-dsl/src/functions/partitioning/mod.rs +++ b/src/daft-dsl/src/functions/partitioning/mod.rs @@ -32,30 +32,30 @@ impl PartitioningExpr { } } -pub fn days(input: &Expr) -> Expr { +pub fn days(input: Expr) -> Expr { Expr::Function { func: super::FunctionExpr::Partitioning(PartitioningExpr::Days), - inputs: vec![input.clone()], + inputs: vec![input], } } -pub fn hours(input: &Expr) -> Expr { +pub fn hours(input: Expr) -> Expr { Expr::Function { func: super::FunctionExpr::Partitioning(PartitioningExpr::Hours), - inputs: vec![input.clone()], + inputs: vec![input], } } -pub fn months(input: &Expr) -> Expr { +pub fn months(input: Expr) -> Expr { Expr::Function { func: super::FunctionExpr::Partitioning(PartitioningExpr::Months), - inputs: vec![input.clone()], + inputs: vec![input], } } -pub fn years(input: &Expr) -> Expr { +pub fn years(input: Expr) -> Expr { Expr::Function { func: super::FunctionExpr::Partitioning(PartitioningExpr::Years), - inputs: vec![input.clone()], + inputs: vec![input], } } diff --git a/src/daft-dsl/src/lib.rs b/src/daft-dsl/src/lib.rs index ac43e833a1..387acc53e9 100644 --- a/src/daft-dsl/src/lib.rs +++ b/src/daft-dsl/src/lib.rs @@ -9,6 +9,7 @@ mod pyobject; #[cfg(feature = "python")] pub mod python; mod treenode; +pub use common_treenode; pub use expr::binary_op; pub use expr::col; pub use expr::{AggExpr, Expr, ExprRef, Operator}; diff --git a/src/daft-dsl/src/optimization.rs b/src/daft-dsl/src/optimization.rs index 1f99bd2931..d2f3d53a18 100644 --- a/src/daft-dsl/src/optimization.rs +++ b/src/daft-dsl/src/optimization.rs @@ -2,6 +2,8 @@ use std::collections::HashMap; use common_treenode::{Transformed, TreeNode, VisitRecursion}; +use crate::Operator; + use super::expr::Expr; pub fn get_required_columns(e: &Expr) -> Vec { @@ -42,3 +44,30 @@ pub fn replace_columns_with_expressions(expr: &Expr, replace_map: &HashMap Vec<&Expr> { + let mut splits = vec![]; + _split_conjuction(expr, &mut splits); + splits +} + +fn _split_conjuction<'a>(expr: &'a Expr, out_exprs: &mut Vec<&'a Expr>) { + match expr { + Expr::BinaryOp { + op: Operator::And, + left, + right, + } => { + _split_conjuction(left, out_exprs); + _split_conjuction(right, out_exprs); + } + Expr::Alias(inner_expr, ..) => _split_conjuction(inner_expr, out_exprs), + _ => { + out_exprs.push(expr); + } + } +} + +pub fn conjuct(exprs: Vec) -> Option { + exprs.into_iter().reduce(|acc, expr| acc.and(&expr)) +} diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index ba1a82c162..e43fc1be4a 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -288,26 +288,6 @@ impl PyExpr { Ok(day_of_week(&self.expr).into()) } - pub fn partitioning_days(&self) -> PyResult { - use functions::partitioning::days; - Ok(days(&self.expr).into()) - } - - pub fn partitioning_hours(&self) -> PyResult { - use functions::partitioning::hours; - Ok(hours(&self.expr).into()) - } - - pub fn partitioning_months(&self) -> PyResult { - use functions::partitioning::months; - Ok(months(&self.expr).into()) - } - - pub fn partitioning_years(&self) -> PyResult { - use functions::partitioning::years; - Ok(years(&self.expr).into()) - } - pub fn utf8_endswith(&self, pattern: &Self) -> PyResult { use crate::functions::utf8::endswith; Ok(endswith(&self.expr, &pattern.expr).into()) diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index 0619a415f1..aff4033a55 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -865,6 +865,7 @@ pub(crate) fn read_parquet_into_micropartition( ) .into(), Pushdowns::new( + None, None, columns .map(|cols| Arc::new(cols.iter().map(|v| v.to_string()).collect::>())), diff --git a/src/daft-plan/Cargo.toml b/src/daft-plan/Cargo.toml index 50b110d32d..00f4d08297 100644 --- a/src/daft-plan/Cargo.toml +++ b/src/daft-plan/Cargo.toml @@ -10,6 +10,7 @@ daft-dsl = {path = "../daft-dsl", default-features = false} daft-scan = {path = "../daft-scan", default-features = false} daft-table = {path = "../daft-table", default-features = false} indexmap = {workspace = true} +itertools = {workspace = true} log = {workspace = true} pyo3 = {workspace = true, optional = true} pyo3-log = {workspace = true, optional = true} diff --git a/src/daft-plan/src/logical_ops/source.rs b/src/daft-plan/src/logical_ops/source.rs index 7ec8a88d14..f219fb1a86 100644 --- a/src/daft-plan/src/logical_ops/source.rs +++ b/src/daft-plan/src/logical_ops/source.rs @@ -53,11 +53,15 @@ impl Source { partitioning_keys, pushdowns, })) => { + use itertools::Itertools; res.push("Source:".to_string()); res.push(format!("{}", scan_op)); res.push(format!("File schema = {}", source_schema.short_string())); - res.push(format!("Partitioning keys = {:?}", partitioning_keys)); + res.push(format!( + "Partitioning keys = [{}]", + partitioning_keys.iter().map(|k| format!("{k}")).join(" ") + )); res.extend(pushdowns.multiline_display()); } #[cfg(feature = "python")] diff --git a/src/daft-plan/src/optimization/rules/mod.rs b/src/daft-plan/src/optimization/rules/mod.rs index fd1578af73..2c78fd9b7b 100644 --- a/src/daft-plan/src/optimization/rules/mod.rs +++ b/src/daft-plan/src/optimization/rules/mod.rs @@ -3,7 +3,6 @@ mod push_down_filter; mod push_down_limit; mod push_down_projection; mod rule; -mod utils; pub use drop_repartition::DropRepartition; pub use push_down_filter::PushDownFilter; diff --git a/src/daft-plan/src/optimization/rules/push_down_filter.rs b/src/daft-plan/src/optimization/rules/push_down_filter.rs index a21178139e..9ef921e33f 100644 --- a/src/daft-plan/src/optimization/rules/push_down_filter.rs +++ b/src/daft-plan/src/optimization/rules/push_down_filter.rs @@ -7,9 +7,12 @@ use common_error::DaftResult; use daft_dsl::{ col, functions::FunctionExpr, - optimization::{get_required_columns, replace_columns_with_expressions}, + optimization::{ + conjuct, get_required_columns, replace_columns_with_expressions, split_conjuction, + }, Expr, }; +use daft_scan::{rewrite_predicate_for_partitioning, ScanExternalInfo}; use crate::{ logical_ops::{Concat, Filter, Project, Source}, @@ -17,10 +20,7 @@ use crate::{ LogicalPlan, }; -use super::{ - utils::{conjuct, split_conjuction}, - ApplyOrder, OptimizerRule, Transformed, -}; +use super::{ApplyOrder, OptimizerRule, Transformed}; /// Optimization rules for pushing Filters further into the logical plan. #[derive(Default, Debug)] @@ -116,8 +116,19 @@ impl OptimizerRule for PushDownFilter { return Ok(Transformed::No(plan)); } let new_predicate = external_info.pushdowns().filters.as_ref().map(|f| predicate.and(f)).unwrap_or(predicate.clone()); + let partition_filter = if let ExternalInfo::Scan(ScanExternalInfo {scan_op, ..}) = &external_info { + rewrite_predicate_for_partitioning(new_predicate.clone(), scan_op.0.partitioning_keys())? + } else { + None + }; let new_pushdowns = external_info.pushdowns().with_filters(Some(Arc::new(new_predicate))); + + let new_pushdowns = if let Some(pfilter) = partition_filter { + new_pushdowns.with_partition_filters(Some(Arc::new(pfilter))) + } else { + new_pushdowns + }; let new_external_info = external_info.with_pushdowns(new_pushdowns); let new_source = LogicalPlan::Source(Source::new( source.output_schema.clone(), diff --git a/src/daft-plan/src/optimization/rules/utils.rs b/src/daft-plan/src/optimization/rules/utils.rs deleted file mode 100644 index c8e7a3019d..0000000000 --- a/src/daft-plan/src/optimization/rules/utils.rs +++ /dev/null @@ -1,28 +0,0 @@ -use daft_dsl::{Expr, Operator}; - -pub fn split_conjuction(expr: &Expr) -> Vec<&Expr> { - let mut splits = vec![]; - _split_conjuction(expr, &mut splits); - splits -} - -fn _split_conjuction<'a>(expr: &'a Expr, out_exprs: &mut Vec<&'a Expr>) { - match expr { - Expr::BinaryOp { - op: Operator::And, - left, - right, - } => { - _split_conjuction(left, out_exprs); - _split_conjuction(right, out_exprs); - } - Expr::Alias(inner_expr, ..) => _split_conjuction(inner_expr, out_exprs), - _ => { - out_exprs.push(expr); - } - } -} - -pub fn conjuct(exprs: Vec) -> Option { - exprs.into_iter().reduce(|acc, expr| acc.and(&expr)) -} diff --git a/src/daft-plan/src/physical_ops/empty_scan.rs b/src/daft-plan/src/physical_ops/empty_scan.rs new file mode 100644 index 0000000000..fbf9e6dd3f --- /dev/null +++ b/src/daft-plan/src/physical_ops/empty_scan.rs @@ -0,0 +1,19 @@ +use crate::PartitionSpec; +use daft_core::schema::SchemaRef; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct EmptyScan { + pub schema: SchemaRef, + pub partition_spec: Arc, +} + +impl EmptyScan { + pub(crate) fn new(schema: SchemaRef, partition_spec: Arc) -> Self { + Self { + schema, + partition_spec, + } + } +} diff --git a/src/daft-plan/src/physical_ops/mod.rs b/src/daft-plan/src/physical_ops/mod.rs index e55aca8dd0..87e2dd29ea 100644 --- a/src/daft-plan/src/physical_ops/mod.rs +++ b/src/daft-plan/src/physical_ops/mod.rs @@ -3,6 +3,7 @@ mod broadcast_join; mod coalesce; mod concat; mod csv; +mod empty_scan; mod explode; mod fanout; mod filter; @@ -24,6 +25,7 @@ pub use broadcast_join::BroadcastJoin; pub use coalesce::Coalesce; pub use concat::Concat; pub use csv::{TabularScanCsv, TabularWriteCsv}; +pub use empty_scan::EmptyScan; pub use explode::Explode; pub use fanout::{FanoutByHash, FanoutByRange, FanoutRandom}; pub use filter::Filter; diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index 68112897ee..010b9d4bc7 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -37,6 +37,7 @@ pub enum PhysicalPlan { TabularScanCsv(TabularScanCsv), TabularScanJson(TabularScanJson), TabularScan(TabularScan), + EmptyScan(EmptyScan), Project(Project), Filter(Filter), Limit(Limit), @@ -65,6 +66,7 @@ impl PhysicalPlan { #[cfg(feature = "python")] Self::InMemoryScan(InMemoryScan { partition_spec, .. }) => partition_spec.clone(), Self::TabularScan(TabularScan { partition_spec, .. }) => partition_spec.clone(), + Self::EmptyScan(EmptyScan { partition_spec, .. }) => partition_spec.clone(), Self::TabularScanParquet(TabularScanParquet { partition_spec, .. }) => { partition_spec.clone() } @@ -179,6 +181,7 @@ impl PhysicalPlan { .iter() .map(|scan_task| scan_task.size_bytes()) .sum::>(), + Self::EmptyScan(..) => Some(0), // Assume no row/column pruning in cardinality-affecting operations. // TODO(Clark): Estimate row/column pruning to get a better size approximation. Self::Filter(Filter { input, .. }) @@ -392,6 +395,22 @@ impl PhysicalPlan { .collect::>(),))?; Ok(py_iter.into()) } + PhysicalPlan::EmptyScan(EmptyScan { schema, .. }) => { + let schema_mod = py.import(pyo3::intern!(py, "daft.logical.schema"))?; + let python_schema = schema_mod + .getattr(pyo3::intern!(py, "Schema"))? + .getattr(pyo3::intern!(py, "_from_pyschema"))? + .call1((PySchema { + schema: schema.clone(), + },))?; + + let py_iter = py + .import(pyo3::intern!(py, "daft.execution.rust_physical_plan_shim"))? + .getattr(pyo3::intern!(py, "empty_scan"))? + .call1((python_schema,))?; + Ok(py_iter.into()) + } + PhysicalPlan::TabularScanParquet(TabularScanParquet { projection_schema, external_info: diff --git a/src/daft-plan/src/planner.rs b/src/daft-plan/src/planner.rs index ba43518b73..d1ddda0ee8 100644 --- a/src/daft-plan/src/planner.rs +++ b/src/daft-plan/src/planner.rs @@ -75,6 +75,7 @@ pub fn plan(logical_plan: &LogicalPlan, cfg: Arc) -> DaftRe SourceInfo::ExternalInfo(ExternalSourceInfo::Scan(ScanExternalInfo { pushdowns, scan_op, + source_schema, .. })) => { let scan_tasks = scan_op.0.to_scan_tasks(pushdowns.clone())?; @@ -85,17 +86,30 @@ pub fn plan(logical_plan: &LogicalPlan, cfg: Arc) -> DaftRe cfg.merge_scan_tasks_min_size_bytes, cfg.merge_scan_tasks_max_size_bytes, ); - let scan_tasks = scan_tasks.collect::>>()?; - let partition_spec = Arc::new(PartitionSpec::new_internal( - PartitionScheme::Unknown, - scan_tasks.len(), - None, - )); - Ok(PhysicalPlan::TabularScan(TabularScan::new( - scan_tasks, - partition_spec, - ))) + if scan_tasks.is_empty() { + let partition_spec = Arc::new(PartitionSpec::new_internal( + PartitionScheme::Unknown, + 1, + None, + )); + + Ok(PhysicalPlan::EmptyScan(EmptyScan::new( + source_schema.clone(), + partition_spec, + ))) + } else { + let partition_spec = Arc::new(PartitionSpec::new_internal( + PartitionScheme::Unknown, + scan_tasks.len(), + None, + )); + + Ok(PhysicalPlan::TabularScan(TabularScan::new( + scan_tasks, + partition_spec, + ))) + } } #[cfg(feature = "python")] SourceInfo::InMemoryInfo(mem_info) => { diff --git a/src/daft-scan/src/expr_rewriter.rs b/src/daft-scan/src/expr_rewriter.rs new file mode 100644 index 0000000000..eeebfd1e6d --- /dev/null +++ b/src/daft-scan/src/expr_rewriter.rs @@ -0,0 +1,157 @@ +use std::collections::{HashMap, HashSet}; + +use common_error::DaftResult; +use daft_dsl::{ + col, + common_treenode::{Transformed, TreeNode, VisitRecursion}, + functions::partitioning, + null_lit, + optimization::{conjuct, split_conjuction}, + Expr, Operator, +}; + +use crate::{PartitionField, PartitionTransform}; + +fn unalias(expr: Expr) -> DaftResult { + expr.transform(&|e| { + if let Expr::Alias(e, _) = e { + Ok(Transformed::Yes(e.as_ref().clone())) + } else { + Ok(Transformed::No(e)) + } + }) +} + +fn apply_partitioning_expr(expr: Expr, tfm: PartitionTransform) -> Option { + use PartitionTransform::*; + match tfm { + Identity => Some(expr), + Year => Some(partitioning::years(expr)), + Month => Some(partitioning::months(expr)), + Day => Some(partitioning::days(expr)), + Hour => Some(partitioning::hours(expr)), + Void => Some(null_lit()), + _ => None, + } +} + +pub fn rewrite_predicate_for_partitioning( + predicate: Expr, + pfields: &[PartitionField], +) -> DaftResult> { + if pfields.is_empty() { + return Ok(None); + } + + let predicate = unalias(predicate)?; + + let source_to_pfield = { + let mut map = HashMap::with_capacity(pfields.len()); + for pf in pfields.iter() { + if let Some(ref source_field) = pf.source_field { + let prev_value = map.insert(source_field.name.as_str(), pf); + if let Some(prev_value) = prev_value { + return Err(common_error::DaftError::ValueError(format!("Duplicate Partitioning Columns found on same source field: {source_field}\n1: {prev_value}\n2: {pf}"))); + } + } + } + map + }; + + let with_part_cols = predicate.transform(&|expr| { + use Operator::*; + match expr { + // Binary Op for Eq + // All transforms should work as is + Expr::BinaryOp { + op: Eq, + ref left, ref right } => { + if let Expr::Column(col_name) = left.as_ref() && let Some(pfield) = source_to_pfield.get(col_name.as_ref()) { + if let Some(tfm) = pfield.transform && tfm.supports_equals() && let Some(new_expr) = apply_partitioning_expr(right.as_ref().clone(), tfm) { + return Ok(Transformed::Yes(Expr::BinaryOp { op: Eq, left: col(pfield.field.name.as_str()).into(), right: new_expr.into() })); + } + Ok(Transformed::No(expr)) + } else if let Expr::Column(col_name) = right.as_ref() && let Some(pfield) = source_to_pfield.get(col_name.as_ref()) { + if let Some(tfm) = pfield.transform && tfm.supports_equals() && let Some(new_expr) = apply_partitioning_expr(left.as_ref().clone(), tfm) { + return Ok(Transformed::Yes(Expr::BinaryOp { op: Eq, left: new_expr.into(), right: col(pfield.field.name.as_str()).into() })); + } + Ok(Transformed::No(expr)) + } else { + Ok(Transformed::No(expr)) + } + }, + // Binary Op for NotEq + // Should only work for Identity + Expr::BinaryOp { + op: NotEq, + ref left, ref right } => { + if let Expr::Column(col_name) = left.as_ref() && let Some(pfield) = source_to_pfield.get(col_name.as_ref()) { + if let Some(tfm) = pfield.transform && tfm.supports_not_equals() && let Some(new_expr) = apply_partitioning_expr(right.as_ref().clone(), tfm) { + return Ok(Transformed::Yes(Expr::BinaryOp { op: NotEq, left: col(pfield.field.name.as_str()).into(), right: new_expr.into() })); + } + Ok(Transformed::No(expr)) + } else if let Expr::Column(col_name) = right.as_ref() && let Some(pfield) = source_to_pfield.get(col_name.as_ref()) { + if let Some(tfm) = pfield.transform && tfm.supports_not_equals() && let Some(new_expr) = apply_partitioning_expr(left.as_ref().clone(), tfm) { + return Ok(Transformed::Yes(Expr::BinaryOp { op: NotEq, left: new_expr.into(), right: col(pfield.field.name.as_str()).into() })); + } + Ok(Transformed::No(expr)) + } else { + Ok(Transformed::No(expr)) + } + }, + // Binary Op for Lt | LtEq | Gt | GtEq + // we need to relax Lt and LtEq and only allow certain Transforms + Expr::BinaryOp { + op, + ref left, ref right } if matches!(op, Lt | LtEq | Gt | GtEq)=> { + use PartitionTransform::*; + + let relaxed_op = match op { + Lt | LtEq => LtEq, + Gt | GtEq => GtEq, + _ => unreachable!("this branch only supports Lt | LtEq | Gt | GtEq") + }; + + if let Expr::Column(col_name) = left.as_ref() && let Some(pfield) = source_to_pfield.get(col_name.as_ref()) { + if let Some(tfm) = pfield.transform && tfm.supports_comparison() && matches!(tfm, Year | Month | Hour | Day) && let Some(new_expr) = apply_partitioning_expr(right.as_ref().clone(), tfm) { + return Ok(Transformed::Yes(Expr::BinaryOp { op: relaxed_op, left: col(pfield.field.name.as_str()).into(), right: new_expr.into() })); + } + Ok(Transformed::No(expr)) + } else if let Expr::Column(col_name) = right.as_ref() && let Some(pfield) = source_to_pfield.get(col_name.as_ref()) { + if let Some(tfm) = pfield.transform && tfm.supports_comparison() && let Some(new_expr) = apply_partitioning_expr(left.as_ref().clone(), tfm) { + return Ok(Transformed::Yes(Expr::BinaryOp { op: relaxed_op, left: new_expr.into(), right: col(pfield.field.name.as_str()).into() })); + } + Ok(Transformed::No(expr)) + } else { + Ok(Transformed::No(expr)) + } + }, + + Expr::IsNull(ref expr) if let Expr::Column(col_name) = expr.as_ref() && let Some(pfield) = source_to_pfield.get(col_name.as_ref()) => { + Ok(Transformed::Yes(Expr::IsNull(col(pfield.field.name.as_str()).into()))) + }, + _ => Ok(Transformed::No(expr)) + } + })?; + + let p_keys = HashSet::<&str>::from_iter(pfields.iter().map(|p| p.field.name.as_ref())); + + let split = split_conjuction(&with_part_cols); + let filtered = split + .into_iter() + .filter(|p| { + let mut keep = true; + p.apply(&mut |e| { + if let Expr::Column(col_name) = e && !p_keys.contains(col_name.as_ref()) { + keep = false; + + } + Ok(VisitRecursion::Continue) + }) + .unwrap(); + keep + }) + .cloned() + .collect::>(); + Ok(conjuct(filtered)) +} diff --git a/src/daft-scan/src/lib.rs b/src/daft-scan/src/lib.rs index b283425498..a96601751c 100644 --- a/src/daft-scan/src/lib.rs +++ b/src/daft-scan/src/lib.rs @@ -1,3 +1,5 @@ +#![feature(if_let_guard)] +#![feature(let_chains)] use std::{ fmt::{Debug, Display}, hash::{Hash, Hasher}, @@ -6,7 +8,7 @@ use std::{ use common_error::{DaftError, DaftResult}; use daft_core::{datatypes::Field, schema::SchemaRef}; -use daft_dsl::{optimization::get_required_columns, Expr, ExprRef}; +use daft_dsl::ExprRef; use daft_stats::{PartitionSpec, TableMetadata, TableStatistics}; use file_format::FileFormatConfig; use serde::{Deserialize, Serialize}; @@ -28,7 +30,8 @@ use pyo3::PyErr; pub use python::register_modules; use snafu::Snafu; use storage_config::StorageConfig; - +mod expr_rewriter; +pub use expr_rewriter::rewrite_predicate_for_partitioning; #[derive(Debug, Snafu)] pub enum Error { #[cfg(feature = "python")] @@ -267,7 +270,11 @@ impl ScanTask { } pub fn num_rows(&self) -> Option { - self.metadata.as_ref().map(|m| m.length) + if self.pushdowns.filters.is_some() { + None + } else { + self.metadata.as_ref().map(|m| m.length) + } } pub fn size_bytes(&self) -> Option { @@ -294,35 +301,23 @@ impl ScanTask { pub struct PartitionField { field: Field, source_field: Option, - transform: Option, + transform: Option, } impl PartitionField { pub fn new( field: Field, source_field: Option, - transform: Option, + transform: Option, ) -> DaftResult { match (&source_field, &transform) { - (Some(sf), Some(tfm)) => { - let req_columns = get_required_columns(tfm); - match req_columns.as_slice() { - [col] => { - if col == &sf.name { - Ok(PartitionField { - field, - source_field, - transform, - }) - } else { - Err(DaftError::ValueError(format!("PartitionField transform's required column and source_field differ: {} vs {}" , col, sf.name))) - } - } - _ => Err(DaftError::ValueError(format!( - "PartitionField only supports unary transforms but received {}", - tfm - ))), - } + (Some(_), Some(_)) => { + // TODO ADD VALIDATION OF TRANSFORM based on types + Ok(PartitionField { + field, + source_field, + transform, + }) } (None, Some(tfm)) => Err(DaftError::ValueError(format!( "transform set in PartitionField: {} but source_field not set", @@ -340,13 +335,54 @@ impl PartitionField { impl Display for PartitionField { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { if let Some(tfm) = &self.transform { - write!(f, "PartitionField({}, {})", self.field, tfm) + write!( + f, + "PartitionField({}, src={}, tfm={})", + self.field, + self.source_field.as_ref().unwrap(), + tfm + ) } else { write!(f, "PartitionField({})", self.field) } } } +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub enum PartitionTransform { + /// https://iceberg.apache.org/spec/#partitioning + /// For Delta, Hudi and Hive, it should always be `Identity`. + Identity, + Bucket(u64), + Truncate(u64), + Year, + Month, + Day, + Hour, + Void, +} + +impl PartitionTransform { + pub fn supports_equals(&self) -> bool { + true + } + + pub fn supports_not_equals(&self) -> bool { + matches!(self, Self::Identity) + } + + pub fn supports_comparison(&self) -> bool { + use PartitionTransform::*; + matches!(self, Identity | Truncate(_) | Year | Month | Day | Hour) + } +} + +impl Display for PartitionTransform { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{self:?}") + } +} + pub trait ScanOperator: Send + Sync + Display + Debug { fn schema(&self) -> SchemaRef; fn partitioning_keys(&self) -> &[PartitionField]; @@ -420,6 +456,8 @@ impl ScanExternalInfo { pub struct Pushdowns { /// Optional filters to apply to the source data. pub filters: Option, + /// Optional filters to apply on partitioning keys. + pub partition_filters: Option, /// Optional columns to select from the source data. pub columns: Option>>, /// Optional number of rows to read. @@ -428,18 +466,20 @@ pub struct Pushdowns { impl Default for Pushdowns { fn default() -> Self { - Self::new(None, None, None) + Self::new(None, None, None, None) } } impl Pushdowns { pub fn new( filters: Option, + partition_filters: Option, columns: Option>>, limit: Option, ) -> Self { Self { filters, + partition_filters, columns, limit, } @@ -448,6 +488,7 @@ impl Pushdowns { pub fn with_limit(&self, limit: Option) -> Self { Self { filters: self.filters.clone(), + partition_filters: self.partition_filters.clone(), columns: self.columns.clone(), limit, } @@ -456,6 +497,16 @@ impl Pushdowns { pub fn with_filters(&self, filters: Option) -> Self { Self { filters, + partition_filters: self.partition_filters.clone(), + columns: self.columns.clone(), + limit: self.limit, + } + } + + pub fn with_partition_filters(&self, partition_filters: Option) -> Self { + Self { + filters: self.filters.clone(), + partition_filters, columns: self.columns.clone(), limit: self.limit, } @@ -464,6 +515,7 @@ impl Pushdowns { pub fn with_columns(&self, columns: Option>>) -> Self { Self { filters: self.filters.clone(), + partition_filters: self.partition_filters.clone(), columns, limit: self.limit, } @@ -477,6 +529,9 @@ impl Pushdowns { if let Some(filters) = &self.filters { res.push(format!("Filter pushdown = {}", filters)); } + if let Some(pfilters) = &self.partition_filters { + res.push(format!("Partition Filter = {}", pfilters)); + } if let Some(limit) = self.limit { res.push(format!("Limit pushdown = {}", limit)); } diff --git a/src/daft-scan/src/python.rs b/src/daft-scan/src/python.rs index a9b945f332..d435410b2c 100644 --- a/src/daft-scan/src/python.rs +++ b/src/daft-scan/src/python.rs @@ -9,6 +9,7 @@ pub mod pylib { use daft_core::impl_bincode_py_state_serialization; use daft_stats::PartitionSpec; use daft_stats::TableMetadata; + use daft_table::python::PyTable; use daft_table::Table; use pyo3::prelude::*; use pyo3::types::PyBytes; @@ -109,6 +110,7 @@ pub mod pylib { can_absorb_filter: bool, can_absorb_limit: bool, can_absorb_select: bool, + display_name: String, } impl PythonScanOperatorBridge { @@ -143,6 +145,11 @@ pub mod pylib { abc.call_method0(py, pyo3::intern!(py, "can_absorb_select"))? .extract::(py) } + + fn _display_name(abc: &PyObject, py: Python) -> PyResult { + abc.call_method0(py, pyo3::intern!(py, "display_name"))? + .extract::(py) + } } #[pymethods] @@ -154,6 +161,8 @@ pub mod pylib { let can_absorb_filter = Self::_can_absorb_filter(&abc, py)?; let can_absorb_limit = Self::_can_absorb_limit(&abc, py)?; let can_absorb_select = Self::_can_absorb_select(&abc, py)?; + let display_name = Self::_display_name(&abc, py)?; + Ok(Self { operator: abc, schema, @@ -161,32 +170,14 @@ pub mod pylib { can_absorb_filter, can_absorb_limit, can_absorb_select, + display_name, }) } } impl Display for PythonScanOperatorBridge { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "PythonScanOperator -operator:\n{:#?} -can_absorb_filter: {} -can_absorb_limit: {} -can_absorb_select: {} -schema:\n{} -partitioning_keys:\n", - self.operator, - self.can_absorb_filter, - self.can_absorb_limit, - self.can_absorb_select, - self.schema - )?; - - for p in self.partitioning_keys.iter() { - writeln!(f, "{p}")?; - } - Ok(()) + write!(f, "PythonScanOperator: {}", self.display_name,) } } @@ -261,6 +252,7 @@ partitioning_keys:\n", #[pymethods] impl PyScanTask { + #[allow(clippy::too_many_arguments)] #[staticmethod] pub fn catalog_scan_task( file: String, @@ -270,10 +262,27 @@ partitioning_keys:\n", storage_config: PyStorageConfig, size_bytes: Option, pushdowns: Option, - ) -> PyResult { - // TODO(Sammy): This should parsed from the operator and passed in here - let empty_pspec = PartitionSpec { - keys: Table::empty(None)?, + partition_values: Option, + ) -> PyResult> { + if let Some(ref pvalues) = partition_values && let Some(Some(ref partition_filters)) = pushdowns.as_ref().map(|p| &p.0.partition_filters) { + let table = &pvalues.table; + let eval_pred = table.eval_expression_list(&[partition_filters.as_ref().clone()])?; + assert_eq!(eval_pred.num_columns(), 1); + let series = eval_pred.get_column_by_index(0)?; + assert_eq!(series.data_type(), &daft_core::DataType::Boolean); + let boolean = series.bool()?; + assert_eq!(boolean.len(), 1); + let value = boolean.get(0); + match value { + Some(false) => return Ok(None), + None | Some(true) => {} + } + } + + let pspec = PartitionSpec { + keys: partition_values + .map(|p| p.table) + .unwrap_or_else(|| Table::empty(None).unwrap()), }; let data_source = DataFileSource::CatalogDataFile { path: file, @@ -282,7 +291,7 @@ partitioning_keys:\n", metadata: TableMetadata { length: num_rows as usize, }, - partition_spec: empty_pspec, + partition_spec: pspec, statistics: None, }; @@ -293,7 +302,7 @@ partitioning_keys:\n", storage_config.into(), pushdowns.map(|p| p.0.as_ref().clone()).unwrap_or_default(), ); - Ok(PyScanTask(scan_task.into())) + Ok(Some(PyScanTask(scan_task.into()))) } pub fn __repr__(&self) -> PyResult { @@ -325,12 +334,12 @@ partitioning_keys:\n", fn new( field: PyField, source_field: Option, - transform: Option, + transform: Option, ) -> PyResult { let p_field = PartitionField::new( field.field, source_field.map(|f| f.into()), - transform.map(|e| e.expr), + transform.map(|e| e.0), )?; Ok(PyPartitionField(Arc::new(p_field))) } @@ -338,6 +347,52 @@ partitioning_keys:\n", pub fn __repr__(&self) -> PyResult { Ok(format!("{}", self.0)) } + + #[getter] + pub fn field(&self) -> PyResult { + Ok(self.0.field.clone().into()) + } + } + + #[pyclass(module = "daft.daft", name = "PartitionTransform", frozen)] + #[derive(Debug, Clone, Serialize, Deserialize)] + pub struct PyPartitionTransform(crate::PartitionTransform); + + #[pymethods] + impl PyPartitionTransform { + #[staticmethod] + pub fn identity() -> PyResult { + Ok(Self(crate::PartitionTransform::Identity)) + } + + #[staticmethod] + pub fn year() -> PyResult { + Ok(Self(crate::PartitionTransform::Year)) + } + + #[staticmethod] + pub fn month() -> PyResult { + Ok(Self(crate::PartitionTransform::Month)) + } + + #[staticmethod] + pub fn day() -> PyResult { + Ok(Self(crate::PartitionTransform::Day)) + } + + #[staticmethod] + pub fn hour() -> PyResult { + Ok(Self(crate::PartitionTransform::Hour)) + } + + #[staticmethod] + pub fn void() -> PyResult { + Ok(Self(crate::PartitionTransform::Void)) + } + + pub fn __repr__(&self) -> PyResult { + Ok(format!("{}", self.0)) + } } #[pyclass(module = "daft.daft", name = "Pushdowns", frozen)] @@ -354,9 +409,17 @@ partitioning_keys:\n", } #[getter] - pub fn filters(&self) -> Option> { - //TODO(Sammy): Figure out how to pass filters back to python - None + pub fn filters(&self) -> Option { + self.0.filters.as_ref().map(|e| PyExpr { + expr: e.as_ref().clone(), + }) + } + + #[getter] + pub fn partition_filters(&self) -> Option { + self.0.partition_filters.as_ref().map(|e| PyExpr { + expr: e.as_ref().clone(), + }) } #[getter] @@ -370,6 +433,7 @@ pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> { parent.add_class::()?; parent.add_class::()?; parent.add_class::()?; + parent.add_class::()?; parent.add_class::()?; Ok(()) } diff --git a/tests/integration/iceberg/test_partition_pruning.py b/tests/integration/iceberg/test_partition_pruning.py new file mode 100644 index 0000000000..22a63de860 --- /dev/null +++ b/tests/integration/iceberg/test_partition_pruning.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +import pytest + +pyiceberg = pytest.importorskip("pyiceberg") +import itertools +from datetime import date, datetime + +import pandas as pd +import pytz + +import daft +from daft.expressions import Expression +from tests.conftest import assert_df_equals + + +@pytest.mark.integration() +def test_daft_iceberg_table_predicate_pushdown_days(local_iceberg_catalog): + + tab = local_iceberg_catalog.load_table("default.test_partitioned_by_days") + df = daft.read_iceberg(tab) + df = df.where(df["ts"] < date(2023, 3, 6)) + df.collect() + daft_pandas = df.to_pandas() + iceberg_pandas = tab.scan().to_arrow().to_pandas() + # need to use datetime here + iceberg_pandas = iceberg_pandas[iceberg_pandas["ts"] < datetime(2023, 3, 6, tzinfo=pytz.utc)] + + assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[]) + + +def is_null(obj): + if isinstance(obj, Expression): + return obj.is_null() + elif isinstance(obj, pd.Series): + return obj.isnull() + else: + raise NotImplementedError() + + +def udf_func(obj): + if isinstance(obj, Expression): + return obj.apply(lambda x: str(x)[:1] == "1", return_dtype=daft.DataType.bool()) + elif isinstance(obj, pd.Series): + return obj.apply(lambda x: str(x)[:1] == "1") + else: + raise NotImplementedError() + + +@pytest.mark.integration() +@pytest.mark.parametrize( + "predicate, table, limit", + itertools.product( + [ + lambda x: x < date(2023, 3, 6), + lambda x: x == date(2023, 3, 6), + lambda x: x > date(2023, 3, 6), + lambda x: x != date(2023, 3, 6), + lambda x: date(2023, 3, 6) > x, + lambda x: date(2023, 3, 6) == x, + lambda x: date(2023, 3, 6) < x, + lambda x: date(2023, 3, 6) != x, + is_null, + udf_func, + ], + [ + "test_partitioned_by_months", + "test_partitioned_by_years", + ], + [None, 1, 2, 1000], + ), +) +def test_daft_iceberg_table_predicate_pushdown_on_date_column(predicate, table, limit, local_iceberg_catalog): + tab = local_iceberg_catalog.load_table(f"default.{table}") + df = daft.read_iceberg(tab) + df = df.where(predicate(df["dt"])) + if limit: + df = df.limit(limit) + df.collect() + + daft_pandas = df.to_pandas() + iceberg_pandas = tab.scan().to_arrow().to_pandas() + iceberg_pandas = iceberg_pandas[predicate(iceberg_pandas["dt"])] + if limit: + iceberg_pandas = iceberg_pandas[:limit] + assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[]) + + +@pytest.mark.integration() +@pytest.mark.parametrize( + "predicate, table, limit", + itertools.product( + [ + lambda x: x < datetime(2023, 3, 6, tzinfo=pytz.utc), + lambda x: x == datetime(2023, 3, 6, tzinfo=pytz.utc), + lambda x: x > datetime(2023, 3, 6, tzinfo=pytz.utc), + lambda x: x != datetime(2023, 3, 6, tzinfo=pytz.utc), + lambda x: datetime(2023, 3, 6, tzinfo=pytz.utc) > x, + lambda x: datetime(2023, 3, 6, tzinfo=pytz.utc) == x, + lambda x: datetime(2023, 3, 6, tzinfo=pytz.utc) < x, + lambda x: datetime(2023, 3, 6, tzinfo=pytz.utc) != x, + is_null, + udf_func, + ], + [ + "test_partitioned_by_days", + "test_partitioned_by_hours", + "test_partitioned_by_identity", + ], + [None, 1, 2, 1000], + ), +) +def test_daft_iceberg_table_predicate_pushdown_on_timestamp_column(predicate, table, limit, local_iceberg_catalog): + tab = local_iceberg_catalog.load_table(f"default.{table}") + df = daft.read_iceberg(tab) + df = df.where(predicate(df["ts"])) + if limit: + df = df.limit(limit) + df.collect() + + daft_pandas = df.to_pandas() + iceberg_pandas = tab.scan().to_arrow().to_pandas() + iceberg_pandas = iceberg_pandas[predicate(iceberg_pandas["ts"])] + if limit: + iceberg_pandas = iceberg_pandas[:limit] + + assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[]) + + +@pytest.mark.integration() +@pytest.mark.parametrize( + "predicate, table, limit", + itertools.product( + [ + lambda x: x < "d", + lambda x: x == "d", + lambda x: x > "d", + lambda x: x != "d", + lambda x: x == "z", + lambda x: "d" > x, + lambda x: "d" == x, + lambda x: "d" < x, + lambda x: "d" != x, + lambda x: "z" == x, + is_null, + udf_func, + ], + [ + "test_partitioned_by_truncate", + ], + [None, 1, 2, 1000], + ), +) +def test_daft_iceberg_table_predicate_pushdown_on_letter(predicate, table, limit, local_iceberg_catalog): + tab = local_iceberg_catalog.load_table(f"default.{table}") + df = daft.read_iceberg(tab) + df = df.where(predicate(df["letter"])) + if limit: + df = df.limit(limit) + df.collect() + + daft_pandas = df.to_pandas() + iceberg_pandas = tab.scan().to_arrow().to_pandas() + iceberg_pandas = iceberg_pandas[predicate(iceberg_pandas["letter"])] + if limit: + iceberg_pandas = iceberg_pandas[:limit] + + assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[]) + + +@pytest.mark.integration() +@pytest.mark.parametrize( + "predicate, table, limit", + itertools.product( + [ + lambda x: x < 4, + lambda x: x == 4, + lambda x: x > 4, + lambda x: x != 4, + lambda x: x == 100, + lambda x: 4 > x, + lambda x: 4 == x, + lambda x: 4 < x, + lambda x: 4 != x, + lambda x: 100 == x, + is_null, + udf_func, + ], + [ + "test_partitioned_by_bucket", + ], + [None, 1, 2, 1000], + ), +) +def test_daft_iceberg_table_predicate_pushdown_on_number(predicate, table, limit, local_iceberg_catalog): + tab = local_iceberg_catalog.load_table(f"default.{table}") + df = daft.read_iceberg(tab) + df = df.where(predicate(df["number"])) + if limit: + df = df.limit(limit) + df.collect() + + daft_pandas = df.to_pandas() + iceberg_pandas = tab.scan().to_arrow().to_pandas() + iceberg_pandas = iceberg_pandas[predicate(iceberg_pandas["number"])] + if limit: + iceberg_pandas = iceberg_pandas[:limit] + + assert_df_equals(daft_pandas, iceberg_pandas, sort_key=[]) + + +@pytest.mark.integration() +def test_daft_iceberg_table_predicate_pushdown_empty_scan(local_iceberg_catalog): + tab = local_iceberg_catalog.load_table("default.test_partitioned_by_months") + df = daft.read_iceberg(tab) + df = df.where(df["dt"] > date(2030, 1, 1)) + df.collect() + values = df.to_arrow() + assert len(values) == 0 diff --git a/tests/integration/iceberg/test_table_load.py b/tests/integration/iceberg/test_table_load.py index 08ef7757a0..81c8d7c8cb 100644 --- a/tests/integration/iceberg/test_table_load.py +++ b/tests/integration/iceberg/test_table_load.py @@ -3,6 +3,7 @@ import pytest pyiceberg = pytest.importorskip("pyiceberg") + from pyiceberg.io.pyarrow import schema_to_pyarrow import daft @@ -34,7 +35,7 @@ def test_daft_iceberg_table_open(local_iceberg_tables): # "test_positional_mor_deletes", # Need Merge on Read # "test_positional_mor_double_deletes", # Need Merge on Read # "test_table_sanitized_character", # Bug in scan().to_arrow().to_arrow() - # "test_table_version", # we have bugs when loading no files + "test_table_version", # we have bugs when loading no files "test_uuid_and_fixed_unpartitioned", ] diff --git a/tests/series/test_partitioning.py b/tests/series/test_partitioning.py index 66fa7422dd..3d03ca9d1a 100644 --- a/tests/series/test_partitioning.py +++ b/tests/series/test_partitioning.py @@ -24,7 +24,9 @@ ) def test_partitioning_days(input, dtype, expected): s = Series.from_pylist(input).cast(dtype) - assert s.partitioning.days().to_pylist() == expected + d = s.partitioning.days() + assert d.datatype() == DataType.date() + assert d.cast(DataType.int32()).to_pylist() == expected @pytest.mark.parametrize( @@ -49,7 +51,9 @@ def test_partitioning_days(input, dtype, expected): ) def test_partitioning_months(input, dtype, expected): s = Series.from_pylist(input).cast(dtype) - assert s.partitioning.months().to_pylist() == expected + m = s.partitioning.months() + assert m.datatype() == DataType.int32() + assert m.to_pylist() == expected @pytest.mark.parametrize( @@ -70,7 +74,9 @@ def test_partitioning_months(input, dtype, expected): ) def test_partitioning_years(input, dtype, expected): s = Series.from_pylist(input).cast(dtype) - assert s.partitioning.years().to_pylist() == expected + y = s.partitioning.years() + assert y.datatype() == DataType.int32() + assert y.to_pylist() == expected @pytest.mark.parametrize( @@ -91,4 +97,6 @@ def test_partitioning_years(input, dtype, expected): ) def test_partitioning_hours(input, dtype, expected): s = Series.from_pylist(input).cast(dtype) - assert s.partitioning.hours().to_pylist() == expected + h = s.partitioning.hours() + assert h.datatype() == DataType.int32() + assert h.to_pylist() == expected