From d1b06fb2b7dc0f47b0715bd6087701a56cac5f23 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Tue, 22 Oct 2024 18:57:06 -0700 Subject: [PATCH] [FEAT] Enable explode for swordfish (#3077) Adds explode as an intermediate operator. Unskips all the explode tests --------- Co-authored-by: Colin Ho --- Cargo.lock | 1 + src/daft-local-execution/Cargo.toml | 1 + .../src/intermediate_ops/explode.rs | 42 +++++++++++++++++++ .../src/intermediate_ops/mod.rs | 1 + src/daft-local-execution/src/pipeline.rs | 18 +++++--- src/daft-physical-plan/src/lib.rs | 6 +-- src/daft-physical-plan/src/local_plan.rs | 25 ++++++++++- src/daft-physical-plan/src/translate.rs | 8 ++++ tests/dataframe/test_explode.py | 6 --- tests/dataframe/test_wildcard.py | 7 +--- tests/sql/test_list_exprs.py | 5 +-- 11 files changed, 95 insertions(+), 25 deletions(-) create mode 100644 src/daft-local-execution/src/intermediate_ops/explode.rs diff --git a/Cargo.lock b/Cargo.lock index c57309e10f..3fcc294fef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1960,6 +1960,7 @@ dependencies = [ "daft-core", "daft-csv", "daft-dsl", + "daft-functions", "daft-io", "daft-json", "daft-micropartition", diff --git a/src/daft-local-execution/Cargo.toml b/src/daft-local-execution/Cargo.toml index cd061c1c35..932462215d 100644 --- a/src/daft-local-execution/Cargo.toml +++ b/src/daft-local-execution/Cargo.toml @@ -7,6 +7,7 @@ common-tracing = {path = "../common/tracing", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-csv = {path = "../daft-csv", default-features = false} daft-dsl = {path = "../daft-dsl", default-features = false} +daft-functions = {path = "../daft-functions", default-features = false} daft-io = {path = "../daft-io", default-features = false} daft-json = {path = "../daft-json", default-features = false} daft-micropartition = {path = "../daft-micropartition", default-features = false} diff --git a/src/daft-local-execution/src/intermediate_ops/explode.rs b/src/daft-local-execution/src/intermediate_ops/explode.rs new file mode 100644 index 0000000000..774be696a8 --- /dev/null +++ b/src/daft-local-execution/src/intermediate_ops/explode.rs @@ -0,0 +1,42 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use daft_dsl::ExprRef; +use daft_functions::list::explode; +use tracing::instrument; + +use super::intermediate_op::{ + IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, +}; +use crate::pipeline::PipelineResultType; + +pub struct ExplodeOperator { + to_explode: Vec, +} + +impl ExplodeOperator { + pub fn new(to_explode: Vec) -> Self { + Self { + to_explode: to_explode.into_iter().map(explode).collect(), + } + } +} + +impl IntermediateOperator for ExplodeOperator { + #[instrument(skip_all, name = "ExplodeOperator::execute")] + fn execute( + &self, + _idx: usize, + input: &PipelineResultType, + _state: Option<&mut Box>, + ) -> DaftResult { + let out = input.as_data().explode(&self.to_explode)?; + Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( + out, + )))) + } + + fn name(&self) -> &'static str { + "ExplodeOperator" + } +} diff --git a/src/daft-local-execution/src/intermediate_ops/mod.rs b/src/daft-local-execution/src/intermediate_ops/mod.rs index 098bbcbfbe..7d97464e24 100644 --- a/src/daft-local-execution/src/intermediate_ops/mod.rs +++ b/src/daft-local-execution/src/intermediate_ops/mod.rs @@ -1,6 +1,7 @@ pub mod aggregate; pub mod anti_semi_hash_join_probe; pub mod buffer; +pub mod explode; pub mod filter; pub mod inner_hash_join_probe; pub mod intermediate_op; diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index f29135afb5..eccece1a56 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -10,8 +10,8 @@ use daft_core::{ use daft_dsl::{col, join::get_common_join_keys, Expr}; use daft_micropartition::MicroPartition; use daft_physical_plan::{ - Concat, EmptyScan, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, LocalPhysicalPlan, - Pivot, Project, Sample, Sort, UnGroupedAggregate, Unpivot, + Concat, EmptyScan, Explode, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, + LocalPhysicalPlan, Pivot, Project, Sample, Sort, UnGroupedAggregate, Unpivot, }; use daft_plan::{populate_aggregation_stages, JoinType}; use daft_table::ProbeState; @@ -22,9 +22,10 @@ use crate::{ channel::PipelineChannel, intermediate_ops::{ aggregate::AggregateOperator, anti_semi_hash_join_probe::AntiSemiProbeOperator, - filter::FilterOperator, inner_hash_join_probe::InnerHashJoinProbeOperator, - intermediate_op::IntermediateNode, pivot::PivotOperator, project::ProjectOperator, - sample::SampleOperator, unpivot::UnpivotOperator, + explode::ExplodeOperator, filter::FilterOperator, + inner_hash_join_probe::InnerHashJoinProbeOperator, intermediate_op::IntermediateNode, + pivot::PivotOperator, project::ProjectOperator, sample::SampleOperator, + unpivot::UnpivotOperator, }, sinks::{ aggregate::AggregateSink, blocking_sink::BlockingSinkNode, concat::ConcatSink, @@ -145,6 +146,13 @@ pub fn physical_plan_to_pipeline( let child_node = physical_plan_to_pipeline(input, psets)?; IntermediateNode::new(Arc::new(filter_op), vec![child_node]).boxed() } + LocalPhysicalPlan::Explode(Explode { + input, to_explode, .. + }) => { + let explode_op = ExplodeOperator::new(to_explode.clone()); + let child_node = physical_plan_to_pipeline(input, psets)?; + IntermediateNode::new(Arc::new(explode_op), vec![child_node]).boxed() + } LocalPhysicalPlan::Limit(Limit { input, num_rows, .. }) => { diff --git a/src/daft-physical-plan/src/lib.rs b/src/daft-physical-plan/src/lib.rs index 75aa616394..ba20720855 100644 --- a/src/daft-physical-plan/src/lib.rs +++ b/src/daft-physical-plan/src/lib.rs @@ -3,8 +3,8 @@ mod local_plan; mod translate; pub use local_plan::{ - Concat, EmptyScan, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, LocalPhysicalPlan, - LocalPhysicalPlanRef, PhysicalScan, PhysicalWrite, Pivot, Project, Sample, Sort, - UnGroupedAggregate, Unpivot, + Concat, EmptyScan, Explode, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, + LocalPhysicalPlan, LocalPhysicalPlanRef, PhysicalScan, PhysicalWrite, Pivot, Project, Sample, + Sort, UnGroupedAggregate, Unpivot, }; pub use translate::translate; diff --git a/src/daft-physical-plan/src/local_plan.rs b/src/daft-physical-plan/src/local_plan.rs index 4ba861798e..94672c2463 100644 --- a/src/daft-physical-plan/src/local_plan.rs +++ b/src/daft-physical-plan/src/local_plan.rs @@ -15,7 +15,7 @@ pub enum LocalPhysicalPlan { Project(Project), Filter(Filter), Limit(Limit), - // Explode(Explode), + Explode(Explode), Unpivot(Unpivot), Sort(Sort), // Split(Split), @@ -107,6 +107,20 @@ impl LocalPhysicalPlan { .arced() } + pub(crate) fn explode( + input: LocalPhysicalPlanRef, + to_explode: Vec, + schema: SchemaRef, + ) -> LocalPhysicalPlanRef { + Self::Explode(Explode { + input, + to_explode, + schema, + plan_stats: PlanStats {}, + }) + .arced() + } + pub(crate) fn project( input: LocalPhysicalPlanRef, projection: Vec, @@ -272,6 +286,7 @@ impl LocalPhysicalPlan { | Self::Sort(Sort { schema, .. }) | Self::Sample(Sample { schema, .. }) | Self::HashJoin(HashJoin { schema, .. }) + | Self::Explode(Explode { schema, .. }) | Self::Unpivot(Unpivot { schema, .. }) | Self::Concat(Concat { schema, .. }) => schema, Self::InMemoryScan(InMemoryScan { info, .. }) => &info.source_schema, @@ -323,6 +338,14 @@ pub struct Limit { pub plan_stats: PlanStats, } +#[derive(Debug)] +pub struct Explode { + pub input: LocalPhysicalPlanRef, + pub to_explode: Vec, + pub schema: SchemaRef, + pub plan_stats: PlanStats, +} + #[derive(Debug)] pub struct Sort { pub input: LocalPhysicalPlanRef, diff --git a/src/daft-physical-plan/src/translate.rs b/src/daft-physical-plan/src/translate.rs index 726b3232d5..7dcb0f552b 100644 --- a/src/daft-physical-plan/src/translate.rs +++ b/src/daft-physical-plan/src/translate.rs @@ -158,6 +158,14 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { log::warn!("Repartition Not supported for Local Executor!; This will be a No-Op"); translate(&repartition.input) } + LogicalPlan::Explode(explode) => { + let input = translate(&explode.input)?; + Ok(LocalPhysicalPlan::explode( + input, + explode.to_explode.clone(), + explode.exploded_schema.clone(), + )) + } _ => todo!("{} not yet implemented", plan.name()), } } diff --git a/tests/dataframe/test_explode.py b/tests/dataframe/test_explode.py index 0e8dbd73d2..26416f9938 100644 --- a/tests/dataframe/test_explode.py +++ b/tests/dataframe/test_explode.py @@ -3,14 +3,8 @@ import pyarrow as pa import pytest -from daft import context from daft.expressions import col -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) - @pytest.mark.parametrize( "data", diff --git a/tests/dataframe/test_wildcard.py b/tests/dataframe/test_wildcard.py index 3497912be5..e732292c53 100644 --- a/tests/dataframe/test_wildcard.py +++ b/tests/dataframe/test_wildcard.py @@ -1,14 +1,9 @@ import pytest import daft -from daft import col, context +from daft import col from daft.exceptions import DaftCoreException -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) - def test_wildcard_select(): df = daft.from_pydict( diff --git a/tests/sql/test_list_exprs.py b/tests/sql/test_list_exprs.py index 9b76735e44..2f0799fb71 100644 --- a/tests/sql/test_list_exprs.py +++ b/tests/sql/test_list_exprs.py @@ -1,8 +1,7 @@ import pyarrow as pa -import pytest import daft -from daft import col, context +from daft import col from daft.daft import CountMode from daft.sql.sql import SQLCatalog @@ -62,8 +61,6 @@ def test_list_counts(): def test_list_explode(): - if context.get_context().daft_execution_config.enable_native_executor is True: - pytest.skip("Native executor fails for these tests") df = daft.from_pydict({"col": [[1, 2, 3], [1, 2], [1, None, 4], []]}) catalog = SQLCatalog({"test": df}) expected = df.explode(col("col"))