Skip to content

Commit

Permalink
Add trust_remote_code parameter to SentenceTransformer (#2398)
Browse files Browse the repository at this point in the history
Also disallow configs to set trust_remote_code
  • Loading branch information
tomaarsen authored Jan 8, 2024
1 parent fdef4ef commit aac80ac
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
17 changes: 11 additions & 6 deletions sentence_transformers/SentenceTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,16 @@ class SentenceTransformer(nn.Sequential):
:param device: Device (like "cuda", "cpu", "mps") that should be used for computation. If None, checks if a GPU
can be used.
:param cache_folder: Path to store models. Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable.
:param trust_remote_code: Whether or not to allow for custom models defined on the Hub in their own modeling files.
This option should only be set to True for repositories you trust and in which you have read the code, as it
will execute code present on the Hub on your local machine.
:param token: Hugging Face authentication token to download private models.
"""
def __init__(self, model_name_or_path: Optional[str] = None,
modules: Optional[Iterable[nn.Module]] = None,
device: Optional[str] = None,
cache_folder: Optional[str] = None,
trust_remote_code: bool = False,
token: Optional[Union[bool, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
):
Expand Down Expand Up @@ -114,9 +118,9 @@ def __init__(self, model_name_or_path: Optional[str] = None,
model_name_or_path = __MODEL_HUB_ORGANIZATION__ + "/" + model_name_or_path

if is_sentence_transformer_model(model_name_or_path, token, cache_folder=cache_folder):
modules = self._load_sbert_model(model_name_or_path, token=token, cache_folder=cache_folder)
modules = self._load_sbert_model(model_name_or_path, token=token, cache_folder=cache_folder, trust_remote_code=trust_remote_code)
else:
modules = self._load_auto_model(model_name_or_path, token=token, cache_folder=cache_folder)
modules = self._load_auto_model(model_name_or_path, token=token, cache_folder=cache_folder, trust_remote_code=trust_remote_code)

if modules is not None and not isinstance(modules, OrderedDict):
modules = OrderedDict([(str(idx), module) for idx, module in enumerate(modules)])
Expand Down Expand Up @@ -806,16 +810,16 @@ def _save_checkpoint(self, checkpoint_path, checkpoint_save_total_limit, step):
shutil.rmtree(old_checkpoints[0]['path'])


def _load_auto_model(self, model_name_or_path: str, token: Optional[Union[bool, str]], cache_folder: Optional[str]):
def _load_auto_model(self, model_name_or_path: str, token: Optional[Union[bool, str]], cache_folder: Optional[str], trust_remote_code: bool = False):
"""
Creates a simple Transformer + Mean Pooling model and returns the modules
"""
logger.warning("No sentence-transformers model found with name {}. Creating a new one with MEAN pooling.".format(model_name_or_path))
transformer_model = Transformer(model_name_or_path, cache_dir=cache_folder, model_args={"token": token})
transformer_model = Transformer(model_name_or_path, cache_dir=cache_folder, model_args={"token": token, "trust_remote_code": trust_remote_code})
pooling_model = Pooling(transformer_model.get_word_embedding_dimension(), 'mean')
return [transformer_model, pooling_model]

def _load_sbert_model(self, model_name_or_path: str, token: Optional[Union[bool, str]], cache_folder: Optional[str]):
def _load_sbert_model(self, model_name_or_path: str, token: Optional[Union[bool, str]], cache_folder: Optional[str], trust_remote_code: bool = False):
"""
Loads a full sentence-transformers model
"""
Expand Down Expand Up @@ -857,8 +861,9 @@ def _load_sbert_model(self, model_name_or_path: str, token: Optional[Union[bool,
break
if "model_args" in kwargs:
kwargs["model_args"]["token"] = token
kwargs["model_args"]["trust_remote_code"] = trust_remote_code
else:
kwargs["model_args"] = {"token": token}
kwargs["model_args"] = {"token": token, "trust_remote_code": trust_remote_code}
module = Transformer(model_name_or_path, cache_dir=cache_folder, **kwargs)
else:
module_path = load_dir_path(model_name_or_path, module_config['path'], token=token, cache_folder=cache_folder)
Expand Down
3 changes: 3 additions & 0 deletions sentence_transformers/models/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,9 @@ def load(input_path: str):

with open(sbert_config_path) as fIn:
config = json.load(fIn)
# Don't allow configs to set trust_remote_code
if "model_args" in config:
config["model_args"].pop("trust_remote_code")
return Transformer(model_name_or_path=input_path, **config)


Expand Down

0 comments on commit aac80ac

Please sign in to comment.