Skip to content

Commit

Permalink
Introduce decorator hack to preserve BC with incorrect usage of save_…
Browse files Browse the repository at this point in the history
…to_hub (#2380)

E.g. providing keyword arguments using positional arguments & vice versa.
  • Loading branch information
tomaarsen authored Dec 14, 2023
1 parent 594143e commit 9a2e415
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 1 deletion.
3 changes: 2 additions & 1 deletion sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from . import __MODEL_HUB_ORGANIZATION__
from .evaluation import SentenceEvaluator
from .util import import_from_string, batch_to_device, fullname, is_sentence_transformer_model, load_dir_path, load_file_path
from .util import import_from_string, batch_to_device, fullname, is_sentence_transformer_model, load_dir_path, load_file_path, save_to_hub_args_decorator
from .models import Transformer, Pooling
from .model_card_templates import ModelCardTemplate
from . import __version__
Expand Down Expand Up @@ -471,6 +471,7 @@ def _create_model_card(self, path: str, model_name: Optional[str] = None, train_
with open(os.path.join(path, "README.md"), "w", encoding='utf8') as fOut:
fOut.write(model_card.strip())

@save_to_hub_args_decorator
def save_to_hub(self,
repo_id: str,
organization: Optional[str] = None,
Expand Down
20 changes: 20 additions & 0 deletions sentence_transformers/util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import requests
from torch import Tensor, device
from typing import List, Callable
Expand Down Expand Up @@ -482,3 +483,22 @@ def load_dir_path(model_name_or_path: str, directory: str, token: Optional[Union
download_kwargs["local_files_only"] = True
repo_path = snapshot_download(**download_kwargs)
return os.path.join(repo_path, directory)


def save_to_hub_args_decorator(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
# If repo_id not already set, use repo_name
repo_name = kwargs.pop("repo_name", None)
if repo_name and "repo_id" not in kwargs:
logger.warning(
"Providing a `repo_name` keyword argument to `save_to_hub` is deprecated, please use `repo_id` instead."
)
kwargs["repo_id"] = repo_name

# If positional args are used, adjust for the new "token" keyword argument
if len(args) >= 2:
args = (*args[:2], None, *args[2:])

return func(self, *args, **kwargs)
return wrapper
33 changes: 33 additions & 0 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,36 @@ def mock_list_repo_refs(self, repo_id=None, **kwargs):
assert mock_upload_folder_kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors"
assert mock_upload_folder_kwargs["folder_path"] == "my_fake_local_model_path"
assert url == "https://huggingface.co/sentence-transformers-testing/stsb-bert-tiny-safetensors/commit/123456"
mock_upload_folder_kwargs.clear()

# Incorrect usage: Using deprecated "repo_name" positional argument
caplog.clear()
with caplog.at_level(logging.WARNING):
url = model.save_to_hub(repo_name="sentence-transformers-testing/stsb-bert-tiny-safetensors")
assert mock_upload_folder_kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors"
assert url == "https://huggingface.co/sentence-transformers-testing/stsb-bert-tiny-safetensors/commit/123456"
assert len(caplog.record_tuples) == 1
assert (
caplog.record_tuples[0][2]
== "Providing a `repo_name` keyword argument to `save_to_hub` is deprecated, please use `repo_id` instead."
)
mock_upload_folder_kwargs.clear()

# Incorrect usage: Use positional arguments from before "token" was introduced
caplog.clear()
with caplog.at_level(logging.WARNING):
url = model.save_to_hub(
"stsb-bert-tiny-safetensors", # repo_name
"sentence-transformers-testing", # organization
True, # private
"Adding new awesome Model!", # commit message
exist_ok=True,
)
assert mock_upload_folder_kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors"
assert mock_upload_folder_kwargs["commit_message"] == "Adding new awesome Model!"
assert url == "https://huggingface.co/sentence-transformers-testing/stsb-bert-tiny-safetensors/commit/123456"
assert len(caplog.record_tuples) == 1
assert (
caplog.record_tuples[0][2]
== 'Providing an `organization` to `save_to_hub` is deprecated, please use `repo_id="sentence-transformers-testing/stsb-bert-tiny-safetensors"` instead.'
)

0 comments on commit 9a2e415

Please sign in to comment.