diff --git a/Cargo.lock b/Cargo.lock index 3fcc294fef..c146000729 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1385,6 +1385,18 @@ dependencies = [ "serde", ] +[[package]] +name = "common-runtime" +version = "0.3.0-dev0" +dependencies = [ + "common-error", + "futures", + "lazy_static", + "log", + "oneshot", + "tokio", +] + [[package]] name = "common-system-info" version = "0.3.0-dev0" @@ -1674,6 +1686,7 @@ dependencies = [ "common-file-formats", "common-hashable-float-wrapper", "common-resource-request", + "common-runtime", "common-system-info", "common-tracing", "common-version", @@ -1763,6 +1776,7 @@ dependencies = [ "async-stream", "common-error", "common-py-serde", + "common-runtime", "csv-async", "daft-compression", "daft-core", @@ -1823,6 +1837,7 @@ dependencies = [ "common-error", "common-hashable-float-wrapper", "common-io-config", + "common-runtime", "daft-core", "daft-dsl", "daft-image", @@ -1892,6 +1907,7 @@ dependencies = [ "common-error", "common-file-formats", "common-io-config", + "common-runtime", "futures", "globset", "google-cloud-storage", @@ -1903,7 +1919,6 @@ dependencies = [ "lazy_static", "log", "md5", - "oneshot", "openssl-sys", "pyo3", "rand 0.8.5", @@ -1925,6 +1940,7 @@ dependencies = [ "chrono", "common-error", "common-py-serde", + "common-runtime", "daft-compression", "daft-core", "daft-decoding", @@ -1956,6 +1972,7 @@ dependencies = [ "common-display", "common-error", "common-file-formats", + "common-runtime", "common-tracing", "daft-core", "daft-csv", @@ -1989,6 +2006,7 @@ dependencies = [ "bincode", "common-error", "common-file-formats", + "common-runtime", "daft-core", "daft-csv", "daft-dsl", @@ -2023,6 +2041,7 @@ dependencies = [ "bytes", "common-arrow-ffi", "common-error", + "common-runtime", "crossbeam-channel", "daft-core", "daft-dsl", @@ -2100,6 +2119,7 @@ dependencies = [ "common-file-formats", "common-io-config", "common-py-serde", + "common-runtime", "daft-core", "daft-csv", "daft-dsl", diff --git a/Cargo.toml b/Cargo.toml index dde3543e70..dfc8fce9d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ common-display = {path = "src/common/display", default-features = false} common-file-formats = {path = "src/common/file-formats", default-features = false} common-hashable-float-wrapper = {path = "src/common/hashable-float-wrapper", default-features = false} common-resource-request = {path = "src/common/resource-request", default-features = false} +common-runtime = {path = "src/common/runtime", default-features = false} common-system-info = {path = "src/common/system-info", default-features = false} common-tracing = {path = "src/common/tracing", default-features = false} common-version = {path = "src/common/version", default-features = false} diff --git a/src/common/runtime/Cargo.toml b/src/common/runtime/Cargo.toml new file mode 100644 index 0000000000..a32376ceb2 --- /dev/null +++ b/src/common/runtime/Cargo.toml @@ -0,0 +1,15 @@ +[dependencies] +common-error = {path = "../error", default-features = false} +futures = {workspace = true} +lazy_static = {workspace = true} +log = {workspace = true} +oneshot = "0.1.8" +tokio = {workspace = true} + +[lints] +workspace = true + +[package] +edition = {workspace = true} +name = "common-runtime" +version = {workspace = true} diff --git a/src/common/runtime/src/lib.rs b/src/common/runtime/src/lib.rs new file mode 100644 index 0000000000..22a7fcdec9 --- /dev/null +++ b/src/common/runtime/src/lib.rs @@ -0,0 +1,185 @@ +use std::{ + future::Future, + panic::AssertUnwindSafe, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, OnceLock, + }, +}; + +use common_error::{DaftError, DaftResult}; +use futures::FutureExt; +use lazy_static::lazy_static; +use tokio::{runtime::RuntimeFlavor, task::JoinHandle}; + +lazy_static! { + static ref NUM_CPUS: usize = std::thread::available_parallelism().unwrap().get(); + static ref THREADED_IO_RUNTIME_NUM_WORKER_THREADS: usize = 8.min(*NUM_CPUS); + static ref COMPUTE_RUNTIME_NUM_WORKER_THREADS: usize = *NUM_CPUS; + static ref COMPUTE_RUNTIME_MAX_BLOCKING_THREADS: usize = 1; // Compute thread should not use blocking threads, limit this to the minimum, i.e. 1 +} + +static THREADED_IO_RUNTIME: OnceLock = OnceLock::new(); +static SINGLE_THREADED_IO_RUNTIME: OnceLock = OnceLock::new(); +static COMPUTE_RUNTIME: OnceLock = OnceLock::new(); + +pub type RuntimeRef = Arc; + +#[derive(Debug, Clone, Copy)] +enum PoolType { + Compute, + IO, +} + +pub struct Runtime { + runtime: tokio::runtime::Runtime, + pool_type: PoolType, +} + +impl Runtime { + pub(crate) fn new(runtime: tokio::runtime::Runtime, pool_type: PoolType) -> RuntimeRef { + Arc::new(Self { runtime, pool_type }) + } + + // TODO: figure out a way to cancel the Future if this output is dropped. + async fn execute_task(future: F, pool_type: PoolType) -> DaftResult + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + AssertUnwindSafe(future).catch_unwind().await.map_err(|e| { + let s = if let Some(s) = e.downcast_ref::() { + s.clone() + } else if let Some(s) = e.downcast_ref::<&str>() { + (*s).to_string() + } else { + "unknown internal error".to_string() + }; + DaftError::ComputeError(format!( + "Caught panic when spawning blocking task in the {:?} runtime: {})", + pool_type, s + )) + }) + } + + /// Spawns a task on the runtime and blocks the current thread until the task is completed. + /// Similar to tokio's Runtime::block_on but requires static lifetime + Send + /// You should use this when you are spawning IO tasks from an Expression Evaluator or in the Executor + pub fn block_on(&self, future: F) -> DaftResult + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + let (tx, rx) = oneshot::channel(); + let pool_type = self.pool_type; + let _join_handle = self.spawn(async move { + let task_output = Self::execute_task(future, pool_type).await; + if tx.send(task_output).is_err() { + log::warn!("Spawned task output ignored: receiver dropped"); + } + }); + rx.recv().expect("Spawned task transmitter dropped") + } + + /// Spawn a task on the runtime and await on it. + /// You should use this when you are spawning compute or IO tasks from the Executor. + pub async fn await_on(&self, future: F) -> DaftResult + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + let (tx, rx) = oneshot::channel(); + let pool_type = self.pool_type; + let _join_handle = self.spawn(async move { + let task_output = Self::execute_task(future, pool_type).await; + if tx.send(task_output).is_err() { + log::warn!("Spawned task output ignored: receiver dropped"); + } + }); + rx.await.expect("Spawned task transmitter dropped") + } + + /// Blocks current thread to compute future. Can not be called in tokio runtime context + /// + pub fn block_on_current_thread(&self, future: F) -> F::Output { + self.runtime.block_on(future) + } + + pub fn spawn(&self, future: F) -> JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + self.runtime.spawn(future) + } +} + +fn init_compute_runtime() -> RuntimeRef { + std::thread::spawn(move || { + let mut builder = tokio::runtime::Builder::new_multi_thread(); + builder + .worker_threads(*COMPUTE_RUNTIME_NUM_WORKER_THREADS) + .enable_all() + .thread_name_fn(move || { + static COMPUTE_THREAD_ATOMIC_ID: AtomicUsize = AtomicUsize::new(0); + let id = COMPUTE_THREAD_ATOMIC_ID.fetch_add(1, Ordering::SeqCst); + format!("Compute-Thread-{}", id) + }) + .max_blocking_threads(*COMPUTE_RUNTIME_MAX_BLOCKING_THREADS); + Runtime::new(builder.build().unwrap(), PoolType::Compute) + }) + .join() + .unwrap() +} + +fn init_io_runtime(multi_thread: bool) -> RuntimeRef { + std::thread::spawn(move || { + let mut builder = tokio::runtime::Builder::new_multi_thread(); + builder + .worker_threads(if multi_thread { + *THREADED_IO_RUNTIME_NUM_WORKER_THREADS + } else { + 1 + }) + .enable_all() + .thread_name_fn(move || { + static COMPUTE_THREAD_ATOMIC_ID: AtomicUsize = AtomicUsize::new(0); + let id = COMPUTE_THREAD_ATOMIC_ID.fetch_add(1, Ordering::SeqCst); + format!("IO-Thread-{}", id) + }); + Runtime::new(builder.build().unwrap(), PoolType::IO) + }) + .join() + .unwrap() +} + +pub fn get_compute_runtime() -> RuntimeRef { + COMPUTE_RUNTIME.get_or_init(init_compute_runtime).clone() +} + +pub fn get_io_runtime(multi_thread: bool) -> RuntimeRef { + if !multi_thread { + SINGLE_THREADED_IO_RUNTIME + .get_or_init(|| init_io_runtime(false)) + .clone() + } else { + THREADED_IO_RUNTIME + .get_or_init(|| init_io_runtime(true)) + .clone() + } +} + +#[must_use] +pub fn get_io_pool_num_threads() -> Option { + match tokio::runtime::Handle::try_current() { + Ok(handle) => { + match handle.runtime_flavor() { + RuntimeFlavor::CurrentThread => Some(1), + RuntimeFlavor::MultiThread => Some(*THREADED_IO_RUNTIME_NUM_WORKER_THREADS), + // RuntimeFlavor is #non_exhaustive, so we default to 1 here to be conservative + _ => Some(1), + } + } + Err(_) => None, + } +} diff --git a/src/daft-csv/Cargo.toml b/src/daft-csv/Cargo.toml index dde511422b..f54b59521c 100644 --- a/src/daft-csv/Cargo.toml +++ b/src/daft-csv/Cargo.toml @@ -4,6 +4,7 @@ async-compat = {workspace = true} async-stream = {workspace = true} common-error = {path = "../common/error", default-features = false} common-py-serde = {path = "../common/py-serde", default-features = false} +common-runtime = {path = "../common/runtime", default-features = false} csv-async = "1.3.0" daft-compression = {path = "../daft-compression", default-features = false} daft-core = {path = "../daft-core", default-features = false} diff --git a/src/daft-csv/src/metadata.rs b/src/daft-csv/src/metadata.rs index 14d5d472ab..17ce4c0267 100644 --- a/src/daft-csv/src/metadata.rs +++ b/src/daft-csv/src/metadata.rs @@ -3,11 +3,12 @@ use std::{collections::HashSet, sync::Arc}; use arrow2::io::csv::read_async::{AsyncReader, AsyncReaderBuilder}; use async_compat::CompatExt; use common_error::DaftResult; +use common_runtime::get_io_runtime; use csv_async::ByteRecord; use daft_compression::CompressionCodec; use daft_core::prelude::Schema; use daft_decoding::inference::infer; -use daft_io::{get_runtime, GetResult, IOClient, IOStatsRef}; +use daft_io::{GetResult, IOClient, IOStatsRef}; use futures::{StreamExt, TryStreamExt}; use snafu::ResultExt; use tokio::{ @@ -58,7 +59,7 @@ pub fn read_csv_schema( io_client: Arc, io_stats: Option, ) -> DaftResult<(Schema, CsvReadStats)> { - let runtime_handle = get_runtime(true)?; + let runtime_handle = get_io_runtime(true); runtime_handle.block_on_current_thread(async { read_csv_schema_single( uri, @@ -80,7 +81,7 @@ pub async fn read_csv_schema_bulk( io_stats: Option, num_parallel_tasks: usize, ) -> DaftResult> { - let runtime_handle = get_runtime(true)?; + let runtime_handle = get_io_runtime(true); let result = runtime_handle .block_on_current_thread(async { let task_stream = futures::stream::iter(uris.iter().map(|uri| { diff --git a/src/daft-csv/src/read.rs b/src/daft-csv/src/read.rs index ce5cb556e4..0754099c21 100644 --- a/src/daft-csv/src/read.rs +++ b/src/daft-csv/src/read.rs @@ -6,12 +6,13 @@ use arrow2::{ }; use async_compat::{Compat, CompatExt}; use common_error::{DaftError, DaftResult}; +use common_runtime::get_io_runtime; use csv_async::AsyncReader; use daft_compression::CompressionCodec; use daft_core::{prelude::*, utils::arrow::cast_array_for_daft_if_needed}; use daft_decoding::deserialize::deserialize_column; use daft_dsl::optimization::get_required_columns; -use daft_io::{get_runtime, GetResult, IOClient, IOStatsRef}; +use daft_io::{GetResult, IOClient, IOStatsRef}; use daft_table::Table; use futures::{stream::BoxStream, Stream, StreamExt, TryStreamExt}; use rayon::{ @@ -53,7 +54,7 @@ pub fn read_csv( multithreaded_io: bool, max_chunks_in_flight: Option, ) -> DaftResult { - let runtime_handle = get_runtime(multithreaded_io)?; + let runtime_handle = get_io_runtime(multithreaded_io); runtime_handle.block_on_current_thread(async { read_csv_single_into_table( uri, @@ -80,7 +81,7 @@ pub fn read_csv_bulk( max_chunks_in_flight: Option, num_parallel_tasks: usize, ) -> DaftResult> { - let runtime_handle = get_runtime(multithreaded_io)?; + let runtime_handle = get_io_runtime(multithreaded_io); let tables = runtime_handle.block_on_current_thread(async move { // Launch a read task per URI, throttling the number of concurrent file reads to num_parallel tasks. let task_stream = futures::stream::iter(uris.iter().map(|uri| { diff --git a/src/daft-functions/Cargo.toml b/src/daft-functions/Cargo.toml index febb241e13..d8452d3dbe 100644 --- a/src/daft-functions/Cargo.toml +++ b/src/daft-functions/Cargo.toml @@ -4,6 +4,7 @@ base64 = {workspace = true} common-error = {path = "../common/error", default-features = false} common-hashable-float-wrapper = {path = "../common/hashable-float-wrapper"} common-io-config = {path = "../common/io-config", default-features = false} +common-runtime = {path = "../common/runtime", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-dsl = {path = "../daft-dsl", default-features = false} daft-image = {path = "../daft-image", default-features = false} diff --git a/src/daft-functions/src/tokenize/bpe.rs b/src/daft-functions/src/tokenize/bpe.rs index 98c826f498..9423cf8c72 100644 --- a/src/daft-functions/src/tokenize/bpe.rs +++ b/src/daft-functions/src/tokenize/bpe.rs @@ -8,7 +8,8 @@ use std::{ use base64::{engine::general_purpose, DecodeError, Engine}; use common_error::{DaftError, DaftResult}; -use daft_io::{get_io_client, get_runtime, IOConfig}; +use common_runtime::get_io_runtime; +use daft_io::{get_io_client, IOConfig}; use snafu::{prelude::*, Snafu}; use tiktoken_rs::CoreBPE; @@ -158,12 +159,11 @@ fn get_file_bpe( ) -> DaftResult { // Fetch the token file as a string let client = get_io_client(false, io_config)?; - let runtime = get_runtime(false)?; + let runtime = get_io_runtime(false); let path = path.to_string(); - let file_bytes = runtime.block_on_io_pool(async move { - client.single_url_get(path, None, None).await?.bytes().await - })??; + let file_bytes = runtime + .block_on(async move { client.single_url_get(path, None, None).await?.bytes().await })??; let file_str = std::str::from_utf8(&file_bytes).with_context(|_| InvalidUtf8SequenceSnafu)?; let tokens_res = parse_tokens(file_str)?; diff --git a/src/daft-functions/src/uri/download.rs b/src/daft-functions/src/uri/download.rs index 15ebc2f9fc..24d3f89d33 100644 --- a/src/daft-functions/src/uri/download.rs +++ b/src/daft-functions/src/uri/download.rs @@ -1,9 +1,10 @@ use std::sync::Arc; use common_error::{DaftError, DaftResult}; +use common_runtime::get_io_runtime; use daft_core::prelude::*; use daft_dsl::{functions::ScalarUDF, ExprRef}; -use daft_io::{get_io_client, get_runtime, Error, IOConfig, IOStatsContext, IOStatsRef}; +use daft_io::{get_io_client, Error, IOConfig, IOStatsContext, IOStatsRef}; use futures::{StreamExt, TryStreamExt}; use serde::Serialize; use snafu::prelude::*; @@ -98,7 +99,7 @@ fn url_download( } ); - let runtime_handle = get_runtime(true)?; + let runtime_handle = get_io_runtime(true); let max_connections = match multi_thread { false => max_connections, true => max_connections * usize::from(std::thread::available_parallelism()?), @@ -139,7 +140,7 @@ fn url_download( stream.try_collect::>().await }; - let mut results = runtime_handle.block_on_io_pool(fetches)??; + let mut results = runtime_handle.block_on(fetches)??; results.sort_by_key(|k| k.0); let mut offsets: Vec = Vec::with_capacity(results.len() + 1); diff --git a/src/daft-functions/src/uri/upload.rs b/src/daft-functions/src/uri/upload.rs index 4ab677614c..1ad91b888b 100644 --- a/src/daft-functions/src/uri/upload.rs +++ b/src/daft-functions/src/uri/upload.rs @@ -1,9 +1,10 @@ use std::sync::Arc; use common_error::{DaftError, DaftResult}; +use common_runtime::get_io_runtime; use daft_core::prelude::*; use daft_dsl::{functions::ScalarUDF, ExprRef}; -use daft_io::{get_io_client, get_runtime, IOConfig, IOStatsRef, SourceType}; +use daft_io::{get_io_client, IOConfig, IOStatsRef, SourceType}; use futures::{StreamExt, TryStreamExt}; use serde::Serialize; @@ -109,7 +110,7 @@ pub fn url_upload( })?; } - let runtime_handle = get_runtime(multi_thread)?; + let runtime_handle = get_io_runtime(multi_thread); let max_connections = match multi_thread { false => max_connections, true => max_connections * usize::from(std::thread::available_parallelism()?), @@ -143,7 +144,7 @@ pub fn url_upload( .await }; - let mut results = runtime_handle.block_on_io_pool(uploads)??; + let mut results = runtime_handle.block_on(uploads)??; results.sort_by_key(|k| k.0); Ok(results.into_iter().map(|(_, path)| path).collect()) diff --git a/src/daft-io/Cargo.toml b/src/daft-io/Cargo.toml index 486a6810ae..88f7ca85d0 100644 --- a/src/daft-io/Cargo.toml +++ b/src/daft-io/Cargo.toml @@ -16,6 +16,7 @@ bytes = {workspace = true} common-error = {path = "../common/error", default-features = false} common-file-formats = {path = "../common/file-formats", default-features = false} common-io-config = {path = "../common/io-config", default-features = false} +common-runtime = {path = "../common/runtime", default-features = false} futures = {workspace = true} globset = "0.4" google-cloud-storage = {version = "0.15.0", default-features = false, features = ["default-tls", "auth"]} @@ -26,7 +27,6 @@ hyper-tls = "0.5.0" itertools = {workspace = true} lazy_static = {workspace = true} log = {workspace = true} -oneshot = "0.1.8" openssl-sys = {version = "0.9.102", features = ["vendored"]} pyo3 = {workspace = true, optional = true} rand = "0.8.5" diff --git a/src/daft-io/src/lib.rs b/src/daft-io/src/lib.rs index 745fc4065c..63d1d08485 100644 --- a/src/daft-io/src/lib.rs +++ b/src/daft-io/src/lib.rs @@ -14,22 +14,13 @@ mod stats; mod stream_utils; use azure_blob::AzureBlobSource; use common_file_formats::FileFormat; -use futures::FutureExt; use google_cloud::GCSSource; use huggingface::HFSource; use lazy_static::lazy_static; #[cfg(feature = "python")] pub mod python; -use std::{ - borrow::Cow, - collections::HashMap, - future::Future, - hash::Hash, - ops::Range, - panic::AssertUnwindSafe, - sync::{Arc, OnceLock}, -}; +use std::{borrow::Cow, collections::HashMap, hash::Hash, ops::Range, sync::Arc}; use common_error::{DaftError, DaftResult}; pub use common_io_config::{AzureConfig, IOConfig, S3Config}; @@ -41,7 +32,6 @@ pub use python::register_modules; use s3_like::S3LikeSource; use snafu::{prelude::*, Snafu}; pub use stats::{IOStatsContext, IOStatsRef}; -use tokio::{runtime::RuntimeFlavor, task::JoinHandle}; use url::ParseError; use self::{http::HttpSource, local::LocalSource, object_io::ObjectSource}; @@ -429,12 +419,7 @@ pub fn parse_url(input: &str) -> Result<(SourceType, Cow<'_, str>)> { } type CacheKey = (bool, Arc); -static THREADED_RUNTIME: OnceLock = OnceLock::new(); -static SINGLE_THREADED_RUNTIME: OnceLock = OnceLock::new(); - lazy_static! { - static ref NUM_CPUS: usize = std::thread::available_parallelism().unwrap().get(); - static ref THREADED_RUNTIME_NUM_WORKER_THREADS: usize = 8.min(*NUM_CPUS); static ref CLIENT_CACHE: std::sync::RwLock>> = std::sync::RwLock::new(HashMap::new()); } @@ -458,102 +443,4 @@ pub fn get_io_client(multi_thread: bool, config: Arc) -> DaftResult; - -pub struct Runtime { - runtime: tokio::runtime::Runtime, -} - -impl Runtime { - fn new(runtime: tokio::runtime::Runtime) -> RuntimeRef { - Arc::new(Self { runtime }) - } - - /// Similar to tokio's Runtime::block_on but requires static lifetime + Send - /// You should use this when you are spawning IO tasks from an Expression Evaluator or in the Executor - pub fn block_on_io_pool(&self, future: F) -> DaftResult - where - F: Future + Send + 'static, - F::Output: Send + 'static, - { - let (tx, rx) = oneshot::channel(); - let _join_handle = self.spawn(async move { - let task_output = AssertUnwindSafe(future).catch_unwind().await.map_err(|e| { - let s = if let Some(s) = e.downcast_ref::() { - s.clone() - } else if let Some(s) = e.downcast_ref::<&str>() { - (*s).to_string() - } else { - "unknown internal error".to_string() - }; - DaftError::ComputeError(format!( - "Caught panic when spawning blocking task in io pool {s})" - )) - }); - - if tx.send(task_output).is_err() { - log::warn!("Spawned task output ignored: receiver dropped"); - } - }); - rx.recv().expect("Spawned task transmitter dropped") - } - - /// Blocks current thread to compute future. Can not be called in tokio runtime context - /// - pub fn block_on_current_thread(&self, future: F) -> F::Output { - self.runtime.block_on(future) - } - - pub fn spawn(&self, future: F) -> JoinHandle - where - F: Future + Send + 'static, - F::Output: Send + 'static, - { - self.runtime.spawn(future) - } -} - -fn init_runtime(num_threads: usize) -> Arc { - std::thread::spawn(move || { - Runtime::new( - tokio::runtime::Builder::new_multi_thread() - .worker_threads(num_threads) - .enable_all() - .build() - .unwrap(), - ) - }) - .join() - .unwrap() -} - -pub fn get_runtime(multi_thread: bool) -> DaftResult { - if !multi_thread { - let runtime = SINGLE_THREADED_RUNTIME - .get_or_init(|| init_runtime(1)) - .clone(); - Ok(runtime) - } else { - let runtime = THREADED_RUNTIME - .get_or_init(|| init_runtime(*THREADED_RUNTIME_NUM_WORKER_THREADS)) - .clone(); - Ok(runtime) - } -} - -#[must_use] -pub fn get_io_pool_num_threads() -> Option { - match tokio::runtime::Handle::try_current() { - Ok(handle) => { - match handle.runtime_flavor() { - RuntimeFlavor::CurrentThread => Some(1), - RuntimeFlavor::MultiThread => Some(*THREADED_RUNTIME_NUM_WORKER_THREADS), - // RuntimeFlavor is #non_exhaustive, so we default to 1 here to be conservative - _ => Some(1), - } - } - Err(_) => None, - } -} - type DynError = Box; diff --git a/src/daft-io/src/python.rs b/src/daft-io/src/python.rs index 6dac52af8a..bd4d79243a 100644 --- a/src/daft-io/src/python.rs +++ b/src/daft-io/src/python.rs @@ -3,10 +3,11 @@ pub use py::register_modules; mod py { use common_error::DaftResult; + use common_runtime::get_io_runtime; use futures::TryStreamExt; use pyo3::{prelude::*, types::PyDict}; - use crate::{get_io_client, get_runtime, parse_url, s3_like, stats::IOStatsContext}; + use crate::{get_io_client, parse_url, s3_like, stats::IOStatsContext}; #[pyfunction] fn io_glob( @@ -28,7 +29,7 @@ mod py { io_config.unwrap_or_default().config.into(), )?; let (scheme, path) = parse_url(&path)?; - let runtime_handle = get_runtime(multithreaded_io)?; + let runtime_handle = get_io_runtime(multithreaded_io); runtime_handle.block_on_current_thread(async move { let source = io_client.get_source(&scheme).await?; @@ -63,7 +64,7 @@ mod py { #[pyfunction] fn s3_config_from_env(py: Python) -> PyResult { let s3_config: DaftResult = py.allow_threads(|| { - let runtime = get_runtime(false)?; + let runtime = get_io_runtime(false); runtime.block_on_current_thread(async { Ok(s3_like::s3_config_from_env().await?) }) }); Ok(common_io_config::python::S3Config { config: s3_config? }) diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index 1604bf0aff..a3d8e9df51 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -17,6 +17,7 @@ use aws_sdk_s3::{ use aws_sig_auth::signer::SigningRequirements; use aws_smithy_async::rt::sleep::TokioSleep; use common_io_config::S3Config; +use common_runtime::get_io_pool_num_threads; use futures::{stream::BoxStream, StreamExt, TryStreamExt}; use reqwest::StatusCode; use s3::{ @@ -34,7 +35,6 @@ use url::{ParseError, Position}; use super::object_io::{GetResult, ObjectSource}; use crate::{ - get_io_pool_num_threads, object_io::{FileMetadata, FileType, LSResult}, stats::IOStatsRef, stream_utils::io_stats_on_bytestream, diff --git a/src/daft-json/Cargo.toml b/src/daft-json/Cargo.toml index 1cf8308a7a..ce1120fbc1 100644 --- a/src/daft-json/Cargo.toml +++ b/src/daft-json/Cargo.toml @@ -3,6 +3,7 @@ arrow2 = {workspace = true} chrono = {workspace = true} common-error = {path = "../common/error", default-features = false} common-py-serde = {path = "../common/py-serde", default-features = false} +common-runtime = {path = "../common/runtime", default-features = false} daft-compression = {path = "../daft-compression", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-decoding = {path = "../daft-decoding"} diff --git a/src/daft-json/src/read.rs b/src/daft-json/src/read.rs index ba9933a46b..9503c77152 100644 --- a/src/daft-json/src/read.rs +++ b/src/daft-json/src/read.rs @@ -1,10 +1,11 @@ use std::{collections::HashMap, num::NonZeroUsize, sync::Arc}; use common_error::{DaftError, DaftResult}; +use common_runtime::get_io_runtime; use daft_compression::CompressionCodec; use daft_core::{prelude::*, utils::arrow::cast_array_for_daft_if_needed}; use daft_dsl::optimization::get_required_columns; -use daft_io::{get_runtime, parse_url, GetResult, IOClient, IOStatsRef, SourceType}; +use daft_io::{parse_url, GetResult, IOClient, IOStatsRef, SourceType}; use daft_table::Table; use futures::{stream::BoxStream, Stream, StreamExt, TryStreamExt}; use rayon::prelude::{IntoParallelIterator, ParallelIterator}; @@ -46,7 +47,7 @@ pub fn read_json( multithreaded_io: bool, max_chunks_in_flight: Option, ) -> DaftResult
{ - let runtime_handle = get_runtime(multithreaded_io)?; + let runtime_handle = get_io_runtime(multithreaded_io); runtime_handle.block_on_current_thread(async { read_json_single_into_table( uri, @@ -73,7 +74,7 @@ pub fn read_json_bulk( max_chunks_in_flight: Option, num_parallel_tasks: usize, ) -> DaftResult> { - let runtime_handle = get_runtime(multithreaded_io)?; + let runtime_handle = get_io_runtime(multithreaded_io); let tables = runtime_handle.block_on_current_thread(async move { // Launch a read task per URI, throttling the number of concurrent file reads to num_parallel tasks. let task_stream = futures::stream::iter(uris.iter().map(|uri| { diff --git a/src/daft-json/src/schema.rs b/src/daft-json/src/schema.rs index e867c513c5..e9af632c4e 100644 --- a/src/daft-json/src/schema.rs +++ b/src/daft-json/src/schema.rs @@ -1,9 +1,10 @@ use std::{collections::HashSet, sync::Arc}; use common_error::DaftResult; +use common_runtime::get_io_runtime; use daft_compression::CompressionCodec; use daft_core::prelude::Schema; -use daft_io::{get_runtime, GetResult, IOClient, IOStatsRef}; +use daft_io::{GetResult, IOClient, IOStatsRef}; use futures::{StreamExt, TryStreamExt}; use indexmap::IndexMap; use snafu::ResultExt; @@ -55,7 +56,7 @@ pub fn read_json_schema( io_client: Arc, io_stats: Option, ) -> DaftResult { - let runtime_handle = get_runtime(true)?; + let runtime_handle = get_io_runtime(true); runtime_handle.block_on_current_thread(async { read_json_schema_single( uri, @@ -77,7 +78,7 @@ pub async fn read_json_schema_bulk( io_stats: Option, num_parallel_tasks: usize, ) -> DaftResult> { - let runtime_handle = get_runtime(true)?; + let runtime_handle = get_io_runtime(true); let result = runtime_handle .block_on_current_thread(async { let task_stream = futures::stream::iter(uris.iter().map(|uri| { diff --git a/src/daft-local-execution/Cargo.toml b/src/daft-local-execution/Cargo.toml index 932462215d..8da3b93325 100644 --- a/src/daft-local-execution/Cargo.toml +++ b/src/daft-local-execution/Cargo.toml @@ -3,6 +3,7 @@ common-daft-config = {path = "../common/daft-config", default-features = false} common-display = {path = "../common/display", default-features = false} common-error = {path = "../common/error", default-features = false} common-file-formats = {path = "../common/file-formats", default-features = false} +common-runtime = {path = "../common/runtime", default-features = false} common-tracing = {path = "../common/tracing", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-csv = {path = "../daft-csv", default-features = false} diff --git a/src/daft-local-execution/src/intermediate_ops/aggregate.rs b/src/daft-local-execution/src/intermediate_ops/aggregate.rs index 13f93cb818..4b8fa7bbb6 100644 --- a/src/daft-local-execution/src/intermediate_ops/aggregate.rs +++ b/src/daft-local-execution/src/intermediate_ops/aggregate.rs @@ -29,7 +29,7 @@ impl IntermediateOperator for AggregateOperator { &self, _idx: usize, input: &PipelineResultType, - _state: Option<&mut Box>, + _state: &IntermediateOperatorState, ) -> DaftResult { let out = input.as_data().agg(&self.agg_exprs, &self.group_by)?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( diff --git a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs index bdcebecab6..45c71d75df 100644 --- a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs @@ -9,7 +9,8 @@ use daft_table::{GrowableTable, Probeable}; use tracing::{info_span, instrument}; use super::intermediate_op::{ - IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, + DynIntermediateOpState, IntermediateOperator, IntermediateOperatorResult, + IntermediateOperatorState, }; use crate::pipeline::PipelineResultType; @@ -36,7 +37,7 @@ impl AntiSemiProbeState { } } -impl IntermediateOperatorState for AntiSemiProbeState { +impl DynIntermediateOpState for AntiSemiProbeState { fn as_any_mut(&mut self) -> &mut dyn std::any::Any { self } @@ -108,34 +109,30 @@ impl IntermediateOperator for AntiSemiProbeOperator { &self, idx: usize, input: &PipelineResultType, - state: Option<&mut Box>, + state: &IntermediateOperatorState, ) -> DaftResult { - let state = state - .expect("AntiSemiProbeOperator should have state") - .as_any_mut() - .downcast_mut::() - .expect("AntiSemiProbeOperator state should be AntiSemiProbeState"); - - if idx == 0 { - let probe_state = input.as_probe_state(); - state.set_table(probe_state.get_probeable()); - Ok(IntermediateOperatorResult::NeedMoreInput(None)) - } else { - let input = input.as_data(); - if input.is_empty() { - let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); - return Ok(IntermediateOperatorResult::NeedMoreInput(Some(empty))); + state.with_state_mut::(|state| { + if idx == 0 { + let probe_state = input.as_probe_state(); + state.set_table(probe_state.get_probeable()); + Ok(IntermediateOperatorResult::NeedMoreInput(None)) + } else { + let input = input.as_data(); + if input.is_empty() { + let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); + return Ok(IntermediateOperatorResult::NeedMoreInput(Some(empty))); + } + let out = self.probe_anti_semi(input, state)?; + Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) } - let out = self.probe_anti_semi(input, state)?; - Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) - } + }) } fn name(&self) -> &'static str { "AntiSemiProbeOperator" } - fn make_state(&self) -> Option> { - Some(Box::new(AntiSemiProbeState::Building)) + fn make_state(&self) -> Box { + Box::new(AntiSemiProbeState::Building) } } diff --git a/src/daft-local-execution/src/intermediate_ops/explode.rs b/src/daft-local-execution/src/intermediate_ops/explode.rs index 774be696a8..30ec8b02b5 100644 --- a/src/daft-local-execution/src/intermediate_ops/explode.rs +++ b/src/daft-local-execution/src/intermediate_ops/explode.rs @@ -28,7 +28,7 @@ impl IntermediateOperator for ExplodeOperator { &self, _idx: usize, input: &PipelineResultType, - _state: Option<&mut Box>, + _state: &IntermediateOperatorState, ) -> DaftResult { let out = input.as_data().explode(&self.to_explode)?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( diff --git a/src/daft-local-execution/src/intermediate_ops/filter.rs b/src/daft-local-execution/src/intermediate_ops/filter.rs index da8dbcd19c..aad3bd7e7d 100644 --- a/src/daft-local-execution/src/intermediate_ops/filter.rs +++ b/src/daft-local-execution/src/intermediate_ops/filter.rs @@ -25,7 +25,7 @@ impl IntermediateOperator for FilterOperator { &self, _idx: usize, input: &PipelineResultType, - _state: Option<&mut Box>, + _state: &IntermediateOperatorState, ) -> DaftResult { let out = input.as_data().filter(&[self.predicate.clone()])?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( diff --git a/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs index a208efea6c..c257b1e618 100644 --- a/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/inner_hash_join_probe.rs @@ -9,7 +9,8 @@ use indexmap::IndexSet; use tracing::{info_span, instrument}; use super::intermediate_op::{ - IntermediateOperator, IntermediateOperatorResult, IntermediateOperatorState, + DynIntermediateOpState, IntermediateOperator, IntermediateOperatorResult, + IntermediateOperatorState, }; use crate::pipeline::PipelineResultType; @@ -36,7 +37,7 @@ impl InnerHashJoinProbeState { } } -impl IntermediateOperatorState for InnerHashJoinProbeState { +impl DynIntermediateOpState for InnerHashJoinProbeState { fn as_any_mut(&mut self) -> &mut dyn std::any::Any { self } @@ -164,14 +165,9 @@ impl IntermediateOperator for InnerHashJoinProbeOperator { &self, idx: usize, input: &PipelineResultType, - state: Option<&mut Box>, + state: &IntermediateOperatorState, ) -> DaftResult { - let state = state - .expect("InnerHashJoinProbeOperator should have state") - .as_any_mut() - .downcast_mut::() - .expect("InnerHashJoinProbeOperator state should be InnerHashJoinProbeState"); - match idx { + state.with_state_mut::(|state| match idx { 0 => { let probe_state = input.as_probe_state(); state.set_probe_state(probe_state.clone()); @@ -186,14 +182,14 @@ impl IntermediateOperator for InnerHashJoinProbeOperator { let out = self.probe_inner(input, state)?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) } - } + }) } fn name(&self) -> &'static str { "InnerHashJoinProbeOperator" } - fn make_state(&self) -> Option> { - Some(Box::new(InnerHashJoinProbeState::Building)) + fn make_state(&self) -> Box { + Box::new(InnerHashJoinProbeState::Building) } } diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index 7b0267c7c0..412d7641a7 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -1,7 +1,8 @@ -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use common_display::tree::TreeDisplay; use common_error::DaftResult; +use common_runtime::get_compute_runtime; use daft_micropartition::MicroPartition; use tracing::{info_span, instrument}; @@ -13,10 +14,41 @@ use crate::{ ExecutionRuntimeHandle, NUM_CPUS, }; -pub trait IntermediateOperatorState: Send + Sync { +pub(crate) trait DynIntermediateOpState: Send + Sync { fn as_any_mut(&mut self) -> &mut dyn std::any::Any; } +struct DefaultIntermediateOperatorState {} +impl DynIntermediateOpState for DefaultIntermediateOperatorState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } +} + +pub(crate) struct IntermediateOperatorState { + inner: Mutex>, +} + +impl IntermediateOperatorState { + fn new(inner: Box) -> Arc { + Arc::new(Self { + inner: Mutex::new(inner), + }) + } + + pub(crate) fn with_state_mut(&self, f: F) -> R + where + F: FnOnce(&mut T) -> R, + { + let mut guard = self.inner.lock().unwrap(); + let state = guard + .as_any_mut() + .downcast_mut::() + .expect("State type mismatch"); + f(state) + } +} + pub enum IntermediateOperatorResult { NeedMoreInput(Option>), #[allow(dead_code)] @@ -28,11 +60,11 @@ pub trait IntermediateOperator: Send + Sync { &self, idx: usize, input: &PipelineResultType, - state: Option<&mut Box>, + state: &IntermediateOperatorState, ) -> DaftResult; fn name(&self) -> &'static str; - fn make_state(&self) -> Option> { - None + fn make_state(&self) -> Box { + Box::new(DefaultIntermediateOperatorState {}) } } @@ -75,11 +107,19 @@ impl IntermediateNode { rt_context: Arc, ) -> DaftResult<()> { let span = info_span!("IntermediateOp::execute"); - let mut state = op.make_state(); + let compute_runtime = get_compute_runtime(); + let state_wrapper = IntermediateOperatorState::new(op.make_state()); while let Some((idx, morsel)) = receiver.recv().await { loop { - let result = - rt_context.in_span(&span, || op.execute(idx, &morsel, state.as_mut()))?; + let op = op.clone(); + let morsel = morsel.clone(); + let span = span.clone(); + let rt_context = rt_context.clone(); + let state_wrapper = state_wrapper.clone(); + let fut = async move { + rt_context.in_span(&span, || op.execute(idx, &morsel, &state_wrapper)) + }; + let result = compute_runtime.await_on(fut).await??; match result { IntermediateOperatorResult::NeedMoreInput(Some(mp)) => { let _ = sender.send(mp.into()).await; diff --git a/src/daft-local-execution/src/intermediate_ops/pivot.rs b/src/daft-local-execution/src/intermediate_ops/pivot.rs index d942053dd9..afac5f9b02 100644 --- a/src/daft-local-execution/src/intermediate_ops/pivot.rs +++ b/src/daft-local-execution/src/intermediate_ops/pivot.rs @@ -38,7 +38,7 @@ impl IntermediateOperator for PivotOperator { &self, _idx: usize, input: &PipelineResultType, - _state: Option<&mut Box>, + _state: &IntermediateOperatorState, ) -> DaftResult { let out = input.as_data().pivot( &self.group_by, diff --git a/src/daft-local-execution/src/intermediate_ops/project.rs b/src/daft-local-execution/src/intermediate_ops/project.rs index abd37b461f..370de989aa 100644 --- a/src/daft-local-execution/src/intermediate_ops/project.rs +++ b/src/daft-local-execution/src/intermediate_ops/project.rs @@ -25,7 +25,7 @@ impl IntermediateOperator for ProjectOperator { &self, _idx: usize, input: &PipelineResultType, - _state: Option<&mut Box>, + _state: &IntermediateOperatorState, ) -> DaftResult { let out = input.as_data().eval_expression_list(&self.projection)?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( diff --git a/src/daft-local-execution/src/intermediate_ops/sample.rs b/src/daft-local-execution/src/intermediate_ops/sample.rs index d74224dc3c..b0e4610292 100644 --- a/src/daft-local-execution/src/intermediate_ops/sample.rs +++ b/src/daft-local-execution/src/intermediate_ops/sample.rs @@ -30,7 +30,7 @@ impl IntermediateOperator for SampleOperator { &self, _idx: usize, input: &PipelineResultType, - _state: Option<&mut Box>, + _state: &IntermediateOperatorState, ) -> DaftResult { let out = input diff --git a/src/daft-local-execution/src/intermediate_ops/unpivot.rs b/src/daft-local-execution/src/intermediate_ops/unpivot.rs index 746d0563c8..5171f9ad42 100644 --- a/src/daft-local-execution/src/intermediate_ops/unpivot.rs +++ b/src/daft-local-execution/src/intermediate_ops/unpivot.rs @@ -38,7 +38,7 @@ impl IntermediateOperator for UnpivotOperator { &self, _idx: usize, input: &PipelineResultType, - _state: Option<&mut Box>, + _state: &IntermediateOperatorState, ) -> DaftResult { let out = input.as_data().unpivot( &self.ids, diff --git a/src/daft-local-execution/src/lib.rs b/src/daft-local-execution/src/lib.rs index c316434a63..e9b7e08b96 100644 --- a/src/daft-local-execution/src/lib.rs +++ b/src/daft-local-execution/src/lib.rs @@ -6,17 +6,41 @@ mod run; mod runtime_stats; mod sinks; mod sources; + use common_error::{DaftError, DaftResult}; use lazy_static::lazy_static; pub use run::NativeExecutor; use snafu::{futures::TryFutureExt, Snafu}; + lazy_static! { pub static ref NUM_CPUS: usize = std::thread::available_parallelism().unwrap().get(); } -pub(crate) type TaskSet = tokio::task::JoinSet; -pub(crate) fn create_task_set() -> TaskSet { - tokio::task::JoinSet::new() +pub(crate) struct TaskSet { + inner: tokio::task::JoinSet, +} + +impl TaskSet { + fn new() -> Self { + Self { + inner: tokio::task::JoinSet::new(), + } + } + + fn spawn(&mut self, future: F) + where + F: std::future::Future + 'static, + { + self.inner.spawn_local(future); + } + + async fn join_next(&mut self) -> Option> { + self.inner.join_next().await + } + + async fn shutdown(&mut self) { + self.inner.shutdown().await; + } } pub struct ExecutionRuntimeHandle { @@ -28,13 +52,13 @@ impl ExecutionRuntimeHandle { #[must_use] pub fn new(default_morsel_size: usize) -> Self { Self { - worker_set: create_task_set(), + worker_set: TaskSet::new(), default_morsel_size, } } pub fn spawn( &mut self, - task: impl std::future::Future> + Send + 'static, + task: impl std::future::Future> + 'static, node_name: &str, ) { let node_name = node_name.to_string(); diff --git a/src/daft-local-execution/src/run.rs b/src/daft-local-execution/src/run.rs index a89d06b2f8..ae0939ef8a 100644 --- a/src/daft-local-execution/src/run.rs +++ b/src/daft-local-execution/src/run.rs @@ -2,10 +2,7 @@ use std::{ collections::HashMap, fs::File, io::Write, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, + sync::Arc, time::{SystemTime, UNIX_EPOCH}, }; @@ -124,19 +121,14 @@ pub fn run_local( let mut pipeline = physical_plan_to_pipeline(physical_plan, &psets)?; let (tx, rx) = create_channel(results_buffer_size.unwrap_or(1)); let handle = std::thread::spawn(move || { - let runtime = tokio::runtime::Builder::new_multi_thread() + let runtime = tokio::runtime::Builder::new_current_thread() .enable_all() - .max_blocking_threads(10) - .thread_name_fn(|| { - static ATOMIC_ID: AtomicUsize = AtomicUsize::new(0); - let id = ATOMIC_ID.fetch_add(1, Ordering::SeqCst); - format!("Executor-Worker-{id}") - }) .build() .expect("Failed to create tokio runtime"); - runtime.block_on(async { + let execution_task = async { let mut runtime_handle = ExecutionRuntimeHandle::new(cfg.default_morsel_size); let mut receiver = pipeline.start(true, &mut runtime_handle)?.get_receiver(); + while let Some(val) = receiver.recv().await { let _ = tx.send(val.as_data().clone()).await; } @@ -164,6 +156,18 @@ pub fn run_local( writeln!(file, "```mermaid\n{}\n```", viz_pipeline(pipeline.as_ref()))?; } Ok(()) + }; + + let local_set = tokio::task::LocalSet::new(); + local_set.block_on(&runtime, async { + tokio::select! { + biased; + _ = tokio::signal::ctrl_c() => { + log::info!("Received Ctrl-C, shutting down execution engine"); + Ok(()) + } + result = execution_task => result, + } }) }); diff --git a/src/daft-local-execution/src/sinks/blocking_sink.rs b/src/daft-local-execution/src/sinks/blocking_sink.rs index dc38e1df34..3fcbf8d660 100644 --- a/src/daft-local-execution/src/sinks/blocking_sink.rs +++ b/src/daft-local-execution/src/sinks/blocking_sink.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use common_display::tree::TreeDisplay; use common_error::DaftResult; +use common_runtime::get_compute_runtime; use daft_micropartition::MicroPartition; use tracing::info_span; @@ -92,19 +93,28 @@ impl PipelineNode for BlockingSinkNode { runtime_handle.spawn( async move { let span = info_span!("BlockingSinkNode::execute"); - let mut guard = op.lock().await; + let compute_runtime = get_compute_runtime(); while let Some(val) = child_results_receiver.recv().await { - if matches!( - rt_context.in_span(&span, || guard.sink(val.as_data()))?, - BlockingSinkStatus::Finished - ) { + let op = op.clone(); + let span = span.clone(); + let rt_context = rt_context.clone(); + let fut = async move { + let mut guard = op.lock().await; + rt_context.in_span(&span, || guard.sink(val.as_data())) + }; + let result = compute_runtime.await_on(fut).await??; + if matches!(result, BlockingSinkStatus::Finished) { break; } } - let finalized_result = rt_context - .in_span(&info_span!("BlockingSinkNode::finalize"), || { - guard.finalize() - })?; + let finalized_result = compute_runtime + .await_on(async move { + let mut guard = op.lock().await; + rt_context.in_span(&info_span!("BlockingSinkNode::finalize"), || { + guard.finalize() + }) + }) + .await??; if let Some(part) = finalized_result { let _ = destination_sender.send(part).await; } diff --git a/src/daft-local-execution/src/sinks/concat.rs b/src/daft-local-execution/src/sinks/concat.rs index 5b98cb84c6..3fc710c691 100644 --- a/src/daft-local-execution/src/sinks/concat.rs +++ b/src/daft-local-execution/src/sinks/concat.rs @@ -4,14 +4,16 @@ use common_error::{DaftError, DaftResult}; use daft_micropartition::MicroPartition; use tracing::instrument; -use super::streaming_sink::{StreamingSink, StreamingSinkOutput, StreamingSinkState}; +use super::streaming_sink::{ + DynStreamingSinkState, StreamingSink, StreamingSinkOutput, StreamingSinkState, +}; use crate::pipeline::PipelineResultType; struct ConcatSinkState { // The index of the last morsel of data that was received, which should be strictly non-decreasing. pub curr_idx: usize, } -impl StreamingSinkState for ConcatSinkState { +impl DynStreamingSinkState for ConcatSinkState { fn as_any_mut(&mut self) -> &mut dyn std::any::Any { self } @@ -28,22 +30,19 @@ impl StreamingSink for ConcatSink { &self, index: usize, input: &PipelineResultType, - state: &mut dyn StreamingSinkState, + state_handle: &StreamingSinkState, ) -> DaftResult { - let state = state - .as_any_mut() - .downcast_mut::() - .expect("ConcatSink should have ConcatSinkState"); - - // If the index is the same as the current index or one more than the current index, then we can accept the morsel. - if state.curr_idx == index || state.curr_idx + 1 == index { - state.curr_idx = index; - Ok(StreamingSinkOutput::NeedMoreInput(Some( - input.as_data().clone(), - ))) - } else { - Err(DaftError::ComputeError(format!("Concat sink received out-of-order data. Expected index to be {} or {}, but got {}.", state.curr_idx, state.curr_idx + 1, index))) - } + state_handle.with_state_mut::(|state| { + // If the index is the same as the current index or one more than the current index, then we can accept the morsel. + if state.curr_idx == index || state.curr_idx + 1 == index { + state.curr_idx = index; + Ok(StreamingSinkOutput::NeedMoreInput(Some( + input.as_data().clone(), + ))) + } else { + Err(DaftError::ComputeError(format!("Concat sink received out-of-order data. Expected index to be {} or {}, but got {}.", state.curr_idx, state.curr_idx + 1, index))) + } + }) } fn name(&self) -> &'static str { @@ -52,12 +51,12 @@ impl StreamingSink for ConcatSink { fn finalize( &self, - _states: Vec>, + _states: Vec>, ) -> DaftResult>> { Ok(None) } - fn make_state(&self) -> Box { + fn make_state(&self) -> Box { Box::new(ConcatSinkState { curr_idx: 0 }) } diff --git a/src/daft-local-execution/src/sinks/limit.rs b/src/daft-local-execution/src/sinks/limit.rs index 633c3511c1..ff24a703e2 100644 --- a/src/daft-local-execution/src/sinks/limit.rs +++ b/src/daft-local-execution/src/sinks/limit.rs @@ -4,7 +4,9 @@ use common_error::DaftResult; use daft_micropartition::MicroPartition; use tracing::instrument; -use super::streaming_sink::{StreamingSink, StreamingSinkOutput, StreamingSinkState}; +use super::streaming_sink::{ + DynStreamingSinkState, StreamingSink, StreamingSinkOutput, StreamingSinkState, +}; use crate::pipeline::PipelineResultType; struct LimitSinkState { @@ -21,7 +23,7 @@ impl LimitSinkState { } } -impl StreamingSinkState for LimitSinkState { +impl DynStreamingSinkState for LimitSinkState { fn as_any_mut(&mut self) -> &mut dyn std::any::Any { self } @@ -43,32 +45,31 @@ impl StreamingSink for LimitSink { &self, index: usize, input: &PipelineResultType, - state: &mut dyn StreamingSinkState, + state_handle: &StreamingSinkState, ) -> DaftResult { assert_eq!(index, 0); - let state = state - .as_any_mut() - .downcast_mut::() - .expect("Limit Sink should have LimitSinkState"); let input = input.as_data(); let input_num_rows = input.len(); - let remaining = state.get_remaining_mut(); - use std::cmp::Ordering::{Equal, Greater, Less}; - match input_num_rows.cmp(remaining) { - Less => { - *remaining -= input_num_rows; - Ok(StreamingSinkOutput::NeedMoreInput(Some(input.clone()))) - } - Equal => { - *remaining = 0; - Ok(StreamingSinkOutput::Finished(Some(input.clone()))) - } - Greater => { - let taken = input.head(*remaining)?; - *remaining = 0; - Ok(StreamingSinkOutput::Finished(Some(Arc::new(taken)))) + + state_handle.with_state_mut::(|state| { + let remaining = state.get_remaining_mut(); + use std::cmp::Ordering::{Equal, Greater, Less}; + match input_num_rows.cmp(remaining) { + Less => { + *remaining -= input_num_rows; + Ok(StreamingSinkOutput::NeedMoreInput(Some(input.clone()))) + } + Equal => { + *remaining = 0; + Ok(StreamingSinkOutput::Finished(Some(input.clone()))) + } + Greater => { + let taken = input.head(*remaining)?; + *remaining = 0; + Ok(StreamingSinkOutput::Finished(Some(Arc::new(taken)))) + } } - } + }) } fn name(&self) -> &'static str { @@ -77,12 +78,12 @@ impl StreamingSink for LimitSink { fn finalize( &self, - _states: Vec>, + _states: Vec>, ) -> DaftResult>> { Ok(None) } - fn make_state(&self) -> Box { + fn make_state(&self) -> Box { Box::new(LimitSinkState::new(self.limit)) } diff --git a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs index ab5ffa8cb0..23cefecf92 100644 --- a/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs +++ b/src/daft-local-execution/src/sinks/outer_hash_join_probe.rs @@ -15,7 +15,9 @@ use daft_table::{GrowableTable, ProbeState, Table}; use indexmap::IndexSet; use tracing::{info_span, instrument}; -use super::streaming_sink::{StreamingSink, StreamingSinkOutput, StreamingSinkState}; +use super::streaming_sink::{ + DynStreamingSinkState, StreamingSink, StreamingSinkOutput, StreamingSinkState, +}; use crate::pipeline::PipelineResultType; struct IndexBitmapBuilder { @@ -106,7 +108,7 @@ impl OuterHashJoinProbeState { } } -impl StreamingSinkState for OuterHashJoinProbeState { +impl DynStreamingSinkState for OuterHashJoinProbeState { fn as_any_mut(&mut self) -> &mut dyn std::any::Any { self } @@ -294,7 +296,7 @@ impl OuterHashJoinProbeSink { fn finalize_outer( &self, - mut states: Vec>, + mut states: Vec>, ) -> DaftResult>> { let states = states .iter_mut() @@ -363,24 +365,19 @@ impl StreamingSink for OuterHashJoinProbeSink { &self, idx: usize, input: &PipelineResultType, - state: &mut dyn StreamingSinkState, + state_handle: &StreamingSinkState, ) -> DaftResult { match idx { 0 => { - let state = state - .as_any_mut() - .downcast_mut::() - .expect("OuterHashJoinProbeSink state should be OuterHashJoinProbeState"); - let probe_state = input.as_probe_state(); - state - .initialize_probe_state(probe_state.clone(), self.join_type == JoinType::Outer); + state_handle.with_state_mut::(|state| { + state.initialize_probe_state( + input.as_probe_state().clone(), + self.join_type == JoinType::Outer, + ); + }); Ok(StreamingSinkOutput::NeedMoreInput(None)) } - _ => { - let state = state - .as_any_mut() - .downcast_mut::() - .expect("OuterHashJoinProbeSink state should be OuterHashJoinProbeState"); + _ => state_handle.with_state_mut::(|state| { let input = input.as_data(); if input.is_empty() { let empty = Arc::new(MicroPartition::empty(Some(self.output_schema.clone()))); @@ -394,7 +391,7 @@ impl StreamingSink for OuterHashJoinProbeSink { ), }?; Ok(StreamingSinkOutput::NeedMoreInput(Some(out))) - } + }), } } @@ -402,13 +399,13 @@ impl StreamingSink for OuterHashJoinProbeSink { "OuterHashJoinProbeSink" } - fn make_state(&self) -> Box { + fn make_state(&self) -> Box { Box::new(OuterHashJoinProbeState::Building) } fn finalize( &self, - states: Vec>, + states: Vec>, ) -> DaftResult>> { if self.join_type == JoinType::Outer { self.finalize_outer(states) diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index 6e8a022cdb..102fd39618 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -1,23 +1,47 @@ -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use common_display::tree::TreeDisplay; use common_error::DaftResult; +use common_runtime::get_compute_runtime; use daft_micropartition::MicroPartition; use snafu::ResultExt; use tracing::{info_span, instrument}; use crate::{ channel::{create_channel, PipelineChannel, Receiver, Sender}, - create_task_set, pipeline::{PipelineNode, PipelineResultType}, runtime_stats::{CountingReceiver, RuntimeStatsContext}, ExecutionRuntimeHandle, JoinSnafu, TaskSet, NUM_CPUS, }; -pub trait StreamingSinkState: Send + Sync { +pub trait DynStreamingSinkState: Send + Sync { fn as_any_mut(&mut self) -> &mut dyn std::any::Any; } +pub(crate) struct StreamingSinkState { + inner: Mutex>, +} + +impl StreamingSinkState { + fn new(inner: Box) -> Arc { + Arc::new(Self { + inner: Mutex::new(inner), + }) + } + + pub(crate) fn with_state_mut(&self, f: F) -> R + where + F: FnOnce(&mut T) -> R, + { + let mut guard = self.inner.lock().unwrap(); + let state = guard + .as_any_mut() + .downcast_mut::() + .expect("State type mismatch"); + f(state) + } +} + pub enum StreamingSinkOutput { NeedMoreInput(Option>), #[allow(dead_code)] @@ -33,20 +57,20 @@ pub trait StreamingSink: Send + Sync { &self, index: usize, input: &PipelineResultType, - state: &mut dyn StreamingSinkState, + state_handle: &StreamingSinkState, ) -> DaftResult; /// Finalize the StreamingSink operator, with the given states from each worker. fn finalize( &self, - states: Vec>, + states: Vec>, ) -> DaftResult>>; /// The name of the StreamingSink operator. fn name(&self) -> &'static str; /// Create a new worker-local state for this StreamingSink. - fn make_state(&self) -> Box; + fn make_state(&self) -> Box; /// The maximum number of concurrent workers that can be spawned for this sink. /// Each worker will has its own StreamingSinkState. @@ -83,13 +107,25 @@ impl StreamingSinkNode { mut input_receiver: Receiver<(usize, PipelineResultType)>, output_sender: Sender>, rt_context: Arc, - ) -> DaftResult> { + ) -> DaftResult> { let span = info_span!("StreamingSink::Execute"); - let mut state = op.make_state(); + let compute_runtime = get_compute_runtime(); + let state_wrapper = StreamingSinkState::new(op.make_state()); + let mut finished = false; while let Some((idx, morsel)) = input_receiver.recv().await { + if finished { + break; + } loop { - let result = - rt_context.in_span(&span, || op.execute(idx, &morsel, state.as_mut()))?; + let op = op.clone(); + let morsel = morsel.clone(); + let span = span.clone(); + let rt_context = rt_context.clone(); + let state_wrapper = state_wrapper.clone(); + let fut = async move { + rt_context.in_span(&span, || op.execute(idx, &morsel, state_wrapper.as_ref())) + }; + let result = compute_runtime.await_on(fut).await??; match result { StreamingSinkOutput::NeedMoreInput(mp) => { if let Some(mp) = mp { @@ -104,18 +140,26 @@ impl StreamingSinkNode { if let Some(mp) = mp { let _ = output_sender.send(mp).await; } - return Ok(state); + finished = true; + break; } } } } - Ok(state) + + // Take the state out of the Arc and Mutex because we need to return it. + // It should be guaranteed that the ONLY holder of state at this point is this function. + Ok(Arc::into_inner(state_wrapper) + .expect("Completed worker should have exclusive access to state wrapper") + .inner + .into_inner() + .expect("Completed worker should have exclusive access to inner state")) } fn spawn_workers( op: Arc, input_receivers: Vec>, - task_set: &mut TaskSet>>, + task_set: &mut TaskSet>>, stats: Arc, ) -> Receiver> { let (output_sender, output_receiver) = create_channel(input_receivers.len()); @@ -217,7 +261,7 @@ impl PipelineNode for StreamingSinkNode { ); runtime_handle.spawn( async move { - let mut task_set = create_task_set(); + let mut task_set = TaskSet::new(); let mut output_receiver = Self::spawn_workers( op.clone(), input_receivers, @@ -235,8 +279,16 @@ impl PipelineNode for StreamingSinkNode { finished_states.push(state); } - if let Some(finalized_result) = op.finalize(finished_states)? { - let _ = destination_sender.send(finalized_result.into()).await; + let compute_runtime = get_compute_runtime(); + let finalized_result = compute_runtime + .await_on(async move { + runtime_stats.in_span(&info_span!("StreamingSinkNode::finalize"), || { + op.finalize(finished_states) + }) + }) + .await??; + if let Some(res) = finalized_result { + let _ = destination_sender.send(res.into()).await; } Ok(()) }, diff --git a/src/daft-local-execution/src/sources/scan_task.rs b/src/daft-local-execution/src/sources/scan_task.rs index 17380251ee..3be2f61691 100644 --- a/src/daft-local-execution/src/sources/scan_task.rs +++ b/src/daft-local-execution/src/sources/scan_task.rs @@ -138,50 +138,52 @@ async fn get_delete_map( .storage_config .get_io_client_and_runtime()?; let scan_tasks = scan_tasks.to_vec(); - runtime.block_on_io_pool(async move { - let mut delete_map = scan_tasks - .iter() - .flat_map(|st| st.sources.iter().map(|s| s.get_path().to_string())) - .map(|path| (path, vec![])) - .collect::>(); - let columns_to_read = Some(vec!["file_path".to_string(), "pos".to_string()]); - let result = read_parquet_bulk_async( - delete_files.into_iter().collect(), - columns_to_read, - None, - None, - None, - None, - io_client, - None, - *NUM_CPUS, - ParquetSchemaInferenceOptions::new(None), - None, - None, - None, - None, - ) - .await?; + runtime + .await_on(async move { + let mut delete_map = scan_tasks + .iter() + .flat_map(|st| st.sources.iter().map(|s| s.get_path().to_string())) + .map(|path| (path, vec![])) + .collect::>(); + let columns_to_read = Some(vec!["file_path".to_string(), "pos".to_string()]); + let result = read_parquet_bulk_async( + delete_files.into_iter().collect(), + columns_to_read, + None, + None, + None, + None, + io_client, + None, + *NUM_CPUS, + ParquetSchemaInferenceOptions::new(None), + None, + None, + None, + None, + ) + .await?; - for table_result in result { - let table = table_result?; - // values in the file_path column are guaranteed by the iceberg spec to match the full URI of the corresponding data file - // https://iceberg.apache.org/spec/#position-delete-files - let file_paths = table.get_column("file_path")?.downcast::()?; - let positions = table.get_column("pos")?.downcast::()?; + for table_result in result { + let table = table_result?; + // values in the file_path column are guaranteed by the iceberg spec to match the full URI of the corresponding data file + // https://iceberg.apache.org/spec/#position-delete-files + let file_paths = table.get_column("file_path")?.downcast::()?; + let positions = table.get_column("pos")?.downcast::()?; - for (file, pos) in file_paths - .as_arrow() - .values_iter() - .zip(positions.as_arrow().values_iter()) - { - if delete_map.contains_key(file) { - delete_map.get_mut(file).unwrap().push(*pos); + for (file, pos) in file_paths + .as_arrow() + .values_iter() + .zip(positions.as_arrow().values_iter()) + { + if delete_map.contains_key(file) { + delete_map.get_mut(file).unwrap().push(*pos); + } } } - } - Ok(Some(delete_map)) - })? + Ok(Some(delete_map)) + }) + .await? } async fn stream_scan_task( diff --git a/src/daft-micropartition/Cargo.toml b/src/daft-micropartition/Cargo.toml index 2b8a405ef8..b03935d6a6 100644 --- a/src/daft-micropartition/Cargo.toml +++ b/src/daft-micropartition/Cargo.toml @@ -3,6 +3,7 @@ arrow2 = {workspace = true} bincode = {workspace = true} common-error = {path = "../common/error", default-features = false} common-file-formats = {path = "../common/file-formats", default-features = false} +common-runtime = {path = "../common/runtime", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-csv = {path = "../daft-csv", default-features = false} daft-dsl = {path = "../daft-dsl", default-features = false} diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index ab5b592fc7..af78186cd8 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -7,10 +7,11 @@ use std::{ use arrow2::io::parquet::read::schema::infer_schema_with_options; use common_error::DaftResult; use common_file_formats::{CsvSourceConfig, FileFormatConfig, ParquetSourceConfig}; +use common_runtime::get_io_runtime; use daft_core::prelude::*; use daft_csv::{CsvConvertOptions, CsvParseOptions, CsvReadOptions}; use daft_dsl::ExprRef; -use daft_io::{get_runtime, IOClient, IOConfig, IOStatsContext, IOStatsRef}; +use daft_io::{IOClient, IOConfig, IOStatsContext, IOStatsRef}; use daft_json::{JsonConvertOptions, JsonParseOptions, JsonReadOptions}; use daft_parquet::read::{ read_parquet_bulk, read_parquet_metadata_bulk, ParquetSchemaInferenceOptions, @@ -1064,7 +1065,7 @@ pub fn read_parquet_into_micropartition>( chunk_size, ); } - let runtime_handle = get_runtime(multithreaded_io)?; + let runtime_handle = get_io_runtime(multithreaded_io); // Attempt to read TableStatistics from the Parquet file let meta_io_client = io_client.clone(); let meta_io_stats = io_stats.clone(); diff --git a/src/daft-parquet/Cargo.toml b/src/daft-parquet/Cargo.toml index 217ac9c96c..f80e46566e 100644 --- a/src/daft-parquet/Cargo.toml +++ b/src/daft-parquet/Cargo.toml @@ -5,6 +5,7 @@ async-stream = {workspace = true} bytes = {workspace = true} common-arrow-ffi = {path = "../common/arrow-ffi", default-features = false} common-error = {path = "../common/error", default-features = false} +common-runtime = {path = "../common/runtime", default-features = false} crossbeam-channel = "0.5.1" daft-core = {path = "../daft-core", default-features = false} daft-dsl = {path = "../daft-dsl", default-features = false} diff --git a/src/daft-parquet/src/read.rs b/src/daft-parquet/src/read.rs index 5897f0f2be..73141bf7ac 100644 --- a/src/daft-parquet/src/read.rs +++ b/src/daft-parquet/src/read.rs @@ -11,11 +11,12 @@ use arrow2::{ }, }; use common_error::DaftResult; +use common_runtime::get_io_runtime; use daft_core::prelude::*; #[cfg(feature = "python")] use daft_core::python::PyTimeUnit; use daft_dsl::{optimization::get_required_columns, ExprRef}; -use daft_io::{get_runtime, parse_url, IOClient, IOStatsRef, SourceType}; +use daft_io::{parse_url, IOClient, IOStatsRef, SourceType}; use daft_table::Table; use futures::{ future::{join_all, try_join_all}, @@ -674,7 +675,7 @@ pub fn read_parquet( schema_infer_options: ParquetSchemaInferenceOptions, metadata: Option>, ) -> DaftResult
{ - let runtime_handle = daft_io::get_runtime(multithreaded_io)?; + let runtime_handle = get_io_runtime(multithreaded_io); runtime_handle.block_on_current_thread(async { read_parquet_single( @@ -713,7 +714,7 @@ pub fn read_parquet_into_pyarrow( schema_infer_options: ParquetSchemaInferenceOptions, file_timeout_ms: Option, ) -> DaftResult { - let runtime_handle = daft_io::get_runtime(multithreaded_io)?; + let runtime_handle = get_io_runtime(multithreaded_io); runtime_handle.block_on_current_thread(async { let fut = read_parquet_single_into_arrow( uri, @@ -760,7 +761,7 @@ pub fn read_parquet_bulk>( delete_map: Option>>, chunk_size: Option, ) -> DaftResult> { - let runtime_handle = daft_io::get_runtime(multithreaded_io)?; + let runtime_handle = get_io_runtime(multithreaded_io); let columns = columns.map(|s| s.iter().map(|v| v.as_ref().to_string()).collect::>()); if let Some(ref row_groups) = row_groups { @@ -903,7 +904,7 @@ pub fn read_parquet_into_pyarrow_bulk>( multithreaded_io: bool, schema_infer_options: ParquetSchemaInferenceOptions, ) -> DaftResult> { - let runtime_handle = get_runtime(multithreaded_io)?; + let runtime_handle = get_io_runtime(multithreaded_io); let columns = columns.map(|s| s.iter().map(|v| v.as_ref().to_string()).collect::>()); if let Some(ref row_groups) = row_groups { if row_groups.len() != uris.len() { @@ -960,7 +961,7 @@ pub fn read_parquet_schema( schema_inference_options: ParquetSchemaInferenceOptions, field_id_mapping: Option>>, ) -> DaftResult<(Schema, FileMetaData)> { - let runtime_handle = get_runtime(true)?; + let runtime_handle = get_io_runtime(true); let builder = runtime_handle.block_on_current_thread(async { ParquetReaderBuilder::from_uri(uri, io_client.clone(), io_stats, field_id_mapping).await })?; @@ -1015,7 +1016,7 @@ pub fn read_parquet_statistics( io_stats: Option, field_id_mapping: Option>>, ) -> DaftResult
{ - let runtime_handle = get_runtime(true)?; + let runtime_handle = get_io_runtime(true); if uris.data_type() != &DataType::Utf8 { return Err(common_error::DaftError::ValueError(format!( @@ -1154,7 +1155,7 @@ mod tests { io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let runtime_handle = daft_io::get_runtime(true)?; + let runtime_handle = get_io_runtime(true); runtime_handle.block_on_current_thread(async move { let tables = stream_parquet( file, @@ -1187,9 +1188,9 @@ mod tests { let io_config = IOConfig::default(); let io_client = Arc::new(IOClient::new(io_config.into())?); - let runtime_handle = daft_io::get_runtime(true)?; + let runtime_handle = get_io_runtime(true); - runtime_handle.block_on_io_pool(async move { + runtime_handle.block_on(async move { let metadata = read_parquet_metadata(&file, io_client, None, None).await?; let serialized = bincode::serialize(&metadata).unwrap(); let deserialized = bincode::deserialize::(&serialized).unwrap(); @@ -1214,9 +1215,9 @@ mod tests { .into(); let io_config = IOConfig::default(); let io_client = Arc::new(IOClient::new(io_config.into()).unwrap()); - let runtime_handle = daft_io::get_runtime(true).unwrap(); + let runtime_handle = get_io_runtime(true); let file_metadata = runtime_handle - .block_on_io_pool({ + .block_on({ let parquet = parquet.clone(); let io_client = io_client.clone(); async move { read_parquet_metadata(&parquet, io_client, None, None).await } diff --git a/src/daft-scan/Cargo.toml b/src/daft-scan/Cargo.toml index d4c5e5a230..d7f5bd557c 100644 --- a/src/daft-scan/Cargo.toml +++ b/src/daft-scan/Cargo.toml @@ -5,6 +5,7 @@ common-error = {path = "../common/error", default-features = false} common-file-formats = {path = "../common/file-formats", default-features = false} common-io-config = {path = "../common/io-config", default-features = false} common-py-serde = {path = "../common/py-serde", default-features = false} +common-runtime = {path = "../common/runtime", default-features = false} daft-core = {path = "../daft-core", default-features = false} daft-csv = {path = "../daft-csv", default-features = false} daft-dsl = {path = "../daft-dsl", default-features = false} diff --git a/src/daft-scan/src/glob.rs b/src/daft-scan/src/glob.rs index 72f6307184..883de475eb 100644 --- a/src/daft-scan/src/glob.rs +++ b/src/daft-scan/src/glob.rs @@ -2,9 +2,10 @@ use std::{sync::Arc, vec}; use common_error::{DaftError, DaftResult}; use common_file_formats::{CsvSourceConfig, FileFormat, FileFormatConfig, ParquetSourceConfig}; +use common_runtime::RuntimeRef; use daft_core::{prelude::Utf8Array, series::IntoSeries}; use daft_csv::CsvParseOptions; -use daft_io::{parse_url, FileMetadata, IOClient, IOStatsContext, IOStatsRef, RuntimeRef}; +use daft_io::{parse_url, FileMetadata, IOClient, IOStatsContext, IOStatsRef}; use daft_parquet::read::ParquetSchemaInferenceOptions; use daft_schema::{dtype::DataType, field::Field, schema::SchemaRef}; use daft_stats::PartitionSpec; diff --git a/src/daft-scan/src/storage_config.rs b/src/daft-scan/src/storage_config.rs index 9a672c8cce..c502ae62dc 100644 --- a/src/daft-scan/src/storage_config.rs +++ b/src/daft-scan/src/storage_config.rs @@ -3,7 +3,8 @@ use std::sync::Arc; use common_error::DaftResult; use common_io_config::IOConfig; use common_py_serde::impl_bincode_py_state_serialization; -use daft_io::{get_io_client, get_runtime, IOClient, RuntimeRef}; +use common_runtime::{get_io_runtime, RuntimeRef}; +use daft_io::{get_io_client, IOClient}; use serde::{Deserialize, Serialize}; #[cfg(feature = "python")] use { @@ -29,7 +30,7 @@ impl StorageConfig { Self::Native(cfg) => { let multithreaded_io = cfg.multithreaded_io; Ok(( - get_runtime(multithreaded_io)?, + get_io_runtime(multithreaded_io), get_io_client( multithreaded_io, Arc::new(cfg.io_config.clone().unwrap_or_default()), @@ -40,7 +41,7 @@ impl StorageConfig { Self::Python(cfg) => { let multithreaded_io = true; // Hardcode to use multithreaded IO if Python storage config is used for data fetches Ok(( - get_runtime(multithreaded_io)?, + get_io_runtime(multithreaded_io), get_io_client( multithreaded_io, Arc::new(cfg.io_config.clone().unwrap_or_default()),