Skip to content

Commit

Permalink
First set of reviewer comments..
Browse files Browse the repository at this point in the history
  • Loading branch information
johnml1135 committed Aug 20, 2024
1 parent e66e507 commit 652f6fb
Show file tree
Hide file tree
Showing 20 changed files with 411 additions and 295 deletions.
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
{
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": "explicit"
"source.organizeImports": "explicit",
},
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true,
"python.analysis.extraPaths": [
"tests"
],
"python.analysis.importFormat": "relative",
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true
Expand Down
11 changes: 8 additions & 3 deletions machine/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,27 @@
from .local_shared_file_service import LocalSharedFileService
from .nmt_engine_build_job import NmtEngineBuildJob
from .nmt_model_factory import NmtModelFactory
from .shared_file_service import DictToJsonWriter, PretranslationInfo, SharedFileService
from .shared_file_service_base import DictToJsonWriter, SharedFileServiceBase
from .smt_engine_build_job import SmtEngineBuildJob
from .smt_model_factory import SmtModelFactory
from .translation_file_service import PretranslationInfo, TranslationFileService
from .word_alignment_build_job import WordAlignmentBuildJob
from .word_alignment_file_service import WordAlignmentFileService, WordAlignmentInfo
from .word_alignment_model_factory import WordAlignmentModelFactory

__all__ = [
"ClearMLSharedFileService",
"LocalSharedFileService",
"NmtEngineBuildJob",
"NmtModelFactory",
"PretranslationInfo",
"DictToJsonWriter",
"SharedFileService",
"SharedFileServiceBase",
"SmtEngineBuildJob",
"SmtModelFactory",
"PretranslationInfo",
"TranslationFileService",
"WordAlignmentBuildJob",
"WordAlignmentFileService",
"WordAlignmentInfo",
"WordAlignmentModelFactory",
]
5 changes: 3 additions & 2 deletions machine/jobs/build_smt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
update_runtime_properties,
update_settings,
)
from .clearml_shared_file_service import ClearMLSharedFileService
from .config import SETTINGS
from .shared_file_service_factory import SharedFileServiceType
from .smt_engine_build_job import SmtEngineBuildJob
from .smt_model_factory import SmtModelFactory
from .translation_file_service import TranslationFileService

# Setup logging
logging.basicConfig(
Expand Down Expand Up @@ -54,7 +55,7 @@ def run(args: dict) -> None:

logger.info(f"Config: {SETTINGS.as_dict()}")

shared_file_service = ClearMLSharedFileService(SETTINGS)
shared_file_service = TranslationFileService(SharedFileServiceType.CLEARML, SETTINGS)
smt_model_factory: SmtModelFactory
if SETTINGS.model_type == "thot":
from .thot.thot_smt_model_factory import ThotSmtModelFactory
Expand Down
3 changes: 1 addition & 2 deletions machine/jobs/build_word_alignment_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

from clearml import Task

from machine.jobs.thot.thot_word_alignment_model_factory import ThotWordAlignmentModelFactory

from ..utils.progress_status import ProgressStatus
from .async_scheduler import AsyncScheduler
from .build_clearml_helper import (
Expand All @@ -19,6 +17,7 @@
)
from .clearml_shared_file_service import ClearMLSharedFileService
from .config import SETTINGS
from .thot.thot_word_alignment_model_factory import ThotWordAlignmentModelFactory
from .word_alignment_build_job import WordAlignmentBuildJob
from .word_alignment_model_factory import WordAlignmentModelFactory

Expand Down
48 changes: 17 additions & 31 deletions machine/jobs/clearml_shared_file_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,59 +5,45 @@

from clearml import StorageManager

from .shared_file_service import SharedFileService
from .shared_file_service_base import SharedFileServiceBase

logger = logging.getLogger(__name__)


class ClearMLSharedFileService(SharedFileService):
def _download_file(self, path: str) -> Path:
uri = f"{self._shared_file_uri}/{self._shared_file_folder}/{path}"
class ClearMLSharedFileService(SharedFileServiceBase):
def download_file(self, path: str) -> Path:
local_folder = str(self._data_dir)
file_path = try_n_times(lambda: StorageManager.download_file(uri, local_folder, skip_zero_size_check=True))
file_path = try_n_times(
lambda: StorageManager.download_file(self._get_uri(path), local_folder, skip_zero_size_check=True)
)
if file_path is None:
raise RuntimeError(f"Failed to download file: {uri}")
raise RuntimeError(f"Failed to download file: {self._get_uri(path)}")
return Path(file_path)

def _download_folder(self, path: str) -> Path:
uri = f"{self._shared_file_uri}/{self._shared_file_folder}/{path}"
local_folder = str(self._data_dir)
folder_path = try_n_times(lambda: StorageManager.download_folder(uri, local_folder))
folder_path = try_n_times(lambda: StorageManager.download_folder(self._get_uri(path), local_folder))
if folder_path is None:
raise RuntimeError(f"Failed to download folder: {uri}")
raise RuntimeError(f"Failed to download folder: {self._get_uri(path)}")
return Path(folder_path) / path

def _exists_file(self, path: str) -> bool:
uri = f"{self._shared_file_uri}/{self._shared_file_folder}/{path}"
return try_n_times(lambda: StorageManager.exists_file(uri)) # type: ignore
return try_n_times(lambda: StorageManager.exists_file(self._get_uri(path))) # type: ignore

def _upload_file(self, path: str, local_file_path: Path) -> None:
final_destination = try_n_times(
lambda: StorageManager.upload_file(
str(local_file_path), f"{self._shared_file_uri}/{self._shared_file_folder}/{path}"
)
)
final_destination = try_n_times(lambda: StorageManager.upload_file(str(local_file_path), self._get_uri(path)))
if final_destination is None:
logger.error(
(
f"Failed to upload file {str(local_file_path)} "
f"to {self._shared_file_uri}/{self._shared_file_folder}/{path}."
)
)
logger.error((f"Failed to upload file {str(local_file_path)} " f"to {self._get_uri(path)}."))

def _upload_folder(self, path: str, local_folder_path: Path) -> None:
final_destination = try_n_times(
lambda: StorageManager.upload_folder(
str(local_folder_path), f"{self._shared_file_uri}/{self._shared_file_folder}/{path}"
)
lambda: StorageManager.upload_folder(str(local_folder_path), self._get_uri(path))
)
if final_destination is None:
logger.error(
(
f"Failed to upload folder {str(local_folder_path)} "
f"to {self._shared_file_uri}/{self._shared_file_folder}/{path}."
)
)
logger.error((f"Failed to upload folder {str(local_folder_path)} " f"to {self._get_uri(path)}."))

def _get_uri(self, path: str) -> str:
return f"{self._shared_file_uri}/{self._shared_file_folder}/{path}"


def try_n_times(func: Callable, n=10):
Expand Down
6 changes: 3 additions & 3 deletions machine/jobs/local_shared_file_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import shutil
from pathlib import Path

from .shared_file_service import SharedFileService
from .shared_file_service_base import SharedFileServiceBase

logger = logging.getLogger(__name__)


class LocalSharedFileService(SharedFileService):
def _download_file(self, path: str) -> Path:
class LocalSharedFileService(SharedFileServiceBase):
def download_file(self, path: str) -> Path:
src_path = self._get_path(path)
dst_path = self._data_dir / self._shared_file_folder / path
dst_path.parent.mkdir(parents=True, exist_ok=True)
Expand Down
33 changes: 18 additions & 15 deletions machine/jobs/nmt_engine_build_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,22 @@
from ..translation.translation_engine import TranslationEngine
from ..utils.phased_progress_reporter import Phase, PhasedProgressReporter
from ..utils.progress_status import ProgressStatus
from .engine_build_job import EngineBuildJob
from .nmt_model_factory import NmtModelFactory
from .shared_file_service import DictToJsonWriter, PretranslationInfo, SharedFileService
from .shared_file_service_base import DictToJsonWriter
from .translation_engine_build_job import TranslationEngineBuildJob
from .translation_file_service import PretranslationInfo, TranslationFileService

logger = logging.getLogger(__name__)


class NmtEngineBuildJob(EngineBuildJob):
def __init__(self, config: Any, nmt_model_factory: NmtModelFactory, shared_file_service: SharedFileService) -> None:
class NmtEngineBuildJob(TranslationEngineBuildJob):
def __init__(
self, config: Any, nmt_model_factory: NmtModelFactory, translation_file_service: TranslationFileService
) -> None:
self._nmt_model_factory = nmt_model_factory
super().__init__(config, shared_file_service)
super().__init__(config, translation_file_service)

def start_job(self) -> None:
def _start_job(self) -> None:
self._nmt_model_factory.init()

def _get_progress_reporter(self, progress: Optional[Callable[[ProgressStatus], None]]) -> PhasedProgressReporter:
Expand All @@ -31,10 +34,10 @@ def _get_progress_reporter(self, progress: Optional[Callable[[ProgressStatus], N
phases = [Phase(message="Pretranslating segments", percentage=1.0)]
return PhasedProgressReporter(progress, phases)

def respond_to_no_training_corpus(self) -> None:
def _respond_to_no_training_corpus(self) -> None:
logger.info("No matching entries in the source and target corpus - skipping training")

def train_model(
def _train_model(
self,
progress_reporter: PhasedProgressReporter,
check_canceled: Optional[Callable[[], None]],
Expand Down Expand Up @@ -70,19 +73,19 @@ def train_model(
model_trainer.train(progress=phase_progress, check_canceled=check_canceled)
model_trainer.save()

def batch_inference(
def _batch_inference(
self,
progress_reporter: PhasedProgressReporter,
check_canceled: Optional[Callable[[], None]],
) -> None:
logger.info("Pretranslating segments")
with self._shared_file_service.get_source_pretranslations() as src_pretranslations:
with self._translation_file_service.get_source_pretranslations() as src_pretranslations:
inference_step_count = sum(1 for _ in src_pretranslations)
with ExitStack() as stack:
phase_progress = stack.enter_context(progress_reporter.start_next_phase())
engine = stack.enter_context(self._nmt_model_factory.create_engine())
src_pretranslations = stack.enter_context(self._shared_file_service.get_source_pretranslations())
writer = stack.enter_context(self._shared_file_service.open_target_pretranslation_writer())
src_pretranslations = stack.enter_context(self._translation_file_service.get_source_pretranslations())
writer = stack.enter_context(self._translation_file_service.open_target_pretranslation_writer())
current_inference_step = 0
phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count))
batch_size = self._config["inference_batch_size"]
Expand All @@ -93,11 +96,11 @@ def batch_inference(
current_inference_step += len(pi_batch)
phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count))

def save_model(self) -> None:
def _save_model(self) -> None:
if "save_model" in self._config and self._config.save_model is not None:
logger.info("Saving model")
model_path = self._nmt_model_factory.save_model()
self._shared_file_service.save_model(
self._translation_file_service.save_model(
model_path, f"models/{self._config.save_model + ''.join(model_path.suffixes)}"
)

Expand All @@ -110,4 +113,4 @@ def _translate_batch(
source_segments = [pi["translation"] for pi in batch]
for i, result in enumerate(engine.translate_batch(source_segments)):
batch[i]["translation"] = result.translation
writer.write(batch[i]) # type: ignore
writer.write(batch[i])
3 changes: 2 additions & 1 deletion machine/jobs/settings.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ default:
add_unk_trg_tokens: true
thot:
word_alignment_model_type: hmm
word_alignment_heuristic: grow-diag-final-and
tokenizer: latin
thot_align:
word_alignment_heuristic: grow-diag-final-and
development:
shared_file_folder: dev
huggingface:
Expand Down
Loading

0 comments on commit 652f6fb

Please sign in to comment.