diff --git a/Cargo.lock b/Cargo.lock index 592a0793ba..814269a215 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1688,6 +1688,7 @@ dependencies = [ "daft-micropartition", "daft-minhash", "daft-parquet", + "daft-physical-plan", "daft-plan", "daft-scan", "daft-scheduler", diff --git a/Cargo.toml b/Cargo.toml index 1d1065f026..eae7570216 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ daft-local-execution = {path = "src/daft-local-execution", default-features = fa daft-micropartition = {path = "src/daft-micropartition", default-features = false} daft-minhash = {path = "src/daft-minhash", default-features = false} daft-parquet = {path = "src/daft-parquet", default-features = false} +daft-physical-plan = {path = "src/daft-physical-plan", default-features = false} daft-plan = {path = "src/daft-plan", default-features = false} daft-scan = {path = "src/daft-scan", default-features = false} daft-scheduler = {path = "src/daft-scheduler", default-features = false} @@ -47,6 +48,7 @@ python = [ "daft-json/python", "daft-micropartition/python", "daft-parquet/python", + "daft-physical-plan/python", "daft-plan/python", "daft-scan/python", "daft-scheduler/python", diff --git a/daft/io/writer.py b/daft/io/writer.py new file mode 100644 index 0000000000..0b98c69dad --- /dev/null +++ b/daft/io/writer.py @@ -0,0 +1,79 @@ +import uuid +from typing import Optional, Union + +from daft.daft import IOConfig +from daft.dependencies import pa, pacsv, pq +from daft.filesystem import ( + _resolve_paths_and_filesystem, + canonicalize_protocol, + get_protocol_from_path, +) +from daft.table.micropartition import MicroPartition + + +class FileWriterBase: + def __init__( + self, + root_dir: str, + file_idx: int, + file_format: str, + compression: Optional[str] = None, + io_config: Optional[IOConfig] = None, + ): + [self.resolved_path], self.fs = _resolve_paths_and_filesystem(root_dir, io_config=io_config) + protocol = get_protocol_from_path(root_dir) + canonicalized_protocol = canonicalize_protocol(protocol) + is_local_fs = canonicalized_protocol == "file" + if is_local_fs: + self.fs.create_dir(root_dir) + + self.file_name = f"{uuid.uuid4()}-{file_idx}.{file_format}" + self.full_path = f"{self.resolved_path}/{self.file_name}" + self.compression = compression if compression is not None else "none" + self.current_writer: Optional[Union[pq.ParquetWriter, pacsv.CSVWriter]] = None + + def _create_writer(self, schema: pa.Schema): + raise NotImplementedError("Subclasses must implement this method.") + + def write(self, table: MicroPartition): + if self.current_writer is None: + self.current_writer = self._create_writer(table.schema().to_pyarrow_schema()) + self.current_writer.write_table(table.to_arrow()) + + def close(self) -> Optional[str]: + if self.current_writer is None: + return None + self.current_writer.close() + return self.full_path + + +class ParquetFileWriter(FileWriterBase): + def __init__( + self, + root_dir: str, + file_idx: int, + compression: str = "none", + io_config: Optional[IOConfig] = None, + ): + super().__init__(root_dir, file_idx, "parquet", compression, io_config) + + def _create_writer(self, schema: pa.Schema) -> pq.ParquetWriter: + return pq.ParquetWriter( + self.full_path, + schema, + compression=self.compression, + use_compliant_nested_type=False, + filesystem=self.fs, + ) + + +class CSVFileWriter(FileWriterBase): + def __init__(self, root_dir: str, file_idx: int, io_config: Optional[IOConfig] = None): + super().__init__(root_dir, file_idx, "csv", None, io_config) + + def _create_writer(self, schema: pa.Schema) -> pacsv.CSVWriter: + file_path = f"{self.resolved_path}/{self.file_name}" + return pacsv.CSVWriter( + file_path, + schema, + ) diff --git a/src/daft-local-execution/src/buffer.rs b/src/daft-local-execution/src/buffer.rs new file mode 100644 index 0000000000..391e4ab604 --- /dev/null +++ b/src/daft-local-execution/src/buffer.rs @@ -0,0 +1,89 @@ +use std::{cmp::Ordering::*, collections::VecDeque, sync::Arc}; + +use common_error::DaftResult; +use daft_micropartition::MicroPartition; + +pub struct RowBasedBuffer { + pub buffer: VecDeque>, + pub curr_len: usize, + pub threshold: usize, +} + +impl RowBasedBuffer { + pub fn new(threshold: usize) -> Self { + assert!(threshold > 0); + Self { + buffer: VecDeque::new(), + curr_len: 0, + threshold, + } + } + + pub fn push(&mut self, part: Arc) { + self.curr_len += part.len(); + self.buffer.push_back(part); + } + + pub fn pop_enough(&mut self) -> DaftResult>>> { + match self.curr_len.cmp(&self.threshold) { + Less => Ok(None), + Equal => { + if self.buffer.len() == 1 { + let part = self.buffer.pop_front().unwrap(); + self.curr_len = 0; + Ok(Some(vec![part])) + } else { + let chunk = MicroPartition::concat( + &std::mem::take(&mut self.buffer) + .iter() + .map(|x| x.as_ref()) + .collect::>(), + )?; + self.curr_len = 0; + Ok(Some(vec![Arc::new(chunk)])) + } + } + Greater => { + let num_ready_chunks = self.curr_len / self.threshold; + let concated = MicroPartition::concat( + &std::mem::take(&mut self.buffer) + .iter() + .map(|x| x.as_ref()) + .collect::>(), + )?; + let mut start = 0; + let mut parts_to_return = Vec::with_capacity(num_ready_chunks); + for _ in 0..num_ready_chunks { + let end = start + self.threshold; + let part = Arc::new(concated.slice(start, end)?); + parts_to_return.push(part); + start = end; + } + if start < concated.len() { + let part = Arc::new(concated.slice(start, concated.len())?); + self.curr_len = part.len(); + self.buffer.push_back(part); + } else { + self.curr_len = 0; + } + Ok(Some(parts_to_return)) + } + } + } + + pub fn pop_all(&mut self) -> DaftResult>> { + assert!(self.curr_len < self.threshold); + if self.buffer.is_empty() { + Ok(None) + } else { + let concated = MicroPartition::concat( + &std::mem::take(&mut self.buffer) + .iter() + .map(|x| x.as_ref()) + .collect::>(), + )?; + self.curr_len = 0; + Ok(Some(Arc::new(concated))) + } + } +} diff --git a/src/daft-local-execution/src/intermediate_ops/buffer.rs b/src/daft-local-execution/src/intermediate_ops/buffer.rs deleted file mode 100644 index 67b17c5380..0000000000 --- a/src/daft-local-execution/src/intermediate_ops/buffer.rs +++ /dev/null @@ -1,77 +0,0 @@ -use std::{cmp::Ordering::*, collections::VecDeque, sync::Arc}; - -use common_error::DaftResult; -use daft_micropartition::MicroPartition; - -pub struct OperatorBuffer { - pub buffer: VecDeque>, - pub curr_len: usize, - pub threshold: usize, -} - -impl OperatorBuffer { - pub fn new(threshold: usize) -> Self { - assert!(threshold > 0); - Self { - buffer: VecDeque::new(), - curr_len: 0, - threshold, - } - } - - pub fn push(&mut self, part: Arc) { - self.curr_len += part.len(); - self.buffer.push_back(part); - } - - pub fn try_clear(&mut self) -> Option>> { - match self.curr_len.cmp(&self.threshold) { - Less => None, - Equal => self.clear_all(), - Greater => Some(self.clear_enough()), - } - } - - fn clear_enough(&mut self) -> DaftResult> { - assert!(self.curr_len > self.threshold); - - let mut to_concat = Vec::with_capacity(self.buffer.len()); - let mut remaining = self.threshold; - - while remaining > 0 { - let part = self.buffer.pop_front().expect("Buffer should not be empty"); - let part_len = part.len(); - if part_len <= remaining { - remaining -= part_len; - to_concat.push(part); - } else { - let (head, tail) = part.split_at(remaining)?; - remaining = 0; - to_concat.push(Arc::new(head)); - self.buffer.push_front(Arc::new(tail)); - break; - } - } - assert_eq!(remaining, 0); - - self.curr_len -= self.threshold; - match to_concat.len() { - 1 => Ok(to_concat.pop().unwrap()), - _ => MicroPartition::concat(&to_concat.iter().map(|x| x.as_ref()).collect::>()) - .map(Arc::new), - } - } - - pub fn clear_all(&mut self) -> Option>> { - if self.buffer.is_empty() { - return None; - } - - let concated = - MicroPartition::concat(&self.buffer.iter().map(|x| x.as_ref()).collect::>()) - .map(Arc::new); - self.buffer.clear(); - self.curr_len = 0; - Some(concated) - } -} diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index abb5c5388b..da81b364d8 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -5,8 +5,8 @@ use common_error::DaftResult; use daft_micropartition::MicroPartition; use tracing::{info_span, instrument}; -use super::buffer::OperatorBuffer; use crate::{ + buffer::RowBasedBuffer, channel::{create_channel, PipelineChannel, Receiver, Sender}, pipeline::{PipelineNode, PipelineResultType}, runtime_stats::{CountingReceiver, CountingSender, RuntimeStatsContext}, @@ -135,7 +135,7 @@ impl IntermediateNode { }; for (idx, mut receiver) in receivers.into_iter().enumerate() { - let mut buffer = OperatorBuffer::new(morsel_size); + let mut buffer = RowBasedBuffer::new(morsel_size); while let Some(morsel) = receiver.recv().await { if morsel.should_broadcast() { for worker_sender in worker_senders.iter() { @@ -143,18 +143,15 @@ impl IntermediateNode { } } else { buffer.push(morsel.as_data().clone()); - if let Some(ready) = buffer.try_clear() { - let _ = send_to_next_worker(idx, ready?.into()).await; + if let Some(ready) = buffer.pop_enough()? { + for part in ready { + let _ = send_to_next_worker(idx, part.into()).await; + } } } } - // Buffer may still have some morsels left above the threshold - while let Some(ready) = buffer.try_clear() { - let _ = send_to_next_worker(idx, ready?.into()).await; - } - // Clear all remaining morsels - if let Some(last_morsel) = buffer.clear_all() { - let _ = send_to_next_worker(idx, last_morsel?.into()).await; + if let Some(ready) = buffer.pop_all()? { + let _ = send_to_next_worker(idx, ready.into()).await; } } Ok(()) diff --git a/src/daft-local-execution/src/intermediate_ops/mod.rs b/src/daft-local-execution/src/intermediate_ops/mod.rs index 593f9ef5ed..7dafadebba 100644 --- a/src/daft-local-execution/src/intermediate_ops/mod.rs +++ b/src/daft-local-execution/src/intermediate_ops/mod.rs @@ -1,6 +1,5 @@ pub mod aggregate; pub mod anti_semi_hash_join_probe; -pub mod buffer; pub mod filter; pub mod hash_join_probe; pub mod intermediate_op; diff --git a/src/daft-local-execution/src/lib.rs b/src/daft-local-execution/src/lib.rs index 1356689a08..e0a0e4cb85 100644 --- a/src/daft-local-execution/src/lib.rs +++ b/src/daft-local-execution/src/lib.rs @@ -1,4 +1,5 @@ #![feature(let_chains)] +mod buffer; mod channel; mod intermediate_ops; mod pipeline; @@ -14,15 +15,20 @@ lazy_static! { pub static ref NUM_CPUS: usize = std::thread::available_parallelism().unwrap().get(); } +pub(crate) type TaskSet = tokio::task::JoinSet; +pub(crate) fn create_task_set() -> TaskSet { + TaskSet::new() +} + pub struct ExecutionRuntimeHandle { - worker_set: tokio::task::JoinSet>, + worker_set: TaskSet>, default_morsel_size: usize, } impl ExecutionRuntimeHandle { pub fn new(default_morsel_size: usize) -> Self { Self { - worker_set: tokio::task::JoinSet::new(), + worker_set: create_task_set(), default_morsel_size, } } diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index 0f84ac2636..e3bfae0ebe 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -1,7 +1,9 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{cmp::min, collections::HashMap, sync::Arc}; +use common_daft_config::DaftExecutionConfig; use common_display::{mermaid::MermaidDisplayVisitor, tree::TreeDisplay}; use common_error::DaftResult; +use common_file_formats::FileFormat; use daft_core::{ datatypes::Field, prelude::{Schema, SchemaRef}, @@ -10,8 +12,8 @@ use daft_core::{ use daft_dsl::{col, join::get_common_join_keys, Expr}; use daft_micropartition::MicroPartition; use daft_physical_plan::{ - Filter, HashAggregate, HashJoin, InMemoryScan, Limit, LocalPhysicalPlan, Project, Sort, - UnGroupedAggregate, + Filter, HashAggregate, HashJoin, InMemoryScan, Limit, LocalPhysicalPlan, PhysicalWrite, + Project, Sort, UnGroupedAggregate, }; use daft_plan::{populate_aggregation_stages, JoinType}; use daft_table::{Probeable, Table}; @@ -27,8 +29,9 @@ use crate::{ }, sinks::{ aggregate::AggregateSink, blocking_sink::BlockingSinkNode, - hash_join_build::HashJoinBuildSink, limit::LimitSink, sort::SortSink, - streaming_sink::StreamingSinkNode, + hash_join_build::HashJoinBuildSink, limit::LimitSink, + partitioned_write::PartitionedWriteNode, sort::SortSink, streaming_sink::StreamingSinkNode, + unpartitioned_write::UnpartionedWriteNode, }, sources::in_memory::InMemorySource, ExecutionRuntimeHandle, PipelineCreationSnafu, @@ -99,6 +102,7 @@ pub(crate) fn viz_pipeline(root: &dyn PipelineNode) -> String { pub fn physical_plan_to_pipeline( physical_plan: &LocalPhysicalPlan, psets: &HashMap>>, + cfg: &Arc, ) -> crate::Result> { use daft_physical_plan::PhysicalScan; @@ -118,21 +122,21 @@ pub fn physical_plan_to_pipeline( input, projection, .. }) => { let proj_op = ProjectOperator::new(projection.clone()); - let child_node = physical_plan_to_pipeline(input, psets)?; + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; IntermediateNode::new(Arc::new(proj_op), vec![child_node]).boxed() } LocalPhysicalPlan::Filter(Filter { input, predicate, .. }) => { let filter_op = FilterOperator::new(predicate.clone()); - let child_node = physical_plan_to_pipeline(input, psets)?; + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; IntermediateNode::new(Arc::new(filter_op), vec![child_node]).boxed() } LocalPhysicalPlan::Limit(Limit { input, num_rows, .. }) => { let sink = LimitSink::new(*num_rows as usize); - let child_node = physical_plan_to_pipeline(input, psets)?; + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; StreamingSinkNode::new(sink.boxed(), vec![child_node]).boxed() } LocalPhysicalPlan::Concat(_) => { @@ -158,7 +162,7 @@ pub fn physical_plan_to_pipeline( .collect(), vec![], ); - let child_node = physical_plan_to_pipeline(input, psets)?; + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; let post_first_agg_node = IntermediateNode::new(Arc::new(first_stage_agg_op), vec![child_node]).boxed(); @@ -186,7 +190,7 @@ pub fn physical_plan_to_pipeline( }) => { let (first_stage_aggs, second_stage_aggs, final_exprs) = populate_aggregation_stages(aggregations, schema, group_by); - let child_node = physical_plan_to_pipeline(input, psets)?; + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; let (post_first_agg_node, group_by) = if !first_stage_aggs.is_empty() { let agg_op = AggregateOperator::new( first_stage_aggs @@ -226,7 +230,7 @@ pub fn physical_plan_to_pipeline( .. }) => { let sort_sink = SortSink::new(sort_by.clone(), descending.clone()); - let child_node = physical_plan_to_pipeline(input, psets)?; + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; BlockingSinkNode::new(sort_sink.boxed(), child_node).boxed() } LocalPhysicalPlan::HashJoin(HashJoin { @@ -298,11 +302,11 @@ pub fn physical_plan_to_pipeline( // we should move to a builder pattern let build_sink = HashJoinBuildSink::new(key_schema.clone(), casted_build_on, join_type)?; - let build_child_node = physical_plan_to_pipeline(build_child, psets)?; + let build_child_node = physical_plan_to_pipeline(build_child, psets, cfg)?; let build_node = BlockingSinkNode::new(build_sink.boxed(), build_child_node).boxed(); - let probe_child_node = physical_plan_to_pipeline(probe_child, psets)?; + let probe_child_node = physical_plan_to_pipeline(probe_child, psets, cfg)?; match join_type { JoinType::Anti | JoinType::Semi => DaftResult::Ok(IntermediateNode::new( @@ -332,8 +336,53 @@ pub fn physical_plan_to_pipeline( })?; probe_node.boxed() } - _ => { - unimplemented!("Physical plan not supported: {}", physical_plan.name()); + LocalPhysicalPlan::PhysicalWrite(PhysicalWrite { + input, + file_info, + data_schema, + file_schema, + .. + }) => { + let child_node = physical_plan_to_pipeline(input, psets, cfg)?; + let (inflation_factor, target_file_size) = match file_info.file_format { + FileFormat::Parquet => (cfg.parquet_inflation_factor, cfg.parquet_target_filesize), + FileFormat::Csv => (cfg.csv_inflation_factor, cfg.csv_target_filesize), + _ => unreachable!("Unsupported file format"), + }; + let estimated_row_size_bytes = data_schema.estimate_row_size_bytes(); + let target_file_rows = if estimated_row_size_bytes > 0.0 { + target_file_size as f64 * inflation_factor / estimated_row_size_bytes + } else { + target_file_size as f64 * inflation_factor + } as usize; + // Just assume same chunk size for CSV and Parquet for now + let target_chunk_rows = min( + target_file_rows, + if estimated_row_size_bytes > 0.0 { + cfg.parquet_target_row_group_size as f64 * inflation_factor + / estimated_row_size_bytes + } else { + cfg.parquet_target_row_group_size as f64 * inflation_factor + } as usize, + ); + match &file_info.partition_cols { + Some(_) => PartitionedWriteNode::new( + child_node, + file_info, + file_schema, + target_file_rows, + target_chunk_rows, + ) + .boxed(), + None => UnpartionedWriteNode::new( + child_node, + file_info, + file_schema, + target_file_rows, + target_chunk_rows, + ) + .boxed(), + } } }; diff --git a/src/daft-local-execution/src/run.rs b/src/daft-local-execution/src/run.rs index 38d7c3e479..48d8d9b40b 100644 --- a/src/daft-local-execution/src/run.rs +++ b/src/daft-local-execution/src/run.rs @@ -121,7 +121,7 @@ pub fn run_local( results_buffer_size: Option, ) -> DaftResult>> + Send>> { refresh_chrome_trace(); - let mut pipeline = physical_plan_to_pipeline(physical_plan, &psets)?; + let mut pipeline = physical_plan_to_pipeline(physical_plan, &psets, &cfg)?; let (tx, rx) = create_channel(results_buffer_size.unwrap_or(1)); let handle = std::thread::spawn(move || { let runtime = tokio::runtime::Builder::new_multi_thread() diff --git a/src/daft-local-execution/src/sinks/mod.rs b/src/daft-local-execution/src/sinks/mod.rs index 39910e7995..f2212eda3c 100644 --- a/src/daft-local-execution/src/sinks/mod.rs +++ b/src/daft-local-execution/src/sinks/mod.rs @@ -3,5 +3,7 @@ pub mod blocking_sink; pub mod concat; pub mod hash_join_build; pub mod limit; +pub mod partitioned_write; pub mod sort; pub mod streaming_sink; +pub mod unpartitioned_write; diff --git a/src/daft-local-execution/src/sinks/partitioned_write.rs b/src/daft-local-execution/src/sinks/partitioned_write.rs new file mode 100644 index 0000000000..2ac53596f6 --- /dev/null +++ b/src/daft-local-execution/src/sinks/partitioned_write.rs @@ -0,0 +1,397 @@ +use std::{collections::HashMap, sync::Arc}; + +use common_display::tree::TreeDisplay; +use common_error::DaftResult; +use common_file_formats::FileFormat; +use daft_core::{ + prelude::{SchemaRef, Utf8Array}, + series::IntoSeries, +}; +use daft_dsl::ExprRef; +use daft_io::IOStatsContext; +use daft_micropartition::{create_file_writer, FileWriter, MicroPartition}; +use daft_plan::OutputFileInfo; +use daft_table::Table; +use snafu::ResultExt; + +use crate::{ + buffer::RowBasedBuffer, + channel::{create_channel, PipelineChannel, Receiver, Sender}, + pipeline::PipelineNode, + runtime_stats::{CountingReceiver, RuntimeStatsContext}, + ExecutionRuntimeHandle, JoinSnafu, NUM_CPUS, +}; + +struct SizedDataWriter { + root_dir: String, + file_format: FileFormat, + compression: Option, + io_config: Option, + writer: Box, + target_file_rows: usize, + written_files: Vec>, + written_rows_so_far: usize, +} + +impl SizedDataWriter { + fn new( + root_dir: String, + target_file_rows: usize, + file_format: FileFormat, + compression: Option, + io_config: Option, + ) -> DaftResult { + Ok(Self { + writer: create_file_writer(&root_dir, 0, &compression, &io_config, file_format)?, + root_dir, + file_format, + compression, + io_config, + target_file_rows, + written_files: vec![], + written_rows_so_far: 0, + }) + } + + fn write(&mut self, data: &Arc) -> DaftResult<()> { + let len = data.len(); + self.writer.write(data)?; + if self.written_rows_so_far + len >= self.target_file_rows { + let file_path = self.writer.close()?; + if let Some(file_path) = file_path { + self.written_files.push(Some(file_path)); + } + self.written_rows_so_far = 0; + self.writer = create_file_writer( + &self.root_dir, + self.written_files.len(), + &self.compression, + &self.io_config, + self.file_format, + )?; + } + Ok(()) + } + + fn finalize(&mut self) -> DaftResult>> { + if let Some(file_path) = self.writer.close()? { + self.written_files.push(Some(file_path)); + } + Ok(self.written_files.clone()) + } +} + +pub(crate) struct PartitionedWriteNode { + child: Box, + runtime_stats: Arc, + file_info: OutputFileInfo, + file_schema: SchemaRef, + target_file_rows: usize, + target_chunk_rows: usize, +} + +impl PartitionedWriteNode { + pub(crate) fn new( + child: Box, + file_info: &OutputFileInfo, + file_schema: &SchemaRef, + target_file_rows: usize, + target_chunk_rows: usize, + ) -> Self { + Self { + child, + runtime_stats: RuntimeStatsContext::new(), + file_info: file_info.clone(), + file_schema: file_schema.clone(), + target_file_rows, + target_chunk_rows, + } + } + + pub(crate) fn boxed(self) -> Box { + Box::new(self) + } + + fn partition( + partition_cols: &[ExprRef], + default_partition_fallback: Arc, + data: &Arc, + ) -> DaftResult<(Vec, Table, Vec>)> { + let (split_tables, partition_values) = data.partition_by_value(partition_cols)?; + let concated = partition_values + .concat_or_get(IOStatsContext::new("MicroPartition::partition_by_value"))?; + let partition_values_table = concated.first().unwrap(); + let pkey_names = partition_values_table.column_names(); + + let mut values_string_values = Vec::with_capacity(partition_values_table.len()); + for name in pkey_names.iter() { + let column = partition_values_table.get_column(name)?; + let string_names = column.to_str_values()?; + let default_part = Utf8Array::from_iter( + "default", + std::iter::once(Some(default_partition_fallback.clone())), + ) + .into_series(); + let null_filled = string_names.if_else(&default_part, &column.not_null()?)?; + values_string_values.push(null_filled); + } + + let mut part_keys_postfixes = Vec::with_capacity(partition_values_table.len()); + for i in 0..partition_values_table.len() { + let postfix = pkey_names + .iter() + .zip(values_string_values.iter()) + .map(|(pkey, values)| { + format!("{}={}", pkey, values.utf8().unwrap().get(i).unwrap()) + }) + .collect::>() + .join("/"); + part_keys_postfixes.push(Arc::from(postfix)); + } + + Ok(( + split_tables, + partition_values_table.clone(), + part_keys_postfixes, + )) + } + + async fn run_writer( + mut input_receiver: Receiver>, + file_info: Arc, + default_partition_fallback: Arc, + target_chunk_rows: usize, + target_file_rows: usize, + ) -> DaftResult> { + let mut writers: HashMap, SizedDataWriter> = HashMap::new(); + let mut buffers: HashMap, RowBasedBuffer> = HashMap::new(); + let mut partition_key_values: HashMap, Table> = HashMap::new(); + while let Some(data) = input_receiver.recv().await { + let (split_tables, partition_values_table, part_keys_postfixes) = Self::partition( + file_info.partition_cols.as_ref().unwrap(), + default_partition_fallback.clone(), + &data, + )?; + for (idx, (postfix, partition)) in part_keys_postfixes + .iter() + .zip(split_tables.into_iter()) + .enumerate() + { + if !partition_key_values.contains_key(postfix) { + let partition_value_row = partition_values_table.slice(idx, idx + 1)?; + partition_key_values.insert(postfix.clone(), partition_value_row); + } + + let buffer = if let Some(buffer) = buffers.get_mut(postfix) { + buffer + } else { + let buffer = RowBasedBuffer::new(target_chunk_rows); + buffers.insert(postfix.clone(), buffer); + &mut buffers.get_mut(postfix).unwrap() + }; + buffer.push(Arc::new(partition)); + + if let Some(ready) = buffer.pop_enough()? { + for part in ready { + if let Some(writer) = writers.get_mut(postfix) { + writer.write(&part)?; + } else { + let mut writer = SizedDataWriter::new( + format!("{}/{}", file_info.root_dir, postfix), + target_file_rows, + file_info.file_format, + file_info.compression.clone(), + file_info.io_config.clone(), + )?; + writer.write(&part)?; + writers.insert(postfix.clone(), writer); + } + } + } + } + } + for (postfix, buffer) in buffers.iter_mut() { + let remaining = buffer.pop_all()?; + if let Some(part) = remaining { + if let Some(writer) = writers.get_mut(postfix) { + writer.write(&part)?; + } else { + let mut writer = SizedDataWriter::new( + format!("{}/{}", file_info.root_dir, postfix), + target_file_rows, + file_info.file_format, + file_info.compression.clone(), + file_info.io_config.clone(), + )?; + writer.write(&part)?; + writers.insert(postfix.clone(), writer); + } + } + } + + let mut written_files = Vec::with_capacity(writers.len()); + let mut partition_keys_values = Vec::with_capacity(writers.len()); + for (postfix, partition_key_value) in partition_key_values.iter() { + let writer = writers.get_mut(postfix).unwrap(); + let file_paths = writer.finalize()?; + if !file_paths.is_empty() { + if file_paths.len() > partition_key_value.len() { + let mut columns = Vec::with_capacity(partition_key_value.num_columns()); + let column_names = partition_key_value.column_names(); + for name in column_names { + let column = partition_key_value.get_column(name)?; + let broadcasted = column.broadcast(file_paths.len())?; + columns.push(broadcasted); + } + let table = Table::from_nonempty_columns(columns)?; + partition_keys_values.push(table); + } else { + partition_keys_values.push(partition_key_value.clone()); + } + } + written_files.extend(file_paths.into_iter()); + } + if written_files.is_empty() { + return Ok(None); + } + let written_files_table = Table::from_nonempty_columns(vec![Utf8Array::from_iter( + "path", + written_files.into_iter(), + ) + .into_series()])?; + if !partition_keys_values.is_empty() { + let unioned = written_files_table.union(&Table::concat(&partition_keys_values)?)?; + Ok(Some(unioned)) + } else { + Ok(Some(written_files_table)) + } + } + + fn spawn_writers( + num_writers: usize, + task_set: &mut tokio::task::JoinSet>>, + file_info: &Arc, + default_partition_fallback: &str, + target_chunk_rows: usize, + target_file_rows: usize, + ) -> Vec>> { + let mut writer_senders = Vec::with_capacity(num_writers); + for _ in 0..num_writers { + let (writer_sender, writer_receiver) = create_channel(1); + task_set.spawn(Self::run_writer( + writer_receiver, + file_info.clone(), + default_partition_fallback.into(), + target_chunk_rows, + target_file_rows, + )); + writer_senders.push(writer_sender); + } + writer_senders + } + + async fn dispatch( + mut input_receiver: CountingReceiver, + senders: Vec>>, + partiton_cols: Vec, + ) -> DaftResult<()> { + while let Some(data) = input_receiver.recv().await { + let partitioned = data + .as_data() + .partition_by_hash(&partiton_cols, senders.len())?; + for (idx, mp) in partitioned.into_iter().enumerate() { + if !mp.is_empty() { + let _ = senders[idx].send(mp.into()).await; + } + } + } + Ok(()) + } +} + +impl TreeDisplay for PartitionedWriteNode { + fn display_as(&self, level: common_display::DisplayLevel) -> String { + use std::fmt::Write; + let mut display = String::new(); + writeln!(display, "{}", self.name()).unwrap(); + use common_display::DisplayLevel::*; + match level { + Compact => {} + _ => { + let rt_result = self.runtime_stats.result(); + rt_result.display(&mut display, true, true, true).unwrap(); + } + } + display + } + + fn get_children(&self) -> Vec<&dyn TreeDisplay> { + vec![self.child.as_tree_display()] + } +} + +impl PipelineNode for PartitionedWriteNode { + fn children(&self) -> Vec<&dyn PipelineNode> { + vec![self.child.as_ref()] + } + + fn name(&self) -> &'static str { + "PartitionedWrite" + } + + fn start( + &mut self, + maintain_order: bool, + runtime_handle: &mut ExecutionRuntimeHandle, + ) -> crate::Result { + let child = self.child.as_mut(); + let child_results_receiver = child + .start(false, runtime_handle)? + .get_receiver_with_stats(&self.runtime_stats); + + let mut destination_channel = PipelineChannel::new(1, maintain_order); + let destination_sender = + destination_channel.get_next_sender_with_stats(&self.runtime_stats); + let file_info = Arc::new(self.file_info.clone()); + let target_chunk_rows = self.target_chunk_rows; + let target_file_rows = self.target_file_rows; + let schema = self.file_schema.clone(); + runtime_handle.spawn( + async move { + let mut task_set = tokio::task::JoinSet::new(); + let writer_senders = Self::spawn_writers( + *NUM_CPUS, + &mut task_set, + &file_info, + "__HIVE_DEFAULT_PARTITION__", + target_chunk_rows, + target_file_rows, + ); + Self::dispatch( + child_results_receiver, + writer_senders, + file_info.partition_cols.clone().unwrap(), + ) + .await?; + let mut results = vec![]; + while let Some(result) = task_set.join_next().await { + if let Some(result) = result.context(JoinSnafu)?? { + results.push(result); + } + } + if results.is_empty() { + return Ok(()); + } + let result_mp = + Arc::new(MicroPartition::new_loaded(schema, Arc::new(results), None)); + let _ = destination_sender.send(result_mp.into()).await; + Ok(()) + }, + self.name(), + ); + Ok(destination_channel) + } + fn as_tree_display(&self) -> &dyn TreeDisplay { + self + } +} diff --git a/src/daft-local-execution/src/sinks/unpartitioned_write.rs b/src/daft-local-execution/src/sinks/unpartitioned_write.rs new file mode 100644 index 0000000000..6b6ce435a9 --- /dev/null +++ b/src/daft-local-execution/src/sinks/unpartitioned_write.rs @@ -0,0 +1,230 @@ +use std::sync::Arc; + +use common_display::tree::TreeDisplay; +use common_error::DaftResult; +use daft_core::{ + prelude::{SchemaRef, Utf8Array}, + series::IntoSeries, +}; +use daft_micropartition::{create_file_writer, FileWriter, MicroPartition}; +use daft_plan::OutputFileInfo; +use daft_table::Table; +use snafu::ResultExt; + +use crate::{ + buffer::RowBasedBuffer, + channel::{create_channel, PipelineChannel, Receiver, Sender}, + create_task_set, + pipeline::PipelineNode, + runtime_stats::{CountingReceiver, RuntimeStatsContext}, + ExecutionRuntimeHandle, JoinSnafu, TaskSet, NUM_CPUS, +}; + +pub(crate) struct UnpartionedWriteNode { + child: Box, + runtime_stats: Arc, + file_info: OutputFileInfo, + file_schema: SchemaRef, + target_in_memory_file_rows: usize, + target_in_memory_chunk_rows: usize, +} + +impl UnpartionedWriteNode { + pub(crate) fn new( + child: Box, + file_info: &OutputFileInfo, + file_schema: &SchemaRef, + target_in_memory_file_rows: usize, + target_in_memory_chunk_rows: usize, + ) -> Self { + Self { + child, + runtime_stats: RuntimeStatsContext::new(), + file_info: file_info.clone(), + file_schema: file_schema.clone(), + target_in_memory_file_rows, + target_in_memory_chunk_rows, + } + } + + pub(crate) fn boxed(self) -> Box { + Box::new(self) + } + + async fn run_writer( + mut input_receiver: Receiver<(Arc, usize)>, + file_info: Arc, + ) -> DaftResult> { + let mut written_file_paths = vec![]; + let mut current_writer: Option> = None; + let mut current_file_idx = None; + while let Some((data, file_idx)) = input_receiver.recv().await { + if current_file_idx.is_none() || current_file_idx.unwrap() != file_idx { + if let Some(writer) = current_writer.take() { + if let Some(path) = writer.close()? { + written_file_paths.push(path); + } + } + current_file_idx = Some(file_idx); + current_writer = Some(create_file_writer( + &file_info.root_dir, + file_idx, + &file_info.compression, + &file_info.io_config, + file_info.file_format, + )?); + } + if let Some(writer) = current_writer.as_mut() { + writer.write(&data)?; + } + } + if let Some(writer) = current_writer { + if let Some(path) = writer.close()? { + written_file_paths.push(path); + } + } + Ok(written_file_paths) + } + + fn spawn_writers( + num_writers: usize, + task_set: &mut TaskSet>>, + file_info: &Arc, + ) -> Vec, usize)>> { + let mut writer_senders = Vec::with_capacity(num_writers); + for _ in 0..num_writers { + let (writer_sender, writer_receiver) = create_channel(1); + task_set.spawn(Self::run_writer(writer_receiver, file_info.clone())); + writer_senders.push(writer_sender); + } + writer_senders + } + + async fn dispatch( + mut input_receiver: CountingReceiver, + target_chunk_rows: usize, + target_file_rows: usize, + senders: Vec, usize)>>, + ) -> DaftResult<()> { + let mut curr_sent_rows = 0; + let mut curr_file_idx = 0; + let mut curr_sender_idx = 0; + let mut buffer = RowBasedBuffer::new(target_chunk_rows); + while let Some(data) = input_receiver.recv().await { + let data = data.as_data(); + if data.is_empty() { + continue; + } + buffer.push(data.clone()); + if let Some(ready) = buffer.pop_enough()? { + for part in ready { + curr_sent_rows += part.len(); + let _ = senders[curr_sender_idx].send((part, curr_file_idx)).await; + if curr_sent_rows >= target_file_rows { + curr_sent_rows = 0; + curr_file_idx += 1; + curr_sender_idx = (curr_sender_idx + 1) % senders.len(); + } + } + } + } + if let Some(leftover) = buffer.pop_all()? { + let _ = senders[curr_file_idx].send((leftover, curr_file_idx)).await; + } + Ok(()) + } +} + +impl TreeDisplay for UnpartionedWriteNode { + fn display_as(&self, level: common_display::DisplayLevel) -> String { + use std::fmt::Write; + let mut display = String::new(); + writeln!(display, "{}", self.name()).unwrap(); + use common_display::DisplayLevel::*; + match level { + Compact => {} + _ => { + let rt_result = self.runtime_stats.result(); + rt_result.display(&mut display, true, true, true).unwrap(); + } + } + display + } + + fn get_children(&self) -> Vec<&dyn TreeDisplay> { + vec![self.child.as_tree_display()] + } +} + +impl PipelineNode for UnpartionedWriteNode { + fn children(&self) -> Vec<&dyn PipelineNode> { + vec![self.child.as_ref()] + } + + fn name(&self) -> &'static str { + "UnpartionedWrite" + } + + fn start( + &mut self, + maintain_order: bool, + runtime_handle: &mut ExecutionRuntimeHandle, + ) -> crate::Result { + let child = self.child.as_mut(); + let child_results_receiver = child + .start(false, runtime_handle)? + .get_receiver_with_stats(&self.runtime_stats); + + let mut destination_channel = PipelineChannel::new(1, maintain_order); + let destination_sender = + destination_channel.get_next_sender_with_stats(&self.runtime_stats); + let file_info = Arc::new(self.file_info.clone()); + let target_chunk_rows = self.target_in_memory_chunk_rows; + let target_file_rows = self.target_in_memory_file_rows; + let file_schema = self.file_schema.clone(); + runtime_handle.spawn( + async move { + let mut task_set = create_task_set(); + let writer_senders = Self::spawn_writers(*NUM_CPUS, &mut task_set, &file_info); + Self::dispatch( + child_results_receiver, + target_chunk_rows, + target_file_rows, + writer_senders, + ) + .await?; + + let mut results = vec![]; + while let Some(result) = task_set.join_next().await { + results.extend(result.context(JoinSnafu)??); + } + if results.is_empty() { + return Ok(()); + } + + let written_file_paths_series = Utf8Array::from(( + "path", + results + .iter() + .map(|v| v.as_str()) + .collect::>() + .as_slice(), + )) + .into_series(); + let result_table = Table::from_nonempty_columns(vec![written_file_paths_series])?; + let result_mp = Arc::new(MicroPartition::new_loaded( + file_schema, + Arc::new(vec![result_table]), + None, + )); + let _ = destination_sender.send(result_mp.into()).await; + Ok(()) + }, + self.name(), + ); + Ok(destination_channel) + } + fn as_tree_display(&self) -> &dyn TreeDisplay { + self + } +} diff --git a/src/daft-local-execution/src/sources/in_memory.rs b/src/daft-local-execution/src/sources/in_memory.rs index 1bf08a8913..82ace7a274 100644 --- a/src/daft-local-execution/src/sources/in_memory.rs +++ b/src/daft-local-execution/src/sources/in_memory.rs @@ -17,6 +17,7 @@ impl InMemorySource { pub fn new(data: Vec>, schema: SchemaRef) -> Self { Self { data, schema } } + pub fn boxed(self) -> Box { Box::new(self) as Box } @@ -32,7 +33,7 @@ impl Source for InMemorySource { ) -> crate::Result> { if self.data.is_empty() { let empty = Arc::new(MicroPartition::empty(Some(self.schema.clone()))); - return Ok(Box::pin(futures::stream::once(async { empty }))); + return Ok(Box::pin(futures::stream::iter(vec![empty]))); } Ok(Box::pin(futures::stream::iter(self.data.clone()))) } diff --git a/src/daft-micropartition/src/lib.rs b/src/daft-micropartition/src/lib.rs index 1a01f4e933..e0508dfcce 100644 --- a/src/daft-micropartition/src/lib.rs +++ b/src/daft-micropartition/src/lib.rs @@ -1,7 +1,10 @@ #![feature(let_chains)] #![feature(iterator_try_reduce)] -use common_error::DaftError; +use std::sync::Arc; + +use common_error::{DaftError, DaftResult}; +use common_file_formats::FileFormat; use snafu::Snafu; mod micropartition; mod ops; @@ -13,6 +16,8 @@ pub mod python; #[cfg(feature = "python")] use pyo3::PyErr; #[cfg(feature = "python")] +pub mod py_writers; +#[cfg(feature = "python")] pub use python::register_modules; #[derive(Debug, Snafu)] @@ -59,3 +64,33 @@ impl From for pyo3::PyErr { daft_error.into() } } + +pub trait FileWriter: Send + Sync { + fn write(&self, data: &Arc) -> DaftResult<()>; + fn close(&self) -> DaftResult>; +} + +pub fn create_file_writer( + root_dir: &str, + file_idx: usize, + compression: &Option, + io_config: &Option, + format: FileFormat, +) -> DaftResult> { + match format { + #[cfg(feature = "python")] + FileFormat::Parquet => Ok(Box::new(py_writers::PyArrowParquetWriter::new( + root_dir, + file_idx, + compression, + io_config, + )?)), + #[cfg(feature = "python")] + FileFormat::Csv => Ok(Box::new(py_writers::PyArrowCSVWriter::new( + root_dir, file_idx, io_config, + )?)), + _ => Err(DaftError::ComputeError( + "Unsupported file format for physical write".to_string(), + )), + } +} diff --git a/src/daft-micropartition/src/py_writers.rs b/src/daft-micropartition/src/py_writers.rs new file mode 100644 index 0000000000..263e72db58 --- /dev/null +++ b/src/daft-micropartition/src/py_writers.rs @@ -0,0 +1,108 @@ +use std::sync::Arc; + +use common_error::DaftResult; +use pyo3::{types::PyAnyMethods, PyObject, Python}; + +use crate::{python::PyMicroPartition, FileWriter, MicroPartition}; + +pub struct PyArrowParquetWriter { + py_writer: PyObject, +} + +impl PyArrowParquetWriter { + pub fn new( + root_dir: &str, + file_idx: usize, + compression: &Option, + io_config: &Option, + ) -> DaftResult { + Python::with_gil(|py| { + let file_writer_module = py.import_bound(pyo3::intern!(py, "daft.io.writer"))?; + let file_writer_class = file_writer_module.getattr("ParquetFileWriter")?; + + let py_writer = file_writer_class.call1(( + root_dir, + file_idx, + compression.as_ref().map(|c| c.as_str()), + io_config.as_ref().map(|cfg| daft_io::python::IOConfig { + config: cfg.clone(), + }), + ))?; + Ok(Self { + py_writer: py_writer.into(), + }) + }) + } +} + +impl FileWriter for PyArrowParquetWriter { + fn write(&self, data: &Arc) -> DaftResult<()> { + Python::with_gil(|py| { + let py_micropartition = py + .import_bound(pyo3::intern!(py, "daft.table"))? + .getattr(pyo3::intern!(py, "MicroPartition"))? + .getattr(pyo3::intern!(py, "_from_pymicropartition"))? + .call1((PyMicroPartition::from(data.clone()),))?; + self.py_writer + .call_method1(py, "write", (py_micropartition,))?; + Ok(()) + }) + } + + fn close(&self) -> DaftResult> { + Python::with_gil(|py| { + let result = self.py_writer.call_method0(py, "close")?; + Ok(result.extract::>(py)?) + }) + } +} + +pub struct PyArrowCSVWriter { + py_writer: PyObject, +} + +impl PyArrowCSVWriter { + pub fn new( + root_dir: &str, + file_idx: usize, + io_config: &Option, + ) -> DaftResult { + Python::with_gil(|py| { + let file_writer_module = py.import_bound(pyo3::intern!(py, "daft.io.writer"))?; + let file_writer_class = file_writer_module.getattr("CSVFileWriter")?; + + let py_writer = file_writer_class.call1(( + root_dir, + file_idx, + io_config.as_ref().map(|cfg| daft_io::python::IOConfig { + config: cfg.clone(), + }), + ))?; + Ok(Self { + py_writer: py_writer.into(), + }) + }) + } +} + +impl FileWriter for PyArrowCSVWriter { + fn write(&self, data: &Arc) -> DaftResult<()> { + Python::with_gil(|py| { + let py_micropartition = py + .import_bound(pyo3::intern!(py, "daft.table"))? + .getattr(pyo3::intern!(py, "MicroPartition"))? + .getattr(pyo3::intern!(py, "_from_pymicropartition"))? + .call1((PyMicroPartition::from(data.clone()),))?; + self.py_writer + .call_method1(py, "write", (py_micropartition,))?; + Ok(()) + }) + } + + fn close(&self) -> DaftResult> { + Python::with_gil(|py| { + let result = self.py_writer.call_method0(py, "close")?; + Ok(result.extract::>(py)?) + }) + } +} diff --git a/src/daft-parquet/src/stream_reader.rs b/src/daft-parquet/src/stream_reader.rs index 178141c64d..bdf1a971e5 100644 --- a/src/daft-parquet/src/stream_reader.rs +++ b/src/daft-parquet/src/stream_reader.rs @@ -9,7 +9,7 @@ use arrow2::io::parquet::read; use common_error::DaftResult; use daft_core::{prelude::*, utils::arrow::cast_array_for_daft_if_needed}; use daft_dsl::ExprRef; -use daft_io::IOStatsRef; +use daft_io::{get_runtime, IOStatsRef}; use daft_table::Table; use futures::{stream::BoxStream, StreamExt}; use itertools::Itertools; @@ -514,13 +514,8 @@ pub(crate) fn local_parquet_stream( // Create a channel for each row group to send the processed tables to the stream // Each channel is expected to have a number of chunks equal to the number of chunks in the row group - let (senders, receivers): (Vec<_>, Vec<_>) = row_ranges - .iter() - .map(|rg_range| { - let expected_num_chunks = - f32::ceil(rg_range.num_rows as f32 / chunk_size as f32) as usize; - crossbeam_channel::bounded(expected_num_chunks) - }) + let (senders, receivers): (Vec<_>, Vec<_>) = (0..row_ranges.len()) + .map(|_| tokio::sync::mpsc::channel::>(1)) .unzip(); let owned_uri = uri.to_string(); @@ -542,34 +537,39 @@ pub(crate) fn local_parquet_stream( DaftResult::Ok(table_iter) }); - rayon::spawn(move || { + let runtime = get_runtime(true)?; + let _ = runtime.block_on_io_pool(async move { // Once a row group has been read into memory and we have the column iterators, // we can start processing them in parallel. - let par_table_iters = table_iters.zip(senders).par_bridge(); + let par_table_iters = table_iters.zip(senders); // For each vec of column iters, iterate through them in parallel lock step such that each iteration // produces a chunk of the row group that can be converted into a table. par_table_iters.for_each(move |(table_iter_result, tx)| { - let table_iter = match table_iter_result { - Ok(t) => t, - Err(e) => { - let _ = tx.send(Err(e)); - return; - } - }; - for table_result in table_iter { - let table_err = table_result.is_err(); - if let Err(crossbeam_channel::TrySendError::Full(_)) = tx.try_send(table_result) { - panic!("Parquet stream channel should not be full") - } - if table_err { - break; + tokio::spawn(async move { + let table_iter = match table_iter_result { + Ok(t) => t, + Err(e) => { + let _ = tx.send(Err(e)).await; + return; + } + }; + for table_result in table_iter { + let table_err = table_result.is_err(); + let _ = tx.send(table_result).await; + if table_err { + break; + } } - } + }); }); }); - let result_stream = futures::stream::iter(receivers.into_iter().map(futures::stream::iter)); + let result_stream = futures::stream::iter( + receivers + .into_iter() + .map(tokio_stream::wrappers::ReceiverStream::new), + ); match maintain_order { true => Ok((metadata, Box::pin(result_stream.flatten()))), diff --git a/src/daft-physical-plan/Cargo.toml b/src/daft-physical-plan/Cargo.toml index 778b8b8560..17d419e63f 100644 --- a/src/daft-physical-plan/Cargo.toml +++ b/src/daft-physical-plan/Cargo.toml @@ -8,6 +8,9 @@ daft-scan = {path = "../daft-scan", default-features = false} log = {workspace = true} strum = {version = "0.26", features = ["derive"]} +[features] +python = ["common-error/python", "common-resource-request/python", "daft-core/python", "daft-dsl/python", "daft-plan/python", "daft-scan/python"] + [lints] workspace = true diff --git a/src/daft-physical-plan/src/local_plan.rs b/src/daft-physical-plan/src/local_plan.rs index 548e6505d8..d56a8fcb96 100644 --- a/src/daft-physical-plan/src/local_plan.rs +++ b/src/daft-physical-plan/src/local_plan.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use common_resource_request::ResourceRequest; use daft_core::prelude::*; use daft_dsl::{AggExpr, ExprRef}; -use daft_plan::InMemoryInfo; +use daft_plan::{InMemoryInfo, OutputFileInfo}; use daft_scan::{ScanTask, ScanTaskRef}; pub type LocalPhysicalPlanRef = Arc; @@ -190,6 +190,22 @@ impl LocalPhysicalPlan { .arced() } + pub(crate) fn physical_write( + input: LocalPhysicalPlanRef, + data_schema: SchemaRef, + file_schema: SchemaRef, + file_info: OutputFileInfo, + ) -> LocalPhysicalPlanRef { + Self::PhysicalWrite(PhysicalWrite { + input, + data_schema, + file_schema, + file_info, + plan_stats: PlanStats {}, + }) + .arced() + } + pub fn schema(&self) -> &SchemaRef { match self { Self::PhysicalScan(PhysicalScan { schema, .. }) @@ -293,7 +309,13 @@ pub struct Concat { #[derive(Debug)] -pub struct PhysicalWrite {} +pub struct PhysicalWrite { + pub input: LocalPhysicalPlanRef, + pub data_schema: SchemaRef, + pub file_schema: SchemaRef, + pub file_info: OutputFileInfo, + pub plan_stats: PlanStats, +} #[derive(Debug)] pub struct PlanStats {} diff --git a/src/daft-physical-plan/src/translate.rs b/src/daft-physical-plan/src/translate.rs index 2d6be9a40b..48e85d982d 100644 --- a/src/daft-physical-plan/src/translate.rs +++ b/src/daft-physical-plan/src/translate.rs @@ -105,6 +105,23 @@ pub fn translate(plan: &LogicalPlanRef) -> DaftResult { log::warn!("Repartition Not supported for Local Executor!; This will be a No-Op"); translate(&repartition.input) } + LogicalPlan::Sink(sink) => { + use daft_plan::SinkInfo; + let input = translate(&sink.input)?; + let data_schema = input.schema().clone(); + match sink.sink_info.as_ref() { + SinkInfo::OutputFileInfo(info) => Ok(LocalPhysicalPlan::physical_write( + input, + data_schema, + sink.schema.clone(), + info.clone(), + )), + #[cfg(feature = "python")] + SinkInfo::CatalogInfo(_) => { + todo!("CatalogInfo not yet implemented") + } + } + } _ => todo!("{} not yet implemented", plan.name()), } } diff --git a/tests/cookbook/test_write.py b/tests/cookbook/test_write.py index 46db61d47e..4910e1f52c 100644 --- a/tests/cookbook/test_write.py +++ b/tests/cookbook/test_write.py @@ -7,14 +7,9 @@ from pyarrow import dataset as pads import daft -from daft import context from tests.conftest import assert_df_equals from tests.cookbook.assets import COOKBOOK_DATA_CSV -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) PYARROW_GE_7_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (7, 0, 0) @@ -178,9 +173,8 @@ def test_parquet_write_multifile(tmp_path, smaller_parquet_target_filesize): df = daft.from_pydict(data) df2 = df.write_parquet(tmp_path) assert len(df2) > 1 - ds = pads.dataset(tmp_path, format="parquet") - readback = ds.to_table() - assert readback.to_pydict() == data + read_back = daft.read_parquet(tmp_path.as_posix() + "/*.parquet").sort(by="x").to_pydict() + assert read_back == data @pytest.mark.skipif( diff --git a/tests/dataframe/test_decimals.py b/tests/dataframe/test_decimals.py index daafec29f0..3a2d11babe 100644 --- a/tests/dataframe/test_decimals.py +++ b/tests/dataframe/test_decimals.py @@ -7,12 +7,6 @@ import pytest import daft -from daft import context - -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) PYARROW_GE_7_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (7, 0, 0) diff --git a/tests/dataframe/test_temporals.py b/tests/dataframe/test_temporals.py index 8843028b01..700cdd16da 100644 --- a/tests/dataframe/test_temporals.py +++ b/tests/dataframe/test_temporals.py @@ -11,11 +11,6 @@ import daft from daft import DataType, col, context -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) - PYARROW_GE_7_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (7, 0, 0) @@ -48,7 +43,10 @@ def test_temporal_arithmetic_with_same_type() -> None: @pytest.mark.parametrize("format", ["csv", "parquet"]) -@pytest.mark.parametrize("use_native_downloader", [True, False]) +@pytest.mark.parametrize( + "use_native_downloader", + [True, False] if context.get_context().daft_execution_config.enable_native_executor is False else [True], +) def test_temporal_file_roundtrip(format, use_native_downloader) -> None: data = { "date32": pa.array([1], pa.date32()), diff --git a/tests/io/test_parquet.py b/tests/io/test_parquet.py index c30ae8da1a..d158afa7e4 100644 --- a/tests/io/test_parquet.py +++ b/tests/io/test_parquet.py @@ -21,10 +21,6 @@ from ..integration.io.conftest import minio_create_bucket -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) PYARROW_GE_11_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (11, 0, 0) PYARROW_GE_13_0_0 = tuple(int(s) for s in pa.__version__.split(".") if s.isnumeric()) >= (13, 0, 0) @@ -49,7 +45,10 @@ def storage_config_from_use_native_downloader(use_native_downloader: bool) -> St return StorageConfig.python(PythonStorageConfig(None)) -@pytest.mark.parametrize("use_native_downloader", [True, False]) +@pytest.mark.parametrize( + "use_native_downloader", + [True, False] if context.get_context().daft_execution_config.enable_native_executor is False else [True], +) @pytest.mark.parametrize("use_deprecated_int96_timestamps", [True, False]) def test_parquet_read_int96_timestamps(use_deprecated_int96_timestamps, use_native_downloader): data = { @@ -81,7 +80,10 @@ def test_parquet_read_int96_timestamps(use_deprecated_int96_timestamps, use_nati assert df.to_arrow() == expected.to_arrow(), f"Expected:\n{expected}\n\nReceived:\n{df.to_arrow()}" -@pytest.mark.parametrize("use_native_downloader", [True, False]) +@pytest.mark.parametrize( + "use_native_downloader", + [True, False] if context.get_context().daft_execution_config.enable_native_executor is False else [True], +) @pytest.mark.parametrize("coerce_to", [TimeUnit.ms(), TimeUnit.us()]) def test_parquet_read_int96_timestamps_overflow(coerce_to, use_native_downloader): # NOTE: datetime.datetime(3000, 1, 1) and datetime.datetime(1000, 1, 1) cannot be represented by our timestamp64(nanosecond) diff --git a/tests/io/test_s3_credentials_refresh.py b/tests/io/test_s3_credentials_refresh.py index 25c5c0cd8c..16a98fadf0 100644 --- a/tests/io/test_s3_credentials_refresh.py +++ b/tests/io/test_s3_credentials_refresh.py @@ -10,14 +10,8 @@ import pytest import daft -from daft import context from tests.io.mock_aws_server import start_service, stop_process -pytestmark = pytest.mark.skipif( - context.get_context().daft_execution_config.enable_native_executor is True, - reason="Native executor fails for these tests", -) - @pytest.fixture(scope="session") def aws_log_file(tmp_path_factory: pytest.TempPathFactory) -> Iterator[io.IOBase]: