Skip to content

Commit

Permalink
[FEAT] Enable pivot for swordfish (#3081)
Browse files Browse the repository at this point in the history
Adds pivot as an intermediate operator. Unskips all the pivot tests.

Co-authored-by: Colin Ho <[email protected]>
  • Loading branch information
colin-ho and Colin Ho authored Oct 21, 2024
1 parent 4781ad3 commit 4a8244b
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 13 deletions.
1 change: 1 addition & 0 deletions src/daft-local-execution/src/intermediate_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ pub mod buffer;
pub mod filter;
pub mod hash_join_probe;
pub mod intermediate_op;
pub mod pivot;
pub mod project;
pub mod sample;
57 changes: 57 additions & 0 deletions src/daft-local-execution/src/intermediate_ops/pivot.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use std::sync::Arc;

use common_error::DaftResult;
use daft_dsl::ExprRef;
use tracing::instrument;

use super::intermediate_op::{
IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState,
};
use crate::pipeline::PipelineResultType;

pub struct PivotOperator {
group_by: Vec<ExprRef>,
pivot_col: ExprRef,
values_col: ExprRef,
names: Vec<String>,
}

impl PivotOperator {
pub fn new(
group_by: Vec<ExprRef>,
pivot_col: ExprRef,
values_col: ExprRef,
names: Vec<String>,
) -> Self {
Self {
group_by,
pivot_col,
values_col,
names,
}
}
}

impl IntermediateOperator for PivotOperator {
#[instrument(skip_all, name = "PivotOperator::execute")]
fn execute(
&self,
_idx: usize,
input: &PipelineResultType,
_state: Option<&mut Box<dyn IntermediateOperatorState>>,
) -> DaftResult<IntermediateOperatorResult> {
let out = input.as_data().pivot(
&self.group_by,
self.pivot_col.clone(),
self.values_col.clone(),
self.names.clone(),
)?;
Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new(
out,
))))
}

fn name(&self) -> &'static str {
"PivotOperator"
}
}
24 changes: 21 additions & 3 deletions src/daft-local-execution/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use daft_core::{
use daft_dsl::{col, join::get_common_join_keys, Expr};
use daft_micropartition::MicroPartition;
use daft_physical_plan::{
EmptyScan, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, LocalPhysicalPlan, Project,
Sample, Sort, UnGroupedAggregate,
EmptyScan, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, LocalPhysicalPlan, Pivot,
Project, Sample, Sort, UnGroupedAggregate,
};
use daft_plan::{populate_aggregation_stages, JoinType};
use daft_table::{Probeable, Table};
Expand All @@ -23,7 +23,8 @@ use crate::{
intermediate_ops::{
aggregate::AggregateOperator, anti_semi_hash_join_probe::AntiSemiProbeOperator,
filter::FilterOperator, hash_join_probe::HashJoinProbeOperator,
intermediate_op::IntermediateNode, project::ProjectOperator, sample::SampleOperator,
intermediate_op::IntermediateNode, pivot::PivotOperator, project::ProjectOperator,
sample::SampleOperator,
},
sinks::{
aggregate::AggregateSink, blocking_sink::BlockingSinkNode,
Expand Down Expand Up @@ -234,6 +235,23 @@ pub fn physical_plan_to_pipeline(

IntermediateNode::new(Arc::new(final_stage_project), vec![second_stage_node]).boxed()
}
LocalPhysicalPlan::Pivot(Pivot {
input,
group_by,
pivot_column,
value_column,
names,
..
}) => {
let pivot_op = PivotOperator::new(
group_by.clone(),
pivot_column.clone(),
value_column.clone(),
names.clone(),
);
let child_node = physical_plan_to_pipeline(input, psets)?;
IntermediateNode::new(Arc::new(pivot_op), vec![child_node]).boxed()
}
LocalPhysicalPlan::Sort(Sort {
input,
sort_by,
Expand Down
3 changes: 2 additions & 1 deletion src/daft-physical-plan/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod translate;

pub use local_plan::{
Concat, EmptyScan, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, LocalPhysicalPlan,
LocalPhysicalPlanRef, PhysicalScan, PhysicalWrite, Project, Sample, Sort, UnGroupedAggregate,
LocalPhysicalPlanRef, PhysicalScan, PhysicalWrite, Pivot, Project, Sample, Sort,
UnGroupedAggregate,
};
pub use translate::translate;
34 changes: 33 additions & 1 deletion src/daft-physical-plan/src/local_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub enum LocalPhysicalPlan {
// ReduceMerge(ReduceMerge),
UnGroupedAggregate(UnGroupedAggregate),
HashAggregate(HashAggregate),
// Pivot(Pivot),
Pivot(Pivot),
Concat(Concat),
HashJoin(HashJoin),
// SortMergeJoin(SortMergeJoin),
Expand Down Expand Up @@ -151,6 +151,26 @@ impl LocalPhysicalPlan {
.arced()
}

pub(crate) fn pivot(
input: LocalPhysicalPlanRef,
group_by: Vec<ExprRef>,
pivot_column: ExprRef,
value_column: ExprRef,
names: Vec<String>,
schema: SchemaRef,
) -> LocalPhysicalPlanRef {
Self::Pivot(Pivot {
input,
group_by,
pivot_column,
value_column,
names,
schema,
plan_stats: PlanStats {},
})
.arced()
}

pub(crate) fn sort(
input: LocalPhysicalPlanRef,
sort_by: Vec<ExprRef>,
Expand Down Expand Up @@ -228,6 +248,7 @@ impl LocalPhysicalPlan {
| Self::Project(Project { schema, .. })
| Self::UnGroupedAggregate(UnGroupedAggregate { schema, .. })
| Self::HashAggregate(HashAggregate { schema, .. })
| Self::Pivot(Pivot { schema, .. })
| Self::Sort(Sort { schema, .. })
| Self::Sample(Sample { schema, .. })
| Self::HashJoin(HashJoin { schema, .. })
Expand Down Expand Up @@ -317,6 +338,17 @@ pub struct HashAggregate {
pub plan_stats: PlanStats,
}

#[derive(Debug)]
pub struct Pivot {
pub input: LocalPhysicalPlanRef,
pub group_by: Vec<ExprRef>,
pub pivot_column: ExprRef,
pub value_column: ExprRef,
pub names: Vec<String>,
pub schema: SchemaRef,
pub plan_stats: PlanStats,
}

#[derive(Debug)]
pub struct HashJoin {
pub left: LocalPhysicalPlanRef,
Expand Down
31 changes: 30 additions & 1 deletion src/daft-physical-plan/src/translate.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use common_error::DaftResult;
use daft_core::join::JoinStrategy;
use daft_core::{join::JoinStrategy, prelude::Schema};
use daft_dsl::ExprRef;
use daft_plan::{LogicalPlan, LogicalPlanRef, SourceInfo};

Expand Down Expand Up @@ -70,6 +70,35 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult<LocalPhysicalPlanRef> {
))
}
}
LogicalPlan::Pivot(pivot) => {
let input = translate(&pivot.input)?;
let groupby_with_pivot = pivot
.group_by
.iter()
.chain(std::iter::once(&pivot.pivot_column))
.cloned()
.collect::<Vec<_>>();
let aggregate_fields = groupby_with_pivot
.iter()
.map(|expr| expr.to_field(input.schema()))
.chain(std::iter::once(pivot.aggregation.to_field(input.schema())))
.collect::<DaftResult<Vec<_>>>()?;
let aggregate_schema = Schema::new(aggregate_fields)?;
let aggregate = LocalPhysicalPlan::hash_aggregate(
input,
vec![pivot.aggregation.clone(); 1],
groupby_with_pivot,
aggregate_schema.into(),
);
Ok(LocalPhysicalPlan::pivot(
aggregate,
pivot.group_by.clone(),
pivot.pivot_column.clone(),
pivot.value_column.clone(),
pivot.names.clone(),
pivot.output_schema.clone(),
))
}
LogicalPlan::Sort(sort) => {
let input = translate(&sort.input)?;
Ok(LocalPhysicalPlan::sort(
Expand Down
7 changes: 0 additions & 7 deletions tests/dataframe/test_pivot.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
import pytest

from daft import context

pytestmark = pytest.mark.skipif(
context.get_context().daft_execution_config.enable_native_executor is True,
reason="Native executor fails for these tests",
)


@pytest.mark.parametrize("repartition_nparts", [1, 2, 5])
def test_pivot(make_df, repartition_nparts):
Expand Down

0 comments on commit 4a8244b

Please sign in to comment.