diff --git a/src/arrow2/src/array/growable/primitive.rs b/src/arrow2/src/array/growable/primitive.rs index e443756cb9..4083cb49db 100644 --- a/src/arrow2/src/array/growable/primitive.rs +++ b/src/arrow2/src/array/growable/primitive.rs @@ -1,10 +1,7 @@ use std::sync::Arc; use crate::{ - array::{Array, PrimitiveArray}, - bitmap::MutableBitmap, - datatypes::DataType, - types::NativeType, + array::{Array, PrimitiveArray}, bitmap::MutableBitmap, datatypes::DataType, types::NativeType }; use super::{ diff --git a/src/daft-core/src/prelude.rs b/src/daft-core/src/prelude.rs index 3b71045ddd..6f6ecaf5a5 100644 --- a/src/daft-core/src/prelude.rs +++ b/src/daft-core/src/prelude.rs @@ -2,6 +2,8 @@ //! //! This module re-exports commonly used items from the Daft core library. +// Re-export arrow2 bitmap +pub use arrow2::bitmap; // Re-export core series structures pub use daft_schema::schema::{Schema, SchemaRef}; diff --git a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs index 525e308ebe..bdcebecab6 100644 --- a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use common_error::DaftResult; +use daft_core::prelude::SchemaRef; use daft_dsl::ExprRef; use daft_micropartition::MicroPartition; use daft_plan::JoinType; @@ -43,14 +44,18 @@ impl IntermediateOperatorState for AntiSemiProbeState { pub struct AntiSemiProbeOperator { probe_on: Vec, - join_type: JoinType, + is_semi: bool, + output_schema: SchemaRef, } impl AntiSemiProbeOperator { - pub fn new(probe_on: Vec, join_type: JoinType) -> Self { + const DEFAULT_GROWABLE_SIZE: usize = 20; + + pub fn new(probe_on: Vec, join_type: &JoinType, output_schema: &SchemaRef) -> Self { Self { probe_on, - join_type, + is_semi: *join_type == JoinType::Semi, + output_schema: output_schema.clone(), } } @@ -65,8 +70,11 @@ impl AntiSemiProbeOperator { let input_tables = input.get_tables()?; - let mut probe_side_growable = - GrowableTable::new(&input_tables.iter().collect::>(), false, 20)?; + let mut probe_side_growable = GrowableTable::new( + &input_tables.iter().collect::>(), + false, + Self::DEFAULT_GROWABLE_SIZE, + )?; drop(_growables); { @@ -76,7 +84,7 @@ impl AntiSemiProbeOperator { let iter = probe_set.probe_exists(&join_keys)?; for (probe_row_idx, matched) in iter.enumerate() { - match (self.join_type == JoinType::Semi, matched) { + match (self.is_semi, matched) { (true, true) | (false, false) => { probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); } @@ -109,15 +117,16 @@ impl IntermediateOperator for AntiSemiProbeOperator { .expect("AntiSemiProbeOperator state should be AntiSemiProbeState"); if idx == 0 { - let (probe_table, _) = input.as_probe_table(); - state.set_table(probe_table); + let probe_state = input.as_probe_state(); + state.set_table(probe_state.get_probeable()); Ok(IntermediateOperatorResult::NeedMoreInput(None)) } else { let input = input.as_data(); - let out = match self.join_type { - JoinType::Semi | JoinType::Anti => self.probe_anti_semi(input, state), - _ => unreachable!("Only Semi and Anti joins are supported"), - }?; + if input.is_empty() { + let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); + return Ok(IntermediateOperatorResult::NeedMoreInput(Some(empty))); + } + let out = self.probe_anti_semi(input, state)?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) } } diff --git a/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs deleted file mode 100644 index dd53b9eac4..0000000000 --- a/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs +++ /dev/null @@ -1,268 +0,0 @@ -use std::sync::Arc; - -use common_error::DaftResult; -use daft_core::prelude::SchemaRef; -use daft_dsl::ExprRef; -use daft_micropartition::MicroPartition; -use daft_plan::JoinType; -use daft_table::{GrowableTable, Probeable, Table}; -use indexmap::IndexSet; -use tracing::{info_span, instrument}; - -use super::intermediate_op::{ - IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, -}; -use crate::pipeline::PipelineResultType; - -enum HashJoinProbeState { - Building, - ReadyToProbe(Arc, Arc>), -} - -impl HashJoinProbeState { - fn set_table(&mut self, table: &Arc, tables: &Arc>) { - if matches!(self, Self::Building) { - *self = Self::ReadyToProbe(table.clone(), tables.clone()); - } else { - panic!("HashJoinProbeState should only be in Building state when setting table") - } - } - - fn get_probeable_and_table(&self) -> (&Arc, &Arc>) { - if let Self::ReadyToProbe(probe_table, tables) = self { - (probe_table, tables) - } else { - panic!("get_probeable_and_table can only be used during the ReadyToProbe Phase") - } - } -} - -impl IntermediateOperatorState for HashJoinProbeState { - fn as_any_mut(&mut self) -> &mut dyn std::any::Any { - self - } -} - -pub struct HashJoinProbeOperator { - probe_on: Vec, - common_join_keys: Vec, - left_non_join_columns: Vec, - right_non_join_columns: Vec, - join_type: JoinType, - build_on_left: bool, -} - -impl HashJoinProbeOperator { - pub fn new( - probe_on: Vec, - left_schema: &SchemaRef, - right_schema: &SchemaRef, - join_type: JoinType, - build_on_left: bool, - common_join_keys: IndexSet, - ) -> Self { - let (common_join_keys, left_non_join_columns, right_non_join_columns) = match join_type { - JoinType::Inner | JoinType::Left | JoinType::Right => { - let left_non_join_columns = left_schema - .fields - .keys() - .filter(|c| !common_join_keys.contains(*c)) - .cloned() - .collect(); - let right_non_join_columns = right_schema - .fields - .keys() - .filter(|c| !common_join_keys.contains(*c)) - .cloned() - .collect(); - ( - common_join_keys.into_iter().collect(), - left_non_join_columns, - right_non_join_columns, - ) - } - _ => { - panic!("Semi, Anti, and join are not supported in HashJoinProbeOperator") - } - }; - Self { - probe_on, - common_join_keys, - left_non_join_columns, - right_non_join_columns, - join_type, - build_on_left, - } - } - - fn probe_inner( - &self, - input: &Arc, - state: &HashJoinProbeState, - ) -> DaftResult> { - let (probe_table, tables) = state.get_probeable_and_table(); - - let _growables = info_span!("HashJoinOperator::build_growables").entered(); - - let mut build_side_growable = - GrowableTable::new(&tables.iter().collect::>(), false, 20)?; - - let input_tables = input.get_tables()?; - - let mut probe_side_growable = - GrowableTable::new(&input_tables.iter().collect::>(), false, 20)?; - - drop(_growables); - { - let _loop = info_span!("HashJoinOperator::eval_and_probe").entered(); - for (probe_side_table_idx, table) in input_tables.iter().enumerate() { - // we should emit one table at a time when this is streaming - let join_keys = table.eval_expression_list(&self.probe_on)?; - let idx_mapper = probe_table.probe_indices(&join_keys)?; - - for (probe_row_idx, inner_iter) in idx_mapper.make_iter().enumerate() { - if let Some(inner_iter) = inner_iter { - for (build_side_table_idx, build_row_idx) in inner_iter { - build_side_growable.extend( - build_side_table_idx as usize, - build_row_idx as usize, - 1, - ); - // we can perform run length compression for this to make this more efficient - probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); - } - } - } - } - } - let build_side_table = build_side_growable.build()?; - let probe_side_table = probe_side_growable.build()?; - - let (left_table, right_table) = if self.build_on_left { - (build_side_table, probe_side_table) - } else { - (probe_side_table, build_side_table) - }; - - let join_keys_table = left_table.get_columns(&self.common_join_keys)?; - let left_non_join_columns = left_table.get_columns(&self.left_non_join_columns)?; - let right_non_join_columns = right_table.get_columns(&self.right_non_join_columns)?; - let final_table = join_keys_table - .union(&left_non_join_columns)? - .union(&right_non_join_columns)?; - - Ok(Arc::new(MicroPartition::new_loaded( - final_table.schema.clone(), - Arc::new(vec![final_table]), - None, - ))) - } - - fn probe_left_right( - &self, - input: &Arc, - state: &HashJoinProbeState, - ) -> DaftResult> { - let (probe_table, tables) = state.get_probeable_and_table(); - - let _growables = info_span!("HashJoinOperator::build_growables").entered(); - - let mut build_side_growable = GrowableTable::new( - &tables.iter().collect::>(), - true, - tables.iter().map(daft_table::Table::len).sum(), - )?; - - let input_tables = input.get_tables()?; - - let mut probe_side_growable = - GrowableTable::new(&input_tables.iter().collect::>(), false, input.len())?; - - drop(_growables); - { - let _loop = info_span!("HashJoinOperator::eval_and_probe").entered(); - for (probe_side_table_idx, table) in input_tables.iter().enumerate() { - let join_keys = table.eval_expression_list(&self.probe_on)?; - let idx_mapper = probe_table.probe_indices(&join_keys)?; - - for (probe_row_idx, inner_iter) in idx_mapper.make_iter().enumerate() { - if let Some(inner_iter) = inner_iter { - for (build_side_table_idx, build_row_idx) in inner_iter { - build_side_growable.extend( - build_side_table_idx as usize, - build_row_idx as usize, - 1, - ); - probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); - } - } else { - // if there's no match, we should still emit the probe side and fill the build side with nulls - build_side_growable.add_nulls(1); - probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); - } - } - } - } - let build_side_table = build_side_growable.build()?; - let probe_side_table = probe_side_growable.build()?; - - let final_table = if self.join_type == JoinType::Left { - let join_table = probe_side_table.get_columns(&self.common_join_keys)?; - let left = probe_side_table.get_columns(&self.left_non_join_columns)?; - let right = build_side_table.get_columns(&self.right_non_join_columns)?; - join_table.union(&left)?.union(&right)? - } else { - let join_table = probe_side_table.get_columns(&self.common_join_keys)?; - let left = build_side_table.get_columns(&self.left_non_join_columns)?; - let right = probe_side_table.get_columns(&self.right_non_join_columns)?; - join_table.union(&left)?.union(&right)? - }; - Ok(Arc::new(MicroPartition::new_loaded( - final_table.schema.clone(), - Arc::new(vec![final_table]), - None, - ))) - } -} - -impl IntermediateOperator for HashJoinProbeOperator { - #[instrument(skip_all, name = "HashJoinOperator::execute")] - fn execute( - &self, - idx: usize, - input: &PipelineResultType, - state: Option<&mut Box>, - ) -> DaftResult { - let state = state - .expect("HashJoinProbeOperator should have state") - .as_any_mut() - .downcast_mut::() - .expect("HashJoinProbeOperator state should be HashJoinProbeState"); - - if idx == 0 { - let (probe_table, tables) = input.as_probe_table(); - state.set_table(probe_table, tables); - Ok(IntermediateOperatorResult::NeedMoreInput(None)) - } else { - let input = input.as_data(); - let out = match self.join_type { - JoinType::Inner => self.probe_inner(input, state), - JoinType::Left | JoinType::Right => self.probe_left_right(input, state), - _ => { - unimplemented!( - "Only Inner, Left, and Right joins are supported in HashJoinProbeOperator" - ) - } - }?; - Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) - } - } - - fn name(&self) -> &'static str { - "HashJoinProbeOperator" - } - - fn make_state(&self) -> Option> { - Some(Box::new(HashJoinProbeState::Building)) - } -} diff --git a/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs new file mode 100644 index 0000000000..a208efea6c --- /dev/null +++ b/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs @@ -0,0 +1,199 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use daft_core::prelude::SchemaRef; +use daft_dsl::ExprRef; +use daft_micropartition::MicroPartition; +use daft_table::{GrowableTable, ProbeState}; +use indexmap::IndexSet; +use tracing::{info_span, instrument}; + +use super::intermediate_op::{ + IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, +}; +use crate::pipeline::PipelineResultType; + +enum InnerHashJoinProbeState { + Building, + ReadyToProbe(Arc), +} + +impl InnerHashJoinProbeState { + fn set_probe_state(&mut self, probe_state: Arc) { + if matches!(self, Self::Building) { + *self = Self::ReadyToProbe(probe_state); + } else { + panic!("InnerHashJoinProbeState should only be in Building state when setting table") + } + } + + fn get_probe_state(&self) -> &Arc { + if let Self::ReadyToProbe(probe_state) = self { + probe_state + } else { + panic!("get_probeable_and_table can only be used during the ReadyToProbe Phase") + } + } +} + +impl IntermediateOperatorState for InnerHashJoinProbeState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } +} + +pub struct InnerHashJoinProbeOperator { + probe_on: Vec, + common_join_keys: Vec, + left_non_join_columns: Vec, + right_non_join_columns: Vec, + build_on_left: bool, + output_schema: SchemaRef, +} + +impl InnerHashJoinProbeOperator { + const DEFAULT_GROWABLE_SIZE: usize = 20; + + pub fn new( + probe_on: Vec, + left_schema: &SchemaRef, + right_schema: &SchemaRef, + build_on_left: bool, + common_join_keys: IndexSet, + output_schema: &SchemaRef, + ) -> Self { + let left_non_join_columns = left_schema + .fields + .keys() + .filter(|c| !common_join_keys.contains(*c)) + .cloned() + .collect(); + let right_non_join_columns = right_schema + .fields + .keys() + .filter(|c| !common_join_keys.contains(*c)) + .cloned() + .collect(); + let common_join_keys = common_join_keys.into_iter().collect(); + Self { + probe_on, + common_join_keys, + left_non_join_columns, + right_non_join_columns, + build_on_left, + output_schema: output_schema.clone(), + } + } + + fn probe_inner( + &self, + input: &Arc, + state: &InnerHashJoinProbeState, + ) -> DaftResult> { + let (probe_table, tables) = { + let probe_state = state.get_probe_state(); + (probe_state.get_probeable(), probe_state.get_tables()) + }; + + let _growables = info_span!("InnerHashJoinOperator::build_growables").entered(); + + let mut build_side_growable = GrowableTable::new( + &tables.iter().collect::>(), + false, + Self::DEFAULT_GROWABLE_SIZE, + )?; + + let input_tables = input.get_tables()?; + + let mut probe_side_growable = GrowableTable::new( + &input_tables.iter().collect::>(), + false, + Self::DEFAULT_GROWABLE_SIZE, + )?; + + drop(_growables); + { + let _loop = info_span!("InnerHashJoinOperator::eval_and_probe").entered(); + for (probe_side_table_idx, table) in input_tables.iter().enumerate() { + // we should emit one table at a time when this is streaming + let join_keys = table.eval_expression_list(&self.probe_on)?; + let idx_mapper = probe_table.probe_indices(&join_keys)?; + + for (probe_row_idx, inner_iter) in idx_mapper.make_iter().enumerate() { + if let Some(inner_iter) = inner_iter { + for (build_side_table_idx, build_row_idx) in inner_iter { + build_side_growable.extend( + build_side_table_idx as usize, + build_row_idx as usize, + 1, + ); + // we can perform run length compression for this to make this more efficient + probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); + } + } + } + } + } + let build_side_table = build_side_growable.build()?; + let probe_side_table = probe_side_growable.build()?; + + let (left_table, right_table) = if self.build_on_left { + (build_side_table, probe_side_table) + } else { + (probe_side_table, build_side_table) + }; + + let join_keys_table = left_table.get_columns(&self.common_join_keys)?; + let left_non_join_columns = left_table.get_columns(&self.left_non_join_columns)?; + let right_non_join_columns = right_table.get_columns(&self.right_non_join_columns)?; + let final_table = join_keys_table + .union(&left_non_join_columns)? + .union(&right_non_join_columns)?; + + Ok(Arc::new(MicroPartition::new_loaded( + final_table.schema.clone(), + Arc::new(vec![final_table]), + None, + ))) + } +} + +impl IntermediateOperator for InnerHashJoinProbeOperator { + #[instrument(skip_all, name = "InnerHashJoinOperator::execute")] + fn execute( + &self, + idx: usize, + input: &PipelineResultType, + state: Option<&mut Box>, + ) -> DaftResult { + let state = state + .expect("InnerHashJoinProbeOperator should have state") + .as_any_mut() + .downcast_mut::() + .expect("InnerHashJoinProbeOperator state should be InnerHashJoinProbeState"); + match idx { + 0 => { + let probe_state = input.as_probe_state(); + state.set_probe_state(probe_state.clone()); + Ok(IntermediateOperatorResult::NeedMoreInput(None)) + } + _ => { + let input = input.as_data(); + if input.is_empty() { + let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); + return Ok(IntermediateOperatorResult::NeedMoreInput(Some(empty))); + } + let out = self.probe_inner(input, state)?; + Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) + } + } + } + + fn name(&self) -> &'static str { + "InnerHashJoinProbeOperator" + } + + fn make_state(&self) -> Option> { + Some(Box::new(InnerHashJoinProbeState::Building)) + } +} diff --git a/src/daft-local-execution/src/intermediate_ops/mod.rs b/src/daft-local-execution/src/intermediate_ops/mod.rs index 28936fc924..15e880c56c 100644 --- a/src/daft-local-execution/src/intermediate_ops/mod.rs +++ b/src/daft-local-execution/src/intermediate_ops/mod.rs @@ -2,7 +2,7 @@ pub mod aggregate; pub mod anti_semi_hash_join_probe; pub mod buffer; pub mod filter; -pub mod hash_join_probe; +pub mod inner_hash_join_probe; pub mod intermediate_op; pub mod pivot; pub mod project; diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index c231183f9a..4e397ff45d 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -14,7 +14,7 @@ use daft_physical_plan::{ Project, Sample, Sort, UnGroupedAggregate, }; use daft_plan::{populate_aggregation_stages, JoinType}; -use daft_table::{Probeable, Table}; +use daft_table::ProbeState; use indexmap::IndexSet; use snafu::ResultExt; @@ -22,13 +22,14 @@ use crate::{ channel::PipelineChannel, intermediate_ops::{ aggregate::AggregateOperator, anti_semi_hash_join_probe::AntiSemiProbeOperator, - filter::FilterOperator, hash_join_probe::HashJoinProbeOperator, + filter::FilterOperator, inner_hash_join_probe::InnerHashJoinProbeOperator, intermediate_op::IntermediateNode, pivot::PivotOperator, project::ProjectOperator, sample::SampleOperator, }, sinks::{ aggregate::AggregateSink, blocking_sink::BlockingSinkNode, - hash_join_build::HashJoinBuildSink, limit::LimitSink, sort::SortSink, + hash_join_build::HashJoinBuildSink, limit::LimitSink, + outer_hash_join_probe::OuterHashJoinProbeSink, sort::SortSink, streaming_sink::StreamingSinkNode, }, sources::{empty_scan::EmptyScanSource, in_memory::InMemorySource}, @@ -38,7 +39,7 @@ use crate::{ #[derive(Clone)] pub enum PipelineResultType { Data(Arc), - ProbeTable(Arc, Arc>), + ProbeState(Arc), } impl From> for PipelineResultType { @@ -47,9 +48,9 @@ impl From> for PipelineResultType { } } -impl From<(Arc, Arc>)> for PipelineResultType { - fn from((probe_table, tables): (Arc, Arc>)) -> Self { - Self::ProbeTable(probe_table, tables) +impl From> for PipelineResultType { + fn from(probe_state: Arc) -> Self { + Self::ProbeState(probe_state) } } @@ -61,15 +62,15 @@ impl PipelineResultType { } } - pub fn as_probe_table(&self) -> (&Arc, &Arc>) { + pub fn as_probe_state(&self) -> &Arc { match self { - Self::ProbeTable(probe_table, tables) => (probe_table, tables), + Self::ProbeState(probe_state) => probe_state, _ => panic!("Expected probe table"), } } pub fn should_broadcast(&self) -> bool { - matches!(self, Self::ProbeTable(_, _)) + matches!(self, Self::ProbeState(_)) } } @@ -149,7 +150,7 @@ pub fn physical_plan_to_pipeline( }) => { let sink = LimitSink::new(*num_rows as usize); let child_node = physical_plan_to_pipeline(input, psets)?; - StreamingSinkNode::new(sink.boxed(), vec![child_node]).boxed() + StreamingSinkNode::new(Arc::new(sink), vec![child_node]).boxed() } LocalPhysicalPlan::Concat(_) => { todo!("concat") @@ -269,7 +270,7 @@ pub fn physical_plan_to_pipeline( left_on, right_on, join_type, - .. + schema, }) => { let left_schema = left.schema(); let right_schema = right.schema(); @@ -280,11 +281,9 @@ pub fn physical_plan_to_pipeline( let build_on_left = match join_type { JoinType::Inner => true, JoinType::Right => true, + JoinType::Outer => true, JoinType::Left => false, JoinType::Anti | JoinType::Semi => false, - JoinType::Outer => { - unimplemented!("Outer join not supported yet"); - } }; let (build_on, probe_on, build_child, probe_child) = match build_on_left { true => (left_on, right_on, left, right), @@ -293,7 +292,7 @@ pub fn physical_plan_to_pipeline( let build_schema = build_child.schema(); let probe_schema = probe_child.schema(); - let probe_node = || -> DaftResult<_> { + || -> DaftResult<_> { let common_join_keys: IndexSet<_> = get_common_join_keys(left_on, right_on) .map(std::string::ToString::to_string) .collect(); @@ -338,32 +337,46 @@ pub fn physical_plan_to_pipeline( let probe_child_node = physical_plan_to_pipeline(probe_child, psets)?; match join_type { - JoinType::Anti | JoinType::Semi => DaftResult::Ok(IntermediateNode::new( - Arc::new(AntiSemiProbeOperator::new(casted_probe_on, *join_type)), + JoinType::Anti | JoinType::Semi => Ok(IntermediateNode::new( + Arc::new(AntiSemiProbeOperator::new( + casted_probe_on, + join_type, + schema, + )), vec![build_node, probe_child_node], - )), - JoinType::Inner | JoinType::Left | JoinType::Right => { - DaftResult::Ok(IntermediateNode::new( - Arc::new(HashJoinProbeOperator::new( + ) + .boxed()), + JoinType::Inner => Ok(IntermediateNode::new( + Arc::new(InnerHashJoinProbeOperator::new( + casted_probe_on, + left_schema, + right_schema, + build_on_left, + common_join_keys, + schema, + )), + vec![build_node, probe_child_node], + ) + .boxed()), + JoinType::Left | JoinType::Right | JoinType::Outer => { + Ok(StreamingSinkNode::new( + Arc::new(OuterHashJoinProbeSink::new( casted_probe_on, left_schema, right_schema, *join_type, - build_on_left, common_join_keys, + schema, )), vec![build_node, probe_child_node], - )) - } - JoinType::Outer => { - unimplemented!("Outer join not supported yet"); + ) + .boxed()) } } }() .with_context(|_| PipelineCreationSnafu { plan_name: physical_plan.name(), - })?; - probe_node.boxed() + })? } _ => { unimplemented!("Physical plan not supported: {}", physical_plan.name()); diff --git a/src/daft-local-execution/src/runtime_stats.rs b/src/daft-local-execution/src/runtime_stats.rs index de1f657273..566d253e9c 100644 --- a/src/daft-local-execution/src/runtime_stats.rs +++ b/src/daft-local-execution/src/runtime_stats.rs @@ -124,8 +124,8 @@ impl CountingSender { ) -> Result<(), SendError> { let len = match v { PipelineResultType::Data(ref mp) => mp.len(), - PipelineResultType::ProbeTable(_, ref tables) => { - tables.iter().map(daft_table::Table::len).sum() + PipelineResultType::ProbeState(ref state) => { + state.get_tables().iter().map(|t| t.len()).sum() } }; self.sender.send(v).await?; @@ -149,8 +149,8 @@ impl CountingReceiver { if let Some(ref v) = v { let len = match v { PipelineResultType::Data(ref mp) => mp.len(), - PipelineResultType::ProbeTable(_, ref tables) => { - tables.iter().map(daft_table::Table::len).sum() + PipelineResultType::ProbeState(state) => { + state.get_tables().iter().map(|t| t.len()).sum() } }; self.rt.mark_rows_received(len as u64); diff --git a/src/daft-local-execution/src/sinks/hash_join_build.rs b/src/daft-local-execution/src/sinks/hash_join_build.rs index 3af65702cd..c8258e281a 100644 --- a/src/daft-local-execution/src/sinks/hash_join_build.rs +++ b/src/daft-local-execution/src/sinks/hash_join_build.rs @@ -5,7 +5,7 @@ use daft_core::prelude::SchemaRef; use daft_dsl::ExprRef; use daft_micropartition::MicroPartition; use daft_plan::JoinType; -use daft_table::{make_probeable_builder, Probeable, ProbeableBuilder, Table}; +use daft_table::{make_probeable_builder, ProbeState, ProbeableBuilder, Table}; use super::blocking_sink::{BlockingSink, BlockingSinkStatus}; use crate::pipeline::PipelineResultType; @@ -17,8 +17,7 @@ enum ProbeTableState { tables: Vec, }, Done { - probe_table: Arc, - tables: Arc>, + probe_state: Arc, }, } @@ -66,8 +65,7 @@ impl ProbeTableState { let pt = ptb.build(); *self = Self::Done { - probe_table: pt, - tables: Arc::new(tables.clone()), + probe_state: Arc::new(ProbeState::new(pt, Arc::new(tables.clone()))), }; Ok(()) } else { @@ -108,12 +106,8 @@ impl BlockingSink for HashJoinBuildSink { fn finalize(&mut self) -> DaftResult> { self.probe_table_state.finalize()?; - if let ProbeTableState::Done { - probe_table, - tables, - } = &self.probe_table_state - { - Ok(Some((probe_table.clone(), tables.clone()).into())) + if let ProbeTableState::Done { probe_state } = &self.probe_table_state { + Ok(Some(probe_state.clone().into())) } else { panic!("finalize should only be called after the probe table is built") } diff --git a/src/daft-local-execution/src/sinks/limit.rs b/src/daft-local-execution/src/sinks/limit.rs index 40b4d1538f..633c3511c1 100644 --- a/src/daft-local-execution/src/sinks/limit.rs +++ b/src/daft-local-execution/src/sinks/limit.rs @@ -4,51 +4,69 @@ use common_error::DaftResult; use daft_micropartition::MicroPartition; use tracing::instrument; -use super::streaming_sink::{StreamSinkOutput, StreamingSink}; +use super::streaming_sink::{StreamingSink, StreamingSinkOutput, StreamingSinkState}; +use crate::pipeline::PipelineResultType; + +struct LimitSinkState { + remaining: usize, +} + +impl LimitSinkState { + fn new(remaining: usize) -> Self { + Self { remaining } + } + + fn get_remaining_mut(&mut self) -> &mut usize { + &mut self.remaining + } +} + +impl StreamingSinkState for LimitSinkState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } +} pub struct LimitSink { - #[allow(dead_code)] limit: usize, - remaining: usize, } impl LimitSink { pub fn new(limit: usize) -> Self { - Self { - limit, - remaining: limit, - } - } - pub fn boxed(self) -> Box { - Box::new(self) + Self { limit } } } impl StreamingSink for LimitSink { #[instrument(skip_all, name = "LimitSink::sink")] fn execute( - &mut self, + &self, index: usize, - input: &Arc, - ) -> DaftResult { + input: &PipelineResultType, + state: &mut dyn StreamingSinkState, + ) -> DaftResult { assert_eq!(index, 0); - + let state = state + .as_any_mut() + .downcast_mut::() + .expect("Limit Sink should have LimitSinkState"); + let input = input.as_data(); let input_num_rows = input.len(); - + let remaining = state.get_remaining_mut(); use std::cmp::Ordering::{Equal, Greater, Less}; - match input_num_rows.cmp(&self.remaining) { + match input_num_rows.cmp(remaining) { Less => { - self.remaining -= input_num_rows; - Ok(StreamSinkOutput::NeedMoreInput(Some(input.clone()))) + *remaining -= input_num_rows; + Ok(StreamingSinkOutput::NeedMoreInput(Some(input.clone()))) } Equal => { - self.remaining = 0; - Ok(StreamSinkOutput::Finished(Some(input.clone()))) + *remaining = 0; + Ok(StreamingSinkOutput::Finished(Some(input.clone()))) } Greater => { - let taken = input.head(self.remaining)?; - self.remaining -= taken.len(); - Ok(StreamSinkOutput::Finished(Some(Arc::new(taken)))) + let taken = input.head(*remaining)?; + *remaining = 0; + Ok(StreamingSinkOutput::Finished(Some(Arc::new(taken)))) } } } @@ -56,4 +74,19 @@ impl StreamingSink for LimitSink { fn name(&self) -> &'static str { "Limit" } + + fn finalize( + &self, + _states: Vec>, + ) -> DaftResult>> { + Ok(None) + } + + fn make_state(&self) -> Box { + Box::new(LimitSinkState::new(self.limit)) + } + + fn max_concurrency(&self) -> usize { + 1 + } } diff --git a/src/daft-local-execution/src/sinks/mod.rs b/src/daft-local-execution/src/sinks/mod.rs index 39910e7995..7960e55a7c 100644 --- a/src/daft-local-execution/src/sinks/mod.rs +++ b/src/daft-local-execution/src/sinks/mod.rs @@ -3,5 +3,6 @@ pub mod blocking_sink; pub mod concat; pub mod hash_join_build; pub mod limit; +pub mod outer_hash_join_probe; pub mod sort; pub mod streaming_sink; diff --git a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs new file mode 100644 index 0000000000..ab5ffa8cb0 --- /dev/null +++ b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs @@ -0,0 +1,419 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use daft_core::{ + prelude::{ + bitmap::{and, Bitmap, MutableBitmap}, + BooleanArray, Schema, SchemaRef, + }, + series::{IntoSeries, Series}, +}; +use daft_dsl::ExprRef; +use daft_micropartition::MicroPartition; +use daft_plan::JoinType; +use daft_table::{GrowableTable, ProbeState, Table}; +use indexmap::IndexSet; +use tracing::{info_span, instrument}; + +use super::streaming_sink::{StreamingSink, StreamingSinkOutput, StreamingSinkState}; +use crate::pipeline::PipelineResultType; + +struct IndexBitmapBuilder { + mutable_bitmaps: Vec, +} + +impl IndexBitmapBuilder { + fn new(tables: &[Table]) -> Self { + Self { + mutable_bitmaps: tables + .iter() + .map(|t| MutableBitmap::from_len_set(t.len())) + .collect(), + } + } + + #[inline] + fn mark_used(&mut self, table_idx: usize, row_idx: usize) { + self.mutable_bitmaps[table_idx].set(row_idx, false); + } + + fn build(self) -> IndexBitmap { + IndexBitmap { + bitmaps: self.mutable_bitmaps.into_iter().map(|b| b.into()).collect(), + } + } +} + +struct IndexBitmap { + bitmaps: Vec, +} + +impl IndexBitmap { + fn merge(&self, other: &Self) -> Self { + Self { + bitmaps: self + .bitmaps + .iter() + .zip(other.bitmaps.iter()) + .map(|(a, b)| and(a, b)) + .collect(), + } + } + + fn convert_to_boolean_arrays(self) -> impl Iterator { + self.bitmaps + .into_iter() + .map(|b| BooleanArray::from(("bitmap", b))) + } +} + +enum OuterHashJoinProbeState { + Building, + ReadyToProbe(Arc, Option), +} + +impl OuterHashJoinProbeState { + fn initialize_probe_state(&mut self, probe_state: Arc, needs_bitmap: bool) { + let tables = probe_state.get_tables(); + if matches!(self, Self::Building) { + *self = Self::ReadyToProbe( + probe_state.clone(), + if needs_bitmap { + Some(IndexBitmapBuilder::new(tables)) + } else { + None + }, + ); + } else { + panic!("OuterHashJoinProbeState should only be in Building state when setting table") + } + } + + fn get_probe_state(&self) -> &ProbeState { + if let Self::ReadyToProbe(probe_state, _) = self { + probe_state + } else { + panic!("get_probeable_and_table can only be used during the ReadyToProbe Phase") + } + } + + fn get_bitmap_builder(&mut self) -> &mut Option { + if let Self::ReadyToProbe(_, bitmap_builder) = self { + bitmap_builder + } else { + panic!("get_bitmap can only be used during the ReadyToProbe Phase") + } + } +} + +impl StreamingSinkState for OuterHashJoinProbeState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } +} + +pub(crate) struct OuterHashJoinProbeSink { + probe_on: Vec, + common_join_keys: Vec, + left_non_join_columns: Vec, + right_non_join_columns: Vec, + right_non_join_schema: SchemaRef, + join_type: JoinType, + output_schema: SchemaRef, +} + +impl OuterHashJoinProbeSink { + pub(crate) fn new( + probe_on: Vec, + left_schema: &SchemaRef, + right_schema: &SchemaRef, + join_type: JoinType, + common_join_keys: IndexSet, + output_schema: &SchemaRef, + ) -> Self { + let left_non_join_columns = left_schema + .fields + .keys() + .filter(|c| !common_join_keys.contains(*c)) + .cloned() + .collect(); + let right_non_join_fields = right_schema + .fields + .values() + .filter(|f| !common_join_keys.contains(&f.name)) + .cloned() + .collect(); + let right_non_join_schema = + Arc::new(Schema::new(right_non_join_fields).expect("right schema should be valid")); + let right_non_join_columns = right_non_join_schema.fields.keys().cloned().collect(); + let common_join_keys = common_join_keys.into_iter().collect(); + Self { + probe_on, + common_join_keys, + left_non_join_columns, + right_non_join_columns, + right_non_join_schema, + join_type, + output_schema: output_schema.clone(), + } + } + + fn probe_left_right( + &self, + input: &Arc, + state: &OuterHashJoinProbeState, + ) -> DaftResult> { + let (probe_table, tables) = { + let probe_state = state.get_probe_state(); + (probe_state.get_probeable(), probe_state.get_tables()) + }; + + let _growables = info_span!("OuterHashJoinProbeSink::build_growables").entered(); + let mut build_side_growable = GrowableTable::new( + &tables.iter().collect::>(), + true, + tables.iter().map(|t| t.len()).sum(), + )?; + + let input_tables = input.get_tables()?; + let mut probe_side_growable = + GrowableTable::new(&input_tables.iter().collect::>(), false, input.len())?; + + drop(_growables); + { + let _loop = info_span!("OuterHashJoinProbeSink::eval_and_probe").entered(); + for (probe_side_table_idx, table) in input_tables.iter().enumerate() { + let join_keys = table.eval_expression_list(&self.probe_on)?; + let idx_mapper = probe_table.probe_indices(&join_keys)?; + + for (probe_row_idx, inner_iter) in idx_mapper.make_iter().enumerate() { + if let Some(inner_iter) = inner_iter { + for (build_side_table_idx, build_row_idx) in inner_iter { + build_side_growable.extend( + build_side_table_idx as usize, + build_row_idx as usize, + 1, + ); + probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); + } + } else { + // if there's no match, we should still emit the probe side and fill the build side with nulls + build_side_growable.add_nulls(1); + probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); + } + } + } + } + let build_side_table = build_side_growable.build()?; + let probe_side_table = probe_side_growable.build()?; + + let final_table = if self.join_type == JoinType::Left { + let join_table = probe_side_table.get_columns(&self.common_join_keys)?; + let left = probe_side_table.get_columns(&self.left_non_join_columns)?; + let right = build_side_table.get_columns(&self.right_non_join_columns)?; + join_table.union(&left)?.union(&right)? + } else { + let join_table = probe_side_table.get_columns(&self.common_join_keys)?; + let left = build_side_table.get_columns(&self.left_non_join_columns)?; + let right = probe_side_table.get_columns(&self.right_non_join_columns)?; + join_table.union(&left)?.union(&right)? + }; + Ok(Arc::new(MicroPartition::new_loaded( + final_table.schema.clone(), + Arc::new(vec![final_table]), + None, + ))) + } + + fn probe_outer( + &self, + input: &Arc, + state: &mut OuterHashJoinProbeState, + ) -> DaftResult> { + let (probe_table, tables) = { + let probe_state = state.get_probe_state(); + ( + probe_state.get_probeable().clone(), + probe_state.get_tables().clone(), + ) + }; + let bitmap_builder = state.get_bitmap_builder(); + let _growables = info_span!("OuterHashJoinProbeSink::build_growables").entered(); + // Need to set use_validity to true here because we add nulls to the build side + let mut build_side_growable = GrowableTable::new( + &tables.iter().collect::>(), + true, + tables.iter().map(|t| t.len()).sum(), + )?; + + let input_tables = input.get_tables()?; + let mut probe_side_growable = + GrowableTable::new(&input_tables.iter().collect::>(), false, input.len())?; + + let left_idx_used = bitmap_builder + .as_mut() + .expect("bitmap should be set in outer join"); + + drop(_growables); + { + let _loop = info_span!("OuterHashJoinProbeSink::eval_and_probe").entered(); + for (probe_side_table_idx, table) in input_tables.iter().enumerate() { + let join_keys = table.eval_expression_list(&self.probe_on)?; + let idx_mapper = probe_table.probe_indices(&join_keys)?; + + for (probe_row_idx, inner_iter) in idx_mapper.make_iter().enumerate() { + if let Some(inner_iter) = inner_iter { + for (build_side_table_idx, build_row_idx) in inner_iter { + let build_side_table_idx = build_side_table_idx as usize; + let build_row_idx = build_row_idx as usize; + left_idx_used.mark_used(build_side_table_idx, build_row_idx); + build_side_growable.extend(build_side_table_idx, build_row_idx, 1); + probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); + } + } else { + // if there's no match, we should still emit the probe side and fill the build side with nulls + build_side_growable.add_nulls(1); + probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); + } + } + } + } + let build_side_table = build_side_growable.build()?; + let probe_side_table = probe_side_growable.build()?; + + let join_table = probe_side_table.get_columns(&self.common_join_keys)?; + let left = build_side_table.get_columns(&self.left_non_join_columns)?; + let right = probe_side_table.get_columns(&self.right_non_join_columns)?; + let final_table = join_table.union(&left)?.union(&right)?; + Ok(Arc::new(MicroPartition::new_loaded( + final_table.schema.clone(), + Arc::new(vec![final_table]), + None, + ))) + } + + fn finalize_outer( + &self, + mut states: Vec>, + ) -> DaftResult>> { + let states = states + .iter_mut() + .map(|s| { + s.as_any_mut() + .downcast_mut::() + .expect("OuterHashJoinProbeSink state should be OuterHashJoinProbeState") + }) + .collect::>(); + let tables = states + .first() + .expect("at least one state should be present") + .get_probe_state() + .get_tables() + .clone(); + + let merged_bitmap = { + let bitmaps = states.into_iter().map(|s| { + if let OuterHashJoinProbeState::ReadyToProbe(_, bitmap) = s { + bitmap + .take() + .expect("bitmap should be present in outer join") + .build() + } else { + panic!("OuterHashJoinProbeState should be in ReadyToProbe state") + } + }); + bitmaps.fold(None, |acc, x| match acc { + None => Some(x), + Some(acc) => Some(acc.merge(&x)), + }) + } + .expect("at least one bitmap should be present"); + + let leftovers = merged_bitmap + .convert_to_boolean_arrays() + .zip(tables.iter()) + .map(|(bitmap, table)| table.mask_filter(&bitmap.into_series())) + .collect::>>()?; + + let build_side_table = Table::concat(&leftovers)?; + + let join_table = build_side_table.get_columns(&self.common_join_keys)?; + let left = build_side_table.get_columns(&self.left_non_join_columns)?; + let right = { + let columns = self + .right_non_join_schema + .fields + .values() + .map(|field| Series::full_null(&field.name, &field.dtype, left.len())) + .collect::>(); + Table::new_unchecked(self.right_non_join_schema.clone(), columns, left.len()) + }; + let final_table = join_table.union(&left)?.union(&right)?; + Ok(Some(Arc::new(MicroPartition::new_loaded( + final_table.schema.clone(), + Arc::new(vec![final_table]), + None, + )))) + } +} + +impl StreamingSink for OuterHashJoinProbeSink { + #[instrument(skip_all, name = "OuterHashJoinProbeSink::execute")] + fn execute( + &self, + idx: usize, + input: &PipelineResultType, + state: &mut dyn StreamingSinkState, + ) -> DaftResult { + match idx { + 0 => { + let state = state + .as_any_mut() + .downcast_mut::() + .expect("OuterHashJoinProbeSink state should be OuterHashJoinProbeState"); + let probe_state = input.as_probe_state(); + state + .initialize_probe_state(probe_state.clone(), self.join_type == JoinType::Outer); + Ok(StreamingSinkOutput::NeedMoreInput(None)) + } + _ => { + let state = state + .as_any_mut() + .downcast_mut::() + .expect("OuterHashJoinProbeSink state should be OuterHashJoinProbeState"); + let input = input.as_data(); + if input.is_empty() { + let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); + return Ok(StreamingSinkOutput::NeedMoreInput(Some(empty))); + } + let out = match self.join_type { + JoinType::Left | JoinType::Right => self.probe_left_right(input, state), + JoinType::Outer => self.probe_outer(input, state), + _ => unreachable!( + "Only Left, Right, and Outer joins are supported in OuterHashJoinProbeSink" + ), + }?; + Ok(StreamingSinkOutput::NeedMoreInput(Some(out))) + } + } + } + + fn name(&self) -> &'static str { + "OuterHashJoinProbeSink" + } + + fn make_state(&self) -> Box { + Box::new(OuterHashJoinProbeState::Building) + } + + fn finalize( + &self, + states: Vec>, + ) -> DaftResult>> { + if self.join_type == JoinType::Outer { + self.finalize_outer(states) + } else { + Ok(None) + } + } +} diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index f18a7efca0..0a7000af8f 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -3,14 +3,22 @@ use std::sync::Arc; use common_display::tree::TreeDisplay; use common_error::DaftResult; use daft_micropartition::MicroPartition; -use tracing::info_span; +use snafu::ResultExt; +use tracing::{info_span, instrument}; use crate::{ - channel::PipelineChannel, pipeline::PipelineNode, runtime_stats::RuntimeStatsContext, - ExecutionRuntimeHandle, NUM_CPUS, + channel::{create_channel, PipelineChannel, Receiver, Sender}, + create_task_set, + pipeline::{PipelineNode, PipelineResultType}, + runtime_stats::{CountingReceiver, RuntimeStatsContext}, + ExecutionRuntimeHandle, JoinSnafu, TaskSet, NUM_CPUS, }; -pub enum StreamSinkOutput { +pub trait StreamingSinkState: Send + Sync { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any; +} + +pub enum StreamingSinkOutput { NeedMoreInput(Option>), #[allow(dead_code)] HasMoreOutput(Arc), @@ -19,35 +27,121 @@ pub enum StreamSinkOutput { pub trait StreamingSink: Send + Sync { fn execute( - &mut self, + &self, index: usize, - input: &Arc, - ) -> DaftResult; - #[allow(dead_code)] + input: &PipelineResultType, + state: &mut dyn StreamingSinkState, + ) -> DaftResult; + fn finalize( + &self, + states: Vec>, + ) -> DaftResult>>; fn name(&self) -> &'static str; + fn make_state(&self) -> Box; + fn max_concurrency(&self) -> usize { + *NUM_CPUS + } } pub struct StreamingSinkNode { - // use a RW lock - op: Arc>>, + op: Arc, name: &'static str, children: Vec>, runtime_stats: Arc, } impl StreamingSinkNode { - pub(crate) fn new(op: Box, children: Vec>) -> Self { + pub(crate) fn new(op: Arc, children: Vec>) -> Self { let name = op.name(); Self { - op: Arc::new(tokio::sync::Mutex::new(op)), + op, name, children, runtime_stats: RuntimeStatsContext::new(), } } + pub(crate) fn boxed(self) -> Box { Box::new(self) } + + #[instrument(level = "info", skip_all, name = "StreamingSink::run_worker")] + async fn run_worker( + op: Arc, + mut input_receiver: Receiver<(usize, PipelineResultType)>, + output_sender: Sender>, + rt_context: Arc, + ) -> DaftResult> { + let span = info_span!("StreamingSink::Execute"); + let mut state = op.make_state(); + while let Some((idx, morsel)) = input_receiver.recv().await { + loop { + let result = + rt_context.in_span(&span, || op.execute(idx, &morsel, state.as_mut()))?; + match result { + StreamingSinkOutput::NeedMoreInput(mp) => { + if let Some(mp) = mp { + let _ = output_sender.send(mp).await; + } + break; + } + StreamingSinkOutput::HasMoreOutput(mp) => { + let _ = output_sender.send(mp).await; + } + StreamingSinkOutput::Finished(mp) => { + if let Some(mp) = mp { + let _ = output_sender.send(mp).await; + } + return Ok(state); + } + } + } + } + Ok(state) + } + + fn spawn_workers( + op: Arc, + input_receivers: Vec>, + task_set: &mut TaskSet>>, + stats: Arc, + ) -> Receiver> { + let (output_sender, output_receiver) = create_channel(input_receivers.len()); + for input_receiver in input_receivers { + task_set.spawn(Self::run_worker( + op.clone(), + input_receiver, + output_sender.clone(), + stats.clone(), + )); + } + output_receiver + } + + async fn forward_input_to_workers( + receivers: Vec, + worker_senders: Vec>, + ) -> DaftResult<()> { + let mut next_worker_idx = 0; + let mut send_to_next_worker = |idx, data: PipelineResultType| { + let next_worker_sender = worker_senders.get(next_worker_idx).unwrap(); + next_worker_idx = (next_worker_idx + 1) % worker_senders.len(); + next_worker_sender.send((idx, data)) + }; + + for (idx, mut receiver) in receivers.into_iter().enumerate() { + while let Some(morsel) = receiver.recv().await { + if morsel.should_broadcast() { + for worker_sender in &worker_senders { + let _ = worker_sender.send((idx, morsel.clone())).await; + } + } else { + let _ = send_to_next_worker(idx, morsel.clone()).await; + } + } + } + Ok(()) + } } impl TreeDisplay for StreamingSinkNode { @@ -88,50 +182,49 @@ impl PipelineNode for StreamingSinkNode { maintain_order: bool, runtime_handle: &mut ExecutionRuntimeHandle, ) -> crate::Result { - let child = self - .children - .get_mut(0) - .expect("we should only have 1 child"); - let child_results_channel = child.start(true, runtime_handle)?; - let mut child_results_receiver = - child_results_channel.get_receiver_with_stats(&self.runtime_stats); - - let mut destination_channel = PipelineChannel::new(*NUM_CPUS, maintain_order); - let sender = destination_channel.get_next_sender_with_stats(&self.runtime_stats); + let mut child_result_receivers = Vec::with_capacity(self.children.len()); + for child in &mut self.children { + let child_result_channel = child.start(maintain_order, runtime_handle)?; + child_result_receivers + .push(child_result_channel.get_receiver_with_stats(&self.runtime_stats.clone())); + } + + let mut destination_channel = PipelineChannel::new(1, maintain_order); + let destination_sender = + destination_channel.get_next_sender_with_stats(&self.runtime_stats); + let op = self.op.clone(); let runtime_stats = self.runtime_stats.clone(); + let num_workers = op.max_concurrency(); + let (input_senders, input_receivers) = (0..num_workers).map(|_| create_channel(1)).unzip(); + runtime_handle.spawn( + Self::forward_input_to_workers(child_result_receivers, input_senders), + self.name(), + ); runtime_handle.spawn( async move { - // this should be a RWLock and run in concurrent workers - let span = info_span!("StreamingSink::execute"); - - let mut sink = op.lock().await; - let mut is_active = true; - while is_active && let Some(val) = child_results_receiver.recv().await { - let val = val.as_data(); - loop { - let result = runtime_stats.in_span(&span, || sink.execute(0, val))?; - match result { - StreamSinkOutput::HasMoreOutput(mp) => { - sender.send(mp.into()).await.unwrap(); - } - StreamSinkOutput::NeedMoreInput(mp) => { - if let Some(mp) = mp { - sender.send(mp.into()).await.unwrap(); - } - break; - } - StreamSinkOutput::Finished(mp) => { - if let Some(mp) = mp { - sender.send(mp.into()).await.unwrap(); - } - is_active = false; - break; - } - } - } + let mut task_set = create_task_set(); + let mut output_receiver = Self::spawn_workers( + op.clone(), + input_receivers, + &mut task_set, + runtime_stats.clone(), + ); + + while let Some(morsel) = output_receiver.recv().await { + let _ = destination_sender.send(morsel.into()).await; + } + + let mut finished_states = Vec::with_capacity(num_workers); + while let Some(result) = task_set.join_next().await { + let state = result.context(JoinSnafu)??; + finished_states.push(state); + } + + if let Some(finalized_result) = op.finalize(finished_states)? { + let _ = destination_sender.send(finalized_result.into()).await; } - DaftResult::Ok(()) + Ok(()) }, self.name(), ); diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 1d84e3d7b1..e93f4d7a77 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -29,7 +29,7 @@ mod probeable; mod repr_html; pub use growable::GrowableTable; -pub use probeable::{make_probeable_builder, Probeable, ProbeableBuilder}; +pub use probeable::{make_probeable_builder, ProbeState, Probeable, ProbeableBuilder}; #[cfg(feature = "python")] pub mod python; diff --git a/src/daft-table/src/probeable/mod.rs b/src/daft-table/src/probeable/mod.rs index a3d935246e..3346bd9869 100644 --- a/src/daft-table/src/probeable/mod.rs +++ b/src/daft-table/src/probeable/mod.rs @@ -77,3 +77,23 @@ pub trait Probeable: Send + Sync { table: &'a Table, ) -> DaftResult + 'a>>; } + +#[derive(Clone)] +pub struct ProbeState { + probeable: Arc, + tables: Arc>, +} + +impl ProbeState { + pub fn new(probeable: Arc, tables: Arc>) -> Self { + Self { probeable, tables } + } + + pub fn get_probeable(&self) -> &Arc { + &self.probeable + } + + pub fn get_tables(&self) -> &Arc> { + &self.tables + } +} diff --git a/tests/cookbook/test_joins.py b/tests/cookbook/test_joins.py index b51c863100..d80dce72a2 100644 --- a/tests/cookbook/test_joins.py +++ b/tests/cookbook/test_joins.py @@ -6,16 +6,20 @@ from daft.expressions import col from tests.conftest import assert_df_equals -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) + +def skip_invalid_join_strategies(join_strategy): + if context.get_context().daft_execution_config.enable_native_executor is True: + if join_strategy not in [None, "hash"]: + pytest.skip("Native executor fails for these tests") @pytest.mark.parametrize( - "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True + "join_strategy", + [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], + indirect=True, ) def test_simple_join(join_strategy, daft_df, service_requests_csv_pd_df, repartition_nparts): + skip_invalid_join_strategies(join_strategy) daft_df = daft_df.repartition(repartition_nparts) daft_df_left = daft_df.select(col("Unique Key"), col("Borough")) daft_df_right = daft_df.select(col("Unique Key"), col("Created Date")) @@ -33,9 +37,12 @@ def test_simple_join(join_strategy, daft_df, service_requests_csv_pd_df, reparti @pytest.mark.parametrize( - "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True + "join_strategy", + [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], + indirect=True, ) def test_simple_self_join(join_strategy, daft_df, service_requests_csv_pd_df, repartition_nparts): + skip_invalid_join_strategies(join_strategy) daft_df = daft_df.repartition(repartition_nparts) daft_df = daft_df.select(col("Unique Key"), col("Borough")) @@ -44,7 +51,11 @@ def test_simple_self_join(join_strategy, daft_df, service_requests_csv_pd_df, re service_requests_csv_pd_df = service_requests_csv_pd_df[["Unique Key", "Borough"]] service_requests_csv_pd_df = ( service_requests_csv_pd_df.set_index("Unique Key") - .join(service_requests_csv_pd_df.set_index("Unique Key"), how="inner", rsuffix="_right") + .join( + service_requests_csv_pd_df.set_index("Unique Key"), + how="inner", + rsuffix="_right", + ) .reset_index() ) service_requests_csv_pd_df = service_requests_csv_pd_df.rename({"Borough_right": "right.Borough"}, axis=1) @@ -53,9 +64,12 @@ def test_simple_self_join(join_strategy, daft_df, service_requests_csv_pd_df, re @pytest.mark.parametrize( - "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True + "join_strategy", + [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], + indirect=True, ) def test_simple_join_missing_rvalues(join_strategy, daft_df, service_requests_csv_pd_df, repartition_nparts): + skip_invalid_join_strategies(join_strategy) daft_df_right = daft_df.sort("Unique Key").limit(25).repartition(repartition_nparts) daft_df_left = daft_df.repartition(repartition_nparts) daft_df_left = daft_df_left.select(col("Unique Key"), col("Borough")) @@ -76,9 +90,12 @@ def test_simple_join_missing_rvalues(join_strategy, daft_df, service_requests_cs @pytest.mark.parametrize( - "join_strategy", [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], indirect=True + "join_strategy", + [None, "hash", "sort_merge", "sort_merge_aligned_boundaries", "broadcast"], + indirect=True, ) def test_simple_join_missing_lvalues(join_strategy, daft_df, service_requests_csv_pd_df, repartition_nparts): + skip_invalid_join_strategies(join_strategy) daft_df_right = daft_df.repartition(repartition_nparts) daft_df_left = daft_df.sort(col("Unique Key")).limit(25).repartition(repartition_nparts) daft_df_left = daft_df_left.select(col("Unique Key"), col("Borough")) diff --git a/tests/dataframe/test_joins.py b/tests/dataframe/test_joins.py index 980d19aa3f..4b08abea61 100644 --- a/tests/dataframe/test_joins.py +++ b/tests/dataframe/test_joins.py @@ -12,7 +12,7 @@ def skip_invalid_join_strategies(join_strategy, join_type): if context.get_context().daft_execution_config.enable_native_executor is True: - if join_type == "outer" or join_strategy not in [None, "hash"]: + if join_strategy not in [None, "hash"]: pytest.skip("Native executor fails for these tests") else: if (join_strategy == "sort_merge" or join_strategy == "sort_merge_aligned_boundaries") and join_type != "inner":