Skip to content

Commit

Permalink
Fix source tokens for HF fast tokenizers (#133)
Browse files Browse the repository at this point in the history
- update MT tutorial to demonstrate TranslationResult
- incompatible with transformers>=4.42, because of breaking change in NLLB tokenizer
  • Loading branch information
ddaspit authored Oct 24, 2024
1 parent ca7fba2 commit 3673392
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 28 deletions.
9 changes: 7 additions & 2 deletions machine/translation/huggingface/hugging_face_nmt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,10 @@ def preprocess(self, *args, truncation=TruncationStrategy.DO_NOT_TRUNCATE, src_l
BatchEncoding,
super().preprocess(*sentences, truncation=truncation, src_lang=src_lang, tgt_lang=tgt_lang),
)
if inputs.is_fast:
if inputs.encodings is not None:
inputs["input_tokens"] = [
inputs.tokens(i) if isinstance(args[i], str) else args[i] for i in range(len(args))
_get_encoding_fast_tokens(inputs.encodings[i]) if isinstance(args[i], str) else args[i]
for i in range(len(args))
]
else:
inputs["input_tokens"] = [self.tokenizer.tokenize(s) if isinstance(s, str) else s for s in args]
Expand Down Expand Up @@ -379,3 +380,7 @@ def torch_gather_nd(params: torch.Tensor, indices: torch.Tensor, batch_dim: int

out = torch.gather(params, dim=batch_dim, index=indices)
return out.reshape(*index_shape, *tail_sizes)


def _get_encoding_fast_tokens(encoding) -> List[str]:
return [token for (token, mask) in zip(encoding.tokens, encoding.special_tokens_mask) if not mask]
21 changes: 10 additions & 11 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ python = ">=3.8.1,<3.12"
regex = ">=2021.7.6"
numpy = "^1.24.4"
sortedcontainers = "^2.4.0"
networkx = "^2.6.3"
networkx = "^3"
charset-normalizer = "^2.1.1"

### extras
sentencepiece = "^0.1.95"
sil-thot = "^3.4.4"
# huggingface extras
transformers = "^4.34.0"
transformers = ">=4.34.0,<4.42"
datasets = "^2.4.0"
sacremoses = "^0.0.53"
# job extras
Expand Down
50 changes: 37 additions & 13 deletions samples/machine_translation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -60,7 +60,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -130,19 +130,23 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we have a trained SMT model and a trained truecasing model, we are ready to translate sentences. First, We need to load the SMT model. The model can be used to translate sentences using the `translate` method.\n"
"Now that we have a trained SMT model and a trained truecasing model, we are ready to translate sentences. First, We need to load the SMT model. The model can be used to translate sentences using the `translate` method. A `TranslationResult` instance is returned when a text segment is translated. In addition to the translated segment, `TranslationResult` contains lots of interesting information about the translated sentence, such as the word confidences, alignment, phrases, and source/target tokens."
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"I would like to book a room until tomorrow.\n"
"Translation: I would like to book a room until tomorrow.\n",
"Source tokens: ['Desearía', 'reservar', 'una', 'habitación', 'hasta', 'mañana', '.']\n",
"Target tokens: ['I', 'would', 'like', 'to', 'book', 'a', 'room', 'until', 'tomorrow', '.']\n",
"Alignment: 0-1 0-2 1-3 1-4 2-5 3-6 4-7 5-8 6-9\n",
"Confidences: [0.1833474940416596, 0.3568307371510516, 0.3556863860951534, 0.2894564705698258, 0.726984900023586, 0.8915912178040876, 0.878754356224247, 0.8849444691927844, 0.8458962922106739, 0.8975745812873857]\n"
]
}
],
Expand All @@ -165,7 +169,11 @@
" lowercase_target=True,\n",
") as model:\n",
" result = model.translate(\"Desearía reservar una habitación hasta mañana.\")\n",
" print(result.translation)"
" print(\"Translation:\", result.translation)\n",
" print(\"Source tokens:\", result.source_tokens)\n",
" print(\"Target tokens:\", result.target_tokens)\n",
" print(\"Alignment:\", result.alignment)\n",
" print(\"Confidences:\", result.confidences)"
]
},
{
Expand Down Expand Up @@ -267,7 +275,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -399,14 +407,18 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"I would like to book a room until tomorrow.\n"
"Translation: I would like to book a room until tomorrow.\n",
"Source tokens: ['▁D', 'ese', 'aría', '▁res', 'ervar', '▁una', '▁hab', 'itación', '▁hasta', '▁mañana', '.']\n",
"Target tokens: ['▁I', '▁would', '▁like', '▁to', '▁book', '▁a', '▁room', '▁until', '▁tom', 'orrow', '.']\n",
"Alignment: 1-2 2-0 2-1 3-4 4-3 5-5 6-6 8-7 9-8 9-9 10-10\n",
"Confidences: [0.9995167207904968, 0.9988614185814005, 0.9995524502931971, 0.9861009574421602, 0.9987220427038153, 0.998968593209302, 0.9944791909715244, 0.9989702587912649, 0.9749540518542505, 0.9996603689253716, 0.9930446924545876]\n"
]
}
],
Expand All @@ -415,7 +427,11 @@
"\n",
"with HuggingFaceNmtEngine(\"out/sp-en-nmt\", src_lang=\"es\", tgt_lang=\"en\") as engine:\n",
" result = engine.translate(\"Desearía reservar una habitación hasta mañana.\")\n",
" print(result.translation)"
" print(\"Translation:\", result.translation)\n",
" print(\"Source tokens:\", result.source_tokens)\n",
" print(\"Target tokens:\", result.target_tokens)\n",
" print(\"Alignment:\", result.alignment)\n",
" print(\"Confidences:\", result.confidences)"
]
},
{
Expand All @@ -427,21 +443,29 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"I'd like to reserve a room for tomorrow.\n"
"Translation: I'd like to reserve a room for tomorrow.\n",
"Source tokens: ['▁Dese', 'aría', '▁reser', 'var', '▁una', '▁habitación', '▁hasta', '▁mañana', '.']\n",
"Target tokens: ['▁I', \"'\", 'd', '▁like', '▁to', '▁reserve', '▁a', '▁room', '▁for', '▁tomorrow', '.']\n",
"Alignment: 0-1 0-3 1-0 1-2 2-5 5-6 5-7 6-8 7-9 8-4 8-10\n",
"Confidences: [0.766540320750896, 0.5910241514763206, 0.8868627789322919, 0.8544048979056736, 0.8613305047447863, 0.45655845183164, 0.8814725030368357, 0.8585703155792751, 0.3142652857171965, 0.8780149028315941, 0.8617016651426532]\n"
]
}
],
"source": [
"with HuggingFaceNmtEngine(\"facebook/nllb-200-distilled-600M\", src_lang=\"spa_Latn\", tgt_lang=\"eng_Latn\") as engine:\n",
" result = engine.translate(\"Desearía reservar una habitación hasta mañana.\")\n",
" print(result.translation)"
" print(\"Translation:\", result.translation)\n",
" print(\"Source tokens:\", result.source_tokens)\n",
" print(\"Target tokens:\", result.target_tokens)\n",
" print(\"Alignment:\", result.alignment)\n",
" print(\"Confidences:\", result.confidences)"
]
}
],
Expand Down

0 comments on commit 3673392

Please sign in to comment.