Skip to content

Commit

Permalink
moved everything from epoch to iteration + fix loading official pretr…
Browse files Browse the repository at this point in the history
…ained weights
  • Loading branch information
clemsgrs committed Oct 24, 2024
1 parent c09982b commit f36020f
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 43 deletions.
17 changes: 8 additions & 9 deletions dinov2/configs/ssl_default_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,23 +61,22 @@ ibot:
train:
batch_size_per_gpu: 64
dataset_path: ImageNet:split=TRAIN
output_dir: .
output_dir: "output"
seed: 0
num_workers: 8
cache_dataset: true
centering: "centering" # or "sinkhorn_knopp"
save_frequency: 0.1 # save every 10% of an epoch
save_every: 5
save_frequency: 0.1 # save every x% of an epoch
tune:
tune_every:
tune_every: # run tuning every x% of an epoch, leave empty to disable tuning
query_dataset_path:
test_dataset_path:
early_stopping:
enable: false
tracking: "auc_20"
min_max: "max"
patience: 20
min_epoch: 30
patience_pct: 0.2 # stop after x% of total epochs without improvement
min_epoch_pct: 0.3 # minimum % of total epochs to run
knn:
nb_knn: [10, 20, 100, 200]
temperature: 0.07
Expand Down Expand Up @@ -105,18 +104,18 @@ teacher:
final_momentum_teacher: 1
warmup_teacher_temp: 0.04
teacher_temp: 0.07
warmup_teacher_temp_epochs: 30
warmup_teacher_temp_pct: 0.3 # warmup for x% of total epochs
optim:
epochs: 100
max_iter:
weight_decay: 0.04
weight_decay_end: 0.4
base_lr: 0.004 # learning rate for a batch size of 1024
lr: 0. # will be set after applying scaling rule
warmup_epochs: 10
warmup_pct: 0.1 # percentage of iterations for warmup
min_lr: 1.0e-06
clip_grad: 3.0
freeze_last_layer_epochs: 1
freeze_last_layer_pct: 0.01
scaling_rule: sqrt_wrt_1024
patch_embed_lr_mult: 0.2
layerwise_decay: 0.9
Expand Down
16 changes: 10 additions & 6 deletions dinov2/eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,21 +170,25 @@ def __init__(
self,
tracking: str,
min_max: str,
patience: int = 20,
min_epoch: int = 50,
nepochs: int,
patience_pct: int = 0.2,
min_epoch_pct: int = 0.3,
checkpoint_dir: Optional[Path] = None,
verbose: bool = False,
):
"""
Args:
patience (int): How long to wait after last time validation loss improved.
min_epoch (int): Earliest epoch possible for stopping
tracking (str): Metric to track for early stopping
min_max (str): Whether to minimize or maximize the tracking metric
nepochs (int): Total number of epochs
patience_pct (int): Percentage of epochs to wait before early stopping
min_epoch_pct (int): Percentage of epochs to wait before enabling early stopping
verbose (bool): If True, prints a message for each validation loss improvement
"""
self.tracking = tracking
self.min_max = min_max
self.patience = patience
self.min_epoch = min_epoch
self.patience = int(round(patience_pct * nepochs, 0))
self.min_epoch = int(round(min_epoch_pct * nepochs, 0))
self.checkpoint_dir = checkpoint_dir
self.verbose = verbose

Expand Down
2 changes: 1 addition & 1 deletion dinov2/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def update_state_dict(model_dict, state_dict):
if v.size() != model_dict[k].size():
updated_state_dict[k] = model_dict[k]
failure += 1
print(f"{k} | ckpt size: {v.size()} | model size: {model_dict[k].size()}")
logger.info(f"{k} | ckpt size: {v.size()} | model size: {model_dict[k].size()}")
else:
updated_state_dict[k] = v
success += 1
Expand Down
3 changes: 1 addition & 2 deletions dinov2/train/ssl_meta_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def __init__(self, cfg):
if cfg.student.pretrained_weights:
chkpt = torch.load(cfg.student.pretrained_weights)
logger.info(f"OPTIONS -- pretrained weights: loading from {cfg.student.pretrained_weights}")
sd = chkpt["model"]
sd, msg = update_state_dict(student_backbone.state_dict(), sd)
sd, msg = update_state_dict(student_backbone.state_dict(), chkpt)
logger.info(f"pretrained weights loaded: {msg}")
student_backbone.load_state_dict(sd, strict=False)

Expand Down
59 changes: 34 additions & 25 deletions dinov2/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ def get_args_parser(add_help: bool = True):
)
parser.add_argument(
"--output-dir",
"--output_dir",
default="output",
type=str,
help="Output directory to save logs and checkpoints",
)
Expand All @@ -75,7 +73,7 @@ def build_schedulers(cfg, OFFICIAL_EPOCH_LENGTH):
base_value=cfg.optim["lr"],
final_value=cfg.optim["min_lr"],
total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH,
warmup_iters=cfg.optim["warmup_epochs"] * OFFICIAL_EPOCH_LENGTH,
warmup_iters=int(round(cfg.optim["warmup_pct"] * cfg.optim["epochs"], 0)) * OFFICIAL_EPOCH_LENGTH,
start_warmup_value=0,
)
wd = dict(
Expand All @@ -91,8 +89,9 @@ def build_schedulers(cfg, OFFICIAL_EPOCH_LENGTH):
teacher_temp = dict(
base_value=cfg.teacher["teacher_temp"],
final_value=cfg.teacher["teacher_temp"],
total_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH,
warmup_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH,
total_iters=int(round(cfg.teacher["warmup_teacher_temp_pct"] * cfg.optim["epochs"], 0)) * OFFICIAL_EPOCH_LENGTH,
warmup_iters=int(round(cfg.teacher["warmup_teacher_temp_pct"] * cfg.optim["epochs"], 0))
* OFFICIAL_EPOCH_LENGTH,
start_warmup_value=cfg.teacher["warmup_teacher_temp"],
)

Expand All @@ -103,7 +102,7 @@ def build_schedulers(cfg, OFFICIAL_EPOCH_LENGTH):
last_layer_lr_schedule = CosineScheduler(**lr)

last_layer_lr_schedule.schedule[
: cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH
: int(round(cfg.optim["freeze_last_layer_pct"] * cfg.optim["epochs"], 0)) * OFFICIAL_EPOCH_LENGTH
] = 0 # mimicking the original schedules

logger.info("Schedulers ready.")
Expand Down Expand Up @@ -149,7 +148,7 @@ def save_checkpoint(cfg, model, iteration):

def do_tune(
cfg,
epoch,
iteration: int,
model: torch.nn.Module,
query_dataset,
test_dataset,
Expand Down Expand Up @@ -184,7 +183,7 @@ def do_tune(
# student = student.to(torch.device(f"cuda:{distributed.get_global_rank()}"))
# teacher = teacher.to(torch.device(f"cuda:{distributed.get_global_rank()}"))
if verbose:
tqdm.tqdm.write(f"Loading epoch {epoch} weights...")
tqdm.tqdm.write(f"Loading epoch {iteration} weights...")
student_weights = model.student.state_dict()
teacher_weights = model.teacher.state_dict()
student_msg = load_weights(student, student_weights)
Expand Down Expand Up @@ -325,15 +324,15 @@ def do_train(cfg, model, resume=False):

total_batch_size = cfg.train.batch_size_per_gpu * distributed.get_global_size()
OFFICIAL_EPOCH_LENGTH = len(dataset) // total_batch_size
save_every = int(cfg.train.save_frequency * OFFICIAL_EPOCH_LENGTH)
save_every = int(round(cfg.train.save_frequency * OFFICIAL_EPOCH_LENGTH, 0))
if cfg.optim.max_iter is not None:
max_iter = cfg.optim.max_iter
else:
max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH

periodic_checkpointer = PeriodicCheckpointer(
checkpointer,
period=cfg.train.save_every * OFFICIAL_EPOCH_LENGTH,
period=save_every,
max_iter=max_iter,
max_to_keep=3,
)
Expand Down Expand Up @@ -369,8 +368,9 @@ def do_train(cfg, model, resume=False):
early_stopper = EarlyStoppingDINO(
cfg.tune.early_stopping.tracking,
cfg.tune.early_stopping.min_max,
cfg.tune.early_stopping.patience,
cfg.tune.early_stopping.min_epoch,
cfg.optim.epochs,
cfg.tune.early_stopping.patience_pct,
cfg.tune.early_stopping.min_epoch_pct,
checkpoint_dir=checkpoint_save_dir,
verbose=True,
)
Expand Down Expand Up @@ -463,9 +463,19 @@ def do_train(cfg, model, resume=False):
metric_logger.update(current_batch_size=current_batch_size)
metric_logger.update(total_loss=losses_reduced, **loss_dict_reduced)

# logging
if distributed.is_main_process() and cfg.wandb.enable:
log_dict = {"iteration": iteration}
update_log_dict(log_dict, f"{header.lower()}/lr", lr, step="iteration")
update_log_dict(log_dict, f"{header.lower()}/wd", wd, step="iteration")
update_log_dict(log_dict, f"{header.lower()}/loss", losses_reduced, step="iteration")
for loss_name, loss_value in loss_dict.items():
update_log_dict(log_dict, f"{header.lower()}/{loss_name}", loss_value, step="iteration")
wandb.log(log_dict, step=iteration)

epoch = iteration // OFFICIAL_EPOCH_LENGTH

# log at the end of each epoch
# addtional logging at the end of each epoch
if iteration % OFFICIAL_EPOCH_LENGTH == 0:
if distributed.is_main_process() and cfg.wandb.enable:
# log the total loss and each individual loss to wandb
Expand All @@ -482,7 +492,7 @@ def do_train(cfg, model, resume=False):
if cfg.tune.tune_every and epoch % cfg.tune.tune_every == 0:
tune_results = do_tune(
cfg,
epoch + 1,
iteration,
model,
query_dataset,
test_dataset,
Expand All @@ -499,17 +509,9 @@ def do_train(cfg, model, resume=False):
if early_stopper.early_stop and cfg.tune.early_stopping.enable:
stop = True

if stop:
if distributed.is_main_process():
tqdm.tqdm.write(
f"Stopping early because best {cfg.tune.early_stopping.tracking} was reached {cfg.tune.early_stopping.patience} epochs ago"
)
break

# log to wandb

if distributed.is_main_process() and cfg.wandb.enable and iteration % OFFICIAL_EPOCH_LENGTH == 0:
wandb.log(log_dict, step=epoch)
# log to wandb
if distributed.is_main_process() and cfg.wandb.enable:
wandb.log(log_dict, step=epoch)

# checkpointing and testing

Expand All @@ -521,6 +523,13 @@ def do_train(cfg, model, resume=False):

iteration = iteration + 1

if stop:
if distributed.is_main_process():
tqdm.tqdm.write(
f"Stopping early because best {cfg.tune.early_stopping.tracking} was reached {cfg.tune.early_stopping.patience} epochs ago"
)
break

# gather stats from all processes

metric_logger.synchronize_between_processes()
Expand Down
1 change: 1 addition & 0 deletions dinov2/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def default_setup(args, cfg):
key = os.environ.get("WANDB_API_KEY")
wandb_run = utils.initialize_wandb(cfg, key=key)
wandb_run.define_metric("epoch", summary="max")
wandb_run.define_metric("iteration", summary="max")
run_id = wandb_run.id
else:
run_id = ""
Expand Down

0 comments on commit f36020f

Please sign in to comment.