Skip to content

Commit

Permalink
add the choice of kl_loss when co-training, update README
Browse files Browse the repository at this point in the history
  • Loading branch information
mzhaoshuai committed Mar 23, 2021
1 parent e4d8ea7 commit 0bdf1d8
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 28 deletions.
55 changes: 27 additions & 28 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
[![arXiv](https://img.shields.io/badge/cs.CV-%09arXiv%3A2011.14660-red)](https://arxiv.org/abs/2011.14660)

# SplitNet: Divide and Co-training
# Divide and Co-training

SplitNet achieves 98.71% on CIFAR-10, 89.46% on CIFAR-100, and 83.60% on ImageNet (SE-ResNet-101, 64x4d, 320px)
Divide and co-training achieve 98.71% on CIFAR-10, 89.46% on CIFAR-100, and 83.60% on ImageNet (SE-ResNet-101, 64x4d, 320px)
by dividing one existing large network into several small ones and co-training.

## Table of Contents
Expand Down Expand Up @@ -31,31 +31,30 @@ by dividing one existing large network into several small ones and co-training.
<div align="justify">

This is the code for the paper
<a href="https://arxiv.org/abs/2011.14660">SplitNet: Divide and Co-training.</a>
<a href="https://arxiv.org/abs/2011.14660">
Towards Better Accuracy-efficiency Trade-offs: Divide and Co-training.</a>
<br />

The width of a neural network matters since increasing
the width will necessarily increase the model capacity. However,
the performance of a network does not improve linearly
with the width and soon gets saturated. To tackle this problem,
we propose to increase the number of networks rather
than purely scaling up the width. To prove it, one large network
is divided into several small ones, and each of these
small networks has a fraction of the original one’s parameters.
We then train these small networks together and make
them see various views of the same data to learn different
and complementary knowledge. During this co-training process,
networks can also learn from each other. As a result,
small networks can achieve better ensemble performance
The width of a neural network matters since increasing the width
will necessarily increase the model capacity.
However, the performance of a network does not improve linearly
with the width and soon gets saturated.
In this case, we argue that increasing the number of networks (ensemble)
can achieve better accuracy-efficiency trade-offs than purely increasing the width.
To prove it,
one large network is divided into several small ones
regarding its parameters and regularization components.
Each of these small networks has a fraction of the original one's parameters.
We then train these small networks together and make them see various
views of the same data to increase their diversity.
During this co-training process,
networks can also learn from each other.
As a result, small networks can achieve better ensemble performance
than the large one with few or no extra parameters or FLOPs.
This reveals that the number of networks is a new dimension
of effective model scaling, besides depth/width/resolution.
Small networks can also achieve faster inference speed
than the large one by concurrent running on different devices.
We validate the idea --- increasing the number of
networks is a new dimension of effective model scaling ---
with different network architectures on common benchmarks
through extensive experiments.
than the large one by concurrent running on different devices.
We validate our argument with 8 different neural architectures on
common benchmarks through extensive experiments.
</div>

<div align=center>
Expand All @@ -70,8 +69,8 @@ through extensive experiments.

## Features and TODO

- [x] Support SplitNet with different models, i.e., ResNet, Wide-ResNet, ResNeXt, ResNeXSt, SENet,
Shake-Shake, DenseNet, PyramidNet (+Shake-Drop), EfficientNet. Also support ResNeSt without SplitNet.
- [x] Support divide and co-training with different models, i.e., ResNet, Wide-ResNet, ResNeXt, ResNeXSt, SENet,
Shake-Shake, DenseNet, PyramidNet (+Shake-Drop), EfficientNet.
- [x] Different data augmentation methods, i.e., mixup, random erasing, auto-augment, rand-augment, cutout
- [x] Distributed training (tested with multi-GPUs on single machine)
- [x] Multi-GPUs synchronized BatchNormalization
Expand Down Expand Up @@ -197,7 +196,7 @@ wget -qO- https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/
- Download [The SVHN dataset](http://ufldl.stanford.edu/housenumbers/) (*Format 2: Cropped Digits*),
put them in the `dataset/svhn` directory.

- `cd` to `github` directory and clone the `SplitNet-Divide-and-Co-training` repo.
- `cd` to `github` directory and clone the `Divide-and-Co-training` repo.
For brevity, rename it as `splitnet`.


Expand Down Expand Up @@ -291,9 +290,9 @@ Then run
## Citations

```
@misc{2020_SplitNet,
@misc{2020_splitnet,
author = {Shuai Zhao and Liguang Zhou and Wenxiao Wang and Deng Cai and Tin Lun Lam and Yangsheng Xu},
title = {SplitNet: Divide and Co-training},
title = {Towards Better Accuracy-efficiency Trade-offs: Divide and Co-training},
howpublished = {arXiv},
year = {2020}
}
Expand Down
Binary file modified miscs/fig1_width.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified miscs/fig2_framework.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified miscs/fig3_latency.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified miscs/res_cifar10.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified miscs/res_cifar100.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified miscs/res_imagenet.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 16 additions & 0 deletions model/splitnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def __init__(self,
self.models = nn.ModuleList(models)
self.criterion = criterion
if args.is_identical_init:
print("INFO:PyTorch: Using identical initialization.")
self._identical_init()

# data transform - use different transformers for different networks
Expand All @@ -222,6 +223,7 @@ def __init__(self,
self.cot_weight_warm_up_epochs = args.cot_weight_warm_up_epochs
# self.kl_temperature = args.kl_temperature
self.cot_loss_choose = args.cot_loss_choose
print("INFO:PyTorch: The co-training loss is {}.".format(self.cot_loss_choose))
self.num_classes = args.num_classes

def forward(self, x, target=None, mode='train', epoch=0, streams=None):
Expand Down Expand Up @@ -335,6 +337,20 @@ def _co_training_loss(self, outputs, loss_choose, epoch=0):
H_mean = (- p_mean * torch.log(p_mean)).sum(-1).mean()
H_sep = (- p_all * F.log_softmax(outputs_all, dim=-1)).sum(-1).mean()
cot_loss = weight_now * (H_mean - H_sep)

elif loss_choose == 'kl_seperate':
outputs_all = torch.stack(outputs, dim=0)
# repeat [1,2,3] like [1,1,2,2,3,3] and [2,3,1,3,1,2]
outputs_r1 = torch.repeat_interleave(outputs_all, self.split_factor - 1, dim=0)
index_list = [j for i in range(self.split_factor) for j in range(self.split_factor) if j!=i]
outputs_r2 = torch.index_select(outputs_all, dim=0, index=torch.tensor(index_list, dtype=torch.long).cuda())
# calculate the KL divergence
kl_loss = F.kl_div(F.log_softmax(outputs_r1, dim=-1),
F.softmax(outputs_r2, dim=-1).detach(),
reduction='none')
# cot_loss = weight_now * (kl_loss.sum(-1).mean(-1).sum() / (self.split_factor - 1))
cot_loss = weight_now * (kl_loss.sum(-1).mean(-1).sum() / (self.split_factor - 1))

else:
raise NotImplementedError

Expand Down

0 comments on commit 0bdf1d8

Please sign in to comment.