-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core
/ FEAT] Add the possibility to push custom tags using PreTrainedModel
itself
#28405
Changes from 1 commit
dd5cbe3
f78ed31
da00274
1180585
3485b10
4b82255
4c7806e
8b89796
e73dc7b
fbef2de
0e4daad
c19e751
eb93371
1fe93b3
a24ad9b
6cfd6f5
40a1d4b
dc31941
acd676b
db3197d
f14cf93
31117f4
36f2cb7
514f13b
b3d5900
22d3412
85584ae
1e3fc1e
59738c6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -87,7 +87,7 @@ | |
replace_return_docstrings, | ||
strtobool, | ||
) | ||
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files | ||
from .utils.hub import convert_file_size_to_int, create_and_tag_model_card, get_checkpoint_shard_files | ||
from .utils.import_utils import ( | ||
ENV_VARS_TRUE_VALUES, | ||
is_sagemaker_mp_enabled, | ||
|
@@ -2423,20 +2423,12 @@ def save_pretrained( | |
# Save the model | ||
for shard_file, shard in shards.items(): | ||
if safe_serialization: | ||
# Retrieve model tags and convert it to a dict of strings | ||
model_tags = self._model_tags | ||
metadata = {"format": "pt"} | ||
|
||
if model_tags is not None: | ||
# Convert as strings | ||
metadata["model_tags"] = json.dumps(model_tags, indent=2, sort_keys=True) | ||
|
||
# At some point we will need to deal better with save_function (used for TPU and other distributed | ||
# joyfulness), but for now this enough. | ||
safe_save_file( | ||
shard, | ||
os.path.join(save_directory, shard_file), | ||
metadata=metadata, | ||
metadata={"format": "pt"}, | ||
) | ||
else: | ||
save_function(shard, os.path.join(save_directory, shard_file)) | ||
|
@@ -2464,6 +2456,13 @@ def save_pretrained( | |
) | ||
|
||
if push_to_hub: | ||
# Eventually create an empty model card | ||
model_card = create_and_tag_model_card(repo_id, self._model_tags) | ||
|
||
# Update model card if needed: | ||
if model_card is not None: | ||
model_card.save(os.path.join(save_directory, "README.md")) | ||
|
||
self._upload_modified_files( | ||
save_directory, | ||
repo_id, | ||
|
@@ -2472,6 +2471,16 @@ def save_pretrained( | |
token=token, | ||
) | ||
|
||
@wraps(PushToHubMixin.push_to_hub) | ||
def push_to_hub(self, *args, **kwargs): | ||
if "tags" not in kwargs: | ||
kwargs["tags"] = self._model_tags | ||
elif "tags" in kwargs and self._model_tags is not None: | ||
logger.warning( | ||
"You manually passed `tags` to `push_to_hub` method and the model has already some tags set, we will use the tags that you passed." | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we use both instead? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 on merging both lists if they exist (for flexibility) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, done ! |
||
return super().push_to_hub(*args, **kwargs) | ||
|
||
def get_memory_footprint(self, return_buffers=True): | ||
r""" | ||
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,6 +33,8 @@ | |
from huggingface_hub import ( | ||
_CACHED_NO_EXIST, | ||
CommitOperationAdd, | ||
ModelCard, | ||
ModelCardData, | ||
constants, | ||
create_branch, | ||
create_commit, | ||
|
@@ -762,6 +764,7 @@ def push_to_hub( | |
safe_serialization: bool = True, | ||
revision: str = None, | ||
commit_description: str = None, | ||
tags: List[str] = None, | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
**deprecated_kwargs, | ||
) -> str: | ||
""" | ||
|
@@ -795,6 +798,8 @@ def push_to_hub( | |
Branch to push the uploaded files to. | ||
commit_description (`str`, *optional*): | ||
The description of the commit that will be created | ||
tags (`List[str]`, *optional*): | ||
List of tags to push on the Hub. | ||
|
||
Examples: | ||
|
||
|
@@ -855,6 +860,9 @@ def push_to_hub( | |
repo_id, private=private, token=token, repo_url=repo_url, organization=organization | ||
) | ||
|
||
# Create a new empty model card and eventually tag it | ||
model_card = create_and_tag_model_card(repo_id, tags) | ||
|
||
if use_temp_dir is None: | ||
use_temp_dir = not os.path.isdir(working_dir) | ||
|
||
|
@@ -864,6 +872,10 @@ def push_to_hub( | |
# Save all files. | ||
self.save_pretrained(work_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization) | ||
|
||
# Update model card if needed: | ||
if model_card is not None: | ||
model_card.save(os.path.join(work_dir, "README.md")) | ||
|
||
return self._upload_modified_files( | ||
work_dir, | ||
repo_id, | ||
|
@@ -1081,6 +1093,27 @@ def extract_info_from_url(url): | |
return {"repo": cache_repo, "revision": revision, "filename": filename} | ||
|
||
|
||
def create_and_tag_model_card(repo_id, tags=None): | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Creates a dummy model card and tags it. | ||
""" | ||
model_card = None | ||
|
||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
try: | ||
# Check if the model card is present on the remote repo | ||
model_card = ModelCard.load(repo_id) | ||
except EntryNotFoundError: | ||
# Otherwise create a simple model card from template | ||
card_data = ModelCardData(language="en", tags=[]) | ||
Wauplin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
model_card = ModelCard.from_template(card_data) | ||
Wauplin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if model_card is not None and tags is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why would There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah nice catch, yes it should never be None here |
||
for model_tag in tags: | ||
model_card.data.tags.append(model_tag) | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return model_card | ||
|
||
|
||
def clean_files_for(file): | ||
""" | ||
Remove, if they exist, file, file.json and file.lock | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) I think using
warnings.warn
would be more appropriate here. I usually tend to follow the "rule" from https://docs.python.org/3/howto/logging.html:Here the problem is most likely avoidable by the user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I think that transformers fully switched to using
logger.warning
(I got told that in a PR review from @LysandreJik but I don't remember which one 😢 ) let me try to find that and link the comment hereThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No you can keep using
warnings.warn
! see #26527