diff --git a/foolbox/models/base.py b/foolbox/models/base.py index 950433a3..315a095a 100644 --- a/foolbox/models/base.py +++ b/foolbox/models/base.py @@ -246,7 +246,7 @@ def gradient(self, image, label): _, gradient = self.predictions_and_gradient(image, label) return gradient - # TODO: make this an abstract method once support is added to all models + @abstractmethod def backward(self, gradient, image): """Backpropagates the gradient of some loss w.r.t. the logits through the network and returns the gradient of that loss w.r.t diff --git a/foolbox/models/mxnet_gluon.py b/foolbox/models/mxnet_gluon.py index 7fef7cfd..22823868 100644 --- a/foolbox/models/mxnet_gluon.py +++ b/foolbox/models/mxnet_gluon.py @@ -88,3 +88,8 @@ def _loss_fn(self, image, label): loss = mx.nd.softmax_cross_entropy(logits, label) loss.backward() return loss.asnumpy() + + def backward(self, gradient, image): # pragma: no cover + # TODO: backward functionality has not yet been implemented + # for MXNetGluonModel + raise NotImplementedError diff --git a/foolbox/models/wrappers.py b/foolbox/models/wrappers.py index f58a4ad0..a1ffe363 100644 --- a/foolbox/models/wrappers.py +++ b/foolbox/models/wrappers.py @@ -154,6 +154,9 @@ def predictions_and_gradient(self, image, label): def gradient(self, image, label): return self.backward_model.gradient(image, label) + def backward(self, gradient, image): + return self.backward_model.backward(gradient, image) + def __enter__(self): assert self.forward_model.__enter__() == self.forward_model assert self.backward_model.__enter__() == self.backward_model diff --git a/foolbox/tests/test_model_wrappers.py b/foolbox/tests/test_model_wrappers.py index 57a26d4c..91f6e2fa 100644 --- a/foolbox/tests/test_model_wrappers.py +++ b/foolbox/tests/test_model_wrappers.py @@ -39,6 +39,8 @@ def test_diff_wrapper(bn_model, bn_image, bn_label): def test_composite_model(gl_bn_model, bn_model, bn_image, bn_label): + num_classes = 10 + test_grad = np.random.rand(num_classes).astype(np.float32) model = CompositeModel(gl_bn_model, bn_model) with model: assert gl_bn_model.num_classes() == model.num_classes() @@ -48,6 +50,9 @@ def test_composite_model(gl_bn_model, bn_model, bn_image, bn_label): assert np.all( bn_model.gradient(bn_image, bn_label) == model.gradient(bn_image, bn_label)) + assert np.all( + bn_model.backward(test_grad, bn_image) == + model.backward(test_grad, bn_image)) assert np.all( gl_bn_model.predictions(bn_image) == model.predictions_and_gradient(bn_image, bn_label)[0]) diff --git a/foolbox/tests/test_models.py b/foolbox/tests/test_models.py index b80bb8e6..90539382 100644 --- a/foolbox/tests/test_models.py +++ b/foolbox/tests/test_models.py @@ -44,6 +44,9 @@ def num_classes(self): def predictions_and_gradient(self, image, label): return 'predictions', 'gradient' + def backward(self, gradient, image): + return image + model = TestModel(bounds=(0, 1), channel_axis=1) image = np.ones((28, 28, 1), dtype=np.float32)