From 18934739368397c358d2f9ddef078ee4685045fd Mon Sep 17 00:00:00 2001 From: John Lambert Date: Thu, 15 Aug 2024 16:28:45 -0400 Subject: [PATCH] Initial word alignment build job (now with files) --- machine/jobs/build_clearml_helper.py | 117 ++++++++++++++ machine/jobs/build_smt_engine.py | 112 ++------------ machine/jobs/build_word_alignment_model.py | 103 +++++++++++++ machine/jobs/engine_build_job.py | 10 +- machine/jobs/nmt_engine_build_job.py | 8 +- machine/jobs/settings.yaml | 3 +- machine/jobs/shared_file_service.py | 33 ++-- machine/jobs/smt_engine_build_job.py | 8 +- machine/jobs/smt_model_factory.py | 16 +- machine/jobs/thot/thot_model_factory.py | 57 +++++++ machine/jobs/thot/thot_smt_model_factory.py | 53 +------ .../thot/thot_word_alignment_model_factory.py | 49 ++++++ machine/jobs/word_alignment_build_job.py | 85 +++++++++++ machine/jobs/word_alignment_model_factory.py | 17 +++ .../thot_symmetrized_word_alignment_model.py | 6 +- .../thot/thot_word_alignment_model_utils.py | 29 ++-- tests/jobs/test_nmt_engine_build_job.py | 10 +- tests/jobs/test_smt_engine_build_job.py | 8 +- tests/jobs/test_word_alignment_build_job.py | 143 ++++++++++++++++++ .../thot/test_thot_smt_model_trainer.py | 98 +----------- .../test_thot_word_alignment_model_trainer.py | 72 +++++++++ .../thot/thot_model_trainer_helper.py | 103 +++++++++++++ 22 files changed, 837 insertions(+), 303 deletions(-) create mode 100644 machine/jobs/build_clearml_helper.py create mode 100644 machine/jobs/build_word_alignment_model.py create mode 100644 machine/jobs/thot/thot_model_factory.py create mode 100644 machine/jobs/thot/thot_word_alignment_model_factory.py create mode 100644 machine/jobs/word_alignment_build_job.py create mode 100644 machine/jobs/word_alignment_model_factory.py create mode 100644 tests/jobs/test_word_alignment_build_job.py create mode 100644 tests/translation/thot/test_thot_word_alignment_model_trainer.py create mode 100644 tests/translation/thot/thot_model_trainer_helper.py diff --git a/machine/jobs/build_clearml_helper.py b/machine/jobs/build_clearml_helper.py new file mode 100644 index 00000000..c7373023 --- /dev/null +++ b/machine/jobs/build_clearml_helper.py @@ -0,0 +1,117 @@ +import json +import logging +import os +from datetime import datetime +from typing import Callable, Optional, Union, cast + +import aiohttp +from clearml import Task +from dynaconf.base import Settings + +from ..utils.canceled_error import CanceledError +from ..utils.progress_status import ProgressStatus +from .async_scheduler import AsyncScheduler + + +class ProgressInfo: + last_percent_completed: Union[int, None] = 0 + last_message: Union[str, None] = "" + last_progress_time: Union[datetime, None] = None + last_check_canceled_time: Union[datetime, None] = None + + +def get_clearml_check_canceled(progress_info: ProgressInfo, task: Task) -> Callable[[], None]: + + def clearml_check_canceled() -> None: + current_time = datetime.now() + if ( + progress_info.last_check_canceled_time is None + or (current_time - progress_info.last_check_canceled_time).seconds > 20 + ): + if task.get_status() == "stopped": + raise CanceledError + progress_info.last_check_canceled_time = current_time + + return clearml_check_canceled + + +def get_clearml_progress_caller( + progress_info: ProgressInfo, task: Task, scheduler: AsyncScheduler, logger: logging.Logger +) -> Callable[[ProgressStatus], None]: + def clearml_progress(progress_status: ProgressStatus) -> None: + percent_completed: Optional[int] = None + if progress_status.percent_completed is not None: + percent_completed = round(progress_status.percent_completed * 100) + message = progress_status.message + if percent_completed != progress_info.last_percent_completed or message != progress_info.last_message: + logger.info(f"{percent_completed}% - {message}") + current_time = datetime.now() + if ( + progress_info.last_progress_time is None + or (current_time - progress_info.last_progress_time).seconds > 1 + ): + new_runtime_props = task.data.runtime.copy() or {} # type: ignore + new_runtime_props["progress"] = str(percent_completed) + new_runtime_props["message"] = message + scheduler.schedule( + update_runtime_properties( + task.id, # type: ignore + task.session.host, + task.session.token, # type: ignore + create_runtime_properties(task, percent_completed, message), + ) + ) + progress_info.last_progress_time = current_time + progress_info.last_percent_completed = percent_completed + progress_info.last_message = message + + return clearml_progress + + +def get_local_progress_caller(progress_info: ProgressInfo, logger: logging.Logger) -> Callable[[ProgressStatus], None]: + + def local_progress(progress_status: ProgressStatus) -> None: + percent_completed: Optional[int] = None + if progress_status.percent_completed is not None: + percent_completed = round(progress_status.percent_completed * 100) + message = progress_status.message + if percent_completed != progress_info.last_percent_completed or message != progress_info.last_message: + logger.info(f"{percent_completed}% - {message}") + progress_info.last_percent_completed = percent_completed + progress_info.last_message = message + + return local_progress + + +def update_settings(settings: Settings, args: dict): + settings.update(args) + settings.model_type = cast(str, settings.model_type).lower() + if "build_options" in settings: + try: + build_options = json.loads(cast(str, settings.build_options)) + except ValueError as e: + raise ValueError("Build options could not be parsed: Invalid JSON") from e + except TypeError as e: + raise TypeError(f"Build options could not be parsed: {e}") from e + settings.update({settings.model_type: build_options}) + settings.data_dir = os.path.expanduser(cast(str, settings.data_dir)) + + +async def update_runtime_properties(task_id: str, base_url: str, token: str, runtime_props: dict) -> None: + async with aiohttp.ClientSession(base_url=base_url, headers={"Authorization": f"Bearer {token}"}) as session: + json = {"task": task_id, "runtime": runtime_props, "force": True} + async with session.post("/tasks.edit", json=json) as response: + response.raise_for_status() + + +def create_runtime_properties(task, percent_completed: Optional[int], message: Optional[str]) -> dict: + runtime_props = task.data.runtime.copy() or {} + if percent_completed is not None: + runtime_props["progress"] = str(percent_completed) + else: + del runtime_props["progress"] + if message is not None: + runtime_props["message"] = message + else: + del runtime_props["message"] + return runtime_props diff --git a/machine/jobs/build_smt_engine.py b/machine/jobs/build_smt_engine.py index 9db0b9a3..3a7742b6 100644 --- a/machine/jobs/build_smt_engine.py +++ b/machine/jobs/build_smt_engine.py @@ -1,16 +1,20 @@ import argparse -import json import logging -import os -from datetime import datetime from typing import Callable, Optional, cast -import aiohttp from clearml import Task -from ..utils.canceled_error import CanceledError from ..utils.progress_status import ProgressStatus from .async_scheduler import AsyncScheduler +from .build_clearml_helper import ( + ProgressInfo, + create_runtime_properties, + get_clearml_check_canceled, + get_clearml_progress_caller, + get_local_progress_caller, + update_runtime_properties, + update_settings, +) from .clearml_shared_file_service import ClearMLSharedFileService from .config import SETTINGS from .smt_engine_build_job import SmtEngineBuildJob @@ -25,118 +29,34 @@ logger = logging.getLogger(str(__package__) + ".build_smt_engine") -async def update_runtime_properties(task_id: str, base_url: str, token: str, runtime_props: dict) -> None: - async with aiohttp.ClientSession(base_url=base_url, headers={"Authorization": f"Bearer {token}"}) as session: - json = {"task": task_id, "runtime": runtime_props, "force": True} - async with session.post("/tasks.edit", json=json) as response: - response.raise_for_status() - - -def create_runtime_properties(task, percent_completed: Optional[int], message: Optional[str]) -> dict: - runtime_props = task.data.runtime.copy() or {} - if percent_completed is not None: - runtime_props["progress"] = str(percent_completed) - else: - del runtime_props["progress"] - if message is not None: - runtime_props["message"] = message - else: - del runtime_props["message"] - return runtime_props - - def run(args: dict) -> None: progress: Callable[[ProgressStatus], None] check_canceled: Optional[Callable[[], None]] = None task = None - last_percent_completed: Optional[int] = None - last_message: Optional[str] = None scheduler: Optional[AsyncScheduler] = None + progress_info = ProgressInfo() if args["clearml"]: task = Task.init() - scheduler = AsyncScheduler() - last_check_canceled_time: Optional[datetime] = None - - def clearml_check_canceled() -> None: - nonlocal last_check_canceled_time - current_time = datetime.now() - if last_check_canceled_time is None or (current_time - last_check_canceled_time).seconds > 20: - if task.get_status() == "stopped": - raise CanceledError - last_check_canceled_time = current_time - - check_canceled = clearml_check_canceled + check_canceled = get_clearml_check_canceled(progress_info, task) task.reload() - last_progress_time: Optional[datetime] = None - - def clearml_progress(progress_status: ProgressStatus) -> None: - nonlocal last_percent_completed - nonlocal last_message - nonlocal last_progress_time - percent_completed: Optional[int] = None - if progress_status.percent_completed is not None: - percent_completed = round(progress_status.percent_completed * 100) - message = progress_status.message - if percent_completed != last_percent_completed or message != last_message: - logger.info(f"{percent_completed}% - {message}") - current_time = datetime.now() - if last_progress_time is None or (current_time - last_progress_time).seconds > 1: - new_runtime_props = task.data.runtime.copy() or {} - new_runtime_props["progress"] = str(percent_completed) - new_runtime_props["message"] = message - scheduler.schedule( - update_runtime_properties( - task.id, - task.session.host, - task.session.token, - create_runtime_properties(task, percent_completed, message), - ) - ) - last_progress_time = current_time - last_percent_completed = percent_completed - last_message = message - - progress = clearml_progress - else: + progress = get_clearml_progress_caller(progress_info, task, scheduler, logger) - def local_progress(progress_status: ProgressStatus) -> None: - nonlocal last_percent_completed - nonlocal last_message - percent_completed: Optional[int] = None - if progress_status.percent_completed is not None: - percent_completed = round(progress_status.percent_completed * 100) - message = progress_status.message - if percent_completed != last_percent_completed or message != last_message: - logger.info(f"{percent_completed}% - {message}") - last_percent_completed = percent_completed - last_message = message - - progress = local_progress + else: + progress = get_local_progress_caller(ProgressInfo(), logger) try: logger.info("SMT Engine Build Job started") - - SETTINGS.update(args) - model_type = cast(str, SETTINGS.model_type).lower() - if "build_options" in SETTINGS: - try: - build_options = json.loads(cast(str, SETTINGS.build_options)) - except ValueError as e: - raise ValueError("Build options could not be parsed: Invalid JSON") from e - except TypeError as e: - raise TypeError(f"Build options could not be parsed: {e}") from e - SETTINGS.update({model_type: build_options}) - SETTINGS.data_dir = os.path.expanduser(cast(str, SETTINGS.data_dir)) + update_settings(SETTINGS, args) logger.info(f"Config: {SETTINGS.as_dict()}") shared_file_service = ClearMLSharedFileService(SETTINGS) smt_model_factory: SmtModelFactory - if model_type == "thot": + if SETTINGS.model_type == "thot": from .thot.thot_smt_model_factory import ThotSmtModelFactory smt_model_factory = ThotSmtModelFactory(SETTINGS) diff --git a/machine/jobs/build_word_alignment_model.py b/machine/jobs/build_word_alignment_model.py new file mode 100644 index 00000000..57175986 --- /dev/null +++ b/machine/jobs/build_word_alignment_model.py @@ -0,0 +1,103 @@ +import argparse +import logging +from typing import Callable, Optional + +from clearml import Task + +from machine.jobs.thot.thot_word_alignment_model_factory import ThotWordAlignmentModelFactory + +from ..utils.progress_status import ProgressStatus +from .async_scheduler import AsyncScheduler +from .build_clearml_helper import ( + ProgressInfo, + create_runtime_properties, + get_clearml_check_canceled, + get_clearml_progress_caller, + get_local_progress_caller, + update_runtime_properties, + update_settings, +) +from .clearml_shared_file_service import ClearMLSharedFileService +from .config import SETTINGS +from .word_alignment_build_job import WordAlignmentBuildJob +from .word_alignment_model_factory import WordAlignmentModelFactory + +# Setup logging +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + level=logging.INFO, +) + +logger = logging.getLogger(str(__package__) + ".build_word_alignment_model") + + +def run(args: dict) -> None: + progress: Callable[[ProgressStatus], None] + check_canceled: Optional[Callable[[], None]] = None + task = None + scheduler: Optional[AsyncScheduler] = None + progress_info = ProgressInfo() + if args["clearml"]: + task = Task.init() + scheduler = AsyncScheduler() + + check_canceled = get_clearml_check_canceled(progress_info, task) + + task.reload() + + progress = get_clearml_progress_caller(progress_info, task, scheduler, logger) + + else: + progress = get_local_progress_caller(ProgressInfo(), logger) + + try: + logger.info("Word Alignment Build Job started") + update_settings(SETTINGS, args) + + logger.info(f"Config: {SETTINGS.as_dict()}") + + shared_file_service = ClearMLSharedFileService(SETTINGS) + word_alignment_model_factory: WordAlignmentModelFactory + if SETTINGS.model_type == "thot": + word_alignment_model_factory = ThotWordAlignmentModelFactory(SETTINGS) + else: + raise RuntimeError("The model type is invalid.") + + word_alignment_build_job = WordAlignmentBuildJob(SETTINGS, word_alignment_model_factory, shared_file_service) + train_corpus_size, confidence = word_alignment_build_job.run(progress, check_canceled) + if scheduler is not None and task is not None: + scheduler.schedule( + update_runtime_properties( + task.id, task.session.host, task.session.token, create_runtime_properties(task, 100, "Completed") + ) + ) + task.get_logger().report_single_value(name="train_corpus_size", value=train_corpus_size) + task.get_logger().report_single_value(name="confidence", value=round(confidence, 4)) + logger.info("Finished") + except Exception as e: + if task: + if task.get_status() == "stopped": + return + else: + task.mark_failed(status_reason=type(e).__name__, status_message=str(e)) + raise e + finally: + if scheduler is not None: + scheduler.stop() + + +def main() -> None: + parser = argparse.ArgumentParser(description="Trains an SMT model.") + parser.add_argument("--model-type", required=True, type=str, help="Model type") + parser.add_argument("--build-id", required=True, type=str, help="Build id") + parser.add_argument("--clearml", default=False, action="store_true", help="Initializes a ClearML task") + parser.add_argument("--build-options", default=None, type=str, help="Build configurations") + args = parser.parse_args() + + input_args = {k: v for k, v in vars(args).items() if v is not None} + + run(input_args) + + +if __name__ == "__main__": + main() diff --git a/machine/jobs/engine_build_job.py b/machine/jobs/engine_build_job.py index a36939ba..07a3ec2f 100644 --- a/machine/jobs/engine_build_job.py +++ b/machine/jobs/engine_build_job.py @@ -2,6 +2,7 @@ from abc import ABC, abstractmethod from typing import Any, Callable, Optional, Tuple +from ..corpora.parallel_text_corpus import ParallelTextCorpus from ..utils.phased_progress_reporter import PhasedProgressReporter from ..utils.progress_status import ProgressStatus from .shared_file_service import SharedFileService @@ -47,10 +48,11 @@ def start_job(self) -> None: ... def init_corpus(self) -> None: logger.info("Downloading data files") - self._source_corpus = self._shared_file_service.create_source_corpus() - self._target_corpus = self._shared_file_service.create_target_corpus() - self._parallel_corpus = self._source_corpus.align_rows(self._target_corpus) - self._parallel_corpus_size = self._parallel_corpus.count(include_empty=False) + if "_source_corpus" not in self.__dict__: + self._source_corpus = self._shared_file_service.create_source_corpus() + self._target_corpus = self._shared_file_service.create_target_corpus() + self._parallel_corpus: ParallelTextCorpus = self._source_corpus.align_rows(self._target_corpus) + self._parallel_corpus_size = self._parallel_corpus.count(include_empty=False) @abstractmethod def _get_progress_reporter( diff --git a/machine/jobs/nmt_engine_build_job.py b/machine/jobs/nmt_engine_build_job.py index 4b752951..c74bf2e5 100644 --- a/machine/jobs/nmt_engine_build_job.py +++ b/machine/jobs/nmt_engine_build_job.py @@ -8,7 +8,7 @@ from ..utils.progress_status import ProgressStatus from .engine_build_job import EngineBuildJob from .nmt_model_factory import NmtModelFactory -from .shared_file_service import PretranslationInfo, PretranslationWriter, SharedFileService +from .shared_file_service import DictToJsonWriter, PretranslationInfo, SharedFileService logger = logging.getLogger(__name__) @@ -85,7 +85,7 @@ def batch_inference( writer = stack.enter_context(self._shared_file_service.open_target_pretranslation_writer()) current_inference_step = 0 phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count)) - batch_size = self._config["pretranslation_batch_size"] + batch_size = self._config["inference_batch_size"] for pi_batch in batch(src_pretranslations, batch_size): if check_canceled is not None: check_canceled() @@ -105,9 +105,9 @@ def save_model(self) -> None: def _translate_batch( engine: TranslationEngine, batch: Sequence[PretranslationInfo], - writer: PretranslationWriter, + writer: DictToJsonWriter, ) -> None: source_segments = [pi["translation"] for pi in batch] for i, result in enumerate(engine.translate_batch(source_segments)): batch[i]["translation"] = result.translation - writer.write(batch[i]) + writer.write(batch[i]) # type: ignore diff --git a/machine/jobs/settings.yaml b/machine/jobs/settings.yaml index b787b6d6..a60b9ef5 100644 --- a/machine/jobs/settings.yaml +++ b/machine/jobs/settings.yaml @@ -2,7 +2,7 @@ default: data_dir: ~/machine shared_file_uri: s3://silnlp/ shared_file_folder: production - pretranslation_batch_size: 1024 + inference_batch_size: 1024 huggingface: parent_model_name: facebook/nllb-200-distilled-1.3B train_params: @@ -27,6 +27,7 @@ default: add_unk_trg_tokens: true thot: word_alignment_model_type: hmm + word_alignment_heuristic: grow-diag-final-and tokenizer: latin development: shared_file_folder: dev diff --git a/machine/jobs/shared_file_service.py b/machine/jobs/shared_file_service.py index d9c71827..25678af3 100644 --- a/machine/jobs/shared_file_service.py +++ b/machine/jobs/shared_file_service.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from contextlib import contextmanager from pathlib import Path -from typing import Any, Generator, Iterator, List, TextIO, TypedDict +from typing import Any, Generator, Iterator, List, MutableMapping, TextIO, TypedDict import json_stream @@ -18,12 +18,18 @@ class PretranslationInfo(TypedDict): translation: str -class PretranslationWriter: +class WordAlignmentInfo(TypedDict): + refs: List[str] + alignmnent: str + + +class DictToJsonWriter: def __init__(self, file: TextIO) -> None: self._file = file self._first = True - def write(self, pi: PretranslationInfo) -> None: + # Use MutableMapping rather than TypeDict to allow for more flexible input + def write(self, pi: MutableMapping) -> None: if not self._first: self._file.write(",\n") self._file.write(" " + json.dumps(pi)) @@ -36,12 +42,10 @@ def __init__( config: Any, source_filename: str = "train.src.txt", target_filename: str = "train.trg.txt", - pretranslation_filename: str = "pretranslate.src.json", ) -> None: self._config = config self._source_filename = source_filename self._target_filename = target_filename - self._pretranslation_filename = pretranslation_filename def create_source_corpus(self) -> TextCorpus: return TextFileTextCorpus(self._download_file(f"{self._build_path}/{self._source_filename}")) @@ -56,7 +60,7 @@ def exists_target_corpus(self) -> bool: return self._exists_file(f"{self._build_path}/{self._target_filename}") def get_source_pretranslations(self) -> ContextManagedGenerator[PretranslationInfo, None, None]: - src_pretranslate_path = self._download_file(f"{self._build_path}/{self._pretranslation_filename}") + src_pretranslate_path = self._download_file(f"{self._build_path}/pretranslate.src.json") def generator() -> Generator[PretranslationInfo, None, None]: with src_pretranslate_path.open("r", encoding="utf-8-sig") as file: @@ -71,15 +75,22 @@ def generator() -> Generator[PretranslationInfo, None, None]: return ContextManagedGenerator(generator()) @contextmanager - def open_target_pretranslation_writer(self) -> Iterator[PretranslationWriter]: + def open_target_pretranslation_writer(self) -> Iterator[DictToJsonWriter]: + return self._open_target_writer("pretranslate.trg.json") + + @contextmanager + def open_target_alignment_writer(self) -> Iterator[DictToJsonWriter]: + return self._open_target_writer("word_alignments.json") + + def _open_target_writer(self, filename) -> Iterator[DictToJsonWriter]: build_dir = self._data_dir / self._shared_file_folder / self._build_path build_dir.mkdir(parents=True, exist_ok=True) - target_pretranslate_path = build_dir / self._pretranslation_filename - with target_pretranslate_path.open("w", encoding="utf-8", newline="\n") as file: + target_path = build_dir / filename + with target_path.open("w", encoding="utf-8", newline="\n") as file: file.write("[\n") - yield PretranslationWriter(file) + yield DictToJsonWriter(file) file.write("\n]\n") - self._upload_file(f"{self._build_path}/{self._pretranslation_filename}", target_pretranslate_path) + self._upload_file(f"{self._build_path}/{filename}", target_path) def save_model(self, model_path: Path, destination: str) -> None: if model_path.is_file(): diff --git a/machine/jobs/smt_engine_build_job.py b/machine/jobs/smt_engine_build_job.py index 86d2ef4c..1aee9f41 100644 --- a/machine/jobs/smt_engine_build_job.py +++ b/machine/jobs/smt_engine_build_job.py @@ -7,7 +7,7 @@ from ..utils.phased_progress_reporter import Phase, PhasedProgressReporter from ..utils.progress_status import ProgressStatus from .engine_build_job import EngineBuildJob -from .shared_file_service import PretranslationInfo, PretranslationWriter, SharedFileService +from .shared_file_service import DictToJsonWriter, PretranslationInfo, SharedFileService from .smt_model_factory import SmtModelFactory logger = logging.getLogger(__name__) @@ -74,7 +74,7 @@ def batch_inference( writer = stack.enter_context(self._shared_file_service.open_target_pretranslation_writer()) current_inference_step = 0 phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count)) - batch_size = self._config["pretranslation_batch_size"] + batch_size = self._config["inference_batch_size"] for pi_batch in batch(src_pretranslations, batch_size): if check_canceled is not None: check_canceled() @@ -93,9 +93,9 @@ def save_model(self) -> None: def _translate_batch( engine: TranslationEngine, batch: Sequence[PretranslationInfo], - writer: PretranslationWriter, + writer: DictToJsonWriter, ) -> None: source_segments = [pi["translation"] for pi in batch] for i, result in enumerate(engine.translate_batch(source_segments)): batch[i]["translation"] = result.translation - writer.write(batch[i]) + writer.write(batch[i]) # type: ignore diff --git a/machine/jobs/smt_model_factory.py b/machine/jobs/smt_model_factory.py index ac8aa854..02942666 100644 --- a/machine/jobs/smt_model_factory.py +++ b/machine/jobs/smt_model_factory.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from pathlib import Path from typing import Optional from ..corpora.parallel_text_corpus import ParallelTextCorpus @@ -9,18 +8,10 @@ from ..translation.trainer import Trainer from ..translation.translation_engine import TranslationEngine from ..translation.truecaser import Truecaser +from .thot.thot_model_factory import ThotModelFactory -class SmtModelFactory(ABC): - @abstractmethod - def init(self) -> None: ... - - @abstractmethod - def create_tokenizer(self) -> Tokenizer[str, int, str]: ... - - @abstractmethod - def create_detokenizer(self) -> Detokenizer[str, str]: ... - +class SmtModelFactory(ABC, ThotModelFactory): @abstractmethod def create_model_trainer(self, tokenizer: Tokenizer[str, int, str], corpus: ParallelTextCorpus) -> Trainer: ... @@ -37,6 +28,3 @@ def create_truecaser_trainer(self, tokenizer: Tokenizer[str, int, str], target_c @abstractmethod def create_truecaser(self) -> Truecaser: ... - - @abstractmethod - def save_model(self) -> Path: ... diff --git a/machine/jobs/thot/thot_model_factory.py b/machine/jobs/thot/thot_model_factory.py new file mode 100644 index 00000000..d00fa2ab --- /dev/null +++ b/machine/jobs/thot/thot_model_factory.py @@ -0,0 +1,57 @@ +import os +import shutil +from pathlib import Path +from typing import Any + +from ...tokenization.detokenizer import Detokenizer +from ...tokenization.latin_word_detokenizer import LatinWordDetokenizer +from ...tokenization.latin_word_tokenizer import LatinWordTokenizer +from ...tokenization.tokenizer import Tokenizer +from ...tokenization.whitespace_detokenizer import WhitespaceDetokenizer +from ...tokenization.whitespace_tokenizer import WhitespaceTokenizer +from ...tokenization.zwsp_word_detokenizer import ZwspWordDetokenizer +from ...tokenization.zwsp_word_tokenizer import ZwspWordTokenizer + +_THOT_NEW_MODEL_DIRECTORY = os.path.join(os.path.dirname(__file__), "thot-new-model") + +_TOKENIZERS = ["latin", "whitespace", "zwsp"] + + +class ThotModelFactory: + def __init__(self, config: Any) -> None: + self._config = config + + def init(self) -> None: + shutil.copytree(_THOT_NEW_MODEL_DIRECTORY, self._model_dir, dirs_exist_ok=True) + + def create_tokenizer(self) -> Tokenizer[str, int, str]: + name: str = self._config.thot.tokenizer + name = name.lower() + if name == "latin": + return LatinWordTokenizer() + if name == "whitespace": + return WhitespaceTokenizer() + if name == "zwsp": + return ZwspWordTokenizer() + raise RuntimeError(f"Unknown tokenizer: {name}. Available tokenizers are: {_TOKENIZERS}.") + + def create_detokenizer(self) -> Detokenizer[str, str]: + name: str = self._config.thot.tokenizer + name = name.lower() + if name == "latin": + return LatinWordDetokenizer() + if name == "whitespace": + return WhitespaceDetokenizer() + if name == "zwsp": + return ZwspWordDetokenizer() + raise RuntimeError(f"Unknown detokenizer: {name}. Available detokenizers are: {_TOKENIZERS}.") + + def save_model(self) -> Path: + tar_file_basename = Path( + self._config.data_dir, self._config.shared_file_folder, "builds", self._config.build_id, "model" + ) + return Path(shutil.make_archive(str(tar_file_basename), "gztar", self._model_dir)) + + @property + def _model_dir(self) -> Path: + return Path(self._config.data_dir, self._config.shared_file_folder, "builds", self._config.build_id, "model") diff --git a/machine/jobs/thot/thot_smt_model_factory.py b/machine/jobs/thot/thot_smt_model_factory.py index 9e89e30f..53cce722 100644 --- a/machine/jobs/thot/thot_smt_model_factory.py +++ b/machine/jobs/thot/thot_smt_model_factory.py @@ -1,18 +1,9 @@ -import os -import shutil -from pathlib import Path -from typing import Any, Optional +from typing import Optional from ...corpora.parallel_text_corpus import ParallelTextCorpus from ...corpora.text_corpus import TextCorpus from ...tokenization.detokenizer import Detokenizer -from ...tokenization.latin_word_detokenizer import LatinWordDetokenizer -from ...tokenization.latin_word_tokenizer import LatinWordTokenizer from ...tokenization.tokenizer import Tokenizer -from ...tokenization.whitespace_detokenizer import WhitespaceDetokenizer -from ...tokenization.whitespace_tokenizer import WhitespaceTokenizer -from ...tokenization.zwsp_word_detokenizer import ZwspWordDetokenizer -from ...tokenization.zwsp_word_tokenizer import ZwspWordTokenizer from ...translation.thot.thot_smt_model import ThotSmtModel from ...translation.thot.thot_smt_model_trainer import ThotSmtModelTrainer from ...translation.trainer import Trainer @@ -21,40 +12,8 @@ from ...translation.unigram_truecaser import UnigramTruecaser, UnigramTruecaserTrainer from ..smt_model_factory import SmtModelFactory -_THOT_NEW_MODEL_DIRECTORY = os.path.join(os.path.dirname(__file__), "thot-new-model") - -_TOKENIZERS = ["latin", "whitespace", "zwsp"] - class ThotSmtModelFactory(SmtModelFactory): - def __init__(self, config: Any) -> None: - self._config = config - - def init(self) -> None: - shutil.copytree(_THOT_NEW_MODEL_DIRECTORY, self._model_dir, dirs_exist_ok=True) - - def create_tokenizer(self) -> Tokenizer[str, int, str]: - name: str = self._config.thot.tokenizer - name = name.lower() - if name == "latin": - return LatinWordTokenizer() - if name == "whitespace": - return WhitespaceTokenizer() - if name == "zwsp": - return ZwspWordTokenizer() - raise RuntimeError(f"Unknown tokenizer: {name}. Available tokenizers are: {_TOKENIZERS}.") - - def create_detokenizer(self) -> Detokenizer[str, str]: - name: str = self._config.thot.tokenizer - name = name.lower() - if name == "latin": - return LatinWordDetokenizer() - if name == "whitespace": - return WhitespaceDetokenizer() - if name == "zwsp": - return ZwspWordDetokenizer() - raise RuntimeError(f"Unknown detokenizer: {name}. Available detokenizers are: {_TOKENIZERS}.") - def create_model_trainer(self, tokenizer: Tokenizer[str, int, str], corpus: ParallelTextCorpus) -> Trainer: return ThotSmtModelTrainer( word_alignment_model_type=self._config.thot.word_alignment_model_type, @@ -90,13 +49,3 @@ def create_truecaser_trainer(self, tokenizer: Tokenizer[str, int, str], target_c def create_truecaser(self) -> Truecaser: return UnigramTruecaser(model_path=self._model_dir / "unigram-casing-model.txt") - - def save_model(self) -> Path: - tar_file_basename = Path( - self._config.data_dir, self._config.shared_file_folder, "builds", self._config.build_id, "model" - ) - return Path(shutil.make_archive(str(tar_file_basename), "gztar", self._model_dir)) - - @property - def _model_dir(self) -> Path: - return Path(self._config.data_dir, self._config.shared_file_folder, "builds", self._config.build_id, "model") diff --git a/machine/jobs/thot/thot_word_alignment_model_factory.py b/machine/jobs/thot/thot_word_alignment_model_factory.py new file mode 100644 index 00000000..c5e2e8cd --- /dev/null +++ b/machine/jobs/thot/thot_word_alignment_model_factory.py @@ -0,0 +1,49 @@ +from pathlib import Path + +from ...corpora.parallel_text_corpus import ParallelTextCorpus +from ...tokenization.tokenizer import Tokenizer +from ...translation.symmetrized_word_alignment_model_trainer import SymmetrizedWordAlignmentModelTrainer +from ...translation.thot.thot_symmetrized_word_alignment_model import ThotSymmetrizedWordAlignmentModel +from ...translation.thot.thot_word_alignment_model_trainer import ThotWordAlignmentModelTrainer +from ...translation.thot.thot_word_alignment_model_utils import create_thot_word_alignment_model +from ...translation.trainer import Trainer +from ...translation.word_alignment_model import WordAlignmentModel +from ..word_alignment_model_factory import WordAlignmentModelFactory + + +class ThotWordAlignmentModelFactory(WordAlignmentModelFactory): + + def create_model_trainer(self, tokenizer: Tokenizer[str, int, str], corpus: ParallelTextCorpus) -> Trainer: + direct_trainer = ThotWordAlignmentModelTrainer( + self._config.thot.word_alignment_model_type, + corpus.lowercase(), + prefix_filename=self._direct_model_path, + source_tokenizer=tokenizer, + target_tokenizer=tokenizer, + ) + inverse_trainer = ThotWordAlignmentModelTrainer( + self._config.thot.word_alignment_model_type, + corpus.invert().lowercase(), + prefix_filename=self._inverse_model_path, + source_tokenizer=tokenizer, + target_tokenizer=tokenizer, + ) + return SymmetrizedWordAlignmentModelTrainer(direct_trainer, inverse_trainer) + + def create_alignment_model( + self, + ) -> WordAlignmentModel: + model = ThotSymmetrizedWordAlignmentModel( + create_thot_word_alignment_model(self._config.thot.word_alignment_model_type, self._direct_model_path), + create_thot_word_alignment_model(self._config.thot.word_alignment_model_type, self._inverse_model_path), + ) + model.heuristic = self._config.thot.word_alignment_heuristic + return model + + @property + def _direct_model_path(self) -> Path: + return self._model_dir / "tm" / "src_trg_invswm" + + @property + def _inverse_model_path(self) -> Path: + return self._model_dir / "tm" / "src_trg_swm" diff --git a/machine/jobs/word_alignment_build_job.py b/machine/jobs/word_alignment_build_job.py new file mode 100644 index 00000000..94755e03 --- /dev/null +++ b/machine/jobs/word_alignment_build_job.py @@ -0,0 +1,85 @@ +import logging +from contextlib import ExitStack +from typing import Any, Callable, Optional + +from ..utils.phased_progress_reporter import Phase, PhasedProgressReporter +from ..utils.progress_status import ProgressStatus +from .engine_build_job import EngineBuildJob +from .shared_file_service import SharedFileService, WordAlignmentInfo +from .word_alignment_model_factory import WordAlignmentModelFactory + +logger = logging.getLogger(__name__) + + +class WordAlignmentBuildJob(EngineBuildJob): + def __init__( + self, + config: Any, + word_alignment_model_factory: WordAlignmentModelFactory, + shared_file_service: SharedFileService, + ) -> None: + self._word_alignment_model_factory = word_alignment_model_factory + super().__init__(config, shared_file_service) + + def start_job(self) -> None: + self._word_alignment_model_factory.init() + self._tokenizer = self._word_alignment_model_factory.create_tokenizer() + logger.info(f"Tokenizer: {type(self._tokenizer).__name__}") + + def _get_progress_reporter(self, progress: Optional[Callable[[ProgressStatus], None]]) -> PhasedProgressReporter: + phases = [ + Phase(message="Training Word Alignment model", percentage=0.9), + Phase(message="Aligning segments", percentage=0.1), + ] + return PhasedProgressReporter(progress, phases) + + def respond_to_no_training_corpus(self) -> None: + raise RuntimeError("No parallel corpus data found") + + def train_model( + self, + progress_reporter: PhasedProgressReporter, + check_canceled: Optional[Callable[[], None]], + ) -> None: + + with progress_reporter.start_next_phase() as phase_progress, self._word_alignment_model_factory.create_model_trainer( + self._tokenizer, self._parallel_corpus + ) as trainer: + trainer.train(progress=phase_progress, check_canceled=check_canceled) + trainer.save() + self._train_corpus_size = trainer.stats.train_corpus_size + self._confidence = trainer.stats.metrics["bleu"] * 100 + + if check_canceled is not None: + check_canceled() + + def batch_inference( + self, + progress_reporter: PhasedProgressReporter, + check_canceled: Optional[Callable[[], None]], + ) -> None: + inference_step_count = self._parallel_corpus.count(include_empty=False) + + with ExitStack() as stack: + phase_progress = stack.enter_context(progress_reporter.start_next_phase()) + alignment_model = stack.enter_context(self._word_alignment_model_factory.create_alignment_model()) + self.init_corpus() + writer = stack.enter_context(self._shared_file_service.open_target_alignment_writer()) + current_inference_step = 0 + phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count)) + batch_size = self._config["inference_batch_size"] + segment_batch = list(self._parallel_corpus.lowercase().take(batch_size)) + if check_canceled is not None: + check_canceled() + alignments = alignment_model.align_batch(segment_batch) + if check_canceled is not None: + check_canceled() + for row, alignment in zip(self._parallel_corpus.get_rows(), alignments): + writer.write(WordAlignmentInfo(refs=row.source_refs, alignment=str(alignment))) # type: ignore + + def save_model(self) -> None: + logger.info("Saving model") + model_path = self._word_alignment_model_factory.save_model() + self._shared_file_service.save_model( + model_path, f"builds/{self._config['build_id']}/model{''.join(model_path.suffixes)}" + ) diff --git a/machine/jobs/word_alignment_model_factory.py b/machine/jobs/word_alignment_model_factory.py new file mode 100644 index 00000000..9ed163dd --- /dev/null +++ b/machine/jobs/word_alignment_model_factory.py @@ -0,0 +1,17 @@ +from abc import ABC, abstractmethod + +from ..corpora.parallel_text_corpus import ParallelTextCorpus +from ..tokenization.tokenizer import Tokenizer +from ..translation.trainer import Trainer +from ..translation.word_alignment_model import WordAlignmentModel +from .thot.thot_model_factory import ThotModelFactory + + +class WordAlignmentModelFactory(ABC, ThotModelFactory): + @abstractmethod + def create_model_trainer(self, tokenizer: Tokenizer[str, int, str], corpus: ParallelTextCorpus) -> Trainer: ... + + @abstractmethod + def create_alignment_model( + self, + ) -> WordAlignmentModel: ... diff --git a/machine/translation/thot/thot_symmetrized_word_alignment_model.py b/machine/translation/thot/thot_symmetrized_word_alignment_model.py index 6c3815f0..c1cbe284 100644 --- a/machine/translation/thot/thot_symmetrized_word_alignment_model.py +++ b/machine/translation/thot/thot_symmetrized_word_alignment_model.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Sequence, cast +from typing import List, Sequence, Union, cast import thot.alignment as ta @@ -38,7 +38,9 @@ def inverse_word_alignment_model(self) -> ThotWordAlignmentModel: return cast(ThotWordAlignmentModel, self._inverse_word_alignment_model) @heuristic.setter - def heuristic(self, value: SymmetrizationHeuristic) -> None: + def heuristic(self, value: Union[SymmetrizationHeuristic, str]) -> None: + if isinstance(value, str): + value = SymmetrizationHeuristic[value.upper().replace("-", "_")] self._heuristic = value self._aligner.heuristic = _convert_heuristic(self._heuristic) diff --git a/machine/translation/thot/thot_word_alignment_model_utils.py b/machine/translation/thot/thot_word_alignment_model_utils.py index b4a18884..82482a86 100644 --- a/machine/translation/thot/thot_word_alignment_model_utils.py +++ b/machine/translation/thot/thot_word_alignment_model_utils.py @@ -1,5 +1,6 @@ -from typing import Union +from typing import Optional, Union +from ...utils.typeshed import StrPath from .thot_fast_align_word_alignment_model import ThotFastAlignWordAlignmentModel from .thot_hmm_word_alignment_model import ThotHmmWordAlignmentModel from .thot_ibm1_word_alignment_model import ThotIbm1WordAlignmentModel @@ -11,25 +12,31 @@ from .thot_word_alignment_model_type import ThotWordAlignmentModelType -def create_thot_word_alignment_model(type: Union[str, int]) -> ThotWordAlignmentModel: +def create_thot_word_alignment_model( + type: Union[str, int], prefix_filename: Optional[StrPath] = None +) -> ThotWordAlignmentModel: if isinstance(type, str): type = ThotWordAlignmentModelType[type.upper()] if type == ThotWordAlignmentModelType.FAST_ALIGN: - return ThotFastAlignWordAlignmentModel() + return ThotFastAlignWordAlignmentModel(prefix_filename) if type == ThotWordAlignmentModelType.IBM1: - return ThotIbm1WordAlignmentModel() + return ThotIbm1WordAlignmentModel(prefix_filename) if type == ThotWordAlignmentModelType.IBM2: - return ThotIbm2WordAlignmentModel() + return ThotIbm2WordAlignmentModel(prefix_filename) if type == ThotWordAlignmentModelType.HMM: - return ThotHmmWordAlignmentModel() + return ThotHmmWordAlignmentModel(prefix_filename) if type == ThotWordAlignmentModelType.IBM3: - return ThotIbm3WordAlignmentModel() + return ThotIbm3WordAlignmentModel(prefix_filename) if type == ThotWordAlignmentModelType.IBM4: - return ThotIbm4WordAlignmentModel() + return ThotIbm4WordAlignmentModel(prefix_filename) raise ValueError("The word alignment model type is unknown.") -def create_thot_symmetrized_word_alignment_model(type: Union[int, str]) -> ThotSymmetrizedWordAlignmentModel: - direct_model = create_thot_word_alignment_model(type) - inverse_model = create_thot_word_alignment_model(type) +def create_thot_symmetrized_word_alignment_model( + type: Union[int, str], + direct_prefix_filename: Optional[StrPath] = None, + inverse_prefix_filename: Optional[StrPath] = None, +) -> ThotSymmetrizedWordAlignmentModel: + direct_model = create_thot_word_alignment_model(type, direct_prefix_filename) + inverse_model = create_thot_word_alignment_model(type, inverse_prefix_filename) return ThotSymmetrizedWordAlignmentModel(direct_model, inverse_model) diff --git a/tests/jobs/test_nmt_engine_build_job.py b/tests/jobs/test_nmt_engine_build_job.py index b4697107..6f192787 100644 --- a/tests/jobs/test_nmt_engine_build_job.py +++ b/tests/jobs/test_nmt_engine_build_job.py @@ -10,7 +10,7 @@ from machine.annotations import Range from machine.corpora import DictionaryTextCorpus -from machine.jobs import NmtEngineBuildJob, NmtModelFactory, PretranslationInfo, PretranslationWriter, SharedFileService +from machine.jobs import DictToJsonWriter, NmtEngineBuildJob, NmtModelFactory, PretranslationInfo, SharedFileService from machine.translation import ( Phrase, Trainer, @@ -116,10 +116,10 @@ def __init__(self, decoy: Decoy) -> None: self.target_pretranslations = "" @contextmanager - def open_target_pretranslation_writer(env: _TestEnvironment) -> Iterator[PretranslationWriter]: + def open_target_pretranslation_writer(env: _TestEnvironment) -> Iterator[DictToJsonWriter]: file = StringIO() file.write("[\n") - yield PretranslationWriter(file) + yield DictToJsonWriter(file) file.write("\n]\n") env.target_pretranslations = file.getvalue() @@ -128,9 +128,7 @@ def open_target_pretranslation_writer(env: _TestEnvironment) -> Iterator[Pretran ) self.job = NmtEngineBuildJob( - MockSettings( - {"src_lang": "es", "trg_lang": "en", "save_model": "save-model", "pretranslation_batch_size": 100} - ), + MockSettings({"src_lang": "es", "trg_lang": "en", "save_model": "save-model", "inference_batch_size": 100}), self.nmt_model_factory, self.shared_file_service, ) diff --git a/tests/jobs/test_smt_engine_build_job.py b/tests/jobs/test_smt_engine_build_job.py index ff77d49d..a6dbb903 100644 --- a/tests/jobs/test_smt_engine_build_job.py +++ b/tests/jobs/test_smt_engine_build_job.py @@ -10,7 +10,7 @@ from machine.annotations import Range from machine.corpora import DictionaryTextCorpus, MemoryText, TextRow -from machine.jobs import PretranslationInfo, PretranslationWriter, SharedFileService, SmtEngineBuildJob, SmtModelFactory +from machine.jobs import DictToJsonWriter, PretranslationInfo, SharedFileService, SmtEngineBuildJob, SmtModelFactory from machine.tokenization import WHITESPACE_DETOKENIZER, WHITESPACE_TOKENIZER from machine.translation import ( Phrase, @@ -147,10 +147,10 @@ def __init__(self, decoy: Decoy) -> None: self.target_pretranslations = "" @contextmanager - def open_target_pretranslation_writer(env: _TestEnvironment) -> Iterator[PretranslationWriter]: + def open_target_pretranslation_writer(env: _TestEnvironment) -> Iterator[DictToJsonWriter]: file = StringIO() file.write("[\n") - yield PretranslationWriter(file) + yield DictToJsonWriter(file) file.write("\n]\n") env.target_pretranslations = file.getvalue() @@ -159,7 +159,7 @@ def open_target_pretranslation_writer(env: _TestEnvironment) -> Iterator[Pretran ) self.job = SmtEngineBuildJob( - MockSettings({"build_id": "mybuild", "pretranslation_batch_size": 100}), + MockSettings({"build_id": "mybuild", "inference_batch_size": 100}), self.smt_model_factory, self.shared_file_service, ) diff --git a/tests/jobs/test_word_alignment_build_job.py b/tests/jobs/test_word_alignment_build_job.py new file mode 100644 index 00000000..e5e8d5dc --- /dev/null +++ b/tests/jobs/test_word_alignment_build_job.py @@ -0,0 +1,143 @@ +import json +from contextlib import contextmanager +from io import StringIO +from pathlib import Path +from typing import Iterator + +from decoy import Decoy, matchers +from pytest import raises +from testutils.mock_settings import MockSettings + +from machine.corpora import DictionaryTextCorpus, MemoryText, TextRow +from machine.jobs import ( + DictToJsonWriter, + PretranslationInfo, + SharedFileService, + WordAlignmentBuildJob, + WordAlignmentModelFactory, +) +from machine.tokenization import WHITESPACE_TOKENIZER +from machine.translation import Trainer, TrainStats, WordAlignmentMatrix +from machine.translation.word_alignment_model import WordAlignmentModel +from machine.utils import CanceledError, ContextManagedGenerator + + +def test_run(decoy: Decoy) -> None: + env = _TestEnvironment(decoy) + env.job.run() + + alignments = json.loads(env.alignment_json) + assert len(alignments) == 1 + assert alignments[0]["alignment"] == "0-0 1-1 2-2" + decoy.verify( + env.shared_file_service.save_model(matchers.Anything(), f"builds/{env.job._config.build_id}/model.zip"), times=1 + ) + + +def test_cancel(decoy: Decoy) -> None: + env = _TestEnvironment(decoy) + checker = _CancellationChecker(3) + with raises(CanceledError): + env.job.run(check_canceled=checker.check_canceled) + + assert env.alignment_json == "" + + +class _TestEnvironment: + def __init__(self, decoy: Decoy) -> None: + self.model_trainer = decoy.mock(cls=Trainer) + decoy.when(self.model_trainer.__enter__()).then_return(self.model_trainer) + stats = TrainStats() + stats.train_corpus_size = 3 + stats.metrics["bleu"] = 30.0 + decoy.when(self.model_trainer.stats).then_return(stats) + + self.model = decoy.mock(cls=WordAlignmentModel) + decoy.when(self.model.__enter__()).then_return(self.model) + decoy.when(self.model.align_batch(matchers.Anything())).then_return( + [ + WordAlignmentMatrix.from_word_pairs(row_count=3, column_count=3, set_values=[(0, 0), (1, 1), (2, 2)]), + ] + ) + + self.word_alignment_model_factory = decoy.mock(cls=WordAlignmentModelFactory) + decoy.when(self.word_alignment_model_factory.create_tokenizer()).then_return(WHITESPACE_TOKENIZER) + decoy.when( + self.word_alignment_model_factory.create_model_trainer(matchers.Anything(), matchers.Anything()) + ).then_return(self.model_trainer) + decoy.when(self.word_alignment_model_factory.create_alignment_model()).then_return(self.model) + decoy.when(self.word_alignment_model_factory.save_model()).then_return(Path("model.zip")) + + self.shared_file_service = decoy.mock(cls=SharedFileService) + decoy.when(self.shared_file_service.create_source_corpus()).then_return( + DictionaryTextCorpus( + MemoryText( + "text1", + [ + TextRow("text1", 1, ["¿Le importaría darnos las llaves de la habitación, por favor?"]), + TextRow("text1", 2, ["¿Le importaría cambiarme a otra habitación más tranquila?"]), + TextRow("text1", 3, ["Me parece que existe un problema."]), + ], + ) + ) + ) + decoy.when(self.shared_file_service.create_target_corpus()).then_return( + DictionaryTextCorpus( + MemoryText( + "text1", + [ + TextRow("text1", 1, ["Would you mind giving us the room keys, please?"]), + TextRow("text1", 2, ["Would you mind moving me to another quieter room?"]), + TextRow("text1", 3, ["I think there is a problem."]), + ], + ) + ) + ) + decoy.when(self.shared_file_service.exists_source_corpus()).then_return(True) + decoy.when(self.shared_file_service.exists_target_corpus()).then_return(True) + decoy.when(self.shared_file_service.get_source_pretranslations()).then_do( + lambda: ContextManagedGenerator( + ( + pi + for pi in [ + PretranslationInfo( + corpusId="corpus1", + textId="text1", + refs=["ref1"], + translation="Por favor, tengo reservada una habitación.", + ) + ] + ) + ) + ) + + self.alignment_json = "" + + @contextmanager + def open_target_alignment_writer(env: _TestEnvironment) -> Iterator[DictToJsonWriter]: + file = StringIO() + file.write("[\n") + yield DictToJsonWriter(file) + file.write("\n]\n") + env.alignment_json = file.getvalue() + + decoy.when(self.shared_file_service.open_target_alignment_writer()).then_do( + lambda: open_target_alignment_writer(self) + ) + + self.job = WordAlignmentBuildJob( + MockSettings({"build_id": "mybuild", "inference_batch_size": 100}), + self.word_alignment_model_factory, + self.shared_file_service, + ) + + +class _CancellationChecker: + def __init__(self, raise_count: int) -> None: + self._call_count = 0 + self._raise_count = raise_count + + def check_canceled(self) -> None: + self._call_count += 1 + if self._call_count == self._raise_count: + raise CanceledError diff --git a/tests/translation/thot/test_thot_smt_model_trainer.py b/tests/translation/thot/test_thot_smt_model_trainer.py index 19933e06..dff454c9 100644 --- a/tests/translation/thot/test_thot_smt_model_trainer.py +++ b/tests/translation/thot/test_thot_smt_model_trainer.py @@ -1,92 +1,14 @@ import os from tempfile import TemporaryDirectory -from machine.corpora import ( - AlignedWordPair, - AlignmentRow, - DictionaryAlignmentCorpus, - DictionaryTextCorpus, - MemoryAlignmentCollection, - MemoryText, - TextRow, -) from machine.translation.thot import ThotSmtModel, ThotSmtModelTrainer, ThotSmtParameters, ThotWordAlignmentModelType +from .thot_model_trainer_helper import get_emtpy_parallel_corpus, get_parallel_corpus + def test_train_non_empty_corpus() -> None: with TemporaryDirectory() as temp_dir: - source_corpus = DictionaryTextCorpus( - [ - MemoryText( - "text1", - [ - _row(1, "¿ le importaría darnos las llaves de la habitación , por favor ?"), - _row( - 2, - "he hecho la reserva de una habitación tranquila doble con ||| teléfono ||| y televisión a " - "nombre de rosario cabedo .", - ), - _row(3, "¿ le importaría cambiarme a otra habitación más tranquila ?"), - _row(4, "por favor , tengo reservada una habitación ."), - _row(5, "me parece que existe un problema ."), - _row( - 6, - "¿ tiene habitaciones libres con televisión , aire acondicionado y caja fuerte ?", - ), - _row(7, "¿ le importaría mostrarnos una habitación con televisión ?"), - _row(8, "¿ tiene teléfono ?"), - _row(9, "voy a marcharme el dos a las ocho de la noche ."), - _row(10, "¿ cuánto cuesta una habitación individual por semana ?"), - ], - ) - ] - ) - - target_corpus = DictionaryTextCorpus( - [ - MemoryText( - "text1", - [ - _row(1, "would you mind giving us the keys to the room , please ?"), - _row( - 2, - "i have made a reservation for a quiet , double room with a ||| telephone ||| and a tv for " - "rosario cabedo .", - ), - _row(3, "would you mind moving me to a quieter room ?"), - _row(4, "i have booked a room ."), - _row(5, "i think that there is a problem ."), - _row(6, "do you have any rooms with a tv , air conditioning and a safe available ?"), - _row(7, "would you mind showing us a room with a tv ?"), - _row(8, "does it have a telephone ?"), - _row(9, "i am leaving on the second at eight in the evening ."), - _row(10, "how much does a single room cost per week ?"), - ], - ) - ] - ) - - alignment_corpus = DictionaryAlignmentCorpus( - [ - MemoryAlignmentCollection( - "text1", - [ - _alignment(1, AlignedWordPair(8, 9)), - _alignment(2, AlignedWordPair(6, 10)), - _alignment(3, AlignedWordPair(6, 8)), - _alignment(4, AlignedWordPair(6, 4)), - _alignment(5), - _alignment(6, AlignedWordPair(2, 4)), - _alignment(7, AlignedWordPair(5, 6)), - _alignment(8), - _alignment(9), - _alignment(10, AlignedWordPair(4, 5)), - ], - ) - ] - ) - - corpus = source_corpus.align_rows(target_corpus, alignment_corpus) + corpus = get_parallel_corpus() parameters = ThotSmtParameters( translation_model_filename_prefix=os.path.join(temp_dir, "tm", "src_trg"), @@ -105,11 +27,7 @@ def test_train_non_empty_corpus() -> None: def test_train_empty_corpus() -> None: with TemporaryDirectory() as temp_dir: - source_corpus = DictionaryTextCorpus([]) - target_corpus = DictionaryTextCorpus([]) - alignment_corpus = DictionaryAlignmentCorpus([]) - - corpus = source_corpus.align_rows(target_corpus, alignment_corpus) + corpus = get_emtpy_parallel_corpus() parameters = ThotSmtParameters( translation_model_filename_prefix=os.path.join(temp_dir, "tm", "src_trg"), @@ -124,11 +42,3 @@ def test_train_empty_corpus() -> None: with ThotSmtModel(ThotWordAlignmentModelType.HMM, parameters) as model: result = model.translate("una habitación individual por semana") assert result.translation == "una habitación individual por semana" - - -def _row(row_ref: int, text: str) -> TextRow: - return TextRow("text1", row_ref, segment=[text]) - - -def _alignment(row_ref: int, *pairs: AlignedWordPair) -> AlignmentRow: - return AlignmentRow("text1", row_ref, aligned_word_pairs=pairs) diff --git a/tests/translation/thot/test_thot_word_alignment_model_trainer.py b/tests/translation/thot/test_thot_word_alignment_model_trainer.py new file mode 100644 index 00000000..7ef2d809 --- /dev/null +++ b/tests/translation/thot/test_thot_word_alignment_model_trainer.py @@ -0,0 +1,72 @@ +from pathlib import Path +from tempfile import TemporaryDirectory + +from machine.corpora.parallel_text_corpus import ParallelTextCorpus +from machine.tokenization.whitespace_tokenizer import WhitespaceTokenizer +from machine.translation.symmetrized_word_alignment_model_trainer import SymmetrizedWordAlignmentModelTrainer +from machine.translation.thot import ThotWordAlignmentModelTrainer +from machine.translation.thot.thot_symmetrized_word_alignment_model import ThotSymmetrizedWordAlignmentModel +from machine.translation.thot.thot_word_alignment_model_utils import create_thot_word_alignment_model +from machine.translation.word_alignment_matrix import WordAlignmentMatrix + +from .thot_model_trainer_helper import get_emtpy_parallel_corpus, get_parallel_corpus + + +def train_model( + corpus: ParallelTextCorpus, direct_model_path: Path, inverse_model_path: Path, thot_word_alignment_model_type: str +): + tokenizer = WhitespaceTokenizer() + direct_trainer = ThotWordAlignmentModelTrainer( + thot_word_alignment_model_type, + corpus.lowercase(), + prefix_filename=direct_model_path, + source_tokenizer=tokenizer, + target_tokenizer=tokenizer, + ) + inverse_trainer = ThotWordAlignmentModelTrainer( + thot_word_alignment_model_type, + corpus.invert().lowercase(), + prefix_filename=inverse_model_path, + source_tokenizer=tokenizer, + target_tokenizer=tokenizer, + ) + + with SymmetrizedWordAlignmentModelTrainer(direct_trainer, inverse_trainer) as trainer: + trainer.train() + trainer.save() + + +def test_train_non_empty_corpus() -> None: + with TemporaryDirectory() as temp_dir: + corpus = get_parallel_corpus() + thot_word_alignment_model_type = "hmm" + tmp_path = Path(temp_dir) + direct_model_path = tmp_path / "tm" / "src_trg_invswm" + inverse_model_path = tmp_path / "tm" / "src_trg_swm" + train_model(corpus, direct_model_path, inverse_model_path, thot_word_alignment_model_type) + with ThotSymmetrizedWordAlignmentModel( + create_thot_word_alignment_model(thot_word_alignment_model_type, direct_model_path), + create_thot_word_alignment_model(thot_word_alignment_model_type, inverse_model_path), + ) as model: + matrix = model.align("una habitación individual por semana", "a single room cost per week") + assert matrix == WordAlignmentMatrix.from_word_pairs( + 6, 7, {(0, 0), (1, 1), (2, 2), (4, 3), (3, 4), (3, 5), (5, 6)} + ) + + +def test_train_empty_corpus() -> None: + with TemporaryDirectory() as temp_dir: + corpus = get_emtpy_parallel_corpus() + thot_word_alignment_model_type = "hmm" + tmp_path = Path(temp_dir) + direct_model_path = tmp_path / "tm" / "src_trg_invswm" + inverse_model_path = tmp_path / "tm" / "src_trg_swm" + train_model(corpus, direct_model_path, inverse_model_path, thot_word_alignment_model_type) + with ThotSymmetrizedWordAlignmentModel( + create_thot_word_alignment_model(thot_word_alignment_model_type, direct_model_path), + create_thot_word_alignment_model(thot_word_alignment_model_type, inverse_model_path), + ) as model: + matrix = model.align("una habitación individual por semana", "a single room cost per week") + assert matrix == WordAlignmentMatrix.from_word_pairs( + 6, 7, {(0, 0), (1, 1), (2, 2), (4, 3), (3, 4), (3, 5), (5, 6)} + ) diff --git a/tests/translation/thot/thot_model_trainer_helper.py b/tests/translation/thot/thot_model_trainer_helper.py new file mode 100644 index 00000000..6bb232a9 --- /dev/null +++ b/tests/translation/thot/thot_model_trainer_helper.py @@ -0,0 +1,103 @@ +from machine.corpora import ( + AlignedWordPair, + AlignmentRow, + DictionaryAlignmentCorpus, + DictionaryTextCorpus, + MemoryAlignmentCollection, + MemoryText, + TextRow, +) +from machine.corpora.parallel_text_corpus import ParallelTextCorpus + + +def get_parallel_corpus() -> ParallelTextCorpus: + source_corpus = DictionaryTextCorpus( + [ + MemoryText( + "text1", + [ + _row(1, "¿ le importaría darnos las llaves de la habitación , por favor ?"), + _row( + 2, + "he hecho la reserva de una habitación tranquila doble con ||| teléfono ||| y televisión a " + "nombre de rosario cabedo .", + ), + _row(3, "¿ le importaría cambiarme a otra habitación más tranquila ?"), + _row(4, "por favor , tengo reservada una habitación ."), + _row(5, "me parece que existe un problema ."), + _row( + 6, + "¿ tiene habitaciones libres con televisión , aire acondicionado y caja fuerte ?", + ), + _row(7, "¿ le importaría mostrarnos una habitación con televisión ?"), + _row(8, "¿ tiene teléfono ?"), + _row(9, "voy a marcharme el dos a las ocho de la noche ."), + _row(10, "¿ cuánto cuesta una habitación individual por semana ?"), + ], + ) + ] + ) + + target_corpus = DictionaryTextCorpus( + [ + MemoryText( + "text1", + [ + _row(1, "would you mind giving us the keys to the room , please ?"), + _row( + 2, + "i have made a reservation for a quiet , double room with a ||| telephone ||| and a tv for " + "rosario cabedo .", + ), + _row(3, "would you mind moving me to a quieter room ?"), + _row(4, "i have booked a room ."), + _row(5, "i think that there is a problem ."), + _row(6, "do you have any rooms with a tv , air conditioning and a safe available ?"), + _row(7, "would you mind showing us a room with a tv ?"), + _row(8, "does it have a telephone ?"), + _row(9, "i am leaving on the second at eight in the evening ."), + _row(10, "how much does a single room cost per week ?"), + ], + ) + ] + ) + + alignment_corpus = DictionaryAlignmentCorpus( + [ + MemoryAlignmentCollection( + "text1", + [ + _alignment(1, AlignedWordPair(8, 9)), + _alignment(2, AlignedWordPair(6, 10)), + _alignment(3, AlignedWordPair(6, 8)), + _alignment(4, AlignedWordPair(6, 4)), + _alignment(5), + _alignment(6, AlignedWordPair(2, 4)), + _alignment(7, AlignedWordPair(5, 6)), + _alignment(8), + _alignment(9), + _alignment(10, AlignedWordPair(4, 5)), + ], + ) + ] + ) + + corpus = source_corpus.align_rows(target_corpus, alignment_corpus) + return corpus + + +def get_emtpy_parallel_corpus() -> ParallelTextCorpus: + source_corpus = DictionaryTextCorpus([]) + target_corpus = DictionaryTextCorpus([]) + alignment_corpus = DictionaryAlignmentCorpus([]) + + corpus = source_corpus.align_rows(target_corpus, alignment_corpus) + return corpus + + +def _row(row_ref: int, text: str) -> TextRow: + return TextRow("text1", row_ref, segment=[text]) + + +def _alignment(row_ref: int, *pairs: AlignedWordPair) -> AlignmentRow: + return AlignmentRow("text1", row_ref, aligned_word_pairs=pairs)