diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index afdd09597b3177..e792c076b7a77a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2408,7 +2408,8 @@ def save_pretrained( save_function(shard, os.path.join(save_directory, shard_file)) if index is None: - path_to_weights = os.path.join(save_directory, _add_variant(WEIGHTS_NAME, variant)) + weights_file_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME + path_to_weights = os.path.join(save_directory, _add_variant(weights_file_name, variant)) logger.info(f"Model weights saved in {path_to_weights}") else: save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME