Skip to content

Commit

Permalink
pipline channel
Browse files Browse the repository at this point in the history
  • Loading branch information
Colin Ho authored and Colin Ho committed Aug 27, 2024
1 parent faa48b6 commit 5357ffa
Show file tree
Hide file tree
Showing 13 changed files with 151 additions and 155 deletions.
81 changes: 49 additions & 32 deletions src/daft-local-execution/src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,50 +2,67 @@ use std::sync::Arc;

use crate::{
pipeline::PipelineResultType,
runtime_stats::{CountingSender, RuntimeStatsContext},
runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext},
};

pub type OneShotSender<T> = tokio::sync::oneshot::Sender<T>;
pub type OneShotReceiver<T> = tokio::sync::oneshot::Receiver<T>;

pub fn create_one_shot_channel<T>() -> (OneShotSender<T>, OneShotReceiver<T>) {
tokio::sync::oneshot::channel()
}

pub type Sender<T> = tokio::sync::mpsc::Sender<T>;
pub type Receiver<T> = tokio::sync::mpsc::Receiver<T>;

pub fn create_channel<T>(buffer_size: usize) -> (Sender<T>, Receiver<T>) {
tokio::sync::mpsc::channel(buffer_size)
}

pub fn create_multi_channel(buffer_size: usize, in_order: bool) -> (MultiSender, MultiReceiver) {
if in_order {
let (senders, receivers) = (0..buffer_size).map(|_| create_channel(1)).unzip();
let sender = MultiSender::InOrder(RoundRobinSender::new(senders));
let receiver = MultiReceiver::InOrder(RoundRobinReceiver::new(receivers));
(sender, receiver)
} else {
let (sender, receiver) = create_channel(buffer_size);
let sender = MultiSender::OutOfOrder(sender);
let receiver = MultiReceiver::OutOfOrder(receiver);
(sender, receiver)
pub struct PipelineChannel {
sender: PipelineSender,
receiver: PipelineReceiver,
}

impl PipelineChannel {
pub fn new(buffer_size: usize, in_order: bool) -> Self {
match in_order {
true => {
let (senders, receivers) = (0..buffer_size).map(|_| create_channel(1)).unzip();
let sender = PipelineSender::InOrder(RoundRobinSender::new(senders));
let receiver = PipelineReceiver::InOrder(RoundRobinReceiver::new(receivers));
Self { sender, receiver }
}
false => {
let (sender, receiver) = create_channel(buffer_size);
let sender = PipelineSender::OutOfOrder(sender);
let receiver = PipelineReceiver::OutOfOrder(receiver);
Self { sender, receiver }
}
}
}

fn get_next_sender(&mut self) -> Sender<PipelineResultType> {
match &mut self.sender {
PipelineSender::InOrder(rr) => rr.get_next_sender(),
PipelineSender::OutOfOrder(sender) => sender.clone(),
}
}

pub(crate) fn get_next_sender_with_stats(
&mut self,
rt: &Arc<RuntimeStatsContext>,
) -> CountingSender {
CountingSender::new(self.get_next_sender(), rt.clone())
}

pub fn get_receiver(self) -> PipelineReceiver {
self.receiver
}

pub(crate) fn get_receiver_with_stats(self, rt: &Arc<RuntimeStatsContext>) -> CountingReceiver {
CountingReceiver::new(self.get_receiver(), rt.clone())
}
}

pub enum MultiSender {
pub enum PipelineSender {
InOrder(RoundRobinSender<PipelineResultType>),
OutOfOrder(Sender<PipelineResultType>),
}

impl MultiSender {
pub fn get_next_sender(&mut self, stats: &Arc<RuntimeStatsContext>) -> CountingSender {
match self {
Self::InOrder(sender) => CountingSender::new(sender.get_next_sender(), stats.clone()),
Self::OutOfOrder(sender) => CountingSender::new(sender.clone(), stats.clone()),
}
}
}
pub struct RoundRobinSender<T> {
senders: Vec<Sender<T>>,
curr_sender_idx: usize,
Expand All @@ -66,16 +83,16 @@ impl<T> RoundRobinSender<T> {
}
}

pub enum MultiReceiver {
pub enum PipelineReceiver {
InOrder(RoundRobinReceiver<PipelineResultType>),
OutOfOrder(Receiver<PipelineResultType>),
}

impl MultiReceiver {
impl PipelineReceiver {
pub async fn recv(&mut self) -> Option<PipelineResultType> {
match self {
Self::InOrder(receiver) => receiver.recv().await,
Self::OutOfOrder(receiver) => receiver.recv().await,
PipelineReceiver::InOrder(rr) => rr.recv().await,
PipelineReceiver::OutOfOrder(r) => r.recv().await,
}
}
}
Expand Down
5 changes: 2 additions & 3 deletions src/daft-local-execution/src/intermediate_ops/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,10 @@ impl IntermediateOperator for AggregateOperator {
fn execute(
&self,
_idx: usize,
input: &PipelineResultType,
input: PipelineResultType,
_state: Option<&mut Box<dyn IntermediateOperatorState>>,
) -> DaftResult<IntermediateOperatorResult> {
let input = input.as_data();
let out = input.agg(&self.agg_exprs, &self.group_by)?;
let out = input.data().agg(&self.agg_exprs, &self.group_by)?;
Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new(
out,
))))
Expand Down
5 changes: 2 additions & 3 deletions src/daft-local-execution/src/intermediate_ops/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@ impl IntermediateOperator for FilterOperator {
fn execute(
&self,
_idx: usize,
input: &PipelineResultType,
input: PipelineResultType,
_state: Option<&mut Box<dyn IntermediateOperatorState>>,
) -> DaftResult<IntermediateOperatorResult> {
let input = input.as_data();
let out = input.filter(&[self.predicate.clone()])?;
let out = input.data().filter(&[self.predicate.clone()])?;
Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new(
out,
))))
Expand Down
13 changes: 7 additions & 6 deletions src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@ enum HashJoinProbeState {
}

impl HashJoinProbeState {
fn set_table(&mut self, table: &Arc<ProbeTable>, tables: &Arc<Vec<Table>>) {
fn set_table(&mut self, table: Arc<ProbeTable>, tables: Arc<Vec<Table>>) {
if let HashJoinProbeState::Building = self {
*self = HashJoinProbeState::ReadyToProbe(table.clone(), tables.clone());
*self = HashJoinProbeState::ReadyToProbe(table, tables);
} else {
panic!("HashJoinProbeState should only be in Building state when setting table")
}
}

fn probe(
&self,
input: &Arc<MicroPartition>,
input: Arc<MicroPartition>,
right_on: &[ExprRef],
pruned_right_side_columns: &[String],
) -> DaftResult<Arc<MicroPartition>> {
Expand Down Expand Up @@ -109,17 +109,18 @@ impl IntermediateOperator for HashJoinProbeOperator {
fn execute(
&self,
idx: usize,
input: &PipelineResultType,
input: PipelineResultType,
state: Option<&mut Box<dyn IntermediateOperatorState>>,
) -> DaftResult<IntermediateOperatorResult> {
println!("HashJoinProbeOperator::execute: idx: {}", idx);
match idx {
0 => {
let state = state
.expect("HashJoinProbeOperator should have state")
.as_any_mut()
.downcast_mut::<HashJoinProbeState>()
.expect("HashJoinProbeOperator state should be HashJoinProbeState");
let (probe_table, tables) = input.as_probe_table();
let (probe_table, tables) = input.probe_table();
state.set_table(probe_table, tables);
Ok(IntermediateOperatorResult::NeedMoreInput(None))
}
Expand All @@ -129,7 +130,7 @@ impl IntermediateOperator for HashJoinProbeOperator {
.as_any_mut()
.downcast_mut::<HashJoinProbeState>()
.expect("HashJoinProbeOperator state should be HashJoinProbeState");
let input = input.as_data();
let input = input.data();
let out = state.probe(input, &self.right_on, &self.pruned_right_side_columns)?;
Ok(IntermediateOperatorResult::NeedMoreInput(Some(out)))
}
Expand Down
42 changes: 18 additions & 24 deletions src/daft-local-execution/src/intermediate_ops/intermediate_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ use tracing::{info_span, instrument};
use async_trait::async_trait;

use crate::{
channel::{create_channel, create_multi_channel, MultiSender, Receiver, Sender},
pipeline::{PipelineNode, PipelineResultReceiver, PipelineResultType},
runtime_stats::{CountingSender, RuntimeStatsContext},
channel::{create_channel, PipelineChannel, Receiver, Sender},
pipeline::{PipelineNode, PipelineResultType},
runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext},
ExecutionRuntimeHandle, NUM_CPUS,
};

Expand All @@ -30,7 +30,7 @@ pub trait IntermediateOperator: Send + Sync {
fn execute(
&self,
idx: usize,
input: &PipelineResultType,
input: PipelineResultType,
state: Option<&mut Box<dyn IntermediateOperatorState>>,
) -> DaftResult<IntermediateOperatorResult>;
fn name(&self) -> &'static str;
Expand Down Expand Up @@ -80,14 +80,7 @@ impl IntermediateNode {
let span = info_span!("IntermediateOp::execute");
let mut state = op.make_state();
while let Some((idx, morsel)) = receiver.recv().await {
let len = match morsel {
PipelineResultType::Data(ref data) => data.len(),
PipelineResultType::ProbeTable(_, ref tables) => {
tables.iter().map(|t| t.len()).sum()
}
};
rt_context.mark_rows_received(len as u64);
let result = rt_context.in_span(&span, || op.execute(idx, &morsel, state.as_mut()))?;
let result = rt_context.in_span(&span, || op.execute(idx, morsel, state.as_mut()))?;
match result {
IntermediateOperatorResult::NeedMoreInput(Some(mp)) => {
let _ = sender.send(mp.into()).await;
Expand All @@ -104,13 +97,14 @@ impl IntermediateNode {
pub async fn spawn_workers(
&self,
num_workers: usize,
destination: &mut MultiSender,
destination_channel: &mut PipelineChannel,
runtime_handle: &mut ExecutionRuntimeHandle,
) -> Vec<Sender<(usize, PipelineResultType)>> {
let mut worker_senders = Vec::with_capacity(num_workers);
for _ in 0..num_workers {
let (worker_sender, worker_receiver) = create_channel(1);
let destination_sender = destination.get_next_sender(&self.runtime_stats);
let destination_sender =
destination_channel.get_next_sender_with_stats(&self.runtime_stats);
runtime_handle.spawn(
Self::run_worker(
self.intermediate_op.clone(),
Expand All @@ -126,7 +120,7 @@ impl IntermediateNode {
}

pub async fn send_to_workers(
receivers: Vec<PipelineResultReceiver>,
receivers: Vec<CountingReceiver>,
worker_senders: Vec<Sender<(usize, PipelineResultType)>>,
morsel_size: usize,
) -> DaftResult<()> {
Expand All @@ -138,15 +132,15 @@ impl IntermediateNode {
};

for (idx, mut receiver) in receivers.into_iter().enumerate() {
println!("idx: {}", idx);
let mut buffer = OperatorBuffer::new(morsel_size);
while let Some(morsel) = receiver.recv().await {
let morsel = morsel?;
if morsel.should_broadcast() {
for worker_sender in worker_senders.iter() {
let _ = worker_sender.send((idx, morsel.clone())).await;
}
} else {
buffer.push(morsel.as_data().clone());
buffer.push(morsel.data().clone());
if let Some(ready) = buffer.try_clear() {
let _ = send_to_next_worker(idx, ready?.into()).await;
}
Expand Down Expand Up @@ -200,17 +194,17 @@ impl PipelineNode for IntermediateNode {
&mut self,
maintain_order: bool,
runtime_handle: &mut ExecutionRuntimeHandle,
) -> crate::Result<PipelineResultReceiver> {
) -> crate::Result<PipelineChannel> {
let mut child_result_receivers = Vec::with_capacity(self.children.len());
for child in self.children.iter_mut() {
let child_result_receiver = child.start(maintain_order, runtime_handle).await?;
child_result_receivers.push(child_result_receiver);
let child_result_channel = child.start(maintain_order, runtime_handle).await?;
child_result_receivers
.push(child_result_channel.get_receiver_with_stats(&self.runtime_stats));
}
let (mut destination_sender, destination_receiver) =
create_multi_channel(*NUM_CPUS, maintain_order);
let mut destination_channel = PipelineChannel::new(*NUM_CPUS, maintain_order);

let worker_senders = self
.spawn_workers(*NUM_CPUS, &mut destination_sender, runtime_handle)
.spawn_workers(*NUM_CPUS, &mut destination_channel, runtime_handle)
.await;
runtime_handle.spawn(
Self::send_to_workers(
Expand All @@ -220,7 +214,7 @@ impl PipelineNode for IntermediateNode {
),
self.intermediate_op.name(),
);
Ok(destination_receiver.into())
Ok(destination_channel)
}
fn as_tree_display(&self) -> &dyn TreeDisplay {
self
Expand Down
5 changes: 2 additions & 3 deletions src/daft-local-execution/src/intermediate_ops/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@ impl IntermediateOperator for ProjectOperator {
fn execute(
&self,
_idx: usize,
input: &PipelineResultType,
input: PipelineResultType,
_state: Option<&mut Box<dyn IntermediateOperatorState>>,
) -> DaftResult<IntermediateOperatorResult> {
let input = input.as_data();
let out = input.eval_expression_list(&self.projection)?;
let out = input.data().eval_expression_list(&self.projection)?;
Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new(
out,
))))
Expand Down
Loading

0 comments on commit 5357ffa

Please sign in to comment.