Skip to content

Commit

Permalink
Updated transformers to 4.45.2
Browse files Browse the repository at this point in the history
  • Loading branch information
TaperChipmunk32 committed Oct 23, 2024
1 parent 1415d12 commit 3da797f
Show file tree
Hide file tree
Showing 3 changed files with 225 additions and 207 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def add_lang_code_to_tokenizer(tokenizer: Any, lang_code: str):
lang_id = tokenizer.convert_tokens_to_ids(lang_code)
tokenizer.lang_code_to_id[lang_code] = lang_id

if isinstance(tokenizer, (NllbTokenizer, MBart50Tokenizer, MBartTokenizer)):
if isinstance(tokenizer, (MBart50Tokenizer, MBartTokenizer)):
tokenizer.id_to_lang_code[lang_id] = lang_code
tokenizer.fairseq_tokens_to_ids[lang_code] = lang_id
tokenizer.fairseq_ids_to_tokens[lang_id] = lang_code
Expand Down Expand Up @@ -276,7 +276,7 @@ def add_lang_code_to_tokenizer(tokenizer: Any, lang_code: str):

# For multilingual translation models like mBART-50 and M2M100 we need to force the target language token
# as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument.
forced_bos_token_id = tokenizer.lang_code_to_id[self._tgt_lang]
forced_bos_token_id = tokenizer.convert_tokens_to_ids(self._tgt_lang)
model.config.forced_bos_token_id = forced_bos_token_id
if model.generation_config is not None:
model.generation_config.forced_bos_token_id = forced_bos_token_id
Expand Down
Loading

0 comments on commit 3da797f

Please sign in to comment.