Skip to content

Commit

Permalink
First try at fixing OOM's
Browse files Browse the repository at this point in the history
  • Loading branch information
johnml1135 committed Nov 8, 2023
1 parent 2174c7d commit 562bab2
Showing 1 changed file with 28 additions and 10 deletions.
38 changes: 28 additions & 10 deletions machine/jobs/nmt_engine_build_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,38 @@ def run(
current_inference_step = 0
phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count))
batch_size = self._config["batch_size"]
translate_batch = TranslateBatch(batch_size)
for pi_batch in batch(src_pretranslations, batch_size):
if check_canceled is not None:
check_canceled()
_translate_batch(engine, pi_batch, writer)
translate_batch.translate(engine, pi_batch, writer)
current_inference_step += len(pi_batch)
phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count))


def _translate_batch(
engine: TranslationEngine,
batch: Sequence[PretranslationInfo],
writer: PretranslationWriter,
) -> None:
source_segments = [pi["translation"] for pi in batch]
for i, result in enumerate(engine.translate_batch(source_segments)):
batch[i]["translation"] = result.translation
writer.write(batch[i])
batch_divisor = 1


class TranslateBatch:
def __init__(self, initial_batch_size):
self.batch_size = initial_batch_size

def translate(
self,
engine: TranslationEngine,
batch: Sequence[PretranslationInfo],
writer: PretranslationWriter,
) -> None:
while True:
source_segments = [pi["translation"] for pi in batch]
outer_batch_size = len(source_segments)
try:
for step in range(0, outer_batch_size, self.batch_size):
for i, result in enumerate(engine.translate_batch(source_segments[step : step + self.batch_size])):
batch[i + step]["translation"] = result.translation
for i in range(len(source_segments)):
writer.write(batch[i])
break
except Exception:
self.batch_size = max(self.batch_size // 2, 1)
logger.info(f"Out of memory error, reducing batch size to {self.batch_size}")

0 comments on commit 562bab2

Please sign in to comment.