Skip to content

Commit

Permalink
Initial refactor
Browse files Browse the repository at this point in the history
Initial word alignment build job
Update tests

Update from reviewer comments

Small fixes

Reviewer comments - still need to add new word align fodler template

one more fix
  • Loading branch information
johnml1135 committed Aug 23, 2024
1 parent 3912c6a commit 2bd1d72
Show file tree
Hide file tree
Showing 35 changed files with 1,398 additions and 556 deletions.
16 changes: 16 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"env": {
"PYTHONPATH": "${workspaceFolder}:${workspaceFolder}/tests"
},
"justMyCode": true
},
{
Expand Down Expand Up @@ -64,6 +67,19 @@
"build1"
]
},
{
"name": "build_word_alignment_model",
"type": "debugpy",
"request": "launch",
"module": "machine.jobs.build_word_alignment_model",
"justMyCode": false,
"args": [
"--model-type",
"thot",
"--build-id",
"build1"
]
},
{
"name": "Python: Debug Tests",
"type": "debugpy",
Expand Down
8 changes: 6 additions & 2 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
{
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit"
"source.organizeImports": "explicit",
},
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"python.analysis.extraPaths": [
"tests"
],
"python.analysis.importFormat": "relative",
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true
Expand All @@ -17,4 +21,4 @@
"python.analysis.extraPaths": [
"./tests"
]
}
}
3 changes: 1 addition & 2 deletions machine/corpora/usfm_text_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
from io import TextIOWrapper
from typing import Generator, Iterable, List, Optional, Sequence

from machine.corpora.scripture_ref import ScriptureRef

from ..scripture.verse_ref import Versification
from ..utils.string_utils import has_sentence_ending
from .corpora_utils import gen
from .scripture_ref import ScriptureRef
from .scripture_ref_usfm_parser_handler import ScriptureRefUsfmParserHandler, ScriptureTextType
from .scripture_text import ScriptureText
from .stream_container import StreamContainer
Expand Down
17 changes: 13 additions & 4 deletions machine/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,27 @@
from .local_shared_file_service import LocalSharedFileService
from .nmt_engine_build_job import NmtEngineBuildJob
from .nmt_model_factory import NmtModelFactory
from .shared_file_service import PretranslationInfo, PretranslationWriter, SharedFileService
from .shared_file_service_base import DictToJsonWriter, SharedFileServiceBase
from .smt_engine_build_job import SmtEngineBuildJob
from .smt_model_factory import SmtModelFactory
from .translation_file_service import PretranslationInfo, TranslationFileService
from .word_alignment_build_job import WordAlignmentBuildJob
from .word_alignment_file_service import WordAlignmentFileService, WordAlignmentInfo
from .word_alignment_model_factory import WordAlignmentModelFactory

__all__ = [
"ClearMLSharedFileService",
"LocalSharedFileService",
"NmtEngineBuildJob",
"NmtModelFactory",
"PretranslationInfo",
"PretranslationWriter",
"SharedFileService",
"DictToJsonWriter",
"SharedFileServiceBase",
"SmtEngineBuildJob",
"SmtModelFactory",
"PretranslationInfo",
"TranslationFileService",
"WordAlignmentBuildJob",
"WordAlignmentFileService",
"WordAlignmentInfo",
"WordAlignmentModelFactory",
]
117 changes: 117 additions & 0 deletions machine/jobs/build_clearml_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import json
import logging
import os
from datetime import datetime
from typing import Callable, Optional, Union, cast

import aiohttp
from clearml import Task
from dynaconf.base import Settings

from ..utils.canceled_error import CanceledError
from ..utils.progress_status import ProgressStatus
from .async_scheduler import AsyncScheduler


class ProgressInfo:
last_percent_completed: Union[int, None] = 0
last_message: Union[str, None] = ""
last_progress_time: Union[datetime, None] = None
last_check_canceled_time: Union[datetime, None] = None


def get_clearml_check_canceled(progress_info: ProgressInfo, task: Task) -> Callable[[], None]:

def clearml_check_canceled() -> None:
current_time = datetime.now()
if (
progress_info.last_check_canceled_time is None
or (current_time - progress_info.last_check_canceled_time).seconds > 20
):
if task.get_status() == "stopped":
raise CanceledError
progress_info.last_check_canceled_time = current_time

return clearml_check_canceled


def get_clearml_progress_caller(
progress_info: ProgressInfo, task: Task, scheduler: AsyncScheduler, logger: logging.Logger
) -> Callable[[ProgressStatus], None]:
def clearml_progress(progress_status: ProgressStatus) -> None:
percent_completed: Optional[int] = None
if progress_status.percent_completed is not None:
percent_completed = round(progress_status.percent_completed * 100)
message = progress_status.message
if percent_completed != progress_info.last_percent_completed or message != progress_info.last_message:
logger.info(f"{percent_completed}% - {message}")
current_time = datetime.now()
if (
progress_info.last_progress_time is None
or (current_time - progress_info.last_progress_time).seconds > 1
):
new_runtime_props = task.data.runtime.copy() or {} # type: ignore
new_runtime_props["progress"] = str(percent_completed)
new_runtime_props["message"] = message
scheduler.schedule(
update_runtime_properties(
task.id, # type: ignore
task.session.host,
task.session.token, # type: ignore
create_runtime_properties(task, percent_completed, message),
)
)
progress_info.last_progress_time = current_time
progress_info.last_percent_completed = percent_completed
progress_info.last_message = message

return clearml_progress


def get_local_progress_caller(progress_info: ProgressInfo, logger: logging.Logger) -> Callable[[ProgressStatus], None]:

def local_progress(progress_status: ProgressStatus) -> None:
percent_completed: Optional[int] = None
if progress_status.percent_completed is not None:
percent_completed = round(progress_status.percent_completed * 100)
message = progress_status.message
if percent_completed != progress_info.last_percent_completed or message != progress_info.last_message:
logger.info(f"{percent_completed}% - {message}")
progress_info.last_percent_completed = percent_completed
progress_info.last_message = message

return local_progress


def update_settings(settings: Settings, args: dict):
settings.update(args)
settings.model_type = cast(str, settings.model_type).lower()
if "build_options" in settings:
try:
build_options = json.loads(cast(str, settings.build_options))
except ValueError as e:
raise ValueError("Build options could not be parsed: Invalid JSON") from e
except TypeError as e:
raise TypeError(f"Build options could not be parsed: {e}") from e
settings.update({settings.model_type: build_options})
settings.data_dir = os.path.expanduser(cast(str, settings.data_dir))


async def update_runtime_properties(task_id: str, base_url: str, token: str, runtime_props: dict) -> None:
async with aiohttp.ClientSession(base_url=base_url, headers={"Authorization": f"Bearer {token}"}) as session:
json = {"task": task_id, "runtime": runtime_props, "force": True}
async with session.post("/tasks.edit", json=json) as response:
response.raise_for_status()


def create_runtime_properties(task, percent_completed: Optional[int], message: Optional[str]) -> dict:
runtime_props = task.data.runtime.copy() or {}
if percent_completed is not None:
runtime_props["progress"] = str(percent_completed)
else:
del runtime_props["progress"]
if message is not None:
runtime_props["message"] = message
else:
del runtime_props["message"]
return runtime_props
7 changes: 4 additions & 3 deletions machine/jobs/build_nmt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

from ..utils.canceled_error import CanceledError
from ..utils.progress_status import ProgressStatus
from .clearml_shared_file_service import ClearMLSharedFileService
from .config import SETTINGS
from .nmt_engine_build_job import NmtEngineBuildJob
from .nmt_model_factory import NmtModelFactory
from .shared_file_service_factory import SharedFileServiceType
from .translation_file_service import TranslationFileService

# Setup logging
logging.basicConfig(
Expand Down Expand Up @@ -58,7 +59,7 @@ def clearml_progress(status: ProgressStatus) -> None:

logger.info(f"Config: {SETTINGS.as_dict()}")

shared_file_service = ClearMLSharedFileService(SETTINGS)
translation_file_service = TranslationFileService(SharedFileServiceType.CLEARML, SETTINGS)
nmt_model_factory: NmtModelFactory
if model_type == "huggingface":
from .huggingface.hugging_face_nmt_model_factory import HuggingFaceNmtModelFactory
Expand All @@ -67,7 +68,7 @@ def clearml_progress(status: ProgressStatus) -> None:
else:
raise RuntimeError("The model type is invalid.")

job = NmtEngineBuildJob(SETTINGS, nmt_model_factory, shared_file_service)
job = NmtEngineBuildJob(SETTINGS, nmt_model_factory, translation_file_service)
train_corpus_size = job.run(progress, check_canceled)
if task is not None:
task.get_logger().report_single_value(name="train_corpus_size", value=train_corpus_size)
Expand Down
Loading

0 comments on commit 2bd1d72

Please sign in to comment.