Skip to content

Commit

Permalink
Correctly resolve trust_remote_code=None for AutoTokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed Jan 9, 2024
1 parent 87a6cf4 commit 9338dd9
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,14 @@ def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs):
tokenizer_auto_map = config.auto_map["AutoTokenizer"]

has_remote_code = tokenizer_auto_map is not None
has_local_code = config_tokenizer_class is not None or type(config) in TOKENIZER_MAPPING
has_local_code = False
if type(config) in TOKENIZER_MAPPING:
has_local_code = True
elif config_tokenizer_class is not None:
for module_name, tokenizers in TOKENIZER_MAPPING_NAMES.items():
if config_tokenizer_class in tokenizers:
has_local_code = True
break
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_local_code, has_remote_code
)
Expand Down

0 comments on commit 9338dd9

Please sign in to comment.