Skip to content

Commit

Permalink
reduce code
Browse files Browse the repository at this point in the history
  • Loading branch information
Colin Ho authored and Colin Ho committed Sep 19, 2024
1 parent dcac8ba commit bb0a944
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 114 deletions.
98 changes: 60 additions & 38 deletions src/daft-local-execution/src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,64 +13,86 @@ pub fn create_channel<T>(buffer_size: usize) -> (Sender<T>, Receiver<T>) {
}

pub struct PipelineChannel {
sender: Sender<PipelineResultType>,
receiver: Receiver<PipelineResultType>,
sender: PipelineSender,
receiver: PipelineReceiver,
}

impl PipelineChannel {
pub(crate) fn new() -> Self {
let (sender, receiver) = create_channel(1);
Self { sender, receiver }
pub fn new(buffer_size: usize, in_order: bool) -> Self {
match in_order {
true => {
let (senders, receivers) = (0..buffer_size).map(|_| create_channel(1)).unzip();
let sender = PipelineSender::InOrder(RoundRobinSender::new(senders));
let receiver = PipelineReceiver::InOrder(RoundRobinReceiver::new(receivers));
Self { sender, receiver }
}
false => {
let (sender, receiver) = create_channel(buffer_size);
let sender = PipelineSender::OutOfOrder(sender);
let receiver = PipelineReceiver::OutOfOrder(receiver);
Self { sender, receiver }
}
}
}

pub(crate) fn get_sender_with_stats(&self, stats: &Arc<RuntimeStatsContext>) -> CountingSender {
CountingSender::new(self.sender.clone(), stats.clone())
fn get_next_sender(&mut self) -> Sender<PipelineResultType> {
match &mut self.sender {
PipelineSender::InOrder(rr) => rr.get_next_sender(),
PipelineSender::OutOfOrder(sender) => sender.clone(),
}
}

pub(crate) fn get_receiver_with_stats(
self,
stats: &Arc<RuntimeStatsContext>,
) -> CountingReceiver {
CountingReceiver::new(self.receiver, stats.clone())
pub(crate) fn get_next_sender_with_stats(
&mut self,
rt: &Arc<RuntimeStatsContext>,
) -> CountingSender {
CountingSender::new(self.get_next_sender(), rt.clone())
}

pub(crate) fn get_receiver(self) -> Receiver<PipelineResultType> {
pub fn get_receiver(self) -> PipelineReceiver {
self.receiver
}

pub(crate) fn get_receiver_with_stats(self, rt: &Arc<RuntimeStatsContext>) -> CountingReceiver {
CountingReceiver::new(self.get_receiver(), rt.clone())
}
}

pub(crate) fn create_ordering_aware_channel<T>(
buffer_size: usize,
ordered: bool,
) -> (Vec<Sender<T>>, OrderingAwareReceiver<T>) {
match ordered {
true => {
let (senders, receivers) = (0..buffer_size).map(|_| create_channel(1)).unzip();
(
senders,
OrderingAwareReceiver::Ordered(RoundRobinReceiver::new(receivers)),
)
}
false => {
let (sender, receiver) = create_channel(buffer_size);
(
(0..buffer_size).map(|_| sender.clone()).collect(),
OrderingAwareReceiver::Unordered(receiver),
)
pub enum PipelineSender {
InOrder(RoundRobinSender<PipelineResultType>),
OutOfOrder(Sender<PipelineResultType>),
}

pub struct RoundRobinSender<T> {
senders: Vec<Sender<T>>,
curr_sender_idx: usize,
}

impl<T> RoundRobinSender<T> {
pub fn new(senders: Vec<Sender<T>>) -> Self {
Self {
senders,
curr_sender_idx: 0,
}
}

pub fn get_next_sender(&mut self) -> Sender<T> {
let next_idx = self.curr_sender_idx;
self.curr_sender_idx = (next_idx + 1) % self.senders.len();
self.senders[next_idx].clone()
}
}

pub enum OrderingAwareReceiver<T> {
Ordered(RoundRobinReceiver<T>),
Unordered(Receiver<T>),
pub enum PipelineReceiver {
InOrder(RoundRobinReceiver<PipelineResultType>),
OutOfOrder(Receiver<PipelineResultType>),
}

impl<T> OrderingAwareReceiver<T> {
pub async fn recv(&mut self) -> Option<T> {
impl PipelineReceiver {
pub async fn recv(&mut self) -> Option<PipelineResultType> {
match self {
OrderingAwareReceiver::Ordered(rr_receiver) => rr_receiver.recv().await,
OrderingAwareReceiver::Unordered(receiver) => receiver.recv().await,
PipelineReceiver::InOrder(rr) => rr.recv().await,
PipelineReceiver::OutOfOrder(r) => r.recv().await,
}
}
}
Expand Down
99 changes: 40 additions & 59 deletions src/daft-local-execution/src/intermediate_ops/intermediate_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@ use std::sync::Arc;
use common_display::tree::TreeDisplay;
use common_error::DaftResult;
use daft_micropartition::MicroPartition;
use snafu::ResultExt;
use tracing::{info_span, instrument};

use super::buffer::OperatorBuffer;
use crate::{
channel::{create_channel, create_ordering_aware_channel, PipelineChannel, Receiver, Sender},
create_worker_set,
channel::{create_channel, PipelineChannel, Receiver, Sender},
pipeline::{PipelineNode, PipelineResultType},
runtime_stats::{CountingReceiver, RuntimeStatsContext},
ExecutionRuntimeHandle, JoinSnafu, WorkerSet, NUM_CPUS,
runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext},
ExecutionRuntimeHandle, NUM_CPUS,
};

pub trait IntermediateOperatorState: Send + Sync {
Expand Down Expand Up @@ -70,53 +68,61 @@ impl IntermediateNode {
}

#[instrument(level = "info", skip_all, name = "IntermediateOperator::run_worker")]
async fn run_worker(
pub async fn run_worker(
op: Arc<dyn IntermediateOperator>,
mut input_receiver: Receiver<(usize, PipelineResultType)>,
output_sender: Sender<Arc<MicroPartition>>,
mut receiver: Receiver<(usize, PipelineResultType)>,
sender: CountingSender,
rt_context: Arc<RuntimeStatsContext>,
) -> DaftResult<()> {
let span = info_span!("IntermediateOp::execute");
let mut state = op.make_state();
while let Some((idx, morsel)) = input_receiver.recv().await {
while let Some((idx, morsel)) = receiver.recv().await {
loop {
let result =
rt_context.in_span(&span, || op.execute(idx, &morsel, state.as_mut()))?;
match result {
IntermediateOperatorResult::NeedMoreInput(Some(mp)) => {
let _ = output_sender.send(mp).await;
let _ = sender.send(mp.into()).await;
break;
}
IntermediateOperatorResult::NeedMoreInput(None) => {
break;
}
IntermediateOperatorResult::HasMoreOutput(mp) => {
let _ = output_sender.send(mp).await;
let _ = sender.send(mp.into()).await;
}
}
}
}
Ok(())
}

fn spawn_workers(
op: Arc<dyn IntermediateOperator>,
input_receivers: Vec<Receiver<(usize, PipelineResultType)>>,
output_senders: Vec<Sender<Arc<MicroPartition>>>,
worker_set: &mut WorkerSet<DaftResult<()>>,
stats: Arc<RuntimeStatsContext>,
) {
for (input_receiver, output_sender) in input_receivers.into_iter().zip(output_senders) {
worker_set.spawn(Self::run_worker(
op.clone(),
input_receiver,
output_sender,
stats.clone(),
));
pub fn spawn_workers(
&self,
num_workers: usize,
destination_channel: &mut PipelineChannel,
runtime_handle: &mut ExecutionRuntimeHandle,
) -> Vec<Sender<(usize, PipelineResultType)>> {
let mut worker_senders = Vec::with_capacity(num_workers);
for _ in 0..num_workers {
let (worker_sender, worker_receiver) = create_channel(1);
let destination_sender =
destination_channel.get_next_sender_with_stats(&self.runtime_stats);
runtime_handle.spawn(
Self::run_worker(
self.intermediate_op.clone(),
worker_receiver,
destination_sender,
self.runtime_stats.clone(),
),
self.intermediate_op.name(),
);
worker_senders.push(worker_sender);
}
worker_senders
}

async fn forward_input_to_workers(
pub async fn send_to_workers(
receivers: Vec<CountingReceiver>,
worker_senders: Vec<Sender<(usize, PipelineResultType)>>,
morsel_size: usize,
Expand Down Expand Up @@ -196,41 +202,16 @@ impl PipelineNode for IntermediateNode {
child_result_receivers
.push(child_result_channel.get_receiver_with_stats(&self.runtime_stats));
}
let mut destination_channel = PipelineChannel::new(*NUM_CPUS, maintain_order);

let destination_channel = PipelineChannel::new();
let destination_sender = destination_channel.get_sender_with_stats(&self.runtime_stats);

let op = self.intermediate_op.clone();
let stats = self.runtime_stats.clone();
let morsel_size = runtime_handle.default_morsel_size();
let worker_senders =
self.spawn_workers(*NUM_CPUS, &mut destination_channel, runtime_handle);
runtime_handle.spawn(
async move {
let num_workers = *NUM_CPUS;
let (input_senders, input_receivers) =
(0..num_workers).map(|_| create_channel(1)).unzip();
let (output_senders, mut output_receiver) =
create_ordering_aware_channel(num_workers, maintain_order);

let mut worker_set = create_worker_set();
Self::spawn_workers(
op.clone(),
input_receivers,
output_senders,
&mut worker_set,
stats.clone(),
);

Self::forward_input_to_workers(child_result_receivers, input_senders, morsel_size)
.await?;

while let Some(morsel) = output_receiver.recv().await {
let _ = destination_sender.send(morsel.into()).await;
}
while let Some(result) = worker_set.join_next().await {
result.context(JoinSnafu)??;
}
Ok(())
},
Self::send_to_workers(
child_result_receivers,
worker_senders,
runtime_handle.default_morsel_size(),
),
self.intermediate_op.name(),
);
Ok(destination_channel)
Expand Down
9 changes: 3 additions & 6 deletions src/daft-local-execution/src/runtime_stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::{
use tokio::sync::mpsc::error::SendError;

use crate::{
channel::{Receiver, Sender},
channel::{PipelineReceiver, Sender},
pipeline::PipelineResultType,
};

Expand Down Expand Up @@ -133,15 +133,12 @@ impl CountingSender {
}

pub(crate) struct CountingReceiver {
receiver: Receiver<PipelineResultType>,
receiver: PipelineReceiver,
rt: Arc<RuntimeStatsContext>,
}

impl CountingReceiver {
pub(crate) fn new(
receiver: Receiver<PipelineResultType>,
rt: Arc<RuntimeStatsContext>,
) -> Self {
pub(crate) fn new(receiver: PipelineReceiver, rt: Arc<RuntimeStatsContext>) -> Self {
Self { receiver, rt }
}
#[inline]
Expand Down
7 changes: 4 additions & 3 deletions src/daft-local-execution/src/sinks/blocking_sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,17 @@ impl PipelineNode for BlockingSinkNode {

fn start(
&mut self,
_maintain_order: bool,
maintain_order: bool,
runtime_handle: &mut ExecutionRuntimeHandle,
) -> crate::Result<PipelineChannel> {
let child = self.child.as_mut();
let mut child_results_receiver = child
.start(false, runtime_handle)?
.get_receiver_with_stats(&self.runtime_stats);

let destination_channel = PipelineChannel::new();
let destination_sender = destination_channel.get_sender_with_stats(&self.runtime_stats);
let mut destination_channel = PipelineChannel::new(1, maintain_order);
let destination_sender =
destination_channel.get_next_sender_with_stats(&self.runtime_stats);
let op = self.op.clone();
let rt_context = self.runtime_stats.clone();
runtime_handle.spawn(
Expand Down
11 changes: 5 additions & 6 deletions src/daft-local-execution/src/sinks/streaming_sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use snafu::ResultExt;
use tracing::{info_span, instrument};

use crate::{
channel::{create_channel, create_ordering_aware_channel, PipelineChannel, Receiver, Sender},
channel::{create_channel, PipelineChannel, Receiver, Sender},
create_worker_set,
pipeline::{PipelineNode, PipelineResultType},
runtime_stats::{CountingReceiver, RuntimeStatsContext},
Expand Down Expand Up @@ -187,9 +187,9 @@ impl PipelineNode for StreamingSinkNode {
.push(child_result_channel.get_receiver_with_stats(&self.runtime_stats.clone()));
}

let destination_channel = PipelineChannel::new();
let mut destination_channel = PipelineChannel::new(1, maintain_order);
let destination_sender =
destination_channel.get_sender_with_stats(&self.runtime_stats.clone());
destination_channel.get_next_sender_with_stats(&self.runtime_stats);

let op = self.op.clone();
let runtime_stats = self.runtime_stats.clone();
Expand All @@ -198,14 +198,13 @@ impl PipelineNode for StreamingSinkNode {
let num_workers = op.max_concurrency();
let (input_senders, input_receivers) =
(0..num_workers).map(|_| create_channel(1)).unzip();
let (output_senders, mut output_receiver) =
create_ordering_aware_channel(num_workers, maintain_order);
let (output_sender, mut output_receiver) = create_channel(num_workers);

let mut worker_set = create_worker_set();
Self::spawn_workers(
op.clone(),
input_receivers,
output_senders,
(0..num_workers).map(|_| output_sender.clone()).collect(),
&mut worker_set,
runtime_stats.clone(),
);
Expand Down
4 changes: 2 additions & 2 deletions src/daft-local-execution/src/sources/source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ impl PipelineNode for SourceNode {
self.source
.get_data(maintain_order, runtime_handle, self.io_stats.clone())?;

let channel = PipelineChannel::new();
let counting_sender = channel.get_sender_with_stats(&self.runtime_stats);
let mut channel = PipelineChannel::new(1, maintain_order);
let counting_sender = channel.get_next_sender_with_stats(&self.runtime_stats);
runtime_handle.spawn(
async move {
while let Some(part) = source_stream.next().await {
Expand Down

0 comments on commit bb0a944

Please sign in to comment.