Skip to content

Official Repository for the TMLR-2023 paper titled: "Beyond Distribution Shift: Spurious Features Through the Lens of Training Dynamics"

License

Notifications You must be signed in to change notification settings

batmanlab/TMLR23_Dynamics_of_Spurious_Features

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

19 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Beyond Distribution Shift: Spurious Features Through the Lens of Training Dynamics

Official PyTorch implementation of the TMLR paper:
Beyond Distribution Shift: Spurious Features Through the Lens of Training Dynamics
Nihal Murali1, Aahlad Puli3, Ke Yu1, Rajesh Ranganath3, Kayhan Batmanghelich2
1 University of Pittsburgh (ISP), 2 Boston University (ECE), 3 New York University (CS)

Table of Contents

  1. Objective
  2. Environment setup
  3. Downloading data
  4. Data Preprocessing
  5. Training pipeline
  6. How to Cite
  7. License and copyright

Objective

Deep Neural Networks (DNNs) are prone to learning spurious features that correlate with the label during training but are irrelevant to the learning problem. This hurts model generalization and poses problems when deploying them in safety-critical applications. This paper aims to better understand the effects of spurious features through the lens of the learning dynamics of the internal neurons during the training process. We make the following observations: (1) While previous works highlight the harmful effects of spurious features on the generalization ability of DNNs, we emphasize that not all spurious features are harmful. Spurious features can be “benign” or “harmful” depending on whether they are “harder” or “easier” to learn than the core features for a given model. This definition is model and dataset dependent. (2) We build upon this premise and use instance difficulty methods (like Prediction Depth (Baldock et al., 2021)) to quantify “easiness” for a given model and to identify this behavior during the training phase. (3) We empirically show that the harmful spurious features can be detected by observing the learning dynamics of the DNN’s early layers. In other words, easy features learned by the initial layers of a DNN early during the training can (potentially) hurt model generalization. We verify our claims on medical and vision datasets, both simulated and real, and justify the empirical success of our hypothesis by showing the theoretical connections between Prediction Depth and information-theoretic concepts like V-usable information (Ethayarajh et al., 2021). Lastly, our experiments show that monitoring only accuracy during training (as is common in machine learning pipelines) is insufficient to detect spurious features. We, therefore, highlight the need for monitoring early training dynamics using suitable instance difficulty metrics.


Environment setup

Run the bash commands below to setup the environment having the relevant libraries/packages.

conda env create --name TMLR23_spurious_dynamics -f environment.yml
conda activate TMLR23_spurious_dynamics

Downloading data

Download the relevant data files from below for which you want to compute Prediction Depth (PD) and detect spurious features.

Dataset URL
NIH Kaggle Link
MIMIC-CXR PhysioNet portal
CheXpert Stanford ML Group
GitHub-COVID covid-chestxray-dataset GitHub

Data Preprocessing

PD computation requires running a k-Nearest Neighbor classifier (between the input image and training data) on the intermediate layers of the model. To get the embeddings at each layer which serves as the training data for the k-NN we need a subset of training data given below. The subset must have equal number of positives and negatives.

Link Description
nih_full.csv NIH full data (see Fig-8)
nih_subset.csv NIH (subset): train embeddings for k-NN PD computations (see Fig-8)
chex_mimic_full.csv Chex-MIMIC full data (see Fig-7a)
chex_mimic_subset.csv Chex-MIMIC (subset): train embeddings for k-NN PD computations (see Fig-7a)
covid_full.csv Covid full data (see Fig-7b)
covid_subset.csv Covid (subset): train embeddings for k-NN PD computations (see Fig-7b)

Training pipeline

  1. See the example config file in ./configs/nih.yaml. Replace the data_file with path to <dataset_name>_full.csv file above, and change the other variables as appropriate.

Run the below command to train a densenet-121 model on the above dataset. Make sure to set in right values for the arguments

python ./scripts/train_densenet.py --main_dir <path/to/repo> 
  1. Analyse the checkpoint using PD plots
python ./scripts/analyse_ckpt.py --main_dir <path/to/repo> --expt_name nih-analysis --ckpt_path </path/to/checkpoint> --csv_train_embs </path/to/dataset_subset.csv> --csv_plot_pd </path/to/test/images/for/PD_plot>

The PD plot is saved as ./output/<expt_name>_pd_plot.svg"

  1. Analyse the peaks in PD plot with Grad-CAM

Follow the steps in the ./notebooks/pd_analysis.ipynb script

Other Experiments

  1. Dominoes Experiments (see Fig-5 and Table-1 in Main Paper)

a. Follow data-generation steps as outlined in: ./notebooks/domino_generation.ipynb b. Train models on domino-datasets by following steps in: ./notebooks/domino_training.ipynb

  1. Not all spurious features hurt generalization! (Toy data expts shown in Fig-6)

The ./notebooks/toy_data_expts.ipynb notebook has all the steps: data generation, model training, pd plots, gradCAM/shap analysis

How to Cite

  • TMLR23 Main Paper
@article{
  murali2023beyond,
  title={Beyond Distribution Shift: Spurious Features Through the Lens of Training Dynamics},
  author={Nihal Murali and Aahlad Manas Puli and Ke Yu and Rajesh Ranganath and kayhan Batmanghelich},
  journal={Transactions on Machine Learning Research},
  issn={2835-8856},
  year={2023},
  url={https://openreview.net/forum?id=Tkvmt9nDmB},
  note={}
}
  • Shortcut paper published in Workshop on Spurious Correlations, Invariance and Stability, ICML 2023
@article{muralishortcut,
  title={Shortcut Learning Through the Lens of Training Dynamics},
  author={Murali, Nihal and Puli, Aahlad Manas and Yu, Ke and Ranganath, Rajesh and Batmanghelich, Kayhan}
}

License and copyright

Licensed under the MIT License

Copyright © Batman Lab, 2023

About

Official Repository for the TMLR-2023 paper titled: "Beyond Distribution Shift: Spurious Features Through the Lens of Training Dynamics"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published