Skip to content

Commit

Permalink
Add revision to load a specific model version
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaarsen committed Jan 17, 2024
1 parent 4b00a34 commit dd41547
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 18 deletions.
54 changes: 40 additions & 14 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class SentenceTransformer(nn.Sequential):
:param device: Device (like "cuda", "cpu", "mps") that should be used for computation. If None, checks if a GPU
can be used.
:param cache_folder: Path to store models. Can also be set by the SENTENCE_TRANSFORMERS_HOME environment variable.
:param revision: The specific model version to use. It can be a branch name, a tag name, or a commit id,
for a stored model on Hugging Face.
:param trust_remote_code: Whether or not to allow for custom models defined on the Hub in their own modeling files.
This option should only be set to True for repositories you trust and in which you have read the code, as it
will execute code present on the Hub on your local machine.
Expand All @@ -83,6 +85,7 @@ def __init__(
device: Optional[str] = None,
cache_folder: Optional[str] = None,
trust_remote_code: bool = False,
revision: Optional[str] = None,
token: Optional[Union[bool, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
):
Expand Down Expand Up @@ -187,13 +190,21 @@ def __init__(
# A model from sentence-transformers
model_name_or_path = __MODEL_HUB_ORGANIZATION__ + "/" + model_name_or_path

if is_sentence_transformer_model(model_name_or_path, token, cache_folder=cache_folder):
if is_sentence_transformer_model(model_name_or_path, token, cache_folder=cache_folder, revision=revision):
modules = self._load_sbert_model(
model_name_or_path, token=token, cache_folder=cache_folder, trust_remote_code=trust_remote_code
model_name_or_path,
token=token,
cache_folder=cache_folder,
revision=revision,
trust_remote_code=trust_remote_code,
)
else:
modules = self._load_auto_model(
model_name_or_path, token=token, cache_folder=cache_folder, trust_remote_code=trust_remote_code
model_name_or_path,
token=token,
cache_folder=cache_folder,
revision=revision,
trust_remote_code=trust_remote_code,
)

if modules is not None and not isinstance(modules, OrderedDict):
Expand Down Expand Up @@ -942,6 +953,7 @@ def _load_auto_model(
model_name_or_path: str,
token: Optional[Union[bool, str]],
cache_folder: Optional[str],
revision: Optional[str] = None,
trust_remote_code: bool = False,
):
"""
Expand All @@ -955,8 +967,8 @@ def _load_auto_model(
transformer_model = Transformer(
model_name_or_path,
cache_dir=cache_folder,
model_args={"token": token, "trust_remote_code": trust_remote_code},
tokenizer_args={"token": token, "trust_remote_code": trust_remote_code},
model_args={"token": token, "trust_remote_code": trust_remote_code, "revision": revision},
tokenizer_args={"token": token, "trust_remote_code": trust_remote_code, "revision": revision},
)
pooling_model = Pooling(transformer_model.get_word_embedding_dimension(), "mean")
return [transformer_model, pooling_model]
Expand All @@ -966,14 +978,19 @@ def _load_sbert_model(
model_name_or_path: str,
token: Optional[Union[bool, str]],
cache_folder: Optional[str],
revision: Optional[str] = None,
trust_remote_code: bool = False,
):
"""
Loads a full sentence-transformers model
"""
# Check if the config_sentence_transformers.json file exists (exists since v2 of the framework)
config_sentence_transformers_json_path = load_file_path(
model_name_or_path, "config_sentence_transformers.json", token=token, cache_folder=cache_folder
model_name_or_path,
"config_sentence_transformers.json",
token=token,
cache_folder=cache_folder,
revision=revision,
)
if config_sentence_transformers_json_path is not None:
with open(config_sentence_transformers_json_path) as fIn:
Expand All @@ -991,7 +1008,9 @@ def _load_sbert_model(
)

# Check if a readme exists
model_card_path = load_file_path(model_name_or_path, "README.md", token=token, cache_folder=cache_folder)
model_card_path = load_file_path(
model_name_or_path, "README.md", token=token, cache_folder=cache_folder, revision=revision
)
if model_card_path is not None:
try:
with open(model_card_path, encoding="utf8") as fIn:
Expand All @@ -1000,7 +1019,9 @@ def _load_sbert_model(
pass

# Load the modules of sentence transformer
modules_json_path = load_file_path(model_name_or_path, "modules.json", token=token, cache_folder=cache_folder)
modules_json_path = load_file_path(
model_name_or_path, "modules.json", token=token, cache_folder=cache_folder, revision=revision
)
with open(modules_json_path) as fIn:
modules_config = json.load(fIn)

Expand All @@ -1021,24 +1042,29 @@ def _load_sbert_model(
"sentence_xlnet_config.json",
]:
config_path = load_file_path(
model_name_or_path, config_name, token=token, cache_folder=cache_folder
model_name_or_path, config_name, token=token, cache_folder=cache_folder, revision=revision
)
if config_path is not None:
with open(config_path) as fIn:
kwargs = json.load(fIn)
break
hub_kwargs = {"token": token, "trust_remote_code": trust_remote_code, "revision": revision}
if "model_args" in kwargs:
kwargs["model_args"].update({"token": token, "trust_remote_code": trust_remote_code})
kwargs["model_args"].update(hub_kwargs)
else:
kwargs["model_args"] = {"token": token, "trust_remote_code": trust_remote_code}
kwargs["model_args"] = hub_kwargs
if "tokenizer_args" in kwargs:
kwargs["tokenizer_args"].update({"token": token, "trust_remote_code": trust_remote_code})
kwargs["tokenizer_args"].update(hub_kwargs)
else:
kwargs["tokenizer_args"] = {"token": token, "trust_remote_code": trust_remote_code}
kwargs["tokenizer_args"] = hub_kwargs
module = Transformer(model_name_or_path, cache_dir=cache_folder, **kwargs)
else:
module_path = load_dir_path(
model_name_or_path, module_config["path"], token=token, cache_folder=cache_folder
model_name_or_path,
module_config["path"],
token=token,
cache_folder=cache_folder,
revision=revision,
)
module = module_class.load(module_path)
modules[module_config["name"]] = module
Expand Down
21 changes: 17 additions & 4 deletions sentence_transformers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,13 +473,20 @@ def __delattr__(self, attr: str) -> None:


def is_sentence_transformer_model(
model_name_or_path: str, token: Optional[Union[bool, str]] = None, cache_folder: Optional[str] = None
model_name_or_path: str,
token: Optional[Union[bool, str]] = None,
cache_folder: Optional[str] = None,
revision: Optional[str] = None,
) -> bool:
return bool(load_file_path(model_name_or_path, "modules.json", token, cache_folder))
return bool(load_file_path(model_name_or_path, "modules.json", token, cache_folder, revision=revision))


def load_file_path(
model_name_or_path: str, filename: str, token: Optional[Union[bool, str]], cache_folder: Optional[str]
model_name_or_path: str,
filename: str,
token: Optional[Union[bool, str]],
cache_folder: Optional[str],
revision: Optional[str] = None,
) -> Optional[str]:
# If file is local
file_path = os.path.join(model_name_or_path, filename)
Expand All @@ -491,6 +498,7 @@ def load_file_path(
return hf_hub_download(
model_name_or_path,
filename=filename,
revision=revision,
library_name="sentence-transformers",
token=token,
cache_dir=cache_folder,
Expand All @@ -500,7 +508,11 @@ def load_file_path(


def load_dir_path(
model_name_or_path: str, directory: str, token: Optional[Union[bool, str]], cache_folder: Optional[str]
model_name_or_path: str,
directory: str,
token: Optional[Union[bool, str]],
cache_folder: Optional[str],
revision: Optional[str] = None,
) -> Optional[str]:
# If file is local
dir_path = os.path.join(model_name_or_path, directory)
Expand All @@ -509,6 +521,7 @@ def load_dir_path(

download_kwargs = {
"repo_id": model_name_or_path,
"revision": revision,
"allow_patterns": f"{directory}/**",
"library_name": "sentence-transformers",
"token": token,
Expand Down
15 changes: 15 additions & 0 deletions tests/test_sentence_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,18 @@ def mock_list_repo_refs(self, repo_id=None, **kwargs):
caplog.record_tuples[0][2]
== 'Providing an `organization` to `save_to_hub` is deprecated, please use `repo_id="sentence-transformers-testing/stsb-bert-tiny-safetensors"` instead.'
)


def test_load_with_revision() -> None:
main_model = SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors", revision="main")
latest_model = SentenceTransformer(
"sentence-transformers-testing/stsb-bert-tiny-safetensors", revision="f3cb857cba53019a20df283396bcca179cf051a4"
)
older_model = SentenceTransformer(
"sentence-transformers-testing/stsb-bert-tiny-safetensors", revision="ba33022fdf0b0fc2643263f0726f44d0a07d0e24"
)

test_sentence = ["Hello there!"]
main_embeddings = main_model.encode(test_sentence, convert_to_tensor=True)
assert torch.equal(main_embeddings, latest_model.encode(test_sentence, convert_to_tensor=True))
assert not torch.equal(main_embeddings, older_model.encode(test_sentence, convert_to_tensor=True))

0 comments on commit dd41547

Please sign in to comment.