The official implementation of CMX: Cross-Modal Fusion for RGB-X Semantic Segmentation with Transformers: More details can be found in our paper [PDF].
- Requirements
- Python 3.7+
- PyTorch 1.7.0 or higher
- CUDA 10.2 or higher
We have tested the following versions of OS and softwares:
- OS: Ubuntu 18.04.6 LTS
- CUDA: 10.2
- PyTorch 1.8.2
- Python 3.8.11
- Install all dependencies. Install pytorch, cuda and cudnn, then install other dependencies via:
pip install -r requirements.txt
Orgnize the dataset folder in the following structure:
<datasets>
|-- <DatasetName1>
|-- <RGBFolder>
|-- <name1>.<ImageFormat>
|-- <name2>.<ImageFormat>
...
|-- <ModalXFolder>
|-- <name1>.<ModalXFormat>
|-- <name2>.<ModalXFormat>
...
|-- <LabelFolder>
|-- <name1>.<LabelFormat>
|-- <name2>.<LabelFormat>
...
|-- train.txt
|-- test.txt
|-- <DatasetName2>
|-- ...
train.txt
contains the names of items in training set, e.g.:
<name1>
<name2>
...
For RGB-Depth semantic segmentation, the generation of HHA maps from Depth maps can refer to https://github.com/charlesCXK/Depth2HHA-python.
-
Pretrain weights:
Download the pretrained segformer here pretrained segformer.
-
Config
Edit config file in
configs.py
, including dataset and network settings. -
Run multi GPU distributed training:
$ CUDA_VISIBLE_DEVICES="GPU IDs" python -m torch.distributed.launch --nproc_per_node="GPU numbers you want to use" train.py
- The tensorboard file is saved in
log_<datasetName>_<backboneSize>/tb/
directory. - Checkpoints are stored in
log_<datasetName>_<backboneSize>/checkpoints/
directory.
Run the evaluation by:
CUDA_VISIBLE_DEVICES="GPU IDs" python eval.py -d="Device ID" -e="epoch number or range"
If you want to use multi GPUs please specify multiple Device IDs (0,1,2...).
We offer the pre-trained weights on different RGBX datasets (Some weights are not avaiable yet, Due to the difference of training platforms, these weights may not be correctly loaded.):
Architecture | Backbone | mIOU(SS) | mIOU(MS & Flip) | Weight |
---|---|---|---|---|
CMX (SegFormer) | MiT-B2 | 54.1% | 54.4% | NYU-MiT-B2 |
CMX (SegFormer) | MiT-B4 | 56.0% | 56.3% | |
CMX (SegFormer) | MiT-B5 | 56.8% | 56.9% |
Architecture | Backbone | mIOU | Weight |
---|---|---|---|
CMX (SegFormer) | MiT-B2 | 58.2% | MFNet-MiT-B2 |
CMX (SegFormer) | MiT-B4 | 59.7% |
Architecture | Backbone | mIOU | Weight |
---|---|---|---|
CMX (SegFormer) | MiT-B2 | 61.3% | ScanNet-MiT-B2 |
Architecture | Backbone | mIOU | Weight |
---|---|---|---|
CMX (SegFormer) | MiT-B4 | 64.28% | RGBE-MiT-B4 |
If you find this repo useful, please consider referencing the following paper:
@article{liu2022cmx,
title={CMX: Cross-Modal Fusion for RGB-X Semantic Segmentation with Transformers},
author={Liu, Huayao and Zhang, Jiaming and Yang, Kailun and Hu, Xinxin and Stiefelhagen, Rainer},
journal={arXiv preprint arXiv:2203.04838},
year={2022}
}
Our code is heavily based on TorchSeg and SA-Gate, thanks for their excellent work!