diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f1f2b1ce004db3..4d59635c90e97b 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2657,7 +2657,10 @@ def save_pretrained( ): os.remove(full_filename) # Save the model - for shard_file, tensors in state_dict_split.filename_to_tensors.items(): + filename_to_tensors = state_dict_split.filename_to_tensors.items() + if module_map: + filename_to_tensors = logging.tqdm(filename_to_tensors, desc="Saving checkpoint shards") + for shard_file, tensors in filename_to_tensors: shard = {tensor: state_dict[tensor] for tensor in tensors} # remake shard with onloaded parameters if necessary if module_map: