From 562bab2a77350b9a05b193dc8991b7e0c5368c44 Mon Sep 17 00:00:00 2001 From: John Lambert Date: Wed, 8 Nov 2023 15:38:42 -0500 Subject: [PATCH] First try at fixing OOM's --- machine/jobs/nmt_engine_build_job.py | 38 ++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/machine/jobs/nmt_engine_build_job.py b/machine/jobs/nmt_engine_build_job.py index 36d7539..276e5ae 100644 --- a/machine/jobs/nmt_engine_build_job.py +++ b/machine/jobs/nmt_engine_build_job.py @@ -87,20 +87,38 @@ def run( current_inference_step = 0 phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count)) batch_size = self._config["batch_size"] + translate_batch = TranslateBatch(batch_size) for pi_batch in batch(src_pretranslations, batch_size): if check_canceled is not None: check_canceled() - _translate_batch(engine, pi_batch, writer) + translate_batch.translate(engine, pi_batch, writer) current_inference_step += len(pi_batch) phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count)) -def _translate_batch( - engine: TranslationEngine, - batch: Sequence[PretranslationInfo], - writer: PretranslationWriter, -) -> None: - source_segments = [pi["translation"] for pi in batch] - for i, result in enumerate(engine.translate_batch(source_segments)): - batch[i]["translation"] = result.translation - writer.write(batch[i]) +batch_divisor = 1 + + +class TranslateBatch: + def __init__(self, initial_batch_size): + self.batch_size = initial_batch_size + + def translate( + self, + engine: TranslationEngine, + batch: Sequence[PretranslationInfo], + writer: PretranslationWriter, + ) -> None: + while True: + source_segments = [pi["translation"] for pi in batch] + outer_batch_size = len(source_segments) + try: + for step in range(0, outer_batch_size, self.batch_size): + for i, result in enumerate(engine.translate_batch(source_segments[step : step + self.batch_size])): + batch[i + step]["translation"] = result.translation + for i in range(len(source_segments)): + writer.write(batch[i]) + break + except Exception: + self.batch_size = max(self.batch_size // 2, 1) + logger.info(f"Out of memory error, reducing batch size to {self.batch_size}")