Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Compute pool for native executor #2986

Merged
merged 15 commits into from
Oct 23, 2024
22 changes: 21 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ common-display = {path = "src/common/display", default-features = false}
common-file-formats = {path = "src/common/file-formats", default-features = false}
common-hashable-float-wrapper = {path = "src/common/hashable-float-wrapper", default-features = false}
common-resource-request = {path = "src/common/resource-request", default-features = false}
common-runtime = {path = "src/common/runtime", default-features = false}
common-system-info = {path = "src/common/system-info", default-features = false}
common-tracing = {path = "src/common/tracing", default-features = false}
common-version = {path = "src/common/version", default-features = false}
Expand Down
15 changes: 15 additions & 0 deletions src/common/runtime/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[dependencies]
common-error = {path = "../error", default-features = false}
futures = {workspace = true}
lazy_static = {workspace = true}
log = {workspace = true}
oneshot = "0.1.8"
tokio = {workspace = true}

[lints]
workspace = true

[package]
edition = {workspace = true}
name = "common-runtime"
version = {workspace = true}
184 changes: 184 additions & 0 deletions src/common/runtime/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
use std::{
future::Future,
panic::AssertUnwindSafe,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, OnceLock,
},
};

use common_error::{DaftError, DaftResult};
use futures::FutureExt;
use lazy_static::lazy_static;
use tokio::{runtime::RuntimeFlavor, task::JoinHandle};

lazy_static! {
static ref NUM_CPUS: usize = std::thread::available_parallelism().unwrap().get();
static ref THREADED_IO_RUNTIME_NUM_WORKER_THREADS: usize = 8.min(*NUM_CPUS);
static ref COMPUTE_RUNTIME_NUM_WORKER_THREADS: usize = *NUM_CPUS;
static ref COMPUTE_RUNTIME_MAX_BLOCKING_THREADS: usize = 1; // Compute thread should not use blocking threads, limit this to the minimum, i.e. 1
}

static THREADED_IO_RUNTIME: OnceLock<RuntimeRef> = OnceLock::new();
static SINGLE_THREADED_IO_RUNTIME: OnceLock<RuntimeRef> = OnceLock::new();
static COMPUTE_RUNTIME: OnceLock<RuntimeRef> = OnceLock::new();

pub type RuntimeRef = Arc<Runtime>;

#[derive(Debug, Clone, Copy)]
enum PoolType {
Compute,
IO,
}

pub struct Runtime {
runtime: tokio::runtime::Runtime,
pool_type: PoolType,
}

impl Runtime {
pub(crate) fn new(runtime: tokio::runtime::Runtime, pool_type: PoolType) -> RuntimeRef {
Arc::new(Self { runtime, pool_type })
}

async fn execute_task<F>(future: F, pool_type: PoolType) -> DaftResult<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
AssertUnwindSafe(future).catch_unwind().await.map_err(|e| {
let s = if let Some(s) = e.downcast_ref::<String>() {
s.clone()
} else if let Some(s) = e.downcast_ref::<&str>() {
(*s).to_string()

Check warning on line 53 in src/common/runtime/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/common/runtime/src/lib.rs#L50-L53

Added lines #L50 - L53 were not covered by tests
} else {
"unknown internal error".to_string()

Check warning on line 55 in src/common/runtime/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/common/runtime/src/lib.rs#L55

Added line #L55 was not covered by tests
};
DaftError::ComputeError(format!(
"Caught panic when spawning blocking task in the {:?} runtime: {})",
pool_type, s
))

Check warning on line 60 in src/common/runtime/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/common/runtime/src/lib.rs#L57-L60

Added lines #L57 - L60 were not covered by tests
})
}

/// Spawns a task on the runtime and blocks the current thread until the task is completed.
/// Similar to tokio's Runtime::block_on but requires static lifetime + Send
/// You should use this when you are spawning IO tasks from an Expression Evaluator or in the Executor
pub fn block_on<F>(&self, future: F) -> DaftResult<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let (tx, rx) = oneshot::channel();
let pool_type = self.pool_type;
let _join_handle = self.spawn(async move {
let task_output = Self::execute_task(future, pool_type).await;
if tx.send(task_output).is_err() {
log::warn!("Spawned task output ignored: receiver dropped");

Check warning on line 77 in src/common/runtime/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/common/runtime/src/lib.rs#L77

Added line #L77 was not covered by tests
}
});
rx.recv().expect("Spawned task transmitter dropped")
}

/// Spawn a task on the runtime and await on it.
/// You should use this when you are spawning compute or IO tasks from the Executor.
pub async fn await_on<F>(&self, future: F) -> DaftResult<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let (tx, rx) = oneshot::channel();
let pool_type = self.pool_type;
let _join_handle = self.spawn(async move {
let task_output = Self::execute_task(future, pool_type).await;
if tx.send(task_output).is_err() {
log::warn!("Spawned task output ignored: receiver dropped");

Check warning on line 95 in src/common/runtime/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/common/runtime/src/lib.rs#L95

Added line #L95 was not covered by tests
}
});
rx.await.expect("Spawned task transmitter dropped")
}

/// Blocks current thread to compute future. Can not be called in tokio runtime context
///
pub fn block_on_current_thread<F: Future>(&self, future: F) -> F::Output {
self.runtime.block_on(future)
}

pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
colin-ho marked this conversation as resolved.
Show resolved Hide resolved
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.runtime.spawn(future)
}
}

fn init_compute_runtime() -> RuntimeRef {
std::thread::spawn(move || {
let mut builder = tokio::runtime::Builder::new_multi_thread();
builder
.worker_threads(*COMPUTE_RUNTIME_NUM_WORKER_THREADS)
.enable_all()
.thread_name_fn(move || {
static COMPUTE_THREAD_ATOMIC_ID: AtomicUsize = AtomicUsize::new(0);
let id = COMPUTE_THREAD_ATOMIC_ID.fetch_add(1, Ordering::SeqCst);
format!("Compute-Thread-{}", id)
})
.max_blocking_threads(*COMPUTE_RUNTIME_MAX_BLOCKING_THREADS);
Runtime::new(builder.build().unwrap(), PoolType::Compute)
})
.join()
.unwrap()
}

fn init_io_runtime(multi_thread: bool) -> RuntimeRef {
std::thread::spawn(move || {
let mut builder = tokio::runtime::Builder::new_multi_thread();
builder
.worker_threads(if multi_thread {
*THREADED_IO_RUNTIME_NUM_WORKER_THREADS
} else {
1
})
.enable_all()
.thread_name_fn(move || {
static COMPUTE_THREAD_ATOMIC_ID: AtomicUsize = AtomicUsize::new(0);
let id = COMPUTE_THREAD_ATOMIC_ID.fetch_add(1, Ordering::SeqCst);
format!("IO-Thread-{}", id)
});
Runtime::new(builder.build().unwrap(), PoolType::IO)
})
.join()
.unwrap()
}

pub fn get_compute_runtime() -> RuntimeRef {
COMPUTE_RUNTIME.get_or_init(init_compute_runtime).clone()
}

pub fn get_io_runtime(multi_thread: bool) -> RuntimeRef {
if !multi_thread {
SINGLE_THREADED_IO_RUNTIME
.get_or_init(|| init_io_runtime(false))
.clone()
} else {
THREADED_IO_RUNTIME
.get_or_init(|| init_io_runtime(true))
.clone()
}
}

#[must_use]
pub fn get_io_pool_num_threads() -> Option<usize> {
match tokio::runtime::Handle::try_current() {
Ok(handle) => {
match handle.runtime_flavor() {
RuntimeFlavor::CurrentThread => Some(1),
RuntimeFlavor::MultiThread => Some(*THREADED_IO_RUNTIME_NUM_WORKER_THREADS),
// RuntimeFlavor is #non_exhaustive, so we default to 1 here to be conservative
_ => Some(1),

Check warning on line 179 in src/common/runtime/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/common/runtime/src/lib.rs#L179

Added line #L179 was not covered by tests
}
}
Err(_) => None,

Check warning on line 182 in src/common/runtime/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/common/runtime/src/lib.rs#L182

Added line #L182 was not covered by tests
}
}
1 change: 1 addition & 0 deletions src/daft-csv/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ async-compat = {workspace = true}
async-stream = {workspace = true}
common-error = {path = "../common/error", default-features = false}
common-py-serde = {path = "../common/py-serde", default-features = false}
common-runtime = {path = "../common/runtime", default-features = false}
csv-async = "1.3.0"
daft-compression = {path = "../daft-compression", default-features = false}
daft-core = {path = "../daft-core", default-features = false}
Expand Down
7 changes: 4 additions & 3 deletions src/daft-csv/src/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
use arrow2::io::csv::read_async::{AsyncReader, AsyncReaderBuilder};
use async_compat::CompatExt;
use common_error::DaftResult;
use common_runtime::get_io_runtime;
use csv_async::ByteRecord;
use daft_compression::CompressionCodec;
use daft_core::prelude::Schema;
use daft_decoding::inference::infer;
use daft_io::{get_runtime, GetResult, IOClient, IOStatsRef};
use daft_io::{GetResult, IOClient, IOStatsRef};
use futures::{StreamExt, TryStreamExt};
use snafu::ResultExt;
use tokio::{
Expand Down Expand Up @@ -58,7 +59,7 @@
io_client: Arc<IOClient>,
io_stats: Option<IOStatsRef>,
) -> DaftResult<(Schema, CsvReadStats)> {
let runtime_handle = get_runtime(true)?;
let runtime_handle = get_io_runtime(true);
runtime_handle.block_on_current_thread(async {
read_csv_schema_single(
uri,
Expand All @@ -80,7 +81,7 @@
io_stats: Option<IOStatsRef>,
num_parallel_tasks: usize,
) -> DaftResult<Vec<(Schema, CsvReadStats)>> {
let runtime_handle = get_runtime(true)?;
let runtime_handle = get_io_runtime(true);

Check warning on line 84 in src/daft-csv/src/metadata.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-csv/src/metadata.rs#L84

Added line #L84 was not covered by tests
let result = runtime_handle
.block_on_current_thread(async {
let task_stream = futures::stream::iter(uris.iter().map(|uri| {
Expand Down
7 changes: 4 additions & 3 deletions src/daft-csv/src/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ use arrow2::{
};
use async_compat::{Compat, CompatExt};
use common_error::{DaftError, DaftResult};
use common_runtime::get_io_runtime;
use csv_async::AsyncReader;
use daft_compression::CompressionCodec;
use daft_core::{prelude::*, utils::arrow::cast_array_for_daft_if_needed};
use daft_decoding::deserialize::deserialize_column;
use daft_dsl::optimization::get_required_columns;
use daft_io::{get_runtime, GetResult, IOClient, IOStatsRef};
use daft_io::{GetResult, IOClient, IOStatsRef};
use daft_table::Table;
use futures::{stream::BoxStream, Stream, StreamExt, TryStreamExt};
use rayon::{
Expand Down Expand Up @@ -53,7 +54,7 @@ pub fn read_csv(
multithreaded_io: bool,
max_chunks_in_flight: Option<usize>,
) -> DaftResult<Table> {
let runtime_handle = get_runtime(multithreaded_io)?;
let runtime_handle = get_io_runtime(multithreaded_io);
runtime_handle.block_on_current_thread(async {
read_csv_single_into_table(
uri,
Expand All @@ -80,7 +81,7 @@ pub fn read_csv_bulk(
max_chunks_in_flight: Option<usize>,
num_parallel_tasks: usize,
) -> DaftResult<Vec<Table>> {
let runtime_handle = get_runtime(multithreaded_io)?;
let runtime_handle = get_io_runtime(multithreaded_io);
let tables = runtime_handle.block_on_current_thread(async move {
// Launch a read task per URI, throttling the number of concurrent file reads to num_parallel tasks.
let task_stream = futures::stream::iter(uris.iter().map(|uri| {
Expand Down
1 change: 1 addition & 0 deletions src/daft-functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ base64 = {workspace = true}
common-error = {path = "../common/error", default-features = false}
common-hashable-float-wrapper = {path = "../common/hashable-float-wrapper"}
common-io-config = {path = "../common/io-config", default-features = false}
common-runtime = {path = "../common/runtime", default-features = false}
daft-core = {path = "../daft-core", default-features = false}
daft-dsl = {path = "../daft-dsl", default-features = false}
daft-image = {path = "../daft-image", default-features = false}
Expand Down
Loading
Loading