diff --git a/src/daft-csv/src/read.rs b/src/daft-csv/src/read.rs
index bc7a9b9739..1ac5809d04 100644
--- a/src/daft-csv/src/read.rs
+++ b/src/daft-csv/src/read.rs
@@ -7,11 +7,7 @@ use arrow2::{
use async_compat::{Compat, CompatExt};
use common_error::{DaftError, DaftResult};
use csv_async::AsyncReader;
-use daft_core::{
- schema::{Schema, SchemaRef},
- utils::arrow::cast_array_for_daft_if_needed,
- Series,
-};
+use daft_core::{schema::Schema, utils::arrow::cast_array_for_daft_if_needed, Series};
use daft_dsl::optimization::get_required_columns;
use daft_io::{get_runtime, GetResult, IOClient, IOStatsRef};
use daft_table::Table;
@@ -155,16 +151,16 @@ fn assert_stream_send<'u, R>(
// Parallel version of table concat
// get rid of this once Table APIs are parallel
-fn parallel_table_concat(tables: &[Table]) -> DaftResult
{
+fn tables_concat(mut tables: Vec) -> DaftResult {
if tables.is_empty() {
return Err(DaftError::ValueError(
"Need at least 1 Table to perform concat".to_string(),
));
}
if tables.len() == 1 {
- return Ok((*tables.first().unwrap()).clone());
+ return Ok(tables.pop().unwrap());
}
- let first_table = tables.first().unwrap();
+ let first_table = tables.pop().unwrap();
let first_schema = first_table.schema.as_ref();
for tab in tables.iter().skip(1) {
@@ -186,7 +182,7 @@ fn parallel_table_concat(tables: &[Table]) -> DaftResult {
Series::concat(series_to_cat.as_slice())
})
.collect::>>()?;
- Table::new(first_table.schema.clone(), new_series)
+ Table::new(first_table.schema, new_series)
}
async fn read_csv_single_into_table(
@@ -204,7 +200,6 @@ async fn read_csv_single_into_table(
let limit = convert_options.as_ref().and_then(|opts| opts.limit);
- let mut remaining_rows = limit.map(|limit| limit as i64);
let include_columns = convert_options
.as_ref()
.and_then(|opts| opts.include_columns.clone());
@@ -255,17 +250,17 @@ async fn read_csv_single_into_table(
let schema_fields = if let Some(include_columns) = &include_columns {
let field_map = fields
.iter()
- .map(|field| (field.name.clone(), field.clone()))
- .collect::>();
+ .map(|field| (field.name.as_str(), field))
+ .collect::>();
include_columns
.iter()
- .map(|col| field_map[col].clone())
+ .map(|col| field_map[col.as_str()].clone())
.collect::>()
} else {
fields
};
- let schema: arrow2::datatypes::Schema = schema_fields.clone().into();
+ let schema: arrow2::datatypes::Schema = schema_fields.into();
let schema = Arc::new(Schema::try_from(&schema)?);
let filtered_tables = assert_stream_send(tables.map_ok(move |table| {
@@ -280,6 +275,7 @@ async fn read_csv_single_into_table(
table
}
}));
+ let mut remaining_rows = limit.map(|limit| limit as i64);
let collected_tables = filtered_tables
.try_take_while(|result| {
match (result, remaining_rows) {
@@ -304,7 +300,7 @@ async fn read_csv_single_into_table(
return Table::empty(Some(schema));
}
// // TODO(Clark): Don't concatenate all chunks from a file into a single table, since MicroPartition is natively chunked.
- let concated_table = parallel_table_concat(&collected_tables)?;
+ let concated_table = tables_concat(collected_tables)?;
if let Some(limit) = limit && concated_table.len() > limit {
// apply head incase that last chunk went over limit
concated_table.head(limit)
@@ -409,17 +405,11 @@ async fn read_csv_single_into_stream(
fields_to_projection_indices(&schema.fields, &convert_options.include_columns);
let fields = schema.fields;
- let fields_subset = projection_indices
- .iter()
- .map(|i| fields.get(*i).unwrap().into())
- .collect::>();
- let read_schema = daft_core::schema::Schema::new(fields_subset)?;
let stream = parse_into_column_array_chunk_stream(
read_stream,
Arc::new(fields.clone()),
projection_indices,
- Arc::new(read_schema),
- );
+ )?;
Ok((stream, fields))
}
@@ -486,10 +476,15 @@ fn parse_into_column_array_chunk_stream(
stream: impl ByteRecordChunkStream + Send,
fields: Arc>,
projection_indices: Arc>,
- read_schema: SchemaRef,
-) -> impl TableStream + Send {
+) -> DaftResult {
// Parsing stream: we spawn background tokio + rayon tasks so we can pipeline chunk parsing with chunk reading, and
// we further parse each chunk column in parallel on the rayon threadpool.
+
+ let fields_subset = projection_indices
+ .iter()
+ .map(|i| fields.get(*i).unwrap().into())
+ .collect::>();
+ let read_schema = Arc::new(daft_core::schema::Schema::new(fields_subset)?);
let read_daft_fields = Arc::new(
read_schema
.fields
@@ -497,7 +492,8 @@ fn parse_into_column_array_chunk_stream(
.map(|f| Arc::new(f.clone()))
.collect::>(),
);
- stream.map_ok(move |record| {
+
+ Ok(stream.map_ok(move |record| {
let (fields, projection_indices) = (fields.clone(), projection_indices.clone());
let read_schema = read_schema.clone();
let read_daft_fields = read_daft_fields.clone();
@@ -528,7 +524,7 @@ fn parse_into_column_array_chunk_stream(
recv.await.context(super::OneShotRecvSnafu {})?
})
.context(super::JoinSnafu {})
- })
+ }))
}
fn fields_to_projection_indices(