-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core
/ FEAT] Add the possibility to push custom tags using PreTrainedModel
itself
#28405
Changes from 3 commits
dd5cbe3
f78ed31
da00274
1180585
3485b10
4b82255
4c7806e
8b89796
e73dc7b
fbef2de
0e4daad
c19e751
eb93371
1fe93b3
a24ad9b
6cfd6f5
40a1d4b
dc31941
acd676b
db3197d
f14cf93
31117f4
36f2cb7
514f13b
b3d5900
22d3412
85584ae
1e3fc1e
59738c6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -87,7 +87,7 @@ | |
replace_return_docstrings, | ||
strtobool, | ||
) | ||
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files | ||
from .utils.hub import convert_file_size_to_int, create_and_tag_model_card, get_checkpoint_shard_files | ||
from .utils.import_utils import ( | ||
ENV_VARS_TRUE_VALUES, | ||
is_sagemaker_mp_enabled, | ||
|
@@ -1163,6 +1163,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix | |
_no_split_modules = None | ||
_skip_keys_device_placement = None | ||
_keep_in_fp32_modules = None | ||
_model_tags = None | ||
|
||
# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing | ||
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings. | ||
|
@@ -1225,6 +1226,9 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs): | |
# when a different component (e.g. language_model) is used. | ||
self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules) | ||
|
||
# Default the model tags with `"transformers"` | ||
self._model_tags = ["transformers"] | ||
|
||
def post_init(self): | ||
""" | ||
A method executed at the end of each Transformer model initialization, to execute code that needs the model's | ||
|
@@ -1239,6 +1243,24 @@ def _backward_compatibility_gradient_checkpointing(self): | |
# Remove the attribute now that is has been consumed, so it's no saved in the config. | ||
delattr(self.config, "gradient_checkpointing") | ||
|
||
def set_model_tags(self, tags: Union[List[str], str]) -> None: | ||
r""" | ||
Manually set the model tags with `tags` | ||
|
||
Args: | ||
tags (`Union[List[str], str]`): | ||
The desired tags to inject in the model | ||
""" | ||
if isinstance(tags, str): | ||
tags = [tags] | ||
|
||
if self._model_tags is None: | ||
self._model_tags = [] | ||
|
||
for tag in tags: | ||
if tag not in self._model_tags: | ||
self._model_tags.append(tag) | ||
|
||
@classmethod | ||
def _from_config(cls, config, **kwargs): | ||
""" | ||
|
@@ -2403,10 +2425,19 @@ def save_pretrained( | |
if safe_serialization: | ||
# At some point we will need to deal better with save_function (used for TPU and other distributed | ||
# joyfulness), but for now this enough. | ||
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"}) | ||
safe_save_file( | ||
shard, | ||
os.path.join(save_directory, shard_file), | ||
metadata={"format": "pt"}, | ||
) | ||
else: | ||
save_function(shard, os.path.join(save_directory, shard_file)) | ||
|
||
if self._model_tags is not None: | ||
logger.warning( | ||
"Detected tags in the model but you are not using safe_serialization, they will be silently ignored. To properly save these tags you should use safe serialization." | ||
) | ||
|
||
if index is None: | ||
weights_file_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME | ||
path_to_weights = os.path.join(save_directory, _add_variant(weights_file_name, variant)) | ||
|
@@ -2425,6 +2456,13 @@ def save_pretrained( | |
) | ||
|
||
if push_to_hub: | ||
# Eventually create an empty model card | ||
model_card = create_and_tag_model_card(repo_id, self._model_tags) | ||
|
||
# Update model card if needed: | ||
if model_card is not None: | ||
model_card.save(os.path.join(save_directory, "README.md")) | ||
|
||
self._upload_modified_files( | ||
save_directory, | ||
repo_id, | ||
|
@@ -2433,6 +2471,16 @@ def save_pretrained( | |
token=token, | ||
) | ||
|
||
@wraps(PushToHubMixin.push_to_hub) | ||
def push_to_hub(self, *args, **kwargs): | ||
if "tags" not in kwargs: | ||
kwargs["tags"] = self._model_tags | ||
elif "tags" in kwargs and self._model_tags is not None: | ||
logger.warning( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit) I think using
Here the problem is most likely avoidable by the user. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! I think that transformers fully switched to using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No you can keep using |
||
"You manually passed `tags` to `push_to_hub` method and the model has already some tags set, we will use the tags that you passed." | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we use both instead? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 on merging both lists if they exist (for flexibility) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, done ! |
||
return super().push_to_hub(*args, **kwargs) | ||
|
||
def get_memory_footprint(self, return_buffers=True): | ||
r""" | ||
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,6 +33,8 @@ | |
from huggingface_hub import ( | ||
_CACHED_NO_EXIST, | ||
CommitOperationAdd, | ||
ModelCard, | ||
ModelCardData, | ||
constants, | ||
create_branch, | ||
create_commit, | ||
|
@@ -762,6 +764,7 @@ def push_to_hub( | |
safe_serialization: bool = True, | ||
revision: str = None, | ||
commit_description: str = None, | ||
tags: List[str] = None, | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
**deprecated_kwargs, | ||
) -> str: | ||
""" | ||
|
@@ -795,6 +798,8 @@ def push_to_hub( | |
Branch to push the uploaded files to. | ||
commit_description (`str`, *optional*): | ||
The description of the commit that will be created | ||
tags (`List[str]`, *optional*): | ||
List of tags to push on the Hub. | ||
|
||
Examples: | ||
|
||
|
@@ -855,6 +860,9 @@ def push_to_hub( | |
repo_id, private=private, token=token, repo_url=repo_url, organization=organization | ||
) | ||
|
||
# Create a new empty model card and eventually tag it | ||
model_card = create_and_tag_model_card(repo_id, tags) | ||
|
||
if use_temp_dir is None: | ||
use_temp_dir = not os.path.isdir(working_dir) | ||
|
||
|
@@ -864,6 +872,10 @@ def push_to_hub( | |
# Save all files. | ||
self.save_pretrained(work_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization) | ||
|
||
# Update model card if needed: | ||
if model_card is not None: | ||
model_card.save(os.path.join(work_dir, "README.md")) | ||
|
||
return self._upload_modified_files( | ||
work_dir, | ||
repo_id, | ||
|
@@ -1081,6 +1093,27 @@ def extract_info_from_url(url): | |
return {"repo": cache_repo, "revision": revision, "filename": filename} | ||
|
||
|
||
def create_and_tag_model_card(repo_id, tags=None): | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Creates a dummy model card and tags it. | ||
""" | ||
model_card = None | ||
|
||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
try: | ||
# Check if the model card is present on the remote repo | ||
model_card = ModelCard.load(repo_id) | ||
except EntryNotFoundError: | ||
# Otherwise create a simple model card from template | ||
card_data = ModelCardData(language="en", tags=[]) | ||
Wauplin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
model_card = ModelCard.from_template(card_data) | ||
Wauplin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if model_card is not None and tags is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why would There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah nice catch, yes it should never be None here |
||
for model_tag in tags: | ||
model_card.data.tags.append(model_tag) | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return model_card | ||
|
||
|
||
def clean_files_for(file): | ||
""" | ||
Remove, if they exist, file, file.json and file.lock | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we want an append behaviour here - more of an overwrite. Otherwise, if I accidentally add a tag there's no way to ever remove it. I'd suggest a
set_model_tags
andadd_model_tags
methodThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd also be in favor of switching to
add_model_tags
if we are appending tags to a list.Don't know if adding support for removing tags if really necessary (but I don't have full context).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah makes sense, I changed that method to
add_tags
and added a new methodset_tags