Skip to content

Commit

Permalink
upload resnexst and some new results on ImageNet
Browse files Browse the repository at this point in the history
mzhaoshuai committed Jan 27, 2021
1 parent 062357f commit e4d8ea7
Showing 14 changed files with 789 additions and 56 deletions.
17 changes: 11 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@

# SplitNet: Divide and Co-training

SplitNet achieves 98.71%(CIFAR-10), 89.46%(CIFAR-100), and 83.34%(ImageNet, 416px)
SplitNet achieves 98.71% on CIFAR-10, 89.46% on CIFAR-100, and 83.60% on ImageNet (SE-ResNet-101, 64x4d, 320px)
by dividing one existing large network into several small ones and co-training.

## Table of Contents
@@ -12,7 +12,7 @@ by dividing one existing large network into several small ones and co-training.
* [Introduction](#Introduction)
* [Features and TODO](#Features-and-TODO)
* [Results and Checkpoints](#Results-and-Checkpoints)
* [Checkpoints](miscs/checkpoints.md)
* [Benchmarks and Checkpoints](miscs/checkpoints.md)
* [Installation](#Installation)
* [Training](#Training)
* [Evaluation](#Evaluation)
@@ -21,6 +21,11 @@ by dividing one existing large network into several small ones and co-training.
* [Acknowledgements](#Acknowledgements)
<!--te-->


## News
- [2021/01/27] Add new results (83.60%) on ImageNet. Upload a new model, ResNeXSt, a combination of ResNeSt and ResNeXt.


## Introduction

<div align="justify">
@@ -65,7 +70,7 @@ through extensive experiments.

## Features and TODO

- [x] Support SplitNet with different models, i.e., ResNet, Wide-ResNet, ResNeXt, SENet,
- [x] Support SplitNet with different models, i.e., ResNet, Wide-ResNet, ResNeXt, ResNeXSt, SENet,
Shake-Shake, DenseNet, PyramidNet (+Shake-Drop), EfficientNet. Also support ResNeSt without SplitNet.
- [x] Different data augmentation methods, i.e., mixup, random erasing, auto-augment, rand-augment, cutout
- [x] Distributed training (tested with multi-GPUs on single machine)
@@ -99,9 +104,9 @@ Experiments on ImageNet are conducted on a single machine with 8 RTX 2080Ti GPUs
</div>


### Checkpoints
### Benchmarks and Checkpoints

[Checkpoints](miscs/checkpoints.md)
[Benchmarks and Checkpoints](miscs/checkpoints.md)

## Installation

@@ -289,7 +294,7 @@ Then run
@misc{2020_SplitNet,
author = {Shuai Zhao and Liguang Zhou and Wenxiao Wang and Deng Cai and Tin Lun Lam and Yangsheng Xu},
title = {SplitNet: Divide and Co-training},
howpublished = {ArXiv},
howpublished = {arXiv},
year = {2020}
}
```
20 changes: 13 additions & 7 deletions miscs/checkpoints.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Checkpoints
# Benchmarks and Checkpoints

Each zip file contains 4 types of files

@@ -12,13 +12,19 @@ Any issues about checkpoints should be raised at
[![checkpoints](https://img.shields.io/badge/issue-3-yellow)](https://github.com/mzhaoshuai/SplitNet-Divide-and-Co-training/issues/3).


## ImageNet

General training protocols: batch size 256, epochs 120, cos learning rate 0.1, AutoAugment/RandAugment, Label smoothing,
mixup, random erasing.

## ImageNet

| Methods | Top-1./Top-5 Acc (%) | # MParams/GFLOPs | Checkpoints |
| Methods | Top-1/Top-5 Acc | MParams/GFLOPs | Checkpoints |
|-----------------------------------|-----------------------|---------------------|--------------|
| ResNet-50, 224px | 78.84 / 94.47 | 25.7 / 5.5 | [resnet50_split1_imagenet_256_06](https://github.com/mzhaoshuai/SplitNet-Divide-and-Co-training/releases/download/1.0.0/resnet50_split1_imagenet_256_06.zip) |
| SE-ResNet-50, 224px | 79.47 / 94.54 | 28.2 / 4.9 | [se_resnet50_split1_imagenet_256_01](https://github.com/mzhaoshuai/SplitNet-Divide-and-Co-training/releases/download/1.0.0/se_resnet50_split1_imagenet_256_01.zip) |
| ResNeXSt-50, 4x16d, 224px | 79.85 / 94.98 | 17.8 / 4.3 | [resnexst50_4x16d_split1_imagenet_256_01](https://github.com/mzhaoshuai/SplitNet-Divide-and-Co-training/releases/download/1.0.0/resnexst50_4x16d_split1_imagenet_256_01.zip) |
| ResNeXSt-50, 8x16d, 224px | 80.90 / 95.36 | 30.5 / 6.8 | [resnexst50_8x16d_split1_imagenet_256_03](https://github.com/mzhaoshuai/SplitNet-Divide-and-Co-training/releases/download/1.0.0/resnexst50_8x16d_split1_imagenet_256_03.zip) |
| ResNeXSt-50, 4x32d, 224px | 81.10 / 95.49 | 37.1 / 8.3 | [resnexst50_4x32d_split1_imagenet_256_05](https://github.com/mzhaoshuai/SplitNet-Divide-and-Co-training/releases/download/1.0.0/resnexst50_4x32d_split1_imagenet_256_05.zip) |
| ResNet-110, 224px | 80.16 / 94.54 | 44.8 / 9.2 | [resnet101_split1_imagenet_256_01](https://github.com/mzhaoshuai/SplitNet-Divide-and-Co-training/releases/download/1.0.0/resnet101_split1_imagenet_256_01.zip) |
| WRN-50-2, 224px | 80.66 / 95.16 | 68.9 / 12.8 | [wide_resnet50_2_split1_imagenet_256_01](https://github.com/mzhaoshuai/SplitNet-Divide-and-Co-training/releases/download/1.0.0/wide_resnet50_2_split1_imagenet_256_01.zip) |
| WRN-50-2, S=2, 224px | 79.64 / 94.82 | 51.4 / 10.9 | [wide_resnet50_2_split2_imagenet_256_02](https://github.com/mzhaoshuai/SplitNet-Divide-and-Co-training/releases/download/1.0.0/wide_resnet50_2_split2_imagenet_256_02.zip) |
@@ -28,12 +34,12 @@ Any issues about checkpoints should be raised at
| ResNeXt-101, 64x4d, S=2, 224px | 82.13 / 95.98 | 88.6 / 18.8 | [resnext101_64x4d_split2_imagenet_256_02](https://github.com/mzhaoshuai/SplitNet-Divide-and-Co-training/releases/download/1.0.0/resnext101_64x4d_split2_imagenet_256_02.zip) |
| EfficientNet-B7, 320px | 81.83 / 95.78 | 66.7 / 10.6 | [efficientnetb7_split1_imagenet_128_03](https://github.com/mzhaoshuai/SplitNet-Divide-and-Co-training/releases/download/1.0.0/efficientnetb7_split1_imagenet_128_03.zip) |
| EfficientNet-B7, S=2, 320px | 82.74 / 96.30 | 68.2 / 10.5 | [efficientnetb7_split2_imagenet_128_02](https://github.com/mzhaoshuai/SplitNet-Divide-and-Co-training/releases/download/1.0.0/efficientnetb7_split2_imagenet_128_02.zip) |
| SE-ResNeXt-101, 64x4d, S=2, 416px | 83.34 / 96.61 | 98.0 / 61.1 | [se_resnext101_64x4d_split2_imagenet_128_02](https://github.com/mzhaoshuai/SplitNet-Divide-and-Co-training/releases/download/1.0.0/se_resnext101_64x4d_split2_imagenet_128_02.zip) |

| SE-ResNeXt-101, 64x4d, S=2, 416px, 120 epochs | 83.34 / 96.61 | 98.0 / 61.1 | [se_resnext101_64x4d_split2_imagenet_128_02](https://github.com/mzhaoshuai/SplitNet-Divide-and-Co-training/releases/download/1.0.0/se_resnext101_64x4d_split2_imagenet_128_02.zip) |
| SE-ResNeXt-101, 64x4d, S=2, 320px, 350 epochs | 83.60 / 96.69 | 98.0 / 38.2 | [se_resnext101_64x4d_B_split2_imagenet_128_05](https://github.com/mzhaoshuai/SplitNet-Divide-and-Co-training/releases/download/1.0.0/se_resnext101_64x4d_B_split2_imagenet_128_05.zip) |

## CIFAR-100

| Methods | Top-1. Acc (%) | # MParams/GFLOPs | Checkpoints |
| Methods | Top-1 Acc | MParams/GFLOPs | Checkpoints |
|---------------------------|----------------|---------------------|--------------|
| WRN-28-10 | 84.50 | 36.5 / 5.25 | [wide_resnet28_10_split1_cifar100_128_01_acc84.5](https://github.com/mzhaoshuai/SplitNet-Divide-and-Co-training/releases/download/1.0.1/wide_resnet28_10_split1_cifar100_128_01_acc84.5.zip) |
| WRN-28-10, S=2 | 85.52 | 35.8 / 5.16 | [wide_resnet28_10_split2_cifar100_128_02_acc85.52](https://github.com/mzhaoshuai/SplitNet-Divide-and-Co-training/releases/download/1.0.1/wide_resnet28_10_split2_cifar100_128_02_acc85.52.zip) |
@@ -51,7 +57,7 @@ Any issues about checkpoints should be raised at

## CIFAR-10

| Methods | Top-1. Acc (%) | # MParams/GFLOPs | Checkpoints |
| Methods | Top-1 Acc | MParams/GFLOPs | Checkpoints |
|---------------------------|----------------|---------------------|--------------|
| WRN-28-10 | 97.59 | 36.5 / 5.25 | [wide_resnet28_10_split1_cifar10_128_08_acc97.59](https://github.com/mzhaoshuai/SplitNet-Divide-and-Co-training/releases/download/1.0.2/wide_resnet28_10_split1_cifar10_128_08_acc97.59.zip) |
| WRN-28-10, S=2 | 98.19 | 35.8 / 5.16 | [wide_resnet28_10_split2_cifar10_128_07_acc98.19](https://github.com/mzhaoshuai/SplitNet-Divide-and-Co-training/releases/download/1.0.2/wide_resnet28_10_split2_cifar10_128_07_acc98.19.zip) |
Binary file modified miscs/fig3_latency.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified miscs/res_imagenet.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions model/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# coding=utf-8
from .drop import get_drop
from .activations import get_act
from .split_attn import SplitAttnConv2d
88 changes: 88 additions & 0 deletions model/layers/split_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
""" Split Attention Conv2d (for ResNeSt Models)
Paper: `ResNeSt: Split-Attention Networks` - /https://arxiv.org/abs/2004.08955
Adapted from original PyTorch impl at https://github.com/zhanghang1989/ResNeSt
Modified for torchscript compat, performance, and consistency with timm by Ross Wightman
"""
import torch
import torch.nn.functional as F
from torch import nn


class RadixSoftmax(nn.Module):
def __init__(self, radix, cardinality):
super(RadixSoftmax, self).__init__()
self.radix = radix
self.cardinality = cardinality

def forward(self, x):
batch = x.size(0)
if self.radix > 1:
x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
x = F.softmax(x, dim=1)
x = x.reshape(batch, -1)
else:
x = torch.sigmoid(x)
return x


class SplitAttnConv2d(nn.Module):
"""Split-Attention Conv2d
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
dilation=1, groups=1, bias=False, radix=2, reduction_factor=4,
act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs):
super(SplitAttnConv2d, self).__init__()
self.radix = radix
self.drop_block = drop_block
mid_chs = out_channels * radix
attn_chs = max(in_channels * radix // reduction_factor, 32)

self.conv = nn.Conv2d(
in_channels, mid_chs, kernel_size, stride, padding, dilation,
groups=groups * radix, bias=bias, **kwargs)
self.bn0 = norm_layer(mid_chs) if norm_layer is not None else None
self.act0 = act_layer(inplace=True)
self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups)
self.bn1 = norm_layer(attn_chs) if norm_layer is not None else None
self.act1 = act_layer(inplace=True)
self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups)
self.rsoftmax = RadixSoftmax(radix, groups)

@property
def in_channels(self):
return self.conv.in_channels

@property
def out_channels(self):
return self.fc1.out_channels

def forward(self, x):
x = self.conv(x)
if self.bn0 is not None:
x = self.bn0(x)
if self.drop_block is not None:
x = self.drop_block(x)
x = self.act0(x)

B, RC, H, W = x.shape
if self.radix > 1:
x = x.reshape((B, self.radix, RC // self.radix, H, W))
x_gap = x.sum(dim=1)
else:
x_gap = x
x_gap = F.adaptive_avg_pool2d(x_gap, 1)
x_gap = self.fc1(x_gap)
if self.bn1 is not None:
x_gap = self.bn1(x_gap)
x_gap = self.act1(x_gap)
x_attn = self.fc2(x_gap)

x_attn = self.rsoftmax(x_attn).view(B, -1, 1, 1)
if self.radix > 1:
out = (x * x_attn.reshape((B, self.radix, RC // self.radix, 1, 1))).sum(dim=1)
else:
out = x * x_attn
return out.contiguous()
4 changes: 3 additions & 1 deletion model/resnest/resnest.py
Original file line number Diff line number Diff line change
@@ -63,7 +63,9 @@ def resnest200(pretrained=False, root='~/.encoding/models', **kwargs):
avd=True, avd_first=False, **kwargs)
if pretrained:
model.load_state_dict(torch.hub.load_state_dict_from_url(
resnest_model_urls['resnest200'], progress=True, check_hash=True))
resnest_model_urls['resnest200'],
model_dir=root,
progress=True, check_hash=True))
return model


1 change: 1 addition & 0 deletions model/resnest/resnet.py
Original file line number Diff line number Diff line change
@@ -126,6 +126,7 @@ def forward(self, x):

return out


class ResNet(nn.Module):
"""ResNet Variants
9 changes: 5 additions & 4 deletions model/resnest/splat.py
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@

__all__ = ['SplAtConv2d']


class SplAtConv2d(Module):
"""Split-Attention Conv2d
"""
@@ -55,10 +56,10 @@ def forward(self, x):
batch, rchannel = x.shape[:2]
if self.radix > 1:
if torch.__version__ < '1.5':
splited = torch.split(x, int(rchannel//self.radix), dim=1)
splited = torch.split(x, int(rchannel // self.radix), dim=1)
else:
splited = torch.split(x, rchannel//self.radix, dim=1)
gap = sum(splited)
splited = torch.split(x, rchannel // self.radix, dim=1)
gap = sum(splited)
else:
gap = x
gap = F.adaptive_avg_pool2d(gap, 1)
@@ -81,6 +82,7 @@ def forward(self, x):
out = atten * x
return out.contiguous()


class rSoftMax(nn.Module):
def __init__(self, radix, cardinality):
super().__init__()
@@ -96,4 +98,3 @@ def forward(self, x):
else:
x = torch.sigmoid(x)
return x

Loading

0 comments on commit e4d8ea7

Please sign in to comment.