-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[refactor]
model loading - no more unnecessary file downloads (#2345)
* Refactor model loading: no full repo download * Add simple test regarding efficient loading * Replace use_auth_token with token in docstring Deprecated arguments are not listed in docstrings * Prevent crash if internet is down * Use load_file_path in "is_sbert_model"
- Loading branch information
Showing
3 changed files
with
152 additions
and
115 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
""" | ||
Tests general behaviour of the SentenceTransformer class | ||
""" | ||
|
||
from pathlib import Path | ||
import tempfile | ||
|
||
import torch | ||
from sentence_transformers import SentenceTransformer | ||
from sentence_transformers.models import Transformer, Pooling | ||
import unittest | ||
|
||
|
||
class TestSentenceTransformer(unittest.TestCase): | ||
def test_load_with_safetensors(self): | ||
with tempfile.TemporaryDirectory() as cache_folder: | ||
safetensors_model = SentenceTransformer( | ||
"sentence-transformers-testing/stsb-bert-tiny-safetensors", | ||
cache_folder=cache_folder, | ||
) | ||
|
||
# Only the safetensors file must be loaded | ||
pytorch_files = list(Path(cache_folder).glob("**/pytorch_model.bin")) | ||
self.assertEqual(0, len(pytorch_files), msg="PyTorch model file must not be downloaded.") | ||
safetensors_files = list(Path(cache_folder).glob("**/model.safetensors")) | ||
self.assertEqual(1, len(safetensors_files), msg="Safetensors model file must be downloaded.") | ||
|
||
with tempfile.TemporaryDirectory() as cache_folder: | ||
transformer = Transformer( | ||
"sentence-transformers-testing/stsb-bert-tiny-safetensors", | ||
cache_dir=cache_folder, | ||
model_args={"use_safetensors": False}, | ||
) | ||
pooling = Pooling(transformer.get_word_embedding_dimension()) | ||
pytorch_model = SentenceTransformer(modules=[transformer, pooling]) | ||
|
||
# Only the pytorch file must be loaded | ||
pytorch_files = list(Path(cache_folder).glob("**/pytorch_model.bin")) | ||
self.assertEqual(1, len(pytorch_files), msg="PyTorch model file must be downloaded.") | ||
safetensors_files = list(Path(cache_folder).glob("**/model.safetensors")) | ||
self.assertEqual(0, len(safetensors_files), msg="Safetensors model file must not be downloaded.") | ||
|
||
sentences = ["This is a test sentence", "This is another test sentence"] | ||
self.assertTrue( | ||
torch.equal(safetensors_model.encode(sentences, convert_to_tensor=True), pytorch_model.encode(sentences, convert_to_tensor=True)), | ||
msg="Ensure that Safetensors and PyTorch loaded models result in identical embeddings", | ||
) |