diff --git a/foolbox/tests/test_models.py b/foolbox/tests/test_models.py index b80bb8e6..90539382 100644 --- a/foolbox/tests/test_models.py +++ b/foolbox/tests/test_models.py @@ -44,6 +44,9 @@ def num_classes(self): def predictions_and_gradient(self, image, label): return 'predictions', 'gradient' + def backward(self, gradient, image): + return image + model = TestModel(bounds=(0, 1), channel_axis=1) image = np.ones((28, 28, 1), dtype=np.float32)