From d38ad6a17164b9bf07477ceb17ca5c7f09657720 Mon Sep 17 00:00:00 2001 From: Ammar Ahmad Awan Date: Tue, 5 Jan 2021 10:31:49 -0800 Subject: [PATCH] change dist to torch.distributed to fix bug in assert. (#638) --- deepspeed/utils/distributed.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepspeed/utils/distributed.py b/deepspeed/utils/distributed.py index e70f00b440bb..c9722af21c24 100644 --- a/deepspeed/utils/distributed.py +++ b/deepspeed/utils/distributed.py @@ -77,9 +77,9 @@ def mpi_discovery(distributed_port=TORCH_DISTRIBUTED_DEFAULT_PORT, verbose=True) os.environ['MASTER_PORT'])) if torch.distributed.is_initialized(): - assert dist.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, dist.get_rank()) - assert dist.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format( - world_size, dist.get_world_size()) + assert torch.distributed.get_rank() == rank, "MPI rank {} does not match torch rank {}".format(rank, dist.get_rank()) + assert torch.distributed.get_world_size() == world_size, "MPI world size {} does not match torch world size {}".format( + world_size, torch.distributed.get_world_size()) def in_aml():