This repository contains the code needed to evaluate models trained in Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples (Gowal et al., 2020) and in Fixing Data Augmentation to Improve Adversarial Robustness (Rebuffi et al., 2021).
We have released our top-performing models in two formats compatible with JAX and PyTorch. This repository also contains our model definitions.
Download a model from links listed in the following table. Clean and robust accuracies are measured on the full test set. The robust accuracy is measured using AutoAttack. The following table contains the models from Gowal et al., 2020.
dataset | norm | radius | architecture | extra data | clean | robust | link |
---|---|---|---|---|---|---|---|
CIFAR-10 | ℓ∞ | 8 / 255 | WRN-70-16 | ✓ | 91.10% | 65.88% | jax, pt |
CIFAR-10 | ℓ∞ | 8 / 255 | WRN-28-10 | ✓ | 89.48% | 62.80% | jax, pt |
CIFAR-10 | ℓ∞ | 8 / 255 | WRN-70-16 | ✗ | 85.29% | 57.20% | jax, pt |
CIFAR-10 | ℓ∞ | 8 / 255 | WRN-34-20 | ✗ | 85.64% | 56.86% | jax, pt |
CIFAR-10 | ℓ2 | 128 / 255 | WRN-70-16 | ✓ | 94.74% | 80.53% | jax, pt |
CIFAR-10 | ℓ2 | 128 / 255 | WRN-70-16 | ✗ | 90.90% | 74.50% | jax, pt |
CIFAR-100 | ℓ∞ | 8 / 255 | WRN-70-16 | ✓ | 69.15% | 36.88% | jax, pt |
CIFAR-100 | ℓ∞ | 8 / 255 | WRN-70-16 | ✗ | 60.86% | 30.03% | jax, pt |
MNIST | ℓ∞ | 0.3 | WRN-28-10 | ✗ | 99.26% | 96.34% | jax, pt |
The following table contains the models from Rebuffi et al., 2021.
dataset | norm | radius | architecture | extra data | clean | robust | link |
---|---|---|---|---|---|---|---|
CIFAR-10 | ℓ∞ | 8 / 255 | WRN-106-16 | ✗ | 88.50% | 64.64% | jax, pt |
CIFAR-10 | ℓ∞ | 8 / 255 | WRN-70-16 | ✗ | 88.54% | 64.25% | jax, pt |
CIFAR-10 | ℓ∞ | 8 / 255 | WRN-28-10 | ✗ | 87.33% | 60.75% | jax, pt |
CIFAR-10 | ℓ2 | 128 / 255 | WRN-70-16 | ✗ | 92.41% | 80.42% | jax, pt |
CIFAR-10 | ℓ2 | 128 / 255 | WRN-28-10 | ✗ | 91.79% | 78.80% | jax, pt |
CIFAR-100 | ℓ∞ | 8 / 255 | WRN-70-16 | ✗ | 63.56% | 34.64% | jax, pt |
CIFAR-100 | ℓ∞ | 8 / 255 | WRN-28-10 | ✗ | 62.41% | 32.06% | jax, pt |
The following has been tested using Python 3.9.2.
Using run.sh
will create and activate a virtualenv, install all necessary
dependencies and run a test program to ensure that you can import all the
modules.
# Run from the parent directory.
sh adversarial_robustness/run.sh
To run the provided code, use this virtualenv:
source /tmp/adversarial_robustness_venv/bin/activate
You may want to edit requirements.txt
before running run.sh
if GPU support
is needed (e.g., use jaxline==0.1.67+cuda111
). See JAX's installation
instructions for more details.
Once downloaded, a model can be evaluated by running the eval.py
script in
either the jax
or pytorch
folders. E.g.:
cd jax
python3 eval.py \
--ckpt=${PATH_TO_CHECKPOINT} --depth=70 --width=16 --dataset=cifar10
These models are also directly available within RobustBench's model zoo.
We also provide a training pipeline that reproduces results from both
publications. This pipeline uses Jaxline
and is written using JAX and
Haiku. To train a model, modify the
configuration in the get_config()
function of jax/experiment.py
and issue
the following command from within the virtualenv created above:
cd jax
python3 train.py --config=experiment.py
The training pipeline can run with multiple worker machines and multiple devices (either GPU or TPU). See Jaxline for more details.
We do not provide a PyTorch implementation of our training pipeline. However, you may find one on GitHub, e.g., adversarial_robustness_pytorch (by Rahul Rade).
Gowal et al. (2020) use samples extracted from TinyImages-80M. Unfortunately, since then, the official TinyImages-80M dataset has been withdrawn (due to the presence of offensive images). As such, we cannot provide a download link to our extrated data until we have manually verified that all extracted images are not offensive. If you want to reproduce our setup, consider the generated datasets below. We are also happy to help, so feel free to reach out to Sven Gowal directly.
Rebuffi et al. (2021) use samples generated by a Denoising Diffusion Probabilistic Model (DDPM; Ho et al., 2020) to improve robustness. The DDPM is solely trained on the original training data and does not use additional external data. The following table links to datasets of 1M generated samples for CIFAR-10, CIFAR-100 and SVHN.
dataset | model | size | link |
---|---|---|---|
CIFAR-10 | DDPM | 1M | npz |
CIFAR-100 | DDPM | 1M | npz |
SVHN | DDPM | 1M | npz |
To load each dataset, use NumPy. E.g.:
npzfile = np.load('cifar10_ddpm.npz')
images = npzfile['image']
labels = npzfile['label']
If you use this code (or any derived code), data or these models in your work, please cite the relevant accompanying paper:
@article{gowal2020uncovering,
title={Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples},
author={Gowal, Sven and Qin, Chongli and Uesato, Jonathan and Mann, Timothy and Kohli, Pushmeet},
journal={arXiv preprint arXiv:2010.03593},
year={2020},
url={https://arxiv.org/pdf/2010.03593}
}
and/or
@article{rebuffi2021fixing,
title={Fixing Data Augmentation to Improve Adversarial Robustness},
author={Rebuffi, Sylvestre-Alvise and Gowal, Sven and Calian, Dan A. and Stimberg, Florian and Wiles, Olivia and Mann, Timothy},
journal={arXiv preprint arXiv:2103.01946},
year={2021},
url={https://arxiv.org/pdf/2103.01946}
}
This is not an official Google product.