This repository contains the implementation of RankSim (ICML 2022) on IMDB-WIKI-DIR dataset.
The imbalanced regression framework and LDS+FDS are based on the public repository of Yang et al., ICML 2021.
The blackbox combinatorial solver is based on the public repository of Vlastelica et al., ICLR 2020.
- Download and extract IMDB faces and WIKI faces respectively using
python download_imdb_wiki.py
- We use the standard train/val/test split file (
imdb_wiki.csv
in folder./data
) provided by Yang et al.(ICML 2021), which is used to set up balanced val/test set. To reproduce the results in the paper, please directly use this file. You can also generate it using
python data/create_imdb_wiki.py
python data/preprocess_imdb_wiki.py
- PyTorch (>= 1.2, tested on 1.6)
- numpy, pandas, scipy, tqdm, matplotlib, PIL, wget
train.py
: main training and evaluation scriptcreate_imdb_wiki.py
: create IMDB-WIKI raw meta datapreprocess_imdb_wiki.py
: create IMDB-WIKI-DIR meta fileimdb_wiki.csv
with balanced val/test set
--data_dir
: data directory to place data and meta file--reweight
: cost-sensitive re-weighting scheme to use--loss
: training loss type--regularization_weight
: gamma, weight of the regularization term (default 100.0)--interpolation_lambda
: lambda, interpolation strength parameter(default 2.0)
To use Vanilla model
python train.py --batch_size 256 --lr 1e-3
To use square-root frequence inverse (SQINV)
python train.py --batch_size 256 --lr 1e-3 --reweight sqrt_inv
To use LDS (Yang et al., ICML 2021) with originally reported hyperparameters
python train.py --batch_size 256 --lr 1e-3 --reweight sqrt_inv --lds --lds_kernel gaussian --lds_ks 5 --lds_sigma 2
To use FDS (Yang et al., ICML 2021) with originally reported hyperparameters
python train.py --batch_size 256 --lr 1e-3 --fds --fds_kernel gaussian --fds_ks 5 --fds_sigma 2
python train.py --batch_size 256 --lr 1e-3 --regularization_weight=100.0 --interpolation_lambda=2.0
python train.py --batch_size 256 --lr 1e-3 --reweight sqrt_inv --regularization_weight=100.0 --interpolation_lambda=2.0
To use RankSim with Focal-R loss
python train.py --loss focal_l1 --batch_size 256 --lr 1e-3 --regularization_weight=100.0 --interpolation_lambda=2.0
To use RankSim (gamma: 100.0, lambda: 2.0) with Gaussian kernel (kernel size: 5, sigma: 2)
python train.py --batch_size 256 --lr 1e-3 --reweight sqrt_inv --lds --lds_kernel gaussian --lds_ks 5 --lds_sigma 2 --regularization_weight=100.0 --interpolation_lambda=2.0
To use RankSim (gamma: 100.0, lambda: 2.0) with Gaussian kernel (kernel size: 5, sigma: 2)
python train.py --batch_size 256 --lr 1e-3 --fds --fds_kernel gaussian --fds_ks 5 --fds_sigma 2 --regularization_weight=100.0 --interpolation_lambda=2.0
To use RankSim (gamma: 100.0, lambda: 2.0) with LDS (Gaussian kernel, kernel size: 5, sigma: 2) and FDS (Gaussian kernel, kernel size: 5, sigma: 2)
python train.py --batch_size 256 --lr 1e-3 --reweight sqrt_inv --lds --lds_kernel gaussian --lds_ks 5 --lds_sigma 2 --fds --fds_kernel gaussian --fds_ks 5 --fds_sigma 2 --regularization_weight=100.0 --interpolation_lambda=2.0
NOTE: We find different batch sizes (e.g. batch size of 64 & learn rate of 2.5e-4) sometimes can improve the performance. You can try different batch size by changing the arguments, e.g. run SQINV + RankSim with batch size 64, learning rate 2.5e-4
python train.py --batch_size 64 --lr 2.5e-4 --reweight sqrt_inv --regularization_weight=100.0 --interpolation_lambda=2.0
If you do not train the model, you can evaluate the model and reproduce our results directly using the pretrained weights from the links below.
python train.py --evaluate [...evaluation model arguments...] --resume <path_to_evaluation_ckpt>
Focal-R + LDS + FDS + RankSim, MAE All-shot 7.67
(weights)
RRT + FDS + RankSim, MAE All 7.35 (best MAE All-shot)
(weights)
SQINV + RankSim, MAE All-shot 7.42
(weights)
SQINV + LDS + FDS + RankSim, MAE All 7.69, MAE Few-shot 21.43 (best MAE Few-shot)
(weights)