From ad175ae4e2ed4d178c07b1da5628673623b12e37 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Fri, 6 Dec 2024 13:44:41 -0800 Subject: [PATCH] refactor(swordfish): Generic broadcast state bridge (#3508) Follow on from https://github.com/Eventual-Inc/Daft/pull/3437#discussion_r1872261677 Co-authored-by: Colin Ho --- .../anti_semi_hash_join_probe.rs | 12 ++--- .../src/intermediate_ops/cross_join.rs | 11 ++--- .../intermediate_ops/inner_hash_join_probe.rs | 10 ++--- src/daft-local-execution/src/lib.rs | 1 + src/daft-local-execution/src/pipeline.rs | 9 ++-- .../src/sinks/cross_join_collect.rs | 41 ++--------------- .../src/sinks/hash_join_build.rs | 45 +++---------------- .../src/sinks/outer_hash_join_probe.rs | 20 ++++----- src/daft-local-execution/src/state_bridge.rs | 35 +++++++++++++++ 9 files changed, 76 insertions(+), 108 deletions(-) create mode 100644 src/daft-local-execution/src/state_bridge.rs diff --git a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs index 765039651e..3b3ebf692c 100644 --- a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs @@ -6,17 +6,17 @@ use daft_core::prelude::SchemaRef; use daft_dsl::ExprRef; use daft_logical_plan::JoinType; use daft_micropartition::MicroPartition; -use daft_table::{GrowableTable, Probeable}; +use daft_table::{GrowableTable, ProbeState, Probeable}; use tracing::{info_span, instrument}; use super::intermediate_op::{ IntermediateOpExecuteResult, IntermediateOpState, IntermediateOperator, IntermediateOperatorResult, }; -use crate::sinks::hash_join_build::ProbeStateBridgeRef; +use crate::state_bridge::BroadcastStateBridgeRef; enum AntiSemiProbeState { - Building(ProbeStateBridgeRef), + Building(BroadcastStateBridgeRef), Probing(Arc), } @@ -24,7 +24,7 @@ impl AntiSemiProbeState { async fn get_or_await_probeable(&mut self) -> Arc { match self { Self::Building(bridge) => { - let probe_state = bridge.get_probe_state().await; + let probe_state = bridge.get_state().await; let probeable = probe_state.get_probeable(); *self = Self::Probing(probeable.clone()); probeable.clone() @@ -48,7 +48,7 @@ struct AntiSemiJoinParams { pub(crate) struct AntiSemiProbeOperator { params: Arc, output_schema: SchemaRef, - probe_state_bridge: ProbeStateBridgeRef, + probe_state_bridge: BroadcastStateBridgeRef, } impl AntiSemiProbeOperator { @@ -58,7 +58,7 @@ impl AntiSemiProbeOperator { probe_on: Vec, join_type: &JoinType, output_schema: &SchemaRef, - probe_state_bridge: ProbeStateBridgeRef, + probe_state_bridge: BroadcastStateBridgeRef, ) -> Self { Self { params: Arc::new(AntiSemiJoinParams { diff --git a/src/daft-local-execution/src/intermediate_ops/cross_join.rs b/src/daft-local-execution/src/intermediate_ops/cross_join.rs index 65dc52f561..773d69d3a4 100644 --- a/src/daft-local-execution/src/intermediate_ops/cross_join.rs +++ b/src/daft-local-execution/src/intermediate_ops/cross_join.rs @@ -4,22 +4,23 @@ use common_error::DaftResult; use common_runtime::RuntimeRef; use daft_core::{join::JoinSide, prelude::SchemaRef}; use daft_micropartition::MicroPartition; +use daft_table::Table; use tracing::{info_span, instrument, Instrument}; use super::intermediate_op::{ IntermediateOpExecuteResult, IntermediateOpState, IntermediateOperator, IntermediateOperatorResult, }; -use crate::sinks::cross_join_collect::CrossJoinStateBridgeRef; +use crate::state_bridge::BroadcastStateBridgeRef; struct CrossJoinState { - bridge: CrossJoinStateBridgeRef, + bridge: BroadcastStateBridgeRef>, stream_idx: usize, collect_idx: usize, } impl CrossJoinState { - fn new(bridge: CrossJoinStateBridgeRef) -> Self { + fn new(bridge: BroadcastStateBridgeRef>) -> Self { Self { bridge, stream_idx: 0, @@ -37,14 +38,14 @@ impl IntermediateOpState for CrossJoinState { pub struct CrossJoinOperator { output_schema: SchemaRef, stream_side: JoinSide, - state_bridge: CrossJoinStateBridgeRef, + state_bridge: BroadcastStateBridgeRef>, } impl CrossJoinOperator { pub(crate) fn new( output_schema: SchemaRef, stream_side: JoinSide, - state_bridge: CrossJoinStateBridgeRef, + state_bridge: BroadcastStateBridgeRef>, ) -> Self { Self { output_schema, diff --git a/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs index ada98dac23..ec119d58e7 100644 --- a/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs @@ -13,10 +13,10 @@ use super::intermediate_op::{ IntermediateOpExecuteResult, IntermediateOpState, IntermediateOperator, IntermediateOperatorResult, }; -use crate::sinks::hash_join_build::ProbeStateBridgeRef; +use crate::state_bridge::BroadcastStateBridgeRef; enum InnerHashJoinProbeState { - Building(ProbeStateBridgeRef), + Building(BroadcastStateBridgeRef), Probing(Arc), } @@ -24,7 +24,7 @@ impl InnerHashJoinProbeState { async fn get_or_await_probe_state(&mut self) -> Arc { match self { Self::Building(bridge) => { - let probe_state = bridge.get_probe_state().await; + let probe_state = bridge.get_state().await; *self = Self::Probing(probe_state.clone()); probe_state } @@ -50,7 +50,7 @@ struct InnerHashJoinParams { pub struct InnerHashJoinProbeOperator { params: Arc, output_schema: SchemaRef, - probe_state_bridge: ProbeStateBridgeRef, + probe_state_bridge: BroadcastStateBridgeRef, } impl InnerHashJoinProbeOperator { @@ -63,7 +63,7 @@ impl InnerHashJoinProbeOperator { build_on_left: bool, common_join_keys: IndexSet, output_schema: &SchemaRef, - probe_state_bridge: ProbeStateBridgeRef, + probe_state_bridge: BroadcastStateBridgeRef, ) -> Self { let left_non_join_columns = left_schema .fields diff --git a/src/daft-local-execution/src/lib.rs b/src/daft-local-execution/src/lib.rs index df22857519..ec6020eaef 100644 --- a/src/daft-local-execution/src/lib.rs +++ b/src/daft-local-execution/src/lib.rs @@ -8,6 +8,7 @@ mod run; mod runtime_stats; mod sinks; mod sources; +mod state_bridge; use std::{ future::Future, diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index fa14998def..010df060a7 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -37,8 +37,8 @@ use crate::{ aggregate::AggregateSink, blocking_sink::BlockingSinkNode, concat::ConcatSink, - cross_join_collect::{CrossJoinCollectSink, CrossJoinStateBridge}, - hash_join_build::{HashJoinBuildSink, ProbeStateBridge}, + cross_join_collect::CrossJoinCollectSink, + hash_join_build::HashJoinBuildSink, limit::LimitSink, monotonically_increasing_id::MonotonicallyIncreasingIdSink, outer_hash_join_probe::OuterHashJoinProbeSink, @@ -48,6 +48,7 @@ use crate::{ write::{WriteFormat, WriteSink}, }, sources::{empty_scan::EmptyScanSource, in_memory::InMemorySource}, + state_bridge::BroadcastStateBridge, ExecutionRuntimeContext, PipelineCreationSnafu, }; @@ -416,7 +417,7 @@ pub fn physical_plan_to_pipeline( .map(|(e, f)| e.clone().cast(&f.dtype)) .collect::>(); // we should move to a builder pattern - let probe_state_bridge = ProbeStateBridge::new(); + let probe_state_bridge = BroadcastStateBridge::new(); let build_sink = HashJoinBuildSink::new( key_schema, casted_build_on, @@ -516,7 +517,7 @@ pub fn physical_plan_to_pipeline( let stream_child_node = physical_plan_to_pipeline(stream_child, psets, cfg)?; let collect_child_node = physical_plan_to_pipeline(collect_child, psets, cfg)?; - let state_bridge = CrossJoinStateBridge::new(); + let state_bridge = BroadcastStateBridge::new(); let collect_node = BlockingSinkNode::new( Arc::new(CrossJoinCollectSink::new(state_bridge.clone())), collect_child_node, diff --git a/src/daft-local-execution/src/sinks/cross_join_collect.rs b/src/daft-local-execution/src/sinks/cross_join_collect.rs index 5d156a337d..98e659b553 100644 --- a/src/daft-local-execution/src/sinks/cross_join_collect.rs +++ b/src/daft-local-execution/src/sinks/cross_join_collect.rs @@ -1,4 +1,4 @@ -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use common_error::DaftResult; use common_runtime::RuntimeRef; @@ -10,40 +10,7 @@ use super::blocking_sink::{ BlockingSink, BlockingSinkFinalizeResult, BlockingSinkSinkResult, BlockingSinkState, BlockingSinkStatus, }; - -pub(crate) type CrossJoinStateBridgeRef = Arc; - -// TODO(Colin): rework into more generic broadcast bridge that can be used for both probe table and micropartition -pub(crate) struct CrossJoinStateBridge { - inner: OnceLock>>, - notify: tokio::sync::Notify, -} - -impl CrossJoinStateBridge { - pub(crate) fn new() -> Arc { - Arc::new(Self { - inner: OnceLock::new(), - notify: tokio::sync::Notify::new(), - }) - } - - pub(crate) fn set_state(&self, state: Arc>) { - assert!( - !self.inner.set(state).is_err(), - "CrossJoinStateBridge should be set only once" - ); - self.notify.notify_waiters(); - } - - pub(crate) async fn get_state(&self) -> Arc> { - loop { - if let Some(state) = self.inner.get() { - return state.clone(); - } - self.notify.notified().await; - } - } -} +use crate::state_bridge::BroadcastStateBridgeRef; struct CrossJoinCollectState(Option>); @@ -54,11 +21,11 @@ impl BlockingSinkState for CrossJoinCollectState { } pub struct CrossJoinCollectSink { - state_bridge: CrossJoinStateBridgeRef, + state_bridge: BroadcastStateBridgeRef>, } impl CrossJoinCollectSink { - pub(crate) fn new(state_bridge: CrossJoinStateBridgeRef) -> Self { + pub(crate) fn new(state_bridge: BroadcastStateBridgeRef>) -> Self { Self { state_bridge } } } diff --git a/src/daft-local-execution/src/sinks/hash_join_build.rs b/src/daft-local-execution/src/sinks/hash_join_build.rs index 79d2c9d3bd..257b594c3f 100644 --- a/src/daft-local-execution/src/sinks/hash_join_build.rs +++ b/src/daft-local-execution/src/sinks/hash_join_build.rs @@ -1,4 +1,4 @@ -use std::sync::{Arc, OnceLock}; +use std::sync::Arc; use common_error::DaftResult; use common_runtime::RuntimeRef; @@ -12,42 +12,7 @@ use super::blocking_sink::{ BlockingSink, BlockingSinkFinalizeResult, BlockingSinkSinkResult, BlockingSinkState, BlockingSinkStatus, }; - -/// ProbeStateBridge is a bridge between the build and probe phase of a hash join. -/// It is used to pass the probe state from the build phase to the probe phase. -/// The build phase sets the probe state once building is complete, and the probe phase -/// waits for the probe state to be set via the `get_probe_state` method. -pub(crate) type ProbeStateBridgeRef = Arc; -pub(crate) struct ProbeStateBridge { - inner: OnceLock>, - notify: tokio::sync::Notify, -} - -impl ProbeStateBridge { - pub(crate) fn new() -> Arc { - Arc::new(Self { - inner: OnceLock::new(), - notify: tokio::sync::Notify::new(), - }) - } - - pub(crate) fn set_probe_state(&self, state: Arc) { - assert!( - !self.inner.set(state).is_err(), - "ProbeStateBridge should be set only once" - ); - self.notify.notify_waiters(); - } - - pub(crate) async fn get_probe_state(&self) -> Arc { - loop { - if let Some(state) = self.inner.get() { - return state.clone(); - } - self.notify.notified().await; - } - } -} +use crate::state_bridge::BroadcastStateBridgeRef; enum ProbeTableState { Building { @@ -131,7 +96,7 @@ pub struct HashJoinBuildSink { projection: Vec, nulls_equal_aware: Option>, join_type: JoinType, - probe_state_bridge: ProbeStateBridgeRef, + probe_state_bridge: BroadcastStateBridgeRef, } impl HashJoinBuildSink { @@ -140,7 +105,7 @@ impl HashJoinBuildSink { projection: Vec, nulls_equal_aware: Option>, join_type: &JoinType, - probe_state_bridge: ProbeStateBridgeRef, + probe_state_bridge: BroadcastStateBridgeRef, ) -> DaftResult { Ok(Self { key_schema, @@ -188,7 +153,7 @@ impl BlockingSink for HashJoinBuildSink { .expect("State type mismatch"); let finalized_probe_state = probe_table_state.finalize(); self.probe_state_bridge - .set_probe_state(finalized_probe_state.into()); + .set_state(finalized_probe_state.into()); Ok(None).into() } diff --git a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs index 4af93075a0..4c3729c150 100644 --- a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs +++ b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs @@ -17,15 +17,13 @@ use futures::{stream, StreamExt}; use indexmap::IndexSet; use tracing::{info_span, instrument}; -use super::{ - hash_join_build::ProbeStateBridgeRef, - streaming_sink::{ - StreamingSink, StreamingSinkExecuteResult, StreamingSinkFinalizeResult, - StreamingSinkOutput, StreamingSinkState, - }, +use super::streaming_sink::{ + StreamingSink, StreamingSinkExecuteResult, StreamingSinkFinalizeResult, StreamingSinkOutput, + StreamingSinkState, }; use crate::{ dispatcher::{DispatchSpawner, RoundRobinDispatcher, UnorderedDispatcher}, + state_bridge::BroadcastStateBridgeRef, ExecutionRuntimeContext, }; @@ -79,7 +77,7 @@ impl IndexBitmap { } enum OuterHashJoinState { - Building(ProbeStateBridgeRef, bool), + Building(BroadcastStateBridgeRef, bool), Probing(Arc, Option), } @@ -87,7 +85,7 @@ impl OuterHashJoinState { async fn get_or_build_probe_state(&mut self) -> Arc { match self { Self::Building(bridge, needs_bitmap) => { - let probe_state = bridge.get_probe_state().await; + let probe_state = bridge.get_state().await; let builder = needs_bitmap.then(|| IndexBitmapBuilder::new(probe_state.get_tables())); *self = Self::Probing(probe_state.clone(), builder); @@ -100,7 +98,7 @@ impl OuterHashJoinState { async fn get_or_build_bitmap(&mut self) -> &mut Option { match self { Self::Building(bridge, _) => { - let probe_state = bridge.get_probe_state().await; + let probe_state = bridge.get_state().await; let builder = IndexBitmapBuilder::new(probe_state.get_tables()); *self = Self::Probing(probe_state, Some(builder)); match self { @@ -132,7 +130,7 @@ struct OuterHashJoinParams { pub(crate) struct OuterHashJoinProbeSink { params: Arc, output_schema: SchemaRef, - probe_state_bridge: ProbeStateBridgeRef, + probe_state_bridge: BroadcastStateBridgeRef, } #[allow(clippy::too_many_arguments)] @@ -145,7 +143,7 @@ impl OuterHashJoinProbeSink { build_on_left: bool, common_join_keys: IndexSet, output_schema: &SchemaRef, - probe_state_bridge: ProbeStateBridgeRef, + probe_state_bridge: BroadcastStateBridgeRef, ) -> Self { // For outer joins, we need to swap the left and right schemas if we are building on the right. let (left_schema, right_schema) = match (join_type, build_on_left) { diff --git a/src/daft-local-execution/src/state_bridge.rs b/src/daft-local-execution/src/state_bridge.rs new file mode 100644 index 0000000000..0821dda1a1 --- /dev/null +++ b/src/daft-local-execution/src/state_bridge.rs @@ -0,0 +1,35 @@ +use std::sync::{Arc, OnceLock}; + +/// BroadcastStateBridge is a bridge to send state from one node to another. +/// e.g. from the build phase to the probe phase of a join operation. +pub(crate) type BroadcastStateBridgeRef = Arc>; +pub(crate) struct BroadcastStateBridge { + inner: OnceLock>, + notify: tokio::sync::Notify, +} + +impl BroadcastStateBridge { + pub(crate) fn new() -> Arc { + Arc::new(Self { + inner: OnceLock::new(), + notify: tokio::sync::Notify::new(), + }) + } + + pub(crate) fn set_state(&self, state: Arc) { + assert!( + !self.inner.set(state).is_err(), + "BroadcastStateBridge should be set only once" + ); + self.notify.notify_waiters(); + } + + pub(crate) async fn get_state(&self) -> Arc { + loop { + if let Some(state) = self.inner.get() { + return state.clone(); + } + self.notify.notified().await; + } + } +}