diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index be3f6739c3..90d89cf852 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -1512,25 +1512,41 @@ def reduce( Then, the reduce instruction is applied to each `i`th slice across the child lists. """ - materializations = list() + pending_materializations: dict[str, MultiOutputPartitionTask[PartitionT]] = {} + inputs_to_reduce = [] + metadatas = [] stage_id = next(stage_id_counter) # Dispatch all fanouts. for step in fanout_plan: + # Check any completed materializations, collect their partitions and metadatas, and add to reduce inputs + newly_completed = [(i, m) for i, m in pending_materializations.items() if m.done()] + for i, completed in newly_completed: + del pending_materializations[i] + inputs_to_reduce.append(deque(completed.partitions())) + metadatas.append(deque(completed.partition_metadatas())) + + # Process another step in the fanout plan if isinstance(step, PartitionTaskBuilder): - step = step.finalize_partition_task_multi_output(stage_id=stage_id) - materializations.append(step) - yield step + finalized_step = step.finalize_partition_task_multi_output(stage_id=stage_id) + pending_materializations[finalized_step.id()] = finalized_step + yield finalized_step + else: + yield step # All fanouts dispatched. Wait for all of them to materialize # (since we need all of them to emit even a single reduce). - while any(not _.done() for _ in materializations): - logger.debug("reduce blocked on completion of all sources in: %s", materializations) - yield None + while pending_materializations: + newly_completed = [(i, m) for i, m in pending_materializations.items() if m.done()] + for i, completed in newly_completed: + del pending_materializations[i] + inputs_to_reduce.append(deque(completed.partitions())) + metadatas.append(deque(completed.partition_metadatas())) + + if pending_materializations: + logger.debug("reduce blocked on completion of all sources in: %s", pending_materializations) + yield None - inputs_to_reduce = [deque(_.partitions()) for _ in materializations] - metadatas = [deque(_.partition_metadatas()) for _ in materializations] - del materializations if not isinstance(reduce_instructions, list): reduce_instructions = [reduce_instructions] * len(inputs_to_reduce[0]) reduce_instructions_ = deque(reduce_instructions)