Skip to content

Commit

Permalink
Merge pull request #213 from jonasrauber/early_stopping
Browse files Browse the repository at this point in the history
support for early stopping when reaching a certain perturbation size
  • Loading branch information
jonasrauber authored Sep 27, 2018
2 parents 2d3d4bb + dd74319 commit c8af62a
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 18 deletions.
41 changes: 40 additions & 1 deletion foolbox/adversarial.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,16 @@
import numpy as np
import numbers

from .distances import Distance
from .distances import MSE


class StopAttack(Exception):
"""Exception thrown to request early stopping of an attack
if a given (optional!) threshold is reached."""
pass


class Adversarial(object):
"""Defines an adversarial that should be found and stores the result.
Expand All @@ -29,6 +36,17 @@ class Adversarial(object):
The ground-truth label of the original image.
distance : a :class:`Distance` class
The measure used to quantify similarity between images.
threshold : float or :class:`Distance`
If not None, the attack will stop as soon as the adversarial
perturbation has a size smaller than this threshold. Can be
an instance of the :class:`Distance` class passed to the distance
argument, or a float assumed to have the same unit as the
the given distance. If None, the attack will simply minimize
the distance as good as possible. Note that the threshold only
influences early stopping of the attack; the returned adversarial
does not necessarily have smaller perturbation size than this
threshold; the `reached_threshold()` method can be used to check
if the threshold has been reached.
"""
def __init__(
Expand All @@ -38,6 +56,7 @@ def __init__(
original_image,
original_class,
distance=MSE,
threshold=None,
verbose=False):

self.__model = model
Expand All @@ -46,6 +65,11 @@ def __init__(
self.__original_image_for_distance = original_image
self.__original_class = original_class
self.__distance = distance

if threshold is not None and not isinstance(threshold, Distance):
threshold = distance(value=threshold)
self.__threshold = threshold

self.verbose = verbose

self.__best_adversarial = None
Expand All @@ -59,7 +83,13 @@ def __init__(
self._best_gradient_calls = 0

# check if the original image is already adversarial
self.predictions(original_image)
try:
self.predictions(original_image)
except StopAttack:
# if a threshold is specified and the original input is
# misclassified, this can already cause a StopAttack
# exception
assert self.distance.value == 0.

def _reset(self):
self.__best_adversarial = None
Expand Down Expand Up @@ -152,6 +182,12 @@ def normalized_distance(self, image):
image,
bounds=self.bounds())

def reached_threshold(self):
"""Returns True if a threshold is given and the currently
best adversarial distance is smaller than the threshold."""
return self.__threshold is not None \
and self.__best_distance <= self.__threshold

def __new_adversarial(self, image, predictions, in_bounds):
image = image.copy() # to prevent accidental inplace changes
distance = self.normalized_distance(image)
Expand All @@ -167,6 +203,9 @@ def __new_adversarial(self, image, predictions, in_bounds):
self._best_prediction_calls = self._total_prediction_calls
self._best_gradient_calls = self._total_gradient_calls

if self.reached_threshold():
raise StopAttack

return True, distance
return False, distance

Expand Down
4 changes: 1 addition & 3 deletions foolbox/attacks/adef_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from .base import Attack
from .base import call_decorator
from ..criteria import Misclassification


def _transpose_image(image):
Expand Down Expand Up @@ -176,8 +175,7 @@ class ADefAttack(Attack):
.. [2]_ https://gitlab.math.ethz.ch/tandrig/ADef/tree/master
"""

def __init__(self, model=None, criterion=Misclassification()):
super(ADefAttack, self).__init__(model=model, criterion=criterion)
def _initialize(self):
self.vector_field = None

@call_decorator
Expand Down
65 changes: 55 additions & 10 deletions foolbox/attacks/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
import logging
import functools
import sys
import abc
Expand All @@ -10,7 +11,9 @@
ABC = abc.ABCMeta('ABC', (), {})

from ..adversarial import Adversarial
from ..adversarial import StopAttack
from ..criteria import Misclassification
from ..distances import MSE


class Attack(ABC):
Expand All @@ -22,12 +25,27 @@ class Attack(ABC):
Parameters
----------
model : :class:`adversarial.Model`
The default model to which the attack is applied if it is not called
with an :class:`Adversarial` instance.
criterion : :class:`adversarial.Criterion`
The default criterion that defines what is adversarial if the attack
is not called with an :class:`Adversarial` instance.
model : a :class:`Model` instance
The model that should be fooled by the adversarial.
Ignored if the attack is called with an :class:`Adversarial` instance.
criterion : a :class:`Criterion` instance
The criterion that determines which images are adversarial.
Ignored if the attack is called with an :class:`Adversarial` instance.
distance : a :class:`Distance` class
The measure used to quantify similarity between images.
Ignored if the attack is called with an :class:`Adversarial` instance.
threshold : float or :class:`Distance`
If not None, the attack will stop as soon as the adversarial
perturbation has a size smaller than this threshold. Can be
an instance of the :class:`Distance` class passed to the distance
argument, or a float assumed to have the same unit as the
the given distance. If None, the attack will simply minimize
the distance as good as possible. Note that the threshold only
influences early stopping of the attack; the returned adversarial
does not necessarily have smaller perturbation size than this
threshold; the `reached_threshold()` method can be used to check
if the threshold has been reached.
Ignored if the attack is called with an :class:`Adversarial` instance.
Notes
-----
Expand All @@ -36,9 +54,24 @@ class Attack(ABC):
"""

def __init__(self, model=None, criterion=Misclassification()):
def __init__(self,
model=None, criterion=Misclassification(),
distance=MSE, threshold=None):
self._default_model = model
self._default_criterion = criterion
self._default_distance = distance
self._default_threshold = threshold

# to customize the initialization in subclasses, please
# try to overwrite _initialize instead of __init__ if
# possible
self._initialize()

def _initialize(self):
"""Additional initializer that can be overwritten by
subclasses without redefining the full __init__ method
including all arguments and documentation."""
pass

@abstractmethod
def __call__(self, input_or_adv, label=None, unpack=True, **kwargs):
Expand Down Expand Up @@ -80,22 +113,34 @@ def wrapper(self, input_or_adv, label=None, unpack=True, **kwargs):
else:
model = self._default_model
criterion = self._default_criterion
distance = self._default_distance
threshold = self._default_threshold
if model is None or criterion is None:
raise ValueError('The attack needs to be initialized'
' with a model and a criterion or it'
' needs to be called with an Adversarial'
' instance.')
a = Adversarial(model, criterion, input_or_adv, label)
a = Adversarial(model, criterion, input_or_adv, label,
distance=distance, threshold=threshold)

assert a is not None

if a.distance.value == 0.:
warnings.warn('Not running the attack because the original input'
' is already misclassified and the adversarial thus'
' has a distance of 0.')
elif a.reached_threshold():
warnings.warn('Not running the attack because the given treshold'
' is already reached')
else:
_ = call_fn(self, a, label=None, unpack=None, **kwargs)
assert _ is None, 'decorated __call__ method must return None'
try:
_ = call_fn(self, a, label=None, unpack=None, **kwargs)
assert _ is None, 'decorated __call__ method must return None'
except StopAttack:
# if a threshold is specified, StopAttack will be thrown
# when the treshold is reached; thus we can do early
# stopping of the attack
logging.info('threshold reached, stopping attack')

if a.image is None:
warnings.warn('{} did not find an adversarial, maybe the model'
Expand Down
4 changes: 0 additions & 4 deletions foolbox/attacks/boundary_attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from .base import Attack
from .base import call_decorator
from .blended_noise import BlendedUniformNoiseAttack
from ..criteria import Misclassification

import numpy as np
from numpy.linalg import norm
Expand Down Expand Up @@ -52,9 +51,6 @@ class BoundaryAttack(Attack):
"""

def __init__(self, model=None, criterion=Misclassification()):
super(BoundaryAttack, self).__init__(model=model, criterion=criterion)

@call_decorator
def __call__(
self,
Expand Down
60 changes: 60 additions & 0 deletions foolbox/tests/test_attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,63 @@ def test_base_attack(model, criterion, image, label):
attack = attacks.FGSM()
with pytest.raises(ValueError):
attack(image, label=wrong_label)


def test_early_stopping(bn_model, bn_criterion, bn_image, bn_label):
attack = attacks.FGSM()

model = bn_model
criterion = bn_criterion
image = bn_image
label = bn_label

wrong_label = label + 1
adv = Adversarial(model, criterion, image, wrong_label)
attack(adv)
assert adv.distance.value == 0
assert not adv.reached_threshold() # because no threshold specified

adv = Adversarial(model, criterion, image, wrong_label, threshold=1e10)
attack(adv)
assert adv.distance.value == 0
assert adv.reached_threshold()

adv = Adversarial(model, criterion, image, label)
attack(adv)
assert adv.distance.value > 0
assert not adv.reached_threshold() # because no threshold specified

c = adv._total_prediction_calls
d = adv.distance.value
large_d = 10 * d
small_d = d / 2

adv = Adversarial(model, criterion, image, label,
threshold=adv._distance(value=large_d))
attack(adv)
assert 0 < adv.distance.value <= large_d
assert adv.reached_threshold()
assert adv._total_prediction_calls < c

adv = Adversarial(model, criterion, image, label,
threshold=large_d)
attack(adv)
assert 0 < adv.distance.value <= large_d
assert adv.reached_threshold()
assert adv._total_prediction_calls < c

adv = Adversarial(model, criterion, image, label,
threshold=small_d)
attack(adv)
assert small_d < adv.distance.value <= large_d
assert not adv.reached_threshold()
assert adv._total_prediction_calls == c
assert adv.distance.value == d

adv = Adversarial(model, criterion, image, label,
threshold=adv._distance(value=large_d))
attack(adv)
assert adv.reached_threshold()
c = adv._total_prediction_calls
attack(adv)
assert adv._total_prediction_calls == c # no new calls

0 comments on commit c8af62a

Please sign in to comment.