From 2bd1d723a06a146f8d9f40a15eca95fc0317ddfd Mon Sep 17 00:00:00 2001 From: John Lambert Date: Fri, 16 Aug 2024 14:54:22 -0400 Subject: [PATCH] Initial refactor Initial word alignment build job Update tests Update from reviewer comments Small fixes Reviewer comments - still need to add new word align fodler template one more fix --- .vscode/launch.json | 16 +++ .vscode/settings.json | 8 +- machine/corpora/usfm_text_base.py | 3 +- machine/jobs/__init__.py | 17 ++- machine/jobs/build_clearml_helper.py | 117 +++++++++++++++ machine/jobs/build_nmt_engine.py | 7 +- machine/jobs/build_smt_engine.py | 119 +++------------ machine/jobs/build_word_alignment_model.py | 105 ++++++++++++++ machine/jobs/clearml_shared_file_service.py | 48 +++---- machine/jobs/local_shared_file_service.py | 6 +- machine/jobs/nmt_engine_build_job.py | 115 +++++++-------- machine/jobs/settings.yaml | 6 +- machine/jobs/shared_file_service.py | 117 --------------- machine/jobs/shared_file_service_base.py | 77 ++++++++++ machine/jobs/shared_file_service_factory.py | 20 +++ machine/jobs/smt_engine_build_job.py | 80 +++++------ machine/jobs/smt_model_factory.py | 29 ++-- machine/jobs/thot/__init__.py | 3 + machine/jobs/thot/thot_smt_model_factory.py | 55 +------ .../thot/thot_word_alignment_model_factory.py | 54 +++++++ machine/jobs/translation_engine_build_job.py | 88 ++++++++++++ machine/jobs/translation_file_service.py | 74 ++++++++++ machine/jobs/word_alignment_build_job.py | 135 ++++++++++++++++++ machine/jobs/word_alignment_file_service.py | 55 +++++++ machine/jobs/word_alignment_model_factory.py | 35 +++++ machine/tokenization/__init__.py | 3 + machine/tokenization/tokenizer_factory.py | 54 +++++++ .../thot_symmetrized_word_alignment_model.py | 6 +- .../thot/thot_word_alignment_model_utils.py | 29 ++-- tests/jobs/test_nmt_engine_build_job.py | 34 +++-- tests/jobs/test_smt_engine_build_job.py | 33 +++-- tests/jobs/test_word_alignment_build_job.py | 122 ++++++++++++++++ .../thot/test_thot_smt_model_trainer.py | 98 +------------ .../test_thot_word_alignment_model_trainer.py | 83 +++++++++++ .../thot/thot_model_trainer_helper.py | 103 +++++++++++++ 35 files changed, 1398 insertions(+), 556 deletions(-) create mode 100644 machine/jobs/build_clearml_helper.py create mode 100644 machine/jobs/build_word_alignment_model.py delete mode 100644 machine/jobs/shared_file_service.py create mode 100644 machine/jobs/shared_file_service_base.py create mode 100644 machine/jobs/shared_file_service_factory.py create mode 100644 machine/jobs/thot/thot_word_alignment_model_factory.py create mode 100644 machine/jobs/translation_engine_build_job.py create mode 100644 machine/jobs/translation_file_service.py create mode 100644 machine/jobs/word_alignment_build_job.py create mode 100644 machine/jobs/word_alignment_file_service.py create mode 100644 machine/jobs/word_alignment_model_factory.py create mode 100644 machine/tokenization/tokenizer_factory.py create mode 100644 tests/jobs/test_word_alignment_build_job.py create mode 100644 tests/translation/thot/test_thot_word_alignment_model_trainer.py create mode 100644 tests/translation/thot/thot_model_trainer_helper.py diff --git a/.vscode/launch.json b/.vscode/launch.json index e43ba656..7ec3d766 100755 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -10,6 +10,9 @@ "request": "launch", "program": "${file}", "console": "integratedTerminal", + "env": { + "PYTHONPATH": "${workspaceFolder}:${workspaceFolder}/tests" + }, "justMyCode": true }, { @@ -64,6 +67,19 @@ "build1" ] }, + { + "name": "build_word_alignment_model", + "type": "debugpy", + "request": "launch", + "module": "machine.jobs.build_word_alignment_model", + "justMyCode": false, + "args": [ + "--model-type", + "thot", + "--build-id", + "build1" + ] + }, { "name": "Python: Debug Tests", "type": "debugpy", diff --git a/.vscode/settings.json b/.vscode/settings.json index 5bbc1b5d..4498368c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,10 +1,14 @@ { "editor.formatOnSave": true, "editor.codeActionsOnSave": { - "source.organizeImports": "explicit" + "source.organizeImports": "explicit", }, "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true, + "python.analysis.extraPaths": [ + "tests" + ], + "python.analysis.importFormat": "relative", "[python]": { "editor.defaultFormatter": "ms-python.black-formatter", "editor.formatOnSave": true @@ -17,4 +21,4 @@ "python.analysis.extraPaths": [ "./tests" ] -} +} \ No newline at end of file diff --git a/machine/corpora/usfm_text_base.py b/machine/corpora/usfm_text_base.py index bae08fed..3c799f5e 100644 --- a/machine/corpora/usfm_text_base.py +++ b/machine/corpora/usfm_text_base.py @@ -2,11 +2,10 @@ from io import TextIOWrapper from typing import Generator, Iterable, List, Optional, Sequence -from machine.corpora.scripture_ref import ScriptureRef - from ..scripture.verse_ref import Versification from ..utils.string_utils import has_sentence_ending from .corpora_utils import gen +from .scripture_ref import ScriptureRef from .scripture_ref_usfm_parser_handler import ScriptureRefUsfmParserHandler, ScriptureTextType from .scripture_text import ScriptureText from .stream_container import StreamContainer diff --git a/machine/jobs/__init__.py b/machine/jobs/__init__.py index 0849ee00..727da68f 100644 --- a/machine/jobs/__init__.py +++ b/machine/jobs/__init__.py @@ -2,18 +2,27 @@ from .local_shared_file_service import LocalSharedFileService from .nmt_engine_build_job import NmtEngineBuildJob from .nmt_model_factory import NmtModelFactory -from .shared_file_service import PretranslationInfo, PretranslationWriter, SharedFileService +from .shared_file_service_base import DictToJsonWriter, SharedFileServiceBase from .smt_engine_build_job import SmtEngineBuildJob from .smt_model_factory import SmtModelFactory +from .translation_file_service import PretranslationInfo, TranslationFileService +from .word_alignment_build_job import WordAlignmentBuildJob +from .word_alignment_file_service import WordAlignmentFileService, WordAlignmentInfo +from .word_alignment_model_factory import WordAlignmentModelFactory __all__ = [ "ClearMLSharedFileService", "LocalSharedFileService", "NmtEngineBuildJob", "NmtModelFactory", - "PretranslationInfo", - "PretranslationWriter", - "SharedFileService", + "DictToJsonWriter", + "SharedFileServiceBase", "SmtEngineBuildJob", "SmtModelFactory", + "PretranslationInfo", + "TranslationFileService", + "WordAlignmentBuildJob", + "WordAlignmentFileService", + "WordAlignmentInfo", + "WordAlignmentModelFactory", ] diff --git a/machine/jobs/build_clearml_helper.py b/machine/jobs/build_clearml_helper.py new file mode 100644 index 00000000..c7373023 --- /dev/null +++ b/machine/jobs/build_clearml_helper.py @@ -0,0 +1,117 @@ +import json +import logging +import os +from datetime import datetime +from typing import Callable, Optional, Union, cast + +import aiohttp +from clearml import Task +from dynaconf.base import Settings + +from ..utils.canceled_error import CanceledError +from ..utils.progress_status import ProgressStatus +from .async_scheduler import AsyncScheduler + + +class ProgressInfo: + last_percent_completed: Union[int, None] = 0 + last_message: Union[str, None] = "" + last_progress_time: Union[datetime, None] = None + last_check_canceled_time: Union[datetime, None] = None + + +def get_clearml_check_canceled(progress_info: ProgressInfo, task: Task) -> Callable[[], None]: + + def clearml_check_canceled() -> None: + current_time = datetime.now() + if ( + progress_info.last_check_canceled_time is None + or (current_time - progress_info.last_check_canceled_time).seconds > 20 + ): + if task.get_status() == "stopped": + raise CanceledError + progress_info.last_check_canceled_time = current_time + + return clearml_check_canceled + + +def get_clearml_progress_caller( + progress_info: ProgressInfo, task: Task, scheduler: AsyncScheduler, logger: logging.Logger +) -> Callable[[ProgressStatus], None]: + def clearml_progress(progress_status: ProgressStatus) -> None: + percent_completed: Optional[int] = None + if progress_status.percent_completed is not None: + percent_completed = round(progress_status.percent_completed * 100) + message = progress_status.message + if percent_completed != progress_info.last_percent_completed or message != progress_info.last_message: + logger.info(f"{percent_completed}% - {message}") + current_time = datetime.now() + if ( + progress_info.last_progress_time is None + or (current_time - progress_info.last_progress_time).seconds > 1 + ): + new_runtime_props = task.data.runtime.copy() or {} # type: ignore + new_runtime_props["progress"] = str(percent_completed) + new_runtime_props["message"] = message + scheduler.schedule( + update_runtime_properties( + task.id, # type: ignore + task.session.host, + task.session.token, # type: ignore + create_runtime_properties(task, percent_completed, message), + ) + ) + progress_info.last_progress_time = current_time + progress_info.last_percent_completed = percent_completed + progress_info.last_message = message + + return clearml_progress + + +def get_local_progress_caller(progress_info: ProgressInfo, logger: logging.Logger) -> Callable[[ProgressStatus], None]: + + def local_progress(progress_status: ProgressStatus) -> None: + percent_completed: Optional[int] = None + if progress_status.percent_completed is not None: + percent_completed = round(progress_status.percent_completed * 100) + message = progress_status.message + if percent_completed != progress_info.last_percent_completed or message != progress_info.last_message: + logger.info(f"{percent_completed}% - {message}") + progress_info.last_percent_completed = percent_completed + progress_info.last_message = message + + return local_progress + + +def update_settings(settings: Settings, args: dict): + settings.update(args) + settings.model_type = cast(str, settings.model_type).lower() + if "build_options" in settings: + try: + build_options = json.loads(cast(str, settings.build_options)) + except ValueError as e: + raise ValueError("Build options could not be parsed: Invalid JSON") from e + except TypeError as e: + raise TypeError(f"Build options could not be parsed: {e}") from e + settings.update({settings.model_type: build_options}) + settings.data_dir = os.path.expanduser(cast(str, settings.data_dir)) + + +async def update_runtime_properties(task_id: str, base_url: str, token: str, runtime_props: dict) -> None: + async with aiohttp.ClientSession(base_url=base_url, headers={"Authorization": f"Bearer {token}"}) as session: + json = {"task": task_id, "runtime": runtime_props, "force": True} + async with session.post("/tasks.edit", json=json) as response: + response.raise_for_status() + + +def create_runtime_properties(task, percent_completed: Optional[int], message: Optional[str]) -> dict: + runtime_props = task.data.runtime.copy() or {} + if percent_completed is not None: + runtime_props["progress"] = str(percent_completed) + else: + del runtime_props["progress"] + if message is not None: + runtime_props["message"] = message + else: + del runtime_props["message"] + return runtime_props diff --git a/machine/jobs/build_nmt_engine.py b/machine/jobs/build_nmt_engine.py index 4adc565b..fe062074 100644 --- a/machine/jobs/build_nmt_engine.py +++ b/machine/jobs/build_nmt_engine.py @@ -8,10 +8,11 @@ from ..utils.canceled_error import CanceledError from ..utils.progress_status import ProgressStatus -from .clearml_shared_file_service import ClearMLSharedFileService from .config import SETTINGS from .nmt_engine_build_job import NmtEngineBuildJob from .nmt_model_factory import NmtModelFactory +from .shared_file_service_factory import SharedFileServiceType +from .translation_file_service import TranslationFileService # Setup logging logging.basicConfig( @@ -58,7 +59,7 @@ def clearml_progress(status: ProgressStatus) -> None: logger.info(f"Config: {SETTINGS.as_dict()}") - shared_file_service = ClearMLSharedFileService(SETTINGS) + translation_file_service = TranslationFileService(SharedFileServiceType.CLEARML, SETTINGS) nmt_model_factory: NmtModelFactory if model_type == "huggingface": from .huggingface.hugging_face_nmt_model_factory import HuggingFaceNmtModelFactory @@ -67,7 +68,7 @@ def clearml_progress(status: ProgressStatus) -> None: else: raise RuntimeError("The model type is invalid.") - job = NmtEngineBuildJob(SETTINGS, nmt_model_factory, shared_file_service) + job = NmtEngineBuildJob(SETTINGS, nmt_model_factory, translation_file_service) train_corpus_size = job.run(progress, check_canceled) if task is not None: task.get_logger().report_single_value(name="train_corpus_size", value=train_corpus_size) diff --git a/machine/jobs/build_smt_engine.py b/machine/jobs/build_smt_engine.py index 9db0b9a3..56a4caa0 100644 --- a/machine/jobs/build_smt_engine.py +++ b/machine/jobs/build_smt_engine.py @@ -1,20 +1,25 @@ import argparse -import json import logging -import os -from datetime import datetime -from typing import Callable, Optional, cast +from typing import Callable, Optional -import aiohttp from clearml import Task -from ..utils.canceled_error import CanceledError from ..utils.progress_status import ProgressStatus from .async_scheduler import AsyncScheduler -from .clearml_shared_file_service import ClearMLSharedFileService +from .build_clearml_helper import ( + ProgressInfo, + create_runtime_properties, + get_clearml_check_canceled, + get_clearml_progress_caller, + get_local_progress_caller, + update_runtime_properties, + update_settings, +) from .config import SETTINGS +from .shared_file_service_factory import SharedFileServiceType from .smt_engine_build_job import SmtEngineBuildJob from .smt_model_factory import SmtModelFactory +from .translation_file_service import TranslationFileService # Setup logging logging.basicConfig( @@ -25,118 +30,34 @@ logger = logging.getLogger(str(__package__) + ".build_smt_engine") -async def update_runtime_properties(task_id: str, base_url: str, token: str, runtime_props: dict) -> None: - async with aiohttp.ClientSession(base_url=base_url, headers={"Authorization": f"Bearer {token}"}) as session: - json = {"task": task_id, "runtime": runtime_props, "force": True} - async with session.post("/tasks.edit", json=json) as response: - response.raise_for_status() - - -def create_runtime_properties(task, percent_completed: Optional[int], message: Optional[str]) -> dict: - runtime_props = task.data.runtime.copy() or {} - if percent_completed is not None: - runtime_props["progress"] = str(percent_completed) - else: - del runtime_props["progress"] - if message is not None: - runtime_props["message"] = message - else: - del runtime_props["message"] - return runtime_props - - def run(args: dict) -> None: progress: Callable[[ProgressStatus], None] check_canceled: Optional[Callable[[], None]] = None task = None - last_percent_completed: Optional[int] = None - last_message: Optional[str] = None scheduler: Optional[AsyncScheduler] = None + progress_info = ProgressInfo() if args["clearml"]: task = Task.init() - scheduler = AsyncScheduler() - last_check_canceled_time: Optional[datetime] = None - - def clearml_check_canceled() -> None: - nonlocal last_check_canceled_time - current_time = datetime.now() - if last_check_canceled_time is None or (current_time - last_check_canceled_time).seconds > 20: - if task.get_status() == "stopped": - raise CanceledError - last_check_canceled_time = current_time - - check_canceled = clearml_check_canceled + check_canceled = get_clearml_check_canceled(progress_info, task) task.reload() - last_progress_time: Optional[datetime] = None - - def clearml_progress(progress_status: ProgressStatus) -> None: - nonlocal last_percent_completed - nonlocal last_message - nonlocal last_progress_time - percent_completed: Optional[int] = None - if progress_status.percent_completed is not None: - percent_completed = round(progress_status.percent_completed * 100) - message = progress_status.message - if percent_completed != last_percent_completed or message != last_message: - logger.info(f"{percent_completed}% - {message}") - current_time = datetime.now() - if last_progress_time is None or (current_time - last_progress_time).seconds > 1: - new_runtime_props = task.data.runtime.copy() or {} - new_runtime_props["progress"] = str(percent_completed) - new_runtime_props["message"] = message - scheduler.schedule( - update_runtime_properties( - task.id, - task.session.host, - task.session.token, - create_runtime_properties(task, percent_completed, message), - ) - ) - last_progress_time = current_time - last_percent_completed = percent_completed - last_message = message - - progress = clearml_progress - else: + progress = get_clearml_progress_caller(progress_info, task, scheduler, logger) - def local_progress(progress_status: ProgressStatus) -> None: - nonlocal last_percent_completed - nonlocal last_message - percent_completed: Optional[int] = None - if progress_status.percent_completed is not None: - percent_completed = round(progress_status.percent_completed * 100) - message = progress_status.message - if percent_completed != last_percent_completed or message != last_message: - logger.info(f"{percent_completed}% - {message}") - last_percent_completed = percent_completed - last_message = message - - progress = local_progress + else: + progress = get_local_progress_caller(ProgressInfo(), logger) try: logger.info("SMT Engine Build Job started") - - SETTINGS.update(args) - model_type = cast(str, SETTINGS.model_type).lower() - if "build_options" in SETTINGS: - try: - build_options = json.loads(cast(str, SETTINGS.build_options)) - except ValueError as e: - raise ValueError("Build options could not be parsed: Invalid JSON") from e - except TypeError as e: - raise TypeError(f"Build options could not be parsed: {e}") from e - SETTINGS.update({model_type: build_options}) - SETTINGS.data_dir = os.path.expanduser(cast(str, SETTINGS.data_dir)) + update_settings(SETTINGS, args) logger.info(f"Config: {SETTINGS.as_dict()}") - shared_file_service = ClearMLSharedFileService(SETTINGS) + shared_file_service = TranslationFileService(SharedFileServiceType.CLEARML, SETTINGS) smt_model_factory: SmtModelFactory - if model_type == "thot": + if SETTINGS.model_type == "thot": from .thot.thot_smt_model_factory import ThotSmtModelFactory smt_model_factory = ThotSmtModelFactory(SETTINGS) diff --git a/machine/jobs/build_word_alignment_model.py b/machine/jobs/build_word_alignment_model.py new file mode 100644 index 00000000..757fd841 --- /dev/null +++ b/machine/jobs/build_word_alignment_model.py @@ -0,0 +1,105 @@ +import argparse +import logging +from typing import Callable, Optional + +from clearml import Task + +from ..utils.progress_status import ProgressStatus +from .async_scheduler import AsyncScheduler +from .build_clearml_helper import ( + ProgressInfo, + create_runtime_properties, + get_clearml_check_canceled, + get_clearml_progress_caller, + get_local_progress_caller, + update_runtime_properties, + update_settings, +) +from .config import SETTINGS +from .shared_file_service_factory import SharedFileServiceType +from .thot.thot_word_alignment_model_factory import ThotWordAlignmentModelFactory +from .word_alignment_build_job import WordAlignmentBuildJob +from .word_alignment_file_service import WordAlignmentFileService +from .word_alignment_model_factory import WordAlignmentModelFactory + +# Setup logging +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + level=logging.INFO, +) + +logger = logging.getLogger(str(__package__) + ".build_word_alignment_model") + + +def run(args: dict) -> WordAlignmentBuildJob: + progress: Callable[[ProgressStatus], None] + check_canceled: Optional[Callable[[], None]] = None + task = None + scheduler: Optional[AsyncScheduler] = None + progress_info = ProgressInfo() + if args["clearml"]: + task = Task.init() + scheduler = AsyncScheduler() + + check_canceled = get_clearml_check_canceled(progress_info, task) + + task.reload() + + progress = get_clearml_progress_caller(progress_info, task, scheduler, logger) + + else: + progress = get_local_progress_caller(ProgressInfo(), logger) + + try: + logger.info("Word Alignment Build Job started") + update_settings(SETTINGS, args) + + logger.info(f"Config: {SETTINGS.as_dict()}") + + word_alignment_file_service = WordAlignmentFileService(SharedFileServiceType.CLEARML, SETTINGS) + word_alignment_model_factory: WordAlignmentModelFactory + if SETTINGS.model_type == "thot": + word_alignment_model_factory = ThotWordAlignmentModelFactory(SETTINGS) + else: + raise RuntimeError("The model type is invalid.") + + word_alignment_build_job = WordAlignmentBuildJob( + SETTINGS, word_alignment_model_factory, word_alignment_file_service + ) + train_corpus_size = word_alignment_build_job.run(progress, check_canceled) + if scheduler is not None and task is not None: + scheduler.schedule( + update_runtime_properties( + task.id, task.session.host, task.session.token, create_runtime_properties(task, 100, "Completed") + ) + ) + task.get_logger().report_single_value(name="train_corpus_size", value=train_corpus_size) + logger.info("Finished") + except Exception as e: + if task: + if task.get_status() == "stopped": + return word_alignment_build_job + else: + task.mark_failed(status_reason=type(e).__name__, status_message=str(e)) + raise e + finally: + if scheduler is not None: + scheduler.stop() + return word_alignment_build_job + + +def main() -> None: + parser = argparse.ArgumentParser(description="Trains an SMT model.") + parser.add_argument("--model-type", required=True, type=str, help="Model type") + parser.add_argument("--build-id", required=True, type=str, help="Build id") + parser.add_argument("--clearml", default=False, action="store_true", help="Initializes a ClearML task") + parser.add_argument("--build-options", default=None, type=str, help="Build configurations") + args = parser.parse_args() + + input_args = {k: v for k, v in vars(args).items() if v is not None} + + run(input_args) + + +if __name__ == "__main__": + main() diff --git a/machine/jobs/clearml_shared_file_service.py b/machine/jobs/clearml_shared_file_service.py index 07bc9e44..aeeaafff 100644 --- a/machine/jobs/clearml_shared_file_service.py +++ b/machine/jobs/clearml_shared_file_service.py @@ -5,59 +5,45 @@ from clearml import StorageManager -from .shared_file_service import SharedFileService +from .shared_file_service_base import SharedFileServiceBase logger = logging.getLogger(__name__) -class ClearMLSharedFileService(SharedFileService): - def _download_file(self, path: str) -> Path: - uri = f"{self._shared_file_uri}/{self._shared_file_folder}/{path}" +class ClearMLSharedFileService(SharedFileServiceBase): + def download_file(self, path: str) -> Path: local_folder = str(self._data_dir) - file_path = try_n_times(lambda: StorageManager.download_file(uri, local_folder, skip_zero_size_check=True)) + file_path = try_n_times( + lambda: StorageManager.download_file(self._get_uri(path), local_folder, skip_zero_size_check=True) + ) if file_path is None: - raise RuntimeError(f"Failed to download file: {uri}") + raise RuntimeError(f"Failed to download file: {self._get_uri(path)}") return Path(file_path) def _download_folder(self, path: str) -> Path: - uri = f"{self._shared_file_uri}/{self._shared_file_folder}/{path}" local_folder = str(self._data_dir) - folder_path = try_n_times(lambda: StorageManager.download_folder(uri, local_folder)) + folder_path = try_n_times(lambda: StorageManager.download_folder(self._get_uri(path), local_folder)) if folder_path is None: - raise RuntimeError(f"Failed to download folder: {uri}") + raise RuntimeError(f"Failed to download folder: {self._get_uri(path)}") return Path(folder_path) / path def _exists_file(self, path: str) -> bool: - uri = f"{self._shared_file_uri}/{self._shared_file_folder}/{path}" - return try_n_times(lambda: StorageManager.exists_file(uri)) # type: ignore + return try_n_times(lambda: StorageManager.exists_file(self._get_uri(path))) # type: ignore def _upload_file(self, path: str, local_file_path: Path) -> None: - final_destination = try_n_times( - lambda: StorageManager.upload_file( - str(local_file_path), f"{self._shared_file_uri}/{self._shared_file_folder}/{path}" - ) - ) + final_destination = try_n_times(lambda: StorageManager.upload_file(str(local_file_path), self._get_uri(path))) if final_destination is None: - logger.error( - ( - f"Failed to upload file {str(local_file_path)} " - f"to {self._shared_file_uri}/{self._shared_file_folder}/{path}." - ) - ) + logger.error((f"Failed to upload file {str(local_file_path)} " f"to {self._get_uri(path)}.")) def _upload_folder(self, path: str, local_folder_path: Path) -> None: final_destination = try_n_times( - lambda: StorageManager.upload_folder( - str(local_folder_path), f"{self._shared_file_uri}/{self._shared_file_folder}/{path}" - ) + lambda: StorageManager.upload_folder(str(local_folder_path), self._get_uri(path)) ) if final_destination is None: - logger.error( - ( - f"Failed to upload folder {str(local_folder_path)} " - f"to {self._shared_file_uri}/{self._shared_file_folder}/{path}." - ) - ) + logger.error((f"Failed to upload folder {str(local_folder_path)} " f"to {self._get_uri(path)}.")) + + def _get_uri(self, path: str) -> str: + return f"{self._shared_file_uri}/{self._shared_file_folder}/{path}" def try_n_times(func: Callable, n=10): diff --git a/machine/jobs/local_shared_file_service.py b/machine/jobs/local_shared_file_service.py index 77553e68..a0d5fd8f 100644 --- a/machine/jobs/local_shared_file_service.py +++ b/machine/jobs/local_shared_file_service.py @@ -2,13 +2,13 @@ import shutil from pathlib import Path -from .shared_file_service import SharedFileService +from .shared_file_service_base import SharedFileServiceBase logger = logging.getLogger(__name__) -class LocalSharedFileService(SharedFileService): - def _download_file(self, path: str) -> Path: +class LocalSharedFileService(SharedFileServiceBase): + def download_file(self, path: str) -> Path: src_path = self._get_path(path) dst_path = self._data_dir / self._shared_file_folder / path dst_path.parent.mkdir(parents=True, exist_ok=True) diff --git a/machine/jobs/nmt_engine_build_job.py b/machine/jobs/nmt_engine_build_job.py index c1643112..060b6f07 100644 --- a/machine/jobs/nmt_engine_build_job.py +++ b/machine/jobs/nmt_engine_build_job.py @@ -1,92 +1,95 @@ import logging from contextlib import ExitStack -from typing import Any, Callable, Optional, Sequence +from typing import Any, Callable, Optional, Sequence, Tuple from ..corpora.corpora_utils import batch from ..translation.translation_engine import TranslationEngine from ..utils.phased_progress_reporter import Phase, PhasedProgressReporter from ..utils.progress_status import ProgressStatus from .nmt_model_factory import NmtModelFactory -from .shared_file_service import PretranslationInfo, PretranslationWriter, SharedFileService +from .shared_file_service_base import DictToJsonWriter +from .translation_engine_build_job import TranslationEngineBuildJob +from .translation_file_service import PretranslationInfo, TranslationFileService logger = logging.getLogger(__name__) -class NmtEngineBuildJob: - def __init__(self, config: Any, nmt_model_factory: NmtModelFactory, shared_file_service: SharedFileService) -> None: - self._config = config +class NmtEngineBuildJob(TranslationEngineBuildJob): + def __init__( + self, config: Any, nmt_model_factory: NmtModelFactory, translation_file_service: TranslationFileService + ) -> None: self._nmt_model_factory = nmt_model_factory - self._shared_file_service = shared_file_service - - def run( - self, - progress: Optional[Callable[[ProgressStatus], None]] = None, - check_canceled: Optional[Callable[[], None]] = None, - ) -> int: - if check_canceled is not None: - check_canceled() - self._nmt_model_factory.init() + super().__init__(config, translation_file_service) - logger.info("Downloading data files") - source_corpus = self._shared_file_service.create_source_corpus() - target_corpus = self._shared_file_service.create_target_corpus() - parallel_corpus = source_corpus.align_rows(target_corpus) - parallel_corpus_size = parallel_corpus.count(include_empty=False) - - if parallel_corpus_size > 0: + def _get_progress_reporter(self, progress: Optional[Callable[[ProgressStatus], None]]) -> PhasedProgressReporter: + if self.parallel_corpus_size > 0: phases = [ Phase(message="Training NMT model", percentage=0.9), Phase(message="Pretranslating segments", percentage=0.1), ] else: phases = [Phase(message="Pretranslating segments", percentage=1.0)] - progress_reporter = PhasedProgressReporter(progress, phases) + return PhasedProgressReporter(progress, phases) - if parallel_corpus_size > 0: - if check_canceled is not None: - check_canceled() + def _respond_to_no_training_corpus(self) -> Tuple[int, float]: + logger.info("No matching entries in the source and target corpus - skipping training") + return 0, float("nan") - if self._nmt_model_factory.train_tokenizer: - logger.info("Training source tokenizer") - with self._nmt_model_factory.create_source_tokenizer_trainer(source_corpus) as source_tokenizer_trainer: - source_tokenizer_trainer.train(check_canceled=check_canceled) - source_tokenizer_trainer.save() + def _train_model( + self, + progress_reporter: PhasedProgressReporter, + check_canceled: Optional[Callable[[], None]], + ) -> Tuple[int, float]: + if check_canceled is not None: + check_canceled() - if check_canceled is not None: - check_canceled() + if self._nmt_model_factory.train_tokenizer: + logger.info("Training source tokenizer") + with self._nmt_model_factory.create_source_tokenizer_trainer( + self.source_corpus + ) as source_tokenizer_trainer: + source_tokenizer_trainer.train(check_canceled=check_canceled) + source_tokenizer_trainer.save() - logger.info("Training target tokenizer") - with self._nmt_model_factory.create_target_tokenizer_trainer(target_corpus) as target_tokenizer_trainer: - target_tokenizer_trainer.train(check_canceled=check_canceled) - target_tokenizer_trainer.save() + if check_canceled is not None: + check_canceled() - if check_canceled is not None: - check_canceled() + logger.info("Training target tokenizer") + with self._nmt_model_factory.create_target_tokenizer_trainer( + self.target_corpus + ) as target_tokenizer_trainer: + target_tokenizer_trainer.train(check_canceled=check_canceled) + target_tokenizer_trainer.save() - logger.info("Training NMT model") - with progress_reporter.start_next_phase() as phase_progress, self._nmt_model_factory.create_model_trainer( - parallel_corpus - ) as model_trainer: - model_trainer.train(progress=phase_progress, check_canceled=check_canceled) - model_trainer.save() - else: - logger.info("No matching entries in the source and target corpus - skipping training") + if check_canceled is not None: + check_canceled() - if check_canceled is not None: - check_canceled() + logger.info("Training NMT model") + with progress_reporter.start_next_phase() as phase_progress, self._nmt_model_factory.create_model_trainer( + self.parallel_corpus + ) as model_trainer: + model_trainer.train(progress=phase_progress, check_canceled=check_canceled) + model_trainer.save() + train_corpus_size = model_trainer.stats.train_corpus_size + return train_corpus_size, float("nan") + def _batch_inference( + self, + progress_reporter: PhasedProgressReporter, + check_canceled: Optional[Callable[[], None]], + ) -> None: logger.info("Pretranslating segments") - with self._shared_file_service.get_source_pretranslations() as src_pretranslations: + with self._translation_file_service.get_source_pretranslations() as src_pretranslations: inference_step_count = sum(1 for _ in src_pretranslations) with ExitStack() as stack: phase_progress = stack.enter_context(progress_reporter.start_next_phase()) engine = stack.enter_context(self._nmt_model_factory.create_engine()) - src_pretranslations = stack.enter_context(self._shared_file_service.get_source_pretranslations()) - writer = stack.enter_context(self._shared_file_service.open_target_pretranslation_writer()) + src_pretranslations = stack.enter_context(self._translation_file_service.get_source_pretranslations()) + writer = stack.enter_context(self._translation_file_service.open_target_pretranslation_writer()) current_inference_step = 0 phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count)) - batch_size = self._config["pretranslation_batch_size"] + batch_size = self._config["inference_batch_size"] for pi_batch in batch(src_pretranslations, batch_size): if check_canceled is not None: check_canceled() @@ -94,19 +97,19 @@ def run( current_inference_step += len(pi_batch) phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count)) + def _save_model(self) -> None: if "save_model" in self._config and self._config.save_model is not None: logger.info("Saving model") model_path = self._nmt_model_factory.save_model() - self._shared_file_service.save_model( + self._translation_file_service.save_model( model_path, f"models/{self._config.save_model + ''.join(model_path.suffixes)}" ) - return parallel_corpus_size def _translate_batch( engine: TranslationEngine, batch: Sequence[PretranslationInfo], - writer: PretranslationWriter, + writer: DictToJsonWriter, ) -> None: source_segments = [pi["translation"] for pi in batch] for i, result in enumerate(engine.translate_batch(source_segments)): diff --git a/machine/jobs/settings.yaml b/machine/jobs/settings.yaml index b787b6d6..a7d6853c 100644 --- a/machine/jobs/settings.yaml +++ b/machine/jobs/settings.yaml @@ -2,7 +2,7 @@ default: data_dir: ~/machine shared_file_uri: s3://silnlp/ shared_file_folder: production - pretranslation_batch_size: 1024 + inference_batch_size: 1024 huggingface: parent_model_name: facebook/nllb-200-distilled-1.3B train_params: @@ -28,6 +28,10 @@ default: thot: word_alignment_model_type: hmm tokenizer: latin + thot_align: + word_alignment_heuristic: grow-diag-final-and + model_type: hmm + tokenizer: latin development: shared_file_folder: dev huggingface: diff --git a/machine/jobs/shared_file_service.py b/machine/jobs/shared_file_service.py deleted file mode 100644 index 07fe6ec4..00000000 --- a/machine/jobs/shared_file_service.py +++ /dev/null @@ -1,117 +0,0 @@ -import json -from abc import ABC, abstractmethod -from contextlib import contextmanager -from pathlib import Path -from typing import Any, Generator, Iterator, List, TextIO, TypedDict - -import json_stream - -from ..corpora.text_corpus import TextCorpus -from ..corpora.text_file_text_corpus import TextFileTextCorpus -from ..utils.context_managed_generator import ContextManagedGenerator - - -class PretranslationInfo(TypedDict): - corpusId: str # noqa: N815 - textId: str # noqa: N815 - refs: List[str] - translation: str - - -class PretranslationWriter: - def __init__(self, file: TextIO) -> None: - self._file = file - self._first = True - - def write(self, pi: PretranslationInfo) -> None: - if not self._first: - self._file.write(",\n") - self._file.write(" " + json.dumps(pi)) - self._first = False - - -class SharedFileService(ABC): - def __init__(self, config: Any) -> None: - self._config = config - - def create_source_corpus(self) -> TextCorpus: - return TextFileTextCorpus(self._download_file(f"builds/{self._build_id}/train.src.txt")) - - def create_target_corpus(self) -> TextCorpus: - return TextFileTextCorpus(self._download_file(f"builds/{self._build_id}/train.trg.txt")) - - def exists_source_corpus(self) -> bool: - return self._exists_file(f"builds/{self._build_id}/train.src.txt") - - def exists_target_corpus(self) -> bool: - return self._exists_file(f"builds/{self._build_id}/train.trg.txt") - - def get_source_pretranslations(self) -> ContextManagedGenerator[PretranslationInfo, None, None]: - src_pretranslate_path = self._download_file(f"builds/{self._build_id}/pretranslate.src.json") - - def generator() -> Generator[PretranslationInfo, None, None]: - with src_pretranslate_path.open("r", encoding="utf-8-sig") as file: - for pi in json_stream.load(file): - yield PretranslationInfo( - corpusId=pi["corpusId"], - textId=pi["textId"], - refs=list(pi["refs"]), - translation=pi["translation"], - ) - - return ContextManagedGenerator(generator()) - - @contextmanager - def open_target_pretranslation_writer(self) -> Iterator[PretranslationWriter]: - build_id: str = self._config.build_id - build_dir = self._data_dir / self._shared_file_folder / "builds" / build_id - build_dir.mkdir(parents=True, exist_ok=True) - target_pretranslate_path = build_dir / "pretranslate.trg.json" - with target_pretranslate_path.open("w", encoding="utf-8", newline="\n") as file: - file.write("[\n") - yield PretranslationWriter(file) - file.write("\n]\n") - self._upload_file(f"builds/{self._build_id}/pretranslate.trg.json", target_pretranslate_path) - - def save_model(self, model_path: Path, destination: str) -> None: - if model_path.is_file(): - self._upload_file(destination, model_path) - else: - self._upload_folder(destination, model_path) - - @property - def _data_dir(self) -> Path: - return Path(self._config.data_dir) - - @property - def _build_id(self) -> str: - return self._config.build_id - - @property - def _engine_id(self) -> str: - return self._config.engine_id - - @property - def _shared_file_uri(self) -> str: - shared_file_uri: str = self._config.shared_file_uri - return shared_file_uri.rstrip("/") - - @property - def _shared_file_folder(self) -> str: - shared_file_folder: str = self._config.shared_file_folder - return shared_file_folder.rstrip("/") - - @abstractmethod - def _download_file(self, path: str) -> Path: ... - - @abstractmethod - def _download_folder(self, path: str) -> Path: ... - - @abstractmethod - def _exists_file(self, path: str) -> bool: ... - - @abstractmethod - def _upload_file(self, path: str, local_file_path: Path) -> None: ... - - @abstractmethod - def _upload_folder(self, path: str, local_folder_path: Path) -> None: ... diff --git a/machine/jobs/shared_file_service_base.py b/machine/jobs/shared_file_service_base.py new file mode 100644 index 00000000..ac8ed831 --- /dev/null +++ b/machine/jobs/shared_file_service_base.py @@ -0,0 +1,77 @@ +import json +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any, Iterator, TextIO + + +class DictToJsonWriter: + def __init__(self, file: TextIO) -> None: + self._file = file + self._first = True + + def write(self, pi: object) -> None: + if not self._first: + self._file.write(",\n") + self._file.write(" " + json.dumps(pi)) + self._first = False + + +class SharedFileServiceBase(ABC): + def __init__( + self, + config: Any, + ) -> None: + self._config = config + + def upload_path(self, path: Path, destination: str) -> None: + if path.is_file(): + self._upload_file(destination, path) + else: + self._upload_folder(destination, path) + + def open_target_writer(self, filename) -> Iterator[DictToJsonWriter]: + build_dir = self._data_dir / self._shared_file_folder / self.build_path + build_dir.mkdir(parents=True, exist_ok=True) + target_path = build_dir / filename + with target_path.open("w", encoding="utf-8", newline="\n") as file: + file.write("[\n") + yield DictToJsonWriter(file) + file.write("\n]\n") + self._upload_file(f"{self.build_path}/{filename}", target_path) + + @property + def build_path(self) -> str: + return f"builds/{self._config.build_id}" + + @property + def _data_dir(self) -> Path: + return Path(self._config.data_dir) + + @property + def _engine_id(self) -> str: + return self._config.engine_id + + @property + def _shared_file_uri(self) -> str: + shared_file_uri: str = self._config.shared_file_uri + return shared_file_uri.rstrip("/") + + @property + def _shared_file_folder(self) -> str: + shared_file_folder: str = self._config.shared_file_folder + return shared_file_folder.rstrip("/") + + @abstractmethod + def download_file(self, path: str) -> Path: ... + + @abstractmethod + def _download_folder(self, path: str) -> Path: ... + + @abstractmethod + def _exists_file(self, path: str) -> bool: ... + + @abstractmethod + def _upload_file(self, path: str, local_file_path: Path) -> None: ... + + @abstractmethod + def _upload_folder(self, path: str, local_folder_path: Path) -> None: ... diff --git a/machine/jobs/shared_file_service_factory.py b/machine/jobs/shared_file_service_factory.py new file mode 100644 index 00000000..ebfce044 --- /dev/null +++ b/machine/jobs/shared_file_service_factory.py @@ -0,0 +1,20 @@ +from enum import IntEnum, auto +from typing import Any, Union + +from .clearml_shared_file_service import ClearMLSharedFileService +from .local_shared_file_service import LocalSharedFileService +from .shared_file_service_base import SharedFileServiceBase + + +class SharedFileServiceType(IntEnum): + LOCAL = auto() + CLEARML = auto() + + +def get_shared_file_service(type: Union[str, SharedFileServiceType], config: Any) -> SharedFileServiceBase: + if isinstance(type, str): + type = SharedFileServiceType[type.upper()] + if type == SharedFileServiceType.LOCAL: + return LocalSharedFileService(config) + elif type == SharedFileServiceType.CLEARML: + return ClearMLSharedFileService(config) diff --git a/machine/jobs/smt_engine_build_job.py b/machine/jobs/smt_engine_build_job.py index 33b80953..5852af29 100644 --- a/machine/jobs/smt_engine_build_job.py +++ b/machine/jobs/smt_engine_build_job.py @@ -3,55 +3,47 @@ from typing import Any, Callable, Optional, Sequence, Tuple from ..corpora.corpora_utils import batch +from ..tokenization.tokenizer_factory import create_detokenizer, create_tokenizer from ..translation.translation_engine import TranslationEngine from ..utils.phased_progress_reporter import Phase, PhasedProgressReporter from ..utils.progress_status import ProgressStatus -from .shared_file_service import PretranslationInfo, PretranslationWriter, SharedFileService +from .shared_file_service_base import DictToJsonWriter from .smt_model_factory import SmtModelFactory +from .translation_engine_build_job import TranslationEngineBuildJob +from .translation_file_service import PretranslationInfo, TranslationFileService logger = logging.getLogger(__name__) -class SmtEngineBuildJob: - def __init__(self, config: Any, smt_model_factory: SmtModelFactory, shared_file_service: SharedFileService) -> None: - self._config = config +class SmtEngineBuildJob(TranslationEngineBuildJob): + def __init__( + self, config: Any, smt_model_factory: SmtModelFactory, shared_file_service: TranslationFileService + ) -> None: self._smt_model_factory = smt_model_factory - self._shared_file_service = shared_file_service - - def run( - self, - progress: Optional[Callable[[ProgressStatus], None]] = None, - check_canceled: Optional[Callable[[], None]] = None, - ) -> Tuple[int, float]: - if check_canceled is not None: - check_canceled() - self._smt_model_factory.init() - tokenizer = self._smt_model_factory.create_tokenizer() - logger.info(f"Tokenizer: {type(tokenizer).__name__}") - - logger.info("Downloading data files") - source_corpus = self._shared_file_service.create_source_corpus() - target_corpus = self._shared_file_service.create_target_corpus() - parallel_corpus = source_corpus.align_rows(target_corpus) - parallel_corpus_size = parallel_corpus.count(include_empty=False) - if parallel_corpus_size == 0: - raise RuntimeError("No parallel corpus data found") + self._tokenizer = create_tokenizer(config.thot.tokenizer) + logger.info(f"Tokenizer: {type(self._tokenizer).__name__}") + super().__init__(config, shared_file_service) - with self._shared_file_service.get_source_pretranslations() as src_pretranslations: - inference_step_count = sum(1 for _ in src_pretranslations) + def _get_progress_reporter(self, progress: Optional[Callable[[ProgressStatus], None]]) -> PhasedProgressReporter: phases = [ Phase(message="Training SMT model", percentage=0.85), Phase(message="Training truecaser", percentage=0.05), Phase(message="Pretranslating segments", percentage=0.1), ] - progress_reporter = PhasedProgressReporter(progress, phases) + return PhasedProgressReporter(progress, phases) - if check_canceled is not None: - check_canceled() + def _respond_to_no_training_corpus(self) -> Tuple[int, float]: + raise RuntimeError("No parallel corpus data found") + + def _train_model( + self, + progress_reporter: PhasedProgressReporter, + check_canceled: Optional[Callable[[], None]], + ) -> Tuple[int, float]: with progress_reporter.start_next_phase() as phase_progress, self._smt_model_factory.create_model_trainer( - tokenizer, parallel_corpus + self._tokenizer, self.parallel_corpus ) as trainer: trainer.train(progress=phase_progress, check_canceled=check_canceled) trainer.save() @@ -59,24 +51,33 @@ def run( confidence = trainer.stats.metrics["bleu"] * 100 with progress_reporter.start_next_phase() as phase_progress, self._smt_model_factory.create_truecaser_trainer( - tokenizer, target_corpus + self._tokenizer, self.target_corpus ) as truecase_trainer: truecase_trainer.train(progress=phase_progress, check_canceled=check_canceled) truecase_trainer.save() if check_canceled is not None: check_canceled() + return train_corpus_size, confidence + + def _batch_inference( + self, + progress_reporter: PhasedProgressReporter, + check_canceled: Optional[Callable[[], None]], + ) -> None: + with self._translation_file_service.get_source_pretranslations() as src_pretranslations: + inference_step_count = sum(1 for _ in src_pretranslations) with ExitStack() as stack: - detokenizer = self._smt_model_factory.create_detokenizer() + detokenizer = create_detokenizer(self._config.thot.tokenizer) truecaser = self._smt_model_factory.create_truecaser() phase_progress = stack.enter_context(progress_reporter.start_next_phase()) - engine = stack.enter_context(self._smt_model_factory.create_engine(tokenizer, detokenizer, truecaser)) - src_pretranslations = stack.enter_context(self._shared_file_service.get_source_pretranslations()) - writer = stack.enter_context(self._shared_file_service.open_target_pretranslation_writer()) + engine = stack.enter_context(self._smt_model_factory.create_engine(self._tokenizer, detokenizer, truecaser)) + src_pretranslations = stack.enter_context(self._translation_file_service.get_source_pretranslations()) + writer = stack.enter_context(self._translation_file_service.open_target_pretranslation_writer()) current_inference_step = 0 phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count)) - batch_size = self._config["pretranslation_batch_size"] + batch_size = self._config["inference_batch_size"] for pi_batch in batch(src_pretranslations, batch_size): if check_canceled is not None: check_canceled() @@ -84,19 +85,18 @@ def run( current_inference_step += len(pi_batch) phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count)) + def _save_model(self) -> None: logger.info("Saving model") model_path = self._smt_model_factory.save_model() - self._shared_file_service.save_model( + self._translation_file_service.save_model( model_path, f"builds/{self._config['build_id']}/model{''.join(model_path.suffixes)}" ) - return train_corpus_size, confidence - def _translate_batch( engine: TranslationEngine, batch: Sequence[PretranslationInfo], - writer: PretranslationWriter, + writer: DictToJsonWriter, ) -> None: source_segments = [pi["translation"] for pi in batch] for i, result in enumerate(engine.translate_batch(source_segments)): diff --git a/machine/jobs/smt_model_factory.py b/machine/jobs/smt_model_factory.py index 8b228404..9c08ae4a 100644 --- a/machine/jobs/smt_model_factory.py +++ b/machine/jobs/smt_model_factory.py @@ -1,5 +1,7 @@ +import shutil from abc import ABC, abstractmethod from pathlib import Path +from typing import Any, Optional from ..corpora.parallel_text_corpus import ParallelTextCorpus from ..corpora.text_corpus import TextCorpus @@ -11,21 +13,21 @@ class SmtModelFactory(ABC): - @abstractmethod - def init(self) -> None: ... - - @abstractmethod - def create_tokenizer(self) -> Tokenizer[str, int, str]: ... + def __init__(self, config: Any) -> None: + self._config = config - @abstractmethod - def create_detokenizer(self) -> Detokenizer[str, str]: ... + def init(self) -> None: + pass @abstractmethod def create_model_trainer(self, tokenizer: Tokenizer[str, int, str], corpus: ParallelTextCorpus) -> Trainer: ... @abstractmethod def create_engine( - self, tokenizer: Tokenizer[str, int, str], detokenizer: Detokenizer[str, str], truecaser: Truecaser + self, + tokenizer: Tokenizer[str, int, str], + detokenizer: Detokenizer[str, str], + truecaser: Optional[Truecaser] = None, ) -> TranslationEngine: ... @abstractmethod @@ -34,5 +36,12 @@ def create_truecaser_trainer(self, tokenizer: Tokenizer[str, int, str], target_c @abstractmethod def create_truecaser(self) -> Truecaser: ... - @abstractmethod - def save_model(self) -> Path: ... + def save_model(self) -> Path: + tar_file_basename = Path( + self._config.data_dir, self._config.shared_file_folder, "builds", self._config.build_id, "model" + ) + return Path(shutil.make_archive(str(tar_file_basename), "gztar", self._model_dir)) + + @property + def _model_dir(self) -> Path: + return Path(self._config.data_dir, self._config.shared_file_folder, "builds", self._config.build_id, "model") diff --git a/machine/jobs/thot/__init__.py b/machine/jobs/thot/__init__.py index e69de29b..e8a5bc61 100644 --- a/machine/jobs/thot/__init__.py +++ b/machine/jobs/thot/__init__.py @@ -0,0 +1,3 @@ +import os + +_THOT_NEW_MODEL_DIRECTORY = os.path.join(os.path.dirname(__file__), "thot-new-model") diff --git a/machine/jobs/thot/thot_smt_model_factory.py b/machine/jobs/thot/thot_smt_model_factory.py index 6c4b64ca..cee3d2df 100644 --- a/machine/jobs/thot/thot_smt_model_factory.py +++ b/machine/jobs/thot/thot_smt_model_factory.py @@ -1,18 +1,10 @@ -import os import shutil -from pathlib import Path -from typing import Any +from typing import Optional from ...corpora.parallel_text_corpus import ParallelTextCorpus from ...corpora.text_corpus import TextCorpus from ...tokenization.detokenizer import Detokenizer -from ...tokenization.latin_word_detokenizer import LatinWordDetokenizer -from ...tokenization.latin_word_tokenizer import LatinWordTokenizer from ...tokenization.tokenizer import Tokenizer -from ...tokenization.whitespace_detokenizer import WhitespaceDetokenizer -from ...tokenization.whitespace_tokenizer import WhitespaceTokenizer -from ...tokenization.zwsp_word_detokenizer import ZwspWordDetokenizer -from ...tokenization.zwsp_word_tokenizer import ZwspWordTokenizer from ...translation.thot.thot_smt_model import ThotSmtModel from ...translation.thot.thot_smt_model_trainer import ThotSmtModelTrainer from ...translation.trainer import Trainer @@ -20,41 +12,13 @@ from ...translation.truecaser import Truecaser from ...translation.unigram_truecaser import UnigramTruecaser, UnigramTruecaserTrainer from ..smt_model_factory import SmtModelFactory - -_THOT_NEW_MODEL_DIRECTORY = os.path.join(os.path.dirname(__file__), "thot-new-model") - -_TOKENIZERS = ["latin", "whitespace", "zwsp"] +from . import _THOT_NEW_MODEL_DIRECTORY class ThotSmtModelFactory(SmtModelFactory): - def __init__(self, config: Any) -> None: - self._config = config - def init(self) -> None: shutil.copytree(_THOT_NEW_MODEL_DIRECTORY, self._model_dir, dirs_exist_ok=True) - def create_tokenizer(self) -> Tokenizer[str, int, str]: - name: str = self._config.thot.tokenizer - name = name.lower() - if name == "latin": - return LatinWordTokenizer() - if name == "whitespace": - return WhitespaceTokenizer() - if name == "zwsp": - return ZwspWordTokenizer() - raise RuntimeError(f"Unknown tokenizer: {name}. Available tokenizers are: {_TOKENIZERS}.") - - def create_detokenizer(self) -> Detokenizer[str, str]: - name: str = self._config.thot.tokenizer - name = name.lower() - if name == "latin": - return LatinWordDetokenizer() - if name == "whitespace": - return WhitespaceDetokenizer() - if name == "zwsp": - return ZwspWordDetokenizer() - raise RuntimeError(f"Unknown detokenizer: {name}. Available detokenizers are: {_TOKENIZERS}.") - def create_model_trainer(self, tokenizer: Tokenizer[str, int, str], corpus: ParallelTextCorpus) -> Trainer: return ThotSmtModelTrainer( word_alignment_model_type=self._config.thot.word_alignment_model_type, @@ -67,7 +31,10 @@ def create_model_trainer(self, tokenizer: Tokenizer[str, int, str], corpus: Para ) def create_engine( - self, tokenizer: Tokenizer[str, int, str], detokenizer: Detokenizer[str, str], truecaser: Truecaser + self, + tokenizer: Tokenizer[str, int, str], + detokenizer: Detokenizer[str, str], + truecaser: Optional[Truecaser] = None, ) -> TranslationEngine: return ThotSmtModel( word_alignment_model_type=self._config.thot.word_alignment_model_type, @@ -87,13 +54,3 @@ def create_truecaser_trainer(self, tokenizer: Tokenizer[str, int, str], target_c def create_truecaser(self) -> Truecaser: return UnigramTruecaser(model_path=self._model_dir / "unigram-casing-model.txt") - - def save_model(self) -> Path: - tar_file_basename = Path( - self._config.data_dir, self._config.shared_file_folder, "builds", self._config.build_id, "model" - ) - return Path(shutil.make_archive(str(tar_file_basename), "gztar", self._model_dir)) - - @property - def _model_dir(self) -> Path: - return Path(self._config.data_dir, self._config.shared_file_folder, "builds", self._config.build_id, "model") diff --git a/machine/jobs/thot/thot_word_alignment_model_factory.py b/machine/jobs/thot/thot_word_alignment_model_factory.py new file mode 100644 index 00000000..60105e60 --- /dev/null +++ b/machine/jobs/thot/thot_word_alignment_model_factory.py @@ -0,0 +1,54 @@ +import shutil +from pathlib import Path + +from ...corpora.parallel_text_corpus import ParallelTextCorpus +from ...tokenization.tokenizer import Tokenizer +from ...translation.symmetrized_word_alignment_model_trainer import SymmetrizedWordAlignmentModelTrainer +from ...translation.thot.thot_symmetrized_word_alignment_model import ThotSymmetrizedWordAlignmentModel +from ...translation.thot.thot_word_alignment_model_trainer import ThotWordAlignmentModelTrainer +from ...translation.thot.thot_word_alignment_model_utils import create_thot_word_alignment_model +from ...translation.trainer import Trainer +from ...translation.word_alignment_model import WordAlignmentModel +from ..word_alignment_model_factory import WordAlignmentModelFactory +from . import _THOT_NEW_MODEL_DIRECTORY + + +class ThotWordAlignmentModelFactory(WordAlignmentModelFactory): + def init(self) -> None: + shutil.copytree(_THOT_NEW_MODEL_DIRECTORY, self._model_dir, dirs_exist_ok=True) + + def create_model_trainer(self, tokenizer: Tokenizer[str, int, str], corpus: ParallelTextCorpus) -> Trainer: + (self._model_dir / "tm").mkdir(parents=True, exist_ok=True) + direct_trainer = ThotWordAlignmentModelTrainer( + self._config.thot_align.model_type, + corpus.lowercase(), + prefix_filename=self._direct_model_path, + source_tokenizer=tokenizer, + target_tokenizer=tokenizer, + ) + inverse_trainer = ThotWordAlignmentModelTrainer( + self._config.thot_align.model_type, + corpus.invert().lowercase(), + prefix_filename=self._inverse_model_path, + source_tokenizer=tokenizer, + target_tokenizer=tokenizer, + ) + return SymmetrizedWordAlignmentModelTrainer(direct_trainer, inverse_trainer) + + def create_alignment_model( + self, + ) -> WordAlignmentModel: + model = ThotSymmetrizedWordAlignmentModel( + create_thot_word_alignment_model(self._config.thot_align.model_type, self._direct_model_path), + create_thot_word_alignment_model(self._config.thot_align.model_type, self._inverse_model_path), + ) + model.heuristic = self._config.thot_align.word_alignment_heuristic + return model + + @property + def _direct_model_path(self) -> Path: + return self._model_dir / "tm" / "src_trg_invswm" + + @property + def _inverse_model_path(self) -> Path: + return self._model_dir / "tm" / "src_trg_swm" diff --git a/machine/jobs/translation_engine_build_job.py b/machine/jobs/translation_engine_build_job.py new file mode 100644 index 00000000..0293e69c --- /dev/null +++ b/machine/jobs/translation_engine_build_job.py @@ -0,0 +1,88 @@ +import logging +from abc import ABC, abstractmethod +from typing import Any, Callable, Optional, Tuple + +from ..corpora.parallel_text_corpus import ParallelTextCorpus +from ..corpora.text_corpus import TextCorpus +from ..utils.phased_progress_reporter import PhasedProgressReporter +from ..utils.progress_status import ProgressStatus +from .translation_file_service import TranslationFileService + +logger = logging.getLogger(__name__) + + +class TranslationEngineBuildJob(ABC): + def __init__(self, config: Any, translation_file_service: TranslationFileService) -> None: + self._config = config + self._translation_file_service = translation_file_service + + def run( + self, + progress: Optional[Callable[[ProgressStatus], None]] = None, + check_canceled: Optional[Callable[[], None]] = None, + ) -> Tuple[int, float]: + if check_canceled is not None: + check_canceled() + + progress_reporter = self._get_progress_reporter(progress) + + if self.parallel_corpus_size == 0: + train_corpus_size, confidence = self._respond_to_no_training_corpus() + else: + train_corpus_size, confidence = self._train_model(progress_reporter, check_canceled) + + if check_canceled is not None: + check_canceled() + + logger.info("Pretranslating segments") + self._batch_inference(progress_reporter, check_canceled) + + self._save_model() + return train_corpus_size, confidence + + @abstractmethod + def _get_progress_reporter( + self, progress: Optional[Callable[[ProgressStatus], None]] + ) -> PhasedProgressReporter: ... + + @abstractmethod + def _respond_to_no_training_corpus(self) -> Tuple[int, float]: ... + + @abstractmethod + def _train_model( + self, + progress_reporter: PhasedProgressReporter, + check_canceled: Optional[Callable[[], None]], + ) -> Tuple[int, float]: ... + + @abstractmethod + def _batch_inference( + self, + progress_reporter: PhasedProgressReporter, + check_canceled: Optional[Callable[[], None]], + ) -> None: ... + + @abstractmethod + def _save_model(self) -> None: ... + + @property + def source_corpus(self) -> TextCorpus: + if "_source_corpus" not in self.__dict__: + self._source_corpus = self._translation_file_service.create_source_corpus() + return self._source_corpus + + @property + def target_corpus(self) -> TextCorpus: + if "_target_corpus" not in self.__dict__: + self._target_corpus = self._translation_file_service.create_target_corpus() + return self._target_corpus + + @property + def parallel_corpus(self) -> ParallelTextCorpus: + if "_parallel_corpus" not in self.__dict__: + self._parallel_corpus = self.source_corpus.align_rows(self.target_corpus) + return self._parallel_corpus + + @property + def parallel_corpus_size(self) -> int: + return self.parallel_corpus.count(include_empty=False) diff --git a/machine/jobs/translation_file_service.py b/machine/jobs/translation_file_service.py new file mode 100644 index 00000000..66a0bf0f --- /dev/null +++ b/machine/jobs/translation_file_service.py @@ -0,0 +1,74 @@ +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Generator, Iterator, List, TypedDict, Union + +import json_stream + +from ..corpora.text_corpus import TextCorpus +from ..corpora.text_file_text_corpus import TextFileTextCorpus +from ..utils.context_managed_generator import ContextManagedGenerator +from .shared_file_service_base import DictToJsonWriter, SharedFileServiceBase +from .shared_file_service_factory import SharedFileServiceType, get_shared_file_service + + +class PretranslationInfo(TypedDict): + corpusId: str # noqa: N815 + textId: str # noqa: N815 + refs: List[str] + translation: str + + +SOURCE_FILENAME = "train.src.txt" +TARGET_FILENAME = "train.trg.txt" +SOURCE_PRETRANSLATION_FILENAME = "pretranslate.src.json" +TARGET_PRETRANSLATION_FILENAME = "pretranslate.trg.json" + + +class TranslationFileService: + def __init__( + self, + type: Union[str, SharedFileServiceType], + config: Any, + ) -> None: + + self.shared_file_service: SharedFileServiceBase = get_shared_file_service(type, config) + + def create_source_corpus(self) -> TextCorpus: + return TextFileTextCorpus( + self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{SOURCE_FILENAME}") + ) + + def create_target_corpus(self) -> TextCorpus: + return TextFileTextCorpus( + self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{TARGET_FILENAME}") + ) + + def exists_source_corpus(self) -> bool: + return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{SOURCE_FILENAME}") + + def exists_target_corpus(self) -> bool: + return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{TARGET_FILENAME}") + + def get_source_pretranslations(self) -> ContextManagedGenerator[PretranslationInfo, None, None]: + src_pretranslate_path = self.shared_file_service.download_file( + f"{self.shared_file_service.build_path}/{SOURCE_PRETRANSLATION_FILENAME}" + ) + + def generator() -> Generator[PretranslationInfo, None, None]: + with src_pretranslate_path.open("r", encoding="utf-8-sig") as file: + for pi in json_stream.load(file): + yield PretranslationInfo( + corpusId=pi["corpusId"], + textId=pi["textId"], + refs=list(pi["refs"]), + translation=pi["translation"], + ) + + return ContextManagedGenerator(generator()) + + def save_model(self, model_path: Path, destination: str) -> None: + self.shared_file_service.upload_path(model_path, destination) + + @contextmanager + def open_target_pretranslation_writer(self) -> Iterator[DictToJsonWriter]: + return self.shared_file_service.open_target_writer(TARGET_PRETRANSLATION_FILENAME) diff --git a/machine/jobs/word_alignment_build_job.py b/machine/jobs/word_alignment_build_job.py new file mode 100644 index 00000000..df47583f --- /dev/null +++ b/machine/jobs/word_alignment_build_job.py @@ -0,0 +1,135 @@ +import logging +from contextlib import ExitStack +from typing import Any, Callable, Optional + +from ..corpora.parallel_text_corpus import ParallelTextCorpus +from ..corpora.text_corpus import TextCorpus +from ..tokenization.tokenizer_factory import create_tokenizer +from ..utils.phased_progress_reporter import Phase, PhasedProgressReporter +from ..utils.progress_status import ProgressStatus +from .word_alignment_file_service import WordAlignmentFileService, WordAlignmentInfo +from .word_alignment_model_factory import WordAlignmentModelFactory + +logger = logging.getLogger(__name__) + + +class WordAlignmentBuildJob: + def __init__( + self, + config: Any, + word_alignment_model_factory: WordAlignmentModelFactory, + word_alignment_file_service: WordAlignmentFileService, + ) -> None: + self._word_alignment_model_factory = word_alignment_model_factory + self._word_alignment_model_factory.init() + self._config = config + self._tokenizer = create_tokenizer(self._config.thot_align.tokenizer) + self._word_alignment_file_service = word_alignment_file_service + self._train_corpus_size = -1 + + def run( + self, + progress: Optional[Callable[[ProgressStatus], None]] = None, + check_canceled: Optional[Callable[[], None]] = None, + ) -> int: + if check_canceled is not None: + check_canceled() + + progress_reporter = self._get_progress_reporter(progress) + + if self.parallel_corpus_size == 0: + raise RuntimeError("No parallel corpus data found") + + train_corpus_size = self._train_model(progress_reporter, check_canceled) + + if check_canceled is not None: + check_canceled() + + logger.info("Generating alignments") + self._batch_inference(progress_reporter, check_canceled) + + self._save_model() + return train_corpus_size + + def _get_progress_reporter(self, progress: Optional[Callable[[ProgressStatus], None]]) -> PhasedProgressReporter: + phases = [ + Phase(message="Training Word Alignment model", percentage=0.9), + Phase(message="Aligning segments", percentage=0.1), + ] + return PhasedProgressReporter(progress, phases) + + def _train_model( + self, + progress_reporter: PhasedProgressReporter, + check_canceled: Optional[Callable[[], None]], + ) -> int: + + with progress_reporter.start_next_phase() as phase_progress, self._word_alignment_model_factory.create_model_trainer( + self._tokenizer, self.parallel_corpus + ) as trainer: + trainer.train(progress=phase_progress, check_canceled=check_canceled) + trainer.save() + train_corpus_size = trainer.stats.train_corpus_size + + if check_canceled is not None: + check_canceled() + return train_corpus_size + + def _batch_inference( + self, + progress_reporter: PhasedProgressReporter, + check_canceled: Optional[Callable[[], None]], + ) -> None: + inference_step_count = self.parallel_corpus.count(include_empty=False) + + with ExitStack() as stack: + phase_progress = stack.enter_context(progress_reporter.start_next_phase()) + alignment_model = stack.enter_context(self._word_alignment_model_factory.create_alignment_model()) + writer = stack.enter_context(self._word_alignment_file_service.open_target_alignment_writer()) + current_inference_step = 0 + phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count)) + batch_size = self._config["inference_batch_size"] + segment_batch = list(self.parallel_corpus.lowercase().tokenize(self._tokenizer).take(batch_size)) + if check_canceled is not None: + check_canceled() + alignments = alignment_model.align_batch(segment_batch) + if check_canceled is not None: + check_canceled() + for row, alignment in zip(self.parallel_corpus.get_rows(), alignments): + writer.write( + WordAlignmentInfo( + refs=[str(ref) for ref in row.source_refs], + column_count=alignment.column_count, + row_count=alignment.row_count, + alignment=str(alignment), + ) + ) + + def _save_model(self) -> None: + logger.info("Saving model") + model_path = self._word_alignment_model_factory.save_model() + self._word_alignment_file_service.save_model( + model_path, f"builds/{self._config['build_id']}/model{''.join(model_path.suffixes)}" + ) + + @property + def source_corpus(self) -> TextCorpus: + if "_source_corpus" not in self.__dict__: + self._source_corpus = self._word_alignment_file_service.create_source_corpus() + return self._source_corpus + + @property + def target_corpus(self) -> TextCorpus: + if "_target_corpus" not in self.__dict__: + self._target_corpus = self._word_alignment_file_service.create_target_corpus() + return self._target_corpus + + @property + def parallel_corpus(self) -> ParallelTextCorpus: + if "_parallel_corpus" not in self.__dict__: + self._parallel_corpus = self.source_corpus.align_rows(self.target_corpus) + return self._parallel_corpus + + @property + def parallel_corpus_size(self) -> int: + return self.parallel_corpus.count(include_empty=False) diff --git a/machine/jobs/word_alignment_file_service.py b/machine/jobs/word_alignment_file_service.py new file mode 100644 index 00000000..ecd264d6 --- /dev/null +++ b/machine/jobs/word_alignment_file_service.py @@ -0,0 +1,55 @@ +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Iterator, List, TypedDict, Union + +from ..corpora.text_corpus import TextCorpus +from ..corpora.text_file_text_corpus import TextFileTextCorpus +from .shared_file_service_base import DictToJsonWriter, SharedFileServiceBase +from .shared_file_service_factory import SharedFileServiceType, get_shared_file_service + + +class WordAlignmentInfo(TypedDict): + refs: List[str] + column_count: int + row_count: int + alignment: str + + +class WordAlignmentFileService: + def __init__( + self, + type: Union[str, SharedFileServiceType], + config: Any, + source_filename: str = "train.src.txt", + target_filename: str = "train.trg.txt", + word_alignment_filename: str = "word_alignments.json", + ) -> None: + + self._source_filename = source_filename + self._target_filename = target_filename + self._word_alignment_filename = word_alignment_filename + + self.shared_file_service: SharedFileServiceBase = get_shared_file_service(type, config) + + def create_source_corpus(self) -> TextCorpus: + return TextFileTextCorpus( + self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{self._source_filename}") + ) + + def create_target_corpus(self) -> TextCorpus: + return TextFileTextCorpus( + self.shared_file_service.download_file(f"{self.shared_file_service.build_path}/{self._target_filename}") + ) + + def exists_source_corpus(self) -> bool: + return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{self._source_filename}") + + def exists_target_corpus(self) -> bool: + return self.shared_file_service._exists_file(f"{self.shared_file_service.build_path}/{self._target_filename}") + + def save_model(self, model_path: Path, destination: str) -> None: + self.shared_file_service.upload_path(model_path, destination) + + @contextmanager + def open_target_alignment_writer(self) -> Iterator[DictToJsonWriter]: + return self.shared_file_service.open_target_writer(self._word_alignment_filename) diff --git a/machine/jobs/word_alignment_model_factory.py b/machine/jobs/word_alignment_model_factory.py new file mode 100644 index 00000000..a4cb107a --- /dev/null +++ b/machine/jobs/word_alignment_model_factory.py @@ -0,0 +1,35 @@ +import shutil +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Any + +from ..corpora.parallel_text_corpus import ParallelTextCorpus +from ..tokenization.tokenizer import Tokenizer +from ..translation.trainer import Trainer +from ..translation.word_alignment_model import WordAlignmentModel + + +class WordAlignmentModelFactory(ABC): + def __init__(self, config: Any) -> None: + self._config = config + + def init(self) -> None: + pass + + @abstractmethod + def create_model_trainer(self, tokenizer: Tokenizer[str, int, str], corpus: ParallelTextCorpus) -> Trainer: ... + + @abstractmethod + def create_alignment_model( + self, + ) -> WordAlignmentModel: ... + + def save_model(self) -> Path: + tar_file_basename = Path( + self._config.data_dir, self._config.shared_file_folder, "builds", self._config.build_id, "model" + ) + return Path(shutil.make_archive(str(tar_file_basename), "gztar", self._model_dir)) + + @property + def _model_dir(self) -> Path: + return Path(self._config.data_dir, self._config.shared_file_folder, "builds", self._config.build_id, "model") diff --git a/machine/tokenization/__init__.py b/machine/tokenization/__init__.py index aab2dbf3..5084b67c 100644 --- a/machine/tokenization/__init__.py +++ b/machine/tokenization/__init__.py @@ -9,6 +9,7 @@ from .string_tokenizer import StringTokenizer from .tokenization_utils import get_ranges, split from .tokenizer import Tokenizer +from .tokenizer_factory import create_detokenizer, create_tokenizer from .whitespace_detokenizer import WHITESPACE_DETOKENIZER, WhitespaceDetokenizer from .whitespace_tokenizer import WHITESPACE_TOKENIZER, WhitespaceTokenizer from .zwsp_word_detokenizer import ZwspWordDetokenizer @@ -27,6 +28,8 @@ "StringDetokenizer", "StringTokenizer", "Tokenizer", + "create_detokenizer", + "create_tokenizer", "WHITESPACE_DETOKENIZER", "WHITESPACE_TOKENIZER", "WhitespaceDetokenizer", diff --git a/machine/tokenization/tokenizer_factory.py b/machine/tokenization/tokenizer_factory.py new file mode 100644 index 00000000..b1d6eda9 --- /dev/null +++ b/machine/tokenization/tokenizer_factory.py @@ -0,0 +1,54 @@ +from enum import Enum, auto +from typing import Union + +from . import ( + Detokenizer, + LatinSentenceTokenizer, + LatinWordDetokenizer, + LatinWordTokenizer, + LineSegmentTokenizer, + NullTokenizer, + Tokenizer, + WhitespaceDetokenizer, + WhitespaceTokenizer, + ZwspWordDetokenizer, + ZwspWordTokenizer, +) + + +class TokenizerType(Enum): + NULL = auto() + LINE_SEGMENT = auto() + WHITESPACE = auto() + LATIN = auto() + LATIN_SENTENCE = auto() + ZWSP = auto() + + +def create_tokenizer(type: Union[str, TokenizerType]) -> Tokenizer[str, int, str]: + if isinstance(type, str): + type = TokenizerType[type.upper()] + if type == TokenizerType.NULL: + return NullTokenizer() + if type == TokenizerType.LINE_SEGMENT: + return LineSegmentTokenizer() + if type == TokenizerType.WHITESPACE: + return WhitespaceTokenizer() + if type == TokenizerType.LATIN: + return LatinWordTokenizer() + if type == TokenizerType.LATIN_SENTENCE: + return LatinSentenceTokenizer() + if type == TokenizerType.ZWSP: + return ZwspWordTokenizer() + + +def create_detokenizer(type: Union[str, TokenizerType]) -> Detokenizer[str, str]: + if isinstance(type, str): + type = TokenizerType[type.upper()] + if type == TokenizerType.WHITESPACE: + return WhitespaceDetokenizer() + if type == TokenizerType.LATIN: + return LatinWordDetokenizer() + if type == TokenizerType.ZWSP: + return ZwspWordDetokenizer() + raise RuntimeError(f"Unknown tokenizer: {type}. Available tokenizers are: whitespace, latin, zwsp.") diff --git a/machine/translation/thot/thot_symmetrized_word_alignment_model.py b/machine/translation/thot/thot_symmetrized_word_alignment_model.py index 6c3815f0..c1cbe284 100644 --- a/machine/translation/thot/thot_symmetrized_word_alignment_model.py +++ b/machine/translation/thot/thot_symmetrized_word_alignment_model.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Sequence, cast +from typing import List, Sequence, Union, cast import thot.alignment as ta @@ -38,7 +38,9 @@ def inverse_word_alignment_model(self) -> ThotWordAlignmentModel: return cast(ThotWordAlignmentModel, self._inverse_word_alignment_model) @heuristic.setter - def heuristic(self, value: SymmetrizationHeuristic) -> None: + def heuristic(self, value: Union[SymmetrizationHeuristic, str]) -> None: + if isinstance(value, str): + value = SymmetrizationHeuristic[value.upper().replace("-", "_")] self._heuristic = value self._aligner.heuristic = _convert_heuristic(self._heuristic) diff --git a/machine/translation/thot/thot_word_alignment_model_utils.py b/machine/translation/thot/thot_word_alignment_model_utils.py index b4a18884..82482a86 100644 --- a/machine/translation/thot/thot_word_alignment_model_utils.py +++ b/machine/translation/thot/thot_word_alignment_model_utils.py @@ -1,5 +1,6 @@ -from typing import Union +from typing import Optional, Union +from ...utils.typeshed import StrPath from .thot_fast_align_word_alignment_model import ThotFastAlignWordAlignmentModel from .thot_hmm_word_alignment_model import ThotHmmWordAlignmentModel from .thot_ibm1_word_alignment_model import ThotIbm1WordAlignmentModel @@ -11,25 +12,31 @@ from .thot_word_alignment_model_type import ThotWordAlignmentModelType -def create_thot_word_alignment_model(type: Union[str, int]) -> ThotWordAlignmentModel: +def create_thot_word_alignment_model( + type: Union[str, int], prefix_filename: Optional[StrPath] = None +) -> ThotWordAlignmentModel: if isinstance(type, str): type = ThotWordAlignmentModelType[type.upper()] if type == ThotWordAlignmentModelType.FAST_ALIGN: - return ThotFastAlignWordAlignmentModel() + return ThotFastAlignWordAlignmentModel(prefix_filename) if type == ThotWordAlignmentModelType.IBM1: - return ThotIbm1WordAlignmentModel() + return ThotIbm1WordAlignmentModel(prefix_filename) if type == ThotWordAlignmentModelType.IBM2: - return ThotIbm2WordAlignmentModel() + return ThotIbm2WordAlignmentModel(prefix_filename) if type == ThotWordAlignmentModelType.HMM: - return ThotHmmWordAlignmentModel() + return ThotHmmWordAlignmentModel(prefix_filename) if type == ThotWordAlignmentModelType.IBM3: - return ThotIbm3WordAlignmentModel() + return ThotIbm3WordAlignmentModel(prefix_filename) if type == ThotWordAlignmentModelType.IBM4: - return ThotIbm4WordAlignmentModel() + return ThotIbm4WordAlignmentModel(prefix_filename) raise ValueError("The word alignment model type is unknown.") -def create_thot_symmetrized_word_alignment_model(type: Union[int, str]) -> ThotSymmetrizedWordAlignmentModel: - direct_model = create_thot_word_alignment_model(type) - inverse_model = create_thot_word_alignment_model(type) +def create_thot_symmetrized_word_alignment_model( + type: Union[int, str], + direct_prefix_filename: Optional[StrPath] = None, + inverse_prefix_filename: Optional[StrPath] = None, +) -> ThotSymmetrizedWordAlignmentModel: + direct_model = create_thot_word_alignment_model(type, direct_prefix_filename) + inverse_model = create_thot_word_alignment_model(type, inverse_prefix_filename) return ThotSymmetrizedWordAlignmentModel(direct_model, inverse_model) diff --git a/tests/jobs/test_nmt_engine_build_job.py b/tests/jobs/test_nmt_engine_build_job.py index b4697107..a5e416d6 100644 --- a/tests/jobs/test_nmt_engine_build_job.py +++ b/tests/jobs/test_nmt_engine_build_job.py @@ -10,7 +10,13 @@ from machine.annotations import Range from machine.corpora import DictionaryTextCorpus -from machine.jobs import NmtEngineBuildJob, NmtModelFactory, PretranslationInfo, PretranslationWriter, SharedFileService +from machine.jobs import ( + DictToJsonWriter, + NmtEngineBuildJob, + NmtModelFactory, + PretranslationInfo, + TranslationFileService, +) from machine.translation import ( Phrase, Trainer, @@ -30,7 +36,7 @@ def test_run(decoy: Decoy) -> None: pretranslations = json.loads(env.target_pretranslations) assert len(pretranslations) == 1 assert pretranslations[0]["translation"] == "Please, I have booked a room." - decoy.verify(env.shared_file_service.save_model(Path("model.tar.gz"), "models/save-model.tar.gz"), times=1) + decoy.verify(env.translation_file_service.save_model(Path("model.tar.gz"), "models/save-model.tar.gz"), times=1) def test_cancel(decoy: Decoy) -> None: @@ -92,12 +98,12 @@ def __init__(self, decoy: Decoy) -> None: decoy.when(self.nmt_model_factory.create_engine()).then_return(self.engine) decoy.when(self.nmt_model_factory.save_model()).then_return(Path("model.tar.gz")) - self.shared_file_service = decoy.mock(cls=SharedFileService) - decoy.when(self.shared_file_service.create_source_corpus()).then_return(DictionaryTextCorpus()) - decoy.when(self.shared_file_service.create_target_corpus()).then_return(DictionaryTextCorpus()) - decoy.when(self.shared_file_service.exists_source_corpus()).then_return(True) - decoy.when(self.shared_file_service.exists_target_corpus()).then_return(True) - decoy.when(self.shared_file_service.get_source_pretranslations()).then_do( + self.translation_file_service = decoy.mock(cls=TranslationFileService) + decoy.when(self.translation_file_service.create_source_corpus()).then_return(DictionaryTextCorpus()) + decoy.when(self.translation_file_service.create_target_corpus()).then_return(DictionaryTextCorpus()) + decoy.when(self.translation_file_service.exists_source_corpus()).then_return(True) + decoy.when(self.translation_file_service.exists_target_corpus()).then_return(True) + decoy.when(self.translation_file_service.get_source_pretranslations()).then_do( lambda: ContextManagedGenerator( ( pi @@ -116,23 +122,21 @@ def __init__(self, decoy: Decoy) -> None: self.target_pretranslations = "" @contextmanager - def open_target_pretranslation_writer(env: _TestEnvironment) -> Iterator[PretranslationWriter]: + def open_target_pretranslation_writer(env: _TestEnvironment) -> Iterator[DictToJsonWriter]: file = StringIO() file.write("[\n") - yield PretranslationWriter(file) + yield DictToJsonWriter(file) file.write("\n]\n") env.target_pretranslations = file.getvalue() - decoy.when(self.shared_file_service.open_target_pretranslation_writer()).then_do( + decoy.when(self.translation_file_service.open_target_pretranslation_writer()).then_do( lambda: open_target_pretranslation_writer(self) ) self.job = NmtEngineBuildJob( - MockSettings( - {"src_lang": "es", "trg_lang": "en", "save_model": "save-model", "pretranslation_batch_size": 100} - ), + MockSettings({"src_lang": "es", "trg_lang": "en", "save_model": "save-model", "inference_batch_size": 100}), self.nmt_model_factory, - self.shared_file_service, + self.translation_file_service, ) diff --git a/tests/jobs/test_smt_engine_build_job.py b/tests/jobs/test_smt_engine_build_job.py index ff77d49d..4acf4123 100644 --- a/tests/jobs/test_smt_engine_build_job.py +++ b/tests/jobs/test_smt_engine_build_job.py @@ -10,18 +10,18 @@ from machine.annotations import Range from machine.corpora import DictionaryTextCorpus, MemoryText, TextRow -from machine.jobs import PretranslationInfo, PretranslationWriter, SharedFileService, SmtEngineBuildJob, SmtModelFactory -from machine.tokenization import WHITESPACE_DETOKENIZER, WHITESPACE_TOKENIZER +from machine.jobs import DictToJsonWriter, PretranslationInfo, SmtEngineBuildJob, SmtModelFactory +from machine.jobs.translation_file_service import TranslationFileService from machine.translation import ( Phrase, Trainer, TrainStats, - TranslationEngine, TranslationResult, TranslationSources, Truecaser, WordAlignmentMatrix, ) +from machine.translation.translation_engine import TranslationEngine from machine.utils import CanceledError, ContextManagedGenerator @@ -33,7 +33,8 @@ def test_run(decoy: Decoy) -> None: assert len(pretranslations) == 1 assert pretranslations[0]["translation"] == "Please, I have booked a room." decoy.verify( - env.shared_file_service.save_model(matchers.Anything(), f"builds/{env.job._config.build_id}/model.zip"), times=1 + env.translation_file_service.save_model(matchers.Anything(), f"builds/{env.job._config.build_id}/model.zip"), + times=1, ) @@ -87,8 +88,6 @@ def __init__(self, decoy: Decoy) -> None: self.truecaser = decoy.mock(cls=Truecaser) self.smt_model_factory = decoy.mock(cls=SmtModelFactory) - decoy.when(self.smt_model_factory.create_tokenizer()).then_return(WHITESPACE_TOKENIZER) - decoy.when(self.smt_model_factory.create_detokenizer()).then_return(WHITESPACE_DETOKENIZER) decoy.when(self.smt_model_factory.create_model_trainer(matchers.Anything(), matchers.Anything())).then_return( self.model_trainer ) @@ -101,8 +100,8 @@ def __init__(self, decoy: Decoy) -> None: decoy.when(self.smt_model_factory.create_truecaser()).then_return(self.truecaser) decoy.when(self.smt_model_factory.save_model()).then_return(Path("model.zip")) - self.shared_file_service = decoy.mock(cls=SharedFileService) - decoy.when(self.shared_file_service.create_source_corpus()).then_return( + self.translation_file_service = decoy.mock(cls=TranslationFileService) + decoy.when(self.translation_file_service.create_source_corpus()).then_return( DictionaryTextCorpus( MemoryText( "text1", @@ -114,7 +113,7 @@ def __init__(self, decoy: Decoy) -> None: ) ) ) - decoy.when(self.shared_file_service.create_target_corpus()).then_return( + decoy.when(self.translation_file_service.create_target_corpus()).then_return( DictionaryTextCorpus( MemoryText( "text1", @@ -126,9 +125,9 @@ def __init__(self, decoy: Decoy) -> None: ) ) ) - decoy.when(self.shared_file_service.exists_source_corpus()).then_return(True) - decoy.when(self.shared_file_service.exists_target_corpus()).then_return(True) - decoy.when(self.shared_file_service.get_source_pretranslations()).then_do( + decoy.when(self.translation_file_service.exists_source_corpus()).then_return(True) + decoy.when(self.translation_file_service.exists_target_corpus()).then_return(True) + decoy.when(self.translation_file_service.get_source_pretranslations()).then_do( lambda: ContextManagedGenerator( ( pi @@ -147,21 +146,21 @@ def __init__(self, decoy: Decoy) -> None: self.target_pretranslations = "" @contextmanager - def open_target_pretranslation_writer(env: _TestEnvironment) -> Iterator[PretranslationWriter]: + def open_target_pretranslation_writer(env: _TestEnvironment) -> Iterator[DictToJsonWriter]: file = StringIO() file.write("[\n") - yield PretranslationWriter(file) + yield DictToJsonWriter(file) file.write("\n]\n") env.target_pretranslations = file.getvalue() - decoy.when(self.shared_file_service.open_target_pretranslation_writer()).then_do( + decoy.when(self.translation_file_service.open_target_pretranslation_writer()).then_do( lambda: open_target_pretranslation_writer(self) ) self.job = SmtEngineBuildJob( - MockSettings({"build_id": "mybuild", "pretranslation_batch_size": 100}), + MockSettings({"build_id": "mybuild", "inference_batch_size": 100, "thot": {"tokenizer": "latin"}}), self.smt_model_factory, - self.shared_file_service, + self.translation_file_service, ) diff --git a/tests/jobs/test_word_alignment_build_job.py b/tests/jobs/test_word_alignment_build_job.py new file mode 100644 index 00000000..7f2c9be2 --- /dev/null +++ b/tests/jobs/test_word_alignment_build_job.py @@ -0,0 +1,122 @@ +import json +from contextlib import contextmanager +from io import StringIO +from pathlib import Path +from typing import Iterator + +from decoy import Decoy, matchers +from pytest import raises +from testutils.mock_settings import MockSettings + +from machine.corpora import DictionaryTextCorpus, MemoryText, TextRow +from machine.jobs import DictToJsonWriter, WordAlignmentBuildJob, WordAlignmentModelFactory +from machine.jobs.word_alignment_file_service import WordAlignmentFileService +from machine.translation import Trainer, TrainStats, WordAlignmentMatrix +from machine.translation.word_alignment_model import WordAlignmentModel +from machine.utils import CanceledError + + +def test_run(decoy: Decoy) -> None: + env = _TestEnvironment(decoy) + env.job.run() + + alignments = json.loads(env.alignment_json) + assert len(alignments) == 1 + assert alignments[0]["alignment"] == "0-0 1-1 2-2" + decoy.verify( + env.word_alignment_file_service.save_model(matchers.Anything(), f"builds/{env.job._config.build_id}/model.zip"), + times=1, + ) + + +def test_cancel(decoy: Decoy) -> None: + env = _TestEnvironment(decoy) + checker = _CancellationChecker(3) + with raises(CanceledError): + env.job.run(check_canceled=checker.check_canceled) + + assert env.alignment_json == "" + + +class _TestEnvironment: + def __init__(self, decoy: Decoy) -> None: + self.model_trainer = decoy.mock(cls=Trainer) + decoy.when(self.model_trainer.__enter__()).then_return(self.model_trainer) + stats = TrainStats() + stats.train_corpus_size = 3 + stats.metrics["bleu"] = 30.0 + decoy.when(self.model_trainer.stats).then_return(stats) + + self.model = decoy.mock(cls=WordAlignmentModel) + decoy.when(self.model.__enter__()).then_return(self.model) + decoy.when(self.model.align_batch(matchers.Anything())).then_return( + [ + WordAlignmentMatrix.from_word_pairs(row_count=3, column_count=3, set_values=[(0, 0), (1, 1), (2, 2)]), + ] + ) + + self.word_alignment_model_factory = decoy.mock(cls=WordAlignmentModelFactory) + decoy.when( + self.word_alignment_model_factory.create_model_trainer(matchers.Anything(), matchers.Anything()) + ).then_return(self.model_trainer) + decoy.when(self.word_alignment_model_factory.create_alignment_model()).then_return(self.model) + decoy.when(self.word_alignment_model_factory.save_model()).then_return(Path("model.zip")) + + self.word_alignment_file_service = decoy.mock(cls=WordAlignmentFileService) + decoy.when(self.word_alignment_file_service.create_source_corpus()).then_return( + DictionaryTextCorpus( + MemoryText( + "text1", + [ + TextRow("text1", 1, ["¿Le importaría darnos las llaves de la habitación, por favor?"]), + TextRow("text1", 2, ["¿Le importaría cambiarme a otra habitación más tranquila?"]), + TextRow("text1", 3, ["Me parece que existe un problema."]), + ], + ) + ) + ) + decoy.when(self.word_alignment_file_service.create_target_corpus()).then_return( + DictionaryTextCorpus( + MemoryText( + "text1", + [ + TextRow("text1", 1, ["Would you mind giving us the room keys, please?"]), + TextRow("text1", 2, ["Would you mind moving me to another quieter room?"]), + TextRow("text1", 3, ["I think there is a problem."]), + ], + ) + ) + ) + decoy.when(self.word_alignment_file_service.exists_source_corpus()).then_return(True) + decoy.when(self.word_alignment_file_service.exists_target_corpus()).then_return(True) + + self.alignment_json = "" + + @contextmanager + def open_target_alignment_writer(env: _TestEnvironment) -> Iterator[DictToJsonWriter]: + file = StringIO() + file.write("[\n") + yield DictToJsonWriter(file) + file.write("\n]\n") + env.alignment_json = file.getvalue() + + decoy.when(self.word_alignment_file_service.open_target_alignment_writer()).then_do( + lambda: open_target_alignment_writer(self) + ) + + self.job = WordAlignmentBuildJob( + MockSettings({"build_id": "mybuild", "inference_batch_size": 100, "thot_align": {"tokenizer": "latin"}}), + self.word_alignment_model_factory, + self.word_alignment_file_service, + ) + + +class _CancellationChecker: + def __init__(self, raise_count: int) -> None: + self._call_count = 0 + self._raise_count = raise_count + + def check_canceled(self) -> None: + self._call_count += 1 + if self._call_count == self._raise_count: + raise CanceledError diff --git a/tests/translation/thot/test_thot_smt_model_trainer.py b/tests/translation/thot/test_thot_smt_model_trainer.py index 19933e06..5be7e4c5 100644 --- a/tests/translation/thot/test_thot_smt_model_trainer.py +++ b/tests/translation/thot/test_thot_smt_model_trainer.py @@ -1,92 +1,14 @@ import os from tempfile import TemporaryDirectory -from machine.corpora import ( - AlignedWordPair, - AlignmentRow, - DictionaryAlignmentCorpus, - DictionaryTextCorpus, - MemoryAlignmentCollection, - MemoryText, - TextRow, -) +from translation.thot.thot_model_trainer_helper import get_emtpy_parallel_corpus, get_parallel_corpus + from machine.translation.thot import ThotSmtModel, ThotSmtModelTrainer, ThotSmtParameters, ThotWordAlignmentModelType def test_train_non_empty_corpus() -> None: with TemporaryDirectory() as temp_dir: - source_corpus = DictionaryTextCorpus( - [ - MemoryText( - "text1", - [ - _row(1, "¿ le importaría darnos las llaves de la habitación , por favor ?"), - _row( - 2, - "he hecho la reserva de una habitación tranquila doble con ||| teléfono ||| y televisión a " - "nombre de rosario cabedo .", - ), - _row(3, "¿ le importaría cambiarme a otra habitación más tranquila ?"), - _row(4, "por favor , tengo reservada una habitación ."), - _row(5, "me parece que existe un problema ."), - _row( - 6, - "¿ tiene habitaciones libres con televisión , aire acondicionado y caja fuerte ?", - ), - _row(7, "¿ le importaría mostrarnos una habitación con televisión ?"), - _row(8, "¿ tiene teléfono ?"), - _row(9, "voy a marcharme el dos a las ocho de la noche ."), - _row(10, "¿ cuánto cuesta una habitación individual por semana ?"), - ], - ) - ] - ) - - target_corpus = DictionaryTextCorpus( - [ - MemoryText( - "text1", - [ - _row(1, "would you mind giving us the keys to the room , please ?"), - _row( - 2, - "i have made a reservation for a quiet , double room with a ||| telephone ||| and a tv for " - "rosario cabedo .", - ), - _row(3, "would you mind moving me to a quieter room ?"), - _row(4, "i have booked a room ."), - _row(5, "i think that there is a problem ."), - _row(6, "do you have any rooms with a tv , air conditioning and a safe available ?"), - _row(7, "would you mind showing us a room with a tv ?"), - _row(8, "does it have a telephone ?"), - _row(9, "i am leaving on the second at eight in the evening ."), - _row(10, "how much does a single room cost per week ?"), - ], - ) - ] - ) - - alignment_corpus = DictionaryAlignmentCorpus( - [ - MemoryAlignmentCollection( - "text1", - [ - _alignment(1, AlignedWordPair(8, 9)), - _alignment(2, AlignedWordPair(6, 10)), - _alignment(3, AlignedWordPair(6, 8)), - _alignment(4, AlignedWordPair(6, 4)), - _alignment(5), - _alignment(6, AlignedWordPair(2, 4)), - _alignment(7, AlignedWordPair(5, 6)), - _alignment(8), - _alignment(9), - _alignment(10, AlignedWordPair(4, 5)), - ], - ) - ] - ) - - corpus = source_corpus.align_rows(target_corpus, alignment_corpus) + corpus = get_parallel_corpus() parameters = ThotSmtParameters( translation_model_filename_prefix=os.path.join(temp_dir, "tm", "src_trg"), @@ -105,11 +27,7 @@ def test_train_non_empty_corpus() -> None: def test_train_empty_corpus() -> None: with TemporaryDirectory() as temp_dir: - source_corpus = DictionaryTextCorpus([]) - target_corpus = DictionaryTextCorpus([]) - alignment_corpus = DictionaryAlignmentCorpus([]) - - corpus = source_corpus.align_rows(target_corpus, alignment_corpus) + corpus = get_emtpy_parallel_corpus() parameters = ThotSmtParameters( translation_model_filename_prefix=os.path.join(temp_dir, "tm", "src_trg"), @@ -124,11 +42,3 @@ def test_train_empty_corpus() -> None: with ThotSmtModel(ThotWordAlignmentModelType.HMM, parameters) as model: result = model.translate("una habitación individual por semana") assert result.translation == "una habitación individual por semana" - - -def _row(row_ref: int, text: str) -> TextRow: - return TextRow("text1", row_ref, segment=[text]) - - -def _alignment(row_ref: int, *pairs: AlignedWordPair) -> AlignmentRow: - return AlignmentRow("text1", row_ref, aligned_word_pairs=pairs) diff --git a/tests/translation/thot/test_thot_word_alignment_model_trainer.py b/tests/translation/thot/test_thot_word_alignment_model_trainer.py new file mode 100644 index 00000000..d5b69c60 --- /dev/null +++ b/tests/translation/thot/test_thot_word_alignment_model_trainer.py @@ -0,0 +1,83 @@ +from pathlib import Path +from tempfile import TemporaryDirectory + +from translation.thot.thot_model_trainer_helper import get_emtpy_parallel_corpus, get_parallel_corpus + +from machine.corpora.parallel_text_corpus import ParallelTextCorpus +from machine.tokenization import StringTokenizer, WhitespaceTokenizer +from machine.translation.symmetrized_word_alignment_model_trainer import SymmetrizedWordAlignmentModelTrainer +from machine.translation.thot import ThotWordAlignmentModelTrainer +from machine.translation.thot.thot_symmetrized_word_alignment_model import ThotSymmetrizedWordAlignmentModel +from machine.translation.thot.thot_word_alignment_model_utils import create_thot_word_alignment_model +from machine.translation.word_alignment_matrix import WordAlignmentMatrix + + +def train_model( + corpus: ParallelTextCorpus, + direct_model_path: Path, + inverse_model_path: Path, + thot_word_alignment_model_type: str, + tokenizer: StringTokenizer, +): + direct_trainer = ThotWordAlignmentModelTrainer( + thot_word_alignment_model_type, + corpus.lowercase(), + prefix_filename=direct_model_path, + source_tokenizer=tokenizer, + target_tokenizer=tokenizer, + ) + inverse_trainer = ThotWordAlignmentModelTrainer( + thot_word_alignment_model_type, + corpus.invert().lowercase(), + prefix_filename=inverse_model_path, + source_tokenizer=tokenizer, + target_tokenizer=tokenizer, + ) + + with SymmetrizedWordAlignmentModelTrainer(direct_trainer, inverse_trainer) as trainer: + trainer.train(lambda status: print(f"{status.message}: {status.percent_completed:.2%}")) + trainer.save() + + +def test_train_non_empty_corpus() -> None: + thot_word_alignment_model_type = "hmm" + tokenizer = WhitespaceTokenizer() + corpus = get_parallel_corpus() + + with TemporaryDirectory() as temp_dir: + tmp_path = Path(temp_dir) + (tmp_path / "tm").mkdir() + direct_model_path = tmp_path / "tm" / "src_trg_invswm" + inverse_model_path = tmp_path / "tm" / "src_trg_swm" + train_model(corpus, direct_model_path, inverse_model_path, thot_word_alignment_model_type, tokenizer) + with ThotSymmetrizedWordAlignmentModel( + create_thot_word_alignment_model(thot_word_alignment_model_type, direct_model_path), + create_thot_word_alignment_model(thot_word_alignment_model_type, inverse_model_path), + ) as model: + matrix = model.align( + list(tokenizer.tokenize("una habitación individual por semana")), + list(tokenizer.tokenize("a single room cost per week")), + ) + assert matrix == WordAlignmentMatrix.from_word_pairs(5, 6, {(0, 2), (1, 2), (2, 3), (2, 4), (2, 5)}) + + +def test_train_empty_corpus() -> None: + thot_word_alignment_model_type = "hmm" + tokenizer = WhitespaceTokenizer() + corpus = get_emtpy_parallel_corpus() + with TemporaryDirectory() as temp_dir: + tmp_path = Path(temp_dir) + direct_model_path = tmp_path / "tm" / "src_trg_invswm" + inverse_model_path = tmp_path / "tm" / "src_trg_swm" + train_model(corpus, direct_model_path, inverse_model_path, thot_word_alignment_model_type, tokenizer) + with ThotSymmetrizedWordAlignmentModel( + create_thot_word_alignment_model(thot_word_alignment_model_type, direct_model_path), + create_thot_word_alignment_model(thot_word_alignment_model_type, inverse_model_path), + ) as model: + matrix = model.align("una habitación individual por semana", "a single room cost per week") + assert matrix == WordAlignmentMatrix.from_word_pairs(5, 6, {(0, 0)}) + + +if __name__ == "__main__": + test_train_non_empty_corpus() + test_train_empty_corpus() diff --git a/tests/translation/thot/thot_model_trainer_helper.py b/tests/translation/thot/thot_model_trainer_helper.py new file mode 100644 index 00000000..6bb232a9 --- /dev/null +++ b/tests/translation/thot/thot_model_trainer_helper.py @@ -0,0 +1,103 @@ +from machine.corpora import ( + AlignedWordPair, + AlignmentRow, + DictionaryAlignmentCorpus, + DictionaryTextCorpus, + MemoryAlignmentCollection, + MemoryText, + TextRow, +) +from machine.corpora.parallel_text_corpus import ParallelTextCorpus + + +def get_parallel_corpus() -> ParallelTextCorpus: + source_corpus = DictionaryTextCorpus( + [ + MemoryText( + "text1", + [ + _row(1, "¿ le importaría darnos las llaves de la habitación , por favor ?"), + _row( + 2, + "he hecho la reserva de una habitación tranquila doble con ||| teléfono ||| y televisión a " + "nombre de rosario cabedo .", + ), + _row(3, "¿ le importaría cambiarme a otra habitación más tranquila ?"), + _row(4, "por favor , tengo reservada una habitación ."), + _row(5, "me parece que existe un problema ."), + _row( + 6, + "¿ tiene habitaciones libres con televisión , aire acondicionado y caja fuerte ?", + ), + _row(7, "¿ le importaría mostrarnos una habitación con televisión ?"), + _row(8, "¿ tiene teléfono ?"), + _row(9, "voy a marcharme el dos a las ocho de la noche ."), + _row(10, "¿ cuánto cuesta una habitación individual por semana ?"), + ], + ) + ] + ) + + target_corpus = DictionaryTextCorpus( + [ + MemoryText( + "text1", + [ + _row(1, "would you mind giving us the keys to the room , please ?"), + _row( + 2, + "i have made a reservation for a quiet , double room with a ||| telephone ||| and a tv for " + "rosario cabedo .", + ), + _row(3, "would you mind moving me to a quieter room ?"), + _row(4, "i have booked a room ."), + _row(5, "i think that there is a problem ."), + _row(6, "do you have any rooms with a tv , air conditioning and a safe available ?"), + _row(7, "would you mind showing us a room with a tv ?"), + _row(8, "does it have a telephone ?"), + _row(9, "i am leaving on the second at eight in the evening ."), + _row(10, "how much does a single room cost per week ?"), + ], + ) + ] + ) + + alignment_corpus = DictionaryAlignmentCorpus( + [ + MemoryAlignmentCollection( + "text1", + [ + _alignment(1, AlignedWordPair(8, 9)), + _alignment(2, AlignedWordPair(6, 10)), + _alignment(3, AlignedWordPair(6, 8)), + _alignment(4, AlignedWordPair(6, 4)), + _alignment(5), + _alignment(6, AlignedWordPair(2, 4)), + _alignment(7, AlignedWordPair(5, 6)), + _alignment(8), + _alignment(9), + _alignment(10, AlignedWordPair(4, 5)), + ], + ) + ] + ) + + corpus = source_corpus.align_rows(target_corpus, alignment_corpus) + return corpus + + +def get_emtpy_parallel_corpus() -> ParallelTextCorpus: + source_corpus = DictionaryTextCorpus([]) + target_corpus = DictionaryTextCorpus([]) + alignment_corpus = DictionaryAlignmentCorpus([]) + + corpus = source_corpus.align_rows(target_corpus, alignment_corpus) + return corpus + + +def _row(row_ref: int, text: str) -> TextRow: + return TextRow("text1", row_ref, segment=[text]) + + +def _alignment(row_ref: int, *pairs: AlignedWordPair) -> AlignmentRow: + return AlignmentRow("text1", row_ref, aligned_word_pairs=pairs)