From 7b40f23a5ff83aba4ab059b62ac781d7766be0b1 Mon Sep 17 00:00:00 2001 From: desmondcheongzx Date: Fri, 30 Aug 2024 14:26:31 -0700 Subject: [PATCH 01/10] Prototype --- Cargo.lock | 4 + Makefile | 4 + src/arrow2/src/io/csv/read_async/reader.rs | 25 +- src/daft-csv/Cargo.toml | 4 + src/daft-csv/src/lib.rs | 4 + src/daft-csv/src/read.rs | 591 ++++++++++++++++++++- src/daft-decoding/src/deserialize.rs | 12 +- 7 files changed, 618 insertions(+), 26 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c2472380cc..844617f3a0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1736,6 +1736,7 @@ dependencies = [ "async-stream", "common-error", "common-py-serde", + "crossbeam-channel", "csv-async", "daft-compression", "daft-core", @@ -1744,6 +1745,9 @@ dependencies = [ "daft-io", "daft-table", "futures", + "indexmap 2.3.0", + "memchr", + "memmap2", "pyo3", "rayon", "rstest", diff --git a/Makefile b/Makefile index d1d96e8df6..327f00513c 100644 --- a/Makefile +++ b/Makefile @@ -52,6 +52,10 @@ build: check-toolchain .venv ## Compile and install Daft for development build-release: check-toolchain .venv ## Compile and install a faster Daft binary @unset CONDA_PREFIX && PYO3_PYTHON=$(VENV_BIN)/python $(VENV_BIN)/maturin develop --release +.PHONY: build-bench +build-bench: check-toolchain .venv ## Compile and install a faster Daft binary + @unset CONDA_PREFIX && PYO3_PYTHON=$(VENV_BIN)/python $(VENV_BIN)/maturin develop --profile dev-bench + .PHONY: test test: .venv build ## Run tests HYPOTHESIS_MAX_EXAMPLES=$(HYPOTHESIS_MAX_EXAMPLES) $(VENV_BIN)/pytest --hypothesis-seed=$(HYPOTHESIS_SEED) diff --git a/src/arrow2/src/io/csv/read_async/reader.rs b/src/arrow2/src/io/csv/read_async/reader.rs index 0db0f25268..a95b88e5ad 100644 --- a/src/arrow2/src/io/csv/read_async/reader.rs +++ b/src/arrow2/src/io/csv/read_async/reader.rs @@ -1,10 +1,11 @@ use futures::AsyncRead; use super::{AsyncReader, ByteRecord}; +use crate::io::csv::read; use crate::error::{Error, Result}; -/// Asynchronosly read `len` rows from `reader` into `row`, skipping the first `skip`. +/// Asynchronosly read `rows.len` rows from `reader` into `rows`, skipping the first `skip`. /// This operation has minimal CPU work and is thus the fastest way to read through a CSV /// without deserializing the contents to Arrow. pub async fn read_rows( @@ -37,3 +38,25 @@ where } Ok(row_number) } + +/// Synchronously read `rows.len` rows from `reader` into `rows`. This is used in the local i/o case. +pub fn local_read_rows( + reader: &mut read::Reader, + rows: &mut [read::ByteRecord], +) -> Result<(usize, bool)> +where + R: std::io::Read, +{ + let mut row_number = 0; + let mut has_more = true; + for row in rows.iter_mut() { + has_more = reader + .read_byte_record(row) + .map_err(|e| Error::External(format!(" at line {}", row_number), Box::new(e)))?; + if !has_more { + break; + } + row_number += 1; + } + Ok((row_number, has_more)) +} diff --git a/src/daft-csv/Cargo.toml b/src/daft-csv/Cargo.toml index 4e58223843..d0781e7184 100644 --- a/src/daft-csv/Cargo.toml +++ b/src/daft-csv/Cargo.toml @@ -5,6 +5,7 @@ async-compression = {workspace = true} async-stream = {workspace = true} common-error = {path = "../common/error", default-features = false} common-py-serde = {path = "../common/py-serde", default-features = false} +crossbeam-channel = "0.5.1" csv-async = "1.3.0" daft-compression = {path = "../daft-compression", default-features = false} daft-core = {path = "../daft-core", default-features = false} @@ -13,6 +14,9 @@ daft-dsl = {path = "../daft-dsl", default-features = false} daft-io = {path = "../daft-io", default-features = false} daft-table = {path = "../daft-table", default-features = false} futures = {workspace = true} +indexmap = {workspace = true, features = ["serde"]} +memchr = "2.7.2" +memmap2 = "0.9.4" pyo3 = {workspace = true, optional = true} rayon = {workspace = true} serde = {workspace = true} diff --git a/src/daft-csv/src/lib.rs b/src/daft-csv/src/lib.rs index 59a62ad2d1..35c4d72006 100644 --- a/src/daft-csv/src/lib.rs +++ b/src/daft-csv/src/lib.rs @@ -2,6 +2,8 @@ #![feature(let_chains)] #![feature(trait_alias)] #![feature(trait_upcasting)] +#![feature(test)] +extern crate test; use common_error::DaftError; use snafu::Snafu; @@ -23,6 +25,8 @@ pub enum Error { #[snafu(display("{source}"))] IOError { source: daft_io::Error }, #[snafu(display("{source}"))] + StdIOError { source: std::io::Error }, + #[snafu(display("{source}"))] CSVError { source: csv_async::Error }, #[snafu(display("Invalid char: {}", val))] WrongChar { diff --git a/src/daft-csv/src/read.rs b/src/daft-csv/src/read.rs index 31ab4f84f0..3ed49204b3 100644 --- a/src/daft-csv/src/read.rs +++ b/src/daft-csv/src/read.rs @@ -1,15 +1,20 @@ -use std::{collections::HashMap, num::NonZeroUsize, sync::Arc}; +use std::{collections::HashMap, num::NonZeroUsize, sync::Arc, sync::Mutex}; + +use std::sync::atomic::{AtomicUsize, Ordering}; use arrow2::{ datatypes::Field, - io::csv::read_async::{read_rows, AsyncReaderBuilder, ByteRecord}, + io::csv::read, + io::csv::read::{Reader, ReaderBuilder}, + io::csv::read_async, + io::csv::read_async::{local_read_rows, read_rows, AsyncReaderBuilder}, }; use async_compat::{Compat, CompatExt}; use common_error::{DaftError, DaftResult}; use csv_async::AsyncReader; use daft_core::{schema::Schema, utils::arrow::cast_array_for_daft_if_needed, Series}; -use daft_dsl::optimization::get_required_columns; -use daft_io::{get_runtime, GetResult, IOClient, IOStatsRef}; +use daft_dsl::{optimization::get_required_columns, Expr}; +use daft_io::{get_runtime, parse_url, GetResult, IOClient, IOStatsRef, SourceType}; use daft_table::Table; use futures::{stream::BoxStream, Stream, StreamExt, TryStreamExt}; use rayon::{ @@ -27,13 +32,16 @@ use tokio::{ }; use tokio_util::io::StreamReader; -use crate::ArrowSnafu; use crate::{metadata::read_csv_schema_single, CsvConvertOptions, CsvParseOptions, CsvReadOptions}; +use crate::{ArrowSnafu, StdIOSnafu}; use daft_compression::CompressionCodec; use daft_decoding::deserialize::deserialize_column; -trait ByteRecordChunkStream: Stream>> {} -impl ByteRecordChunkStream for S where S: Stream>> {} +trait ByteRecordChunkStream: Stream>> {} +impl ByteRecordChunkStream for S where + S: Stream>> +{ +} type TableChunkResult = super::Result>, super::JoinSnafu, super::Error>>; @@ -145,22 +153,35 @@ pub async fn stream_csv( io_stats: Option, max_chunks_in_flight: Option, ) -> DaftResult>> { - let stream = stream_csv_single( - &uri, - convert_options, - parse_options, - read_options, - io_client, - io_stats, - max_chunks_in_flight, - ) - .await?; - - Ok(Box::pin(stream)) + let uri = uri.as_str(); + let (source_type, _) = parse_url(uri)?; + let is_compressed = CompressionCodec::from_uri(uri).is_some(); + let use_local_reader = false; // TODO(desmond): Feature under dev. + if matches!(source_type, SourceType::File) && !is_compressed && use_local_reader { + let stream = stream_csv_local( + uri, + convert_options, + parse_options.unwrap_or_default(), + read_options, + max_chunks_in_flight, + ) + .await?; + Ok(Box::pin(stream)) + } else { + let stream = stream_csv_single( + uri, + convert_options, + parse_options, + read_options, + io_client, + io_stats, + max_chunks_in_flight, + ) + .await?; + Ok(Box::pin(stream)) + } } -// Parallel version of table concat -// get rid of this once Table APIs are parallel fn tables_concat(mut tables: Vec) -> DaftResult
{ if tables.is_empty() { return Err(DaftError::ValueError( @@ -199,6 +220,387 @@ fn tables_concat(mut tables: Vec
) -> DaftResult
{ ) } +#[derive(Debug)] +struct CsvBufferPool { + buffers: Mutex>>, + buffer_size: usize, + record_buffer_size: usize, + num_fields: usize, +} + +struct CsvBufferPoolRef<'a> { + pool: &'a CsvBufferPool, + buffer: Vec, +} + +impl CsvBufferPool { + pub fn new( + record_buffer_size: usize, + num_fields: usize, + chunk_size_rows: usize, + initial_pool_size: usize, + ) -> Self { + let chunk_buffers = vec![ + vec![ + read::ByteRecord::with_capacity(record_buffer_size, num_fields); + chunk_size_rows + ]; + initial_pool_size + ]; + CsvBufferPool { + buffers: Mutex::new(chunk_buffers), + buffer_size: chunk_size_rows, + record_buffer_size, + num_fields, + } + } + + pub fn get_buffer(&self) -> CsvBufferPoolRef { + let mut buffers = self.buffers.lock().unwrap(); + let buffer = buffers.pop(); + let buffer = match buffer { + Some(buffer) => buffer, + None => { + vec![ + read::ByteRecord::with_capacity(self.record_buffer_size, self.num_fields); + self.buffer_size + ] + } + }; + + CsvBufferPoolRef { pool: self, buffer } + } + + fn return_buffer(&self, buffer: Vec) { + let mut buffers = self.buffers.lock().unwrap(); + buffers.push(buffer); + } +} + +// Daft does not currently support non-\n record terminators (e.g. carriage return \r, which only +// matters for pre-Mac OS X). +const NEWLINE: u8 = b'\n'; +const DEFAULT_CHUNK_SIZE: usize = 4 * 1024 * 1024; // 1MiB. TODO(desmond): This should be tuned. + +/// Helper function that finds the first new line character (\n) in the given byte slice. +fn next_line_position(input: &[u8]) -> Option { + // Assuming we are searching for the ASCII `\n` character, we don't need to do any special + // handling for UTF-8, since a `\n` value always corresponds to an ASCII `\n`. + // For more details, see: https://en.wikipedia.org/wiki/UTF-8#Encoding + memchr::memchr(NEWLINE, input) +} + +/// Helper function that determines what chunk of data to parse given a starting position within the +/// file, and the desired initial chunk size. +/// +/// Given a starting position, we use our chunk size to compute a preliminary start and stop +/// position. For example, we can visualize all preliminary chunks in a file as follows. +/// +/// Chunk 1 Chunk 2 Chunk 3 Chunk N +/// ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ +/// │ │ │\n │ │ \n │ │ \n │ +/// │ │ │ │ │ │ │ │ +/// │ │ │ \n │ │ │ │ │ +/// │ \n │ │ │ │ \n │ │ │ +/// │ │ │ │ │ │ ... │ \n │ +/// │ │ │ \n │ │ │ │ │ +/// │ \n │ │ │ │ │ │ │ +/// │ │ │ │ │ \n │ │ \n │ +/// └──────────┘ └──────────┘ └──────────┘ └──────────┘ +/// +/// However, record boundaries (i.e. the \n terminators) do not align nicely with these preliminary +/// chunk boundaries. So we adjust each preliminary chunk as follows: +/// - Find the first record terminator from the chunk's start. This is the new starting position. +/// - Find the first record terminator from the chunk's end. This is the new ending position. +/// - If a given preliminary chunk doesn't contain a record terminator, the adjusted chunk is empty. +/// +/// For example: +/// +/// Adjusted Chunk 1 Adj. Chunk 2 Adj. Chunk 3 Adj. Chunk N +/// ┌──────────────────┐┌─────────────────┐ ┌────────┐ ┌─┐ +/// │ \n││ \n│ │ \n│ \n │ │ +/// │ ┌───────┘│ ┌──────────┘ │ ┌─────┘ │ │ +/// │ │ ┌───┘ \n │ ┌───────┘ │ ┌────────┘ │ +/// │ \n │ │ │ │ \n │ │ │ +/// │ │ │ │ │ │ ... │ \n │ +/// │ │ │ \n │ │ │ │ │ +/// │ \n │ │ │ │ │ │ │ +/// │ │ │ │ │ \n │ │ \n │ +/// └──────────┘ └──────────┘ └──────────┘ └──────────┘ +/// +/// Using this method, we now have adjusted chunks that are aligned with record boundaries, that do +/// not overlap, and that fully cover every byte in the CSV file. Parsing each adjusted chunk can +/// now happen in parallel. +/// +/// This is the same method as described in: +/// Ge, Chang et al. “Speculative Distributed CSV Data Parsing for Big Data Analytics.” Proceedings of the 2019 International Conference on Management of Data (2019). +fn get_file_chunk(bytes: &[u8], start: usize, chunk_size: usize) -> Option<(usize, usize)> { + let stop = start + chunk_size; + let start = if start == 0 { + 0 + } else { + match next_line_position(&bytes[start..]) { + // Start reading after the first record terminator from the start of the chunk. + Some(pos) => start + pos + 1, + // If there's no record terminator found, then the previous chunk reader would have + // consumed the current chunk. Hence, we skip. + None => return None, + } + }; + // If the first record terminator comes after this chunk, then the previous chunk reader would + // have consumed the current chunk. Hence, we skip. + if start > stop { + return None; + } + let stop = if stop < bytes.len() { + match next_line_position(&bytes[stop..]) { + // Read up to the first terminator from the end of the chunk. + Some(pos) => stop + pos, + None => bytes.len(), + } + } else { + bytes.len() + }; + Some((start, stop)) +} + +#[allow(clippy::too_many_arguments)] +fn parse_csv_chunk( + mut reader: Reader, + projection_indices: Arc>, + fields: Vec, + read_daft_fields: Arc>>, + read_schema: Arc, + buf: CsvBufferPoolRef, + include_columns: &Option>, + predicate: Option>, +) -> DaftResult> +where + R: std::io::Read, +{ + let mut chunk_buffer = buf.buffer; + let mut tables = vec![]; + loop { + let (rows_read, has_more) = + local_read_rows(&mut reader, chunk_buffer.as_mut_slice()).context(ArrowSnafu {})?; + let chunk = projection_indices + .par_iter() + .enumerate() + .map(|(i, proj_idx)| { + let deserialized_col = deserialize_column( + &chunk_buffer[0..rows_read], + *proj_idx, + fields[*proj_idx].data_type().clone(), + 0, + ); + Series::try_from_field_and_arrow_array( + read_daft_fields[i].clone(), + cast_array_for_daft_if_needed(deserialized_col?), + ) + }) + .collect::>>()?; + let num_rows = chunk.first().map(|s| s.len()).unwrap_or(0); + let table = Table::new_unchecked(read_schema.clone(), chunk, num_rows); + let table = if let Some(predicate) = &predicate { + let filtered = table.filter(&[predicate.clone()])?; + if let Some(include_columns) = &include_columns { + filtered.get_columns(include_columns.as_slice())? + } else { + filtered + } + } else { + table + }; + tables.push(table); + + // The number of record might exceed the number of byte records we've allocated. + // Retry until all byte records in this chunk are read. + if !has_more { + break; + } + } + buf.pool.return_buffer(chunk_buffer); + Ok(tables) +} + +async fn stream_csv_local( + uri: &str, + convert_options: Option, + parse_options: CsvParseOptions, + read_options: Option, + max_chunks_in_flight: Option, +) -> DaftResult> + Send> { + let uri = uri.trim_start_matches("file://"); + let file = std::fs::File::open(uri)?; + let mmap = unsafe { memmap2::Mmap::map(&file) }.context(StdIOSnafu)?; + let bytes = &mmap[..]; + + // TODO(desmond): This logic is repeated multiple times in this file. Should dedup. + let predicate = convert_options + .as_ref() + .and_then(|opts| opts.predicate.clone()); + + let limit = convert_options.as_ref().and_then(|opts| opts.limit); + + let include_columns = convert_options + .as_ref() + .and_then(|opts| opts.include_columns.clone()); + + let convert_options = match (convert_options, &predicate) { + (None, _) => None, + (co, None) => co, + (Some(mut co), Some(predicate)) => { + if let Some(ref mut include_columns) = co.include_columns { + let required_columns_for_predicate = get_required_columns(predicate); + for rc in required_columns_for_predicate { + if include_columns.iter().all(|c| c.as_str() != rc.as_str()) { + include_columns.push(rc) + } + } + } + // If we have a limit and a predicate, remove limit for stream. + co.limit = None; + Some(co) + } + } + .unwrap_or_default(); + // End of `should dedup`. + + // TODO(desmond): We should do better schema inference here. + let schema = convert_options.clone().schema.unwrap().to_arrow()?; + let n_threads: usize = std::thread::available_parallelism() + .unwrap_or(NonZeroUsize::new(2).unwrap()) + .into(); + let chunk_size = read_options + .as_ref() + .and_then(|opt| opt.chunk_size.or_else(|| opt.buffer_size.map(|bs| bs / 8))) + .unwrap_or(DEFAULT_CHUNK_SIZE); + let projection_indices = fields_to_projection_indices( + &schema.clone().fields, + &convert_options.clone().include_columns, + ); + let fields = schema.clone().fields; + let fields_subset = projection_indices + .iter() + .map(|i| fields.get(*i).unwrap().into()) + .collect::>(); + let read_schema = Arc::new(daft_core::schema::Schema::new(fields_subset)?); + let read_daft_fields = Arc::new( + read_schema + .fields + .values() + .map(|f| Arc::new(f.clone())) + .collect::>(), + ); + // TODO(desmond): Need better upfront estimators. Sample or keep running count of stats. + let estimated_mean_row_size = 100f64; + let estimated_std_row_size = 20f64; + let record_buffer_size = (estimated_mean_row_size + estimated_std_row_size).ceil() as usize; + let chunk_size_rows = (chunk_size as f64 / record_buffer_size as f64).ceil() as usize; + // TODO(desmond): We don't want to create a per-read buffer pool, we want one pool shared with + // the whole process. + let buffer_pool = CsvBufferPool::new( + record_buffer_size, + schema.fields.len(), + chunk_size_rows, + n_threads * 2, + ); + let chunk_offsets: Vec = (0..=bytes.len()).step_by(chunk_size).collect(); + // TODO(desmond): Memory usage is still growing during execution of a .count(*).collect(), so + // the following approach still isn't quite right. + // TODO(desmond): Also, is this usage of max_chunks_in_flight correct? + let (sender, receiver) = + crossbeam_channel::bounded(max_chunks_in_flight.unwrap_or(n_threads * 2)); + let rows_read = AtomicUsize::new(0); + rayon::spawn(move || { + let bytes = &mmap[..]; + chunk_offsets.into_par_iter().for_each(|start| { + // TODO(desmond): Use try_for_each and terminate early once the limit is reached. + let limit_reached = limit.map_or(false, |limit| { + let current_rows_read = rows_read.load(Ordering::Relaxed); + current_rows_read >= limit + }); + + if !limit_reached && let Some((start, stop)) = get_file_chunk(bytes, start, chunk_size) + { + let buf = buffer_pool.get_buffer(); + let chunk = &bytes[start..stop]; + // Only the first chunk might potentially have headers. Subsequent chunks should + // read all rows as records. + let has_headers = start == 0 && parse_options.has_header; + let rdr = ReaderBuilder::new() + .has_headers(has_headers) + .delimiter(parse_options.delimiter) + .double_quote(parse_options.double_quote) + // TODO(desmond): We need to handle the quoted case properly. + .quote(parse_options.quote) + .escape(parse_options.escape_char) + .comment(parse_options.comment) + .flexible(parse_options.allow_variable_columns) + .from_reader(chunk); + Reader::from_reader(chunk); + let table_results = parse_csv_chunk( + rdr, + projection_indices.clone(), + fields.clone(), + read_daft_fields.clone(), + read_schema.clone(), + buf, + &include_columns, + predicate.clone(), + ); + match table_results { + Ok(tables) => { + for table in tables { + let table_len = table.len(); + sender.send(Ok(table)).unwrap(); + // Atomically update the number of rows read only after the result has + // been sent. In theory we could wrap these steps in a mutex, but + // applying limit at this layer can be best-effort with no adverse + // side effects. + rows_read.fetch_add(table_len, Ordering::Relaxed); + } + } + Err(e) => sender.send(Err(e)).unwrap(), + } + } + }) + }); + + let result_stream = futures::stream::iter(receiver); + Ok(result_stream) +} + +async fn tables_stream_collect(stream: BoxStream<'static, DaftResult
>) -> Vec
{ + stream + .filter_map(|result| async { + match result { + Ok(table) => Some(table), + Err(_) => None, // Skips errors; you could log them or handle differently + } + }) + .collect() + .await +} + +async fn read_csv_local( + uri: &str, + convert_options: Option, + parse_options: CsvParseOptions, + read_options: Option, + max_chunks_in_flight: Option, +) -> DaftResult
{ + let stream = stream_csv_local( + uri, + convert_options, + parse_options, + read_options, + max_chunks_in_flight, + ) + .await?; + tables_concat(tables_stream_collect(Box::pin(stream)).await) +} + async fn read_csv_single_into_table( uri: &str, convert_options: Option, @@ -208,6 +610,20 @@ async fn read_csv_single_into_table( io_stats: Option, max_chunks_in_flight: Option, ) -> DaftResult
{ + let (source_type, _) = parse_url(uri)?; + let is_compressed = CompressionCodec::from_uri(uri).is_some(); + let use_local_reader = false; // TODO(desmond): Feature under dev. + if matches!(source_type, SourceType::File) && !is_compressed && use_local_reader { + return read_csv_local( + uri, + convert_options, + parse_options.unwrap_or_default(), + read_options, + max_chunks_in_flight, + ) + .await; + } + let predicate = convert_options .as_ref() .and_then(|opts| opts.predicate.clone()); @@ -557,7 +973,7 @@ where estimated_rows_per_desired_chunk.max(8).min(num_rows - total_rows_read) }; let mut chunk_buffer = vec![ - ByteRecord::with_capacity(record_buffer_size, num_fields); + read_async::ByteRecord::with_capacity(record_buffer_size, num_fields); chunk_size_rows ]; @@ -576,7 +992,7 @@ where chunk_buffer.truncate(rows_read); if rows_read > 0 { - yield chunk_buffer + yield chunk_buffer; } } } @@ -811,6 +1227,135 @@ mod tests { Ok(()) } + use crate::read::read_csv_local; + use daft_io::get_runtime; + #[test] + fn test_csv_read_experimental() -> DaftResult<()> { + let file = "file:///Users/desmond/tasks/csv-reader/G1_1e8_1e1_5_0.csv"; + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + + let io_client = Arc::new(IOClient::new(io_config.into())?); + + use indexmap::IndexMap; + let mut fields = IndexMap::new(); + fields.insert( + "id1".to_string(), + Field { + name: "id1".to_string(), + dtype: daft_core::datatypes::DataType::Utf8, + metadata: Arc::default(), + }, + ); + fields.insert( + "id2".to_string(), + Field { + name: "id2".to_string(), + dtype: daft_core::datatypes::DataType::Utf8, + metadata: Arc::default(), + }, + ); + fields.insert( + "id3".to_string(), + Field { + name: "id3".to_string(), + dtype: daft_core::datatypes::DataType::Utf8, + metadata: Arc::default(), + }, + ); + fields.insert( + "id4".to_string(), + Field { + name: "id4".to_string(), + dtype: daft_core::datatypes::DataType::Int64, + metadata: Arc::default(), + }, + ); + fields.insert( + "id5".to_string(), + Field { + name: "id5".to_string(), + dtype: daft_core::datatypes::DataType::Int64, + metadata: Arc::default(), + }, + ); + fields.insert( + "id6".to_string(), + Field { + name: "id6".to_string(), + dtype: daft_core::datatypes::DataType::Int64, + metadata: Arc::default(), + }, + ); + fields.insert( + "v1".to_string(), + Field { + name: "v1".to_string(), + dtype: daft_core::datatypes::DataType::Int64, + metadata: Arc::default(), + }, + ); + fields.insert( + "v2".to_string(), + Field { + name: "v2".to_string(), + dtype: daft_core::datatypes::DataType::Int64, + metadata: Arc::default(), + }, + ); + fields.insert( + "v3".to_string(), + Field { + name: "v3".to_string(), + dtype: daft_core::datatypes::DataType::Float64, + metadata: Arc::default(), + }, + ); + + let runtime_handle = get_runtime(true)?; + let _rt_guard = runtime_handle.enter(); + let result = runtime_handle.block_on(async { + read_csv_local( + file.as_ref(), + Some(CsvConvertOptions { + limit: None, + include_columns: None, + column_names: None, + schema: Some(Arc::new(Schema { fields: fields })), + predicate: None, + }), + CsvParseOptions::default().with_delimiter(b','), + None, + None, + ) + .await + }); + + assert!( + result.is_ok(), + "Got Err: {:?} when using the experimental local csv reader", + result + ); + + let column_names = vec!["id1", "id2", "id3", "id4", "id5", "id6", "v1", "v2", "v3"]; + check_equal_local_arrow2( + file, + &result.unwrap(), + true, + None, + true, + None, + None, + None, + Some(column_names), + None, + None, + ); + + Ok(()) + } + #[test] fn test_csv_read_local_no_headers() -> DaftResult<()> { let file = format!( diff --git a/src/daft-decoding/src/deserialize.rs b/src/daft-decoding/src/deserialize.rs index dd3c56292e..ffb4a9aa6d 100644 --- a/src/daft-decoding/src/deserialize.rs +++ b/src/daft-decoding/src/deserialize.rs @@ -2,12 +2,13 @@ use arrow2::{ array::*, datatypes::*, error::{Error, Result}, + io::csv, offset::Offset, temporal_conversions, types::NativeType, }; use chrono::{Datelike, Timelike}; -use csv_async::ByteRecord; +use csv_async; pub(crate) const ISO8601: &str = "%+"; pub(crate) const ISO8601_NO_TIME_ZONE: &str = "%Y-%m-%dT%H:%M:%S%.f"; @@ -35,7 +36,14 @@ pub trait ByteRecordGeneric { fn get(&self, index: usize) -> Option<&[u8]>; } -impl ByteRecordGeneric for ByteRecord { +impl ByteRecordGeneric for csv_async::ByteRecord { + #[inline] + fn get(&self, index: usize) -> Option<&[u8]> { + self.get(index) + } +} + +impl ByteRecordGeneric for csv::read::ByteRecord { #[inline] fn get(&self, index: usize) -> Option<&[u8]> { self.get(index) From c3ce7e1e76d0c098478848c5a23cbe014f0e27e3 Mon Sep 17 00:00:00 2001 From: desmondcheongzx Date: Sat, 28 Sep 2024 04:33:00 -0700 Subject: [PATCH 02/10] Rearchitect --- src/daft-csv/src/lib.rs | 1 + src/daft-csv/src/local.rs | 912 ++++++++++++++++++++++++++++++++++++++ src/daft-csv/src/read.rs | 409 +---------------- 3 files changed, 925 insertions(+), 397 deletions(-) create mode 100644 src/daft-csv/src/local.rs diff --git a/src/daft-csv/src/lib.rs b/src/daft-csv/src/lib.rs index 35c4d72006..a5a93ca56b 100644 --- a/src/daft-csv/src/lib.rs +++ b/src/daft-csv/src/lib.rs @@ -7,6 +7,7 @@ extern crate test; use common_error::DaftError; use snafu::Snafu; +pub mod local; pub mod metadata; pub mod options; #[cfg(feature = "python")] diff --git a/src/daft-csv/src/local.rs b/src/daft-csv/src/local.rs new file mode 100644 index 0000000000..3bd73f2ef0 --- /dev/null +++ b/src/daft-csv/src/local.rs @@ -0,0 +1,912 @@ +use core::str; +use std::io::{Chain, Cursor, Read}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::{num::NonZeroUsize, sync::Arc, sync::Condvar, sync::Mutex}; + +use crate::ArrowSnafu; +use crate::{CsvConvertOptions, CsvParseOptions, CsvReadOptions}; +use arrow2::{ + datatypes::Field, + io::csv::read, + io::csv::read::{Reader, ReaderBuilder}, + io::csv::read_async::local_read_rows, +}; +use common_error::{DaftError, DaftResult}; +use crossbeam_channel::Sender; +use daft_core::{schema::Schema, utils::arrow::cast_array_for_daft_if_needed, Series}; +use daft_decoding::deserialize::deserialize_column; +use daft_dsl::{optimization::get_required_columns, Expr}; +use daft_table::Table; +use futures::{stream::BoxStream, Stream, StreamExt}; +use rayon::{ + iter::IndexedParallelIterator, + prelude::{IntoParallelRefIterator, ParallelIterator}, +}; +use snafu::ResultExt; + +use crate::read::{fields_to_projection_indices, tables_concat}; + +#[allow(clippy::doc_lazy_continuation)] +/// Our local CSV reader has the following approach to reading CSV files: +/// 1. Read the CSV file in 4MB chunks from a slab pool. +/// 2. Adjust the chunks so that chunks are contiguous and contain complete +/// CSV records. See `get_file_chunk` for more details. +/// 3. In parallel with the above, convert the adjusted chunks into byte records, +/// which are stored within pre-allocated CSV buffers. +/// 4. In parallel with the above, deserialize each CSV buffer into a Daft table +/// and stream the results. +/// +/// Slab Pool CSV Buffer Pool +/// ┌────────────────────┐ ┌────────────────────┐ +/// │ 4MB Chunks │ │ CSV Buffers │ +/// │┌───┐┌───┐┌───┐ │ │┌───┐┌───┐┌───┐ │ +/// ││ ││ ││ │ ... │ ││ ││ ││ │ ... │ +/// │└─┬─┘└─┬─┘└───┘ │ │└─┬─┘└─┬─┘└───┘ │ +/// └──┼────┼────────────┘ └──┼────┼────────────┘ +/// │ │ │ │ +/// ───────┐ │ │ │ │ +/// /│ │ │ │ │ │ +/// /─┘ │ │ │ │ │ +/// │ │ ▼ ▼ ▼ ▼ +/// │ ─┼───►┌───┐ ┌───┐ ┌────┐ ┬--─┐ ┌───┐ ┌───┐ ┌───┐ ┌───┐ +/// │ │ │ │ │ │ ──────► │ ┬┘┌─┘ ┬─┘ ───────► │ │ │ │ ──────────► │ │ │ │ +/// │ CSV File │ └───┘ └───┘ └───┴ └───┘ └───┘ └───┘ └───┘ └───┘ +/// │ │ Chain of buffers Adjusted chunks Vectors of ByteRecords Stream of Daft tables +/// │ │ +/// └──────────┘ + +/// A pool of ByteRecord slabs. Used for deserializing CSV. +#[derive(Debug)] +struct CsvBufferPool { + buffers: Mutex>>, + buffer_size: usize, + record_buffer_size: usize, + num_fields: usize, +} + +/// A slab of ByteRecords. Used for deserializing CSV. +struct CsvBuffer { + pool: Arc, + buffer: Vec, +} + +impl CsvBufferPool { + pub fn new( + record_buffer_size: usize, + num_fields: usize, + chunk_size_rows: usize, + initial_pool_size: usize, + ) -> Self { + let chunk_buffers = vec![ + vec![ + read::ByteRecord::with_capacity(record_buffer_size, num_fields); + chunk_size_rows + ]; + initial_pool_size + ]; + CsvBufferPool { + buffers: Mutex::new(chunk_buffers), + buffer_size: chunk_size_rows, + record_buffer_size, + num_fields, + } + } + + pub fn get_buffer(self: &Arc) -> CsvBuffer { + let mut buffers = self.buffers.lock().unwrap(); + let buffer = buffers.pop(); + let buffer = match buffer { + Some(buffer) => buffer, + None => { + println!("csv buf empty"); + vec![ + read::ByteRecord::with_capacity(self.record_buffer_size, self.num_fields); + self.buffer_size + ] + } + }; + + CsvBuffer { + pool: Arc::clone(self), + buffer, + } + } + + fn return_buffer(&self, buffer: Vec) { + let mut buffers = self.buffers.lock().unwrap(); + buffers.push(buffer); + } +} + +// The default size of a slab used for reading CSV files in chunks. Currently set to 4MB. +const SLABSIZE: usize = 4 * 1024 * 1024; +// The default number of slabs in a slab pool. +const SLABPOOL_DEFAULT_SIZE: usize = 20; + +/// A pool of 4MB slabs. Used for reading CSV files in 4MB chunks. +#[derive(Debug)] +struct SlabPool { + buffers: Mutex>>, + condvar: Condvar, +} + +/// A 4MB slab of bytes. Used for reading CSV files in 4MB chunks. +#[derive(Clone)] +struct Slab { + pool: Arc, + // We wrap the Arc in an Option so that when a Slab is being dropped, we can move the Slab's reference + // to the Arc back to the slab pool. + buffer: Option>, +} + +impl Drop for Slab { + fn drop(&mut self) { + // Move the buffer back to the slab pool. + if let Some(buffer) = self.buffer.take() { + self.pool.return_buffer(buffer); + } + } +} + +impl SlabPool { + pub fn new() -> Self { + let chunk_buffers: Vec> = (0..SLABPOOL_DEFAULT_SIZE) + .map(|_| Arc::new([0; SLABSIZE])) + .collect(); + SlabPool { + buffers: Mutex::new(chunk_buffers), + condvar: Condvar::new(), + } + } + + pub fn get_buffer(self: &Arc) -> Arc<[u8; SLABSIZE]> { + let mut buffers = self.buffers.lock().unwrap(); + while buffers.is_empty() { + buffers = self.condvar.wait(buffers).unwrap(); + } + buffers.pop().unwrap() + } + + fn return_buffer(&self, buffer: Arc<[u8; SLABSIZE]>) { + let mut buffers = self.buffers.lock().unwrap(); + buffers.push(buffer); + self.condvar.notify_one(); + } +} + +/// A data structure that holds either a single slice of bytes, or a chain of two slices of bytes. +/// See the description to `parse_json` for more details. +#[derive(Debug)] +enum BufferSource<'a> { + Single(Cursor<&'a [u8]>), + Chain(Chain, Cursor<&'a [u8]>>), +} + +/// Read implementation that allows BufferSource to be used by csv::read::Reader. +impl<'a> Read for BufferSource<'a> { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + match self { + BufferSource::Single(cursor) => std::io::Read::read(cursor, buf), + BufferSource::Chain(chain) => chain.read(buf), + } + } +} + +pub async fn read_csv_local( + uri: &str, + convert_options: Option, + parse_options: CsvParseOptions, + read_options: Option, + max_chunks_in_flight: Option, +) -> DaftResult
{ + let stream = stream_csv_local( + uri, + convert_options, + parse_options, + read_options, + max_chunks_in_flight, + )?; + tables_concat(tables_stream_collect(Box::pin(stream)).await) +} + +async fn tables_stream_collect(stream: BoxStream<'static, DaftResult
>) -> Vec
{ + stream + .filter_map(|result| async { + match result { + Ok(table) => Some(table), + Err(_) => None, // Skips errors; you could log them or handle differently + } + }) + .collect() + .await +} + +pub fn stream_csv_local( + uri: &str, + convert_options: Option, + parse_options: CsvParseOptions, + read_options: Option, + max_chunks_in_flight: Option, +) -> DaftResult> + Send> { + let uri = uri.trim_start_matches("file://"); + let file = std::fs::File::open(uri)?; + + // TODO(desmond): This logic is repeated multiple times in the csv reader files. Should dedup. + let predicate = convert_options + .as_ref() + .and_then(|opts| opts.predicate.clone()); + + let limit = convert_options.as_ref().and_then(|opts| opts.limit); + + let include_columns = convert_options + .as_ref() + .and_then(|opts| opts.include_columns.clone()); + + let convert_options = match (convert_options, &predicate) { + (None, _) => None, + (co, None) => co, + (Some(mut co), Some(predicate)) => { + if let Some(ref mut include_columns) = co.include_columns { + let required_columns_for_predicate = get_required_columns(predicate); + for rc in required_columns_for_predicate { + if include_columns.iter().all(|c| c.as_str() != rc.as_str()) { + include_columns.push(rc) + } + } + } + // If we have a limit and a predicate, remove limit for stream. + co.limit = None; + Some(co) + } + } + .unwrap_or_default(); + // End of `should dedup`. + + // TODO(desmond): We should do better schema inference here. + let schema = convert_options.clone().schema.unwrap().to_arrow()?; + let n_threads: usize = std::thread::available_parallelism() + .unwrap_or(NonZeroUsize::new(2).unwrap()) + .into(); + let chunk_size = read_options + .as_ref() + .and_then(|opt| opt.chunk_size.or_else(|| opt.buffer_size.map(|bs| bs / 8))) + .unwrap_or(DEFAULT_CHUNK_SIZE); + let projection_indices = fields_to_projection_indices( + &schema.clone().fields, + &convert_options.clone().include_columns, + ); + let fields = schema.clone().fields; + let fields_subset = projection_indices + .iter() + .map(|i| fields.get(*i).unwrap().into()) + .collect::>(); + let read_schema = Arc::new(daft_core::schema::Schema::new(fields_subset)?); + let read_daft_fields = Arc::new( + read_schema + .fields + .values() + .map(|f| Arc::new(f.clone())) + .collect::>(), + ); + // TODO(desmond): Need better upfront estimators. Cory did something like what we need here: https://github.com/universalmind303/Daft/blob/7b40f23a5ff83aba4ab059b62ac781d7766be0b1/src/daft-json/src/local.rs#L338 + let estimated_mean_row_size = 100f64; + let estimated_std_row_size = 20f64; + let record_buffer_size = (estimated_mean_row_size + estimated_std_row_size).ceil() as usize; + let chunk_size_rows = (chunk_size as f64 / record_buffer_size as f64).ceil() as usize; + let num_fields = schema.fields.len(); + // TODO(desmond): We might consider creating per-process buffer pools and slab pools. + let buffer_pool = Arc::new(CsvBufferPool::new( + record_buffer_size, + num_fields, + chunk_size_rows, + n_threads * 2, + )); + let slabpool = Arc::new(SlabPool::new()); + // We suppose that each slab of CSV data produces (chunk size / slab size) number of Daft tables. We + // then double this capacity to ensure that our channel is never full and our threads won't deadlock. + let (sender, receiver) = + crossbeam_channel::bounded(max_chunks_in_flight.unwrap_or(2 * chunk_size / SLABSIZE)); + rayon::spawn(move || { + consume_csv_file( + file, + buffer_pool, + slabpool, + parse_options, + projection_indices, + read_daft_fields, + read_schema, + fields, + num_fields, + &include_columns, + predicate, + limit, + sender, + ); + }); + let result_stream = futures::stream::iter(receiver); + Ok(result_stream) +} + +/// Consumes the CSV file and sends the results to `sender`. +#[allow(clippy::too_many_arguments)] +fn consume_csv_file( + mut file: std::fs::File, + buffer_pool: Arc, + slabpool: Arc, + parse_options: CsvParseOptions, + projection_indices: Arc>, + read_daft_fields: Arc>>, + read_schema: Arc, + fields: Vec, + num_fields: usize, + include_columns: &Option>, + predicate: Option>, + limit: Option, + sender: Sender>, +) { + let rows_read = Arc::new(AtomicUsize::new(0)); + let mut has_header = parse_options.has_header; + let total_len = file.metadata().unwrap().len() as usize; + let field_delimiter = parse_options.delimiter; + let escape_char = parse_options.escape_char; + let quote_char = parse_options.quote; + let double_quote_escape_allowed = parse_options.double_quote; + let mut total_bytes_read = 0; + let mut next_slab = None; + let mut next_buffer_len = 0; + let mut first_buffer = true; + loop { + let limit_reached = limit.map_or(false, |limit| { + let current_rows_read = rows_read.load(Ordering::Relaxed); + current_rows_read >= limit + }); + if limit_reached { + break; + } + let (current_slab, current_buffer_len) = match next_slab.take() { + Some(next_slab) => { + total_bytes_read += next_buffer_len; + (next_slab, next_buffer_len) + } + None => { + let mut buffer = slabpool.get_buffer(); + match Arc::get_mut(&mut buffer) { + Some(inner_buffer) => { + let bytes_read = file.read(inner_buffer).unwrap(); + if bytes_read == 0 { + slabpool.return_buffer(buffer); + break; + } + total_bytes_read += bytes_read; + ( + Arc::new(Slab { + pool: Arc::clone(&slabpool), + buffer: Some(buffer), + }), + bytes_read, + ) + } + None => { + slabpool.return_buffer(buffer); + break; + } + } + } + }; + (next_slab, next_buffer_len) = if total_bytes_read < total_len { + let mut next_buffer = slabpool.get_buffer(); + match Arc::get_mut(&mut next_buffer) { + Some(inner_buffer) => { + let bytes_read = file.read(inner_buffer).unwrap(); + if bytes_read == 0 { + slabpool.return_buffer(next_buffer); + (None, 0) + } else { + ( + Some(Arc::new(Slab { + pool: Arc::clone(&slabpool), + buffer: Some(next_buffer), + })), + bytes_read, + ) + } + } + None => { + slabpool.return_buffer(next_buffer); + break; + } + } + } else { + (None, 0) + }; + let file_chunk = get_file_chunk( + unsafe_clone_buffer(¤t_slab.buffer), + current_buffer_len, + next_slab + .as_ref() + .map(|slab| unsafe_clone_buffer(&slab.buffer)), + next_buffer_len, + first_buffer, + num_fields, + quote_char, + field_delimiter, + escape_char, + double_quote_escape_allowed, + ); + first_buffer = false; + if let (None, _) = file_chunk { + // Return the buffer. It doesn't matter that we still have a reference to the slab. We're going to fallback + // and the slabs will be useless. + slabpool.return_buffer(unsafe_clone_buffer(¤t_slab.buffer)); + // Exit early before spawning a new thread. + break; + // TODO(desmond): we should fallback instead. + } + let current_slab_clone = Arc::clone(¤t_slab); + let next_slab_clone = next_slab.clone(); + let parse_options = parse_options.clone(); + let csv_buffer = buffer_pool.get_buffer(); + let projection_indices = projection_indices.clone(); + let fields = fields.clone(); + let read_daft_fields = read_daft_fields.clone(); + let read_schema = read_schema.clone(); + let include_columns = include_columns.clone(); + let predicate = predicate.clone(); + let sender = sender.clone(); + let rows_read = Arc::clone(&rows_read); + rayon::spawn(move || { + let limit_reached = limit.map_or(false, |limit| { + let current_rows_read = rows_read.load(Ordering::Relaxed); + current_rows_read >= limit + }); + if !limit_reached { + match file_chunk { + (Some(start), None) => { + if let Some(buffer) = ¤t_slab_clone.buffer { + let buffer_source = BufferSource::Single(Cursor::new( + &buffer[start..current_buffer_len], + )); + dispatch_to_parse_csv( + has_header, + parse_options, + buffer_source, + projection_indices, + fields, + read_daft_fields, + read_schema, + csv_buffer, + &include_columns, + predicate, + sender, + rows_read, + ); + } else { + panic!("Trying to read from a CSV buffer that doesn't exist. Please report this issue.") + } + } + (Some(start), Some(end)) => { + if let Some(next_slab_clone) = next_slab_clone + && let Some(current_buffer) = ¤t_slab_clone.buffer + && let Some(next_buffer) = &next_slab_clone.buffer + { + let buffer_source = BufferSource::Chain(std::io::Read::chain( + Cursor::new(¤t_buffer[start..current_buffer_len]), + Cursor::new(&next_buffer[..end]), + )); + dispatch_to_parse_csv( + has_header, + parse_options, + buffer_source, + projection_indices, + fields, + read_daft_fields, + read_schema, + csv_buffer, + &include_columns, + predicate, + sender, + rows_read, + ); + } else { + panic!("Trying to read from an overflow CSV buffer that doesn't exist. Please report this issue.") + } + } + _ => panic!( + "Something went wrong when parsing the CSV file. Please report this issue." + ), + }; + } + }); + has_header = false; + if total_bytes_read >= total_len { + break; + } + } +} + +/// Unsafe helper function that extracts the buffer from an &Option>. Users should +/// ensure that the buffer is Some, otherwise this function causes the process to panic. +fn unsafe_clone_buffer(buffer: &Option>) -> Arc<[u8; SLABSIZE]> { + match buffer { + Some(buffer) => Arc::clone(buffer), + None => panic!("Tried to clone a CSV slab that doesn't exist. Please report this error."), + } +} + +#[allow(clippy::doc_lazy_continuation)] +/// Helper function that determines what chunk of data to parse given a starting position within the +/// file, and the desired initial chunk size. +/// +/// Given a starting position, we use our chunk size to compute a preliminary start and stop +/// position. For example, we can visualize all preliminary chunks in a file as follows. +/// +/// Chunk 1 Chunk 2 Chunk 3 Chunk N +/// ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ +/// │ │ │\n │ │ \n │ │ \n │ +/// │ │ │ │ │ │ │ │ +/// │ │ │ \n │ │ │ │ │ +/// │ \n │ │ │ │ \n │ │ │ +/// │ │ │ │ │ │ ... │ \n │ +/// │ │ │ \n │ │ │ │ │ +/// │ \n │ │ │ │ │ │ │ +/// │ │ │ │ │ \n │ │ \n │ +/// └──────────┘ └──────────┘ └──────────┘ └──────────┘ +/// +/// However, record boundaries (i.e. the \n terminators) do not align nicely with these preliminary +/// chunk boundaries. So we adjust each preliminary chunk as follows: +/// - Find the first record terminator from the chunk's start. This is the new starting position. +/// - Find the first record terminator from the chunk's end. This is the new ending position. +/// - If a given preliminary chunk doesn't contain a record terminator, the adjusted chunk is empty. +/// +/// For example: +/// +/// Adjusted Chunk 1 Adj. Chunk 2 Adj. Chunk 3 Adj. Chunk N +/// ┌──────────────────┐┌─────────────────┐ ┌────────┐ ┌─┐ +/// │ \n││ \n│ │ \n│ \n │ │ +/// │ ┌───────┘│ ┌──────────┘ │ ┌─────┘ │ │ +/// │ │ ┌───┘ \n │ ┌───────┘ │ ┌────────┘ │ +/// │ \n │ │ │ │ \n │ │ │ +/// │ │ │ │ │ │ ... │ \n │ +/// │ │ │ \n │ │ │ │ │ +/// │ \n │ │ │ │ │ │ │ +/// │ │ │ │ │ \n │ │ \n │ +/// └──────────┘ └──────────┘ └──────────┘ └──────────┘ +/// +/// Using this method, we now have adjusted chunks that are aligned with record boundaries, that do +/// not overlap, and that fully cover every byte in the CSV file. Parsing each adjusted chunk can +/// now happen in parallel. +/// +/// This is the same method as described in: +/// Ge, Chang et al. “Speculative Distributed CSV Data Parsing for Big Data Analytics.” Proceedings of the 2019 International Conference on Management of Data (2019). +/// +/// Another observation is that seeing a pure \n character is not necessarily indicative of a record +/// terminator. We need to consider whether the \n character was seen within a quoted field, since the +/// string "some text \n some text" is a valid CSV string field. To do this, we carry out the following +/// algorithm: +/// 1. Find a \n character. +/// 2. Check if the CSV string immediately following this \n character is valid, i.e. does it parse +/// as valid CSV, and does it produce the same number of fields as our schema. +/// 2a. If there is a valid record at this point, then we assume that the \n we saw was a valid terminator. +/// 2b. If the record at this point is invalid, then this was likely a \n in a quoted field. Find the next +/// \n character and go back to 2. +#[allow(clippy::too_many_arguments)] +fn get_file_chunk( + current_buffer: Arc<[u8; SLABSIZE]>, + current_buffer_len: usize, + next_buffer: Option>, + next_buffer_len: usize, + first_buffer: bool, + num_fields: usize, + quote_char: u8, + field_delimiter: u8, + escape_char: Option, + double_quote_escape_allowed: bool, +) -> (Option, Option) { + // TODO(desmond): There is a potential fast path here when `escape_char` is None: simply check for \n characters. + let start = if !first_buffer { + let start = next_line_position( + ¤t_buffer[..current_buffer_len], + 0, + num_fields, + quote_char, + field_delimiter, + escape_char, + double_quote_escape_allowed, + ); + match start { + Some(_) => start, + None => return (None, None), // If the record size is >= 4MB, return None and fallback. + } + } else { + Some(0) + }; + // If there is a next buffer, find the adjusted chunk in that buffer. If there's no next buffer, we're at the end of the file. + let end = if let Some(next_buffer) = next_buffer { + let end = next_line_position( + &next_buffer[..next_buffer_len], + 0, + num_fields, + quote_char, + field_delimiter, + escape_char, + double_quote_escape_allowed, + ); + match end { + Some(_) => end, + None => return (None, None), // If the record size is >= 4MB, return None and fallback. + } + } else { + None + }; + (start, end) +} + +/// Helper function that finds the first valid record terminator in a buffer. +fn next_line_position( + buffer: &[u8], + offset: usize, + num_fields: usize, + quote_char: u8, + field_delimiter: u8, + escape_char: Option, + double_quote_escape_allowed: bool, +) -> Option { + let mut start = offset; + loop { + start = match newline_position(&buffer[start..]) { + // Start reading after the first record terminator from the start of the chunk. + Some(pos) => start + pos + 1, + None => return None, + }; + if start >= buffer.len() { + return None; + } + if validate_csv_record( + &buffer[start..], + num_fields, + quote_char, + field_delimiter, + escape_char, + double_quote_escape_allowed, + ) { + return Some(start); + } + } +} + +// Daft does not currently support non-\n record terminators (e.g. carriage return \r, which only +// matters for pre-Mac OS X). +const NEWLINE: u8 = b'\n'; +const DOUBLE_QUOTE: u8 = b'"'; +const DEFAULT_CHUNK_SIZE: usize = 4 * 1024 * 1024; // 1MiB. TODO(desmond): This should be tuned. + +/// Helper function that finds the first new line character (\n) in the given byte slice. +fn newline_position(buffer: &[u8]) -> Option { + // Assuming we are searching for the ASCII `\n` character, we don't need to do any special + // handling for UTF-8, since a `\n` value always corresponds to an ASCII `\n`. + // For more details, see: https://en.wikipedia.org/wiki/UTF-8#Encoding + memchr::memchr(NEWLINE, buffer) +} + +/// Csv states used by the state machine in `validate_csv_record`. +#[derive(Clone)] +enum CsvState { + FieldStart, + RecordEnd, + UnquotedField, + QuotedField, + Unquote, +} + +/// State machine that validates whether the current buffer starts at a valid csv record. +/// See `get_file_chunk` for more details. +fn validate_csv_record( + buffer: &[u8], + num_fields: usize, + quote_char: u8, + field_delimiter: u8, + escape_char: Option, + double_quote_escape_allowed: bool, +) -> bool { + let mut state = CsvState::FieldStart; + let mut index = 0; + let mut num_fields_seen = 0; + loop { + if index >= buffer.len() { + // We've reached the end of the buffer without seeing a valid record. + return false; + } + match state { + CsvState::FieldStart => { + let byte = buffer[index]; + if byte == NEWLINE { + state = CsvState::RecordEnd; + } else if byte == quote_char { + state = CsvState::QuotedField; + index += 1; + } else { + state = CsvState::UnquotedField; + } + } + CsvState::RecordEnd => { + return num_fields_seen == num_fields; + } + CsvState::UnquotedField => { + // We follow the convention where an unquoted field does consider escape characters. + while index < buffer.len() { + let byte = buffer[index]; + if byte == NEWLINE { + num_fields_seen += 1; + state = CsvState::RecordEnd; + break; + } + if byte == field_delimiter { + num_fields_seen += 1; + state = CsvState::FieldStart; + index += 1; + break; + } + index += 1; + } + } + CsvState::QuotedField => { + while index < buffer.len() { + let byte = buffer[index]; + if byte == quote_char { + state = CsvState::Unquote; + index += 1; + break; + } + if let Some(escape_char) = escape_char + && byte == escape_char + { + // Skip the next character. + index += 1; + } + index += 1; + } + } + CsvState::Unquote => { + let byte = buffer[index]; + if let Some(escape_char) = escape_char + && byte == escape_char + && escape_char == quote_char + && (byte != DOUBLE_QUOTE || double_quote_escape_allowed) + { + state = CsvState::QuotedField; + index += 1; + continue; + } + if byte == NEWLINE { + num_fields_seen += 1; + state = CsvState::RecordEnd; + continue; + } + if byte == field_delimiter { + num_fields_seen += 1; + state = CsvState::FieldStart; + index += 1; + continue; + } + // Other characters are not allowed after a quote. This is invalid CSV. + return false; + } + } + } +} + +/// Helper function that takes in a BufferSource, calls parse_csv() to extract table values from +/// the buffer source, then streams the results to `sender`. +#[allow(clippy::too_many_arguments)] +fn dispatch_to_parse_csv( + has_header: bool, + parse_options: CsvParseOptions, + buffer_source: BufferSource, + projection_indices: Arc>, + fields: Vec, + read_daft_fields: Arc>>, + read_schema: Arc, + csv_buffer: CsvBuffer, + include_columns: &Option>, + predicate: Option>, + sender: Sender>, + rows_read: Arc, +) { + let table_results = { + let rdr = ReaderBuilder::new() + .has_headers(has_header) + .delimiter(parse_options.delimiter) + .double_quote(parse_options.double_quote) + .quote(parse_options.quote) + .escape(parse_options.escape_char) + .comment(parse_options.comment) + .flexible(parse_options.allow_variable_columns) + .from_reader(buffer_source); + parse_csv_chunk( + rdr, + projection_indices, + fields, + read_daft_fields, + read_schema, + csv_buffer, + include_columns, + predicate, + ) + }; + match table_results { + Ok(tables) => { + for table in tables { + let table_len = table.len(); + sender.send(Ok(table)).unwrap(); + // Atomically update the number of rows read only after the result has + // been sent. In theory we could wrap these steps in a mutex, but + // applying limit at this layer can be best-effort with no adverse + // side effects. + rows_read.fetch_add(table_len, Ordering::Relaxed); + } + } + Err(e) => sender.send(Err(e)).unwrap(), + } +} + +/// Helper function that consumes a CSV reader and turns it into a vector of Daft tables. +#[allow(clippy::too_many_arguments)] +fn parse_csv_chunk( + mut reader: Reader, + projection_indices: Arc>, + fields: Vec, + read_daft_fields: Arc>>, + read_schema: Arc, + csv_buffer: CsvBuffer, + include_columns: &Option>, + predicate: Option>, +) -> DaftResult> +where + R: std::io::Read, +{ + let mut chunk_buffer = csv_buffer.buffer; + let mut tables = vec![]; + loop { + //let time = Instant::now(); + let (rows_read, has_more) = + local_read_rows(&mut reader, chunk_buffer.as_mut_slice()).context(ArrowSnafu {})?; + //let time = Instant::now(); + let chunk = projection_indices + .par_iter() + .enumerate() + .map(|(i, proj_idx)| { + let deserialized_col = deserialize_column( + &chunk_buffer[0..rows_read], + *proj_idx, + fields[*proj_idx].data_type().clone(), + 0, + ); + Series::try_from_field_and_arrow_array( + read_daft_fields[i].clone(), + cast_array_for_daft_if_needed(deserialized_col?), + ) + }) + .collect::>>()?; + let num_rows = chunk.first().map(|s| s.len()).unwrap_or(0); + let table = Table::new_unchecked(read_schema.clone(), chunk, num_rows); + let table = if let Some(predicate) = &predicate { + let filtered = table.filter(&[predicate.clone()])?; + if let Some(include_columns) = &include_columns { + filtered.get_columns(include_columns.as_slice())? + } else { + filtered + } + } else { + table + }; + tables.push(table); + + // The number of record might exceed the number of byte records we've allocated. + // Retry until all byte records in this chunk are read. + if !has_more { + break; + } + } + csv_buffer.pool.return_buffer(chunk_buffer); + Ok(tables) +} diff --git a/src/daft-csv/src/read.rs b/src/daft-csv/src/read.rs index 3ed49204b3..9609e7bb63 100644 --- a/src/daft-csv/src/read.rs +++ b/src/daft-csv/src/read.rs @@ -1,19 +1,16 @@ -use std::{collections::HashMap, num::NonZeroUsize, sync::Arc, sync::Mutex}; - -use std::sync::atomic::{AtomicUsize, Ordering}; +use core::str; +use std::{collections::HashMap, num::NonZeroUsize, sync::Arc}; use arrow2::{ datatypes::Field, - io::csv::read, - io::csv::read::{Reader, ReaderBuilder}, io::csv::read_async, - io::csv::read_async::{local_read_rows, read_rows, AsyncReaderBuilder}, + io::csv::read_async::{read_rows, AsyncReaderBuilder}, }; use async_compat::{Compat, CompatExt}; use common_error::{DaftError, DaftResult}; use csv_async::AsyncReader; use daft_core::{schema::Schema, utils::arrow::cast_array_for_daft_if_needed, Series}; -use daft_dsl::{optimization::get_required_columns, Expr}; +use daft_dsl::optimization::get_required_columns; use daft_io::{get_runtime, parse_url, GetResult, IOClient, IOStatsRef, SourceType}; use daft_table::Table; use futures::{stream::BoxStream, Stream, StreamExt, TryStreamExt}; @@ -32,8 +29,8 @@ use tokio::{ }; use tokio_util::io::StreamReader; +use crate::ArrowSnafu; use crate::{metadata::read_csv_schema_single, CsvConvertOptions, CsvParseOptions, CsvReadOptions}; -use crate::{ArrowSnafu, StdIOSnafu}; use daft_compression::CompressionCodec; use daft_decoding::deserialize::deserialize_column; @@ -43,6 +40,8 @@ impl ByteRecordChunkStream for S where { } +use crate::{local::read_csv_local, local::stream_csv_local}; + type TableChunkResult = super::Result>, super::JoinSnafu, super::Error>>; trait TableStream: Stream {} @@ -156,16 +155,14 @@ pub async fn stream_csv( let uri = uri.as_str(); let (source_type, _) = parse_url(uri)?; let is_compressed = CompressionCodec::from_uri(uri).is_some(); - let use_local_reader = false; // TODO(desmond): Feature under dev. - if matches!(source_type, SourceType::File) && !is_compressed && use_local_reader { + if matches!(source_type, SourceType::File) && !is_compressed { let stream = stream_csv_local( uri, convert_options, parse_options.unwrap_or_default(), read_options, max_chunks_in_flight, - ) - .await?; + )?; Ok(Box::pin(stream)) } else { let stream = stream_csv_single( @@ -182,7 +179,7 @@ pub async fn stream_csv( } } -fn tables_concat(mut tables: Vec
) -> DaftResult
{ +pub fn tables_concat(mut tables: Vec
) -> DaftResult
{ if tables.is_empty() { return Err(DaftError::ValueError( "Need at least 1 Table to perform concat".to_string(), @@ -220,387 +217,6 @@ fn tables_concat(mut tables: Vec
) -> DaftResult
{ ) } -#[derive(Debug)] -struct CsvBufferPool { - buffers: Mutex>>, - buffer_size: usize, - record_buffer_size: usize, - num_fields: usize, -} - -struct CsvBufferPoolRef<'a> { - pool: &'a CsvBufferPool, - buffer: Vec, -} - -impl CsvBufferPool { - pub fn new( - record_buffer_size: usize, - num_fields: usize, - chunk_size_rows: usize, - initial_pool_size: usize, - ) -> Self { - let chunk_buffers = vec![ - vec![ - read::ByteRecord::with_capacity(record_buffer_size, num_fields); - chunk_size_rows - ]; - initial_pool_size - ]; - CsvBufferPool { - buffers: Mutex::new(chunk_buffers), - buffer_size: chunk_size_rows, - record_buffer_size, - num_fields, - } - } - - pub fn get_buffer(&self) -> CsvBufferPoolRef { - let mut buffers = self.buffers.lock().unwrap(); - let buffer = buffers.pop(); - let buffer = match buffer { - Some(buffer) => buffer, - None => { - vec![ - read::ByteRecord::with_capacity(self.record_buffer_size, self.num_fields); - self.buffer_size - ] - } - }; - - CsvBufferPoolRef { pool: self, buffer } - } - - fn return_buffer(&self, buffer: Vec) { - let mut buffers = self.buffers.lock().unwrap(); - buffers.push(buffer); - } -} - -// Daft does not currently support non-\n record terminators (e.g. carriage return \r, which only -// matters for pre-Mac OS X). -const NEWLINE: u8 = b'\n'; -const DEFAULT_CHUNK_SIZE: usize = 4 * 1024 * 1024; // 1MiB. TODO(desmond): This should be tuned. - -/// Helper function that finds the first new line character (\n) in the given byte slice. -fn next_line_position(input: &[u8]) -> Option { - // Assuming we are searching for the ASCII `\n` character, we don't need to do any special - // handling for UTF-8, since a `\n` value always corresponds to an ASCII `\n`. - // For more details, see: https://en.wikipedia.org/wiki/UTF-8#Encoding - memchr::memchr(NEWLINE, input) -} - -/// Helper function that determines what chunk of data to parse given a starting position within the -/// file, and the desired initial chunk size. -/// -/// Given a starting position, we use our chunk size to compute a preliminary start and stop -/// position. For example, we can visualize all preliminary chunks in a file as follows. -/// -/// Chunk 1 Chunk 2 Chunk 3 Chunk N -/// ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ -/// │ │ │\n │ │ \n │ │ \n │ -/// │ │ │ │ │ │ │ │ -/// │ │ │ \n │ │ │ │ │ -/// │ \n │ │ │ │ \n │ │ │ -/// │ │ │ │ │ │ ... │ \n │ -/// │ │ │ \n │ │ │ │ │ -/// │ \n │ │ │ │ │ │ │ -/// │ │ │ │ │ \n │ │ \n │ -/// └──────────┘ └──────────┘ └──────────┘ └──────────┘ -/// -/// However, record boundaries (i.e. the \n terminators) do not align nicely with these preliminary -/// chunk boundaries. So we adjust each preliminary chunk as follows: -/// - Find the first record terminator from the chunk's start. This is the new starting position. -/// - Find the first record terminator from the chunk's end. This is the new ending position. -/// - If a given preliminary chunk doesn't contain a record terminator, the adjusted chunk is empty. -/// -/// For example: -/// -/// Adjusted Chunk 1 Adj. Chunk 2 Adj. Chunk 3 Adj. Chunk N -/// ┌──────────────────┐┌─────────────────┐ ┌────────┐ ┌─┐ -/// │ \n││ \n│ │ \n│ \n │ │ -/// │ ┌───────┘│ ┌──────────┘ │ ┌─────┘ │ │ -/// │ │ ┌───┘ \n │ ┌───────┘ │ ┌────────┘ │ -/// │ \n │ │ │ │ \n │ │ │ -/// │ │ │ │ │ │ ... │ \n │ -/// │ │ │ \n │ │ │ │ │ -/// │ \n │ │ │ │ │ │ │ -/// │ │ │ │ │ \n │ │ \n │ -/// └──────────┘ └──────────┘ └──────────┘ └──────────┘ -/// -/// Using this method, we now have adjusted chunks that are aligned with record boundaries, that do -/// not overlap, and that fully cover every byte in the CSV file. Parsing each adjusted chunk can -/// now happen in parallel. -/// -/// This is the same method as described in: -/// Ge, Chang et al. “Speculative Distributed CSV Data Parsing for Big Data Analytics.” Proceedings of the 2019 International Conference on Management of Data (2019). -fn get_file_chunk(bytes: &[u8], start: usize, chunk_size: usize) -> Option<(usize, usize)> { - let stop = start + chunk_size; - let start = if start == 0 { - 0 - } else { - match next_line_position(&bytes[start..]) { - // Start reading after the first record terminator from the start of the chunk. - Some(pos) => start + pos + 1, - // If there's no record terminator found, then the previous chunk reader would have - // consumed the current chunk. Hence, we skip. - None => return None, - } - }; - // If the first record terminator comes after this chunk, then the previous chunk reader would - // have consumed the current chunk. Hence, we skip. - if start > stop { - return None; - } - let stop = if stop < bytes.len() { - match next_line_position(&bytes[stop..]) { - // Read up to the first terminator from the end of the chunk. - Some(pos) => stop + pos, - None => bytes.len(), - } - } else { - bytes.len() - }; - Some((start, stop)) -} - -#[allow(clippy::too_many_arguments)] -fn parse_csv_chunk( - mut reader: Reader, - projection_indices: Arc>, - fields: Vec, - read_daft_fields: Arc>>, - read_schema: Arc, - buf: CsvBufferPoolRef, - include_columns: &Option>, - predicate: Option>, -) -> DaftResult> -where - R: std::io::Read, -{ - let mut chunk_buffer = buf.buffer; - let mut tables = vec![]; - loop { - let (rows_read, has_more) = - local_read_rows(&mut reader, chunk_buffer.as_mut_slice()).context(ArrowSnafu {})?; - let chunk = projection_indices - .par_iter() - .enumerate() - .map(|(i, proj_idx)| { - let deserialized_col = deserialize_column( - &chunk_buffer[0..rows_read], - *proj_idx, - fields[*proj_idx].data_type().clone(), - 0, - ); - Series::try_from_field_and_arrow_array( - read_daft_fields[i].clone(), - cast_array_for_daft_if_needed(deserialized_col?), - ) - }) - .collect::>>()?; - let num_rows = chunk.first().map(|s| s.len()).unwrap_or(0); - let table = Table::new_unchecked(read_schema.clone(), chunk, num_rows); - let table = if let Some(predicate) = &predicate { - let filtered = table.filter(&[predicate.clone()])?; - if let Some(include_columns) = &include_columns { - filtered.get_columns(include_columns.as_slice())? - } else { - filtered - } - } else { - table - }; - tables.push(table); - - // The number of record might exceed the number of byte records we've allocated. - // Retry until all byte records in this chunk are read. - if !has_more { - break; - } - } - buf.pool.return_buffer(chunk_buffer); - Ok(tables) -} - -async fn stream_csv_local( - uri: &str, - convert_options: Option, - parse_options: CsvParseOptions, - read_options: Option, - max_chunks_in_flight: Option, -) -> DaftResult> + Send> { - let uri = uri.trim_start_matches("file://"); - let file = std::fs::File::open(uri)?; - let mmap = unsafe { memmap2::Mmap::map(&file) }.context(StdIOSnafu)?; - let bytes = &mmap[..]; - - // TODO(desmond): This logic is repeated multiple times in this file. Should dedup. - let predicate = convert_options - .as_ref() - .and_then(|opts| opts.predicate.clone()); - - let limit = convert_options.as_ref().and_then(|opts| opts.limit); - - let include_columns = convert_options - .as_ref() - .and_then(|opts| opts.include_columns.clone()); - - let convert_options = match (convert_options, &predicate) { - (None, _) => None, - (co, None) => co, - (Some(mut co), Some(predicate)) => { - if let Some(ref mut include_columns) = co.include_columns { - let required_columns_for_predicate = get_required_columns(predicate); - for rc in required_columns_for_predicate { - if include_columns.iter().all(|c| c.as_str() != rc.as_str()) { - include_columns.push(rc) - } - } - } - // If we have a limit and a predicate, remove limit for stream. - co.limit = None; - Some(co) - } - } - .unwrap_or_default(); - // End of `should dedup`. - - // TODO(desmond): We should do better schema inference here. - let schema = convert_options.clone().schema.unwrap().to_arrow()?; - let n_threads: usize = std::thread::available_parallelism() - .unwrap_or(NonZeroUsize::new(2).unwrap()) - .into(); - let chunk_size = read_options - .as_ref() - .and_then(|opt| opt.chunk_size.or_else(|| opt.buffer_size.map(|bs| bs / 8))) - .unwrap_or(DEFAULT_CHUNK_SIZE); - let projection_indices = fields_to_projection_indices( - &schema.clone().fields, - &convert_options.clone().include_columns, - ); - let fields = schema.clone().fields; - let fields_subset = projection_indices - .iter() - .map(|i| fields.get(*i).unwrap().into()) - .collect::>(); - let read_schema = Arc::new(daft_core::schema::Schema::new(fields_subset)?); - let read_daft_fields = Arc::new( - read_schema - .fields - .values() - .map(|f| Arc::new(f.clone())) - .collect::>(), - ); - // TODO(desmond): Need better upfront estimators. Sample or keep running count of stats. - let estimated_mean_row_size = 100f64; - let estimated_std_row_size = 20f64; - let record_buffer_size = (estimated_mean_row_size + estimated_std_row_size).ceil() as usize; - let chunk_size_rows = (chunk_size as f64 / record_buffer_size as f64).ceil() as usize; - // TODO(desmond): We don't want to create a per-read buffer pool, we want one pool shared with - // the whole process. - let buffer_pool = CsvBufferPool::new( - record_buffer_size, - schema.fields.len(), - chunk_size_rows, - n_threads * 2, - ); - let chunk_offsets: Vec = (0..=bytes.len()).step_by(chunk_size).collect(); - // TODO(desmond): Memory usage is still growing during execution of a .count(*).collect(), so - // the following approach still isn't quite right. - // TODO(desmond): Also, is this usage of max_chunks_in_flight correct? - let (sender, receiver) = - crossbeam_channel::bounded(max_chunks_in_flight.unwrap_or(n_threads * 2)); - let rows_read = AtomicUsize::new(0); - rayon::spawn(move || { - let bytes = &mmap[..]; - chunk_offsets.into_par_iter().for_each(|start| { - // TODO(desmond): Use try_for_each and terminate early once the limit is reached. - let limit_reached = limit.map_or(false, |limit| { - let current_rows_read = rows_read.load(Ordering::Relaxed); - current_rows_read >= limit - }); - - if !limit_reached && let Some((start, stop)) = get_file_chunk(bytes, start, chunk_size) - { - let buf = buffer_pool.get_buffer(); - let chunk = &bytes[start..stop]; - // Only the first chunk might potentially have headers. Subsequent chunks should - // read all rows as records. - let has_headers = start == 0 && parse_options.has_header; - let rdr = ReaderBuilder::new() - .has_headers(has_headers) - .delimiter(parse_options.delimiter) - .double_quote(parse_options.double_quote) - // TODO(desmond): We need to handle the quoted case properly. - .quote(parse_options.quote) - .escape(parse_options.escape_char) - .comment(parse_options.comment) - .flexible(parse_options.allow_variable_columns) - .from_reader(chunk); - Reader::from_reader(chunk); - let table_results = parse_csv_chunk( - rdr, - projection_indices.clone(), - fields.clone(), - read_daft_fields.clone(), - read_schema.clone(), - buf, - &include_columns, - predicate.clone(), - ); - match table_results { - Ok(tables) => { - for table in tables { - let table_len = table.len(); - sender.send(Ok(table)).unwrap(); - // Atomically update the number of rows read only after the result has - // been sent. In theory we could wrap these steps in a mutex, but - // applying limit at this layer can be best-effort with no adverse - // side effects. - rows_read.fetch_add(table_len, Ordering::Relaxed); - } - } - Err(e) => sender.send(Err(e)).unwrap(), - } - } - }) - }); - - let result_stream = futures::stream::iter(receiver); - Ok(result_stream) -} - -async fn tables_stream_collect(stream: BoxStream<'static, DaftResult
>) -> Vec
{ - stream - .filter_map(|result| async { - match result { - Ok(table) => Some(table), - Err(_) => None, // Skips errors; you could log them or handle differently - } - }) - .collect() - .await -} - -async fn read_csv_local( - uri: &str, - convert_options: Option, - parse_options: CsvParseOptions, - read_options: Option, - max_chunks_in_flight: Option, -) -> DaftResult
{ - let stream = stream_csv_local( - uri, - convert_options, - parse_options, - read_options, - max_chunks_in_flight, - ) - .await?; - tables_concat(tables_stream_collect(Box::pin(stream)).await) -} - async fn read_csv_single_into_table( uri: &str, convert_options: Option, @@ -612,8 +228,7 @@ async fn read_csv_single_into_table( ) -> DaftResult
{ let (source_type, _) = parse_url(uri)?; let is_compressed = CompressionCodec::from_uri(uri).is_some(); - let use_local_reader = false; // TODO(desmond): Feature under dev. - if matches!(source_type, SourceType::File) && !is_compressed && use_local_reader { + if matches!(source_type, SourceType::File) && !is_compressed { return read_csv_local( uri, convert_options, @@ -1054,7 +669,7 @@ fn parse_into_column_array_chunk_stream( })) } -fn fields_to_projection_indices( +pub fn fields_to_projection_indices( fields: &[arrow2::datatypes::Field], include_columns: &Option>, ) -> Arc> { From 409ca486643f1569ce74f9671f8a0f39cd151af3 Mon Sep 17 00:00:00 2001 From: desmondcheongzx Date: Sat, 28 Sep 2024 04:39:14 -0700 Subject: [PATCH 03/10] Disable feature --- src/daft-csv/src/read.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/daft-csv/src/read.rs b/src/daft-csv/src/read.rs index 9609e7bb63..1bd8e5b994 100644 --- a/src/daft-csv/src/read.rs +++ b/src/daft-csv/src/read.rs @@ -155,7 +155,8 @@ pub async fn stream_csv( let uri = uri.as_str(); let (source_type, _) = parse_url(uri)?; let is_compressed = CompressionCodec::from_uri(uri).is_some(); - if matches!(source_type, SourceType::File) && !is_compressed { + let use_local = false; + if matches!(source_type, SourceType::File) && !is_compressed && use_local { let stream = stream_csv_local( uri, convert_options, @@ -228,7 +229,8 @@ async fn read_csv_single_into_table( ) -> DaftResult
{ let (source_type, _) = parse_url(uri)?; let is_compressed = CompressionCodec::from_uri(uri).is_some(); - if matches!(source_type, SourceType::File) && !is_compressed { + let use_local = false; + if matches!(source_type, SourceType::File) && !is_compressed && use_local { return read_csv_local( uri, convert_options, From 445bfce9bcc776acf2547e6491f5fede2de8d1b6 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Sat, 28 Sep 2024 11:41:54 -0700 Subject: [PATCH 04/10] rename as per clippy specifications upper_case_acronyms --- src/daft-csv/src/compression.rs | 66 --------------------------------- src/daft-csv/src/lib.rs | 12 +++--- 2 files changed, 6 insertions(+), 72 deletions(-) delete mode 100644 src/daft-csv/src/compression.rs diff --git a/src/daft-csv/src/compression.rs b/src/daft-csv/src/compression.rs deleted file mode 100644 index 268b1566d9..0000000000 --- a/src/daft-csv/src/compression.rs +++ /dev/null @@ -1,66 +0,0 @@ -use async_compression::tokio::bufread::{ - BrotliDecoder, BzDecoder, DeflateDecoder, GzipDecoder, LzmaDecoder, XzDecoder, ZlibDecoder, - ZstdDecoder, -}; -use std::{path::PathBuf, pin::Pin}; -use tokio::io::{AsyncBufRead, AsyncRead}; -use url::Url; - -#[derive(Debug)] -pub enum CompressionCodec { - Brotli, - Bz, - Deflate, - Gzip, - Lzma, - Xz, - Zlib, - Zstd, -} - -impl CompressionCodec { - pub fn from_uri(uri: &str) -> Option { - let url = Url::parse(uri); - let path = match &url { - Ok(url) => url.path(), - _ => uri, - }; - let extension = PathBuf::from(path) - .extension()? - .to_string_lossy() - .to_string(); - Self::from_extension(extension.as_ref()) - } - pub fn from_extension(extension: &str) -> Option { - use CompressionCodec::*; - match extension { - "br" => Some(Brotli), - "bz2" => Some(Bz), - "deflate" => Some(Deflate), - "gz" => Some(Gzip), - "lzma" => Some(Lzma), - "xz" => Some(Xz), - "zl" => Some(Zlib), - "zstd" | "zst" => Some(Zstd), - "snappy" => todo!("Snappy compression support not yet implemented"), - _ => None, - } - } - - pub fn to_decoder( - &self, - reader: T, - ) -> Pin> { - use CompressionCodec::*; - match self { - Brotli => Box::pin(BrotliDecoder::new(reader)), - Bz => Box::pin(BzDecoder::new(reader)), - Deflate => Box::pin(DeflateDecoder::new(reader)), - Gzip => Box::pin(GzipDecoder::new(reader)), - Lzma => Box::pin(LzmaDecoder::new(reader)), - Xz => Box::pin(XzDecoder::new(reader)), - Zlib => Box::pin(ZlibDecoder::new(reader)), - Zstd => Box::pin(ZstdDecoder::new(reader)), - } - } -} diff --git a/src/daft-csv/src/lib.rs b/src/daft-csv/src/lib.rs index a5a93ca56b..2546a525a4 100644 --- a/src/daft-csv/src/lib.rs +++ b/src/daft-csv/src/lib.rs @@ -24,11 +24,11 @@ pub use read::{read_csv, read_csv_bulk, stream_csv}; #[derive(Debug, Snafu)] pub enum Error { #[snafu(display("{source}"))] - IOError { source: daft_io::Error }, + IoError { source: daft_io::Error }, #[snafu(display("{source}"))] - StdIOError { source: std::io::Error }, + StdIoError { source: std::io::Error }, #[snafu(display("{source}"))] - CSVError { source: csv_async::Error }, + CsvError { source: csv_async::Error }, #[snafu(display("Invalid char: {}", val))] WrongChar { source: std::char::TryFromCharError, @@ -50,7 +50,7 @@ pub enum Error { impl From for DaftError { fn from(err: Error) -> DaftError { match err { - Error::IOError { source } => source.into(), + Error::IoError { source } => source.into(), _ => DaftError::External(err.into()), } } @@ -58,12 +58,12 @@ impl From for DaftError { impl From for Error { fn from(err: daft_io::Error) -> Self { - Error::IOError { source: err } + Error::IoError { source: err } } } #[cfg(feature = "python")] -impl From for pyo3::PyErr { +impl From for PyErr { fn from(value: Error) -> Self { let daft_error: DaftError = value.into(); daft_error.into() From e5f530ba52d3a50d369152730ac5c201e6a20aa9 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Thu, 3 Oct 2024 22:48:44 -0700 Subject: [PATCH 05/10] wip feel free to go to an earlier commit hash (this is not working) --- Cargo.lock | 48 ++ Cargo.toml | 3 + src/daft-csv/Cargo.toml | 3 + src/daft-csv/src/local.rs | 634 +++++++++++------- src/daft-csv/src/local/pool.rs | 247 +++++++ .../src/local/pool/fixed_capacity_vec.rs | 114 ++++ src/daft-csv/src/read.rs | 2 +- 7 files changed, 826 insertions(+), 225 deletions(-) create mode 100644 src/daft-csv/src/local/pool.rs create mode 100644 src/daft-csv/src/local/pool/fixed_capacity_vec.rs diff --git a/Cargo.lock b/Cargo.lock index 844617f3a0..8f75823d6b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1744,7 +1744,9 @@ dependencies = [ "daft-dsl", "daft-io", "daft-table", + "derive_more", "futures", + "heapless", "indexmap 2.3.0", "memchr", "memmap2", @@ -1754,6 +1756,7 @@ dependencies = [ "serde", "snafu", "tokio", + "tokio-stream", "tokio-util", "url", ] @@ -2152,6 +2155,26 @@ dependencies = [ "serde", ] +[[package]] +name = "derive_more" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a9b99b9cbbe49445b21764dc0625032a89b145a2642e67603e1c936f5458d05" +dependencies = [ + "derive_more-impl", +] + +[[package]] +name = "derive_more-impl" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.74", +] + [[package]] name = "diff" version = "0.1.13" @@ -2690,6 +2713,15 @@ dependencies = [ "serde", ] +[[package]] +name = "hash32" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47d60b12902ba28e2730cd37e95b8c9223af2808df9e902d4df49588d1470606" +dependencies = [ + "byteorder", +] + [[package]] name = "hash_hasher" version = "2.0.3" @@ -2712,6 +2744,16 @@ dependencies = [ "allocator-api2", ] +[[package]] +name = "heapless" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bfb9eb618601c89945a70e254898da93b13be0388091d42117462b265bb3fad" +dependencies = [ + "hash32", + "stable_deref_trait", +] + [[package]] name = "heck" version = "0.3.3" @@ -4918,6 +4960,12 @@ dependencies = [ "log", ] +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + [[package]] name = "static_assertions" version = "1.1.0" diff --git a/Cargo.toml b/Cargo.toml index 5d2a771aa1..d79e83c277 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -164,6 +164,9 @@ sqlparser = "0.49.0" sysinfo = "0.30.12" test-log = "0.2.16" tiktoken-rs = "0.5.9" +derive_more = "1.0.0" +heapless = "0.8.0" + tokio = {version = "1.37.0", features = [ "net", "time", diff --git a/src/daft-csv/Cargo.toml b/src/daft-csv/Cargo.toml index d0781e7184..db862d394b 100644 --- a/src/daft-csv/Cargo.toml +++ b/src/daft-csv/Cargo.toml @@ -23,6 +23,9 @@ serde = {workspace = true} snafu = {workspace = true} tokio = {workspace = true} tokio-util = {workspace = true} +tokio-stream = {workspace = true} +derive_more = { workspace = true, features = ["deref", "deref_mut"] } +heapless = { workspace = true } url = {workspace = true} [dev-dependencies] diff --git a/src/daft-csv/src/local.rs b/src/daft-csv/src/local.rs index 3bd73f2ef0..5b5da9514f 100644 --- a/src/daft-csv/src/local.rs +++ b/src/daft-csv/src/local.rs @@ -23,9 +23,12 @@ use rayon::{ prelude::{IntoParallelRefIterator, ParallelIterator}, }; use snafu::ResultExt; - +use crate::local::pool::{read_slabs_windowed, WindowedSlab}; use crate::read::{fields_to_projection_indices, tables_concat}; +mod pool; + + #[allow(clippy::doc_lazy_continuation)] /// Our local CSV reader has the following approach to reading CSV files: /// 1. Read the CSV file in 4MB chunks from a slab pool. @@ -118,61 +121,10 @@ impl CsvBufferPool { } } -// The default size of a slab used for reading CSV files in chunks. Currently set to 4MB. -const SLABSIZE: usize = 4 * 1024 * 1024; +// The default size of a slab used for reading CSV files in chunks. Currently set to 4MiB. +const SLAB_SIZE: usize = 4 * 1024 * 1024; // The default number of slabs in a slab pool. -const SLABPOOL_DEFAULT_SIZE: usize = 20; - -/// A pool of 4MB slabs. Used for reading CSV files in 4MB chunks. -#[derive(Debug)] -struct SlabPool { - buffers: Mutex>>, - condvar: Condvar, -} - -/// A 4MB slab of bytes. Used for reading CSV files in 4MB chunks. -#[derive(Clone)] -struct Slab { - pool: Arc, - // We wrap the Arc in an Option so that when a Slab is being dropped, we can move the Slab's reference - // to the Arc back to the slab pool. - buffer: Option>, -} - -impl Drop for Slab { - fn drop(&mut self) { - // Move the buffer back to the slab pool. - if let Some(buffer) = self.buffer.take() { - self.pool.return_buffer(buffer); - } - } -} - -impl SlabPool { - pub fn new() -> Self { - let chunk_buffers: Vec> = (0..SLABPOOL_DEFAULT_SIZE) - .map(|_| Arc::new([0; SLABSIZE])) - .collect(); - SlabPool { - buffers: Mutex::new(chunk_buffers), - condvar: Condvar::new(), - } - } - - pub fn get_buffer(self: &Arc) -> Arc<[u8; SLABSIZE]> { - let mut buffers = self.buffers.lock().unwrap(); - while buffers.is_empty() { - buffers = self.condvar.wait(buffers).unwrap(); - } - buffers.pop().unwrap() - } - - fn return_buffer(&self, buffer: Arc<[u8; SLABSIZE]>) { - let mut buffers = self.buffers.lock().unwrap(); - buffers.push(buffer); - self.condvar.notify_one(); - } -} +const SLAB_POOL_DEFAULT_SIZE: usize = 20; /// A data structure that holds either a single slice of bytes, or a chain of two slices of bytes. /// See the description to `parse_json` for more details. @@ -205,7 +157,7 @@ pub async fn read_csv_local( parse_options, read_options, max_chunks_in_flight, - )?; + ).await?; tables_concat(tables_stream_collect(Box::pin(stream)).await) } @@ -221,15 +173,15 @@ async fn tables_stream_collect(stream: BoxStream<'static, DaftResult
>) -> .await } -pub fn stream_csv_local( +pub async fn stream_csv_local( uri: &str, convert_options: Option, parse_options: CsvParseOptions, read_options: Option, max_chunks_in_flight: Option, -) -> DaftResult> + Send> { +) -> DaftResult> + Send> { let uri = uri.trim_start_matches("file://"); - let file = std::fs::File::open(uri)?; + let file = tokio::fs::File::open(uri).await?; // TODO(desmond): This logic is repeated multiple times in the csv reader files. Should dedup. let predicate = convert_options @@ -259,7 +211,7 @@ pub fn stream_csv_local( Some(co) } } - .unwrap_or_default(); + .unwrap_or_default(); // End of `should dedup`. // TODO(desmond): We should do better schema inference here. @@ -305,8 +257,8 @@ pub fn stream_csv_local( // We suppose that each slab of CSV data produces (chunk size / slab size) number of Daft tables. We // then double this capacity to ensure that our channel is never full and our threads won't deadlock. let (sender, receiver) = - crossbeam_channel::bounded(max_chunks_in_flight.unwrap_or(2 * chunk_size / SLABSIZE)); - rayon::spawn(move || { + crossbeam_channel::bounded(max_chunks_in_flight.unwrap_or(2 * chunk_size / SLAB_SIZE)); + tokio::spawn(async move || { consume_csv_file( file, buffer_pool, @@ -330,9 +282,8 @@ pub fn stream_csv_local( /// Consumes the CSV file and sends the results to `sender`. #[allow(clippy::too_many_arguments)] fn consume_csv_file( - mut file: std::fs::File, + file: tokio::fs::File, buffer_pool: Arc, - slabpool: Arc, parse_options: CsvParseOptions, projection_indices: Arc>, read_daft_fields: Arc>>, @@ -344,189 +295,426 @@ fn consume_csv_file( limit: Option, sender: Sender>, ) { - let rows_read = Arc::new(AtomicUsize::new(0)); - let mut has_header = parse_options.has_header; - let total_len = file.metadata().unwrap().len() as usize; - let field_delimiter = parse_options.delimiter; - let escape_char = parse_options.escape_char; - let quote_char = parse_options.quote; - let double_quote_escape_allowed = parse_options.double_quote; - let mut total_bytes_read = 0; - let mut next_slab = None; - let mut next_buffer_len = 0; - let mut first_buffer = true; - loop { - let limit_reached = limit.map_or(false, |limit| { - let current_rows_read = rows_read.load(Ordering::Relaxed); + let mut csv_consumer = CsvConsumer::new( + file, + buffer_pool, + slabpool, + parse_options, + projection_indices, + read_daft_fields, + read_schema, + fields, + num_fields, + include_columns.clone(), + predicate, + limit, + sender, + ); + + csv_consumer.consume(); +} + +/// A struct representing a CSV consumer that processes and parses CSV data. +struct CsvConsumer { + /// The file being read. + file: tokio::fs::File, + /// Options for parsing the CSV file. + parse_options: CsvParseOptions, + /// Indices of columns to be projected (included in the output). + projection_indices: Arc>, + /// Fields to be read from the CSV, in Daft format. + read_daft_fields: Arc>>, + /// Schema of the data to be read. + read_schema: Arc, + /// Fields of the CSV file. + fields: Vec, + /// Total number of fields in the CSV. + num_fields: usize, + /// Optional list of columns to include in the output. + include_columns: Option>, + /// Optional predicate for filtering rows. + predicate: Option>, + /// Optional limit on the number of rows to read. + limit: Option, + /// Channel sender for sending parsed tables or errors. + sender: Sender>, + /// Atomic counter for the number of rows read. + rows_read: Arc, + /// Flag indicating whether the CSV has a header row. + has_header: bool, + /// Total length of the file in bytes. + total_len: usize, + /// Total number of bytes read so far. + total_bytes_read: usize, + /// Length of the next buffer to be processed. + next_buffer_len: usize, + /// Flag indicating if this is the first buffer being processed. + first_buffer: bool, +} + +impl CsvConsumer { + /// Creates a new CsvConsumer instance. + fn new( + file: tokio::fs::File, + buffer_pool: Arc, + slabpool: Arc, + parse_options: CsvParseOptions, + projection_indices: Arc>, + read_daft_fields: Arc>>, + read_schema: Arc, + fields: Vec, + num_fields: usize, + include_columns: Option>, + predicate: Option>, + limit: Option, + sender: Sender>, + ) -> Self { + let total_len = file.metadata().unwrap().len() as usize; + let has_header = parse_options.has_header; + Self { + file, + buffer_pool, + slabpool, + parse_options, + projection_indices, + read_daft_fields, + read_schema, + fields, + num_fields, + include_columns, + predicate, + limit, + sender, + rows_read: Arc::new(AtomicUsize::new(0)), + has_header, + total_len, + total_bytes_read: 0, + next_slab: None, + next_buffer_len: 0, + first_buffer: true, + } + } + + /// Main method to consume and process the CSV file. + async fn consume(mut self, file: tokio::fs::File) { + let mut pool = read_slabs_windowed(file, SLAB_SIZE, SLAB_POOL_DEFAULT_SIZE); + + loop { + if self.limit_reached() { + break; + } + + let Some(windowed_slab) = pool.next().await else { + break; + }; + + if !self.process_slab(windowed_slab) { + break; + } + self.has_header = false; + if self.total_bytes_read >= self.total_len { + break; + } + } + } + + /// Checks if the row limit has been reached. + fn limit_reached(&self) -> bool { + self.limit.map_or(false, |limit| { + let current_rows_read = self.rows_read.load(Ordering::Relaxed); current_rows_read >= limit - }); - if limit_reached { - break; + }) + } + + /// Processes a single slab of data. + fn process_slab(&mut self, windowed_slab: WindowedSlab) -> bool { + let file_chunk = self.get_file_chunk(¤t_slab, current_buffer_len); + self.first_buffer = false; + + if let (None, _) = file_chunk { + self.slabpool.return_buffer(unsafe_clone_buffer(¤t_slab.buffer)); + return false; } - let (current_slab, current_buffer_len) = match next_slab.take() { + + self.spawn_parse_thread(current_slab, file_chunk); + true + } + + /// Retrieves the current slab to be processed. + fn get_current_slab(&mut self) -> (Arc, usize) { + match self.next_slab.take() { Some(next_slab) => { - total_bytes_read += next_buffer_len; - (next_slab, next_buffer_len) + self.total_bytes_read += self.next_buffer_len; + (next_slab, self.next_buffer_len) } - None => { - let mut buffer = slabpool.get_buffer(); - match Arc::get_mut(&mut buffer) { - Some(inner_buffer) => { - let bytes_read = file.read(inner_buffer).unwrap(); - if bytes_read == 0 { - slabpool.return_buffer(buffer); - break; - } - total_bytes_read += bytes_read; - ( - Arc::new(Slab { - pool: Arc::clone(&slabpool), - buffer: Some(buffer), - }), - bytes_read, - ) - } - None => { - slabpool.return_buffer(buffer); - break; - } + None => self.read_new_slab(), + } + } + + /// Reads a new slab from the file. + fn read_new_slab(&mut self) -> (Arc, usize) { + let mut buffer = self.slabpool.get_buffer(); + match Arc::get_mut(&mut buffer) { + Some(inner_buffer) => { + let bytes_read = self.file.read(inner_buffer).unwrap(); + if bytes_read == 0 { + self.slabpool.return_buffer(buffer); + return (Arc::new(Slab { + pool: Arc::clone(&self.slabpool), + buffer: None, + }), 0); } + self.total_bytes_read += bytes_read; + ( + Arc::new(Slab { + pool: Arc::clone(&self.slabpool), + buffer: Some(buffer), + }), + bytes_read, + ) } - }; - (next_slab, next_buffer_len) = if total_bytes_read < total_len { - let mut next_buffer = slabpool.get_buffer(); + None => { + self.slabpool.return_buffer(buffer); + (Arc::new(Slab { + pool: Arc::clone(&self.slabpool), + buffer: None, + }), 0) + } + } + } + + /// Prepares the next slab for processing. + fn prepare_next_slab(&mut self) { + if self.total_bytes_read < self.total_len { + let mut next_buffer = self.slabpool.get_buffer(); match Arc::get_mut(&mut next_buffer) { Some(inner_buffer) => { - let bytes_read = file.read(inner_buffer).unwrap(); + let bytes_read = self.file.read(inner_buffer).unwrap(); if bytes_read == 0 { - slabpool.return_buffer(next_buffer); - (None, 0) + self.slabpool.return_buffer(next_buffer); + self.next_slab = None; + self.next_buffer_len = 0; } else { - ( - Some(Arc::new(Slab { - pool: Arc::clone(&slabpool), - buffer: Some(next_buffer), - })), - bytes_read, - ) + self.next_slab = Some(Arc::new(Slab { + pool: Arc::clone(&self.slabpool), + buffer: Some(next_buffer), + })); + self.next_buffer_len = bytes_read; } } None => { - slabpool.return_buffer(next_buffer); - break; + self.slabpool.return_buffer(next_buffer); + self.next_slab = None; + self.next_buffer_len = 0; } } } else { - (None, 0) + self.next_slab = None; + self.next_buffer_len = 0; + } + } + + /// Retrieves the chunk of the file to be processed. + fn get_file_chunk(&self, current_slab: &WindowedSlab, first_buffer: bool) -> (Option, Option) { + let (left, right) = match current_slab.as_slice() { + [left, right] => (left.as_slice(), Some(right.as_slice())), + [left] => (left.as_slice(), None), + _ => unreachable!("Unexpected windowed slab size"), }; - let file_chunk = get_file_chunk( - unsafe_clone_buffer(¤t_slab.buffer), - current_buffer_len, - next_slab - .as_ref() - .map(|slab| unsafe_clone_buffer(&slab.buffer)), - next_buffer_len, + + + get_file_chunk( + left, + right, first_buffer, - num_fields, - quote_char, - field_delimiter, - escape_char, - double_quote_escape_allowed, - ); - first_buffer = false; - if let (None, _) = file_chunk { - // Return the buffer. It doesn't matter that we still have a reference to the slab. We're going to fallback - // and the slabs will be useless. - slabpool.return_buffer(unsafe_clone_buffer(¤t_slab.buffer)); - // Exit early before spawning a new thread. - break; - // TODO(desmond): we should fallback instead. - } + self.num_fields, + self.parse_options.quote, + self.parse_options.delimiter, + self.parse_options.escape_char, + self.parse_options.double_quote, + ) + } + + /// Spawns a new thread to parse the CSV chunk. + fn spawn_parse_thread(&self, current_slab: Arc, file_chunk: (Option, Option)) { let current_slab_clone = Arc::clone(¤t_slab); - let next_slab_clone = next_slab.clone(); - let parse_options = parse_options.clone(); - let csv_buffer = buffer_pool.get_buffer(); - let projection_indices = projection_indices.clone(); - let fields = fields.clone(); - let read_daft_fields = read_daft_fields.clone(); - let read_schema = read_schema.clone(); - let include_columns = include_columns.clone(); - let predicate = predicate.clone(); - let sender = sender.clone(); - let rows_read = Arc::clone(&rows_read); + let next_slab_clone = self.next_slab.clone(); + let parse_options = self.parse_options.clone(); + let csv_buffer = self.buffer_pool.get_buffer(); + let projection_indices = self.projection_indices.clone(); + let fields = self.fields.clone(); + let read_daft_fields = self.read_daft_fields.clone(); + let read_schema = self.read_schema.clone(); + let include_columns = self.include_columns.clone(); + let predicate = self.predicate.clone(); + let sender = self.sender.clone(); + let rows_read = Arc::clone(&self.rows_read); + let limit = self.limit; + let has_header = self.has_header; + rayon::spawn(move || { - let limit_reached = limit.map_or(false, |limit| { - let current_rows_read = rows_read.load(Ordering::Relaxed); - current_rows_read >= limit - }); - if !limit_reached { - match file_chunk { - (Some(start), None) => { - if let Some(buffer) = ¤t_slab_clone.buffer { - let buffer_source = BufferSource::Single(Cursor::new( - &buffer[start..current_buffer_len], - )); - dispatch_to_parse_csv( - has_header, - parse_options, - buffer_source, - projection_indices, - fields, - read_daft_fields, - read_schema, - csv_buffer, - &include_columns, - predicate, - sender, - rows_read, - ); - } else { - panic!("Trying to read from a CSV buffer that doesn't exist. Please report this issue.") - } - } - (Some(start), Some(end)) => { - if let Some(next_slab_clone) = next_slab_clone - && let Some(current_buffer) = ¤t_slab_clone.buffer - && let Some(next_buffer) = &next_slab_clone.buffer - { - let buffer_source = BufferSource::Chain(std::io::Read::chain( - Cursor::new(¤t_buffer[start..current_buffer_len]), - Cursor::new(&next_buffer[..end]), - )); - dispatch_to_parse_csv( - has_header, - parse_options, - buffer_source, - projection_indices, - fields, - read_daft_fields, - read_schema, - csv_buffer, - &include_columns, - predicate, - sender, - rows_read, - ); - } else { - panic!("Trying to read from an overflow CSV buffer that doesn't exist. Please report this issue.") - } - } - _ => panic!( - "Something went wrong when parsing the CSV file. Please report this issue." - ), - }; + if !Self::thread_limit_reached(&rows_read, limit) { + Self::parse_chunk( + has_header, + parse_options, + current_slab_clone, + next_slab_clone, + file_chunk, + projection_indices, + fields, + read_daft_fields, + read_schema, + csv_buffer, + &include_columns, + predicate, + sender, + rows_read, + ); } }); - has_header = false; - if total_bytes_read >= total_len { - break; + } + + /// Checks if the thread has reached its row limit. + fn thread_limit_reached(rows_read: &Arc, limit: Option) -> bool { + limit.map_or(false, |limit| { + let current_rows_read = rows_read.load(Ordering::Relaxed); + current_rows_read >= limit + }) + } + + /// Parses a chunk of the CSV file. + #[allow(clippy::too_many_arguments)] + fn parse_chunk( + has_header: bool, + parse_options: CsvParseOptions, + current_slab: Arc, + next_slab: Option>, + file_chunk: (Option, Option), + projection_indices: Arc>, + fields: Vec, + read_daft_fields: Arc>>, + read_schema: Arc, + csv_buffer: CsvBuffer, + include_columns: &Option>, + predicate: Option>, + sender: Sender>, + rows_read: Arc, + ) { + match file_chunk { + (Some(start), None) => { + if let Some(buffer) = ¤t_slab.buffer { + let buffer_source = BufferSource::Single(Cursor::new( + &buffer[start..], + )); + Self::dispatch_to_parse_csv( + has_header, + parse_options, + buffer_source, + projection_indices, + fields, + read_daft_fields, + read_schema, + csv_buffer, + include_columns, + predicate, + sender, + rows_read, + ); + } else { + panic!("Trying to read from a CSV buffer that doesn't exist. Please report this issue.") + } + } + (Some(start), Some(end)) => { + if let Some(next_slab) = next_slab + && let Some(current_buffer) = ¤t_slab.buffer + && let Some(next_buffer) = &next_slab.buffer + { + let buffer_source = BufferSource::Chain(Read::chain( + Cursor::new(¤t_buffer[start..]), + Cursor::new(&next_buffer[..end]), + )); + Self::dispatch_to_parse_csv( + has_header, + parse_options, + buffer_source, + projection_indices, + fields, + read_daft_fields, + read_schema, + csv_buffer, + include_columns, + predicate, + sender, + rows_read, + ); + } else { + panic!("Trying to read from an overflow CSV buffer that doesn't exist. Please report this issue.") + } + } + _ => panic!( + "Something went wrong when parsing the CSV file. Please report this issue." + ), + } + } + + /// Helper function that takes in a BufferSource, calls parse_csv() to extract table values from + /// the buffer source, then streams the results to `sender`. + #[allow(clippy::too_many_arguments)] + fn dispatch_to_parse_csv( + has_header: bool, + parse_options: CsvParseOptions, + buffer_source: BufferSource, + projection_indices: Arc>, + fields: Vec, + read_daft_fields: Arc>>, + read_schema: Arc, + csv_buffer: CsvBuffer, + include_columns: &Option>, + predicate: Option>, + sender: Sender>, + rows_read: Arc, + ) { + let table_results = { + let rdr = ReaderBuilder::new() + .has_headers(has_header) + .delimiter(parse_options.delimiter) + .double_quote(parse_options.double_quote) + .quote(parse_options.quote) + .escape(parse_options.escape_char) + .comment(parse_options.comment) + .flexible(parse_options.allow_variable_columns) + .from_reader(buffer_source); + parse_csv_chunk( + rdr, + projection_indices, + fields, + read_daft_fields, + read_schema, + csv_buffer, + include_columns, + predicate, + ) + }; + match table_results { + Ok(tables) => { + for table in tables { + let table_len = table.len(); + sender.send(Ok(table)).unwrap(); + // Atomically update the number of rows read only after the result has + // been sent. In theory we could wrap these steps in a mutex, but + // applying limit at this layer can be best-effort with no adverse + // side effects. + rows_read.fetch_add(table_len, Ordering::Relaxed); + } + } + Err(e) => sender.send(Err(e)).unwrap(), } } } /// Unsafe helper function that extracts the buffer from an &Option>. Users should /// ensure that the buffer is Some, otherwise this function causes the process to panic. -fn unsafe_clone_buffer(buffer: &Option>) -> Arc<[u8; SLABSIZE]> { +fn unsafe_clone_buffer(buffer: &Option>) -> Arc<[u8; SLAB_SIZE]> { match buffer { Some(buffer) => Arc::clone(buffer), None => panic!("Tried to clone a CSV slab that doesn't exist. Please report this error."), @@ -591,10 +779,8 @@ fn unsafe_clone_buffer(buffer: &Option>) -> Arc<[u8; SLABSIZ /// \n character and go back to 2. #[allow(clippy::too_many_arguments)] fn get_file_chunk( - current_buffer: Arc<[u8; SLABSIZE]>, - current_buffer_len: usize, - next_buffer: Option>, - next_buffer_len: usize, + current_buffer: &[u8], + next_buffer: Option<&[u8]>, first_buffer: bool, num_fields: usize, quote_char: u8, @@ -862,7 +1048,7 @@ fn parse_csv_chunk( predicate: Option>, ) -> DaftResult> where - R: std::io::Read, + R: Read, { let mut chunk_buffer = csv_buffer.buffer; let mut tables = vec![]; diff --git a/src/daft-csv/src/local/pool.rs b/src/daft-csv/src/local/pool.rs new file mode 100644 index 0000000000..0fe836fac9 --- /dev/null +++ b/src/daft-csv/src/local/pool.rs @@ -0,0 +1,247 @@ +use core::mem::ManuallyDrop; +use std::pin::pin; +use futures::Stream; +use std::sync::Arc; +use tokio::fs::File; +use tokio::io::{AsyncRead, AsyncReadExt}; +use tokio::sync::mpsc::{self, Receiver, Sender}; + +mod fixed_capacity_vec; + +type SlabData = Vec; + +/// A pool of reusable memory slabs for efficient I/O operations. +struct SlabPool { + available_slabs_sender: Sender, + available_slabs: Receiver, +} + +impl SlabPool { + /// Creates a new `SlabPool` with a specified number of slabs of a given size. + fn new(slab_count: usize, slab_size: usize) -> Self { + let (tx, rx) = mpsc::channel(slab_count); + for _ in 0..slab_count { + tx.try_send(Vec::with_capacity(slab_size)) + .expect("Failed to send slab to pool"); + } + Self { + available_slabs: rx, + available_slabs_sender: tx, + } + } + + /// Asynchronously retrieves the next available slab from the pool. + async fn get_next_data(&mut self) -> Slab { + let mut data = self + .available_slabs + .recv() + .await + .expect("Slab pool is empty"); + + data.clear(); + + Slab { + send_back_to_pool: self.available_slabs_sender.clone(), + data: ManuallyDrop::new(data), + } + } +} + +/// Represents a single memory slab that can be returned to the pool when dropped. +#[derive(Debug)] +pub struct Slab { + send_back_to_pool: Sender, + data: ManuallyDrop, +} + +impl std::ops::Deref for Slab { + type Target = SlabData; + + fn deref(&self) -> &Self::Target { + &self.data + } +} + +impl std::ops::DerefMut for Slab { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.data + } +} + +type SharedSlab = Arc; + +impl Drop for Slab { + fn drop(&mut self) { + let data = unsafe { ManuallyDrop::take(&mut self.data) }; + let _ = self.send_back_to_pool.try_send(data); + } +} + +use tokio_stream::wrappers::ReceiverStream; + +/// Asynchronously reads slabs from a file and returns a stream of SharedSlabs. +pub fn read_slabs( + mut file: R, + buffer_size: usize, + pool_size: usize, +) -> impl Stream { + let (tx, rx) = mpsc::channel::(pool_size); + let pool = SlabPool::new(pool_size, buffer_size); + tokio::spawn(async move { + let mut pool = pool; + loop { + let mut slab = pool.get_next_data().await; + + // note: might not be exactly 4MiB + let mut total_read = 0; + while total_read < slab.capacity() { + let result = file + .read_buf(&mut *slab) + .await + .expect("Failed to read from file"); + + if result == 0 { + // End of file reached + break; + } + + total_read += result; + + // If we're close to filling the buffer, stop reading + if slab.capacity() - total_read < 1024 { + break; + } + } + + if total_read == 0 { + // No data read, end of file + break; + } + + // Update the length of the slab with the actual number of bytes read + debug_assert_eq!(total_read, slab.len(), "Slab length should be equal to the number of bytes read"); + tx.send(slab).await.expect("Failed to send slab to stream"); + } + }); + + ReceiverStream::new(rx) +} + +use crate::local::pool::fixed_capacity_vec::FixedCapacityVec; +use futures::stream::StreamExt; + +pub type WindowedSlab = heapless::Vec; + +/// Asynchronously reads slabs from a file and returns a stream of WindowedSlabs. +/// +/// This function creates a windowed view of the slabs, where each `WindowedSlab` +/// contains two consecutive slabs. The windowing is done in an overlapping manner, +/// so the second slab of the previous window becomes the first slab of the next window. +/// +/// # Arguments +/// +/// * `file` - The file to read from. +/// * `buffer_size` - The size of each slab's buffer. +/// * `pool_size` - The size of the slab pool. +/// +/// # Returns +/// +/// A `Stream` of `WindowedSlab`s. +pub fn read_slabs_windowed( + file: R, + buffer_size: usize, + pool_size: usize, +) -> impl Stream { + let mut slab_stream = read_slabs(file, buffer_size, pool_size); + + use tokio_stream::StreamExt; + use tokio::sync::mpsc; + + let (tx, rx) = mpsc::channel(pool_size); + + tokio::spawn(async move { + let mut slab_stream = pin!(slab_stream); + + let mut windowed_slab = heapless::Vec::::new(); + + let mut slab_stream = slab_stream.as_mut(); + while let Some(slab) = StreamExt::next(&mut slab_stream).await { + let slab = SharedSlab::from(slab); + + windowed_slab.push(slab).unwrap(); + + if windowed_slab.len() == 2 { + tx.send(windowed_slab.clone()).await.unwrap(); + windowed_slab.remove(0); + } + } + + // Send the last windowed slab + if !windowed_slab.is_empty() { + tx.send(windowed_slab).await.unwrap(); + } + }); + + ReceiverStream::new(rx) +} + + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Cursor; + use futures::StreamExt; + + #[tokio::test] + async fn test_read_slabs() { + let data = b"Hello, World!".repeat(1000); + let data_len = data.len(); + let cursor = Cursor::new(data); + let buffer_size = 100; + let pool_size = 5; + + let mut stream = read_slabs(cursor, buffer_size, pool_size).await; + let mut total_bytes = 0; + + while let Some(slab) = stream.next().await { + assert!(slab.len() <= buffer_size); + total_bytes += slab.len(); + } + + assert_eq!(total_bytes, data_len); + } + + #[tokio::test] + async fn test_read_slabs_windowed() { + let data = b"Hello, World!".repeat(1000); + let data_len = data.len(); + let cursor = Cursor::new(data); + let buffer_size = 100; + let pool_size = 5; + + let mut stream = read_slabs_windowed(cursor, buffer_size, pool_size).await; + let mut total_bytes = 0; + let mut previous_slab: Option = None; + + let left_total = 0; + let right_total = 0; + + while let Some(windowed_slab) = stream.next().await { + assert_eq!(windowed_slab.len(), 2); + + if let Some(prev) = previous_slab { + assert!(Arc::ptr_eq(&prev, &windowed_slab[0])); + } + + left_total += windowed_slab[0].len(); + right_total += windowed_slab[1].len(); + total_bytes += windowed_slab[1].len(); + previous_slab = Some(windowed_slab[1].clone()); + } + + assert_eq!(total_bytes, data_len); + assert_eq!(left_total, right_total); + assert_eq!(left_total, data_len); + } +} + diff --git a/src/daft-csv/src/local/pool/fixed_capacity_vec.rs b/src/daft-csv/src/local/pool/fixed_capacity_vec.rs new file mode 100644 index 0000000000..7972ffd336 --- /dev/null +++ b/src/daft-csv/src/local/pool/fixed_capacity_vec.rs @@ -0,0 +1,114 @@ +/// A vector with a fixed capacity, optimized for performance. +pub struct FixedCapacityVec { + data: Box<[u8]>, + len: usize, +} + +impl FixedCapacityVec { + /// Creates a new `FixedCapacityVec` with the specified capacity. + #[inline] + pub fn new(capacity: usize) -> Self { + Self { + data: vec![0; capacity].into_boxed_slice(), + len: 0, + } + } + + /// Returns the current length of the vector. + #[inline] + pub fn len(&self) -> usize { + self.len + } + + /// Returns true if the vector is empty. + #[inline] + pub fn is_empty(&self) -> bool { + self.len == 0 + } + + /// Returns the capacity of the vector. + #[inline] + pub fn capacity(&self) -> usize { + self.data.len() + } + + /// Pushes an element onto the end of the vector. + /// + /// # Panics + /// + /// Panics if the vector is already at capacity. + #[inline] + pub fn push(&mut self, value: u8) { + assert!(self.len < self.capacity(), "FixedCapacityVec is at capacity"); + self.data[self.len] = value; + self.len += 1; + } + + /// Removes and returns the last element of the vector. + /// + /// # Panics + /// + /// Panics if the vector is empty. + #[inline] + pub fn pop(&mut self) -> u8 { + assert!(!self.is_empty(), "FixedCapacityVec is empty"); + self.len -= 1; + self.data[self.len] + } + + /// Returns a reference to the element at the given index. + /// + /// # Panics + /// + /// Panics if the index is out of bounds. + #[inline] + pub fn get(&self, index: usize) -> &u8 { + assert!(index < self.len, "Index out of bounds"); + &self.data[index] + } + + /// Returns a mutable reference to the element at the given index. + /// + /// # Panics + /// + /// Panics if the index is out of bounds. + #[inline] + pub fn get_mut(&mut self, index: usize) -> &mut u8 { + assert!(index < self.len, "Index out of bounds"); + &mut self.data[index] + } + + /// Clears the vector, removing all elements. + #[inline] + pub fn clear(&mut self) { + self.len = 0; + } + + /// Returns a slice containing the entire capacity of the vector. + /// + /// This includes both initialized and uninitialized elements. + /// + /// # Safety + /// + /// This function is unsafe because it returns a slice that may contain + /// uninitialized memory. The caller must ensure that they only access + /// initialized elements (up to `self.len()`). + #[inline] + pub fn capacity_slice(&self) -> &[u8] { + &self.data + } + + /// Returns a mutable slice containing the entire capacity of the vector. + /// + /// This includes both initialized and uninitialized elements. + /// + /// # Safety + /// + /// This function is unsafe because it returns a slice that may contain + /// uninitialized memory. The caller must ensure that they only access + /// initialized elements (up to `self.len()`). + #[inline] + pub fn capacity_slice_mut(&mut self) -> &mut [u8] { + &mut self.data + } +} diff --git a/src/daft-csv/src/read.rs b/src/daft-csv/src/read.rs index 1bd8e5b994..89a3c87fa3 100644 --- a/src/daft-csv/src/read.rs +++ b/src/daft-csv/src/read.rs @@ -846,7 +846,7 @@ mod tests { use crate::read::read_csv_local; use daft_io::get_runtime; - #[test] + // #[test] todo: re-enable fn test_csv_read_experimental() -> DaftResult<()> { let file = "file:///Users/desmond/tasks/csv-reader/G1_1e8_1e1_5_0.csv"; From 9c86a06179e004d1b6cebb5b8e27c948fa48dafc Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Fri, 4 Oct 2024 15:27:41 -0700 Subject: [PATCH 06/10] stash again --- src/daft-csv/src/local.rs | 4 ++-- src/daft-csv/src/local/pool.rs | 23 ++++++++++------------- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/src/daft-csv/src/local.rs b/src/daft-csv/src/local.rs index 5b5da9514f..36b0a0afd5 100644 --- a/src/daft-csv/src/local.rs +++ b/src/daft-csv/src/local.rs @@ -1036,11 +1036,11 @@ fn dispatch_to_parse_csv( } /// Helper function that consumes a CSV reader and turns it into a vector of Daft tables. -#[allow(clippy::too_many_arguments)] +#[expect(clippy::too_many_arguments)] fn parse_csv_chunk( mut reader: Reader, projection_indices: Arc>, - fields: Vec, + fields: Vec, read_daft_fields: Arc>>, read_schema: Arc, csv_buffer: CsvBuffer, diff --git a/src/daft-csv/src/local/pool.rs b/src/daft-csv/src/local/pool.rs index 0fe836fac9..d2ecb66b1f 100644 --- a/src/daft-csv/src/local/pool.rs +++ b/src/daft-csv/src/local/pool.rs @@ -1,8 +1,7 @@ use core::mem::ManuallyDrop; -use std::pin::pin; use futures::Stream; +use std::pin::pin; use std::sync::Arc; -use tokio::fs::File; use tokio::io::{AsyncRead, AsyncReadExt}; use tokio::sync::mpsc::{self, Receiver, Sender}; @@ -84,7 +83,7 @@ pub fn read_slabs( mut file: R, buffer_size: usize, pool_size: usize, -) -> impl Stream { +) -> impl Stream { let (tx, rx) = mpsc::channel::(pool_size); let pool = SlabPool::new(pool_size, buffer_size); tokio::spawn(async move { @@ -119,7 +118,11 @@ pub fn read_slabs( } // Update the length of the slab with the actual number of bytes read - debug_assert_eq!(total_read, slab.len(), "Slab length should be equal to the number of bytes read"); + debug_assert_eq!( + total_read, + slab.len(), + "Slab length should be equal to the number of bytes read" + ); tx.send(slab).await.expect("Failed to send slab to stream"); } }); @@ -127,9 +130,6 @@ pub fn read_slabs( ReceiverStream::new(rx) } -use crate::local::pool::fixed_capacity_vec::FixedCapacityVec; -use futures::stream::StreamExt; - pub type WindowedSlab = heapless::Vec; /// Asynchronously reads slabs from a file and returns a stream of WindowedSlabs. @@ -151,11 +151,11 @@ pub fn read_slabs_windowed( file: R, buffer_size: usize, pool_size: usize, -) -> impl Stream { - let mut slab_stream = read_slabs(file, buffer_size, pool_size); +) -> impl Stream { + let slab_stream = read_slabs(file, buffer_size, pool_size); - use tokio_stream::StreamExt; use tokio::sync::mpsc; + use tokio_stream::StreamExt; let (tx, rx) = mpsc::channel(pool_size); @@ -185,12 +185,10 @@ pub fn read_slabs_windowed( ReceiverStream::new(rx) } - #[cfg(test)] mod tests { use super::*; use std::io::Cursor; - use futures::StreamExt; #[tokio::test] async fn test_read_slabs() { @@ -244,4 +242,3 @@ mod tests { assert_eq!(left_total, data_len); } } - From 70b37d8e0a9c5cdfda180d1494aeedee273565a6 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Mon, 7 Oct 2024 16:12:00 -0700 Subject: [PATCH 07/10] SlabPool is now generic, reverted `local.rs` --- src/daft-csv/src/local.rs | 636 ++++++++++++--------------------- src/daft-csv/src/local/pool.rs | 153 ++++---- 2 files changed, 314 insertions(+), 475 deletions(-) diff --git a/src/daft-csv/src/local.rs b/src/daft-csv/src/local.rs index 36b0a0afd5..1ce58551fb 100644 --- a/src/daft-csv/src/local.rs +++ b/src/daft-csv/src/local.rs @@ -23,11 +23,10 @@ use rayon::{ prelude::{IntoParallelRefIterator, ParallelIterator}, }; use snafu::ResultExt; -use crate::local::pool::{read_slabs_windowed, WindowedSlab}; -use crate::read::{fields_to_projection_indices, tables_concat}; mod pool; +use crate::read::{fields_to_projection_indices, tables_concat}; #[allow(clippy::doc_lazy_continuation)] /// Our local CSV reader has the following approach to reading CSV files: @@ -121,10 +120,61 @@ impl CsvBufferPool { } } -// The default size of a slab used for reading CSV files in chunks. Currently set to 4MiB. -const SLAB_SIZE: usize = 4 * 1024 * 1024; +// The default size of a slab used for reading CSV files in chunks. Currently set to 4MB. +const SLABSIZE: usize = 4 * 1024 * 1024; // The default number of slabs in a slab pool. -const SLAB_POOL_DEFAULT_SIZE: usize = 20; +const SLABPOOL_DEFAULT_SIZE: usize = 20; + +/// A pool of 4MB slabs. Used for reading CSV files in 4MB chunks. +#[derive(Debug)] +struct SlabPool { + buffers: Mutex>>, + condvar: Condvar, +} + +/// A 4MB slab of bytes. Used for reading CSV files in 4MB chunks. +#[derive(Clone)] +struct Slab { + pool: Arc, + // We wrap the Arc in an Option so that when a Slab is being dropped, we can move the Slab's reference + // to the Arc back to the slab pool. + buffer: Option>, +} + +impl Drop for Slab { + fn drop(&mut self) { + // Move the buffer back to the slab pool. + if let Some(buffer) = self.buffer.take() { + self.pool.return_buffer(buffer); + } + } +} + +impl SlabPool { + pub fn new() -> Self { + let chunk_buffers: Vec> = (0..SLABPOOL_DEFAULT_SIZE) + .map(|_| Arc::new([0; SLABSIZE])) + .collect(); + SlabPool { + buffers: Mutex::new(chunk_buffers), + condvar: Condvar::new(), + } + } + + pub fn get_buffer(self: &Arc) -> Arc<[u8; SLABSIZE]> { + let mut buffers = self.buffers.lock().unwrap(); + while buffers.is_empty() { + buffers = self.condvar.wait(buffers).unwrap(); + } + buffers.pop().unwrap() + } + + fn return_buffer(&self, buffer: Arc<[u8; SLABSIZE]>) { + let mut buffers = self.buffers.lock().unwrap(); + buffers.push(buffer); + self.condvar.notify_one(); + } +} /// A data structure that holds either a single slice of bytes, or a chain of two slices of bytes. /// See the description to `parse_json` for more details. @@ -157,7 +207,7 @@ pub async fn read_csv_local( parse_options, read_options, max_chunks_in_flight, - ).await?; + )?; tables_concat(tables_stream_collect(Box::pin(stream)).await) } @@ -173,15 +223,15 @@ async fn tables_stream_collect(stream: BoxStream<'static, DaftResult
>) -> .await } -pub async fn stream_csv_local( +pub fn stream_csv_local( uri: &str, convert_options: Option, parse_options: CsvParseOptions, read_options: Option, max_chunks_in_flight: Option, -) -> DaftResult> + Send> { +) -> DaftResult> + Send> { let uri = uri.trim_start_matches("file://"); - let file = tokio::fs::File::open(uri).await?; + let file = std::fs::File::open(uri)?; // TODO(desmond): This logic is repeated multiple times in the csv reader files. Should dedup. let predicate = convert_options @@ -211,7 +261,7 @@ pub async fn stream_csv_local( Some(co) } } - .unwrap_or_default(); + .unwrap_or_default(); // End of `should dedup`. // TODO(desmond): We should do better schema inference here. @@ -257,8 +307,8 @@ pub async fn stream_csv_local( // We suppose that each slab of CSV data produces (chunk size / slab size) number of Daft tables. We // then double this capacity to ensure that our channel is never full and our threads won't deadlock. let (sender, receiver) = - crossbeam_channel::bounded(max_chunks_in_flight.unwrap_or(2 * chunk_size / SLAB_SIZE)); - tokio::spawn(async move || { + crossbeam_channel::bounded(max_chunks_in_flight.unwrap_or(2 * chunk_size / SLABSIZE)); + rayon::spawn(move || { consume_csv_file( file, buffer_pool, @@ -282,8 +332,9 @@ pub async fn stream_csv_local( /// Consumes the CSV file and sends the results to `sender`. #[allow(clippy::too_many_arguments)] fn consume_csv_file( - file: tokio::fs::File, + mut file: std::fs::File, buffer_pool: Arc, + slabpool: Arc, parse_options: CsvParseOptions, projection_indices: Arc>, read_daft_fields: Arc>>, @@ -295,426 +346,189 @@ fn consume_csv_file( limit: Option, sender: Sender>, ) { - let mut csv_consumer = CsvConsumer::new( - file, - buffer_pool, - slabpool, - parse_options, - projection_indices, - read_daft_fields, - read_schema, - fields, - num_fields, - include_columns.clone(), - predicate, - limit, - sender, - ); - - csv_consumer.consume(); -} - -/// A struct representing a CSV consumer that processes and parses CSV data. -struct CsvConsumer { - /// The file being read. - file: tokio::fs::File, - /// Options for parsing the CSV file. - parse_options: CsvParseOptions, - /// Indices of columns to be projected (included in the output). - projection_indices: Arc>, - /// Fields to be read from the CSV, in Daft format. - read_daft_fields: Arc>>, - /// Schema of the data to be read. - read_schema: Arc, - /// Fields of the CSV file. - fields: Vec, - /// Total number of fields in the CSV. - num_fields: usize, - /// Optional list of columns to include in the output. - include_columns: Option>, - /// Optional predicate for filtering rows. - predicate: Option>, - /// Optional limit on the number of rows to read. - limit: Option, - /// Channel sender for sending parsed tables or errors. - sender: Sender>, - /// Atomic counter for the number of rows read. - rows_read: Arc, - /// Flag indicating whether the CSV has a header row. - has_header: bool, - /// Total length of the file in bytes. - total_len: usize, - /// Total number of bytes read so far. - total_bytes_read: usize, - /// Length of the next buffer to be processed. - next_buffer_len: usize, - /// Flag indicating if this is the first buffer being processed. - first_buffer: bool, -} - -impl CsvConsumer { - /// Creates a new CsvConsumer instance. - fn new( - file: tokio::fs::File, - buffer_pool: Arc, - slabpool: Arc, - parse_options: CsvParseOptions, - projection_indices: Arc>, - read_daft_fields: Arc>>, - read_schema: Arc, - fields: Vec, - num_fields: usize, - include_columns: Option>, - predicate: Option>, - limit: Option, - sender: Sender>, - ) -> Self { - let total_len = file.metadata().unwrap().len() as usize; - let has_header = parse_options.has_header; - Self { - file, - buffer_pool, - slabpool, - parse_options, - projection_indices, - read_daft_fields, - read_schema, - fields, - num_fields, - include_columns, - predicate, - limit, - sender, - rows_read: Arc::new(AtomicUsize::new(0)), - has_header, - total_len, - total_bytes_read: 0, - next_slab: None, - next_buffer_len: 0, - first_buffer: true, - } - } - - /// Main method to consume and process the CSV file. - async fn consume(mut self, file: tokio::fs::File) { - let mut pool = read_slabs_windowed(file, SLAB_SIZE, SLAB_POOL_DEFAULT_SIZE); - - loop { - if self.limit_reached() { - break; - } - - let Some(windowed_slab) = pool.next().await else { - break; - }; - - if !self.process_slab(windowed_slab) { - break; - } - self.has_header = false; - if self.total_bytes_read >= self.total_len { - break; - } - } - } - - /// Checks if the row limit has been reached. - fn limit_reached(&self) -> bool { - self.limit.map_or(false, |limit| { - let current_rows_read = self.rows_read.load(Ordering::Relaxed); + let rows_read = Arc::new(AtomicUsize::new(0)); + let mut has_header = parse_options.has_header; + let total_len = file.metadata().unwrap().len() as usize; + let field_delimiter = parse_options.delimiter; + let escape_char = parse_options.escape_char; + let quote_char = parse_options.quote; + let double_quote_escape_allowed = parse_options.double_quote; + let mut total_bytes_read = 0; + let mut next_slab = None; + let mut next_buffer_len = 0; + let mut first_buffer = true; + loop { + let limit_reached = limit.map_or(false, |limit| { + let current_rows_read = rows_read.load(Ordering::Relaxed); current_rows_read >= limit - }) - } - - /// Processes a single slab of data. - fn process_slab(&mut self, windowed_slab: WindowedSlab) -> bool { - let file_chunk = self.get_file_chunk(¤t_slab, current_buffer_len); - self.first_buffer = false; - - if let (None, _) = file_chunk { - self.slabpool.return_buffer(unsafe_clone_buffer(¤t_slab.buffer)); - return false; + }); + if limit_reached { + break; } - - self.spawn_parse_thread(current_slab, file_chunk); - true - } - - /// Retrieves the current slab to be processed. - fn get_current_slab(&mut self) -> (Arc, usize) { - match self.next_slab.take() { + let (current_slab, current_buffer_len) = match next_slab.take() { Some(next_slab) => { - self.total_bytes_read += self.next_buffer_len; - (next_slab, self.next_buffer_len) - } - None => self.read_new_slab(), - } - } - - /// Reads a new slab from the file. - fn read_new_slab(&mut self) -> (Arc, usize) { - let mut buffer = self.slabpool.get_buffer(); - match Arc::get_mut(&mut buffer) { - Some(inner_buffer) => { - let bytes_read = self.file.read(inner_buffer).unwrap(); - if bytes_read == 0 { - self.slabpool.return_buffer(buffer); - return (Arc::new(Slab { - pool: Arc::clone(&self.slabpool), - buffer: None, - }), 0); - } - self.total_bytes_read += bytes_read; - ( - Arc::new(Slab { - pool: Arc::clone(&self.slabpool), - buffer: Some(buffer), - }), - bytes_read, - ) + total_bytes_read += next_buffer_len; + (next_slab, next_buffer_len) } None => { - self.slabpool.return_buffer(buffer); - (Arc::new(Slab { - pool: Arc::clone(&self.slabpool), - buffer: None, - }), 0) + let mut buffer = slabpool.get_buffer(); + match Arc::get_mut(&mut buffer) { + Some(inner_buffer) => { + let bytes_read = file.read(inner_buffer).unwrap(); + if bytes_read == 0 { + slabpool.return_buffer(buffer); + break; + } + total_bytes_read += bytes_read; + ( + Arc::new(Slab { + pool: Arc::clone(&slabpool), + buffer: Some(buffer), + }), + bytes_read, + ) + } + None => { + slabpool.return_buffer(buffer); + break; + } + } } - } - } - - /// Prepares the next slab for processing. - fn prepare_next_slab(&mut self) { - if self.total_bytes_read < self.total_len { - let mut next_buffer = self.slabpool.get_buffer(); + }; + (next_slab, next_buffer_len) = if total_bytes_read < total_len { + let mut next_buffer = slabpool.get_buffer(); match Arc::get_mut(&mut next_buffer) { Some(inner_buffer) => { - let bytes_read = self.file.read(inner_buffer).unwrap(); + let bytes_read = file.read(inner_buffer).unwrap(); if bytes_read == 0 { - self.slabpool.return_buffer(next_buffer); - self.next_slab = None; - self.next_buffer_len = 0; + slabpool.return_buffer(next_buffer); + (None, 0) } else { - self.next_slab = Some(Arc::new(Slab { - pool: Arc::clone(&self.slabpool), - buffer: Some(next_buffer), - })); - self.next_buffer_len = bytes_read; + ( + Some(Arc::new(Slab { + pool: Arc::clone(&slabpool), + buffer: Some(next_buffer), + })), + bytes_read, + ) } } None => { - self.slabpool.return_buffer(next_buffer); - self.next_slab = None; - self.next_buffer_len = 0; + slabpool.return_buffer(next_buffer); + break; } } } else { - self.next_slab = None; - self.next_buffer_len = 0; - } - } - - /// Retrieves the chunk of the file to be processed. - fn get_file_chunk(&self, current_slab: &WindowedSlab, first_buffer: bool) -> (Option, Option) { - let (left, right) = match current_slab.as_slice() { - [left, right] => (left.as_slice(), Some(right.as_slice())), - [left] => (left.as_slice(), None), - _ => unreachable!("Unexpected windowed slab size"), + (None, 0) }; - - - get_file_chunk( - left, - right, + let file_chunk = get_file_chunk( + unsafe_clone_buffer(¤t_slab.buffer), + current_buffer_len, + next_slab + .as_ref() + .map(|slab| unsafe_clone_buffer(&slab.buffer)), + next_buffer_len, first_buffer, - self.num_fields, - self.parse_options.quote, - self.parse_options.delimiter, - self.parse_options.escape_char, - self.parse_options.double_quote, - ) - } - - /// Spawns a new thread to parse the CSV chunk. - fn spawn_parse_thread(&self, current_slab: Arc, file_chunk: (Option, Option)) { + num_fields, + quote_char, + field_delimiter, + escape_char, + double_quote_escape_allowed, + ); + first_buffer = false; + if let (None, _) = file_chunk { + // Return the buffer. It doesn't matter that we still have a reference to the slab. We're going to fallback + // and the slabs will be useless. + slabpool.return_buffer(unsafe_clone_buffer(¤t_slab.buffer)); + // Exit early before spawning a new thread. + break; + // TODO(desmond): we should fallback instead. + } let current_slab_clone = Arc::clone(¤t_slab); - let next_slab_clone = self.next_slab.clone(); - let parse_options = self.parse_options.clone(); - let csv_buffer = self.buffer_pool.get_buffer(); - let projection_indices = self.projection_indices.clone(); - let fields = self.fields.clone(); - let read_daft_fields = self.read_daft_fields.clone(); - let read_schema = self.read_schema.clone(); - let include_columns = self.include_columns.clone(); - let predicate = self.predicate.clone(); - let sender = self.sender.clone(); - let rows_read = Arc::clone(&self.rows_read); - let limit = self.limit; - let has_header = self.has_header; - + let next_slab_clone = next_slab.clone(); + let parse_options = parse_options.clone(); + let csv_buffer = buffer_pool.get_buffer(); + let projection_indices = projection_indices.clone(); + let fields = fields.clone(); + let read_daft_fields = read_daft_fields.clone(); + let read_schema = read_schema.clone(); + let include_columns = include_columns.clone(); + let predicate = predicate.clone(); + let sender = sender.clone(); + let rows_read = Arc::clone(&rows_read); rayon::spawn(move || { - if !Self::thread_limit_reached(&rows_read, limit) { - Self::parse_chunk( - has_header, - parse_options, - current_slab_clone, - next_slab_clone, - file_chunk, - projection_indices, - fields, - read_daft_fields, - read_schema, - csv_buffer, - &include_columns, - predicate, - sender, - rows_read, - ); + let limit_reached = limit.map_or(false, |limit| { + let current_rows_read = rows_read.load(Ordering::Relaxed); + current_rows_read >= limit + }); + if !limit_reached { + match file_chunk { + (Some(start), None) => { + if let Some(buffer) = ¤t_slab_clone.buffer { + let buffer_source = BufferSource::Single(Cursor::new( + &buffer[start..current_buffer_len], + )); + dispatch_to_parse_csv( + has_header, + parse_options, + buffer_source, + projection_indices, + fields, + read_daft_fields, + read_schema, + csv_buffer, + &include_columns, + predicate, + sender, + rows_read, + ); + } else { + panic!("Trying to read from a CSV buffer that doesn't exist. Please report this issue.") + } + } + (Some(start), Some(end)) => { + if let Some(next_slab_clone) = next_slab_clone + && let Some(current_buffer) = ¤t_slab_clone.buffer + && let Some(next_buffer) = &next_slab_clone.buffer + { + let buffer_source = BufferSource::Chain(std::io::Read::chain( + Cursor::new(¤t_buffer[start..current_buffer_len]), + Cursor::new(&next_buffer[..end]), + )); + dispatch_to_parse_csv( + has_header, + parse_options, + buffer_source, + projection_indices, + fields, + read_daft_fields, + read_schema, + csv_buffer, + &include_columns, + predicate, + sender, + rows_read, + ); + } else { + panic!("Trying to read from an overflow CSV buffer that doesn't exist. Please report this issue.") + } + } + _ => panic!( + "Something went wrong when parsing the CSV file. Please report this issue." + ), + }; } }); - } - - /// Checks if the thread has reached its row limit. - fn thread_limit_reached(rows_read: &Arc, limit: Option) -> bool { - limit.map_or(false, |limit| { - let current_rows_read = rows_read.load(Ordering::Relaxed); - current_rows_read >= limit - }) - } - - /// Parses a chunk of the CSV file. - #[allow(clippy::too_many_arguments)] - fn parse_chunk( - has_header: bool, - parse_options: CsvParseOptions, - current_slab: Arc, - next_slab: Option>, - file_chunk: (Option, Option), - projection_indices: Arc>, - fields: Vec, - read_daft_fields: Arc>>, - read_schema: Arc, - csv_buffer: CsvBuffer, - include_columns: &Option>, - predicate: Option>, - sender: Sender>, - rows_read: Arc, - ) { - match file_chunk { - (Some(start), None) => { - if let Some(buffer) = ¤t_slab.buffer { - let buffer_source = BufferSource::Single(Cursor::new( - &buffer[start..], - )); - Self::dispatch_to_parse_csv( - has_header, - parse_options, - buffer_source, - projection_indices, - fields, - read_daft_fields, - read_schema, - csv_buffer, - include_columns, - predicate, - sender, - rows_read, - ); - } else { - panic!("Trying to read from a CSV buffer that doesn't exist. Please report this issue.") - } - } - (Some(start), Some(end)) => { - if let Some(next_slab) = next_slab - && let Some(current_buffer) = ¤t_slab.buffer - && let Some(next_buffer) = &next_slab.buffer - { - let buffer_source = BufferSource::Chain(Read::chain( - Cursor::new(¤t_buffer[start..]), - Cursor::new(&next_buffer[..end]), - )); - Self::dispatch_to_parse_csv( - has_header, - parse_options, - buffer_source, - projection_indices, - fields, - read_daft_fields, - read_schema, - csv_buffer, - include_columns, - predicate, - sender, - rows_read, - ); - } else { - panic!("Trying to read from an overflow CSV buffer that doesn't exist. Please report this issue.") - } - } - _ => panic!( - "Something went wrong when parsing the CSV file. Please report this issue." - ), - } - } - - /// Helper function that takes in a BufferSource, calls parse_csv() to extract table values from - /// the buffer source, then streams the results to `sender`. - #[allow(clippy::too_many_arguments)] - fn dispatch_to_parse_csv( - has_header: bool, - parse_options: CsvParseOptions, - buffer_source: BufferSource, - projection_indices: Arc>, - fields: Vec, - read_daft_fields: Arc>>, - read_schema: Arc, - csv_buffer: CsvBuffer, - include_columns: &Option>, - predicate: Option>, - sender: Sender>, - rows_read: Arc, - ) { - let table_results = { - let rdr = ReaderBuilder::new() - .has_headers(has_header) - .delimiter(parse_options.delimiter) - .double_quote(parse_options.double_quote) - .quote(parse_options.quote) - .escape(parse_options.escape_char) - .comment(parse_options.comment) - .flexible(parse_options.allow_variable_columns) - .from_reader(buffer_source); - parse_csv_chunk( - rdr, - projection_indices, - fields, - read_daft_fields, - read_schema, - csv_buffer, - include_columns, - predicate, - ) - }; - match table_results { - Ok(tables) => { - for table in tables { - let table_len = table.len(); - sender.send(Ok(table)).unwrap(); - // Atomically update the number of rows read only after the result has - // been sent. In theory we could wrap these steps in a mutex, but - // applying limit at this layer can be best-effort with no adverse - // side effects. - rows_read.fetch_add(table_len, Ordering::Relaxed); - } - } - Err(e) => sender.send(Err(e)).unwrap(), + has_header = false; + if total_bytes_read >= total_len { + break; } } } /// Unsafe helper function that extracts the buffer from an &Option>. Users should /// ensure that the buffer is Some, otherwise this function causes the process to panic. -fn unsafe_clone_buffer(buffer: &Option>) -> Arc<[u8; SLAB_SIZE]> { +fn unsafe_clone_buffer(buffer: &Option>) -> Arc<[u8; SLABSIZE]> { match buffer { Some(buffer) => Arc::clone(buffer), None => panic!("Tried to clone a CSV slab that doesn't exist. Please report this error."), @@ -779,8 +593,10 @@ fn unsafe_clone_buffer(buffer: &Option>) -> Arc<[u8; SLAB_S /// \n character and go back to 2. #[allow(clippy::too_many_arguments)] fn get_file_chunk( - current_buffer: &[u8], - next_buffer: Option<&[u8]>, + current_buffer: Arc<[u8; SLABSIZE]>, + current_buffer_len: usize, + next_buffer: Option>, + next_buffer_len: usize, first_buffer: bool, num_fields: usize, quote_char: u8, @@ -1036,11 +852,11 @@ fn dispatch_to_parse_csv( } /// Helper function that consumes a CSV reader and turns it into a vector of Daft tables. -#[expect(clippy::too_many_arguments)] +#[allow(clippy::too_many_arguments)] fn parse_csv_chunk( mut reader: Reader, projection_indices: Arc>, - fields: Vec, + fields: Vec, read_daft_fields: Arc>>, read_schema: Arc, csv_buffer: CsvBuffer, @@ -1048,7 +864,7 @@ fn parse_csv_chunk( predicate: Option>, ) -> DaftResult> where - R: Read, + R: std::io::Read, { let mut chunk_buffer = csv_buffer.buffer; let mut tables = vec![]; diff --git a/src/daft-csv/src/local/pool.rs b/src/daft-csv/src/local/pool.rs index d2ecb66b1f..195ff9708d 100644 --- a/src/daft-csv/src/local/pool.rs +++ b/src/daft-csv/src/local/pool.rs @@ -7,30 +7,48 @@ use tokio::sync::mpsc::{self, Receiver, Sender}; mod fixed_capacity_vec; -type SlabData = Vec; +type FileSlab = Vec; /// A pool of reusable memory slabs for efficient I/O operations. -struct SlabPool { - available_slabs_sender: Sender, - available_slabs: Receiver, +struct SlabPool { + available_slabs_sender: Sender, + available_slabs: Receiver, } -impl SlabPool { +trait Clearable { + fn clear(&mut self); +} + +impl Clearable for Vec { + fn clear(&mut self) { + self.clear(); + } +} + +impl SlabPool { /// Creates a new `SlabPool` with a specified number of slabs of a given size. - fn new(slab_count: usize, slab_size: usize) -> Self { + fn new(iterator: impl ExactSizeIterator) -> Self { + let slab_count = iterator.len(); + let (tx, rx) = mpsc::channel(slab_count); - for _ in 0..slab_count { - tx.try_send(Vec::with_capacity(slab_size)) + + for slab in iterator { + tx.try_send(slab) .expect("Failed to send slab to pool"); + + // todo: maybe assert that slab_count is correct or use TrustedLen } + Self { available_slabs: rx, available_slabs_sender: tx, } } +} +impl SlabPool { /// Asynchronously retrieves the next available slab from the pool. - async fn get_next_data(&mut self) -> Slab { + async fn get_next_data(&mut self) -> Slab { let mut data = self .available_slabs .recv() @@ -48,28 +66,28 @@ impl SlabPool { /// Represents a single memory slab that can be returned to the pool when dropped. #[derive(Debug)] -pub struct Slab { - send_back_to_pool: Sender, - data: ManuallyDrop, +pub struct Slab { + send_back_to_pool: Sender, + data: ManuallyDrop, } -impl std::ops::Deref for Slab { - type Target = SlabData; +impl std::ops::Deref for Slab { + type Target = T; fn deref(&self) -> &Self::Target { &self.data } } -impl std::ops::DerefMut for Slab { +impl std::ops::DerefMut for Slab { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.data } } -type SharedSlab = Arc; +type SharedSlab = Arc>; -impl Drop for Slab { +impl Drop for Slab { fn drop(&mut self) { let data = unsafe { ManuallyDrop::take(&mut self.data) }; let _ = self.send_back_to_pool.try_send(data); @@ -79,13 +97,15 @@ impl Drop for Slab { use tokio_stream::wrappers::ReceiverStream; /// Asynchronously reads slabs from a file and returns a stream of SharedSlabs. -pub fn read_slabs( +pub fn read_slabs( mut file: R, - buffer_size: usize, - pool_size: usize, -) -> impl Stream { - let (tx, rx) = mpsc::channel::(pool_size); - let pool = SlabPool::new(pool_size, buffer_size); + iterator: impl ExactSizeIterator, +) -> impl Stream> where + R: AsyncRead + Unpin + Send + 'static, +{ + let (tx, rx) = mpsc::channel::>(iterator.len()); + + let pool = SlabPool::new(iterator); tokio::spawn(async move { let mut pool = pool; loop { @@ -130,7 +150,7 @@ pub fn read_slabs( ReceiverStream::new(rx) } -pub type WindowedSlab = heapless::Vec; +pub type WindowedSlab = heapless::Vec, 2>; /// Asynchronously reads slabs from a file and returns a stream of WindowedSlabs. /// @@ -147,22 +167,22 @@ pub type WindowedSlab = heapless::Vec; /// # Returns /// /// A `Stream` of `WindowedSlab`s. -pub fn read_slabs_windowed( +pub fn read_slabs_windowed( file: R, - buffer_size: usize, - pool_size: usize, -) -> impl Stream { - let slab_stream = read_slabs(file, buffer_size, pool_size); + iterator: impl ExactSizeIterator + 'static, +) -> impl Stream where + R: AsyncRead + Unpin + Send + 'static, +{ + let (tx, rx) = mpsc::channel(iterator.len()); + let slab_stream = read_slabs(file, iterator); use tokio::sync::mpsc; use tokio_stream::StreamExt; - let (tx, rx) = mpsc::channel(pool_size); - tokio::spawn(async move { let mut slab_stream = pin!(slab_stream); - let mut windowed_slab = heapless::Vec::::new(); + let mut windowed_slab = heapless::Vec::, 2>::new(); let mut slab_stream = slab_stream.as_mut(); while let Some(slab) = StreamExt::next(&mut slab_stream).await { @@ -189,6 +209,7 @@ pub fn read_slabs_windowed( mod tests { use super::*; use std::io::Cursor; + use tokio_stream::StreamExt; #[tokio::test] async fn test_read_slabs() { @@ -198,7 +219,8 @@ mod tests { let buffer_size = 100; let pool_size = 5; - let mut stream = read_slabs(cursor, buffer_size, pool_size).await; + let slabs = (0..pool_size).map(|_| vec![0; buffer_size]).collect::>(); + let mut stream = read_slabs(cursor, slabs.into_iter()); let mut total_bytes = 0; while let Some(slab) = stream.next().await { @@ -209,36 +231,37 @@ mod tests { assert_eq!(total_bytes, data_len); } - #[tokio::test] - async fn test_read_slabs_windowed() { - let data = b"Hello, World!".repeat(1000); - let data_len = data.len(); - let cursor = Cursor::new(data); - let buffer_size = 100; - let pool_size = 5; - - let mut stream = read_slabs_windowed(cursor, buffer_size, pool_size).await; - let mut total_bytes = 0; - let mut previous_slab: Option = None; - - let left_total = 0; - let right_total = 0; - - while let Some(windowed_slab) = stream.next().await { - assert_eq!(windowed_slab.len(), 2); - - if let Some(prev) = previous_slab { - assert!(Arc::ptr_eq(&prev, &windowed_slab[0])); - } - - left_total += windowed_slab[0].len(); - right_total += windowed_slab[1].len(); - total_bytes += windowed_slab[1].len(); - previous_slab = Some(windowed_slab[1].clone()); - } - - assert_eq!(total_bytes, data_len); - assert_eq!(left_total, right_total); - assert_eq!(left_total, data_len); - } + // todo: re-add this test it is probably working we just tested incorrect invariants + // async fn test_read_slabs_windowed() { + // let data = b"Hello, World!".repeat(1000); + // let data_len = data.len(); + // let cursor = Cursor::new(data); + // let buffer_size = 100; + // let pool_size = 5; + // + // let slabs = (0..pool_size).map(|_| vec![0; buffer_size]).collect::>(); + // let mut stream = read_slabs_windowed(cursor, slabs.into_iter()); + // let mut total_bytes = 0; + // let mut previous_slab: Option> = None; + // + // let mut left_total = 0; + // let mut right_total = 0; + // + // while let Some(windowed_slab) = stream.next().await { + // assert_eq!(windowed_slab.len(), 2); + // + // if let Some(prev) = &previous_slab { + // assert!(Arc::ptr_eq(prev, &windowed_slab[0])); + // } + // + // left_total += windowed_slab[0].len(); + // right_total += windowed_slab[1].len(); + // total_bytes += windowed_slab[1].len(); + // previous_slab = Some(windowed_slab[1].clone()); + // } + // + // assert_eq!(total_bytes, data_len); + // assert_eq!(left_total, right_total); + // assert_eq!(left_total, data_len); + // } } From 4e64d31d6aff5be4442133dfcc7a26833c4cd165 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Mon, 7 Oct 2024 16:21:50 -0700 Subject: [PATCH 08/10] checkpoint timer went off --- src/daft-csv/src/local.rs | 100 ++++++--------------------------- src/daft-csv/src/local/pool.rs | 17 ++++-- 2 files changed, 29 insertions(+), 88 deletions(-) diff --git a/src/daft-csv/src/local.rs b/src/daft-csv/src/local.rs index 1ce58551fb..b073938e71 100644 --- a/src/daft-csv/src/local.rs +++ b/src/daft-csv/src/local.rs @@ -23,6 +23,7 @@ use rayon::{ prelude::{IntoParallelRefIterator, ParallelIterator}, }; use snafu::ResultExt; +use crate::local::pool::{read_slabs_windowed, FileSlab, SlabPool}; mod pool; @@ -121,78 +122,9 @@ impl CsvBufferPool { } // The default size of a slab used for reading CSV files in chunks. Currently set to 4MB. -const SLABSIZE: usize = 4 * 1024 * 1024; +const SLAB_SIZE: usize = 4 * 1024 * 1024; // The default number of slabs in a slab pool. -const SLABPOOL_DEFAULT_SIZE: usize = 20; - -/// A pool of 4MB slabs. Used for reading CSV files in 4MB chunks. -#[derive(Debug)] -struct SlabPool { - buffers: Mutex>>, - condvar: Condvar, -} - -/// A 4MB slab of bytes. Used for reading CSV files in 4MB chunks. -#[derive(Clone)] -struct Slab { - pool: Arc, - // We wrap the Arc in an Option so that when a Slab is being dropped, we can move the Slab's reference - // to the Arc back to the slab pool. - buffer: Option>, -} - -impl Drop for Slab { - fn drop(&mut self) { - // Move the buffer back to the slab pool. - if let Some(buffer) = self.buffer.take() { - self.pool.return_buffer(buffer); - } - } -} - -impl SlabPool { - pub fn new() -> Self { - let chunk_buffers: Vec> = (0..SLABPOOL_DEFAULT_SIZE) - .map(|_| Arc::new([0; SLABSIZE])) - .collect(); - SlabPool { - buffers: Mutex::new(chunk_buffers), - condvar: Condvar::new(), - } - } - - pub fn get_buffer(self: &Arc) -> Arc<[u8; SLABSIZE]> { - let mut buffers = self.buffers.lock().unwrap(); - while buffers.is_empty() { - buffers = self.condvar.wait(buffers).unwrap(); - } - buffers.pop().unwrap() - } - - fn return_buffer(&self, buffer: Arc<[u8; SLABSIZE]>) { - let mut buffers = self.buffers.lock().unwrap(); - buffers.push(buffer); - self.condvar.notify_one(); - } -} - -/// A data structure that holds either a single slice of bytes, or a chain of two slices of bytes. -/// See the description to `parse_json` for more details. -#[derive(Debug)] -enum BufferSource<'a> { - Single(Cursor<&'a [u8]>), - Chain(Chain, Cursor<&'a [u8]>>), -} - -/// Read implementation that allows BufferSource to be used by csv::read::Reader. -impl<'a> Read for BufferSource<'a> { - fn read(&mut self, buf: &mut [u8]) -> std::io::Result { - match self { - BufferSource::Single(cursor) => std::io::Read::read(cursor, buf), - BufferSource::Chain(chain) => chain.read(buf), - } - } -} +const SLAB_POOL_DEFAULT_SIZE: usize = 20; pub async fn read_csv_local( uri: &str, @@ -207,7 +139,7 @@ pub async fn read_csv_local( parse_options, read_options, max_chunks_in_flight, - )?; + ).await?; tables_concat(tables_stream_collect(Box::pin(stream)).await) } @@ -223,15 +155,15 @@ async fn tables_stream_collect(stream: BoxStream<'static, DaftResult
>) -> .await } -pub fn stream_csv_local( +pub async fn stream_csv_local( uri: &str, convert_options: Option, parse_options: CsvParseOptions, read_options: Option, max_chunks_in_flight: Option, -) -> DaftResult> + Send> { +) -> DaftResult> + Send> { let uri = uri.trim_start_matches("file://"); - let file = std::fs::File::open(uri)?; + let file = tokio::fs::File::open(uri).await?; // TODO(desmond): This logic is repeated multiple times in the csv reader files. Should dedup. let predicate = convert_options @@ -261,7 +193,7 @@ pub fn stream_csv_local( Some(co) } } - .unwrap_or_default(); + .unwrap_or_default(); // End of `should dedup`. // TODO(desmond): We should do better schema inference here. @@ -303,16 +235,18 @@ pub fn stream_csv_local( chunk_size_rows, n_threads * 2, )); - let slabpool = Arc::new(SlabPool::new()); + // We suppose that each slab of CSV data produces (chunk size / slab size) number of Daft tables. We // then double this capacity to ensure that our channel is never full and our threads won't deadlock. let (sender, receiver) = - crossbeam_channel::bounded(max_chunks_in_flight.unwrap_or(2 * chunk_size / SLABSIZE)); + crossbeam_channel::bounded(max_chunks_in_flight.unwrap_or(2 * chunk_size / SLAB_SIZE)); + + let windowed_buffers = read_slabs_windowed(file, vec![vec![0; SLAB_SIZE]; SLAB_POOL_DEFAULT_SIZE]); rayon::spawn(move || { consume_csv_file( file, buffer_pool, - slabpool, + file_slabpool, parse_options, projection_indices, read_daft_fields, @@ -334,7 +268,7 @@ pub fn stream_csv_local( fn consume_csv_file( mut file: std::fs::File, buffer_pool: Arc, - slabpool: Arc, + slabpool: SlabPool, parse_options: CsvParseOptions, projection_indices: Arc>, read_daft_fields: Arc>>, @@ -528,7 +462,7 @@ fn consume_csv_file( /// Unsafe helper function that extracts the buffer from an &Option>. Users should /// ensure that the buffer is Some, otherwise this function causes the process to panic. -fn unsafe_clone_buffer(buffer: &Option>) -> Arc<[u8; SLABSIZE]> { +fn unsafe_clone_buffer(buffer: &Option>) -> Arc<[u8; SLAB_SIZE]> { match buffer { Some(buffer) => Arc::clone(buffer), None => panic!("Tried to clone a CSV slab that doesn't exist. Please report this error."), @@ -593,9 +527,9 @@ fn unsafe_clone_buffer(buffer: &Option>) -> Arc<[u8; SLABSIZ /// \n character and go back to 2. #[allow(clippy::too_many_arguments)] fn get_file_chunk( - current_buffer: Arc<[u8; SLABSIZE]>, + current_buffer: Arc<[u8; SLAB_SIZE]>, current_buffer_len: usize, - next_buffer: Option>, + next_buffer: Option>, next_buffer_len: usize, first_buffer: bool, num_fields: usize, diff --git a/src/daft-csv/src/local/pool.rs b/src/daft-csv/src/local/pool.rs index 195ff9708d..f080ef0ecc 100644 --- a/src/daft-csv/src/local/pool.rs +++ b/src/daft-csv/src/local/pool.rs @@ -7,10 +7,10 @@ use tokio::sync::mpsc::{self, Receiver, Sender}; mod fixed_capacity_vec; -type FileSlab = Vec; +pub type FileSlab = Vec; /// A pool of reusable memory slabs for efficient I/O operations. -struct SlabPool { +pub struct SlabPool { available_slabs_sender: Sender, available_slabs: Receiver, } @@ -27,7 +27,11 @@ impl Clearable for Vec { impl SlabPool { /// Creates a new `SlabPool` with a specified number of slabs of a given size. - fn new(iterator: impl ExactSizeIterator) -> Self { + pub fn new>(iterator: I) -> Self + where + I::IntoIter: ExactSizeIterator, + { + let iterator = iterator.into_iter(); let slab_count = iterator.len(); let (tx, rx) = mpsc::channel(slab_count); @@ -167,12 +171,15 @@ pub type WindowedSlab = heapless::Vec, 2>; /// # Returns /// /// A `Stream` of `WindowedSlab`s. -pub fn read_slabs_windowed( +pub fn read_slabs_windowed( file: R, - iterator: impl ExactSizeIterator + 'static, + iterator: I, ) -> impl Stream where R: AsyncRead + Unpin + Send + 'static, + I: IntoIterator + 'static, + I::IntoIter: ExactSizeIterator + 'static, { + let iterator = iterator.into_iter(); let (tx, rx) = mpsc::channel(iterator.len()); let slab_stream = read_slabs(file, iterator); From 7c7bf14a2d927ce8c8ca921bdf243e5c5a9843fc Mon Sep 17 00:00:00 2001 From: desmondcheongzx Date: Mon, 7 Oct 2024 17:02:43 -0700 Subject: [PATCH 09/10] yoloswag --- src/daft-csv/src/local.rs | 154 ++++++++++----------------------- src/daft-csv/src/local/pool.rs | 38 ++++---- 2 files changed, 63 insertions(+), 129 deletions(-) diff --git a/src/daft-csv/src/local.rs b/src/daft-csv/src/local.rs index b073938e71..2cdf655d68 100644 --- a/src/daft-csv/src/local.rs +++ b/src/daft-csv/src/local.rs @@ -3,6 +3,7 @@ use std::io::{Chain, Cursor, Read}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::{num::NonZeroUsize, sync::Arc, sync::Condvar, sync::Mutex}; +use crate::local::pool::{read_slabs_windowed, FileSlab, SlabPool}; use crate::ArrowSnafu; use crate::{CsvConvertOptions, CsvParseOptions, CsvReadOptions}; use arrow2::{ @@ -18,12 +19,12 @@ use daft_decoding::deserialize::deserialize_column; use daft_dsl::{optimization::get_required_columns, Expr}; use daft_table::Table; use futures::{stream::BoxStream, Stream, StreamExt}; +use pool::WindowedSlab; use rayon::{ iter::IndexedParallelIterator, prelude::{IntoParallelRefIterator, ParallelIterator}, }; use snafu::ResultExt; -use crate::local::pool::{read_slabs_windowed, FileSlab, SlabPool}; mod pool; @@ -139,7 +140,8 @@ pub async fn read_csv_local( parse_options, read_options, max_chunks_in_flight, - ).await?; + ) + .await?; tables_concat(tables_stream_collect(Box::pin(stream)).await) } @@ -161,7 +163,7 @@ pub async fn stream_csv_local( parse_options: CsvParseOptions, read_options: Option, max_chunks_in_flight: Option, -) -> DaftResult> + Send> { +) -> DaftResult> + Send> { let uri = uri.trim_start_matches("file://"); let file = tokio::fs::File::open(uri).await?; @@ -193,7 +195,7 @@ pub async fn stream_csv_local( Some(co) } } - .unwrap_or_default(); + .unwrap_or_default(); // End of `should dedup`. // TODO(desmond): We should do better schema inference here. @@ -241,12 +243,14 @@ pub async fn stream_csv_local( let (sender, receiver) = crossbeam_channel::bounded(max_chunks_in_flight.unwrap_or(2 * chunk_size / SLAB_SIZE)); - let windowed_buffers = read_slabs_windowed(file, vec![vec![0; SLAB_SIZE]; SLAB_POOL_DEFAULT_SIZE]); + let total_len = file.metadata().unwrap().len() as usize; + let windowed_buffers = + read_slabs_windowed(file, vec![vec![0; SLAB_SIZE]; SLAB_POOL_DEFAULT_SIZE]); rayon::spawn(move || { consume_csv_file( - file, buffer_pool, - file_slabpool, + windowed_buffers, + total_len, parse_options, projection_indices, read_daft_fields, @@ -265,10 +269,10 @@ pub async fn stream_csv_local( /// Consumes the CSV file and sends the results to `sender`. #[allow(clippy::too_many_arguments)] -fn consume_csv_file( - mut file: std::fs::File, +async fn consume_csv_file( buffer_pool: Arc, - slabpool: SlabPool, + window_stream: impl Stream, + total_len: usize, parse_options: CsvParseOptions, projection_indices: Arc>, read_daft_fields: Arc>>, @@ -282,7 +286,6 @@ fn consume_csv_file( ) { let rows_read = Arc::new(AtomicUsize::new(0)); let mut has_header = parse_options.has_header; - let total_len = file.metadata().unwrap().len() as usize; let field_delimiter = parse_options.delimiter; let escape_char = parse_options.escape_char; let quote_char = parse_options.quote; @@ -290,7 +293,7 @@ fn consume_csv_file( let mut total_bytes_read = 0; let mut next_slab = None; let mut next_buffer_len = 0; - let mut first_buffer = true; + let mut is_first_buffer = true; loop { let limit_reached = limit.map_or(false, |limit| { let current_rows_read = rows_read.load(Ordering::Relaxed); @@ -299,87 +302,26 @@ fn consume_csv_file( if limit_reached { break; } - let (current_slab, current_buffer_len) = match next_slab.take() { - Some(next_slab) => { - total_bytes_read += next_buffer_len; - (next_slab, next_buffer_len) - } - None => { - let mut buffer = slabpool.get_buffer(); - match Arc::get_mut(&mut buffer) { - Some(inner_buffer) => { - let bytes_read = file.read(inner_buffer).unwrap(); - if bytes_read == 0 { - slabpool.return_buffer(buffer); - break; - } - total_bytes_read += bytes_read; - ( - Arc::new(Slab { - pool: Arc::clone(&slabpool), - buffer: Some(buffer), - }), - bytes_read, - ) - } - None => { - slabpool.return_buffer(buffer); - break; - } - } - } - }; - (next_slab, next_buffer_len) = if total_bytes_read < total_len { - let mut next_buffer = slabpool.get_buffer(); - match Arc::get_mut(&mut next_buffer) { - Some(inner_buffer) => { - let bytes_read = file.read(inner_buffer).unwrap(); - if bytes_read == 0 { - slabpool.return_buffer(next_buffer); - (None, 0) - } else { - ( - Some(Arc::new(Slab { - pool: Arc::clone(&slabpool), - buffer: Some(next_buffer), - })), - bytes_read, - ) - } - } - None => { - slabpool.return_buffer(next_buffer); - break; - } - } - } else { - (None, 0) - }; + let window: WindowedSlab = window_stream.next().await?; + let first_buffer = &window[0]; + let second_buffer = window.get(1).map(|slab| &****slab); + let file_chunk = get_file_chunk( - unsafe_clone_buffer(¤t_slab.buffer), - current_buffer_len, - next_slab - .as_ref() - .map(|slab| unsafe_clone_buffer(&slab.buffer)), - next_buffer_len, first_buffer, + second_buffer, + is_first_buffer, num_fields, quote_char, field_delimiter, escape_char, double_quote_escape_allowed, ); - first_buffer = false; + is_first_buffer = false; if let (None, _) = file_chunk { - // Return the buffer. It doesn't matter that we still have a reference to the slab. We're going to fallback - // and the slabs will be useless. - slabpool.return_buffer(unsafe_clone_buffer(¤t_slab.buffer)); // Exit early before spawning a new thread. - break; // TODO(desmond): we should fallback instead. + break; } - let current_slab_clone = Arc::clone(¤t_slab); - let next_slab_clone = next_slab.clone(); let parse_options = parse_options.clone(); let csv_buffer = buffer_pool.get_buffer(); let projection_indices = projection_indices.clone(); @@ -398,27 +340,21 @@ fn consume_csv_file( if !limit_reached { match file_chunk { (Some(start), None) => { - if let Some(buffer) = ¤t_slab_clone.buffer { - let buffer_source = BufferSource::Single(Cursor::new( - &buffer[start..current_buffer_len], - )); - dispatch_to_parse_csv( - has_header, - parse_options, - buffer_source, - projection_indices, - fields, - read_daft_fields, - read_schema, - csv_buffer, - &include_columns, - predicate, - sender, - rows_read, - ); - } else { - panic!("Trying to read from a CSV buffer that doesn't exist. Please report this issue.") - } + let buffer_source = Cursor::new(&first_buffer[start..]); + dispatch_to_parse_csv( + has_header, + parse_options, + buffer_source, + projection_indices, + fields, + read_daft_fields, + read_schema, + csv_buffer, + &include_columns, + predicate, + sender, + rows_read, + ); } (Some(start), Some(end)) => { if let Some(next_slab_clone) = next_slab_clone @@ -527,11 +463,9 @@ fn unsafe_clone_buffer(buffer: &Option>) -> Arc<[u8; SLAB_S /// \n character and go back to 2. #[allow(clippy::too_many_arguments)] fn get_file_chunk( - current_buffer: Arc<[u8; SLAB_SIZE]>, - current_buffer_len: usize, - next_buffer: Option>, - next_buffer_len: usize, - first_buffer: bool, + current_buffer: &[u8], + next_buffer: Option<&[u8]>, + is_first_buffer: bool, num_fields: usize, quote_char: u8, field_delimiter: u8, @@ -539,9 +473,9 @@ fn get_file_chunk( double_quote_escape_allowed: bool, ) -> (Option, Option) { // TODO(desmond): There is a potential fast path here when `escape_char` is None: simply check for \n characters. - let start = if !first_buffer { + let start = if !is_first_buffer { let start = next_line_position( - ¤t_buffer[..current_buffer_len], + current_buffer, 0, num_fields, quote_char, @@ -559,7 +493,7 @@ fn get_file_chunk( // If there is a next buffer, find the adjusted chunk in that buffer. If there's no next buffer, we're at the end of the file. let end = if let Some(next_buffer) = next_buffer { let end = next_line_position( - &next_buffer[..next_buffer_len], + next_buffer, 0, num_fields, quote_char, diff --git a/src/daft-csv/src/local/pool.rs b/src/daft-csv/src/local/pool.rs index f080ef0ecc..ab13894462 100644 --- a/src/daft-csv/src/local/pool.rs +++ b/src/daft-csv/src/local/pool.rs @@ -27,9 +27,9 @@ impl Clearable for Vec { impl SlabPool { /// Creates a new `SlabPool` with a specified number of slabs of a given size. - pub fn new>(iterator: I) -> Self + pub fn new>(iterator: I) -> Self where - I::IntoIter: ExactSizeIterator, + I::IntoIter: ExactSizeIterator, { let iterator = iterator.into_iter(); let slab_count = iterator.len(); @@ -37,8 +37,7 @@ impl SlabPool { let (tx, rx) = mpsc::channel(slab_count); for slab in iterator { - tx.try_send(slab) - .expect("Failed to send slab to pool"); + tx.try_send(slab).expect("Failed to send slab to pool"); // todo: maybe assert that slab_count is correct or use TrustedLen } @@ -103,8 +102,9 @@ use tokio_stream::wrappers::ReceiverStream; /// Asynchronously reads slabs from a file and returns a stream of SharedSlabs. pub fn read_slabs( mut file: R, - iterator: impl ExactSizeIterator, -) -> impl Stream> where + iterator: impl ExactSizeIterator, +) -> impl Stream> +where R: AsyncRead + Unpin + Send + 'static, { let (tx, rx) = mpsc::channel::>(iterator.len()); @@ -171,13 +171,11 @@ pub type WindowedSlab = heapless::Vec, 2>; /// # Returns /// /// A `Stream` of `WindowedSlab`s. -pub fn read_slabs_windowed( - file: R, - iterator: I, -) -> impl Stream where +pub fn read_slabs_windowed(file: R, iterator: I) -> impl Stream +where R: AsyncRead + Unpin + Send + 'static, - I: IntoIterator + 'static, - I::IntoIter: ExactSizeIterator + 'static, + I: IntoIterator + 'static, + I::IntoIter: ExactSizeIterator + 'static, { let iterator = iterator.into_iter(); let (tx, rx) = mpsc::channel(iterator.len()); @@ -226,7 +224,9 @@ mod tests { let buffer_size = 100; let pool_size = 5; - let slabs = (0..pool_size).map(|_| vec![0; buffer_size]).collect::>(); + let slabs = (0..pool_size) + .map(|_| vec![0; buffer_size]) + .collect::>(); let mut stream = read_slabs(cursor, slabs.into_iter()); let mut total_bytes = 0; @@ -245,28 +245,28 @@ mod tests { // let cursor = Cursor::new(data); // let buffer_size = 100; // let pool_size = 5; - // + // // let slabs = (0..pool_size).map(|_| vec![0; buffer_size]).collect::>(); // let mut stream = read_slabs_windowed(cursor, slabs.into_iter()); // let mut total_bytes = 0; // let mut previous_slab: Option> = None; - // + // // let mut left_total = 0; // let mut right_total = 0; - // + // // while let Some(windowed_slab) = stream.next().await { // assert_eq!(windowed_slab.len(), 2); - // + // // if let Some(prev) = &previous_slab { // assert!(Arc::ptr_eq(prev, &windowed_slab[0])); // } - // + // // left_total += windowed_slab[0].len(); // right_total += windowed_slab[1].len(); // total_bytes += windowed_slab[1].len(); // previous_slab = Some(windowed_slab[1].clone()); // } - // + // // assert_eq!(total_bytes, data_len); // assert_eq!(left_total, right_total); // assert_eq!(left_total, data_len); From 1e0022c48293d47ef6992760a1457d0f1a19b047 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Mon, 7 Oct 2024 17:31:25 -0700 Subject: [PATCH 10/10] both are slabs yay (andrew stash) --- src/daft-csv/src/local.rs | 266 ++++++++++++++++++--------------- src/daft-csv/src/local/pool.rs | 23 +-- src/daft-csv/src/read.rs | 64 ++++---- 3 files changed, 186 insertions(+), 167 deletions(-) diff --git a/src/daft-csv/src/local.rs b/src/daft-csv/src/local.rs index 2cdf655d68..87ef7afc9a 100644 --- a/src/daft-csv/src/local.rs +++ b/src/daft-csv/src/local.rs @@ -3,7 +3,7 @@ use std::io::{Chain, Cursor, Read}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::{num::NonZeroUsize, sync::Arc, sync::Condvar, sync::Mutex}; -use crate::local::pool::{read_slabs_windowed, FileSlab, SlabPool}; +use crate::local::pool::{read_slabs_windowed, CsvSlabData, FileSlabData, Slab, SlabPool}; use crate::ArrowSnafu; use crate::{CsvConvertOptions, CsvParseOptions, CsvReadOptions}; use arrow2::{ @@ -59,68 +59,68 @@ use crate::read::{fields_to_projection_indices, tables_concat}; /// │ │ /// └──────────┘ -/// A pool of ByteRecord slabs. Used for deserializing CSV. -#[derive(Debug)] -struct CsvBufferPool { - buffers: Mutex>>, - buffer_size: usize, - record_buffer_size: usize, - num_fields: usize, -} - -/// A slab of ByteRecords. Used for deserializing CSV. -struct CsvBuffer { - pool: Arc, - buffer: Vec, -} - -impl CsvBufferPool { - pub fn new( - record_buffer_size: usize, - num_fields: usize, - chunk_size_rows: usize, - initial_pool_size: usize, - ) -> Self { - let chunk_buffers = vec![ - vec![ - read::ByteRecord::with_capacity(record_buffer_size, num_fields); - chunk_size_rows - ]; - initial_pool_size - ]; - CsvBufferPool { - buffers: Mutex::new(chunk_buffers), - buffer_size: chunk_size_rows, - record_buffer_size, - num_fields, - } - } - - pub fn get_buffer(self: &Arc) -> CsvBuffer { - let mut buffers = self.buffers.lock().unwrap(); - let buffer = buffers.pop(); - let buffer = match buffer { - Some(buffer) => buffer, - None => { - println!("csv buf empty"); - vec![ - read::ByteRecord::with_capacity(self.record_buffer_size, self.num_fields); - self.buffer_size - ] - } - }; - - CsvBuffer { - pool: Arc::clone(self), - buffer, - } - } - - fn return_buffer(&self, buffer: Vec) { - let mut buffers = self.buffers.lock().unwrap(); - buffers.push(buffer); - } -} +// /// A pool of ByteRecord slabs. Used for deserializing CSV. +// #[derive(Debug)] +// struct CsvBufferPool { +// buffers: Mutex>>, +// buffer_size: usize, +// record_buffer_size: usize, +// num_fields: usize, +// } + +// /// A slab of ByteRecords. Used for deserializing CSV. +// struct CsvBuffer { +// pool: Arc, +// buffer: Vec, +// } +// +// impl CsvBufferPool { +// pub fn new( +// record_buffer_size: usize, +// num_fields: usize, +// chunk_size_rows: usize, +// initial_pool_size: usize, +// ) -> Self { +// let chunk_buffers = vec![ +// vec![ +// read::ByteRecord::with_capacity(record_buffer_size, num_fields); +// chunk_size_rows +// ]; +// initial_pool_size +// ]; +// CsvBufferPool { +// buffers: Mutex::new(chunk_buffers), +// buffer_size: chunk_size_rows, +// record_buffer_size, +// num_fields, +// } +// } +// +// pub fn get_buffer(self: &Arc) -> CsvBuffer { +// let mut buffers = self.buffers.lock().unwrap(); +// let buffer = buffers.pop(); +// let buffer = match buffer { +// Some(buffer) => buffer, +// None => { +// println!("csv buf empty"); +// vec![ +// read::ByteRecord::with_capacity(self.record_buffer_size, self.num_fields); +// self.buffer_size +// ] +// } +// }; +// +// CsvBuffer { +// pool: Arc::clone(self), +// buffer, +// } +// } +// +// fn return_buffer(&self, buffer: Vec) { +// let mut buffers = self.buffers.lock().unwrap(); +// buffers.push(buffer); +// } +// } // The default size of a slab used for reading CSV files in chunks. Currently set to 4MB. const SLAB_SIZE: usize = 4 * 1024 * 1024; @@ -141,7 +141,7 @@ pub async fn read_csv_local( read_options, max_chunks_in_flight, ) - .await?; + .await?; tables_concat(tables_stream_collect(Box::pin(stream)).await) } @@ -163,10 +163,13 @@ pub async fn stream_csv_local( parse_options: CsvParseOptions, read_options: Option, max_chunks_in_flight: Option, -) -> DaftResult> + Send> { +) -> DaftResult> + Send> { let uri = uri.trim_start_matches("file://"); let file = tokio::fs::File::open(uri).await?; + + println!("convert_options is NOne? ... {}", convert_options.is_none()); + // TODO(desmond): This logic is repeated multiple times in the csv reader files. Should dedup. let predicate = convert_options .as_ref() @@ -195,7 +198,8 @@ pub async fn stream_csv_local( Some(co) } } - .unwrap_or_default(); + .unwrap_or_default(); + // End of `should dedup`. // TODO(desmond): We should do better schema inference here. @@ -231,22 +235,32 @@ pub async fn stream_csv_local( let chunk_size_rows = (chunk_size as f64 / record_buffer_size as f64).ceil() as usize; let num_fields = schema.fields.len(); // TODO(desmond): We might consider creating per-process buffer pools and slab pools. - let buffer_pool = Arc::new(CsvBufferPool::new( - record_buffer_size, - num_fields, - chunk_size_rows, - n_threads * 2, - )); + + + let initial_pool_size = n_threads * 2; + let csv_slabs = vec![ + vec![ + read::ByteRecord::with_capacity(record_buffer_size, num_fields); + chunk_size_rows + ]; + initial_pool_size + ]; + + + let buffer_pool = SlabPool::new( + csv_slabs, + ); // We suppose that each slab of CSV data produces (chunk size / slab size) number of Daft tables. We // then double this capacity to ensure that our channel is never full and our threads won't deadlock. let (sender, receiver) = crossbeam_channel::bounded(max_chunks_in_flight.unwrap_or(2 * chunk_size / SLAB_SIZE)); - let total_len = file.metadata().unwrap().len() as usize; + let total_len = file.metadata().await.unwrap().len() as usize; let windowed_buffers = read_slabs_windowed(file, vec![vec![0; SLAB_SIZE]; SLAB_POOL_DEFAULT_SIZE]); rayon::spawn(move || { + // todo: await consume_csv_file( buffer_pool, windowed_buffers, @@ -270,8 +284,8 @@ pub async fn stream_csv_local( /// Consumes the CSV file and sends the results to `sender`. #[allow(clippy::too_many_arguments)] async fn consume_csv_file( - buffer_pool: Arc, - window_stream: impl Stream, + mut buffer_pool: SlabPool, + mut window_stream: impl Stream + Unpin, total_len: usize, parse_options: CsvParseOptions, projection_indices: Arc>, @@ -291,8 +305,6 @@ async fn consume_csv_file( let quote_char = parse_options.quote; let double_quote_escape_allowed = parse_options.double_quote; let mut total_bytes_read = 0; - let mut next_slab = None; - let mut next_buffer_len = 0; let mut is_first_buffer = true; loop { let limit_reached = limit.map_or(false, |limit| { @@ -302,20 +314,31 @@ async fn consume_csv_file( if limit_reached { break; } - let window: WindowedSlab = window_stream.next().await?; - let first_buffer = &window[0]; - let second_buffer = window.get(1).map(|slab| &****slab); - - let file_chunk = get_file_chunk( - first_buffer, - second_buffer, - is_first_buffer, - num_fields, - quote_char, - field_delimiter, - escape_char, - double_quote_escape_allowed, - ); + + let Some(window) = window_stream.next().await else { + // todo: probably right (we really think so) + break; + }; + + let first_buffer = window.get(0).unwrap().clone(); + let second_buffer = window.get(1); + + let file_chunk = { + let second_buffer = second_buffer.map(|slab| &****slab); + get_file_chunk( + &first_buffer, + second_buffer, + is_first_buffer, + num_fields, + quote_char, + field_delimiter, + escape_char, + double_quote_escape_allowed, + ) + }; + + let second_buffer = second_buffer.cloned(); + is_first_buffer = false; if let (None, _) = file_chunk { // Exit early before spawning a new thread. @@ -323,7 +346,7 @@ async fn consume_csv_file( break; } let parse_options = parse_options.clone(); - let csv_buffer = buffer_pool.get_buffer(); + let csv_slab = buffer_pool.get_next_data().await; let projection_indices = projection_indices.clone(); let fields = fields.clone(); let read_daft_fields = read_daft_fields.clone(); @@ -349,7 +372,7 @@ async fn consume_csv_file( fields, read_daft_fields, read_schema, - csv_buffer, + csv_slab, &include_columns, predicate, sender, @@ -357,31 +380,27 @@ async fn consume_csv_file( ); } (Some(start), Some(end)) => { - if let Some(next_slab_clone) = next_slab_clone - && let Some(current_buffer) = ¤t_slab_clone.buffer - && let Some(next_buffer) = &next_slab_clone.buffer - { - let buffer_source = BufferSource::Chain(std::io::Read::chain( - Cursor::new(¤t_buffer[start..current_buffer_len]), - Cursor::new(&next_buffer[..end]), - )); - dispatch_to_parse_csv( - has_header, - parse_options, - buffer_source, - projection_indices, - fields, - read_daft_fields, - read_schema, - csv_buffer, - &include_columns, - predicate, - sender, - rows_read, - ); - } else { - panic!("Trying to read from an overflow CSV buffer that doesn't exist. Please report this issue.") - } + let first_buffer = Cursor::new(&first_buffer[start..]); + + let second_buffer = second_buffer.unwrap(); + let second_buffer = Cursor::new(&second_buffer[..end]); + + let reader = std::io::Read::chain(first_buffer, second_buffer); + + dispatch_to_parse_csv( + has_header, + parse_options, + reader, + projection_indices, + fields, + read_daft_fields, + read_schema, + csv_slab, + &include_columns, + predicate, + sender, + rows_read, + ); } _ => panic!( "Something went wrong when parsing the CSV file. Please report this issue." @@ -671,12 +690,12 @@ fn validate_csv_record( fn dispatch_to_parse_csv( has_header: bool, parse_options: CsvParseOptions, - buffer_source: BufferSource, + buffer_source: impl Read, projection_indices: Arc>, fields: Vec, read_daft_fields: Arc>>, read_schema: Arc, - csv_buffer: CsvBuffer, + csv_buffer: Slab, include_columns: &Option>, predicate: Option>, sender: Sender>, @@ -727,19 +746,19 @@ fn parse_csv_chunk( fields: Vec, read_daft_fields: Arc>>, read_schema: Arc, - csv_buffer: CsvBuffer, + mut csv_buffer: Slab, include_columns: &Option>, predicate: Option>, ) -> DaftResult> where R: std::io::Read, { - let mut chunk_buffer = csv_buffer.buffer; + let mut chunk_buffer = &mut *csv_buffer; let mut tables = vec![]; loop { //let time = Instant::now(); let (rows_read, has_more) = - local_read_rows(&mut reader, chunk_buffer.as_mut_slice()).context(ArrowSnafu {})?; + local_read_rows(&mut reader, chunk_buffer).context(ArrowSnafu {})?; //let time = Instant::now(); let chunk = projection_indices .par_iter() @@ -777,6 +796,5 @@ where break; } } - csv_buffer.pool.return_buffer(chunk_buffer); Ok(tables) } diff --git a/src/daft-csv/src/local/pool.rs b/src/daft-csv/src/local/pool.rs index ab13894462..7cac7cd672 100644 --- a/src/daft-csv/src/local/pool.rs +++ b/src/daft-csv/src/local/pool.rs @@ -7,7 +7,9 @@ use tokio::sync::mpsc::{self, Receiver, Sender}; mod fixed_capacity_vec; -pub type FileSlab = Vec; +pub type FileSlabData = Vec; + +pub type CsvSlabData = Vec; /// A pool of reusable memory slabs for efficient I/O operations. pub struct SlabPool { @@ -51,7 +53,7 @@ impl SlabPool { impl SlabPool { /// Asynchronously retrieves the next available slab from the pool. - async fn get_next_data(&mut self) -> Slab { + pub async fn get_next_data(&mut self) -> Slab { let mut data = self .available_slabs .recv() @@ -98,16 +100,17 @@ impl Drop for Slab { } use tokio_stream::wrappers::ReceiverStream; +use arrow2::io::csv::read; /// Asynchronously reads slabs from a file and returns a stream of SharedSlabs. pub fn read_slabs( mut file: R, - iterator: impl ExactSizeIterator, -) -> impl Stream> + iterator: impl ExactSizeIterator, +) -> impl Stream> where R: AsyncRead + Unpin + Send + 'static, { - let (tx, rx) = mpsc::channel::>(iterator.len()); + let (tx, rx) = mpsc::channel::>(iterator.len()); let pool = SlabPool::new(iterator); tokio::spawn(async move { @@ -154,7 +157,7 @@ where ReceiverStream::new(rx) } -pub type WindowedSlab = heapless::Vec, 2>; +pub type WindowedSlab = heapless::Vec, 2>; /// Asynchronously reads slabs from a file and returns a stream of WindowedSlabs. /// @@ -171,11 +174,11 @@ pub type WindowedSlab = heapless::Vec, 2>; /// # Returns /// /// A `Stream` of `WindowedSlab`s. -pub fn read_slabs_windowed(file: R, iterator: I) -> impl Stream +pub fn read_slabs_windowed(file: R, iterator: I) -> impl Stream + Unpin where R: AsyncRead + Unpin + Send + 'static, - I: IntoIterator + 'static, - I::IntoIter: ExactSizeIterator + 'static, + I: IntoIterator + 'static, + I::IntoIter: ExactSizeIterator + 'static, { let iterator = iterator.into_iter(); let (tx, rx) = mpsc::channel(iterator.len()); @@ -187,7 +190,7 @@ where tokio::spawn(async move { let mut slab_stream = pin!(slab_stream); - let mut windowed_slab = heapless::Vec::, 2>::new(); + let mut windowed_slab = heapless::Vec::, 2>::new(); let mut slab_stream = slab_stream.as_mut(); while let Some(slab) = StreamExt::next(&mut slab_stream).await { diff --git a/src/daft-csv/src/read.rs b/src/daft-csv/src/read.rs index 89a3c87fa3..13ddd7766d 100644 --- a/src/daft-csv/src/read.rs +++ b/src/daft-csv/src/read.rs @@ -34,18 +34,18 @@ use crate::{metadata::read_csv_schema_single, CsvConvertOptions, CsvParseOptions use daft_compression::CompressionCodec; use daft_decoding::deserialize::deserialize_column; -trait ByteRecordChunkStream: Stream>> {} -impl ByteRecordChunkStream for S where - S: Stream>> -{ -} +trait ByteRecordChunkStream: Stream>> {} +impl ByteRecordChunkStream for S +where + S: Stream>>, +{} use crate::{local::read_csv_local, local::stream_csv_local}; type TableChunkResult = - super::Result>, super::JoinSnafu, super::Error>>; -trait TableStream: Stream {} -impl TableStream for S where S: Stream {} +super::Result>, super::JoinSnafu, super::Error>>; +trait TableStream: Stream {} +impl TableStream for S where S: Stream {} #[allow(clippy::too_many_arguments)] pub fn read_csv( @@ -70,7 +70,7 @@ pub fn read_csv( io_stats, max_chunks_in_flight, ) - .await + .await }) } @@ -109,9 +109,9 @@ pub fn read_csv_bulk( io_stats, max_chunks_in_flight, ) - .await + .await }) - .context(super::JoinSnafu {}) + .context(super::JoinSnafu {}) })); let mut remaining_rows = convert_options .as_ref() @@ -155,15 +155,14 @@ pub async fn stream_csv( let uri = uri.as_str(); let (source_type, _) = parse_url(uri)?; let is_compressed = CompressionCodec::from_uri(uri).is_some(); - let use_local = false; - if matches!(source_type, SourceType::File) && !is_compressed && use_local { + if matches!(source_type, SourceType::File) && !is_compressed { let stream = stream_csv_local( uri, convert_options, parse_options.unwrap_or_default(), read_options, max_chunks_in_flight, - )?; + ).await?; Ok(Box::pin(stream)) } else { let stream = stream_csv_single( @@ -175,7 +174,7 @@ pub async fn stream_csv( io_stats, max_chunks_in_flight, ) - .await?; + .await?; Ok(Box::pin(stream)) } } @@ -229,8 +228,7 @@ async fn read_csv_single_into_table( ) -> DaftResult
{ let (source_type, _) = parse_url(uri)?; let is_compressed = CompressionCodec::from_uri(uri).is_some(); - let use_local = false; - if matches!(source_type, SourceType::File) && !is_compressed && use_local { + if matches!(source_type, SourceType::File) && !is_compressed { return read_csv_local( uri, convert_options, @@ -238,7 +236,7 @@ async fn read_csv_single_into_table( read_options, max_chunks_in_flight, ) - .await; + .await; } let predicate = convert_options @@ -277,7 +275,7 @@ async fn read_csv_single_into_table( io_client, io_stats, ) - .await?; + .await?; // Default max chunks in flight is set to 2x the number of cores, which should ensure pipelining of reading chunks // with the parsing of chunks on the rayon threadpool. let max_chunks_in_flight = max_chunks_in_flight.unwrap_or_else(|| { @@ -366,7 +364,7 @@ async fn stream_csv_single( io_client: Arc, io_stats: Option, max_chunks_in_flight: Option, -) -> DaftResult> + Send> { +) -> DaftResult> + Send> { let predicate = convert_options .as_ref() .and_then(|opts| opts.predicate.clone()); @@ -403,7 +401,7 @@ async fn stream_csv_single( io_client, io_stats, ) - .await?; + .await?; // Default max chunks in flight is set to 2x the number of cores, which should ensure pipelining of reading chunks // with the parsing of chunks on the rayon threadpool. let max_chunks_in_flight = max_chunks_in_flight.unwrap_or_else(|| { @@ -469,7 +467,7 @@ async fn read_csv_single_into_stream( io_client.clone(), io_stats.clone(), ) - .await?; + .await?; ( schema.to_arrow()?, Some(read_stats.mean_record_size_bytes), @@ -667,7 +665,7 @@ fn parse_into_column_array_chunk_stream( }); recv.await.context(super::OneShotRecvSnafu {})? }) - .context(super::JoinSnafu {}) + .context(super::JoinSnafu {}) })) } @@ -946,7 +944,7 @@ mod tests { None, None, ) - .await + .await }); assert!( @@ -1188,7 +1186,7 @@ mod tests { #[test] fn test_csv_read_local_escape() -> DaftResult<()> { - let file = format!("{}/test/iris_tiny_escape.csv", env!("CARGO_MANIFEST_DIR"),); + let file = format!("{}/test/iris_tiny_escape.csv", env!("CARGO_MANIFEST_DIR"), ); let mut io_config = IOConfig::default(); io_config.s3.anonymous = true; @@ -1236,7 +1234,7 @@ mod tests { #[test] fn test_csv_read_local_comment() -> DaftResult<()> { - let file = format!("{}/test/iris_tiny_comment.csv", env!("CARGO_MANIFEST_DIR"),); + let file = format!("{}/test/iris_tiny_comment.csv", env!("CARGO_MANIFEST_DIR"), ); let mut io_config = IOConfig::default(); io_config.s3.anonymous = true; @@ -1283,7 +1281,7 @@ mod tests { } #[test] fn test_csv_read_local_limit() -> DaftResult<()> { - let file = format!("{}/test/iris_tiny.csv", env!("CARGO_MANIFEST_DIR"),); + let file = format!("{}/test/iris_tiny.csv", env!("CARGO_MANIFEST_DIR"), ); let mut io_config = IOConfig::default(); io_config.s3.anonymous = true; @@ -1331,7 +1329,7 @@ mod tests { #[test] fn test_csv_read_local_projection() -> DaftResult<()> { - let file = format!("{}/test/iris_tiny.csv", env!("CARGO_MANIFEST_DIR"),); + let file = format!("{}/test/iris_tiny.csv", env!("CARGO_MANIFEST_DIR"), ); let mut io_config = IOConfig::default(); io_config.s3.anonymous = true; @@ -1441,7 +1439,7 @@ mod tests { #[test] fn test_csv_read_local_larger_than_buffer_size() -> DaftResult<()> { - let file = format!("{}/test/iris_tiny.csv", env!("CARGO_MANIFEST_DIR"),); + let file = format!("{}/test/iris_tiny.csv", env!("CARGO_MANIFEST_DIR"), ); let mut io_config = IOConfig::default(); io_config.s3.anonymous = true; @@ -1489,7 +1487,7 @@ mod tests { #[test] fn test_csv_read_local_larger_than_chunk_size() -> DaftResult<()> { - let file = format!("{}/test/iris_tiny.csv", env!("CARGO_MANIFEST_DIR"),); + let file = format!("{}/test/iris_tiny.csv", env!("CARGO_MANIFEST_DIR"), ); let mut io_config = IOConfig::default(); io_config.s3.anonymous = true; @@ -1537,7 +1535,7 @@ mod tests { #[test] fn test_csv_read_local_throttled_streaming() -> DaftResult<()> { - let file = format!("{}/test/iris_tiny.csv", env!("CARGO_MANIFEST_DIR"),); + let file = format!("{}/test/iris_tiny.csv", env!("CARGO_MANIFEST_DIR"), ); let mut io_config = IOConfig::default(); io_config.s3.anonymous = true; @@ -1585,7 +1583,7 @@ mod tests { #[test] fn test_csv_read_local_nulls() -> DaftResult<()> { - let file = format!("{}/test/iris_tiny_nulls.csv", env!("CARGO_MANIFEST_DIR"),); + let file = format!("{}/test/iris_tiny_nulls.csv", env!("CARGO_MANIFEST_DIR"), ); let mut io_config = IOConfig::default(); io_config.s3.anonymous = true; @@ -1762,7 +1760,7 @@ mod tests { #[test] fn test_csv_read_local_wrong_type_yields_nulls() -> DaftResult<()> { - let file = format!("{}/test/iris_tiny.csv", env!("CARGO_MANIFEST_DIR"),); + let file = format!("{}/test/iris_tiny.csv", env!("CARGO_MANIFEST_DIR"), ); let mut io_config = IOConfig::default(); io_config.s3.anonymous = true;