Skip to content

Commit

Permalink
Fixes it works locally
Browse files Browse the repository at this point in the history
  • Loading branch information
Jay Chia committed Sep 29, 2024
1 parent 37faaa0 commit 1e8b87d
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions daft/runners/ray_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,8 +1502,8 @@ def push_based_shuffle_service_context(
num_cpus = int(ray.cluster_resources()["CPU"])

# Number of mappers is ~2x number of mergers
num_map_tasks = num_cpus // 3
num_merge_tasks = num_map_tasks * 2
num_merge_tasks = num_cpus // 3
num_map_tasks = num_merge_tasks * 2 + (num_cpus % 3)

yield RayPushBasedShuffle(num_map_tasks, num_merge_tasks, num_partitions, partition_by)

Expand Down Expand Up @@ -1570,7 +1570,7 @@ def _get_reducer_inputs_location(self, reducer_idx: int) -> tuple[int, int]:
if num_reducers > reducer_idx:
return merger_idx, reducer_idx
else:
reducer_idx - num_reducers
reducer_idx -= num_reducers
raise ValueError(f"Cannot find merger for reducer_idx: {reducer_idx}")

def _merger_options(self, merger_idx: int) -> dict[str, Any]:
Expand Down Expand Up @@ -1613,25 +1613,24 @@ def run(self, materialized_inputs: list[ray.ObjectRef]) -> list[ray.ObjectRef]:
Each Mapper then should run partitioning on the data into `N` chunks.
"""
# [N_ROUNDS, N_MERGERS, N_REDUCERS_PER_MERGER] list of outputs
merge_results: list[list[list[ray.ObjectRef]]] = []
total_merge_results: list[list[list[ray.ObjectRef]]] = []
map_results_buffer: list[ray.ObjectRef] = []

# Keep running the pipeline while there is still work to do
while materialized_inputs or map_results_buffer:
# Drain the map_results_buffer, running merge tasks
per_round_merge_results = []
while map_results_buffer:
map_results = map_results_buffer.pop()
assert len(map_results) == self._num_mergers
for merger_idx, merger_input in enumerate(map_results):
if map_results_buffer:
for merger_idx in range(self._num_mergers):
merger_input = [mapper_results[merger_idx] for mapper_results in map_results_buffer]
merge_results = merge_fn.options(
**self._merger_options(merger_idx), num_returns=self._num_reducers_for_merger(merger_idx)
).remote(*merger_input)
per_round_merge_results.append(merge_results)
if per_round_merge_results:
merge_results.append(per_round_merge_results)
total_merge_results.append(per_round_merge_results)

# Run map tasks:
map_results_buffer = []
for i in range(self._num_mappers):
if len(materialized_inputs) == 0:
break
Expand All @@ -1643,16 +1642,20 @@ def run(self, materialized_inputs: list[ray.ObjectRef]) -> list[ray.ObjectRef]:
map_results_buffer.append(map_results)

# Wait for all tasks in this wave to complete
ray.wait(per_round_merge_results)
ray.wait(map_results_buffer)
for results in per_round_merge_results:
ray.wait(results)
for results in map_results_buffer:
ray.wait(results)

# INVARIANT: At this point, the map/merge step is done
# Start running all the reduce functions
# TODO: we could stagger this by num CPUs as well, but here we just YOLO run all
reduce_results = []
for reducer_idx in range(self._num_reducers):
assigned_merger_idx, offset = self._get_reducer_inputs_location(reducer_idx)
reducer_inputs = [merge_results[round][assigned_merger_idx][offset] for round in range(len(merge_results))]
reducer_inputs = [
total_merge_results[round][assigned_merger_idx][offset] for round in range(len(total_merge_results))
]
res = reduce_fn.options(**self._reduce_options(reducer_idx)).remote(*reducer_inputs)
reduce_results.append(res)
return reduce_results

0 comments on commit 1e8b87d

Please sign in to comment.