Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Exponential Moving Average of Weights (EMA) (#138)
Browse files Browse the repository at this point in the history
Summary:
EMA as used in "Fast and Accurate Model Scaling" to improve accuracy.
Note that EMA of model weights is nearly free computationally (if not
computed every iter), hence EMA weights area always computed/stored.
Saving/loading checkpoints has been updated, but the code is backward
compatible with checkpoints that do not store the ema weights.

Details:
-config.py: added EMA options
-meters.py: generalized to allow for ema meter
-net.py: added update_model_ema() to compute model ema
-trainer.py: added updating/testing/logging of ema model
-checkpoint.py: save/load_checkpoint() also save/load ema weights

Pull Request resolved: #138

Reviewed By: theschnitz, vaibhava0

Differential Revision: D28204356

Pulled By: pdollar

fbshipit-source-id: 633969856c4359b663cb6811325c3489b74ccf2d
  • Loading branch information
pdollar authored and facebook-github-bot committed May 5, 2021
1 parent d0a090c commit 2c152a6
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 26 deletions.
56 changes: 47 additions & 9 deletions pycls/core/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def has_checkpoint():
return any(_NAME_PREFIX in f for f in pathmgr.ls(checkpoint_dir))


def save_checkpoint(model, optimizer, epoch, best):
"""Saves a checkpoint."""
def save_checkpoint(model, model_ema, optimizer, epoch, test_err, ema_err):
"""Saves a checkpoint and also the best weights so far in a best checkpoint."""
# Save checkpoints only from the master process
if not dist.is_master_proc():
return
Expand All @@ -65,29 +65,67 @@ def save_checkpoint(model, optimizer, epoch, best):
# Record the state
checkpoint = {
"epoch": epoch,
"test_err": test_err,
"ema_err": ema_err,
"model_state": unwrap_model(model).state_dict(),
"ema_state": unwrap_model(model_ema).state_dict(),
"optimizer_state": optimizer.state_dict(),
"cfg": cfg.dump(),
}
# Write the checkpoint
checkpoint_file = get_checkpoint(epoch + 1)
with pathmgr.open(checkpoint_file, "wb") as f:
torch.save(checkpoint, f)
# If best copy checkpoint to the best checkpoint
if best:
# Store the best model and model_ema weights so far
if not pathmgr.exists(get_checkpoint_best()):
pathmgr.copy(checkpoint_file, get_checkpoint_best())
else:
with pathmgr.open(get_checkpoint_best(), "rb") as f:
best = torch.load(f, map_location="cpu")
# Select the best model weights and the best model_ema weights
if test_err < best["test_err"] or ema_err < best["ema_err"]:
if test_err < best["test_err"]:
best["model_state"] = checkpoint["model_state"]
best["test_err"] = test_err
if ema_err < best["ema_err"]:
best["ema_state"] = checkpoint["ema_state"]
best["ema_err"] = ema_err
with pathmgr.open(get_checkpoint_best(), "wb") as f:
torch.save(best, f)
return checkpoint_file


def load_checkpoint(checkpoint_file, model, optimizer=None):
"""Loads the checkpoint from the given file."""
def load_checkpoint(checkpoint_file, model, model_ema=None, optimizer=None):
"""
Loads a checkpoint selectively based on the input options.
Each checkpoint contains both the model and model_ema weights (except checkpoints
created by old versions of the code). If both the model and model_weights are
requested, both sets of weights are loaded. If only the model weights are requested
(that is if model_ema=None), the *better* set of weights is selected to be loaded
(according to the lesser of test_err and ema_err, also stored in the checkpoint).
The code is backward compatible with checkpoints that do not store the ema weights.
"""
err_str = "Checkpoint '{}' not found"
assert pathmgr.exists(checkpoint_file), err_str.format(checkpoint_file)
with pathmgr.open(checkpoint_file, "rb") as f:
checkpoint = torch.load(f, map_location="cpu")
unwrap_model(model).load_state_dict(checkpoint["model_state"])
optimizer.load_state_dict(checkpoint["optimizer_state"]) if optimizer else ()
return checkpoint["epoch"]
# Get test_err and ema_err (with backward compatibility)
test_err = checkpoint["test_err"] if "test_err" in checkpoint else 100
ema_err = checkpoint["ema_err"] if "ema_err" in checkpoint else 100
# Load model and optionally model_ema weights (with backward compatibility)
ema_state = "ema_state" if "ema_state" in checkpoint else "model_state"
if model_ema:
unwrap_model(model).load_state_dict(checkpoint["model_state"])
unwrap_model(model_ema).load_state_dict(checkpoint[ema_state])
else:
best_state = "model_state" if test_err <= ema_err else ema_state
unwrap_model(model).load_state_dict(checkpoint[best_state])
# Load optimizer if requested
if optimizer:
optimizer.load_state_dict(checkpoint["optimizer_state"])
return checkpoint["epoch"], test_err, ema_err


def delete_checkpoints(checkpoint_dir=None, keep="all"):
Expand Down
6 changes: 6 additions & 0 deletions pycls/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,12 @@
# Gradually warm up the OPTIM.BASE_LR over this number of epochs
_C.OPTIM.WARMUP_EPOCHS = 0

# Exponential Moving Average (EMA) update value
_C.OPTIM.EMA_ALPHA = 1e-5

# Iteration frequency with which to update EMA weights
_C.OPTIM.EMA_UPDATE_PERIOD = 32


# --------------------------------- Training options --------------------------------- #
_C.TRAIN = CfgNode()
Expand Down
14 changes: 8 additions & 6 deletions pycls/core/meters.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,10 @@ def get_global_avg(self):
class TrainMeter(object):
"""Measures training stats."""

def __init__(self, epoch_iters):
def __init__(self, epoch_iters, phase="train"):
self.epoch_iters = epoch_iters
self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters
self.phase = phase
self.iter_timer = Timer()
self.loss = ScalarMeter(cfg.LOG_PERIOD)
self.loss_total = 0.0
Expand Down Expand Up @@ -149,7 +150,7 @@ def get_iter_stats(self, cur_epoch, cur_iter):
def log_iter_stats(self, cur_epoch, cur_iter):
if (cur_iter + 1) % cfg.LOG_PERIOD == 0:
stats = self.get_iter_stats(cur_epoch, cur_iter)
logger.info(logging.dump_log_data(stats, "train_iter"))
logger.info(logging.dump_log_data(stats, self.phase + "_iter"))

def get_epoch_stats(self, cur_epoch):
cur_iter_total = (cur_epoch + 1) * self.epoch_iters
Expand All @@ -173,14 +174,15 @@ def get_epoch_stats(self, cur_epoch):

def log_epoch_stats(self, cur_epoch):
stats = self.get_epoch_stats(cur_epoch)
logger.info(logging.dump_log_data(stats, "train_epoch"))
logger.info(logging.dump_log_data(stats, self.phase + "_epoch"))


class TestMeter(object):
"""Measures testing stats."""

def __init__(self, epoch_iters):
def __init__(self, epoch_iters, phase="test"):
self.epoch_iters = epoch_iters
self.phase = phase
self.iter_timer = Timer()
# Current minibatch errors (smoothed over a window)
self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD)
Expand Down Expand Up @@ -233,7 +235,7 @@ def get_iter_stats(self, cur_epoch, cur_iter):
def log_iter_stats(self, cur_epoch, cur_iter):
if (cur_iter + 1) % cfg.LOG_PERIOD == 0:
stats = self.get_iter_stats(cur_epoch, cur_iter)
logger.info(logging.dump_log_data(stats, "test_iter"))
logger.info(logging.dump_log_data(stats, self.phase + "_iter"))

def get_epoch_stats(self, cur_epoch):
top1_err = self.num_top1_mis / self.num_samples
Expand All @@ -255,4 +257,4 @@ def get_epoch_stats(self, cur_epoch):

def log_epoch_stats(self, cur_epoch):
stats = self.get_epoch_stats(cur_epoch)
logger.info(logging.dump_log_data(stats, "test_epoch"))
logger.info(logging.dump_log_data(stats, self.phase + "_epoch"))
16 changes: 16 additions & 0 deletions pycls/core/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,19 @@ def mixup(inputs, labels):
inputs = m * inputs + (1.0 - m) * inputs[permutation, :]
labels = m * labels + (1.0 - m) * labels[permutation, :]
return inputs, labels, labels.argmax(1)


def update_model_ema(model, model_ema, cur_epoch, cur_iter):
"""Update exponential moving average (ema) of model weights."""
update_period = cfg.OPTIM.EMA_UPDATE_PERIOD
if update_period == 0 or cur_iter % update_period != 0:
return
# Adjust alpha to be fairly independent of other parameters
adjust = cfg.TRAIN.BATCH_SIZE / cfg.OPTIM.MAX_EPOCH * update_period
alpha = min(1.0, cfg.OPTIM.EMA_ALPHA * adjust)
# During warmup simply copy over weights instead of using ema
alpha = 1.0 if cur_epoch < cfg.OPTIM.WARMUP_EPOCHS else alpha
# Take ema of all parameters (not just named parameters)
params = unwrap_model(model).state_dict()
for name, param in unwrap_model(model_ema).state_dict().items():
param.copy_(param * (1.0 - alpha) + params[name] * alpha)
27 changes: 16 additions & 11 deletions pycls/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""Tools for training and testing a model."""

import random
from copy import deepcopy

import numpy as np
import pycls.core.benchmark as benchmark
Expand Down Expand Up @@ -72,7 +73,7 @@ def setup_model():
return model


def train_epoch(loader, model, loss_fun, optimizer, scaler, meter, cur_epoch):
def train_epoch(loader, model, ema, loss_fun, optimizer, scaler, meter, cur_epoch):
"""Performs one epoch of training."""
# Shuffle the data
data_loader.shuffle(loader, cur_epoch)
Expand All @@ -81,6 +82,7 @@ def train_epoch(loader, model, loss_fun, optimizer, scaler, meter, cur_epoch):
optim.set_lr(optimizer, lr)
# Enable training mode
model.train()
ema.train()
meter.reset()
meter.iter_tic()
for cur_iter, (inputs, labels) in enumerate(loader):
Expand All @@ -99,6 +101,8 @@ def train_epoch(loader, model, loss_fun, optimizer, scaler, meter, cur_epoch):
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# Update ema weights
net.update_model_ema(model, ema, cur_epoch, cur_iter)
# Compute the errors
top1_err, top5_err = meters.topk_errors(preds, labels, [1, 5])
# Combine the stats across the GPUs (no reduction if 1 GPU used)
Expand Down Expand Up @@ -146,48 +150,49 @@ def train_model():
"""Trains the model."""
# Setup training/testing environment
setup_env()
# Construct the model, loss_fun, and optimizer
# Construct the model, ema, loss_fun, and optimizer
model = setup_model()
ema = deepcopy(model)
loss_fun = builders.build_loss_fun().cuda()
optimizer = optim.construct_optimizer(model)
# Load checkpoint or initial weights
start_epoch = 0
if cfg.TRAIN.AUTO_RESUME and cp.has_checkpoint():
file = cp.get_last_checkpoint()
epoch = cp.load_checkpoint(file, model, optimizer)
epoch = cp.load_checkpoint(file, model, ema, optimizer)[0]
logger.info("Loaded checkpoint from: {}".format(file))
start_epoch = epoch + 1
elif cfg.TRAIN.WEIGHTS:
cp.load_checkpoint(cfg.TRAIN.WEIGHTS, model)
cp.load_checkpoint(cfg.TRAIN.WEIGHTS, model, ema)
logger.info("Loaded initial weights from: {}".format(cfg.TRAIN.WEIGHTS))
# Create data loaders and meters
train_loader = data_loader.construct_train_loader()
test_loader = data_loader.construct_test_loader()
train_meter = meters.TrainMeter(len(train_loader))
test_meter = meters.TestMeter(len(test_loader))
ema_meter = meters.TestMeter(len(test_loader), "test_ema")
# Create a GradScaler for mixed precision training
scaler = amp.GradScaler(enabled=cfg.TRAIN.MIXED_PRECISION)
# Compute model and loader timings
if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0:
benchmark.compute_time_full(model, loss_fun, train_loader, test_loader)
# Perform the training loop
logger.info("Start epoch: {}".format(start_epoch + 1))
best_err = np.inf
for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
# Train for one epoch
params = (train_loader, model, loss_fun, optimizer, scaler, train_meter)
params = (train_loader, model, ema, loss_fun, optimizer, scaler, train_meter)
train_epoch(*params, cur_epoch)
# Compute precise BN stats
if cfg.BN.USE_PRECISE_STATS:
net.compute_precise_bn_stats(model, train_loader)
net.compute_precise_bn_stats(ema, train_loader)
# Evaluate the model
test_epoch(test_loader, model, test_meter, cur_epoch)
# Check if checkpoint is best so far (note: should checkpoint meters as well)
stats = test_meter.get_epoch_stats(cur_epoch)
best = stats["top1_err"] <= best_err
best_err = min(stats["top1_err"], best_err)
test_epoch(test_loader, ema, ema_meter, cur_epoch)
test_err = test_meter.get_epoch_stats(cur_epoch)["top1_err"]
ema_err = ema_meter.get_epoch_stats(cur_epoch)["top1_err"]
# Save a checkpoint
file = cp.save_checkpoint(model, optimizer, cur_epoch, best)
file = cp.save_checkpoint(model, ema, optimizer, cur_epoch, test_err, ema_err)
logger.info("Wrote checkpoint to: {}".format(file))


Expand Down

0 comments on commit 2c152a6

Please sign in to comment.