This repository provides an end-to-end pipeline to classify spinal stenosis severity across different vertebral levels using a hybrid approach that integrates pre-trained vision models and dedicated vertebral level embeddings. The method leverages the feature extraction capabilities of pre-trained backbones while enriching them with spinal-level context, ultimately improving classification performance on the RSNA Spinal Stenosis dataset.
We implemented six distinct pre-trained vision models as backbones:
- Vision Transformer (ViT-B/16) [1]
- Swin Transformer (Swin-B) [2]
- BEiT [3]
- EfficientNet V2-M [4]
- ResNet-152 [5]
- ConvNeXt Base [6]
All models were pre-trained on ImageNet-1K. The original classification heads were removed, and their feature extraction layers were retained to produce feature vectors of dimensions ranging from 768 to 2048, depending on the backbone.
To include vertebral level information, we introduced a dedicated embedding module:
- Maps each spinal level (L1/L2 through L5/S1) to a 256-dimensional embedding vector.
- Concatenates the level embedding with the backbone’s feature vector.
- The combined feature vector is passed through a Multi-Layer Perceptron (MLP) with the following structure:
- Input:
[Backbone Features + 256-D Level Embedding]
- MLP Architecture:
512 → 256 → 3
- Each layer is followed by Layer Normalization, GELU activation, and a dropout layer for regularization.
- Input:
- Optimizer: AdamW
learning_rate=1e-4
weight_decay=0.02
- Batch Size: 32
- Number of Epochs: 30
- Loss Function: Cross-Entropy Loss with class weights to address class imbalance.
- Data Augmentation:
- Random horizontal flips
- Random rotations
- Normalization using ImageNet statistics:
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
- Initially freeze backbone weights to retain pre-trained knowledge.
- Selectively unfreeze the last 32 layers for fine-tuning.
- Fully train the level embedding module and classification head.
Model performance is evaluated using the following metrics:
- Accuracy
- Precision
- Recall
- F1-Score
Confusion matrices are generated for each severity grade to provide detailed insights into model predictions. All evaluations are conducted on held-out validation and test sets to ensure unbiased assessments.
- Python 3.8+
- PyTorch
- torchvision
- transformers (Hugging Face)
- pandas
- argparse
Install the required dependencies via:
pip install torch torchvision transformers pandas
git clone https://github.com/Shijia1997/RSNA_spinal_stenosis_prediction.git
python train_models.py \
--models convnext efficientnet resnet152 \
--batch_size 32 \
--num_epochs 30 \
--learning_rate 1e-4 \
--weight_decay 0.02 \
--level_embeddings 256 \
--freeze_backbone True \
--unfreeze_last_n 32
python train_models.py \
--models vit swin beit \
--batch_size 32 \
--num_epochs 30 \
--learning_rate 1e-4 \
--weight_decay 0.02 \
--level_embeddings 256 \
--freeze_backbone True \
--unfreeze_last_n 32
After training, the following outputs will be generated:
- Metrics and confusion matrices: Saved in the
./results
directory. - Trained models: Saved in the
./trained_models
directory. - Plots and visualizations: Saved in the
./plots
directory.
- Dosovitskiy, A., et al. "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale." ICLR, 2021.
- Liu, Z., et al. "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows." ICCV, 2021.
- Bao, H., et al. "BEiT: BERT Pre-Training of Image Transformers." ICLR, 2022.
- Tan, M., et al. "EfficientNetV2: Smaller Models and Faster Training." ICML, 2021.
- He, K., et al. "Deep Residual Learning for Image Recognition." CVPR, 2016.
- Liu, Z., et al. "A ConvNet for the 2020s." CVPR, 2022.