diff --git a/deepspeed/checkpoint/utils.py b/deepspeed/checkpoint/utils.py index c305e8884e83..5964da00728e 100644 --- a/deepspeed/checkpoint/utils.py +++ b/deepspeed/checkpoint/utils.py @@ -51,7 +51,12 @@ def clone_tensors_for_torch_save(item, device=torch.device('cpu')): - copy of ``item`` with cloned tensors on target device """ if torch.is_tensor(item): - return item.detach().clone().to(device) + if type(device) is str: + device = torch.device(device) + if device == item.device: + return item.detach().clone() + else: + return item.detach().to(device) elif isinstance(item, list): return [clone_tensors_for_torch_save(v, device) for v in item] elif isinstance(item, tuple):