diff --git a/src/daft-local-execution/src/intermediate_ops/aggregate.rs b/src/daft-local-execution/src/intermediate_ops/aggregate.rs index c83bc3f540..39750d1401 100644 --- a/src/daft-local-execution/src/intermediate_ops/aggregate.rs +++ b/src/daft-local-execution/src/intermediate_ops/aggregate.rs @@ -29,10 +29,10 @@ impl IntermediateOperator for AggregateOperator { fn execute( &self, _idx: usize, - input: PipelineResultType, + input: &PipelineResultType, _state: Option<&mut Box>, ) -> DaftResult { - let out = input.data().agg(&self.agg_exprs, &self.group_by)?; + let out = input.as_data().agg(&self.agg_exprs, &self.group_by)?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( out, )))) diff --git a/src/daft-local-execution/src/intermediate_ops/filter.rs b/src/daft-local-execution/src/intermediate_ops/filter.rs index ef85379ec6..eeb02c4aff 100644 --- a/src/daft-local-execution/src/intermediate_ops/filter.rs +++ b/src/daft-local-execution/src/intermediate_ops/filter.rs @@ -25,10 +25,10 @@ impl IntermediateOperator for FilterOperator { fn execute( &self, _idx: usize, - input: PipelineResultType, + input: &PipelineResultType, _state: Option<&mut Box>, ) -> DaftResult { - let out = input.data().filter(&[self.predicate.clone()])?; + let out = input.as_data().filter(&[self.predicate.clone()])?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( out, )))) diff --git a/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs b/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs index 7b1bb04d22..1541cb2341 100644 --- a/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs +++ b/src/daft-local-execution/src/intermediate_ops/hash_join_probe.rs @@ -19,9 +19,9 @@ enum HashJoinProbeState { } impl HashJoinProbeState { - fn set_table(&mut self, table: Arc, tables: Arc>) { + fn set_table(&mut self, table: &Arc, tables: &Arc>) { if let HashJoinProbeState::Building = self { - *self = HashJoinProbeState::ReadyToProbe(table, tables); + *self = HashJoinProbeState::ReadyToProbe(table.clone(), tables.clone()); } else { panic!("HashJoinProbeState should only be in Building state when setting table") } @@ -29,7 +29,7 @@ impl HashJoinProbeState { fn probe( &self, - input: Arc, + input: &Arc, right_on: &[ExprRef], pruned_right_side_columns: &[String], ) -> DaftResult> { @@ -109,7 +109,7 @@ impl IntermediateOperator for HashJoinProbeOperator { fn execute( &self, idx: usize, - input: PipelineResultType, + input: &PipelineResultType, state: Option<&mut Box>, ) -> DaftResult { println!("HashJoinProbeOperator::execute: idx: {}", idx); @@ -120,7 +120,7 @@ impl IntermediateOperator for HashJoinProbeOperator { .as_any_mut() .downcast_mut::() .expect("HashJoinProbeOperator state should be HashJoinProbeState"); - let (probe_table, tables) = input.probe_table(); + let (probe_table, tables) = input.as_probe_table(); state.set_table(probe_table, tables); Ok(IntermediateOperatorResult::NeedMoreInput(None)) } @@ -130,7 +130,7 @@ impl IntermediateOperator for HashJoinProbeOperator { .as_any_mut() .downcast_mut::() .expect("HashJoinProbeOperator state should be HashJoinProbeState"); - let input = input.data(); + let input = input.as_data(); let out = state.probe(input, &self.right_on, &self.pruned_right_side_columns)?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(out))) } diff --git a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs index 6c4253fe98..9abf35c6f6 100644 --- a/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs +++ b/src/daft-local-execution/src/intermediate_ops/intermediate_op.rs @@ -30,7 +30,7 @@ pub trait IntermediateOperator: Send + Sync { fn execute( &self, idx: usize, - input: PipelineResultType, + input: &PipelineResultType, state: Option<&mut Box>, ) -> DaftResult; fn name(&self) -> &'static str; @@ -80,14 +80,20 @@ impl IntermediateNode { let span = info_span!("IntermediateOp::execute"); let mut state = op.make_state(); while let Some((idx, morsel)) = receiver.recv().await { - let result = rt_context.in_span(&span, || op.execute(idx, morsel, state.as_mut()))?; - match result { - IntermediateOperatorResult::NeedMoreInput(Some(mp)) => { - let _ = sender.send(mp.into()).await; - } - IntermediateOperatorResult::NeedMoreInput(None) => {} - IntermediateOperatorResult::HasMoreOutput(mp) => { - let _ = sender.send(mp.into()).await; + loop { + let result = + rt_context.in_span(&span, || op.execute(idx, &morsel, state.as_mut()))?; + match result { + IntermediateOperatorResult::NeedMoreInput(Some(mp)) => { + let _ = sender.send(mp.into()).await; + break; + } + IntermediateOperatorResult::NeedMoreInput(None) => { + break; + } + IntermediateOperatorResult::HasMoreOutput(mp) => { + let _ = sender.send(mp.into()).await; + } } } } @@ -140,7 +146,7 @@ impl IntermediateNode { let _ = worker_sender.send((idx, morsel.clone())).await; } } else { - buffer.push(morsel.data().clone()); + buffer.push(morsel.as_data().clone()); if let Some(ready) = buffer.try_clear() { let _ = send_to_next_worker(idx, ready?.into()).await; } diff --git a/src/daft-local-execution/src/intermediate_ops/project.rs b/src/daft-local-execution/src/intermediate_ops/project.rs index bc9f7f4eea..6f4b57ba00 100644 --- a/src/daft-local-execution/src/intermediate_ops/project.rs +++ b/src/daft-local-execution/src/intermediate_ops/project.rs @@ -25,10 +25,10 @@ impl IntermediateOperator for ProjectOperator { fn execute( &self, _idx: usize, - input: PipelineResultType, + input: &PipelineResultType, _state: Option<&mut Box>, ) -> DaftResult { - let out = input.data().eval_expression_list(&self.projection)?; + let out = input.as_data().eval_expression_list(&self.projection)?; Ok(IntermediateOperatorResult::NeedMoreInput(Some(Arc::new( out, )))) diff --git a/src/daft-local-execution/src/pipeline.rs b/src/daft-local-execution/src/pipeline.rs index ad9668637f..7b05472ed7 100644 --- a/src/daft-local-execution/src/pipeline.rs +++ b/src/daft-local-execution/src/pipeline.rs @@ -53,14 +53,14 @@ impl From<(Arc, Arc>)> for PipelineResultType { } impl PipelineResultType { - pub fn data(self) -> Arc { + pub fn as_data(&self) -> &Arc { match self { PipelineResultType::Data(data) => data, _ => panic!("Expected data"), } } - pub fn probe_table(self) -> (Arc, Arc>) { + pub fn as_probe_table(&self) -> (&Arc, &Arc>) { match self { PipelineResultType::ProbeTable(probe_table, tables) => (probe_table, tables), _ => panic!("Expected probe table"), diff --git a/src/daft-local-execution/src/run.rs b/src/daft-local-execution/src/run.rs index 04726f86e5..6d9c53e01d 100644 --- a/src/daft-local-execution/src/run.rs +++ b/src/daft-local-execution/src/run.rs @@ -141,7 +141,7 @@ pub fn run_local( .await? .get_receiver(); while let Some(val) = receiver.recv().await { - let _ = tx.send(val.data()).await; + let _ = tx.send(val.as_data().clone()).await; } while let Some(result) = runtime_handle.join_next().await { diff --git a/src/daft-local-execution/src/sinks/blocking_sink.rs b/src/daft-local-execution/src/sinks/blocking_sink.rs index 63ea4c2cee..87ba335215 100644 --- a/src/daft-local-execution/src/sinks/blocking_sink.rs +++ b/src/daft-local-execution/src/sinks/blocking_sink.rs @@ -100,7 +100,7 @@ impl PipelineNode for BlockingSinkNode { let mut guard = op.lock().await; while let Some(val) = child_results_receiver.recv().await { if let BlockingSinkStatus::Finished = - rt_context.in_span(&span, || guard.sink(&val.data()))? + rt_context.in_span(&span, || guard.sink(val.as_data()))? { break; } diff --git a/src/daft-local-execution/src/sinks/streaming_sink.rs b/src/daft-local-execution/src/sinks/streaming_sink.rs index afe4be0d3b..dcf1b9db8d 100644 --- a/src/daft-local-execution/src/sinks/streaming_sink.rs +++ b/src/daft-local-execution/src/sinks/streaming_sink.rs @@ -108,9 +108,9 @@ impl PipelineNode for StreamingSinkNode { let mut sink = op.lock().await; let mut is_active = true; while is_active && let Some(val) = child_results_receiver.recv().await { - let val = val.data(); + let val = val.as_data(); loop { - let result = runtime_stats.in_span(&span, || sink.execute(0, &val))?; + let result = runtime_stats.in_span(&span, || sink.execute(0, val))?; match result { StreamSinkOutput::HasMoreOutput(mp) => { sender.send(mp.into()).await.unwrap();