diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 355fadbdc..361e3f434 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -4,6 +4,7 @@ import shutil import stat from collections import OrderedDict +import warnings from typing import List, Dict, Tuple, Iterable, Type, Union, Callable, Optional, Literal import numpy as np from numpy import ndarray @@ -22,7 +23,7 @@ from . import __MODEL_HUB_ORGANIZATION__ from .evaluation import SentenceEvaluator -from .util import import_from_string, batch_to_device, fullname, snapshot_download +from .util import import_from_string, batch_to_device, fullname, is_sentence_transformer_model, load_dir_path, load_file_path from .models import Transformer, Pooling from .model_card_templates import ModelCardTemplate from . import __version__ @@ -59,17 +60,27 @@ class SentenceTransformer(nn.Sequential): :param device: Device (like "cuda", "cpu", "mps") that should be used for computation. If None, checks if a GPU can be used. :param cache_folder: Path to store models. Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable. - :param use_auth_token: Hugging Face authentication token to download private models. + :param token: Hugging Face authentication token to download private models. """ def __init__(self, model_name_or_path: Optional[str] = None, modules: Optional[Iterable[nn.Module]] = None, device: Optional[str] = None, cache_folder: Optional[str] = None, - use_auth_token: Union[bool, str, None] = None + token: Optional[Union[bool, str]] = None, + use_auth_token: Optional[Union[bool, str]] = None, ): self._model_card_vars = {} self._model_card_text = None self._model_config = {} + if use_auth_token is not None: + warnings.warn( + "The `use_auth_token` argument is deprecated and will be removed in v3 of SentenceTransformers.", FutureWarning + ) + if token is not None: + raise ValueError( + "`token` and `use_auth_token` are both specified. Please set only the argument `token`." + ) + token = use_auth_token if cache_folder is None: cache_folder = os.getenv('SENTENCE_TRANSFORMERS_HOME') @@ -86,13 +97,10 @@ def __init__(self, model_name_or_path: Optional[str] = None, if model_name_or_path is not None and model_name_or_path != "": logger.info("Load pretrained SentenceTransformer: {}".format(model_name_or_path)) - #Old models that don't belong to any organization + # Old models that don't belong to any organization basic_transformer_models = ['albert-base-v1', 'albert-base-v2', 'albert-large-v1', 'albert-large-v2', 'albert-xlarge-v1', 'albert-xlarge-v2', 'albert-xxlarge-v1', 'albert-xxlarge-v2', 'bert-base-cased-finetuned-mrpc', 'bert-base-cased', 'bert-base-chinese', 'bert-base-german-cased', 'bert-base-german-dbmdz-cased', 'bert-base-german-dbmdz-uncased', 'bert-base-multilingual-cased', 'bert-base-multilingual-uncased', 'bert-base-uncased', 'bert-large-cased-whole-word-masking-finetuned-squad', 'bert-large-cased-whole-word-masking', 'bert-large-cased', 'bert-large-uncased-whole-word-masking-finetuned-squad', 'bert-large-uncased-whole-word-masking', 'bert-large-uncased', 'camembert-base', 'ctrl', 'distilbert-base-cased-distilled-squad', 'distilbert-base-cased', 'distilbert-base-german-cased', 'distilbert-base-multilingual-cased', 'distilbert-base-uncased-distilled-squad', 'distilbert-base-uncased-finetuned-sst-2-english', 'distilbert-base-uncased', 'distilgpt2', 'distilroberta-base', 'gpt2-large', 'gpt2-medium', 'gpt2-xl', 'gpt2', 'openai-gpt', 'roberta-base-openai-detector', 'roberta-base', 'roberta-large-mnli', 'roberta-large-openai-detector', 'roberta-large', 't5-11b', 't5-3b', 't5-base', 't5-large', 't5-small', 'transfo-xl-wt103', 'xlm-clm-ende-1024', 'xlm-clm-enfr-1024', 'xlm-mlm-100-1280', 'xlm-mlm-17-1280', 'xlm-mlm-en-2048', 'xlm-mlm-ende-1024', 'xlm-mlm-enfr-1024', 'xlm-mlm-enro-1024', 'xlm-mlm-tlm-xnli15-1024', 'xlm-mlm-xnli15-1024', 'xlm-roberta-base', 'xlm-roberta-large-finetuned-conll02-dutch', 'xlm-roberta-large-finetuned-conll02-spanish', 'xlm-roberta-large-finetuned-conll03-english', 'xlm-roberta-large-finetuned-conll03-german', 'xlm-roberta-large', 'xlnet-base-cased', 'xlnet-large-cased'] - if os.path.exists(model_name_or_path): - #Load from path - model_path = model_name_or_path - else: + if not os.path.exists(model_name_or_path): #Not a path, load from hub if '\\' in model_name_or_path or model_name_or_path.count('/') > 1: raise ValueError("Path {} not found".format(model_name_or_path)) @@ -101,21 +109,10 @@ def __init__(self, model_name_or_path: Optional[str] = None, # A model from sentence-transformers model_name_or_path = __MODEL_HUB_ORGANIZATION__ + "/" + model_name_or_path - model_path = os.path.join(cache_folder, model_name_or_path.replace("/", "_")) - - if not os.path.exists(os.path.join(model_path, 'modules.json')): - # Download from hub with caching - snapshot_download(model_name_or_path, - cache_dir=cache_folder, - library_name='sentence-transformers', - library_version=__version__, - ignore_files=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'], - use_auth_token=use_auth_token) - - if os.path.exists(os.path.join(model_path, 'modules.json')): #Load as SentenceTransformer model - modules = self._load_sbert_model(model_path) - else: #Load with AutoModel - modules = self._load_auto_model(model_path) + if is_sentence_transformer_model(model_name_or_path, token, cache_folder=cache_folder): + modules = self._load_sbert_model(model_name_or_path, token=token, cache_folder=cache_folder) + else: + modules = self._load_auto_model(model_name_or_path, token=token, cache_folder=cache_folder) if modules is not None and not isinstance(modules, OrderedDict): modules = OrderedDict([(str(idx), module) for idx, module in enumerate(modules)]) @@ -823,22 +820,22 @@ def _save_checkpoint(self, checkpoint_path, checkpoint_save_total_limit, step): shutil.rmtree(old_checkpoints[0]['path']) - def _load_auto_model(self, model_name_or_path): + def _load_auto_model(self, model_name_or_path: str, token: Optional[Union[bool, str]], cache_folder: Optional[str]): """ Creates a simple Transformer + Mean Pooling model and returns the modules """ logger.warning("No sentence-transformers model found with name {}. Creating a new one with MEAN pooling.".format(model_name_or_path)) - transformer_model = Transformer(model_name_or_path) + transformer_model = Transformer(model_name_or_path, cache_dir=cache_folder, model_args={"token": token}) pooling_model = Pooling(transformer_model.get_word_embedding_dimension(), 'mean') return [transformer_model, pooling_model] - def _load_sbert_model(self, model_path): + def _load_sbert_model(self, model_name_or_path: str, token: Optional[Union[bool, str]], cache_folder: Optional[str]): """ Loads a full sentence-transformers model """ # Check if the config_sentence_transformers.json file exists (exists since v2 of the framework) - config_sentence_transformers_json_path = os.path.join(model_path, 'config_sentence_transformers.json') - if os.path.exists(config_sentence_transformers_json_path): + config_sentence_transformers_json_path = load_file_path(model_name_or_path, 'config_sentence_transformers.json', token=token, cache_folder=cache_folder) + if config_sentence_transformers_json_path is not None: with open(config_sentence_transformers_json_path) as fIn: self._model_config = json.load(fIn) @@ -846,8 +843,8 @@ def _load_sbert_model(self, model_path): logger.warning("You try to use a model that was created with version {}, however, your version is {}. This might cause unexpected behavior or errors. In that case, try to update to the latest version.\n\n\n".format(self._model_config['__version__']['sentence_transformers'], __version__)) # Check if a readme exists - model_card_path = os.path.join(model_path, 'README.md') - if os.path.exists(model_card_path): + model_card_path = load_file_path(model_name_or_path, 'README.md', token=token, cache_folder=cache_folder) + if model_card_path is not None: try: with open(model_card_path, encoding='utf8') as fIn: self._model_card_text = fIn.read() @@ -855,14 +852,31 @@ def _load_sbert_model(self, model_path): pass # Load the modules of sentence transformer - modules_json_path = os.path.join(model_path, 'modules.json') + modules_json_path = load_file_path(model_name_or_path, 'modules.json', token=token, cache_folder=cache_folder) with open(modules_json_path) as fIn: modules_config = json.load(fIn) modules = OrderedDict() for module_config in modules_config: module_class = import_from_string(module_config['type']) - module = module_class.load(os.path.join(model_path, module_config['path'])) + # For Transformer, don't load the full directory, rely on `transformers` instead + # But, do load the config file first. + if module_class == Transformer and module_config['path'] == "": + kwargs = {} + for config_name in ['sentence_bert_config.json', 'sentence_roberta_config.json', 'sentence_distilbert_config.json', 'sentence_camembert_config.json', 'sentence_albert_config.json', 'sentence_xlm-roberta_config.json', 'sentence_xlnet_config.json']: + config_path = load_file_path(model_name_or_path, config_name, token=token, cache_folder=cache_folder) + if config_path is not None: + with open(config_path) as fIn: + kwargs = json.load(fIn) + break + if "model_args" in kwargs: + kwargs["model_args"]["token"] = token + else: + kwargs["model_args"] = {"token": token} + module = Transformer(model_name_or_path, cache_dir=cache_folder, **kwargs) + else: + module_path = load_dir_path(model_name_or_path, module_config['path'], token=token, cache_folder=cache_folder) + module = module_class.load(module_path) modules[module_config['name']] = module return modules diff --git a/sentence_transformers/util.py b/sentence_transformers/util.py index 62f6d1509..a7fca5be1 100644 --- a/sentence_transformers/util.py +++ b/sentence_transformers/util.py @@ -12,9 +12,9 @@ from typing import Dict, Optional, Union from pathlib import Path -import huggingface_hub from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE -from huggingface_hub import HfApi, hf_hub_url, cached_download, HfFolder +from huggingface_hub import snapshot_download, hf_hub_download +from huggingface_hub.utils import EntryNotFoundError import fnmatch from packaging import version import heapq @@ -424,86 +424,62 @@ def community_detection(embeddings, threshold=0.75, min_community_size=10, batch ###################### - -def snapshot_download( - repo_id: str, - revision: Optional[str] = None, - cache_dir: Union[str, Path, None] = None, - library_name: Optional[str] = None, - library_version: Optional[str] = None, - user_agent: Union[Dict, str, None] = None, - ignore_files: Optional[List[str]] = None, - use_auth_token: Union[bool, str, None] = None -) -> str: +class disabled_tqdm(tqdm): """ - Method derived from huggingface_hub. - Adds a new parameters 'ignore_files', which allows to ignore certain files / file-patterns + Class to override `disable` argument in case progress bars are globally disabled. + + Taken from https://github.com/tqdm/tqdm/issues/619#issuecomment-619639324. """ - if cache_dir is None: - cache_dir = HUGGINGFACE_HUB_CACHE - if isinstance(cache_dir, Path): - cache_dir = str(cache_dir) - - _api = HfApi() - - token = None - if isinstance(use_auth_token, str): - token = use_auth_token - elif use_auth_token: - token = HfFolder.get_token() - - model_info = _api.model_info(repo_id=repo_id, revision=revision, token=token) - - storage_folder = os.path.join( - cache_dir, repo_id.replace("/", "_") - ) - - all_files = model_info.siblings - #Download modules.json as the last file - for idx, repofile in enumerate(all_files): - if repofile.rfilename == "modules.json": - del all_files[idx] - all_files.append(repofile) - break - - for model_file in all_files: - if ignore_files is not None: - skip_download = False - for pattern in ignore_files: - if fnmatch.fnmatch(model_file.rfilename, pattern): - skip_download = True - break - - if skip_download: - continue - - url = hf_hub_url( - repo_id, filename=model_file.rfilename, revision=model_info.sha - ) - relative_filepath = os.path.join(*model_file.rfilename.split("/")) - - # Create potential nested dir - nested_dirname = os.path.dirname( - os.path.join(storage_folder, relative_filepath) - ) - os.makedirs(nested_dirname, exist_ok=True) - - cached_download_args = {'url': url, - 'cache_dir': storage_folder, - 'force_filename': relative_filepath, - 'library_name': library_name, - 'library_version': library_version, - 'user_agent': user_agent, - 'use_auth_token': use_auth_token} - - if version.parse(huggingface_hub.__version__) >= version.parse("0.8.1"): - # huggingface_hub v0.8.1 introduces a new cache layout. We sill use a manual layout - # And need to pass legacy_cache_layout=True to avoid that a warning will be printed - cached_download_args['legacy_cache_layout'] = True - - path = cached_download(**cached_download_args) - - if os.path.exists(path + ".lock"): - os.remove(path + ".lock") - - return storage_folder + + def __init__(self, *args, **kwargs): + kwargs["disable"] = True + super().__init__(*args, **kwargs) + + def __delattr__(self, attr: str) -> None: + """Fix for https://github.com/huggingface/huggingface_hub/issues/1603""" + try: + super().__delattr__(attr) + except AttributeError: + if attr != "_lock": + raise + + +def is_sentence_transformer_model(model_name_or_path: str, token: Optional[Union[bool, str]] = None, cache_folder: Optional[str] = None) -> bool: + return bool(load_file_path(model_name_or_path, "modules.json", token, cache_folder)) + + +def load_file_path(model_name_or_path: str, filename: str, token: Optional[Union[bool, str]], cache_folder: Optional[str]) -> Optional[str]: + # If file is local + file_path = os.path.join(model_name_or_path, filename) + if os.path.exists(file_path): + return file_path + + # If file is remote + try: + return hf_hub_download(model_name_or_path, filename=filename, library_name="sentence-transformers", token=token, cache_dir=cache_folder) + except Exception: + return + + +def load_dir_path(model_name_or_path: str, directory: str, token: Optional[Union[bool, str]], cache_folder: Optional[str]) -> Optional[str]: + # If file is local + dir_path = os.path.join(model_name_or_path, directory) + if os.path.exists(dir_path): + return dir_path + + download_kwargs = { + "repo_id": model_name_or_path, + "allow_patterns":f"{directory}/**", + "library_name": "sentence-transformers", + "token": token, + "cache_dir": cache_folder, + "tqdm_class": disabled_tqdm, + } + # Try to download from the remote + try: + repo_path = snapshot_download(**download_kwargs) + except Exception: + # Otherwise, try local (i.e. cache) only + download_kwargs["local_files_only"] = True + repo_path = snapshot_download(**download_kwargs) + return os.path.join(repo_path, directory) diff --git a/tests/test_sentence_transformer.py b/tests/test_sentence_transformer.py new file mode 100644 index 000000000..c8913fcb7 --- /dev/null +++ b/tests/test_sentence_transformer.py @@ -0,0 +1,47 @@ +""" +Tests general behaviour of the SentenceTransformer class +""" + +from pathlib import Path +import tempfile + +import torch +from sentence_transformers import SentenceTransformer +from sentence_transformers.models import Transformer, Pooling +import unittest + + +class TestSentenceTransformer(unittest.TestCase): + def test_load_with_safetensors(self): + with tempfile.TemporaryDirectory() as cache_folder: + safetensors_model = SentenceTransformer( + "sentence-transformers-testing/stsb-bert-tiny-safetensors", + cache_folder=cache_folder, + ) + + # Only the safetensors file must be loaded + pytorch_files = list(Path(cache_folder).glob("**/pytorch_model.bin")) + self.assertEqual(0, len(pytorch_files), msg="PyTorch model file must not be downloaded.") + safetensors_files = list(Path(cache_folder).glob("**/model.safetensors")) + self.assertEqual(1, len(safetensors_files), msg="Safetensors model file must be downloaded.") + + with tempfile.TemporaryDirectory() as cache_folder: + transformer = Transformer( + "sentence-transformers-testing/stsb-bert-tiny-safetensors", + cache_dir=cache_folder, + model_args={"use_safetensors": False}, + ) + pooling = Pooling(transformer.get_word_embedding_dimension()) + pytorch_model = SentenceTransformer(modules=[transformer, pooling]) + + # Only the pytorch file must be loaded + pytorch_files = list(Path(cache_folder).glob("**/pytorch_model.bin")) + self.assertEqual(1, len(pytorch_files), msg="PyTorch model file must be downloaded.") + safetensors_files = list(Path(cache_folder).glob("**/model.safetensors")) + self.assertEqual(0, len(safetensors_files), msg="Safetensors model file must not be downloaded.") + + sentences = ["This is a test sentence", "This is another test sentence"] + self.assertTrue( + torch.equal(safetensors_model.encode(sentences, convert_to_tensor=True), pytorch_model.encode(sentences, convert_to_tensor=True)), + msg="Ensure that Safetensors and PyTorch loaded models result in identical embeddings", + )