Skip to content

Commit

Permalink
[FEAT] migrate schema inference → async, block at py boundary (#3432)
Browse files Browse the repository at this point in the history
Converts schema inference operations for CSV, JSON, and Parquet files to
use async/await instead of synchronous runtime blocking. This
architectural change ensures that blocking operations happen at the
highest level possible (the Python API boundary) rather than deep within
the inference logic. Key changes include:

- Making read_csv_schema, read_json_schema, and read_parquet_schema
async
- Updating scan builder interfaces to use async finish() methods 
- Removing unnecessary runtime.block_on calls from schema inference
paths
- Moving runtime.block_on calls to Python API layer where blocking is
unavoidable
- Converting schema-related tests to use tokio async runtime
- Adding common-runtime dependency where needed
- Fixes #3423

This change improves the consistency of async IO handling and creates a
cleaner architecture where blocking is consolidated at the Python
interface rather than scattered throughout the codebase.
  • Loading branch information
andrewgazelka authored Nov 27, 2024
1 parent e89c9f5 commit b6eee0b
Show file tree
Hide file tree
Showing 16 changed files with 231 additions and 187 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,11 @@ chrono-tz = "0.10.0"
comfy-table = "7.1.1"
common-daft-config = {path = "src/common/daft-config"}
common-error = {path = "src/common/error", default-features = false}
common-runtime = {path = "src/common/runtime", default-features = false}
daft-core = {path = "src/daft-core"}
daft-dsl = {path = "src/daft-dsl"}
daft-hash = {path = "src/daft-hash"}
daft-local-execution = {path = "src/daft-local-execution"}
daft-local-plan = {path = "src/daft-local-plan"}
daft-logical-plan = {path = "src/daft-logical-plan"}
daft-scan = {path = "src/daft-scan"}
daft-schema = {path = "src/daft-schema"}
Expand Down
145 changes: 75 additions & 70 deletions src/daft-csv/src/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::{collections::HashSet, sync::Arc};
use arrow2::io::csv::read_async::{AsyncReader, AsyncReaderBuilder};
use async_compat::CompatExt;
use common_error::DaftResult;
use common_runtime::get_io_runtime;
use csv_async::ByteRecord;
use daft_compression::CompressionCodec;
use daft_core::prelude::Schema;
Expand Down Expand Up @@ -52,25 +51,22 @@ impl Default for CsvReadStats {
}
}

pub fn read_csv_schema(
pub async fn read_csv_schema(
uri: &str,
parse_options: Option<CsvParseOptions>,
max_bytes: Option<usize>,
io_client: Arc<IOClient>,
io_stats: Option<IOStatsRef>,
) -> DaftResult<(Schema, CsvReadStats)> {
let runtime_handle = get_io_runtime(true);
runtime_handle.block_on_current_thread(async {
read_csv_schema_single(
uri,
parse_options.unwrap_or_default(),
// Default to 1 MiB.
max_bytes.or(Some(1024 * 1024)),
io_client,
io_stats,
)
.await
})
read_csv_schema_single(
uri,
parse_options.unwrap_or_default(),
// Default to 1 MiB.
max_bytes.or(Some(1024 * 1024)),
io_client,
io_stats,
)
.await
}

pub async fn read_csv_schema_bulk(
Expand All @@ -81,32 +77,32 @@ pub async fn read_csv_schema_bulk(
io_stats: Option<IOStatsRef>,
num_parallel_tasks: usize,
) -> DaftResult<Vec<(Schema, CsvReadStats)>> {
let runtime_handle = get_io_runtime(true);
let result = runtime_handle
.block_on_current_thread(async {
let task_stream = futures::stream::iter(uris.iter().map(|uri| {
let owned_string = (*uri).to_string();
let owned_client = io_client.clone();
let owned_io_stats = io_stats.clone();
let owned_parse_options = parse_options.clone();
tokio::spawn(async move {
read_csv_schema_single(
&owned_string,
owned_parse_options.unwrap_or_default(),
max_bytes,
owned_client,
owned_io_stats,
)
.await
})
}));
task_stream
.buffered(num_parallel_tasks)
.try_collect::<Vec<_>>()
let result = async {
let task_stream = futures::stream::iter(uris.iter().map(|uri| {
let owned_string = (*uri).to_string();
let owned_client = io_client.clone();
let owned_io_stats = io_stats.clone();
let owned_parse_options = parse_options.clone();
tokio::spawn(async move {
read_csv_schema_single(
&owned_string,
owned_parse_options.unwrap_or_default(),
max_bytes,
owned_client,
owned_io_stats,
)
.await
})
.context(super::JoinSnafu {})?;
result.into_iter().collect::<DaftResult<Vec<_>>>()
})
}));
task_stream
.buffered(num_parallel_tasks)
.try_collect::<Vec<_>>()
.await
}
.await
.context(super::JoinSnafu {})?;

result.into_iter().collect()
}

pub(crate) async fn read_csv_schema_single(
Expand Down Expand Up @@ -300,7 +296,8 @@ mod tests {
use crate::CsvParseOptions;

#[rstest]
fn test_csv_schema_local(
#[tokio::test]
async fn test_csv_schema_local(
#[values(
// Uncompressed
None,
Expand Down Expand Up @@ -333,7 +330,8 @@ mod tests {
io_config.s3.anonymous = true;
let io_client = Arc::new(IOClient::new(io_config.into())?);

let (schema, read_stats) = read_csv_schema(file.as_ref(), None, None, io_client, None)?;
let (schema, read_stats) =
read_csv_schema(file.as_ref(), None, None, io_client, None).await?;
assert_eq!(
schema,
Schema::new(vec![
Expand All @@ -350,8 +348,8 @@ mod tests {
Ok(())
}

#[test]
fn test_csv_schema_local_delimiter() -> DaftResult<()> {
#[tokio::test]
async fn test_csv_schema_local_delimiter() -> DaftResult<()> {
let file = format!(
"{}/test/iris_tiny_bar_delimiter.csv",
env!("CARGO_MANIFEST_DIR"),
Expand All @@ -367,7 +365,8 @@ mod tests {
None,
io_client,
None,
)?;
)
.await?;
assert_eq!(
schema,
Schema::new(vec![
Expand All @@ -384,23 +383,23 @@ mod tests {
Ok(())
}

#[test]
fn test_csv_schema_local_read_stats() -> DaftResult<()> {
#[tokio::test]
async fn test_csv_schema_local_read_stats() -> DaftResult<()> {
let file = format!("{}/test/iris_tiny.csv", env!("CARGO_MANIFEST_DIR"),);

let mut io_config = IOConfig::default();
io_config.s3.anonymous = true;
let io_client = Arc::new(IOClient::new(io_config.into())?);

let (_, read_stats) = read_csv_schema(file.as_ref(), None, None, io_client, None)?;
let (_, read_stats) = read_csv_schema(file.as_ref(), None, None, io_client, None).await?;
assert_eq!(read_stats.total_bytes_read, 328);
assert_eq!(read_stats.total_records_read, 20);

Ok(())
}

#[test]
fn test_csv_schema_local_no_headers() -> DaftResult<()> {
#[tokio::test]
async fn test_csv_schema_local_no_headers() -> DaftResult<()> {
let file = format!(
"{}/test/iris_tiny_no_headers.csv",
env!("CARGO_MANIFEST_DIR"),
Expand All @@ -416,7 +415,8 @@ mod tests {
None,
io_client,
None,
)?;
)
.await?;
assert_eq!(
schema,
Schema::new(vec![
Expand All @@ -433,8 +433,8 @@ mod tests {
Ok(())
}

#[test]
fn test_csv_schema_local_empty_lines_skipped() -> DaftResult<()> {
#[tokio::test]
async fn test_csv_schema_local_empty_lines_skipped() -> DaftResult<()> {
let file = format!(
"{}/test/iris_tiny_empty_lines.csv",
env!("CARGO_MANIFEST_DIR"),
Expand All @@ -444,7 +444,8 @@ mod tests {
io_config.s3.anonymous = true;
let io_client = Arc::new(IOClient::new(io_config.into())?);

let (schema, read_stats) = read_csv_schema(file.as_ref(), None, None, io_client, None)?;
let (schema, read_stats) =
read_csv_schema(file.as_ref(), None, None, io_client, None).await?;
assert_eq!(
schema,
Schema::new(vec![
Expand All @@ -461,15 +462,16 @@ mod tests {
Ok(())
}

#[test]
fn test_csv_schema_local_nulls() -> DaftResult<()> {
#[tokio::test]
async fn test_csv_schema_local_nulls() -> DaftResult<()> {
let file = format!("{}/test/iris_tiny_nulls.csv", env!("CARGO_MANIFEST_DIR"),);

let mut io_config = IOConfig::default();
io_config.s3.anonymous = true;
let io_client = Arc::new(IOClient::new(io_config.into())?);

let (schema, read_stats) = read_csv_schema(file.as_ref(), None, None, io_client, None)?;
let (schema, read_stats) =
read_csv_schema(file.as_ref(), None, None, io_client, None).await?;
assert_eq!(
schema,
Schema::new(vec![
Expand All @@ -486,8 +488,8 @@ mod tests {
Ok(())
}

#[test]
fn test_csv_schema_local_conflicting_types_utf8_fallback() -> DaftResult<()> {
#[tokio::test]
async fn test_csv_schema_local_conflicting_types_utf8_fallback() -> DaftResult<()> {
let file = format!(
"{}/test/iris_tiny_conflicting_dtypes.csv",
env!("CARGO_MANIFEST_DIR"),
Expand All @@ -497,7 +499,8 @@ mod tests {
io_config.s3.anonymous = true;
let io_client = Arc::new(IOClient::new(io_config.into())?);

let (schema, read_stats) = read_csv_schema(file.as_ref(), None, None, io_client, None)?;
let (schema, read_stats) =
read_csv_schema(file.as_ref(), None, None, io_client, None).await?;
assert_eq!(
schema,
Schema::new(vec![
Expand All @@ -515,16 +518,16 @@ mod tests {
Ok(())
}

#[test]
fn test_csv_schema_local_max_bytes() -> DaftResult<()> {
#[tokio::test]
async fn test_csv_schema_local_max_bytes() -> DaftResult<()> {
let file = format!("{}/test/iris_tiny.csv", env!("CARGO_MANIFEST_DIR"),);

let mut io_config = IOConfig::default();
io_config.s3.anonymous = true;
let io_client = Arc::new(IOClient::new(io_config.into())?);

let (schema, read_stats) =
read_csv_schema(file.as_ref(), None, Some(100), io_client, None)?;
read_csv_schema(file.as_ref(), None, Some(100), io_client, None).await?;
assert_eq!(
schema,
Schema::new(vec![
Expand All @@ -550,8 +553,8 @@ mod tests {
Ok(())
}

#[test]
fn test_csv_schema_local_invalid_column_header_mismatch() -> DaftResult<()> {
#[tokio::test]
async fn test_csv_schema_local_invalid_column_header_mismatch() -> DaftResult<()> {
let file = format!(
"{}/test/iris_tiny_invalid_header_cols_mismatch.csv",
env!("CARGO_MANIFEST_DIR"),
Expand All @@ -561,7 +564,7 @@ mod tests {
io_config.s3.anonymous = true;
let io_client = Arc::new(IOClient::new(io_config.into())?);

let err = read_csv_schema(file.as_ref(), None, None, io_client, None);
let err = read_csv_schema(file.as_ref(), None, None, io_client, None).await;
assert!(err.is_err());
let err = err.unwrap_err();
assert!(matches!(err, DaftError::ArrowError(_)), "{}", err);
Expand All @@ -575,8 +578,8 @@ mod tests {
Ok(())
}

#[test]
fn test_csv_schema_local_invalid_no_header_variable_num_cols() -> DaftResult<()> {
#[tokio::test]
async fn test_csv_schema_local_invalid_no_header_variable_num_cols() -> DaftResult<()> {
let file = format!(
"{}/test/iris_tiny_invalid_no_header_variable_num_cols.csv",
env!("CARGO_MANIFEST_DIR"),
Expand All @@ -592,7 +595,8 @@ mod tests {
None,
io_client,
None,
);
)
.await;
assert!(err.is_err());
let err = err.unwrap_err();
assert!(matches!(err, DaftError::ArrowError(_)), "{}", err);
Expand All @@ -607,7 +611,8 @@ mod tests {
}

#[rstest]
fn test_csv_schema_s3(
#[tokio::test]
async fn test_csv_schema_s3(
#[values(
// Uncompressed
None,
Expand Down Expand Up @@ -639,7 +644,7 @@ mod tests {
io_config.s3.anonymous = true;
let io_client = Arc::new(IOClient::new(io_config.into())?);

let (schema, _) = read_csv_schema(file.as_ref(), None, None, io_client, None)?;
let (schema, _) = read_csv_schema(file.as_ref(), None, None, io_client, None).await?;
assert_eq!(
schema,
Schema::new(vec![
Expand Down
21 changes: 14 additions & 7 deletions src/daft-csv/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,20 @@ pub mod pylib {
multithreaded_io.unwrap_or(true),
io_config.unwrap_or_default().config.into(),
)?;
let (schema, _) = crate::metadata::read_csv_schema(
uri,
parse_options,
max_bytes,
io_client,
Some(io_stats),
)?;

let runtime = common_runtime::get_io_runtime(multithreaded_io.unwrap_or(true));

let (schema, _) = runtime.block_on_current_thread(async move {
crate::metadata::read_csv_schema(
uri,
parse_options,
max_bytes,
io_client,
Some(io_stats),
)
.await
})?;

Ok(Arc::new(schema).into())
})
}
Expand Down
Loading

0 comments on commit b6eee0b

Please sign in to comment.