From 85138a52fb3e6993937257a9cc2bf1033718a46b Mon Sep 17 00:00:00 2001 From: John Lambert Date: Wed, 14 Aug 2024 09:15:20 -0400 Subject: [PATCH] Minor naming fixes --- machine/jobs/engine_build_job.py | 4 ++-- machine/jobs/nmt_engine_build_job.py | 2 +- machine/jobs/smt_engine_build_job.py | 2 +- machine/jobs/smt_model_factory.py | 6 +++++- machine/jobs/thot/thot_smt_model_factory.py | 7 +++++-- 5 files changed, 14 insertions(+), 7 deletions(-) diff --git a/machine/jobs/engine_build_job.py b/machine/jobs/engine_build_job.py index d638e8e..a36939b 100644 --- a/machine/jobs/engine_build_job.py +++ b/machine/jobs/engine_build_job.py @@ -37,7 +37,7 @@ def run( check_canceled() logger.info("Pretranslating segments") - self.pretranslate_segments(progress_reporter, check_canceled) + self.batch_inference(progress_reporter, check_canceled) self.save_model() return self._train_corpus_size, self._confidence @@ -68,7 +68,7 @@ def train_model( ) -> None: ... @abstractmethod - def pretranslate_segments( + def batch_inference( self, progress_reporter: PhasedProgressReporter, check_canceled: Optional[Callable[[], None]], diff --git a/machine/jobs/nmt_engine_build_job.py b/machine/jobs/nmt_engine_build_job.py index baa5bee..4b75295 100644 --- a/machine/jobs/nmt_engine_build_job.py +++ b/machine/jobs/nmt_engine_build_job.py @@ -70,7 +70,7 @@ def train_model( model_trainer.train(progress=phase_progress, check_canceled=check_canceled) model_trainer.save() - def pretranslate_segments( + def batch_inference( self, progress_reporter: PhasedProgressReporter, check_canceled: Optional[Callable[[], None]], diff --git a/machine/jobs/smt_engine_build_job.py b/machine/jobs/smt_engine_build_job.py index ff95ae5..86d2ef4 100644 --- a/machine/jobs/smt_engine_build_job.py +++ b/machine/jobs/smt_engine_build_job.py @@ -57,7 +57,7 @@ def train_model( if check_canceled is not None: check_canceled() - def pretranslate_segments( + def batch_inference( self, progress_reporter: PhasedProgressReporter, check_canceled: Optional[Callable[[], None]], diff --git a/machine/jobs/smt_model_factory.py b/machine/jobs/smt_model_factory.py index 8b22840..ac8aa85 100644 --- a/machine/jobs/smt_model_factory.py +++ b/machine/jobs/smt_model_factory.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from pathlib import Path +from typing import Optional from ..corpora.parallel_text_corpus import ParallelTextCorpus from ..corpora.text_corpus import TextCorpus @@ -25,7 +26,10 @@ def create_model_trainer(self, tokenizer: Tokenizer[str, int, str], corpus: Para @abstractmethod def create_engine( - self, tokenizer: Tokenizer[str, int, str], detokenizer: Detokenizer[str, str], truecaser: Truecaser + self, + tokenizer: Tokenizer[str, int, str], + detokenizer: Detokenizer[str, str], + truecaser: Optional[Truecaser] = None, ) -> TranslationEngine: ... @abstractmethod diff --git a/machine/jobs/thot/thot_smt_model_factory.py b/machine/jobs/thot/thot_smt_model_factory.py index 6c4b64c..9e89e30 100644 --- a/machine/jobs/thot/thot_smt_model_factory.py +++ b/machine/jobs/thot/thot_smt_model_factory.py @@ -1,7 +1,7 @@ import os import shutil from pathlib import Path -from typing import Any +from typing import Any, Optional from ...corpora.parallel_text_corpus import ParallelTextCorpus from ...corpora.text_corpus import TextCorpus @@ -67,7 +67,10 @@ def create_model_trainer(self, tokenizer: Tokenizer[str, int, str], corpus: Para ) def create_engine( - self, tokenizer: Tokenizer[str, int, str], detokenizer: Detokenizer[str, str], truecaser: Truecaser + self, + tokenizer: Tokenizer[str, int, str], + detokenizer: Detokenizer[str, str], + truecaser: Optional[Truecaser] = None, ) -> TranslationEngine: return ThotSmtModel( word_alignment_model_type=self._config.thot.word_alignment_model_type,