From a89fb89d2bd502f0a401f70b60a2c6475e7f9e58 Mon Sep 17 00:00:00 2001 From: Tom Aarsen Date: Wed, 13 Dec 2023 12:55:58 +0100 Subject: [PATCH] Add mocked tests, add coverage to gitignore 100% (but mocked) coverage on save_to_hub --- .gitignore | 5 ++- sentence_transformers/SentenceTransformer.py | 8 ++++- tests/test_sentence_transformer.py | 37 ++++++++++++++++++++ 3 files changed, 48 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 149eca463..8c8b2456a 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,7 @@ nr_*/ /docs/make.bat /docs/Makefile /examples/training/quora_duplicate_questions/quora-IR-dataset/ -build \ No newline at end of file +build + +htmlcov +.coverage \ No newline at end of file diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 888a006b4..07fb08ce1 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -476,10 +476,16 @@ def save_to_hub(self, """ if organization: if "/" not in repo_id: - logger.warning(f"Providing an `organization` to `save_to_hub` is deprecated, please use `repo_id=\"{organization}/{repo_id}\"` instead.") + logger.warning( + f"Providing an `organization` to `save_to_hub` is deprecated, please use `repo_id=\"{organization}/{repo_id}\"` instead." + ) repo_id = f"{organization}/{repo_id}" elif repo_id.split("/")[0] != organization: raise ValueError("Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id`.") + else: + logger.warning( + f"Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id=\"{repo_id}\"` instead." + ) api = HfApi(token=token) repo_url = api.create_repo( diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py index 674e7801d..d2ef00f53 100644 --- a/tests/test_sentence_transformer.py +++ b/tests/test_sentence_transformer.py @@ -3,10 +3,12 @@ """ +import logging from pathlib import Path import tempfile import pytest +from huggingface_hub import HfApi, RepoUrl import torch from sentence_transformers import SentenceTransformer from sentence_transformers.models import Transformer, Pooling @@ -64,3 +66,38 @@ def test_to() -> None: assert model._target_device == model.device, "Prevent backwards compatibility failure for _target_device" model._target_device = "cpu" assert model.device.type == "cpu", "Ensure that setting `_target_device` doesn't crash." + +def test_save_to_hub(monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture) -> None: + def mock_create_repo(self, repo_id, **kwargs): + return RepoUrl(f"https://huggingface.co/{repo_id}") + + def mock_upload_folder(self, **kwargs): + return kwargs + + monkeypatch.setattr(HfApi, "create_repo", mock_create_repo) + monkeypatch.setattr(HfApi, "upload_folder", mock_upload_folder) + + model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors") + kwargs = model.save_to_hub("sentence-transformers-testing/stsb-bert-tiny-safetensors") + assert kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors" + + with pytest.raises(ValueError, match="Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id`."): + model.save_to_hub("sentence-transformers-testing/stsb-bert-tiny-safetensors", organization="unrelated") + + caplog.clear() + with caplog.at_level(logging.WARNING): + kwargs = model.save_to_hub("sentence-transformers-testing/stsb-bert-tiny-safetensors", organization="sentence-transformers-testing") + assert kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors" + assert len(caplog.record_tuples) == 1 + assert caplog.record_tuples[0][2] == "Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id=\"sentence-transformers-testing/stsb-bert-tiny-safetensors\"` instead." + + caplog.clear() + with caplog.at_level(logging.WARNING): + kwargs = model.save_to_hub("stsb-bert-tiny-safetensors", organization="sentence-transformers-testing") + assert kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors" + assert len(caplog.record_tuples) == 1 + assert caplog.record_tuples[0][2] == "Providing an `organization` to `save_to_hub` is deprecated, please use `repo_id=\"sentence-transformers-testing/stsb-bert-tiny-safetensors\"` instead." + + kwargs = model.save_to_hub("sentence-transformers-testing/stsb-bert-tiny-safetensors", local_model_path="my_fake_local_model_path") + assert kwargs["repo_id"] == "sentence-transformers-testing/stsb-bert-tiny-safetensors" + assert kwargs["folder_path"] == "my_fake_local_model_path"