From 2b1764c0ea1b351386d47925b4eaf1ed60163632 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Fri, 19 Jan 2024 13:03:52 +0100 Subject: [PATCH] Merge Develop for a new release (#322) * Add support for tensorflow backend which allows for differentiability (#112) * Added support for tensorflow * Updates to get tests passing * Or --> And * Moving modopt to allow working with tensorflow * Fix issues with wos * Fix all flakes finally! * Update modopt/base/backend.py Co-authored-by: Samuel Farrens * Update modopt/base/backend.py Co-authored-by: Samuel Farrens * Minute updates to codes * Add dynamic module * Fix docu * Fix PEP Co-authored-by: chaithyagr Co-authored-by: Samuel Farrens * Fix 115 (#116) * Fix issues * Add right tests * Fix PEP Co-authored-by: chaithyagr * Minor bug fix, remove elif (#124) Co-authored-by: chaithyagr * Add tests for modopt.base.backend and fix minute bug uncovered (#126) * Minor bug fix, remove elif * Add tests for backend * Fix tests * Add tests * Remove cupy * PEP fixes * Fix PEP * Fix PEP and update * Final PEP * Update setup.cfg Co-authored-by: Samuel Farrens * Update test_base.py Co-authored-by: chaithyagr Co-authored-by: Samuel Farrens * Release cleanup (#128) * updated GPU dependencies * added logo to manifest * updated package version and release date * Unpin package dependencies (#189) * unpinned dependencies * updated pinned documentation dependency versions * Add Gradient descent algorithms (#196) * Version 1.5.1 patch release (#114) * Add support for tensorflow backend which allows for differentiability (#112) * Added support for tensorflow * Updates to get tests passing * Or --> And * Moving modopt to allow working with tensorflow * Fix issues with wos * Fix all flakes finally! * Update modopt/base/backend.py Co-authored-by: Samuel Farrens * Update modopt/base/backend.py Co-authored-by: Samuel Farrens * Minute updates to codes * Add dynamic module * Fix docu * Fix PEP Co-authored-by: chaithyagr Co-authored-by: Samuel Farrens * Fix 115 (#116) * Fix issues * Add right tests * Fix PEP Co-authored-by: chaithyagr * Minor bug fix, remove elif (#124) Co-authored-by: chaithyagr * Add tests for modopt.base.backend and fix minute bug uncovered (#126) * Minor bug fix, remove elif * Add tests for backend * Fix tests * Add tests * Remove cupy * PEP fixes * Fix PEP * Fix PEP and update * Final PEP * Update setup.cfg Co-authored-by: Samuel Farrens * Update test_base.py Co-authored-by: chaithyagr Co-authored-by: Samuel Farrens * Release cleanup (#128) * updated GPU dependencies * added logo to manifest * updated package version and release date Co-authored-by: Chaithya G R Co-authored-by: chaithyagr * make algorithms a module. * add Gradient Descent Algorithms * enforce WPS compliance. * add test for gradient descent * Docstrings improvements * Add See Also and minor corrections * add idx initialisation for all algorithms. * fix merge error * fix typo Co-authored-by: Samuel Farrens Co-authored-by: Chaithya G R Co-authored-by: chaithyagr * Release cleanup (#198) * started clean up for next release * update progress * further clean up * additional clean up * cleaned up link to logo * fixed index.rst * fixed conflict * Fast Singular Value Thresholding (#209) * add SingularValueThreshold This Method provides 10x faster SVT estimation than the LowRankMatrix Operator. * linting * add test for fast computation. * flake8 compliance * Ignore DAR000 Error. * Update modopt/signal/svd.py tuples in docstring Co-authored-by: Samuel Farrens * Update modopt/signal/svd.py typo Co-authored-by: Samuel Farrens * Update modopt/opt/proximity.py typo Co-authored-by: Samuel Farrens * update docstring * fix isort * Update modopt/signal/svd.py Co-authored-by: Samuel Farrens * Update modopt/signal/svd.py Co-authored-by: Samuel Farrens * run isort Co-authored-by: Samuel Farrens * added writeable input data array feature for benchopt (#213) * removed flake8 limit * updated patch version * [lint] pydocstyle compliance. (#228) * [lint] pydocstyle compliance. * use pytest-pydocstyle * Power method: fix #211 (#212) * Correct the norm update for Power Method x_new should be divided by its norm, not by x_old_norm. * fix test value We are testing for eigen value of Identity. It should be one. * fix WPS350 * fix test value for unconverged case Co-authored-by: Samuel Farrens * Switch from progressbar to tqdm (#231) * switch from progressbar to tqdm. The progress bar can be provided externally for nested usage. * exposes the progress bar argument. * Child classes better have to implement these. (my linter was complaining) * update docs for progress bar using tqdm. * fix WPS errors * drop progressbar requirement, add tqdm. * [lint] disable warning for non implemented function. * simplify progbar check and argument passthrough * Update README for tqdm dependency (#240) Remote progressbar, use tqdm. * add small help for the metric argument. (#241) * add small help for the metric argument. * RST validation * use single quote * use double backticks. Co-authored-by: Samuel Farrens * add implementation for admm and fast admm. Based on Goldstein2014 * add Goldstein ref. * WPS compliance. * Abstract class for cost function. * add custom cost operator for admm. * fix WPS compliance. * Ci update (#268) * update python version support. * use string for CI. * remove flake8 and wemake-python-styleguide This anticipates the change to black formatting. * remove wps checks * apparently conda does not support 3.11 for now * remove all linting testing. * fix np.int warning/error * fix dtype error * fix precision for doctest * added black and isort support * Update python version in README * add 3.7 for test back * don't test 3.10 twice * Test rewrite (#266) * add MatrixOperator. * move base test to pytest. * [fixme] remove flake8 and emoji config. * rewrite test_math module using pytest. * use fail/skipparam helper function. * generalize usage of failparam * refactor test_signal. * refactor test_signal, the end. * lint * fix missing parameter. * add dummy object test helper. * rewrite test for cost and gradients. * show missing lines in coverage reports * rewrite of proximity operators testing. * add fail low rank method. * add cases for algorithms test * add algorithm test. * add pytest-cases and pytest-xdists support. * add support for testing metrics. * improve base module coverage. * test for wrong mask in metric module. * add docstring. * update email adress and authors field. * 100% coverage for transform module. * move linear operator to class * update docstring. * paramet(e)rization. * update docstring. * improve test_helper module. * raises should be specified for each failparam call. * encapsulate module's test in classes. * skip test if sklearn is not installed. * pin pydocstyle * removed unnormalised Gaussian kernel option and corresponding test * Restrict scikit-image version for testing * added fix for basic test suite * set behaviour for different astropy versions * updated docstring for gaussian_kernel * Use example scripts as tests. (#277) * Initialize the example module. * do not export the assert statements. * add matplotlib as requirement. * add support for sphinx-gallery * Update modopt/examples/README.rst Co-authored-by: Samuel Farrens * Update modopt/examples/__init__.py Co-authored-by: Samuel Farrens * Update modopt/examples/conftest.py Co-authored-by: Samuel Farrens * Update modopt/examples/example_lasso_forward_backward.py Co-authored-by: Samuel Farrens * Update modopt/examples/example_lasso_forward_backward.py Co-authored-by: Samuel Farrens * ignore auto_example folder * doc formatting. * add pogm and basic comparison. * fix: add matplotlib for the plotting in examples scripts. * fix: add matplotlib for basic ci too. * ci: run pytest with xdist for faster testing --------- Co-authored-by: Samuel Farrens * fix: specify data_range for ssim. Refs: #290 * typos. * feat(test): add test for admm. * feat(admm): improve doc. * refactor: rename abstract cost to CostParent. * feat: add test for fast admm. * feat(admm): improve docstrings. * style: remove extra line.c * feat: make POGM more memory efficient. * feat: add a dummy cost for the identity operator. * feat: create a linear operator module, add wavelet transform. * feat: add test case for wavelet transform. * Update setup.py --------- Co-authored-by: chaithyagr Co-authored-by: Samuel Farrens Co-authored-by: Pierre-Antoine Comby <77174042+paquiteau@users.noreply.github.com> Co-authored-by: Pierre-antoine Comby Co-authored-by: Pierre-Antoine Comby --- .github/workflows/cd-build.yml | 4 +- .github/workflows/ci-build.yml | 21 +- .gitignore | 1 + README.md | 4 +- develop.txt | 13 +- docs/requirements.txt | 1 + docs/source/conf.py | 13 + docs/source/refs.bib | 12 + docs/source/toc.rst | 1 + modopt/examples/README.rst | 5 + modopt/examples/__init__.py | 10 + modopt/examples/conftest.py | 46 + .../example_lasso_forward_backward.py | 153 ++ modopt/math/matrix.py | 17 +- modopt/math/metrics.py | 6 +- modopt/math/stats.py | 46 +- modopt/opt/algorithms/__init__.py | 3 +- modopt/opt/algorithms/admm.py | 337 ++++ modopt/opt/algorithms/base.py | 73 +- modopt/opt/algorithms/forward_backward.py | 38 +- modopt/opt/algorithms/primal_dual.py | 11 +- modopt/opt/cost.py | 152 +- modopt/opt/linear/__init__.py | 21 + modopt/opt/{linear.py => linear/base.py} | 58 +- modopt/opt/linear/wavelet.py | 216 +++ modopt/opt/proximity.py | 2 +- modopt/signal/filter.py | 8 +- modopt/signal/positivity.py | 2 +- modopt/signal/svd.py | 4 +- modopt/tests/test_algorithms.py | 673 +++----- modopt/tests/test_base.py | 435 ++---- modopt/tests/test_helpers/__init__.py | 1 + modopt/tests/test_helpers/utils.py | 23 + modopt/tests/test_math.py | 661 +++----- modopt/tests/test_opt.py | 1390 +++++------------ modopt/tests/test_signal.py | 561 +++---- requirements.txt | 2 +- setup.cfg | 12 +- setup.py | 4 +- 39 files changed, 2439 insertions(+), 2601 deletions(-) create mode 100644 modopt/examples/README.rst create mode 100644 modopt/examples/__init__.py create mode 100644 modopt/examples/conftest.py create mode 100644 modopt/examples/example_lasso_forward_backward.py create mode 100644 modopt/opt/algorithms/admm.py create mode 100644 modopt/opt/linear/__init__.py rename modopt/opt/{linear.py => linear/base.py} (84%) create mode 100644 modopt/opt/linear/wavelet.py create mode 100644 modopt/tests/test_helpers/__init__.py create mode 100644 modopt/tests/test_helpers/utils.py diff --git a/.github/workflows/cd-build.yml b/.github/workflows/cd-build.yml index 1e49f8bc..fca9feb1 100644 --- a/.github/workflows/cd-build.yml +++ b/.github/workflows/cd-build.yml @@ -62,9 +62,7 @@ jobs: - name: Set up Conda with Python 3.8 uses: conda-incubator/setup-miniconda@v2 with: - auto-update-conda: true - python-version: 3.8 - auto-activate-base: false + python-version: "3.8" - name: Install dependencies shell: bash -l {0} diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml index 3ffcb6f4..c4ba28a0 100644 --- a/.github/workflows/ci-build.yml +++ b/.github/workflows/ci-build.yml @@ -16,21 +16,12 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macos-latest] - python-version: [3.8] + python-version: ["3.10"] steps: - name: Checkout uses: actions/checkout@v2 - - name: Report WPS Errors - uses: wemake-services/wemake-python-styleguide@0.14.1 - continue-on-error: true - with: - reporter: 'github-pr-review' - path: './modopt' - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - name: Set up Conda with Python ${{ matrix.python-version }} uses: conda-incubator/setup-miniconda@v2 with: @@ -52,7 +43,7 @@ jobs: python -m pip install --upgrade pip python -m pip install -r develop.txt python -m pip install -r docs/requirements.txt - python -m pip install astropy scikit-image scikit-learn + python -m pip install astropy "scikit-image<0.20" scikit-learn matplotlib python -m pip install tensorflow>=2.4.1 python -m pip install twine python -m pip install . @@ -61,7 +52,7 @@ jobs: shell: bash -l {0} run: | export PATH=/usr/share/miniconda/bin:$PATH - python setup.py test + pytest -n 2 - name: Save Test Results if: always() @@ -98,7 +89,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macos-latest] - python-version: [3.6, 3.7, 3.9] + python-version: ["3.7", "3.8", "3.9"] steps: - name: Checkout @@ -117,11 +108,11 @@ jobs: python --version python -m pip install --upgrade pip python -m pip install -r develop.txt - python -m pip install astropy scikit-image scikit-learn + python -m pip install astropy "scikit-image<0.20" scikit-learn matplotlib python -m pip install . - name: Run Tests shell: bash -l {0} run: | export PATH=/usr/share/miniconda/bin:$PATH - python setup.py test + pytest -n 2 diff --git a/.gitignore b/.gitignore index 06dff8db..f9eaaa68 100644 --- a/.gitignore +++ b/.gitignore @@ -73,6 +73,7 @@ instance/ docs/_build/ docs/source/fortuna.* docs/source/scripts.* +docs/source/auto_examples/ docs/source/*.nblink # PyBuilder diff --git a/README.md b/README.md index acb316ad..223d0b73 100644 --- a/README.md +++ b/README.md @@ -37,11 +37,11 @@ All packages required by ModOpt should be installed automatically. Optional pack In order to run the code in this repository the following packages must be installed: -* [Python](https://www.python.org/) [> 3.6] +* [Python](https://www.python.org/) [> 3.7] * [importlib_metadata](https://importlib-metadata.readthedocs.io/en/latest/) [==3.7.0] * [Numpy](http://www.numpy.org/) [==1.19.5] * [Scipy](http://www.scipy.org/) [==1.5.4] -* [Progressbar 2](https://progressbar-2.readthedocs.io/) [==3.53.1] +* [tqdm](https://tqdm.github.io/) [>=4.64.0] ### Optional Packages diff --git a/develop.txt b/develop.txt index 8beef0ff..6ff665eb 100644 --- a/develop.txt +++ b/develop.txt @@ -1,9 +1,12 @@ coverage>=5.5 -flake8>=4 -nose>=1.3.7 pytest>=6.2.2 +pytest-raises>=0.10 +pytest-cases>= 3.6 +pytest-xdist>= 3.0.1 pytest-cov>=2.11.1 -pytest-pep8>=1.0.6 pytest-emoji>=0.2.0 -pytest-flake8>=1.0.7 -wemake-python-styleguide>=0.15.2 +pydocstyle==6.1.1 +pytest-pydocstyle>=2.2.0 +black +isort +pytest-black diff --git a/docs/requirements.txt b/docs/requirements.txt index 4d2a14fb..c9e29c88 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -6,3 +6,4 @@ numpydoc==1.1.0 sphinx==4.3.1 sphinxcontrib-bibtex==2.4.1 sphinxawesome-theme==3.2.1 +sphinx-gallery==0.11.1 diff --git a/docs/source/conf.py b/docs/source/conf.py index fb954f6d..46564b9f 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -45,6 +45,7 @@ 'nbsphinx', 'nbsphinx_link', 'numpydoc', + "sphinx_gallery.gen_gallery" ] # Include module names for objects @@ -145,6 +146,18 @@ # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. html_show_copyright = True + + +# -- Options for Sphinx Gallery ---------------------------------------------- + +sphinx_gallery_conf = { + "examples_dirs": ["../../modopt/examples/"], + "filename_pattern": "/example_", + "ignore_pattern": r"/(__init__|conftest)\.py", +} + + + # -- Options for nbshpinx output ------------------------------------------ diff --git a/docs/source/refs.bib b/docs/source/refs.bib index d8365e71..7782ca52 100644 --- a/docs/source/refs.bib +++ b/docs/source/refs.bib @@ -207,3 +207,15 @@ @article{zou2005 journal = {Journal of the Royal Statistical Society Series B}, doi = {10.1111/j.1467-9868.2005.00527.x} } + +@article{Goldstein2014, + author={Goldstein, Tom and O’Donoghue, Brendan and Setzer, Simon and Baraniuk, Richard}, + year={2014}, + month={Jan}, + pages={1588–1623}, + title={Fast Alternating Direction Optimization Methods}, + journal={SIAM Journal on Imaging Sciences}, + volume={7}, + ISSN={1936-4954}, + doi={10/gdwr49}, +} diff --git a/docs/source/toc.rst b/docs/source/toc.rst index 84a6af87..ef5753f5 100644 --- a/docs/source/toc.rst +++ b/docs/source/toc.rst @@ -25,6 +25,7 @@ plugin_example notebooks + auto_examples/index .. toctree:: :hidden: diff --git a/modopt/examples/README.rst b/modopt/examples/README.rst new file mode 100644 index 00000000..e6ffbe27 --- /dev/null +++ b/modopt/examples/README.rst @@ -0,0 +1,5 @@ +======== +Examples +======== + +This is a collection of Python scripts demonstrating the use of ModOpt. diff --git a/modopt/examples/__init__.py b/modopt/examples/__init__.py new file mode 100644 index 00000000..d7e77357 --- /dev/null +++ b/modopt/examples/__init__.py @@ -0,0 +1,10 @@ +"""EXAMPLES. + +This module contains documented examples that demonstrate the usage of various +ModOpt tools. + +These examples also serve as integration tests for various methods. + +:Author: Pierre-Antoine Comby + +""" diff --git a/modopt/examples/conftest.py b/modopt/examples/conftest.py new file mode 100644 index 00000000..73358679 --- /dev/null +++ b/modopt/examples/conftest.py @@ -0,0 +1,46 @@ +"""TEST CONFIGURATION. + +This module contains methods for configuring the testing of the example +scripts. + +:Author: Pierre-Antoine Comby + +Notes +----- +Based on: +https://stackoverflow.com/questions/56807698/how-to-run-script-as-pytest-test + +""" +from pathlib import Path +import runpy +import pytest + +def pytest_collect_file(path, parent): + """Pytest hook. + + Create a collector for the given path, or None if not relevant. + The new node needs to have the specified parent as parent. + """ + p = Path(path) + if p.suffix == '.py' and 'example' in p.name: + return Script.from_parent(parent, path=p, name=p.name) + + +class Script(pytest.File): + """Script files collected by pytest.""" + + def collect(self): + """Collect the script as its own item.""" + yield ScriptItem.from_parent(self, name=self.name) + +class ScriptItem(pytest.Item): + """Item script collected by pytest.""" + + def runtest(self): + """Run the script as a test.""" + runpy.run_path(str(self.path)) + + def repr_failure(self, excinfo): + """Return only the error traceback of the script.""" + excinfo.traceback = excinfo.traceback.cut(path=self.path) + return super().repr_failure(excinfo) diff --git a/modopt/examples/example_lasso_forward_backward.py b/modopt/examples/example_lasso_forward_backward.py new file mode 100644 index 00000000..7f820000 --- /dev/null +++ b/modopt/examples/example_lasso_forward_backward.py @@ -0,0 +1,153 @@ +# noqa: D205 +""" +Solving the LASSO Problem with the Forward Backward Algorithm. +============================================================== + +This an example to show how to solve an example LASSO Problem +using the Forward-Backward Algorithm. + +In this example we are going to use: + - Modopt Operators (Linear, Gradient, Proximal) + - Modopt implementation of solvers + - Modopt Metric API. +TODO: add reference to LASSO paper. +""" + +import numpy as np +import matplotlib.pyplot as plt + +from modopt.opt.algorithms import ForwardBackward, POGM +from modopt.opt.cost import costObj +from modopt.opt.linear import LinearParent, Identity +from modopt.opt.gradient import GradBasic +from modopt.opt.proximity import SparseThreshold +from modopt.math.matrix import PowerMethod +from modopt.math.stats import mse + +# %% +# Here we create a instance of the LASSO Problem + +BETA_TRUE = np.array( + [3.0, 1.5, 0, 0, 2, 0, 0, 0] +) # 8 original values from lLASSO Paper +DIM = len(BETA_TRUE) + + +rng = np.random.default_rng() +sigma_noise = 1 +obs = 20 +# create a measurement matrix with decaying covariance matrix. +cov = 0.4 ** abs((np.arange(DIM) * np.ones((DIM, DIM))).T - np.arange(DIM)) +x = rng.multivariate_normal(np.zeros(DIM), cov, obs) + +y = x @ BETA_TRUE +y_noise = y + (sigma_noise * np.random.standard_normal(obs)) + + +# %% +# Next we create Operators for solving the problem. + +# MatrixOperator could also work here. +lin_op = LinearParent(lambda b: x @ b, lambda bb: x.T @ bb) +grad_op = GradBasic(y_noise, op=lin_op.op, trans_op=lin_op.adj_op) + +prox_op = SparseThreshold(Identity(), 1, thresh_type="soft") + +# %% +# In order to get the best convergence rate, we first determine the Lipschitz constant of the gradient Operator +# + +calc_lips = PowerMethod(grad_op.trans_op_op, 8, data_type="float32", auto_run=True) +lip = calc_lips.spec_rad +print("lipschitz constant:", lip) + +# %% +# Solving using FISTA algorithm +# ----------------------------- +# +# TODO: Add description/Reference of FISTA. + +cost_op_fista = costObj([grad_op, prox_op], verbose=False) + +fb_fista = ForwardBackward( + np.zeros(8), + beta_param=1 / lip, + grad=grad_op, + prox=prox_op, + cost=cost_op_fista, + metric_call_period=1, + auto_iterate=False, # Just to give us the pleasure of doing things by ourself. +) + +fb_fista.iterate() + +# %% +# After the run we can have a look at the results + +print(fb_fista.x_final) +mse_fista = mse(fb_fista.x_final, BETA_TRUE) +plt.stem(fb_fista.x_final, label="estimation", linefmt="C0-") +plt.stem(BETA_TRUE, label="reference", linefmt="C1-") +plt.legend() +plt.title(f"FISTA Estimation MSE={mse_fista:.4f}") + +# sphinx_gallery_start_ignore +assert mse(fb_fista.x_final, BETA_TRUE) < 1 +# sphinx_gallery_end_ignore + + +# %% +# Solving Using the POGM Algorithm +# -------------------------------- +# +# TODO: Add description/Reference to POGM. + + +cost_op_pogm = costObj([grad_op, prox_op], verbose=False) + +fb_pogm = POGM( + np.zeros(8), + np.zeros(8), + np.zeros(8), + np.zeros(8), + beta_param=1 / lip, + grad=grad_op, + prox=prox_op, + cost=cost_op_pogm, + metric_call_period=1, + auto_iterate=False, # Just to give us the pleasure of doing things by ourself. +) + +fb_pogm.iterate() + +# %% +# After the run we can have a look at the results + +print(fb_pogm.x_final) +mse_pogm = mse(fb_pogm.x_final, BETA_TRUE) + +plt.stem(fb_pogm.x_final, label="estimation", linefmt="C0-") +plt.stem(BETA_TRUE, label="reference", linefmt="C1-") +plt.legend() +plt.title(f"FISTA Estimation MSE={mse_pogm:.4f}") +# +# sphinx_gallery_start_ignore +assert mse(fb_pogm.x_final, BETA_TRUE) < 1 + +# %% +# Comparing the Two algorithms +# ---------------------------- + +plt.figure() +plt.semilogy(cost_op_fista._cost_list, label="FISTA convergence") +plt.semilogy(cost_op_pogm._cost_list, label="POGM convergence") +plt.xlabel("iterations") +plt.ylabel("Cost Function") +plt.legend() +plt.show() + + +# %% +# We can see that the two algorithm converges quickly, and POGM requires less iterations. +# However the POGM iterations are more costly, so a proper benchmark with time measurement is needed. +# Check the benchopt benchmark for more details. diff --git a/modopt/math/matrix.py b/modopt/math/matrix.py index 939cf41f..8361531d 100644 --- a/modopt/math/matrix.py +++ b/modopt/math/matrix.py @@ -285,9 +285,9 @@ class PowerMethod(object): >>> np.random.seed(1) >>> pm = PowerMethod(lambda x: x.dot(x.T), (3, 3)) >>> np.around(pm.spec_rad, 6) - 0.904292 + 1.0 >>> np.around(pm.inv_spec_rad, 6) - 1.105837 + 1.0 Notes ----- @@ -348,17 +348,21 @@ def get_spec_rad(self, tolerance=1e-6, max_iter=20, extra_factor=1.0): # Set (or reset) values of x. x_old = self._set_initial_x() + xp = get_array_module(x_old) + x_old_norm = xp.linalg.norm(x_old) + + x_old /= x_old_norm + # Iterate until the L2 norm of x converges. for i_elem in range(max_iter): - xp = get_array_module(x_old) - x_old_norm = xp.linalg.norm(x_old) - - x_new = self._operator(x_old) / x_old_norm + x_new = self._operator(x_old) x_new_norm = xp.linalg.norm(x_new) + x_new /= x_new_norm + if (xp.abs(x_new_norm - x_old_norm) < tolerance): message = ( ' - Power Method converged after {0} iterations!' @@ -374,6 +378,7 @@ def get_spec_rad(self, tolerance=1e-6, max_iter=20, extra_factor=1.0): print(message.format(max_iter)) xp.copyto(x_old, x_new) + x_old_norm = x_new_norm self.spec_rad = x_new_norm * extra_factor self.inv_spec_rad = 1.0 / self.spec_rad diff --git a/modopt/math/metrics.py b/modopt/math/metrics.py index 1b870e23..21952624 100644 --- a/modopt/math/metrics.py +++ b/modopt/math/metrics.py @@ -23,7 +23,7 @@ def min_max_normalize(img): """Min-Max Normalize. - Centre and normalize a given array. + Normalize a given array in the [0,1] range. Parameters ---------- @@ -33,7 +33,7 @@ def min_max_normalize(img): Returns ------- numpy.ndarray - Centred and normalized array + normalized array """ min_img = img.min() @@ -126,7 +126,7 @@ def ssim(test, ref, mask=None): test, ref, mask = _preprocess_input(test, ref, mask) test = move_to_cpu(test) - assim, ssim_value = compare_ssim(test, ref, full=True) + assim, ssim_value = compare_ssim(test, ref, full=True, data_range=1.0) if mask is None: return assim diff --git a/modopt/math/stats.py b/modopt/math/stats.py index 3ac818a7..59bf6759 100644 --- a/modopt/math/stats.py +++ b/modopt/math/stats.py @@ -11,6 +11,8 @@ import numpy as np try: + from packaging import version + from astropy import __version__ as astropy_version from astropy.convolution import Gaussian2DKernel except ImportError: # pragma: no cover import_astropy = False @@ -18,7 +20,7 @@ import_astropy = True -def gaussian_kernel(data_shape, sigma, norm='max'): +def gaussian_kernel(data_shape, sigma, norm="max"): """Gaussian kernel. This method produces a Gaussian kerenal of a specified size and dispersion. @@ -29,9 +31,8 @@ def gaussian_kernel(data_shape, sigma, norm='max'): Desiered shape of the kernel sigma : float Standard deviation of the kernel - norm : {'max', 'sum', 'none'}, optional - Normalisation of the kerenl (options are ``'max'``, ``'sum'`` or - ``'none'``, default is ``'max'``) + norm : {'max', 'sum'}, optional + Normalisation of the kerenl (options are ``'max'`` or ``'sum'``, default is ``'max'``) Returns ------- @@ -60,22 +61,22 @@ def gaussian_kernel(data_shape, sigma, norm='max'): """ if not import_astropy: # pragma: no cover - raise ImportError('Astropy package not found.') + raise ImportError("Astropy package not found.") - if norm not in {'max', 'sum', 'none'}: + if norm not in {"max", "sum"}: raise ValueError('Invalid norm, options are "max", "sum" or "none".') kernel = np.array( Gaussian2DKernel(sigma, x_size=data_shape[1], y_size=data_shape[0]), ) - if norm == 'max': + if norm == "max": return kernel / np.max(kernel) - elif norm == 'sum': + elif version.parse(astropy_version) < version.parse("5.2"): return kernel / np.sum(kernel) - elif norm == 'none': + else: return kernel @@ -147,7 +148,7 @@ def mse(data1, data2): return np.mean((data1 - data2) ** 2) -def psnr(data1, data2, method='starck', max_pix=255): +def psnr(data1, data2, method="starck", max_pix=255): r"""Peak Signal-to-Noise Ratio. This method calculates the Peak Signal-to-Noise Ratio between two data @@ -202,23 +203,21 @@ def psnr(data1, data2, method='starck', max_pix=255): 10\log_{10}(\mathrm{MSE})) """ - if method == 'starck': - return ( - 20 * np.log10( - (data1.shape[0] * np.abs(np.max(data1) - np.min(data1))) - / np.linalg.norm(data1 - data2), - ) + if method == "starck": + return 20 * np.log10( + (data1.shape[0] * np.abs(np.max(data1) - np.min(data1))) + / np.linalg.norm(data1 - data2), ) - elif method == 'wiki': - return (20 * np.log10(max_pix) - 10 * np.log10(mse(data1, data2))) + elif method == "wiki": + return 20 * np.log10(max_pix) - 10 * np.log10(mse(data1, data2)) raise ValueError( 'Invalid PSNR method. Options are "starck" and "wiki"', ) -def psnr_stack(data1, data2, metric=np.mean, method='starck'): +def psnr_stack(data1, data2, metric=np.mean, method="starck"): """Peak Signa-to-Noise for stack of images. This method calculates the PSNRs for two stacks of 2D arrays. @@ -261,12 +260,11 @@ def psnr_stack(data1, data2, metric=np.mean, method='starck'): """ if data1.ndim != 3 or data2.ndim != 3: - raise ValueError('Input data must be a 3D np.ndarray') + raise ValueError("Input data must be a 3D np.ndarray") - return metric([ - psnr(i_elem, j_elem, method=method) - for i_elem, j_elem in zip(data1, data2) - ]) + return metric( + [psnr(i_elem, j_elem, method=method) for i_elem, j_elem in zip(data1, data2)] + ) def sigma_mad(input_data): diff --git a/modopt/opt/algorithms/__init__.py b/modopt/opt/algorithms/__init__.py index e0ac2572..d4e7082b 100644 --- a/modopt/opt/algorithms/__init__.py +++ b/modopt/opt/algorithms/__init__.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -r"""OPTIMISATION ALGOTITHMS. +r"""OPTIMISATION ALGORITHMS. This module contains class implementations of various optimisation algoritms. @@ -57,3 +57,4 @@ SAGAOptGradOpt, VanillaGenericGradOpt) from modopt.opt.algorithms.primal_dual import Condat +from modopt.opt.algorithms.admm import ADMM, FastADMM diff --git a/modopt/opt/algorithms/admm.py b/modopt/opt/algorithms/admm.py new file mode 100644 index 00000000..b881b770 --- /dev/null +++ b/modopt/opt/algorithms/admm.py @@ -0,0 +1,337 @@ +"""ADMM Algorithms.""" +import numpy as np + +from modopt.base.backend import get_array_module +from modopt.opt.algorithms.base import SetUp +from modopt.opt.cost import CostParent + + +class ADMMcostObj(CostParent): + r"""Cost Object for the ADMM problem class. + + Parameters + ---------- + cost_funcs: 2-tuples of callable + f and g function. + A : OperatorBase + First Operator + B : OperatorBase + Second Operator + b : numpy.ndarray + Observed data + **kwargs : dict + Extra parameters for cost operator configuration + + Notes + ----- + Compute :math:`f(u)+g(v) + \tau \| Au +Bv - b\|^2` + + See Also + -------- + CostParent: parent class + """ + + def __init__(self, cost_funcs, A, B, b, tau, **kwargs): + super().__init__(*kwargs) + self.cost_funcs = cost_funcs + self.A = A + self.B = B + self.b = b + self.tau = tau + + def _calc_cost(self, u, v, **kwargs): + """Calculate the cost. + + This method calculates the cost from each of the input operators. + + Parameters + ---------- + u: numpy.ndarray + First primal variable of ADMM + v: numpy.ndarray + Second primal variable of ADMM + + Returns + ------- + float + Cost value + + """ + xp = get_array_module(u) + cost = self.cost_funcs[0](u) + cost += self.cost_funcs[1](v) + cost += self.tau * xp.linalg.norm(self.A.op(u) + self.B.op(v) - self.b) + return cost + + +class ADMM(SetUp): + r"""Fast ADMM Optimisation Algorihm. + + This class implement the ADMM algorithm described in :cite:`Goldstein2014` (Algorithm 1). + + Parameters + ---------- + u: numpy.ndarray + Initial value for first primal variable of ADMM + v: numpy.ndarray + Initial value for second primal variable of ADMM + mu: numpy.ndarray + Initial value for lagrangian multiplier. + A : modopt.opt.linear.LinearOperator + Linear operator for u + B: modopt.opt.linear.LinearOperator + Linear operator for v + b : numpy.ndarray + Constraint vector + optimizers: tuple + 2-tuple of callable, that are the optimizers for the u and v. + Each callable should access an init and obs argument and returns an estimate for: + .. math:: u_{k+1} = \argmin H(u) + \frac{\tau}{2}\|A u - y\|^2 + .. math:: v_{k+1} = \argmin G(v) + \frac{\tau}{2}\|Bv - y \|^2 + cost_funcs: tuple + 2-tuple of callable, that compute values of H and G. + tau: float, default=1 + Coupling parameter for ADMM. + + Notes + ----- + The algorithm solve the problem: + + .. math:: u, v = \arg\min H(u) + G(v) + \frac\tau2 \|Au + Bv - b \|_2^2 + + with the following augmented lagrangian: + + .. math :: \mathcal{L}_{\tau}(u,v, \lambda) = H(u) + G(v) + +\langle\lambda |Au + Bv -b \rangle + \frac\tau2 \| Au + Bv -b \|^2 + + To allow easy iterative solving, the change of variable + :math:`\mu=\lambda/\tau` is used. Hence, the lagrangian of interest is: + + .. math :: \tilde{\mathcal{L}}_{\tau}(u,v, \mu) = H(u) + G(v) + + \frac\tau2 \left(\|\mu + Au +Bv - b\|^2 - \|\mu\|^2\right) + + See Also + -------- + SetUp: parent class + """ + + def __init__( + self, + u, + v, + mu, + A, + B, + b, + optimizers, + tau=1, + cost_funcs=None, + **kwargs, + ): + super().__init__(**kwargs) + self.A = A + self.B = B + self.b = b + self._opti_H = optimizers[0] + self._opti_G = optimizers[1] + self._tau = tau + if cost_funcs is not None: + self._cost_func = ADMMcostObj(cost_funcs, A, B, b, tau) + else: + self._cost_func = None + + # init iteration variables. + self._u_old = self.xp.copy(u) + self._u_new = self.xp.copy(u) + self._v_old = self.xp.copy(v) + self._v_new = self.xp.copy(v) + self._mu_new = self.xp.copy(mu) + self._mu_old = self.xp.copy(mu) + + def _update(self): + self._u_new = self._opti_H( + init=self._u_old, + obs=self.B.op(self._v_old) + self._u_old - self.b, + ) + tmp = self.A.op(self._u_new) + self._v_new = self._opti_G( + init=self._v_old, + obs=tmp + self._u_old - self.b, + ) + + self._mu_new = self._mu_old + (tmp + self.B.op(self._v_new) - self.b) + + # update cycle + self._u_old = self.xp.copy(self._u_new) + self._v_old = self.xp.copy(self._v_new) + self._mu_old = self.xp.copy(self._mu_new) + + # Test cost function for convergence. + if self._cost_func: + self.converge = self.any_convergence_flag() + self.converge |= self._cost_func.get_cost(self._u_new, self._v_new) + + def iterate(self, max_iter=150): + """Iterate. + + This method calls update until either convergence criteria is met or + the maximum number of iterations is reached. + + Parameters + ---------- + max_iter : int, optional + Maximum number of iterations (default is ``150``) + """ + self._run_alg(max_iter) + + # retrieve metrics results + self.retrieve_outputs() + # rename outputs as attributes + self.u_final = self._u_new + self.x_final = self.u_final # for backward compatibility + self.v_final = self._v_new + + def get_notify_observers_kwargs(self): + """Notify observers. + + Return the mapping between the metrics call and the iterated + variables. + + Returns + ------- + dict + The mapping between the iterated variables + """ + return { + 'x_new': self._u_new, + 'v_new': self._v_new, + 'idx': self.idx, + } + + def retrieve_outputs(self): + """Retrieve outputs. + + Declare the outputs of the algorithms as attributes: x_final, + y_final, metrics. + """ + metrics = {} + for obs in self._observers['cv_metrics']: + metrics[obs.name] = obs.retrieve_metrics() + self.metrics = metrics + + +class FastADMM(ADMM): + r"""Fast ADMM Optimisation Algorihm. + + This class implement the fast ADMM algorithm + (Algorithm 8 from :cite:`Goldstein2014`) + + Parameters + ---------- + u: numpy.ndarray + Initial value for first primal variable of ADMM + v: numpy.ndarray + Initial value for second primal variable of ADMM + mu: numpy.ndarray + Initial value for lagrangian multiplier. + A : modopt.opt.linear.LinearOperator + Linear operator for u + B: modopt.opt.linear.LinearOperator + Linear operator for v + b : numpy.ndarray + Constraint vector + optimizers: tuple + 2-tuple of callable, that are the optimizers for the u and v. + Each callable should access an init and obs argument and returns an estimate for: + .. math:: u_{k+1} = \argmin H(u) + \frac{\tau}{2}\|A u - y\|^2 + .. math:: v_{k+1} = \argmin G(v) + \frac{\tau}{2}\|Bv - y \|^2 + cost_funcs: tuple + 2-tuple of callable, that compute values of H and G. + tau: float, default=1 + Coupling parameter for ADMM. + eta: float, default=0.999 + Convergence parameter for ADMM. + alpha: float, default=1. + Initial value for the FISTA-like acceleration parameter. + + Notes + ----- + This is an accelerated version of the ADMM algorithm. The convergence hypothesis are stronger than for the ADMM algorithm. + + See Also + -------- + ADMM: parent class + """ + + def __init__( + self, + u, + v, + mu, + A, + B, + b, + optimizers, + cost_funcs=None, + alpha=1, + eta=0.999, + tau=1, + **kwargs, + ): + super().__init__( + u=u, + v=b, + mu=mu, + A=A, + B=B, + b=b, + optimizers=optimizers, + cost_funcs=cost_funcs, + **kwargs, + ) + self._c_old = np.inf + self._c_new = 0 + self._eta = eta + self._alpha_old = alpha + self._alpha_new = alpha + self._v_hat = self.xp.copy(self._v_new) + self._mu_hat = self.xp.copy(self._mu_new) + + def _update(self): + # Classical ADMM steps + self._u_new = self._opti_H( + init=self._u_old, + obs=self.B.op(self._v_hat) + self._u_old - self.b, + ) + tmp = self.A.op(self._u_new) + self._v_new = self._opti_G( + init=self._v_hat, + obs=tmp + self._u_old - self.b, + ) + + self._mu_new = self._mu_hat + (tmp + self.B.op(self._v_new) - self.b) + + # restarting condition + self._c_new = self.xp.linalg.norm(self._mu_new - self._mu_hat) + self._c_new += self._tau * self.xp.linalg.norm( + self.B.op(self._v_new - self._v_hat), + ) + if self._c_new < self._eta * self._c_old: + self._alpha_new = 1 + np.sqrt(1 + 4 * self._alpha_old**2) + beta = (self._alpha_new - 1) / self._alpha_old + self._v_hat = self._v_new + (self._v_new - self._v_old) * beta + self._mu_hat = self._mu_new + (self._mu_new - self._mu_old) * beta + else: + # reboot to old iteration + self._alpha_new = 1 + self._v_hat = self._v_old + self._mu_hat = self._mu_old + self._c_new = self._c_old / self._eta + + self.xp.copyto(self._u_old, self._u_new) + self.xp.copyto(self._v_old, self._v_new) + self.xp.copyto(self._mu_old, self._mu_new) + # Test cost function for convergence. + if self._cost_func: + self.converge = self.any_convergence_flag() + self.convergd |= self._cost_func.get_cost(self._u_new, self._v_new) diff --git a/modopt/opt/algorithms/base.py b/modopt/opt/algorithms/base.py index 85c36306..c5a4b101 100644 --- a/modopt/opt/algorithms/base.py +++ b/modopt/opt/algorithms/base.py @@ -4,7 +4,7 @@ from inspect import getmro import numpy as np -from progressbar import ProgressBar +from tqdm.auto import tqdm from modopt.base import backend from modopt.base.observable import MetricObserver, Observable @@ -12,17 +12,17 @@ class SetUp(Observable): - r"""Algorithm Set-Up. + """Algorithm Set-Up. This class contains methods for checking the set-up of an optimisation - algotithm and produces warnings if they do not comply. + algorithm and produces warnings if they do not comply. Parameters ---------- metric_call_period : int, optional Metric call period (default is ``5``) metrics : dict, optional - Metrics to be used (default is ``\{\}``) + Metrics to be used (default is ``None``) verbose : bool, optional Option for verbose output (default is ``False``) progress : bool, optional @@ -34,11 +34,32 @@ class SetUp(Observable): use_gpu : bool, optional Option to use available GPU + Notes + ----- + If provided, the ``metrics`` argument should be a nested dictionary of the + following form:: + + metrics = { + 'metric_name': { + 'metric': callable, + 'mapping': {'x_new': 'test'}, + 'cst_kwargs': {'ref': ref_image}, + 'early_stopping': False, + } + } + + Where ``callable`` is a function with arguments being for instance + ``test`` and ``ref``. The mapping of the argument uses the same keys as the + output of ``get_notify_observer_kwargs``, ``cst_kwargs`` defines constant + arguments that will always be passed to the metric call. + If ``early_stopping`` is True, the metric will be used to check for + convergence of the algorithm, in that case it is recommended to have + ``metric_call_period = 1`` + See Also -------- modopt.base.observable.Observable : parent class modopt.base.observable.MetricObserver : definition of metrics - """ def __init__( @@ -240,9 +261,8 @@ def _iterations(self, max_iter, progbar=None): ---------- max_iter : int Maximum number of iterations - progbar : progressbar.bar.ProgressBar - Progress bar (default is ``None``) - + progbar: tqdm.tqdm + Progress bar handle (default is ``None``) """ for idx in range(max_iter): self.idx = idx @@ -268,10 +288,10 @@ def _iterations(self, max_iter, progbar=None): print(' - Converged!') break - if not isinstance(progbar, type(None)): - progbar.update(idx) + if progbar: + progbar.update() - def _run_alg(self, max_iter): + def _run_alg(self, max_iter, progbar=None): """Run algorithm. Run the update step of a given algorithm up to the maximum number of @@ -281,17 +301,34 @@ def _run_alg(self, max_iter): ---------- max_iter : int Maximum number of iterations + progbar: tqdm.tqdm + Progress bar handle (default is ``None``) See Also -------- - progressbar.bar.ProgressBar + tqdm.tqdm """ - if self.progress: - with ProgressBar( - redirect_stdout=True, - max_value=max_iter, - ) as progbar: - self._iterations(max_iter, progbar=progbar) + if self.progress and progbar is None: + with tqdm(total=max_iter) as pb: + self._iterations(max_iter, progbar=pb) + elif progbar: + self._iterations(max_iter, progbar=progbar) else: self._iterations(max_iter) + + def _update(self): + raise NotImplementedError + + def get_notify_observers_kwargs(self): + """Notify Observers. + + Return the mapping between the metrics call and the iterated + variables. + + Raises + ------ + NotImplementedError + This method should be overriden by subclasses. + """ + raise NotImplementedError diff --git a/modopt/opt/algorithms/forward_backward.py b/modopt/opt/algorithms/forward_backward.py index e18f66c3..702799c6 100644 --- a/modopt/opt/algorithms/forward_backward.py +++ b/modopt/opt/algorithms/forward_backward.py @@ -467,7 +467,7 @@ def _update(self): or self._cost_func.get_cost(self._x_new) ) - def iterate(self, max_iter=150): + def iterate(self, max_iter=150, progbar=None): """Iterate. This method calls update until either the convergence criteria is met @@ -477,9 +477,10 @@ def iterate(self, max_iter=150): ---------- max_iter : int, optional Maximum number of iterations (default is ``150``) - + progbar: tqdm.tqdm + Progress bar handle (default is ``None``) """ - self._run_alg(max_iter) + self._run_alg(max_iter, progbar) # retrieve metrics results self.retrieve_outputs() @@ -750,7 +751,7 @@ def _update(self): if self._cost_func: self.converge = self._cost_func.get_cost(self._x_new) - def iterate(self, max_iter=150): + def iterate(self, max_iter=150, progbar=None): """Iterate. This method calls update until either convergence criteria is met or @@ -760,9 +761,10 @@ def iterate(self, max_iter=150): ---------- max_iter : int, optional Maximum number of iterations (default is ``150``) - + progbar: tqdm.tqdm + Progress bar handle (default is ``None``) """ - self._run_alg(max_iter) + self._run_alg(max_iter, progbar) # retrieve metrics results self.retrieve_outputs() @@ -815,9 +817,9 @@ class POGM(SetUp): Initial guess for the :math:`y` variable z : numpy.ndarray Initial guess for the :math:`z` variable - grad + grad : GradBasic Gradient operator class - prox + prox : ProximalParent Proximity operator class cost : class instance or str, optional Cost function class instance (default is ``'auto'``); Use ``'auto'`` to @@ -942,7 +944,9 @@ def _update(self): """ # Step 4 from alg. 3 self._grad.get_grad(self._x_old) - self._u_new = self._x_old - self._beta * self._grad.grad + #self._u_new = self._x_old - self._beta * self._grad.grad + self._u_new = -self._beta * self._grad.grad + self._u_new += self._x_old # Step 5 from alg. 3 self._t_new = 0.5 * (1 + self.xp.sqrt(1 + 4 * self._t_old ** 2)) @@ -964,10 +968,15 @@ def _update(self): # Restarting and gamma-Decreasing # Step 9 from alg. 3 - self._g_new = self._grad.grad - (self._x_new - self._z) / self._xi + #self._g_new = self._grad.grad - (self._x_new - self._z) / self._xi + self._g_new = (self._z - self._x_new) + self._g_new /= self._xi + self._g_new += self._grad.grad # Step 10 from alg 3. - self._y_new = self._x_old - self._beta * self._g_new + #self._y_new = self._x_old - self._beta * self._g_new + self._y_new = - self._beta * self._g_new + self._y_new += self._x_old # Step 11 from alg. 3 restart_crit = ( @@ -995,7 +1004,7 @@ def _update(self): or self._cost_func.get_cost(self._x_new) ) - def iterate(self, max_iter=150): + def iterate(self, max_iter=150, progbar=None): """Iterate. This method calls update until either convergence criteria is met or @@ -1005,9 +1014,10 @@ def iterate(self, max_iter=150): ---------- max_iter : int, optional Maximum number of iterations (default is ``150``) - + progbar: tqdm.tqdm + Progress bar handle (default is ``None``) """ - self._run_alg(max_iter) + self._run_alg(max_iter, progbar) # retrieve metrics results self.retrieve_outputs() diff --git a/modopt/opt/algorithms/primal_dual.py b/modopt/opt/algorithms/primal_dual.py index c8566969..d5bdd431 100644 --- a/modopt/opt/algorithms/primal_dual.py +++ b/modopt/opt/algorithms/primal_dual.py @@ -225,7 +225,7 @@ def _update(self): or self._cost_func.get_cost(self._x_new, self._y_new) ) - def iterate(self, max_iter=150, n_rewightings=1): + def iterate(self, max_iter=150, n_rewightings=1, progbar=None): """Iterate. This method calls update until either convergence criteria is met or @@ -237,14 +237,17 @@ def iterate(self, max_iter=150, n_rewightings=1): Maximum number of iterations (default is ``150``) n_rewightings : int, optional Number of reweightings to perform (default is ``1``) - + progbar: tqdm.tqdm + Progress bar handle (default is ``None``) """ - self._run_alg(max_iter) + self._run_alg(max_iter, progbar) if not isinstance(self._reweight, type(None)): for _ in range(n_rewightings): self._reweight.reweight(self._linear.op(self._x_new)) - self._run_alg(max_iter) + if progbar: + progbar.reset(total=max_iter) + self._run_alg(max_iter, progbar) # retrieve metrics results self.retrieve_outputs() diff --git a/modopt/opt/cost.py b/modopt/opt/cost.py index 3cdfcc50..688a3959 100644 --- a/modopt/opt/cost.py +++ b/modopt/opt/cost.py @@ -6,6 +6,8 @@ """ +import abc + import numpy as np from modopt.base.backend import get_array_module @@ -13,8 +15,8 @@ from modopt.plot.cost_plot import plotCost -class costObj(object): - """Generic cost function object. +class CostParent(abc.ABC): + """Abstract cost function object. This class updates the cost according to the input operator classes and tests for convergence. @@ -40,7 +42,8 @@ class costObj(object): Notes ----- - The costFunc class must contain a method called ``cost``. + All child classes should implement a ``_calc_cost`` method (returning + a float) or a ``get_cost`` for more complex behavior on convergence test. Examples -------- @@ -71,7 +74,6 @@ class costObj(object): def __init__( self, - operators, initial_cost=1e6, tolerance=1e-4, cost_interval=1, @@ -80,9 +82,6 @@ def __init__( plot_output=None, ): - self._operators = operators - if not isinstance(operators, type(None)): - self._check_operators() self.cost = initial_cost self._cost_list = [] self._cost_interval = cost_interval @@ -93,30 +92,6 @@ def __init__( self._plot_output = plot_output self._verbose = verbose - def _check_operators(self): - """Check operators. - - This method checks if the input operators have a ``cost`` method. - - Raises - ------ - TypeError - For invalid operators type - ValueError - For operators without ``cost`` method - - """ - if not isinstance(self._operators, (list, tuple, np.ndarray)): - message = ( - 'Input operators must be provided as a list, not {0}' - ) - raise TypeError(message.format(type(self._operators))) - - for op in self._operators: - if not hasattr(op, 'cost'): - raise ValueError('Operators must contain "cost" method.') - op.cost = check_callable(op.cost) - def _check_cost(self): """Check cost function. @@ -167,6 +142,7 @@ def _check_cost(self): return False + @abc.abstractmethod def _calc_cost(self, *args, **kwargs): """Calculate the cost. @@ -178,14 +154,7 @@ def _calc_cost(self, *args, **kwargs): Positional arguments **kwargs : dict Keyword arguments - - Returns - ------- - float - Cost value - """ - return np.sum([op.cost(*args, **kwargs) for op in self._operators]) def get_cost(self, *args, **kwargs): """Get cost function. @@ -241,3 +210,110 @@ def plot_cost(self): # pragma: no cover """ plotCost(self._cost_list, self._plot_output) + + +class costObj(CostParent): + """Abstract cost function object. + + This class updates the cost according to the input operator classes and + tests for convergence. + + Parameters + ---------- + opertors : list, tuple or numpy.ndarray + List of operators classes containing ``cost`` method + initial_cost : float, optional + Initial value of the cost (default is ``1e6``) + tolerance : float, optional + Tolerance threshold for convergence (default is ``1e-4``) + cost_interval : int, optional + Iteration interval to calculate cost (default is ``1``). + If ``cost_interval`` is ``None`` the cost is never calculated, + thereby saving on computation time. + test_range : int, optional + Number of cost values to be used in test (default is ``4``) + verbose : bool, optional + Option for verbose output (default is ``True``) + plot_output : str, optional + Output file name for cost function plot + + Examples + -------- + >>> from modopt.opt.cost import * + >>> class dummy(object): + ... def cost(self, x): + ... return x ** 2 + ... + ... + >>> inst = costObj([dummy(), dummy()]) + >>> inst.get_cost(2) + - ITERATION: 1 + - COST: 8 + + False + >>> inst.get_cost(2) + - ITERATION: 2 + - COST: 8 + + False + >>> inst.get_cost(2) + - ITERATION: 3 + - COST: 8 + + False + """ + + def __init__( + self, + operators, + **kwargs, + ): + super().__init__(**kwargs) + + self._operators = operators + if not isinstance(operators, type(None)): + self._check_operators() + + def _check_operators(self): + """Check operators. + + This method checks if the input operators have a ``cost`` method. + + Raises + ------ + TypeError + For invalid operators type + ValueError + For operators without ``cost`` method + + """ + if not isinstance(self._operators, (list, tuple, np.ndarray)): + message = ( + 'Input operators must be provided as a list, not {0}' + ) + raise TypeError(message.format(type(self._operators))) + + for op in self._operators: + if not hasattr(op, 'cost'): + raise ValueError('Operators must contain "cost" method.') + op.cost = check_callable(op.cost) + + def _calc_cost(self, *args, **kwargs): + """Calculate the cost. + + This method calculates the cost from each of the input operators. + + Parameters + ---------- + *args : tuple + Positional arguments + **kwargs : dict + Keyword arguments + + Returns + ------- + float + Cost value + + """ + return np.sum([op.cost(*args, **kwargs) for op in self._operators]) diff --git a/modopt/opt/linear/__init__.py b/modopt/opt/linear/__init__.py new file mode 100644 index 00000000..d5c0d21f --- /dev/null +++ b/modopt/opt/linear/__init__.py @@ -0,0 +1,21 @@ +"""LINEAR OPERATORS. + +This module contains linear operator classes. + +:Author: Samuel Farrens +:Author: Pierre-Antoine Comby +""" + +from .base import LinearParent, Identity, MatrixOperator, LinearCombo + +from .wavelet import WaveletConvolve, WaveletTransform + + +__all__ = [ + "LinearParent", + "Identity", + "MatrixOperator", + "LinearCombo", + "WaveletConvolve", + "WaveletTransform", +] diff --git a/modopt/opt/linear.py b/modopt/opt/linear/base.py similarity index 84% rename from modopt/opt/linear.py rename to modopt/opt/linear/base.py index d8679998..e347970d 100644 --- a/modopt/opt/linear.py +++ b/modopt/opt/linear/base.py @@ -1,18 +1,9 @@ -# -*- coding: utf-8 -*- - -"""LINEAR OPERATORS. - -This module contains linear operator classes. - -:Author: Samuel Farrens - -""" +"""Base classes for linear operators.""" import numpy as np -from modopt.base.types import check_callable, check_float -from modopt.signal.wavelet import filter_convolve_stack - +from modopt.base.types import check_callable +from modopt.base.backend import get_array_module class LinearParent(object): """Linear Operator Parent Class. @@ -78,42 +69,24 @@ def __init__(self): self.op = lambda input_data: input_data self.adj_op = self.op + self.cost= lambda *args, **kwargs: 0 -class WaveletConvolve(LinearParent): - """Wavelet Convolution Class. - - This class defines the wavelet transform operators via convolution with - predefined filters. - - Parameters - ---------- - filters: numpy.ndarray - Array of wavelet filter coefficients - method : str, optional - Convolution method (default is ``'scipy'``) - - See Also - -------- - LinearParent : parent class - modopt.signal.wavelet.filter_convolve_stack : wavelet filter convolution +class MatrixOperator(LinearParent): + """ + Matrix Operator class. + This class transforms an array into a suitable linear operator. """ - def __init__(self, filters, method='scipy'): + def __init__(self, array): + self.op = lambda x: array @ x + xp = get_array_module(array) - self._filters = check_float(filters) - self.op = lambda input_data: filter_convolve_stack( - input_data, - self._filters, - method=method, - ) - self.adj_op = lambda input_data: filter_convolve_stack( - input_data, - self._filters, - filter_rot=True, - method=method, - ) + if xp.any(xp.iscomplex(array)): + self.adj_op = lambda x: array.T.conjugate() @ x + else: + self.adj_op = lambda x: array.T @ x class LinearCombo(LinearParent): @@ -150,7 +123,6 @@ class LinearCombo(LinearParent): See Also -------- LinearParent : parent class - """ def __init__(self, operators, weights=None): diff --git a/modopt/opt/linear/wavelet.py b/modopt/opt/linear/wavelet.py new file mode 100644 index 00000000..6e22a2b0 --- /dev/null +++ b/modopt/opt/linear/wavelet.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 +"""Wavelet operator, using either scipy filter or pywavelet.""" +import warnings + +import numpy as np + +from modopt.base.types import check_float +from modopt.signal.wavelet import filter_convolve_stack + +from .base import LinearParent + +pywt_available = True +try: + import pywt + from joblib import Parallel, cpu_count, delayed +except ImportError: + pywt_available = False + + +class WaveletConvolve(LinearParent): + """Wavelet Convolution Class. + + This class defines the wavelet transform operators via convolution with + predefined filters. + + Parameters + ---------- + filters: numpy.ndarray + Array of wavelet filter coefficients + method : str, optional + Convolution method (default is ``'scipy'``) + + See Also + -------- + LinearParent : parent class + modopt.signal.wavelet.filter_convolve_stack : wavelet filter convolution + + """ + + def __init__(self, filters, method='scipy'): + + self._filters = check_float(filters) + self.op = lambda input_data: filter_convolve_stack( + input_data, + self._filters, + method=method, + ) + self.adj_op = lambda input_data: filter_convolve_stack( + input_data, + self._filters, + filter_rot=True, + method=method, + ) + + + +class WaveletTransform(LinearParent): + """ + 2D and 3D wavelet transform class. + + This is a light wrapper around PyWavelet, with multicoil support. + + Parameters + ---------- + wavelet_name: str + the wavelet name to be used during the decomposition. + shape: tuple[int,...] + Shape of the input data. The shape should be a tuple of length 2 or 3. + It should not contains coils or batch dimension. + nb_scales: int, default 4 + the number of scales in the decomposition. + n_batchs: int, default 1 + the number of channel/ batch dimension + n_jobs: int, default 1 + the number of cores to use for multichannel. + backend: str, default "threading" + the backend to use for parallel multichannel linear operation. + verbose: int, default 0 + the verbosity level. + + Attributes + ---------- + nb_scale: int + number of scale decomposed in wavelet space. + n_jobs: int + number of jobs for parallel computation + n_batchs: int + number of coils use f + backend: str + Backend use for parallel computation + verbose: int + Verbosity level + """ + + def __init__( + self, + wavelet_name, + shape, + level=4, + n_batch=1, + n_jobs=1, + decimated=True, + backend="threading", + mode="symmetric", + ): + if not pywt_available: + raise ImportError( + "PyWavelet and/or joblib are not available. Please install it to use WaveletTransform." + ) + if wavelet_name not in pywt.wavelist(kind="all"): + raise ValueError( + "Invalid wavelet name. Availables are ``pywt.waveletlist(kind='all')``" + ) + + self.wavelet = wavelet_name + if isinstance(shape, int): + shape = (shape,) + self.shape = shape + self.n_jobs = n_jobs + self.mode = mode + self.level = level + if not decimated: + raise NotImplementedError( + "Undecimated Wavelet Transform is not implemented yet." + ) + ca, *cds = pywt.wavedecn_shapes( + self.shape, wavelet=self.wavelet, mode=self.mode, level=self.level + ) + self.coeffs_shape = [ca] + [s for cd in cds for s in cd.values()] + + if len(shape) > 1: + self.dwt = pywt.wavedecn + self.idwt = pywt.waverecn + self._pywt_fun = "wavedecn" + else: + self.dwt = pywt.wavedec + self.idwt = pywt.waverec + self._pywt_fun = "wavedec" + + self.n_batch = n_batch + if self.n_batch == 1 and self.n_jobs != 1: + warnings.warn("Making n_jobs = 1 for WaveletTransform as n_batchs = 1") + self.n_jobs = 1 + self.backend = backend + n_proc = self.n_jobs + if n_proc < 0: + n_proc = cpu_count() + self.n_jobs + 1 + + def op(self, data): + """Define the wavelet operator. + + This method returns the input data convolved with the wavelet filter. + + Parameters + ---------- + data: ndarray or Image + input 2D data array. + + Returns + ------- + coeffs: ndarray + the wavelet coefficients. + """ + if self.n_batch > 1: + coeffs, self.coeffs_slices, self.raw_coeffs_shape = zip( + *Parallel( + n_jobs=self.n_jobs, backend=self.backend, verbose=self.verbose + )(delayed(self._op)(data[i]) for i in np.arange(self.n_batch)) + ) + coeffs = np.asarray(coeffs) + else: + coeffs, self.coeffs_slices, self.raw_coeffs_shape = self._op(data) + return coeffs + + def _op(self, data): + """Single coil wavelet transform.""" + return pywt.ravel_coeffs( + self.dwt(data, mode=self.mode, level=self.level, wavelet=self.wavelet) + ) + + def adj_op(self, coeffs): + """Define the wavelet adjoint operator. + + This method returns the reconstructed image. + + Parameters + ---------- + coeffs: ndarray + the wavelet coefficients. + + Returns + ------- + data: ndarray + the reconstructed data. + """ + if self.n_batch > 1: + images = Parallel( + n_jobs=self.n_jobs, backend=self.backend, verbose=self.verbose + )( + delayed(self._adj_op)(coeffs[i], self.coeffs_shape[i]) + for i in np.arange(self.n_batch) + ) + images = np.asarray(images) + else: + images = self._adj_op(coeffs) + return images + + def _adj_op(self, coeffs): + """Single coil inverse wavelet transform.""" + return self.idwt( + pywt.unravel_coeffs( + coeffs, self.coeffs_slices, self.raw_coeffs_shape, self._pywt_fun + ), + wavelet=self.wavelet, + mode=self.mode, + ) diff --git a/modopt/opt/proximity.py b/modopt/opt/proximity.py index f8f368ef..e8492367 100644 --- a/modopt/opt/proximity.py +++ b/modopt/opt/proximity.py @@ -993,7 +993,7 @@ def _interpolate(self, alpha0, alpha1, sum0, sum1): :math:`\sum\theta(\alpha^*)=k` via a linear interpolation. Parameters - ----------- + ---------- alpha0: float A value for wich :math:`\sum\theta(\alpha^0) \leq k` alpha1: float diff --git a/modopt/signal/filter.py b/modopt/signal/filter.py index 8e24768c..84dd8160 100644 --- a/modopt/signal/filter.py +++ b/modopt/signal/filter.py @@ -73,8 +73,8 @@ def mex_hat(data_point, sigma): Examples -------- >>> from modopt.signal.filter import mex_hat - >>> mex_hat(2, 1) - -0.3521390522571337 + >>> round(mex_hat(2, 1), 15) + -0.352139052257134 """ data_point = check_float(data_point) @@ -108,8 +108,8 @@ def mex_hat_dir(data_gauss, data_mex, sigma): Examples -------- >>> from modopt.signal.filter import mex_hat_dir - >>> mex_hat_dir(1, 2, 1) - 0.17606952612856686 + >>> round(mex_hat_dir(1, 2, 1), 16) + 0.1760695261285668 """ data_gauss = check_float(data_gauss) diff --git a/modopt/signal/positivity.py b/modopt/signal/positivity.py index e4ec098d..c19ba62c 100644 --- a/modopt/signal/positivity.py +++ b/modopt/signal/positivity.py @@ -48,7 +48,7 @@ def pos_recursive(input_data): """ if input_data.dtype == 'O': - res = np.array([pos_recursive(elem) for elem in input_data]) + res = np.array([pos_recursive(elem) for elem in input_data], dtype="object") else: res = pos_thresh(input_data) diff --git a/modopt/signal/svd.py b/modopt/signal/svd.py index 6dcb9eda..f3d40a51 100644 --- a/modopt/signal/svd.py +++ b/modopt/signal/svd.py @@ -57,7 +57,7 @@ def find_n_pc(u_vec, factor=0.5): ) # Get the shape of the array - array_shape = np.repeat(np.int(np.sqrt(u_vec.shape[0])), 2) + array_shape = np.repeat(int(np.sqrt(u_vec.shape[0])), 2) # Find the auto correlation of the left singular vector. u_auto = [ @@ -299,7 +299,7 @@ def svd_thresh_coef(input_data, operator, threshold, thresh_type='hard'): a_matrix = np.dot(s_values, v_vec) # Get the shape of the array - array_shape = np.repeat(np.int(np.sqrt(u_vec.shape[0])), 2) + array_shape = np.repeat(int(np.sqrt(u_vec.shape[0])), 2) # Compute threshold matrix. ti = np.array([ diff --git a/modopt/tests/test_algorithms.py b/modopt/tests/test_algorithms.py index 7ff96a8b..5671b8e3 100644 --- a/modopt/tests/test_algorithms.py +++ b/modopt/tests/test_algorithms.py @@ -1,470 +1,279 @@ # -*- coding: utf-8 -*- -"""UNIT TESTS FOR OPT.ALGORITHMS. +"""UNIT TESTS FOR Algorithms. -This module contains unit tests for the modopt.opt.algorithms module. - -:Author: Samuel Farrens +This module contains unit tests for the modopt.opt module. +:Authors: + Samuel Farrens + Pierre-Antoine Comby """ -from unittest import TestCase - import numpy as np import numpy.testing as npt - +import pytest from modopt.opt import algorithms, cost, gradient, linear, proximity, reweight - -# Basic functions to be used as operators or as dummy functions -func_identity = lambda x_val: x_val -func_double = lambda x_val: x_val * 2 -func_sq = lambda x_val: x_val ** 2 -func_cube = lambda x_val: x_val ** 3 - - -class Dummy(object): - """Dummy class for tests.""" - - pass - - -class AlgorithmTestCase(TestCase): - """Test case for algorithms module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) - self.data2 = self.data1 + np.random.randn(*self.data1.shape) * 1e-6 - self.data3 = np.arange(9).reshape(3, 3).astype(float) + 1 - - grad_inst = gradient.GradBasic( - self.data1, - func_identity, - func_identity, - ) - +from pytest_cases import ( + case, + fixture, + fixture_ref, + lazy_value, + parametrize, + parametrize_with_cases, +) + +from test_helpers import Dummy + +SKLEARN_AVAILABLE = True +try: + import sklearn +except ImportError: + SKLEARN_AVAILABLE = False + + +@fixture +def idty(): + """Identity function.""" + return lambda x: x + + +@fixture +def reweight_op(): + """Reweight operator.""" + data3 = np.arange(9).reshape(3, 3).astype(float) + 1 + return reweight.cwbReweight(data3) + + +def build_kwargs(kwargs, use_metrics): + """Build the kwargs for each algorithm, replacing placeholders by true values. + + This function has to be call for each test, as direct parameterization somehow + is not working with pytest-xdist and pytest-cases. + It also adds dummy metric measurement to validate the metric api. + """ + update_value = { + "idty": lambda x: x, + "lin_idty": linear.Identity(), + "reweight_op": reweight.cwbReweight( + np.arange(9).reshape(3, 3).astype(float) + 1 + ), + } + new_kwargs = dict() + print(kwargs) + # update the value of the dict is possible. + for key in kwargs: + new_kwargs[key] = update_value.get(kwargs[key], kwargs[key]) + + if use_metrics: + new_kwargs["linear"] = linear.Identity() + new_kwargs["metrics"] = { + "diff": { + "metric": lambda test, ref: np.sum(test - ref), + "mapping": {"x_new": "test"}, + "cst_kwargs": {"ref": np.arange(9).reshape((3, 3))}, + "early_stopping": False, + } + } + + return new_kwargs + + +@parametrize(use_metrics=[True, False]) +class AlgoCases: + """Cases for algorithms. + + Most of the test solves the trivial problem + + .. math:: + \\min_x \\frac{1}{2} \\| y - x \\|_2^2 \\quad\\text{s.t.} x \\geq 0 + + More complex and concrete usecases are shown in examples. + """ + + data1 = np.arange(9).reshape(3, 3).astype(float) + data2 = data1 + np.random.randn(*data1.shape) * 1e-6 + max_iter = 20 + + @parametrize( + kwargs=[ + {"beta_update": "idty", "auto_iterate": False, "cost": None}, + {"beta_update": "idty"}, + {"cost": None, "lambda_update": None}, + {"beta_update": "idty", "a_cd": 3}, + {"beta_update": "idty", "r_lazy": 3, "p_lazy": 0.7, "q_lazy": 0.7}, + {"restart_strategy": "adaptive", "xi_restart": 0.9}, + { + "restart_strategy": "greedy", + "xi_restart": 0.9, + "min_beta": 1.0, + "s_greedy": 1.1, + }, + ] + ) + def case_forward_backward(self, kwargs, idty, use_metrics): + """Forward Backward case. + """ + update_kwargs = build_kwargs(kwargs, use_metrics) + algo = algorithms.ForwardBackward( + self.data1, + grad=gradient.GradBasic(self.data1, idty, idty), + prox=proximity.Positivity(), + **update_kwargs, + ) + if update_kwargs.get("auto_iterate", None) is False: + algo.iterate(self.max_iter) + return algo, update_kwargs + + @parametrize( + kwargs=[ + { + "cost": None, + "auto_iterate": False, + "gamma_update": "idty", + "beta_update": "idty", + }, + {"gamma_update": "idty", "lambda_update": "idty"}, + {"cost": True}, + {"cost": True, "step_size": 2}, + ] + ) + def case_gen_forward_backward(self, kwargs, use_metrics, idty): + """General FB setup.""" + update_kwargs = build_kwargs(kwargs, use_metrics) + grad_inst = gradient.GradBasic(self.data1, idty, idty) prox_inst = proximity.Positivity() prox_dual_inst = proximity.IdentityProx() - linear_inst = linear.Identity() - reweight_inst = reweight.cwbReweight(self.data3) - cost_inst = cost.costObj([grad_inst, prox_inst, prox_dual_inst]) - self.setup = algorithms.SetUp() - self.max_iter = 20 - - self.fb_all_iter = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=None, - auto_iterate=False, - beta_update=func_identity, - ) - self.fb_all_iter.iterate(self.max_iter) - - self.fb1 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - beta_update=func_identity, - ) - - self.fb2 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - lambda_update=None, - ) - - self.fb3 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - beta_update=func_identity, - a_cd=3, - ) - - self.fb4 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - beta_update=func_identity, - r_lazy=3, - p_lazy=0.7, - q_lazy=0.7, - ) - - self.fb5 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - restart_strategy='adaptive', - xi_restart=0.9, - ) - - self.fb6 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - restart_strategy='greedy', - xi_restart=0.9, - min_beta=1.0, - s_greedy=1.1, - ) - - self.gfb_all_iter = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - cost=None, - auto_iterate=False, - gamma_update=func_identity, - beta_update=func_identity, - ) - self.gfb_all_iter.iterate(self.max_iter) - - self.gfb1 = algorithms.GenForwardBackward( + if update_kwargs.get("cost", None) is True: + update_kwargs["cost"] = cost.costObj([grad_inst, prox_inst, prox_dual_inst]) + algo = algorithms.GenForwardBackward( self.data1, grad=grad_inst, prox_list=[prox_inst, prox_dual_inst], - gamma_update=func_identity, - lambda_update=func_identity, - ) - - self.gfb2 = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - cost=cost_inst, - ) - - self.gfb3 = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - cost=cost_inst, - step_size=2, - ) - - self.condat_all_iter = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - cost=None, - prox_dual=prox_dual_inst, - sigma_update=func_identity, - tau_update=func_identity, - rho_update=func_identity, - auto_iterate=False, - ) - self.condat_all_iter.iterate(self.max_iter) - - self.condat1 = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - prox_dual=prox_dual_inst, - sigma_update=func_identity, - tau_update=func_identity, - rho_update=func_identity, - ) - - self.condat2 = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - prox_dual=prox_dual_inst, - linear=linear_inst, - cost=cost_inst, - reweight=reweight_inst, - ) + **update_kwargs, + ) + if update_kwargs.get("auto_iterate", None) is False: + algo.iterate(self.max_iter) + return algo, update_kwargs + + @parametrize( + kwargs=[ + { + "sigma_dual": "idty", + "tau_update": "idty", + "rho_update": "idty", + "auto_iterate": False, + }, + { + "sigma_dual": "idty", + "tau_update": "idty", + "rho_update": "idty", + }, + { + "linear": "lin_idty", + "cost": True, + "reweight": "reweight_op", + }, + ] + ) + def case_condat(self, kwargs, use_metrics, idty): + """Condat Vu Algorithm setup.""" + update_kwargs = build_kwargs(kwargs, use_metrics) + grad_inst = gradient.GradBasic(self.data1, idty, idty) + prox_inst = proximity.Positivity() + prox_dual_inst = proximity.IdentityProx() + if update_kwargs.get("cost", None) is True: + update_kwargs["cost"] = cost.costObj([grad_inst, prox_inst, prox_dual_inst]) - self.condat3 = algorithms.Condat( + algo = algorithms.Condat( self.data1, self.data2, grad=grad_inst, prox=prox_inst, prox_dual=prox_dual_inst, - linear=Dummy(), - cost=cost_inst, - auto_iterate=False, + **update_kwargs, ) + if update_kwargs.get("auto_iterate", None) is False: + algo.iterate(self.max_iter) + return algo, update_kwargs - self.pogm_all_iter = algorithms.POGM( + @parametrize(kwargs=[{"auto_iterate": False, "cost": None}, {}]) + def case_pogm(self, kwargs, use_metrics, idty): + """POGM setup.""" + update_kwargs = build_kwargs(kwargs, use_metrics) + grad_inst = gradient.GradBasic(self.data1, idty, idty) + prox_inst = proximity.Positivity() + algo = algorithms.POGM( u=self.data1, x=self.data1, y=self.data1, z=self.data1, grad=grad_inst, prox=prox_inst, - auto_iterate=False, - cost=None, + **update_kwargs, ) - self.pogm_all_iter.iterate(self.max_iter) - self.pogm1 = algorithms.POGM( - u=self.data1, - x=self.data1, - y=self.data1, - z=self.data1, - grad=grad_inst, - prox=prox_inst, - ) + if update_kwargs.get("auto_iterate", None) is False: + algo.iterate(self.max_iter) + return algo, update_kwargs - self.vanilla_grad = algorithms.VanillaGenericGradOpt( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - ) - self.ada_grad = algorithms.AdaGenericGradOpt( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - ) - self.adam_grad = algorithms.ADAMGradOpt( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - ) - self.momentum_grad = algorithms.MomentumGradOpt( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - ) - self.rms_grad = algorithms.RMSpropGradOpt( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - ) - self.saga_grad = algorithms.SAGAOptGradOpt( + @parametrize( + GradDescent=[ + algorithms.VanillaGenericGradOpt, + algorithms.AdaGenericGradOpt, + algorithms.ADAMGradOpt, + algorithms.MomentumGradOpt, + algorithms.RMSpropGradOpt, + algorithms.SAGAOptGradOpt, + ] + ) + def case_grad(self, GradDescent, use_metrics, idty): + """Gradient Descent algorithm test.""" + update_kwargs = build_kwargs({}, use_metrics) + grad_inst = gradient.GradBasic(self.data1, idty, idty) + prox_inst = proximity.Positivity() + cost_inst = cost.costObj([grad_inst, prox_inst]) + + algo = GradDescent( self.data1, grad=grad_inst, prox=prox_inst, cost=cost_inst, + **update_kwargs, ) + algo.iterate() + return algo, update_kwargs + @parametrize(admm=[algorithms.ADMM,algorithms.FastADMM]) + def case_admm(self, admm, use_metrics, idty): + """ADMM setup.""" + def optim1(init, obs): + return obs - self.dummy = Dummy() - self.dummy.cost = func_identity - self.setup._check_operator(self.dummy.cost) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.setup = None - self.fb_all_iter = None - self.fb1 = None - self.fb2 = None - self.gfb_all_iter = None - self.gfb1 = None - self.gfb2 = None - self.condat_all_iter = None - self.condat1 = None - self.condat2 = None - self.condat3 = None - self.pogm1 = None - self.pogm_all_iter = None - self.dummy = None - - def test_set_up(self): - """Test set_up.""" - npt.assert_raises(TypeError, self.setup._check_input_data, 1) - - npt.assert_raises(TypeError, self.setup._check_param, 1) - - npt.assert_raises(TypeError, self.setup._check_param_update, 1) - - def test_all_iter(self): - """Test if all opt run for all iterations.""" - opts = [ - self.fb_all_iter, - self.gfb_all_iter, - self.condat_all_iter, - self.pogm_all_iter, - ] - for opt in opts: - npt.assert_equal(opt.idx, self.max_iter - 1) - - def test_forward_backward(self): - """Test forward_backward.""" - npt.assert_array_equal( - self.fb1.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb2.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb3.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb4.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb5.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb6.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) + def optim2(init, obs): + return obs - def test_gen_forward_backward(self): - """Test gen_forward_backward.""" - npt.assert_array_equal( - self.gfb1.x_final, - self.data1, - err_msg='Incorrect GenForwardBackward result.', - ) - - npt.assert_array_equal( - self.gfb2.x_final, - self.data1, - err_msg='Incorrect GenForwardBackward result.', - ) - - npt.assert_array_equal( - self.gfb3.x_final, - self.data1, - err_msg='Incorrect GenForwardBackward result.', - ) - - npt.assert_equal( - self.gfb3.step_size, - 2, - err_msg='Incorrect step size.', - ) - - npt.assert_raises( - TypeError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=1, - ) - - npt.assert_raises( - ValueError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=[1], - ) - - npt.assert_raises( - ValueError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=[0.5, 0.5], - ) - - npt.assert_raises( - ValueError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=[0.5], - ) - - def test_condat(self): - """Test gen_condat.""" - npt.assert_almost_equal( - self.condat1.x_final, - self.data1, - err_msg='Incorrect Condat result.', - ) - - npt.assert_almost_equal( - self.condat2.x_final, - self.data1, - err_msg='Incorrect Condat result.', - ) - - def test_pogm(self): - """Test pogm.""" - npt.assert_almost_equal( - self.pogm1.x_final, - self.data1, - err_msg='Incorrect POGM result.', - ) - - def test_ada_grad(self): - """Test ADA Gradient Descent.""" - self.ada_grad.iterate() - npt.assert_almost_equal( - self.ada_grad.x_final, - self.data1, - err_msg='Incorrect ADAGrad results.', - ) - - def test_adam_grad(self): - """Test ADAM Gradient Descent.""" - self.adam_grad.iterate() - npt.assert_almost_equal( - self.adam_grad.x_final, - self.data1, - err_msg='Incorrect ADAMGrad results.', - ) - - def test_momemtum_grad(self): - """Test Momemtum Gradient Descent.""" - self.momentum_grad.iterate() - npt.assert_almost_equal( - self.momentum_grad.x_final, - self.data1, - err_msg='Incorrect MomentumGrad results.', - ) - - def test_rmsprop_grad(self): - """Test RMSProp Gradient Descent.""" - self.rms_grad.iterate() - npt.assert_almost_equal( - self.rms_grad.x_final, - self.data1, - err_msg='Incorrect RMSPropGrad results.', - ) - - def test_saga_grad(self): - """Test SAGA Descent.""" - self.saga_grad.iterate() - npt.assert_almost_equal( - self.saga_grad.x_final, - self.data1, - err_msg='Incorrect SAGA Grad results.', - ) - - def test_vanilla_grad(self): - """Test Vanilla Gradient Descent.""" - self.vanilla_grad.iterate() - npt.assert_almost_equal( - self.vanilla_grad.x_final, - self.data1, - err_msg='Incorrect VanillaGrad results.', - ) + update_kwargs = build_kwargs({}, use_metrics) + algo = admm( + u=self.data1, + v=self.data1, + mu=np.zeros_like(self.data1), + A=linear.Identity(), + B=linear.Identity(), + b=self.data1, + optimizers=(optim1, optim2), + **update_kwargs, + ) + algo.iterate() + return algo, update_kwargs + +@parametrize_with_cases("algo, kwargs", cases=AlgoCases) +def test_algo(algo, kwargs): + """Test algorithms.""" + if kwargs.get("auto_iterate") is False: + # algo already run + npt.assert_almost_equal(algo.idx, AlgoCases.max_iter - 1) + else: + npt.assert_almost_equal(algo.x_final, AlgoCases.data1) + + if kwargs.get("metrics"): + print(algo.metrics) + npt.assert_almost_equal(algo.metrics["diff"]["values"][-1], 0, 3) diff --git a/modopt/tests/test_base.py b/modopt/tests/test_base.py index 873a4506..e32ff94b 100644 --- a/modopt/tests/test_base.py +++ b/modopt/tests/test_base.py @@ -1,192 +1,139 @@ -# -*- coding: utf-8 -*- - -"""UNIT TESTS FOR BASE. - -This module contains unit tests for the modopt.base module. - -:Author: Samuel Farrens - """ +Test for base module. -from builtins import range -from unittest import TestCase, skipIf - +:Authors: + Samuel Farrens + Pierre-Antoine Comby +""" import numpy as np import numpy.testing as npt +import pytest +from test_helpers import failparam, skipparam -from modopt.base import np_adjust, transform, types -from modopt.base.backend import (LIBRARIES, change_backend, get_array_module, - get_backend) +from modopt.base import backend, np_adjust, transform, types +from modopt.base.backend import LIBRARIES -class NPAdjustTestCase(TestCase): - """Test case for np_adjust module.""" +class TestNpAdjust: + """Test for npadjust.""" - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape((3, 3)) - self.data2 = np.arange(18).reshape((2, 3, 3)) - self.data3 = np.array([ + array33 = np.arange(9).reshape((3, 3)) + array233 = np.arange(18).reshape((2, 3, 3)) + arraypad = np.array( + [ [0, 0, 0, 0, 0], [0, 0, 1, 2, 0], [0, 3, 4, 5, 0], [0, 6, 7, 8, 0], [0, 0, 0, 0, 0], - ]) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.data3 = None + ] + ) def test_rotate(self): """Test rotate.""" npt.assert_array_equal( - np_adjust.rotate(self.data1), - np.array([[8, 7, 6], [5, 4, 3], [2, 1, 0]]), - err_msg='Incorrect rotation', + np_adjust.rotate(self.array33), + np.rot90(np.rot90(self.array33)), + err_msg="Incorrect rotation.", ) def test_rotate_stack(self): """Test rotate_stack.""" npt.assert_array_equal( - np_adjust.rotate_stack(self.data2), - np.array([ - [[8, 7, 6], [5, 4, 3], [2, 1, 0]], - [[17, 16, 15], [14, 13, 12], [11, 10, 9]], - ]), - err_msg='Incorrect stack rotation', + np_adjust.rotate_stack(self.array233), + np.rot90(self.array233, k=2, axes=(1, 2)), + err_msg="Incorrect stack rotation.", ) - def test_pad2d(self): + @pytest.mark.parametrize( + "padding", + [ + 1, + [1, 1], + np.array([1, 1]), + failparam("1", raises=ValueError), + ], + ) + def test_pad2d(self, padding): """Test pad2d.""" - npt.assert_array_equal( - np_adjust.pad2d(self.data1, (1, 1)), - self.data3, - err_msg='Incorrect padding', - ) - - npt.assert_array_equal( - np_adjust.pad2d(self.data1, 1), - self.data3, - err_msg='Incorrect padding', - ) - - npt.assert_array_equal( - np_adjust.pad2d(self.data1, np.array([1, 1])), - self.data3, - err_msg='Incorrect padding', - ) - - npt.assert_raises(ValueError, np_adjust.pad2d, self.data1, '1') + npt.assert_equal(np_adjust.pad2d(self.array33, padding), self.arraypad) def test_fancy_transpose(self): - """Test fancy_transpose.""" + """Test fancy transpose.""" npt.assert_array_equal( - np_adjust.fancy_transpose(self.data2), - np.array([ - [[0, 3, 6], [9, 12, 15]], - [[1, 4, 7], [10, 13, 16]], - [[2, 5, 8], [11, 14, 17]], - ]), - err_msg='Incorrect fancy transpose', + np_adjust.fancy_transpose(self.array233), + np.array( + [ + [[0, 3, 6], [9, 12, 15]], + [[1, 4, 7], [10, 13, 16]], + [[2, 5, 8], [11, 14, 17]], + ] + ), + err_msg="Incorrect fancy transpose", ) def test_ftr(self): """Test ftr.""" npt.assert_array_equal( - np_adjust.ftr(self.data2), - np.array([ - [[0, 3, 6], [9, 12, 15]], - [[1, 4, 7], [10, 13, 16]], - [[2, 5, 8], [11, 14, 17]], - ]), - err_msg='Incorrect fancy transpose: ftr', + np_adjust.ftr(self.array233), + np.array( + [ + [[0, 3, 6], [9, 12, 15]], + [[1, 4, 7], [10, 13, 16]], + [[2, 5, 8], [11, 14, 17]], + ] + ), + err_msg="Incorrect fancy transpose: ftr", ) def test_ftl(self): - """Test ftl.""" - npt.assert_array_equal( - np_adjust.ftl(self.data2), - np.array([ - [[0, 9], [1, 10], [2, 11]], - [[3, 12], [4, 13], [5, 14]], - [[6, 15], [7, 16], [8, 17]], - ]), - err_msg='Incorrect fancy transpose: ftl', - ) - - -class TransformTestCase(TestCase): - """Test case for transform module.""" - - def setUp(self): - """Set test parameter values.""" - self.cube = np.arange(16).reshape((4, 2, 2)) - self.map = np.array( - [[0, 1, 4, 5], [2, 3, 6, 7], [8, 9, 12, 13], [10, 11, 14, 15]], - ) - self.matrix = np.array( - [[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]], - ) - self.layout = (2, 2) - - def tearDown(self): - """Unset test parameter values.""" - self.cube = None - self.map = None - self.layout = None - - def test_cube2map(self): + """Test fancy transpose left.""" + npt.assert_array_equal( + np_adjust.ftl(self.array233), + np.array( + [ + [[0, 9], [1, 10], [2, 11]], + [[3, 12], [4, 13], [5, 14]], + [[6, 15], [7, 16], [8, 17]], + ] + ), + err_msg="Incorrect fancy transpose: ftl", + ) + + +class TestTransforms: + """Test for the transform module.""" + + cube = np.arange(16).reshape((4, 2, 2)) + map = np.array([[0, 1, 4, 5], [2, 3, 6, 7], [8, 9, 12, 13], [10, 11, 14, 15]]) + matrix = np.array([[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]) + layout = (2, 2) + fail_layout = (3, 3) + + @pytest.mark.parametrize( + ("func", "indata", "layout", "outdata"), + [ + (transform.cube2map, cube, layout, map), + failparam(transform.cube2map, np.eye(2), layout, map, raises=ValueError), + (transform.map2cube, map, layout, cube), + (transform.map2matrix, map, layout, matrix), + (transform.matrix2map, matrix, matrix.shape, map), + ], + ) + def test_map(self, func, indata, layout, outdata): """Test cube2map.""" npt.assert_array_equal( - transform.cube2map(self.cube, self.layout), - self.map, - err_msg='Incorrect transformation: cube2map', - ) - - npt.assert_raises( - ValueError, - transform.cube2map, - self.map, - self.layout, - ) - - npt.assert_raises(ValueError, transform.cube2map, self.cube, (3, 3)) - - def test_map2cube(self): - """Test map2cube.""" - npt.assert_array_equal( - transform.map2cube(self.map, self.layout), - self.cube, - err_msg='Incorrect transformation: map2cube', - ) - - npt.assert_raises(ValueError, transform.map2cube, self.map, (3, 3)) - - def test_map2matrix(self): - """Test map2matrix.""" - npt.assert_array_equal( - transform.map2matrix(self.map, self.layout), - self.matrix, - err_msg='Incorrect transformation: map2matrix', - ) - - def test_matrix2map(self): - """Test matrix2map.""" - npt.assert_array_equal( - transform.matrix2map(self.matrix, self.map.shape), - self.map, - err_msg='Incorrect transformation: matrix2map', + func(indata, layout), + outdata, ) + if func.__name__ != "map2matrix": + npt.assert_raises(ValueError, func, indata, self.fail_layout) def test_cube2matrix(self): """Test cube2matrix.""" npt.assert_array_equal( transform.cube2matrix(self.cube), self.matrix, - err_msg='Incorrect transformation: cube2matrix', ) def test_matrix2cube(self): @@ -194,136 +141,78 @@ def test_matrix2cube(self): npt.assert_array_equal( transform.matrix2cube(self.matrix, self.cube[0].shape), self.cube, - err_msg='Incorrect transformation: matrix2cube', - ) - - -class TypesTestCase(TestCase): - """Test case for types module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = list(range(5)) - self.data2 = np.arange(5) - self.data3 = np.arange(5).astype(float) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.data3 = None - - def test_check_float(self): - """Test check_float.""" - npt.assert_array_equal( - types.check_float(1.0), - 1.0, - err_msg='Float check failed', - ) - - npt.assert_array_equal( - types.check_float(1), - 1.0, - err_msg='Float check failed', - ) - - npt.assert_array_equal( - types.check_float(self.data1), - self.data3, - err_msg='Float check failed', - ) - - npt.assert_array_equal( - types.check_float(self.data2), - self.data3, - err_msg='Float check failed', - ) - - npt.assert_raises(TypeError, types.check_float, '1') - - def test_check_int(self): - """Test check_int.""" - npt.assert_array_equal( - types.check_int(1), - 1, - err_msg='Float check failed', - ) - - npt.assert_array_equal( - types.check_int(1.0), - 1, - err_msg='Float check failed', - ) - - npt.assert_array_equal( - types.check_int(self.data1), - self.data2, - err_msg='Float check failed', - ) - - npt.assert_array_equal( - types.check_int(self.data3), - self.data2, - err_msg='Int check failed', - ) - - npt.assert_raises(TypeError, types.check_int, '1') - - def test_check_npndarray(self): + err_msg="Incorrect transformation: matrix2cube", + ) + + +class TestType: + """Test for type module.""" + + data_list = list(range(5)) + data_int = np.arange(5) + data_flt = np.arange(5).astype(float) + + @pytest.mark.parametrize( + ("data", "checked"), + [ + (1.0, 1.0), + (1, 1.0), + (data_list, data_flt), + (data_int, data_flt), + failparam("1.0", 1.0, raises=TypeError), + ], + ) + def test_check_float(self, data, checked): + """Test check float.""" + npt.assert_array_equal(types.check_float(data), checked) + + @pytest.mark.parametrize( + ("data", "checked"), + [ + (1.0, 1), + (1, 1), + (data_list, data_int), + (data_flt, data_int), + failparam("1", None, raises=TypeError), + ], + ) + def test_check_int(self, data, checked): + """Test check int.""" + npt.assert_array_equal(types.check_int(data), checked) + + @pytest.mark.parametrize( + ("data", "dtype"), [(data_flt, np.integer), (data_int, np.floating)] + ) + def test_check_npndarray(self, data, dtype): """Test check_npndarray.""" npt.assert_raises( TypeError, types.check_npndarray, - self.data3, - dtype=np.integer, - ) - - -class TestBackend(TestCase): - """Test the backend codes.""" - - def setUp(self): - """Set test parameter values.""" - self.input = np.array([10, 10]) - - @skipIf(LIBRARIES['tensorflow'] is None, 'tensorflow library not installed') - def test_tf_backend(self): - """Test tensorflow backend.""" - xp, backend = get_backend('tensorflow') - if backend != 'tensorflow' or xp != LIBRARIES['tensorflow']: - raise AssertionError('tensorflow get_backend fails!') - tf_input = change_backend(self.input, 'tensorflow') - if ( - get_array_module(LIBRARIES['tensorflow'].ones(1)) != LIBRARIES['tensorflow'] - or get_array_module(tf_input) != LIBRARIES['tensorflow'] - ): - raise AssertionError('tensorflow backend fails!') - - @skipIf(LIBRARIES['cupy'] is None, 'cupy library not installed') - def test_cp_backend(self): - """Test cupy backend.""" - xp, backend = get_backend('cupy') - if backend != 'cupy' or xp != LIBRARIES['cupy']: - raise AssertionError('cupy get_backend fails!') - cp_input = change_backend(self.input, 'cupy') - if ( - get_array_module(LIBRARIES['cupy'].ones(1)) != LIBRARIES['cupy'] - or get_array_module(cp_input) != LIBRARIES['cupy'] - ): - raise AssertionError('cupy backend fails!') - - def test_np_backend(self): - """Test numpy backend.""" - xp, backend = get_backend('numpy') - if backend != 'numpy' or xp != LIBRARIES['numpy']: - raise AssertionError('numpy get_backend fails!') - np_input = change_backend(self.input, 'numpy') - if ( - get_array_module(LIBRARIES['numpy'].ones(1)) != LIBRARIES['numpy'] - or get_array_module(np_input) != LIBRARIES['numpy'] - ): - raise AssertionError('numpy backend fails!') - - def tearDown(self): - """Tear Down of objects.""" - self.input = None + data, + dtype=dtype, + ) + + def test_check_callable(self): + """Test callable.""" + npt.assert_raises(TypeError, types.check_callable, 1) + + +@pytest.mark.parametrize( + "backend_name", + [ + skipparam(name, cond=LIBRARIES[name] is None, reason=f"{name} not installed") + for name in LIBRARIES + ], +) +def test_tf_backend(backend_name): + """Test Modopt computational backends.""" + xp, checked_backend_name = backend.get_backend(backend_name) + if checked_backend_name != backend_name or xp != LIBRARIES[backend_name]: + raise AssertionError(f"{backend_name} get_backend fails!") + xp_input = backend.change_backend(np.array([10, 10]), backend_name) + if ( + backend.get_array_module(LIBRARIES[backend_name].ones(1)) + != backend.LIBRARIES[backend_name] + or backend.get_array_module(xp_input) != LIBRARIES[backend_name] + ): + raise AssertionError(f"{backend_name} backend fails!") diff --git a/modopt/tests/test_helpers/__init__.py b/modopt/tests/test_helpers/__init__.py new file mode 100644 index 00000000..3886b877 --- /dev/null +++ b/modopt/tests/test_helpers/__init__.py @@ -0,0 +1 @@ +from .utils import failparam, skipparam, Dummy diff --git a/modopt/tests/test_helpers/utils.py b/modopt/tests/test_helpers/utils.py new file mode 100644 index 00000000..d8227640 --- /dev/null +++ b/modopt/tests/test_helpers/utils.py @@ -0,0 +1,23 @@ +""" +Some helper functions for the test parametrization. +They should be used inside ``@pytest.mark.parametrize`` call. + +:Author: Pierre-Antoine Comby +""" +import pytest + + +def failparam(*args, raises=None): + """Return a pytest parameterization that should raise an error.""" + if not issubclass(raises, Exception): + raise ValueError("raises should be an expected Exception.") + return pytest.param(*args, marks=pytest.mark.raises(exception=raises)) + + +def skipparam(*args, cond=True, reason=""): + """Return a pytest parameterization that should be skip if cond is valid.""" + return pytest.param(*args, marks=pytest.mark.skipif(cond, reason=reason)) + + +class Dummy: + pass diff --git a/modopt/tests/test_math.py b/modopt/tests/test_math.py index 99908e02..ea177b15 100644 --- a/modopt/tests/test_math.py +++ b/modopt/tests/test_math.py @@ -1,215 +1,181 @@ -# -*- coding: utf-8 -*- - """UNIT TESTS FOR MATH. This module contains unit tests for the modopt.math module. -:Author: Samuel Farrens - +:Authors: + Samuel Farrens + Pierre-Antoine Comby """ - -from unittest import TestCase, skipIf, skipUnless +import pytest +from test_helpers import failparam, skipparam import numpy as np import numpy.testing as npt + from modopt.math import convolve, matrix, metrics, stats try: import astropy except ImportError: # pragma: no cover - import_astropy = False + ASTROPY_AVAILABLE = False else: # pragma: no cover - import_astropy = True + ASTROPY_AVAILABLE = True try: from skimage.metrics import structural_similarity as compare_ssim except ImportError: # pragma: no cover - import_skimage = False + SKIMAGE_AVAILABLE = False else: - import_skimage = True - - -class ConvolveTestCase(TestCase): - """Test case for convolve module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(18).reshape(2, 3, 3) - self.data2 = self.data1 + 1 - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - - @skipUnless(import_astropy, 'Astropy not installed.') # pragma: no cover - def test_convolve_astropy(self): - """Test convolve using astropy.""" - npt.assert_allclose( - convolve.convolve(self.data1[0], self.data2[0], method='astropy'), - np.array([ - [210.0, 201.0, 210.0], - [129.0, 120.0, 129.0], - [210.0, 201.0, 210.0], - ]), - err_msg='Incorrect convolution: astropy', - ) - - npt.assert_raises( - ValueError, - convolve.convolve, - self.data1[0], - self.data2, - ) - - npt.assert_raises( - ValueError, - convolve.convolve, - self.data1[0], - self.data2[0], - method='bla', - ) - - def test_convolve_scipy(self): - """Test convolve using scipy.""" - npt.assert_allclose( - convolve.convolve(self.data1[0], self.data2[0], method='scipy'), - np.array([ + SKIMAGE_AVAILABLE = True + + +class TestConvolve: + """Test convolve functions.""" + + array233 = np.arange(18).reshape((2, 3, 3)) + array233_1 = array233 + 1 + result_astropy = np.array( + [ + [210.0, 201.0, 210.0], + [129.0, 120.0, 129.0], + [210.0, 201.0, 210.0], + ] + ) + result_scipy = np.array( + [ + [ [14.0, 35.0, 38.0], [57.0, 120.0, 111.0], [110.0, 197.0, 158.0], - ]), - err_msg='Incorrect convolution: scipy', - ) - - def test_convolve_stack(self): - """Test convolve_stack.""" + ], + [ + [518.0, 845.0, 614.0], + [975.0, 1578.0, 1137.0], + [830.0, 1331.0, 950.0], + ], + ] + ) + + result_rot_kernel = np.array( + [ + [ + [66.0, 115.0, 82.0], + [153.0, 240.0, 159.0], + [90.0, 133.0, 82.0], + ], + [ + [714.0, 1087.0, 730.0], + [1125.0, 1698.0, 1131.0], + [738.0, 1105.0, 730.0], + ], + ] + ) + + @pytest.mark.parametrize( + ("input_data", "kernel", "method", "result"), + [ + skipparam( + array233[0], + array233_1[0], + "astropy", + result_astropy, + cond=not ASTROPY_AVAILABLE, + reason="astropy not available", + ), + failparam( + array233[0], array233_1, "astropy", result_astropy, raises=ValueError + ), + failparam( + array233[0], array233_1[0], "fail!", result_astropy, raises=ValueError + ), + (array233[0], array233_1[0], "scipy", result_scipy[0]), + ], + ) + def test_convolve(self, input_data, kernel, method, result): + """Test convolve function.""" + npt.assert_allclose(convolve.convolve(input_data, kernel, method), result) + + @pytest.mark.parametrize( + ("result", "rot_kernel"), + [ + (result_scipy, False), + (result_rot_kernel, True), + ], + ) + def test_convolve_stack(self, result, rot_kernel): + """Test convolve stack function.""" npt.assert_allclose( - convolve.convolve_stack(self.data1, self.data2), - np.array([ - [ - [14.0, 35.0, 38.0], - [57.0, 120.0, 111.0], - [110.0, 197.0, 158.0], - ], - [ - [518.0, 845.0, 614.0], - [975.0, 1578.0, 1137.0], - [830.0, 1331.0, 950.0], - ], - ]), - err_msg='Incorrect convolution: stack', + convolve.convolve_stack( + self.array233, self.array233_1, rot_kernel=rot_kernel + ), + result, ) - def test_convolve_stack_rot(self): - """Test convolve_stack rotated.""" - npt.assert_allclose( - convolve.convolve_stack(self.data1, self.data2, rot_kernel=True), - np.array([ - [ - [66.0, 115.0, 82.0], - [153.0, 240.0, 159.0], - [90.0, 133.0, 82.0], - ], - [ - [714.0, 1087.0, 730.0], - [1125.0, 1698.0, 1131.0], - [738.0, 1105.0, 730.0], - ], - ]), - err_msg='Incorrect convolution: stack rot', - ) +class TestMatrix: + """Test matrix module.""" -class MatrixTestCase(TestCase): - """Test case for matrix module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3) - self.data2 = np.arange(3) - self.data3 = np.arange(6).reshape(2, 3) - np.random.seed(1) - self.pmInstance1 = matrix.PowerMethod( - lambda x_val: x_val.dot(x_val.T), - self.data1.shape, - verbose=True, - ) - np.random.seed(1) - self.pmInstance2 = matrix.PowerMethod( - lambda x_val: x_val.dot(x_val.T), - self.data1.shape, - auto_run=False, - verbose=True, - ) - self.pmInstance2.get_spec_rad(max_iter=1) - self.gram_schmidt_out = ( - np.array([ + array3 = np.arange(3) + array33 = np.arange(9).reshape((3, 3)) + array23 = np.arange(6).reshape((2, 3)) + gram_schmidt_out = ( + np.array( + [ [0, 1.0, 2.0], [3.0, 1.2, -6e-1], [-1.77635684e-15, 0, 0], - ]), - np.array([ + ] + ), + np.array( + [ [0, 0.4472136, 0.89442719], [0.91287093, 0.36514837, -0.18257419], [-1.0, 0, 0], - ]), - ) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.data3 = None - self.pmInstance1 = None - self.pmInstance2 = None - self.gram_schmidt_out = None - - def test_gram_schmidt_orthonormal(self): - """Test gram_schmidt with orthonormal output.""" - npt.assert_allclose( - matrix.gram_schmidt(self.data1), - self.gram_schmidt_out[1], - err_msg='Incorrect Gram-Schmidt: orthonormal', - ) + ] + ), + ) - npt.assert_raises( - ValueError, - matrix.gram_schmidt, - self.data1, - return_opt='bla', - ) - - def test_gram_schmidt_orthogonal(self): - """Test gram_schmidt with orthogonal output.""" - npt.assert_allclose( - matrix.gram_schmidt(self.data1, return_opt='orthogonal'), - self.gram_schmidt_out[0], - err_msg='Incorrect Gram-Schmidt: orthogonal', + @pytest.fixture + def pm_instance(self, request): + """Power Method instance.""" + np.random.seed(1) + pm = matrix.PowerMethod( + lambda x_val: x_val.dot(x_val.T), + self.array33.shape, + auto_run=request.param, + verbose=True, ) - - def test_gram_schmidt_both(self): - """Test gram_schmidt with both outputs.""" + if not request.param: + pm.get_spec_rad(max_iter=1) + return pm + + @pytest.mark.parametrize( + ("return_opt", "output"), + [ + ("orthonormal", gram_schmidt_out[1]), + ("orthogonal", gram_schmidt_out[0]), + ("both", gram_schmidt_out), + failparam("fail!", gram_schmidt_out, raises=ValueError), + ], + ) + def test_gram_schmidt(self, return_opt, output): + """Test gram schmidt.""" npt.assert_allclose( - matrix.gram_schmidt(self.data1, return_opt='both'), - self.gram_schmidt_out, - err_msg='Incorrect Gram-Schmidt: both', + matrix.gram_schmidt(self.array33, return_opt=return_opt), output ) def test_nuclear_norm(self): - """Test nuclear_norm.""" + """Test nuclear norm.""" npt.assert_almost_equal( - matrix.nuclear_norm(self.data1), + matrix.nuclear_norm(self.array33), 15.49193338482967, - err_msg='Incorrect nuclear norm', ) def test_project(self): """Test project.""" npt.assert_array_equal( - matrix.project(self.data2, self.data2 + 3), + matrix.project(self.array3, self.array3 + 3), np.array([0, 2.8, 5.6]), - err_msg='Incorrect projection', ) def test_rot_matrix(self): @@ -217,280 +183,149 @@ def test_rot_matrix(self): npt.assert_allclose( matrix.rot_matrix(np.pi / 6), np.array([[0.8660254, -0.5], [0.5, 0.8660254]]), - err_msg='Incorrect rotation matrix', ) def test_rotate(self): """Test rotate.""" npt.assert_array_equal( - matrix.rotate(self.data1, np.pi / 2), + matrix.rotate(self.array33, np.pi / 2), np.array([[2, 5, 8], [1, 4, 7], [0, 3, 6]]), - err_msg='Incorrect rotation', - ) - - npt.assert_raises(ValueError, matrix.rotate, self.data3, np.pi / 2) - - def test_powermethod_converged(self): - """Test PowerMethod converged.""" - npt.assert_almost_equal( - self.pmInstance1.spec_rad, - 0.90429242629600837, - err_msg='Incorrect spectral radius: converged', ) - npt.assert_almost_equal( - self.pmInstance1.inv_spec_rad, - 1.1058369736612865, - err_msg='Incorrect inverse spectral radius: converged', - ) - - def test_powermethod_unconverged(self): - """Test PowerMethod unconverged.""" - npt.assert_almost_equal( - self.pmInstance2.spec_rad, - 0.92048833577059219, - err_msg='Incorrect spectral radius: unconverged', - ) - - npt.assert_almost_equal( - self.pmInstance2.inv_spec_rad, - 1.0863798715741946, - err_msg='Incorrect inverse spectral radius: unconverged', - ) - - -class MetricsTestCase(TestCase): - """Test case for metrics module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(49).reshape(7, 7) - self.mask = np.ones(self.data1.shape) - self.ssim_res = 0.8963363560519094 - self.ssim_mask_res = 0.805154442543846 - self.snr_res = 10.134554256920536 - self.psnr_res = 14.860761791850397 - self.mse_res = 0.03265305507330247 - self.nrmse_res = 0.31136678840022625 - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.mask = None - self.ssim_res = None - self.ssim_mask_res = None - self.psnr_res = None - self.mse_res = None - self.nrmse_res = None - - @skipIf(import_skimage, 'skimage is installed.') # pragma: no cover - def test_ssim_skimage_error(self): - """Test ssim skimage error.""" - npt.assert_raises(ImportError, metrics.ssim, self.data1, self.data1) - - @skipUnless(import_skimage, 'skimage not installed.') # pragma: no cover - def test_ssim(self): + npt.assert_raises(ValueError, matrix.rotate, self.array23, np.pi / 2) + + @pytest.mark.parametrize( + ("pm_instance", "value"), + [(True, 1.0), (False, 0.8675467477372257)], + indirect=["pm_instance"], + ) + def test_power_method(self, pm_instance, value): + """Test power method.""" + npt.assert_almost_equal(pm_instance.spec_rad, value) + npt.assert_almost_equal(pm_instance.inv_spec_rad, 1 / value) + + +class TestMetrics: + """Test metrics module.""" + + data1 = np.arange(49).reshape(7, 7) + mask = np.ones(data1.shape) + ssim_res = 0.8963363560519094 + ssim_mask_res = 0.805154442543846 + snr_res = 10.134554256920536 + psnr_res = 14.860761791850397 + mse_res = 0.03265305507330247 + nrmse_res = 0.31136678840022625 + + @pytest.mark.skipif(not SKIMAGE_AVAILABLE, reason="skimage not installed") + @pytest.mark.parametrize( + ("data1", "data2", "result", "mask"), + [ + (data1, data1**2, ssim_res, None), + (data1, data1**2, ssim_mask_res, mask), + failparam(data1, data1, None, 1, raises=ValueError), + ], + ) + def test_ssim(self, data1, data2, result, mask): """Test ssim.""" - npt.assert_almost_equal( - metrics.ssim(self.data1, self.data1 ** 2), - self.ssim_res, - err_msg='Incorrect SSIM result', - ) + npt.assert_almost_equal(metrics.ssim(data1, data2, mask=mask), result) - npt.assert_almost_equal( - metrics.ssim(self.data1, self.data1 ** 2, mask=self.mask), - self.ssim_mask_res, - err_msg='Incorrect SSIM result', - ) - - npt.assert_raises( - ValueError, - metrics.ssim, - self.data1, - self.data1, - mask=1, - ) + @pytest.mark.skipif(SKIMAGE_AVAILABLE, reason="skimage installed") + def test_ssim_fail(self): + """Test ssim.""" + npt.assert_raises(ImportError, metrics.ssim, self.data1, self.data1) - def test_snr(self): + @pytest.mark.parametrize( + ("metric", "data", "result", "mask"), + [ + (metrics.snr, data1, snr_res, None), + (metrics.snr, data1, snr_res, mask), + (metrics.psnr, data1, psnr_res, None), + (metrics.psnr, data1, psnr_res, mask), + (metrics.mse, data1, mse_res, None), + (metrics.mse, data1, mse_res, mask), + (metrics.nrmse, data1, nrmse_res, None), + (metrics.nrmse, data1, nrmse_res, mask), + failparam(metrics.snr, data1, snr_res, "maskfail", raises=ValueError), + ], + ) + def test_metric(self, metric, data, result, mask): """Test snr.""" - npt.assert_almost_equal( - metrics.snr(self.data1, self.data1 ** 2), - self.snr_res, - err_msg='Incorrect SNR result', - ) - - npt.assert_almost_equal( - metrics.snr(self.data1, self.data1 ** 2, mask=self.mask), - self.snr_res, - err_msg='Incorrect SNR result', - ) - - def test_psnr(self): - """Test psnr.""" - npt.assert_almost_equal( - metrics.psnr(self.data1, self.data1 ** 2), - self.psnr_res, - err_msg='Incorrect PSNR result', - ) - - npt.assert_almost_equal( - metrics.psnr(self.data1, self.data1 ** 2, mask=self.mask), - self.psnr_res, - err_msg='Incorrect PSNR result', - ) - - def test_mse(self): - """Test mse.""" - npt.assert_almost_equal( - metrics.mse(self.data1, self.data1 ** 2), - self.mse_res, - err_msg='Incorrect MSE result', - ) - - npt.assert_almost_equal( - metrics.mse(self.data1, self.data1 ** 2, mask=self.mask), - self.mse_res, - err_msg='Incorrect MSE result', - ) - - def test_nrmse(self): - """Test nrmse.""" - npt.assert_almost_equal( - metrics.nrmse(self.data1, self.data1 ** 2), - self.nrmse_res, - err_msg='Incorrect NRMSE result', - ) - - npt.assert_almost_equal( - metrics.nrmse(self.data1, self.data1 ** 2, mask=self.mask), - self.nrmse_res, - err_msg='Incorrect NRMSE result', - ) - - -class StatsTestCase(TestCase): - """Test case for stats module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3) - self.data2 = np.arange(18).reshape(2, 3, 3) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - - @skipIf(import_astropy, 'Astropy is installed.') # pragma: no cover - def test_gaussian_kernel_astropy_error(self): - """Test gaussian_kernel astropy error.""" - npt.assert_raises( - ImportError, - stats.gaussian_kernel, - self.data1.shape, - 1, - ) - - @skipUnless(import_astropy, 'Astropy not installed.') # pragma: no cover - def test_gaussian_kernel_max(self): - """Test gaussian_kernel with max norm.""" + npt.assert_almost_equal(metric(data, data**2, mask=mask), result) + + +class TestStats: + """Test stats module.""" + + array33 = np.arange(9).reshape(3, 3) + array233 = np.arange(18).reshape(2, 3, 3) + + @pytest.mark.skipif(not ASTROPY_AVAILABLE, reason="astropy not installed") + @pytest.mark.parametrize( + ("norm", "result"), + [ + ( + "max", + np.array( + [ + [0.36787944, 0.60653066, 0.36787944], + [0.60653066, 1.0, 0.60653066], + [0.36787944, 0.60653066, 0.36787944], + ] + ), + ), + ( + "sum", + np.array( + [ + [0.07511361, 0.1238414, 0.07511361], + [0.1238414, 0.20417996, 0.1238414], + [0.07511361, 0.1238414, 0.07511361], + ] + ), + ), + failparam("fail", None, raises=ValueError), + ], + ) + def test_gaussian_kernel(self, norm, result): + """Test Gaussian kernel.""" npt.assert_allclose( - stats.gaussian_kernel(self.data1.shape, 1), - np.array([ - [0.36787944, 0.60653066, 0.36787944], - [0.60653066, 1.0, 0.60653066], - [0.36787944, 0.60653066, 0.36787944], - ]), - err_msg='Incorrect gaussian kernel: max norm', + stats.gaussian_kernel(self.array33.shape, 1, norm=norm), result ) - npt.assert_raises( - ValueError, - stats.gaussian_kernel, - self.data1.shape, - 1, - norm='bla', - ) - - @skipUnless(import_astropy, 'Astropy not installed.') # pragma: no cover - def test_gaussian_kernel_sum(self): - """Test gaussian_kernel with sum norm.""" - npt.assert_allclose( - stats.gaussian_kernel(self.data1.shape, 1, norm='sum'), - np.array([ - [0.07511361, 0.1238414, 0.07511361], - [0.1238414, 0.20417996, 0.1238414], - [0.07511361, 0.1238414, 0.07511361], - ]), - err_msg='Incorrect gaussian kernel: sum norm', - ) - - @skipUnless(import_astropy, 'Astropy not installed.') # pragma: no cover - def test_gaussian_kernel_none(self): - """Test gaussian_kernel with no norm.""" - npt.assert_allclose( - stats.gaussian_kernel(self.data1.shape, 1, norm='none'), - np.array([ - [0.05854983, 0.09653235, 0.05854983], - [0.09653235, 0.15915494, 0.09653235], - [0.05854983, 0.09653235, 0.05854983], - ]), - err_msg='Incorrect gaussian kernel: sum norm', - ) + @pytest.mark.skipif(ASTROPY_AVAILABLE, reason="astropy installed") + def test_import_astropy(self): + """Test missing astropy.""" + npt.assert_raises(ImportError, stats.gaussian_kernel, self.array33.shape, 1) def test_mad(self): """Test mad.""" - npt.assert_equal( - stats.mad(self.data1), - 2.0, - err_msg='Incorrect median absolute deviation', - ) - - def test_mse(self): - """Test mse.""" - npt.assert_equal( - stats.mse(self.data1, self.data1 + 2), - 4.0, - err_msg='Incorrect mean squared error', - ) + npt.assert_equal(stats.mad(self.array33), 2.0) - def test_psnr_starck(self): - """Test psnr.""" + def test_sigma_mad(self): + """Test sigma_mad.""" npt.assert_almost_equal( - stats.psnr(self.data1, self.data1 + 2), - 12.041199826559248, - err_msg='Incorrect PSNR: starck', - ) - - npt.assert_raises( - ValueError, - stats.psnr, - self.data1, - self.data1, - method='bla', + stats.sigma_mad(self.array33), + 2.9651999999999998, ) - def test_psnr_wiki(self): - """Test psnr wiki method.""" - npt.assert_almost_equal( - stats.psnr(self.data1, self.data1 + 2, method='wiki'), - 42.110203695399477, - err_msg='Incorrect PSNR: wiki', - ) + @pytest.mark.parametrize( + ("data1", "data2", "method", "result"), + [ + (array33, array33 + 2, "starck", 12.041199826559248), + failparam(array33, array33, "fail", 0, raises=ValueError), + (array33, array33 + 2, "wiki", 42.110203695399477), + ], + ) + def test_psnr(self, data1, data2, method, result): + """Test PSNR.""" + npt.assert_almost_equal(stats.psnr(data1, data2, method=method), result) def test_psnr_stack(self): """Test psnr stack.""" npt.assert_almost_equal( - stats.psnr_stack(self.data2, self.data2 + 2), + stats.psnr_stack(self.array233, self.array233 + 2), 12.041199826559248, - err_msg='Incorrect PSNR stack', ) - npt.assert_raises(ValueError, stats.psnr_stack, self.data1, self.data1) - - def test_sigma_mad(self): - """Test sigma_mad.""" - npt.assert_almost_equal( - stats.sigma_mad(self.data1), - 2.9651999999999998, - err_msg='Incorrect sigma from MAD', - ) + npt.assert_raises(ValueError, stats.psnr_stack, self.array33, self.array33) diff --git a/modopt/tests/test_opt.py b/modopt/tests/test_opt.py index d5547783..4a82e33c 100644 --- a/modopt/tests/test_opt.py +++ b/modopt/tests/test_opt.py @@ -1,718 +1,293 @@ -# -*- coding: utf-8 -*- - """UNIT TESTS FOR OPT. -This module contains unit tests for the modopt.opt module. - -:Author: Samuel Farrens +This module contains tests for the modopt.opt module. +:Authors: + Samuel Farrens + Pierre-Antoine Comby """ -from builtins import zip -from unittest import TestCase, skipIf, skipUnless - import numpy as np import numpy.testing as npt +import pytest +from pytest_cases import parametrize, parametrize_with_cases, case, fixture, fixture_ref + +from modopt.opt import cost, gradient, linear, proximity, reweight -from modopt.opt import algorithms, cost, gradient, linear, proximity, reweight +from test_helpers import Dummy +SKLEARN_AVAILABLE = True try: import sklearn -except ImportError: # pragma: no cover - import_sklearn = False -else: - import_sklearn = True +except ImportError: + SKLEARN_AVAILABLE = False +PYWT_AVAILABLE = True +try: + import pywt + import joblib +except ImportError: + PYWT_AVAILABLE = False # Basic functions to be used as operators or as dummy functions func_identity = lambda x_val: x_val func_double = lambda x_val: x_val * 2 -func_sq = lambda x_val: x_val ** 2 -func_cube = lambda x_val: x_val ** 3 - - -class Dummy(object): - """Dummy class for tests.""" - - pass - - -class AlgorithmTestCase(TestCase): - """Test case for algorithms module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) - self.data2 = self.data1 + np.random.randn(*self.data1.shape) * 1e-6 - self.data3 = np.arange(9).reshape(3, 3).astype(float) + 1 - - grad_inst = gradient.GradBasic( - self.data1, - func_identity, - func_identity, - ) - - prox_inst = proximity.Positivity() - prox_dual_inst = proximity.IdentityProx() - linear_inst = linear.Identity() - reweight_inst = reweight.cwbReweight(self.data3) - cost_inst = cost.costObj([grad_inst, prox_inst, prox_dual_inst]) - self.setup = algorithms.SetUp() - self.max_iter = 20 - - self.fb_all_iter = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=None, - auto_iterate=False, - beta_update=func_identity, - ) - self.fb_all_iter.iterate(self.max_iter) - - self.fb1 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - beta_update=func_identity, - ) - - self.fb2 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - cost=cost_inst, - lambda_update=None, - ) - - self.fb3 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - beta_update=func_identity, - a_cd=3, - ) - - self.fb4 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - beta_update=func_identity, - r_lazy=3, - p_lazy=0.7, - q_lazy=0.7, - ) - - self.fb5 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - restart_strategy='adaptive', - xi_restart=0.9, - ) - - self.fb6 = algorithms.ForwardBackward( - self.data1, - grad=grad_inst, - prox=prox_inst, - restart_strategy='greedy', - xi_restart=0.9, - min_beta=1.0, - s_greedy=1.1, - ) - - self.gfb_all_iter = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - cost=None, - auto_iterate=False, - gamma_update=func_identity, - beta_update=func_identity, - ) - self.gfb_all_iter.iterate(self.max_iter) - - self.gfb1 = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - gamma_update=func_identity, - lambda_update=func_identity, - ) - - self.gfb2 = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - cost=cost_inst, - ) - - self.gfb3 = algorithms.GenForwardBackward( - self.data1, - grad=grad_inst, - prox_list=[prox_inst, prox_dual_inst], - cost=cost_inst, - step_size=2, - ) - - self.condat_all_iter = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - cost=None, - prox_dual=prox_dual_inst, - sigma_update=func_identity, - tau_update=func_identity, - rho_update=func_identity, - auto_iterate=False, - ) - self.condat_all_iter.iterate(self.max_iter) - - self.condat1 = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - prox_dual=prox_dual_inst, - sigma_update=func_identity, - tau_update=func_identity, - rho_update=func_identity, - ) - - self.condat2 = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - prox_dual=prox_dual_inst, - linear=linear_inst, - cost=cost_inst, - reweight=reweight_inst, - ) - - self.condat3 = algorithms.Condat( - self.data1, - self.data2, - grad=grad_inst, - prox=prox_inst, - prox_dual=prox_dual_inst, - linear=Dummy(), - cost=cost_inst, - auto_iterate=False, - ) - - self.pogm_all_iter = algorithms.POGM( - u=self.data1, - x=self.data1, - y=self.data1, - z=self.data1, - grad=grad_inst, - prox=prox_inst, - auto_iterate=False, - cost=None, - ) - self.pogm_all_iter.iterate(self.max_iter) - - self.pogm1 = algorithms.POGM( - u=self.data1, - x=self.data1, - y=self.data1, - z=self.data1, - grad=grad_inst, - prox=prox_inst, - ) - - self.dummy = Dummy() - self.dummy.cost = func_identity - self.setup._check_operator(self.dummy.cost) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.setup = None - self.fb_all_iter = None - self.fb1 = None - self.fb2 = None - self.gfb_all_iter = None - self.gfb1 = None - self.gfb2 = None - self.condat_all_iter = None - self.condat1 = None - self.condat2 = None - self.condat3 = None - self.pogm1 = None - self.pogm_all_iter = None - self.dummy = None - - def test_set_up(self): - """Test set_up.""" - npt.assert_raises(TypeError, self.setup._check_input_data, 1) - - npt.assert_raises(TypeError, self.setup._check_param, 1) - - npt.assert_raises(TypeError, self.setup._check_param_update, 1) - - def test_all_iter(self): - """Test if all opt run for all iterations.""" - opts = [ - self.fb_all_iter, - self.gfb_all_iter, - self.condat_all_iter, - self.pogm_all_iter, - ] - for opt in opts: - npt.assert_equal(opt.idx, self.max_iter - 1) - - def test_forward_backward(self): - """Test forward_backward.""" - npt.assert_array_equal( - self.fb1.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb2.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb3.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb4.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb5.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - npt.assert_array_equal( - self.fb6.x_final, - self.data1, - err_msg='Incorrect ForwardBackward result.', - ) - - def test_gen_forward_backward(self): - """Test gen_forward_backward.""" - npt.assert_array_equal( - self.gfb1.x_final, - self.data1, - err_msg='Incorrect GenForwardBackward result.', - ) - - npt.assert_array_equal( - self.gfb2.x_final, - self.data1, - err_msg='Incorrect GenForwardBackward result.', - ) - - npt.assert_array_equal( - self.gfb3.x_final, - self.data1, - err_msg='Incorrect GenForwardBackward result.', - ) - - npt.assert_equal( - self.gfb3.step_size, - 2, - err_msg='Incorrect step size.', - ) - - npt.assert_raises( - TypeError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=1, - ) - - npt.assert_raises( - ValueError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=[1], - ) - - npt.assert_raises( - ValueError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=[0.5, 0.5], - ) - - npt.assert_raises( - ValueError, - algorithms.GenForwardBackward, - self.data1, - self.dummy, - [self.dummy], - weights=[0.5], - ) - - def test_condat(self): - """Test gen_condat.""" - npt.assert_almost_equal( - self.condat1.x_final, - self.data1, - err_msg='Incorrect Condat result.', - ) +func_sq = lambda x_val: x_val**2 +func_cube = lambda x_val: x_val**3 + + +@case(tags="cost") +@parametrize( + ("cost_interval", "n_calls", "converged"), + [(1, 1, False), (1, 2, True), (2, 5, False), (None, 6, False)], +) +def case_cost_op(cost_interval, n_calls, converged): + """Case function for costs.""" + dummy_inst1 = Dummy() + dummy_inst1.cost = func_sq + dummy_inst2 = Dummy() + dummy_inst2.cost = func_cube + + cost_obj = cost.costObj([dummy_inst1, dummy_inst2], cost_interval=cost_interval) + + for _ in range(n_calls + 1): + cost_obj.get_cost(2) + return cost_obj, converged + + +@parametrize_with_cases("cost_obj, converged", cases=".", has_tag="cost") +def test_costs(cost_obj, converged): + """Test cost.""" + npt.assert_equal(cost_obj.get_cost(2), converged) + if cost_obj._cost_interval: + npt.assert_equal(cost_obj.cost, 12) + + +def test_raise_cost(): + """Test error raising for cost.""" + npt.assert_raises(TypeError, cost.costObj, 1) + npt.assert_raises(ValueError, cost.costObj, [Dummy(), Dummy()]) + + +@case(tags="grad") +@parametrize(call=("op", "trans_op", "trans_op_op")) +def case_grad_parent(call): + """Case for gradient parent.""" + input_data = np.arange(9).reshape(3, 3) + callables = { + "op": func_sq, + "trans_op": func_cube, + "get_grad": func_identity, + "cost": lambda input_val: 1.0, + } + + grad_op = gradient.GradParent( + input_data, + **callables, + data_type=np.floating, + ) + if call != "trans_op_op": + result = callables[call](input_data) + else: + result = callables["trans_op"](callables["op"](input_data)) + + grad_call = getattr(grad_op, call)(input_data) + return grad_call, result + + +@parametrize_with_cases("grad_values, result", cases=".", has_tag="grad") +def test_grad_op(grad_values, result): + """Test Gradient operator.""" + npt.assert_equal(grad_values, result) + + +@pytest.fixture +def grad_basic(): + """Case for GradBasic.""" + input_data = np.arange(9).reshape(3, 3) + grad_op = gradient.GradBasic( + input_data, + func_sq, + func_cube, + verbose=True, + ) + grad_op.get_grad(input_data) + return grad_op + + +def test_grad_basic(grad_basic): + """Test grad basic.""" + npt.assert_array_equal( + grad_basic.grad, + np.array( + [ + [0, 0, 8.0], + [2.16000000e2, 1.72800000e3, 8.0e3], + [2.70000000e4, 7.40880000e4, 1.75616000e5], + ] + ), + err_msg="Incorrect gradient.", + ) - npt.assert_almost_equal( - self.condat2.x_final, - self.data1, - err_msg='Incorrect Condat result.', - ) - def test_pogm(self): - """Test pogm.""" - npt.assert_almost_equal( - self.pogm1.x_final, - self.data1, - err_msg='Incorrect POGM result.', - ) +def test_grad_basic_cost(grad_basic): + """Test grad_basic cost.""" + npt.assert_almost_equal(grad_basic.cost(np.arange(9).reshape(3, 3)), 3192.0) -class CostTestCase(TestCase): - """Test case for cost module.""" +def test_grad_op_raises(): + """Test raise error.""" + npt.assert_raises( + TypeError, + gradient.GradParent, + 1, + func_sq, + func_cube, + ) - def setUp(self): - """Set test parameter values.""" - dummy_inst1 = Dummy() - dummy_inst1.cost = func_sq - dummy_inst2 = Dummy() - dummy_inst2.cost = func_cube - self.inst1 = cost.costObj([dummy_inst1, dummy_inst2]) - self.inst2 = cost.costObj([dummy_inst1, dummy_inst2], cost_interval=2) - # Test that by default cost of False if interval is None - self.inst_none = cost.costObj( - [dummy_inst1, dummy_inst2], - cost_interval=None, - ) - for _ in range(2): - self.inst1.get_cost(2) - for _ in range(6): - self.inst2.get_cost(2) - self.inst_none.get_cost(2) - self.dummy = Dummy() - - def tearDown(self): - """Unset test parameter values.""" - self.inst = None - - def test_cost_object(self): - """Test cost_object.""" - npt.assert_equal( - self.inst1.get_cost(2), - False, - err_msg='Incorrect cost test result.', - ) - npt.assert_equal( - self.inst1.get_cost(2), - True, - err_msg='Incorrect cost test result.', - ) - npt.assert_equal( - self.inst_none.get_cost(2), - False, - err_msg='Incorrect cost test result.', - ) - - npt.assert_equal(self.inst1.cost, 12, err_msg='Incorrect cost value.') +############# +# LINEAR OP # +############# - npt.assert_equal(self.inst2.cost, 12, err_msg='Incorrect cost value.') - npt.assert_raises(TypeError, cost.costObj, 1) +class LinearCases: + """Linear operator cases.""" - npt.assert_raises(ValueError, cost.costObj, [self.dummy, self.dummy]) + def case_linear_identity(self): + """Case linear operator identity.""" + linop = linear.Identity() + data_op, data_adj_op, res_op, res_adj_op = 1, 1, 1, 1 -class GradientTestCase(TestCase): - """Test case for gradient module.""" + return linop, data_op, data_adj_op, res_op, res_adj_op - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) - self.gp = gradient.GradParent( - self.data1, - func_sq, - func_cube, - func_identity, - lambda input_val: 1.0, - data_type=np.floating, - ) - self.gp.grad = self.gp.get_grad(self.data1) - self.gb = gradient.GradBasic( - self.data1, - func_sq, - func_cube, - ) - self.gb.get_grad(self.data1) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.gp = None - self.gb = None - - def test_grad_parent_operators(self): - """Test GradParent.""" - npt.assert_array_equal( - self.gp.op(self.data1), - np.array([[0, 1.0, 4.0], [9.0, 16.0, 25.0], [36.0, 49.0, 64.0]]), - err_msg='Incorrect gradient operation.', - ) - - npt.assert_array_equal( - self.gp.trans_op(self.data1), - np.array( - [[0, 1.0, 8.0], [27.0, 64.0, 125.0], [216.0, 343.0, 512.0]], - ), - err_msg='Incorrect gradient transpose operation.', + def case_linear_wavelet_convolve(self): + """Case linear operator wavelet.""" + linop = linear.WaveletConvolve( + filters=np.arange(8).reshape(2, 2, 2).astype(float) ) + data_op = np.arange(4).reshape(1, 2, 2).astype(float) + data_adj_op = np.arange(8).reshape(1, 2, 2, 2).astype(float) + res_op = np.array([[[[0, 0], [0, 4.0]], [[0, 4.0], [8.0, 28.0]]]]) + res_adj_op = np.array([[[28.0, 62.0], [68.0, 140.0]]]) - npt.assert_array_equal( - self.gp.trans_op_op(self.data1), - np.array([ - [0, 1.0, 6.40000000e1], - [7.29000000e2, 4.09600000e3, 1.56250000e4], - [4.66560000e4, 1.17649000e5, 2.62144000e5], - ]), - err_msg='Incorrect gradient transpose operation operation.', - ) + return linop, data_op, data_adj_op, res_op, res_adj_op - npt.assert_equal( - self.gp.cost(self.data1), - 1.0, - err_msg='Incorrect cost.', + @pytest.mark.skipif(not PYWT_AVAILABLE, reason="PyWavelet not available.") + def case_linear_wavelet_transform(self): + linop = linear.WaveletTransform( + wavelet_name="haar", + shape=(8, 8), + level=2, ) + data_op = np.arange(64).reshape(8, 8).astype(float) + res_op, slices, shapes = pywt.ravel_coeffs(pywt.wavedecn(data_op, "haar", level=2)) + data_adj_op = linop.op(data_op) + res_adj_op = pywt.waverecn(pywt.unravel_coeffs(data_adj_op, slices, shapes, "wavedecn"), "haar") + return linop, data_op, data_adj_op, res_op, res_adj_op - npt.assert_raises( - TypeError, - gradient.GradParent, - 1, + @parametrize(weights=[[1.0, 1.0], None]) + def case_linear_combo(self, weights): + """Case linear operator combo with weights.""" + parent = linear.LinearParent( func_sq, func_cube, ) + linop = linear.LinearCombo([parent, parent], weights) - def test_grad_basic_gradient(self): - """Test GradBasic.""" - npt.assert_array_equal( - self.gb.grad, - np.array([ - [0, 0, 8.0], - [2.16000000e2, 1.72800000e3, 8.0e3], - [2.70000000e4, 7.40880000e4, 1.75616000e5], - ]), - err_msg='Incorrect gradient.', + data_op, data_adj_op, res_op, res_adj_op = ( + 2, + np.array([2, 2]), + np.array([4, 4]), + 8.0 * (2 if weights else 1), ) + return linop, data_op, data_adj_op, res_op, res_adj_op -class LinearTestCase(TestCase): - """Test case for linear module.""" + @parametrize(factor=[1, 1 + 1j]) + def case_linear_matrix(self, factor): + """Case linear operator from matrix.""" + linop = linear.MatrixOperator(np.eye(5) * factor) + data_op = np.arange(5) + data_adj_op = np.arange(5) + res_op = np.arange(5) * factor + res_adj_op = np.arange(5) * np.conjugate(factor) - def setUp(self): - """Set test parameter values.""" - self.parent = linear.LinearParent( - func_sq, - func_cube, - ) - self.ident = linear.Identity() - filters = np.arange(8).reshape(2, 2, 2).astype(float) - self.wave = linear.WaveletConvolve(filters) - self.combo = linear.LinearCombo([self.parent, self.parent]) - self.combo_weight = linear.LinearCombo( - [self.parent, self.parent], - [1.0, 1.0], - ) - self.data1 = np.arange(18).reshape(2, 3, 3).astype(float) - self.data2 = np.arange(4).reshape(1, 2, 2).astype(float) - self.data3 = np.arange(8).reshape(1, 2, 2, 2).astype(float) - self.data4 = np.array([[[[0, 0], [0, 4.0]], [[0, 4.0], [8.0, 28.0]]]]) - self.data5 = np.array([[[28.0, 62.0], [68.0, 140.0]]]) - self.dummy = Dummy() - - def tearDown(self): - """Unset test parameter values.""" - self.parent = None - self.ident = None - self.combo = None - self.combo_weight = None - self.data1 = None - self.data2 = None - self.data3 = None - self.data4 = None - self.data5 = None - self.dummy = None - - def test_linear_parent(self): - """Test LinearParent.""" - npt.assert_equal( - self.parent.op(2), - 4, - err_msg='Incorrect linear parent operation.', - ) + return linop, data_op, data_adj_op, res_op, res_adj_op - npt.assert_equal( - self.parent.adj_op(2), - 8, - err_msg='Incorrect linear parent adjoint operation.', - ) - npt.assert_raises(TypeError, linear.LinearParent, 0, 0) +@fixture +@parametrize_with_cases( + "linop, data_op, data_adj_op, res_op, res_adj_op", cases=LinearCases +) +def lin_adj_op(linop, data_op, data_adj_op, res_op, res_adj_op): + """Get adj_op relative data.""" + return linop.adj_op, data_adj_op, res_adj_op - def test_identity(self): - """Test Identity.""" - npt.assert_equal( - self.ident.op(1.0), - 1.0, - err_msg='Incorrect identity operation.', - ) - npt.assert_equal( - self.ident.adj_op(1.0), - 1.0, - err_msg='Incorrect identity adjoint operation.', - ) +@fixture +@parametrize_with_cases( + "linop, data_op, data_adj_op, res_op, res_adj_op", cases=LinearCases +) +def lin_op(linop, data_op, data_adj_op, res_op, res_adj_op): + """Get op relative data.""" + return linop.op, data_op, res_op - def test_wavelet_convolve(self): - """Test WaveletConvolve.""" - npt.assert_almost_equal( - self.wave.op(self.data2), - self.data4, - err_msg='Incorrect wavelet convolution operation.', - ) - npt.assert_almost_equal( - self.wave.adj_op(self.data3), - self.data5, - err_msg='Incorrect wavelet convolution adjoint operation.', - ) +@parametrize( + ("action", "data", "result"), [fixture_ref(lin_op), fixture_ref(lin_adj_op)] +) +def test_linear_operator(action, data, result): + """Test linear operator.""" + npt.assert_almost_equal(action(data), result) - def test_linear_combo(self): - """Test LinearCombo.""" - npt.assert_equal( - self.combo.op(2), - np.array([4, 4]).astype(object), - err_msg='Incorrect combined linear operation', - ) - npt.assert_equal( - self.combo.adj_op([2, 2]), - 8.0, - err_msg='Incorrect combined linear adjoint operation', - ) +dummy_with_op = Dummy() +dummy_with_op.op = lambda x: x - npt.assert_raises(TypeError, linear.LinearCombo, self.parent) - npt.assert_raises(ValueError, linear.LinearCombo, []) +@pytest.mark.parametrize( + ("args", "error"), + [ + ([linear.LinearParent(func_sq, func_cube)], TypeError), + ([[]], ValueError), + ([[Dummy()]], ValueError), + ([[dummy_with_op]], ValueError), + ([[]], ValueError), + ([[linear.LinearParent(func_sq, func_cube)] * 2, [1.0]], ValueError), + ([[linear.LinearParent(func_sq, func_cube)] * 2, ["1", "1"]], TypeError), + ], +) +def test_linear_combo_errors(args, error): + """Test linear combo_errors.""" + npt.assert_raises(error, linear.LinearCombo, *args) - npt.assert_raises(ValueError, linear.LinearCombo, [self.dummy]) - self.dummy.op = func_identity +############# +# Proximity # +############# - npt.assert_raises(ValueError, linear.LinearCombo, [self.dummy]) - def test_linear_combo_weight(self): - """Test LinearCombo with weight .""" - npt.assert_equal( - self.combo_weight.op(2), - np.array([4, 4]).astype(object), - err_msg='Incorrect combined linear operation', - ) - - npt.assert_equal( - self.combo_weight.adj_op([2, 2]), - 16.0, - err_msg='Incorrect combined linear adjoint operation', - ) +class ProxCases: + """Class containing all proximal operator cases. - npt.assert_raises( - ValueError, - linear.LinearCombo, - [self.parent, self.parent], - [1.0], - ) - - npt.assert_raises( - TypeError, - linear.LinearCombo, - [self.parent, self.parent], - ['1', '1'], - ) + Each case should return 4 parameters: + 1. The proximal operator + 2. test input data + 3. Expected result data + 4. Expected cost value. + """ + weights = np.ones(9).reshape(3, 3).astype(float) * 3 + array33 = np.arange(9).reshape(3, 3).astype(float) + array33_st = np.array([[-0, -0, -0], [0, 1.0, 2.0], [3.0, 4.0, 5.0]]) + array33_st2 = array33_st * -1 -class ProximityTestCase(TestCase): - """Test case for proximity module.""" + array33_support = np.asarray([[0, 0, 0], [0, 1.0, 1.25], [1.5, 1.75, 2.0]]) - def setUp(self): - """Set test parameter values.""" - self.parent = proximity.ProximityParent( - func_sq, - func_double, - ) - self.identity = proximity.IdentityProx() - self.positivity = proximity.Positivity() - weights = np.ones(9).reshape(3, 3).astype(float) * 3 - self.sparsethresh = proximity.SparseThreshold( - linear.Identity(), - weights, - ) - self.lowrank = proximity.LowRankMatrix(10.0, thresh_type='hard') - self.lowrank_rank = proximity.LowRankMatrix( - 10.0, - initial_rank=1, - thresh_type='hard', - ) - self.lowrank_ngole = proximity.LowRankMatrix( - 10.0, - lowr_type='ngole', - operator=func_double, - ) - self.linear_comp = proximity.LinearCompositionProx( - linear_op=linear.Identity(), - prox_op=self.sparsethresh, - ) - self.combo = proximity.ProximityCombo([self.identity, self.positivity]) - if import_sklearn: - self.owl = proximity.OrderedWeightedL1Norm(weights.flatten()) - self.ridge = proximity.Ridge(linear.Identity(), weights) - self.elasticnet_alpha0 = proximity.ElasticNet( - linear.Identity(), - alpha=0, - beta=weights, - ) - self.elasticnet_beta0 = proximity.ElasticNet( - linear.Identity(), - alpha=weights, - beta=0, - ) - self.one_support = proximity.KSupportNorm(beta=0.2, k_value=1) - self.five_support_norm = proximity.KSupportNorm(beta=3, k_value=5) - self.d_support = proximity.KSupportNorm(beta=3.0 * 2, k_value=19) - self.group_lasso = proximity.GroupLASSO( - weights=np.tile(weights, (4, 1, 1)), - ) - self.data1 = np.arange(9).reshape(3, 3).astype(float) - self.data2 = np.array([[-0, -0, -0], [0, 1.0, 2.0], [3.0, 4.0, 5.0]]) - self.data3 = np.arange(18).reshape(2, 3, 3).astype(float) - self.data4 = np.array([ + array233 = np.arange(18).reshape(2, 3, 3).astype(float) + array233_2 = np.array( [ [2.73843189, 3.14594066, 3.55344943], [3.9609582, 4.36846698, 4.77597575], @@ -723,349 +298,230 @@ def setUp(self): [11.67394789, 12.87497954, 14.07601119], [15.27704284, 16.47807449, 17.67910614], ], - ]) - self.data5 = np.array([ + ] + ) + array233_3 = np.array( + [ [[0, 0, 0], [0, 0, 0], [0, 0, 0]], [ [4.00795282, 4.60438026, 5.2008077], [5.79723515, 6.39366259, 6.99009003], [7.58651747, 8.18294492, 8.77937236], ], - ]) - self.data6 = self.data3 * -1 - self.data7 = self.combo.op(self.data6) - self.data8 = np.empty(2, dtype=np.ndarray) - self.data8[0] = np.array( - [[-0, -1.0, -2.0], [-3.0, -4.0, -5.0], [-6.0, -7.0, -8.0]], - ) - self.data8[1] = np.array( - [[-0, -0, -0], [-0, -0, -0], [-0, -0, -0]], - ) - self.data9 = self.data1 * (1 + 1j) - self.data10 = self.data9 / (2 * 3 + 1) - self.data11 = np.asarray( - [[0, 0, 0], [0, 1.0, 1.25], [1.5, 1.75, 2.0]], - ) - self.random_data = 3 * np.random.random( - self.group_lasso.weights[0].shape, - ) - self.random_data_tile = np.tile( - self.random_data, - (self.group_lasso.weights.shape[0], 1, 1), - ) - self.gl_result_data = 2 * self.random_data_tile - 3 - self.gl_result_data = np.array( - (self.gl_result_data * (self.gl_result_data > 0).astype('int')) - / 2, - ) - - self.dummy = Dummy() - - def tearDown(self): - """Unset test parameter values.""" - self.parent = None - self.identity = None - self.positivity = None - self.sparsethresh = None - self.lowrank = None - self.lowrank_rank = None - self.lowrank_ngole = None - self.combo = None - self.data1 = None - self.data2 = None - self.data3 = None - self.data4 = None - self.data5 = None - self.data6 = None - self.data7 = None - self.data8 = None - self.dummy = None - self.random_data = None - self.random_data_tile = None - self.gl_result_data = None - - def test_proximity_parent(self): - """Test ProximityParent.""" - npt.assert_equal( - self.parent.op(3), + ] + ) + + def case_prox_parent(self): + """Case prox parent.""" + return ( + proximity.ProximityParent( + func_sq, + func_double, + ), + 3, 9, - err_msg='Inccoret proximity parent operation.', - ) - - npt.assert_equal( - self.parent.cost(3), 6, - err_msg='Incorrect proximity parent cost.', - ) - - def test_identity(self): - """Test IdentityProx.""" - npt.assert_equal( - self.identity.op(3), - 3, - err_msg='Incorrect proximity identity operation.', - ) - - npt.assert_equal( - self.identity.cost(3), - 0, - err_msg='Incorrect proximity identity cost.', - ) - - def test_positivity(self): - """Test Positivity.""" - npt.assert_equal( - self.positivity.op(-3), - 0, - err_msg='Incorrect proximity positivity operation.', - ) - - npt.assert_equal( - self.positivity.cost(-3, verbose=True), - 0, - err_msg='Incorrect proximity positivity cost.', ) - def test_sparse_threshold(self): - """Test SparseThreshold.""" - npt.assert_array_equal( - self.sparsethresh.op(self.data1), - self.data2, - err_msg='Incorrect sparse threshold operation.', - ) - - npt.assert_equal( - self.sparsethresh.cost(self.data1, verbose=True), - 108.0, - err_msg='Incorrect sparse threshold cost.', - ) - - def test_low_rank_matrix(self): - """Test LowRankMatrix.""" - npt.assert_almost_equal( - self.lowrank.op(self.data3), - self.data4, - err_msg='Incorrect low rank operation: standard', - ) - - npt.assert_almost_equal( - self.lowrank_rank.op(self.data3), - self.data4, - err_msg='Incorrect low rank operation: standard with rank', - ) - npt.assert_almost_equal( - self.lowrank_ngole.op(self.data3), - self.data5, - err_msg='Incorrect low rank operation: ngole', - ) - - npt.assert_almost_equal( - self.lowrank.cost(self.data3, verbose=True), - 469.39132942464983, - err_msg='Incorrect low rank cost.', - ) - - def test_linear_comp_prox(self): - """Test LinearCompositionProx.""" - npt.assert_array_equal( - self.linear_comp.op(self.data1), - self.data2, - err_msg='Incorrect sparse threshold operation.', - ) - - npt.assert_equal( - self.linear_comp.cost(self.data1, verbose=True), - 108.0, - err_msg='Incorrect sparse threshold cost.', + def case_prox_identity(self): + """Case prox identity.""" + return proximity.IdentityProx(), 3, 3, 0 + + def case_prox_positivity(self): + """Case prox positivity.""" + return proximity.Positivity(), -3, 0, 0 + + def case_prox_sparsethresh(self): + """Case prox sparsethreshosld.""" + return ( + proximity.SparseThreshold(linear.Identity(), weights=self.weights), + self.array33, + self.array33_st, + 108, + ) + + @parametrize( + "lowr_type, initial_rank, operator, result, cost", + [ + ("standard", None, None, array233_2, 469.3913294246498), + ("standard", 1, None, array233_2, 469.3913294246498), + ("ngole", None, func_double, array233_3, 469.3913294246498), + ], + ) + def case_prox_lowrank(self, lowr_type, initial_rank, operator, result, cost): + """Case prox lowrank.""" + return ( + proximity.LowRankMatrix( + 10, + lowr_type=lowr_type, + initial_rank=initial_rank, + operator=operator, + thresh_type="hard" if lowr_type == "standard" else "soft", + ), + self.array233, + result, + cost, ) - def test_proximity_combo(self): - """Test ProximityCombo.""" - for data7, data8 in zip(self.data7, self.data8): - npt.assert_array_equal( - data7, - data8, - err_msg='Incorrect combined operation', + def case_prox_linear_comp(self): + """Case prox linear comp.""" + return ( + proximity.LinearCompositionProx( + linear_op=linear.Identity(), prox_op=self.case_prox_sparsethresh()[0] + ), + self.array33, + self.array33_st, + 108, + ) + + def case_prox_ridge(self): + """Case prox ridge.""" + return ( + proximity.Ridge(linear.Identity(), self.weights), + self.array33 * (1 + 1j), + self.array33 * (1 + 1j) / 7, + 1224, + ) + + @parametrize("alpha, beta", [(0, weights), (weights, 0)]) + def case_prox_elasticnet(self, alpha, beta): + """Case prox elastic net.""" + if np.all(alpha == 0): + data = self.case_prox_sparsethresh()[1:] + else: + data = self.case_prox_ridge()[1:] + return (proximity.ElasticNet(linear.Identity(), alpha, beta), *data) + + @parametrize( + "beta, k_value, data, result, cost", + [ + (0.2, 1, array33.flatten(), array33_st.flatten(), 259.2), + (3, 5, array33.flatten(), array33_support.flatten(), 684.0), + ( + 6.0, + 9, + array33.flatten() * (1 + 1j), + array33.flatten() * (1 + 1j) / 7, + 1224, + ), + ], + ) + def case_prox_Ksupport(self, beta, k_value, data, result, cost): + """Case prox K-support norm.""" + return (proximity.KSupportNorm(beta=beta, k_value=k_value), data, result, cost) + + @parametrize(use_weights=[True, False]) + def case_prox_grouplasso(self, use_weights): + """Case GroupLasso proximity.""" + if use_weights: + weights = np.tile(self.weights, (4, 1, 1)) + else: + weights = np.tile(np.zeros((3, 3)), (4, 1, 1)) + + random_data = 3 * np.random.random(weights[0].shape) + random_data_tile = np.tile(random_data, (weights.shape[0], 1, 1)) + if use_weights: + gl_result_data = 2 * random_data_tile - 3 + gl_result_data = ( + np.array(gl_result_data * (gl_result_data > 0).astype("int")) / 2 ) - - npt.assert_equal( - self.combo.cost(self.data6), - 0, - err_msg='Incorrect combined cost.', - ) - - npt.assert_raises(TypeError, proximity.ProximityCombo, 1) - - npt.assert_raises(ValueError, proximity.ProximityCombo, []) - - npt.assert_raises(ValueError, proximity.ProximityCombo, [self.dummy]) - - self.dummy.op = func_identity - - npt.assert_raises(ValueError, proximity.ProximityCombo, [self.dummy]) - - @skipIf(import_sklearn, 'sklearn is installed.') # pragma: no cover - def test_owl_sklearn_error(self): - """Test OrderedWeightedL1Norm with Scikit-Learn.""" - npt.assert_raises(ImportError, proximity.OrderedWeightedL1Norm, 1) - - @skipUnless(import_sklearn, 'sklearn not installed.') # pragma: no cover - def test_sparse_owl(self): - """Test OrderedWeightedL1Norm.""" - npt.assert_array_equal( - self.owl.op(self.data1.flatten()), - self.data2.flatten(), - err_msg='Incorrect sparse threshold operation.', - ) - - npt.assert_equal( - self.owl.cost(self.data1.flatten(), verbose=True), + cost = np.sum(random_data_tile) * 6 + else: + gl_result_data = random_data_tile + cost = 0 + return ( + proximity.GroupLASSO( + weights=weights, + ), + random_data_tile, + gl_result_data, + cost, + ) + + @pytest.mark.skipif(not SKLEARN_AVAILABLE, reason="sklearn not available.") + def case_prox_owl(self): + """Case prox for Ordered Weighted L1 Norm.""" + return ( + proximity.OrderedWeightedL1Norm(self.weights.flatten()), + self.array33.flatten(), + self.array33_st.flatten(), 108.0, - err_msg='Incorrect sparse threshold cost.', ) - npt.assert_raises( - ValueError, - proximity.OrderedWeightedL1Norm, - np.arange(10), - ) - def test_ridge(self): - """Test Ridge.""" - npt.assert_array_equal( - self.ridge.op(self.data9), - self.data10, - err_msg='Incorect shrinkage operation.', - ) +@parametrize_with_cases("operator, input_data, op_result, cost_result", cases=ProxCases) +def test_prox_op(operator, input_data, op_result, cost_result): + """Test proximity operator op.""" + npt.assert_almost_equal(operator.op(input_data), op_result) - npt.assert_equal( - self.ridge.cost(self.data9, verbose=True), - 408.0 * 3.0, - err_msg='Incorect shrinkage cost.', - ) - def test_elastic_net_alpha0(self): - """Test ElasticNet.""" - npt.assert_array_equal( - self.elasticnet_alpha0.op(self.data1), - self.data2, - err_msg='Incorect sparse threshold operation ElasticNet class.', - ) +@parametrize_with_cases("operator, input_data, op_result, cost_result", cases=ProxCases) +def test_prox_cost(operator, input_data, op_result, cost_result): + """Test proximity operator cost.""" + npt.assert_almost_equal(operator.cost(input_data, verbose=True), cost_result) - npt.assert_equal( - self.elasticnet_alpha0.cost(self.data1), - 108.0, - err_msg='Incorect shrinkage cost in ElasticNet class.', - ) - def test_elastic_net_beta0(self): - """Test ElasticNet with beta=0.""" - npt.assert_array_equal( - self.elasticnet_beta0.op(self.data9), - self.data10, - err_msg='Incorect ridge operation ElasticNet class.', - ) +@parametrize( + "arg, error", + [ + (1, TypeError), + ([], ValueError), + ([Dummy()], ValueError), + ([dummy_with_op], ValueError), + ], +) +def test_error_prox_combo(arg, error): + """Test errors for proximity combo.""" + npt.assert_raises(error, proximity.ProximityCombo, arg) - npt.assert_equal( - self.elasticnet_beta0.cost(self.data9, verbose=True), - 408.0 * 3.0, - err_msg='Incorect shrinkage cost in ElasticNet class.', - ) - def test_one_support_norm(self): - """Test KSupportNorm with k=1.""" - npt.assert_allclose( - self.one_support.op(self.data1.flatten()), - self.data2.flatten(), - err_msg='Incorect sparse threshold operation for 1-support norm', - rtol=1e-6, - ) - - npt.assert_equal( - self.one_support.cost(self.data1.flatten(), verbose=True), - 259.2, - err_msg='Incorect sparse threshold cost.', - ) +@pytest.mark.skipif(SKLEARN_AVAILABLE, reason="sklearn is installed") +def test_fail_sklearn(): + """Test fail OWL with sklearn.""" + npt.assert_raises(ImportError, proximity.OrderedWeightedL1Norm, 1) - npt.assert_raises(ValueError, proximity.KSupportNorm, 0, 0) - def test_five_support_norm(self): - """Test KSupportNorm with k=5.""" - npt.assert_allclose( - self.five_support_norm.op(self.data1.flatten()), - self.data11.flatten(), - err_msg='Incorect sparse Ksupport norm operation', - rtol=1e-6, - ) +@pytest.mark.skipif(not SKLEARN_AVAILABLE, reason="sklearn is not installed.") +def test_fail_owl(): + """Test errors for Ordered Weighted L1 Norm.""" + npt.assert_raises( + ValueError, + proximity.OrderedWeightedL1Norm, + np.arange(10), + ) - npt.assert_equal( - self.five_support_norm.cost(self.data1.flatten(), verbose=True), - 684.0, - err_msg='Incorrect 5-support norm cost.', - ) + npt.assert_raises( + ValueError, + proximity.OrderedWeightedL1Norm, + -np.arange(10), + ) - npt.assert_raises(ValueError, proximity.KSupportNorm, 0, 0) - def test_d_support_norm(self): - """Test KSupportNorm with k=19.""" - npt.assert_allclose( - self.d_support.op(self.data9.flatten()), - self.data10.flatten(), - err_msg='Incorect shrinkage operation for d-support norm', - rtol=1e-6, - ) +def test_fail_lowrank(): + """Test fail for lowrank.""" + prox_op = proximity.LowRankMatrix(10, lowr_type="fail") + npt.assert_raises(ValueError, prox_op.op, 0) - npt.assert_almost_equal( - self.d_support.cost(self.data9.flatten(), verbose=True), - 408.0 * 3.0, - err_msg='Incorrect shrinkage cost for d-support norm.', - ) - npt.assert_raises(ValueError, proximity.KSupportNorm, 0, 0) +def test_fail_Ksupport_norm(): + """Test fail for K-support norm.""" + npt.assert_raises(ValueError, proximity.KSupportNorm, 0, 0) - def test_group_lasso(self): - """Test GroupLASSO.""" - npt.assert_allclose( - self.group_lasso.op(self.random_data_tile), - self.gl_result_data, - ) - npt.assert_equal( - self.group_lasso.cost(self.random_data_tile), - np.sum(6 * self.random_data_tile), - ) - # Check that for 0 weights operator doesnt change result - self.group_lasso.weights = np.zeros_like(self.group_lasso.weights) - npt.assert_equal( - self.group_lasso.op(self.random_data_tile), - self.random_data_tile, - ) - npt.assert_equal(self.group_lasso.cost(self.random_data_tile), 0) +def test_reweight(): + """Test for reweight module.""" + data1 = np.arange(9).reshape(3, 3).astype(float) + 1 + data2 = np.array( + [[0.5, 1.0, 1.5], [2.0, 2.5, 3.0], [3.5, 4.0, 4.5]], + ) -class ReweightTestCase(TestCase): - """Test case for reweight module.""" + rw = reweight.cwbReweight(data1) + rw.reweight(data1) - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) + 1 - self.data2 = np.array( - [[0.5, 1.0, 1.5], [2.0, 2.5, 3.0], [3.5, 4.0, 4.5]], - ) - self.rw = reweight.cwbReweight(self.data1) - self.rw.reweight(self.data1) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.rw = None - - def test_cwbreweight(self): - """Test cwbReweight.""" - npt.assert_array_equal( - self.rw.weights, - self.data2, - err_msg='Incorrect CWB re-weighting.', - ) + npt.assert_array_equal( + rw.weights, + data2, + err_msg="Incorrect CWB re-weighting.", + ) - npt.assert_raises(ValueError, self.rw.reweight, self.data1[0]) + npt.assert_raises(ValueError, rw.reweight, data1[0]) diff --git a/modopt/tests/test_signal.py b/modopt/tests/test_signal.py index 7490b98c..202e541b 100644 --- a/modopt/tests/test_signal.py +++ b/modopt/tests/test_signal.py @@ -1,322 +1,240 @@ -# -*- coding: utf-8 -*- - """UNIT TESTS FOR SIGNAL. This module contains unit tests for the modopt.signal module. -:Author: Samuel Farrens - +:Authors: + Samuel Farrens + Pierre-Antoine Comby """ -from unittest import TestCase - import numpy as np import numpy.testing as npt +import pytest +from test_helpers import failparam from modopt.signal import filter, noise, positivity, svd, validation, wavelet -class FilterTestCase(TestCase): - """Test case for filter module.""" - - def test_guassian_filter(self): - """Test guassian_filter.""" - npt.assert_almost_equal( - filter.gaussian_filter(1, 1), - 0.24197072451914337, - err_msg='Incorrect Gaussian filter', - ) +class TestFilter: + """Test filter module""" + @pytest.mark.parametrize( + ("norm", "result"), [(True, 0.24197072451914337), (False, 0.60653065971263342)] + ) + def test_gaussian_filter(self, norm, result): + """Test gaussian filter.""" + npt.assert_almost_equal(filter.gaussian_filter(1, 1, norm=norm), result) - npt.assert_almost_equal( - filter.gaussian_filter(1, 1, norm=False), - 0.60653065971263342, - err_msg='Incorrect Gaussian filter', - ) def test_mex_hat(self): - """Test mex_hat.""" + """Test mexican hat filter.""" npt.assert_almost_equal( filter.mex_hat(2, 1), -0.35213905225713371, - err_msg='Incorrect Mexican hat filter', ) + def test_mex_hat_dir(self): - """Test mex_hat_dir.""" + """Test directional mexican hat filter.""" npt.assert_almost_equal( filter.mex_hat_dir(1, 2, 1), 0.17606952612856686, - err_msg='Incorrect directional Mexican hat filter', ) -class NoiseTestCase(TestCase): - """Test case for noise module.""" +class TestNoise: + """Test noise module.""" - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) - self.data2 = np.array( - [[0, 2.0, 2.0], [4.0, 5.0, 10], [11.0, 15.0, 18.0]], - ) - self.data3 = np.array([ + data1 = np.arange(9).reshape(3, 3).astype(float) + data2 = np.array( + [[0, 2.0, 2.0], [4.0, 5.0, 10], [11.0, 15.0, 18.0]], + ) + data3 = np.array( + [ [1.62434536, 0.38824359, 1.47182825], [1.92703138, 4.86540763, 2.6984613], [7.74481176, 6.2387931, 8.3190391], - ]) - self.data4 = np.array([[0, 0, 0], [0, 0, 5.0], [6.0, 7.0, 8.0]]) - self.data5 = np.array( - [[0, 0, 0], [0, 0, 0], [1.0, 2.0, 3.0]], - ) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.data3 = None - self.data4 = None - self.data5 = None - - def test_add_noise_poisson(self): - """Test add_noise with Poisson noise.""" - np.random.seed(1) - npt.assert_array_equal( - noise.add_noise(self.data1, noise_type='poisson'), - self.data2, - err_msg='Incorrect noise: Poisson', - ) - - npt.assert_raises( - ValueError, - noise.add_noise, - self.data1, - noise_type='bla', - ) - - npt.assert_raises(ValueError, noise.add_noise, self.data1, (1, 1)) - - def test_add_noise_gaussian(self): - """Test add_noise with Gaussian noise.""" - np.random.seed(1) - npt.assert_almost_equal( - noise.add_noise(self.data1), - self.data3, - err_msg='Incorrect noise: Gaussian', - ) - + ] + ) + data4 = np.array([[0, 0, 0], [0, 0, 5.0], [6.0, 7.0, 8.0]]) + data5 = np.array( + [[0, 0, 0], [0, 0, 0], [1.0, 2.0, 3.0]], + ) + + @pytest.mark.parametrize( + ("data", "noise_type", "sigma", "data_noise"), + [ + (data1, "poisson", 1, data2), + (data1, "gauss", 1, data3), + (data1, "gauss", (1, 1, 1), data3), + failparam(data1, "fail", 1, data1, raises=ValueError), + ], + ) + def test_add_noise(self, data, noise_type, sigma, data_noise): + """Test add_noise.""" np.random.seed(1) npt.assert_almost_equal( - noise.add_noise(self.data1, sigma=(1, 1, 1)), - self.data3, - err_msg='Incorrect noise: Gaussian', - ) - - def test_thresh_hard(self): - """Test thresh with hard threshold.""" - npt.assert_array_equal( - noise.thresh(self.data1, 5), - self.data4, - err_msg='Incorrect threshold: hard', - ) - - npt.assert_raises( - ValueError, - noise.thresh, - self.data1, - 5, - threshold_type='bla', + noise.add_noise(data, sigma=sigma, noise_type=noise_type), data_noise ) - def test_thresh_soft(self): - """Test thresh with soft threshold.""" + @pytest.mark.parametrize( + ("threshold_type", "result"), + [("hard", data4), ("soft", data5), failparam("fail", None, raises=ValueError)], + ) + def test_thresh(self, threshold_type, result): + """Test threshold.""" npt.assert_array_equal( - noise.thresh(self.data1, 5, threshold_type='soft'), - self.data5, - err_msg='Incorrect threshold: soft', + noise.thresh(self.data1, 5, threshold_type=threshold_type), result ) - -class PositivityTestCase(TestCase): - """Test case for positivity module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3) - 5 - self.data2 = np.array([[0, 0, 0], [0, 0, 0], [1, 2, 3]]) - self.data3 = np.array( - [np.arange(5) - 3, np.arange(4) - 2], - dtype=object, - ) - self.data4 = np.array( - [np.array([0, 0, 0, 0, 1]), np.array([0, 0, 0, 1])], +class TestPositivity: + """Test positivity module.""" + data1 = np.arange(9).reshape(3, 3).astype(float) + data4 = np.array([[0, 0, 0], [0, 0, 5.0], [6.0, 7.0, 8.0]]) + data5 = np.array( + [[0, 0, 0], [0, 0, 0], [1.0, 2.0, 3.0]], + ) + @pytest.mark.parametrize( + ("value", "expected"), + [ + (-1.0, -float(0)), + (-1, 0), + (data1 - 5, data5), + ( + np.array([np.arange(3) - 1, np.arange(2) - 1], dtype=object), + np.array([np.array([0, 0, 1]), np.array([0, 0])], dtype=object), + ), + failparam("-1", None, raises=TypeError), + ], + ) + def test_positive(self, value, expected): + """Test positive.""" + if isinstance(value, np.ndarray) and value.dtype == "O": + for v, e in zip(positivity.positive(value), expected): + npt.assert_array_equal(v, e) + else: + npt.assert_array_equal(positivity.positive(value), expected) + + +class TestSVD: + """Test for svd module.""" + + @pytest.fixture + def data(self): + """Initialize test data.""" + data1 = np.arange(18).reshape(9, 2).astype(float) + data2 = np.arange(32).reshape(16, 2).astype(float) + data3 = np.array( + [ + np.array( + [ + [-0.01744594, -0.61438865], + [-0.08435304, -0.50397984], + [-0.15126014, -0.39357102], + [-0.21816724, -0.28316221], + [-0.28507434, -0.17275339], + [-0.35198144, -0.06234457], + [-0.41888854, 0.04806424], + [-0.48579564, 0.15847306], + [-0.55270274, 0.26888188], + ] + ), + np.array([42.23492742, 1.10041151]), + np.array( + [ + [-0.67608034, -0.73682791], + [0.73682791, -0.67608034], + ] + ), + ], dtype=object, ) - self.pos_dtype_obj = positivity.positive(self.data3) - self.err = 'Incorrect positivity' - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - - def test_positivity(self): - """Test positivity.""" - npt.assert_equal(positivity.positive(-1), 0, err_msg=self.err) - - npt.assert_equal( - positivity.positive(-1.0), - -float(0), - err_msg=self.err, + data4 = np.array( + [ + [-1.05426832e-16, 1.0], + [2.0, 3.0], + [4.0, 5.0], + [6.0, 7.0], + [8.0, 9.0], + [1.0e1, 1.1e1], + [1.2e1, 1.3e1], + [1.4e1, 1.5e1], + [1.6e1, 1.7e1], + ] ) - npt.assert_equal( - positivity.positive(self.data1), - self.data2, - err_msg=self.err, + data5 = np.array( + [ + [0.49815487, 0.54291537], + [2.40863386, 2.62505584], + [4.31911286, 4.70719631], + [6.22959185, 6.78933678], + [8.14007085, 8.87147725], + [10.05054985, 10.95361772], + [11.96102884, 13.03575819], + [13.87150784, 15.11789866], + [15.78198684, 17.20003913], + ] ) + return (data1, data2, data3, data4, data5) - for expected, output in zip(self.data4, self.pos_dtype_obj): - print(expected, output) - npt.assert_array_equal(expected, output, err_msg=self.err) + @pytest.fixture + def svd0(self, data): + """Compute SVD of first data sample.""" + return svd.calculate_svd(data[0]) - npt.assert_raises(TypeError, positivity.positive, '-1') - - -class SVDTestCase(TestCase): - """Test case for svd module.""" - - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(18).reshape(9, 2).astype(float) - self.data2 = np.arange(32).reshape(16, 2).astype(float) - self.data3 = np.array( - [ - np.array([ - [-0.01744594, -0.61438865], - [-0.08435304, -0.50397984], - [-0.15126014, -0.39357102], - [-0.21816724, -0.28316221], - [-0.28507434, -0.17275339], - [-0.35198144, -0.06234457], - [-0.41888854, 0.04806424], - [-0.48579564, 0.15847306], - [-0.55270274, 0.26888188], - ]), - np.array([42.23492742, 1.10041151]), - np.array([ - [-0.67608034, -0.73682791], - [0.73682791, -0.67608034], - ]), - ], - dtype=object, - ) - self.data4 = np.array([ - [-1.05426832e-16, 1.0], - [2.0, 3.0], - [4.0, 5.0], - [6.0, 7.0], - [8.0, 9.0], - [1.0e1, 1.1e1], - [1.2e1, 1.3e1], - [1.4e1, 1.5e1], - [1.6e1, 1.7e1], - ]) - self.data5 = np.array([ - [0.49815487, 0.54291537], - [2.40863386, 2.62505584], - [4.31911286, 4.70719631], - [6.22959185, 6.78933678], - [8.14007085, 8.87147725], - [10.05054985, 10.95361772], - [11.96102884, 13.03575819], - [13.87150784, 15.11789866], - [15.78198684, 17.20003913], - ]) - self.svd = svd.calculate_svd(self.data1) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.data3 = None - self.data4 = None - self.svd = None - - def test_find_n_pc(self): - """Test find_n_pc.""" + def test_find_n_pc(self, data): + """Test find number of principal component.""" npt.assert_equal( - svd.find_n_pc(svd.svd(self.data2)[0]), + svd.find_n_pc(svd.svd(data[1])[0]), 2, - err_msg='Incorrect number of principal components.', + err_msg="Incorrect number of principal components.", ) + def test_n_pc_fail_non_square(self): + """Test find_n_pc.""" npt.assert_raises(ValueError, svd.find_n_pc, np.arange(3)) - def test_calculate_svd(self): + def test_calculate_svd(self, data, svd0): """Test calculate_svd.""" + errors = [] + for i, name in enumerate("USV"): + try: + npt.assert_almost_equal(svd0[i], data[2][i]) + except AssertionError: + errors.append(name) + if errors: + raise AssertionError("Incorrect SVD calculation for: " + ", ".join(errors)) + + @pytest.mark.parametrize( + ("n_pc", "idx_res"), + [(None, 3), (1, 4), ("all", 0), failparam("fail", 1, raises=ValueError)], + ) + def test_svd_thresh(self, data, n_pc, idx_res): + """Test svd_tresh.""" npt.assert_almost_equal( - self.svd[0], - np.array(self.data3)[0], - err_msg='Incorrect SVD calculation: U', - ) - - npt.assert_almost_equal( - self.svd[1], - np.array(self.data3)[1], - err_msg='Incorrect SVD calculation: S', - ) - - npt.assert_almost_equal( - self.svd[2], - np.array(self.data3)[2], - err_msg='Incorrect SVD calculation: V', - ) - - def test_svd_thresh(self): - """Test svd_thresh.""" - npt.assert_almost_equal( - svd.svd_thresh(self.data1), - self.data4, - err_msg='Incorrect SVD tresholding', - ) - - npt.assert_almost_equal( - svd.svd_thresh(self.data1, n_pc=1), - self.data5, - err_msg='Incorrect SVD tresholding', - ) - - npt.assert_almost_equal( - svd.svd_thresh(self.data1, n_pc='all'), - self.data1, - err_msg='Incorrect SVD tresholding', + svd.svd_thresh(data[0], n_pc=n_pc), + data[idx_res], ) + def test_svd_tresh_invalid_type(self): + """Test svd_tresh failure.""" npt.assert_raises(TypeError, svd.svd_thresh, 1) - npt.assert_raises(ValueError, svd.svd_thresh, self.data1, n_pc='bla') - - def test_svd_thresh_coef(self): - """Test svd_thresh_coef.""" + @pytest.mark.parametrize("operator", [lambda x: x, failparam(0, raises=TypeError)]) + def test_svd_thresh_coef(self, data, operator): + """Test svd_tresh_coef.""" npt.assert_almost_equal( - svd.svd_thresh_coef(self.data1, lambda x_val: x_val, 0), - self.data1, - err_msg='Incorrect SVD coefficient tresholding', + svd.svd_thresh_coef(data[0], operator, 0), + data[0], + err_msg="Incorrect SVD coefficient tresholding", ) - npt.assert_raises(TypeError, svd.svd_thresh_coef, self.data1, 0, 0) - + # TODO test_svd_thresh_coef_fast -class ValidationTestCase(TestCase): - """Test case for validation module.""" +class TestValidation: + """Test validation Module.""" - def setUp(self): - """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None + array33 = np.arange(9).reshape(3, 3) def test_transpose_test(self): """Test transpose_test.""" @@ -325,90 +243,81 @@ def test_transpose_test(self): validation.transpose_test( lambda x_val, y_val: x_val.dot(y_val), lambda x_val, y_val: x_val.dot(y_val.T), - self.data1.shape, - x_args=self.data1, + self.array33.shape, + x_args=self.array33, ), None, ) - npt.assert_raises( - TypeError, - validation.transpose_test, - 0, - 0, - self.data1.shape, - x_args=self.data1, - ) - -class WaveletTestCase(TestCase): - """Test case for wavelet module.""" +class TestWavelet: + """Test Wavelet Module.""" - def setUp(self): + @pytest.fixture + def data(self): """Set test parameter values.""" - self.data1 = np.arange(9).reshape(3, 3).astype(float) - self.data2 = np.arange(36).reshape(4, 3, 3).astype(float) - self.data3 = np.array([ - [ - [6.0, 20, 26.0], - [36.0, 84.0, 84.0], - [90, 164.0, 134.0], - ], + data1 = np.arange(9).reshape(3, 3).astype(float) + data2 = np.arange(36).reshape(4, 3, 3).astype(float) + data3 = np.array( [ - [78.0, 155.0, 134.0], - [225.0, 408.0, 327.0], - [270, 461.0, 350], - ], + [ + [6.0, 20, 26.0], + [36.0, 84.0, 84.0], + [90, 164.0, 134.0], + ], + [ + [78.0, 155.0, 134.0], + [225.0, 408.0, 327.0], + [270, 461.0, 350], + ], + [ + [150, 290, 242.0], + [414.0, 732.0, 570], + [450, 758.0, 566.0], + ], + [ + [222.0, 425.0, 350], + [603.0, 1056.0, 813.0], + [630, 1055.0, 782.0], + ], + ] + ) + + data4 = np.array( [ - [150, 290, 242.0], - [414.0, 732.0, 570], - [450, 758.0, 566.0], - ], + [6496.0, 9796.0, 6544.0], + [9924.0, 14910, 9924.0], + [6544.0, 9796.0, 6496.0], + ] + ) + + data5 = np.array( [ - [222.0, 425.0, 350], - [603.0, 1056.0, 813.0], - [630, 1055.0, 782.0], - ], - ]) - - self.data4 = np.array([ - [6496.0, 9796.0, 6544.0], - [9924.0, 14910, 9924.0], - [6544.0, 9796.0, 6496.0], - ]) - - self.data5 = np.array([ - [[0, 1.0, 4.0], [3.0, 10, 13.0], [6.0, 19.0, 22.0]], - [[3.0, 10, 13.0], [24.0, 46.0, 40], [45.0, 82.0, 67.0]], - [[6.0, 19.0, 22.0], [45.0, 82.0, 67.0], [84.0, 145.0, 112.0]], - ]) - - def tearDown(self): - """Unset test parameter values.""" - self.data1 = None - self.data2 = None - self.data3 = None - self.data4 = None - self.data5 = None - - def test_filter_convolve(self): - """Test filter_convolve.""" - npt.assert_almost_equal( - wavelet.filter_convolve(self.data1, self.data2), - self.data3, - err_msg='Inccorect filter comvolution.', + [[0, 1.0, 4.0], [3.0, 10, 13.0], [6.0, 19.0, 22.0]], + [[3.0, 10, 13.0], [24.0, 46.0, 40], [45.0, 82.0, 67.0]], + [[6.0, 19.0, 22.0], [45.0, 82.0, 67.0], [84.0, 145.0, 112.0]], + ] ) + return (data1, data2, data3, data4, data5) + @pytest.mark.parametrize( + ("idx_data", "idx_filter", "idx_res", "filter_rot"), + [(0, 1, 2, False), (1, 1, 3, True)], + ) + def test_filter_convolve(self, data, idx_data, idx_filter, idx_res, filter_rot): + """Test filter_convolve.""" npt.assert_almost_equal( - wavelet.filter_convolve(self.data2, self.data2, filter_rot=True), - self.data4, - err_msg='Inccorect filter comvolution.', + wavelet.filter_convolve( + data[idx_data], data[idx_filter], filter_rot=filter_rot + ), + data[idx_res], + err_msg="Inccorect filter comvolution.", ) - def test_filter_convolve_stack(self): + def test_filter_convolve_stack(self, data): """Test filter_convolve_stack.""" npt.assert_almost_equal( - wavelet.filter_convolve_stack(self.data1, self.data1), - self.data5, - err_msg='Inccorect filter stack comvolution.', + wavelet.filter_convolve_stack(data[0], data[0]), + data[4], + err_msg="Inccorect filter stack comvolution.", ) diff --git a/requirements.txt b/requirements.txt index 63a404ba..1f44de13 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ importlib_metadata>=3.7.0 numpy>=1.19.5 scipy>=1.5.4 -progressbar2>=3.53.1 +tqdm>=4.64.0 diff --git a/setup.cfg b/setup.cfg index cabd35a0..100adb40 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,6 +42,8 @@ per-file-ignores = modopt/opt/algorithms/__init__.py: F401,F403,WPS318, WPS319, WPS412, WPS410 #Todo: x is a too short name. modopt/opt/algorithms/forward_backward.py: WPS111 + #Todo: u,v , A is a too short name. + modopt/opt/algorithms/admm.py: WPS111, N803 #Todo: Check need for del statement modopt/opt/algorithms/primal_dual.py: WPS111, WPS420 #multiline parameters bug with tuples @@ -79,13 +81,17 @@ max-string-usages = 20 max-raises = 5 [tool:pytest] +norecursedirs=tests/test_helpers testpaths = modopt addopts = --verbose - --emoji - --flake8 --cov=modopt - --cov-report=term + --cov-report=term-missing --cov-report=xml --junitxml=pytest.xml + --pydocstyle + +[pydocstyle] +convention=numpy +add-ignore=D107 diff --git a/setup.py b/setup.py index c93dd020..e6a8a9e6 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ # Set the package release version major = 1 -minor = 6 +minor = 7 patch = 1 # Set the package details @@ -20,7 +20,7 @@ license = 'MIT' # Set the package classifiers -python_versions_supported = ['3.6', '3.7', '3.8', '3.9'] +python_versions_supported = ['3.7', '3.8', '3.9', '3.10', '3.11'] os_platforms_supported = ['Unix', 'MacOS'] lc_str = 'License :: OSI Approved :: {0} License'