diff --git a/foolbox/adversarial.py b/foolbox/adversarial.py index aba527fa..8167f64d 100644 --- a/foolbox/adversarial.py +++ b/foolbox/adversarial.py @@ -6,9 +6,16 @@ import numpy as np import numbers +from .distances import Distance from .distances import MSE +class StopAttack(Exception): + """Exception thrown to request early stopping of an attack + if a given (optional!) threshold is reached.""" + pass + + class Adversarial(object): """Defines an adversarial that should be found and stores the result. @@ -29,6 +36,17 @@ class Adversarial(object): The ground-truth label of the original image. distance : a :class:`Distance` class The measure used to quantify similarity between images. + threshold : float or :class:`Distance` + If not None, the attack will stop as soon as the adversarial + perturbation has a size smaller than this threshold. Can be + an instance of the :class:`Distance` class passed to the distance + argument, or a float assumed to have the same unit as the + the given distance. If None, the attack will simply minimize + the distance as good as possible. Note that the threshold only + influences early stopping of the attack; the returned adversarial + does not necessarily have smaller perturbation size than this + threshold; the `reached_threshold()` method can be used to check + if the threshold has been reached. """ def __init__( @@ -38,6 +56,7 @@ def __init__( original_image, original_class, distance=MSE, + threshold=None, verbose=False): self.__model = model @@ -46,6 +65,11 @@ def __init__( self.__original_image_for_distance = original_image self.__original_class = original_class self.__distance = distance + + if threshold is not None and not isinstance(threshold, Distance): + threshold = distance(value=threshold) + self.__threshold = threshold + self.verbose = verbose self.__best_adversarial = None @@ -59,7 +83,13 @@ def __init__( self._best_gradient_calls = 0 # check if the original image is already adversarial - self.predictions(original_image) + try: + self.predictions(original_image) + except StopAttack: + # if a threshold is specified and the original input is + # misclassified, this can already cause a StopAttack + # exception + assert self.distance.value == 0. def _reset(self): self.__best_adversarial = None @@ -152,6 +182,12 @@ def normalized_distance(self, image): image, bounds=self.bounds()) + def reached_threshold(self): + """Returns True if a threshold is given and the currently + best adversarial distance is smaller than the threshold.""" + return self.__threshold is not None \ + and self.__best_distance <= self.__threshold + def __new_adversarial(self, image, predictions, in_bounds): image = image.copy() # to prevent accidental inplace changes distance = self.normalized_distance(image) @@ -167,6 +203,9 @@ def __new_adversarial(self, image, predictions, in_bounds): self._best_prediction_calls = self._total_prediction_calls self._best_gradient_calls = self._total_gradient_calls + if self.reached_threshold(): + raise StopAttack + return True, distance return False, distance diff --git a/foolbox/attacks/adef_attack.py b/foolbox/attacks/adef_attack.py index 71de9e92..04481ac6 100644 --- a/foolbox/attacks/adef_attack.py +++ b/foolbox/attacks/adef_attack.py @@ -6,7 +6,6 @@ from .base import Attack from .base import call_decorator -from ..criteria import Misclassification def _transpose_image(image): @@ -176,8 +175,7 @@ class ADefAttack(Attack): .. [2]_ https://gitlab.math.ethz.ch/tandrig/ADef/tree/master """ - def __init__(self, model=None, criterion=Misclassification()): - super(ADefAttack, self).__init__(model=model, criterion=criterion) + def _initialize(self): self.vector_field = None @call_decorator diff --git a/foolbox/attacks/base.py b/foolbox/attacks/base.py index 0d7f146b..5e0b77c5 100644 --- a/foolbox/attacks/base.py +++ b/foolbox/attacks/base.py @@ -1,4 +1,5 @@ import warnings +import logging import functools import sys import abc @@ -10,7 +11,9 @@ ABC = abc.ABCMeta('ABC', (), {}) from ..adversarial import Adversarial +from ..adversarial import StopAttack from ..criteria import Misclassification +from ..distances import MSE class Attack(ABC): @@ -22,12 +25,27 @@ class Attack(ABC): Parameters ---------- - model : :class:`adversarial.Model` - The default model to which the attack is applied if it is not called - with an :class:`Adversarial` instance. - criterion : :class:`adversarial.Criterion` - The default criterion that defines what is adversarial if the attack - is not called with an :class:`Adversarial` instance. + model : a :class:`Model` instance + The model that should be fooled by the adversarial. + Ignored if the attack is called with an :class:`Adversarial` instance. + criterion : a :class:`Criterion` instance + The criterion that determines which images are adversarial. + Ignored if the attack is called with an :class:`Adversarial` instance. + distance : a :class:`Distance` class + The measure used to quantify similarity between images. + Ignored if the attack is called with an :class:`Adversarial` instance. + threshold : float or :class:`Distance` + If not None, the attack will stop as soon as the adversarial + perturbation has a size smaller than this threshold. Can be + an instance of the :class:`Distance` class passed to the distance + argument, or a float assumed to have the same unit as the + the given distance. If None, the attack will simply minimize + the distance as good as possible. Note that the threshold only + influences early stopping of the attack; the returned adversarial + does not necessarily have smaller perturbation size than this + threshold; the `reached_threshold()` method can be used to check + if the threshold has been reached. + Ignored if the attack is called with an :class:`Adversarial` instance. Notes ----- @@ -36,9 +54,24 @@ class Attack(ABC): """ - def __init__(self, model=None, criterion=Misclassification()): + def __init__(self, + model=None, criterion=Misclassification(), + distance=MSE, threshold=None): self._default_model = model self._default_criterion = criterion + self._default_distance = distance + self._default_threshold = threshold + + # to customize the initialization in subclasses, please + # try to overwrite _initialize instead of __init__ if + # possible + self._initialize() + + def _initialize(self): + """Additional initializer that can be overwritten by + subclasses without redefining the full __init__ method + including all arguments and documentation.""" + pass @abstractmethod def __call__(self, input_or_adv, label=None, unpack=True, **kwargs): @@ -80,12 +113,15 @@ def wrapper(self, input_or_adv, label=None, unpack=True, **kwargs): else: model = self._default_model criterion = self._default_criterion + distance = self._default_distance + threshold = self._default_threshold if model is None or criterion is None: raise ValueError('The attack needs to be initialized' ' with a model and a criterion or it' ' needs to be called with an Adversarial' ' instance.') - a = Adversarial(model, criterion, input_or_adv, label) + a = Adversarial(model, criterion, input_or_adv, label, + distance=distance, threshold=threshold) assert a is not None @@ -93,9 +129,18 @@ def wrapper(self, input_or_adv, label=None, unpack=True, **kwargs): warnings.warn('Not running the attack because the original input' ' is already misclassified and the adversarial thus' ' has a distance of 0.') + elif a.reached_threshold(): + warnings.warn('Not running the attack because the given treshold' + ' is already reached') else: - _ = call_fn(self, a, label=None, unpack=None, **kwargs) - assert _ is None, 'decorated __call__ method must return None' + try: + _ = call_fn(self, a, label=None, unpack=None, **kwargs) + assert _ is None, 'decorated __call__ method must return None' + except StopAttack: + # if a threshold is specified, StopAttack will be thrown + # when the treshold is reached; thus we can do early + # stopping of the attack + logging.info('threshold reached, stopping attack') if a.image is None: warnings.warn('{} did not find an adversarial, maybe the model' diff --git a/foolbox/attacks/boundary_attack.py b/foolbox/attacks/boundary_attack.py index 95f03914..a1ba7f0b 100644 --- a/foolbox/attacks/boundary_attack.py +++ b/foolbox/attacks/boundary_attack.py @@ -16,7 +16,6 @@ from .base import Attack from .base import call_decorator from .blended_noise import BlendedUniformNoiseAttack -from ..criteria import Misclassification import numpy as np from numpy.linalg import norm @@ -52,9 +51,6 @@ class BoundaryAttack(Attack): """ - def __init__(self, model=None, criterion=Misclassification()): - super(BoundaryAttack, self).__init__(model=model, criterion=criterion) - @call_decorator def __call__( self, diff --git a/foolbox/tests/test_attacks.py b/foolbox/tests/test_attacks.py index cc427fea..23b1ee50 100644 --- a/foolbox/tests/test_attacks.py +++ b/foolbox/tests/test_attacks.py @@ -57,3 +57,63 @@ def test_base_attack(model, criterion, image, label): attack = attacks.FGSM() with pytest.raises(ValueError): attack(image, label=wrong_label) + + +def test_early_stopping(bn_model, bn_criterion, bn_image, bn_label): + attack = attacks.FGSM() + + model = bn_model + criterion = bn_criterion + image = bn_image + label = bn_label + + wrong_label = label + 1 + adv = Adversarial(model, criterion, image, wrong_label) + attack(adv) + assert adv.distance.value == 0 + assert not adv.reached_threshold() # because no threshold specified + + adv = Adversarial(model, criterion, image, wrong_label, threshold=1e10) + attack(adv) + assert adv.distance.value == 0 + assert adv.reached_threshold() + + adv = Adversarial(model, criterion, image, label) + attack(adv) + assert adv.distance.value > 0 + assert not adv.reached_threshold() # because no threshold specified + + c = adv._total_prediction_calls + d = adv.distance.value + large_d = 10 * d + small_d = d / 2 + + adv = Adversarial(model, criterion, image, label, + threshold=adv._distance(value=large_d)) + attack(adv) + assert 0 < adv.distance.value <= large_d + assert adv.reached_threshold() + assert adv._total_prediction_calls < c + + adv = Adversarial(model, criterion, image, label, + threshold=large_d) + attack(adv) + assert 0 < adv.distance.value <= large_d + assert adv.reached_threshold() + assert adv._total_prediction_calls < c + + adv = Adversarial(model, criterion, image, label, + threshold=small_d) + attack(adv) + assert small_d < adv.distance.value <= large_d + assert not adv.reached_threshold() + assert adv._total_prediction_calls == c + assert adv.distance.value == d + + adv = Adversarial(model, criterion, image, label, + threshold=adv._distance(value=large_d)) + attack(adv) + assert adv.reached_threshold() + c = adv._total_prediction_calls + attack(adv) + assert adv._total_prediction_calls == c # no new calls