Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Enable concat for swordfish #2976

Merged
merged 15 commits into from
Oct 22, 2024
17 changes: 8 additions & 9 deletions src/daft-local-execution/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ use daft_core::{
use daft_dsl::{col, join::get_common_join_keys, Expr};
use daft_micropartition::MicroPartition;
use daft_physical_plan::{
EmptyScan, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, LocalPhysicalPlan, Pivot,
Project, Sample, Sort, UnGroupedAggregate, Unpivot,
Concat, EmptyScan, Filter, HashAggregate, HashJoin, InMemoryScan, Limit, LocalPhysicalPlan,
Pivot, Project, Sample, Sort, UnGroupedAggregate, Unpivot,
};
use daft_plan::{populate_aggregation_stages, JoinType};
use daft_table::ProbeState;
Expand All @@ -27,7 +27,7 @@ use crate::{
sample::SampleOperator, unpivot::UnpivotOperator,
},
sinks::{
aggregate::AggregateSink, blocking_sink::BlockingSinkNode,
aggregate::AggregateSink, blocking_sink::BlockingSinkNode, concat::ConcatSink,
hash_join_build::HashJoinBuildSink, limit::LimitSink,
outer_hash_join_probe::OuterHashJoinProbeSink, sort::SortSink,
streaming_sink::StreamingSinkNode,
Expand Down Expand Up @@ -152,12 +152,11 @@ pub fn physical_plan_to_pipeline(
let child_node = physical_plan_to_pipeline(input, psets)?;
StreamingSinkNode::new(Arc::new(sink), vec![child_node]).boxed()
}
LocalPhysicalPlan::Concat(_) => {
todo!("concat")
// let sink = ConcatSink::new();
// let left_child = physical_plan_to_pipeline(input, psets)?;
// let right_child = physical_plan_to_pipeline(other, psets)?;
// PipelineNode::double_sink(sink, left_child, right_child)
LocalPhysicalPlan::Concat(Concat { input, other, .. }) => {
let left_child = physical_plan_to_pipeline(input, psets)?;
let right_child = physical_plan_to_pipeline(other, psets)?;
let sink = ConcatSink {};
StreamingSinkNode::new(Arc::new(sink), vec![left_child, right_child]).boxed()
}
LocalPhysicalPlan::UnGroupedAggregate(UnGroupedAggregate {
input,
Expand Down
129 changes: 68 additions & 61 deletions src/daft-local-execution/src/sinks/concat.rs
Original file line number Diff line number Diff line change
@@ -1,61 +1,68 @@
// use std::sync::Arc;

// use common_error::DaftResult;
// use daft_micropartition::MicroPartition;
// use tracing::instrument;

// use super::sink::{Sink, SinkResultType};

// #[derive(Clone)]
// pub struct ConcatSink {
// result_left: Vec<Arc<MicroPartition>>,
// result_right: Vec<Arc<MicroPartition>>,
// }

// impl ConcatSink {
// pub fn new() -> Self {
// Self {
// result_left: Vec::new(),
// result_right: Vec::new(),
// }
// }

// #[instrument(skip_all, name = "ConcatSink::sink")]
// fn sink_left(&mut self, input: &Arc<MicroPartition>) -> DaftResult<SinkResultType> {
// self.result_left.push(input.clone());
// Ok(SinkResultType::NeedMoreInput)
// }

// #[instrument(skip_all, name = "ConcatSink::sink")]
// fn sink_right(&mut self, input: &Arc<MicroPartition>) -> DaftResult<SinkResultType> {
// self.result_right.push(input.clone());
// Ok(SinkResultType::NeedMoreInput)
// }
// }

// impl Sink for ConcatSink {
// fn sink(&mut self, index: usize, input: &Arc<MicroPartition>) -> DaftResult<SinkResultType> {
// match index {
// 0 => self.sink_left(input),
// 1 => self.sink_right(input),
// _ => panic!("concat only supports 2 inputs, got {index}"),
// }
// }

// fn in_order(&self) -> bool {
// true
// }

// fn num_inputs(&self) -> usize {
// 2
// }

// #[instrument(skip_all, name = "ConcatSink::finalize")]
// fn finalize(self: Box<Self>) -> DaftResult<Vec<Arc<MicroPartition>>> {
// Ok(self
// .result_left
// .into_iter()
// .chain(self.result_right.into_iter())
// .collect())
// }
// }
use std::sync::Arc;

use common_error::{DaftError, DaftResult};
use daft_micropartition::MicroPartition;
use tracing::instrument;

use super::streaming_sink::{StreamingSink, StreamingSinkOutput, StreamingSinkState};
use crate::pipeline::PipelineResultType;

struct ConcatSinkState {
// The index of the last morsel of data that was received, which should be strictly non-decreasing.
pub curr_idx: usize,
}
impl StreamingSinkState for ConcatSinkState {
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}

pub struct ConcatSink {}

impl StreamingSink for ConcatSink {
/// Execute for the ConcatSink operator does not do any computation and simply returns the input data.
/// It only expects that the indices of the input data are strictly non-decreasing.
/// TODO(Colin): If maintain_order is false, technically we could accept any index. Make this optimization later.
#[instrument(skip_all, name = "ConcatSink::sink")]
fn execute(
&self,
index: usize,
input: &PipelineResultType,
state: &mut dyn StreamingSinkState,
) -> DaftResult<StreamingSinkOutput> {
let state = state
.as_any_mut()
.downcast_mut::<ConcatSinkState>()
.expect("ConcatSink should have ConcatSinkState");

// If the index is the same as the current index or one more than the current index, then we can accept the morsel.
if state.curr_idx == index || state.curr_idx + 1 == index {
state.curr_idx = index;
Ok(StreamingSinkOutput::NeedMoreInput(Some(
input.as_data().clone(),
)))
} else {
Err(DaftError::ComputeError(format!("Concat sink received out-of-order data. Expected index to be {} or {}, but got {}.", state.curr_idx, state.curr_idx + 1, index)))
}
}

fn name(&self) -> &'static str {
"Concat"
}

fn finalize(
&self,
_states: Vec<Box<dyn StreamingSinkState>>,
) -> DaftResult<Option<Arc<MicroPartition>>> {
Ok(None)
}

fn make_state(&self) -> Box<dyn StreamingSinkState> {
Box::new(ConcatSinkState { curr_idx: 0 })
}

/// Since the ConcatSink does not do any computation, it does not need to spawn multiple workers.
fn max_concurrency(&self) -> usize {
1
}
}
14 changes: 14 additions & 0 deletions src/daft-local-execution/src/sinks/streaming_sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,30 @@ pub enum StreamingSinkOutput {
}

pub trait StreamingSink: Send + Sync {
/// Execute the StreamingSink operator on the morsel of input data,
/// received from the child with the given index,
/// with the given state.
fn execute(
&self,
index: usize,
input: &PipelineResultType,
state: &mut dyn StreamingSinkState,
) -> DaftResult<StreamingSinkOutput>;

/// Finalize the StreamingSink operator, with the given states from each worker.
fn finalize(
&self,
states: Vec<Box<dyn StreamingSinkState>>,
) -> DaftResult<Option<Arc<MicroPartition>>>;

/// The name of the StreamingSink operator.
fn name(&self) -> &'static str;

/// Create a new worker-local state for this StreamingSink.
fn make_state(&self) -> Box<dyn StreamingSinkState>;

/// The maximum number of concurrent workers that can be spawned for this sink.
/// Each worker will has its own StreamingSinkState.
fn max_concurrency(&self) -> usize {
*NUM_CPUS
}
Expand Down Expand Up @@ -118,6 +130,8 @@ impl StreamingSinkNode {
output_receiver
}

// Forwards input from the children to the workers in a round-robin fashion.
// Always exhausts the input from one child before moving to the next.
async fn forward_input_to_workers(
receivers: Vec<CountingReceiver>,
worker_senders: Vec<Sender<(usize, PipelineResultType)>>,
Expand Down
7 changes: 1 addition & 6 deletions tests/dataframe/test_approx_count_distinct.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,7 @@
import pytest

import daft
from daft import col, context

pytestmark = pytest.mark.skipif(
context.get_context().daft_execution_config.enable_native_executor is True,
reason="Native executor fails for these tests",
)
from daft import col

TESTS = [
[[], 0],
Expand Down
7 changes: 0 additions & 7 deletions tests/dataframe/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,6 @@

import pytest

from daft import context

pytestmark = pytest.mark.skipif(
context.get_context().daft_execution_config.enable_native_executor is True,
reason="Native executor fails for these tests",
)


def test_simple_concat(make_df):
df1 = make_df({"foo": [1, 2, 3]})
Expand Down
6 changes: 0 additions & 6 deletions tests/dataframe/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import pytest

from daft import context


def test_sample_fraction(make_df, valid_data: list[dict[str, float]]) -> None:
df = make_df(valid_data)
Expand Down Expand Up @@ -100,10 +98,6 @@ def test_sample_without_replacement(make_df, valid_data: list[dict[str, float]])
assert pylist[0] != pylist[1]


@pytest.mark.skipif(
context.get_context().daft_execution_config.enable_native_executor is True,
reason="Native executor fails for concat",
)
def test_sample_with_concat(make_df, valid_data: list[dict[str, float]]) -> None:
df1 = make_df(valid_data)
df2 = make_df(valid_data)
Expand Down
6 changes: 0 additions & 6 deletions tests/dataframe/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,6 @@
import pytest

import daft
from daft import context

pytestmark = pytest.mark.skipif(
context.get_context().daft_execution_config.enable_native_executor is True,
reason="Native executor fails for these tests",
)


def add_1(df):
Expand Down
Loading