From 36733928490b08da15ab08b728b023c9621fc1a3 Mon Sep 17 00:00:00 2001 From: Damien Daspit Date: Thu, 24 Oct 2024 07:22:50 -0500 Subject: [PATCH] Fix source tokens for HF fast tokenizers (#133) - update MT tutorial to demonstrate TranslationResult - incompatible with transformers>=4.42, because of breaking change in NLLB tokenizer --- .../huggingface/hugging_face_nmt_engine.py | 9 +++- poetry.lock | 21 ++++---- pyproject.toml | 4 +- samples/machine_translation.ipynb | 50 ++++++++++++++----- 4 files changed, 56 insertions(+), 28 deletions(-) diff --git a/machine/translation/huggingface/hugging_face_nmt_engine.py b/machine/translation/huggingface/hugging_face_nmt_engine.py index 2906fed..b3b609a 100644 --- a/machine/translation/huggingface/hugging_face_nmt_engine.py +++ b/machine/translation/huggingface/hugging_face_nmt_engine.py @@ -203,9 +203,10 @@ def preprocess(self, *args, truncation=TruncationStrategy.DO_NOT_TRUNCATE, src_l BatchEncoding, super().preprocess(*sentences, truncation=truncation, src_lang=src_lang, tgt_lang=tgt_lang), ) - if inputs.is_fast: + if inputs.encodings is not None: inputs["input_tokens"] = [ - inputs.tokens(i) if isinstance(args[i], str) else args[i] for i in range(len(args)) + _get_encoding_fast_tokens(inputs.encodings[i]) if isinstance(args[i], str) else args[i] + for i in range(len(args)) ] else: inputs["input_tokens"] = [self.tokenizer.tokenize(s) if isinstance(s, str) else s for s in args] @@ -379,3 +380,7 @@ def torch_gather_nd(params: torch.Tensor, indices: torch.Tensor, batch_dim: int out = torch.gather(params, dim=batch_dim, index=indices) return out.reshape(*index_shape, *tail_sizes) + + +def _get_encoding_fast_tokens(encoding) -> List[str]: + return [token for (token, mask) in zip(encoding.tokens, encoding.special_tokens_mask) if not mask] diff --git a/poetry.lock b/poetry.lock index 31f0d6c..a28c47d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1737,21 +1737,21 @@ files = [ [[package]] name = "networkx" -version = "2.6.3" +version = "3.1" description = "Python package for creating and manipulating graphs and networks" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "networkx-2.6.3-py3-none-any.whl", hash = "sha256:80b6b89c77d1dfb64a4c7854981b60aeea6360ac02c6d4e4913319e0a313abef"}, - {file = "networkx-2.6.3.tar.gz", hash = "sha256:c0946ed31d71f1b732b5aaa6da5a0388a345019af232ce2f49c766e2d6795c51"}, + {file = "networkx-3.1-py3-none-any.whl", hash = "sha256:4f33f68cb2afcf86f28a45f43efc27a9386b535d567d2127f8f61d51dec58d36"}, + {file = "networkx-3.1.tar.gz", hash = "sha256:de346335408f84de0eada6ff9fafafff9bcda11f0a0dfaa931133debb146ab61"}, ] [package.extras] -default = ["matplotlib (>=3.3)", "numpy (>=1.19)", "pandas (>=1.1)", "scipy (>=1.5,!=1.6.1)"] -developer = ["black (==21.5b1)", "pre-commit (>=2.12)"] -doc = ["nb2plots (>=0.6)", "numpydoc (>=1.1)", "pillow (>=8.2)", "pydata-sphinx-theme (>=0.6,<1.0)", "sphinx (>=4.0,<5.0)", "sphinx-gallery (>=0.9,<1.0)", "texext (>=0.6.6)"] -extra = ["lxml (>=4.5)", "pydot (>=1.4.1)", "pygraphviz (>=1.7)"] -test = ["codecov (>=2.1)", "pytest (>=6.2)", "pytest-cov (>=2.12)"] +default = ["matplotlib (>=3.4)", "numpy (>=1.20)", "pandas (>=1.3)", "scipy (>=1.8)"] +developer = ["mypy (>=1.1)", "pre-commit (>=3.2)"] +doc = ["nb2plots (>=0.6)", "numpydoc (>=1.5)", "pillow (>=9.4)", "pydata-sphinx-theme (>=0.13)", "sphinx (>=6.1)", "sphinx-gallery (>=0.12)", "texext (>=0.6.7)"] +extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.10)", "sympy (>=1.10)"] +test = ["codecov (>=2.1)", "pytest (>=7.2)", "pytest-cov (>=4.0)"] [[package]] name = "nodeenv" @@ -2630,7 +2630,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -4008,4 +4007,4 @@ thot = ["sil-thot"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.12" -content-hash = "e25fd409a86457951a9ba91a820377a7d2cf6c424f2e922bc7aa2a92011b20c6" +content-hash = "70fb2f7721d9f7ec9c0cb9b053a8a84b7bcac62a66c052988b73bb3fae9d18ba" diff --git a/pyproject.toml b/pyproject.toml index 16fb753..9367d14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,14 +55,14 @@ python = ">=3.8.1,<3.12" regex = ">=2021.7.6" numpy = "^1.24.4" sortedcontainers = "^2.4.0" -networkx = "^2.6.3" +networkx = "^3" charset-normalizer = "^2.1.1" ### extras sentencepiece = "^0.1.95" sil-thot = "^3.4.4" # huggingface extras -transformers = "^4.34.0" +transformers = ">=4.34.0,<4.42" datasets = "^2.4.0" sacremoses = "^0.0.53" # job extras diff --git a/samples/machine_translation.ipynb b/samples/machine_translation.ipynb index 255a054..89103f2 100644 --- a/samples/machine_translation.ipynb +++ b/samples/machine_translation.ipynb @@ -40,7 +40,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -60,7 +60,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -130,19 +130,23 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Now that we have a trained SMT model and a trained truecasing model, we are ready to translate sentences. First, We need to load the SMT model. The model can be used to translate sentences using the `translate` method.\n" + "Now that we have a trained SMT model and a trained truecasing model, we are ready to translate sentences. First, We need to load the SMT model. The model can be used to translate sentences using the `translate` method. A `TranslationResult` instance is returned when a text segment is translated. In addition to the translated segment, `TranslationResult` contains lots of interesting information about the translated sentence, such as the word confidences, alignment, phrases, and source/target tokens." ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "I would like to book a room until tomorrow.\n" + "Translation: I would like to book a room until tomorrow.\n", + "Source tokens: ['Desearía', 'reservar', 'una', 'habitación', 'hasta', 'mañana', '.']\n", + "Target tokens: ['I', 'would', 'like', 'to', 'book', 'a', 'room', 'until', 'tomorrow', '.']\n", + "Alignment: 0-1 0-2 1-3 1-4 2-5 3-6 4-7 5-8 6-9\n", + "Confidences: [0.1833474940416596, 0.3568307371510516, 0.3556863860951534, 0.2894564705698258, 0.726984900023586, 0.8915912178040876, 0.878754356224247, 0.8849444691927844, 0.8458962922106739, 0.8975745812873857]\n" ] } ], @@ -165,7 +169,11 @@ " lowercase_target=True,\n", ") as model:\n", " result = model.translate(\"Desearía reservar una habitación hasta mañana.\")\n", - " print(result.translation)" + " print(\"Translation:\", result.translation)\n", + " print(\"Source tokens:\", result.source_tokens)\n", + " print(\"Target tokens:\", result.target_tokens)\n", + " print(\"Alignment:\", result.alignment)\n", + " print(\"Confidences:\", result.confidences)" ] }, { @@ -267,7 +275,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -399,14 +407,18 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "I would like to book a room until tomorrow.\n" + "Translation: I would like to book a room until tomorrow.\n", + "Source tokens: ['▁D', 'ese', 'aría', '▁res', 'ervar', '▁una', '▁hab', 'itación', '▁hasta', '▁mañana', '.']\n", + "Target tokens: ['▁I', '▁would', '▁like', '▁to', '▁book', '▁a', '▁room', '▁until', '▁tom', 'orrow', '.']\n", + "Alignment: 1-2 2-0 2-1 3-4 4-3 5-5 6-6 8-7 9-8 9-9 10-10\n", + "Confidences: [0.9995167207904968, 0.9988614185814005, 0.9995524502931971, 0.9861009574421602, 0.9987220427038153, 0.998968593209302, 0.9944791909715244, 0.9989702587912649, 0.9749540518542505, 0.9996603689253716, 0.9930446924545876]\n" ] } ], @@ -415,7 +427,11 @@ "\n", "with HuggingFaceNmtEngine(\"out/sp-en-nmt\", src_lang=\"es\", tgt_lang=\"en\") as engine:\n", " result = engine.translate(\"Desearía reservar una habitación hasta mañana.\")\n", - " print(result.translation)" + " print(\"Translation:\", result.translation)\n", + " print(\"Source tokens:\", result.source_tokens)\n", + " print(\"Target tokens:\", result.target_tokens)\n", + " print(\"Alignment:\", result.alignment)\n", + " print(\"Confidences:\", result.confidences)" ] }, { @@ -427,21 +443,29 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "I'd like to reserve a room for tomorrow.\n" + "Translation: I'd like to reserve a room for tomorrow.\n", + "Source tokens: ['▁Dese', 'aría', '▁reser', 'var', '▁una', '▁habitación', '▁hasta', '▁mañana', '.']\n", + "Target tokens: ['▁I', \"'\", 'd', '▁like', '▁to', '▁reserve', '▁a', '▁room', '▁for', '▁tomorrow', '.']\n", + "Alignment: 0-1 0-3 1-0 1-2 2-5 5-6 5-7 6-8 7-9 8-4 8-10\n", + "Confidences: [0.766540320750896, 0.5910241514763206, 0.8868627789322919, 0.8544048979056736, 0.8613305047447863, 0.45655845183164, 0.8814725030368357, 0.8585703155792751, 0.3142652857171965, 0.8780149028315941, 0.8617016651426532]\n" ] } ], "source": [ "with HuggingFaceNmtEngine(\"facebook/nllb-200-distilled-600M\", src_lang=\"spa_Latn\", tgt_lang=\"eng_Latn\") as engine:\n", " result = engine.translate(\"Desearía reservar una habitación hasta mañana.\")\n", - " print(result.translation)" + " print(\"Translation:\", result.translation)\n", + " print(\"Source tokens:\", result.source_tokens)\n", + " print(\"Target tokens:\", result.target_tokens)\n", + " print(\"Alignment:\", result.alignment)\n", + " print(\"Confidences:\", result.confidences)" ] } ],