diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index 6a7f5a2879..2e40a280a4 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -59,6 +59,7 @@ from daft.daft import FileFormat, IOConfig, JoinType, PyExpr from daft.logical.schema import Schema + from daft.runners.partitioning import PartialPartitionMetadata # A PhysicalPlan that is still being built - may yield both PartitionTaskBuilders and PartitionTasks. @@ -1624,8 +1625,71 @@ def streaming_push_exchange_op( def fully_materializing_exchange_op( child_plan: InProgressPhysicalPlan[PartitionT], partition_by: list[PyExpr], num_partitions: int ) -> InProgressPhysicalPlan[PartitionT]: - # 1. Materialize everything - raise NotImplementedError("TODO: Sammy") + from daft.expressions import Expression + + # Step 1: Naively materialize all child partitions + stage_id_children = next(stage_id_counter) + materialized_partitions: list[SingleOutputPartitionTask] = [] + for step in child_plan: + if isinstance(step, PartitionTaskBuilder): + task = step.finalize_partition_task_single_output(stage_id=stage_id_children) + materialized_partitions.append(task) + yield task + elif isinstance(step, PartitionTask): + yield step + elif step is None: + yield None + else: + yield step + + # Step 2: Wait for all partitions to be done + while any(not p.done() for p in materialized_partitions): + yield None + + # Step 3: Yield the map tasks + stage_id_map_tasks = next(stage_id_counter) + materialized_map_partitions: list[MultiOutputPartitionTask] = [] + while materialized_partitions: + materialized_child_partition = materialized_partitions.pop(0) + map_task = ( + PartitionTaskBuilder( + inputs=[materialized_child_partition.partition()], + partial_metadatas=materialized_child_partition.partial_metadatas, + resource_request=ResourceRequest(), + ) + .add_instruction( + execution_step.FanoutHash( + _num_outputs=num_partitions, + partition_by=ExpressionsProjection([Expression._from_pyexpr(expr) for expr in partition_by]), + ), + ResourceRequest(), + ) + .finalize_partition_task_multi_output(stage_id=stage_id_map_tasks) + ) + materialized_map_partitions.append(map_task) + yield map_task + + # Step 4: Wait on all the map tasks to complete + while any(not p.done() for p in materialized_map_partitions): + yield None + + # Step 5: "Transpose the results" and run reduce tasks + transposed_results: list[list[tuple[PartitionT, PartialPartitionMetadata]]] = [[] for _ in range(num_partitions)] + for map_task in materialized_map_partitions: + partitions = map_task.partitions() + partition_metadatas = map_task.partial_metadatas + for i, (partition, meta) in enumerate(zip(partitions, partition_metadatas)): + transposed_results[i].append((partition, meta)) + + for i, partitions in enumerate(transposed_results): + reduce_task = PartitionTaskBuilder( + inputs=[p for p, _ in partitions], + partial_metadatas=[m for _, m in partitions], + resource_request=ResourceRequest(), + ).add_instruction( + instruction=execution_step.ReduceMerge(), + ) + yield reduce_task # This was the complicated one...