From 2174c7d025bf9bdad1098c13a4a5fa15ad996e1c Mon Sep 17 00:00:00 2001 From: mshannon-sil <131058912+mshannon-sil@users.noreply.github.com> Date: Thu, 2 Nov 2023 11:06:27 -0400 Subject: [PATCH] enforce upper bound on max steps (#55) --- machine/jobs/huggingface/hugging_face_nmt_model_factory.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/machine/jobs/huggingface/hugging_face_nmt_model_factory.py b/machine/jobs/huggingface/hugging_face_nmt_model_factory.py index ce44c75..8601f71 100644 --- a/machine/jobs/huggingface/hugging_face_nmt_model_factory.py +++ b/machine/jobs/huggingface/hugging_face_nmt_model_factory.py @@ -24,6 +24,8 @@ def __init__(self, config: Any, shared_file_service: SharedFileService) -> None: args["output_dir"] = str(self._model_dir) args["overwrite_output_dir"] = True if "max_steps" in self._config: + if self._config.max_steps > 50000: + raise ValueError("max_steps must be less than or equal to 50000") args["max_steps"] = self._config.max_steps parser = HfArgumentParser(cast(Any, Seq2SeqTrainingArguments)) self._training_args = cast(Seq2SeqTrainingArguments, parser.parse_dict(args)[0])