diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index be33cff0c..a5e696ce0 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -967,6 +967,7 @@ def _load_auto_model( model_name_or_path, cache_dir=cache_folder, model_args={"token": token, "trust_remote_code": trust_remote_code}, + tokenizer_args={"token": token, "trust_remote_code": trust_remote_code}, ) pooling_model = Pooling(transformer_model.get_word_embedding_dimension(), "mean") return [transformer_model, pooling_model] @@ -1038,10 +1039,13 @@ def _load_sbert_model( kwargs = json.load(fIn) break if "model_args" in kwargs: - kwargs["model_args"]["token"] = token - kwargs["model_args"]["trust_remote_code"] = trust_remote_code + kwargs["model_args"].update({"token": token, "trust_remote_code": trust_remote_code}) else: kwargs["model_args"] = {"token": token, "trust_remote_code": trust_remote_code} + if "tokenizer_args" in kwargs: + kwargs["tokenizer_args"].update({"token": token, "trust_remote_code": trust_remote_code}) + else: + kwargs["tokenizer_args"] = {"token": token, "trust_remote_code": trust_remote_code} module = Transformer(model_name_or_path, cache_dir=cache_folder, **kwargs) else: module_path = load_dir_path(