Skip to content

Commit

Permalink
Initial word alignment build job (now with files)
Browse files Browse the repository at this point in the history
  • Loading branch information
johnml1135 committed Aug 15, 2024
1 parent bff4239 commit 1893473
Show file tree
Hide file tree
Showing 22 changed files with 837 additions and 303 deletions.
117 changes: 117 additions & 0 deletions machine/jobs/build_clearml_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import json
import logging
import os
from datetime import datetime
from typing import Callable, Optional, Union, cast

import aiohttp
from clearml import Task
from dynaconf.base import Settings

from ..utils.canceled_error import CanceledError
from ..utils.progress_status import ProgressStatus
from .async_scheduler import AsyncScheduler


class ProgressInfo:
last_percent_completed: Union[int, None] = 0
last_message: Union[str, None] = ""
last_progress_time: Union[datetime, None] = None
last_check_canceled_time: Union[datetime, None] = None


def get_clearml_check_canceled(progress_info: ProgressInfo, task: Task) -> Callable[[], None]:

def clearml_check_canceled() -> None:
current_time = datetime.now()
if (
progress_info.last_check_canceled_time is None
or (current_time - progress_info.last_check_canceled_time).seconds > 20
):
if task.get_status() == "stopped":
raise CanceledError
progress_info.last_check_canceled_time = current_time

return clearml_check_canceled


def get_clearml_progress_caller(
progress_info: ProgressInfo, task: Task, scheduler: AsyncScheduler, logger: logging.Logger
) -> Callable[[ProgressStatus], None]:
def clearml_progress(progress_status: ProgressStatus) -> None:
percent_completed: Optional[int] = None
if progress_status.percent_completed is not None:
percent_completed = round(progress_status.percent_completed * 100)
message = progress_status.message
if percent_completed != progress_info.last_percent_completed or message != progress_info.last_message:
logger.info(f"{percent_completed}% - {message}")
current_time = datetime.now()
if (
progress_info.last_progress_time is None
or (current_time - progress_info.last_progress_time).seconds > 1
):
new_runtime_props = task.data.runtime.copy() or {} # type: ignore
new_runtime_props["progress"] = str(percent_completed)
new_runtime_props["message"] = message
scheduler.schedule(
update_runtime_properties(
task.id, # type: ignore
task.session.host,
task.session.token, # type: ignore
create_runtime_properties(task, percent_completed, message),
)
)
progress_info.last_progress_time = current_time
progress_info.last_percent_completed = percent_completed
progress_info.last_message = message

return clearml_progress


def get_local_progress_caller(progress_info: ProgressInfo, logger: logging.Logger) -> Callable[[ProgressStatus], None]:

def local_progress(progress_status: ProgressStatus) -> None:
percent_completed: Optional[int] = None
if progress_status.percent_completed is not None:
percent_completed = round(progress_status.percent_completed * 100)
message = progress_status.message
if percent_completed != progress_info.last_percent_completed or message != progress_info.last_message:
logger.info(f"{percent_completed}% - {message}")
progress_info.last_percent_completed = percent_completed
progress_info.last_message = message

return local_progress


def update_settings(settings: Settings, args: dict):
settings.update(args)
settings.model_type = cast(str, settings.model_type).lower()
if "build_options" in settings:
try:
build_options = json.loads(cast(str, settings.build_options))
except ValueError as e:
raise ValueError("Build options could not be parsed: Invalid JSON") from e
except TypeError as e:
raise TypeError(f"Build options could not be parsed: {e}") from e
settings.update({settings.model_type: build_options})
settings.data_dir = os.path.expanduser(cast(str, settings.data_dir))


async def update_runtime_properties(task_id: str, base_url: str, token: str, runtime_props: dict) -> None:
async with aiohttp.ClientSession(base_url=base_url, headers={"Authorization": f"Bearer {token}"}) as session:
json = {"task": task_id, "runtime": runtime_props, "force": True}
async with session.post("/tasks.edit", json=json) as response:
response.raise_for_status()


def create_runtime_properties(task, percent_completed: Optional[int], message: Optional[str]) -> dict:
runtime_props = task.data.runtime.copy() or {}
if percent_completed is not None:
runtime_props["progress"] = str(percent_completed)
else:
del runtime_props["progress"]
if message is not None:
runtime_props["message"] = message
else:
del runtime_props["message"]
return runtime_props
112 changes: 16 additions & 96 deletions machine/jobs/build_smt_engine.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
import argparse
import json
import logging
import os
from datetime import datetime
from typing import Callable, Optional, cast

import aiohttp
from clearml import Task

from ..utils.canceled_error import CanceledError
from ..utils.progress_status import ProgressStatus
from .async_scheduler import AsyncScheduler
from .build_clearml_helper import (
ProgressInfo,
create_runtime_properties,
get_clearml_check_canceled,
get_clearml_progress_caller,
get_local_progress_caller,
update_runtime_properties,
update_settings,
)
from .clearml_shared_file_service import ClearMLSharedFileService
from .config import SETTINGS
from .smt_engine_build_job import SmtEngineBuildJob
Expand All @@ -25,118 +29,34 @@
logger = logging.getLogger(str(__package__) + ".build_smt_engine")


async def update_runtime_properties(task_id: str, base_url: str, token: str, runtime_props: dict) -> None:
async with aiohttp.ClientSession(base_url=base_url, headers={"Authorization": f"Bearer {token}"}) as session:
json = {"task": task_id, "runtime": runtime_props, "force": True}
async with session.post("/tasks.edit", json=json) as response:
response.raise_for_status()


def create_runtime_properties(task, percent_completed: Optional[int], message: Optional[str]) -> dict:
runtime_props = task.data.runtime.copy() or {}
if percent_completed is not None:
runtime_props["progress"] = str(percent_completed)
else:
del runtime_props["progress"]
if message is not None:
runtime_props["message"] = message
else:
del runtime_props["message"]
return runtime_props


def run(args: dict) -> None:
progress: Callable[[ProgressStatus], None]
check_canceled: Optional[Callable[[], None]] = None
task = None
last_percent_completed: Optional[int] = None
last_message: Optional[str] = None
scheduler: Optional[AsyncScheduler] = None
progress_info = ProgressInfo()
if args["clearml"]:
task = Task.init()

scheduler = AsyncScheduler()

last_check_canceled_time: Optional[datetime] = None

def clearml_check_canceled() -> None:
nonlocal last_check_canceled_time
current_time = datetime.now()
if last_check_canceled_time is None or (current_time - last_check_canceled_time).seconds > 20:
if task.get_status() == "stopped":
raise CanceledError
last_check_canceled_time = current_time

check_canceled = clearml_check_canceled
check_canceled = get_clearml_check_canceled(progress_info, task)

task.reload()

last_progress_time: Optional[datetime] = None

def clearml_progress(progress_status: ProgressStatus) -> None:
nonlocal last_percent_completed
nonlocal last_message
nonlocal last_progress_time
percent_completed: Optional[int] = None
if progress_status.percent_completed is not None:
percent_completed = round(progress_status.percent_completed * 100)
message = progress_status.message
if percent_completed != last_percent_completed or message != last_message:
logger.info(f"{percent_completed}% - {message}")
current_time = datetime.now()
if last_progress_time is None or (current_time - last_progress_time).seconds > 1:
new_runtime_props = task.data.runtime.copy() or {}
new_runtime_props["progress"] = str(percent_completed)
new_runtime_props["message"] = message
scheduler.schedule(
update_runtime_properties(
task.id,
task.session.host,
task.session.token,
create_runtime_properties(task, percent_completed, message),
)
)
last_progress_time = current_time
last_percent_completed = percent_completed
last_message = message

progress = clearml_progress
else:
progress = get_clearml_progress_caller(progress_info, task, scheduler, logger)

def local_progress(progress_status: ProgressStatus) -> None:
nonlocal last_percent_completed
nonlocal last_message
percent_completed: Optional[int] = None
if progress_status.percent_completed is not None:
percent_completed = round(progress_status.percent_completed * 100)
message = progress_status.message
if percent_completed != last_percent_completed or message != last_message:
logger.info(f"{percent_completed}% - {message}")
last_percent_completed = percent_completed
last_message = message

progress = local_progress
else:
progress = get_local_progress_caller(ProgressInfo(), logger)

try:
logger.info("SMT Engine Build Job started")

SETTINGS.update(args)
model_type = cast(str, SETTINGS.model_type).lower()
if "build_options" in SETTINGS:
try:
build_options = json.loads(cast(str, SETTINGS.build_options))
except ValueError as e:
raise ValueError("Build options could not be parsed: Invalid JSON") from e
except TypeError as e:
raise TypeError(f"Build options could not be parsed: {e}") from e
SETTINGS.update({model_type: build_options})
SETTINGS.data_dir = os.path.expanduser(cast(str, SETTINGS.data_dir))
update_settings(SETTINGS, args)

logger.info(f"Config: {SETTINGS.as_dict()}")

shared_file_service = ClearMLSharedFileService(SETTINGS)
smt_model_factory: SmtModelFactory
if model_type == "thot":
if SETTINGS.model_type == "thot":
from .thot.thot_smt_model_factory import ThotSmtModelFactory

smt_model_factory = ThotSmtModelFactory(SETTINGS)
Expand Down
103 changes: 103 additions & 0 deletions machine/jobs/build_word_alignment_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import argparse
import logging
from typing import Callable, Optional

from clearml import Task

from machine.jobs.thot.thot_word_alignment_model_factory import ThotWordAlignmentModelFactory

from ..utils.progress_status import ProgressStatus
from .async_scheduler import AsyncScheduler
from .build_clearml_helper import (
ProgressInfo,
create_runtime_properties,
get_clearml_check_canceled,
get_clearml_progress_caller,
get_local_progress_caller,
update_runtime_properties,
update_settings,
)
from .clearml_shared_file_service import ClearMLSharedFileService
from .config import SETTINGS
from .word_alignment_build_job import WordAlignmentBuildJob
from .word_alignment_model_factory import WordAlignmentModelFactory

# Setup logging
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
level=logging.INFO,
)

logger = logging.getLogger(str(__package__) + ".build_word_alignment_model")


def run(args: dict) -> None:
progress: Callable[[ProgressStatus], None]
check_canceled: Optional[Callable[[], None]] = None
task = None
scheduler: Optional[AsyncScheduler] = None
progress_info = ProgressInfo()
if args["clearml"]:
task = Task.init()
scheduler = AsyncScheduler()

check_canceled = get_clearml_check_canceled(progress_info, task)

task.reload()

progress = get_clearml_progress_caller(progress_info, task, scheduler, logger)

else:
progress = get_local_progress_caller(ProgressInfo(), logger)

try:
logger.info("Word Alignment Build Job started")
update_settings(SETTINGS, args)

logger.info(f"Config: {SETTINGS.as_dict()}")

shared_file_service = ClearMLSharedFileService(SETTINGS)
word_alignment_model_factory: WordAlignmentModelFactory
if SETTINGS.model_type == "thot":
word_alignment_model_factory = ThotWordAlignmentModelFactory(SETTINGS)
else:
raise RuntimeError("The model type is invalid.")

word_alignment_build_job = WordAlignmentBuildJob(SETTINGS, word_alignment_model_factory, shared_file_service)
train_corpus_size, confidence = word_alignment_build_job.run(progress, check_canceled)
if scheduler is not None and task is not None:
scheduler.schedule(
update_runtime_properties(
task.id, task.session.host, task.session.token, create_runtime_properties(task, 100, "Completed")
)
)
task.get_logger().report_single_value(name="train_corpus_size", value=train_corpus_size)
task.get_logger().report_single_value(name="confidence", value=round(confidence, 4))
logger.info("Finished")
except Exception as e:
if task:
if task.get_status() == "stopped":
return
else:
task.mark_failed(status_reason=type(e).__name__, status_message=str(e))
raise e
finally:
if scheduler is not None:
scheduler.stop()


def main() -> None:
parser = argparse.ArgumentParser(description="Trains an SMT model.")
parser.add_argument("--model-type", required=True, type=str, help="Model type")
parser.add_argument("--build-id", required=True, type=str, help="Build id")
parser.add_argument("--clearml", default=False, action="store_true", help="Initializes a ClearML task")
parser.add_argument("--build-options", default=None, type=str, help="Build configurations")
args = parser.parse_args()

input_args = {k: v for k, v in vars(args).items() if v is not None}

run(input_args)


if __name__ == "__main__":
main()
10 changes: 6 additions & 4 deletions machine/jobs/engine_build_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Tuple

from ..corpora.parallel_text_corpus import ParallelTextCorpus
from ..utils.phased_progress_reporter import PhasedProgressReporter
from ..utils.progress_status import ProgressStatus
from .shared_file_service import SharedFileService
Expand Down Expand Up @@ -47,10 +48,11 @@ def start_job(self) -> None: ...

def init_corpus(self) -> None:
logger.info("Downloading data files")
self._source_corpus = self._shared_file_service.create_source_corpus()
self._target_corpus = self._shared_file_service.create_target_corpus()
self._parallel_corpus = self._source_corpus.align_rows(self._target_corpus)
self._parallel_corpus_size = self._parallel_corpus.count(include_empty=False)
if "_source_corpus" not in self.__dict__:
self._source_corpus = self._shared_file_service.create_source_corpus()
self._target_corpus = self._shared_file_service.create_target_corpus()
self._parallel_corpus: ParallelTextCorpus = self._source_corpus.align_rows(self._target_corpus)
self._parallel_corpus_size = self._parallel_corpus.count(include_empty=False)

@abstractmethod
def _get_progress_reporter(
Expand Down
Loading

0 comments on commit 1893473

Please sign in to comment.