Skip to content

Commit

Permalink
DOC add Mandrill example (#10)
Browse files Browse the repository at this point in the history
Co-authored-by: Hande Gözükan <[email protected]>
Co-authored-by: Romain Primet <[email protected]>
  • Loading branch information
3 people authored Feb 17, 2021
1 parent a4a6214 commit 1b54bac
Show file tree
Hide file tree
Showing 18 changed files with 153 additions and 34 deletions.
6 changes: 4 additions & 2 deletions .github/workflows/deploy_ghpages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@ jobs:
- uses: actions/checkout@v2
- name: Generate HTML docs
uses: ammaraskar/sphinx-action@master
env:
ALLOW: --allow-run-as-root
with:
docs-folder: "docs/"
pre-build-command: |
apt-get update
apt-get install -y libopenmpi-dev openmpi-bin
pip install -e .[doc]
echo "localhost slots=50">hostfile
- name: Upload generated HTML as artifact
uses: actions/upload-artifact@v2
with:
Expand Down Expand Up @@ -51,5 +54,4 @@ jobs:
directory: gh-pages
github_token: ${{ secrets.GITHUB_TOKEN }}




2 changes: 1 addition & 1 deletion .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ on: [push, pull_request]

jobs:
test:
runs-on: ubuntu-latest
runs-on: ubuntu-20.04
defaults:
run:
shell: bash -l {0}
Expand Down
5 changes: 2 additions & 3 deletions benchmarks/compare_cdl.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@

import pandas
import numpy as np
import matplotlib.pyplot as plt
from collections import namedtuple

from dicodile.dicodile import dicodile
from dicodile import dicodile
from dicodile.data.images import get_hubble
from dicodile.utils.viz import median_curve
from dicodile.utils.dictionary import get_lambda_max
Expand Down Expand Up @@ -50,7 +49,7 @@ def run_one(method, n_atoms, atom_support, reg, z_positive, n_workers, n_iter,
'CBPDN': {'rho': 50.0*reg_ + 0.5,
'NonNegCoef': z_positive},
'DictSize': D_init_.shape,
}
}
opt = ConvBPDNDictLearn_Consensus.Options(options)
cdl = ConvBPDNDictLearn_Consensus(
D_init_, X_, lmbda=reg_, nproc=n_workers, opt=opt, dimK=1, dimN=2)
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/dicodile_hubble.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from scipy import sparse

from dicodile.dicodile import dicodile
from dicodile import dicodile
from dicodile.data.images import get_hubble
from dicodile.utils.viz import plot_atom_and_coefs

Expand Down
2 changes: 1 addition & 1 deletion benchmarks/dicodile_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from sklearn.feature_extraction.image import extract_patches_2d

# Import for CDL
from dicodile.dicodile import dicodile
from dicodile import dicodile

# Import to initiate the dictionary
from dicodile.utils.dictionary import prox_d
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/scaling_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from benchmarks.parallel_resource_balance import ParallelResourceBalance

from dicodile.update_z.dicod import dicod
from dicodile.data.images import get_mandril
from dicodile.data.images import fetch_mandrill
from dicodile.utils import check_random_state
from dicodile.utils.dictionary import get_lambda_max
from dicodile.utils.dictionary import init_dictionary
Expand Down Expand Up @@ -78,7 +78,7 @@ def run_one_scaling_2d(n_atoms, atom_support, reg, n_workers, strategy, tol,
# Generate a problem
print(colorify(79*"=" + f"\n{tag} Start with {n_workers} workers\n" +
79*"="))
X = get_mandril()
X = fetch_mandrill()
D = init_dictionary(X, n_atoms, atom_support, random_state=random_state)
reg_ = reg * get_lambda_max(X, D).max()

Expand Down
4 changes: 2 additions & 2 deletions benchmarks/scaling_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from benchmarks.parallel_resource_balance import ParallelResourceBalance

from dicodile.update_z.dicod import dicod
from dicodile.data.images import get_mandril
from dicodile.data.images import fetch_mandrill
from dicodile.utils import check_random_state
from dicodile.utils.dictionary import get_lambda_max
from dicodile.utils.dictionary import init_dictionary
Expand Down Expand Up @@ -80,7 +80,7 @@ def run_one_grid(n_atoms, atom_support, reg, n_workers, grid, tol,
# Generate a problem
print(colorify(79*"=" + f"\n{tag} Start with {n_workers} workers\n" +
79*"="))
X = get_mandril()
X = fetch_mandrill()
D = init_dictionary(X, n_atoms, atom_support, random_state=random_state)
reg_ = reg * get_lambda_max(X, D).max()

Expand Down
4 changes: 2 additions & 2 deletions benchmarks/soft_lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


from dicodile.update_z.dicod import dicod
from dicodile.data.images import get_mandril
from dicodile.data.images import fetch_mandrill
from dicodile.utils.segmentation import Segmentation
from dicodile.utils.dictionary import get_lambda_max
from dicodile.utils.dictionary import init_dictionary
Expand All @@ -19,7 +19,7 @@ def run_without_soft_lock(n_atoms=25, atom_support=(12, 12), reg=.01,
tol=5e-2, n_workers=100, random_state=60):
rng = np.random.RandomState(random_state)

X = get_mandril()
X = fetch_mandrill()
D_init = init_dictionary(X, n_atoms, atom_support, random_state=rng)
lmbd_max = get_lambda_max(X, D_init).max()
reg_ = reg * lmbd_max
Expand Down
3 changes: 3 additions & 0 deletions dicodile/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._dicodile import dicodile

__all__ = ['dicodile']
1 change: 0 additions & 1 deletion dicodile/dicodile.py → dicodile/_dicodile.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import time
import numpy as np

Expand Down
12 changes: 6 additions & 6 deletions dicodile/data/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@
from .home import DATA_HOME


def get_mandril():
def fetch_mandrill():

mandril_dir = DATA_HOME / "images" / "standard_images"
mandril_dir.mkdir(parents=True, exist_ok=True)
mandril = download(
mandrill_dir = DATA_HOME / "images" / "standard_images"
mandrill_dir.mkdir(parents=True, exist_ok=True)
mandrill = download(
"http://sipi.usc.edu/database/download.php?vol=misc&img=4.2.03",
mandril_dir / "mandril_color.tif"
mandrill_dir / "mandrill_color.tif"
)

X = plt.imread(mandril) / 255
X = plt.imread(mandrill) / 255
return X.swapaxes(0, 2)


Expand Down
6 changes: 0 additions & 6 deletions dicodile/data/tests/test_mandril.py

This file was deleted.

6 changes: 6 additions & 0 deletions dicodile/data/tests/test_mandrill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from dicodile.data.images import fetch_mandrill


def test_fetch_mandrill():
data = fetch_mandrill()
assert(3, 512, 512) == data.shape
7 changes: 3 additions & 4 deletions dicodile/tests/test_dicodile.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from dicodile.dicodile import dicodile
from dicodile import dicodile
from dicodile.data.simulate import simulate_data


from dicodile.utils.testing import is_deacreasing


Expand All @@ -11,6 +10,6 @@ def test_dicodile():
n_channels=3, noise_level=1e-5, random_state=42)

pobj, times, D_hat, z_hat = dicodile(
X, D, reg=.1, z_positive=True, n_iter=10, eps=1e-4,
n_workers=1, verbose=2, tol=1e-10)
X, D, reg=.1, z_positive=True, n_iter=10, eps=1e-4,
n_workers=1, verbose=2, tol=1e-10)
assert is_deacreasing(pobj)
3 changes: 2 additions & 1 deletion dicodile/utils/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def display_dictionaries(*list_D, styles=None, axes=None, filename=None):
n_dict = len(list_D)
D_0 = list_D[0]

if styles is None and n_dict > 1:
if styles is None and n_dict >= 1:
styles = [dict(color=f'C{i}') for i in range(n_dict)]

# compute layout
Expand Down Expand Up @@ -202,6 +202,7 @@ class RotationAwareAnnotation(mpl_text.Annotation):
Key-word arguments for the Annotation. List of available kwargs:
https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.annotate.html
"""

def __init__(self, text, anchor_pt, next_pt, ax=None, **kwargs):
# Get the Artiste to draw on
self.ax = ax or plt.gca()
Expand Down
18 changes: 17 additions & 1 deletion docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,31 @@

# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXOPTS +=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build
HOSTFILE = ../hostfile
ALLOW_AS_ROOT = ${ALLOW}

# Put it first so that "make" without argument is like "make help".
.PHONY: help
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

.PHONY: clean
clean:
rm -rf $(BUILDDIR)/*
rm -rf auto_examples/*
rm -rf generated/*


.PHONY: html
html:
mpiexec -np 1 $(ALLOW_AS_ROOT) --hostfile $(HOSTFILE) $(SPHINXBUILD) -b html $(SOURCEDIR) $(BUILDDIR)/html $(SPHINXOPTS) ${0}
@echo
@echo "Build finished. The HTML pages are in $(BUILDDIR)/html."

.PHONY: help Makefile

# Catch-all target: route all unknown targets to Sphinx using the new
Expand Down
2 changes: 1 addition & 1 deletion docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
API Documentation
=================

.. currentmodule:: dicodile.dicodile
.. currentmodule:: dicodile


.. autosummary::
Expand Down
100 changes: 100 additions & 0 deletions examples/plot_mandrill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""DiCoDiLe on the Mandrill image
==============================
This example illlustrates reconstruction of `Mandrill image
<http://sipi.usc.edu/database/download.php?vol=misc&img=4.2.03>`_
using DiCoDiLe algorithm with default soft_lock value "border" and 9
workers.
""" # noqa

import numpy as np
import matplotlib.pyplot as plt

from dicodile.data.images import fetch_mandrill

from dicodile.utils.dictionary import init_dictionary
from dicodile.utils.viz import display_dictionaries
from dicodile.utils.csc import reconstruct

from dicodile import dicodile


###############################################################################
# We will first download the Mandrill image.

X = fetch_mandrill()

plt.axis('off')
plt.imshow(X.swapaxes(0, 2))


###############################################################################
# We will create a random dictionary of **K = 25** patches of size **8x8**
# from the original Mandrill image to be used for sparse coding.

# set dictionary size
n_atoms = 25

# set individual atom (patch) size
atom_support = (8, 8)

D_init = init_dictionary(X, n_atoms, atom_support, random_state=60)

###############################################################################
# We are going to run `dicodile` with **9** workers on **3x3** grids.

# number of iterations for dicodile
n_iter = 3

# number of iterations for csc (dicodile_z)
max_iter = 10000

# number of splits along each dimension
w_world = 3

# number of workers
n_workers = w_world * w_world

# coordinate selection strategy for coordinate descent
strategy = 'greedy'

###############################################################################
# Run `dicodile`.

pobj, times, D_hat, z_hat = dicodile(X, D_init, n_iter=n_iter,
n_workers=n_workers,
strategy=strategy,
dicod_kwargs={"max_iter": max_iter},
verbose=6)


print("[DICOD] final cost : {}".format(pobj))

###############################################################################
# Plot and compare the initial dictionary `D_init` with the
# dictionary `D_hat` improved by `dicodile`.

# normalize dictionaries
normalized_D_init = D_init / D_init.max()
normalized_D_hat = D_hat / D_hat.max()

display_dictionaries(normalized_D_init, normalized_D_hat)


###############################################################################
# Reconstruct the image from `z_hat` and `D_hat`.

X_hat = reconstruct(z_hat, D_hat)
X_hat = np.clip(X_hat, 0, 1)


###############################################################################
# Plot the reconstructed image.

fig = plt.figure("recovery")

ax = plt.subplot()
ax.imshow(X_hat.swapaxes(0, 2))
ax.axis('off')
plt.tight_layout()

0 comments on commit 1b54bac

Please sign in to comment.