Skip to content

Commit

Permalink
make add_lang_code_to_tokenizer a public function again
Browse files Browse the repository at this point in the history
  • Loading branch information
mshannon-sil committed Nov 6, 2024
1 parent 1fe806e commit 2269ae0
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
4 changes: 2 additions & 2 deletions machine/translation/huggingface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@

from .hugging_face_nmt_engine import HuggingFaceNmtEngine
from .hugging_face_nmt_model import HuggingFaceNmtModel
from .hugging_face_nmt_model_trainer import HuggingFaceNmtModelTrainer
from .hugging_face_nmt_model_trainer import HuggingFaceNmtModelTrainer, add_lang_code_to_tokenizer

__all__ = ["HuggingFaceNmtEngine", "HuggingFaceNmtModel", "HuggingFaceNmtModelTrainer"]
__all__ = ["add_lang_code_to_tokenizer", "HuggingFaceNmtEngine", "HuggingFaceNmtModel", "HuggingFaceNmtModelTrainer"]
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,9 @@ def add_tokens(tokenizer: Any, missing_tokens: List[str]) -> Any:
if isinstance(tokenizer, MULTILINGUAL_TOKENIZERS):
logger.info("Add new language codes as tokens")
if self._src_lang is not None:
_add_lang_code_to_tokenizer(tokenizer, self._src_lang)
add_lang_code_to_tokenizer(tokenizer, self._src_lang)
if self._tgt_lang is not None:
_add_lang_code_to_tokenizer(tokenizer, self._tgt_lang)
add_lang_code_to_tokenizer(tokenizer, self._tgt_lang)

# We resize the embeddings only when necessary to avoid index errors.
embedding_size = cast(Any, model.get_input_embeddings()).weight.shape[0]
Expand Down Expand Up @@ -398,7 +398,7 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra
)


def _add_lang_code_to_tokenizer(tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], lang_code: str):
def add_lang_code_to_tokenizer(tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], lang_code: str):
if isinstance(tokenizer, M2M100Tokenizer):
lang_token = "__" + lang_code + "__"
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from machine.corpora import DictionaryTextCorpus, MemoryText, TextRow
from machine.translation.huggingface import HuggingFaceNmtEngine, HuggingFaceNmtModelTrainer
from machine.translation.huggingface.hugging_face_nmt_model_trainer import _add_lang_code_to_tokenizer
from machine.translation.huggingface.hugging_face_nmt_model_trainer import add_lang_code_to_tokenizer


def test_train_non_empty_corpus() -> None:
Expand Down Expand Up @@ -481,7 +481,7 @@ def test_nllb_tokenizer_add_lang_code() -> None:
with TemporaryDirectory() as temp_dir:
tokenizer = cast(NllbTokenizer, NllbTokenizer.from_pretrained("facebook/nllb-200-distilled-600M"))
assert "new_lang" not in tokenizer.added_tokens_encoder
_add_lang_code_to_tokenizer(tokenizer, "new_lang")
add_lang_code_to_tokenizer(tokenizer, "new_lang")
assert "new_lang" in tokenizer.added_tokens_encoder
tokenizer.save_pretrained(temp_dir)
new_tokenizer = cast(NllbTokenizer, NllbTokenizer.from_pretrained(temp_dir))
Expand All @@ -493,7 +493,7 @@ def test_nllb_tokenizer_fast_add_lang_code() -> None:
with TemporaryDirectory() as temp_dir:
tokenizer = cast(NllbTokenizerFast, NllbTokenizerFast.from_pretrained("facebook/nllb-200-distilled-600M"))
assert "new_lang" not in tokenizer.added_tokens_encoder
_add_lang_code_to_tokenizer(tokenizer, "new_lang")
add_lang_code_to_tokenizer(tokenizer, "new_lang")
assert "new_lang" in tokenizer.added_tokens_encoder
tokenizer.save_pretrained(temp_dir)
new_tokenizer = cast(NllbTokenizerFast, NllbTokenizerFast.from_pretrained(temp_dir))
Expand All @@ -505,7 +505,7 @@ def test_mbart_tokenizer_add_lang_code() -> None:
with TemporaryDirectory() as temp_dir:
tokenizer = cast(MBartTokenizer, MBartTokenizer.from_pretrained("hf-internal-testing/tiny-random-nllb"))
assert "nl_NS" not in tokenizer.added_tokens_encoder
_add_lang_code_to_tokenizer(tokenizer, "nl_NS")
add_lang_code_to_tokenizer(tokenizer, "nl_NS")
assert "nl_NS" in tokenizer.added_tokens_encoder
tokenizer.save_pretrained(temp_dir)
new_tokenizer = cast(MBartTokenizer, MBartTokenizer.from_pretrained(temp_dir))
Expand All @@ -517,7 +517,7 @@ def test_mbart_tokenizer_fast_add_lang_code() -> None:
with TemporaryDirectory() as temp_dir:
tokenizer = cast(MBartTokenizerFast, MBartTokenizerFast.from_pretrained("hf-internal-testing/tiny-random-nllb"))
assert "nl_NS" not in tokenizer.added_tokens_encoder
_add_lang_code_to_tokenizer(tokenizer, "nl_NS")
add_lang_code_to_tokenizer(tokenizer, "nl_NS")
assert "nl_NS" in tokenizer.added_tokens_encoder
tokenizer.save_pretrained(temp_dir)
new_tokenizer = cast(MBartTokenizerFast, MBartTokenizerFast.from_pretrained(temp_dir))
Expand All @@ -529,7 +529,7 @@ def test_mbart_50_tokenizer_add_lang_code() -> None:
with TemporaryDirectory() as temp_dir:
tokenizer = cast(MBart50Tokenizer, MBart50Tokenizer.from_pretrained("hf-internal-testing/tiny-random-mbart50"))
assert "nl_NS" not in tokenizer.added_tokens_encoder
_add_lang_code_to_tokenizer(tokenizer, "nl_NS")
add_lang_code_to_tokenizer(tokenizer, "nl_NS")
assert "nl_NS" in tokenizer.added_tokens_encoder
tokenizer.save_pretrained(temp_dir)
new_tokenizer = cast(MBart50Tokenizer, MBart50Tokenizer.from_pretrained(temp_dir))
Expand All @@ -543,7 +543,7 @@ def test_mbart_50_tokenizer_fast_add_lang_code() -> None:
MBart50TokenizerFast, MBart50TokenizerFast.from_pretrained("hf-internal-testing/tiny-random-mbart50")
)
assert "nl_NS" not in tokenizer.added_tokens_encoder
_add_lang_code_to_tokenizer(tokenizer, "nl_NS")
add_lang_code_to_tokenizer(tokenizer, "nl_NS")
assert "nl_NS" in tokenizer.added_tokens_encoder
tokenizer.save_pretrained(temp_dir)
new_tokenizer = cast(MBart50TokenizerFast, MBart50TokenizerFast.from_pretrained(temp_dir))
Expand All @@ -556,7 +556,7 @@ def test_m2m_100_tokenizer_add_lang_code() -> None:
tokenizer = cast(M2M100Tokenizer, M2M100Tokenizer.from_pretrained("stas/tiny-m2m_100"))
assert "nc" not in tokenizer.lang_code_to_id
assert "__nc__" not in tokenizer.added_tokens_encoder
_add_lang_code_to_tokenizer(tokenizer, "nc")
add_lang_code_to_tokenizer(tokenizer, "nc")
assert "nc" in tokenizer.lang_code_to_id
assert "__nc__" in tokenizer.added_tokens_encoder
tokenizer.save_pretrained(temp_dir)
Expand Down

0 comments on commit 2269ae0

Please sign in to comment.