diff --git a/machine/jobs/build_nmt_engine.py b/machine/jobs/build_nmt_engine.py index 8d3a369..19031e5 100644 --- a/machine/jobs/build_nmt_engine.py +++ b/machine/jobs/build_nmt_engine.py @@ -30,7 +30,7 @@ def run(args: dict) -> None: task = Task.init() def clearml_check_canceled() -> None: - if task.get_status() in {"stopped", "stopping"}: + if task.get_status() == "stopped": raise CanceledError check_canceled = clearml_check_canceled @@ -72,7 +72,10 @@ def clearml_progress(status: ProgressStatus) -> None: logger.info("Finished") except Exception as e: if task: - task.mark_failed(status_reason=type(e).__name__, status_message=str(e)) + if task.get_status() == "stopped": + return + else: + task.mark_failed(status_reason=type(e).__name__, status_message=str(e)) raise e diff --git a/machine/translation/huggingface/hugging_face_nmt_model_trainer.py b/machine/translation/huggingface/hugging_face_nmt_model_trainer.py index ddc8dfc..907a045 100644 --- a/machine/translation/huggingface/hugging_face_nmt_model_trainer.py +++ b/machine/translation/huggingface/hugging_face_nmt_model_trainer.py @@ -382,21 +382,21 @@ def __init__( self._check_canceled = check_canceled def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None: + if self._check_canceled is not None: + self._check_canceled() + if self._progress is not None and state.is_local_process_zero: self._progress( ProgressStatus(0) if self._max_steps is None else ProgressStatus.from_step(0, self._max_steps) ) + def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None: if self._check_canceled is not None: self._check_canceled() - def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None: if self._progress is not None and state.is_local_process_zero: self._progress( ProgressStatus(state.global_step) if self._max_steps is None else ProgressStatus.from_step(state.global_step, self._max_steps) ) - - if self._check_canceled is not None: - self._check_canceled()