diff --git a/Cargo.lock b/Cargo.lock index 76a93d64ad..60f509ebc5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1343,6 +1343,7 @@ dependencies = [ "pyo3-log", "serde", "snafu", + "tokio", ] [[package]] diff --git a/src/daft-micropartition/Cargo.toml b/src/daft-micropartition/Cargo.toml index 5c7e026dc2..a1a8df1a1e 100644 --- a/src/daft-micropartition/Cargo.toml +++ b/src/daft-micropartition/Cargo.toml @@ -18,6 +18,7 @@ pyo3 = {workspace = true, optional = true} pyo3-log = {workspace = true} serde = {workspace = true} snafu = {workspace = true} +tokio = {workspace = true} [features] default = ["python"] diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index 12e9b2e9ab..48321b51f6 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -26,7 +26,7 @@ use snafu::ResultExt; use crate::PyIOSnafu; use crate::{DaftCSVSnafu, DaftCoreComputeSnafu}; -use daft_io::{IOConfig, IOStatsContext, IOStatsRef}; +use daft_io::{IOClient, IOConfig, IOStatsContext, IOStatsRef}; use daft_stats::TableMetadata; use daft_stats::TableStatistics; @@ -422,6 +422,7 @@ impl MicroPartition { &ParquetSchemaInferenceOptions { coerce_int96_timestamp_unit: *coerce_int96_timestamp_unit, }, + Some(schema.clone()), field_id_mapping, ) .context(DaftCoreComputeSnafu)?; @@ -697,29 +698,66 @@ pub(crate) fn read_json_into_micropartition( } } -// TODO: Deduplicate this with the other `rename_schema_recursively` function in file.rs -fn rename_schema_recursively( - daft_schema: Schema, - field_id_mapping: &BTreeMap, -) -> DaftResult { - Schema::new( - daft_schema - .fields - .into_iter() - .map(|(_, field)| { - if let Some(field_id) = field.metadata.get("field_id") { - let field_id = str::parse::(field_id).unwrap(); - let mapped_field = field_id_mapping.get(&field_id); - match mapped_field { - None => field, - Some(mapped_field) => field.rename(&mapped_field.name), - } - } else { - field - } - }) - .collect(), - ) +#[allow(clippy::too_many_arguments)] +fn _read_parquet_into_loaded_micropartition( + io_client: Arc, + runtime_handle: Arc, + uris: &[&str], + columns: Option<&[&str]>, + start_offset: Option, + num_rows: Option, + row_groups: Option>>>, + predicate: Option, + io_stats: Option, + num_parallel_tasks: usize, + schema_infer_options: &ParquetSchemaInferenceOptions, + catalog_provided_schema: Option, + field_id_mapping: &Option>>, +) -> DaftResult { + let all_tables = read_parquet_bulk( + uris, + columns, + start_offset, + num_rows, + row_groups, + predicate, + io_client, + io_stats, + num_parallel_tasks, + runtime_handle, + schema_infer_options, + field_id_mapping, + )?; + + // Prefer using the `catalog_provided_schema` but fall back onto inferred schema from Parquet files + let full_daft_schema = match catalog_provided_schema { + Some(catalog_provided_schema) => catalog_provided_schema, + None => { + let unioned_schema = all_tables + .iter() + .map(|t| t.schema.clone()) + .try_reduce(|l, r| DaftResult::Ok(l.union(&r)?.into()))?; + unioned_schema.expect("we need at least 1 schema") + } + }; + + // Hack to avoid to owned schema + let full_daft_schema = Schema { + fields: full_daft_schema.fields.clone(), + }; + let pruned_daft_schema = prune_fields_from_schema(full_daft_schema, columns)?; + + let all_tables = all_tables + .into_iter() + .map(|t| t.cast_to_schema(&pruned_daft_schema)) + .collect::>>()?; + + // TODO: we can pass in stats here to optimize downstream workloads such as join + Ok(MicroPartition::new_loaded( + Arc::new(pruned_daft_schema), + all_tables.into(), + None, + )) } #[allow(clippy::too_many_arguments)] @@ -735,6 +773,7 @@ pub(crate) fn read_parquet_into_micropartition( num_parallel_tasks: usize, multithreaded_io: bool, schema_infer_options: &ParquetSchemaInferenceOptions, + catalog_provided_schema: Option, field_id_mapping: &Option>>, ) -> DaftResult { if let Some(so) = start_offset && so > 0 { @@ -745,44 +784,23 @@ pub(crate) fn read_parquet_into_micropartition( let runtime_handle = daft_io::get_runtime(multithreaded_io)?; let io_client = daft_io::get_io_client(multithreaded_io, io_config.clone())?; - if let Some(predicate) = predicate { + if predicate.is_some() { // We have a predicate, so we will perform eager read only reading what row groups we need. - let all_tables = read_parquet_bulk( + return _read_parquet_into_loaded_micropartition( + io_client, + runtime_handle, uris, columns, - None, + start_offset, num_rows, row_groups, - Some(predicate.clone()), - io_client, + predicate, io_stats, num_parallel_tasks, - runtime_handle, schema_infer_options, + catalog_provided_schema, field_id_mapping, - )?; - - let unioned_schema = all_tables - .iter() - .map(|t| t.schema.clone()) - .try_reduce(|l, r| DaftResult::Ok(l.union(&r)?.into()))?; - let full_daft_schema = unioned_schema.expect("we need at least 1 schema"); - // Hack to avoid to owned schema - let full_daft_schema = Schema { - fields: full_daft_schema.fields.clone(), - }; - let pruned_daft_schema = prune_fields_from_schema(full_daft_schema, columns)?; - - let all_tables = all_tables - .into_iter() - .map(|t| t.cast_to_schema(&pruned_daft_schema)) - .collect::>>()?; - // TODO: we can pass in stats here to optimize downstream workloads such as join - return Ok(MicroPartition::new_loaded( - Arc::new(pruned_daft_schema), - all_tables.into(), - None, - )); + ); } let meta_io_client = io_client.clone(); @@ -824,46 +842,40 @@ pub(crate) fn read_parquet_into_micropartition( None }; - // Union and prune the schema using the specified `columns` - let resolved_schemas = if let Some(field_id_mapping) = field_id_mapping { - pq_file_schemas - .into_iter() - .map(|pq_file_schema| { - rename_schema_recursively(pq_file_schema, field_id_mapping.as_ref()) - }) - .collect::>>()? - } else { - pq_file_schemas - }; - let unioned_schema = resolved_schemas - .into_iter() - .try_reduce(|l, r| l.union(&r))?; - let full_daft_schema = unioned_schema.expect("we need at least 1 schema"); - let pruned_daft_schema = prune_fields_from_schema(full_daft_schema, columns)?; + if let Some(stats) = stats { + // Statistics are provided by the Parquet file, so we create an unloaded MicroPartition + // by constructing an appropriate ScanTask + + // Prefer using the `catalog_provided_schema` but fall back onto inferred schema from Parquet files + let scan_task_daft_schema = match catalog_provided_schema { + Some(catalog_provided_schema) => catalog_provided_schema, + None => { + let unioned_schema = pq_file_schemas.into_iter().try_reduce(|l, r| l.union(&r))?; + Arc::new(unioned_schema.expect("we need at least 1 schema")) + } + }; - // Get total number of rows, accounting for selected `row_groups` and the indicated `num_rows` - let total_rows_no_limit = match &row_groups { - None => metadata.iter().map(|fm| fm.num_rows).sum(), - Some(row_groups) => metadata - .iter() - .zip(row_groups.iter()) - .map(|(fm, rg)| match rg { - Some(rg) => rg - .iter() - .map(|rg_idx| fm.row_groups.get(*rg_idx as usize).unwrap().num_rows()) - .sum::(), - None => fm.num_rows, - }) - .sum(), - }; - let total_rows = num_rows - .map(|num_rows| num_rows.min(total_rows_no_limit)) - .unwrap_or(total_rows_no_limit); + // Get total number of rows, accounting for selected `row_groups` and the indicated `num_rows` + let total_rows_no_limit = match &row_groups { + None => metadata.iter().map(|fm| fm.num_rows).sum(), + Some(row_groups) => metadata + .iter() + .zip(row_groups.iter()) + .map(|(fm, rg)| match rg { + Some(rg) => rg + .iter() + .map(|rg_idx| fm.row_groups.get(*rg_idx as usize).unwrap().num_rows()) + .sum::(), + None => fm.num_rows, + }) + .sum(), + }; + let total_rows = num_rows + .map(|num_rows| num_rows.min(total_rows_no_limit)) + .unwrap_or(total_rows_no_limit); - if let Some(stats) = stats { let owned_urls = uris.iter().map(|s| s.to_string()).collect::>(); - let daft_schema = Arc::new(pruned_daft_schema); let size_bytes = metadata .iter() .map(|m| -> u64 { @@ -891,7 +903,7 @@ pub(crate) fn read_parquet_into_micropartition( field_id_mapping: field_id_mapping.clone(), }) .into(), - daft_schema.clone(), + scan_task_daft_schema, StorageConfig::Native( NativeStorageConfig::new_internal( multithreaded_io, @@ -909,15 +921,6 @@ pub(crate) fn read_parquet_into_micropartition( ), ); - let exprs = daft_schema - .fields - .keys() - .map(|n| daft_dsl::col(n.as_str())) - .collect::>(); - - // use schema to update stats - let stats = stats.eval_expression_list(exprs.as_slice(), daft_schema.as_ref())?; - Ok(MicroPartition::new_unloaded( scan_task.materialized_schema(), Arc::new(scan_task), @@ -925,29 +928,21 @@ pub(crate) fn read_parquet_into_micropartition( stats, )) } else { - let all_tables = read_parquet_bulk( + _read_parquet_into_loaded_micropartition( + io_client, + runtime_handle, uris, columns, start_offset, num_rows, row_groups, - None, - io_client, + predicate, io_stats, num_parallel_tasks, - runtime_handle, schema_infer_options, + catalog_provided_schema, field_id_mapping, - )?; - let all_tables = all_tables - .into_iter() - .map(|t| t.cast_to_schema(&pruned_daft_schema)) - .collect::>>()?; - Ok(MicroPartition::new_loaded( - Arc::new(pruned_daft_schema), - all_tables.into(), - None, - )) + ) } } diff --git a/src/daft-micropartition/src/python.rs b/src/daft-micropartition/src/python.rs index 4c16e2b70f..fa9b53f328 100644 --- a/src/daft-micropartition/src/python.rs +++ b/src/daft-micropartition/src/python.rs @@ -543,6 +543,7 @@ impl PyMicroPartition { 1, multithreaded_io.unwrap_or(true), &schema_infer_options, + None, &None, ) })?; @@ -584,6 +585,7 @@ impl PyMicroPartition { num_parallel_tasks.unwrap_or(128) as usize, multithreaded_io.unwrap_or(true), &schema_infer_options, + None, &None, ) })?;