diff --git a/ssd/utils/distributed_util.py b/ssd/utils/distributed_util.py index 38c6303b..7a9b5bd1 100644 --- a/ssd/utils/distributed_util.py +++ b/ssd/utils/distributed_util.py @@ -71,8 +71,8 @@ def all_gather(data): tensor = torch.ByteTensor(storage).to("cuda") # obtain Tensor size of each rank - local_size = torch.IntTensor([tensor.numel()]).to("cuda") - size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)] + local_size = torch.LongTensor([tensor.numel()]).to("cuda") + size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)] dist.all_gather(size_list, local_size) size_list = [int(size.item()) for size in size_list] max_size = max(size_list)