diff --git a/.github/workflows/cd-build.yml b/.github/workflows/cd-build.yml
index fca9feb1..21b6cc3a 100644
--- a/.github/workflows/cd-build.yml
+++ b/.github/workflows/cd-build.yml
@@ -14,32 +14,28 @@ jobs:
steps:
- name: Checkout
- uses: actions/checkout@v2
+ uses: actions/checkout@v4
- - name: Set up Conda with Python 3.8
- uses: conda-incubator/setup-miniconda@v2
+ - uses: actions/setup-python@v4
with:
- auto-update-conda: true
- python-version: 3.8
- auto-activate-base: false
+ python-version: "3.10"
+ cache: pip
- name: Install dependencies
shell: bash -l {0}
run: |
python -m pip install --upgrade pip
- python -m pip install -r develop.txt
python -m pip install twine
- python -m pip install .
+ python -m pip install .[doc,test]
- name: Run Tests
shell: bash -l {0}
run: |
- python setup.py test
+ pytest
- name: Check distribution
shell: bash -l {0}
run: |
- python setup.py sdist
twine check dist/*
- name: Upload coverage to Codecov
@@ -57,20 +53,15 @@ jobs:
steps:
- name: Checkout
- uses: actions/checkout@v2
+ uses: actions/checkout@v4
- - name: Set up Conda with Python 3.8
- uses: conda-incubator/setup-miniconda@v2
- with:
- python-version: "3.8"
- name: Install dependencies
shell: bash -l {0}
run: |
conda install -c conda-forge pandoc
python -m pip install --upgrade pip
- python -m pip install -r docs/requirements.txt
- python -m pip install .
+ python -m pip install .[doc]
- name: Build API documentation
shell: bash -l {0}
diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml
index c4ba28a0..3a209d12 100644
--- a/.github/workflows/ci-build.yml
+++ b/.github/workflows/ci-build.yml
@@ -16,61 +16,41 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
- python-version: ["3.10"]
+ python-version: ["3.8", "3.9", "3.10"]
steps:
- - name: Checkout
- uses: actions/checkout@v2
-
- - name: Set up Conda with Python ${{ matrix.python-version }}
- uses: conda-incubator/setup-miniconda@v2
+ - uses: actions/checkout@v4
+ - uses: actions/setup-python@v4
with:
- auto-update-conda: true
python-version: ${{ matrix.python-version }}
- auto-activate-base: false
-
- - name: Check Conda
- shell: bash -l {0}
- run: |
- conda info
- conda list
- python --version
+ cache: pip
- name: Install Dependencies
shell: bash -l {0}
run: |
python --version
python -m pip install --upgrade pip
- python -m pip install -r develop.txt
- python -m pip install -r docs/requirements.txt
- python -m pip install astropy "scikit-image<0.20" scikit-learn matplotlib
- python -m pip install tensorflow>=2.4.1
- python -m pip install twine
- python -m pip install .
+ python -m pip install .[test]
+ python -m pip install astropy scikit-image scikit-learn matplotlib
+ python -m pip install tensorflow>=2.4.1 torch
- name: Run Tests
shell: bash -l {0}
run: |
- export PATH=/usr/share/miniconda/bin:$PATH
pytest -n 2
- name: Save Test Results
if: always()
- uses: actions/upload-artifact@v2
+ uses: actions/upload-artifact@v4
with:
name: unit-test-results-${{ matrix.os }}-${{ matrix.python-version }}
- path: pytest.xml
-
- - name: Check Distribution
- shell: bash -l {0}
- run: |
- python setup.py sdist
- twine check dist/*
+ path: coverage.xml
- name: Check API Documentation build
shell: bash -l {0}
run: |
- conda install -c conda-forge pandoc
+ apt install pandoc
+ pip install .[doc] ipykernel
sphinx-apidoc -t docs/_templates -feTMo docs/source modopt
sphinx-build -b doctest -E docs/source docs/_build
@@ -81,38 +61,3 @@ jobs:
file: coverage.xml
flags: unittests
- test-basic:
- name: Basic Test Suite
- runs-on: ${{ matrix.os }}
-
- strategy:
- fail-fast: false
- matrix:
- os: [ubuntu-latest, macos-latest]
- python-version: ["3.7", "3.8", "3.9"]
-
- steps:
- - name: Checkout
- uses: actions/checkout@v2
-
- - name: Set up Conda with Python ${{ matrix.python-version }}
- uses: conda-incubator/setup-miniconda@v2
- with:
- auto-update-conda: true
- python-version: ${{ matrix.python-version }}
- auto-activate-base: false
-
- - name: Install Dependencies
- shell: bash -l {0}
- run: |
- python --version
- python -m pip install --upgrade pip
- python -m pip install -r develop.txt
- python -m pip install astropy "scikit-image<0.20" scikit-learn matplotlib
- python -m pip install .
-
- - name: Run Tests
- shell: bash -l {0}
- run: |
- export PATH=/usr/share/miniconda/bin:$PATH
- pytest -n 2
diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml
new file mode 100644
index 00000000..45fc2b23
--- /dev/null
+++ b/.github/workflows/style.yml
@@ -0,0 +1,38 @@
+name: Style checking
+
+on:
+ push:
+ branches: [ "master", "main", "develop" ]
+ pull_request:
+ branches: [ "master", "main", "develop" ]
+
+ workflow_dispatch:
+
+env:
+ PYTHON_VERSION: "3.10"
+
+jobs:
+ linter-check:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v4
+ - name: Set up Python ${{ env.PYTHON_VERSION }}
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ env.PYTHON_VERSION }}
+ cache: pip
+
+ - name: Install Python deps
+ shell: bash
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install -e .[test,dev]
+
+ - name: Black Check
+ shell: bash
+ run: black . --diff --color --check
+
+ - name: ruff Check
+ shell: bash
+ run: ruff check
diff --git a/.pylintrc b/.pylintrc
deleted file mode 100644
index 3ac9aef9..00000000
--- a/.pylintrc
+++ /dev/null
@@ -1,2 +0,0 @@
-[MASTER]
-ignore-patterns=**/docs/**/*.py
diff --git a/.pyup.yml b/.pyup.yml
deleted file mode 100644
index 8fdac7ff..00000000
--- a/.pyup.yml
+++ /dev/null
@@ -1,14 +0,0 @@
-# autogenerated pyup.io config file
-# see https://pyup.io/docs/configuration/ for all available options
-
-schedule: ''
-update: all
-label_prs: update
-assignees: sfarrens
-requirements:
- - requirements.txt:
- pin: False
- - develop.txt:
- pin: False
- - docs/requirements.txt:
- pin: True
diff --git a/MANIFEST.in b/MANIFEST.in
deleted file mode 100644
index 9a2f374e..00000000
--- a/MANIFEST.in
+++ /dev/null
@@ -1,5 +0,0 @@
-include requirements.txt
-include develop.txt
-include docs/requirements.txt
-include README.rst
-include LICENSE.txt
diff --git a/develop.txt b/develop.txt
deleted file mode 100644
index 6ff665eb..00000000
--- a/develop.txt
+++ /dev/null
@@ -1,12 +0,0 @@
-coverage>=5.5
-pytest>=6.2.2
-pytest-raises>=0.10
-pytest-cases>= 3.6
-pytest-xdist>= 3.0.1
-pytest-cov>=2.11.1
-pytest-emoji>=0.2.0
-pydocstyle==6.1.1
-pytest-pydocstyle>=2.2.0
-black
-isort
-pytest-black
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 46564b9f..69921008 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
# Python Template sphinx config
# Import relevant modules
@@ -9,56 +8,53 @@
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
-sys.path.insert(0, os.path.abspath('../..'))
+sys.path.insert(0, os.path.abspath("../.."))
# -- General configuration ------------------------------------------------
# General information about the project.
-project = 'modopt'
+project = "modopt"
mdata = metadata(project)
-author = mdata['Author']
-version = mdata['Version']
-copyright = '2020, {}'.format(author)
-gh_user = 'sfarrens'
-
-# If your documentation needs a minimal Sphinx version, state it here.
-needs_sphinx = '3.3'
+author = "Samuel Farrens, Pierre-Antoine Comby, Chaithya GR, Philippe Ciuciu"
+version = mdata["Version"]
+copyright = f"2020, {author}"
+gh_user = "sfarrens"
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
- 'sphinx.ext.autodoc',
- 'sphinx.ext.autosummary',
- 'sphinx.ext.coverage',
- 'sphinx.ext.doctest',
- 'sphinx.ext.ifconfig',
- 'sphinx.ext.intersphinx',
- 'sphinx.ext.mathjax',
- 'sphinx.ext.napoleon',
- 'sphinx.ext.todo',
- 'sphinx.ext.viewcode',
- 'sphinxawesome_theme',
- 'sphinxcontrib.bibtex',
- 'myst_parser',
- 'nbsphinx',
- 'nbsphinx_link',
- 'numpydoc',
- "sphinx_gallery.gen_gallery"
+ "sphinx.ext.autodoc",
+ "sphinx.ext.autosummary",
+ "sphinx.ext.coverage",
+ "sphinx.ext.doctest",
+ "sphinx.ext.ifconfig",
+ "sphinx.ext.intersphinx",
+ "sphinx.ext.mathjax",
+ "sphinx.ext.napoleon",
+ "sphinx.ext.todo",
+ "sphinx.ext.viewcode",
+ "sphinxawesome_theme.highlighting",
+ "sphinxcontrib.bibtex",
+ "myst_parser",
+ "nbsphinx",
+ "nbsphinx_link",
+ "numpydoc",
+ "sphinx_gallery.gen_gallery",
]
# Include module names for objects
add_module_names = False
# Set class documentation standard.
-autoclass_content = 'class'
+autoclass_content = "class"
# Audodoc options
autodoc_default_options = {
- 'member-order': 'bysource',
- 'private-members': True,
- 'show-inheritance': True
+ "member-order": "bysource",
+ "private-members": True,
+ "show-inheritance": True,
}
# Generate summaries
@@ -69,17 +65,17 @@
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
-source_suffix = ['.rst', '.md']
+source_suffix = [".rst", ".md"]
# The master toctree document.
-master_doc = 'index'
+master_doc = "index"
# If true, sectionauthor and moduleauthor directives will be shown in the
# output. They are ignored by default.
show_authors = True
# The name of the Pygments (syntax highlighting) style to use.
-pygments_style = 'default'
+pygments_style = "default"
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True
@@ -88,7 +84,7 @@
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
-html_theme = 'sphinxawesome_theme'
+html_theme = "sphinxawesome_theme"
# html_theme = 'sphinx_book_theme'
# Theme options are theme-specific and customize the look and feel of a theme
@@ -101,11 +97,10 @@
"breadcrumbs_separator": "/",
"show_prev_next": True,
"show_scrolltop": True,
-
}
html_collapsible_definitions = True
html_awesome_headerlinks = True
-html_logo = 'modopt_logo.jpg'
+html_logo = "modopt_logo.png"
html_permalinks_icon = (
'
'''
+ r""" """
+ r""""""
)
nbsphinx_prolog = nb_header_pt1 + nb_header_pt2
@@ -240,28 +233,28 @@ def add_notebooks(nb_path='../../notebooks'):
# Refer to the package libraries for type definitions
intersphinx_mapping = {
- 'python': ('http://docs.python.org/3', None),
- 'numpy': ('https://numpy.org/doc/stable/', None),
- 'scipy': ('https://docs.scipy.org/doc/scipy/reference', None),
- 'progressbar': ('https://progressbar-2.readthedocs.io/en/latest/', None),
- 'matplotlib': ('https://matplotlib.org', None),
- 'astropy': ('http://docs.astropy.org/en/latest/', None),
- 'cupy': ('https://docs-cupy.chainer.org/en/stable/', None),
- 'torch': ('https://pytorch.org/docs/stable/', None),
- 'sklearn': (
- 'http://scikit-learn.org/stable',
- (None, './_intersphinx/sklearn-objects.inv')
+ "python": ("http://docs.python.org/3", None),
+ "numpy": ("https://numpy.org/doc/stable/", None),
+ "scipy": ("https://docs.scipy.org/doc/scipy/reference", None),
+ "progressbar": ("https://progressbar-2.readthedocs.io/en/latest/", None),
+ "matplotlib": ("https://matplotlib.org", None),
+ "astropy": ("http://docs.astropy.org/en/latest/", None),
+ "cupy": ("https://docs-cupy.chainer.org/en/stable/", None),
+ "torch": ("https://pytorch.org/docs/stable/", None),
+ "sklearn": (
+ "http://scikit-learn.org/stable",
+ (None, "./_intersphinx/sklearn-objects.inv"),
),
- 'tensorflow': (
- 'https://www.tensorflow.org/api_docs/python',
+ "tensorflow": (
+ "https://www.tensorflow.org/api_docs/python",
(
- 'https://github.com/GPflow/tensorflow-intersphinx/'
- + 'raw/master/tf2_py_objects.inv')
- )
-
+ "https://github.com/GPflow/tensorflow-intersphinx/"
+ + "raw/master/tf2_py_objects.inv"
+ ),
+ ),
}
# -- BibTeX Setting ----------------------------------------------
-bibtex_bibfiles = ['refs.bib', 'my_ref.bib']
-bibtex_default_style = 'alpha'
+bibtex_bibfiles = ["refs.bib", "my_ref.bib"]
+bibtex_default_style = "alpha"
diff --git a/modopt/examples/README.rst b/examples/README.rst
similarity index 100%
rename from modopt/examples/README.rst
rename to examples/README.rst
diff --git a/modopt/examples/__init__.py b/examples/__init__.py
similarity index 100%
rename from modopt/examples/__init__.py
rename to examples/__init__.py
diff --git a/modopt/examples/conftest.py b/examples/conftest.py
similarity index 95%
rename from modopt/examples/conftest.py
rename to examples/conftest.py
index 73358679..f3ed371b 100644
--- a/modopt/examples/conftest.py
+++ b/examples/conftest.py
@@ -11,10 +11,12 @@
https://stackoverflow.com/questions/56807698/how-to-run-script-as-pytest-test
"""
+
from pathlib import Path
import runpy
import pytest
+
def pytest_collect_file(path, parent):
"""Pytest hook.
@@ -22,7 +24,7 @@ def pytest_collect_file(path, parent):
The new node needs to have the specified parent as parent.
"""
p = Path(path)
- if p.suffix == '.py' and 'example' in p.name:
+ if p.suffix == ".py" and "example" in p.name:
return Script.from_parent(parent, path=p, name=p.name)
@@ -33,6 +35,7 @@ def collect(self):
"""Collect the script as its own item."""
yield ScriptItem.from_parent(self, name=self.name)
+
class ScriptItem(pytest.Item):
"""Item script collected by pytest."""
diff --git a/modopt/examples/example_lasso_forward_backward.py b/examples/example_lasso_forward_backward.py
similarity index 95%
rename from modopt/examples/example_lasso_forward_backward.py
rename to examples/example_lasso_forward_backward.py
index 7f820000..f3e5091d 100644
--- a/modopt/examples/example_lasso_forward_backward.py
+++ b/examples/example_lasso_forward_backward.py
@@ -1,4 +1,3 @@
-# noqa: D205
"""
Solving the LASSO Problem with the Forward Backward Algorithm.
==============================================================
@@ -76,7 +75,7 @@
prox=prox_op,
cost=cost_op_fista,
metric_call_period=1,
- auto_iterate=False, # Just to give us the pleasure of doing things by ourself.
+ auto_iterate=False, # Just to give us the pleasure of doing things by ourself.
)
fb_fista.iterate()
@@ -115,7 +114,7 @@
prox=prox_op,
cost=cost_op_pogm,
metric_call_period=1,
- auto_iterate=False, # Just to give us the pleasure of doing things by ourself.
+ auto_iterate=False, # Just to give us the pleasure of doing things by ourself.
)
fb_pogm.iterate()
@@ -133,6 +132,7 @@
#
# sphinx_gallery_start_ignore
assert mse(fb_pogm.x_final, BETA_TRUE) < 1
+# sphinx_gallery_end_ignore
# %%
# Comparing the Two algorithms
diff --git a/modopt/__init__.py b/modopt/__init__.py
deleted file mode 100644
index 2c06c1db..00000000
--- a/modopt/__init__.py
+++ /dev/null
@@ -1,24 +0,0 @@
-# -*- coding: utf-8 -*-
-
-"""MODOPT PACKAGE.
-
-ModOpt is a series of Modular Optimisation tools for solving inverse problems.
-
-"""
-
-from warnings import warn
-
-from importlib_metadata import version
-
-from modopt.base import *
-
-try:
- _version = version('modopt')
-except Exception: # pragma: no cover
- _version = 'Unkown'
- warn(
- 'Could not extract package metadata. Make sure the package is '
- + 'correctly installed.',
- )
-
-__version__ = _version
diff --git a/modopt/tests/test_helpers/__init__.py b/modopt/tests/test_helpers/__init__.py
deleted file mode 100644
index 3886b877..00000000
--- a/modopt/tests/test_helpers/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .utils import failparam, skipparam, Dummy
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 00000000..721c8b37
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,58 @@
+[project]
+name="modopt"
+description = 'Modular Optimisation tools for soliving inverse problems.'
+version = "1.7.1"
+requires-python= ">=3.8"
+
+authors = [{name="Samuel Farrens", email="samuel.farrens@cea.fr"},
+{name="Chaithya GR", email="chaithyagr@gmail.com"},
+{name="Pierre-Antoine Comby", email="pierre-antoine.comby@cea.fr"},
+{name="Philippe Ciuciu", email="philippe.ciuciu@cea.fr"}
+]
+readme="README.md"
+license={file="LICENCE.txt"}
+
+dependencies = ["numpy", "scipy", "tqdm", "importlib_metadata"]
+
+[project.optional-dependencies]
+gpu=["torch", "ptwt"]
+doc=["myst-parser",
+"nbsphinx",
+"nbsphinx-link",
+"sphinx-gallery",
+"numpydoc",
+"sphinxawesome-theme",
+"sphinxcontrib-bibtex"]
+dev=["black", "ruff"]
+test=["pytest<8.0.0", "pytest-cases", "pytest-cov", "pytest-xdist", "pytest-sugar"]
+
+[build-system]
+requires=["setuptools", "setuptools-scm[toml]", "wheel"]
+
+[tool.coverage.run]
+omit = ["*tests*", "*__init__*", "*setup.py*", "*_version.py*", "*example*"]
+
+[tool.coverage.report]
+precision = 2
+exclude_lines = ["pragma: no cover", "raise NotImplementedError"]
+
+[tool.black]
+
+
+[tool.ruff]
+exclude = ["examples", "docs"]
+[tool.ruff.lint]
+select = ["E", "F", "B", "Q", "UP", "D", "NPY", "RUF"]
+
+ignore = ["F401"] # we like the try: import ... expect: ...
+
+[tool.ruff.lint.pydocstyle]
+convention="numpy"
+
+[tool.isort]
+profile="black"
+
+[tool.pytest.ini_options]
+minversion = "6.0"
+norecursedirs = ["tests/test_helpers"]
+addopts = ["--cov=modopt", "--cov-report=term-missing", "--cov-report=xml"]
diff --git a/requirements.txt b/requirements.txt
deleted file mode 100644
index 1f44de13..00000000
--- a/requirements.txt
+++ /dev/null
@@ -1,4 +0,0 @@
-importlib_metadata>=3.7.0
-numpy>=1.19.5
-scipy>=1.5.4
-tqdm>=4.64.0
diff --git a/setup.cfg b/setup.cfg
deleted file mode 100644
index 100adb40..00000000
--- a/setup.cfg
+++ /dev/null
@@ -1,97 +0,0 @@
-[aliases]
-test=pytest
-
-[metadata]
-description_file = README.rst
-
-[darglint]
-docstring_style = numpy
-strictness = short
-
-[flake8]
-ignore =
- D107, #Justification: Don't need docstring for __init__ in numpydoc style
- RST304, #Justification: Need to use :cite: role for citations
- RST210, #Justification: RST210, RST213 Inconsistent with numpydoc
- RST213, # documentation for handling *args and **kwargs
- W503, #Justification: Have to choose one multiline operator format
- WPS202, #Todo: Rethink module size, possibly split large modules
- WPS337, #Todo: Consider simplifying multiline conditions.
- WPS338, #Todo: Consider changing method order
- WPS403, #Todo: Rethink no cover lines
- WPS421, #Todo: Review need for print statements
- WPS432, #Justification: Mathematical codes require "magic numbers"
- WPS433, #Todo: Rethink conditional imports
- WPS463, #Todo: Rename get_ methods
- WPS615, #Todo: Rename get_ methods
-per-file-ignores =
- #Justification: Needed for keeping package version and current API
- *__init__.py*: F401,F403,WPS347,WPS410,WPS412
- #Todo: Rethink conditional imports
- #Todo: How can we bypass mutable constants?
- modopt/base/backend.py: WPS229, WPS420, WPS407
- #Todo: Rethink conditional imports
- modopt/base/observable.py: WPS420,WPS604
- #Todo: Check string for log formatting
- modopt/interface/log.py: WPS323
- #Todo: Rethink conditional imports
- modopt/math/convolve.py: WPS301,WPS420
- #Todo: Rethink conditional imports
- modopt/math/matrix.py: WPS420
- #Todo: import has bad parenthesis
- modopt/opt/algorithms/__init__.py: F401,F403,WPS318, WPS319, WPS412, WPS410
- #Todo: x is a too short name.
- modopt/opt/algorithms/forward_backward.py: WPS111
- #Todo: u,v , A is a too short name.
- modopt/opt/algorithms/admm.py: WPS111, N803
- #Todo: Check need for del statement
- modopt/opt/algorithms/primal_dual.py: WPS111, WPS420
- #multiline parameters bug with tuples
- modopt/opt/algorithms/gradient_descent.py: WPS111, WPS420, WPS317
- #Todo: Consider changing costObj name
- modopt/opt/cost.py: N801,
- #Todo:
- # - Rethink subscript slice assignment
- # - Reduce complexity of KSupportNorm
- # - Check bitwise operations
- modopt/opt/proximity.py: WPS220,WPS231,WPS352,WPS362,WPS465,WPS506,WPS508
- #Todo: Consider changing cwbReweight name
- modopt/opt/reweight.py: N801
- #Justification: Needed to import matplotlib.pyplot
- modopt/plot/cost_plot.py: N802,WPS301
- #Todo: Investigate possible bug in find_n_pc function
- #Todo: Investigate darglint error
- modopt/signal/svd.py: WPS345, DAR000
- #Todo: Check security of using system executable call
- modopt/signal/wavelet.py: S404,S603
- #Todo: Clean up tests
- modopt/tests/*.py: E731,F401,WPS301,WPS420,WPS425,WPS437,WPS604
- #Todo: Import has bad parenthesis
- modopt/tests/test_base.py: WPS318,WPS319,E501,WPS301
-#WPS Settings
-max-arguments = 25
-max-attributes = 40
-max-cognitive-score = 20
-max-function-expressions = 20
-max-line-complexity = 30
-max-local-variables = 10
-max-methods = 20
-max-module-expressions = 20
-max-string-usages = 20
-max-raises = 5
-
-[tool:pytest]
-norecursedirs=tests/test_helpers
-testpaths =
- modopt
-addopts =
- --verbose
- --cov=modopt
- --cov-report=term-missing
- --cov-report=xml
- --junitxml=pytest.xml
- --pydocstyle
-
-[pydocstyle]
-convention=numpy
-add-ignore=D107
diff --git a/setup.py b/setup.py
deleted file mode 100644
index e6a8a9e6..00000000
--- a/setup.py
+++ /dev/null
@@ -1,73 +0,0 @@
-#! /usr/bin/env python
-# -*- coding: utf-8 -*-
-
-from setuptools import setup, find_packages
-import os
-
-# Set the package release version
-major = 1
-minor = 7
-patch = 1
-
-# Set the package details
-name = 'modopt'
-version = '.'.join(str(value) for value in (major, minor, patch))
-author = 'Samuel Farrens'
-email = 'samuel.farrens@cea.fr'
-gh_user = 'cea-cosmic'
-url = 'https://github.com/{0}/{1}'.format(gh_user, name)
-description = 'Modular Optimisation tools for soliving inverse problems.'
-license = 'MIT'
-
-# Set the package classifiers
-python_versions_supported = ['3.7', '3.8', '3.9', '3.10', '3.11']
-os_platforms_supported = ['Unix', 'MacOS']
-
-lc_str = 'License :: OSI Approved :: {0} License'
-ln_str = 'Programming Language :: Python'
-py_str = 'Programming Language :: Python :: {0}'
-os_str = 'Operating System :: {0}'
-
-classifiers = (
- [lc_str.format(license)]
- + [ln_str]
- + [py_str.format(ver) for ver in python_versions_supported]
- + [os_str.format(ops) for ops in os_platforms_supported]
-)
-
-# Source package description from README.md
-this_directory = os.path.abspath(os.path.dirname(__file__))
-with open(os.path.join(this_directory, 'README.md'), encoding='utf-8') as f:
- long_description = f.read()
-
-# Source package requirements from requirements.txt
-with open('requirements.txt') as open_file:
- install_requires = open_file.read()
-
-# Source test requirements from develop.txt
-with open('develop.txt') as open_file:
- tests_require = open_file.read()
-
-# Source doc requirements from docs/requirements.txt
-with open('docs/requirements.txt') as open_file:
- docs_require = open_file.read()
-
-
-setup(
- name=name,
- author=author,
- author_email=email,
- version=version,
- license=license,
- url=url,
- description=description,
- long_description=long_description,
- long_description_content_type='text/markdown',
- packages=find_packages(),
- install_requires=install_requires,
- python_requires='>=3.6',
- setup_requires=['pytest-runner'],
- tests_require=tests_require,
- extras_require={'develop': tests_require + docs_require},
- classifiers=classifiers,
-)
diff --git a/src/modopt/__init__.py b/src/modopt/__init__.py
new file mode 100644
index 00000000..5e8de1b6
--- /dev/null
+++ b/src/modopt/__init__.py
@@ -0,0 +1,25 @@
+"""MODOPT PACKAGE.
+
+ModOpt is a series of Modular Optimisation tools for solving inverse problems.
+
+"""
+
+from warnings import warn
+
+from importlib_metadata import version
+
+from modopt.base import np_adjust, transform, types, observable
+
+__all__ = ["np_adjust", "transform", "types", "observable"]
+
+try:
+ _version = version("modopt")
+except Exception: # pragma: no cover
+ _version = "Unkown"
+ warn(
+ "Could not extract package metadata. Make sure the package is "
+ + "correctly installed.",
+ stacklevel=1,
+ )
+
+__version__ = _version
diff --git a/modopt/base/__init__.py b/src/modopt/base/__init__.py
similarity index 71%
rename from modopt/base/__init__.py
rename to src/modopt/base/__init__.py
index 88424bae..c4c681d7 100644
--- a/modopt/base/__init__.py
+++ b/src/modopt/base/__init__.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""BASE ROUTINES.
This module contains submodules for basic operations such as type
@@ -9,4 +7,4 @@
"""
-__all__ = ['np_adjust', 'transform', 'types', 'observable']
+__all__ = ["np_adjust", "transform", "types", "observable"]
diff --git a/modopt/base/backend.py b/src/modopt/base/backend.py
similarity index 77%
rename from modopt/base/backend.py
rename to src/modopt/base/backend.py
index 1f4e9a72..485f649a 100644
--- a/modopt/base/backend.py
+++ b/src/modopt/base/backend.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""BACKEND MODULE.
This module contains methods for GPU Compatiblity.
@@ -26,22 +24,24 @@
# Handle the compatibility with variable
LIBRARIES = {
- 'cupy': None,
- 'tensorflow': None,
- 'numpy': np,
+ "cupy": None,
+ "tensorflow": None,
+ "numpy": np,
}
-if util.find_spec('cupy') is not None:
+if util.find_spec("cupy") is not None:
try:
import cupy as cp
- LIBRARIES['cupy'] = cp
+
+ LIBRARIES["cupy"] = cp
except ImportError:
pass
-if util.find_spec('tensorflow') is not None:
+if util.find_spec("tensorflow") is not None:
try:
from tensorflow.experimental import numpy as tnp
- LIBRARIES['tensorflow'] = tnp
+
+ LIBRARIES["tensorflow"] = tnp
except ImportError:
pass
@@ -66,12 +66,12 @@ def get_backend(backend):
"""
if backend not in LIBRARIES.keys() or LIBRARIES[backend] is None:
msg = (
- '{0} backend not possible, please ensure that '
- + 'the optional libraries are installed.\n'
- + 'Reverting to numpy.'
+ "{0} backend not possible, please ensure that "
+ + "the optional libraries are installed.\n"
+ + "Reverting to numpy."
)
warn(msg.format(backend))
- backend = 'numpy'
+ backend = "numpy"
return LIBRARIES[backend], backend
@@ -92,16 +92,16 @@ def get_array_module(input_data):
The numpy or cupy module
"""
- if LIBRARIES['tensorflow'] is not None:
- if isinstance(input_data, LIBRARIES['tensorflow'].ndarray):
- return LIBRARIES['tensorflow']
- if LIBRARIES['cupy'] is not None:
- if isinstance(input_data, LIBRARIES['cupy'].ndarray):
- return LIBRARIES['cupy']
+ if LIBRARIES["tensorflow"] is not None:
+ if isinstance(input_data, LIBRARIES["tensorflow"].ndarray):
+ return LIBRARIES["tensorflow"]
+ if LIBRARIES["cupy"] is not None:
+ if isinstance(input_data, LIBRARIES["cupy"].ndarray):
+ return LIBRARIES["cupy"]
return np
-def change_backend(input_data, backend='cupy'):
+def change_backend(input_data, backend="cupy"):
"""Move data to device.
This method changes the backend of an array. This can be used to copy data
@@ -151,13 +151,13 @@ def move_to_cpu(input_data):
"""
xp = get_array_module(input_data)
- if xp == LIBRARIES['numpy']:
+ if xp == LIBRARIES["numpy"]:
return input_data
- elif xp == LIBRARIES['cupy']:
+ elif xp == LIBRARIES["cupy"]:
return input_data.get()
- elif xp == LIBRARIES['tensorflow']:
+ elif xp == LIBRARIES["tensorflow"]:
return input_data.data.numpy()
- raise ValueError('Cannot identify the array type.')
+ raise ValueError("Cannot identify the array type.")
def convert_to_tensor(input_data):
@@ -184,9 +184,9 @@ def convert_to_tensor(input_data):
"""
if not import_torch:
raise ImportError(
- 'Required version of Torch package not found'
- + 'see documentation for details: https://cea-cosmic.'
- + 'github.io/ModOpt/#optional-packages',
+ "Required version of Torch package not found"
+ + "see documentation for details: https://cea-cosmic."
+ + "github.io/ModOpt/#optional-packages",
)
xp = get_array_module(input_data)
@@ -220,9 +220,9 @@ def convert_to_cupy_array(input_data):
"""
if not import_torch:
raise ImportError(
- 'Required version of Torch package not found'
- + 'see documentation for details: https://cea-cosmic.'
- + 'github.io/ModOpt/#optional-packages',
+ "Required version of Torch package not found"
+ + "see documentation for details: https://cea-cosmic."
+ + "github.io/ModOpt/#optional-packages",
)
if input_data.is_cuda:
diff --git a/modopt/base/np_adjust.py b/src/modopt/base/np_adjust.py
similarity index 96%
rename from modopt/base/np_adjust.py
rename to src/modopt/base/np_adjust.py
index 6d290e43..10cb5c29 100644
--- a/modopt/base/np_adjust.py
+++ b/src/modopt/base/np_adjust.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""NUMPY ADJUSTMENT ROUTINES.
This module contains methods for adjusting the default output for certain
@@ -154,8 +152,7 @@ def pad2d(input_data, padding):
padding = np.array(padding)
elif not isinstance(padding, np.ndarray):
raise ValueError(
- 'Padding must be an integer or a tuple (or list, np.ndarray) '
- + 'of itegers',
+ "Padding must be an integer or a tuple (or list, np.ndarray) of integers",
)
if padding.size == 1:
@@ -164,7 +161,7 @@ def pad2d(input_data, padding):
pad_x = (padding[0], padding[0])
pad_y = (padding[1], padding[1])
- return np.pad(input_data, (pad_x, pad_y), 'constant')
+ return np.pad(input_data, (pad_x, pad_y), "constant")
def ftr(input_data):
diff --git a/modopt/base/observable.py b/src/modopt/base/observable.py
similarity index 95%
rename from modopt/base/observable.py
rename to src/modopt/base/observable.py
index 6471ba58..bf8371c3 100644
--- a/modopt/base/observable.py
+++ b/src/modopt/base/observable.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""Observable.
This module contains observable classes
@@ -13,13 +11,13 @@
import numpy as np
-class SignalObject(object):
+class SignalObject:
"""Dummy class for signals."""
pass
-class Observable(object):
+class Observable:
"""Base class for observable classes.
This class defines a simple interface to add or remove observers
@@ -33,7 +31,6 @@ class Observable(object):
"""
def __init__(self, signals):
-
# Define class parameters
self._allowed_signals = []
self._observers = {}
@@ -177,7 +174,7 @@ def _remove_observer(self, signal, observer):
self._observers[signal].remove(observer)
-class MetricObserver(object):
+class MetricObserver:
"""Metric observer.
Wrapper of the metric to the observer object notify by the Observable
@@ -215,7 +212,6 @@ def __init__(
wind=6,
eps=1.0e-3,
):
-
self.name = name
self.metric = metric
self.mapping = mapping
@@ -264,9 +260,7 @@ def is_converge(self):
mid_idx = -(self.wind // 2)
old_mean = np.array(self.list_cv_values[start_idx:mid_idx]).mean()
current_mean = np.array(self.list_cv_values[mid_idx:]).mean()
- normalize_residual_metrics = (
- np.abs(old_mean - current_mean) / np.abs(old_mean)
- )
+ normalize_residual_metrics = np.abs(old_mean - current_mean) / np.abs(old_mean)
self.converge_flag = normalize_residual_metrics < self.eps
def retrieve_metrics(self):
@@ -287,7 +281,7 @@ def retrieve_metrics(self):
time_val -= time_val[0]
return {
- 'time': time_val,
- 'index': self.list_iters,
- 'values': self.list_cv_values,
+ "time": time_val,
+ "index": self.list_iters,
+ "values": self.list_cv_values,
}
diff --git a/modopt/base/transform.py b/src/modopt/base/transform.py
similarity index 87%
rename from modopt/base/transform.py
rename to src/modopt/base/transform.py
index 07ce846f..25ed102a 100644
--- a/modopt/base/transform.py
+++ b/src/modopt/base/transform.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""DATA TRANSFORM ROUTINES.
This module contains methods for transforming data.
@@ -53,18 +51,17 @@ def cube2map(data_cube, layout):
"""
if data_cube.ndim != 3:
- raise ValueError('The input data must have 3 dimensions.')
+ raise ValueError("The input data must have 3 dimensions.")
if data_cube.shape[0] != np.prod(layout):
raise ValueError(
- 'The desired layout must match the number of input '
- + 'data layers.',
+ "The desired layout must match the number of input " + "data layers.",
)
- res = ([
+ res = [
np.hstack(data_cube[slice(layout[1] * elem, layout[1] * (elem + 1))])
for elem in range(layout[0])
- ])
+ ]
return np.vstack(res)
@@ -118,20 +115,24 @@ def map2cube(data_map, layout):
"""
if np.all(np.array(data_map.shape) % np.array(layout)):
raise ValueError(
- 'The desired layout must be a multiple of the number '
- + 'pixels in the data map.',
+ "The desired layout must be a multiple of the number "
+ + "pixels in the data map.",
)
d_shape = np.array(data_map.shape) // np.array(layout)
- return np.array([
- data_map[(
- slice(i_elem * d_shape[0], (i_elem + 1) * d_shape[0]),
- slice(j_elem * d_shape[1], (j_elem + 1) * d_shape[1]),
- )]
- for i_elem in range(layout[0])
- for j_elem in range(layout[1])
- ])
+ return np.array(
+ [
+ data_map[
+ (
+ slice(i_elem * d_shape[0], (i_elem + 1) * d_shape[0]),
+ slice(j_elem * d_shape[1], (j_elem + 1) * d_shape[1]),
+ )
+ ]
+ for i_elem in range(layout[0])
+ for j_elem in range(layout[1])
+ ]
+ )
def map2matrix(data_map, layout):
@@ -186,9 +187,9 @@ def map2matrix(data_map, layout):
image_shape * (i_elem % layout[1] + 1),
)
data_matrix.append(
- (
- data_map[lower[0]:upper[0], lower[1]:upper[1]]
- ).reshape(image_shape ** 2),
+ (data_map[lower[0] : upper[0], lower[1] : upper[1]]).reshape(
+ image_shape**2
+ ),
)
return np.array(data_matrix).T
@@ -232,7 +233,7 @@ def matrix2map(data_matrix, map_shape):
# Get the shape and layout of the images
image_shape = np.sqrt(data_matrix.shape[0]).astype(int)
- layout = np.array(map_shape // np.repeat(image_shape, 2), dtype='int')
+ layout = np.array(map_shape // np.repeat(image_shape, 2), dtype="int")
# Map objects from matrix
data_map = np.zeros(map_shape)
@@ -248,7 +249,7 @@ def matrix2map(data_matrix, map_shape):
image_shape * (i_elem // layout[1] + 1),
image_shape * (i_elem % layout[1] + 1),
)
- data_map[lower[0]:upper[0], lower[1]:upper[1]] = temp[:, :, i_elem]
+ data_map[lower[0] : upper[0], lower[1] : upper[1]] = temp[:, :, i_elem]
return data_map.astype(int)
@@ -285,7 +286,7 @@ def cube2matrix(data_cube):
"""
return data_cube.reshape(
- [data_cube.shape[0]] + [np.prod(data_cube.shape[1:])],
+ [data_cube.shape[0], np.prod(data_cube.shape[1:])],
).T
@@ -330,4 +331,4 @@ def matrix2cube(data_matrix, im_shape):
cube2matrix : complimentary function
"""
- return data_matrix.T.reshape([data_matrix.shape[1]] + list(im_shape))
+ return data_matrix.T.reshape([data_matrix.shape[1], *list(im_shape)])
diff --git a/modopt/base/types.py b/src/modopt/base/types.py
similarity index 82%
rename from modopt/base/types.py
rename to src/modopt/base/types.py
index 16e06f15..9e9a15b9 100644
--- a/modopt/base/types.py
+++ b/src/modopt/base/types.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""TYPE HANDLING ROUTINES.
This module contains methods for handing object types.
@@ -30,7 +28,7 @@ def check_callable(input_obj):
For invalid input type
"""
if not callable(input_obj):
- raise TypeError('The input object must be a callable function.')
+ raise TypeError("The input object must be a callable function.")
return input_obj
@@ -71,14 +69,13 @@ def check_float(input_obj):
"""
if not isinstance(input_obj, (int, float, list, tuple, np.ndarray)):
- raise TypeError('Invalid input type.')
+ raise TypeError("Invalid input type.")
if isinstance(input_obj, int):
input_obj = float(input_obj)
elif isinstance(input_obj, (list, tuple)):
input_obj = np.array(input_obj, dtype=float)
- elif (
- isinstance(input_obj, np.ndarray)
- and (not np.issubdtype(input_obj.dtype, np.floating))
+ elif isinstance(input_obj, np.ndarray) and (
+ not np.issubdtype(input_obj.dtype, np.floating)
):
input_obj = input_obj.astype(float)
@@ -121,14 +118,13 @@ def check_int(input_obj):
"""
if not isinstance(input_obj, (int, float, list, tuple, np.ndarray)):
- raise TypeError('Invalid input type.')
+ raise TypeError("Invalid input type.")
if isinstance(input_obj, float):
input_obj = int(input_obj)
elif isinstance(input_obj, (list, tuple)):
input_obj = np.array(input_obj, dtype=int)
- elif (
- isinstance(input_obj, np.ndarray)
- and (not np.issubdtype(input_obj.dtype, np.integer))
+ elif isinstance(input_obj, np.ndarray) and (
+ not np.issubdtype(input_obj.dtype, np.integer)
):
input_obj = input_obj.astype(int)
@@ -160,19 +156,18 @@ def check_npndarray(input_obj, dtype=None, writeable=True, verbose=True):
"""
if not isinstance(input_obj, np.ndarray):
- raise TypeError('Input is not a numpy array.')
+ raise TypeError("Input is not a numpy array.")
- if (
- (not isinstance(dtype, type(None)))
- and (not np.issubdtype(input_obj.dtype, dtype))
+ if (not isinstance(dtype, type(None))) and (
+ not np.issubdtype(input_obj.dtype, dtype)
):
raise (
TypeError(
- 'The numpy array elements are not of type: {0}'.format(dtype),
+ f"The numpy array elements are not of type: {dtype}",
),
)
if not writeable and verbose and input_obj.flags.writeable:
- warn('Making input data immutable.')
+ warn("Making input data immutable.")
input_obj.flags.writeable = writeable
diff --git a/modopt/interface/__init__.py b/src/modopt/interface/__init__.py
similarity index 75%
rename from modopt/interface/__init__.py
rename to src/modopt/interface/__init__.py
index f9439747..a54f4bf5 100644
--- a/modopt/interface/__init__.py
+++ b/src/modopt/interface/__init__.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""INTERFACE ROUTINES.
This module contains submodules for error handling, logging and IO interaction.
@@ -8,4 +6,4 @@
"""
-__all__ = ['errors', 'log']
+__all__ = ["errors", "log"]
diff --git a/modopt/interface/errors.py b/src/modopt/interface/errors.py
similarity index 75%
rename from modopt/interface/errors.py
rename to src/modopt/interface/errors.py
index 0fbe7e71..84031e3c 100644
--- a/modopt/interface/errors.py
+++ b/src/modopt/interface/errors.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""ERROR HANDLING ROUTINES.
This module contains methods for handing warnings and errors.
@@ -34,16 +32,16 @@ def warn(warn_string, log=None):
"""
if import_fail:
- warn_txt = 'WARNING'
+ warn_txt = "WARNING"
else:
- warn_txt = colored('WARNING', 'yellow')
+ warn_txt = colored("WARNING", "yellow")
# Print warning to stdout.
- sys.stderr.write('{0}: {1}\n'.format(warn_txt, warn_string))
+ sys.stderr.write(f"{warn_txt}: {warn_string}\n")
# Check if a logging structure is provided.
if not isinstance(log, type(None)):
- warnings.warn(warn_string)
+ warnings.warn(warn_string, stacklevel=2)
def catch_error(exception, log=None):
@@ -61,17 +59,17 @@ def catch_error(exception, log=None):
"""
if import_fail:
- err_txt = 'ERROR'
+ err_txt = "ERROR"
else:
- err_txt = colored('ERROR', 'red')
+ err_txt = colored("ERROR", "red")
# Print exception to stdout.
- stream_txt = '{0}: {1}\n'.format(err_txt, exception)
+ stream_txt = f"{err_txt}: {exception}\n"
sys.stderr.write(stream_txt)
# Check if a logging structure is provided.
if not isinstance(log, type(None)):
- log_txt = 'ERROR: {0}\n'.format(exception)
+ log_txt = f"ERROR: {exception}\n"
log.exception(log_txt)
@@ -91,11 +89,11 @@ def file_name_error(file_name):
If file name not specified or file not found
"""
- if file_name == '' or file_name[0][0] == '-':
- raise IOError('Input file name not specified.')
+ if file_name == "" or file_name[0][0] == "-":
+ raise OSError("Input file name not specified.")
elif not os.path.isfile(file_name):
- raise IOError('Input file name {0} not found!'.format(file_name))
+ raise OSError(f"Input file name {file_name} not found!")
def is_exe(fpath):
@@ -136,7 +134,7 @@ def is_executable(exe_name):
"""
if not isinstance(exe_name, str):
- raise TypeError('Executable name must be a string.')
+ raise TypeError("Executable name must be a string.")
fpath, fname = os.path.split(exe_name)
@@ -146,11 +144,9 @@ def is_executable(exe_name):
else:
res = any(
is_exe(os.path.join(path, exe_name))
- for path in os.environ['PATH'].split(os.pathsep)
+ for path in os.environ["PATH"].split(os.pathsep)
)
if not res:
- message = (
- '{0} does not appear to be a valid executable on this system.'
- )
- raise IOError(message.format(exe_name))
+ message = "{0} does not appear to be a valid executable on this system."
+ raise OSError(message.format(exe_name))
diff --git a/modopt/interface/log.py b/src/modopt/interface/log.py
similarity index 77%
rename from modopt/interface/log.py
rename to src/modopt/interface/log.py
index 3b2fa77a..50c316b7 100644
--- a/modopt/interface/log.py
+++ b/src/modopt/interface/log.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""LOGGING ROUTINES.
This module contains methods for handing logging.
@@ -30,22 +28,22 @@ def set_up_log(filename, verbose=True):
"""
# Add file extension.
- filename = '{0}.log'.format(filename)
+ filename = f"{filename}.log"
if verbose:
- print('Preparing log file:', filename)
+ print("Preparing log file:", filename)
# Capture warnings.
logging.captureWarnings(True)
# Set output format.
formatter = logging.Formatter(
- fmt='%(asctime)s %(message)s',
- datefmt='%d/%m/%Y %H:%M:%S',
+ fmt="%(asctime)s %(message)s",
+ datefmt="%d/%m/%Y %H:%M:%S",
)
# Create file handler.
- fh = logging.FileHandler(filename=filename, mode='w')
+ fh = logging.FileHandler(filename=filename, mode="w")
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
@@ -55,7 +53,7 @@ def set_up_log(filename, verbose=True):
log.addHandler(fh)
# Send opening message.
- log.info('The log file has been set-up.')
+ log.info("The log file has been set-up.")
return log
@@ -74,10 +72,10 @@ def close_log(log, verbose=True):
"""
if verbose:
- print('Closing log file:', log.name)
+ print("Closing log file:", log.name)
# Send closing message.
- log.info('The log file has been closed.')
+ log.info("The log file has been closed.")
# Remove all handlers from log.
for log_handler in log.handlers:
diff --git a/modopt/math/__init__.py b/src/modopt/math/__init__.py
similarity index 64%
rename from modopt/math/__init__.py
rename to src/modopt/math/__init__.py
index a22c0c98..d5ffc67a 100644
--- a/modopt/math/__init__.py
+++ b/src/modopt/math/__init__.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""MATHEMATICS ROUTINES.
This module contains submodules for mathematical applications.
@@ -8,4 +6,4 @@
"""
-__all__ = ['convolve', 'matrix', 'stats', 'metrics']
+__all__ = ["convolve", "matrix", "stats", "metrics"]
diff --git a/modopt/math/convolve.py b/src/modopt/math/convolve.py
similarity index 87%
rename from modopt/math/convolve.py
rename to src/modopt/math/convolve.py
index a4322ff2..21dc8b4e 100644
--- a/modopt/math/convolve.py
+++ b/src/modopt/math/convolve.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""CONVOLUTION ROUTINES.
This module contains methods for convolution.
@@ -18,7 +16,7 @@
from astropy.convolution import convolve_fft
except ImportError: # pragma: no cover
import_astropy = False
- warn('astropy not found, will default to scipy for convolution')
+ warn("astropy not found, will default to scipy for convolution")
else:
import_astropy = True
try:
@@ -30,7 +28,7 @@
warn('Using pyFFTW "monkey patch" for scipy.fftpack')
-def convolve(input_data, kernel, method='scipy'):
+def convolve(input_data, kernel, method="scipy"):
"""Convolve data with kernel.
This method convolves the input data with a given kernel using FFT and
@@ -80,29 +78,29 @@ def convolve(input_data, kernel, method='scipy'):
"""
if input_data.ndim != kernel.ndim:
- raise ValueError('Data and kernel must have the same dimensions.')
+ raise ValueError("Data and kernel must have the same dimensions.")
- if method not in {'astropy', 'scipy'}:
+ if method not in {"astropy", "scipy"}:
raise ValueError('Invalid method. Options are "astropy" or "scipy".')
if not import_astropy: # pragma: no cover
- method = 'scipy'
+ method = "scipy"
- if method == 'astropy':
+ if method == "astropy":
return convolve_fft(
input_data,
kernel,
- boundary='wrap',
+ boundary="wrap",
crop=False,
- nan_treatment='fill',
+ nan_treatment="fill",
normalize_kernel=False,
)
- elif method == 'scipy':
- return scipy.signal.fftconvolve(input_data, kernel, mode='same')
+ elif method == "scipy":
+ return scipy.signal.fftconvolve(input_data, kernel, mode="same")
-def convolve_stack(input_data, kernel, rot_kernel=False, method='scipy'):
+def convolve_stack(input_data, kernel, rot_kernel=False, method="scipy"):
"""Convolve stack of data with stack of kernels.
This method convolves the input data with a given kernel using FFT and
@@ -156,7 +154,9 @@ def convolve_stack(input_data, kernel, rot_kernel=False, method='scipy'):
if rot_kernel:
kernel = rotate_stack(kernel)
- return np.array([
- convolve(data_i, kernel_i, method=method)
- for data_i, kernel_i in zip(input_data, kernel)
- ])
+ return np.array(
+ [
+ convolve(data_i, kernel_i, method=method)
+ for data_i, kernel_i in zip(input_data, kernel)
+ ]
+ )
diff --git a/modopt/math/matrix.py b/src/modopt/math/matrix.py
similarity index 90%
rename from modopt/math/matrix.py
rename to src/modopt/math/matrix.py
index 8361531d..b200f15d 100644
--- a/modopt/math/matrix.py
+++ b/src/modopt/math/matrix.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""MATRIX ROUTINES.
This module contains methods for matrix operations.
@@ -15,7 +13,7 @@
from modopt.base.backend import get_array_module, get_backend
-def gram_schmidt(matrix, return_opt='orthonormal'):
+def gram_schmidt(matrix, return_opt="orthonormal"):
r"""Gram-Schmit.
This method orthonormalizes the row vectors of the input matrix.
@@ -55,7 +53,7 @@ def gram_schmidt(matrix, return_opt='orthonormal'):
https://en.wikipedia.org/wiki/Gram%E2%80%93Schmidt_process
"""
- if return_opt not in {'orthonormal', 'orthogonal', 'both'}:
+ if return_opt not in {"orthonormal", "orthogonal", "both"}:
raise ValueError(
'Invalid return_opt, options are: "orthonormal", "orthogonal" or '
+ '"both"',
@@ -65,7 +63,6 @@ def gram_schmidt(matrix, return_opt='orthonormal'):
e_vec = []
for vector in matrix:
-
if u_vec:
u_now = vector - sum(project(u_i, vector) for u_i in u_vec)
else:
@@ -77,11 +74,11 @@ def gram_schmidt(matrix, return_opt='orthonormal'):
u_vec = np.array(u_vec)
e_vec = np.array(e_vec)
- if return_opt == 'orthonormal':
+ if return_opt == "orthonormal":
return e_vec
- elif return_opt == 'orthogonal':
+ elif return_opt == "orthogonal":
return u_vec
- elif return_opt == 'both':
+ elif return_opt == "both":
return u_vec, e_vec
@@ -201,7 +198,7 @@ def rot_matrix(angle):
return np.around(
np.array(
[[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]],
- dtype='float',
+ dtype="float",
),
10,
)
@@ -243,22 +240,21 @@ def rotate(matrix, angle):
shape = np.array(matrix.shape)
if shape[0] != shape[1]:
- raise ValueError('Input matrix must be square.')
+ raise ValueError("Input matrix must be square.")
shift = (shape - 1) // 2
index = (
- np.array(list(product(*np.array([np.arange(sval) for sval in shape]))))
- - shift
+ np.array(list(product(*np.array([np.arange(sval) for sval in shape])))) - shift
)
- new_index = np.array(np.dot(index, rot_matrix(angle)), dtype='int') + shift
+ new_index = np.array(np.dot(index, rot_matrix(angle)), dtype="int") + shift
new_index[new_index >= shape[0]] -= shape[0]
return matrix[tuple(zip(new_index.T))].reshape(shape.T)
-class PowerMethod(object):
+class PowerMethod:
"""Power method class.
This method performs implements power method to calculate the spectral
@@ -277,6 +273,8 @@ class PowerMethod(object):
initialisation (default is ``True``)
verbose : bool, optional
Optional verbosity (default is ``False``)
+ rng: int, xp.random.Generator or None (default is ``None``)
+ Random number generator or seed.
Examples
--------
@@ -301,16 +299,17 @@ def __init__(
data_shape,
data_type=float,
auto_run=True,
- compute_backend='numpy',
+ compute_backend="numpy",
verbose=False,
+ rng=None,
):
-
self._operator = operator
self._data_shape = data_shape
self._data_type = data_type
self._verbose = verbose
xp, compute_backend = get_backend(compute_backend)
self.xp = xp
+ self.rng = None
self.compute_backend = compute_backend
if auto_run:
self.get_spec_rad()
@@ -327,7 +326,8 @@ def _set_initial_x(self):
Random values of the same shape as the input data
"""
- return self.xp.random.random(self._data_shape).astype(self._data_type)
+ rng = self.xp.random.default_rng(self.rng)
+ return rng.random(self._data_shape).astype(self._data_type)
def get_spec_rad(self, tolerance=1e-6, max_iter=20, extra_factor=1.0):
"""Get spectral radius.
@@ -363,18 +363,14 @@ def get_spec_rad(self, tolerance=1e-6, max_iter=20, extra_factor=1.0):
x_new /= x_new_norm
- if (xp.abs(x_new_norm - x_old_norm) < tolerance):
- message = (
- ' - Power Method converged after {0} iterations!'
- )
+ if xp.abs(x_new_norm - x_old_norm) < tolerance:
+ message = " - Power Method converged after {0} iterations!"
if self._verbose:
print(message.format(i_elem + 1))
break
elif i_elem == max_iter - 1 and self._verbose:
- message = (
- ' - Power Method did not converge after {0} iterations!'
- )
+ message = " - Power Method did not converge after {0} iterations!"
print(message.format(max_iter))
xp.copyto(x_old, x_new)
diff --git a/modopt/math/metrics.py b/src/modopt/math/metrics.py
similarity index 91%
rename from modopt/math/metrics.py
rename to src/modopt/math/metrics.py
index 21952624..befd4fa4 100644
--- a/modopt/math/metrics.py
+++ b/src/modopt/math/metrics.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""METRICS.
This module contains classes of different metric functions for optimization.
@@ -71,15 +69,13 @@ def _preprocess_input(test, ref, mask=None):
The SNR
"""
- test = np.abs(np.copy(test)).astype('float64')
- ref = np.abs(np.copy(ref)).astype('float64')
+ test = np.abs(np.copy(test)).astype("float64")
+ ref = np.abs(np.copy(ref)).astype("float64")
test = min_max_normalize(test)
ref = min_max_normalize(ref)
if (not isinstance(mask, np.ndarray)) and (mask is not None):
- message = (
- 'Mask should be None, or a numpy.ndarray, got "{0}" instead.'
- )
+ message = 'Mask should be None, or a numpy.ndarray, got "{0}" instead.'
raise ValueError(message.format(mask))
if mask is None:
@@ -119,9 +115,9 @@ def ssim(test, ref, mask=None):
"""
if not import_skimage: # pragma: no cover
raise ImportError(
- 'Required version of Scikit-Image package not found'
- + 'see documentation for details: https://cea-cosmic.'
- + 'github.io/ModOpt/#optional-packages',
+ "Required version of Scikit-Image package not found"
+ + "see documentation for details: https://cea-cosmic."
+ + "github.io/ModOpt/#optional-packages",
)
test, ref, mask = _preprocess_input(test, ref, mask)
@@ -270,6 +266,6 @@ def nrmse(test, ref, mask=None):
ref = mask * ref
num = np.sqrt(mse(test, ref))
- deno = np.sqrt(np.mean((np.square(test))))
+ deno = np.sqrt(np.mean(np.square(test)))
return num / deno
diff --git a/modopt/math/stats.py b/src/modopt/math/stats.py
similarity index 97%
rename from modopt/math/stats.py
rename to src/modopt/math/stats.py
index 59bf6759..8583a8c3 100644
--- a/modopt/math/stats.py
+++ b/src/modopt/math/stats.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""STATISTICS ROUTINES.
This module contains methods for basic statistics.
@@ -31,8 +29,8 @@ def gaussian_kernel(data_shape, sigma, norm="max"):
Desiered shape of the kernel
sigma : float
Standard deviation of the kernel
- norm : {'max', 'sum'}, optional
- Normalisation of the kerenl (options are ``'max'`` or ``'sum'``, default is ``'max'``)
+ norm : {'max', 'sum'}, optional, default='max'
+ Normalisation of the kernel
Returns
-------
diff --git a/modopt/opt/__init__.py b/src/modopt/opt/__init__.py
similarity index 59%
rename from modopt/opt/__init__.py
rename to src/modopt/opt/__init__.py
index 2fd3d747..62d1f388 100644
--- a/modopt/opt/__init__.py
+++ b/src/modopt/opt/__init__.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""OPTIMISATION PROBLEM MODULES.
This module contains submodules for solving optimisation problems.
@@ -8,4 +6,4 @@
"""
-__all__ = ['cost', 'gradient', 'linear', 'algorithms', 'proximity', 'reweight']
+__all__ = ["cost", "gradient", "linear", "algorithms", "proximity", "reweight"]
diff --git a/modopt/opt/algorithms/__init__.py b/src/modopt/opt/algorithms/__init__.py
similarity index 53%
rename from modopt/opt/algorithms/__init__.py
rename to src/modopt/opt/algorithms/__init__.py
index d4e7082b..ff79502c 100644
--- a/modopt/opt/algorithms/__init__.py
+++ b/src/modopt/opt/algorithms/__init__.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
r"""OPTIMISATION ALGORITHMS.
This module contains class implementations of various optimisation algoritms.
@@ -45,16 +44,32 @@
"""
-from modopt.opt.algorithms.base import SetUp
-from modopt.opt.algorithms.forward_backward import (FISTA, POGM,
- ForwardBackward,
- GenForwardBackward)
-from modopt.opt.algorithms.gradient_descent import (AdaGenericGradOpt,
- ADAMGradOpt,
- GenericGradOpt,
- MomentumGradOpt,
- RMSpropGradOpt,
- SAGAOptGradOpt,
- VanillaGenericGradOpt)
-from modopt.opt.algorithms.primal_dual import Condat
-from modopt.opt.algorithms.admm import ADMM, FastADMM
+from .forward_backward import FISTA, ForwardBackward, GenForwardBackward, POGM
+from .primal_dual import Condat
+from .gradient_descent import (
+ ADAMGradOpt,
+ AdaGenericGradOpt,
+ GenericGradOpt,
+ MomentumGradOpt,
+ RMSpropGradOpt,
+ SAGAOptGradOpt,
+ VanillaGenericGradOpt,
+)
+from .admm import ADMM, FastADMM
+
+__all__ = [
+ "FISTA",
+ "ForwardBackward",
+ "GenForwardBackward",
+ "POGM",
+ "Condat",
+ "ADAMGradOpt",
+ "AdaGenericGradOpt",
+ "GenericGradOpt",
+ "MomentumGradOpt",
+ "RMSpropGradOpt",
+ "SAGAOptGradOpt",
+ "VanillaGenericGradOpt",
+ "ADMM",
+ "FastADMM",
+]
diff --git a/modopt/opt/algorithms/admm.py b/src/modopt/opt/algorithms/admm.py
similarity index 95%
rename from modopt/opt/algorithms/admm.py
rename to src/modopt/opt/algorithms/admm.py
index b881b770..b2f45171 100644
--- a/modopt/opt/algorithms/admm.py
+++ b/src/modopt/opt/algorithms/admm.py
@@ -1,4 +1,5 @@
"""ADMM Algorithms."""
+
import numpy as np
from modopt.base.backend import get_array_module
@@ -67,7 +68,8 @@ def _calc_cost(self, u, v, **kwargs):
class ADMM(SetUp):
r"""Fast ADMM Optimisation Algorihm.
- This class implement the ADMM algorithm described in :cite:`Goldstein2014` (Algorithm 1).
+ This class implement the ADMM algorithm described in :cite:`Goldstein2014`
+ (Algorithm 1).
Parameters
----------
@@ -85,7 +87,7 @@ class ADMM(SetUp):
Constraint vector
optimizers: tuple
2-tuple of callable, that are the optimizers for the u and v.
- Each callable should access an init and obs argument and returns an estimate for:
+ Each callable should access init and obs argument and returns an estimate for:
.. math:: u_{k+1} = \argmin H(u) + \frac{\tau}{2}\|A u - y\|^2
.. math:: v_{k+1} = \argmin G(v) + \frac{\tau}{2}\|Bv - y \|^2
cost_funcs: tuple
@@ -188,7 +190,7 @@ def iterate(self, max_iter=150):
self.retrieve_outputs()
# rename outputs as attributes
self.u_final = self._u_new
- self.x_final = self.u_final # for backward compatibility
+ self.x_final = self.u_final # for backward compatibility
self.v_final = self._v_new
def get_notify_observers_kwargs(self):
@@ -203,9 +205,9 @@ def get_notify_observers_kwargs(self):
The mapping between the iterated variables
"""
return {
- 'x_new': self._u_new,
- 'v_new': self._v_new,
- 'idx': self.idx,
+ "x_new": self._u_new,
+ "v_new": self._v_new,
+ "idx": self.idx,
}
def retrieve_outputs(self):
@@ -215,7 +217,7 @@ def retrieve_outputs(self):
y_final, metrics.
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
@@ -242,7 +244,7 @@ class FastADMM(ADMM):
Constraint vector
optimizers: tuple
2-tuple of callable, that are the optimizers for the u and v.
- Each callable should access an init and obs argument and returns an estimate for:
+ Each callable should access init and obs argument and returns an estimate for:
.. math:: u_{k+1} = \argmin H(u) + \frac{\tau}{2}\|A u - y\|^2
.. math:: v_{k+1} = \argmin G(v) + \frac{\tau}{2}\|Bv - y \|^2
cost_funcs: tuple
@@ -256,7 +258,8 @@ class FastADMM(ADMM):
Notes
-----
- This is an accelerated version of the ADMM algorithm. The convergence hypothesis are stronger than for the ADMM algorithm.
+ This is an accelerated version of the ADMM algorithm. The convergence hypothesis are
+ stronger than for the ADMM algorithm.
See Also
--------
diff --git a/modopt/opt/algorithms/base.py b/src/modopt/opt/algorithms/base.py
similarity index 87%
rename from modopt/opt/algorithms/base.py
rename to src/modopt/opt/algorithms/base.py
index c5a4b101..f7391063 100644
--- a/modopt/opt/algorithms/base.py
+++ b/src/modopt/opt/algorithms/base.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""Base SetUp for optimisation algorithms."""
from inspect import getmro
@@ -69,7 +68,7 @@ def __init__(
verbose=False,
progress=True,
step_size=None,
- compute_backend='numpy',
+ compute_backend="numpy",
**dummy_kwargs,
):
self.idx = 0
@@ -79,26 +78,26 @@ def __init__(
self.metrics = metrics
self.step_size = step_size
self._op_parents = (
- 'GradParent',
- 'ProximityParent',
- 'LinearParent',
- 'costObj',
+ "GradParent",
+ "ProximityParent",
+ "LinearParent",
+ "costObj",
)
self.metric_call_period = metric_call_period
# Declaration of observers for metrics
- super().__init__(['cv_metrics'])
+ super().__init__(["cv_metrics"])
for name, dic in self.metrics.items():
observer = MetricObserver(
name,
- dic['metric'],
- dic['mapping'],
- dic['cst_kwargs'],
- dic['early_stopping'],
+ dic["metric"],
+ dic["mapping"],
+ dic["cst_kwargs"],
+ dic["early_stopping"],
)
- self.add_observer('cv_metrics', observer)
+ self.add_observer("cv_metrics", observer)
xp, compute_backend = backend.get_backend(compute_backend)
self.xp = xp
@@ -111,14 +110,13 @@ def metrics(self):
@metrics.setter
def metrics(self, metrics):
-
if isinstance(metrics, type(None)):
self._metrics = {}
elif isinstance(metrics, dict):
self._metrics = metrics
else:
raise TypeError(
- 'Metrics must be a dictionary, not {0}.'.format(type(metrics)),
+ f"Metrics must be a dictionary, not {type(metrics)}.",
)
def any_convergence_flag(self):
@@ -132,9 +130,7 @@ def any_convergence_flag(self):
True if any convergence criteria met
"""
- return any(
- obs.converge_flag for obs in self._observers['cv_metrics']
- )
+ return any(obs.converge_flag for obs in self._observers["cv_metrics"])
def copy_data(self, input_data):
"""Copy Data.
@@ -152,10 +148,12 @@ def copy_data(self, input_data):
Copy of input data
"""
- return self.xp.copy(backend.change_backend(
- input_data,
- self.compute_backend,
- ))
+ return self.xp.copy(
+ backend.change_backend(
+ input_data,
+ self.compute_backend,
+ )
+ )
def _check_input_data(self, input_data):
"""Check input data type.
@@ -175,7 +173,7 @@ def _check_input_data(self, input_data):
"""
if not (isinstance(input_data, (self.xp.ndarray, np.ndarray))):
raise TypeError(
- 'Input data must be a numpy array or backend array',
+ "Input data must be a numpy array or backend array",
)
def _check_param(self, param_val):
@@ -195,7 +193,7 @@ def _check_param(self, param_val):
"""
if not isinstance(param_val, float):
- raise TypeError('Algorithm parameter must be a float value.')
+ raise TypeError("Algorithm parameter must be a float value.")
def _check_param_update(self, param_update):
"""Check algorithm parameter update methods.
@@ -213,14 +211,13 @@ def _check_param_update(self, param_update):
For invalid input type
"""
- param_conditions = (
- not isinstance(param_update, type(None))
- and not callable(param_update)
+ param_conditions = not isinstance(param_update, type(None)) and not callable(
+ param_update
)
if param_conditions:
raise TypeError(
- 'Algorithm parameter update must be a callabale function.',
+ "Algorithm parameter update must be a callabale function.",
)
def _check_operator(self, operator):
@@ -239,7 +236,7 @@ def _check_operator(self, operator):
tree = [op_obj.__name__ for op_obj in getmro(operator.__class__)]
if not any(parent in tree for parent in self._op_parents):
- message = '{0} does not inherit an operator parent.'
+ message = "{0} does not inherit an operator parent."
warn(message.format(str(operator.__class__)))
def _compute_metrics(self):
@@ -250,7 +247,7 @@ def _compute_metrics(self):
"""
kwargs = self.get_notify_observers_kwargs()
- self.notify_observers('cv_metrics', **kwargs)
+ self.notify_observers("cv_metrics", **kwargs)
def _iterations(self, max_iter, progbar=None):
"""Iterate method.
@@ -273,7 +270,6 @@ def _iterations(self, max_iter, progbar=None):
# We do not call metrics if metrics is empty or metric call
# period is None
if self.metrics and self.metric_call_period is not None:
-
metric_conditions = (
self.idx % self.metric_call_period == 0
or self.idx == (max_iter - 1)
@@ -285,7 +281,7 @@ def _iterations(self, max_iter, progbar=None):
if self.converge:
if self.verbose:
- print(' - Converged!')
+ print(" - Converged!")
break
if progbar:
diff --git a/modopt/opt/algorithms/forward_backward.py b/src/modopt/opt/algorithms/forward_backward.py
similarity index 87%
rename from modopt/opt/algorithms/forward_backward.py
rename to src/modopt/opt/algorithms/forward_backward.py
index 702799c6..31927eb0 100644
--- a/modopt/opt/algorithms/forward_backward.py
+++ b/src/modopt/opt/algorithms/forward_backward.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""Forward-Backward Algorithms."""
import numpy as np
@@ -9,7 +8,7 @@
from modopt.opt.linear import Identity
-class FISTA(object):
+class FISTA:
r"""FISTA.
This class is inherited by optimisation classes to speed up convergence
@@ -52,12 +51,12 @@ class FISTA(object):
"""
_restarting_strategies = (
- 'adaptive', # option 1 in alg 4
- 'adaptive-i',
- 'adaptive-1',
- 'adaptive-ii', # option 2 in alg 4
- 'adaptive-2',
- 'greedy', # alg 5
+ "adaptive", # option 1 in alg 4
+ "adaptive-i",
+ "adaptive-1",
+ "adaptive-ii", # option 2 in alg 4
+ "adaptive-2",
+ "greedy", # alg 5
None, # no restarting
)
@@ -73,26 +72,28 @@ def __init__(
r_lazy=4,
**kwargs,
):
-
if isinstance(a_cd, type(None)):
- self.mode = 'regular'
+ self.mode = "regular"
self.p_lazy = p_lazy
self.q_lazy = q_lazy
self.r_lazy = r_lazy
elif a_cd > 2:
- self.mode = 'CD'
+ self.mode = "CD"
self.a_cd = a_cd
self._n = 0
else:
raise ValueError(
- 'a_cd must either be None (for regular mode) or a number > 2',
+ "a_cd must either be None (for regular mode) or a number > 2",
)
if restart_strategy in self._restarting_strategies:
self._check_restart_params(
- restart_strategy, min_beta, s_greedy, xi_restart,
+ restart_strategy,
+ min_beta,
+ s_greedy,
+ xi_restart,
)
self.restart_strategy = restart_strategy
self.min_beta = min_beta
@@ -100,10 +101,10 @@ def __init__(
self.xi_restart = xi_restart
else:
- message = 'Restarting strategy must be one of {0}.'
+ message = "Restarting strategy must be one of {0}."
raise ValueError(
message.format(
- ', '.join(self._restarting_strategies),
+ ", ".join(self._restarting_strategies),
),
)
self._t_now = 1.0
@@ -155,22 +156,20 @@ def _check_restart_params(
if restart_strategy is None:
return True
- if self.mode != 'regular':
+ if self.mode != "regular":
raise ValueError(
- 'Restarting strategies can only be used with regular mode.',
+ "Restarting strategies can only be used with regular mode.",
)
- greedy_params_check = (
- min_beta is None or s_greedy is None or s_greedy <= 1
- )
+ greedy_params_check = min_beta is None or s_greedy is None or s_greedy <= 1
- if restart_strategy == 'greedy' and greedy_params_check:
+ if restart_strategy == "greedy" and greedy_params_check:
raise ValueError(
- 'You need a min_beta and an s_greedy > 1 for greedy restart.',
+ "You need a min_beta and an s_greedy > 1 for greedy restart.",
)
if xi_restart is None or xi_restart >= 1:
- raise ValueError('You need a xi_restart < 1 for restart.')
+ raise ValueError("You need a xi_restart < 1 for restart.")
return True
@@ -210,12 +209,12 @@ def is_restart(self, z_old, x_new, x_old):
criterion = xp.vdot(z_old - x_new, x_new - x_old) >= 0
if criterion:
- if 'adaptive' in self.restart_strategy:
+ if "adaptive" in self.restart_strategy:
self.r_lazy *= self.xi_restart
- if self.restart_strategy in {'adaptive-ii', 'adaptive-2'}:
+ if self.restart_strategy in {"adaptive-ii", "adaptive-2"}:
self._t_now = 1
- if self.restart_strategy == 'greedy':
+ if self.restart_strategy == "greedy":
cur_delta = xp.linalg.norm(x_new - x_old)
if self._delta0 is None:
self._delta0 = self.s_greedy * cur_delta
@@ -269,17 +268,17 @@ def update_lambda(self, *args, **kwargs):
Implements steps 3 and 4 from algoritm 10.7 in :cite:`bauschke2009`.
"""
- if self.restart_strategy == 'greedy':
+ if self.restart_strategy == "greedy":
return 2
# Steps 3 and 4 from alg.10.7.
self._t_prev = self._t_now
- if self.mode == 'regular':
- sqrt_part = self.r_lazy * self._t_prev ** 2 + self.q_lazy
+ if self.mode == "regular":
+ sqrt_part = self.r_lazy * self._t_prev**2 + self.q_lazy
self._t_now = self.p_lazy + np.sqrt(sqrt_part) * 0.5
- elif self.mode == 'CD':
+ elif self.mode == "CD":
self._t_now = (self._n + self.a_cd - 1) / self.a_cd
self._n += 1
@@ -344,18 +343,17 @@ def __init__(
x,
grad,
prox,
- cost='auto',
+ cost="auto",
beta_param=1.0,
lambda_param=1.0,
beta_update=None,
- lambda_update='fista',
+ lambda_update="fista",
auto_iterate=True,
metric_call_period=5,
metrics=None,
linear=None,
**kwargs,
):
-
# Set default algorithm properties
super().__init__(
metric_call_period=metric_call_period,
@@ -376,7 +374,7 @@ def __init__(
self._prox = prox
self._linear = linear
- if cost == 'auto':
+ if cost == "auto":
self._cost_func = costObj([self._grad, self._prox])
else:
self._cost_func = cost
@@ -384,7 +382,7 @@ def __init__(
# Check if there is a linear op, needed for metrics in the FB algoritm
if metrics and self._linear is None:
raise ValueError(
- 'When using metrics, you must pass a linear operator',
+ "When using metrics, you must pass a linear operator",
)
if self._linear is None:
@@ -400,7 +398,7 @@ def __init__(
# Set the algorithm parameter update methods
self._check_param_update(beta_update)
self._beta_update = beta_update
- if isinstance(lambda_update, str) and lambda_update == 'fista':
+ if isinstance(lambda_update, str) and lambda_update == "fista":
fista = FISTA(**kwargs)
self._lambda_update = fista.update_lambda
self._is_restart = fista.is_restart
@@ -462,9 +460,8 @@ def _update(self):
# Test cost function for convergence.
if self._cost_func:
- self.converge = (
- self.any_convergence_flag()
- or self._cost_func.get_cost(self._x_new)
+ self.converge = self.any_convergence_flag() or self._cost_func.get_cost(
+ self._x_new
)
def iterate(self, max_iter=150, progbar=None):
@@ -500,9 +497,9 @@ def get_notify_observers_kwargs(self):
"""
return {
- 'x_new': self._linear.adj_op(self._x_new),
- 'z_new': self._z_new,
- 'idx': self.idx,
+ "x_new": self._linear.adj_op(self._x_new),
+ "z_new": self._z_new,
+ "idx": self.idx,
}
def retrieve_outputs(self):
@@ -513,7 +510,7 @@ def retrieve_outputs(self):
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
@@ -577,7 +574,7 @@ def __init__(
x,
grad,
prox_list,
- cost='auto',
+ cost="auto",
gamma_param=1.0,
lambda_param=1.0,
gamma_update=None,
@@ -589,7 +586,6 @@ def __init__(
linear=None,
**kwargs,
):
-
# Set default algorithm properties
super().__init__(
metric_call_period=metric_call_period,
@@ -602,22 +598,22 @@ def __init__(
self._x_old = self.xp.copy(x)
# Set the algorithm operators
- for operator in [grad, cost] + prox_list:
+ for operator in [grad, cost, *prox_list]:
self._check_operator(operator)
self._grad = grad
self._prox_list = self.xp.array(prox_list)
self._linear = linear
- if cost == 'auto':
- self._cost_func = costObj([self._grad] + prox_list)
+ if cost == "auto":
+ self._cost_func = costObj([self._grad, *prox_list])
else:
self._cost_func = cost
# Check if there is a linear op, needed for metrics in the FB algoritm
if metrics and self._linear is None:
raise ValueError(
- 'When using metrics, you must pass a linear operator',
+ "When using metrics, you must pass a linear operator",
)
if self._linear is None:
@@ -641,9 +637,7 @@ def __init__(
self._set_weights(weights)
# Set initial z
- self._z = self.xp.array([
- self._x_old for i in range(self._prox_list.size)
- ])
+ self._z = self.xp.array([self._x_old for i in range(self._prox_list.size)])
# Automatically run the algorithm
if auto_iterate:
@@ -673,25 +667,25 @@ def _set_weights(self, weights):
self._prox_list.size,
)
elif not isinstance(weights, (list, tuple, np.ndarray)):
- raise TypeError('Weights must be provided as a list.')
+ raise TypeError("Weights must be provided as a list.")
weights = self.xp.array(weights)
if not np.issubdtype(weights.dtype, np.floating):
- raise ValueError('Weights must be list of float values.')
+ raise ValueError("Weights must be list of float values.")
if weights.size != self._prox_list.size:
raise ValueError(
- 'The number of weights must match the number of proximity '
- + 'operators.',
+ "The number of weights must match the number of proximity "
+ + "operators.",
)
expected_weight_sum = 1.0
if self.xp.sum(weights) != expected_weight_sum:
raise ValueError(
- 'Proximity operator weights must sum to 1.0. Current sum of '
- + 'weights = {0}'.format(self.xp.sum(weights)),
+ "Proximity operator weights must sum to 1.0. Current sum of "
+ + f"weights = {self.xp.sum(weights)}",
)
self._weights = weights
@@ -726,9 +720,7 @@ def _update(self):
# Update z values.
for i in range(self._prox_list.size):
- z_temp = (
- 2 * self._x_old - self._z[i] - self._gamma * self._grad.grad
- )
+ z_temp = 2 * self._x_old - self._z[i] - self._gamma * self._grad.grad
z_prox = self._prox_list[i].op(
z_temp,
extra_factor=self._gamma / self._weights[i],
@@ -784,9 +776,9 @@ def get_notify_observers_kwargs(self):
"""
return {
- 'x_new': self._linear.adj_op(self._x_new),
- 'z_new': self._z,
- 'idx': self.idx,
+ "x_new": self._linear.adj_op(self._x_new),
+ "z_new": self._z,
+ "idx": self.idx,
}
def retrieve_outputs(self):
@@ -797,7 +789,7 @@ def retrieve_outputs(self):
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
@@ -871,7 +863,7 @@ def __init__(
z,
grad,
prox,
- cost='auto',
+ cost="auto",
linear=None,
beta_param=1.0,
sigma_bar=1.0,
@@ -880,7 +872,6 @@ def __init__(
metrics=None,
**kwargs,
):
-
# Set default algorithm properties
super().__init__(
metric_call_period=metric_call_period,
@@ -905,7 +896,7 @@ def __init__(
self._grad = grad
self._prox = prox
self._linear = linear
- if cost == 'auto':
+ if cost == "auto":
self._cost_func = costObj([self._grad, self._prox])
else:
self._cost_func = cost
@@ -918,7 +909,7 @@ def __init__(
for param_val in (beta_param, sigma_bar):
self._check_param(param_val)
if sigma_bar < 0 or sigma_bar > 1:
- raise ValueError('The sigma bar parameter needs to be in [0, 1]')
+ raise ValueError("The sigma bar parameter needs to be in [0, 1]")
self._beta = self.step_size or beta_param
self._sigma_bar = sigma_bar
@@ -944,18 +935,18 @@ def _update(self):
"""
# Step 4 from alg. 3
self._grad.get_grad(self._x_old)
- #self._u_new = self._x_old - self._beta * self._grad.grad
- self._u_new = -self._beta * self._grad.grad
+ # self._u_new = self._x_old - self._beta * self._grad.grad
+ self._u_new = -self._beta * self._grad.grad
self._u_new += self._x_old
# Step 5 from alg. 3
- self._t_new = 0.5 * (1 + self.xp.sqrt(1 + 4 * self._t_old ** 2))
+ self._t_new = 0.5 * (1 + self.xp.sqrt(1 + 4 * self._t_old**2))
# Step 6 from alg. 3
t_shifted_ratio = (self._t_old - 1) / self._t_new
sigma_t_ratio = self._sigma * self._t_old / self._t_new
beta_xi_t_shifted_ratio = t_shifted_ratio * self._beta / self._xi
- self._z = - beta_xi_t_shifted_ratio * (self._x_old - self._z)
+ self._z = -beta_xi_t_shifted_ratio * (self._x_old - self._z)
self._z += self._u_new
self._z += t_shifted_ratio * (self._u_new - self._u_old)
self._z += sigma_t_ratio * (self._u_new - self._x_old)
@@ -968,20 +959,18 @@ def _update(self):
# Restarting and gamma-Decreasing
# Step 9 from alg. 3
- #self._g_new = self._grad.grad - (self._x_new - self._z) / self._xi
- self._g_new = (self._z - self._x_new)
+ # self._g_new = self._grad.grad - (self._x_new - self._z) / self._xi
+ self._g_new = self._z - self._x_new
self._g_new /= self._xi
self._g_new += self._grad.grad
# Step 10 from alg 3.
- #self._y_new = self._x_old - self._beta * self._g_new
- self._y_new = - self._beta * self._g_new
+ # self._y_new = self._x_old - self._beta * self._g_new
+ self._y_new = -self._beta * self._g_new
self._y_new += self._x_old
# Step 11 from alg. 3
- restart_crit = (
- self.xp.vdot(-self._g_new, self._y_new - self._y_old) < 0
- )
+ restart_crit = self.xp.vdot(-self._g_new, self._y_new - self._y_old) < 0
if restart_crit:
self._t_new = 1
self._sigma = 1
@@ -999,9 +988,8 @@ def _update(self):
# Test cost function for convergence.
if self._cost_func:
- self.converge = (
- self.any_convergence_flag()
- or self._cost_func.get_cost(self._x_new)
+ self.converge = self.any_convergence_flag() or self._cost_func.get_cost(
+ self._x_new
)
def iterate(self, max_iter=150, progbar=None):
@@ -1037,14 +1025,14 @@ def get_notify_observers_kwargs(self):
"""
return {
- 'u_new': self._u_new,
- 'x_new': self._linear.adj_op(self._x_new),
- 'y_new': self._y_new,
- 'z_new': self._z,
- 'xi': self._xi,
- 'sigma': self._sigma,
- 't': self._t_new,
- 'idx': self.idx,
+ "u_new": self._u_new,
+ "x_new": self._linear.adj_op(self._x_new),
+ "y_new": self._y_new,
+ "z_new": self._z,
+ "xi": self._xi,
+ "sigma": self._sigma,
+ "t": self._t_new,
+ "idx": self.idx,
}
def retrieve_outputs(self):
@@ -1055,6 +1043,6 @@ def retrieve_outputs(self):
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
diff --git a/modopt/opt/algorithms/gradient_descent.py b/src/modopt/opt/algorithms/gradient_descent.py
similarity index 95%
rename from modopt/opt/algorithms/gradient_descent.py
rename to src/modopt/opt/algorithms/gradient_descent.py
index f3fe4b10..0960be5a 100644
--- a/modopt/opt/algorithms/gradient_descent.py
+++ b/src/modopt/opt/algorithms/gradient_descent.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""Gradient Descent Algorithms."""
import numpy as np
@@ -103,7 +102,7 @@ def __init__(
self._check_operator(operator)
self._grad = grad
self._prox = prox
- if cost == 'auto':
+ if cost == "auto":
self._cost_func = costObj([self._grad, self._prox])
else:
self._cost_func = cost
@@ -157,9 +156,8 @@ def _update(self):
self._eta = self._eta_update(self._eta, self.idx)
# Test cost function for convergence.
if self._cost_func:
- self.converge = (
- self.any_convergence_flag()
- or self._cost_func.get_cost(self._x_new)
+ self.converge = self.any_convergence_flag() or self._cost_func.get_cost(
+ self._x_new
)
def _update_grad_dir(self, grad):
@@ -208,10 +206,10 @@ def get_notify_observers_kwargs(self):
"""
return {
- 'x_new': self._x_new,
- 'dir_grad': self._dir_grad,
- 'speed_grad': self._speed_grad,
- 'idx': self.idx,
+ "x_new": self._x_new,
+ "dir_grad": self._dir_grad,
+ "speed_grad": self._speed_grad,
+ "idx": self.idx,
}
def retrieve_outputs(self):
@@ -222,7 +220,7 @@ def retrieve_outputs(self):
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
@@ -308,7 +306,7 @@ class RMSpropGradOpt(GenericGradOpt):
def __init__(self, *args, gamma=0.5, **kwargs):
super().__init__(*args, **kwargs)
if gamma < 0 or gamma > 1:
- raise ValueError('gamma is outside of range [0,1]')
+ raise ValueError("gamma is outside of range [0,1]")
self._check_param(gamma)
self._gamma = gamma
@@ -405,9 +403,9 @@ def __init__(self, *args, gamma=0.9, beta=0.9, **kwargs):
self._check_param(gamma)
self._check_param(beta)
if gamma < 0 or gamma >= 1:
- raise ValueError('gamma is outside of range [0,1]')
+ raise ValueError("gamma is outside of range [0,1]")
if beta < 0 or beta >= 1:
- raise ValueError('beta is outside of range [0,1]')
+ raise ValueError("beta is outside of range [0,1]")
self._gamma = gamma
self._beta = beta
self._beta_pow = 1
diff --git a/modopt/opt/algorithms/primal_dual.py b/src/modopt/opt/algorithms/primal_dual.py
similarity index 89%
rename from modopt/opt/algorithms/primal_dual.py
rename to src/modopt/opt/algorithms/primal_dual.py
index d5bdd431..fee49a25 100644
--- a/modopt/opt/algorithms/primal_dual.py
+++ b/src/modopt/opt/algorithms/primal_dual.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""Primal-Dual Algorithms."""
from modopt.opt.algorithms.base import SetUp
@@ -81,7 +80,7 @@ def __init__(
prox,
prox_dual,
linear=None,
- cost='auto',
+ cost="auto",
reweight=None,
rho=0.5,
sigma=1.0,
@@ -96,7 +95,6 @@ def __init__(
metrics=None,
**kwargs,
):
-
# Set default algorithm properties
super().__init__(
metric_call_period=metric_call_period,
@@ -123,12 +121,14 @@ def __init__(
self._linear = Identity()
else:
self._linear = linear
- if cost == 'auto':
- self._cost_func = costObj([
- self._grad,
- self._prox,
- self._prox_dual,
- ])
+ if cost == "auto":
+ self._cost_func = costObj(
+ [
+ self._grad,
+ self._prox,
+ self._prox_dual,
+ ]
+ )
else:
self._cost_func = cost
@@ -187,22 +187,17 @@ def _update(self):
self._grad.get_grad(self._x_old)
x_prox = self._prox.op(
- self._x_old - self._tau * self._grad.grad - self._tau
- * self._linear.adj_op(self._y_old),
+ self._x_old
+ - self._tau * self._grad.grad
+ - self._tau * self._linear.adj_op(self._y_old),
)
# Step 2 from eq.9.
- y_temp = (
- self._y_old + self._sigma
- * self._linear.op(2 * x_prox - self._x_old)
- )
+ y_temp = self._y_old + self._sigma * self._linear.op(2 * x_prox - self._x_old)
- y_prox = (
- y_temp - self._sigma
- * self._prox_dual.op(
- y_temp / self._sigma,
- extra_factor=(1.0 / self._sigma),
- )
+ y_prox = y_temp - self._sigma * self._prox_dual.op(
+ y_temp / self._sigma,
+ extra_factor=(1.0 / self._sigma),
)
# Step 3 from eq.9.
@@ -220,9 +215,8 @@ def _update(self):
# Test cost function for convergence.
if self._cost_func:
- self.converge = (
- self.any_convergence_flag()
- or self._cost_func.get_cost(self._x_new, self._y_new)
+ self.converge = self.any_convergence_flag() or self._cost_func.get_cost(
+ self._x_new, self._y_new
)
def iterate(self, max_iter=150, n_rewightings=1, progbar=None):
@@ -267,7 +261,7 @@ def get_notify_observers_kwargs(self):
The mapping between the iterated variables
"""
- return {'x_new': self._x_new, 'y_new': self._y_new, 'idx': self.idx}
+ return {"x_new": self._x_new, "y_new": self._y_new, "idx": self.idx}
def retrieve_outputs(self):
"""Retrieve outputs.
@@ -277,6 +271,6 @@ def retrieve_outputs(self):
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
diff --git a/modopt/opt/cost.py b/src/modopt/opt/cost.py
similarity index 91%
rename from modopt/opt/cost.py
rename to src/modopt/opt/cost.py
index 688a3959..37771f16 100644
--- a/modopt/opt/cost.py
+++ b/src/modopt/opt/cost.py
@@ -81,7 +81,6 @@ def __init__(
verbose=True,
plot_output=None,
):
-
self.cost = initial_cost
self._cost_list = []
self._cost_interval = cost_interval
@@ -112,20 +111,19 @@ def _check_cost(self):
# Check if enough cost values have been collected
if len(self._test_list) == self._test_range:
-
# The mean of the first half of the test list
t1 = xp.mean(
- xp.array(self._test_list[len(self._test_list) // 2:]),
+ xp.array(self._test_list[len(self._test_list) // 2 :]),
axis=0,
)
# The mean of the second half of the test list
t2 = xp.mean(
- xp.array(self._test_list[:len(self._test_list) // 2]),
+ xp.array(self._test_list[: len(self._test_list) // 2]),
axis=0,
)
# Calculate the change across the test list
if xp.around(t1, decimals=16):
- cost_diff = (xp.linalg.norm(t1 - t2) / xp.linalg.norm(t1))
+ cost_diff = xp.linalg.norm(t1 - t2) / xp.linalg.norm(t1)
else:
cost_diff = 0
@@ -133,9 +131,9 @@ def _check_cost(self):
self._test_list = []
if self._verbose:
- print(' - CONVERGENCE TEST - ')
- print(' - CHANGE IN COST:', cost_diff)
- print('')
+ print(" - CONVERGENCE TEST - ")
+ print(" - CHANGE IN COST:", cost_diff)
+ print("")
# Check for convergence
return cost_diff <= self._tolerance
@@ -176,8 +174,7 @@ def get_cost(self, *args, **kwargs):
"""
# Check if the cost should be calculated
test_conditions = (
- self._cost_interval is None
- or self._iteration % self._cost_interval
+ self._cost_interval is None or self._iteration % self._cost_interval
)
if test_conditions:
@@ -185,15 +182,15 @@ def get_cost(self, *args, **kwargs):
else:
if self._verbose:
- print(' - ITERATION:', self._iteration)
+ print(" - ITERATION:", self._iteration)
# Calculate the current cost
- self.cost = self._calc_cost(verbose=self._verbose, *args, **kwargs)
+ self.cost = self._calc_cost(*args, verbose=self._verbose, **kwargs)
self._cost_list.append(self.cost)
if self._verbose:
- print(' - COST:', self.cost)
- print('')
+ print(" - COST:", self.cost)
+ print("")
# Test for convergence
test_result = self._check_cost()
@@ -288,13 +285,11 @@ def _check_operators(self):
"""
if not isinstance(self._operators, (list, tuple, np.ndarray)):
- message = (
- 'Input operators must be provided as a list, not {0}'
- )
+ message = "Input operators must be provided as a list, not {0}"
raise TypeError(message.format(type(self._operators)))
for op in self._operators:
- if not hasattr(op, 'cost'):
+ if not hasattr(op, "cost"):
raise ValueError('Operators must contain "cost" method.')
op.cost = check_callable(op.cost)
diff --git a/modopt/opt/gradient.py b/src/modopt/opt/gradient.py
similarity index 97%
rename from modopt/opt/gradient.py
rename to src/modopt/opt/gradient.py
index caa8fa9d..fe9b87d8 100644
--- a/modopt/opt/gradient.py
+++ b/src/modopt/opt/gradient.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""GRADIENT CLASSES.
This module contains classses for defining algorithm gradients.
@@ -14,7 +12,7 @@
from modopt.base.types import check_callable, check_float, check_npndarray
-class GradParent(object):
+class GradParent:
"""Gradient Parent Class.
This class defines the basic methods that will be inherited by specific
@@ -71,7 +69,6 @@ def __init__(
input_data_writeable=False,
verbose=True,
):
-
self.verbose = verbose
self._input_data_writeable = input_data_writeable
self._grad_data_type = data_type
@@ -100,7 +97,6 @@ def obs_data(self):
@obs_data.setter
def obs_data(self, input_data):
-
if self._grad_data_type in {float, np.floating}:
input_data = check_float(input_data)
check_npndarray(
@@ -128,7 +124,6 @@ def op(self):
@op.setter
def op(self, operator):
-
self._op = check_callable(operator)
@property
@@ -147,7 +142,6 @@ def trans_op(self):
@trans_op.setter
def trans_op(self, operator):
-
self._trans_op = check_callable(operator)
@property
@@ -157,7 +151,6 @@ def get_grad(self):
@get_grad.setter
def get_grad(self, method):
-
self._get_grad = check_callable(method)
@property
@@ -167,7 +160,6 @@ def grad(self):
@grad.setter
def grad(self, input_value):
-
if self._grad_data_type in {float, np.floating}:
input_value = check_float(input_value)
self._grad = input_value
@@ -179,7 +171,6 @@ def cost(self):
@cost.setter
def cost(self, method):
-
self._cost = check_callable(method)
def trans_op_op(self, input_data):
@@ -243,7 +234,6 @@ class GradBasic(GradParent):
"""
def __init__(self, *args, **kwargs):
-
super().__init__(*args, **kwargs)
self.get_grad = self._get_grad_method
self.cost = self._cost_method
@@ -289,7 +279,7 @@ def _cost_method(self, *args, **kwargs):
"""
cost_val = 0.5 * np.linalg.norm(self.obs_data - self.op(args[0])) ** 2
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - DATA FIDELITY (X):', cost_val)
+ if kwargs.get("verbose"):
+ print(" - DATA FIDELITY (X):", cost_val)
return cost_val
diff --git a/modopt/opt/linear/__init__.py b/src/modopt/opt/linear/__init__.py
similarity index 100%
rename from modopt/opt/linear/__init__.py
rename to src/modopt/opt/linear/__init__.py
diff --git a/modopt/opt/linear/base.py b/src/modopt/opt/linear/base.py
similarity index 92%
rename from modopt/opt/linear/base.py
rename to src/modopt/opt/linear/base.py
index e347970d..af748a73 100644
--- a/modopt/opt/linear/base.py
+++ b/src/modopt/opt/linear/base.py
@@ -5,7 +5,8 @@
from modopt.base.types import check_callable
from modopt.base.backend import get_array_module
-class LinearParent(object):
+
+class LinearParent:
"""Linear Operator Parent Class.
This class sets the structure for defining linear operator instances.
@@ -29,7 +30,6 @@ class LinearParent(object):
"""
def __init__(self, op, adj_op):
-
self.op = op
self.adj_op = adj_op
@@ -40,7 +40,6 @@ def op(self):
@op.setter
def op(self, operator):
-
self._op = check_callable(operator)
@property
@@ -50,7 +49,6 @@ def adj_op(self):
@adj_op.setter
def adj_op(self, operator):
-
self._adj_op = check_callable(operator)
@@ -66,10 +64,9 @@ class Identity(LinearParent):
"""
def __init__(self):
-
self.op = lambda input_data: input_data
self.adj_op = self.op
- self.cost= lambda *args, **kwargs: 0
+ self.cost = lambda *args, **kwargs: 0
class MatrixOperator(LinearParent):
@@ -126,7 +123,6 @@ class LinearCombo(LinearParent):
"""
def __init__(self, operators, weights=None):
-
operators, weights = self._check_inputs(operators, weights)
self.operators = operators
self.weights = weights
@@ -159,14 +155,13 @@ def _check_type(self, input_val):
"""
if not isinstance(input_val, (list, tuple, np.ndarray)):
raise TypeError(
- 'Invalid input type, input must be a list, tuple or numpy '
- + 'array.',
+ "Invalid input type, input must be a list, tuple or numpy " + "array.",
)
input_val = np.array(input_val)
if not input_val.size:
- raise ValueError('Input list is empty.')
+ raise ValueError("Input list is empty.")
return input_val
@@ -199,11 +194,10 @@ def _check_inputs(self, operators, weights):
operators = self._check_type(operators)
for operator in operators:
-
- if not hasattr(operator, 'op'):
+ if not hasattr(operator, "op"):
raise ValueError('Operators must contain "op" method.')
- if not hasattr(operator, 'adj_op'):
+ if not hasattr(operator, "adj_op"):
raise ValueError('Operators must contain "adj_op" method.')
operator.op = check_callable(operator.op)
@@ -214,12 +208,11 @@ def _check_inputs(self, operators, weights):
if weights.size != operators.size:
raise ValueError(
- 'The number of weights must match the number of '
- + 'operators.',
+ "The number of weights must match the number of " + "operators.",
)
if not np.issubdtype(weights.dtype, np.floating):
- raise TypeError('The weights must be a list of float values.')
+ raise TypeError("The weights must be a list of float values.")
return operators, weights
diff --git a/modopt/opt/linear/wavelet.py b/src/modopt/opt/linear/wavelet.py
similarity index 91%
rename from modopt/opt/linear/wavelet.py
rename to src/modopt/opt/linear/wavelet.py
index 5feead66..8012a072 100644
--- a/modopt/opt/linear/wavelet.py
+++ b/src/modopt/opt/linear/wavelet.py
@@ -45,8 +45,7 @@ class WaveletConvolve(LinearParent):
"""
- def __init__(self, filters, method='scipy'):
-
+ def __init__(self, filters, method="scipy"):
self._filters = check_float(filters)
self.op = lambda input_data: filter_convolve_stack(
input_data,
@@ -61,13 +60,11 @@ def __init__(self, filters, method='scipy'):
)
-
-
class WaveletTransform(LinearParent):
"""
2D and 3D wavelet transform class.
- This is a wrapper around either Pywavelet (CPU) or Pytorch Wavelet (GPU using Pytorch).
+ This is a wrapper around either Pywavelet (CPU) or Pytorch Wavelet (GPU).
Parameters
----------
@@ -81,33 +78,41 @@ class WaveletTransform(LinearParent):
mode: str, default "zero"
Boundary Condition mode
compute_backend: str, "numpy" or "cupy", default "numpy"
- Backend library to use. "cupy" also requires a working installation of PyTorch and pytorch wavelets.
+ Backend library to use. "cupy" also requires a working installation of PyTorch
+ and PyTorch wavelets (ptwt).
**kwargs: extra kwargs for Pywavelet or Pytorch Wavelet
"""
- def __init__(self,
+
+ def __init__(
+ self,
wavelet_name,
shape,
level=4,
mode="symmetric",
compute_backend="numpy",
- **kwargs):
-
+ **kwargs,
+ ):
if compute_backend == "cupy" and ptwt_available:
- self.operator = CupyWaveletTransform(wavelet=wavelet_name, shape=shape, level=level, mode=mode)
+ self.operator = CupyWaveletTransform(
+ wavelet=wavelet_name, shape=shape, level=level, mode=mode
+ )
elif compute_backend == "numpy" and pywt_available:
- self.operator = CPUWaveletTransform(wavelet_name=wavelet_name, shape=shape, level=level, **kwargs)
+ self.operator = CPUWaveletTransform(
+ wavelet_name=wavelet_name, shape=shape, level=level, **kwargs
+ )
else:
raise ValueError(f"Compute Backend {compute_backend} not available")
-
self.op = self.operator.op
self.adj_op = self.operator.adj_op
@property
def coeffs_shape(self):
+ """Get the coeffs shapes."""
return self.operator.coeffs_shape
+
class CPUWaveletTransform(LinearParent):
"""
2D and 3D wavelet transform class.
@@ -159,7 +164,8 @@ def __init__(
):
if not pywt_available:
raise ImportError(
- "PyWavelet and/or joblib are not available. Please install it to use WaveletTransform."
+ "PyWavelet and/or joblib are not available. "
+ "Please install it to use WaveletTransform."
)
if wavelet_name not in pywt.wavelist(kind="all"):
raise ValueError(
@@ -193,7 +199,9 @@ def __init__(
self.n_batch = n_batch
if self.n_batch == 1 and self.n_jobs != 1:
- warnings.warn("Making n_jobs = 1 for WaveletTransform as n_batchs = 1")
+ warnings.warn(
+ "Making n_jobs = 1 for WaveletTransform as n_batchs = 1", stacklevel=1
+ )
self.n_jobs = 1
self.backend = backend
n_proc = self.n_jobs
@@ -273,22 +281,22 @@ def _adj_op(self, coeffs):
class TorchWaveletTransform:
"""Wavelet transform using pytorch."""
- wavedec3_keys = ["aad", "ada", "add", "daa", "dad", "dda", "ddd"]
+ wavedec3_keys = ("aad", "ada", "add", "daa", "dad", "dda", "ddd")
def __init__(
self,
- shape: tuple[int, ...],
- wavelet: str,
- level: int,
- mode: str,
+ shape,
+ wavelet,
+ level,
+ mode,
):
self.wavelet = wavelet
self.level = level
self.shape = shape
self.mode = mode
- self.coeffs_shape = None # will be set after op.
+ self.coeffs_shape = None # will be set after op.
- def op(self, data: torch.Tensor) -> list[torch.Tensor]:
+ def op(self, data):
"""Apply the wavelet decomposition on.
Parameters
@@ -350,7 +358,7 @@ def op(self, data: torch.Tensor) -> list[torch.Tensor]:
)
return [coeffs_[0]] + [cc for c in coeffs_[1:] for cc in c.values()]
- def adj_op(self, coeffs: list[torch.Tensor]) -> torch.Tensor:
+ def adj_op(self, coeffs):
"""Apply the wavelet recomposition.
Parameters
@@ -409,20 +417,22 @@ class CupyWaveletTransform(LinearParent):
def __init__(
self,
- shape: tuple[int, ...],
- wavelet: str,
- level: int,
- mode: str,
+ shape,
+ wavelet,
+ level,
+ mode,
):
self.wavelet = wavelet
self.level = level
self.shape = shape
self.mode = mode
- self.operator = TorchWaveletTransform(shape=shape, wavelet=wavelet, level=level,mode=mode)
- self.coeffs_shape = None # will be set after op
+ self.operator = TorchWaveletTransform(
+ shape=shape, wavelet=wavelet, level=level, mode=mode
+ )
+ self.coeffs_shape = None # will be set after op
- def op(self, data: cp.array) -> cp.ndarray:
+ def op(self, data):
"""Define the wavelet operator.
This method returns the input data convolved with the wavelet filter.
@@ -452,7 +462,7 @@ def op(self, data: cp.array) -> cp.ndarray:
return ret
- def adj_op(self, data: cp.ndarray) -> cp.ndarray:
+ def adj_op(self, data):
"""Define the wavelet adjoint operator.
This method returns the reconstructed image.
diff --git a/modopt/opt/proximity.py b/src/modopt/opt/proximity.py
similarity index 89%
rename from modopt/opt/proximity.py
rename to src/modopt/opt/proximity.py
index fc81a753..dea862ca 100644
--- a/modopt/opt/proximity.py
+++ b/src/modopt/opt/proximity.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""PROXIMITY OPERATORS.
This module contains classes of proximity operators for optimisation.
@@ -32,7 +30,7 @@
from modopt.signal.svd import svd_thresh, svd_thresh_coef, svd_thresh_coef_fast
-class ProximityParent(object):
+class ProximityParent:
"""Proximity Operator Parent Class.
This class sets the structure for defining proximity operator instances.
@@ -48,7 +46,6 @@ class ProximityParent(object):
"""
def __init__(self, op, cost):
-
self.op = op
self.cost = cost
@@ -59,7 +56,6 @@ def op(self):
@op.setter
def op(self, operator):
-
self._op = check_callable(operator)
@property
@@ -79,7 +75,6 @@ def cost(self):
@cost.setter
def cost(self, method):
-
self._cost = check_callable(method)
@@ -99,9 +94,8 @@ class IdentityProx(ProximityParent):
"""
def __init__(self):
-
- self.op = lambda x_val: x_val
- self.cost = lambda x_val: 0
+ self.op = lambda x_val, *args, **kwargs: x_val
+ self.cost = lambda x_val, *args, **kwargs: 0
class Positivity(ProximityParent):
@@ -117,10 +111,25 @@ class Positivity(ProximityParent):
"""
def __init__(self):
-
- self.op = lambda input_data: positive(input_data)
self.cost = self._cost_method
+ def op(self, input_data, *args, **kwargs):
+ """
+ Make the data positive.
+
+ Parameters
+ ----------
+ input_data: np.ndarray
+ Input array
+ *args, **kwargs: dummy.
+
+ Returns
+ -------
+ np.ndarray
+ Positive data.
+ """
+ return positive(input_data)
+
def _cost_method(self, *args, **kwargs):
"""Calculate positivity component of the cost.
@@ -140,8 +149,8 @@ def _cost_method(self, *args, **kwargs):
``0.0``
"""
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - Min (X):', np.min(args[0]))
+ if kwargs.get("verbose"):
+ print(" - Min (X):", np.min(args[0]))
return 0
@@ -167,8 +176,7 @@ class SparseThreshold(ProximityParent):
"""
- def __init__(self, linear, weights, thresh_type='soft'):
-
+ def __init__(self, linear, weights, thresh_type="soft"):
self._linear = linear
self.weights = weights
self._thresh_type = thresh_type
@@ -221,8 +229,8 @@ def _cost_method(self, *args, **kwargs):
if isinstance(cost_val, xp.ndarray):
cost_val = cost_val.item()
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - L1 NORM (X):', cost_val)
+ if kwargs.get("verbose"):
+ print(" - L1 NORM (X):", cost_val)
return cost_val
@@ -273,12 +281,11 @@ class LowRankMatrix(ProximityParent):
def __init__(
self,
threshold,
- thresh_type='soft',
- lowr_type='standard',
+ thresh_type="soft",
+ lowr_type="standard",
initial_rank=None,
operator=None,
):
-
self.thresh = threshold
self.thresh_type = thresh_type
self.lowr_type = lowr_type
@@ -315,13 +322,13 @@ def _op_method(self, input_data, extra_factor=1.0, rank=None):
"""
# Update threshold with extra factor.
threshold = self.thresh * extra_factor
- if self.lowr_type == 'standard' and self.rank is None and rank is None:
+ if self.lowr_type == "standard" and self.rank is None and rank is None:
data_matrix = svd_thresh(
cube2matrix(input_data),
threshold,
thresh_type=self.thresh_type,
)
- elif self.lowr_type == 'standard':
+ elif self.lowr_type == "standard":
data_matrix, update_rank = svd_thresh_coef_fast(
cube2matrix(input_data),
threshold,
@@ -331,7 +338,7 @@ def _op_method(self, input_data, extra_factor=1.0, rank=None):
)
self.rank = update_rank # save for future use
- elif self.lowr_type == 'ngole':
+ elif self.lowr_type == "ngole":
data_matrix = svd_thresh_coef(
cube2matrix(input_data),
self.operator,
@@ -339,7 +346,7 @@ def _op_method(self, input_data, extra_factor=1.0, rank=None):
thresh_type=self.thresh_type,
)
else:
- raise ValueError('lowr_type should be standard or ngole')
+ raise ValueError("lowr_type should be standard or ngole")
# Return updated data.
return matrix2cube(data_matrix, input_data.shape[1:])
@@ -365,8 +372,8 @@ def _cost_method(self, *args, **kwargs):
"""
cost_val = self.thresh * nuclear_norm(cube2matrix(args[0]))
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - NUCLEAR NORM (X):', cost_val)
+ if kwargs.get("verbose"):
+ print(" - NUCLEAR NORM (X):", cost_val)
return cost_val
@@ -470,7 +477,6 @@ class ProximityCombo(ProximityParent):
"""
def __init__(self, operators):
-
operators = self._check_operators(operators)
self.operators = operators
self.op = self._op_method
@@ -506,19 +512,19 @@ def _check_operators(self, operators):
"""
if not isinstance(operators, (list, tuple, np.ndarray)):
raise TypeError(
- 'Invalid input type, operators must be a list, tuple or '
- + 'numpy array.',
+ "Invalid input type, operators must be a list, tuple or "
+ + "numpy array.",
)
operators = np.array(operators)
if not operators.size:
- raise ValueError('Operator list is empty.')
+ raise ValueError("Operator list is empty.")
for operator in operators:
- if not hasattr(operator, 'op'):
+ if not hasattr(operator, "op"):
raise ValueError('Operators must contain "op" method.')
- if not hasattr(operator, 'cost'):
+ if not hasattr(operator, "cost"):
raise ValueError('Operators must contain "cost" method.')
operator.op = check_callable(operator.op)
operator.cost = check_callable(operator.cost)
@@ -573,10 +579,12 @@ def _cost_method(self, *args, **kwargs):
Combinded cost components
"""
- return np.sum([
- operator.cost(input_data)
- for operator, input_data in zip(self.operators, args[0])
- ])
+ return np.sum(
+ [
+ operator.cost(input_data)
+ for operator, input_data in zip(self.operators, args[0])
+ ]
+ )
class OrderedWeightedL1Norm(ProximityParent):
@@ -617,16 +625,16 @@ class OrderedWeightedL1Norm(ProximityParent):
def __init__(self, weights):
if not import_sklearn: # pragma: no cover
raise ImportError(
- 'Required version of Scikit-Learn package not found see '
- + 'documentation for details: '
- + 'https://cea-cosmic.github.io/ModOpt/#optional-packages',
+ "Required version of Scikit-Learn package not found see "
+ + "documentation for details: "
+ + "https://cea-cosmic.github.io/ModOpt/#optional-packages",
)
if np.max(np.diff(weights)) > 0:
- raise ValueError('Weights must be non increasing')
+ raise ValueError("Weights must be non increasing")
self.weights = weights.flatten()
if (self.weights < 0).any():
raise ValueError(
- 'The weight values must be provided in descending order',
+ "The weight values must be provided in descending order",
)
self.op = self._op_method
self.cost = self._cost_method
@@ -664,7 +672,9 @@ def _op_method(self, input_data, extra_factor=1.0):
# Projection onto the monotone non-negative cone using
# isotonic_regression
data_abs = isotonic_regression(
- data_abs - threshold, y_min=0, increasing=False,
+ data_abs - threshold,
+ y_min=0,
+ increasing=False,
)
# Unsorting the data
@@ -672,7 +682,7 @@ def _op_method(self, input_data, extra_factor=1.0):
data_abs_unsorted[data_abs_sort_idx] = data_abs
# Putting the sign back
- with np.errstate(invalid='ignore'):
+ with np.errstate(invalid="ignore"):
sign_data = data_squeezed / np.abs(data_squeezed)
# Removing NAN caused by the sign
@@ -702,8 +712,8 @@ def _cost_method(self, *args, **kwargs):
self.weights * np.sort(np.squeeze(np.abs(args[0]))[::-1]),
)
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - OWL NORM (X):', cost_val)
+ if kwargs.get("verbose"):
+ print(" - OWL NORM (X):", cost_val)
return cost_val
@@ -734,8 +744,7 @@ class Ridge(ProximityParent):
"""
- def __init__(self, linear, weights, thresh_type='soft'):
-
+ def __init__(self, linear, weights, thresh_type="soft"):
self._linear = linear
self.weights = weights
self.op = self._op_method
@@ -786,8 +795,8 @@ def _cost_method(self, *args, **kwargs):
np.abs(self.weights * self._linear.op(args[0]) ** 2),
)
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - L2 NORM (X):', cost_val)
+ if kwargs.get("verbose"):
+ print(" - L2 NORM (X):", cost_val)
return cost_val
@@ -822,7 +831,6 @@ class ElasticNet(ProximityParent):
"""
def __init__(self, linear, alpha, beta):
-
self._linear = linear
self.alpha = alpha
self.beta = beta
@@ -848,8 +856,8 @@ def _op_method(self, input_data, extra_factor=1.0):
"""
soft_threshold = self.beta * extra_factor
- normalization = (self.alpha * 2 * extra_factor + 1)
- return thresh(input_data, soft_threshold, 'soft') / normalization
+ normalization = self.alpha * 2 * extra_factor + 1
+ return thresh(input_data, soft_threshold, "soft") / normalization
def _cost_method(self, *args, **kwargs):
"""Calculate Ridge component of the cost.
@@ -875,8 +883,8 @@ def _cost_method(self, *args, **kwargs):
+ np.abs(self.beta * self._linear.op(args[0])),
)
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - ELASTIC NET (X):', cost_val)
+ if kwargs.get("verbose"):
+ print(" - ELASTIC NET (X):", cost_val)
return cost_val
@@ -942,7 +950,7 @@ def k_value(self):
def k_value(self, k_val):
if k_val < 1:
raise ValueError(
- 'The k parameter should be greater or equal than 1',
+ "The k parameter should be greater or equal than 1",
)
self._k_value = k_val
@@ -987,7 +995,7 @@ def _compute_theta(self, input_data, alpha, extra_factor=1.0):
alpha_beta = alpha_input - self.beta * extra_factor
theta = alpha_beta * ((alpha_beta <= 1) & (alpha_beta >= 0))
theta = np.nan_to_num(theta)
- theta += (alpha_input > (self.beta * extra_factor + 1))
+ theta += alpha_input > (self.beta * extra_factor + 1)
return theta
def _interpolate(self, alpha0, alpha1, sum0, sum1):
@@ -1078,12 +1086,10 @@ def _binary_search(self, input_data, alpha, extra_factor=1.0):
midpoint = 0
while (first_idx <= last_idx) and not found and (cnt < alpha.shape[0]):
-
midpoint = (first_idx + last_idx) // 2
cnt += 1
if prev_midpoint == midpoint:
-
# Particular case
sum0 = self._compute_theta(
data_abs,
@@ -1096,11 +1102,11 @@ def _binary_search(self, input_data, alpha, extra_factor=1.0):
extra_factor,
).sum()
- if (np.abs(sum0 - self._k_value) <= tolerance):
+ if np.abs(sum0 - self._k_value) <= tolerance:
found = True
midpoint = first_idx
- if (np.abs(sum1 - self._k_value) <= tolerance):
+ if np.abs(sum1 - self._k_value) <= tolerance:
found = True
midpoint = last_idx - 1
# -1 because output is index such that
@@ -1145,13 +1151,17 @@ def _binary_search(self, input_data, alpha, extra_factor=1.0):
if found:
return (
- midpoint, alpha[midpoint], alpha[midpoint + 1], sum0, sum1,
+ midpoint,
+ alpha[midpoint],
+ alpha[midpoint + 1],
+ sum0,
+ sum1,
)
raise ValueError(
- 'Cannot find the coordinate of alpha (i) such '
- + 'that sum(theta(alpha[i])) =< k and '
- + 'sum(theta(alpha[i+1])) >= k ',
+ "Cannot find the coordinate of alpha (i) such "
+ + "that sum(theta(alpha[i])) =< k and "
+ + "sum(theta(alpha[i+1])) >= k ",
)
def _find_alpha(self, input_data, extra_factor=1.0):
@@ -1175,15 +1185,13 @@ def _find_alpha(self, input_data, extra_factor=1.0):
data_size = input_data.shape[0]
# Computes the alpha^i points line 1 in Algorithm 1.
- alpha = np.zeros((data_size * 2))
+ alpha = np.zeros(data_size * 2)
data_abs = np.abs(input_data)
- alpha[:data_size] = (
- (self.beta * extra_factor)
- / (data_abs + sys.float_info.epsilon)
+ alpha[:data_size] = (self.beta * extra_factor) / (
+ data_abs + sys.float_info.epsilon
)
- alpha[data_size:] = (
- (self.beta * extra_factor + 1)
- / (data_abs + sys.float_info.epsilon)
+ alpha[data_size:] = (self.beta * extra_factor + 1) / (
+ data_abs + sys.float_info.epsilon
)
alpha = np.sort(np.unique(alpha))
@@ -1220,8 +1228,8 @@ def _op_method(self, input_data, extra_factor=1.0):
k_max = np.prod(data_shape)
if self._k_value > k_max:
warn(
- 'K value of the K-support norm is greater than the input '
- + 'dimension, its value will be set to {0}'.format(k_max),
+ "K value of the K-support norm is greater than the input "
+ + f"dimension, its value will be set to {k_max}",
)
self._k_value = k_max
@@ -1233,8 +1241,7 @@ def _op_method(self, input_data, extra_factor=1.0):
# Computes line 5. in Algorithm 1.
rslt = np.nan_to_num(
- (input_data.flatten() * theta)
- / (theta + self.beta * extra_factor),
+ (input_data.flatten() * theta) / (theta + self.beta * extra_factor),
)
return rslt.reshape(data_shape)
@@ -1275,25 +1282,20 @@ def _find_q(self, sorted_data):
found = True
q_val = 0
- elif (
- (sorted_data[self._k_value - 1:].sum())
- <= sorted_data[self._k_value - 1]
- ):
+ elif (sorted_data[self._k_value - 1 :].sum()) <= sorted_data[self._k_value - 1]:
found = True
q_val = self._k_value - 1
while (
- not found and not cnt == self._k_value
+ not found
+ and not cnt == self._k_value
and (first_idx <= last_idx < self._k_value)
):
-
q_val = (first_idx + last_idx) // 2
cnt += 1
l1_part = sorted_data[q_val:].sum() / (self._k_value - q_val)
- if (
- sorted_data[q_val + 1] <= l1_part <= sorted_data[q_val]
- ):
+ if sorted_data[q_val + 1] <= l1_part <= sorted_data[q_val]:
found = True
else:
@@ -1328,15 +1330,12 @@ def _cost_method(self, *args, **kwargs):
data_abs = data_abs[ix] # Sorted absolute value of the data
q_val = self._find_q(data_abs)
cost_val = (
- (
- np.sum(data_abs[:q_val] ** 2) * 0.5
- + np.sum(data_abs[q_val:]) ** 2
- / (self._k_value - q_val)
- ) * self.beta
- )
+ np.sum(data_abs[:q_val] ** 2) * 0.5
+ + np.sum(data_abs[q_val:]) ** 2 / (self._k_value - q_val)
+ ) * self.beta
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - K-SUPPORT NORM (X):', cost_val)
+ if kwargs.get("verbose"):
+ print(" - K-SUPPORT NORM (X):", cost_val)
return cost_val
@@ -1381,7 +1380,7 @@ def __init__(self, weights):
self.op = self._op_method
self.cost = self._cost_method
- def _op_method(self, input_data, extra_factor=1.0):
+ def _op_method(self, input_data, *args, extra_factor=1.0, **kwargs):
"""Operator.
This method returns the input data thresholded by the weights.
@@ -1392,6 +1391,7 @@ def _op_method(self, input_data, extra_factor=1.0):
Input data array
extra_factor : float
Additional multiplication factor (default is ``1.0``)
+ *args, **kwargs: no effects
Returns
-------
@@ -1407,7 +1407,7 @@ def _op_method(self, input_data, extra_factor=1.0):
(1.0 - self.weights * extra_factor / denominator),
)
- def _cost_method(self, input_data):
+ def _cost_method(self, input_data, *args, **kwargs):
"""Calculate the group LASSO component of the cost.
This method calculate the cost function of the proximable part.
@@ -1417,6 +1417,8 @@ def _cost_method(self, input_data):
input_data : numpy.ndarray
Input array of the sparse code
+ *args, **kwargs: no effects.
+
Returns
-------
float
diff --git a/modopt/opt/reweight.py b/src/modopt/opt/reweight.py
similarity index 91%
rename from modopt/opt/reweight.py
rename to src/modopt/opt/reweight.py
index 8c4f2449..4a9bf44b 100644
--- a/modopt/opt/reweight.py
+++ b/src/modopt/opt/reweight.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""REWEIGHTING CLASSES.
This module contains classes for reweighting optimisation implementations.
@@ -13,7 +11,7 @@
from modopt.base.types import check_float
-class cwbReweight(object):
+class cwbReweight:
"""Candes, Wakin and Boyd reweighting class.
This class implements the reweighting scheme described in
@@ -45,7 +43,6 @@ class cwbReweight(object):
"""
def __init__(self, weights, thresh_factor=1.0, verbose=False):
-
self.weights = check_float(weights)
self.original_weights = np.copy(self.weights)
self.thresh_factor = check_float(thresh_factor)
@@ -81,7 +78,7 @@ def reweight(self, input_data):
"""
if self.verbose:
- print(' - Reweighting: {0}'.format(self._rw_num))
+ print(f" - Reweighting: {self._rw_num}")
self._rw_num += 1
@@ -89,7 +86,7 @@ def reweight(self, input_data):
if input_data.shape != self.weights.shape:
raise ValueError(
- 'Input data must have the same shape as the initial weights.',
+ "Input data must have the same shape as the initial weights.",
)
thresh_weights = self.thresh_factor * self.original_weights
diff --git a/modopt/plot/__init__.py b/src/modopt/plot/__init__.py
similarity index 73%
rename from modopt/plot/__init__.py
rename to src/modopt/plot/__init__.py
index 28d60be6..f31ed596 100644
--- a/modopt/plot/__init__.py
+++ b/src/modopt/plot/__init__.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""PLOTTING ROUTINES.
This module contains submodules for plotting applications.
@@ -8,4 +6,4 @@
"""
-__all__ = ['cost_plot']
+__all__ = ["cost_plot"]
diff --git a/modopt/plot/cost_plot.py b/src/modopt/plot/cost_plot.py
similarity index 66%
rename from modopt/plot/cost_plot.py
rename to src/modopt/plot/cost_plot.py
index aa855eaa..7fb7e39b 100644
--- a/modopt/plot/cost_plot.py
+++ b/src/modopt/plot/cost_plot.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""PLOTTING ROUTINES.
This module contains methods for making plots.
@@ -37,20 +35,20 @@ def plotCost(cost_list, output=None):
"""
if import_fail:
- raise ImportError('Matplotlib package not found')
+ raise ImportError("Matplotlib package not found")
else:
if isinstance(output, type(None)):
- file_name = 'cost_function.png'
+ file_name = "cost_function.png"
else:
- file_name = '{0}_cost_function.png'.format(output)
+ file_name = f"{output}_cost_function.png"
plt.figure()
- plt.plot(np.log10(cost_list), 'r-')
- plt.title('Cost Function')
- plt.xlabel('Iteration')
- plt.ylabel(r'$\log_{10}$ Cost')
+ plt.plot(np.log10(cost_list), "r-")
+ plt.title("Cost Function")
+ plt.xlabel("Iteration")
+ plt.ylabel(r"$\log_{10}$ Cost")
plt.savefig(file_name)
plt.close()
- print(' - Saving cost function data to:', file_name)
+ print(" - Saving cost function data to:", file_name)
diff --git a/modopt/signal/__init__.py b/src/modopt/signal/__init__.py
similarity index 58%
rename from modopt/signal/__init__.py
rename to src/modopt/signal/__init__.py
index dbc6d053..6bf0912b 100644
--- a/modopt/signal/__init__.py
+++ b/src/modopt/signal/__init__.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""SIGNAL PROCESSING ROUTINES.
This module contains submodules for signal processing.
@@ -8,4 +6,4 @@
"""
-__all__ = ['filter', 'noise', 'positivity', 'svd', 'validation', 'wavelet']
+__all__ = ["filter", "noise", "positivity", "svd", "validation", "wavelet"]
diff --git a/modopt/signal/filter.py b/src/modopt/signal/filter.py
similarity index 96%
rename from modopt/signal/filter.py
rename to src/modopt/signal/filter.py
index 84dd8160..33c3c105 100644
--- a/modopt/signal/filter.py
+++ b/src/modopt/signal/filter.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""FILTER ROUTINES.
This module contains methods for distance measurements in cosmology.
@@ -81,7 +79,7 @@ def mex_hat(data_point, sigma):
sigma = check_float(sigma)
xs = (data_point / sigma) ** 2
- factor = 2 * (3 * sigma) ** -0.5 * np.pi ** -0.25
+ factor = 2 * (3 * sigma) ** -0.5 * np.pi**-0.25
return factor * (1 - xs) * np.exp(-0.5 * xs)
diff --git a/modopt/signal/noise.py b/src/modopt/signal/noise.py
similarity index 86%
rename from modopt/signal/noise.py
rename to src/modopt/signal/noise.py
index a59d5553..28307f52 100644
--- a/modopt/signal/noise.py
+++ b/src/modopt/signal/noise.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""NOISE ROUTINES.
This module contains methods for adding and removing noise from data.
@@ -8,14 +6,12 @@
"""
-from builtins import zip
-
import numpy as np
from modopt.base.backend import get_array_module
-def add_noise(input_data, sigma=1.0, noise_type='gauss'):
+def add_noise(input_data, sigma=1.0, noise_type="gauss", rng=None):
"""Add noise to data.
This method adds Gaussian or Poisson noise to the input data.
@@ -29,6 +25,9 @@ def add_noise(input_data, sigma=1.0, noise_type='gauss'):
default is ``1.0``)
noise_type : {'gauss', 'poisson'}
Type of noise to be added (default is ``'gauss'``)
+ rng: np.random.Generator or int
+ A Random number generator or a seed to initialize one.
+
Returns
-------
@@ -68,9 +67,12 @@ def add_noise(input_data, sigma=1.0, noise_type='gauss'):
array([ 3.24869073, -1.22351283, -1.0563435 , -2.14593724, 1.73081526])
"""
+ if not isinstance(rng, np.random.Generator):
+ rng = np.random.default_rng(rng)
+
input_data = np.array(input_data)
- if noise_type not in {'gauss', 'poisson'}:
+ if noise_type not in {"gauss", "poisson"}:
raise ValueError(
'Invalid noise type. Options are "gauss" or "poisson"',
)
@@ -78,15 +80,14 @@ def add_noise(input_data, sigma=1.0, noise_type='gauss'):
if isinstance(sigma, (list, tuple, np.ndarray)):
if len(sigma) != input_data.shape[0]:
raise ValueError(
- 'Number of sigma values must match first dimension of input '
- + 'data',
+ "Number of sigma values must match first dimension of input " + "data",
)
- if noise_type == 'gauss':
- random = np.random.randn(*input_data.shape)
+ if noise_type == "gauss":
+ random = rng.standard_normal(input_data.shape)
- elif noise_type == 'poisson':
- random = np.random.poisson(np.abs(input_data))
+ elif noise_type == "poisson":
+ random = rng.poisson(np.abs(input_data))
if isinstance(sigma, (int, float)):
return input_data + sigma * random
@@ -96,7 +97,7 @@ def add_noise(input_data, sigma=1.0, noise_type='gauss'):
return input_data + noise
-def thresh(input_data, threshold, threshold_type='hard'):
+def thresh(input_data, threshold, threshold_type="hard"):
r"""Threshold data.
This method perfoms hard or soft thresholding on the input data.
@@ -169,12 +170,12 @@ def thresh(input_data, threshold, threshold_type='hard'):
input_data = xp.array(input_data)
- if threshold_type not in {'hard', 'soft'}:
+ if threshold_type not in {"hard", "soft"}:
raise ValueError(
'Invalid threshold type. Options are "hard" or "soft"',
)
- if threshold_type == 'soft':
+ if threshold_type == "soft":
denominator = xp.maximum(xp.finfo(np.float64).eps, xp.abs(input_data))
max_value = xp.maximum((1.0 - threshold / denominator), 0)
diff --git a/modopt/signal/positivity.py b/src/modopt/signal/positivity.py
similarity index 90%
rename from modopt/signal/positivity.py
rename to src/modopt/signal/positivity.py
index c19ba62c..8d7aa46c 100644
--- a/modopt/signal/positivity.py
+++ b/src/modopt/signal/positivity.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""POSITIVITY.
This module contains a function that retains only positive coefficients in
@@ -47,7 +45,7 @@ def pos_recursive(input_data):
Positive coefficients
"""
- if input_data.dtype == 'O':
+ if input_data.dtype == "O":
res = np.array([pos_recursive(elem) for elem in input_data], dtype="object")
else:
@@ -97,15 +95,15 @@ def positive(input_data, ragged=False):
"""
if not isinstance(input_data, (int, float, list, tuple, np.ndarray)):
raise TypeError(
- 'Invalid data type, input must be `int`, `float`, `list`, '
- + '`tuple` or `np.ndarray`.',
+ "Invalid data type, input must be `int`, `float`, `list`, "
+ + "`tuple` or `np.ndarray`.",
)
if isinstance(input_data, (int, float)):
return pos_thresh(input_data)
if ragged:
- input_data = np.array(input_data, dtype='object')
+ input_data = np.array(input_data, dtype="object")
else:
input_data = np.array(input_data)
diff --git a/modopt/signal/svd.py b/src/modopt/signal/svd.py
similarity index 85%
rename from modopt/signal/svd.py
rename to src/modopt/signal/svd.py
index f3d40a51..cf147503 100644
--- a/modopt/signal/svd.py
+++ b/src/modopt/signal/svd.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""SVD ROUTINES.
This module contains methods for thresholding singular values.
@@ -52,8 +50,8 @@ def find_n_pc(u_vec, factor=0.5):
"""
if np.sqrt(u_vec.shape[0]) % 1:
raise ValueError(
- 'Invalid left singular vector. The size of the first '
- + 'dimenion of ``u_vec`` must be perfect square.',
+ "Invalid left singular vector. The size of the first "
+ + "dimenion of ``u_vec`` must be perfect square.",
)
# Get the shape of the array
@@ -69,13 +67,12 @@ def find_n_pc(u_vec, factor=0.5):
]
# Return the required number of principal components.
- return np.sum([
- (
- u_val[tuple(zip(array_shape // 2))] ** 2 <= factor
- * np.sum(u_val ** 2),
- )
- for u_val in u_auto
- ])
+ return np.sum(
+ [
+ (u_val[tuple(zip(array_shape // 2))] ** 2 <= factor * np.sum(u_val**2),)
+ for u_val in u_auto
+ ]
+ )
def calculate_svd(input_data):
@@ -101,17 +98,17 @@ def calculate_svd(input_data):
"""
if (not isinstance(input_data, np.ndarray)) or (input_data.ndim != 2):
- raise TypeError('Input data must be a 2D np.ndarray.')
+ raise TypeError("Input data must be a 2D np.ndarray.")
return svd(
input_data,
check_finite=False,
- lapack_driver='gesvd',
+ lapack_driver="gesvd",
full_matrices=False,
)
-def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'):
+def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type="hard"):
"""Threshold the singular values.
This method thresholds the input data using singular value decomposition.
@@ -156,16 +153,11 @@ def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'):
"""
less_than_zero = isinstance(n_pc, int) and n_pc <= 0
- str_not_all = isinstance(n_pc, str) and n_pc != 'all'
+ str_not_all = isinstance(n_pc, str) and n_pc != "all"
- if (
- (not isinstance(n_pc, (int, str, type(None))))
- or less_than_zero
- or str_not_all
- ):
+ if (not isinstance(n_pc, (int, str, type(None)))) or less_than_zero or str_not_all:
raise ValueError(
- 'Invalid value for "n_pc", specify a positive integer value or '
- + '"all"',
+ 'Invalid value for "n_pc", specify a positive integer value or ' + '"all"',
)
# Get SVD of input data.
@@ -176,15 +168,14 @@ def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'):
# Find the required number of principal components if not specified.
if isinstance(n_pc, type(None)):
n_pc = find_n_pc(u_vec, factor=0.1)
- print('xxxx', n_pc, u_vec)
+ print("xxxx", n_pc, u_vec)
# If the number of PCs is too large use all of the singular values.
- if (
- (isinstance(n_pc, int) and n_pc >= s_values.size)
- or (isinstance(n_pc, str) and n_pc == 'all')
+ if (isinstance(n_pc, int) and n_pc >= s_values.size) or (
+ isinstance(n_pc, str) and n_pc == "all"
):
n_pc = s_values.size
- warn('Using all singular values.')
+ warn("Using all singular values.")
threshold = s_values[n_pc - 1]
@@ -192,7 +183,7 @@ def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'):
s_new = thresh(s_values, threshold, thresh_type)
if np.all(s_new == s_values):
- warn('No change to singular values.')
+ warn("No change to singular values.")
# Diagonalize the svd
s_new = np.diag(s_new)
@@ -206,7 +197,7 @@ def svd_thresh_coef_fast(
threshold,
n_vals=-1,
extra_vals=5,
- thresh_type='hard',
+ thresh_type="hard",
):
"""Threshold the singular values coefficients.
@@ -241,7 +232,7 @@ def svd_thresh_coef_fast(
ok = False
while not ok:
(u_vec, s_values, v_vec) = svds(input_data, k=n_vals)
- ok = (s_values[0] <= threshold or n_vals == min(input_data.shape) - 1)
+ ok = s_values[0] <= threshold or n_vals == min(input_data.shape) - 1
n_vals = min(n_vals + extra_vals, *input_data.shape)
s_values = thresh(
@@ -259,7 +250,7 @@ def svd_thresh_coef_fast(
)
-def svd_thresh_coef(input_data, operator, threshold, thresh_type='hard'):
+def svd_thresh_coef(input_data, operator, threshold, thresh_type="hard"):
"""Threshold the singular values coefficients.
This method thresholds the input data using singular value decomposition.
@@ -287,7 +278,7 @@ def svd_thresh_coef(input_data, operator, threshold, thresh_type='hard'):
"""
if not callable(operator):
- raise TypeError('Operator must be a callable function.')
+ raise TypeError("Operator must be a callable function.")
# Get SVD of data matrix
u_vec, s_values, v_vec = calculate_svd(input_data)
@@ -302,10 +293,9 @@ def svd_thresh_coef(input_data, operator, threshold, thresh_type='hard'):
array_shape = np.repeat(int(np.sqrt(u_vec.shape[0])), 2)
# Compute threshold matrix.
- ti = np.array([
- np.linalg.norm(elem)
- for elem in operator(matrix2cube(u_vec, array_shape))
- ])
+ ti = np.array(
+ [np.linalg.norm(elem) for elem in operator(matrix2cube(u_vec, array_shape))]
+ )
threshold *= np.repeat(ti, a_matrix.shape[1]).reshape(a_matrix.shape)
# Threshold coefficients.
diff --git a/modopt/signal/validation.py b/src/modopt/signal/validation.py
similarity index 80%
rename from modopt/signal/validation.py
rename to src/modopt/signal/validation.py
index 422a987b..cdf69b7d 100644
--- a/modopt/signal/validation.py
+++ b/src/modopt/signal/validation.py
@@ -16,6 +16,7 @@ def transpose_test(
x_args=None,
y_shape=None,
y_args=None,
+ rng=None,
):
"""Transpose test.
@@ -36,6 +37,8 @@ def transpose_test(
Shape of transpose operator input data (default is ``None``)
y_args : tuple, optional
Arguments to be passed to transpose operator (default is ``None``)
+ rng: numpy.random.Generator or int or None (default is ``None``)
+ Initialized random number generator or seed.
Raises
------
@@ -54,7 +57,7 @@ def transpose_test(
"""
if not callable(operator) or not callable(operator_t):
- raise TypeError('The input operators must be callable functions.')
+ raise TypeError("The input operators must be callable functions.")
if isinstance(y_shape, type(None)):
y_shape = x_shape
@@ -62,9 +65,11 @@ def transpose_test(
if isinstance(y_args, type(None)):
y_args = x_args
+ if not isinstance(rng, np.random.Generator):
+ rng = np.random.default_rng(rng)
# Generate random arrays.
- x_val = np.random.ranf(x_shape)
- y_val = np.random.ranf(y_shape)
+ x_val = rng.random(x_shape)
+ y_val = rng.random(y_shape)
# Calculate
mx_y = np.sum(np.multiply(operator(x_val, x_args), y_val))
@@ -73,4 +78,4 @@ def transpose_test(
x_mty = np.sum(np.multiply(x_val, operator_t(y_val, y_args)))
# Test the difference between the two.
- print(' - | - | =', np.abs(mx_y - x_mty))
+ print(" - | - | =", np.abs(mx_y - x_mty))
diff --git a/modopt/signal/wavelet.py b/src/modopt/signal/wavelet.py
similarity index 88%
rename from modopt/signal/wavelet.py
rename to src/modopt/signal/wavelet.py
index bc4ffc70..b55b78d9 100644
--- a/modopt/signal/wavelet.py
+++ b/src/modopt/signal/wavelet.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""WAVELET MODULE.
This module contains methods for performing wavelet transformations using
@@ -58,20 +56,20 @@ def execute(command_line):
"""
if not isinstance(command_line, str):
- raise TypeError('Command line must be a string.')
+ raise TypeError("Command line must be a string.")
command = command_line.split()
process = sp.Popen(command, stdout=sp.PIPE, stderr=sp.PIPE)
stdout, stderr = process.communicate()
- return stdout.decode('utf-8'), stderr.decode('utf-8')
+ return stdout.decode("utf-8"), stderr.decode("utf-8")
def call_mr_transform(
input_data,
- opt='',
- path='./',
+ opt="",
+ path="./",
remove_files=True,
): # pragma: no cover
"""Call ``mr_transform``.
@@ -127,26 +125,23 @@ def call_mr_transform(
"""
if not import_astropy:
- raise ImportError('Astropy package not found.')
+ raise ImportError("Astropy package not found.")
if (not isinstance(input_data, np.ndarray)) or (input_data.ndim != 2):
- raise ValueError('Input data must be a 2D numpy array.')
+ raise ValueError("Input data must be a 2D numpy array.")
- executable = 'mr_transform'
+ executable = "mr_transform"
# Make sure mr_transform is installed.
is_executable(executable)
# Create a unique string using the current date and time.
- unique_string = (
- datetime.now().strftime('%Y.%m.%d_%H.%M.%S')
- + str(getrandbits(128))
- )
+ unique_string = datetime.now().strftime("%Y.%m.%d_%H.%M.%S") + str(getrandbits(128))
# Set the ouput file names.
- file_name = '{0}mr_temp_{1}'.format(path, unique_string)
- file_fits = '{0}.fits'.format(file_name)
- file_mr = '{0}.mr'.format(file_name)
+ file_name = f"{path}mr_temp_{unique_string}"
+ file_fits = f"{file_name}.fits"
+ file_mr = f"{file_name}.mr"
# Write the input data to a fits file.
fits.writeto(file_fits, input_data)
@@ -155,15 +150,15 @@ def call_mr_transform(
opt = opt.split()
# Prepare command and execute it
- command_line = ' '.join([executable] + opt + [file_fits, file_mr])
+ command_line = " ".join([executable, *opt, file_fits, file_mr])
stdout, _ = execute(command_line)
# Check for errors
- if any(word in stdout for word in ('bad', 'Error', 'Sorry')):
+ if any(word in stdout for word in ("bad", "Error", "Sorry")):
remove(file_fits)
message = '{0} raised following exception: "{1}"'
raise RuntimeError(
- message.format(executable, stdout.rstrip('\n')),
+ message.format(executable, stdout.rstrip("\n")),
)
# Retrieve wavelet transformed data.
@@ -198,12 +193,12 @@ def trim_filter(filter_array):
min_idx = np.min(non_zero_indices, axis=-1)
max_idx = np.max(non_zero_indices, axis=-1)
- return filter_array[min_idx[0]:max_idx[0] + 1, min_idx[1]:max_idx[1] + 1]
+ return filter_array[min_idx[0] : max_idx[0] + 1, min_idx[1] : max_idx[1] + 1]
def get_mr_filters(
data_shape,
- opt='',
+ opt="",
coarse=False,
trim=False,
): # pragma: no cover
@@ -256,7 +251,7 @@ def get_mr_filters(
return mr_filters[:-1]
-def filter_convolve(input_data, filters, filter_rot=False, method='scipy'):
+def filter_convolve(input_data, filters, filter_rot=False, method="scipy"):
"""Filter convolve.
This method convolves the input image with the wavelet filters.
@@ -315,16 +310,14 @@ def filter_convolve(input_data, filters, filter_rot=False, method='scipy'):
axis=0,
)
- return np.array([
- convolve(input_data, filt, method=method) for filt in filters
- ])
+ return np.array([convolve(input_data, filt, method=method) for filt in filters])
def filter_convolve_stack(
input_data,
filters,
filter_rot=False,
- method='scipy',
+ method="scipy",
):
"""Filter convolve.
@@ -366,7 +359,9 @@ def filter_convolve_stack(
"""
# Return the convolved data cube.
- return np.array([
- filter_convolve(elem, filters, filter_rot=filter_rot, method=method)
- for elem in input_data
- ])
+ return np.array(
+ [
+ filter_convolve(elem, filters, filter_rot=filter_rot, method=method)
+ for elem in input_data
+ ]
+ )
diff --git a/modopt/tests/test_algorithms.py b/tests/test_algorithms.py
similarity index 96%
rename from modopt/tests/test_algorithms.py
rename to tests/test_algorithms.py
index 5671b8e3..63847764 100644
--- a/modopt/tests/test_algorithms.py
+++ b/tests/test_algorithms.py
@@ -1,5 +1,3 @@
-# -*- coding: utf-8 -*-
-
"""UNIT TESTS FOR Algorithms.
This module contains unit tests for the modopt.opt module.
@@ -11,18 +9,13 @@
import numpy as np
import numpy.testing as npt
-import pytest
from modopt.opt import algorithms, cost, gradient, linear, proximity, reweight
from pytest_cases import (
- case,
fixture,
- fixture_ref,
- lazy_value,
parametrize,
parametrize_with_cases,
)
-from test_helpers import Dummy
SKLEARN_AVAILABLE = True
try:
@@ -31,6 +24,9 @@
SKLEARN_AVAILABLE = False
+rng = np.random.default_rng()
+
+
@fixture
def idty():
"""Identity function."""
@@ -80,7 +76,7 @@ def build_kwargs(kwargs, use_metrics):
@parametrize(use_metrics=[True, False])
class AlgoCases:
- """Cases for algorithms.
+ r"""Cases for algorithms.
Most of the test solves the trivial problem
@@ -91,7 +87,7 @@ class AlgoCases:
"""
data1 = np.arange(9).reshape(3, 3).astype(float)
- data2 = data1 + np.random.randn(*data1.shape) * 1e-6
+ data2 = data1 + rng.standard_normal(data1.shape) * 1e-6
max_iter = 20
@parametrize(
@@ -111,8 +107,7 @@ class AlgoCases:
]
)
def case_forward_backward(self, kwargs, idty, use_metrics):
- """Forward Backward case.
- """
+ """Forward Backward case."""
update_kwargs = build_kwargs(kwargs, use_metrics)
algo = algorithms.ForwardBackward(
self.data1,
@@ -242,9 +237,11 @@ def case_grad(self, GradDescent, use_metrics, idty):
)
algo.iterate()
return algo, update_kwargs
- @parametrize(admm=[algorithms.ADMM,algorithms.FastADMM])
+
+ @parametrize(admm=[algorithms.ADMM, algorithms.FastADMM])
def case_admm(self, admm, use_metrics, idty):
"""ADMM setup."""
+
def optim1(init, obs):
return obs
@@ -265,6 +262,7 @@ def optim2(init, obs):
algo.iterate()
return algo, update_kwargs
+
@parametrize_with_cases("algo, kwargs", cases=AlgoCases)
def test_algo(algo, kwargs):
"""Test algorithms."""
diff --git a/modopt/tests/test_base.py b/tests/test_base.py
similarity index 99%
rename from modopt/tests/test_base.py
rename to tests/test_base.py
index e32ff94b..62e09095 100644
--- a/modopt/tests/test_base.py
+++ b/tests/test_base.py
@@ -5,6 +5,7 @@
Samuel Farrens
Pierre-Antoine Comby
"""
+
import numpy as np
import numpy.testing as npt
import pytest
@@ -148,7 +149,7 @@ def test_matrix2cube(self):
class TestType:
"""Test for type module."""
- data_list = list(range(5))
+ data_list = list(range(5)) # noqa: RUF012
data_int = np.arange(5)
data_flt = np.arange(5).astype(float)
diff --git a/tests/test_helpers/__init__.py b/tests/test_helpers/__init__.py
new file mode 100644
index 00000000..0ded847a
--- /dev/null
+++ b/tests/test_helpers/__init__.py
@@ -0,0 +1,5 @@
+"""Utilities for tests."""
+
+from .utils import Dummy, failparam, skipparam
+
+__all__ = ["Dummy", "failparam", "skipparam"]
diff --git a/modopt/tests/test_helpers/utils.py b/tests/test_helpers/utils.py
similarity index 75%
rename from modopt/tests/test_helpers/utils.py
rename to tests/test_helpers/utils.py
index d8227640..41f948a6 100644
--- a/modopt/tests/test_helpers/utils.py
+++ b/tests/test_helpers/utils.py
@@ -1,9 +1,11 @@
"""
Some helper functions for the test parametrization.
+
They should be used inside ``@pytest.mark.parametrize`` call.
:Author: Pierre-Antoine Comby
"""
+
import pytest
@@ -11,13 +13,15 @@ def failparam(*args, raises=None):
"""Return a pytest parameterization that should raise an error."""
if not issubclass(raises, Exception):
raise ValueError("raises should be an expected Exception.")
- return pytest.param(*args, marks=pytest.mark.raises(exception=raises))
+ return pytest.param(*args, marks=[pytest.mark.xfail(exception=raises)])
def skipparam(*args, cond=True, reason=""):
"""Return a pytest parameterization that should be skip if cond is valid."""
- return pytest.param(*args, marks=pytest.mark.skipif(cond, reason=reason))
+ return pytest.param(*args, marks=[pytest.mark.skipif(cond, reason=reason)])
class Dummy:
+ """Dummy Class."""
+
pass
diff --git a/modopt/tests/test_math.py b/tests/test_math.py
similarity index 95%
rename from modopt/tests/test_math.py
rename to tests/test_math.py
index ea177b15..5c466e5e 100644
--- a/modopt/tests/test_math.py
+++ b/tests/test_math.py
@@ -6,6 +6,7 @@
Samuel Farrens
Pierre-Antoine Comby
"""
+
import pytest
from test_helpers import failparam, skipparam
@@ -28,6 +29,8 @@
else:
SKIMAGE_AVAILABLE = True
+rng = np.random.default_rng(1)
+
class TestConvolve:
"""Test convolve functions."""
@@ -135,18 +138,15 @@ class TestMatrix:
),
)
- @pytest.fixture
+ @pytest.fixture(scope="module")
def pm_instance(self, request):
"""Power Method instance."""
- np.random.seed(1)
pm = matrix.PowerMethod(
lambda x_val: x_val.dot(x_val.T),
self.array33.shape,
- auto_run=request.param,
verbose=True,
+ rng=np.random.default_rng(0),
)
- if not request.param:
- pm.get_spec_rad(max_iter=1)
return pm
@pytest.mark.parametrize(
@@ -194,12 +194,7 @@ def test_rotate(self):
npt.assert_raises(ValueError, matrix.rotate, self.array23, np.pi / 2)
- @pytest.mark.parametrize(
- ("pm_instance", "value"),
- [(True, 1.0), (False, 0.8675467477372257)],
- indirect=["pm_instance"],
- )
- def test_power_method(self, pm_instance, value):
+ def test_power_method(self, pm_instance, value=1):
"""Test power method."""
npt.assert_almost_equal(pm_instance.spec_rad, value)
npt.assert_almost_equal(pm_instance.inv_spec_rad, 1 / value)
@@ -210,8 +205,8 @@ class TestMetrics:
data1 = np.arange(49).reshape(7, 7)
mask = np.ones(data1.shape)
- ssim_res = 0.8963363560519094
- ssim_mask_res = 0.805154442543846
+ ssim_res = 0.8958315888566867
+ ssim_mask_res = 0.8023827544418249
snr_res = 10.134554256920536
psnr_res = 14.860761791850397
mse_res = 0.03265305507330247
diff --git a/modopt/tests/test_opt.py b/tests/test_opt.py
similarity index 92%
rename from modopt/tests/test_opt.py
rename to tests/test_opt.py
index 7c30186e..2ea58c27 100644
--- a/modopt/tests/test_opt.py
+++ b/tests/test_opt.py
@@ -36,11 +36,28 @@
except ImportError:
PYWT_AVAILABLE = False
+rng = np.random.default_rng()
+
+
# Basic functions to be used as operators or as dummy functions
-func_identity = lambda x_val: x_val
-func_double = lambda x_val: x_val * 2
-func_sq = lambda x_val: x_val**2
-func_cube = lambda x_val: x_val**3
+def func_identity(x_val, *args, **kwargs):
+ """Return x."""
+ return x_val
+
+
+def func_double(x_val, *args, **kwargs):
+ """Double x."""
+ return x_val * 2
+
+
+def func_sq(x_val, *args, **kwargs):
+ """Square x."""
+ return x_val**2
+
+
+def func_cube(x_val, *args, **kwargs):
+ """Cube x."""
+ return x_val**3
@case(tags="cost")
@@ -183,19 +200,35 @@ def case_linear_wavelet_convolve(self):
@parametrize(
compute_backend=[
- pytest.param("numpy", marks=pytest.mark.skipif(not PYWT_AVAILABLE, reason="PyWavelet not available.")),
- pytest.param("cupy", marks=pytest.mark.skipif(not PTWT_AVAILABLE, reason="Pytorch Wavelet not available."))
- ])
- def case_linear_wavelet_transform(self, compute_backend="numpy"):
+ pytest.param(
+ "numpy",
+ marks=pytest.mark.skipif(
+ not PYWT_AVAILABLE, reason="PyWavelet not available."
+ ),
+ ),
+ pytest.param(
+ "cupy",
+ marks=pytest.mark.skipif(
+ not PTWT_AVAILABLE, reason="Pytorch Wavelet not available."
+ ),
+ ),
+ ]
+ )
+ def case_linear_wavelet_transform(self, compute_backend):
+ """Case linear wavelet operator."""
linop = linear.WaveletTransform(
wavelet_name="haar",
shape=(8, 8),
level=2,
)
data_op = np.arange(64).reshape(8, 8).astype(float)
- res_op, slices, shapes = pywt.ravel_coeffs(pywt.wavedecn(data_op, "haar", level=2))
+ res_op, slices, shapes = pywt.ravel_coeffs(
+ pywt.wavedecn(data_op, "haar", level=2)
+ )
data_adj_op = linop.op(data_op)
- res_adj_op = pywt.waverecn(pywt.unravel_coeffs(data_adj_op, slices, shapes, "wavedecn"), "haar")
+ res_adj_op = pywt.waverecn(
+ pywt.unravel_coeffs(data_adj_op, slices, shapes, "wavedecn"), "haar"
+ )
return linop, data_op, data_adj_op, res_op, res_adj_op
@parametrize(weights=[[1.0, 1.0], None])
@@ -299,6 +332,7 @@ class ProxCases:
array233 = np.arange(18).reshape(2, 3, 3).astype(float)
array233_2 = np.array(
+ [
[
[2.73843189, 3.14594066, 3.55344943],
[3.9609582, 4.36846698, 4.77597575],
@@ -309,6 +343,7 @@ class ProxCases:
[11.67394789, 12.87497954, 14.07601119],
[15.27704284, 16.47807449, 17.67910614],
],
+ ]
)
array233_3 = np.array(
[
@@ -428,7 +463,7 @@ def case_prox_grouplasso(self, use_weights):
else:
weights = np.tile(np.zeros((3, 3)), (4, 1, 1))
- random_data = 3 * np.random.random(weights[0].shape)
+ random_data = 3 * rng.random(weights[0].shape)
random_data_tile = np.tile(random_data, (weights.shape[0], 1, 1))
if use_weights:
gl_result_data = 2 * random_data_tile - 3
diff --git a/modopt/tests/test_signal.py b/tests/test_signal.py
similarity index 96%
rename from modopt/tests/test_signal.py
rename to tests/test_signal.py
index 202e541b..6dbb0bba 100644
--- a/modopt/tests/test_signal.py
+++ b/tests/test_signal.py
@@ -16,7 +16,8 @@
class TestFilter:
- """Test filter module"""
+ """Test filter module."""
+
@pytest.mark.parametrize(
("norm", "result"), [(True, 0.24197072451914337), (False, 0.60653065971263342)]
)
@@ -24,7 +25,6 @@ def test_gaussian_filter(self, norm, result):
"""Test gaussian filter."""
npt.assert_almost_equal(filter.gaussian_filter(1, 1, norm=norm), result)
-
def test_mex_hat(self):
"""Test mexican hat filter."""
npt.assert_almost_equal(
@@ -32,7 +32,6 @@ def test_mex_hat(self):
-0.35213905225713371,
)
-
def test_mex_hat_dir(self):
"""Test directional mexican hat filter."""
npt.assert_almost_equal(
@@ -46,13 +45,13 @@ class TestNoise:
data1 = np.arange(9).reshape(3, 3).astype(float)
data2 = np.array(
- [[0, 2.0, 2.0], [4.0, 5.0, 10], [11.0, 15.0, 18.0]],
+ [[0, 3.0, 4.0], [6.0, 9.0, 8.0], [14.0, 14.0, 17.0]],
)
data3 = np.array(
[
- [1.62434536, 0.38824359, 1.47182825],
- [1.92703138, 4.86540763, 2.6984613],
- [7.74481176, 6.2387931, 8.3190391],
+ [0.3455842, 1.8216181, 2.3304371],
+ [1.6968428, 4.9053559, 5.4463746],
+ [5.4630468, 7.5811181, 8.3645724],
]
)
data4 = np.array([[0, 0, 0], [0, 0, 5.0], [6.0, 7.0, 8.0]])
@@ -71,9 +70,10 @@ class TestNoise:
)
def test_add_noise(self, data, noise_type, sigma, data_noise):
"""Test add_noise."""
- np.random.seed(1)
+ rng = np.random.default_rng(1)
npt.assert_almost_equal(
- noise.add_noise(data, sigma=sigma, noise_type=noise_type), data_noise
+ noise.add_noise(data, sigma=sigma, noise_type=noise_type, rng=rng),
+ data_noise,
)
@pytest.mark.parametrize(
@@ -86,13 +86,16 @@ def test_thresh(self, threshold_type, result):
noise.thresh(self.data1, 5, threshold_type=threshold_type), result
)
+
class TestPositivity:
"""Test positivity module."""
+
data1 = np.arange(9).reshape(3, 3).astype(float)
data4 = np.array([[0, 0, 0], [0, 0, 5.0], [6.0, 7.0, 8.0]])
data5 = np.array(
[[0, 0, 0], [0, 0, 0], [1.0, 2.0, 3.0]],
)
+
@pytest.mark.parametrize(
("value", "expected"),
[
@@ -231,6 +234,7 @@ def test_svd_thresh_coef(self, data, operator):
# TODO test_svd_thresh_coef_fast
+
class TestValidation:
"""Test validation Module."""
@@ -238,13 +242,13 @@ class TestValidation:
def test_transpose_test(self):
"""Test transpose_test."""
- np.random.seed(2)
npt.assert_equal(
validation.transpose_test(
lambda x_val, y_val: x_val.dot(y_val),
lambda x_val, y_val: x_val.dot(y_val.T),
self.array33.shape,
x_args=self.array33,
+ rng=2,
),
None,
)