Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core/ FEAT] Add the possibility to push custom tags using PreTrainedModel itself #28405

Merged
merged 29 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
dd5cbe3
v1 tags
younesbelkada Jan 9, 2024
f78ed31
remove unneeded conversion
younesbelkada Jan 9, 2024
da00274
v2
younesbelkada Jan 9, 2024
1180585
rm unneeded warning
younesbelkada Jan 9, 2024
3485b10
Merge remote-tracking branch 'upstream/main' into set-custom-tag
younesbelkada Jan 9, 2024
4b82255
add more utility methods
younesbelkada Jan 9, 2024
4c7806e
Update src/transformers/utils/hub.py
younesbelkada Jan 9, 2024
8b89796
Update src/transformers/utils/hub.py
younesbelkada Jan 9, 2024
e73dc7b
Update src/transformers/utils/hub.py
younesbelkada Jan 9, 2024
fbef2de
more enhancements
younesbelkada Jan 9, 2024
0e4daad
oops
younesbelkada Jan 9, 2024
c19e751
merge tags
younesbelkada Jan 9, 2024
eb93371
clean up
younesbelkada Jan 9, 2024
1fe93b3
revert unneeded change
younesbelkada Jan 9, 2024
a24ad9b
Merge remote-tracking branch 'upstream/main' into set-custom-tag
younesbelkada Jan 10, 2024
6cfd6f5
add extensive docs
younesbelkada Jan 10, 2024
40a1d4b
more docs
younesbelkada Jan 10, 2024
dc31941
more kwargs
younesbelkada Jan 10, 2024
acd676b
add test
younesbelkada Jan 10, 2024
db3197d
oops
younesbelkada Jan 10, 2024
f14cf93
fix test
younesbelkada Jan 10, 2024
31117f4
Update src/transformers/modeling_utils.py
younesbelkada Jan 10, 2024
36f2cb7
Update src/transformers/utils/hub.py
younesbelkada Jan 10, 2024
514f13b
Update src/transformers/modeling_utils.py
younesbelkada Jan 10, 2024
b3d5900
Update src/transformers/trainer.py
younesbelkada Jan 15, 2024
22d3412
Update src/transformers/modeling_utils.py
younesbelkada Jan 15, 2024
85584ae
Merge remote-tracking branch 'upstream/main' into set-custom-tag
younesbelkada Jan 15, 2024
1e3fc1e
add more conditions
younesbelkada Jan 15, 2024
59738c6
more logic
younesbelkada Jan 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
replace_return_docstrings,
strtobool,
)
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from .utils.hub import convert_file_size_to_int, create_and_tag_model_card, get_checkpoint_shard_files
from .utils.import_utils import (
ENV_VARS_TRUE_VALUES,
is_sagemaker_mp_enabled,
Expand Down Expand Up @@ -1163,6 +1163,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
_no_split_modules = None
_skip_keys_device_placement = None
_keep_in_fp32_modules = None
_model_tags = None

# a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
# keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
Expand Down Expand Up @@ -1225,6 +1226,9 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
# when a different component (e.g. language_model) is used.
self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules)

# Default the model tags with `"transformers"`
self._model_tags = ["transformers"]

def post_init(self):
"""
A method executed at the end of each Transformer model initialization, to execute code that needs the model's
Expand All @@ -1239,6 +1243,24 @@ def _backward_compatibility_gradient_checkpointing(self):
# Remove the attribute now that is has been consumed, so it's no saved in the config.
delattr(self.config, "gradient_checkpointing")

def set_model_tags(self, tags: Union[List[str], str]) -> None:
r"""
Manually set the model tags with `tags`

Args:
tags (`Union[List[str], str]`):
The desired tags to inject in the model
"""
if isinstance(tags, str):
tags = [tags]

if self._model_tags is None:
self._model_tags = []

for tag in tags:
if tag not in self._model_tags:
self._model_tags.append(tag)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want an append behaviour here - more of an overwrite. Otherwise, if I accidentally add a tag there's no way to ever remove it. I'd suggest a set_model_tags and add_model_tags method

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd also be in favor of switching to add_model_tags if we are appending tags to a list.
Don't know if adding support for removing tags if really necessary (but I don't have full context).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah makes sense, I changed that method to add_tags and added a new method set_tags


@classmethod
def _from_config(cls, config, **kwargs):
"""
Expand Down Expand Up @@ -2403,10 +2425,19 @@ def save_pretrained(
if safe_serialization:
# At some point we will need to deal better with save_function (used for TPU and other distributed
# joyfulness), but for now this enough.
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"})
safe_save_file(
shard,
os.path.join(save_directory, shard_file),
metadata={"format": "pt"},
)
else:
save_function(shard, os.path.join(save_directory, shard_file))

if self._model_tags is not None:
logger.warning(
"Detected tags in the model but you are not using safe_serialization, they will be silently ignored. To properly save these tags you should use safe serialization."
)

if index is None:
weights_file_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
path_to_weights = os.path.join(save_directory, _add_variant(weights_file_name, variant))
Expand All @@ -2425,6 +2456,13 @@ def save_pretrained(
)

if push_to_hub:
# Eventually create an empty model card
model_card = create_and_tag_model_card(repo_id, self._model_tags)

# Update model card if needed:
if model_card is not None:
model_card.save(os.path.join(save_directory, "README.md"))

self._upload_modified_files(
save_directory,
repo_id,
Expand All @@ -2433,6 +2471,16 @@ def save_pretrained(
token=token,
)

@wraps(PushToHubMixin.push_to_hub)
def push_to_hub(self, *args, **kwargs):
if "tags" not in kwargs:
kwargs["tags"] = self._model_tags
elif "tags" in kwargs and self._model_tags is not None:
logger.warning(
Copy link
Contributor

@Wauplin Wauplin Jan 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) I think using warnings.warn would be more appropriate here. I usually tend to follow the "rule" from https://docs.python.org/3/howto/logging.html:

warnings.warn() in library code if the issue is avoidable and the client application should be modified to eliminate the warning

logging.warning() if there is nothing the client application can do about the situation, but the event should still be noted

Here the problem is most likely avoidable by the user.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I think that transformers fully switched to using logger.warning (I got told that in a PR review from @LysandreJik but I don't remember which one 😢 ) let me try to find that and link the comment here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No you can keep using warnings.warn! see #26527

"You manually passed `tags` to `push_to_hub` method and the model has already some tags set, we will use the tags that you passed."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use both instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on merging both lists if they exist (for flexibility)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, done !

return super().push_to_hub(*args, **kwargs)

def get_memory_footprint(self, return_buffers=True):
r"""
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
Expand Down
33 changes: 33 additions & 0 deletions src/transformers/utils/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from huggingface_hub import (
_CACHED_NO_EXIST,
CommitOperationAdd,
ModelCard,
ModelCardData,
constants,
create_branch,
create_commit,
Expand Down Expand Up @@ -762,6 +764,7 @@ def push_to_hub(
safe_serialization: bool = True,
revision: str = None,
commit_description: str = None,
tags: List[str] = None,
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
**deprecated_kwargs,
) -> str:
"""
Expand Down Expand Up @@ -795,6 +798,8 @@ def push_to_hub(
Branch to push the uploaded files to.
commit_description (`str`, *optional*):
The description of the commit that will be created
tags (`List[str]`, *optional*):
List of tags to push on the Hub.

Examples:

Expand Down Expand Up @@ -855,6 +860,9 @@ def push_to_hub(
repo_id, private=private, token=token, repo_url=repo_url, organization=organization
)

# Create a new empty model card and eventually tag it
model_card = create_and_tag_model_card(repo_id, tags)

if use_temp_dir is None:
use_temp_dir = not os.path.isdir(working_dir)

Expand All @@ -864,6 +872,10 @@ def push_to_hub(
# Save all files.
self.save_pretrained(work_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)

# Update model card if needed:
if model_card is not None:
model_card.save(os.path.join(work_dir, "README.md"))

return self._upload_modified_files(
work_dir,
repo_id,
Expand Down Expand Up @@ -1081,6 +1093,27 @@ def extract_info_from_url(url):
return {"repo": cache_repo, "revision": revision, "filename": filename}


def create_and_tag_model_card(repo_id, tags=None):
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
"""
Creates a dummy model card and tags it.
"""
model_card = None

younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
try:
# Check if the model card is present on the remote repo
model_card = ModelCard.load(repo_id)
except EntryNotFoundError:
# Otherwise create a simple model card from template
card_data = ModelCardData(language="en", tags=[])
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
model_card = ModelCard.from_template(card_data)
Wauplin marked this conversation as resolved.
Show resolved Hide resolved

if model_card is not None and tags is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would model_card be None here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah nice catch, yes it should never be None here

for model_tag in tags:
model_card.data.tags.append(model_tag)
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved

return model_card


def clean_files_for(file):
"""
Remove, if they exist, file, file.json and file.lock
Expand Down
Loading