diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index 898a1f2207..5ae6ade161 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -1574,36 +1574,15 @@ def _get_reducer_inputs_location(self, reducer_idx: int) -> tuple[int, int]: raise ValueError(f"Cannot find merger for reducer_idx: {reducer_idx}") def _merger_options(self, merger_idx: int) -> dict[str, Any]: - # TODO: populate the nth merger's options. Place the nth merge task on the (n % NUM_NODES)th node - # - # node_strategies = { - # node_id: { - # "scheduling_strategy": NodeAffinitySchedulingStrategy( - # node_id, soft=True - # ) - # } - # for node_id in set(merge_task_placement) - # } - # self._merge_task_options = [ - # node_strategies[node_id] for node_id in merge_task_placement - # ] - return {} + num_nodes = len(ray.nodes()) + node_id = ray.nodes()[merger_idx % num_nodes]["NodeID"] + return { + "scheduling_strategy": ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(node_id, soft=True) + } def _reduce_options(self, reducer_idx: int) -> dict[str, Any]: - # TODO: populate the nth merger's options. Place the nth merge task on the (n % NUM_NODES)th node - # - # node_strategies = { - # node_id: { - # "scheduling_strategy": NodeAffinitySchedulingStrategy( - # node_id, soft=True - # ) - # } - # for node_id in set(merge_task_placement) - # } - # self._merge_task_options = [ - # node_strategies[node_id] for node_id in merge_task_placement - # ] - return {} + assigned_merger_idx, _ = self._get_reducer_inputs_location(reducer_idx) + return self._merger_options(assigned_merger_idx) def run(self, materialized_inputs: list[ray.ObjectRef]) -> list[ray.ObjectRef]: """Runs the Mappers and Mergers in a 2-stage pipeline until all mergers are materialized