Skip to content

Commit

Permalink
added backward method support to CompositeModel wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonas Rauber committed Oct 12, 2018
1 parent a829ff0 commit 704f4b7
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 1 deletion.
2 changes: 1 addition & 1 deletion foolbox/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def gradient(self, image, label):
_, gradient = self.predictions_and_gradient(image, label)
return gradient

# TODO: make this an abstract method once support is added to all models
@abstractmethod
def backward(self, gradient, image):
"""Backpropagates the gradient of some loss w.r.t. the logits
through the network and returns the gradient of that loss w.r.t
Expand Down
5 changes: 5 additions & 0 deletions foolbox/models/mxnet_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,8 @@ def _loss_fn(self, image, label):
loss = mx.nd.softmax_cross_entropy(logits, label)
loss.backward()
return loss.asnumpy()

def backward(self, gradient, image): # pragma: no cover
# TODO: backward functionality has not yet been implemented
# for MXNetGluonModel
raise NotImplementedError
3 changes: 3 additions & 0 deletions foolbox/models/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ def predictions_and_gradient(self, image, label):
def gradient(self, image, label):
return self.backward_model.gradient(image, label)

def backward(self, gradient, image):
return self.backward_model.backward(gradient, image)

def __enter__(self):
assert self.forward_model.__enter__() == self.forward_model
assert self.backward_model.__enter__() == self.backward_model
Expand Down
5 changes: 5 additions & 0 deletions foolbox/tests/test_model_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def test_diff_wrapper(bn_model, bn_image, bn_label):


def test_composite_model(gl_bn_model, bn_model, bn_image, bn_label):
num_classes = 10
test_grad = np.random.rand(num_classes).astype(np.float32)
model = CompositeModel(gl_bn_model, bn_model)
with model:
assert gl_bn_model.num_classes() == model.num_classes()
Expand All @@ -48,6 +50,9 @@ def test_composite_model(gl_bn_model, bn_model, bn_image, bn_label):
assert np.all(
bn_model.gradient(bn_image, bn_label) ==
model.gradient(bn_image, bn_label))
assert np.all(
bn_model.backward(test_grad, bn_image) ==
model.backward(test_grad, bn_image))
assert np.all(
gl_bn_model.predictions(bn_image) ==
model.predictions_and_gradient(bn_image, bn_label)[0])
Expand Down

0 comments on commit 704f4b7

Please sign in to comment.