Skip to content

Commit

Permalink
consolidate runtime + wrap state
Browse files Browse the repository at this point in the history
  • Loading branch information
Colin Ho authored and Colin Ho committed Oct 22, 2024
1 parent 85341c3 commit a88a193
Show file tree
Hide file tree
Showing 16 changed files with 155 additions and 201 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

65 changes: 0 additions & 65 deletions src/common/runtime/src/compute.rs

This file was deleted.

90 changes: 0 additions & 90 deletions src/common/runtime/src/io.rs

This file was deleted.

122 changes: 109 additions & 13 deletions src/common/runtime/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
pub mod compute;
pub mod io;
use std::{
future::Future,
panic::AssertUnwindSafe,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, OnceLock,
},
};

use std::sync::atomic::{AtomicUsize, Ordering};

use compute::{ComputeRuntime, ComputeRuntimeRef, COMPUTE_RUNTIME};
use io::{IORuntime, IORuntimeRef, SINGLE_THREADED_IO_RUNTIME, THREADED_IO_RUNTIME};
use common_error::{DaftError, DaftResult};
use futures::FutureExt;
use lazy_static::lazy_static;
use tokio::runtime::RuntimeFlavor;
use tokio::{runtime::RuntimeFlavor, task::JoinHandle};

lazy_static! {
static ref NUM_CPUS: usize = std::thread::available_parallelism().unwrap().get();
Expand All @@ -15,7 +19,99 @@ lazy_static! {
static ref COMPUTE_RUNTIME_MAX_BLOCKING_THREADS: usize = 1; // Compute thread should not use blocking threads, limit this to the minimum, i.e. 1
}

fn init_compute_runtime() -> ComputeRuntimeRef {
static THREADED_IO_RUNTIME: OnceLock<RuntimeRef> = OnceLock::new();
static SINGLE_THREADED_IO_RUNTIME: OnceLock<RuntimeRef> = OnceLock::new();
static COMPUTE_RUNTIME: OnceLock<RuntimeRef> = OnceLock::new();

pub type RuntimeRef = Arc<Runtime>;

#[derive(Debug, Clone, Copy)]
enum PoolType {
Compute,
IO,
}

pub struct Runtime {
runtime: tokio::runtime::Runtime,
pool_type: PoolType,
}

impl Runtime {
pub(crate) fn new(runtime: tokio::runtime::Runtime, pool_type: PoolType) -> RuntimeRef {
Arc::new(Self { runtime, pool_type })
}

async fn execute_task<F>(future: F, pool_type: PoolType) -> DaftResult<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
AssertUnwindSafe(future).catch_unwind().await.map_err(|e| {
let s = if let Some(s) = e.downcast_ref::<String>() {
s.clone()
} else if let Some(s) = e.downcast_ref::<&str>() {
(*s).to_string()

Check warning on line 53 in src/common/runtime/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/common/runtime/src/lib.rs#L50-L53

Added lines #L50 - L53 were not covered by tests
} else {
"unknown internal error".to_string()

Check warning on line 55 in src/common/runtime/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/common/runtime/src/lib.rs#L55

Added line #L55 was not covered by tests
};
DaftError::ComputeError(format!(
"Caught panic when spawning blocking task in the {:?} runtime: {})",
pool_type, s
))

Check warning on line 60 in src/common/runtime/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/common/runtime/src/lib.rs#L57-L60

Added lines #L57 - L60 were not covered by tests
})
}

/// Similar to tokio's Runtime::block_on but requires static lifetime + Send
/// You should use this when you are spawning IO tasks from an Expression Evaluator or in the Executor
pub fn block_on<F>(&self, future: F) -> DaftResult<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let (tx, rx) = oneshot::channel();
let pool_type = self.pool_type;
let _join_handle = self.spawn(async move {
let task_output = Self::execute_task(future, pool_type).await;
if tx.send(task_output).is_err() {
log::warn!("Spawned task output ignored: receiver dropped");

Check warning on line 76 in src/common/runtime/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/common/runtime/src/lib.rs#L76

Added line #L76 was not covered by tests
}
});
rx.recv().expect("Spawned task transmitter dropped")
}

/// Similar to block_on, but is async and can be awaited
pub async fn await_on<F>(&self, future: F) -> DaftResult<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let (tx, rx) = oneshot::channel();
let pool_type = self.pool_type;
let _join_handle = self.spawn(async move {
let task_output = Self::execute_task(future, pool_type).await;
if tx.send(task_output).is_err() {
log::warn!("Spawned task output ignored: receiver dropped");

Check warning on line 93 in src/common/runtime/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/common/runtime/src/lib.rs#L93

Added line #L93 was not covered by tests
}
});
rx.await.expect("Spawned task transmitter dropped")
}

/// Blocks current thread to compute future. Can not be called in tokio runtime context
///
pub fn block_on_current_thread<F: Future>(&self, future: F) -> F::Output {
self.runtime.block_on(future)
}

pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.runtime.spawn(future)
}
}

fn init_compute_runtime() -> RuntimeRef {
std::thread::spawn(move || {
let mut builder = tokio::runtime::Builder::new_multi_thread();
builder
Expand All @@ -27,13 +123,13 @@ fn init_compute_runtime() -> ComputeRuntimeRef {
format!("Compute-Thread-{}", id)
})
.max_blocking_threads(*COMPUTE_RUNTIME_MAX_BLOCKING_THREADS);
ComputeRuntime::new(builder.build().unwrap())
Runtime::new(builder.build().unwrap(), PoolType::Compute)
})
.join()
.unwrap()
}

fn init_io_runtime(multi_thread: bool) -> IORuntimeRef {
fn init_io_runtime(multi_thread: bool) -> RuntimeRef {
std::thread::spawn(move || {
let mut builder = tokio::runtime::Builder::new_multi_thread();
builder
Expand All @@ -48,17 +144,17 @@ fn init_io_runtime(multi_thread: bool) -> IORuntimeRef {
let id = COMPUTE_THREAD_ATOMIC_ID.fetch_add(1, Ordering::SeqCst);
format!("IO-Thread-{}", id)
});
IORuntime::new(builder.build().unwrap())
Runtime::new(builder.build().unwrap(), PoolType::IO)
})
.join()
.unwrap()
}

pub fn get_compute_runtime() -> ComputeRuntimeRef {
pub fn get_compute_runtime() -> RuntimeRef {
COMPUTE_RUNTIME.get_or_init(init_compute_runtime).clone()
}

pub fn get_io_runtime(multi_thread: bool) -> IORuntimeRef {
pub fn get_io_runtime(multi_thread: bool) -> RuntimeRef {
if !multi_thread {
SINGLE_THREADED_IO_RUNTIME
.get_or_init(|| init_io_runtime(false))
Expand Down
1 change: 0 additions & 1 deletion src/daft-io/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ hyper-tls = "0.5.0"
itertools = {workspace = true}
lazy_static = {workspace = true}
log = {workspace = true}
oneshot = "0.1.8"
openssl-sys = {version = "0.9.102", features = ["vendored"]}
pyo3 = {workspace = true, optional = true}
rand = "0.8.5"
Expand Down
4 changes: 2 additions & 2 deletions src/daft-local-execution/src/intermediate_ops/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use daft_dsl::ExprRef;
use tracing::instrument;

use super::intermediate_op::{
IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState,
IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorStateWrapper,
};
use crate::pipeline::PipelineResultType;

Expand All @@ -29,7 +29,7 @@ impl IntermediateOperator for AggregateOperator {
&self,
_idx: usize,
input: &PipelineResultType,
_state: &mut dyn IntermediateOperatorState,
_state: &IntermediateOperatorStateWrapper,
) -> DaftResult<IntermediateOperatorResult> {
let out = input.as_data().agg(&self.agg_exprs, &self.group_by)?;
Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use tracing::{info_span, instrument};

use super::intermediate_op::{
IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState,
IntermediateOperatorStateWrapper,
};
use crate::pipeline::PipelineResultType;

Expand Down Expand Up @@ -108,9 +109,10 @@ impl IntermediateOperator for AntiSemiProbeOperator {
&self,
idx: usize,
input: &PipelineResultType,
state: &mut dyn IntermediateOperatorState,
state: &IntermediateOperatorStateWrapper,
) -> DaftResult<IntermediateOperatorResult> {
let state = state
let mut guard = state.inner.lock().unwrap();
let state = guard
.as_any_mut()
.downcast_mut::<AntiSemiProbeState>()
.expect("AntiSemiProbeOperator state should be AntiSemiProbeState");
Expand Down
4 changes: 2 additions & 2 deletions src/daft-local-execution/src/intermediate_ops/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use daft_dsl::ExprRef;
use tracing::instrument;

use super::intermediate_op::{
IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState,
IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorStateWrapper,
};
use crate::pipeline::PipelineResultType;

Expand All @@ -25,7 +25,7 @@ impl IntermediateOperator for FilterOperator {
&self,
_idx: usize,
input: &PipelineResultType,
_state: &mut dyn IntermediateOperatorState,
_state: &IntermediateOperatorStateWrapper,
) -> DaftResult<IntermediateOperatorResult> {
let out = input.as_data().filter(&[self.predicate.clone()])?;
Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new(
Expand Down
Loading

0 comments on commit a88a193

Please sign in to comment.