Skip to content

Commit

Permalink
fix deprecated tokenizer methods, add test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
mshannon-sil committed Nov 6, 2024
1 parent 3a9df2c commit 1fe806e
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 35 deletions.
17 changes: 12 additions & 5 deletions machine/translation/huggingface/hugging_face_nmt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
import logging
import re
from math import exp, prod
from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union, cast
from typing import Iterable, List, Optional, Sequence, Tuple, Union, cast

import torch # pyright: ignore[reportMissingImports]
from sacremoses import MosesPunctNormalizer
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer,
M2M100Tokenizer,
NllbTokenizer,
NllbTokenizerFast,
PreTrainedModel,
Expand Down Expand Up @@ -73,17 +74,23 @@ def __init__(
self._pipeline_kwargs["prefix"] = f"translate {src_lang} to {tgt_lang}: "
else:
additional_special_tokens = self._tokenizer.additional_special_tokens
if isinstance(self._tokenizer, M2M100Tokenizer):
src_lang_token = self._tokenizer.lang_code_to_token.get(src_lang) if src_lang is not None else None
tgt_lang_token = self._tokenizer.lang_code_to_token.get(tgt_lang) if tgt_lang is not None else None
else:
src_lang_token = src_lang
tgt_lang_token = tgt_lang
if (
src_lang is not None
and src_lang not in cast(Any, self._tokenizer).lang_code_to_id
and src_lang not in additional_special_tokens
and src_lang_token not in self._tokenizer.added_tokens_encoder
and src_lang_token not in additional_special_tokens
):
raise ValueError(f"The specified model does not support the language code '{src_lang}'")

if (
tgt_lang is not None
and tgt_lang not in cast(Any, self._tokenizer).lang_code_to_id
and tgt_lang not in additional_special_tokens
and tgt_lang_token not in self._tokenizer.added_tokens_encoder
and tgt_lang_token not in additional_special_tokens
):
raise ValueError(f"The specified model does not support the language code '{tgt_lang}'")

Expand Down
49 changes: 29 additions & 20 deletions machine/translation/huggingface/hugging_face_nmt_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
NllbTokenizer,
NllbTokenizerFast,
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
Expand Down Expand Up @@ -218,30 +219,12 @@ def add_tokens(tokenizer: Any, missing_tokens: List[str]) -> Any:
if missing_tokens:
tokenizer = add_tokens(tokenizer, missing_tokens)

def add_lang_code_to_tokenizer(tokenizer: Any, lang_code: str):
if lang_code in tokenizer.lang_code_to_id:
return
tokenizer.add_special_tokens(
{"additional_special_tokens": tokenizer.additional_special_tokens + [lang_code]}
)
lang_id = tokenizer.convert_tokens_to_ids(lang_code)
tokenizer.lang_code_to_id[lang_code] = lang_id

if isinstance(tokenizer, (MBart50Tokenizer, MBartTokenizer)):
tokenizer.id_to_lang_code[lang_id] = lang_code
tokenizer.fairseq_tokens_to_ids[lang_code] = lang_id
tokenizer.fairseq_ids_to_tokens[lang_id] = lang_code
elif isinstance(tokenizer, M2M100Tokenizer):
tokenizer.lang_code_to_token[lang_code] = lang_code
tokenizer.lang_token_to_id[lang_code] = lang_id
tokenizer.id_to_lang_token[lang_id] = lang_code

if isinstance(tokenizer, MULTILINGUAL_TOKENIZERS):
logger.info("Add new language codes as tokens")
if self._src_lang is not None:
add_lang_code_to_tokenizer(tokenizer, self._src_lang)
_add_lang_code_to_tokenizer(tokenizer, self._src_lang)
if self._tgt_lang is not None:
add_lang_code_to_tokenizer(tokenizer, self._tgt_lang)
_add_lang_code_to_tokenizer(tokenizer, self._tgt_lang)

# We resize the embeddings only when necessary to avoid index errors.
embedding_size = cast(Any, model.get_input_embeddings()).weight.shape[0]
Expand Down Expand Up @@ -413,3 +396,29 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra
if self._max_steps is None
else ProgressStatus.from_step(state.global_step, self._max_steps)
)


def _add_lang_code_to_tokenizer(tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], lang_code: str):
if isinstance(tokenizer, M2M100Tokenizer):
lang_token = "__" + lang_code + "__"
else:
lang_token = lang_code

if lang_token in tokenizer.added_tokens_encoder:
return

tokenizer.add_special_tokens(
{"additional_special_tokens": tokenizer.additional_special_tokens + [lang_token]} # type: ignore
)
lang_id = cast(int, tokenizer.convert_tokens_to_ids(lang_token))

if isinstance(tokenizer, (MBart50Tokenizer, MBartTokenizer)):
tokenizer.lang_code_to_id[lang_code] = lang_id
tokenizer.id_to_lang_code[lang_id] = lang_code
tokenizer.fairseq_tokens_to_ids[lang_code] = lang_id
tokenizer.fairseq_ids_to_tokens[lang_id] = lang_code
elif isinstance(tokenizer, M2M100Tokenizer):
tokenizer.lang_code_to_id[lang_code] = lang_id
tokenizer.lang_code_to_token[lang_code] = lang_token
tokenizer.lang_token_to_id[lang_token] = lang_id
tokenizer.id_to_lang_token[lang_id] = lang_token
119 changes: 109 additions & 10 deletions tests/translation/huggingface/test_hugging_face_nmt_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,23 @@
skip("skipping Hugging Face tests on MacOS", allow_module_level=True)

from tempfile import TemporaryDirectory

from transformers import PreTrainedTokenizerFast, Seq2SeqTrainingArguments
from typing import cast

from transformers import (
M2M100Tokenizer,
MBart50Tokenizer,
MBart50TokenizerFast,
MBartTokenizer,
MBartTokenizerFast,
NllbTokenizer,
NllbTokenizerFast,
PreTrainedTokenizerFast,
Seq2SeqTrainingArguments,
)

from machine.corpora import DictionaryTextCorpus, MemoryText, TextRow
from machine.translation.huggingface import HuggingFaceNmtEngine, HuggingFaceNmtModelTrainer
from machine.translation.huggingface.hugging_face_nmt_model_trainer import _add_lang_code_to_tokenizer


def test_train_non_empty_corpus() -> None:
Expand Down Expand Up @@ -142,10 +154,8 @@ def test_update_tokenizer_missing_char() -> None:
"Ḻ, ḻ, Ṉ, ॽ, " + "‌ and " + "‍" + " are new characters"
)
finetuned_result_nochar_composite = finetuned_engine_nochar.tokenizer.encode("Ḏ is a composite character")
normalized_result_nochar1 = finetuned_engine_nochar.tokenizer.backend_tokenizer.normalizer.normalize_str(
"‌ "
)
normalized_result_nochar2 = finetuned_engine_nochar.tokenizer.backend_tokenizer.normalizer.normalize_str("‍")
norm_result_nochar1 = finetuned_engine_nochar.tokenizer.backend_tokenizer.normalizer.normalize_str("‌ ")
norm_result_nochar2 = finetuned_engine_nochar.tokenizer.backend_tokenizer.normalizer.normalize_str("‍")

with HuggingFaceNmtModelTrainer(
"hf-internal-testing/tiny-random-nllb",
Expand All @@ -167,11 +177,11 @@ def test_update_tokenizer_missing_char() -> None:
"Ḻ, ḻ, Ṉ, ॽ, " + "‌ and " + "‍" + " are new characters"
)
finetuned_result_char_composite = finetuned_engine_char.tokenizer.encode("Ḏ is a composite character")
normalized_result_char1 = finetuned_engine_char.tokenizer.backend_tokenizer.normalizer.normalize_str("‌ ")
normalized_result_char2 = finetuned_engine_char.tokenizer.backend_tokenizer.normalizer.normalize_str("‍")
norm_result_char1 = finetuned_engine_char.tokenizer.backend_tokenizer.normalizer.normalize_str("‌ ")
norm_result_char2 = finetuned_engine_char.tokenizer.backend_tokenizer.normalizer.normalize_str("‍")

assert normalized_result_nochar1 != normalized_result_char1
assert normalized_result_nochar2 != normalized_result_char2
assert norm_result_nochar1 != norm_result_char1
assert norm_result_nochar2 != norm_result_char2

assert finetuned_result_nochar != finetuned_result_char
assert finetuned_result_nochar_composite != finetuned_result_char_composite
Expand Down Expand Up @@ -467,5 +477,94 @@ def test_update_tokenizer_no_missing_char() -> None:
assert finetuned_result_nochar == finetuned_result_char


def test_nllb_tokenizer_add_lang_code() -> None:
with TemporaryDirectory() as temp_dir:
tokenizer = cast(NllbTokenizer, NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M"))
assert "new_lang" not in tokenizer.added_tokens_encoder
_add_lang_code_to_tokenizer(tokenizer, "new_lang")
assert "new_lang" in tokenizer.added_tokens_encoder
tokenizer.save_pretrained(temp_dir)
new_tokenizer = cast(NllbTokenizer, NllbTokenizer.from_pretrained(temp_dir))
assert "new_lang" in new_tokenizer.added_tokens_encoder
return


def test_nllb_tokenizer_fast_add_lang_code() -> None:
with TemporaryDirectory() as temp_dir:
tokenizer = cast(NllbTokenizerFast, NllbTokenizerFast.from_pretrained("facebook/nllb-200-distilled-600M"))
assert "new_lang" not in tokenizer.added_tokens_encoder
_add_lang_code_to_tokenizer(tokenizer, "new_lang")
assert "new_lang" in tokenizer.added_tokens_encoder
tokenizer.save_pretrained(temp_dir)
new_tokenizer = cast(NllbTokenizerFast, NllbTokenizerFast.from_pretrained(temp_dir))
assert "new_lang" in new_tokenizer.added_tokens_encoder
return


def test_mbart_tokenizer_add_lang_code() -> None:
with TemporaryDirectory() as temp_dir:
tokenizer = cast(MBartTokenizer, MBartTokenizer.from_pretrained("hf-internal-testing/tiny-random-nllb"))
assert "nl_NS" not in tokenizer.added_tokens_encoder
_add_lang_code_to_tokenizer(tokenizer, "nl_NS")
assert "nl_NS" in tokenizer.added_tokens_encoder
tokenizer.save_pretrained(temp_dir)
new_tokenizer = cast(MBartTokenizer, MBartTokenizer.from_pretrained(temp_dir))
assert "nl_NS" in new_tokenizer.added_tokens_encoder
return


def test_mbart_tokenizer_fast_add_lang_code() -> None:
with TemporaryDirectory() as temp_dir:
tokenizer = cast(MBartTokenizerFast, MBartTokenizerFast.from_pretrained("hf-internal-testing/tiny-random-nllb"))
assert "nl_NS" not in tokenizer.added_tokens_encoder
_add_lang_code_to_tokenizer(tokenizer, "nl_NS")
assert "nl_NS" in tokenizer.added_tokens_encoder
tokenizer.save_pretrained(temp_dir)
new_tokenizer = cast(MBartTokenizerFast, MBartTokenizerFast.from_pretrained(temp_dir))
assert "nl_NS" in new_tokenizer.added_tokens_encoder
return


def test_mbart_50_tokenizer_add_lang_code() -> None:
with TemporaryDirectory() as temp_dir:
tokenizer = cast(MBart50Tokenizer, MBart50Tokenizer.from_pretrained("hf-internal-testing/tiny-random-mbart50"))
assert "nl_NS" not in tokenizer.added_tokens_encoder
_add_lang_code_to_tokenizer(tokenizer, "nl_NS")
assert "nl_NS" in tokenizer.added_tokens_encoder
tokenizer.save_pretrained(temp_dir)
new_tokenizer = cast(MBart50Tokenizer, MBart50Tokenizer.from_pretrained(temp_dir))
assert "nl_NS" in new_tokenizer.added_tokens_encoder
return


def test_mbart_50_tokenizer_fast_add_lang_code() -> None:
with TemporaryDirectory() as temp_dir:
tokenizer = cast(
MBart50TokenizerFast, MBart50TokenizerFast.from_pretrained("hf-internal-testing/tiny-random-mbart50")
)
assert "nl_NS" not in tokenizer.added_tokens_encoder
_add_lang_code_to_tokenizer(tokenizer, "nl_NS")
assert "nl_NS" in tokenizer.added_tokens_encoder
tokenizer.save_pretrained(temp_dir)
new_tokenizer = cast(MBart50TokenizerFast, MBart50TokenizerFast.from_pretrained(temp_dir))
assert "nl_NS" in new_tokenizer.added_tokens_encoder
return


def test_m2m_100_tokenizer_add_lang_code() -> None:
with TemporaryDirectory() as temp_dir:
tokenizer = cast(M2M100Tokenizer, M2M100Tokenizer.from_pretrained("stas/tiny-m2m_100"))
assert "nc" not in tokenizer.lang_code_to_id
assert "__nc__" not in tokenizer.added_tokens_encoder
_add_lang_code_to_tokenizer(tokenizer, "nc")
assert "nc" in tokenizer.lang_code_to_id
assert "__nc__" in tokenizer.added_tokens_encoder
tokenizer.save_pretrained(temp_dir)
new_tokenizer = cast(M2M100Tokenizer, M2M100Tokenizer.from_pretrained(temp_dir))
assert "nc" in tokenizer.lang_code_to_id
assert "__nc__" in new_tokenizer.added_tokens_encoder
return


def _row(row_ref: int, text: str) -> TextRow:
return TextRow("text1", row_ref, segment=[text])

0 comments on commit 1fe806e

Please sign in to comment.