From 86513162ab75c38d6a88f6423819c9c7e69024fc Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Mon, 11 Dec 2023 15:30:52 -0500 Subject: [PATCH] first pass predicate apply in csv --- src/daft-csv/src/options.rs | 16 ++- src/daft-csv/src/read.rs | 100 ++++++++++++------ src/daft-micropartition/src/micropartition.rs | 1 + src/daft-scan/src/glob.rs | 1 - src/daft-table/src/ffi.rs | 4 +- src/daft-table/src/lib.rs | 16 ++- src/daft-table/src/python.rs | 6 ++ 7 files changed, 100 insertions(+), 44 deletions(-) diff --git a/src/daft-csv/src/options.rs b/src/daft-csv/src/options.rs index eceebadfd7..da81ed3a36 100644 --- a/src/daft-csv/src/options.rs +++ b/src/daft-csv/src/options.rs @@ -1,8 +1,10 @@ use daft_core::{impl_bincode_py_state_serialization, schema::SchemaRef}; +use daft_dsl::{Expr, ExprRef}; use serde::{Deserialize, Serialize}; #[cfg(feature = "python")] use { daft_core::python::schema::PySchema, + daft_dsl::python::PyExpr, pyo3::{ pyclass, pyclass::CompareOp, pymethods, types::PyBytes, PyObject, PyResult, PyTypeInfo, Python, ToPyObject, @@ -17,6 +19,7 @@ pub struct CsvConvertOptions { pub include_columns: Option>, pub column_names: Option>, pub schema: Option, + pub predicate: Option, } impl CsvConvertOptions { @@ -25,12 +28,14 @@ impl CsvConvertOptions { include_columns: Option>, column_names: Option>, schema: Option, + predicate: Option, ) -> Self { Self { limit, include_columns, column_names, schema, + predicate: predicate, } } @@ -40,6 +45,7 @@ impl CsvConvertOptions { include_columns: self.include_columns, column_names: self.column_names, schema: self.schema, + predicate: self.predicate, } } @@ -49,6 +55,7 @@ impl CsvConvertOptions { include_columns, column_names: self.column_names, schema: self.schema, + predicate: self.predicate, } } @@ -58,6 +65,7 @@ impl CsvConvertOptions { include_columns: self.include_columns, column_names, schema: self.schema, + predicate: self.predicate, } } @@ -67,13 +75,14 @@ impl CsvConvertOptions { include_columns: self.include_columns, column_names: self.column_names, schema, + predicate: self.predicate, } } } impl Default for CsvConvertOptions { fn default() -> Self { - Self::new_internal(None, None, None, None) + Self::new_internal(None, None, None, None, None) } } @@ -88,19 +97,22 @@ impl CsvConvertOptions { /// * `include_columns` - The names of the columns that should be kept, e.g. via a projection. /// * `column_names` - The names for the CSV columns. /// * `schema` - The names and dtypes for the CSV columns. + /// * `predicate` - Expression to filter rows applied before the limit #[new] - #[pyo3(signature = (limit=None, include_columns=None, column_names=None, schema=None))] + #[pyo3(signature = (limit=None, include_columns=None, column_names=None, schema=None, predicate=None))] pub fn new( limit: Option, include_columns: Option>, column_names: Option>, schema: Option, + predicate: Option, ) -> Self { Self::new_internal( limit, include_columns, column_names, schema.map(|s| s.into()), + predicate.map(|p| p.expr.into()), ) } diff --git a/src/daft-csv/src/read.rs b/src/daft-csv/src/read.rs index 7b02d30d2d..3a7dab407e 100644 --- a/src/daft-csv/src/read.rs +++ b/src/daft-csv/src/read.rs @@ -10,7 +10,7 @@ use csv_async::AsyncReader; use daft_core::{schema::Schema, utils::arrow::cast_array_for_daft_if_needed, Series}; use daft_io::{get_runtime, GetResult, IOClient, IOStatsRef}; use daft_table::Table; -use futures::{Stream, StreamExt, TryStreamExt, future}; +use futures::{future, Stream, StreamExt, TryStreamExt}; use rayon::prelude::{ IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator, }; @@ -31,14 +31,8 @@ use daft_compression::CompressionCodec; use daft_decoding::deserialize::deserialize_column; trait ByteRecordChunkStream = Stream>>; -trait ColumnArrayChunkStream = Stream< - Item = super::Result< - Context< - JoinHandle>>>, - super::JoinSnafu, - super::Error, - >, - >, +trait ColumnArrayChunkStream = Stream< + Item = super::Result>, super::JoinSnafu, super::Error>>, >; #[allow(clippy::too_many_arguments)] @@ -136,6 +130,13 @@ pub fn read_csv_bulk( tables.into_iter().collect::>>() } +#[inline] +fn assert_stream_send<'u, R>( + s: impl 'u + Send + Stream, +) -> impl 'u + Send + Stream { + s +} + async fn read_csv_single_into_table( uri: &str, convert_options: Option, @@ -148,11 +149,15 @@ async fn read_csv_single_into_table( let include_columns = convert_options .as_ref() .and_then(|opts| opts.include_columns.clone()); + let predicate = convert_options + .as_ref() + .and_then(|opts| opts.predicate.clone()); - let mut remaining_rows = convert_options + let limit = convert_options .as_ref() .and_then(|opts| opts.limit.map(|limit| limit as i64)); - let (chunk_stream, fields) = read_csv_single_into_stream( + let mut remaining_rows = limit; + let (chunk_stream, mut fields) = read_csv_single_into_stream( uri, convert_options.unwrap_or_default(), parse_options.unwrap_or_default(), @@ -172,9 +177,40 @@ async fn read_csv_single_into_table( .unwrap() }); // Collect all chunks in chunk x column form. - let chunks = chunk_stream + let tables = chunk_stream // Limit the number of chunks we have in flight at any given time. - .try_buffered(max_chunks_in_flight) + .try_buffered(max_chunks_in_flight); + if let Some(include_columns) = include_columns { + let field_map = fields + .into_iter() + .map(|field| (field.name.clone(), field)) + .collect::>(); + fields = include_columns + .into_iter() + .map(|col| field_map[&col].clone()) + .collect::>(); + } + let schema: arrow2::datatypes::Schema = fields.clone().into(); + let daft_schema = Arc::new(Schema::try_from(&schema)?); + let owned_daft_schema = daft_schema.clone(); + + let filtered_tables = assert_stream_send(tables.map_ok(move |result| { + let arrow_chunk = result?; + let columns_series = arrow_chunk + .into_par_iter() + .zip(&fields) + .map(|(array, field)| { + Series::try_from((field.name.as_ref(), cast_array_for_daft_if_needed(array))) + }) + .collect::>>()?; + let table = Table::new(owned_daft_schema.clone(), columns_series)?; + if let Some(predicate) = &predicate { + table.filter(&[predicate.as_ref()]) + } else { + Ok(table) + } + })); + let collected_tables = filtered_tables .try_take_while(|result| { match (result, remaining_rows) { // Limit has been met, early-terminate. @@ -194,21 +230,17 @@ async fn read_csv_single_into_table( .into_iter() .collect::>>()?; // Handle empty table case. - if chunks.is_empty() { - let schema: arrow2::datatypes::Schema = fields.into(); - let daft_schema = Arc::new(Schema::try_from(&schema)?); + if collected_tables.is_empty() { return Table::empty(Some(daft_schema)); } - // Transpose chunk x column into column x chunk. - let mut column_arrays = vec![Vec::with_capacity(chunks.len()); chunks[0].len()]; - for chunk in chunks.into_iter() { - for (idx, col) in chunk.into_iter().enumerate() { - column_arrays[idx].push(col); - } - } - // Build table from chunks. // TODO(Clark): Don't concatenate all chunks from a file into a single table, since MicroPartition is natively chunked. - chunks_to_table(column_arrays, include_columns, fields) + let concated_table = Table::concat(&collected_tables)?; + if let Some(limit) = limit { + // apply head incase that last chunk went over limit + concated_table.head(limit as usize) + } else { + Ok(concated_table) + } } async fn read_csv_single_into_stream( @@ -218,7 +250,10 @@ async fn read_csv_single_into_stream( read_options: Option, io_client: Arc, io_stats: Option, -) -> DaftResult<(impl ColumnArrayChunkStream + Send, Vec)> { +) -> DaftResult<( + impl ColumnArrayChunkStream>> + Send, + Vec, +)> { let (mut schema, estimated_mean_row_size, estimated_std_row_size) = match convert_options.schema { Some(schema) => (schema.to_arrow()?, None, None), @@ -306,12 +341,11 @@ async fn read_csv_single_into_stream( let projection_indices = fields_to_projection_indices(&schema.fields, &convert_options.include_columns); let fields = schema.fields; - let stream = - parse_into_column_array_chunk_stream( - read_stream, - Arc::new(fields.clone()), - projection_indices, - ); + let stream = parse_into_column_array_chunk_stream( + read_stream, + Arc::new(fields.clone()), + projection_indices, + ); Ok((stream, fields)) } @@ -378,7 +412,7 @@ fn parse_into_column_array_chunk_stream( stream: impl ByteRecordChunkStream + Send, fields: Arc>, projection_indices: Arc>, -) -> impl ColumnArrayChunkStream + Send { +) -> impl ColumnArrayChunkStream>> + Send { // Parsing stream: we spawn background tokio + rayon tasks so we can pipeline chunk parsing with chunk reading, and // we further parse each chunk column in parallel on the rayon threadpool. stream.map_ok(move |record| { diff --git a/src/daft-micropartition/src/micropartition.rs b/src/daft-micropartition/src/micropartition.rs index 94d81ed3de..cb2386660f 100644 --- a/src/daft-micropartition/src/micropartition.rs +++ b/src/daft-micropartition/src/micropartition.rs @@ -155,6 +155,7 @@ fn materialize_scan_task( .as_ref() .map(|cols| cols.iter().map(|col| col.to_string()).collect()), None, + scan_task.pushdowns.filters.clone(), ); let parse_options = CsvParseOptions::new_with_defaults( cfg.has_headers, diff --git a/src/daft-scan/src/glob.rs b/src/daft-scan/src/glob.rs index 0175ef003c..80c4200910 100644 --- a/src/daft-scan/src/glob.rs +++ b/src/daft-scan/src/glob.rs @@ -65,7 +65,6 @@ fn run_glob( io_stats: Option, ) -> DaftResult { let (_, parsed_glob_path) = parse_url(glob_path)?; - // Construct a static-lifetime BoxStream returning the FileMetadata let glob_input = parsed_glob_path.as_ref().to_string(); let runtime_handle = runtime.handle(); diff --git a/src/daft-table/src/ffi.rs b/src/daft-table/src/ffi.rs index 0e0e80f669..c488dcb1fc 100644 --- a/src/daft-table/src/ffi.rs +++ b/src/daft-table/src/ffi.rs @@ -52,9 +52,7 @@ pub fn record_batches_to_table( .collect::>>()?; tables.push(Table::from_columns(columns)?) } - Ok(Table::concat( - tables.iter().collect::>().as_slice(), - )?) + Ok(Table::concat(tables.as_slice())?) }) } diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index feaeaaa0e8..154e53d991 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -216,19 +216,19 @@ impl Table { Ok(Table::new(self.schema.clone(), new_series?).unwrap()) } - pub fn concat(tables: &[&Table]) -> DaftResult { + pub fn concat>(tables: &[T]) -> DaftResult { if tables.is_empty() { return Err(DaftError::ValueError( "Need at least 1 Table to perform concat".to_string(), )); } if tables.len() == 1 { - return Ok((*tables.first().unwrap()).clone()); + return Ok((*tables.first().unwrap().as_ref()).clone()); } - let first_table = tables.first().unwrap(); + let first_table = tables.first().unwrap().as_ref(); let first_schema = first_table.schema.as_ref(); - for tab in tables.iter().skip(1) { + for tab in tables.iter().skip(1).map(|t| t.as_ref()) { if tab.schema.as_ref() != first_schema { return Err(DaftError::SchemaMismatch(format!( "Table concat requires all schemas to match, {} vs {}", @@ -241,7 +241,7 @@ impl Table { for i in 0..num_columns { let series_to_cat: Vec<&Series> = tables .iter() - .map(|s| s.get_column_by_index(i).unwrap()) + .map(|s| s.as_ref().get_column_by_index(i).unwrap()) .collect(); new_series.push(Series::concat(series_to_cat.as_slice())?); } @@ -496,6 +496,12 @@ impl Display for Table { } } +impl AsRef for Table { + fn as_ref(&self) -> &Table { + self + } +} + #[cfg(test)] mod test { diff --git a/src/daft-table/src/python.rs b/src/daft-table/src/python.rs index 7c2900d52b..596695967d 100644 --- a/src/daft-table/src/python.rs +++ b/src/daft-table/src/python.rs @@ -354,6 +354,12 @@ impl From for Table { } } +impl AsRef
for PyTable { + fn as_ref(&self) -> &Table { + &self.table + } +} + pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> { parent.add_class::()?; Ok(())