diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index f0dfc5259c..898a1f2207 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -1502,8 +1502,8 @@ def push_based_shuffle_service_context( num_cpus = int(ray.cluster_resources()["CPU"]) # Number of mappers is ~2x number of mergers - num_map_tasks = num_cpus // 3 - num_merge_tasks = num_map_tasks * 2 + num_merge_tasks = num_cpus // 3 + num_map_tasks = num_merge_tasks * 2 + (num_cpus % 3) yield RayPushBasedShuffle(num_map_tasks, num_merge_tasks, num_partitions, partition_by) @@ -1570,7 +1570,7 @@ def _get_reducer_inputs_location(self, reducer_idx: int) -> tuple[int, int]: if num_reducers > reducer_idx: return merger_idx, reducer_idx else: - reducer_idx - num_reducers + reducer_idx -= num_reducers raise ValueError(f"Cannot find merger for reducer_idx: {reducer_idx}") def _merger_options(self, merger_idx: int) -> dict[str, Any]: @@ -1613,25 +1613,24 @@ def run(self, materialized_inputs: list[ray.ObjectRef]) -> list[ray.ObjectRef]: Each Mapper then should run partitioning on the data into `N` chunks. """ # [N_ROUNDS, N_MERGERS, N_REDUCERS_PER_MERGER] list of outputs - merge_results: list[list[list[ray.ObjectRef]]] = [] + total_merge_results: list[list[list[ray.ObjectRef]]] = [] map_results_buffer: list[ray.ObjectRef] = [] # Keep running the pipeline while there is still work to do while materialized_inputs or map_results_buffer: # Drain the map_results_buffer, running merge tasks per_round_merge_results = [] - while map_results_buffer: - map_results = map_results_buffer.pop() - assert len(map_results) == self._num_mergers - for merger_idx, merger_input in enumerate(map_results): + if map_results_buffer: + for merger_idx in range(self._num_mergers): + merger_input = [mapper_results[merger_idx] for mapper_results in map_results_buffer] merge_results = merge_fn.options( **self._merger_options(merger_idx), num_returns=self._num_reducers_for_merger(merger_idx) ).remote(*merger_input) per_round_merge_results.append(merge_results) - if per_round_merge_results: - merge_results.append(per_round_merge_results) + total_merge_results.append(per_round_merge_results) # Run map tasks: + map_results_buffer = [] for i in range(self._num_mappers): if len(materialized_inputs) == 0: break @@ -1643,8 +1642,10 @@ def run(self, materialized_inputs: list[ray.ObjectRef]) -> list[ray.ObjectRef]: map_results_buffer.append(map_results) # Wait for all tasks in this wave to complete - ray.wait(per_round_merge_results) - ray.wait(map_results_buffer) + for results in per_round_merge_results: + ray.wait(results) + for results in map_results_buffer: + ray.wait(results) # INVARIANT: At this point, the map/merge step is done # Start running all the reduce functions @@ -1652,7 +1653,9 @@ def run(self, materialized_inputs: list[ray.ObjectRef]) -> list[ray.ObjectRef]: reduce_results = [] for reducer_idx in range(self._num_reducers): assigned_merger_idx, offset = self._get_reducer_inputs_location(reducer_idx) - reducer_inputs = [merge_results[round][assigned_merger_idx][offset] for round in range(len(merge_results))] + reducer_inputs = [ + total_merge_results[round][assigned_merger_idx][offset] for round in range(len(total_merge_results)) + ] res = reduce_fn.options(**self._reduce_options(reducer_idx)).remote(*reducer_inputs) reduce_results.append(res) return reduce_results