Skip to content

Commit

Permalink
[FEAT] Enable explode for swordfish (#3077)
Browse files Browse the repository at this point in the history
Adds explode as an intermediate operator. Unskips all the explode tests

---------

Co-authored-by: Colin Ho <[email protected]>
  • Loading branch information
colin-ho and Colin Ho authored Oct 23, 2024
1 parent 0727dc1 commit d1b06fb
Show file tree
Hide file tree
Showing 11 changed files with 95 additions and 25 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/daft-local-execution/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ common-tracing = {path = "../common/tracing", default-features = false}
daft-core = {path = "../daft-core", default-features = false}
daft-csv = {path = "../daft-csv", default-features = false}
daft-dsl = {path = "../daft-dsl", default-features = false}
daft-functions = {path = "../daft-functions", default-features = false}
daft-io = {path = "../daft-io", default-features = false}
daft-json = {path = "../daft-json", default-features = false}
daft-micropartition = {path = "../daft-micropartition", default-features = false}
Expand Down
42 changes: 42 additions & 0 deletions src/daft-local-execution/src/intermediate_ops/explode.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
use std::sync::Arc;

use common_error::DaftResult;
use daft_dsl::ExprRef;
use daft_functions::list::explode;
use tracing::instrument;

use super::intermediate_op::{
IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState,
};
use crate::pipeline::PipelineResultType;

pub struct ExplodeOperator {
to_explode: Vec<ExprRef>,
}

impl ExplodeOperator {
pub fn new(to_explode: Vec<ExprRef>) -> Self {
Self {
to_explode: to_explode.into_iter().map(explode).collect(),
}
}
}

impl IntermediateOperator for ExplodeOperator {
#[instrument(skip_all, name = "ExplodeOperator::execute")]
fn execute(
&self,
_idx: usize,
input: &PipelineResultType,
_state: Option<&mut Box<dyn IntermediateOperatorState>>,
) -> DaftResult<IntermediateOperatorResult> {
let out = input.as_data().explode(&self.to_explode)?;
Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new(
out,
))))
}

fn name(&self) -> &'static str {
"ExplodeOperator"
}
}
1 change: 1 addition & 0 deletions src/daft-local-execution/src/intermediate_ops/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod aggregate;
pub mod anti_semi_hash_join_probe;
pub mod buffer;
pub mod explode;
pub mod filter;
pub mod inner_hash_join_probe;
pub mod intermediate_op;
Expand Down
18 changes: 13 additions & 5 deletions src/daft-local-execution/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use daft_core::{
use daft_dsl::{col, join::get_common_join_keys, Expr};
use daft_micropartition::MicroPartition;
use daft_physical_plan::{
Concat, EmptyScan, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, LocalPhysicalPlan,
Pivot, Project, Sample, Sort, UnGroupedAggregate, Unpivot,
Concat, EmptyScan, Explode, Filter, HashAggregate, HashJoin, InMemoryScan, Limit,
LocalPhysicalPlan, Pivot, Project, Sample, Sort, UnGroupedAggregate, Unpivot,
};
use daft_plan::{populate_aggregation_stages, JoinType};
use daft_table::ProbeState;
Expand All @@ -22,9 +22,10 @@ use crate::{
channel::PipelineChannel,
intermediate_ops::{
aggregate::AggregateOperator, anti_semi_hash_join_probe::AntiSemiProbeOperator,
filter::FilterOperator, inner_hash_join_probe::InnerHashJoinProbeOperator,
intermediate_op::IntermediateNode, pivot::PivotOperator, project::ProjectOperator,
sample::SampleOperator, unpivot::UnpivotOperator,
explode::ExplodeOperator, filter::FilterOperator,
inner_hash_join_probe::InnerHashJoinProbeOperator, intermediate_op::IntermediateNode,
pivot::PivotOperator, project::ProjectOperator, sample::SampleOperator,
unpivot::UnpivotOperator,
},
sinks::{
aggregate::AggregateSink, blocking_sink::BlockingSinkNode, concat::ConcatSink,
Expand Down Expand Up @@ -145,6 +146,13 @@ pub fn physical_plan_to_pipeline(
let child_node = physical_plan_to_pipeline(input, psets)?;
IntermediateNode::new(Arc::new(filter_op), vec![child_node]).boxed()
}
LocalPhysicalPlan::Explode(Explode {
input, to_explode, ..
}) => {
let explode_op = ExplodeOperator::new(to_explode.clone());
let child_node = physical_plan_to_pipeline(input, psets)?;
IntermediateNode::new(Arc::new(explode_op), vec![child_node]).boxed()
}
LocalPhysicalPlan::Limit(Limit {
input, num_rows, ..
}) => {
Expand Down
6 changes: 3 additions & 3 deletions src/daft-physical-plan/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ mod local_plan;
mod translate;

pub use local_plan::{
Concat, EmptyScan, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, LocalPhysicalPlan,
LocalPhysicalPlanRef, PhysicalScan, PhysicalWrite, Pivot, Project, Sample, Sort,
UnGroupedAggregate, Unpivot,
Concat, EmptyScan, Explode, Filter, HashAggregate, HashJoin, InMemoryScan, Limit,
LocalPhysicalPlan, LocalPhysicalPlanRef, PhysicalScan, PhysicalWrite, Pivot, Project, Sample,
Sort, UnGroupedAggregate, Unpivot,
};
pub use translate::translate;
25 changes: 24 additions & 1 deletion src/daft-physical-plan/src/local_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub enum LocalPhysicalPlan {
Project(Project),
Filter(Filter),
Limit(Limit),
// Explode(Explode),
Explode(Explode),
Unpivot(Unpivot),
Sort(Sort),
// Split(Split),
Expand Down Expand Up @@ -107,6 +107,20 @@ impl LocalPhysicalPlan {
.arced()
}

pub(crate) fn explode(
input: LocalPhysicalPlanRef,
to_explode: Vec<ExprRef>,
schema: SchemaRef,
) -> LocalPhysicalPlanRef {
Self::Explode(Explode {
input,
to_explode,
schema,
plan_stats: PlanStats {},
})
.arced()
}

pub(crate) fn project(
input: LocalPhysicalPlanRef,
projection: Vec<ExprRef>,
Expand Down Expand Up @@ -272,6 +286,7 @@ impl LocalPhysicalPlan {
| Self::Sort(Sort { schema, .. })
| Self::Sample(Sample { schema, .. })
| Self::HashJoin(HashJoin { schema, .. })
| Self::Explode(Explode { schema, .. })
| Self::Unpivot(Unpivot { schema, .. })
| Self::Concat(Concat { schema, .. }) => schema,
Self::InMemoryScan(InMemoryScan { info, .. }) => &info.source_schema,
Expand Down Expand Up @@ -323,6 +338,14 @@ pub struct Limit {
pub plan_stats: PlanStats,
}

#[derive(Debug)]
pub struct Explode {
pub input: LocalPhysicalPlanRef,
pub to_explode: Vec<ExprRef>,
pub schema: SchemaRef,
pub plan_stats: PlanStats,
}

#[derive(Debug)]
pub struct Sort {
pub input: LocalPhysicalPlanRef,
Expand Down
8 changes: 8 additions & 0 deletions src/daft-physical-plan/src/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,14 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult<LocalPhysicalPlanRef> {
log::warn!("Repartition Not supported for Local Executor!; This will be a No-Op");
translate(&repartition.input)
}
LogicalPlan::Explode(explode) => {
let input = translate(&explode.input)?;
Ok(LocalPhysicalPlan::explode(
input,
explode.to_explode.clone(),
explode.exploded_schema.clone(),
))
}
_ => todo!("{} not yet implemented", plan.name()),
}
}
6 changes: 0 additions & 6 deletions tests/dataframe/test_explode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,8 @@
import pyarrow as pa
import pytest

from daft import context
from daft.expressions import col

pytestmark = pytest.mark.skipif(
context.get_context().daft_execution_config.enable_native_executor is True,
reason="Native executor fails for these tests",
)


@pytest.mark.parametrize(
"data",
Expand Down
7 changes: 1 addition & 6 deletions tests/dataframe/test_wildcard.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
import pytest

import daft
from daft import col, context
from daft import col
from daft.exceptions import DaftCoreException

pytestmark = pytest.mark.skipif(
context.get_context().daft_execution_config.enable_native_executor is True,
reason="Native executor fails for these tests",
)


def test_wildcard_select():
df = daft.from_pydict(
Expand Down
5 changes: 1 addition & 4 deletions tests/sql/test_list_exprs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import pyarrow as pa
import pytest

import daft
from daft import col, context
from daft import col
from daft.daft import CountMode
from daft.sql.sql import SQLCatalog

Expand Down Expand Up @@ -62,8 +61,6 @@ def test_list_counts():


def test_list_explode():
if context.get_context().daft_execution_config.enable_native_executor is True:
pytest.skip("Native executor fails for these tests")
df = daft.from_pydict({"col": [[1, 2, 3], [1, 2], [1, None, 4], []]})
catalog = SQLCatalog({"test": df})
expected = df.explode(col("col"))
Expand Down

0 comments on commit d1b06fb

Please sign in to comment.