diff --git a/Cargo.lock b/Cargo.lock index 0efd60a689..60d5e4639b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -105,7 +105,7 @@ dependencies = [ [[package]] name = "arrow2" version = "0.17.1" -source = "git+https://github.com/Eventual-Inc/arrow2?rev=065a31da8fd8a75cbece5f99295a4068713a71ed#065a31da8fd8a75cbece5f99295a4068713a71ed" +source = "git+https://github.com/Eventual-Inc/arrow2?rev=0a6f79e0da7e32cc30381f4cc8cf5a8483909f78#0a6f79e0da7e32cc30381f4cc8cf5a8483909f78" dependencies = [ "ahash", "arrow-format", @@ -159,6 +159,25 @@ dependencies = [ "tokio", ] +[[package]] +name = "async-compression" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f658e2baef915ba0f26f1f7c42bfb8e12f532a01f449a090ded75ae7a07e9ba2" +dependencies = [ + "brotli", + "bzip2", + "deflate64", + "flate2", + "futures-core", + "memchr", + "pin-project-lite", + "tokio", + "xz2", + "zstd 0.13.0", + "zstd-safe 7.0.0", +] + [[package]] name = "async-recursion" version = "1.0.5" @@ -791,6 +810,27 @@ dependencies = [ "either", ] +[[package]] +name = "bzip2" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" +dependencies = [ + "bzip2-sys", + "libc", +] + +[[package]] +name = "bzip2-sys" +version = "0.1.11+1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "cc" version = "1.0.83" @@ -1069,6 +1109,7 @@ version = "0.1.10" dependencies = [ "arrow2", "async-compat", + "async-compression", "async-stream", "bytes", "common-error", @@ -1081,10 +1122,12 @@ dependencies = [ "pyo3", "pyo3-log", "rayon", + "rstest", "snafu", "tokio", "tokio-stream", "tokio-util", + "url", ] [[package]] @@ -1225,6 +1268,12 @@ dependencies = [ "rand 0.8.5", ] +[[package]] +name = "deflate64" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61ceff48ed7e0e66d428a569d36485a091c39fe118ee1220217655f6b814fa9" + [[package]] name = "der" version = "0.5.1" @@ -1500,6 +1549,12 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" +[[package]] +name = "futures-timer" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e64b03909df88034c26dc1547e8970b91f98bdb65165d6a4e9110d94263dbb2c" + [[package]] name = "futures-util" version = "0.3.28" @@ -1568,6 +1623,12 @@ version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0" +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + [[package]] name = "globset" version = "0.4.13" @@ -2127,6 +2188,17 @@ dependencies = [ "libc", ] +[[package]] +name = "lzma-sys" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5fda04ab3764e6cde78b9974eec4f779acaba7c4e84b36eca3cf77c581b85d27" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + [[package]] name = "matrixmultiply" version = "0.3.8" @@ -2511,7 +2583,7 @@ dependencies = [ "snap", "streaming-decompression", "xxhash-rust", - "zstd", + "zstd 0.12.4", ] [[package]] @@ -2948,6 +3020,12 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3cbb081b9784b07cceb8824c8583f86db4814d172ab043f3c23f7dc600bf83d" +[[package]] +name = "relative-path" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c707298afce11da2efef2f600116fa93ffa7a032b5d7b628aa17711ec81383ca" + [[package]] name = "reqwest" version = "0.11.22" @@ -3024,6 +3102,35 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rstest" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97eeab2f3c0a199bc4be135c36c924b6590b88c377d416494288c14f2db30199" +dependencies = [ + "futures", + "futures-timer", + "rstest_macros", + "rustc_version", +] + +[[package]] +name = "rstest_macros" +version = "0.18.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d428f8247852f894ee1be110b375111b586d4fa431f6c46e64ba5a0dcccbe605" +dependencies = [ + "cfg-if", + "glob", + "proc-macro2", + "quote", + "regex", + "relative-path", + "rustc_version", + "syn 2.0.38", + "unicode-ident", +] + [[package]] name = "rustc-demangle" version = "0.1.23" @@ -3257,6 +3364,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e4de37ad025c587a29e8f3f5605c00f70b98715ef90b9061a815b9e59e9042d6" dependencies = [ "doc-comment", + "futures-core", + "pin-project", "snafu-derive", ] @@ -4000,6 +4109,15 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9828b178da53440fa9c766a3d2f73f7cf5d0ac1fe3980c1e5018d899fd19e07b" +[[package]] +name = "xz2" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "388c44dc09d76f1536602ead6d325eb532f5c122f17782bd57fb47baeeb767e2" +dependencies = [ + "lzma-sys", +] + [[package]] name = "zeroize" version = "1.6.0" @@ -4012,7 +4130,16 @@ version = "0.12.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a27595e173641171fc74a1232b7b1c7a7cb6e18222c11e9dfb9888fa424c53c" dependencies = [ - "zstd-safe", + "zstd-safe 6.0.6", +] + +[[package]] +name = "zstd" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bffb3309596d527cfcba7dfc6ed6052f1d39dfbd7c867aa2e865e4a449c10110" +dependencies = [ + "zstd-safe 7.0.0", ] [[package]] @@ -4025,6 +4152,15 @@ dependencies = [ "zstd-sys", ] +[[package]] +name = "zstd-safe" +version = "7.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43747c7422e2924c11144d5229878b98180ef8b06cca4ab5af37afc8a8d8ea3e" +dependencies = [ + "zstd-sys", +] + [[package]] name = "zstd-sys" version = "2.0.8+zstd.1.5.5" diff --git a/Cargo.toml b/Cargo.toml index fff540922d..2a0240bf4c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -76,6 +76,7 @@ members = [ [workspace.dependencies] async-compat = "0.2.1" +async-compression = {version = "0.4.4", features = ["tokio", "all-algorithms"]} async-stream = "0.3.5" bytes = "1.4.0" futures = "0.3.28" @@ -88,15 +89,16 @@ prettytable-rs = "0.10" rand = "^0.8" rayon = "1.7.0" serde_json = "1.0.104" -snafu = "0.7.4" +snafu = {version = "0.7.4", features = ["futures"]} tokio = {version = "1.32.0", features = ["net", "time", "bytes", "process", "signal", "macros", "rt", "rt-multi-thread"]} tokio-stream = {version = "0.1.14", features = ["fs"]} tokio-util = "0.7.8" +url = "2.4.0" [workspace.dependencies.arrow2] git = "https://github.com/Eventual-Inc/arrow2" package = "arrow2" -rev = "065a31da8fd8a75cbece5f99295a4068713a71ed" +rev = "0a6f79e0da7e32cc30381f4cc8cf5a8483909f78" [workspace.dependencies.bincode] version = "1.3.3" diff --git a/daft/daft.pyi b/daft/daft.pyi index dffd96aa44..e0a3b644a3 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -190,8 +190,16 @@ class CsvSourceConfig: delimiter: str has_headers: bool + buffer_size: int | None + chunk_size: int | None - def __init__(self, delimiter: str, has_headers: bool): ... + def __init__( + self, + delimiter: str, + has_headers: bool, + buffer_size: int | None = None, + chunk_size: int | None = None, + ): ... class JsonSourceConfig: """ @@ -425,6 +433,9 @@ def read_csv( delimiter: str | None = None, io_config: IOConfig | None = None, multithreaded_io: bool | None = None, + schema: PySchema | None = None, + buffer_size: int | None = None, + chunk_size: int | None = None, ): ... def read_csv_schema( uri: str, diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index c98404d122..e24bbf2899 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -373,6 +373,8 @@ def _handle_tabular_files_scan( csv_options=TableParseCSVOptions( delimiter=format_config.delimiter, header_index=0 if format_config.has_headers else None, + buffer_size=format_config.buffer_size, + chunk_size=format_config.chunk_size, ), read_options=read_options, ) diff --git a/daft/io/_csv.py b/daft/io/_csv.py index 0ab885d418..fac80c76f0 100644 --- a/daft/io/_csv.py +++ b/daft/io/_csv.py @@ -28,6 +28,8 @@ def read_csv( delimiter: str = ",", io_config: Optional["IOConfig"] = None, use_native_downloader: bool = False, + _buffer_size: Optional[int] = None, + _chunk_size: Optional[int] = None, ) -> DataFrame: """Creates a DataFrame from CSV file(s) @@ -62,7 +64,12 @@ def read_csv( if isinstance(path, list) and len(path) == 0: raise ValueError(f"Cannot read DataFrame from from empty list of CSV filepaths") - csv_config = CsvSourceConfig(delimiter=delimiter, has_headers=has_headers) + csv_config = CsvSourceConfig( + delimiter=delimiter, + has_headers=has_headers, + buffer_size=_buffer_size, + chunk_size=_chunk_size, + ) file_format_config = FileFormatConfig.from_csv_config(csv_config) if use_native_downloader: storage_config = StorageConfig.native(NativeStorageConfig(io_config)) diff --git a/daft/runners/partitioning.py b/daft/runners/partitioning.py index 5d02a341a0..4a496ab0e1 100644 --- a/daft/runners/partitioning.py +++ b/daft/runners/partitioning.py @@ -44,10 +44,14 @@ class TableParseCSVOptions: Args: delimiter: The delimiter to use when parsing CSVs, defaults to "," header_index: Index of the header row, or None if no header + buffer_size: Size of the buffer (in bytes) used by the streaming reader. + chunk_size: Size of the chunks (in bytes) deserialized in parallel by the streaming reader. """ delimiter: str = "," header_index: int | None = 0 + buffer_size: int | None = None + chunk_size: int | None = None @dataclass(frozen=True) diff --git a/daft/table/table.py b/daft/table/table.py index ec865bd05d..22bda0d675 100644 --- a/daft/table/table.py +++ b/daft/table/table.py @@ -446,6 +446,9 @@ def read_csv( delimiter: str | None = None, io_config: IOConfig | None = None, multithreaded_io: bool | None = None, + schema: Schema | None = None, + buffer_size: int | None = None, + chunk_size: int | None = None, ) -> Table: return Table._from_pytable( _read_csv( @@ -457,6 +460,9 @@ def read_csv( delimiter=delimiter, io_config=io_config, multithreaded_io=multithreaded_io, + schema=schema._schema if schema is not None else None, + buffer_size=buffer_size, + chunk_size=chunk_size, ) ) diff --git a/daft/table/table_io.py b/daft/table/table_io.py index b12ffb9feb..73aff2f66a 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -223,6 +223,9 @@ def read_csv( has_header=has_header, delimiter=csv_options.delimiter, io_config=config.io_config, + schema=schema, + buffer_size=csv_options.buffer_size, + chunk_size=csv_options.chunk_size, ) return _cast_table_to_schema(tbl, read_options=read_options, schema=schema) diff --git a/src/daft-core/src/array/fixed_size_list_array.rs b/src/daft-core/src/array/fixed_size_list_array.rs index ff48317aee..0dbb587a90 100644 --- a/src/daft-core/src/array/fixed_size_list_array.rs +++ b/src/daft-core/src/array/fixed_size_list_array.rs @@ -7,7 +7,7 @@ use crate::datatypes::{DaftArrayType, Field}; use crate::series::Series; use crate::DataType; -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct FixedSizeListArray { pub field: Arc, pub flat_child: Series, diff --git a/src/daft-core/src/array/list_array.rs b/src/daft-core/src/array/list_array.rs index 1f8775801d..62caad10a5 100644 --- a/src/daft-core/src/array/list_array.rs +++ b/src/daft-core/src/array/list_array.rs @@ -7,7 +7,7 @@ use crate::datatypes::{DaftArrayType, Field}; use crate::series::Series; use crate::DataType; -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct ListArray { pub field: Arc, pub flat_child: Series, diff --git a/src/daft-core/src/array/pseudo_arrow/mod.rs b/src/daft-core/src/array/pseudo_arrow/mod.rs index d80981c2f3..80a60d83d8 100644 --- a/src/daft-core/src/array/pseudo_arrow/mod.rs +++ b/src/daft-core/src/array/pseudo_arrow/mod.rs @@ -214,7 +214,7 @@ pub mod compute; #[cfg(feature = "python")] pub mod python; -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct PseudoArrowArray { values: Buffer, validity: Option, diff --git a/src/daft-core/src/array/struct_array.rs b/src/daft-core/src/array/struct_array.rs index b2abb5d04c..a80759fc58 100644 --- a/src/daft-core/src/array/struct_array.rs +++ b/src/daft-core/src/array/struct_array.rs @@ -7,7 +7,7 @@ use crate::datatypes::{DaftArrayType, Field}; use crate::series::Series; use crate::DataType; -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct StructArray { pub field: Arc, pub children: Vec, diff --git a/src/daft-core/src/datatypes/logical.rs b/src/daft-core/src/datatypes/logical.rs index e5cd627e41..75c9d80bb2 100644 --- a/src/daft-core/src/datatypes/logical.rs +++ b/src/daft-core/src/datatypes/logical.rs @@ -15,7 +15,7 @@ use super::{ /// A LogicalArray is a wrapper on top of some underlying array, applying the semantic meaning of its /// field.datatype() to the underlying array. -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct LogicalArrayImpl { pub field: Arc, pub physical: PhysicalArray, diff --git a/src/daft-core/src/datatypes/mod.rs b/src/daft-core/src/datatypes/mod.rs index c3c88c4f2d..305877b860 100644 --- a/src/daft-core/src/datatypes/mod.rs +++ b/src/daft-core/src/datatypes/mod.rs @@ -57,7 +57,7 @@ pub trait DaftLogicalType: Send + Sync + DaftDataType + 'static { macro_rules! impl_daft_arrow_datatype { ($ca:ident, $variant:ident) => { - #[derive(Clone)] + #[derive(Clone, Debug)] pub struct $ca {} impl DaftDataType for $ca { @@ -76,7 +76,7 @@ macro_rules! impl_daft_arrow_datatype { macro_rules! impl_daft_non_arrow_datatype { ($ca:ident, $variant:ident) => { - #[derive(Clone)] + #[derive(Clone, Debug)] pub struct $ca {} impl DaftDataType for $ca { @@ -93,7 +93,7 @@ macro_rules! impl_daft_non_arrow_datatype { macro_rules! impl_daft_logical_data_array_datatype { ($ca:ident, $variant:ident, $physical_type:ident) => { - #[derive(Clone)] + #[derive(Clone, Debug)] pub struct $ca {} impl DaftDataType for $ca { @@ -113,7 +113,7 @@ macro_rules! impl_daft_logical_data_array_datatype { macro_rules! impl_daft_logical_fixed_size_list_datatype { ($ca:ident, $variant:ident) => { - #[derive(Clone)] + #[derive(Clone, Debug)] pub struct $ca {} impl DaftDataType for $ca { @@ -133,7 +133,7 @@ macro_rules! impl_daft_logical_fixed_size_list_datatype { macro_rules! impl_nested_datatype { ($ca:ident, $array_type:ident) => { - #[derive(Clone)] + #[derive(Clone, Debug)] pub struct $ca {} impl DaftDataType for $ca { diff --git a/src/daft-core/src/series/array_impl/mod.rs b/src/daft-core/src/series/array_impl/mod.rs index c440c4126a..6a3f0839ad 100644 --- a/src/daft-core/src/series/array_impl/mod.rs +++ b/src/daft-core/src/series/array_impl/mod.rs @@ -5,6 +5,7 @@ pub mod nested_array; use super::Series; +#[derive(Debug)] pub struct ArrayWrapper(pub T); pub trait IntoSeries { diff --git a/src/daft-core/src/series/mod.rs b/src/daft-core/src/series/mod.rs index d79b478ae2..27be057048 100644 --- a/src/daft-core/src/series/mod.rs +++ b/src/daft-core/src/series/mod.rs @@ -19,7 +19,7 @@ pub use array_impl::IntoSeries; pub(crate) use self::series_like::SeriesLike; -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Series { pub inner: Arc, } diff --git a/src/daft-core/src/series/series_like.rs b/src/daft-core/src/series/series_like.rs index 9a91a4ccc3..0595651c04 100644 --- a/src/daft-core/src/series/series_like.rs +++ b/src/daft-core/src/series/series_like.rs @@ -7,7 +7,7 @@ use crate::{ use common_error::DaftResult; use super::Series; -pub trait SeriesLike: Send + Sync + Any { +pub trait SeriesLike: Send + Sync + Any + std::fmt::Debug { #[allow(clippy::wrong_self_convention)] fn into_series(&self) -> Series; fn to_arrow(&self) -> Box; diff --git a/src/daft-csv/Cargo.toml b/src/daft-csv/Cargo.toml index feb97c058b..74b67318f2 100644 --- a/src/daft-csv/Cargo.toml +++ b/src/daft-csv/Cargo.toml @@ -1,6 +1,7 @@ [dependencies] arrow2 = {workspace = true, features = ["io_csv", "io_csv_async"]} async-compat = {workspace = true} +async-compression = {workspace = true} async-stream = {workspace = true} bytes = {workspace = true} common-error = {path = "../common/error", default-features = false} @@ -17,6 +18,10 @@ snafu = {workspace = true} tokio = {workspace = true} tokio-stream = {workspace = true} tokio-util = {workspace = true} +url = {workspace = true} + +[dev-dependencies] +rstest = "0.18.2" [features] default = ["python"] diff --git a/src/daft-csv/src/compression.rs b/src/daft-csv/src/compression.rs new file mode 100644 index 0000000000..268b1566d9 --- /dev/null +++ b/src/daft-csv/src/compression.rs @@ -0,0 +1,66 @@ +use async_compression::tokio::bufread::{ + BrotliDecoder, BzDecoder, DeflateDecoder, GzipDecoder, LzmaDecoder, XzDecoder, ZlibDecoder, + ZstdDecoder, +}; +use std::{path::PathBuf, pin::Pin}; +use tokio::io::{AsyncBufRead, AsyncRead}; +use url::Url; + +#[derive(Debug)] +pub enum CompressionCodec { + Brotli, + Bz, + Deflate, + Gzip, + Lzma, + Xz, + Zlib, + Zstd, +} + +impl CompressionCodec { + pub fn from_uri(uri: &str) -> Option { + let url = Url::parse(uri); + let path = match &url { + Ok(url) => url.path(), + _ => uri, + }; + let extension = PathBuf::from(path) + .extension()? + .to_string_lossy() + .to_string(); + Self::from_extension(extension.as_ref()) + } + pub fn from_extension(extension: &str) -> Option { + use CompressionCodec::*; + match extension { + "br" => Some(Brotli), + "bz2" => Some(Bz), + "deflate" => Some(Deflate), + "gz" => Some(Gzip), + "lzma" => Some(Lzma), + "xz" => Some(Xz), + "zl" => Some(Zlib), + "zstd" | "zst" => Some(Zstd), + "snappy" => todo!("Snappy compression support not yet implemented"), + _ => None, + } + } + + pub fn to_decoder( + &self, + reader: T, + ) -> Pin> { + use CompressionCodec::*; + match self { + Brotli => Box::pin(BrotliDecoder::new(reader)), + Bz => Box::pin(BzDecoder::new(reader)), + Deflate => Box::pin(DeflateDecoder::new(reader)), + Gzip => Box::pin(GzipDecoder::new(reader)), + Lzma => Box::pin(LzmaDecoder::new(reader)), + Xz => Box::pin(XzDecoder::new(reader)), + Zlib => Box::pin(ZlibDecoder::new(reader)), + Zstd => Box::pin(ZstdDecoder::new(reader)), + } + } +} diff --git a/src/daft-csv/src/lib.rs b/src/daft-csv/src/lib.rs index 3b750cdbfc..ff727da31c 100644 --- a/src/daft-csv/src/lib.rs +++ b/src/daft-csv/src/lib.rs @@ -3,6 +3,7 @@ use common_error::DaftError; use snafu::Snafu; +mod compression; pub mod metadata; #[cfg(feature = "python")] pub mod python; @@ -16,6 +17,17 @@ pub enum Error { IOError { source: daft_io::Error }, #[snafu(display("{source}"))] CSVError { source: csv_async::Error }, + #[snafu(display("{source}"))] + ArrowError { source: arrow2::error::Error }, + #[snafu(display("Error joining spawned task: {}", source))] + JoinError { source: tokio::task::JoinError }, + #[snafu(display( + "Sender of OneShot Channel Dropped before sending data over: {}", + source + ))] + OneShotRecvError { + source: tokio::sync::oneshot::error::RecvError, + }, } impl From for DaftError { diff --git a/src/daft-csv/src/metadata.rs b/src/daft-csv/src/metadata.rs index f5f572af5c..0e4fcd7b33 100644 --- a/src/daft-csv/src/metadata.rs +++ b/src/daft-csv/src/metadata.rs @@ -1,89 +1,630 @@ -use std::sync::Arc; +use std::{collections::HashSet, sync::Arc}; -use arrow2::io::csv::read_async::{infer, infer_schema, AsyncReaderBuilder}; +use arrow2::io::csv::read_async::{infer, AsyncReader, AsyncReaderBuilder}; use async_compat::CompatExt; use common_error::DaftResult; +use csv_async::ByteRecord; use daft_core::schema::Schema; use daft_io::{get_runtime, GetResult, IOClient, IOStatsRef}; -use futures::{io::Cursor, AsyncRead, AsyncSeek}; -use tokio::fs::File; +use tokio::{ + fs::File, + io::{AsyncBufRead, AsyncRead, BufReader}, +}; +use tokio_util::io::StreamReader; + +use crate::compression::CompressionCodec; + +const DEFAULT_COLUMN_PREFIX: &str = "column_"; pub fn read_csv_schema( uri: &str, has_header: bool, delimiter: Option, + max_bytes: Option, io_client: Arc, io_stats: Option, -) -> DaftResult { +) -> DaftResult<(Schema, usize, usize, f64, f64)> { let runtime_handle = get_runtime(true)?; let _rt_guard = runtime_handle.enter(); runtime_handle.block_on(async { - read_csv_schema_single(uri, has_header, delimiter, io_client, io_stats).await + read_csv_schema_single( + uri, + has_header, + delimiter, + // Default to 1 MiB. + max_bytes.or(Some(1024 * 1024)), + io_client, + io_stats, + ) + .await }) } -async fn read_csv_schema_single( +pub(crate) async fn read_csv_schema_single( uri: &str, has_header: bool, delimiter: Option, + max_bytes: Option, io_client: Arc, io_stats: Option, -) -> DaftResult { +) -> DaftResult<(Schema, usize, usize, f64, f64)> { + let compression_codec = CompressionCodec::from_uri(uri); match io_client .single_url_get(uri.to_string(), None, io_stats) .await? { GetResult::File(file) => { - read_csv_schema_from_reader( - File::open(file.path).await?.compat(), + read_csv_schema_from_compressed_reader( + BufReader::new(File::open(file.path).await?), + compression_codec, + has_header, + delimiter, + max_bytes, + ) + .await + } + GetResult::Stream(stream, size, _) => { + read_csv_schema_from_compressed_reader( + StreamReader::new(stream), + compression_codec, + has_header, + delimiter, + // Truncate max_bytes to size if both are set. + max_bytes.map(|m| size.map(|s| m.min(s)).unwrap_or(m)), + ) + .await + } + } +} + +async fn read_csv_schema_from_compressed_reader( + reader: R, + compression_codec: Option, + has_header: bool, + delimiter: Option, + max_bytes: Option, +) -> DaftResult<(Schema, usize, usize, f64, f64)> +where + R: AsyncBufRead + Unpin + Send + 'static, +{ + match compression_codec { + Some(compression) => { + read_csv_schema_from_uncompressed_reader( + compression.to_decoder(reader), has_header, delimiter, + max_bytes, ) .await } - result @ GetResult::Stream(..) => { - read_csv_schema_from_reader(Cursor::new(result.bytes().await?), has_header, delimiter) - .await + None => { + read_csv_schema_from_uncompressed_reader(reader, has_header, delimiter, max_bytes).await } } } -async fn read_csv_schema_from_reader( +async fn read_csv_schema_from_uncompressed_reader( reader: R, has_header: bool, delimiter: Option, -) -> DaftResult + max_bytes: Option, +) -> DaftResult<(Schema, usize, usize, f64, f64)> where - R: AsyncRead + AsyncSeek + Unpin + Sync + Send, + R: AsyncRead + Unpin + Send, +{ + let (schema, total_bytes_read, num_records_read, mean_size, std_size) = + read_csv_arrow_schema_from_uncompressed_reader(reader, has_header, delimiter, max_bytes) + .await?; + Ok(( + Schema::try_from(&schema)?, + total_bytes_read, + num_records_read, + mean_size, + std_size, + )) +} + +async fn read_csv_arrow_schema_from_uncompressed_reader( + reader: R, + has_header: bool, + delimiter: Option, + max_bytes: Option, +) -> DaftResult<(arrow2::datatypes::Schema, usize, usize, f64, f64)> +where + R: AsyncRead + Unpin + Send, { let mut reader = AsyncReaderBuilder::new() .has_headers(has_header) .delimiter(delimiter.unwrap_or(b',')) - .create_reader(reader); - let (fields, _) = infer_schema(&mut reader, None, has_header, &infer).await?; - let schema: arrow2::datatypes::Schema = fields.into(); - Schema::try_from(&schema) + .buffer_capacity(max_bytes.unwrap_or(1 << 20).min(1 << 20)) + .create_reader(reader.compat()); + let (fields, total_bytes_read, num_records_read, mean_size, std_size) = + infer_schema(&mut reader, None, max_bytes, has_header, &infer).await?; + Ok(( + fields.into(), + total_bytes_read, + num_records_read, + mean_size, + std_size, + )) +} + +async fn infer_schema( + reader: &mut AsyncReader, + max_rows: Option, + max_bytes: Option, + has_header: bool, + infer: &F, +) -> arrow2::error::Result<(Vec, usize, usize, f64, f64)> +where + R: futures::AsyncRead + Unpin + Send, + F: Fn(&[u8]) -> arrow2::datatypes::DataType, +{ + let mut record = ByteRecord::new(); + // get or create header names + // when has_header is false, creates default column names with column_ prefix + let (headers, did_read_record): (Vec, bool) = if has_header { + ( + reader + .headers() + .await? + .iter() + .map(|s| s.to_string()) + .collect(), + false, + ) + } else { + // Save the csv reader position before reading headers + if !reader.read_byte_record(&mut record).await? { + return Ok((vec![], 0, 0, 0f64, 0f64)); + } + let first_record_count = record.len(); + ( + (0..first_record_count) + .map(|i| format!("{}{}", DEFAULT_COLUMN_PREFIX, i + 1)) + .collect(), + true, + ) + }; + // keep track of inferred field types + let mut column_types: Vec> = + vec![HashSet::new(); headers.len()]; + let mut records_count = 0; + let mut total_bytes = 0; + let mut mean = 0f64; + let mut m2 = 0f64; + if did_read_record { + records_count += 1; + let record_size = record.as_slice().len(); + total_bytes += record_size; + let delta = (record_size as f64) - mean; + mean += delta / (records_count as f64); + let delta2 = (record_size as f64) - mean; + m2 += delta * delta2; + for (i, column) in column_types.iter_mut().enumerate() { + if let Some(string) = record.get(i) { + column.insert(infer(string)); + } + } + } + let max_records = max_rows.unwrap_or(usize::MAX); + let max_bytes = max_bytes.unwrap_or(usize::MAX); + while records_count < max_records && total_bytes < max_bytes { + if !reader.read_byte_record(&mut record).await? { + break; + } + records_count += 1; + let record_size = record.as_slice().len(); + total_bytes += record_size; + let delta = (record_size as f64) - mean; + mean += delta / (records_count as f64); + let delta2 = (record_size as f64) - mean; + m2 += delta * delta2; + for (i, column) in column_types.iter_mut().enumerate() { + if let Some(string) = record.get(i) { + column.insert(infer(string)); + } + } + } + let fields = merge_schema(&headers, &mut column_types); + let std = (m2 / ((records_count - 1) as f64)).sqrt(); + Ok((fields, total_bytes, records_count, mean, std)) +} + +fn merge_fields( + field_name: &str, + possibilities: &mut HashSet, +) -> arrow2::datatypes::Field { + use arrow2::datatypes::DataType; + + if possibilities.len() > 1 { + // Drop nulls from possibilities. + possibilities.remove(&DataType::Null); + } + // determine data type based on possible types + // if there are incompatible types, use DataType::Utf8 + let data_type = match possibilities.len() { + 1 => possibilities.drain().next().unwrap(), + 2 => { + if possibilities.contains(&DataType::Int64) + && possibilities.contains(&DataType::Float64) + { + // we have an integer and double, fall down to double + DataType::Float64 + } else { + // default to Utf8 for conflicting datatypes (e.g bool and int) + DataType::Utf8 + } + } + _ => DataType::Utf8, + }; + arrow2::datatypes::Field::new(field_name, data_type, true) +} + +fn merge_schema( + headers: &[String], + column_types: &mut [HashSet], +) -> Vec { + headers + .iter() + .zip(column_types.iter_mut()) + .map(|(field_name, possibilities)| merge_fields(field_name, possibilities)) + .collect() } #[cfg(test)] mod tests { use std::sync::Arc; - use common_error::DaftResult; + use common_error::{DaftError, DaftResult}; use daft_core::{datatypes::Field, schema::Schema, DataType}; use daft_io::{IOClient, IOConfig}; + use rstest::rstest; use super::read_csv_schema; + #[rstest] + fn test_csv_schema_local( + #[values( + // Uncompressed + None, + // brotli + Some("br"), + // bzip2 + Some("bz2"), + // deflate + Some("deflate"), + // gzip + Some("gz"), + // lzma + Some("lzma"), + // xz + Some("xz"), + // zlib + Some("zl"), + // zstd + Some("zst"), + )] + compression: Option<&str>, + ) -> DaftResult<()> { + let file = format!( + "{}/test/iris_tiny.csv{}", + env!("CARGO_MANIFEST_DIR"), + compression.map_or("".to_string(), |ext| format!(".{}", ext)) + ); + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let (schema, total_bytes_read, num_records_read, _, _) = + read_csv_schema(file.as_ref(), true, None, None, io_client.clone(), None)?; + assert_eq!( + schema, + Schema::new(vec![ + Field::new("sepal.length", DataType::Float64), + Field::new("sepal.width", DataType::Float64), + Field::new("petal.length", DataType::Float64), + Field::new("petal.width", DataType::Float64), + Field::new("variety", DataType::Utf8), + ])? + .into(), + ); + assert_eq!(total_bytes_read, 328); + assert_eq!(num_records_read, 20); + + Ok(()) + } + + #[test] + fn test_csv_schema_local_delimiter() -> DaftResult<()> { + let file = format!( + "{}/test/iris_tiny_bar_delimiter.csv", + env!("CARGO_MANIFEST_DIR"), + ); + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let (schema, total_bytes_read, num_records_read, _, _) = read_csv_schema( + file.as_ref(), + true, + Some(b'|'), + None, + io_client.clone(), + None, + )?; + assert_eq!( + schema, + Schema::new(vec![ + Field::new("sepal.length", DataType::Float64), + Field::new("sepal.width", DataType::Float64), + Field::new("petal.length", DataType::Float64), + Field::new("petal.width", DataType::Float64), + Field::new("variety", DataType::Utf8), + ])? + .into(), + ); + assert_eq!(total_bytes_read, 328); + assert_eq!(num_records_read, 20); + + Ok(()) + } + + #[test] + fn test_csv_schema_local_read_stats() -> DaftResult<()> { + let file = format!("{}/test/iris_tiny.csv", env!("CARGO_MANIFEST_DIR"),); + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let (_, total_bytes_read, num_records_read, _, _) = + read_csv_schema(file.as_ref(), true, None, None, io_client.clone(), None)?; + assert_eq!(total_bytes_read, 328); + assert_eq!(num_records_read, 20); + + Ok(()) + } + + #[test] + fn test_csv_schema_local_no_headers() -> DaftResult<()> { + let file = format!( + "{}/test/iris_tiny_no_headers.csv", + env!("CARGO_MANIFEST_DIR"), + ); + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let (schema, total_bytes_read, num_records_read, _, _) = + read_csv_schema(file.as_ref(), false, None, None, io_client.clone(), None)?; + assert_eq!( + schema, + Schema::new(vec![ + Field::new("column_1", DataType::Float64), + Field::new("column_2", DataType::Float64), + Field::new("column_3", DataType::Float64), + Field::new("column_4", DataType::Float64), + Field::new("column_5", DataType::Utf8), + ])? + .into(), + ); + assert_eq!(total_bytes_read, 328); + assert_eq!(num_records_read, 20); + + Ok(()) + } + + #[test] + fn test_csv_schema_local_empty_lines_skipped() -> DaftResult<()> { + let file = format!( + "{}/test/iris_tiny_empty_lines.csv", + env!("CARGO_MANIFEST_DIR"), + ); + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let (schema, total_bytes_read, num_records_read, _, _) = + read_csv_schema(file.as_ref(), true, None, None, io_client.clone(), None)?; + assert_eq!( + schema, + Schema::new(vec![ + Field::new("sepal.length", DataType::Float64), + Field::new("sepal.width", DataType::Float64), + Field::new("petal.length", DataType::Float64), + Field::new("petal.width", DataType::Float64), + Field::new("variety", DataType::Utf8), + ])? + .into(), + ); + assert_eq!(total_bytes_read, 49); + assert_eq!(num_records_read, 3); + + Ok(()) + } + + #[test] + fn test_csv_schema_local_nulls() -> DaftResult<()> { + let file = format!("{}/test/iris_tiny_nulls.csv", env!("CARGO_MANIFEST_DIR"),); + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let (schema, total_bytes_read, num_records_read, _, _) = + read_csv_schema(file.as_ref(), true, None, None, io_client.clone(), None)?; + assert_eq!( + schema, + Schema::new(vec![ + Field::new("sepal.length", DataType::Float64), + Field::new("sepal.width", DataType::Float64), + Field::new("petal.length", DataType::Float64), + Field::new("petal.width", DataType::Float64), + Field::new("variety", DataType::Utf8), + ])? + .into(), + ); + assert_eq!(total_bytes_read, 82); + assert_eq!(num_records_read, 6); + + Ok(()) + } + + #[test] + fn test_csv_schema_local_conflicting_types_utf8_fallback() -> DaftResult<()> { + let file = format!( + "{}/test/iris_tiny_conflicting_dtypes.csv", + env!("CARGO_MANIFEST_DIR"), + ); + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let (schema, total_bytes_read, num_records_read, _, _) = + read_csv_schema(file.as_ref(), true, None, None, io_client.clone(), None)?; + assert_eq!( + schema, + Schema::new(vec![ + // All conflicting dtypes fall back to UTF8. + Field::new("sepal.length", DataType::Utf8), + Field::new("sepal.width", DataType::Utf8), + Field::new("petal.length", DataType::Utf8), + Field::new("petal.width", DataType::Utf8), + Field::new("variety", DataType::Utf8), + ])? + .into(), + ); + assert_eq!(total_bytes_read, 33); + assert_eq!(num_records_read, 2); + + Ok(()) + } + + #[test] + fn test_csv_schema_local_max_bytes() -> DaftResult<()> { + let file = format!("{}/test/iris_tiny.csv", env!("CARGO_MANIFEST_DIR"),); + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let (schema, total_bytes_read, num_records_read, _, _) = read_csv_schema( + file.as_ref(), + true, + None, + Some(100), + io_client.clone(), + None, + )?; + assert_eq!( + schema, + Schema::new(vec![ + Field::new("sepal.length", DataType::Float64), + Field::new("sepal.width", DataType::Float64), + Field::new("petal.length", DataType::Float64), + Field::new("petal.width", DataType::Float64), + Field::new("variety", DataType::Utf8), + ])? + .into(), + ); + // Max bytes doesn't include header, so add 15 bytes to upper bound. + assert!(total_bytes_read <= 100 + 15, "{}", total_bytes_read); + assert!(num_records_read <= 10, "{}", num_records_read); + + Ok(()) + } + #[test] - fn test_csv_schema_from_s3() -> DaftResult<()> { - let file = "s3://daft-public-data/test_fixtures/csv-dev/mvp.csv"; + fn test_csv_schema_local_invalid_column_header_mismatch() -> DaftResult<()> { + let file = format!( + "{}/test/iris_tiny_invalid_header_cols_mismatch.csv", + env!("CARGO_MANIFEST_DIR"), + ); + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let err = read_csv_schema(file.as_ref(), true, None, None, io_client.clone(), None); + assert!(err.is_err()); + let err = err.unwrap_err(); + assert!(matches!(err, DaftError::ArrowError(_)), "{}", err); + assert!( + err.to_string() + .contains("found record with 4 fields, but the previous record has 5 fields"), + "{}", + err + ); + + Ok(()) + } + + #[test] + fn test_csv_schema_local_invalid_no_header_variable_num_cols() -> DaftResult<()> { + let file = format!( + "{}/test/iris_tiny_invalid_no_header_variable_num_cols.csv", + env!("CARGO_MANIFEST_DIR"), + ); + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let err = read_csv_schema(file.as_ref(), true, None, None, io_client.clone(), None); + assert!(err.is_err()); + let err = err.unwrap_err(); + assert!(matches!(err, DaftError::ArrowError(_)), "{}", err); + assert!( + err.to_string() + .contains("found record with 5 fields, but the previous record has 4 fields"), + "{}", + err + ); + + Ok(()) + } + + #[rstest] + fn test_csv_schema_s3( + #[values( + // Uncompressed + None, + // brotli + Some("br"), + // bzip2 + Some("bz2"), + // deflate + Some("deflate"), + // gzip + Some("gz"), + // lzma + Some("lzma"), + // xz + Some("xz"), + // zlib + Some("zl"), + // zstd + Some("zst"), + )] + compression: Option<&str>, + ) -> DaftResult<()> { + let file = format!( + "s3://daft-public-data/test_fixtures/csv-dev/mvp.csv{}", + compression.map_or("".to_string(), |ext| format!(".{}", ext)) + ); let mut io_config = IOConfig::default(); io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let schema = read_csv_schema(file, true, None, io_client.clone(), None)?; + let (schema, _, _, _, _) = + read_csv_schema(file.as_ref(), true, None, None, io_client.clone(), None)?; assert_eq!( schema, Schema::new(vec![ diff --git a/src/daft-csv/src/python.rs b/src/daft-csv/src/python.rs index f45a03b9b5..def9bfa966 100644 --- a/src/daft-csv/src/python.rs +++ b/src/daft-csv/src/python.rs @@ -32,6 +32,9 @@ pub mod pylib { delimiter: Option<&str>, io_config: Option, multithreaded_io: Option, + schema: Option, + buffer_size: Option, + chunk_size: Option, ) -> PyResult { py.allow_threads(|| { let io_stats = IOStatsContext::new(format!("read_csv: for uri {uri}")); @@ -50,6 +53,10 @@ pub mod pylib { io_client, Some(io_stats), multithreaded_io.unwrap_or(true), + schema.map(|s| s.schema), + buffer_size, + chunk_size, + None, )? .into()) }) @@ -61,6 +68,7 @@ pub mod pylib { uri: &str, has_header: Option, delimiter: Option<&str>, + max_bytes: Option, io_config: Option, multithreaded_io: Option, ) -> PyResult { @@ -71,14 +79,15 @@ pub mod pylib { multithreaded_io.unwrap_or(true), io_config.unwrap_or_default().config.into(), )?; - Ok(Arc::new(crate::metadata::read_csv_schema( + let (schema, _, _, _, _) = crate::metadata::read_csv_schema( uri, has_header.unwrap_or(true), str_delimiter_to_byte(delimiter)?, + max_bytes, io_client, Some(io_stats), - )?) - .into()) + )?; + Ok(Arc::new(schema).into()) }) } } diff --git a/src/daft-csv/src/read.rs b/src/daft-csv/src/read.rs index cc8bdcd63c..693a2d0cdb 100644 --- a/src/daft-csv/src/read.rs +++ b/src/daft-csv/src/read.rs @@ -1,22 +1,36 @@ use std::{ collections::{HashMap, HashSet}, + num::NonZeroUsize, sync::Arc, }; use arrow2::{ datatypes::Field, - io::csv::read_async::{ - deserialize_batch, deserialize_column, infer, infer_schema, read_rows, AsyncReaderBuilder, - ByteRecord, - }, + io::csv::read_async::{deserialize_column, read_rows, AsyncReaderBuilder, ByteRecord}, }; -use async_compat::CompatExt; +use async_compat::{Compat, CompatExt}; use common_error::DaftResult; -use daft_core::{schema::Schema, utils::arrow::cast_array_for_daft_if_needed, Series}; +use csv_async::AsyncReader; +use daft_core::{ + schema::{Schema, SchemaRef}, + utils::arrow::cast_array_for_daft_if_needed, + Series, +}; use daft_io::{get_runtime, GetResult, IOClient, IOStatsRef}; use daft_table::Table; -use futures::{io::Cursor, AsyncRead, AsyncSeek}; -use tokio::fs::File; +use futures::TryStreamExt; +use rayon::prelude::{ + IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator, +}; +use snafu::{futures::TryFutureExt, ResultExt}; +use tokio::{ + fs::File, + io::{AsyncBufRead, AsyncRead, BufReader}, +}; +use tokio_util::io::StreamReader; + +use crate::metadata::read_csv_schema_single; +use crate::{compression::CompressionCodec, ArrowSnafu}; #[allow(clippy::too_many_arguments)] pub fn read_csv( @@ -29,6 +43,10 @@ pub fn read_csv( io_client: Arc, io_stats: Option, multithreaded_io: bool, + schema: Option, + buffer_size: Option, + chunk_size: Option, + max_chunks_in_flight: Option, ) -> DaftResult { let runtime_handle = get_runtime(multithreaded_io)?; let _rt_guard = runtime_handle.enter(); @@ -39,9 +57,13 @@ pub fn read_csv( include_columns, num_rows, has_header, - delimiter, + delimiter.unwrap_or(b','), io_client, io_stats, + schema, + buffer_size, + chunk_size, + max_chunks_in_flight, ) .await }) @@ -54,57 +76,178 @@ async fn read_csv_single( include_columns: Option>, num_rows: Option, has_header: bool, - delimiter: Option, + delimiter: u8, io_client: Arc, io_stats: Option, + schema: Option, + buffer_size: Option, + chunk_size: Option, + max_chunks_in_flight: Option, ) -> DaftResult
{ + let (schema, estimated_mean_row_size, estimated_std_row_size) = match schema { + Some(schema) => (schema.to_arrow()?, None, None), + None => { + let (schema, _, _, mean, std) = read_csv_schema_single( + uri, + has_header, + Some(delimiter), + // Read at most 1 MiB when doing schema inference. + Some(1024 * 1024), + io_client.clone(), + io_stats.clone(), + ) + .await?; + (schema.to_arrow()?, Some(mean), Some(std)) + } + }; + let compression_codec = CompressionCodec::from_uri(uri); match io_client .single_url_get(uri.to_string(), None, io_stats) .await? { GetResult::File(file) => { - read_csv_single_from_reader( - File::open(file.path).await?.compat(), + read_csv_from_compressed_reader( + BufReader::new(File::open(file.path).await?), + compression_codec, column_names, include_columns, num_rows, has_header, delimiter, + schema, + // Default buffer size of 512 KiB. + buffer_size.unwrap_or(512 * 1024), + // Default chunk size of 64 KiB. + chunk_size.unwrap_or(64 * 1024), + // Default max chunks in flight is set to 2x the number of cores, which should ensure pipelining of reading chunks + // with the parsing of chunks on the rayon threadpool. + max_chunks_in_flight.unwrap_or( + std::thread::available_parallelism() + .unwrap_or(NonZeroUsize::new(2).unwrap()) + .checked_mul(2.try_into().unwrap()) + .unwrap() + .try_into() + .unwrap(), + ), + estimated_mean_row_size, + estimated_std_row_size, ) .await } - result @ GetResult::Stream(..) => { - // TODO(Clark): Enable streaming remote reads by wrapping the BoxStream in a buffered stream that's - // (1) sync and (2) seekable. - read_csv_single_from_reader( - Cursor::new(result.bytes().await?), + GetResult::Stream(stream, _, _) => { + read_csv_from_compressed_reader( + StreamReader::new(stream), + compression_codec, column_names, include_columns, num_rows, has_header, delimiter, + schema, + // Default buffer size of 512 KiB. + buffer_size.unwrap_or(512 * 1024), + // Default chunk size of 64 KiB. + chunk_size.unwrap_or(64 * 1024), + // Default max chunks in flight is set to 2x the number of cores, which should ensure pipelining of reading chunks + // with the parsing of chunks on the rayon threadpool. + max_chunks_in_flight.unwrap_or( + std::thread::available_parallelism() + .unwrap_or(NonZeroUsize::new(2).unwrap()) + .checked_mul(2.try_into().unwrap()) + .unwrap() + .try_into() + .unwrap(), + ), + estimated_mean_row_size, + estimated_std_row_size, ) .await } } } -async fn read_csv_single_from_reader( +#[allow(clippy::too_many_arguments)] +async fn read_csv_from_compressed_reader( reader: R, + compression_codec: Option, column_names: Option>, include_columns: Option>, num_rows: Option, has_header: bool, - delimiter: Option, + delimiter: u8, + schema: arrow2::datatypes::Schema, + buffer_size: usize, + chunk_size: usize, + max_chunks_in_flight: usize, + estimated_mean_row_size: Option, + estimated_std_row_size: Option, +) -> DaftResult
+where + R: AsyncBufRead + Unpin + Send + 'static, +{ + match compression_codec { + Some(compression) => { + read_csv_from_uncompressed_reader( + compression.to_decoder(reader), + column_names, + include_columns, + num_rows, + has_header, + delimiter, + schema, + buffer_size, + chunk_size, + max_chunks_in_flight, + estimated_mean_row_size, + estimated_std_row_size, + ) + .await + } + None => { + read_csv_from_uncompressed_reader( + reader, + column_names, + include_columns, + num_rows, + has_header, + delimiter, + schema, + buffer_size, + chunk_size, + max_chunks_in_flight, + estimated_mean_row_size, + estimated_std_row_size, + ) + .await + } + } +} + +#[allow(clippy::too_many_arguments)] +async fn read_csv_from_uncompressed_reader( + stream_reader: R, + column_names: Option>, + include_columns: Option>, + num_rows: Option, + has_header: bool, + delimiter: u8, + schema: arrow2::datatypes::Schema, + buffer_size: usize, + chunk_size: usize, + max_chunks_in_flight: usize, + estimated_mean_row_size: Option, + estimated_std_row_size: Option, ) -> DaftResult
where - R: AsyncRead + AsyncSeek + Unpin + Sync + Send, + R: AsyncRead + Unpin + Send, { - let mut reader = AsyncReaderBuilder::new() + let reader = AsyncReaderBuilder::new() .has_headers(has_header) - .delimiter(delimiter.unwrap_or(b',')) - .create_reader(reader); - let (mut fields, _) = infer_schema(&mut reader, None, has_header, &infer).await?; + .delimiter(delimiter) + .buffer_capacity(buffer_size) + .create_reader(stream_reader.compat()); + let mut fields = schema.fields; + // Rename fields, if necessary. if let Some(column_names) = column_names { fields = fields .into_iter() @@ -114,70 +257,28 @@ where }) .collect(); } - let field_name_to_idx = fields - .iter() - .enumerate() - .map(|(idx, f)| (f.name.as_ref(), idx)) - .collect::>(); - let projection_indices = include_columns.as_ref().map(|cols| { - cols.iter() - .map(|c| field_name_to_idx[c]) - .collect::>() - }); - let num_rows = num_rows.unwrap_or(usize::MAX); - // TODO(Clark): Make batch size configurable. - // TODO(Clark): Get estimated average row size in bytes during schema inference and use it to: - // 1. Set a reasonable batch size. - // 2. Preallocate per-column batch vecs. - // 3. Preallocate column Arrow array buffers. - let batch_size = 1024.min(num_rows); - // TODO(Clark): Instead of allocating an array-per-column-batch and concatenating at the end, - // progressively grow a single array per column (with the above preallocation based on estimated - // number of rows). - let mut column_arrays = vec![ - vec![]; - projection_indices - .as_ref() - .map(|p| p.len()) - .unwrap_or(fields.len()) - ]; - let mut buffer = vec![ByteRecord::with_capacity(0, fields.len()); batch_size]; - let mut rows = buffer.as_mut_slice(); - // Number of rows read in last read. - let mut rows_read = 1; - // Total number of rows read across all reads. - let mut total_rows_read = 0; - while rows_read > 0 && total_rows_read < num_rows { - if rows.len() > num_rows - total_rows_read { - // If we need to read less than the entire row buffer, truncate the buffer to the number - // of rows that we actually want to read. - rows = &mut rows[..num_rows - total_rows_read + 1] - } - rows_read = read_rows(&mut reader, 0, rows).await?; - total_rows_read += rows_read; - // TODO(Clark): Parallelize column deserialization over a rayon threadpool. - for (idx, array) in deserialize_batch( - &rows[..rows_read], - &fields, - projection_indices.as_deref(), - 0, - deserialize_column, - )? - .into_arrays() - .into_iter() - .enumerate() - { - column_arrays[idx].push(array); - } - } + // Read CSV into Arrow2 column chunks. + let column_chunks = read_into_column_chunks( + reader, + fields.clone().into(), + fields_to_projection_indices(&fields, &include_columns), + num_rows, + chunk_size, + max_chunks_in_flight, + estimated_mean_row_size, + estimated_std_row_size, + ) + .await?; + // Truncate fields to only contain projected columns. if let Some(include_columns) = include_columns { - // Truncate fields to only contain projected columns. let include_columns: HashSet<&str> = include_columns.into_iter().collect(); fields.retain(|f| include_columns.contains(f.name.as_str())) } - let columns_series = column_arrays - .into_iter() - .zip(fields.iter()) + // Concatenate column chunks and convert into Daft Series. + // Note that this concatenation is done in parallel on the rayon threadpool. + let columns_series = column_chunks + .into_par_iter() + .zip(&fields) .map(|(mut arrays, field)| { let array = if arrays.len() > 1 { // Concatenate all array chunks. @@ -190,63 +291,387 @@ where Series::try_from((field.name.as_ref(), cast_array_for_daft_if_needed(array))) }) .collect::>>()?; + // Build Daft Table. let schema: arrow2::datatypes::Schema = fields.into(); let daft_schema = Schema::try_from(&schema)?; Table::new(daft_schema, columns_series) } +#[allow(clippy::too_many_arguments)] +async fn read_into_column_chunks( + mut reader: AsyncReader>, + fields: Arc>, + projection_indices: Arc>, + num_rows: Option, + chunk_size: usize, + max_chunks_in_flight: usize, + estimated_mean_row_size: Option, + estimated_std_row_size: Option, +) -> DaftResult>>> +where + R: AsyncRead + Unpin + Send, +{ + let num_fields = fields.len(); + let num_rows = num_rows.unwrap_or(usize::MAX); + let mut estimated_mean_row_size = estimated_mean_row_size.unwrap_or(200f64); + let mut estimated_std_row_size = estimated_std_row_size.unwrap_or(20f64); + // Stream of unparsed CSV byte record chunks. + let read_stream = async_stream::try_stream! { + // Number of rows read in last read. + let mut rows_read = 1; + // Total number of rows read across all reads. + let mut total_rows_read = 0; + let mut mean = 0f64; + let mut m2 = 0f64; + while rows_read > 0 && total_rows_read < num_rows { + // Allocate a record buffer of size 1 standard above the observed mean record size. + // If the record sizes are normally distributed, this should result in ~85% of the records not requiring + // reallocation during reading. + let record_buffer_size = (estimated_mean_row_size + estimated_std_row_size).ceil() as usize; + // Get chunk size in # of rows, using the estimated mean row size in bytes. + let chunk_size_rows = { + let estimated_rows_per_desired_chunk = chunk_size / (estimated_mean_row_size.ceil() as usize); + // Process at least 8 rows in a chunk, even if the rows are pretty large. + // Cap chunk size at the remaining number of rows we need to read before we reach the num_rows limit. + estimated_rows_per_desired_chunk.max(8).min(num_rows - total_rows_read) + }; + let mut chunk_buffer = vec![ + ByteRecord::with_capacity(record_buffer_size, num_fields); + chunk_size_rows + ]; + + let byte_pos_before = reader.position().byte(); + rows_read = read_rows(&mut reader, 0, chunk_buffer.as_mut_slice()).await.context(ArrowSnafu {})?; + let bytes_read = reader.position().byte() - byte_pos_before; + + // Update stats. + total_rows_read += rows_read; + let delta = (bytes_read as f64) - mean; + mean += delta / (total_rows_read as f64); + let delta2 = (bytes_read as f64) - mean; + m2 += delta * delta2; + estimated_mean_row_size = mean; + estimated_std_row_size = (m2 / ((total_rows_read - 1) as f64)).sqrt(); + + chunk_buffer.truncate(rows_read); + yield chunk_buffer + } + }; + // Parsing stream: we spawn background tokio + rayon tasks so we can pipeline chunk parsing with chunk reading, and + // we further parse each chunk column in parallel on the rayon threadpool. + let parse_stream = read_stream.map_ok(|record| { + let fields = fields.clone(); + let projection_indices = projection_indices.clone(); + tokio::spawn(async move { + let (send, recv) = tokio::sync::oneshot::channel(); + rayon::spawn(move || { + let result = (move || { + let chunk = projection_indices + .par_iter() + .map(|idx| { + deserialize_column( + record.as_slice(), + *idx, + fields[*idx].data_type().clone(), + 0, + ) + }) + .collect::>>>()?; + DaftResult::Ok(chunk) + })(); + let _ = send.send(result); + }); + recv.await.context(super::OneShotRecvSnafu {})? + }) + .context(super::JoinSnafu {}) + }); + // Collect all chunks in chunk x column form. + let chunks = parse_stream + // Limit the number of chunks we have in flight at any given time. + .try_buffered(max_chunks_in_flight) + .try_collect::>() + .await? + .into_iter() + .collect::>>()?; + // Transpose chunk x column into column x chunk. + let mut column_arrays = vec![Vec::with_capacity(chunks.len()); projection_indices.len()]; + for chunk in chunks.into_iter() { + for (idx, col) in chunk.into_iter().enumerate() { + column_arrays[idx].push(col); + } + } + Ok(column_arrays) +} + +fn fields_to_projection_indices( + fields: &Vec, + include_columns: &Option>, +) -> Arc> { + let field_name_to_idx = fields + .iter() + .enumerate() + .map(|(idx, f)| (f.name.as_ref(), idx)) + .collect::>(); + include_columns + .as_ref() + .map_or_else( + || (0..fields.len()).collect(), + |cols| { + cols.iter() + .map(|c| field_name_to_idx[c]) + .collect::>() + }, + ) + .into() +} + #[cfg(test)] mod tests { use std::sync::Arc; - use common_error::DaftResult; + use common_error::{DaftError, DaftResult}; - use daft_core::{datatypes::Field, schema::Schema, DataType}; + use arrow2::io::csv::read::{ + deserialize_batch, deserialize_column, infer, infer_schema, read_rows, ByteRecord, + ReaderBuilder, + }; + use daft_core::{ + datatypes::Field, + schema::Schema, + utils::arrow::{cast_array_for_daft_if_needed, cast_array_from_daft_if_needed}, + DataType, + }; use daft_io::{IOClient, IOConfig}; + use daft_table::Table; + use rstest::rstest; use super::read_csv; + fn check_equal_local_arrow2( + path: &str, + out: &Table, + has_header: bool, + delimiter: Option, + column_names: Option>, + projection: Option>, + limit: Option, + ) { + let mut reader = ReaderBuilder::new() + .delimiter(delimiter.unwrap_or(b',')) + .from_path(path) + .unwrap(); + let (mut fields, _) = infer_schema(&mut reader, None, has_header, &infer).unwrap(); + if !has_header && let Some(column_names) = column_names { + fields = fields.into_iter().zip(column_names.into_iter()).map(|(field, name)| arrow2::datatypes::Field::new(name, field.data_type, true).with_metadata(field.metadata)).collect::>(); + } + let mut rows = vec![ByteRecord::default(); limit.unwrap_or(100)]; + let rows_read = read_rows(&mut reader, 0, &mut rows).unwrap(); + let rows = &rows[..rows_read]; + let chunk = deserialize_batch( + rows, + &fields, + projection.as_ref().map(|p| p.as_slice()), + 0, + deserialize_column, + ) + .unwrap(); + if let Some(projection) = projection { + fields = projection + .into_iter() + .map(|idx| fields[idx].clone()) + .collect(); + } + let columns = chunk + .into_arrays() + .into_iter() + // Roundtrip with Daft for casting. + .map(|c| cast_array_from_daft_if_needed(cast_array_for_daft_if_needed(c))) + .collect::>(); + let schema: arrow2::datatypes::Schema = fields.into(); + // Roundtrip with Daft for casting. + let schema = Schema::try_from(&schema).unwrap().to_arrow().unwrap(); + assert_eq!(out.schema.to_arrow().unwrap(), schema); + let out_columns = (0..out.num_columns()) + .map(|i| out.get_column_by_index(i).unwrap().to_arrow()) + .collect::>(); + assert_eq!(out_columns, columns); + } + + #[rstest] + fn test_csv_read_local( + #[values( + // Uncompressed + None, + // brotli + Some("br"), + // bzip2 + Some("bz2"), + // deflate + Some("deflate"), + // gzip + Some("gz"), + // lzma + Some("lzma"), + // xz + Some("xz"), + // zlib + Some("zl"), + // zstd + Some("zst"), + )] + compression: Option<&str>, + ) -> DaftResult<()> { + let file = format!( + "{}/test/iris_tiny.csv{}", + env!("CARGO_MANIFEST_DIR"), + compression.map_or("".to_string(), |ext| format!(".{}", ext)) + ); + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let table = read_csv( + file.as_ref(), + None, + None, + None, + true, + None, + io_client, + None, + true, + None, + None, + None, + None, + )?; + assert_eq!(table.len(), 20); + assert_eq!( + table.schema, + Schema::new(vec![ + Field::new("sepal.length", DataType::Float64), + Field::new("sepal.width", DataType::Float64), + Field::new("petal.length", DataType::Float64), + Field::new("petal.width", DataType::Float64), + Field::new("variety", DataType::Utf8), + ])? + .into(), + ); + if compression.is_none() { + check_equal_local_arrow2(file.as_ref(), &table, true, None, None, None, None); + } + + Ok(()) + } + #[test] - fn test_csv_read_from_s3() -> DaftResult<()> { - let file = "s3://daft-public-data/test_fixtures/csv-dev/mvp.csv"; + fn test_csv_read_local_no_headers() -> DaftResult<()> { + let file = format!( + "{}/test/iris_tiny_no_headers.csv", + env!("CARGO_MANIFEST_DIR"), + ); let mut io_config = IOConfig::default(); io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let table = read_csv(file, None, None, None, true, None, io_client, None, true)?; - assert_eq!(table.len(), 100); + let column_names = vec![ + "sepal.length", + "sepal.width", + "petal.length", + "petal.width", + "variety", + ]; + let table = read_csv( + file.as_ref(), + Some(column_names.clone()), + None, + None, + false, + None, + io_client, + None, + true, + None, + None, + None, + None, + )?; + assert_eq!(table.len(), 20); assert_eq!( table.schema, Schema::new(vec![ - Field::new("a", DataType::Int64), - Field::new("b", DataType::Utf8) + Field::new("sepal.length", DataType::Float64), + Field::new("sepal.width", DataType::Float64), + Field::new("petal.length", DataType::Float64), + Field::new("petal.width", DataType::Float64), + Field::new("variety", DataType::Utf8), ])? .into(), ); + check_equal_local_arrow2( + file.as_ref(), + &table, + false, + None, + Some(column_names), + None, + None, + ); Ok(()) } #[test] - fn test_csv_read_from_s3_larger_than_batch_size() -> DaftResult<()> { - let file = "s3://daft-public-data/test_fixtures/csv-dev/medium.csv"; + fn test_csv_read_local_delimiter() -> DaftResult<()> { + let file = format!( + "{}/test/iris_tiny_bar_delimiter.csv", + env!("CARGO_MANIFEST_DIR"), + ); let mut io_config = IOConfig::default(); io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let table = read_csv(file, None, None, None, true, None, io_client, None, true)?; - assert_eq!(table.len(), 5000); + let table = read_csv( + file.as_ref(), + None, + None, + Some(5), + true, + Some(b'|'), + io_client, + None, + true, + None, + None, + None, + None, + )?; + assert_eq!(table.len(), 5); + assert_eq!( + table.schema, + Schema::new(vec![ + Field::new("sepal.length", DataType::Float64), + Field::new("sepal.width", DataType::Float64), + Field::new("petal.length", DataType::Float64), + Field::new("petal.width", DataType::Float64), + Field::new("variety", DataType::Utf8), + ])? + .into(), + ); + check_equal_local_arrow2(file.as_ref(), &table, true, Some(b'|'), None, None, Some(5)); Ok(()) } #[test] - fn test_csv_read_from_s3_limit() -> DaftResult<()> { - let file = "s3://daft-public-data/test_fixtures/csv-dev/mvp.csv"; + fn test_csv_read_local_limit() -> DaftResult<()> { + let file = format!("{}/test/iris_tiny.csv", env!("CARGO_MANIFEST_DIR"),); let mut io_config = IOConfig::default(); io_config.s3.anonymous = true; @@ -254,32 +679,40 @@ mod tests { let io_client = Arc::new(IOClient::new(io_config.into())?); let table = read_csv( - file, + file.as_ref(), None, None, - Some(10), + Some(5), true, None, io_client, None, true, + None, + None, + None, + None, )?; - assert_eq!(table.len(), 10); + assert_eq!(table.len(), 5); assert_eq!( table.schema, Schema::new(vec![ - Field::new("a", DataType::Int64), - Field::new("b", DataType::Utf8) + Field::new("sepal.length", DataType::Float64), + Field::new("sepal.width", DataType::Float64), + Field::new("petal.length", DataType::Float64), + Field::new("petal.width", DataType::Float64), + Field::new("variety", DataType::Utf8), ])? .into(), ); + check_equal_local_arrow2(file.as_ref(), &table, true, None, None, None, Some(5)); Ok(()) } #[test] - fn test_csv_read_from_s3_projection() -> DaftResult<()> { - let file = "s3://daft-public-data/test_fixtures/csv-dev/mvp.csv"; + fn test_csv_read_local_projection() -> DaftResult<()> { + let file = format!("{}/test/iris_tiny.csv", env!("CARGO_MANIFEST_DIR"),); let mut io_config = IOConfig::default(); io_config.s3.anonymous = true; @@ -287,21 +720,716 @@ mod tests { let io_client = Arc::new(IOClient::new(io_config.into())?); let table = read_csv( - file, + file.as_ref(), None, - Some(vec!["b"]), + Some(vec!["petal.length", "petal.width"]), None, true, None, io_client, None, true, + None, + None, + None, + None, )?; - assert_eq!(table.len(), 100); + assert_eq!(table.len(), 20); assert_eq!( table.schema, - Schema::new(vec![Field::new("b", DataType::Utf8)])?.into(), + Schema::new(vec![ + Field::new("petal.length", DataType::Float64), + Field::new("petal.width", DataType::Float64), + ])? + .into(), + ); + check_equal_local_arrow2( + file.as_ref(), + &table, + true, + None, + None, + Some(vec![2, 3]), + None, + ); + + Ok(()) + } + + #[test] + fn test_csv_read_local_no_headers_and_projection() -> DaftResult<()> { + let file = format!( + "{}/test/iris_tiny_no_headers.csv", + env!("CARGO_MANIFEST_DIR"), + ); + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let column_names = vec![ + "sepal.length", + "sepal.width", + "petal.length", + "petal.width", + "variety", + ]; + let table = read_csv( + file.as_ref(), + Some(column_names.clone()), + Some(vec!["petal.length", "petal.width"]), + None, + false, + None, + io_client, + None, + true, + None, + None, + None, + None, + )?; + assert_eq!(table.len(), 20); + assert_eq!( + table.schema, + Schema::new(vec![ + Field::new("petal.length", DataType::Float64), + Field::new("petal.width", DataType::Float64), + ])? + .into(), + ); + check_equal_local_arrow2( + file.as_ref(), + &table, + false, + None, + Some(column_names), + Some(vec![2, 3]), + None, + ); + + Ok(()) + } + + #[test] + fn test_csv_read_local_larger_than_buffer_size() -> DaftResult<()> { + let file = format!("{}/test/iris_tiny.csv", env!("CARGO_MANIFEST_DIR"),); + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let table = read_csv( + file.as_ref(), + None, + None, + None, + true, + None, + io_client, + None, + true, + None, + Some(128), + None, + None, + )?; + assert_eq!(table.len(), 20); + assert_eq!( + table.schema, + Schema::new(vec![ + Field::new("sepal.length", DataType::Float64), + Field::new("sepal.width", DataType::Float64), + Field::new("petal.length", DataType::Float64), + Field::new("petal.width", DataType::Float64), + Field::new("variety", DataType::Utf8), + ])? + .into(), ); + check_equal_local_arrow2(file.as_ref(), &table, true, None, None, None, None); + + Ok(()) + } + + #[test] + fn test_csv_read_local_larger_than_chunk_size() -> DaftResult<()> { + let file = format!("{}/test/iris_tiny.csv", env!("CARGO_MANIFEST_DIR"),); + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let table = read_csv( + file.as_ref(), + None, + None, + None, + true, + None, + io_client, + None, + true, + None, + None, + Some(100), + None, + )?; + assert_eq!(table.len(), 20); + assert_eq!( + table.schema, + Schema::new(vec![ + Field::new("sepal.length", DataType::Float64), + Field::new("sepal.width", DataType::Float64), + Field::new("petal.length", DataType::Float64), + Field::new("petal.width", DataType::Float64), + Field::new("variety", DataType::Utf8), + ])? + .into(), + ); + check_equal_local_arrow2(file.as_ref(), &table, true, None, None, None, None); + + Ok(()) + } + + #[test] + fn test_csv_read_local_throttled_streaming() -> DaftResult<()> { + let file = format!("{}/test/iris_tiny.csv", env!("CARGO_MANIFEST_DIR"),); + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let table = read_csv( + file.as_ref(), + None, + None, + None, + true, + None, + io_client, + None, + true, + None, + None, + None, + Some(5), + )?; + assert_eq!(table.len(), 20); + assert_eq!( + table.schema, + Schema::new(vec![ + Field::new("sepal.length", DataType::Float64), + Field::new("sepal.width", DataType::Float64), + Field::new("petal.length", DataType::Float64), + Field::new("petal.width", DataType::Float64), + Field::new("variety", DataType::Utf8), + ])? + .into(), + ); + check_equal_local_arrow2(file.as_ref(), &table, true, None, None, None, None); + + Ok(()) + } + + #[test] + fn test_csv_read_local_nulls() -> DaftResult<()> { + let file = format!("{}/test/iris_tiny_nulls.csv", env!("CARGO_MANIFEST_DIR"),); + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let table = read_csv( + file.as_ref(), + None, + None, + None, + true, + None, + io_client, + None, + true, + None, + None, + None, + None, + )?; + assert_eq!(table.len(), 6); + assert_eq!( + table.schema, + Schema::new(vec![ + Field::new("sepal.length", DataType::Float64), + Field::new("sepal.width", DataType::Float64), + Field::new("petal.length", DataType::Float64), + Field::new("petal.width", DataType::Float64), + Field::new("variety", DataType::Utf8), + ])? + .into(), + ); + check_equal_local_arrow2(file.as_ref(), &table, true, None, None, None, None); + + Ok(()) + } + + #[test] + fn test_csv_read_local_empty_lines_dropped() -> DaftResult<()> { + let file = format!( + "{}/test/iris_tiny_empty_lines.csv", + env!("CARGO_MANIFEST_DIR"), + ); + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let table = read_csv( + file.as_ref(), + None, + None, + None, + true, + None, + io_client, + None, + true, + None, + None, + None, + None, + )?; + assert_eq!(table.len(), 3); + assert_eq!( + table.schema, + Schema::new(vec![ + Field::new("sepal.length", DataType::Float64), + Field::new("sepal.width", DataType::Float64), + Field::new("petal.length", DataType::Float64), + Field::new("petal.width", DataType::Float64), + Field::new("variety", DataType::Utf8), + ])? + .into(), + ); + check_equal_local_arrow2(file.as_ref(), &table, true, None, None, None, None); + + Ok(()) + } + + #[test] + fn test_csv_read_local_wrong_type_yields_nulls() -> DaftResult<()> { + let file = format!("{}/test/iris_tiny.csv", env!("CARGO_MANIFEST_DIR"),); + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let schema = Schema::new(vec![ + // Conversion to all of these types should fail, resulting in nulls. + Field::new("sepal.length", DataType::Boolean), + Field::new("sepal.width", DataType::Boolean), + Field::new("petal.length", DataType::Boolean), + Field::new("petal.width", DataType::Boolean), + Field::new("variety", DataType::Int64), + ])?; + let table = read_csv( + file.as_ref(), + None, + None, + None, + true, + None, + io_client, + None, + true, + Some(schema.into()), + None, + None, + None, + )?; + let num_rows = table.len(); + assert_eq!(num_rows, 20); + // Check that all columns are all null. + for idx in 0..table.num_columns() { + let column = table.get_column_by_index(idx)?; + assert_eq!(column.to_arrow().null_count(), num_rows); + } + + Ok(()) + } + + #[test] + fn test_csv_read_local_invalid_cols_header_mismatch() -> DaftResult<()> { + let file = format!( + "{}/test/iris_tiny_invalid_header_cols_mismatch.csv", + env!("CARGO_MANIFEST_DIR"), + ); + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let err = read_csv( + file.as_ref(), + None, + None, + None, + true, + None, + io_client, + None, + true, + None, + None, + None, + None, + ); + assert!(err.is_err()); + let err = err.unwrap_err(); + assert!(matches!(err, DaftError::ArrowError(_)), "{}", err); + assert!( + err.to_string() + .contains("found record with 4 fields, but the previous record has 5 fields"), + "{}", + err + ); + + Ok(()) + } + + #[test] + fn test_csv_read_local_invalid_no_header_variable_num_cols() -> DaftResult<()> { + let file = format!( + "{}/test/iris_tiny_invalid_no_header_variable_num_cols.csv", + env!("CARGO_MANIFEST_DIR"), + ); + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let err = read_csv( + file.as_ref(), + None, + None, + None, + false, + None, + io_client, + None, + true, + None, + None, + None, + None, + ); + assert!(err.is_err()); + let err = err.unwrap_err(); + assert!(matches!(err, DaftError::ArrowError(_)), "{}", err); + assert!( + err.to_string() + .contains("found record with 5 fields, but the previous record has 4 fields"), + "{}", + err + ); + + Ok(()) + } + + #[rstest] + fn test_csv_read_s3_compression( + #[values( + // Uncompressed + None, + // brotli + Some("br"), + // bzip2 + Some("bz2"), + // deflate + Some("deflate"), + // gzip + Some("gz"), + // lzma + Some("lzma"), + // xz + Some("xz"), + // zlib + Some("zl"), + // zstd + Some("zst"), + )] + compression: Option<&str>, + ) -> DaftResult<()> { + let file = format!( + "s3://daft-public-data/test_fixtures/csv-dev/mvp.csv{}", + compression.map_or("".to_string(), |ext| format!(".{}", ext)) + ); + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let table = read_csv( + file.as_ref(), + None, + None, + None, + true, + None, + io_client, + None, + true, + None, + None, + None, + None, + )?; + assert_eq!(table.len(), 100); + assert_eq!( + table.schema, + Schema::new(vec![ + Field::new("a", DataType::Int64), + Field::new("b", DataType::Utf8) + ])? + .into(), + ); + + Ok(()) + } + + #[test] + fn test_csv_read_s3_no_headers() -> DaftResult<()> { + let file = "s3://daft-public-data/test_fixtures/csv-dev/mvp_no_header.csv"; + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let column_names = vec!["a", "b"]; + let table = read_csv( + file.as_ref(), + Some(column_names.clone()), + None, + None, + false, + None, + io_client, + None, + true, + None, + None, + None, + None, + )?; + assert_eq!(table.len(), 100); + assert_eq!( + table.schema, + Schema::new(vec![ + Field::new("a", DataType::Int64), + Field::new("b", DataType::Utf8) + ])? + .into(), + ); + + Ok(()) + } + + #[test] + fn test_csv_read_s3_no_headers_and_projection() -> DaftResult<()> { + let file = "s3://daft-public-data/test_fixtures/csv-dev/mvp_no_header.csv"; + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let column_names = vec!["a", "b"]; + let table = read_csv( + file.as_ref(), + Some(column_names.clone()), + Some(vec!["b"]), + None, + false, + None, + io_client, + None, + true, + None, + None, + None, + None, + )?; + assert_eq!(table.len(), 100); + assert_eq!( + table.schema, + Schema::new(vec![Field::new("b", DataType::Utf8)])?.into(), + ); + + Ok(()) + } + + #[test] + fn test_csv_read_s3_limit() -> DaftResult<()> { + let file = "s3://daft-public-data/test_fixtures/csv-dev/mvp.csv"; + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let table = read_csv( + file, + None, + None, + Some(10), + true, + None, + io_client, + None, + true, + None, + None, + None, + None, + )?; + assert_eq!(table.len(), 10); + assert_eq!( + table.schema, + Schema::new(vec![ + Field::new("a", DataType::Int64), + Field::new("b", DataType::Utf8) + ])? + .into(), + ); + + Ok(()) + } + + #[test] + fn test_csv_read_s3_projection() -> DaftResult<()> { + let file = "s3://daft-public-data/test_fixtures/csv-dev/mvp.csv"; + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let table = read_csv( + file, + None, + Some(vec!["b"]), + None, + true, + None, + io_client, + None, + true, + None, + None, + None, + None, + )?; + assert_eq!(table.len(), 100); + assert_eq!( + table.schema, + Schema::new(vec![Field::new("b", DataType::Utf8)])?.into(), + ); + + Ok(()) + } + + #[test] + fn test_csv_read_s3_larger_than_buffer_size() -> DaftResult<()> { + let file = "s3://daft-public-data/test_fixtures/csv-dev/medium.csv"; + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let table = read_csv( + file, + None, + None, + None, + true, + None, + io_client, + None, + true, + None, + Some(100), + None, + None, + )?; + assert_eq!(table.len(), 5000); + + Ok(()) + } + + #[test] + fn test_csv_read_s3_larger_than_chunk_size() -> DaftResult<()> { + let file = "s3://daft-public-data/test_fixtures/csv-dev/medium.csv"; + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let table = read_csv( + file, + None, + None, + None, + true, + None, + io_client, + None, + true, + None, + None, + Some(100), + None, + )?; + assert_eq!(table.len(), 5000); + + Ok(()) + } + + #[test] + fn test_csv_read_s3_throttled_streaming() -> DaftResult<()> { + let file = "s3://daft-public-data/test_fixtures/csv-dev/medium.csv"; + + let mut io_config = IOConfig::default(); + io_config.s3.anonymous = true; + + let io_client = Arc::new(IOClient::new(io_config.into())?); + + let table = read_csv( + file, + None, + None, + None, + true, + None, + io_client, + None, + true, + None, + None, + None, + Some(5), + )?; + assert_eq!(table.len(), 5000); Ok(()) } diff --git a/src/daft-csv/test/iris_tiny.csv b/src/daft-csv/test/iris_tiny.csv new file mode 100644 index 0000000000..5d15229a3f --- /dev/null +++ b/src/daft-csv/test/iris_tiny.csv @@ -0,0 +1,21 @@ +"sepal.length","sepal.width","petal.length","petal.width","variety" +5.1,3.5,1.4,.2,"Setosa" +4.9,3,1.4,.2,"Setosa" +4.7,3.2,1.3,.2,"Setosa" +4.6,3.1,1.5,.2,"Setosa" +5,3.6,1.4,.2,"Setosa" +5.4,3.9,1.7,.4,"Setosa" +4.6,3.4,1.4,.3,"Setosa" +5,3.4,1.5,.2,"Setosa" +4.4,2.9,1.4,.2,"Setosa" +4.9,3.1,1.5,.1,"Setosa" +5.4,3.7,1.5,.2,"Setosa" +4.8,3.4,1.6,.2,"Setosa" +4.8,3,1.4,.1,"Setosa" +4.3,3,1.1,.1,"Setosa" +5.8,4,1.2,.2,"Setosa" +5.7,4.4,1.5,.4,"Setosa" +5.4,3.9,1.3,.4,"Setosa" +5.1,3.5,1.4,.3,"Setosa" +5.7,3.8,1.7,.3,"Setosa" +5.1,3.8,1.5,.3,"Setosa" diff --git a/src/daft-csv/test/iris_tiny.csv.br b/src/daft-csv/test/iris_tiny.csv.br new file mode 100644 index 0000000000..0df01da799 Binary files /dev/null and b/src/daft-csv/test/iris_tiny.csv.br differ diff --git a/src/daft-csv/test/iris_tiny.csv.bz2 b/src/daft-csv/test/iris_tiny.csv.bz2 new file mode 100644 index 0000000000..415a354b59 Binary files /dev/null and b/src/daft-csv/test/iris_tiny.csv.bz2 differ diff --git a/src/daft-csv/test/iris_tiny.csv.deflate b/src/daft-csv/test/iris_tiny.csv.deflate new file mode 100644 index 0000000000..cb9efee30c --- /dev/null +++ b/src/daft-csv/test/iris_tiny.csv.deflate @@ -0,0 +1,2 @@ +mQ + { ;AaVV%|&jN]oLb]̱_sc:oz@`sm2P턓pw /ďbpn撥1yC3g;֔EgpjLέI/h!!1+5> \ No newline at end of file diff --git a/src/daft-csv/test/iris_tiny.csv.gz b/src/daft-csv/test/iris_tiny.csv.gz new file mode 100644 index 0000000000..478749dc1c Binary files /dev/null and b/src/daft-csv/test/iris_tiny.csv.gz differ diff --git a/src/daft-csv/test/iris_tiny.csv.lzma b/src/daft-csv/test/iris_tiny.csv.lzma new file mode 100644 index 0000000000..b43b8bcdaf Binary files /dev/null and b/src/daft-csv/test/iris_tiny.csv.lzma differ diff --git a/src/daft-csv/test/iris_tiny.csv.xz b/src/daft-csv/test/iris_tiny.csv.xz new file mode 100644 index 0000000000..9a61d44473 Binary files /dev/null and b/src/daft-csv/test/iris_tiny.csv.xz differ diff --git a/src/daft-csv/test/iris_tiny.csv.zl b/src/daft-csv/test/iris_tiny.csv.zl new file mode 100644 index 0000000000..cc742f2251 --- /dev/null +++ b/src/daft-csv/test/iris_tiny.csv.zl @@ -0,0 +1,2 @@ +xmQ + { ;AaVV%|&jN]oLb]̱_sc:oz@`sm2P턓pw /ďbpn撥1yC3g;֔EgpjLέI/h!!1+5> \ No newline at end of file diff --git a/src/daft-csv/test/iris_tiny.csv.zst b/src/daft-csv/test/iris_tiny.csv.zst new file mode 100644 index 0000000000..a5160e2852 Binary files /dev/null and b/src/daft-csv/test/iris_tiny.csv.zst differ diff --git a/src/daft-csv/test/iris_tiny_bar_delimiter.csv b/src/daft-csv/test/iris_tiny_bar_delimiter.csv new file mode 100644 index 0000000000..e0c5acb47e --- /dev/null +++ b/src/daft-csv/test/iris_tiny_bar_delimiter.csv @@ -0,0 +1,21 @@ +"sepal.length"|"sepal.width"|"petal.length"|"petal.width"|"variety" +5.1|3.5|1.4|.2|"Setosa" +4.9|3|1.4|.2|"Setosa" +4.7|3.2|1.3|.2|"Setosa" +4.6|3.1|1.5|.2|"Setosa" +5|3.6|1.4|.2|"Setosa" +5.4|3.9|1.7|.4|"Setosa" +4.6|3.4|1.4|.3|"Setosa" +5|3.4|1.5|.2|"Setosa" +4.4|2.9|1.4|.2|"Setosa" +4.9|3.1|1.5|.1|"Setosa" +5.4|3.7|1.5|.2|"Setosa" +4.8|3.4|1.6|.2|"Setosa" +4.8|3|1.4|.1|"Setosa" +4.3|3|1.1|.1|"Setosa" +5.8|4|1.2|.2|"Setosa" +5.7|4.4|1.5|.4|"Setosa" +5.4|3.9|1.3|.4|"Setosa" +5.1|3.5|1.4|.3|"Setosa" +5.7|3.8|1.7|.3|"Setosa" +5.1|3.8|1.5|.3|"Setosa" diff --git a/src/daft-csv/test/iris_tiny_conflicting_dtypes.csv b/src/daft-csv/test/iris_tiny_conflicting_dtypes.csv new file mode 100644 index 0000000000..0bb2781bce --- /dev/null +++ b/src/daft-csv/test/iris_tiny_conflicting_dtypes.csv @@ -0,0 +1,3 @@ +"sepal.length","sepal.width","petal.length","petal.width","variety" +5.1,"bar",1.4,"quux","Setosa" +"foo",3,"baz",.2,false diff --git a/src/daft-csv/test/iris_tiny_empty_lines.csv b/src/daft-csv/test/iris_tiny_empty_lines.csv new file mode 100644 index 0000000000..3df0e5d1f8 --- /dev/null +++ b/src/daft-csv/test/iris_tiny_empty_lines.csv @@ -0,0 +1,6 @@ +"sepal.length","sepal.width","petal.length","petal.width","variety" +5.1,3.5,1.4,.2,"Setosa" + +4.9,3,1.4,.2,"Setosa" + +4.7,3.2,1.3,.2,"Setosa" diff --git a/src/daft-csv/test/iris_tiny_invalid_header_cols_mismatch.csv b/src/daft-csv/test/iris_tiny_invalid_header_cols_mismatch.csv new file mode 100644 index 0000000000..2361cb5866 --- /dev/null +++ b/src/daft-csv/test/iris_tiny_invalid_header_cols_mismatch.csv @@ -0,0 +1,4 @@ +"sepal.length","sepal.width","petal.length","petal.width","variety" +5.1,3.5,1.4,.2 +4.9,3,1.4,.2 +4.7,3.2,1.3,.2 diff --git a/src/daft-csv/test/iris_tiny_invalid_no_header_variable_num_cols.csv b/src/daft-csv/test/iris_tiny_invalid_no_header_variable_num_cols.csv new file mode 100644 index 0000000000..f09824a737 --- /dev/null +++ b/src/daft-csv/test/iris_tiny_invalid_no_header_variable_num_cols.csv @@ -0,0 +1,3 @@ +5.1,3.5,1.4,.2 +4.9,3,1.4,.2,"Seratosa" +4.7,3.2,1.3,.2 diff --git a/src/daft-csv/test/iris_tiny_no_headers.csv b/src/daft-csv/test/iris_tiny_no_headers.csv new file mode 100644 index 0000000000..57714e2ca2 --- /dev/null +++ b/src/daft-csv/test/iris_tiny_no_headers.csv @@ -0,0 +1,20 @@ +5.1,3.5,1.4,.2,"Setosa" +4.9,3,1.4,.2,"Setosa" +4.7,3.2,1.3,.2,"Setosa" +4.6,3.1,1.5,.2,"Setosa" +5,3.6,1.4,.2,"Setosa" +5.4,3.9,1.7,.4,"Setosa" +4.6,3.4,1.4,.3,"Setosa" +5,3.4,1.5,.2,"Setosa" +4.4,2.9,1.4,.2,"Setosa" +4.9,3.1,1.5,.1,"Setosa" +5.4,3.7,1.5,.2,"Setosa" +4.8,3.4,1.6,.2,"Setosa" +4.8,3,1.4,.1,"Setosa" +4.3,3,1.1,.1,"Setosa" +5.8,4,1.2,.2,"Setosa" +5.7,4.4,1.5,.4,"Setosa" +5.4,3.9,1.3,.4,"Setosa" +5.1,3.5,1.4,.3,"Setosa" +5.7,3.8,1.7,.3,"Setosa" +5.1,3.8,1.5,.3,"Setosa" diff --git a/src/daft-csv/test/iris_tiny_nulls.csv b/src/daft-csv/test/iris_tiny_nulls.csv new file mode 100644 index 0000000000..5773543eab --- /dev/null +++ b/src/daft-csv/test/iris_tiny_nulls.csv @@ -0,0 +1,7 @@ +"sepal.length","sepal.width","petal.length","petal.width","variety" +5.1,3.5,,.2,"Setosa" +4.9,3,1.4,,"Setosa" +4.7,3.2,1.3,.2,"Setosa" +4.6,,1.5,.2,"Setosa" +,.6,1.4,.2,"Setosa" +5.4,3.9,1.7,.4, diff --git a/src/daft-io/Cargo.toml b/src/daft-io/Cargo.toml index 0ee10318f9..c63b74fc25 100644 --- a/src/daft-io/Cargo.toml +++ b/src/daft-io/Cargo.toml @@ -32,7 +32,7 @@ serde_json = {workspace = true} snafu = {workspace = true} tokio = {workspace = true} tokio-stream = {workspace = true} -url = "2.4.0" +url = {workspace = true} [dependencies.reqwest] default-features = false diff --git a/src/daft-plan/src/source_info/file_format.rs b/src/daft-plan/src/source_info/file_format.rs index 1e4988f4d5..2bcb542f0e 100644 --- a/src/daft-plan/src/source_info/file_format.rs +++ b/src/daft-plan/src/source_info/file_format.rs @@ -79,6 +79,8 @@ impl_bincode_py_state_serialization!(ParquetSourceConfig); pub struct CsvSourceConfig { pub delimiter: String, pub has_headers: bool, + pub buffer_size: Option, + pub chunk_size: Option, } #[cfg(feature = "python")] @@ -90,11 +92,20 @@ impl CsvSourceConfig { /// /// * `delimiter` - The character delmiting individual cells in the CSV data. /// * `has_headers` - Whether the CSV has a header row; if so, it will be skipped during data parsing. + /// * `buffer_size` - Size of the buffer (in bytes) used by the streaming reader. + /// * `chunk_size` - Size of the chunks (in bytes) deserialized in parallel by the streaming reader. #[new] - fn new(delimiter: String, has_headers: bool) -> Self { + fn new( + delimiter: String, + has_headers: bool, + buffer_size: Option, + chunk_size: Option, + ) -> Self { Self { delimiter, has_headers, + buffer_size, + chunk_size, } } } diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 1f4df77e47..e8fd10537c 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -23,7 +23,7 @@ pub mod python; #[cfg(feature = "python")] pub use python::register_modules; -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct Table { pub schema: SchemaRef, columns: Vec,