diff --git a/machine/translation/huggingface/__init__.py b/machine/translation/huggingface/__init__.py index c4c5ceb..3342eaf 100644 --- a/machine/translation/huggingface/__init__.py +++ b/machine/translation/huggingface/__init__.py @@ -8,6 +8,6 @@ from .hugging_face_nmt_engine import HuggingFaceNmtEngine from .hugging_face_nmt_model import HuggingFaceNmtModel -from .hugging_face_nmt_model_trainer import HuggingFaceNmtModelTrainer +from .hugging_face_nmt_model_trainer import HuggingFaceNmtModelTrainer, add_lang_code_to_tokenizer -__all__ = ["HuggingFaceNmtEngine", "HuggingFaceNmtModel", "HuggingFaceNmtModelTrainer"] +__all__ = ["add_lang_code_to_tokenizer", "HuggingFaceNmtEngine", "HuggingFaceNmtModel", "HuggingFaceNmtModelTrainer"] diff --git a/machine/translation/huggingface/hugging_face_nmt_model_trainer.py b/machine/translation/huggingface/hugging_face_nmt_model_trainer.py index 2c80e1d..f15ba4e 100644 --- a/machine/translation/huggingface/hugging_face_nmt_model_trainer.py +++ b/machine/translation/huggingface/hugging_face_nmt_model_trainer.py @@ -222,9 +222,9 @@ def add_tokens(tokenizer: Any, missing_tokens: List[str]) -> Any: if isinstance(tokenizer, MULTILINGUAL_TOKENIZERS): logger.info("Add new language codes as tokens") if self._src_lang is not None: - _add_lang_code_to_tokenizer(tokenizer, self._src_lang) + add_lang_code_to_tokenizer(tokenizer, self._src_lang) if self._tgt_lang is not None: - _add_lang_code_to_tokenizer(tokenizer, self._tgt_lang) + add_lang_code_to_tokenizer(tokenizer, self._tgt_lang) # We resize the embeddings only when necessary to avoid index errors. embedding_size = cast(Any, model.get_input_embeddings()).weight.shape[0] @@ -398,7 +398,7 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra ) -def _add_lang_code_to_tokenizer(tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], lang_code: str): +def add_lang_code_to_tokenizer(tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], lang_code: str): if isinstance(tokenizer, M2M100Tokenizer): lang_token = "__" + lang_code + "__" else: diff --git a/tests/translation/huggingface/test_hugging_face_nmt_model_trainer.py b/tests/translation/huggingface/test_hugging_face_nmt_model_trainer.py index 068a8c6..4094616 100644 --- a/tests/translation/huggingface/test_hugging_face_nmt_model_trainer.py +++ b/tests/translation/huggingface/test_hugging_face_nmt_model_trainer.py @@ -22,7 +22,7 @@ from machine.corpora import DictionaryTextCorpus, MemoryText, TextRow from machine.translation.huggingface import HuggingFaceNmtEngine, HuggingFaceNmtModelTrainer -from machine.translation.huggingface.hugging_face_nmt_model_trainer import _add_lang_code_to_tokenizer +from machine.translation.huggingface.hugging_face_nmt_model_trainer import add_lang_code_to_tokenizer def test_train_non_empty_corpus() -> None: @@ -481,7 +481,7 @@ def test_nllb_tokenizer_add_lang_code() -> None: with TemporaryDirectory() as temp_dir: tokenizer = cast(NllbTokenizer, NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")) assert "new_lang" not in tokenizer.added_tokens_encoder - _add_lang_code_to_tokenizer(tokenizer, "new_lang") + add_lang_code_to_tokenizer(tokenizer, "new_lang") assert "new_lang" in tokenizer.added_tokens_encoder tokenizer.save_pretrained(temp_dir) new_tokenizer = cast(NllbTokenizer, NllbTokenizer.from_pretrained(temp_dir)) @@ -493,7 +493,7 @@ def test_nllb_tokenizer_fast_add_lang_code() -> None: with TemporaryDirectory() as temp_dir: tokenizer = cast(NllbTokenizerFast, NllbTokenizerFast.from_pretrained("facebook/nllb-200-distilled-600M")) assert "new_lang" not in tokenizer.added_tokens_encoder - _add_lang_code_to_tokenizer(tokenizer, "new_lang") + add_lang_code_to_tokenizer(tokenizer, "new_lang") assert "new_lang" in tokenizer.added_tokens_encoder tokenizer.save_pretrained(temp_dir) new_tokenizer = cast(NllbTokenizerFast, NllbTokenizerFast.from_pretrained(temp_dir)) @@ -505,7 +505,7 @@ def test_mbart_tokenizer_add_lang_code() -> None: with TemporaryDirectory() as temp_dir: tokenizer = cast(MBartTokenizer, MBartTokenizer.from_pretrained("hf-internal-testing/tiny-random-nllb")) assert "nl_NS" not in tokenizer.added_tokens_encoder - _add_lang_code_to_tokenizer(tokenizer, "nl_NS") + add_lang_code_to_tokenizer(tokenizer, "nl_NS") assert "nl_NS" in tokenizer.added_tokens_encoder tokenizer.save_pretrained(temp_dir) new_tokenizer = cast(MBartTokenizer, MBartTokenizer.from_pretrained(temp_dir)) @@ -517,7 +517,7 @@ def test_mbart_tokenizer_fast_add_lang_code() -> None: with TemporaryDirectory() as temp_dir: tokenizer = cast(MBartTokenizerFast, MBartTokenizerFast.from_pretrained("hf-internal-testing/tiny-random-nllb")) assert "nl_NS" not in tokenizer.added_tokens_encoder - _add_lang_code_to_tokenizer(tokenizer, "nl_NS") + add_lang_code_to_tokenizer(tokenizer, "nl_NS") assert "nl_NS" in tokenizer.added_tokens_encoder tokenizer.save_pretrained(temp_dir) new_tokenizer = cast(MBartTokenizerFast, MBartTokenizerFast.from_pretrained(temp_dir)) @@ -529,7 +529,7 @@ def test_mbart_50_tokenizer_add_lang_code() -> None: with TemporaryDirectory() as temp_dir: tokenizer = cast(MBart50Tokenizer, MBart50Tokenizer.from_pretrained("hf-internal-testing/tiny-random-mbart50")) assert "nl_NS" not in tokenizer.added_tokens_encoder - _add_lang_code_to_tokenizer(tokenizer, "nl_NS") + add_lang_code_to_tokenizer(tokenizer, "nl_NS") assert "nl_NS" in tokenizer.added_tokens_encoder tokenizer.save_pretrained(temp_dir) new_tokenizer = cast(MBart50Tokenizer, MBart50Tokenizer.from_pretrained(temp_dir)) @@ -543,7 +543,7 @@ def test_mbart_50_tokenizer_fast_add_lang_code() -> None: MBart50TokenizerFast, MBart50TokenizerFast.from_pretrained("hf-internal-testing/tiny-random-mbart50") ) assert "nl_NS" not in tokenizer.added_tokens_encoder - _add_lang_code_to_tokenizer(tokenizer, "nl_NS") + add_lang_code_to_tokenizer(tokenizer, "nl_NS") assert "nl_NS" in tokenizer.added_tokens_encoder tokenizer.save_pretrained(temp_dir) new_tokenizer = cast(MBart50TokenizerFast, MBart50TokenizerFast.from_pretrained(temp_dir)) @@ -556,7 +556,7 @@ def test_m2m_100_tokenizer_add_lang_code() -> None: tokenizer = cast(M2M100Tokenizer, M2M100Tokenizer.from_pretrained("stas/tiny-m2m_100")) assert "nc" not in tokenizer.lang_code_to_id assert "__nc__" not in tokenizer.added_tokens_encoder - _add_lang_code_to_tokenizer(tokenizer, "nc") + add_lang_code_to_tokenizer(tokenizer, "nc") assert "nc" in tokenizer.lang_code_to_id assert "__nc__" in tokenizer.added_tokens_encoder tokenizer.save_pretrained(temp_dir)