Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error handling for s3 read errors. #273

Merged
merged 14 commits into from
May 14, 2024
50 changes: 50 additions & 0 deletions open_lm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
check_exists,
start_sync_process,
remote_sync_with_expon_backoff,
get_metadata_file,
get_string_for_epoch,
log_num_checkpoints,
terminate_sync_process,
Expand Down Expand Up @@ -203,6 +204,7 @@ def save_checkpoint(
next_shard_per_source=None,
samples_seen=None,
shard_shuffle_seed=None,
train_data_string=None,
averagers=None,
):
cpu_state, optim_state = None, None
Expand Down Expand Up @@ -246,6 +248,20 @@ def save_checkpoint(
"is_final_checkpoint": is_final_checkpoint,
"evaluation_metrics": evaluation_metrics,
}
if next_shard_per_source is not None:
checkpoint_dict_stats["next_shard_per_source"] = next_shard_per_source

if samples_seen is not None:
checkpoint_dict_stats["samples_seen"] = samples_seen

if step is not None:
checkpoint_dict_stats["step"] = step

if shard_shuffle_seed is not None:
checkpoint_dict_stats["shard_shuffle_seed"] = shard_shuffle_seed

if train_data_string is not None:
checkpoint_dict_stats["train_data_string"] = train_data_string

prefixes = {
"epoch_": checkpoint_dict_model,
Expand Down Expand Up @@ -752,6 +768,7 @@ def main(args):
# Only enter training loop if there are steps to be done.
done_training = global_step >= total_steps
epoch = start_epoch
num_ckpt_too_few_tokens = 0
while not done_training:
if is_master(args):
logging.info(f"Start epoch {epoch}")
Expand Down Expand Up @@ -823,6 +840,15 @@ def main(args):
logging.info("Training exiting due to NaN value")
break

expected_steps = data["train"].dataloader.num_batches
if steps_done_epoch < (1 - args.data_tolerate_error_p) * expected_steps and not done_training:
num_ckpt_too_few_tokens += 1

if num_ckpt_too_few_tokens > args.data_tolerate_num_ckpts:
raise RuntimeError(
f"{num_ckpt_too_few_tokens} checkpoints happened where the number of tokens seen was less than {1 - args.data_tolerate_error_p} of expected. This is likely due to transient errors e.g. reading from S3."
)

epoch = epoch + 1
evaluation_metrics = []
if "val_list" in data and (epoch % args.val_frequency == 0 or done_training):
Expand All @@ -840,6 +866,29 @@ def main(args):
logging.error(traceback.format_exc())
logging.warning("evaluation failed! continuing to save_checkpoint")

if is_master(args):
end_of_epoch_log = {
"epoch": epoch,
"tokens": (global_step + 1) * args.global_batch_size * args.seq_len,
"checkpoints_too_few_tokens": num_ckpt_too_few_tokens,
"percentage_of_data_seen": steps_done_epoch / expected_steps,
}

if args.dataset_manifest is not None:
for i in range(len(next_shard_per_source)):
end_of_epoch_log[f"next_shard_{i}"] = next_shard_per_source[i]
end_of_epoch_log[f"dataset_pass_{i}"] = next_shard_per_source[i] // len(
get_metadata_file(args.dataset_manifest[i])
)

for name, val in end_of_epoch_log.items():
name = "train/" + name
if writer is not None:
writer.add_scalar(name, val, global_step)
if args.wandb:
assert wandb is not None, "Please install wandb."
wandb.log({name: val, "step": global_step, "tokens": end_of_epoch_log["tokens"]})

# Saving checkpoints.
save_checkpoint(
args,
Expand All @@ -853,6 +902,7 @@ def main(args):
next_shard_per_source=next_shard_per_source if args.dataset_manifest is not None else None,
samples_seen=samples_seen if args.dataset_manifest is not None else None,
shard_shuffle_seed=args.shard_shuffle_seed,
train_data_string=train_data_string_per_source if args.dataset_manifest is not None else None,
averagers=averagers,
)

Expand Down
13 changes: 13 additions & 0 deletions open_lm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,19 @@ def parse_args(args):
default=0,
help="Whether to log the average model training loss. if not 0, it will log the average loss over the specified number of steps.",
)
parser.add_argument(
"--data-tolerate-error-p",
type=float,
default=0.09, # Roughly the number required to not repeat more than 10% of data.
help="This is the percentage of expected tokens above which the checkpoint is considered failed because of not having seen enough data.",
)
parser.add_argument(
"--data-tolerate-num-ckpts",
type=int,
default=0,
help="This is the maximum number of failed checkpoints (due to not having seen enough tokens) that are allowed",
)

add_model_args(parser)

config = maybe_load_config(parser, args)
Expand Down
Loading