Skip to content
This repository has been archived by the owner on Dec 6, 2023. It is now read-only.

Add proximal opeartor for the 2D total variation penalty #101

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions examples/plot_2d_total_variation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""
Example use of total variation denoising.

We recover an image that was corrupted with gaussian noise.
For this we use the proximal operator of the
total variation penalty (TotalVariation2DPenalty), which solves
a problem of the form

argmin_x ||corrupted_image - x||^2 + alpha * TV(x)

where TV(x) is the 2-dimensional total variation penalty.
"""
import pylab as plt
import numpy as np
from scipy import misc
from lightning.impl.penalty import TotalVariation2DPenalty

face = misc.imresize(misc.face(gray=True), 0.2)
face = face.astype(np.float) / 255.

# add gaussian noise to the origin image
data = face + 0.2 * np.random.randn(*face.shape)
f, ax = plt.subplots(1, 5, sharey=False)

ax[0].set_title('original')
ax[0].imshow(data, interpolation='nearest', cmap=plt.cm.gray)
ax[0].set_xticks(())
ax[0].set_yticks(())

for i, alpha in enumerate(np.logspace(-1, -0.5, 4)):
print('Computing denoising for alpha=%s' % alpha)
denoised = TotalVariation2DPenalty(*face.shape).projection([data.ravel()], alpha, 1.0)
ax[i+1].set_title(r'$\alpha$=%.2f' % alpha)
ax[i+1].imshow(denoised.reshape(face.shape), interpolation='nearest', cmap=plt.cm.gray)
ax[i+1].set_xticks(())
ax[i+1].set_yticks(())
plt.show()
26 changes: 16 additions & 10 deletions lightning/impl/fista.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .penalty import SimplexConstraint
from .penalty import L1BallConstraint
from .penalty import TotalVariation1DPenalty
from .penalty import TotalVariation2DPenalty


class _BaseFista(object):
Expand All @@ -29,14 +30,15 @@ def _get_penalty(self):
if hasattr(self.penalty, 'projection'):
return self.penalty
penalties = {
"l1": L1Penalty(),
"l1/l2": L1L2Penalty(),
"trace": TracePenalty(),
"simplex": SimplexConstraint(),
"l1-ball": L1BallConstraint(),
"tv1d": TotalVariation1DPenalty()
"l1": L1Penalty,
"l1/l2": L1L2Penalty,
"trace": TracePenalty,
"simplex": SimplexConstraint,
"l1-ball": L1BallConstraint,
"tv1d": TotalVariation1DPenalty,
"tv2d": TotalVariation2DPenalty
}
return penalties[self.penalty]
return penalties[self.penalty](*self.prox_args)

def _get_objective(self, df, y, loss):
return self.C * loss.objective(df, y)
Expand Down Expand Up @@ -76,8 +78,9 @@ def _fit(self, X, y, n_vectors):

t = 1.0
for it in xrange(self.max_iter):

if self.verbose >= 1:
print("Iter", it + 1, obj)
print("Iter=%s, \tloss=%s" % (it+1, obj))

# Save current values
t_old = t
Expand Down Expand Up @@ -188,7 +191,7 @@ class FistaClassifier(BaseClassifier, _BaseFista):

def __init__(self, C=1.0, alpha=1.0, loss="squared_hinge", penalty="l1",
multiclass=False, max_iter=100, max_steps=30, eta=2.0,
sigma=1e-5, callback=None, verbose=0):
sigma=1e-5, callback=None, verbose=0, prox_args=()):
self.C = C
self.alpha = alpha
self.loss = loss
Expand All @@ -200,6 +203,7 @@ def __init__(self, C=1.0, alpha=1.0, loss="squared_hinge", penalty="l1",
self.sigma = sigma
self.callback = callback
self.verbose = verbose
self.prox_args = prox_args

def _get_loss(self):
if self.multiclass:
Expand Down Expand Up @@ -277,7 +281,8 @@ class FistaRegressor(BaseRegressor, _BaseFista):
"""

def __init__(self, C=1.0, alpha=1.0, penalty="l1", max_iter=100,
max_steps=30, eta=2.0, sigma=1e-5, callback=None, verbose=0):
max_steps=30, eta=2.0, sigma=1e-5, callback=None, verbose=0,
prox_args=()):
self.C = C
self.alpha = alpha
self.penalty = penalty
Expand All @@ -287,6 +292,7 @@ def __init__(self, C=1.0, alpha=1.0, penalty="l1", max_iter=100,
self.sigma = sigma
self.callback = callback
self.verbose = verbose
self.prox_args = prox_args

def _get_loss(self):
return Squared()
Expand Down
78 changes: 75 additions & 3 deletions lightning/impl/penalty.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,24 @@
"""
In this file there are defined several proximal operators for common penalty
functions. These objects implement the following methods:

* projection(self, coef, alpha, L)
returns the value of the proximal operator for the penalty,
where coef is a ndarray that contains the coefficients of the
model, alpha is the amount of regularization and 1/L is the
step size

* regularization(self, coef)
returns the value of the penalty at coef, where coef is a
ndarray.
"""
# Author: Mathieu Blondel
# Fabian Pedregosa
# License: BSD

import numpy as np
from scipy.linalg import svd
from lightning.impl.prox_fast import prox_tv1d
from lightning.impl.prox_fast import prox_tv1d, prox_tv2d


class L1Penalty(object):
Expand Down Expand Up @@ -87,11 +102,68 @@ def regularization(self, coef):


class TotalVariation1DPenalty(object):
"""
Proximal operator for the 1-D total variation penalty (also known
as fussed lasso)
"""
def projection(self, coef, alpha, L):
tmp = coef.copy()
tmp = np.empty_like(coef)
for i in range(tmp.shape[0]):
prox_tv1d(tmp[i, :], alpha / L) # operates inplace
tmp[i] = prox_tv1d(coef[i], alpha / L)
return tmp

def regularization(self, coef):
return np.sum(np.abs(np.diff(coef)))


class TotalVariation2DPenalty(object):
"""
Proximal operator for the 2-D total variation penalty. This
proximal operator is computed approximately using the
Douglas-Rachford algorithm.

Parameters
----------
n_rows: int
number of rows in the image

n_cols: int
number of columns in the image

max_iter: int
maximum number of iterations to compute this proximal
operator.

Misc
-----
Note that n_rows * n_cols needs to be equal to the size
of the vector of coefficients.

References
----------
Barbero, Alvaro, and Suvrit Sra. "Modular proximal optimization for
multidimensional total-variation regularization." arXiv preprint
arXiv:1411.0589 (2014).
"""
def __init__(self, n_rows, n_cols, max_iter=1000, tol=1e-12):
self.n_rows = n_rows
self.n_cols = n_cols
self.max_iter = max_iter
self.tol = tol

def projection(self, coef, alpha, L):
tmp = np.empty_like(coef)
for i in range(tmp.shape[0]):
tmp[i] = prox_tv2d(
coef[i].reshape((self.n_rows, self.n_cols)),
alpha / L, self.max_iter, self.tol).ravel()
return tmp

def regularization(self, coef):
out = 0.0
for i in range(coef.shape[0]):
img = coef[i].reshape((self.n_rows, self.n_cols))
tmp1 = np.abs(np.diff(img, axis=0))
tmp2 = np.abs(np.diff(img, axis=1))
out += tmp1.sum() + tmp2.sum()
return out
Loading