Skip to content

Commit

Permalink
Merge pull request #232 from bveliqi/model_zoo_v2
Browse files Browse the repository at this point in the history
Added model zoo
  • Loading branch information
jonasrauber authored Nov 16, 2018
2 parents 6a7628e + 219add5 commit fe0307f
Show file tree
Hide file tree
Showing 16 changed files with 464 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "foolbox/tests/data/model_repo"]
path = foolbox/tests/data/model_repo
url = https://github.com/bveliqi/foolbox-zoo-dummy.git
2 changes: 2 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ You might want to have a look at our recently announced `Robust Vision Benchmark
user/tutorial
user/examples
user/adversarial
user/zoo
user/development
user/faq

Expand All @@ -44,6 +45,7 @@ You might want to have a look at our recently announced `Robust Vision Benchmark

modules/models
modules/criteria
modules/zoo
modules/distances
modules/attacks
modules/adversarial
Expand Down
17 changes: 17 additions & 0 deletions docs/modules/zoo.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
:mod:`foolbox.zoo`
=================================

.. automodule:: foolbox.zoo


Get Model
----------------

.. autofunction:: get_model



Fetch Weights
----------------

.. autofunction:: fetch_weights
26 changes: 26 additions & 0 deletions docs/user/zoo.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
=========
Model Zoo
=========

This tutorial will show you how the model zoo can be used to run your attack against a robust model.

Downloading a model
===================

For this tutorial, we will download the `Madry et al. CIFAR10 challenge` robust model implemented in `TensorFlow`
and run a `FGSM (GradienSignAttack)` against it.

.. code-block:: python3
from foolbox import zoo
# download the model
model = zoo.get_model(url="https://github.com/bethgelab/cifar10_challenge.git")
# read image and label
image = ...
label = ...
# apply attack on source image
attack = foolbox.attacks.FGSM(model)
adversarial = attack(image[:,:,::-1], label)
1 change: 1 addition & 0 deletions foolbox/tests/data/model_repo
Submodule model_repo added at e1932b
93 changes: 93 additions & 0 deletions foolbox/tests/test_fetch_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from foolbox.zoo import fetch_weights
from foolbox.zoo.common import path_exists, home_directory_path, sha256_hash
from foolbox.zoo.weights_fetcher import FOLDER

import os
import pytest
import shutil

import responses
import io
import zipfile


@responses.activate
def test_fetch_weights_unzipped():
weights_uri = 'http://localhost:8080/weights.zip'
raw_body = _random_body(zipped=False)

# mock server
responses.add(responses.GET, weights_uri,
body=raw_body, status=200, stream=True)

expected_path = _expected_path(weights_uri)

if path_exists(expected_path):
shutil.rmtree(expected_path) # make sure path does not exist already

file_path = fetch_weights(weights_uri)

exists_locally = path_exists(expected_path)
assert exists_locally
assert expected_path in file_path


@responses.activate
def test_fetch_weights_zipped():
weights_uri = 'http://localhost:8080/weights.zip'

# mock server
raw_body = _random_body(zipped=True)
responses.add(responses.GET, weights_uri,
body=raw_body, status=200, stream=True,
content_type='application/zip',
headers={'Accept-Encoding': 'gzip, deflate'})

expected_path = _expected_path(weights_uri)

if path_exists(expected_path):
shutil.rmtree(expected_path) # make sure path does not exist already

file_path = fetch_weights(weights_uri, unzip=True)

exists_locally = path_exists(expected_path)
assert exists_locally
assert expected_path in file_path


@responses.activate
def test_fetch_weights_returns_404():
weights_uri = 'http://down:8080/weights.zip'

# mock server
responses.add(responses.GET, weights_uri, status=404)

expected_path = _expected_path(weights_uri)

if path_exists(expected_path):
shutil.rmtree(expected_path) # make sure path does not exist already

with pytest.raises(RuntimeError):
fetch_weights(weights_uri, unzip=False)


def test_no_uri_given():
assert fetch_weights(None) is None


def _random_body(zipped=False):
if zipped:
data = io.BytesIO()
with zipfile.ZipFile(data, mode='w') as z:
z.writestr('test.txt', 'no real weights in here :)')
data.seek(0)
return data.getvalue()
else:
raw_body = os.urandom(1024)
return raw_body


def _expected_path(weights_uri):
hash_digest = sha256_hash(weights_uri)
local_path = home_directory_path(FOLDER, hash_digest)
return local_path
32 changes: 32 additions & 0 deletions foolbox/tests/test_git_cloner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from foolbox.zoo import git_cloner
import os
import hashlib
import pytest
from foolbox.zoo.git_cloner import GitCloneError


def test_git_clone():
# given
git_uri = "https://github.com/bethgelab/convex_adversarial.git"
expected_path = _expected_path(git_uri)

# when
path = git_cloner.clone(git_uri)

# then
assert path == expected_path


def test_wrong_git_uri():
git_uri = "[email protected]:bethgelab/non-existing-repo.git"
with pytest.raises(GitCloneError):
git_cloner.clone(git_uri)


def _expected_path(git_uri):
home = os.path.expanduser('~')
m = hashlib.sha256()
m.update(git_uri.encode())
hash = m.hexdigest()
expected_path = os.path.join(home, '.foolbox_zoo', hash)
return expected_path
52 changes: 52 additions & 0 deletions foolbox/tests/test_model_zoo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from foolbox import zoo
import numpy as np
import foolbox
import sys
import pytest
from foolbox.zoo.model_loader import ModelLoader
from os.path import join, dirname


@pytest.fixture(autouse=True)
def unload_foolbox_model_module():
# reload foolbox_model from scratch for every run
# to ensure atomic tests without side effects
module_names = ['foolbox_model', 'model']
for module_name in module_names:
if module_name in sys.modules:
del sys.modules[module_name]


test_data = [
# private repo won't work on travis
# ('https://github.com/bethgelab/AnalysisBySynthesis.git', (1, 28, 28)),
# ('https://github.com/bethgelab/convex_adversarial.git', (1, 28, 28)),
# ('https://github.com/bethgelab/mnist_challenge.git', 784)
(join('file://', dirname(__file__), 'data/model_repo'), (3, 224, 224))
]


@pytest.mark.parametrize("url, dim", test_data)
def test_loading_model(url, dim):
# download model
model = zoo.get_model(url)

# create a dummy image
x = np.zeros(dim, dtype=np.float32)
x[:] = np.random.randn(*x.shape)

# run the model
logits = model.predictions(x)
probabilities = foolbox.utils.softmax(logits)
predicted_class = np.argmax(logits)

# sanity check
assert predicted_class >= 0
assert np.sum(probabilities) >= 0.9999

# TODO: delete fmodel


def test_non_default_module_throws_error():
with pytest.raises(RuntimeError):
ModelLoader.get(key='other')
2 changes: 2 additions & 0 deletions foolbox/zoo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .zoo import get_model # noqa: F401
from .weights_fetcher import fetch_weights # noqa: F401
18 changes: 18 additions & 0 deletions foolbox/zoo/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import hashlib
import os


def sha256_hash(git_uri):
m = hashlib.sha256()
m.update(git_uri.encode())
return m.hexdigest()


def home_directory_path(folder, hash_digest):
# does this work on all operating systems?
home = os.path.expanduser('~')
return os.path.join(home, folder, hash_digest)


def path_exists(local_path):
return os.path.exists(local_path)
39 changes: 39 additions & 0 deletions foolbox/zoo/git_cloner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from git import Repo
import logging
from .common import sha256_hash, home_directory_path, path_exists

FOLDER = '.foolbox_zoo'


class GitCloneError(RuntimeError):
pass


def clone(git_uri):
"""
Clone a remote git repository to a local path.
:param git_uri: the URI to the git repository to be cloned
:return: the generated local path where the repository has been cloned to
"""
hash_digest = sha256_hash(git_uri)
local_path = home_directory_path(FOLDER, hash_digest)
exists_locally = path_exists(local_path)

if not exists_locally:
_clone_repo(git_uri, local_path)
else:
logging.info( # pragma: no cover
"Git repository already exists locally.") # pragma: no cover

return local_path


def _clone_repo(git_uri, local_path):
logging.info("Cloning repo %s to %s", git_uri, local_path)
try:
Repo.clone_from(git_uri, local_path)
except Exception as e:
logging.exception("Failed to clone repository", e)
raise GitCloneError("Failed to clone repository")
logging.info("Cloned repo successfully.")
45 changes: 45 additions & 0 deletions foolbox/zoo/model_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import sys
import importlib

import abc
abstractmethod = abc.abstractmethod
if sys.version_info >= (3, 4):
ABC = abc.ABC
else: # pragma: no cover
ABC = abc.ABCMeta('ABC', (), {})


class ModelLoader(ABC):

@abstractmethod
def load(self, path):
"""
Load a model from a local path, to which a git repository
has been previously cloned to.
:param path: the path to the local repository containing the code
:return: a foolbox-wrapped model
"""
pass # pragma: no cover

@staticmethod
def get(key='default'):
if key is 'default':
return DefaultLoader()
else:
raise RuntimeError("No model loader for: %s".format(key))

@staticmethod
def _import_module(path, module_name='foolbox_model'):
sys.path.insert(0, path)
module = importlib.import_module(module_name)
print('imported module: {}'.format(module))
return module


class DefaultLoader(ModelLoader):

def load(self, path, module_name='foolbox_model'):
module = ModelLoader._import_module(path, module_name)
model = module.create()
return model
Loading

0 comments on commit fe0307f

Please sign in to comment.