MPVAE: Disentangled Variational Autoencoder based Multi-Label Classification with Covariance-Aware Multivariate Probit Model
Junwen Bai, Shufeng Kong, Carla Gomes
IJCAI-PRICAI 2020
[paper]
In this paper, we propose Multi-variate Probit based Variational AutoEncoder (MPVAE) to 1) align the label embedding subspace and the feature embedding subspace and 2) handle the correlations between labels via classic Multi-variate Probit model. MPVAE improves both the embedding space learning and label correlation encoding. Furthermore, β-VAE brings disentanglement effects and could improve the performance compared to vanilla VAE.
- Python 3.7+
- TensorFlow 1.15.0
- numpy 1.17.3
- sklearn 0.22.1
Older versions might work as well.
A PyTorch implementation of MPVAE can be found here.
git clone
this repo to your local machine.
All datasets can be downloaded from the Google drive Baidu drive.
The downloaded datasets are already in the format that can be recognized by the code.
The downloaded datasets are organized in the npy
format. There are 4 npy
files in total. One contains the data entries and the others are indices for train, validation and test splits. For example, mirflickr dataset has 4 npy
files: mirflickr_data.npy
, mirflickr_train_idx.npy
, mirflickr_val_idx.npy
, mirflickr_test_idx.npy
.
The other 3 npy
files are just the lists of indices for different splits.
We use mirflickr as the running example here. The detailed descriptions of FLAGS can be found in config.py
.
To train the model, use the following script:
./run_train_mirflickr.sh
The best validation checkpoint will be written into run_test_mirflickr.sh
automatically, if one sets the flag write_to_test_sh
to True
and specifies the path to the test bash with flag test_sh_path
.
To test the model, use the following script:
./run_test_mirflickr.sh
The default hyper-parameters should give reasonably good results.
If you have any questions, feel free to open an issue.
One can further check the scripts under scripts
folder, which contains tuned hyperparameters for most datasets.
If you find our paper interesting, or will use the datasets we collected, please cite our paper:
@inproceedings{bai2021disentangled,
title={Disentangled variational autoencoder based multi-label classification with covariance-aware multivariate probit model},
author={Bai, Junwen and Kong, Shufeng and Gomes, Carla},
booktitle={Proceedings of the Twenty-Ninth International Conference on International Joint Conferences on Artificial Intelligence},
pages={4313--4321},
year={2021}
}