diff --git a/README.md b/README.md index 9c76b8c..7163d94 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,106 @@ -# MambaOut -MambaOut: Do We Really Need Mamba for Vision? +# [MambaOut: Do We Really Need Mamba for Vision?](https://arxiv.org/abs/2405.xxxxx) + +
+ +In memory of Kobe Bryant
+> "What can I say, Mamba out." +> +> — *Kobe Bryant, NBA farewell speech, 2016* + + + +This is a PyTorch implementation of MambaOut proposed by our paper "[MambaOut: Do We Really Need Mamba for Vision?](https://arxiv.org/abs/2303.16900)". + + +## Requirements +PyTorch and timm 0.6.11 (`pip install timm==0.6.11`). + +Data preparation: ImageNet with the following folder structure, you can extract ImageNet by this [script](https://gist.github.com/BIGBALLON/8a71d225eff18d88e469e6ea9b39cef4). + +``` +│imagenet/ +├──train/ +│ ├── n01440764 +│ │ ├── n01440764_10026.JPEG +│ │ ├── n01440764_10027.JPEG +│ │ ├── ...... +│ ├── ...... +├──val/ +│ ├── n01440764 +│ │ ├── ILSVRC2012_val_00000293.JPEG +│ │ ├── ILSVRC2012_val_00002138.JPEG +│ │ ├── ...... +│ ├── ...... +``` + + +## Models +### MambaOut trained on ImageNet +| Model | Resolution | Params | MACs | Top1 Acc | +| :--- | :---: | :---: | :---: | :---: | +| [mambaout_femto](https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_femto.pth) | 224 | 7.3M | 1.2G | 78.9 | +| [mambaout_tiny](https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_tiny.pth) | 224 | 26.5M | 4.5G | 82.7 | +| [mambaout_small](https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_small.pth) | 224 | 48.5M | 9.0G | 84.1 | +| [mambaout_base](https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_base.pth) | 224 | 84.8M | 15.8G | 84.2 | + + +#### Usage +We also provide a Colab notebook which runs the steps to perform inference with MambaOut: [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/) + + +## Validation + +To evaluate models, run: + +```bash +MODEL=mambaout_tiny +python3 validate.py /path/to/imagenet --model $MODEL -b 128 \ + --pretrained +``` + +## Train +We use batch size of 4096 by default and we show how to train models with 8 GPUs. For multi-node training, adjust `--grad-accum-steps` according to your situations. + + +```bash +DATA_PATH=/path/to/imagenet +CODE_PATH=/path/to/code/MambaOut # modify code path here + + +ALL_BATCH_SIZE=4096 +NUM_GPU=8 +GRAD_ACCUM_STEPS=4 # Adjust according to your GPU numbers and memory size. +let BATCH_SIZE=ALL_BATCH_SIZE/NUM_GPU/GRAD_ACCUM_STEPS + + +MODEL=mambaout_tiny +DROP_PATH=0.2 + + +cd $CODE_PATH && sh distributed_train.sh $NUM_GPU $DATA_PATH \ +--model $MODEL --opt adamw --lr 4e-3 --warmup-epochs 20 \ +-b $BATCH_SIZE --grad-accum-steps $GRAD_ACCUM_STEPS \ +--drop-path $DROP_PATH +``` +Training scripts of other models are shown in [scripts](/scripts/). + + +## Bibtex +``` +@article{yu2024mambaout, + title={MambaOut: Do We Really Need Mamba for Vision?}, + author={Yu, Weihao and and Wang, Xinchao}, + journal={arXiv preprint arXiv:2405.xxxxx}, + year={2024} +} +``` + +## Acknowledgment +Weihao was partly supported by Snap Research Fellowship, Google TPU Research Cloud (TRC), and Google Cloud Research Credits program. We thank Dongze Lian, Qiuhong Shen, Xingyi Yang, and Gongfan Fang for valuable discussions. + +Our implementation is based on [pytorch-image-models](https://github.com/huggingface/pytorch-image-models), [poolformer](https://github.com/sail-sg/poolformer), [ConvNeXt](https://github.com/facebookresearch/ConvNeXt), [metaformer](https://github.com/sail-sg/metaformer) and [inceptionnext](https://github.com/sail-sg/inceptionnext). diff --git a/distributed_train.sh b/distributed_train.sh new file mode 100755 index 0000000..1985669 --- /dev/null +++ b/distributed_train.sh @@ -0,0 +1,5 @@ +#!/bin/bash +NUM_PROC=$1 +shift +python3 -m torch.distributed.launch --nproc_per_node=$NUM_PROC train.py "$@" + diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..9350b4e --- /dev/null +++ b/models/__init__.py @@ -0,0 +1 @@ +from .mambaout import * \ No newline at end of file diff --git a/models/mambaout.py b/models/mambaout.py new file mode 100755 index 0000000..5ef851f --- /dev/null +++ b/models/mambaout.py @@ -0,0 +1,313 @@ +""" +MambaOut models for image classification. +Some implementations are modified from: +timm (https://github.com/rwightman/pytorch-image-models), +MetaFormer (https://github.com/sail-sg/metaformer), +InceptionNeXt (https://github.com/sail-sg/inceptionnext) +""" +from functools import partial +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import trunc_normal_, DropPath +from timm.models.registry import register_model +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': 1.0, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + 'mambaout_femto': _cfg( + url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_femto.pth'), + 'mambaout_tiny': _cfg( + url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_tiny.pth'), + 'mambaout_small': _cfg( + url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_small.pth'), + 'mambaout_base': _cfg( + url='https://github.com/yuweihao/MambaOut/releases/download/model/mambaout_base.pth'), +} + + +class StemLayer(nn.Module): + r""" Code modified from InternImage: + https://github.com/OpenGVLab/InternImage + """ + + def __init__(self, + in_channels=3, + out_channels=96, + act_layer=nn.GELU, + norm_layer=partial(nn.LayerNorm, eps=1e-6)): + super().__init__() + self.conv1 = nn.Conv2d(in_channels, + out_channels // 2, + kernel_size=3, + stride=2, + padding=1) + self.norm1 = norm_layer(out_channels // 2) + self.act = act_layer() + self.conv2 = nn.Conv2d(out_channels // 2, + out_channels, + kernel_size=3, + stride=2, + padding=1) + self.norm2 = norm_layer(out_channels) + + def forward(self, x): + x = self.conv1(x) + x = x.permute(0, 2, 3, 1) + x = self.norm1(x) + x = x.permute(0, 3, 1, 2) + x = self.act(x) + x = self.conv2(x) + x = x.permute(0, 2, 3, 1) + x = self.norm2(x) + return x + + +class DownsampleLayer(nn.Module): + r""" Code modified from InternImage: + https://github.com/OpenGVLab/InternImage + """ + def __init__(self, in_channels=96, out_channels=198, norm_layer=partial(nn.LayerNorm, eps=1e-6)): + super().__init__() + self.conv = nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1) + self.norm = norm_layer(out_channels) + + def forward(self, x): + x = self.conv(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + x = self.norm(x) + return x + + +class MlpHead(nn.Module): + """ MLP classification head + """ + def __init__(self, dim, num_classes=1000, act_layer=nn.GELU, mlp_ratio=4, + norm_layer=partial(nn.LayerNorm, eps=1e-6), head_dropout=0., bias=True): + super().__init__() + hidden_features = int(mlp_ratio * dim) + self.fc1 = nn.Linear(dim, hidden_features, bias=bias) + self.act = act_layer() + self.norm = norm_layer(hidden_features) + self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias) + self.head_dropout = nn.Dropout(head_dropout) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.norm(x) + x = self.head_dropout(x) + x = self.fc2(x) + return x + + +class GatedCNNBlock(nn.Module): + r""" Our implementation of Gated CNN Block: https://arxiv.org/pdf/1612.08083 + Args: + conv_ratio: control the number of channels to conduct depthwise convolution. + Conduct convolution on partial channels can improve paraitcal efficiency. + The idea of partical channels is from ShuffleNet V2 (https://arxiv.org/abs/1807.11164) and + also used by InceptionNeXt (https://arxiv.org/abs/2303.16900) and FasterNet (https://arxiv.org/abs/2303.03667) + """ + def __init__(self, dim, expension_ratio=8/3, kernel_size=7, conv_ratio=1.0, + norm_layer=partial(nn.LayerNorm,eps=1e-6), + act_layer=nn.GELU, + drop_path=0., + **kwargs): + super().__init__() + self.norm = norm_layer(dim) + hidden = int(expension_ratio * dim) + self.fc1 = nn.Linear(dim, hidden * 2) + self.act = act_layer() + conv_channels = int(conv_ratio * dim) + self.split_indices = (hidden, hidden - conv_channels, conv_channels) + self.conv = nn.Conv2d(conv_channels, conv_channels, kernel_size=kernel_size, padding=kernel_size//2, groups=conv_channels) + self.fc2 = nn.Linear(hidden, dim) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x # [B, H, W, C] + x = self.norm(x) + g, i, c = torch.split(self.fc1(x), self.split_indices, dim=-1) + c = c.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] + c = self.conv(c) + c = c.permute(0, 2, 3, 1) # [B, C, H, W] -> [B, H, W, C] + x = self.fc2(self.act(g) * torch.cat((i, c), dim=-1)) + x = self.drop_path(x) + return x + shortcut + +r""" +downsampling (stem) for the first stage is two layer of conv with k3, s2 and p1 +downsamplings for the last 3 stages is a layer of conv with k3, s2 and p1 +DOWNSAMPLE_LAYERS_FOUR_STAGES format: [Downsampling, Downsampling, Downsampling, Downsampling] +use `partial` to specify some arguments +""" +DOWNSAMPLE_LAYERS_FOUR_STAGES = [StemLayer] + [DownsampleLayer]*3 + + +class MambaOut(nn.Module): + r""" MetaFormer + A PyTorch impl of : `MetaFormer Baselines for Vision` - + https://arxiv.org/abs/2210.13452 + + Args: + in_chans (int): Number of input image channels. Default: 3. + num_classes (int): Number of classes for classification head. Default: 1000. + depths (list or tuple): Number of blocks at each stage. Default: [3, 3, 9, 3]. + dims (int): Feature dimension at each stage. Default: [96, 192, 384, 576]. + downsample_layers: (list or tuple): Downsampling layers before each stage. + drop_path_rate (float): Stochastic depth rate. Default: 0. + output_norm: norm before classifier head. Default: partial(nn.LayerNorm, eps=1e-6). + head_fn: classification head. Default: nn.Linear. + head_dropout (float): dropout for MLP classifier. Default: 0. + """ + def __init__(self, in_chans=3, num_classes=1000, + depths=[3, 3, 9, 3], + dims=[96, 192, 384, 576], + downsample_layers=DOWNSAMPLE_LAYERS_FOUR_STAGES, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + conv_ratio=1.0, + kernel_size=7, + drop_path_rate=0., + output_norm=partial(nn.LayerNorm, eps=1e-6), + head_fn=MlpHead, + head_dropout=0.0, + **kwargs, + ): + super().__init__() + self.num_classes = num_classes + + if not isinstance(depths, (list, tuple)): + depths = [depths] # it means the model has only one stage + if not isinstance(dims, (list, tuple)): + dims = [dims] + + num_stage = len(depths) + self.num_stage = num_stage + + if not isinstance(downsample_layers, (list, tuple)): + downsample_layers = [downsample_layers] * num_stage + down_dims = [in_chans] + dims + self.downsample_layers = nn.ModuleList( + [downsample_layers[i](down_dims[i], down_dims[i+1]) for i in range(num_stage)] + ) + + dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + + self.stages = nn.ModuleList() + cur = 0 + for i in range(num_stage): + stage = nn.Sequential( + *[GatedCNNBlock(dim=dims[i], + norm_layer=norm_layer, + act_layer=act_layer, + kernel_size=kernel_size, + conv_ratio=conv_ratio, + drop_path=dp_rates[cur + j], + ) for j in range(depths[i])] + ) + self.stages.append(stage) + cur += depths[i] + + self.norm = output_norm(dims[-1]) + + if head_dropout > 0.0: + self.head = head_fn(dims[-1], num_classes, head_dropout=head_dropout) + else: + self.head = head_fn(dims[-1], num_classes) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'norm'} + + def forward_features(self, x): + for i in range(self.num_stage): + x = self.downsample_layers[i](x) + x = self.stages[i](x) + return self.norm(x.mean([1, 2])) # (B, H, W, C) -> (B, C) + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + + +############################################################################### +# a series of MambaOut models +@register_model +def mambaout_femto(pretrained=False, **kwargs): + model = MambaOut( + depths=[3, 3, 9, 3], + dims=[48, 96, 192, 288], + **kwargs) + model.default_cfg = default_cfgs['mambaout_femto'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def mambaout_tiny(pretrained=False, **kwargs): + model = MambaOut( + depths=[3, 3, 9, 3], + dims=[96, 192, 384, 576], + **kwargs) + model.default_cfg = default_cfgs['mambaout_tiny'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def mambaout_small(pretrained=False, **kwargs): + model = MambaOut( + depths=[3, 4, 27, 3], + dims=[96, 192, 384, 576], + **kwargs) + model.default_cfg = default_cfgs['mambaout_small'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model + + +@register_model +def mambaout_base(pretrained=False, **kwargs): + model = MambaOut( + depths=[3, 4, 27, 3], + dims=[128, 256, 512, 768], + **kwargs) + model.default_cfg = default_cfgs['mambaout_base'] + if pretrained: + state_dict = torch.hub.load_state_dict_from_url( + url= model.default_cfg['url'], map_location="cpu", check_hash=True) + model.load_state_dict(state_dict) + return model \ No newline at end of file diff --git a/scripts/train_mambaout_base.sh b/scripts/train_mambaout_base.sh new file mode 100644 index 0000000..d4edd3b --- /dev/null +++ b/scripts/train_mambaout_base.sh @@ -0,0 +1,18 @@ +DATA_PATH=/path/to/imagenet +CODE_PATH=/path/to/code/MambaOut # modify code path here + + +ALL_BATCH_SIZE=4096 +NUM_GPU=8 +GRAD_ACCUM_STEPS=4 # Adjust according to your GPU numbers and memory size. +let BATCH_SIZE=ALL_BATCH_SIZE/NUM_GPU/GRAD_ACCUM_STEPS + + +MODEL=mambaout_base +DROP_PATH=0.6 + + +cd $CODE_PATH && sh distributed_train.sh $NUM_GPU $DATA_PATH \ +--model $MODEL --opt adamw --lr 4e-3 --warmup-epochs 20 \ +-b $BATCH_SIZE --grad-accum-steps $GRAD_ACCUM_STEPS \ +--drop-path $DROP_PATH \ No newline at end of file diff --git a/scripts/train_mambaout_femto.sh b/scripts/train_mambaout_femto.sh new file mode 100644 index 0000000..1c03d7b --- /dev/null +++ b/scripts/train_mambaout_femto.sh @@ -0,0 +1,18 @@ +DATA_PATH=/path/to/imagenet +CODE_PATH=/path/to/code/MambaOut # modify code path here + + +ALL_BATCH_SIZE=4096 +NUM_GPU=8 +GRAD_ACCUM_STEPS=4 # Adjust according to your GPU numbers and memory size. +let BATCH_SIZE=ALL_BATCH_SIZE/NUM_GPU/GRAD_ACCUM_STEPS + + +MODEL=mambaout_femto +DROP_PATH=0.025 + + +cd $CODE_PATH && sh distributed_train.sh $NUM_GPU $DATA_PATH \ +--model $MODEL --opt adamw --lr 4e-3 --warmup-epochs 20 \ +-b $BATCH_SIZE --grad-accum-steps $GRAD_ACCUM_STEPS \ +--drop-path $DROP_PATH \ No newline at end of file diff --git a/scripts/train_mambaout_small.sh b/scripts/train_mambaout_small.sh new file mode 100644 index 0000000..5865983 --- /dev/null +++ b/scripts/train_mambaout_small.sh @@ -0,0 +1,18 @@ +DATA_PATH=/path/to/imagenet +CODE_PATH=/path/to/code/MambaOut # modify code path here + + +ALL_BATCH_SIZE=4096 +NUM_GPU=8 +GRAD_ACCUM_STEPS=4 # Adjust according to your GPU numbers and memory size. +let BATCH_SIZE=ALL_BATCH_SIZE/NUM_GPU/GRAD_ACCUM_STEPS + + +MODEL=mambaout_small +DROP_PATH=0.4 + + +cd $CODE_PATH && sh distributed_train.sh $NUM_GPU $DATA_PATH \ +--model $MODEL --opt adamw --lr 4e-3 --warmup-epochs 20 \ +-b $BATCH_SIZE --grad-accum-steps $GRAD_ACCUM_STEPS \ +--drop-path $DROP_PATH \ No newline at end of file diff --git a/scripts/train_mambaout_tiny.sh b/scripts/train_mambaout_tiny.sh new file mode 100644 index 0000000..3393240 --- /dev/null +++ b/scripts/train_mambaout_tiny.sh @@ -0,0 +1,18 @@ +DATA_PATH=/path/to/imagenet +CODE_PATH=/path/to/code/MambaOut # modify code path here + + +ALL_BATCH_SIZE=4096 +NUM_GPU=8 +GRAD_ACCUM_STEPS=4 # Adjust according to your GPU numbers and memory size. +let BATCH_SIZE=ALL_BATCH_SIZE/NUM_GPU/GRAD_ACCUM_STEPS + + +MODEL=mambaout_tiny +DROP_PATH=0.2 + + +cd $CODE_PATH && sh distributed_train.sh $NUM_GPU $DATA_PATH \ +--model $MODEL --opt adamw --lr 4e-3 --warmup-epochs 20 \ +-b $BATCH_SIZE --grad-accum-steps $GRAD_ACCUM_STEPS \ +--drop-path $DROP_PATH \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..00e7b47 --- /dev/null +++ b/train.py @@ -0,0 +1,926 @@ +#!/usr/bin/env python3 +r""" +This script is mostly copied from https://github.com/rwightman/pytorch-image-models/blob/v0.6.11/train.py +and make some modifications: +1) enable the gradient accumulation (`--grad-accum-steps`) +2) add `--head-dropout` for ConvFormer and CAFormer with MLP head +3) Set some default values of hyper-parameters following DeiT: +-j 8 \ +--opt adamw \ +--epochs 300 \ +--sched cosine \ +--warmup-epochs 5 \ +--warmup-lr 1e-6 \ +--min-lr 1e-5 \ +--weight-decay 0.05 \ +--smoothing 0.1 \ +--aa rand-m9-mstd0.5-inc1 \ +--mixup 0.8 \ +--cutmix 1.0 \ +--remode pixel \ +--reprob 0.25 \ +""" + +""" ImageNet Training Script + +This is intended to be a lean and easily modifiable ImageNet training script that reproduces ImageNet +training results with some of the latest networks and training techniques. It favours canonical PyTorch +and standard Python style over trying to be able to 'do it all.' That said, it offers quite a few speed +and training result improvements over the usual PyTorch example scripts. Repurpose as you see fit. + +This script was started from an early version of the PyTorch ImageNet example +(https://github.com/pytorch/examples/tree/master/imagenet) + +NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples +(https://github.com/NVIDIA/apex/tree/master/examples/imagenet) + +Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) +""" +import argparse +import logging +import os +import time +from collections import OrderedDict +from contextlib import suppress +from datetime import datetime + +import torch +import torch.nn as nn +import torchvision.utils +import yaml +from torch.nn.parallel import DistributedDataParallel as NativeDDP + +from timm import utils +from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset +from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, \ + LabelSmoothingCrossEntropy +from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \ + convert_splitbn_model, convert_sync_batchnorm, model_parameters, set_fast_norm +from timm.optim import create_optimizer_v2, optimizer_kwargs +from timm.scheduler import create_scheduler +# from timm.utils import ApexScaler, NativeScaler +from utils import ApexScalerAccum as ApexScaler +from utils import NativeScalerAccum as NativeScaler + +import models + +try: + from apex import amp + from apex.parallel import DistributedDataParallel as ApexDDP + from apex.parallel import convert_syncbn_model + has_apex = True +except ImportError: + has_apex = False + +has_native_amp = False +try: + if getattr(torch.cuda.amp, 'autocast') is not None: + has_native_amp = True +except AttributeError: + pass + +try: + import wandb + has_wandb = True +except ImportError: + has_wandb = False + +try: + from functorch.compile import memory_efficient_fusion + has_functorch = True +except ImportError as e: + has_functorch = False + + +torch.backends.cudnn.benchmark = True +_logger = logging.getLogger('train') + +# The first arg parser parses out only the --config argument, this argument is used to +# load a yaml file containing key-values that override the defaults for the main parser below +config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False) +parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', + help='YAML config file specifying default arguments') + + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') + +# Dataset parameters +group = parser.add_argument_group('Dataset parameters') +# Keep this argument outside of the dataset group because it is positional. +parser.add_argument('data_dir', metavar='DIR', + help='path to dataset') +group.add_argument('--dataset', '-d', metavar='NAME', default='', + help='dataset type (default: ImageFolder/ImageTar if empty)') +group.add_argument('--train-split', metavar='NAME', default='train', + help='dataset train split (default: train)') +group.add_argument('--val-split', metavar='NAME', default='validation', + help='dataset validation split (default: validation)') +group.add_argument('--dataset-download', action='store_true', default=False, + help='Allow download of dataset for torch/ and tfds/ datasets that support it.') +group.add_argument('--class-map', default='', type=str, metavar='FILENAME', + help='path to class to idx mapping file (default: "")') + +# Model parameters +group = parser.add_argument_group('Model parameters') +group.add_argument('--model', default='resnet50', type=str, metavar='MODEL', + help='Name of model to train (default: "resnet50"') +group.add_argument('--pretrained', action='store_true', default=False, + help='Start with pretrained version of specified network (if avail)') +group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', + help='Initialize model from this checkpoint (default: none)') +group.add_argument('--resume', default='', type=str, metavar='PATH', + help='Resume full model and optimizer state from checkpoint (default: none)') +group.add_argument('--no-resume-opt', action='store_true', default=False, + help='prevent resume of optimizer state when resuming model') +group.add_argument('--num-classes', type=int, default=None, metavar='N', + help='number of label classes (Model default if None)') +group.add_argument('--gp', default=None, type=str, metavar='POOL', + help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') +group.add_argument('--img-size', type=int, default=None, metavar='N', + help='Image patch size (default: None => model default)') +group.add_argument('--input-size', default=None, nargs=3, type=int, + metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') +group.add_argument('--crop-pct', default=None, type=float, + metavar='N', help='Input image center crop percent (for validation only)') +group.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override mean pixel value of dataset') +group.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', + help='Override std deviation of dataset') +group.add_argument('--interpolation', default='', type=str, metavar='NAME', + help='Image resize interpolation type (overrides model)') +group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N', + help='Input batch size for training (default: 128)') +group.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N', + help='Validation batch size override (default: None)') +group.add_argument('--channels-last', action='store_true', default=False, + help='Use channels_last memory layout') +scripting_group = group.add_mutually_exclusive_group() +scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true', + help='torch.jit.script the full model') +scripting_group.add_argument('--aot-autograd', default=False, action='store_true', + help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)") +group.add_argument('--fuser', default='', type=str, + help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") +group.add_argument('--fast-norm', default=False, action='store_true', + help='enable experimental fast-norm') +group.add_argument('--grad-checkpointing', action='store_true', default=False, + help='Enable gradient checkpointing through model blocks/stages') + +# Optimizer parameters +group = parser.add_argument_group('Optimizer parameters') +group.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', + help='Optimizer (default: "adamw"') +group.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', + help='Optimizer Epsilon (default: None, use opt default)') +group.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', + help='Optimizer Betas (default: None, use opt default)') +group.add_argument('--momentum', type=float, default=0.9, metavar='M', + help='Optimizer momentum (default: 0.9)') +group.add_argument('--weight-decay', type=float, default=0.05, + help='weight decay (default: 0.05)') +group.add_argument('--clip-grad', type=float, default=None, metavar='NORM', + help='Clip gradient norm (default: None, no clipping)') +group.add_argument('--clip-mode', type=str, default='norm', + help='Gradient clipping mode. One of ("norm", "value", "agc")') +group.add_argument('--layer-decay', type=float, default=None, + help='layer-wise learning rate decay (default: None)') + +# Learning rate schedule parameters +group = parser.add_argument_group('Learning rate schedule parameters') +group.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', + help='LR scheduler (default: "cosine"') +group.add_argument('--lr', type=float, default=0.05, metavar='LR', + help='learning rate (default: 0.05)') +group.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', + help='learning rate noise on/off epoch percentages') +group.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', + help='learning rate noise limit percent (default: 0.67)') +group.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', + help='learning rate noise std-dev (default: 1.0)') +group.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', + help='learning rate cycle len multiplier (default: 1.0)') +group.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT', + help='amount to decay each learning rate cycle (default: 0.5)') +group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', + help='learning rate cycle limit, cycles enabled if > 1') +group.add_argument('--lr-k-decay', type=float, default=1.0, + help='learning rate k-decay for cosine/poly (default: 1.0)') +group.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', + help='warmup learning rate (default: 1e-6)') +group.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', + help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') +group.add_argument('--epochs', type=int, default=300, metavar='N', + help='number of epochs to train (default: 300)') +parser.add_argument('--grad-accum-steps', default=1, type=int, + help='gradient accumulation steps') +group.add_argument('--epoch-repeats', type=float, default=0., metavar='N', + help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).') +group.add_argument('--start-epoch', default=None, type=int, metavar='N', + help='manual epoch number (useful on restarts)') +group.add_argument('--decay-milestones', default=[30, 60], type=int, nargs='+', metavar="MILESTONES", + help='list of decay epoch indices for multistep lr. must be increasing') +group.add_argument('--decay-epochs', type=float, default=100, metavar='N', + help='epoch interval to decay LR') +group.add_argument('--warmup-epochs', type=int, default=5, metavar='N', + help='epochs to warmup LR, if scheduler supports') +group.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', + help='epochs to cooldown LR at min_lr, after cyclic schedule ends') +group.add_argument('--patience-epochs', type=int, default=10, metavar='N', + help='patience epochs for Plateau LR scheduler (default: 10') +group.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', + help='LR decay rate (default: 0.1)') + +# Augmentation & regularization parameters +group = parser.add_argument_group('Augmentation and regularization parameters') +group.add_argument('--no-aug', action='store_true', default=False, + help='Disable all training augmentation, override other train aug args') +group.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', + help='Random resize scale (default: 0.08 1.0)') +group.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO', + help='Random resize aspect ratio (default: 0.75 1.33)') +group.add_argument('--hflip', type=float, default=0.5, + help='Horizontal flip training aug probability') +group.add_argument('--vflip', type=float, default=0., + help='Vertical flip training aug probability') +group.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', + help='Color jitter factor (default: 0.4)') +group.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', + help='Use AutoAugment policy. "v0" or "original". (default: rand-m9-mstd0.5-inc1)'), +group.add_argument('--aug-repeats', type=float, default=0, + help='Number of augmentation repetitions (distributed training only) (default: 0)') +group.add_argument('--aug-splits', type=int, default=0, + help='Number of augmentation splits (default: 0, valid: 0 or >=2)') +group.add_argument('--jsd-loss', action='store_true', default=False, + help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') +group.add_argument('--bce-loss', action='store_true', default=False, + help='Enable BCE loss w/ Mixup/CutMix use.') +group.add_argument('--bce-target-thresh', type=float, default=None, + help='Threshold for binarizing softened BCE targets (default: None, disabled)') +group.add_argument('--reprob', type=float, default=0.25, metavar='PCT', + help='Random erase prob (default: 0.25)') +group.add_argument('--remode', type=str, default='pixel', + help='Random erase mode (default: "pixel")') +group.add_argument('--recount', type=int, default=1, + help='Random erase count (default: 1)') +group.add_argument('--resplit', action='store_true', default=False, + help='Do not random erase first (clean) augmentation split') +group.add_argument('--mixup', type=float, default=0.8, + help='mixup alpha, mixup enabled if > 0. (default: 0.8)') +group.add_argument('--cutmix', type=float, default=1.0, + help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)') +group.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, + help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') +group.add_argument('--mixup-prob', type=float, default=1.0, + help='Probability of performing mixup or cutmix when either/both is enabled') +group.add_argument('--mixup-switch-prob', type=float, default=0.5, + help='Probability of switching to cutmix when both mixup and cutmix enabled') +group.add_argument('--mixup-mode', type=str, default='batch', + help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') +group.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', + help='Turn off mixup after this epoch, disabled if 0 (default: 0)') +group.add_argument('--smoothing', type=float, default=0.1, + help='Label smoothing (default: 0.1)') +group.add_argument('--train-interpolation', type=str, default='random', + help='Training interpolation (random, bilinear, bicubic default: "random")') +group.add_argument('--drop', type=float, default=0.0, metavar='PCT', + help='Dropout rate (default: 0.)') +group.add_argument('--drop-connect', type=float, default=None, metavar='PCT', + help='Drop connect rate, DEPRECATED, use drop-path (default: None)') +group.add_argument('--drop-path', type=float, default=None, metavar='PCT', + help='Drop path rate (default: None)') +group.add_argument('--drop-block', type=float, default=None, metavar='PCT', + help='Drop block rate (default: None)') +group.add_argument('--head-dropout', type=float, default=0.0, metavar='PCT', + help='dropout rate for classifier (default: 0.0)') + +# Batch norm parameters (only works with gen_efficientnet based models currently) +group = parser.add_argument_group('Batch norm parameters', 'Only works with gen_efficientnet based models currently.') +group.add_argument('--bn-momentum', type=float, default=None, + help='BatchNorm momentum override (if not None)') +group.add_argument('--bn-eps', type=float, default=None, + help='BatchNorm epsilon override (if not None)') +group.add_argument('--sync-bn', action='store_true', + help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') +group.add_argument('--dist-bn', type=str, default='reduce', + help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') +group.add_argument('--split-bn', action='store_true', + help='Enable separate BN layers per augmentation split.') + +# Model Exponential Moving Average +group = parser.add_argument_group('Model exponential moving average parameters') +group.add_argument('--model-ema', action='store_true', default=False, + help='Enable tracking moving average of model weights') +group.add_argument('--model-ema-force-cpu', action='store_true', default=False, + help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') +group.add_argument('--model-ema-decay', type=float, default=0.9998, + help='decay factor for model weights moving average (default: 0.9998)') + +# Misc +group = parser.add_argument_group('Miscellaneous parameters') +group.add_argument('--seed', type=int, default=42, metavar='S', + help='random seed (default: 42)') +group.add_argument('--worker-seeding', type=str, default='all', + help='worker seed mode (default: all)') +group.add_argument('--log-interval', type=int, default=50, metavar='N', + help='how many batches to wait before logging training status') +group.add_argument('--recovery-interval', type=int, default=0, metavar='N', + help='how many batches to wait before writing recovery checkpoint') +group.add_argument('--checkpoint-hist', type=int, default=10, metavar='N', + help='number of checkpoints to keep (default: 10)') +group.add_argument('-j', '--workers', type=int, default=8, metavar='N', + help='how many training processes to use (default: 8)') +group.add_argument('--save-images', action='store_true', default=False, + help='save images of input bathes every log interval for debugging') +group.add_argument('--amp', action='store_true', default=False, + help='use NVIDIA Apex AMP or Native AMP for mixed precision training') +group.add_argument('--apex-amp', action='store_true', default=False, + help='Use NVIDIA Apex AMP mixed precision') +group.add_argument('--native-amp', action='store_true', default=False, + help='Use Native Torch AMP mixed precision') +group.add_argument('--no-ddp-bb', action='store_true', default=False, + help='Force broadcast buffers for native DDP to off.') +group.add_argument('--pin-mem', action='store_true', default=False, + help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') +group.add_argument('--no-prefetcher', action='store_true', default=False, + help='disable fast prefetcher') +group.add_argument('--output', default='', type=str, metavar='PATH', + help='path to output folder (default: none, current dir)') +group.add_argument('--experiment', default='', type=str, metavar='NAME', + help='name of train experiment, name of sub-folder for output') +group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC', + help='Best metric (default: "top1"') +group.add_argument('--tta', type=int, default=0, metavar='N', + help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') +group.add_argument("--local_rank", default=0, type=int) +group.add_argument('--use-multi-epochs-loader', action='store_true', default=False, + help='use the multi-epochs-loader to save time at the beginning of every epoch') +group.add_argument('--log-wandb', action='store_true', default=False, + help='log training and validation metrics to wandb') + + +def _parse_args(): + # Do we have a config file to parse? + args_config, remaining = config_parser.parse_known_args() + if args_config.config: + with open(args_config.config, 'r') as f: + cfg = yaml.safe_load(f) + parser.set_defaults(**cfg) + + # The main arg parser parses the rest of the args, the usual + # defaults will have been overridden if config file specified. + args = parser.parse_args(remaining) + + # Cache the args as a text string to save them in the output dir later + args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) + return args, args_text + + +def main(): + utils.setup_default_logging() + args, args_text = _parse_args() + + args.prefetcher = not args.no_prefetcher + args.distributed = False + if 'WORLD_SIZE' in os.environ: + args.distributed = int(os.environ['WORLD_SIZE']) > 1 + args.device = 'cuda:0' + args.world_size = 1 + args.rank = 0 # global rank + if args.distributed: + if 'LOCAL_RANK' in os.environ: + args.local_rank = int(os.getenv('LOCAL_RANK')) + args.device = 'cuda:%d' % args.local_rank + torch.cuda.set_device(args.local_rank) + torch.distributed.init_process_group(backend='nccl', init_method='env://') + args.world_size = torch.distributed.get_world_size() + args.rank = torch.distributed.get_rank() + _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' + % (args.rank, args.world_size)) + else: + _logger.info('Training with a single process on 1 GPUs.') + assert args.rank >= 0 + + if args.rank == 0 and args.log_wandb: + if has_wandb: + wandb.init(project=args.experiment, config=args) + else: + _logger.warning("You've requested to log metrics to wandb but package not found. " + "Metrics not being logged to wandb, try `pip install wandb`") + + # resolve AMP arguments based on PyTorch / Apex availability + use_amp = None + if args.amp: + # `--amp` chooses native amp before apex (APEX ver not actively maintained) + if has_native_amp: + args.native_amp = True + elif has_apex: + args.apex_amp = True + if args.apex_amp and has_apex: + use_amp = 'apex' + elif args.native_amp and has_native_amp: + use_amp = 'native' + elif args.apex_amp or args.native_amp: + _logger.warning("Neither APEX or native Torch AMP is available, using float32. " + "Install NVIDA apex or upgrade to PyTorch 1.6") + + utils.random_seed(args.seed, args.rank) + + if args.fuser: + utils.set_jit_fuser(args.fuser) + if args.fast_norm: + set_fast_norm() + + create_model_args = dict( + model_name=args.model, + pretrained=args.pretrained, + num_classes=args.num_classes, + drop_rate=args.drop, + drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path + drop_path_rate=args.drop_path, + drop_block_rate=args.drop_block, + global_pool=args.gp, + bn_momentum=args.bn_momentum, + bn_eps=args.bn_eps, + scriptable=args.torchscript, + checkpoint_path=args.initial_checkpoint + ) + + model = create_model(**create_model_args) + + if args.num_classes is None: + assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' + args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly + + if args.grad_checkpointing: + model.set_grad_checkpointing(enable=True) + + if args.local_rank == 0: + _logger.info( + f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}') + + data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) + + # setup augmentation batch splits for contrastive loss or split bn + num_aug_splits = 0 + if args.aug_splits > 0: + assert args.aug_splits > 1, 'A split of 1 makes no sense' + num_aug_splits = args.aug_splits + + # enable split bn (separate bn stats per batch-portion) + if args.split_bn: + assert num_aug_splits > 1 or args.resplit + model = convert_splitbn_model(model, max(num_aug_splits, 2)) + + # move model to GPU, enable channels last layout if set + model.cuda() + if args.channels_last: + model = model.to(memory_format=torch.channels_last) + + # setup synchronized BatchNorm for distributed training + if args.distributed and args.sync_bn: + args.dist_bn = '' # disable dist_bn when sync BN active + assert not args.split_bn + if has_apex and use_amp == 'apex': + # Apex SyncBN used with Apex AMP + # WARNING this won't currently work with models using BatchNormAct2d + model = convert_syncbn_model(model) + else: + model = convert_sync_batchnorm(model) + if args.local_rank == 0: + _logger.info( + 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' + 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') + + if args.torchscript: + assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' + assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' + model = torch.jit.script(model) + if args.aot_autograd: + assert has_functorch, "functorch is needed for --aot-autograd" + model = memory_efficient_fusion(model) + + optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) + + # setup automatic mixed-precision (AMP) loss scaling and op casting + amp_autocast = suppress # do nothing + loss_scaler = None + if use_amp == 'apex': + model, optimizer = amp.initialize(model, optimizer, opt_level='O1') + loss_scaler = ApexScaler() + if args.local_rank == 0: + _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') + elif use_amp == 'native': + amp_autocast = torch.cuda.amp.autocast + loss_scaler = NativeScaler() + if args.local_rank == 0: + _logger.info('Using native Torch AMP. Training in mixed precision.') + else: + if args.local_rank == 0: + _logger.info('AMP not enabled. Training in float32.') + + + # optionally resume from a checkpoint + resume_epoch = None + if args.resume: + resume_epoch = resume_checkpoint( + model, args.resume, + optimizer=None if args.no_resume_opt else optimizer, + loss_scaler=None if args.no_resume_opt else loss_scaler, + log_info=args.local_rank == 0) + + # setup exponential moving average of model weights, SWA could be used here too + model_ema = None + if args.model_ema: + # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper + model_ema = utils.ModelEmaV2( + model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) + if args.resume: + load_checkpoint(model_ema.module, args.resume, use_ema=True) + + # setup distributed training + if args.distributed: + if has_apex and use_amp == 'apex': + # Apex DDP preferred unless native amp is activated + if args.local_rank == 0: + _logger.info("Using NVIDIA APEX DistributedDataParallel.") + model = ApexDDP(model, delay_allreduce=True) + else: + if args.local_rank == 0: + _logger.info("Using native Torch DistributedDataParallel.") + model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb) + # NOTE: EMA model does not need to be wrapped by DDP + + # setup learning rate schedule and starting epoch + lr_scheduler, num_epochs = create_scheduler(args, optimizer) + start_epoch = 0 + if args.start_epoch is not None: + # a specified start_epoch will always override the resume epoch + start_epoch = args.start_epoch + elif resume_epoch is not None: + start_epoch = resume_epoch + if lr_scheduler is not None and start_epoch > 0: + lr_scheduler.step(start_epoch) + + if args.local_rank == 0: + _logger.info('Scheduled epochs: {}'.format(num_epochs)) + + # create the train and eval datasets + dataset_train = create_dataset( + args.dataset, root=args.data_dir, split=args.train_split, is_training=True, + class_map=args.class_map, + download=args.dataset_download, + batch_size=args.batch_size, + repeats=args.epoch_repeats) + dataset_eval = create_dataset( + args.dataset, root=args.data_dir, split=args.val_split, is_training=False, + class_map=args.class_map, + download=args.dataset_download, + batch_size=args.batch_size) + + total_batch_size = args.batch_size * args.grad_accum_steps * args.world_size + num_training_steps_per_epoch = len(dataset_train) // total_batch_size + if args.local_rank == 0: + _logger.info('Total batch size: {}'.format(total_batch_size)) + + # setup mixup / cutmix + collate_fn = None + mixup_fn = None + mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None + if mixup_active: + mixup_args = dict( + mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, + prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, + label_smoothing=args.smoothing, num_classes=args.num_classes) + if args.prefetcher: + assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) + collate_fn = FastCollateMixup(**mixup_args) + else: + mixup_fn = Mixup(**mixup_args) + + # wrap dataset in AugMix helper + if num_aug_splits > 1: + dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) + + # create data loaders w/ augmentation pipeiine + train_interpolation = args.train_interpolation + if args.no_aug or not train_interpolation: + train_interpolation = data_config['interpolation'] + loader_train = create_loader( + dataset_train, + input_size=data_config['input_size'], + batch_size=args.batch_size, + is_training=True, + use_prefetcher=args.prefetcher, + no_aug=args.no_aug, + re_prob=args.reprob, + re_mode=args.remode, + re_count=args.recount, + re_split=args.resplit, + scale=args.scale, + ratio=args.ratio, + hflip=args.hflip, + vflip=args.vflip, + color_jitter=args.color_jitter, + auto_augment=args.aa, + num_aug_repeats=args.aug_repeats, + num_aug_splits=num_aug_splits, + interpolation=train_interpolation, + mean=data_config['mean'], + std=data_config['std'], + num_workers=args.workers, + distributed=args.distributed, + collate_fn=collate_fn, + pin_memory=args.pin_mem, + use_multi_epochs_loader=args.use_multi_epochs_loader, + worker_seeding=args.worker_seeding, + ) + + loader_eval = create_loader( + dataset_eval, + input_size=data_config['input_size'], + batch_size=args.validation_batch_size or args.batch_size, + is_training=False, + use_prefetcher=args.prefetcher, + interpolation=data_config['interpolation'], + mean=data_config['mean'], + std=data_config['std'], + num_workers=args.workers, + distributed=args.distributed, + crop_pct=data_config['crop_pct'], + pin_memory=args.pin_mem, + ) + + # setup loss function + if args.jsd_loss: + assert num_aug_splits > 1 # JSD only valid with aug splits set + train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing) + elif mixup_active: + # smoothing is handled with mixup target transform which outputs sparse, soft targets + if args.bce_loss: + train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh) + else: + train_loss_fn = SoftTargetCrossEntropy() + elif args.smoothing: + if args.bce_loss: + train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing, target_threshold=args.bce_target_thresh) + else: + train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing) + else: + train_loss_fn = nn.CrossEntropyLoss() + train_loss_fn = train_loss_fn.cuda() + validate_loss_fn = nn.CrossEntropyLoss().cuda() + + # setup checkpoint saver and eval metric tracking + eval_metric = args.eval_metric + best_metric = None + best_epoch = None + saver = None + output_dir = None + if args.rank == 0: + if args.experiment: + exp_name = args.experiment + else: + exp_name = '-'.join([ + datetime.now().strftime("%Y%m%d-%H%M%S"), + safe_model_name(args.model), + str(data_config['input_size'][-1]) + ]) + output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name) + decreasing = True if eval_metric == 'loss' else False + saver = utils.CheckpointSaver( + model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, + checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) + with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: + f.write(args_text) + + try: + for epoch in range(start_epoch, num_epochs): + if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): + loader_train.sampler.set_epoch(epoch) + + train_metrics = train_one_epoch( + epoch, model, loader_train, optimizer, train_loss_fn, args, + lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, + amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, + grad_accum_steps=args.grad_accum_steps, num_training_steps_per_epoch=num_training_steps_per_epoch + ) + + if args.distributed and args.dist_bn in ('broadcast', 'reduce'): + if args.local_rank == 0: + _logger.info("Distributing BatchNorm running means and vars") + utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce') + + eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) + + if model_ema is not None and not args.model_ema_force_cpu: + if args.distributed and args.dist_bn in ('broadcast', 'reduce'): + utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') + ema_eval_metrics = validate( + model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') + eval_metrics = ema_eval_metrics + + if lr_scheduler is not None: + # step LR for next epoch + lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) + + if output_dir is not None: + utils.update_summary( + epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), + write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb) + + if saver is not None: + # save proper checkpoint with eval metric + save_metric = eval_metrics[eval_metric] + best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) + + except KeyboardInterrupt: + pass + if best_metric is not None: + _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) + + +def train_one_epoch( + epoch, model, loader, optimizer, loss_fn, args, + lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress, + loss_scaler=None, model_ema=None, mixup_fn=None, + grad_accum_steps=1, num_training_steps_per_epoch=None): + + if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: + if args.prefetcher and loader.mixup_enabled: + loader.mixup_enabled = False + elif mixup_fn is not None: + mixup_fn.mixup_enabled = False + + second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order + batch_time_m = utils.AverageMeter() + data_time_m = utils.AverageMeter() + losses_m = utils.AverageMeter() + + model.train() + optimizer.zero_grad() + + end = time.time() + last_idx = len(loader) - 1 + num_updates = epoch * len(loader) + for batch_idx, (input, target) in enumerate(loader): + step = batch_idx // grad_accum_steps + if step >= num_training_steps_per_epoch: + continue + # last_batch = batch_idx == last_idx + last_batch = ((batch_idx + 1) // grad_accum_steps) == num_training_steps_per_epoch + data_time_m.update(time.time() - end) + if not args.prefetcher: + input, target = input.cuda(), target.cuda() + if mixup_fn is not None: + input, target = mixup_fn(input, target) + if args.channels_last: + input = input.contiguous(memory_format=torch.channels_last) + + with amp_autocast(): + output = model(input) + loss = loss_fn(output, target) + + if not args.distributed: + losses_m.update(loss.item(), input.size(0)) + + + update_grad = (batch_idx + 1) % grad_accum_steps == 0 + loss_update = loss / grad_accum_steps + if loss_scaler is not None: + loss_scaler( + loss_update, optimizer, + clip_grad=args.clip_grad, clip_mode=args.clip_mode, + parameters=model_parameters(model, exclude_head='agc' in args.clip_mode), + create_graph=second_order, update_grad=update_grad) + else: + loss_update.backward(create_graph=second_order) + if update_grad: + if args.clip_grad is not None: + utils.dispatch_clip_grad( + model_parameters(model, exclude_head='agc' in args.clip_mode), + value=args.clip_grad, mode=args.clip_mode) + optimizer.step() + + if update_grad: + optimizer.zero_grad() + if model_ema is not None: + model_ema.update(model) + + torch.cuda.synchronize() + num_updates += 1 + batch_time_m.update(time.time() - end) + if last_batch or batch_idx % args.log_interval == 0: + lrl = [param_group['lr'] for param_group in optimizer.param_groups] + lr = sum(lrl) / len(lrl) + + if args.distributed: + reduced_loss = utils.reduce_tensor(loss.data, args.world_size) + losses_m.update(reduced_loss.item(), input.size(0)) + + if args.local_rank == 0: + _logger.info( + 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' + 'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) ' + 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' + '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' + 'LR: {lr:.3e} ' + 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( + epoch, + batch_idx, len(loader), + 100. * batch_idx / last_idx, + loss=losses_m, + batch_time=batch_time_m, + rate=input.size(0) * args.world_size / batch_time_m.val, + rate_avg=input.size(0) * args.world_size / batch_time_m.avg, + lr=lr, + data_time=data_time_m)) + + if args.save_images and output_dir: + torchvision.utils.save_image( + input, + os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), + padding=0, + normalize=True) + + if saver is not None and args.recovery_interval and ( + last_batch or (batch_idx + 1) % args.recovery_interval == 0): + saver.save_recovery(epoch, batch_idx=batch_idx) + + if lr_scheduler is not None: + lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) + + end = time.time() + # end for + + if hasattr(optimizer, 'sync_lookahead'): + optimizer.sync_lookahead() + + return OrderedDict([('loss', losses_m.avg)]) + + +def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''): + batch_time_m = utils.AverageMeter() + losses_m = utils.AverageMeter() + top1_m = utils.AverageMeter() + top5_m = utils.AverageMeter() + + model.eval() + + end = time.time() + last_idx = len(loader) - 1 + with torch.no_grad(): + for batch_idx, (input, target) in enumerate(loader): + last_batch = batch_idx == last_idx + if not args.prefetcher: + input = input.cuda() + target = target.cuda() + if args.channels_last: + input = input.contiguous(memory_format=torch.channels_last) + + with amp_autocast(): + output = model(input) + if isinstance(output, (tuple, list)): + output = output[0] + + # augmentation reduction + reduce_factor = args.tta + if reduce_factor > 1: + output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) + target = target[0:target.size(0):reduce_factor] + + loss = loss_fn(output, target) + acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) + + if args.distributed: + reduced_loss = utils.reduce_tensor(loss.data, args.world_size) + acc1 = utils.reduce_tensor(acc1, args.world_size) + acc5 = utils.reduce_tensor(acc5, args.world_size) + else: + reduced_loss = loss.data + + torch.cuda.synchronize() + + losses_m.update(reduced_loss.item(), input.size(0)) + top1_m.update(acc1.item(), output.size(0)) + top5_m.update(acc5.item(), output.size(0)) + + batch_time_m.update(time.time() - end) + end = time.time() + if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): + log_name = 'Test' + log_suffix + _logger.info( + '{0}: [{1:>4d}/{2}] ' + 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' + 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' + 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' + 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( + log_name, batch_idx, last_idx, batch_time=batch_time_m, + loss=losses_m, top1=top1_m, top5=top5_m)) + + metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) + + return metrics + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..8e55df8 --- /dev/null +++ b/utils.py @@ -0,0 +1,61 @@ +# Modifed form timm and swin repo. + +""" CUDA / AMP utils + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch + +try: + from apex import amp + has_apex = True +except ImportError: + amp = None + has_apex = False + +from timm.utils.clip_grad import dispatch_clip_grad + + +class ApexScalerAccum: + state_dict_key = "amp" + + def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False, + update_grad=True): + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward(create_graph=create_graph) + if update_grad: + if clip_grad is not None: + dispatch_clip_grad(amp.master_params(optimizer), clip_grad, mode=clip_mode) + optimizer.step() + + def state_dict(self): + if 'state_dict' in amp.__dict__: + return amp.state_dict() + + def load_state_dict(self, state_dict): + if 'load_state_dict' in amp.__dict__: + amp.load_state_dict(state_dict) + + +class NativeScalerAccum: + state_dict_key = "amp_scaler" + + def __init__(self): + self._scaler = torch.cuda.amp.GradScaler() + + def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False, + update_grad=True): + self._scaler.scale(loss).backward(create_graph=create_graph) + if update_grad: + if clip_grad is not None: + assert parameters is not None + self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place + dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) + self._scaler.step(optimizer) + self._scaler.update() + + def state_dict(self): + return self._scaler.state_dict() + + def load_state_dict(self, state_dict): + self._scaler.load_state_dict(state_dict) diff --git a/validate.py b/validate.py new file mode 100644 index 0000000..84ef470 --- /dev/null +++ b/validate.py @@ -0,0 +1,400 @@ +#!/usr/bin/env python3 +# This script is mostly copied from https://github.com/rwightman/pytorch-image-models/blob/v0.6.11/validate.py +""" ImageNet Validation Script + +This is intended to be a lean and easily modifiable ImageNet validation script for evaluating pretrained +models or training checkpoints against ImageNet or similarly organized image datasets. It prioritizes +canonical PyTorch, standard Python style, and good performance. Repurpose as you see fit. + +Hacked together by Ross Wightman (https://github.com/rwightman) +""" +import argparse +import os +import csv +import glob +import json +import time +import logging +import torch +import torch.nn as nn +import torch.nn.parallel +from collections import OrderedDict +from contextlib import suppress + +from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models, set_fast_norm +from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet +from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser,\ + decay_batch_step, check_batch_size_retry + +import models + +has_apex = False +try: + from apex import amp + has_apex = True +except ImportError: + pass + +has_native_amp = False +try: + if getattr(torch.cuda.amp, 'autocast') is not None: + has_native_amp = True +except AttributeError: + pass + +try: + from functorch.compile import memory_efficient_fusion + has_functorch = True +except ImportError as e: + has_functorch = False + +torch.backends.cudnn.benchmark = True +_logger = logging.getLogger('validate') + + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('--dataset', '-d', metavar='NAME', default='', + help='dataset type (default: ImageFolder/ImageTar if empty)') +parser.add_argument('--split', metavar='NAME', default='validation', + help='dataset split (default: validation)') +parser.add_argument('--dataset-download', action='store_true', default=False, + help='Allow download of dataset for torch/ and tfds/ datasets that support it.') +parser.add_argument('--model', '-m', metavar='NAME', default='dpn92', + help='model architecture (default: dpn92)') +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers (default: 2)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', help='mini-batch size (default: 256)') +parser.add_argument('--img-size', default=None, type=int, + metavar='N', help='Input image dimension, uses model default if empty') +parser.add_argument('--input-size', default=None, nargs=3, type=int, + metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') +parser.add_argument('--use-train-size', action='store_true', default=False, + help='force use of train input size, even when test size is specified in pretrained cfg') +parser.add_argument('--crop-pct', default=None, type=float, + metavar='N', help='Input image center crop pct') +parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override mean pixel value of dataset') +parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', + help='Override std deviation of of dataset') +parser.add_argument('--interpolation', default='', type=str, metavar='NAME', + help='Image resize interpolation type (overrides model)') +parser.add_argument('--num-classes', type=int, default=None, + help='Number classes in dataset') +parser.add_argument('--class-map', default='', type=str, metavar='FILENAME', + help='path to class to idx mapping file (default: "")') +parser.add_argument('--gp', default=None, type=str, metavar='POOL', + help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') +parser.add_argument('--log-freq', default=10, type=int, + metavar='N', help='batch logging frequency (default: 10)') +parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='use pre-trained model') +parser.add_argument('--num-gpu', type=int, default=1, + help='Number of GPUS to use') +parser.add_argument('--test-pool', dest='test_pool', action='store_true', + help='enable test time pool') +parser.add_argument('--no-prefetcher', action='store_true', default=False, + help='disable fast prefetcher') +parser.add_argument('--pin-mem', action='store_true', default=False, + help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') +parser.add_argument('--channels-last', action='store_true', default=False, + help='Use channels_last memory layout') +parser.add_argument('--amp', action='store_true', default=False, + help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.') +parser.add_argument('--apex-amp', action='store_true', default=False, + help='Use NVIDIA Apex AMP mixed precision') +parser.add_argument('--native-amp', action='store_true', default=False, + help='Use Native Torch AMP mixed precision') +parser.add_argument('--tf-preprocessing', action='store_true', default=False, + help='Use Tensorflow preprocessing pipeline (require CPU TF installed') +parser.add_argument('--use-ema', dest='use_ema', action='store_true', + help='use ema version of weights if present') +scripting_group = parser.add_mutually_exclusive_group() +scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true', + help='torch.jit.script the full model') +scripting_group.add_argument('--aot-autograd', default=False, action='store_true', + help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)") +parser.add_argument('--fuser', default='', type=str, + help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") +parser.add_argument('--fast-norm', default=False, action='store_true', + help='enable experimental fast-norm') +parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', + help='Output csv file for validation results (summary)') +parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME', + help='Real labels JSON file for imagenet evaluation') +parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME', + help='Valid label indices txt file for validation of partial label space') +parser.add_argument('--retry', default=False, action='store_true', + help='Enable batch size decay & retry for single model validation') + + +def validate(args): + # might as well try to validate something + args.pretrained = args.pretrained or not args.checkpoint + args.prefetcher = not args.no_prefetcher + amp_autocast = suppress # do nothing + if args.amp: + if has_native_amp: + args.native_amp = True + elif has_apex: + args.apex_amp = True + else: + _logger.warning("Neither APEX or Native Torch AMP is available.") + assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." + if args.native_amp: + amp_autocast = torch.cuda.amp.autocast + _logger.info('Validating in mixed precision with native PyTorch AMP.') + elif args.apex_amp: + _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') + else: + _logger.info('Validating in float32. AMP not enabled.') + + if args.fuser: + set_jit_fuser(args.fuser) + if args.fast_norm: + set_fast_norm() + + # create model + model = create_model( + args.model, + pretrained=args.pretrained, + num_classes=args.num_classes, + in_chans=3, + global_pool=args.gp, + scriptable=args.torchscript) + if args.num_classes is None: + assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' + args.num_classes = model.num_classes + + if args.checkpoint: + load_checkpoint(model, args.checkpoint, args.use_ema) + + param_count = sum([m.numel() for m in model.parameters()]) + _logger.info('Model %s created, param count: %d' % (args.model, param_count)) + + data_config = resolve_data_config( + vars(args), + model=model, + use_test_size=not args.use_train_size, + verbose=True + ) + test_time_pool = False + if args.test_pool: + model, test_time_pool = apply_test_time_pool(model, data_config) + + if args.torchscript: + torch.jit.optimized_execution(True) + model = torch.jit.script(model) + if args.aot_autograd: + assert has_functorch, "functorch is needed for --aot-autograd" + model = memory_efficient_fusion(model) + + model = model.cuda() + if args.apex_amp: + model = amp.initialize(model, opt_level='O1') + + if args.channels_last: + model = model.to(memory_format=torch.channels_last) + + if args.num_gpu > 1: + model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) + + criterion = nn.CrossEntropyLoss().cuda() + + dataset = create_dataset( + root=args.data, name=args.dataset, split=args.split, + download=args.dataset_download, load_bytes=args.tf_preprocessing, class_map=args.class_map) + + if args.valid_labels: + with open(args.valid_labels, 'r') as f: + valid_labels = {int(line.rstrip()) for line in f} + valid_labels = [i in valid_labels for i in range(args.num_classes)] + else: + valid_labels = None + + if args.real_labels: + real_labels = RealLabelsImagenet(dataset.filenames(basename=True), real_json=args.real_labels) + else: + real_labels = None + + crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] + loader = create_loader( + dataset, + input_size=data_config['input_size'], + batch_size=args.batch_size, + use_prefetcher=args.prefetcher, + interpolation=data_config['interpolation'], + mean=data_config['mean'], + std=data_config['std'], + num_workers=args.workers, + crop_pct=crop_pct, + pin_memory=args.pin_mem, + tf_preprocessing=args.tf_preprocessing) + + batch_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + + model.eval() + with torch.no_grad(): + # warmup, reduce variability of first batch time, especially for comparing torchscript vs non + input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda() + if args.channels_last: + input = input.contiguous(memory_format=torch.channels_last) + with amp_autocast(): + model(input) + + end = time.time() + for batch_idx, (input, target) in enumerate(loader): + if args.no_prefetcher: + target = target.cuda() + input = input.cuda() + if args.channels_last: + input = input.contiguous(memory_format=torch.channels_last) + + # compute output + with amp_autocast(): + output = model(input) + + if valid_labels is not None: + output = output[:, valid_labels] + loss = criterion(output, target) + + if real_labels is not None: + real_labels.add_result(output) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5)) + losses.update(loss.item(), input.size(0)) + top1.update(acc1.item(), input.size(0)) + top5.update(acc5.item(), input.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if batch_idx % args.log_freq == 0: + _logger.info( + 'Test: [{0:>4d}/{1}] ' + 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' + 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' + 'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' + 'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( + batch_idx, len(loader), batch_time=batch_time, + rate_avg=input.size(0) / batch_time.avg, + loss=losses, top1=top1, top5=top5)) + + if real_labels is not None: + # real labels mode replaces topk values at the end + top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5) + else: + top1a, top5a = top1.avg, top5.avg + results = OrderedDict( + model=args.model, + top1=round(top1a, 4), top1_err=round(100 - top1a, 4), + top5=round(top5a, 4), top5_err=round(100 - top5a, 4), + param_count=round(param_count / 1e6, 2), + img_size=data_config['input_size'][-1], + crop_pct=crop_pct, + interpolation=data_config['interpolation']) + + _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format( + results['top1'], results['top1_err'], results['top5'], results['top5_err'])) + + return results + + +def _try_run(args, initial_batch_size): + batch_size = initial_batch_size + results = OrderedDict() + error_str = 'Unknown' + while batch_size: + args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case + try: + torch.cuda.empty_cache() + results = validate(args) + return results + except RuntimeError as e: + error_str = str(e) + _logger.error(f'"{error_str}" while running validation.') + if not check_batch_size_retry(error_str): + break + batch_size = decay_batch_step(batch_size) + _logger.warning(f'Reducing batch size to {batch_size} for retry.') + results['error'] = error_str + _logger.error(f'{args.model} failed to validate ({error_str}).') + return results + + +def main(): + setup_default_logging() + args = parser.parse_args() + model_cfgs = [] + model_names = [] + if os.path.isdir(args.checkpoint): + # validate all checkpoints in a path with same model + checkpoints = glob.glob(args.checkpoint + '/*.pth.tar') + checkpoints += glob.glob(args.checkpoint + '/*.pth') + model_names = list_models(args.model) + model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)] + else: + if args.model == 'all': + # validate all models in a list of names with pretrained checkpoints + args.pretrained = True + model_names = list_models(pretrained=True, exclude_filters=['*_in21k', '*_in22k', '*_dino']) + model_cfgs = [(n, '') for n in model_names] + elif not is_model(args.model): + # model name doesn't exist, try as wildcard filter + model_names = list_models(args.model) + model_cfgs = [(n, '') for n in model_names] + + if not model_cfgs and os.path.isfile(args.model): + with open(args.model) as f: + model_names = [line.rstrip() for line in f] + model_cfgs = [(n, None) for n in model_names if n] + + if len(model_cfgs): + results_file = args.results_file or './results-all.csv' + _logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names))) + results = [] + try: + initial_batch_size = args.batch_size + for m, c in model_cfgs: + args.model = m + args.checkpoint = c + r = _try_run(args, initial_batch_size) + if 'error' in r: + continue + if args.checkpoint: + r['checkpoint'] = args.checkpoint + results.append(r) + except KeyboardInterrupt as e: + pass + results = sorted(results, key=lambda x: x['top1'], reverse=True) + if len(results): + write_results(results_file, results) + else: + if args.retry: + results = _try_run(args, args.batch_size) + else: + results = validate(args) + # output results in JSON to stdout w/ delimiter for runner script + print(f'--result\n{json.dumps(results, indent=4)}') + + +def write_results(results_file, results): + with open(results_file, mode='w') as cf: + dw = csv.DictWriter(cf, fieldnames=results[0].keys()) + dw.writeheader() + for r in results: + dw.writerow(r) + cf.flush() + + +if __name__ == '__main__': + main() \ No newline at end of file