diff --git a/.github/workflows/cd-build.yml b/.github/workflows/cd-build.yml index fca9feb1..21b6cc3a 100644 --- a/.github/workflows/cd-build.yml +++ b/.github/workflows/cd-build.yml @@ -14,32 +14,28 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 - - name: Set up Conda with Python 3.8 - uses: conda-incubator/setup-miniconda@v2 + - uses: actions/setup-python@v4 with: - auto-update-conda: true - python-version: 3.8 - auto-activate-base: false + python-version: "3.10" + cache: pip - name: Install dependencies shell: bash -l {0} run: | python -m pip install --upgrade pip - python -m pip install -r develop.txt python -m pip install twine - python -m pip install . + python -m pip install .[doc,test] - name: Run Tests shell: bash -l {0} run: | - python setup.py test + pytest - name: Check distribution shell: bash -l {0} run: | - python setup.py sdist twine check dist/* - name: Upload coverage to Codecov @@ -57,20 +53,15 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v2 + uses: actions/checkout@v4 - - name: Set up Conda with Python 3.8 - uses: conda-incubator/setup-miniconda@v2 - with: - python-version: "3.8" - name: Install dependencies shell: bash -l {0} run: | conda install -c conda-forge pandoc python -m pip install --upgrade pip - python -m pip install -r docs/requirements.txt - python -m pip install . + python -m pip install .[doc] - name: Build API documentation shell: bash -l {0} diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml index c4ba28a0..3a209d12 100644 --- a/.github/workflows/ci-build.yml +++ b/.github/workflows/ci-build.yml @@ -16,61 +16,41 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macos-latest] - python-version: ["3.10"] + python-version: ["3.8", "3.9", "3.10"] steps: - - name: Checkout - uses: actions/checkout@v2 - - - name: Set up Conda with Python ${{ matrix.python-version }} - uses: conda-incubator/setup-miniconda@v2 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v4 with: - auto-update-conda: true python-version: ${{ matrix.python-version }} - auto-activate-base: false - - - name: Check Conda - shell: bash -l {0} - run: | - conda info - conda list - python --version + cache: pip - name: Install Dependencies shell: bash -l {0} run: | python --version python -m pip install --upgrade pip - python -m pip install -r develop.txt - python -m pip install -r docs/requirements.txt - python -m pip install astropy "scikit-image<0.20" scikit-learn matplotlib - python -m pip install tensorflow>=2.4.1 - python -m pip install twine - python -m pip install . + python -m pip install .[test] + python -m pip install astropy scikit-image scikit-learn matplotlib + python -m pip install tensorflow>=2.4.1 torch - name: Run Tests shell: bash -l {0} run: | - export PATH=/usr/share/miniconda/bin:$PATH pytest -n 2 - name: Save Test Results if: always() - uses: actions/upload-artifact@v2 + uses: actions/upload-artifact@v4 with: name: unit-test-results-${{ matrix.os }}-${{ matrix.python-version }} - path: pytest.xml - - - name: Check Distribution - shell: bash -l {0} - run: | - python setup.py sdist - twine check dist/* + path: coverage.xml - name: Check API Documentation build shell: bash -l {0} run: | - conda install -c conda-forge pandoc + apt install pandoc + pip install .[doc] ipykernel sphinx-apidoc -t docs/_templates -feTMo docs/source modopt sphinx-build -b doctest -E docs/source docs/_build @@ -81,38 +61,3 @@ jobs: file: coverage.xml flags: unittests - test-basic: - name: Basic Test Suite - runs-on: ${{ matrix.os }} - - strategy: - fail-fast: false - matrix: - os: [ubuntu-latest, macos-latest] - python-version: ["3.7", "3.8", "3.9"] - - steps: - - name: Checkout - uses: actions/checkout@v2 - - - name: Set up Conda with Python ${{ matrix.python-version }} - uses: conda-incubator/setup-miniconda@v2 - with: - auto-update-conda: true - python-version: ${{ matrix.python-version }} - auto-activate-base: false - - - name: Install Dependencies - shell: bash -l {0} - run: | - python --version - python -m pip install --upgrade pip - python -m pip install -r develop.txt - python -m pip install astropy "scikit-image<0.20" scikit-learn matplotlib - python -m pip install . - - - name: Run Tests - shell: bash -l {0} - run: | - export PATH=/usr/share/miniconda/bin:$PATH - pytest -n 2 diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml new file mode 100644 index 00000000..45fc2b23 --- /dev/null +++ b/.github/workflows/style.yml @@ -0,0 +1,38 @@ +name: Style checking + +on: + push: + branches: [ "master", "main", "develop" ] + pull_request: + branches: [ "master", "main", "develop" ] + + workflow_dispatch: + +env: + PYTHON_VERSION: "3.10" + +jobs: + linter-check: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Set up Python ${{ env.PYTHON_VERSION }} + uses: actions/setup-python@v4 + with: + python-version: ${{ env.PYTHON_VERSION }} + cache: pip + + - name: Install Python deps + shell: bash + run: | + python -m pip install --upgrade pip + python -m pip install -e .[test,dev] + + - name: Black Check + shell: bash + run: black . --diff --color --check + + - name: ruff Check + shell: bash + run: ruff check diff --git a/.pylintrc b/.pylintrc deleted file mode 100644 index 3ac9aef9..00000000 --- a/.pylintrc +++ /dev/null @@ -1,2 +0,0 @@ -[MASTER] -ignore-patterns=**/docs/**/*.py diff --git a/.pyup.yml b/.pyup.yml deleted file mode 100644 index 8fdac7ff..00000000 --- a/.pyup.yml +++ /dev/null @@ -1,14 +0,0 @@ -# autogenerated pyup.io config file -# see https://pyup.io/docs/configuration/ for all available options - -schedule: '' -update: all -label_prs: update -assignees: sfarrens -requirements: - - requirements.txt: - pin: False - - develop.txt: - pin: False - - docs/requirements.txt: - pin: True diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 9a2f374e..00000000 --- a/MANIFEST.in +++ /dev/null @@ -1,5 +0,0 @@ -include requirements.txt -include develop.txt -include docs/requirements.txt -include README.rst -include LICENSE.txt diff --git a/develop.txt b/develop.txt deleted file mode 100644 index 6ff665eb..00000000 --- a/develop.txt +++ /dev/null @@ -1,12 +0,0 @@ -coverage>=5.5 -pytest>=6.2.2 -pytest-raises>=0.10 -pytest-cases>= 3.6 -pytest-xdist>= 3.0.1 -pytest-cov>=2.11.1 -pytest-emoji>=0.2.0 -pydocstyle==6.1.1 -pytest-pydocstyle>=2.2.0 -black -isort -pytest-black diff --git a/docs/source/conf.py b/docs/source/conf.py index 46564b9f..69921008 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Python Template sphinx config # Import relevant modules @@ -9,56 +8,53 @@ # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. -sys.path.insert(0, os.path.abspath('../..')) +sys.path.insert(0, os.path.abspath("../..")) # -- General configuration ------------------------------------------------ # General information about the project. -project = 'modopt' +project = "modopt" mdata = metadata(project) -author = mdata['Author'] -version = mdata['Version'] -copyright = '2020, {}'.format(author) -gh_user = 'sfarrens' - -# If your documentation needs a minimal Sphinx version, state it here. -needs_sphinx = '3.3' +author = "Samuel Farrens, Pierre-Antoine Comby, Chaithya GR, Philippe Ciuciu" +version = mdata["Version"] +copyright = f"2020, {author}" +gh_user = "sfarrens" # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.coverage', - 'sphinx.ext.doctest', - 'sphinx.ext.ifconfig', - 'sphinx.ext.intersphinx', - 'sphinx.ext.mathjax', - 'sphinx.ext.napoleon', - 'sphinx.ext.todo', - 'sphinx.ext.viewcode', - 'sphinxawesome_theme', - 'sphinxcontrib.bibtex', - 'myst_parser', - 'nbsphinx', - 'nbsphinx_link', - 'numpydoc', - "sphinx_gallery.gen_gallery" + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.coverage", + "sphinx.ext.doctest", + "sphinx.ext.ifconfig", + "sphinx.ext.intersphinx", + "sphinx.ext.mathjax", + "sphinx.ext.napoleon", + "sphinx.ext.todo", + "sphinx.ext.viewcode", + "sphinxawesome_theme.highlighting", + "sphinxcontrib.bibtex", + "myst_parser", + "nbsphinx", + "nbsphinx_link", + "numpydoc", + "sphinx_gallery.gen_gallery", ] # Include module names for objects add_module_names = False # Set class documentation standard. -autoclass_content = 'class' +autoclass_content = "class" # Audodoc options autodoc_default_options = { - 'member-order': 'bysource', - 'private-members': True, - 'show-inheritance': True + "member-order": "bysource", + "private-members": True, + "show-inheritance": True, } # Generate summaries @@ -69,17 +65,17 @@ # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: -source_suffix = ['.rst', '.md'] +source_suffix = [".rst", ".md"] # The master toctree document. -master_doc = 'index' +master_doc = "index" # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. show_authors = True # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'default' +pygments_style = "default" # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = True @@ -88,7 +84,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'sphinxawesome_theme' +html_theme = "sphinxawesome_theme" # html_theme = 'sphinx_book_theme' # Theme options are theme-specific and customize the look and feel of a theme @@ -101,11 +97,10 @@ "breadcrumbs_separator": "/", "show_prev_next": True, "show_scrolltop": True, - } html_collapsible_definitions = True html_awesome_headerlinks = True -html_logo = 'modopt_logo.jpg' +html_logo = "modopt_logo.png" html_permalinks_icon = ( '' @@ -118,7 +113,7 @@ ) # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -html_title = '{0} v{1}'.format(project, version) +html_title = f"{project} v{version}" # A shorter title for the navigation bar. Default is the same as html_title. # html_short_title = None @@ -134,7 +129,7 @@ # If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # using the given strftime format. -html_last_updated_fmt = '%d %b, %Y' +html_last_updated_fmt = "%d %b, %Y" # If true, SmartyPants will be used to convert quotes and dashes to # typographically correct entities. @@ -147,27 +142,25 @@ html_show_copyright = True - # -- Options for Sphinx Gallery ---------------------------------------------- sphinx_gallery_conf = { - "examples_dirs": ["../../modopt/examples/"], + "examples_dirs": ["../../examples/"], "filename_pattern": "/example_", "ignore_pattern": r"/(__init__|conftest)\.py", } - # -- Options for nbshpinx output ------------------------------------------ # Custom fucntion to find notebooks, create .nblink files and update the # notebooks.rst file -def add_notebooks(nb_path='../../notebooks'): +def add_notebooks(nb_path="../../notebooks"): - print('Looking for notebooks') - nb_ext = '.ipynb' - nb_rst_file_name = 'notebooks.rst' + print("Looking for notebooks") + nb_ext = ".ipynb" + nb_rst_file_name = "notebooks.rst" nb_link_format = '{{\n "path": "{0}/{1}"\n}}' nbs = sorted([nb for nb in os.listdir(nb_path) if nb.endswith(nb_ext)]) @@ -176,21 +169,21 @@ def add_notebooks(nb_path='../../notebooks'): nb_name = nb.rstrip(nb_ext) - nb_link_file_name = nb_name + '.nblink' - print('Writing {0}'.format(nb_link_file_name)) - with open(nb_link_file_name, 'w') as nb_link_file: + nb_link_file_name = nb_name + ".nblink" + print(f"Writing {nb_link_file_name}") + with open(nb_link_file_name, "w") as nb_link_file: nb_link_file.write(nb_link_format.format(nb_path, nb)) - print('Looking for {0} in {1}'.format(nb_name, nb_rst_file_name)) - with open(nb_rst_file_name, 'r') as nb_rst_file: + print(f"Looking for {nb_name} in {nb_rst_file_name}") + with open(nb_rst_file_name) as nb_rst_file: check_name = nb_name not in nb_rst_file.read() if check_name: - print('Adding {0} to {1}'.format(nb_name, nb_rst_file_name)) - with open(nb_rst_file_name, 'a') as nb_rst_file: + print(f"Adding {nb_name} to {nb_rst_file_name}") + with open(nb_rst_file_name, "a") as nb_rst_file: if list_pos == 0: - nb_rst_file.write('\n') - nb_rst_file.write(' {0}\n'.format(nb_name)) + nb_rst_file.write("\n") + nb_rst_file.write(f" {nb_name}\n") return nbs @@ -198,13 +191,13 @@ def add_notebooks(nb_path='../../notebooks'): # Add notebooks add_notebooks() -binder = 'https://mybinder.org/v2/gh' -binder_badge = 'https://mybinder.org/badge_logo.svg' -github = 'https://github.com/' -github_badge = 'https://badgen.net/badge/icon/github?icon=github&label' +binder = "https://mybinder.org/v2/gh" +binder_badge = "https://mybinder.org/badge_logo.svg" +github = "https://github.com/" +github_badge = "https://badgen.net/badge/icon/github?icon=github&label" # Remove promts and add binder badge -nb_header_pt1 = r''' +nb_header_pt1 = r""" {% if env.metadata[env.docname]['nbsphinx-link-target'] %} {% set docpath = env.metadata[env.docname]['nbsphinx-link-target'] %} {% else %} @@ -220,18 +213,18 @@ def add_notebooks(nb_path='../../notebooks'): } -''' +""" nb_header_pt2 = ( - r'''

''' - r'''''' + - r'''Binder badge
''' - r'''
GitHub badge'''.format(github_badge) + - r'''

''' + r"""

""" + rf"""""" + + rf"""Binder badge
""" + r"""
GitHub badge""" + + r"""

""" ) nbsphinx_prolog = nb_header_pt1 + nb_header_pt2 @@ -240,28 +233,28 @@ def add_notebooks(nb_path='../../notebooks'): # Refer to the package libraries for type definitions intersphinx_mapping = { - 'python': ('http://docs.python.org/3', None), - 'numpy': ('https://numpy.org/doc/stable/', None), - 'scipy': ('https://docs.scipy.org/doc/scipy/reference', None), - 'progressbar': ('https://progressbar-2.readthedocs.io/en/latest/', None), - 'matplotlib': ('https://matplotlib.org', None), - 'astropy': ('http://docs.astropy.org/en/latest/', None), - 'cupy': ('https://docs-cupy.chainer.org/en/stable/', None), - 'torch': ('https://pytorch.org/docs/stable/', None), - 'sklearn': ( - 'http://scikit-learn.org/stable', - (None, './_intersphinx/sklearn-objects.inv') + "python": ("http://docs.python.org/3", None), + "numpy": ("https://numpy.org/doc/stable/", None), + "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), + "progressbar": ("https://progressbar-2.readthedocs.io/en/latest/", None), + "matplotlib": ("https://matplotlib.org", None), + "astropy": ("http://docs.astropy.org/en/latest/", None), + "cupy": ("https://docs-cupy.chainer.org/en/stable/", None), + "torch": ("https://pytorch.org/docs/stable/", None), + "sklearn": ( + "http://scikit-learn.org/stable", + (None, "./_intersphinx/sklearn-objects.inv"), ), - 'tensorflow': ( - 'https://www.tensorflow.org/api_docs/python', + "tensorflow": ( + "https://www.tensorflow.org/api_docs/python", ( - 'https://github.com/GPflow/tensorflow-intersphinx/' - + 'raw/master/tf2_py_objects.inv') - ) - + "https://github.com/GPflow/tensorflow-intersphinx/" + + "raw/master/tf2_py_objects.inv" + ), + ), } # -- BibTeX Setting ---------------------------------------------- -bibtex_bibfiles = ['refs.bib', 'my_ref.bib'] -bibtex_default_style = 'alpha' +bibtex_bibfiles = ["refs.bib", "my_ref.bib"] +bibtex_default_style = "alpha" diff --git a/modopt/examples/README.rst b/examples/README.rst similarity index 100% rename from modopt/examples/README.rst rename to examples/README.rst diff --git a/modopt/examples/__init__.py b/examples/__init__.py similarity index 100% rename from modopt/examples/__init__.py rename to examples/__init__.py diff --git a/modopt/examples/conftest.py b/examples/conftest.py similarity index 95% rename from modopt/examples/conftest.py rename to examples/conftest.py index 73358679..f3ed371b 100644 --- a/modopt/examples/conftest.py +++ b/examples/conftest.py @@ -11,10 +11,12 @@ https://stackoverflow.com/questions/56807698/how-to-run-script-as-pytest-test """ + from pathlib import Path import runpy import pytest + def pytest_collect_file(path, parent): """Pytest hook. @@ -22,7 +24,7 @@ def pytest_collect_file(path, parent): The new node needs to have the specified parent as parent. """ p = Path(path) - if p.suffix == '.py' and 'example' in p.name: + if p.suffix == ".py" and "example" in p.name: return Script.from_parent(parent, path=p, name=p.name) @@ -33,6 +35,7 @@ def collect(self): """Collect the script as its own item.""" yield ScriptItem.from_parent(self, name=self.name) + class ScriptItem(pytest.Item): """Item script collected by pytest.""" diff --git a/modopt/examples/example_lasso_forward_backward.py b/examples/example_lasso_forward_backward.py similarity index 95% rename from modopt/examples/example_lasso_forward_backward.py rename to examples/example_lasso_forward_backward.py index 7f820000..f3e5091d 100644 --- a/modopt/examples/example_lasso_forward_backward.py +++ b/examples/example_lasso_forward_backward.py @@ -1,4 +1,3 @@ -# noqa: D205 """ Solving the LASSO Problem with the Forward Backward Algorithm. ============================================================== @@ -76,7 +75,7 @@ prox=prox_op, cost=cost_op_fista, metric_call_period=1, - auto_iterate=False, # Just to give us the pleasure of doing things by ourself. + auto_iterate=False, # Just to give us the pleasure of doing things by ourself. ) fb_fista.iterate() @@ -115,7 +114,7 @@ prox=prox_op, cost=cost_op_pogm, metric_call_period=1, - auto_iterate=False, # Just to give us the pleasure of doing things by ourself. + auto_iterate=False, # Just to give us the pleasure of doing things by ourself. ) fb_pogm.iterate() @@ -133,6 +132,7 @@ # # sphinx_gallery_start_ignore assert mse(fb_pogm.x_final, BETA_TRUE) < 1 +# sphinx_gallery_end_ignore # %% # Comparing the Two algorithms diff --git a/modopt/__init__.py b/modopt/__init__.py deleted file mode 100644 index 2c06c1db..00000000 --- a/modopt/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -# -*- coding: utf-8 -*- - -"""MODOPT PACKAGE. - -ModOpt is a series of Modular Optimisation tools for solving inverse problems. - -""" - -from warnings import warn - -from importlib_metadata import version - -from modopt.base import * - -try: - _version = version('modopt') -except Exception: # pragma: no cover - _version = 'Unkown' - warn( - 'Could not extract package metadata. Make sure the package is ' - + 'correctly installed.', - ) - -__version__ = _version diff --git a/modopt/tests/test_helpers/__init__.py b/modopt/tests/test_helpers/__init__.py deleted file mode 100644 index 3886b877..00000000 --- a/modopt/tests/test_helpers/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .utils import failparam, skipparam, Dummy diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..721c8b37 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,58 @@ +[project] +name="modopt" +description = 'Modular Optimisation tools for soliving inverse problems.' +version = "1.7.1" +requires-python= ">=3.8" + +authors = [{name="Samuel Farrens", email="samuel.farrens@cea.fr"}, +{name="Chaithya GR", email="chaithyagr@gmail.com"}, +{name="Pierre-Antoine Comby", email="pierre-antoine.comby@cea.fr"}, +{name="Philippe Ciuciu", email="philippe.ciuciu@cea.fr"} +] +readme="README.md" +license={file="LICENCE.txt"} + +dependencies = ["numpy", "scipy", "tqdm", "importlib_metadata"] + +[project.optional-dependencies] +gpu=["torch", "ptwt"] +doc=["myst-parser", +"nbsphinx", +"nbsphinx-link", +"sphinx-gallery", +"numpydoc", +"sphinxawesome-theme", +"sphinxcontrib-bibtex"] +dev=["black", "ruff"] +test=["pytest<8.0.0", "pytest-cases", "pytest-cov", "pytest-xdist", "pytest-sugar"] + +[build-system] +requires=["setuptools", "setuptools-scm[toml]", "wheel"] + +[tool.coverage.run] +omit = ["*tests*", "*__init__*", "*setup.py*", "*_version.py*", "*example*"] + +[tool.coverage.report] +precision = 2 +exclude_lines = ["pragma: no cover", "raise NotImplementedError"] + +[tool.black] + + +[tool.ruff] +exclude = ["examples", "docs"] +[tool.ruff.lint] +select = ["E", "F", "B", "Q", "UP", "D", "NPY", "RUF"] + +ignore = ["F401"] # we like the try: import ... expect: ... + +[tool.ruff.lint.pydocstyle] +convention="numpy" + +[tool.isort] +profile="black" + +[tool.pytest.ini_options] +minversion = "6.0" +norecursedirs = ["tests/test_helpers"] +addopts = ["--cov=modopt", "--cov-report=term-missing", "--cov-report=xml"] diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 1f44de13..00000000 --- a/requirements.txt +++ /dev/null @@ -1,4 +0,0 @@ -importlib_metadata>=3.7.0 -numpy>=1.19.5 -scipy>=1.5.4 -tqdm>=4.64.0 diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 100adb40..00000000 --- a/setup.cfg +++ /dev/null @@ -1,97 +0,0 @@ -[aliases] -test=pytest - -[metadata] -description_file = README.rst - -[darglint] -docstring_style = numpy -strictness = short - -[flake8] -ignore = - D107, #Justification: Don't need docstring for __init__ in numpydoc style - RST304, #Justification: Need to use :cite: role for citations - RST210, #Justification: RST210, RST213 Inconsistent with numpydoc - RST213, # documentation for handling *args and **kwargs - W503, #Justification: Have to choose one multiline operator format - WPS202, #Todo: Rethink module size, possibly split large modules - WPS337, #Todo: Consider simplifying multiline conditions. - WPS338, #Todo: Consider changing method order - WPS403, #Todo: Rethink no cover lines - WPS421, #Todo: Review need for print statements - WPS432, #Justification: Mathematical codes require "magic numbers" - WPS433, #Todo: Rethink conditional imports - WPS463, #Todo: Rename get_ methods - WPS615, #Todo: Rename get_ methods -per-file-ignores = - #Justification: Needed for keeping package version and current API - *__init__.py*: F401,F403,WPS347,WPS410,WPS412 - #Todo: Rethink conditional imports - #Todo: How can we bypass mutable constants? - modopt/base/backend.py: WPS229, WPS420, WPS407 - #Todo: Rethink conditional imports - modopt/base/observable.py: WPS420,WPS604 - #Todo: Check string for log formatting - modopt/interface/log.py: WPS323 - #Todo: Rethink conditional imports - modopt/math/convolve.py: WPS301,WPS420 - #Todo: Rethink conditional imports - modopt/math/matrix.py: WPS420 - #Todo: import has bad parenthesis - modopt/opt/algorithms/__init__.py: F401,F403,WPS318, WPS319, WPS412, WPS410 - #Todo: x is a too short name. - modopt/opt/algorithms/forward_backward.py: WPS111 - #Todo: u,v , A is a too short name. - modopt/opt/algorithms/admm.py: WPS111, N803 - #Todo: Check need for del statement - modopt/opt/algorithms/primal_dual.py: WPS111, WPS420 - #multiline parameters bug with tuples - modopt/opt/algorithms/gradient_descent.py: WPS111, WPS420, WPS317 - #Todo: Consider changing costObj name - modopt/opt/cost.py: N801, - #Todo: - # - Rethink subscript slice assignment - # - Reduce complexity of KSupportNorm - # - Check bitwise operations - modopt/opt/proximity.py: WPS220,WPS231,WPS352,WPS362,WPS465,WPS506,WPS508 - #Todo: Consider changing cwbReweight name - modopt/opt/reweight.py: N801 - #Justification: Needed to import matplotlib.pyplot - modopt/plot/cost_plot.py: N802,WPS301 - #Todo: Investigate possible bug in find_n_pc function - #Todo: Investigate darglint error - modopt/signal/svd.py: WPS345, DAR000 - #Todo: Check security of using system executable call - modopt/signal/wavelet.py: S404,S603 - #Todo: Clean up tests - modopt/tests/*.py: E731,F401,WPS301,WPS420,WPS425,WPS437,WPS604 - #Todo: Import has bad parenthesis - modopt/tests/test_base.py: WPS318,WPS319,E501,WPS301 -#WPS Settings -max-arguments = 25 -max-attributes = 40 -max-cognitive-score = 20 -max-function-expressions = 20 -max-line-complexity = 30 -max-local-variables = 10 -max-methods = 20 -max-module-expressions = 20 -max-string-usages = 20 -max-raises = 5 - -[tool:pytest] -norecursedirs=tests/test_helpers -testpaths = - modopt -addopts = - --verbose - --cov=modopt - --cov-report=term-missing - --cov-report=xml - --junitxml=pytest.xml - --pydocstyle - -[pydocstyle] -convention=numpy -add-ignore=D107 diff --git a/setup.py b/setup.py deleted file mode 100644 index e6a8a9e6..00000000 --- a/setup.py +++ /dev/null @@ -1,73 +0,0 @@ -#! /usr/bin/env python -# -*- coding: utf-8 -*- - -from setuptools import setup, find_packages -import os - -# Set the package release version -major = 1 -minor = 7 -patch = 1 - -# Set the package details -name = 'modopt' -version = '.'.join(str(value) for value in (major, minor, patch)) -author = 'Samuel Farrens' -email = 'samuel.farrens@cea.fr' -gh_user = 'cea-cosmic' -url = 'https://github.com/{0}/{1}'.format(gh_user, name) -description = 'Modular Optimisation tools for soliving inverse problems.' -license = 'MIT' - -# Set the package classifiers -python_versions_supported = ['3.7', '3.8', '3.9', '3.10', '3.11'] -os_platforms_supported = ['Unix', 'MacOS'] - -lc_str = 'License :: OSI Approved :: {0} License' -ln_str = 'Programming Language :: Python' -py_str = 'Programming Language :: Python :: {0}' -os_str = 'Operating System :: {0}' - -classifiers = ( - [lc_str.format(license)] - + [ln_str] - + [py_str.format(ver) for ver in python_versions_supported] - + [os_str.format(ops) for ops in os_platforms_supported] -) - -# Source package description from README.md -this_directory = os.path.abspath(os.path.dirname(__file__)) -with open(os.path.join(this_directory, 'README.md'), encoding='utf-8') as f: - long_description = f.read() - -# Source package requirements from requirements.txt -with open('requirements.txt') as open_file: - install_requires = open_file.read() - -# Source test requirements from develop.txt -with open('develop.txt') as open_file: - tests_require = open_file.read() - -# Source doc requirements from docs/requirements.txt -with open('docs/requirements.txt') as open_file: - docs_require = open_file.read() - - -setup( - name=name, - author=author, - author_email=email, - version=version, - license=license, - url=url, - description=description, - long_description=long_description, - long_description_content_type='text/markdown', - packages=find_packages(), - install_requires=install_requires, - python_requires='>=3.6', - setup_requires=['pytest-runner'], - tests_require=tests_require, - extras_require={'develop': tests_require + docs_require}, - classifiers=classifiers, -) diff --git a/src/modopt/__init__.py b/src/modopt/__init__.py new file mode 100644 index 00000000..5e8de1b6 --- /dev/null +++ b/src/modopt/__init__.py @@ -0,0 +1,25 @@ +"""MODOPT PACKAGE. + +ModOpt is a series of Modular Optimisation tools for solving inverse problems. + +""" + +from warnings import warn + +from importlib_metadata import version + +from modopt.base import np_adjust, transform, types, observable + +__all__ = ["np_adjust", "transform", "types", "observable"] + +try: + _version = version("modopt") +except Exception: # pragma: no cover + _version = "Unkown" + warn( + "Could not extract package metadata. Make sure the package is " + + "correctly installed.", + stacklevel=1, + ) + +__version__ = _version diff --git a/modopt/base/__init__.py b/src/modopt/base/__init__.py similarity index 71% rename from modopt/base/__init__.py rename to src/modopt/base/__init__.py index 88424bae..c4c681d7 100644 --- a/modopt/base/__init__.py +++ b/src/modopt/base/__init__.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """BASE ROUTINES. This module contains submodules for basic operations such as type @@ -9,4 +7,4 @@ """ -__all__ = ['np_adjust', 'transform', 'types', 'observable'] +__all__ = ["np_adjust", "transform", "types", "observable"] diff --git a/modopt/base/backend.py b/src/modopt/base/backend.py similarity index 77% rename from modopt/base/backend.py rename to src/modopt/base/backend.py index 1f4e9a72..485f649a 100644 --- a/modopt/base/backend.py +++ b/src/modopt/base/backend.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """BACKEND MODULE. This module contains methods for GPU Compatiblity. @@ -26,22 +24,24 @@ # Handle the compatibility with variable LIBRARIES = { - 'cupy': None, - 'tensorflow': None, - 'numpy': np, + "cupy": None, + "tensorflow": None, + "numpy": np, } -if util.find_spec('cupy') is not None: +if util.find_spec("cupy") is not None: try: import cupy as cp - LIBRARIES['cupy'] = cp + + LIBRARIES["cupy"] = cp except ImportError: pass -if util.find_spec('tensorflow') is not None: +if util.find_spec("tensorflow") is not None: try: from tensorflow.experimental import numpy as tnp - LIBRARIES['tensorflow'] = tnp + + LIBRARIES["tensorflow"] = tnp except ImportError: pass @@ -66,12 +66,12 @@ def get_backend(backend): """ if backend not in LIBRARIES.keys() or LIBRARIES[backend] is None: msg = ( - '{0} backend not possible, please ensure that ' - + 'the optional libraries are installed.\n' - + 'Reverting to numpy.' + "{0} backend not possible, please ensure that " + + "the optional libraries are installed.\n" + + "Reverting to numpy." ) warn(msg.format(backend)) - backend = 'numpy' + backend = "numpy" return LIBRARIES[backend], backend @@ -92,16 +92,16 @@ def get_array_module(input_data): The numpy or cupy module """ - if LIBRARIES['tensorflow'] is not None: - if isinstance(input_data, LIBRARIES['tensorflow'].ndarray): - return LIBRARIES['tensorflow'] - if LIBRARIES['cupy'] is not None: - if isinstance(input_data, LIBRARIES['cupy'].ndarray): - return LIBRARIES['cupy'] + if LIBRARIES["tensorflow"] is not None: + if isinstance(input_data, LIBRARIES["tensorflow"].ndarray): + return LIBRARIES["tensorflow"] + if LIBRARIES["cupy"] is not None: + if isinstance(input_data, LIBRARIES["cupy"].ndarray): + return LIBRARIES["cupy"] return np -def change_backend(input_data, backend='cupy'): +def change_backend(input_data, backend="cupy"): """Move data to device. This method changes the backend of an array. This can be used to copy data @@ -151,13 +151,13 @@ def move_to_cpu(input_data): """ xp = get_array_module(input_data) - if xp == LIBRARIES['numpy']: + if xp == LIBRARIES["numpy"]: return input_data - elif xp == LIBRARIES['cupy']: + elif xp == LIBRARIES["cupy"]: return input_data.get() - elif xp == LIBRARIES['tensorflow']: + elif xp == LIBRARIES["tensorflow"]: return input_data.data.numpy() - raise ValueError('Cannot identify the array type.') + raise ValueError("Cannot identify the array type.") def convert_to_tensor(input_data): @@ -184,9 +184,9 @@ def convert_to_tensor(input_data): """ if not import_torch: raise ImportError( - 'Required version of Torch package not found' - + 'see documentation for details: https://cea-cosmic.' - + 'github.io/ModOpt/#optional-packages', + "Required version of Torch package not found" + + "see documentation for details: https://cea-cosmic." + + "github.io/ModOpt/#optional-packages", ) xp = get_array_module(input_data) @@ -220,9 +220,9 @@ def convert_to_cupy_array(input_data): """ if not import_torch: raise ImportError( - 'Required version of Torch package not found' - + 'see documentation for details: https://cea-cosmic.' - + 'github.io/ModOpt/#optional-packages', + "Required version of Torch package not found" + + "see documentation for details: https://cea-cosmic." + + "github.io/ModOpt/#optional-packages", ) if input_data.is_cuda: diff --git a/modopt/base/np_adjust.py b/src/modopt/base/np_adjust.py similarity index 96% rename from modopt/base/np_adjust.py rename to src/modopt/base/np_adjust.py index 6d290e43..10cb5c29 100644 --- a/modopt/base/np_adjust.py +++ b/src/modopt/base/np_adjust.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """NUMPY ADJUSTMENT ROUTINES. This module contains methods for adjusting the default output for certain @@ -154,8 +152,7 @@ def pad2d(input_data, padding): padding = np.array(padding) elif not isinstance(padding, np.ndarray): raise ValueError( - 'Padding must be an integer or a tuple (or list, np.ndarray) ' - + 'of itegers', + "Padding must be an integer or a tuple (or list, np.ndarray) of integers", ) if padding.size == 1: @@ -164,7 +161,7 @@ def pad2d(input_data, padding): pad_x = (padding[0], padding[0]) pad_y = (padding[1], padding[1]) - return np.pad(input_data, (pad_x, pad_y), 'constant') + return np.pad(input_data, (pad_x, pad_y), "constant") def ftr(input_data): diff --git a/modopt/base/observable.py b/src/modopt/base/observable.py similarity index 95% rename from modopt/base/observable.py rename to src/modopt/base/observable.py index 6471ba58..bf8371c3 100644 --- a/modopt/base/observable.py +++ b/src/modopt/base/observable.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """Observable. This module contains observable classes @@ -13,13 +11,13 @@ import numpy as np -class SignalObject(object): +class SignalObject: """Dummy class for signals.""" pass -class Observable(object): +class Observable: """Base class for observable classes. This class defines a simple interface to add or remove observers @@ -33,7 +31,6 @@ class Observable(object): """ def __init__(self, signals): - # Define class parameters self._allowed_signals = [] self._observers = {} @@ -177,7 +174,7 @@ def _remove_observer(self, signal, observer): self._observers[signal].remove(observer) -class MetricObserver(object): +class MetricObserver: """Metric observer. Wrapper of the metric to the observer object notify by the Observable @@ -215,7 +212,6 @@ def __init__( wind=6, eps=1.0e-3, ): - self.name = name self.metric = metric self.mapping = mapping @@ -264,9 +260,7 @@ def is_converge(self): mid_idx = -(self.wind // 2) old_mean = np.array(self.list_cv_values[start_idx:mid_idx]).mean() current_mean = np.array(self.list_cv_values[mid_idx:]).mean() - normalize_residual_metrics = ( - np.abs(old_mean - current_mean) / np.abs(old_mean) - ) + normalize_residual_metrics = np.abs(old_mean - current_mean) / np.abs(old_mean) self.converge_flag = normalize_residual_metrics < self.eps def retrieve_metrics(self): @@ -287,7 +281,7 @@ def retrieve_metrics(self): time_val -= time_val[0] return { - 'time': time_val, - 'index': self.list_iters, - 'values': self.list_cv_values, + "time": time_val, + "index": self.list_iters, + "values": self.list_cv_values, } diff --git a/modopt/base/transform.py b/src/modopt/base/transform.py similarity index 87% rename from modopt/base/transform.py rename to src/modopt/base/transform.py index 07ce846f..25ed102a 100644 --- a/modopt/base/transform.py +++ b/src/modopt/base/transform.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """DATA TRANSFORM ROUTINES. This module contains methods for transforming data. @@ -53,18 +51,17 @@ def cube2map(data_cube, layout): """ if data_cube.ndim != 3: - raise ValueError('The input data must have 3 dimensions.') + raise ValueError("The input data must have 3 dimensions.") if data_cube.shape[0] != np.prod(layout): raise ValueError( - 'The desired layout must match the number of input ' - + 'data layers.', + "The desired layout must match the number of input " + "data layers.", ) - res = ([ + res = [ np.hstack(data_cube[slice(layout[1] * elem, layout[1] * (elem + 1))]) for elem in range(layout[0]) - ]) + ] return np.vstack(res) @@ -118,20 +115,24 @@ def map2cube(data_map, layout): """ if np.all(np.array(data_map.shape) % np.array(layout)): raise ValueError( - 'The desired layout must be a multiple of the number ' - + 'pixels in the data map.', + "The desired layout must be a multiple of the number " + + "pixels in the data map.", ) d_shape = np.array(data_map.shape) // np.array(layout) - return np.array([ - data_map[( - slice(i_elem * d_shape[0], (i_elem + 1) * d_shape[0]), - slice(j_elem * d_shape[1], (j_elem + 1) * d_shape[1]), - )] - for i_elem in range(layout[0]) - for j_elem in range(layout[1]) - ]) + return np.array( + [ + data_map[ + ( + slice(i_elem * d_shape[0], (i_elem + 1) * d_shape[0]), + slice(j_elem * d_shape[1], (j_elem + 1) * d_shape[1]), + ) + ] + for i_elem in range(layout[0]) + for j_elem in range(layout[1]) + ] + ) def map2matrix(data_map, layout): @@ -186,9 +187,9 @@ def map2matrix(data_map, layout): image_shape * (i_elem % layout[1] + 1), ) data_matrix.append( - ( - data_map[lower[0]:upper[0], lower[1]:upper[1]] - ).reshape(image_shape ** 2), + (data_map[lower[0] : upper[0], lower[1] : upper[1]]).reshape( + image_shape**2 + ), ) return np.array(data_matrix).T @@ -232,7 +233,7 @@ def matrix2map(data_matrix, map_shape): # Get the shape and layout of the images image_shape = np.sqrt(data_matrix.shape[0]).astype(int) - layout = np.array(map_shape // np.repeat(image_shape, 2), dtype='int') + layout = np.array(map_shape // np.repeat(image_shape, 2), dtype="int") # Map objects from matrix data_map = np.zeros(map_shape) @@ -248,7 +249,7 @@ def matrix2map(data_matrix, map_shape): image_shape * (i_elem // layout[1] + 1), image_shape * (i_elem % layout[1] + 1), ) - data_map[lower[0]:upper[0], lower[1]:upper[1]] = temp[:, :, i_elem] + data_map[lower[0] : upper[0], lower[1] : upper[1]] = temp[:, :, i_elem] return data_map.astype(int) @@ -285,7 +286,7 @@ def cube2matrix(data_cube): """ return data_cube.reshape( - [data_cube.shape[0]] + [np.prod(data_cube.shape[1:])], + [data_cube.shape[0], np.prod(data_cube.shape[1:])], ).T @@ -330,4 +331,4 @@ def matrix2cube(data_matrix, im_shape): cube2matrix : complimentary function """ - return data_matrix.T.reshape([data_matrix.shape[1]] + list(im_shape)) + return data_matrix.T.reshape([data_matrix.shape[1], *list(im_shape)]) diff --git a/modopt/base/types.py b/src/modopt/base/types.py similarity index 82% rename from modopt/base/types.py rename to src/modopt/base/types.py index 16e06f15..9e9a15b9 100644 --- a/modopt/base/types.py +++ b/src/modopt/base/types.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """TYPE HANDLING ROUTINES. This module contains methods for handing object types. @@ -30,7 +28,7 @@ def check_callable(input_obj): For invalid input type """ if not callable(input_obj): - raise TypeError('The input object must be a callable function.') + raise TypeError("The input object must be a callable function.") return input_obj @@ -71,14 +69,13 @@ def check_float(input_obj): """ if not isinstance(input_obj, (int, float, list, tuple, np.ndarray)): - raise TypeError('Invalid input type.') + raise TypeError("Invalid input type.") if isinstance(input_obj, int): input_obj = float(input_obj) elif isinstance(input_obj, (list, tuple)): input_obj = np.array(input_obj, dtype=float) - elif ( - isinstance(input_obj, np.ndarray) - and (not np.issubdtype(input_obj.dtype, np.floating)) + elif isinstance(input_obj, np.ndarray) and ( + not np.issubdtype(input_obj.dtype, np.floating) ): input_obj = input_obj.astype(float) @@ -121,14 +118,13 @@ def check_int(input_obj): """ if not isinstance(input_obj, (int, float, list, tuple, np.ndarray)): - raise TypeError('Invalid input type.') + raise TypeError("Invalid input type.") if isinstance(input_obj, float): input_obj = int(input_obj) elif isinstance(input_obj, (list, tuple)): input_obj = np.array(input_obj, dtype=int) - elif ( - isinstance(input_obj, np.ndarray) - and (not np.issubdtype(input_obj.dtype, np.integer)) + elif isinstance(input_obj, np.ndarray) and ( + not np.issubdtype(input_obj.dtype, np.integer) ): input_obj = input_obj.astype(int) @@ -160,19 +156,18 @@ def check_npndarray(input_obj, dtype=None, writeable=True, verbose=True): """ if not isinstance(input_obj, np.ndarray): - raise TypeError('Input is not a numpy array.') + raise TypeError("Input is not a numpy array.") - if ( - (not isinstance(dtype, type(None))) - and (not np.issubdtype(input_obj.dtype, dtype)) + if (not isinstance(dtype, type(None))) and ( + not np.issubdtype(input_obj.dtype, dtype) ): raise ( TypeError( - 'The numpy array elements are not of type: {0}'.format(dtype), + f"The numpy array elements are not of type: {dtype}", ), ) if not writeable and verbose and input_obj.flags.writeable: - warn('Making input data immutable.') + warn("Making input data immutable.") input_obj.flags.writeable = writeable diff --git a/modopt/interface/__init__.py b/src/modopt/interface/__init__.py similarity index 75% rename from modopt/interface/__init__.py rename to src/modopt/interface/__init__.py index f9439747..a54f4bf5 100644 --- a/modopt/interface/__init__.py +++ b/src/modopt/interface/__init__.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """INTERFACE ROUTINES. This module contains submodules for error handling, logging and IO interaction. @@ -8,4 +6,4 @@ """ -__all__ = ['errors', 'log'] +__all__ = ["errors", "log"] diff --git a/modopt/interface/errors.py b/src/modopt/interface/errors.py similarity index 75% rename from modopt/interface/errors.py rename to src/modopt/interface/errors.py index 0fbe7e71..84031e3c 100644 --- a/modopt/interface/errors.py +++ b/src/modopt/interface/errors.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """ERROR HANDLING ROUTINES. This module contains methods for handing warnings and errors. @@ -34,16 +32,16 @@ def warn(warn_string, log=None): """ if import_fail: - warn_txt = 'WARNING' + warn_txt = "WARNING" else: - warn_txt = colored('WARNING', 'yellow') + warn_txt = colored("WARNING", "yellow") # Print warning to stdout. - sys.stderr.write('{0}: {1}\n'.format(warn_txt, warn_string)) + sys.stderr.write(f"{warn_txt}: {warn_string}\n") # Check if a logging structure is provided. if not isinstance(log, type(None)): - warnings.warn(warn_string) + warnings.warn(warn_string, stacklevel=2) def catch_error(exception, log=None): @@ -61,17 +59,17 @@ def catch_error(exception, log=None): """ if import_fail: - err_txt = 'ERROR' + err_txt = "ERROR" else: - err_txt = colored('ERROR', 'red') + err_txt = colored("ERROR", "red") # Print exception to stdout. - stream_txt = '{0}: {1}\n'.format(err_txt, exception) + stream_txt = f"{err_txt}: {exception}\n" sys.stderr.write(stream_txt) # Check if a logging structure is provided. if not isinstance(log, type(None)): - log_txt = 'ERROR: {0}\n'.format(exception) + log_txt = f"ERROR: {exception}\n" log.exception(log_txt) @@ -91,11 +89,11 @@ def file_name_error(file_name): If file name not specified or file not found """ - if file_name == '' or file_name[0][0] == '-': - raise IOError('Input file name not specified.') + if file_name == "" or file_name[0][0] == "-": + raise OSError("Input file name not specified.") elif not os.path.isfile(file_name): - raise IOError('Input file name {0} not found!'.format(file_name)) + raise OSError(f"Input file name {file_name} not found!") def is_exe(fpath): @@ -136,7 +134,7 @@ def is_executable(exe_name): """ if not isinstance(exe_name, str): - raise TypeError('Executable name must be a string.') + raise TypeError("Executable name must be a string.") fpath, fname = os.path.split(exe_name) @@ -146,11 +144,9 @@ def is_executable(exe_name): else: res = any( is_exe(os.path.join(path, exe_name)) - for path in os.environ['PATH'].split(os.pathsep) + for path in os.environ["PATH"].split(os.pathsep) ) if not res: - message = ( - '{0} does not appear to be a valid executable on this system.' - ) - raise IOError(message.format(exe_name)) + message = "{0} does not appear to be a valid executable on this system." + raise OSError(message.format(exe_name)) diff --git a/modopt/interface/log.py b/src/modopt/interface/log.py similarity index 77% rename from modopt/interface/log.py rename to src/modopt/interface/log.py index 3b2fa77a..50c316b7 100644 --- a/modopt/interface/log.py +++ b/src/modopt/interface/log.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """LOGGING ROUTINES. This module contains methods for handing logging. @@ -30,22 +28,22 @@ def set_up_log(filename, verbose=True): """ # Add file extension. - filename = '{0}.log'.format(filename) + filename = f"{filename}.log" if verbose: - print('Preparing log file:', filename) + print("Preparing log file:", filename) # Capture warnings. logging.captureWarnings(True) # Set output format. formatter = logging.Formatter( - fmt='%(asctime)s %(message)s', - datefmt='%d/%m/%Y %H:%M:%S', + fmt="%(asctime)s %(message)s", + datefmt="%d/%m/%Y %H:%M:%S", ) # Create file handler. - fh = logging.FileHandler(filename=filename, mode='w') + fh = logging.FileHandler(filename=filename, mode="w") fh.setLevel(logging.DEBUG) fh.setFormatter(formatter) @@ -55,7 +53,7 @@ def set_up_log(filename, verbose=True): log.addHandler(fh) # Send opening message. - log.info('The log file has been set-up.') + log.info("The log file has been set-up.") return log @@ -74,10 +72,10 @@ def close_log(log, verbose=True): """ if verbose: - print('Closing log file:', log.name) + print("Closing log file:", log.name) # Send closing message. - log.info('The log file has been closed.') + log.info("The log file has been closed.") # Remove all handlers from log. for log_handler in log.handlers: diff --git a/modopt/math/__init__.py b/src/modopt/math/__init__.py similarity index 64% rename from modopt/math/__init__.py rename to src/modopt/math/__init__.py index a22c0c98..d5ffc67a 100644 --- a/modopt/math/__init__.py +++ b/src/modopt/math/__init__.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """MATHEMATICS ROUTINES. This module contains submodules for mathematical applications. @@ -8,4 +6,4 @@ """ -__all__ = ['convolve', 'matrix', 'stats', 'metrics'] +__all__ = ["convolve", "matrix", "stats", "metrics"] diff --git a/modopt/math/convolve.py b/src/modopt/math/convolve.py similarity index 87% rename from modopt/math/convolve.py rename to src/modopt/math/convolve.py index a4322ff2..21dc8b4e 100644 --- a/modopt/math/convolve.py +++ b/src/modopt/math/convolve.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """CONVOLUTION ROUTINES. This module contains methods for convolution. @@ -18,7 +16,7 @@ from astropy.convolution import convolve_fft except ImportError: # pragma: no cover import_astropy = False - warn('astropy not found, will default to scipy for convolution') + warn("astropy not found, will default to scipy for convolution") else: import_astropy = True try: @@ -30,7 +28,7 @@ warn('Using pyFFTW "monkey patch" for scipy.fftpack') -def convolve(input_data, kernel, method='scipy'): +def convolve(input_data, kernel, method="scipy"): """Convolve data with kernel. This method convolves the input data with a given kernel using FFT and @@ -80,29 +78,29 @@ def convolve(input_data, kernel, method='scipy'): """ if input_data.ndim != kernel.ndim: - raise ValueError('Data and kernel must have the same dimensions.') + raise ValueError("Data and kernel must have the same dimensions.") - if method not in {'astropy', 'scipy'}: + if method not in {"astropy", "scipy"}: raise ValueError('Invalid method. Options are "astropy" or "scipy".') if not import_astropy: # pragma: no cover - method = 'scipy' + method = "scipy" - if method == 'astropy': + if method == "astropy": return convolve_fft( input_data, kernel, - boundary='wrap', + boundary="wrap", crop=False, - nan_treatment='fill', + nan_treatment="fill", normalize_kernel=False, ) - elif method == 'scipy': - return scipy.signal.fftconvolve(input_data, kernel, mode='same') + elif method == "scipy": + return scipy.signal.fftconvolve(input_data, kernel, mode="same") -def convolve_stack(input_data, kernel, rot_kernel=False, method='scipy'): +def convolve_stack(input_data, kernel, rot_kernel=False, method="scipy"): """Convolve stack of data with stack of kernels. This method convolves the input data with a given kernel using FFT and @@ -156,7 +154,9 @@ def convolve_stack(input_data, kernel, rot_kernel=False, method='scipy'): if rot_kernel: kernel = rotate_stack(kernel) - return np.array([ - convolve(data_i, kernel_i, method=method) - for data_i, kernel_i in zip(input_data, kernel) - ]) + return np.array( + [ + convolve(data_i, kernel_i, method=method) + for data_i, kernel_i in zip(input_data, kernel) + ] + ) diff --git a/modopt/math/matrix.py b/src/modopt/math/matrix.py similarity index 90% rename from modopt/math/matrix.py rename to src/modopt/math/matrix.py index 8361531d..b200f15d 100644 --- a/modopt/math/matrix.py +++ b/src/modopt/math/matrix.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """MATRIX ROUTINES. This module contains methods for matrix operations. @@ -15,7 +13,7 @@ from modopt.base.backend import get_array_module, get_backend -def gram_schmidt(matrix, return_opt='orthonormal'): +def gram_schmidt(matrix, return_opt="orthonormal"): r"""Gram-Schmit. This method orthonormalizes the row vectors of the input matrix. @@ -55,7 +53,7 @@ def gram_schmidt(matrix, return_opt='orthonormal'): https://en.wikipedia.org/wiki/Gram%E2%80%93Schmidt_process """ - if return_opt not in {'orthonormal', 'orthogonal', 'both'}: + if return_opt not in {"orthonormal", "orthogonal", "both"}: raise ValueError( 'Invalid return_opt, options are: "orthonormal", "orthogonal" or ' + '"both"', @@ -65,7 +63,6 @@ def gram_schmidt(matrix, return_opt='orthonormal'): e_vec = [] for vector in matrix: - if u_vec: u_now = vector - sum(project(u_i, vector) for u_i in u_vec) else: @@ -77,11 +74,11 @@ def gram_schmidt(matrix, return_opt='orthonormal'): u_vec = np.array(u_vec) e_vec = np.array(e_vec) - if return_opt == 'orthonormal': + if return_opt == "orthonormal": return e_vec - elif return_opt == 'orthogonal': + elif return_opt == "orthogonal": return u_vec - elif return_opt == 'both': + elif return_opt == "both": return u_vec, e_vec @@ -201,7 +198,7 @@ def rot_matrix(angle): return np.around( np.array( [[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]], - dtype='float', + dtype="float", ), 10, ) @@ -243,22 +240,21 @@ def rotate(matrix, angle): shape = np.array(matrix.shape) if shape[0] != shape[1]: - raise ValueError('Input matrix must be square.') + raise ValueError("Input matrix must be square.") shift = (shape - 1) // 2 index = ( - np.array(list(product(*np.array([np.arange(sval) for sval in shape])))) - - shift + np.array(list(product(*np.array([np.arange(sval) for sval in shape])))) - shift ) - new_index = np.array(np.dot(index, rot_matrix(angle)), dtype='int') + shift + new_index = np.array(np.dot(index, rot_matrix(angle)), dtype="int") + shift new_index[new_index >= shape[0]] -= shape[0] return matrix[tuple(zip(new_index.T))].reshape(shape.T) -class PowerMethod(object): +class PowerMethod: """Power method class. This method performs implements power method to calculate the spectral @@ -277,6 +273,8 @@ class PowerMethod(object): initialisation (default is ``True``) verbose : bool, optional Optional verbosity (default is ``False``) + rng: int, xp.random.Generator or None (default is ``None``) + Random number generator or seed. Examples -------- @@ -301,16 +299,17 @@ def __init__( data_shape, data_type=float, auto_run=True, - compute_backend='numpy', + compute_backend="numpy", verbose=False, + rng=None, ): - self._operator = operator self._data_shape = data_shape self._data_type = data_type self._verbose = verbose xp, compute_backend = get_backend(compute_backend) self.xp = xp + self.rng = None self.compute_backend = compute_backend if auto_run: self.get_spec_rad() @@ -327,7 +326,8 @@ def _set_initial_x(self): Random values of the same shape as the input data """ - return self.xp.random.random(self._data_shape).astype(self._data_type) + rng = self.xp.random.default_rng(self.rng) + return rng.random(self._data_shape).astype(self._data_type) def get_spec_rad(self, tolerance=1e-6, max_iter=20, extra_factor=1.0): """Get spectral radius. @@ -363,18 +363,14 @@ def get_spec_rad(self, tolerance=1e-6, max_iter=20, extra_factor=1.0): x_new /= x_new_norm - if (xp.abs(x_new_norm - x_old_norm) < tolerance): - message = ( - ' - Power Method converged after {0} iterations!' - ) + if xp.abs(x_new_norm - x_old_norm) < tolerance: + message = " - Power Method converged after {0} iterations!" if self._verbose: print(message.format(i_elem + 1)) break elif i_elem == max_iter - 1 and self._verbose: - message = ( - ' - Power Method did not converge after {0} iterations!' - ) + message = " - Power Method did not converge after {0} iterations!" print(message.format(max_iter)) xp.copyto(x_old, x_new) diff --git a/modopt/math/metrics.py b/src/modopt/math/metrics.py similarity index 91% rename from modopt/math/metrics.py rename to src/modopt/math/metrics.py index 21952624..befd4fa4 100644 --- a/modopt/math/metrics.py +++ b/src/modopt/math/metrics.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """METRICS. This module contains classes of different metric functions for optimization. @@ -71,15 +69,13 @@ def _preprocess_input(test, ref, mask=None): The SNR """ - test = np.abs(np.copy(test)).astype('float64') - ref = np.abs(np.copy(ref)).astype('float64') + test = np.abs(np.copy(test)).astype("float64") + ref = np.abs(np.copy(ref)).astype("float64") test = min_max_normalize(test) ref = min_max_normalize(ref) if (not isinstance(mask, np.ndarray)) and (mask is not None): - message = ( - 'Mask should be None, or a numpy.ndarray, got "{0}" instead.' - ) + message = 'Mask should be None, or a numpy.ndarray, got "{0}" instead.' raise ValueError(message.format(mask)) if mask is None: @@ -119,9 +115,9 @@ def ssim(test, ref, mask=None): """ if not import_skimage: # pragma: no cover raise ImportError( - 'Required version of Scikit-Image package not found' - + 'see documentation for details: https://cea-cosmic.' - + 'github.io/ModOpt/#optional-packages', + "Required version of Scikit-Image package not found" + + "see documentation for details: https://cea-cosmic." + + "github.io/ModOpt/#optional-packages", ) test, ref, mask = _preprocess_input(test, ref, mask) @@ -270,6 +266,6 @@ def nrmse(test, ref, mask=None): ref = mask * ref num = np.sqrt(mse(test, ref)) - deno = np.sqrt(np.mean((np.square(test)))) + deno = np.sqrt(np.mean(np.square(test))) return num / deno diff --git a/modopt/math/stats.py b/src/modopt/math/stats.py similarity index 97% rename from modopt/math/stats.py rename to src/modopt/math/stats.py index 59bf6759..8583a8c3 100644 --- a/modopt/math/stats.py +++ b/src/modopt/math/stats.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """STATISTICS ROUTINES. This module contains methods for basic statistics. @@ -31,8 +29,8 @@ def gaussian_kernel(data_shape, sigma, norm="max"): Desiered shape of the kernel sigma : float Standard deviation of the kernel - norm : {'max', 'sum'}, optional - Normalisation of the kerenl (options are ``'max'`` or ``'sum'``, default is ``'max'``) + norm : {'max', 'sum'}, optional, default='max' + Normalisation of the kernel Returns ------- diff --git a/modopt/opt/__init__.py b/src/modopt/opt/__init__.py similarity index 59% rename from modopt/opt/__init__.py rename to src/modopt/opt/__init__.py index 2fd3d747..62d1f388 100644 --- a/modopt/opt/__init__.py +++ b/src/modopt/opt/__init__.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """OPTIMISATION PROBLEM MODULES. This module contains submodules for solving optimisation problems. @@ -8,4 +6,4 @@ """ -__all__ = ['cost', 'gradient', 'linear', 'algorithms', 'proximity', 'reweight'] +__all__ = ["cost", "gradient", "linear", "algorithms", "proximity", "reweight"] diff --git a/modopt/opt/algorithms/__init__.py b/src/modopt/opt/algorithms/__init__.py similarity index 53% rename from modopt/opt/algorithms/__init__.py rename to src/modopt/opt/algorithms/__init__.py index d4e7082b..ff79502c 100644 --- a/modopt/opt/algorithms/__init__.py +++ b/src/modopt/opt/algorithms/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- r"""OPTIMISATION ALGORITHMS. This module contains class implementations of various optimisation algoritms. @@ -45,16 +44,32 @@ """ -from modopt.opt.algorithms.base import SetUp -from modopt.opt.algorithms.forward_backward import (FISTA, POGM, - ForwardBackward, - GenForwardBackward) -from modopt.opt.algorithms.gradient_descent import (AdaGenericGradOpt, - ADAMGradOpt, - GenericGradOpt, - MomentumGradOpt, - RMSpropGradOpt, - SAGAOptGradOpt, - VanillaGenericGradOpt) -from modopt.opt.algorithms.primal_dual import Condat -from modopt.opt.algorithms.admm import ADMM, FastADMM +from .forward_backward import FISTA, ForwardBackward, GenForwardBackward, POGM +from .primal_dual import Condat +from .gradient_descent import ( + ADAMGradOpt, + AdaGenericGradOpt, + GenericGradOpt, + MomentumGradOpt, + RMSpropGradOpt, + SAGAOptGradOpt, + VanillaGenericGradOpt, +) +from .admm import ADMM, FastADMM + +__all__ = [ + "FISTA", + "ForwardBackward", + "GenForwardBackward", + "POGM", + "Condat", + "ADAMGradOpt", + "AdaGenericGradOpt", + "GenericGradOpt", + "MomentumGradOpt", + "RMSpropGradOpt", + "SAGAOptGradOpt", + "VanillaGenericGradOpt", + "ADMM", + "FastADMM", +] diff --git a/modopt/opt/algorithms/admm.py b/src/modopt/opt/algorithms/admm.py similarity index 95% rename from modopt/opt/algorithms/admm.py rename to src/modopt/opt/algorithms/admm.py index b881b770..b2f45171 100644 --- a/modopt/opt/algorithms/admm.py +++ b/src/modopt/opt/algorithms/admm.py @@ -1,4 +1,5 @@ """ADMM Algorithms.""" + import numpy as np from modopt.base.backend import get_array_module @@ -67,7 +68,8 @@ def _calc_cost(self, u, v, **kwargs): class ADMM(SetUp): r"""Fast ADMM Optimisation Algorihm. - This class implement the ADMM algorithm described in :cite:`Goldstein2014` (Algorithm 1). + This class implement the ADMM algorithm described in :cite:`Goldstein2014` + (Algorithm 1). Parameters ---------- @@ -85,7 +87,7 @@ class ADMM(SetUp): Constraint vector optimizers: tuple 2-tuple of callable, that are the optimizers for the u and v. - Each callable should access an init and obs argument and returns an estimate for: + Each callable should access init and obs argument and returns an estimate for: .. math:: u_{k+1} = \argmin H(u) + \frac{\tau}{2}\|A u - y\|^2 .. math:: v_{k+1} = \argmin G(v) + \frac{\tau}{2}\|Bv - y \|^2 cost_funcs: tuple @@ -188,7 +190,7 @@ def iterate(self, max_iter=150): self.retrieve_outputs() # rename outputs as attributes self.u_final = self._u_new - self.x_final = self.u_final # for backward compatibility + self.x_final = self.u_final # for backward compatibility self.v_final = self._v_new def get_notify_observers_kwargs(self): @@ -203,9 +205,9 @@ def get_notify_observers_kwargs(self): The mapping between the iterated variables """ return { - 'x_new': self._u_new, - 'v_new': self._v_new, - 'idx': self.idx, + "x_new": self._u_new, + "v_new": self._v_new, + "idx": self.idx, } def retrieve_outputs(self): @@ -215,7 +217,7 @@ def retrieve_outputs(self): y_final, metrics. """ metrics = {} - for obs in self._observers['cv_metrics']: + for obs in self._observers["cv_metrics"]: metrics[obs.name] = obs.retrieve_metrics() self.metrics = metrics @@ -242,7 +244,7 @@ class FastADMM(ADMM): Constraint vector optimizers: tuple 2-tuple of callable, that are the optimizers for the u and v. - Each callable should access an init and obs argument and returns an estimate for: + Each callable should access init and obs argument and returns an estimate for: .. math:: u_{k+1} = \argmin H(u) + \frac{\tau}{2}\|A u - y\|^2 .. math:: v_{k+1} = \argmin G(v) + \frac{\tau}{2}\|Bv - y \|^2 cost_funcs: tuple @@ -256,7 +258,8 @@ class FastADMM(ADMM): Notes ----- - This is an accelerated version of the ADMM algorithm. The convergence hypothesis are stronger than for the ADMM algorithm. + This is an accelerated version of the ADMM algorithm. The convergence hypothesis are + stronger than for the ADMM algorithm. See Also -------- diff --git a/modopt/opt/algorithms/base.py b/src/modopt/opt/algorithms/base.py similarity index 87% rename from modopt/opt/algorithms/base.py rename to src/modopt/opt/algorithms/base.py index c5a4b101..f7391063 100644 --- a/modopt/opt/algorithms/base.py +++ b/src/modopt/opt/algorithms/base.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Base SetUp for optimisation algorithms.""" from inspect import getmro @@ -69,7 +68,7 @@ def __init__( verbose=False, progress=True, step_size=None, - compute_backend='numpy', + compute_backend="numpy", **dummy_kwargs, ): self.idx = 0 @@ -79,26 +78,26 @@ def __init__( self.metrics = metrics self.step_size = step_size self._op_parents = ( - 'GradParent', - 'ProximityParent', - 'LinearParent', - 'costObj', + "GradParent", + "ProximityParent", + "LinearParent", + "costObj", ) self.metric_call_period = metric_call_period # Declaration of observers for metrics - super().__init__(['cv_metrics']) + super().__init__(["cv_metrics"]) for name, dic in self.metrics.items(): observer = MetricObserver( name, - dic['metric'], - dic['mapping'], - dic['cst_kwargs'], - dic['early_stopping'], + dic["metric"], + dic["mapping"], + dic["cst_kwargs"], + dic["early_stopping"], ) - self.add_observer('cv_metrics', observer) + self.add_observer("cv_metrics", observer) xp, compute_backend = backend.get_backend(compute_backend) self.xp = xp @@ -111,14 +110,13 @@ def metrics(self): @metrics.setter def metrics(self, metrics): - if isinstance(metrics, type(None)): self._metrics = {} elif isinstance(metrics, dict): self._metrics = metrics else: raise TypeError( - 'Metrics must be a dictionary, not {0}.'.format(type(metrics)), + f"Metrics must be a dictionary, not {type(metrics)}.", ) def any_convergence_flag(self): @@ -132,9 +130,7 @@ def any_convergence_flag(self): True if any convergence criteria met """ - return any( - obs.converge_flag for obs in self._observers['cv_metrics'] - ) + return any(obs.converge_flag for obs in self._observers["cv_metrics"]) def copy_data(self, input_data): """Copy Data. @@ -152,10 +148,12 @@ def copy_data(self, input_data): Copy of input data """ - return self.xp.copy(backend.change_backend( - input_data, - self.compute_backend, - )) + return self.xp.copy( + backend.change_backend( + input_data, + self.compute_backend, + ) + ) def _check_input_data(self, input_data): """Check input data type. @@ -175,7 +173,7 @@ def _check_input_data(self, input_data): """ if not (isinstance(input_data, (self.xp.ndarray, np.ndarray))): raise TypeError( - 'Input data must be a numpy array or backend array', + "Input data must be a numpy array or backend array", ) def _check_param(self, param_val): @@ -195,7 +193,7 @@ def _check_param(self, param_val): """ if not isinstance(param_val, float): - raise TypeError('Algorithm parameter must be a float value.') + raise TypeError("Algorithm parameter must be a float value.") def _check_param_update(self, param_update): """Check algorithm parameter update methods. @@ -213,14 +211,13 @@ def _check_param_update(self, param_update): For invalid input type """ - param_conditions = ( - not isinstance(param_update, type(None)) - and not callable(param_update) + param_conditions = not isinstance(param_update, type(None)) and not callable( + param_update ) if param_conditions: raise TypeError( - 'Algorithm parameter update must be a callabale function.', + "Algorithm parameter update must be a callabale function.", ) def _check_operator(self, operator): @@ -239,7 +236,7 @@ def _check_operator(self, operator): tree = [op_obj.__name__ for op_obj in getmro(operator.__class__)] if not any(parent in tree for parent in self._op_parents): - message = '{0} does not inherit an operator parent.' + message = "{0} does not inherit an operator parent." warn(message.format(str(operator.__class__))) def _compute_metrics(self): @@ -250,7 +247,7 @@ def _compute_metrics(self): """ kwargs = self.get_notify_observers_kwargs() - self.notify_observers('cv_metrics', **kwargs) + self.notify_observers("cv_metrics", **kwargs) def _iterations(self, max_iter, progbar=None): """Iterate method. @@ -273,7 +270,6 @@ def _iterations(self, max_iter, progbar=None): # We do not call metrics if metrics is empty or metric call # period is None if self.metrics and self.metric_call_period is not None: - metric_conditions = ( self.idx % self.metric_call_period == 0 or self.idx == (max_iter - 1) @@ -285,7 +281,7 @@ def _iterations(self, max_iter, progbar=None): if self.converge: if self.verbose: - print(' - Converged!') + print(" - Converged!") break if progbar: diff --git a/modopt/opt/algorithms/forward_backward.py b/src/modopt/opt/algorithms/forward_backward.py similarity index 87% rename from modopt/opt/algorithms/forward_backward.py rename to src/modopt/opt/algorithms/forward_backward.py index 702799c6..31927eb0 100644 --- a/modopt/opt/algorithms/forward_backward.py +++ b/src/modopt/opt/algorithms/forward_backward.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Forward-Backward Algorithms.""" import numpy as np @@ -9,7 +8,7 @@ from modopt.opt.linear import Identity -class FISTA(object): +class FISTA: r"""FISTA. This class is inherited by optimisation classes to speed up convergence @@ -52,12 +51,12 @@ class FISTA(object): """ _restarting_strategies = ( - 'adaptive', # option 1 in alg 4 - 'adaptive-i', - 'adaptive-1', - 'adaptive-ii', # option 2 in alg 4 - 'adaptive-2', - 'greedy', # alg 5 + "adaptive", # option 1 in alg 4 + "adaptive-i", + "adaptive-1", + "adaptive-ii", # option 2 in alg 4 + "adaptive-2", + "greedy", # alg 5 None, # no restarting ) @@ -73,26 +72,28 @@ def __init__( r_lazy=4, **kwargs, ): - if isinstance(a_cd, type(None)): - self.mode = 'regular' + self.mode = "regular" self.p_lazy = p_lazy self.q_lazy = q_lazy self.r_lazy = r_lazy elif a_cd > 2: - self.mode = 'CD' + self.mode = "CD" self.a_cd = a_cd self._n = 0 else: raise ValueError( - 'a_cd must either be None (for regular mode) or a number > 2', + "a_cd must either be None (for regular mode) or a number > 2", ) if restart_strategy in self._restarting_strategies: self._check_restart_params( - restart_strategy, min_beta, s_greedy, xi_restart, + restart_strategy, + min_beta, + s_greedy, + xi_restart, ) self.restart_strategy = restart_strategy self.min_beta = min_beta @@ -100,10 +101,10 @@ def __init__( self.xi_restart = xi_restart else: - message = 'Restarting strategy must be one of {0}.' + message = "Restarting strategy must be one of {0}." raise ValueError( message.format( - ', '.join(self._restarting_strategies), + ", ".join(self._restarting_strategies), ), ) self._t_now = 1.0 @@ -155,22 +156,20 @@ def _check_restart_params( if restart_strategy is None: return True - if self.mode != 'regular': + if self.mode != "regular": raise ValueError( - 'Restarting strategies can only be used with regular mode.', + "Restarting strategies can only be used with regular mode.", ) - greedy_params_check = ( - min_beta is None or s_greedy is None or s_greedy <= 1 - ) + greedy_params_check = min_beta is None or s_greedy is None or s_greedy <= 1 - if restart_strategy == 'greedy' and greedy_params_check: + if restart_strategy == "greedy" and greedy_params_check: raise ValueError( - 'You need a min_beta and an s_greedy > 1 for greedy restart.', + "You need a min_beta and an s_greedy > 1 for greedy restart.", ) if xi_restart is None or xi_restart >= 1: - raise ValueError('You need a xi_restart < 1 for restart.') + raise ValueError("You need a xi_restart < 1 for restart.") return True @@ -210,12 +209,12 @@ def is_restart(self, z_old, x_new, x_old): criterion = xp.vdot(z_old - x_new, x_new - x_old) >= 0 if criterion: - if 'adaptive' in self.restart_strategy: + if "adaptive" in self.restart_strategy: self.r_lazy *= self.xi_restart - if self.restart_strategy in {'adaptive-ii', 'adaptive-2'}: + if self.restart_strategy in {"adaptive-ii", "adaptive-2"}: self._t_now = 1 - if self.restart_strategy == 'greedy': + if self.restart_strategy == "greedy": cur_delta = xp.linalg.norm(x_new - x_old) if self._delta0 is None: self._delta0 = self.s_greedy * cur_delta @@ -269,17 +268,17 @@ def update_lambda(self, *args, **kwargs): Implements steps 3 and 4 from algoritm 10.7 in :cite:`bauschke2009`. """ - if self.restart_strategy == 'greedy': + if self.restart_strategy == "greedy": return 2 # Steps 3 and 4 from alg.10.7. self._t_prev = self._t_now - if self.mode == 'regular': - sqrt_part = self.r_lazy * self._t_prev ** 2 + self.q_lazy + if self.mode == "regular": + sqrt_part = self.r_lazy * self._t_prev**2 + self.q_lazy self._t_now = self.p_lazy + np.sqrt(sqrt_part) * 0.5 - elif self.mode == 'CD': + elif self.mode == "CD": self._t_now = (self._n + self.a_cd - 1) / self.a_cd self._n += 1 @@ -344,18 +343,17 @@ def __init__( x, grad, prox, - cost='auto', + cost="auto", beta_param=1.0, lambda_param=1.0, beta_update=None, - lambda_update='fista', + lambda_update="fista", auto_iterate=True, metric_call_period=5, metrics=None, linear=None, **kwargs, ): - # Set default algorithm properties super().__init__( metric_call_period=metric_call_period, @@ -376,7 +374,7 @@ def __init__( self._prox = prox self._linear = linear - if cost == 'auto': + if cost == "auto": self._cost_func = costObj([self._grad, self._prox]) else: self._cost_func = cost @@ -384,7 +382,7 @@ def __init__( # Check if there is a linear op, needed for metrics in the FB algoritm if metrics and self._linear is None: raise ValueError( - 'When using metrics, you must pass a linear operator', + "When using metrics, you must pass a linear operator", ) if self._linear is None: @@ -400,7 +398,7 @@ def __init__( # Set the algorithm parameter update methods self._check_param_update(beta_update) self._beta_update = beta_update - if isinstance(lambda_update, str) and lambda_update == 'fista': + if isinstance(lambda_update, str) and lambda_update == "fista": fista = FISTA(**kwargs) self._lambda_update = fista.update_lambda self._is_restart = fista.is_restart @@ -462,9 +460,8 @@ def _update(self): # Test cost function for convergence. if self._cost_func: - self.converge = ( - self.any_convergence_flag() - or self._cost_func.get_cost(self._x_new) + self.converge = self.any_convergence_flag() or self._cost_func.get_cost( + self._x_new ) def iterate(self, max_iter=150, progbar=None): @@ -500,9 +497,9 @@ def get_notify_observers_kwargs(self): """ return { - 'x_new': self._linear.adj_op(self._x_new), - 'z_new': self._z_new, - 'idx': self.idx, + "x_new": self._linear.adj_op(self._x_new), + "z_new": self._z_new, + "idx": self.idx, } def retrieve_outputs(self): @@ -513,7 +510,7 @@ def retrieve_outputs(self): """ metrics = {} - for obs in self._observers['cv_metrics']: + for obs in self._observers["cv_metrics"]: metrics[obs.name] = obs.retrieve_metrics() self.metrics = metrics @@ -577,7 +574,7 @@ def __init__( x, grad, prox_list, - cost='auto', + cost="auto", gamma_param=1.0, lambda_param=1.0, gamma_update=None, @@ -589,7 +586,6 @@ def __init__( linear=None, **kwargs, ): - # Set default algorithm properties super().__init__( metric_call_period=metric_call_period, @@ -602,22 +598,22 @@ def __init__( self._x_old = self.xp.copy(x) # Set the algorithm operators - for operator in [grad, cost] + prox_list: + for operator in [grad, cost, *prox_list]: self._check_operator(operator) self._grad = grad self._prox_list = self.xp.array(prox_list) self._linear = linear - if cost == 'auto': - self._cost_func = costObj([self._grad] + prox_list) + if cost == "auto": + self._cost_func = costObj([self._grad, *prox_list]) else: self._cost_func = cost # Check if there is a linear op, needed for metrics in the FB algoritm if metrics and self._linear is None: raise ValueError( - 'When using metrics, you must pass a linear operator', + "When using metrics, you must pass a linear operator", ) if self._linear is None: @@ -641,9 +637,7 @@ def __init__( self._set_weights(weights) # Set initial z - self._z = self.xp.array([ - self._x_old for i in range(self._prox_list.size) - ]) + self._z = self.xp.array([self._x_old for i in range(self._prox_list.size)]) # Automatically run the algorithm if auto_iterate: @@ -673,25 +667,25 @@ def _set_weights(self, weights): self._prox_list.size, ) elif not isinstance(weights, (list, tuple, np.ndarray)): - raise TypeError('Weights must be provided as a list.') + raise TypeError("Weights must be provided as a list.") weights = self.xp.array(weights) if not np.issubdtype(weights.dtype, np.floating): - raise ValueError('Weights must be list of float values.') + raise ValueError("Weights must be list of float values.") if weights.size != self._prox_list.size: raise ValueError( - 'The number of weights must match the number of proximity ' - + 'operators.', + "The number of weights must match the number of proximity " + + "operators.", ) expected_weight_sum = 1.0 if self.xp.sum(weights) != expected_weight_sum: raise ValueError( - 'Proximity operator weights must sum to 1.0. Current sum of ' - + 'weights = {0}'.format(self.xp.sum(weights)), + "Proximity operator weights must sum to 1.0. Current sum of " + + f"weights = {self.xp.sum(weights)}", ) self._weights = weights @@ -726,9 +720,7 @@ def _update(self): # Update z values. for i in range(self._prox_list.size): - z_temp = ( - 2 * self._x_old - self._z[i] - self._gamma * self._grad.grad - ) + z_temp = 2 * self._x_old - self._z[i] - self._gamma * self._grad.grad z_prox = self._prox_list[i].op( z_temp, extra_factor=self._gamma / self._weights[i], @@ -784,9 +776,9 @@ def get_notify_observers_kwargs(self): """ return { - 'x_new': self._linear.adj_op(self._x_new), - 'z_new': self._z, - 'idx': self.idx, + "x_new": self._linear.adj_op(self._x_new), + "z_new": self._z, + "idx": self.idx, } def retrieve_outputs(self): @@ -797,7 +789,7 @@ def retrieve_outputs(self): """ metrics = {} - for obs in self._observers['cv_metrics']: + for obs in self._observers["cv_metrics"]: metrics[obs.name] = obs.retrieve_metrics() self.metrics = metrics @@ -871,7 +863,7 @@ def __init__( z, grad, prox, - cost='auto', + cost="auto", linear=None, beta_param=1.0, sigma_bar=1.0, @@ -880,7 +872,6 @@ def __init__( metrics=None, **kwargs, ): - # Set default algorithm properties super().__init__( metric_call_period=metric_call_period, @@ -905,7 +896,7 @@ def __init__( self._grad = grad self._prox = prox self._linear = linear - if cost == 'auto': + if cost == "auto": self._cost_func = costObj([self._grad, self._prox]) else: self._cost_func = cost @@ -918,7 +909,7 @@ def __init__( for param_val in (beta_param, sigma_bar): self._check_param(param_val) if sigma_bar < 0 or sigma_bar > 1: - raise ValueError('The sigma bar parameter needs to be in [0, 1]') + raise ValueError("The sigma bar parameter needs to be in [0, 1]") self._beta = self.step_size or beta_param self._sigma_bar = sigma_bar @@ -944,18 +935,18 @@ def _update(self): """ # Step 4 from alg. 3 self._grad.get_grad(self._x_old) - #self._u_new = self._x_old - self._beta * self._grad.grad - self._u_new = -self._beta * self._grad.grad + # self._u_new = self._x_old - self._beta * self._grad.grad + self._u_new = -self._beta * self._grad.grad self._u_new += self._x_old # Step 5 from alg. 3 - self._t_new = 0.5 * (1 + self.xp.sqrt(1 + 4 * self._t_old ** 2)) + self._t_new = 0.5 * (1 + self.xp.sqrt(1 + 4 * self._t_old**2)) # Step 6 from alg. 3 t_shifted_ratio = (self._t_old - 1) / self._t_new sigma_t_ratio = self._sigma * self._t_old / self._t_new beta_xi_t_shifted_ratio = t_shifted_ratio * self._beta / self._xi - self._z = - beta_xi_t_shifted_ratio * (self._x_old - self._z) + self._z = -beta_xi_t_shifted_ratio * (self._x_old - self._z) self._z += self._u_new self._z += t_shifted_ratio * (self._u_new - self._u_old) self._z += sigma_t_ratio * (self._u_new - self._x_old) @@ -968,20 +959,18 @@ def _update(self): # Restarting and gamma-Decreasing # Step 9 from alg. 3 - #self._g_new = self._grad.grad - (self._x_new - self._z) / self._xi - self._g_new = (self._z - self._x_new) + # self._g_new = self._grad.grad - (self._x_new - self._z) / self._xi + self._g_new = self._z - self._x_new self._g_new /= self._xi self._g_new += self._grad.grad # Step 10 from alg 3. - #self._y_new = self._x_old - self._beta * self._g_new - self._y_new = - self._beta * self._g_new + # self._y_new = self._x_old - self._beta * self._g_new + self._y_new = -self._beta * self._g_new self._y_new += self._x_old # Step 11 from alg. 3 - restart_crit = ( - self.xp.vdot(-self._g_new, self._y_new - self._y_old) < 0 - ) + restart_crit = self.xp.vdot(-self._g_new, self._y_new - self._y_old) < 0 if restart_crit: self._t_new = 1 self._sigma = 1 @@ -999,9 +988,8 @@ def _update(self): # Test cost function for convergence. if self._cost_func: - self.converge = ( - self.any_convergence_flag() - or self._cost_func.get_cost(self._x_new) + self.converge = self.any_convergence_flag() or self._cost_func.get_cost( + self._x_new ) def iterate(self, max_iter=150, progbar=None): @@ -1037,14 +1025,14 @@ def get_notify_observers_kwargs(self): """ return { - 'u_new': self._u_new, - 'x_new': self._linear.adj_op(self._x_new), - 'y_new': self._y_new, - 'z_new': self._z, - 'xi': self._xi, - 'sigma': self._sigma, - 't': self._t_new, - 'idx': self.idx, + "u_new": self._u_new, + "x_new": self._linear.adj_op(self._x_new), + "y_new": self._y_new, + "z_new": self._z, + "xi": self._xi, + "sigma": self._sigma, + "t": self._t_new, + "idx": self.idx, } def retrieve_outputs(self): @@ -1055,6 +1043,6 @@ def retrieve_outputs(self): """ metrics = {} - for obs in self._observers['cv_metrics']: + for obs in self._observers["cv_metrics"]: metrics[obs.name] = obs.retrieve_metrics() self.metrics = metrics diff --git a/modopt/opt/algorithms/gradient_descent.py b/src/modopt/opt/algorithms/gradient_descent.py similarity index 95% rename from modopt/opt/algorithms/gradient_descent.py rename to src/modopt/opt/algorithms/gradient_descent.py index f3fe4b10..0960be5a 100644 --- a/modopt/opt/algorithms/gradient_descent.py +++ b/src/modopt/opt/algorithms/gradient_descent.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Gradient Descent Algorithms.""" import numpy as np @@ -103,7 +102,7 @@ def __init__( self._check_operator(operator) self._grad = grad self._prox = prox - if cost == 'auto': + if cost == "auto": self._cost_func = costObj([self._grad, self._prox]) else: self._cost_func = cost @@ -157,9 +156,8 @@ def _update(self): self._eta = self._eta_update(self._eta, self.idx) # Test cost function for convergence. if self._cost_func: - self.converge = ( - self.any_convergence_flag() - or self._cost_func.get_cost(self._x_new) + self.converge = self.any_convergence_flag() or self._cost_func.get_cost( + self._x_new ) def _update_grad_dir(self, grad): @@ -208,10 +206,10 @@ def get_notify_observers_kwargs(self): """ return { - 'x_new': self._x_new, - 'dir_grad': self._dir_grad, - 'speed_grad': self._speed_grad, - 'idx': self.idx, + "x_new": self._x_new, + "dir_grad": self._dir_grad, + "speed_grad": self._speed_grad, + "idx": self.idx, } def retrieve_outputs(self): @@ -222,7 +220,7 @@ def retrieve_outputs(self): """ metrics = {} - for obs in self._observers['cv_metrics']: + for obs in self._observers["cv_metrics"]: metrics[obs.name] = obs.retrieve_metrics() self.metrics = metrics @@ -308,7 +306,7 @@ class RMSpropGradOpt(GenericGradOpt): def __init__(self, *args, gamma=0.5, **kwargs): super().__init__(*args, **kwargs) if gamma < 0 or gamma > 1: - raise ValueError('gamma is outside of range [0,1]') + raise ValueError("gamma is outside of range [0,1]") self._check_param(gamma) self._gamma = gamma @@ -405,9 +403,9 @@ def __init__(self, *args, gamma=0.9, beta=0.9, **kwargs): self._check_param(gamma) self._check_param(beta) if gamma < 0 or gamma >= 1: - raise ValueError('gamma is outside of range [0,1]') + raise ValueError("gamma is outside of range [0,1]") if beta < 0 or beta >= 1: - raise ValueError('beta is outside of range [0,1]') + raise ValueError("beta is outside of range [0,1]") self._gamma = gamma self._beta = beta self._beta_pow = 1 diff --git a/modopt/opt/algorithms/primal_dual.py b/src/modopt/opt/algorithms/primal_dual.py similarity index 89% rename from modopt/opt/algorithms/primal_dual.py rename to src/modopt/opt/algorithms/primal_dual.py index d5bdd431..fee49a25 100644 --- a/modopt/opt/algorithms/primal_dual.py +++ b/src/modopt/opt/algorithms/primal_dual.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Primal-Dual Algorithms.""" from modopt.opt.algorithms.base import SetUp @@ -81,7 +80,7 @@ def __init__( prox, prox_dual, linear=None, - cost='auto', + cost="auto", reweight=None, rho=0.5, sigma=1.0, @@ -96,7 +95,6 @@ def __init__( metrics=None, **kwargs, ): - # Set default algorithm properties super().__init__( metric_call_period=metric_call_period, @@ -123,12 +121,14 @@ def __init__( self._linear = Identity() else: self._linear = linear - if cost == 'auto': - self._cost_func = costObj([ - self._grad, - self._prox, - self._prox_dual, - ]) + if cost == "auto": + self._cost_func = costObj( + [ + self._grad, + self._prox, + self._prox_dual, + ] + ) else: self._cost_func = cost @@ -187,22 +187,17 @@ def _update(self): self._grad.get_grad(self._x_old) x_prox = self._prox.op( - self._x_old - self._tau * self._grad.grad - self._tau - * self._linear.adj_op(self._y_old), + self._x_old + - self._tau * self._grad.grad + - self._tau * self._linear.adj_op(self._y_old), ) # Step 2 from eq.9. - y_temp = ( - self._y_old + self._sigma - * self._linear.op(2 * x_prox - self._x_old) - ) + y_temp = self._y_old + self._sigma * self._linear.op(2 * x_prox - self._x_old) - y_prox = ( - y_temp - self._sigma - * self._prox_dual.op( - y_temp / self._sigma, - extra_factor=(1.0 / self._sigma), - ) + y_prox = y_temp - self._sigma * self._prox_dual.op( + y_temp / self._sigma, + extra_factor=(1.0 / self._sigma), ) # Step 3 from eq.9. @@ -220,9 +215,8 @@ def _update(self): # Test cost function for convergence. if self._cost_func: - self.converge = ( - self.any_convergence_flag() - or self._cost_func.get_cost(self._x_new, self._y_new) + self.converge = self.any_convergence_flag() or self._cost_func.get_cost( + self._x_new, self._y_new ) def iterate(self, max_iter=150, n_rewightings=1, progbar=None): @@ -267,7 +261,7 @@ def get_notify_observers_kwargs(self): The mapping between the iterated variables """ - return {'x_new': self._x_new, 'y_new': self._y_new, 'idx': self.idx} + return {"x_new": self._x_new, "y_new": self._y_new, "idx": self.idx} def retrieve_outputs(self): """Retrieve outputs. @@ -277,6 +271,6 @@ def retrieve_outputs(self): """ metrics = {} - for obs in self._observers['cv_metrics']: + for obs in self._observers["cv_metrics"]: metrics[obs.name] = obs.retrieve_metrics() self.metrics = metrics diff --git a/modopt/opt/cost.py b/src/modopt/opt/cost.py similarity index 91% rename from modopt/opt/cost.py rename to src/modopt/opt/cost.py index 688a3959..37771f16 100644 --- a/modopt/opt/cost.py +++ b/src/modopt/opt/cost.py @@ -81,7 +81,6 @@ def __init__( verbose=True, plot_output=None, ): - self.cost = initial_cost self._cost_list = [] self._cost_interval = cost_interval @@ -112,20 +111,19 @@ def _check_cost(self): # Check if enough cost values have been collected if len(self._test_list) == self._test_range: - # The mean of the first half of the test list t1 = xp.mean( - xp.array(self._test_list[len(self._test_list) // 2:]), + xp.array(self._test_list[len(self._test_list) // 2 :]), axis=0, ) # The mean of the second half of the test list t2 = xp.mean( - xp.array(self._test_list[:len(self._test_list) // 2]), + xp.array(self._test_list[: len(self._test_list) // 2]), axis=0, ) # Calculate the change across the test list if xp.around(t1, decimals=16): - cost_diff = (xp.linalg.norm(t1 - t2) / xp.linalg.norm(t1)) + cost_diff = xp.linalg.norm(t1 - t2) / xp.linalg.norm(t1) else: cost_diff = 0 @@ -133,9 +131,9 @@ def _check_cost(self): self._test_list = [] if self._verbose: - print(' - CONVERGENCE TEST - ') - print(' - CHANGE IN COST:', cost_diff) - print('') + print(" - CONVERGENCE TEST - ") + print(" - CHANGE IN COST:", cost_diff) + print("") # Check for convergence return cost_diff <= self._tolerance @@ -176,8 +174,7 @@ def get_cost(self, *args, **kwargs): """ # Check if the cost should be calculated test_conditions = ( - self._cost_interval is None - or self._iteration % self._cost_interval + self._cost_interval is None or self._iteration % self._cost_interval ) if test_conditions: @@ -185,15 +182,15 @@ def get_cost(self, *args, **kwargs): else: if self._verbose: - print(' - ITERATION:', self._iteration) + print(" - ITERATION:", self._iteration) # Calculate the current cost - self.cost = self._calc_cost(verbose=self._verbose, *args, **kwargs) + self.cost = self._calc_cost(*args, verbose=self._verbose, **kwargs) self._cost_list.append(self.cost) if self._verbose: - print(' - COST:', self.cost) - print('') + print(" - COST:", self.cost) + print("") # Test for convergence test_result = self._check_cost() @@ -288,13 +285,11 @@ def _check_operators(self): """ if not isinstance(self._operators, (list, tuple, np.ndarray)): - message = ( - 'Input operators must be provided as a list, not {0}' - ) + message = "Input operators must be provided as a list, not {0}" raise TypeError(message.format(type(self._operators))) for op in self._operators: - if not hasattr(op, 'cost'): + if not hasattr(op, "cost"): raise ValueError('Operators must contain "cost" method.') op.cost = check_callable(op.cost) diff --git a/modopt/opt/gradient.py b/src/modopt/opt/gradient.py similarity index 97% rename from modopt/opt/gradient.py rename to src/modopt/opt/gradient.py index caa8fa9d..fe9b87d8 100644 --- a/modopt/opt/gradient.py +++ b/src/modopt/opt/gradient.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """GRADIENT CLASSES. This module contains classses for defining algorithm gradients. @@ -14,7 +12,7 @@ from modopt.base.types import check_callable, check_float, check_npndarray -class GradParent(object): +class GradParent: """Gradient Parent Class. This class defines the basic methods that will be inherited by specific @@ -71,7 +69,6 @@ def __init__( input_data_writeable=False, verbose=True, ): - self.verbose = verbose self._input_data_writeable = input_data_writeable self._grad_data_type = data_type @@ -100,7 +97,6 @@ def obs_data(self): @obs_data.setter def obs_data(self, input_data): - if self._grad_data_type in {float, np.floating}: input_data = check_float(input_data) check_npndarray( @@ -128,7 +124,6 @@ def op(self): @op.setter def op(self, operator): - self._op = check_callable(operator) @property @@ -147,7 +142,6 @@ def trans_op(self): @trans_op.setter def trans_op(self, operator): - self._trans_op = check_callable(operator) @property @@ -157,7 +151,6 @@ def get_grad(self): @get_grad.setter def get_grad(self, method): - self._get_grad = check_callable(method) @property @@ -167,7 +160,6 @@ def grad(self): @grad.setter def grad(self, input_value): - if self._grad_data_type in {float, np.floating}: input_value = check_float(input_value) self._grad = input_value @@ -179,7 +171,6 @@ def cost(self): @cost.setter def cost(self, method): - self._cost = check_callable(method) def trans_op_op(self, input_data): @@ -243,7 +234,6 @@ class GradBasic(GradParent): """ def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) self.get_grad = self._get_grad_method self.cost = self._cost_method @@ -289,7 +279,7 @@ def _cost_method(self, *args, **kwargs): """ cost_val = 0.5 * np.linalg.norm(self.obs_data - self.op(args[0])) ** 2 - if 'verbose' in kwargs and kwargs['verbose']: - print(' - DATA FIDELITY (X):', cost_val) + if kwargs.get("verbose"): + print(" - DATA FIDELITY (X):", cost_val) return cost_val diff --git a/modopt/opt/linear/__init__.py b/src/modopt/opt/linear/__init__.py similarity index 100% rename from modopt/opt/linear/__init__.py rename to src/modopt/opt/linear/__init__.py diff --git a/modopt/opt/linear/base.py b/src/modopt/opt/linear/base.py similarity index 92% rename from modopt/opt/linear/base.py rename to src/modopt/opt/linear/base.py index e347970d..af748a73 100644 --- a/modopt/opt/linear/base.py +++ b/src/modopt/opt/linear/base.py @@ -5,7 +5,8 @@ from modopt.base.types import check_callable from modopt.base.backend import get_array_module -class LinearParent(object): + +class LinearParent: """Linear Operator Parent Class. This class sets the structure for defining linear operator instances. @@ -29,7 +30,6 @@ class LinearParent(object): """ def __init__(self, op, adj_op): - self.op = op self.adj_op = adj_op @@ -40,7 +40,6 @@ def op(self): @op.setter def op(self, operator): - self._op = check_callable(operator) @property @@ -50,7 +49,6 @@ def adj_op(self): @adj_op.setter def adj_op(self, operator): - self._adj_op = check_callable(operator) @@ -66,10 +64,9 @@ class Identity(LinearParent): """ def __init__(self): - self.op = lambda input_data: input_data self.adj_op = self.op - self.cost= lambda *args, **kwargs: 0 + self.cost = lambda *args, **kwargs: 0 class MatrixOperator(LinearParent): @@ -126,7 +123,6 @@ class LinearCombo(LinearParent): """ def __init__(self, operators, weights=None): - operators, weights = self._check_inputs(operators, weights) self.operators = operators self.weights = weights @@ -159,14 +155,13 @@ def _check_type(self, input_val): """ if not isinstance(input_val, (list, tuple, np.ndarray)): raise TypeError( - 'Invalid input type, input must be a list, tuple or numpy ' - + 'array.', + "Invalid input type, input must be a list, tuple or numpy " + "array.", ) input_val = np.array(input_val) if not input_val.size: - raise ValueError('Input list is empty.') + raise ValueError("Input list is empty.") return input_val @@ -199,11 +194,10 @@ def _check_inputs(self, operators, weights): operators = self._check_type(operators) for operator in operators: - - if not hasattr(operator, 'op'): + if not hasattr(operator, "op"): raise ValueError('Operators must contain "op" method.') - if not hasattr(operator, 'adj_op'): + if not hasattr(operator, "adj_op"): raise ValueError('Operators must contain "adj_op" method.') operator.op = check_callable(operator.op) @@ -214,12 +208,11 @@ def _check_inputs(self, operators, weights): if weights.size != operators.size: raise ValueError( - 'The number of weights must match the number of ' - + 'operators.', + "The number of weights must match the number of " + "operators.", ) if not np.issubdtype(weights.dtype, np.floating): - raise TypeError('The weights must be a list of float values.') + raise TypeError("The weights must be a list of float values.") return operators, weights diff --git a/modopt/opt/linear/wavelet.py b/src/modopt/opt/linear/wavelet.py similarity index 91% rename from modopt/opt/linear/wavelet.py rename to src/modopt/opt/linear/wavelet.py index 5feead66..8012a072 100644 --- a/modopt/opt/linear/wavelet.py +++ b/src/modopt/opt/linear/wavelet.py @@ -45,8 +45,7 @@ class WaveletConvolve(LinearParent): """ - def __init__(self, filters, method='scipy'): - + def __init__(self, filters, method="scipy"): self._filters = check_float(filters) self.op = lambda input_data: filter_convolve_stack( input_data, @@ -61,13 +60,11 @@ def __init__(self, filters, method='scipy'): ) - - class WaveletTransform(LinearParent): """ 2D and 3D wavelet transform class. - This is a wrapper around either Pywavelet (CPU) or Pytorch Wavelet (GPU using Pytorch). + This is a wrapper around either Pywavelet (CPU) or Pytorch Wavelet (GPU). Parameters ---------- @@ -81,33 +78,41 @@ class WaveletTransform(LinearParent): mode: str, default "zero" Boundary Condition mode compute_backend: str, "numpy" or "cupy", default "numpy" - Backend library to use. "cupy" also requires a working installation of PyTorch and pytorch wavelets. + Backend library to use. "cupy" also requires a working installation of PyTorch + and PyTorch wavelets (ptwt). **kwargs: extra kwargs for Pywavelet or Pytorch Wavelet """ - def __init__(self, + + def __init__( + self, wavelet_name, shape, level=4, mode="symmetric", compute_backend="numpy", - **kwargs): - + **kwargs, + ): if compute_backend == "cupy" and ptwt_available: - self.operator = CupyWaveletTransform(wavelet=wavelet_name, shape=shape, level=level, mode=mode) + self.operator = CupyWaveletTransform( + wavelet=wavelet_name, shape=shape, level=level, mode=mode + ) elif compute_backend == "numpy" and pywt_available: - self.operator = CPUWaveletTransform(wavelet_name=wavelet_name, shape=shape, level=level, **kwargs) + self.operator = CPUWaveletTransform( + wavelet_name=wavelet_name, shape=shape, level=level, **kwargs + ) else: raise ValueError(f"Compute Backend {compute_backend} not available") - self.op = self.operator.op self.adj_op = self.operator.adj_op @property def coeffs_shape(self): + """Get the coeffs shapes.""" return self.operator.coeffs_shape + class CPUWaveletTransform(LinearParent): """ 2D and 3D wavelet transform class. @@ -159,7 +164,8 @@ def __init__( ): if not pywt_available: raise ImportError( - "PyWavelet and/or joblib are not available. Please install it to use WaveletTransform." + "PyWavelet and/or joblib are not available. " + "Please install it to use WaveletTransform." ) if wavelet_name not in pywt.wavelist(kind="all"): raise ValueError( @@ -193,7 +199,9 @@ def __init__( self.n_batch = n_batch if self.n_batch == 1 and self.n_jobs != 1: - warnings.warn("Making n_jobs = 1 for WaveletTransform as n_batchs = 1") + warnings.warn( + "Making n_jobs = 1 for WaveletTransform as n_batchs = 1", stacklevel=1 + ) self.n_jobs = 1 self.backend = backend n_proc = self.n_jobs @@ -273,22 +281,22 @@ def _adj_op(self, coeffs): class TorchWaveletTransform: """Wavelet transform using pytorch.""" - wavedec3_keys = ["aad", "ada", "add", "daa", "dad", "dda", "ddd"] + wavedec3_keys = ("aad", "ada", "add", "daa", "dad", "dda", "ddd") def __init__( self, - shape: tuple[int, ...], - wavelet: str, - level: int, - mode: str, + shape, + wavelet, + level, + mode, ): self.wavelet = wavelet self.level = level self.shape = shape self.mode = mode - self.coeffs_shape = None # will be set after op. + self.coeffs_shape = None # will be set after op. - def op(self, data: torch.Tensor) -> list[torch.Tensor]: + def op(self, data): """Apply the wavelet decomposition on. Parameters @@ -350,7 +358,7 @@ def op(self, data: torch.Tensor) -> list[torch.Tensor]: ) return [coeffs_[0]] + [cc for c in coeffs_[1:] for cc in c.values()] - def adj_op(self, coeffs: list[torch.Tensor]) -> torch.Tensor: + def adj_op(self, coeffs): """Apply the wavelet recomposition. Parameters @@ -409,20 +417,22 @@ class CupyWaveletTransform(LinearParent): def __init__( self, - shape: tuple[int, ...], - wavelet: str, - level: int, - mode: str, + shape, + wavelet, + level, + mode, ): self.wavelet = wavelet self.level = level self.shape = shape self.mode = mode - self.operator = TorchWaveletTransform(shape=shape, wavelet=wavelet, level=level,mode=mode) - self.coeffs_shape = None # will be set after op + self.operator = TorchWaveletTransform( + shape=shape, wavelet=wavelet, level=level, mode=mode + ) + self.coeffs_shape = None # will be set after op - def op(self, data: cp.array) -> cp.ndarray: + def op(self, data): """Define the wavelet operator. This method returns the input data convolved with the wavelet filter. @@ -452,7 +462,7 @@ def op(self, data: cp.array) -> cp.ndarray: return ret - def adj_op(self, data: cp.ndarray) -> cp.ndarray: + def adj_op(self, data): """Define the wavelet adjoint operator. This method returns the reconstructed image. diff --git a/modopt/opt/proximity.py b/src/modopt/opt/proximity.py similarity index 89% rename from modopt/opt/proximity.py rename to src/modopt/opt/proximity.py index fc81a753..dea862ca 100644 --- a/modopt/opt/proximity.py +++ b/src/modopt/opt/proximity.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """PROXIMITY OPERATORS. This module contains classes of proximity operators for optimisation. @@ -32,7 +30,7 @@ from modopt.signal.svd import svd_thresh, svd_thresh_coef, svd_thresh_coef_fast -class ProximityParent(object): +class ProximityParent: """Proximity Operator Parent Class. This class sets the structure for defining proximity operator instances. @@ -48,7 +46,6 @@ class ProximityParent(object): """ def __init__(self, op, cost): - self.op = op self.cost = cost @@ -59,7 +56,6 @@ def op(self): @op.setter def op(self, operator): - self._op = check_callable(operator) @property @@ -79,7 +75,6 @@ def cost(self): @cost.setter def cost(self, method): - self._cost = check_callable(method) @@ -99,9 +94,8 @@ class IdentityProx(ProximityParent): """ def __init__(self): - - self.op = lambda x_val: x_val - self.cost = lambda x_val: 0 + self.op = lambda x_val, *args, **kwargs: x_val + self.cost = lambda x_val, *args, **kwargs: 0 class Positivity(ProximityParent): @@ -117,10 +111,25 @@ class Positivity(ProximityParent): """ def __init__(self): - - self.op = lambda input_data: positive(input_data) self.cost = self._cost_method + def op(self, input_data, *args, **kwargs): + """ + Make the data positive. + + Parameters + ---------- + input_data: np.ndarray + Input array + *args, **kwargs: dummy. + + Returns + ------- + np.ndarray + Positive data. + """ + return positive(input_data) + def _cost_method(self, *args, **kwargs): """Calculate positivity component of the cost. @@ -140,8 +149,8 @@ def _cost_method(self, *args, **kwargs): ``0.0`` """ - if 'verbose' in kwargs and kwargs['verbose']: - print(' - Min (X):', np.min(args[0])) + if kwargs.get("verbose"): + print(" - Min (X):", np.min(args[0])) return 0 @@ -167,8 +176,7 @@ class SparseThreshold(ProximityParent): """ - def __init__(self, linear, weights, thresh_type='soft'): - + def __init__(self, linear, weights, thresh_type="soft"): self._linear = linear self.weights = weights self._thresh_type = thresh_type @@ -221,8 +229,8 @@ def _cost_method(self, *args, **kwargs): if isinstance(cost_val, xp.ndarray): cost_val = cost_val.item() - if 'verbose' in kwargs and kwargs['verbose']: - print(' - L1 NORM (X):', cost_val) + if kwargs.get("verbose"): + print(" - L1 NORM (X):", cost_val) return cost_val @@ -273,12 +281,11 @@ class LowRankMatrix(ProximityParent): def __init__( self, threshold, - thresh_type='soft', - lowr_type='standard', + thresh_type="soft", + lowr_type="standard", initial_rank=None, operator=None, ): - self.thresh = threshold self.thresh_type = thresh_type self.lowr_type = lowr_type @@ -315,13 +322,13 @@ def _op_method(self, input_data, extra_factor=1.0, rank=None): """ # Update threshold with extra factor. threshold = self.thresh * extra_factor - if self.lowr_type == 'standard' and self.rank is None and rank is None: + if self.lowr_type == "standard" and self.rank is None and rank is None: data_matrix = svd_thresh( cube2matrix(input_data), threshold, thresh_type=self.thresh_type, ) - elif self.lowr_type == 'standard': + elif self.lowr_type == "standard": data_matrix, update_rank = svd_thresh_coef_fast( cube2matrix(input_data), threshold, @@ -331,7 +338,7 @@ def _op_method(self, input_data, extra_factor=1.0, rank=None): ) self.rank = update_rank # save for future use - elif self.lowr_type == 'ngole': + elif self.lowr_type == "ngole": data_matrix = svd_thresh_coef( cube2matrix(input_data), self.operator, @@ -339,7 +346,7 @@ def _op_method(self, input_data, extra_factor=1.0, rank=None): thresh_type=self.thresh_type, ) else: - raise ValueError('lowr_type should be standard or ngole') + raise ValueError("lowr_type should be standard or ngole") # Return updated data. return matrix2cube(data_matrix, input_data.shape[1:]) @@ -365,8 +372,8 @@ def _cost_method(self, *args, **kwargs): """ cost_val = self.thresh * nuclear_norm(cube2matrix(args[0])) - if 'verbose' in kwargs and kwargs['verbose']: - print(' - NUCLEAR NORM (X):', cost_val) + if kwargs.get("verbose"): + print(" - NUCLEAR NORM (X):", cost_val) return cost_val @@ -470,7 +477,6 @@ class ProximityCombo(ProximityParent): """ def __init__(self, operators): - operators = self._check_operators(operators) self.operators = operators self.op = self._op_method @@ -506,19 +512,19 @@ def _check_operators(self, operators): """ if not isinstance(operators, (list, tuple, np.ndarray)): raise TypeError( - 'Invalid input type, operators must be a list, tuple or ' - + 'numpy array.', + "Invalid input type, operators must be a list, tuple or " + + "numpy array.", ) operators = np.array(operators) if not operators.size: - raise ValueError('Operator list is empty.') + raise ValueError("Operator list is empty.") for operator in operators: - if not hasattr(operator, 'op'): + if not hasattr(operator, "op"): raise ValueError('Operators must contain "op" method.') - if not hasattr(operator, 'cost'): + if not hasattr(operator, "cost"): raise ValueError('Operators must contain "cost" method.') operator.op = check_callable(operator.op) operator.cost = check_callable(operator.cost) @@ -573,10 +579,12 @@ def _cost_method(self, *args, **kwargs): Combinded cost components """ - return np.sum([ - operator.cost(input_data) - for operator, input_data in zip(self.operators, args[0]) - ]) + return np.sum( + [ + operator.cost(input_data) + for operator, input_data in zip(self.operators, args[0]) + ] + ) class OrderedWeightedL1Norm(ProximityParent): @@ -617,16 +625,16 @@ class OrderedWeightedL1Norm(ProximityParent): def __init__(self, weights): if not import_sklearn: # pragma: no cover raise ImportError( - 'Required version of Scikit-Learn package not found see ' - + 'documentation for details: ' - + 'https://cea-cosmic.github.io/ModOpt/#optional-packages', + "Required version of Scikit-Learn package not found see " + + "documentation for details: " + + "https://cea-cosmic.github.io/ModOpt/#optional-packages", ) if np.max(np.diff(weights)) > 0: - raise ValueError('Weights must be non increasing') + raise ValueError("Weights must be non increasing") self.weights = weights.flatten() if (self.weights < 0).any(): raise ValueError( - 'The weight values must be provided in descending order', + "The weight values must be provided in descending order", ) self.op = self._op_method self.cost = self._cost_method @@ -664,7 +672,9 @@ def _op_method(self, input_data, extra_factor=1.0): # Projection onto the monotone non-negative cone using # isotonic_regression data_abs = isotonic_regression( - data_abs - threshold, y_min=0, increasing=False, + data_abs - threshold, + y_min=0, + increasing=False, ) # Unsorting the data @@ -672,7 +682,7 @@ def _op_method(self, input_data, extra_factor=1.0): data_abs_unsorted[data_abs_sort_idx] = data_abs # Putting the sign back - with np.errstate(invalid='ignore'): + with np.errstate(invalid="ignore"): sign_data = data_squeezed / np.abs(data_squeezed) # Removing NAN caused by the sign @@ -702,8 +712,8 @@ def _cost_method(self, *args, **kwargs): self.weights * np.sort(np.squeeze(np.abs(args[0]))[::-1]), ) - if 'verbose' in kwargs and kwargs['verbose']: - print(' - OWL NORM (X):', cost_val) + if kwargs.get("verbose"): + print(" - OWL NORM (X):", cost_val) return cost_val @@ -734,8 +744,7 @@ class Ridge(ProximityParent): """ - def __init__(self, linear, weights, thresh_type='soft'): - + def __init__(self, linear, weights, thresh_type="soft"): self._linear = linear self.weights = weights self.op = self._op_method @@ -786,8 +795,8 @@ def _cost_method(self, *args, **kwargs): np.abs(self.weights * self._linear.op(args[0]) ** 2), ) - if 'verbose' in kwargs and kwargs['verbose']: - print(' - L2 NORM (X):', cost_val) + if kwargs.get("verbose"): + print(" - L2 NORM (X):", cost_val) return cost_val @@ -822,7 +831,6 @@ class ElasticNet(ProximityParent): """ def __init__(self, linear, alpha, beta): - self._linear = linear self.alpha = alpha self.beta = beta @@ -848,8 +856,8 @@ def _op_method(self, input_data, extra_factor=1.0): """ soft_threshold = self.beta * extra_factor - normalization = (self.alpha * 2 * extra_factor + 1) - return thresh(input_data, soft_threshold, 'soft') / normalization + normalization = self.alpha * 2 * extra_factor + 1 + return thresh(input_data, soft_threshold, "soft") / normalization def _cost_method(self, *args, **kwargs): """Calculate Ridge component of the cost. @@ -875,8 +883,8 @@ def _cost_method(self, *args, **kwargs): + np.abs(self.beta * self._linear.op(args[0])), ) - if 'verbose' in kwargs and kwargs['verbose']: - print(' - ELASTIC NET (X):', cost_val) + if kwargs.get("verbose"): + print(" - ELASTIC NET (X):", cost_val) return cost_val @@ -942,7 +950,7 @@ def k_value(self): def k_value(self, k_val): if k_val < 1: raise ValueError( - 'The k parameter should be greater or equal than 1', + "The k parameter should be greater or equal than 1", ) self._k_value = k_val @@ -987,7 +995,7 @@ def _compute_theta(self, input_data, alpha, extra_factor=1.0): alpha_beta = alpha_input - self.beta * extra_factor theta = alpha_beta * ((alpha_beta <= 1) & (alpha_beta >= 0)) theta = np.nan_to_num(theta) - theta += (alpha_input > (self.beta * extra_factor + 1)) + theta += alpha_input > (self.beta * extra_factor + 1) return theta def _interpolate(self, alpha0, alpha1, sum0, sum1): @@ -1078,12 +1086,10 @@ def _binary_search(self, input_data, alpha, extra_factor=1.0): midpoint = 0 while (first_idx <= last_idx) and not found and (cnt < alpha.shape[0]): - midpoint = (first_idx + last_idx) // 2 cnt += 1 if prev_midpoint == midpoint: - # Particular case sum0 = self._compute_theta( data_abs, @@ -1096,11 +1102,11 @@ def _binary_search(self, input_data, alpha, extra_factor=1.0): extra_factor, ).sum() - if (np.abs(sum0 - self._k_value) <= tolerance): + if np.abs(sum0 - self._k_value) <= tolerance: found = True midpoint = first_idx - if (np.abs(sum1 - self._k_value) <= tolerance): + if np.abs(sum1 - self._k_value) <= tolerance: found = True midpoint = last_idx - 1 # -1 because output is index such that @@ -1145,13 +1151,17 @@ def _binary_search(self, input_data, alpha, extra_factor=1.0): if found: return ( - midpoint, alpha[midpoint], alpha[midpoint + 1], sum0, sum1, + midpoint, + alpha[midpoint], + alpha[midpoint + 1], + sum0, + sum1, ) raise ValueError( - 'Cannot find the coordinate of alpha (i) such ' - + 'that sum(theta(alpha[i])) =< k and ' - + 'sum(theta(alpha[i+1])) >= k ', + "Cannot find the coordinate of alpha (i) such " + + "that sum(theta(alpha[i])) =< k and " + + "sum(theta(alpha[i+1])) >= k ", ) def _find_alpha(self, input_data, extra_factor=1.0): @@ -1175,15 +1185,13 @@ def _find_alpha(self, input_data, extra_factor=1.0): data_size = input_data.shape[0] # Computes the alpha^i points line 1 in Algorithm 1. - alpha = np.zeros((data_size * 2)) + alpha = np.zeros(data_size * 2) data_abs = np.abs(input_data) - alpha[:data_size] = ( - (self.beta * extra_factor) - / (data_abs + sys.float_info.epsilon) + alpha[:data_size] = (self.beta * extra_factor) / ( + data_abs + sys.float_info.epsilon ) - alpha[data_size:] = ( - (self.beta * extra_factor + 1) - / (data_abs + sys.float_info.epsilon) + alpha[data_size:] = (self.beta * extra_factor + 1) / ( + data_abs + sys.float_info.epsilon ) alpha = np.sort(np.unique(alpha)) @@ -1220,8 +1228,8 @@ def _op_method(self, input_data, extra_factor=1.0): k_max = np.prod(data_shape) if self._k_value > k_max: warn( - 'K value of the K-support norm is greater than the input ' - + 'dimension, its value will be set to {0}'.format(k_max), + "K value of the K-support norm is greater than the input " + + f"dimension, its value will be set to {k_max}", ) self._k_value = k_max @@ -1233,8 +1241,7 @@ def _op_method(self, input_data, extra_factor=1.0): # Computes line 5. in Algorithm 1. rslt = np.nan_to_num( - (input_data.flatten() * theta) - / (theta + self.beta * extra_factor), + (input_data.flatten() * theta) / (theta + self.beta * extra_factor), ) return rslt.reshape(data_shape) @@ -1275,25 +1282,20 @@ def _find_q(self, sorted_data): found = True q_val = 0 - elif ( - (sorted_data[self._k_value - 1:].sum()) - <= sorted_data[self._k_value - 1] - ): + elif (sorted_data[self._k_value - 1 :].sum()) <= sorted_data[self._k_value - 1]: found = True q_val = self._k_value - 1 while ( - not found and not cnt == self._k_value + not found + and not cnt == self._k_value and (first_idx <= last_idx < self._k_value) ): - q_val = (first_idx + last_idx) // 2 cnt += 1 l1_part = sorted_data[q_val:].sum() / (self._k_value - q_val) - if ( - sorted_data[q_val + 1] <= l1_part <= sorted_data[q_val] - ): + if sorted_data[q_val + 1] <= l1_part <= sorted_data[q_val]: found = True else: @@ -1328,15 +1330,12 @@ def _cost_method(self, *args, **kwargs): data_abs = data_abs[ix] # Sorted absolute value of the data q_val = self._find_q(data_abs) cost_val = ( - ( - np.sum(data_abs[:q_val] ** 2) * 0.5 - + np.sum(data_abs[q_val:]) ** 2 - / (self._k_value - q_val) - ) * self.beta - ) + np.sum(data_abs[:q_val] ** 2) * 0.5 + + np.sum(data_abs[q_val:]) ** 2 / (self._k_value - q_val) + ) * self.beta - if 'verbose' in kwargs and kwargs['verbose']: - print(' - K-SUPPORT NORM (X):', cost_val) + if kwargs.get("verbose"): + print(" - K-SUPPORT NORM (X):", cost_val) return cost_val @@ -1381,7 +1380,7 @@ def __init__(self, weights): self.op = self._op_method self.cost = self._cost_method - def _op_method(self, input_data, extra_factor=1.0): + def _op_method(self, input_data, *args, extra_factor=1.0, **kwargs): """Operator. This method returns the input data thresholded by the weights. @@ -1392,6 +1391,7 @@ def _op_method(self, input_data, extra_factor=1.0): Input data array extra_factor : float Additional multiplication factor (default is ``1.0``) + *args, **kwargs: no effects Returns ------- @@ -1407,7 +1407,7 @@ def _op_method(self, input_data, extra_factor=1.0): (1.0 - self.weights * extra_factor / denominator), ) - def _cost_method(self, input_data): + def _cost_method(self, input_data, *args, **kwargs): """Calculate the group LASSO component of the cost. This method calculate the cost function of the proximable part. @@ -1417,6 +1417,8 @@ def _cost_method(self, input_data): input_data : numpy.ndarray Input array of the sparse code + *args, **kwargs: no effects. + Returns ------- float diff --git a/modopt/opt/reweight.py b/src/modopt/opt/reweight.py similarity index 91% rename from modopt/opt/reweight.py rename to src/modopt/opt/reweight.py index 8c4f2449..4a9bf44b 100644 --- a/modopt/opt/reweight.py +++ b/src/modopt/opt/reweight.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """REWEIGHTING CLASSES. This module contains classes for reweighting optimisation implementations. @@ -13,7 +11,7 @@ from modopt.base.types import check_float -class cwbReweight(object): +class cwbReweight: """Candes, Wakin and Boyd reweighting class. This class implements the reweighting scheme described in @@ -45,7 +43,6 @@ class cwbReweight(object): """ def __init__(self, weights, thresh_factor=1.0, verbose=False): - self.weights = check_float(weights) self.original_weights = np.copy(self.weights) self.thresh_factor = check_float(thresh_factor) @@ -81,7 +78,7 @@ def reweight(self, input_data): """ if self.verbose: - print(' - Reweighting: {0}'.format(self._rw_num)) + print(f" - Reweighting: {self._rw_num}") self._rw_num += 1 @@ -89,7 +86,7 @@ def reweight(self, input_data): if input_data.shape != self.weights.shape: raise ValueError( - 'Input data must have the same shape as the initial weights.', + "Input data must have the same shape as the initial weights.", ) thresh_weights = self.thresh_factor * self.original_weights diff --git a/modopt/plot/__init__.py b/src/modopt/plot/__init__.py similarity index 73% rename from modopt/plot/__init__.py rename to src/modopt/plot/__init__.py index 28d60be6..f31ed596 100644 --- a/modopt/plot/__init__.py +++ b/src/modopt/plot/__init__.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """PLOTTING ROUTINES. This module contains submodules for plotting applications. @@ -8,4 +6,4 @@ """ -__all__ = ['cost_plot'] +__all__ = ["cost_plot"] diff --git a/modopt/plot/cost_plot.py b/src/modopt/plot/cost_plot.py similarity index 66% rename from modopt/plot/cost_plot.py rename to src/modopt/plot/cost_plot.py index aa855eaa..7fb7e39b 100644 --- a/modopt/plot/cost_plot.py +++ b/src/modopt/plot/cost_plot.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """PLOTTING ROUTINES. This module contains methods for making plots. @@ -37,20 +35,20 @@ def plotCost(cost_list, output=None): """ if import_fail: - raise ImportError('Matplotlib package not found') + raise ImportError("Matplotlib package not found") else: if isinstance(output, type(None)): - file_name = 'cost_function.png' + file_name = "cost_function.png" else: - file_name = '{0}_cost_function.png'.format(output) + file_name = f"{output}_cost_function.png" plt.figure() - plt.plot(np.log10(cost_list), 'r-') - plt.title('Cost Function') - plt.xlabel('Iteration') - plt.ylabel(r'$\log_{10}$ Cost') + plt.plot(np.log10(cost_list), "r-") + plt.title("Cost Function") + plt.xlabel("Iteration") + plt.ylabel(r"$\log_{10}$ Cost") plt.savefig(file_name) plt.close() - print(' - Saving cost function data to:', file_name) + print(" - Saving cost function data to:", file_name) diff --git a/modopt/signal/__init__.py b/src/modopt/signal/__init__.py similarity index 58% rename from modopt/signal/__init__.py rename to src/modopt/signal/__init__.py index dbc6d053..6bf0912b 100644 --- a/modopt/signal/__init__.py +++ b/src/modopt/signal/__init__.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """SIGNAL PROCESSING ROUTINES. This module contains submodules for signal processing. @@ -8,4 +6,4 @@ """ -__all__ = ['filter', 'noise', 'positivity', 'svd', 'validation', 'wavelet'] +__all__ = ["filter", "noise", "positivity", "svd", "validation", "wavelet"] diff --git a/modopt/signal/filter.py b/src/modopt/signal/filter.py similarity index 96% rename from modopt/signal/filter.py rename to src/modopt/signal/filter.py index 84dd8160..33c3c105 100644 --- a/modopt/signal/filter.py +++ b/src/modopt/signal/filter.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """FILTER ROUTINES. This module contains methods for distance measurements in cosmology. @@ -81,7 +79,7 @@ def mex_hat(data_point, sigma): sigma = check_float(sigma) xs = (data_point / sigma) ** 2 - factor = 2 * (3 * sigma) ** -0.5 * np.pi ** -0.25 + factor = 2 * (3 * sigma) ** -0.5 * np.pi**-0.25 return factor * (1 - xs) * np.exp(-0.5 * xs) diff --git a/modopt/signal/noise.py b/src/modopt/signal/noise.py similarity index 86% rename from modopt/signal/noise.py rename to src/modopt/signal/noise.py index a59d5553..28307f52 100644 --- a/modopt/signal/noise.py +++ b/src/modopt/signal/noise.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """NOISE ROUTINES. This module contains methods for adding and removing noise from data. @@ -8,14 +6,12 @@ """ -from builtins import zip - import numpy as np from modopt.base.backend import get_array_module -def add_noise(input_data, sigma=1.0, noise_type='gauss'): +def add_noise(input_data, sigma=1.0, noise_type="gauss", rng=None): """Add noise to data. This method adds Gaussian or Poisson noise to the input data. @@ -29,6 +25,9 @@ def add_noise(input_data, sigma=1.0, noise_type='gauss'): default is ``1.0``) noise_type : {'gauss', 'poisson'} Type of noise to be added (default is ``'gauss'``) + rng: np.random.Generator or int + A Random number generator or a seed to initialize one. + Returns ------- @@ -68,9 +67,12 @@ def add_noise(input_data, sigma=1.0, noise_type='gauss'): array([ 3.24869073, -1.22351283, -1.0563435 , -2.14593724, 1.73081526]) """ + if not isinstance(rng, np.random.Generator): + rng = np.random.default_rng(rng) + input_data = np.array(input_data) - if noise_type not in {'gauss', 'poisson'}: + if noise_type not in {"gauss", "poisson"}: raise ValueError( 'Invalid noise type. Options are "gauss" or "poisson"', ) @@ -78,15 +80,14 @@ def add_noise(input_data, sigma=1.0, noise_type='gauss'): if isinstance(sigma, (list, tuple, np.ndarray)): if len(sigma) != input_data.shape[0]: raise ValueError( - 'Number of sigma values must match first dimension of input ' - + 'data', + "Number of sigma values must match first dimension of input " + "data", ) - if noise_type == 'gauss': - random = np.random.randn(*input_data.shape) + if noise_type == "gauss": + random = rng.standard_normal(input_data.shape) - elif noise_type == 'poisson': - random = np.random.poisson(np.abs(input_data)) + elif noise_type == "poisson": + random = rng.poisson(np.abs(input_data)) if isinstance(sigma, (int, float)): return input_data + sigma * random @@ -96,7 +97,7 @@ def add_noise(input_data, sigma=1.0, noise_type='gauss'): return input_data + noise -def thresh(input_data, threshold, threshold_type='hard'): +def thresh(input_data, threshold, threshold_type="hard"): r"""Threshold data. This method perfoms hard or soft thresholding on the input data. @@ -169,12 +170,12 @@ def thresh(input_data, threshold, threshold_type='hard'): input_data = xp.array(input_data) - if threshold_type not in {'hard', 'soft'}: + if threshold_type not in {"hard", "soft"}: raise ValueError( 'Invalid threshold type. Options are "hard" or "soft"', ) - if threshold_type == 'soft': + if threshold_type == "soft": denominator = xp.maximum(xp.finfo(np.float64).eps, xp.abs(input_data)) max_value = xp.maximum((1.0 - threshold / denominator), 0) diff --git a/modopt/signal/positivity.py b/src/modopt/signal/positivity.py similarity index 90% rename from modopt/signal/positivity.py rename to src/modopt/signal/positivity.py index c19ba62c..8d7aa46c 100644 --- a/modopt/signal/positivity.py +++ b/src/modopt/signal/positivity.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """POSITIVITY. This module contains a function that retains only positive coefficients in @@ -47,7 +45,7 @@ def pos_recursive(input_data): Positive coefficients """ - if input_data.dtype == 'O': + if input_data.dtype == "O": res = np.array([pos_recursive(elem) for elem in input_data], dtype="object") else: @@ -97,15 +95,15 @@ def positive(input_data, ragged=False): """ if not isinstance(input_data, (int, float, list, tuple, np.ndarray)): raise TypeError( - 'Invalid data type, input must be `int`, `float`, `list`, ' - + '`tuple` or `np.ndarray`.', + "Invalid data type, input must be `int`, `float`, `list`, " + + "`tuple` or `np.ndarray`.", ) if isinstance(input_data, (int, float)): return pos_thresh(input_data) if ragged: - input_data = np.array(input_data, dtype='object') + input_data = np.array(input_data, dtype="object") else: input_data = np.array(input_data) diff --git a/modopt/signal/svd.py b/src/modopt/signal/svd.py similarity index 85% rename from modopt/signal/svd.py rename to src/modopt/signal/svd.py index f3d40a51..cf147503 100644 --- a/modopt/signal/svd.py +++ b/src/modopt/signal/svd.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """SVD ROUTINES. This module contains methods for thresholding singular values. @@ -52,8 +50,8 @@ def find_n_pc(u_vec, factor=0.5): """ if np.sqrt(u_vec.shape[0]) % 1: raise ValueError( - 'Invalid left singular vector. The size of the first ' - + 'dimenion of ``u_vec`` must be perfect square.', + "Invalid left singular vector. The size of the first " + + "dimenion of ``u_vec`` must be perfect square.", ) # Get the shape of the array @@ -69,13 +67,12 @@ def find_n_pc(u_vec, factor=0.5): ] # Return the required number of principal components. - return np.sum([ - ( - u_val[tuple(zip(array_shape // 2))] ** 2 <= factor - * np.sum(u_val ** 2), - ) - for u_val in u_auto - ]) + return np.sum( + [ + (u_val[tuple(zip(array_shape // 2))] ** 2 <= factor * np.sum(u_val**2),) + for u_val in u_auto + ] + ) def calculate_svd(input_data): @@ -101,17 +98,17 @@ def calculate_svd(input_data): """ if (not isinstance(input_data, np.ndarray)) or (input_data.ndim != 2): - raise TypeError('Input data must be a 2D np.ndarray.') + raise TypeError("Input data must be a 2D np.ndarray.") return svd( input_data, check_finite=False, - lapack_driver='gesvd', + lapack_driver="gesvd", full_matrices=False, ) -def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'): +def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type="hard"): """Threshold the singular values. This method thresholds the input data using singular value decomposition. @@ -156,16 +153,11 @@ def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'): """ less_than_zero = isinstance(n_pc, int) and n_pc <= 0 - str_not_all = isinstance(n_pc, str) and n_pc != 'all' + str_not_all = isinstance(n_pc, str) and n_pc != "all" - if ( - (not isinstance(n_pc, (int, str, type(None)))) - or less_than_zero - or str_not_all - ): + if (not isinstance(n_pc, (int, str, type(None)))) or less_than_zero or str_not_all: raise ValueError( - 'Invalid value for "n_pc", specify a positive integer value or ' - + '"all"', + 'Invalid value for "n_pc", specify a positive integer value or ' + '"all"', ) # Get SVD of input data. @@ -176,15 +168,14 @@ def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'): # Find the required number of principal components if not specified. if isinstance(n_pc, type(None)): n_pc = find_n_pc(u_vec, factor=0.1) - print('xxxx', n_pc, u_vec) + print("xxxx", n_pc, u_vec) # If the number of PCs is too large use all of the singular values. - if ( - (isinstance(n_pc, int) and n_pc >= s_values.size) - or (isinstance(n_pc, str) and n_pc == 'all') + if (isinstance(n_pc, int) and n_pc >= s_values.size) or ( + isinstance(n_pc, str) and n_pc == "all" ): n_pc = s_values.size - warn('Using all singular values.') + warn("Using all singular values.") threshold = s_values[n_pc - 1] @@ -192,7 +183,7 @@ def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'): s_new = thresh(s_values, threshold, thresh_type) if np.all(s_new == s_values): - warn('No change to singular values.') + warn("No change to singular values.") # Diagonalize the svd s_new = np.diag(s_new) @@ -206,7 +197,7 @@ def svd_thresh_coef_fast( threshold, n_vals=-1, extra_vals=5, - thresh_type='hard', + thresh_type="hard", ): """Threshold the singular values coefficients. @@ -241,7 +232,7 @@ def svd_thresh_coef_fast( ok = False while not ok: (u_vec, s_values, v_vec) = svds(input_data, k=n_vals) - ok = (s_values[0] <= threshold or n_vals == min(input_data.shape) - 1) + ok = s_values[0] <= threshold or n_vals == min(input_data.shape) - 1 n_vals = min(n_vals + extra_vals, *input_data.shape) s_values = thresh( @@ -259,7 +250,7 @@ def svd_thresh_coef_fast( ) -def svd_thresh_coef(input_data, operator, threshold, thresh_type='hard'): +def svd_thresh_coef(input_data, operator, threshold, thresh_type="hard"): """Threshold the singular values coefficients. This method thresholds the input data using singular value decomposition. @@ -287,7 +278,7 @@ def svd_thresh_coef(input_data, operator, threshold, thresh_type='hard'): """ if not callable(operator): - raise TypeError('Operator must be a callable function.') + raise TypeError("Operator must be a callable function.") # Get SVD of data matrix u_vec, s_values, v_vec = calculate_svd(input_data) @@ -302,10 +293,9 @@ def svd_thresh_coef(input_data, operator, threshold, thresh_type='hard'): array_shape = np.repeat(int(np.sqrt(u_vec.shape[0])), 2) # Compute threshold matrix. - ti = np.array([ - np.linalg.norm(elem) - for elem in operator(matrix2cube(u_vec, array_shape)) - ]) + ti = np.array( + [np.linalg.norm(elem) for elem in operator(matrix2cube(u_vec, array_shape))] + ) threshold *= np.repeat(ti, a_matrix.shape[1]).reshape(a_matrix.shape) # Threshold coefficients. diff --git a/modopt/signal/validation.py b/src/modopt/signal/validation.py similarity index 80% rename from modopt/signal/validation.py rename to src/modopt/signal/validation.py index 422a987b..cdf69b7d 100644 --- a/modopt/signal/validation.py +++ b/src/modopt/signal/validation.py @@ -16,6 +16,7 @@ def transpose_test( x_args=None, y_shape=None, y_args=None, + rng=None, ): """Transpose test. @@ -36,6 +37,8 @@ def transpose_test( Shape of transpose operator input data (default is ``None``) y_args : tuple, optional Arguments to be passed to transpose operator (default is ``None``) + rng: numpy.random.Generator or int or None (default is ``None``) + Initialized random number generator or seed. Raises ------ @@ -54,7 +57,7 @@ def transpose_test( """ if not callable(operator) or not callable(operator_t): - raise TypeError('The input operators must be callable functions.') + raise TypeError("The input operators must be callable functions.") if isinstance(y_shape, type(None)): y_shape = x_shape @@ -62,9 +65,11 @@ def transpose_test( if isinstance(y_args, type(None)): y_args = x_args + if not isinstance(rng, np.random.Generator): + rng = np.random.default_rng(rng) # Generate random arrays. - x_val = np.random.ranf(x_shape) - y_val = np.random.ranf(y_shape) + x_val = rng.random(x_shape) + y_val = rng.random(y_shape) # Calculate mx_y = np.sum(np.multiply(operator(x_val, x_args), y_val)) @@ -73,4 +78,4 @@ def transpose_test( x_mty = np.sum(np.multiply(x_val, operator_t(y_val, y_args))) # Test the difference between the two. - print(' - | - | =', np.abs(mx_y - x_mty)) + print(" - | - | =", np.abs(mx_y - x_mty)) diff --git a/modopt/signal/wavelet.py b/src/modopt/signal/wavelet.py similarity index 88% rename from modopt/signal/wavelet.py rename to src/modopt/signal/wavelet.py index bc4ffc70..b55b78d9 100644 --- a/modopt/signal/wavelet.py +++ b/src/modopt/signal/wavelet.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """WAVELET MODULE. This module contains methods for performing wavelet transformations using @@ -58,20 +56,20 @@ def execute(command_line): """ if not isinstance(command_line, str): - raise TypeError('Command line must be a string.') + raise TypeError("Command line must be a string.") command = command_line.split() process = sp.Popen(command, stdout=sp.PIPE, stderr=sp.PIPE) stdout, stderr = process.communicate() - return stdout.decode('utf-8'), stderr.decode('utf-8') + return stdout.decode("utf-8"), stderr.decode("utf-8") def call_mr_transform( input_data, - opt='', - path='./', + opt="", + path="./", remove_files=True, ): # pragma: no cover """Call ``mr_transform``. @@ -127,26 +125,23 @@ def call_mr_transform( """ if not import_astropy: - raise ImportError('Astropy package not found.') + raise ImportError("Astropy package not found.") if (not isinstance(input_data, np.ndarray)) or (input_data.ndim != 2): - raise ValueError('Input data must be a 2D numpy array.') + raise ValueError("Input data must be a 2D numpy array.") - executable = 'mr_transform' + executable = "mr_transform" # Make sure mr_transform is installed. is_executable(executable) # Create a unique string using the current date and time. - unique_string = ( - datetime.now().strftime('%Y.%m.%d_%H.%M.%S') - + str(getrandbits(128)) - ) + unique_string = datetime.now().strftime("%Y.%m.%d_%H.%M.%S") + str(getrandbits(128)) # Set the ouput file names. - file_name = '{0}mr_temp_{1}'.format(path, unique_string) - file_fits = '{0}.fits'.format(file_name) - file_mr = '{0}.mr'.format(file_name) + file_name = f"{path}mr_temp_{unique_string}" + file_fits = f"{file_name}.fits" + file_mr = f"{file_name}.mr" # Write the input data to a fits file. fits.writeto(file_fits, input_data) @@ -155,15 +150,15 @@ def call_mr_transform( opt = opt.split() # Prepare command and execute it - command_line = ' '.join([executable] + opt + [file_fits, file_mr]) + command_line = " ".join([executable, *opt, file_fits, file_mr]) stdout, _ = execute(command_line) # Check for errors - if any(word in stdout for word in ('bad', 'Error', 'Sorry')): + if any(word in stdout for word in ("bad", "Error", "Sorry")): remove(file_fits) message = '{0} raised following exception: "{1}"' raise RuntimeError( - message.format(executable, stdout.rstrip('\n')), + message.format(executable, stdout.rstrip("\n")), ) # Retrieve wavelet transformed data. @@ -198,12 +193,12 @@ def trim_filter(filter_array): min_idx = np.min(non_zero_indices, axis=-1) max_idx = np.max(non_zero_indices, axis=-1) - return filter_array[min_idx[0]:max_idx[0] + 1, min_idx[1]:max_idx[1] + 1] + return filter_array[min_idx[0] : max_idx[0] + 1, min_idx[1] : max_idx[1] + 1] def get_mr_filters( data_shape, - opt='', + opt="", coarse=False, trim=False, ): # pragma: no cover @@ -256,7 +251,7 @@ def get_mr_filters( return mr_filters[:-1] -def filter_convolve(input_data, filters, filter_rot=False, method='scipy'): +def filter_convolve(input_data, filters, filter_rot=False, method="scipy"): """Filter convolve. This method convolves the input image with the wavelet filters. @@ -315,16 +310,14 @@ def filter_convolve(input_data, filters, filter_rot=False, method='scipy'): axis=0, ) - return np.array([ - convolve(input_data, filt, method=method) for filt in filters - ]) + return np.array([convolve(input_data, filt, method=method) for filt in filters]) def filter_convolve_stack( input_data, filters, filter_rot=False, - method='scipy', + method="scipy", ): """Filter convolve. @@ -366,7 +359,9 @@ def filter_convolve_stack( """ # Return the convolved data cube. - return np.array([ - filter_convolve(elem, filters, filter_rot=filter_rot, method=method) - for elem in input_data - ]) + return np.array( + [ + filter_convolve(elem, filters, filter_rot=filter_rot, method=method) + for elem in input_data + ] + ) diff --git a/modopt/tests/test_algorithms.py b/tests/test_algorithms.py similarity index 96% rename from modopt/tests/test_algorithms.py rename to tests/test_algorithms.py index 5671b8e3..63847764 100644 --- a/modopt/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """UNIT TESTS FOR Algorithms. This module contains unit tests for the modopt.opt module. @@ -11,18 +9,13 @@ import numpy as np import numpy.testing as npt -import pytest from modopt.opt import algorithms, cost, gradient, linear, proximity, reweight from pytest_cases import ( - case, fixture, - fixture_ref, - lazy_value, parametrize, parametrize_with_cases, ) -from test_helpers import Dummy SKLEARN_AVAILABLE = True try: @@ -31,6 +24,9 @@ SKLEARN_AVAILABLE = False +rng = np.random.default_rng() + + @fixture def idty(): """Identity function.""" @@ -80,7 +76,7 @@ def build_kwargs(kwargs, use_metrics): @parametrize(use_metrics=[True, False]) class AlgoCases: - """Cases for algorithms. + r"""Cases for algorithms. Most of the test solves the trivial problem @@ -91,7 +87,7 @@ class AlgoCases: """ data1 = np.arange(9).reshape(3, 3).astype(float) - data2 = data1 + np.random.randn(*data1.shape) * 1e-6 + data2 = data1 + rng.standard_normal(data1.shape) * 1e-6 max_iter = 20 @parametrize( @@ -111,8 +107,7 @@ class AlgoCases: ] ) def case_forward_backward(self, kwargs, idty, use_metrics): - """Forward Backward case. - """ + """Forward Backward case.""" update_kwargs = build_kwargs(kwargs, use_metrics) algo = algorithms.ForwardBackward( self.data1, @@ -242,9 +237,11 @@ def case_grad(self, GradDescent, use_metrics, idty): ) algo.iterate() return algo, update_kwargs - @parametrize(admm=[algorithms.ADMM,algorithms.FastADMM]) + + @parametrize(admm=[algorithms.ADMM, algorithms.FastADMM]) def case_admm(self, admm, use_metrics, idty): """ADMM setup.""" + def optim1(init, obs): return obs @@ -265,6 +262,7 @@ def optim2(init, obs): algo.iterate() return algo, update_kwargs + @parametrize_with_cases("algo, kwargs", cases=AlgoCases) def test_algo(algo, kwargs): """Test algorithms.""" diff --git a/modopt/tests/test_base.py b/tests/test_base.py similarity index 99% rename from modopt/tests/test_base.py rename to tests/test_base.py index e32ff94b..62e09095 100644 --- a/modopt/tests/test_base.py +++ b/tests/test_base.py @@ -5,6 +5,7 @@ Samuel Farrens Pierre-Antoine Comby """ + import numpy as np import numpy.testing as npt import pytest @@ -148,7 +149,7 @@ def test_matrix2cube(self): class TestType: """Test for type module.""" - data_list = list(range(5)) + data_list = list(range(5)) # noqa: RUF012 data_int = np.arange(5) data_flt = np.arange(5).astype(float) diff --git a/tests/test_helpers/__init__.py b/tests/test_helpers/__init__.py new file mode 100644 index 00000000..0ded847a --- /dev/null +++ b/tests/test_helpers/__init__.py @@ -0,0 +1,5 @@ +"""Utilities for tests.""" + +from .utils import Dummy, failparam, skipparam + +__all__ = ["Dummy", "failparam", "skipparam"] diff --git a/modopt/tests/test_helpers/utils.py b/tests/test_helpers/utils.py similarity index 75% rename from modopt/tests/test_helpers/utils.py rename to tests/test_helpers/utils.py index d8227640..41f948a6 100644 --- a/modopt/tests/test_helpers/utils.py +++ b/tests/test_helpers/utils.py @@ -1,9 +1,11 @@ """ Some helper functions for the test parametrization. + They should be used inside ``@pytest.mark.parametrize`` call. :Author: Pierre-Antoine Comby """ + import pytest @@ -11,13 +13,15 @@ def failparam(*args, raises=None): """Return a pytest parameterization that should raise an error.""" if not issubclass(raises, Exception): raise ValueError("raises should be an expected Exception.") - return pytest.param(*args, marks=pytest.mark.raises(exception=raises)) + return pytest.param(*args, marks=[pytest.mark.xfail(exception=raises)]) def skipparam(*args, cond=True, reason=""): """Return a pytest parameterization that should be skip if cond is valid.""" - return pytest.param(*args, marks=pytest.mark.skipif(cond, reason=reason)) + return pytest.param(*args, marks=[pytest.mark.skipif(cond, reason=reason)]) class Dummy: + """Dummy Class.""" + pass diff --git a/modopt/tests/test_math.py b/tests/test_math.py similarity index 95% rename from modopt/tests/test_math.py rename to tests/test_math.py index ea177b15..5c466e5e 100644 --- a/modopt/tests/test_math.py +++ b/tests/test_math.py @@ -6,6 +6,7 @@ Samuel Farrens Pierre-Antoine Comby """ + import pytest from test_helpers import failparam, skipparam @@ -28,6 +29,8 @@ else: SKIMAGE_AVAILABLE = True +rng = np.random.default_rng(1) + class TestConvolve: """Test convolve functions.""" @@ -135,18 +138,15 @@ class TestMatrix: ), ) - @pytest.fixture + @pytest.fixture(scope="module") def pm_instance(self, request): """Power Method instance.""" - np.random.seed(1) pm = matrix.PowerMethod( lambda x_val: x_val.dot(x_val.T), self.array33.shape, - auto_run=request.param, verbose=True, + rng=np.random.default_rng(0), ) - if not request.param: - pm.get_spec_rad(max_iter=1) return pm @pytest.mark.parametrize( @@ -194,12 +194,7 @@ def test_rotate(self): npt.assert_raises(ValueError, matrix.rotate, self.array23, np.pi / 2) - @pytest.mark.parametrize( - ("pm_instance", "value"), - [(True, 1.0), (False, 0.8675467477372257)], - indirect=["pm_instance"], - ) - def test_power_method(self, pm_instance, value): + def test_power_method(self, pm_instance, value=1): """Test power method.""" npt.assert_almost_equal(pm_instance.spec_rad, value) npt.assert_almost_equal(pm_instance.inv_spec_rad, 1 / value) @@ -210,8 +205,8 @@ class TestMetrics: data1 = np.arange(49).reshape(7, 7) mask = np.ones(data1.shape) - ssim_res = 0.8963363560519094 - ssim_mask_res = 0.805154442543846 + ssim_res = 0.8958315888566867 + ssim_mask_res = 0.8023827544418249 snr_res = 10.134554256920536 psnr_res = 14.860761791850397 mse_res = 0.03265305507330247 diff --git a/modopt/tests/test_opt.py b/tests/test_opt.py similarity index 92% rename from modopt/tests/test_opt.py rename to tests/test_opt.py index 7c30186e..2ea58c27 100644 --- a/modopt/tests/test_opt.py +++ b/tests/test_opt.py @@ -36,11 +36,28 @@ except ImportError: PYWT_AVAILABLE = False +rng = np.random.default_rng() + + # Basic functions to be used as operators or as dummy functions -func_identity = lambda x_val: x_val -func_double = lambda x_val: x_val * 2 -func_sq = lambda x_val: x_val**2 -func_cube = lambda x_val: x_val**3 +def func_identity(x_val, *args, **kwargs): + """Return x.""" + return x_val + + +def func_double(x_val, *args, **kwargs): + """Double x.""" + return x_val * 2 + + +def func_sq(x_val, *args, **kwargs): + """Square x.""" + return x_val**2 + + +def func_cube(x_val, *args, **kwargs): + """Cube x.""" + return x_val**3 @case(tags="cost") @@ -183,19 +200,35 @@ def case_linear_wavelet_convolve(self): @parametrize( compute_backend=[ - pytest.param("numpy", marks=pytest.mark.skipif(not PYWT_AVAILABLE, reason="PyWavelet not available.")), - pytest.param("cupy", marks=pytest.mark.skipif(not PTWT_AVAILABLE, reason="Pytorch Wavelet not available.")) - ]) - def case_linear_wavelet_transform(self, compute_backend="numpy"): + pytest.param( + "numpy", + marks=pytest.mark.skipif( + not PYWT_AVAILABLE, reason="PyWavelet not available." + ), + ), + pytest.param( + "cupy", + marks=pytest.mark.skipif( + not PTWT_AVAILABLE, reason="Pytorch Wavelet not available." + ), + ), + ] + ) + def case_linear_wavelet_transform(self, compute_backend): + """Case linear wavelet operator.""" linop = linear.WaveletTransform( wavelet_name="haar", shape=(8, 8), level=2, ) data_op = np.arange(64).reshape(8, 8).astype(float) - res_op, slices, shapes = pywt.ravel_coeffs(pywt.wavedecn(data_op, "haar", level=2)) + res_op, slices, shapes = pywt.ravel_coeffs( + pywt.wavedecn(data_op, "haar", level=2) + ) data_adj_op = linop.op(data_op) - res_adj_op = pywt.waverecn(pywt.unravel_coeffs(data_adj_op, slices, shapes, "wavedecn"), "haar") + res_adj_op = pywt.waverecn( + pywt.unravel_coeffs(data_adj_op, slices, shapes, "wavedecn"), "haar" + ) return linop, data_op, data_adj_op, res_op, res_adj_op @parametrize(weights=[[1.0, 1.0], None]) @@ -299,6 +332,7 @@ class ProxCases: array233 = np.arange(18).reshape(2, 3, 3).astype(float) array233_2 = np.array( + [ [ [2.73843189, 3.14594066, 3.55344943], [3.9609582, 4.36846698, 4.77597575], @@ -309,6 +343,7 @@ class ProxCases: [11.67394789, 12.87497954, 14.07601119], [15.27704284, 16.47807449, 17.67910614], ], + ] ) array233_3 = np.array( [ @@ -428,7 +463,7 @@ def case_prox_grouplasso(self, use_weights): else: weights = np.tile(np.zeros((3, 3)), (4, 1, 1)) - random_data = 3 * np.random.random(weights[0].shape) + random_data = 3 * rng.random(weights[0].shape) random_data_tile = np.tile(random_data, (weights.shape[0], 1, 1)) if use_weights: gl_result_data = 2 * random_data_tile - 3 diff --git a/modopt/tests/test_signal.py b/tests/test_signal.py similarity index 96% rename from modopt/tests/test_signal.py rename to tests/test_signal.py index 202e541b..6dbb0bba 100644 --- a/modopt/tests/test_signal.py +++ b/tests/test_signal.py @@ -16,7 +16,8 @@ class TestFilter: - """Test filter module""" + """Test filter module.""" + @pytest.mark.parametrize( ("norm", "result"), [(True, 0.24197072451914337), (False, 0.60653065971263342)] ) @@ -24,7 +25,6 @@ def test_gaussian_filter(self, norm, result): """Test gaussian filter.""" npt.assert_almost_equal(filter.gaussian_filter(1, 1, norm=norm), result) - def test_mex_hat(self): """Test mexican hat filter.""" npt.assert_almost_equal( @@ -32,7 +32,6 @@ def test_mex_hat(self): -0.35213905225713371, ) - def test_mex_hat_dir(self): """Test directional mexican hat filter.""" npt.assert_almost_equal( @@ -46,13 +45,13 @@ class TestNoise: data1 = np.arange(9).reshape(3, 3).astype(float) data2 = np.array( - [[0, 2.0, 2.0], [4.0, 5.0, 10], [11.0, 15.0, 18.0]], + [[0, 3.0, 4.0], [6.0, 9.0, 8.0], [14.0, 14.0, 17.0]], ) data3 = np.array( [ - [1.62434536, 0.38824359, 1.47182825], - [1.92703138, 4.86540763, 2.6984613], - [7.74481176, 6.2387931, 8.3190391], + [0.3455842, 1.8216181, 2.3304371], + [1.6968428, 4.9053559, 5.4463746], + [5.4630468, 7.5811181, 8.3645724], ] ) data4 = np.array([[0, 0, 0], [0, 0, 5.0], [6.0, 7.0, 8.0]]) @@ -71,9 +70,10 @@ class TestNoise: ) def test_add_noise(self, data, noise_type, sigma, data_noise): """Test add_noise.""" - np.random.seed(1) + rng = np.random.default_rng(1) npt.assert_almost_equal( - noise.add_noise(data, sigma=sigma, noise_type=noise_type), data_noise + noise.add_noise(data, sigma=sigma, noise_type=noise_type, rng=rng), + data_noise, ) @pytest.mark.parametrize( @@ -86,13 +86,16 @@ def test_thresh(self, threshold_type, result): noise.thresh(self.data1, 5, threshold_type=threshold_type), result ) + class TestPositivity: """Test positivity module.""" + data1 = np.arange(9).reshape(3, 3).astype(float) data4 = np.array([[0, 0, 0], [0, 0, 5.0], [6.0, 7.0, 8.0]]) data5 = np.array( [[0, 0, 0], [0, 0, 0], [1.0, 2.0, 3.0]], ) + @pytest.mark.parametrize( ("value", "expected"), [ @@ -231,6 +234,7 @@ def test_svd_thresh_coef(self, data, operator): # TODO test_svd_thresh_coef_fast + class TestValidation: """Test validation Module.""" @@ -238,13 +242,13 @@ class TestValidation: def test_transpose_test(self): """Test transpose_test.""" - np.random.seed(2) npt.assert_equal( validation.transpose_test( lambda x_val, y_val: x_val.dot(y_val), lambda x_val, y_val: x_val.dot(y_val.T), self.array33.shape, x_args=self.array33, + rng=2, ), None, )