diff --git a/Cargo.lock b/Cargo.lock index 8500d3f053..f5a1ef9c25 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1803,12 +1803,14 @@ dependencies = [ "daft-micropartition", "daft-minhash", "daft-parquet", + "daft-physical-plan", "daft-plan", "daft-scan", "daft-scheduler", "daft-sql", "daft-stats", "daft-table", + "daft-writers", "lazy_static", "libc", "log", @@ -2084,6 +2086,7 @@ dependencies = [ name = "daft-local-execution" version = "0.3.0-dev0" dependencies = [ + "async-trait", "common-daft-config", "common-display", "common-error", @@ -2102,6 +2105,7 @@ dependencies = [ "daft-plan", "daft-scan", "daft-table", + "daft-writers", "futures", "indexmap 2.5.0", "lazy_static", @@ -2363,6 +2367,22 @@ dependencies = [ "serde", ] +[[package]] +name = "daft-writers" +version = "0.3.0-dev0" +dependencies = [ + "common-daft-config", + "common-error", + "common-file-formats", + "daft-core", + "daft-dsl", + "daft-io", + "daft-micropartition", + "daft-plan", + "daft-table", + "pyo3", +] + [[package]] name = "deflate64" version = "0.1.9" diff --git a/Cargo.toml b/Cargo.toml index 0fe63cbb28..6b0b24dbe0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,12 +22,14 @@ daft-local-execution = {path = "src/daft-local-execution", default-features = fa daft-micropartition = {path = "src/daft-micropartition", default-features = false} daft-minhash = {path = "src/daft-minhash", default-features = false} daft-parquet = {path = "src/daft-parquet", default-features = false} +daft-physical-plan = {path = "src/daft-physical-plan", default-features = false} daft-plan = {path = "src/daft-plan", default-features = false} daft-scan = {path = "src/daft-scan", default-features = false} daft-scheduler = {path = "src/daft-scheduler", default-features = false} daft-sql = {path = "src/daft-sql", default-features = false} daft-stats = {path = "src/daft-stats", default-features = false} daft-table = {path = "src/daft-table", default-features = false} +daft-writers = {path = "src/daft-writers", default-features = false} lazy_static = {workspace = true} log = {workspace = true} lzma-sys = {version = "*", features = ["static"]} @@ -53,12 +55,20 @@ python = [ "daft-local-execution/python", "daft-micropartition/python", "daft-parquet/python", + "daft-physical-plan/python", "daft-plan/python", "daft-scan/python", "daft-scheduler/python", "daft-sql/python", "daft-stats/python", "daft-table/python", + "daft-functions/python", + "daft-functions-json/python", + "daft-writers/python", + "common-daft-config/python", + "common-system-info/python", + "common-display/python", + "common-resource-request/python", "dep:pyo3", "dep:pyo3-log" ] @@ -141,6 +151,7 @@ members = [ "src/daft-scheduler", "src/daft-sketch", "src/daft-sql", + "src/daft-writers", "src/daft-table", "src/hyperloglog", "src/parquet2" diff --git a/daft/io/writer.py b/daft/io/writer.py new file mode 100644 index 0000000000..a3e99046a8 --- /dev/null +++ b/daft/io/writer.py @@ -0,0 +1,159 @@ +import uuid +from abc import ABC, abstractmethod +from typing import Optional + +from daft.daft import IOConfig +from daft.dependencies import pa, pacsv, pq +from daft.filesystem import ( + _resolve_paths_and_filesystem, + canonicalize_protocol, + get_protocol_from_path, +) +from daft.series import Series +from daft.table.micropartition import MicroPartition +from daft.table.partitioning import ( + partition_strings_to_path, + partition_values_to_str_mapping, +) +from daft.table.table import Table + + +class FileWriterBase(ABC): + def __init__( + self, + root_dir: str, + file_idx: int, + file_format: str, + partition_values: Optional[Table] = None, + compression: Optional[str] = None, + io_config: Optional[IOConfig] = None, + default_partition_fallback: str = "__HIVE_DEFAULT_PARTITION__", + ): + [self.resolved_path], self.fs = _resolve_paths_and_filesystem(root_dir, io_config=io_config) + protocol = get_protocol_from_path(root_dir) + canonicalized_protocol = canonicalize_protocol(protocol) + is_local_fs = canonicalized_protocol == "file" + + self.file_name = f"{uuid.uuid4()}-{file_idx}.{file_format}" + self.partition_values = partition_values + if self.partition_values is not None: + partition_strings = { + key: values.to_pylist()[0] + for key, values in partition_values_to_str_mapping(self.partition_values).items() + } + self.dir_path = partition_strings_to_path(self.resolved_path, partition_strings, default_partition_fallback) + else: + self.dir_path = f"{self.resolved_path}" + + self.full_path = f"{self.dir_path}/{self.file_name}" + if is_local_fs: + self.fs.create_dir(self.dir_path, recursive=True) + + self.compression = compression if compression is not None else "none" + + @abstractmethod + def write(self, table: MicroPartition) -> None: + """Write data to the file using the appropriate writer. + + Args: + table: MicroPartition containing the data to be written. + """ + pass + + @abstractmethod + def close(self) -> Table: + """Close the writer and return metadata about the written file. Write should not be called after close. + + Returns: + Table containing metadata about the written file, including path and partition values. + """ + pass + + +class ParquetFileWriter(FileWriterBase): + def __init__( + self, + root_dir: str, + file_idx: int, + partition_values: Optional[Table] = None, + compression: str = "none", + io_config: Optional[IOConfig] = None, + ): + super().__init__( + root_dir=root_dir, + file_idx=file_idx, + file_format="parquet", + partition_values=partition_values, + compression=compression, + io_config=io_config, + ) + self.is_closed = False + self.current_writer: Optional[pq.ParquetWriter] = None + + def _create_writer(self, schema: pa.Schema) -> pq.ParquetWriter: + return pq.ParquetWriter( + self.full_path, + schema, + compression=self.compression, + use_compliant_nested_type=False, + filesystem=self.fs, + ) + + def write(self, table: MicroPartition) -> None: + assert not self.is_closed, "Cannot write to a closed ParquetFileWriter" + if self.current_writer is None: + self.current_writer = self._create_writer(table.schema().to_pyarrow_schema()) + self.current_writer.write_table(table.to_arrow()) + + def close(self) -> Table: + if self.current_writer is not None: + self.current_writer.close() + + self.is_closed = True + metadata = {"path": Series.from_pylist([self.full_path])} + if self.partition_values is not None: + for col_name in self.partition_values.column_names(): + metadata[col_name] = self.partition_values.get_column(col_name) + return Table.from_pydict(metadata) + + +class CSVFileWriter(FileWriterBase): + def __init__( + self, + root_dir: str, + file_idx: int, + partition_values: Optional[Table] = None, + io_config: Optional[IOConfig] = None, + ): + super().__init__( + root_dir=root_dir, + file_idx=file_idx, + file_format="csv", + partition_values=partition_values, + io_config=io_config, + ) + self.current_writer: Optional[pacsv.CSVWriter] = None + self.is_closed = False + + def _create_writer(self, schema: pa.Schema) -> pacsv.CSVWriter: + return pacsv.CSVWriter( + self.full_path, + schema, + ) + + def write(self, table: MicroPartition) -> None: + assert not self.is_closed, "Cannot write to a closed CSVFileWriter" + if self.current_writer is None: + self.current_writer = self._create_writer(table.schema().to_pyarrow_schema()) + self.current_writer.write_table(table.to_arrow()) + + def close(self) -> Table: + if self.current_writer is not None: + self.current_writer.close() + + self.is_closed = True + metadata = {"path": Series.from_pylist([self.full_path])} + if self.partition_values is not None: + for col_name in self.partition_values.column_names(): + metadata[col_name] = self.partition_values.get_column(col_name) + return Table.from_pydict(metadata) diff --git a/daft/table/partitioning.py b/daft/table/partitioning.py index 70a590cb45..2333a198e5 100644 --- a/daft/table/partitioning.py +++ b/daft/table/partitioning.py @@ -1,13 +1,16 @@ -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from daft import Series from daft.expressions import ExpressionsProjection +from daft.table.table import Table from .micropartition import MicroPartition def partition_strings_to_path( - root_path: str, parts: Dict[str, str], partition_null_fallback: str = "__HIVE_DEFAULT_PARTITION__" + root_path: str, + parts: Dict[str, str], + partition_null_fallback: str = "__HIVE_DEFAULT_PARTITION__", ) -> str: keys = parts.keys() values = [partition_null_fallback if value is None else value for value in parts.values()] @@ -15,6 +18,25 @@ def partition_strings_to_path( return f"{root_path}/{postfix}" +def partition_values_to_str_mapping( + partition_values: Union[MicroPartition, Table], +) -> Dict[str, Series]: + null_part = Series.from_pylist( + [None] + ) # This is to ensure that the null values are replaced with the default_partition_fallback value + pkey_names = partition_values.column_names() + + partition_strings = {} + + for c in pkey_names: + column = partition_values.get_column(c) + string_names = column._to_str_values() + null_filled = column.is_null().if_else(null_part, string_names) + partition_strings[c] = null_filled + + return partition_strings + + class PartitionedTable: def __init__(self, table: MicroPartition, partition_keys: Optional[ExpressionsProjection]): self.table = table @@ -56,20 +78,10 @@ def partition_values_str(self) -> Optional[MicroPartition]: If the table is not partitioned, returns None. """ - null_part = Series.from_pylist([None]) partition_values = self.partition_values() if partition_values is None: return None else: - pkey_names = partition_values.column_names() - - partition_strings = {} - - for c in pkey_names: - column = partition_values.get_column(c) - string_names = column._to_str_values() - null_filled = column.is_null().if_else(null_part, string_names) - partition_strings[c] = null_filled - + partition_strings = partition_values_to_str_mapping(partition_values) return MicroPartition.from_pydict(partition_strings) diff --git a/src/daft-core/src/utils/identity_hash_set.rs b/src/daft-core/src/utils/identity_hash_set.rs index cadff5f774..71e021376c 100644 --- a/src/daft-core/src/utils/identity_hash_set.rs +++ b/src/daft-core/src/utils/identity_hash_set.rs @@ -1,4 +1,4 @@ -use std::hash::{BuildHasherDefault, Hasher}; +use std::hash::{BuildHasherDefault, Hash, Hasher}; pub type IdentityBuildHasher = BuildHasherDefault; @@ -27,3 +27,14 @@ impl Hasher for IdentityHasher { self.hash = i; } } + +pub struct IndexHash { + pub idx: u64, + pub hash: u64, +} + +impl Hash for IndexHash { + fn hash(&self, state: &mut H) { + state.write_u64(self.hash); + } +} diff --git a/src/daft-io/src/lib.rs b/src/daft-io/src/lib.rs index 63d1d08485..923be570ba 100644 --- a/src/daft-io/src/lib.rs +++ b/src/daft-io/src/lib.rs @@ -12,6 +12,7 @@ mod object_store_glob; mod s3_like; mod stats; mod stream_utils; + use azure_blob::AzureBlobSource; use common_file_formats::FileFormat; use google_cloud::GCSSource; diff --git a/src/daft-local-execution/Cargo.toml b/src/daft-local-execution/Cargo.toml index 8da3b93325..f0be341b6a 100644 --- a/src/daft-local-execution/Cargo.toml +++ b/src/daft-local-execution/Cargo.toml @@ -1,4 +1,5 @@ [dependencies] +async-trait = {workspace = true} common-daft-config = {path = "../common/daft-config", default-features = false} common-display = {path = "../common/display", default-features = false} common-error = {path = "../common/error", default-features = false} @@ -17,6 +18,7 @@ daft-physical-plan = {path = "../daft-physical-plan", default-features = false} daft-plan = {path = "../daft-plan", default-features = false} daft-scan = {path = "../daft-scan", default-features = false} daft-table = {path = "../daft-table", default-features = false} +daft-writers = {path = "../daft-writers", default-features = false} futures = {workspace = true} indexmap = {workspace = true} lazy_static = {workspace = true} @@ -29,7 +31,7 @@ tokio-stream = {workspace = true} tracing = {workspace = true} [features] -python = ["dep:pyo3", "common-daft-config/python", "common-file-formats/python", "common-error/python", "daft-dsl/python", "daft-io/python", "daft-micropartition/python", "daft-plan/python", "daft-scan/python", "common-display/python"] +python = ["dep:pyo3", "common-daft-config/python", "common-file-formats/python", "common-error/python", "daft-dsl/python", "daft-io/python", "daft-micropartition/python", "daft-plan/python", "daft-scan/python", "daft-writers/python", "common-display/python"] [lints] workspace = true diff --git a/src/daft-local-execution/src/buffer.rs b/src/daft-local-execution/src/buffer.rs new file mode 100644 index 0000000000..4211200182 --- /dev/null +++ b/src/daft-local-execution/src/buffer.rs @@ -0,0 +1,82 @@ +use std::{cmp::Ordering::*, collections::VecDeque, sync::Arc}; + +use common_error::DaftResult; +use daft_micropartition::MicroPartition; + +// A buffer that accumulates morsels until a threshold is reached +pub struct RowBasedBuffer { + pub buffer: VecDeque>, + pub curr_len: usize, + pub threshold: usize, +} + +impl RowBasedBuffer { + pub fn new(threshold: usize) -> Self { + assert!(threshold > 0); + Self { + buffer: VecDeque::new(), + curr_len: 0, + threshold, + } + } + + // Push a morsel to the buffer + pub fn push(&mut self, part: &Arc) { + self.curr_len += part.len(); + self.buffer.push_back(part.clone()); + } + + // Pop enough morsels that reach the threshold + // - If the buffer currently has not enough morsels, return None + // - If the buffer has exactly enough morsels, return the morsels + // - If the buffer has more than enough morsels, return a vec of morsels, each correctly sized to the threshold. + // The remaining morsels will be pushed back to the buffer + pub fn pop_enough(&mut self) -> DaftResult>>> { + match self.curr_len.cmp(&self.threshold) { + Less => Ok(None), + Equal => { + if self.buffer.len() == 1 { + let part = self.buffer.pop_front().unwrap(); + self.curr_len = 0; + Ok(Some(vec![part])) + } else { + let chunk = MicroPartition::concat(std::mem::take(&mut self.buffer))?; + self.curr_len = 0; + Ok(Some(vec![chunk.into()])) + } + } + Greater => { + let num_ready_chunks = self.curr_len / self.threshold; + let concated = MicroPartition::concat(std::mem::take(&mut self.buffer))?; + let mut start = 0; + let mut parts_to_return = Vec::with_capacity(num_ready_chunks); + for _ in 0..num_ready_chunks { + let end = start + self.threshold; + let part = concated.slice(start, end)?; + parts_to_return.push(part.into()); + start = end; + } + if start < concated.len() { + let part = concated.slice(start, concated.len())?; + self.curr_len = part.len(); + self.buffer.push_back(part.into()); + } else { + self.curr_len = 0; + } + Ok(Some(parts_to_return)) + } + } + } + + // Pop all morsels in the buffer regardless of the threshold + pub fn pop_all(&mut self) -> DaftResult>> { + assert!(self.curr_len < self.threshold); + if self.buffer.is_empty() { + Ok(None) + } else { + let concated = MicroPartition::concat(std::mem::take(&mut self.buffer))?; + self.curr_len = 0; + Ok(Some(concated.into())) + } + } +} diff --git a/src/daft-local-execution/src/dispatcher.rs b/src/daft-local-execution/src/dispatcher.rs new file mode 100644 index 0000000000..d21fc306b6 --- /dev/null +++ b/src/daft-local-execution/src/dispatcher.rs @@ -0,0 +1,102 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use common_error::DaftResult; +use daft_dsl::ExprRef; + +use crate::{ + buffer::RowBasedBuffer, channel::Sender, pipeline::PipelineResultType, + runtime_stats::CountingReceiver, +}; + +#[async_trait] +pub(crate) trait Dispatcher { + async fn dispatch( + &self, + receiver: CountingReceiver, + worker_senders: Vec>, + ) -> DaftResult<()>; +} + +pub(crate) struct RoundRobinBufferedDispatcher { + morsel_size: usize, +} + +impl RoundRobinBufferedDispatcher { + pub(crate) fn new(morsel_size: usize) -> Self { + Self { morsel_size } + } +} + +#[async_trait] +impl Dispatcher for RoundRobinBufferedDispatcher { + async fn dispatch( + &self, + mut receiver: CountingReceiver, + worker_senders: Vec>, + ) -> DaftResult<()> { + let mut next_worker_idx = 0; + let mut send_to_next_worker = |data: PipelineResultType| { + let next_worker_sender = worker_senders.get(next_worker_idx).unwrap(); + next_worker_idx = (next_worker_idx + 1) % worker_senders.len(); + next_worker_sender.send(data) + }; + + let mut buffer = RowBasedBuffer::new(self.morsel_size); + while let Some(morsel) = receiver.recv().await { + if morsel.should_broadcast() { + for worker_sender in &worker_senders { + let _ = worker_sender.send(morsel.clone()).await; + } + } else { + buffer.push(morsel.as_data()); + if let Some(ready) = buffer.pop_enough()? { + for r in ready { + let _ = send_to_next_worker(r.into()).await; + } + } + } + } + // Clear all remaining morsels + if let Some(last_morsel) = buffer.pop_all()? { + let _ = send_to_next_worker(last_morsel.into()).await; + } + Ok(()) + } +} + +pub(crate) struct PartitionedDispatcher { + partition_by: Vec, +} + +impl PartitionedDispatcher { + pub(crate) fn new(partition_by: Vec) -> Self { + Self { partition_by } + } +} + +#[async_trait] +impl Dispatcher for PartitionedDispatcher { + async fn dispatch( + &self, + mut receiver: CountingReceiver, + worker_senders: Vec>, + ) -> DaftResult<()> { + while let Some(morsel) = receiver.recv().await { + if morsel.should_broadcast() { + for worker_sender in &worker_senders { + let _ = worker_sender.send(morsel.clone()).await; + } + } else { + let partitions = morsel + .as_data() + .partition_by_hash(&self.partition_by, worker_senders.len())?; + for (partition, worker_sender) in partitions.into_iter().zip(worker_senders.iter()) + { + let _ = worker_sender.send(Arc::new(partition).into()).await; + } + } + } + Ok(()) + } +} diff --git a/src/daft-local-execution/src/intermediate_ops/buffer.rs b/src/daft-local-execution/src/intermediate_ops/buffer.rs deleted file mode 100644 index 3c66301610..0000000000 --- a/src/daft-local-execution/src/intermediate_ops/buffer.rs +++ /dev/null @@ -1,91 +0,0 @@ -use std::{ - cmp::Ordering::{Equal, Greater, Less}, - collections::VecDeque, - sync::Arc, -}; - -use common_error::DaftResult; -use daft_micropartition::MicroPartition; - -pub struct OperatorBuffer { - pub buffer: VecDeque>, - pub curr_len: usize, - pub threshold: usize, -} - -impl OperatorBuffer { - pub fn new(threshold: usize) -> Self { - assert!(threshold > 0); - Self { - buffer: VecDeque::new(), - curr_len: 0, - threshold, - } - } - - pub fn push(&mut self, part: Arc) { - self.curr_len += part.len(); - self.buffer.push_back(part); - } - - pub fn try_clear(&mut self) -> Option>> { - match self.curr_len.cmp(&self.threshold) { - Less => None, - Equal => self.clear_all(), - Greater => Some(self.clear_enough()), - } - } - - fn clear_enough(&mut self) -> DaftResult> { - assert!(self.curr_len > self.threshold); - - let mut to_concat = Vec::with_capacity(self.buffer.len()); - let mut remaining = self.threshold; - - while remaining > 0 { - let part = self.buffer.pop_front().expect("Buffer should not be empty"); - let part_len = part.len(); - if part_len <= remaining { - remaining -= part_len; - to_concat.push(part); - } else { - let (head, tail) = part.split_at(remaining)?; - remaining = 0; - to_concat.push(Arc::new(head)); - self.buffer.push_front(Arc::new(tail)); - break; - } - } - assert_eq!(remaining, 0); - - self.curr_len -= self.threshold; - match to_concat.len() { - 1 => Ok(to_concat.pop().unwrap()), - _ => MicroPartition::concat( - &to_concat - .iter() - .map(std::convert::AsRef::as_ref) - .collect::>(), - ) - .map(Arc::new), - } - } - - pub fn clear_all(&mut self) -> Option>> { - if self.buffer.is_empty() { - return None; - } - - let concated = MicroPartition::concat( - &self - .buffer - .iter() - .map(std::convert::AsRef::as_ref) - .collect::>(), - ) - .map(Arc::new); - self.buffer.clear(); - self.curr_len = 0; - Some(concated) - } -} diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index f8b3acbd4c..d7503fa33c 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -6,8 +6,8 @@ use common_runtime::get_compute_runtime; use daft_micropartition::MicroPartition; use tracing::{info_span, instrument}; -use super::buffer::OperatorBuffer; use crate::{ + buffer::RowBasedBuffer, channel::{create_channel, PipelineChannel, Receiver, Sender}, pipeline::{PipelineNode, PipelineResultType}, runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}, @@ -185,26 +185,23 @@ impl IntermediateNode { }; for (idx, mut receiver) in receivers.into_iter().enumerate() { - let mut buffer = OperatorBuffer::new(morsel_size); + let mut buffer = RowBasedBuffer::new(morsel_size); while let Some(morsel) = receiver.recv().await { if morsel.should_broadcast() { for worker_sender in &worker_senders { let _ = worker_sender.send((idx, morsel.clone())).await; } } else { - buffer.push(morsel.as_data().clone()); - if let Some(ready) = buffer.try_clear() { - let _ = send_to_next_worker(idx, ready?.into()).await; + buffer.push(morsel.as_data()); + if let Some(ready) = buffer.pop_enough()? { + for part in ready { + let _ = send_to_next_worker(idx, part.into()).await; + } } } } - // Buffer may still have some morsels left above the threshold - while let Some(ready) = buffer.try_clear() { - let _ = send_to_next_worker(idx, ready?.into()).await; - } - // Clear all remaining morsels - if let Some(last_morsel) = buffer.clear_all() { - let _ = send_to_next_worker(idx, last_morsel?.into()).await; + if let Some(ready) = buffer.pop_all()? { + let _ = send_to_next_worker(idx, ready.into()).await; } } Ok(()) diff --git a/src/daft-local-execution/src/intermediate_ops/mod.rs b/src/daft-local-execution/src/intermediate_ops/mod.rs index 5c935a24ed..336f6ebc92 100644 --- a/src/daft-local-execution/src/intermediate_ops/mod.rs +++ b/src/daft-local-execution/src/intermediate_ops/mod.rs @@ -1,7 +1,6 @@ pub mod actor_pool_project; pub mod aggregate; pub mod anti_semi_hash_join_probe; -pub mod buffer; pub mod explode; pub mod filter; pub mod inner_hash_join_probe; diff --git a/src/daft-local-execution/src/lib.rs b/src/daft-local-execution/src/lib.rs index e9b7e08b96..553ad18b40 100644 --- a/src/daft-local-execution/src/lib.rs +++ b/src/daft-local-execution/src/lib.rs @@ -1,5 +1,7 @@ #![feature(let_chains)] +mod buffer; mod channel; +mod dispatcher; mod intermediate_ops; mod pipeline; mod run; diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index cb84dacde4..b371af0c8a 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -1,7 +1,9 @@ use std::{collections::HashMap, sync::Arc}; +use common_daft_config::DaftExecutionConfig; use common_display::{mermaid::MermaidDisplayVisitor, tree::TreeDisplay}; use common_error::DaftResult; +use common_file_formats::FileFormat; use daft_core::{ datatypes::Field, prelude::{Schema, SchemaRef}, @@ -11,10 +13,12 @@ use daft_dsl::{col, join::get_common_join_keys, Expr}; use daft_micropartition::MicroPartition; use daft_physical_plan::{ ActorPoolProject, Concat, EmptyScan, Explode, Filter, HashAggregate, HashJoin, InMemoryScan, - Limit, LocalPhysicalPlan, Pivot, Project, Sample, Sort, UnGroupedAggregate, Unpivot, + Limit, LocalPhysicalPlan, PhysicalWrite, Pivot, Project, Sample, Sort, UnGroupedAggregate, + Unpivot, }; use daft_plan::{populate_aggregation_stages, JoinType}; use daft_table::ProbeState; +use daft_writers::make_writer_factory; use indexmap::IndexSet; use snafu::ResultExt; @@ -28,10 +32,15 @@ use crate::{ sample::SampleOperator, unpivot::UnpivotOperator, }, sinks::{ - aggregate::AggregateSink, blocking_sink::BlockingSinkNode, concat::ConcatSink, - hash_join_build::HashJoinBuildSink, limit::LimitSink, - outer_hash_join_probe::OuterHashJoinProbeSink, sort::SortSink, + aggregate::AggregateSink, + blocking_sink::BlockingSinkNode, + concat::ConcatSink, + hash_join_build::HashJoinBuildSink, + limit::LimitSink, + outer_hash_join_probe::OuterHashJoinProbeSink, + sort::SortSink, streaming_sink::StreamingSinkNode, + write::{WriteFormat, WriteSink}, }, sources::{empty_scan::EmptyScanSource, in_memory::InMemorySource}, ExecutionRuntimeHandle, PipelineCreationSnafu, @@ -102,6 +111,7 @@ pub fn viz_pipeline(root: &dyn PipelineNode) -> String { pub fn physical_plan_to_pipeline( physical_plan: &LocalPhysicalPlan, psets: &HashMap>>, + cfg: &Arc, ) -> crate::Result> { use daft_physical_plan::PhysicalScan; @@ -125,14 +135,14 @@ pub fn physical_plan_to_pipeline( input, projection, .. }) => { let proj_op = ProjectOperator::new(projection.clone()); - let child_node = physical_plan_to_pipeline(input, psets)?; + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; IntermediateNode::new(Arc::new(proj_op), vec![child_node]).boxed() } LocalPhysicalPlan::ActorPoolProject(ActorPoolProject { input, projection, .. }) => { let proj_op = ActorPoolProjectOperator::new(projection.clone()); - let child_node = physical_plan_to_pipeline(input, psets)?; + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; IntermediateNode::new(Arc::new(proj_op), vec![child_node]).boxed() } LocalPhysicalPlan::Sample(Sample { @@ -143,33 +153,33 @@ pub fn physical_plan_to_pipeline( .. }) => { let sample_op = SampleOperator::new(*fraction, *with_replacement, *seed); - let child_node = physical_plan_to_pipeline(input, psets)?; + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; IntermediateNode::new(Arc::new(sample_op), vec![child_node]).boxed() } LocalPhysicalPlan::Filter(Filter { input, predicate, .. }) => { let filter_op = FilterOperator::new(predicate.clone()); - let child_node = physical_plan_to_pipeline(input, psets)?; + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; IntermediateNode::new(Arc::new(filter_op), vec![child_node]).boxed() } LocalPhysicalPlan::Explode(Explode { input, to_explode, .. }) => { let explode_op = ExplodeOperator::new(to_explode.clone()); - let child_node = physical_plan_to_pipeline(input, psets)?; + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; IntermediateNode::new(Arc::new(explode_op), vec![child_node]).boxed() } LocalPhysicalPlan::Limit(Limit { input, num_rows, .. }) => { let sink = LimitSink::new(*num_rows as usize); - let child_node = physical_plan_to_pipeline(input, psets)?; + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; StreamingSinkNode::new(Arc::new(sink), vec![child_node]).boxed() } LocalPhysicalPlan::Concat(Concat { input, other, .. }) => { - let left_child = physical_plan_to_pipeline(input, psets)?; - let right_child = physical_plan_to_pipeline(other, psets)?; + let left_child = physical_plan_to_pipeline(input, psets, cfg)?; + let right_child = physical_plan_to_pipeline(other, psets, cfg)?; let sink = ConcatSink {}; StreamingSinkNode::new(Arc::new(sink), vec![left_child, right_child]).boxed() } @@ -189,7 +199,7 @@ pub fn physical_plan_to_pipeline( .collect(), vec![], ); - let child_node = physical_plan_to_pipeline(input, psets)?; + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; let post_first_agg_node = IntermediateNode::new(Arc::new(first_stage_agg_op), vec![child_node]).boxed(); @@ -202,7 +212,7 @@ pub fn physical_plan_to_pipeline( vec![], ); let second_stage_node = - BlockingSinkNode::new(second_stage_agg_sink.boxed(), post_first_agg_node).boxed(); + BlockingSinkNode::new(Arc::new(second_stage_agg_sink), post_first_agg_node).boxed(); let final_stage_project = ProjectOperator::new(final_exprs); @@ -217,7 +227,7 @@ pub fn physical_plan_to_pipeline( }) => { let (first_stage_aggs, second_stage_aggs, final_exprs) = populate_aggregation_stages(aggregations, schema, group_by); - let child_node = physical_plan_to_pipeline(input, psets)?; + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; let (post_first_agg_node, group_by) = if !first_stage_aggs.is_empty() { let agg_op = AggregateOperator::new( first_stage_aggs @@ -244,7 +254,7 @@ pub fn physical_plan_to_pipeline( group_by.clone(), ); let second_stage_node = - BlockingSinkNode::new(second_stage_agg_sink.boxed(), post_first_agg_node).boxed(); + BlockingSinkNode::new(Arc::new(second_stage_agg_sink), post_first_agg_node).boxed(); let final_stage_project = ProjectOperator::new(final_exprs); @@ -258,7 +268,7 @@ pub fn physical_plan_to_pipeline( value_name, .. }) => { - let child_node = physical_plan_to_pipeline(input, psets)?; + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; let unpivot_op = UnpivotOperator::new( ids.clone(), values.clone(), @@ -281,7 +291,7 @@ pub fn physical_plan_to_pipeline( value_column.clone(), names.clone(), ); - let child_node = physical_plan_to_pipeline(input, psets)?; + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; IntermediateNode::new(Arc::new(pivot_op), vec![child_node]).boxed() } LocalPhysicalPlan::Sort(Sort { @@ -291,8 +301,8 @@ pub fn physical_plan_to_pipeline( .. }) => { let sort_sink = SortSink::new(sort_by.clone(), descending.clone()); - let child_node = physical_plan_to_pipeline(input, psets)?; - BlockingSinkNode::new(sort_sink.boxed(), child_node).boxed() + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; + BlockingSinkNode::new(Arc::new(sort_sink), child_node).boxed() } LocalPhysicalPlan::HashJoin(HashJoin { @@ -361,11 +371,11 @@ pub fn physical_plan_to_pipeline( // we should move to a builder pattern let build_sink = HashJoinBuildSink::new(key_schema, casted_build_on, join_type)?; - let build_child_node = physical_plan_to_pipeline(build_child, psets)?; + let build_child_node = physical_plan_to_pipeline(build_child, psets, cfg)?; let build_node = - BlockingSinkNode::new(build_sink.boxed(), build_child_node).boxed(); + BlockingSinkNode::new(Arc::new(build_sink), build_child_node).boxed(); - let probe_child_node = physical_plan_to_pipeline(probe_child, psets)?; + let probe_child_node = physical_plan_to_pipeline(probe_child, psets, cfg)?; match join_type { JoinType::Anti | JoinType::Semi => Ok(IntermediateNode::new( @@ -409,8 +419,29 @@ pub fn physical_plan_to_pipeline( plan_name: physical_plan.name(), })? } - _ => { - unimplemented!("Physical plan not supported: {}", physical_plan.name()); + LocalPhysicalPlan::PhysicalWrite(PhysicalWrite { + input, + file_info, + data_schema, + file_schema, + .. + }) => { + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; + let writer_factory = make_writer_factory(file_info, data_schema, cfg); + let write_format = match (file_info.file_format, file_info.partition_cols.is_some()) { + (FileFormat::Parquet, true) => WriteFormat::PartitionedParquet, + (FileFormat::Parquet, false) => WriteFormat::Parquet, + (FileFormat::Csv, true) => WriteFormat::PartitionedCsv, + (FileFormat::Csv, false) => WriteFormat::Csv, + (_, _) => panic!("Unsupported file format"), + }; + let write_sink = WriteSink::new( + write_format, + writer_factory, + file_info.partition_cols.clone(), + file_schema.clone(), + ); + BlockingSinkNode::new(Arc::new(write_sink), child_node).boxed() } }; diff --git a/src/daft-local-execution/src/run.rs b/src/daft-local-execution/src/run.rs index ae0939ef8a..3cab41ef36 100644 --- a/src/daft-local-execution/src/run.rs +++ b/src/daft-local-execution/src/run.rs @@ -118,7 +118,7 @@ pub fn run_local( results_buffer_size: Option, ) -> DaftResult>> + Send>> { refresh_chrome_trace(); - let mut pipeline = physical_plan_to_pipeline(physical_plan, &psets)?; + let mut pipeline = physical_plan_to_pipeline(physical_plan, &psets, &cfg)?; let (tx, rx) = create_channel(results_buffer_size.unwrap_or(1)); let handle = std::thread::spawn(move || { let runtime = tokio::runtime::Builder::new_current_thread() diff --git a/src/daft-local-execution/src/sinks/aggregate.rs b/src/daft-local-execution/src/sinks/aggregate.rs index e94ff7c68b..abc8acce4c 100644 --- a/src/daft-local-execution/src/sinks/aggregate.rs +++ b/src/daft-local-execution/src/sinks/aggregate.rs @@ -5,19 +5,43 @@ use daft_dsl::ExprRef; use daft_micropartition::MicroPartition; use tracing::instrument; -use super::blocking_sink::{BlockingSink, BlockingSinkStatus}; -use crate::pipeline::PipelineResultType; +use super::blocking_sink::{BlockingSink, BlockingSinkState, BlockingSinkStatus}; +use crate::{pipeline::PipelineResultType, NUM_CPUS}; enum AggregateState { Accumulating(Vec>), - #[allow(dead_code)] - Done(Arc), + Done, +} + +impl AggregateState { + fn push(&mut self, part: Arc) { + if let Self::Accumulating(ref mut parts) = self { + parts.push(part); + } else { + panic!("AggregateSink should be in Accumulating state"); + } + } + + fn finalize(&mut self) -> Vec> { + let res = if let Self::Accumulating(ref mut parts) = self { + std::mem::take(parts) + } else { + panic!("AggregateSink should be in Accumulating state"); + }; + *self = Self::Done; + res + } +} + +impl BlockingSinkState for AggregateState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } } pub struct AggregateSink { agg_exprs: Vec, group_by: Vec, - state: AggregateState, } impl AggregateSink { @@ -25,47 +49,51 @@ impl AggregateSink { Self { agg_exprs, group_by, - state: AggregateState::Accumulating(vec![]), } } - - pub fn boxed(self) -> Box { - Box::new(self) - } } impl BlockingSink for AggregateSink { #[instrument(skip_all, name = "AggregateSink::sink")] - fn sink(&mut self, input: &Arc) -> DaftResult { - if let AggregateState::Accumulating(parts) = &mut self.state { - parts.push(input.clone()); - Ok(BlockingSinkStatus::NeedMoreInput) - } else { - panic!("AggregateSink should be in Accumulating state"); - } + fn sink( + &self, + input: &Arc, + mut state: Box, + ) -> DaftResult { + state + .as_any_mut() + .downcast_mut::() + .expect("AggregateSink should have AggregateState") + .push(input.clone()); + Ok(BlockingSinkStatus::NeedMoreInput(state)) } #[instrument(skip_all, name = "AggregateSink::finalize")] - fn finalize(&mut self) -> DaftResult> { - if let AggregateState::Accumulating(parts) = &mut self.state { - assert!( - !parts.is_empty(), - "We can not finalize AggregateSink with no data" - ); - let concated = MicroPartition::concat( - &parts - .iter() - .map(std::convert::AsRef::as_ref) - .collect::>(), - )?; - let agged = Arc::new(concated.agg(&self.agg_exprs, &self.group_by)?); - self.state = AggregateState::Done(agged.clone()); - Ok(Some(agged.into())) - } else { - panic!("AggregateSink should be in Accumulating state"); - } + fn finalize( + &self, + states: Vec>, + ) -> DaftResult> { + let all_parts = states.into_iter().flat_map(|mut state| { + state + .as_any_mut() + .downcast_mut::() + .expect("AggregateSink should have AggregateState") + .finalize() + }); + let concated = MicroPartition::concat(all_parts)?; + let agged = Arc::new(concated.agg(&self.agg_exprs, &self.group_by)?); + Ok(Some(agged.into())) } + fn name(&self) -> &'static str { "AggregateSink" } + + fn max_concurrency(&self) -> usize { + *NUM_CPUS + } + + fn make_state(&self) -> DaftResult> { + Ok(Box::new(AggregateState::Accumulating(vec![]))) + } } diff --git a/src/daft-local-execution/src/sinks/blocking_sink.rs b/src/daft-local-execution/src/sinks/blocking_sink.rs index 3fcbf8d660..7835fbe138 100644 --- a/src/daft-local-execution/src/sinks/blocking_sink.rs +++ b/src/daft-local-execution/src/sinks/blocking_sink.rs @@ -4,39 +4,58 @@ use common_display::tree::TreeDisplay; use common_error::DaftResult; use common_runtime::get_compute_runtime; use daft_micropartition::MicroPartition; -use tracing::info_span; +use snafu::ResultExt; +use tracing::{info_span, instrument}; use crate::{ - channel::PipelineChannel, + channel::{create_channel, PipelineChannel, Receiver}, + dispatcher::{Dispatcher, RoundRobinBufferedDispatcher}, pipeline::{PipelineNode, PipelineResultType}, runtime_stats::RuntimeStatsContext, - ExecutionRuntimeHandle, + ExecutionRuntimeHandle, JoinSnafu, TaskSet, }; +pub trait BlockingSinkState: Send + Sync { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any; +} + pub enum BlockingSinkStatus { - NeedMoreInput, + NeedMoreInput(Box), #[allow(dead_code)] - Finished, + Finished(Box), } pub trait BlockingSink: Send + Sync { - fn sink(&mut self, input: &Arc) -> DaftResult; - fn finalize(&mut self) -> DaftResult>; + fn sink( + &self, + input: &Arc, + state: Box, + ) -> DaftResult; + fn finalize( + &self, + states: Vec>, + ) -> DaftResult>; fn name(&self) -> &'static str; + fn make_state(&self) -> DaftResult>; + fn make_dispatcher(&self, runtime_handle: &ExecutionRuntimeHandle) -> Arc { + Arc::new(RoundRobinBufferedDispatcher::new( + runtime_handle.default_morsel_size(), + )) + } + fn max_concurrency(&self) -> usize; } pub struct BlockingSinkNode { - // use a RW lock - op: Arc>>, + op: Arc, name: &'static str, child: Box, runtime_stats: Arc, } impl BlockingSinkNode { - pub(crate) fn new(op: Box, child: Box) -> Self { + pub(crate) fn new(op: Arc, child: Box) -> Self { let name = op.name(); Self { - op: Arc::new(tokio::sync::Mutex::new(op)), + op, name, child, runtime_stats: RuntimeStatsContext::new(), @@ -45,6 +64,46 @@ impl BlockingSinkNode { pub(crate) fn boxed(self) -> Box { Box::new(self) } + + #[instrument(level = "info", skip_all, name = "BlockingSink::run_worker")] + async fn run_worker( + op: Arc, + mut input_receiver: Receiver, + rt_context: Arc, + ) -> DaftResult> { + let span = info_span!("BlockingSink::Sink"); + let compute_runtime = get_compute_runtime(); + let mut state = op.make_state()?; + while let Some(morsel) = input_receiver.recv().await { + let op = op.clone(); + let morsel = morsel.clone(); + let span = span.clone(); + let rt_context = rt_context.clone(); + let fut = async move { rt_context.in_span(&span, || op.sink(morsel.as_data(), state)) }; + let result = compute_runtime.await_on(fut).await??; + match result { + BlockingSinkStatus::NeedMoreInput(new_state) => { + state = new_state; + } + BlockingSinkStatus::Finished(new_state) => { + return Ok(new_state); + } + } + } + + Ok(state) + } + + fn spawn_workers( + op: Arc, + input_receivers: Vec>, + task_set: &mut TaskSet>>, + stats: Arc, + ) { + for input_receiver in input_receivers { + task_set.spawn(Self::run_worker(op.clone(), input_receiver, stats.clone())); + } + } } impl TreeDisplay for BlockingSinkNode { @@ -81,7 +140,7 @@ impl PipelineNode for BlockingSinkNode { runtime_handle: &mut ExecutionRuntimeHandle, ) -> crate::Result { let child = self.child.as_mut(); - let mut child_results_receiver = child + let child_results_receiver = child .start(false, runtime_handle)? .get_receiver_with_stats(&self.runtime_stats); @@ -89,34 +148,45 @@ impl PipelineNode for BlockingSinkNode { let destination_sender = destination_channel.get_next_sender_with_stats(&self.runtime_stats); let op = self.op.clone(); - let rt_context = self.runtime_stats.clone(); + let runtime_stats = self.runtime_stats.clone(); + let num_workers = op.max_concurrency(); + let (input_senders, input_receivers) = (0..num_workers).map(|_| create_channel(1)).unzip(); + let dispatcher = op.make_dispatcher(runtime_handle); runtime_handle.spawn( async move { - let span = info_span!("BlockingSinkNode::execute"); - let compute_runtime = get_compute_runtime(); - while let Some(val) = child_results_receiver.recv().await { - let op = op.clone(); - let span = span.clone(); - let rt_context = rt_context.clone(); - let fut = async move { - let mut guard = op.lock().await; - rt_context.in_span(&span, || guard.sink(val.as_data())) - }; - let result = compute_runtime.await_on(fut).await??; - if matches!(result, BlockingSinkStatus::Finished) { - break; - } + dispatcher + .dispatch(child_results_receiver, input_senders) + .await + }, + self.name(), + ); + + runtime_handle.spawn( + async move { + let mut task_set = TaskSet::new(); + Self::spawn_workers( + op.clone(), + input_receivers, + &mut task_set, + runtime_stats.clone(), + ); + + let mut finished_states = Vec::with_capacity(num_workers); + while let Some(result) = task_set.join_next().await { + let state = result.context(JoinSnafu)??; + finished_states.push(state); } + + let compute_runtime = get_compute_runtime(); let finalized_result = compute_runtime .await_on(async move { - let mut guard = op.lock().await; - rt_context.in_span(&info_span!("BlockingSinkNode::finalize"), || { - guard.finalize() + runtime_stats.in_span(&info_span!("BlockingSinkNode::finalize"), || { + op.finalize(finished_states) }) }) .await??; - if let Some(part) = finalized_result { - let _ = destination_sender.send(part).await; + if let Some(res) = finalized_result { + let _ = destination_sender.send(res).await; } Ok(()) }, diff --git a/src/daft-local-execution/src/sinks/hash_join_build.rs b/src/daft-local-execution/src/sinks/hash_join_build.rs index c8258e281a..677f63279d 100644 --- a/src/daft-local-execution/src/sinks/hash_join_build.rs +++ b/src/daft-local-execution/src/sinks/hash_join_build.rs @@ -7,7 +7,7 @@ use daft_micropartition::MicroPartition; use daft_plan::JoinType; use daft_table::{make_probeable_builder, ProbeState, ProbeableBuilder, Table}; -use super::blocking_sink::{BlockingSink, BlockingSinkStatus}; +use super::blocking_sink::{BlockingSink, BlockingSinkState, BlockingSinkStatus}; use crate::pipeline::PipelineResultType; enum ProbeTableState { @@ -74,8 +74,16 @@ impl ProbeTableState { } } +impl BlockingSinkState for ProbeTableState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } +} + pub struct HashJoinBuildSink { - probe_table_state: ProbeTableState, + key_schema: SchemaRef, + projection: Vec, + join_type: JoinType, } impl HashJoinBuildSink { @@ -85,13 +93,11 @@ impl HashJoinBuildSink { join_type: &JoinType, ) -> DaftResult { Ok(Self { - probe_table_state: ProbeTableState::new(&key_schema, projection, join_type)?, + key_schema, + projection, + join_type: *join_type, }) } - - pub(crate) fn boxed(self) -> Box { - Box::new(self) - } } impl BlockingSink for HashJoinBuildSink { @@ -99,17 +105,46 @@ impl BlockingSink for HashJoinBuildSink { "HashJoinBuildSink" } - fn sink(&mut self, input: &Arc) -> DaftResult { - self.probe_table_state.add_tables(input)?; - Ok(BlockingSinkStatus::NeedMoreInput) + fn sink( + &self, + input: &Arc, + mut state: Box, + ) -> DaftResult { + state + .as_any_mut() + .downcast_mut::() + .expect("HashJoinBuildSink should have ProbeTableState") + .add_tables(input)?; + Ok(BlockingSinkStatus::NeedMoreInput(state)) } - fn finalize(&mut self) -> DaftResult> { - self.probe_table_state.finalize()?; - if let ProbeTableState::Done { probe_state } = &self.probe_table_state { + fn finalize( + &self, + states: Vec>, + ) -> DaftResult> { + assert_eq!(states.len(), 1); + let mut state = states.into_iter().next().unwrap(); + let probe_table_state = state + .as_any_mut() + .downcast_mut::() + .expect("State type mismatch"); + probe_table_state.finalize()?; + if let ProbeTableState::Done { probe_state } = probe_table_state { Ok(Some(probe_state.clone().into())) } else { panic!("finalize should only be called after the probe table is built") } } + + fn max_concurrency(&self) -> usize { + 1 + } + + fn make_state(&self) -> DaftResult> { + Ok(Box::new(ProbeTableState::new( + &self.key_schema, + self.projection.clone(), + &self.join_type, + )?)) + } } diff --git a/src/daft-local-execution/src/sinks/mod.rs b/src/daft-local-execution/src/sinks/mod.rs index 7960e55a7c..64366385c3 100644 --- a/src/daft-local-execution/src/sinks/mod.rs +++ b/src/daft-local-execution/src/sinks/mod.rs @@ -6,3 +6,4 @@ pub mod limit; pub mod outer_hash_join_probe; pub mod sort; pub mod streaming_sink; +pub mod write; diff --git a/src/daft-local-execution/src/sinks/sort.rs b/src/daft-local-execution/src/sinks/sort.rs index 169ea9e55d..83c933d1ec 100644 --- a/src/daft-local-execution/src/sinks/sort.rs +++ b/src/daft-local-execution/src/sinks/sort.rs @@ -5,18 +5,42 @@ use daft_dsl::ExprRef; use daft_micropartition::MicroPartition; use tracing::instrument; -use super::blocking_sink::{BlockingSink, BlockingSinkStatus}; -use crate::pipeline::PipelineResultType; -pub struct SortSink { - sort_by: Vec, - descending: Vec, - state: SortState, -} +use super::blocking_sink::{BlockingSink, BlockingSinkState, BlockingSinkStatus}; +use crate::{pipeline::PipelineResultType, NUM_CPUS}; enum SortState { Building(Vec>), - #[allow(dead_code)] - Done(Arc), + Done, +} + +impl SortState { + fn push(&mut self, part: Arc) { + if let Self::Building(ref mut parts) = self { + parts.push(part); + } else { + panic!("SortSink should be in Building state"); + } + } + + fn finalize(&mut self) -> Vec> { + let res = if let Self::Building(ref mut parts) = self { + std::mem::take(parts) + } else { + panic!("SortSink should be in Building state"); + }; + *self = Self::Done; + res + } +} + +impl BlockingSinkState for SortState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } +} +pub struct SortSink { + sort_by: Vec, + descending: Vec, } impl SortSink { @@ -24,46 +48,51 @@ impl SortSink { Self { sort_by, descending, - state: SortState::Building(vec![]), } } - pub fn boxed(self) -> Box { - Box::new(self) - } } impl BlockingSink for SortSink { #[instrument(skip_all, name = "SortSink::sink")] - fn sink(&mut self, input: &Arc) -> DaftResult { - if let SortState::Building(parts) = &mut self.state { - parts.push(input.clone()); - } else { - panic!("SortSink should be in Building state"); - } - Ok(BlockingSinkStatus::NeedMoreInput) + fn sink( + &self, + input: &Arc, + mut state: Box, + ) -> DaftResult { + state + .as_any_mut() + .downcast_mut::() + .expect("SortSink should have sort state") + .push(input.clone()); + Ok(BlockingSinkStatus::NeedMoreInput(state)) } #[instrument(skip_all, name = "SortSink::finalize")] - fn finalize(&mut self) -> DaftResult> { - if let SortState::Building(parts) = &mut self.state { - assert!( - !parts.is_empty(), - "We can not finalize SortSink with no data" - ); - let concated = MicroPartition::concat( - &parts - .iter() - .map(std::convert::AsRef::as_ref) - .collect::>(), - )?; - let sorted = Arc::new(concated.sort(&self.sort_by, &self.descending)?); - self.state = SortState::Done(sorted.clone()); - Ok(Some(sorted.into())) - } else { - panic!("SortSink should be in Building state"); - } + fn finalize( + &self, + states: Vec>, + ) -> DaftResult> { + let parts = states.into_iter().flat_map(|mut state| { + let state = state + .as_any_mut() + .downcast_mut::() + .expect("State type mismatch"); + state.finalize() + }); + let concated = MicroPartition::concat(parts)?; + let sorted = Arc::new(concated.sort(&self.sort_by, &self.descending)?); + Ok(Some(sorted.into())) } + fn name(&self) -> &'static str { "SortResult" } + + fn make_state(&self) -> DaftResult> { + Ok(Box::new(SortState::Building(Vec::new()))) + } + + fn max_concurrency(&self) -> usize { + *NUM_CPUS + } } diff --git a/src/daft-local-execution/src/sinks/write.rs b/src/daft-local-execution/src/sinks/write.rs new file mode 100644 index 0000000000..002f32a25a --- /dev/null +++ b/src/daft-local-execution/src/sinks/write.rs @@ -0,0 +1,137 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use daft_core::prelude::SchemaRef; +use daft_dsl::ExprRef; +use daft_micropartition::MicroPartition; +use daft_table::Table; +use daft_writers::{FileWriter, WriterFactory}; +use tracing::instrument; + +use super::blocking_sink::{BlockingSink, BlockingSinkState, BlockingSinkStatus}; +use crate::{ + dispatcher::{Dispatcher, PartitionedDispatcher, RoundRobinBufferedDispatcher}, + pipeline::PipelineResultType, + NUM_CPUS, +}; + +pub enum WriteFormat { + Parquet, + PartitionedParquet, + Csv, + PartitionedCsv, +} + +struct WriteState { + writer: Box, Result = Vec>>, +} + +impl WriteState { + pub fn new( + writer: Box, Result = Vec
>>, + ) -> Self { + Self { writer } + } +} + +impl BlockingSinkState for WriteState { + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } +} + +pub(crate) struct WriteSink { + write_format: WriteFormat, + writer_factory: Arc, Result = Vec
>>, + partition_by: Option>, + file_schema: SchemaRef, +} + +impl WriteSink { + pub(crate) fn new( + write_format: WriteFormat, + writer_factory: Arc, Result = Vec
>>, + partition_by: Option>, + file_schema: SchemaRef, + ) -> Self { + Self { + write_format, + writer_factory, + partition_by, + file_schema, + } + } +} + +impl BlockingSink for WriteSink { + #[instrument(skip_all, name = "WriteSink::sink")] + fn sink( + &self, + input: &Arc, + mut state: Box, + ) -> DaftResult { + state + .as_any_mut() + .downcast_mut::() + .expect("WriteSink should have WriteState") + .writer + .write(input)?; + Ok(BlockingSinkStatus::NeedMoreInput(state)) + } + + #[instrument(skip_all, name = "WriteSink::finalize")] + fn finalize( + &self, + states: Vec>, + ) -> DaftResult> { + let mut results = vec![]; + for mut state in states { + let state = state + .as_any_mut() + .downcast_mut::() + .expect("State type mismatch"); + results.extend(state.writer.close()?); + } + let mp = Arc::new(MicroPartition::new_loaded( + self.file_schema.clone(), + results.into(), + None, + )); + Ok(Some(mp.into())) + } + + fn name(&self) -> &'static str { + match self.write_format { + WriteFormat::Parquet => "ParquetSink", + WriteFormat::PartitionedParquet => "PartitionedParquetSink", + WriteFormat::Csv => "CsvSink", + WriteFormat::PartitionedCsv => "PartitionedCsvSink", + } + } + + fn make_state(&self) -> DaftResult> { + let writer = self.writer_factory.create_writer(0, None)?; + Ok(Box::new(WriteState::new(writer)) as Box) + } + + fn make_dispatcher( + &self, + runtime_handle: &crate::ExecutionRuntimeHandle, + ) -> Arc { + if let Some(partition_by) = &self.partition_by { + Arc::new(PartitionedDispatcher::new(partition_by.clone())) + } else { + Arc::new(RoundRobinBufferedDispatcher::new( + runtime_handle.default_morsel_size(), + )) + } + } + + fn max_concurrency(&self) -> usize { + if self.partition_by.is_some() { + *NUM_CPUS + } else { + 1 + } + } +} diff --git a/src/daft-micropartition/src/ops/concat.rs b/src/daft-micropartition/src/ops/concat.rs index 2108cc01e3..2ca6160178 100644 --- a/src/daft-micropartition/src/ops/concat.rs +++ b/src/daft-micropartition/src/ops/concat.rs @@ -1,4 +1,4 @@ -use std::sync::Mutex; +use std::{borrow::Borrow, ops::Deref, sync::Mutex}; use common_error::{DaftError, DaftResult}; use daft_io::IOStatsContext; @@ -7,18 +7,25 @@ use daft_stats::TableMetadata; use crate::micropartition::{MicroPartition, TableState}; impl MicroPartition { - pub fn concat(mps: &[&Self]) -> DaftResult { + pub fn concat(mps: I) -> DaftResult + where + I: IntoIterator, + T: Deref, + T::Target: Borrow, + { + let mps: Vec<_> = mps.into_iter().collect(); if mps.is_empty() { return Err(DaftError::ValueError( "Need at least 1 MicroPartition to perform concat".to_string(), )); } - let first_table = mps.first().unwrap(); + let first_table = mps.first().unwrap().deref().borrow(); - let first_schema = first_table.schema.as_ref(); + let first_schema = &first_table.schema; for tab in mps.iter().skip(1) { - if tab.schema.as_ref() != first_schema { + let tab = tab.deref().borrow(); + if &tab.schema != first_schema { return Err(DaftError::SchemaMismatch(format!( "MicroPartition concat requires all schemas to match, {} vs {}", first_schema, tab.schema @@ -30,13 +37,14 @@ impl MicroPartition { let mut all_tables = vec![]; - for m in mps { + for m in &mps { + let m = m.deref().borrow(); let tables = m.tables_or_read(io_stats.clone())?; all_tables.extend_from_slice(tables.as_slice()); } let mut all_stats = None; - for stats in mps.iter().flat_map(|m| &m.statistics) { + for stats in mps.iter().flat_map(|m| &m.deref().borrow().statistics) { if all_stats.is_none() { all_stats = Some(stats.clone()); } @@ -48,7 +56,7 @@ impl MicroPartition { let new_len = all_tables.iter().map(daft_table::Table::len).sum(); Ok(Self { - schema: mps.first().unwrap().schema.clone(), + schema: first_schema.clone(), state: Mutex::new(TableState::Loaded(all_tables.into())), metadata: TableMetadata { length: new_len }, statistics: all_stats, diff --git a/src/daft-micropartition/src/python.rs b/src/daft-micropartition/src/python.rs index 8875c29517..ab9b4a7db1 100644 --- a/src/daft-micropartition/src/python.rs +++ b/src/daft-micropartition/src/python.rs @@ -146,8 +146,8 @@ impl PyMicroPartition { #[staticmethod] pub fn concat(py: Python, to_concat: Vec) -> PyResult { - let mps: Vec<_> = to_concat.iter().map(|t| t.inner.as_ref()).collect(); - py.allow_threads(|| Ok(MicroPartition::concat(mps.as_slice())?.into())) + let mps_iter = to_concat.iter().map(|t| t.inner.as_ref()); + py.allow_threads(|| Ok(MicroPartition::concat(mps_iter)?.into())) } pub fn slice(&self, py: Python, start: i64, end: i64) -> PyResult { diff --git a/src/daft-physical-plan/Cargo.toml b/src/daft-physical-plan/Cargo.toml index 778b8b8560..17d419e63f 100644 --- a/src/daft-physical-plan/Cargo.toml +++ b/src/daft-physical-plan/Cargo.toml @@ -8,6 +8,9 @@ daft-scan = {path = "../daft-scan", default-features = false} log = {workspace = true} strum = {version = "0.26", features = ["derive"]} +[features] +python = ["common-error/python", "common-resource-request/python", "daft-core/python", "daft-dsl/python", "daft-plan/python", "daft-scan/python"] + [lints] workspace = true diff --git a/src/daft-physical-plan/src/local_plan.rs b/src/daft-physical-plan/src/local_plan.rs index ea2144a982..720b53080e 100644 --- a/src/daft-physical-plan/src/local_plan.rs +++ b/src/daft-physical-plan/src/local_plan.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use common_resource_request::ResourceRequest; use daft_core::prelude::*; use daft_dsl::{AggExpr, ExprRef}; -use daft_plan::InMemoryInfo; +use daft_plan::{InMemoryInfo, OutputFileInfo}; use daft_scan::{ScanTask, ScanTaskRef}; pub type LocalPhysicalPlanRef = Arc; @@ -287,7 +287,22 @@ impl LocalPhysicalPlan { .arced() } - #[must_use] + pub(crate) fn physical_write( + input: LocalPhysicalPlanRef, + data_schema: SchemaRef, + file_schema: SchemaRef, + file_info: OutputFileInfo, + ) -> LocalPhysicalPlanRef { + Self::PhysicalWrite(PhysicalWrite { + input, + data_schema, + file_schema, + file_info, + plan_stats: PlanStats {}, + }) + .arced() + } + pub fn schema(&self) -> &SchemaRef { match self { Self::PhysicalScan(PhysicalScan { schema, .. }) @@ -447,7 +462,13 @@ pub struct Concat { } #[derive(Debug)] -pub struct PhysicalWrite {} +pub struct PhysicalWrite { + pub input: LocalPhysicalPlanRef, + pub data_schema: SchemaRef, + pub file_schema: SchemaRef, + pub file_info: OutputFileInfo, + pub plan_stats: PlanStats, +} #[derive(Debug)] pub struct PlanStats {} diff --git a/src/daft-physical-plan/src/translate.rs b/src/daft-physical-plan/src/translate.rs index fc9cd0d656..3e5ee5cf13 100644 --- a/src/daft-physical-plan/src/translate.rs +++ b/src/daft-physical-plan/src/translate.rs @@ -176,6 +176,21 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { log::warn!("Repartition Not supported for Local Executor!; This will be a No-Op"); translate(&repartition.input) } + LogicalPlan::Sink(sink) => { + use daft_plan::SinkInfo; + let input = translate(&sink.input)?; + let data_schema = input.schema().clone(); + match sink.sink_info.as_ref() { + SinkInfo::OutputFileInfo(info) => Ok(LocalPhysicalPlan::physical_write( + input, + data_schema, + sink.schema.clone(), + info.clone(), + )), + #[cfg(feature = "python")] + SinkInfo::CatalogInfo(_) => todo!("CatalogInfo not yet implemented"), + } + } LogicalPlan::Explode(explode) => { let input = translate(&explode.input)?; Ok(LocalPhysicalPlan::explode( diff --git a/src/daft-plan/src/lib.rs b/src/daft-plan/src/lib.rs index 2541a143db..25d8fd4b62 100644 --- a/src/daft-plan/src/lib.rs +++ b/src/daft-plan/src/lib.rs @@ -31,7 +31,7 @@ pub use physical_planner::{ #[cfg(feature = "python")] use pyo3::prelude::*; #[cfg(feature = "python")] -pub use sink_info::{DeltaLakeCatalogInfo, IcebergCatalogInfo, LanceCatalogInfo}; +pub use sink_info::{CatalogType, DeltaLakeCatalogInfo, IcebergCatalogInfo, LanceCatalogInfo}; pub use sink_info::{OutputFileInfo, SinkInfo}; pub use source_info::{FileInfo, FileInfos, InMemoryInfo, SourceInfo}; #[cfg(feature = "python")] diff --git a/src/daft-plan/src/sink_info.rs b/src/daft-plan/src/sink_info.rs index 02c8e05273..d083a1f53d 100644 --- a/src/daft-plan/src/sink_info.rs +++ b/src/daft-plan/src/sink_info.rs @@ -63,7 +63,6 @@ pub struct IcebergCatalogInfo { #[derivative(PartialEq = "ignore")] #[derivative(Hash = "ignore")] pub iceberg_schema: PyObject, - #[serde( serialize_with = "serialize_py_object", deserialize_with = "deserialize_py_object" diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 701791aaf9..0a450ace70 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -785,6 +785,23 @@ impl Table { } } +impl PartialEq for Table { + fn eq(&self, other: &Self) -> bool { + if self.len() != other.len() { + return false; + } + if self.schema != other.schema { + return false; + } + for (lhs, rhs) in self.columns.iter().zip(other.columns.iter()) { + if lhs != rhs { + return false; + } + } + true + } +} + impl Display for Table { // `f` is a buffer, and this method must write the formatted string into it fn fmt(&self, f: &mut Formatter) -> Result { diff --git a/src/daft-table/src/ops/hash.rs b/src/daft-table/src/ops/hash.rs index c011597c3f..4b4d120da0 100644 --- a/src/daft-table/src/ops/hash.rs +++ b/src/daft-table/src/ops/hash.rs @@ -1,28 +1,14 @@ -use std::{ - collections::{hash_map::RawEntryMut, HashMap}, - hash::{Hash, Hasher}, -}; +use std::collections::{hash_map::RawEntryMut, HashMap}; use common_error::{DaftError, DaftResult}; use daft_core::{ array::ops::{arrow2::comparison::build_multi_array_is_equal, as_arrow::AsArrow}, datatypes::UInt64Array, - utils::identity_hash_set::IdentityBuildHasher, + utils::identity_hash_set::{IdentityBuildHasher, IndexHash}, }; use crate::Table; -pub struct IndexHash { - pub idx: u64, - pub hash: u64, -} - -impl Hash for IndexHash { - fn hash(&self, state: &mut H) { - state.write_u64(self.hash); - } -} - impl Table { pub fn hash_rows(&self) -> DaftResult { if self.num_columns() == 0 { diff --git a/src/daft-table/src/probeable/probe_set.rs b/src/daft-table/src/probeable/probe_set.rs index a948ad2a4b..adf9251756 100644 --- a/src/daft-table/src/probeable/probe_set.rs +++ b/src/daft-table/src/probeable/probe_set.rs @@ -9,12 +9,12 @@ use daft_core::{ prelude::SchemaRef, utils::{ dyn_compare::{build_dyn_multi_array_compare, MultiDynArrayComparator}, - identity_hash_set::IdentityBuildHasher, + identity_hash_set::{IdentityBuildHasher, IndexHash}, }, }; use super::{ArrowTableEntry, IndicesMapper, Probeable, ProbeableBuilder}; -use crate::{ops::hash::IndexHash, Table}; +use crate::Table; pub struct ProbeSet { schema: SchemaRef, hash_table: HashMap, diff --git a/src/daft-table/src/probeable/probe_table.rs b/src/daft-table/src/probeable/probe_table.rs index c0a4bde0de..8e8c8c8647 100644 --- a/src/daft-table/src/probeable/probe_table.rs +++ b/src/daft-table/src/probeable/probe_table.rs @@ -9,12 +9,12 @@ use daft_core::{ prelude::SchemaRef, utils::{ dyn_compare::{build_dyn_multi_array_compare, MultiDynArrayComparator}, - identity_hash_set::IdentityBuildHasher, + identity_hash_set::{IdentityBuildHasher, IndexHash}, }, }; use super::{ArrowTableEntry, IndicesMapper, Probeable, ProbeableBuilder}; -use crate::{ops::hash::IndexHash, Table}; +use crate::Table; pub struct ProbeTable { schema: SchemaRef, diff --git a/src/daft-writers/Cargo.toml b/src/daft-writers/Cargo.toml new file mode 100644 index 0000000000..19cff1807d --- /dev/null +++ b/src/daft-writers/Cargo.toml @@ -0,0 +1,22 @@ +[dependencies] +common-daft-config = {path = "../common/daft-config", default-features = false} +common-error = {path = "../common/error", default-features = false} +common-file-formats = {path = "../common/file-formats", default-features = false} +daft-core = {path = "../daft-core", default-features = false} +daft-dsl = {path = "../daft-dsl", default-features = false} +daft-io = {path = "../daft-io", default-features = false} +daft-micropartition = {path = "../daft-micropartition", default-features = false} +daft-plan = {path = "../daft-plan", default-features = false} +daft-table = {path = "../daft-table", default-features = false} +pyo3 = {workspace = true, optional = true} + +[features] +python = ["dep:pyo3", "common-file-formats/python", "common-error/python", "daft-dsl/python", "daft-io/python", "daft-micropartition/python", "daft-plan/python"] + +[lints] +workspace = true + +[package] +edition = {workspace = true} +name = "daft-writers" +version = {workspace = true} diff --git a/src/daft-writers/src/batch.rs b/src/daft-writers/src/batch.rs new file mode 100644 index 0000000000..d8af2f94fb --- /dev/null +++ b/src/daft-writers/src/batch.rs @@ -0,0 +1,195 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use daft_micropartition::MicroPartition; +use daft_table::Table; + +use crate::{FileWriter, WriterFactory}; + +// TargetBatchWriter is a writer that writes in batches of rows, i.e. for Parquet where we want to write +// a row group at a time. +pub struct TargetBatchWriter { + target_in_memory_chunk_rows: usize, + writer: Box, Result = Option
>>, + leftovers: Option>, + is_closed: bool, +} + +impl TargetBatchWriter { + pub fn new( + target_in_memory_chunk_rows: usize, + writer: Box, Result = Option
>>, + ) -> Self { + Self { + target_in_memory_chunk_rows, + writer, + leftovers: None, + is_closed: false, + } + } +} + +impl FileWriter for TargetBatchWriter { + type Input = Arc; + type Result = Option
; + + fn write(&mut self, input: &Arc) -> DaftResult<()> { + assert!( + !self.is_closed, + "Cannot write to a closed TargetBatchWriter" + ); + let input = if let Some(leftovers) = self.leftovers.take() { + MicroPartition::concat([&leftovers, input])?.into() + } else { + input.clone() + }; + + let mut local_offset = 0; + loop { + let remaining_rows = input.len() - local_offset; + + use std::cmp::Ordering; + match remaining_rows.cmp(&self.target_in_memory_chunk_rows) { + Ordering::Equal => { + // Write exactly one chunk + let chunk = input.slice(local_offset, local_offset + remaining_rows)?; + return self.writer.write(&chunk.into()); + } + Ordering::Less => { + // Store remaining rows as leftovers + let remainder = input.slice(local_offset, local_offset + remaining_rows)?; + self.leftovers = Some(remainder.into()); + return Ok(()); + } + Ordering::Greater => { + // Write a complete chunk and continue + let chunk = input.slice( + local_offset, + local_offset + self.target_in_memory_chunk_rows, + )?; + self.writer.write(&chunk.into())?; + local_offset += self.target_in_memory_chunk_rows; + } + } + } + } + + fn close(&mut self) -> DaftResult { + if let Some(leftovers) = self.leftovers.take() { + self.writer.write(&leftovers)?; + } + self.is_closed = true; + self.writer.close() + } +} + +pub struct TargetBatchWriterFactory { + writer_factory: Arc, Result = Option
>>, + target_in_memory_chunk_rows: usize, +} + +impl TargetBatchWriterFactory { + pub fn new( + writer_factory: Arc, Result = Option
>>, + target_in_memory_chunk_rows: usize, + ) -> Self { + Self { + writer_factory, + target_in_memory_chunk_rows, + } + } +} + +impl WriterFactory for TargetBatchWriterFactory { + type Input = Arc; + type Result = Option
; + + fn create_writer( + &self, + file_idx: usize, + partition_values: Option<&Table>, + ) -> DaftResult>> { + let writer = self + .writer_factory + .create_writer(file_idx, partition_values)?; + Ok(Box::new(TargetBatchWriter::new( + self.target_in_memory_chunk_rows, + writer, + ))) + } +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::test::{make_dummy_mp, DummyWriterFactory}; + + #[test] + fn test_target_batch_writer_exact_batch() { + let dummy_writer_factory = DummyWriterFactory; + let mut writer = + TargetBatchWriter::new(1, dummy_writer_factory.create_writer(0, None).unwrap()); + + let mp = make_dummy_mp(1); + writer.write(&mp).unwrap(); + let res = writer.close().unwrap(); + + assert!(res.is_some()); + let write_count = res + .unwrap() + .get_column("write_count") + .unwrap() + .u64() + .unwrap() + .get(0) + .unwrap(); + assert_eq!(write_count, 1); + } + + #[test] + fn test_target_batch_writer_small_batches() { + let dummy_writer_factory = DummyWriterFactory; + let mut writer = + TargetBatchWriter::new(3, dummy_writer_factory.create_writer(0, None).unwrap()); + + for _ in 0..8 { + let mp = make_dummy_mp(1); + writer.write(&mp).unwrap(); + } + let res = writer.close().unwrap(); + + assert!(res.is_some()); + let write_count = res + .unwrap() + .get_column("write_count") + .unwrap() + .u64() + .unwrap() + .get(0) + .unwrap(); + assert_eq!(write_count, 3); + } + + #[test] + fn test_target_batch_writer_big_batch() { + let dummy_writer_factory = DummyWriterFactory; + let mut writer = + TargetBatchWriter::new(3, dummy_writer_factory.create_writer(0, None).unwrap()); + + let mp = make_dummy_mp(10); + writer.write(&mp).unwrap(); + let res = writer.close().unwrap(); + + assert!(res.is_some()); + let write_count = res + .unwrap() + .get_column("write_count") + .unwrap() + .u64() + .unwrap() + .get(0) + .unwrap(); + assert_eq!(write_count, 4); + } +} diff --git a/src/daft-writers/src/file.rs b/src/daft-writers/src/file.rs new file mode 100644 index 0000000000..639b720d15 --- /dev/null +++ b/src/daft-writers/src/file.rs @@ -0,0 +1,215 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use daft_micropartition::MicroPartition; +use daft_table::Table; + +use crate::{FileWriter, WriterFactory}; + +// TargetFileSizeWriter is a writer that writes in files of a target size. +// It rotates the writer when the current file reaches the target size. +struct TargetFileSizeWriter { + current_file_rows: usize, + current_writer: Box, Result = Option
>>, + writer_factory: Arc, Result = Option
>>, + target_in_memory_file_rows: usize, + results: Vec
, + partition_values: Option
, + is_closed: bool, +} + +impl TargetFileSizeWriter { + fn new( + target_in_memory_file_rows: usize, + writer_factory: Arc, Result = Option
>>, + partition_values: Option
, + ) -> DaftResult { + let writer: Box, Result = Option
>> = + writer_factory.create_writer(0, partition_values.as_ref())?; + Ok(Self { + current_file_rows: 0, + current_writer: writer, + writer_factory, + target_in_memory_file_rows, + results: vec![], + partition_values, + is_closed: false, + }) + } + + fn rotate_writer(&mut self) -> DaftResult<()> { + if let Some(result) = self.current_writer.close()? { + self.results.push(result); + } + self.current_writer = self + .writer_factory + .create_writer(self.results.len(), self.partition_values.as_ref())?; + Ok(()) + } +} + +impl FileWriter for TargetFileSizeWriter { + type Input = Arc; + type Result = Vec
; + + fn write(&mut self, input: &Arc) -> DaftResult<()> { + assert!( + !self.is_closed, + "Cannot write to a closed TargetFileSizeWriter" + ); + use std::cmp::Ordering; + + let mut local_offset = 0; + + loop { + let remaining_input_rows = input.len() - local_offset; + let rows_until_target = self.target_in_memory_file_rows - self.current_file_rows; + + match remaining_input_rows.cmp(&rows_until_target) { + Ordering::Equal => { + // Write exactly what's needed to fill the current file + let to_write = + input.slice(local_offset, local_offset + remaining_input_rows)?; + self.current_writer.write(&to_write.into())?; + self.rotate_writer()?; + self.current_file_rows = 0; + return Ok(()); + } + Ordering::Less => { + // Write remaining input and update counter + let to_write = + input.slice(local_offset, local_offset + remaining_input_rows)?; + self.current_writer.write(&to_write.into())?; + self.current_file_rows += remaining_input_rows; + return Ok(()); + } + Ordering::Greater => { + // Write what fits in current file + let to_write = input.slice(local_offset, local_offset + rows_until_target)?; + self.current_writer.write(&to_write.into())?; + self.rotate_writer()?; + self.current_file_rows = 0; + + // Update offset and continue loop + local_offset += rows_until_target; + } + } + } + } + + fn close(&mut self) -> DaftResult { + if self.current_file_rows > 0 { + if let Some(result) = self.current_writer.close()? { + self.results.push(result); + } + } + self.is_closed = true; + Ok(std::mem::take(&mut self.results)) + } +} + +pub(crate) struct TargetFileSizeWriterFactory { + writer_factory: Arc, Result = Option
>>, + target_in_memory_file_rows: usize, +} + +impl TargetFileSizeWriterFactory { + pub(crate) fn new( + writer_factory: Arc, Result = Option
>>, + target_in_memory_file_rows: usize, + ) -> Self { + Self { + writer_factory, + target_in_memory_file_rows, + } + } +} + +impl WriterFactory for TargetFileSizeWriterFactory { + type Input = Arc; + type Result = Vec
; + + fn create_writer( + &self, + _file_idx: usize, + partition_values: Option<&Table>, + ) -> DaftResult>> { + Ok(Box::new(TargetFileSizeWriter::new( + self.target_in_memory_file_rows, + self.writer_factory.clone(), + partition_values.cloned(), + )?) + as Box< + dyn FileWriter, + >) + } +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::test::{make_dummy_mp, DummyWriterFactory}; + + #[test] + fn test_target_file_writer_exact_file() { + let dummy_writer_factory = DummyWriterFactory; + let mut writer = + TargetFileSizeWriter::new(1, Arc::new(dummy_writer_factory), None).unwrap(); + + let mp = make_dummy_mp(1); + writer.write(&mp).unwrap(); + let res = writer.close().unwrap(); + assert_eq!(res.len(), 1); + } + + #[test] + fn test_target_file_writer_less_rows_for_one_file() { + let dummy_writer_factory = DummyWriterFactory; + let mut writer = + TargetFileSizeWriter::new(3, Arc::new(dummy_writer_factory), None).unwrap(); + + let mp = make_dummy_mp(2); + writer.write(&mp).unwrap(); + let res = writer.close().unwrap(); + assert_eq!(res.len(), 1); + } + + #[test] + fn test_target_file_writer_more_rows_for_one_file() { + let dummy_writer_factory = DummyWriterFactory; + let mut writer = + TargetFileSizeWriter::new(3, Arc::new(dummy_writer_factory), None).unwrap(); + + let mp = make_dummy_mp(4); + writer.write(&mp).unwrap(); + let res = writer.close().unwrap(); + assert_eq!(res.len(), 2); + } + + #[test] + fn test_target_file_writer_multiple_files() { + let dummy_writer_factory = DummyWriterFactory; + let mut writer = + TargetFileSizeWriter::new(3, Arc::new(dummy_writer_factory), None).unwrap(); + + let mp = make_dummy_mp(10); + writer.write(&mp).unwrap(); + let res = writer.close().unwrap(); + assert_eq!(res.len(), 4); + } + + #[test] + fn test_target_file_writer_many_writes_many_files() { + let dummy_writer_factory = DummyWriterFactory; + let mut writer = + TargetFileSizeWriter::new(3, Arc::new(dummy_writer_factory), None).unwrap(); + + for _ in 0..10 { + let mp = make_dummy_mp(1); + writer.write(&mp).unwrap(); + } + let res = writer.close().unwrap(); + assert_eq!(res.len(), 4); + } +} diff --git a/src/daft-writers/src/lib.rs b/src/daft-writers/src/lib.rs new file mode 100644 index 0000000000..405596485b --- /dev/null +++ b/src/daft-writers/src/lib.rs @@ -0,0 +1,127 @@ +#![feature(hash_raw_entry)] +#![feature(let_chains)] +mod batch; +mod file; +mod partition; +mod physical; + +#[cfg(test)] +mod test; + +#[cfg(feature = "python")] +mod python; + +use std::{cmp::min, sync::Arc}; + +use batch::TargetBatchWriterFactory; +use common_daft_config::DaftExecutionConfig; +use common_error::DaftResult; +use common_file_formats::FileFormat; +use daft_core::prelude::SchemaRef; +use daft_micropartition::MicroPartition; +use daft_plan::OutputFileInfo; +use daft_table::Table; +use file::TargetFileSizeWriterFactory; +use partition::PartitionedWriterFactory; +use physical::PhysicalWriterFactory; + +/// This trait is used to abstract the writing of data to a file. +/// The `Input` type is the type of data that will be written to the file. +/// The `Result` type is the type of the result that will be returned when the file is closed. +pub trait FileWriter: Send + Sync { + type Input; + type Result; + + /// Write data to the file. + fn write(&mut self, data: &Self::Input) -> DaftResult<()>; + + /// Close the file and return the result. The caller should NOT write to the file after calling this method. + fn close(&mut self) -> DaftResult; +} + +/// This trait is used to abstract the creation of a `FileWriter` +/// The `create_writer` method is used to create a new `FileWriter`. +/// `file_idx` is the index of the file that will be written to. +/// `partition_values` is the partition values of the data that will be written to the file. +pub trait WriterFactory: Send + Sync { + type Input; + type Result; + fn create_writer( + &self, + file_idx: usize, + partition_values: Option<&Table>, + ) -> DaftResult>>; +} + +pub fn make_writer_factory( + file_info: &OutputFileInfo, + schema: &SchemaRef, + cfg: &DaftExecutionConfig, +) -> Arc, Result = Vec
>> { + let estimated_row_size_bytes = schema.estimate_row_size_bytes(); + let base_writer_factory = PhysicalWriterFactory::new(file_info.clone()); + match file_info.file_format { + FileFormat::Parquet => { + let target_in_memory_file_size = + cfg.parquet_target_filesize as f64 * cfg.parquet_inflation_factor; + let target_in_memory_row_group_size = + cfg.parquet_target_row_group_size as f64 * cfg.parquet_inflation_factor; + + let target_file_rows = if estimated_row_size_bytes > 0.0 { + target_in_memory_file_size / estimated_row_size_bytes + } else { + target_in_memory_file_size + } as usize; + + let target_row_group_rows = min( + target_file_rows, + if estimated_row_size_bytes > 0.0 { + target_in_memory_row_group_size / estimated_row_size_bytes + } else { + target_in_memory_row_group_size + } as usize, + ); + + let row_group_writer_factory = + TargetBatchWriterFactory::new(Arc::new(base_writer_factory), target_row_group_rows); + + let file_writer_factory = TargetFileSizeWriterFactory::new( + Arc::new(row_group_writer_factory), + target_file_rows, + ); + + if let Some(partition_cols) = &file_info.partition_cols { + let partitioned_writer_factory = PartitionedWriterFactory::new( + Arc::new(file_writer_factory), + partition_cols.clone(), + ); + Arc::new(partitioned_writer_factory) + } else { + Arc::new(file_writer_factory) + } + } + FileFormat::Csv => { + let target_in_memory_file_size = + cfg.csv_target_filesize as f64 * cfg.csv_inflation_factor; + let target_file_rows = if estimated_row_size_bytes > 0.0 { + target_in_memory_file_size / estimated_row_size_bytes + } else { + target_in_memory_file_size + } as usize; + + let file_writer_factory = + TargetFileSizeWriterFactory::new(Arc::new(base_writer_factory), target_file_rows); + + if let Some(partition_cols) = &file_info.partition_cols { + let partitioned_writer_factory = PartitionedWriterFactory::new( + Arc::new(file_writer_factory), + partition_cols.clone(), + ); + Arc::new(partitioned_writer_factory) + } else { + Arc::new(file_writer_factory) + } + } + _ => unreachable!("Physical write should only support Parquet and CSV"), + } +} diff --git a/src/daft-writers/src/partition.rs b/src/daft-writers/src/partition.rs new file mode 100644 index 0000000000..73b0ab0216 --- /dev/null +++ b/src/daft-writers/src/partition.rs @@ -0,0 +1,157 @@ +use std::{ + collections::{hash_map::RawEntryMut, HashMap}, + sync::Arc, +}; + +use common_error::DaftResult; +use daft_core::{array::ops::as_arrow::AsArrow, utils::identity_hash_set::IndexHash}; +use daft_dsl::ExprRef; +use daft_io::IOStatsContext; +use daft_micropartition::MicroPartition; +use daft_table::Table; + +use crate::{FileWriter, WriterFactory}; + +/// PartitionedWriter is a writer that partitions the input data by a set of columns, and writes each partition +/// to a separate file. It uses a map to keep track of the writers for each partition. +struct PartitionedWriter { + // TODO: Figure out a way to NOT use the IndexHash + RawEntryMut pattern here. Ideally we want to store ScalarValues, aka. single Rows of the partition values as keys for the hashmap. + per_partition_writers: + HashMap, Result = Vec
>>>, + saved_partition_values: Vec
, + writer_factory: Arc, Result = Vec
>>, + partition_by: Vec, + is_closed: bool, +} + +impl PartitionedWriter { + pub fn new( + writer_factory: Arc, Result = Vec
>>, + partition_by: Vec, + ) -> Self { + Self { + per_partition_writers: HashMap::new(), + saved_partition_values: vec![], + writer_factory, + partition_by, + is_closed: false, + } + } + + fn partition( + partition_cols: &[ExprRef], + data: &Arc, + ) -> DaftResult<(Vec
, Table)> { + let data = data.concat_or_get(IOStatsContext::new("MicroPartition::partition_by_value"))?; + let table = data.first().unwrap(); + let (split_tables, partition_values) = table.partition_by_value(partition_cols)?; + Ok((split_tables, partition_values)) + } +} + +impl FileWriter for PartitionedWriter { + type Input = Arc; + type Result = Vec
; + + fn write(&mut self, input: &Arc) -> DaftResult<()> { + assert!( + !self.is_closed, + "Cannot write to a closed PartitionedWriter" + ); + + let (split_tables, partition_values) = + Self::partition(self.partition_by.as_slice(), input)?; + let partition_values_hash = partition_values.hash_rows()?; + for (idx, (table, partition_value_hash)) in split_tables + .into_iter() + .zip(partition_values_hash.as_arrow().values_iter()) + .enumerate() + { + let partition_value_row = partition_values.slice(idx, idx + 1)?; + let entry = self.per_partition_writers.raw_entry_mut().from_hash( + *partition_value_hash, + |other| { + (*partition_value_hash == other.hash) && { + let other_table = + self.saved_partition_values.get(other.idx as usize).unwrap(); + other_table == &partition_value_row + } + }, + ); + match entry { + RawEntryMut::Vacant(entry) => { + let mut writer = self + .writer_factory + .create_writer(0, Some(partition_value_row.as_ref()))?; + writer.write(&Arc::new(MicroPartition::new_loaded( + table.schema.clone(), + vec![table].into(), + None, + )))?; + entry.insert_hashed_nocheck( + *partition_value_hash, + IndexHash { + idx: self.saved_partition_values.len() as u64, + hash: *partition_value_hash, + }, + writer, + ); + self.saved_partition_values.push(partition_value_row); + } + RawEntryMut::Occupied(mut entry) => { + let writer = entry.get_mut(); + writer.write(&Arc::new(MicroPartition::new_loaded( + table.schema.clone(), + vec![table].into(), + None, + )))?; + } + } + } + Ok(()) + } + + fn close(&mut self) -> DaftResult { + let mut results = vec![]; + for (_, mut writer) in self.per_partition_writers.drain() { + results.extend(writer.close()?); + } + self.is_closed = true; + Ok(results) + } +} + +pub(crate) struct PartitionedWriterFactory { + writer_factory: Arc, Result = Vec
>>, + partition_cols: Vec, +} + +impl PartitionedWriterFactory { + pub(crate) fn new( + writer_factory: Arc, Result = Vec
>>, + partition_cols: Vec, + ) -> Self { + Self { + writer_factory, + partition_cols, + } + } +} +impl WriterFactory for PartitionedWriterFactory { + type Input = Arc; + type Result = Vec
; + + fn create_writer( + &self, + _file_idx: usize, + _partition_values: Option<&Table>, + ) -> DaftResult>> { + Ok(Box::new(PartitionedWriter::new( + self.writer_factory.clone(), + self.partition_cols.clone(), + )) + as Box< + dyn FileWriter, + >) + } +} diff --git a/src/daft-writers/src/physical.rs b/src/daft-writers/src/physical.rs new file mode 100644 index 0000000000..81cb3dfbfe --- /dev/null +++ b/src/daft-writers/src/physical.rs @@ -0,0 +1,77 @@ +use std::sync::Arc; + +use common_error::{DaftError, DaftResult}; +use common_file_formats::FileFormat; +use daft_micropartition::MicroPartition; +use daft_plan::OutputFileInfo; +use daft_table::Table; + +use crate::{FileWriter, WriterFactory}; + +/// PhysicalWriterFactory is a factory for creating physical writers, i.e. parquet, csv writers. +pub struct PhysicalWriterFactory { + output_file_info: OutputFileInfo, + native: bool, // TODO: Implement native writer +} + +impl PhysicalWriterFactory { + pub fn new(output_file_info: OutputFileInfo) -> Self { + Self { + output_file_info, + native: false, + } + } +} + +impl WriterFactory for PhysicalWriterFactory { + type Input = Arc; + type Result = Option
; + + fn create_writer( + &self, + file_idx: usize, + partition_values: Option<&Table>, + ) -> DaftResult>> { + match self.native { + true => unimplemented!(), + false => { + let writer = create_pyarrow_file_writer( + &self.output_file_info.root_dir, + file_idx, + &self.output_file_info.compression, + &self.output_file_info.io_config, + self.output_file_info.file_format, + partition_values, + )?; + Ok(writer) + } + } + } +} + +pub fn create_pyarrow_file_writer( + root_dir: &str, + file_idx: usize, + compression: &Option, + io_config: &Option, + format: FileFormat, + partition: Option<&Table>, +) -> DaftResult, Result = Option
>>> { + match format { + #[cfg(feature = "python")] + FileFormat::Parquet => Ok(Box::new(crate::python::PyArrowWriter::new_parquet_writer( + root_dir, + file_idx, + compression, + io_config, + partition, + )?)), + #[cfg(feature = "python")] + FileFormat::Csv => Ok(Box::new(crate::python::PyArrowWriter::new_csv_writer( + root_dir, file_idx, io_config, partition, + )?)), + _ => Err(DaftError::ComputeError( + "Unsupported file format for physical write".to_string(), + )), + } +} diff --git a/src/daft-writers/src/python.rs b/src/daft-writers/src/python.rs new file mode 100644 index 0000000000..7bcecb2b03 --- /dev/null +++ b/src/daft-writers/src/python.rs @@ -0,0 +1,118 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use daft_micropartition::{python::PyMicroPartition, MicroPartition}; +use daft_table::{python::PyTable, Table}; +use pyo3::{types::PyAnyMethods, PyObject, Python}; + +use crate::FileWriter; + +pub struct PyArrowWriter { + py_writer: PyObject, + is_closed: bool, +} + +impl PyArrowWriter { + pub fn new_parquet_writer( + root_dir: &str, + file_idx: usize, + compression: &Option, + io_config: &Option, + partition_values: Option<&Table>, + ) -> DaftResult { + Python::with_gil(|py| { + let file_writer_module = py.import_bound(pyo3::intern!(py, "daft.io.writer"))?; + let file_writer_class = file_writer_module.getattr("ParquetFileWriter")?; + let _from_pytable = py + .import_bound(pyo3::intern!(py, "daft.table"))? + .getattr(pyo3::intern!(py, "Table"))? + .getattr(pyo3::intern!(py, "_from_pytable"))?; + let partition_values = match partition_values { + Some(pv) => { + let py_table = _from_pytable.call1((PyTable::from(pv.clone()),))?; + Some(py_table) + } + None => None, + }; + + let py_writer = file_writer_class.call1(( + root_dir, + file_idx, + partition_values, + compression.as_ref().map(|c| c.as_str()), + io_config.as_ref().map(|cfg| daft_io::python::IOConfig { + config: cfg.clone(), + }), + ))?; + Ok(Self { + py_writer: py_writer.into(), + is_closed: false, + }) + }) + } + + pub fn new_csv_writer( + root_dir: &str, + file_idx: usize, + io_config: &Option, + partition_values: Option<&Table>, + ) -> DaftResult { + Python::with_gil(|py| { + let file_writer_module = py.import_bound(pyo3::intern!(py, "daft.io.writer"))?; + let file_writer_class = file_writer_module.getattr("CSVFileWriter")?; + let _from_pytable = py + .import_bound(pyo3::intern!(py, "daft.table"))? + .getattr(pyo3::intern!(py, "Table"))? + .getattr(pyo3::intern!(py, "_from_pytable"))?; + let partition_values = match partition_values { + Some(pv) => { + let py_table = _from_pytable.call1((PyTable::from(pv.clone()),))?; + Some(py_table) + } + None => None, + }; + let py_writer = file_writer_class.call1(( + root_dir, + file_idx, + partition_values, + io_config.as_ref().map(|cfg| daft_io::python::IOConfig { + config: cfg.clone(), + }), + ))?; + Ok(Self { + py_writer: py_writer.into(), + is_closed: false, + }) + }) + } +} + +impl FileWriter for PyArrowWriter { + type Input = Arc; + type Result = Option
; + + fn write(&mut self, data: &Self::Input) -> DaftResult<()> { + assert!(!self.is_closed, "Cannot write to a closed PyArrowWriter"); + Python::with_gil(|py| { + let py_micropartition = py + .import_bound(pyo3::intern!(py, "daft.table"))? + .getattr(pyo3::intern!(py, "MicroPartition"))? + .getattr(pyo3::intern!(py, "_from_pymicropartition"))? + .call1((PyMicroPartition::from(data.clone()),))?; + self.py_writer + .call_method1(py, pyo3::intern!(py, "write"), (py_micropartition,))?; + Ok(()) + }) + } + + fn close(&mut self) -> DaftResult { + self.is_closed = true; + Python::with_gil(|py| { + let result = self + .py_writer + .call_method0(py, pyo3::intern!(py, "close"))? + .getattr(py, pyo3::intern!(py, "_table"))?; + Ok(Some(result.extract::(py)?.into())) + }) + } +} diff --git a/src/daft-writers/src/test.rs b/src/daft-writers/src/test.rs new file mode 100644 index 0000000000..f862930d3f --- /dev/null +++ b/src/daft-writers/src/test.rs @@ -0,0 +1,84 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use daft_core::{ + prelude::{Int64Array, Schema, UInt64Array, Utf8Array}, + series::IntoSeries, +}; +use daft_micropartition::MicroPartition; +use daft_table::Table; + +use crate::{FileWriter, WriterFactory}; + +pub(crate) struct DummyWriterFactory; + +impl WriterFactory for DummyWriterFactory { + type Input = Arc; + type Result = Option
; + + fn create_writer( + &self, + file_idx: usize, + partition_values: Option<&Table>, + ) -> DaftResult>> { + Ok(Box::new(DummyWriter { + file_idx: file_idx.to_string(), + partition_values: partition_values.cloned(), + write_count: 0, + }) + as Box< + dyn FileWriter, + >) + } +} + +pub(crate) struct DummyWriter { + file_idx: String, + partition_values: Option
, + write_count: usize, +} + +impl FileWriter for DummyWriter { + type Input = Arc; + type Result = Option
; + + fn write(&mut self, _input: &Self::Input) -> DaftResult<()> { + self.write_count += 1; + Ok(()) + } + + fn close(&mut self) -> DaftResult { + let path_series = + Utf8Array::from_values("path", std::iter::once(self.file_idx.clone())).into_series(); + let write_count_series = + UInt64Array::from_values("write_count", std::iter::once(self.write_count as u64)) + .into_series(); + let path_table = Table::new_unchecked( + Schema::new(vec![ + path_series.field().clone(), + write_count_series.field().clone(), + ]) + .unwrap(), + vec![path_series.into(), write_count_series.into()], + 1, + ); + if let Some(partition_values) = self.partition_values.take() { + let unioned = path_table.union(&partition_values)?; + Ok(Some(unioned)) + } else { + Ok(Some(path_table)) + } + } +} + +pub(crate) fn make_dummy_mp(num_rows: usize) -> Arc { + let series = + Int64Array::from_values("ints", std::iter::repeat(42).take(num_rows)).into_series(); + let schema = Arc::new(Schema::new(vec![series.field().clone()]).unwrap()); + let table = Table::new_unchecked(schema.clone(), vec![series.into()], num_rows); + Arc::new(MicroPartition::new_loaded( + schema.into(), + vec![table].into(), + None, + )) +} diff --git a/tests/benchmarks/conftest.py b/tests/benchmarks/conftest.py index 48e4cd43d9..3c49a2733f 100644 --- a/tests/benchmarks/conftest.py +++ b/tests/benchmarks/conftest.py @@ -6,6 +6,17 @@ import memray import pytest +from fsspec.implementations.local import LocalFileSystem + +import daft +from benchmarking.tpch import data_generation +from tests.assets import TPCH_DBGEN_DIR + +IS_CI = True if os.getenv("CI") else False + +SCALE_FACTOR = 0.2 +NUM_PARTS = [1] if IS_CI else [1, 2] +SOURCE_TYPES = ["in-memory"] if IS_CI else ["parquet", "in-memory"] memray_stats = defaultdict(dict) @@ -48,3 +59,69 @@ def benchmark_wrapper(func, group): return track_mem(func, group) return benchmark_wrapper + + +@pytest.fixture(scope="session", params=NUM_PARTS) +def gen_tpch(request): + # Parametrize the number of parts for each file so that we run tests on single-partition files and multi-partition files + num_parts = request.param + + csv_files_location = data_generation.gen_csv_files(TPCH_DBGEN_DIR, num_parts, SCALE_FACTOR) + parquet_files_location = data_generation.gen_parquet(csv_files_location) + + in_memory_tables = {} + for tbl_name in data_generation.SCHEMA.keys(): + arrow_table = daft.read_parquet(f"{parquet_files_location}/{tbl_name}/*").to_arrow() + in_memory_tables[tbl_name] = daft.from_arrow(arrow_table) + + sqlite_path = data_generation.gen_sqlite_db( + csv_filepath=csv_files_location, + num_parts=num_parts, + ) + + return ( + csv_files_location, + parquet_files_location, + in_memory_tables, + num_parts, + ), sqlite_path + + +@pytest.fixture(scope="module", params=SOURCE_TYPES) +def get_df(gen_tpch, request): + (csv_files_location, parquet_files_location, in_memory_tables, num_parts), _ = gen_tpch + source_type = request.param + print(f"Source Type: {source_type}") + + def _get_df(tbl_name: str): + print(f"Table Name: {tbl_name}, Source Type: {source_type}") + if source_type == "csv": + local_fs = LocalFileSystem() + nonchunked_filepath = f"{csv_files_location}/{tbl_name}.tbl" + chunked_filepath = nonchunked_filepath + ".*" + try: + local_fs.expand_path(chunked_filepath) + fp = chunked_filepath + except FileNotFoundError: + fp = nonchunked_filepath + + df = daft.read_csv( + fp, + has_headers=False, + delimiter="|", + ) + df = df.select( + *[ + daft.col(autoname).alias(colname) + for autoname, colname in zip(df.column_names, data_generation.SCHEMA[tbl_name]) + ] + ) + elif source_type == "parquet": + fp = f"{parquet_files_location}/{tbl_name}/*" + df = daft.read_parquet(fp) + elif source_type == "in-memory": + df = in_memory_tables[tbl_name] + + return df + + return _get_df, num_parts diff --git a/tests/benchmarks/test_local_tpch.py b/tests/benchmarks/test_local_tpch.py index 07165d9ebc..eed56f4b0e 100644 --- a/tests/benchmarks/test_local_tpch.py +++ b/tests/benchmarks/test_local_tpch.py @@ -1,13 +1,12 @@ from __future__ import annotations -import os import sys import pytest -from fsspec.implementations.local import LocalFileSystem import daft -from benchmarking.tpch import answers, data_generation +from benchmarking.tpch import answers +from tests.benchmarks.conftest import IS_CI if sys.platform == "win32": pytest.skip(allow_module_level=True) @@ -15,83 +14,9 @@ import itertools import daft.context -from tests.assets import TPCH_DBGEN_DIR -from tests.integration.conftest import * # noqa: F403 +from tests.integration.conftest import check_answer # noqa F401 -IS_CI = True if os.getenv("CI") else False - -SCALE_FACTOR = 0.2 ENGINES = ["native"] if IS_CI else ["native", "python"] -NUM_PARTS = [1] if IS_CI else [1, 2] -SOURCE_TYPES = ["in-memory"] if IS_CI else ["parquet", "in-memory"] - - -@pytest.fixture(scope="session", params=NUM_PARTS) -def gen_tpch(request): - # Parametrize the number of parts for each file so that we run tests on single-partition files and multi-partition files - num_parts = request.param - - csv_files_location = data_generation.gen_csv_files(TPCH_DBGEN_DIR, num_parts, SCALE_FACTOR) - - # Disable native executor to generate parquet files, remove once native executor supports writing parquet files - with daft.context.execution_config_ctx(enable_native_executor=False): - parquet_files_location = data_generation.gen_parquet(csv_files_location) - - in_memory_tables = {} - for tbl_name in data_generation.SCHEMA.keys(): - arrow_table = daft.read_parquet(f"{parquet_files_location}/{tbl_name}/*").to_arrow() - in_memory_tables[tbl_name] = daft.from_arrow(arrow_table) - - sqlite_path = data_generation.gen_sqlite_db( - csv_filepath=csv_files_location, - num_parts=num_parts, - ) - - return ( - csv_files_location, - parquet_files_location, - in_memory_tables, - num_parts, - ), sqlite_path - - -@pytest.fixture(scope="module", params=SOURCE_TYPES) # TODO: Enable CSV after improving the CSV reader -def get_df(gen_tpch, request): - (csv_files_location, parquet_files_location, in_memory_tables, num_parts), _ = gen_tpch - source_type = request.param - - def _get_df(tbl_name: str): - if source_type == "csv": - local_fs = LocalFileSystem() - nonchunked_filepath = f"{csv_files_location}/{tbl_name}.tbl" - chunked_filepath = nonchunked_filepath + ".*" - try: - local_fs.expand_path(chunked_filepath) - fp = chunked_filepath - except FileNotFoundError: - fp = nonchunked_filepath - - df = daft.read_csv( - fp, - has_headers=False, - delimiter="|", - ) - df = df.select( - *[ - daft.col(autoname).alias(colname) - for autoname, colname in zip(df.column_names, data_generation.SCHEMA[tbl_name]) - ] - ) - elif source_type == "parquet": - fp = f"{parquet_files_location}/{tbl_name}/*" - df = daft.read_parquet(fp) - elif source_type == "in-memory": - df = in_memory_tables[tbl_name] - - return df - - return _get_df, num_parts - TPCH_QUESTIONS = list(range(1, 11)) @@ -102,7 +27,7 @@ def _get_df(tbl_name: str): ) @pytest.mark.benchmark(group="tpch") @pytest.mark.parametrize("engine, q", itertools.product(ENGINES, TPCH_QUESTIONS)) -def test_tpch(tmp_path, check_answer, get_df, benchmark_with_memray, engine, q): +def test_tpch(tmp_path, check_answer, get_df, benchmark_with_memray, engine, q): # noqa F811 get_df, num_parts = get_df def f(): diff --git a/tests/benchmarks/test_streaming_writes.py b/tests/benchmarks/test_streaming_writes.py new file mode 100644 index 0000000000..eaab068101 --- /dev/null +++ b/tests/benchmarks/test_streaming_writes.py @@ -0,0 +1,68 @@ +import pytest + +import daft +from tests.benchmarks.conftest import IS_CI + +ENGINES = ["native", "python"] + + +@pytest.mark.skipif(IS_CI, reason="Write benchmarks are not run in CI") +@pytest.mark.benchmark(group="write") +@pytest.mark.parametrize("engine", ENGINES) +@pytest.mark.parametrize( + "file_type, target_file_size, target_row_group_size", + [ + ("parquet", None, None), + ( + "parquet", + 5 * 1024 * 1024, + 1024 * 1024, + ), # 5MB target file size, 1MB target row group size + ("csv", None, None), + ], +) +@pytest.mark.parametrize("partition_cols", [None, ["L_SHIPMODE"]]) +@pytest.mark.parametrize("get_df", ["in-memory"], indirect=True) +def test_streaming_write( + tmp_path, + get_df, + benchmark_with_memray, + engine, + file_type, + target_file_size, + target_row_group_size, + partition_cols, +): + get_df, num_parts = get_df + daft_df = get_df("lineitem") + + def f(): + if engine == "native": + ctx = daft.context.execution_config_ctx( + enable_native_executor=True, + parquet_target_filesize=target_file_size, + parquet_target_row_group_size=target_row_group_size, + csv_target_filesize=target_file_size, + ) + elif engine == "python": + ctx = daft.context.execution_config_ctx( + enable_native_executor=False, + parquet_target_filesize=target_file_size, + parquet_target_row_group_size=target_row_group_size, + csv_target_filesize=target_file_size, + ) + else: + raise ValueError(f"{engine} unsupported") + + with ctx: + if file_type == "parquet": + return daft_df.write_parquet(tmp_path, partition_cols=partition_cols) + elif file_type == "csv": + return daft_df.write_csv(tmp_path, partition_cols=partition_cols) + else: + raise ValueError(f"{file_type} unsupported") + + benchmark_group = f"parts-{num_parts}-partition-cols-{partition_cols}-file-type-{file_type}-target-file-size-{target_file_size}-target-row-group-size-{target_row_group_size}" + result_files = benchmark_with_memray(f, benchmark_group).to_pydict()["path"] + read_back = daft.read_parquet(result_files) if file_type == "parquet" else daft.read_csv(result_files) + assert read_back.count_rows() == daft_df.count_rows() diff --git a/tests/cookbook/test_write.py b/tests/cookbook/test_write.py index ddd2c9b040..c00e00f1ac 100644 --- a/tests/cookbook/test_write.py +++ b/tests/cookbook/test_write.py @@ -7,14 +7,9 @@ from pyarrow import dataset as pads import daft -from daft import context from tests.conftest import assert_df_equals from tests.cookbook.assets import COOKBOOK_DATA_CSV -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) PYARROW_GE_7_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (7, 0, 0) @@ -178,9 +173,8 @@ def test_parquet_write_multifile(tmp_path, smaller_parquet_target_filesize): df = daft.from_pydict(data) df2 = df.write_parquet(tmp_path) assert len(df2) > 1 - ds = pads.dataset(tmp_path, format="parquet") - readback = ds.to_table() - assert readback.to_pydict() == data + read_back = daft.read_parquet(tmp_path.as_posix() + "/*.parquet").sort(by="x").to_pydict() + assert read_back == data @pytest.mark.skipif( @@ -201,9 +195,7 @@ def test_parquet_write_multifile_with_partitioning(tmp_path, smaller_parquet_tar def test_parquet_write_with_some_empty_partitions(tmp_path): data = {"x": [1, 2, 3], "y": ["a", "b", "c"]} - output_files = daft.from_pydict(data).into_partitions(4).write_parquet(tmp_path) - - assert len(output_files) == 3 + daft.from_pydict(data).into_partitions(4).write_parquet(tmp_path) read_back = daft.read_parquet(tmp_path.as_posix() + "/**/*.parquet").sort("x").to_pydict() assert read_back == data @@ -286,9 +278,7 @@ def test_empty_csv_write_with_partitioning(tmp_path): def test_csv_write_with_some_empty_partitions(tmp_path): data = {"x": [1, 2, 3], "y": ["a", "b", "c"]} - output_files = daft.from_pydict(data).into_partitions(4).write_csv(tmp_path) - - assert len(output_files) == 3 + daft.from_pydict(data).into_partitions(4).write_csv(tmp_path) read_back = daft.read_csv(tmp_path.as_posix() + "/**/*.csv").sort("x").to_pydict() assert read_back == data diff --git a/tests/dataframe/test_decimals.py b/tests/dataframe/test_decimals.py index daafec29f0..3a2d11babe 100644 --- a/tests/dataframe/test_decimals.py +++ b/tests/dataframe/test_decimals.py @@ -7,12 +7,6 @@ import pytest import daft -from daft import context - -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) PYARROW_GE_7_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (7, 0, 0) diff --git a/tests/dataframe/test_temporals.py b/tests/dataframe/test_temporals.py index 9d09ce20fa..1a26b8cffc 100644 --- a/tests/dataframe/test_temporals.py +++ b/tests/dataframe/test_temporals.py @@ -9,12 +9,7 @@ import pytz import daft -from daft import DataType, col, context - -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) +from daft import DataType, col PYARROW_GE_7_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (7, 0, 0) diff --git a/tests/io/test_csv_roundtrip.py b/tests/io/test_csv_roundtrip.py index b9e8ccc9b8..dd288e6806 100644 --- a/tests/io/test_csv_roundtrip.py +++ b/tests/io/test_csv_roundtrip.py @@ -7,12 +7,8 @@ import pytest import daft -from daft import DataType, TimeUnit, context +from daft import DataType, TimeUnit -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) PYARROW_GE_11_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (11, 0, 0) diff --git a/tests/io/test_parquet_roundtrip.py b/tests/io/test_parquet_roundtrip.py index 292c5b98e1..8804ef6d33 100644 --- a/tests/io/test_parquet_roundtrip.py +++ b/tests/io/test_parquet_roundtrip.py @@ -12,13 +12,6 @@ PYARROW_GE_8_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (8, 0, 0) -from daft import context - -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) - @pytest.mark.skipif( not PYARROW_GE_8_0_0, diff --git a/tests/io/test_s3_credentials_refresh.py b/tests/io/test_s3_credentials_refresh.py index 25c5c0cd8c..16a98fadf0 100644 --- a/tests/io/test_s3_credentials_refresh.py +++ b/tests/io/test_s3_credentials_refresh.py @@ -10,14 +10,8 @@ import pytest import daft -from daft import context from tests.io.mock_aws_server import start_service, stop_process -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) - @pytest.fixture(scope="session") def aws_log_file(tmp_path_factory: pytest.TempPathFactory) -> Iterator[io.IOBase]: