diff --git a/machine/jobs/huggingface/hugging_face_nmt_model_factory.py b/machine/jobs/huggingface/hugging_face_nmt_model_factory.py index 685a5d3..d17fe25 100644 --- a/machine/jobs/huggingface/hugging_face_nmt_model_factory.py +++ b/machine/jobs/huggingface/hugging_face_nmt_model_factory.py @@ -10,9 +10,9 @@ from ...corpora.text_corpus import TextCorpus from ...translation.huggingface.hugging_face_nmt_engine import HuggingFaceNmtEngine from ...translation.huggingface.hugging_face_nmt_model_trainer import HuggingFaceNmtModelTrainer -from ...translation.nmt_translation_engine import NmtTranslationEngine from ...translation.null_trainer import NullTrainer from ...translation.trainer import Trainer +from ...translation.translation_engine import TranslationEngine from ..nmt_model_factory import NmtModelFactory from ..shared_file_service import SharedFileService @@ -70,7 +70,7 @@ def create_model_trainer(self, corpus: ParallelTextCorpus) -> Trainer: add_unk_trg_tokens=self._config.huggingface.tokenizer.add_unk_trg_tokens, ) - def create_engine(self, half_previous_batch_size=False) -> NmtTranslationEngine: + def create_engine(self, half_previous_batch_size=False) -> TranslationEngine: if half_previous_batch_size: self._config.huggingface.generate_params.batch_size = max( self._config.huggingface.generate_params.batch_size // 2, 1 diff --git a/machine/jobs/nmt_engine_build_job.py b/machine/jobs/nmt_engine_build_job.py index 3799205..36d7539 100644 --- a/machine/jobs/nmt_engine_build_job.py +++ b/machine/jobs/nmt_engine_build_job.py @@ -3,7 +3,7 @@ from typing import Any, Callable, Optional, Sequence from ..corpora.corpora_utils import batch -from ..translation.nmt_translation_engine import NmtTranslationEngine +from ..translation.translation_engine import TranslationEngine from ..utils.phased_progress_reporter import Phase, PhasedProgressReporter from ..utils.progress_status import ProgressStatus from .nmt_model_factory import NmtModelFactory @@ -81,48 +81,26 @@ def run( inference_step_count = sum(1 for _ in src_pretranslations) with ExitStack() as stack: phase_progress = stack.enter_context(progress_reporter.start_next_phase()) + engine = stack.enter_context(self._nmt_model_factory.create_engine()) src_pretranslations = stack.enter_context(self._shared_file_service.get_source_pretranslations()) writer = stack.enter_context(self._shared_file_service.open_target_pretranslation_writer()) current_inference_step = 0 phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count)) batch_size = self._config["batch_size"] - translate_batch = TranslateBatch(stack, self._nmt_model_factory) for pi_batch in batch(src_pretranslations, batch_size): if check_canceled is not None: check_canceled() - translate_batch.translate(pi_batch, writer) + _translate_batch(engine, pi_batch, writer) current_inference_step += len(pi_batch) phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count)) -batch_divisor = 1 - - -class TranslateBatch: - def __init__(self, stack: ExitStack, nmt_model_factory: NmtModelFactory): - self._stack = stack - self._nmt_model_factory = nmt_model_factory - self._engine: NmtTranslationEngine = self._stack.enter_context(self._nmt_model_factory.create_engine()) - - def translate( - self, - batch: Sequence[PretranslationInfo], - writer: PretranslationWriter, - ) -> None: - while True: - source_segments = [pi["translation"] for pi in batch] - outer_batch_size = len(source_segments) - try: - for step in range(0, outer_batch_size, self._engine.get_batch_size()): - for i, result in enumerate( - self._engine.translate_batch(source_segments[step : step + self._engine.get_batch_size()]) - ): - batch[i + step]["translation"] = result.translation - for i in range(len(source_segments)): - writer.write(batch[i]) - break - except Exception: - logger.info(f"Out of memory error, reducing batch size to {self._engine.get_batch_size() // 2}") - self._engine = self._stack.enter_context( - self._nmt_model_factory.create_engine(half_previous_batch_size=True) - ) +def _translate_batch( + engine: TranslationEngine, + batch: Sequence[PretranslationInfo], + writer: PretranslationWriter, +) -> None: + source_segments = [pi["translation"] for pi in batch] + for i, result in enumerate(engine.translate_batch(source_segments)): + batch[i]["translation"] = result.translation + writer.write(batch[i]) diff --git a/machine/jobs/nmt_model_factory.py b/machine/jobs/nmt_model_factory.py index c108b9a..6161320 100644 --- a/machine/jobs/nmt_model_factory.py +++ b/machine/jobs/nmt_model_factory.py @@ -2,8 +2,8 @@ from ..corpora.parallel_text_corpus import ParallelTextCorpus from ..corpora.text_corpus import TextCorpus -from ..translation.nmt_translation_engine import NmtTranslationEngine from ..translation.trainer import Trainer +from ..translation.translation_engine import TranslationEngine class NmtModelFactory(ABC): @@ -29,7 +29,7 @@ def create_model_trainer(self, corpus: ParallelTextCorpus) -> Trainer: ... @abstractmethod - def create_engine(self, half_previous_batch_size=False) -> NmtTranslationEngine: + def create_engine(self, half_previous_batch_size=False) -> TranslationEngine: ... @abstractmethod diff --git a/machine/jobs/settings.yaml b/machine/jobs/settings.yaml index fa8908f..1343812 100644 --- a/machine/jobs/settings.yaml +++ b/machine/jobs/settings.yaml @@ -34,7 +34,5 @@ staging: max_steps: 10 huggingface: parent_model_name: facebook/nllb-200-distilled-600M - train_params: - group_by_length: false generate_params: num_beams: 1 diff --git a/machine/translation/huggingface/hugging_face_nmt_engine.py b/machine/translation/huggingface/hugging_face_nmt_engine.py index 2bd4abc..dd137f6 100644 --- a/machine/translation/huggingface/hugging_face_nmt_engine.py +++ b/machine/translation/huggingface/hugging_face_nmt_engine.py @@ -12,7 +12,7 @@ from ...annotations.range import Range from ...utils.typeshed import StrPath -from ..nmt_translation_engine import NmtTranslationEngine +from ..translation_engine import TranslationEngine from ..translation_result import TranslationResult from ..translation_result_builder import TranslationResultBuilder from ..translation_sources import TranslationSources @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) -class HuggingFaceNmtEngine(NmtTranslationEngine): +class HuggingFaceNmtEngine(TranslationEngine): def __init__( self, model: Union[PreTrainedModel, StrPath, str], @@ -63,11 +63,7 @@ def __init__( ): raise ValueError(f"'{tgt_lang}' is not a valid language code.") - batch_size = self._pipeline_kwargs.pop("batch_size") - if batch_size is not None: - self._batch_size = int(batch_size) # type: ignore[assignment] - else: - self._batch_size = 16 + self._batch_size = int(self._pipeline_kwargs.pop("batch_size", 1)) # If not set, default to not backing off (1.0). self._oom_batch_size_backoff_multiplier = self._pipeline_kwargs.pop("oom_batch_size_backoff_multiplier", 1.0) @@ -88,9 +84,6 @@ def translate_n(self, n: int, segment: Union[str, Sequence[str]]) -> Sequence[Tr def translate_batch(self, segments: Sequence[Union[str, Sequence[str]]]) -> Sequence[TranslationResult]: return [results[0] for results in self.translate_n_batch(1, segments)] - def get_batch_size(self) -> int: - return self._batch_size - def translate_n_batch( self, n: int, segments: Sequence[Union[str, Sequence[str]]] ) -> Sequence[Sequence[TranslationResult]]: @@ -106,13 +99,19 @@ def translate_n_batch( all_results.extend(self._try_translate_n_batch(n, segments[step : step + self._batch_size])) return all_results except Exception as e: + # The out or memory error is not inherited from if self._oom_batch_size_backoff_multiplier >= 0.9999: raise Exception( - "Likely an Out of Memory Error. Change oom_batch_size_backoff_multiplier to < 1 to gracefuly handle these type of errors." + "Likely an Out of Memory Error. Change oom_batch_size_backoff_multiplier " + + "to < 1 to gracefuly handle these type of errors." ) from e + if self._batch_size == 1: + # Could it be another error? + raise e self._batch_size = max(int(round(self._batch_size * self._oom_batch_size_backoff_multiplier)), 1) logger.info( - f"Out of memory error caught, reducing batch size to {self._batch_size}. Remaking translation pipeline." + f"Out of memory error caught with message {e.args[0]}, reducing batch size to {self._batch_size}. " + + "Remaking translation pipeline." ) self._pipeline = _TranslationPipeline( model=self._model, diff --git a/machine/translation/huggingface/hugging_face_nmt_model_trainer.py b/machine/translation/huggingface/hugging_face_nmt_model_trainer.py index 07b04bc..e4be94a 100644 --- a/machine/translation/huggingface/hugging_face_nmt_model_trainer.py +++ b/machine/translation/huggingface/hugging_face_nmt_model_trainer.py @@ -134,7 +134,6 @@ def train( # Set seed before initializing model. set_seed(self._training_args.seed) - logger.info("Initializing tokenizer.") if isinstance(self._model, PreTrainedModel): model = self._model self._original_use_cache = model.config.use_cache @@ -148,6 +147,8 @@ def train( num_labels=0, ) model = cast(PreTrainedModel, AutoModelForSeq2SeqLM.from_pretrained(self._model, config=config)) + + logger.info("Initializing tokenizer") tokenizer = AutoTokenizer.from_pretrained(model.name_or_path, use_fast=True) src_lang = self._src_lang @@ -194,8 +195,8 @@ def add_tokens(tokenizer: Any, missing_tokens: List[str]) -> Any: logger.info(f"Added {len(missing_tokens)} tokens to the tokenizer: {missing_tokens}") return AutoTokenizer.from_pretrained(str(tokenizer_dir), use_fast=True) - logger.info("Checking for missing tokens.") if self._add_unk_src_tokens or self._add_unk_trg_tokens: + logger.info("Checking for missing tokens") if not isinstance(tokenizer, PreTrainedTokenizerFast): logger.warning( f"Tokenizer can not be updated from default configuration: \ @@ -235,8 +236,8 @@ def add_lang_code_to_tokenizer(tokenizer: Any, lang_code: str): tokenizer.lang_token_to_id[lang_code] = lang_id tokenizer.id_to_lang_token[lang_id] = lang_code - logger.info("Add new language codes as tokens.") if isinstance(tokenizer, MULTILINGUAL_TOKENIZERS): + logger.info("Add new language codes as tokens") if self._src_lang is not None: add_lang_code_to_tokenizer(tokenizer, self._src_lang) if self._tgt_lang is not None: @@ -312,7 +313,7 @@ def preprocess_function(examples): model_inputs["labels"] = labels["input_ids"] return model_inputs - logger.info("Run tokenizer.") + logger.info("Run tokenizer") train_dataset = train_dataset.map( preprocess_function, batched=True, @@ -343,7 +344,7 @@ def preprocess_function(examples): ], ) - logger.info("Train NMT model.") + logger.info("Train NMT model") ckpt = None if self._training_args.resume_from_checkpoint is not None: ckpt = self._training_args.resume_from_checkpoint @@ -357,7 +358,7 @@ def preprocess_function(examples): self._metrics["train_samples"] = len(train_dataset) self._trainer.log_metrics("train", self._metrics) - logger.info("Model training finished.") + logger.info("Model training finished") def save(self) -> None: if self._trainer is None: diff --git a/machine/translation/nmt_translation_engine.py b/machine/translation/nmt_translation_engine.py deleted file mode 100644 index c486a07..0000000 --- a/machine/translation/nmt_translation_engine.py +++ /dev/null @@ -1,12 +0,0 @@ -from __future__ import annotations - -from abc import abstractmethod -from typing import ContextManager - -from .translation_engine import TranslationEngine - - -class NmtTranslationEngine(TranslationEngine, ContextManager["NmtTranslationEngine"]): - @abstractmethod - def get_batch_size(self) -> int: - ... diff --git a/tests/jobs/test_nmt_engine_build_job.py b/tests/jobs/test_nmt_engine_build_job.py index 142149f..b3a1bf6 100644 --- a/tests/jobs/test_nmt_engine_build_job.py +++ b/tests/jobs/test_nmt_engine_build_job.py @@ -10,7 +10,7 @@ from machine.corpora import DictionaryTextCorpus from machine.jobs import NmtEngineBuildJob, NmtModelFactory, PretranslationInfo, PretranslationWriter, SharedFileService from machine.translation import Phrase, Trainer, TrainStats, TranslationResult, TranslationSources, WordAlignmentMatrix -from machine.translation.nmt_translation_engine import NmtTranslationEngine +from machine.translation.translation_engine import TranslationEngine from machine.utils import CanceledError, ContextManagedGenerator @@ -45,9 +45,8 @@ def __init__(self, decoy: Decoy) -> None: stats.metrics["bleu"] = 30.0 decoy.when(self.model_trainer.stats).then_return(stats) - self.engine = decoy.mock(cls=NmtTranslationEngine) + self.engine = decoy.mock(cls=TranslationEngine) decoy.when(self.engine.__enter__()).then_return(self.engine) - decoy.when(self.engine.get_batch_size()).then_return(16) decoy.when(self.engine.translate_batch(matchers.Anything())).then_return( [ TranslationResult(