diff --git a/open_lm/main.py b/open_lm/main.py index 62a6fa50..f31c1224 100644 --- a/open_lm/main.py +++ b/open_lm/main.py @@ -59,6 +59,7 @@ check_exists, start_sync_process, remote_sync_with_expon_backoff, + get_metadata_file, get_string_for_epoch, log_num_checkpoints, terminate_sync_process, @@ -203,6 +204,7 @@ def save_checkpoint( next_shard_per_source=None, samples_seen=None, shard_shuffle_seed=None, + train_data_string=None, averagers=None, ): cpu_state, optim_state = None, None @@ -246,6 +248,20 @@ def save_checkpoint( "is_final_checkpoint": is_final_checkpoint, "evaluation_metrics": evaluation_metrics, } + if next_shard_per_source is not None: + checkpoint_dict_stats["next_shard_per_source"] = next_shard_per_source + + if samples_seen is not None: + checkpoint_dict_stats["samples_seen"] = samples_seen + + if step is not None: + checkpoint_dict_stats["step"] = step + + if shard_shuffle_seed is not None: + checkpoint_dict_stats["shard_shuffle_seed"] = shard_shuffle_seed + + if train_data_string is not None: + checkpoint_dict_stats["train_data_string"] = train_data_string prefixes = { "epoch_": checkpoint_dict_model, @@ -752,6 +768,7 @@ def main(args): # Only enter training loop if there are steps to be done. done_training = global_step >= total_steps epoch = start_epoch + num_ckpt_too_few_tokens = 0 while not done_training: if is_master(args): logging.info(f"Start epoch {epoch}") @@ -823,6 +840,15 @@ def main(args): logging.info("Training exiting due to NaN value") break + expected_steps = data["train"].dataloader.num_batches + if steps_done_epoch < (1 - args.data_tolerate_error_p) * expected_steps and not done_training: + num_ckpt_too_few_tokens += 1 + + if num_ckpt_too_few_tokens > args.data_tolerate_num_ckpts: + raise RuntimeError( + f"{num_ckpt_too_few_tokens} checkpoints happened where the number of tokens seen was less than {1 - args.data_tolerate_error_p} of expected. This is likely due to transient errors e.g. reading from S3." + ) + epoch = epoch + 1 evaluation_metrics = [] if "val_list" in data and (epoch % args.val_frequency == 0 or done_training): @@ -840,6 +866,29 @@ def main(args): logging.error(traceback.format_exc()) logging.warning("evaluation failed! continuing to save_checkpoint") + if is_master(args): + end_of_epoch_log = { + "epoch": epoch, + "tokens": (global_step + 1) * args.global_batch_size * args.seq_len, + "checkpoints_too_few_tokens": num_ckpt_too_few_tokens, + "percentage_of_data_seen": steps_done_epoch / expected_steps, + } + + if args.dataset_manifest is not None: + for i in range(len(next_shard_per_source)): + end_of_epoch_log[f"next_shard_{i}"] = next_shard_per_source[i] + end_of_epoch_log[f"dataset_pass_{i}"] = next_shard_per_source[i] // len( + get_metadata_file(args.dataset_manifest[i]) + ) + + for name, val in end_of_epoch_log.items(): + name = "train/" + name + if writer is not None: + writer.add_scalar(name, val, global_step) + if args.wandb: + assert wandb is not None, "Please install wandb." + wandb.log({name: val, "step": global_step, "tokens": end_of_epoch_log["tokens"]}) + # Saving checkpoints. save_checkpoint( args, @@ -853,6 +902,7 @@ def main(args): next_shard_per_source=next_shard_per_source if args.dataset_manifest is not None else None, samples_seen=samples_seen if args.dataset_manifest is not None else None, shard_shuffle_seed=args.shard_shuffle_seed, + train_data_string=train_data_string_per_source if args.dataset_manifest is not None else None, averagers=averagers, ) diff --git a/open_lm/params.py b/open_lm/params.py index 4cad8e19..0a7a3f64 100644 --- a/open_lm/params.py +++ b/open_lm/params.py @@ -775,6 +775,19 @@ def parse_args(args): default=0, help="Whether to log the average model training loss. if not 0, it will log the average loss over the specified number of steps.", ) + parser.add_argument( + "--data-tolerate-error-p", + type=float, + default=0.09, # Roughly the number required to not repeat more than 10% of data. + help="This is the percentage of expected tokens above which the checkpoint is considered failed because of not having seen enough data.", + ) + parser.add_argument( + "--data-tolerate-num-ckpts", + type=int, + default=0, + help="This is the maximum number of failed checkpoints (due to not having seen enough tokens) that are allowed", + ) + add_model_args(parser) config = maybe_load_config(parser, args)