Skip to content

Commit

Permalink
first pass predicate apply in csv
Browse files Browse the repository at this point in the history
  • Loading branch information
samster25 committed Dec 11, 2023
1 parent 6e5ab1b commit 8651316
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 44 deletions.
16 changes: 14 additions & 2 deletions src/daft-csv/src/options.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
use daft_core::{impl_bincode_py_state_serialization, schema::SchemaRef};
use daft_dsl::{Expr, ExprRef};
use serde::{Deserialize, Serialize};
#[cfg(feature = "python")]
use {
daft_core::python::schema::PySchema,
daft_dsl::python::PyExpr,
pyo3::{
pyclass, pyclass::CompareOp, pymethods, types::PyBytes, PyObject, PyResult, PyTypeInfo,
Python, ToPyObject,
Expand All @@ -17,6 +19,7 @@ pub struct CsvConvertOptions {
pub include_columns: Option<Vec<String>>,
pub column_names: Option<Vec<String>>,
pub schema: Option<SchemaRef>,
pub predicate: Option<ExprRef>,
}

impl CsvConvertOptions {
Expand All @@ -25,12 +28,14 @@ impl CsvConvertOptions {
include_columns: Option<Vec<String>>,
column_names: Option<Vec<String>>,
schema: Option<SchemaRef>,
predicate: Option<ExprRef>,
) -> Self {
Self {
limit,
include_columns,
column_names,
schema,
predicate: predicate,
}
}

Expand All @@ -40,6 +45,7 @@ impl CsvConvertOptions {
include_columns: self.include_columns,
column_names: self.column_names,
schema: self.schema,
predicate: self.predicate,
}
}

Expand All @@ -49,6 +55,7 @@ impl CsvConvertOptions {
include_columns,
column_names: self.column_names,
schema: self.schema,
predicate: self.predicate,
}
}

Expand All @@ -58,6 +65,7 @@ impl CsvConvertOptions {
include_columns: self.include_columns,
column_names,
schema: self.schema,
predicate: self.predicate,
}
}

Expand All @@ -67,13 +75,14 @@ impl CsvConvertOptions {
include_columns: self.include_columns,
column_names: self.column_names,
schema,
predicate: self.predicate,
}
}
}

impl Default for CsvConvertOptions {
fn default() -> Self {
Self::new_internal(None, None, None, None)
Self::new_internal(None, None, None, None, None)
}
}

Expand All @@ -88,19 +97,22 @@ impl CsvConvertOptions {
/// * `include_columns` - The names of the columns that should be kept, e.g. via a projection.
/// * `column_names` - The names for the CSV columns.
/// * `schema` - The names and dtypes for the CSV columns.
/// * `predicate` - Expression to filter rows applied before the limit
#[new]
#[pyo3(signature = (limit=None, include_columns=None, column_names=None, schema=None))]
#[pyo3(signature = (limit=None, include_columns=None, column_names=None, schema=None, predicate=None))]
pub fn new(
limit: Option<usize>,
include_columns: Option<Vec<String>>,
column_names: Option<Vec<String>>,
schema: Option<PySchema>,
predicate: Option<PyExpr>,
) -> Self {
Self::new_internal(
limit,
include_columns,
column_names,
schema.map(|s| s.into()),
predicate.map(|p| p.expr.into()),
)
}

Expand Down
100 changes: 67 additions & 33 deletions src/daft-csv/src/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use csv_async::AsyncReader;
use daft_core::{schema::Schema, utils::arrow::cast_array_for_daft_if_needed, Series};
use daft_io::{get_runtime, GetResult, IOClient, IOStatsRef};
use daft_table::Table;
use futures::{Stream, StreamExt, TryStreamExt, future};
use futures::{future, Stream, StreamExt, TryStreamExt};
use rayon::prelude::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator,
};
Expand All @@ -31,14 +31,8 @@ use daft_compression::CompressionCodec;
use daft_decoding::deserialize::deserialize_column;

trait ByteRecordChunkStream = Stream<Item = super::Result<Vec<ByteRecord>>>;
trait ColumnArrayChunkStream = Stream<
Item = super::Result<
Context<
JoinHandle<DaftResult<Vec<Box<dyn arrow2::array::Array>>>>,
super::JoinSnafu,
super::Error,
>,
>,
trait ColumnArrayChunkStream<T> = Stream<
Item = super::Result<Context<JoinHandle<DaftResult<T>>, super::JoinSnafu, super::Error>>,
>;

#[allow(clippy::too_many_arguments)]
Expand Down Expand Up @@ -136,6 +130,13 @@ pub fn read_csv_bulk(
tables.into_iter().collect::<DaftResult<Vec<_>>>()
}

#[inline]
fn assert_stream_send<'u, R>(
s: impl 'u + Send + Stream<Item = R>,
) -> impl 'u + Send + Stream<Item = R> {
s
}

async fn read_csv_single_into_table(
uri: &str,
convert_options: Option<CsvConvertOptions>,
Expand All @@ -148,11 +149,15 @@ async fn read_csv_single_into_table(
let include_columns = convert_options
.as_ref()
.and_then(|opts| opts.include_columns.clone());
let predicate = convert_options
.as_ref()
.and_then(|opts| opts.predicate.clone());

let mut remaining_rows = convert_options
let limit = convert_options
.as_ref()
.and_then(|opts| opts.limit.map(|limit| limit as i64));
let (chunk_stream, fields) = read_csv_single_into_stream(
let mut remaining_rows = limit;
let (chunk_stream, mut fields) = read_csv_single_into_stream(
uri,
convert_options.unwrap_or_default(),
parse_options.unwrap_or_default(),
Expand All @@ -172,9 +177,40 @@ async fn read_csv_single_into_table(
.unwrap()
});
// Collect all chunks in chunk x column form.
let chunks = chunk_stream
let tables = chunk_stream
// Limit the number of chunks we have in flight at any given time.
.try_buffered(max_chunks_in_flight)
.try_buffered(max_chunks_in_flight);
if let Some(include_columns) = include_columns {
let field_map = fields
.into_iter()
.map(|field| (field.name.clone(), field))
.collect::<HashMap<String, Field>>();
fields = include_columns
.into_iter()
.map(|col| field_map[&col].clone())
.collect::<Vec<_>>();
}
let schema: arrow2::datatypes::Schema = fields.clone().into();
let daft_schema = Arc::new(Schema::try_from(&schema)?);
let owned_daft_schema = daft_schema.clone();

let filtered_tables = assert_stream_send(tables.map_ok(move |result| {
let arrow_chunk = result?;
let columns_series = arrow_chunk
.into_par_iter()
.zip(&fields)
.map(|(array, field)| {
Series::try_from((field.name.as_ref(), cast_array_for_daft_if_needed(array)))
})
.collect::<DaftResult<Vec<Series>>>()?;
let table = Table::new(owned_daft_schema.clone(), columns_series)?;
if let Some(predicate) = &predicate {
table.filter(&[predicate.as_ref()])
} else {
Ok(table)
}
}));
let collected_tables = filtered_tables
.try_take_while(|result| {
match (result, remaining_rows) {
// Limit has been met, early-terminate.
Expand All @@ -194,21 +230,17 @@ async fn read_csv_single_into_table(
.into_iter()
.collect::<DaftResult<Vec<_>>>()?;
// Handle empty table case.
if chunks.is_empty() {
let schema: arrow2::datatypes::Schema = fields.into();
let daft_schema = Arc::new(Schema::try_from(&schema)?);
if collected_tables.is_empty() {
return Table::empty(Some(daft_schema));
}
// Transpose chunk x column into column x chunk.
let mut column_arrays = vec![Vec::with_capacity(chunks.len()); chunks[0].len()];
for chunk in chunks.into_iter() {
for (idx, col) in chunk.into_iter().enumerate() {
column_arrays[idx].push(col);
}
}
// Build table from chunks.
// TODO(Clark): Don't concatenate all chunks from a file into a single table, since MicroPartition is natively chunked.
chunks_to_table(column_arrays, include_columns, fields)
let concated_table = Table::concat(&collected_tables)?;
if let Some(limit) = limit {
// apply head incase that last chunk went over limit
concated_table.head(limit as usize)
} else {
Ok(concated_table)
}
}

async fn read_csv_single_into_stream(
Expand All @@ -218,7 +250,10 @@ async fn read_csv_single_into_stream(
read_options: Option<CsvReadOptions>,
io_client: Arc<IOClient>,
io_stats: Option<IOStatsRef>,
) -> DaftResult<(impl ColumnArrayChunkStream + Send, Vec<Field>)> {
) -> DaftResult<(
impl ColumnArrayChunkStream<Vec<Box<dyn arrow2::array::Array>>> + Send,
Vec<Field>,
)> {
let (mut schema, estimated_mean_row_size, estimated_std_row_size) = match convert_options.schema
{
Some(schema) => (schema.to_arrow()?, None, None),
Expand Down Expand Up @@ -306,12 +341,11 @@ async fn read_csv_single_into_stream(
let projection_indices =
fields_to_projection_indices(&schema.fields, &convert_options.include_columns);
let fields = schema.fields;
let stream =
parse_into_column_array_chunk_stream(
read_stream,
Arc::new(fields.clone()),
projection_indices,
);
let stream = parse_into_column_array_chunk_stream(
read_stream,
Arc::new(fields.clone()),
projection_indices,
);

Ok((stream, fields))
}
Expand Down Expand Up @@ -378,7 +412,7 @@ fn parse_into_column_array_chunk_stream(
stream: impl ByteRecordChunkStream + Send,
fields: Arc<Vec<arrow2::datatypes::Field>>,
projection_indices: Arc<Vec<usize>>,
) -> impl ColumnArrayChunkStream + Send {
) -> impl ColumnArrayChunkStream<Vec<Box<dyn arrow2::array::Array>>> + Send {
// Parsing stream: we spawn background tokio + rayon tasks so we can pipeline chunk parsing with chunk reading, and
// we further parse each chunk column in parallel on the rayon threadpool.
stream.map_ok(move |record| {
Expand Down
1 change: 1 addition & 0 deletions src/daft-micropartition/src/micropartition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ fn materialize_scan_task(
.as_ref()
.map(|cols| cols.iter().map(|col| col.to_string()).collect()),
None,
scan_task.pushdowns.filters.clone(),
);
let parse_options = CsvParseOptions::new_with_defaults(
cfg.has_headers,
Expand Down
1 change: 0 additions & 1 deletion src/daft-scan/src/glob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ fn run_glob(
io_stats: Option<IOStatsRef>,
) -> DaftResult<FileInfoIterator> {
let (_, parsed_glob_path) = parse_url(glob_path)?;

// Construct a static-lifetime BoxStream returning the FileMetadata
let glob_input = parsed_glob_path.as_ref().to_string();
let runtime_handle = runtime.handle();
Expand Down
4 changes: 1 addition & 3 deletions src/daft-table/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ pub fn record_batches_to_table(
.collect::<DaftResult<Vec<_>>>()?;
tables.push(Table::from_columns(columns)?)
}
Ok(Table::concat(
tables.iter().collect::<Vec<&Table>>().as_slice(),
)?)
Ok(Table::concat(tables.as_slice())?)
})
}

Expand Down
16 changes: 11 additions & 5 deletions src/daft-table/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,19 +216,19 @@ impl Table {
Ok(Table::new(self.schema.clone(), new_series?).unwrap())
}

pub fn concat(tables: &[&Table]) -> DaftResult<Self> {
pub fn concat<T: AsRef<Table>>(tables: &[T]) -> DaftResult<Self> {
if tables.is_empty() {
return Err(DaftError::ValueError(
"Need at least 1 Table to perform concat".to_string(),
));
}
if tables.len() == 1 {
return Ok((*tables.first().unwrap()).clone());
return Ok((*tables.first().unwrap().as_ref()).clone());
}
let first_table = tables.first().unwrap();
let first_table = tables.first().unwrap().as_ref();

let first_schema = first_table.schema.as_ref();
for tab in tables.iter().skip(1) {
for tab in tables.iter().skip(1).map(|t| t.as_ref()) {
if tab.schema.as_ref() != first_schema {
return Err(DaftError::SchemaMismatch(format!(
"Table concat requires all schemas to match, {} vs {}",
Expand All @@ -241,7 +241,7 @@ impl Table {
for i in 0..num_columns {
let series_to_cat: Vec<&Series> = tables
.iter()
.map(|s| s.get_column_by_index(i).unwrap())
.map(|s| s.as_ref().get_column_by_index(i).unwrap())
.collect();
new_series.push(Series::concat(series_to_cat.as_slice())?);
}
Expand Down Expand Up @@ -496,6 +496,12 @@ impl Display for Table {
}
}

impl AsRef<Table> for Table {
fn as_ref(&self) -> &Table {
self
}
}

#[cfg(test)]
mod test {

Expand Down
6 changes: 6 additions & 0 deletions src/daft-table/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,12 @@ impl From<PyTable> for Table {
}
}

impl AsRef<Table> for PyTable {
fn as_ref(&self) -> &Table {
&self.table
}
}

pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> {
parent.add_class::<PyTable>()?;
Ok(())
Expand Down

0 comments on commit 8651316

Please sign in to comment.