From 652f6fbbad48dc965c3268330325e6a51653ed12 Mon Sep 17 00:00:00 2001 From: John Lambert Date: Tue, 20 Aug 2024 14:17:57 -0400 Subject: [PATCH] First set of reviewer comments.. --- .vscode/settings.json | 3 +- machine/jobs/__init__.py | 11 +- machine/jobs/build_smt_engine.py | 5 +- machine/jobs/build_word_alignment_model.py | 3 +- machine/jobs/clearml_shared_file_service.py | 48 +++--- machine/jobs/local_shared_file_service.py | 6 +- machine/jobs/nmt_engine_build_job.py | 33 +++-- machine/jobs/settings.yaml | 3 +- machine/jobs/shared_file_service.py | 138 ------------------ machine/jobs/shared_file_service_base.py | 77 ++++++++++ machine/jobs/shared_file_service_factory.py | 20 +++ machine/jobs/smt_engine_build_job.py | 31 ++-- .../thot/thot_word_alignment_model_factory.py | 2 +- ...job.py => translation_engine_build_job.py} | 36 ++--- machine/jobs/translation_file_service.py | 78 ++++++++++ machine/jobs/word_alignment_build_job.py | 64 ++++++-- machine/jobs/word_alignment_file_service.py | 55 +++++++ tests/jobs/test_nmt_engine_build_job.py | 26 ++-- tests/jobs/test_smt_engine_build_job.py | 24 +-- tests/jobs/test_word_alignment_build_job.py | 43 ++---- 20 files changed, 411 insertions(+), 295 deletions(-) delete mode 100644 machine/jobs/shared_file_service.py create mode 100644 machine/jobs/shared_file_service_base.py create mode 100644 machine/jobs/shared_file_service_factory.py rename machine/jobs/{engine_build_job.py => translation_engine_build_job.py} (66%) create mode 100644 machine/jobs/translation_file_service.py create mode 100644 machine/jobs/word_alignment_file_service.py diff --git a/.vscode/settings.json b/.vscode/settings.json index af37422..024f091 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,13 +1,14 @@ { "editor.formatOnSave": true, "editor.codeActionsOnSave": { - "source.organizeImports": "explicit" + "source.organizeImports": "explicit", }, "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true, "python.analysis.extraPaths": [ "tests" ], + "python.analysis.importFormat": "relative", "[python]": { "editor.defaultFormatter": "ms-python.black-formatter", "editor.formatOnSave": true diff --git a/machine/jobs/__init__.py b/machine/jobs/__init__.py index ec82eac..727da68 100644 --- a/machine/jobs/__init__.py +++ b/machine/jobs/__init__.py @@ -2,10 +2,12 @@ from .local_shared_file_service import LocalSharedFileService from .nmt_engine_build_job import NmtEngineBuildJob from .nmt_model_factory import NmtModelFactory -from .shared_file_service import DictToJsonWriter, PretranslationInfo, SharedFileService +from .shared_file_service_base import DictToJsonWriter, SharedFileServiceBase from .smt_engine_build_job import SmtEngineBuildJob from .smt_model_factory import SmtModelFactory +from .translation_file_service import PretranslationInfo, TranslationFileService from .word_alignment_build_job import WordAlignmentBuildJob +from .word_alignment_file_service import WordAlignmentFileService, WordAlignmentInfo from .word_alignment_model_factory import WordAlignmentModelFactory __all__ = [ @@ -13,11 +15,14 @@ "LocalSharedFileService", "NmtEngineBuildJob", "NmtModelFactory", - "PretranslationInfo", "DictToJsonWriter", - "SharedFileService", + "SharedFileServiceBase", "SmtEngineBuildJob", "SmtModelFactory", + "PretranslationInfo", + "TranslationFileService", "WordAlignmentBuildJob", + "WordAlignmentFileService", + "WordAlignmentInfo", "WordAlignmentModelFactory", ] diff --git a/machine/jobs/build_smt_engine.py b/machine/jobs/build_smt_engine.py index d2cc169..56a4caa 100644 --- a/machine/jobs/build_smt_engine.py +++ b/machine/jobs/build_smt_engine.py @@ -15,10 +15,11 @@ update_runtime_properties, update_settings, ) -from .clearml_shared_file_service import ClearMLSharedFileService from .config import SETTINGS +from .shared_file_service_factory import SharedFileServiceType from .smt_engine_build_job import SmtEngineBuildJob from .smt_model_factory import SmtModelFactory +from .translation_file_service import TranslationFileService # Setup logging logging.basicConfig( @@ -54,7 +55,7 @@ def run(args: dict) -> None: logger.info(f"Config: {SETTINGS.as_dict()}") - shared_file_service = ClearMLSharedFileService(SETTINGS) + shared_file_service = TranslationFileService(SharedFileServiceType.CLEARML, SETTINGS) smt_model_factory: SmtModelFactory if SETTINGS.model_type == "thot": from .thot.thot_smt_model_factory import ThotSmtModelFactory diff --git a/machine/jobs/build_word_alignment_model.py b/machine/jobs/build_word_alignment_model.py index 5717598..802ff35 100644 --- a/machine/jobs/build_word_alignment_model.py +++ b/machine/jobs/build_word_alignment_model.py @@ -4,8 +4,6 @@ from clearml import Task -from machine.jobs.thot.thot_word_alignment_model_factory import ThotWordAlignmentModelFactory - from ..utils.progress_status import ProgressStatus from .async_scheduler import AsyncScheduler from .build_clearml_helper import ( @@ -19,6 +17,7 @@ ) from .clearml_shared_file_service import ClearMLSharedFileService from .config import SETTINGS +from .thot.thot_word_alignment_model_factory import ThotWordAlignmentModelFactory from .word_alignment_build_job import WordAlignmentBuildJob from .word_alignment_model_factory import WordAlignmentModelFactory diff --git a/machine/jobs/clearml_shared_file_service.py b/machine/jobs/clearml_shared_file_service.py index 07bc9e4..aeeaaff 100644 --- a/machine/jobs/clearml_shared_file_service.py +++ b/machine/jobs/clearml_shared_file_service.py @@ -5,59 +5,45 @@ from clearml import StorageManager -from .shared_file_service import SharedFileService +from .shared_file_service_base import SharedFileServiceBase logger = logging.getLogger(__name__) -class ClearMLSharedFileService(SharedFileService): - def _download_file(self, path: str) -> Path: - uri = f"{self._shared_file_uri}/{self._shared_file_folder}/{path}" +class ClearMLSharedFileService(SharedFileServiceBase): + def download_file(self, path: str) -> Path: local_folder = str(self._data_dir) - file_path = try_n_times(lambda: StorageManager.download_file(uri, local_folder, skip_zero_size_check=True)) + file_path = try_n_times( + lambda: StorageManager.download_file(self._get_uri(path), local_folder, skip_zero_size_check=True) + ) if file_path is None: - raise RuntimeError(f"Failed to download file: {uri}") + raise RuntimeError(f"Failed to download file: {self._get_uri(path)}") return Path(file_path) def _download_folder(self, path: str) -> Path: - uri = f"{self._shared_file_uri}/{self._shared_file_folder}/{path}" local_folder = str(self._data_dir) - folder_path = try_n_times(lambda: StorageManager.download_folder(uri, local_folder)) + folder_path = try_n_times(lambda: StorageManager.download_folder(self._get_uri(path), local_folder)) if folder_path is None: - raise RuntimeError(f"Failed to download folder: {uri}") + raise RuntimeError(f"Failed to download folder: {self._get_uri(path)}") return Path(folder_path) / path def _exists_file(self, path: str) -> bool: - uri = f"{self._shared_file_uri}/{self._shared_file_folder}/{path}" - return try_n_times(lambda: StorageManager.exists_file(uri)) # type: ignore + return try_n_times(lambda: StorageManager.exists_file(self._get_uri(path))) # type: ignore def _upload_file(self, path: str, local_file_path: Path) -> None: - final_destination = try_n_times( - lambda: StorageManager.upload_file( - str(local_file_path), f"{self._shared_file_uri}/{self._shared_file_folder}/{path}" - ) - ) + final_destination = try_n_times(lambda: StorageManager.upload_file(str(local_file_path), self._get_uri(path))) if final_destination is None: - logger.error( - ( - f"Failed to upload file {str(local_file_path)} " - f"to {self._shared_file_uri}/{self._shared_file_folder}/{path}." - ) - ) + logger.error((f"Failed to upload file {str(local_file_path)} " f"to {self._get_uri(path)}.")) def _upload_folder(self, path: str, local_folder_path: Path) -> None: final_destination = try_n_times( - lambda: StorageManager.upload_folder( - str(local_folder_path), f"{self._shared_file_uri}/{self._shared_file_folder}/{path}" - ) + lambda: StorageManager.upload_folder(str(local_folder_path), self._get_uri(path)) ) if final_destination is None: - logger.error( - ( - f"Failed to upload folder {str(local_folder_path)} " - f"to {self._shared_file_uri}/{self._shared_file_folder}/{path}." - ) - ) + logger.error((f"Failed to upload folder {str(local_folder_path)} " f"to {self._get_uri(path)}.")) + + def _get_uri(self, path: str) -> str: + return f"{self._shared_file_uri}/{self._shared_file_folder}/{path}" def try_n_times(func: Callable, n=10): diff --git a/machine/jobs/local_shared_file_service.py b/machine/jobs/local_shared_file_service.py index 77553e6..a0d5fd8 100644 --- a/machine/jobs/local_shared_file_service.py +++ b/machine/jobs/local_shared_file_service.py @@ -2,13 +2,13 @@ import shutil from pathlib import Path -from .shared_file_service import SharedFileService +from .shared_file_service_base import SharedFileServiceBase logger = logging.getLogger(__name__) -class LocalSharedFileService(SharedFileService): - def _download_file(self, path: str) -> Path: +class LocalSharedFileService(SharedFileServiceBase): + def download_file(self, path: str) -> Path: src_path = self._get_path(path) dst_path = self._data_dir / self._shared_file_folder / path dst_path.parent.mkdir(parents=True, exist_ok=True) diff --git a/machine/jobs/nmt_engine_build_job.py b/machine/jobs/nmt_engine_build_job.py index c74bf2e..dc3e9ee 100644 --- a/machine/jobs/nmt_engine_build_job.py +++ b/machine/jobs/nmt_engine_build_job.py @@ -6,19 +6,22 @@ from ..translation.translation_engine import TranslationEngine from ..utils.phased_progress_reporter import Phase, PhasedProgressReporter from ..utils.progress_status import ProgressStatus -from .engine_build_job import EngineBuildJob from .nmt_model_factory import NmtModelFactory -from .shared_file_service import DictToJsonWriter, PretranslationInfo, SharedFileService +from .shared_file_service_base import DictToJsonWriter +from .translation_engine_build_job import TranslationEngineBuildJob +from .translation_file_service import PretranslationInfo, TranslationFileService logger = logging.getLogger(__name__) -class NmtEngineBuildJob(EngineBuildJob): - def __init__(self, config: Any, nmt_model_factory: NmtModelFactory, shared_file_service: SharedFileService) -> None: +class NmtEngineBuildJob(TranslationEngineBuildJob): + def __init__( + self, config: Any, nmt_model_factory: NmtModelFactory, translation_file_service: TranslationFileService + ) -> None: self._nmt_model_factory = nmt_model_factory - super().__init__(config, shared_file_service) + super().__init__(config, translation_file_service) - def start_job(self) -> None: + def _start_job(self) -> None: self._nmt_model_factory.init() def _get_progress_reporter(self, progress: Optional[Callable[[ProgressStatus], None]]) -> PhasedProgressReporter: @@ -31,10 +34,10 @@ def _get_progress_reporter(self, progress: Optional[Callable[[ProgressStatus], N phases = [Phase(message="Pretranslating segments", percentage=1.0)] return PhasedProgressReporter(progress, phases) - def respond_to_no_training_corpus(self) -> None: + def _respond_to_no_training_corpus(self) -> None: logger.info("No matching entries in the source and target corpus - skipping training") - def train_model( + def _train_model( self, progress_reporter: PhasedProgressReporter, check_canceled: Optional[Callable[[], None]], @@ -70,19 +73,19 @@ def train_model( model_trainer.train(progress=phase_progress, check_canceled=check_canceled) model_trainer.save() - def batch_inference( + def _batch_inference( self, progress_reporter: PhasedProgressReporter, check_canceled: Optional[Callable[[], None]], ) -> None: logger.info("Pretranslating segments") - with self._shared_file_service.get_source_pretranslations() as src_pretranslations: + with self._translation_file_service.get_source_pretranslations() as src_pretranslations: inference_step_count = sum(1 for _ in src_pretranslations) with ExitStack() as stack: phase_progress = stack.enter_context(progress_reporter.start_next_phase()) engine = stack.enter_context(self._nmt_model_factory.create_engine()) - src_pretranslations = stack.enter_context(self._shared_file_service.get_source_pretranslations()) - writer = stack.enter_context(self._shared_file_service.open_target_pretranslation_writer()) + src_pretranslations = stack.enter_context(self._translation_file_service.get_source_pretranslations()) + writer = stack.enter_context(self._translation_file_service.open_target_pretranslation_writer()) current_inference_step = 0 phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count)) batch_size = self._config["inference_batch_size"] @@ -93,11 +96,11 @@ def batch_inference( current_inference_step += len(pi_batch) phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count)) - def save_model(self) -> None: + def _save_model(self) -> None: if "save_model" in self._config and self._config.save_model is not None: logger.info("Saving model") model_path = self._nmt_model_factory.save_model() - self._shared_file_service.save_model( + self._translation_file_service.save_model( model_path, f"models/{self._config.save_model + ''.join(model_path.suffixes)}" ) @@ -110,4 +113,4 @@ def _translate_batch( source_segments = [pi["translation"] for pi in batch] for i, result in enumerate(engine.translate_batch(source_segments)): batch[i]["translation"] = result.translation - writer.write(batch[i]) # type: ignore + writer.write(batch[i]) diff --git a/machine/jobs/settings.yaml b/machine/jobs/settings.yaml index a60b9ef..ccd6ab0 100644 --- a/machine/jobs/settings.yaml +++ b/machine/jobs/settings.yaml @@ -27,8 +27,9 @@ default: add_unk_trg_tokens: true thot: word_alignment_model_type: hmm - word_alignment_heuristic: grow-diag-final-and tokenizer: latin + thot_align: + word_alignment_heuristic: grow-diag-final-and development: shared_file_folder: dev huggingface: diff --git a/machine/jobs/shared_file_service.py b/machine/jobs/shared_file_service.py deleted file mode 100644 index 3780f96..0000000 --- a/machine/jobs/shared_file_service.py +++ /dev/null @@ -1,138 +0,0 @@ -import json -from abc import ABC, abstractmethod -from contextlib import contextmanager -from pathlib import Path -from typing import Any, Generator, Iterator, List, MutableMapping, TextIO, TypedDict - -import json_stream - -from ..corpora.text_corpus import TextCorpus -from ..corpora.text_file_text_corpus import TextFileTextCorpus -from ..utils.context_managed_generator import ContextManagedGenerator - - -class PretranslationInfo(TypedDict): - corpusId: str # noqa: N815 - textId: str # noqa: N815 - refs: List[str] - translation: str - - -class WordAlignmentInfo(TypedDict): - refs: List[str] - column_count: int - row_count: int - alignmnent: str - - -class DictToJsonWriter: - def __init__(self, file: TextIO) -> None: - self._file = file - self._first = True - - # Use MutableMapping rather than TypeDict to allow for more flexible input - def write(self, pi: MutableMapping) -> None: - if not self._first: - self._file.write(",\n") - self._file.write(" " + json.dumps(pi)) - self._first = False - - -class SharedFileService(ABC): - def __init__( - self, - config: Any, - source_filename: str = "train.src.txt", - target_filename: str = "train.trg.txt", - ) -> None: - self._config = config - self._source_filename = source_filename - self._target_filename = target_filename - - def create_source_corpus(self) -> TextCorpus: - return TextFileTextCorpus(self._download_file(f"{self._build_path}/{self._source_filename}")) - - def create_target_corpus(self) -> TextCorpus: - return TextFileTextCorpus(self._download_file(f"{self._build_path}/{self._target_filename}")) - - def exists_source_corpus(self) -> bool: - return self._exists_file(f"{self._build_path}/{self._source_filename}") - - def exists_target_corpus(self) -> bool: - return self._exists_file(f"{self._build_path}/{self._target_filename}") - - def get_source_pretranslations(self) -> ContextManagedGenerator[PretranslationInfo, None, None]: - src_pretranslate_path = self._download_file(f"{self._build_path}/pretranslate.src.json") - - def generator() -> Generator[PretranslationInfo, None, None]: - with src_pretranslate_path.open("r", encoding="utf-8-sig") as file: - for pi in json_stream.load(file): - yield PretranslationInfo( - corpusId=pi["corpusId"], - textId=pi["textId"], - refs=list(pi["refs"]), - translation=pi["translation"], - ) - - return ContextManagedGenerator(generator()) - - @contextmanager - def open_target_pretranslation_writer(self) -> Iterator[DictToJsonWriter]: - return self._open_target_writer("pretranslate.trg.json") - - @contextmanager - def open_target_alignment_writer(self) -> Iterator[DictToJsonWriter]: - return self._open_target_writer("word_alignments.json") - - def _open_target_writer(self, filename) -> Iterator[DictToJsonWriter]: - build_dir = self._data_dir / self._shared_file_folder / self._build_path - build_dir.mkdir(parents=True, exist_ok=True) - target_path = build_dir / filename - with target_path.open("w", encoding="utf-8", newline="\n") as file: - file.write("[\n") - yield DictToJsonWriter(file) - file.write("\n]\n") - self._upload_file(f"{self._build_path}/{filename}", target_path) - - def save_model(self, model_path: Path, destination: str) -> None: - if model_path.is_file(): - self._upload_file(destination, model_path) - else: - self._upload_folder(destination, model_path) - - @property - def _data_dir(self) -> Path: - return Path(self._config.data_dir) - - @property - def _build_path(self) -> str: - return f"builds/{self._config.build_id}" - - @property - def _engine_id(self) -> str: - return self._config.engine_id - - @property - def _shared_file_uri(self) -> str: - shared_file_uri: str = self._config.shared_file_uri - return shared_file_uri.rstrip("/") - - @property - def _shared_file_folder(self) -> str: - shared_file_folder: str = self._config.shared_file_folder - return shared_file_folder.rstrip("/") - - @abstractmethod - def _download_file(self, path: str) -> Path: ... - - @abstractmethod - def _download_folder(self, path: str) -> Path: ... - - @abstractmethod - def _exists_file(self, path: str) -> bool: ... - - @abstractmethod - def _upload_file(self, path: str, local_file_path: Path) -> None: ... - - @abstractmethod - def _upload_folder(self, path: str, local_folder_path: Path) -> None: ... diff --git a/machine/jobs/shared_file_service_base.py b/machine/jobs/shared_file_service_base.py new file mode 100644 index 0000000..ac8ed83 --- /dev/null +++ b/machine/jobs/shared_file_service_base.py @@ -0,0 +1,77 @@ +import json +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, Iterator, TextIO + + +class DictToJsonWriter: + def __init__(self, file: TextIO) -> None: + self._file = file + self._first = True + + def write(self, pi: object) -> None: + if not self._first: + self._file.write(",\n") + self._file.write(" " + json.dumps(pi)) + self._first = False + + +class SharedFileServiceBase(ABC): + def __init__( + self, + config: Any, + ) -> None: + self._config = config + + def upload_path(self, path: Path, destination: str) -> None: + if path.is_file(): + self._upload_file(destination, path) + else: + self._upload_folder(destination, path) + + def open_target_writer(self, filename) -> Iterator[DictToJsonWriter]: + build_dir = self._data_dir / self._shared_file_folder / self.build_path + build_dir.mkdir(parents=True, exist_ok=True) + target_path = build_dir / filename + with target_path.open("w", encoding="utf-8", newline="\n") as file: + file.write("[\n") + yield DictToJsonWriter(file) + file.write("\n]\n") + self._upload_file(f"{self.build_path}/{filename}", target_path) + + @property + def build_path(self) -> str: + return f"builds/{self._config.build_id}" + + @property + def _data_dir(self) -> Path: + return Path(self._config.data_dir) + + @property + def _engine_id(self) -> str: + return self._config.engine_id + + @property + def _shared_file_uri(self) -> str: + shared_file_uri: str = self._config.shared_file_uri + return shared_file_uri.rstrip("/") + + @property + def _shared_file_folder(self) -> str: + shared_file_folder: str = self._config.shared_file_folder + return shared_file_folder.rstrip("/") + + @abstractmethod + def download_file(self, path: str) -> Path: ... + + @abstractmethod + def _download_folder(self, path: str) -> Path: ... + + @abstractmethod + def _exists_file(self, path: str) -> bool: ... + + @abstractmethod + def _upload_file(self, path: str, local_file_path: Path) -> None: ... + + @abstractmethod + def _upload_folder(self, path: str, local_folder_path: Path) -> None: ... diff --git a/machine/jobs/shared_file_service_factory.py b/machine/jobs/shared_file_service_factory.py new file mode 100644 index 0000000..afdcbe0 --- /dev/null +++ b/machine/jobs/shared_file_service_factory.py @@ -0,0 +1,20 @@ +from enum import IntEnum, auto +from typing import Any, Union + +from machine.jobs.clearml_shared_file_service import ClearMLSharedFileService +from machine.jobs.local_shared_file_service import LocalSharedFileService +from machine.jobs.shared_file_service_base import SharedFileServiceBase + + +class SharedFileServiceType(IntEnum): + LOCAL = auto() + CLEARML = auto() + + +def get_shared_file_service(type: Union[str, SharedFileServiceType], config: Any) -> SharedFileServiceBase: + if isinstance(type, str): + type = SharedFileServiceType[type.upper()] + if type == SharedFileServiceType.LOCAL: + return LocalSharedFileService(config) + elif type == SharedFileServiceType.CLEARML: + return ClearMLSharedFileService(config) diff --git a/machine/jobs/smt_engine_build_job.py b/machine/jobs/smt_engine_build_job.py index 1aee9f4..8b73d1e 100644 --- a/machine/jobs/smt_engine_build_job.py +++ b/machine/jobs/smt_engine_build_job.py @@ -6,19 +6,22 @@ from ..translation.translation_engine import TranslationEngine from ..utils.phased_progress_reporter import Phase, PhasedProgressReporter from ..utils.progress_status import ProgressStatus -from .engine_build_job import EngineBuildJob -from .shared_file_service import DictToJsonWriter, PretranslationInfo, SharedFileService +from .shared_file_service_base import DictToJsonWriter from .smt_model_factory import SmtModelFactory +from .translation_engine_build_job import TranslationEngineBuildJob +from .translation_file_service import PretranslationInfo, TranslationFileService logger = logging.getLogger(__name__) -class SmtEngineBuildJob(EngineBuildJob): - def __init__(self, config: Any, smt_model_factory: SmtModelFactory, shared_file_service: SharedFileService) -> None: +class SmtEngineBuildJob(TranslationEngineBuildJob): + def __init__( + self, config: Any, smt_model_factory: SmtModelFactory, shared_file_service: TranslationFileService + ) -> None: self._smt_model_factory = smt_model_factory super().__init__(config, shared_file_service) - def start_job(self) -> None: + def _start_job(self) -> None: self._smt_model_factory.init() self._tokenizer = self._smt_model_factory.create_tokenizer() logger.info(f"Tokenizer: {type(self._tokenizer).__name__}") @@ -31,10 +34,10 @@ def _get_progress_reporter(self, progress: Optional[Callable[[ProgressStatus], N ] return PhasedProgressReporter(progress, phases) - def respond_to_no_training_corpus(self) -> None: + def _respond_to_no_training_corpus(self) -> None: raise RuntimeError("No parallel corpus data found") - def train_model( + def _train_model( self, progress_reporter: PhasedProgressReporter, check_canceled: Optional[Callable[[], None]], @@ -57,12 +60,12 @@ def train_model( if check_canceled is not None: check_canceled() - def batch_inference( + def _batch_inference( self, progress_reporter: PhasedProgressReporter, check_canceled: Optional[Callable[[], None]], ) -> None: - with self._shared_file_service.get_source_pretranslations() as src_pretranslations: + with self._translation_file_service.get_source_pretranslations() as src_pretranslations: inference_step_count = sum(1 for _ in src_pretranslations) with ExitStack() as stack: @@ -70,8 +73,8 @@ def batch_inference( truecaser = self._smt_model_factory.create_truecaser() phase_progress = stack.enter_context(progress_reporter.start_next_phase()) engine = stack.enter_context(self._smt_model_factory.create_engine(self._tokenizer, detokenizer, truecaser)) - src_pretranslations = stack.enter_context(self._shared_file_service.get_source_pretranslations()) - writer = stack.enter_context(self._shared_file_service.open_target_pretranslation_writer()) + src_pretranslations = stack.enter_context(self._translation_file_service.get_source_pretranslations()) + writer = stack.enter_context(self._translation_file_service.open_target_pretranslation_writer()) current_inference_step = 0 phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count)) batch_size = self._config["inference_batch_size"] @@ -82,10 +85,10 @@ def batch_inference( current_inference_step += len(pi_batch) phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count)) - def save_model(self) -> None: + def _save_model(self) -> None: logger.info("Saving model") model_path = self._smt_model_factory.save_model() - self._shared_file_service.save_model( + self._translation_file_service.save_model( model_path, f"builds/{self._config['build_id']}/model{''.join(model_path.suffixes)}" ) @@ -98,4 +101,4 @@ def _translate_batch( source_segments = [pi["translation"] for pi in batch] for i, result in enumerate(engine.translate_batch(source_segments)): batch[i]["translation"] = result.translation - writer.write(batch[i]) # type: ignore + writer.write(batch[i]) diff --git a/machine/jobs/thot/thot_word_alignment_model_factory.py b/machine/jobs/thot/thot_word_alignment_model_factory.py index 7ba92b4..2c46e72 100644 --- a/machine/jobs/thot/thot_word_alignment_model_factory.py +++ b/machine/jobs/thot/thot_word_alignment_model_factory.py @@ -38,7 +38,7 @@ def create_alignment_model( create_thot_word_alignment_model(self._config.thot.word_alignment_model_type, self._direct_model_path), create_thot_word_alignment_model(self._config.thot.word_alignment_model_type, self._inverse_model_path), ) - model.heuristic = self._config.thot.word_alignment_heuristic + model.heuristic = self._config.thot_align.word_alignment_heuristic return model @property diff --git a/machine/jobs/engine_build_job.py b/machine/jobs/translation_engine_build_job.py similarity index 66% rename from machine/jobs/engine_build_job.py rename to machine/jobs/translation_engine_build_job.py index 07a3ec2..68fdfe0 100644 --- a/machine/jobs/engine_build_job.py +++ b/machine/jobs/translation_engine_build_job.py @@ -5,15 +5,15 @@ from ..corpora.parallel_text_corpus import ParallelTextCorpus from ..utils.phased_progress_reporter import PhasedProgressReporter from ..utils.progress_status import ProgressStatus -from .shared_file_service import SharedFileService +from .translation_file_service import TranslationFileService logger = logging.getLogger(__name__) -class EngineBuildJob(ABC): - def __init__(self, config: Any, shared_file_service: SharedFileService) -> None: +class TranslationEngineBuildJob(ABC): + def __init__(self, config: Any, translation_file_service: TranslationFileService) -> None: self._config = config - self._shared_file_service = shared_file_service + self._translation_file_service = translation_file_service self._train_corpus_size = -1 self._confidence = -1 @@ -25,32 +25,32 @@ def run( if check_canceled is not None: check_canceled() - self.start_job() - self.init_corpus() + self._start_job() + self._init_corpus() progress_reporter = self._get_progress_reporter(progress) if self._parallel_corpus_size == 0: - self.respond_to_no_training_corpus() + self._respond_to_no_training_corpus() else: - self.train_model(progress_reporter, check_canceled) + self._train_model(progress_reporter, check_canceled) if check_canceled is not None: check_canceled() logger.info("Pretranslating segments") - self.batch_inference(progress_reporter, check_canceled) + self._batch_inference(progress_reporter, check_canceled) - self.save_model() + self._save_model() return self._train_corpus_size, self._confidence @abstractmethod - def start_job(self) -> None: ... + def _start_job(self) -> None: ... - def init_corpus(self) -> None: + def _init_corpus(self) -> None: logger.info("Downloading data files") if "_source_corpus" not in self.__dict__: - self._source_corpus = self._shared_file_service.create_source_corpus() - self._target_corpus = self._shared_file_service.create_target_corpus() + self._source_corpus = self._translation_file_service.create_source_corpus() + self._target_corpus = self._translation_file_service.create_target_corpus() self._parallel_corpus: ParallelTextCorpus = self._source_corpus.align_rows(self._target_corpus) self._parallel_corpus_size = self._parallel_corpus.count(include_empty=False) @@ -60,21 +60,21 @@ def _get_progress_reporter( ) -> PhasedProgressReporter: ... @abstractmethod - def respond_to_no_training_corpus(self) -> None: ... + def _respond_to_no_training_corpus(self) -> None: ... @abstractmethod - def train_model( + def _train_model( self, progress_reporter: PhasedProgressReporter, check_canceled: Optional[Callable[[], None]], ) -> None: ... @abstractmethod - def batch_inference( + def _batch_inference( self, progress_reporter: PhasedProgressReporter, check_canceled: Optional[Callable[[], None]], ) -> None: ... @abstractmethod - def save_model(self) -> None: ... + def _save_model(self) -> None: ... diff --git a/machine/jobs/translation_file_service.py b/machine/jobs/translation_file_service.py new file mode 100644 index 0000000..87b9c82 --- /dev/null +++ b/machine/jobs/translation_file_service.py @@ -0,0 +1,78 @@ +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Generator, Iterator, List, TypedDict, Union + +import json_stream + +from machine.jobs.shared_file_service_base import DictToJsonWriter, SharedFileServiceBase +from machine.jobs.shared_file_service_factory import SharedFileServiceType, get_shared_file_service + +from ..corpora.text_corpus import TextCorpus +from ..corpora.text_file_text_corpus import TextFileTextCorpus +from ..utils.context_managed_generator import ContextManagedGenerator + + +class PretranslationInfo(TypedDict): + corpusId: str # noqa: N815 + textId: str # noqa: N815 + refs: List[str] + translation: str + + +class TranslationFileService: + def __init__( + self, + type: Union[str, SharedFileServiceType], + config: Any, + source_filename: str = "train.src.txt", + target_filename: str = "train.trg.txt", + source_pretranslate_filename: str = "pretranslate.src.json", + target_pretranslate_filename: str = "pretranslate.trg.json", + ) -> None: + + self._source_filename = source_filename + self._target_filename = target_filename + self._source_pretranslate_filename = source_pretranslate_filename + self._target_pretranslate_filename = target_pretranslate_filename + + self.shared_file_service: SharedFileServiceBase = get_shared_file_service(type, config) + + def create_source_corpus(self) -> TextCorpus: + return TextFileTextCorpus( + self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{self._source_filename}") + ) + + def create_target_corpus(self) -> TextCorpus: + return TextFileTextCorpus( + self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{self._target_filename}") + ) + + def exists_source_corpus(self) -> bool: + return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{self._source_filename}") + + def exists_target_corpus(self) -> bool: + return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{self._target_filename}") + + def get_source_pretranslations(self) -> ContextManagedGenerator[PretranslationInfo, None, None]: + src_pretranslate_path = self.shared_file_service.download_file( + f"{self.shared_file_service.build_path}/{self._source_pretranslate_filename}" + ) + + def generator() -> Generator[PretranslationInfo, None, None]: + with src_pretranslate_path.open("r", encoding="utf-8-sig") as file: + for pi in json_stream.load(file): + yield PretranslationInfo( + corpusId=pi["corpusId"], + textId=pi["textId"], + refs=list(pi["refs"]), + translation=pi["translation"], + ) + + return ContextManagedGenerator(generator()) + + def save_model(self, model_path: Path, destination: str) -> None: + self.shared_file_service.upload_path(model_path, destination) + + @contextmanager + def open_target_pretranslation_writer(self) -> Iterator[DictToJsonWriter]: + return self.shared_file_service.open_target_writer(self._target_pretranslate_filename) diff --git a/machine/jobs/word_alignment_build_job.py b/machine/jobs/word_alignment_build_job.py index 3395461..d5dee72 100644 --- a/machine/jobs/word_alignment_build_job.py +++ b/machine/jobs/word_alignment_build_job.py @@ -1,27 +1,63 @@ import logging from contextlib import ExitStack -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Tuple +from ..corpora.parallel_text_corpus import ParallelTextCorpus from ..utils.phased_progress_reporter import Phase, PhasedProgressReporter from ..utils.progress_status import ProgressStatus -from .engine_build_job import EngineBuildJob -from .shared_file_service import SharedFileService, WordAlignmentInfo +from .word_alignment_file_service import WordAlignmentFileService, WordAlignmentInfo from .word_alignment_model_factory import WordAlignmentModelFactory logger = logging.getLogger(__name__) -class WordAlignmentBuildJob(EngineBuildJob): +class WordAlignmentBuildJob: def __init__( self, config: Any, word_alignment_model_factory: WordAlignmentModelFactory, - shared_file_service: SharedFileService, + word_alignment_file_service: WordAlignmentFileService, ) -> None: self._word_alignment_model_factory = word_alignment_model_factory - super().__init__(config, shared_file_service) + self._config = config + self._word_alignment_file_service = word_alignment_file_service + self._train_corpus_size = -1 - def start_job(self) -> None: + def run( + self, + progress: Optional[Callable[[ProgressStatus], None]] = None, + check_canceled: Optional[Callable[[], None]] = None, + ) -> Tuple[int, float]: + if check_canceled is not None: + check_canceled() + + self._start_job() + self._init_corpus() + progress_reporter = self._get_progress_reporter(progress) + + if self._parallel_corpus_size == 0: + self._respond_to_no_training_corpus() + else: + self._train_model(progress_reporter, check_canceled) + + if check_canceled is not None: + check_canceled() + + logger.info("Pretranslating segments") + self._batch_inference(progress_reporter, check_canceled) + + self._save_model() + return self._train_corpus_size, self._confidence + + def _init_corpus(self) -> None: + logger.info("Downloading data files") + if "_source_corpus" not in self.__dict__: + self._source_corpus = self._word_alignment_file_service.create_source_corpus() + self._target_corpus = self._word_alignment_file_service.create_target_corpus() + self._parallel_corpus: ParallelTextCorpus = self._source_corpus.align_rows(self._target_corpus) + self._parallel_corpus_size = self._parallel_corpus.count(include_empty=False) + + def _start_job(self) -> None: self._word_alignment_model_factory.init() self._tokenizer = self._word_alignment_model_factory.create_tokenizer() logger.info(f"Tokenizer: {type(self._tokenizer).__name__}") @@ -33,10 +69,10 @@ def _get_progress_reporter(self, progress: Optional[Callable[[ProgressStatus], N ] return PhasedProgressReporter(progress, phases) - def respond_to_no_training_corpus(self) -> None: + def _respond_to_no_training_corpus(self) -> None: raise RuntimeError("No parallel corpus data found") - def train_model( + def _train_model( self, progress_reporter: PhasedProgressReporter, check_canceled: Optional[Callable[[], None]], @@ -53,7 +89,7 @@ def train_model( if check_canceled is not None: check_canceled() - def batch_inference( + def _batch_inference( self, progress_reporter: PhasedProgressReporter, check_canceled: Optional[Callable[[], None]], @@ -63,8 +99,8 @@ def batch_inference( with ExitStack() as stack: phase_progress = stack.enter_context(progress_reporter.start_next_phase()) alignment_model = stack.enter_context(self._word_alignment_model_factory.create_alignment_model()) - self.init_corpus() - writer = stack.enter_context(self._shared_file_service.open_target_alignment_writer()) + self._init_corpus() + writer = stack.enter_context(self._word_alignment_file_service.open_target_alignment_writer()) current_inference_step = 0 phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count)) batch_size = self._config["inference_batch_size"] @@ -84,9 +120,9 @@ def batch_inference( ) # type: ignore ) - def save_model(self) -> None: + def _save_model(self) -> None: logger.info("Saving model") model_path = self._word_alignment_model_factory.save_model() - self._shared_file_service.save_model( + self._word_alignment_file_service.save_model( model_path, f"builds/{self._config['build_id']}/model{''.join(model_path.suffixes)}" ) diff --git a/machine/jobs/word_alignment_file_service.py b/machine/jobs/word_alignment_file_service.py new file mode 100644 index 0000000..694591c --- /dev/null +++ b/machine/jobs/word_alignment_file_service.py @@ -0,0 +1,55 @@ +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Iterator, List, TypedDict, Union + +from ..corpora.text_corpus import TextCorpus +from ..corpora.text_file_text_corpus import TextFileTextCorpus +from .shared_file_service_base import DictToJsonWriter, SharedFileServiceBase +from .shared_file_service_factory import SharedFileServiceType, get_shared_file_service + + +class WordAlignmentInfo(TypedDict): + refs: List[str] + column_count: int + row_count: int + alignmnent: str + + +class WordAlignmentFileService: + def __init__( + self, + type: Union[str, SharedFileServiceType], + config: Any, + source_filename: str = "train.src.txt", + target_filename: str = "train.trg.txt", + word_alignment_filename: str = "word_alignments.json", + ) -> None: + + self._source_filename = source_filename + self._target_filename = target_filename + self._word_alignment_filename = word_alignment_filename + + self.shared_file_service: SharedFileServiceBase = get_shared_file_service(type, config) + + def create_source_corpus(self) -> TextCorpus: + return TextFileTextCorpus( + self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{self._source_filename}") + ) + + def create_target_corpus(self) -> TextCorpus: + return TextFileTextCorpus( + self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{self._target_filename}") + ) + + def exists_source_corpus(self) -> bool: + return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{self._source_filename}") + + def exists_target_corpus(self) -> bool: + return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{self._target_filename}") + + def save_model(self, model_path: Path, destination: str) -> None: + self.shared_file_service.upload_path(model_path, destination) + + @contextmanager + def open_target_alignment_writer(self) -> Iterator[DictToJsonWriter]: + return self.shared_file_service.open_target_writer(self._word_alignment_filename) diff --git a/tests/jobs/test_nmt_engine_build_job.py b/tests/jobs/test_nmt_engine_build_job.py index 6f19278..a5e416d 100644 --- a/tests/jobs/test_nmt_engine_build_job.py +++ b/tests/jobs/test_nmt_engine_build_job.py @@ -10,7 +10,13 @@ from machine.annotations import Range from machine.corpora import DictionaryTextCorpus -from machine.jobs import DictToJsonWriter, NmtEngineBuildJob, NmtModelFactory, PretranslationInfo, SharedFileService +from machine.jobs import ( + DictToJsonWriter, + NmtEngineBuildJob, + NmtModelFactory, + PretranslationInfo, + TranslationFileService, +) from machine.translation import ( Phrase, Trainer, @@ -30,7 +36,7 @@ def test_run(decoy: Decoy) -> None: pretranslations = json.loads(env.target_pretranslations) assert len(pretranslations) == 1 assert pretranslations[0]["translation"] == "Please, I have booked a room." - decoy.verify(env.shared_file_service.save_model(Path("model.tar.gz"), "models/save-model.tar.gz"), times=1) + decoy.verify(env.translation_file_service.save_model(Path("model.tar.gz"), "models/save-model.tar.gz"), times=1) def test_cancel(decoy: Decoy) -> None: @@ -92,12 +98,12 @@ def __init__(self, decoy: Decoy) -> None: decoy.when(self.nmt_model_factory.create_engine()).then_return(self.engine) decoy.when(self.nmt_model_factory.save_model()).then_return(Path("model.tar.gz")) - self.shared_file_service = decoy.mock(cls=SharedFileService) - decoy.when(self.shared_file_service.create_source_corpus()).then_return(DictionaryTextCorpus()) - decoy.when(self.shared_file_service.create_target_corpus()).then_return(DictionaryTextCorpus()) - decoy.when(self.shared_file_service.exists_source_corpus()).then_return(True) - decoy.when(self.shared_file_service.exists_target_corpus()).then_return(True) - decoy.when(self.shared_file_service.get_source_pretranslations()).then_do( + self.translation_file_service = decoy.mock(cls=TranslationFileService) + decoy.when(self.translation_file_service.create_source_corpus()).then_return(DictionaryTextCorpus()) + decoy.when(self.translation_file_service.create_target_corpus()).then_return(DictionaryTextCorpus()) + decoy.when(self.translation_file_service.exists_source_corpus()).then_return(True) + decoy.when(self.translation_file_service.exists_target_corpus()).then_return(True) + decoy.when(self.translation_file_service.get_source_pretranslations()).then_do( lambda: ContextManagedGenerator( ( pi @@ -123,14 +129,14 @@ def open_target_pretranslation_writer(env: _TestEnvironment) -> Iterator[DictToJ file.write("\n]\n") env.target_pretranslations = file.getvalue() - decoy.when(self.shared_file_service.open_target_pretranslation_writer()).then_do( + decoy.when(self.translation_file_service.open_target_pretranslation_writer()).then_do( lambda: open_target_pretranslation_writer(self) ) self.job = NmtEngineBuildJob( MockSettings({"src_lang": "es", "trg_lang": "en", "save_model": "save-model", "inference_batch_size": 100}), self.nmt_model_factory, - self.shared_file_service, + self.translation_file_service, ) diff --git a/tests/jobs/test_smt_engine_build_job.py b/tests/jobs/test_smt_engine_build_job.py index a6dbb90..3aafcfc 100644 --- a/tests/jobs/test_smt_engine_build_job.py +++ b/tests/jobs/test_smt_engine_build_job.py @@ -10,18 +10,19 @@ from machine.annotations import Range from machine.corpora import DictionaryTextCorpus, MemoryText, TextRow -from machine.jobs import DictToJsonWriter, PretranslationInfo, SharedFileService, SmtEngineBuildJob, SmtModelFactory +from machine.jobs import DictToJsonWriter, PretranslationInfo, SmtEngineBuildJob, SmtModelFactory +from machine.jobs.translation_file_service import TranslationFileService from machine.tokenization import WHITESPACE_DETOKENIZER, WHITESPACE_TOKENIZER from machine.translation import ( Phrase, Trainer, TrainStats, - TranslationEngine, TranslationResult, TranslationSources, Truecaser, WordAlignmentMatrix, ) +from machine.translation.translation_engine import TranslationEngine from machine.utils import CanceledError, ContextManagedGenerator @@ -33,7 +34,8 @@ def test_run(decoy: Decoy) -> None: assert len(pretranslations) == 1 assert pretranslations[0]["translation"] == "Please, I have booked a room." decoy.verify( - env.shared_file_service.save_model(matchers.Anything(), f"builds/{env.job._config.build_id}/model.zip"), times=1 + env.translation_file_service.save_model(matchers.Anything(), f"builds/{env.job._config.build_id}/model.zip"), + times=1, ) @@ -101,8 +103,8 @@ def __init__(self, decoy: Decoy) -> None: decoy.when(self.smt_model_factory.create_truecaser()).then_return(self.truecaser) decoy.when(self.smt_model_factory.save_model()).then_return(Path("model.zip")) - self.shared_file_service = decoy.mock(cls=SharedFileService) - decoy.when(self.shared_file_service.create_source_corpus()).then_return( + self.translation_file_service = decoy.mock(cls=TranslationFileService) + decoy.when(self.translation_file_service.create_source_corpus()).then_return( DictionaryTextCorpus( MemoryText( "text1", @@ -114,7 +116,7 @@ def __init__(self, decoy: Decoy) -> None: ) ) ) - decoy.when(self.shared_file_service.create_target_corpus()).then_return( + decoy.when(self.translation_file_service.create_target_corpus()).then_return( DictionaryTextCorpus( MemoryText( "text1", @@ -126,9 +128,9 @@ def __init__(self, decoy: Decoy) -> None: ) ) ) - decoy.when(self.shared_file_service.exists_source_corpus()).then_return(True) - decoy.when(self.shared_file_service.exists_target_corpus()).then_return(True) - decoy.when(self.shared_file_service.get_source_pretranslations()).then_do( + decoy.when(self.translation_file_service.exists_source_corpus()).then_return(True) + decoy.when(self.translation_file_service.exists_target_corpus()).then_return(True) + decoy.when(self.translation_file_service.get_source_pretranslations()).then_do( lambda: ContextManagedGenerator( ( pi @@ -154,14 +156,14 @@ def open_target_pretranslation_writer(env: _TestEnvironment) -> Iterator[DictToJ file.write("\n]\n") env.target_pretranslations = file.getvalue() - decoy.when(self.shared_file_service.open_target_pretranslation_writer()).then_do( + decoy.when(self.translation_file_service.open_target_pretranslation_writer()).then_do( lambda: open_target_pretranslation_writer(self) ) self.job = SmtEngineBuildJob( MockSettings({"build_id": "mybuild", "inference_batch_size": 100}), self.smt_model_factory, - self.shared_file_service, + self.translation_file_service, ) diff --git a/tests/jobs/test_word_alignment_build_job.py b/tests/jobs/test_word_alignment_build_job.py index e5e8d5d..ad4535e 100644 --- a/tests/jobs/test_word_alignment_build_job.py +++ b/tests/jobs/test_word_alignment_build_job.py @@ -9,17 +9,12 @@ from testutils.mock_settings import MockSettings from machine.corpora import DictionaryTextCorpus, MemoryText, TextRow -from machine.jobs import ( - DictToJsonWriter, - PretranslationInfo, - SharedFileService, - WordAlignmentBuildJob, - WordAlignmentModelFactory, -) +from machine.jobs import DictToJsonWriter, WordAlignmentBuildJob, WordAlignmentModelFactory +from machine.jobs.word_alignment_file_service import WordAlignmentFileService from machine.tokenization import WHITESPACE_TOKENIZER from machine.translation import Trainer, TrainStats, WordAlignmentMatrix from machine.translation.word_alignment_model import WordAlignmentModel -from machine.utils import CanceledError, ContextManagedGenerator +from machine.utils import CanceledError def test_run(decoy: Decoy) -> None: @@ -30,7 +25,8 @@ def test_run(decoy: Decoy) -> None: assert len(alignments) == 1 assert alignments[0]["alignment"] == "0-0 1-1 2-2" decoy.verify( - env.shared_file_service.save_model(matchers.Anything(), f"builds/{env.job._config.build_id}/model.zip"), times=1 + env.word_alignment_file_service.save_model(matchers.Anything(), f"builds/{env.job._config.build_id}/model.zip"), + times=1, ) @@ -68,8 +64,8 @@ def __init__(self, decoy: Decoy) -> None: decoy.when(self.word_alignment_model_factory.create_alignment_model()).then_return(self.model) decoy.when(self.word_alignment_model_factory.save_model()).then_return(Path("model.zip")) - self.shared_file_service = decoy.mock(cls=SharedFileService) - decoy.when(self.shared_file_service.create_source_corpus()).then_return( + self.word_alignment_file_service = decoy.mock(cls=WordAlignmentFileService) + decoy.when(self.word_alignment_file_service.create_source_corpus()).then_return( DictionaryTextCorpus( MemoryText( "text1", @@ -81,7 +77,7 @@ def __init__(self, decoy: Decoy) -> None: ) ) ) - decoy.when(self.shared_file_service.create_target_corpus()).then_return( + decoy.when(self.word_alignment_file_service.create_target_corpus()).then_return( DictionaryTextCorpus( MemoryText( "text1", @@ -93,23 +89,8 @@ def __init__(self, decoy: Decoy) -> None: ) ) ) - decoy.when(self.shared_file_service.exists_source_corpus()).then_return(True) - decoy.when(self.shared_file_service.exists_target_corpus()).then_return(True) - decoy.when(self.shared_file_service.get_source_pretranslations()).then_do( - lambda: ContextManagedGenerator( - ( - pi - for pi in [ - PretranslationInfo( - corpusId="corpus1", - textId="text1", - refs=["ref1"], - translation="Por favor, tengo reservada una habitaciĆ³n.", - ) - ] - ) - ) - ) + decoy.when(self.word_alignment_file_service.exists_source_corpus()).then_return(True) + decoy.when(self.word_alignment_file_service.exists_target_corpus()).then_return(True) self.alignment_json = "" @@ -121,14 +102,14 @@ def open_target_alignment_writer(env: _TestEnvironment) -> Iterator[DictToJsonWr file.write("\n]\n") env.alignment_json = file.getvalue() - decoy.when(self.shared_file_service.open_target_alignment_writer()).then_do( + decoy.when(self.word_alignment_file_service.open_target_alignment_writer()).then_do( lambda: open_target_alignment_writer(self) ) self.job = WordAlignmentBuildJob( MockSettings({"build_id": "mybuild", "inference_batch_size": 100}), self.word_alignment_model_factory, - self.shared_file_service, + self.word_alignment_file_service, )