From f5bcd4d2c468806f42e3a983c0fab08dbe4d39e6 Mon Sep 17 00:00:00 2001 From: Colin Ho Date: Thu, 31 Oct 2024 13:21:33 -0700 Subject: [PATCH] [FEAT] Streaming physical writes for native executor (#2992) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Streaming writes for swordfish (parquet + csv only). Iceberg and delta writes are here: https://github.com/Eventual-Inc/Daft/pull/2966 Implement streaming writes as a blocking sink. Unpartitioned writes run with 1 worker, and Partitioned writes run with NUM_CPUs workers. As a drive by, made blocking sinks parallelizable. **Behaviour** - Unpartitioned: Make writes to a `TargetFileSizeWriter`, which manages file sizes and row group sizes, as data is streamed in. - Partitioned: Partition data via a `Dispatcher` and send to workers based on the hash. Each worker runs a `PartitionedWriter` that manages partitioning by value, file sizes, and row group sizes. **Benchmarks:** I made a new benchmark suite in `tests/benchmarks/test_streaming_writes.py`, it tests writes of tpch lineitem to parquet/csv with/without partition columns and different file/rowgroup size. The streaming executor performs much better when there are partition columns, as seen in this screenshot. Without partition columns it is about the same, when target row group size / file size is decreased, it is slightly slower. Likely due to the fact that probably does more slicing, but will need to investigate more. Memory usage is the same for both. Screenshot 2024-10-03 at 11 22 32 AM Memory test on read->write parquet tpch lineitem sf1: Native: Screenshot 2024-10-08 at 1 48 34 PM Python: Screenshot 2024-10-08 at 1 48 50 PM --------- Co-authored-by: Colin Ho Co-authored-by: Colin Ho Co-authored-by: Colin Ho --- Cargo.lock | 20 ++ Cargo.toml | 11 + daft/io/writer.py | 159 +++++++++++++ daft/table/partitioning.py | 38 ++-- src/daft-core/src/utils/identity_hash_set.rs | 13 +- src/daft-io/src/lib.rs | 1 + src/daft-local-execution/Cargo.toml | 4 +- src/daft-local-execution/src/buffer.rs | 82 +++++++ src/daft-local-execution/src/dispatcher.rs | 102 +++++++++ .../src/intermediate_ops/buffer.rs | 91 -------- .../src/intermediate_ops/intermediate_op.rs | 21 +- .../src/intermediate_ops/mod.rs | 1 - src/daft-local-execution/src/lib.rs | 2 + src/daft-local-execution/src/pipeline.rs | 81 +++++-- src/daft-local-execution/src/run.rs | 2 +- .../src/sinks/aggregate.rs | 98 +++++--- .../src/sinks/blocking_sink.rs | 134 ++++++++--- .../src/sinks/hash_join_build.rs | 61 +++-- src/daft-local-execution/src/sinks/mod.rs | 1 + src/daft-local-execution/src/sinks/sort.rs | 105 +++++---- src/daft-local-execution/src/sinks/write.rs | 137 +++++++++++ src/daft-micropartition/src/ops/concat.rs | 24 +- src/daft-micropartition/src/python.rs | 4 +- src/daft-physical-plan/Cargo.toml | 3 + src/daft-physical-plan/src/local_plan.rs | 27 ++- src/daft-physical-plan/src/translate.rs | 15 ++ src/daft-plan/src/lib.rs | 2 +- src/daft-plan/src/sink_info.rs | 1 - src/daft-table/src/lib.rs | 17 ++ src/daft-table/src/ops/hash.rs | 18 +- src/daft-table/src/probeable/probe_set.rs | 4 +- src/daft-table/src/probeable/probe_table.rs | 4 +- src/daft-writers/Cargo.toml | 22 ++ src/daft-writers/src/batch.rs | 195 ++++++++++++++++ src/daft-writers/src/file.rs | 215 ++++++++++++++++++ src/daft-writers/src/lib.rs | 127 +++++++++++ src/daft-writers/src/partition.rs | 157 +++++++++++++ src/daft-writers/src/physical.rs | 77 +++++++ src/daft-writers/src/python.rs | 118 ++++++++++ src/daft-writers/src/test.rs | 84 +++++++ tests/benchmarks/conftest.py | 77 +++++++ tests/benchmarks/test_local_tpch.py | 83 +------ tests/benchmarks/test_streaming_writes.py | 68 ++++++ tests/cookbook/test_write.py | 18 +- tests/dataframe/test_decimals.py | 6 - tests/dataframe/test_temporals.py | 7 +- tests/io/test_csv_roundtrip.py | 6 +- tests/io/test_parquet_roundtrip.py | 7 - tests/io/test_s3_credentials_refresh.py | 6 - 49 files changed, 2135 insertions(+), 421 deletions(-) create mode 100644 daft/io/writer.py create mode 100644 src/daft-local-execution/src/buffer.rs create mode 100644 src/daft-local-execution/src/dispatcher.rs delete mode 100644 src/daft-local-execution/src/intermediate_ops/buffer.rs create mode 100644 src/daft-local-execution/src/sinks/write.rs create mode 100644 src/daft-writers/Cargo.toml create mode 100644 src/daft-writers/src/batch.rs create mode 100644 src/daft-writers/src/file.rs create mode 100644 src/daft-writers/src/lib.rs create mode 100644 src/daft-writers/src/partition.rs create mode 100644 src/daft-writers/src/physical.rs create mode 100644 src/daft-writers/src/python.rs create mode 100644 src/daft-writers/src/test.rs create mode 100644 tests/benchmarks/test_streaming_writes.py diff --git a/Cargo.lock b/Cargo.lock index 8500d3f053..f5a1ef9c25 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1803,12 +1803,14 @@ dependencies = [ "daft-micropartition", "daft-minhash", "daft-parquet", + "daft-physical-plan", "daft-plan", "daft-scan", "daft-scheduler", "daft-sql", "daft-stats", "daft-table", + "daft-writers", "lazy_static", "libc", "log", @@ -2084,6 +2086,7 @@ dependencies = [ name = "daft-local-execution" version = "0.3.0-dev0" dependencies = [ + "async-trait", "common-daft-config", "common-display", "common-error", @@ -2102,6 +2105,7 @@ dependencies = [ "daft-plan", "daft-scan", "daft-table", + "daft-writers", "futures", "indexmap 2.5.0", "lazy_static", @@ -2363,6 +2367,22 @@ dependencies = [ "serde", ] +[[package]] +name = "daft-writers" +version = "0.3.0-dev0" +dependencies = [ + "common-daft-config", + "common-error", + "common-file-formats", + "daft-core", + "daft-dsl", + "daft-io", + "daft-micropartition", + "daft-plan", + "daft-table", + "pyo3", +] + [[package]] name = "deflate64" version = "0.1.9" diff --git a/Cargo.toml b/Cargo.toml index 0fe63cbb28..6b0b24dbe0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,12 +22,14 @@ daft-local-execution = {path = "src/daft-local-execution", default-features = fa daft-micropartition = {path = "src/daft-micropartition", default-features = false} daft-minhash = {path = "src/daft-minhash", default-features = false} daft-parquet = {path = "src/daft-parquet", default-features = false} +daft-physical-plan = {path = "src/daft-physical-plan", default-features = false} daft-plan = {path = "src/daft-plan", default-features = false} daft-scan = {path = "src/daft-scan", default-features = false} daft-scheduler = {path = "src/daft-scheduler", default-features = false} daft-sql = {path = "src/daft-sql", default-features = false} daft-stats = {path = "src/daft-stats", default-features = false} daft-table = {path = "src/daft-table", default-features = false} +daft-writers = {path = "src/daft-writers", default-features = false} lazy_static = {workspace = true} log = {workspace = true} lzma-sys = {version = "*", features = ["static"]} @@ -53,12 +55,20 @@ python = [ "daft-local-execution/python", "daft-micropartition/python", "daft-parquet/python", + "daft-physical-plan/python", "daft-plan/python", "daft-scan/python", "daft-scheduler/python", "daft-sql/python", "daft-stats/python", "daft-table/python", + "daft-functions/python", + "daft-functions-json/python", + "daft-writers/python", + "common-daft-config/python", + "common-system-info/python", + "common-display/python", + "common-resource-request/python", "dep:pyo3", "dep:pyo3-log" ] @@ -141,6 +151,7 @@ members = [ "src/daft-scheduler", "src/daft-sketch", "src/daft-sql", + "src/daft-writers", "src/daft-table", "src/hyperloglog", "src/parquet2" diff --git a/daft/io/writer.py b/daft/io/writer.py new file mode 100644 index 0000000000..a3e99046a8 --- /dev/null +++ b/daft/io/writer.py @@ -0,0 +1,159 @@ +import uuid +from abc import ABC, abstractmethod +from typing import Optional + +from daft.daft import IOConfig +from daft.dependencies import pa, pacsv, pq +from daft.filesystem import ( + _resolve_paths_and_filesystem, + canonicalize_protocol, + get_protocol_from_path, +) +from daft.series import Series +from daft.table.micropartition import MicroPartition +from daft.table.partitioning import ( + partition_strings_to_path, + partition_values_to_str_mapping, +) +from daft.table.table import Table + + +class FileWriterBase(ABC): + def __init__( + self, + root_dir: str, + file_idx: int, + file_format: str, + partition_values: Optional[Table] = None, + compression: Optional[str] = None, + io_config: Optional[IOConfig] = None, + default_partition_fallback: str = "__HIVE_DEFAULT_PARTITION__", + ): + [self.resolved_path], self.fs = _resolve_paths_and_filesystem(root_dir, io_config=io_config) + protocol = get_protocol_from_path(root_dir) + canonicalized_protocol = canonicalize_protocol(protocol) + is_local_fs = canonicalized_protocol == "file" + + self.file_name = f"{uuid.uuid4()}-{file_idx}.{file_format}" + self.partition_values = partition_values + if self.partition_values is not None: + partition_strings = { + key: values.to_pylist()[0] + for key, values in partition_values_to_str_mapping(self.partition_values).items() + } + self.dir_path = partition_strings_to_path(self.resolved_path, partition_strings, default_partition_fallback) + else: + self.dir_path = f"{self.resolved_path}" + + self.full_path = f"{self.dir_path}/{self.file_name}" + if is_local_fs: + self.fs.create_dir(self.dir_path, recursive=True) + + self.compression = compression if compression is not None else "none" + + @abstractmethod + def write(self, table: MicroPartition) -> None: + """Write data to the file using the appropriate writer. + + Args: + table: MicroPartition containing the data to be written. + """ + pass + + @abstractmethod + def close(self) -> Table: + """Close the writer and return metadata about the written file. Write should not be called after close. + + Returns: + Table containing metadata about the written file, including path and partition values. + """ + pass + + +class ParquetFileWriter(FileWriterBase): + def __init__( + self, + root_dir: str, + file_idx: int, + partition_values: Optional[Table] = None, + compression: str = "none", + io_config: Optional[IOConfig] = None, + ): + super().__init__( + root_dir=root_dir, + file_idx=file_idx, + file_format="parquet", + partition_values=partition_values, + compression=compression, + io_config=io_config, + ) + self.is_closed = False + self.current_writer: Optional[pq.ParquetWriter] = None + + def _create_writer(self, schema: pa.Schema) -> pq.ParquetWriter: + return pq.ParquetWriter( + self.full_path, + schema, + compression=self.compression, + use_compliant_nested_type=False, + filesystem=self.fs, + ) + + def write(self, table: MicroPartition) -> None: + assert not self.is_closed, "Cannot write to a closed ParquetFileWriter" + if self.current_writer is None: + self.current_writer = self._create_writer(table.schema().to_pyarrow_schema()) + self.current_writer.write_table(table.to_arrow()) + + def close(self) -> Table: + if self.current_writer is not None: + self.current_writer.close() + + self.is_closed = True + metadata = {"path": Series.from_pylist([self.full_path])} + if self.partition_values is not None: + for col_name in self.partition_values.column_names(): + metadata[col_name] = self.partition_values.get_column(col_name) + return Table.from_pydict(metadata) + + +class CSVFileWriter(FileWriterBase): + def __init__( + self, + root_dir: str, + file_idx: int, + partition_values: Optional[Table] = None, + io_config: Optional[IOConfig] = None, + ): + super().__init__( + root_dir=root_dir, + file_idx=file_idx, + file_format="csv", + partition_values=partition_values, + io_config=io_config, + ) + self.current_writer: Optional[pacsv.CSVWriter] = None + self.is_closed = False + + def _create_writer(self, schema: pa.Schema) -> pacsv.CSVWriter: + return pacsv.CSVWriter( + self.full_path, + schema, + ) + + def write(self, table: MicroPartition) -> None: + assert not self.is_closed, "Cannot write to a closed CSVFileWriter" + if self.current_writer is None: + self.current_writer = self._create_writer(table.schema().to_pyarrow_schema()) + self.current_writer.write_table(table.to_arrow()) + + def close(self) -> Table: + if self.current_writer is not None: + self.current_writer.close() + + self.is_closed = True + metadata = {"path": Series.from_pylist([self.full_path])} + if self.partition_values is not None: + for col_name in self.partition_values.column_names(): + metadata[col_name] = self.partition_values.get_column(col_name) + return Table.from_pydict(metadata) diff --git a/daft/table/partitioning.py b/daft/table/partitioning.py index 70a590cb45..2333a198e5 100644 --- a/daft/table/partitioning.py +++ b/daft/table/partitioning.py @@ -1,13 +1,16 @@ -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from daft import Series from daft.expressions import ExpressionsProjection +from daft.table.table import Table from .micropartition import MicroPartition def partition_strings_to_path( - root_path: str, parts: Dict[str, str], partition_null_fallback: str = "__HIVE_DEFAULT_PARTITION__" + root_path: str, + parts: Dict[str, str], + partition_null_fallback: str = "__HIVE_DEFAULT_PARTITION__", ) -> str: keys = parts.keys() values = [partition_null_fallback if value is None else value for value in parts.values()] @@ -15,6 +18,25 @@ def partition_strings_to_path( return f"{root_path}/{postfix}" +def partition_values_to_str_mapping( + partition_values: Union[MicroPartition, Table], +) -> Dict[str, Series]: + null_part = Series.from_pylist( + [None] + ) # This is to ensure that the null values are replaced with the default_partition_fallback value + pkey_names = partition_values.column_names() + + partition_strings = {} + + for c in pkey_names: + column = partition_values.get_column(c) + string_names = column._to_str_values() + null_filled = column.is_null().if_else(null_part, string_names) + partition_strings[c] = null_filled + + return partition_strings + + class PartitionedTable: def __init__(self, table: MicroPartition, partition_keys: Optional[ExpressionsProjection]): self.table = table @@ -56,20 +78,10 @@ def partition_values_str(self) -> Optional[MicroPartition]: If the table is not partitioned, returns None. """ - null_part = Series.from_pylist([None]) partition_values = self.partition_values() if partition_values is None: return None else: - pkey_names = partition_values.column_names() - - partition_strings = {} - - for c in pkey_names: - column = partition_values.get_column(c) - string_names = column._to_str_values() - null_filled = column.is_null().if_else(null_part, string_names) - partition_strings[c] = null_filled - + partition_strings = partition_values_to_str_mapping(partition_values) return MicroPartition.from_pydict(partition_strings) diff --git a/src/daft-core/src/utils/identity_hash_set.rs b/src/daft-core/src/utils/identity_hash_set.rs index cadff5f774..71e021376c 100644 --- a/src/daft-core/src/utils/identity_hash_set.rs +++ b/src/daft-core/src/utils/identity_hash_set.rs @@ -1,4 +1,4 @@ -use std::hash::{BuildHasherDefault, Hasher}; +use std::hash::{BuildHasherDefault, Hash, Hasher}; pub type IdentityBuildHasher = BuildHasherDefault; @@ -27,3 +27,14 @@ impl Hasher for IdentityHasher { self.hash = i; } } + +pub struct IndexHash { + pub idx: u64, + pub hash: u64, +} + +impl Hash for IndexHash { + fn hash(&self, state: &mut H) { + state.write_u64(self.hash); + } +} diff --git a/src/daft-io/src/lib.rs b/src/daft-io/src/lib.rs index 63d1d08485..923be570ba 100644 --- a/src/daft-io/src/lib.rs +++ b/src/daft-io/src/lib.rs @@ -12,6 +12,7 @@ mod object_store_glob; mod s3_like; mod stats; mod stream_utils; + use azure_blob::AzureBlobSource; use common_file_formats::FileFormat; use google_cloud::GCSSource; diff --git a/src/daft-local-execution/Cargo.toml b/src/daft-local-execution/Cargo.toml index 8da3b93325..f0be341b6a 100644 --- a/src/daft-local-execution/Cargo.toml +++ b/src/daft-local-execution/Cargo.toml @@ -1,4 +1,5 @@ [dependencies] +async-trait = {workspace = true} common-daft-config = {path = "../common/daft-config", default-features = false} common-display = {path = "../common/display", default-features = false} common-error = {path = "../common/error", default-features = false} @@ -17,6 +18,7 @@ daft-physical-plan = {path = "../daft-physical-plan", default-features = false} daft-plan = {path = "../daft-plan", default-features = false} daft-scan = {path = "../daft-scan", default-features = false} daft-table = {path = "../daft-table", default-features = false} +daft-writers = {path = "../daft-writers", default-features = false} futures = {workspace = true} indexmap = {workspace = true} lazy_static = {workspace = true} @@ -29,7 +31,7 @@ tokio-stream = {workspace = true} tracing = {workspace = true} [features] -python = ["dep:pyo3", "common-daft-config/python", "common-file-formats/python", "common-error/python", "daft-dsl/python", "daft-io/python", "daft-micropartition/python", "daft-plan/python", "daft-scan/python", "common-display/python"] +python = ["dep:pyo3", "common-daft-config/python", "common-file-formats/python", "common-error/python", "daft-dsl/python", "daft-io/python", "daft-micropartition/python", "daft-plan/python", "daft-scan/python", "daft-writers/python", "common-display/python"] [lints] workspace = true diff --git a/src/daft-local-execution/src/buffer.rs b/src/daft-local-execution/src/buffer.rs new file mode 100644 index 0000000000..4211200182 --- /dev/null +++ b/src/daft-local-execution/src/buffer.rs @@ -0,0 +1,82 @@ +use std::{cmp::Ordering::*, collections::VecDeque, sync::Arc}; + +use common_error::DaftResult; +use daft_micropartition::MicroPartition; + +// A buffer that accumulates morsels until a threshold is reached +pub struct RowBasedBuffer { + pub buffer: VecDeque>, + pub curr_len: usize, + pub threshold: usize, +} + +impl RowBasedBuffer { + pub fn new(threshold: usize) -> Self { + assert!(threshold > 0); + Self { + buffer: VecDeque::new(), + curr_len: 0, + threshold, + } + } + + // Push a morsel to the buffer + pub fn push(&mut self, part: &Arc) { + self.curr_len += part.len(); + self.buffer.push_back(part.clone()); + } + + // Pop enough morsels that reach the threshold + // - If the buffer currently has not enough morsels, return None + // - If the buffer has exactly enough morsels, return the morsels + // - If the buffer has more than enough morsels, return a vec of morsels, each correctly sized to the threshold. + // The remaining morsels will be pushed back to the buffer + pub fn pop_enough(&mut self) -> DaftResult>>> { + match self.curr_len.cmp(&self.threshold) { + Less => Ok(None), + Equal => { + if self.buffer.len() == 1 { + let part = self.buffer.pop_front().unwrap(); + self.curr_len = 0; + Ok(Some(vec![part])) + } else { + let chunk = MicroPartition::concat(std::mem::take(&mut self.buffer))?; + self.curr_len = 0; + Ok(Some(vec![chunk.into()])) + } + } + Greater => { + let num_ready_chunks = self.curr_len / self.threshold; + let concated = MicroPartition::concat(std::mem::take(&mut self.buffer))?; + let mut start = 0; + let mut parts_to_return = Vec::with_capacity(num_ready_chunks); + for _ in 0..num_ready_chunks { + let end = start + self.threshold; + let part = concated.slice(start, end)?; + parts_to_return.push(part.into()); + start = end; + } + if start < concated.len() { + let part = concated.slice(start, concated.len())?; + self.curr_len = part.len(); + self.buffer.push_back(part.into()); + } else { + self.curr_len = 0; + } + Ok(Some(parts_to_return)) + } + } + } + + // Pop all morsels in the buffer regardless of the threshold + pub fn pop_all(&mut self) -> DaftResult>> { + assert!(self.curr_len < self.threshold); + if self.buffer.is_empty() { + Ok(None) + } else { + let concated = MicroPartition::concat(std::mem::take(&mut self.buffer))?; + self.curr_len = 0; + Ok(Some(concated.into())) + } + } +} diff --git a/src/daft-local-execution/src/dispatcher.rs b/src/daft-local-execution/src/dispatcher.rs new file mode 100644 index 0000000000..d21fc306b6 --- /dev/null +++ b/src/daft-local-execution/src/dispatcher.rs @@ -0,0 +1,102 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use common_error::DaftResult; +use daft_dsl::ExprRef; + +use crate::{ + buffer::RowBasedBuffer, channel::Sender, pipeline::PipelineResultType, + runtime_stats::CountingReceiver, +}; + +#[async_trait] +pub(crate) trait Dispatcher { + async fn dispatch( + &self, + receiver: CountingReceiver, + worker_senders: Vec>, + ) -> DaftResult<()>; +} + +pub(crate) struct RoundRobinBufferedDispatcher { + morsel_size: usize, +} + +impl RoundRobinBufferedDispatcher { + pub(crate) fn new(morsel_size: usize) -> Self { + Self { morsel_size } + } +} + +#[async_trait] +impl Dispatcher for RoundRobinBufferedDispatcher { + async fn dispatch( + &self, + mut receiver: CountingReceiver, + worker_senders: Vec>, + ) -> DaftResult<()> { + let mut next_worker_idx = 0; + let mut send_to_next_worker = |data: PipelineResultType| { + let next_worker_sender = worker_senders.get(next_worker_idx).unwrap(); + next_worker_idx = (next_worker_idx + 1) % worker_senders.len(); + next_worker_sender.send(data) + }; + + let mut buffer = RowBasedBuffer::new(self.morsel_size); + while let Some(morsel) = receiver.recv().await { + if morsel.should_broadcast() { + for worker_sender in &worker_senders { + let _ = worker_sender.send(morsel.clone()).await; + } + } else { + buffer.push(morsel.as_data()); + if let Some(ready) = buffer.pop_enough()? { + for r in ready { + let _ = send_to_next_worker(r.into()).await; + } + } + } + } + // Clear all remaining morsels + if let Some(last_morsel) = buffer.pop_all()? { + let _ = send_to_next_worker(last_morsel.into()).await; + } + Ok(()) + } +} + +pub(crate) struct PartitionedDispatcher { + partition_by: Vec, +} + +impl PartitionedDispatcher { + pub(crate) fn new(partition_by: Vec) -> Self { + Self { partition_by } + } +} + +#[async_trait] +impl Dispatcher for PartitionedDispatcher { + async fn dispatch( + &self, + mut receiver: CountingReceiver, + worker_senders: Vec>, + ) -> DaftResult<()> { + while let Some(morsel) = receiver.recv().await { + if morsel.should_broadcast() { + for worker_sender in &worker_senders { + let _ = worker_sender.send(morsel.clone()).await; + } + } else { + let partitions = morsel + .as_data() + .partition_by_hash(&self.partition_by, worker_senders.len())?; + for (partition, worker_sender) in partitions.into_iter().zip(worker_senders.iter()) + { + let _ = worker_sender.send(Arc::new(partition).into()).await; + } + } + } + Ok(()) + } +} diff --git a/src/daft-local-execution/src/intermediate_ops/buffer.rs b/src/daft-local-execution/src/intermediate_ops/buffer.rs deleted file mode 100644 index 3c66301610..0000000000 --- a/src/daft-local-execution/src/intermediate_ops/buffer.rs +++ /dev/null @@ -1,91 +0,0 @@ -use std::{ - cmp::Ordering::{Equal, Greater, Less}, - collections::VecDeque, - sync::Arc, -}; - -use common_error::DaftResult; -use daft_micropartition::MicroPartition; - -pub struct OperatorBuffer { - pub buffer: VecDeque>, - pub curr_len: usize, - pub threshold: usize, -} - -impl OperatorBuffer { - pub fn new(threshold: usize) -> Self { - assert!(threshold > 0); - Self { - buffer: VecDeque::new(), - curr_len: 0, - threshold, - } - } - - pub fn push(&mut self, part: Arc) { - self.curr_len += part.len(); - self.buffer.push_back(part); - } - - pub fn try_clear(&mut self) -> Option>> { - match self.curr_len.cmp(&self.threshold) { - Less => None, - Equal => self.clear_all(), - Greater => Some(self.clear_enough()), - } - } - - fn clear_enough(&mut self) -> DaftResult> { - assert!(self.curr_len > self.threshold); - - let mut to_concat = Vec::with_capacity(self.buffer.len()); - let mut remaining = self.threshold; - - while remaining > 0 { - let part = self.buffer.pop_front().expect("Buffer should not be empty"); - let part_len = part.len(); - if part_len <= remaining { - remaining -= part_len; - to_concat.push(part); - } else { - let (head, tail) = part.split_at(remaining)?; - remaining = 0; - to_concat.push(Arc::new(head)); - self.buffer.push_front(Arc::new(tail)); - break; - } - } - assert_eq!(remaining, 0); - - self.curr_len -= self.threshold; - match to_concat.len() { - 1 => Ok(to_concat.pop().unwrap()), - _ => MicroPartition::concat( - &to_concat - .iter() - .map(std::convert::AsRef::as_ref) - .collect::>(), - ) - .map(Arc::new), - } - } - - pub fn clear_all(&mut self) -> Option>> { - if self.buffer.is_empty() { - return None; - } - - let concated = MicroPartition::concat( - &self - .buffer - .iter() - .map(std::convert::AsRef::as_ref) - .collect::>(), - ) - .map(Arc::new); - self.buffer.clear(); - self.curr_len = 0; - Some(concated) - } -} diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index f8b3acbd4c..d7503fa33c 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -6,8 +6,8 @@ use common_runtime::get_compute_runtime; use daft_micropartition::MicroPartition; use tracing::{info_span, instrument}; -use super::buffer::OperatorBuffer; use crate::{ + buffer::RowBasedBuffer, channel::{create_channel, PipelineChannel, Receiver, Sender}, pipeline::{PipelineNode, PipelineResultType}, runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}, @@ -185,26 +185,23 @@ impl IntermediateNode { }; for (idx, mut receiver) in receivers.into_iter().enumerate() { - let mut buffer = OperatorBuffer::new(morsel_size); + let mut buffer = RowBasedBuffer::new(morsel_size); while let Some(morsel) = receiver.recv().await { if morsel.should_broadcast() { for worker_sender in &worker_senders { let _ = worker_sender.send((idx, morsel.clone())).await; } } else { - buffer.push(morsel.as_data().clone()); - if let Some(ready) = buffer.try_clear() { - let _ = send_to_next_worker(idx, ready?.into()).await; + buffer.push(morsel.as_data()); + if let Some(ready) = buffer.pop_enough()? { + for part in ready { + let _ = send_to_next_worker(idx, part.into()).await; + } } } } - // Buffer may still have some morsels left above the threshold - while let Some(ready) = buffer.try_clear() { - let _ = send_to_next_worker(idx, ready?.into()).await; - } - // Clear all remaining morsels - if let Some(last_morsel) = buffer.clear_all() { - let _ = send_to_next_worker(idx, last_morsel?.into()).await; + if let Some(ready) = buffer.pop_all()? { + let _ = send_to_next_worker(idx, ready.into()).await; } } Ok(()) diff --git a/src/daft-local-execution/src/intermediate_ops/mod.rs b/src/daft-local-execution/src/intermediate_ops/mod.rs index 5c935a24ed..336f6ebc92 100644 --- a/src/daft-local-execution/src/intermediate_ops/mod.rs +++ b/src/daft-local-execution/src/intermediate_ops/mod.rs @@ -1,7 +1,6 @@ pub mod actor_pool_project; pub mod aggregate; pub mod anti_semi_hash_join_probe; -pub mod buffer; pub mod explode; pub mod filter; pub mod inner_hash_join_probe; diff --git a/src/daft-local-execution/src/lib.rs b/src/daft-local-execution/src/lib.rs index e9b7e08b96..553ad18b40 100644 --- a/src/daft-local-execution/src/lib.rs +++ b/src/daft-local-execution/src/lib.rs @@ -1,5 +1,7 @@ #![feature(let_chains)] +mod buffer; mod channel; +mod dispatcher; mod intermediate_ops; mod pipeline; mod run; diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index cb84dacde4..b371af0c8a 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -1,7 +1,9 @@ use std::{collections::HashMap, sync::Arc}; +use common_daft_config::DaftExecutionConfig; use common_display::{mermaid::MermaidDisplayVisitor, tree::TreeDisplay}; use common_error::DaftResult; +use common_file_formats::FileFormat; use daft_core::{ datatypes::Field, prelude::{Schema, SchemaRef}, @@ -11,10 +13,12 @@ use daft_dsl::{col, join::get_common_join_keys, Expr}; use daft_micropartition::MicroPartition; use daft_physical_plan::{ ActorPoolProject, Concat, EmptyScan, Explode, Filter, HashAggregate, HashJoin, InMemoryScan, - Limit, LocalPhysicalPlan, Pivot, Project, Sample, Sort, UnGroupedAggregate, Unpivot, + Limit, LocalPhysicalPlan, PhysicalWrite, Pivot, Project, Sample, Sort, UnGroupedAggregate, + Unpivot, }; use daft_plan::{populate_aggregation_stages, JoinType}; use daft_table::ProbeState; +use daft_writers::make_writer_factory; use indexmap::IndexSet; use snafu::ResultExt; @@ -28,10 +32,15 @@ use crate::{ sample::SampleOperator, unpivot::UnpivotOperator, }, sinks::{ - aggregate::AggregateSink, blocking_sink::BlockingSinkNode, concat::ConcatSink, - hash_join_build::HashJoinBuildSink, limit::LimitSink, - outer_hash_join_probe::OuterHashJoinProbeSink, sort::SortSink, + aggregate::AggregateSink, + blocking_sink::BlockingSinkNode, + concat::ConcatSink, + hash_join_build::HashJoinBuildSink, + limit::LimitSink, + outer_hash_join_probe::OuterHashJoinProbeSink, + sort::SortSink, streaming_sink::StreamingSinkNode, + write::{WriteFormat, WriteSink}, }, sources::{empty_scan::EmptyScanSource, in_memory::InMemorySource}, ExecutionRuntimeHandle, PipelineCreationSnafu, @@ -102,6 +111,7 @@ pub fn viz_pipeline(root: &dyn PipelineNode) -> String { pub fn physical_plan_to_pipeline( physical_plan: &LocalPhysicalPlan, psets: &HashMap>>, + cfg: &Arc, ) -> crate::Result> { use daft_physical_plan::PhysicalScan; @@ -125,14 +135,14 @@ pub fn physical_plan_to_pipeline( input, projection, .. }) => { let proj_op = ProjectOperator::new(projection.clone()); - let child_node = physical_plan_to_pipeline(input, psets)?; + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; IntermediateNode::new(Arc::new(proj_op), vec![child_node]).boxed() } LocalPhysicalPlan::ActorPoolProject(ActorPoolProject { input, projection, .. }) => { let proj_op = ActorPoolProjectOperator::new(projection.clone()); - let child_node = physical_plan_to_pipeline(input, psets)?; + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; IntermediateNode::new(Arc::new(proj_op), vec![child_node]).boxed() } LocalPhysicalPlan::Sample(Sample { @@ -143,33 +153,33 @@ pub fn physical_plan_to_pipeline( .. }) => { let sample_op = SampleOperator::new(*fraction, *with_replacement, *seed); - let child_node = physical_plan_to_pipeline(input, psets)?; + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; IntermediateNode::new(Arc::new(sample_op), vec![child_node]).boxed() } LocalPhysicalPlan::Filter(Filter { input, predicate, .. }) => { let filter_op = FilterOperator::new(predicate.clone()); - let child_node = physical_plan_to_pipeline(input, psets)?; + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; IntermediateNode::new(Arc::new(filter_op), vec![child_node]).boxed() } LocalPhysicalPlan::Explode(Explode { input, to_explode, .. }) => { let explode_op = ExplodeOperator::new(to_explode.clone()); - let child_node = physical_plan_to_pipeline(input, psets)?; + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; IntermediateNode::new(Arc::new(explode_op), vec![child_node]).boxed() } LocalPhysicalPlan::Limit(Limit { input, num_rows, .. }) => { let sink = LimitSink::new(*num_rows as usize); - let child_node = physical_plan_to_pipeline(input, psets)?; + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; StreamingSinkNode::new(Arc::new(sink), vec![child_node]).boxed() } LocalPhysicalPlan::Concat(Concat { input, other, .. }) => { - let left_child = physical_plan_to_pipeline(input, psets)?; - let right_child = physical_plan_to_pipeline(other, psets)?; + let left_child = physical_plan_to_pipeline(input, psets, cfg)?; + let right_child = physical_plan_to_pipeline(other, psets, cfg)?; let sink = ConcatSink {}; StreamingSinkNode::new(Arc::new(sink), vec![left_child, right_child]).boxed() } @@ -189,7 +199,7 @@ pub fn physical_plan_to_pipeline( .collect(), vec![], ); - let child_node = physical_plan_to_pipeline(input, psets)?; + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; let post_first_agg_node = IntermediateNode::new(Arc::new(first_stage_agg_op), vec![child_node]).boxed(); @@ -202,7 +212,7 @@ pub fn physical_plan_to_pipeline( vec![], ); let second_stage_node = - BlockingSinkNode::new(second_stage_agg_sink.boxed(), post_first_agg_node).boxed(); + BlockingSinkNode::new(Arc::new(second_stage_agg_sink), post_first_agg_node).boxed(); let final_stage_project = ProjectOperator::new(final_exprs); @@ -217,7 +227,7 @@ pub fn physical_plan_to_pipeline( }) => { let (first_stage_aggs, second_stage_aggs, final_exprs) = populate_aggregation_stages(aggregations, schema, group_by); - let child_node = physical_plan_to_pipeline(input, psets)?; + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; let (post_first_agg_node, group_by) = if !first_stage_aggs.is_empty() { let agg_op = AggregateOperator::new( first_stage_aggs @@ -244,7 +254,7 @@ pub fn physical_plan_to_pipeline( group_by.clone(), ); let second_stage_node = - BlockingSinkNode::new(second_stage_agg_sink.boxed(), post_first_agg_node).boxed(); + BlockingSinkNode::new(Arc::new(second_stage_agg_sink), post_first_agg_node).boxed(); let final_stage_project = ProjectOperator::new(final_exprs); @@ -258,7 +268,7 @@ pub fn physical_plan_to_pipeline( value_name, .. }) => { - let child_node = physical_plan_to_pipeline(input, psets)?; + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; let unpivot_op = UnpivotOperator::new( ids.clone(), values.clone(), @@ -281,7 +291,7 @@ pub fn physical_plan_to_pipeline( value_column.clone(), names.clone(), ); - let child_node = physical_plan_to_pipeline(input, psets)?; + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; IntermediateNode::new(Arc::new(pivot_op), vec![child_node]).boxed() } LocalPhysicalPlan::Sort(Sort { @@ -291,8 +301,8 @@ pub fn physical_plan_to_pipeline( .. }) => { let sort_sink = SortSink::new(sort_by.clone(), descending.clone()); - let child_node = physical_plan_to_pipeline(input, psets)?; - BlockingSinkNode::new(sort_sink.boxed(), child_node).boxed() + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; + BlockingSinkNode::new(Arc::new(sort_sink), child_node).boxed() } LocalPhysicalPlan::HashJoin(HashJoin { @@ -361,11 +371,11 @@ pub fn physical_plan_to_pipeline( // we should move to a builder pattern let build_sink = HashJoinBuildSink::new(key_schema, casted_build_on, join_type)?; - let build_child_node = physical_plan_to_pipeline(build_child, psets)?; + let build_child_node = physical_plan_to_pipeline(build_child, psets, cfg)?; let build_node = - BlockingSinkNode::new(build_sink.boxed(), build_child_node).boxed(); + BlockingSinkNode::new(Arc::new(build_sink), build_child_node).boxed(); - let probe_child_node = physical_plan_to_pipeline(probe_child, psets)?; + let probe_child_node = physical_plan_to_pipeline(probe_child, psets, cfg)?; match join_type { JoinType::Anti | JoinType::Semi => Ok(IntermediateNode::new( @@ -409,8 +419,29 @@ pub fn physical_plan_to_pipeline( plan_name: physical_plan.name(), })? } - _ => { - unimplemented!("Physical plan not supported: {}", physical_plan.name()); + LocalPhysicalPlan::PhysicalWrite(PhysicalWrite { + input, + file_info, + data_schema, + file_schema, + .. + }) => { + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; + let writer_factory = make_writer_factory(file_info, data_schema, cfg); + let write_format = match (file_info.file_format, file_info.partition_cols.is_some()) { + (FileFormat::Parquet, true) => WriteFormat::PartitionedParquet, + (FileFormat::Parquet, false) => WriteFormat::Parquet, + (FileFormat::Csv, true) => WriteFormat::PartitionedCsv, + (FileFormat::Csv, false) => WriteFormat::Csv, + (_, _) => panic!("Unsupported file format"), + }; + let write_sink = WriteSink::new( + write_format, + writer_factory, + file_info.partition_cols.clone(), + file_schema.clone(), + ); + BlockingSinkNode::new(Arc::new(write_sink), child_node).boxed() } }; diff --git a/src/daft-local-execution/src/run.rs b/src/daft-local-execution/src/run.rs index ae0939ef8a..3cab41ef36 100644 --- a/src/daft-local-execution/src/run.rs +++ b/src/daft-local-execution/src/run.rs @@ -118,7 +118,7 @@ pub fn run_local( results_buffer_size: Option, ) -> DaftResult>> + Send>> { refresh_chrome_trace(); - let mut pipeline = physical_plan_to_pipeline(physical_plan, &psets)?; + let mut pipeline = physical_plan_to_pipeline(physical_plan, &psets, &cfg)?; let (tx, rx) = create_channel(results_buffer_size.unwrap_or(1)); let handle = std::thread::spawn(move || { let runtime = tokio::runtime::Builder::new_current_thread() diff --git a/src/daft-local-execution/src/sinks/aggregate.rs b/src/daft-local-execution/src/sinks/aggregate.rs index e94ff7c68b..abc8acce4c 100644 --- a/src/daft-local-execution/src/sinks/aggregate.rs +++ b/src/daft-local-execution/src/sinks/aggregate.rs @@ -5,19 +5,43 @@ use daft_dsl::ExprRef; use daft_micropartition::MicroPartition; use tracing::instrument; -use super::blocking_sink::{BlockingSink, BlockingSinkStatus}; -use crate::pipeline::PipelineResultType; +use super::blocking_sink::{BlockingSink, BlockingSinkState, BlockingSinkStatus}; +use crate::{pipeline::PipelineResultType, NUM_CPUS}; enum AggregateState { Accumulating(Vec>), - #[allow(dead_code)] - Done(Arc), + Done, +} + +impl AggregateState { + fn push(&mut self, part: Arc) { + if let Self::Accumulating(ref mut parts) = self { + parts.push(part); + } else { + panic!("AggregateSink should be in Accumulating state"); + } + } + + fn finalize(&mut self) -> Vec> { + let res = if let Self::Accumulating(ref mut parts) = self { + std::mem::take(parts) + } else { + panic!("AggregateSink should be in Accumulating state"); + }; + *self = Self::Done; + res + } +} + +impl BlockingSinkState for AggregateState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } } pub struct AggregateSink { agg_exprs: Vec, group_by: Vec, - state: AggregateState, } impl AggregateSink { @@ -25,47 +49,51 @@ impl AggregateSink { Self { agg_exprs, group_by, - state: AggregateState::Accumulating(vec![]), } } - - pub fn boxed(self) -> Box { - Box::new(self) - } } impl BlockingSink for AggregateSink { #[instrument(skip_all, name = "AggregateSink::sink")] - fn sink(&mut self, input: &Arc) -> DaftResult { - if let AggregateState::Accumulating(parts) = &mut self.state { - parts.push(input.clone()); - Ok(BlockingSinkStatus::NeedMoreInput) - } else { - panic!("AggregateSink should be in Accumulating state"); - } + fn sink( + &self, + input: &Arc, + mut state: Box, + ) -> DaftResult { + state + .as_any_mut() + .downcast_mut::() + .expect("AggregateSink should have AggregateState") + .push(input.clone()); + Ok(BlockingSinkStatus::NeedMoreInput(state)) } #[instrument(skip_all, name = "AggregateSink::finalize")] - fn finalize(&mut self) -> DaftResult> { - if let AggregateState::Accumulating(parts) = &mut self.state { - assert!( - !parts.is_empty(), - "We can not finalize AggregateSink with no data" - ); - let concated = MicroPartition::concat( - &parts - .iter() - .map(std::convert::AsRef::as_ref) - .collect::>(), - )?; - let agged = Arc::new(concated.agg(&self.agg_exprs, &self.group_by)?); - self.state = AggregateState::Done(agged.clone()); - Ok(Some(agged.into())) - } else { - panic!("AggregateSink should be in Accumulating state"); - } + fn finalize( + &self, + states: Vec>, + ) -> DaftResult> { + let all_parts = states.into_iter().flat_map(|mut state| { + state + .as_any_mut() + .downcast_mut::() + .expect("AggregateSink should have AggregateState") + .finalize() + }); + let concated = MicroPartition::concat(all_parts)?; + let agged = Arc::new(concated.agg(&self.agg_exprs, &self.group_by)?); + Ok(Some(agged.into())) } + fn name(&self) -> &'static str { "AggregateSink" } + + fn max_concurrency(&self) -> usize { + *NUM_CPUS + } + + fn make_state(&self) -> DaftResult> { + Ok(Box::new(AggregateState::Accumulating(vec![]))) + } } diff --git a/src/daft-local-execution/src/sinks/blocking_sink.rs b/src/daft-local-execution/src/sinks/blocking_sink.rs index 3fcbf8d660..7835fbe138 100644 --- a/src/daft-local-execution/src/sinks/blocking_sink.rs +++ b/src/daft-local-execution/src/sinks/blocking_sink.rs @@ -4,39 +4,58 @@ use common_display::tree::TreeDisplay; use common_error::DaftResult; use common_runtime::get_compute_runtime; use daft_micropartition::MicroPartition; -use tracing::info_span; +use snafu::ResultExt; +use tracing::{info_span, instrument}; use crate::{ - channel::PipelineChannel, + channel::{create_channel, PipelineChannel, Receiver}, + dispatcher::{Dispatcher, RoundRobinBufferedDispatcher}, pipeline::{PipelineNode, PipelineResultType}, runtime_stats::RuntimeStatsContext, - ExecutionRuntimeHandle, + ExecutionRuntimeHandle, JoinSnafu, TaskSet, }; +pub trait BlockingSinkState: Send + Sync { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any; +} + pub enum BlockingSinkStatus { - NeedMoreInput, + NeedMoreInput(Box), #[allow(dead_code)] - Finished, + Finished(Box), } pub trait BlockingSink: Send + Sync { - fn sink(&mut self, input: &Arc) -> DaftResult; - fn finalize(&mut self) -> DaftResult>; + fn sink( + &self, + input: &Arc, + state: Box, + ) -> DaftResult; + fn finalize( + &self, + states: Vec>, + ) -> DaftResult>; fn name(&self) -> &'static str; + fn make_state(&self) -> DaftResult>; + fn make_dispatcher(&self, runtime_handle: &ExecutionRuntimeHandle) -> Arc { + Arc::new(RoundRobinBufferedDispatcher::new( + runtime_handle.default_morsel_size(), + )) + } + fn max_concurrency(&self) -> usize; } pub struct BlockingSinkNode { - // use a RW lock - op: Arc>>, + op: Arc, name: &'static str, child: Box, runtime_stats: Arc, } impl BlockingSinkNode { - pub(crate) fn new(op: Box, child: Box) -> Self { + pub(crate) fn new(op: Arc, child: Box) -> Self { let name = op.name(); Self { - op: Arc::new(tokio::sync::Mutex::new(op)), + op, name, child, runtime_stats: RuntimeStatsContext::new(), @@ -45,6 +64,46 @@ impl BlockingSinkNode { pub(crate) fn boxed(self) -> Box { Box::new(self) } + + #[instrument(level = "info", skip_all, name = "BlockingSink::run_worker")] + async fn run_worker( + op: Arc, + mut input_receiver: Receiver, + rt_context: Arc, + ) -> DaftResult> { + let span = info_span!("BlockingSink::Sink"); + let compute_runtime = get_compute_runtime(); + let mut state = op.make_state()?; + while let Some(morsel) = input_receiver.recv().await { + let op = op.clone(); + let morsel = morsel.clone(); + let span = span.clone(); + let rt_context = rt_context.clone(); + let fut = async move { rt_context.in_span(&span, || op.sink(morsel.as_data(), state)) }; + let result = compute_runtime.await_on(fut).await??; + match result { + BlockingSinkStatus::NeedMoreInput(new_state) => { + state = new_state; + } + BlockingSinkStatus::Finished(new_state) => { + return Ok(new_state); + } + } + } + + Ok(state) + } + + fn spawn_workers( + op: Arc, + input_receivers: Vec>, + task_set: &mut TaskSet>>, + stats: Arc, + ) { + for input_receiver in input_receivers { + task_set.spawn(Self::run_worker(op.clone(), input_receiver, stats.clone())); + } + } } impl TreeDisplay for BlockingSinkNode { @@ -81,7 +140,7 @@ impl PipelineNode for BlockingSinkNode { runtime_handle: &mut ExecutionRuntimeHandle, ) -> crate::Result { let child = self.child.as_mut(); - let mut child_results_receiver = child + let child_results_receiver = child .start(false, runtime_handle)? .get_receiver_with_stats(&self.runtime_stats); @@ -89,34 +148,45 @@ impl PipelineNode for BlockingSinkNode { let destination_sender = destination_channel.get_next_sender_with_stats(&self.runtime_stats); let op = self.op.clone(); - let rt_context = self.runtime_stats.clone(); + let runtime_stats = self.runtime_stats.clone(); + let num_workers = op.max_concurrency(); + let (input_senders, input_receivers) = (0..num_workers).map(|_| create_channel(1)).unzip(); + let dispatcher = op.make_dispatcher(runtime_handle); runtime_handle.spawn( async move { - let span = info_span!("BlockingSinkNode::execute"); - let compute_runtime = get_compute_runtime(); - while let Some(val) = child_results_receiver.recv().await { - let op = op.clone(); - let span = span.clone(); - let rt_context = rt_context.clone(); - let fut = async move { - let mut guard = op.lock().await; - rt_context.in_span(&span, || guard.sink(val.as_data())) - }; - let result = compute_runtime.await_on(fut).await??; - if matches!(result, BlockingSinkStatus::Finished) { - break; - } + dispatcher + .dispatch(child_results_receiver, input_senders) + .await + }, + self.name(), + ); + + runtime_handle.spawn( + async move { + let mut task_set = TaskSet::new(); + Self::spawn_workers( + op.clone(), + input_receivers, + &mut task_set, + runtime_stats.clone(), + ); + + let mut finished_states = Vec::with_capacity(num_workers); + while let Some(result) = task_set.join_next().await { + let state = result.context(JoinSnafu)??; + finished_states.push(state); } + + let compute_runtime = get_compute_runtime(); let finalized_result = compute_runtime .await_on(async move { - let mut guard = op.lock().await; - rt_context.in_span(&info_span!("BlockingSinkNode::finalize"), || { - guard.finalize() + runtime_stats.in_span(&info_span!("BlockingSinkNode::finalize"), || { + op.finalize(finished_states) }) }) .await??; - if let Some(part) = finalized_result { - let _ = destination_sender.send(part).await; + if let Some(res) = finalized_result { + let _ = destination_sender.send(res).await; } Ok(()) }, diff --git a/src/daft-local-execution/src/sinks/hash_join_build.rs b/src/daft-local-execution/src/sinks/hash_join_build.rs index c8258e281a..677f63279d 100644 --- a/src/daft-local-execution/src/sinks/hash_join_build.rs +++ b/src/daft-local-execution/src/sinks/hash_join_build.rs @@ -7,7 +7,7 @@ use daft_micropartition::MicroPartition; use daft_plan::JoinType; use daft_table::{make_probeable_builder, ProbeState, ProbeableBuilder, Table}; -use super::blocking_sink::{BlockingSink, BlockingSinkStatus}; +use super::blocking_sink::{BlockingSink, BlockingSinkState, BlockingSinkStatus}; use crate::pipeline::PipelineResultType; enum ProbeTableState { @@ -74,8 +74,16 @@ impl ProbeTableState { } } +impl BlockingSinkState for ProbeTableState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } +} + pub struct HashJoinBuildSink { - probe_table_state: ProbeTableState, + key_schema: SchemaRef, + projection: Vec, + join_type: JoinType, } impl HashJoinBuildSink { @@ -85,13 +93,11 @@ impl HashJoinBuildSink { join_type: &JoinType, ) -> DaftResult { Ok(Self { - probe_table_state: ProbeTableState::new(&key_schema, projection, join_type)?, + key_schema, + projection, + join_type: *join_type, }) } - - pub(crate) fn boxed(self) -> Box { - Box::new(self) - } } impl BlockingSink for HashJoinBuildSink { @@ -99,17 +105,46 @@ impl BlockingSink for HashJoinBuildSink { "HashJoinBuildSink" } - fn sink(&mut self, input: &Arc) -> DaftResult { - self.probe_table_state.add_tables(input)?; - Ok(BlockingSinkStatus::NeedMoreInput) + fn sink( + &self, + input: &Arc, + mut state: Box, + ) -> DaftResult { + state + .as_any_mut() + .downcast_mut::() + .expect("HashJoinBuildSink should have ProbeTableState") + .add_tables(input)?; + Ok(BlockingSinkStatus::NeedMoreInput(state)) } - fn finalize(&mut self) -> DaftResult> { - self.probe_table_state.finalize()?; - if let ProbeTableState::Done { probe_state } = &self.probe_table_state { + fn finalize( + &self, + states: Vec>, + ) -> DaftResult> { + assert_eq!(states.len(), 1); + let mut state = states.into_iter().next().unwrap(); + let probe_table_state = state + .as_any_mut() + .downcast_mut::() + .expect("State type mismatch"); + probe_table_state.finalize()?; + if let ProbeTableState::Done { probe_state } = probe_table_state { Ok(Some(probe_state.clone().into())) } else { panic!("finalize should only be called after the probe table is built") } } + + fn max_concurrency(&self) -> usize { + 1 + } + + fn make_state(&self) -> DaftResult> { + Ok(Box::new(ProbeTableState::new( + &self.key_schema, + self.projection.clone(), + &self.join_type, + )?)) + } } diff --git a/src/daft-local-execution/src/sinks/mod.rs b/src/daft-local-execution/src/sinks/mod.rs index 7960e55a7c..64366385c3 100644 --- a/src/daft-local-execution/src/sinks/mod.rs +++ b/src/daft-local-execution/src/sinks/mod.rs @@ -6,3 +6,4 @@ pub mod limit; pub mod outer_hash_join_probe; pub mod sort; pub mod streaming_sink; +pub mod write; diff --git a/src/daft-local-execution/src/sinks/sort.rs b/src/daft-local-execution/src/sinks/sort.rs index 169ea9e55d..83c933d1ec 100644 --- a/src/daft-local-execution/src/sinks/sort.rs +++ b/src/daft-local-execution/src/sinks/sort.rs @@ -5,18 +5,42 @@ use daft_dsl::ExprRef; use daft_micropartition::MicroPartition; use tracing::instrument; -use super::blocking_sink::{BlockingSink, BlockingSinkStatus}; -use crate::pipeline::PipelineResultType; -pub struct SortSink { - sort_by: Vec, - descending: Vec, - state: SortState, -} +use super::blocking_sink::{BlockingSink, BlockingSinkState, BlockingSinkStatus}; +use crate::{pipeline::PipelineResultType, NUM_CPUS}; enum SortState { Building(Vec>), - #[allow(dead_code)] - Done(Arc), + Done, +} + +impl SortState { + fn push(&mut self, part: Arc) { + if let Self::Building(ref mut parts) = self { + parts.push(part); + } else { + panic!("SortSink should be in Building state"); + } + } + + fn finalize(&mut self) -> Vec> { + let res = if let Self::Building(ref mut parts) = self { + std::mem::take(parts) + } else { + panic!("SortSink should be in Building state"); + }; + *self = Self::Done; + res + } +} + +impl BlockingSinkState for SortState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } +} +pub struct SortSink { + sort_by: Vec, + descending: Vec, } impl SortSink { @@ -24,46 +48,51 @@ impl SortSink { Self { sort_by, descending, - state: SortState::Building(vec![]), } } - pub fn boxed(self) -> Box { - Box::new(self) - } } impl BlockingSink for SortSink { #[instrument(skip_all, name = "SortSink::sink")] - fn sink(&mut self, input: &Arc) -> DaftResult { - if let SortState::Building(parts) = &mut self.state { - parts.push(input.clone()); - } else { - panic!("SortSink should be in Building state"); - } - Ok(BlockingSinkStatus::NeedMoreInput) + fn sink( + &self, + input: &Arc, + mut state: Box, + ) -> DaftResult { + state + .as_any_mut() + .downcast_mut::() + .expect("SortSink should have sort state") + .push(input.clone()); + Ok(BlockingSinkStatus::NeedMoreInput(state)) } #[instrument(skip_all, name = "SortSink::finalize")] - fn finalize(&mut self) -> DaftResult> { - if let SortState::Building(parts) = &mut self.state { - assert!( - !parts.is_empty(), - "We can not finalize SortSink with no data" - ); - let concated = MicroPartition::concat( - &parts - .iter() - .map(std::convert::AsRef::as_ref) - .collect::>(), - )?; - let sorted = Arc::new(concated.sort(&self.sort_by, &self.descending)?); - self.state = SortState::Done(sorted.clone()); - Ok(Some(sorted.into())) - } else { - panic!("SortSink should be in Building state"); - } + fn finalize( + &self, + states: Vec>, + ) -> DaftResult> { + let parts = states.into_iter().flat_map(|mut state| { + let state = state + .as_any_mut() + .downcast_mut::() + .expect("State type mismatch"); + state.finalize() + }); + let concated = MicroPartition::concat(parts)?; + let sorted = Arc::new(concated.sort(&self.sort_by, &self.descending)?); + Ok(Some(sorted.into())) } + fn name(&self) -> &'static str { "SortResult" } + + fn make_state(&self) -> DaftResult> { + Ok(Box::new(SortState::Building(Vec::new()))) + } + + fn max_concurrency(&self) -> usize { + *NUM_CPUS + } } diff --git a/src/daft-local-execution/src/sinks/write.rs b/src/daft-local-execution/src/sinks/write.rs new file mode 100644 index 0000000000..002f32a25a --- /dev/null +++ b/src/daft-local-execution/src/sinks/write.rs @@ -0,0 +1,137 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use daft_core::prelude::SchemaRef; +use daft_dsl::ExprRef; +use daft_micropartition::MicroPartition; +use daft_table::Table; +use daft_writers::{FileWriter, WriterFactory}; +use tracing::instrument; + +use super::blocking_sink::{BlockingSink, BlockingSinkState, BlockingSinkStatus}; +use crate::{ + dispatcher::{Dispatcher, PartitionedDispatcher, RoundRobinBufferedDispatcher}, + pipeline::PipelineResultType, + NUM_CPUS, +}; + +pub enum WriteFormat { + Parquet, + PartitionedParquet, + Csv, + PartitionedCsv, +} + +struct WriteState { + writer: Box, Result = Vec>>, +} + +impl WriteState { + pub fn new( + writer: Box, Result = Vec
>>, + ) -> Self { + Self { writer } + } +} + +impl BlockingSinkState for WriteState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } +} + +pub(crate) struct WriteSink { + write_format: WriteFormat, + writer_factory: Arc, Result = Vec
>>, + partition_by: Option>, + file_schema: SchemaRef, +} + +impl WriteSink { + pub(crate) fn new( + write_format: WriteFormat, + writer_factory: Arc, Result = Vec
>>, + partition_by: Option>, + file_schema: SchemaRef, + ) -> Self { + Self { + write_format, + writer_factory, + partition_by, + file_schema, + } + } +} + +impl BlockingSink for WriteSink { + #[instrument(skip_all, name = "WriteSink::sink")] + fn sink( + &self, + input: &Arc, + mut state: Box, + ) -> DaftResult { + state + .as_any_mut() + .downcast_mut::() + .expect("WriteSink should have WriteState") + .writer + .write(input)?; + Ok(BlockingSinkStatus::NeedMoreInput(state)) + } + + #[instrument(skip_all, name = "WriteSink::finalize")] + fn finalize( + &self, + states: Vec>, + ) -> DaftResult> { + let mut results = vec![]; + for mut state in states { + let state = state + .as_any_mut() + .downcast_mut::() + .expect("State type mismatch"); + results.extend(state.writer.close()?); + } + let mp = Arc::new(MicroPartition::new_loaded( + self.file_schema.clone(), + results.into(), + None, + )); + Ok(Some(mp.into())) + } + + fn name(&self) -> &'static str { + match self.write_format { + WriteFormat::Parquet => "ParquetSink", + WriteFormat::PartitionedParquet => "PartitionedParquetSink", + WriteFormat::Csv => "CsvSink", + WriteFormat::PartitionedCsv => "PartitionedCsvSink", + } + } + + fn make_state(&self) -> DaftResult> { + let writer = self.writer_factory.create_writer(0, None)?; + Ok(Box::new(WriteState::new(writer)) as Box) + } + + fn make_dispatcher( + &self, + runtime_handle: &crate::ExecutionRuntimeHandle, + ) -> Arc { + if let Some(partition_by) = &self.partition_by { + Arc::new(PartitionedDispatcher::new(partition_by.clone())) + } else { + Arc::new(RoundRobinBufferedDispatcher::new( + runtime_handle.default_morsel_size(), + )) + } + } + + fn max_concurrency(&self) -> usize { + if self.partition_by.is_some() { + *NUM_CPUS + } else { + 1 + } + } +} diff --git a/src/daft-micropartition/src/ops/concat.rs b/src/daft-micropartition/src/ops/concat.rs index 2108cc01e3..2ca6160178 100644 --- a/src/daft-micropartition/src/ops/concat.rs +++ b/src/daft-micropartition/src/ops/concat.rs @@ -1,4 +1,4 @@ -use std::sync::Mutex; +use std::{borrow::Borrow, ops::Deref, sync::Mutex}; use common_error::{DaftError, DaftResult}; use daft_io::IOStatsContext; @@ -7,18 +7,25 @@ use daft_stats::TableMetadata; use crate::micropartition::{MicroPartition, TableState}; impl MicroPartition { - pub fn concat(mps: &[&Self]) -> DaftResult { + pub fn concat(mps: I) -> DaftResult + where + I: IntoIterator, + T: Deref, + T::Target: Borrow, + { + let mps: Vec<_> = mps.into_iter().collect(); if mps.is_empty() { return Err(DaftError::ValueError( "Need at least 1 MicroPartition to perform concat".to_string(), )); } - let first_table = mps.first().unwrap(); + let first_table = mps.first().unwrap().deref().borrow(); - let first_schema = first_table.schema.as_ref(); + let first_schema = &first_table.schema; for tab in mps.iter().skip(1) { - if tab.schema.as_ref() != first_schema { + let tab = tab.deref().borrow(); + if &tab.schema != first_schema { return Err(DaftError::SchemaMismatch(format!( "MicroPartition concat requires all schemas to match, {} vs {}", first_schema, tab.schema @@ -30,13 +37,14 @@ impl MicroPartition { let mut all_tables = vec![]; - for m in mps { + for m in &mps { + let m = m.deref().borrow(); let tables = m.tables_or_read(io_stats.clone())?; all_tables.extend_from_slice(tables.as_slice()); } let mut all_stats = None; - for stats in mps.iter().flat_map(|m| &m.statistics) { + for stats in mps.iter().flat_map(|m| &m.deref().borrow().statistics) { if all_stats.is_none() { all_stats = Some(stats.clone()); } @@ -48,7 +56,7 @@ impl MicroPartition { let new_len = all_tables.iter().map(daft_table::Table::len).sum(); Ok(Self { - schema: mps.first().unwrap().schema.clone(), + schema: first_schema.clone(), state: Mutex::new(TableState::Loaded(all_tables.into())), metadata: TableMetadata { length: new_len }, statistics: all_stats, diff --git a/src/daft-micropartition/src/python.rs b/src/daft-micropartition/src/python.rs index 8875c29517..ab9b4a7db1 100644 --- a/src/daft-micropartition/src/python.rs +++ b/src/daft-micropartition/src/python.rs @@ -146,8 +146,8 @@ impl PyMicroPartition { #[staticmethod] pub fn concat(py: Python, to_concat: Vec) -> PyResult { - let mps: Vec<_> = to_concat.iter().map(|t| t.inner.as_ref()).collect(); - py.allow_threads(|| Ok(MicroPartition::concat(mps.as_slice())?.into())) + let mps_iter = to_concat.iter().map(|t| t.inner.as_ref()); + py.allow_threads(|| Ok(MicroPartition::concat(mps_iter)?.into())) } pub fn slice(&self, py: Python, start: i64, end: i64) -> PyResult { diff --git a/src/daft-physical-plan/Cargo.toml b/src/daft-physical-plan/Cargo.toml index 778b8b8560..17d419e63f 100644 --- a/src/daft-physical-plan/Cargo.toml +++ b/src/daft-physical-plan/Cargo.toml @@ -8,6 +8,9 @@ daft-scan = {path = "../daft-scan", default-features = false} log = {workspace = true} strum = {version = "0.26", features = ["derive"]} +[features] +python = ["common-error/python", "common-resource-request/python", "daft-core/python", "daft-dsl/python", "daft-plan/python", "daft-scan/python"] + [lints] workspace = true diff --git a/src/daft-physical-plan/src/local_plan.rs b/src/daft-physical-plan/src/local_plan.rs index ea2144a982..720b53080e 100644 --- a/src/daft-physical-plan/src/local_plan.rs +++ b/src/daft-physical-plan/src/local_plan.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use common_resource_request::ResourceRequest; use daft_core::prelude::*; use daft_dsl::{AggExpr, ExprRef}; -use daft_plan::InMemoryInfo; +use daft_plan::{InMemoryInfo, OutputFileInfo}; use daft_scan::{ScanTask, ScanTaskRef}; pub type LocalPhysicalPlanRef = Arc; @@ -287,7 +287,22 @@ impl LocalPhysicalPlan { .arced() } - #[must_use] + pub(crate) fn physical_write( + input: LocalPhysicalPlanRef, + data_schema: SchemaRef, + file_schema: SchemaRef, + file_info: OutputFileInfo, + ) -> LocalPhysicalPlanRef { + Self::PhysicalWrite(PhysicalWrite { + input, + data_schema, + file_schema, + file_info, + plan_stats: PlanStats {}, + }) + .arced() + } + pub fn schema(&self) -> &SchemaRef { match self { Self::PhysicalScan(PhysicalScan { schema, .. }) @@ -447,7 +462,13 @@ pub struct Concat { } #[derive(Debug)] -pub struct PhysicalWrite {} +pub struct PhysicalWrite { + pub input: LocalPhysicalPlanRef, + pub data_schema: SchemaRef, + pub file_schema: SchemaRef, + pub file_info: OutputFileInfo, + pub plan_stats: PlanStats, +} #[derive(Debug)] pub struct PlanStats {} diff --git a/src/daft-physical-plan/src/translate.rs b/src/daft-physical-plan/src/translate.rs index fc9cd0d656..3e5ee5cf13 100644 --- a/src/daft-physical-plan/src/translate.rs +++ b/src/daft-physical-plan/src/translate.rs @@ -176,6 +176,21 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { log::warn!("Repartition Not supported for Local Executor!; This will be a No-Op"); translate(&repartition.input) } + LogicalPlan::Sink(sink) => { + use daft_plan::SinkInfo; + let input = translate(&sink.input)?; + let data_schema = input.schema().clone(); + match sink.sink_info.as_ref() { + SinkInfo::OutputFileInfo(info) => Ok(LocalPhysicalPlan::physical_write( + input, + data_schema, + sink.schema.clone(), + info.clone(), + )), + #[cfg(feature = "python")] + SinkInfo::CatalogInfo(_) => todo!("CatalogInfo not yet implemented"), + } + } LogicalPlan::Explode(explode) => { let input = translate(&explode.input)?; Ok(LocalPhysicalPlan::explode( diff --git a/src/daft-plan/src/lib.rs b/src/daft-plan/src/lib.rs index 2541a143db..25d8fd4b62 100644 --- a/src/daft-plan/src/lib.rs +++ b/src/daft-plan/src/lib.rs @@ -31,7 +31,7 @@ pub use physical_planner::{ #[cfg(feature = "python")] use pyo3::prelude::*; #[cfg(feature = "python")] -pub use sink_info::{DeltaLakeCatalogInfo, IcebergCatalogInfo, LanceCatalogInfo}; +pub use sink_info::{CatalogType, DeltaLakeCatalogInfo, IcebergCatalogInfo, LanceCatalogInfo}; pub use sink_info::{OutputFileInfo, SinkInfo}; pub use source_info::{FileInfo, FileInfos, InMemoryInfo, SourceInfo}; #[cfg(feature = "python")] diff --git a/src/daft-plan/src/sink_info.rs b/src/daft-plan/src/sink_info.rs index 02c8e05273..d083a1f53d 100644 --- a/src/daft-plan/src/sink_info.rs +++ b/src/daft-plan/src/sink_info.rs @@ -63,7 +63,6 @@ pub struct IcebergCatalogInfo { #[derivative(PartialEq = "ignore")] #[derivative(Hash = "ignore")] pub iceberg_schema: PyObject, - #[serde( serialize_with = "serialize_py_object", deserialize_with = "deserialize_py_object" diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 701791aaf9..0a450ace70 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -785,6 +785,23 @@ impl Table { } } +impl PartialEq for Table { + fn eq(&self, other: &Self) -> bool { + if self.len() != other.len() { + return false; + } + if self.schema != other.schema { + return false; + } + for (lhs, rhs) in self.columns.iter().zip(other.columns.iter()) { + if lhs != rhs { + return false; + } + } + true + } +} + impl Display for Table { // `f` is a buffer, and this method must write the formatted string into it fn fmt(&self, f: &mut Formatter) -> Result { diff --git a/src/daft-table/src/ops/hash.rs b/src/daft-table/src/ops/hash.rs index c011597c3f..4b4d120da0 100644 --- a/src/daft-table/src/ops/hash.rs +++ b/src/daft-table/src/ops/hash.rs @@ -1,28 +1,14 @@ -use std::{ - collections::{hash_map::RawEntryMut, HashMap}, - hash::{Hash, Hasher}, -}; +use std::collections::{hash_map::RawEntryMut, HashMap}; use common_error::{DaftError, DaftResult}; use daft_core::{ array::ops::{arrow2::comparison::build_multi_array_is_equal, as_arrow::AsArrow}, datatypes::UInt64Array, - utils::identity_hash_set::IdentityBuildHasher, + utils::identity_hash_set::{IdentityBuildHasher, IndexHash}, }; use crate::Table; -pub struct IndexHash { - pub idx: u64, - pub hash: u64, -} - -impl Hash for IndexHash { - fn hash(&self, state: &mut H) { - state.write_u64(self.hash); - } -} - impl Table { pub fn hash_rows(&self) -> DaftResult { if self.num_columns() == 0 { diff --git a/src/daft-table/src/probeable/probe_set.rs b/src/daft-table/src/probeable/probe_set.rs index a948ad2a4b..adf9251756 100644 --- a/src/daft-table/src/probeable/probe_set.rs +++ b/src/daft-table/src/probeable/probe_set.rs @@ -9,12 +9,12 @@ use daft_core::{ prelude::SchemaRef, utils::{ dyn_compare::{build_dyn_multi_array_compare, MultiDynArrayComparator}, - identity_hash_set::IdentityBuildHasher, + identity_hash_set::{IdentityBuildHasher, IndexHash}, }, }; use super::{ArrowTableEntry, IndicesMapper, Probeable, ProbeableBuilder}; -use crate::{ops::hash::IndexHash, Table}; +use crate::Table; pub struct ProbeSet { schema: SchemaRef, hash_table: HashMap, diff --git a/src/daft-table/src/probeable/probe_table.rs b/src/daft-table/src/probeable/probe_table.rs index c0a4bde0de..8e8c8c8647 100644 --- a/src/daft-table/src/probeable/probe_table.rs +++ b/src/daft-table/src/probeable/probe_table.rs @@ -9,12 +9,12 @@ use daft_core::{ prelude::SchemaRef, utils::{ dyn_compare::{build_dyn_multi_array_compare, MultiDynArrayComparator}, - identity_hash_set::IdentityBuildHasher, + identity_hash_set::{IdentityBuildHasher, IndexHash}, }, }; use super::{ArrowTableEntry, IndicesMapper, Probeable, ProbeableBuilder}; -use crate::{ops::hash::IndexHash, Table}; +use crate::Table; pub struct ProbeTable { schema: SchemaRef, diff --git a/src/daft-writers/Cargo.toml b/src/daft-writers/Cargo.toml new file mode 100644 index 0000000000..19cff1807d --- /dev/null +++ b/src/daft-writers/Cargo.toml @@ -0,0 +1,22 @@ +[dependencies] +common-daft-config = {path = "../common/daft-config", default-features = false} +common-error = {path = "../common/error", default-features = false} +common-file-formats = {path = "../common/file-formats", default-features = false} +daft-core = {path = "../daft-core", default-features = false} +daft-dsl = {path = "../daft-dsl", default-features = false} +daft-io = {path = "../daft-io", default-features = false} +daft-micropartition = {path = "../daft-micropartition", default-features = false} +daft-plan = {path = "../daft-plan", default-features = false} +daft-table = {path = "../daft-table", default-features = false} +pyo3 = {workspace = true, optional = true} + +[features] +python = ["dep:pyo3", "common-file-formats/python", "common-error/python", "daft-dsl/python", "daft-io/python", "daft-micropartition/python", "daft-plan/python"] + +[lints] +workspace = true + +[package] +edition = {workspace = true} +name = "daft-writers" +version = {workspace = true} diff --git a/src/daft-writers/src/batch.rs b/src/daft-writers/src/batch.rs new file mode 100644 index 0000000000..d8af2f94fb --- /dev/null +++ b/src/daft-writers/src/batch.rs @@ -0,0 +1,195 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use daft_micropartition::MicroPartition; +use daft_table::Table; + +use crate::{FileWriter, WriterFactory}; + +// TargetBatchWriter is a writer that writes in batches of rows, i.e. for Parquet where we want to write +// a row group at a time. +pub struct TargetBatchWriter { + target_in_memory_chunk_rows: usize, + writer: Box, Result = Option
>>, + leftovers: Option>, + is_closed: bool, +} + +impl TargetBatchWriter { + pub fn new( + target_in_memory_chunk_rows: usize, + writer: Box, Result = Option
>>, + ) -> Self { + Self { + target_in_memory_chunk_rows, + writer, + leftovers: None, + is_closed: false, + } + } +} + +impl FileWriter for TargetBatchWriter { + type Input = Arc; + type Result = Option
; + + fn write(&mut self, input: &Arc) -> DaftResult<()> { + assert!( + !self.is_closed, + "Cannot write to a closed TargetBatchWriter" + ); + let input = if let Some(leftovers) = self.leftovers.take() { + MicroPartition::concat([&leftovers, input])?.into() + } else { + input.clone() + }; + + let mut local_offset = 0; + loop { + let remaining_rows = input.len() - local_offset; + + use std::cmp::Ordering; + match remaining_rows.cmp(&self.target_in_memory_chunk_rows) { + Ordering::Equal => { + // Write exactly one chunk + let chunk = input.slice(local_offset, local_offset + remaining_rows)?; + return self.writer.write(&chunk.into()); + } + Ordering::Less => { + // Store remaining rows as leftovers + let remainder = input.slice(local_offset, local_offset + remaining_rows)?; + self.leftovers = Some(remainder.into()); + return Ok(()); + } + Ordering::Greater => { + // Write a complete chunk and continue + let chunk = input.slice( + local_offset, + local_offset + self.target_in_memory_chunk_rows, + )?; + self.writer.write(&chunk.into())?; + local_offset += self.target_in_memory_chunk_rows; + } + } + } + } + + fn close(&mut self) -> DaftResult { + if let Some(leftovers) = self.leftovers.take() { + self.writer.write(&leftovers)?; + } + self.is_closed = true; + self.writer.close() + } +} + +pub struct TargetBatchWriterFactory { + writer_factory: Arc, Result = Option
>>, + target_in_memory_chunk_rows: usize, +} + +impl TargetBatchWriterFactory { + pub fn new( + writer_factory: Arc, Result = Option
>>, + target_in_memory_chunk_rows: usize, + ) -> Self { + Self { + writer_factory, + target_in_memory_chunk_rows, + } + } +} + +impl WriterFactory for TargetBatchWriterFactory { + type Input = Arc; + type Result = Option
; + + fn create_writer( + &self, + file_idx: usize, + partition_values: Option<&Table>, + ) -> DaftResult>> { + let writer = self + .writer_factory + .create_writer(file_idx, partition_values)?; + Ok(Box::new(TargetBatchWriter::new( + self.target_in_memory_chunk_rows, + writer, + ))) + } +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::test::{make_dummy_mp, DummyWriterFactory}; + + #[test] + fn test_target_batch_writer_exact_batch() { + let dummy_writer_factory = DummyWriterFactory; + let mut writer = + TargetBatchWriter::new(1, dummy_writer_factory.create_writer(0, None).unwrap()); + + let mp = make_dummy_mp(1); + writer.write(&mp).unwrap(); + let res = writer.close().unwrap(); + + assert!(res.is_some()); + let write_count = res + .unwrap() + .get_column("write_count") + .unwrap() + .u64() + .unwrap() + .get(0) + .unwrap(); + assert_eq!(write_count, 1); + } + + #[test] + fn test_target_batch_writer_small_batches() { + let dummy_writer_factory = DummyWriterFactory; + let mut writer = + TargetBatchWriter::new(3, dummy_writer_factory.create_writer(0, None).unwrap()); + + for _ in 0..8 { + let mp = make_dummy_mp(1); + writer.write(&mp).unwrap(); + } + let res = writer.close().unwrap(); + + assert!(res.is_some()); + let write_count = res + .unwrap() + .get_column("write_count") + .unwrap() + .u64() + .unwrap() + .get(0) + .unwrap(); + assert_eq!(write_count, 3); + } + + #[test] + fn test_target_batch_writer_big_batch() { + let dummy_writer_factory = DummyWriterFactory; + let mut writer = + TargetBatchWriter::new(3, dummy_writer_factory.create_writer(0, None).unwrap()); + + let mp = make_dummy_mp(10); + writer.write(&mp).unwrap(); + let res = writer.close().unwrap(); + + assert!(res.is_some()); + let write_count = res + .unwrap() + .get_column("write_count") + .unwrap() + .u64() + .unwrap() + .get(0) + .unwrap(); + assert_eq!(write_count, 4); + } +} diff --git a/src/daft-writers/src/file.rs b/src/daft-writers/src/file.rs new file mode 100644 index 0000000000..639b720d15 --- /dev/null +++ b/src/daft-writers/src/file.rs @@ -0,0 +1,215 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use daft_micropartition::MicroPartition; +use daft_table::Table; + +use crate::{FileWriter, WriterFactory}; + +// TargetFileSizeWriter is a writer that writes in files of a target size. +// It rotates the writer when the current file reaches the target size. +struct TargetFileSizeWriter { + current_file_rows: usize, + current_writer: Box, Result = Option
>>, + writer_factory: Arc, Result = Option
>>, + target_in_memory_file_rows: usize, + results: Vec
, + partition_values: Option
, + is_closed: bool, +} + +impl TargetFileSizeWriter { + fn new( + target_in_memory_file_rows: usize, + writer_factory: Arc, Result = Option
>>, + partition_values: Option
, + ) -> DaftResult { + let writer: Box, Result = Option
>> = + writer_factory.create_writer(0, partition_values.as_ref())?; + Ok(Self { + current_file_rows: 0, + current_writer: writer, + writer_factory, + target_in_memory_file_rows, + results: vec![], + partition_values, + is_closed: false, + }) + } + + fn rotate_writer(&mut self) -> DaftResult<()> { + if let Some(result) = self.current_writer.close()? { + self.results.push(result); + } + self.current_writer = self + .writer_factory + .create_writer(self.results.len(), self.partition_values.as_ref())?; + Ok(()) + } +} + +impl FileWriter for TargetFileSizeWriter { + type Input = Arc; + type Result = Vec
; + + fn write(&mut self, input: &Arc) -> DaftResult<()> { + assert!( + !self.is_closed, + "Cannot write to a closed TargetFileSizeWriter" + ); + use std::cmp::Ordering; + + let mut local_offset = 0; + + loop { + let remaining_input_rows = input.len() - local_offset; + let rows_until_target = self.target_in_memory_file_rows - self.current_file_rows; + + match remaining_input_rows.cmp(&rows_until_target) { + Ordering::Equal => { + // Write exactly what's needed to fill the current file + let to_write = + input.slice(local_offset, local_offset + remaining_input_rows)?; + self.current_writer.write(&to_write.into())?; + self.rotate_writer()?; + self.current_file_rows = 0; + return Ok(()); + } + Ordering::Less => { + // Write remaining input and update counter + let to_write = + input.slice(local_offset, local_offset + remaining_input_rows)?; + self.current_writer.write(&to_write.into())?; + self.current_file_rows += remaining_input_rows; + return Ok(()); + } + Ordering::Greater => { + // Write what fits in current file + let to_write = input.slice(local_offset, local_offset + rows_until_target)?; + self.current_writer.write(&to_write.into())?; + self.rotate_writer()?; + self.current_file_rows = 0; + + // Update offset and continue loop + local_offset += rows_until_target; + } + } + } + } + + fn close(&mut self) -> DaftResult { + if self.current_file_rows > 0 { + if let Some(result) = self.current_writer.close()? { + self.results.push(result); + } + } + self.is_closed = true; + Ok(std::mem::take(&mut self.results)) + } +} + +pub(crate) struct TargetFileSizeWriterFactory { + writer_factory: Arc, Result = Option
>>, + target_in_memory_file_rows: usize, +} + +impl TargetFileSizeWriterFactory { + pub(crate) fn new( + writer_factory: Arc, Result = Option
>>, + target_in_memory_file_rows: usize, + ) -> Self { + Self { + writer_factory, + target_in_memory_file_rows, + } + } +} + +impl WriterFactory for TargetFileSizeWriterFactory { + type Input = Arc; + type Result = Vec
; + + fn create_writer( + &self, + _file_idx: usize, + partition_values: Option<&Table>, + ) -> DaftResult>> { + Ok(Box::new(TargetFileSizeWriter::new( + self.target_in_memory_file_rows, + self.writer_factory.clone(), + partition_values.cloned(), + )?) + as Box< + dyn FileWriter, + >) + } +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::test::{make_dummy_mp, DummyWriterFactory}; + + #[test] + fn test_target_file_writer_exact_file() { + let dummy_writer_factory = DummyWriterFactory; + let mut writer = + TargetFileSizeWriter::new(1, Arc::new(dummy_writer_factory), None).unwrap(); + + let mp = make_dummy_mp(1); + writer.write(&mp).unwrap(); + let res = writer.close().unwrap(); + assert_eq!(res.len(), 1); + } + + #[test] + fn test_target_file_writer_less_rows_for_one_file() { + let dummy_writer_factory = DummyWriterFactory; + let mut writer = + TargetFileSizeWriter::new(3, Arc::new(dummy_writer_factory), None).unwrap(); + + let mp = make_dummy_mp(2); + writer.write(&mp).unwrap(); + let res = writer.close().unwrap(); + assert_eq!(res.len(), 1); + } + + #[test] + fn test_target_file_writer_more_rows_for_one_file() { + let dummy_writer_factory = DummyWriterFactory; + let mut writer = + TargetFileSizeWriter::new(3, Arc::new(dummy_writer_factory), None).unwrap(); + + let mp = make_dummy_mp(4); + writer.write(&mp).unwrap(); + let res = writer.close().unwrap(); + assert_eq!(res.len(), 2); + } + + #[test] + fn test_target_file_writer_multiple_files() { + let dummy_writer_factory = DummyWriterFactory; + let mut writer = + TargetFileSizeWriter::new(3, Arc::new(dummy_writer_factory), None).unwrap(); + + let mp = make_dummy_mp(10); + writer.write(&mp).unwrap(); + let res = writer.close().unwrap(); + assert_eq!(res.len(), 4); + } + + #[test] + fn test_target_file_writer_many_writes_many_files() { + let dummy_writer_factory = DummyWriterFactory; + let mut writer = + TargetFileSizeWriter::new(3, Arc::new(dummy_writer_factory), None).unwrap(); + + for _ in 0..10 { + let mp = make_dummy_mp(1); + writer.write(&mp).unwrap(); + } + let res = writer.close().unwrap(); + assert_eq!(res.len(), 4); + } +} diff --git a/src/daft-writers/src/lib.rs b/src/daft-writers/src/lib.rs new file mode 100644 index 0000000000..405596485b --- /dev/null +++ b/src/daft-writers/src/lib.rs @@ -0,0 +1,127 @@ +#![feature(hash_raw_entry)] +#![feature(let_chains)] +mod batch; +mod file; +mod partition; +mod physical; + +#[cfg(test)] +mod test; + +#[cfg(feature = "python")] +mod python; + +use std::{cmp::min, sync::Arc}; + +use batch::TargetBatchWriterFactory; +use common_daft_config::DaftExecutionConfig; +use common_error::DaftResult; +use common_file_formats::FileFormat; +use daft_core::prelude::SchemaRef; +use daft_micropartition::MicroPartition; +use daft_plan::OutputFileInfo; +use daft_table::Table; +use file::TargetFileSizeWriterFactory; +use partition::PartitionedWriterFactory; +use physical::PhysicalWriterFactory; + +/// This trait is used to abstract the writing of data to a file. +/// The `Input` type is the type of data that will be written to the file. +/// The `Result` type is the type of the result that will be returned when the file is closed. +pub trait FileWriter: Send + Sync { + type Input; + type Result; + + /// Write data to the file. + fn write(&mut self, data: &Self::Input) -> DaftResult<()>; + + /// Close the file and return the result. The caller should NOT write to the file after calling this method. + fn close(&mut self) -> DaftResult; +} + +/// This trait is used to abstract the creation of a `FileWriter` +/// The `create_writer` method is used to create a new `FileWriter`. +/// `file_idx` is the index of the file that will be written to. +/// `partition_values` is the partition values of the data that will be written to the file. +pub trait WriterFactory: Send + Sync { + type Input; + type Result; + fn create_writer( + &self, + file_idx: usize, + partition_values: Option<&Table>, + ) -> DaftResult>>; +} + +pub fn make_writer_factory( + file_info: &OutputFileInfo, + schema: &SchemaRef, + cfg: &DaftExecutionConfig, +) -> Arc, Result = Vec
>> { + let estimated_row_size_bytes = schema.estimate_row_size_bytes(); + let base_writer_factory = PhysicalWriterFactory::new(file_info.clone()); + match file_info.file_format { + FileFormat::Parquet => { + let target_in_memory_file_size = + cfg.parquet_target_filesize as f64 * cfg.parquet_inflation_factor; + let target_in_memory_row_group_size = + cfg.parquet_target_row_group_size as f64 * cfg.parquet_inflation_factor; + + let target_file_rows = if estimated_row_size_bytes > 0.0 { + target_in_memory_file_size / estimated_row_size_bytes + } else { + target_in_memory_file_size + } as usize; + + let target_row_group_rows = min( + target_file_rows, + if estimated_row_size_bytes > 0.0 { + target_in_memory_row_group_size / estimated_row_size_bytes + } else { + target_in_memory_row_group_size + } as usize, + ); + + let row_group_writer_factory = + TargetBatchWriterFactory::new(Arc::new(base_writer_factory), target_row_group_rows); + + let file_writer_factory = TargetFileSizeWriterFactory::new( + Arc::new(row_group_writer_factory), + target_file_rows, + ); + + if let Some(partition_cols) = &file_info.partition_cols { + let partitioned_writer_factory = PartitionedWriterFactory::new( + Arc::new(file_writer_factory), + partition_cols.clone(), + ); + Arc::new(partitioned_writer_factory) + } else { + Arc::new(file_writer_factory) + } + } + FileFormat::Csv => { + let target_in_memory_file_size = + cfg.csv_target_filesize as f64 * cfg.csv_inflation_factor; + let target_file_rows = if estimated_row_size_bytes > 0.0 { + target_in_memory_file_size / estimated_row_size_bytes + } else { + target_in_memory_file_size + } as usize; + + let file_writer_factory = + TargetFileSizeWriterFactory::new(Arc::new(base_writer_factory), target_file_rows); + + if let Some(partition_cols) = &file_info.partition_cols { + let partitioned_writer_factory = PartitionedWriterFactory::new( + Arc::new(file_writer_factory), + partition_cols.clone(), + ); + Arc::new(partitioned_writer_factory) + } else { + Arc::new(file_writer_factory) + } + } + _ => unreachable!("Physical write should only support Parquet and CSV"), + } +} diff --git a/src/daft-writers/src/partition.rs b/src/daft-writers/src/partition.rs new file mode 100644 index 0000000000..73b0ab0216 --- /dev/null +++ b/src/daft-writers/src/partition.rs @@ -0,0 +1,157 @@ +use std::{ + collections::{hash_map::RawEntryMut, HashMap}, + sync::Arc, +}; + +use common_error::DaftResult; +use daft_core::{array::ops::as_arrow::AsArrow, utils::identity_hash_set::IndexHash}; +use daft_dsl::ExprRef; +use daft_io::IOStatsContext; +use daft_micropartition::MicroPartition; +use daft_table::Table; + +use crate::{FileWriter, WriterFactory}; + +/// PartitionedWriter is a writer that partitions the input data by a set of columns, and writes each partition +/// to a separate file. It uses a map to keep track of the writers for each partition. +struct PartitionedWriter { + // TODO: Figure out a way to NOT use the IndexHash + RawEntryMut pattern here. Ideally we want to store ScalarValues, aka. single Rows of the partition values as keys for the hashmap. + per_partition_writers: + HashMap, Result = Vec
>>>, + saved_partition_values: Vec
, + writer_factory: Arc, Result = Vec
>>, + partition_by: Vec, + is_closed: bool, +} + +impl PartitionedWriter { + pub fn new( + writer_factory: Arc, Result = Vec
>>, + partition_by: Vec, + ) -> Self { + Self { + per_partition_writers: HashMap::new(), + saved_partition_values: vec![], + writer_factory, + partition_by, + is_closed: false, + } + } + + fn partition( + partition_cols: &[ExprRef], + data: &Arc, + ) -> DaftResult<(Vec
, Table)> { + let data = data.concat_or_get(IOStatsContext::new("MicroPartition::partition_by_value"))?; + let table = data.first().unwrap(); + let (split_tables, partition_values) = table.partition_by_value(partition_cols)?; + Ok((split_tables, partition_values)) + } +} + +impl FileWriter for PartitionedWriter { + type Input = Arc; + type Result = Vec
; + + fn write(&mut self, input: &Arc) -> DaftResult<()> { + assert!( + !self.is_closed, + "Cannot write to a closed PartitionedWriter" + ); + + let (split_tables, partition_values) = + Self::partition(self.partition_by.as_slice(), input)?; + let partition_values_hash = partition_values.hash_rows()?; + for (idx, (table, partition_value_hash)) in split_tables + .into_iter() + .zip(partition_values_hash.as_arrow().values_iter()) + .enumerate() + { + let partition_value_row = partition_values.slice(idx, idx + 1)?; + let entry = self.per_partition_writers.raw_entry_mut().from_hash( + *partition_value_hash, + |other| { + (*partition_value_hash == other.hash) && { + let other_table = + self.saved_partition_values.get(other.idx as usize).unwrap(); + other_table == &partition_value_row + } + }, + ); + match entry { + RawEntryMut::Vacant(entry) => { + let mut writer = self + .writer_factory + .create_writer(0, Some(partition_value_row.as_ref()))?; + writer.write(&Arc::new(MicroPartition::new_loaded( + table.schema.clone(), + vec![table].into(), + None, + )))?; + entry.insert_hashed_nocheck( + *partition_value_hash, + IndexHash { + idx: self.saved_partition_values.len() as u64, + hash: *partition_value_hash, + }, + writer, + ); + self.saved_partition_values.push(partition_value_row); + } + RawEntryMut::Occupied(mut entry) => { + let writer = entry.get_mut(); + writer.write(&Arc::new(MicroPartition::new_loaded( + table.schema.clone(), + vec![table].into(), + None, + )))?; + } + } + } + Ok(()) + } + + fn close(&mut self) -> DaftResult { + let mut results = vec![]; + for (_, mut writer) in self.per_partition_writers.drain() { + results.extend(writer.close()?); + } + self.is_closed = true; + Ok(results) + } +} + +pub(crate) struct PartitionedWriterFactory { + writer_factory: Arc, Result = Vec
>>, + partition_cols: Vec, +} + +impl PartitionedWriterFactory { + pub(crate) fn new( + writer_factory: Arc, Result = Vec
>>, + partition_cols: Vec, + ) -> Self { + Self { + writer_factory, + partition_cols, + } + } +} +impl WriterFactory for PartitionedWriterFactory { + type Input = Arc; + type Result = Vec
; + + fn create_writer( + &self, + _file_idx: usize, + _partition_values: Option<&Table>, + ) -> DaftResult>> { + Ok(Box::new(PartitionedWriter::new( + self.writer_factory.clone(), + self.partition_cols.clone(), + )) + as Box< + dyn FileWriter, + >) + } +} diff --git a/src/daft-writers/src/physical.rs b/src/daft-writers/src/physical.rs new file mode 100644 index 0000000000..81cb3dfbfe --- /dev/null +++ b/src/daft-writers/src/physical.rs @@ -0,0 +1,77 @@ +use std::sync::Arc; + +use common_error::{DaftError, DaftResult}; +use common_file_formats::FileFormat; +use daft_micropartition::MicroPartition; +use daft_plan::OutputFileInfo; +use daft_table::Table; + +use crate::{FileWriter, WriterFactory}; + +/// PhysicalWriterFactory is a factory for creating physical writers, i.e. parquet, csv writers. +pub struct PhysicalWriterFactory { + output_file_info: OutputFileInfo, + native: bool, // TODO: Implement native writer +} + +impl PhysicalWriterFactory { + pub fn new(output_file_info: OutputFileInfo) -> Self { + Self { + output_file_info, + native: false, + } + } +} + +impl WriterFactory for PhysicalWriterFactory { + type Input = Arc; + type Result = Option
; + + fn create_writer( + &self, + file_idx: usize, + partition_values: Option<&Table>, + ) -> DaftResult>> { + match self.native { + true => unimplemented!(), + false => { + let writer = create_pyarrow_file_writer( + &self.output_file_info.root_dir, + file_idx, + &self.output_file_info.compression, + &self.output_file_info.io_config, + self.output_file_info.file_format, + partition_values, + )?; + Ok(writer) + } + } + } +} + +pub fn create_pyarrow_file_writer( + root_dir: &str, + file_idx: usize, + compression: &Option, + io_config: &Option, + format: FileFormat, + partition: Option<&Table>, +) -> DaftResult, Result = Option
>>> { + match format { + #[cfg(feature = "python")] + FileFormat::Parquet => Ok(Box::new(crate::python::PyArrowWriter::new_parquet_writer( + root_dir, + file_idx, + compression, + io_config, + partition, + )?)), + #[cfg(feature = "python")] + FileFormat::Csv => Ok(Box::new(crate::python::PyArrowWriter::new_csv_writer( + root_dir, file_idx, io_config, partition, + )?)), + _ => Err(DaftError::ComputeError( + "Unsupported file format for physical write".to_string(), + )), + } +} diff --git a/src/daft-writers/src/python.rs b/src/daft-writers/src/python.rs new file mode 100644 index 0000000000..7bcecb2b03 --- /dev/null +++ b/src/daft-writers/src/python.rs @@ -0,0 +1,118 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use daft_micropartition::{python::PyMicroPartition, MicroPartition}; +use daft_table::{python::PyTable, Table}; +use pyo3::{types::PyAnyMethods, PyObject, Python}; + +use crate::FileWriter; + +pub struct PyArrowWriter { + py_writer: PyObject, + is_closed: bool, +} + +impl PyArrowWriter { + pub fn new_parquet_writer( + root_dir: &str, + file_idx: usize, + compression: &Option, + io_config: &Option, + partition_values: Option<&Table>, + ) -> DaftResult { + Python::with_gil(|py| { + let file_writer_module = py.import_bound(pyo3::intern!(py, "daft.io.writer"))?; + let file_writer_class = file_writer_module.getattr("ParquetFileWriter")?; + let _from_pytable = py + .import_bound(pyo3::intern!(py, "daft.table"))? + .getattr(pyo3::intern!(py, "Table"))? + .getattr(pyo3::intern!(py, "_from_pytable"))?; + let partition_values = match partition_values { + Some(pv) => { + let py_table = _from_pytable.call1((PyTable::from(pv.clone()),))?; + Some(py_table) + } + None => None, + }; + + let py_writer = file_writer_class.call1(( + root_dir, + file_idx, + partition_values, + compression.as_ref().map(|c| c.as_str()), + io_config.as_ref().map(|cfg| daft_io::python::IOConfig { + config: cfg.clone(), + }), + ))?; + Ok(Self { + py_writer: py_writer.into(), + is_closed: false, + }) + }) + } + + pub fn new_csv_writer( + root_dir: &str, + file_idx: usize, + io_config: &Option, + partition_values: Option<&Table>, + ) -> DaftResult { + Python::with_gil(|py| { + let file_writer_module = py.import_bound(pyo3::intern!(py, "daft.io.writer"))?; + let file_writer_class = file_writer_module.getattr("CSVFileWriter")?; + let _from_pytable = py + .import_bound(pyo3::intern!(py, "daft.table"))? + .getattr(pyo3::intern!(py, "Table"))? + .getattr(pyo3::intern!(py, "_from_pytable"))?; + let partition_values = match partition_values { + Some(pv) => { + let py_table = _from_pytable.call1((PyTable::from(pv.clone()),))?; + Some(py_table) + } + None => None, + }; + let py_writer = file_writer_class.call1(( + root_dir, + file_idx, + partition_values, + io_config.as_ref().map(|cfg| daft_io::python::IOConfig { + config: cfg.clone(), + }), + ))?; + Ok(Self { + py_writer: py_writer.into(), + is_closed: false, + }) + }) + } +} + +impl FileWriter for PyArrowWriter { + type Input = Arc; + type Result = Option
; + + fn write(&mut self, data: &Self::Input) -> DaftResult<()> { + assert!(!self.is_closed, "Cannot write to a closed PyArrowWriter"); + Python::with_gil(|py| { + let py_micropartition = py + .import_bound(pyo3::intern!(py, "daft.table"))? + .getattr(pyo3::intern!(py, "MicroPartition"))? + .getattr(pyo3::intern!(py, "_from_pymicropartition"))? + .call1((PyMicroPartition::from(data.clone()),))?; + self.py_writer + .call_method1(py, pyo3::intern!(py, "write"), (py_micropartition,))?; + Ok(()) + }) + } + + fn close(&mut self) -> DaftResult { + self.is_closed = true; + Python::with_gil(|py| { + let result = self + .py_writer + .call_method0(py, pyo3::intern!(py, "close"))? + .getattr(py, pyo3::intern!(py, "_table"))?; + Ok(Some(result.extract::(py)?.into())) + }) + } +} diff --git a/src/daft-writers/src/test.rs b/src/daft-writers/src/test.rs new file mode 100644 index 0000000000..f862930d3f --- /dev/null +++ b/src/daft-writers/src/test.rs @@ -0,0 +1,84 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use daft_core::{ + prelude::{Int64Array, Schema, UInt64Array, Utf8Array}, + series::IntoSeries, +}; +use daft_micropartition::MicroPartition; +use daft_table::Table; + +use crate::{FileWriter, WriterFactory}; + +pub(crate) struct DummyWriterFactory; + +impl WriterFactory for DummyWriterFactory { + type Input = Arc; + type Result = Option
; + + fn create_writer( + &self, + file_idx: usize, + partition_values: Option<&Table>, + ) -> DaftResult>> { + Ok(Box::new(DummyWriter { + file_idx: file_idx.to_string(), + partition_values: partition_values.cloned(), + write_count: 0, + }) + as Box< + dyn FileWriter, + >) + } +} + +pub(crate) struct DummyWriter { + file_idx: String, + partition_values: Option
, + write_count: usize, +} + +impl FileWriter for DummyWriter { + type Input = Arc; + type Result = Option
; + + fn write(&mut self, _input: &Self::Input) -> DaftResult<()> { + self.write_count += 1; + Ok(()) + } + + fn close(&mut self) -> DaftResult { + let path_series = + Utf8Array::from_values("path", std::iter::once(self.file_idx.clone())).into_series(); + let write_count_series = + UInt64Array::from_values("write_count", std::iter::once(self.write_count as u64)) + .into_series(); + let path_table = Table::new_unchecked( + Schema::new(vec![ + path_series.field().clone(), + write_count_series.field().clone(), + ]) + .unwrap(), + vec![path_series.into(), write_count_series.into()], + 1, + ); + if let Some(partition_values) = self.partition_values.take() { + let unioned = path_table.union(&partition_values)?; + Ok(Some(unioned)) + } else { + Ok(Some(path_table)) + } + } +} + +pub(crate) fn make_dummy_mp(num_rows: usize) -> Arc { + let series = + Int64Array::from_values("ints", std::iter::repeat(42).take(num_rows)).into_series(); + let schema = Arc::new(Schema::new(vec![series.field().clone()]).unwrap()); + let table = Table::new_unchecked(schema.clone(), vec![series.into()], num_rows); + Arc::new(MicroPartition::new_loaded( + schema.into(), + vec![table].into(), + None, + )) +} diff --git a/tests/benchmarks/conftest.py b/tests/benchmarks/conftest.py index 48e4cd43d9..3c49a2733f 100644 --- a/tests/benchmarks/conftest.py +++ b/tests/benchmarks/conftest.py @@ -6,6 +6,17 @@ import memray import pytest +from fsspec.implementations.local import LocalFileSystem + +import daft +from benchmarking.tpch import data_generation +from tests.assets import TPCH_DBGEN_DIR + +IS_CI = True if os.getenv("CI") else False + +SCALE_FACTOR = 0.2 +NUM_PARTS = [1] if IS_CI else [1, 2] +SOURCE_TYPES = ["in-memory"] if IS_CI else ["parquet", "in-memory"] memray_stats = defaultdict(dict) @@ -48,3 +59,69 @@ def benchmark_wrapper(func, group): return track_mem(func, group) return benchmark_wrapper + + +@pytest.fixture(scope="session", params=NUM_PARTS) +def gen_tpch(request): + # Parametrize the number of parts for each file so that we run tests on single-partition files and multi-partition files + num_parts = request.param + + csv_files_location = data_generation.gen_csv_files(TPCH_DBGEN_DIR, num_parts, SCALE_FACTOR) + parquet_files_location = data_generation.gen_parquet(csv_files_location) + + in_memory_tables = {} + for tbl_name in data_generation.SCHEMA.keys(): + arrow_table = daft.read_parquet(f"{parquet_files_location}/{tbl_name}/*").to_arrow() + in_memory_tables[tbl_name] = daft.from_arrow(arrow_table) + + sqlite_path = data_generation.gen_sqlite_db( + csv_filepath=csv_files_location, + num_parts=num_parts, + ) + + return ( + csv_files_location, + parquet_files_location, + in_memory_tables, + num_parts, + ), sqlite_path + + +@pytest.fixture(scope="module", params=SOURCE_TYPES) +def get_df(gen_tpch, request): + (csv_files_location, parquet_files_location, in_memory_tables, num_parts), _ = gen_tpch + source_type = request.param + print(f"Source Type: {source_type}") + + def _get_df(tbl_name: str): + print(f"Table Name: {tbl_name}, Source Type: {source_type}") + if source_type == "csv": + local_fs = LocalFileSystem() + nonchunked_filepath = f"{csv_files_location}/{tbl_name}.tbl" + chunked_filepath = nonchunked_filepath + ".*" + try: + local_fs.expand_path(chunked_filepath) + fp = chunked_filepath + except FileNotFoundError: + fp = nonchunked_filepath + + df = daft.read_csv( + fp, + has_headers=False, + delimiter="|", + ) + df = df.select( + *[ + daft.col(autoname).alias(colname) + for autoname, colname in zip(df.column_names, data_generation.SCHEMA[tbl_name]) + ] + ) + elif source_type == "parquet": + fp = f"{parquet_files_location}/{tbl_name}/*" + df = daft.read_parquet(fp) + elif source_type == "in-memory": + df = in_memory_tables[tbl_name] + + return df + + return _get_df, num_parts diff --git a/tests/benchmarks/test_local_tpch.py b/tests/benchmarks/test_local_tpch.py index 07165d9ebc..eed56f4b0e 100644 --- a/tests/benchmarks/test_local_tpch.py +++ b/tests/benchmarks/test_local_tpch.py @@ -1,13 +1,12 @@ from __future__ import annotations -import os import sys import pytest -from fsspec.implementations.local import LocalFileSystem import daft -from benchmarking.tpch import answers, data_generation +from benchmarking.tpch import answers +from tests.benchmarks.conftest import IS_CI if sys.platform == "win32": pytest.skip(allow_module_level=True) @@ -15,83 +14,9 @@ import itertools import daft.context -from tests.assets import TPCH_DBGEN_DIR -from tests.integration.conftest import * # noqa: F403 +from tests.integration.conftest import check_answer # noqa F401 -IS_CI = True if os.getenv("CI") else False - -SCALE_FACTOR = 0.2 ENGINES = ["native"] if IS_CI else ["native", "python"] -NUM_PARTS = [1] if IS_CI else [1, 2] -SOURCE_TYPES = ["in-memory"] if IS_CI else ["parquet", "in-memory"] - - -@pytest.fixture(scope="session", params=NUM_PARTS) -def gen_tpch(request): - # Parametrize the number of parts for each file so that we run tests on single-partition files and multi-partition files - num_parts = request.param - - csv_files_location = data_generation.gen_csv_files(TPCH_DBGEN_DIR, num_parts, SCALE_FACTOR) - - # Disable native executor to generate parquet files, remove once native executor supports writing parquet files - with daft.context.execution_config_ctx(enable_native_executor=False): - parquet_files_location = data_generation.gen_parquet(csv_files_location) - - in_memory_tables = {} - for tbl_name in data_generation.SCHEMA.keys(): - arrow_table = daft.read_parquet(f"{parquet_files_location}/{tbl_name}/*").to_arrow() - in_memory_tables[tbl_name] = daft.from_arrow(arrow_table) - - sqlite_path = data_generation.gen_sqlite_db( - csv_filepath=csv_files_location, - num_parts=num_parts, - ) - - return ( - csv_files_location, - parquet_files_location, - in_memory_tables, - num_parts, - ), sqlite_path - - -@pytest.fixture(scope="module", params=SOURCE_TYPES) # TODO: Enable CSV after improving the CSV reader -def get_df(gen_tpch, request): - (csv_files_location, parquet_files_location, in_memory_tables, num_parts), _ = gen_tpch - source_type = request.param - - def _get_df(tbl_name: str): - if source_type == "csv": - local_fs = LocalFileSystem() - nonchunked_filepath = f"{csv_files_location}/{tbl_name}.tbl" - chunked_filepath = nonchunked_filepath + ".*" - try: - local_fs.expand_path(chunked_filepath) - fp = chunked_filepath - except FileNotFoundError: - fp = nonchunked_filepath - - df = daft.read_csv( - fp, - has_headers=False, - delimiter="|", - ) - df = df.select( - *[ - daft.col(autoname).alias(colname) - for autoname, colname in zip(df.column_names, data_generation.SCHEMA[tbl_name]) - ] - ) - elif source_type == "parquet": - fp = f"{parquet_files_location}/{tbl_name}/*" - df = daft.read_parquet(fp) - elif source_type == "in-memory": - df = in_memory_tables[tbl_name] - - return df - - return _get_df, num_parts - TPCH_QUESTIONS = list(range(1, 11)) @@ -102,7 +27,7 @@ def _get_df(tbl_name: str): ) @pytest.mark.benchmark(group="tpch") @pytest.mark.parametrize("engine, q", itertools.product(ENGINES, TPCH_QUESTIONS)) -def test_tpch(tmp_path, check_answer, get_df, benchmark_with_memray, engine, q): +def test_tpch(tmp_path, check_answer, get_df, benchmark_with_memray, engine, q): # noqa F811 get_df, num_parts = get_df def f(): diff --git a/tests/benchmarks/test_streaming_writes.py b/tests/benchmarks/test_streaming_writes.py new file mode 100644 index 0000000000..eaab068101 --- /dev/null +++ b/tests/benchmarks/test_streaming_writes.py @@ -0,0 +1,68 @@ +import pytest + +import daft +from tests.benchmarks.conftest import IS_CI + +ENGINES = ["native", "python"] + + +@pytest.mark.skipif(IS_CI, reason="Write benchmarks are not run in CI") +@pytest.mark.benchmark(group="write") +@pytest.mark.parametrize("engine", ENGINES) +@pytest.mark.parametrize( + "file_type, target_file_size, target_row_group_size", + [ + ("parquet", None, None), + ( + "parquet", + 5 * 1024 * 1024, + 1024 * 1024, + ), # 5MB target file size, 1MB target row group size + ("csv", None, None), + ], +) +@pytest.mark.parametrize("partition_cols", [None, ["L_SHIPMODE"]]) +@pytest.mark.parametrize("get_df", ["in-memory"], indirect=True) +def test_streaming_write( + tmp_path, + get_df, + benchmark_with_memray, + engine, + file_type, + target_file_size, + target_row_group_size, + partition_cols, +): + get_df, num_parts = get_df + daft_df = get_df("lineitem") + + def f(): + if engine == "native": + ctx = daft.context.execution_config_ctx( + enable_native_executor=True, + parquet_target_filesize=target_file_size, + parquet_target_row_group_size=target_row_group_size, + csv_target_filesize=target_file_size, + ) + elif engine == "python": + ctx = daft.context.execution_config_ctx( + enable_native_executor=False, + parquet_target_filesize=target_file_size, + parquet_target_row_group_size=target_row_group_size, + csv_target_filesize=target_file_size, + ) + else: + raise ValueError(f"{engine} unsupported") + + with ctx: + if file_type == "parquet": + return daft_df.write_parquet(tmp_path, partition_cols=partition_cols) + elif file_type == "csv": + return daft_df.write_csv(tmp_path, partition_cols=partition_cols) + else: + raise ValueError(f"{file_type} unsupported") + + benchmark_group = f"parts-{num_parts}-partition-cols-{partition_cols}-file-type-{file_type}-target-file-size-{target_file_size}-target-row-group-size-{target_row_group_size}" + result_files = benchmark_with_memray(f, benchmark_group).to_pydict()["path"] + read_back = daft.read_parquet(result_files) if file_type == "parquet" else daft.read_csv(result_files) + assert read_back.count_rows() == daft_df.count_rows() diff --git a/tests/cookbook/test_write.py b/tests/cookbook/test_write.py index ddd2c9b040..c00e00f1ac 100644 --- a/tests/cookbook/test_write.py +++ b/tests/cookbook/test_write.py @@ -7,14 +7,9 @@ from pyarrow import dataset as pads import daft -from daft import context from tests.conftest import assert_df_equals from tests.cookbook.assets import COOKBOOK_DATA_CSV -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) PYARROW_GE_7_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (7, 0, 0) @@ -178,9 +173,8 @@ def test_parquet_write_multifile(tmp_path, smaller_parquet_target_filesize): df = daft.from_pydict(data) df2 = df.write_parquet(tmp_path) assert len(df2) > 1 - ds = pads.dataset(tmp_path, format="parquet") - readback = ds.to_table() - assert readback.to_pydict() == data + read_back = daft.read_parquet(tmp_path.as_posix() + "/*.parquet").sort(by="x").to_pydict() + assert read_back == data @pytest.mark.skipif( @@ -201,9 +195,7 @@ def test_parquet_write_multifile_with_partitioning(tmp_path, smaller_parquet_tar def test_parquet_write_with_some_empty_partitions(tmp_path): data = {"x": [1, 2, 3], "y": ["a", "b", "c"]} - output_files = daft.from_pydict(data).into_partitions(4).write_parquet(tmp_path) - - assert len(output_files) == 3 + daft.from_pydict(data).into_partitions(4).write_parquet(tmp_path) read_back = daft.read_parquet(tmp_path.as_posix() + "/**/*.parquet").sort("x").to_pydict() assert read_back == data @@ -286,9 +278,7 @@ def test_empty_csv_write_with_partitioning(tmp_path): def test_csv_write_with_some_empty_partitions(tmp_path): data = {"x": [1, 2, 3], "y": ["a", "b", "c"]} - output_files = daft.from_pydict(data).into_partitions(4).write_csv(tmp_path) - - assert len(output_files) == 3 + daft.from_pydict(data).into_partitions(4).write_csv(tmp_path) read_back = daft.read_csv(tmp_path.as_posix() + "/**/*.csv").sort("x").to_pydict() assert read_back == data diff --git a/tests/dataframe/test_decimals.py b/tests/dataframe/test_decimals.py index daafec29f0..3a2d11babe 100644 --- a/tests/dataframe/test_decimals.py +++ b/tests/dataframe/test_decimals.py @@ -7,12 +7,6 @@ import pytest import daft -from daft import context - -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) PYARROW_GE_7_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (7, 0, 0) diff --git a/tests/dataframe/test_temporals.py b/tests/dataframe/test_temporals.py index 9d09ce20fa..1a26b8cffc 100644 --- a/tests/dataframe/test_temporals.py +++ b/tests/dataframe/test_temporals.py @@ -9,12 +9,7 @@ import pytz import daft -from daft import DataType, col, context - -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) +from daft import DataType, col PYARROW_GE_7_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (7, 0, 0) diff --git a/tests/io/test_csv_roundtrip.py b/tests/io/test_csv_roundtrip.py index b9e8ccc9b8..dd288e6806 100644 --- a/tests/io/test_csv_roundtrip.py +++ b/tests/io/test_csv_roundtrip.py @@ -7,12 +7,8 @@ import pytest import daft -from daft import DataType, TimeUnit, context +from daft import DataType, TimeUnit -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) PYARROW_GE_11_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (11, 0, 0) diff --git a/tests/io/test_parquet_roundtrip.py b/tests/io/test_parquet_roundtrip.py index 292c5b98e1..8804ef6d33 100644 --- a/tests/io/test_parquet_roundtrip.py +++ b/tests/io/test_parquet_roundtrip.py @@ -12,13 +12,6 @@ PYARROW_GE_8_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (8, 0, 0) -from daft import context - -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) - @pytest.mark.skipif( not PYARROW_GE_8_0_0, diff --git a/tests/io/test_s3_credentials_refresh.py b/tests/io/test_s3_credentials_refresh.py index 25c5c0cd8c..16a98fadf0 100644 --- a/tests/io/test_s3_credentials_refresh.py +++ b/tests/io/test_s3_credentials_refresh.py @@ -10,14 +10,8 @@ import pytest import daft -from daft import context from tests.io.mock_aws_server import start_service, stop_process -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) - @pytest.fixture(scope="session") def aws_log_file(tmp_path_factory: pytest.TempPathFactory) -> Iterator[io.IOBase]: