From 3d849676f56704a719f6487d3ed0a6037ffd4130 Mon Sep 17 00:00:00 2001 From: Ihar Hrachyshka Date: Wed, 24 Jul 2024 18:21:09 -0400 Subject: [PATCH] Improve run_training error handling The library shouldn't (just) print errors; it should raise errors and let callers deal with them. This patch also makes the function raise RuntimeError when training process returns a non-zero return code, or when it times out on waiting for exit. ilab cli already handles all Exceptions raised by the function. KeyboardInterrupt won't be caught by this exception handler since it's not Exception (but BaseException), but it's handled by click library instead [1]. The intent of this patch is to make `ilab train` fail with non-zero return code when training routine failed. [1] https://click.palletsprojects.com/en/7.x/exceptions/#where-are-errors-handled Signed-off-by: Ihar Hrachyshka --- src/instructlab/training/main_ds.py | 34 +++++++++++------------------ 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index b56d5b85..2e7cf8b6 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -663,28 +663,20 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: command.append("--cpu_offload_optimizer_pin_memory") print(f"\033[92mRunning command: {' '.join(command)}\033[0m") - process = None + process = StreamablePopen( + f"{train_args.ckpt_output_dir}/full_logs_global{torch_args.node_rank}.log", + command, + ) + print("\033[91mTerminating process 🤖\033[0m") + process.terminate() try: - process = StreamablePopen( - f"{train_args.ckpt_output_dir}/full_logs_global{torch_args.node_rank}.log", - command, - ) - - except KeyboardInterrupt: - print("Process interrupted by user") - except Exception as e: - print(f"An error occurred: {str(e)}") - finally: - if "process" not in locals() or process is None: - return - - print("\033[91mTerminating process 🤖\033[0m") - process.terminate() - try: - process.wait(timeout=60) - except subprocess.TimeoutExpired: - print("\033[91mProcess did not terminate in time, killing it.\033[0m") - process.kill() + rc = process.wait(timeout=60) + if rc: + raise RuntimeError(f"Training process exited with code {rc}") + except subprocess.TimeoutExpired as e: + print("\033[91mProcess did not terminate in time, killing it.\033[0m") + process.kill() + raise RuntimeError("Training process timed out on exit.") from e if __name__ == "__main__":