Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Dec 4, 2024
2 parents 10cdc5f + fc23007 commit b2175a4
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions deepspeed/utils/zero_to_fp32.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,19 +514,20 @@ def to_torch_tensor(state_dict, return_empty_tensor=False):
"""
Convert state_dict of GatheredTensor to torch tensor
"""
torch_state_dict = {}
converted_tensors = {}
for name, tensor in state_dict.items():
tensor_id = id(tensor)
if tensor_id in converted_tensors:
shared_tensor = state_dict[converted_tensors[tensor_id]]
state_dict[name] = shared_tensor
if tensor_id in converted_tensors: # shared tensors
shared_tensor = torch_state_dict[converted_tensors[tensor_id]]
torch_state_dict[name] = shared_tensor
else:
converted_tensors[tensor_id] = name
if return_empty_tensor:
state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
torch_state_dict[name] = torch.empty(tensor.shape, dtype=tensor.dtype)
else:
state_dict[name] = tensor.contiguous()
return state_dict
torch_state_dict[name] = tensor.contiguous()
return torch_state_dict


def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,
Expand Down Expand Up @@ -660,8 +661,9 @@ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
else:
torch.save(shard_state_dict, output_path)
# release the memory of current shard
for tensor_name in shard_state_dict:
for tensor_name in list(shard_state_dict.keys()):
del state_dict[tensor_name]
del shard_state_dict[tensor_name]
del shard_state_dict
gc.collect()

Expand Down

0 comments on commit b2175a4

Please sign in to comment.