From f36020f488546eec35d89525bef13fe601684163 Mon Sep 17 00:00:00 2001 From: clemsgrs Date: Thu, 24 Oct 2024 06:32:02 +0200 Subject: [PATCH] moved everything from epoch to iteration + fix loading official pretrained weights --- dinov2/configs/ssl_default_config.yaml | 17 ++++---- dinov2/eval/utils.py | 16 ++++--- dinov2/models/__init__.py | 2 +- dinov2/train/ssl_meta_arch.py | 3 +- dinov2/train/train.py | 59 +++++++++++++++----------- dinov2/utils/config.py | 1 + 6 files changed, 55 insertions(+), 43 deletions(-) diff --git a/dinov2/configs/ssl_default_config.yaml b/dinov2/configs/ssl_default_config.yaml index 6ce2f39cd..6d3356ef9 100644 --- a/dinov2/configs/ssl_default_config.yaml +++ b/dinov2/configs/ssl_default_config.yaml @@ -61,23 +61,22 @@ ibot: train: batch_size_per_gpu: 64 dataset_path: ImageNet:split=TRAIN - output_dir: . + output_dir: "output" seed: 0 num_workers: 8 cache_dataset: true centering: "centering" # or "sinkhorn_knopp" - save_frequency: 0.1 # save every 10% of an epoch - save_every: 5 + save_frequency: 0.1 # save every x% of an epoch tune: - tune_every: + tune_every: # run tuning every x% of an epoch, leave empty to disable tuning query_dataset_path: test_dataset_path: early_stopping: enable: false tracking: "auc_20" min_max: "max" - patience: 20 - min_epoch: 30 + patience_pct: 0.2 # stop after x% of total epochs without improvement + min_epoch_pct: 0.3 # minimum % of total epochs to run knn: nb_knn: [10, 20, 100, 200] temperature: 0.07 @@ -105,7 +104,7 @@ teacher: final_momentum_teacher: 1 warmup_teacher_temp: 0.04 teacher_temp: 0.07 - warmup_teacher_temp_epochs: 30 + warmup_teacher_temp_pct: 0.3 # warmup for x% of total epochs optim: epochs: 100 max_iter: @@ -113,10 +112,10 @@ optim: weight_decay_end: 0.4 base_lr: 0.004 # learning rate for a batch size of 1024 lr: 0. # will be set after applying scaling rule - warmup_epochs: 10 + warmup_pct: 0.1 # percentage of iterations for warmup min_lr: 1.0e-06 clip_grad: 3.0 - freeze_last_layer_epochs: 1 + freeze_last_layer_pct: 0.01 scaling_rule: sqrt_wrt_1024 patch_embed_lr_mult: 0.2 layerwise_decay: 0.9 diff --git a/dinov2/eval/utils.py b/dinov2/eval/utils.py index eb5f8ba56..5b6dd449a 100644 --- a/dinov2/eval/utils.py +++ b/dinov2/eval/utils.py @@ -170,21 +170,25 @@ def __init__( self, tracking: str, min_max: str, - patience: int = 20, - min_epoch: int = 50, + nepochs: int, + patience_pct: int = 0.2, + min_epoch_pct: int = 0.3, checkpoint_dir: Optional[Path] = None, verbose: bool = False, ): """ Args: - patience (int): How long to wait after last time validation loss improved. - min_epoch (int): Earliest epoch possible for stopping + tracking (str): Metric to track for early stopping + min_max (str): Whether to minimize or maximize the tracking metric + nepochs (int): Total number of epochs + patience_pct (int): Percentage of epochs to wait before early stopping + min_epoch_pct (int): Percentage of epochs to wait before enabling early stopping verbose (bool): If True, prints a message for each validation loss improvement """ self.tracking = tracking self.min_max = min_max - self.patience = patience - self.min_epoch = min_epoch + self.patience = int(round(patience_pct * nepochs, 0)) + self.min_epoch = int(round(min_epoch_pct * nepochs, 0)) self.checkpoint_dir = checkpoint_dir self.verbose = verbose diff --git a/dinov2/models/__init__.py b/dinov2/models/__init__.py index 82316c145..fddf20b7f 100644 --- a/dinov2/models/__init__.py +++ b/dinov2/models/__init__.py @@ -24,7 +24,7 @@ def update_state_dict(model_dict, state_dict): if v.size() != model_dict[k].size(): updated_state_dict[k] = model_dict[k] failure += 1 - print(f"{k} | ckpt size: {v.size()} | model size: {model_dict[k].size()}") + logger.info(f"{k} | ckpt size: {v.size()} | model size: {model_dict[k].size()}") else: updated_state_dict[k] = v success += 1 diff --git a/dinov2/train/ssl_meta_arch.py b/dinov2/train/ssl_meta_arch.py index dbb8da553..27f9f6542 100644 --- a/dinov2/train/ssl_meta_arch.py +++ b/dinov2/train/ssl_meta_arch.py @@ -45,8 +45,7 @@ def __init__(self, cfg): if cfg.student.pretrained_weights: chkpt = torch.load(cfg.student.pretrained_weights) logger.info(f"OPTIONS -- pretrained weights: loading from {cfg.student.pretrained_weights}") - sd = chkpt["model"] - sd, msg = update_state_dict(student_backbone.state_dict(), sd) + sd, msg = update_state_dict(student_backbone.state_dict(), chkpt) logger.info(f"pretrained weights loaded: {msg}") student_backbone.load_state_dict(sd, strict=False) diff --git a/dinov2/train/train.py b/dinov2/train/train.py index 9c00d96c1..67b8b81f9 100644 --- a/dinov2/train/train.py +++ b/dinov2/train/train.py @@ -57,8 +57,6 @@ def get_args_parser(add_help: bool = True): ) parser.add_argument( "--output-dir", - "--output_dir", - default="output", type=str, help="Output directory to save logs and checkpoints", ) @@ -75,7 +73,7 @@ def build_schedulers(cfg, OFFICIAL_EPOCH_LENGTH): base_value=cfg.optim["lr"], final_value=cfg.optim["min_lr"], total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, - warmup_iters=cfg.optim["warmup_epochs"] * OFFICIAL_EPOCH_LENGTH, + warmup_iters=int(round(cfg.optim["warmup_pct"] * cfg.optim["epochs"], 0)) * OFFICIAL_EPOCH_LENGTH, start_warmup_value=0, ) wd = dict( @@ -91,8 +89,9 @@ def build_schedulers(cfg, OFFICIAL_EPOCH_LENGTH): teacher_temp = dict( base_value=cfg.teacher["teacher_temp"], final_value=cfg.teacher["teacher_temp"], - total_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH, - warmup_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH, + total_iters=int(round(cfg.teacher["warmup_teacher_temp_pct"] * cfg.optim["epochs"], 0)) * OFFICIAL_EPOCH_LENGTH, + warmup_iters=int(round(cfg.teacher["warmup_teacher_temp_pct"] * cfg.optim["epochs"], 0)) + * OFFICIAL_EPOCH_LENGTH, start_warmup_value=cfg.teacher["warmup_teacher_temp"], ) @@ -103,7 +102,7 @@ def build_schedulers(cfg, OFFICIAL_EPOCH_LENGTH): last_layer_lr_schedule = CosineScheduler(**lr) last_layer_lr_schedule.schedule[ - : cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH + : int(round(cfg.optim["freeze_last_layer_pct"] * cfg.optim["epochs"], 0)) * OFFICIAL_EPOCH_LENGTH ] = 0 # mimicking the original schedules logger.info("Schedulers ready.") @@ -149,7 +148,7 @@ def save_checkpoint(cfg, model, iteration): def do_tune( cfg, - epoch, + iteration: int, model: torch.nn.Module, query_dataset, test_dataset, @@ -184,7 +183,7 @@ def do_tune( # student = student.to(torch.device(f"cuda:{distributed.get_global_rank()}")) # teacher = teacher.to(torch.device(f"cuda:{distributed.get_global_rank()}")) if verbose: - tqdm.tqdm.write(f"Loading epoch {epoch} weights...") + tqdm.tqdm.write(f"Loading epoch {iteration} weights...") student_weights = model.student.state_dict() teacher_weights = model.teacher.state_dict() student_msg = load_weights(student, student_weights) @@ -325,7 +324,7 @@ def do_train(cfg, model, resume=False): total_batch_size = cfg.train.batch_size_per_gpu * distributed.get_global_size() OFFICIAL_EPOCH_LENGTH = len(dataset) // total_batch_size - save_every = int(cfg.train.save_frequency * OFFICIAL_EPOCH_LENGTH) + save_every = int(round(cfg.train.save_frequency * OFFICIAL_EPOCH_LENGTH, 0)) if cfg.optim.max_iter is not None: max_iter = cfg.optim.max_iter else: @@ -333,7 +332,7 @@ def do_train(cfg, model, resume=False): periodic_checkpointer = PeriodicCheckpointer( checkpointer, - period=cfg.train.save_every * OFFICIAL_EPOCH_LENGTH, + period=save_every, max_iter=max_iter, max_to_keep=3, ) @@ -369,8 +368,9 @@ def do_train(cfg, model, resume=False): early_stopper = EarlyStoppingDINO( cfg.tune.early_stopping.tracking, cfg.tune.early_stopping.min_max, - cfg.tune.early_stopping.patience, - cfg.tune.early_stopping.min_epoch, + cfg.optim.epochs, + cfg.tune.early_stopping.patience_pct, + cfg.tune.early_stopping.min_epoch_pct, checkpoint_dir=checkpoint_save_dir, verbose=True, ) @@ -463,9 +463,19 @@ def do_train(cfg, model, resume=False): metric_logger.update(current_batch_size=current_batch_size) metric_logger.update(total_loss=losses_reduced, **loss_dict_reduced) + # logging + if distributed.is_main_process() and cfg.wandb.enable: + log_dict = {"iteration": iteration} + update_log_dict(log_dict, f"{header.lower()}/lr", lr, step="iteration") + update_log_dict(log_dict, f"{header.lower()}/wd", wd, step="iteration") + update_log_dict(log_dict, f"{header.lower()}/loss", losses_reduced, step="iteration") + for loss_name, loss_value in loss_dict.items(): + update_log_dict(log_dict, f"{header.lower()}/{loss_name}", loss_value, step="iteration") + wandb.log(log_dict, step=iteration) + epoch = iteration // OFFICIAL_EPOCH_LENGTH - # log at the end of each epoch + # addtional logging at the end of each epoch if iteration % OFFICIAL_EPOCH_LENGTH == 0: if distributed.is_main_process() and cfg.wandb.enable: # log the total loss and each individual loss to wandb @@ -482,7 +492,7 @@ def do_train(cfg, model, resume=False): if cfg.tune.tune_every and epoch % cfg.tune.tune_every == 0: tune_results = do_tune( cfg, - epoch + 1, + iteration, model, query_dataset, test_dataset, @@ -499,17 +509,9 @@ def do_train(cfg, model, resume=False): if early_stopper.early_stop and cfg.tune.early_stopping.enable: stop = True - if stop: - if distributed.is_main_process(): - tqdm.tqdm.write( - f"Stopping early because best {cfg.tune.early_stopping.tracking} was reached {cfg.tune.early_stopping.patience} epochs ago" - ) - break - - # log to wandb - - if distributed.is_main_process() and cfg.wandb.enable and iteration % OFFICIAL_EPOCH_LENGTH == 0: - wandb.log(log_dict, step=epoch) + # log to wandb + if distributed.is_main_process() and cfg.wandb.enable: + wandb.log(log_dict, step=epoch) # checkpointing and testing @@ -521,6 +523,13 @@ def do_train(cfg, model, resume=False): iteration = iteration + 1 + if stop: + if distributed.is_main_process(): + tqdm.tqdm.write( + f"Stopping early because best {cfg.tune.early_stopping.tracking} was reached {cfg.tune.early_stopping.patience} epochs ago" + ) + break + # gather stats from all processes metric_logger.synchronize_between_processes() diff --git a/dinov2/utils/config.py b/dinov2/utils/config.py index 69cc124b9..ee205ee3d 100644 --- a/dinov2/utils/config.py +++ b/dinov2/utils/config.py @@ -55,6 +55,7 @@ def default_setup(args, cfg): key = os.environ.get("WANDB_API_KEY") wandb_run = utils.initialize_wandb(cfg, key=key) wandb_run.define_metric("epoch", summary="max") + wandb_run.define_metric("iteration", summary="max") run_id = wandb_run.id else: run_id = ""