diff --git a/machine/jobs/huggingface/hugging_face_nmt_model_factory.py b/machine/jobs/huggingface/hugging_face_nmt_model_factory.py index ce44c75..8601f71 100644 --- a/machine/jobs/huggingface/hugging_face_nmt_model_factory.py +++ b/machine/jobs/huggingface/hugging_face_nmt_model_factory.py @@ -24,6 +24,8 @@ def __init__(self, config: Any, shared_file_service: SharedFileService) -> None: args["output_dir"] = str(self._model_dir) args["overwrite_output_dir"] = True if "max_steps" in self._config: + if self._config.max_steps > 50000: + raise ValueError("max_steps must be less than or equal to 50000") args["max_steps"] = self._config.max_steps parser = HfArgumentParser(cast(Any, Seq2SeqTrainingArguments)) self._training_args = cast(Seq2SeqTrainingArguments, parser.parse_dict(args)[0])