This repository provides a PyTorch implementation of the CAiD: Context-Aware Instance Discrimination for Self-supervised Learning in Medical Imaging.
This is the first work that quantitatively and systematically shows the general limitation of instance discrimination methods in learning a distinct set of features from unlabeled medical images due to their global anatomical similarity. To alleviate this problem, we propose Context-Aware instance Discrimination (CAiD), a simple yet powerful self-supervised framework that formulates an auxiliary context prediction task to equip instance discrimination learning with fine-grained contextual information captured from local regions of images.
CAiD: Context-Aware Instance Discrimination for Self-supervised Learning in Medical Imaging
Mohammad Reza Hosseinzadeh Taher1, Fatemeh Haghighi1, Michael B. Gotway2, Jianming Liang1
1 Arizona State University, 2 Mayo Clinic
Published in: International Conference on Medical Imaging with Deep Learning (MIDL), 2022.
Paper | Code | Poster | Slides | Presentation (YouTube, YouKu)
- CAiD enriches existing instance discrimination methods.
- CAiD provides more separable features.
- CAiD provides more reusable low/mid-level features.
Credit to superbar by Scott Lowe for Matlab code of superbar.
- Linux
- Python
- Install PyTorch (pytorch.org)
Clone the repository and install dependencies using the following command:
$ git clone https://github.com/MR-HosseinzadehTaher/CAiD.git
$ cd CAiD/
$ pip install -r requirements.txt
We used traing set of ChestX-ray14 dataset for pre-training CAiD models, which can be downloaded from this link.
- The downloaded ChestX-ray14 should have a directory structure as follows:
ChestX-ray14/
|-- images/
|-- 00000012_000.png
|-- 00000017_002.png
...
We use 10% of training data for validation. We also provide the list of training and validation images in dataset/Xray14_train_official.txt
and dataset/Xray14_val_official.txt
, respectively. The training set is based on the officiall split provided by ChestX-ray14 dataset. Training labels are not used during pre-training stage. The path to images folder is required for pre-training stage.
This implementation only supports multi-gpu, DistributedDataParallel training, which is faster and simpler; single-gpu or DataParallel training is not supported. The instance discrimination setup follows MoCo. The checkpoints with the lowest validation loss are used for fine-tuning.
To do unsupervised pre-training of a U-Net model with ResNet-50 backbone on ChestX-ray14 using 4 NVIDIA V100 GPUs, run the following command:
python main_CAiD_moco.py /path/to/images/folder --dist-url 'tcp://localhost:10001' --multiprocessing-distributed \
--world-size 1 --rank 0 --mlp --moco-t 0.2 --cos --mode caid
For downstream tasks, we use the code provided by recent transfer learning benchmark in medical imaging.
CAiD provides a pre-trained U-Net model, which the encoder can be utilized for the classification and encoder-decoder for the segmentation downstream tasks.
For classification tasks, a ResNet-50 encoder can be initialized with the pre-trained encoder of CAiD as follows:
import torchvision.models as models
num_classes = #number of target task classes
weight = #path to CAiD pre-trained model
model = models.__dict__['resnet50'](num_classes=num_classes)
state_dict = torch.load(weight, map_location="cpu")
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
state_dict = {k.replace("encoder.", ""): v for k, v in state_dict.items()}
for k in list(state_dict.keys()):
if k.startswith('fc') or k.startswith('segmentation_head') or k.startswith('decoder') :
del state_dict[k]
msg = model.load_state_dict(state_dict, strict=False)
print("=> loaded pre-trained model '{}'".format(weight))
print("missing keys:", msg.missing_keys)
For segmentation tasks, a U-Net can be initialized with the pre-trained encoder and decoder of CAiD as follows:
import segmentation_models_pytorch as smp
backbone = 'resnet50'
weight = #path to CAiD pre-trained model
model=smp.Unet(backbone)
state_dict = torch.load(weight, map_location="cpu")
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
for k in list(state_dict.keys()):
if k.startswith('fc') or k.startswith('segmentation_head'):
del state_dict[k]
msg = model.load_state_dict(state_dict, strict=False)
print("=> loaded pre-trained model '{}'".format(weight))
print("missing keys:", msg.missing_keys)
If you use this code or use our pre-trained weights for your research, please cite our paper:
@inproceedings{
taher2022caid,
title={CAiD: Context-Aware Instance Discrimination for Self-supervised Learning in Medical Imaging},
author={ Hosseinzadeh Taher, Mohammad Reza and Haghighi, Fatemeh and Gotway, Michael and Liang, Jianming },
booktitle={Medical Imaging with Deep Learning},
year={2022},
}
This research has been supported in part by ASU and Mayo Clinic through a Seed Grant and an Innovation Grant and in part by the NIH under Award Number R01HL128785. The content is solely the responsibility of the authors and does not necessarily represent the official views of the NIH. This work has utilized the GPUs provided in part by the ASU Research Computing and in part by the Extreme Science and Engineering Discovery Environment (XSEDE) funded by the National Science Foundation (NSF) under grant number ACI-1548562. The content of this paper is covered by patents pending. We build U-Net architecture for segmentation tasks by referring to the released code at segmentation_models.pytorch. The instance discrimination is based on MoCo.
Released under the ASU GitHub Project License.