diff --git a/daft/table/micropartition.py b/daft/table/micropartition.py index f931fda2bd..de14ab9921 100644 --- a/daft/table/micropartition.py +++ b/daft/table/micropartition.py @@ -310,6 +310,7 @@ def read_parquet( start_offset: int | None = None, num_rows: int | None = None, row_groups: list[int] | None = None, + predicates: list[Expression] | None = None, io_config: IOConfig | None = None, multithreaded_io: bool | None = None, coerce_int96_timestamp_unit: TimeUnit = TimeUnit.ns(), @@ -321,6 +322,7 @@ def read_parquet( start_offset, num_rows, row_groups, + [pred._expr for pred in predicates] if predicates is not None else None, io_config, multithreaded_io, coerce_int96_timestamp_unit._timeunit, @@ -335,6 +337,7 @@ def read_parquet_bulk( start_offset: int | None = None, num_rows: int | None = None, row_groups_per_path: list[list[int] | None] | None = None, + predicates: list[Expression] | None = None, io_config: IOConfig | None = None, num_parallel_tasks: int | None = 128, multithreaded_io: bool | None = None, @@ -347,6 +350,7 @@ def read_parquet_bulk( start_offset, num_rows, row_groups_per_path, + [pred._expr for pred in predicates] if predicates is not None else None, io_config, num_parallel_tasks, multithreaded_io, diff --git a/daft/table/table.py b/daft/table/table.py index 1669170e23..0ef1333a9e 100644 --- a/daft/table/table.py +++ b/daft/table/table.py @@ -383,6 +383,7 @@ def read_parquet( start_offset: int | None = None, num_rows: int | None = None, row_groups: list[int] | None = None, + predicates: list[Expression] | None = None, io_config: IOConfig | None = None, multithreaded_io: bool | None = None, coerce_int96_timestamp_unit: TimeUnit = TimeUnit.ns(), @@ -394,6 +395,7 @@ def read_parquet( start_offset=start_offset, num_rows=num_rows, row_groups=row_groups, + predicates=[pred._expr for pred in predicates] if predicates is not None else None, io_config=io_config, multithreaded_io=multithreaded_io, coerce_int96_timestamp_unit=coerce_int96_timestamp_unit._timeunit, @@ -408,6 +410,7 @@ def read_parquet_bulk( start_offset: int | None = None, num_rows: int | None = None, row_groups_per_path: list[list[int] | None] | None = None, + predicates: list[Expression] | None = None, io_config: IOConfig | None = None, num_parallel_tasks: int | None = 128, multithreaded_io: bool | None = None, @@ -419,6 +422,7 @@ def read_parquet_bulk( start_offset=start_offset, num_rows=num_rows, row_groups=row_groups_per_path, + predicates=[pred._expr for pred in predicates] if predicates is not None else None, io_config=io_config, num_parallel_tasks=num_parallel_tasks, multithreaded_io=multithreaded_io, diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index b272023d46..bf7d216cbc 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -8,6 +8,7 @@ use common_error::DaftResult; use daft_core::schema::{Schema, SchemaRef}; use daft_csv::{CsvConvertOptions, CsvParseOptions, CsvReadOptions}; +use daft_dsl::{Expr, ExprRef}; use daft_parquet::read::{ read_parquet_bulk, read_parquet_metadata_bulk, ParquetSchemaInferenceOptions, }; @@ -119,6 +120,11 @@ fn materialize_scan_task( None, scan_task.pushdowns.limit, row_groups, + scan_task + .pushdowns + .filters + .as_ref() + .map(|v| v.as_ref().clone()), io_client.clone(), io_stats, 8, @@ -387,6 +393,11 @@ impl MicroPartition { None, scan_task.pushdowns.limit, row_groups, + scan_task + .pushdowns + .filters + .as_ref() + .map(|v| v.as_ref().clone()), cfg.io_config .clone() .map(|c| Arc::new(c.clone())) @@ -635,6 +646,7 @@ pub(crate) fn read_parquet_into_micropartition( start_offset: Option, num_rows: Option, row_groups: Option>>>, + predicates: Option>, io_config: Arc, io_stats: Option, num_parallel_tasks: usize, @@ -648,6 +660,49 @@ pub(crate) fn read_parquet_into_micropartition( // Run the required I/O to retrieve all the Parquet FileMetaData let runtime_handle = daft_io::get_runtime(multithreaded_io)?; let io_client = daft_io::get_io_client(multithreaded_io, io_config.clone())?; + + if let Some(ref predicates) = predicates { + // We have a predicate, so we will perform eager read with the predicate + // Since we currently + let all_tables = read_parquet_bulk( + uris, + columns, + None, + None, + row_groups, + Some(predicates.clone()), + io_client, + io_stats, + num_parallel_tasks, + runtime_handle, + schema_infer_options, + )?; + + let unioned_schema = all_tables + .iter() + .map(|t| t.schema.clone()) + .try_reduce(|l, r| DaftResult::Ok(l.union(&r)?.into()))?; + let full_daft_schema = unioned_schema.expect("we need at least 1 schema"); + // Hack to avoid to owned schema + let full_daft_schema = Schema { + fields: full_daft_schema.fields.clone(), + }; + let pruned_daft_schema = prune_fields_from_schema(full_daft_schema, columns)?; + + let all_tables = all_tables + .into_iter() + .map(|t| t.cast_to_schema(&pruned_daft_schema)) + .collect::>>()?; + let loaded = + MicroPartition::new_loaded(Arc::new(pruned_daft_schema), all_tables.into(), None); + + if let Some(num_rows) = num_rows { + return loaded.head(num_rows); + } else { + return Ok(loaded); + } + } + let meta_io_client = io_client.clone(); let meta_io_stats = io_stats.clone(); let metadata = runtime_handle.block_on(async move { @@ -775,6 +830,7 @@ pub(crate) fn read_parquet_into_micropartition( start_offset, num_rows, row_groups, + None, io_client, io_stats, num_parallel_tasks, diff --git a/src/daft-micropartition/src/ops/filter.rs b/src/daft-micropartition/src/ops/filter.rs index b49df69c95..88ab516de9 100644 --- a/src/daft-micropartition/src/ops/filter.rs +++ b/src/daft-micropartition/src/ops/filter.rs @@ -8,7 +8,7 @@ use crate::{micropartition::MicroPartition, DaftCoreComputeSnafu}; use daft_stats::TruthValue; impl MicroPartition { - pub fn filter(&self, predicate: &[Expr]) -> DaftResult { + pub fn filter>(&self, predicate: &[E]) -> DaftResult { let io_stats = IOStatsContext::new("MicroPartition::filter"); if predicate.is_empty() { return Ok(Self::empty(Some(self.schema.clone()))); @@ -16,7 +16,7 @@ impl MicroPartition { if let Some(statistics) = &self.statistics { let folded_expr = predicate .iter() - .cloned() + .map(|e| e.as_ref().clone()) .reduce(|a, b| a.and(&b)) .expect("should have at least 1 expr"); let eval_result = statistics.eval_expression(&folded_expr)?; diff --git a/src/daft-micropartition/src/python.rs b/src/daft-micropartition/src/python.rs index f2a6796ffc..56b8691673 100644 --- a/src/daft-micropartition/src/python.rs +++ b/src/daft-micropartition/src/python.rs @@ -405,6 +405,7 @@ impl PyMicroPartition { start_offset: Option, num_rows: Option, row_groups: Option>, + predicates: Option>, io_config: Option, multithreaded_io: Option, coerce_int96_timestamp_unit: Option, @@ -423,6 +424,8 @@ impl PyMicroPartition { start_offset, num_rows, row_groups.map(|rg| vec![Some(rg)]), + predicates + .map(|e_vec| e_vec.into_iter().map(|e| e.expr.into()).collect::>()), io_config, Some(io_stats), 1, @@ -442,6 +445,7 @@ impl PyMicroPartition { start_offset: Option, num_rows: Option, row_groups: Option>>>, + predicates: Option>, io_config: Option, num_parallel_tasks: Option, multithreaded_io: Option, @@ -461,6 +465,8 @@ impl PyMicroPartition { start_offset, num_rows, row_groups, + predicates + .map(|e_vec| e_vec.into_iter().map(|e| e.expr.into()).collect::>()), io_config, Some(io_stats), num_parallel_tasks.unwrap_or(128) as usize, diff --git a/src/daft-parquet/src/file.rs b/src/daft-parquet/src/file.rs index 246069077d..e81332c25b 100644 --- a/src/daft-parquet/src/file.rs +++ b/src/daft-parquet/src/file.rs @@ -3,7 +3,7 @@ use std::{collections::HashSet, sync::Arc}; use arrow2::io::parquet::read::schema::infer_schema_with_options; use common_error::DaftResult; use daft_core::{schema::Schema, utils::arrow::cast_array_for_daft_if_needed, Series}; -use daft_dsl::Expr; +use daft_dsl::{Expr, ExprRef}; use daft_io::{IOClient, IOStatsRef}; use daft_stats::TruthValue; use daft_table::Table; @@ -33,7 +33,7 @@ pub(crate) struct ParquetReaderBuilder { limit: Option, row_groups: Option>, schema_inference_options: ParquetSchemaInferenceOptions, - predicate: Option, + predicates: Option>, } use parquet2::read::decompress; @@ -99,7 +99,7 @@ pub(crate) fn build_row_ranges( limit: Option, row_start_offset: usize, row_groups: Option<&[i64]>, - predicate: Option<&Expr>, + predicates: Option<&[ExprRef]>, schema: &Schema, metadata: &parquet2::metadata::FileMetaData, uri: &str, @@ -107,6 +107,15 @@ pub(crate) fn build_row_ranges( let limit = limit.map(|v| v as i64); let mut row_ranges = vec![]; let mut curr_row_index = 0; + + let folded_expr = predicates.map(|preds| { + preds + .iter() + .cloned() + .reduce(|a, b| a.and(&b).into()) + .expect("should have at least 1 expr") + }); + if let Some(row_groups) = row_groups { let mut rows_to_add: i64 = limit.unwrap_or(i64::MAX); for i in row_groups { @@ -122,11 +131,12 @@ pub(crate) fn build_row_ranges( break; } let rg = metadata.row_groups.get(i).unwrap(); - if let Some(pred) = predicate { + if let Some(ref pred) = folded_expr { let stats = statistics::row_group_metadata_to_table_stats(rg, schema) .with_context(|_| UnableToConvertRowGroupMetadataToStatsSnafu { path: uri.to_string(), })?; + let evaled = stats.eval_expression(pred).with_context(|_| { UnableToRunExpressionOnStatsSnafu { path: uri.to_string(), @@ -153,7 +163,7 @@ pub(crate) fn build_row_ranges( curr_row_index += rg.num_rows(); continue; } else if rows_to_add > 0 { - if let Some(pred) = predicate { + if let Some(ref pred) = folded_expr { let stats = statistics::row_group_metadata_to_table_stats(rg, schema) .with_context(|_| UnableToConvertRowGroupMetadataToStatsSnafu { path: uri.to_string(), @@ -203,7 +213,7 @@ impl ParquetReaderBuilder { limit: None, row_groups: None, schema_inference_options: Default::default(), - predicate: None, + predicates: None, }) } @@ -259,8 +269,9 @@ impl ParquetReaderBuilder { self } - pub fn set_filter(mut self, predicate: Expr) -> Self { - self.predicate = Some(predicate); + pub fn set_filter(mut self, predicates: Vec) -> Self { + assert_eq!(self.limit, None); + self.predicates = Some(predicates); self } @@ -276,13 +287,13 @@ impl ParquetReaderBuilder { .fields .retain(|f| names_to_keep.contains(f.name.as_str())); } - // DONT UNWRAP + // TODO: DONT UNWRAP let daft_schema = Schema::try_from(&arrow_schema).unwrap(); let row_ranges = build_row_ranges( self.limit, self.row_start_offset, self.row_groups.as_deref(), - self.predicate.as_ref(), + self.predicates.as_deref(), &daft_schema, &self.metadata, &self.uri, diff --git a/src/daft-parquet/src/python.rs b/src/daft-parquet/src/python.rs index 73f202feb0..05c63f5cb7 100644 --- a/src/daft-parquet/src/python.rs +++ b/src/daft-parquet/src/python.rs @@ -5,6 +5,7 @@ pub mod pylib { ffi::field_to_py, python::{datatype::PyTimeUnit, schema::PySchema, PySeries}, }; + use daft_dsl::python::PyExpr; use daft_io::{get_io_client, python::IOConfig, IOStatsContext}; use daft_table::python::PyTable; use pyo3::{pyfunction, types::PyModule, PyResult, Python}; @@ -21,6 +22,7 @@ pub mod pylib { start_offset: Option, num_rows: Option, row_groups: Option>, + predicates: Option>, io_config: Option, multithreaded_io: Option, coerce_int96_timestamp_unit: Option, @@ -43,6 +45,8 @@ pub mod pylib { start_offset, num_rows, row_groups, + predicates + .map(|e_vec| e_vec.into_iter().map(|e| e.expr.into()).collect::>()), io_client, Some(io_stats.clone()), runtime_handle, @@ -127,6 +131,7 @@ pub mod pylib { start_offset: Option, num_rows: Option, row_groups: Option>>>, + predicates: Option>, io_config: Option, num_parallel_tasks: Option, multithreaded_io: Option, @@ -150,6 +155,8 @@ pub mod pylib { start_offset, num_rows, row_groups, + predicates + .map(|e_vec| e_vec.into_iter().map(|e| e.expr.into()).collect::>()), io_client, Some(io_stats), num_parallel_tasks.unwrap_or(128) as usize, diff --git a/src/daft-parquet/src/read.rs b/src/daft-parquet/src/read.rs index 7311b9b1fa..fa91f49ba1 100644 --- a/src/daft-parquet/src/read.rs +++ b/src/daft-parquet/src/read.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{fmt::format, sync::Arc}; use common_error::DaftResult; @@ -7,6 +7,7 @@ use daft_core::{ schema::Schema, DataType, IntoSeries, Series, }; +use daft_dsl::{Expr, ExprRef}; use daft_io::{get_runtime, parse_url, IOClient, IOStatsRef, SourceType}; use daft_table::Table; use futures::{ @@ -61,12 +62,18 @@ async fn read_parquet_single( start_offset: Option, num_rows: Option, row_groups: Option>, + predicates: Option>, io_client: Arc, io_stats: Option, schema_infer_options: ParquetSchemaInferenceOptions, ) -> DaftResult { + let pred_set = predicates.is_some(); + if pred_set && num_rows.is_some() { + return Err(common_error::DaftError::ValueError("Parquet Reader Currently doesn't support having both `num_rows` and `predicate` set at the same time".to_string())); + } let (source_type, fixed_uri) = parse_url(uri)?; - let (metadata, table) = if matches!(source_type, SourceType::File) { + let (metadata, mut table) = if matches!(source_type, SourceType::File) { + // TODO thread predicate to local parquet read crate::stream_reader::local_parquet_read_async( fixed_uri.as_ref(), columns.map(|s| s.iter().map(|s| s.to_string()).collect_vec()), @@ -99,6 +106,12 @@ async fn read_parquet_single( builder }; + let builder = if let Some(ref predicates) = predicates { + builder.set_filter(predicates.clone()) + } else { + builder + }; + let parquet_reader = builder.build()?; let ranges = parquet_reader.prebuffer_ranges(io_client, io_stats)?; Ok(( @@ -117,44 +130,48 @@ async fn read_parquet_single( let metadata_num_columns = metadata.schema().fields().len(); - if let Some(row_groups) = row_groups { - let expected_rows: usize = row_groups - .iter() - .map(|i| rows_per_row_groups.get(*i as usize).unwrap()) - .sum(); - if expected_rows != table.len() { - return Err(super::Error::ParquetNumRowMismatch { - path: uri.into(), - metadata_num_rows: expected_rows, - read_rows: table.len(), - } - .into()); - } + if let Some(predicates) = predicates { + // TODO ideally pipeline this with IO and before concating, rather than after + table = table.filter(predicates.as_slice())?; } else { - match (start_offset, num_rows) { - (None, None) if metadata_num_rows != table.len() => { - Err(super::Error::ParquetNumRowMismatch { + if let Some(row_groups) = row_groups { + let expected_rows: usize = row_groups + .iter() + .map(|i| rows_per_row_groups.get(*i as usize).unwrap()) + .sum(); + if expected_rows != table.len() { + return Err(super::Error::ParquetNumRowMismatch { path: uri.into(), - metadata_num_rows, + metadata_num_rows: expected_rows, read_rows: table.len(), - }) + } + .into()); } - (Some(s), None) if metadata_num_rows.saturating_sub(s) != table.len() => { - Err(super::Error::ParquetNumRowMismatch { + } else { + match (start_offset, num_rows) { + (None, None) if metadata_num_rows != table.len() => { + Err(super::Error::ParquetNumRowMismatch { + path: uri.into(), + metadata_num_rows, + read_rows: table.len(), + }) + } + (Some(s), None) if metadata_num_rows.saturating_sub(s) != table.len() => { + Err(super::Error::ParquetNumRowMismatch { + path: uri.into(), + metadata_num_rows: metadata_num_rows.saturating_sub(s), + read_rows: table.len(), + }) + } + (_, Some(n)) if n < table.len() => Err(super::Error::ParquetNumRowMismatch { path: uri.into(), - metadata_num_rows: metadata_num_rows.saturating_sub(s), + metadata_num_rows: n.min(metadata_num_rows), read_rows: table.len(), - }) - } - (_, Some(n)) if n < table.len() => Err(super::Error::ParquetNumRowMismatch { - path: uri.into(), - metadata_num_rows: n.min(metadata_num_rows), - read_rows: table.len(), - }), - _ => Ok(()), - }?; - }; - + }), + _ => Ok(()), + }?; + }; + } let expected_num_columns = if let Some(columns) = columns { columns.len() } else { @@ -316,6 +333,7 @@ pub fn read_parquet( start_offset: Option, num_rows: Option, row_groups: Option>, + predicates: Option>, io_client: Arc, io_stats: Option, runtime_handle: Arc, @@ -329,6 +347,7 @@ pub fn read_parquet( start_offset, num_rows, row_groups, + predicates, io_client, io_stats, schema_infer_options, @@ -373,6 +392,7 @@ pub fn read_parquet_bulk( start_offset: Option, num_rows: Option, row_groups: Option>>>, + predicates: Option>, io_client: Arc, io_stats: Option, num_parallel_tasks: usize, @@ -396,6 +416,7 @@ pub fn read_parquet_bulk( let uri = uri.to_string(); let owned_columns = owned_columns.clone(); let owned_row_group = row_groups.as_ref().and_then(|rgs| rgs[i].clone()); + let owned_predicates = predicates.clone(); let io_client = io_client.clone(); let io_stats = io_stats.clone(); @@ -412,6 +433,7 @@ pub fn read_parquet_bulk( start_offset, num_rows, owned_row_group, + owned_predicates, io_client, io_stats, schema_infer_options, @@ -643,6 +665,7 @@ mod tests { None, None, None, + None, io_client, None, runtime_handle, diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index c296d7cc57..feaeaaa0e8 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -174,16 +174,20 @@ impl Table { Ok(column_sizes?.iter().sum()) } - pub fn filter(&self, predicate: &[Expr]) -> DaftResult { + pub fn filter>(&self, predicate: &[E]) -> DaftResult { if predicate.is_empty() { Ok(self.clone()) } else if predicate.len() == 1 { - let mask = self.eval_expression(predicate.get(0).unwrap())?; + let mask = self.eval_expression(predicate.get(0).unwrap().as_ref())?; self.mask_filter(&mask) } else { - let mut expr = predicate.get(0).unwrap().and(predicate.get(1).unwrap()); + let mut expr = predicate + .get(0) + .unwrap() + .as_ref() + .and(predicate.get(1).unwrap().as_ref()); for i in 2..predicate.len() { - let next = predicate.get(i).unwrap(); + let next = predicate.get(i).unwrap().as_ref(); expr = expr.and(next); } let mask = self.eval_expression(&expr)?;