Skip to content

Commit

Permalink
Use weights_only=True with torch.load for transfo_xl (huggingfa…
Browse files Browse the repository at this point in the history
…ce#35241)

fix

Co-authored-by: ydshieh <[email protected]>
  • Loading branch information
ydshieh and ydshieh authored Dec 20, 2024
1 parent 6fae2a8 commit 0fc2970
Showing 1 changed file with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def __init__(
"from a PyTorch pretrained vocabulary, "
"or activate it with environment variables USE_TORCH=1 and USE_TF=0."
)
vocab_dict = torch.load(pretrained_vocab_file)
vocab_dict = torch.load(pretrained_vocab_file, weights_only=True)

if vocab_dict is not None:
for key, value in vocab_dict.items():
Expand Down Expand Up @@ -705,7 +705,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs,

# Instantiate tokenizer.
corpus = cls(*inputs, **kwargs)
corpus_dict = torch.load(resolved_corpus_file)
corpus_dict = torch.load(resolved_corpus_file, weights_only=True)
for key, value in corpus_dict.items():
corpus.__dict__[key] = value
corpus.vocab = vocab
Expand Down Expand Up @@ -784,7 +784,7 @@ def get_lm_corpus(datadir, dataset):
fn_pickle = os.path.join(datadir, "cache.pkl")
if os.path.exists(fn):
logger.info("Loading cached dataset...")
corpus = torch.load(fn_pickle)
corpus = torch.load(fn_pickle, weights_only=True)
elif os.path.exists(fn):
logger.info("Loading cached dataset from pickle...")
if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")):
Expand Down

0 comments on commit 0fc2970

Please sign in to comment.