Skip to content

Commit

Permalink
Add option to save the model during build job
Browse files Browse the repository at this point in the history
  • Loading branch information
ddaspit committed Feb 2, 2024
1 parent f915db0 commit 65bc4d4
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 13 deletions.
1 change: 1 addition & 0 deletions machine/jobs/build_nmt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def main() -> None:
parser.add_argument("--trg-lang", required=True, type=str, help="Target language tag")
parser.add_argument("--clearml", default=False, action="store_true", help="Initializes a ClearML task")
parser.add_argument("--build-options", default=None, type=str, help="Build configurations")
parser.add_argument("--save-model", default=False, action="store_true", help="Save the model")
args = parser.parse_args()

run({k: v for k, v in vars(args).items() if v is not None})
Expand Down
8 changes: 7 additions & 1 deletion machine/jobs/huggingface/hugging_face_nmt_model_factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import tarfile
from pathlib import Path
from typing import Any, cast

Expand Down Expand Up @@ -84,7 +85,12 @@ def create_engine(self) -> TranslationEngine:
)

def save_model(self) -> None:
self._shared_file_service.save_model(self._model_dir)
tar_file_path = Path(self._config.data_dir, "builds", self._config.build_id, "model.tar.gz")
with tarfile.open(tar_file_path, "w:gz") as tar:
for path in self._model_dir.iterdir():
if path.is_file():
tar.add(path, arcname=path.name)
self._shared_file_service.save_model(tar_file_path)

@property
def _model_dir(self) -> Path:
Expand Down
4 changes: 4 additions & 0 deletions machine/jobs/nmt_engine_build_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def run(
current_inference_step += len(pi_batch)
phase_progress(ProgressStatus.from_step(current_inference_step, inference_step_count))

if self._config["save_model"]:
logger.info("Saving model")
self._nmt_model_factory.save_model()


def _translate_batch(
engine: TranslationEngine,
Expand Down
22 changes: 10 additions & 12 deletions machine/jobs/shared_file_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,11 @@ def open_target_pretranslation_writer(self) -> Iterator[PretranslationWriter]:
def get_parent_model(self, language_tag: str) -> Path:
return self._download_folder(f"parent_models/{language_tag}", cache=True)

def save_model(self, model_dir: Path) -> None:
self._upload_folder(f"models/{self._engine_id}", model_dir)
def save_model(self, model_path: Path) -> None:
if model_path.is_file():
self._upload_file(f"models/{self._build_id}" + "".join(model_path.suffixes), model_path)
else:
self._upload_folder(f"models/{self._build_id}", model_path)

@property
def _data_dir(self) -> Path:
Expand All @@ -102,21 +105,16 @@ def _shared_file_folder(self) -> str:
return shared_file_folder.rstrip("/")

@abstractmethod
def _download_file(self, path: str, cache: bool = False) -> Path:
...
def _download_file(self, path: str, cache: bool = False) -> Path: ...

@abstractmethod
def _download_folder(self, path: str, cache: bool = False) -> Path:
...
def _download_folder(self, path: str, cache: bool = False) -> Path: ...

@abstractmethod
def _exists_file(self, path: str) -> bool:
...
def _exists_file(self, path: str) -> bool: ...

@abstractmethod
def _upload_file(self, path: str, local_file_path: Path) -> None:
...
def _upload_file(self, path: str, local_file_path: Path) -> None: ...

@abstractmethod
def _upload_folder(self, path: str, local_folder_path: Path) -> None:
...
def _upload_folder(self, path: str, local_folder_path: Path) -> None: ...

0 comments on commit 65bc4d4

Please sign in to comment.