Skip to content

Commit

Permalink
Publish num_steps
Browse files Browse the repository at this point in the history
Publish inference_step
Apply temp huggingface fix
  • Loading branch information
johnml1135 committed Oct 12, 2023
1 parent adf9bb9 commit 7da35a0
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 27 deletions.
8 changes: 6 additions & 2 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@
},
"editor.formatOnSave": true,
"editor.formatOnType": true,
"isort.args":["--profile", "black"]
"isort.args": [
"--profile",
"black"
]
},
// Add the IDs of extensions you want installed when the container is created.
"extensions": [
Expand All @@ -54,7 +57,8 @@
"donjayamanne.githistory",
"tamasfe.even-better-toml",
"github.vscode-github-actions",
"mhutchie.git-graph"
"mhutchie.git-graph",
"GitHub.copilot"
]
}
}
Expand Down
9 changes: 1 addition & 8 deletions machine/jobs/build_nmt_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,10 @@


def run(args: dict) -> None:
check_canceled: Optional[Callable[[], None]] = None
task = None
if args["clearml"]:
task = Task.init()

def clearml_check_canceled() -> None:
if task.get_status() in {"stopped", "stopping"}:
raise CanceledError

check_canceled = clearml_check_canceled

try:
logger.info("NMT Engine Build Job started")

Expand Down Expand Up @@ -60,7 +53,7 @@ def clearml_check_canceled() -> None:
raise RuntimeError("The model type is invalid.")

job = NmtEngineBuildJob(SETTINGS, nmt_model_factory, shared_file_service)
job.run(check_canceled)
job.run(task)
logger.info("Finished")
except Exception as e:
logger.exception(e, stack_info=True)
Expand Down
12 changes: 12 additions & 0 deletions machine/jobs/huggingface/hugging_face_nmt_model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, cast

from transformers import AutoConfig, AutoModelForSeq2SeqLM, HfArgumentParser, PreTrainedModel, Seq2SeqTrainingArguments
from transformers.integrations import ClearMLCallback

from ...corpora.parallel_text_corpus import ParallelTextCorpus
from ...corpora.text_corpus import TextCorpus
Expand Down Expand Up @@ -77,3 +78,14 @@ def save_model(self) -> None:
@property
def _model_dir(self) -> Path:
return Path(self._config.data_dir, "builds", self._config.build_id, "model")


# FIXME - remove this code when the fix is applied to Huggingface
# https://github.com/huggingface/transformers/pull/26763
def on_train_end(
self: ClearMLCallback, args, state, control, model=None, tokenizer=None, metrics=None, logs=None, **kwargs
):
pass


setattr(ClearMLCallback, "on_train_end", on_train_end)
50 changes: 33 additions & 17 deletions machine/jobs/nmt_engine_build_job.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import logging
from contextlib import ExitStack
from typing import Any, Callable, Optional, Sequence
from typing import Any, Optional, Sequence

import pandas as pd
from clearml import Model, Task

from ..corpora.corpora_utils import batch
from ..translation.translation_engine import TranslationEngine
from ..utils.canceled_error import CanceledError
from .nmt_model_factory import NmtModelFactory
from .shared_file_service import PretranslationInfo, PretranslationWriter, SharedFileService

Expand All @@ -15,10 +19,11 @@ def __init__(self, config: Any, nmt_model_factory: NmtModelFactory, shared_file_
self._config = config
self._nmt_model_factory = nmt_model_factory
self._shared_file_service = shared_file_service
self.clearml_task: Optional[Task] = None

def run(self, check_canceled: Optional[Callable[[], None]] = None) -> None:
if check_canceled is not None:
check_canceled()
def run(self, task: Optional[Task]) -> None:
self.clearml_task = task
self._send_clearml_config()

self._nmt_model_factory.init()

Expand All @@ -28,46 +33,57 @@ def run(self, check_canceled: Optional[Callable[[], None]] = None) -> None:
parallel_corpus = source_corpus.align_rows(target_corpus)

if parallel_corpus.count(include_empty=False):
if check_canceled is not None:
check_canceled()
self._check_canceled()

if self._nmt_model_factory.train_tokenizer:
logger.info("Training source tokenizer")
with self._nmt_model_factory.create_source_tokenizer_trainer(source_corpus) as source_tokenizer_trainer:
source_tokenizer_trainer.train(check_canceled=check_canceled)
source_tokenizer_trainer.train(check_canceled=self._check_canceled)
source_tokenizer_trainer.save()

if check_canceled is not None:
check_canceled()
self._check_canceled()

logger.info("Training target tokenizer")
with self._nmt_model_factory.create_target_tokenizer_trainer(target_corpus) as target_tokenizer_trainer:
target_tokenizer_trainer.train(check_canceled=check_canceled)
target_tokenizer_trainer.train(check_canceled=self._check_canceled)
target_tokenizer_trainer.save()

if check_canceled is not None:
check_canceled()
self._check_canceled()

logger.info("Training NMT model")
with self._nmt_model_factory.create_model_trainer(parallel_corpus) as model_trainer:
model_trainer.train(check_canceled=check_canceled)
model_trainer.train(check_canceled=self._check_canceled)
model_trainer.save()
else:
logger.info("No matching entries in the source and target corpus - skipping training")

if check_canceled is not None:
check_canceled()
self._check_canceled()

logger.info("Pretranslating segments")
with ExitStack() as stack:
model = stack.enter_context(self._nmt_model_factory.create_engine())
src_pretranslations = stack.enter_context(self._shared_file_service.get_source_pretranslations())
writer = stack.enter_context(self._shared_file_service.open_target_pretranslation_writer())
for pi_batch in batch(src_pretranslations, self._config["batch_size"]):
if check_canceled is not None:
check_canceled()
self._check_canceled()
_translate_batch(model, pi_batch, writer)

def _send_clearml_config(self) -> None:
if self.clearml_task:
self.clearml_task.get_logger().report_single_value(name="total_steps", value=self._config["max_steps"])

def _check_canceled(self) -> None:
if self.clearml_task:
if self.clearml_task.get_status() in {"stopped", "stopping"}:
raise CanceledError

def _update_inference_step(self, step_num: int) -> None:
if self.clearml_task:
self.clearml_task.mark_started(force=True)
self.clearml_task.get_logger().report_single_value(name="inference_step", value=step_num)
# This is a hack fix for a clearml bug: https://github.com/allegroai/clearml/issues/1119
self.clearml_task.get_logger().flush(wait=True)


def _translate_batch(
engine: TranslationEngine,
Expand Down

0 comments on commit 7da35a0

Please sign in to comment.