Skip to content

Commit

Permalink
Improve run_training error handling
Browse files Browse the repository at this point in the history
The library shouldn't (just) print errors; it should raise errors and
let callers deal with them.

This patch also makes the function raise RuntimeError when training
process returns a non-zero return code, or when it times out on waiting
for exit.

ilab cli already handles all Exceptions raised by the function.
KeyboardInterrupt won't be caught by this exception handler since it's
not Exception (but BaseException), but it's handled by click library
instead [1].

The intent of this patch is to make `ilab train` fail with non-zero
return code when training routine failed.

[1] https://click.palletsprojects.com/en/7.x/exceptions/#where-are-errors-handled

Signed-off-by: Ihar Hrachyshka <[email protected]>
  • Loading branch information
booxter committed Jul 26, 2024
1 parent 9fdeb87 commit 3d84967
Showing 1 changed file with 13 additions and 21 deletions.
34 changes: 13 additions & 21 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,28 +663,20 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
command.append("--cpu_offload_optimizer_pin_memory")

print(f"\033[92mRunning command: {' '.join(command)}\033[0m")
process = None
process = StreamablePopen(
f"{train_args.ckpt_output_dir}/full_logs_global{torch_args.node_rank}.log",
command,
)
print("\033[91mTerminating process 🤖\033[0m")
process.terminate()
try:
process = StreamablePopen(
f"{train_args.ckpt_output_dir}/full_logs_global{torch_args.node_rank}.log",
command,
)

except KeyboardInterrupt:
print("Process interrupted by user")
except Exception as e:
print(f"An error occurred: {str(e)}")
finally:
if "process" not in locals() or process is None:
return

print("\033[91mTerminating process 🤖\033[0m")
process.terminate()
try:
process.wait(timeout=60)
except subprocess.TimeoutExpired:
print("\033[91mProcess did not terminate in time, killing it.\033[0m")
process.kill()
rc = process.wait(timeout=60)
if rc:
raise RuntimeError(f"Training process exited with code {rc}")
except subprocess.TimeoutExpired as e:
print("\033[91mProcess did not terminate in time, killing it.\033[0m")
process.kill()
raise RuntimeError("Training process timed out on exit.") from e


if __name__ == "__main__":
Expand Down

0 comments on commit 3d84967

Please sign in to comment.