-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
569 changed files
with
85,430 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
pip-wheel-metadata/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
db.sqlite3-journal | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# pipenv | ||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
# install all needed dependencies. | ||
#Pipfile.lock | ||
|
||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow | ||
__pypackages__/ | ||
|
||
# Celery stuff | ||
celerybeat-schedule | ||
celerybeat.pid | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.dmypy.json | ||
dmypy.json | ||
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
./outputs | ||
outputs/* | ||
figures/* | ||
plots/* | ||
.idea/* | ||
checkpoints/ | ||
*.out | ||
*.err | ||
.vscode/ | ||
*.xml | ||
*.pth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,84 +1,170 @@ | ||
# Fast Vision Transformers with HiLo Attention | ||
Official PyTorch implementation of [Fast Vision Transformers with HiLo Attention](https://arxiv.org/abs/2205.13213). | ||
# Fast Vision Transformers with HiLo Attention👋 | ||
[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) | ||
|
||
|
||
This is the official PyTorch implementation of [Fast Vision Transformers with HiLo Attention](https://arxiv.org/abs/2205.13213). | ||
|
||
By [Zizheng Pan](https://scholar.google.com.au/citations?user=w_VMopoAAAAJ&hl=en), [Jianfei Cai](https://scholar.google.com/citations?user=N6czCoUAAAAJ&hl=en), and [Bohan Zhuang](https://scholar.google.com.au/citations?user=DFuDBBwAAAAJ). | ||
|
||
|
||
|
||
## A Gentle Introduction | ||
|
||
![hilo](.github/arch.png) | ||
|
||
|
||
We introduce LITv2, a simple and effective ViT which performs favourably against the existing state-of-the-art methods across a spectrum of different model sizes with faster speed. | ||
|
||
![hilo](.github/hilo.png) | ||
|
||
The core of LITv2: **HiLo attention** HiLo is inspired by the insight that high frequencies in an image capture local fine details and low frequencies focus on global structures, whereas a multi-head self-attention layer neglects the characteristic of different frequencies. Therefore, we propose to disentangle the high/low frequency patterns in an attention layer by separating the heads into two groups, where one group encodes high frequencies via self-attention within each local window, and another group performs the attention to model the global relationship between the average-pooled low-frequency keys from each window and each query position in the input feature map. | ||
|
||
|
||
|
||
## News | ||
|
||
- **16/06/2022.** We release the source code for classification/detection/segmentation, along with the pretrained weights. Any issues are welcomed! | ||
|
||
|
||
|
||
## Installation | ||
|
||
### Requirements | ||
|
||
- Linux with Python ≥ 3.6 | ||
- PyTorch 1.8.1 | ||
- CUDA 11.1 | ||
- An NVIDIA GPU | ||
|
||
### Conda environment setup | ||
|
||
**Note**: You can use the same environment to debug [LITv1](https://github.com/ziplab/LIT). Otherwise, you can create a new python virtual environment by the following script. | ||
|
||
```bash | ||
conda create -n lit python=3.7 | ||
conda activate lit | ||
|
||
# Install Pytorch and TorchVision | ||
pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html | ||
|
||
pip install timm==0.3.2 | ||
pip install ninja | ||
pip install tensorboard | ||
|
||
# Install NVIDIA apex | ||
git clone https://github.com/NVIDIA/apex | ||
cd apex | ||
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ | ||
cd ../ | ||
rm -rf apex/ | ||
|
||
# Build Deformable Convolution | ||
cd mm_modules/DCN | ||
python setup.py build install | ||
|
||
pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8 | ||
``` | ||
|
||
|
||
|
||
# Getting Started | ||
|
||
For image classification on ImageNet, please refer to [classification](https://github.com/ziplab/LITv2/tree/main/classification). | ||
|
||
For object detection on COCO 2017, please refer to [detection](https://github.com/ziplab/LITv2/tree/main/detection). | ||
|
||
For semantic segmentation on ADE20K, please refer to [segmentation](https://github.com/ziplab/LITv2/tree/main/segmentation). | ||
|
||
|
||
|
||
The core of LITv2: **HiLo attention**. HiLo is inspired by the insight that high frequencies in an image capture local fine details and low frequencies focus on global structures, whereas a multi-head self-attention layer neglects the characteristic of different frequencies. Therefore, we propose to disentangle the high/low frequency patterns in an attention layer by separating the heads into two groups, where one group encodes high frequencies via self-attention within each local window, and another group performs the attention to model the global relationship between the average-pooled low-frequency keys from each window and each query position in the input feature map. | ||
## Results and Model Zoo | ||
|
||
**Note:** For your convenience, you can download find all models and logs from [Google Drive](https://drive.google.com/drive/folders/1VAtrPWEqxi-6q6luwEVdYvkBYedApwbU?usp=sharing) (4.8G in total). Alternatively, we also provide download links from github. | ||
|
||
### Image Classification on ImageNet-1K | ||
|
||
All models are trained with 300 epochs with a total batch size of 1024 on 8 V100 GPUs. | ||
|
||
| Model | Resolution | Params (M) | FLOPs (G) | Throughput (imgs/s) | Train Mem (GB) | Test Mem (GB) | Top-1 (%) | Download | | ||
| ------- | ---------- | ---------- | --------- | ------------------- | -------------- | ------------- | --------- | ----------- | | ||
| LITv2-S | 224 | 28 | 3.7 | 1,471 | 5.1 | 1.2 | 82.0 | model & log | | ||
| LITv2-M | 224 | 49 | 7.5 | 812 | 8.8 | 1.4 | 83.3 | model & log | | ||
| LITv2-B | 224 | 87 | 13.2 | 602 | 12.2 | 2.1 | 83.6 | model & log | | ||
| LITv2-B | 384 | 87 | 39.7 | 198 | 35.8 | 4.6 | 84.7 | model | | ||
|
||
## Usage | ||
> By default, the throughput and memory footprint are tested on one RTX 3090 based on a batch size of 64. Memory is measured by the peak memory usage with `torch.cuda.max_memory_allocated()`. Throughput is averaged over 30 runs. | ||
Code and pretrained weights will be released soon. | ||
### Object Detection on COCO 2017 | ||
|
||
All models are trained with 1x schedule (12 epochs) with a total batch size of 16 on 8 V100 GPUs. | ||
|
||
#### RetinaNet | ||
|
||
## Image Classification on ImageNet-1K | ||
| Backbone | Window Size | Params (M) | FLOPs (G) | FPS | box AP | Config | Download | | ||
| -------- | ----------- | ---------- | --------- | ---- | ------ | ------ | ----------- | | ||
| LITv2-S | 2 | 38 | 242 | 18.7 | 44.0 | link | model & log | | ||
| LITv2-S | 4 | 38 | 230 | 20.4 | 43.7 | link | model & log | | ||
| LITv2-M | 2 | 59 | 348 | 12.2 | 46.0 | link | model & log | | ||
| LITv2-M | 4 | 59 | 312 | 14.8 | 45.8 | link | model & log | | ||
| LITv2-B | 2 | 97 | 481 | 9.5 | 46.7 | link | model & log | | ||
| LITv2-B | 4 | 97 | 430 | 11.8 | 46.3 | link | model & log | | ||
|
||
| Model | Resolution | Params (M) | FLOPs (G) | Throughput (imgs/s) | Train Mem (GB) | Test Mem (GB) | Top-1 (%) | | ||
| ------- | ---------- | ---------- | --------- | ------------------- | -------------- | ------------- | --------- | | ||
| LITv2-S | 224 | 28 | 3.7 | 1,471 | 5.1 | 1.2 | 82.0 | | ||
| LITv2-M | 224 | 49 | 7.5 | 812 | 8.8 | 1.4 | 83.3 | | ||
| LITv2-B | 224 | 87 | 13.2 | 602 | 12.2 | 2.1 | 83.6 | | ||
| LITv2-B | 384 | 87 | 39.7 | 198 | 35.8 | 4.6 | 84.7 | | ||
#### Mask R-CNN | ||
|
||
> Throughput and memory footprint are tested on one RTX 3090 based on a batch size of 64. Memory is measured by the peak memory usage with `torch.cuda.max_memory_allocated()`. | ||
| Backbone | Window Size | Params (M) | FLOPs (G) | FPS | box AP | mask AP | Config | Download | | ||
| -------- | ----------- | ---------- | --------- | ---- | ------ | ------- | ------ | ----------- | | ||
| LITv2-S | 2 | 47 | 261 | 18.7 | 44.9 | 40.8 | link | model & log | | ||
| LITv2-S | 4 | 47 | 249 | 21.9 | 44.7 | 40.7 | link | model & log | | ||
| LITv2-M | 2 | 68 | 367 | 12.6 | 46.8 | 42.3 | link | model & log | | ||
| LITv2-M | 4 | 68 | 315 | 16.0 | 46.5 | 42.0 | link | model & log | | ||
| LITv2-B | 2 | 106 | 500 | 9.3 | 47.3 | 42.6 | link | model & log | | ||
| LITv2-B | 4 | 106 | 449 | 11.5 | 46.8 | 42.3 | link | model & log | | ||
|
||
## Object Detection on COCO | ||
### Semantic Segmentation on ADE20K | ||
|
||
### RetinaNet | ||
All models are trained with 80K iterations with a total batch size of 16 on 8 V100 GPUs. | ||
|
||
| Backbone | Window Size | Params (M) | FLOPs (G) | FPS | box AP | | ||
| -------- | ----------- | ---------- | --------- | ---- | ------ | | ||
| LITv2-S | 2 | 38 | 242 | 18.7 | 44.0 | | ||
| LITv2-S | 4 | 38 | 230 | 20.4 | 43.7 | | ||
| LITv2-M | 2 | 59 | 348 | 12.2 | 46.0 | | ||
| LITv2-M | 4 | 59 | 312 | 14.8 | 45.8 | | ||
| LITv2-B | 2 | 97 | 481 | 9.5 | 46.7 | | ||
| LITv2-B | 4 | 97 | 430 | 11.8 | 46.3 | | ||
| Backbone | Params (M) | FLOPs (G) | FPS | mIoU | Config | Download | | ||
| -------- | ---------- | --------- | ---- | ---- | ------ | ----------- | | ||
| LITv2-S | 31 | 41 | 42.6 | 44.3 | link | model & log | | ||
| LITv2-M | 52 | 63 | 28.5 | 45.7 | link | model & log | | ||
| LITv2-B | 90 | 93 | 27.5 | 47.2 | link | model & log | | ||
|
||
### Mask R-CNN | ||
|
||
| Backbone | Window Size | Params (M) | FLOPs (G) | FPS | box AP | mask AP | | ||
| -------- | ----------- | ---------- | --------- | ---- | ------ | ------- | | ||
| LITv2-S | 2 | 47 | 261 | 18.7 | 44.9 | 40.8 | | ||
| LITv2-S | 4 | 47 | 249 | 21.9 | 44.7 | 40.7 | | ||
| LITv2-M | 2 | 68 | 367 | 12.6 | 46.8 | 42.3 | | ||
| LITv2-M | 4 | 68 | 315 | 16.0 | 46.5 | 42.0 | | ||
| LITv2-B | 2 | 106 | 500 | 9.3 | 47.3 | 42.6 | | ||
| LITv2-B | 4 | 106 | 449 | 11.5 | 46.8 | 42.3 | | ||
|
||
## Citation | ||
|
||
If you use LITv2 in your research, please consider the following BibTeX entry and giving us a star 🌟. | ||
|
||
## Semantic Segmentation on ADE20K | ||
```BibTeX | ||
@article{pan2022hilo | ||
title={Fast Vision Transformers with HiLo Attention}, | ||
author={Pan, Zizheng and Cai, Jianfei and Zhuang, Bohan}, | ||
journal={arXiv preprint arXiv:2205.13213}, | ||
year={2022} | ||
} | ||
``` | ||
|
||
### Semantic FPN | ||
If you find the code useful, please also consider the following BibTeX entry | ||
|
||
| Backbone | Params (M) | FLOPs (G) | FPS | mIoU | | ||
| -------- | ---------- | --------- | ---- | ---- | | ||
| LITv2-S | 31 | 41 | 42.6 | 44.3 | | ||
| LITv2-M | 52 | 63 | 28.5 | 45.7 | | ||
| LITv2-B | 90 | 93 | 27.5 | 47.2 | | ||
```BibTeX | ||
@inproceedings{pan2022litv1, | ||
title={Less is More: Pay Less Attention in Vision Transformers}, | ||
author={Pan, Zizheng and Zhuang, Bohan and He, Haoyu and Liu, Jing and Cai, Jianfei}, | ||
booktitle = {AAAI}, | ||
year={2022} | ||
} | ||
``` | ||
|
||
|
||
|
||
## License | ||
# License | ||
|
||
This repository is released under the Apache 2.0 license as found in the [LICENSE](https://github.com/zip-group/LITv2/blob/main/LICENSE) file. | ||
This repository is released under the Apache 2.0 license as found in the [LICENSE](https://github.com/ziplab/LITv2/blob/main/LICENSE) file. | ||
|
||
|
||
|
||
## Acknowledgement | ||
|
||
This repository is built upon [DeiT](https://github.com/facebookresearch/deit), [Swin](https://github.com/microsoft/Swin-Transformer) and [LIT](https://github.com/zip-group/LIT), we thank the authors for their open-sourced code. | ||
This repository is built upon [DeiT](https://github.com/facebookresearch/deit), [Swin](https://github.com/microsoft/Swin-Transformer) and [LIT](https://github.com/ziplab/LIT), we thank the authors for their open-sourced code. | ||
|
Oops, something went wrong.