diff --git a/src/daft-local-execution/src/intermediate_ops/aggregate.rs b/src/daft-local-execution/src/intermediate_ops/aggregate.rs index 13f93cb818..96d8b8da30 100644 --- a/src/daft-local-execution/src/intermediate_ops/aggregate.rs +++ b/src/daft-local-execution/src/intermediate_ops/aggregate.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use common_error::DaftResult; use daft_dsl::ExprRef; +use daft_table::Table; use tracing::instrument; use super::intermediate_op::{ @@ -31,9 +32,10 @@ impl IntermediateOperator for AggregateOperator { input: &PipelineResultType, _state: Option<&mut Box>, ) -> DaftResult { - let out = input.as_data().agg(&self.agg_exprs, &self.group_by)?; + let input = Table::concat(input.as_data())?; + let out = input.agg(&self.agg_exprs, &self.group_by)?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( - out, + vec![out; 1], )))) } diff --git a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs index b4d7792b1e..205a447122 100644 --- a/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/anti_semi_hash_join_probe.rs @@ -1,16 +1,9 @@ use std::sync::Arc; use common_error::DaftResult; -use daft_core::{ - prelude::{ - bitmap::{Bitmap, MutableBitmap}, - BooleanArray, - }, - series::IntoSeries, -}; use daft_dsl::ExprRef; use daft_plan::JoinType; -use daft_table::{Probeable, Table}; +use daft_table::{GrowableTable, Probeable, Table}; use tracing::{info_span, instrument}; use super::intermediate_op::{ @@ -60,31 +53,36 @@ impl AntiSemiProbeOperator { } } - fn probe_anti_semi(&self, input: &Table, state: &mut AntiSemiProbeState) -> DaftResult { + fn probe_anti_semi( + &self, + input: &[Table], + state: &mut AntiSemiProbeState, + ) -> DaftResult
{ let probe_set = state.get_probeable(); let _growables = info_span!("AntiSemiOperator::build_growables").entered(); - let mut input_idx_matches = MutableBitmap::from_len_zeroed(input.len()); + let mut probe_side_growable = + GrowableTable::new(&input.iter().collect::>(), false, 20)?; drop(_growables); { let _loop = info_span!("AntiSemiOperator::eval_and_probe").entered(); - let join_keys = input.eval_expression_list(&self.probe_on)?; - let iter = probe_set.probe_exists(&join_keys)?; - - for (probe_row_idx, matched) in iter.enumerate() { - match (self.join_type == JoinType::Semi, matched) { - (true, true) | (false, false) => { - input_idx_matches.set(probe_row_idx, true); + for (probe_side_table_idx, table) in input.iter().enumerate() { + let join_keys = table.eval_expression_list(&self.probe_on)?; + let iter = probe_set.probe_exists(&join_keys)?; + + for (probe_row_idx, matched) in iter.enumerate() { + match (self.join_type == JoinType::Semi, matched) { + (true, true) | (false, false) => { + probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); + } + _ => {} } - _ => {} } } } - let bitmap: Bitmap = input_idx_matches.into(); - let result = input.mask_filter(&BooleanArray::from(("bitmap", bitmap)).into_series())?; - Ok(result) + probe_side_growable.build() } } @@ -119,7 +117,7 @@ impl IntermediateOperator for AntiSemiProbeOperator { _ => unreachable!("Only Semi and Anti joins are supported"), }?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( - out, + vec![out; 1], )))) } } diff --git a/src/daft-local-execution/src/intermediate_ops/buffer.rs b/src/daft-local-execution/src/intermediate_ops/buffer.rs index 2d5ea04333..18d3e9627a 100644 --- a/src/daft-local-execution/src/intermediate_ops/buffer.rs +++ b/src/daft-local-execution/src/intermediate_ops/buffer.rs @@ -1,10 +1,10 @@ -use std::{cmp::Ordering::*, collections::VecDeque, sync::Arc}; +use std::{cmp::Ordering::*, collections::VecDeque}; use common_error::DaftResult; use daft_table::Table; pub struct OperatorBuffer { - pub buffer: VecDeque>, + pub buffer: VecDeque
, pub curr_len: usize, pub threshold: usize, } @@ -19,12 +19,14 @@ impl OperatorBuffer { } } - pub fn push(&mut self, part: Arc
) { - self.curr_len += part.len(); - self.buffer.push_back(part); + pub fn push(&mut self, parts: &[Table]) { + for part in parts { + self.buffer.push_back(part.clone()); + self.curr_len += part.len(); + } } - pub fn try_clear(&mut self) -> Option>> { + pub fn try_clear(&mut self) -> Option>> { match self.curr_len.cmp(&self.threshold) { Less => None, Equal => self.clear_all(), @@ -32,9 +34,8 @@ impl OperatorBuffer { } } - fn clear_enough(&mut self) -> DaftResult> { + fn clear_enough(&mut self) -> DaftResult> { assert!(self.curr_len > self.threshold); - let mut to_concat = Vec::with_capacity(self.buffer.len()); let mut remaining = self.threshold; @@ -47,28 +48,28 @@ impl OperatorBuffer { } else { let (head, tail) = part.split_at(remaining)?; remaining = 0; - to_concat.push(Arc::new(head)); - self.buffer.push_front(Arc::new(tail)); + to_concat.push(head); + self.buffer.push_front(tail); break; } } assert_eq!(remaining, 0); self.curr_len -= self.threshold; - match to_concat.len() { - 1 => Ok(to_concat.pop().unwrap()), - _ => Ok(Arc::new(Table::concat(&to_concat)?)), - } + Ok(to_concat) } - pub fn clear_all(&mut self) -> Option>> { + pub fn clear_all(&mut self) -> Option>> { if self.buffer.is_empty() { return None; } - let concated = Table::concat(&std::mem::take(&mut self.buffer).iter().collect::>()) - .map(Arc::new); self.curr_len = 0; - Some(concated) + Some( + std::mem::take(&mut self.buffer) + .into_iter() + .map(Ok) + .collect(), + ) } } diff --git a/src/daft-local-execution/src/intermediate_ops/filter.rs b/src/daft-local-execution/src/intermediate_ops/filter.rs index da8dbcd19c..f506af9d2b 100644 --- a/src/daft-local-execution/src/intermediate_ops/filter.rs +++ b/src/daft-local-execution/src/intermediate_ops/filter.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use common_error::DaftResult; use daft_dsl::ExprRef; +use daft_table::Table; use tracing::instrument; use super::intermediate_op::{ @@ -27,7 +28,11 @@ impl IntermediateOperator for FilterOperator { input: &PipelineResultType, _state: Option<&mut Box>, ) -> DaftResult { - let out = input.as_data().filter(&[self.predicate.clone()])?; + let out = input + .as_data() + .iter() + .map(|t| t.filter(&[self.predicate.clone()])) + .collect::>>()?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( out, )))) diff --git a/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs index 3be41bcfe9..a8022a657c 100644 --- a/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs @@ -1,10 +1,7 @@ use std::sync::Arc; use common_error::DaftResult; -use daft_core::{ - prelude::{SchemaRef, UInt64Array}, - series::IntoSeries, -}; +use daft_core::prelude::SchemaRef; use daft_dsl::ExprRef; use daft_plan::JoinType; use daft_table::{GrowableTable, Probeable, Table}; @@ -97,7 +94,7 @@ impl HashJoinProbeOperator { } } - fn probe_inner(&self, input: &Table, state: &mut HashJoinProbeState) -> DaftResult
{ + fn probe_inner(&self, input: &[Table], state: &mut HashJoinProbeState) -> DaftResult
{ let (probe_table, tables) = state.get_probeable_and_table(); let _growables = info_span!("HashJoinOperator::build_growables").entered(); @@ -105,30 +102,34 @@ impl HashJoinProbeOperator { let mut build_side_growable = GrowableTable::new(&tables.iter().collect::>(), false, 20)?; - let mut input_idx_matches = vec![]; + let mut probe_side_growable = + GrowableTable::new(&input.iter().collect::>(), false, 20)?; drop(_growables); { let _loop = info_span!("HashJoinOperator::eval_and_probe").entered(); - let join_keys = input.eval_expression_list(&self.probe_on)?; - let idx_mapper = probe_table.probe_indices(&join_keys)?; - - for (probe_row_idx, inner_iter) in idx_mapper.make_iter().enumerate() { - if let Some(inner_iter) = inner_iter { - for (build_side_table_idx, build_row_idx) in inner_iter { - build_side_growable.extend( - build_side_table_idx as usize, - build_row_idx as usize, - 1, - ); - input_idx_matches.push(probe_row_idx as u64); + for (probe_side_table_idx, table) in input.iter().enumerate() { + // we should emit one table at a time when this is streaming + let join_keys = table.eval_expression_list(&self.probe_on)?; + let idx_mapper = probe_table.probe_indices(&join_keys)?; + + for (probe_row_idx, inner_iter) in idx_mapper.make_iter().enumerate() { + if let Some(inner_iter) = inner_iter { + for (build_side_table_idx, build_row_idx) in inner_iter { + build_side_growable.extend( + build_side_table_idx as usize, + build_row_idx as usize, + 1, + ); + // we can perform run length compression for this to make this more efficient + probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); + } } } } } let build_side_table = build_side_growable.build()?; - let probe_side_table = - input.take(&UInt64Array::from(("matches", input_idx_matches)).into_series())?; + let probe_side_table = probe_side_growable.build()?; let (left_table, right_table) = if self.build_on_left { (build_side_table, probe_side_table) @@ -139,14 +140,16 @@ impl HashJoinProbeOperator { let join_keys_table = left_table.get_columns(&self.common_join_keys)?; let left_non_join_columns = left_table.get_columns(&self.left_non_join_columns)?; let right_non_join_columns = right_table.get_columns(&self.right_non_join_columns)?; - let final_table = join_keys_table + join_keys_table .union(&left_non_join_columns)? - .union(&right_non_join_columns)?; - - Ok(final_table) + .union(&right_non_join_columns) } - fn probe_left_right(&self, input: &Table, state: &mut HashJoinProbeState) -> DaftResult
{ + fn probe_left_right( + &self, + input: &[Table], + state: &mut HashJoinProbeState, + ) -> DaftResult
{ let (probe_table, tables) = state.get_probeable_and_table(); let _growables = info_span!("HashJoinOperator::build_growables").entered(); @@ -157,47 +160,48 @@ impl HashJoinProbeOperator { tables.iter().map(|t| t.len()).sum(), )?; - let mut input_idx_matches = Vec::with_capacity(input.len()); + let mut probe_side_growable = + GrowableTable::new(&input.iter().collect::>(), false, input.len())?; drop(_growables); { let _loop = info_span!("HashJoinOperator::eval_and_probe").entered(); - let join_keys = input.eval_expression_list(&self.probe_on)?; - let idx_mapper = probe_table.probe_indices(&join_keys)?; - - for (probe_row_idx, inner_iter) in idx_mapper.make_iter().enumerate() { - if let Some(inner_iter) = inner_iter { - for (build_side_table_idx, build_row_idx) in inner_iter { - build_side_growable.extend( - build_side_table_idx as usize, - build_row_idx as usize, - 1, - ); - input_idx_matches.push(probe_row_idx as u64); + for (probe_side_table_idx, table) in input.iter().enumerate() { + let join_keys = table.eval_expression_list(&self.probe_on)?; + let idx_mapper = probe_table.probe_indices(&join_keys)?; + + for (probe_row_idx, inner_iter) in idx_mapper.make_iter().enumerate() { + if let Some(inner_iter) = inner_iter { + for (build_side_table_idx, build_row_idx) in inner_iter { + build_side_growable.extend( + build_side_table_idx as usize, + build_row_idx as usize, + 1, + ); + probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); + } + } else { + // if there's no match, we should still emit the probe side and fill the build side with nulls + build_side_growable.add_nulls(1); + probe_side_growable.extend(probe_side_table_idx, probe_row_idx, 1); } - } else { - // if there's no match, we should still emit the probe side and fill the build side with nulls - build_side_growable.add_nulls(1); - input_idx_matches.push(probe_row_idx as u64); } } } let build_side_table = build_side_growable.build()?; - let probe_side_table = - input.take(&UInt64Array::from(("matches", input_idx_matches)).into_series())?; + let probe_side_table = probe_side_growable.build()?; - let final_table = if self.join_type == JoinType::Left { + if self.join_type == JoinType::Left { let join_table = probe_side_table.get_columns(&self.common_join_keys)?; let left = probe_side_table.get_columns(&self.left_non_join_columns)?; let right = build_side_table.get_columns(&self.right_non_join_columns)?; - join_table.union(&left)?.union(&right)? + join_table.union(&left)?.union(&right) } else { let join_table = probe_side_table.get_columns(&self.common_join_keys)?; let left = build_side_table.get_columns(&self.left_non_join_columns)?; let right = probe_side_table.get_columns(&self.right_non_join_columns)?; - join_table.union(&left)?.union(&right)? - }; - Ok(final_table) + join_table.union(&left)?.union(&right) + } } } @@ -235,7 +239,7 @@ impl IntermediateOperator for HashJoinProbeOperator { } }?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( - out, + vec![out; 1], )))) } } diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index a54424a6ec..328c1913bb 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -18,9 +18,9 @@ pub trait IntermediateOperatorState: Send + Sync { } pub enum IntermediateOperatorResult { - NeedMoreInput(Option>), + NeedMoreInput(Option>>), #[allow(dead_code)] - HasMoreOutput(Arc
), + HasMoreOutput(Arc>), } pub trait IntermediateOperator: Send + Sync { @@ -142,19 +142,19 @@ impl IntermediateNode { let _ = worker_sender.send((idx, morsel.clone())).await; } } else { - buffer.push(morsel.as_data().clone()); + buffer.push(morsel.as_data()); if let Some(ready) = buffer.try_clear() { - let _ = send_to_next_worker(idx, ready?.into()).await; + let _ = send_to_next_worker(idx, Arc::new(ready?).into()).await; } } } // Buffer may still have some morsels left above the threshold while let Some(ready) = buffer.try_clear() { - let _ = send_to_next_worker(idx, ready?.into()).await; + let _ = send_to_next_worker(idx, Arc::new(ready?).into()).await; } // Clear all remaining morsels if let Some(last_morsel) = buffer.clear_all() { - let _ = send_to_next_worker(idx, last_morsel?.into()).await; + let _ = send_to_next_worker(idx, Arc::new(last_morsel?).into()).await; } } Ok(()) diff --git a/src/daft-local-execution/src/intermediate_ops/project.rs b/src/daft-local-execution/src/intermediate_ops/project.rs index abd37b461f..52fd0257b6 100644 --- a/src/daft-local-execution/src/intermediate_ops/project.rs +++ b/src/daft-local-execution/src/intermediate_ops/project.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use common_error::DaftResult; use daft_dsl::ExprRef; +use daft_table::Table; use tracing::instrument; use super::intermediate_op::{ @@ -27,7 +28,11 @@ impl IntermediateOperator for ProjectOperator { input: &PipelineResultType, _state: Option<&mut Box>, ) -> DaftResult { - let out = input.as_data().eval_expression_list(&self.projection)?; + let out = input + .as_data() + .iter() + .map(|t| t.eval_expression_list(&self.projection)) + .collect::>>()?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( out, )))) diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index c52615671b..7d39cae882 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -36,12 +36,12 @@ use crate::{ #[derive(Clone)] pub enum PipelineResultType { - Data(Arc
), + Data(Arc>), ProbeTable(Arc, Arc>), } -impl From> for PipelineResultType { - fn from(data: Arc
) -> Self { +impl From>> for PipelineResultType { + fn from(data: Arc>) -> Self { Self::Data(data) } } @@ -53,7 +53,7 @@ impl From<(Arc, Arc>)> for PipelineResultType { } impl PipelineResultType { - pub fn as_data(&self) -> &Arc
{ + pub fn as_data(&self) -> &Arc> { match self { Self::Data(data) => data, _ => panic!("Expected data"), diff --git a/src/daft-local-execution/src/run.rs b/src/daft-local-execution/src/run.rs index 4f296f70df..0c3f9f38c7 100644 --- a/src/daft-local-execution/src/run.rs +++ b/src/daft-local-execution/src/run.rs @@ -169,7 +169,7 @@ pub fn run_local( }); struct ReceiverIterator { - receiver: Receiver>, + receiver: Receiver>>, handle: Option>>, } @@ -178,12 +178,14 @@ pub fn run_local( fn next(&mut self) -> Option { match self.receiver.blocking_recv() { - Some(part) => Some(Ok(Arc::new(MicroPartition::new_loaded( - part.schema.clone(), - vec![part.as_ref().clone()].into(), - None, - )))), - None => { + Some(part) if !part.is_empty() && !part.first().unwrap().is_empty() => { + Some(Ok(Arc::new(MicroPartition::new_loaded( + part.first().unwrap().schema.clone(), + part, + None, + )))) + } + _ => { if self.handle.is_some() { let join_result = self .handle diff --git a/src/daft-local-execution/src/sinks/aggregate.rs b/src/daft-local-execution/src/sinks/aggregate.rs index 31cc020f16..ea7c2ac153 100644 --- a/src/daft-local-execution/src/sinks/aggregate.rs +++ b/src/daft-local-execution/src/sinks/aggregate.rs @@ -9,7 +9,7 @@ use super::blocking_sink::{BlockingSink, BlockingSinkStatus}; use crate::pipeline::PipelineResultType; enum AggregateState { - Accumulating(Vec>), + Accumulating(Vec
), #[allow(dead_code)] Done(Table), } @@ -36,9 +36,11 @@ impl AggregateSink { impl BlockingSink for AggregateSink { #[instrument(skip_all, name = "AggregateSink::sink")] - fn sink(&mut self, input: &Arc
) -> DaftResult { + fn sink(&mut self, input: &[Table]) -> DaftResult { if let AggregateState::Accumulating(parts) = &mut self.state { - parts.push(input.clone()); + for t in input { + parts.push(t.clone()); + } Ok(BlockingSinkStatus::NeedMoreInput) } else { panic!("AggregateSink should be in Accumulating state"); @@ -55,7 +57,7 @@ impl BlockingSink for AggregateSink { let concated = Table::concat(parts)?; let agged = concated.agg(&self.agg_exprs, &self.group_by)?; self.state = AggregateState::Done(agged.clone()); - Ok(Some(Arc::new(agged).into())) + Ok(Some(Arc::new(vec![agged; 1]).into())) } else { panic!("AggregateSink should be in Accumulating state"); } diff --git a/src/daft-local-execution/src/sinks/blocking_sink.rs b/src/daft-local-execution/src/sinks/blocking_sink.rs index d01ee9e101..01dd8cda8b 100644 --- a/src/daft-local-execution/src/sinks/blocking_sink.rs +++ b/src/daft-local-execution/src/sinks/blocking_sink.rs @@ -18,7 +18,7 @@ pub enum BlockingSinkStatus { } pub trait BlockingSink: Send + Sync { - fn sink(&mut self, input: &Arc
) -> DaftResult; + fn sink(&mut self, input: &[Table]) -> DaftResult; fn finalize(&mut self) -> DaftResult>; fn name(&self) -> &'static str; } diff --git a/src/daft-local-execution/src/sinks/hash_join_build.rs b/src/daft-local-execution/src/sinks/hash_join_build.rs index 1c38a8f317..9a6d0d4c1f 100644 --- a/src/daft-local-execution/src/sinks/hash_join_build.rs +++ b/src/daft-local-execution/src/sinks/hash_join_build.rs @@ -97,8 +97,10 @@ impl BlockingSink for HashJoinBuildSink { "HashJoinBuildSink" } - fn sink(&mut self, input: &Arc
) -> DaftResult { - self.probe_table_state.add_tables(input)?; + fn sink(&mut self, input: &[Table]) -> DaftResult { + for t in input { + self.probe_table_state.add_tables(t)?; + } Ok(BlockingSinkStatus::NeedMoreInput) } diff --git a/src/daft-local-execution/src/sinks/limit.rs b/src/daft-local-execution/src/sinks/limit.rs index 3c642d9fc3..f7e57a24cc 100644 --- a/src/daft-local-execution/src/sinks/limit.rs +++ b/src/daft-local-execution/src/sinks/limit.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use common_error::DaftResult; use daft_table::Table; use tracing::instrument; @@ -26,27 +24,33 @@ impl LimitSink { impl StreamingSink for LimitSink { #[instrument(skip_all, name = "LimitSink::sink")] - fn execute(&mut self, index: usize, input: &Arc
) -> DaftResult { + fn execute(&mut self, index: usize, input: &[Table]) -> DaftResult { assert_eq!(index, 0); - - let input_num_rows = input.len(); - use std::cmp::Ordering::*; - match input_num_rows.cmp(&self.remaining) { - Less => { - self.remaining -= input_num_rows; - Ok(StreamSinkOutput::NeedMoreInput(Some(input.clone()))) - } - Equal => { - self.remaining = 0; - Ok(StreamSinkOutput::Finished(Some(input.clone()))) - } - Greater => { - let taken = input.head(self.remaining)?; - self.remaining -= taken.len(); - Ok(StreamSinkOutput::Finished(Some(taken.into()))) + + let mut result = vec![]; + for t in input { + let input_num_rows = t.len(); + + match input_num_rows.cmp(&self.remaining) { + Less => { + result.push(t.clone()); + self.remaining -= input_num_rows; + } + Equal => { + self.remaining = 0; + result.push(t.clone()); + return Ok(StreamSinkOutput::Finished(Some(result.into()))); + } + Greater => { + let taken = t.head(self.remaining)?; + self.remaining -= taken.len(); + result.push(taken.clone()); + return Ok(StreamSinkOutput::Finished(Some(result.into()))); + } } } + Ok(StreamSinkOutput::NeedMoreInput(Some(result.into()))) } fn name(&self) -> &'static str { diff --git a/src/daft-local-execution/src/sinks/sort.rs b/src/daft-local-execution/src/sinks/sort.rs index 4247b1676e..cb8682e71c 100644 --- a/src/daft-local-execution/src/sinks/sort.rs +++ b/src/daft-local-execution/src/sinks/sort.rs @@ -14,7 +14,7 @@ pub struct SortSink { } enum SortState { - Building(Vec>), + Building(Vec
), #[allow(dead_code)] Done(Table), } @@ -34,9 +34,11 @@ impl SortSink { impl BlockingSink for SortSink { #[instrument(skip_all, name = "SortSink::sink")] - fn sink(&mut self, input: &Arc
) -> DaftResult { + fn sink(&mut self, input: &[Table]) -> DaftResult { if let SortState::Building(parts) = &mut self.state { - parts.push(input.clone()); + for t in input { + parts.push(t.clone()); + } } else { panic!("SortSink should be in Building state"); } @@ -53,7 +55,7 @@ impl BlockingSink for SortSink { let concated = Table::concat(parts)?; let sorted = concated.sort(&self.sort_by, &self.descending)?; self.state = SortState::Done(sorted.clone()); - Ok(Some(Arc::new(sorted).into())) + Ok(Some(Arc::new(vec![sorted; 1]).into())) } else { panic!("SortSink should be in Building state"); } diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index 529842c7c5..b330f9f7a9 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -11,14 +11,14 @@ use crate::{ }; pub enum StreamSinkOutput { - NeedMoreInput(Option>), + NeedMoreInput(Option>>), #[allow(dead_code)] - HasMoreOutput(Arc
), - Finished(Option>), + HasMoreOutput(Arc>), + Finished(Option>>), } pub trait StreamingSink: Send + Sync { - fn execute(&mut self, index: usize, input: &Arc
) -> DaftResult; + fn execute(&mut self, index: usize, input: &[Table]) -> DaftResult; #[allow(dead_code)] fn name(&self) -> &'static str; } diff --git a/src/daft-local-execution/src/sources/in_memory.rs b/src/daft-local-execution/src/sources/in_memory.rs index ca8cbc0536..04a4b7f0e5 100644 --- a/src/daft-local-execution/src/sources/in_memory.rs +++ b/src/daft-local-execution/src/sources/in_memory.rs @@ -33,7 +33,7 @@ impl Source for InMemorySource { _io_stats: IOStatsRef, ) -> crate::Result> { if self.data.is_empty() { - let empty = Table::empty(Some(self.schema.clone())); + let empty = vec![Table::empty(Some(self.schema.clone())); 1]; return Ok(Box::pin(futures::stream::once(async { Ok(Arc::new(empty)) }))); @@ -41,9 +41,7 @@ impl Source for InMemorySource { let data = self.data.clone(); let stream = try_stream! { for mp in data { - for table in mp.get_tables()?.iter() { - yield Arc::new(table.clone()); - } + yield mp.get_tables()? } }; Ok(Box::pin(stream)) diff --git a/src/daft-local-execution/src/sources/scan_task.rs b/src/daft-local-execution/src/sources/scan_task.rs index 8c31b314d1..4c51df7d6a 100644 --- a/src/daft-local-execution/src/sources/scan_task.rs +++ b/src/daft-local-execution/src/sources/scan_task.rs @@ -34,7 +34,7 @@ impl ScanTaskSource { )] async fn process_scan_task_stream( scan_task: Arc, - sender: Sender>>, + sender: Sender>>>, maintain_order: bool, io_stats: IOStatsRef, ) -> DaftResult<()> { @@ -42,12 +42,12 @@ impl ScanTaskSource { let mut stream = stream_scan_task(scan_task, Some(io_stats), maintain_order).await?; let mut has_data = false; while let Some(partition) = stream.next().await { - let _ = sender.send(Ok(Arc::new(partition?))).await; + let _ = sender.send(Ok(Arc::new(vec![partition?]))).await; has_data = true; } if !has_data { let empty = Table::empty(Some(schema.clone())); - let _ = sender.send(Ok(Arc::new(empty))).await; + let _ = sender.send(Ok(Arc::new(vec![empty]))).await; } Ok(()) } diff --git a/src/daft-local-execution/src/sources/source.rs b/src/daft-local-execution/src/sources/source.rs index 1865f7284d..75ffdd9894 100644 --- a/src/daft-local-execution/src/sources/source.rs +++ b/src/daft-local-execution/src/sources/source.rs @@ -11,7 +11,7 @@ use crate::{ ExecutionRuntimeHandle, }; -pub type SourceStream<'a> = BoxStream<'a, DaftResult>>; +pub type SourceStream<'a> = BoxStream<'a, DaftResult>>>; pub(crate) trait Source: Send + Sync { fn name(&self) -> &'static str;