Skip to content

Commit

Permalink
Exception catching improvements
Browse files Browse the repository at this point in the history
Signed-off-by: Kelly A <[email protected]>
  • Loading branch information
kellyaa committed May 14, 2024
1 parent e1334b5 commit cd260a3
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 62 deletions.
10 changes: 4 additions & 6 deletions build/accelerate_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,12 @@ def main():
except (TypeError, ValueError, EnvironmentError) as e:
logging.error(traceback.format_exc())
write_termination_log(
"Exception raised during training. This may be a problem with your input: {}".format(
e
)
f"Exception raised during training. This may be a problem with your input: {e}"
)
sys.exit(USER_ERROR_EXIT_CODE)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log("Unhandled exception during training")
write_termination_log(f"Unhandled exception during training. {e}")
sys.exit(INTERNAL_ERROR_EXIT_CODE)

##########
Expand All @@ -85,11 +83,11 @@ def main():
return_code = e.returncode
if return_code not in [INTERNAL_ERROR_EXIT_CODE, USER_ERROR_EXIT_CODE]:
return_code = INTERNAL_ERROR_EXIT_CODE
write_termination_log("Unhandled exception during training")
write_termination_log(f"Unhandled exception during training. {e}")
sys.exit(return_code)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log("Unhandled exception during training")
write_termination_log(f"Unhandled exception during training. {e}")
sys.exit(INTERNAL_ERROR_EXIT_CODE)

return 0
Expand Down
28 changes: 19 additions & 9 deletions build/launch_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
import sys
import traceback

# Third Party
from huggingface_hub.utils._validators import HFValidationError
from torch.cuda import OutOfMemoryError

# First Party
import logging

Expand Down Expand Up @@ -77,9 +81,7 @@ def main():
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log(
"Exception raised during training. This may be a problem with your input: {}".format(
e
)
f"Exception raised during training. This may be a problem with your input: {e}"
)
sys.exit(USER_ERROR_EXIT_CODE)

Expand All @@ -97,14 +99,18 @@ def main():
peft_config=tune_config,
tracker_configs=tracker_config_args,
)
except MemoryError:
except (MemoryError, OutOfMemoryError) as e:
logging.error(traceback.format_exc())
write_termination_log("OOM error during training")
write_termination_log(f"OOM error during training. {e}")
sys.exit(INTERNAL_ERROR_EXIT_CODE)
except FileNotFoundError as e:
logging.error(traceback.format_exc())
write_termination_log("Unable to load file: {}".format(e))
sys.exit(USER_ERROR_EXIT_CODE)
except HFValidationError as e:
logging.error(traceback.format_exc())
write_termination_log(f"Specified base model not found. Exception: {e}")
sys.exit(USER_ERROR_EXIT_CODE)
except (TypeError, ValueError, EnvironmentError) as e:
logging.error(traceback.format_exc())
write_termination_log(
Expand All @@ -113,7 +119,7 @@ def main():
sys.exit(USER_ERROR_EXIT_CODE)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log("Unhandled exception during training")
write_termination_log(f"Unhandled exception during training: {e}")
sys.exit(INTERNAL_ERROR_EXIT_CODE)

if merge_model:
Expand Down Expand Up @@ -142,7 +148,9 @@ def main():
)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log("Exception encountered merging model checkpoints")
write_termination_log(
f"Exception encountered merging base model with checkpoint. {e}"
)
sys.exit(INTERNAL_ERROR_EXIT_CODE)
else:
try:
Expand All @@ -161,7 +169,7 @@ def main():
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log(
"Exception encountered writing output model to storage"
f"Exception encountered writing output model to storage: {e}"
)
sys.exit(INTERNAL_ERROR_EXIT_CODE)

Expand All @@ -175,7 +183,9 @@ def main():
shutil.copy(train_logs_filepath, original_output_dir)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log("Exception encountered in capturing training logs")
write_termination_log(
f"Exception encountered in capturing training logs: {e}"
)
sys.exit(INTERNAL_ERROR_EXIT_CODE)

return 0
Expand Down
5 changes: 3 additions & 2 deletions build/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
INTERNAL_ERROR_EXIT_CODE = 203


def write_termination_log(text, log_file="/dev/termination-log"):
def write_termination_log(text):
log_file = os.environ.get("TERMINATION_LOG_FILE", "/dev/termination-log")
try:
with open(log_file, "a", encoding="utf-8") as handle:
handle.write(text)
Expand Down Expand Up @@ -220,7 +221,7 @@ def process_accelerate_launch_args(job_config_dict):
)

# Add training_script
script = os.environ.get("LAUNCH_TRAINING_SCRIPT") or "/app/launch_training.py"
script = os.environ.get("LAUNCH_TRAINING_SCRIPT", "/app/launch_training.py")
accelerate_launch_args.append(script)

logging.debug("accelerate_launch_args: %s", accelerate_launch_args)
Expand Down
116 changes: 71 additions & 45 deletions tests/build/test_launch_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,87 +60,113 @@
}


def test_successful_pt():
"""Check if we can bootstrap and peft tune causallm models"""
def serialize_args(args_json):
message_bytes = pickle.dumps(args_json)
base64_bytes = base64.b64encode(message_bytes)
return base64_bytes.decode("ascii")


def setup_env(tempdir):
os.environ["LAUNCH_TRAINING_SCRIPT"] = SCRIPT
os.environ["PYTHONPATH"] = "./:$PYTHONPATH"
os.environ["TERMINATION_LOG_FILE"] = tempdir + "/termination-log"


def cleanup_env():
os.environ.pop("LAUNCH_TRAINING_SCRIPT", None)
os.environ.pop("PYTHONPATH", None)
os.environ.pop("TERMINATION_LOG_FILE", None)


def test_successful_pt():
"""Check if we can bootstrap and peft tune causallm models"""
with tempfile.TemporaryDirectory() as tempdir:
setup_env(tempdir)
TRAIN_KWARGS = {**BASE_PEFT_KWARGS, **{"output_dir": tempdir}}
message_bytes = pickle.dumps(TRAIN_KWARGS)
base64_bytes = base64.b64encode(message_bytes)
serialized_args = base64_bytes.decode("ascii")

serialized_args = serialize_args(TRAIN_KWARGS)
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args

assert main() == 0
assert os.path.exists(tempdir + "/termination-log") is False


def test_bad_script_path():
"""Check if we can bootstrap and peft tune causallm models"""
os.environ["LAUNCH_TRAINING_SCRIPT"] = "invalid"
os.environ["PYTHONPATH"] = "./:$PYTHONPATH"
"""Check for appropriate error for an invalid training script location"""
with tempfile.TemporaryDirectory() as tempdir:
setup_env(tempdir)
TRAIN_KWARGS = {**BASE_PEFT_KWARGS, **{"output_dir": tempdir}}
message_bytes = pickle.dumps(TRAIN_KWARGS)
base64_bytes = base64.b64encode(message_bytes)
serialized_args = base64_bytes.decode("ascii")

serialized_args = serialize_args(TRAIN_KWARGS)
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args
os.environ["LAUNCH_TRAINING_SCRIPT"] = "/not/here"

with pytest.raises(SystemExit) as pytest_wrapped_e:
main()
assert pytest_wrapped_e.type == SystemExit
assert pytest_wrapped_e.value.code == INTERNAL_ERROR_EXIT_CODE
with pytest.raises(SystemExit) as pytest_wrapped_e:
main()
assert pytest_wrapped_e.type == SystemExit
assert pytest_wrapped_e.value.code == INTERNAL_ERROR_EXIT_CODE
assert os.stat(tempdir + "/termination-log").st_size > 0


def test_blank_env_var():
os.environ["LAUNCH_TRAINING_SCRIPT"] = SCRIPT
os.environ["PYTHONPATH"] = "./:$PYTHONPATH"
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = ""
with pytest.raises(SystemExit) as pytest_wrapped_e:
main()
assert pytest_wrapped_e.type == SystemExit
assert pytest_wrapped_e.value.code == USER_ERROR_EXIT_CODE
with tempfile.TemporaryDirectory() as tempdir:
setup_env(tempdir)
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = ""
with pytest.raises(SystemExit) as pytest_wrapped_e:
main()
assert pytest_wrapped_e.type == SystemExit
assert pytest_wrapped_e.value.code == USER_ERROR_EXIT_CODE
assert os.stat(tempdir + "/termination-log").st_size > 0


def test_faulty_file_path():
os.environ["LAUNCH_TRAINING_SCRIPT"] = SCRIPT
os.environ["PYTHONPATH"] = "./:$PYTHONPATH"
with tempfile.TemporaryDirectory() as tempdir:
setup_env(tempdir)
faulty_path = os.path.join(tempdir, "non_existent_file.pkl")
TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"training_data_path": faulty_path, "output_dir": tempdir},
}
message_bytes = pickle.dumps(TRAIN_KWARGS)
base64_bytes = base64.b64encode(message_bytes)
serialized_args = base64_bytes.decode("ascii")
serialized_args = serialize_args(TRAIN_KWARGS)
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args
with pytest.raises(SystemExit) as pytest_wrapped_e:
main()
assert pytest_wrapped_e.type == SystemExit
assert pytest_wrapped_e.value.code == USER_ERROR_EXIT_CODE
assert os.stat(tempdir + "/termination-log").st_size > 0


def test_config_parsing_error():
os.environ["LAUNCH_TRAINING_SCRIPT"] = SCRIPT
os.environ["PYTHONPATH"] = "./:$PYTHONPATH"
def test_bad_base_model_path():
with tempfile.TemporaryDirectory() as tempdir:
setup_env(tempdir)
TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"model_name_or_path": "/wrong/path"},
}
serialized_args = serialize_args(TRAIN_KWARGS)
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args
with pytest.raises(SystemExit) as pytest_wrapped_e:
main()
assert pytest_wrapped_e.type == SystemExit
assert pytest_wrapped_e.value.code == USER_ERROR_EXIT_CODE
assert os.stat(tempdir + "/termination-log").st_size > 0

TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"num_train_epochs": "five"},
} # Intentional type error
message_bytes = pickle.dumps(TRAIN_KWARGS)
base64_bytes = base64.b64encode(message_bytes)
serialized_args = base64_bytes.decode("ascii")
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args
with pytest.raises(SystemExit) as pytest_wrapped_e:
main()
assert pytest_wrapped_e.type == SystemExit
assert pytest_wrapped_e.value.code == USER_ERROR_EXIT_CODE

def test_config_parsing_error():
with tempfile.TemporaryDirectory() as tempdir:
setup_env(tempdir)
TRAIN_KWARGS = {
**BASE_PEFT_KWARGS,
**{"num_train_epochs": "five"},
} # Intentional type error
serialized_args = serialize_args(TRAIN_KWARGS)
os.environ["SFT_TRAINER_CONFIG_JSON_ENV_VAR"] = serialized_args
with pytest.raises(SystemExit) as pytest_wrapped_e:
main()
assert pytest_wrapped_e.type == SystemExit
assert pytest_wrapped_e.value.code == USER_ERROR_EXIT_CODE
assert os.stat(tempdir + "/termination-log").st_size > 0


def test_cleanup():
# This runs to unset env variables that could disrupt other tests
os.environ.pop("LAUNCH_TRAINING_SCRIPT", None)
cleanup_env()
assert True

0 comments on commit cd260a3

Please sign in to comment.