From 3b751704ba901fa1beca933d733db16194307f2b Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 23 Aug 2024 16:16:55 -0700 Subject: [PATCH 1/6] init --- Cargo.lock | 1 + src/daft-local-execution/Cargo.toml | 1 + src/daft-local-execution/src/channel.rs | 117 ++---- .../src/intermediate_ops/aggregate.rs | 19 +- .../src/intermediate_ops/filter.rs | 19 +- .../src/intermediate_ops/hash_join_probe.rs | 142 +++++++ .../src/intermediate_ops/intermediate_op.rs | 143 ++++--- .../src/intermediate_ops/mod.rs | 1 + .../src/intermediate_ops/project.rs | 19 +- src/daft-local-execution/src/pipeline.rs | 175 +++++++- src/daft-local-execution/src/run.rs | 14 +- src/daft-local-execution/src/runtime_stats.rs | 16 +- .../src/sinks/aggregate.rs | 6 +- .../src/sinks/blocking_sink.rs | 34 +- .../src/sinks/hash_join.rs | 383 ------------------ .../src/sinks/hash_join_build.rs | 129 ++++++ src/daft-local-execution/src/sinks/mod.rs | 2 +- src/daft-local-execution/src/sinks/sort.rs | 6 +- .../src/sinks/streaming_sink.rs | 34 +- .../src/sources/in_memory.rs | 23 +- .../src/sources/scan_task.rs | 135 +++--- .../src/sources/source.rs | 40 +- 22 files changed, 755 insertions(+), 704 deletions(-) create mode 100644 src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs delete mode 100644 src/daft-local-execution/src/sinks/hash_join.rs create mode 100644 src/daft-local-execution/src/sinks/hash_join_build.rs diff --git a/Cargo.lock b/Cargo.lock index 8bf5e58b39..79fd15515c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1913,6 +1913,7 @@ dependencies = [ "pyo3", "snafu", "tokio", + "tokio-stream", "tracing", ] diff --git a/src/daft-local-execution/Cargo.toml b/src/daft-local-execution/Cargo.toml index 07b9bc4682..5c8e208dab 100644 --- a/src/daft-local-execution/Cargo.toml +++ b/src/daft-local-execution/Cargo.toml @@ -24,6 +24,7 @@ num-format = "0.4.4" pyo3 = {workspace = true, optional = true} snafu = {workspace = true} tokio = {workspace = true} +tokio-stream = {workspace = true} tracing = {workspace = true} [features] diff --git a/src/daft-local-execution/src/channel.rs b/src/daft-local-execution/src/channel.rs index 89031c2347..0f8cd621e6 100644 --- a/src/daft-local-execution/src/channel.rs +++ b/src/daft-local-execution/src/channel.rs @@ -1,96 +1,74 @@ -use std::sync::Arc; +pub type OneShotSender = tokio::sync::oneshot::Sender; +pub type OneShotReceiver = tokio::sync::oneshot::Receiver; -use daft_micropartition::MicroPartition; +pub fn create_one_shot_channel() -> (OneShotSender, OneShotReceiver) { + tokio::sync::oneshot::channel() +} -pub type SingleSender = tokio::sync::mpsc::Sender>; -pub type SingleReceiver = tokio::sync::mpsc::Receiver>; +pub type Sender = tokio::sync::mpsc::Sender; +pub type Receiver = tokio::sync::mpsc::Receiver; -pub fn create_single_channel(buffer_size: usize) -> (SingleSender, SingleReceiver) { +pub fn create_channel(buffer_size: usize) -> (Sender, Receiver) { tokio::sync::mpsc::channel(buffer_size) } -pub fn create_channel(buffer_size: usize, in_order: bool) -> (MultiSender, MultiReceiver) { +pub fn create_multi_channel( + buffer_size: usize, + in_order: bool, +) -> (MultiSender, MultiReceiver) { if in_order { - let (senders, receivers) = (0..buffer_size).map(|_| create_single_channel(1)).unzip(); - let sender = MultiSender::InOrder(InOrderSender::new(senders)); - let receiver = MultiReceiver::InOrder(InOrderReceiver::new(receivers)); + let (senders, receivers) = (0..buffer_size).map(|_| create_channel(1)).unzip(); + let sender = MultiSender::InOrder(RoundRobinSender::new(senders)); + let receiver = MultiReceiver::InOrder(RoundRobinReceiver::new(receivers)); (sender, receiver) } else { - let (sender, receiver) = create_single_channel(buffer_size); - let sender = MultiSender::OutOfOrder(OutOfOrderSender::new(sender)); - let receiver = MultiReceiver::OutOfOrder(OutOfOrderReceiver::new(receiver)); + let (sender, receiver) = create_channel(buffer_size); + let sender = MultiSender::OutOfOrder(sender); + let receiver = MultiReceiver::OutOfOrder(receiver); (sender, receiver) } } -pub enum MultiSender { - InOrder(InOrderSender), - OutOfOrder(OutOfOrderSender), +pub enum MultiSender { + InOrder(RoundRobinSender), + OutOfOrder(Sender), } -impl MultiSender { - pub fn get_next_sender(&mut self) -> SingleSender { +impl MultiSender { + pub fn get_next_sender(&mut self) -> Sender { match self { Self::InOrder(sender) => sender.get_next_sender(), - Self::OutOfOrder(sender) => sender.get_sender(), - } - } - - pub fn buffer_size(&self) -> usize { - match self { - Self::InOrder(sender) => sender.senders.len(), - Self::OutOfOrder(sender) => sender.sender.max_capacity(), - } - } - - pub fn in_order(&self) -> bool { - match self { - Self::InOrder(_) => true, - Self::OutOfOrder(_) => false, + Self::OutOfOrder(sender) => sender.clone(), } } } -pub struct InOrderSender { - senders: Vec, +pub struct RoundRobinSender { + senders: Vec>, curr_sender_idx: usize, } -impl InOrderSender { - pub fn new(senders: Vec) -> Self { +impl RoundRobinSender { + pub fn new(senders: Vec>) -> Self { Self { senders, curr_sender_idx: 0, } } - pub fn get_next_sender(&mut self) -> SingleSender { + pub fn get_next_sender(&mut self) -> Sender { let next_idx = self.curr_sender_idx; self.curr_sender_idx = (next_idx + 1) % self.senders.len(); self.senders[next_idx].clone() } } -pub struct OutOfOrderSender { - sender: SingleSender, +pub enum MultiReceiver { + InOrder(RoundRobinReceiver), + OutOfOrder(Receiver), } -impl OutOfOrderSender { - pub fn new(sender: SingleSender) -> Self { - Self { sender } - } - - pub fn get_sender(&self) -> SingleSender { - self.sender.clone() - } -} - -pub enum MultiReceiver { - InOrder(InOrderReceiver), - OutOfOrder(OutOfOrderReceiver), -} - -impl MultiReceiver { - pub async fn recv(&mut self) -> Option> { +impl MultiReceiver { + pub async fn recv(&mut self) -> Option { match self { Self::InOrder(receiver) => receiver.recv().await, Self::OutOfOrder(receiver) => receiver.recv().await, @@ -98,14 +76,14 @@ impl MultiReceiver { } } -pub struct InOrderReceiver { - receivers: Vec, +pub struct RoundRobinReceiver { + receivers: Vec>, curr_receiver_idx: usize, is_done: bool, } -impl InOrderReceiver { - pub fn new(receivers: Vec) -> Self { +impl RoundRobinReceiver { + pub fn new(receivers: Vec>) -> Self { Self { receivers, curr_receiver_idx: 0, @@ -113,7 +91,7 @@ impl InOrderReceiver { } } - pub async fn recv(&mut self) -> Option> { + pub async fn recv(&mut self) -> Option { if self.is_done { return None; } @@ -128,20 +106,3 @@ impl InOrderReceiver { None } } - -pub struct OutOfOrderReceiver { - receiver: SingleReceiver, -} - -impl OutOfOrderReceiver { - pub fn new(receiver: SingleReceiver) -> Self { - Self { receiver } - } - - pub async fn recv(&mut self) -> Option> { - if let Some(val) = self.receiver.recv().await { - return Some(val); - } - None - } -} diff --git a/src/daft-local-execution/src/intermediate_ops/aggregate.rs b/src/daft-local-execution/src/intermediate_ops/aggregate.rs index ec0cbcda7a..69f2c5a6e9 100644 --- a/src/daft-local-execution/src/intermediate_ops/aggregate.rs +++ b/src/daft-local-execution/src/intermediate_ops/aggregate.rs @@ -2,10 +2,13 @@ use std::sync::Arc; use common_error::DaftResult; use daft_dsl::ExprRef; -use daft_micropartition::MicroPartition; use tracing::instrument; -use super::intermediate_op::IntermediateOperator; +use crate::pipeline::PipelineResultType; + +use super::intermediate_op::{ + IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, +}; pub struct AggregateOperator { agg_exprs: Vec, @@ -23,9 +26,17 @@ impl AggregateOperator { impl IntermediateOperator for AggregateOperator { #[instrument(skip_all, name = "AggregateOperator::execute")] - fn execute(&self, input: &Arc) -> DaftResult> { + fn execute( + &self, + _idx: usize, + input: &PipelineResultType, + _state: Option<&mut Box>, + ) -> DaftResult { + let input = input.as_data(); let out = input.agg(&self.agg_exprs, &self.group_by)?; - Ok(Arc::new(out)) + Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( + out, + )))) } fn name(&self) -> &'static str { diff --git a/src/daft-local-execution/src/intermediate_ops/filter.rs b/src/daft-local-execution/src/intermediate_ops/filter.rs index 0461cd7283..e8b1611b87 100644 --- a/src/daft-local-execution/src/intermediate_ops/filter.rs +++ b/src/daft-local-execution/src/intermediate_ops/filter.rs @@ -2,10 +2,13 @@ use std::sync::Arc; use common_error::DaftResult; use daft_dsl::ExprRef; -use daft_micropartition::MicroPartition; use tracing::instrument; -use super::intermediate_op::IntermediateOperator; +use crate::pipeline::PipelineResultType; + +use super::intermediate_op::{ + IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, +}; pub struct FilterOperator { predicate: ExprRef, @@ -19,9 +22,17 @@ impl FilterOperator { impl IntermediateOperator for FilterOperator { #[instrument(skip_all, name = "FilterOperator::execute")] - fn execute(&self, input: &Arc) -> DaftResult> { + fn execute( + &self, + _idx: usize, + input: &PipelineResultType, + _state: Option<&mut Box>, + ) -> DaftResult { + let input = input.as_data(); let out = input.filter(&[self.predicate.clone()])?; - Ok(Arc::new(out)) + Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( + out, + )))) } fn name(&self) -> &'static str { diff --git a/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs new file mode 100644 index 0000000000..c2f3ff4a3e --- /dev/null +++ b/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs @@ -0,0 +1,142 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use daft_core::JoinType; +use daft_dsl::ExprRef; +use daft_micropartition::MicroPartition; +use daft_table::{GrowableTable, ProbeTable, Table}; +use tracing::{info_span, instrument}; + +use crate::pipeline::PipelineResultType; + +use super::intermediate_op::{ + IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, +}; + +enum HashJoinProbeState { + Building, + ReadyToProbe(Arc, Arc>), +} + +impl HashJoinProbeState { + fn probe( + &self, + input: &Arc, + right_on: &[ExprRef], + pruned_right_side_columns: &[String], + ) -> DaftResult> { + if let HashJoinProbeState::ReadyToProbe(probe_table, tables) = self { + let _growables = info_span!("HashJoinOperator::build_growables").entered(); + + // Left should only be created once per probe table + let mut left_growable = + GrowableTable::new(&tables.iter().collect::>(), false, 20)?; + // right should only be created morsel + + let right_input_tables = input.get_tables()?; + + let mut right_growable = + GrowableTable::new(&right_input_tables.iter().collect::>(), false, 20)?; + + drop(_growables); + { + let _loop = info_span!("HashJoinOperator::eval_and_probe").entered(); + for (r_table_idx, table) in right_input_tables.iter().enumerate() { + // we should emit one table at a time when this is streaming + let join_keys = table.eval_expression_list(right_on)?; + let iter = probe_table.probe(&join_keys)?; + + for (l_table_idx, l_row_idx, right_idx) in iter { + left_growable.extend(l_table_idx as usize, l_row_idx as usize, 1); + // we can perform run length compression for this to make this more efficient + right_growable.extend(r_table_idx, right_idx as usize, 1); + } + } + } + let left_table = left_growable.build()?; + let right_table = right_growable.build()?; + + let pruned_right_table = right_table.get_columns(pruned_right_side_columns)?; + + let final_table = left_table.union(&pruned_right_table)?; + Ok(Arc::new(MicroPartition::new_loaded( + final_table.schema.clone(), + Arc::new(vec![final_table]), + None, + ))) + } else { + panic!("probe can only be used during the ReadyToProbe Phase") + } + } +} + +impl IntermediateOperatorState for HashJoinProbeState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } +} + +pub struct HashJoinProbeOperator { + right_on: Vec, + pruned_right_side_columns: Vec, + _join_type: JoinType, +} + +impl HashJoinProbeOperator { + pub fn new( + right_on: Vec, + pruned_right_side_columns: Vec, + join_type: JoinType, + ) -> Self { + Self { + right_on, + pruned_right_side_columns, + _join_type: join_type, + } + } +} + +impl IntermediateOperator for HashJoinProbeOperator { + #[instrument(skip_all, name = "HashJoinOperator::execute")] + fn execute( + &self, + idx: usize, + input: &PipelineResultType, + state: Option<&mut Box>, + ) -> DaftResult { + match idx { + 0 => { + let (probe_table, tables) = input.as_probe_table(); + let state = state + .expect("HashJoinProbeOperator should have state") + .as_any_mut() + .downcast_mut::() + .expect("HashJoinProbeOperator state should be HashJoinProbeState"); + if let HashJoinProbeState::Building = state { + *state = HashJoinProbeState::ReadyToProbe(probe_table.clone(), tables.clone()); + } else { + panic!("HashJoinProbeOperator should only be in Building state on first input"); + } + Ok(IntermediateOperatorResult::NeedMoreInput(None)) + } + _ => { + let state = state + .expect("HashJoinProbeOperator should have state") + .as_any_mut() + .downcast_mut::() + .expect("HashJoinProbeOperator state should be HashJoinProbeState"); + let input = input.as_data(); + let out = state.probe(input, &self.right_on, &self.pruned_right_side_columns)?; + Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) + } + } + } + + fn name(&self) -> &'static str { + "HashJoinProbeOperator" + } + + fn make_state(&self) -> Option> { + Some(Box::new(HashJoinProbeState::Building)) + } +} diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index ecb7182fab..283b7badb8 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -8,20 +8,35 @@ use tracing::{info_span, instrument}; use async_trait::async_trait; use crate::{ - channel::{ - create_channel, create_single_channel, MultiReceiver, MultiSender, SingleReceiver, - SingleSender, - }, - pipeline::PipelineNode, + channel::{create_channel, create_multi_channel, MultiSender, Receiver, Sender}, + pipeline::{PipelineNode, PipelineResultReceiver, PipelineResultType}, runtime_stats::{CountingSender, RuntimeStatsContext}, ExecutionRuntimeHandle, NUM_CPUS, }; use super::buffer::OperatorBuffer; +pub trait IntermediateOperatorState: Send + Sync { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any; +} + +pub enum IntermediateOperatorResult { + NeedMoreInput(Option>), + #[allow(dead_code)] + HasMoreOutput(Arc), +} + pub trait IntermediateOperator: Send + Sync { - fn execute(&self, input: &Arc) -> DaftResult>; + fn execute( + &self, + idx: usize, + input: &PipelineResultType, + state: Option<&mut Box>, + ) -> DaftResult; fn name(&self) -> &'static str; + fn make_state(&self) -> Option> { + None + } } pub(crate) struct IntermediateNode { @@ -58,29 +73,44 @@ impl IntermediateNode { #[instrument(level = "info", skip_all, name = "IntermediateOperator::run_worker")] pub async fn run_worker( op: Arc, - mut receiver: SingleReceiver, - sender: SingleSender, + mut receiver: Receiver<(usize, PipelineResultType)>, + sender: Sender, rt_context: Arc, ) -> DaftResult<()> { let span = info_span!("IntermediateOp::execute"); + let mut state = op.make_state(); let sender = CountingSender::new(sender, rt_context.clone()); - while let Some(morsel) = receiver.recv().await { - rt_context.mark_rows_received(morsel.len() as u64); - let result = rt_context.in_span(&span, || op.execute(&morsel))?; - let _ = sender.send(result).await; + while let Some((idx, morsel)) = receiver.recv().await { + let len = match morsel { + PipelineResultType::Data(ref data) => data.len(), + PipelineResultType::ProbeTable(_, ref tables) => { + tables.iter().map(|t| t.len()).sum() + } + }; + rt_context.mark_rows_received(len as u64); + let result = rt_context.in_span(&span, || op.execute(idx, &morsel, state.as_mut()))?; + match result { + IntermediateOperatorResult::NeedMoreInput(Some(mp)) => { + let _ = sender.send(mp.into()).await; + } + IntermediateOperatorResult::NeedMoreInput(None) => {} + IntermediateOperatorResult::HasMoreOutput(mp) => { + let _ = sender.send(mp.into()).await; + } + } } Ok(()) } pub async fn spawn_workers( &self, - destination: &mut MultiSender, + num_workers: usize, + destination: &mut MultiSender, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> Vec { - let num_senders = destination.buffer_size(); - let mut worker_senders = Vec::with_capacity(num_senders); - for _ in 0..num_senders { - let (worker_sender, worker_receiver) = create_single_channel(1); + ) -> Vec> { + let mut worker_senders = Vec::with_capacity(num_workers); + for _ in 0..num_workers { + let (worker_sender, worker_receiver) = create_channel(1); let destination_sender = destination.get_next_sender(); runtime_handle.spawn( Self::run_worker( @@ -97,33 +127,40 @@ impl IntermediateNode { } pub async fn send_to_workers( - mut receiver: MultiReceiver, - worker_senders: Vec, + receivers: Vec, + worker_senders: Vec>, morsel_size: usize, ) -> DaftResult<()> { let mut next_worker_idx = 0; - let mut send_to_next_worker = |morsel: Arc| { + let mut send_to_next_worker = |idx, data: PipelineResultType| { let next_worker_sender = worker_senders.get(next_worker_idx).unwrap(); next_worker_idx = (next_worker_idx + 1) % worker_senders.len(); - next_worker_sender.send(morsel) + next_worker_sender.send((idx, data)) }; - let mut buffer = OperatorBuffer::new(morsel_size); - while let Some(morsel) = receiver.recv().await { - buffer.push(morsel); - if let Some(ready) = buffer.try_clear() { - let _ = send_to_next_worker(ready?).await; + for (idx, mut receiver) in receivers.into_iter().enumerate() { + let mut buffer = OperatorBuffer::new(morsel_size); + while let Some(morsel) = receiver.recv().await { + let morsel = morsel?; + if morsel.should_broadcast() { + for worker_sender in worker_senders.iter() { + let _ = worker_sender.send((idx, morsel.clone())).await; + } + } else { + buffer.push(morsel.as_data().clone()); + if let Some(ready) = buffer.try_clear() { + let _ = send_to_next_worker(idx, ready?.into()).await; + } + } + } + // Buffer may still have some morsels left above the threshold + while let Some(ready) = buffer.try_clear() { + let _ = send_to_next_worker(idx, ready?.into()).await; + } + // Clear all remaining morsels + if let Some(last_morsel) = buffer.clear_all() { + let _ = send_to_next_worker(idx, last_morsel?.into()).await; } - } - - // Buffer may still have some morsels left above the threshold - while let Some(ready) = buffer.try_clear() { - let _ = send_to_next_worker(ready?).await; - } - - // Clear all remaining morsels - if let Some(last_morsel) = buffer.clear_all() { - let _ = send_to_next_worker(last_morsel?).await; } Ok(()) } @@ -162,31 +199,29 @@ impl PipelineNode for IntermediateNode { async fn start( &mut self, - mut destination: MultiSender, + maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result<()> { - assert_eq!( - self.children.len(), - 1, - "we only support 1 child for Intermediate Node for now" - ); - let (sender, receiver) = create_channel(*NUM_CPUS, destination.in_order()); - let child = self - .children - .get_mut(0) - .expect("we should only have 1 child"); - child.start(sender, runtime_handle).await?; - - let worker_senders = self.spawn_workers(&mut destination, runtime_handle).await; + ) -> crate::Result { + let mut child_result_receivers = Vec::with_capacity(self.children.len()); + for child in self.children.iter_mut() { + let child_result_receiver = child.start(maintain_order, runtime_handle).await?; + child_result_receivers.push(child_result_receiver); + } + let (mut destination_sender, destination_receiver) = + create_multi_channel(*NUM_CPUS, maintain_order); + + let worker_senders = self + .spawn_workers(*NUM_CPUS, &mut destination_sender, runtime_handle) + .await; runtime_handle.spawn( Self::send_to_workers( - receiver, + child_result_receivers, worker_senders, runtime_handle.default_morsel_size(), ), self.intermediate_op.name(), ); - Ok(()) + Ok(destination_receiver.into()) } fn as_tree_display(&self) -> &dyn TreeDisplay { self diff --git a/src/daft-local-execution/src/intermediate_ops/mod.rs b/src/daft-local-execution/src/intermediate_ops/mod.rs index 290f4bf184..781e0f5db9 100644 --- a/src/daft-local-execution/src/intermediate_ops/mod.rs +++ b/src/daft-local-execution/src/intermediate_ops/mod.rs @@ -1,5 +1,6 @@ pub mod aggregate; pub mod buffer; pub mod filter; +pub mod hash_join_probe; pub mod intermediate_op; pub mod project; diff --git a/src/daft-local-execution/src/intermediate_ops/project.rs b/src/daft-local-execution/src/intermediate_ops/project.rs index a38bbcd2cc..090116ad71 100644 --- a/src/daft-local-execution/src/intermediate_ops/project.rs +++ b/src/daft-local-execution/src/intermediate_ops/project.rs @@ -2,10 +2,13 @@ use std::sync::Arc; use common_error::DaftResult; use daft_dsl::ExprRef; -use daft_micropartition::MicroPartition; use tracing::instrument; -use super::intermediate_op::IntermediateOperator; +use crate::pipeline::PipelineResultType; + +use super::intermediate_op::{ + IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, +}; pub struct ProjectOperator { projection: Vec, @@ -19,9 +22,17 @@ impl ProjectOperator { impl IntermediateOperator for ProjectOperator { #[instrument(skip_all, name = "ProjectOperator::execute")] - fn execute(&self, input: &Arc) -> DaftResult> { + fn execute( + &self, + _idx: usize, + input: &PipelineResultType, + _state: Option<&mut Box>, + ) -> DaftResult { + let input = input.as_data(); let out = input.eval_expression_list(&self.projection)?; - Ok(Arc::new(out)) + Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( + out, + )))) } fn name(&self) -> &'static str { diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index 723bcc18ed..f87e7f8fea 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -1,24 +1,29 @@ use std::{collections::HashMap, sync::Arc}; use crate::{ + channel::{MultiReceiver, OneShotReceiver, Receiver}, intermediate_ops::{ - aggregate::AggregateOperator, filter::FilterOperator, intermediate_op::IntermediateNode, + aggregate::AggregateOperator, filter::FilterOperator, + hash_join_probe::HashJoinProbeOperator, intermediate_op::IntermediateNode, project::ProjectOperator, }, sinks::{ - aggregate::AggregateSink, - blocking_sink::BlockingSinkNode, - hash_join::{HashJoinNode, HashJoinOperator}, - limit::LimitSink, - sort::SortSink, + aggregate::AggregateSink, blocking_sink::BlockingSinkNode, + hash_join_build::HashJoinBuildSink, limit::LimitSink, sort::SortSink, streaming_sink::StreamingSinkNode, }, sources::in_memory::InMemorySource, - ExecutionRuntimeHandle, PipelineCreationSnafu, + ExecutionRuntimeHandle, OneShotRecvSnafu, PipelineCreationSnafu, }; use async_trait::async_trait; use common_display::{mermaid::MermaidDisplayVisitor, tree::TreeDisplay}; +use common_error::DaftResult; +use daft_core::{ + datatypes::Field, + schema::{Schema, SchemaRef}, + utils::supertype, +}; use daft_dsl::Expr; use daft_micropartition::MicroPartition; use daft_physical_plan::{ @@ -26,9 +31,88 @@ use daft_physical_plan::{ UnGroupedAggregate, }; use daft_plan::populate_aggregation_stages; +use daft_table::{ProbeTable, Table}; use snafu::ResultExt; -use crate::channel::MultiSender; +#[derive(Clone)] +pub enum PipelineResultType { + Data(Arc), + ProbeTable(Arc, Arc>), +} + +impl From> for PipelineResultType { + fn from(data: Arc) -> Self { + PipelineResultType::Data(data) + } +} + +impl From<(Arc, Arc>)> for PipelineResultType { + fn from((probe_table, tables): (Arc, Arc>)) -> Self { + PipelineResultType::ProbeTable(probe_table, tables) + } +} + +impl PipelineResultType { + pub fn as_data(&self) -> &Arc { + match self { + PipelineResultType::Data(data) => data, + _ => panic!("Expected data"), + } + } + + pub fn as_probe_table(&self) -> (&Arc, &Arc>) { + match self { + PipelineResultType::ProbeTable(probe_table, tables) => (probe_table, tables), + _ => panic!("Expected probe table"), + } + } + + pub fn should_broadcast(&self) -> bool { + matches!(self, PipelineResultType::ProbeTable(_, _)) + } +} + +pub enum PipelineResultReceiver { + Single(Receiver), + Multi(MultiReceiver), + OneShot(OneShotReceiver, bool), +} + +impl From> for PipelineResultReceiver { + fn from(rx: Receiver) -> Self { + PipelineResultReceiver::Multi(MultiReceiver::OutOfOrder(rx)) + } +} + +impl From> for PipelineResultReceiver { + fn from(rx: MultiReceiver) -> Self { + PipelineResultReceiver::Multi(rx) + } +} + +impl From> for PipelineResultReceiver { + fn from(rx: OneShotReceiver) -> Self { + PipelineResultReceiver::OneShot(rx, false) + } +} + +impl PipelineResultReceiver { + pub async fn recv(&mut self) -> Option> { + match self { + PipelineResultReceiver::Single(rx) => rx.recv().await.map(Ok), + PipelineResultReceiver::Multi(rx) => rx.recv().await.map(Ok), + PipelineResultReceiver::OneShot(rx, done) => { + if *done { + None + } else { + let result = rx.await.context(OneShotRecvSnafu); + *done = true; + Some(result) + } + } + } + } +} #[async_trait] pub trait PipelineNode: Sync + Send + TreeDisplay { @@ -36,9 +120,9 @@ pub trait PipelineNode: Sync + Send + TreeDisplay { fn name(&self) -> &'static str; async fn start( &mut self, - destination: MultiSender, + maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result<()>; + ) -> crate::Result; fn as_tree_display(&self) -> &dyn TreeDisplay; } @@ -192,18 +276,71 @@ pub fn physical_plan_to_pipeline( let left_node = physical_plan_to_pipeline(left, psets)?; let right_node = physical_plan_to_pipeline(right, psets)?; - // we should move to a builder pattern - let sink = HashJoinOperator::new( - left_on.clone(), - right_on.clone(), - *join_type, - left_schema, - right_schema, - ) + let probe_node = || -> DaftResult<_> { + let left_key_fields = left_on + .iter() + .map(|e| e.to_field(left_schema)) + .collect::>>()?; + let right_key_fields = right_on + .iter() + .map(|e| e.to_field(right_schema)) + .collect::>>()?; + let key_schema: SchemaRef = Schema::new( + left_key_fields + .into_iter() + .zip(right_key_fields.into_iter()) + .map(|(l, r)| { + // TODO we should be using the comparison_op function here instead but i'm just using existing behavior for now + let dtype = supertype::try_get_supertype(&l.dtype, &r.dtype)?; + Ok(Field::new(l.name, dtype)) + }) + .collect::>>()?, + )? + .into(); + + let left_on = left_on + .iter() + .zip(key_schema.fields.values()) + .map(|(e, f)| e.clone().cast(&f.dtype)) + .collect::>(); + let right_on = right_on + .iter() + .zip(key_schema.fields.values()) + .map(|(e, f)| e.clone().cast(&f.dtype)) + .collect::>(); + let common_join_keys = left_on + .iter() + .zip(right_on.iter()) + .filter_map(|(l, r)| { + if l.name() == r.name() { + Some(l.name()) + } else { + None + } + }) + .collect::>(); + let pruned_right_side_columns = right_schema + .fields + .keys() + .filter(|k| !common_join_keys.contains(k.as_str())) + .cloned() + .collect::>(); + + // we should move to a builder pattern + let build_sink = HashJoinBuildSink::new(key_schema.clone(), left_on)?; + let build_node = BlockingSinkNode::new(build_sink.boxed(), left_node).boxed(); + + let probe_op = + HashJoinProbeOperator::new(right_on, pruned_right_side_columns, *join_type); + DaftResult::Ok(IntermediateNode::new( + Arc::new(probe_op), + vec![build_node, right_node], + )) + }() .with_context(|_| PipelineCreationSnafu { plan_name: physical_plan.name(), })?; - HashJoinNode::new(sink, left_node, right_node).boxed() + probe_node.boxed() } _ => { unimplemented!("Physical plan not supported: {}", physical_plan.name()); diff --git a/src/daft-local-execution/src/run.rs b/src/daft-local-execution/src/run.rs index 1e8a03ba24..7767635f1f 100644 --- a/src/daft-local-execution/src/run.rs +++ b/src/daft-local-execution/src/run.rs @@ -23,9 +23,9 @@ use { }; use crate::{ - channel::{create_channel, create_single_channel, SingleReceiver}, + channel::{create_channel, Receiver}, pipeline::{physical_plan_to_pipeline, viz_pipeline}, - Error, ExecutionRuntimeHandle, NUM_CPUS, + Error, ExecutionRuntimeHandle, }; #[cfg(feature = "python")] @@ -122,7 +122,7 @@ pub fn run_local( ) -> DaftResult>> + Send>> { refresh_chrome_trace(); let mut pipeline = physical_plan_to_pipeline(physical_plan, &psets)?; - let (tx, rx) = create_single_channel(results_buffer_size.unwrap_or(1)); + let (tx, rx) = create_channel(results_buffer_size.unwrap_or(1)); let handle = std::thread::spawn(move || { let runtime = tokio::runtime::Builder::new_multi_thread() .enable_all() @@ -135,12 +135,10 @@ pub fn run_local( .build() .expect("Failed to create tokio runtime"); runtime.block_on(async { - let (sender, mut receiver) = create_channel(*NUM_CPUS, true); - let mut runtime_handle = ExecutionRuntimeHandle::new(cfg.default_morsel_size); - pipeline.start(sender, &mut runtime_handle).await?; + let mut receiver = pipeline.start(true, &mut runtime_handle).await?; while let Some(val) = receiver.recv().await { - let _ = tx.send(val).await; + let _ = tx.send(val?.as_data().clone()).await; } while let Some(result) = runtime_handle.join_next().await { @@ -170,7 +168,7 @@ pub fn run_local( }); struct ReceiverIterator { - receiver: SingleReceiver, + receiver: Receiver>, handle: Option>>, } diff --git a/src/daft-local-execution/src/runtime_stats.rs b/src/daft-local-execution/src/runtime_stats.rs index e3e80886d2..adc1c8270d 100644 --- a/src/daft-local-execution/src/runtime_stats.rs +++ b/src/daft-local-execution/src/runtime_stats.rs @@ -5,10 +5,9 @@ use std::{ time::Instant, }; -use daft_micropartition::MicroPartition; use tokio::sync::mpsc::error::SendError; -use crate::channel::SingleSender; +use crate::{channel::Sender, pipeline::PipelineResultType}; #[derive(Default)] pub(crate) struct RuntimeStatsContext { @@ -108,20 +107,23 @@ impl RuntimeStatsContext { } pub(crate) struct CountingSender { - sender: SingleSender, + sender: Sender, rt: Arc, } impl CountingSender { - pub(crate) fn new(sender: SingleSender, rt: Arc) -> Self { + pub(crate) fn new(sender: Sender, rt: Arc) -> Self { Self { sender, rt } } #[inline] pub(crate) async fn send( &self, - v: Arc, - ) -> Result<(), SendError>> { - let len = v.len(); + v: PipelineResultType, + ) -> Result<(), SendError> { + let len = match v { + PipelineResultType::Data(ref mp) => mp.len(), + PipelineResultType::ProbeTable(_, ref tables) => tables.iter().map(|t| t.len()).sum(), + }; self.sender.send(v).await?; self.rt.mark_rows_emitted(len as u64); Ok(()) diff --git a/src/daft-local-execution/src/sinks/aggregate.rs b/src/daft-local-execution/src/sinks/aggregate.rs index 0ed1862674..33163758ef 100644 --- a/src/daft-local-execution/src/sinks/aggregate.rs +++ b/src/daft-local-execution/src/sinks/aggregate.rs @@ -5,6 +5,8 @@ use daft_dsl::ExprRef; use daft_micropartition::MicroPartition; use tracing::instrument; +use crate::pipeline::PipelineResultType; + use super::blocking_sink::{BlockingSink, BlockingSinkStatus}; enum AggregateState { @@ -45,7 +47,7 @@ impl BlockingSink for AggregateSink { } #[instrument(skip_all, name = "AggregateSink::finalize")] - fn finalize(&mut self) -> DaftResult>> { + fn finalize(&mut self) -> DaftResult> { if let AggregateState::Accumulating(parts) = &mut self.state { assert!( !parts.is_empty(), @@ -55,7 +57,7 @@ impl BlockingSink for AggregateSink { MicroPartition::concat(&parts.iter().map(|x| x.as_ref()).collect::>())?; let agged = Arc::new(concated.agg(&self.agg_exprs, &self.group_by)?); self.state = AggregateState::Done(agged.clone()); - Ok(Some(agged)) + Ok(Some(agged.into())) } else { panic!("AggregateSink should be in Accumulating state"); } diff --git a/src/daft-local-execution/src/sinks/blocking_sink.rs b/src/daft-local-execution/src/sinks/blocking_sink.rs index 17320734ec..1399791f96 100644 --- a/src/daft-local-execution/src/sinks/blocking_sink.rs +++ b/src/daft-local-execution/src/sinks/blocking_sink.rs @@ -6,10 +6,10 @@ use daft_micropartition::MicroPartition; use tracing::info_span; use crate::{ - channel::{create_channel, MultiSender}, - pipeline::PipelineNode, + channel::create_one_shot_channel, + pipeline::{PipelineNode, PipelineResultReceiver, PipelineResultType}, runtime_stats::RuntimeStatsContext, - ExecutionRuntimeHandle, NUM_CPUS, + ExecutionRuntimeHandle, }; use async_trait::async_trait; pub enum BlockingSinkStatus { @@ -20,7 +20,7 @@ pub enum BlockingSinkStatus { pub trait BlockingSink: Send + Sync { fn sink(&mut self, input: &Arc) -> DaftResult; - fn finalize(&mut self) -> DaftResult>>; + fn finalize(&mut self) -> DaftResult>; fn name(&self) -> &'static str; } @@ -80,24 +80,25 @@ impl PipelineNode for BlockingSinkNode { async fn start( &mut self, - mut destination: MultiSender, + _maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result<()> { - let (sender, mut streaming_receiver) = create_channel(*NUM_CPUS, true); - // now we can start building the right side + ) -> crate::Result { let child = self.child.as_mut(); - child.start(sender, runtime_handle).await?; + let mut child_results_receiver = child.start(false, runtime_handle).await?; let op = self.op.clone(); + let (destination_sender, destination_receiver) = create_one_shot_channel(); let rt_context = self.runtime_stats.clone(); runtime_handle.spawn( async move { let span = info_span!("BlockingSinkNode::execute"); let mut guard = op.lock().await; - while let Some(val) = streaming_receiver.recv().await { + while let Some(val) = child_results_receiver.recv().await { + let val = val?; + let val = val.as_data(); rt_context.mark_rows_received(val.len() as u64); if let BlockingSinkStatus::Finished = - rt_context.in_span(&span, || guard.sink(&val))? + rt_context.in_span(&span, || guard.sink(val))? { break; } @@ -107,15 +108,20 @@ impl PipelineNode for BlockingSinkNode { guard.finalize() })?; if let Some(part) = finalized_result { - let len = part.len(); - let _ = destination.get_next_sender().send(part).await; + let len = match part { + PipelineResultType::Data(ref part) => part.len(), + PipelineResultType::ProbeTable(_, ref tables) => { + tables.iter().map(|t| t.len()).sum() + } + }; + let _ = destination_sender.send(part); rt_context.mark_rows_emitted(len as u64); } Ok(()) }, self.name(), ); - Ok(()) + Ok(destination_receiver.into()) } fn as_tree_display(&self) -> &dyn TreeDisplay { self diff --git a/src/daft-local-execution/src/sinks/hash_join.rs b/src/daft-local-execution/src/sinks/hash_join.rs deleted file mode 100644 index eabe56024e..0000000000 --- a/src/daft-local-execution/src/sinks/hash_join.rs +++ /dev/null @@ -1,383 +0,0 @@ -use std::sync::Arc; - -use crate::{ - channel::{create_channel, MultiSender}, - intermediate_ops::intermediate_op::{IntermediateNode, IntermediateOperator}, - pipeline::PipelineNode, - runtime_stats::RuntimeStatsContext, - ExecutionRuntimeHandle, JoinSnafu, PipelineExecutionSnafu, NUM_CPUS, -}; -use async_trait::async_trait; -use common_display::tree::TreeDisplay; -use common_error::DaftResult; -use daft_core::{ - datatypes::Field, - schema::{Schema, SchemaRef}, - utils::supertype, -}; -use daft_dsl::ExprRef; -use daft_micropartition::MicroPartition; -use daft_plan::JoinType; -use snafu::{futures::TryFutureExt, ResultExt}; -use tracing::info_span; - -use super::blocking_sink::{BlockingSink, BlockingSinkStatus}; -use daft_table::{GrowableTable, ProbeTable, ProbeTableBuilder, Table}; - -enum HashJoinState { - Building { - probe_table_builder: Option, - projection: Vec, - tables: Vec, - }, - Probing { - probe_table: Arc, - tables: Arc>, - }, -} - -impl HashJoinState { - fn new(key_schema: &SchemaRef, projection: Vec) -> DaftResult { - Ok(Self::Building { - probe_table_builder: Some(ProbeTableBuilder::new(key_schema.clone())?), - projection, - tables: vec![], - }) - } - - fn add_tables(&mut self, input: &Arc) -> DaftResult<()> { - if let Self::Building { - ref mut probe_table_builder, - projection, - tables, - } = self - { - let probe_table_builder = probe_table_builder.as_mut().unwrap(); - for table in input.get_tables()?.iter() { - tables.push(table.clone()); - let join_keys = table.eval_expression_list(projection)?; - - probe_table_builder.add_table(&join_keys)?; - } - Ok(()) - } else { - panic!("add_tables can only be used during the Building Phase") - } - } - fn finalize(&mut self) -> DaftResult<()> { - if let Self::Building { - probe_table_builder, - tables, - .. - } = self - { - let ptb = std::mem::take(probe_table_builder).expect("should be set in building mode"); - let pt = ptb.build(); - - *self = Self::Probing { - probe_table: Arc::new(pt), - tables: Arc::new(tables.clone()), - }; - Ok(()) - } else { - panic!("finalize can only be used during the Building Phase") - } - } -} - -pub(crate) struct HashJoinOperator { - right_on: Vec, - pruned_right_side_columns: Vec, - _join_type: JoinType, - join_state: HashJoinState, -} - -impl HashJoinOperator { - pub(crate) fn new( - left_on: Vec, - right_on: Vec, - join_type: JoinType, - left_schema: &SchemaRef, - right_schema: &SchemaRef, - ) -> DaftResult { - let left_key_fields = left_on - .iter() - .map(|e| e.to_field(left_schema)) - .collect::>>()?; - let right_key_fields = right_on - .iter() - .map(|e| e.to_field(right_schema)) - .collect::>>()?; - let key_schema: SchemaRef = Schema::new( - left_key_fields - .into_iter() - .zip(right_key_fields.into_iter()) - .map(|(l, r)| { - // TODO we should be using the comparison_op function here instead but i'm just using existing behavior for now - let dtype = supertype::try_get_supertype(&l.dtype, &r.dtype)?; - Ok(Field::new(l.name, dtype)) - }) - .collect::>>()?, - )? - .into(); - - let left_on = left_on - .into_iter() - .zip(key_schema.fields.values()) - .map(|(e, f)| e.cast(&f.dtype)) - .collect::>(); - let right_on = right_on - .into_iter() - .zip(key_schema.fields.values()) - .map(|(e, f)| e.cast(&f.dtype)) - .collect::>(); - let common_join_keys = left_on - .iter() - .zip(right_on.iter()) - .filter_map(|(l, r)| { - if l.name() == r.name() { - Some(l.name()) - } else { - None - } - }) - .collect::>(); - let pruned_right_side_columns = right_schema - .fields - .keys() - .filter(|k| !common_join_keys.contains(k.as_str())) - .cloned() - .collect::>(); - assert_eq!(join_type, JoinType::Inner); - Ok(Self { - right_on, - pruned_right_side_columns, - _join_type: join_type, - join_state: HashJoinState::new(&key_schema, left_on)?, - }) - } - - fn as_sink(&mut self) -> &mut dyn BlockingSink { - self - } - - fn as_intermediate_op(&self) -> Arc { - if let HashJoinState::Probing { - probe_table, - tables, - } = &self.join_state - { - Arc::new(HashJoinProber { - probe_table: probe_table.clone(), - tables: tables.clone(), - right_on: self.right_on.clone(), - pruned_right_side_columns: self.pruned_right_side_columns.clone(), - }) - } else { - panic!("can't call as_intermediate_op when not in probing state") - } - } -} - -struct HashJoinProber { - probe_table: Arc, - tables: Arc>, - right_on: Vec, - pruned_right_side_columns: Vec, -} - -impl IntermediateOperator for HashJoinProber { - fn name(&self) -> &'static str { - "HashJoinProber" - } - fn execute(&self, input: &Arc) -> DaftResult> { - let _span = info_span!("HashJoinOperator::execute").entered(); - let _growables = info_span!("HashJoinOperator::build_growables").entered(); - - // Left should only be created once per probe table - let mut left_growable = - GrowableTable::new(&self.tables.iter().collect::>(), false, 20)?; - // right should only be created morsel - - let right_input_tables = input.get_tables()?; - - let mut right_growable = - GrowableTable::new(&right_input_tables.iter().collect::>(), false, 20)?; - - drop(_growables); - { - let _loop = info_span!("HashJoinOperator::eval_and_probe").entered(); - for (r_table_idx, table) in right_input_tables.iter().enumerate() { - // we should emit one table at a time when this is streaming - let join_keys = table.eval_expression_list(&self.right_on)?; - let iter = self.probe_table.probe(&join_keys)?; - - for (l_table_idx, l_row_idx, right_idx) in iter { - left_growable.extend(l_table_idx as usize, l_row_idx as usize, 1); - // we can perform run length compression for this to make this more efficient - right_growable.extend(r_table_idx, right_idx as usize, 1); - } - } - } - let left_table = left_growable.build()?; - let right_table = right_growable.build()?; - - let pruned_right_table = right_table.get_columns(&self.pruned_right_side_columns)?; - - let final_table = left_table.union(&pruned_right_table)?; - Ok(Arc::new(MicroPartition::new_loaded( - final_table.schema.clone(), - Arc::new(vec![final_table]), - None, - ))) - } -} - -impl BlockingSink for HashJoinOperator { - fn name(&self) -> &'static str { - "HashJoin" - } - - fn sink(&mut self, input: &Arc) -> DaftResult { - self.join_state.add_tables(input)?; - Ok(BlockingSinkStatus::NeedMoreInput) - } - fn finalize(&mut self) -> DaftResult>> { - self.join_state.finalize()?; - Ok(None) - } -} - -pub(crate) struct HashJoinNode { - // use a RW lock - hash_join: Arc>, - left: Box, - right: Box, - build_runtime_stats: Arc, - probe_runtime_stats: Arc, -} - -impl HashJoinNode { - pub(crate) fn new( - op: HashJoinOperator, - left: Box, - right: Box, - ) -> Self { - HashJoinNode { - hash_join: Arc::new(tokio::sync::Mutex::new(op)), - left, - right, - build_runtime_stats: RuntimeStatsContext::new(), - probe_runtime_stats: RuntimeStatsContext::new(), - } - } - pub(crate) fn boxed(self) -> Box { - Box::new(self) - } -} - -impl TreeDisplay for HashJoinNode { - fn display_as(&self, level: common_display::DisplayLevel) -> String { - use std::fmt::Write; - let mut display = String::new(); - writeln!(display, "{}", self.name()).unwrap(); - use common_display::DisplayLevel::*; - match level { - Compact => {} - _ => { - let build_rt_result = self.build_runtime_stats.result(); - writeln!(display, "Probe Table Build:").unwrap(); - - build_rt_result - .display(&mut display, true, false, true) - .unwrap(); - - let probe_rt_result = self.probe_runtime_stats.result(); - writeln!(display, "\nProbe Phase:").unwrap(); - probe_rt_result - .display(&mut display, true, true, true) - .unwrap(); - } - } - display - } - fn get_children(&self) -> Vec<&dyn TreeDisplay> { - vec![self.left.as_tree_display(), self.right.as_tree_display()] - } -} - -#[async_trait] -impl PipelineNode for HashJoinNode { - fn children(&self) -> Vec<&dyn PipelineNode> { - vec![self.left.as_ref(), self.right.as_ref()] - } - - fn name(&self) -> &'static str { - "HashJoin" - } - - async fn start( - &mut self, - mut destination: MultiSender, - runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result<()> { - let (sender, mut pt_receiver) = create_channel(*NUM_CPUS, false); - self.left.start(sender, runtime_handle).await?; - let hash_join = self.hash_join.clone(); - let build_runtime_stats = self.build_runtime_stats.clone(); - let name = self.name(); - let probe_table_build = tokio::spawn( - async move { - let span = info_span!("ProbeTable::sink"); - let mut guard = hash_join.lock().await; - let sink = guard.as_sink(); - while let Some(val) = pt_receiver.recv().await { - build_runtime_stats.mark_rows_received(val.len() as u64); - if let BlockingSinkStatus::Finished = - build_runtime_stats.in_span(&span, || sink.sink(&val))? - { - break; - } - } - build_runtime_stats - .in_span(&info_span!("ProbeTable::finalize"), || sink.finalize())?; - DaftResult::Ok(()) - } - .with_context(move |_| PipelineExecutionSnafu { node_name: name }), - ); - // should wrap in context join handle - - let (right_sender, streaming_receiver) = create_channel(*NUM_CPUS, destination.in_order()); - // now we can start building the right side - self.right.start(right_sender, runtime_handle).await?; - - probe_table_build.await.context(JoinSnafu {})??; - - let hash_join = self.hash_join.clone(); - let probing_op = { - let guard = hash_join.lock().await; - guard.as_intermediate_op() - }; - let probing_node = IntermediateNode::new_with_runtime_stats( - probing_op, - vec![], - self.probe_runtime_stats.clone(), - ); - let worker_senders = probing_node - .spawn_workers(&mut destination, runtime_handle) - .await; - runtime_handle.spawn( - IntermediateNode::send_to_workers( - streaming_receiver, - worker_senders, - runtime_handle.default_morsel_size(), - ), - self.name(), - ); - Ok(()) - } - - fn as_tree_display(&self) -> &dyn TreeDisplay { - self - } -} diff --git a/src/daft-local-execution/src/sinks/hash_join_build.rs b/src/daft-local-execution/src/sinks/hash_join_build.rs new file mode 100644 index 0000000000..3899bdb234 --- /dev/null +++ b/src/daft-local-execution/src/sinks/hash_join_build.rs @@ -0,0 +1,129 @@ +use std::sync::Arc; + +use crate::pipeline::PipelineResultType; +use common_error::DaftResult; +use daft_core::schema::SchemaRef; +use daft_dsl::ExprRef; +use daft_micropartition::MicroPartition; + +use super::blocking_sink::{BlockingSink, BlockingSinkStatus}; +use daft_table::{ProbeTable, ProbeTableBuilder, Table}; + +enum ProbeTableState { + Building { + probe_table_builder: Option, + projection: Vec, + tables: Vec
, + }, + Done { + probe_table: Arc, + tables: Arc>, + }, +} + +impl ProbeTableState { + fn new(key_schema: &SchemaRef, projection: Vec) -> DaftResult { + Ok(Self::Building { + probe_table_builder: Some(ProbeTableBuilder::new(key_schema.clone())?), + projection, + tables: vec![], + }) + } + + fn add_tables(&mut self, input: &Arc) -> DaftResult<()> { + if let Self::Building { + ref mut probe_table_builder, + projection, + tables, + } = self + { + let probe_table_builder = probe_table_builder.as_mut().unwrap(); + for table in input.get_tables()?.iter() { + tables.push(table.clone()); + let join_keys = table.eval_expression_list(projection)?; + + probe_table_builder.add_table(&join_keys)?; + } + Ok(()) + } else { + panic!("add_tables can only be used during the Building Phase") + } + } + fn finalize(&mut self) -> DaftResult<()> { + if let Self::Building { + probe_table_builder, + tables, + .. + } = self + { + let ptb = std::mem::take(probe_table_builder).expect("should be set in building mode"); + let pt = ptb.build(); + + *self = Self::Done { + probe_table: Arc::new(pt), + tables: Arc::new(tables.clone()), + }; + Ok(()) + } else { + panic!("finalize can only be used during the Building Phase") + } + } +} + +pub(crate) struct HashJoinBuildSink { + probe_table_state: ProbeTableState, +} + +impl HashJoinBuildSink { + pub(crate) fn new(key_schema: SchemaRef, projection: Vec) -> DaftResult { + Ok(Self { + probe_table_state: ProbeTableState::new(&key_schema, projection)?, + }) + } + + pub(crate) fn boxed(self) -> Box { + Box::new(self) + } +} + +impl BlockingSink for HashJoinBuildSink { + fn name(&self) -> &'static str { + "HashJoinBuildSink" + } + + fn sink(&mut self, input: &Arc) -> DaftResult { + self.probe_table_state.add_tables(input)?; + Ok(BlockingSinkStatus::NeedMoreInput) + } + fn finalize(&mut self) -> DaftResult> { + self.probe_table_state.finalize()?; + if let ProbeTableState::Done { + probe_table, + tables, + } = &self.probe_table_state + { + Ok(Some((probe_table.clone(), tables.clone()).into())) + } else { + panic!("finalize should only be called after the probe table is built") + } + } +} + +// pub(crate) struct HashJoinOperator { +// join_state: ProbeTableState, +// } + +// impl HashJoinOperator { +// pub(crate) fn new( +// left_on: Vec, +// right_on: Vec, +// join_type: JoinType, +// left_schema: &SchemaRef, +// right_schema: &SchemaRef, +// ) -> DaftResult { +// +// Ok(Self { +// join_state: ProbeTableState::new(&key_schema, left_on)?, +// }) +// } +// } diff --git a/src/daft-local-execution/src/sinks/mod.rs b/src/daft-local-execution/src/sinks/mod.rs index 865c6df167..39910e7995 100644 --- a/src/daft-local-execution/src/sinks/mod.rs +++ b/src/daft-local-execution/src/sinks/mod.rs @@ -1,7 +1,7 @@ pub mod aggregate; pub mod blocking_sink; pub mod concat; -pub mod hash_join; +pub mod hash_join_build; pub mod limit; pub mod sort; pub mod streaming_sink; diff --git a/src/daft-local-execution/src/sinks/sort.rs b/src/daft-local-execution/src/sinks/sort.rs index 0a8ef48b5c..86d951fd83 100644 --- a/src/daft-local-execution/src/sinks/sort.rs +++ b/src/daft-local-execution/src/sinks/sort.rs @@ -6,7 +6,7 @@ use daft_micropartition::MicroPartition; use tracing::instrument; use super::blocking_sink::{BlockingSink, BlockingSinkStatus}; - +use crate::pipeline::PipelineResultType; pub struct SortSink { sort_by: Vec, descending: Vec, @@ -44,7 +44,7 @@ impl BlockingSink for SortSink { } #[instrument(skip_all, name = "SortSink::finalize")] - fn finalize(&mut self) -> DaftResult>> { + fn finalize(&mut self) -> DaftResult> { if let SortState::Building(parts) = &mut self.state { assert!( !parts.is_empty(), @@ -54,7 +54,7 @@ impl BlockingSink for SortSink { MicroPartition::concat(&parts.iter().map(|x| x.as_ref()).collect::>())?; let sorted = Arc::new(concated.sort(&self.sort_by, &self.descending)?); self.state = SortState::Done(sorted.clone()); - Ok(Some(sorted)) + Ok(Some(sorted.into())) } else { panic!("SortSink should be in Building state"); } diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index 76bd905c9c..de674d3e55 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -6,8 +6,8 @@ use daft_micropartition::MicroPartition; use tracing::info_span; use crate::{ - channel::{create_channel, MultiSender}, - pipeline::PipelineNode, + channel::create_multi_channel, + pipeline::{PipelineNode, PipelineResultReceiver}, runtime_stats::RuntimeStatsContext, ExecutionRuntimeHandle, NUM_CPUS, }; @@ -87,16 +87,16 @@ impl PipelineNode for StreamingSinkNode { async fn start( &mut self, - mut destination: MultiSender, + maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result<()> { - let (sender, mut streaming_receiver) = create_channel(*NUM_CPUS, destination.in_order()); - // now we can start building the right side + ) -> crate::Result { let child = self .children .get_mut(0) .expect("we should only have 1 child"); - child.start(sender, runtime_handle).await?; + let mut child_results_receiver = child.start(true, runtime_handle).await?; + let (mut destination_sender, destination_receiver) = + create_multi_channel(*NUM_CPUS, maintain_order); let op = self.op.clone(); let runtime_stats = self.runtime_stats.clone(); runtime_handle.spawn( @@ -106,22 +106,24 @@ impl PipelineNode for StreamingSinkNode { let mut sink = op.lock().await; let mut is_active = true; - while is_active && let Some(val) = streaming_receiver.recv().await { + while is_active && let Some(val) = child_results_receiver.recv().await { + let val = val?; + let val = val.as_data(); runtime_stats.mark_rows_received(val.len() as u64); loop { - let result = runtime_stats.in_span(&span, || sink.execute(0, &val))?; + let result = runtime_stats.in_span(&span, || sink.execute(0, val))?; match result { StreamSinkOutput::HasMoreOutput(mp) => { let len = mp.len() as u64; - let sender = destination.get_next_sender(); - sender.send(mp).await.unwrap(); + let sender = destination_sender.get_next_sender(); + sender.send(mp.into()).await.unwrap(); runtime_stats.mark_rows_emitted(len); } StreamSinkOutput::NeedMoreInput(mp) => { if let Some(mp) = mp { let len = mp.len() as u64; - let sender = destination.get_next_sender(); - sender.send(mp).await.unwrap(); + let sender = destination_sender.get_next_sender(); + sender.send(mp.into()).await.unwrap(); runtime_stats.mark_rows_emitted(len); } break; @@ -129,8 +131,8 @@ impl PipelineNode for StreamingSinkNode { StreamSinkOutput::Finished(mp) => { if let Some(mp) = mp { let len = mp.len() as u64; - let sender = destination.get_next_sender(); - sender.send(mp).await.unwrap(); + let sender = destination_sender.get_next_sender(); + sender.send(mp.into()).await.unwrap(); runtime_stats.mark_rows_emitted(len); } is_active = false; @@ -143,7 +145,7 @@ impl PipelineNode for StreamingSinkNode { }, self.name(), ); - Ok(()) + Ok(destination_receiver.into()) } fn as_tree_display(&self) -> &dyn TreeDisplay { self diff --git a/src/daft-local-execution/src/sources/in_memory.rs b/src/daft-local-execution/src/sources/in_memory.rs index 8d02a7486d..1a0719c3d0 100644 --- a/src/daft-local-execution/src/sources/in_memory.rs +++ b/src/daft-local-execution/src/sources/in_memory.rs @@ -1,11 +1,12 @@ use std::sync::Arc; -use crate::{channel::MultiSender, runtime_stats::RuntimeStatsContext, ExecutionRuntimeHandle}; +use crate::ExecutionRuntimeHandle; use daft_io::IOStatsRef; use daft_micropartition::MicroPartition; use tracing::instrument; use super::source::Source; +use crate::sources::source::SourceStream; pub struct InMemorySource { data: Vec>, @@ -24,24 +25,12 @@ impl Source for InMemorySource { #[instrument(name = "InMemorySource::get_data", level = "info", skip_all)] fn get_data( &self, - mut destination: MultiSender, - runtime_handle: &mut ExecutionRuntimeHandle, - runtime_stats: Arc, + _maintain_order: bool, + _runtime_handle: &mut ExecutionRuntimeHandle, _io_stats: IOStatsRef, - ) -> crate::Result<()> { + ) -> crate::Result> { let data = self.data.clone(); - runtime_handle.spawn( - async move { - for part in data { - let len = part.len(); - let _ = destination.get_next_sender().send(part).await; - runtime_stats.mark_rows_emitted(len as u64); - } - Ok(()) - }, - self.name(), - ); - Ok(()) + Ok(Box::pin(futures::stream::iter(data))) } fn name(&self) -> &'static str { "InMemory" diff --git a/src/daft-local-execution/src/sources/scan_task.rs b/src/daft-local-execution/src/sources/scan_task.rs index b82649713d..0f8f079481 100644 --- a/src/daft-local-execution/src/sources/scan_task.rs +++ b/src/daft-local-execution/src/sources/scan_task.rs @@ -1,5 +1,4 @@ use common_error::DaftResult; -use daft_core::schema::SchemaRef; use daft_csv::{CsvConvertOptions, CsvParseOptions, CsvReadOptions}; use daft_io::IOStatsRef; use daft_json::{JsonConvertOptions, JsonParseOptions, JsonReadOptions}; @@ -10,19 +9,16 @@ use daft_scan::{ storage_config::StorageConfig, ChunkSpec, ScanTask, }; -use daft_stats::{PartitionSpec, TableStatistics}; -use daft_table::Table; -use futures::{stream::BoxStream, StreamExt}; +use futures::{Stream, StreamExt}; use std::sync::Arc; +use tokio_stream::wrappers::ReceiverStream; use crate::{ - channel::{MultiSender, SingleSender}, - runtime_stats::{CountingSender, RuntimeStatsContext}, + channel::{create_channel, Sender}, + sources::source::{Source, SourceStream}, ExecutionRuntimeHandle, }; -use super::source::{Source, SourceStream}; - use tracing::instrument; pub struct ScanTaskSource { @@ -41,16 +37,11 @@ impl ScanTaskSource { )] async fn process_scan_task_stream( scan_task: Arc, - sender: SingleSender, - morsel_size: usize, + sender: Sender>, maintain_order: bool, io_stats: IOStatsRef, - runtime_stats: Arc, ) -> DaftResult<()> { - let mut stream = - stream_scan_task(scan_task, Some(io_stats), maintain_order, morsel_size).await?; - let sender = CountingSender::new(sender, runtime_stats.clone()); - + let mut stream = stream_scan_task(scan_task, Some(io_stats), maintain_order).await?; while let Some(partition) = stream.next().await { let _ = sender.send(partition?).await; } @@ -64,28 +55,45 @@ impl Source for ScanTaskSource { #[instrument(name = "ScanTaskSource::get_data", level = "info", skip_all)] fn get_data( &self, - mut destination: MultiSender, + maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - runtime_stats: Arc, io_stats: IOStatsRef, - ) -> crate::Result<()> { - let morsel_size = runtime_handle.default_morsel_size(); - let maintain_order = destination.in_order(); - for scan_task in self.scan_tasks.clone() { - let sender = destination.get_next_sender(); - runtime_handle.spawn( - Self::process_scan_task_stream( - scan_task, - sender, - morsel_size, - maintain_order, - io_stats.clone(), - runtime_stats.clone(), - ), - self.name(), - ); + ) -> crate::Result> { + match maintain_order { + true => { + let (senders, receivers): (Vec<_>, Vec<_>) = (0..self.scan_tasks.len()) + .map(|_| create_channel(1)) + .unzip(); + for (scan_task, sender) in self.scan_tasks.iter().zip(senders) { + runtime_handle.spawn( + Self::process_scan_task_stream( + scan_task.clone(), + sender, + maintain_order, + io_stats.clone(), + ), + self.name(), + ); + } + let stream = futures::stream::iter(receivers.into_iter().map(ReceiverStream::new)); + Ok(Box::pin(stream.flatten())) + } + false => { + let (sender, receiver) = create_channel(self.scan_tasks.len()); + for scan_task in self.scan_tasks.iter() { + runtime_handle.spawn( + Self::process_scan_task_stream( + scan_task.clone(), + sender.clone(), + maintain_order, + io_stats.clone(), + ), + self.name(), + ); + } + Ok(Box::pin(ReceiverStream::new(receiver))) + } } - Ok(()) } fn name(&self) -> &'static str { @@ -97,8 +105,7 @@ async fn stream_scan_task( scan_task: Arc, io_stats: Option, maintain_order: bool, - morsel_size: usize, -) -> DaftResult> { +) -> DaftResult>> + Send> { let pushdown_columns = scan_task .pushdowns .columns @@ -294,43 +301,21 @@ async fn stream_scan_task( } }; - let mp_stream = chunk_tables_into_micropartition_stream( - table_stream, - scan_task.materialized_schema(), - scan_task.partition_spec().cloned(), - scan_task.statistics.clone(), - morsel_size, - ); - Ok(Box::pin(mp_stream)) -} - -fn chunk_tables_into_micropartition_stream( - mut table_stream: BoxStream<'static, DaftResult
>, - schema: SchemaRef, - partition_spec: Option, - statistics: Option, - morsel_size: usize, -) -> SourceStream<'static> { - let chunked_stream = async_stream::try_stream! { - let mut buffer = vec![]; - let mut total_rows = 0; - while let Some(table) = table_stream.next().await { - let table = table?; - let casted_table = table.cast_to_schema_with_fill(schema.as_ref(), partition_spec.as_ref().map(|pspec| pspec.to_fill_map()).as_ref())?; - total_rows += casted_table.len(); - buffer.push(casted_table); - - if total_rows >= morsel_size { - let mp = Arc::new(MicroPartition::new_loaded(schema.clone(), Arc::new(buffer), statistics.clone())); - buffer = vec![]; - total_rows = 0; - yield mp; - } - } - if !buffer.is_empty() { - let mp = Arc::new(MicroPartition::new_loaded(schema, Arc::new(buffer), statistics)); - yield mp; - } - }; - Box::pin(chunked_stream) + Ok(table_stream.map(move |table| { + let table = table?; + let casted_table = table.cast_to_schema_with_fill( + scan_task.materialized_schema().as_ref(), + scan_task + .partition_spec() + .as_ref() + .map(|pspec| pspec.to_fill_map()) + .as_ref(), + )?; + let mp = Arc::new(MicroPartition::new_loaded( + scan_task.materialized_schema().clone(), + Arc::new(vec![casted_table]), + scan_task.statistics.clone(), + )); + Ok(mp) + })) } diff --git a/src/daft-local-execution/src/sources/source.rs b/src/daft-local-execution/src/sources/source.rs index 818e361589..de42ea26db 100644 --- a/src/daft-local-execution/src/sources/source.rs +++ b/src/daft-local-execution/src/sources/source.rs @@ -1,29 +1,29 @@ use std::sync::Arc; use common_display::{tree::TreeDisplay, utils::bytes_to_human_readable}; -use common_error::DaftResult; use daft_io::{IOStatsContext, IOStatsRef}; use daft_micropartition::MicroPartition; -use futures::stream::BoxStream; +use futures::{stream::BoxStream, StreamExt}; use async_trait::async_trait; use crate::{ - channel::MultiSender, pipeline::PipelineNode, runtime_stats::RuntimeStatsContext, + channel::create_multi_channel, + pipeline::{PipelineNode, PipelineResultReceiver}, + runtime_stats::{CountingSender, RuntimeStatsContext}, ExecutionRuntimeHandle, }; -pub type SourceStream<'a> = BoxStream<'a, DaftResult>>; +pub type SourceStream<'a> = BoxStream<'a, Arc>; pub(crate) trait Source: Send + Sync { fn name(&self) -> &'static str; fn get_data( &self, - destination: MultiSender, + maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - runtime_stats: Arc, io_stats: IOStatsRef, - ) -> crate::Result<()>; + ) -> crate::Result>; } struct SourceNode { @@ -74,15 +74,25 @@ impl PipelineNode for SourceNode { } async fn start( &mut self, - destination: MultiSender, + maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result<()> { - self.source.get_data( - destination, - runtime_handle, - self.runtime_stats.clone(), - self.io_stats.clone(), - ) + ) -> crate::Result { + let mut source_stream = + self.source + .get_data(maintain_order, runtime_handle, self.io_stats.clone())?; + + let (mut tx, rx) = create_multi_channel(1, maintain_order); + let counting_sender = CountingSender::new(tx.get_next_sender(), self.runtime_stats.clone()); + runtime_handle.spawn( + async move { + while let Some(part) = source_stream.next().await { + let _ = counting_sender.send(part.into()).await; + } + Ok(()) + }, + self.name(), + ); + Ok(rx.into()) } fn as_tree_display(&self) -> &dyn TreeDisplay { self From faa48b60f0cfe0dc77eff7b56ad106d592727574 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 23 Aug 2024 18:05:43 -0700 Subject: [PATCH 2/6] lil cleanup --- Cargo.lock | 2 - src/daft-local-execution/Cargo.toml | 2 - src/daft-local-execution/src/channel.rs | 36 +++++++------ .../src/intermediate_ops/hash_join_probe.rs | 16 +++--- .../src/intermediate_ops/intermediate_op.rs | 7 ++- src/daft-local-execution/src/pipeline.rs | 16 ++---- .../src/sinks/hash_join_build.rs | 19 ------- .../src/sinks/streaming_sink.rs | 12 ++--- .../src/sources/scan_task.rs | 54 ++++++++----------- .../src/sources/source.rs | 4 +- 10 files changed, 65 insertions(+), 103 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 79fd15515c..b547e28349 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1888,7 +1888,6 @@ dependencies = [ name = "daft-local-execution" version = "0.3.0-dev0" dependencies = [ - "async-stream", "async-trait", "common-daft-config", "common-display", @@ -1904,7 +1903,6 @@ dependencies = [ "daft-physical-plan", "daft-plan", "daft-scan", - "daft-stats", "daft-table", "futures", "lazy_static", diff --git a/src/daft-local-execution/Cargo.toml b/src/daft-local-execution/Cargo.toml index 5c8e208dab..66949390f5 100644 --- a/src/daft-local-execution/Cargo.toml +++ b/src/daft-local-execution/Cargo.toml @@ -1,5 +1,4 @@ [dependencies] -async-stream = {workspace = true} async-trait = {workspace = true} common-daft-config = {path = "../common/daft-config", default-features = false} common-display = {path = "../common/display", default-features = false} @@ -15,7 +14,6 @@ daft-parquet = {path = "../daft-parquet", default-features = false} daft-physical-plan = {path = "../daft-physical-plan", default-features = false} daft-plan = {path = "../daft-plan", default-features = false} daft-scan = {path = "../daft-scan", default-features = false} -daft-stats = {path = "../daft-stats", default-features = false} daft-table = {path = "../daft-table", default-features = false} futures = {workspace = true} lazy_static = {workspace = true} diff --git a/src/daft-local-execution/src/channel.rs b/src/daft-local-execution/src/channel.rs index 0f8cd621e6..db73bbb64f 100644 --- a/src/daft-local-execution/src/channel.rs +++ b/src/daft-local-execution/src/channel.rs @@ -1,3 +1,10 @@ +use std::sync::Arc; + +use crate::{ + pipeline::PipelineResultType, + runtime_stats::{CountingSender, RuntimeStatsContext}, +}; + pub type OneShotSender = tokio::sync::oneshot::Sender; pub type OneShotReceiver = tokio::sync::oneshot::Receiver; @@ -12,10 +19,7 @@ pub fn create_channel(buffer_size: usize) -> (Sender, Receiver) { tokio::sync::mpsc::channel(buffer_size) } -pub fn create_multi_channel( - buffer_size: usize, - in_order: bool, -) -> (MultiSender, MultiReceiver) { +pub fn create_multi_channel(buffer_size: usize, in_order: bool) -> (MultiSender, MultiReceiver) { if in_order { let (senders, receivers) = (0..buffer_size).map(|_| create_channel(1)).unzip(); let sender = MultiSender::InOrder(RoundRobinSender::new(senders)); @@ -29,16 +33,16 @@ pub fn create_multi_channel( } } -pub enum MultiSender { - InOrder(RoundRobinSender), - OutOfOrder(Sender), +pub enum MultiSender { + InOrder(RoundRobinSender), + OutOfOrder(Sender), } -impl MultiSender { - pub fn get_next_sender(&mut self) -> Sender { +impl MultiSender { + pub fn get_next_sender(&mut self, stats: &Arc) -> CountingSender { match self { - Self::InOrder(sender) => sender.get_next_sender(), - Self::OutOfOrder(sender) => sender.clone(), + Self::InOrder(sender) => CountingSender::new(sender.get_next_sender(), stats.clone()), + Self::OutOfOrder(sender) => CountingSender::new(sender.clone(), stats.clone()), } } } @@ -62,13 +66,13 @@ impl RoundRobinSender { } } -pub enum MultiReceiver { - InOrder(RoundRobinReceiver), - OutOfOrder(Receiver), +pub enum MultiReceiver { + InOrder(RoundRobinReceiver), + OutOfOrder(Receiver), } -impl MultiReceiver { - pub async fn recv(&mut self) -> Option { +impl MultiReceiver { + pub async fn recv(&mut self) -> Option { match self { Self::InOrder(receiver) => receiver.recv().await, Self::OutOfOrder(receiver) => receiver.recv().await, diff --git a/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs index c2f3ff4a3e..b8a4053eb8 100644 --- a/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs @@ -19,6 +19,14 @@ enum HashJoinProbeState { } impl HashJoinProbeState { + fn set_table(&mut self, table: &Arc, tables: &Arc>) { + if let HashJoinProbeState::Building = self { + *self = HashJoinProbeState::ReadyToProbe(table.clone(), tables.clone()); + } else { + panic!("HashJoinProbeState should only be in Building state when setting table") + } + } + fn probe( &self, input: &Arc, @@ -106,17 +114,13 @@ impl IntermediateOperator for HashJoinProbeOperator { ) -> DaftResult { match idx { 0 => { - let (probe_table, tables) = input.as_probe_table(); let state = state .expect("HashJoinProbeOperator should have state") .as_any_mut() .downcast_mut::() .expect("HashJoinProbeOperator state should be HashJoinProbeState"); - if let HashJoinProbeState::Building = state { - *state = HashJoinProbeState::ReadyToProbe(probe_table.clone(), tables.clone()); - } else { - panic!("HashJoinProbeOperator should only be in Building state on first input"); - } + let (probe_table, tables) = input.as_probe_table(); + state.set_table(probe_table, tables); Ok(IntermediateOperatorResult::NeedMoreInput(None)) } _ => { diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index 283b7badb8..b59f1773ba 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -74,12 +74,11 @@ impl IntermediateNode { pub async fn run_worker( op: Arc, mut receiver: Receiver<(usize, PipelineResultType)>, - sender: Sender, + sender: CountingSender, rt_context: Arc, ) -> DaftResult<()> { let span = info_span!("IntermediateOp::execute"); let mut state = op.make_state(); - let sender = CountingSender::new(sender, rt_context.clone()); while let Some((idx, morsel)) = receiver.recv().await { let len = match morsel { PipelineResultType::Data(ref data) => data.len(), @@ -105,13 +104,13 @@ impl IntermediateNode { pub async fn spawn_workers( &self, num_workers: usize, - destination: &mut MultiSender, + destination: &mut MultiSender, runtime_handle: &mut ExecutionRuntimeHandle, ) -> Vec> { let mut worker_senders = Vec::with_capacity(num_workers); for _ in 0..num_workers { let (worker_sender, worker_receiver) = create_channel(1); - let destination_sender = destination.get_next_sender(); + let destination_sender = destination.get_next_sender(&self.runtime_stats); runtime_handle.spawn( Self::run_worker( self.intermediate_op.clone(), diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index f87e7f8fea..6dce8ca44e 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, sync::Arc}; use crate::{ - channel::{MultiReceiver, OneShotReceiver, Receiver}, + channel::{MultiReceiver, OneShotReceiver}, intermediate_ops::{ aggregate::AggregateOperator, filter::FilterOperator, hash_join_probe::HashJoinProbeOperator, intermediate_op::IntermediateNode, @@ -73,19 +73,12 @@ impl PipelineResultType { } pub enum PipelineResultReceiver { - Single(Receiver), - Multi(MultiReceiver), + Multi(MultiReceiver), OneShot(OneShotReceiver, bool), } -impl From> for PipelineResultReceiver { - fn from(rx: Receiver) -> Self { - PipelineResultReceiver::Multi(MultiReceiver::OutOfOrder(rx)) - } -} - -impl From> for PipelineResultReceiver { - fn from(rx: MultiReceiver) -> Self { +impl From for PipelineResultReceiver { + fn from(rx: MultiReceiver) -> Self { PipelineResultReceiver::Multi(rx) } } @@ -99,7 +92,6 @@ impl From> for PipelineResultReceiver { impl PipelineResultReceiver { pub async fn recv(&mut self) -> Option> { match self { - PipelineResultReceiver::Single(rx) => rx.recv().await.map(Ok), PipelineResultReceiver::Multi(rx) => rx.recv().await.map(Ok), PipelineResultReceiver::OneShot(rx, done) => { if *done { diff --git a/src/daft-local-execution/src/sinks/hash_join_build.rs b/src/daft-local-execution/src/sinks/hash_join_build.rs index 3899bdb234..54965371cf 100644 --- a/src/daft-local-execution/src/sinks/hash_join_build.rs +++ b/src/daft-local-execution/src/sinks/hash_join_build.rs @@ -108,22 +108,3 @@ impl BlockingSink for HashJoinBuildSink { } } } - -// pub(crate) struct HashJoinOperator { -// join_state: ProbeTableState, -// } - -// impl HashJoinOperator { -// pub(crate) fn new( -// left_on: Vec, -// right_on: Vec, -// join_type: JoinType, -// left_schema: &SchemaRef, -// right_schema: &SchemaRef, -// ) -> DaftResult { -// -// Ok(Self { -// join_state: ProbeTableState::new(&key_schema, left_on)?, -// }) -// } -// } diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index de674d3e55..4ab12e8b4c 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -114,26 +114,20 @@ impl PipelineNode for StreamingSinkNode { let result = runtime_stats.in_span(&span, || sink.execute(0, val))?; match result { StreamSinkOutput::HasMoreOutput(mp) => { - let len = mp.len() as u64; - let sender = destination_sender.get_next_sender(); + let sender = destination_sender.get_next_sender(&runtime_stats); sender.send(mp.into()).await.unwrap(); - runtime_stats.mark_rows_emitted(len); } StreamSinkOutput::NeedMoreInput(mp) => { if let Some(mp) = mp { - let len = mp.len() as u64; - let sender = destination_sender.get_next_sender(); + let sender = destination_sender.get_next_sender(&runtime_stats); sender.send(mp.into()).await.unwrap(); - runtime_stats.mark_rows_emitted(len); } break; } StreamSinkOutput::Finished(mp) => { if let Some(mp) = mp { - let len = mp.len() as u64; - let sender = destination_sender.get_next_sender(); + let sender = destination_sender.get_next_sender(&runtime_stats); sender.send(mp.into()).await.unwrap(); - runtime_stats.mark_rows_emitted(len); } is_active = false; break; diff --git a/src/daft-local-execution/src/sources/scan_task.rs b/src/daft-local-execution/src/sources/scan_task.rs index 0f8f079481..76c08b565c 100644 --- a/src/daft-local-execution/src/sources/scan_task.rs +++ b/src/daft-local-execution/src/sources/scan_task.rs @@ -59,41 +59,33 @@ impl Source for ScanTaskSource { runtime_handle: &mut ExecutionRuntimeHandle, io_stats: IOStatsRef, ) -> crate::Result> { - match maintain_order { - true => { - let (senders, receivers): (Vec<_>, Vec<_>) = (0..self.scan_tasks.len()) - .map(|_| create_channel(1)) - .unzip(); - for (scan_task, sender) in self.scan_tasks.iter().zip(senders) { - runtime_handle.spawn( - Self::process_scan_task_stream( - scan_task.clone(), - sender, - maintain_order, - io_stats.clone(), - ), - self.name(), - ); - } - let stream = futures::stream::iter(receivers.into_iter().map(ReceiverStream::new)); - Ok(Box::pin(stream.flatten())) - } + let (senders, receivers): (Vec<_>, Vec<_>) = match maintain_order { + true => (0..self.scan_tasks.len()) + .map(|_| create_channel(1)) + .unzip(), false => { let (sender, receiver) = create_channel(self.scan_tasks.len()); - for scan_task in self.scan_tasks.iter() { - runtime_handle.spawn( - Self::process_scan_task_stream( - scan_task.clone(), - sender.clone(), - maintain_order, - io_stats.clone(), - ), - self.name(), - ); - } - Ok(Box::pin(ReceiverStream::new(receiver))) + ( + std::iter::repeat(sender) + .take(self.scan_tasks.len()) + .collect(), + vec![receiver], + ) } + }; + for (scan_task, sender) in self.scan_tasks.iter().zip(senders) { + runtime_handle.spawn( + Self::process_scan_task_stream( + scan_task.clone(), + sender, + maintain_order, + io_stats.clone(), + ), + self.name(), + ); } + let stream = futures::stream::iter(receivers.into_iter().map(ReceiverStream::new)); + Ok(Box::pin(stream.flatten())) } fn name(&self) -> &'static str { diff --git a/src/daft-local-execution/src/sources/source.rs b/src/daft-local-execution/src/sources/source.rs index de42ea26db..1b98759e89 100644 --- a/src/daft-local-execution/src/sources/source.rs +++ b/src/daft-local-execution/src/sources/source.rs @@ -10,7 +10,7 @@ use async_trait::async_trait; use crate::{ channel::create_multi_channel, pipeline::{PipelineNode, PipelineResultReceiver}, - runtime_stats::{CountingSender, RuntimeStatsContext}, + runtime_stats::RuntimeStatsContext, ExecutionRuntimeHandle, }; @@ -82,7 +82,7 @@ impl PipelineNode for SourceNode { .get_data(maintain_order, runtime_handle, self.io_stats.clone())?; let (mut tx, rx) = create_multi_channel(1, maintain_order); - let counting_sender = CountingSender::new(tx.get_next_sender(), self.runtime_stats.clone()); + let counting_sender = tx.get_next_sender(&self.runtime_stats); runtime_handle.spawn( async move { while let Some(part) = source_stream.next().await { From 5357ffae3da4e1b5d5c7131a5b442fa66a812b9e Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Mon, 26 Aug 2024 17:55:16 -0700 Subject: [PATCH 3/6] pipline channel --- src/daft-local-execution/src/channel.rs | 81 +++++++++++-------- .../src/intermediate_ops/aggregate.rs | 5 +- .../src/intermediate_ops/filter.rs | 5 +- .../src/intermediate_ops/hash_join_probe.rs | 13 +-- .../src/intermediate_ops/intermediate_op.rs | 42 +++++----- .../src/intermediate_ops/project.rs | 5 +- src/daft-local-execution/src/pipeline.rs | 44 ++-------- src/daft-local-execution/src/run.rs | 7 +- src/daft-local-execution/src/runtime_stats.rs | 30 ++++++- .../src/sinks/blocking_sink.rs | 35 ++++---- .../src/sinks/hash_join_build.rs | 1 + .../src/sinks/streaming_sink.rs | 26 +++--- .../src/sources/source.rs | 12 ++- 13 files changed, 151 insertions(+), 155 deletions(-) diff --git a/src/daft-local-execution/src/channel.rs b/src/daft-local-execution/src/channel.rs index db73bbb64f..bb22b9d4ea 100644 --- a/src/daft-local-execution/src/channel.rs +++ b/src/daft-local-execution/src/channel.rs @@ -2,16 +2,9 @@ use std::sync::Arc; use crate::{ pipeline::PipelineResultType, - runtime_stats::{CountingSender, RuntimeStatsContext}, + runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}, }; -pub type OneShotSender = tokio::sync::oneshot::Sender; -pub type OneShotReceiver = tokio::sync::oneshot::Receiver; - -pub fn create_one_shot_channel() -> (OneShotSender, OneShotReceiver) { - tokio::sync::oneshot::channel() -} - pub type Sender = tokio::sync::mpsc::Sender; pub type Receiver = tokio::sync::mpsc::Receiver; @@ -19,33 +12,57 @@ pub fn create_channel(buffer_size: usize) -> (Sender, Receiver) { tokio::sync::mpsc::channel(buffer_size) } -pub fn create_multi_channel(buffer_size: usize, in_order: bool) -> (MultiSender, MultiReceiver) { - if in_order { - let (senders, receivers) = (0..buffer_size).map(|_| create_channel(1)).unzip(); - let sender = MultiSender::InOrder(RoundRobinSender::new(senders)); - let receiver = MultiReceiver::InOrder(RoundRobinReceiver::new(receivers)); - (sender, receiver) - } else { - let (sender, receiver) = create_channel(buffer_size); - let sender = MultiSender::OutOfOrder(sender); - let receiver = MultiReceiver::OutOfOrder(receiver); - (sender, receiver) +pub struct PipelineChannel { + sender: PipelineSender, + receiver: PipelineReceiver, +} + +impl PipelineChannel { + pub fn new(buffer_size: usize, in_order: bool) -> Self { + match in_order { + true => { + let (senders, receivers) = (0..buffer_size).map(|_| create_channel(1)).unzip(); + let sender = PipelineSender::InOrder(RoundRobinSender::new(senders)); + let receiver = PipelineReceiver::InOrder(RoundRobinReceiver::new(receivers)); + Self { sender, receiver } + } + false => { + let (sender, receiver) = create_channel(buffer_size); + let sender = PipelineSender::OutOfOrder(sender); + let receiver = PipelineReceiver::OutOfOrder(receiver); + Self { sender, receiver } + } + } + } + + fn get_next_sender(&mut self) -> Sender { + match &mut self.sender { + PipelineSender::InOrder(rr) => rr.get_next_sender(), + PipelineSender::OutOfOrder(sender) => sender.clone(), + } + } + + pub(crate) fn get_next_sender_with_stats( + &mut self, + rt: &Arc, + ) -> CountingSender { + CountingSender::new(self.get_next_sender(), rt.clone()) + } + + pub fn get_receiver(self) -> PipelineReceiver { + self.receiver + } + + pub(crate) fn get_receiver_with_stats(self, rt: &Arc) -> CountingReceiver { + CountingReceiver::new(self.get_receiver(), rt.clone()) } } -pub enum MultiSender { +pub enum PipelineSender { InOrder(RoundRobinSender), OutOfOrder(Sender), } -impl MultiSender { - pub fn get_next_sender(&mut self, stats: &Arc) -> CountingSender { - match self { - Self::InOrder(sender) => CountingSender::new(sender.get_next_sender(), stats.clone()), - Self::OutOfOrder(sender) => CountingSender::new(sender.clone(), stats.clone()), - } - } -} pub struct RoundRobinSender { senders: Vec>, curr_sender_idx: usize, @@ -66,16 +83,16 @@ impl RoundRobinSender { } } -pub enum MultiReceiver { +pub enum PipelineReceiver { InOrder(RoundRobinReceiver), OutOfOrder(Receiver), } -impl MultiReceiver { +impl PipelineReceiver { pub async fn recv(&mut self) -> Option { match self { - Self::InOrder(receiver) => receiver.recv().await, - Self::OutOfOrder(receiver) => receiver.recv().await, + PipelineReceiver::InOrder(rr) => rr.recv().await, + PipelineReceiver::OutOfOrder(r) => r.recv().await, } } } diff --git a/src/daft-local-execution/src/intermediate_ops/aggregate.rs b/src/daft-local-execution/src/intermediate_ops/aggregate.rs index 69f2c5a6e9..c83bc3f540 100644 --- a/src/daft-local-execution/src/intermediate_ops/aggregate.rs +++ b/src/daft-local-execution/src/intermediate_ops/aggregate.rs @@ -29,11 +29,10 @@ impl IntermediateOperator for AggregateOperator { fn execute( &self, _idx: usize, - input: &PipelineResultType, + input: PipelineResultType, _state: Option<&mut Box>, ) -> DaftResult { - let input = input.as_data(); - let out = input.agg(&self.agg_exprs, &self.group_by)?; + let out = input.data().agg(&self.agg_exprs, &self.group_by)?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( out, )))) diff --git a/src/daft-local-execution/src/intermediate_ops/filter.rs b/src/daft-local-execution/src/intermediate_ops/filter.rs index e8b1611b87..ef85379ec6 100644 --- a/src/daft-local-execution/src/intermediate_ops/filter.rs +++ b/src/daft-local-execution/src/intermediate_ops/filter.rs @@ -25,11 +25,10 @@ impl IntermediateOperator for FilterOperator { fn execute( &self, _idx: usize, - input: &PipelineResultType, + input: PipelineResultType, _state: Option<&mut Box>, ) -> DaftResult { - let input = input.as_data(); - let out = input.filter(&[self.predicate.clone()])?; + let out = input.data().filter(&[self.predicate.clone()])?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( out, )))) diff --git a/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs index b8a4053eb8..7b1bb04d22 100644 --- a/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs @@ -19,9 +19,9 @@ enum HashJoinProbeState { } impl HashJoinProbeState { - fn set_table(&mut self, table: &Arc, tables: &Arc>) { + fn set_table(&mut self, table: Arc, tables: Arc>) { if let HashJoinProbeState::Building = self { - *self = HashJoinProbeState::ReadyToProbe(table.clone(), tables.clone()); + *self = HashJoinProbeState::ReadyToProbe(table, tables); } else { panic!("HashJoinProbeState should only be in Building state when setting table") } @@ -29,7 +29,7 @@ impl HashJoinProbeState { fn probe( &self, - input: &Arc, + input: Arc, right_on: &[ExprRef], pruned_right_side_columns: &[String], ) -> DaftResult> { @@ -109,9 +109,10 @@ impl IntermediateOperator for HashJoinProbeOperator { fn execute( &self, idx: usize, - input: &PipelineResultType, + input: PipelineResultType, state: Option<&mut Box>, ) -> DaftResult { + println!("HashJoinProbeOperator::execute: idx: {}", idx); match idx { 0 => { let state = state @@ -119,7 +120,7 @@ impl IntermediateOperator for HashJoinProbeOperator { .as_any_mut() .downcast_mut::() .expect("HashJoinProbeOperator state should be HashJoinProbeState"); - let (probe_table, tables) = input.as_probe_table(); + let (probe_table, tables) = input.probe_table(); state.set_table(probe_table, tables); Ok(IntermediateOperatorResult::NeedMoreInput(None)) } @@ -129,7 +130,7 @@ impl IntermediateOperator for HashJoinProbeOperator { .as_any_mut() .downcast_mut::() .expect("HashJoinProbeOperator state should be HashJoinProbeState"); - let input = input.as_data(); + let input = input.data(); let out = state.probe(input, &self.right_on, &self.pruned_right_side_columns)?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) } diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index b59f1773ba..6c4253fe98 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -8,9 +8,9 @@ use tracing::{info_span, instrument}; use async_trait::async_trait; use crate::{ - channel::{create_channel, create_multi_channel, MultiSender, Receiver, Sender}, - pipeline::{PipelineNode, PipelineResultReceiver, PipelineResultType}, - runtime_stats::{CountingSender, RuntimeStatsContext}, + channel::{create_channel, PipelineChannel, Receiver, Sender}, + pipeline::{PipelineNode, PipelineResultType}, + runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}, ExecutionRuntimeHandle, NUM_CPUS, }; @@ -30,7 +30,7 @@ pub trait IntermediateOperator: Send + Sync { fn execute( &self, idx: usize, - input: &PipelineResultType, + input: PipelineResultType, state: Option<&mut Box>, ) -> DaftResult; fn name(&self) -> &'static str; @@ -80,14 +80,7 @@ impl IntermediateNode { let span = info_span!("IntermediateOp::execute"); let mut state = op.make_state(); while let Some((idx, morsel)) = receiver.recv().await { - let len = match morsel { - PipelineResultType::Data(ref data) => data.len(), - PipelineResultType::ProbeTable(_, ref tables) => { - tables.iter().map(|t| t.len()).sum() - } - }; - rt_context.mark_rows_received(len as u64); - let result = rt_context.in_span(&span, || op.execute(idx, &morsel, state.as_mut()))?; + let result = rt_context.in_span(&span, || op.execute(idx, morsel, state.as_mut()))?; match result { IntermediateOperatorResult::NeedMoreInput(Some(mp)) => { let _ = sender.send(mp.into()).await; @@ -104,13 +97,14 @@ impl IntermediateNode { pub async fn spawn_workers( &self, num_workers: usize, - destination: &mut MultiSender, + destination_channel: &mut PipelineChannel, runtime_handle: &mut ExecutionRuntimeHandle, ) -> Vec> { let mut worker_senders = Vec::with_capacity(num_workers); for _ in 0..num_workers { let (worker_sender, worker_receiver) = create_channel(1); - let destination_sender = destination.get_next_sender(&self.runtime_stats); + let destination_sender = + destination_channel.get_next_sender_with_stats(&self.runtime_stats); runtime_handle.spawn( Self::run_worker( self.intermediate_op.clone(), @@ -126,7 +120,7 @@ impl IntermediateNode { } pub async fn send_to_workers( - receivers: Vec, + receivers: Vec, worker_senders: Vec>, morsel_size: usize, ) -> DaftResult<()> { @@ -138,15 +132,15 @@ impl IntermediateNode { }; for (idx, mut receiver) in receivers.into_iter().enumerate() { + println!("idx: {}", idx); let mut buffer = OperatorBuffer::new(morsel_size); while let Some(morsel) = receiver.recv().await { - let morsel = morsel?; if morsel.should_broadcast() { for worker_sender in worker_senders.iter() { let _ = worker_sender.send((idx, morsel.clone())).await; } } else { - buffer.push(morsel.as_data().clone()); + buffer.push(morsel.data().clone()); if let Some(ready) = buffer.try_clear() { let _ = send_to_next_worker(idx, ready?.into()).await; } @@ -200,17 +194,17 @@ impl PipelineNode for IntermediateNode { &mut self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result { + ) -> crate::Result { let mut child_result_receivers = Vec::with_capacity(self.children.len()); for child in self.children.iter_mut() { - let child_result_receiver = child.start(maintain_order, runtime_handle).await?; - child_result_receivers.push(child_result_receiver); + let child_result_channel = child.start(maintain_order, runtime_handle).await?; + child_result_receivers + .push(child_result_channel.get_receiver_with_stats(&self.runtime_stats)); } - let (mut destination_sender, destination_receiver) = - create_multi_channel(*NUM_CPUS, maintain_order); + let mut destination_channel = PipelineChannel::new(*NUM_CPUS, maintain_order); let worker_senders = self - .spawn_workers(*NUM_CPUS, &mut destination_sender, runtime_handle) + .spawn_workers(*NUM_CPUS, &mut destination_channel, runtime_handle) .await; runtime_handle.spawn( Self::send_to_workers( @@ -220,7 +214,7 @@ impl PipelineNode for IntermediateNode { ), self.intermediate_op.name(), ); - Ok(destination_receiver.into()) + Ok(destination_channel) } fn as_tree_display(&self) -> &dyn TreeDisplay { self diff --git a/src/daft-local-execution/src/intermediate_ops/project.rs b/src/daft-local-execution/src/intermediate_ops/project.rs index 090116ad71..bc9f7f4eea 100644 --- a/src/daft-local-execution/src/intermediate_ops/project.rs +++ b/src/daft-local-execution/src/intermediate_ops/project.rs @@ -25,11 +25,10 @@ impl IntermediateOperator for ProjectOperator { fn execute( &self, _idx: usize, - input: &PipelineResultType, + input: PipelineResultType, _state: Option<&mut Box>, ) -> DaftResult { - let input = input.as_data(); - let out = input.eval_expression_list(&self.projection)?; + let out = input.data().eval_expression_list(&self.projection)?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( out, )))) diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index 6dce8ca44e..ad9668637f 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, sync::Arc}; use crate::{ - channel::{MultiReceiver, OneShotReceiver}, + channel::PipelineChannel, intermediate_ops::{ aggregate::AggregateOperator, filter::FilterOperator, hash_join_probe::HashJoinProbeOperator, intermediate_op::IntermediateNode, @@ -13,7 +13,7 @@ use crate::{ streaming_sink::StreamingSinkNode, }, sources::in_memory::InMemorySource, - ExecutionRuntimeHandle, OneShotRecvSnafu, PipelineCreationSnafu, + ExecutionRuntimeHandle, PipelineCreationSnafu, }; use async_trait::async_trait; @@ -53,14 +53,14 @@ impl From<(Arc, Arc>)> for PipelineResultType { } impl PipelineResultType { - pub fn as_data(&self) -> &Arc { + pub fn data(self) -> Arc { match self { PipelineResultType::Data(data) => data, _ => panic!("Expected data"), } } - pub fn as_probe_table(&self) -> (&Arc, &Arc>) { + pub fn probe_table(self) -> (Arc, Arc>) { match self { PipelineResultType::ProbeTable(probe_table, tables) => (probe_table, tables), _ => panic!("Expected probe table"), @@ -72,40 +72,6 @@ impl PipelineResultType { } } -pub enum PipelineResultReceiver { - Multi(MultiReceiver), - OneShot(OneShotReceiver, bool), -} - -impl From for PipelineResultReceiver { - fn from(rx: MultiReceiver) -> Self { - PipelineResultReceiver::Multi(rx) - } -} - -impl From> for PipelineResultReceiver { - fn from(rx: OneShotReceiver) -> Self { - PipelineResultReceiver::OneShot(rx, false) - } -} - -impl PipelineResultReceiver { - pub async fn recv(&mut self) -> Option> { - match self { - PipelineResultReceiver::Multi(rx) => rx.recv().await.map(Ok), - PipelineResultReceiver::OneShot(rx, done) => { - if *done { - None - } else { - let result = rx.await.context(OneShotRecvSnafu); - *done = true; - Some(result) - } - } - } - } -} - #[async_trait] pub trait PipelineNode: Sync + Send + TreeDisplay { fn children(&self) -> Vec<&dyn PipelineNode>; @@ -114,7 +80,7 @@ pub trait PipelineNode: Sync + Send + TreeDisplay { &mut self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result; + ) -> crate::Result; fn as_tree_display(&self) -> &dyn TreeDisplay; } diff --git a/src/daft-local-execution/src/run.rs b/src/daft-local-execution/src/run.rs index 7767635f1f..04726f86e5 100644 --- a/src/daft-local-execution/src/run.rs +++ b/src/daft-local-execution/src/run.rs @@ -136,9 +136,12 @@ pub fn run_local( .expect("Failed to create tokio runtime"); runtime.block_on(async { let mut runtime_handle = ExecutionRuntimeHandle::new(cfg.default_morsel_size); - let mut receiver = pipeline.start(true, &mut runtime_handle).await?; + let mut receiver = pipeline + .start(true, &mut runtime_handle) + .await? + .get_receiver(); while let Some(val) = receiver.recv().await { - let _ = tx.send(val?.as_data().clone()).await; + let _ = tx.send(val.data()).await; } while let Some(result) = runtime_handle.join_next().await { diff --git a/src/daft-local-execution/src/runtime_stats.rs b/src/daft-local-execution/src/runtime_stats.rs index adc1c8270d..bcd9e4a1f5 100644 --- a/src/daft-local-execution/src/runtime_stats.rs +++ b/src/daft-local-execution/src/runtime_stats.rs @@ -7,7 +7,10 @@ use std::{ use tokio::sync::mpsc::error::SendError; -use crate::{channel::Sender, pipeline::PipelineResultType}; +use crate::{ + channel::{PipelineReceiver, Sender}, + pipeline::PipelineResultType, +}; #[derive(Default)] pub(crate) struct RuntimeStatsContext { @@ -129,3 +132,28 @@ impl CountingSender { Ok(()) } } + +pub(crate) struct CountingReceiver { + receiver: PipelineReceiver, + rt: Arc, +} + +impl CountingReceiver { + pub(crate) fn new(receiver: PipelineReceiver, rt: Arc) -> Self { + Self { receiver, rt } + } + #[inline] + pub(crate) async fn recv(&mut self) -> Option { + let v = self.receiver.recv().await; + if let Some(ref v) = v { + let len = match v { + PipelineResultType::Data(ref mp) => mp.len(), + PipelineResultType::ProbeTable(_, ref tables) => { + tables.iter().map(|t| t.len()).sum() + } + }; + self.rt.mark_rows_received(len as u64); + } + v + } +} diff --git a/src/daft-local-execution/src/sinks/blocking_sink.rs b/src/daft-local-execution/src/sinks/blocking_sink.rs index 1399791f96..63ea4c2cee 100644 --- a/src/daft-local-execution/src/sinks/blocking_sink.rs +++ b/src/daft-local-execution/src/sinks/blocking_sink.rs @@ -6,8 +6,8 @@ use daft_micropartition::MicroPartition; use tracing::info_span; use crate::{ - channel::create_one_shot_channel, - pipeline::{PipelineNode, PipelineResultReceiver, PipelineResultType}, + channel::PipelineChannel, + pipeline::{PipelineNode, PipelineResultType}, runtime_stats::RuntimeStatsContext, ExecutionRuntimeHandle, }; @@ -80,25 +80,27 @@ impl PipelineNode for BlockingSinkNode { async fn start( &mut self, - _maintain_order: bool, + maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result { + ) -> crate::Result { let child = self.child.as_mut(); - let mut child_results_receiver = child.start(false, runtime_handle).await?; - let op = self.op.clone(); + let mut child_results_receiver = child + .start(false, runtime_handle) + .await? + .get_receiver_with_stats(&self.runtime_stats); - let (destination_sender, destination_receiver) = create_one_shot_channel(); + let mut destination_channel = PipelineChannel::new(1, maintain_order); + let destination_sender = + destination_channel.get_next_sender_with_stats(&self.runtime_stats); + let op = self.op.clone(); let rt_context = self.runtime_stats.clone(); runtime_handle.spawn( async move { let span = info_span!("BlockingSinkNode::execute"); let mut guard = op.lock().await; while let Some(val) = child_results_receiver.recv().await { - let val = val?; - let val = val.as_data(); - rt_context.mark_rows_received(val.len() as u64); if let BlockingSinkStatus::Finished = - rt_context.in_span(&span, || guard.sink(val))? + rt_context.in_span(&span, || guard.sink(&val.data()))? { break; } @@ -108,20 +110,13 @@ impl PipelineNode for BlockingSinkNode { guard.finalize() })?; if let Some(part) = finalized_result { - let len = match part { - PipelineResultType::Data(ref part) => part.len(), - PipelineResultType::ProbeTable(_, ref tables) => { - tables.iter().map(|t| t.len()).sum() - } - }; - let _ = destination_sender.send(part); - rt_context.mark_rows_emitted(len as u64); + let _ = destination_sender.send(part).await; } Ok(()) }, self.name(), ); - Ok(destination_receiver.into()) + Ok(destination_channel) } fn as_tree_display(&self) -> &dyn TreeDisplay { self diff --git a/src/daft-local-execution/src/sinks/hash_join_build.rs b/src/daft-local-execution/src/sinks/hash_join_build.rs index 54965371cf..0a86756334 100644 --- a/src/daft-local-execution/src/sinks/hash_join_build.rs +++ b/src/daft-local-execution/src/sinks/hash_join_build.rs @@ -102,6 +102,7 @@ impl BlockingSink for HashJoinBuildSink { tables, } = &self.probe_table_state { + println!("HashJoinBuildSink::finalize: table_len: {}", tables.len()); Ok(Some((probe_table.clone(), tables.clone()).into())) } else { panic!("finalize should only be called after the probe table is built") diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index 4ab12e8b4c..afe4be0d3b 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -6,9 +6,7 @@ use daft_micropartition::MicroPartition; use tracing::info_span; use crate::{ - channel::create_multi_channel, - pipeline::{PipelineNode, PipelineResultReceiver}, - runtime_stats::RuntimeStatsContext, + channel::PipelineChannel, pipeline::PipelineNode, runtime_stats::RuntimeStatsContext, ExecutionRuntimeHandle, NUM_CPUS, }; use async_trait::async_trait; @@ -89,14 +87,17 @@ impl PipelineNode for StreamingSinkNode { &mut self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result { + ) -> crate::Result { let child = self .children .get_mut(0) .expect("we should only have 1 child"); - let mut child_results_receiver = child.start(true, runtime_handle).await?; - let (mut destination_sender, destination_receiver) = - create_multi_channel(*NUM_CPUS, maintain_order); + let child_results_channel = child.start(true, runtime_handle).await?; + let mut child_results_receiver = + child_results_channel.get_receiver_with_stats(&self.runtime_stats); + + let mut destination_channel = PipelineChannel::new(*NUM_CPUS, maintain_order); + let sender = destination_channel.get_next_sender_with_stats(&self.runtime_stats); let op = self.op.clone(); let runtime_stats = self.runtime_stats.clone(); runtime_handle.spawn( @@ -107,26 +108,21 @@ impl PipelineNode for StreamingSinkNode { let mut sink = op.lock().await; let mut is_active = true; while is_active && let Some(val) = child_results_receiver.recv().await { - let val = val?; - let val = val.as_data(); - runtime_stats.mark_rows_received(val.len() as u64); + let val = val.data(); loop { - let result = runtime_stats.in_span(&span, || sink.execute(0, val))?; + let result = runtime_stats.in_span(&span, || sink.execute(0, &val))?; match result { StreamSinkOutput::HasMoreOutput(mp) => { - let sender = destination_sender.get_next_sender(&runtime_stats); sender.send(mp.into()).await.unwrap(); } StreamSinkOutput::NeedMoreInput(mp) => { if let Some(mp) = mp { - let sender = destination_sender.get_next_sender(&runtime_stats); sender.send(mp.into()).await.unwrap(); } break; } StreamSinkOutput::Finished(mp) => { if let Some(mp) = mp { - let sender = destination_sender.get_next_sender(&runtime_stats); sender.send(mp.into()).await.unwrap(); } is_active = false; @@ -139,7 +135,7 @@ impl PipelineNode for StreamingSinkNode { }, self.name(), ); - Ok(destination_receiver.into()) + Ok(destination_channel) } fn as_tree_display(&self) -> &dyn TreeDisplay { self diff --git a/src/daft-local-execution/src/sources/source.rs b/src/daft-local-execution/src/sources/source.rs index 1b98759e89..be7d6bef67 100644 --- a/src/daft-local-execution/src/sources/source.rs +++ b/src/daft-local-execution/src/sources/source.rs @@ -8,9 +8,7 @@ use futures::{stream::BoxStream, StreamExt}; use async_trait::async_trait; use crate::{ - channel::create_multi_channel, - pipeline::{PipelineNode, PipelineResultReceiver}, - runtime_stats::RuntimeStatsContext, + channel::PipelineChannel, pipeline::PipelineNode, runtime_stats::RuntimeStatsContext, ExecutionRuntimeHandle, }; @@ -76,13 +74,13 @@ impl PipelineNode for SourceNode { &mut self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, - ) -> crate::Result { + ) -> crate::Result { let mut source_stream = self.source .get_data(maintain_order, runtime_handle, self.io_stats.clone())?; - let (mut tx, rx) = create_multi_channel(1, maintain_order); - let counting_sender = tx.get_next_sender(&self.runtime_stats); + let mut channel = PipelineChannel::new(1, maintain_order); + let counting_sender = channel.get_next_sender_with_stats(&self.runtime_stats); runtime_handle.spawn( async move { while let Some(part) = source_stream.next().await { @@ -92,7 +90,7 @@ impl PipelineNode for SourceNode { }, self.name(), ); - Ok(rx.into()) + Ok(channel) } fn as_tree_display(&self) -> &dyn TreeDisplay { self From 8dde19260ee33d79e2f79c0f6dfef8b7a8825117 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Mon, 26 Aug 2024 18:00:42 -0700 Subject: [PATCH 4/6] pipline channel --- .../src/intermediate_ops/aggregate.rs | 4 +-- .../src/intermediate_ops/filter.rs | 4 +-- .../src/intermediate_ops/hash_join_probe.rs | 12 ++++----- .../src/intermediate_ops/intermediate_op.rs | 26 ++++++++++++------- .../src/intermediate_ops/project.rs | 4 +-- src/daft-local-execution/src/pipeline.rs | 4 +-- src/daft-local-execution/src/run.rs | 2 +- .../src/sinks/blocking_sink.rs | 2 +- .../src/sinks/streaming_sink.rs | 4 +-- 9 files changed, 34 insertions(+), 28 deletions(-) diff --git a/src/daft-local-execution/src/intermediate_ops/aggregate.rs b/src/daft-local-execution/src/intermediate_ops/aggregate.rs index c83bc3f540..39750d1401 100644 --- a/src/daft-local-execution/src/intermediate_ops/aggregate.rs +++ b/src/daft-local-execution/src/intermediate_ops/aggregate.rs @@ -29,10 +29,10 @@ impl IntermediateOperator for AggregateOperator { fn execute( &self, _idx: usize, - input: PipelineResultType, + input: &PipelineResultType, _state: Option<&mut Box>, ) -> DaftResult { - let out = input.data().agg(&self.agg_exprs, &self.group_by)?; + let out = input.as_data().agg(&self.agg_exprs, &self.group_by)?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( out, )))) diff --git a/src/daft-local-execution/src/intermediate_ops/filter.rs b/src/daft-local-execution/src/intermediate_ops/filter.rs index ef85379ec6..eeb02c4aff 100644 --- a/src/daft-local-execution/src/intermediate_ops/filter.rs +++ b/src/daft-local-execution/src/intermediate_ops/filter.rs @@ -25,10 +25,10 @@ impl IntermediateOperator for FilterOperator { fn execute( &self, _idx: usize, - input: PipelineResultType, + input: &PipelineResultType, _state: Option<&mut Box>, ) -> DaftResult { - let out = input.data().filter(&[self.predicate.clone()])?; + let out = input.as_data().filter(&[self.predicate.clone()])?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( out, )))) diff --git a/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs index 7b1bb04d22..1541cb2341 100644 --- a/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs @@ -19,9 +19,9 @@ enum HashJoinProbeState { } impl HashJoinProbeState { - fn set_table(&mut self, table: Arc, tables: Arc>) { + fn set_table(&mut self, table: &Arc, tables: &Arc>) { if let HashJoinProbeState::Building = self { - *self = HashJoinProbeState::ReadyToProbe(table, tables); + *self = HashJoinProbeState::ReadyToProbe(table.clone(), tables.clone()); } else { panic!("HashJoinProbeState should only be in Building state when setting table") } @@ -29,7 +29,7 @@ impl HashJoinProbeState { fn probe( &self, - input: Arc, + input: &Arc, right_on: &[ExprRef], pruned_right_side_columns: &[String], ) -> DaftResult> { @@ -109,7 +109,7 @@ impl IntermediateOperator for HashJoinProbeOperator { fn execute( &self, idx: usize, - input: PipelineResultType, + input: &PipelineResultType, state: Option<&mut Box>, ) -> DaftResult { println!("HashJoinProbeOperator::execute: idx: {}", idx); @@ -120,7 +120,7 @@ impl IntermediateOperator for HashJoinProbeOperator { .as_any_mut() .downcast_mut::() .expect("HashJoinProbeOperator state should be HashJoinProbeState"); - let (probe_table, tables) = input.probe_table(); + let (probe_table, tables) = input.as_probe_table(); state.set_table(probe_table, tables); Ok(IntermediateOperatorResult::NeedMoreInput(None)) } @@ -130,7 +130,7 @@ impl IntermediateOperator for HashJoinProbeOperator { .as_any_mut() .downcast_mut::() .expect("HashJoinProbeOperator state should be HashJoinProbeState"); - let input = input.data(); + let input = input.as_data(); let out = state.probe(input, &self.right_on, &self.pruned_right_side_columns)?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) } diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index 6c4253fe98..9abf35c6f6 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -30,7 +30,7 @@ pub trait IntermediateOperator: Send + Sync { fn execute( &self, idx: usize, - input: PipelineResultType, + input: &PipelineResultType, state: Option<&mut Box>, ) -> DaftResult; fn name(&self) -> &'static str; @@ -80,14 +80,20 @@ impl IntermediateNode { let span = info_span!("IntermediateOp::execute"); let mut state = op.make_state(); while let Some((idx, morsel)) = receiver.recv().await { - let result = rt_context.in_span(&span, || op.execute(idx, morsel, state.as_mut()))?; - match result { - IntermediateOperatorResult::NeedMoreInput(Some(mp)) => { - let _ = sender.send(mp.into()).await; - } - IntermediateOperatorResult::NeedMoreInput(None) => {} - IntermediateOperatorResult::HasMoreOutput(mp) => { - let _ = sender.send(mp.into()).await; + loop { + let result = + rt_context.in_span(&span, || op.execute(idx, &morsel, state.as_mut()))?; + match result { + IntermediateOperatorResult::NeedMoreInput(Some(mp)) => { + let _ = sender.send(mp.into()).await; + break; + } + IntermediateOperatorResult::NeedMoreInput(None) => { + break; + } + IntermediateOperatorResult::HasMoreOutput(mp) => { + let _ = sender.send(mp.into()).await; + } } } } @@ -140,7 +146,7 @@ impl IntermediateNode { let _ = worker_sender.send((idx, morsel.clone())).await; } } else { - buffer.push(morsel.data().clone()); + buffer.push(morsel.as_data().clone()); if let Some(ready) = buffer.try_clear() { let _ = send_to_next_worker(idx, ready?.into()).await; } diff --git a/src/daft-local-execution/src/intermediate_ops/project.rs b/src/daft-local-execution/src/intermediate_ops/project.rs index bc9f7f4eea..6f4b57ba00 100644 --- a/src/daft-local-execution/src/intermediate_ops/project.rs +++ b/src/daft-local-execution/src/intermediate_ops/project.rs @@ -25,10 +25,10 @@ impl IntermediateOperator for ProjectOperator { fn execute( &self, _idx: usize, - input: PipelineResultType, + input: &PipelineResultType, _state: Option<&mut Box>, ) -> DaftResult { - let out = input.data().eval_expression_list(&self.projection)?; + let out = input.as_data().eval_expression_list(&self.projection)?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( out, )))) diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index ad9668637f..7b05472ed7 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -53,14 +53,14 @@ impl From<(Arc, Arc>)> for PipelineResultType { } impl PipelineResultType { - pub fn data(self) -> Arc { + pub fn as_data(&self) -> &Arc { match self { PipelineResultType::Data(data) => data, _ => panic!("Expected data"), } } - pub fn probe_table(self) -> (Arc, Arc>) { + pub fn as_probe_table(&self) -> (&Arc, &Arc>) { match self { PipelineResultType::ProbeTable(probe_table, tables) => (probe_table, tables), _ => panic!("Expected probe table"), diff --git a/src/daft-local-execution/src/run.rs b/src/daft-local-execution/src/run.rs index 04726f86e5..6d9c53e01d 100644 --- a/src/daft-local-execution/src/run.rs +++ b/src/daft-local-execution/src/run.rs @@ -141,7 +141,7 @@ pub fn run_local( .await? .get_receiver(); while let Some(val) = receiver.recv().await { - let _ = tx.send(val.data()).await; + let _ = tx.send(val.as_data().clone()).await; } while let Some(result) = runtime_handle.join_next().await { diff --git a/src/daft-local-execution/src/sinks/blocking_sink.rs b/src/daft-local-execution/src/sinks/blocking_sink.rs index 63ea4c2cee..87ba335215 100644 --- a/src/daft-local-execution/src/sinks/blocking_sink.rs +++ b/src/daft-local-execution/src/sinks/blocking_sink.rs @@ -100,7 +100,7 @@ impl PipelineNode for BlockingSinkNode { let mut guard = op.lock().await; while let Some(val) = child_results_receiver.recv().await { if let BlockingSinkStatus::Finished = - rt_context.in_span(&span, || guard.sink(&val.data()))? + rt_context.in_span(&span, || guard.sink(val.as_data()))? { break; } diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index afe4be0d3b..dcf1b9db8d 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -108,9 +108,9 @@ impl PipelineNode for StreamingSinkNode { let mut sink = op.lock().await; let mut is_active = true; while is_active && let Some(val) = child_results_receiver.recv().await { - let val = val.data(); + let val = val.as_data(); loop { - let result = runtime_stats.in_span(&span, || sink.execute(0, &val))?; + let result = runtime_stats.in_span(&span, || sink.execute(0, val))?; match result { StreamSinkOutput::HasMoreOutput(mp) => { sender.send(mp.into()).await.unwrap(); From 05524b8d7f7a849b2c0352a31d12cc33bd29a5f0 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Mon, 26 Aug 2024 18:03:25 -0700 Subject: [PATCH 5/6] remove printlns --- .../src/intermediate_ops/hash_join_probe.rs | 1 - .../src/intermediate_ops/intermediate_op.rs | 1 - src/daft-local-execution/src/sinks/hash_join_build.rs | 2 +- 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs index 1541cb2341..b8a4053eb8 100644 --- a/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs @@ -112,7 +112,6 @@ impl IntermediateOperator for HashJoinProbeOperator { input: &PipelineResultType, state: Option<&mut Box>, ) -> DaftResult { - println!("HashJoinProbeOperator::execute: idx: {}", idx); match idx { 0 => { let state = state diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index 9abf35c6f6..aec8e91be6 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -138,7 +138,6 @@ impl IntermediateNode { }; for (idx, mut receiver) in receivers.into_iter().enumerate() { - println!("idx: {}", idx); let mut buffer = OperatorBuffer::new(morsel_size); while let Some(morsel) = receiver.recv().await { if morsel.should_broadcast() { diff --git a/src/daft-local-execution/src/sinks/hash_join_build.rs b/src/daft-local-execution/src/sinks/hash_join_build.rs index 0a86756334..b7940d4ebe 100644 --- a/src/daft-local-execution/src/sinks/hash_join_build.rs +++ b/src/daft-local-execution/src/sinks/hash_join_build.rs @@ -95,6 +95,7 @@ impl BlockingSink for HashJoinBuildSink { self.probe_table_state.add_tables(input)?; Ok(BlockingSinkStatus::NeedMoreInput) } + fn finalize(&mut self) -> DaftResult> { self.probe_table_state.finalize()?; if let ProbeTableState::Done { @@ -102,7 +103,6 @@ impl BlockingSink for HashJoinBuildSink { tables, } = &self.probe_table_state { - println!("HashJoinBuildSink::finalize: table_len: {}", tables.len()); Ok(Some((probe_table.clone(), tables.clone()).into())) } else { panic!("finalize should only be called after the probe table is built") From 7c76de2ff969efdf45275af8aef89f748ffe83d9 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Thu, 29 Aug 2024 14:07:00 -0700 Subject: [PATCH 6/6] no async trait --- Cargo.lock | 1 - src/daft-local-execution/Cargo.toml | 1 - .../src/intermediate_ops/intermediate_op.rs | 14 +++++--------- src/daft-local-execution/src/pipeline.rs | 4 +--- src/daft-local-execution/src/run.rs | 5 +---- .../src/sinks/blocking_sink.rs | 7 ++----- .../src/sinks/streaming_sink.rs | 7 +++---- src/daft-local-execution/src/sources/source.rs | 5 +---- 8 files changed, 13 insertions(+), 31 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b547e28349..7c06eeb8ad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1888,7 +1888,6 @@ dependencies = [ name = "daft-local-execution" version = "0.3.0-dev0" dependencies = [ - "async-trait", "common-daft-config", "common-display", "common-error", diff --git a/src/daft-local-execution/Cargo.toml b/src/daft-local-execution/Cargo.toml index 66949390f5..678ea0ccbe 100644 --- a/src/daft-local-execution/Cargo.toml +++ b/src/daft-local-execution/Cargo.toml @@ -1,5 +1,4 @@ [dependencies] -async-trait = {workspace = true} common-daft-config = {path = "../common/daft-config", default-features = false} common-display = {path = "../common/display", default-features = false} common-error = {path = "../common/error", default-features = false} diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index aec8e91be6..d4d6ea4456 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -5,8 +5,6 @@ use common_error::DaftResult; use daft_micropartition::MicroPartition; use tracing::{info_span, instrument}; -use async_trait::async_trait; - use crate::{ channel::{create_channel, PipelineChannel, Receiver, Sender}, pipeline::{PipelineNode, PipelineResultType}, @@ -100,7 +98,7 @@ impl IntermediateNode { Ok(()) } - pub async fn spawn_workers( + pub fn spawn_workers( &self, num_workers: usize, destination_channel: &mut PipelineChannel, @@ -185,7 +183,6 @@ impl TreeDisplay for IntermediateNode { } } -#[async_trait] impl PipelineNode for IntermediateNode { fn children(&self) -> Vec<&dyn PipelineNode> { self.children.iter().map(|v| v.as_ref()).collect() @@ -195,22 +192,21 @@ impl PipelineNode for IntermediateNode { self.intermediate_op.name() } - async fn start( + fn start( &mut self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, ) -> crate::Result { let mut child_result_receivers = Vec::with_capacity(self.children.len()); for child in self.children.iter_mut() { - let child_result_channel = child.start(maintain_order, runtime_handle).await?; + let child_result_channel = child.start(maintain_order, runtime_handle)?; child_result_receivers .push(child_result_channel.get_receiver_with_stats(&self.runtime_stats)); } let mut destination_channel = PipelineChannel::new(*NUM_CPUS, maintain_order); - let worker_senders = self - .spawn_workers(*NUM_CPUS, &mut destination_channel, runtime_handle) - .await; + let worker_senders = + self.spawn_workers(*NUM_CPUS, &mut destination_channel, runtime_handle); runtime_handle.spawn( Self::send_to_workers( child_result_receivers, diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index 7b05472ed7..f182b9f9cc 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -16,7 +16,6 @@ use crate::{ ExecutionRuntimeHandle, PipelineCreationSnafu, }; -use async_trait::async_trait; use common_display::{mermaid::MermaidDisplayVisitor, tree::TreeDisplay}; use common_error::DaftResult; use daft_core::{ @@ -72,11 +71,10 @@ impl PipelineResultType { } } -#[async_trait] pub trait PipelineNode: Sync + Send + TreeDisplay { fn children(&self) -> Vec<&dyn PipelineNode>; fn name(&self) -> &'static str; - async fn start( + fn start( &mut self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, diff --git a/src/daft-local-execution/src/run.rs b/src/daft-local-execution/src/run.rs index 6d9c53e01d..ba500c272b 100644 --- a/src/daft-local-execution/src/run.rs +++ b/src/daft-local-execution/src/run.rs @@ -136,10 +136,7 @@ pub fn run_local( .expect("Failed to create tokio runtime"); runtime.block_on(async { let mut runtime_handle = ExecutionRuntimeHandle::new(cfg.default_morsel_size); - let mut receiver = pipeline - .start(true, &mut runtime_handle) - .await? - .get_receiver(); + let mut receiver = pipeline.start(true, &mut runtime_handle)?.get_receiver(); while let Some(val) = receiver.recv().await { let _ = tx.send(val.as_data().clone()).await; } diff --git a/src/daft-local-execution/src/sinks/blocking_sink.rs b/src/daft-local-execution/src/sinks/blocking_sink.rs index 87ba335215..09e42ae81f 100644 --- a/src/daft-local-execution/src/sinks/blocking_sink.rs +++ b/src/daft-local-execution/src/sinks/blocking_sink.rs @@ -11,7 +11,6 @@ use crate::{ runtime_stats::RuntimeStatsContext, ExecutionRuntimeHandle, }; -use async_trait::async_trait; pub enum BlockingSinkStatus { NeedMoreInput, #[allow(dead_code)] @@ -68,7 +67,6 @@ impl TreeDisplay for BlockingSinkNode { } } -#[async_trait] impl PipelineNode for BlockingSinkNode { fn children(&self) -> Vec<&dyn PipelineNode> { vec![self.child.as_ref()] @@ -78,15 +76,14 @@ impl PipelineNode for BlockingSinkNode { self.name } - async fn start( + fn start( &mut self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, ) -> crate::Result { let child = self.child.as_mut(); let mut child_results_receiver = child - .start(false, runtime_handle) - .await? + .start(false, runtime_handle)? .get_receiver_with_stats(&self.runtime_stats); let mut destination_channel = PipelineChannel::new(1, maintain_order); diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index dcf1b9db8d..1804a3e07e 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -9,7 +9,7 @@ use crate::{ channel::PipelineChannel, pipeline::PipelineNode, runtime_stats::RuntimeStatsContext, ExecutionRuntimeHandle, NUM_CPUS, }; -use async_trait::async_trait; + pub enum StreamSinkOutput { NeedMoreInput(Option>), #[allow(dead_code)] @@ -73,7 +73,6 @@ impl TreeDisplay for StreamingSinkNode { } } -#[async_trait] impl PipelineNode for StreamingSinkNode { fn children(&self) -> Vec<&dyn PipelineNode> { self.children.iter().map(|v| v.as_ref()).collect() @@ -83,7 +82,7 @@ impl PipelineNode for StreamingSinkNode { self.name } - async fn start( + fn start( &mut self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, @@ -92,7 +91,7 @@ impl PipelineNode for StreamingSinkNode { .children .get_mut(0) .expect("we should only have 1 child"); - let child_results_channel = child.start(true, runtime_handle).await?; + let child_results_channel = child.start(true, runtime_handle)?; let mut child_results_receiver = child_results_channel.get_receiver_with_stats(&self.runtime_stats); diff --git a/src/daft-local-execution/src/sources/source.rs b/src/daft-local-execution/src/sources/source.rs index be7d6bef67..175dc66427 100644 --- a/src/daft-local-execution/src/sources/source.rs +++ b/src/daft-local-execution/src/sources/source.rs @@ -5,8 +5,6 @@ use daft_io::{IOStatsContext, IOStatsRef}; use daft_micropartition::MicroPartition; use futures::{stream::BoxStream, StreamExt}; -use async_trait::async_trait; - use crate::{ channel::PipelineChannel, pipeline::PipelineNode, runtime_stats::RuntimeStatsContext, ExecutionRuntimeHandle, @@ -62,7 +60,6 @@ impl TreeDisplay for SourceNode { } } -#[async_trait] impl PipelineNode for SourceNode { fn name(&self) -> &'static str { self.source.name() @@ -70,7 +67,7 @@ impl PipelineNode for SourceNode { fn children(&self) -> Vec<&dyn PipelineNode> { vec![] } - async fn start( + fn start( &mut self, maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle,