Official PyTorch implementation of the TMLR paper:
Beyond Distribution Shift: Spurious Features Through the Lens of Training Dynamics
Nihal Murali1,
Aahlad Puli3,
Ke Yu1,
Rajesh Ranganath3,
Kayhan Batmanghelich2
1 University of Pittsburgh (ISP), 2 Boston University (ECE), 3 New York University (CS)
- Objective
- Environment setup
- Downloading data
- Data Preprocessing
- Training pipeline
- How to Cite
- License and copyright
Deep Neural Networks (DNNs) are prone to learning spurious features that correlate with the label during training but are irrelevant to the learning problem. This hurts model generalization and poses problems when deploying them in safety-critical applications. This paper aims to better understand the effects of spurious features through the lens of the learning dynamics of the internal neurons during the training process. We make the following observations: (1) While previous works highlight the harmful effects of spurious features on the generalization ability of DNNs, we emphasize that not all spurious features are harmful. Spurious features can be “benign” or “harmful” depending on whether they are “harder” or “easier” to learn than the core features for a given model. This definition is model and dataset dependent. (2) We build upon this premise and use instance difficulty methods (like Prediction Depth (Baldock et al., 2021)) to quantify “easiness” for a given model and to identify this behavior during the training phase. (3) We empirically show that the harmful spurious features can be detected by observing the learning dynamics of the DNN’s early layers. In other words, easy features learned by the initial layers of a DNN early during the training can (potentially) hurt model generalization. We verify our claims on medical and vision datasets, both simulated and real, and justify the empirical success of our hypothesis by showing the theoretical connections between Prediction Depth and information-theoretic concepts like V-usable information (Ethayarajh et al., 2021). Lastly, our experiments show that monitoring only accuracy during training (as is common in machine learning pipelines) is insufficient to detect spurious features. We, therefore, highlight the need for monitoring early training dynamics using suitable instance difficulty metrics.
Run the bash commands below to setup the environment having the relevant libraries/packages.
conda env create --name TMLR23_spurious_dynamics -f environment.yml
conda activate TMLR23_spurious_dynamics
Download the relevant data files from below for which you want to compute Prediction Depth (PD) and detect spurious features.
Dataset | URL |
---|---|
NIH | Kaggle Link |
MIMIC-CXR | PhysioNet portal |
CheXpert | Stanford ML Group |
GitHub-COVID | covid-chestxray-dataset GitHub |
PD computation requires running a k-Nearest Neighbor classifier (between the input image and training data) on the intermediate layers of the model. To get the embeddings at each layer which serves as the training data for the k-NN we need a subset of training data given below. The subset must have equal number of positives and negatives.
Link | Description |
---|---|
nih_full.csv | NIH full data (see Fig-8) |
nih_subset.csv | NIH (subset): train embeddings for k-NN PD computations (see Fig-8) |
chex_mimic_full.csv | Chex-MIMIC full data (see Fig-7a) |
chex_mimic_subset.csv | Chex-MIMIC (subset): train embeddings for k-NN PD computations (see Fig-7a) |
covid_full.csv | Covid full data (see Fig-7b) |
covid_subset.csv | Covid (subset): train embeddings for k-NN PD computations (see Fig-7b) |
- See the example config file in ./configs/nih.yaml. Replace the data_file with path to <dataset_name>_full.csv file above, and change the other variables as appropriate.
Run the below command to train a densenet-121 model on the above dataset. Make sure to set in right values for the arguments
python ./scripts/train_densenet.py --main_dir <path/to/repo>
- Analyse the checkpoint using PD plots
python ./scripts/analyse_ckpt.py --main_dir <path/to/repo> --expt_name nih-analysis --ckpt_path </path/to/checkpoint> --csv_train_embs </path/to/dataset_subset.csv> --csv_plot_pd </path/to/test/images/for/PD_plot>
The PD plot is saved as ./output/<expt_name>_pd_plot.svg"
- Analyse the peaks in PD plot with Grad-CAM
Follow the steps in the ./notebooks/pd_analysis.ipynb script
- Dominoes Experiments (see Fig-5 and Table-1 in Main Paper)
a. Follow data-generation steps as outlined in: ./notebooks/domino_generation.ipynb b. Train models on domino-datasets by following steps in: ./notebooks/domino_training.ipynb
- Not all spurious features hurt generalization! (Toy data expts shown in Fig-6)
The ./notebooks/toy_data_expts.ipynb notebook has all the steps: data generation, model training, pd plots, gradCAM/shap analysis
- TMLR23 Main Paper
@article{
murali2023beyond,
title={Beyond Distribution Shift: Spurious Features Through the Lens of Training Dynamics},
author={Nihal Murali and Aahlad Manas Puli and Ke Yu and Rajesh Ranganath and kayhan Batmanghelich},
journal={Transactions on Machine Learning Research},
issn={2835-8856},
year={2023},
url={https://openreview.net/forum?id=Tkvmt9nDmB},
note={}
}
- Shortcut paper published in Workshop on Spurious Correlations, Invariance and Stability, ICML 2023
@article{muralishortcut,
title={Shortcut Learning Through the Lens of Training Dynamics},
author={Murali, Nihal and Puli, Aahlad Manas and Yu, Ke and Ranganath, Rajesh and Batmanghelich, Kayhan}
}
Licensed under the MIT License
Copyright © Batman Lab, 2023