diff --git a/src/daft-csv/src/read.rs b/src/daft-csv/src/read.rs index bc7a9b9739..1ac5809d04 100644 --- a/src/daft-csv/src/read.rs +++ b/src/daft-csv/src/read.rs @@ -7,11 +7,7 @@ use arrow2::{ use async_compat::{Compat, CompatExt}; use common_error::{DaftError, DaftResult}; use csv_async::AsyncReader; -use daft_core::{ - schema::{Schema, SchemaRef}, - utils::arrow::cast_array_for_daft_if_needed, - Series, -}; +use daft_core::{schema::Schema, utils::arrow::cast_array_for_daft_if_needed, Series}; use daft_dsl::optimization::get_required_columns; use daft_io::{get_runtime, GetResult, IOClient, IOStatsRef}; use daft_table::Table; @@ -155,16 +151,16 @@ fn assert_stream_send<'u, R>( // Parallel version of table concat // get rid of this once Table APIs are parallel -fn parallel_table_concat(tables: &[Table]) -> DaftResult { +fn tables_concat(mut tables: Vec
) -> DaftResult
{ if tables.is_empty() { return Err(DaftError::ValueError( "Need at least 1 Table to perform concat".to_string(), )); } if tables.len() == 1 { - return Ok((*tables.first().unwrap()).clone()); + return Ok(tables.pop().unwrap()); } - let first_table = tables.first().unwrap(); + let first_table = tables.pop().unwrap(); let first_schema = first_table.schema.as_ref(); for tab in tables.iter().skip(1) { @@ -186,7 +182,7 @@ fn parallel_table_concat(tables: &[Table]) -> DaftResult
{ Series::concat(series_to_cat.as_slice()) }) .collect::>>()?; - Table::new(first_table.schema.clone(), new_series) + Table::new(first_table.schema, new_series) } async fn read_csv_single_into_table( @@ -204,7 +200,6 @@ async fn read_csv_single_into_table( let limit = convert_options.as_ref().and_then(|opts| opts.limit); - let mut remaining_rows = limit.map(|limit| limit as i64); let include_columns = convert_options .as_ref() .and_then(|opts| opts.include_columns.clone()); @@ -255,17 +250,17 @@ async fn read_csv_single_into_table( let schema_fields = if let Some(include_columns) = &include_columns { let field_map = fields .iter() - .map(|field| (field.name.clone(), field.clone())) - .collect::>(); + .map(|field| (field.name.as_str(), field)) + .collect::>(); include_columns .iter() - .map(|col| field_map[col].clone()) + .map(|col| field_map[col.as_str()].clone()) .collect::>() } else { fields }; - let schema: arrow2::datatypes::Schema = schema_fields.clone().into(); + let schema: arrow2::datatypes::Schema = schema_fields.into(); let schema = Arc::new(Schema::try_from(&schema)?); let filtered_tables = assert_stream_send(tables.map_ok(move |table| { @@ -280,6 +275,7 @@ async fn read_csv_single_into_table( table } })); + let mut remaining_rows = limit.map(|limit| limit as i64); let collected_tables = filtered_tables .try_take_while(|result| { match (result, remaining_rows) { @@ -304,7 +300,7 @@ async fn read_csv_single_into_table( return Table::empty(Some(schema)); } // // TODO(Clark): Don't concatenate all chunks from a file into a single table, since MicroPartition is natively chunked. - let concated_table = parallel_table_concat(&collected_tables)?; + let concated_table = tables_concat(collected_tables)?; if let Some(limit) = limit && concated_table.len() > limit { // apply head incase that last chunk went over limit concated_table.head(limit) @@ -409,17 +405,11 @@ async fn read_csv_single_into_stream( fields_to_projection_indices(&schema.fields, &convert_options.include_columns); let fields = schema.fields; - let fields_subset = projection_indices - .iter() - .map(|i| fields.get(*i).unwrap().into()) - .collect::>(); - let read_schema = daft_core::schema::Schema::new(fields_subset)?; let stream = parse_into_column_array_chunk_stream( read_stream, Arc::new(fields.clone()), projection_indices, - Arc::new(read_schema), - ); + )?; Ok((stream, fields)) } @@ -486,10 +476,15 @@ fn parse_into_column_array_chunk_stream( stream: impl ByteRecordChunkStream + Send, fields: Arc>, projection_indices: Arc>, - read_schema: SchemaRef, -) -> impl TableStream + Send { +) -> DaftResult { // Parsing stream: we spawn background tokio + rayon tasks so we can pipeline chunk parsing with chunk reading, and // we further parse each chunk column in parallel on the rayon threadpool. + + let fields_subset = projection_indices + .iter() + .map(|i| fields.get(*i).unwrap().into()) + .collect::>(); + let read_schema = Arc::new(daft_core::schema::Schema::new(fields_subset)?); let read_daft_fields = Arc::new( read_schema .fields @@ -497,7 +492,8 @@ fn parse_into_column_array_chunk_stream( .map(|f| Arc::new(f.clone())) .collect::>(), ); - stream.map_ok(move |record| { + + Ok(stream.map_ok(move |record| { let (fields, projection_indices) = (fields.clone(), projection_indices.clone()); let read_schema = read_schema.clone(); let read_daft_fields = read_daft_fields.clone(); @@ -528,7 +524,7 @@ fn parse_into_column_array_chunk_stream( recv.await.context(super::OneShotRecvSnafu {})? }) .context(super::JoinSnafu {}) - }) + })) } fn fields_to_projection_indices(