Skip to content

Commit

Permalink
more cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Colin Ho authored and Colin Ho committed Sep 19, 2024
1 parent 86a5fa4 commit dcac8ba
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ use tracing::{info_span, instrument};
use super::buffer::OperatorBuffer;
use crate::{
channel::{create_channel, create_ordering_aware_channel, PipelineChannel, Receiver, Sender},
create_worker_set,
pipeline::{PipelineNode, PipelineResultType},
runtime_stats::{CountingReceiver, RuntimeStatsContext},
ExecutionRuntimeHandle, JoinSnafu, NUM_CPUS,
ExecutionRuntimeHandle, JoinSnafu, WorkerSet, NUM_CPUS,
};

pub trait IntermediateOperatorState: Send + Sync {
Expand Down Expand Up @@ -102,7 +103,7 @@ impl IntermediateNode {
op: Arc<dyn IntermediateOperator>,
input_receivers: Vec<Receiver<(usize, PipelineResultType)>>,
output_senders: Vec<Sender<Arc<MicroPartition>>>,
worker_set: &mut tokio::task::JoinSet<DaftResult<()>>,
worker_set: &mut WorkerSet<DaftResult<()>>,
stats: Arc<RuntimeStatsContext>,
) {
for (input_receiver, output_sender) in input_receivers.into_iter().zip(output_senders) {
Expand Down Expand Up @@ -209,14 +210,16 @@ impl PipelineNode for IntermediateNode {
(0..num_workers).map(|_| create_channel(1)).unzip();
let (output_senders, mut output_receiver) =
create_ordering_aware_channel(num_workers, maintain_order);
let mut worker_set = tokio::task::JoinSet::new();

let mut worker_set = create_worker_set();
Self::spawn_workers(
op.clone(),
input_receivers,
output_senders,
&mut worker_set,
stats.clone(),
);

Self::forward_input_to_workers(child_result_receivers, input_senders, morsel_size)
.await?;

Expand Down
9 changes: 7 additions & 2 deletions src/daft-local-execution/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,20 @@ lazy_static! {
pub static ref NUM_CPUS: usize = std::thread::available_parallelism().unwrap().get();
}

pub(crate) type WorkerSet<T> = tokio::task::JoinSet<T>;
pub(crate) fn create_worker_set<T>() -> WorkerSet<T> {
tokio::task::JoinSet::new()
}

pub struct ExecutionRuntimeHandle {
worker_set: tokio::task::JoinSet<crate::Result<()>>,
worker_set: WorkerSet<crate::Result<()>>,
default_morsel_size: usize,
}

impl ExecutionRuntimeHandle {
pub fn new(default_morsel_size: usize) -> Self {
Self {
worker_set: tokio::task::JoinSet::new(),
worker_set: create_worker_set(),
default_morsel_size,
}
}
Expand Down
14 changes: 8 additions & 6 deletions src/daft-local-execution/src/sinks/streaming_sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ use tracing::{info_span, instrument};

use crate::{
channel::{create_channel, create_ordering_aware_channel, PipelineChannel, Receiver, Sender},
create_worker_set,
pipeline::{PipelineNode, PipelineResultType},
runtime_stats::{CountingReceiver, RuntimeStatsContext},
ExecutionRuntimeHandle, JoinSnafu, NUM_CPUS,
ExecutionRuntimeHandle, JoinSnafu, WorkerSet, NUM_CPUS,
};

pub trait StreamingSinkState: Send + Sync {
Expand Down Expand Up @@ -103,7 +104,7 @@ impl StreamingSinkNode {
op: Arc<dyn StreamingSink>,
input_receivers: Vec<Receiver<(usize, PipelineResultType)>>,
output_senders: Vec<Sender<Arc<MicroPartition>>>,
worker_set: &mut tokio::task::JoinSet<DaftResult<Box<dyn StreamingSinkState>>>,
worker_set: &mut WorkerSet<DaftResult<Box<dyn StreamingSinkState>>>,
stats: Arc<RuntimeStatsContext>,
) {
for (input_receiver, output_sender) in input_receivers.into_iter().zip(output_senders) {
Expand Down Expand Up @@ -191,24 +192,25 @@ impl PipelineNode for StreamingSinkNode {
destination_channel.get_sender_with_stats(&self.runtime_stats.clone());

let op = self.op.clone();
let stats = self.runtime_stats.clone();
let runtime_stats = self.runtime_stats.clone();
runtime_handle.spawn(
async move {
let num_workers = op.max_concurrency();
let (input_senders, input_receivers) =
(0..num_workers).map(|_| create_channel(1)).unzip();
let (output_senders, mut output_receiver) =
create_ordering_aware_channel(num_workers, maintain_order);
let mut worker_set = tokio::task::JoinSet::new();

let mut worker_set = create_worker_set();
Self::spawn_workers(
op.clone(),
input_receivers,
output_senders,
&mut worker_set,
stats.clone(),
runtime_stats.clone(),
);
Self::forward_input_to_workers(child_result_receivers, input_senders).await?;

Self::forward_input_to_workers(child_result_receivers, input_senders).await?;
while let Some(morsel) = output_receiver.recv().await {
let _ = destination_sender.send(morsel.into()).await;
}
Expand Down

0 comments on commit dcac8ba

Please sign in to comment.