Skip to content

Commit

Permalink
[fix] Simplify save_to_hub, remove git dependency, add 'token' ar…
Browse files Browse the repository at this point in the history
…gument (#2376)

* Fix & simplify save_to_hub, add 'token' argument

* Convert sentence_transformers tests to pytest format

* Run formatting

black with line length of 120

* Add mocked tests, add coverage to gitignore

100% (but mocked) coverage on save_to_hub

* Update dependencies

transformers>4.32.0 is required for 'token'
huggingface_hub>=0.15.1 is required by transformers

* Prevent backward compatibility breaking: return commit link

Thanks @Wauplin for the help
  • Loading branch information
tomaarsen authored Dec 13, 2023
1 parent 8af4744 commit 1565ef6
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 127 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,7 @@ nr_*/
/docs/make.bat
/docs/Makefile
/examples/training/quora_duplicate_questions/quora-IR-dataset/
build
build

htmlcov
.coverage
5 changes: 2 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
transformers>=4.6.0,<5.0.0
tokenizers>=0.10.3
transformers>=4.32.0,<5.0.0
tqdm
torch>=1.6.0
numpy
scikit-learn
scipy
nltk
sentencepiece
huggingface-hub
huggingface-hub>=0.15.1
Pillow
116 changes: 44 additions & 72 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
from numpy import ndarray
import transformers
from huggingface_hub import HfApi, HfFolder, Repository
from huggingface_hub import HfApi
import torch
from torch import nn, Tensor, device
from torch.optim import Optimizer
Expand Down Expand Up @@ -468,8 +468,9 @@ def _create_model_card(self, path: str, model_name: Optional[str] = None, train_
fOut.write(model_card.strip())

def save_to_hub(self,
repo_name: str,
repo_id: str,
organization: Optional[str] = None,
token: Optional[str] = None,
private: Optional[bool] = None,
commit_message: str = "Add new SentenceTransformer model.",
local_model_path: Optional[str] = None,
Expand All @@ -479,90 +480,61 @@ def save_to_hub(self,
"""
Uploads all elements of this Sentence Transformer to a new HuggingFace Hub repository.
:param repo_name: Repository name for your model in the Hub.
:param organization: Organization in which you want to push your model or tokenizer (you must be a member of this organization).
:param repo_id: Repository name for your model in the Hub, including the user or organization.
:param token: An authentication token (See https://huggingface.co/settings/token)
:param private: Set to true, for hosting a prive model
:param commit_message: Message to commit while pushing.
:param local_model_path: Path of the model locally. If set, this file path will be uploaded. Otherwise, the current model will be uploaded
:param exist_ok: If true, saving to an existing repository is OK. If false, saving only to a new repository is possible
:param replace_model_card: If true, replace an existing model card in the hub with the automatically created model card
:param train_datasets: Datasets used to train the model. If set, the datasets will be added to the model card in the Hub.
:return: The url of the commit of your model in the given repository.
"""
token = HfFolder.get_token()
if token is None:
raise ValueError("You must login to the Hugging Face hub on this computer by typing `transformers-cli login`.")
:param organization: Deprecated. Organization in which you want to push your model or tokenizer (you must be a member of this organization).
if '/' in repo_name:
splits = repo_name.split('/', maxsplit=1)
if organization is None or organization == splits[0]:
organization = splits[0]
repo_name = splits[1]
:return: The url of the commit of your model in the repository on the Hugging Face Hub.
"""
if organization:
if "/" not in repo_id:
logger.warning(
f"Providing an `organization` to `save_to_hub` is deprecated, please use `repo_id=\"{organization}/{repo_id}\"` instead."
)
repo_id = f"{organization}/{repo_id}"
elif repo_id.split("/")[0] != organization:
raise ValueError("Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id`.")
else:
raise ValueError("You passed and invalid repository name: {}.".format(repo_name))
logger.warning(
f"Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id=\"{repo_id}\"` instead."
)

endpoint = "https://huggingface.co"
repo_id = repo_name
if organization:
repo_id = f"{organization}/{repo_id}"
repo_url = HfApi(endpoint=endpoint).create_repo(
api = HfApi(token=token)
repo_url = api.create_repo(
repo_id=repo_id,
private=private,
repo_type=None,
exist_ok=exist_ok,
)
if local_model_path:
folder_url = api.upload_folder(
repo_id=repo_id,
token=token,
private=private,
repo_type=None,
exist_ok=exist_ok,
folder_path=local_model_path,
commit_message=commit_message
)
full_model_name = repo_url[len(endpoint)+1:].strip("/")

with tempfile.TemporaryDirectory() as tmp_dir:
# First create the repo (and clone its content if it's nonempty).
logger.info("Create repository and clone it if it exists")
repo = Repository(tmp_dir, clone_from=repo_url)

# If user provides local files, copy them.
if local_model_path:
copy_tree(local_model_path, tmp_dir)
else: # Else, save model directly into local repo.
else:
with tempfile.TemporaryDirectory() as tmp_dir:
create_model_card = replace_model_card or not os.path.exists(os.path.join(tmp_dir, 'README.md'))
self.save(tmp_dir, model_name=full_model_name, create_model_card=create_model_card, train_datasets=train_datasets)

#Find files larger 5M and track with git-lfs
large_files = []
for root, dirs, files in os.walk(tmp_dir):
for filename in files:
file_path = os.path.join(root, filename)
rel_path = os.path.relpath(file_path, tmp_dir)

if os.path.getsize(file_path) > (5 * 1024 * 1024):
large_files.append(rel_path)

if len(large_files) > 0:
logger.info("Track files with git lfs: {}".format(", ".join(large_files)))
repo.lfs_track(large_files)

logger.info("Push model to the hub. This might take a while")
push_return = repo.push_to_hub(commit_message=commit_message)

def on_rm_error(func, path, exc_info):
# path contains the path of the file that couldn't be removed
# let's just assume that it's read-only and unlink it.
try:
os.chmod(path, stat.S_IWRITE)
os.unlink(path)
except:
pass

# Remove .git folder. On Windows, the .git folder might be read-only and cannot be deleted
# Hence, try to set write permissions on error
try:
for f in os.listdir(tmp_dir):
shutil.rmtree(os.path.join(tmp_dir, f), onerror=on_rm_error)
except Exception as e:
logger.warning("Error when deleting temp folder: {}".format(str(e)))
pass
self.save(tmp_dir, model_name=repo_url.repo_id, create_model_card=create_model_card, train_datasets=train_datasets)
folder_url = api.upload_folder(
repo_id=repo_id,
folder_path=tmp_dir,
commit_message=commit_message
)

refs = api.list_repo_refs(repo_id=repo_id)
for branch in refs.branches:
if branch.name == "main":
return f"https://huggingface.co/{repo_id}/commit/{branch.target_commit}"
# This isn't expected to ever be reached.
return folder_url

return push_return

def smart_batching_collate(self, batch):
"""
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
packages=find_packages(),
python_requires=">=3.8.0",
install_requires=[
'transformers>=4.6.0,<5.0.0',
'transformers>=4.32.0,<5.0.0',
'tqdm',
'torch>=1.6.0',
'numpy',
'scikit-learn',
'scipy',
'nltk',
'sentencepiece',
'huggingface-hub>=0.4.0',
'huggingface-hub>=0.15.1',
'Pillow'
],
classifiers=[
Expand Down
166 changes: 117 additions & 49 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,64 +3,132 @@
"""


import logging
from pathlib import Path
import tempfile
import pytest

from huggingface_hub import HfApi, RepoUrl, GitRefs, GitRefInfo
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Transformer, Pooling
import unittest


class TestSentenceTransformer(unittest.TestCase):
def test_load_with_safetensors(self):
with tempfile.TemporaryDirectory() as cache_folder:
safetensors_model = SentenceTransformer(
"sentence-transformers-testing/stsb-bert-tiny-safetensors",
cache_folder=cache_folder,
)

# Only the safetensors file must be loaded
pytorch_files = list(Path(cache_folder).glob("**/pytorch_model.bin"))
self.assertEqual(0, len(pytorch_files), msg="PyTorch model file must not be downloaded.")
safetensors_files = list(Path(cache_folder).glob("**/model.safetensors"))
self.assertEqual(1, len(safetensors_files), msg="Safetensors model file must be downloaded.")

with tempfile.TemporaryDirectory() as cache_folder:
transformer = Transformer(
"sentence-transformers-testing/stsb-bert-tiny-safetensors",
cache_dir=cache_folder,
model_args={"use_safetensors": False},
)
pooling = Pooling(transformer.get_word_embedding_dimension())
pytorch_model = SentenceTransformer(modules=[transformer, pooling])

# Only the pytorch file must be loaded
pytorch_files = list(Path(cache_folder).glob("**/pytorch_model.bin"))
self.assertEqual(1, len(pytorch_files), msg="PyTorch model file must be downloaded.")
safetensors_files = list(Path(cache_folder).glob("**/model.safetensors"))
self.assertEqual(0, len(safetensors_files), msg="Safetensors model file must not be downloaded.")

sentences = ["This is a test sentence", "This is another test sentence"]
self.assertTrue(
torch.equal(safetensors_model.encode(sentences, convert_to_tensor=True), pytorch_model.encode(sentences, convert_to_tensor=True)),
msg="Ensure that Safetensors and PyTorch loaded models result in identical embeddings",


def test_load_with_safetensors() -> None:
with tempfile.TemporaryDirectory() as cache_folder:
safetensors_model = SentenceTransformer(
"sentence-transformers-testing/stsb-bert-tiny-safetensors",
cache_folder=cache_folder,
)

# Only the safetensors file must be loaded
pytorch_files = list(Path(cache_folder).glob("**/pytorch_model.bin"))
assert 0 == len(pytorch_files), "PyTorch model file must not be downloaded."
safetensors_files = list(Path(cache_folder).glob("**/model.safetensors"))
assert 1 == len(safetensors_files), "Safetensors model file must be downloaded."

with tempfile.TemporaryDirectory() as cache_folder:
transformer = Transformer(
"sentence-transformers-testing/stsb-bert-tiny-safetensors",
cache_dir=cache_folder,
model_args={"use_safetensors": False},
)
pooling = Pooling(transformer.get_word_embedding_dimension())
pytorch_model = SentenceTransformer(modules=[transformer, pooling])

# Only the pytorch file must be loaded
pytorch_files = list(Path(cache_folder).glob("**/pytorch_model.bin"))
assert 1 == len(pytorch_files), "PyTorch model file must be downloaded."
safetensors_files = list(Path(cache_folder).glob("**/model.safetensors"))
assert 0 == len(safetensors_files), "Safetensors model file must not be downloaded."

sentences = ["This is a test sentence", "This is another test sentence"]
assert torch.equal(
safetensors_model.encode(sentences, convert_to_tensor=True),
pytorch_model.encode(sentences, convert_to_tensor=True),
), "Ensure that Safetensors and PyTorch loaded models result in identical embeddings"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA must be available to test moving devices effectively.")
def test_to() -> None:
model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors", device="cpu")

test_device = torch.device("cuda")
assert model.device.type == "cpu"
assert test_device.type == "cuda"

model.to(test_device)
assert model.device.type == "cuda", "The model device should have updated"

@unittest.skipUnless(torch.cuda.is_available(), reason="CUDA must be available to test moving devices effectively.")
def test_to(self):
model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors", device="cpu")
model.encode("Test sentence")
assert model.device.type == "cuda", "Encoding shouldn't change the device"

test_device = torch.device("cuda")
self.assertEqual(model.device.type, "cpu")
self.assertEqual(test_device.type, "cuda")
assert model._target_device == model.device, "Prevent backwards compatibility failure for _target_device"
model._target_device = "cpu"
assert model.device.type == "cpu", "Ensure that setting `_target_device` doesn't crash."

model.to(test_device)
self.assertEqual(model.device.type, "cuda", msg="The model device should have updated")

model.encode("Test sentence")
self.assertEqual(model.device.type, "cuda", msg="Encoding shouldn't change the device")
def test_save_to_hub(monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture) -> None:
def mock_create_repo(self, repo_id, **kwargs):
return RepoUrl(f"https://huggingface.co/{repo_id}")

mock_upload_folder_kwargs = {}

def mock_upload_folder(self, **kwargs):
nonlocal mock_upload_folder_kwargs
mock_upload_folder_kwargs = kwargs

def mock_list_repo_refs(self, repo_id=None, **kwargs):
try:
git_ref_info = GitRefInfo(name="main", ref="refs/heads/main", target_commit="123456")
except TypeError:
git_ref_info = GitRefInfo(dict(name="main", ref="refs/heads/main", targetCommit="123456"))
return GitRefs(branches=[git_ref_info], converts=[], tags=[])

monkeypatch.setattr(HfApi, "create_repo", mock_create_repo)
monkeypatch.setattr(HfApi, "upload_folder", mock_upload_folder)
monkeypatch.setattr(HfApi, "list_repo_refs", mock_list_repo_refs)

model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")
url = model.save_to_hub("sentence-transformers-testing/stsb-bert-tiny-safetensors")
assert mock_upload_folder_kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors"
assert url == "https://huggingface.co/sentence-transformers-testing/stsb-bert-tiny-safetensors/commit/123456"
mock_upload_folder_kwargs.clear()

with pytest.raises(
ValueError, match="Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id`."
):
model.save_to_hub("sentence-transformers-testing/stsb-bert-tiny-safetensors", organization="unrelated")

caplog.clear()
with caplog.at_level(logging.WARNING):
url = model.save_to_hub(
"sentence-transformers-testing/stsb-bert-tiny-safetensors", organization="sentence-transformers-testing"
)
assert mock_upload_folder_kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors"
assert url == "https://huggingface.co/sentence-transformers-testing/stsb-bert-tiny-safetensors/commit/123456"
assert len(caplog.record_tuples) == 1
assert (
caplog.record_tuples[0][2]
== 'Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id="sentence-transformers-testing/stsb-bert-tiny-safetensors"` instead.'
)
mock_upload_folder_kwargs.clear()

caplog.clear()
with caplog.at_level(logging.WARNING):
url = model.save_to_hub("stsb-bert-tiny-safetensors", organization="sentence-transformers-testing")
assert mock_upload_folder_kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors"
assert url == "https://huggingface.co/sentence-transformers-testing/stsb-bert-tiny-safetensors/commit/123456"
assert len(caplog.record_tuples) == 1
assert (
caplog.record_tuples[0][2]
== 'Providing an `organization` to `save_to_hub` is deprecated, please use `repo_id="sentence-transformers-testing/stsb-bert-tiny-safetensors"` instead.'
)
mock_upload_folder_kwargs.clear()

self.assertEqual(model._target_device, model.device, msg="Prevent backwards compatibility failure for _target_device")
model._target_device = "cpu"
self.assertEqual(model.device.type, "cpu", msg="Ensure that setting `_target_device` doesn't crash.")
url = model.save_to_hub(
"sentence-transformers-testing/stsb-bert-tiny-safetensors", local_model_path="my_fake_local_model_path"
)
assert mock_upload_folder_kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors"
assert mock_upload_folder_kwargs["folder_path"] == "my_fake_local_model_path"
assert url == "https://huggingface.co/sentence-transformers-testing/stsb-bert-tiny-safetensors/commit/123456"

0 comments on commit 1565ef6

Please sign in to comment.