From 7b40f23a5ff83aba4ab059b62ac781d7766be0b1 Mon Sep 17 00:00:00 2001 From: desmondcheongzx Date: Fri, 30 Aug 2024 14:26:31 -0700 Subject: [PATCH] Prototype --- Cargo.lock | 4 + Makefile | 4 + src/arrow2/src/io/csv/read_async/reader.rs | 25 +- src/daft-csv/Cargo.toml | 4 + src/daft-csv/src/lib.rs | 4 + src/daft-csv/src/read.rs | 591 ++++++++++++++++++++- src/daft-decoding/src/deserialize.rs | 12 +- 7 files changed, 618 insertions(+), 26 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c2472380cc..844617f3a0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1736,6 +1736,7 @@ dependencies = [ "async-stream", "common-error", "common-py-serde", + "crossbeam-channel", "csv-async", "daft-compression", "daft-core", @@ -1744,6 +1745,9 @@ dependencies = [ "daft-io", "daft-table", "futures", + "indexmap 2.3.0", + "memchr", + "memmap2", "pyo3", "rayon", "rstest", diff --git a/Makefile b/Makefile index d1d96e8df6..327f00513c 100644 --- a/Makefile +++ b/Makefile @@ -52,6 +52,10 @@ build: check-toolchain .venv ## Compile and install Daft for development build-release: check-toolchain .venv ## Compile and install a faster Daft binary @unset CONDA_PREFIX && PYO3_PYTHON=$(VENV_BIN)/python $(VENV_BIN)/maturin develop --release +.PHONY: build-bench +build-bench: check-toolchain .venv ## Compile and install a faster Daft binary + @unset CONDA_PREFIX && PYO3_PYTHON=$(VENV_BIN)/python $(VENV_BIN)/maturin develop --profile dev-bench + .PHONY: test test: .venv build ## Run tests HYPOTHESIS_MAX_EXAMPLES=$(HYPOTHESIS_MAX_EXAMPLES) $(VENV_BIN)/pytest --hypothesis-seed=$(HYPOTHESIS_SEED) diff --git a/src/arrow2/src/io/csv/read_async/reader.rs b/src/arrow2/src/io/csv/read_async/reader.rs index 0db0f25268..a95b88e5ad 100644 --- a/src/arrow2/src/io/csv/read_async/reader.rs +++ b/src/arrow2/src/io/csv/read_async/reader.rs @@ -1,10 +1,11 @@ use futures::AsyncRead; use super::{AsyncReader, ByteRecord}; +use crate::io::csv::read; use crate::error::{Error, Result}; -/// Asynchronosly read `len` rows from `reader` into `row`, skipping the first `skip`. +/// Asynchronosly read `rows.len` rows from `reader` into `rows`, skipping the first `skip`. /// This operation has minimal CPU work and is thus the fastest way to read through a CSV /// without deserializing the contents to Arrow. pub async fn read_rows( @@ -37,3 +38,25 @@ where } Ok(row_number) } + +/// Synchronously read `rows.len` rows from `reader` into `rows`. This is used in the local i/o case. +pub fn local_read_rows( + reader: &mut read::Reader, + rows: &mut [read::ByteRecord], +) -> Result<(usize, bool)> +where + R: std::io::Read, +{ + let mut row_number = 0; + let mut has_more = true; + for row in rows.iter_mut() { + has_more = reader + .read_byte_record(row) + .map_err(|e| Error::External(format!(" at line {}", row_number), Box::new(e)))?; + if !has_more { + break; + } + row_number += 1; + } + Ok((row_number, has_more)) +} diff --git a/src/daft-csv/Cargo.toml b/src/daft-csv/Cargo.toml index 4e58223843..d0781e7184 100644 --- a/src/daft-csv/Cargo.toml +++ b/src/daft-csv/Cargo.toml @@ -5,6 +5,7 @@ async-compression = {workspace = true} async-stream = {workspace = true} common-error = {path = "../common/error", default-features = false} common-py-serde = {path = "../common/py-serde", default-features = false} +crossbeam-channel = "0.5.1" csv-async = "1.3.0" daft-compression = {path = "../daft-compression", default-features = false} daft-core = {path = "../daft-core", default-features = false} @@ -13,6 +14,9 @@ daft-dsl = {path = "../daft-dsl", default-features = false} daft-io = {path = "../daft-io", default-features = false} daft-table = {path = "../daft-table", default-features = false} futures = {workspace = true} +indexmap = {workspace = true, features = ["serde"]} +memchr = "2.7.2" +memmap2 = "0.9.4" pyo3 = {workspace = true, optional = true} rayon = {workspace = true} serde = {workspace = true} diff --git a/src/daft-csv/src/lib.rs b/src/daft-csv/src/lib.rs index 59a62ad2d1..35c4d72006 100644 --- a/src/daft-csv/src/lib.rs +++ b/src/daft-csv/src/lib.rs @@ -2,6 +2,8 @@ #![feature(let_chains)] #![feature(trait_alias)] #![feature(trait_upcasting)] +#![feature(test)] +extern crate test; use common_error::DaftError; use snafu::Snafu; @@ -23,6 +25,8 @@ pub enum Error { #[snafu(display("{source}"))] IOError { source: daft_io::Error }, #[snafu(display("{source}"))] + StdIOError { source: std::io::Error }, + #[snafu(display("{source}"))] CSVError { source: csv_async::Error }, #[snafu(display("Invalid char: {}", val))] WrongChar { diff --git a/src/daft-csv/src/read.rs b/src/daft-csv/src/read.rs index 31ab4f84f0..3ed49204b3 100644 --- a/src/daft-csv/src/read.rs +++ b/src/daft-csv/src/read.rs @@ -1,15 +1,20 @@ -use std::{collections::HashMap, num::NonZeroUsize, sync::Arc}; +use std::{collections::HashMap, num::NonZeroUsize, sync::Arc, sync::Mutex}; + +use std::sync::atomic::{AtomicUsize, Ordering}; use arrow2::{ datatypes::Field, - io::csv::read_async::{read_rows, AsyncReaderBuilder, ByteRecord}, + io::csv::read, + io::csv::read::{Reader, ReaderBuilder}, + io::csv::read_async, + io::csv::read_async::{local_read_rows, read_rows, AsyncReaderBuilder}, }; use async_compat::{Compat, CompatExt}; use common_error::{DaftError, DaftResult}; use csv_async::AsyncReader; use daft_core::{schema::Schema, utils::arrow::cast_array_for_daft_if_needed, Series}; -use daft_dsl::optimization::get_required_columns; -use daft_io::{get_runtime, GetResult, IOClient, IOStatsRef}; +use daft_dsl::{optimization::get_required_columns, Expr}; +use daft_io::{get_runtime, parse_url, GetResult, IOClient, IOStatsRef, SourceType}; use daft_table::Table; use futures::{stream::BoxStream, Stream, StreamExt, TryStreamExt}; use rayon::{ @@ -27,13 +32,16 @@ use tokio::{ }; use tokio_util::io::StreamReader; -use crate::ArrowSnafu; use crate::{metadata::read_csv_schema_single, CsvConvertOptions, CsvParseOptions, CsvReadOptions}; +use crate::{ArrowSnafu, StdIOSnafu}; use daft_compression::CompressionCodec; use daft_decoding::deserialize::deserialize_column; -trait ByteRecordChunkStream: Stream>> {} -impl ByteRecordChunkStream for S where S: Stream>> {} +trait ByteRecordChunkStream: Stream>> {} +impl ByteRecordChunkStream for S where + S: Stream>> +{ +} type TableChunkResult = super::Result>, super::JoinSnafu, super::Error>>; @@ -145,22 +153,35 @@ pub async fn stream_csv( io_stats: Option, max_chunks_in_flight: Option, ) -> DaftResult>> { - let stream = stream_csv_single( - &uri, - convert_options, - parse_options, - read_options, - io_client, - io_stats, - max_chunks_in_flight, - ) - .await?; - - Ok(Box::pin(stream)) + let uri = uri.as_str(); + let (source_type, _) = parse_url(uri)?; + let is_compressed = CompressionCodec::from_uri(uri).is_some(); + let use_local_reader = false; // TODO(desmond): Feature under dev. + if matches!(source_type, SourceType::File) && !is_compressed && use_local_reader { + let stream = stream_csv_local( + uri, + convert_options, + parse_options.unwrap_or_default(), + read_options, + max_chunks_in_flight, + ) + .await?; + Ok(Box::pin(stream)) + } else { + let stream = stream_csv_single( + uri, + convert_options, + parse_options, + read_options, + io_client, + io_stats, + max_chunks_in_flight, + ) + .await?; + Ok(Box::pin(stream)) + } } -// Parallel version of table concat -// get rid of this once Table APIs are parallel fn tables_concat(mut tables: Vec) -> DaftResult
{ if tables.is_empty() { return Err(DaftError::ValueError( @@ -199,6 +220,387 @@ fn tables_concat(mut tables: Vec
) -> DaftResult
{ ) } +#[derive(Debug)] +struct CsvBufferPool { + buffers: Mutex>>, + buffer_size: usize, + record_buffer_size: usize, + num_fields: usize, +} + +struct CsvBufferPoolRef<'a> { + pool: &'a CsvBufferPool, + buffer: Vec, +} + +impl CsvBufferPool { + pub fn new( + record_buffer_size: usize, + num_fields: usize, + chunk_size_rows: usize, + initial_pool_size: usize, + ) -> Self { + let chunk_buffers = vec![ + vec![ + read::ByteRecord::with_capacity(record_buffer_size, num_fields); + chunk_size_rows + ]; + initial_pool_size + ]; + CsvBufferPool { + buffers: Mutex::new(chunk_buffers), + buffer_size: chunk_size_rows, + record_buffer_size, + num_fields, + } + } + + pub fn get_buffer(&self) -> CsvBufferPoolRef { + let mut buffers = self.buffers.lock().unwrap(); + let buffer = buffers.pop(); + let buffer = match buffer { + Some(buffer) => buffer, + None => { + vec![ + read::ByteRecord::with_capacity(self.record_buffer_size, self.num_fields); + self.buffer_size + ] + } + }; + + CsvBufferPoolRef { pool: self, buffer } + } + + fn return_buffer(&self, buffer: Vec) { + let mut buffers = self.buffers.lock().unwrap(); + buffers.push(buffer); + } +} + +// Daft does not currently support non-\n record terminators (e.g. carriage return \r, which only +// matters for pre-Mac OS X). +const NEWLINE: u8 = b'\n'; +const DEFAULT_CHUNK_SIZE: usize = 4 * 1024 * 1024; // 1MiB. TODO(desmond): This should be tuned. + +/// Helper function that finds the first new line character (\n) in the given byte slice. +fn next_line_position(input: &[u8]) -> Option { + // Assuming we are searching for the ASCII `\n` character, we don't need to do any special + // handling for UTF-8, since a `\n` value always corresponds to an ASCII `\n`. + // For more details, see: https://en.wikipedia.org/wiki/UTF-8#Encoding + memchr::memchr(NEWLINE, input) +} + +/// Helper function that determines what chunk of data to parse given a starting position within the +/// file, and the desired initial chunk size. +/// +/// Given a starting position, we use our chunk size to compute a preliminary start and stop +/// position. For example, we can visualize all preliminary chunks in a file as follows. +/// +/// Chunk 1 Chunk 2 Chunk 3 Chunk N +/// ┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐ +/// │ │ │\n │ │ \n │ │ \n │ +/// │ │ │ │ │ │ │ │ +/// │ │ │ \n │ │ │ │ │ +/// │ \n │ │ │ │ \n │ │ │ +/// │ │ │ │ │ │ ... │ \n │ +/// │ │ │ \n │ │ │ │ │ +/// │ \n │ │ │ │ │ │ │ +/// │ │ │ │ │ \n │ │ \n │ +/// └──────────┘ └──────────┘ └──────────┘ └──────────┘ +/// +/// However, record boundaries (i.e. the \n terminators) do not align nicely with these preliminary +/// chunk boundaries. So we adjust each preliminary chunk as follows: +/// - Find the first record terminator from the chunk's start. This is the new starting position. +/// - Find the first record terminator from the chunk's end. This is the new ending position. +/// - If a given preliminary chunk doesn't contain a record terminator, the adjusted chunk is empty. +/// +/// For example: +/// +/// Adjusted Chunk 1 Adj. Chunk 2 Adj. Chunk 3 Adj. Chunk N +/// ┌──────────────────┐┌─────────────────┐ ┌────────┐ ┌─┐ +/// │ \n││ \n│ │ \n│ \n │ │ +/// │ ┌───────┘│ ┌──────────┘ │ ┌─────┘ │ │ +/// │ │ ┌───┘ \n │ ┌───────┘ │ ┌────────┘ │ +/// │ \n │ │ │ │ \n │ │ │ +/// │ │ │ │ │ │ ... │ \n │ +/// │ │ │ \n │ │ │ │ │ +/// │ \n │ │ │ │ │ │ │ +/// │ │ │ │ │ \n │ │ \n │ +/// └──────────┘ └──────────┘ └──────────┘ └──────────┘ +/// +/// Using this method, we now have adjusted chunks that are aligned with record boundaries, that do +/// not overlap, and that fully cover every byte in the CSV file. Parsing each adjusted chunk can +/// now happen in parallel. +/// +/// This is the same method as described in: +/// Ge, Chang et al. “Speculative Distributed CSV Data Parsing for Big Data Analytics.” Proceedings of the 2019 International Conference on Management of Data (2019). +fn get_file_chunk(bytes: &[u8], start: usize, chunk_size: usize) -> Option<(usize, usize)> { + let stop = start + chunk_size; + let start = if start == 0 { + 0 + } else { + match next_line_position(&bytes[start..]) { + // Start reading after the first record terminator from the start of the chunk. + Some(pos) => start + pos + 1, + // If there's no record terminator found, then the previous chunk reader would have + // consumed the current chunk. Hence, we skip. + None => return None, + } + }; + // If the first record terminator comes after this chunk, then the previous chunk reader would + // have consumed the current chunk. Hence, we skip. + if start > stop { + return None; + } + let stop = if stop < bytes.len() { + match next_line_position(&bytes[stop..]) { + // Read up to the first terminator from the end of the chunk. + Some(pos) => stop + pos, + None => bytes.len(), + } + } else { + bytes.len() + }; + Some((start, stop)) +} + +#[allow(clippy::too_many_arguments)] +fn parse_csv_chunk( + mut reader: Reader, + projection_indices: Arc>, + fields: Vec, + read_daft_fields: Arc>>, + read_schema: Arc, + buf: CsvBufferPoolRef, + include_columns: &Option>, + predicate: Option>, +) -> DaftResult> +where + R: std::io::Read, +{ + let mut chunk_buffer = buf.buffer; + let mut tables = vec![]; + loop { + let (rows_read, has_more) = + local_read_rows(&mut reader, chunk_buffer.as_mut_slice()).context(ArrowSnafu {})?; + let chunk = projection_indices + .par_iter() + .enumerate() + .map(|(i, proj_idx)| { + let deserialized_col = deserialize_column( + &chunk_buffer[0..rows_read], + *proj_idx, + fields[*proj_idx].data_type().clone(), + 0, + ); + Series::try_from_field_and_arrow_array( + read_daft_fields[i].clone(), + cast_array_for_daft_if_needed(deserialized_col?), + ) + }) + .collect::>>()?; + let num_rows = chunk.first().map(|s| s.len()).unwrap_or(0); + let table = Table::new_unchecked(read_schema.clone(), chunk, num_rows); + let table = if let Some(predicate) = &predicate { + let filtered = table.filter(&[predicate.clone()])?; + if let Some(include_columns) = &include_columns { + filtered.get_columns(include_columns.as_slice())? + } else { + filtered + } + } else { + table + }; + tables.push(table); + + // The number of record might exceed the number of byte records we've allocated. + // Retry until all byte records in this chunk are read. + if !has_more { + break; + } + } + buf.pool.return_buffer(chunk_buffer); + Ok(tables) +} + +async fn stream_csv_local( + uri: &str, + convert_options: Option, + parse_options: CsvParseOptions, + read_options: Option, + max_chunks_in_flight: Option, +) -> DaftResult> + Send> { + let uri = uri.trim_start_matches("file://"); + let file = std::fs::File::open(uri)?; + let mmap = unsafe { memmap2::Mmap::map(&file) }.context(StdIOSnafu)?; + let bytes = &mmap[..]; + + // TODO(desmond): This logic is repeated multiple times in this file. Should dedup. + let predicate = convert_options + .as_ref() + .and_then(|opts| opts.predicate.clone()); + + let limit = convert_options.as_ref().and_then(|opts| opts.limit); + + let include_columns = convert_options + .as_ref() + .and_then(|opts| opts.include_columns.clone()); + + let convert_options = match (convert_options, &predicate) { + (None, _) => None, + (co, None) => co, + (Some(mut co), Some(predicate)) => { + if let Some(ref mut include_columns) = co.include_columns { + let required_columns_for_predicate = get_required_columns(predicate); + for rc in required_columns_for_predicate { + if include_columns.iter().all(|c| c.as_str() != rc.as_str()) { + include_columns.push(rc) + } + } + } + // If we have a limit and a predicate, remove limit for stream. + co.limit = None; + Some(co) + } + } + .unwrap_or_default(); + // End of `should dedup`. + + // TODO(desmond): We should do better schema inference here. + let schema = convert_options.clone().schema.unwrap().to_arrow()?; + let n_threads: usize = std::thread::available_parallelism() + .unwrap_or(NonZeroUsize::new(2).unwrap()) + .into(); + let chunk_size = read_options + .as_ref() + .and_then(|opt| opt.chunk_size.or_else(|| opt.buffer_size.map(|bs| bs / 8))) + .unwrap_or(DEFAULT_CHUNK_SIZE); + let projection_indices = fields_to_projection_indices( + &schema.clone().fields, + &convert_options.clone().include_columns, + ); + let fields = schema.clone().fields; + let fields_subset = projection_indices + .iter() + .map(|i| fields.get(*i).unwrap().into()) + .collect::>(); + let read_schema = Arc::new(daft_core::schema::Schema::new(fields_subset)?); + let read_daft_fields = Arc::new( + read_schema + .fields + .values() + .map(|f| Arc::new(f.clone())) + .collect::>(), + ); + // TODO(desmond): Need better upfront estimators. Sample or keep running count of stats. + let estimated_mean_row_size = 100f64; + let estimated_std_row_size = 20f64; + let record_buffer_size = (estimated_mean_row_size + estimated_std_row_size).ceil() as usize; + let chunk_size_rows = (chunk_size as f64 / record_buffer_size as f64).ceil() as usize; + // TODO(desmond): We don't want to create a per-read buffer pool, we want one pool shared with + // the whole process. + let buffer_pool = CsvBufferPool::new( + record_buffer_size, + schema.fields.len(), + chunk_size_rows, + n_threads * 2, + ); + let chunk_offsets: Vec = (0..=bytes.len()).step_by(chunk_size).collect(); + // TODO(desmond): Memory usage is still growing during execution of a .count(*).collect(), so + // the following approach still isn't quite right. + // TODO(desmond): Also, is this usage of max_chunks_in_flight correct? + let (sender, receiver) = + crossbeam_channel::bounded(max_chunks_in_flight.unwrap_or(n_threads * 2)); + let rows_read = AtomicUsize::new(0); + rayon::spawn(move || { + let bytes = &mmap[..]; + chunk_offsets.into_par_iter().for_each(|start| { + // TODO(desmond): Use try_for_each and terminate early once the limit is reached. + let limit_reached = limit.map_or(false, |limit| { + let current_rows_read = rows_read.load(Ordering::Relaxed); + current_rows_read >= limit + }); + + if !limit_reached && let Some((start, stop)) = get_file_chunk(bytes, start, chunk_size) + { + let buf = buffer_pool.get_buffer(); + let chunk = &bytes[start..stop]; + // Only the first chunk might potentially have headers. Subsequent chunks should + // read all rows as records. + let has_headers = start == 0 && parse_options.has_header; + let rdr = ReaderBuilder::new() + .has_headers(has_headers) + .delimiter(parse_options.delimiter) + .double_quote(parse_options.double_quote) + // TODO(desmond): We need to handle the quoted case properly. + .quote(parse_options.quote) + .escape(parse_options.escape_char) + .comment(parse_options.comment) + .flexible(parse_options.allow_variable_columns) + .from_reader(chunk); + Reader::from_reader(chunk); + let table_results = parse_csv_chunk( + rdr, + projection_indices.clone(), + fields.clone(), + read_daft_fields.clone(), + read_schema.clone(), + buf, + &include_columns, + predicate.clone(), + ); + match table_results { + Ok(tables) => { + for table in tables { + let table_len = table.len(); + sender.send(Ok(table)).unwrap(); + // Atomically update the number of rows read only after the result has + // been sent. In theory we could wrap these steps in a mutex, but + // applying limit at this layer can be best-effort with no adverse + // side effects. + rows_read.fetch_add(table_len, Ordering::Relaxed); + } + } + Err(e) => sender.send(Err(e)).unwrap(), + } + } + }) + }); + + let result_stream = futures::stream::iter(receiver); + Ok(result_stream) +} + +async fn tables_stream_collect(stream: BoxStream<'static, DaftResult
>) -> Vec
{ + stream + .filter_map(|result| async { + match result { + Ok(table) => Some(table), + Err(_) => None, // Skips errors; you could log them or handle differently + } + }) + .collect() + .await +} + +async fn read_csv_local( + uri: &str, + convert_options: Option, + parse_options: CsvParseOptions, + read_options: Option, + max_chunks_in_flight: Option, +) -> DaftResult
{ + let stream = stream_csv_local( + uri, + convert_options, + parse_options, + read_options, + max_chunks_in_flight, + ) + .await?; + tables_concat(tables_stream_collect(Box::pin(stream)).await) +} + async fn read_csv_single_into_table( uri: &str, convert_options: Option, @@ -208,6 +610,20 @@ async fn read_csv_single_into_table( io_stats: Option, max_chunks_in_flight: Option, ) -> DaftResult
{ + let (source_type, _) = parse_url(uri)?; + let is_compressed = CompressionCodec::from_uri(uri).is_some(); + let use_local_reader = false; // TODO(desmond): Feature under dev. + if matches!(source_type, SourceType::File) && !is_compressed && use_local_reader { + return read_csv_local( + uri, + convert_options, + parse_options.unwrap_or_default(), + read_options, + max_chunks_in_flight, + ) + .await; + } + let predicate = convert_options .as_ref() .and_then(|opts| opts.predicate.clone()); @@ -557,7 +973,7 @@ where estimated_rows_per_desired_chunk.max(8).min(num_rows - total_rows_read) }; let mut chunk_buffer = vec![ - ByteRecord::with_capacity(record_buffer_size, num_fields); + read_async::ByteRecord::with_capacity(record_buffer_size, num_fields); chunk_size_rows ]; @@ -576,7 +992,7 @@ where chunk_buffer.truncate(rows_read); if rows_read > 0 { - yield chunk_buffer + yield chunk_buffer; } } } @@ -811,6 +1227,135 @@ mod tests { Ok(()) } + use crate::read::read_csv_local; + use daft_io::get_runtime; + #[test] + fn test_csv_read_experimental() -> DaftResult<()> { + let file = "file:///Users/desmond/tasks/csv-reader/G1_1e8_1e1_5_0.csv"; + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + + let io_client = Arc::new(IOClient::new(io_config.into())?); + + use indexmap::IndexMap; + let mut fields = IndexMap::new(); + fields.insert( + "id1".to_string(), + Field { + name: "id1".to_string(), + dtype: daft_core::datatypes::DataType::Utf8, + metadata: Arc::default(), + }, + ); + fields.insert( + "id2".to_string(), + Field { + name: "id2".to_string(), + dtype: daft_core::datatypes::DataType::Utf8, + metadata: Arc::default(), + }, + ); + fields.insert( + "id3".to_string(), + Field { + name: "id3".to_string(), + dtype: daft_core::datatypes::DataType::Utf8, + metadata: Arc::default(), + }, + ); + fields.insert( + "id4".to_string(), + Field { + name: "id4".to_string(), + dtype: daft_core::datatypes::DataType::Int64, + metadata: Arc::default(), + }, + ); + fields.insert( + "id5".to_string(), + Field { + name: "id5".to_string(), + dtype: daft_core::datatypes::DataType::Int64, + metadata: Arc::default(), + }, + ); + fields.insert( + "id6".to_string(), + Field { + name: "id6".to_string(), + dtype: daft_core::datatypes::DataType::Int64, + metadata: Arc::default(), + }, + ); + fields.insert( + "v1".to_string(), + Field { + name: "v1".to_string(), + dtype: daft_core::datatypes::DataType::Int64, + metadata: Arc::default(), + }, + ); + fields.insert( + "v2".to_string(), + Field { + name: "v2".to_string(), + dtype: daft_core::datatypes::DataType::Int64, + metadata: Arc::default(), + }, + ); + fields.insert( + "v3".to_string(), + Field { + name: "v3".to_string(), + dtype: daft_core::datatypes::DataType::Float64, + metadata: Arc::default(), + }, + ); + + let runtime_handle = get_runtime(true)?; + let _rt_guard = runtime_handle.enter(); + let result = runtime_handle.block_on(async { + read_csv_local( + file.as_ref(), + Some(CsvConvertOptions { + limit: None, + include_columns: None, + column_names: None, + schema: Some(Arc::new(Schema { fields: fields })), + predicate: None, + }), + CsvParseOptions::default().with_delimiter(b','), + None, + None, + ) + .await + }); + + assert!( + result.is_ok(), + "Got Err: {:?} when using the experimental local csv reader", + result + ); + + let column_names = vec!["id1", "id2", "id3", "id4", "id5", "id6", "v1", "v2", "v3"]; + check_equal_local_arrow2( + file, + &result.unwrap(), + true, + None, + true, + None, + None, + None, + Some(column_names), + None, + None, + ); + + Ok(()) + } + #[test] fn test_csv_read_local_no_headers() -> DaftResult<()> { let file = format!( diff --git a/src/daft-decoding/src/deserialize.rs b/src/daft-decoding/src/deserialize.rs index dd3c56292e..ffb4a9aa6d 100644 --- a/src/daft-decoding/src/deserialize.rs +++ b/src/daft-decoding/src/deserialize.rs @@ -2,12 +2,13 @@ use arrow2::{ array::*, datatypes::*, error::{Error, Result}, + io::csv, offset::Offset, temporal_conversions, types::NativeType, }; use chrono::{Datelike, Timelike}; -use csv_async::ByteRecord; +use csv_async; pub(crate) const ISO8601: &str = "%+"; pub(crate) const ISO8601_NO_TIME_ZONE: &str = "%Y-%m-%dT%H:%M:%S%.f"; @@ -35,7 +36,14 @@ pub trait ByteRecordGeneric { fn get(&self, index: usize) -> Option<&[u8]>; } -impl ByteRecordGeneric for ByteRecord { +impl ByteRecordGeneric for csv_async::ByteRecord { + #[inline] + fn get(&self, index: usize) -> Option<&[u8]> { + self.get(index) + } +} + +impl ByteRecordGeneric for csv::read::ByteRecord { #[inline] fn get(&self, index: usize) -> Option<&[u8]> { self.get(index)