From 95a61d26bfb198c590570229a81f7e3bec0049a0 Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Sat, 14 Dec 2024 09:04:50 +0800 Subject: [PATCH 01/14] perf: lazily import pyiceberg and unity catalog if available (#3565) `memray` reports that a not-insignificant amount of memory is being taken by our catalog module at import-time image This PR makes those imports lazy Co-authored-by: Jay Chia --- daft/catalog/__init__.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/daft/catalog/__init__.py b/daft/catalog/__init__.py index 201c1f9b47..5e74fd4c08 100644 --- a/daft/catalog/__init__.py +++ b/daft/catalog/__init__.py @@ -45,21 +45,12 @@ from daft.dataframe import DataFrame -_PYICEBERG_AVAILABLE = False -try: - from pyiceberg.catalog import Catalog as PyIcebergCatalog - - _PYICEBERG_AVAILABLE = True -except ImportError: - pass +from typing import TYPE_CHECKING -_UNITY_AVAILABLE = False -try: +if TYPE_CHECKING: + from pyiceberg.catalog import Catalog as PyIcebergCatalog from daft.unity_catalog import UnityCatalog - _UNITY_AVAILABLE = True -except ImportError: - pass __all__ = [ "read_table", @@ -136,6 +127,22 @@ def register_python_catalog(catalog: PyIcebergCatalog | UnityCatalog, name: str >>> daft.catalog.register_python_catalog(catalog, "my_daft_catalog") """ + _PYICEBERG_AVAILABLE = False + try: + from pyiceberg.catalog import Catalog as PyIcebergCatalog + + _PYICEBERG_AVAILABLE = True + except ImportError: + pass + + _UNITY_AVAILABLE = False + try: + from daft.unity_catalog import UnityCatalog + + _UNITY_AVAILABLE = True + except ImportError: + pass + python_catalog: PyIcebergCatalog if _PYICEBERG_AVAILABLE and isinstance(catalog, PyIcebergCatalog): from daft.catalog.pyiceberg import PyIcebergCatalogAdaptor From 6c21917c9123e5eb61cc0a03f06b31e0a7b0f885 Mon Sep 17 00:00:00 2001 From: ccmao1130 Date: Mon, 16 Dec 2024 13:26:09 -0800 Subject: [PATCH 02/14] docs: update tpch benchmark link (#3542) --- docs/source/faq/benchmarks.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/faq/benchmarks.rst b/docs/source/faq/benchmarks.rst index 3c8df1a7d7..f66c4f2e5f 100644 --- a/docs/source/faq/benchmarks.rst +++ b/docs/source/faq/benchmarks.rst @@ -133,7 +133,7 @@ Benchmarking Code Our benchmarking scripts and code can be found in the `distributed-query-benchmarks `_ GitHub repository. * TPC-H queries for Daft were written by us. -* TPC-H queries for SparkSQL was adapted from `this repository `_. +* TPC-H queries for SparkSQL was adapted from `this repository `_. * TPC-H queries for Dask and Modin were adapted from these repositories for questions `Q1-7 `_ and `Q8-10 `_. Infrastructure From e148248dae8af90c8993d2ec6b2f471521c0a7f2 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Mon, 16 Dec 2024 22:02:49 -0800 Subject: [PATCH 03/14] feat(swordfish): Optimize grouped aggregations (#3534) Optimize swordfish grouped aggs for high cardinality groups ### Approach There's 3 strategies for grouped aggs: 1. Partition each input morsel into `N` partitions, then do a partial agg. (good for high cardinality). 2. Do a partial agg, then partition into `N` partitions. (good for low cardinality). Can be optimized with https://github.com/Eventual-Inc/Daft/pull/3556 3. Partition only, no partial agg. (only for map_groups, which has no partial agg). ### Notes on alternative approaches - Distributing partitions across workers (i.e. having each worker being responsible for accumulating only one partition) is much slower for low cardinality aggs (TPCH Q1 would have been 1.5x slower). This is because most of the work will end up being on only a few workers, reducing parallelism. - Simply partitioning the input and then only aggregating at the end works well with higher cardinality, but low cardinality takes a hit. (TPCH Q1 would have been 2.5x slower). - Probe Table approach was much slower, due to many calls to the multi-table dyn comparator. It was also much more complex to implement. ### Benchmarks [MrPowers Benchmarks](https://github.com/MrPowers/mrpowers-benchmarks) results (seconds, lower is better). | Query | this PR | Pyrunner | Current swordfish | |-------|---------|----------|-------------------| | q1 | 0.285720| 0.768858 | 0.356499 | | q2 | 4.780064| 6.122199 | 53.340565 | | q3 | 2.201079| 3.922857 | 16.935125 | | q4 | 0.313106| 0.545192 | 0.335541 | | q5 | 1.618228| 2.889354 | 10.665339 | | q7 | 2.087872| 3.856998 | 16.072660 | | q10 | 6.306756| 8.173738 | 53.800501 | --------- Co-authored-by: EC2 Default User Co-authored-by: Colin Ho Co-authored-by: EC2 Default User --- daft/context.py | 8 +- daft/daft/__init__.pyi | 6 + src/common/daft-config/src/lib.rs | 4 + src/common/daft-config/src/python.rs | 19 + .../src/intermediate_ops/aggregate.rs | 56 --- .../src/intermediate_ops/mod.rs | 1 - src/daft-local-execution/src/lib.rs | 2 + src/daft-local-execution/src/pipeline.rs | 89 +--- .../src/sinks/aggregate.rs | 62 ++- .../src/sinks/grouped_aggregate.rs | 438 ++++++++++++++++++ src/daft-local-execution/src/sinks/mod.rs | 1 + tests/dataframe/test_intersect.py | 4 +- tests/sql/test_aggs.py | 1 + 13 files changed, 537 insertions(+), 154 deletions(-) delete mode 100644 src/daft-local-execution/src/intermediate_ops/aggregate.rs create mode 100644 src/daft-local-execution/src/sinks/grouped_aggregate.rs diff --git a/daft/context.py b/daft/context.py index 0b071a431b..dbec115099 100644 --- a/daft/context.py +++ b/daft/context.py @@ -344,6 +344,8 @@ def set_execution_config( csv_target_filesize: int | None = None, csv_inflation_factor: float | None = None, shuffle_aggregation_default_partitions: int | None = None, + partial_aggregation_threshold: int | None = None, + high_cardinality_aggregation_threshold: float | None = None, read_sql_partition_size_bytes: int | None = None, enable_aqe: bool | None = None, enable_native_executor: bool | None = None, @@ -384,7 +386,9 @@ def set_execution_config( parquet_inflation_factor: Inflation Factor of parquet files (In-Memory-Size / File-Size) ratio. Defaults to 3.0 csv_target_filesize: Target File Size when writing out CSV Files. Defaults to 512MB csv_inflation_factor: Inflation Factor of CSV files (In-Memory-Size / File-Size) ratio. Defaults to 0.5 - shuffle_aggregation_default_partitions: Maximum number of partitions to create when performing aggregations. Defaults to 200, unless the number of input partitions is less than 200. + shuffle_aggregation_default_partitions: Maximum number of partitions to create when performing aggregations on the Ray Runner. Defaults to 200, unless the number of input partitions is less than 200. + partial_aggregation_threshold: Threshold for performing partial aggregations on the Native Runner. Defaults to 10000 rows. + high_cardinality_aggregation_threshold: Threshold selectivity for performing high cardinality aggregations on the Native Runner. Defaults to 0.8. read_sql_partition_size_bytes: Target size of partition when reading from SQL databases. Defaults to 512MB enable_aqe: Enables Adaptive Query Execution, Defaults to False enable_native_executor: Enables the native executor, Defaults to False @@ -413,6 +417,8 @@ def set_execution_config( csv_target_filesize=csv_target_filesize, csv_inflation_factor=csv_inflation_factor, shuffle_aggregation_default_partitions=shuffle_aggregation_default_partitions, + partial_aggregation_threshold=partial_aggregation_threshold, + high_cardinality_aggregation_threshold=high_cardinality_aggregation_threshold, read_sql_partition_size_bytes=read_sql_partition_size_bytes, enable_aqe=enable_aqe, enable_native_executor=enable_native_executor, diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index 47e88d9afd..da292d2df1 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -1730,6 +1730,8 @@ class PyDaftExecutionConfig: csv_target_filesize: int | None = None, csv_inflation_factor: float | None = None, shuffle_aggregation_default_partitions: int | None = None, + partial_aggregation_threshold: int | None = None, + high_cardinality_aggregation_threshold: float | None = None, read_sql_partition_size_bytes: int | None = None, enable_aqe: bool | None = None, enable_native_executor: bool | None = None, @@ -1765,6 +1767,10 @@ class PyDaftExecutionConfig: @property def shuffle_aggregation_default_partitions(self) -> int: ... @property + def partial_aggregation_threshold(self) -> int: ... + @property + def high_cardinality_aggregation_threshold(self) -> float: ... + @property def read_sql_partition_size_bytes(self) -> int: ... @property def enable_aqe(self) -> bool: ... diff --git a/src/common/daft-config/src/lib.rs b/src/common/daft-config/src/lib.rs index 3ffa4ea158..590fd5cf6c 100644 --- a/src/common/daft-config/src/lib.rs +++ b/src/common/daft-config/src/lib.rs @@ -44,6 +44,8 @@ pub struct DaftExecutionConfig { pub csv_target_filesize: usize, pub csv_inflation_factor: f64, pub shuffle_aggregation_default_partitions: usize, + pub partial_aggregation_threshold: usize, + pub high_cardinality_aggregation_threshold: f64, pub read_sql_partition_size_bytes: usize, pub enable_aqe: bool, pub enable_native_executor: bool, @@ -70,6 +72,8 @@ impl Default for DaftExecutionConfig { csv_target_filesize: 512 * 1024 * 1024, // 512MB csv_inflation_factor: 0.5, shuffle_aggregation_default_partitions: 200, + partial_aggregation_threshold: 10000, + high_cardinality_aggregation_threshold: 0.8, read_sql_partition_size_bytes: 512 * 1024 * 1024, // 512MB enable_aqe: false, enable_native_executor: false, diff --git a/src/common/daft-config/src/python.rs b/src/common/daft-config/src/python.rs index aceefd63d2..3228263b07 100644 --- a/src/common/daft-config/src/python.rs +++ b/src/common/daft-config/src/python.rs @@ -90,6 +90,8 @@ impl PyDaftExecutionConfig { csv_target_filesize: Option, csv_inflation_factor: Option, shuffle_aggregation_default_partitions: Option, + partial_aggregation_threshold: Option, + high_cardinality_aggregation_threshold: Option, read_sql_partition_size_bytes: Option, enable_aqe: Option, enable_native_executor: Option, @@ -146,6 +148,13 @@ impl PyDaftExecutionConfig { { config.shuffle_aggregation_default_partitions = shuffle_aggregation_default_partitions; } + if let Some(partial_aggregation_threshold) = partial_aggregation_threshold { + config.partial_aggregation_threshold = partial_aggregation_threshold; + } + if let Some(high_cardinality_aggregation_threshold) = high_cardinality_aggregation_threshold + { + config.high_cardinality_aggregation_threshold = high_cardinality_aggregation_threshold; + } if let Some(read_sql_partition_size_bytes) = read_sql_partition_size_bytes { config.read_sql_partition_size_bytes = read_sql_partition_size_bytes; } @@ -245,6 +254,16 @@ impl PyDaftExecutionConfig { Ok(self.config.shuffle_aggregation_default_partitions) } + #[getter] + fn get_partial_aggregation_threshold(&self) -> PyResult { + Ok(self.config.partial_aggregation_threshold) + } + + #[getter] + fn get_high_cardinality_aggregation_threshold(&self) -> PyResult { + Ok(self.config.high_cardinality_aggregation_threshold) + } + #[getter] fn get_read_sql_partition_size_bytes(&self) -> PyResult { Ok(self.config.read_sql_partition_size_bytes) diff --git a/src/daft-local-execution/src/intermediate_ops/aggregate.rs b/src/daft-local-execution/src/intermediate_ops/aggregate.rs deleted file mode 100644 index cb9344b160..0000000000 --- a/src/daft-local-execution/src/intermediate_ops/aggregate.rs +++ /dev/null @@ -1,56 +0,0 @@ -use std::sync::Arc; - -use common_runtime::RuntimeRef; -use daft_dsl::ExprRef; -use daft_micropartition::MicroPartition; -use tracing::instrument; - -use super::intermediate_op::{ - IntermediateOpExecuteResult, IntermediateOpState, IntermediateOperator, - IntermediateOperatorResult, -}; - -struct AggParams { - agg_exprs: Vec, - group_by: Vec, -} - -pub struct AggregateOperator { - params: Arc, -} - -impl AggregateOperator { - pub fn new(agg_exprs: Vec, group_by: Vec) -> Self { - Self { - params: Arc::new(AggParams { - agg_exprs, - group_by, - }), - } - } -} - -impl IntermediateOperator for AggregateOperator { - #[instrument(skip_all, name = "AggregateOperator::execute")] - fn execute( - &self, - input: Arc, - state: Box, - runtime: &RuntimeRef, - ) -> IntermediateOpExecuteResult { - let params = self.params.clone(); - runtime - .spawn(async move { - let out = input.agg(¶ms.agg_exprs, ¶ms.group_by)?; - Ok(( - state, - IntermediateOperatorResult::NeedMoreInput(Some(Arc::new(out))), - )) - }) - .into() - } - - fn name(&self) -> &'static str { - "AggregateOperator" - } -} diff --git a/src/daft-local-execution/src/intermediate_ops/mod.rs b/src/daft-local-execution/src/intermediate_ops/mod.rs index f0af96b763..ca01bafaa4 100644 --- a/src/daft-local-execution/src/intermediate_ops/mod.rs +++ b/src/daft-local-execution/src/intermediate_ops/mod.rs @@ -1,5 +1,4 @@ pub mod actor_pool_project; -pub mod aggregate; pub mod anti_semi_hash_join_probe; pub mod cross_join; pub mod explode; diff --git a/src/daft-local-execution/src/lib.rs b/src/daft-local-execution/src/lib.rs index ec6020eaef..b1b5d99bb0 100644 --- a/src/daft-local-execution/src/lib.rs +++ b/src/daft-local-execution/src/lib.rs @@ -1,4 +1,6 @@ #![feature(let_chains)] +#![feature(option_get_or_insert_default)] + mod buffer; mod channel; mod dispatcher; diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index 0c75de9be6..8881cb37de 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -10,7 +10,7 @@ use daft_core::{ prelude::{Schema, SchemaRef}, utils::supertype, }; -use daft_dsl::{col, join::get_common_join_keys, Expr}; +use daft_dsl::{col, join::get_common_join_keys}; use daft_local_plan::{ ActorPoolProject, Concat, CrossJoin, EmptyScan, Explode, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, LocalPhysicalPlan, MonotonicallyIncreasingId, PhysicalWrite, Pivot, @@ -18,7 +18,6 @@ use daft_local_plan::{ }; use daft_logical_plan::{stats::StatsState, JoinType}; use daft_micropartition::{partitioning::PartitionSet, MicroPartition}; -use daft_physical_plan::{extract_agg_expr, populate_aggregation_stages}; use daft_scan::ScanTaskRef; use daft_writers::make_physical_writer_factory; use indexmap::IndexSet; @@ -27,7 +26,7 @@ use snafu::ResultExt; use crate::{ channel::Receiver, intermediate_ops::{ - actor_pool_project::ActorPoolProjectOperator, aggregate::AggregateOperator, + actor_pool_project::ActorPoolProjectOperator, anti_semi_hash_join_probe::AntiSemiProbeOperator, cross_join::CrossJoinOperator, explode::ExplodeOperator, filter::FilterOperator, inner_hash_join_probe::InnerHashJoinProbeOperator, intermediate_op::IntermediateNode, @@ -38,6 +37,7 @@ use crate::{ blocking_sink::BlockingSinkNode, concat::ConcatSink, cross_join_collect::CrossJoinCollectSink, + grouped_aggregate::GroupedAggregateSink, hash_join_build::HashJoinBuildSink, limit::LimitSink, monotonically_increasing_id::MonotonicallyIncreasingIdSink, @@ -170,42 +170,13 @@ pub fn physical_plan_to_pipeline( schema, .. }) => { - let aggregations = aggregations - .iter() - .map(extract_agg_expr) - .collect::>>() - .with_context(|_| PipelineCreationSnafu { - plan_name: physical_plan.name(), - })?; - - let (first_stage_aggs, second_stage_aggs, final_exprs) = - populate_aggregation_stages(&aggregations, schema, &[]); - let first_stage_agg_op = AggregateOperator::new( - first_stage_aggs - .values() - .cloned() - .map(|e| Arc::new(Expr::Agg(e))) - .collect(), - vec![], - ); let child_node = physical_plan_to_pipeline(input, psets, cfg)?; - let post_first_agg_node = - IntermediateNode::new(Arc::new(first_stage_agg_op), vec![child_node]).boxed(); - - let second_stage_agg_sink = AggregateSink::new( - second_stage_aggs - .values() - .cloned() - .map(|e| Arc::new(Expr::Agg(e))) - .collect(), - vec![], - ); - let second_stage_node = - BlockingSinkNode::new(Arc::new(second_stage_agg_sink), post_first_agg_node).boxed(); - - let final_stage_project = ProjectOperator::new(final_exprs); - - IntermediateNode::new(Arc::new(final_stage_project), vec![second_stage_node]).boxed() + let agg_sink = AggregateSink::new(aggregations, schema).with_context(|_| { + PipelineCreationSnafu { + plan_name: physical_plan.name(), + } + })?; + BlockingSinkNode::new(Arc::new(agg_sink), child_node).boxed() } LocalPhysicalPlan::HashAggregate(HashAggregate { input, @@ -214,48 +185,12 @@ pub fn physical_plan_to_pipeline( schema, .. }) => { - let aggregations = aggregations - .iter() - .map(extract_agg_expr) - .collect::>>() + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; + let agg_sink = GroupedAggregateSink::new(aggregations, group_by, schema, cfg) .with_context(|_| PipelineCreationSnafu { plan_name: physical_plan.name(), })?; - - let (first_stage_aggs, second_stage_aggs, final_exprs) = - populate_aggregation_stages(&aggregations, schema, group_by); - let child_node = physical_plan_to_pipeline(input, psets, cfg)?; - let (post_first_agg_node, group_by) = if !first_stage_aggs.is_empty() { - let agg_op = AggregateOperator::new( - first_stage_aggs - .values() - .cloned() - .map(|e| Arc::new(Expr::Agg(e))) - .collect(), - group_by.clone(), - ); - ( - IntermediateNode::new(Arc::new(agg_op), vec![child_node]).boxed(), - &group_by.iter().map(|e| col(e.name())).collect(), - ) - } else { - (child_node, group_by) - }; - - let second_stage_agg_sink = AggregateSink::new( - second_stage_aggs - .values() - .cloned() - .map(|e| Arc::new(Expr::Agg(e))) - .collect(), - group_by.clone(), - ); - let second_stage_node = - BlockingSinkNode::new(Arc::new(second_stage_agg_sink), post_first_agg_node).boxed(); - - let final_stage_project = ProjectOperator::new(final_exprs); - - IntermediateNode::new(Arc::new(final_stage_project), vec![second_stage_node]).boxed() + BlockingSinkNode::new(Arc::new(agg_sink), child_node).boxed() } LocalPhysicalPlan::Unpivot(Unpivot { input, diff --git a/src/daft-local-execution/src/sinks/aggregate.rs b/src/daft-local-execution/src/sinks/aggregate.rs index 4ac74c8fb5..148fb7e2e4 100644 --- a/src/daft-local-execution/src/sinks/aggregate.rs +++ b/src/daft-local-execution/src/sinks/aggregate.rs @@ -2,8 +2,10 @@ use std::sync::Arc; use common_error::DaftResult; use common_runtime::RuntimeRef; -use daft_dsl::ExprRef; +use daft_core::prelude::SchemaRef; +use daft_dsl::{Expr, ExprRef}; use daft_micropartition::MicroPartition; +use daft_physical_plan::extract_agg_expr; use tracing::instrument; use super::blocking_sink::{ @@ -44,8 +46,9 @@ impl BlockingSinkState for AggregateState { } struct AggParams { - agg_exprs: Vec, - group_by: Vec, + sink_agg_exprs: Vec, + finalize_agg_exprs: Vec, + final_projections: Vec, } pub struct AggregateSink { @@ -53,13 +56,31 @@ pub struct AggregateSink { } impl AggregateSink { - pub fn new(agg_exprs: Vec, group_by: Vec) -> Self { - Self { + pub fn new(aggregations: &[ExprRef], schema: &SchemaRef) -> DaftResult { + let aggregations = aggregations + .iter() + .map(extract_agg_expr) + .collect::>>()?; + let (sink_aggs, finalize_aggs, final_projections) = + daft_physical_plan::populate_aggregation_stages(&aggregations, schema, &[]); + let sink_agg_exprs = sink_aggs + .values() + .cloned() + .map(|e| Arc::new(Expr::Agg(e))) + .collect(); + let finalize_agg_exprs = finalize_aggs + .values() + .cloned() + .map(|e| Arc::new(Expr::Agg(e))) + .collect(); + + Ok(Self { agg_sink_params: Arc::new(AggParams { - agg_exprs, - group_by, + sink_agg_exprs, + finalize_agg_exprs, + final_projections, }), - } + }) } } @@ -69,14 +90,20 @@ impl BlockingSink for AggregateSink { &self, input: Arc, mut state: Box, - _runtime: &RuntimeRef, + runtime: &RuntimeRef, ) -> BlockingSinkSinkResult { - state - .as_any_mut() - .downcast_mut::() - .expect("AggregateSink should have AggregateState") - .push(input); - Ok(BlockingSinkStatus::NeedMoreInput(state)).into() + let params = self.agg_sink_params.clone(); + runtime + .spawn(async move { + let agg_state = state + .as_any_mut() + .downcast_mut::() + .expect("AggregateSink should have AggregateState"); + let agged = Arc::new(input.agg(¶ms.sink_agg_exprs, &[])?); + agg_state.push(agged); + Ok(BlockingSinkStatus::NeedMoreInput(state)) + }) + .into() } #[instrument(skip_all, name = "AggregateSink::finalize")] @@ -96,8 +123,9 @@ impl BlockingSink for AggregateSink { .finalize() }); let concated = MicroPartition::concat(all_parts)?; - let agged = Arc::new(concated.agg(¶ms.agg_exprs, ¶ms.group_by)?); - Ok(Some(agged)) + let agged = concated.agg(¶ms.finalize_agg_exprs, &[])?; + let projected = agged.eval_expression_list(¶ms.final_projections)?; + Ok(Some(Arc::new(projected))) }) .into() } diff --git a/src/daft-local-execution/src/sinks/grouped_aggregate.rs b/src/daft-local-execution/src/sinks/grouped_aggregate.rs new file mode 100644 index 0000000000..c240480eed --- /dev/null +++ b/src/daft-local-execution/src/sinks/grouped_aggregate.rs @@ -0,0 +1,438 @@ +use std::{ + collections::HashSet, + sync::{Arc, Mutex}, +}; + +use common_daft_config::DaftExecutionConfig; +use common_error::DaftResult; +use common_runtime::RuntimeRef; +use daft_core::prelude::SchemaRef; +use daft_dsl::{col, Expr, ExprRef}; +use daft_micropartition::MicroPartition; +use daft_physical_plan::extract_agg_expr; +use tracing::{info_span, instrument, Instrument}; + +use super::blocking_sink::{ + BlockingSink, BlockingSinkFinalizeResult, BlockingSinkSinkResult, BlockingSinkState, + BlockingSinkStatus, +}; +use crate::NUM_CPUS; + +#[derive(Clone)] +enum AggStrategy { + // TODO: This would probably benefit from doing sharded aggs. + AggThenPartition, + PartitionThenAgg(usize), + PartitionOnly, +} + +impl AggStrategy { + fn execute_strategy( + &self, + inner_states: &mut [Option], + input: Arc, + params: &GroupedAggregateParams, + ) -> DaftResult<()> { + match self { + Self::AggThenPartition => Self::execute_agg_then_partition(inner_states, input, params), + Self::PartitionThenAgg(threshold) => { + Self::execute_partition_then_agg(inner_states, input, params, *threshold) + } + Self::PartitionOnly => Self::execute_partition_only(inner_states, input, params), + } + } + + fn execute_agg_then_partition( + inner_states: &mut [Option], + input: Arc, + params: &GroupedAggregateParams, + ) -> DaftResult<()> { + let agged = input.agg( + params.partial_agg_exprs.as_slice(), + params.group_by.as_slice(), + )?; + let partitioned = + agged.partition_by_hash(params.final_group_by.as_slice(), inner_states.len())?; + for (p, state) in partitioned.into_iter().zip(inner_states.iter_mut()) { + let state = state.get_or_insert_default(); + state.partially_aggregated.push(p); + } + Ok(()) + } + + fn execute_partition_then_agg( + inner_states: &mut [Option], + input: Arc, + params: &GroupedAggregateParams, + partial_agg_threshold: usize, + ) -> DaftResult<()> { + let partitioned = + input.partition_by_hash(params.group_by.as_slice(), inner_states.len())?; + for (p, state) in partitioned.into_iter().zip(inner_states.iter_mut()) { + let state = state.get_or_insert_default(); + if state.unaggregated_size + p.len() >= partial_agg_threshold { + let unaggregated = std::mem::take(&mut state.unaggregated); + let aggregated = + MicroPartition::concat(unaggregated.iter().chain(std::iter::once(&p)))?.agg( + params.partial_agg_exprs.as_slice(), + params.group_by.as_slice(), + )?; + state.partially_aggregated.push(aggregated); + state.unaggregated_size = 0; + } else { + state.unaggregated_size += p.len(); + state.unaggregated.push(p); + } + } + Ok(()) + } + + fn execute_partition_only( + inner_states: &mut [Option], + input: Arc, + params: &GroupedAggregateParams, + ) -> DaftResult<()> { + let partitioned = + input.partition_by_hash(params.group_by.as_slice(), inner_states.len())?; + for (p, state) in partitioned.into_iter().zip(inner_states.iter_mut()) { + let state = state.get_or_insert_default(); + state.unaggregated_size += p.len(); + state.unaggregated.push(p); + } + Ok(()) + } +} + +#[derive(Default)] +struct SinglePartitionAggregateState { + partially_aggregated: Vec, + unaggregated: Vec, + unaggregated_size: usize, +} + +enum GroupedAggregateState { + Accumulating { + inner_states: Vec>, + strategy: Option, + partial_agg_threshold: usize, + high_cardinality_threshold_ratio: f64, + }, + Done, +} + +impl GroupedAggregateState { + fn new( + num_partitions: usize, + partial_agg_threshold: usize, + high_cardinality_threshold_ratio: f64, + ) -> Self { + let inner_states = (0..num_partitions).map(|_| None).collect::>(); + Self::Accumulating { + inner_states, + strategy: None, + partial_agg_threshold, + high_cardinality_threshold_ratio, + } + } + + fn push( + &mut self, + input: Arc, + params: &GroupedAggregateParams, + global_strategy_lock: &Arc>>, + ) -> DaftResult<()> { + let Self::Accumulating { + ref mut inner_states, + strategy, + partial_agg_threshold, + high_cardinality_threshold_ratio, + } = self + else { + panic!("GroupedAggregateSink should be in Accumulating state"); + }; + + // If we have determined a strategy, execute it. + if let Some(strategy) = strategy { + strategy.execute_strategy(inner_states, input, params)?; + } else { + // Otherwise, determine the strategy and execute + let decided_strategy = Self::determine_agg_strategy( + &input, + params, + *high_cardinality_threshold_ratio, + *partial_agg_threshold, + strategy, + global_strategy_lock, + )?; + decided_strategy.execute_strategy(inner_states, input, params)?; + } + Ok(()) + } + + fn determine_agg_strategy( + input: &Arc, + params: &GroupedAggregateParams, + high_cardinality_threshold_ratio: f64, + partial_agg_threshold: usize, + local_strategy_cache: &mut Option, + global_strategy_lock: &Arc>>, + ) -> DaftResult { + let mut global_strategy = global_strategy_lock.lock().unwrap(); + // If some other worker has determined a strategy, use that. + if let Some(global_strat) = global_strategy.as_ref() { + *local_strategy_cache = Some(global_strat.clone()); + return Ok(global_strat.clone()); + } + + // Else determine the strategy. + let groupby = input.eval_expression_list(params.group_by.as_slice())?; + + let groupkey_hashes = groupby + .get_tables()? + .iter() + .map(|t| t.hash_rows()) + .collect::>>()?; + let estimated_num_groups = groupkey_hashes + .iter() + .flatten() + .collect::>() + .len(); + + let decided_strategy = if estimated_num_groups as f64 / input.len() as f64 + >= high_cardinality_threshold_ratio + { + AggStrategy::PartitionThenAgg(partial_agg_threshold) + } else { + AggStrategy::AggThenPartition + }; + + *local_strategy_cache = Some(decided_strategy.clone()); + *global_strategy = Some(decided_strategy.clone()); + Ok(decided_strategy) + } + + fn finalize(&mut self) -> Vec> { + let res = if let Self::Accumulating { + ref mut inner_states, + .. + } = self + { + std::mem::take(inner_states) + } else { + panic!("GroupedAggregateSink should be in Accumulating state"); + }; + *self = Self::Done; + res + } +} + +impl BlockingSinkState for GroupedAggregateState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } +} + +struct GroupedAggregateParams { + // The original aggregations and group by expressions + original_aggregations: Vec, + group_by: Vec, + // The expressions for to be used for partial aggregation + partial_agg_exprs: Vec, + // The expressions for the final aggregation + final_agg_exprs: Vec, + final_group_by: Vec, + final_projections: Vec, +} + +pub struct GroupedAggregateSink { + grouped_aggregate_params: Arc, + partial_agg_threshold: usize, + high_cardinality_threshold_ratio: f64, + global_strategy_lock: Arc>>, +} + +impl GroupedAggregateSink { + pub fn new( + aggregations: &[ExprRef], + group_by: &[ExprRef], + schema: &SchemaRef, + cfg: &DaftExecutionConfig, + ) -> DaftResult { + let aggregations = aggregations + .iter() + .map(extract_agg_expr) + .collect::>>()?; + let (partial_aggs, final_aggs, final_projections) = + daft_physical_plan::populate_aggregation_stages(&aggregations, schema, group_by); + let partial_agg_exprs = partial_aggs + .into_values() + .map(|e| Arc::new(Expr::Agg(e))) + .collect::>(); + let final_agg_exprs = final_aggs + .into_values() + .map(|e| Arc::new(Expr::Agg(e))) + .collect::>(); + let final_group_by = if !partial_agg_exprs.is_empty() { + group_by.iter().map(|e| col(e.name())).collect::>() + } else { + group_by.to_vec() + }; + let strategy = if partial_agg_exprs.is_empty() && !final_agg_exprs.is_empty() { + Some(AggStrategy::PartitionOnly) + } else { + None + }; + Ok(Self { + grouped_aggregate_params: Arc::new(GroupedAggregateParams { + original_aggregations: aggregations + .into_iter() + .map(|e| Expr::Agg(e).into()) + .collect(), + group_by: group_by.to_vec(), + partial_agg_exprs, + final_agg_exprs, + final_group_by, + final_projections, + }), + partial_agg_threshold: cfg.partial_aggregation_threshold, + high_cardinality_threshold_ratio: cfg.high_cardinality_aggregation_threshold, + global_strategy_lock: Arc::new(Mutex::new(strategy)), + }) + } + + fn num_partitions(&self) -> usize { + *NUM_CPUS + } +} + +impl BlockingSink for GroupedAggregateSink { + #[instrument(skip_all, name = "GroupedAggregateSink::sink")] + fn sink( + &self, + input: Arc, + mut state: Box, + runtime: &RuntimeRef, + ) -> BlockingSinkSinkResult { + let params = self.grouped_aggregate_params.clone(); + let strategy_lock = self.global_strategy_lock.clone(); + runtime + .spawn( + async move { + let agg_state = state + .as_any_mut() + .downcast_mut::() + .expect("GroupedAggregateSink should have GroupedAggregateState"); + + agg_state.push(input, ¶ms, &strategy_lock)?; + Ok(BlockingSinkStatus::NeedMoreInput(state)) + } + .instrument(info_span!("GroupedAggregateSink::sink")), + ) + .into() + } + + #[instrument(skip_all, name = "GroupedAggregateSink::finalize")] + fn finalize( + &self, + states: Vec>, + runtime: &RuntimeRef, + ) -> BlockingSinkFinalizeResult { + let params = self.grouped_aggregate_params.clone(); + let num_partitions = self.num_partitions(); + runtime + .spawn( + async move { + let mut state_iters = states + .into_iter() + .map(|mut state| { + state + .as_any_mut() + .downcast_mut::() + .expect("GroupedAggregateSink should have GroupedAggregateState") + .finalize() + .into_iter() + }) + .collect::>(); + + let mut per_partition_finalize_tasks = tokio::task::JoinSet::new(); + for _ in 0..num_partitions { + let per_partition_state = state_iters + .iter_mut() + .map(|state| { + state.next().expect( + "GroupedAggregateState should have SinglePartitionAggregateState", + ) + }) + .collect::>(); + let params = params.clone(); + per_partition_finalize_tasks.spawn(async move { + let mut unaggregated = vec![]; + let mut partially_aggregated = vec![]; + for state in per_partition_state.into_iter().flatten() { + unaggregated.extend(state.unaggregated); + partially_aggregated.extend(state.partially_aggregated); + } + + // If we have no partially aggregated partitions, aggregate the unaggregated partitions using the original aggregations + if partially_aggregated.is_empty() { + let concated = MicroPartition::concat(&unaggregated)?; + let agged = concated + .agg(¶ms.original_aggregations, ¶ms.group_by)?; + Ok(agged) + } + // If we have no unaggregated partitions, finalize the partially aggregated partitions + else if unaggregated.is_empty() { + let concated = MicroPartition::concat(&partially_aggregated)?; + let agged = concated + .agg(¶ms.final_agg_exprs, ¶ms.final_group_by)?; + let projected = + agged.eval_expression_list(¶ms.final_projections)?; + Ok(projected) + } + // Otherwise, partially aggregate the unaggregated partitions, concatenate them with the partially aggregated partitions, and finalize the result. + else { + let leftover_partial_agg = + MicroPartition::concat(&unaggregated)? + .agg(¶ms.partial_agg_exprs, ¶ms.group_by)?; + let concated = MicroPartition::concat( + partially_aggregated + .iter() + .chain(std::iter::once(&leftover_partial_agg)), + )?; + let agged = concated + .agg(¶ms.final_agg_exprs, ¶ms.final_group_by)?; + let projected = + agged.eval_expression_list(¶ms.final_projections)?; + Ok(projected) + } + }); + } + let results = per_partition_finalize_tasks + .join_all() + .await + .into_iter() + .collect::>>()?; + let concated = MicroPartition::concat(&results)?; + Ok(Some(Arc::new(concated))) + } + .instrument(info_span!("GroupedAggregateSink::finalize")), + ) + .into() + } + + fn name(&self) -> &'static str { + "GroupedAggregateSink" + } + + fn max_concurrency(&self) -> usize { + *NUM_CPUS + } + + fn make_state(&self) -> DaftResult> { + Ok(Box::new(GroupedAggregateState::new( + self.num_partitions(), + self.partial_agg_threshold, + self.high_cardinality_threshold_ratio, + ))) + } +} diff --git a/src/daft-local-execution/src/sinks/mod.rs b/src/daft-local-execution/src/sinks/mod.rs index 497e05e50e..8d36320128 100644 --- a/src/daft-local-execution/src/sinks/mod.rs +++ b/src/daft-local-execution/src/sinks/mod.rs @@ -2,6 +2,7 @@ pub mod aggregate; pub mod blocking_sink; pub mod concat; pub mod cross_join_collect; +pub mod grouped_aggregate; pub mod hash_join_build; pub mod limit; pub mod monotonically_increasing_id; diff --git a/tests/dataframe/test_intersect.py b/tests/dataframe/test_intersect.py index 24330ec84e..4061fe7f71 100644 --- a/tests/dataframe/test_intersect.py +++ b/tests/dataframe/test_intersect.py @@ -14,7 +14,7 @@ def test_simple_intersect(make_df): def test_intersect_with_duplicate(make_df): df1 = make_df({"foo": [1, 2, 2, 3]}) df2 = make_df({"bar": [2, 3, 3]}) - result = df1.intersect(df2) + result = df1.intersect(df2).sort(by="foo") assert result.to_pydict() == {"foo": [2, 3]} @@ -37,7 +37,7 @@ def test_intersect_with_nulls(make_df): df2 = make_df({"bar": [2, 3, None]}) df2_without_null = make_df({"bar": [2, 3]}) - result = df1.intersect(df2) + result = df1.intersect(df2).sort(by="foo") assert result.to_pydict() == {"foo": [2, None]} result = df1_without_mull.intersect(df2) diff --git a/tests/sql/test_aggs.py b/tests/sql/test_aggs.py index 7fede4dc00..ee001afb2c 100644 --- a/tests/sql/test_aggs.py +++ b/tests/sql/test_aggs.py @@ -83,6 +83,7 @@ def test_having(agg, cond, expected): from df group by id having {cond} + order by id """, catalog, ).to_pydict() From 47f5897be295beeb887ca585a15173e35530767b Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Tue, 17 Dec 2024 12:00:44 -0800 Subject: [PATCH 04/14] feat(connect): support `DdlParse` (#3580) Co-authored-by: Cory Grinstead --- Cargo.lock | 1 + Cargo.toml | 1 + src/daft-connect/Cargo.toml | 1 + src/daft-connect/src/lib.rs | 24 ++++++- src/daft-schema/src/schema.rs | 7 +- src/daft-sql/src/lib.rs | 3 + src/daft-sql/src/planner.rs | 111 +++++++++++++++++++++++++++-- tests/connect/test_analyze_plan.py | 18 +++++ 8 files changed, 160 insertions(+), 6 deletions(-) create mode 100644 tests/connect/test_analyze_plan.py diff --git a/Cargo.lock b/Cargo.lock index fd01681f0b..3011a56b24 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1988,6 +1988,7 @@ dependencies = [ "daft-micropartition", "daft-scan", "daft-schema", + "daft-sql", "daft-table", "dashmap", "eyre", diff --git a/Cargo.toml b/Cargo.toml index b6f5284a60..d5a5cf218d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -204,6 +204,7 @@ daft-logical-plan = {path = "src/daft-logical-plan"} daft-micropartition = {path = "src/daft-micropartition"} daft-scan = {path = "src/daft-scan"} daft-schema = {path = "src/daft-schema"} +daft-sql = {path = "src/daft-sql"} daft-table = {path = "src/daft-table"} derivative = "2.2.0" derive_builder = "0.20.2" diff --git a/src/daft-connect/Cargo.toml b/src/daft-connect/Cargo.toml index b1d1f63052..22ddfe04bc 100644 --- a/src/daft-connect/Cargo.toml +++ b/src/daft-connect/Cargo.toml @@ -11,6 +11,7 @@ daft-logical-plan = {workspace = true} daft-micropartition = {workspace = true} daft-scan = {workspace = true} daft-schema = {workspace = true} +daft-sql = {workspace = true} daft-table = {workspace = true} dashmap = "6.1.0" eyre = "0.6.12" diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs index 439a74dc57..369bfe8e47 100644 --- a/src/daft-connect/src/lib.rs +++ b/src/daft-connect/src/lib.rs @@ -323,7 +323,29 @@ impl SparkConnectService for DaftSparkConnectService { Ok(Response::new(response)) } - _ => unimplemented_err!("Analyze plan operation is not yet implemented"), + Analyze::DdlParse(DdlParse { ddl_string }) => { + let daft_schema = match daft_sql::sql_schema(&ddl_string) { + Ok(daft_schema) => daft_schema, + Err(e) => return invalid_argument_err!("{e}"), + }; + + let daft_schema = daft_schema.to_struct(); + + let schema = translation::to_spark_datatype(&daft_schema); + + let schema = analyze_plan_response::Schema { + schema: Some(schema), + }; + + let response = AnalyzePlanResponse { + session_id, + server_side_session_id: String::new(), + result: Some(analyze_plan_response::Result::Schema(schema)), + }; + + Ok(Response::new(response)) + } + other => unimplemented_err!("Analyze plan operation is not yet implemented: {other:?}"), } } diff --git a/src/daft-schema/src/schema.rs b/src/daft-schema/src/schema.rs index af8eb77e96..a1fc464e96 100644 --- a/src/daft-schema/src/schema.rs +++ b/src/daft-schema/src/schema.rs @@ -13,7 +13,7 @@ use derive_more::Display; use indexmap::IndexMap; use serde::{Deserialize, Serialize}; -use crate::field::Field; +use crate::{field::Field, prelude::DataType}; pub type SchemaRef = Arc; @@ -48,6 +48,11 @@ impl Schema { Ok(Self { fields: map }) } + pub fn to_struct(&self) -> DataType { + let fields = self.fields.values().cloned().collect(); + DataType::Struct(fields) + } + pub fn exclude>(&self, names: &[S]) -> DaftResult { let mut fields = IndexMap::new(); let names = names.iter().map(|s| s.as_ref()).collect::>(); diff --git a/src/daft-sql/src/lib.rs b/src/daft-sql/src/lib.rs index bcb71494b6..75a819c204 100644 --- a/src/daft-sql/src/lib.rs +++ b/src/daft-sql/src/lib.rs @@ -4,7 +4,10 @@ pub mod catalog; pub mod error; pub mod functions; mod modules; + mod planner; +pub use planner::*; + #[cfg(feature = "python")] pub mod python; mod table_provider; diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index e7a1fa381c..683391b601 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -21,10 +21,10 @@ use daft_functions::{ use daft_logical_plan::{LogicalPlanBuilder, LogicalPlanRef}; use sqlparser::{ ast::{ - ArrayElemTypeDef, BinaryOperator, CastKind, DateTimeField, Distinct, ExactNumberInfo, - ExcludeSelectItem, GroupByExpr, Ident, Query, SelectItem, SetExpr, Statement, StructField, - Subscript, TableAlias, TableWithJoins, TimezoneInfo, UnaryOperator, Value, - WildcardAdditionalOptions, With, + ArrayElemTypeDef, BinaryOperator, CastKind, ColumnDef, DateTimeField, Distinct, + ExactNumberInfo, ExcludeSelectItem, GroupByExpr, Ident, Query, SelectItem, SetExpr, + Statement, StructField, Subscript, TableAlias, TableWithJoins, TimezoneInfo, UnaryOperator, + Value, WildcardAdditionalOptions, With, }, dialect::GenericDialect, parser::{Parser, ParserOptions}, @@ -1262,6 +1262,28 @@ impl<'a> SQLPlanner<'a> { } } + fn column_to_field(&self, column_def: &ColumnDef) -> SQLPlannerResult { + let ColumnDef { + name, + data_type, + collation, + options, + } = column_def; + + if let Some(collation) = collation { + unsupported_sql_err!("collation operation ({collation:?}) is not supported") + } + + if !options.is_empty() { + unsupported_sql_err!("unsupported options: {options:?}") + } + + let name = ident_to_str(name); + let data_type = self.sql_dtype_to_dtype(data_type)?; + + Ok(Field::new(name, data_type)) + } + fn value_to_lit(&self, value: &Value) -> SQLPlannerResult { Ok(match value { Value::SingleQuotedString(s) => LiteralValue::Utf8(s.clone()), @@ -2114,6 +2136,32 @@ fn check_wildcard_options( Ok(()) } + +pub fn sql_schema>(s: S) -> SQLPlannerResult { + let planner = SQLPlanner::default(); + + let tokens = Tokenizer::new(&GenericDialect, s.as_ref()).tokenize()?; + + let mut parser = Parser::new(&GenericDialect) + .with_options(ParserOptions { + trailing_commas: true, + ..Default::default() + }) + .with_tokens(tokens); + + let column_defs = parser.parse_comma_separated(Parser::parse_column_def)?; + + let fields: Result, _> = column_defs + .into_iter() + .map(|c| planner.column_to_field(&c)) + .collect(); + + let fields = fields?; + + let schema = Schema::new(fields)?; + Ok(Arc::new(schema)) +} + pub fn sql_expr>(s: S) -> SQLPlannerResult { let mut planner = SQLPlanner::default(); @@ -2138,6 +2186,12 @@ pub fn sql_expr>(s: S) -> SQLPlannerResult { // ---------------- // Helper functions // ---------------- + +/// # Examples +/// ``` +/// // Quoted identifier "MyCol" -> "MyCol" +/// // Unquoted identifier MyCol -> "MyCol" +/// ``` fn ident_to_str(ident: &Ident) -> String { if ident.quote_style == Some('"') { ident.value.to_string() @@ -2190,3 +2244,52 @@ fn unresolve_alias(expr: ExprRef, projection: &[ExprRef]) -> SQLPlannerResult Date: Tue, 17 Dec 2024 12:16:01 -0800 Subject: [PATCH 05/14] ci: Add ability to array-ify args and run multiple jobs (#3584) # Overview Previously, the `run-cluster` workflow only ran one ray-job-submission. This PR extends the ability to be able to run any arbitrary array of job submissions by enabling us to pass an array into the `entrypoint_args` input param. This then splits the command into its multiple pieces and submits them all. ## Example Usage ```sh gh workflow run run-cluster.yaml \ --ref $current_branch \ -f working_dir="." \ -f daft_wheel_url="https://github-actions-artifacts-bucket.s3.us-west-2.amazonaws.com/builds/54428e3738e96764af60cfdd8a0e4a41717ec9f9/getdaft-0.3.0.dev0-cp38-abi3-manylinux_2_31_x86_64.whl" \ -f entrypoint_script="benchmarking/tpcds/ray_entrypoint.py" \ -f entrypoint_args="[\"--tpcds-gen-folder='gendata' --question='1'\", \"--tpcds-gen-folder='gendata' --question='2'\"]" ``` The above invocation runs TPC-DS queries 1 and 2. --- .github/ci-scripts/job_runner.py | 114 ++++++++++++++++++++ .github/ci-scripts/templatize_ray_config.py | 2 + .github/workflows/run-cluster.yaml | 36 +++---- benchmarking/tpcds/ray_entrypoint.py | 64 ++++++++--- 4 files changed, 180 insertions(+), 36 deletions(-) create mode 100644 .github/ci-scripts/job_runner.py diff --git a/.github/ci-scripts/job_runner.py b/.github/ci-scripts/job_runner.py new file mode 100644 index 0000000000..12c949136f --- /dev/null +++ b/.github/ci-scripts/job_runner.py @@ -0,0 +1,114 @@ +# /// script +# requires-python = ">=3.12" +# dependencies = [] +# /// + +import argparse +import asyncio +import json +from dataclasses import dataclass +from datetime import datetime, timedelta +from pathlib import Path +from typing import Optional + +from ray.job_submission import JobStatus, JobSubmissionClient + + +def parse_env_var_str(env_var_str: str) -> dict: + iter = map( + lambda s: s.strip().split("="), + filter(lambda s: s, env_var_str.split(",")), + ) + return {k: v for k, v in iter} + + +async def print_logs(logs): + async for lines in logs: + print(lines, end="") + + +async def wait_on_job(logs, timeout_s): + await asyncio.wait_for(print_logs(logs), timeout=timeout_s) + + +@dataclass +class Result: + query: int + duration: timedelta + error_msg: Optional[str] + + +def submit_job( + working_dir: Path, + entrypoint_script: str, + entrypoint_args: str, + env_vars: str, + enable_ray_tracing: bool, +): + env_vars_dict = parse_env_var_str(env_vars) + if enable_ray_tracing: + env_vars_dict["DAFT_ENABLE_RAY_TRACING"] = "1" + + client = JobSubmissionClient(address="http://localhost:8265") + + if entrypoint_args.startswith("[") and entrypoint_args.endswith("]"): + # this is a json-encoded list of strings; parse accordingly + list_of_entrypoint_args: list[str] = json.loads(entrypoint_args) + else: + list_of_entrypoint_args: list[str] = [entrypoint_args] + + results = [] + + for index, args in enumerate(list_of_entrypoint_args): + entrypoint = f"DAFT_RUNNER=ray python {entrypoint_script} {args}" + print(f"{entrypoint=}") + start = datetime.now() + job_id = client.submit_job( + entrypoint=entrypoint, + runtime_env={ + "working_dir": working_dir, + "env_vars": env_vars_dict, + }, + ) + + asyncio.run(wait_on_job(client.tail_job_logs(job_id), timeout_s=60 * 30)) + + status = client.get_job_status(job_id) + assert status.is_terminal(), "Job should have terminated" + end = datetime.now() + duration = end - start + error_msg = None + if status != JobStatus.SUCCEEDED: + job_info = client.get_job_info(job_id) + error_msg = job_info.message + + result = Result(query=index, duration=duration, error_msg=error_msg) + results.append(result) + + print(f"{results=}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--working-dir", type=Path, required=True) + parser.add_argument("--entrypoint-script", type=str, required=True) + parser.add_argument("--entrypoint-args", type=str, required=True) + parser.add_argument("--env-vars", type=str, required=True) + parser.add_argument("--enable-ray-tracing", action="store_true") + + args = parser.parse_args() + + if not (args.working_dir.exists() and args.working_dir.is_dir()): + raise ValueError("The working-dir must exist and be a directory") + + entrypoint: Path = args.working_dir / args.entrypoint_script + if not (entrypoint.exists() and entrypoint.is_file()): + raise ValueError("The entrypoint script must exist and be a file") + + submit_job( + working_dir=args.working_dir, + entrypoint_script=args.entrypoint_script, + entrypoint_args=args.entrypoint_args, + env_vars=args.env_vars, + enable_ray_tracing=args.enable_ray_tracing, + ) diff --git a/.github/ci-scripts/templatize_ray_config.py b/.github/ci-scripts/templatize_ray_config.py index 30fe47f995..1608cf8dee 100644 --- a/.github/ci-scripts/templatize_ray_config.py +++ b/.github/ci-scripts/templatize_ray_config.py @@ -110,5 +110,7 @@ class Metadata(BaseModel, extra="allow"): if metadata: metadata = Metadata(**metadata) content = content.replace(OTHER_INSTALL_PLACEHOLDER, " ".join(metadata.dependencies)) + else: + content = content.replace(OTHER_INSTALL_PLACEHOLDER, "") print(content) diff --git a/.github/workflows/run-cluster.yaml b/.github/workflows/run-cluster.yaml index 644250d1f1..e0262be2cb 100644 --- a/.github/workflows/run-cluster.yaml +++ b/.github/workflows/run-cluster.yaml @@ -34,7 +34,7 @@ on: type: string required: true entrypoint_args: - description: Entry-point arguments + description: Entry-point arguments (either a simple string or a JSON list) type: string required: false default: "" @@ -79,24 +79,15 @@ jobs: uv run \ --python 3.12 \ .github/ci-scripts/templatize_ray_config.py \ - --cluster-name "ray-ci-run-${{ github.run_id }}_${{ github.run_attempt }}" \ - --daft-wheel-url '${{ inputs.daft_wheel_url }}' \ - --daft-version '${{ inputs.daft_version }}' \ - --python-version '${{ inputs.python_version }}' \ - --cluster-profile '${{ inputs.cluster_profile }}' \ - --working-dir '${{ inputs.working_dir }}' \ - --entrypoint-script '${{ inputs.entrypoint_script }}' + --cluster-name="ray-ci-run-${{ github.run_id }}_${{ github.run_attempt }}" \ + --daft-wheel-url='${{ inputs.daft_wheel_url }}' \ + --daft-version='${{ inputs.daft_version }}' \ + --python-version='${{ inputs.python_version }}' \ + --cluster-profile='${{ inputs.cluster_profile }}' \ + --working-dir='${{ inputs.working_dir }}' \ + --entrypoint-script='${{ inputs.entrypoint_script }}' ) >> .github/assets/ray.yaml cat .github/assets/ray.yaml - - name: Setup ray env vars - run: | - source .venv/bin/activate - ray_env_var=$(python .github/ci-scripts/format_env_vars.py \ - --env-vars '${{ inputs.env_vars }}' \ - --enable-ray-tracing \ - ) - echo $ray_env_var - echo "ray_env_var=$ray_env_var" >> $GITHUB_ENV - name: Download private ssh key run: | KEY=$(aws secretsmanager get-secret-value --secret-id ci-github-actions-ray-cluster-key-3 --query SecretString --output text) @@ -117,11 +108,12 @@ jobs: echo 'Invalid command submitted; command cannot be empty' exit 1 fi - ray job submit \ - --working-dir ${{ inputs.working_dir }} \ - --address http://localhost:8265 \ - --runtime-env-json "$ray_env_var" \ - -- python ${{ inputs.entrypoint_script }} ${{ inputs.entrypoint_args }} + python .github/ci-scripts/job_runner.py \ + --working-dir='${{ inputs.working_dir }}' \ + --entrypoint-script='${{ inputs.entrypoint_script }}' \ + --entrypoint-args='${{ inputs.entrypoint_args }}' \ + --env-vars='${{ inputs.env_vars }}' \ + --enable-ray-tracing - name: Download log files from ray cluster run: | source .venv/bin/activate diff --git a/benchmarking/tpcds/ray_entrypoint.py b/benchmarking/tpcds/ray_entrypoint.py index 10e52c4198..20656b37b9 100644 --- a/benchmarking/tpcds/ray_entrypoint.py +++ b/benchmarking/tpcds/ray_entrypoint.py @@ -1,17 +1,54 @@ import argparse from pathlib import Path -import helpers - import daft +from daft.sql.sql import SQLCatalog + +TABLE_NAMES = [ + "call_center", + "catalog_page", + "catalog_returns", + "catalog_sales", + "customer", + "customer_address", + "customer_demographics", + "date_dim", + "household_demographics", + "income_band", + "inventory", + "item", + "promotion", + "reason", + "ship_mode", + "store", + "store_returns", + "store_sales", + "time_dim", + "warehouse", + "web_page", + "web_returns", + "web_sales", + "web_site", +] + + +def register_catalog(scale_factor: int) -> SQLCatalog: + return SQLCatalog( + tables={ + table: daft.read_parquet( + f"s3://eventual-dev-benchmarking-fixtures/uncompressed/tpcds-dbgen/{scale_factor}/{table}.parquet" + ) + for table in TABLE_NAMES + } + ) def run( - parquet_folder: Path, question: int, dry_run: bool, + scale_factor: int, ): - catalog = helpers.generate_catalog(parquet_folder) + catalog = register_catalog(scale_factor) query_file = Path(__file__).parent / "queries" / f"{question:02}.sql" with open(query_file) as f: query = f.read() @@ -23,27 +60,26 @@ def run( if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument( - "--tpcds-gen-folder", - required=True, - type=Path, - help="Path to the TPC-DS data generation folder", - ) parser.add_argument( "--question", - required=True, type=int, help="The TPC-DS question index to run", + required=True, ) parser.add_argument( "--dry-run", action="store_true", help="Whether or not to run the query in dry-run mode; if true, only the plan will be printed out", ) + parser.add_argument( + "--scale-factor", + type=int, + help="Which scale factor to run this data at", + required=False, + default=2, + ) args = parser.parse_args() - tpcds_gen_folder: Path = args.tpcds_gen_folder - assert tpcds_gen_folder.exists() assert args.question in range(1, 100) - run(args.tpcds_gen_folder, args.question, args.dry_run) + run(question=args.question, dry_run=args.dry_run, scale_factor=args.scale_factor) From 5165e5e67d14be17b9ee4713dd6f6b01d29e963a Mon Sep 17 00:00:00 2001 From: Kev Wang Date: Tue, 17 Dec 2024 12:25:50 -0800 Subject: [PATCH 06/14] chore: move symbolic and boolean algebra code into new crate (#3570) Moving some of the code we have around symbolic/boolean algebra on expressions into its own crate, because I anticipate that we will be building more of this kind of thing, so it would be nicer to consolidate it as well as make it easier to reuse. It also allows us to better test these things in isolation of the context they are being used in. For example, we'll be building some optimization rules that more intelligently finds predicates for filter pushdown into joins, and that may use both `split_conjunction` as well as some expression simplification logic. Also took this opportunity to fix a typo (conjuct -> conjunct) and rename `conjunct` (which is an adjective) to `combine_conjunction`. Otherwise everything else is pretty much a straightforward move --- Cargo.lock | 15 + Cargo.toml | 2 + src/common/scan-info/Cargo.toml | 1 + src/common/scan-info/src/expr_rewriter.rs | 9 +- src/daft-algebra/Cargo.toml | 16 + src/daft-algebra/src/boolean.rs | 24 + src/daft-algebra/src/lib.rs | 4 + src/daft-algebra/src/simplify.rs | 465 ++++++++++++++++++ src/daft-dsl/src/optimization.rs | 30 +- src/daft-logical-plan/Cargo.toml | 1 + .../optimization/rules/push_down_filter.rs | 37 +- .../rules/simplify_expressions.rs | 463 +---------------- .../src/optimization/rules/unnest_subquery.rs | 27 +- src/daft-sql/Cargo.toml | 1 + src/daft-sql/src/planner.rs | 10 +- 15 files changed, 570 insertions(+), 535 deletions(-) create mode 100644 src/daft-algebra/Cargo.toml create mode 100644 src/daft-algebra/src/boolean.rs create mode 100644 src/daft-algebra/src/lib.rs create mode 100644 src/daft-algebra/src/simplify.rs diff --git a/Cargo.lock b/Cargo.lock index 3011a56b24..048324ca60 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1577,6 +1577,7 @@ dependencies = [ "common-display", "common-error", "common-file-formats", + "daft-algebra", "daft-dsl", "daft-schema", "pyo3", @@ -1905,6 +1906,7 @@ dependencies = [ "common-system-info", "common-tracing", "common-version", + "daft-algebra", "daft-catalog", "daft-catalog-python-catalog", "daft-compression", @@ -1941,6 +1943,17 @@ dependencies = [ "tikv-jemallocator", ] +[[package]] +name = "daft-algebra" +version = "0.3.0-dev0" +dependencies = [ + "common-error", + "common-treenode", + "daft-dsl", + "daft-schema", + "rstest", +] + [[package]] name = "daft-catalog" version = "0.3.0-dev0" @@ -2329,6 +2342,7 @@ dependencies = [ "common-resource-request", "common-scan-info", "common-treenode", + "daft-algebra", "daft-core", "daft-dsl", "daft-functions", @@ -2536,6 +2550,7 @@ dependencies = [ "common-error", "common-io-config", "common-runtime", + "daft-algebra", "daft-core", "daft-dsl", "daft-functions", diff --git a/Cargo.toml b/Cargo.toml index d5a5cf218d..ba00e45cd4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ common-scan-info = {path = "src/common/scan-info", default-features = false} common-system-info = {path = "src/common/system-info", default-features = false} common-tracing = {path = "src/common/tracing", default-features = false} common-version = {path = "src/common/version", default-features = false} +daft-algebra = {path = "src/daft-algebra", default-features = false} daft-catalog = {path = "src/daft-catalog", default-features = false} daft-catalog-python-catalog = {path = "src/daft-catalog/python-catalog", optional = true} daft-compression = {path = "src/daft-compression", default-features = false} @@ -149,6 +150,7 @@ members = [ "src/common/scan-info", "src/common/system-info", "src/common/treenode", + "src/daft-algebra", "src/daft-catalog", "src/daft-core", "src/daft-csv", diff --git a/src/common/scan-info/Cargo.toml b/src/common/scan-info/Cargo.toml index 04f9550997..0aecf55f6e 100644 --- a/src/common/scan-info/Cargo.toml +++ b/src/common/scan-info/Cargo.toml @@ -3,6 +3,7 @@ common-daft-config = {path = "../daft-config", default-features = false} common-display = {path = "../display", default-features = false} common-error = {path = "../error", default-features = false} common-file-formats = {path = "../file-formats", default-features = false} +daft-algebra = {path = "../../daft-algebra", default-features = false} daft-dsl = {path = "../../daft-dsl", default-features = false} daft-schema = {path = "../../daft-schema", default-features = false} pyo3 = {workspace = true, optional = true} diff --git a/src/common/scan-info/src/expr_rewriter.rs b/src/common/scan-info/src/expr_rewriter.rs index f678ad07c1..fedf212a18 100644 --- a/src/common/scan-info/src/expr_rewriter.rs +++ b/src/common/scan-info/src/expr_rewriter.rs @@ -1,13 +1,12 @@ use std::collections::HashMap; use common_error::DaftResult; +use daft_algebra::boolean::split_conjunction; use daft_dsl::{ col, common_treenode::{Transformed, TreeNode, TreeNodeRecursion}, functions::{partitioning, FunctionExpr}, - null_lit, - optimization::split_conjuction, - Expr, ExprRef, Operator, + null_lit, Expr, ExprRef, Operator, }; use crate::{PartitionField, PartitionTransform}; @@ -93,7 +92,7 @@ pub fn rewrite_predicate_for_partitioning( // Before rewriting predicate for partition filter pushdown, partition predicate clauses into groups that will need // to be applied at the data level (i.e. any clauses that aren't pure partition predicates with identity // transformations). - let data_split = split_conjuction(predicate); + let data_split = split_conjunction(predicate); // Predicates that reference both partition columns and data columns. let mut needs_filter_op_preds: Vec = vec![]; // Predicates that only reference data columns (no partition column references) or only reference partition columns @@ -332,7 +331,7 @@ pub fn rewrite_predicate_for_partitioning( let with_part_cols = with_part_cols.data; // Filter to predicate clauses that only involve partition columns. - let split = split_conjuction(&with_part_cols); + let split = split_conjunction(&with_part_cols); let mut part_preds: Vec = vec![]; for e in split { let mut all_part_keys = true; diff --git a/src/daft-algebra/Cargo.toml b/src/daft-algebra/Cargo.toml new file mode 100644 index 0000000000..89e942c700 --- /dev/null +++ b/src/daft-algebra/Cargo.toml @@ -0,0 +1,16 @@ +[dependencies] +common-error = {path = "../common/error", default-features = false} +common-treenode = {path = "../common/treenode", default-features = false} +daft-dsl = {path = "../daft-dsl", default-features = false} +daft-schema = {path = "../daft-schema", default-features = false} + +[dev-dependencies] +rstest = {workspace = true} + +[lints] +workspace = true + +[package] +edition = {workspace = true} +name = "daft-algebra" +version = {workspace = true} diff --git a/src/daft-algebra/src/boolean.rs b/src/daft-algebra/src/boolean.rs new file mode 100644 index 0000000000..38f659e00c --- /dev/null +++ b/src/daft-algebra/src/boolean.rs @@ -0,0 +1,24 @@ +use common_treenode::{TreeNode, TreeNodeRecursion}; +use daft_dsl::{Expr, ExprRef, Operator}; + +pub fn split_conjunction(expr: &ExprRef) -> Vec { + let mut splits = vec![]; + + expr.apply(|e| match e.as_ref() { + Expr::BinaryOp { + op: Operator::And, .. + } + | Expr::Alias(..) => Ok(TreeNodeRecursion::Continue), + _ => { + splits.push(e.clone()); + Ok(TreeNodeRecursion::Jump) + } + }) + .unwrap(); + + splits +} + +pub fn combine_conjunction>(exprs: T) -> Option { + exprs.into_iter().reduce(|acc, e| acc.and(e)) +} diff --git a/src/daft-algebra/src/lib.rs b/src/daft-algebra/src/lib.rs new file mode 100644 index 0000000000..317ef5eea1 --- /dev/null +++ b/src/daft-algebra/src/lib.rs @@ -0,0 +1,4 @@ +pub mod boolean; +mod simplify; + +pub use simplify::simplify_expr; diff --git a/src/daft-algebra/src/simplify.rs b/src/daft-algebra/src/simplify.rs new file mode 100644 index 0000000000..698a48fa1b --- /dev/null +++ b/src/daft-algebra/src/simplify.rs @@ -0,0 +1,465 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use common_treenode::Transformed; +use daft_dsl::{lit, null_lit, Expr, ExprRef, LiteralValue, Operator}; +use daft_schema::{dtype::DataType, schema::SchemaRef}; + +pub fn simplify_expr(expr: Expr, schema: &SchemaRef) -> DaftResult> { + Ok(match expr { + // ---------------- + // Eq + // ---------------- + // true = A --> A + // false = A --> !A + Expr::BinaryOp { + op: Operator::Eq, + left, + right, + } + // A = true --> A + // A = false --> !A + | Expr::BinaryOp { + op: Operator::Eq, + left: right, + right: left, + } if is_bool_lit(&left) && is_bool_type(&right, schema) => { + Transformed::yes(match as_bool_lit(&left) { + Some(true) => right, + Some(false) => right.not(), + None => unreachable!(), + }) + } + + // null = A --> null + // A = null --> null + Expr::BinaryOp { + op: Operator::Eq, + left, + right, + } + | Expr::BinaryOp { + op: Operator::Eq, + left: right, + right: left, + } if is_null(&left) && is_bool_type(&right, schema) => Transformed::yes(null_lit()), + + // ---------------- + // Neq + // ---------------- + // true != A --> !A + // false != A --> A + Expr::BinaryOp { + op: Operator::NotEq, + left, + right, + } + // A != true --> !A + // A != false --> A + | Expr::BinaryOp { + op: Operator::NotEq, + left: right, + right: left, + } if is_bool_lit(&left) && is_bool_type(&right, schema) => { + Transformed::yes(match as_bool_lit(&left) { + Some(true) => right.not(), + Some(false) => right, + None => unreachable!(), + }) + } + + // null != A --> null + // A != null --> null + Expr::BinaryOp { + op: Operator::NotEq, + left, + right, + } + | Expr::BinaryOp { + op: Operator::NotEq, + left: right, + right: left, + } if is_null(&left) && is_bool_type(&right, schema) => Transformed::yes(null_lit()), + + // ---------------- + // OR + // ---------------- + + // true OR A --> true + Expr::BinaryOp { + op: Operator::Or, + left, + right: _, + } if is_true(&left) => Transformed::yes(left), + // false OR A --> A + Expr::BinaryOp { + op: Operator::Or, + left, + right, + } if is_false(&left) => Transformed::yes(right), + // A OR true --> true + Expr::BinaryOp { + op: Operator::Or, + left: _, + right, + } if is_true(&right) => Transformed::yes(right), + // A OR false --> A + Expr::BinaryOp { + left, + op: Operator::Or, + right, + } if is_false(&right) => Transformed::yes(left), + + // ---------------- + // AND (TODO) + // ---------------- + + // ---------------- + // Multiplication + // ---------------- + + // A * 1 --> A + // 1 * A --> A + Expr::BinaryOp { + op: Operator::Multiply, + left, + right, + }| Expr::BinaryOp { + op: Operator::Multiply, + left: right, + right: left, + } if is_one(&right) => Transformed::yes(left), + + // A * null --> null + Expr::BinaryOp { + op: Operator::Multiply, + left: _, + right, + } if is_null(&right) => Transformed::yes(right), + // null * A --> null + Expr::BinaryOp { + op: Operator::Multiply, + left, + right: _, + } if is_null(&left) => Transformed::yes(left), + + // TODO: Can't do this one because we don't have a way to determine if an expr potentially contains nulls (nullable) + // A * 0 --> 0 (if A is not null and not floating/decimal) + // 0 * A --> 0 (if A is not null and not floating/decimal) + + // ---------------- + // Division + // ---------------- + // A / 1 --> A + Expr::BinaryOp { + op: Operator::TrueDivide, + left, + right, + } if is_one(&right) => Transformed::yes(left), + // null / A --> null + Expr::BinaryOp { + op: Operator::TrueDivide, + left, + right: _, + } if is_null(&left) => Transformed::yes(left), + // A / null --> null + Expr::BinaryOp { + op: Operator::TrueDivide, + left: _, + right, + } if is_null(&right) => Transformed::yes(right), + + // ---------------- + // Addition + // ---------------- + // A + 0 --> A + Expr::BinaryOp { + op: Operator::Plus, + left, + right, + } if is_zero(&right) => Transformed::yes(left), + + // 0 + A --> A + Expr::BinaryOp { + op: Operator::Plus, + left, + right, + } if is_zero(&left) => Transformed::yes(right), + + // ---------------- + // Subtraction + // ---------------- + + // A - 0 --> A + Expr::BinaryOp { + op: Operator::Minus, + left, + right, + } if is_zero(&right) => Transformed::yes(left), + + // A - null --> null + Expr::BinaryOp { + op: Operator::Minus, + left: _, + right, + } if is_null(&right) => Transformed::yes(right), + // null - A --> null + Expr::BinaryOp { + op: Operator::Minus, + left, + right: _, + } if is_null(&left) => Transformed::yes(left), + + // ---------------- + // Modulus + // ---------------- + + // A % null --> null + Expr::BinaryOp { + op: Operator::Modulus, + left: _, + right, + } if is_null(&right) => Transformed::yes(right), + + // null % A --> null + Expr::BinaryOp { + op: Operator::Modulus, + left, + right: _, + } if is_null(&left) => Transformed::yes(left), + + // A BETWEEN low AND high --> A >= low AND A <= high + Expr::Between(expr, low, high) => { + Transformed::yes(expr.clone().lt_eq(high).and(expr.gt_eq(low))) + } + Expr::Not(expr) => match Arc::unwrap_or_clone(expr) { + // NOT (BETWEEN A AND B) --> A < low OR A > high + Expr::Between(expr, low, high) => { + Transformed::yes(expr.clone().lt(low).or(expr.gt(high))) + } + // expr NOT IN () --> true + Expr::IsIn(_, list) if list.is_empty() => Transformed::yes(lit(true)), + + expr => { + let expr = simplify_expr(expr, schema)?; + if expr.transformed { + Transformed::yes(expr.data.not()) + } else { + Transformed::no(expr.data.not()) + } + } + }, + // expr IN () --> false + Expr::IsIn(_, list) if list.is_empty() => Transformed::yes(lit(false)), + + other => Transformed::no(Arc::new(other)), + }) +} + +fn is_zero(s: &Expr) -> bool { + match s { + Expr::Literal(LiteralValue::Int32(0)) + | Expr::Literal(LiteralValue::Int64(0)) + | Expr::Literal(LiteralValue::UInt32(0)) + | Expr::Literal(LiteralValue::UInt64(0)) + | Expr::Literal(LiteralValue::Float64(0.)) => true, + Expr::Literal(LiteralValue::Decimal(v, _p, _s)) if *v == 0 => true, + _ => false, + } +} + +fn is_one(s: &Expr) -> bool { + match s { + Expr::Literal(LiteralValue::Int32(1)) + | Expr::Literal(LiteralValue::Int64(1)) + | Expr::Literal(LiteralValue::UInt32(1)) + | Expr::Literal(LiteralValue::UInt64(1)) + | Expr::Literal(LiteralValue::Float64(1.)) => true, + + Expr::Literal(LiteralValue::Decimal(v, _p, s)) => { + *s >= 0 && POWS_OF_TEN.get(*s as usize).is_some_and(|pow| v == pow) + } + _ => false, + } +} + +fn is_true(expr: &Expr) -> bool { + match expr { + Expr::Literal(LiteralValue::Boolean(v)) => *v, + _ => false, + } +} +fn is_false(expr: &Expr) -> bool { + match expr { + Expr::Literal(LiteralValue::Boolean(v)) => !*v, + _ => false, + } +} + +/// returns true if expr is a +/// `Expr::Literal(LiteralValue::Boolean(v))` , false otherwise +fn is_bool_lit(expr: &Expr) -> bool { + matches!(expr, Expr::Literal(LiteralValue::Boolean(_))) +} + +fn is_bool_type(expr: &Expr, schema: &SchemaRef) -> bool { + matches!(expr.get_type(schema), Ok(DataType::Boolean)) +} + +fn as_bool_lit(expr: &Expr) -> Option { + expr.as_literal().and_then(|l| l.as_bool()) +} + +fn is_null(expr: &Expr) -> bool { + matches!(expr, Expr::Literal(LiteralValue::Null)) +} + +static POWS_OF_TEN: [i128; 38] = [ + 1, + 10, + 100, + 1000, + 10000, + 100000, + 1000000, + 10000000, + 100000000, + 1000000000, + 10000000000, + 100000000000, + 1000000000000, + 10000000000000, + 100000000000000, + 1000000000000000, + 10000000000000000, + 100000000000000000, + 1000000000000000000, + 10000000000000000000, + 100000000000000000000, + 1000000000000000000000, + 10000000000000000000000, + 100000000000000000000000, + 1000000000000000000000000, + 10000000000000000000000000, + 100000000000000000000000000, + 1000000000000000000000000000, + 10000000000000000000000000000, + 100000000000000000000000000000, + 1000000000000000000000000000000, + 10000000000000000000000000000000, + 100000000000000000000000000000000, + 1000000000000000000000000000000000, + 10000000000000000000000000000000000, + 100000000000000000000000000000000000, + 1000000000000000000000000000000000000, + 10000000000000000000000000000000000000, +]; + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use common_error::DaftResult; + use daft_dsl::{col, lit, null_lit, ExprRef}; + use daft_schema::{ + dtype::DataType, + field::Field, + schema::{Schema, SchemaRef}, + }; + use rstest::{fixture, rstest}; + + use crate::simplify_expr; + + #[fixture] + fn schema() -> SchemaRef { + Arc::new( + Schema::new(vec![ + Field::new("bool", DataType::Boolean), + Field::new("int", DataType::Int32), + ]) + .unwrap(), + ) + } + + #[rstest] + // true = A --> A + #[case(col("bool").eq(lit(true)), col("bool"))] + // false = A --> !A + #[case(col("bool").eq(lit(false)), col("bool").not())] + // A == true ---> A + #[case(col("bool").eq(lit(true)), col("bool"))] + // null = A --> null + #[case(null_lit().eq(col("bool")), null_lit())] + // A == false ---> !A + #[case(col("bool").eq(lit(false)), col("bool").not())] + // true != A --> !A + #[case(lit(true).not_eq(col("bool")), col("bool").not())] + // false != A --> A + #[case(lit(false).not_eq(col("bool")), col("bool"))] + // true OR A --> true + #[case(lit(true).or(col("bool")), lit(true))] + // false OR A --> A + #[case(lit(false).or(col("bool")), col("bool"))] + // A OR true --> true + #[case(col("bool").or(lit(true)), lit(true))] + // A OR false --> A + #[case(col("bool").or(lit(false)), col("bool"))] + fn test_simplify_bool_exprs( + #[case] input: ExprRef, + #[case] expected: ExprRef, + schema: SchemaRef, + ) -> DaftResult<()> { + let optimized = simplify_expr(Arc::unwrap_or_clone(input), &schema)?; + + assert!(optimized.transformed); + assert_eq!(optimized.data, expected); + Ok(()) + } + + #[rstest] + // A * 1 --> A + #[case(col("int").mul(lit(1)), col("int"))] + // 1 * A --> A + #[case(lit(1).mul(col("int")), col("int"))] + // A / 1 --> A + #[case(col("int").div(lit(1)), col("int"))] + // A + 0 --> A + #[case(col("int").add(lit(0)), col("int"))] + // A - 0 --> A + #[case(col("int").sub(lit(0)), col("int"))] + fn test_math_exprs( + #[case] input: ExprRef, + #[case] expected: ExprRef, + schema: SchemaRef, + ) -> DaftResult<()> { + let optimized = simplify_expr(Arc::unwrap_or_clone(input), &schema)?; + + assert!(optimized.transformed); + assert_eq!(optimized.data, expected); + Ok(()) + } + + #[rstest] + fn test_not_between(schema: SchemaRef) -> DaftResult<()> { + let input = col("int").between(lit(1), lit(10)).not(); + let expected = col("int").lt(lit(1)).or(col("int").gt(lit(10))); + + let optimized = simplify_expr(Arc::unwrap_or_clone(input), &schema)?; + + assert!(optimized.transformed); + assert_eq!(optimized.data, expected); + Ok(()) + } + + #[rstest] + fn test_between(schema: SchemaRef) -> DaftResult<()> { + let input = col("int").between(lit(1), lit(10)); + let expected = col("int").lt_eq(lit(10)).and(col("int").gt_eq(lit(1))); + + let optimized = simplify_expr(Arc::unwrap_or_clone(input), &schema)?; + + assert!(optimized.transformed); + assert_eq!(optimized.data, expected); + Ok(()) + } +} diff --git a/src/daft-dsl/src/optimization.rs b/src/daft-dsl/src/optimization.rs index 06cff96959..38a9c8588e 100644 --- a/src/daft-dsl/src/optimization.rs +++ b/src/daft-dsl/src/optimization.rs @@ -2,8 +2,7 @@ use std::collections::HashMap; use common_treenode::{Transformed, TreeNode, TreeNodeRecursion}; -use super::expr::Expr; -use crate::{ExprRef, Operator}; +use crate::{Expr, ExprRef}; pub fn get_required_columns(e: &ExprRef) -> Vec { let mut cols = vec![]; @@ -57,30 +56,3 @@ pub fn replace_columns_with_expressions( .expect("Error occurred when rewriting column expressions"); transformed.data } - -pub fn split_conjuction(expr: &ExprRef) -> Vec<&ExprRef> { - let mut splits = vec![]; - _split_conjuction(expr, &mut splits); - splits -} - -fn _split_conjuction<'a>(expr: &'a ExprRef, out_exprs: &mut Vec<&'a ExprRef>) { - match expr.as_ref() { - Expr::BinaryOp { - op: Operator::And, - left, - right, - } => { - _split_conjuction(left, out_exprs); - _split_conjuction(right, out_exprs); - } - Expr::Alias(inner_expr, ..) => _split_conjuction(inner_expr, out_exprs), - _ => { - out_exprs.push(expr); - } - } -} - -pub fn conjuct>(exprs: T) -> Option { - exprs.into_iter().reduce(|acc, expr| acc.and(expr)) -} diff --git a/src/daft-logical-plan/Cargo.toml b/src/daft-logical-plan/Cargo.toml index 1b4dab023f..707d881977 100644 --- a/src/daft-logical-plan/Cargo.toml +++ b/src/daft-logical-plan/Cargo.toml @@ -8,6 +8,7 @@ common-py-serde = {path = "../common/py-serde", default-features = false} common-resource-request = {path = "../common/resource-request", default-features = false} common-scan-info = {path = "../common/scan-info", default-features = false} common-treenode = {path = "../common/treenode", default-features = false} +daft-algebra = {path = "../daft-algebra", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-dsl = {path = "../daft-dsl", default-features = false} daft-functions = {path = "../daft-functions", default-features = false} diff --git a/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs b/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs index 2b77bd8e9a..6e5be33c40 100644 --- a/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs +++ b/src/daft-logical-plan/src/optimization/rules/push_down_filter.rs @@ -6,12 +6,11 @@ use std::{ use common_error::DaftResult; use common_scan_info::{rewrite_predicate_for_partitioning, PredicateGroups}; use common_treenode::{DynTreeNode, Transformed, TreeNode}; +use daft_algebra::boolean::{combine_conjunction, split_conjunction}; use daft_core::join::JoinType; use daft_dsl::{ col, - optimization::{ - conjuct, get_required_columns, replace_columns_with_expressions, split_conjuction, - }, + optimization::{get_required_columns, replace_columns_with_expressions}, ExprRef, }; @@ -56,20 +55,20 @@ impl PushDownFilter { // Filter-Filter --> Filter // Split predicate expression on conjunctions (ANDs). - let parent_predicates = split_conjuction(&filter.predicate); - let predicate_set: HashSet<&ExprRef> = parent_predicates.iter().copied().collect(); + let parent_predicates = split_conjunction(&filter.predicate); + let predicate_set: HashSet<&ExprRef> = parent_predicates.iter().collect(); // Add child predicate expressions to parent predicate expressions, eliminating duplicates. let new_predicates: Vec = parent_predicates .iter() .chain( - split_conjuction(&child_filter.predicate) + split_conjunction(&child_filter.predicate) .iter() - .filter(|e| !predicate_set.contains(**e)), + .filter(|e| !predicate_set.contains(*e)), ) .map(|e| (*e).clone()) .collect::>(); // Reconjunct predicate expressions. - let new_predicate = conjuct(new_predicates).unwrap(); + let new_predicate = combine_conjunction(new_predicates).unwrap(); let new_filter: Arc = LogicalPlan::from(Filter::try_new(child_filter.input.clone(), new_predicate)?) .into(); @@ -133,8 +132,8 @@ impl PushDownFilter { return Ok(Transformed::no(plan)); } - let data_filter = conjuct(data_only_filter); - let partition_filter = conjuct(partition_only_filter); + let data_filter = combine_conjunction(data_only_filter); + let partition_filter = combine_conjunction(partition_only_filter); assert!(data_filter.is_some() || partition_filter.is_some()); let new_pushdowns = if let Some(data_filter) = data_filter { @@ -158,7 +157,7 @@ impl PushDownFilter { // TODO(Clark): Support pushing predicates referencing both partition and data columns into the scan. let filter_op: LogicalPlan = Filter::try_new( new_source.into(), - conjuct(needing_filter_op).unwrap(), + combine_conjunction(needing_filter_op).unwrap(), )? .into(); return Ok(Transformed::yes(filter_op.into())); @@ -176,7 +175,7 @@ impl PushDownFilter { // don't involve compute. // // Filter-Projection --> {Filter-}Projection-Filter - let predicates = split_conjuction(&filter.predicate); + let predicates = split_conjunction(&filter.predicate); let projection_input_mapping = child_project .projection .iter() @@ -191,7 +190,7 @@ impl PushDownFilter { let mut can_push: Vec = vec![]; let mut can_not_push: Vec = vec![]; for predicate in predicates { - let predicate_cols = get_required_columns(predicate); + let predicate_cols = get_required_columns(&predicate); if predicate_cols .iter() .all(|col| projection_input_mapping.contains_key(col)) @@ -212,7 +211,7 @@ impl PushDownFilter { return Ok(Transformed::no(plan)); } // Create new Filter with predicates that can be pushed past Projection. - let predicates_to_push = conjuct(can_push).unwrap(); + let predicates_to_push = combine_conjunction(can_push).unwrap(); let push_down_filter: LogicalPlan = Filter::try_new(child_project.input.clone(), predicates_to_push)?.into(); // Create new Projection. @@ -226,7 +225,7 @@ impl PushDownFilter { } else { // Otherwise, add a Filter after Projection that filters with predicate expressions // that couldn't be pushed past the Projection, returning a Filter-Projection-Filter subplan. - let post_projection_predicate = conjuct(can_not_push).unwrap(); + let post_projection_predicate = combine_conjunction(can_not_push).unwrap(); let post_projection_filter: LogicalPlan = Filter::try_new(new_projection.into(), post_projection_predicate)?.into(); post_projection_filter.into() @@ -274,7 +273,7 @@ impl PushDownFilter { let left_cols = HashSet::<_>::from_iter(child_join.left.schema().names()); let right_cols = HashSet::<_>::from_iter(child_join.right.schema().names()); - for predicate in split_conjuction(&filter.predicate).into_iter().cloned() { + for predicate in split_conjunction(&filter.predicate) { let pred_cols = HashSet::<_>::from_iter(get_required_columns(&predicate)); match ( @@ -307,11 +306,11 @@ impl PushDownFilter { } } - let left_pushdowns = conjuct(left_pushdowns); - let right_pushdowns = conjuct(right_pushdowns); + let left_pushdowns = combine_conjunction(left_pushdowns); + let right_pushdowns = combine_conjunction(right_pushdowns); if left_pushdowns.is_some() || right_pushdowns.is_some() { - let kept_predicates = conjuct(kept_predicates); + let kept_predicates = combine_conjunction(kept_predicates); let new_left = left_pushdowns.map_or_else( || child_join.left.clone(), diff --git a/src/daft-logical-plan/src/optimization/rules/simplify_expressions.rs b/src/daft-logical-plan/src/optimization/rules/simplify_expressions.rs index bb890e2a17..bb395ae428 100644 --- a/src/daft-logical-plan/src/optimization/rules/simplify_expressions.rs +++ b/src/daft-logical-plan/src/optimization/rules/simplify_expressions.rs @@ -3,9 +3,7 @@ use std::sync::Arc; use common_error::DaftResult; use common_scan_info::{PhysicalScanInfo, ScanState}; use common_treenode::{Transformed, TreeNode}; -use daft_core::prelude::SchemaRef; -use daft_dsl::{lit, null_lit, Expr, ExprRef, LiteralValue, Operator}; -use daft_schema::dtype::DataType; +use daft_algebra::simplify_expr; use super::OptimizerRule; use crate::LogicalPlan; @@ -46,364 +44,13 @@ impl OptimizerRule for SimplifyExpressionsRule { } } -fn simplify_expr(expr: Expr, schema: &SchemaRef) -> DaftResult> { - Ok(match expr { - // ---------------- - // Eq - // ---------------- - // true = A --> A - // false = A --> !A - Expr::BinaryOp { - op: Operator::Eq, - left, - right, - } - // A = true --> A - // A = false --> !A - | Expr::BinaryOp { - op: Operator::Eq, - left: right, - right: left, - } if is_bool_lit(&left) && is_bool_type(&right, schema) => { - Transformed::yes(match as_bool_lit(&left) { - Some(true) => right, - Some(false) => right.not(), - None => unreachable!(), - }) - } - - // null = A --> null - // A = null --> null - Expr::BinaryOp { - op: Operator::Eq, - left, - right, - } - | Expr::BinaryOp { - op: Operator::Eq, - left: right, - right: left, - } if is_null(&left) && is_bool_type(&right, schema) => Transformed::yes(null_lit()), - - // ---------------- - // Neq - // ---------------- - // true != A --> !A - // false != A --> A - Expr::BinaryOp { - op: Operator::NotEq, - left, - right, - } - // A != true --> !A - // A != false --> A - | Expr::BinaryOp { - op: Operator::NotEq, - left: right, - right: left, - } if is_bool_lit(&left) && is_bool_type(&right, schema) => { - Transformed::yes(match as_bool_lit(&left) { - Some(true) => right.not(), - Some(false) => right, - None => unreachable!(), - }) - } - - // null != A --> null - // A != null --> null - Expr::BinaryOp { - op: Operator::NotEq, - left, - right, - } - | Expr::BinaryOp { - op: Operator::NotEq, - left: right, - right: left, - } if is_null(&left) && is_bool_type(&right, schema) => Transformed::yes(null_lit()), - - // ---------------- - // OR - // ---------------- - - // true OR A --> true - Expr::BinaryOp { - op: Operator::Or, - left, - right: _, - } if is_true(&left) => Transformed::yes(left), - // false OR A --> A - Expr::BinaryOp { - op: Operator::Or, - left, - right, - } if is_false(&left) => Transformed::yes(right), - // A OR true --> true - Expr::BinaryOp { - op: Operator::Or, - left: _, - right, - } if is_true(&right) => Transformed::yes(right), - // A OR false --> A - Expr::BinaryOp { - left, - op: Operator::Or, - right, - } if is_false(&right) => Transformed::yes(left), - - // ---------------- - // AND (TODO) - // ---------------- - - // ---------------- - // Multiplication - // ---------------- - - // A * 1 --> A - // 1 * A --> A - Expr::BinaryOp { - op: Operator::Multiply, - left, - right, - }| Expr::BinaryOp { - op: Operator::Multiply, - left: right, - right: left, - } if is_one(&right) => Transformed::yes(left), - - // A * null --> null - Expr::BinaryOp { - op: Operator::Multiply, - left: _, - right, - } if is_null(&right) => Transformed::yes(right), - // null * A --> null - Expr::BinaryOp { - op: Operator::Multiply, - left, - right: _, - } if is_null(&left) => Transformed::yes(left), - - // TODO: Can't do this one because we don't have a way to determine if an expr potentially contains nulls (nullable) - // A * 0 --> 0 (if A is not null and not floating/decimal) - // 0 * A --> 0 (if A is not null and not floating/decimal) - - // ---------------- - // Division - // ---------------- - // A / 1 --> A - Expr::BinaryOp { - op: Operator::TrueDivide, - left, - right, - } if is_one(&right) => Transformed::yes(left), - // null / A --> null - Expr::BinaryOp { - op: Operator::TrueDivide, - left, - right: _, - } if is_null(&left) => Transformed::yes(left), - // A / null --> null - Expr::BinaryOp { - op: Operator::TrueDivide, - left: _, - right, - } if is_null(&right) => Transformed::yes(right), - - // ---------------- - // Addition - // ---------------- - // A + 0 --> A - Expr::BinaryOp { - op: Operator::Plus, - left, - right, - } if is_zero(&right) => Transformed::yes(left), - - // 0 + A --> A - Expr::BinaryOp { - op: Operator::Plus, - left, - right, - } if is_zero(&left) => Transformed::yes(right), - - // ---------------- - // Subtraction - // ---------------- - - // A - 0 --> A - Expr::BinaryOp { - op: Operator::Minus, - left, - right, - } if is_zero(&right) => Transformed::yes(left), - - // A - null --> null - Expr::BinaryOp { - op: Operator::Minus, - left: _, - right, - } if is_null(&right) => Transformed::yes(right), - // null - A --> null - Expr::BinaryOp { - op: Operator::Minus, - left, - right: _, - } if is_null(&left) => Transformed::yes(left), - - // ---------------- - // Modulus - // ---------------- - - // A % null --> null - Expr::BinaryOp { - op: Operator::Modulus, - left: _, - right, - } if is_null(&right) => Transformed::yes(right), - - // null % A --> null - Expr::BinaryOp { - op: Operator::Modulus, - left, - right: _, - } if is_null(&left) => Transformed::yes(left), - - // A BETWEEN low AND high --> A >= low AND A <= high - Expr::Between(expr, low, high) => { - Transformed::yes(expr.clone().lt_eq(high).and(expr.gt_eq(low))) - } - Expr::Not(expr) => match Arc::unwrap_or_clone(expr) { - // NOT (BETWEEN A AND B) --> A < low OR A > high - Expr::Between(expr, low, high) => { - Transformed::yes(expr.clone().lt(low).or(expr.gt(high))) - } - // expr NOT IN () --> true - Expr::IsIn(_, list) if list.is_empty() => Transformed::yes(lit(true)), - - expr => { - let expr = simplify_expr(expr, schema)?; - if expr.transformed { - Transformed::yes(expr.data.not()) - } else { - Transformed::no(expr.data.not()) - } - } - }, - // expr IN () --> false - Expr::IsIn(_, list) if list.is_empty() => Transformed::yes(lit(false)), - - other => Transformed::no(Arc::new(other)), - }) -} - -fn is_zero(s: &Expr) -> bool { - match s { - Expr::Literal(LiteralValue::Int32(0)) - | Expr::Literal(LiteralValue::Int64(0)) - | Expr::Literal(LiteralValue::UInt32(0)) - | Expr::Literal(LiteralValue::UInt64(0)) - | Expr::Literal(LiteralValue::Float64(0.)) => true, - Expr::Literal(LiteralValue::Decimal(v, _p, _s)) if *v == 0 => true, - _ => false, - } -} - -fn is_one(s: &Expr) -> bool { - match s { - Expr::Literal(LiteralValue::Int32(1)) - | Expr::Literal(LiteralValue::Int64(1)) - | Expr::Literal(LiteralValue::UInt32(1)) - | Expr::Literal(LiteralValue::UInt64(1)) - | Expr::Literal(LiteralValue::Float64(1.)) => true, - - Expr::Literal(LiteralValue::Decimal(v, _p, s)) => { - *s >= 0 && POWS_OF_TEN.get(*s as usize).is_some_and(|pow| v == pow) - } - _ => false, - } -} - -fn is_true(expr: &Expr) -> bool { - match expr { - Expr::Literal(LiteralValue::Boolean(v)) => *v, - _ => false, - } -} -fn is_false(expr: &Expr) -> bool { - match expr { - Expr::Literal(LiteralValue::Boolean(v)) => !*v, - _ => false, - } -} - -/// returns true if expr is a -/// `Expr::Literal(LiteralValue::Boolean(v))` , false otherwise -fn is_bool_lit(expr: &Expr) -> bool { - matches!(expr, Expr::Literal(LiteralValue::Boolean(_))) -} - -fn is_bool_type(expr: &Expr, schema: &SchemaRef) -> bool { - matches!(expr.get_type(schema), Ok(DataType::Boolean)) -} - -fn as_bool_lit(expr: &Expr) -> Option { - expr.as_literal().and_then(|l| l.as_bool()) -} - -fn is_null(expr: &Expr) -> bool { - matches!(expr, Expr::Literal(LiteralValue::Null)) -} - -static POWS_OF_TEN: [i128; 38] = [ - 1, - 10, - 100, - 1000, - 10000, - 100000, - 1000000, - 10000000, - 100000000, - 1000000000, - 10000000000, - 100000000000, - 1000000000000, - 10000000000000, - 100000000000000, - 1000000000000000, - 10000000000000000, - 100000000000000000, - 1000000000000000000, - 10000000000000000000, - 100000000000000000000, - 1000000000000000000000, - 10000000000000000000000, - 100000000000000000000000, - 1000000000000000000000000, - 10000000000000000000000000, - 100000000000000000000000000, - 1000000000000000000000000000, - 10000000000000000000000000000, - 100000000000000000000000000000, - 1000000000000000000000000000000, - 10000000000000000000000000000000, - 100000000000000000000000000000000, - 1000000000000000000000000000000000, - 10000000000000000000000000000000000, - 100000000000000000000000000000000000, - 1000000000000000000000000000000000000, - 10000000000000000000000000000000000000, -]; - #[cfg(test)] mod test { use std::sync::Arc; use daft_core::prelude::Schema; - use daft_dsl::{col, lit, null_lit, ExprRef}; + use daft_dsl::{col, lit}; use daft_schema::{dtype::DataType, field::Field}; - use rstest::rstest; use super::SimplifyExpressionsRule; use crate::{ @@ -436,112 +83,6 @@ mod test { ) } - #[rstest] - // true = A --> A - #[case(col("bool").eq(lit(true)), col("bool"))] - // false = A --> !A - #[case(col("bool").eq(lit(false)), col("bool").not())] - // A == true ---> A - #[case(col("bool").eq(lit(true)), col("bool"))] - // null = A --> null - #[case(null_lit().eq(col("bool")), null_lit())] - // A == false ---> !A - #[case(col("bool").eq(lit(false)), col("bool").not())] - // true != A --> !A - #[case(lit(true).not_eq(col("bool")), col("bool").not())] - // false != A --> A - #[case(lit(false).not_eq(col("bool")), col("bool"))] - // true OR A --> true - #[case(lit(true).or(col("bool")), lit(true))] - // false OR A --> A - #[case(lit(false).or(col("bool")), col("bool"))] - // A OR true --> true - #[case(col("bool").or(lit(true)), lit(true))] - // A OR false --> A - #[case(col("bool").or(lit(false)), col("bool"))] - fn test_simplify_bool_exprs(#[case] input: ExprRef, #[case] expected: ExprRef) { - let source = make_source().filter(input).unwrap().build(); - let optimizer = SimplifyExpressionsRule::new(); - let optimized = optimizer.try_optimize(source).unwrap(); - - let LogicalPlan::Filter(Filter { predicate, .. }) = optimized.data.as_ref() else { - panic!("Expected Filter, got {:?}", optimized.data) - }; - - // make sure the expression is simplified - assert!(optimized.transformed); - - assert_eq!(predicate, &expected); - } - - #[rstest] - // A * 1 --> A - #[case(col("int").mul(lit(1)), col("int"))] - // 1 * A --> A - #[case(lit(1).mul(col("int")), col("int"))] - // A / 1 --> A - #[case(col("int").div(lit(1)), col("int"))] - // A + 0 --> A - #[case(col("int").add(lit(0)), col("int"))] - // A - 0 --> A - #[case(col("int").sub(lit(0)), col("int"))] - fn test_math_exprs(#[case] input: ExprRef, #[case] expected: ExprRef) { - let source = make_source().select(vec![input]).unwrap().build(); - let optimizer = SimplifyExpressionsRule::new(); - let optimized = optimizer.try_optimize(source).unwrap(); - - let LogicalPlan::Project(Project { projection, .. }) = optimized.data.as_ref() else { - panic!("Expected Filter, got {:?}", optimized.data) - }; - - let projection = projection.first().unwrap(); - - // make sure the expression is simplified - assert!(optimized.transformed); - - assert_eq!(projection, &expected); - } - - #[test] - fn test_not_between() { - let source = make_source() - .filter(col("int").between(lit(1), lit(10)).not()) - .unwrap() - .build(); - let optimizer = SimplifyExpressionsRule::new(); - let optimized = optimizer.try_optimize(source).unwrap(); - - let LogicalPlan::Filter(Filter { predicate, .. }) = optimized.data.as_ref() else { - panic!("Expected Filter, got {:?}", optimized.data) - }; - - // make sure the expression is simplified - assert!(optimized.transformed); - - assert_eq!(predicate, &col("int").lt(lit(1)).or(col("int").gt(lit(10)))); - } - - #[test] - fn test_between() { - let source = make_source() - .filter(col("int").between(lit(1), lit(10))) - .unwrap() - .build(); - let optimizer = SimplifyExpressionsRule::new(); - let optimized = optimizer.try_optimize(source).unwrap(); - - let LogicalPlan::Filter(Filter { predicate, .. }) = optimized.data.as_ref() else { - panic!("Expected Filter, got {:?}", optimized.data) - }; - - // make sure the expression is simplified - assert!(optimized.transformed); - - assert_eq!( - predicate, - &col("int").lt_eq(lit(10)).and(col("int").gt_eq(lit(1))) - ); - } #[test] fn test_nested_plan() { let source = make_source() diff --git a/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs b/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs index 3413e8cc53..5039cc9767 100644 --- a/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs +++ b/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs @@ -2,12 +2,9 @@ use std::{collections::HashSet, sync::Arc}; use common_error::{DaftError, DaftResult}; use common_treenode::{DynTreeNode, Transformed, TreeNode}; +use daft_algebra::boolean::{combine_conjunction, split_conjunction}; use daft_core::{join::JoinType, prelude::SchemaRef}; -use daft_dsl::{ - col, - optimization::{conjuct, split_conjuction}, - Expr, ExprRef, Operator, Subquery, -}; +use daft_dsl::{col, Expr, ExprRef, Operator, Subquery}; use itertools::multiunzip; use uuid::Uuid; @@ -73,12 +70,12 @@ impl UnnestScalarSubquery { impl UnnestScalarSubquery { fn unnest_subqueries( input: LogicalPlanRef, - exprs: Vec<&ExprRef>, + exprs: &[ExprRef], ) -> DaftResult)>> { let mut subqueries = HashSet::new(); let new_exprs = exprs - .into_iter() + .iter() .map(|expr| { expr.clone() .transform_down(|e| { @@ -164,7 +161,7 @@ impl OptimizerRule for UnnestScalarSubquery { input, predicate, .. }) => { let unnest_result = - Self::unnest_subqueries(input.clone(), split_conjuction(predicate))?; + Self::unnest_subqueries(input.clone(), &split_conjunction(predicate))?; if !unnest_result.transformed { return Ok(Transformed::no(node)); @@ -172,7 +169,7 @@ impl OptimizerRule for UnnestScalarSubquery { let (new_input, new_predicates) = unnest_result.data; - let new_predicate = conjuct(new_predicates) + let new_predicate = combine_conjunction(new_predicates) .expect("predicates are guaranteed to exist at this point, so 'conjunct' should never return 'None'"); let new_filter = Arc::new(LogicalPlan::Filter(Filter::try_new( @@ -192,7 +189,7 @@ impl OptimizerRule for UnnestScalarSubquery { input, projection, .. }) => { let unnest_result = - Self::unnest_subqueries(input.clone(), projection.iter().collect())?; + Self::unnest_subqueries(input.clone(), projection)?; if !unnest_result.transformed { return Ok(Transformed::no(node)); @@ -275,7 +272,7 @@ impl OptimizerRule for UnnestPredicateSubquery { }) => { let mut subqueries = HashSet::new(); - let new_predicates = split_conjuction(predicate) + let new_predicates = split_conjunction(predicate) .into_iter() .filter(|expr| { match expr.as_ref() { @@ -303,7 +300,6 @@ impl OptimizerRule for UnnestPredicateSubquery { _ => true } }) - .cloned() .collect::>(); if subqueries.is_empty() { @@ -345,7 +341,7 @@ impl OptimizerRule for UnnestPredicateSubquery { )?))) })?; - let new_plan = if let Some(new_predicate) = conjuct(new_predicates) { + let new_plan = if let Some(new_predicate) = combine_conjunction(new_predicates) { // add filter back if there are non-subquery predicates Arc::new(LogicalPlan::Filter(Filter::try_new( new_input, @@ -387,7 +383,7 @@ fn pull_up_correlated_cols( }) => { let mut found_correlated_col = false; - let preds = split_conjuction(predicate) + let preds = split_conjunction(predicate) .into_iter() .filter(|expr| { if let Expr::BinaryOp { @@ -418,7 +414,6 @@ fn pull_up_correlated_cols( true }) - .cloned() .collect::>(); // no new correlated cols found @@ -426,7 +421,7 @@ fn pull_up_correlated_cols( return Ok((plan.clone(), subquery_on, outer_on)); } - if let Some(new_predicate) = conjuct(preds) { + if let Some(new_predicate) = combine_conjunction(preds) { let new_plan = Arc::new(LogicalPlan::Filter(Filter::try_new( input.clone(), new_predicate, diff --git a/src/daft-sql/Cargo.toml b/src/daft-sql/Cargo.toml index 6e45c23741..a402235011 100644 --- a/src/daft-sql/Cargo.toml +++ b/src/daft-sql/Cargo.toml @@ -3,6 +3,7 @@ common-daft-config = {path = "../common/daft-config"} common-error = {path = "../common/error"} common-io-config = {path = "../common/io-config", default-features = false} common-runtime = {workspace = true} +daft-algebra = {path = "../daft-algebra"} daft-core = {path = "../daft-core"} daft-dsl = {path = "../daft-dsl"} daft-functions = {path = "../daft-functions"} diff --git a/src/daft-sql/src/planner.rs b/src/daft-sql/src/planner.rs index 683391b601..ce2ef703a3 100644 --- a/src/daft-sql/src/planner.rs +++ b/src/daft-sql/src/planner.rs @@ -6,13 +6,13 @@ use std::{ }; use common_error::{DaftError, DaftResult}; +use daft_algebra::boolean::combine_conjunction; use daft_core::prelude::*; use daft_dsl::{ col, common_treenode::{Transformed, TreeNode}, - has_agg, lit, literals_to_series, null_lit, - optimization::conjuct, - AggExpr, Expr, ExprRef, LiteralValue, Operator, OuterReferenceColumn, Subquery, + has_agg, lit, literals_to_series, null_lit, AggExpr, Expr, ExprRef, LiteralValue, Operator, + OuterReferenceColumn, Subquery, }; use daft_functions::{ numeric::{ceil::ceil, floor::floor}, @@ -959,12 +959,12 @@ impl<'a> SQLPlanner<'a> { }; let mut left_plan = self.current_relation.as_ref().unwrap().inner.clone(); - if let Some(left_predicate) = conjuct(left_filters) { + if let Some(left_predicate) = combine_conjunction(left_filters) { left_plan = left_plan.filter(left_predicate)?; } let mut right_plan = right_rel.inner.clone(); - if let Some(right_predicate) = conjuct(right_filters) { + if let Some(right_predicate) = combine_conjunction(right_filters) { right_plan = right_plan.filter(right_predicate)?; } From 8620635483ce3d1bf063b381e2bb0b4abf7b4856 Mon Sep 17 00:00:00 2001 From: Kev Wang Date: Tue, 17 Dec 2024 14:21:09 -0800 Subject: [PATCH 07/14] perf: filter null join key optimization rule (#3583) --- .../src/optimization/optimizer.rs | 8 +- .../rules/filter_null_join_key.rs | 300 ++++++++++++++++++ .../src/optimization/rules/mod.rs | 2 + 3 files changed, 307 insertions(+), 3 deletions(-) create mode 100644 src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs diff --git a/src/daft-logical-plan/src/optimization/optimizer.rs b/src/daft-logical-plan/src/optimization/optimizer.rs index 76f6251438..64ed0a65c9 100644 --- a/src/daft-logical-plan/src/optimization/optimizer.rs +++ b/src/daft-logical-plan/src/optimization/optimizer.rs @@ -6,9 +6,10 @@ use common_treenode::Transformed; use super::{ logical_plan_tracker::LogicalPlanTracker, rules::{ - DropRepartition, EliminateCrossJoin, EnrichWithStats, LiftProjectFromAgg, MaterializeScans, - OptimizerRule, PushDownFilter, PushDownLimit, PushDownProjection, SimplifyExpressionsRule, - SplitActorPoolProjects, UnnestPredicateSubquery, UnnestScalarSubquery, + DropRepartition, EliminateCrossJoin, EnrichWithStats, FilterNullJoinKey, + LiftProjectFromAgg, MaterializeScans, OptimizerRule, PushDownFilter, PushDownLimit, + PushDownProjection, SimplifyExpressionsRule, SplitActorPoolProjects, + UnnestPredicateSubquery, UnnestScalarSubquery, }, }; use crate::LogicalPlan; @@ -109,6 +110,7 @@ impl Optimizer { RuleBatch::new( vec![ Box::new(DropRepartition::new()), + Box::new(FilterNullJoinKey::new()), Box::new(PushDownFilter::new()), Box::new(PushDownProjection::new()), Box::new(EliminateCrossJoin::new()), diff --git a/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs b/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs new file mode 100644 index 0000000000..80aab76c18 --- /dev/null +++ b/src/daft-logical-plan/src/optimization/rules/filter_null_join_key.rs @@ -0,0 +1,300 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use common_treenode::{Transformed, TreeNode}; +use daft_algebra::boolean::combine_conjunction; +use daft_core::join::JoinType; + +use super::OptimizerRule; +use crate::{ + ops::{Filter, Join}, + LogicalPlan, +}; + +/// Optimization rule for filtering out nulls from join keys. +/// +/// When a join will always discard null keys from a join side, +/// this rule inserts a filter before that side to remove rows where a join key is null. +/// This reduces the cardinality of the tables before a join to improve join performance, +/// and can also be pushed down with other rules to reduce source and intermediate output sizes. +/// +/// # Example +/// ```sql +/// SELECT * FROM left JOIN right ON left.x = right.y +/// ``` +/// turns into +/// ```sql +/// SELECT * +/// FROM (SELECT * FROM left WHERE x IS NOT NULL) AS non_null_left +/// JOIN (SELECT * FROM right WHERE x IS NOT NULL) AS non_null_right +/// ON non_null_left.x = non_null_right.y +/// ``` +/// +/// So if `left` was: +/// ``` +/// ╭───────╮ +/// │ x │ +/// │ --- │ +/// │ Int64 │ +/// ╞═══════╡ +/// │ 1 │ +/// ├╌╌╌╌╌╌╌┤ +/// │ 2 │ +/// ├╌╌╌╌╌╌╌┤ +/// │ None │ +/// ╰───────╯ +/// ``` +/// And `right` was: +/// ``` +/// ╭───────╮ +/// │ y │ +/// │ --- │ +/// │ Int64 │ +/// ╞═══════╡ +/// │ 1 │ +/// ├╌╌╌╌╌╌╌┤ +/// │ None │ +/// ├╌╌╌╌╌╌╌┤ +/// │ None │ +/// ╰───────╯ +/// ``` +/// the original query would join on all rows, whereas the new query would first filter out null rows and join on the following: +/// +/// `non_null_left`: +/// ``` +/// ╭───────╮ +/// │ x │ +/// │ --- │ +/// │ Int64 │ +/// ╞═══════╡ +/// │ 1 │ +/// ├╌╌╌╌╌╌╌┤ +/// │ 2 │ +/// ╰───────╯ +/// ``` +/// `non_null_right`: +/// ``` +/// ╭───────╮ +/// │ y │ +/// │ --- │ +/// │ Int64 │ +/// ╞═══════╡ +/// │ 1 │ +/// ╰───────╯ +/// ``` +#[derive(Default, Debug)] +pub struct FilterNullJoinKey {} + +impl FilterNullJoinKey { + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for FilterNullJoinKey { + fn try_optimize(&self, plan: Arc) -> DaftResult>> { + plan.transform(|node| { + if let LogicalPlan::Join(Join { + left, + right, + left_on, + right_on, + null_equals_nulls, + join_type, + .. + }) = node.as_ref() + { + let mut null_equals_nulls_iter = null_equals_nulls.as_ref().map_or_else( + || Box::new(std::iter::repeat(false)) as Box>, + |x| Box::new(x.clone().into_iter()), + ); + + let (can_filter_left, can_filter_right) = match join_type { + JoinType::Inner => (true, true), + JoinType::Left => (false, true), + JoinType::Right => (true, false), + JoinType::Outer => (false, false), + JoinType::Anti => (false, true), + JoinType::Semi => (true, true), + }; + + let left_null_pred = if can_filter_left { + combine_conjunction( + null_equals_nulls_iter + .by_ref() + .zip(left_on) + .filter(|(null_eq_null, _)| !null_eq_null) + .map(|(_, left_key)| left_key.clone().is_null().not()), + ) + } else { + None + }; + + let right_null_pred = if can_filter_right { + combine_conjunction( + null_equals_nulls_iter + .by_ref() + .zip(right_on) + .filter(|(null_eq_null, _)| !null_eq_null) + .map(|(_, right_key)| right_key.clone().is_null().not()), + ) + } else { + None + }; + + if left_null_pred.is_none() && right_null_pred.is_none() { + Ok(Transformed::no(node.clone())) + } else { + let new_left = if let Some(pred) = left_null_pred { + Arc::new(LogicalPlan::Filter(Filter::try_new(left.clone(), pred)?)) + } else { + left.clone() + }; + + let new_right = if let Some(pred) = right_null_pred { + Arc::new(LogicalPlan::Filter(Filter::try_new(right.clone(), pred)?)) + } else { + right.clone() + }; + + let new_join = Arc::new(node.with_new_children(&[new_left, new_right])); + + Ok(Transformed::yes(new_join)) + } + } else { + Ok(Transformed::no(node)) + } + }) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use common_error::DaftResult; + use daft_core::prelude::*; + use daft_dsl::col; + + use crate::{ + optimization::{ + optimizer::{RuleBatch, RuleExecutionStrategy}, + rules::filter_null_join_key::FilterNullJoinKey, + test::assert_optimized_plan_with_rules_eq, + }, + test::{dummy_scan_node, dummy_scan_operator}, + LogicalPlan, + }; + + /// Helper that creates an optimizer with the FilterNullJoinKey rule registered, optimizes + /// the provided plan with said optimizer, and compares the optimized plan with + /// the provided expected plan. + fn assert_optimized_plan_eq( + plan: Arc, + expected: Arc, + ) -> DaftResult<()> { + assert_optimized_plan_with_rules_eq( + plan, + expected, + vec![RuleBatch::new( + vec![Box::new(FilterNullJoinKey::new())], + RuleExecutionStrategy::Once, + )], + ) + } + + #[test] + fn filter_keys_basic() -> DaftResult<()> { + let left_scan = dummy_scan_node(dummy_scan_operator(vec![ + Field::new("a", DataType::Int64), + Field::new("b", DataType::Utf8), + ])); + + let right_scan = dummy_scan_node(dummy_scan_operator(vec![ + Field::new("c", DataType::Int64), + Field::new("d", DataType::Utf8), + ])); + + let plan = left_scan + .join( + right_scan.clone(), + vec![col("a")], + vec![col("c")], + JoinType::Inner, + None, + None, + None, + false, + )? + .build(); + + let expected = left_scan + .filter(col("a").is_null().not())? + .clone() + .join( + right_scan.filter(col("c").is_null().not())?, + vec![col("a")], + vec![col("c")], + JoinType::Inner, + None, + None, + None, + false, + )? + .build(); + + assert_optimized_plan_eq(plan, expected)?; + + Ok(()) + } + + #[test] + fn filter_keys_null_equals_nulls() -> DaftResult<()> { + let left_scan = dummy_scan_node(dummy_scan_operator(vec![ + Field::new("a", DataType::Int64), + Field::new("b", DataType::Utf8), + Field::new("c", DataType::Boolean), + ])); + + let right_scan = dummy_scan_node(dummy_scan_operator(vec![ + Field::new("d", DataType::Int64), + Field::new("e", DataType::Utf8), + Field::new("f", DataType::Boolean), + ])); + + let plan = left_scan + .join_with_null_safe_equal( + right_scan.clone(), + vec![col("a"), col("b"), col("c")], + vec![col("d"), col("e"), col("f")], + Some(vec![false, true, false]), + JoinType::Left, + None, + None, + None, + false, + )? + .build(); + + let expected_predicate = col("d").is_null().not().and(col("f").is_null().not()); + + let expected = left_scan + .clone() + .join_with_null_safe_equal( + right_scan.filter(expected_predicate)?, + vec![col("a"), col("b"), col("c")], + vec![col("d"), col("e"), col("f")], + Some(vec![false, true, false]), + JoinType::Left, + None, + None, + None, + false, + )? + .build(); + + assert_optimized_plan_eq(plan, expected)?; + + Ok(()) + } +} diff --git a/src/daft-logical-plan/src/optimization/rules/mod.rs b/src/daft-logical-plan/src/optimization/rules/mod.rs index f540a77cb0..787c0e7ac6 100644 --- a/src/daft-logical-plan/src/optimization/rules/mod.rs +++ b/src/daft-logical-plan/src/optimization/rules/mod.rs @@ -1,6 +1,7 @@ mod drop_repartition; mod eliminate_cross_join; mod enrich_with_stats; +mod filter_null_join_key; mod lift_project_from_agg; mod materialize_scans; mod push_down_filter; @@ -15,6 +16,7 @@ mod unnest_subquery; pub use drop_repartition::DropRepartition; pub use eliminate_cross_join::EliminateCrossJoin; pub use enrich_with_stats::EnrichWithStats; +pub use filter_null_join_key::FilterNullJoinKey; pub use lift_project_from_agg::LiftProjectFromAgg; pub use materialize_scans::MaterializeScans; pub use push_down_filter::PushDownFilter; From cc5ad0099d2cb9197a590886b0344c92f7821bc4 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Tue, 17 Dec 2024 17:39:05 -0800 Subject: [PATCH 08/14] chore: Fix ordering in sql tests (#3596) Co-authored-by: Colin Ho --- tests/dataframe/test_intersect.py | 2 +- tests/integration/sql/docker-compose/docker-compose.yml | 7 +++---- tests/sql/test_sql.py | 5 ++--- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/dataframe/test_intersect.py b/tests/dataframe/test_intersect.py index 4061fe7f71..59e9d9a79e 100644 --- a/tests/dataframe/test_intersect.py +++ b/tests/dataframe/test_intersect.py @@ -7,7 +7,7 @@ def test_simple_intersect(make_df): df1 = make_df({"foo": [1, 2, 3]}) df2 = make_df({"bar": [2, 3, 4]}) - result = df1.intersect(df2) + result = df1.intersect(df2).sort(by="foo") assert result.to_pydict() == {"foo": [2, 3]} diff --git a/tests/integration/sql/docker-compose/docker-compose.yml b/tests/integration/sql/docker-compose/docker-compose.yml index b8eb8c3eba..3748cb1da4 100644 --- a/tests/integration/sql/docker-compose/docker-compose.yml +++ b/tests/integration/sql/docker-compose/docker-compose.yml @@ -1,13 +1,12 @@ -version: '3.7' services: trino: - image: trinodb/trino + image: trinodb/trino:467 container_name: trino ports: - 8080:8080 postgres: - image: postgres:latest + image: postgres:17.2 container_name: postgres environment: POSTGRES_DB: postgres @@ -19,7 +18,7 @@ services: - postgres_data:/var/lib/postgresql/data mysql: - image: mysql:latest + image: mysql:9.1 container_name: mysql environment: MYSQL_DATABASE: mysql diff --git a/tests/sql/test_sql.py b/tests/sql/test_sql.py index f06c4f64f7..caa5d2a8a9 100644 --- a/tests/sql/test_sql.py +++ b/tests/sql/test_sql.py @@ -211,9 +211,8 @@ def test_sql_tbl_alias(): def test_sql_distinct(): df = daft.from_pydict({"n": [1, 1, 2, 2]}) - actual = daft.sql("SELECT DISTINCT n FROM df").collect().to_pydict() - expected = df.distinct().collect().to_pydict() - assert actual == expected + df = daft.sql("SELECT DISTINCT n FROM df").collect().to_pydict() + assert set(df["n"]) == {1, 2} @pytest.mark.parametrize( From 4bb041348f77c493c5a69399e5d9d6cf21f5d340 Mon Sep 17 00:00:00 2001 From: Desmond Cheong Date: Wed, 18 Dec 2024 00:18:57 -0800 Subject: [PATCH 09/14] fix(parquet): Fix parquet reads of required fields nested within optional fields (#3598) In https://dist-data.slack.com/archives/C041NA2RBFD/p1734108409644689, it was discovered that when a parquet schema contained required fields nested within optional fields, e.g. ``` optional group struct_field = 10 { required binary nested_field (STRING) = 38; } ``` we would hit the following error: ``` validity mask length must match the number of values ``` Nested parquet fields can be marked as `required` within `optional` parent fields because the parent field may be optional, but if the parent field is defined then the nested field _must also be defined_. `required`/`optional` affects the definition levels that encode parquet values. Internally, Daft also constructs `Required*` or `Optional*`[1] state based on whether the field is marked as `required` vs `optional` respectively. `Required*` state means that we don't rely on a validity mask, whereas for `Optional*` state we maintain a validity mask. However, when a parent field is `optional`, even if a nested field is `required`, there can be situations where the nested value is NULL because the parent value is NULL. In such situations, the `required` nested field should keep track `Optional*` state (i.e. keep track of a validity mask), but should still decode parquet definition levels as if it were `required`. In this PR, we keep track if any parent of a nested field is nullable. In those cases we keep track of `Optional*` state. [1]: E.g. `RequiredDictionary` vs `OptionalDictionary` states. --- .../parquet/read/deserialize/binary/nested.rs | 7 ++- .../read/deserialize/boolean/nested.rs | 7 ++- .../read/deserialize/dictionary/nested.rs | 4 +- .../deserialize/fixed_size_binary/nested.rs | 7 ++- .../src/io/parquet/read/deserialize/mod.rs | 4 +- .../src/io/parquet/read/deserialize/nested.rs | 27 ++++++++ .../parquet/read/deserialize/nested_utils.rs | 6 +- .../parquet/read/deserialize/null/nested.rs | 5 ++ .../read/deserialize/primitive/nested.rs | 8 ++- tests/io/test_parquet.py | 61 +++++++++++++++++++ 10 files changed, 128 insertions(+), 8 deletions(-) diff --git a/src/arrow2/src/io/parquet/read/deserialize/binary/nested.rs b/src/arrow2/src/io/parquet/read/deserialize/binary/nested.rs index 998b8c03cf..a06b741092 100644 --- a/src/arrow2/src/io/parquet/read/deserialize/binary/nested.rs +++ b/src/arrow2/src/io/parquet/read/deserialize/binary/nested.rs @@ -53,8 +53,9 @@ impl<'a, O: Offset> NestedDecoder<'a> for BinaryDecoder { &self, page: &'a DataPage, dict: Option<&'a Self::Dictionary>, + is_parent_nullable: bool, ) -> Result { - let is_optional = + let is_optional = is_parent_nullable || page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; let is_filtered = page.selected_rows().is_some(); @@ -145,6 +146,7 @@ pub struct NestedIter { chunk_size: Option, rows_remaining: usize, values_remaining: usize, + is_parent_nullable: bool, } impl NestedIter { @@ -155,6 +157,7 @@ impl NestedIter { num_rows: usize, chunk_size: Option, num_values: usize, + is_parent_nullable: bool, ) -> Self { Self { iter, @@ -165,6 +168,7 @@ impl NestedIter { chunk_size, rows_remaining: num_rows, values_remaining: num_values, + is_parent_nullable, } } } @@ -182,6 +186,7 @@ impl Iterator for NestedIter { &mut self.values_remaining, &self.init, self.chunk_size, + self.is_parent_nullable, &BinaryDecoder::::default(), ); match maybe_state { diff --git a/src/arrow2/src/io/parquet/read/deserialize/boolean/nested.rs b/src/arrow2/src/io/parquet/read/deserialize/boolean/nested.rs index 8a811deba7..6fd0f9afd5 100644 --- a/src/arrow2/src/io/parquet/read/deserialize/boolean/nested.rs +++ b/src/arrow2/src/io/parquet/read/deserialize/boolean/nested.rs @@ -53,8 +53,9 @@ impl<'a> NestedDecoder<'a> for BooleanDecoder { &self, page: &'a DataPage, _: Option<&'a Self::Dictionary>, + is_parent_nullable: bool, ) -> Result { - let is_optional = + let is_optional = is_parent_nullable || page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; let is_filtered = page.selected_rows().is_some(); @@ -116,6 +117,7 @@ pub struct NestedIter { rows_remaining: usize, chunk_size: Option, values_remaining: usize, + is_parent_nullable: bool, } impl NestedIter { @@ -125,6 +127,7 @@ impl NestedIter { num_rows: usize, chunk_size: Option, num_values: usize, + is_parent_nullable: bool, ) -> Self { Self { iter, @@ -133,6 +136,7 @@ impl NestedIter { rows_remaining: num_rows, chunk_size, values_remaining: num_values, + is_parent_nullable, } } } @@ -153,6 +157,7 @@ impl Iterator for NestedIter { &mut self.values_remaining, &self.init, self.chunk_size, + self.is_parent_nullable, &BooleanDecoder::default(), ); match maybe_state { diff --git a/src/arrow2/src/io/parquet/read/deserialize/dictionary/nested.rs b/src/arrow2/src/io/parquet/read/deserialize/dictionary/nested.rs index 89448bf243..0ab5e520cd 100644 --- a/src/arrow2/src/io/parquet/read/deserialize/dictionary/nested.rs +++ b/src/arrow2/src/io/parquet/read/deserialize/dictionary/nested.rs @@ -87,8 +87,9 @@ impl<'a, K: DictionaryKey> NestedDecoder<'a> for DictionaryDecoder { &self, page: &'a DataPage, _: Option<&'a Self::Dictionary>, + is_parent_nullable: bool, ) -> Result { - let is_optional = + let is_optional = is_parent_nullable || page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; let is_filtered = page.selected_rows().is_some(); @@ -200,6 +201,7 @@ pub fn next_dict Box> panic!("Reading a dictionary logical type with Daft is not currently supported. Please file an issue."), &DictionaryDecoder::::default(), chunk_size, + panic!("Reading a dictionary logical type with Daft is not currently supported. Please file an issue."), ); match error { Ok(_) => {} diff --git a/src/arrow2/src/io/parquet/read/deserialize/fixed_size_binary/nested.rs b/src/arrow2/src/io/parquet/read/deserialize/fixed_size_binary/nested.rs index bda884bde2..cd9b2059c7 100644 --- a/src/arrow2/src/io/parquet/read/deserialize/fixed_size_binary/nested.rs +++ b/src/arrow2/src/io/parquet/read/deserialize/fixed_size_binary/nested.rs @@ -50,8 +50,9 @@ impl<'a> NestedDecoder<'a> for BinaryDecoder { &self, page: &'a DataPage, dict: Option<&'a Self::Dictionary>, + is_parent_nullable: bool, ) -> Result { - let is_optional = + let is_optional = is_parent_nullable || page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; let is_filtered = page.selected_rows().is_some(); @@ -141,6 +142,7 @@ pub struct NestedIter { chunk_size: Option, rows_remaining: usize, values_remaining: usize, + is_parent_nullable: bool, } impl NestedIter { @@ -151,6 +153,7 @@ impl NestedIter { num_rows: usize, chunk_size: Option, num_values: usize, + is_parent_nullable: bool, ) -> Self { let size = FixedSizeBinaryArray::get_size(&data_type); Self { @@ -163,6 +166,7 @@ impl NestedIter { chunk_size, rows_remaining: num_rows, values_remaining: num_values, + is_parent_nullable, } } } @@ -179,6 +183,7 @@ impl Iterator for NestedIter { &mut self.values_remaining, &self.init, self.chunk_size, + self.is_parent_nullable, &BinaryDecoder { size: self.size }, ); match maybe_state { diff --git a/src/arrow2/src/io/parquet/read/deserialize/mod.rs b/src/arrow2/src/io/parquet/read/deserialize/mod.rs index c05978d748..3d641e452f 100644 --- a/src/arrow2/src/io/parquet/read/deserialize/mod.rs +++ b/src/arrow2/src/io/parquet/read/deserialize/mod.rs @@ -154,7 +154,7 @@ where } nested::columns_to_iter_recursive( - columns, types, field, init, num_rows, chunk_size, num_values, + columns, types, field, init, num_rows, chunk_size, num_values, false, ) } @@ -242,7 +242,7 @@ where { Ok(Box::new( nested::columns_to_iter_recursive( - columns, types, field, init, num_rows, chunk_size, num_values, + columns, types, field, init, num_rows, chunk_size, num_values, false, )? .map(|x| x.map(|x| x.1)), )) diff --git a/src/arrow2/src/io/parquet/read/deserialize/nested.rs b/src/arrow2/src/io/parquet/read/deserialize/nested.rs index ba7125b76f..c4b09bcc42 100644 --- a/src/arrow2/src/io/parquet/read/deserialize/nested.rs +++ b/src/arrow2/src/io/parquet/read/deserialize/nested.rs @@ -34,6 +34,7 @@ where })) } +#[allow(clippy::too_many_arguments)] pub fn columns_to_iter_recursive<'a, I>( mut columns: Vec, mut types: Vec<&PrimitiveType>, @@ -42,6 +43,7 @@ pub fn columns_to_iter_recursive<'a, I>( num_rows: usize, chunk_size: Option, mut num_values: Vec, + is_parent_nullable: bool, ) -> Result> where I: Pages + 'a, @@ -61,6 +63,7 @@ where num_rows, chunk_size, num_values.pop().unwrap(), + is_parent_nullable, )) } Boolean => { @@ -72,6 +75,7 @@ where num_rows, chunk_size, num_values.pop().unwrap(), + is_parent_nullable, )) } Primitive(Int8) => { @@ -84,6 +88,7 @@ where num_rows, chunk_size, num_values.pop().unwrap(), + is_parent_nullable, |x: i32| x as i8, )) } @@ -97,6 +102,7 @@ where num_rows, chunk_size, num_values.pop().unwrap(), + is_parent_nullable, |x: i32| x as i16, )) } @@ -110,6 +116,7 @@ where num_rows, chunk_size, num_values.pop().unwrap(), + is_parent_nullable, |x: i32| x, )) } @@ -123,6 +130,7 @@ where num_rows, chunk_size, num_values.pop().unwrap(), + is_parent_nullable, |x: i64| x, )) } @@ -136,6 +144,7 @@ where num_rows, chunk_size, num_values.pop().unwrap(), + is_parent_nullable, |x: i32| x as u8, )) } @@ -149,6 +158,7 @@ where num_rows, chunk_size, num_values.pop().unwrap(), + is_parent_nullable, |x: i32| x as u16, )) } @@ -163,6 +173,7 @@ where num_rows, chunk_size, num_values.pop().unwrap(), + is_parent_nullable, |x: i32| x as u32, )), // some implementations of parquet write arrow's u32 into i64. @@ -173,6 +184,7 @@ where num_rows, chunk_size, num_values.pop().unwrap(), + is_parent_nullable, |x: i64| x as u32, )), other => { @@ -192,6 +204,7 @@ where num_rows, chunk_size, num_values.pop().unwrap(), + is_parent_nullable, |x: i64| x as u64, )) } @@ -205,6 +218,7 @@ where num_rows, chunk_size, num_values.pop().unwrap(), + is_parent_nullable, |x: f32| x, )) } @@ -218,6 +232,7 @@ where num_rows, chunk_size, num_values.pop().unwrap(), + is_parent_nullable, |x: f64| x, )) } @@ -231,6 +246,7 @@ where num_rows, chunk_size, num_values.pop().unwrap(), + is_parent_nullable, )) } LargeBinary | LargeUtf8 => { @@ -243,6 +259,7 @@ where num_rows, chunk_size, num_values.pop().unwrap(), + is_parent_nullable, )) } _ => match field.data_type().to_logical_type() { @@ -274,6 +291,7 @@ where num_rows, chunk_size, num_values, + is_parent_nullable || field.is_nullable, )?; let iter = iter.map(move |x| { let (mut nested, array) = x?; @@ -293,6 +311,7 @@ where num_rows, chunk_size, num_values.pop().unwrap(), + is_parent_nullable, |x: i32| x as i128, )), PhysicalType::Int64 => primitive(primitive::NestedIter::new( @@ -302,6 +321,7 @@ where num_rows, chunk_size, num_values.pop().unwrap(), + is_parent_nullable, |x: i64| x as i128, )), PhysicalType::FixedLenByteArray(n) if n > 16 => { @@ -317,6 +337,7 @@ where num_rows, chunk_size, num_values.pop().unwrap(), + is_parent_nullable, ); // Convert the fixed length byte array to Decimal. let iter = iter.map(move |x| { @@ -359,6 +380,7 @@ where num_rows, chunk_size, num_values.pop().unwrap(), + is_parent_nullable, |x: i32| i256(I256::new(x as i128)), )), PhysicalType::Int64 => primitive(primitive::NestedIter::new( @@ -368,6 +390,7 @@ where num_rows, chunk_size, num_values.pop().unwrap(), + is_parent_nullable, |x: i64| i256(I256::new(x as i128)), )), PhysicalType::FixedLenByteArray(n) if n <= 16 => { @@ -378,6 +401,7 @@ where num_rows, chunk_size, num_values.pop().unwrap(), + is_parent_nullable, ); // Convert the fixed length byte array to Decimal. let iter = iter.map(move |x| { @@ -410,6 +434,7 @@ where num_rows, chunk_size, num_values.pop().unwrap(), + is_parent_nullable, ); // Convert the fixed length byte array to Decimal. let iter = iter.map(move |x| { @@ -465,6 +490,7 @@ where num_rows, chunk_size, num_values, + is_parent_nullable || field.is_nullable, ) }) .collect::>>()?; @@ -481,6 +507,7 @@ where num_rows, chunk_size, num_values, + is_parent_nullable || field.is_nullable, )?; let iter = iter.map(move |x| { let (mut nested, array) = x?; diff --git a/src/arrow2/src/io/parquet/read/deserialize/nested_utils.rs b/src/arrow2/src/io/parquet/read/deserialize/nested_utils.rs index 3772cf3099..879b529d3e 100644 --- a/src/arrow2/src/io/parquet/read/deserialize/nested_utils.rs +++ b/src/arrow2/src/io/parquet/read/deserialize/nested_utils.rs @@ -252,6 +252,7 @@ pub(super) trait NestedDecoder<'a> { &self, page: &'a DataPage, dict: Option<&'a Self::Dictionary>, + is_parent_nullable: bool, ) -> Result; /// Initializes a new state @@ -360,8 +361,9 @@ pub(super) fn extend<'a, D: NestedDecoder<'a>>( values_remaining: &mut usize, decoder: &D, chunk_size: Option, + is_parent_nullable: bool, ) -> Result<()> { - let mut values_page = decoder.build_state(page, dict)?; + let mut values_page = decoder.build_state(page, dict, is_parent_nullable)?; let mut page = NestedPage::try_new(page)?; let capacity = chunk_size.unwrap_or(0); @@ -574,6 +576,7 @@ pub(super) fn next<'a, I, D>( values_remaining: &mut usize, init: &[InitNested], chunk_size: Option, + is_parent_nullable: bool, decoder: &D, ) -> MaybeNext> where @@ -624,6 +627,7 @@ where values_remaining, decoder, chunk_size, + is_parent_nullable, ); match error { Ok(_) => {} diff --git a/src/arrow2/src/io/parquet/read/deserialize/null/nested.rs b/src/arrow2/src/io/parquet/read/deserialize/null/nested.rs index 8763cb0447..81cd924508 100644 --- a/src/arrow2/src/io/parquet/read/deserialize/null/nested.rs +++ b/src/arrow2/src/io/parquet/read/deserialize/null/nested.rs @@ -34,6 +34,7 @@ impl<'a> NestedDecoder<'a> for NullDecoder { &self, _page: &'a DataPage, dict: Option<&'a Self::Dictionary>, + _: bool, ) -> Result { if let Some(n) = dict { return Ok(*n); @@ -74,6 +75,7 @@ where rows_remaining: usize, chunk_size: Option, values_remaining: usize, + is_parent_nullable: bool, decoder: NullDecoder, } @@ -88,6 +90,7 @@ where num_rows: usize, chunk_size: Option, num_values: usize, + is_parent_nullable: bool, ) -> Self { Self { iter, @@ -97,6 +100,7 @@ where chunk_size, rows_remaining: num_rows, values_remaining: num_values, + is_parent_nullable, decoder: NullDecoder {}, } } @@ -117,6 +121,7 @@ where &mut self.values_remaining, &self.init, self.chunk_size, + self.is_parent_nullable, &self.decoder, ); match maybe_state { diff --git a/src/arrow2/src/io/parquet/read/deserialize/primitive/nested.rs b/src/arrow2/src/io/parquet/read/deserialize/primitive/nested.rs index 6b65f693fa..7790467c26 100644 --- a/src/arrow2/src/io/parquet/read/deserialize/primitive/nested.rs +++ b/src/arrow2/src/io/parquet/read/deserialize/primitive/nested.rs @@ -87,8 +87,9 @@ where &self, page: &'a DataPage, dict: Option<&'a Self::Dictionary>, + is_parent_nullable: bool, ) -> Result { - let is_optional = + let is_optional = is_parent_nullable || page.descriptor.primitive_type.field_info.repetition == Repetition::Optional; let is_filtered = page.selected_rows().is_some(); @@ -185,6 +186,7 @@ where rows_remaining: usize, chunk_size: Option, values_remaining: usize, + is_parent_nullable: bool, decoder: PrimitiveDecoder, } @@ -196,6 +198,7 @@ where P: ParquetNativeType, F: Copy + Fn(P) -> T, { + #[allow(clippy::too_many_arguments)] pub fn new( iter: I, init: Vec, @@ -203,6 +206,7 @@ where num_rows: usize, chunk_size: Option, num_values: usize, + is_parent_nullable: bool, op: F, ) -> Self { Self { @@ -214,6 +218,7 @@ where chunk_size, rows_remaining: num_rows, values_remaining: num_values, + is_parent_nullable, decoder: PrimitiveDecoder::new(op), } } @@ -238,6 +243,7 @@ where &mut self.values_remaining, &self.init, self.chunk_size, + self.is_parent_nullable, &self.decoder, ); match maybe_state { diff --git a/tests/io/test_parquet.py b/tests/io/test_parquet.py index 8249e291d1..52ff279349 100644 --- a/tests/io/test_parquet.py +++ b/tests/io/test_parquet.py @@ -381,3 +381,64 @@ def test_parquet_limits_across_row_groups(tmpdir, minio_io_config): ) # Reset the target row group size. daft.set_execution_config(parquet_target_row_group_size=default_row_group_size) + + +@pytest.mark.parametrize("optional_outer_struct", [True, False]) +@pytest.mark.parametrize("optional_inner_struct", [True, False]) +def test_parquet_nested_optional_or_required_fields(tmpdir, optional_outer_struct, optional_inner_struct): + schema = pa.schema( + [ + pa.field( + "outer_struct_field", + pa.struct( + [ + pa.field( + "inner_struct_field", + pa.struct( + [ + pa.field("optional_field_str", pa.string()), + pa.field("optional_field_binary", pa.binary()), + pa.field("optional_field_int", pa.int32()), + pa.field("optional_field_bool", pa.bool_()), + pa.field("required_field_str", pa.string(), nullable=False), + pa.field("required_field_binary", pa.binary(), nullable=False), + pa.field("required_field_int", pa.int32(), nullable=False), + pa.field("required_field_bool", pa.bool_(), nullable=False), + ] + ), + nullable=optional_inner_struct, + ) + ] + ), + nullable=optional_outer_struct, + ) + ] + ) + num_records = 8192 + data = [ + { + "outer_struct_field": None + if optional_outer_struct and i % 4 == 0 + else { + "inner_struct_field": None + if optional_inner_struct and i % 5 == 0 + else { + "optional_field_str": f"string_{i}" if i % 3 != 0 else None, + "optional_field_binary": f"binary_{i}".encode() if i % 5 != 0 else None, + "optional_field_int": i if i % 7 != 0 else None, + "optional_field_bool": bool(i % 3) if i % 11 != 0 else None, + "required_field_str": f"string_{i}", + "required_field_binary": f"binary_{i}".encode(), + "required_field_int": i, + "required_field_bool": bool(i % 3), + } + } + } + for i in range(num_records) + ] + expected = pa.Table.from_pylist(data, schema=schema) + output_file = f"{tmpdir}/{uuid.uuid4()!s}.parquet" + papq.write_table(expected, output_file) + expected = MicroPartition.from_arrow(expected) + df = daft.read_parquet(output_file) + assert df.to_arrow() == expected.to_arrow(), f"Expected:\n{expected.to_arrow()}\n\nReceived:\n{df.to_arrow()}" From 855a02d4630c2e73df7cdb58352d03cde3dd3d16 Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Wed, 18 Dec 2024 18:01:34 +0800 Subject: [PATCH 10/14] fix: guard concurrent extension datatype setting with a lock (#3589) Fixes a bug where concurrent accesses to `_ensure_registered_super_ext_type` might potentially cause race conditions, erroring out on multiple calls to `pa.register_extension_type(DaftExtension(pa.null()))` from different threads. --------- Co-authored-by: Jay Chia --- daft/datatype.py | 43 +++++++++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/daft/datatype.py b/daft/datatype.py index b15902c41d..45db65723b 100644 --- a/daft/datatype.py +++ b/daft/datatype.py @@ -1,5 +1,6 @@ from __future__ import annotations +import threading from typing import TYPE_CHECKING, Union from daft.context import get_context @@ -576,6 +577,7 @@ def __hash__(self) -> int: DataTypeLike = Union[DataType, type] +_EXT_TYPE_REGISTRATION_LOCK = threading.Lock() _EXT_TYPE_REGISTERED = False _STATIC_DAFT_EXTENSION = None @@ -583,31 +585,36 @@ def __hash__(self) -> int: def _ensure_registered_super_ext_type(): global _EXT_TYPE_REGISTERED global _STATIC_DAFT_EXTENSION + + # Double-checked locking: avoid grabbing the lock if we know that the ext type + # has already been registered. if not _EXT_TYPE_REGISTERED: + with _EXT_TYPE_REGISTRATION_LOCK: + if not _EXT_TYPE_REGISTERED: - class DaftExtension(pa.ExtensionType): - def __init__(self, dtype, metadata=b""): - # attributes need to be set first before calling - # super init (as that calls serialize) - self._metadata = metadata - super().__init__(dtype, "daft.super_extension") + class DaftExtension(pa.ExtensionType): + def __init__(self, dtype, metadata=b""): + # attributes need to be set first before calling + # super init (as that calls serialize) + self._metadata = metadata + super().__init__(dtype, "daft.super_extension") - def __reduce__(self): - return type(self).__arrow_ext_deserialize__, (self.storage_type, self.__arrow_ext_serialize__()) + def __reduce__(self): + return type(self).__arrow_ext_deserialize__, (self.storage_type, self.__arrow_ext_serialize__()) - def __arrow_ext_serialize__(self): - return self._metadata + def __arrow_ext_serialize__(self): + return self._metadata - @classmethod - def __arrow_ext_deserialize__(cls, storage_type, serialized): - return cls(storage_type, serialized) + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + return cls(storage_type, serialized) - _STATIC_DAFT_EXTENSION = DaftExtension - pa.register_extension_type(DaftExtension(pa.null())) - import atexit + _STATIC_DAFT_EXTENSION = DaftExtension + pa.register_extension_type(DaftExtension(pa.null())) + import atexit - atexit.register(lambda: pa.unregister_extension_type("daft.super_extension")) - _EXT_TYPE_REGISTERED = True + atexit.register(lambda: pa.unregister_extension_type("daft.super_extension")) + _EXT_TYPE_REGISTERED = True def get_super_ext_type(): From 07752b8aa2bde5fc6f7eb00639499fe3b4f8ca3c Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Wed, 18 Dec 2024 18:21:04 +0800 Subject: [PATCH 11/14] ci: Always download logs (#3588) Changes code to always download logs even in the event of a failure Co-authored-by: Jay Chia --- .github/workflows/run-cluster.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/run-cluster.yaml b/.github/workflows/run-cluster.yaml index e0262be2cb..2a6a2a531d 100644 --- a/.github/workflows/run-cluster.yaml +++ b/.github/workflows/run-cluster.yaml @@ -115,6 +115,7 @@ jobs: --env-vars='${{ inputs.env_vars }}' \ --enable-ray-tracing - name: Download log files from ray cluster + if: always() run: | source .venv/bin/activate ray rsync-down .github/assets/ray.yaml /tmp/ray/session_*/logs ray-daft-logs @@ -144,6 +145,7 @@ jobs: source .venv/bin/activate ray down .github/assets/ray.yaml -y - name: Upload log files + if: always() uses: actions/upload-artifact@v4 with: name: ray-daft-logs From 660250249f10b0e85f1a2c0ab1369edff4500be7 Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Wed, 18 Dec 2024 15:15:41 -0600 Subject: [PATCH 12/14] refactor: allow InMemory to take in non python based entries (#3554) --- Cargo.lock | 35 +- Cargo.toml | 15 +- src/common/partitioning/Cargo.toml | 17 + src/common/partitioning/src/lib.rs | 167 +++++++++ src/daft-connect/src/lib.rs | 1 + src/daft-connect/src/op/execute/root.rs | 12 +- src/daft-connect/src/op/execute/write.rs | 15 +- src/daft-connect/src/session.rs | 6 + src/daft-connect/src/translation.rs | 2 +- .../src/translation/datatype/codec.rs | 5 +- .../src/translation/logical_plan.rs | 133 ++++--- .../src/translation/logical_plan/aggregate.rs | 73 ++-- .../src/translation/logical_plan/drop.rs | 57 +-- .../src/translation/logical_plan/filter.rs | 30 +- .../logical_plan/local_relation.rs | 108 ++---- .../src/translation/logical_plan/project.rs | 24 +- .../src/translation/logical_plan/range.rs | 79 +++-- .../src/translation/logical_plan/read.rs | 7 +- .../src/translation/logical_plan/to_df.rs | 44 ++- .../translation/logical_plan/with_columns.rs | 45 +-- src/daft-connect/src/translation/schema.rs | 11 +- src/daft-local-execution/src/pipeline.rs | 14 +- src/daft-local-execution/src/run.rs | 28 +- .../src/sources/in_memory.rs | 8 +- src/daft-logical-plan/Cargo.toml | 1 + src/daft-logical-plan/src/builder.rs | 8 +- .../src/optimization/rules/push_down_limit.rs | 14 +- src/daft-logical-plan/src/source_info/mod.rs | 16 +- src/daft-micropartition/Cargo.toml | 20 +- src/daft-micropartition/src/lib.rs | 2 +- src/daft-micropartition/src/micropartition.rs | 1 + src/daft-micropartition/src/partitioning.rs | 328 ++++++++++-------- src/daft-scheduler/Cargo.toml | 1 + src/daft-scheduler/src/adaptive.rs | 3 +- 34 files changed, 783 insertions(+), 547 deletions(-) create mode 100644 src/common/partitioning/Cargo.toml create mode 100644 src/common/partitioning/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 048324ca60..9ad391fad1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -915,7 +915,7 @@ dependencies = [ "serde_json", "time", "url", - "uuid 1.10.0", + "uuid 1.11.0", ] [[package]] @@ -936,7 +936,7 @@ dependencies = [ "time", "tz-rs", "url", - "uuid 1.10.0", + "uuid 1.11.0", ] [[package]] @@ -958,7 +958,7 @@ dependencies = [ "sha2", "time", "url", - "uuid 1.10.0", + "uuid 1.11.0", ] [[package]] @@ -978,7 +978,7 @@ dependencies = [ "serde_json", "time", "url", - "uuid 1.10.0", + "uuid 1.11.0", ] [[package]] @@ -1537,6 +1537,17 @@ dependencies = [ "typetag", ] +[[package]] +name = "common-partitioning" +version = "0.3.0-dev0" +dependencies = [ + "common-error", + "common-py-serde", + "futures", + "pyo3", + "serde", +] + [[package]] name = "common-py-serde" version = "0.3.0-dev0" @@ -1900,6 +1911,7 @@ dependencies = [ "common-display", "common-file-formats", "common-hashable-float-wrapper", + "common-partitioning", "common-resource-request", "common-runtime", "common-scan-info", @@ -2013,7 +2025,7 @@ dependencies = [ "tokio", "tonic", "tracing", - "uuid 1.10.0", + "uuid 1.11.0", ] [[package]] @@ -2145,7 +2157,7 @@ dependencies = [ "tiktoken-rs", "tokio", "typetag", - "uuid 1.10.0", + "uuid 1.11.0", "xxhash-rust", ] @@ -2338,6 +2350,7 @@ dependencies = [ "common-error", "common-file-formats", "common-io-config", + "common-partitioning", "common-py-serde", "common-resource-request", "common-scan-info", @@ -2357,7 +2370,7 @@ dependencies = [ "serde", "snafu", "test-log", - "uuid 1.10.0", + "uuid 1.11.0", ] [[package]] @@ -2368,6 +2381,7 @@ dependencies = [ "bincode", "common-error", "common-file-formats", + "common-partitioning", "common-runtime", "common-scan-info", "daft-core", @@ -2379,10 +2393,12 @@ dependencies = [ "daft-scan", "daft-stats", "daft-table", + "dashmap", "futures", "parquet2", "pyo3", "snafu", + "tracing", ] [[package]] @@ -2498,6 +2514,7 @@ dependencies = [ "common-error", "common-file-formats", "common-io-config", + "common-partitioning", "common-py-serde", "daft-core", "daft-dsl", @@ -6894,9 +6911,9 @@ dependencies = [ [[package]] name = "uuid" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" dependencies = [ "getrandom 0.2.15", "serde", diff --git a/Cargo.toml b/Cargo.toml index ba00e45cd4..b371fd37be 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,6 +3,7 @@ common-daft-config = {path = "src/common/daft-config", default-features = false} common-display = {path = "src/common/display", default-features = false} common-file-formats = {path = "src/common/file-formats", default-features = false} common-hashable-float-wrapper = {path = "src/common/hashable-float-wrapper", default-features = false} +common-partitioning = {path = "src/common/partitioning", default-features = false} common-resource-request = {path = "src/common/resource-request", default-features = false} common-runtime = {path = "src/common/runtime", default-features = false} common-scan-info = {path = "src/common/scan-info", default-features = false} @@ -47,18 +48,12 @@ sysinfo = {workspace = true} # maturin will turn this on python = [ "common-daft-config/python", - "common-daft-config/python", - "common-daft-config/python", - "common-display/python", "common-display/python", - "common-display/python", - "common-resource-request/python", - "common-resource-request/python", + "common-partitioning/python", "common-resource-request/python", + "common-file-formats/python", "common-scan-info/python", "common-system-info/python", - "common-system-info/python", - "common-system-info/python", "daft-catalog-python-catalog/python", "daft-catalog/python", "daft-connect/python", @@ -80,7 +75,6 @@ python = [ "daft-scheduler/python", "daft-sql/python", "daft-stats/python", - "daft-stats/python", "daft-table/python", "daft-writers/python", "dep:daft-catalog-python-catalog", @@ -177,7 +171,8 @@ members = [ "src/daft-connect", "src/parquet2", # "src/spark-connect-script", - "src/generated/spark-connect" + "src/generated/spark-connect", + "src/common/partitioning" ] [workspace.dependencies] diff --git a/src/common/partitioning/Cargo.toml b/src/common/partitioning/Cargo.toml new file mode 100644 index 0000000000..e89aaf6827 --- /dev/null +++ b/src/common/partitioning/Cargo.toml @@ -0,0 +1,17 @@ +[dependencies] +common-py-serde = {path = "../py-serde", optional = true} +pyo3 = {workspace = true, optional = true} +common-error.workspace = true +futures.workspace = true +serde.workspace = true + +[features] +python = ["dep:pyo3", "common-error/python", "common-py-serde"] + +[lints] +workspace = true + +[package] +name = "common-partitioning" +edition.workspace = true +version.workspace = true diff --git a/src/common/partitioning/src/lib.rs b/src/common/partitioning/src/lib.rs new file mode 100644 index 0000000000..2df5c8cbb7 --- /dev/null +++ b/src/common/partitioning/src/lib.rs @@ -0,0 +1,167 @@ +use std::{any::Any, sync::Arc}; + +use common_error::DaftResult; +use futures::stream::BoxStream; +use serde::{Deserialize, Serialize}; +#[cfg(feature = "python")] +use { + common_py_serde::{deserialize_py_object, serialize_py_object}, + pyo3::PyObject, +}; + +/// Common trait interface for dataset partitioning, defined in this shared crate to avoid circular dependencies. +/// Acts as a forward declaration for concrete partition implementations. _(Specifically the `MicroPartition` type defined in `daft-micropartition`)_ +pub trait Partition: std::fmt::Debug + Send + Sync { + fn as_any(&self) -> &dyn Any; + fn size_bytes(&self) -> DaftResult>; +} + +impl Partition for Arc +where + T: Partition, +{ + fn as_any(&self) -> &dyn Any { + (**self).as_any() + } + fn size_bytes(&self) -> DaftResult> { + (**self).size_bytes() + } +} + +/// An Arc'd reference to a [`Partition`] +pub type PartitionRef = Arc; + +/// Key used to identify a partition +pub type PartitionId = usize; + +/// ported over from `daft/runners/partitioning.py` +// TODO: port over the rest of the functionality +#[derive(Debug, Clone)] +pub struct PartitionMetadata { + pub num_rows: usize, + pub size_bytes: usize, +} + +/// A partition set is a collection of partitions. +/// It is up to the implementation to decide how to store and manage the partition batches. +/// For example, an in memory partition set could likely be stored as `HashMap>`. +/// +/// It is important to note that the methods do not take `&mut self` but instead take `&self`. +/// So it is up to the implementation to manage any interior mutability. +pub trait PartitionSet: std::fmt::Debug + Send + Sync { + /// Merge all micropartitions into a single micropartition + fn get_merged_partitions(&self) -> DaftResult; + /// Get a preview of the micropartitions + fn get_preview_partitions(&self, num_rows: usize) -> DaftResult>; + /// Number of partitions + fn num_partitions(&self) -> usize; + fn len(&self) -> usize; + /// Check if the partition set is empty + fn is_empty(&self) -> bool { + self.len() == 0 + } + /// Size of the partition set in bytes + fn size_bytes(&self) -> DaftResult; + /// Check if a partition exists + fn has_partition(&self, idx: &PartitionId) -> bool; + /// Delete a partition + fn delete_partition(&self, idx: &PartitionId) -> DaftResult<()>; + /// Set a partition + fn set_partition(&self, idx: PartitionId, part: &T) -> DaftResult<()>; + /// Get a partition + fn get_partition(&self, idx: &PartitionId) -> DaftResult; + /// Consume the partition set and return a stream of partitions + fn to_partition_stream(&self) -> BoxStream<'static, DaftResult>; + fn metadata(&self) -> PartitionMetadata; +} + +impl PartitionSet

for Arc +where + P: Partition + Clone, + PS: PartitionSet

+ Clone, +{ + fn get_merged_partitions(&self) -> DaftResult { + PS::get_merged_partitions(self) + } + + fn get_preview_partitions(&self, num_rows: usize) -> DaftResult> { + PS::get_preview_partitions(self, num_rows) + } + + fn num_partitions(&self) -> usize { + PS::num_partitions(self) + } + + fn len(&self) -> usize { + PS::len(self) + } + + fn size_bytes(&self) -> DaftResult { + PS::size_bytes(self) + } + + fn has_partition(&self, idx: &PartitionId) -> bool { + PS::has_partition(self, idx) + } + + fn delete_partition(&self, idx: &PartitionId) -> DaftResult<()> { + PS::delete_partition(self, idx) + } + + fn set_partition(&self, idx: PartitionId, part: &P) -> DaftResult<()> { + PS::set_partition(self, idx, part) + } + + fn get_partition(&self, idx: &PartitionId) -> DaftResult

{ + PS::get_partition(self, idx) + } + + fn to_partition_stream(&self) -> BoxStream<'static, DaftResult

> { + PS::to_partition_stream(self) + } + + fn metadata(&self) -> PartitionMetadata { + PS::metadata(self) + } +} + +pub type PartitionSetRef = Arc>; + +pub trait PartitionSetCache>: + std::fmt::Debug + Send + Sync +{ + fn get_partition_set(&self, key: &str) -> Option>; + fn get_all_partition_sets(&self) -> Vec>; + fn put_partition_set(&self, key: &str, partition_set: &PS); + fn rm_partition_set(&self, key: &str); + fn clear(&self); +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum PartitionCacheEntry { + #[serde( + serialize_with = "serialize_py_object", + deserialize_with = "deserialize_py_object" + )] + #[cfg(feature = "python")] + /// in python, the partition cache is a weakvalue dictionary, so it will store the entry as long as this reference exists. + Python(PyObject), + + Rust { + key: String, + #[serde(skip)] + /// We don't ever actually reference the value, we're just holding it to ensure the partition set is kept alive. + /// + /// It's only wrapped in an `Option` to satisfy serde Deserialize. We skip (de)serializing, but serde still complains if it's not an Option. + value: Option>, + }, +} + +impl PartitionCacheEntry { + pub fn new_rust(key: String, value: Arc) -> Self { + Self::Rust { + key, + value: Some(value), + } + } +} diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs index 369bfe8e47..1371421396 100644 --- a/src/daft-connect/src/lib.rs +++ b/src/daft-connect/src/lib.rs @@ -30,6 +30,7 @@ use crate::session::Session; mod config; mod err; mod op; + mod session; mod translation; pub mod util; diff --git a/src/daft-connect/src/op/execute/root.rs b/src/daft-connect/src/op/execute/root.rs index fb18318708..4a9feb4ed7 100644 --- a/src/daft-connect/src/op/execute/root.rs +++ b/src/daft-connect/src/op/execute/root.rs @@ -10,7 +10,6 @@ use crate::{ op::execute::{ExecuteStream, PlanIds}, session::Session, translation, - translation::Plan, }; impl Session { @@ -30,19 +29,24 @@ impl Session { let finished = context.finished(); let (tx, rx) = tokio::sync::mpsc::channel::>(1); + + let pset = self.psets.clone(); + tokio::spawn(async move { let execution_fut = async { - let Plan { builder, psets } = translation::to_logical_plan(command).await?; + let translator = translation::SparkAnalyzer::new(&pset); + let lp = translator.to_logical_plan(command).await?; // todo: convert optimize to async (looks like A LOT of work)... it touches a lot of API // I tried and spent about an hour and gave up ~ Andrew Gazelka 🪦 2024-12-09 - let optimized_plan = tokio::task::spawn_blocking(move || builder.optimize()) + let optimized_plan = tokio::task::spawn_blocking(move || lp.optimize()) .await .unwrap()?; let cfg = Arc::new(DaftExecutionConfig::default()); let native_executor = NativeExecutor::from_logical_plan_builder(&optimized_plan)?; - let mut result_stream = native_executor.run(psets, cfg, None)?.into_stream(); + + let mut result_stream = native_executor.run(&pset, cfg, None)?.into_stream(); while let Some(result) = result_stream.next().await { let result = result?; diff --git a/src/daft-connect/src/op/execute/write.rs b/src/daft-connect/src/op/execute/write.rs index 44696f8164..5db783f5e1 100644 --- a/src/daft-connect/src/op/execute/write.rs +++ b/src/daft-connect/src/op/execute/write.rs @@ -32,6 +32,7 @@ impl Session { }; let finished = context.finished(); + let pset = self.psets.clone(); let result = async move { let WriteOperation { @@ -109,19 +110,19 @@ impl Session { } }; - let mut plan = translation::to_logical_plan(input).await?; + let translator = translation::SparkAnalyzer::new(&pset); - plan.builder = plan - .builder + let plan = translator.to_logical_plan(input).await?; + + let plan = plan .table_write(&path, FileFormat::Parquet, None, None, None) .wrap_err("Failed to create table write plan")?; - let optimized_plan = plan.builder.optimize()?; + let optimized_plan = plan.optimize()?; let cfg = DaftExecutionConfig::default(); let native_executor = NativeExecutor::from_logical_plan_builder(&optimized_plan)?; - let mut result_stream = native_executor - .run(plan.psets, cfg.into(), None)? - .into_stream(); + + let mut result_stream = native_executor.run(&pset, cfg.into(), None)?.into_stream(); // this is so we make sure the operation is actually done // before we return diff --git a/src/daft-connect/src/session.rs b/src/daft-connect/src/session.rs index 24f7fabe80..30f827ba9e 100644 --- a/src/daft-connect/src/session.rs +++ b/src/daft-connect/src/session.rs @@ -1,5 +1,6 @@ use std::collections::BTreeMap; +use daft_micropartition::partitioning::InMemoryPartitionSetCache; use uuid::Uuid; pub struct Session { @@ -10,6 +11,9 @@ pub struct Session { id: String, server_side_session_id: String, + /// MicroPartitionSet associated with this session + /// this will be filled up as the user runs queries + pub(crate) psets: InMemoryPartitionSetCache, } impl Session { @@ -24,10 +28,12 @@ impl Session { pub fn new(id: String) -> Self { let server_side_session_id = Uuid::new_v4(); let server_side_session_id = server_side_session_id.to_string(); + Self { config_values: Default::default(), id, server_side_session_id, + psets: InMemoryPartitionSetCache::empty(), } } diff --git a/src/daft-connect/src/translation.rs b/src/daft-connect/src/translation.rs index 8b61b93f98..5d9bf89881 100644 --- a/src/daft-connect/src/translation.rs +++ b/src/daft-connect/src/translation.rs @@ -9,5 +9,5 @@ mod schema; pub use datatype::{deser_spark_datatype, to_daft_datatype, to_spark_datatype}; pub use expr::to_daft_expr; pub use literal::to_daft_literal; -pub use logical_plan::{to_logical_plan, Plan}; +pub use logical_plan::SparkAnalyzer; pub use schema::relation_to_schema; diff --git a/src/daft-connect/src/translation/datatype/codec.rs b/src/daft-connect/src/translation/datatype/codec.rs index 50f2d94a02..4f554765ba 100644 --- a/src/daft-connect/src/translation/datatype/codec.rs +++ b/src/daft-connect/src/translation/datatype/codec.rs @@ -2,7 +2,6 @@ use color_eyre::Help; use eyre::{bail, ensure, eyre}; use serde_json::Value; use spark_connect::data_type::Kind; -use tracing::warn; #[derive(Debug)] enum TypeTag { @@ -211,12 +210,10 @@ fn deser_struct_field( bail!("expected object"); }; - let Some(metadata) = object.remove("metadata") else { + let Some(_metadata) = object.remove("metadata") else { bail!("missing metadata"); }; - warn!("ignoring metadata: {metadata:?}"); - let Some(name) = object.remove("name") else { bail!("missing name"); }; diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index b6097d17ad..15eb495502 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -1,14 +1,9 @@ use daft_logical_plan::LogicalPlanBuilder; -use daft_micropartition::partitioning::InMemoryPartitionSet; +use daft_micropartition::partitioning::InMemoryPartitionSetCache; use eyre::{bail, Context}; use spark_connect::{relation::RelType, Limit, Relation}; use tracing::warn; -use crate::translation::logical_plan::{ - aggregate::aggregate, drop::drop, filter::filter, local_relation::local_relation, - project::project, range::range, read::read, to_df::to_df, with_columns::with_columns, -}; - mod aggregate; mod drop; mod filter; @@ -19,82 +14,86 @@ mod read; mod to_df; mod with_columns; -pub struct Plan { - pub builder: LogicalPlanBuilder, - pub psets: InMemoryPartitionSet, +pub struct SparkAnalyzer<'a> { + pub psets: &'a InMemoryPartitionSetCache, } -impl Plan { - pub fn new(builder: LogicalPlanBuilder) -> Self { - Self { - builder, - psets: InMemoryPartitionSet::default(), - } +impl SparkAnalyzer<'_> { + pub fn new(pset: &InMemoryPartitionSetCache) -> SparkAnalyzer { + SparkAnalyzer { psets: pset } } -} -impl From for Plan { - fn from(builder: LogicalPlanBuilder) -> Self { - Self { - builder, - psets: InMemoryPartitionSet::default(), - } - } -} + pub async fn to_logical_plan(&self, relation: Relation) -> eyre::Result { + let Some(common) = relation.common else { + bail!("Common metadata is required"); + }; -pub async fn to_logical_plan(relation: Relation) -> eyre::Result { - if let Some(common) = relation.common { if common.origin.is_some() { warn!("Ignoring common metadata for relation: {common:?}; not yet implemented"); } - }; - let Some(rel_type) = relation.rel_type else { - bail!("Relation type is required"); - }; + let Some(rel_type) = relation.rel_type else { + bail!("Relation type is required"); + }; - match rel_type { - RelType::Limit(l) => limit(*l) - .await - .wrap_err("Failed to apply limit to logical plan"), - RelType::Range(r) => range(r).wrap_err("Failed to apply range to logical plan"), - RelType::Project(p) => project(*p) - .await - .wrap_err("Failed to apply project to logical plan"), - RelType::Filter(f) => filter(*f) - .await - .wrap_err("Failed to apply filter to logical plan"), - RelType::Aggregate(a) => aggregate(*a) - .await - .wrap_err("Failed to apply aggregate to logical plan"), - RelType::WithColumns(w) => with_columns(*w) - .await - .wrap_err("Failed to apply with_columns to logical plan"), - RelType::ToDf(t) => to_df(*t) - .await - .wrap_err("Failed to apply to_df to logical plan"), - RelType::LocalRelation(l) => { - local_relation(l).wrap_err("Failed to apply local_relation to logical plan") + match rel_type { + RelType::Limit(l) => self + .limit(*l) + .await + .wrap_err("Failed to apply limit to logical plan"), + RelType::Range(r) => self + .range(r) + .wrap_err("Failed to apply range to logical plan"), + RelType::Project(p) => self + .project(*p) + .await + .wrap_err("Failed to apply project to logical plan"), + RelType::Aggregate(a) => self + .aggregate(*a) + .await + .wrap_err("Failed to apply aggregate to logical plan"), + RelType::WithColumns(w) => self + .with_columns(*w) + .await + .wrap_err("Failed to apply with_columns to logical plan"), + RelType::ToDf(t) => self + .to_df(*t) + .await + .wrap_err("Failed to apply to_df to logical plan"), + RelType::LocalRelation(l) => { + let Some(plan_id) = common.plan_id else { + bail!("Plan ID is required for LocalRelation"); + }; + self.local_relation(plan_id, l) + .wrap_err("Failed to apply local_relation to logical plan") + } + RelType::Read(r) => read::read(r) + .await + .wrap_err("Failed to apply read to logical plan"), + RelType::Drop(d) => self + .drop(*d) + .await + .wrap_err("Failed to apply drop to logical plan"), + RelType::Filter(f) => self + .filter(*f) + .await + .wrap_err("Failed to apply filter to logical plan"), + plan => bail!("Unsupported relation type: {plan:?}"), } - RelType::Read(r) => read(r) - .await - .wrap_err("Failed to apply read to logical plan"), - RelType::Drop(d) => drop(*d) - .await - .wrap_err("Failed to apply drop to logical plan"), - plan => bail!("Unsupported relation type: {plan:?}"), } } -async fn limit(limit: Limit) -> eyre::Result { - let Limit { input, limit } = limit; +impl SparkAnalyzer<'_> { + async fn limit(&self, limit: Limit) -> eyre::Result { + let Limit { input, limit } = limit; - let Some(input) = input else { - bail!("input must be set"); - }; + let Some(input) = input else { + bail!("input must be set"); + }; - let mut plan = Box::pin(to_logical_plan(*input)).await?; - plan.builder = plan.builder.limit(i64::from(limit), false)?; // todo: eager or no + let plan = Box::pin(self.to_logical_plan(*input)).await?; - Ok(plan) + plan.limit(i64::from(limit), false) + .wrap_err("Failed to apply limit to logical plan") + } } diff --git a/src/daft-connect/src/translation/logical_plan/aggregate.rs b/src/daft-connect/src/translation/logical_plan/aggregate.rs index 3687f191f8..2a46b0cbba 100644 --- a/src/daft-connect/src/translation/logical_plan/aggregate.rs +++ b/src/daft-connect/src/translation/logical_plan/aggregate.rs @@ -1,52 +1,59 @@ +use daft_logical_plan::LogicalPlanBuilder; use eyre::{bail, WrapErr}; use spark_connect::aggregate::GroupType; -use crate::translation::{logical_plan::Plan, to_daft_expr, to_logical_plan}; +use super::SparkAnalyzer; +use crate::translation::to_daft_expr; -pub async fn aggregate(aggregate: spark_connect::Aggregate) -> eyre::Result { - let spark_connect::Aggregate { - input, - group_type, - grouping_expressions, - aggregate_expressions, - pivot, - grouping_sets, - } = aggregate; +impl SparkAnalyzer<'_> { + pub async fn aggregate( + &self, + aggregate: spark_connect::Aggregate, + ) -> eyre::Result { + let spark_connect::Aggregate { + input, + group_type, + grouping_expressions, + aggregate_expressions, + pivot, + grouping_sets, + } = aggregate; - let Some(input) = input else { - bail!("input is required"); - }; + let Some(input) = input else { + bail!("input is required"); + }; - let mut plan = Box::pin(to_logical_plan(*input)).await?; + let mut plan = Box::pin(self.to_logical_plan(*input)).await?; - let group_type = GroupType::try_from(group_type) - .wrap_err_with(|| format!("Invalid group type: {group_type:?}"))?; + let group_type = GroupType::try_from(group_type) + .wrap_err_with(|| format!("Invalid group type: {group_type:?}"))?; - assert_groupby(group_type)?; + assert_groupby(group_type)?; - if let Some(pivot) = pivot { - bail!("Pivot not yet supported; got {pivot:?}"); - } + if let Some(pivot) = pivot { + bail!("Pivot not yet supported; got {pivot:?}"); + } - if !grouping_sets.is_empty() { - bail!("Grouping sets not yet supported; got {grouping_sets:?}"); - } + if !grouping_sets.is_empty() { + bail!("Grouping sets not yet supported; got {grouping_sets:?}"); + } - let grouping_expressions: Vec<_> = grouping_expressions - .iter() - .map(to_daft_expr) - .try_collect()?; + let grouping_expressions: Vec<_> = grouping_expressions + .iter() + .map(to_daft_expr) + .try_collect()?; - let aggregate_expressions: Vec<_> = aggregate_expressions - .iter() - .map(to_daft_expr) - .try_collect()?; + let aggregate_expressions: Vec<_> = aggregate_expressions + .iter() + .map(to_daft_expr) + .try_collect()?; - plan.builder = plan.builder + plan = plan .aggregate(aggregate_expressions.clone(), grouping_expressions.clone()) .wrap_err_with(|| format!("Failed to apply aggregate to logical plan aggregate_expressions={aggregate_expressions:?} grouping_expressions={grouping_expressions:?}"))?; - Ok(plan) + Ok(plan) + } } fn assert_groupby(plan: GroupType) -> eyre::Result<()> { diff --git a/src/daft-connect/src/translation/logical_plan/drop.rs b/src/daft-connect/src/translation/logical_plan/drop.rs index 35613add68..b5cac5a41b 100644 --- a/src/daft-connect/src/translation/logical_plan/drop.rs +++ b/src/daft-connect/src/translation/logical_plan/drop.rs @@ -1,39 +1,40 @@ +use daft_logical_plan::LogicalPlanBuilder; use eyre::bail; -use crate::translation::{to_logical_plan, Plan}; +use super::SparkAnalyzer; -pub async fn drop(drop: spark_connect::Drop) -> eyre::Result { - let spark_connect::Drop { - input, - columns, - column_names, - } = drop; +impl SparkAnalyzer<'_> { + pub async fn drop(&self, drop: spark_connect::Drop) -> eyre::Result { + let spark_connect::Drop { + input, + columns, + column_names, + } = drop; - let Some(input) = input else { - bail!("input is required"); - }; + let Some(input) = input else { + bail!("input is required"); + }; - if !columns.is_empty() { - bail!("columns is not supported; use column_names instead"); - } - - let mut plan = Box::pin(to_logical_plan(*input)).await?; + if !columns.is_empty() { + bail!("columns is not supported; use column_names instead"); + } - // Get all column names from the schema - let all_columns = plan.builder.schema().names(); + let plan = Box::pin(self.to_logical_plan(*input)).await?; - // Create a set of columns to drop for efficient lookup - let columns_to_drop: std::collections::HashSet<_> = column_names.iter().collect(); + // Get all column names from the schema + let all_columns = plan.schema().names(); - // Create expressions for all columns except the ones being dropped - let to_select = all_columns - .iter() - .filter(|col_name| !columns_to_drop.contains(*col_name)) - .map(|col_name| daft_dsl::col(col_name.clone())) - .collect(); + // Create a set of columns to drop for efficient lookup + let columns_to_drop: std::collections::HashSet<_> = column_names.iter().collect(); - // Use select to keep only the columns we want - plan.builder = plan.builder.select(to_select)?; + // Create expressions for all columns except the ones being dropped + let to_select = all_columns + .iter() + .filter(|col_name| !columns_to_drop.contains(*col_name)) + .map(|col_name| daft_dsl::col(col_name.clone())) + .collect(); - Ok(plan) + // Use select to keep only the columns we want + Ok(plan.select(to_select)?) + } } diff --git a/src/daft-connect/src/translation/logical_plan/filter.rs b/src/daft-connect/src/translation/logical_plan/filter.rs index 6879464abc..43ad4c7a52 100644 --- a/src/daft-connect/src/translation/logical_plan/filter.rs +++ b/src/daft-connect/src/translation/logical_plan/filter.rs @@ -1,22 +1,24 @@ +use daft_logical_plan::LogicalPlanBuilder; use eyre::bail; -use crate::translation::{to_daft_expr, to_logical_plan, Plan}; +use super::SparkAnalyzer; +use crate::translation::to_daft_expr; -pub async fn filter(filter: spark_connect::Filter) -> eyre::Result { - let spark_connect::Filter { input, condition } = filter; +impl SparkAnalyzer<'_> { + pub async fn filter(&self, filter: spark_connect::Filter) -> eyre::Result { + let spark_connect::Filter { input, condition } = filter; - let Some(input) = input else { - bail!("input is required"); - }; + let Some(input) = input else { + bail!("input is required"); + }; - let Some(condition) = condition else { - bail!("condition is required"); - }; + let Some(condition) = condition else { + bail!("condition is required"); + }; - let condition = to_daft_expr(&condition)?; + let condition = to_daft_expr(&condition)?; - let mut plan = Box::pin(to_logical_plan(*input)).await?; - plan.builder = plan.builder.filter(condition)?; - - Ok(plan) + let plan = Box::pin(self.to_logical_plan(*input)).await?; + Ok(plan.filter(condition)?) + } } diff --git a/src/daft-connect/src/translation/logical_plan/local_relation.rs b/src/daft-connect/src/translation/logical_plan/local_relation.rs index 7244ed09b9..574e35a2fd 100644 --- a/src/daft-connect/src/translation/logical_plan/local_relation.rs +++ b/src/daft-connect/src/translation/logical_plan/local_relation.rs @@ -1,32 +1,28 @@ -use std::{collections::HashMap, io::Cursor, sync::Arc}; +use std::{io::Cursor, sync::Arc}; use arrow2::io::ipc::{ read::{StreamMetadata, StreamReader, StreamState, Version}, IpcField, IpcSchema, }; use daft_core::series::Series; -use daft_logical_plan::{ - logical_plan::Source, InMemoryInfo, LogicalPlan, LogicalPlanBuilder, PyLogicalPlanBuilder, - SourceInfo, +use daft_logical_plan::LogicalPlanBuilder; +use daft_micropartition::partitioning::{ + MicroPartitionSet, PartitionCacheEntry, PartitionMetadata, PartitionSet, PartitionSetCache, }; -use daft_micropartition::partitioning::InMemoryPartitionSet; use daft_schema::dtype::DaftDataType; use daft_table::Table; use eyre::{bail, ensure, WrapErr}; use itertools::Itertools; -use crate::translation::{deser_spark_datatype, logical_plan::Plan, to_daft_datatype}; +use super::SparkAnalyzer; +use crate::translation::{deser_spark_datatype, to_daft_datatype}; -pub fn local_relation(plan: spark_connect::LocalRelation) -> eyre::Result { - #[cfg(not(feature = "python"))] - { - bail!("LocalRelation plan is only supported in Python mode"); - } - - #[cfg(feature = "python")] - { - use daft_micropartition::{python::PyMicroPartition, MicroPartition}; - use pyo3::{types::PyAnyMethods, Python}; +impl SparkAnalyzer<'_> { + pub fn local_relation( + &self, + plan_id: i64, + plan: spark_connect::LocalRelation, + ) -> eyre::Result { let spark_connect::LocalRelation { data, schema } = plan; let Some(data) = data else { @@ -139,67 +135,33 @@ pub fn local_relation(plan: spark_connect::LocalRelation) -> eyre::Result "Mismatch in row counts across columns; all columns must have the same number of rows." ); - let Some(&num_rows) = num_rows.first() else { - bail!("No columns were found; at least one column is required.") - }; + let batch = Table::from_nonempty_columns(columns)?; - let table = Table::new_with_size(daft_schema.clone(), columns, num_rows) - .wrap_err("Failed to create Table from columns and schema.")?; - - tables.push(table); + tables.push(batch); } tables }; - // Note: Verify if the Daft schema used here matches the schema of the table. - let micro_partition = MicroPartition::new_loaded(daft_schema, Arc::new(tables), None); - let micro_partition = Arc::new(micro_partition); - - let plan = Python::with_gil(|py| { - // Convert MicroPartition to a logical plan using Python interop. - let py_micropartition = py - .import_bound(pyo3::intern!(py, "daft.table"))? - .getattr(pyo3::intern!(py, "MicroPartition"))? - .getattr(pyo3::intern!(py, "_from_pymicropartition"))? - .call1((PyMicroPartition::from(micro_partition.clone()),))?; - - // ERROR: 2: AttributeError: 'daft.daft.PySchema' object has no attribute '_schema' - let py_plan_builder = py - .import_bound(pyo3::intern!(py, "daft.dataframe.dataframe"))? - .getattr(pyo3::intern!(py, "to_logical_plan_builder"))? - .call1((py_micropartition,))?; - - let py_plan_builder = py_plan_builder.getattr(pyo3::intern!(py, "_builder"))?; - - let plan: PyLogicalPlanBuilder = py_plan_builder.extract()?; - - Ok::<_, eyre::Error>(plan.builder) - })?; - - let cache_key = grab_singular_cache_key(&plan)?; - - let mut psets = HashMap::new(); - psets.insert(cache_key, vec![micro_partition]); - - let plan = Plan { - builder: plan, - psets: InMemoryPartitionSet::new(psets), - }; - - Ok(plan) + let pset = MicroPartitionSet::from_tables(plan_id as usize, tables)?; + let PartitionMetadata { + size_bytes, + num_rows, + } = pset.metadata(); + let num_partitions = pset.num_partitions(); + + let partition_key: Arc = uuid::Uuid::new_v4().to_string().into(); + let pset = Arc::new(pset); + self.psets.put_partition_set(&partition_key, &pset); + + let lp = LogicalPlanBuilder::in_memory_scan( + &partition_key, + PartitionCacheEntry::new_rust(partition_key.to_string(), pset), + daft_schema, + num_partitions, + size_bytes, + num_rows, + )?; + + Ok(lp) } } - -fn grab_singular_cache_key(plan: &LogicalPlanBuilder) -> eyre::Result { - let plan = &*plan.plan; - - let LogicalPlan::Source(Source { source_info, .. }) = plan else { - bail!("Expected a source plan"); - }; - - let SourceInfo::InMemory(InMemoryInfo { cache_key, .. }) = &**source_info else { - bail!("Expected an in-memory source"); - }; - - Ok(cache_key.clone()) -} diff --git a/src/daft-connect/src/translation/logical_plan/project.rs b/src/daft-connect/src/translation/logical_plan/project.rs index af03c8dc2e..448242d31d 100644 --- a/src/daft-connect/src/translation/logical_plan/project.rs +++ b/src/daft-connect/src/translation/logical_plan/project.rs @@ -3,22 +3,26 @@ //! TL;DR: Project is Spark's equivalent of SQL SELECT - it selects columns, renames them via aliases, //! and creates new columns from expressions. Example: `df.select(col("id").alias("my_number"))` +use daft_logical_plan::LogicalPlanBuilder; use eyre::bail; use spark_connect::Project; -use crate::translation::{logical_plan::Plan, to_daft_expr, to_logical_plan}; +use super::SparkAnalyzer; +use crate::translation::to_daft_expr; -pub async fn project(project: Project) -> eyre::Result { - let Project { input, expressions } = project; +impl SparkAnalyzer<'_> { + pub async fn project(&self, project: Project) -> eyre::Result { + let Project { input, expressions } = project; - let Some(input) = input else { - bail!("Project input is required"); - }; + let Some(input) = input else { + bail!("Project input is required"); + }; - let mut plan = Box::pin(to_logical_plan(*input)).await?; + let mut plan = Box::pin(self.to_logical_plan(*input)).await?; - let daft_exprs: Vec<_> = expressions.iter().map(to_daft_expr).try_collect()?; - plan.builder = plan.builder.select(daft_exprs)?; + let daft_exprs: Vec<_> = expressions.iter().map(to_daft_expr).try_collect()?; + plan = plan.select(daft_exprs)?; - Ok(plan) + Ok(plan) + } } diff --git a/src/daft-connect/src/translation/logical_plan/range.rs b/src/daft-connect/src/translation/logical_plan/range.rs index ff15e0cacb..1660bef5bf 100644 --- a/src/daft-connect/src/translation/logical_plan/range.rs +++ b/src/daft-connect/src/translation/logical_plan/range.rs @@ -2,56 +2,59 @@ use daft_logical_plan::LogicalPlanBuilder; use eyre::{ensure, Context}; use spark_connect::Range; -use crate::translation::logical_plan::Plan; +use super::SparkAnalyzer; -pub fn range(range: Range) -> eyre::Result { - #[cfg(not(feature = "python"))] - { - use eyre::bail; - bail!("Range operations require Python feature to be enabled"); - } +impl SparkAnalyzer<'_> { + pub fn range(&self, range: Range) -> eyre::Result { + #[cfg(not(feature = "python"))] + { + use eyre::bail; + bail!("Range operations require Python feature to be enabled"); + } - #[cfg(feature = "python")] - { - use daft_scan::python::pylib::ScanOperatorHandle; - use pyo3::prelude::*; - let Range { - start, - end, - step, - num_partitions, - } = range; + #[cfg(feature = "python")] + { + use daft_scan::python::pylib::ScanOperatorHandle; + use pyo3::prelude::*; + let Range { + start, + end, + step, + num_partitions, + } = range; - let partitions = num_partitions.unwrap_or(1); + let partitions = num_partitions.unwrap_or(1); - ensure!(partitions > 0, "num_partitions must be greater than 0"); + ensure!(partitions > 0, "num_partitions must be greater than 0"); - let start = start.unwrap_or(0); + let start = start.unwrap_or(0); - let step = usize::try_from(step).wrap_err("step must be a positive integer")?; - ensure!(step > 0, "step must be greater than 0"); + let step = usize::try_from(step).wrap_err("step must be a positive integer")?; + ensure!(step > 0, "step must be greater than 0"); - let plan = Python::with_gil(|py| { - let range_module = PyModule::import_bound(py, "daft.io._range") - .wrap_err("Failed to import range module")?; + let plan = Python::with_gil(|py| { + let range_module = PyModule::import_bound(py, "daft.io._range") + .wrap_err("Failed to import range module")?; - let range = range_module - .getattr(pyo3::intern!(py, "RangeScanOperator")) - .wrap_err("Failed to get range function")?; + let range = range_module + .getattr(pyo3::intern!(py, "RangeScanOperator")) + .wrap_err("Failed to get range function")?; - let range = range - .call1((start, end, step, partitions)) - .wrap_err("Failed to create range scan operator")? - .to_object(py); + let range = range + .call1((start, end, step, partitions)) + .wrap_err("Failed to create range scan operator")? + .to_object(py); - let scan_operator_handle = ScanOperatorHandle::from_python_scan_operator(range, py)?; + let scan_operator_handle = + ScanOperatorHandle::from_python_scan_operator(range, py)?; - let plan = LogicalPlanBuilder::table_scan(scan_operator_handle.into(), None)?; + let plan = LogicalPlanBuilder::table_scan(scan_operator_handle.into(), None)?; - eyre::Result::<_>::Ok(plan) - }) - .wrap_err("Failed to create range scan")?; + eyre::Result::<_>::Ok(plan) + }) + .wrap_err("Failed to create range scan")?; - Ok(plan.into()) + Ok(plan) + } } } diff --git a/src/daft-connect/src/translation/logical_plan/read.rs b/src/daft-connect/src/translation/logical_plan/read.rs index fc8a834fbb..9a73783191 100644 --- a/src/daft-connect/src/translation/logical_plan/read.rs +++ b/src/daft-connect/src/translation/logical_plan/read.rs @@ -1,12 +1,11 @@ +use daft_logical_plan::LogicalPlanBuilder; use eyre::{bail, WrapErr}; use spark_connect::read::ReadType; use tracing::warn; -use crate::translation::Plan; - mod data_source; -pub async fn read(read: spark_connect::Read) -> eyre::Result { +pub async fn read(read: spark_connect::Read) -> eyre::Result { let spark_connect::Read { is_streaming, read_type, @@ -28,5 +27,5 @@ pub async fn read(read: spark_connect::Read) -> eyre::Result { .wrap_err("Failed to create data source"), }?; - Ok(Plan::from(builder)) + Ok(builder) } diff --git a/src/daft-connect/src/translation/logical_plan/to_df.rs b/src/daft-connect/src/translation/logical_plan/to_df.rs index c2a355a1e5..e3d172661b 100644 --- a/src/daft-connect/src/translation/logical_plan/to_df.rs +++ b/src/daft-connect/src/translation/logical_plan/to_df.rs @@ -1,30 +1,28 @@ +use daft_logical_plan::LogicalPlanBuilder; use eyre::{bail, WrapErr}; -use crate::translation::{logical_plan::Plan, to_logical_plan}; +use super::SparkAnalyzer; +impl SparkAnalyzer<'_> { + pub async fn to_df(&self, to_df: spark_connect::ToDf) -> eyre::Result { + let spark_connect::ToDf { + input, + column_names, + } = to_df; -pub async fn to_df(to_df: spark_connect::ToDf) -> eyre::Result { - let spark_connect::ToDf { - input, - column_names, - } = to_df; + let Some(input) = input else { + bail!("Input is required"); + }; - let Some(input) = input else { - bail!("Input is required"); - }; + let mut plan = Box::pin(self.to_logical_plan(*input)).await?; - let mut plan = Box::pin(to_logical_plan(*input)) - .await - .wrap_err("Failed to translate relation to logical plan")?; + let column_names: Vec<_> = column_names + .iter() + .map(|s| daft_dsl::col(s.as_str())) + .collect(); - let column_names: Vec<_> = column_names - .iter() - .map(|s| daft_dsl::col(s.as_str())) - .collect(); - - plan.builder = plan - .builder - .select(column_names) - .wrap_err("Failed to add columns to logical plan")?; - - Ok(plan) + plan = plan + .select(column_names) + .wrap_err("Failed to add columns to logical plan")?; + Ok(plan) + } } diff --git a/src/daft-connect/src/translation/logical_plan/with_columns.rs b/src/daft-connect/src/translation/logical_plan/with_columns.rs index 08396ecdba..97b3c3d1d1 100644 --- a/src/daft-connect/src/translation/logical_plan/with_columns.rs +++ b/src/daft-connect/src/translation/logical_plan/with_columns.rs @@ -1,30 +1,35 @@ +use daft_logical_plan::LogicalPlanBuilder; use eyre::bail; use spark_connect::{expression::ExprType, Expression}; -use crate::translation::{to_daft_expr, to_logical_plan, Plan}; +use super::SparkAnalyzer; +use crate::translation::to_daft_expr; -pub async fn with_columns(with_columns: spark_connect::WithColumns) -> eyre::Result { - let spark_connect::WithColumns { input, aliases } = with_columns; +impl SparkAnalyzer<'_> { + pub async fn with_columns( + &self, + with_columns: spark_connect::WithColumns, + ) -> eyre::Result { + let spark_connect::WithColumns { input, aliases } = with_columns; - let Some(input) = input else { - bail!("input is required"); - }; + let Some(input) = input else { + bail!("input is required"); + }; - let mut plan = Box::pin(to_logical_plan(*input)).await?; + let plan = Box::pin(self.to_logical_plan(*input)).await?; - let daft_exprs: Vec<_> = aliases - .into_iter() - .map(|alias| { - let expression = Expression { - common: None, - expr_type: Some(ExprType::Alias(Box::new(alias))), - }; + let daft_exprs: Vec<_> = aliases + .into_iter() + .map(|alias| { + let expression = Expression { + common: None, + expr_type: Some(ExprType::Alias(Box::new(alias))), + }; - to_daft_expr(&expression) - }) - .try_collect()?; + to_daft_expr(&expression) + }) + .try_collect()?; - plan.builder = plan.builder.with_columns(daft_exprs)?; - - Ok(plan) + Ok(plan.with_columns(daft_exprs)?) + } } diff --git a/src/daft-connect/src/translation/schema.rs b/src/daft-connect/src/translation/schema.rs index 1868eaeb2d..605f1b640a 100644 --- a/src/daft-connect/src/translation/schema.rs +++ b/src/daft-connect/src/translation/schema.rs @@ -1,10 +1,12 @@ +use daft_micropartition::partitioning::InMemoryPartitionSetCache; use spark_connect::{ data_type::{Kind, Struct, StructField}, DataType, Relation, }; use tracing::warn; -use crate::translation::{to_logical_plan, to_spark_datatype}; +use super::SparkAnalyzer; +use crate::translation::to_spark_datatype; #[tracing::instrument(skip_all)] pub async fn relation_to_schema(input: Relation) -> eyre::Result { @@ -14,9 +16,12 @@ pub async fn relation_to_schema(input: Relation) -> eyre::Result { } } - let plan = Box::pin(to_logical_plan(input)).await?; + // We're just checking the schema here, so we don't need to use a persistent cache as it won't be used + let pset = InMemoryPartitionSetCache::empty(); + let translator = SparkAnalyzer::new(&pset); + let plan = Box::pin(translator.to_logical_plan(input)).await?; - let result = plan.builder.schema(); + let result = plan.schema(); let fields: eyre::Result> = result .fields diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index 8881cb37de..d13277e673 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -17,7 +17,10 @@ use daft_local_plan::{ Project, Sample, Sort, UnGroupedAggregate, Unpivot, }; use daft_logical_plan::{stats::StatsState, JoinType}; -use daft_micropartition::{partitioning::PartitionSet, MicroPartition}; +use daft_micropartition::{ + partitioning::{MicroPartitionSet, PartitionSetCache}, + MicroPartition, MicroPartitionRef, +}; use daft_scan::ScanTaskRef; use daft_writers::make_physical_writer_factory; use indexmap::IndexSet; @@ -78,7 +81,7 @@ pub fn viz_pipeline(root: &dyn PipelineNode) -> String { pub fn physical_plan_to_pipeline( physical_plan: &LocalPhysicalPlan, - psets: &(impl PartitionSet + ?Sized), + psets: &(impl PartitionSetCache> + ?Sized), cfg: &Arc, ) -> crate::Result> { use daft_local_plan::PhysicalScan; @@ -105,9 +108,12 @@ pub fn physical_plan_to_pipeline( scan_task_source.arced().into() } LocalPhysicalPlan::InMemoryScan(InMemoryScan { info, .. }) => { + let cache_key: Arc = info.cache_key.clone().into(); + let materialized_pset = psets - .get_partition(&info.cache_key) - .unwrap_or_else(|_| panic!("Cache key not found: {:?}", info.cache_key)); + .get_partition_set(&cache_key) + .unwrap_or_else(|| panic!("Cache key not found: {:?}", info.cache_key)); + InMemorySource::new(materialized_pset, info.source_schema.clone()) .arced() .into() diff --git a/src/daft-local-execution/src/run.rs b/src/daft-local-execution/src/run.rs index 94d9173184..32b92c9015 100644 --- a/src/daft-local-execution/src/run.rs +++ b/src/daft-local-execution/src/run.rs @@ -12,8 +12,8 @@ use common_tracing::refresh_chrome_trace; use daft_local_plan::{translate, LocalPhysicalPlan}; use daft_logical_plan::LogicalPlanBuilder; use daft_micropartition::{ - partitioning::{InMemoryPartitionSet, PartitionSet}, - MicroPartition, + partitioning::{InMemoryPartitionSetCache, MicroPartitionSet, PartitionSetCache}, + MicroPartition, MicroPartitionRef, }; use futures::{FutureExt, Stream}; use loole::RecvFuture; @@ -80,22 +80,25 @@ impl PyNativeExecutor { cfg: PyDaftExecutionConfig, results_buffer_size: Option, ) -> PyResult { - let native_psets: HashMap>> = psets + let native_psets: HashMap> = psets .into_iter() .map(|(part_id, parts)| { ( part_id, - parts - .into_iter() - .map(std::convert::Into::into) - .collect::>>(), + Arc::new( + parts + .into_iter() + .map(std::convert::Into::into) + .collect::>>() + .into(), + ), ) }) .collect(); - let psets = InMemoryPartitionSet::new(native_psets); + let psets = InMemoryPartitionSetCache::new(&native_psets); let out = py.allow_threads(|| { self.executor - .run(psets, cfg.config, results_buffer_size) + .run(&psets, cfg.config, results_buffer_size) .map(|res| res.into_iter()) })?; let iter = Box::new(out.map(|part| { @@ -117,6 +120,7 @@ impl NativeExecutor { ) -> DaftResult { let logical_plan = logical_plan_builder.build(); let local_physical_plan = translate(&logical_plan)?; + Ok(Self { local_physical_plan, cancel: CancellationToken::new(), @@ -125,7 +129,7 @@ impl NativeExecutor { pub fn run( &self, - psets: impl PartitionSet, + psets: &(impl PartitionSetCache> + ?Sized), cfg: Arc, results_buffer_size: Option, ) -> DaftResult { @@ -250,13 +254,13 @@ impl IntoIterator for ExecutionEngineResult { pub fn run_local( physical_plan: &LocalPhysicalPlan, - psets: impl PartitionSet, + psets: &(impl PartitionSetCache> + ?Sized), cfg: Arc, results_buffer_size: Option, cancel: CancellationToken, ) -> DaftResult { refresh_chrome_trace(); - let pipeline = physical_plan_to_pipeline(physical_plan, &psets, &cfg)?; + let pipeline = physical_plan_to_pipeline(physical_plan, psets, &cfg)?; let (tx, rx) = create_channel(results_buffer_size.unwrap_or(1)); let handle = std::thread::spawn(move || { let runtime = tokio::runtime::Builder::new_current_thread() diff --git a/src/daft-local-execution/src/sources/in_memory.rs b/src/daft-local-execution/src/sources/in_memory.rs index b76775ca3c..53f682f543 100644 --- a/src/daft-local-execution/src/sources/in_memory.rs +++ b/src/daft-local-execution/src/sources/in_memory.rs @@ -4,19 +4,19 @@ use async_trait::async_trait; use common_error::DaftResult; use daft_core::prelude::SchemaRef; use daft_io::IOStatsRef; -use daft_micropartition::partitioning::PartitionBatchRef; +use daft_micropartition::{partitioning::PartitionSetRef, MicroPartitionRef}; use tracing::instrument; use super::source::Source; use crate::sources::source::SourceStream; pub struct InMemorySource { - data: PartitionBatchRef, + data: PartitionSetRef, schema: SchemaRef, } impl InMemorySource { - pub fn new(data: PartitionBatchRef, schema: SchemaRef) -> Self { + pub fn new(data: PartitionSetRef, schema: SchemaRef) -> Self { Self { data, schema } } pub fn arced(self) -> Arc { @@ -32,7 +32,7 @@ impl Source for InMemorySource { _maintain_order: bool, _io_stats: IOStatsRef, ) -> DaftResult> { - Ok(self.data.clone().into_partition_stream()) + Ok(self.data.clone().to_partition_stream()) } fn name(&self) -> &'static str { "InMemory" diff --git a/src/daft-logical-plan/Cargo.toml b/src/daft-logical-plan/Cargo.toml index 707d881977..f183a63ef3 100644 --- a/src/daft-logical-plan/Cargo.toml +++ b/src/daft-logical-plan/Cargo.toml @@ -4,6 +4,7 @@ common-display = {path = "../common/display", default-features = false} common-error = {path = "../common/error", default-features = false} common-file-formats = {path = "../common/file-formats", default-features = false} common-io-config = {path = "../common/io-config", default-features = false} +common-partitioning = {path = "../common/partitioning", default-features = false} common-py-serde = {path = "../common/py-serde", default-features = false} common-resource-request = {path = "../common/resource-request", default-features = false} common-scan-info = {path = "../common/scan-info", default-features = false} diff --git a/src/daft-logical-plan/src/builder.rs b/src/daft-logical-plan/src/builder.rs index e8418105e0..38921e71fe 100644 --- a/src/daft-logical-plan/src/builder.rs +++ b/src/daft-logical-plan/src/builder.rs @@ -15,7 +15,6 @@ use daft_schema::schema::{Schema, SchemaRef}; #[cfg(feature = "python")] use { crate::sink_info::{CatalogInfo, IcebergCatalogInfo}, - crate::source_info::InMemoryInfo, common_daft_config::PyDaftPlanningConfig, daft_dsl::python::PyExpr, // daft_scan::python::pylib::ScanOperatorHandle, @@ -31,7 +30,7 @@ use crate::{ HashRepartitionConfig, IntoPartitionsConfig, RandomShuffleConfig, RepartitionSpec, }, sink_info::{OutputFileInfo, SinkInfo}, - source_info::SourceInfo, + source_info::{InMemoryInfo, SourceInfo}, LogicalPlanRef, }; @@ -114,10 +113,9 @@ impl LogicalPlanBuilder { Self::new(self.plan.clone(), Some(config)) } - #[cfg(feature = "python")] pub fn in_memory_scan( partition_key: &str, - cache_entry: PyObject, + cache_entry: common_partitioning::PartitionCacheEntry, schema: Arc, num_partitions: usize, size_bytes: usize, @@ -695,7 +693,7 @@ impl PyLogicalPlanBuilder { ) -> PyResult { Ok(LogicalPlanBuilder::in_memory_scan( partition_key, - cache_entry, + common_partitioning::PartitionCacheEntry::Python(cache_entry), schema.into(), num_partitions, size_bytes, diff --git a/src/daft-logical-plan/src/optimization/rules/push_down_limit.rs b/src/daft-logical-plan/src/optimization/rules/push_down_limit.rs index 1a002ea572..8c2d8e67bc 100644 --- a/src/daft-logical-plan/src/optimization/rules/push_down_limit.rs +++ b/src/daft-logical-plan/src/optimization/rules/push_down_limit.rs @@ -273,10 +273,16 @@ mod tests { fn limit_does_not_push_into_in_memory_source() -> DaftResult<()> { let py_obj = Python::with_gil(|py| py.None()); let schema: Arc = Schema::new(vec![Field::new("a", DataType::Int64)])?.into(); - let plan = - LogicalPlanBuilder::in_memory_scan("foo", py_obj, schema, Default::default(), 5, 3)? - .limit(5, false)? - .build(); + let plan = LogicalPlanBuilder::in_memory_scan( + "foo", + common_partitioning::PartitionCacheEntry::Python(py_obj), + schema, + Default::default(), + 5, + 3, + )? + .limit(5, false)? + .build(); assert_optimized_plan_eq(plan.clone(), plan)?; Ok(()) } diff --git a/src/daft-logical-plan/src/source_info/mod.rs b/src/daft-logical-plan/src/source_info/mod.rs index 11122464a0..9ae2f1d68b 100644 --- a/src/daft-logical-plan/src/source_info/mod.rs +++ b/src/daft-logical-plan/src/source_info/mod.rs @@ -4,15 +4,11 @@ use std::{ sync::atomic::AtomicUsize, }; +use common_partitioning::PartitionCacheEntry; use common_scan_info::PhysicalScanInfo; use daft_schema::schema::SchemaRef; pub use file_info::{FileInfo, FileInfos}; use serde::{Deserialize, Serialize}; -#[cfg(feature = "python")] -use { - common_py_serde::{deserialize_py_object, serialize_py_object}, - pyo3::PyObject, -}; use crate::partitioning::ClusteringSpecRef; @@ -27,24 +23,18 @@ pub enum SourceInfo { pub struct InMemoryInfo { pub source_schema: SchemaRef, pub cache_key: String, - #[cfg(feature = "python")] - #[serde( - serialize_with = "serialize_py_object", - deserialize_with = "deserialize_py_object" - )] - pub cache_entry: PyObject, + pub cache_entry: PartitionCacheEntry, pub num_partitions: usize, pub size_bytes: usize, pub num_rows: usize, pub clustering_spec: Option, } -#[cfg(feature = "python")] impl InMemoryInfo { pub fn new( source_schema: SchemaRef, cache_key: String, - cache_entry: PyObject, + cache_entry: PartitionCacheEntry, num_partitions: usize, size_bytes: usize, num_rows: usize, diff --git a/src/daft-micropartition/Cargo.toml b/src/daft-micropartition/Cargo.toml index ef1e3bec20..c39b0b5b78 100644 --- a/src/daft-micropartition/Cargo.toml +++ b/src/daft-micropartition/Cargo.toml @@ -3,6 +3,7 @@ arrow2 = {workspace = true} bincode = {workspace = true} common-error = {path = "../common/error", default-features = false} common-file-formats = {path = "../common/file-formats", default-features = false} +common-partitioning = {path = "../common/partitioning", default-features = false} common-runtime = {path = "../common/runtime", default-features = false} common-scan-info = {path = "../common/scan-info", default-features = false} daft-core = {path = "../daft-core", default-features = false} @@ -14,13 +15,30 @@ daft-parquet = {path = "../daft-parquet", default-features = false} daft-scan = {path = "../daft-scan", default-features = false} daft-stats = {path = "../daft-stats", default-features = false} daft-table = {path = "../daft-table", default-features = false} +dashmap = "6.1.0" futures = {workspace = true} parquet2 = {workspace = true} pyo3 = {workspace = true, optional = true} snafu = {workspace = true} +tracing = {workspace = true} [features] -python = ["dep:pyo3", "common-error/python", "common-file-formats/python", "common-scan-info/python", "daft-core/python", "daft-dsl/python", "daft-table/python", "daft-io/python", "daft-parquet/python", "daft-scan/python", "daft-stats/python", "daft-csv/python", "daft-json/python"] +python = [ + "dep:pyo3", + "common-error/python", + "common-partitioning/python", + "common-file-formats/python", + "common-scan-info/python", + "daft-core/python", + "daft-dsl/python", + "daft-table/python", + "daft-io/python", + "daft-parquet/python", + "daft-scan/python", + "daft-stats/python", + "daft-csv/python", + "daft-json/python" +] [lints] workspace = true diff --git a/src/daft-micropartition/src/lib.rs b/src/daft-micropartition/src/lib.rs index 097506d887..1ea7e3cfec 100644 --- a/src/daft-micropartition/src/lib.rs +++ b/src/daft-micropartition/src/lib.rs @@ -7,7 +7,7 @@ use snafu::Snafu; mod micropartition; mod ops; -pub use micropartition::MicroPartition; +pub use micropartition::{MicroPartition, MicroPartitionRef}; #[cfg(feature = "python")] pub mod python; diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index cf9dc22a56..84b4e014dd 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -60,6 +60,7 @@ impl Display for TableState { } } } +pub type MicroPartitionRef = Arc; #[derive(Debug)] pub struct MicroPartition { diff --git a/src/daft-micropartition/src/partitioning.rs b/src/daft-micropartition/src/partitioning.rs index 7bf1a3d586..76667a8618 100644 --- a/src/daft-micropartition/src/partitioning.rs +++ b/src/daft-micropartition/src/partitioning.rs @@ -1,152 +1,79 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{ + any::Any, + sync::{Arc, Weak}, +}; use common_error::{DaftError, DaftResult}; -use daft_dsl::Expr; +pub use common_partitioning::*; +use daft_table::Table; +use dashmap::DashMap; use futures::stream::BoxStream; -use crate::MicroPartition; -type PartitionId = String; +use crate::{micropartition::MicroPartitionRef, MicroPartition}; -/// ported over from `daft/runners/partitioning.py` -// TODO: port over the rest of the functionality -#[derive(Debug, Clone)] -pub struct Boundaries { - pub sort_by: Vec, - pub bounds: Arc, -} - -/// ported over from `daft/runners/partitioning.py` -// TODO: port over the rest of the functionality -#[derive(Debug, Clone)] -pub struct PartitionMetadata { - pub num_rows: usize, - pub size_bytes: usize, - pub boundaries: Option, -} - -impl PartitionMetadata { - pub fn from_micro_partition(mp: &MicroPartition) -> Self { - let num_rows = mp.len(); - let size_bytes = mp.size_bytes().unwrap_or(None).unwrap_or(0); - Self { - num_rows, - size_bytes, - boundaries: None, - } +impl Partition for MicroPartition { + fn as_any(&self) -> &dyn Any { + self + } + fn size_bytes(&self) -> DaftResult> { + self.size_bytes() } } -/// A collection of related [`MicroPartition`]'s that can be processed as a single unit. -pub trait PartitionBatch: Send + Sync { - fn micropartitions(&self) -> Vec>; - fn metadata(&self) -> PartitionMetadata; - fn into_partition_stream( - self: Arc, - ) -> BoxStream<'static, DaftResult>>; -} - -/// an in memory batch of [`MicroPartition`]'s -#[derive(Debug, Clone)] -pub struct InMemoryPartitionBatch { - pub partition: Vec>, - pub metadata: Option, +// An in memory partition set +#[derive(Debug, Default, Clone)] +pub struct MicroPartitionSet { + pub partitions: DashMap, } - -impl PartitionBatch for InMemoryPartitionBatch { - fn micropartitions(&self) -> Vec> { - self.partition.clone() +impl From> for MicroPartitionSet { + fn from(value: Vec) -> Self { + let partitions = value + .into_iter() + .enumerate() + .map(|(i, v)| (i as PartitionId, v)) + .collect(); + Self { partitions } } +} - fn metadata(&self) -> PartitionMetadata { - if let Some(metadata) = &self.metadata { - metadata.clone() - } else if self.partition.is_empty() { - PartitionMetadata { - num_rows: 0, - size_bytes: 0, - boundaries: None, - } - } else { - PartitionMetadata::from_micro_partition(&self.partition[0]) +impl MicroPartitionSet { + pub fn new>(psets: T) -> Self { + Self { + partitions: psets.into_iter().collect(), } } - fn into_partition_stream( - self: Arc, - ) -> BoxStream<'static, DaftResult>> { - Box::pin(futures::stream::iter( - self.partition.clone().into_iter().map(Ok), - )) + pub fn empty() -> Self { + Self::default() } -} - -/// an arc'd reference to a [`PartitionBatch`] -pub type PartitionBatchRef = Arc; -/// a collection of [`MicroPartition`] -/// -/// Since we can have different partition sets such as an in memory, or a distributed partition set, we need to abstract over the partition set. -/// This trait defines the common operations that can be performed on a partition set. -pub trait PartitionSet { - /// Merge all micropartitions into a single micropartition - fn get_merged_micropartitions(&self) -> DaftResult; - /// Get a preview of the micropartitions - fn get_preview_micropartitions(&self, num_rows: usize) -> DaftResult>>; - fn items(&self) -> DaftResult>; - fn values(&self) -> DaftResult> { - let items = self.items()?; - Ok(items.into_iter().map(|(_, mp)| mp).collect()) - } - /// Number of partitions - fn num_partitions(&self) -> usize; - - fn len(&self) -> usize; - /// Check if the partition set is empty - fn is_empty(&self) -> bool; - /// Size of the partition set in bytes - fn size_bytes(&self) -> DaftResult; - /// Check if a partition exists - fn has_partition(&self, idx: &PartitionId) -> bool; - /// Delete a partition - fn delete_partition(&mut self, idx: &PartitionId) -> DaftResult<()>; - /// Set a partition - fn set_partition(&mut self, idx: PartitionId, part: PartitionBatchRef) -> DaftResult<()>; - /// Get a partition - fn get_partition(&self, idx: &PartitionId) -> DaftResult; -} - -/// An in memory partition set -#[derive(Debug, Default)] -pub struct InMemoryPartitionSet { - pub partitions: HashMap>>, -} + pub fn from_tables(id: PartitionId, tables: Vec) -> DaftResult { + if tables.is_empty() { + return Ok(Self::empty()); + } -impl InMemoryPartitionSet { - pub fn new(psets: HashMap>>) -> Self { - Self { partitions: psets } + let schema = &tables[0].schema; + let mp = MicroPartition::new_loaded(schema.clone(), Arc::new(tables), None); + Ok(Self::new(vec![(id, Arc::new(mp))])) } } -impl PartitionSet for InMemoryPartitionSet { - fn get_merged_micropartitions(&self) -> DaftResult { - let parts = self.values()?; - let parts = parts - .into_iter() - .flat_map(|mat_res| mat_res.micropartitions()); - - MicroPartition::concat(parts) +impl PartitionSet for MicroPartitionSet { + fn get_merged_partitions(&self) -> DaftResult { + let parts = self.partitions.iter().map(|v| v.value().clone()); + MicroPartition::concat(parts).map(|mp| Arc::new(mp) as _) } - fn get_preview_micropartitions( - &self, - mut num_rows: usize, - ) -> DaftResult>> { + fn get_preview_partitions(&self, mut num_rows: usize) -> DaftResult> { let mut preview_parts = vec![]; - for part in self.partitions.values().flatten() { + for part in self.partitions.iter().map(|v| v.value().clone()) { let part_len = part.len(); if part_len >= num_rows { - preview_parts.push(Arc::new(part.slice(0, num_rows)?)); + let mp = part.slice(0, num_rows)?; + let part = Arc::new(mp); + + preview_parts.push(part); break; } else { num_rows -= part_len; @@ -156,20 +83,6 @@ impl PartitionSet for InMemoryPartitionSet { Ok(preview_parts) } - fn items(&self) -> DaftResult> { - self.partitions - .iter() - .map(|(k, v)| { - let partition = InMemoryPartitionBatch { - partition: v.clone(), - metadata: None, - }; - - Ok((k.clone(), Arc::new(partition) as _)) - }) - .collect() - } - fn num_partitions(&self) -> usize { self.partitions.len() } @@ -183,40 +96,147 @@ impl PartitionSet for InMemoryPartitionSet { } fn size_bytes(&self) -> DaftResult { - let partitions = self.values()?; - let mut partitions = partitions.into_iter().flat_map(|mp| mp.micropartitions()); - partitions.try_fold(0, |acc, mp| Ok(acc + mp.size_bytes()?.unwrap_or(0))) + let mut parts = self.partitions.iter().map(|v| v.value().clone()); + + parts.try_fold(0, |acc, mp| Ok(acc + mp.size_bytes()?.unwrap_or(0))) } fn has_partition(&self, partition_id: &PartitionId) -> bool { self.partitions.contains_key(partition_id) } - fn delete_partition(&mut self, partition_id: &PartitionId) -> DaftResult<()> { + fn delete_partition(&self, partition_id: &PartitionId) -> DaftResult<()> { self.partitions.remove(partition_id); Ok(()) } - fn set_partition( - &mut self, - partition_id: PartitionId, - part: PartitionBatchRef, - ) -> DaftResult<()> { - let part = part.micropartitions(); - - self.partitions.insert(partition_id, part); + fn set_partition(&self, partition_id: PartitionId, part: &MicroPartitionRef) -> DaftResult<()> { + self.partitions.insert(partition_id, part.clone()); Ok(()) } - fn get_partition(&self, idx: &PartitionId) -> DaftResult { + fn get_partition(&self, idx: &PartitionId) -> DaftResult { let part = self .partitions .get(idx) .ok_or(DaftError::ValueError("Partition not found".to_string()))?; - Ok(Arc::new(InMemoryPartitionBatch { - partition: part.clone(), - metadata: None, - })) + Ok(part.clone()) + } + + fn to_partition_stream(&self) -> BoxStream<'static, DaftResult> { + let partitions = self.partitions.clone().into_iter().map(|(_, v)| v).map(Ok); + + Box::pin(futures::stream::iter(partitions)) + } + + fn metadata(&self) -> PartitionMetadata { + let size_bytes = self.size_bytes().unwrap_or(0); + let num_rows = self.partitions.iter().map(|v| v.value().len()).sum(); + PartitionMetadata { + num_rows, + size_bytes, + } + } +} + +/// An in-memory cache for partition sets +/// +/// Note: this holds weak references to the partition sets. It's structurally similar to a WeakValueHashMap +/// +/// This means that if the partition set is dropped, it will be removed from the cache. +/// So the partition set must outlive the lifetime of the value in the cache. +/// +/// if the partition set is dropped before the cache, it will be removed +/// ex: +/// ```rust,no_run +/// +/// let cache = InMemoryPartitionSetCache::empty(); +/// let outer =Arc::new(MicroPartitionSet::empty()); +/// cache.put_partition_set("outer", &outer); +/// { +/// let inner = Arc::new(MicroPartitionSet::empty()); +/// cache.put_partition_set("inner", &inner); +/// cache.get_partition_set("inner"); // Some(inner) +/// // inner is dropped here +/// } +/// +/// cache.get_partition_set("inner"); // None +/// cache.get_partition_set("outer"); // Some(outer) +/// drop(outer); +/// cache.get_partition_set("outer"); // None +/// ``` +#[derive(Debug, Default, Clone)] +pub struct InMemoryPartitionSetCache { + pub partition_sets: DashMap>, +} + +impl InMemoryPartitionSetCache { + pub fn new<'a, T: IntoIterator)>>( + psets: T, + ) -> Self { + Self { + partition_sets: psets + .into_iter() + .map(|(k, v)| (k.clone(), Arc::downgrade(v))) + .collect(), + } + } + pub fn empty() -> Self { + Self::default() + } +} + +impl PartitionSetCache> for InMemoryPartitionSetCache { + fn get_partition_set(&self, key: &str) -> Option> { + let weak_pset = self.partition_sets.get(key).map(|v| v.value().clone())?; + // if the partition set has been dropped, remove it from the cache + let Some(pset) = weak_pset.upgrade() else { + tracing::trace!("Removing dropped partition set from cache: {}", key); + self.partition_sets.remove(key); + return None; + }; + + Some(pset as _) + } + + fn get_all_partition_sets(&self) -> Vec> { + let psets = self.partition_sets.iter().filter_map(|v| { + let pset = v.value().upgrade()?; + Some(pset as _) + }); + + psets.collect() + } + + fn put_partition_set(&self, key: &str, partition_set: &Arc) { + self.partition_sets + .insert(key.to_string(), Arc::downgrade(partition_set)); + } + + fn rm_partition_set(&self, key: &str) { + self.partition_sets.remove(key); + } + + fn clear(&self) { + self.partition_sets.clear(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_cache_drops_pset() { + let cache = InMemoryPartitionSetCache::empty(); + + { + let pset = Arc::new(MicroPartitionSet::empty()); + cache.put_partition_set("key", &pset); + assert!(cache.get_partition_set("key").is_some()); + } + + assert!(cache.get_partition_set("key").is_none()); } } diff --git a/src/daft-scheduler/Cargo.toml b/src/daft-scheduler/Cargo.toml index b69b8ff548..7132253ac9 100644 --- a/src/daft-scheduler/Cargo.toml +++ b/src/daft-scheduler/Cargo.toml @@ -4,6 +4,7 @@ common-display = {path = "../common/display", default-features = false} common-error = {path = "../common/error", default-features = false} common-file-formats = {path = "../common/file-formats", default-features = false} common-io-config = {path = "../common/io-config", default-features = false} +common-partitioning = {path = "../common/partitioning", default-features = false} common-py-serde = {path = "../common/py-serde", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-dsl = {path = "../daft-dsl", default-features = false} diff --git a/src/daft-scheduler/src/adaptive.rs b/src/daft-scheduler/src/adaptive.rs index c85e575467..d3b27ca310 100644 --- a/src/daft-scheduler/src/adaptive.rs +++ b/src/daft-scheduler/src/adaptive.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use common_daft_config::DaftExecutionConfig; +use common_partitioning::PartitionCacheEntry; use daft_core::prelude::Schema; use daft_logical_plan::{InMemoryInfo, LogicalPlan}; use daft_physical_plan::{AdaptivePlanner, MaterializedResults}; @@ -66,7 +67,7 @@ impl AdaptivePhysicalPlanScheduler { let in_memory_info = InMemoryInfo::new( Schema::empty().into(), // TODO thread in schema from in memory scan partition_key.into(), - cache_entry, + PartitionCacheEntry::Python(cache_entry), num_partitions, size_bytes, num_rows, From 3a3707aec24145d5770c4f5511e30977006193d3 Mon Sep 17 00:00:00 2001 From: Kev Wang Date: Wed, 18 Dec 2024 14:31:13 -0800 Subject: [PATCH 13/14] chore!: remove pyarrow-based file reader (#3587) BREAKING CHANGE: removed the following: reading using PyArrow, `PythonStorageConfig` type and `StorageConfig.python` function, `daft.table.schema_inference`, refactored `StorageConfig` and `NativeStorageConfig` into one type Also includes two small fixes: - `Schema` struct equality now also depends on the ordering of the fields - Monotonically increasing ID schema now always inserts ID as the first column todo: - [x] refactor StorageConfig and NativeStorageConfig into one struct - [x] clean up table_io.py --- daft/daft/__init__.pyi | 31 +- daft/delta_lake/delta_lake_scan.py | 2 +- daft/expressions/expressions.py | 38 +- daft/hudi/hudi_scan.py | 2 +- daft/io/_csv.py | 11 +- daft/io/_deltalake.py | 6 +- daft/io/_hudi.py | 4 +- daft/io/_iceberg.py | 4 +- daft/io/_json.py | 11 +- daft/io/_parquet.py | 9 +- daft/io/_sql.py | 4 +- daft/sql/sql_scan.py | 2 +- daft/table/schema_inference.py | 141 ------ daft/table/table_io.py | 300 +++-------- daft/udf_library/__init__.py | 0 daft/udf_library/url_udfs.py | 88 ---- .../src/sources/scan_task.rs | 23 +- .../src/ops/monotonically_increasing_id.rs | 3 +- src/daft-micropartition/src/micropartition.rs | 477 +++++++----------- src/daft-micropartition/src/python.rs | 10 +- src/daft-scan/src/builder.rs | 31 +- src/daft-scan/src/lib.rs | 17 +- src/daft-scan/src/python.rs | 26 +- src/daft-scan/src/scan_task_iters.rs | 6 +- src/daft-scan/src/storage_config.rs | 194 +------ src/daft-schema/src/schema.rs | 16 +- src/daft-sql/src/table_provider/read_csv.rs | 3 - tests/dataframe/test_creation.py | 77 +-- tests/dataframe/test_temporals.py | 7 +- .../integration/io/test_url_download_http.py | 7 +- .../io/test_url_download_private_aws_s3.py | 4 +- .../io/test_url_download_public_aws_s3.py | 10 +- .../io/test_url_download_public_azure.py | 4 +- ...test_url_download_s3_local_retry_server.py | 4 +- tests/io/test_parquet.py | 18 +- tests/sql/test_table_funcs.py | 4 +- tests/table/table_io/test_csv.py | 101 +--- tests/table/table_io/test_json.py | 51 +- tests/table/table_io/test_parquet.py | 61 +-- tests/table/table_io/test_read_time_cast.py | 4 +- tests/udf_library/test_url_udfs.py | 47 +- 41 files changed, 467 insertions(+), 1391 deletions(-) delete mode 100644 daft/table/schema_inference.py delete mode 100644 daft/udf_library/__init__.py delete mode 100644 daft/udf_library/url_udfs.py diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index da292d2df1..6860f72491 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -586,37 +586,14 @@ class IOConfig: """Replaces values if provided, returning a new IOConfig.""" ... -class NativeStorageConfig: - """Storage configuration for the Rust-native I/O layer.""" +class StorageConfig: + """Configuration for interacting with a particular storage backend.""" # Whether or not to use a multithreaded tokio runtime for processing I/O multithreaded_io: bool io_config: IOConfig - def __init__(self, multithreaded_io: bool, io_config: IOConfig): ... - -class PythonStorageConfig: - """Storage configuration for the legacy Python I/O layer.""" - - io_config: IOConfig - - def __init__(self, io_config: IOConfig): ... - -class StorageConfig: - """Configuration for interacting with a particular storage backend, using a particular I/O layer implementation.""" - - @staticmethod - def native(config: NativeStorageConfig) -> StorageConfig: - """Create from a native storage config.""" - ... - - @staticmethod - def python(config: PythonStorageConfig) -> StorageConfig: - """Create from a Python storage config.""" - ... - - @property - def config(self) -> NativeStorageConfig | PythonStorageConfig: ... + def __init__(self, multithreaded_io: bool, io_config: IOConfig | None): ... class ScanTask: """A batch of scan tasks for reading data from an external source.""" @@ -650,8 +627,8 @@ class ScanTask: url: str, file_format: FileFormatConfig, schema: PySchema, - num_rows: int | None, storage_config: StorageConfig, + num_rows: int | None, size_bytes: int | None, pushdowns: Pushdowns | None, stats: PyTable | None, diff --git a/daft/delta_lake/delta_lake_scan.py b/daft/delta_lake/delta_lake_scan.py index 20dd2e93ae..b81b0128e9 100644 --- a/daft/delta_lake/delta_lake_scan.py +++ b/daft/delta_lake/delta_lake_scan.py @@ -41,7 +41,7 @@ def __init__( # Thus, if we don't detect any credentials being available, we attempt to detect it from the environment using our Daft credentials chain. # # See: https://github.com/delta-io/delta-rs/issues/2117 - deltalake_sdk_io_config = storage_config.config.io_config + deltalake_sdk_io_config = storage_config.io_config scheme = urlparse(table_uri).scheme if scheme == "s3" or scheme == "s3a": # Try to get region from boto3 diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 5e865ff936..58440bfa80 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -1333,7 +1333,6 @@ def download( max_connections: int = 32, on_error: Literal["raise", "null"] = "raise", io_config: IOConfig | None = None, - use_native_downloader: bool = True, ) -> Expression: """Treats each string as a URL, and downloads the bytes contents as a bytes column. @@ -1351,37 +1350,26 @@ def download( the error but fallback to a Null value. Defaults to "raise". io_config: IOConfig to use when accessing remote storage. Note that the S3Config's `max_connections` parameter will be overridden with `max_connections` that is passed in as a kwarg. - use_native_downloader (bool): Use the native downloader rather than python based one. - Defaults to True. Returns: Expression: a Binary expression which is the bytes contents of the URL, or None if an error occurred during download """ - if use_native_downloader: + raise_on_error = False + if on_error == "raise": + raise_on_error = True + elif on_error == "null": raise_on_error = False - if on_error == "raise": - raise_on_error = True - elif on_error == "null": - raise_on_error = False - else: - raise NotImplementedError(f"Unimplemented on_error option: {on_error}.") - - if not (isinstance(max_connections, int) and max_connections > 0): - raise ValueError(f"Invalid value for `max_connections`: {max_connections}") - - multi_thread = ExpressionUrlNamespace._should_use_multithreading_tokio_runtime() - io_config = ExpressionUrlNamespace._override_io_config_max_connections(max_connections, io_config) - return Expression._from_pyexpr( - _url_download(self._expr, max_connections, raise_on_error, multi_thread, io_config) - ) else: - from daft.udf_library import url_udfs + raise NotImplementedError(f"Unimplemented on_error option: {on_error}.") - return url_udfs.download_udf( - Expression._from_pyexpr(self._expr), - max_worker_threads=max_connections, - on_error=on_error, - ) + if not (isinstance(max_connections, int) and max_connections > 0): + raise ValueError(f"Invalid value for `max_connections`: {max_connections}") + + multi_thread = ExpressionUrlNamespace._should_use_multithreading_tokio_runtime() + io_config = ExpressionUrlNamespace._override_io_config_max_connections(max_connections, io_config) + return Expression._from_pyexpr( + _url_download(self._expr, max_connections, raise_on_error, multi_thread, io_config) + ) def upload( self, diff --git a/daft/hudi/hudi_scan.py b/daft/hudi/hudi_scan.py index d68cfd99bb..2f342e4a6a 100644 --- a/daft/hudi/hudi_scan.py +++ b/daft/hudi/hudi_scan.py @@ -25,7 +25,7 @@ class HudiScanOperator(ScanOperator): def __init__(self, table_uri: str, storage_config: StorageConfig) -> None: super().__init__() - resolved_path, self._resolved_fs = _resolve_paths_and_filesystem(table_uri, storage_config.config.io_config) + resolved_path, self._resolved_fs = _resolve_paths_and_filesystem(table_uri, storage_config.io_config) self._table = HudiTable(table_uri, self._resolved_fs, resolved_path[0]) self._storage_config = storage_config self._schema = Schema.from_pyarrow_schema(self._table.schema) diff --git a/daft/io/_csv.py b/daft/io/_csv.py index 27ebd2a096..5c57c21918 100644 --- a/daft/io/_csv.py +++ b/daft/io/_csv.py @@ -8,8 +8,6 @@ CsvSourceConfig, FileFormatConfig, IOConfig, - NativeStorageConfig, - PythonStorageConfig, StorageConfig, ) from daft.dataframe import DataFrame @@ -32,7 +30,6 @@ def read_csv( io_config: Optional["IOConfig"] = None, file_path_column: Optional[str] = None, hive_partitioning: bool = False, - use_native_downloader: bool = True, schema_hints: Optional[Dict[str, DataType]] = None, _buffer_size: Optional[int] = None, _chunk_size: Optional[int] = None, @@ -58,8 +55,6 @@ def read_csv( io_config (IOConfig): Config to be used with the native downloader file_path_column: Include the source path(s) as a column with this name. Defaults to None. hive_partitioning: Whether to infer hive_style partitions from file paths and include them as columns in the Dataframe. Defaults to False. - use_native_downloader: Whether to use the native downloader instead of PyArrow for reading Parquet. This - is currently experimental. returns: DataFrame: parsed DataFrame @@ -91,10 +86,8 @@ def read_csv( chunk_size=_chunk_size, ) file_format_config = FileFormatConfig.from_csv_config(csv_config) - if use_native_downloader: - storage_config = StorageConfig.native(NativeStorageConfig(True, io_config)) - else: - storage_config = StorageConfig.python(PythonStorageConfig(io_config=io_config)) + storage_config = StorageConfig(True, io_config) + builder = get_tabular_files_scan( path=path, infer_schema=infer_schema, diff --git a/daft/io/_deltalake.py b/daft/io/_deltalake.py index 24ba15ee43..8cbf6a6cf5 100644 --- a/daft/io/_deltalake.py +++ b/daft/io/_deltalake.py @@ -4,7 +4,7 @@ from daft import context from daft.api_annotations import PublicAPI -from daft.daft import IOConfig, NativeStorageConfig, ScanOperatorHandle, StorageConfig +from daft.daft import IOConfig, ScanOperatorHandle, StorageConfig from daft.dataframe import DataFrame from daft.dependencies import unity_catalog from daft.io.catalog import DataCatalogTable @@ -60,7 +60,7 @@ def read_deltalake( ) io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config - storage_config = StorageConfig.native(NativeStorageConfig(multithreaded_io, io_config)) + storage_config = StorageConfig(multithreaded_io, io_config) if isinstance(table, str): table_uri = table @@ -72,7 +72,7 @@ def read_deltalake( # Override the storage_config with the one provided by Unity catalog table_io_config = table.io_config if table_io_config is not None: - storage_config = StorageConfig.native(NativeStorageConfig(multithreaded_io, table_io_config)) + storage_config = StorageConfig(multithreaded_io, table_io_config) else: raise ValueError( f"table argument must be a table URI string, DataCatalogTable or UnityCatalogTable instance, but got: {type(table)}, {table}" diff --git a/daft/io/_hudi.py b/daft/io/_hudi.py index a8d4c6f999..2a70188f65 100644 --- a/daft/io/_hudi.py +++ b/daft/io/_hudi.py @@ -4,7 +4,7 @@ from daft import context from daft.api_annotations import PublicAPI -from daft.daft import IOConfig, NativeStorageConfig, ScanOperatorHandle, StorageConfig +from daft.daft import IOConfig, ScanOperatorHandle, StorageConfig from daft.dataframe import DataFrame from daft.logical.builder import LogicalPlanBuilder @@ -33,7 +33,7 @@ def read_hudi( io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config multithreaded_io = context.get_context().get_or_create_runner().name != "ray" - storage_config = StorageConfig.native(NativeStorageConfig(multithreaded_io, io_config)) + storage_config = StorageConfig(multithreaded_io, io_config) hudi_operator = HudiScanOperator(table_uri, storage_config=storage_config) diff --git a/daft/io/_iceberg.py b/daft/io/_iceberg.py index 62f47babba..dbf94dd76d 100644 --- a/daft/io/_iceberg.py +++ b/daft/io/_iceberg.py @@ -4,7 +4,7 @@ from daft import context from daft.api_annotations import PublicAPI -from daft.daft import IOConfig, NativeStorageConfig, ScanOperatorHandle, StorageConfig +from daft.daft import IOConfig, ScanOperatorHandle, StorageConfig from daft.dataframe import DataFrame from daft.logical.builder import LogicalPlanBuilder @@ -123,7 +123,7 @@ def read_iceberg( io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config multithreaded_io = context.get_context().get_or_create_runner().name != "ray" - storage_config = StorageConfig.native(NativeStorageConfig(multithreaded_io, io_config)) + storage_config = StorageConfig(multithreaded_io, io_config) iceberg_operator = IcebergScanOperator(pyiceberg_table, snapshot_id=snapshot_id, storage_config=storage_config) diff --git a/daft/io/_json.py b/daft/io/_json.py index 3626d775a1..62bb02ac15 100644 --- a/daft/io/_json.py +++ b/daft/io/_json.py @@ -8,8 +8,6 @@ FileFormatConfig, IOConfig, JsonSourceConfig, - NativeStorageConfig, - PythonStorageConfig, StorageConfig, ) from daft.dataframe import DataFrame @@ -25,7 +23,6 @@ def read_json( io_config: Optional["IOConfig"] = None, file_path_column: Optional[str] = None, hive_partitioning: bool = False, - use_native_downloader: bool = True, schema_hints: Optional[Dict[str, DataType]] = None, _buffer_size: Optional[int] = None, _chunk_size: Optional[int] = None, @@ -45,8 +42,6 @@ def read_json( io_config (IOConfig): Config to be used with the native downloader file_path_column: Include the source path(s) as a column with this name. Defaults to None. hive_partitioning: Whether to infer hive_style partitions from file paths and include them as columns in the Dataframe. Defaults to False. - use_native_downloader: Whether to use the native downloader instead of PyArrow for reading Parquet. This - is currently experimental. returns: DataFrame: parsed DataFrame @@ -68,10 +63,8 @@ def read_json( json_config = JsonSourceConfig(_buffer_size, _chunk_size) file_format_config = FileFormatConfig.from_json_config(json_config) - if use_native_downloader: - storage_config = StorageConfig.native(NativeStorageConfig(True, io_config)) - else: - storage_config = StorageConfig.python(PythonStorageConfig(io_config=io_config)) + storage_config = StorageConfig(True, io_config) + builder = get_tabular_files_scan( path=path, infer_schema=infer_schema, diff --git a/daft/io/_parquet.py b/daft/io/_parquet.py index b65c9b5441..e133f2a505 100644 --- a/daft/io/_parquet.py +++ b/daft/io/_parquet.py @@ -7,9 +7,7 @@ from daft.daft import ( FileFormatConfig, IOConfig, - NativeStorageConfig, ParquetSourceConfig, - PythonStorageConfig, StorageConfig, ) from daft.dataframe import DataFrame @@ -26,7 +24,6 @@ def read_parquet( io_config: Optional["IOConfig"] = None, file_path_column: Optional[str] = None, hive_partitioning: bool = False, - use_native_downloader: bool = True, coerce_int96_timestamp_unit: Optional[Union[str, TimeUnit]] = None, schema_hints: Optional[Dict[str, DataType]] = None, _multithreaded_io: Optional[bool] = None, @@ -49,7 +46,6 @@ def read_parquet( io_config (IOConfig): Config to be used with the native downloader file_path_column: Include the source path(s) as a column with this name. Defaults to None. hive_partitioning: Whether to infer hive_style partitions from file paths and include them as columns in the Dataframe. Defaults to False. - use_native_downloader: Whether to use the native downloader instead of PyArrow for reading Parquet. coerce_int96_timestamp_unit: TimeUnit to coerce Int96 TimeStamps to. e.g.: [ns, us, ms], Defaults to None. _multithreaded_io: Whether to use multithreading for IO threads. Setting this to False can be helpful in reducing the amount of system resources (number of connections and thread contention) when running in the Ray runner. @@ -87,10 +83,7 @@ def read_parquet( file_format_config = FileFormatConfig.from_parquet_config( ParquetSourceConfig(coerce_int96_timestamp_unit=pytimeunit, row_groups=row_groups, chunk_size=_chunk_size) ) - if use_native_downloader: - storage_config = StorageConfig.native(NativeStorageConfig(multithreaded_io, io_config)) - else: - storage_config = StorageConfig.python(PythonStorageConfig(io_config=io_config)) + storage_config = StorageConfig(multithreaded_io, io_config) builder = get_tabular_files_scan( path=path, diff --git a/daft/io/_sql.py b/daft/io/_sql.py index 1fa714fcc5..09065dbfd1 100644 --- a/daft/io/_sql.py +++ b/daft/io/_sql.py @@ -5,7 +5,7 @@ from daft import context, from_pydict from daft.api_annotations import PublicAPI -from daft.daft import PythonStorageConfig, ScanOperatorHandle, StorageConfig +from daft.daft import ScanOperatorHandle, StorageConfig from daft.dataframe import DataFrame from daft.datatype import DataType from daft.logical.builder import LogicalPlanBuilder @@ -94,7 +94,7 @@ def read_sql( ) io_config = context.get_context().daft_planning_config.default_io_config - storage_config = StorageConfig.python(PythonStorageConfig(io_config)) + storage_config = StorageConfig(True, io_config) sql_conn = SQLConnection.from_url(conn) if isinstance(conn, str) else SQLConnection.from_connection_factory(conn) sql_operator = SQLScanOperator( diff --git a/daft/sql/sql_scan.py b/daft/sql/sql_scan.py index 48ee7fe4d6..7c090ef90a 100644 --- a/daft/sql/sql_scan.py +++ b/daft/sql/sql_scan.py @@ -271,8 +271,8 @@ def _construct_scan_task( url=self.conn.url, file_format=file_format_config, schema=self._schema._schema, - num_rows=num_rows, storage_config=self.storage_config, + num_rows=num_rows, size_bytes=size_bytes, pushdowns=pushdowns if not apply_pushdowns_to_sql else None, stats=stats, diff --git a/daft/table/schema_inference.py b/daft/table/schema_inference.py deleted file mode 100644 index 1d0a1bffe7..0000000000 --- a/daft/table/schema_inference.py +++ /dev/null @@ -1,141 +0,0 @@ -from __future__ import annotations - -import pathlib - -from daft.daft import ( - CsvParseOptions, - JsonParseOptions, - NativeStorageConfig, - PythonStorageConfig, - StorageConfig, -) -from daft.datatype import DataType -from daft.dependencies import pacsv, pajson, pq -from daft.filesystem import _resolve_paths_and_filesystem -from daft.logical.schema import Schema -from daft.runners.partitioning import TableParseCSVOptions -from daft.table import MicroPartition -from daft.table.table_io import FileInput, _open_stream - - -def from_csv( - file: FileInput, - storage_config: StorageConfig | None = None, - csv_options: TableParseCSVOptions = TableParseCSVOptions(), -) -> Schema: - """Infers a Schema from a CSV file. - - Args: - file (str | IO): either a file-like object or a string file path (potentially prefixed with a protocol such as "s3://") - fs (fsspec.AbstractFileSystem): fsspec FileSystem to use for reading data. - By default, Daft will automatically construct a FileSystem instance internally. - csv_options (vPartitionParseCSVOptions, optional): CSV-specific configs to apply when reading the file - read_options (TableReadOptions, optional): Options for reading the file - Returns: - Schema: Inferred Schema from the CSV. - """ - # Have PyArrow generate the column names if user specifies that there are no headers - pyarrow_autogenerate_column_names = csv_options.header_index is None - - io_config = None - if storage_config is not None: - config = storage_config.config - if isinstance(config, NativeStorageConfig): - assert isinstance(file, (str, pathlib.Path)), "Native downloader only works on string inputs to read_csv" - io_config = config.io_config - return Schema.from_csv( - str(file), - parse_options=CsvParseOptions( - has_header=csv_options.header_index is not None, - delimiter=csv_options.delimiter, - double_quote=csv_options.double_quote, - quote=csv_options.quote, - allow_variable_columns=csv_options.allow_variable_columns, - escape_char=csv_options.escape_char, - comment=csv_options.comment, - ), - io_config=io_config, - ) - - assert isinstance(config, PythonStorageConfig) - io_config = config.io_config - - with _open_stream(file, io_config) as f: - reader = pacsv.open_csv( - f, - parse_options=pacsv.ParseOptions( - delimiter=csv_options.delimiter, - ), - read_options=pacsv.ReadOptions( - autogenerate_column_names=pyarrow_autogenerate_column_names, - ), - ) - - return Schema.from_pyarrow_schema(reader.schema) - - -def from_json( - file: FileInput, - storage_config: StorageConfig | None = None, -) -> Schema: - """Reads a Schema from a JSON file. - - Args: - file (FileInput): either a file-like object or a string file path (potentially prefixed with a protocol such as "s3://") - read_options (TableReadOptions, optional): Options for reading the file - - Returns: - Schema: Inferred Schema from the JSON - """ - io_config = None - if storage_config is not None: - config = storage_config.config - if isinstance(config, NativeStorageConfig): - assert isinstance(file, (str, pathlib.Path)), "Native downloader only works on string inputs to read_json" - io_config = config.io_config - return Schema.from_json( - str(file), - parse_options=JsonParseOptions(), - io_config=io_config, - ) - - assert isinstance(config, PythonStorageConfig) - io_config = config.io_config - - with _open_stream(file, io_config) as f: - table = pajson.read_json(f) - - return MicroPartition.from_arrow(table).schema() - - -def from_parquet( - file: FileInput, - storage_config: StorageConfig | None = None, -) -> Schema: - """Infers a Schema from a Parquet file.""" - io_config = None - if storage_config is not None: - config = storage_config.config - if isinstance(config, NativeStorageConfig): - assert isinstance( - file, (str, pathlib.Path) - ), "Native downloader only works on string inputs to read_parquet" - io_config = config.io_config - return Schema.from_parquet(str(file), io_config=io_config) - - assert isinstance(config, PythonStorageConfig) - io_config = config.io_config - - if not isinstance(file, (str, pathlib.Path)): - # BytesIO path. - f = file - else: - paths, fs = _resolve_paths_and_filesystem(file, io_config=io_config) - assert len(paths) == 1 - path = paths[0] - f = fs.open_input_file(path) - - pqf = pq.ParquetFile(f) - arrow_schema = pqf.metadata.schema.to_arrow_schema() - - return Schema._from_field_name_and_types([(f.name, DataType.from_arrow_type(f.type)) for f in arrow_schema]) diff --git a/daft/table/table_io.py b/daft/table/table_io.py index 1a74e37659..30744d57d3 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -1,11 +1,10 @@ from __future__ import annotations -import contextlib import math import pathlib import random import time -from typing import IO, TYPE_CHECKING, Any, Iterator, Union +from typing import TYPE_CHECKING, Any, Iterator, Union from uuid import uuid4 from daft.context import get_context @@ -18,11 +17,9 @@ JsonConvertOptions, JsonParseOptions, JsonReadOptions, - NativeStorageConfig, - PythonStorageConfig, StorageConfig, ) -from daft.dependencies import pa, pacsv, pads, pajson, pq +from daft.dependencies import pa, pacsv, pads, pq from daft.expressions import ExpressionsProjection, col from daft.filesystem import ( _resolve_paths_and_filesystem, @@ -41,10 +38,10 @@ from .micropartition import MicroPartition from .partitioning import PartitionedTable, partition_strings_to_path -FileInput = Union[pathlib.Path, str, IO[bytes]] +FileInput = Union[pathlib.Path, str] if TYPE_CHECKING: - from collections.abc import Callable, Generator + from collections.abc import Callable from pyiceberg.schema import Schema as IcebergSchema from pyiceberg.table import TableProperties as IcebergTableProperties @@ -53,22 +50,6 @@ from daft.sql.sql_connection import SQLConnection -@contextlib.contextmanager -def _open_stream( - file: FileInput, - io_config: IOConfig | None, -) -> Generator[pa.NativeFile, None, None]: - """Opens the provided file for reading, yield a pyarrow file handle.""" - if isinstance(file, (pathlib.Path, str)): - paths, fs = _resolve_paths_and_filesystem(file, io_config=io_config) - assert len(paths) == 1 - path = paths[0] - with fs.open_input_stream(path) as f: - yield f - else: - yield file - - def _cast_table_to_schema(table: MicroPartition, read_options: TableReadOptions, schema: Schema) -> pa.Table: """Performs a cast of a Daft MicroPartition to the requested Schema/Data. @@ -108,40 +89,23 @@ def read_json( Returns: MicroPartition: Parsed MicroPartition from JSON """ - io_config = None - if storage_config is not None: - config = storage_config.config - if isinstance(config, NativeStorageConfig): - assert isinstance(file, (str, pathlib.Path)), "Native downloader only works on string inputs to read_json" - json_convert_options = JsonConvertOptions( - limit=read_options.num_rows, - include_columns=read_options.column_names, - schema=schema._schema if schema is not None else None, - ) - json_parse_options = JsonParseOptions() - tbl = MicroPartition.read_json( - str(file), - convert_options=json_convert_options, - parse_options=json_parse_options, - read_options=json_read_options, - io_config=config.io_config, - ) - return _cast_table_to_schema(tbl, read_options=read_options, schema=schema) - else: - assert isinstance(config, PythonStorageConfig) - io_config = config.io_config - - with _open_stream(file, io_config) as f: - table = pajson.read_json(f) - - if read_options.column_names is not None: - table = table.select(read_options.column_names) - - # TODO(jay): Can't limit number of rows with current PyArrow filesystem so we have to shave it off after the read - if read_options.num_rows is not None: - table = table[: read_options.num_rows] - - return _cast_table_to_schema(MicroPartition.from_arrow(table), read_options=read_options, schema=schema) + # TODO: move this logic into Rust + config = storage_config if storage_config is not None else StorageConfig(True, IOConfig()) + assert isinstance(file, (str, pathlib.Path)), "Native downloader only works on string inputs to read_json" + json_convert_options = JsonConvertOptions( + limit=read_options.num_rows, + include_columns=read_options.column_names, + schema=schema._schema if schema is not None else None, + ) + json_parse_options = JsonParseOptions() + tbl = MicroPartition.read_json( + str(file), + convert_options=json_convert_options, + parse_options=json_parse_options, + read_options=json_read_options, + io_config=config.io_config, + ) + return _cast_table_to_schema(tbl, read_options=read_options, schema=schema) def read_parquet( @@ -162,70 +126,20 @@ def read_parquet( Returns: MicroPartition: Parsed MicroPartition from Parquet """ - io_config = None - if storage_config is not None: - config = storage_config.config - if isinstance(config, NativeStorageConfig): - assert isinstance( - file, (str, pathlib.Path) - ), "Native downloader only works on string or Path inputs to read_parquet" - tbl = MicroPartition.read_parquet( - str(file), - columns=read_options.column_names, - num_rows=read_options.num_rows, - io_config=config.io_config, - coerce_int96_timestamp_unit=parquet_options.coerce_int96_timestamp_unit, - multithreaded_io=config.multithreaded_io, - ) - return _cast_table_to_schema(tbl, read_options=read_options, schema=schema) - - assert isinstance(config, PythonStorageConfig) - io_config = config.io_config - - f: IO - if not isinstance(file, (str, pathlib.Path)): - f = file - else: - paths, fs = _resolve_paths_and_filesystem(file, io_config=io_config) - assert len(paths) == 1 - path = paths[0] - f = fs.open_input_file(path) - - # If no rows required, we manually construct an empty table with the right schema - if read_options.num_rows == 0: - pqf = pq.ParquetFile( - f, - coerce_int96_timestamp_unit=str(parquet_options.coerce_int96_timestamp_unit), - ) - arrow_schema = pqf.metadata.schema.to_arrow_schema() - table = pa.Table.from_arrays( - [pa.array([], type=field.type) for field in arrow_schema], - schema=arrow_schema, - ) - elif read_options.num_rows is not None: - pqf = pq.ParquetFile( - f, - coerce_int96_timestamp_unit=str(parquet_options.coerce_int96_timestamp_unit), - ) - # Only read the required row groups. - rows_needed = read_options.num_rows - for i in range(pqf.metadata.num_row_groups): - row_group_meta = pqf.metadata.row_group(i) - rows_needed -= row_group_meta.num_rows - if rows_needed <= 0: - break - table = pqf.read_row_groups(list(range(i + 1)), columns=read_options.column_names) - if rows_needed < 0: - # Need to truncate the table to the row limit. - table = table.slice(length=read_options.num_rows) - else: - table = pq.read_table( - f, - columns=read_options.column_names, - coerce_int96_timestamp_unit=str(parquet_options.coerce_int96_timestamp_unit), - ) - - return _cast_table_to_schema(MicroPartition.from_arrow(table), read_options=read_options, schema=schema) + # TODO: move this logic into Rust + config = storage_config if storage_config is not None else StorageConfig(True, IOConfig()) + assert isinstance( + file, (str, pathlib.Path) + ), "Native downloader only works on string or Path inputs to read_parquet" + tbl = MicroPartition.read_parquet( + str(file), + columns=read_options.column_names, + num_rows=read_options.num_rows, + io_config=config.io_config, + coerce_int96_timestamp_unit=parquet_options.coerce_int96_timestamp_unit, + multithreaded_io=config.multithreaded_io, + ) + return _cast_table_to_schema(tbl, read_options=read_options, schema=schema) def read_sql( @@ -260,24 +174,6 @@ def read_sql( return _cast_table_to_schema(mp, read_options=read_options, schema=schema) -class PACSVStreamHelper: - def __init__(self, stream: pa.CSVStreamReader) -> None: - self.stream = stream - - def __next__(self) -> pa.RecordBatch: - return self.stream.read_next_batch() - - def __iter__(self) -> PACSVStreamHelper: - return self - - -def skip_comment(comment: str | None) -> Callable | None: - if comment is None: - return None - else: - return lambda row: "skip" if row.text.startswith(comment) else "error" - - def read_csv( file: FileInput, schema: Schema, @@ -298,104 +194,34 @@ def read_csv( Returns: MicroPartition: Parsed MicroPartition from CSV """ - io_config = None - if storage_config is not None: - config = storage_config.config - if isinstance(config, NativeStorageConfig): - assert isinstance( - file, (str, pathlib.Path) - ), "Native downloader only works on string or Path inputs to read_csv" - has_header = csv_options.header_index is not None - csv_convert_options = CsvConvertOptions( - limit=read_options.num_rows, - include_columns=read_options.column_names, - column_names=schema.column_names() if not has_header else None, - schema=schema._schema if schema is not None else None, - ) - csv_parse_options = CsvParseOptions( - has_header=has_header, - delimiter=csv_options.delimiter, - double_quote=csv_options.double_quote, - quote=csv_options.quote, - allow_variable_columns=csv_options.allow_variable_columns, - escape_char=csv_options.escape_char, - comment=csv_options.comment, - ) - csv_read_options = CsvReadOptions(buffer_size=csv_options.buffer_size, chunk_size=csv_options.chunk_size) - tbl = MicroPartition.read_csv( - str(file), - convert_options=csv_convert_options, - parse_options=csv_parse_options, - read_options=csv_read_options, - io_config=config.io_config, - ) - return _cast_table_to_schema(tbl, read_options=read_options, schema=schema) - else: - assert isinstance(config, PythonStorageConfig) - io_config = config.io_config - - with _open_stream(file, io_config) as f: - from daft.utils import get_arrow_version - - arrow_version = get_arrow_version() - - if csv_options.comment is not None and arrow_version < (7, 0, 0): - raise ValueError( - "pyarrow < 7.0.0 doesn't support handling comments in CSVs, please upgrade pyarrow to 7.0.0+." - ) - - parse_options = pacsv.ParseOptions( - delimiter=csv_options.delimiter, - quote_char=csv_options.quote, - escape_char=csv_options.escape_char, - ) - - if arrow_version >= (7, 0, 0): - parse_options.invalid_row_handler = skip_comment(csv_options.comment) - - pacsv_stream = pacsv.open_csv( - f, - parse_options=parse_options, - read_options=pacsv.ReadOptions( - # If no header, we use the schema's column names. Otherwise we use the headers in the CSV file. - column_names=(schema.column_names() if csv_options.header_index is None else None), - ), - convert_options=pacsv.ConvertOptions( - # Column pruning - include_columns=read_options.column_names, - # If any columns are missing, parse as null array - include_missing_columns=True, - ), - ) - - if read_options.num_rows is not None: - rows_left = read_options.num_rows - pa_batches = [] - pa_schema = None - for record_batch in PACSVStreamHelper(pacsv_stream): - if pa_schema is None: - pa_schema = record_batch.schema - if record_batch.num_rows > rows_left: - record_batch = record_batch.slice(0, rows_left) - pa_batches.append(record_batch) - rows_left -= record_batch.num_rows - - # Break needs to be here; always need to process at least one record batch - if rows_left <= 0: - break - - # If source schema isn't determined, then the file was truly empty; set an empty source schema - if pa_schema is None: - pa_schema = pa.schema([]) - - daft_table = MicroPartition.from_arrow_record_batches(pa_batches, pa_schema) - assert len(daft_table) <= read_options.num_rows - - else: - pa_table = pacsv_stream.read_all() - daft_table = MicroPartition.from_arrow(pa_table) - - return _cast_table_to_schema(daft_table, read_options=read_options, schema=schema) + # TODO: move this logic into Rust + config = storage_config if storage_config is not None else StorageConfig(True, IOConfig()) + assert isinstance(file, (str, pathlib.Path)), "Native downloader only works on string or Path inputs to read_csv" + has_header = csv_options.header_index is not None + csv_convert_options = CsvConvertOptions( + limit=read_options.num_rows, + include_columns=read_options.column_names, + column_names=schema.column_names() if not has_header else None, + schema=schema._schema if schema is not None else None, + ) + csv_parse_options = CsvParseOptions( + has_header=has_header, + delimiter=csv_options.delimiter, + double_quote=csv_options.double_quote, + quote=csv_options.quote, + allow_variable_columns=csv_options.allow_variable_columns, + escape_char=csv_options.escape_char, + comment=csv_options.comment, + ) + csv_read_options = CsvReadOptions(buffer_size=csv_options.buffer_size, chunk_size=csv_options.chunk_size) + tbl = MicroPartition.read_csv( + str(file), + convert_options=csv_convert_options, + parse_options=csv_parse_options, + read_options=csv_read_options, + io_config=config.io_config, + ) + return _cast_table_to_schema(tbl, read_options=read_options, schema=schema) def partitioned_table_to_hive_iter(partitioned: PartitionedTable, root_path: str) -> Iterator[tuple[pa.Table, str]]: diff --git a/daft/udf_library/__init__.py b/daft/udf_library/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/daft/udf_library/url_udfs.py b/daft/udf_library/url_udfs.py deleted file mode 100644 index 37f4c485bc..0000000000 --- a/daft/udf_library/url_udfs.py +++ /dev/null @@ -1,88 +0,0 @@ -from __future__ import annotations - -import logging -import threading -from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Literal - -from daft import filesystem -from daft.datatype import DataType -from daft.series import Series -from daft.udf import udf - -thread_local = threading.local() - - -logger = logging.getLogger(__name__) - - -def _worker_thread_initializer() -> None: - """Initializes per-thread local state.""" - thread_local.filesystems_cache = {} - - -def _download(path: str | None, on_error: Literal["raise", "null"]) -> bytes | None: - if path is None: - return None - protocol = filesystem.get_protocol_from_path(path) - - # If no fsspec filesystem provided, first check the cache. - # If none in the cache, create one based on the path protocol. - fs = thread_local.filesystems_cache.get(protocol, None) - fs = filesystem.get_filesystem(protocol) - thread_local.filesystems_cache[protocol] = fs - - try: - return fs.cat_file(path) - except Exception as e: - if on_error == "raise": - raise - elif on_error == "null": - logger.error( - "Encountered error during download from URL %s and falling back to Null\n\n%s: %s", path, e, str(e) - ) - return None - else: - raise NotImplementedError(f"Unimplemented on_error option: {on_error}.\n\nEncountered error: {e}") - - -def _warmup_fsspec_registry(urls_pylist: list[str | None]) -> None: - """HACK: filesystem.get_filesystem calls fsspec.get_filesystem_class under the hood, which throws an error if accessed concurrently for the first time. We "warm" it up in a single-threaded fashion here. - - This should be fixed in the next release of FSSpec - See: https://github.com/Eventual-Inc/Daft/issues/892 - """ - import fsspec - - protocols = {filesystem.get_protocol_from_path(url) for url in urls_pylist if url is not None} - for protocol in protocols: - fsspec.get_filesystem_class(protocol) - - -@udf(return_dtype=DataType.binary()) -def download_udf( - urls, - max_worker_threads: int = 8, - on_error: Literal["raise", "null"] = "raise", -): - """Downloads the contents of the supplied URLs. - - Args: - urls: URLs as a UTF8 string series - max_worker_threads: max number of worker threads to use, defaults to 8 - on_error: Behavior when a URL download error is encountered - "raise" to raise the error immediately or "null" to log - the error but fallback to a Null value. Defaults to "raise". - fs (fsspec.AbstractFileSystem): fsspec FileSystem to use for downloading data. - By default, Daft will automatically construct a FileSystem instance internally. - """ - urls_pylist = urls.to_arrow().to_pylist() - - _warmup_fsspec_registry(urls_pylist) - - executor = ThreadPoolExecutor(max_workers=max_worker_threads, initializer=_worker_thread_initializer) - results: list[bytes | None] = [None for _ in range(len(urls))] - future_to_idx = {executor.submit(_download, urls_pylist[i], on_error): i for i in range(len(urls))} - for future in as_completed(future_to_idx): - results[future_to_idx[future]] = future.result() - - return Series.from_pylist(results) diff --git a/src/daft-local-execution/src/sources/scan_task.rs b/src/daft-local-execution/src/sources/scan_task.rs index 04cadaa767..6e4a76e1fe 100644 --- a/src/daft-local-execution/src/sources/scan_task.rs +++ b/src/daft-local-execution/src/sources/scan_task.rs @@ -15,7 +15,7 @@ use daft_io::IOStatsRef; use daft_json::{JsonConvertOptions, JsonParseOptions, JsonReadOptions}; use daft_micropartition::MicroPartition; use daft_parquet::read::{read_parquet_bulk_async, ParquetSchemaInferenceOptions}; -use daft_scan::{storage_config::StorageConfig, ChunkSpec, ScanTask}; +use daft_scan::{ChunkSpec, ScanTask}; use futures::{Stream, StreamExt, TryStreamExt}; use snafu::ResultExt; use tracing::instrument; @@ -241,19 +241,14 @@ async fn stream_scan_task( } let source = scan_task.sources.first().unwrap(); let url = source.get_path(); - let (io_config, multi_threaded_io) = match scan_task.storage_config.as_ref() { - StorageConfig::Native(native_storage_config) => ( - native_storage_config.io_config.as_ref(), - native_storage_config.multithreaded_io, - ), - - #[cfg(feature = "python")] - StorageConfig::Python(python_storage_config) => { - (python_storage_config.io_config.as_ref(), true) - } - }; - let io_config = Arc::new(io_config.cloned().unwrap_or_default()); - let io_client = daft_io::get_io_client(multi_threaded_io, io_config)?; + let io_config = Arc::new( + scan_task + .storage_config + .io_config + .clone() + .unwrap_or_default(), + ); + let io_client = daft_io::get_io_client(scan_task.storage_config.multithreaded_io, io_config)?; let table_stream = match scan_task.file_format_config.as_ref() { FileFormatConfig::Parquet(ParquetSourceConfig { coerce_int96_timestamp_unit, diff --git a/src/daft-logical-plan/src/ops/monotonically_increasing_id.rs b/src/daft-logical-plan/src/ops/monotonically_increasing_id.rs index 9be863a686..170296fa2a 100644 --- a/src/daft-logical-plan/src/ops/monotonically_increasing_id.rs +++ b/src/daft-logical-plan/src/ops/monotonically_increasing_id.rs @@ -17,7 +17,8 @@ impl MonotonicallyIncreasingId { let column_name = column_name.unwrap_or("id"); let mut schema_with_id_index_map = input.schema().fields.clone(); - schema_with_id_index_map.insert( + schema_with_id_index_map.shift_insert( + 0, column_name.to_string(), Field::new(column_name, DataType::UInt64), ); diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index 84b4e014dd..652ce3e57c 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -6,7 +6,9 @@ use std::{ use arrow2::io::parquet::read::schema::infer_schema_with_options; use common_error::DaftResult; -use common_file_formats::{CsvSourceConfig, FileFormatConfig, ParquetSourceConfig}; +#[cfg(feature = "python")] +use common_file_formats::DatabaseSourceConfig; +use common_file_formats::{FileFormatConfig, ParquetSourceConfig}; use common_runtime::get_io_runtime; use common_scan_info::Pushdowns; use daft_core::prelude::*; @@ -17,16 +19,11 @@ use daft_json::{JsonConvertOptions, JsonParseOptions, JsonReadOptions}; use daft_parquet::read::{ read_parquet_bulk, read_parquet_metadata_bulk, ParquetSchemaInferenceOptions, }; -use daft_scan::{ - storage_config::{NativeStorageConfig, StorageConfig}, - ChunkSpec, DataSource, ScanTask, -}; +use daft_scan::{storage_config::StorageConfig, ChunkSpec, DataSource, ScanTask}; use daft_stats::{PartitionSpec, TableMetadata, TableStatistics}; use daft_table::Table; use parquet2::metadata::FileMetaData; use snafu::ResultExt; -#[cfg(feature = "python")] -use {crate::PyIOSnafu, common_file_formats::DatabaseSourceConfig}; use crate::{DaftCSVSnafu, DaftCoreComputeSnafu}; @@ -112,280 +109,193 @@ fn materialize_scan_task( .iter() .map(daft_scan::DataSource::get_path); - let mut table_values = match scan_task.storage_config.as_ref() { - StorageConfig::Native(native_storage_config) => { - let multithreaded_io = native_storage_config.multithreaded_io; - let io_config = Arc::new(native_storage_config.io_config.clone().unwrap_or_default()); - let io_client = daft_io::get_io_client(multithreaded_io, io_config).unwrap(); - - match scan_task.file_format_config.as_ref() { - // ******************** - // Native Parquet Reads - // ******************** - FileFormatConfig::Parquet(ParquetSourceConfig { - coerce_int96_timestamp_unit, - field_id_mapping, - chunk_size, - .. - }) => { - let inference_options = - ParquetSchemaInferenceOptions::new(Some(*coerce_int96_timestamp_unit)); + let multithreaded_io = scan_task.storage_config.multithreaded_io; + let io_config = Arc::new( + scan_task + .storage_config + .io_config + .clone() + .unwrap_or_default(), + ); + let io_client = daft_io::get_io_client(multithreaded_io, io_config).unwrap(); + + let mut table_values = match scan_task.file_format_config.as_ref() { + // ******************** + // Native Parquet Reads + // ******************** + FileFormatConfig::Parquet(ParquetSourceConfig { + coerce_int96_timestamp_unit, + field_id_mapping, + chunk_size, + .. + }) => { + let inference_options = + ParquetSchemaInferenceOptions::new(Some(*coerce_int96_timestamp_unit)); - // TODO: This is a hardcoded magic value but should be configurable - let num_parallel_tasks = 8; + // TODO: This is a hardcoded magic value but should be configurable + let num_parallel_tasks = 8; - let urls = urls.collect::>(); + let urls = urls.collect::>(); - // Create vec of all unique delete files in the scan task - let iceberg_delete_files = scan_task - .sources - .iter() - .filter_map(|s| s.get_iceberg_delete_files()) - .flatten() - .map(String::as_str) - .collect::>() - .into_iter() - .collect::>(); - - let delete_map = _read_delete_files( - iceberg_delete_files.as_slice(), - urls.as_slice(), - io_client.clone(), - io_stats.clone(), - num_parallel_tasks, - multithreaded_io, - &inference_options, - ) - .context(DaftCoreComputeSnafu)?; + // Create vec of all unique delete files in the scan task + let iceberg_delete_files = scan_task + .sources + .iter() + .filter_map(|s| s.get_iceberg_delete_files()) + .flatten() + .map(String::as_str) + .collect::>() + .into_iter() + .collect::>(); - let row_groups = parquet_sources_to_row_groups(scan_task.sources.as_slice()); - let metadatas = scan_task - .sources - .iter() - .map(|s| s.get_parquet_metadata().cloned()) - .collect::>>(); - daft_parquet::read::read_parquet_bulk( - urls.as_slice(), - file_column_names.as_deref(), - None, - scan_task.pushdowns.limit, - row_groups, - scan_task.pushdowns.filters.clone(), - io_client, - io_stats, - num_parallel_tasks, - multithreaded_io, - &inference_options, - field_id_mapping.clone(), - metadatas, - Some(delete_map), - *chunk_size, - ) - .context(DaftCoreComputeSnafu)? - } + let delete_map = _read_delete_files( + iceberg_delete_files.as_slice(), + urls.as_slice(), + io_client.clone(), + io_stats.clone(), + num_parallel_tasks, + multithreaded_io, + &inference_options, + ) + .context(DaftCoreComputeSnafu)?; - // **************** - // Native CSV Reads - // **************** - FileFormatConfig::Csv(cfg) => { - let schema_of_file = scan_task.schema.clone(); - let col_names = if !cfg.has_headers { - Some( - schema_of_file - .fields - .values() - .map(|f| f.name.as_str()) - .collect::>(), - ) - } else { - None - }; - let convert_options = CsvConvertOptions::new_internal( - scan_task.pushdowns.limit, - file_column_names - .as_ref() - .map(|cols| cols.iter().map(|col| (*col).to_string()).collect()), - col_names - .as_ref() - .map(|cols| cols.iter().map(|col| (*col).to_string()).collect()), - Some(schema_of_file), - scan_task.pushdowns.filters.clone(), - ); - let parse_options = CsvParseOptions::new_with_defaults( - cfg.has_headers, - cfg.delimiter, - cfg.double_quote, - cfg.quote, - cfg.allow_variable_columns, - cfg.escape_char, - cfg.comment, - ) - .context(DaftCSVSnafu)?; - let read_options = - CsvReadOptions::new_internal(cfg.buffer_size, cfg.chunk_size); - let uris = urls.collect::>(); - daft_csv::read_csv_bulk( - uris.as_slice(), - Some(convert_options), - Some(parse_options), - Some(read_options), - io_client, - io_stats, - native_storage_config.multithreaded_io, - None, - 8, - ) - .context(DaftCoreComputeSnafu)? - } + let row_groups = parquet_sources_to_row_groups(scan_task.sources.as_slice()); + let metadatas = scan_task + .sources + .iter() + .map(|s| s.get_parquet_metadata().cloned()) + .collect::>>(); + daft_parquet::read::read_parquet_bulk( + urls.as_slice(), + file_column_names.as_deref(), + None, + scan_task.pushdowns.limit, + row_groups, + scan_task.pushdowns.filters.clone(), + io_client, + io_stats, + num_parallel_tasks, + multithreaded_io, + &inference_options, + field_id_mapping.clone(), + metadatas, + Some(delete_map), + *chunk_size, + ) + .context(DaftCoreComputeSnafu)? + } - // **************** - // Native JSON Reads - // **************** - FileFormatConfig::Json(cfg) => { - let convert_options = JsonConvertOptions::new_internal( - scan_task.pushdowns.limit, - file_column_names - .as_ref() - .map(|cols| cols.iter().map(|col| (*col).to_string()).collect()), - Some(scan_task.schema.clone()), - scan_task.pushdowns.filters.clone(), - ); - let parse_options = JsonParseOptions::new_internal(); - let read_options = - JsonReadOptions::new_internal(cfg.buffer_size, cfg.chunk_size); - let uris = urls.collect::>(); - daft_json::read_json_bulk( - uris.as_slice(), - Some(convert_options), - Some(parse_options), - Some(read_options), - io_client, - io_stats, - native_storage_config.multithreaded_io, - None, - 8, - ) - .context(DaftCoreComputeSnafu)? - } - #[cfg(feature = "python")] - FileFormatConfig::Database(_) => { - return Err(common_error::DaftError::TypeError( - "Native reads for Database file format not implemented".to_string(), - )) - .context(DaftCoreComputeSnafu); - } - #[cfg(feature = "python")] - FileFormatConfig::PythonFunction => { - return Err(common_error::DaftError::TypeError( - "Native reads for PythonFunction file format not implemented".to_string(), - )) - .context(DaftCoreComputeSnafu); - } - } + // **************** + // Native CSV Reads + // **************** + FileFormatConfig::Csv(cfg) => { + let schema_of_file = scan_task.schema.clone(); + let col_names = if !cfg.has_headers { + Some( + schema_of_file + .fields + .values() + .map(|f| f.name.as_str()) + .collect::>(), + ) + } else { + None + }; + let convert_options = CsvConvertOptions::new_internal( + scan_task.pushdowns.limit, + file_column_names + .as_ref() + .map(|cols| cols.iter().map(|col| (*col).to_string()).collect()), + col_names + .as_ref() + .map(|cols| cols.iter().map(|col| (*col).to_string()).collect()), + Some(schema_of_file), + scan_task.pushdowns.filters.clone(), + ); + let parse_options = CsvParseOptions::new_with_defaults( + cfg.has_headers, + cfg.delimiter, + cfg.double_quote, + cfg.quote, + cfg.allow_variable_columns, + cfg.escape_char, + cfg.comment, + ) + .context(DaftCSVSnafu)?; + let read_options = CsvReadOptions::new_internal(cfg.buffer_size, cfg.chunk_size); + let uris = urls.collect::>(); + daft_csv::read_csv_bulk( + uris.as_slice(), + Some(convert_options), + Some(parse_options), + Some(read_options), + io_client, + io_stats, + scan_task.storage_config.multithreaded_io, + None, + 8, + ) + .context(DaftCoreComputeSnafu)? + } + + // **************** + // Native JSON Reads + // **************** + FileFormatConfig::Json(cfg) => { + let convert_options = JsonConvertOptions::new_internal( + scan_task.pushdowns.limit, + file_column_names + .as_ref() + .map(|cols| cols.iter().map(|col| (*col).to_string()).collect()), + Some(scan_task.schema.clone()), + scan_task.pushdowns.filters.clone(), + ); + let parse_options = JsonParseOptions::new_internal(); + let read_options = JsonReadOptions::new_internal(cfg.buffer_size, cfg.chunk_size); + let uris = urls.collect::>(); + daft_json::read_json_bulk( + uris.as_slice(), + Some(convert_options), + Some(parse_options), + Some(read_options), + io_client, + io_stats, + scan_task.storage_config.multithreaded_io, + None, + 8, + ) + .context(DaftCoreComputeSnafu)? } #[cfg(feature = "python")] - StorageConfig::Python(_) => { - use pyo3::Python; - match scan_task.file_format_config.as_ref() { - FileFormatConfig::Parquet(ParquetSourceConfig { - coerce_int96_timestamp_unit, - .. - }) => Python::with_gil(|py| { - urls.map(|url| { - crate::python::read_parquet_into_py_table( - py, - url, - scan_task.schema.clone().into(), - (*coerce_int96_timestamp_unit).into(), - scan_task.storage_config.clone().into(), - scan_task - .pushdowns - .columns - .as_ref() - .map(|cols| cols.as_ref().clone()), - scan_task.pushdowns.limit, - ) - .map(std::convert::Into::into) - .context(PyIOSnafu) - }) - .collect::>>() - })?, - FileFormatConfig::Csv(CsvSourceConfig { - has_headers, - delimiter, - double_quote, - .. - }) => Python::with_gil(|py| { - urls.map(|url| { - crate::python::read_csv_into_py_table( - py, - url, - *has_headers, - *delimiter, - *double_quote, - scan_task.schema.clone().into(), - scan_task.storage_config.clone().into(), - scan_task - .pushdowns - .columns - .as_ref() - .map(|cols| cols.as_ref().clone()), - scan_task.pushdowns.limit, - ) - .map(std::convert::Into::into) - .context(PyIOSnafu) - }) - .collect::>>() - })?, - FileFormatConfig::Json(_) => Python::with_gil(|py| { - urls.map(|url| { - crate::python::read_json_into_py_table( - py, - url, - scan_task.schema.clone().into(), - scan_task.storage_config.clone().into(), - scan_task - .pushdowns - .columns - .as_ref() - .map(|cols| cols.as_ref().clone()), - scan_task.pushdowns.limit, - ) - .map(std::convert::Into::into) - .context(PyIOSnafu) - }) - .collect::>>() - })?, - FileFormatConfig::Database(DatabaseSourceConfig { sql, conn }) => { - let predicate = scan_task + FileFormatConfig::Database(DatabaseSourceConfig { sql, conn }) => { + let predicate = scan_task + .pushdowns + .filters + .as_ref() + .map(|p| (*p.as_ref()).clone().into()); + pyo3::Python::with_gil(|py| { + let table = crate::python::read_sql_into_py_table( + py, + sql, + conn, + predicate.clone(), + scan_task.schema.clone().into(), + scan_task .pushdowns - .filters + .columns .as_ref() - .map(|p| (*p.as_ref()).clone().into()); - Python::with_gil(|py| { - let table = crate::python::read_sql_into_py_table( - py, - sql, - conn, - predicate.clone(), - scan_task.schema.clone().into(), - scan_task - .pushdowns - .columns - .as_ref() - .map(|cols| cols.as_ref().clone()), - scan_task.pushdowns.limit, - ) - .map(std::convert::Into::into) - .context(PyIOSnafu)?; - Ok(vec![table]) - })? - } - FileFormatConfig::PythonFunction => { - let tables = crate::python::read_pyfunc_into_table_iter(&scan_task)?; - tables.collect::>>()? - } - } + .map(|cols| cols.as_ref().clone()), + scan_task.pushdowns.limit, + ) + .map(std::convert::Into::into) + .context(crate::PyIOSnafu)?; + Ok(vec![table]) + })? + } + #[cfg(feature = "python")] + FileFormatConfig::PythonFunction => { + let tables = crate::python::read_pyfunc_into_table_iter(&scan_task)?; + tables.collect::>>()? } }; @@ -473,11 +383,10 @@ impl MicroPartition { &scan_task.metadata, &scan_task.statistics, scan_task.file_format_config.as_ref(), - scan_task.storage_config.as_ref(), ) { // CASE: ScanTask provides all required metadata. // If the scan_task provides metadata (e.g. retrieved from a catalog) we can use it to create an unloaded MicroPartition - (Some(metadata), Some(statistics), _, _) if scan_task.pushdowns.filters.is_none() => { + (Some(metadata), Some(statistics), _) if scan_task.pushdowns.filters.is_none() => { Ok(Self::new_unloaded( scan_task.clone(), scan_task @@ -502,7 +411,6 @@ impl MicroPartition { chunk_size, .. }), - StorageConfig::Native(cfg), ) => { let uris = scan_task .sources @@ -538,10 +446,15 @@ impl MicroPartition { row_groups, scan_task.pushdowns.filters.clone(), scan_task.partition_spec(), - cfg.io_config.clone().map(Arc::new).unwrap_or_default(), + scan_task + .storage_config + .io_config + .clone() + .map(Arc::new) + .unwrap_or_default(), Some(io_stats), if scan_task.sources.len() == 1 { 1 } else { 128 }, // Hardcoded for to 128 bulk reads - cfg.multithreaded_io, + scan_task.storage_config.multithreaded_io, &ParquetSchemaInferenceOptions { coerce_int96_timestamp_unit, ..Default::default() @@ -677,7 +590,8 @@ impl MicroPartition { .collect::>>()?; let mut schema_with_id_index_map = self.schema.fields.clone(); - schema_with_id_index_map.insert( + schema_with_id_index_map.shift_insert( + 0, column_name.to_string(), Field::new(column_name, DataType::UInt64), ); @@ -1198,14 +1112,7 @@ pub fn read_parquet_into_micropartition>( }) .into(), scan_task_daft_schema, - StorageConfig::Native( - NativeStorageConfig::new_internal( - multithreaded_io, - Some(io_config.as_ref().clone()), - ) - .into(), - ) - .into(), + StorageConfig::new_internal(multithreaded_io, Some(io_config.as_ref().clone())).into(), Pushdowns::new( None, None, diff --git a/src/daft-micropartition/src/python.rs b/src/daft-micropartition/src/python.rs index 06c59c97fd..e062abb2b5 100644 --- a/src/daft-micropartition/src/python.rs +++ b/src/daft-micropartition/src/python.rs @@ -12,7 +12,7 @@ use daft_io::{python::IOConfig, IOStatsContext}; use daft_json::{JsonConvertOptions, JsonParseOptions, JsonReadOptions}; use daft_parquet::read::ParquetSchemaInferenceOptions; use daft_scan::{ - python::pylib::PyScanTask, storage_config::PyStorageConfig, DataSource, ScanTask, ScanTaskRef, + python::pylib::PyScanTask, storage_config::StorageConfig, DataSource, ScanTask, ScanTaskRef, }; use daft_stats::{TableMetadata, TableStatistics}; use daft_table::{python::PyTable, Table}; @@ -534,7 +534,7 @@ impl PyMicroPartition { py: Python, uri: &str, schema: PySchema, - storage_config: PyStorageConfig, + storage_config: StorageConfig, include_columns: Option>, num_rows: Option, ) -> PyResult { @@ -801,7 +801,7 @@ pub fn read_json_into_py_table( py: Python, uri: &str, schema: PySchema, - storage_config: PyStorageConfig, + storage_config: StorageConfig, include_columns: Option>, num_rows: Option, ) -> PyResult { @@ -831,7 +831,7 @@ pub fn read_csv_into_py_table( delimiter: Option, double_quote: bool, schema: PySchema, - storage_config: PyStorageConfig, + storage_config: StorageConfig, include_columns: Option>, num_rows: Option, ) -> PyResult { @@ -863,7 +863,7 @@ pub fn read_parquet_into_py_table( uri: &str, schema: PySchema, coerce_int96_timestamp_unit: PyTimeUnit, - storage_config: PyStorageConfig, + storage_config: StorageConfig, include_columns: Option>, num_rows: Option, ) -> PyResult { diff --git a/src/daft-scan/src/builder.rs b/src/daft-scan/src/builder.rs index 158c272495..b154f20d92 100644 --- a/src/daft-scan/src/builder.rs +++ b/src/daft-scan/src/builder.rs @@ -10,10 +10,7 @@ use daft_schema::{field::Field, schema::SchemaRef}; #[cfg(feature = "python")] use {crate::python::pylib::ScanOperatorHandle, pyo3::prelude::*}; -use crate::{ - glob::GlobScanOperator, - storage_config::{NativeStorageConfig, StorageConfig}, -}; +use crate::{glob::GlobScanOperator, storage_config::StorageConfig}; pub struct ParquetScanBuilder { pub glob_paths: Vec, @@ -109,9 +106,10 @@ impl ParquetScanBuilder { GlobScanOperator::try_new( self.glob_paths, Arc::new(FileFormatConfig::Parquet(cfg)), - Arc::new(StorageConfig::Native(Arc::new( - NativeStorageConfig::new_internal(self.multithreaded, self.io_config), - ))), + Arc::new(StorageConfig::new_internal( + self.multithreaded, + self.io_config, + )), self.infer_schema, self.schema, self.file_path_column, @@ -144,7 +142,6 @@ pub struct CsvScanBuilder { pub allow_variable_columns: bool, pub buffer_size: Option, pub chunk_size: Option, - pub use_native_downloader: bool, pub schema_hints: Option, } @@ -172,7 +169,6 @@ impl CsvScanBuilder { allow_variable_columns: false, buffer_size: None, chunk_size: None, - use_native_downloader: true, schema_hints: None, } } @@ -236,10 +232,6 @@ impl CsvScanBuilder { self.schema_hints = Some(schema_hints); self } - pub fn use_native_downloader(mut self, use_native_downloader: bool) -> Self { - self.use_native_downloader = use_native_downloader; - self - } pub async fn finish(self) -> DaftResult { let cfg = CsvSourceConfig { @@ -258,9 +250,7 @@ impl CsvScanBuilder { GlobScanOperator::try_new( self.glob_paths, Arc::new(FileFormatConfig::Csv(cfg)), - Arc::new(StorageConfig::Native(Arc::new( - NativeStorageConfig::new_internal(false, self.io_config), - ))), + Arc::new(StorageConfig::new_internal(false, self.io_config)), self.infer_schema, self.schema, self.file_path_column, @@ -279,25 +269,22 @@ pub fn delta_scan>( io_config: Option, multithreaded_io: bool, ) -> DaftResult { - use crate::storage_config::{NativeStorageConfig, PyStorageConfig, StorageConfig}; + use crate::storage_config::StorageConfig; Python::with_gil(|py| { let io_config = io_config.unwrap_or_default(); - let native_storage_config = NativeStorageConfig { + let storage_config = StorageConfig { io_config: Some(io_config), multithreaded_io, }; - let py_storage_config: PyStorageConfig = - Arc::new(StorageConfig::Native(Arc::new(native_storage_config))).into(); - // let py_io_config = PyIOConfig { config: io_config }; let delta_lake_scan = PyModule::import_bound(py, "daft.delta_lake.delta_lake_scan")?; let delta_lake_scan_operator = delta_lake_scan.getattr(pyo3::intern!(py, "DeltaLakeScanOperator"))?; let delta_lake_operator = delta_lake_scan_operator - .call1((glob_path.as_ref(), py_storage_config))? + .call1((glob_path.as_ref(), storage_config))? .to_object(py); let scan_operator_handle = ScanOperatorHandle::from_python_scan_operator(delta_lake_operator, py)?; diff --git a/src/daft-scan/src/lib.rs b/src/daft-scan/src/lib.rs index efae6d250c..2f984dc213 100644 --- a/src/daft-scan/src/lib.rs +++ b/src/daft-scan/src/lib.rs @@ -769,8 +769,7 @@ impl ScanTask { let storage_config = self.storage_config.multiline_display(); if !storage_config.is_empty() { res.push(format!( - "{} storage config = {{ {} }}", - self.storage_config.var_name(), + "storage config = {{ {} }}", storage_config.join(", ") )); } @@ -841,11 +840,7 @@ mod test { use daft_schema::{schema::Schema, time_unit::TimeUnit}; use itertools::Itertools; - use crate::{ - glob::GlobScanOperator, - storage_config::{NativeStorageConfig, StorageConfig}, - DataSource, ScanTask, - }; + use crate::{glob::GlobScanOperator, storage_config::StorageConfig, DataSource, ScanTask}; fn make_scan_task(num_sources: usize) -> ScanTask { let sources = (0..num_sources) @@ -872,9 +867,7 @@ mod test { sources, Arc::new(file_format_config), Arc::new(Schema::empty()), - Arc::new(StorageConfig::Native(Arc::new( - NativeStorageConfig::new_internal(false, None), - ))), + Arc::new(StorageConfig::new_internal(false, None)), Pushdowns::default(), None, ) @@ -897,9 +890,7 @@ mod test { let glob_scan_operator: GlobScanOperator = GlobScanOperator::try_new( sources, Arc::new(file_format_config), - Arc::new(StorageConfig::Native(Arc::new( - NativeStorageConfig::new_internal(false, None), - ))), + Arc::new(StorageConfig::new_internal(false, None)), false, Some(Arc::new(Schema::empty())), None, diff --git a/src/daft-scan/src/python.rs b/src/daft-scan/src/python.rs index e27d68ef3a..d6e0665047 100644 --- a/src/daft-scan/src/python.rs +++ b/src/daft-scan/src/python.rs @@ -4,7 +4,7 @@ use common_py_serde::{deserialize_py_object, serialize_py_object}; use pyo3::{prelude::*, types::PyTuple}; use serde::{Deserialize, Serialize}; -use crate::storage_config::{NativeStorageConfig, PyStorageConfig, PythonStorageConfig}; +use crate::storage_config::StorageConfig; #[derive(Debug, Clone, Serialize, Deserialize)] struct PyObjectSerializableWrapper( @@ -87,9 +87,7 @@ pub mod pylib { use super::PythonTablesFactoryArgs; use crate::{ - anonymous::AnonymousScanOperator, - glob::GlobScanOperator, - storage_config::{PyStorageConfig, PythonStorageConfig}, + anonymous::AnonymousScanOperator, glob::GlobScanOperator, storage_config::StorageConfig, DataSource, ScanTask, }; #[pyclass(module = "daft.daft", frozen)] @@ -110,7 +108,7 @@ pub mod pylib { files: Vec, schema: PySchema, file_format_config: PyFileFormatConfig, - storage_config: PyStorageConfig, + storage_config: StorageConfig, ) -> PyResult { py.allow_threads(|| { let schema = schema.schema; @@ -132,7 +130,7 @@ pub mod pylib { py: Python, glob_path: Vec, file_format_config: PyFileFormatConfig, - storage_config: PyStorageConfig, + storage_config: StorageConfig, hive_partitioning: bool, infer_schema: bool, schema: Option, @@ -346,7 +344,7 @@ pub mod pylib { file: String, file_format: PyFileFormatConfig, schema: PySchema, - storage_config: PyStorageConfig, + storage_config: StorageConfig, num_rows: Option, size_bytes: Option, iceberg_delete_files: Option>, @@ -410,7 +408,7 @@ pub mod pylib { url: String, file_format: PyFileFormatConfig, schema: PySchema, - storage_config: PyStorageConfig, + storage_config: StorageConfig, num_rows: Option, size_bytes: Option, pushdowns: Option, @@ -473,9 +471,7 @@ pub mod pylib { schema.schema, // HACK: StorageConfig isn't used when running the Python function but this is a non-optional arg for // ScanTask creation, so we just put in a placeholder here - Arc::new(crate::storage_config::StorageConfig::Python(Arc::new( - PythonStorageConfig { io_config: None }, - ))), + Arc::new(Default::default()), pushdowns.map(|p| p.0.as_ref().clone()).unwrap_or_default(), None, ); @@ -563,9 +559,7 @@ pub mod pylib { vec![data_source], Arc::new(FileFormatConfig::Parquet(default::Default::default())), Arc::new(schema), - Arc::new(crate::storage_config::StorageConfig::Native(Arc::new( - default::Default::default(), - ))), + Arc::new(Default::default()), Pushdowns::new(None, None, columns.map(Arc::new), None), None, ); @@ -574,9 +568,7 @@ pub mod pylib { } pub fn register_modules(parent: &Bound) -> PyResult<()> { - parent.add_class::()?; - parent.add_class::()?; - parent.add_class::()?; + parent.add_class::()?; parent.add_class::()?; parent.add_class::()?; diff --git a/src/daft-scan/src/scan_task_iters.rs b/src/daft-scan/src/scan_task_iters.rs index 3ee2a18ccd..8c38c1ecec 100644 --- a/src/daft-scan/src/scan_task_iters.rs +++ b/src/daft-scan/src/scan_task_iters.rs @@ -8,9 +8,7 @@ use daft_io::IOStatsContext; use daft_parquet::read::read_parquet_metadata; use parquet2::metadata::RowGroupList; -use crate::{ - storage_config::StorageConfig, ChunkSpec, DataSource, Pushdowns, ScanTask, ScanTaskRef, -}; +use crate::{ChunkSpec, DataSource, Pushdowns, ScanTask, ScanTaskRef}; pub(crate) type BoxScanTaskIter<'a> = Box> + 'a>; @@ -206,13 +204,11 @@ pub(crate) fn split_by_row_groups( FileFormatConfig::Parquet(ParquetSourceConfig { field_id_mapping, .. }), - StorageConfig::Native(_), [source], Some(None), None, ) = ( t.file_format_config.as_ref(), - t.storage_config.as_ref(), &t.sources[..], t.sources.first().map(DataSource::get_chunk_spec), t.pushdowns.limit, diff --git a/src/daft-scan/src/storage_config.rs b/src/daft-scan/src/storage_config.rs index c502ae62dc..15efe8f177 100644 --- a/src/daft-scan/src/storage_config.rs +++ b/src/daft-scan/src/storage_config.rs @@ -9,76 +9,21 @@ use serde::{Deserialize, Serialize}; #[cfg(feature = "python")] use { common_io_config::python, - pyo3::{pyclass, pymethods, IntoPy, PyObject, PyResult, Python}, - std::hash::{Hash, Hasher}, + pyo3::{pyclass, pymethods, PyObject, PyResult, Python}, + std::hash::Hash, }; /// Configuration for interacting with a particular storage backend, using a particular /// I/O layer implementation. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)] -pub enum StorageConfig { - Native(Arc), - #[cfg(feature = "python")] - Python(Arc), -} - -impl StorageConfig { - pub fn get_io_client_and_runtime(&self) -> DaftResult<(RuntimeRef, Arc)> { - // Grab an IOClient and Runtime - // TODO: This should be cleaned up and hidden behind a better API from daft-io - match self { - Self::Native(cfg) => { - let multithreaded_io = cfg.multithreaded_io; - Ok(( - get_io_runtime(multithreaded_io), - get_io_client( - multithreaded_io, - Arc::new(cfg.io_config.clone().unwrap_or_default()), - )?, - )) - } - #[cfg(feature = "python")] - Self::Python(cfg) => { - let multithreaded_io = true; // Hardcode to use multithreaded IO if Python storage config is used for data fetches - Ok(( - get_io_runtime(multithreaded_io), - get_io_client( - multithreaded_io, - Arc::new(cfg.io_config.clone().unwrap_or_default()), - )?, - )) - } - } - } - - #[must_use] - pub fn var_name(&self) -> &'static str { - match self { - Self::Native(_) => "Native", - #[cfg(feature = "python")] - Self::Python(_) => "Python", - } - } - - #[must_use] - pub fn multiline_display(&self) -> Vec { - match self { - Self::Native(source) => source.multiline_display(), - #[cfg(feature = "python")] - Self::Python(source) => source.multiline_display(), - } - } -} - -/// Storage configuration for the Rust-native I/O layer. -#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)] #[cfg_attr(feature = "python", pyclass(module = "daft.daft"))] -pub struct NativeStorageConfig { +pub struct StorageConfig { + // TODO: store Arc instead pub io_config: Option, pub multithreaded_io: bool, } -impl NativeStorageConfig { +impl StorageConfig { #[must_use] pub fn new_internal(multithreaded_io: bool, io_config: Option) -> Self { Self { @@ -87,6 +32,19 @@ impl NativeStorageConfig { } } + pub fn get_io_client_and_runtime(&self) -> DaftResult<(RuntimeRef, Arc)> { + // Grab an IOClient and Runtime + // TODO: This should be cleaned up and hidden behind a better API from daft-io + let multithreaded_io = self.multithreaded_io; + Ok(( + get_io_runtime(multithreaded_io), + get_io_client( + multithreaded_io, + Arc::new(self.io_config.clone().unwrap_or_default()), + )?, + )) + } + #[must_use] pub fn multiline_display(&self) -> Vec { let mut res = vec![]; @@ -101,7 +59,7 @@ impl NativeStorageConfig { } } -impl Default for NativeStorageConfig { +impl Default for StorageConfig { fn default() -> Self { Self::new_internal(true, None) } @@ -109,7 +67,7 @@ impl Default for NativeStorageConfig { #[cfg(feature = "python")] #[pymethods] -impl NativeStorageConfig { +impl StorageConfig { #[new] #[must_use] pub fn new(multithreaded_io: bool, io_config: Option) -> Self { @@ -129,114 +87,4 @@ impl NativeStorageConfig { } } -/// Storage configuration for the legacy Python I/O layer. -#[derive(Clone, Debug, Serialize, Deserialize)] -#[cfg(feature = "python")] -#[cfg_attr(feature = "python", pyclass(module = "daft.daft"))] -pub struct PythonStorageConfig { - /// IOConfig is used when constructing Python filesystems (PyArrow or fsspec filesystems) - /// and also used for globbing (since we have no Python-based globbing anymore) - pub io_config: Option, -} - -#[cfg(feature = "python")] -impl PythonStorageConfig { - #[must_use] - pub fn multiline_display(&self) -> Vec { - let mut res = vec![]; - if let Some(io_config) = &self.io_config { - res.push(format!( - "IO config = {}", - io_config.multiline_display().join(", ") - )); - } - res - } -} - -#[cfg(feature = "python")] -#[pymethods] -impl PythonStorageConfig { - #[new] - #[must_use] - pub fn new(io_config: Option) -> Self { - Self { - io_config: io_config.map(|c| c.config), - } - } - - #[getter] - #[must_use] - pub fn io_config(&self) -> Option { - self.io_config - .as_ref() - .map(|c| python::IOConfig { config: c.clone() }) - } -} - -#[cfg(feature = "python")] -impl PartialEq for PythonStorageConfig { - fn eq(&self, other: &Self) -> bool { - self.io_config.eq(&other.io_config) - } -} - -#[cfg(feature = "python")] -impl Eq for PythonStorageConfig {} - -#[cfg(feature = "python")] -impl Hash for PythonStorageConfig { - fn hash(&self, state: &mut H) { - self.io_config.hash(state); - } -} - -/// A Python-exposed interface for storage configs. -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(transparent)] -#[cfg_attr( - feature = "python", - pyclass(module = "daft.daft", name = "StorageConfig") -)] -pub struct PyStorageConfig(Arc); - -#[cfg(feature = "python")] -#[pymethods] -impl PyStorageConfig { - /// Create from a native storage config. - #[staticmethod] - fn native(config: NativeStorageConfig) -> Self { - Self(Arc::new(StorageConfig::Native(config.into()))) - } - - /// Create from a Python storage config. - #[staticmethod] - fn python(config: PythonStorageConfig) -> Self { - Self(Arc::new(StorageConfig::Python(config.into()))) - } - - /// Get the underlying storage config. - #[getter] - fn get_config(&self, py: Python) -> PyObject { - use StorageConfig::{Native, Python}; - - match self.0.as_ref() { - Native(config) => config.as_ref().clone().into_py(py), - Python(config) => config.as_ref().clone().into_py(py), - } - } -} - -impl_bincode_py_state_serialization!(PyStorageConfig); - -impl From for Arc { - fn from(value: PyStorageConfig) -> Self { - value.0 - } -} - -impl From> for PyStorageConfig { - fn from(value: Arc) -> Self { - Self(value) - } -} +impl_bincode_py_state_serialization!(StorageConfig); diff --git a/src/daft-schema/src/schema.rs b/src/daft-schema/src/schema.rs index a1fc464e96..65a37bf9ad 100644 --- a/src/daft-schema/src/schema.rs +++ b/src/daft-schema/src/schema.rs @@ -17,7 +17,7 @@ use crate::{field::Field, prelude::DataType}; pub type SchemaRef = Arc; -#[derive(Debug, Display, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Display, Serialize, Deserialize)] #[serde(transparent)] #[display("{}\n", make_schema_vertical_table( fields.iter().map(|(name, field)| (name.clone(), field.dtype.to_string())) @@ -334,3 +334,17 @@ impl TryFrom<&arrow2::datatypes::Schema> for Schema { Self::new(daft_fields) } } + +/// Custom impl of PartialEq because IndexMap PartialEq does not check for ordering +impl PartialEq for Schema { + fn eq(&self, other: &Self) -> bool { + self.fields.len() == other.fields.len() + && self + .fields + .iter() + .zip(other.fields.iter()) + .all(|(s, o)| s == o) + } +} + +impl Eq for Schema {} diff --git a/src/daft-sql/src/table_provider/read_csv.rs b/src/daft-sql/src/table_provider/read_csv.rs index 241fad455f..0ced7ea5a9 100644 --- a/src/daft-sql/src/table_provider/read_csv.rs +++ b/src/daft-sql/src/table_provider/read_csv.rs @@ -46,7 +46,6 @@ impl TryFrom for CsvScanBuilder { let buffer_size = args.try_get_named("buffer_size")?; let file_path_column = args.try_get_named("file_path_column")?; let hive_partitioning = args.try_get_named("hive_partitioning")?.unwrap_or(false); - let use_native_downloader = args.try_get_named("use_native_downloader")?.unwrap_or(true); let schema = None; // TODO let schema_hints = None; // TODO let io_config = args.get_named("io_config").map(expr_to_iocfg).transpose()?; @@ -65,7 +64,6 @@ impl TryFrom for CsvScanBuilder { io_config, file_path_column, hive_partitioning, - use_native_downloader, schema_hints, buffer_size, chunk_size, @@ -95,7 +93,6 @@ impl SQLTableFunction for ReadCsvFunction { "io_config", "file_path_column", "hive_partitioning", - "use_native_downloader", // "schema_hints", "buffer_size", "chunk_size", diff --git a/tests/dataframe/test_creation.py b/tests/dataframe/test_creation.py index 4772919c14..3ef0f56f5a 100644 --- a/tests/dataframe/test_creation.py +++ b/tests/dataframe/test_creation.py @@ -367,8 +367,7 @@ def create_temp_filename() -> str: yield os.path.join(dir, "tempfile") -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_create_dataframe_csv(valid_data: list[dict[str, float]], use_native_downloader) -> None: +def test_create_dataframe_csv(valid_data: list[dict[str, float]]) -> None: with create_temp_filename() as fname: with open(fname, "w") as f: header = list(valid_data[0].keys()) @@ -377,7 +376,7 @@ def test_create_dataframe_csv(valid_data: list[dict[str, float]], use_native_dow writer.writerows([[item[col] for col in header] for item in valid_data]) f.flush() - df = daft.read_csv(fname, use_native_downloader=use_native_downloader) + df = daft.read_csv(fname) assert df.column_names == COL_NAMES pd_df = df.to_pandas() @@ -385,8 +384,7 @@ def test_create_dataframe_csv(valid_data: list[dict[str, float]], use_native_dow assert len(pd_df) == len(valid_data) -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_create_dataframe_multiple_csvs(valid_data: list[dict[str, float]], use_native_downloader) -> None: +def test_create_dataframe_multiple_csvs(valid_data: list[dict[str, float]]) -> None: with create_temp_filename() as f1name, create_temp_filename() as f2name: with open(f1name, "w") as f1, open(f2name, "w") as f2: for f in (f1, f2): @@ -396,7 +394,7 @@ def test_create_dataframe_multiple_csvs(valid_data: list[dict[str, float]], use_ writer.writerows([[item[col] for col in header] for item in valid_data]) f.flush() - df = daft.read_csv([f1name, f2name], use_native_downloader=use_native_downloader) + df = daft.read_csv([f1name, f2name]) assert df.column_names == COL_NAMES pd_df = df.to_pandas() @@ -475,8 +473,7 @@ def test_create_dataframe_csv_with_file_path_column_duplicate_field_names() -> N daft.read_json(fname, file_path_column="path") -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_create_dataframe_csv_generate_headers(valid_data: list[dict[str, float]], use_native_downloader) -> None: +def test_create_dataframe_csv_generate_headers(valid_data: list[dict[str, float]]) -> None: with create_temp_filename() as fname: with open(fname, "w") as f: header = list(valid_data[0].keys()) @@ -485,7 +482,7 @@ def test_create_dataframe_csv_generate_headers(valid_data: list[dict[str, float] f.flush() cnames = [f"column_{i}" for i in range(1, 6)] - df = daft.read_csv(fname, has_headers=False, use_native_downloader=use_native_downloader) + df = daft.read_csv(fname, has_headers=False) assert df.column_names == cnames pd_df = df.to_pandas() @@ -493,8 +490,7 @@ def test_create_dataframe_csv_generate_headers(valid_data: list[dict[str, float] assert len(pd_df) == len(valid_data) -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_create_dataframe_csv_column_projection(valid_data: list[dict[str, float]], use_native_downloader) -> None: +def test_create_dataframe_csv_column_projection(valid_data: list[dict[str, float]]) -> None: with create_temp_filename() as fname: with open(fname, "w") as f: header = list(valid_data[0].keys()) @@ -514,8 +510,7 @@ def test_create_dataframe_csv_column_projection(valid_data: list[dict[str, float assert len(pd_df) == len(valid_data) -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_create_dataframe_csv_custom_delimiter(valid_data: list[dict[str, float]], use_native_downloader) -> None: +def test_create_dataframe_csv_custom_delimiter(valid_data: list[dict[str, float]]) -> None: with create_temp_filename() as fname: with open(fname, "w") as f: header = list(valid_data[0].keys()) @@ -524,7 +519,7 @@ def test_create_dataframe_csv_custom_delimiter(valid_data: list[dict[str, float] writer.writerows([[item[col] for col in header] for item in valid_data]) f.flush() - df = daft.read_csv(fname, delimiter="\t", use_native_downloader=use_native_downloader) + df = daft.read_csv(fname, delimiter="\t") assert df.column_names == COL_NAMES pd_df = df.to_pandas() @@ -532,8 +527,7 @@ def test_create_dataframe_csv_custom_delimiter(valid_data: list[dict[str, float] assert len(pd_df) == len(valid_data) -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_create_dataframe_csv_provided_schema(valid_data: list[dict[str, float]], use_native_downloader) -> None: +def test_create_dataframe_csv_provided_schema(valid_data: list[dict[str, float]]) -> None: with create_temp_filename() as fname: with open(fname, "w") as f: header = list(valid_data[0].keys()) @@ -553,7 +547,6 @@ def test_create_dataframe_csv_provided_schema(valid_data: list[dict[str, float]] "p_width": DataType.string(), "variety": DataType.string(), }, - use_native_downloader=use_native_downloader, ) assert df.column_names == ["s_length", "s_width", "p_length", "p_width", "variety"] assert df.schema()["s_length"].dtype == DataType.float32() @@ -567,10 +560,7 @@ def test_create_dataframe_csv_provided_schema(valid_data: list[dict[str, float]] assert len(pd_df) == len(valid_data) -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_create_dataframe_csv_provided_schema_no_headers( - valid_data: list[dict[str, float]], use_native_downloader -) -> None: +def test_create_dataframe_csv_provided_schema_no_headers(valid_data: list[dict[str, float]]) -> None: with create_temp_filename() as fname: with open(fname, "w") as f: header = list(valid_data[0].keys()) @@ -592,7 +582,6 @@ def test_create_dataframe_csv_provided_schema_no_headers( infer_schema=False, schema=schema_for_csv_without_headers, has_headers=False, - use_native_downloader=use_native_downloader, ) assert df.column_names == list(schema_for_csv_without_headers.keys()) @@ -601,8 +590,7 @@ def test_create_dataframe_csv_provided_schema_no_headers( assert len(pd_df) == len(valid_data) -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_create_dataframe_csv_schema_hints_partial(valid_data: list[dict[str, float]], use_native_downloader) -> None: +def test_create_dataframe_csv_schema_hints_partial(valid_data: list[dict[str, float]]) -> None: with create_temp_filename() as fname: with open(fname, "w") as f: header = list(valid_data[0].keys()) @@ -619,7 +607,6 @@ def test_create_dataframe_csv_schema_hints_partial(valid_data: list[dict[str, fl "sepal_length": DataType.float64(), "sepal_width": DataType.float64(), }, - use_native_downloader=use_native_downloader, ) assert df.column_names == COL_NAMES @@ -628,10 +615,7 @@ def test_create_dataframe_csv_schema_hints_partial(valid_data: list[dict[str, fl assert len(pd_df) == len(valid_data) -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_create_dataframe_csv_schema_hints_override_types( - valid_data: list[dict[str, float]], use_native_downloader -) -> None: +def test_create_dataframe_csv_schema_hints_override_types(valid_data: list[dict[str, float]]) -> None: with create_temp_filename() as fname: with open(fname, "w") as f: header = list(valid_data[0].keys()) @@ -647,7 +631,6 @@ def test_create_dataframe_csv_schema_hints_override_types( schema={ "sepal_length": DataType.string(), # Override the inferred float64 type to string }, - use_native_downloader=use_native_downloader, ) assert df.column_names == COL_NAMES @@ -659,10 +642,7 @@ def test_create_dataframe_csv_schema_hints_override_types( assert pd_df["sepal_length"][0] == str(valid_data[0]["sepal_length"]) -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_create_dataframe_csv_schema_hints_ignore_random_hint( - valid_data: list[dict[str, float]], use_native_downloader -) -> None: +def test_create_dataframe_csv_schema_hints_ignore_random_hint(valid_data: list[dict[str, float]]) -> None: with create_temp_filename() as fname: with open(fname, "w") as f: header = list(valid_data[0].keys()) @@ -678,7 +658,6 @@ def test_create_dataframe_csv_schema_hints_ignore_random_hint( schema={ "foo": DataType.string(), # Random column name that is not in the table }, - use_native_downloader=use_native_downloader, ) assert df.column_names == COL_NAMES @@ -687,10 +666,7 @@ def test_create_dataframe_csv_schema_hints_ignore_random_hint( assert len(pd_df) == len(valid_data) -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_create_dataframe_csv_without_schema_or_inference( - valid_data: list[dict[str, float]], use_native_downloader -) -> None: +def test_create_dataframe_csv_without_schema_or_inference(valid_data: list[dict[str, float]]) -> None: with create_temp_filename() as fname: with open(fname, "w") as f: header = list(valid_data[0].keys()) @@ -704,7 +680,6 @@ def test_create_dataframe_csv_without_schema_or_inference( fname, delimiter="\t", infer_schema=False, - use_native_downloader=use_native_downloader, ) @@ -713,8 +688,7 @@ def test_create_dataframe_csv_without_schema_or_inference( ### -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_create_dataframe_json(valid_data: list[dict[str, float]], use_native_downloader) -> None: +def test_create_dataframe_json(valid_data: list[dict[str, float]]) -> None: with create_temp_filename() as fname: with open(fname, "w") as f: for data in valid_data: @@ -722,7 +696,7 @@ def test_create_dataframe_json(valid_data: list[dict[str, float]], use_native_do f.write("\n") f.flush() - df = daft.read_json(fname, use_native_downloader=use_native_downloader) + df = daft.read_json(fname) assert df.column_names == COL_NAMES pd_df = df.to_pandas() @@ -730,8 +704,7 @@ def test_create_dataframe_json(valid_data: list[dict[str, float]], use_native_do assert len(pd_df) == len(valid_data) -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_create_dataframe_multiple_jsons(valid_data: list[dict[str, float]], use_native_downloader) -> None: +def test_create_dataframe_multiple_jsons(valid_data: list[dict[str, float]]) -> None: with create_temp_filename() as f1name, create_temp_filename() as f2name: with open(f1name, "w") as f1, open(f2name, "w") as f2: for f in (f1, f2): @@ -740,7 +713,7 @@ def test_create_dataframe_multiple_jsons(valid_data: list[dict[str, float]], use f.write("\n") f.flush() - df = daft.read_json([f1name, f2name], use_native_downloader=use_native_downloader) + df = daft.read_json([f1name, f2name]) assert df.column_names == COL_NAMES pd_df = df.to_pandas() @@ -815,8 +788,7 @@ def test_create_dataframe_json_with_file_path_column_duplicate_field_names() -> daft.read_json(fname, file_path_column="path") -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_create_dataframe_json_column_projection(valid_data: list[dict[str, float]], use_native_downloader) -> None: +def test_create_dataframe_json_column_projection(valid_data: list[dict[str, float]]) -> None: with create_temp_filename() as fname: with open(fname, "w") as f: for data in valid_data: @@ -826,7 +798,7 @@ def test_create_dataframe_json_column_projection(valid_data: list[dict[str, floa col_subset = COL_NAMES[:3] - df = daft.read_json(fname, use_native_downloader=use_native_downloader) + df = daft.read_json(fname) df = df.select(*col_subset) assert df.column_names == col_subset @@ -835,21 +807,17 @@ def test_create_dataframe_json_column_projection(valid_data: list[dict[str, floa assert len(pd_df) == len(valid_data) -# TODO(Clark): Debug why this segfaults for the native downloader and is slow for the Python downloader. -# @pytest.mark.parametrize("use_native_downloader", [True, False]) @pytest.mark.skip def test_create_dataframe_json_https() -> None: df = daft.read_json( "https://github.com/Eventual-Inc/mnist-json/raw/master/mnist_handwritten_test.json.gz", - # use_native_downloader=use_native_downloader, ) df.collect() assert set(df.column_names) == {"label", "image"} assert len(df) == 10000 -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_create_dataframe_json_provided_schema(valid_data: list[dict[str, float]], use_native_downloader) -> None: +def test_create_dataframe_json_provided_schema(valid_data: list[dict[str, float]]) -> None: with create_temp_filename() as fname: with open(fname, "w") as f: for data in valid_data: @@ -867,7 +835,6 @@ def test_create_dataframe_json_provided_schema(valid_data: list[dict[str, float] "petal_width": DataType.float32(), "variety": DataType.string(), }, - use_native_downloader=use_native_downloader, ) assert df.column_names == COL_NAMES diff --git a/tests/dataframe/test_temporals.py b/tests/dataframe/test_temporals.py index 52b70b4a46..a719025764 100644 --- a/tests/dataframe/test_temporals.py +++ b/tests/dataframe/test_temporals.py @@ -43,8 +43,7 @@ def test_temporal_arithmetic_with_same_type() -> None: @pytest.mark.parametrize("format", ["csv", "parquet"]) -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_temporal_file_roundtrip(format, use_native_downloader) -> None: +def test_temporal_file_roundtrip(format) -> None: data = { "date32": pa.array([1], pa.date32()), "date64": pa.array([1], pa.date64()), @@ -86,10 +85,10 @@ def test_temporal_file_roundtrip(format, use_native_downloader) -> None: with tempfile.TemporaryDirectory() as dirname: if format == "csv": df.write_csv(dirname) - df_readback = daft.read_csv(dirname, use_native_downloader=use_native_downloader).collect() + df_readback = daft.read_csv(dirname).collect() elif format == "parquet": df.write_parquet(dirname) - df_readback = daft.read_parquet(dirname, use_native_downloader=use_native_downloader).collect() + df_readback = daft.read_parquet(dirname).collect() assert df.to_pydict() == df_readback.to_pydict() diff --git a/tests/integration/io/test_url_download_http.py b/tests/integration/io/test_url_download_http.py index 03b96ef921..94406c0dd2 100644 --- a/tests/integration/io/test_url_download_http.py +++ b/tests/integration/io/test_url_download_http.py @@ -6,11 +6,10 @@ @pytest.mark.integration() -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_url_download_http(mock_http_image_urls, image_data, use_native_downloader): +def test_url_download_http(mock_http_image_urls, image_data): data = {"urls": mock_http_image_urls} df = daft.from_pydict(data) - df = df.with_column("data", df["urls"].url.download(use_native_downloader=use_native_downloader)) + df = df.with_column("data", df["urls"].url.download()) assert df.to_pydict() == {**data, "data": [image_data for _ in range(len(mock_http_image_urls))]} @@ -20,7 +19,7 @@ def test_url_download_http_error_codes(nginx_config, status_code): server_url, _ = nginx_config data = {"urls": [f"{server_url}/{status_code}.html"]} df = daft.from_pydict(data) - df = df.with_column("data", df["urls"].url.download(on_error="raise", use_native_downloader=True)) + df = df.with_column("data", df["urls"].url.download(on_error="raise")) # 404 should always be corner-cased to return FileNotFoundError if status_code == 404: diff --git a/tests/integration/io/test_url_download_private_aws_s3.py b/tests/integration/io/test_url_download_private_aws_s3.py index 8d0f795c74..25bcca3705 100644 --- a/tests/integration/io/test_url_download_private_aws_s3.py +++ b/tests/integration/io/test_url_download_private_aws_s3.py @@ -27,7 +27,7 @@ def io_config(pytestconfig) -> IOConfig: def test_url_download_aws_s3_public_bucket_with_creds(small_images_s3_paths, io_config): data = {"urls": small_images_s3_paths} df = daft.from_pydict(data) - df = df.with_column("data", df["urls"].url.download(use_native_downloader=True, io_config=io_config)) + df = df.with_column("data", df["urls"].url.download(io_config=io_config)) data = df.to_pydict() assert len(data["data"]) == 12 @@ -38,5 +38,5 @@ def test_url_download_aws_s3_public_bucket_with_creds(small_images_s3_paths, io_ @pytest.mark.integration() def test_read_parquet_aws_s3_public_bucket_with_creds(io_config): filename = "s3://daft-public-data/test_fixtures/parquet-dev/mvp.parquet" - df = daft.read_parquet(filename, io_config=io_config, use_native_downloader=True).collect() + df = daft.read_parquet(filename, io_config=io_config).collect() assert len(df) == 100 diff --git a/tests/integration/io/test_url_download_public_aws_s3.py b/tests/integration/io/test_url_download_public_aws_s3.py index 67c2c9422d..37a41e93a9 100644 --- a/tests/integration/io/test_url_download_public_aws_s3.py +++ b/tests/integration/io/test_url_download_public_aws_s3.py @@ -38,7 +38,7 @@ def test_url_download_aws_s3_public_bucket_custom_s3fs_wrong_region(small_images def test_url_download_aws_s3_public_bucket_native_downloader(aws_public_s3_config, small_images_s3_paths): data = {"urls": small_images_s3_paths} df = daft.from_pydict(data) - df = df.with_column("data", df["urls"].url.download(io_config=aws_public_s3_config, use_native_downloader=True)) + df = df.with_column("data", df["urls"].url.download(io_config=aws_public_s3_config)) data = df.to_pydict() assert len(data["data"]) == 12 @@ -61,9 +61,7 @@ def test_url_download_aws_s3_public_bucket_native_downloader_with_connect_timeou ) with pytest.raises((ReadTimeoutError, ConnectTimeoutError), match="timed out when trying to connect to"): - df = df.with_column( - "data", df["urls"].url.download(io_config=connect_timeout_config, use_native_downloader=True) - ).collect() + df = df.with_column("data", df["urls"].url.download(io_config=connect_timeout_config)).collect() @pytest.mark.integration() @@ -81,6 +79,4 @@ def test_url_download_aws_s3_public_bucket_native_downloader_with_read_timeout(s ) with pytest.raises((ReadTimeoutError, ConnectTimeoutError), match="timed out when trying to connect to"): - df = df.with_column( - "data", df["urls"].url.download(io_config=read_timeout_config, use_native_downloader=True) - ).collect() + df = df.with_column("data", df["urls"].url.download(io_config=read_timeout_config)).collect() diff --git a/tests/integration/io/test_url_download_public_azure.py b/tests/integration/io/test_url_download_public_azure.py index 74f28bff6b..049f16b4dd 100644 --- a/tests/integration/io/test_url_download_public_azure.py +++ b/tests/integration/io/test_url_download_public_azure.py @@ -9,9 +9,7 @@ def test_url_download_public_azure(azure_storage_public_config) -> None: data = {"urls": ["az://public-anonymous/mvp.parquet"]} df = daft.from_pydict(data) - df = df.with_column( - "data", df["urls"].url.download(io_config=azure_storage_public_config, use_native_downloader=True) - ) + df = df.with_column("data", df["urls"].url.download(io_config=azure_storage_public_config)) data = df.to_pydict() assert len(data["data"]) == 1 diff --git a/tests/integration/io/test_url_download_s3_local_retry_server.py b/tests/integration/io/test_url_download_s3_local_retry_server.py index 686e1caf79..5935c56a81 100644 --- a/tests/integration/io/test_url_download_s3_local_retry_server.py +++ b/tests/integration/io/test_url_download_s3_local_retry_server.py @@ -10,7 +10,5 @@ def test_url_download_local_retry_server(retry_server_s3_config): bucket = "80-per-second-rate-limited-gets-bucket" data = {"urls": [f"s3://{bucket}/foo{i}" for i in range(100)]} df = daft.from_pydict(data) - df = df.with_column( - "data", df["urls"].url.download(io_config=retry_server_s3_config, use_native_downloader=True, on_error="null") - ) + df = df.with_column("data", df["urls"].url.download(io_config=retry_server_s3_config, on_error="null")) assert df.to_pydict() == {**data, "data": [f"foo{i}".encode() for i in range(100)]} diff --git a/tests/io/test_parquet.py b/tests/io/test_parquet.py index 52ff279349..2557809012 100644 --- a/tests/io/test_parquet.py +++ b/tests/io/test_parquet.py @@ -12,7 +12,6 @@ import pytest import daft -from daft.daft import NativeStorageConfig, PythonStorageConfig, StorageConfig from daft.datatype import DataType, TimeUnit from daft.expressions import col from daft.logical.schema import Schema @@ -37,16 +36,8 @@ def _parquet_write_helper(data: pa.Table, row_group_size: int | None = None, pap yield file -def storage_config_from_use_native_downloader(use_native_downloader: bool) -> StorageConfig: - if use_native_downloader: - return StorageConfig.native(NativeStorageConfig(True, None)) - else: - return StorageConfig.python(PythonStorageConfig(None)) - - -@pytest.mark.parametrize("use_native_downloader", [True, False]) @pytest.mark.parametrize("use_deprecated_int96_timestamps", [True, False]) -def test_parquet_read_int96_timestamps(use_deprecated_int96_timestamps, use_native_downloader): +def test_parquet_read_int96_timestamps(use_deprecated_int96_timestamps): data = { "timestamp_ms": pa.array([1, 2, 3], pa.timestamp("ms")), "timestamp_us": pa.array([1, 2, 3], pa.timestamp("us")), @@ -72,13 +63,12 @@ def test_parquet_read_int96_timestamps(use_deprecated_int96_timestamps, use_nati papq_write_table_kwargs=papq_write_table_kwargs, ) as f: expected = MicroPartition.from_pydict(data) - df = daft.read_parquet(f, schema={k: v for k, v in schema}, use_native_downloader=use_native_downloader) + df = daft.read_parquet(f, schema={k: v for k, v in schema}) assert df.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{df.to_arrow()}" -@pytest.mark.parametrize("use_native_downloader", [True, False]) @pytest.mark.parametrize("coerce_to", [TimeUnit.ms(), TimeUnit.us()]) -def test_parquet_read_int96_timestamps_overflow(coerce_to, use_native_downloader): +def test_parquet_read_int96_timestamps_overflow(coerce_to): # NOTE: datetime.datetime(3000, 1, 1) and datetime.datetime(1000, 1, 1) cannot be represented by our timestamp64(nanosecond) # type. However they can be written to Parquet's INT96 type. Here we test that a round-trip is possible if provided with # the appropriate flags. @@ -100,7 +90,7 @@ def test_parquet_read_int96_timestamps_overflow(coerce_to, use_native_downloader papq_write_table_kwargs=papq_write_table_kwargs, ) as f: expected = MicroPartition.from_pydict(data) - df = daft.read_parquet(f, coerce_int96_timestamp_unit=coerce_to, use_native_downloader=use_native_downloader) + df = daft.read_parquet(f, coerce_int96_timestamp_unit=coerce_to) assert df.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{df}" diff --git a/tests/sql/test_table_funcs.py b/tests/sql/test_table_funcs.py index ba9fe518fc..5b765dc0d1 100644 --- a/tests/sql/test_table_funcs.py +++ b/tests/sql/test_table_funcs.py @@ -51,7 +51,6 @@ def test_read_csv_other_options( allow_variable_columns=True, file_path_column="filepath", hive_partitioning=False, - use_native_downloader=True, ): df1 = daft.read_csv( sample_csv_path, @@ -61,9 +60,8 @@ def test_read_csv_other_options( allow_variable_columns=allow_variable_columns, file_path_column=file_path_column, hive_partitioning=hive_partitioning, - use_native_downloader=use_native_downloader, ) df2 = daft.sql( - f"SELECT * FROM read_csv('{sample_csv_path}', delimiter {op} '{delimiter}', escape_char {op} '{escape_char}', comment {op} '{comment}', allow_variable_columns {op} {str(allow_variable_columns).lower()}, file_path_column {op} '{file_path_column}', hive_partitioning {op} {str(hive_partitioning).lower()}, use_native_downloader {op} {str(use_native_downloader).lower()})" + f"SELECT * FROM read_csv('{sample_csv_path}', delimiter {op} '{delimiter}', escape_char {op} '{escape_char}', comment {op} '{comment}', allow_variable_columns {op} {str(allow_variable_columns).lower()}, file_path_column {op} '{file_path_column}', hive_partitioning {op} {str(hive_partitioning).lower()})" ).collect() assert df1.to_pydict() == df2.to_pydict() diff --git a/tests/table/table_io/test_csv.py b/tests/table/table_io/test_csv.py index fa8e96fbee..ea0118db3b 100644 --- a/tests/table/table_io/test_csv.py +++ b/tests/table/table_io/test_csv.py @@ -9,21 +9,14 @@ import pytest import daft -from daft.daft import NativeStorageConfig, PythonStorageConfig, StorageConfig +from daft.daft import CsvParseOptions from daft.datatype import DataType from daft.logical.schema import Schema from daft.runners.partitioning import TableParseCSVOptions, TableReadOptions -from daft.table import MicroPartition, schema_inference, table_io +from daft.table import MicroPartition, table_io from daft.utils import get_arrow_version -def storage_config_from_use_native_downloader(use_native_downloader: bool) -> StorageConfig: - if use_native_downloader: - return StorageConfig.native(NativeStorageConfig(True, None)) - else: - return StorageConfig.python(PythonStorageConfig(None)) - - def test_read_input(tmpdir): tmpdir = pathlib.Path(tmpdir) data = {"foo": [1, 2, 3]} @@ -39,9 +32,6 @@ def test_read_input(tmpdir): assert table_io.read_csv(tmpdir / "file.csv", schema=schema).to_pydict() == data assert table_io.read_csv(str(tmpdir / "file.csv"), schema=schema).to_pydict() == data - with open(tmpdir / "file.csv", "rb") as f: - assert table_io.read_csv(f, schema=schema).to_pydict() == data - @contextlib.contextmanager def _csv_write_helper(header: list[str] | None, data: list[list[str | None]], **kwargs): @@ -64,8 +54,7 @@ def _csv_write_helper(header: list[str] | None, data: list[list[str | None]], ** ("True", DataType.bool()), ], ) -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_csv_infer_schema(data, expected_dtype, use_native_downloader): +def test_csv_infer_schema(data, expected_dtype): with _csv_write_helper( header=["id", "data"], data=[ @@ -74,13 +63,11 @@ def test_csv_infer_schema(data, expected_dtype, use_native_downloader): ["3", None], ], ) as f: - storage_config = storage_config_from_use_native_downloader(use_native_downloader) - schema = schema_inference.from_csv(f, storage_config=storage_config) + schema = Schema.from_csv(f) assert schema == Schema._from_field_name_and_types([("id", DataType.int64()), ("data", expected_dtype)]) -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_csv_infer_schema_custom_delimiter(use_native_downloader): +def test_csv_infer_schema_custom_delimiter(): with _csv_write_helper( header=["id", "data"], data=[ @@ -90,15 +77,11 @@ def test_csv_infer_schema_custom_delimiter(use_native_downloader): ], delimiter="|", ) as f: - storage_config = storage_config_from_use_native_downloader(use_native_downloader) - schema = schema_inference.from_csv( - f, storage_config=storage_config, csv_options=TableParseCSVOptions(delimiter="|") - ) + schema = Schema.from_csv(f, parse_options=CsvParseOptions(delimiter="|")) assert schema == Schema._from_field_name_and_types([("id", DataType.int64()), ("data", DataType.int64())]) -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_csv_infer_schema_no_header(use_native_downloader): +def test_csv_infer_schema_no_header(): with _csv_write_helper( header=None, data=[ @@ -107,15 +90,8 @@ def test_csv_infer_schema_no_header(use_native_downloader): ["3", None], ], ) as f: - storage_config = storage_config_from_use_native_downloader(use_native_downloader) - schema = schema_inference.from_csv( - f, storage_config=storage_config, csv_options=TableParseCSVOptions(header_index=None) - ) - fields = ( - [("column_1", DataType.int64()), ("column_2", DataType.int64())] - if use_native_downloader - else [("f0", DataType.int64()), ("f1", DataType.int64())] - ) + schema = Schema.from_csv(f, parse_options=CsvParseOptions(has_header=False)) + fields = [("column_1", DataType.int64()), ("column_2", DataType.int64())] assert schema == Schema._from_field_name_and_types(fields) @@ -129,8 +105,7 @@ def test_csv_infer_schema_no_header(use_native_downloader): ("True", daft.Series.from_pylist([True, True, None])), ], ) -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_csv_read_data(data, expected_data_series, use_native_downloader): +def test_csv_read_data(data, expected_data_series): with _csv_write_helper( header=["id", "data"], data=[ @@ -139,7 +114,6 @@ def test_csv_read_data(data, expected_data_series, use_native_downloader): ["3", None], ], ) as f: - storage_config = storage_config_from_use_native_downloader(use_native_downloader) schema = Schema._from_field_name_and_types( [("id", DataType.int64()), ("data", expected_data_series.datatype())] ) @@ -149,12 +123,11 @@ def test_csv_read_data(data, expected_data_series, use_native_downloader): "data": expected_data_series, } ) - table = table_io.read_csv(f, schema, storage_config=storage_config) + table = table_io.read_csv(f, schema) assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_csv_read_data_csv_limit_rows(use_native_downloader): +def test_csv_read_data_csv_limit_rows(): with _csv_write_helper( header=["id", "data"], data=[ @@ -163,8 +136,6 @@ def test_csv_read_data_csv_limit_rows(use_native_downloader): ["3", None], ], ) as f: - storage_config = storage_config_from_use_native_downloader(use_native_downloader) - schema = Schema._from_field_name_and_types([("id", DataType.int64()), ("data", DataType.int64())]) expected = MicroPartition.from_pydict( { @@ -175,14 +146,12 @@ def test_csv_read_data_csv_limit_rows(use_native_downloader): table = table_io.read_csv( f, schema, - storage_config=storage_config, read_options=TableReadOptions(num_rows=2), ) assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_csv_read_data_csv_select_columns(use_native_downloader): +def test_csv_read_data_csv_select_columns(): with _csv_write_helper( header=["id", "data"], data=[ @@ -191,8 +160,6 @@ def test_csv_read_data_csv_select_columns(use_native_downloader): ["3", None], ], ) as f: - storage_config = storage_config_from_use_native_downloader(use_native_downloader) - schema = Schema._from_field_name_and_types([("id", DataType.int64()), ("data", DataType.int64())]) expected = MicroPartition.from_pydict( { @@ -202,14 +169,12 @@ def test_csv_read_data_csv_select_columns(use_native_downloader): table = table_io.read_csv( f, schema, - storage_config=storage_config, read_options=TableReadOptions(column_names=["data"]), ) assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_csv_read_data_csv_custom_delimiter(use_native_downloader): +def test_csv_read_data_csv_custom_delimiter(): with _csv_write_helper( header=["id", "data"], data=[ @@ -219,8 +184,6 @@ def test_csv_read_data_csv_custom_delimiter(use_native_downloader): ], delimiter="|", ) as f: - storage_config = storage_config_from_use_native_downloader(use_native_downloader) - schema = Schema._from_field_name_and_types([("id", DataType.int64()), ("data", DataType.int64())]) expected = MicroPartition.from_pydict( { @@ -231,14 +194,12 @@ def test_csv_read_data_csv_custom_delimiter(use_native_downloader): table = table_io.read_csv( f, schema, - storage_config=storage_config, csv_options=TableParseCSVOptions(delimiter="|"), ) assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_csv_read_data_csv_no_header(use_native_downloader): +def test_csv_read_data_csv_no_header(): with _csv_write_helper( header=None, data=[ @@ -247,8 +208,6 @@ def test_csv_read_data_csv_no_header(use_native_downloader): ["3", None], ], ) as f: - storage_config = storage_config_from_use_native_downloader(use_native_downloader) - schema = Schema._from_field_name_and_types([("id", DataType.int64()), ("data", DataType.int64())]) expected = MicroPartition.from_pydict( { @@ -259,15 +218,13 @@ def test_csv_read_data_csv_no_header(use_native_downloader): table = table_io.read_csv( f, schema, - storage_config=storage_config, csv_options=TableParseCSVOptions(header_index=None), read_options=TableReadOptions(column_names=["id", "data"]), ) assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_csv_read_data_csv_custom_quote(use_native_downloader): +def test_csv_read_data_csv_custom_quote(): with _csv_write_helper( header=["'id'", "'data'"], data=[ @@ -276,8 +233,6 @@ def test_csv_read_data_csv_custom_quote(use_native_downloader): ["3", "aa"], ], ) as f: - storage_config = storage_config_from_use_native_downloader(use_native_downloader) - schema = Schema._from_field_name_and_types([("id", DataType.int64()), ("data", DataType.string())]) expected = MicroPartition.from_pydict( { @@ -288,16 +243,13 @@ def test_csv_read_data_csv_custom_quote(use_native_downloader): table = table_io.read_csv( f, schema, - storage_config=storage_config, csv_options=TableParseCSVOptions(quote="'"), ) assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" -# TODO this test still fails with use_native_downloader = True -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_csv_read_data_custom_escape(use_native_downloader): +def test_csv_read_data_custom_escape(): with _csv_write_helper( header=["id", "data"], data=[ @@ -310,8 +262,6 @@ def test_csv_read_data_custom_escape(use_native_downloader): doublequote=False, quoting=csv.QUOTE_ALL, ) as f: - storage_config = storage_config_from_use_native_downloader(use_native_downloader) - schema = Schema._from_field_name_and_types([("id", DataType.int64()), ("data", DataType.string())]) expected = MicroPartition.from_pydict( { @@ -322,16 +272,13 @@ def test_csv_read_data_custom_escape(use_native_downloader): table = table_io.read_csv( f, schema, - storage_config=storage_config, csv_options=TableParseCSVOptions(escape_char="\\", double_quote=False), ) assert table.to_arrow() == expected.to_arrow(), f"Received:\n{table}\n\nExpected:\n{expected}" -# TODO Not testing use_native_downloader = False, as pyarrow does not support comments directly -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_csv_read_data_custom_comment(use_native_downloader): +def test_csv_read_data_custom_comment(): with tempfile.TemporaryDirectory() as directory_name: file = os.path.join(directory_name, "tempfile") with open(file, "w", newline="") as f: @@ -341,8 +288,6 @@ def test_csv_read_data_custom_comment(use_native_downloader): f.write("# comment line\n") writer.writerow(["3", "aa"]) - storage_config = storage_config_from_use_native_downloader(use_native_downloader) - schema = Schema._from_field_name_and_types([("id", DataType.int64()), ("data", DataType.string())]) expected = MicroPartition.from_pydict( { @@ -356,7 +301,6 @@ def test_csv_read_data_custom_comment(use_native_downloader): table = table_io.read_csv( file, schema, - storage_config=storage_config, csv_options=TableParseCSVOptions(comment="#"), ) assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" @@ -370,8 +314,6 @@ def test_csv_read_data_variable_missing_columns(): ["2", "2"], ], ) as f: - storage_config = storage_config_from_use_native_downloader(True) - schema = Schema._from_field_name_and_types( [ ("id", DataType.int64()), @@ -387,7 +329,6 @@ def test_csv_read_data_variable_missing_columns(): table = table_io.read_csv( f, schema, - storage_config=storage_config, csv_options=TableParseCSVOptions(allow_variable_columns=True), ) assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" @@ -401,8 +342,6 @@ def test_csv_read_data_variable_extra_columns(): ["2", "2", "2"], ], ) as f: - storage_config = storage_config_from_use_native_downloader(True) - schema = Schema._from_field_name_and_types( [ ("id", DataType.int64()), @@ -418,7 +357,6 @@ def test_csv_read_data_variable_extra_columns(): table = table_io.read_csv( f, schema, - storage_config=storage_config, csv_options=TableParseCSVOptions(allow_variable_columns=True), ) assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" @@ -432,8 +370,6 @@ def test_csv_read_data_variable_columns_with_non_matching_types(): ["2", "2"], ], ) as f: - storage_config = storage_config_from_use_native_downloader(True) - schema = Schema._from_field_name_and_types( [ ("id", DataType.int64()), @@ -449,7 +385,6 @@ def test_csv_read_data_variable_columns_with_non_matching_types(): table = table_io.read_csv( f, schema, - storage_config=storage_config, csv_options=TableParseCSVOptions(allow_variable_columns=True), ) assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" diff --git a/tests/table/table_io/test_json.py b/tests/table/table_io/test_json.py index 202f17ba29..08798c6a9b 100644 --- a/tests/table/table_io/test_json.py +++ b/tests/table/table_io/test_json.py @@ -10,22 +10,13 @@ import pytest import daft -from daft.daft import NativeStorageConfig, PythonStorageConfig, StorageConfig from daft.datatype import DataType from daft.logical.schema import Schema from daft.runners.partitioning import TableReadOptions -from daft.table import MicroPartition, schema_inference, table_io +from daft.table import MicroPartition, table_io -def storage_config_from_use_native_downloader(use_native_downloader: bool) -> StorageConfig: - if use_native_downloader: - return StorageConfig.native(NativeStorageConfig(True, None)) - else: - return StorageConfig.python(PythonStorageConfig(None)) - - -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_read_input(tmpdir, use_native_downloader): +def test_read_input(tmpdir): tmpdir = pathlib.Path(tmpdir) data = {"foo": [1, 2, 3]} with open(tmpdir / "file.json", "w") as f: @@ -34,18 +25,10 @@ def test_read_input(tmpdir, use_native_downloader): f.write("\n") schema = Schema._from_field_name_and_types([("foo", DataType.int64())]) - storage_config = storage_config_from_use_native_downloader(use_native_downloader) # Test pathlib, str and IO - assert table_io.read_json(tmpdir / "file.json", schema=schema, storage_config=storage_config).to_pydict() == data - assert ( - table_io.read_json(str(tmpdir / "file.json"), schema=schema, storage_config=storage_config).to_pydict() == data - ) - - with open(tmpdir / "file.json", "rb") as f: - if use_native_downloader: - f = tmpdir / "file.json" - assert table_io.read_json(f, schema=schema, storage_config=storage_config).to_pydict() == data + assert table_io.read_json(tmpdir / "file.json", schema=schema).to_pydict() == data + assert table_io.read_json(str(tmpdir / "file.json"), schema=schema).to_pydict() == data @contextlib.contextmanager @@ -78,16 +61,14 @@ def _json_write_helper(data: dict[str, list[Any]]): ([1, None, 2], DataType.list(DataType.int64())), ], ) -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_json_infer_schema(data, expected_dtype, use_native_downloader): +def test_json_infer_schema(data, expected_dtype): with _json_write_helper( { "id": [1, 2, 3], "data": [data, data, None], } ) as f: - storage_config = storage_config_from_use_native_downloader(use_native_downloader) - schema = schema_inference.from_json(f, storage_config=storage_config) + schema = Schema.from_json(f) assert schema == Schema._from_field_name_and_types([("id", DataType.int64()), ("data", expected_dtype)]) @@ -102,8 +83,7 @@ def test_json_infer_schema(data, expected_dtype, use_native_downloader): ({"foo": 1}, daft.Series.from_pylist([{"foo": 1}, {"foo": 1}, None])), ], ) -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_json_read_data(data, expected_data_series, use_native_downloader): +def test_json_read_data(data, expected_data_series): with _json_write_helper( { "id": [1, 2, 3], @@ -119,13 +99,11 @@ def test_json_read_data(data, expected_data_series, use_native_downloader): "data": expected_data_series, } ) - storage_config = storage_config_from_use_native_downloader(use_native_downloader) - table = table_io.read_json(f, schema, storage_config=storage_config) + table = table_io.read_json(f, schema) assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_json_read_data_limit_rows(use_native_downloader): +def test_json_read_data_limit_rows(): with _json_write_helper( { "id": [1, 2, 3], @@ -139,13 +117,11 @@ def test_json_read_data_limit_rows(use_native_downloader): "data": [1, 2], } ) - storage_config = storage_config_from_use_native_downloader(use_native_downloader) - table = table_io.read_json(f, schema, read_options=TableReadOptions(num_rows=2), storage_config=storage_config) + table = table_io.read_json(f, schema, read_options=TableReadOptions(num_rows=2)) assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_json_read_data_select_columns(use_native_downloader): +def test_json_read_data_select_columns(): with _json_write_helper( { "id": [1, 2, 3], @@ -158,8 +134,5 @@ def test_json_read_data_select_columns(use_native_downloader): "data": [1, 2, None], } ) - storage_config = storage_config_from_use_native_downloader(use_native_downloader) - table = table_io.read_json( - f, schema, read_options=TableReadOptions(column_names=["data"]), storage_config=storage_config - ) + table = table_io.read_json(f, schema, read_options=TableReadOptions(column_names=["data"])) assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" diff --git a/tests/table/table_io/test_parquet.py b/tests/table/table_io/test_parquet.py index d350d016ca..7dadf8deb0 100644 --- a/tests/table/table_io/test_parquet.py +++ b/tests/table/table_io/test_parquet.py @@ -12,7 +12,6 @@ import pytest import daft -from daft.daft import NativeStorageConfig, PythonStorageConfig, StorageConfig from daft.datatype import DataType, TimeUnit from daft.exceptions import DaftCoreException from daft.logical.schema import Schema @@ -21,7 +20,6 @@ MicroPartition, read_parquet_into_pyarrow, read_parquet_into_pyarrow_bulk, - schema_inference, table_io, ) @@ -29,13 +27,6 @@ PYARROW_GE_13_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (13, 0, 0) -def storage_config_from_use_native_downloader(use_native_downloader: bool) -> StorageConfig: - if use_native_downloader: - return StorageConfig.native(NativeStorageConfig(True, None)) - else: - return StorageConfig.python(PythonStorageConfig(None)) - - def test_read_input(tmpdir): tmpdir = pathlib.Path(tmpdir) data = pa.Table.from_pydict({"foo": [1, 2, 3]}) @@ -48,9 +39,6 @@ def test_read_input(tmpdir): assert table_io.read_parquet(tmpdir / "file.parquet", schema=schema).to_arrow() == data assert table_io.read_parquet(str(tmpdir / "file.parquet"), schema=schema).to_arrow() == data - with open(tmpdir / "file.parquet", "rb") as f: - assert table_io.read_parquet(f, schema=schema).to_arrow() == data - @contextlib.contextmanager def _parquet_write_helper(data: pa.Table, row_group_size: int | None = None, papq_write_table_kwargs: dict = {}): @@ -76,14 +64,7 @@ def _parquet_write_helper(data: pa.Table, row_group_size: int | None = None, pap ([1, None, 2], DataType.list(DataType.int64())), ], ) -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_parquet_infer_schema(data, expected_dtype, use_native_downloader): - # HACK: Pyarrow 13 changed their schema parsing behavior so we receive DataType.list(..) instead of DataType.list(..) - # However, our native downloader still parses DataType.list(..) regardless of PyArrow version - if PYARROW_GE_13_0_0 and not use_native_downloader and expected_dtype == DataType.list(DataType.int64()): - expected_dtype = DataType.list(DataType.int64()) - storage_config = storage_config_from_use_native_downloader(use_native_downloader) - +def test_parquet_infer_schema(data, expected_dtype): with _parquet_write_helper( pa.Table.from_pydict( { @@ -92,17 +73,15 @@ def test_parquet_infer_schema(data, expected_dtype, use_native_downloader): } ) ) as f: - schema = schema_inference.from_parquet(f, storage_config=storage_config) + schema = Schema.from_parquet(f) assert schema == Schema._from_field_name_and_types([("id", DataType.int64()), ("data", expected_dtype)]) -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_parquet_read_empty(use_native_downloader): +def test_parquet_read_empty(): with _parquet_write_helper(pa.Table.from_pydict({"foo": pa.array([], type=pa.int64())})) as f: schema = Schema._from_field_name_and_types([("foo", DataType.int64())]) expected = MicroPartition.from_pydict({"foo": pa.array([], type=pa.int64())}) - storage_config = storage_config_from_use_native_downloader(use_native_downloader) - table = table_io.read_parquet(f, schema, storage_config=storage_config) + table = table_io.read_parquet(f, schema) assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" @@ -123,8 +102,7 @@ def test_parquet_read_empty(use_native_downloader): ({"foo": 1}, daft.Series.from_pylist([{"foo": 1}, {"foo": 1}, None])), ], ) -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_parquet_read_data(data, expected_data_series, use_native_downloader): +def test_parquet_read_data(data, expected_data_series): with _parquet_write_helper( pa.Table.from_pydict( { @@ -142,14 +120,12 @@ def test_parquet_read_data(data, expected_data_series, use_native_downloader): "data": expected_data_series, } ) - storage_config = storage_config_from_use_native_downloader(use_native_downloader) - table = table_io.read_parquet(f, schema, storage_config=storage_config) + table = table_io.read_parquet(f, schema) assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" @pytest.mark.parametrize("row_group_size", [None, 1, 3]) -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_parquet_read_data_limit_rows(row_group_size, use_native_downloader): +def test_parquet_read_data_limit_rows(row_group_size): with _parquet_write_helper( pa.Table.from_pydict( { @@ -166,10 +142,7 @@ def test_parquet_read_data_limit_rows(row_group_size, use_native_downloader): "data": [1, 2], } ) - storage_config = storage_config_from_use_native_downloader(use_native_downloader) - table = table_io.read_parquet( - f, schema, read_options=TableReadOptions(num_rows=2), storage_config=storage_config - ) + table = table_io.read_parquet(f, schema, read_options=TableReadOptions(num_rows=2)) assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" @@ -180,8 +153,7 @@ def test_parquet_read_data_multi_row_groups(): assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" -@pytest.mark.parametrize("use_native_downloader", [True, False]) -def test_parquet_read_data_select_columns(use_native_downloader): +def test_parquet_read_data_select_columns(): with _parquet_write_helper( pa.Table.from_pydict( { @@ -196,10 +168,7 @@ def test_parquet_read_data_select_columns(use_native_downloader): "data": [1, 2, None], } ) - storage_config = storage_config_from_use_native_downloader(use_native_downloader) - table = table_io.read_parquet( - f, schema, read_options=TableReadOptions(column_names=["data"]), storage_config=storage_config - ) + table = table_io.read_parquet(f, schema, read_options=TableReadOptions(column_names=["data"])) assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" @@ -208,9 +177,8 @@ def test_parquet_read_data_select_columns(use_native_downloader): ### -@pytest.mark.parametrize("use_native_downloader", [True, False]) @pytest.mark.parametrize("use_deprecated_int96_timestamps", [True, False]) -def test_parquet_read_int96_timestamps(use_deprecated_int96_timestamps, use_native_downloader): +def test_parquet_read_int96_timestamps(use_deprecated_int96_timestamps): data = { "timestamp_ms": pa.array([1, 2, 3], pa.timestamp("ms")), "timestamp_us": pa.array([1, 2, 3], pa.timestamp("us")), @@ -237,19 +205,16 @@ def test_parquet_read_int96_timestamps(use_deprecated_int96_timestamps, use_nati ) as f: schema = Schema._from_field_name_and_types(schema) expected = MicroPartition.from_pydict(data) - storage_config = storage_config_from_use_native_downloader(use_native_downloader) table = table_io.read_parquet( f, schema, read_options=TableReadOptions(column_names=schema.column_names()), - storage_config=storage_config, ) assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" -@pytest.mark.parametrize("use_native_downloader", [True, False]) @pytest.mark.parametrize("coerce_to", [TimeUnit.ms(), TimeUnit.us()]) -def test_parquet_read_int96_timestamps_overflow(coerce_to, use_native_downloader): +def test_parquet_read_int96_timestamps_overflow(coerce_to): # NOTE: datetime.datetime(3000, 1, 1) and datetime.datetime(1000, 1, 1) cannot be represented by our timestamp64(nanosecond) # type. However they can be written to Parquet's INT96 type. Here we test that a round-trip is possible if provided with # the appropriate flags. @@ -275,13 +240,11 @@ def test_parquet_read_int96_timestamps_overflow(coerce_to, use_native_downloader ) as f: schema = Schema._from_field_name_and_types(schema) expected = MicroPartition.from_pydict(data) - storage_config = storage_config_from_use_native_downloader(use_native_downloader) table = table_io.read_parquet( f, schema, read_options=TableReadOptions(column_names=schema.column_names()), parquet_options=TableParseParquetOptions(coerce_int96_timestamp_unit=coerce_to), - storage_config=storage_config, ) assert table.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{table}" diff --git a/tests/table/table_io/test_read_time_cast.py b/tests/table/table_io/test_read_time_cast.py index 65f00305dd..5d59806a42 100644 --- a/tests/table/table_io/test_read_time_cast.py +++ b/tests/table/table_io/test_read_time_cast.py @@ -28,9 +28,9 @@ ), # Test reordering of columns ( - pa.Table.from_pydict({"foo": pa.array([1, 2, 3]), "bar": pa.array([1, 2, 3])}), + pa.Table.from_pydict({"foo": pa.array([1, 2, 3]), "bar": pa.array([4, 5, 6])}), Schema._from_field_name_and_types([("bar", DataType.int64()), ("foo", DataType.int64())]), - MicroPartition.from_pydict({"bar": pa.array([1, 2, 3]), "foo": pa.array([1, 2, 3])}), + MicroPartition.from_pydict({"bar": pa.array([4, 5, 6]), "foo": pa.array([1, 2, 3])}), ), # Test automatic insertion of null values for missing column ( diff --git a/tests/udf_library/test_url_udfs.py b/tests/udf_library/test_url_udfs.py index ff4dd11a82..66ddc6c4fa 100644 --- a/tests/udf_library/test_url_udfs.py +++ b/tests/udf_library/test_url_udfs.py @@ -9,7 +9,7 @@ import daft from daft.expressions import col -from tests.conftest import assert_df_equals, get_tests_daft_runner_name +from tests.conftest import assert_df_equals def _get_filename(): @@ -30,31 +30,28 @@ def files(tmpdir) -> list[str]: return filepaths -@pytest.mark.parametrize("use_native_downloader", [False, True]) -def test_download(files, use_native_downloader): +def test_download(files): # Run it twice to ensure runtime works for _ in range(2): df = daft.from_pydict({"filenames": [str(f) for f in files]}) - df = df.with_column("bytes", col("filenames").url.download(use_native_downloader=use_native_downloader)) + df = df.with_column("bytes", col("filenames").url.download()) pd_df = pd.DataFrame.from_dict({"filenames": [str(f) for f in files]}) pd_df["bytes"] = pd.Series([pathlib.Path(fn).read_bytes() for fn in files]) assert_df_equals(df.to_pandas(), pd_df, sort_key="filenames") -@pytest.mark.parametrize("use_native_downloader", [False, True]) -def test_download_with_none(files, use_native_downloader): +def test_download_with_none(files): data = {"id": list(range(len(files) * 2)), "filenames": [str(f) for f in files] + [None for _ in range(len(files))]} # Run it twice to ensure runtime works for _ in range(2): df = daft.from_pydict(data) - df = df.with_column("bytes", col("filenames").url.download(use_native_downloader=use_native_downloader)) + df = df.with_column("bytes", col("filenames").url.download()) pd_df = pd.DataFrame.from_dict(data) pd_df["bytes"] = pd.Series([pathlib.Path(fn).read_bytes() if fn is not None else None for fn in files]) assert_df_equals(df.to_pandas(), pd_df, sort_key="id") -@pytest.mark.parametrize("use_native_downloader", [False, True]) -def test_download_with_missing_urls(files, use_native_downloader): +def test_download_with_missing_urls(files): data = { "id": list(range(len(files) * 2)), "filenames": [str(f) for f in files] + [str(uuid.uuid4()) for _ in range(len(files))], @@ -63,7 +60,10 @@ def test_download_with_missing_urls(files, use_native_downloader): for _ in range(2): df = daft.from_pydict(data) df = df.with_column( - "bytes", col("filenames").url.download(on_error="null", use_native_downloader=use_native_downloader) + "bytes", + col("filenames").url.download( + on_error="null", + ), ) pd_df = pd.DataFrame.from_dict(data) pd_df["bytes"] = pd.Series( @@ -72,8 +72,7 @@ def test_download_with_missing_urls(files, use_native_downloader): assert_df_equals(df.to_pandas(), pd_df, sort_key="id") -@pytest.mark.parametrize("use_native_downloader", [False, True]) -def test_download_with_missing_urls_reraise_errors(files, use_native_downloader): +def test_download_with_missing_urls_reraise_errors(files): data = { "id": list(range(len(files) * 2)), "filenames": [str(f) for f in files] + [str(uuid.uuid4()) for _ in range(len(files))], @@ -82,26 +81,18 @@ def test_download_with_missing_urls_reraise_errors(files, use_native_downloader) for _ in range(2): df = daft.from_pydict(data) df = df.with_column( - "bytes", col("filenames").url.download(on_error="raise", use_native_downloader=use_native_downloader) + "bytes", + col("filenames").url.download( + on_error="raise", + ), ) # TODO: Change to a FileNotFound Error - if not use_native_downloader: - with pytest.raises(RuntimeError) as exc_info: - df.collect() + with pytest.raises(FileNotFoundError): + df.collect() - # Ray's wrapping of the exception loses information about the `.cause`, but preserves it in the string error message - if get_tests_daft_runner_name() == "ray": - assert "FileNotFoundError" in str(exc_info.value) - else: - assert isinstance(exc_info.value.__cause__, FileNotFoundError) - else: - with pytest.raises(FileNotFoundError): - df.collect() - -@pytest.mark.parametrize("use_native_downloader", [False, True]) -def test_download_with_duplicate_urls(files, use_native_downloader): +def test_download_with_duplicate_urls(files): data = { "id": list(range(len(files) * 2)), "filenames": [str(f) for f in files] * 2, @@ -109,7 +100,7 @@ def test_download_with_duplicate_urls(files, use_native_downloader): # Run it twice to ensure runtime works for _ in range(2): df = daft.from_pydict(data) - df = df.with_column("bytes", col("filenames").url.download(use_native_downloader=use_native_downloader)) + df = df.with_column("bytes", col("filenames").url.download()) pd_df = pd.DataFrame.from_dict(data) pd_df["bytes"] = pd.Series( [pathlib.Path(fn).read_bytes() if pathlib.Path(fn).exists() else None for fn in files * 2] From ca4d3f794a376616557110ab12f6a115f17735f9 Mon Sep 17 00:00:00 2001 From: Cory Grinstead Date: Wed, 18 Dec 2024 16:35:18 -0600 Subject: [PATCH 14/14] feat(connect): df.show (#3560) depends on https://github.com/Eventual-Inc/Daft/pull/3554 [see here for proper diff](https://github.com/universalmind303/Daft/compare/refactor-lp-3...universalmind303:Daft:connect_show?expand=1) --- src/daft-connect/src/session.rs | 1 - .../src/translation/logical_plan.rs | 105 +++++++++++++++++- src/daft-dsl/src/lit.rs | 6 + tests/connect/test_show.py | 9 ++ 4 files changed, 116 insertions(+), 5 deletions(-) create mode 100644 tests/connect/test_show.py diff --git a/src/daft-connect/src/session.rs b/src/daft-connect/src/session.rs index 30f827ba9e..7de8d5851b 100644 --- a/src/daft-connect/src/session.rs +++ b/src/daft-connect/src/session.rs @@ -28,7 +28,6 @@ impl Session { pub fn new(id: String) -> Self { let server_side_session_id = Uuid::new_v4(); let server_side_session_id = server_side_session_id.to_string(); - Self { config_values: Default::default(), id, diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs index 15eb495502..439f5bd551 100644 --- a/src/daft-connect/src/translation/logical_plan.rs +++ b/src/daft-connect/src/translation/logical_plan.rs @@ -1,7 +1,21 @@ +use std::sync::Arc; + +use common_daft_config::DaftExecutionConfig; +use daft_core::prelude::Schema; +use daft_dsl::LiteralValue; +use daft_local_execution::NativeExecutor; use daft_logical_plan::LogicalPlanBuilder; -use daft_micropartition::partitioning::InMemoryPartitionSetCache; +use daft_micropartition::{ + partitioning::{ + InMemoryPartitionSetCache, MicroPartitionSet, PartitionCacheEntry, PartitionMetadata, + PartitionSet, PartitionSetCache, + }, + MicroPartition, +}; +use daft_table::Table; use eyre::{bail, Context}; -use spark_connect::{relation::RelType, Limit, Relation}; +use futures::TryStreamExt; +use spark_connect::{relation::RelType, Limit, Relation, ShowString}; use tracing::warn; mod aggregate; @@ -22,6 +36,35 @@ impl SparkAnalyzer<'_> { pub fn new(pset: &InMemoryPartitionSetCache) -> SparkAnalyzer { SparkAnalyzer { psets: pset } } + pub fn create_in_memory_scan( + &self, + plan_id: usize, + schema: Arc, + tables: Vec
, + ) -> eyre::Result { + let partition_key = uuid::Uuid::new_v4().to_string(); + + let pset = Arc::new(MicroPartitionSet::from_tables(plan_id, tables)?); + + let PartitionMetadata { + num_rows, + size_bytes, + } = pset.metadata(); + let num_partitions = pset.num_partitions(); + + self.psets.put_partition_set(&partition_key, &pset); + + let cache_entry = PartitionCacheEntry::new_rust(partition_key.clone(), pset); + + Ok(LogicalPlanBuilder::in_memory_scan( + &partition_key, + cache_entry, + schema, + num_partitions, + size_bytes, + num_rows, + )?) + } pub async fn to_logical_plan(&self, relation: Relation) -> eyre::Result { let Some(common) = relation.common else { @@ -78,12 +121,18 @@ impl SparkAnalyzer<'_> { .filter(*f) .await .wrap_err("Failed to apply filter to logical plan"), + RelType::ShowString(ss) => { + let Some(plan_id) = common.plan_id else { + bail!("Plan ID is required for LocalRelation"); + }; + self.show_string(plan_id, *ss) + .await + .wrap_err("Failed to show string") + } plan => bail!("Unsupported relation type: {plan:?}"), } } -} -impl SparkAnalyzer<'_> { async fn limit(&self, limit: Limit) -> eyre::Result { let Limit { input, limit } = limit; @@ -96,4 +145,52 @@ impl SparkAnalyzer<'_> { plan.limit(i64::from(limit), false) .wrap_err("Failed to apply limit to logical plan") } + + /// right now this just naively applies a limit to the logical plan + /// In the future, we want this to more closely match our daft implementation + async fn show_string( + &self, + plan_id: i64, + show_string: ShowString, + ) -> eyre::Result { + let ShowString { + input, + num_rows, + truncate: _, + vertical, + } = show_string; + + if vertical { + bail!("Vertical show string is not supported"); + } + + let Some(input) = input else { + bail!("input must be set"); + }; + + let plan = Box::pin(self.to_logical_plan(*input)).await?; + let plan = plan.limit(num_rows as i64, true)?; + + let optimized_plan = tokio::task::spawn_blocking(move || plan.optimize()) + .await + .unwrap()?; + + let cfg = Arc::new(DaftExecutionConfig::default()); + let native_executor = NativeExecutor::from_logical_plan_builder(&optimized_plan)?; + let result_stream = native_executor.run(self.psets, cfg, None)?.into_stream(); + let batch = result_stream.try_collect::>().await?; + let single_batch = MicroPartition::concat(batch)?; + let tbls = single_batch.get_tables()?; + let tbl = Table::concat(&tbls)?; + let output = tbl.to_comfy_table(None).to_string(); + + let s = LiteralValue::Utf8(output) + .into_single_value_series()? + .rename("show_string"); + + let tbl = Table::from_nonempty_columns(vec![s])?; + let schema = tbl.schema.clone(); + + self.create_in_memory_scan(plan_id as _, schema, vec![tbl]) + } } diff --git a/src/daft-dsl/src/lit.rs b/src/daft-dsl/src/lit.rs index 1d86442aef..c1c7ce81c3 100644 --- a/src/daft-dsl/src/lit.rs +++ b/src/daft-dsl/src/lit.rs @@ -444,6 +444,12 @@ pub fn null_lit() -> ExprRef { Arc::new(Expr::Literal(LiteralValue::Null)) } +impl LiteralValue { + pub fn into_single_value_series(self) -> DaftResult { + literals_to_series(&[self]) + } +} + /// Convert a slice of literals to a series. /// This function will return an error if the literals are not all the same type pub fn literals_to_series(values: &[LiteralValue]) -> DaftResult { diff --git a/tests/connect/test_show.py b/tests/connect/test_show.py new file mode 100644 index 0000000000..a463d5de72 --- /dev/null +++ b/tests/connect/test_show.py @@ -0,0 +1,9 @@ +from __future__ import annotations + + +def test_show(spark_session): + df = spark_session.range(10) + try: + df.show() + except Exception as e: + assert False, e