Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
train enhancements: smooth labels, mixup, mixed precision (#121)
Browse files Browse the repository at this point in the history
Summary:
notes:
-benchmark.py: properly time w mixed precision and smooth labels
-builders.py: default loss is now SoftCrossEntropyLoss
-config.py: options for MIXED_PRECISION, LABEL_SMOOTHING, MIXUP_ALPHA
-net.py: added smooth_one_hot_labels, mixup, SoftCrossEntropyLoss
-trainer.py: uses smooth labels, mixup, mixed precision

Pull Request resolved: #121

Reviewed By: theschnitz

Differential Revision: D25126890

Pulled By: pdollar

fbshipit-source-id: dd4f67ed5202109b372bbd4ed9f429f08238b719
  • Loading branch information
pdollar authored and facebook-github-bot committed Nov 20, 2020
1 parent ace492d commit ca89a79
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 15 deletions.
12 changes: 9 additions & 3 deletions pycls/core/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
"""Benchmarking functions."""

import pycls.core.logging as logging
import pycls.core.net as net
import pycls.datasets.loader as loader
import torch
import torch.cuda.amp as amp
from pycls.core.config import cfg
from pycls.core.timer import Timer

Expand Down Expand Up @@ -48,9 +50,12 @@ def compute_time_train(model, loss_fun):
im_size, batch_size = cfg.TRAIN.IM_SIZE, int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS)
inputs = torch.rand(batch_size, 3, im_size, im_size).cuda(non_blocking=False)
labels = torch.zeros(batch_size, dtype=torch.int64).cuda(non_blocking=False)
labels_one_hot = net.smooth_one_hot_labels(labels)
# Cache BatchNorm2D running stats
bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)]
bn_stats = [[bn.running_mean.clone(), bn.running_var.clone()] for bn in bns]
# Create a GradScaler for mixed precision training
scaler = amp.GradScaler(enabled=cfg.TRAIN.MIXED_PRECISION)
# Compute precise forward backward pass time
fw_timer, bw_timer = Timer(), Timer()
total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER
Expand All @@ -61,13 +66,14 @@ def compute_time_train(model, loss_fun):
bw_timer.reset()
# Forward
fw_timer.tic()
preds = model(inputs)
loss = loss_fun(preds, labels)
with amp.autocast(enabled=cfg.TRAIN.MIXED_PRECISION):
preds = model(inputs)
loss = loss_fun(preds, labels_one_hot)
torch.cuda.synchronize()
fw_timer.toc()
# Backward
bw_timer.tic()
loss.backward()
scaler.scale(loss).backward()
torch.cuda.synchronize()
bw_timer.toc()
# Restore BatchNorm2D running stats
Expand Down
4 changes: 2 additions & 2 deletions pycls/core/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

"""Model and loss construction functions."""

import torch
from pycls.core.config import cfg
from pycls.core.net import SoftCrossEntropyLoss
from pycls.models.anynet import AnyNet
from pycls.models.effnet import EffNet
from pycls.models.regnet import RegNet
Expand All @@ -19,7 +19,7 @@
_models = {"anynet": AnyNet, "effnet": EffNet, "resnet": ResNet, "regnet": RegNet}

# Supported loss functions
_loss_funs = {"cross_entropy": torch.nn.CrossEntropyLoss}
_loss_funs = {"cross_entropy": SoftCrossEntropyLoss}


def get_model():
Expand Down
9 changes: 9 additions & 0 deletions pycls/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,15 @@
# Weights to start training from
_C.TRAIN.WEIGHTS = ""

# If True train using mixed precision
_C.TRAIN.MIXED_PRECISION = False

# Label smoothing value in 0 to 1 where (0 gives no smoothing)
_C.TRAIN.LABEL_SMOOTHING = 0.0

# Batch mixup regularization value in 0 to 1 (0 gives no mixup)
_C.TRAIN.MIXUP_ALPHA = 0.0

# Standard deviation for AlexNet-style PCA jitter (0 gives no PCA jitter)
_C.TRAIN.PCA_STD = 0.1

Expand Down
38 changes: 38 additions & 0 deletions pycls/core/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import itertools

import numpy as np
import pycls.core.distributed as dist
import torch
from pycls.core.config import cfg
Expand Down Expand Up @@ -58,3 +59,40 @@ def complexity(model):
cx = {"h": size, "w": size, "flops": 0, "params": 0, "acts": 0}
cx = unwrap_model(model).complexity(cx)
return {"flops": cx["flops"], "params": cx["params"], "acts": cx["acts"]}


def smooth_one_hot_labels(labels):
"""Convert each label to a one-hot vector."""
n_classes, label_smooth = cfg.MODEL.NUM_CLASSES, cfg.TRAIN.LABEL_SMOOTHING
err_str = "Invalid input to one_hot_vector()"
assert labels.ndim == 1 and labels.max() < n_classes, err_str
shape = (labels.shape[0], n_classes)
neg_val = label_smooth / n_classes
pos_val = 1.0 - label_smooth + neg_val
labels_one_hot = torch.full(shape, neg_val, dtype=torch.float, device=labels.device)
labels_one_hot.scatter_(1, labels.long().view(-1, 1), pos_val)
return labels_one_hot


class SoftCrossEntropyLoss(torch.nn.Module):
"""SoftCrossEntropyLoss (useful for label smoothing and mixup).
Identical to torch.nn.CrossEntropyLoss if used with one-hot labels."""

def __init__(self):
super(SoftCrossEntropyLoss, self).__init__()

def forward(self, x, y):
loss = -y * torch.nn.functional.log_softmax(x, -1)
return torch.sum(loss) / x.shape[0]


def mixup(inputs, labels):
"""Apply mixup to minibatch (https://arxiv.org/abs/1710.09412)."""
alpha = cfg.TRAIN.MIXUP_ALPHA
assert labels.shape[1] == cfg.MODEL.NUM_CLASSES, "mixup labels must be one-hot"
if alpha > 0:
m = np.random.beta(alpha, alpha)
permutation = torch.randperm(labels.shape[0])
inputs = m * inputs + (1.0 - m) * inputs[permutation, :]
labels = m * labels + (1.0 - m) * labels[permutation, :]
return inputs, labels, labels.argmax(1)
28 changes: 18 additions & 10 deletions pycls/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pycls.core.optimizer as optim
import pycls.datasets.loader as data_loader
import torch
import torch.cuda.amp as amp
from pycls.core.config import cfg


Expand Down Expand Up @@ -71,7 +72,7 @@ def setup_model():
return model


def train_epoch(loader, model, loss_fun, optimizer, meter, cur_epoch):
def train_epoch(loader, model, loss_fun, optimizer, scaler, meter, cur_epoch):
"""Performs one epoch of training."""
# Shuffle the data
data_loader.shuffle(loader, cur_epoch)
Expand All @@ -85,15 +86,19 @@ def train_epoch(loader, model, loss_fun, optimizer, meter, cur_epoch):
for cur_iter, (inputs, labels) in enumerate(loader):
# Transfer the data to the current GPU device
inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True)
# Perform the forward pass
preds = model(inputs)
# Compute the loss
loss = loss_fun(preds, labels)
# Perform the backward pass
# Convert labels to smoothed one-hot vector
labels_one_hot = net.smooth_one_hot_labels(labels)
# Apply mixup to the batch (no effect if mixup alpha is 0)
inputs, labels_one_hot, labels = net.mixup(inputs, labels_one_hot)
# Perform the forward pass and compute the loss
with amp.autocast(enabled=cfg.TRAIN.MIXED_PRECISION):
preds = model(inputs)
loss = loss_fun(preds, labels_one_hot)
# Perform the backward pass and update the parameters
optimizer.zero_grad()
loss.backward()
# Update the parameters
optimizer.step()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
# Compute the errors
top1_err, top5_err = meters.topk_errors(preds, labels, [1, 5])
# Combine the stats across the GPUs (no reduction if 1 GPU used)
Expand Down Expand Up @@ -160,6 +165,8 @@ def train_model():
test_loader = data_loader.construct_test_loader()
train_meter = meters.TrainMeter(len(train_loader))
test_meter = meters.TestMeter(len(test_loader))
# Create a GradScaler for mixed precision training
scaler = amp.GradScaler(enabled=cfg.TRAIN.MIXED_PRECISION)
# Compute model and loader timings
if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0:
benchmark.compute_time_full(model, loss_fun, train_loader, test_loader)
Expand All @@ -168,7 +175,8 @@ def train_model():
best_err = np.inf
for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH):
# Train for one epoch
train_epoch(train_loader, model, loss_fun, optimizer, train_meter, cur_epoch)
params = (train_loader, model, loss_fun, optimizer, scaler, train_meter)
train_epoch(*params, cur_epoch)
# Compute precise BN stats
if cfg.BN.USE_PRECISE_STATS:
net.compute_precise_bn_stats(model, train_loader)
Expand Down

0 comments on commit ca89a79

Please sign in to comment.