Skip to content

Commit

Permalink
Add mocked tests, add coverage to gitignore
Browse files Browse the repository at this point in the history
100% (but mocked) coverage on save_to_hub
  • Loading branch information
tomaarsen committed Dec 13, 2023
1 parent 06aee05 commit a89fb89
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 2 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,7 @@ nr_*/
/docs/make.bat
/docs/Makefile
/examples/training/quora_duplicate_questions/quora-IR-dataset/
build
build

htmlcov
.coverage
8 changes: 7 additions & 1 deletion sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,10 +476,16 @@ def save_to_hub(self,
"""
if organization:
if "/" not in repo_id:
logger.warning(f"Providing an `organization` to `save_to_hub` is deprecated, please use `repo_id=\"{organization}/{repo_id}\"` instead.")
logger.warning(
f"Providing an `organization` to `save_to_hub` is deprecated, please use `repo_id=\"{organization}/{repo_id}\"` instead."
)
repo_id = f"{organization}/{repo_id}"
elif repo_id.split("/")[0] != organization:
raise ValueError("Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id`.")
else:
logger.warning(
f"Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id=\"{repo_id}\"` instead."
)

api = HfApi(token=token)
repo_url = api.create_repo(
Expand Down
37 changes: 37 additions & 0 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
"""


import logging
from pathlib import Path
import tempfile
import pytest

from huggingface_hub import HfApi, RepoUrl
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Transformer, Pooling
Expand Down Expand Up @@ -64,3 +66,38 @@ def test_to() -> None:
assert model._target_device == model.device, "Prevent backwards compatibility failure for _target_device"
model._target_device = "cpu"
assert model.device.type == "cpu", "Ensure that setting `_target_device` doesn't crash."

def test_save_to_hub(monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture) -> None:
def mock_create_repo(self, repo_id, **kwargs):
return RepoUrl(f"https://huggingface.co/{repo_id}")

def mock_upload_folder(self, **kwargs):
return kwargs

monkeypatch.setattr(HfApi, "create_repo", mock_create_repo)
monkeypatch.setattr(HfApi, "upload_folder", mock_upload_folder)

model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")
kwargs = model.save_to_hub("sentence-transformers-testing/stsb-bert-tiny-safetensors")
assert kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors"

with pytest.raises(ValueError, match="Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id`."):
model.save_to_hub("sentence-transformers-testing/stsb-bert-tiny-safetensors", organization="unrelated")

caplog.clear()
with caplog.at_level(logging.WARNING):
kwargs = model.save_to_hub("sentence-transformers-testing/stsb-bert-tiny-safetensors", organization="sentence-transformers-testing")
assert kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors"
assert len(caplog.record_tuples) == 1
assert caplog.record_tuples[0][2] == "Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id=\"sentence-transformers-testing/stsb-bert-tiny-safetensors\"` instead."

caplog.clear()
with caplog.at_level(logging.WARNING):
kwargs = model.save_to_hub("stsb-bert-tiny-safetensors", organization="sentence-transformers-testing")
assert kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors"
assert len(caplog.record_tuples) == 1
assert caplog.record_tuples[0][2] == "Providing an `organization` to `save_to_hub` is deprecated, please use `repo_id=\"sentence-transformers-testing/stsb-bert-tiny-safetensors\"` instead."

kwargs = model.save_to_hub("sentence-transformers-testing/stsb-bert-tiny-safetensors", local_model_path="my_fake_local_model_path")
assert kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors"
assert kwargs["folder_path"] == "my_fake_local_model_path"

0 comments on commit a89fb89

Please sign in to comment.