-
-
Notifications
You must be signed in to change notification settings - Fork 428
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #232 from bveliqi/model_zoo_v2
Added model zoo
- Loading branch information
Showing
16 changed files
with
464 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
[submodule "foolbox/tests/data/model_repo"] | ||
path = foolbox/tests/data/model_repo | ||
url = https://github.com/bveliqi/foolbox-zoo-dummy.git |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
:mod:`foolbox.zoo` | ||
================================= | ||
|
||
.. automodule:: foolbox.zoo | ||
|
||
|
||
Get Model | ||
---------------- | ||
|
||
.. autofunction:: get_model | ||
|
||
|
||
|
||
Fetch Weights | ||
---------------- | ||
|
||
.. autofunction:: fetch_weights |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
========= | ||
Model Zoo | ||
========= | ||
|
||
This tutorial will show you how the model zoo can be used to run your attack against a robust model. | ||
|
||
Downloading a model | ||
=================== | ||
|
||
For this tutorial, we will download the `Madry et al. CIFAR10 challenge` robust model implemented in `TensorFlow` | ||
and run a `FGSM (GradienSignAttack)` against it. | ||
|
||
.. code-block:: python3 | ||
from foolbox import zoo | ||
# download the model | ||
model = zoo.get_model(url="https://github.com/bethgelab/cifar10_challenge.git") | ||
# read image and label | ||
image = ... | ||
label = ... | ||
# apply attack on source image | ||
attack = foolbox.attacks.FGSM(model) | ||
adversarial = attack(image[:,:,::-1], label) |
Submodule model_repo
added at
e1932b
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from foolbox.zoo import fetch_weights | ||
from foolbox.zoo.common import path_exists, home_directory_path, sha256_hash | ||
from foolbox.zoo.weights_fetcher import FOLDER | ||
|
||
import os | ||
import pytest | ||
import shutil | ||
|
||
import responses | ||
import io | ||
import zipfile | ||
|
||
|
||
@responses.activate | ||
def test_fetch_weights_unzipped(): | ||
weights_uri = 'http://localhost:8080/weights.zip' | ||
raw_body = _random_body(zipped=False) | ||
|
||
# mock server | ||
responses.add(responses.GET, weights_uri, | ||
body=raw_body, status=200, stream=True) | ||
|
||
expected_path = _expected_path(weights_uri) | ||
|
||
if path_exists(expected_path): | ||
shutil.rmtree(expected_path) # make sure path does not exist already | ||
|
||
file_path = fetch_weights(weights_uri) | ||
|
||
exists_locally = path_exists(expected_path) | ||
assert exists_locally | ||
assert expected_path in file_path | ||
|
||
|
||
@responses.activate | ||
def test_fetch_weights_zipped(): | ||
weights_uri = 'http://localhost:8080/weights.zip' | ||
|
||
# mock server | ||
raw_body = _random_body(zipped=True) | ||
responses.add(responses.GET, weights_uri, | ||
body=raw_body, status=200, stream=True, | ||
content_type='application/zip', | ||
headers={'Accept-Encoding': 'gzip, deflate'}) | ||
|
||
expected_path = _expected_path(weights_uri) | ||
|
||
if path_exists(expected_path): | ||
shutil.rmtree(expected_path) # make sure path does not exist already | ||
|
||
file_path = fetch_weights(weights_uri, unzip=True) | ||
|
||
exists_locally = path_exists(expected_path) | ||
assert exists_locally | ||
assert expected_path in file_path | ||
|
||
|
||
@responses.activate | ||
def test_fetch_weights_returns_404(): | ||
weights_uri = 'http://down:8080/weights.zip' | ||
|
||
# mock server | ||
responses.add(responses.GET, weights_uri, status=404) | ||
|
||
expected_path = _expected_path(weights_uri) | ||
|
||
if path_exists(expected_path): | ||
shutil.rmtree(expected_path) # make sure path does not exist already | ||
|
||
with pytest.raises(RuntimeError): | ||
fetch_weights(weights_uri, unzip=False) | ||
|
||
|
||
def test_no_uri_given(): | ||
assert fetch_weights(None) is None | ||
|
||
|
||
def _random_body(zipped=False): | ||
if zipped: | ||
data = io.BytesIO() | ||
with zipfile.ZipFile(data, mode='w') as z: | ||
z.writestr('test.txt', 'no real weights in here :)') | ||
data.seek(0) | ||
return data.getvalue() | ||
else: | ||
raw_body = os.urandom(1024) | ||
return raw_body | ||
|
||
|
||
def _expected_path(weights_uri): | ||
hash_digest = sha256_hash(weights_uri) | ||
local_path = home_directory_path(FOLDER, hash_digest) | ||
return local_path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from foolbox.zoo import git_cloner | ||
import os | ||
import hashlib | ||
import pytest | ||
from foolbox.zoo.git_cloner import GitCloneError | ||
|
||
|
||
def test_git_clone(): | ||
# given | ||
git_uri = "https://github.com/bethgelab/convex_adversarial.git" | ||
expected_path = _expected_path(git_uri) | ||
|
||
# when | ||
path = git_cloner.clone(git_uri) | ||
|
||
# then | ||
assert path == expected_path | ||
|
||
|
||
def test_wrong_git_uri(): | ||
git_uri = "[email protected]:bethgelab/non-existing-repo.git" | ||
with pytest.raises(GitCloneError): | ||
git_cloner.clone(git_uri) | ||
|
||
|
||
def _expected_path(git_uri): | ||
home = os.path.expanduser('~') | ||
m = hashlib.sha256() | ||
m.update(git_uri.encode()) | ||
hash = m.hexdigest() | ||
expected_path = os.path.join(home, '.foolbox_zoo', hash) | ||
return expected_path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from foolbox import zoo | ||
import numpy as np | ||
import foolbox | ||
import sys | ||
import pytest | ||
from foolbox.zoo.model_loader import ModelLoader | ||
from os.path import join, dirname | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def unload_foolbox_model_module(): | ||
# reload foolbox_model from scratch for every run | ||
# to ensure atomic tests without side effects | ||
module_names = ['foolbox_model', 'model'] | ||
for module_name in module_names: | ||
if module_name in sys.modules: | ||
del sys.modules[module_name] | ||
|
||
|
||
test_data = [ | ||
# private repo won't work on travis | ||
# ('https://github.com/bethgelab/AnalysisBySynthesis.git', (1, 28, 28)), | ||
# ('https://github.com/bethgelab/convex_adversarial.git', (1, 28, 28)), | ||
# ('https://github.com/bethgelab/mnist_challenge.git', 784) | ||
(join('file://', dirname(__file__), 'data/model_repo'), (3, 224, 224)) | ||
] | ||
|
||
|
||
@pytest.mark.parametrize("url, dim", test_data) | ||
def test_loading_model(url, dim): | ||
# download model | ||
model = zoo.get_model(url) | ||
|
||
# create a dummy image | ||
x = np.zeros(dim, dtype=np.float32) | ||
x[:] = np.random.randn(*x.shape) | ||
|
||
# run the model | ||
logits = model.predictions(x) | ||
probabilities = foolbox.utils.softmax(logits) | ||
predicted_class = np.argmax(logits) | ||
|
||
# sanity check | ||
assert predicted_class >= 0 | ||
assert np.sum(probabilities) >= 0.9999 | ||
|
||
# TODO: delete fmodel | ||
|
||
|
||
def test_non_default_module_throws_error(): | ||
with pytest.raises(RuntimeError): | ||
ModelLoader.get(key='other') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .zoo import get_model # noqa: F401 | ||
from .weights_fetcher import fetch_weights # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import hashlib | ||
import os | ||
|
||
|
||
def sha256_hash(git_uri): | ||
m = hashlib.sha256() | ||
m.update(git_uri.encode()) | ||
return m.hexdigest() | ||
|
||
|
||
def home_directory_path(folder, hash_digest): | ||
# does this work on all operating systems? | ||
home = os.path.expanduser('~') | ||
return os.path.join(home, folder, hash_digest) | ||
|
||
|
||
def path_exists(local_path): | ||
return os.path.exists(local_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from git import Repo | ||
import logging | ||
from .common import sha256_hash, home_directory_path, path_exists | ||
|
||
FOLDER = '.foolbox_zoo' | ||
|
||
|
||
class GitCloneError(RuntimeError): | ||
pass | ||
|
||
|
||
def clone(git_uri): | ||
""" | ||
Clone a remote git repository to a local path. | ||
:param git_uri: the URI to the git repository to be cloned | ||
:return: the generated local path where the repository has been cloned to | ||
""" | ||
hash_digest = sha256_hash(git_uri) | ||
local_path = home_directory_path(FOLDER, hash_digest) | ||
exists_locally = path_exists(local_path) | ||
|
||
if not exists_locally: | ||
_clone_repo(git_uri, local_path) | ||
else: | ||
logging.info( # pragma: no cover | ||
"Git repository already exists locally.") # pragma: no cover | ||
|
||
return local_path | ||
|
||
|
||
def _clone_repo(git_uri, local_path): | ||
logging.info("Cloning repo %s to %s", git_uri, local_path) | ||
try: | ||
Repo.clone_from(git_uri, local_path) | ||
except Exception as e: | ||
logging.exception("Failed to clone repository", e) | ||
raise GitCloneError("Failed to clone repository") | ||
logging.info("Cloned repo successfully.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import sys | ||
import importlib | ||
|
||
import abc | ||
abstractmethod = abc.abstractmethod | ||
if sys.version_info >= (3, 4): | ||
ABC = abc.ABC | ||
else: # pragma: no cover | ||
ABC = abc.ABCMeta('ABC', (), {}) | ||
|
||
|
||
class ModelLoader(ABC): | ||
|
||
@abstractmethod | ||
def load(self, path): | ||
""" | ||
Load a model from a local path, to which a git repository | ||
has been previously cloned to. | ||
:param path: the path to the local repository containing the code | ||
:return: a foolbox-wrapped model | ||
""" | ||
pass # pragma: no cover | ||
|
||
@staticmethod | ||
def get(key='default'): | ||
if key is 'default': | ||
return DefaultLoader() | ||
else: | ||
raise RuntimeError("No model loader for: %s".format(key)) | ||
|
||
@staticmethod | ||
def _import_module(path, module_name='foolbox_model'): | ||
sys.path.insert(0, path) | ||
module = importlib.import_module(module_name) | ||
print('imported module: {}'.format(module)) | ||
return module | ||
|
||
|
||
class DefaultLoader(ModelLoader): | ||
|
||
def load(self, path, module_name='foolbox_model'): | ||
module = ModelLoader._import_module(path, module_name) | ||
model = module.create() | ||
return model |
Oops, something went wrong.