From c9fc34a4be558efce2a26d2b4e08cd8880524dc8 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Thu, 17 Oct 2024 15:15:25 -0700 Subject: [PATCH] Use file store for tests (#6632) This PR changes the `init_method` for tests to `FileStore` for robustness. --- tests/unit/common.py | 50 +++++++++++++++++++++++++++----------------- 1 file changed, 31 insertions(+), 19 deletions(-) diff --git a/tests/unit/common.py b/tests/unit/common.py index 69ba4c2708ac..685f943df2fe 100644 --- a/tests/unit/common.py +++ b/tests/unit/common.py @@ -147,16 +147,13 @@ class DistributedExec(ABC): def run(self): ... - def __call__(self, request=None): + def __call__(self, request): self._fixture_kwargs = self._get_fixture_kwargs(request, self.run) world_size = self.world_size if self.requires_cuda_env and not get_accelerator().is_available(): pytest.skip("only supported in accelerator environments.") - if isinstance(world_size, int): - world_size = [world_size] - for procs in world_size: - self._launch_procs(procs) + self._launch_with_file_store(request, world_size) def _get_fixture_kwargs(self, request, func): if not request: @@ -172,7 +169,7 @@ def _get_fixture_kwargs(self, request, func): pass # test methods can have kwargs that are not fixtures return fixture_kwargs - def _launch_daemonic_procs(self, num_procs): + def _launch_daemonic_procs(self, num_procs, init_method): # Create process pool or use cached one master_port = None @@ -198,7 +195,7 @@ def _launch_daemonic_procs(self, num_procs): master_port = get_master_port() # Run the test - args = [(local_rank, num_procs, master_port) for local_rank in range(num_procs)] + args = [(local_rank, num_procs, master_port, init_method) for local_rank in range(num_procs)] skip_msgs_async = pool.starmap_async(self._dist_run, args) try: @@ -218,7 +215,7 @@ def _launch_daemonic_procs(self, num_procs): assert len(set(skip_msgs)) == 1, "Multiple different skip messages received" pytest.skip(skip_msgs[0]) - def _launch_non_daemonic_procs(self, num_procs): + def _launch_non_daemonic_procs(self, num_procs, init_method): assert not self.reuse_dist_env, "Cannot reuse distributed environment with non-daemonic processes" master_port = get_master_port() @@ -227,7 +224,7 @@ def _launch_non_daemonic_procs(self, num_procs): prev_start_method = mp.get_start_method() mp.set_start_method('spawn', force=True) for local_rank in range(num_procs): - p = mp.Process(target=self._dist_run, args=(local_rank, num_procs, master_port, skip_msg)) + p = mp.Process(target=self._dist_run, args=(local_rank, num_procs, master_port, init_method, skip_msg)) p.start() processes.append(p) mp.set_start_method(prev_start_method, force=True) @@ -269,7 +266,7 @@ def _launch_non_daemonic_procs(self, num_procs): # add a check here to assert all exit messages are equal pytest.skip(skip_msg.get()) - def _launch_procs(self, num_procs): + def _launch_procs(self, num_procs, init_method): # Verify we have enough accelerator devices to run this test if get_accelerator().is_available() and get_accelerator().device_count() < num_procs: pytest.skip( @@ -284,11 +281,11 @@ def _launch_procs(self, num_procs): mp.set_start_method('forkserver', force=True) if self.non_daemonic_procs: - self._launch_non_daemonic_procs(num_procs) + self._launch_non_daemonic_procs(num_procs, init_method) else: - self._launch_daemonic_procs(num_procs) + self._launch_daemonic_procs(num_procs, init_method) - def _dist_run(self, local_rank, num_procs, master_port, skip_msg=""): + def _dist_run(self, local_rank, num_procs, master_port, init_method, skip_msg=""): if not dist.is_initialized(): """ Initialize deepspeed.comm and execute the user function. """ if self.set_dist_env: @@ -312,7 +309,10 @@ def _dist_run(self, local_rank, num_procs, master_port, skip_msg=""): get_accelerator().set_device(local_rank) if self.init_distributed: - deepspeed.init_distributed(dist_backend=self.backend) + deepspeed.init_distributed(dist_backend=self.backend, + init_method=init_method, + rank=local_rank, + world_size=num_procs) dist.barrier() try: @@ -328,6 +328,22 @@ def _dist_run(self, local_rank, num_procs, master_port, skip_msg=""): return skip_msg + def _launch_with_file_store(self, request, world_size): + tmpdir = request.getfixturevalue("tmpdir") + dist_file_store = tmpdir.join("dist_file_store") + assert not os.path.exists(dist_file_store) + init_method = f"file://{dist_file_store}" + + if isinstance(world_size, int): + world_size = [world_size] + for procs in world_size: + try: + self._launch_procs(procs, init_method) + finally: + if os.path.exists(dist_file_store): + os.remove(dist_file_store) + time.sleep(0.5) + def _dist_destroy(self): if (dist is not None) and dist.is_initialized(): dist.barrier() @@ -473,11 +489,7 @@ def __call__(self, request): else: world_size = self._fixture_kwargs.get("world_size", self.world_size) - if isinstance(world_size, int): - world_size = [world_size] - for procs in world_size: - self._launch_procs(procs) - time.sleep(0.5) + self._launch_with_file_store(request, world_size) def _get_current_test_func(self, request): # DistributedTest subclasses may have multiple test methods