From c1c3182a2778f58232e63cbd54f7f6ca9049c2b4 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Thu, 19 Sep 2024 12:01:53 -0700 Subject: [PATCH 01/13] initial implementation --- Cargo.lock | 1 + src/arrow2/src/bitmap/mutable.rs | 19 + src/daft-local-execution/Cargo.toml | 1 + src/daft-local-execution/src/channel.rs | 98 ++-- .../anti_semi_hash_join_probe.rs | 13 +- .../src/intermediate_ops/hash_join_probe.rs | 273 ----------- .../intermediate_ops/inner_hash_join_probe.rs | 181 +++++++ .../src/intermediate_ops/intermediate_op.rs | 90 ++-- .../src/intermediate_ops/mod.rs | 2 +- src/daft-local-execution/src/pipeline.rs | 47 +- src/daft-local-execution/src/runtime_stats.rs | 9 +- .../src/sinks/blocking_sink.rs | 7 +- src/daft-local-execution/src/sinks/limit.rs | 86 +++- src/daft-local-execution/src/sinks/mod.rs | 1 + .../src/sinks/outer_hash_join_probe.rs | 451 ++++++++++++++++++ .../src/sinks/streaming_sink.rs | 197 ++++++-- .../src/sources/source.rs | 4 +- tests/cookbook/test_joins.py | 35 +- tests/dataframe/test_joins.py | 2 +- 19 files changed, 1022 insertions(+), 495 deletions(-) delete mode 100644 src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs create mode 100644 src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs create mode 100644 src/daft-local-execution/src/sinks/outer_hash_join_probe.rs diff --git a/Cargo.lock b/Cargo.lock index 90e4745ba8..a7fec8ec0e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1948,6 +1948,7 @@ dependencies = [ name = "daft-local-execution" version = "0.3.0-dev0" dependencies = [ + "arrow2", "common-daft-config", "common-display", "common-error", diff --git a/src/arrow2/src/bitmap/mutable.rs b/src/arrow2/src/bitmap/mutable.rs index cb77decd84..edf373a506 100644 --- a/src/arrow2/src/bitmap/mutable.rs +++ b/src/arrow2/src/bitmap/mutable.rs @@ -325,6 +325,25 @@ impl MutableBitmap { pub(crate) fn bitchunks_exact_mut(&mut self) -> BitChunksExactMut { BitChunksExactMut::new(&mut self.buffer, self.length) } + + pub fn or(&self, other: &MutableBitmap) -> MutableBitmap { + assert_eq!( + self.length, other.length, + "Bitmaps must have the same length" + ); + + let new_buffer: Vec = self + .buffer + .iter() + .zip(other.buffer.iter()) + .map(|(&a, &b)| a | b) // Apply bitwise OR on each pair of bytes + .collect(); + + MutableBitmap { + buffer: new_buffer, + length: self.length, + } + } } impl From for Bitmap { diff --git a/src/daft-local-execution/Cargo.toml b/src/daft-local-execution/Cargo.toml index 2dac516672..bf49c9fa9d 100644 --- a/src/daft-local-execution/Cargo.toml +++ b/src/daft-local-execution/Cargo.toml @@ -1,4 +1,5 @@ [dependencies] +arrow2 = {workspace = true} common-daft-config = {path = "../common/daft-config", default-features = false} common-display = {path = "../common/display", default-features = false} common-error = {path = "../common/error", default-features = false} diff --git a/src/daft-local-execution/src/channel.rs b/src/daft-local-execution/src/channel.rs index bb22b9d4ea..84f9bbe8a3 100644 --- a/src/daft-local-execution/src/channel.rs +++ b/src/daft-local-execution/src/channel.rs @@ -13,86 +13,64 @@ pub fn create_channel(buffer_size: usize) -> (Sender, Receiver) { } pub struct PipelineChannel { - sender: PipelineSender, - receiver: PipelineReceiver, + sender: Sender, + receiver: Receiver, } impl PipelineChannel { - pub fn new(buffer_size: usize, in_order: bool) -> Self { - match in_order { - true => { - let (senders, receivers) = (0..buffer_size).map(|_| create_channel(1)).unzip(); - let sender = PipelineSender::InOrder(RoundRobinSender::new(senders)); - let receiver = PipelineReceiver::InOrder(RoundRobinReceiver::new(receivers)); - Self { sender, receiver } - } - false => { - let (sender, receiver) = create_channel(buffer_size); - let sender = PipelineSender::OutOfOrder(sender); - let receiver = PipelineReceiver::OutOfOrder(receiver); - Self { sender, receiver } - } - } + pub(crate) fn new() -> Self { + let (sender, receiver) = create_channel(1); + Self { sender, receiver } } - fn get_next_sender(&mut self) -> Sender { - match &mut self.sender { - PipelineSender::InOrder(rr) => rr.get_next_sender(), - PipelineSender::OutOfOrder(sender) => sender.clone(), - } + pub(crate) fn get_sender_with_stats(&self, stats: &Arc) -> CountingSender { + CountingSender::new(self.sender.clone(), stats.clone()) } - pub(crate) fn get_next_sender_with_stats( - &mut self, - rt: &Arc, - ) -> CountingSender { - CountingSender::new(self.get_next_sender(), rt.clone()) + pub(crate) fn get_receiver_with_stats( + self, + stats: &Arc, + ) -> CountingReceiver { + CountingReceiver::new(self.receiver, stats.clone()) } - pub fn get_receiver(self) -> PipelineReceiver { + pub(crate) fn get_receiver(self) -> Receiver { self.receiver } - - pub(crate) fn get_receiver_with_stats(self, rt: &Arc) -> CountingReceiver { - CountingReceiver::new(self.get_receiver(), rt.clone()) - } -} - -pub enum PipelineSender { - InOrder(RoundRobinSender), - OutOfOrder(Sender), } -pub struct RoundRobinSender { - senders: Vec>, - curr_sender_idx: usize, -} - -impl RoundRobinSender { - pub fn new(senders: Vec>) -> Self { - Self { - senders, - curr_sender_idx: 0, +pub(crate) fn make_ordering_aware_channel( + buffer_size: usize, + ordered: bool, +) -> (Vec>, OrderingAwareReceiver) { + match ordered { + true => { + let (senders, receivers) = (0..buffer_size).map(|_| create_channel(1)).unzip(); + ( + senders, + OrderingAwareReceiver::Ordered(RoundRobinReceiver::new(receivers)), + ) + } + false => { + let (sender, receiver) = create_channel(buffer_size); + ( + (0..buffer_size).map(|_| sender.clone()).collect(), + OrderingAwareReceiver::Unordered(receiver), + ) } - } - - pub fn get_next_sender(&mut self) -> Sender { - let next_idx = self.curr_sender_idx; - self.curr_sender_idx = (next_idx + 1) % self.senders.len(); - self.senders[next_idx].clone() } } -pub enum PipelineReceiver { - InOrder(RoundRobinReceiver), - OutOfOrder(Receiver), +pub enum OrderingAwareReceiver { + Ordered(RoundRobinReceiver), + Unordered(Receiver), } -impl PipelineReceiver { - pub async fn recv(&mut self) -> Option { +impl OrderingAwareReceiver { + pub async fn recv(&mut self) -> Option { match self { - PipelineReceiver::InOrder(rr) => rr.recv().await, - PipelineReceiver::OutOfOrder(r) => r.recv().await, + OrderingAwareReceiver::Ordered(rr_receiver) => rr_receiver.recv().await, + OrderingAwareReceiver::Unordered(receiver) => receiver.recv().await, } } } diff --git a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs index 14bc949ff0..637beb3fc5 100644 --- a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs @@ -43,14 +43,14 @@ impl IntermediateOperatorState for AntiSemiProbeState { pub struct AntiSemiProbeOperator { probe_on: Vec, - join_type: JoinType, + is_semi: bool, } impl AntiSemiProbeOperator { - pub fn new(probe_on: Vec, join_type: JoinType) -> Self { + pub fn new(probe_on: Vec, join_type: &JoinType) -> Self { Self { probe_on, - join_type, + is_semi: *join_type == JoinType::Semi, } } @@ -76,7 +76,7 @@ impl AntiSemiProbeOperator { let iter = probe_set.probe_exists(&join_keys)?; for (probe_row_idx, matched) in iter.enumerate() { - match (self.join_type == JoinType::Semi, matched) { + match (self.is_semi, matched) { (true, true) | (false, false) => { probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); } @@ -120,10 +120,7 @@ impl IntermediateOperator for AntiSemiProbeOperator { .downcast_mut::() .expect("AntiSemiProbeOperator state should be AntiSemiProbeState"); let input = input.as_data(); - let out = match self.join_type { - JoinType::Semi | JoinType::Anti => self.probe_anti_semi(input, state), - _ => unreachable!("Only Semi and Anti joins are supported"), - }?; + let out = self.probe_anti_semi(input, state)?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) } } diff --git a/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs deleted file mode 100644 index f849064964..0000000000 --- a/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs +++ /dev/null @@ -1,273 +0,0 @@ -use std::sync::Arc; - -use common_error::DaftResult; -use daft_core::prelude::SchemaRef; -use daft_dsl::ExprRef; -use daft_micropartition::MicroPartition; -use daft_plan::JoinType; -use daft_table::{GrowableTable, Probeable, Table}; -use indexmap::IndexSet; -use tracing::{info_span, instrument}; - -use super::intermediate_op::{ - IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, -}; -use crate::pipeline::PipelineResultType; - -enum HashJoinProbeState { - Building, - ReadyToProbe(Arc, Arc>), -} - -impl HashJoinProbeState { - fn set_table(&mut self, table: &Arc, tables: &Arc>) { - if let HashJoinProbeState::Building = self { - *self = HashJoinProbeState::ReadyToProbe(table.clone(), tables.clone()); - } else { - panic!("HashJoinProbeState should only be in Building state when setting table") - } - } - - fn get_probeable_and_table(&self) -> (&Arc, &Arc>) { - if let HashJoinProbeState::ReadyToProbe(probe_table, tables) = self { - (probe_table, tables) - } else { - panic!("get_probeable_and_table can only be used during the ReadyToProbe Phase") - } - } -} - -impl IntermediateOperatorState for HashJoinProbeState { - fn as_any_mut(&mut self) -> &mut dyn std::any::Any { - self - } -} - -pub struct HashJoinProbeOperator { - probe_on: Vec, - common_join_keys: Vec, - left_non_join_columns: Vec, - right_non_join_columns: Vec, - join_type: JoinType, - build_on_left: bool, -} - -impl HashJoinProbeOperator { - pub fn new( - probe_on: Vec, - left_schema: &SchemaRef, - right_schema: &SchemaRef, - join_type: JoinType, - build_on_left: bool, - common_join_keys: IndexSet, - ) -> Self { - let (common_join_keys, left_non_join_columns, right_non_join_columns) = match join_type { - JoinType::Inner | JoinType::Left | JoinType::Right => { - let left_non_join_columns = left_schema - .fields - .keys() - .filter(|c| !common_join_keys.contains(*c)) - .cloned() - .collect(); - let right_non_join_columns = right_schema - .fields - .keys() - .filter(|c| !common_join_keys.contains(*c)) - .cloned() - .collect(); - ( - common_join_keys.into_iter().collect(), - left_non_join_columns, - right_non_join_columns, - ) - } - _ => { - panic!("Semi, Anti, and join are not supported in HashJoinProbeOperator") - } - }; - Self { - probe_on, - common_join_keys, - left_non_join_columns, - right_non_join_columns, - join_type, - build_on_left, - } - } - - fn probe_inner( - &self, - input: &Arc, - state: &mut HashJoinProbeState, - ) -> DaftResult> { - let (probe_table, tables) = state.get_probeable_and_table(); - - let _growables = info_span!("HashJoinOperator::build_growables").entered(); - - let mut build_side_growable = - GrowableTable::new(&tables.iter().collect::>(), false, 20)?; - - let input_tables = input.get_tables()?; - - let mut probe_side_growable = - GrowableTable::new(&input_tables.iter().collect::>(), false, 20)?; - - drop(_growables); - { - let _loop = info_span!("HashJoinOperator::eval_and_probe").entered(); - for (probe_side_table_idx, table) in input_tables.iter().enumerate() { - // we should emit one table at a time when this is streaming - let join_keys = table.eval_expression_list(&self.probe_on)?; - let idx_mapper = probe_table.probe_indices(&join_keys)?; - - for (probe_row_idx, inner_iter) in idx_mapper.make_iter().enumerate() { - if let Some(inner_iter) = inner_iter { - for (build_side_table_idx, build_row_idx) in inner_iter { - build_side_growable.extend( - build_side_table_idx as usize, - build_row_idx as usize, - 1, - ); - // we can perform run length compression for this to make this more efficient - probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); - } - } - } - } - } - let build_side_table = build_side_growable.build()?; - let probe_side_table = probe_side_growable.build()?; - - let (left_table, right_table) = if self.build_on_left { - (build_side_table, probe_side_table) - } else { - (probe_side_table, build_side_table) - }; - - let join_keys_table = left_table.get_columns(&self.common_join_keys)?; - let left_non_join_columns = left_table.get_columns(&self.left_non_join_columns)?; - let right_non_join_columns = right_table.get_columns(&self.right_non_join_columns)?; - let final_table = join_keys_table - .union(&left_non_join_columns)? - .union(&right_non_join_columns)?; - - Ok(Arc::new(MicroPartition::new_loaded( - final_table.schema.clone(), - Arc::new(vec![final_table]), - None, - ))) - } - - fn probe_left_right( - &self, - input: &Arc, - state: &mut HashJoinProbeState, - ) -> DaftResult> { - let (probe_table, tables) = state.get_probeable_and_table(); - - let _growables = info_span!("HashJoinOperator::build_growables").entered(); - - let mut build_side_growable = GrowableTable::new( - &tables.iter().collect::>(), - true, - tables.iter().map(|t| t.len()).sum(), - )?; - - let input_tables = input.get_tables()?; - - let mut probe_side_growable = - GrowableTable::new(&input_tables.iter().collect::>(), false, input.len())?; - - drop(_growables); - { - let _loop = info_span!("HashJoinOperator::eval_and_probe").entered(); - for (probe_side_table_idx, table) in input_tables.iter().enumerate() { - let join_keys = table.eval_expression_list(&self.probe_on)?; - let idx_mapper = probe_table.probe_indices(&join_keys)?; - - for (probe_row_idx, inner_iter) in idx_mapper.make_iter().enumerate() { - if let Some(inner_iter) = inner_iter { - for (build_side_table_idx, build_row_idx) in inner_iter { - build_side_growable.extend( - build_side_table_idx as usize, - build_row_idx as usize, - 1, - ); - probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); - } - } else { - // if there's no match, we should still emit the probe side and fill the build side with nulls - build_side_growable.add_nulls(1); - probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); - } - } - } - } - let build_side_table = build_side_growable.build()?; - let probe_side_table = probe_side_growable.build()?; - - let final_table = if self.join_type == JoinType::Left { - let join_table = probe_side_table.get_columns(&self.common_join_keys)?; - let left = probe_side_table.get_columns(&self.left_non_join_columns)?; - let right = build_side_table.get_columns(&self.right_non_join_columns)?; - join_table.union(&left)?.union(&right)? - } else { - let join_table = probe_side_table.get_columns(&self.common_join_keys)?; - let left = build_side_table.get_columns(&self.left_non_join_columns)?; - let right = probe_side_table.get_columns(&self.right_non_join_columns)?; - join_table.union(&left)?.union(&right)? - }; - Ok(Arc::new(MicroPartition::new_loaded( - final_table.schema.clone(), - Arc::new(vec![final_table]), - None, - ))) - } -} - -impl IntermediateOperator for HashJoinProbeOperator { - #[instrument(skip_all, name = "HashJoinOperator::execute")] - fn execute( - &self, - idx: usize, - input: &PipelineResultType, - state: Option<&mut Box>, - ) -> DaftResult { - match idx { - 0 => { - let state = state - .expect("HashJoinProbeOperator should have state") - .as_any_mut() - .downcast_mut::() - .expect("HashJoinProbeOperator state should be HashJoinProbeState"); - let (probe_table, tables) = input.as_probe_table(); - state.set_table(probe_table, tables); - Ok(IntermediateOperatorResult::NeedMoreInput(None)) - } - _ => { - let state = state - .expect("HashJoinProbeOperator should have state") - .as_any_mut() - .downcast_mut::() - .expect("HashJoinProbeOperator state should be HashJoinProbeState"); - let input = input.as_data(); - let out = match self.join_type { - JoinType::Inner => self.probe_inner(input, state), - JoinType::Left | JoinType::Right => self.probe_left_right(input, state), - _ => { - unimplemented!("Only Inner, Left, and Right joins are supported in HashJoinProbeOperator") - } - }?; - Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) - } - } - } - - fn name(&self) -> &'static str { - "HashJoinProbeOperator" - } - - fn make_state(&self) -> Option> { - Some(Box::new(HashJoinProbeState::Building)) - } -} diff --git a/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs new file mode 100644 index 0000000000..e102804770 --- /dev/null +++ b/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs @@ -0,0 +1,181 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use daft_core::prelude::SchemaRef; +use daft_dsl::ExprRef; +use daft_micropartition::MicroPartition; +use daft_table::{GrowableTable, Probeable, Table}; +use indexmap::IndexSet; +use tracing::{info_span, instrument}; + +use super::intermediate_op::{ + IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, +}; +use crate::pipeline::PipelineResultType; + +enum InnerHashJoinProbeState { + Building, + ReadyToProbe(Arc, Arc>), +} + +impl InnerHashJoinProbeState { + fn set_table(&mut self, table: &Arc, tables: &Arc>) { + if let InnerHashJoinProbeState::Building = self { + *self = InnerHashJoinProbeState::ReadyToProbe(table.clone(), tables.clone()); + } else { + panic!("InnerHashJoinProbeState should only be in Building state when setting table") + } + } + + fn get_probeable_and_table(&self) -> (&Arc, &Arc>) { + if let InnerHashJoinProbeState::ReadyToProbe(probe_table, tables) = self { + (probe_table, tables) + } else { + panic!("get_probeable_and_table can only be used during the ReadyToProbe Phase") + } + } +} + +impl IntermediateOperatorState for InnerHashJoinProbeState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } +} + +pub struct InnerHashJoinProbeOperator { + probe_on: Vec, + common_join_keys: Vec, + left_non_join_columns: Vec, + right_non_join_columns: Vec, + build_on_left: bool, +} + +impl InnerHashJoinProbeOperator { + pub fn new( + probe_on: Vec, + left_schema: &SchemaRef, + right_schema: &SchemaRef, + build_on_left: bool, + common_join_keys: IndexSet, + ) -> Self { + let left_non_join_columns = left_schema + .fields + .keys() + .filter(|c| !common_join_keys.contains(*c)) + .cloned() + .collect(); + let right_non_join_columns = right_schema + .fields + .keys() + .filter(|c| !common_join_keys.contains(*c)) + .cloned() + .collect(); + let common_join_keys = common_join_keys.into_iter().collect(); + Self { + probe_on, + common_join_keys, + left_non_join_columns, + right_non_join_columns, + build_on_left, + } + } + + fn probe_inner( + &self, + input: &Arc, + state: &mut InnerHashJoinProbeState, + ) -> DaftResult> { + let (probe_table, tables) = state.get_probeable_and_table(); + + let _growables = info_span!("InnerHashJoinOperator::build_growables").entered(); + + let mut build_side_growable = + GrowableTable::new(&tables.iter().collect::>(), false, 20)?; + + let input_tables = input.get_tables()?; + + let mut probe_side_growable = + GrowableTable::new(&input_tables.iter().collect::>(), false, 20)?; + + drop(_growables); + { + let _loop = info_span!("InnerHashJoinOperator::eval_and_probe").entered(); + for (probe_side_table_idx, table) in input_tables.iter().enumerate() { + // we should emit one table at a time when this is streaming + let join_keys = table.eval_expression_list(&self.probe_on)?; + let idx_mapper = probe_table.probe_indices(&join_keys)?; + + for (probe_row_idx, inner_iter) in idx_mapper.make_iter().enumerate() { + if let Some(inner_iter) = inner_iter { + for (build_side_table_idx, build_row_idx) in inner_iter { + build_side_growable.extend( + build_side_table_idx as usize, + build_row_idx as usize, + 1, + ); + // we can perform run length compression for this to make this more efficient + probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); + } + } + } + } + } + let build_side_table = build_side_growable.build()?; + let probe_side_table = probe_side_growable.build()?; + + let (left_table, right_table) = if self.build_on_left { + (build_side_table, probe_side_table) + } else { + (probe_side_table, build_side_table) + }; + + let join_keys_table = left_table.get_columns(&self.common_join_keys)?; + let left_non_join_columns = left_table.get_columns(&self.left_non_join_columns)?; + let right_non_join_columns = right_table.get_columns(&self.right_non_join_columns)?; + let final_table = join_keys_table + .union(&left_non_join_columns)? + .union(&right_non_join_columns)?; + + Ok(Arc::new(MicroPartition::new_loaded( + final_table.schema.clone(), + Arc::new(vec![final_table]), + None, + ))) + } +} + +impl IntermediateOperator for InnerHashJoinProbeOperator { + #[instrument(skip_all, name = "InnerHashJoinOperator::execute")] + fn execute( + &self, + idx: usize, + input: &PipelineResultType, + state: Option<&mut Box>, + ) -> DaftResult { + let state = state + .expect("InnerHashJoinProbeOperator should have state") + .as_any_mut() + .downcast_mut::() + .expect("InnerHashJoinProbeOperator state should be InnerHashJoinProbeState"); + match idx { + 0 => { + let (probe_table, tables) = input.as_probe_table(); + state.set_table(probe_table, tables); + Ok(IntermediateOperatorResult::NeedMoreInput(None)) + } + _ => { + let input = input.as_data(); + let out = self.probe_inner(input, state)?; + Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) + } + } + } + + fn name(&self) -> &'static str { + "InnerHashJoinProbeOperator" + } + + fn make_state(&self) -> Option> { + Some(Box::new(InnerHashJoinProbeState::Building)) + } +} diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index abb5c5388b..8f561c94b7 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -7,9 +7,9 @@ use tracing::{info_span, instrument}; use super::buffer::OperatorBuffer; use crate::{ - channel::{create_channel, PipelineChannel, Receiver, Sender}, + channel::{create_channel, make_ordering_aware_channel, PipelineChannel, Receiver, Sender}, pipeline::{PipelineNode, PipelineResultType}, - runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}, + runtime_stats::{CountingReceiver, RuntimeStatsContext}, ExecutionRuntimeHandle, NUM_CPUS, }; @@ -68,28 +68,28 @@ impl IntermediateNode { } #[instrument(level = "info", skip_all, name = "IntermediateOperator::run_worker")] - pub async fn run_worker( + async fn run_worker( op: Arc, - mut receiver: Receiver<(usize, PipelineResultType)>, - sender: CountingSender, + mut input_receiver: Receiver<(usize, PipelineResultType)>, + output_sender: Sender>, rt_context: Arc, ) -> DaftResult<()> { let span = info_span!("IntermediateOp::execute"); let mut state = op.make_state(); - while let Some((idx, morsel)) = receiver.recv().await { + while let Some((idx, morsel)) = input_receiver.recv().await { loop { let result = rt_context.in_span(&span, || op.execute(idx, &morsel, state.as_mut()))?; match result { IntermediateOperatorResult::NeedMoreInput(Some(mp)) => { - let _ = sender.send(mp.into()).await; + let _ = output_sender.send(mp).await; break; } IntermediateOperatorResult::NeedMoreInput(None) => { break; } IntermediateOperatorResult::HasMoreOutput(mp) => { - let _ = sender.send(mp.into()).await; + let _ = output_sender.send(mp).await; } } } @@ -97,32 +97,24 @@ impl IntermediateNode { Ok(()) } - pub fn spawn_workers( - &self, - num_workers: usize, - destination_channel: &mut PipelineChannel, - runtime_handle: &mut ExecutionRuntimeHandle, - ) -> Vec> { - let mut worker_senders = Vec::with_capacity(num_workers); - for _ in 0..num_workers { - let (worker_sender, worker_receiver) = create_channel(1); - let destination_sender = - destination_channel.get_next_sender_with_stats(&self.runtime_stats); - runtime_handle.spawn( - Self::run_worker( - self.intermediate_op.clone(), - worker_receiver, - destination_sender, - self.runtime_stats.clone(), - ), - self.intermediate_op.name(), - ); - worker_senders.push(worker_sender); + fn spawn_workers( + op: Arc, + input_receivers: Vec>, + output_senders: Vec>>, + worker_set: &mut tokio::task::JoinSet>, + stats: Arc, + ) { + for (input_receiver, output_sender) in input_receivers.into_iter().zip(output_senders) { + worker_set.spawn(Self::run_worker( + op.clone(), + input_receiver, + output_sender, + stats.clone(), + )); } - worker_senders } - pub async fn send_to_workers( + async fn forward_input_to_workers( receivers: Vec, worker_senders: Vec>, morsel_size: usize, @@ -202,16 +194,36 @@ impl PipelineNode for IntermediateNode { child_result_receivers .push(child_result_channel.get_receiver_with_stats(&self.runtime_stats)); } - let mut destination_channel = PipelineChannel::new(*NUM_CPUS, maintain_order); - let worker_senders = - self.spawn_workers(*NUM_CPUS, &mut destination_channel, runtime_handle); + let destination_channel = PipelineChannel::new(); + let destination_sender = destination_channel.get_sender_with_stats(&self.runtime_stats); + + let op = self.intermediate_op.clone(); + let stats = self.runtime_stats.clone(); + let morsel_size = runtime_handle.default_morsel_size(); runtime_handle.spawn( - Self::send_to_workers( - child_result_receivers, - worker_senders, - runtime_handle.default_morsel_size(), - ), + async move { + let num_workers = *NUM_CPUS; + let (input_senders, input_receivers) = + (0..num_workers).map(|_| create_channel(1)).unzip(); + let (output_senders, mut output_receiver) = + make_ordering_aware_channel(num_workers, maintain_order); + let mut worker_set = tokio::task::JoinSet::new(); + Self::spawn_workers( + op.clone(), + input_receivers, + output_senders, + &mut worker_set, + stats.clone(), + ); + Self::forward_input_to_workers(child_result_receivers, input_senders, morsel_size) + .await?; + + while let Some(morsel) = output_receiver.recv().await { + let _ = destination_sender.send(morsel.into()).await; + } + Ok(()) + }, self.intermediate_op.name(), ); Ok(destination_channel) diff --git a/src/daft-local-execution/src/intermediate_ops/mod.rs b/src/daft-local-execution/src/intermediate_ops/mod.rs index 593f9ef5ed..21c505abda 100644 --- a/src/daft-local-execution/src/intermediate_ops/mod.rs +++ b/src/daft-local-execution/src/intermediate_ops/mod.rs @@ -2,6 +2,6 @@ pub mod aggregate; pub mod anti_semi_hash_join_probe; pub mod buffer; pub mod filter; -pub mod hash_join_probe; +pub mod inner_hash_join_probe; pub mod intermediate_op; pub mod project; diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index 62935dfff0..7f990fc63d 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -22,12 +22,13 @@ use crate::{ channel::PipelineChannel, intermediate_ops::{ aggregate::AggregateOperator, anti_semi_hash_join_probe::AntiSemiProbeOperator, - filter::FilterOperator, hash_join_probe::HashJoinProbeOperator, + filter::FilterOperator, inner_hash_join_probe::InnerHashJoinProbeOperator, intermediate_op::IntermediateNode, project::ProjectOperator, }, sinks::{ aggregate::AggregateSink, blocking_sink::BlockingSinkNode, - hash_join_build::HashJoinBuildSink, limit::LimitSink, sort::SortSink, + hash_join_build::HashJoinBuildSink, limit::LimitSink, + outer_hash_join_probe::OuterHashJoinProbeSink, sort::SortSink, streaming_sink::StreamingSinkNode, }, sources::in_memory::InMemorySource, @@ -131,7 +132,7 @@ pub fn physical_plan_to_pipeline( }) => { let sink = LimitSink::new(*num_rows as usize); let child_node = physical_plan_to_pipeline(input, psets)?; - StreamingSinkNode::new(sink.boxed(), vec![child_node]).boxed() + StreamingSinkNode::new(Arc::new(sink), vec![child_node]).boxed() } LocalPhysicalPlan::Concat(_) => { todo!("concat") @@ -238,11 +239,9 @@ pub fn physical_plan_to_pipeline( let build_on_left = match join_type { JoinType::Inner => true, JoinType::Right => true, + JoinType::Outer => true, JoinType::Left => false, JoinType::Anti | JoinType::Semi => false, - JoinType::Outer => { - unimplemented!("Outer join not supported yet"); - } }; let (build_on, probe_on, build_child, probe_child) = match build_on_left { true => (left_on, right_on, left, right), @@ -251,7 +250,7 @@ pub fn physical_plan_to_pipeline( let build_schema = build_child.schema(); let probe_schema = probe_child.schema(); - let probe_node = || -> DaftResult<_> { + || -> DaftResult<_> { let common_join_keys: IndexSet<_> = get_common_join_keys(left_on, right_on) .map(|k| k.to_string()) .collect(); @@ -297,32 +296,40 @@ pub fn physical_plan_to_pipeline( let probe_child_node = physical_plan_to_pipeline(probe_child, psets)?; match join_type { - JoinType::Anti | JoinType::Semi => DaftResult::Ok(IntermediateNode::new( - Arc::new(AntiSemiProbeOperator::new(casted_probe_on, *join_type)), + JoinType::Anti | JoinType::Semi => Ok(IntermediateNode::new( + Arc::new(AntiSemiProbeOperator::new(casted_probe_on, join_type)), + vec![build_node, probe_child_node], + ) + .boxed()), + JoinType::Inner => Ok(IntermediateNode::new( + Arc::new(InnerHashJoinProbeOperator::new( + casted_probe_on, + left_schema, + right_schema, + build_on_left, + common_join_keys, + )), vec![build_node, probe_child_node], - )), - JoinType::Inner | JoinType::Left | JoinType::Right => { - DaftResult::Ok(IntermediateNode::new( - Arc::new(HashJoinProbeOperator::new( + ) + .boxed()), + JoinType::Left | JoinType::Right | JoinType::Outer => { + Ok(StreamingSinkNode::new( + Arc::new(OuterHashJoinProbeSink::new( casted_probe_on, left_schema, right_schema, *join_type, - build_on_left, common_join_keys, )), vec![build_node, probe_child_node], - )) - } - JoinType::Outer => { - unimplemented!("Outer join not supported yet"); + ) + .boxed()) } } }() .with_context(|_| PipelineCreationSnafu { plan_name: physical_plan.name(), - })?; - probe_node.boxed() + })? } _ => { unimplemented!("Physical plan not supported: {}", physical_plan.name()); diff --git a/src/daft-local-execution/src/runtime_stats.rs b/src/daft-local-execution/src/runtime_stats.rs index 7489a8fd36..cc8b8a368a 100644 --- a/src/daft-local-execution/src/runtime_stats.rs +++ b/src/daft-local-execution/src/runtime_stats.rs @@ -8,7 +8,7 @@ use std::{ use tokio::sync::mpsc::error::SendError; use crate::{ - channel::{PipelineReceiver, Sender}, + channel::{Receiver, Sender}, pipeline::PipelineResultType, }; @@ -133,12 +133,15 @@ impl CountingSender { } pub(crate) struct CountingReceiver { - receiver: PipelineReceiver, + receiver: Receiver, rt: Arc, } impl CountingReceiver { - pub(crate) fn new(receiver: PipelineReceiver, rt: Arc) -> Self { + pub(crate) fn new( + receiver: Receiver, + rt: Arc, + ) -> Self { Self { receiver, rt } } #[inline] diff --git a/src/daft-local-execution/src/sinks/blocking_sink.rs b/src/daft-local-execution/src/sinks/blocking_sink.rs index 09e42ae81f..00012cb2f1 100644 --- a/src/daft-local-execution/src/sinks/blocking_sink.rs +++ b/src/daft-local-execution/src/sinks/blocking_sink.rs @@ -78,7 +78,7 @@ impl PipelineNode for BlockingSinkNode { fn start( &mut self, - maintain_order: bool, + _maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, ) -> crate::Result { let child = self.child.as_mut(); @@ -86,9 +86,8 @@ impl PipelineNode for BlockingSinkNode { .start(false, runtime_handle)? .get_receiver_with_stats(&self.runtime_stats); - let mut destination_channel = PipelineChannel::new(1, maintain_order); - let destination_sender = - destination_channel.get_next_sender_with_stats(&self.runtime_stats); + let destination_channel = PipelineChannel::new(); + let destination_sender = destination_channel.get_sender_with_stats(&self.runtime_stats); let op = self.op.clone(); let rt_context = self.runtime_stats.clone(); runtime_handle.spawn( diff --git a/src/daft-local-execution/src/sinks/limit.rs b/src/daft-local-execution/src/sinks/limit.rs index 91435961d5..5dcb70d109 100644 --- a/src/daft-local-execution/src/sinks/limit.rs +++ b/src/daft-local-execution/src/sinks/limit.rs @@ -4,51 +4,76 @@ use common_error::DaftResult; use daft_micropartition::MicroPartition; use tracing::instrument; -use super::streaming_sink::{StreamSinkOutput, StreamingSink}; +use super::streaming_sink::{StreamingSink, StreamingSinkOutput, StreamingSinkState}; +use crate::pipeline::PipelineResultType; + +struct LimitSinkState { + remaining: usize, +} + +impl LimitSinkState { + fn new(remaining: usize) -> Self { + Self { remaining } + } + + fn get_remaining(&self) -> usize { + self.remaining + } + + fn set_remaining(&mut self, remaining: usize) { + self.remaining = remaining; + } +} + +impl StreamingSinkState for LimitSinkState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } +} pub struct LimitSink { - #[allow(dead_code)] limit: usize, - remaining: usize, } impl LimitSink { pub fn new(limit: usize) -> Self { - Self { - limit, - remaining: limit, - } - } - pub fn boxed(self) -> Box { - Box::new(self) + Self { limit } } } impl StreamingSink for LimitSink { #[instrument(skip_all, name = "LimitSink::sink")] fn execute( - &mut self, + &self, index: usize, - input: &Arc, - ) -> DaftResult { + input: &PipelineResultType, + state: &mut dyn StreamingSinkState, + ) -> DaftResult { assert_eq!(index, 0); - + let state = state + .as_any_mut() + .downcast_mut::() + .expect("Limit Sink should have LimitSinkState"); + let input = input.as_data(); let input_num_rows = input.len(); - + let mut remaining = state.get_remaining(); use std::cmp::Ordering::*; - match input_num_rows.cmp(&self.remaining) { + match input_num_rows.cmp(&remaining) { Less => { - self.remaining -= input_num_rows; - Ok(StreamSinkOutput::NeedMoreInput(Some(input.clone()))) + remaining -= input_num_rows; + state.set_remaining(remaining); + Ok(StreamingSinkOutput::NeedMoreInput(Some(input.clone()))) } Equal => { - self.remaining = 0; - Ok(StreamSinkOutput::Finished(Some(input.clone()))) + remaining = 0; + state.set_remaining(remaining); + Ok(StreamingSinkOutput::Finished(Some(input.clone()))) } Greater => { - let taken = input.head(self.remaining)?; - self.remaining -= taken.len(); - Ok(StreamSinkOutput::Finished(Some(Arc::new(taken)))) + let taken = input.head(remaining)?; + remaining -= taken.len(); + state.set_remaining(remaining); + Ok(StreamingSinkOutput::Finished(Some(Arc::new(taken)))) } } } @@ -56,4 +81,19 @@ impl StreamingSink for LimitSink { fn name(&self) -> &'static str { "Limit" } + + fn finalize( + &self, + _states: Vec>, + ) -> DaftResult>> { + Ok(None) + } + + fn make_state(&self) -> Box { + Box::new(LimitSinkState::new(self.limit)) + } + + fn max_concurrency(&self) -> usize { + 1 + } } diff --git a/src/daft-local-execution/src/sinks/mod.rs b/src/daft-local-execution/src/sinks/mod.rs index 39910e7995..7960e55a7c 100644 --- a/src/daft-local-execution/src/sinks/mod.rs +++ b/src/daft-local-execution/src/sinks/mod.rs @@ -3,5 +3,6 @@ pub mod blocking_sink; pub mod concat; pub mod hash_join_build; pub mod limit; +pub mod outer_hash_join_probe; pub mod sort; pub mod streaming_sink; diff --git a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs new file mode 100644 index 0000000000..fdbe559277 --- /dev/null +++ b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs @@ -0,0 +1,451 @@ +use std::sync::Arc; + +use arrow2::bitmap::MutableBitmap; +use common_error::DaftResult; +use daft_core::{ + prelude::{Field, Schema, SchemaRef}, + series::Series, +}; +use daft_dsl::ExprRef; +use daft_micropartition::MicroPartition; +use daft_plan::JoinType; +use daft_table::{GrowableTable, Probeable, Table}; +use indexmap::IndexSet; +use tracing::{info_span, instrument}; + +use super::streaming_sink::{StreamingSink, StreamingSinkOutput, StreamingSinkState}; +use crate::pipeline::PipelineResultType; + +struct IndexBitmapTracker { + bitmaps: Vec, +} + +impl IndexBitmapTracker { + fn new(tables: &Arc>) -> Self { + let bitmaps = tables + .iter() + .map(|table| MutableBitmap::from_len_zeroed(table.len())) + .collect(); + Self { bitmaps } + } + + fn set_true(&mut self, table_idx: usize, row_idx: usize) { + self.bitmaps[table_idx].set(row_idx, true); + } + + fn or(&self, other: &Self) -> Self { + let bitmaps = self + .bitmaps + .iter() + .zip(other.bitmaps.iter()) + .map(|(a, b)| a.or(b)) + .collect(); + Self { bitmaps } + } + + fn get_unused_indices(&self) -> impl Iterator + '_ { + self.bitmaps + .iter() + .enumerate() + .flat_map(|(table_idx, bitmap)| { + bitmap + .iter() + .enumerate() + .filter_map(move |(row_idx, is_set)| { + if !is_set { + Some((table_idx, row_idx)) + } else { + None + } + }) + }) + } +} + +enum OuterHashJoinProbeState { + Building, + ReadyToProbe( + Arc, + Arc>, + Option, + ), +} + +impl OuterHashJoinProbeState { + fn initialize_probe_state( + &mut self, + table: &Arc, + tables: &Arc>, + needs_bitmap: bool, + ) { + if let OuterHashJoinProbeState::Building = self { + *self = OuterHashJoinProbeState::ReadyToProbe( + table.clone(), + tables.clone(), + if needs_bitmap { + Some(IndexBitmapTracker::new(tables)) + } else { + None + }, + ); + } else { + panic!("OuterHashJoinProbeState should only be in Building state when setting table") + } + } + + fn get_probeable_and_tables(&self) -> (Arc, Arc>) { + if let OuterHashJoinProbeState::ReadyToProbe(probe_table, tables, _) = self { + (probe_table.clone(), tables.clone()) + } else { + panic!("get_probeable_and_table can only be used during the ReadyToProbe Phase") + } + } + + fn get_bitmap(&mut self) -> &mut Option { + if let OuterHashJoinProbeState::ReadyToProbe(_, _, bitmap) = self { + bitmap + } else { + panic!("get_bitmap can only be used during the ReadyToProbe Phase") + } + } +} + +impl StreamingSinkState for OuterHashJoinProbeState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } +} + +pub(crate) struct OuterHashJoinProbeSink { + probe_on: Vec, + common_join_keys: Vec, + left_non_join_columns: Vec, + right_non_join_columns: Vec, + join_type: JoinType, +} + +impl OuterHashJoinProbeSink { + pub(crate) fn new( + probe_on: Vec, + left_schema: &SchemaRef, + right_schema: &SchemaRef, + join_type: JoinType, + common_join_keys: IndexSet, + ) -> Self { + let left_non_join_columns = left_schema + .fields + .values() + .filter(|field| !common_join_keys.contains(field.name.as_str())) + .cloned() + .collect(); + let right_non_join_columns = right_schema + .fields + .values() + .filter(|field| !common_join_keys.contains(field.name.as_str())) + .cloned() + .collect(); + let common_join_keys = common_join_keys.into_iter().collect(); + Self { + probe_on, + common_join_keys, + left_non_join_columns, + right_non_join_columns, + join_type, + } + } + + fn probe_left_right( + &self, + input: &Arc, + state: &mut OuterHashJoinProbeState, + ) -> DaftResult> { + let (probe_table, tables) = state.get_probeable_and_tables(); + + let _growables = info_span!("OuterHashJoinProbeSink::build_growables").entered(); + + let mut build_side_growable = GrowableTable::new( + &tables.iter().collect::>(), + true, + tables.iter().map(|t| t.len()).sum(), + )?; + + let input_tables = input.get_tables()?; + + let mut probe_side_growable = + GrowableTable::new(&input_tables.iter().collect::>(), false, input.len())?; + + drop(_growables); + { + let _loop = info_span!("OuterHashJoinProbeSink::eval_and_probe").entered(); + for (probe_side_table_idx, table) in input_tables.iter().enumerate() { + let join_keys = table.eval_expression_list(&self.probe_on)?; + let idx_mapper = probe_table.probe_indices(&join_keys)?; + + for (probe_row_idx, inner_iter) in idx_mapper.make_iter().enumerate() { + if let Some(inner_iter) = inner_iter { + for (build_side_table_idx, build_row_idx) in inner_iter { + build_side_growable.extend( + build_side_table_idx as usize, + build_row_idx as usize, + 1, + ); + probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); + } + } else { + // if there's no match, we should still emit the probe side and fill the build side with nulls + build_side_growable.add_nulls(1); + probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); + } + } + } + } + let build_side_table = build_side_growable.build()?; + let probe_side_table = probe_side_growable.build()?; + + let final_table = if self.join_type == JoinType::Left { + let join_table = probe_side_table.get_columns(&self.common_join_keys)?; + let left = probe_side_table.get_columns( + &self + .left_non_join_columns + .iter() + .map(|f| f.name.as_str()) + .collect::>(), + )?; + let right = build_side_table.get_columns( + &self + .right_non_join_columns + .iter() + .map(|f| f.name.as_str()) + .collect::>(), + )?; + join_table.union(&left)?.union(&right)? + } else { + let join_table = probe_side_table.get_columns(&self.common_join_keys)?; + let left = build_side_table.get_columns( + &self + .left_non_join_columns + .iter() + .map(|f| f.name.as_str()) + .collect::>(), + )?; + let right = probe_side_table.get_columns( + &self + .right_non_join_columns + .iter() + .map(|f| f.name.as_str()) + .collect::>(), + )?; + join_table.union(&left)?.union(&right)? + }; + Ok(Arc::new(MicroPartition::new_loaded( + final_table.schema.clone(), + Arc::new(vec![final_table]), + None, + ))) + } + + fn probe_outer( + &self, + input: &Arc, + state: &mut OuterHashJoinProbeState, + ) -> DaftResult> { + let (probe_table, tables) = state.get_probeable_and_tables(); + let bitmap = state.get_bitmap(); + let _growables = info_span!("OuterHashJoinProbeSink::build_growables").entered(); + + // Need to set use_validity to true here because we add nulls to the build side + let mut build_side_growable = + GrowableTable::new(&tables.iter().collect::>(), true, 20)?; + + let input_tables = input.get_tables()?; + + let mut probe_side_growable = + GrowableTable::new(&input_tables.iter().collect::>(), false, 20)?; + + let left_idx_used = bitmap.as_mut().expect("bitmap should be set in outer join"); + + drop(_growables); + { + let _loop = info_span!("OuterHashJoinProbeSink::eval_and_probe").entered(); + for (probe_side_table_idx, table) in input_tables.iter().enumerate() { + let join_keys = table.eval_expression_list(&self.probe_on)?; + let idx_mapper = probe_table.probe_indices(&join_keys)?; + + for (probe_row_idx, inner_iter) in idx_mapper.make_iter().enumerate() { + if let Some(inner_iter) = inner_iter { + for (build_side_table_idx, build_row_idx) in inner_iter { + left_idx_used + .set_true(build_side_table_idx as usize, build_row_idx as usize); + build_side_growable.extend( + build_side_table_idx as usize, + build_row_idx as usize, + 1, + ); + probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); + } + } else { + // if there's no match, we should still emit the probe side and fill the build side with nulls + build_side_growable.add_nulls(1); + probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); + } + } + } + } + let build_side_table = build_side_growable.build()?; + let probe_side_table = probe_side_growable.build()?; + + let join_table = probe_side_table.get_columns(&self.common_join_keys)?; + let left = build_side_table.get_columns( + &self + .left_non_join_columns + .iter() + .map(|f| f.name.as_str()) + .collect::>(), + )?; + let right = probe_side_table.get_columns( + &self + .right_non_join_columns + .iter() + .map(|f| f.name.as_str()) + .collect::>(), + )?; + let final_table = join_table.union(&left)?.union(&right)?; + Ok(Arc::new(MicroPartition::new_loaded( + final_table.schema.clone(), + Arc::new(vec![final_table]), + None, + ))) + } + + fn finalize_outer( + &self, + mut states: Vec>, + ) -> DaftResult>> { + let states = states + .iter_mut() + .map(|s| { + s.as_any_mut() + .downcast_mut::() + .expect("OuterHashJoinProbeSink state should be OuterHashJoinProbeState") + }) + .collect::>(); + let tables = states + .first() + .expect("at least one state should be present") + .get_probeable_and_tables() + .1; + + let merged_bitmap = { + let bitmaps = states + .into_iter() + .map(|s| { + if let OuterHashJoinProbeState::ReadyToProbe(_, _, bitmap) = s { + bitmap + .take() + .expect("bitmap should be present in outer join") + } else { + panic!("OuterHashJoinProbeState should be in ReadyToProbe state") + } + }) + .collect::>(); + bitmaps.into_iter().fold(None, |acc, x| match acc { + None => Some(x), + Some(acc) => Some(acc.or(&x)), + }) + } + .expect("at least one bitmap should be present"); + + let mut build_side_growable = + GrowableTable::new(&tables.iter().collect::>(), true, 20)?; + + for (table_idx, row_idx) in merged_bitmap.get_unused_indices() { + build_side_growable.extend(table_idx, row_idx, 1); + } + + let build_side_table = build_side_growable.build()?; + + let join_table = build_side_table.get_columns(&self.common_join_keys)?; + let left = build_side_table.get_columns( + &self + .left_non_join_columns + .iter() + .map(|f| f.name.as_str()) + .collect::>(), + )?; + let right = { + let schema = Schema::new(self.right_non_join_columns.to_vec())?; + let columns = self + .right_non_join_columns + .iter() + .map(|field| Series::full_null(&field.name, &field.dtype, left.len())) + .collect::>(); + Table::new_unchecked(schema, columns, left.len()) + }; + let final_table = join_table.union(&left)?.union(&right)?; + Ok(Some(Arc::new(MicroPartition::new_loaded( + final_table.schema.clone(), + Arc::new(vec![final_table]), + None, + )))) + } +} + +impl StreamingSink for OuterHashJoinProbeSink { + #[instrument(skip_all, name = "OuterHashJoinProbeSink::execute")] + fn execute( + &self, + idx: usize, + input: &PipelineResultType, + state: &mut dyn StreamingSinkState, + ) -> DaftResult { + match idx { + 0 => { + let state = state + .as_any_mut() + .downcast_mut::() + .expect("OuterHashJoinProbeSink state should be OuterHashJoinProbeState"); + let (probe_table, tables) = input.as_probe_table(); + state.initialize_probe_state( + probe_table, + tables, + self.join_type == JoinType::Outer, + ); + Ok(StreamingSinkOutput::NeedMoreInput(None)) + } + _ => { + let state = state + .as_any_mut() + .downcast_mut::() + .expect("OuterHashJoinProbeSink state should be OuterHashJoinProbeState"); + let input = input.as_data(); + let out = match self.join_type { + JoinType::Left | JoinType::Right => self.probe_left_right(input, state), + JoinType::Outer => self.probe_outer(input, state), + _ => unreachable!( + "Only Left, Right, and Outer joins are supported in OuterHashJoinProbeSink" + ), + }?; + Ok(StreamingSinkOutput::NeedMoreInput(Some(out))) + } + } + } + + fn name(&self) -> &'static str { + "OuterHashJoinProbeSink" + } + + fn make_state(&self) -> Box { + Box::new(OuterHashJoinProbeState::Building) + } + + fn finalize( + &self, + states: Vec>, + ) -> DaftResult>> { + if self.join_type == JoinType::Outer { + self.finalize_outer(states) + } else { + Ok(None) + } + } +} diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index 1804a3e07e..686e96a599 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -3,14 +3,21 @@ use std::sync::Arc; use common_display::tree::TreeDisplay; use common_error::DaftResult; use daft_micropartition::MicroPartition; -use tracing::info_span; +use snafu::ResultExt; +use tracing::{info_span, instrument}; use crate::{ - channel::PipelineChannel, pipeline::PipelineNode, runtime_stats::RuntimeStatsContext, - ExecutionRuntimeHandle, NUM_CPUS, + channel::{create_channel, make_ordering_aware_channel, PipelineChannel, Receiver, Sender}, + pipeline::{PipelineNode, PipelineResultType}, + runtime_stats::{CountingReceiver, RuntimeStatsContext}, + ExecutionRuntimeHandle, JoinSnafu, NUM_CPUS, }; -pub enum StreamSinkOutput { +pub trait StreamingSinkState: Send + Sync { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any; +} + +pub enum StreamingSinkOutput { NeedMoreInput(Option>), #[allow(dead_code)] HasMoreOutput(Arc), @@ -19,35 +26,121 @@ pub enum StreamSinkOutput { pub trait StreamingSink: Send + Sync { fn execute( - &mut self, + &self, index: usize, - input: &Arc, - ) -> DaftResult; - #[allow(dead_code)] + input: &PipelineResultType, + state: &mut dyn StreamingSinkState, + ) -> DaftResult; + fn finalize( + &self, + states: Vec>, + ) -> DaftResult>>; fn name(&self) -> &'static str; + fn make_state(&self) -> Box; + fn max_concurrency(&self) -> usize { + *NUM_CPUS + } } pub(crate) struct StreamingSinkNode { - // use a RW lock - op: Arc>>, + op: Arc, name: &'static str, children: Vec>, runtime_stats: Arc, } impl StreamingSinkNode { - pub(crate) fn new(op: Box, children: Vec>) -> Self { + pub(crate) fn new(op: Arc, children: Vec>) -> Self { let name = op.name(); StreamingSinkNode { - op: Arc::new(tokio::sync::Mutex::new(op)), + op, name, children, runtime_stats: RuntimeStatsContext::new(), } } + pub(crate) fn boxed(self) -> Box { Box::new(self) } + + #[instrument(level = "info", skip_all, name = "StreamingSink::run_worker")] + async fn run_worker( + op: Arc, + mut input_receiver: Receiver<(usize, PipelineResultType)>, + output_sender: Sender>, + rt_context: Arc, + ) -> DaftResult> { + let span = info_span!("StreamingSink::Execute"); + let mut state = op.make_state(); + while let Some((idx, morsel)) = input_receiver.recv().await { + loop { + let result = + rt_context.in_span(&span, || op.execute(idx, &morsel, state.as_mut()))?; + match result { + StreamingSinkOutput::NeedMoreInput(Some(mp)) => { + let _ = output_sender.send(mp).await; + break; + } + StreamingSinkOutput::NeedMoreInput(None) => { + break; + } + StreamingSinkOutput::HasMoreOutput(mp) => { + let _ = output_sender.send(mp).await; + } + StreamingSinkOutput::Finished(mp) => { + if let Some(mp) = mp { + let _ = output_sender.send(mp).await; + } + return Ok(state); + } + } + } + } + Ok(state) + } + + fn spawn_workers( + op: Arc, + input_receivers: Vec>, + output_senders: Vec>>, + worker_set: &mut tokio::task::JoinSet>>, + stats: Arc, + ) { + for (input_receiver, output_sender) in input_receivers.into_iter().zip(output_senders) { + worker_set.spawn(Self::run_worker( + op.clone(), + input_receiver, + output_sender, + stats.clone(), + )); + } + } + + async fn forward_input_to_workers( + receivers: Vec, + worker_senders: Vec>, + ) -> DaftResult<()> { + let mut next_worker_idx = 0; + let mut send_to_next_worker = |idx, data: PipelineResultType| { + let next_worker_sender = worker_senders.get(next_worker_idx).unwrap(); + next_worker_idx = (next_worker_idx + 1) % worker_senders.len(); + next_worker_sender.send((idx, data)) + }; + + for (idx, mut receiver) in receivers.into_iter().enumerate() { + while let Some(morsel) = receiver.recv().await { + if morsel.should_broadcast() { + for worker_sender in worker_senders.iter() { + let _ = worker_sender.send((idx, morsel.clone())).await; + } + } else { + let _ = send_to_next_worker(idx, morsel.clone()).await; + } + } + } + Ok(()) + } } impl TreeDisplay for StreamingSinkNode { @@ -87,50 +180,50 @@ impl PipelineNode for StreamingSinkNode { maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, ) -> crate::Result { - let child = self - .children - .get_mut(0) - .expect("we should only have 1 child"); - let child_results_channel = child.start(true, runtime_handle)?; - let mut child_results_receiver = - child_results_channel.get_receiver_with_stats(&self.runtime_stats); - - let mut destination_channel = PipelineChannel::new(*NUM_CPUS, maintain_order); - let sender = destination_channel.get_next_sender_with_stats(&self.runtime_stats); + let mut child_result_receivers = Vec::with_capacity(self.children.len()); + for child in self.children.iter_mut() { + let child_result_channel = child.start(maintain_order, runtime_handle)?; + child_result_receivers + .push(child_result_channel.get_receiver_with_stats(&self.runtime_stats.clone())); + } + + let destination_channel = PipelineChannel::new(); + let destination_sender = + destination_channel.get_sender_with_stats(&self.runtime_stats.clone()); + let op = self.op.clone(); - let runtime_stats = self.runtime_stats.clone(); + let stats = self.runtime_stats.clone(); runtime_handle.spawn( async move { - // this should be a RWLock and run in concurrent workers - let span = info_span!("StreamingSink::execute"); - - let mut sink = op.lock().await; - let mut is_active = true; - while is_active && let Some(val) = child_results_receiver.recv().await { - let val = val.as_data(); - loop { - let result = runtime_stats.in_span(&span, || sink.execute(0, val))?; - match result { - StreamSinkOutput::HasMoreOutput(mp) => { - sender.send(mp.into()).await.unwrap(); - } - StreamSinkOutput::NeedMoreInput(mp) => { - if let Some(mp) = mp { - sender.send(mp.into()).await.unwrap(); - } - break; - } - StreamSinkOutput::Finished(mp) => { - if let Some(mp) = mp { - sender.send(mp.into()).await.unwrap(); - } - is_active = false; - break; - } - } - } + let num_workers = op.max_concurrency(); + let (input_senders, input_receivers) = + (0..num_workers).map(|_| create_channel(1)).unzip(); + let (output_senders, mut output_receiver) = + make_ordering_aware_channel(num_workers, maintain_order); + let mut worker_set = tokio::task::JoinSet::new(); + Self::spawn_workers( + op.clone(), + input_receivers, + output_senders, + &mut worker_set, + stats.clone(), + ); + Self::forward_input_to_workers(child_result_receivers, input_senders).await?; + + while let Some(morsel) = output_receiver.recv().await { + let _ = destination_sender.send(morsel.into()).await; + } + + let mut finished_states = Vec::with_capacity(num_workers); + while let Some(result) = worker_set.join_next().await { + let state = result.context(JoinSnafu)??; + finished_states.push(state); + } + + if let Some(finalized_result) = op.finalize(finished_states)? { + let _ = destination_sender.send(finalized_result.into()).await; } - DaftResult::Ok(()) + Ok(()) }, self.name(), ); diff --git a/src/daft-local-execution/src/sources/source.rs b/src/daft-local-execution/src/sources/source.rs index 175dc66427..1e6d77c814 100644 --- a/src/daft-local-execution/src/sources/source.rs +++ b/src/daft-local-execution/src/sources/source.rs @@ -76,8 +76,8 @@ impl PipelineNode for SourceNode { self.source .get_data(maintain_order, runtime_handle, self.io_stats.clone())?; - let mut channel = PipelineChannel::new(1, maintain_order); - let counting_sender = channel.get_next_sender_with_stats(&self.runtime_stats); + let channel = PipelineChannel::new(); + let counting_sender = channel.get_sender_with_stats(&self.runtime_stats); runtime_handle.spawn( async move { while let Some(part) = source_stream.next().await { diff --git a/tests/cookbook/test_joins.py b/tests/cookbook/test_joins.py index b51c863100..d80dce72a2 100644 --- a/tests/cookbook/test_joins.py +++ b/tests/cookbook/test_joins.py @@ -6,16 +6,20 @@ from daft.expressions import col from tests.conftest import assert_df_equals -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) + +def skip_invalid_join_strategies(join_strategy): + if context.get_context().daft_execution_config.enable_native_executor is True: + if join_strategy not in [None, "hash"]: + pytest.skip("Native executor fails for these tests") @pytest.mark.parametrize( - "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True + "join_strategy", + [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], + indirect=True, ) def test_simple_join(join_strategy, daft_df, service_requests_csv_pd_df, repartition_nparts): + skip_invalid_join_strategies(join_strategy) daft_df = daft_df.repartition(repartition_nparts) daft_df_left = daft_df.select(col("Unique Key"), col("Borough")) daft_df_right = daft_df.select(col("Unique Key"), col("Created Date")) @@ -33,9 +37,12 @@ def test_simple_join(join_strategy, daft_df, service_requests_csv_pd_df, reparti @pytest.mark.parametrize( - "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True + "join_strategy", + [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], + indirect=True, ) def test_simple_self_join(join_strategy, daft_df, service_requests_csv_pd_df, repartition_nparts): + skip_invalid_join_strategies(join_strategy) daft_df = daft_df.repartition(repartition_nparts) daft_df = daft_df.select(col("Unique Key"), col("Borough")) @@ -44,7 +51,11 @@ def test_simple_self_join(join_strategy, daft_df, service_requests_csv_pd_df, re service_requests_csv_pd_df = service_requests_csv_pd_df[["Unique Key", "Borough"]] service_requests_csv_pd_df = ( service_requests_csv_pd_df.set_index("Unique Key") - .join(service_requests_csv_pd_df.set_index("Unique Key"), how="inner", rsuffix="_right") + .join( + service_requests_csv_pd_df.set_index("Unique Key"), + how="inner", + rsuffix="_right", + ) .reset_index() ) service_requests_csv_pd_df = service_requests_csv_pd_df.rename({"Borough_right": "right.Borough"}, axis=1) @@ -53,9 +64,12 @@ def test_simple_self_join(join_strategy, daft_df, service_requests_csv_pd_df, re @pytest.mark.parametrize( - "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True + "join_strategy", + [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], + indirect=True, ) def test_simple_join_missing_rvalues(join_strategy, daft_df, service_requests_csv_pd_df, repartition_nparts): + skip_invalid_join_strategies(join_strategy) daft_df_right = daft_df.sort("Unique Key").limit(25).repartition(repartition_nparts) daft_df_left = daft_df.repartition(repartition_nparts) daft_df_left = daft_df_left.select(col("Unique Key"), col("Borough")) @@ -76,9 +90,12 @@ def test_simple_join_missing_rvalues(join_strategy, daft_df, service_requests_cs @pytest.mark.parametrize( - "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True + "join_strategy", + [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], + indirect=True, ) def test_simple_join_missing_lvalues(join_strategy, daft_df, service_requests_csv_pd_df, repartition_nparts): + skip_invalid_join_strategies(join_strategy) daft_df_right = daft_df.repartition(repartition_nparts) daft_df_left = daft_df.sort(col("Unique Key")).limit(25).repartition(repartition_nparts) daft_df_left = daft_df_left.select(col("Unique Key"), col("Borough")) diff --git a/tests/dataframe/test_joins.py b/tests/dataframe/test_joins.py index b0bdbf9df4..799f5ba056 100644 --- a/tests/dataframe/test_joins.py +++ b/tests/dataframe/test_joins.py @@ -11,7 +11,7 @@ def skip_invalid_join_strategies(join_strategy, join_type): if context.get_context().daft_execution_config.enable_native_executor is True: - if join_type == "outer" or join_strategy not in [None, "hash"]: + if join_strategy not in [None, "hash"]: pytest.skip("Native executor fails for these tests") else: if (join_strategy == "sort_merge" or join_strategy == "sort_merge_aligned_boundaries") and join_type != "inner": From 3aa8aafbdb64ec0736f13490bc05573a3c5699ea Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Thu, 19 Sep 2024 13:26:05 -0700 Subject: [PATCH 02/13] clean up lil bit and fix test --- src/daft-local-execution/src/channel.rs | 2 +- .../src/intermediate_ops/intermediate_op.rs | 10 +++++++--- .../src/sinks/streaming_sink.rs | 13 ++++++------- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/daft-local-execution/src/channel.rs b/src/daft-local-execution/src/channel.rs index 84f9bbe8a3..f5351bd3cf 100644 --- a/src/daft-local-execution/src/channel.rs +++ b/src/daft-local-execution/src/channel.rs @@ -39,7 +39,7 @@ impl PipelineChannel { } } -pub(crate) fn make_ordering_aware_channel( +pub(crate) fn create_ordering_aware_channel( buffer_size: usize, ordered: bool, ) -> (Vec>, OrderingAwareReceiver) { diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index 8f561c94b7..6c0ead0253 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -3,14 +3,15 @@ use std::sync::Arc; use common_display::tree::TreeDisplay; use common_error::DaftResult; use daft_micropartition::MicroPartition; +use snafu::ResultExt; use tracing::{info_span, instrument}; use super::buffer::OperatorBuffer; use crate::{ - channel::{create_channel, make_ordering_aware_channel, PipelineChannel, Receiver, Sender}, + channel::{create_channel, create_ordering_aware_channel, PipelineChannel, Receiver, Sender}, pipeline::{PipelineNode, PipelineResultType}, runtime_stats::{CountingReceiver, RuntimeStatsContext}, - ExecutionRuntimeHandle, NUM_CPUS, + ExecutionRuntimeHandle, JoinSnafu, NUM_CPUS, }; pub trait IntermediateOperatorState: Send + Sync { @@ -207,7 +208,7 @@ impl PipelineNode for IntermediateNode { let (input_senders, input_receivers) = (0..num_workers).map(|_| create_channel(1)).unzip(); let (output_senders, mut output_receiver) = - make_ordering_aware_channel(num_workers, maintain_order); + create_ordering_aware_channel(num_workers, maintain_order); let mut worker_set = tokio::task::JoinSet::new(); Self::spawn_workers( op.clone(), @@ -222,6 +223,9 @@ impl PipelineNode for IntermediateNode { while let Some(morsel) = output_receiver.recv().await { let _ = destination_sender.send(morsel.into()).await; } + while let Some(result) = worker_set.join_next().await { + result.context(JoinSnafu)??; + } Ok(()) }, self.intermediate_op.name(), diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index 686e96a599..60cc165f0f 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -7,7 +7,7 @@ use snafu::ResultExt; use tracing::{info_span, instrument}; use crate::{ - channel::{create_channel, make_ordering_aware_channel, PipelineChannel, Receiver, Sender}, + channel::{create_channel, create_ordering_aware_channel, PipelineChannel, Receiver, Sender}, pipeline::{PipelineNode, PipelineResultType}, runtime_stats::{CountingReceiver, RuntimeStatsContext}, ExecutionRuntimeHandle, JoinSnafu, NUM_CPUS, @@ -78,11 +78,10 @@ impl StreamingSinkNode { let result = rt_context.in_span(&span, || op.execute(idx, &morsel, state.as_mut()))?; match result { - StreamingSinkOutput::NeedMoreInput(Some(mp)) => { - let _ = output_sender.send(mp).await; - break; - } - StreamingSinkOutput::NeedMoreInput(None) => { + StreamingSinkOutput::NeedMoreInput(mp) => { + if let Some(mp) = mp { + let _ = output_sender.send(mp).await; + } break; } StreamingSinkOutput::HasMoreOutput(mp) => { @@ -199,7 +198,7 @@ impl PipelineNode for StreamingSinkNode { let (input_senders, input_receivers) = (0..num_workers).map(|_| create_channel(1)).unzip(); let (output_senders, mut output_receiver) = - make_ordering_aware_channel(num_workers, maintain_order); + create_ordering_aware_channel(num_workers, maintain_order); let mut worker_set = tokio::task::JoinSet::new(); Self::spawn_workers( op.clone(), From 86a5fa4550de2f8daad9359efefeabf4dcee259c Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Thu, 19 Sep 2024 13:40:53 -0700 Subject: [PATCH 03/13] get right side null schema in .new --- .../src/sinks/outer_hash_join_probe.rs | 92 ++++++------------- 1 file changed, 29 insertions(+), 63 deletions(-) diff --git a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs index fdbe559277..740a060e80 100644 --- a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs +++ b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use arrow2::bitmap::MutableBitmap; use common_error::DaftResult; use daft_core::{ - prelude::{Field, Schema, SchemaRef}, + prelude::{Schema, SchemaRef}, series::Series, }; use daft_dsl::ExprRef; @@ -119,8 +119,9 @@ impl StreamingSinkState for OuterHashJoinProbeState { pub(crate) struct OuterHashJoinProbeSink { probe_on: Vec, common_join_keys: Vec, - left_non_join_columns: Vec, - right_non_join_columns: Vec, + left_non_join_columns: Vec, + right_non_join_columns: Vec, + right_non_join_schema: SchemaRef, join_type: JoinType, } @@ -134,22 +135,26 @@ impl OuterHashJoinProbeSink { ) -> Self { let left_non_join_columns = left_schema .fields - .values() - .filter(|field| !common_join_keys.contains(field.name.as_str())) + .keys() + .filter(|c| !common_join_keys.contains(*c)) .cloned() .collect(); - let right_non_join_columns = right_schema + let right_non_join_fields = right_schema .fields .values() - .filter(|field| !common_join_keys.contains(field.name.as_str())) + .filter(|f| !common_join_keys.contains(&f.name)) .cloned() .collect(); + let right_non_join_schema = + Arc::new(Schema::new(right_non_join_fields).expect("right schema should be valid")); + let right_non_join_columns = right_non_join_schema.fields.keys().cloned().collect(); let common_join_keys = common_join_keys.into_iter().collect(); Self { probe_on, common_join_keys, left_non_join_columns, right_non_join_columns, + right_non_join_schema, join_type, } } @@ -204,37 +209,13 @@ impl OuterHashJoinProbeSink { let final_table = if self.join_type == JoinType::Left { let join_table = probe_side_table.get_columns(&self.common_join_keys)?; - let left = probe_side_table.get_columns( - &self - .left_non_join_columns - .iter() - .map(|f| f.name.as_str()) - .collect::>(), - )?; - let right = build_side_table.get_columns( - &self - .right_non_join_columns - .iter() - .map(|f| f.name.as_str()) - .collect::>(), - )?; + let left = probe_side_table.get_columns(&self.left_non_join_columns)?; + let right = build_side_table.get_columns(&self.right_non_join_columns)?; join_table.union(&left)?.union(&right)? } else { let join_table = probe_side_table.get_columns(&self.common_join_keys)?; - let left = build_side_table.get_columns( - &self - .left_non_join_columns - .iter() - .map(|f| f.name.as_str()) - .collect::>(), - )?; - let right = probe_side_table.get_columns( - &self - .right_non_join_columns - .iter() - .map(|f| f.name.as_str()) - .collect::>(), - )?; + let left = build_side_table.get_columns(&self.left_non_join_columns)?; + let right = probe_side_table.get_columns(&self.right_non_join_columns)?; join_table.union(&left)?.union(&right)? }; Ok(Arc::new(MicroPartition::new_loaded( @@ -254,13 +235,16 @@ impl OuterHashJoinProbeSink { let _growables = info_span!("OuterHashJoinProbeSink::build_growables").entered(); // Need to set use_validity to true here because we add nulls to the build side - let mut build_side_growable = - GrowableTable::new(&tables.iter().collect::>(), true, 20)?; + let mut build_side_growable = GrowableTable::new( + &tables.iter().collect::>(), + true, + tables.iter().map(|t| t.len()).sum(), + )?; let input_tables = input.get_tables()?; let mut probe_side_growable = - GrowableTable::new(&input_tables.iter().collect::>(), false, 20)?; + GrowableTable::new(&input_tables.iter().collect::>(), false, input.len())?; let left_idx_used = bitmap.as_mut().expect("bitmap should be set in outer join"); @@ -295,20 +279,8 @@ impl OuterHashJoinProbeSink { let probe_side_table = probe_side_growable.build()?; let join_table = probe_side_table.get_columns(&self.common_join_keys)?; - let left = build_side_table.get_columns( - &self - .left_non_join_columns - .iter() - .map(|f| f.name.as_str()) - .collect::>(), - )?; - let right = probe_side_table.get_columns( - &self - .right_non_join_columns - .iter() - .map(|f| f.name.as_str()) - .collect::>(), - )?; + let left = build_side_table.get_columns(&self.left_non_join_columns)?; + let right = probe_side_table.get_columns(&self.right_non_join_columns)?; let final_table = join_table.union(&left)?.union(&right)?; Ok(Arc::new(MicroPartition::new_loaded( final_table.schema.clone(), @@ -365,21 +337,15 @@ impl OuterHashJoinProbeSink { let build_side_table = build_side_growable.build()?; let join_table = build_side_table.get_columns(&self.common_join_keys)?; - let left = build_side_table.get_columns( - &self - .left_non_join_columns - .iter() - .map(|f| f.name.as_str()) - .collect::>(), - )?; + let left = build_side_table.get_columns(&self.left_non_join_columns)?; let right = { - let schema = Schema::new(self.right_non_join_columns.to_vec())?; let columns = self - .right_non_join_columns - .iter() + .right_non_join_schema + .fields + .values() .map(|field| Series::full_null(&field.name, &field.dtype, left.len())) .collect::>(); - Table::new_unchecked(schema, columns, left.len()) + Table::new_unchecked(self.right_non_join_schema.clone(), columns, left.len()) }; let final_table = join_table.union(&left)?.union(&right)?; Ok(Some(Arc::new(MicroPartition::new_loaded( From dcac8ba6d3a519448c3f600dcf8c062ecf1094bf Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Thu, 19 Sep 2024 13:52:47 -0700 Subject: [PATCH 04/13] more cleanup --- .../src/intermediate_ops/intermediate_op.rs | 9 ++++++--- src/daft-local-execution/src/lib.rs | 9 +++++++-- .../src/sinks/streaming_sink.rs | 14 ++++++++------ 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index 6c0ead0253..7f0a945502 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -9,9 +9,10 @@ use tracing::{info_span, instrument}; use super::buffer::OperatorBuffer; use crate::{ channel::{create_channel, create_ordering_aware_channel, PipelineChannel, Receiver, Sender}, + create_worker_set, pipeline::{PipelineNode, PipelineResultType}, runtime_stats::{CountingReceiver, RuntimeStatsContext}, - ExecutionRuntimeHandle, JoinSnafu, NUM_CPUS, + ExecutionRuntimeHandle, JoinSnafu, WorkerSet, NUM_CPUS, }; pub trait IntermediateOperatorState: Send + Sync { @@ -102,7 +103,7 @@ impl IntermediateNode { op: Arc, input_receivers: Vec>, output_senders: Vec>>, - worker_set: &mut tokio::task::JoinSet>, + worker_set: &mut WorkerSet>, stats: Arc, ) { for (input_receiver, output_sender) in input_receivers.into_iter().zip(output_senders) { @@ -209,7 +210,8 @@ impl PipelineNode for IntermediateNode { (0..num_workers).map(|_| create_channel(1)).unzip(); let (output_senders, mut output_receiver) = create_ordering_aware_channel(num_workers, maintain_order); - let mut worker_set = tokio::task::JoinSet::new(); + + let mut worker_set = create_worker_set(); Self::spawn_workers( op.clone(), input_receivers, @@ -217,6 +219,7 @@ impl PipelineNode for IntermediateNode { &mut worker_set, stats.clone(), ); + Self::forward_input_to_workers(child_result_receivers, input_senders, morsel_size) .await?; diff --git a/src/daft-local-execution/src/lib.rs b/src/daft-local-execution/src/lib.rs index 732f306768..01250dc43f 100644 --- a/src/daft-local-execution/src/lib.rs +++ b/src/daft-local-execution/src/lib.rs @@ -14,15 +14,20 @@ lazy_static! { pub static ref NUM_CPUS: usize = std::thread::available_parallelism().unwrap().get(); } +pub(crate) type WorkerSet = tokio::task::JoinSet; +pub(crate) fn create_worker_set() -> WorkerSet { + tokio::task::JoinSet::new() +} + pub struct ExecutionRuntimeHandle { - worker_set: tokio::task::JoinSet>, + worker_set: WorkerSet>, default_morsel_size: usize, } impl ExecutionRuntimeHandle { pub fn new(default_morsel_size: usize) -> Self { Self { - worker_set: tokio::task::JoinSet::new(), + worker_set: create_worker_set(), default_morsel_size, } } diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index 60cc165f0f..b580ae1d75 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -8,9 +8,10 @@ use tracing::{info_span, instrument}; use crate::{ channel::{create_channel, create_ordering_aware_channel, PipelineChannel, Receiver, Sender}, + create_worker_set, pipeline::{PipelineNode, PipelineResultType}, runtime_stats::{CountingReceiver, RuntimeStatsContext}, - ExecutionRuntimeHandle, JoinSnafu, NUM_CPUS, + ExecutionRuntimeHandle, JoinSnafu, WorkerSet, NUM_CPUS, }; pub trait StreamingSinkState: Send + Sync { @@ -103,7 +104,7 @@ impl StreamingSinkNode { op: Arc, input_receivers: Vec>, output_senders: Vec>>, - worker_set: &mut tokio::task::JoinSet>>, + worker_set: &mut WorkerSet>>, stats: Arc, ) { for (input_receiver, output_sender) in input_receivers.into_iter().zip(output_senders) { @@ -191,7 +192,7 @@ impl PipelineNode for StreamingSinkNode { destination_channel.get_sender_with_stats(&self.runtime_stats.clone()); let op = self.op.clone(); - let stats = self.runtime_stats.clone(); + let runtime_stats = self.runtime_stats.clone(); runtime_handle.spawn( async move { let num_workers = op.max_concurrency(); @@ -199,16 +200,17 @@ impl PipelineNode for StreamingSinkNode { (0..num_workers).map(|_| create_channel(1)).unzip(); let (output_senders, mut output_receiver) = create_ordering_aware_channel(num_workers, maintain_order); - let mut worker_set = tokio::task::JoinSet::new(); + + let mut worker_set = create_worker_set(); Self::spawn_workers( op.clone(), input_receivers, output_senders, &mut worker_set, - stats.clone(), + runtime_stats.clone(), ); - Self::forward_input_to_workers(child_result_receivers, input_senders).await?; + Self::forward_input_to_workers(child_result_receivers, input_senders).await?; while let Some(morsel) = output_receiver.recv().await { let _ = destination_sender.send(morsel.into()).await; } From bb0a9448d905e93c0f5a81d3df46641f8db54ebc Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Thu, 19 Sep 2024 14:06:12 -0700 Subject: [PATCH 05/13] reduce code --- src/daft-local-execution/src/channel.rs | 98 +++++++++++------- .../src/intermediate_ops/intermediate_op.rs | 99 ++++++++----------- src/daft-local-execution/src/runtime_stats.rs | 9 +- .../src/sinks/blocking_sink.rs | 7 +- .../src/sinks/streaming_sink.rs | 11 +-- .../src/sources/source.rs | 4 +- 6 files changed, 114 insertions(+), 114 deletions(-) diff --git a/src/daft-local-execution/src/channel.rs b/src/daft-local-execution/src/channel.rs index f5351bd3cf..bb22b9d4ea 100644 --- a/src/daft-local-execution/src/channel.rs +++ b/src/daft-local-execution/src/channel.rs @@ -13,64 +13,86 @@ pub fn create_channel(buffer_size: usize) -> (Sender, Receiver) { } pub struct PipelineChannel { - sender: Sender, - receiver: Receiver, + sender: PipelineSender, + receiver: PipelineReceiver, } impl PipelineChannel { - pub(crate) fn new() -> Self { - let (sender, receiver) = create_channel(1); - Self { sender, receiver } + pub fn new(buffer_size: usize, in_order: bool) -> Self { + match in_order { + true => { + let (senders, receivers) = (0..buffer_size).map(|_| create_channel(1)).unzip(); + let sender = PipelineSender::InOrder(RoundRobinSender::new(senders)); + let receiver = PipelineReceiver::InOrder(RoundRobinReceiver::new(receivers)); + Self { sender, receiver } + } + false => { + let (sender, receiver) = create_channel(buffer_size); + let sender = PipelineSender::OutOfOrder(sender); + let receiver = PipelineReceiver::OutOfOrder(receiver); + Self { sender, receiver } + } + } } - pub(crate) fn get_sender_with_stats(&self, stats: &Arc) -> CountingSender { - CountingSender::new(self.sender.clone(), stats.clone()) + fn get_next_sender(&mut self) -> Sender { + match &mut self.sender { + PipelineSender::InOrder(rr) => rr.get_next_sender(), + PipelineSender::OutOfOrder(sender) => sender.clone(), + } } - pub(crate) fn get_receiver_with_stats( - self, - stats: &Arc, - ) -> CountingReceiver { - CountingReceiver::new(self.receiver, stats.clone()) + pub(crate) fn get_next_sender_with_stats( + &mut self, + rt: &Arc, + ) -> CountingSender { + CountingSender::new(self.get_next_sender(), rt.clone()) } - pub(crate) fn get_receiver(self) -> Receiver { + pub fn get_receiver(self) -> PipelineReceiver { self.receiver } + + pub(crate) fn get_receiver_with_stats(self, rt: &Arc) -> CountingReceiver { + CountingReceiver::new(self.get_receiver(), rt.clone()) + } } -pub(crate) fn create_ordering_aware_channel( - buffer_size: usize, - ordered: bool, -) -> (Vec>, OrderingAwareReceiver) { - match ordered { - true => { - let (senders, receivers) = (0..buffer_size).map(|_| create_channel(1)).unzip(); - ( - senders, - OrderingAwareReceiver::Ordered(RoundRobinReceiver::new(receivers)), - ) - } - false => { - let (sender, receiver) = create_channel(buffer_size); - ( - (0..buffer_size).map(|_| sender.clone()).collect(), - OrderingAwareReceiver::Unordered(receiver), - ) +pub enum PipelineSender { + InOrder(RoundRobinSender), + OutOfOrder(Sender), +} + +pub struct RoundRobinSender { + senders: Vec>, + curr_sender_idx: usize, +} + +impl RoundRobinSender { + pub fn new(senders: Vec>) -> Self { + Self { + senders, + curr_sender_idx: 0, } } + + pub fn get_next_sender(&mut self) -> Sender { + let next_idx = self.curr_sender_idx; + self.curr_sender_idx = (next_idx + 1) % self.senders.len(); + self.senders[next_idx].clone() + } } -pub enum OrderingAwareReceiver { - Ordered(RoundRobinReceiver), - Unordered(Receiver), +pub enum PipelineReceiver { + InOrder(RoundRobinReceiver), + OutOfOrder(Receiver), } -impl OrderingAwareReceiver { - pub async fn recv(&mut self) -> Option { +impl PipelineReceiver { + pub async fn recv(&mut self) -> Option { match self { - OrderingAwareReceiver::Ordered(rr_receiver) => rr_receiver.recv().await, - OrderingAwareReceiver::Unordered(receiver) => receiver.recv().await, + PipelineReceiver::InOrder(rr) => rr.recv().await, + PipelineReceiver::OutOfOrder(r) => r.recv().await, } } } diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index 7f0a945502..abb5c5388b 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -3,16 +3,14 @@ use std::sync::Arc; use common_display::tree::TreeDisplay; use common_error::DaftResult; use daft_micropartition::MicroPartition; -use snafu::ResultExt; use tracing::{info_span, instrument}; use super::buffer::OperatorBuffer; use crate::{ - channel::{create_channel, create_ordering_aware_channel, PipelineChannel, Receiver, Sender}, - create_worker_set, + channel::{create_channel, PipelineChannel, Receiver, Sender}, pipeline::{PipelineNode, PipelineResultType}, - runtime_stats::{CountingReceiver, RuntimeStatsContext}, - ExecutionRuntimeHandle, JoinSnafu, WorkerSet, NUM_CPUS, + runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}, + ExecutionRuntimeHandle, NUM_CPUS, }; pub trait IntermediateOperatorState: Send + Sync { @@ -70,28 +68,28 @@ impl IntermediateNode { } #[instrument(level = "info", skip_all, name = "IntermediateOperator::run_worker")] - async fn run_worker( + pub async fn run_worker( op: Arc, - mut input_receiver: Receiver<(usize, PipelineResultType)>, - output_sender: Sender>, + mut receiver: Receiver<(usize, PipelineResultType)>, + sender: CountingSender, rt_context: Arc, ) -> DaftResult<()> { let span = info_span!("IntermediateOp::execute"); let mut state = op.make_state(); - while let Some((idx, morsel)) = input_receiver.recv().await { + while let Some((idx, morsel)) = receiver.recv().await { loop { let result = rt_context.in_span(&span, || op.execute(idx, &morsel, state.as_mut()))?; match result { IntermediateOperatorResult::NeedMoreInput(Some(mp)) => { - let _ = output_sender.send(mp).await; + let _ = sender.send(mp.into()).await; break; } IntermediateOperatorResult::NeedMoreInput(None) => { break; } IntermediateOperatorResult::HasMoreOutput(mp) => { - let _ = output_sender.send(mp).await; + let _ = sender.send(mp.into()).await; } } } @@ -99,24 +97,32 @@ impl IntermediateNode { Ok(()) } - fn spawn_workers( - op: Arc, - input_receivers: Vec>, - output_senders: Vec>>, - worker_set: &mut WorkerSet>, - stats: Arc, - ) { - for (input_receiver, output_sender) in input_receivers.into_iter().zip(output_senders) { - worker_set.spawn(Self::run_worker( - op.clone(), - input_receiver, - output_sender, - stats.clone(), - )); + pub fn spawn_workers( + &self, + num_workers: usize, + destination_channel: &mut PipelineChannel, + runtime_handle: &mut ExecutionRuntimeHandle, + ) -> Vec> { + let mut worker_senders = Vec::with_capacity(num_workers); + for _ in 0..num_workers { + let (worker_sender, worker_receiver) = create_channel(1); + let destination_sender = + destination_channel.get_next_sender_with_stats(&self.runtime_stats); + runtime_handle.spawn( + Self::run_worker( + self.intermediate_op.clone(), + worker_receiver, + destination_sender, + self.runtime_stats.clone(), + ), + self.intermediate_op.name(), + ); + worker_senders.push(worker_sender); } + worker_senders } - async fn forward_input_to_workers( + pub async fn send_to_workers( receivers: Vec, worker_senders: Vec>, morsel_size: usize, @@ -196,41 +202,16 @@ impl PipelineNode for IntermediateNode { child_result_receivers .push(child_result_channel.get_receiver_with_stats(&self.runtime_stats)); } + let mut destination_channel = PipelineChannel::new(*NUM_CPUS, maintain_order); - let destination_channel = PipelineChannel::new(); - let destination_sender = destination_channel.get_sender_with_stats(&self.runtime_stats); - - let op = self.intermediate_op.clone(); - let stats = self.runtime_stats.clone(); - let morsel_size = runtime_handle.default_morsel_size(); + let worker_senders = + self.spawn_workers(*NUM_CPUS, &mut destination_channel, runtime_handle); runtime_handle.spawn( - async move { - let num_workers = *NUM_CPUS; - let (input_senders, input_receivers) = - (0..num_workers).map(|_| create_channel(1)).unzip(); - let (output_senders, mut output_receiver) = - create_ordering_aware_channel(num_workers, maintain_order); - - let mut worker_set = create_worker_set(); - Self::spawn_workers( - op.clone(), - input_receivers, - output_senders, - &mut worker_set, - stats.clone(), - ); - - Self::forward_input_to_workers(child_result_receivers, input_senders, morsel_size) - .await?; - - while let Some(morsel) = output_receiver.recv().await { - let _ = destination_sender.send(morsel.into()).await; - } - while let Some(result) = worker_set.join_next().await { - result.context(JoinSnafu)??; - } - Ok(()) - }, + Self::send_to_workers( + child_result_receivers, + worker_senders, + runtime_handle.default_morsel_size(), + ), self.intermediate_op.name(), ); Ok(destination_channel) diff --git a/src/daft-local-execution/src/runtime_stats.rs b/src/daft-local-execution/src/runtime_stats.rs index cc8b8a368a..7489a8fd36 100644 --- a/src/daft-local-execution/src/runtime_stats.rs +++ b/src/daft-local-execution/src/runtime_stats.rs @@ -8,7 +8,7 @@ use std::{ use tokio::sync::mpsc::error::SendError; use crate::{ - channel::{Receiver, Sender}, + channel::{PipelineReceiver, Sender}, pipeline::PipelineResultType, }; @@ -133,15 +133,12 @@ impl CountingSender { } pub(crate) struct CountingReceiver { - receiver: Receiver, + receiver: PipelineReceiver, rt: Arc, } impl CountingReceiver { - pub(crate) fn new( - receiver: Receiver, - rt: Arc, - ) -> Self { + pub(crate) fn new(receiver: PipelineReceiver, rt: Arc) -> Self { Self { receiver, rt } } #[inline] diff --git a/src/daft-local-execution/src/sinks/blocking_sink.rs b/src/daft-local-execution/src/sinks/blocking_sink.rs index 00012cb2f1..09e42ae81f 100644 --- a/src/daft-local-execution/src/sinks/blocking_sink.rs +++ b/src/daft-local-execution/src/sinks/blocking_sink.rs @@ -78,7 +78,7 @@ impl PipelineNode for BlockingSinkNode { fn start( &mut self, - _maintain_order: bool, + maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, ) -> crate::Result { let child = self.child.as_mut(); @@ -86,8 +86,9 @@ impl PipelineNode for BlockingSinkNode { .start(false, runtime_handle)? .get_receiver_with_stats(&self.runtime_stats); - let destination_channel = PipelineChannel::new(); - let destination_sender = destination_channel.get_sender_with_stats(&self.runtime_stats); + let mut destination_channel = PipelineChannel::new(1, maintain_order); + let destination_sender = + destination_channel.get_next_sender_with_stats(&self.runtime_stats); let op = self.op.clone(); let rt_context = self.runtime_stats.clone(); runtime_handle.spawn( diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index b580ae1d75..d3d09063a6 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -7,7 +7,7 @@ use snafu::ResultExt; use tracing::{info_span, instrument}; use crate::{ - channel::{create_channel, create_ordering_aware_channel, PipelineChannel, Receiver, Sender}, + channel::{create_channel, PipelineChannel, Receiver, Sender}, create_worker_set, pipeline::{PipelineNode, PipelineResultType}, runtime_stats::{CountingReceiver, RuntimeStatsContext}, @@ -187,9 +187,9 @@ impl PipelineNode for StreamingSinkNode { .push(child_result_channel.get_receiver_with_stats(&self.runtime_stats.clone())); } - let destination_channel = PipelineChannel::new(); + let mut destination_channel = PipelineChannel::new(1, maintain_order); let destination_sender = - destination_channel.get_sender_with_stats(&self.runtime_stats.clone()); + destination_channel.get_next_sender_with_stats(&self.runtime_stats); let op = self.op.clone(); let runtime_stats = self.runtime_stats.clone(); @@ -198,14 +198,13 @@ impl PipelineNode for StreamingSinkNode { let num_workers = op.max_concurrency(); let (input_senders, input_receivers) = (0..num_workers).map(|_| create_channel(1)).unzip(); - let (output_senders, mut output_receiver) = - create_ordering_aware_channel(num_workers, maintain_order); + let (output_sender, mut output_receiver) = create_channel(num_workers); let mut worker_set = create_worker_set(); Self::spawn_workers( op.clone(), input_receivers, - output_senders, + (0..num_workers).map(|_| output_sender.clone()).collect(), &mut worker_set, runtime_stats.clone(), ); diff --git a/src/daft-local-execution/src/sources/source.rs b/src/daft-local-execution/src/sources/source.rs index 1e6d77c814..175dc66427 100644 --- a/src/daft-local-execution/src/sources/source.rs +++ b/src/daft-local-execution/src/sources/source.rs @@ -76,8 +76,8 @@ impl PipelineNode for SourceNode { self.source .get_data(maintain_order, runtime_handle, self.io_stats.clone())?; - let channel = PipelineChannel::new(); - let counting_sender = channel.get_sender_with_stats(&self.runtime_stats); + let mut channel = PipelineChannel::new(1, maintain_order); + let counting_sender = channel.get_next_sender_with_stats(&self.runtime_stats); runtime_handle.spawn( async move { while let Some(part) = source_stream.next().await { From c797da442b3f3384923b26747edb4b580dd49138 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Thu, 19 Sep 2024 14:31:24 -0700 Subject: [PATCH 06/13] need to drop sender --- .../src/sinks/streaming_sink.rs | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index d3d09063a6..1ed9758159 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -103,18 +103,19 @@ impl StreamingSinkNode { fn spawn_workers( op: Arc, input_receivers: Vec>, - output_senders: Vec>>, worker_set: &mut WorkerSet>>, stats: Arc, - ) { - for (input_receiver, output_sender) in input_receivers.into_iter().zip(output_senders) { + ) -> Receiver> { + let (output_sender, output_receiver) = create_channel(input_receivers.len()); + for input_receiver in input_receivers { worker_set.spawn(Self::run_worker( op.clone(), input_receiver, - output_sender, + output_sender.clone(), stats.clone(), )); } + output_receiver } async fn forward_input_to_workers( @@ -198,13 +199,11 @@ impl PipelineNode for StreamingSinkNode { let num_workers = op.max_concurrency(); let (input_senders, input_receivers) = (0..num_workers).map(|_| create_channel(1)).unzip(); - let (output_sender, mut output_receiver) = create_channel(num_workers); let mut worker_set = create_worker_set(); - Self::spawn_workers( + let mut output_receiver = Self::spawn_workers( op.clone(), input_receivers, - (0..num_workers).map(|_| output_sender.clone()).collect(), &mut worker_set, runtime_stats.clone(), ); From dbe7723d51cdf01db710ae2e2c4f5d07184d8d68 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Wed, 25 Sep 2024 15:41:29 -0700 Subject: [PATCH 07/13] bitmap improvements + probe state --- Cargo.lock | 1 - src/arrow2/src/bitmap/mutable.rs | 19 -- src/daft-core/src/prelude.rs | 2 + src/daft-local-execution/Cargo.toml | 1 - .../anti_semi_hash_join_probe.rs | 8 +- .../intermediate_ops/inner_hash_join_probe.rs | 23 ++- src/daft-local-execution/src/pipeline.rs | 16 +- src/daft-local-execution/src/runtime_stats.rs | 8 +- .../src/sinks/hash_join_build.rs | 16 +- .../src/sinks/outer_hash_join_probe.rs | 191 ++++++++++-------- src/daft-table/src/lib.rs | 2 +- src/daft-table/src/probeable/mod.rs | 20 ++ 12 files changed, 162 insertions(+), 145 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a7fec8ec0e..90e4745ba8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1948,7 +1948,6 @@ dependencies = [ name = "daft-local-execution" version = "0.3.0-dev0" dependencies = [ - "arrow2", "common-daft-config", "common-display", "common-error", diff --git a/src/arrow2/src/bitmap/mutable.rs b/src/arrow2/src/bitmap/mutable.rs index edf373a506..cb77decd84 100644 --- a/src/arrow2/src/bitmap/mutable.rs +++ b/src/arrow2/src/bitmap/mutable.rs @@ -325,25 +325,6 @@ impl MutableBitmap { pub(crate) fn bitchunks_exact_mut(&mut self) -> BitChunksExactMut { BitChunksExactMut::new(&mut self.buffer, self.length) } - - pub fn or(&self, other: &MutableBitmap) -> MutableBitmap { - assert_eq!( - self.length, other.length, - "Bitmaps must have the same length" - ); - - let new_buffer: Vec = self - .buffer - .iter() - .zip(other.buffer.iter()) - .map(|(&a, &b)| a | b) // Apply bitwise OR on each pair of bytes - .collect(); - - MutableBitmap { - buffer: new_buffer, - length: self.length, - } - } } impl From for Bitmap { diff --git a/src/daft-core/src/prelude.rs b/src/daft-core/src/prelude.rs index 3b71045ddd..0de797d1b0 100644 --- a/src/daft-core/src/prelude.rs +++ b/src/daft-core/src/prelude.rs @@ -3,6 +3,8 @@ //! This module re-exports commonly used items from the Daft core library. // Re-export core series structures +// Re-export arrow2 bitmap +pub use arrow2::bitmap; pub use daft_schema::schema::{Schema, SchemaRef}; // Re-export count mode enum diff --git a/src/daft-local-execution/Cargo.toml b/src/daft-local-execution/Cargo.toml index bf49c9fa9d..2dac516672 100644 --- a/src/daft-local-execution/Cargo.toml +++ b/src/daft-local-execution/Cargo.toml @@ -1,5 +1,4 @@ [dependencies] -arrow2 = {workspace = true} common-daft-config = {path = "../common/daft-config", default-features = false} common-display = {path = "../common/display", default-features = false} common-error = {path = "../common/error", default-features = false} diff --git a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs index 637beb3fc5..3050f6339a 100644 --- a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs @@ -18,9 +18,9 @@ enum AntiSemiProbeState { } impl AntiSemiProbeState { - fn set_table(&mut self, table: &Arc) { + fn set_table(&mut self, table: Arc) { if let AntiSemiProbeState::Building = self { - *self = AntiSemiProbeState::ReadyToProbe(table.clone()); + *self = AntiSemiProbeState::ReadyToProbe(table); } else { panic!("AntiSemiProbeState should only be in Building state when setting table") } @@ -109,8 +109,8 @@ impl IntermediateOperator for AntiSemiProbeOperator { .as_any_mut() .downcast_mut::() .expect("AntiSemiProbeOperator state should be AntiSemiProbeState"); - let (probe_table, _) = input.as_probe_table(); - state.set_table(probe_table); + let probe_state = input.as_probe_state(); + state.set_table(probe_state.get_probeable()); Ok(IntermediateOperatorResult::NeedMoreInput(None)) } _ => { diff --git a/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs index e102804770..ae48eea4ab 100644 --- a/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs @@ -4,7 +4,7 @@ use common_error::DaftResult; use daft_core::prelude::SchemaRef; use daft_dsl::ExprRef; use daft_micropartition::MicroPartition; -use daft_table::{GrowableTable, Probeable, Table}; +use daft_table::{GrowableTable, ProbeState}; use indexmap::IndexSet; use tracing::{info_span, instrument}; @@ -15,21 +15,21 @@ use crate::pipeline::PipelineResultType; enum InnerHashJoinProbeState { Building, - ReadyToProbe(Arc, Arc>), + ReadyToProbe(Arc), } impl InnerHashJoinProbeState { - fn set_table(&mut self, table: &Arc, tables: &Arc>) { + fn set_probe_state(&mut self, probe_state: Arc) { if let InnerHashJoinProbeState::Building = self { - *self = InnerHashJoinProbeState::ReadyToProbe(table.clone(), tables.clone()); + *self = InnerHashJoinProbeState::ReadyToProbe(probe_state); } else { panic!("InnerHashJoinProbeState should only be in Building state when setting table") } } - fn get_probeable_and_table(&self) -> (&Arc, &Arc>) { - if let InnerHashJoinProbeState::ReadyToProbe(probe_table, tables) = self { - (probe_table, tables) + fn get_probe_state(&self) -> Arc { + if let InnerHashJoinProbeState::ReadyToProbe(probe_state) = self { + probe_state.clone() } else { panic!("get_probeable_and_table can only be used during the ReadyToProbe Phase") } @@ -85,7 +85,10 @@ impl InnerHashJoinProbeOperator { input: &Arc, state: &mut InnerHashJoinProbeState, ) -> DaftResult> { - let (probe_table, tables) = state.get_probeable_and_table(); + let (probe_table, tables) = { + let probe_state = state.get_probe_state(); + (probe_state.get_probeable(), probe_state.get_tables()) + }; let _growables = info_span!("InnerHashJoinOperator::build_growables").entered(); @@ -159,8 +162,8 @@ impl IntermediateOperator for InnerHashJoinProbeOperator { .expect("InnerHashJoinProbeOperator state should be InnerHashJoinProbeState"); match idx { 0 => { - let (probe_table, tables) = input.as_probe_table(); - state.set_table(probe_table, tables); + let probe_state = input.as_probe_state(); + state.set_probe_state(probe_state.clone()); Ok(IntermediateOperatorResult::NeedMoreInput(None)) } _ => { diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index 7f990fc63d..2a37ca2308 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -14,7 +14,7 @@ use daft_physical_plan::{ UnGroupedAggregate, }; use daft_plan::{populate_aggregation_stages, JoinType}; -use daft_table::{Probeable, Table}; +use daft_table::ProbeState; use indexmap::IndexSet; use snafu::ResultExt; @@ -38,7 +38,7 @@ use crate::{ #[derive(Clone)] pub enum PipelineResultType { Data(Arc), - ProbeTable(Arc, Arc>), + ProbeState(Arc), } impl From> for PipelineResultType { @@ -47,9 +47,9 @@ impl From> for PipelineResultType { } } -impl From<(Arc, Arc>)> for PipelineResultType { - fn from((probe_table, tables): (Arc, Arc>)) -> Self { - PipelineResultType::ProbeTable(probe_table, tables) +impl From> for PipelineResultType { + fn from(probe_state: Arc) -> Self { + PipelineResultType::ProbeState(probe_state) } } @@ -61,15 +61,15 @@ impl PipelineResultType { } } - pub fn as_probe_table(&self) -> (&Arc, &Arc>) { + pub fn as_probe_state(&self) -> &Arc { match self { - PipelineResultType::ProbeTable(probe_table, tables) => (probe_table, tables), + PipelineResultType::ProbeState(probe_state) => probe_state, _ => panic!("Expected probe table"), } } pub fn should_broadcast(&self) -> bool { - matches!(self, PipelineResultType::ProbeTable(_, _)) + matches!(self, PipelineResultType::ProbeState(_)) } } diff --git a/src/daft-local-execution/src/runtime_stats.rs b/src/daft-local-execution/src/runtime_stats.rs index 7489a8fd36..3afddd27b3 100644 --- a/src/daft-local-execution/src/runtime_stats.rs +++ b/src/daft-local-execution/src/runtime_stats.rs @@ -124,7 +124,9 @@ impl CountingSender { ) -> Result<(), SendError> { let len = match v { PipelineResultType::Data(ref mp) => mp.len(), - PipelineResultType::ProbeTable(_, ref tables) => tables.iter().map(|t| t.len()).sum(), + PipelineResultType::ProbeState(ref state) => { + state.get_tables().iter().map(|t| t.len()).sum() + } }; self.sender.send(v).await?; self.rt.mark_rows_emitted(len as u64); @@ -147,8 +149,8 @@ impl CountingReceiver { if let Some(ref v) = v { let len = match v { PipelineResultType::Data(ref mp) => mp.len(), - PipelineResultType::ProbeTable(_, ref tables) => { - tables.iter().map(|t| t.len()).sum() + PipelineResultType::ProbeState(state) => { + state.get_tables().iter().map(|t| t.len()).sum() } }; self.rt.mark_rows_received(len as u64); diff --git a/src/daft-local-execution/src/sinks/hash_join_build.rs b/src/daft-local-execution/src/sinks/hash_join_build.rs index 5f84045101..6c106bc677 100644 --- a/src/daft-local-execution/src/sinks/hash_join_build.rs +++ b/src/daft-local-execution/src/sinks/hash_join_build.rs @@ -5,7 +5,7 @@ use daft_core::prelude::SchemaRef; use daft_dsl::ExprRef; use daft_micropartition::MicroPartition; use daft_plan::JoinType; -use daft_table::{make_probeable_builder, Probeable, ProbeableBuilder, Table}; +use daft_table::{make_probeable_builder, ProbeState, ProbeableBuilder, Table}; use super::blocking_sink::{BlockingSink, BlockingSinkStatus}; use crate::pipeline::PipelineResultType; @@ -17,8 +17,7 @@ enum ProbeTableState { tables: Vec, }, Done { - probe_table: Arc, - tables: Arc>, + probe_state: Arc, }, } @@ -66,8 +65,7 @@ impl ProbeTableState { let pt = ptb.build(); *self = Self::Done { - probe_table: pt, - tables: Arc::new(tables.clone()), + probe_state: Arc::new(ProbeState::new(pt, Arc::new(tables.clone()))), }; Ok(()) } else { @@ -108,12 +106,8 @@ impl BlockingSink for HashJoinBuildSink { fn finalize(&mut self) -> DaftResult> { self.probe_table_state.finalize()?; - if let ProbeTableState::Done { - probe_table, - tables, - } = &self.probe_table_state - { - Ok(Some((probe_table.clone(), tables.clone()).into())) + if let ProbeTableState::Done { probe_state } = &self.probe_table_state { + Ok(Some(probe_state.clone().into())) } else { panic!("finalize should only be called after the probe table is built") } diff --git a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs index 740a060e80..d2b1a389bb 100644 --- a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs +++ b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs @@ -1,89 +1,103 @@ use std::sync::Arc; -use arrow2::bitmap::MutableBitmap; use common_error::DaftResult; use daft_core::{ - prelude::{Schema, SchemaRef}, + prelude::{ + bitmap::{or, Bitmap, MutableBitmap}, + Schema, SchemaRef, + }, series::Series, }; use daft_dsl::ExprRef; use daft_micropartition::MicroPartition; use daft_plan::JoinType; -use daft_table::{GrowableTable, Probeable, Table}; +use daft_table::{GrowableTable, ProbeState, Table}; use indexmap::IndexSet; use tracing::{info_span, instrument}; use super::streaming_sink::{StreamingSink, StreamingSinkOutput, StreamingSinkState}; use crate::pipeline::PipelineResultType; -struct IndexBitmapTracker { - bitmaps: Vec, +struct IndexBitmapBuilder { + bitmap: MutableBitmap, } -impl IndexBitmapTracker { - fn new(tables: &Arc>) -> Self { - let bitmaps = tables - .iter() - .map(|table| MutableBitmap::from_len_zeroed(table.len())) - .collect(); - Self { bitmaps } +impl IndexBitmapBuilder { + const TABLE_IDX_SHIFT: usize = 36; + const LOWER_MASK: u64 = (1 << Self::TABLE_IDX_SHIFT) - 1; + + fn new(tables: &[Table]) -> Self { + let total_len = tables.iter().map(|t| t.len()).sum(); + Self { + bitmap: MutableBitmap::with_capacity(total_len), + } } - fn set_true(&mut self, table_idx: usize, row_idx: usize) { - self.bitmaps[table_idx].set(row_idx, true); + #[inline] + fn mark_used(&mut self, table_idx: usize, row_idx: usize) { + let idx = (table_idx as u64) << Self::TABLE_IDX_SHIFT | row_idx as u64; + self.bitmap.set(idx as usize, true); } + fn build(self) -> IndexBitmap { + IndexBitmap { + bitmap: self.bitmap.into(), + table_idx_shift: Self::TABLE_IDX_SHIFT, + lower_mask: Self::LOWER_MASK, + } + } +} + +struct IndexBitmap { + bitmap: Bitmap, + table_idx_shift: usize, + lower_mask: u64, +} + +impl IndexBitmap { fn or(&self, other: &Self) -> Self { - let bitmaps = self - .bitmaps - .iter() - .zip(other.bitmaps.iter()) - .map(|(a, b)| a.or(b)) - .collect(); - Self { bitmaps } + assert_eq!(self.table_idx_shift, other.table_idx_shift); + assert_eq!(self.lower_mask, other.lower_mask); + Self { + bitmap: or(&self.bitmap, &other.bitmap), + table_idx_shift: self.table_idx_shift, + lower_mask: self.lower_mask, + } + } + + fn unset_bits(&self) -> usize { + self.bitmap.unset_bits() } fn get_unused_indices(&self) -> impl Iterator + '_ { - self.bitmaps + self.bitmap .iter() .enumerate() - .flat_map(|(table_idx, bitmap)| { - bitmap - .iter() - .enumerate() - .filter_map(move |(row_idx, is_set)| { - if !is_set { - Some((table_idx, row_idx)) - } else { - None - } - }) + .filter_map(move |(idx, is_set)| { + if !is_set { + let table_idx = idx >> self.table_idx_shift; + let row_idx = (idx as u64 & self.lower_mask) as usize; + Some((table_idx, row_idx)) + } else { + None + } }) } } enum OuterHashJoinProbeState { Building, - ReadyToProbe( - Arc, - Arc>, - Option, - ), + ReadyToProbe(Arc, Option), } impl OuterHashJoinProbeState { - fn initialize_probe_state( - &mut self, - table: &Arc, - tables: &Arc>, - needs_bitmap: bool, - ) { + fn initialize_probe_state(&mut self, probe_state: Arc, needs_bitmap: bool) { + let tables = probe_state.get_tables().clone(); if let OuterHashJoinProbeState::Building = self { *self = OuterHashJoinProbeState::ReadyToProbe( - table.clone(), - tables.clone(), + probe_state, if needs_bitmap { - Some(IndexBitmapTracker::new(tables)) + Some(IndexBitmapBuilder::new(&tables)) } else { None }, @@ -93,17 +107,17 @@ impl OuterHashJoinProbeState { } } - fn get_probeable_and_tables(&self) -> (Arc, Arc>) { - if let OuterHashJoinProbeState::ReadyToProbe(probe_table, tables, _) = self { - (probe_table.clone(), tables.clone()) + fn get_probe_state(&self) -> &ProbeState { + if let OuterHashJoinProbeState::ReadyToProbe(probe_state, _) = self { + probe_state } else { panic!("get_probeable_and_table can only be used during the ReadyToProbe Phase") } } - fn get_bitmap(&mut self) -> &mut Option { - if let OuterHashJoinProbeState::ReadyToProbe(_, _, bitmap) = self { - bitmap + fn get_bitmap_builder(&mut self) -> &mut Option { + if let OuterHashJoinProbeState::ReadyToProbe(_, bitmap_builder) = self { + bitmap_builder } else { panic!("get_bitmap can only be used during the ReadyToProbe Phase") } @@ -164,7 +178,10 @@ impl OuterHashJoinProbeSink { input: &Arc, state: &mut OuterHashJoinProbeState, ) -> DaftResult> { - let (probe_table, tables) = state.get_probeable_and_tables(); + let (probe_table, tables) = { + let probe_state = state.get_probe_state(); + (probe_state.get_probeable(), probe_state.get_tables()) + }; let _growables = info_span!("OuterHashJoinProbeSink::build_growables").entered(); @@ -230,8 +247,11 @@ impl OuterHashJoinProbeSink { input: &Arc, state: &mut OuterHashJoinProbeState, ) -> DaftResult> { - let (probe_table, tables) = state.get_probeable_and_tables(); - let bitmap = state.get_bitmap(); + let (probe_table, tables) = { + let probe_state = state.get_probe_state(); + (probe_state.get_probeable(), probe_state.get_tables()) + }; + let bitmap_builder = state.get_bitmap_builder(); let _growables = info_span!("OuterHashJoinProbeSink::build_growables").entered(); // Need to set use_validity to true here because we add nulls to the build side @@ -246,7 +266,9 @@ impl OuterHashJoinProbeSink { let mut probe_side_growable = GrowableTable::new(&input_tables.iter().collect::>(), false, input.len())?; - let left_idx_used = bitmap.as_mut().expect("bitmap should be set in outer join"); + let left_idx_used = bitmap_builder + .as_mut() + .expect("bitmap should be set in outer join"); drop(_growables); { @@ -258,13 +280,10 @@ impl OuterHashJoinProbeSink { for (probe_row_idx, inner_iter) in idx_mapper.make_iter().enumerate() { if let Some(inner_iter) = inner_iter { for (build_side_table_idx, build_row_idx) in inner_iter { - left_idx_used - .set_true(build_side_table_idx as usize, build_row_idx as usize); - build_side_growable.extend( - build_side_table_idx as usize, - build_row_idx as usize, - 1, - ); + let build_side_table_idx = build_side_table_idx as usize; + let build_row_idx = build_row_idx as usize; + left_idx_used.mark_used(build_side_table_idx, build_row_idx); + build_side_growable.extend(build_side_table_idx, build_row_idx, 1); probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); } } else { @@ -304,31 +323,32 @@ impl OuterHashJoinProbeSink { let tables = states .first() .expect("at least one state should be present") - .get_probeable_and_tables() - .1; + .get_probe_state() + .get_tables(); let merged_bitmap = { - let bitmaps = states - .into_iter() - .map(|s| { - if let OuterHashJoinProbeState::ReadyToProbe(_, _, bitmap) = s { - bitmap - .take() - .expect("bitmap should be present in outer join") - } else { - panic!("OuterHashJoinProbeState should be in ReadyToProbe state") - } - }) - .collect::>(); - bitmaps.into_iter().fold(None, |acc, x| match acc { + let bitmaps = states.into_iter().map(|s| { + if let OuterHashJoinProbeState::ReadyToProbe(_, bitmap) = s { + bitmap + .take() + .expect("bitmap should be present in outer join") + .build() + } else { + panic!("OuterHashJoinProbeState should be in ReadyToProbe state") + } + }); + bitmaps.fold(None, |acc, x| match acc { None => Some(x), Some(acc) => Some(acc.or(&x)), }) } .expect("at least one bitmap should be present"); - let mut build_side_growable = - GrowableTable::new(&tables.iter().collect::>(), true, 20)?; + let mut build_side_growable = GrowableTable::new( + &tables.iter().collect::>(), + true, + merged_bitmap.unset_bits(), + )?; for (table_idx, row_idx) in merged_bitmap.get_unused_indices() { build_side_growable.extend(table_idx, row_idx, 1); @@ -370,12 +390,9 @@ impl StreamingSink for OuterHashJoinProbeSink { .as_any_mut() .downcast_mut::() .expect("OuterHashJoinProbeSink state should be OuterHashJoinProbeState"); - let (probe_table, tables) = input.as_probe_table(); - state.initialize_probe_state( - probe_table, - tables, - self.join_type == JoinType::Outer, - ); + let probe_state = input.as_probe_state(); + state + .initialize_probe_state(probe_state.clone(), self.join_type == JoinType::Outer); Ok(StreamingSinkOutput::NeedMoreInput(None)) } _ => { diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 100ca31726..484e30fe3b 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -28,7 +28,7 @@ mod probeable; mod repr_html; pub use growable::GrowableTable; -pub use probeable::{make_probeable_builder, Probeable, ProbeableBuilder}; +pub use probeable::{make_probeable_builder, ProbeState, Probeable, ProbeableBuilder}; #[cfg(feature = "python")] pub mod python; diff --git a/src/daft-table/src/probeable/mod.rs b/src/daft-table/src/probeable/mod.rs index a3d935246e..5dfae13b30 100644 --- a/src/daft-table/src/probeable/mod.rs +++ b/src/daft-table/src/probeable/mod.rs @@ -77,3 +77,23 @@ pub trait Probeable: Send + Sync { table: &'a Table, ) -> DaftResult + 'a>>; } + +#[derive(Clone)] +pub struct ProbeState { + probeable: Arc, + tables: Arc>, +} + +impl ProbeState { + pub fn new(probeable: Arc, tables: Arc>) -> Self { + Self { probeable, tables } + } + + pub fn get_probeable(&self) -> Arc { + self.probeable.clone() + } + + pub fn get_tables(&self) -> Arc> { + self.tables.clone() + } +} From 94245f810e81791481538c054f41077e267d527a Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Wed, 25 Sep 2024 16:36:09 -0700 Subject: [PATCH 08/13] prefix sum --- src/arrow2/src/array/growable/primitive.rs | 5 +- .../src/sinks/outer_hash_join_probe.rs | 56 ++++++++++++------- 2 files changed, 36 insertions(+), 25 deletions(-) diff --git a/src/arrow2/src/array/growable/primitive.rs b/src/arrow2/src/array/growable/primitive.rs index e443756cb9..4083cb49db 100644 --- a/src/arrow2/src/array/growable/primitive.rs +++ b/src/arrow2/src/array/growable/primitive.rs @@ -1,10 +1,7 @@ use std::sync::Arc; use crate::{ - array::{Array, PrimitiveArray}, - bitmap::MutableBitmap, - datatypes::DataType, - types::NativeType, + array::{Array, PrimitiveArray}, bitmap::MutableBitmap, datatypes::DataType, types::NativeType }; use super::{ diff --git a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs index d2b1a389bb..7edebae898 100644 --- a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs +++ b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs @@ -20,48 +20,55 @@ use crate::pipeline::PipelineResultType; struct IndexBitmapBuilder { bitmap: MutableBitmap, + prefix_sums: Vec, } impl IndexBitmapBuilder { - const TABLE_IDX_SHIFT: usize = 36; - const LOWER_MASK: u64 = (1 << Self::TABLE_IDX_SHIFT) - 1; - fn new(tables: &[Table]) -> Self { - let total_len = tables.iter().map(|t| t.len()).sum(); + println!("tables: {:#?}", tables); + let prefix_sums = tables + .iter() + .map(|t| t.len()) + .scan(0, |acc, x| { + let prev = *acc; + *acc += x; + Some(prev) + }) + .collect::>(); + println!("prefix_sums: {:?}", prefix_sums); + let total_len = prefix_sums.last().unwrap() + tables.last().unwrap().len(); + println!("total_len: {:?}", total_len); Self { - bitmap: MutableBitmap::with_capacity(total_len), + bitmap: MutableBitmap::from_len_zeroed(total_len), + prefix_sums, } } #[inline] fn mark_used(&mut self, table_idx: usize, row_idx: usize) { - let idx = (table_idx as u64) << Self::TABLE_IDX_SHIFT | row_idx as u64; - self.bitmap.set(idx as usize, true); + let idx = self.prefix_sums[table_idx] + row_idx; + self.bitmap.set(idx, true); } fn build(self) -> IndexBitmap { IndexBitmap { bitmap: self.bitmap.into(), - table_idx_shift: Self::TABLE_IDX_SHIFT, - lower_mask: Self::LOWER_MASK, + prefix_sums: self.prefix_sums, } } } struct IndexBitmap { bitmap: Bitmap, - table_idx_shift: usize, - lower_mask: u64, + prefix_sums: Vec, } impl IndexBitmap { fn or(&self, other: &Self) -> Self { - assert_eq!(self.table_idx_shift, other.table_idx_shift); - assert_eq!(self.lower_mask, other.lower_mask); + assert_eq!(self.prefix_sums, other.prefix_sums); Self { bitmap: or(&self.bitmap, &other.bitmap), - table_idx_shift: self.table_idx_shift, - lower_mask: self.lower_mask, + prefix_sums: self.prefix_sums.clone(), } } @@ -70,16 +77,21 @@ impl IndexBitmap { } fn get_unused_indices(&self) -> impl Iterator + '_ { + let mut curr_table = 0; self.bitmap .iter() .enumerate() .filter_map(move |(idx, is_set)| { - if !is_set { - let table_idx = idx >> self.table_idx_shift; - let row_idx = (idx as u64 & self.lower_mask) as usize; - Some((table_idx, row_idx)) - } else { + if is_set { None + } else { + while curr_table < self.prefix_sums.len() - 1 + && idx >= self.prefix_sums[curr_table + 1] + { + curr_table += 1; + } + let row_idx = idx - self.prefix_sums[curr_table]; + Some((curr_table, row_idx)) } }) } @@ -343,7 +355,8 @@ impl OuterHashJoinProbeSink { }) } .expect("at least one bitmap should be present"); - + println!("num nulls: {}", merged_bitmap.unset_bits()); + println!("tables: {:#?}", tables); let mut build_side_growable = GrowableTable::new( &tables.iter().collect::>(), true, @@ -351,6 +364,7 @@ impl OuterHashJoinProbeSink { )?; for (table_idx, row_idx) in merged_bitmap.get_unused_indices() { + println!("table_idx: {:?}, row_idx: {:?}", table_idx, row_idx); build_side_growable.extend(table_idx, row_idx, 1); } From 787f29fb985ae0491bf079f10380fcdafed82946 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Wed, 25 Sep 2024 16:43:57 -0700 Subject: [PATCH 09/13] cleanup --- .../anti_semi_hash_join_probe.rs | 15 ++++++++++----- .../intermediate_ops/inner_hash_join_probe.rs | 16 ++++++++++++---- .../src/sinks/outer_hash_join_probe.rs | 7 +------ 3 files changed, 23 insertions(+), 15 deletions(-) diff --git a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs index 09509ab221..e67cb069a9 100644 --- a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs @@ -18,9 +18,9 @@ enum AntiSemiProbeState { } impl AntiSemiProbeState { - fn set_table(&mut self, table: Arc) { + fn set_table(&mut self, table: &Arc) { if let Self::Building = self { - *self = Self::ReadyToProbe(table); + *self = Self::ReadyToProbe(table.clone()); } else { panic!("AntiSemiProbeState should only be in Building state when setting table") } @@ -47,6 +47,8 @@ pub struct AntiSemiProbeOperator { } impl AntiSemiProbeOperator { + const DEFAULT_GROWABLE_SIZE: usize = 20; + pub fn new(probe_on: Vec, join_type: &JoinType) -> Self { Self { probe_on, @@ -65,8 +67,11 @@ impl AntiSemiProbeOperator { let input_tables = input.get_tables()?; - let mut probe_side_growable = - GrowableTable::new(&input_tables.iter().collect::>(), false, 20)?; + let mut probe_side_growable = GrowableTable::new( + &input_tables.iter().collect::>(), + false, + Self::DEFAULT_GROWABLE_SIZE, + )?; drop(_growables); { @@ -110,7 +115,7 @@ impl IntermediateOperator for AntiSemiProbeOperator { .downcast_mut::() .expect("AntiSemiProbeOperator state should be AntiSemiProbeState"); let probe_state = input.as_probe_state(); - state.set_table(probe_state.get_probeable()); + state.set_table(&probe_state.get_probeable()); Ok(IntermediateOperatorResult::NeedMoreInput(None)) } _ => { diff --git a/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs index 6359c0761d..5f857a0e77 100644 --- a/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs @@ -51,6 +51,8 @@ pub struct InnerHashJoinProbeOperator { } impl InnerHashJoinProbeOperator { + const DEFAULT_GROWABLE_SIZE: usize = 20; + pub fn new( probe_on: Vec, left_schema: &SchemaRef, @@ -92,13 +94,19 @@ impl InnerHashJoinProbeOperator { let _growables = info_span!("InnerHashJoinOperator::build_growables").entered(); - let mut build_side_growable = - GrowableTable::new(&tables.iter().collect::>(), false, 20)?; + let mut build_side_growable = GrowableTable::new( + &tables.iter().collect::>(), + false, + Self::DEFAULT_GROWABLE_SIZE, + )?; let input_tables = input.get_tables()?; - let mut probe_side_growable = - GrowableTable::new(&input_tables.iter().collect::>(), false, 20)?; + let mut probe_side_growable = GrowableTable::new( + &input_tables.iter().collect::>(), + false, + Self::DEFAULT_GROWABLE_SIZE, + )?; drop(_growables); { diff --git a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs index 1a8753eb7c..c416e38aca 100644 --- a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs +++ b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs @@ -25,7 +25,6 @@ struct IndexBitmapBuilder { impl IndexBitmapBuilder { fn new(tables: &[Table]) -> Self { - println!("tables: {:#?}", tables); let prefix_sums = tables .iter() .map(|t| t.len()) @@ -35,9 +34,7 @@ impl IndexBitmapBuilder { Some(prev) }) .collect::>(); - println!("prefix_sums: {:?}", prefix_sums); let total_len = prefix_sums.last().unwrap() + tables.last().unwrap().len(); - println!("total_len: {:?}", total_len); Self { bitmap: MutableBitmap::from_len_zeroed(total_len), prefix_sums, @@ -355,8 +352,7 @@ impl OuterHashJoinProbeSink { }) } .expect("at least one bitmap should be present"); - println!("num nulls: {}", merged_bitmap.unset_bits()); - println!("tables: {:#?}", tables); + let mut build_side_growable = GrowableTable::new( &tables.iter().collect::>(), true, @@ -364,7 +360,6 @@ impl OuterHashJoinProbeSink { )?; for (table_idx, row_idx) in merged_bitmap.get_unused_indices() { - println!("table_idx: {:?}, row_idx: {:?}", table_idx, row_idx); build_side_growable.extend(table_idx, row_idx, 1); } From a2d417ae7652578be7846e38dcc884eb6a74cfb9 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Wed, 25 Sep 2024 20:10:22 -0700 Subject: [PATCH 10/13] fix deadlock; --- src/daft-local-execution/src/sinks/streaming_sink.rs | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index 768d3ca5a2..f1b4935074 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -194,12 +194,14 @@ impl PipelineNode for StreamingSinkNode { let op = self.op.clone(); let runtime_stats = self.runtime_stats.clone(); + let num_workers = op.max_concurrency(); + let (input_senders, input_receivers) = (0..num_workers).map(|_| create_channel(1)).unzip(); + runtime_handle.spawn( + Self::forward_input_to_workers(child_result_receivers, input_senders), + self.name(), + ); runtime_handle.spawn( async move { - let num_workers = op.max_concurrency(); - let (input_senders, input_receivers) = - (0..num_workers).map(|_| create_channel(1)).unzip(); - let mut worker_set = create_worker_set(); let mut output_receiver = Self::spawn_workers( op.clone(), @@ -208,7 +210,6 @@ impl PipelineNode for StreamingSinkNode { runtime_stats.clone(), ); - Self::forward_input_to_workers(child_result_receivers, input_senders).await?; while let Some(morsel) = output_receiver.recv().await { let _ = destination_sender.send(morsel.into()).await; } From f93847ac14fcee588acddcedcf4a45e6ccd65d84 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Wed, 25 Sep 2024 21:47:19 -0700 Subject: [PATCH 11/13] check for empty --- .../src/intermediate_ops/anti_semi_hash_join_probe.rs | 9 ++++++++- .../src/intermediate_ops/inner_hash_join_probe.rs | 7 +++++++ src/daft-local-execution/src/pipeline.rs | 10 ++++++++-- .../src/sinks/outer_hash_join_probe.rs | 7 +++++++ 4 files changed, 30 insertions(+), 3 deletions(-) diff --git a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs index e67cb069a9..df3e012a71 100644 --- a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use common_error::DaftResult; +use daft_core::prelude::SchemaRef; use daft_dsl::ExprRef; use daft_micropartition::MicroPartition; use daft_plan::JoinType; @@ -44,15 +45,17 @@ impl IntermediateOperatorState for AntiSemiProbeState { pub struct AntiSemiProbeOperator { probe_on: Vec, is_semi: bool, + output_schema: SchemaRef, } impl AntiSemiProbeOperator { const DEFAULT_GROWABLE_SIZE: usize = 20; - pub fn new(probe_on: Vec, join_type: &JoinType) -> Self { + pub fn new(probe_on: Vec, join_type: &JoinType, output_schema: &SchemaRef) -> Self { Self { probe_on, is_semi: *join_type == JoinType::Semi, + output_schema: output_schema.clone(), } } @@ -125,6 +128,10 @@ impl IntermediateOperator for AntiSemiProbeOperator { .downcast_mut::() .expect("AntiSemiProbeOperator state should be AntiSemiProbeState"); let input = input.as_data(); + if input.is_empty() { + let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); + return Ok(IntermediateOperatorResult::NeedMoreInput(Some(empty))); + } let out = self.probe_anti_semi(input, state)?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) } diff --git a/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs index 5f857a0e77..9672fdc0c7 100644 --- a/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs @@ -48,6 +48,7 @@ pub struct InnerHashJoinProbeOperator { left_non_join_columns: Vec, right_non_join_columns: Vec, build_on_left: bool, + output_schema: SchemaRef, } impl InnerHashJoinProbeOperator { @@ -59,6 +60,7 @@ impl InnerHashJoinProbeOperator { right_schema: &SchemaRef, build_on_left: bool, common_join_keys: IndexSet, + output_schema: &SchemaRef, ) -> Self { let left_non_join_columns = left_schema .fields @@ -79,6 +81,7 @@ impl InnerHashJoinProbeOperator { left_non_join_columns, right_non_join_columns, build_on_left, + output_schema: output_schema.clone(), } } @@ -176,6 +179,10 @@ impl IntermediateOperator for InnerHashJoinProbeOperator { } _ => { let input = input.as_data(); + if input.is_empty() { + let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); + return Ok(IntermediateOperatorResult::NeedMoreInput(Some(empty))); + } let out = self.probe_inner(input, state)?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) } diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index e9779749e8..848c9f683d 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -230,7 +230,7 @@ pub fn physical_plan_to_pipeline( left_on, right_on, join_type, - .. + schema, }) => { let left_schema = left.schema(); let right_schema = right.schema(); @@ -299,7 +299,11 @@ pub fn physical_plan_to_pipeline( match join_type { JoinType::Anti | JoinType::Semi => Ok(IntermediateNode::new( - Arc::new(AntiSemiProbeOperator::new(casted_probe_on, join_type)), + Arc::new(AntiSemiProbeOperator::new( + casted_probe_on, + join_type, + schema, + )), vec![build_node, probe_child_node], ) .boxed()), @@ -310,6 +314,7 @@ pub fn physical_plan_to_pipeline( right_schema, build_on_left, common_join_keys, + schema, )), vec![build_node, probe_child_node], ) @@ -322,6 +327,7 @@ pub fn physical_plan_to_pipeline( right_schema, *join_type, common_join_keys, + schema, )), vec![build_node, probe_child_node], ) diff --git a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs index c416e38aca..2a15bf062f 100644 --- a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs +++ b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs @@ -146,6 +146,7 @@ pub(crate) struct OuterHashJoinProbeSink { right_non_join_columns: Vec, right_non_join_schema: SchemaRef, join_type: JoinType, + output_schema: SchemaRef, } impl OuterHashJoinProbeSink { @@ -155,6 +156,7 @@ impl OuterHashJoinProbeSink { right_schema: &SchemaRef, join_type: JoinType, common_join_keys: IndexSet, + output_schema: &SchemaRef, ) -> Self { let left_non_join_columns = left_schema .fields @@ -179,6 +181,7 @@ impl OuterHashJoinProbeSink { right_non_join_columns, right_non_join_schema, join_type, + output_schema: output_schema.clone(), } } @@ -410,6 +413,10 @@ impl StreamingSink for OuterHashJoinProbeSink { .downcast_mut::() .expect("OuterHashJoinProbeSink state should be OuterHashJoinProbeState"); let input = input.as_data(); + if input.is_empty() { + let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); + return Ok(StreamingSinkOutput::NeedMoreInput(Some(empty))); + } let out = match self.join_type { JoinType::Left | JoinType::Right => self.probe_left_right(input, state), JoinType::Outer => self.probe_outer(input, state), From d60933691640f30bc1ec9534a20fc51997ade143 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Wed, 25 Sep 2024 22:15:37 -0700 Subject: [PATCH 12/13] use mask filter --- src/daft-core/src/prelude.rs | 2 +- .../src/sinks/outer_hash_join_probe.rs | 93 ++++++------------- 2 files changed, 30 insertions(+), 65 deletions(-) diff --git a/src/daft-core/src/prelude.rs b/src/daft-core/src/prelude.rs index 0de797d1b0..6f6ecaf5a5 100644 --- a/src/daft-core/src/prelude.rs +++ b/src/daft-core/src/prelude.rs @@ -2,9 +2,9 @@ //! //! This module re-exports commonly used items from the Daft core library. -// Re-export core series structures // Re-export arrow2 bitmap pub use arrow2::bitmap; +// Re-export core series structures pub use daft_schema::schema::{Schema, SchemaRef}; // Re-export count mode enum diff --git a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs index 2a15bf062f..f1bf813823 100644 --- a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs +++ b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs @@ -3,10 +3,10 @@ use std::sync::Arc; use common_error::DaftResult; use daft_core::{ prelude::{ - bitmap::{or, Bitmap, MutableBitmap}, - Schema, SchemaRef, + bitmap::{and, Bitmap, MutableBitmap}, + BooleanArray, Schema, SchemaRef, }, - series::Series, + series::{IntoSeries, Series}, }; use daft_dsl::ExprRef; use daft_micropartition::MicroPartition; @@ -19,78 +19,51 @@ use super::streaming_sink::{StreamingSink, StreamingSinkOutput, StreamingSinkSta use crate::pipeline::PipelineResultType; struct IndexBitmapBuilder { - bitmap: MutableBitmap, - prefix_sums: Vec, + mutable_bitmaps: Vec, } impl IndexBitmapBuilder { fn new(tables: &[Table]) -> Self { - let prefix_sums = tables - .iter() - .map(|t| t.len()) - .scan(0, |acc, x| { - let prev = *acc; - *acc += x; - Some(prev) - }) - .collect::>(); - let total_len = prefix_sums.last().unwrap() + tables.last().unwrap().len(); Self { - bitmap: MutableBitmap::from_len_zeroed(total_len), - prefix_sums, + mutable_bitmaps: tables + .iter() + .map(|t| MutableBitmap::from_len_set(t.len())) + .collect(), } } #[inline] fn mark_used(&mut self, table_idx: usize, row_idx: usize) { - let idx = self.prefix_sums[table_idx] + row_idx; - self.bitmap.set(idx, true); + self.mutable_bitmaps[table_idx].set(row_idx, false); } fn build(self) -> IndexBitmap { IndexBitmap { - bitmap: self.bitmap.into(), - prefix_sums: self.prefix_sums, + bitmaps: self.mutable_bitmaps.into_iter().map(|b| b.into()).collect(), } } } struct IndexBitmap { - bitmap: Bitmap, - prefix_sums: Vec, + bitmaps: Vec, } impl IndexBitmap { - fn or(&self, other: &Self) -> Self { - assert_eq!(self.prefix_sums, other.prefix_sums); + fn merge(&self, other: &Self) -> Self { Self { - bitmap: or(&self.bitmap, &other.bitmap), - prefix_sums: self.prefix_sums.clone(), + bitmaps: self + .bitmaps + .iter() + .zip(other.bitmaps.iter()) + .map(|(a, b)| and(a, b)) + .collect(), } } - fn unset_bits(&self) -> usize { - self.bitmap.unset_bits() - } - - fn get_unused_indices(&self) -> impl Iterator + '_ { - let mut curr_table = 0; - self.bitmap - .iter() - .enumerate() - .filter_map(move |(idx, is_set)| { - if is_set { - None - } else { - while curr_table < self.prefix_sums.len() - 1 - && idx >= self.prefix_sums[curr_table + 1] - { - curr_table += 1; - } - let row_idx = idx - self.prefix_sums[curr_table]; - Some((curr_table, row_idx)) - } - }) + fn convert_to_boolean_arrays(self) -> impl Iterator { + self.bitmaps + .into_iter() + .map(|b| BooleanArray::from(("bitmap", b))) } } @@ -196,7 +169,6 @@ impl OuterHashJoinProbeSink { }; let _growables = info_span!("OuterHashJoinProbeSink::build_growables").entered(); - let mut build_side_growable = GrowableTable::new( &tables.iter().collect::>(), true, @@ -204,7 +176,6 @@ impl OuterHashJoinProbeSink { )?; let input_tables = input.get_tables()?; - let mut probe_side_growable = GrowableTable::new(&input_tables.iter().collect::>(), false, input.len())?; @@ -265,7 +236,6 @@ impl OuterHashJoinProbeSink { }; let bitmap_builder = state.get_bitmap_builder(); let _growables = info_span!("OuterHashJoinProbeSink::build_growables").entered(); - // Need to set use_validity to true here because we add nulls to the build side let mut build_side_growable = GrowableTable::new( &tables.iter().collect::>(), @@ -274,7 +244,6 @@ impl OuterHashJoinProbeSink { )?; let input_tables = input.get_tables()?; - let mut probe_side_growable = GrowableTable::new(&input_tables.iter().collect::>(), false, input.len())?; @@ -351,22 +320,18 @@ impl OuterHashJoinProbeSink { }); bitmaps.fold(None, |acc, x| match acc { None => Some(x), - Some(acc) => Some(acc.or(&x)), + Some(acc) => Some(acc.merge(&x)), }) } .expect("at least one bitmap should be present"); - let mut build_side_growable = GrowableTable::new( - &tables.iter().collect::>(), - true, - merged_bitmap.unset_bits(), - )?; - - for (table_idx, row_idx) in merged_bitmap.get_unused_indices() { - build_side_growable.extend(table_idx, row_idx, 1); - } + let leftovers = merged_bitmap + .convert_to_boolean_arrays() + .zip(tables.iter()) + .map(|(bitmap, table)| table.mask_filter(&bitmap.into_series())) + .collect::>>()?; - let build_side_table = build_side_growable.build()?; + let build_side_table = Table::concat(&leftovers)?; let join_table = build_side_table.get_columns(&self.common_join_keys)?; let left = build_side_table.get_columns(&self.left_non_join_columns)?; From d4e8942a517b912461cbb01a0e636e87ccb7140d Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Mon, 21 Oct 2024 18:04:40 -0700 Subject: [PATCH 13/13] feedback --- .../anti_semi_hash_join_probe.rs | 2 +- .../intermediate_ops/inner_hash_join_probe.rs | 4 ++-- src/daft-local-execution/src/sinks/limit.rs | 23 +++++++------------ .../src/sinks/outer_hash_join_probe.rs | 12 ++++++---- src/daft-table/src/probeable/mod.rs | 8 +++---- 5 files changed, 23 insertions(+), 26 deletions(-) diff --git a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs index 5ad6dc3ff9..bdcebecab6 100644 --- a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs @@ -118,7 +118,7 @@ impl IntermediateOperator for AntiSemiProbeOperator { if idx == 0 { let probe_state = input.as_probe_state(); - state.set_table(&probe_state.get_probeable()); + state.set_table(probe_state.get_probeable()); Ok(IntermediateOperatorResult::NeedMoreInput(None)) } else { let input = input.as_data(); diff --git a/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs index db23681d09..a208efea6c 100644 --- a/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs @@ -27,9 +27,9 @@ impl InnerHashJoinProbeState { } } - fn get_probe_state(&self) -> Arc { + fn get_probe_state(&self) -> &Arc { if let Self::ReadyToProbe(probe_state) = self { - probe_state.clone() + probe_state } else { panic!("get_probeable_and_table can only be used during the ReadyToProbe Phase") } diff --git a/src/daft-local-execution/src/sinks/limit.rs b/src/daft-local-execution/src/sinks/limit.rs index 20f5e418b9..633c3511c1 100644 --- a/src/daft-local-execution/src/sinks/limit.rs +++ b/src/daft-local-execution/src/sinks/limit.rs @@ -16,12 +16,8 @@ impl LimitSinkState { Self { remaining } } - fn get_remaining(&self) -> usize { - self.remaining - } - - fn set_remaining(&mut self, remaining: usize) { - self.remaining = remaining; + fn get_remaining_mut(&mut self) -> &mut usize { + &mut self.remaining } } @@ -56,23 +52,20 @@ impl StreamingSink for LimitSink { .expect("Limit Sink should have LimitSinkState"); let input = input.as_data(); let input_num_rows = input.len(); - let mut remaining = state.get_remaining(); + let remaining = state.get_remaining_mut(); use std::cmp::Ordering::{Equal, Greater, Less}; - match input_num_rows.cmp(&remaining) { + match input_num_rows.cmp(remaining) { Less => { - remaining -= input_num_rows; - state.set_remaining(remaining); + *remaining -= input_num_rows; Ok(StreamingSinkOutput::NeedMoreInput(Some(input.clone()))) } Equal => { - remaining = 0; - state.set_remaining(remaining); + *remaining = 0; Ok(StreamingSinkOutput::Finished(Some(input.clone()))) } Greater => { - let taken = input.head(remaining)?; - remaining -= taken.len(); - state.set_remaining(remaining); + let taken = input.head(*remaining)?; + *remaining = 0; Ok(StreamingSinkOutput::Finished(Some(Arc::new(taken)))) } } diff --git a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs index ec8a7485c2..ab5ffa8cb0 100644 --- a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs +++ b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs @@ -77,9 +77,9 @@ impl OuterHashJoinProbeState { let tables = probe_state.get_tables(); if matches!(self, Self::Building) { *self = Self::ReadyToProbe( - probe_state, + probe_state.clone(), if needs_bitmap { - Some(IndexBitmapBuilder::new(&tables)) + Some(IndexBitmapBuilder::new(tables)) } else { None }, @@ -232,7 +232,10 @@ impl OuterHashJoinProbeSink { ) -> DaftResult> { let (probe_table, tables) = { let probe_state = state.get_probe_state(); - (probe_state.get_probeable(), probe_state.get_tables()) + ( + probe_state.get_probeable().clone(), + probe_state.get_tables().clone(), + ) }; let bitmap_builder = state.get_bitmap_builder(); let _growables = info_span!("OuterHashJoinProbeSink::build_growables").entered(); @@ -305,7 +308,8 @@ impl OuterHashJoinProbeSink { .first() .expect("at least one state should be present") .get_probe_state() - .get_tables(); + .get_tables() + .clone(); let merged_bitmap = { let bitmaps = states.into_iter().map(|s| { diff --git a/src/daft-table/src/probeable/mod.rs b/src/daft-table/src/probeable/mod.rs index 5dfae13b30..3346bd9869 100644 --- a/src/daft-table/src/probeable/mod.rs +++ b/src/daft-table/src/probeable/mod.rs @@ -89,11 +89,11 @@ impl ProbeState { Self { probeable, tables } } - pub fn get_probeable(&self) -> Arc { - self.probeable.clone() + pub fn get_probeable(&self) -> &Arc { + &self.probeable } - pub fn get_tables(&self) -> Arc> { - self.tables.clone() + pub fn get_tables(&self) -> &Arc> { + &self.tables } }