Skip to content

Commit

Permalink
[FEAT] Enable unpivot for swordfish (#3078)
Browse files Browse the repository at this point in the history
Adds unpivot as an intermediate operator. Unskips all the unpivot tests

---------

Co-authored-by: Colin Ho <[email protected]>
  • Loading branch information
colin-ho and Colin Ho authored Oct 22, 2024
1 parent b9b2d72 commit 31d5412
Show file tree
Hide file tree
Showing 7 changed files with 123 additions and 10 deletions.
1 change: 1 addition & 0 deletions src/daft-local-execution/src/intermediate_ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ pub mod intermediate_op;
pub mod pivot;
pub mod project;
pub mod sample;
pub mod unpivot;
57 changes: 57 additions & 0 deletions src/daft-local-execution/src/intermediate_ops/unpivot.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use std::sync::Arc;

use common_error::DaftResult;
use daft_dsl::ExprRef;
use tracing::instrument;

use super::intermediate_op::{
IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState,
};
use crate::pipeline::PipelineResultType;

pub struct UnpivotOperator {
ids: Vec<ExprRef>,
values: Vec<ExprRef>,
variable_name: String,
value_name: String,
}

impl UnpivotOperator {
pub fn new(
ids: Vec<ExprRef>,
values: Vec<ExprRef>,
variable_name: String,
value_name: String,
) -> Self {
Self {
ids,
values,
variable_name,
value_name,
}
}
}

impl IntermediateOperator for UnpivotOperator {
#[instrument(skip_all, name = "UnpivotOperator::execute")]
fn execute(
&self,
_idx: usize,
input: &PipelineResultType,
_state: Option<&mut Box<dyn IntermediateOperatorState>>,
) -> DaftResult<IntermediateOperatorResult> {
let out = input.as_data().unpivot(
&self.ids,
&self.values,
&self.variable_name,
&self.value_name,
)?;
Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new(
out,
))))
}

fn name(&self) -> &'static str {
"UnpivotOperator"
}
}
21 changes: 19 additions & 2 deletions src/daft-local-execution/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use daft_dsl::{col, join::get_common_join_keys, Expr};
use daft_micropartition::MicroPartition;
use daft_physical_plan::{
EmptyScan, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, LocalPhysicalPlan, Pivot,
Project, Sample, Sort, UnGroupedAggregate,
Project, Sample, Sort, UnGroupedAggregate, Unpivot,
};
use daft_plan::{populate_aggregation_stages, JoinType};
use daft_table::ProbeState;
Expand All @@ -24,7 +24,7 @@ use crate::{
aggregate::AggregateOperator, anti_semi_hash_join_probe::AntiSemiProbeOperator,
filter::FilterOperator, inner_hash_join_probe::InnerHashJoinProbeOperator,
intermediate_op::IntermediateNode, pivot::PivotOperator, project::ProjectOperator,
sample::SampleOperator,
sample::SampleOperator, unpivot::UnpivotOperator,
},
sinks::{
aggregate::AggregateSink, blocking_sink::BlockingSinkNode,
Expand Down Expand Up @@ -236,6 +236,23 @@ pub fn physical_plan_to_pipeline(

IntermediateNode::new(Arc::new(final_stage_project), vec![second_stage_node]).boxed()
}
LocalPhysicalPlan::Unpivot(Unpivot {
input,
ids,
values,
variable_name,
value_name,
..
}) => {
let child_node = physical_plan_to_pipeline(input, psets)?;
let unpivot_op = UnpivotOperator::new(
ids.clone(),
values.clone(),
variable_name.clone(),
value_name.clone(),
);
IntermediateNode::new(Arc::new(unpivot_op), vec![child_node]).boxed()
}
LocalPhysicalPlan::Pivot(Pivot {
input,
group_by,
Expand Down
2 changes: 1 addition & 1 deletion src/daft-physical-plan/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ mod translate;
pub use local_plan::{
Concat, EmptyScan, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, LocalPhysicalPlan,
LocalPhysicalPlanRef, PhysicalScan, PhysicalWrite, Pivot, Project, Sample, Sort,
UnGroupedAggregate,
UnGroupedAggregate, Unpivot,
};
pub use translate::translate;
34 changes: 33 additions & 1 deletion src/daft-physical-plan/src/local_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub enum LocalPhysicalPlan {
Filter(Filter),
Limit(Limit),
// Explode(Explode),
// Unpivot(Unpivot),
Unpivot(Unpivot),
Sort(Sort),
// Split(Split),
Sample(Sample),
Expand Down Expand Up @@ -151,6 +151,26 @@ impl LocalPhysicalPlan {
.arced()
}

pub(crate) fn unpivot(
input: LocalPhysicalPlanRef,
ids: Vec<ExprRef>,
values: Vec<ExprRef>,
variable_name: String,
value_name: String,
schema: SchemaRef,
) -> LocalPhysicalPlanRef {
Self::Unpivot(Unpivot {
input,
ids,
values,
variable_name,
value_name,
schema,
plan_stats: PlanStats {},
})
.arced()
}

pub(crate) fn pivot(
input: LocalPhysicalPlanRef,
group_by: Vec<ExprRef>,
Expand Down Expand Up @@ -252,6 +272,7 @@ impl LocalPhysicalPlan {
| Self::Sort(Sort { schema, .. })
| Self::Sample(Sample { schema, .. })
| Self::HashJoin(HashJoin { schema, .. })
| Self::Unpivot(Unpivot { schema, .. })
| Self::Concat(Concat { schema, .. }) => schema,
Self::InMemoryScan(InMemoryScan { info, .. }) => &info.source_schema,
_ => todo!("{:?}", self),
Expand Down Expand Up @@ -338,6 +359,17 @@ pub struct HashAggregate {
pub plan_stats: PlanStats,
}

#[derive(Debug)]
pub struct Unpivot {
pub input: LocalPhysicalPlanRef,
pub ids: Vec<ExprRef>,
pub values: Vec<ExprRef>,
pub variable_name: String,
pub value_name: String,
pub schema: SchemaRef,
pub plan_stats: PlanStats,
}

#[derive(Debug)]
pub struct Pivot {
pub input: LocalPhysicalPlanRef,
Expand Down
11 changes: 11 additions & 0 deletions src/daft-physical-plan/src/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,17 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult<LocalPhysicalPlanRef> {
))
}
}
LogicalPlan::Unpivot(unpivot) => {
let input = translate(&unpivot.input)?;
Ok(LocalPhysicalPlan::unpivot(
input,
unpivot.ids.clone(),
unpivot.values.clone(),
unpivot.variable_name.clone(),
unpivot.value_name.clone(),
unpivot.output_schema.clone(),
))
}
LogicalPlan::Pivot(pivot) => {
let input = translate(&pivot.input)?;
let groupby_with_pivot = pivot
Expand Down
7 changes: 1 addition & 6 deletions tests/dataframe/test_unpivot.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
import pytest

from daft import col, context
from daft import col
from daft.datatype import DataType

pytestmark = pytest.mark.skipif(
context.get_context().daft_execution_config.enable_native_executor is True,
reason="Native executor fails for these tests",
)


@pytest.mark.parametrize("n_partitions", [1, 2, 4])
def test_unpivot(make_df, n_partitions):
Expand Down

0 comments on commit 31d5412

Please sign in to comment.