Skip to content

Commit

Permalink
address efficiency issues
Browse files Browse the repository at this point in the history
  • Loading branch information
mshannon-sil committed Jan 11, 2024
1 parent 4adbb9a commit 8629f3f
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 33 deletions.
50 changes: 34 additions & 16 deletions machine/translation/huggingface/hugging_face_nmt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import re
from math import exp, prod
from typing import Any, Iterable, List, Sequence, Tuple, Union, cast
from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union, cast

import torch # pyright: ignore[reportMissingImports]
from sacremoses import MosesPunctNormalizer
Expand All @@ -15,6 +15,8 @@
NllbTokenizer,
NllbTokenizerFast,
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
TranslationPipeline,
)
from transformers.generation import BeamSearchEncoderDecoderOutput, GreedySearchEncoderDecoderOutput
Expand Down Expand Up @@ -48,6 +50,11 @@ def __init__(
PreTrainedModel, AutoModelForSeq2SeqLM.from_pretrained(str(self._model), config=model_config)
)
self._tokenizer = AutoTokenizer.from_pretrained(self._model.name_or_path, use_fast=True)
if isinstance(self._tokenizer, (NllbTokenizer, NllbTokenizerFast)):
self._mpn = MosesPunctNormalizer()
self._mpn.substitutions = [(re.compile(r), sub) for r, sub in self._mpn.substitutions]
else:
self._mpn = None

src_lang = self._pipeline_kwargs.get("src_lang")
tgt_lang = self._pipeline_kwargs.get("tgt_lang")
Expand Down Expand Up @@ -81,6 +88,7 @@ def __init__(
self._pipeline = _TranslationPipeline(
model=self._model,
tokenizer=self._tokenizer,
mpn=self._mpn,
batch_size=self._batch_size,
**self._pipeline_kwargs,
)
Expand Down Expand Up @@ -159,24 +167,34 @@ def close(self) -> None:


class _TranslationPipeline(TranslationPipeline):
def __init__(
self,
model: Union[PreTrainedModel, StrPath, str],
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
batch_size: int,
mpn: Optional[MosesPunctNormalizer] = None,
**kwargs,
) -> None:
super().__init__(model=model, tokenizer=tokenizer, batch_size=batch_size, **kwargs)
self._mpn = mpn

def preprocess(self, *args, truncation=TruncationStrategy.DO_NOT_TRUNCATE, src_lang=None, tgt_lang=None):
if self.tokenizer is None:
raise RuntimeError("No tokenizer is specified.")
sentences = [
s
if isinstance(s, str)
else self.tokenizer.decode(self.tokenizer.convert_tokens_to_ids(s), use_source_tokenizer=True)
for s in args
]
if isinstance(self.tokenizer, (NllbTokenizer, NllbTokenizerFast)):
mpn = MosesPunctNormalizer()
mpn.substitutions = [(re.compile(r), sub) for r, sub in mpn.substitutions]

def normalize_all(lines: Iterable[str]) -> Iterable[str]:
for line in lines:
yield mpn.normalize(line)

sentences = list(normalize_all(sentences))
if self._mpn:
sentences = [
self._mpn.normalize(s)
if isinstance(s, str)
else self.tokenizer.decode(self.tokenizer.convert_tokens_to_ids(s), use_source_tokenizer=True)
for s in args
]
else:
sentences = [
s
if isinstance(s, str)
else self.tokenizer.decode(self.tokenizer.convert_tokens_to_ids(s), use_source_tokenizer=True)
for s in args
]
inputs = cast(
BatchEncoding, super().preprocess(*sentences, truncation=truncation, src_lang=src_lang, tgt_lang=tgt_lang)
)
Expand Down
26 changes: 9 additions & 17 deletions machine/translation/huggingface/hugging_face_nmt_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import re
from pathlib import Path
from typing import Any, Callable, Iterable, List, Optional, Union, cast
from typing import Any, Callable, List, Optional, Union, cast

import datasets.utils.logging as datasets_logging
import torch # pyright: ignore[reportMissingImports]
Expand Down Expand Up @@ -96,6 +96,8 @@ def __init__(
self.max_target_length = max_target_length
self._add_unk_src_tokens = add_unk_src_tokens
self._add_unk_trg_tokens = add_unk_trg_tokens
self._mpn = MosesPunctNormalizer()
self._mpn.substitutions = [(re.compile(r), sub) for r, sub in self._mpn.substitutions]

@property
def stats(self) -> TrainStats:
Expand Down Expand Up @@ -169,9 +171,7 @@ def find_missing_characters(tokenizer: Any, train_dataset: Dataset, lang_codes:
for lang_code in lang_codes:
for ex in train_dataset["translation"]:
charset = charset | set(ex[lang_code])
mpn = MosesPunctNormalizer()
mpn.substitutions = [(re.compile(r), sub) for r, sub in mpn.substitutions]
charset = {mpn.normalize(char) for char in charset}
charset = {self._mpn.normalize(char) for char in charset}
charset = {tokenizer.backend_tokenizer.normalizer.normalize_str(char) for char in charset}
charset = set(filter(None, {char.strip() for char in charset}))
missing_characters = sorted(list(charset - vocab))
Expand Down Expand Up @@ -302,20 +302,12 @@ def add_lang_code_to_tokenizer(tokenizer: Any, lang_code: str):
)

def preprocess_function(examples):
inputs = [ex[src_lang] for ex in examples["translation"]]
targets = [ex[tgt_lang] for ex in examples["translation"]]
inputs = [prefix + inp for inp in inputs]

if isinstance(tokenizer, (NllbTokenizer, NllbTokenizerFast)):
mpn = MosesPunctNormalizer()
mpn.substitutions = [(re.compile(r), sub) for r, sub in mpn.substitutions]

def normalize_all(lines: Iterable[str]) -> Iterable[str]:
for line in lines:
yield mpn.normalize(line)

inputs = list(normalize_all(inputs))
targets = list(normalize_all(targets))
inputs = [self._mpn.normalize(prefix + ex[src_lang]) for ex in examples["translation"]]
targets = [self._mpn.normalize(ex[tgt_lang]) for ex in examples["translation"]]
else:
inputs = [prefix + ex[src_lang] for ex in examples["translation"]]
targets = [ex[tgt_lang] for ex in examples["translation"]]

model_inputs = tokenizer(inputs, max_length=max_source_length, truncation=True)
# Tokenize targets with the `text_target` keyword argument
Expand Down

0 comments on commit 8629f3f

Please sign in to comment.