English | 简体中文
This repository is the code implementation of the paper RSMamba: Remote Sensing Image Classification with State Space Model, which is based on the MMPretrain project.
The current branch has been tested on Linux system, PyTorch 2.x and CUDA 12.1, supports Python 3.8+, and is compatible with most CUDA versions.
If you find this project helpful, please give us a star ⭐️, your support is our greatest motivation.
Main Features
- Consistent API interface and usage with MMPretrain
- Open-sourced RSMamba models of different sizes in the paper
- Support for training and testing on multiple datasets
🌟 2024.03.28 Released the RSMamba project, which is fully consistent with the API interface and usage of MMPretrain.
🌟 2024.03.29 Open-sourced the weight files of RSMamba models of different sizes in the paper.
- Open-source model training parameters
- Introduction
- Updates
- TODO
- Table of Contents
- Installation
- Dataset Preparation
- Model Training
- Model Testing
- Image Prediction
- FAQ
- Acknowledgements
- Citation
- License
- Contact Us
- Linux system, Windows is not tested, depending on whether
causal-conv1d
andmamba-ssm
can be installed - Python 3.8+, recommended 3.11
- PyTorch 2.0 or higher, recommended 2.2
- CUDA 11.7 or higher, recommended 12.1
- MMCV 2.0 or higher, recommended 2.1
It is recommended to use Miniconda for installation. The following commands will create a virtual environment named rsmamba
and install PyTorch and MMCV. In the following installation steps, the default installed CUDA version is 12.1. If your CUDA version is not 12.1, please modify it according to the actual situation.
Note: If you are experienced with PyTorch and have already installed it, you can skip to the next section. Otherwise, you can follow the steps below.
Step 0: Install Miniconda.
Step 1: Create a virtual environment named rsmamba
and activate it.
conda create -n rsmamba python=3.11 -y
conda activate rsmamba
Step 2: Install PyTorch2.2.x.
Linux/Windows:
pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu121 -y
Or
conda install pytorch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 pytorch-cuda=12.1 -c pytorch -c nvidia -y
Step 3: Install MMCV2.1.x.
pip install -U openmim
mim install mmcv==2.1.0
# or
pip install mmcv==2.1.0 -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.1/index.html
Step 4: Install other dependencies.
pip install -U mat4py ipdb modelindex
pip install transformers==4.39.2
pip install causal-conv1d==1.2.0.post2
pip install mamba-ssm==1.2.0.post1
You can download or clone the RSMamba repository.
git clone [email protected]:KyanChen/RSMamba.git
cd RSMamba
We provide the method of preparing the remote sensing image classification dataset used in the paper.
- Image and annotation download link: UC Merced Dataset.
- Image and annotation download link: AID Dataset。
- Image and annotation download link: NWPU RESISC45 Dataset。
Note: The data
folder of this project provides a small number of image annotation examples for the above datasets.
You can also choose other sources to download the data, but you need to organize the dataset in the following format:
${DATASET_ROOT} # Dataset root directory, for example: /home/username/data/UC
├── airplane
│ ├── airplane01.tif
│ ├── airplane02.tif
│ └── ...
├── ...
├── ...
├── ...
└── ...
Note: In the project folder datainfo
, we provide the data set partition file. You can also use the Python script to divide the data set.
If you want to use other datasets, you can refer to the MMPretrain documentation for dataset preparation.
We provide the configuration files of RSMamba models with different parameter sizes in the paper, which can be found in the configuration files folder. The Config file is fully consistent with the API interface and usage of MMPretrain. Below we provide an analysis of some of the main parameters. If you want to know more about the parameters, you can refer to the MMPretrain documentation.
Parameter Parsing:
work_dir
:The output path of the model training, generally no need to modify.code_root
:The root directory of the code, modify to the absolute path of the root directory of this project.data_root
:The root directory of the dataset, modify to the absolute path of the dataset root directory.batch_size
:The batch size of a single card, needs to be modified according to the memory size.max_epochs
:The maximum number of training epochs, generally no need to modify.vis_backends/WandbVisBackend
:Configuration of the network-side visualization tool, after opening the comment, you need to register an account on thewandb
official website, and you can view the visualization results during the training process in the web browser.model/backbone/arch
:The type of the model's backbone network, needs to be modified according to the selected model, includingb
,l
,h
.model/backbone/path_type
:The path type of the model, needs to be modified according to the selected model.default_hooks-CheckpointHook
:Configuration of the checkpoint saving during the model training process, generally no need to modify.num_classes
:The number of categories in the dataset, needs to be modified according to the number of categories in the dataset.dataset_type
:The type of the dataset, needs to be modified according to the type of the dataset.resume
: Whether to resume training, generally no need to modify.load_from
:The path of the pre-trained checkpoint of the model, generally no need to modify.data_preprocessor/mean/std
:The mean and standard deviation of data preprocessing, needs to be modified according to the mean and standard deviation of the dataset, generally no need to modify, refer to Python script.
Some parameters come from the inheritance value of _base_
, you can find them in the basic configuration files folder.
python tools/train.py configs/rsmamba/name_to_config.py # name_to_config.py is the configuration file you want to use
sh ./tools/dist_train.sh configs/rsmamba/name_to_config.py ${GPU_NUM} # name_to_config.py is the configuration file you want to use, GPU_NUM is the number of GPUs used
If you want to use other image classification models, you can refer to MMPretrain for model training, or you can put their Config files into the configs
folder of this project, and then train them according to the above method.
python tools/test.py configs/rsmamba/name_to_config.py ${CHECKPOINT_FILE} # name_to_config.py is the configuration file you want to use, CHECKPOINT_FILE is the checkpoint file you want to use
sh ./tools/dist_test.sh configs/rsmamba/name_to_config.py ${CHECKPOINT_FILE} ${GPU_NUM} # name_to_config.py is the configuration file you want to use, CHECKPOINT_FILE is the checkpoint file you want to use, GPU_NUM is the number of GPUs used
python demo/image_demo.py ${IMAGE_FILE} configs/rsmamba/name_to_config.py --checkpoint ${CHECKPOINT_FILE} --show-dir ${OUTPUT_DIR} # IMAGE_FILE is the image file you want to predict, name_to_config.py is the configuration file you want to use, CHECKPOINT_FILE is the checkpoint file you want to use, OUTPUT_DIR is the output path of the prediction result
python demo/image_demo.py ${IMAGE_DIR} configs/rsmamba/name_to_config.py --checkpoint ${CHECKPOINT_FILE} --show-dir ${OUTPUT_DIR} # IMAGE_DIR is the image folder you want to predict, name_to_config.py is the configuration file you want to use, CHECKPOINT_FILE is the checkpoint file you want to use, OUTPUT_DIR is the output path of the prediction result
We have listed some common problems and their corresponding solutions here. If you find that some problems are missing, please feel free to submit a PR to enrich this list. If you cannot get help here, please use issue to seek help. Please fill in all the required information in the template, which will help us locate the problem more quickly.
We recommend that you do not install MMPretrain, as we have made some modifications to the code of MMPretrain, and installing MMPretrain may cause the code to run incorrectly. If you encounter an error that the module has not been registered, please check:
- If MMPretrain is installed, uninstall it
- If
@MODELS.register_module()
is added in front of the class name, if not, add it - If
from .xxx import xxx
is added in__init__.py
, if not, add it - If
custom_imports = dict(imports=['mmpretrain.rsmamba'], allow_failed_imports=False)
is added in the Config file, if not, add it
If you encounter a Bad substitution
error when running dist_train.sh
, use bash dist_train.sh
to run the script.
- If you encounter an error when installing causal-conv1d and mamba-ssm, check if your CUDA version matches the requirements of the installation package.
- If the problem persists, download the corresponding precompiled package, and then use
pip install xxx.whl
to install it. Refer to causal-conv1d and mamba-ssm.
This project is developed based on MMPretrain, thanks to the MMPretrain project for providing the code foundation.
If you use the code or performance benchmarks of this project in your research, please refer to the following bibtex citation of RSMamba.
@article{chen2024rsmamba,
title={RSMamba: Remote Sensing Image Classification with State Space Model},
author={Chen, Keyan and Chen, Bowen and Liu, Chenyang and Li, Wenyuan and Zou, Zhengxia and Shi, Zhenwei},
journal={arXiv preprint arXiv:2403.19654},
year={2024}
}
This project is licensed under the Apache 2.0 License.
If you have any other questions❓, please contact us in time 👬