From 5357ffae3da4e1b5d5c7131a5b442fa66a812b9e Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Mon, 26 Aug 2024 17:55:16 -0700 Subject: [PATCH] pipline channel --- src/daft-local-execution/src/channel.rs | 81 +++++++++++-------- .../src/intermediate_ops/aggregate.rs | 5 +- .../src/intermediate_ops/filter.rs | 5 +- .../src/intermediate_ops/hash_join_probe.rs | 13 +-- .../src/intermediate_ops/intermediate_op.rs | 42 +++++----- .../src/intermediate_ops/project.rs | 5 +- src/daft-local-execution/src/pipeline.rs | 44 ++-------- src/daft-local-execution/src/run.rs | 7 +- src/daft-local-execution/src/runtime_stats.rs | 30 ++++++- .../src/sinks/blocking_sink.rs | 35 ++++---- .../src/sinks/hash_join_build.rs | 1 + .../src/sinks/streaming_sink.rs | 26 +++--- .../src/sources/source.rs | 12 ++- 13 files changed, 151 insertions(+), 155 deletions(-) diff --git a/src/daft-local-execution/src/channel.rs b/src/daft-local-execution/src/channel.rs index db73bbb64f..bb22b9d4ea 100644 --- a/src/daft-local-execution/src/channel.rs +++ b/src/daft-local-execution/src/channel.rs @@ -2,16 +2,9 @@ use std::sync::Arc; use crate::{ pipeline::PipelineResultType, - runtime_stats::{CountingSender, RuntimeStatsContext}, + runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}, }; -pub type OneShotSender = tokio::sync::oneshot::Sender; -pub type OneShotReceiver = tokio::sync::oneshot::Receiver; - -pub fn create_one_shot_channel() -> (OneShotSender, OneShotReceiver) { - tokio::sync::oneshot::channel() -} - pub type Sender = tokio::sync::mpsc::Sender; pub type Receiver = tokio::sync::mpsc::Receiver; @@ -19,33 +12,57 @@ pub fn create_channel(buffer_size: usize) -> (Sender, Receiver) { tokio::sync::mpsc::channel(buffer_size) } -pub fn create_multi_channel(buffer_size: usize, in_order: bool) -> (MultiSender, MultiReceiver) { - if in_order { - let (senders, receivers) = (0..buffer_size).map(|_| create_channel(1)).unzip(); - let sender = MultiSender::InOrder(RoundRobinSender::new(senders)); - let receiver = MultiReceiver::InOrder(RoundRobinReceiver::new(receivers)); - (sender, receiver) - } else { - let (sender, receiver) = create_channel(buffer_size); - let sender = MultiSender::OutOfOrder(sender); - let receiver = MultiReceiver::OutOfOrder(receiver); - (sender, receiver) +pub struct PipelineChannel { + sender: PipelineSender, + receiver: PipelineReceiver, +} + +impl PipelineChannel { + pub fn new(buffer_size: usize, in_order: bool) -> Self { + match in_order { + true => { + let (senders, receivers) = (0..buffer_size).map(|_| create_channel(1)).unzip(); + let sender = PipelineSender::InOrder(RoundRobinSender::new(senders)); + let receiver = PipelineReceiver::InOrder(RoundRobinReceiver::new(receivers)); + Self { sender, receiver } + } + false => { + let (sender, receiver) = create_channel(buffer_size); + let sender = PipelineSender::OutOfOrder(sender); + let receiver = PipelineReceiver::OutOfOrder(receiver); + Self { sender, receiver } + } + } + } + + fn get_next_sender(&mut self) -> Sender { + match &mut self.sender { + PipelineSender::InOrder(rr) => rr.get_next_sender(), + PipelineSender::OutOfOrder(sender) => sender.clone(), + } + } + + pub(crate) fn get_next_sender_with_stats( + &mut self, + rt: &Arc, + ) -> CountingSender { + CountingSender::new(self.get_next_sender(), rt.clone()) + } + + pub fn get_receiver(self) -> PipelineReceiver { + self.receiver + } + + pub(crate) fn get_receiver_with_stats(self, rt: &Arc) -> CountingReceiver { + CountingReceiver::new(self.get_receiver(), rt.clone()) } } -pub enum MultiSender { +pub enum PipelineSender { InOrder(RoundRobinSender), OutOfOrder(Sender), } -impl MultiSender { - pub fn get_next_sender(&mut self, stats: &Arc) -> CountingSender { - match self { - Self::InOrder(sender) => CountingSender::new(sender.get_next_sender(), stats.clone()), - Self::OutOfOrder(sender) => CountingSender::new(sender.clone(), stats.clone()), - } - } -} pub struct RoundRobinSender { senders: Vec>, curr_sender_idx: usize, @@ -66,16 +83,16 @@ impl RoundRobinSender { } } -pub enum MultiReceiver { +pub enum PipelineReceiver { InOrder(RoundRobinReceiver), OutOfOrder(Receiver), } -impl MultiReceiver { +impl PipelineReceiver { pub async fn recv(&mut self) -> Option { match self { - Self::InOrder(receiver) => receiver.recv().await, - Self::OutOfOrder(receiver) => receiver.recv().await, + PipelineReceiver::InOrder(rr) => rr.recv().await, + PipelineReceiver::OutOfOrder(r) => r.recv().await, } } } diff --git a/src/daft-local-execution/src/intermediate_ops/aggregate.rs b/src/daft-local-execution/src/intermediate_ops/aggregate.rs index 69f2c5a6e9..c83bc3f540 100644 --- a/src/daft-local-execution/src/intermediate_ops/aggregate.rs +++ b/src/daft-local-execution/src/intermediate_ops/aggregate.rs @@ -29,11 +29,10 @@ impl IntermediateOperator for AggregateOperator { fn execute( &self, _idx: usize, - input: &PipelineResultType, + input: PipelineResultType, _state: Option<&mut Box>, ) -> DaftResult { - let input = input.as_data(); - let out = input.agg(&self.agg_exprs, &self.group_by)?; + let out = input.data().agg(&self.agg_exprs, &self.group_by)?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( out, )))) diff --git a/src/daft-local-execution/src/intermediate_ops/filter.rs b/src/daft-local-execution/src/intermediate_ops/filter.rs index e8b1611b87..ef85379ec6 100644 --- a/src/daft-local-execution/src/intermediate_ops/filter.rs +++ b/src/daft-local-execution/src/intermediate_ops/filter.rs @@ -25,11 +25,10 @@ impl IntermediateOperator for FilterOperator { fn execute( &self, _idx: usize, - input: &PipelineResultType, + input: PipelineResultType, _state: Option<&mut Box>, ) -> DaftResult { - let input = input.as_data(); - let out = input.filter(&[self.predicate.clone()])?; + let out = input.data().filter(&[self.predicate.clone()])?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( out, )))) diff --git a/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs index b8a4053eb8..7b1bb04d22 100644 --- a/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs @@ -19,9 +19,9 @@ enum HashJoinProbeState { } impl HashJoinProbeState { - fn set_table(&mut self, table: &Arc, tables: &Arc>) { + fn set_table(&mut self, table: Arc, tables: Arc>) { if let HashJoinProbeState::Building = self { - *self = HashJoinProbeState::ReadyToProbe(table.clone(), tables.clone()); + *self = HashJoinProbeState::ReadyToProbe(table, tables); } else { panic!("HashJoinProbeState should only be in Building state when setting table") } @@ -29,7 +29,7 @@ impl HashJoinProbeState { fn probe( &self, - input: &Arc, + input: Arc, right_on: &[ExprRef], pruned_right_side_columns: &[String], ) -> DaftResult> { @@ -109,9 +109,10 @@ impl IntermediateOperator for HashJoinProbeOperator { fn execute( &self, idx: usize, - input: &PipelineResultType, + input: PipelineResultType, state: Option<&mut Box>, ) -> DaftResult { + println!("HashJoinProbeOperator::execute: idx: {}", idx); match idx { 0 => { let state = state @@ -119,7 +120,7 @@ impl IntermediateOperator for HashJoinProbeOperator { .as_any_mut() .downcast_mut::() .expect("HashJoinProbeOperator state should be HashJoinProbeState"); - let (probe_table, tables) = input.as_probe_table(); + let (probe_table, tables) = input.probe_table(); state.set_table(probe_table, tables); Ok(IntermediateOperatorResult::NeedMoreInput(None)) } @@ -129,7 +130,7 @@ impl IntermediateOperator for HashJoinProbeOperator { .as_any_mut() .downcast_mut::() .expect("HashJoinProbeOperator state should be HashJoinProbeState"); - let input = input.as_data(); + let input = input.data(); let out = state.probe(input, &self.right_on, &self.pruned_right_side_columns)?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) } diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index b59f1773ba..6c4253fe98 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -8,9 +8,9 @@ use tracing::{info_span, instrument}; use async_trait::async_trait; use crate::{ - channel::{create_channel, create_multi_channel, MultiSender, Receiver, Sender}, - pipeline::{PipelineNode, PipelineResultReceiver, PipelineResultType}, - runtime_stats::{CountingSender, RuntimeStatsContext}, + channel::{create_channel, PipelineChannel, Receiver, Sender}, + pipeline::{PipelineNode, PipelineResultType}, + runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}, ExecutionRuntimeHandle, NUM_CPUS, }; @@ -30,7 +30,7 @@ pub trait IntermediateOperator: Send + Sync { fn execute( &self, idx: usize, - input: &PipelineResultType, + input: PipelineResultType, state: Option<&mut Box>, ) -> DaftResult; fn name(&self) -> &'static str; @@ -80,14 +80,7 @@ impl IntermediateNode { let span = info_span!("IntermediateOp::execute"); let mut state = op.make_state(); while let Some((idx, morsel)) = receiver.recv().await { - let len = match morsel { - PipelineResultType::Data(ref data) => data.len(), - PipelineResultType::ProbeTable(_, ref tables) => { - tables.iter().map(|t| t.len()).sum() - } - }; - rt_context.mark_rows_received(len as u64); - let result = rt_context.in_span(&span, || op.execute(idx, &morsel, state.as_mut()))?; + let result = rt_context.in_span(&span, || op.execute(idx, morsel, state.as_mut()))?; match result { IntermediateOperatorResult::NeedMoreInput(Some(mp)) => { let _ = sender.send(mp.into()).await; @@ -104,13 +97,14 @@ impl IntermediateNode { pub async fn spawn_workers( &self, num_workers: usize, - destination: &mut MultiSender, + destination_channel: &mut PipelineChannel, runtime_handle: &mut ExecutionRuntimeHandle, ) -> Vec> { let mut worker_senders = Vec::with_capacity(num_workers); for _ in 0..num_workers { let (worker_sender, worker_receiver) = create_channel(1); - let destination_sender = destination.get_next_sender(&self.runtime_stats); + let destination_sender = + destination_channel.get_next_sender_with_stats(&self.runtime_stats); runtime_handle.spawn( Self::run_worker( self.intermediate_op.clone(), @@ -126,7 +120,7 @@ impl IntermediateNode { } pub async fn send_to_workers( - receivers: Vec, + receivers: Vec, worker_senders: Vec>, morsel_size: usize, ) -> DaftResult<()> { @@ -138,15 +132,15 @@ impl IntermediateNode { }; for (idx, mut receiver) in receivers.into_iter().enumerate() { + println!("idx: {}", idx); let mut buffer = OperatorBuffer::new(morsel_size); while let Some(morsel) = receiver.recv().await { - let morsel = morsel?; if morsel.should_broadcast() { for worker_sender in worker_senders.iter() { let _ = worker_sender.send((idx, morsel.clone())).await; } } else { - buffer.push(morsel.as_data().clone()); + buffer.push(morsel.data().clone()); if let Some(ready) = buffer.try_clear() { let _ = send_to_next_worker(idx, ready?.into()).await; } @@ -200,17 +194,17 @@ impl PipelineNode for IntermediateNode { &mut self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result { + ) -> crate::Result { let mut child_result_receivers = Vec::with_capacity(self.children.len()); for child in self.children.iter_mut() { - let child_result_receiver = child.start(maintain_order, runtime_handle).await?; - child_result_receivers.push(child_result_receiver); + let child_result_channel = child.start(maintain_order, runtime_handle).await?; + child_result_receivers + .push(child_result_channel.get_receiver_with_stats(&self.runtime_stats)); } - let (mut destination_sender, destination_receiver) = - create_multi_channel(*NUM_CPUS, maintain_order); + let mut destination_channel = PipelineChannel::new(*NUM_CPUS, maintain_order); let worker_senders = self - .spawn_workers(*NUM_CPUS, &mut destination_sender, runtime_handle) + .spawn_workers(*NUM_CPUS, &mut destination_channel, runtime_handle) .await; runtime_handle.spawn( Self::send_to_workers( @@ -220,7 +214,7 @@ impl PipelineNode for IntermediateNode { ), self.intermediate_op.name(), ); - Ok(destination_receiver.into()) + Ok(destination_channel) } fn as_tree_display(&self) -> &dyn TreeDisplay { self diff --git a/src/daft-local-execution/src/intermediate_ops/project.rs b/src/daft-local-execution/src/intermediate_ops/project.rs index 090116ad71..bc9f7f4eea 100644 --- a/src/daft-local-execution/src/intermediate_ops/project.rs +++ b/src/daft-local-execution/src/intermediate_ops/project.rs @@ -25,11 +25,10 @@ impl IntermediateOperator for ProjectOperator { fn execute( &self, _idx: usize, - input: &PipelineResultType, + input: PipelineResultType, _state: Option<&mut Box>, ) -> DaftResult { - let input = input.as_data(); - let out = input.eval_expression_list(&self.projection)?; + let out = input.data().eval_expression_list(&self.projection)?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( out, )))) diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index 6dce8ca44e..ad9668637f 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, sync::Arc}; use crate::{ - channel::{MultiReceiver, OneShotReceiver}, + channel::PipelineChannel, intermediate_ops::{ aggregate::AggregateOperator, filter::FilterOperator, hash_join_probe::HashJoinProbeOperator, intermediate_op::IntermediateNode, @@ -13,7 +13,7 @@ use crate::{ streaming_sink::StreamingSinkNode, }, sources::in_memory::InMemorySource, - ExecutionRuntimeHandle, OneShotRecvSnafu, PipelineCreationSnafu, + ExecutionRuntimeHandle, PipelineCreationSnafu, }; use async_trait::async_trait; @@ -53,14 +53,14 @@ impl From<(Arc, Arc>)> for PipelineResultType { } impl PipelineResultType { - pub fn as_data(&self) -> &Arc { + pub fn data(self) -> Arc { match self { PipelineResultType::Data(data) => data, _ => panic!("Expected data"), } } - pub fn as_probe_table(&self) -> (&Arc, &Arc>) { + pub fn probe_table(self) -> (Arc, Arc>) { match self { PipelineResultType::ProbeTable(probe_table, tables) => (probe_table, tables), _ => panic!("Expected probe table"), @@ -72,40 +72,6 @@ impl PipelineResultType { } } -pub enum PipelineResultReceiver { - Multi(MultiReceiver), - OneShot(OneShotReceiver, bool), -} - -impl From for PipelineResultReceiver { - fn from(rx: MultiReceiver) -> Self { - PipelineResultReceiver::Multi(rx) - } -} - -impl From> for PipelineResultReceiver { - fn from(rx: OneShotReceiver) -> Self { - PipelineResultReceiver::OneShot(rx, false) - } -} - -impl PipelineResultReceiver { - pub async fn recv(&mut self) -> Option> { - match self { - PipelineResultReceiver::Multi(rx) => rx.recv().await.map(Ok), - PipelineResultReceiver::OneShot(rx, done) => { - if *done { - None - } else { - let result = rx.await.context(OneShotRecvSnafu); - *done = true; - Some(result) - } - } - } - } -} - #[async_trait] pub trait PipelineNode: Sync + Send + TreeDisplay { fn children(&self) -> Vec<&dyn PipelineNode>; @@ -114,7 +80,7 @@ pub trait PipelineNode: Sync + Send + TreeDisplay { &mut self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result; + ) -> crate::Result; fn as_tree_display(&self) -> &dyn TreeDisplay; } diff --git a/src/daft-local-execution/src/run.rs b/src/daft-local-execution/src/run.rs index 7767635f1f..04726f86e5 100644 --- a/src/daft-local-execution/src/run.rs +++ b/src/daft-local-execution/src/run.rs @@ -136,9 +136,12 @@ pub fn run_local( .expect("Failed to create tokio runtime"); runtime.block_on(async { let mut runtime_handle = ExecutionRuntimeHandle::new(cfg.default_morsel_size); - let mut receiver = pipeline.start(true, &mut runtime_handle).await?; + let mut receiver = pipeline + .start(true, &mut runtime_handle) + .await? + .get_receiver(); while let Some(val) = receiver.recv().await { - let _ = tx.send(val?.as_data().clone()).await; + let _ = tx.send(val.data()).await; } while let Some(result) = runtime_handle.join_next().await { diff --git a/src/daft-local-execution/src/runtime_stats.rs b/src/daft-local-execution/src/runtime_stats.rs index adc1c8270d..bcd9e4a1f5 100644 --- a/src/daft-local-execution/src/runtime_stats.rs +++ b/src/daft-local-execution/src/runtime_stats.rs @@ -7,7 +7,10 @@ use std::{ use tokio::sync::mpsc::error::SendError; -use crate::{channel::Sender, pipeline::PipelineResultType}; +use crate::{ + channel::{PipelineReceiver, Sender}, + pipeline::PipelineResultType, +}; #[derive(Default)] pub(crate) struct RuntimeStatsContext { @@ -129,3 +132,28 @@ impl CountingSender { Ok(()) } } + +pub(crate) struct CountingReceiver { + receiver: PipelineReceiver, + rt: Arc, +} + +impl CountingReceiver { + pub(crate) fn new(receiver: PipelineReceiver, rt: Arc) -> Self { + Self { receiver, rt } + } + #[inline] + pub(crate) async fn recv(&mut self) -> Option { + let v = self.receiver.recv().await; + if let Some(ref v) = v { + let len = match v { + PipelineResultType::Data(ref mp) => mp.len(), + PipelineResultType::ProbeTable(_, ref tables) => { + tables.iter().map(|t| t.len()).sum() + } + }; + self.rt.mark_rows_received(len as u64); + } + v + } +} diff --git a/src/daft-local-execution/src/sinks/blocking_sink.rs b/src/daft-local-execution/src/sinks/blocking_sink.rs index 1399791f96..63ea4c2cee 100644 --- a/src/daft-local-execution/src/sinks/blocking_sink.rs +++ b/src/daft-local-execution/src/sinks/blocking_sink.rs @@ -6,8 +6,8 @@ use daft_micropartition::MicroPartition; use tracing::info_span; use crate::{ - channel::create_one_shot_channel, - pipeline::{PipelineNode, PipelineResultReceiver, PipelineResultType}, + channel::PipelineChannel, + pipeline::{PipelineNode, PipelineResultType}, runtime_stats::RuntimeStatsContext, ExecutionRuntimeHandle, }; @@ -80,25 +80,27 @@ impl PipelineNode for BlockingSinkNode { async fn start( &mut self, - _maintain_order: bool, + maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result { + ) -> crate::Result { let child = self.child.as_mut(); - let mut child_results_receiver = child.start(false, runtime_handle).await?; - let op = self.op.clone(); + let mut child_results_receiver = child + .start(false, runtime_handle) + .await? + .get_receiver_with_stats(&self.runtime_stats); - let (destination_sender, destination_receiver) = create_one_shot_channel(); + let mut destination_channel = PipelineChannel::new(1, maintain_order); + let destination_sender = + destination_channel.get_next_sender_with_stats(&self.runtime_stats); + let op = self.op.clone(); let rt_context = self.runtime_stats.clone(); runtime_handle.spawn( async move { let span = info_span!("BlockingSinkNode::execute"); let mut guard = op.lock().await; while let Some(val) = child_results_receiver.recv().await { - let val = val?; - let val = val.as_data(); - rt_context.mark_rows_received(val.len() as u64); if let BlockingSinkStatus::Finished = - rt_context.in_span(&span, || guard.sink(val))? + rt_context.in_span(&span, || guard.sink(&val.data()))? { break; } @@ -108,20 +110,13 @@ impl PipelineNode for BlockingSinkNode { guard.finalize() })?; if let Some(part) = finalized_result { - let len = match part { - PipelineResultType::Data(ref part) => part.len(), - PipelineResultType::ProbeTable(_, ref tables) => { - tables.iter().map(|t| t.len()).sum() - } - }; - let _ = destination_sender.send(part); - rt_context.mark_rows_emitted(len as u64); + let _ = destination_sender.send(part).await; } Ok(()) }, self.name(), ); - Ok(destination_receiver.into()) + Ok(destination_channel) } fn as_tree_display(&self) -> &dyn TreeDisplay { self diff --git a/src/daft-local-execution/src/sinks/hash_join_build.rs b/src/daft-local-execution/src/sinks/hash_join_build.rs index 54965371cf..0a86756334 100644 --- a/src/daft-local-execution/src/sinks/hash_join_build.rs +++ b/src/daft-local-execution/src/sinks/hash_join_build.rs @@ -102,6 +102,7 @@ impl BlockingSink for HashJoinBuildSink { tables, } = &self.probe_table_state { + println!("HashJoinBuildSink::finalize: table_len: {}", tables.len()); Ok(Some((probe_table.clone(), tables.clone()).into())) } else { panic!("finalize should only be called after the probe table is built") diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index 4ab12e8b4c..afe4be0d3b 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -6,9 +6,7 @@ use daft_micropartition::MicroPartition; use tracing::info_span; use crate::{ - channel::create_multi_channel, - pipeline::{PipelineNode, PipelineResultReceiver}, - runtime_stats::RuntimeStatsContext, + channel::PipelineChannel, pipeline::PipelineNode, runtime_stats::RuntimeStatsContext, ExecutionRuntimeHandle, NUM_CPUS, }; use async_trait::async_trait; @@ -89,14 +87,17 @@ impl PipelineNode for StreamingSinkNode { &mut self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result { + ) -> crate::Result { let child = self .children .get_mut(0) .expect("we should only have 1 child"); - let mut child_results_receiver = child.start(true, runtime_handle).await?; - let (mut destination_sender, destination_receiver) = - create_multi_channel(*NUM_CPUS, maintain_order); + let child_results_channel = child.start(true, runtime_handle).await?; + let mut child_results_receiver = + child_results_channel.get_receiver_with_stats(&self.runtime_stats); + + let mut destination_channel = PipelineChannel::new(*NUM_CPUS, maintain_order); + let sender = destination_channel.get_next_sender_with_stats(&self.runtime_stats); let op = self.op.clone(); let runtime_stats = self.runtime_stats.clone(); runtime_handle.spawn( @@ -107,26 +108,21 @@ impl PipelineNode for StreamingSinkNode { let mut sink = op.lock().await; let mut is_active = true; while is_active && let Some(val) = child_results_receiver.recv().await { - let val = val?; - let val = val.as_data(); - runtime_stats.mark_rows_received(val.len() as u64); + let val = val.data(); loop { - let result = runtime_stats.in_span(&span, || sink.execute(0, val))?; + let result = runtime_stats.in_span(&span, || sink.execute(0, &val))?; match result { StreamSinkOutput::HasMoreOutput(mp) => { - let sender = destination_sender.get_next_sender(&runtime_stats); sender.send(mp.into()).await.unwrap(); } StreamSinkOutput::NeedMoreInput(mp) => { if let Some(mp) = mp { - let sender = destination_sender.get_next_sender(&runtime_stats); sender.send(mp.into()).await.unwrap(); } break; } StreamSinkOutput::Finished(mp) => { if let Some(mp) = mp { - let sender = destination_sender.get_next_sender(&runtime_stats); sender.send(mp.into()).await.unwrap(); } is_active = false; @@ -139,7 +135,7 @@ impl PipelineNode for StreamingSinkNode { }, self.name(), ); - Ok(destination_receiver.into()) + Ok(destination_channel) } fn as_tree_display(&self) -> &dyn TreeDisplay { self diff --git a/src/daft-local-execution/src/sources/source.rs b/src/daft-local-execution/src/sources/source.rs index 1b98759e89..be7d6bef67 100644 --- a/src/daft-local-execution/src/sources/source.rs +++ b/src/daft-local-execution/src/sources/source.rs @@ -8,9 +8,7 @@ use futures::{stream::BoxStream, StreamExt}; use async_trait::async_trait; use crate::{ - channel::create_multi_channel, - pipeline::{PipelineNode, PipelineResultReceiver}, - runtime_stats::RuntimeStatsContext, + channel::PipelineChannel, pipeline::PipelineNode, runtime_stats::RuntimeStatsContext, ExecutionRuntimeHandle, }; @@ -76,13 +74,13 @@ impl PipelineNode for SourceNode { &mut self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result { + ) -> crate::Result { let mut source_stream = self.source .get_data(maintain_order, runtime_handle, self.io_stats.clone())?; - let (mut tx, rx) = create_multi_channel(1, maintain_order); - let counting_sender = tx.get_next_sender(&self.runtime_stats); + let mut channel = PipelineChannel::new(1, maintain_order); + let counting_sender = channel.get_next_sender_with_stats(&self.runtime_stats); runtime_handle.spawn( async move { while let Some(part) = source_stream.next().await { @@ -92,7 +90,7 @@ impl PipelineNode for SourceNode { }, self.name(), ); - Ok(rx.into()) + Ok(channel) } fn as_tree_display(&self) -> &dyn TreeDisplay { self