From 753499946b4383c0b4041ac1e0e51d38cab7f0d9 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Wed, 14 Feb 2024 18:59:32 +0100
Subject: [PATCH 01/34] use a pyproject.toml file.
---
.github/workflows/ci-build.yml | 2 +-
.pylintrc | 2 -
.pyup.yml | 14 -----
MANIFEST.in | 5 --
develop.txt | 12 -----
pyproject.toml | 56 ++++++++++++++++++++
requirements.txt | 4 --
setup.cfg | 97 ----------------------------------
setup.py | 73 -------------------------
9 files changed, 57 insertions(+), 208 deletions(-)
delete mode 100644 .pylintrc
delete mode 100644 .pyup.yml
delete mode 100644 MANIFEST.in
delete mode 100644 develop.txt
create mode 100644 pyproject.toml
delete mode 100644 requirements.txt
delete mode 100644 setup.cfg
delete mode 100644 setup.py
diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml
index c4ba28a0..88129d45 100644
--- a/.github/workflows/ci-build.yml
+++ b/.github/workflows/ci-build.yml
@@ -44,7 +44,7 @@ jobs:
python -m pip install -r develop.txt
python -m pip install -r docs/requirements.txt
python -m pip install astropy "scikit-image<0.20" scikit-learn matplotlib
- python -m pip install tensorflow>=2.4.1
+ python -m pip install tensorflow>=2.4.1 torch
python -m pip install twine
python -m pip install .
diff --git a/.pylintrc b/.pylintrc
deleted file mode 100644
index 3ac9aef9..00000000
--- a/.pylintrc
+++ /dev/null
@@ -1,2 +0,0 @@
-[MASTER]
-ignore-patterns=**/docs/**/*.py
diff --git a/.pyup.yml b/.pyup.yml
deleted file mode 100644
index 8fdac7ff..00000000
--- a/.pyup.yml
+++ /dev/null
@@ -1,14 +0,0 @@
-# autogenerated pyup.io config file
-# see https://pyup.io/docs/configuration/ for all available options
-
-schedule: ''
-update: all
-label_prs: update
-assignees: sfarrens
-requirements:
- - requirements.txt:
- pin: False
- - develop.txt:
- pin: False
- - docs/requirements.txt:
- pin: True
diff --git a/MANIFEST.in b/MANIFEST.in
deleted file mode 100644
index 9a2f374e..00000000
--- a/MANIFEST.in
+++ /dev/null
@@ -1,5 +0,0 @@
-include requirements.txt
-include develop.txt
-include docs/requirements.txt
-include README.rst
-include LICENSE.txt
diff --git a/develop.txt b/develop.txt
deleted file mode 100644
index 6ff665eb..00000000
--- a/develop.txt
+++ /dev/null
@@ -1,12 +0,0 @@
-coverage>=5.5
-pytest>=6.2.2
-pytest-raises>=0.10
-pytest-cases>= 3.6
-pytest-xdist>= 3.0.1
-pytest-cov>=2.11.1
-pytest-emoji>=0.2.0
-pydocstyle==6.1.1
-pytest-pydocstyle>=2.2.0
-black
-isort
-pytest-black
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 00000000..71bdce82
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,56 @@
+[project]
+name="modopt"
+description = 'Modular Optimisation tools for soliving inverse problems.'
+version = "1.7.1"
+requires-python= ">=3.8"
+
+authors = [{name="Samuel Farrens", email="samuel.farrens@cea.fr"},
+{name="Chaithya GR", email="chaithyagr@gmail.com"},
+{name="Pierre-Antoine Comby", email="pierre-antoine.comby@cea.fr"}
+]
+readme="README.md"
+license={file="LICENCE.txt"}
+
+dependencies = ["numpy", "scipy", "tqdm"]
+
+[project.optional-dependencies]
+gpu=["torch", "ptwt"]
+doc=["myst-parser==0.16.1",
+"nbsphinx==0.8.7",
+"nbsphinx-link==1.3.0",
+"sphinx-gallery==0.11.1",
+"sphinxawesome-theme==3.2.1",
+"sphinxcontrib-bibtex"]
+dev=["black", "pytest<8.0.0", "pytest-cases", "pytest-cov", "pytest-sugar", "ruff"]
+
+[build-system]
+requires=["setuptools", "setuptools-scm[toml]", "wheel"]
+
+[tool.setuptools]
+packages=["modopt"]
+
+[tool.coverage.run]
+omit = ["*tests*", "*__init__*", "*setup.py*", "*_version.py*", "*example*"]
+
+[tool.coverage.report]
+precision = 2
+exclude_lines = ["pragma: no cover", "raise NotImplementedError"]
+
+[tool.black]
+
+[tool.ruff]
+
+src=["modopt"]
+select = ["E", "F", "B", "Q", "UP", "D", "NPY", "RUF"]
+
+[tool.ruff.pydocstyle]
+convention="numpy"
+
+[tool.isort]
+profile="black"
+
+[tool.pytest.ini_options]
+minversion = "6.0"
+norecursedirs = ["tests/helpers"]
+testpaths=["modopt"]
+addopts = ["--verbose", "--cov=modopt", "--cov-report=term-missing", "--cov-report=xml", "--junitxml=pytest.xml"]
diff --git a/requirements.txt b/requirements.txt
deleted file mode 100644
index 1f44de13..00000000
--- a/requirements.txt
+++ /dev/null
@@ -1,4 +0,0 @@
-importlib_metadata>=3.7.0
-numpy>=1.19.5
-scipy>=1.5.4
-tqdm>=4.64.0
diff --git a/setup.cfg b/setup.cfg
deleted file mode 100644
index 100adb40..00000000
--- a/setup.cfg
+++ /dev/null
@@ -1,97 +0,0 @@
-[aliases]
-test=pytest
-
-[metadata]
-description_file = README.rst
-
-[darglint]
-docstring_style = numpy
-strictness = short
-
-[flake8]
-ignore =
- D107, #Justification: Don't need docstring for __init__ in numpydoc style
- RST304, #Justification: Need to use :cite: role for citations
- RST210, #Justification: RST210, RST213 Inconsistent with numpydoc
- RST213, # documentation for handling *args and **kwargs
- W503, #Justification: Have to choose one multiline operator format
- WPS202, #Todo: Rethink module size, possibly split large modules
- WPS337, #Todo: Consider simplifying multiline conditions.
- WPS338, #Todo: Consider changing method order
- WPS403, #Todo: Rethink no cover lines
- WPS421, #Todo: Review need for print statements
- WPS432, #Justification: Mathematical codes require "magic numbers"
- WPS433, #Todo: Rethink conditional imports
- WPS463, #Todo: Rename get_ methods
- WPS615, #Todo: Rename get_ methods
-per-file-ignores =
- #Justification: Needed for keeping package version and current API
- *__init__.py*: F401,F403,WPS347,WPS410,WPS412
- #Todo: Rethink conditional imports
- #Todo: How can we bypass mutable constants?
- modopt/base/backend.py: WPS229, WPS420, WPS407
- #Todo: Rethink conditional imports
- modopt/base/observable.py: WPS420,WPS604
- #Todo: Check string for log formatting
- modopt/interface/log.py: WPS323
- #Todo: Rethink conditional imports
- modopt/math/convolve.py: WPS301,WPS420
- #Todo: Rethink conditional imports
- modopt/math/matrix.py: WPS420
- #Todo: import has bad parenthesis
- modopt/opt/algorithms/__init__.py: F401,F403,WPS318, WPS319, WPS412, WPS410
- #Todo: x is a too short name.
- modopt/opt/algorithms/forward_backward.py: WPS111
- #Todo: u,v , A is a too short name.
- modopt/opt/algorithms/admm.py: WPS111, N803
- #Todo: Check need for del statement
- modopt/opt/algorithms/primal_dual.py: WPS111, WPS420
- #multiline parameters bug with tuples
- modopt/opt/algorithms/gradient_descent.py: WPS111, WPS420, WPS317
- #Todo: Consider changing costObj name
- modopt/opt/cost.py: N801,
- #Todo:
- # - Rethink subscript slice assignment
- # - Reduce complexity of KSupportNorm
- # - Check bitwise operations
- modopt/opt/proximity.py: WPS220,WPS231,WPS352,WPS362,WPS465,WPS506,WPS508
- #Todo: Consider changing cwbReweight name
- modopt/opt/reweight.py: N801
- #Justification: Needed to import matplotlib.pyplot
- modopt/plot/cost_plot.py: N802,WPS301
- #Todo: Investigate possible bug in find_n_pc function
- #Todo: Investigate darglint error
- modopt/signal/svd.py: WPS345, DAR000
- #Todo: Check security of using system executable call
- modopt/signal/wavelet.py: S404,S603
- #Todo: Clean up tests
- modopt/tests/*.py: E731,F401,WPS301,WPS420,WPS425,WPS437,WPS604
- #Todo: Import has bad parenthesis
- modopt/tests/test_base.py: WPS318,WPS319,E501,WPS301
-#WPS Settings
-max-arguments = 25
-max-attributes = 40
-max-cognitive-score = 20
-max-function-expressions = 20
-max-line-complexity = 30
-max-local-variables = 10
-max-methods = 20
-max-module-expressions = 20
-max-string-usages = 20
-max-raises = 5
-
-[tool:pytest]
-norecursedirs=tests/test_helpers
-testpaths =
- modopt
-addopts =
- --verbose
- --cov=modopt
- --cov-report=term-missing
- --cov-report=xml
- --junitxml=pytest.xml
- --pydocstyle
-
-[pydocstyle]
-convention=numpy
-add-ignore=D107
diff --git a/setup.py b/setup.py
deleted file mode 100644
index e6a8a9e6..00000000
--- a/setup.py
+++ /dev/null
@@ -1,73 +0,0 @@
-#! /usr/bin/env python
-# -*- coding: utf-8 -*-
-
-from setuptools import setup, find_packages
-import os
-
-# Set the package release version
-major = 1
-minor = 7
-patch = 1
-
-# Set the package details
-name = 'modopt'
-version = '.'.join(str(value) for value in (major, minor, patch))
-author = 'Samuel Farrens'
-email = 'samuel.farrens@cea.fr'
-gh_user = 'cea-cosmic'
-url = 'https://github.com/{0}/{1}'.format(gh_user, name)
-description = 'Modular Optimisation tools for soliving inverse problems.'
-license = 'MIT'
-
-# Set the package classifiers
-python_versions_supported = ['3.7', '3.8', '3.9', '3.10', '3.11']
-os_platforms_supported = ['Unix', 'MacOS']
-
-lc_str = 'License :: OSI Approved :: {0} License'
-ln_str = 'Programming Language :: Python'
-py_str = 'Programming Language :: Python :: {0}'
-os_str = 'Operating System :: {0}'
-
-classifiers = (
- [lc_str.format(license)]
- + [ln_str]
- + [py_str.format(ver) for ver in python_versions_supported]
- + [os_str.format(ops) for ops in os_platforms_supported]
-)
-
-# Source package description from README.md
-this_directory = os.path.abspath(os.path.dirname(__file__))
-with open(os.path.join(this_directory, 'README.md'), encoding='utf-8') as f:
- long_description = f.read()
-
-# Source package requirements from requirements.txt
-with open('requirements.txt') as open_file:
- install_requires = open_file.read()
-
-# Source test requirements from develop.txt
-with open('develop.txt') as open_file:
- tests_require = open_file.read()
-
-# Source doc requirements from docs/requirements.txt
-with open('docs/requirements.txt') as open_file:
- docs_require = open_file.read()
-
-
-setup(
- name=name,
- author=author,
- author_email=email,
- version=version,
- license=license,
- url=url,
- description=description,
- long_description=long_description,
- long_description_content_type='text/markdown',
- packages=find_packages(),
- install_requires=install_requires,
- python_requires='>=3.6',
- setup_requires=['pytest-runner'],
- tests_require=tests_require,
- extras_require={'develop': tests_require + docs_require},
- classifiers=classifiers,
-)
From 02a16d4588cada4edde0a4616b5e27d11e8489dd Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Wed, 14 Feb 2024 19:05:05 +0100
Subject: [PATCH 02/34] feat: add a style checking CI.
---
.github/workflows/style.yml | 38 +++++++++++++++++++++++++++++++++++++
1 file changed, 38 insertions(+)
create mode 100644 .github/workflows/style.yml
diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml
new file mode 100644
index 00000000..04ce6da6
--- /dev/null
+++ b/.github/workflows/style.yml
@@ -0,0 +1,38 @@
+name: Style checking
+
+on:
+ push:
+ branches: [ "master", "main", "develop" ]
+ pull_request:
+ branches: [ "master", "main", "develop" ]
+
+ workflow_dispatch:
+
+env:
+ PYTHON_VERSION: "3.10"
+
+jobs:
+ linter-check:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v3
+ - name: Set up Python ${{ env.PYTHON_VERSION }}
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ env.PYTHON_VERSION }}
+ cache: pip
+
+ - name: Install Python deps
+ shell: bash
+ run: |
+ python -m pip install --upgrade pip
+ python -m pip install -e .[test,dev]
+
+ - name: Black Check
+ shell: bash
+ run: black . --diff --color --check
+
+ - name: ruff Check
+ shell: bash
+ run: ruff check
From d9aa9643815e4415686bac98d92edc9f38fc9ca6 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Fri, 16 Feb 2024 21:18:17 +0100
Subject: [PATCH 03/34] fix: missing bracket.
---
modopt/tests/test_opt.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/modopt/tests/test_opt.py b/modopt/tests/test_opt.py
index 7c30186e..1c2e7824 100644
--- a/modopt/tests/test_opt.py
+++ b/modopt/tests/test_opt.py
@@ -299,6 +299,7 @@ class ProxCases:
array233 = np.arange(18).reshape(2, 3, 3).astype(float)
array233_2 = np.array(
+ [
[
[2.73843189, 3.14594066, 3.55344943],
[3.9609582, 4.36846698, 4.77597575],
From 4d37bef1744a920fb17df8a8ad457fcb312537c1 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Fri, 16 Feb 2024 22:37:45 +0100
Subject: [PATCH 04/34] run black.
---
docs/source/conf.py | 169 +++++++++---------
modopt/__init__.py | 8 +-
modopt/base/__init__.py | 2 +-
modopt/base/backend.py | 58 +++---
modopt/base/np_adjust.py | 6 +-
modopt/base/observable.py | 10 +-
modopt/base/transform.py | 43 ++---
modopt/base/types.py | 27 ++-
modopt/examples/conftest.py | 5 +-
.../example_lasso_forward_backward.py | 4 +-
modopt/interface/__init__.py | 2 +-
modopt/interface/errors.py | 28 ++-
modopt/interface/log.py | 16 +-
modopt/math/__init__.py | 2 +-
modopt/math/convolve.py | 32 ++--
modopt/math/matrix.py | 31 ++--
modopt/math/metrics.py | 14 +-
modopt/opt/__init__.py | 2 +-
modopt/opt/algorithms/__init__.py | 25 +--
modopt/opt/algorithms/admm.py | 11 +-
modopt/opt/algorithms/base.py | 55 +++---
modopt/opt/algorithms/forward_backward.py | 161 ++++++++---------
modopt/opt/algorithms/gradient_descent.py | 23 ++-
modopt/opt/algorithms/primal_dual.py | 44 +++--
modopt/opt/cost.py | 27 ++-
modopt/opt/gradient.py | 4 +-
modopt/opt/linear/base.py | 17 +-
modopt/opt/linear/wavelet.py | 29 +--
modopt/opt/proximity.py | 142 ++++++++-------
modopt/opt/reweight.py | 4 +-
modopt/plot/__init__.py | 2 +-
modopt/plot/cost_plot.py | 16 +-
modopt/signal/__init__.py | 2 +-
modopt/signal/filter.py | 2 +-
modopt/signal/noise.py | 17 +-
modopt/signal/positivity.py | 8 +-
modopt/signal/svd.py | 60 +++----
modopt/signal/validation.py | 4 +-
modopt/signal/wavelet.py | 53 +++---
modopt/tests/test_algorithms.py | 8 +-
modopt/tests/test_base.py | 1 +
modopt/tests/test_helpers/utils.py | 1 +
modopt/tests/test_math.py | 1 +
modopt/tests/test_opt.py | 8 +-
modopt/tests/test_signal.py | 7 +-
45 files changed, 589 insertions(+), 602 deletions(-)
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 46564b9f..987576a9 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -9,56 +9,56 @@
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
-sys.path.insert(0, os.path.abspath('../..'))
+sys.path.insert(0, os.path.abspath("../.."))
# -- General configuration ------------------------------------------------
# General information about the project.
-project = 'modopt'
+project = "modopt"
mdata = metadata(project)
-author = mdata['Author']
-version = mdata['Version']
-copyright = '2020, {}'.format(author)
-gh_user = 'sfarrens'
+author = mdata["Author"]
+version = mdata["Version"]
+copyright = "2020, {}".format(author)
+gh_user = "sfarrens"
# If your documentation needs a minimal Sphinx version, state it here.
-needs_sphinx = '3.3'
+needs_sphinx = "3.3"
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
- 'sphinx.ext.autodoc',
- 'sphinx.ext.autosummary',
- 'sphinx.ext.coverage',
- 'sphinx.ext.doctest',
- 'sphinx.ext.ifconfig',
- 'sphinx.ext.intersphinx',
- 'sphinx.ext.mathjax',
- 'sphinx.ext.napoleon',
- 'sphinx.ext.todo',
- 'sphinx.ext.viewcode',
- 'sphinxawesome_theme',
- 'sphinxcontrib.bibtex',
- 'myst_parser',
- 'nbsphinx',
- 'nbsphinx_link',
- 'numpydoc',
- "sphinx_gallery.gen_gallery"
+ "sphinx.ext.autodoc",
+ "sphinx.ext.autosummary",
+ "sphinx.ext.coverage",
+ "sphinx.ext.doctest",
+ "sphinx.ext.ifconfig",
+ "sphinx.ext.intersphinx",
+ "sphinx.ext.mathjax",
+ "sphinx.ext.napoleon",
+ "sphinx.ext.todo",
+ "sphinx.ext.viewcode",
+ "sphinxawesome_theme",
+ "sphinxcontrib.bibtex",
+ "myst_parser",
+ "nbsphinx",
+ "nbsphinx_link",
+ "numpydoc",
+ "sphinx_gallery.gen_gallery",
]
# Include module names for objects
add_module_names = False
# Set class documentation standard.
-autoclass_content = 'class'
+autoclass_content = "class"
# Audodoc options
autodoc_default_options = {
- 'member-order': 'bysource',
- 'private-members': True,
- 'show-inheritance': True
+ "member-order": "bysource",
+ "private-members": True,
+ "show-inheritance": True,
}
# Generate summaries
@@ -69,17 +69,17 @@
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
-source_suffix = ['.rst', '.md']
+source_suffix = [".rst", ".md"]
# The master toctree document.
-master_doc = 'index'
+master_doc = "index"
# If true, sectionauthor and moduleauthor directives will be shown in the
# output. They are ignored by default.
show_authors = True
# The name of the Pygments (syntax highlighting) style to use.
-pygments_style = 'default'
+pygments_style = "default"
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True
@@ -88,7 +88,7 @@
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
-html_theme = 'sphinxawesome_theme'
+html_theme = "sphinxawesome_theme"
# html_theme = 'sphinx_book_theme'
# Theme options are theme-specific and customize the look and feel of a theme
@@ -101,11 +101,10 @@
"breadcrumbs_separator": "/",
"show_prev_next": True,
"show_scrolltop": True,
-
}
html_collapsible_definitions = True
html_awesome_headerlinks = True
-html_logo = 'modopt_logo.jpg'
+html_logo = "modopt_logo.jpg"
html_permalinks_icon = (
'
'''
+ r""" """
+ r""""""
)
nbsphinx_prolog = nb_header_pt1 + nb_header_pt2
@@ -240,28 +237,28 @@ def add_notebooks(nb_path='../../notebooks'):
# Refer to the package libraries for type definitions
intersphinx_mapping = {
- 'python': ('http://docs.python.org/3', None),
- 'numpy': ('https://numpy.org/doc/stable/', None),
- 'scipy': ('https://docs.scipy.org/doc/scipy/reference', None),
- 'progressbar': ('https://progressbar-2.readthedocs.io/en/latest/', None),
- 'matplotlib': ('https://matplotlib.org', None),
- 'astropy': ('http://docs.astropy.org/en/latest/', None),
- 'cupy': ('https://docs-cupy.chainer.org/en/stable/', None),
- 'torch': ('https://pytorch.org/docs/stable/', None),
- 'sklearn': (
- 'http://scikit-learn.org/stable',
- (None, './_intersphinx/sklearn-objects.inv')
+ "python": ("http://docs.python.org/3", None),
+ "numpy": ("https://numpy.org/doc/stable/", None),
+ "scipy": ("https://docs.scipy.org/doc/scipy/reference", None),
+ "progressbar": ("https://progressbar-2.readthedocs.io/en/latest/", None),
+ "matplotlib": ("https://matplotlib.org", None),
+ "astropy": ("http://docs.astropy.org/en/latest/", None),
+ "cupy": ("https://docs-cupy.chainer.org/en/stable/", None),
+ "torch": ("https://pytorch.org/docs/stable/", None),
+ "sklearn": (
+ "http://scikit-learn.org/stable",
+ (None, "./_intersphinx/sklearn-objects.inv"),
),
- 'tensorflow': (
- 'https://www.tensorflow.org/api_docs/python',
+ "tensorflow": (
+ "https://www.tensorflow.org/api_docs/python",
(
- 'https://github.com/GPflow/tensorflow-intersphinx/'
- + 'raw/master/tf2_py_objects.inv')
- )
-
+ "https://github.com/GPflow/tensorflow-intersphinx/"
+ + "raw/master/tf2_py_objects.inv"
+ ),
+ ),
}
# -- BibTeX Setting ----------------------------------------------
-bibtex_bibfiles = ['refs.bib', 'my_ref.bib']
-bibtex_default_style = 'alpha'
+bibtex_bibfiles = ["refs.bib", "my_ref.bib"]
+bibtex_default_style = "alpha"
diff --git a/modopt/__init__.py b/modopt/__init__.py
index 2c06c1db..d446e15d 100644
--- a/modopt/__init__.py
+++ b/modopt/__init__.py
@@ -13,12 +13,12 @@
from modopt.base import *
try:
- _version = version('modopt')
+ _version = version("modopt")
except Exception: # pragma: no cover
- _version = 'Unkown'
+ _version = "Unkown"
warn(
- 'Could not extract package metadata. Make sure the package is '
- + 'correctly installed.',
+ "Could not extract package metadata. Make sure the package is "
+ + "correctly installed.",
)
__version__ = _version
diff --git a/modopt/base/__init__.py b/modopt/base/__init__.py
index 88424bae..d75ff315 100644
--- a/modopt/base/__init__.py
+++ b/modopt/base/__init__.py
@@ -9,4 +9,4 @@
"""
-__all__ = ['np_adjust', 'transform', 'types', 'observable']
+__all__ = ["np_adjust", "transform", "types", "observable"]
diff --git a/modopt/base/backend.py b/modopt/base/backend.py
index 1f4e9a72..fd933ebb 100644
--- a/modopt/base/backend.py
+++ b/modopt/base/backend.py
@@ -26,22 +26,24 @@
# Handle the compatibility with variable
LIBRARIES = {
- 'cupy': None,
- 'tensorflow': None,
- 'numpy': np,
+ "cupy": None,
+ "tensorflow": None,
+ "numpy": np,
}
-if util.find_spec('cupy') is not None:
+if util.find_spec("cupy") is not None:
try:
import cupy as cp
- LIBRARIES['cupy'] = cp
+
+ LIBRARIES["cupy"] = cp
except ImportError:
pass
-if util.find_spec('tensorflow') is not None:
+if util.find_spec("tensorflow") is not None:
try:
from tensorflow.experimental import numpy as tnp
- LIBRARIES['tensorflow'] = tnp
+
+ LIBRARIES["tensorflow"] = tnp
except ImportError:
pass
@@ -66,12 +68,12 @@ def get_backend(backend):
"""
if backend not in LIBRARIES.keys() or LIBRARIES[backend] is None:
msg = (
- '{0} backend not possible, please ensure that '
- + 'the optional libraries are installed.\n'
- + 'Reverting to numpy.'
+ "{0} backend not possible, please ensure that "
+ + "the optional libraries are installed.\n"
+ + "Reverting to numpy."
)
warn(msg.format(backend))
- backend = 'numpy'
+ backend = "numpy"
return LIBRARIES[backend], backend
@@ -92,16 +94,16 @@ def get_array_module(input_data):
The numpy or cupy module
"""
- if LIBRARIES['tensorflow'] is not None:
- if isinstance(input_data, LIBRARIES['tensorflow'].ndarray):
- return LIBRARIES['tensorflow']
- if LIBRARIES['cupy'] is not None:
- if isinstance(input_data, LIBRARIES['cupy'].ndarray):
- return LIBRARIES['cupy']
+ if LIBRARIES["tensorflow"] is not None:
+ if isinstance(input_data, LIBRARIES["tensorflow"].ndarray):
+ return LIBRARIES["tensorflow"]
+ if LIBRARIES["cupy"] is not None:
+ if isinstance(input_data, LIBRARIES["cupy"].ndarray):
+ return LIBRARIES["cupy"]
return np
-def change_backend(input_data, backend='cupy'):
+def change_backend(input_data, backend="cupy"):
"""Move data to device.
This method changes the backend of an array. This can be used to copy data
@@ -151,13 +153,13 @@ def move_to_cpu(input_data):
"""
xp = get_array_module(input_data)
- if xp == LIBRARIES['numpy']:
+ if xp == LIBRARIES["numpy"]:
return input_data
- elif xp == LIBRARIES['cupy']:
+ elif xp == LIBRARIES["cupy"]:
return input_data.get()
- elif xp == LIBRARIES['tensorflow']:
+ elif xp == LIBRARIES["tensorflow"]:
return input_data.data.numpy()
- raise ValueError('Cannot identify the array type.')
+ raise ValueError("Cannot identify the array type.")
def convert_to_tensor(input_data):
@@ -184,9 +186,9 @@ def convert_to_tensor(input_data):
"""
if not import_torch:
raise ImportError(
- 'Required version of Torch package not found'
- + 'see documentation for details: https://cea-cosmic.'
- + 'github.io/ModOpt/#optional-packages',
+ "Required version of Torch package not found"
+ + "see documentation for details: https://cea-cosmic."
+ + "github.io/ModOpt/#optional-packages",
)
xp = get_array_module(input_data)
@@ -220,9 +222,9 @@ def convert_to_cupy_array(input_data):
"""
if not import_torch:
raise ImportError(
- 'Required version of Torch package not found'
- + 'see documentation for details: https://cea-cosmic.'
- + 'github.io/ModOpt/#optional-packages',
+ "Required version of Torch package not found"
+ + "see documentation for details: https://cea-cosmic."
+ + "github.io/ModOpt/#optional-packages",
)
if input_data.is_cuda:
diff --git a/modopt/base/np_adjust.py b/modopt/base/np_adjust.py
index 6d290e43..31a785f5 100644
--- a/modopt/base/np_adjust.py
+++ b/modopt/base/np_adjust.py
@@ -154,8 +154,8 @@ def pad2d(input_data, padding):
padding = np.array(padding)
elif not isinstance(padding, np.ndarray):
raise ValueError(
- 'Padding must be an integer or a tuple (or list, np.ndarray) '
- + 'of itegers',
+ "Padding must be an integer or a tuple (or list, np.ndarray) "
+ + "of itegers",
)
if padding.size == 1:
@@ -164,7 +164,7 @@ def pad2d(input_data, padding):
pad_x = (padding[0], padding[0])
pad_y = (padding[1], padding[1])
- return np.pad(input_data, (pad_x, pad_y), 'constant')
+ return np.pad(input_data, (pad_x, pad_y), "constant")
def ftr(input_data):
diff --git a/modopt/base/observable.py b/modopt/base/observable.py
index 6471ba58..2f69a1a7 100644
--- a/modopt/base/observable.py
+++ b/modopt/base/observable.py
@@ -264,9 +264,7 @@ def is_converge(self):
mid_idx = -(self.wind // 2)
old_mean = np.array(self.list_cv_values[start_idx:mid_idx]).mean()
current_mean = np.array(self.list_cv_values[mid_idx:]).mean()
- normalize_residual_metrics = (
- np.abs(old_mean - current_mean) / np.abs(old_mean)
- )
+ normalize_residual_metrics = np.abs(old_mean - current_mean) / np.abs(old_mean)
self.converge_flag = normalize_residual_metrics < self.eps
def retrieve_metrics(self):
@@ -287,7 +285,7 @@ def retrieve_metrics(self):
time_val -= time_val[0]
return {
- 'time': time_val,
- 'index': self.list_iters,
- 'values': self.list_cv_values,
+ "time": time_val,
+ "index": self.list_iters,
+ "values": self.list_cv_values,
}
diff --git a/modopt/base/transform.py b/modopt/base/transform.py
index 07ce846f..fedd5efb 100644
--- a/modopt/base/transform.py
+++ b/modopt/base/transform.py
@@ -53,18 +53,17 @@ def cube2map(data_cube, layout):
"""
if data_cube.ndim != 3:
- raise ValueError('The input data must have 3 dimensions.')
+ raise ValueError("The input data must have 3 dimensions.")
if data_cube.shape[0] != np.prod(layout):
raise ValueError(
- 'The desired layout must match the number of input '
- + 'data layers.',
+ "The desired layout must match the number of input " + "data layers.",
)
- res = ([
+ res = [
np.hstack(data_cube[slice(layout[1] * elem, layout[1] * (elem + 1))])
for elem in range(layout[0])
- ])
+ ]
return np.vstack(res)
@@ -118,20 +117,24 @@ def map2cube(data_map, layout):
"""
if np.all(np.array(data_map.shape) % np.array(layout)):
raise ValueError(
- 'The desired layout must be a multiple of the number '
- + 'pixels in the data map.',
+ "The desired layout must be a multiple of the number "
+ + "pixels in the data map.",
)
d_shape = np.array(data_map.shape) // np.array(layout)
- return np.array([
- data_map[(
- slice(i_elem * d_shape[0], (i_elem + 1) * d_shape[0]),
- slice(j_elem * d_shape[1], (j_elem + 1) * d_shape[1]),
- )]
- for i_elem in range(layout[0])
- for j_elem in range(layout[1])
- ])
+ return np.array(
+ [
+ data_map[
+ (
+ slice(i_elem * d_shape[0], (i_elem + 1) * d_shape[0]),
+ slice(j_elem * d_shape[1], (j_elem + 1) * d_shape[1]),
+ )
+ ]
+ for i_elem in range(layout[0])
+ for j_elem in range(layout[1])
+ ]
+ )
def map2matrix(data_map, layout):
@@ -186,9 +189,9 @@ def map2matrix(data_map, layout):
image_shape * (i_elem % layout[1] + 1),
)
data_matrix.append(
- (
- data_map[lower[0]:upper[0], lower[1]:upper[1]]
- ).reshape(image_shape ** 2),
+ (data_map[lower[0] : upper[0], lower[1] : upper[1]]).reshape(
+ image_shape**2
+ ),
)
return np.array(data_matrix).T
@@ -232,7 +235,7 @@ def matrix2map(data_matrix, map_shape):
# Get the shape and layout of the images
image_shape = np.sqrt(data_matrix.shape[0]).astype(int)
- layout = np.array(map_shape // np.repeat(image_shape, 2), dtype='int')
+ layout = np.array(map_shape // np.repeat(image_shape, 2), dtype="int")
# Map objects from matrix
data_map = np.zeros(map_shape)
@@ -248,7 +251,7 @@ def matrix2map(data_matrix, map_shape):
image_shape * (i_elem // layout[1] + 1),
image_shape * (i_elem % layout[1] + 1),
)
- data_map[lower[0]:upper[0], lower[1]:upper[1]] = temp[:, :, i_elem]
+ data_map[lower[0] : upper[0], lower[1] : upper[1]] = temp[:, :, i_elem]
return data_map.astype(int)
diff --git a/modopt/base/types.py b/modopt/base/types.py
index 16e06f15..7ea805ad 100644
--- a/modopt/base/types.py
+++ b/modopt/base/types.py
@@ -30,7 +30,7 @@ def check_callable(input_obj):
For invalid input type
"""
if not callable(input_obj):
- raise TypeError('The input object must be a callable function.')
+ raise TypeError("The input object must be a callable function.")
return input_obj
@@ -71,14 +71,13 @@ def check_float(input_obj):
"""
if not isinstance(input_obj, (int, float, list, tuple, np.ndarray)):
- raise TypeError('Invalid input type.')
+ raise TypeError("Invalid input type.")
if isinstance(input_obj, int):
input_obj = float(input_obj)
elif isinstance(input_obj, (list, tuple)):
input_obj = np.array(input_obj, dtype=float)
- elif (
- isinstance(input_obj, np.ndarray)
- and (not np.issubdtype(input_obj.dtype, np.floating))
+ elif isinstance(input_obj, np.ndarray) and (
+ not np.issubdtype(input_obj.dtype, np.floating)
):
input_obj = input_obj.astype(float)
@@ -121,14 +120,13 @@ def check_int(input_obj):
"""
if not isinstance(input_obj, (int, float, list, tuple, np.ndarray)):
- raise TypeError('Invalid input type.')
+ raise TypeError("Invalid input type.")
if isinstance(input_obj, float):
input_obj = int(input_obj)
elif isinstance(input_obj, (list, tuple)):
input_obj = np.array(input_obj, dtype=int)
- elif (
- isinstance(input_obj, np.ndarray)
- and (not np.issubdtype(input_obj.dtype, np.integer))
+ elif isinstance(input_obj, np.ndarray) and (
+ not np.issubdtype(input_obj.dtype, np.integer)
):
input_obj = input_obj.astype(int)
@@ -160,19 +158,18 @@ def check_npndarray(input_obj, dtype=None, writeable=True, verbose=True):
"""
if not isinstance(input_obj, np.ndarray):
- raise TypeError('Input is not a numpy array.')
+ raise TypeError("Input is not a numpy array.")
- if (
- (not isinstance(dtype, type(None)))
- and (not np.issubdtype(input_obj.dtype, dtype))
+ if (not isinstance(dtype, type(None))) and (
+ not np.issubdtype(input_obj.dtype, dtype)
):
raise (
TypeError(
- 'The numpy array elements are not of type: {0}'.format(dtype),
+ "The numpy array elements are not of type: {0}".format(dtype),
),
)
if not writeable and verbose and input_obj.flags.writeable:
- warn('Making input data immutable.')
+ warn("Making input data immutable.")
input_obj.flags.writeable = writeable
diff --git a/modopt/examples/conftest.py b/modopt/examples/conftest.py
index 73358679..f3ed371b 100644
--- a/modopt/examples/conftest.py
+++ b/modopt/examples/conftest.py
@@ -11,10 +11,12 @@
https://stackoverflow.com/questions/56807698/how-to-run-script-as-pytest-test
"""
+
from pathlib import Path
import runpy
import pytest
+
def pytest_collect_file(path, parent):
"""Pytest hook.
@@ -22,7 +24,7 @@ def pytest_collect_file(path, parent):
The new node needs to have the specified parent as parent.
"""
p = Path(path)
- if p.suffix == '.py' and 'example' in p.name:
+ if p.suffix == ".py" and "example" in p.name:
return Script.from_parent(parent, path=p, name=p.name)
@@ -33,6 +35,7 @@ def collect(self):
"""Collect the script as its own item."""
yield ScriptItem.from_parent(self, name=self.name)
+
class ScriptItem(pytest.Item):
"""Item script collected by pytest."""
diff --git a/modopt/examples/example_lasso_forward_backward.py b/modopt/examples/example_lasso_forward_backward.py
index 7f820000..c28b0499 100644
--- a/modopt/examples/example_lasso_forward_backward.py
+++ b/modopt/examples/example_lasso_forward_backward.py
@@ -76,7 +76,7 @@
prox=prox_op,
cost=cost_op_fista,
metric_call_period=1,
- auto_iterate=False, # Just to give us the pleasure of doing things by ourself.
+ auto_iterate=False, # Just to give us the pleasure of doing things by ourself.
)
fb_fista.iterate()
@@ -115,7 +115,7 @@
prox=prox_op,
cost=cost_op_pogm,
metric_call_period=1,
- auto_iterate=False, # Just to give us the pleasure of doing things by ourself.
+ auto_iterate=False, # Just to give us the pleasure of doing things by ourself.
)
fb_pogm.iterate()
diff --git a/modopt/interface/__init__.py b/modopt/interface/__init__.py
index f9439747..55904ca1 100644
--- a/modopt/interface/__init__.py
+++ b/modopt/interface/__init__.py
@@ -8,4 +8,4 @@
"""
-__all__ = ['errors', 'log']
+__all__ = ["errors", "log"]
diff --git a/modopt/interface/errors.py b/modopt/interface/errors.py
index 0fbe7e71..eb4aa4ca 100644
--- a/modopt/interface/errors.py
+++ b/modopt/interface/errors.py
@@ -34,12 +34,12 @@ def warn(warn_string, log=None):
"""
if import_fail:
- warn_txt = 'WARNING'
+ warn_txt = "WARNING"
else:
- warn_txt = colored('WARNING', 'yellow')
+ warn_txt = colored("WARNING", "yellow")
# Print warning to stdout.
- sys.stderr.write('{0}: {1}\n'.format(warn_txt, warn_string))
+ sys.stderr.write("{0}: {1}\n".format(warn_txt, warn_string))
# Check if a logging structure is provided.
if not isinstance(log, type(None)):
@@ -61,17 +61,17 @@ def catch_error(exception, log=None):
"""
if import_fail:
- err_txt = 'ERROR'
+ err_txt = "ERROR"
else:
- err_txt = colored('ERROR', 'red')
+ err_txt = colored("ERROR", "red")
# Print exception to stdout.
- stream_txt = '{0}: {1}\n'.format(err_txt, exception)
+ stream_txt = "{0}: {1}\n".format(err_txt, exception)
sys.stderr.write(stream_txt)
# Check if a logging structure is provided.
if not isinstance(log, type(None)):
- log_txt = 'ERROR: {0}\n'.format(exception)
+ log_txt = "ERROR: {0}\n".format(exception)
log.exception(log_txt)
@@ -91,11 +91,11 @@ def file_name_error(file_name):
If file name not specified or file not found
"""
- if file_name == '' or file_name[0][0] == '-':
- raise IOError('Input file name not specified.')
+ if file_name == "" or file_name[0][0] == "-":
+ raise IOError("Input file name not specified.")
elif not os.path.isfile(file_name):
- raise IOError('Input file name {0} not found!'.format(file_name))
+ raise IOError("Input file name {0} not found!".format(file_name))
def is_exe(fpath):
@@ -136,7 +136,7 @@ def is_executable(exe_name):
"""
if not isinstance(exe_name, str):
- raise TypeError('Executable name must be a string.')
+ raise TypeError("Executable name must be a string.")
fpath, fname = os.path.split(exe_name)
@@ -146,11 +146,9 @@ def is_executable(exe_name):
else:
res = any(
is_exe(os.path.join(path, exe_name))
- for path in os.environ['PATH'].split(os.pathsep)
+ for path in os.environ["PATH"].split(os.pathsep)
)
if not res:
- message = (
- '{0} does not appear to be a valid executable on this system.'
- )
+ message = "{0} does not appear to be a valid executable on this system."
raise IOError(message.format(exe_name))
diff --git a/modopt/interface/log.py b/modopt/interface/log.py
index 3b2fa77a..a02428d9 100644
--- a/modopt/interface/log.py
+++ b/modopt/interface/log.py
@@ -30,22 +30,22 @@ def set_up_log(filename, verbose=True):
"""
# Add file extension.
- filename = '{0}.log'.format(filename)
+ filename = "{0}.log".format(filename)
if verbose:
- print('Preparing log file:', filename)
+ print("Preparing log file:", filename)
# Capture warnings.
logging.captureWarnings(True)
# Set output format.
formatter = logging.Formatter(
- fmt='%(asctime)s %(message)s',
- datefmt='%d/%m/%Y %H:%M:%S',
+ fmt="%(asctime)s %(message)s",
+ datefmt="%d/%m/%Y %H:%M:%S",
)
# Create file handler.
- fh = logging.FileHandler(filename=filename, mode='w')
+ fh = logging.FileHandler(filename=filename, mode="w")
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
@@ -55,7 +55,7 @@ def set_up_log(filename, verbose=True):
log.addHandler(fh)
# Send opening message.
- log.info('The log file has been set-up.')
+ log.info("The log file has been set-up.")
return log
@@ -74,10 +74,10 @@ def close_log(log, verbose=True):
"""
if verbose:
- print('Closing log file:', log.name)
+ print("Closing log file:", log.name)
# Send closing message.
- log.info('The log file has been closed.')
+ log.info("The log file has been closed.")
# Remove all handlers from log.
for log_handler in log.handlers:
diff --git a/modopt/math/__init__.py b/modopt/math/__init__.py
index a22c0c98..8e92aa50 100644
--- a/modopt/math/__init__.py
+++ b/modopt/math/__init__.py
@@ -8,4 +8,4 @@
"""
-__all__ = ['convolve', 'matrix', 'stats', 'metrics']
+__all__ = ["convolve", "matrix", "stats", "metrics"]
diff --git a/modopt/math/convolve.py b/modopt/math/convolve.py
index a4322ff2..528b2338 100644
--- a/modopt/math/convolve.py
+++ b/modopt/math/convolve.py
@@ -18,7 +18,7 @@
from astropy.convolution import convolve_fft
except ImportError: # pragma: no cover
import_astropy = False
- warn('astropy not found, will default to scipy for convolution')
+ warn("astropy not found, will default to scipy for convolution")
else:
import_astropy = True
try:
@@ -30,7 +30,7 @@
warn('Using pyFFTW "monkey patch" for scipy.fftpack')
-def convolve(input_data, kernel, method='scipy'):
+def convolve(input_data, kernel, method="scipy"):
"""Convolve data with kernel.
This method convolves the input data with a given kernel using FFT and
@@ -80,29 +80,29 @@ def convolve(input_data, kernel, method='scipy'):
"""
if input_data.ndim != kernel.ndim:
- raise ValueError('Data and kernel must have the same dimensions.')
+ raise ValueError("Data and kernel must have the same dimensions.")
- if method not in {'astropy', 'scipy'}:
+ if method not in {"astropy", "scipy"}:
raise ValueError('Invalid method. Options are "astropy" or "scipy".')
if not import_astropy: # pragma: no cover
- method = 'scipy'
+ method = "scipy"
- if method == 'astropy':
+ if method == "astropy":
return convolve_fft(
input_data,
kernel,
- boundary='wrap',
+ boundary="wrap",
crop=False,
- nan_treatment='fill',
+ nan_treatment="fill",
normalize_kernel=False,
)
- elif method == 'scipy':
- return scipy.signal.fftconvolve(input_data, kernel, mode='same')
+ elif method == "scipy":
+ return scipy.signal.fftconvolve(input_data, kernel, mode="same")
-def convolve_stack(input_data, kernel, rot_kernel=False, method='scipy'):
+def convolve_stack(input_data, kernel, rot_kernel=False, method="scipy"):
"""Convolve stack of data with stack of kernels.
This method convolves the input data with a given kernel using FFT and
@@ -156,7 +156,9 @@ def convolve_stack(input_data, kernel, rot_kernel=False, method='scipy'):
if rot_kernel:
kernel = rotate_stack(kernel)
- return np.array([
- convolve(data_i, kernel_i, method=method)
- for data_i, kernel_i in zip(input_data, kernel)
- ])
+ return np.array(
+ [
+ convolve(data_i, kernel_i, method=method)
+ for data_i, kernel_i in zip(input_data, kernel)
+ ]
+ )
diff --git a/modopt/math/matrix.py b/modopt/math/matrix.py
index 8361531d..6ddb3f2f 100644
--- a/modopt/math/matrix.py
+++ b/modopt/math/matrix.py
@@ -15,7 +15,7 @@
from modopt.base.backend import get_array_module, get_backend
-def gram_schmidt(matrix, return_opt='orthonormal'):
+def gram_schmidt(matrix, return_opt="orthonormal"):
r"""Gram-Schmit.
This method orthonormalizes the row vectors of the input matrix.
@@ -55,7 +55,7 @@ def gram_schmidt(matrix, return_opt='orthonormal'):
https://en.wikipedia.org/wiki/Gram%E2%80%93Schmidt_process
"""
- if return_opt not in {'orthonormal', 'orthogonal', 'both'}:
+ if return_opt not in {"orthonormal", "orthogonal", "both"}:
raise ValueError(
'Invalid return_opt, options are: "orthonormal", "orthogonal" or '
+ '"both"',
@@ -77,11 +77,11 @@ def gram_schmidt(matrix, return_opt='orthonormal'):
u_vec = np.array(u_vec)
e_vec = np.array(e_vec)
- if return_opt == 'orthonormal':
+ if return_opt == "orthonormal":
return e_vec
- elif return_opt == 'orthogonal':
+ elif return_opt == "orthogonal":
return u_vec
- elif return_opt == 'both':
+ elif return_opt == "both":
return u_vec, e_vec
@@ -201,7 +201,7 @@ def rot_matrix(angle):
return np.around(
np.array(
[[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]],
- dtype='float',
+ dtype="float",
),
10,
)
@@ -243,16 +243,15 @@ def rotate(matrix, angle):
shape = np.array(matrix.shape)
if shape[0] != shape[1]:
- raise ValueError('Input matrix must be square.')
+ raise ValueError("Input matrix must be square.")
shift = (shape - 1) // 2
index = (
- np.array(list(product(*np.array([np.arange(sval) for sval in shape]))))
- - shift
+ np.array(list(product(*np.array([np.arange(sval) for sval in shape])))) - shift
)
- new_index = np.array(np.dot(index, rot_matrix(angle)), dtype='int') + shift
+ new_index = np.array(np.dot(index, rot_matrix(angle)), dtype="int") + shift
new_index[new_index >= shape[0]] -= shape[0]
return matrix[tuple(zip(new_index.T))].reshape(shape.T)
@@ -301,7 +300,7 @@ def __init__(
data_shape,
data_type=float,
auto_run=True,
- compute_backend='numpy',
+ compute_backend="numpy",
verbose=False,
):
@@ -363,18 +362,14 @@ def get_spec_rad(self, tolerance=1e-6, max_iter=20, extra_factor=1.0):
x_new /= x_new_norm
- if (xp.abs(x_new_norm - x_old_norm) < tolerance):
- message = (
- ' - Power Method converged after {0} iterations!'
- )
+ if xp.abs(x_new_norm - x_old_norm) < tolerance:
+ message = " - Power Method converged after {0} iterations!"
if self._verbose:
print(message.format(i_elem + 1))
break
elif i_elem == max_iter - 1 and self._verbose:
- message = (
- ' - Power Method did not converge after {0} iterations!'
- )
+ message = " - Power Method did not converge after {0} iterations!"
print(message.format(max_iter))
xp.copyto(x_old, x_new)
diff --git a/modopt/math/metrics.py b/modopt/math/metrics.py
index 21952624..93f7ce06 100644
--- a/modopt/math/metrics.py
+++ b/modopt/math/metrics.py
@@ -71,15 +71,13 @@ def _preprocess_input(test, ref, mask=None):
The SNR
"""
- test = np.abs(np.copy(test)).astype('float64')
- ref = np.abs(np.copy(ref)).astype('float64')
+ test = np.abs(np.copy(test)).astype("float64")
+ ref = np.abs(np.copy(ref)).astype("float64")
test = min_max_normalize(test)
ref = min_max_normalize(ref)
if (not isinstance(mask, np.ndarray)) and (mask is not None):
- message = (
- 'Mask should be None, or a numpy.ndarray, got "{0}" instead.'
- )
+ message = 'Mask should be None, or a numpy.ndarray, got "{0}" instead.'
raise ValueError(message.format(mask))
if mask is None:
@@ -119,9 +117,9 @@ def ssim(test, ref, mask=None):
"""
if not import_skimage: # pragma: no cover
raise ImportError(
- 'Required version of Scikit-Image package not found'
- + 'see documentation for details: https://cea-cosmic.'
- + 'github.io/ModOpt/#optional-packages',
+ "Required version of Scikit-Image package not found"
+ + "see documentation for details: https://cea-cosmic."
+ + "github.io/ModOpt/#optional-packages",
)
test, ref, mask = _preprocess_input(test, ref, mask)
diff --git a/modopt/opt/__init__.py b/modopt/opt/__init__.py
index 2fd3d747..8b285bee 100644
--- a/modopt/opt/__init__.py
+++ b/modopt/opt/__init__.py
@@ -8,4 +8,4 @@
"""
-__all__ = ['cost', 'gradient', 'linear', 'algorithms', 'proximity', 'reweight']
+__all__ = ["cost", "gradient", "linear", "algorithms", "proximity", "reweight"]
diff --git a/modopt/opt/algorithms/__init__.py b/modopt/opt/algorithms/__init__.py
index d4e7082b..6a29325a 100644
--- a/modopt/opt/algorithms/__init__.py
+++ b/modopt/opt/algorithms/__init__.py
@@ -46,15 +46,20 @@
"""
from modopt.opt.algorithms.base import SetUp
-from modopt.opt.algorithms.forward_backward import (FISTA, POGM,
- ForwardBackward,
- GenForwardBackward)
-from modopt.opt.algorithms.gradient_descent import (AdaGenericGradOpt,
- ADAMGradOpt,
- GenericGradOpt,
- MomentumGradOpt,
- RMSpropGradOpt,
- SAGAOptGradOpt,
- VanillaGenericGradOpt)
+from modopt.opt.algorithms.forward_backward import (
+ FISTA,
+ POGM,
+ ForwardBackward,
+ GenForwardBackward,
+)
+from modopt.opt.algorithms.gradient_descent import (
+ AdaGenericGradOpt,
+ ADAMGradOpt,
+ GenericGradOpt,
+ MomentumGradOpt,
+ RMSpropGradOpt,
+ SAGAOptGradOpt,
+ VanillaGenericGradOpt,
+)
from modopt.opt.algorithms.primal_dual import Condat
from modopt.opt.algorithms.admm import ADMM, FastADMM
diff --git a/modopt/opt/algorithms/admm.py b/modopt/opt/algorithms/admm.py
index b881b770..4fd8074e 100644
--- a/modopt/opt/algorithms/admm.py
+++ b/modopt/opt/algorithms/admm.py
@@ -1,4 +1,5 @@
"""ADMM Algorithms."""
+
import numpy as np
from modopt.base.backend import get_array_module
@@ -188,7 +189,7 @@ def iterate(self, max_iter=150):
self.retrieve_outputs()
# rename outputs as attributes
self.u_final = self._u_new
- self.x_final = self.u_final # for backward compatibility
+ self.x_final = self.u_final # for backward compatibility
self.v_final = self._v_new
def get_notify_observers_kwargs(self):
@@ -203,9 +204,9 @@ def get_notify_observers_kwargs(self):
The mapping between the iterated variables
"""
return {
- 'x_new': self._u_new,
- 'v_new': self._v_new,
- 'idx': self.idx,
+ "x_new": self._u_new,
+ "v_new": self._v_new,
+ "idx": self.idx,
}
def retrieve_outputs(self):
@@ -215,7 +216,7 @@ def retrieve_outputs(self):
y_final, metrics.
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
diff --git a/modopt/opt/algorithms/base.py b/modopt/opt/algorithms/base.py
index c5a4b101..e2b9017d 100644
--- a/modopt/opt/algorithms/base.py
+++ b/modopt/opt/algorithms/base.py
@@ -69,7 +69,7 @@ def __init__(
verbose=False,
progress=True,
step_size=None,
- compute_backend='numpy',
+ compute_backend="numpy",
**dummy_kwargs,
):
self.idx = 0
@@ -79,26 +79,26 @@ def __init__(
self.metrics = metrics
self.step_size = step_size
self._op_parents = (
- 'GradParent',
- 'ProximityParent',
- 'LinearParent',
- 'costObj',
+ "GradParent",
+ "ProximityParent",
+ "LinearParent",
+ "costObj",
)
self.metric_call_period = metric_call_period
# Declaration of observers for metrics
- super().__init__(['cv_metrics'])
+ super().__init__(["cv_metrics"])
for name, dic in self.metrics.items():
observer = MetricObserver(
name,
- dic['metric'],
- dic['mapping'],
- dic['cst_kwargs'],
- dic['early_stopping'],
+ dic["metric"],
+ dic["mapping"],
+ dic["cst_kwargs"],
+ dic["early_stopping"],
)
- self.add_observer('cv_metrics', observer)
+ self.add_observer("cv_metrics", observer)
xp, compute_backend = backend.get_backend(compute_backend)
self.xp = xp
@@ -118,7 +118,7 @@ def metrics(self, metrics):
self._metrics = metrics
else:
raise TypeError(
- 'Metrics must be a dictionary, not {0}.'.format(type(metrics)),
+ "Metrics must be a dictionary, not {0}.".format(type(metrics)),
)
def any_convergence_flag(self):
@@ -132,9 +132,7 @@ def any_convergence_flag(self):
True if any convergence criteria met
"""
- return any(
- obs.converge_flag for obs in self._observers['cv_metrics']
- )
+ return any(obs.converge_flag for obs in self._observers["cv_metrics"])
def copy_data(self, input_data):
"""Copy Data.
@@ -152,10 +150,12 @@ def copy_data(self, input_data):
Copy of input data
"""
- return self.xp.copy(backend.change_backend(
- input_data,
- self.compute_backend,
- ))
+ return self.xp.copy(
+ backend.change_backend(
+ input_data,
+ self.compute_backend,
+ )
+ )
def _check_input_data(self, input_data):
"""Check input data type.
@@ -175,7 +175,7 @@ def _check_input_data(self, input_data):
"""
if not (isinstance(input_data, (self.xp.ndarray, np.ndarray))):
raise TypeError(
- 'Input data must be a numpy array or backend array',
+ "Input data must be a numpy array or backend array",
)
def _check_param(self, param_val):
@@ -195,7 +195,7 @@ def _check_param(self, param_val):
"""
if not isinstance(param_val, float):
- raise TypeError('Algorithm parameter must be a float value.')
+ raise TypeError("Algorithm parameter must be a float value.")
def _check_param_update(self, param_update):
"""Check algorithm parameter update methods.
@@ -213,14 +213,13 @@ def _check_param_update(self, param_update):
For invalid input type
"""
- param_conditions = (
- not isinstance(param_update, type(None))
- and not callable(param_update)
+ param_conditions = not isinstance(param_update, type(None)) and not callable(
+ param_update
)
if param_conditions:
raise TypeError(
- 'Algorithm parameter update must be a callabale function.',
+ "Algorithm parameter update must be a callabale function.",
)
def _check_operator(self, operator):
@@ -239,7 +238,7 @@ def _check_operator(self, operator):
tree = [op_obj.__name__ for op_obj in getmro(operator.__class__)]
if not any(parent in tree for parent in self._op_parents):
- message = '{0} does not inherit an operator parent.'
+ message = "{0} does not inherit an operator parent."
warn(message.format(str(operator.__class__)))
def _compute_metrics(self):
@@ -250,7 +249,7 @@ def _compute_metrics(self):
"""
kwargs = self.get_notify_observers_kwargs()
- self.notify_observers('cv_metrics', **kwargs)
+ self.notify_observers("cv_metrics", **kwargs)
def _iterations(self, max_iter, progbar=None):
"""Iterate method.
@@ -285,7 +284,7 @@ def _iterations(self, max_iter, progbar=None):
if self.converge:
if self.verbose:
- print(' - Converged!')
+ print(" - Converged!")
break
if progbar:
diff --git a/modopt/opt/algorithms/forward_backward.py b/modopt/opt/algorithms/forward_backward.py
index 702799c6..d34125fa 100644
--- a/modopt/opt/algorithms/forward_backward.py
+++ b/modopt/opt/algorithms/forward_backward.py
@@ -52,12 +52,12 @@ class FISTA(object):
"""
_restarting_strategies = (
- 'adaptive', # option 1 in alg 4
- 'adaptive-i',
- 'adaptive-1',
- 'adaptive-ii', # option 2 in alg 4
- 'adaptive-2',
- 'greedy', # alg 5
+ "adaptive", # option 1 in alg 4
+ "adaptive-i",
+ "adaptive-1",
+ "adaptive-ii", # option 2 in alg 4
+ "adaptive-2",
+ "greedy", # alg 5
None, # no restarting
)
@@ -75,24 +75,27 @@ def __init__(
):
if isinstance(a_cd, type(None)):
- self.mode = 'regular'
+ self.mode = "regular"
self.p_lazy = p_lazy
self.q_lazy = q_lazy
self.r_lazy = r_lazy
elif a_cd > 2:
- self.mode = 'CD'
+ self.mode = "CD"
self.a_cd = a_cd
self._n = 0
else:
raise ValueError(
- 'a_cd must either be None (for regular mode) or a number > 2',
+ "a_cd must either be None (for regular mode) or a number > 2",
)
if restart_strategy in self._restarting_strategies:
self._check_restart_params(
- restart_strategy, min_beta, s_greedy, xi_restart,
+ restart_strategy,
+ min_beta,
+ s_greedy,
+ xi_restart,
)
self.restart_strategy = restart_strategy
self.min_beta = min_beta
@@ -100,10 +103,10 @@ def __init__(
self.xi_restart = xi_restart
else:
- message = 'Restarting strategy must be one of {0}.'
+ message = "Restarting strategy must be one of {0}."
raise ValueError(
message.format(
- ', '.join(self._restarting_strategies),
+ ", ".join(self._restarting_strategies),
),
)
self._t_now = 1.0
@@ -155,22 +158,20 @@ def _check_restart_params(
if restart_strategy is None:
return True
- if self.mode != 'regular':
+ if self.mode != "regular":
raise ValueError(
- 'Restarting strategies can only be used with regular mode.',
+ "Restarting strategies can only be used with regular mode.",
)
- greedy_params_check = (
- min_beta is None or s_greedy is None or s_greedy <= 1
- )
+ greedy_params_check = min_beta is None or s_greedy is None or s_greedy <= 1
- if restart_strategy == 'greedy' and greedy_params_check:
+ if restart_strategy == "greedy" and greedy_params_check:
raise ValueError(
- 'You need a min_beta and an s_greedy > 1 for greedy restart.',
+ "You need a min_beta and an s_greedy > 1 for greedy restart.",
)
if xi_restart is None or xi_restart >= 1:
- raise ValueError('You need a xi_restart < 1 for restart.')
+ raise ValueError("You need a xi_restart < 1 for restart.")
return True
@@ -210,12 +211,12 @@ def is_restart(self, z_old, x_new, x_old):
criterion = xp.vdot(z_old - x_new, x_new - x_old) >= 0
if criterion:
- if 'adaptive' in self.restart_strategy:
+ if "adaptive" in self.restart_strategy:
self.r_lazy *= self.xi_restart
- if self.restart_strategy in {'adaptive-ii', 'adaptive-2'}:
+ if self.restart_strategy in {"adaptive-ii", "adaptive-2"}:
self._t_now = 1
- if self.restart_strategy == 'greedy':
+ if self.restart_strategy == "greedy":
cur_delta = xp.linalg.norm(x_new - x_old)
if self._delta0 is None:
self._delta0 = self.s_greedy * cur_delta
@@ -269,17 +270,17 @@ def update_lambda(self, *args, **kwargs):
Implements steps 3 and 4 from algoritm 10.7 in :cite:`bauschke2009`.
"""
- if self.restart_strategy == 'greedy':
+ if self.restart_strategy == "greedy":
return 2
# Steps 3 and 4 from alg.10.7.
self._t_prev = self._t_now
- if self.mode == 'regular':
- sqrt_part = self.r_lazy * self._t_prev ** 2 + self.q_lazy
+ if self.mode == "regular":
+ sqrt_part = self.r_lazy * self._t_prev**2 + self.q_lazy
self._t_now = self.p_lazy + np.sqrt(sqrt_part) * 0.5
- elif self.mode == 'CD':
+ elif self.mode == "CD":
self._t_now = (self._n + self.a_cd - 1) / self.a_cd
self._n += 1
@@ -344,11 +345,11 @@ def __init__(
x,
grad,
prox,
- cost='auto',
+ cost="auto",
beta_param=1.0,
lambda_param=1.0,
beta_update=None,
- lambda_update='fista',
+ lambda_update="fista",
auto_iterate=True,
metric_call_period=5,
metrics=None,
@@ -376,7 +377,7 @@ def __init__(
self._prox = prox
self._linear = linear
- if cost == 'auto':
+ if cost == "auto":
self._cost_func = costObj([self._grad, self._prox])
else:
self._cost_func = cost
@@ -384,7 +385,7 @@ def __init__(
# Check if there is a linear op, needed for metrics in the FB algoritm
if metrics and self._linear is None:
raise ValueError(
- 'When using metrics, you must pass a linear operator',
+ "When using metrics, you must pass a linear operator",
)
if self._linear is None:
@@ -400,7 +401,7 @@ def __init__(
# Set the algorithm parameter update methods
self._check_param_update(beta_update)
self._beta_update = beta_update
- if isinstance(lambda_update, str) and lambda_update == 'fista':
+ if isinstance(lambda_update, str) and lambda_update == "fista":
fista = FISTA(**kwargs)
self._lambda_update = fista.update_lambda
self._is_restart = fista.is_restart
@@ -462,9 +463,8 @@ def _update(self):
# Test cost function for convergence.
if self._cost_func:
- self.converge = (
- self.any_convergence_flag()
- or self._cost_func.get_cost(self._x_new)
+ self.converge = self.any_convergence_flag() or self._cost_func.get_cost(
+ self._x_new
)
def iterate(self, max_iter=150, progbar=None):
@@ -500,9 +500,9 @@ def get_notify_observers_kwargs(self):
"""
return {
- 'x_new': self._linear.adj_op(self._x_new),
- 'z_new': self._z_new,
- 'idx': self.idx,
+ "x_new": self._linear.adj_op(self._x_new),
+ "z_new": self._z_new,
+ "idx": self.idx,
}
def retrieve_outputs(self):
@@ -513,7 +513,7 @@ def retrieve_outputs(self):
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
@@ -577,7 +577,7 @@ def __init__(
x,
grad,
prox_list,
- cost='auto',
+ cost="auto",
gamma_param=1.0,
lambda_param=1.0,
gamma_update=None,
@@ -609,7 +609,7 @@ def __init__(
self._prox_list = self.xp.array(prox_list)
self._linear = linear
- if cost == 'auto':
+ if cost == "auto":
self._cost_func = costObj([self._grad] + prox_list)
else:
self._cost_func = cost
@@ -617,7 +617,7 @@ def __init__(
# Check if there is a linear op, needed for metrics in the FB algoritm
if metrics and self._linear is None:
raise ValueError(
- 'When using metrics, you must pass a linear operator',
+ "When using metrics, you must pass a linear operator",
)
if self._linear is None:
@@ -641,9 +641,7 @@ def __init__(
self._set_weights(weights)
# Set initial z
- self._z = self.xp.array([
- self._x_old for i in range(self._prox_list.size)
- ])
+ self._z = self.xp.array([self._x_old for i in range(self._prox_list.size)])
# Automatically run the algorithm
if auto_iterate:
@@ -673,25 +671,25 @@ def _set_weights(self, weights):
self._prox_list.size,
)
elif not isinstance(weights, (list, tuple, np.ndarray)):
- raise TypeError('Weights must be provided as a list.')
+ raise TypeError("Weights must be provided as a list.")
weights = self.xp.array(weights)
if not np.issubdtype(weights.dtype, np.floating):
- raise ValueError('Weights must be list of float values.')
+ raise ValueError("Weights must be list of float values.")
if weights.size != self._prox_list.size:
raise ValueError(
- 'The number of weights must match the number of proximity '
- + 'operators.',
+ "The number of weights must match the number of proximity "
+ + "operators.",
)
expected_weight_sum = 1.0
if self.xp.sum(weights) != expected_weight_sum:
raise ValueError(
- 'Proximity operator weights must sum to 1.0. Current sum of '
- + 'weights = {0}'.format(self.xp.sum(weights)),
+ "Proximity operator weights must sum to 1.0. Current sum of "
+ + "weights = {0}".format(self.xp.sum(weights)),
)
self._weights = weights
@@ -726,9 +724,7 @@ def _update(self):
# Update z values.
for i in range(self._prox_list.size):
- z_temp = (
- 2 * self._x_old - self._z[i] - self._gamma * self._grad.grad
- )
+ z_temp = 2 * self._x_old - self._z[i] - self._gamma * self._grad.grad
z_prox = self._prox_list[i].op(
z_temp,
extra_factor=self._gamma / self._weights[i],
@@ -784,9 +780,9 @@ def get_notify_observers_kwargs(self):
"""
return {
- 'x_new': self._linear.adj_op(self._x_new),
- 'z_new': self._z,
- 'idx': self.idx,
+ "x_new": self._linear.adj_op(self._x_new),
+ "z_new": self._z,
+ "idx": self.idx,
}
def retrieve_outputs(self):
@@ -797,7 +793,7 @@ def retrieve_outputs(self):
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
@@ -871,7 +867,7 @@ def __init__(
z,
grad,
prox,
- cost='auto',
+ cost="auto",
linear=None,
beta_param=1.0,
sigma_bar=1.0,
@@ -905,7 +901,7 @@ def __init__(
self._grad = grad
self._prox = prox
self._linear = linear
- if cost == 'auto':
+ if cost == "auto":
self._cost_func = costObj([self._grad, self._prox])
else:
self._cost_func = cost
@@ -918,7 +914,7 @@ def __init__(
for param_val in (beta_param, sigma_bar):
self._check_param(param_val)
if sigma_bar < 0 or sigma_bar > 1:
- raise ValueError('The sigma bar parameter needs to be in [0, 1]')
+ raise ValueError("The sigma bar parameter needs to be in [0, 1]")
self._beta = self.step_size or beta_param
self._sigma_bar = sigma_bar
@@ -944,18 +940,18 @@ def _update(self):
"""
# Step 4 from alg. 3
self._grad.get_grad(self._x_old)
- #self._u_new = self._x_old - self._beta * self._grad.grad
- self._u_new = -self._beta * self._grad.grad
+ # self._u_new = self._x_old - self._beta * self._grad.grad
+ self._u_new = -self._beta * self._grad.grad
self._u_new += self._x_old
# Step 5 from alg. 3
- self._t_new = 0.5 * (1 + self.xp.sqrt(1 + 4 * self._t_old ** 2))
+ self._t_new = 0.5 * (1 + self.xp.sqrt(1 + 4 * self._t_old**2))
# Step 6 from alg. 3
t_shifted_ratio = (self._t_old - 1) / self._t_new
sigma_t_ratio = self._sigma * self._t_old / self._t_new
beta_xi_t_shifted_ratio = t_shifted_ratio * self._beta / self._xi
- self._z = - beta_xi_t_shifted_ratio * (self._x_old - self._z)
+ self._z = -beta_xi_t_shifted_ratio * (self._x_old - self._z)
self._z += self._u_new
self._z += t_shifted_ratio * (self._u_new - self._u_old)
self._z += sigma_t_ratio * (self._u_new - self._x_old)
@@ -968,20 +964,18 @@ def _update(self):
# Restarting and gamma-Decreasing
# Step 9 from alg. 3
- #self._g_new = self._grad.grad - (self._x_new - self._z) / self._xi
- self._g_new = (self._z - self._x_new)
+ # self._g_new = self._grad.grad - (self._x_new - self._z) / self._xi
+ self._g_new = self._z - self._x_new
self._g_new /= self._xi
self._g_new += self._grad.grad
# Step 10 from alg 3.
- #self._y_new = self._x_old - self._beta * self._g_new
- self._y_new = - self._beta * self._g_new
+ # self._y_new = self._x_old - self._beta * self._g_new
+ self._y_new = -self._beta * self._g_new
self._y_new += self._x_old
# Step 11 from alg. 3
- restart_crit = (
- self.xp.vdot(-self._g_new, self._y_new - self._y_old) < 0
- )
+ restart_crit = self.xp.vdot(-self._g_new, self._y_new - self._y_old) < 0
if restart_crit:
self._t_new = 1
self._sigma = 1
@@ -999,9 +993,8 @@ def _update(self):
# Test cost function for convergence.
if self._cost_func:
- self.converge = (
- self.any_convergence_flag()
- or self._cost_func.get_cost(self._x_new)
+ self.converge = self.any_convergence_flag() or self._cost_func.get_cost(
+ self._x_new
)
def iterate(self, max_iter=150, progbar=None):
@@ -1037,14 +1030,14 @@ def get_notify_observers_kwargs(self):
"""
return {
- 'u_new': self._u_new,
- 'x_new': self._linear.adj_op(self._x_new),
- 'y_new': self._y_new,
- 'z_new': self._z,
- 'xi': self._xi,
- 'sigma': self._sigma,
- 't': self._t_new,
- 'idx': self.idx,
+ "u_new": self._u_new,
+ "x_new": self._linear.adj_op(self._x_new),
+ "y_new": self._y_new,
+ "z_new": self._z,
+ "xi": self._xi,
+ "sigma": self._sigma,
+ "t": self._t_new,
+ "idx": self.idx,
}
def retrieve_outputs(self):
@@ -1055,6 +1048,6 @@ def retrieve_outputs(self):
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
diff --git a/modopt/opt/algorithms/gradient_descent.py b/modopt/opt/algorithms/gradient_descent.py
index f3fe4b10..d3af1686 100644
--- a/modopt/opt/algorithms/gradient_descent.py
+++ b/modopt/opt/algorithms/gradient_descent.py
@@ -103,7 +103,7 @@ def __init__(
self._check_operator(operator)
self._grad = grad
self._prox = prox
- if cost == 'auto':
+ if cost == "auto":
self._cost_func = costObj([self._grad, self._prox])
else:
self._cost_func = cost
@@ -157,9 +157,8 @@ def _update(self):
self._eta = self._eta_update(self._eta, self.idx)
# Test cost function for convergence.
if self._cost_func:
- self.converge = (
- self.any_convergence_flag()
- or self._cost_func.get_cost(self._x_new)
+ self.converge = self.any_convergence_flag() or self._cost_func.get_cost(
+ self._x_new
)
def _update_grad_dir(self, grad):
@@ -208,10 +207,10 @@ def get_notify_observers_kwargs(self):
"""
return {
- 'x_new': self._x_new,
- 'dir_grad': self._dir_grad,
- 'speed_grad': self._speed_grad,
- 'idx': self.idx,
+ "x_new": self._x_new,
+ "dir_grad": self._dir_grad,
+ "speed_grad": self._speed_grad,
+ "idx": self.idx,
}
def retrieve_outputs(self):
@@ -222,7 +221,7 @@ def retrieve_outputs(self):
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
@@ -308,7 +307,7 @@ class RMSpropGradOpt(GenericGradOpt):
def __init__(self, *args, gamma=0.5, **kwargs):
super().__init__(*args, **kwargs)
if gamma < 0 or gamma > 1:
- raise ValueError('gamma is outside of range [0,1]')
+ raise ValueError("gamma is outside of range [0,1]")
self._check_param(gamma)
self._gamma = gamma
@@ -405,9 +404,9 @@ def __init__(self, *args, gamma=0.9, beta=0.9, **kwargs):
self._check_param(gamma)
self._check_param(beta)
if gamma < 0 or gamma >= 1:
- raise ValueError('gamma is outside of range [0,1]')
+ raise ValueError("gamma is outside of range [0,1]")
if beta < 0 or beta >= 1:
- raise ValueError('beta is outside of range [0,1]')
+ raise ValueError("beta is outside of range [0,1]")
self._gamma = gamma
self._beta = beta
self._beta_pow = 1
diff --git a/modopt/opt/algorithms/primal_dual.py b/modopt/opt/algorithms/primal_dual.py
index d5bdd431..179ddf95 100644
--- a/modopt/opt/algorithms/primal_dual.py
+++ b/modopt/opt/algorithms/primal_dual.py
@@ -81,7 +81,7 @@ def __init__(
prox,
prox_dual,
linear=None,
- cost='auto',
+ cost="auto",
reweight=None,
rho=0.5,
sigma=1.0,
@@ -123,12 +123,14 @@ def __init__(
self._linear = Identity()
else:
self._linear = linear
- if cost == 'auto':
- self._cost_func = costObj([
- self._grad,
- self._prox,
- self._prox_dual,
- ])
+ if cost == "auto":
+ self._cost_func = costObj(
+ [
+ self._grad,
+ self._prox,
+ self._prox_dual,
+ ]
+ )
else:
self._cost_func = cost
@@ -187,22 +189,17 @@ def _update(self):
self._grad.get_grad(self._x_old)
x_prox = self._prox.op(
- self._x_old - self._tau * self._grad.grad - self._tau
- * self._linear.adj_op(self._y_old),
+ self._x_old
+ - self._tau * self._grad.grad
+ - self._tau * self._linear.adj_op(self._y_old),
)
# Step 2 from eq.9.
- y_temp = (
- self._y_old + self._sigma
- * self._linear.op(2 * x_prox - self._x_old)
- )
+ y_temp = self._y_old + self._sigma * self._linear.op(2 * x_prox - self._x_old)
- y_prox = (
- y_temp - self._sigma
- * self._prox_dual.op(
- y_temp / self._sigma,
- extra_factor=(1.0 / self._sigma),
- )
+ y_prox = y_temp - self._sigma * self._prox_dual.op(
+ y_temp / self._sigma,
+ extra_factor=(1.0 / self._sigma),
)
# Step 3 from eq.9.
@@ -220,9 +217,8 @@ def _update(self):
# Test cost function for convergence.
if self._cost_func:
- self.converge = (
- self.any_convergence_flag()
- or self._cost_func.get_cost(self._x_new, self._y_new)
+ self.converge = self.any_convergence_flag() or self._cost_func.get_cost(
+ self._x_new, self._y_new
)
def iterate(self, max_iter=150, n_rewightings=1, progbar=None):
@@ -267,7 +263,7 @@ def get_notify_observers_kwargs(self):
The mapping between the iterated variables
"""
- return {'x_new': self._x_new, 'y_new': self._y_new, 'idx': self.idx}
+ return {"x_new": self._x_new, "y_new": self._y_new, "idx": self.idx}
def retrieve_outputs(self):
"""Retrieve outputs.
@@ -277,6 +273,6 @@ def retrieve_outputs(self):
"""
metrics = {}
- for obs in self._observers['cv_metrics']:
+ for obs in self._observers["cv_metrics"]:
metrics[obs.name] = obs.retrieve_metrics()
self.metrics = metrics
diff --git a/modopt/opt/cost.py b/modopt/opt/cost.py
index 688a3959..4bead130 100644
--- a/modopt/opt/cost.py
+++ b/modopt/opt/cost.py
@@ -115,17 +115,17 @@ def _check_cost(self):
# The mean of the first half of the test list
t1 = xp.mean(
- xp.array(self._test_list[len(self._test_list) // 2:]),
+ xp.array(self._test_list[len(self._test_list) // 2 :]),
axis=0,
)
# The mean of the second half of the test list
t2 = xp.mean(
- xp.array(self._test_list[:len(self._test_list) // 2]),
+ xp.array(self._test_list[: len(self._test_list) // 2]),
axis=0,
)
# Calculate the change across the test list
if xp.around(t1, decimals=16):
- cost_diff = (xp.linalg.norm(t1 - t2) / xp.linalg.norm(t1))
+ cost_diff = xp.linalg.norm(t1 - t2) / xp.linalg.norm(t1)
else:
cost_diff = 0
@@ -133,9 +133,9 @@ def _check_cost(self):
self._test_list = []
if self._verbose:
- print(' - CONVERGENCE TEST - ')
- print(' - CHANGE IN COST:', cost_diff)
- print('')
+ print(" - CONVERGENCE TEST - ")
+ print(" - CHANGE IN COST:", cost_diff)
+ print("")
# Check for convergence
return cost_diff <= self._tolerance
@@ -176,8 +176,7 @@ def get_cost(self, *args, **kwargs):
"""
# Check if the cost should be calculated
test_conditions = (
- self._cost_interval is None
- or self._iteration % self._cost_interval
+ self._cost_interval is None or self._iteration % self._cost_interval
)
if test_conditions:
@@ -185,15 +184,15 @@ def get_cost(self, *args, **kwargs):
else:
if self._verbose:
- print(' - ITERATION:', self._iteration)
+ print(" - ITERATION:", self._iteration)
# Calculate the current cost
self.cost = self._calc_cost(verbose=self._verbose, *args, **kwargs)
self._cost_list.append(self.cost)
if self._verbose:
- print(' - COST:', self.cost)
- print('')
+ print(" - COST:", self.cost)
+ print("")
# Test for convergence
test_result = self._check_cost()
@@ -288,13 +287,11 @@ def _check_operators(self):
"""
if not isinstance(self._operators, (list, tuple, np.ndarray)):
- message = (
- 'Input operators must be provided as a list, not {0}'
- )
+ message = "Input operators must be provided as a list, not {0}"
raise TypeError(message.format(type(self._operators)))
for op in self._operators:
- if not hasattr(op, 'cost'):
+ if not hasattr(op, "cost"):
raise ValueError('Operators must contain "cost" method.')
op.cost = check_callable(op.cost)
diff --git a/modopt/opt/gradient.py b/modopt/opt/gradient.py
index caa8fa9d..8d7bacc7 100644
--- a/modopt/opt/gradient.py
+++ b/modopt/opt/gradient.py
@@ -289,7 +289,7 @@ def _cost_method(self, *args, **kwargs):
"""
cost_val = 0.5 * np.linalg.norm(self.obs_data - self.op(args[0])) ** 2
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - DATA FIDELITY (X):', cost_val)
+ if "verbose" in kwargs and kwargs["verbose"]:
+ print(" - DATA FIDELITY (X):", cost_val)
return cost_val
diff --git a/modopt/opt/linear/base.py b/modopt/opt/linear/base.py
index e347970d..79f05d85 100644
--- a/modopt/opt/linear/base.py
+++ b/modopt/opt/linear/base.py
@@ -5,6 +5,7 @@
from modopt.base.types import check_callable
from modopt.base.backend import get_array_module
+
class LinearParent(object):
"""Linear Operator Parent Class.
@@ -69,7 +70,7 @@ def __init__(self):
self.op = lambda input_data: input_data
self.adj_op = self.op
- self.cost= lambda *args, **kwargs: 0
+ self.cost = lambda *args, **kwargs: 0
class MatrixOperator(LinearParent):
@@ -159,14 +160,13 @@ def _check_type(self, input_val):
"""
if not isinstance(input_val, (list, tuple, np.ndarray)):
raise TypeError(
- 'Invalid input type, input must be a list, tuple or numpy '
- + 'array.',
+ "Invalid input type, input must be a list, tuple or numpy " + "array.",
)
input_val = np.array(input_val)
if not input_val.size:
- raise ValueError('Input list is empty.')
+ raise ValueError("Input list is empty.")
return input_val
@@ -200,10 +200,10 @@ def _check_inputs(self, operators, weights):
for operator in operators:
- if not hasattr(operator, 'op'):
+ if not hasattr(operator, "op"):
raise ValueError('Operators must contain "op" method.')
- if not hasattr(operator, 'adj_op'):
+ if not hasattr(operator, "adj_op"):
raise ValueError('Operators must contain "adj_op" method.')
operator.op = check_callable(operator.op)
@@ -214,12 +214,11 @@ def _check_inputs(self, operators, weights):
if weights.size != operators.size:
raise ValueError(
- 'The number of weights must match the number of '
- + 'operators.',
+ "The number of weights must match the number of " + "operators.",
)
if not np.issubdtype(weights.dtype, np.floating):
- raise TypeError('The weights must be a list of float values.')
+ raise TypeError("The weights must be a list of float values.")
return operators, weights
diff --git a/modopt/opt/linear/wavelet.py b/modopt/opt/linear/wavelet.py
index 5feead66..e1608ff4 100644
--- a/modopt/opt/linear/wavelet.py
+++ b/modopt/opt/linear/wavelet.py
@@ -45,7 +45,7 @@ class WaveletConvolve(LinearParent):
"""
- def __init__(self, filters, method='scipy'):
+ def __init__(self, filters, method="scipy"):
self._filters = check_float(filters)
self.op = lambda input_data: filter_convolve_stack(
@@ -61,8 +61,6 @@ def __init__(self, filters, method='scipy'):
)
-
-
class WaveletTransform(LinearParent):
"""
2D and 3D wavelet transform class.
@@ -85,22 +83,28 @@ class WaveletTransform(LinearParent):
**kwargs: extra kwargs for Pywavelet or Pytorch Wavelet
"""
- def __init__(self,
+
+ def __init__(
+ self,
wavelet_name,
shape,
level=4,
mode="symmetric",
compute_backend="numpy",
- **kwargs):
+ **kwargs,
+ ):
if compute_backend == "cupy" and ptwt_available:
- self.operator = CupyWaveletTransform(wavelet=wavelet_name, shape=shape, level=level, mode=mode)
+ self.operator = CupyWaveletTransform(
+ wavelet=wavelet_name, shape=shape, level=level, mode=mode
+ )
elif compute_backend == "numpy" and pywt_available:
- self.operator = CPUWaveletTransform(wavelet_name=wavelet_name, shape=shape, level=level, **kwargs)
+ self.operator = CPUWaveletTransform(
+ wavelet_name=wavelet_name, shape=shape, level=level, **kwargs
+ )
else:
raise ValueError(f"Compute Backend {compute_backend} not available")
-
self.op = self.operator.op
self.adj_op = self.operator.adj_op
@@ -108,6 +112,7 @@ def __init__(self,
def coeffs_shape(self):
return self.operator.coeffs_shape
+
class CPUWaveletTransform(LinearParent):
"""
2D and 3D wavelet transform class.
@@ -286,7 +291,7 @@ def __init__(
self.level = level
self.shape = shape
self.mode = mode
- self.coeffs_shape = None # will be set after op.
+ self.coeffs_shape = None # will be set after op.
def op(self, data: torch.Tensor) -> list[torch.Tensor]:
"""Apply the wavelet decomposition on.
@@ -419,8 +424,10 @@ def __init__(
self.shape = shape
self.mode = mode
- self.operator = TorchWaveletTransform(shape=shape, wavelet=wavelet, level=level,mode=mode)
- self.coeffs_shape = None # will be set after op
+ self.operator = TorchWaveletTransform(
+ shape=shape, wavelet=wavelet, level=level, mode=mode
+ )
+ self.coeffs_shape = None # will be set after op
def op(self, data: cp.array) -> cp.ndarray:
"""Define the wavelet operator.
diff --git a/modopt/opt/proximity.py b/modopt/opt/proximity.py
index fc81a753..b562f77a 100644
--- a/modopt/opt/proximity.py
+++ b/modopt/opt/proximity.py
@@ -140,8 +140,8 @@ def _cost_method(self, *args, **kwargs):
``0.0``
"""
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - Min (X):', np.min(args[0]))
+ if "verbose" in kwargs and kwargs["verbose"]:
+ print(" - Min (X):", np.min(args[0]))
return 0
@@ -167,7 +167,7 @@ class SparseThreshold(ProximityParent):
"""
- def __init__(self, linear, weights, thresh_type='soft'):
+ def __init__(self, linear, weights, thresh_type="soft"):
self._linear = linear
self.weights = weights
@@ -221,8 +221,8 @@ def _cost_method(self, *args, **kwargs):
if isinstance(cost_val, xp.ndarray):
cost_val = cost_val.item()
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - L1 NORM (X):', cost_val)
+ if "verbose" in kwargs and kwargs["verbose"]:
+ print(" - L1 NORM (X):", cost_val)
return cost_val
@@ -273,8 +273,8 @@ class LowRankMatrix(ProximityParent):
def __init__(
self,
threshold,
- thresh_type='soft',
- lowr_type='standard',
+ thresh_type="soft",
+ lowr_type="standard",
initial_rank=None,
operator=None,
):
@@ -315,13 +315,13 @@ def _op_method(self, input_data, extra_factor=1.0, rank=None):
"""
# Update threshold with extra factor.
threshold = self.thresh * extra_factor
- if self.lowr_type == 'standard' and self.rank is None and rank is None:
+ if self.lowr_type == "standard" and self.rank is None and rank is None:
data_matrix = svd_thresh(
cube2matrix(input_data),
threshold,
thresh_type=self.thresh_type,
)
- elif self.lowr_type == 'standard':
+ elif self.lowr_type == "standard":
data_matrix, update_rank = svd_thresh_coef_fast(
cube2matrix(input_data),
threshold,
@@ -331,7 +331,7 @@ def _op_method(self, input_data, extra_factor=1.0, rank=None):
)
self.rank = update_rank # save for future use
- elif self.lowr_type == 'ngole':
+ elif self.lowr_type == "ngole":
data_matrix = svd_thresh_coef(
cube2matrix(input_data),
self.operator,
@@ -339,7 +339,7 @@ def _op_method(self, input_data, extra_factor=1.0, rank=None):
thresh_type=self.thresh_type,
)
else:
- raise ValueError('lowr_type should be standard or ngole')
+ raise ValueError("lowr_type should be standard or ngole")
# Return updated data.
return matrix2cube(data_matrix, input_data.shape[1:])
@@ -365,8 +365,8 @@ def _cost_method(self, *args, **kwargs):
"""
cost_val = self.thresh * nuclear_norm(cube2matrix(args[0]))
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - NUCLEAR NORM (X):', cost_val)
+ if "verbose" in kwargs and kwargs["verbose"]:
+ print(" - NUCLEAR NORM (X):", cost_val)
return cost_val
@@ -506,19 +506,19 @@ def _check_operators(self, operators):
"""
if not isinstance(operators, (list, tuple, np.ndarray)):
raise TypeError(
- 'Invalid input type, operators must be a list, tuple or '
- + 'numpy array.',
+ "Invalid input type, operators must be a list, tuple or "
+ + "numpy array.",
)
operators = np.array(operators)
if not operators.size:
- raise ValueError('Operator list is empty.')
+ raise ValueError("Operator list is empty.")
for operator in operators:
- if not hasattr(operator, 'op'):
+ if not hasattr(operator, "op"):
raise ValueError('Operators must contain "op" method.')
- if not hasattr(operator, 'cost'):
+ if not hasattr(operator, "cost"):
raise ValueError('Operators must contain "cost" method.')
operator.op = check_callable(operator.op)
operator.cost = check_callable(operator.cost)
@@ -573,10 +573,12 @@ def _cost_method(self, *args, **kwargs):
Combinded cost components
"""
- return np.sum([
- operator.cost(input_data)
- for operator, input_data in zip(self.operators, args[0])
- ])
+ return np.sum(
+ [
+ operator.cost(input_data)
+ for operator, input_data in zip(self.operators, args[0])
+ ]
+ )
class OrderedWeightedL1Norm(ProximityParent):
@@ -617,16 +619,16 @@ class OrderedWeightedL1Norm(ProximityParent):
def __init__(self, weights):
if not import_sklearn: # pragma: no cover
raise ImportError(
- 'Required version of Scikit-Learn package not found see '
- + 'documentation for details: '
- + 'https://cea-cosmic.github.io/ModOpt/#optional-packages',
+ "Required version of Scikit-Learn package not found see "
+ + "documentation for details: "
+ + "https://cea-cosmic.github.io/ModOpt/#optional-packages",
)
if np.max(np.diff(weights)) > 0:
- raise ValueError('Weights must be non increasing')
+ raise ValueError("Weights must be non increasing")
self.weights = weights.flatten()
if (self.weights < 0).any():
raise ValueError(
- 'The weight values must be provided in descending order',
+ "The weight values must be provided in descending order",
)
self.op = self._op_method
self.cost = self._cost_method
@@ -664,7 +666,9 @@ def _op_method(self, input_data, extra_factor=1.0):
# Projection onto the monotone non-negative cone using
# isotonic_regression
data_abs = isotonic_regression(
- data_abs - threshold, y_min=0, increasing=False,
+ data_abs - threshold,
+ y_min=0,
+ increasing=False,
)
# Unsorting the data
@@ -672,7 +676,7 @@ def _op_method(self, input_data, extra_factor=1.0):
data_abs_unsorted[data_abs_sort_idx] = data_abs
# Putting the sign back
- with np.errstate(invalid='ignore'):
+ with np.errstate(invalid="ignore"):
sign_data = data_squeezed / np.abs(data_squeezed)
# Removing NAN caused by the sign
@@ -702,8 +706,8 @@ def _cost_method(self, *args, **kwargs):
self.weights * np.sort(np.squeeze(np.abs(args[0]))[::-1]),
)
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - OWL NORM (X):', cost_val)
+ if "verbose" in kwargs and kwargs["verbose"]:
+ print(" - OWL NORM (X):", cost_val)
return cost_val
@@ -734,7 +738,7 @@ class Ridge(ProximityParent):
"""
- def __init__(self, linear, weights, thresh_type='soft'):
+ def __init__(self, linear, weights, thresh_type="soft"):
self._linear = linear
self.weights = weights
@@ -786,8 +790,8 @@ def _cost_method(self, *args, **kwargs):
np.abs(self.weights * self._linear.op(args[0]) ** 2),
)
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - L2 NORM (X):', cost_val)
+ if "verbose" in kwargs and kwargs["verbose"]:
+ print(" - L2 NORM (X):", cost_val)
return cost_val
@@ -848,8 +852,8 @@ def _op_method(self, input_data, extra_factor=1.0):
"""
soft_threshold = self.beta * extra_factor
- normalization = (self.alpha * 2 * extra_factor + 1)
- return thresh(input_data, soft_threshold, 'soft') / normalization
+ normalization = self.alpha * 2 * extra_factor + 1
+ return thresh(input_data, soft_threshold, "soft") / normalization
def _cost_method(self, *args, **kwargs):
"""Calculate Ridge component of the cost.
@@ -875,8 +879,8 @@ def _cost_method(self, *args, **kwargs):
+ np.abs(self.beta * self._linear.op(args[0])),
)
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - ELASTIC NET (X):', cost_val)
+ if "verbose" in kwargs and kwargs["verbose"]:
+ print(" - ELASTIC NET (X):", cost_val)
return cost_val
@@ -942,7 +946,7 @@ def k_value(self):
def k_value(self, k_val):
if k_val < 1:
raise ValueError(
- 'The k parameter should be greater or equal than 1',
+ "The k parameter should be greater or equal than 1",
)
self._k_value = k_val
@@ -987,7 +991,7 @@ def _compute_theta(self, input_data, alpha, extra_factor=1.0):
alpha_beta = alpha_input - self.beta * extra_factor
theta = alpha_beta * ((alpha_beta <= 1) & (alpha_beta >= 0))
theta = np.nan_to_num(theta)
- theta += (alpha_input > (self.beta * extra_factor + 1))
+ theta += alpha_input > (self.beta * extra_factor + 1)
return theta
def _interpolate(self, alpha0, alpha1, sum0, sum1):
@@ -1096,11 +1100,11 @@ def _binary_search(self, input_data, alpha, extra_factor=1.0):
extra_factor,
).sum()
- if (np.abs(sum0 - self._k_value) <= tolerance):
+ if np.abs(sum0 - self._k_value) <= tolerance:
found = True
midpoint = first_idx
- if (np.abs(sum1 - self._k_value) <= tolerance):
+ if np.abs(sum1 - self._k_value) <= tolerance:
found = True
midpoint = last_idx - 1
# -1 because output is index such that
@@ -1145,13 +1149,17 @@ def _binary_search(self, input_data, alpha, extra_factor=1.0):
if found:
return (
- midpoint, alpha[midpoint], alpha[midpoint + 1], sum0, sum1,
+ midpoint,
+ alpha[midpoint],
+ alpha[midpoint + 1],
+ sum0,
+ sum1,
)
raise ValueError(
- 'Cannot find the coordinate of alpha (i) such '
- + 'that sum(theta(alpha[i])) =< k and '
- + 'sum(theta(alpha[i+1])) >= k ',
+ "Cannot find the coordinate of alpha (i) such "
+ + "that sum(theta(alpha[i])) =< k and "
+ + "sum(theta(alpha[i+1])) >= k ",
)
def _find_alpha(self, input_data, extra_factor=1.0):
@@ -1177,13 +1185,11 @@ def _find_alpha(self, input_data, extra_factor=1.0):
# Computes the alpha^i points line 1 in Algorithm 1.
alpha = np.zeros((data_size * 2))
data_abs = np.abs(input_data)
- alpha[:data_size] = (
- (self.beta * extra_factor)
- / (data_abs + sys.float_info.epsilon)
+ alpha[:data_size] = (self.beta * extra_factor) / (
+ data_abs + sys.float_info.epsilon
)
- alpha[data_size:] = (
- (self.beta * extra_factor + 1)
- / (data_abs + sys.float_info.epsilon)
+ alpha[data_size:] = (self.beta * extra_factor + 1) / (
+ data_abs + sys.float_info.epsilon
)
alpha = np.sort(np.unique(alpha))
@@ -1220,8 +1226,8 @@ def _op_method(self, input_data, extra_factor=1.0):
k_max = np.prod(data_shape)
if self._k_value > k_max:
warn(
- 'K value of the K-support norm is greater than the input '
- + 'dimension, its value will be set to {0}'.format(k_max),
+ "K value of the K-support norm is greater than the input "
+ + "dimension, its value will be set to {0}".format(k_max),
)
self._k_value = k_max
@@ -1233,8 +1239,7 @@ def _op_method(self, input_data, extra_factor=1.0):
# Computes line 5. in Algorithm 1.
rslt = np.nan_to_num(
- (input_data.flatten() * theta)
- / (theta + self.beta * extra_factor),
+ (input_data.flatten() * theta) / (theta + self.beta * extra_factor),
)
return rslt.reshape(data_shape)
@@ -1275,15 +1280,13 @@ def _find_q(self, sorted_data):
found = True
q_val = 0
- elif (
- (sorted_data[self._k_value - 1:].sum())
- <= sorted_data[self._k_value - 1]
- ):
+ elif (sorted_data[self._k_value - 1 :].sum()) <= sorted_data[self._k_value - 1]:
found = True
q_val = self._k_value - 1
while (
- not found and not cnt == self._k_value
+ not found
+ and not cnt == self._k_value
and (first_idx <= last_idx < self._k_value)
):
@@ -1291,9 +1294,7 @@ def _find_q(self, sorted_data):
cnt += 1
l1_part = sorted_data[q_val:].sum() / (self._k_value - q_val)
- if (
- sorted_data[q_val + 1] <= l1_part <= sorted_data[q_val]
- ):
+ if sorted_data[q_val + 1] <= l1_part <= sorted_data[q_val]:
found = True
else:
@@ -1328,15 +1329,12 @@ def _cost_method(self, *args, **kwargs):
data_abs = data_abs[ix] # Sorted absolute value of the data
q_val = self._find_q(data_abs)
cost_val = (
- (
- np.sum(data_abs[:q_val] ** 2) * 0.5
- + np.sum(data_abs[q_val:]) ** 2
- / (self._k_value - q_val)
- ) * self.beta
- )
+ np.sum(data_abs[:q_val] ** 2) * 0.5
+ + np.sum(data_abs[q_val:]) ** 2 / (self._k_value - q_val)
+ ) * self.beta
- if 'verbose' in kwargs and kwargs['verbose']:
- print(' - K-SUPPORT NORM (X):', cost_val)
+ if "verbose" in kwargs and kwargs["verbose"]:
+ print(" - K-SUPPORT NORM (X):", cost_val)
return cost_val
diff --git a/modopt/opt/reweight.py b/modopt/opt/reweight.py
index 8c4f2449..7ff9aac4 100644
--- a/modopt/opt/reweight.py
+++ b/modopt/opt/reweight.py
@@ -81,7 +81,7 @@ def reweight(self, input_data):
"""
if self.verbose:
- print(' - Reweighting: {0}'.format(self._rw_num))
+ print(" - Reweighting: {0}".format(self._rw_num))
self._rw_num += 1
@@ -89,7 +89,7 @@ def reweight(self, input_data):
if input_data.shape != self.weights.shape:
raise ValueError(
- 'Input data must have the same shape as the initial weights.',
+ "Input data must have the same shape as the initial weights.",
)
thresh_weights = self.thresh_factor * self.original_weights
diff --git a/modopt/plot/__init__.py b/modopt/plot/__init__.py
index 28d60be6..da6e096c 100644
--- a/modopt/plot/__init__.py
+++ b/modopt/plot/__init__.py
@@ -8,4 +8,4 @@
"""
-__all__ = ['cost_plot']
+__all__ = ["cost_plot"]
diff --git a/modopt/plot/cost_plot.py b/modopt/plot/cost_plot.py
index aa855eaa..36958450 100644
--- a/modopt/plot/cost_plot.py
+++ b/modopt/plot/cost_plot.py
@@ -37,20 +37,20 @@ def plotCost(cost_list, output=None):
"""
if import_fail:
- raise ImportError('Matplotlib package not found')
+ raise ImportError("Matplotlib package not found")
else:
if isinstance(output, type(None)):
- file_name = 'cost_function.png'
+ file_name = "cost_function.png"
else:
- file_name = '{0}_cost_function.png'.format(output)
+ file_name = "{0}_cost_function.png".format(output)
plt.figure()
- plt.plot(np.log10(cost_list), 'r-')
- plt.title('Cost Function')
- plt.xlabel('Iteration')
- plt.ylabel(r'$\log_{10}$ Cost')
+ plt.plot(np.log10(cost_list), "r-")
+ plt.title("Cost Function")
+ plt.xlabel("Iteration")
+ plt.ylabel(r"$\log_{10}$ Cost")
plt.savefig(file_name)
plt.close()
- print(' - Saving cost function data to:', file_name)
+ print(" - Saving cost function data to:", file_name)
diff --git a/modopt/signal/__init__.py b/modopt/signal/__init__.py
index dbc6d053..09b2d2c4 100644
--- a/modopt/signal/__init__.py
+++ b/modopt/signal/__init__.py
@@ -8,4 +8,4 @@
"""
-__all__ = ['filter', 'noise', 'positivity', 'svd', 'validation', 'wavelet']
+__all__ = ["filter", "noise", "positivity", "svd", "validation", "wavelet"]
diff --git a/modopt/signal/filter.py b/modopt/signal/filter.py
index 84dd8160..2c7d8626 100644
--- a/modopt/signal/filter.py
+++ b/modopt/signal/filter.py
@@ -81,7 +81,7 @@ def mex_hat(data_point, sigma):
sigma = check_float(sigma)
xs = (data_point / sigma) ** 2
- factor = 2 * (3 * sigma) ** -0.5 * np.pi ** -0.25
+ factor = 2 * (3 * sigma) ** -0.5 * np.pi**-0.25
return factor * (1 - xs) * np.exp(-0.5 * xs)
diff --git a/modopt/signal/noise.py b/modopt/signal/noise.py
index a59d5553..fadf5308 100644
--- a/modopt/signal/noise.py
+++ b/modopt/signal/noise.py
@@ -15,7 +15,7 @@
from modopt.base.backend import get_array_module
-def add_noise(input_data, sigma=1.0, noise_type='gauss'):
+def add_noise(input_data, sigma=1.0, noise_type="gauss"):
"""Add noise to data.
This method adds Gaussian or Poisson noise to the input data.
@@ -70,7 +70,7 @@ def add_noise(input_data, sigma=1.0, noise_type='gauss'):
"""
input_data = np.array(input_data)
- if noise_type not in {'gauss', 'poisson'}:
+ if noise_type not in {"gauss", "poisson"}:
raise ValueError(
'Invalid noise type. Options are "gauss" or "poisson"',
)
@@ -78,14 +78,13 @@ def add_noise(input_data, sigma=1.0, noise_type='gauss'):
if isinstance(sigma, (list, tuple, np.ndarray)):
if len(sigma) != input_data.shape[0]:
raise ValueError(
- 'Number of sigma values must match first dimension of input '
- + 'data',
+ "Number of sigma values must match first dimension of input " + "data",
)
- if noise_type == 'gauss':
+ if noise_type == "gauss":
random = np.random.randn(*input_data.shape)
- elif noise_type == 'poisson':
+ elif noise_type == "poisson":
random = np.random.poisson(np.abs(input_data))
if isinstance(sigma, (int, float)):
@@ -96,7 +95,7 @@ def add_noise(input_data, sigma=1.0, noise_type='gauss'):
return input_data + noise
-def thresh(input_data, threshold, threshold_type='hard'):
+def thresh(input_data, threshold, threshold_type="hard"):
r"""Threshold data.
This method perfoms hard or soft thresholding on the input data.
@@ -169,12 +168,12 @@ def thresh(input_data, threshold, threshold_type='hard'):
input_data = xp.array(input_data)
- if threshold_type not in {'hard', 'soft'}:
+ if threshold_type not in {"hard", "soft"}:
raise ValueError(
'Invalid threshold type. Options are "hard" or "soft"',
)
- if threshold_type == 'soft':
+ if threshold_type == "soft":
denominator = xp.maximum(xp.finfo(np.float64).eps, xp.abs(input_data))
max_value = xp.maximum((1.0 - threshold / denominator), 0)
diff --git a/modopt/signal/positivity.py b/modopt/signal/positivity.py
index c19ba62c..5c4b795b 100644
--- a/modopt/signal/positivity.py
+++ b/modopt/signal/positivity.py
@@ -47,7 +47,7 @@ def pos_recursive(input_data):
Positive coefficients
"""
- if input_data.dtype == 'O':
+ if input_data.dtype == "O":
res = np.array([pos_recursive(elem) for elem in input_data], dtype="object")
else:
@@ -97,15 +97,15 @@ def positive(input_data, ragged=False):
"""
if not isinstance(input_data, (int, float, list, tuple, np.ndarray)):
raise TypeError(
- 'Invalid data type, input must be `int`, `float`, `list`, '
- + '`tuple` or `np.ndarray`.',
+ "Invalid data type, input must be `int`, `float`, `list`, "
+ + "`tuple` or `np.ndarray`.",
)
if isinstance(input_data, (int, float)):
return pos_thresh(input_data)
if ragged:
- input_data = np.array(input_data, dtype='object')
+ input_data = np.array(input_data, dtype="object")
else:
input_data = np.array(input_data)
diff --git a/modopt/signal/svd.py b/modopt/signal/svd.py
index f3d40a51..cc204817 100644
--- a/modopt/signal/svd.py
+++ b/modopt/signal/svd.py
@@ -52,8 +52,8 @@ def find_n_pc(u_vec, factor=0.5):
"""
if np.sqrt(u_vec.shape[0]) % 1:
raise ValueError(
- 'Invalid left singular vector. The size of the first '
- + 'dimenion of ``u_vec`` must be perfect square.',
+ "Invalid left singular vector. The size of the first "
+ + "dimenion of ``u_vec`` must be perfect square.",
)
# Get the shape of the array
@@ -69,13 +69,12 @@ def find_n_pc(u_vec, factor=0.5):
]
# Return the required number of principal components.
- return np.sum([
- (
- u_val[tuple(zip(array_shape // 2))] ** 2 <= factor
- * np.sum(u_val ** 2),
- )
- for u_val in u_auto
- ])
+ return np.sum(
+ [
+ (u_val[tuple(zip(array_shape // 2))] ** 2 <= factor * np.sum(u_val**2),)
+ for u_val in u_auto
+ ]
+ )
def calculate_svd(input_data):
@@ -101,17 +100,17 @@ def calculate_svd(input_data):
"""
if (not isinstance(input_data, np.ndarray)) or (input_data.ndim != 2):
- raise TypeError('Input data must be a 2D np.ndarray.')
+ raise TypeError("Input data must be a 2D np.ndarray.")
return svd(
input_data,
check_finite=False,
- lapack_driver='gesvd',
+ lapack_driver="gesvd",
full_matrices=False,
)
-def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'):
+def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type="hard"):
"""Threshold the singular values.
This method thresholds the input data using singular value decomposition.
@@ -156,16 +155,11 @@ def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'):
"""
less_than_zero = isinstance(n_pc, int) and n_pc <= 0
- str_not_all = isinstance(n_pc, str) and n_pc != 'all'
+ str_not_all = isinstance(n_pc, str) and n_pc != "all"
- if (
- (not isinstance(n_pc, (int, str, type(None))))
- or less_than_zero
- or str_not_all
- ):
+ if (not isinstance(n_pc, (int, str, type(None)))) or less_than_zero or str_not_all:
raise ValueError(
- 'Invalid value for "n_pc", specify a positive integer value or '
- + '"all"',
+ 'Invalid value for "n_pc", specify a positive integer value or ' + '"all"',
)
# Get SVD of input data.
@@ -176,15 +170,14 @@ def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'):
# Find the required number of principal components if not specified.
if isinstance(n_pc, type(None)):
n_pc = find_n_pc(u_vec, factor=0.1)
- print('xxxx', n_pc, u_vec)
+ print("xxxx", n_pc, u_vec)
# If the number of PCs is too large use all of the singular values.
- if (
- (isinstance(n_pc, int) and n_pc >= s_values.size)
- or (isinstance(n_pc, str) and n_pc == 'all')
+ if (isinstance(n_pc, int) and n_pc >= s_values.size) or (
+ isinstance(n_pc, str) and n_pc == "all"
):
n_pc = s_values.size
- warn('Using all singular values.')
+ warn("Using all singular values.")
threshold = s_values[n_pc - 1]
@@ -192,7 +185,7 @@ def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'):
s_new = thresh(s_values, threshold, thresh_type)
if np.all(s_new == s_values):
- warn('No change to singular values.')
+ warn("No change to singular values.")
# Diagonalize the svd
s_new = np.diag(s_new)
@@ -206,7 +199,7 @@ def svd_thresh_coef_fast(
threshold,
n_vals=-1,
extra_vals=5,
- thresh_type='hard',
+ thresh_type="hard",
):
"""Threshold the singular values coefficients.
@@ -241,7 +234,7 @@ def svd_thresh_coef_fast(
ok = False
while not ok:
(u_vec, s_values, v_vec) = svds(input_data, k=n_vals)
- ok = (s_values[0] <= threshold or n_vals == min(input_data.shape) - 1)
+ ok = s_values[0] <= threshold or n_vals == min(input_data.shape) - 1
n_vals = min(n_vals + extra_vals, *input_data.shape)
s_values = thresh(
@@ -259,7 +252,7 @@ def svd_thresh_coef_fast(
)
-def svd_thresh_coef(input_data, operator, threshold, thresh_type='hard'):
+def svd_thresh_coef(input_data, operator, threshold, thresh_type="hard"):
"""Threshold the singular values coefficients.
This method thresholds the input data using singular value decomposition.
@@ -287,7 +280,7 @@ def svd_thresh_coef(input_data, operator, threshold, thresh_type='hard'):
"""
if not callable(operator):
- raise TypeError('Operator must be a callable function.')
+ raise TypeError("Operator must be a callable function.")
# Get SVD of data matrix
u_vec, s_values, v_vec = calculate_svd(input_data)
@@ -302,10 +295,9 @@ def svd_thresh_coef(input_data, operator, threshold, thresh_type='hard'):
array_shape = np.repeat(int(np.sqrt(u_vec.shape[0])), 2)
# Compute threshold matrix.
- ti = np.array([
- np.linalg.norm(elem)
- for elem in operator(matrix2cube(u_vec, array_shape))
- ])
+ ti = np.array(
+ [np.linalg.norm(elem) for elem in operator(matrix2cube(u_vec, array_shape))]
+ )
threshold *= np.repeat(ti, a_matrix.shape[1]).reshape(a_matrix.shape)
# Threshold coefficients.
diff --git a/modopt/signal/validation.py b/modopt/signal/validation.py
index 422a987b..68c1e726 100644
--- a/modopt/signal/validation.py
+++ b/modopt/signal/validation.py
@@ -54,7 +54,7 @@ def transpose_test(
"""
if not callable(operator) or not callable(operator_t):
- raise TypeError('The input operators must be callable functions.')
+ raise TypeError("The input operators must be callable functions.")
if isinstance(y_shape, type(None)):
y_shape = x_shape
@@ -73,4 +73,4 @@ def transpose_test(
x_mty = np.sum(np.multiply(x_val, operator_t(y_val, y_args)))
# Test the difference between the two.
- print(' - | - | =', np.abs(mx_y - x_mty))
+ print(" - | - | =", np.abs(mx_y - x_mty))
diff --git a/modopt/signal/wavelet.py b/modopt/signal/wavelet.py
index bc4ffc70..72d608e7 100644
--- a/modopt/signal/wavelet.py
+++ b/modopt/signal/wavelet.py
@@ -58,20 +58,20 @@ def execute(command_line):
"""
if not isinstance(command_line, str):
- raise TypeError('Command line must be a string.')
+ raise TypeError("Command line must be a string.")
command = command_line.split()
process = sp.Popen(command, stdout=sp.PIPE, stderr=sp.PIPE)
stdout, stderr = process.communicate()
- return stdout.decode('utf-8'), stderr.decode('utf-8')
+ return stdout.decode("utf-8"), stderr.decode("utf-8")
def call_mr_transform(
input_data,
- opt='',
- path='./',
+ opt="",
+ path="./",
remove_files=True,
): # pragma: no cover
"""Call ``mr_transform``.
@@ -127,26 +127,23 @@ def call_mr_transform(
"""
if not import_astropy:
- raise ImportError('Astropy package not found.')
+ raise ImportError("Astropy package not found.")
if (not isinstance(input_data, np.ndarray)) or (input_data.ndim != 2):
- raise ValueError('Input data must be a 2D numpy array.')
+ raise ValueError("Input data must be a 2D numpy array.")
- executable = 'mr_transform'
+ executable = "mr_transform"
# Make sure mr_transform is installed.
is_executable(executable)
# Create a unique string using the current date and time.
- unique_string = (
- datetime.now().strftime('%Y.%m.%d_%H.%M.%S')
- + str(getrandbits(128))
- )
+ unique_string = datetime.now().strftime("%Y.%m.%d_%H.%M.%S") + str(getrandbits(128))
# Set the ouput file names.
- file_name = '{0}mr_temp_{1}'.format(path, unique_string)
- file_fits = '{0}.fits'.format(file_name)
- file_mr = '{0}.mr'.format(file_name)
+ file_name = "{0}mr_temp_{1}".format(path, unique_string)
+ file_fits = "{0}.fits".format(file_name)
+ file_mr = "{0}.mr".format(file_name)
# Write the input data to a fits file.
fits.writeto(file_fits, input_data)
@@ -155,15 +152,15 @@ def call_mr_transform(
opt = opt.split()
# Prepare command and execute it
- command_line = ' '.join([executable] + opt + [file_fits, file_mr])
+ command_line = " ".join([executable] + opt + [file_fits, file_mr])
stdout, _ = execute(command_line)
# Check for errors
- if any(word in stdout for word in ('bad', 'Error', 'Sorry')):
+ if any(word in stdout for word in ("bad", "Error", "Sorry")):
remove(file_fits)
message = '{0} raised following exception: "{1}"'
raise RuntimeError(
- message.format(executable, stdout.rstrip('\n')),
+ message.format(executable, stdout.rstrip("\n")),
)
# Retrieve wavelet transformed data.
@@ -198,12 +195,12 @@ def trim_filter(filter_array):
min_idx = np.min(non_zero_indices, axis=-1)
max_idx = np.max(non_zero_indices, axis=-1)
- return filter_array[min_idx[0]:max_idx[0] + 1, min_idx[1]:max_idx[1] + 1]
+ return filter_array[min_idx[0] : max_idx[0] + 1, min_idx[1] : max_idx[1] + 1]
def get_mr_filters(
data_shape,
- opt='',
+ opt="",
coarse=False,
trim=False,
): # pragma: no cover
@@ -256,7 +253,7 @@ def get_mr_filters(
return mr_filters[:-1]
-def filter_convolve(input_data, filters, filter_rot=False, method='scipy'):
+def filter_convolve(input_data, filters, filter_rot=False, method="scipy"):
"""Filter convolve.
This method convolves the input image with the wavelet filters.
@@ -315,16 +312,14 @@ def filter_convolve(input_data, filters, filter_rot=False, method='scipy'):
axis=0,
)
- return np.array([
- convolve(input_data, filt, method=method) for filt in filters
- ])
+ return np.array([convolve(input_data, filt, method=method) for filt in filters])
def filter_convolve_stack(
input_data,
filters,
filter_rot=False,
- method='scipy',
+ method="scipy",
):
"""Filter convolve.
@@ -366,7 +361,9 @@ def filter_convolve_stack(
"""
# Return the convolved data cube.
- return np.array([
- filter_convolve(elem, filters, filter_rot=filter_rot, method=method)
- for elem in input_data
- ])
+ return np.array(
+ [
+ filter_convolve(elem, filters, filter_rot=filter_rot, method=method)
+ for elem in input_data
+ ]
+ )
diff --git a/modopt/tests/test_algorithms.py b/modopt/tests/test_algorithms.py
index 5671b8e3..ca0cd666 100644
--- a/modopt/tests/test_algorithms.py
+++ b/modopt/tests/test_algorithms.py
@@ -111,8 +111,7 @@ class AlgoCases:
]
)
def case_forward_backward(self, kwargs, idty, use_metrics):
- """Forward Backward case.
- """
+ """Forward Backward case."""
update_kwargs = build_kwargs(kwargs, use_metrics)
algo = algorithms.ForwardBackward(
self.data1,
@@ -242,9 +241,11 @@ def case_grad(self, GradDescent, use_metrics, idty):
)
algo.iterate()
return algo, update_kwargs
- @parametrize(admm=[algorithms.ADMM,algorithms.FastADMM])
+
+ @parametrize(admm=[algorithms.ADMM, algorithms.FastADMM])
def case_admm(self, admm, use_metrics, idty):
"""ADMM setup."""
+
def optim1(init, obs):
return obs
@@ -265,6 +266,7 @@ def optim2(init, obs):
algo.iterate()
return algo, update_kwargs
+
@parametrize_with_cases("algo, kwargs", cases=AlgoCases)
def test_algo(algo, kwargs):
"""Test algorithms."""
diff --git a/modopt/tests/test_base.py b/modopt/tests/test_base.py
index e32ff94b..298253d6 100644
--- a/modopt/tests/test_base.py
+++ b/modopt/tests/test_base.py
@@ -5,6 +5,7 @@
Samuel Farrens
Pierre-Antoine Comby
"""
+
import numpy as np
import numpy.testing as npt
import pytest
diff --git a/modopt/tests/test_helpers/utils.py b/modopt/tests/test_helpers/utils.py
index d8227640..895b2371 100644
--- a/modopt/tests/test_helpers/utils.py
+++ b/modopt/tests/test_helpers/utils.py
@@ -4,6 +4,7 @@
:Author: Pierre-Antoine Comby
"""
+
import pytest
diff --git a/modopt/tests/test_math.py b/modopt/tests/test_math.py
index ea177b15..e7536b03 100644
--- a/modopt/tests/test_math.py
+++ b/modopt/tests/test_math.py
@@ -6,6 +6,7 @@
Samuel Farrens
Pierre-Antoine Comby
"""
+
import pytest
from test_helpers import failparam, skipparam
diff --git a/modopt/tests/test_opt.py b/modopt/tests/test_opt.py
index 1c2e7824..e77074ab 100644
--- a/modopt/tests/test_opt.py
+++ b/modopt/tests/test_opt.py
@@ -193,9 +193,13 @@ def case_linear_wavelet_transform(self, compute_backend="numpy"):
level=2,
)
data_op = np.arange(64).reshape(8, 8).astype(float)
- res_op, slices, shapes = pywt.ravel_coeffs(pywt.wavedecn(data_op, "haar", level=2))
+ res_op, slices, shapes = pywt.ravel_coeffs(
+ pywt.wavedecn(data_op, "haar", level=2)
+ )
data_adj_op = linop.op(data_op)
- res_adj_op = pywt.waverecn(pywt.unravel_coeffs(data_adj_op, slices, shapes, "wavedecn"), "haar")
+ res_adj_op = pywt.waverecn(
+ pywt.unravel_coeffs(data_adj_op, slices, shapes, "wavedecn"), "haar"
+ )
return linop, data_op, data_adj_op, res_op, res_adj_op
@parametrize(weights=[[1.0, 1.0], None])
diff --git a/modopt/tests/test_signal.py b/modopt/tests/test_signal.py
index 202e541b..b3787fc6 100644
--- a/modopt/tests/test_signal.py
+++ b/modopt/tests/test_signal.py
@@ -17,6 +17,7 @@
class TestFilter:
"""Test filter module"""
+
@pytest.mark.parametrize(
("norm", "result"), [(True, 0.24197072451914337), (False, 0.60653065971263342)]
)
@@ -24,7 +25,6 @@ def test_gaussian_filter(self, norm, result):
"""Test gaussian filter."""
npt.assert_almost_equal(filter.gaussian_filter(1, 1, norm=norm), result)
-
def test_mex_hat(self):
"""Test mexican hat filter."""
npt.assert_almost_equal(
@@ -32,7 +32,6 @@ def test_mex_hat(self):
-0.35213905225713371,
)
-
def test_mex_hat_dir(self):
"""Test directional mexican hat filter."""
npt.assert_almost_equal(
@@ -86,13 +85,16 @@ def test_thresh(self, threshold_type, result):
noise.thresh(self.data1, 5, threshold_type=threshold_type), result
)
+
class TestPositivity:
"""Test positivity module."""
+
data1 = np.arange(9).reshape(3, 3).astype(float)
data4 = np.array([[0, 0, 0], [0, 0, 5.0], [6.0, 7.0, 8.0]])
data5 = np.array(
[[0, 0, 0], [0, 0, 0], [1.0, 2.0, 3.0]],
)
+
@pytest.mark.parametrize(
("value", "expected"),
[
@@ -231,6 +233,7 @@ def test_svd_thresh_coef(self, data, operator):
# TODO test_svd_thresh_coef_fast
+
class TestValidation:
"""Test validation Module."""
From 067c40c83972990099fa95c2d9d5d7338dc5d3fe Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Fri, 16 Feb 2024 21:22:47 +0100
Subject: [PATCH 05/34] run ruff --fix --unsafe-fixes .
---
docs/source/conf.py | 23 +++++++++----------
modopt/__init__.py | 1 -
modopt/base/__init__.py | 1 -
modopt/base/backend.py | 1 -
modopt/base/np_adjust.py | 1 -
modopt/base/observable.py | 7 +++---
modopt/base/transform.py | 5 ++--
modopt/base/types.py | 3 +--
.../example_lasso_forward_backward.py | 1 -
modopt/interface/__init__.py | 1 -
modopt/interface/errors.py | 13 +++++------
modopt/interface/log.py | 3 +--
modopt/math/__init__.py | 1 -
modopt/math/convolve.py | 1 -
modopt/math/matrix.py | 3 +--
modopt/math/metrics.py | 3 +--
modopt/math/stats.py | 1 -
modopt/opt/__init__.py | 1 -
modopt/opt/algorithms/__init__.py | 19 ---------------
modopt/opt/algorithms/base.py | 3 +--
modopt/opt/algorithms/forward_backward.py | 9 ++++----
modopt/opt/algorithms/gradient_descent.py | 1 -
modopt/opt/algorithms/primal_dual.py | 1 -
modopt/opt/gradient.py | 5 ++--
modopt/opt/linear/base.py | 2 +-
modopt/opt/proximity.py | 21 ++++++++---------
modopt/opt/reweight.py | 5 ++--
modopt/plot/__init__.py | 1 -
modopt/plot/cost_plot.py | 3 +--
modopt/signal/__init__.py | 1 -
modopt/signal/filter.py | 1 -
modopt/signal/noise.py | 2 --
modopt/signal/positivity.py | 1 -
modopt/signal/svd.py | 1 -
modopt/signal/wavelet.py | 9 ++++----
modopt/tests/test_algorithms.py | 8 +------
modopt/tests/test_helpers/__init__.py | 1 -
modopt/tests/test_opt.py | 12 ++++++----
modopt/tests/test_signal.py | 2 +-
39 files changed, 61 insertions(+), 117 deletions(-)
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 987576a9..cd39ee08 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
# Python Template sphinx config
# Import relevant modules
@@ -19,7 +18,7 @@
mdata = metadata(project)
author = mdata["Author"]
version = mdata["Version"]
-copyright = "2020, {}".format(author)
+copyright = f"2020, {author}"
gh_user = "sfarrens"
# If your documentation needs a minimal Sphinx version, state it here.
@@ -117,7 +116,7 @@
)
# The name for this set of Sphinx documents. If None, it defaults to
# " v documentation".
-html_title = "{0} v{1}".format(project, version)
+html_title = f"{project} v{version}"
# A shorter title for the navigation bar. Default is the same as html_title.
# html_short_title = None
@@ -174,20 +173,20 @@ def add_notebooks(nb_path="../../notebooks"):
nb_name = nb.rstrip(nb_ext)
nb_link_file_name = nb_name + ".nblink"
- print("Writing {0}".format(nb_link_file_name))
+ print(f"Writing {nb_link_file_name}")
with open(nb_link_file_name, "w") as nb_link_file:
nb_link_file.write(nb_link_format.format(nb_path, nb))
- print("Looking for {0} in {1}".format(nb_name, nb_rst_file_name))
- with open(nb_rst_file_name, "r") as nb_rst_file:
+ print(f"Looking for {nb_name} in {nb_rst_file_name}")
+ with open(nb_rst_file_name) as nb_rst_file:
check_name = nb_name not in nb_rst_file.read()
if check_name:
- print("Adding {0} to {1}".format(nb_name, nb_rst_file_name))
+ print(f"Adding {nb_name} to {nb_rst_file_name}")
with open(nb_rst_file_name, "a") as nb_rst_file:
if list_pos == 0:
nb_rst_file.write("\n")
- nb_rst_file.write(" {0}\n".format(nb_name))
+ nb_rst_file.write(f" {nb_name}\n")
return nbs
@@ -220,14 +219,14 @@ def add_notebooks(nb_path="../../notebooks"):
"""
nb_header_pt2 = (
r""" """
r""""""
)
diff --git a/modopt/__init__.py b/modopt/__init__.py
index d446e15d..958f3ace 100644
--- a/modopt/__init__.py
+++ b/modopt/__init__.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""MODOPT PACKAGE.
diff --git a/modopt/base/__init__.py b/modopt/base/__init__.py
index d75ff315..e7df6c37 100644
--- a/modopt/base/__init__.py
+++ b/modopt/base/__init__.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""BASE ROUTINES.
diff --git a/modopt/base/backend.py b/modopt/base/backend.py
index fd933ebb..b4987942 100644
--- a/modopt/base/backend.py
+++ b/modopt/base/backend.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""BACKEND MODULE.
diff --git a/modopt/base/np_adjust.py b/modopt/base/np_adjust.py
index 31a785f5..586a1ee0 100644
--- a/modopt/base/np_adjust.py
+++ b/modopt/base/np_adjust.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""NUMPY ADJUSTMENT ROUTINES.
diff --git a/modopt/base/observable.py b/modopt/base/observable.py
index 2f69a1a7..69c6b238 100644
--- a/modopt/base/observable.py
+++ b/modopt/base/observable.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""Observable.
@@ -13,13 +12,13 @@
import numpy as np
-class SignalObject(object):
+class SignalObject:
"""Dummy class for signals."""
pass
-class Observable(object):
+class Observable:
"""Base class for observable classes.
This class defines a simple interface to add or remove observers
@@ -177,7 +176,7 @@ def _remove_observer(self, signal, observer):
self._observers[signal].remove(observer)
-class MetricObserver(object):
+class MetricObserver:
"""Metric observer.
Wrapper of the metric to the observer object notify by the Observable
diff --git a/modopt/base/transform.py b/modopt/base/transform.py
index fedd5efb..1dc9039a 100644
--- a/modopt/base/transform.py
+++ b/modopt/base/transform.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""DATA TRANSFORM ROUTINES.
@@ -288,7 +287,7 @@ def cube2matrix(data_cube):
"""
return data_cube.reshape(
- [data_cube.shape[0]] + [np.prod(data_cube.shape[1:])],
+ [data_cube.shape[0], np.prod(data_cube.shape[1:])],
).T
@@ -333,4 +332,4 @@ def matrix2cube(data_matrix, im_shape):
cube2matrix : complimentary function
"""
- return data_matrix.T.reshape([data_matrix.shape[1]] + list(im_shape))
+ return data_matrix.T.reshape([data_matrix.shape[1], *list(im_shape)])
diff --git a/modopt/base/types.py b/modopt/base/types.py
index 7ea805ad..5ed24ec3 100644
--- a/modopt/base/types.py
+++ b/modopt/base/types.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""TYPE HANDLING ROUTINES.
@@ -165,7 +164,7 @@ def check_npndarray(input_obj, dtype=None, writeable=True, verbose=True):
):
raise (
TypeError(
- "The numpy array elements are not of type: {0}".format(dtype),
+ f"The numpy array elements are not of type: {dtype}",
),
)
diff --git a/modopt/examples/example_lasso_forward_backward.py b/modopt/examples/example_lasso_forward_backward.py
index c28b0499..7e650e05 100644
--- a/modopt/examples/example_lasso_forward_backward.py
+++ b/modopt/examples/example_lasso_forward_backward.py
@@ -1,4 +1,3 @@
-# noqa: D205
"""
Solving the LASSO Problem with the Forward Backward Algorithm.
==============================================================
diff --git a/modopt/interface/__init__.py b/modopt/interface/__init__.py
index 55904ca1..529816ee 100644
--- a/modopt/interface/__init__.py
+++ b/modopt/interface/__init__.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""INTERFACE ROUTINES.
diff --git a/modopt/interface/errors.py b/modopt/interface/errors.py
index eb4aa4ca..5c84ad0e 100644
--- a/modopt/interface/errors.py
+++ b/modopt/interface/errors.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""ERROR HANDLING ROUTINES.
@@ -39,7 +38,7 @@ def warn(warn_string, log=None):
warn_txt = colored("WARNING", "yellow")
# Print warning to stdout.
- sys.stderr.write("{0}: {1}\n".format(warn_txt, warn_string))
+ sys.stderr.write(f"{warn_txt}: {warn_string}\n")
# Check if a logging structure is provided.
if not isinstance(log, type(None)):
@@ -66,12 +65,12 @@ def catch_error(exception, log=None):
err_txt = colored("ERROR", "red")
# Print exception to stdout.
- stream_txt = "{0}: {1}\n".format(err_txt, exception)
+ stream_txt = f"{err_txt}: {exception}\n"
sys.stderr.write(stream_txt)
# Check if a logging structure is provided.
if not isinstance(log, type(None)):
- log_txt = "ERROR: {0}\n".format(exception)
+ log_txt = f"ERROR: {exception}\n"
log.exception(log_txt)
@@ -92,10 +91,10 @@ def file_name_error(file_name):
"""
if file_name == "" or file_name[0][0] == "-":
- raise IOError("Input file name not specified.")
+ raise OSError("Input file name not specified.")
elif not os.path.isfile(file_name):
- raise IOError("Input file name {0} not found!".format(file_name))
+ raise OSError(f"Input file name {file_name} not found!")
def is_exe(fpath):
@@ -151,4 +150,4 @@ def is_executable(exe_name):
if not res:
message = "{0} does not appear to be a valid executable on this system."
- raise IOError(message.format(exe_name))
+ raise OSError(message.format(exe_name))
diff --git a/modopt/interface/log.py b/modopt/interface/log.py
index a02428d9..d3e0d8e9 100644
--- a/modopt/interface/log.py
+++ b/modopt/interface/log.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""LOGGING ROUTINES.
@@ -30,7 +29,7 @@ def set_up_log(filename, verbose=True):
"""
# Add file extension.
- filename = "{0}.log".format(filename)
+ filename = f"{filename}.log"
if verbose:
print("Preparing log file:", filename)
diff --git a/modopt/math/__init__.py b/modopt/math/__init__.py
index 8e92aa50..0423a333 100644
--- a/modopt/math/__init__.py
+++ b/modopt/math/__init__.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""MATHEMATICS ROUTINES.
diff --git a/modopt/math/convolve.py b/modopt/math/convolve.py
index 528b2338..ac1cf84c 100644
--- a/modopt/math/convolve.py
+++ b/modopt/math/convolve.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""CONVOLUTION ROUTINES.
diff --git a/modopt/math/matrix.py b/modopt/math/matrix.py
index 6ddb3f2f..a2419a6c 100644
--- a/modopt/math/matrix.py
+++ b/modopt/math/matrix.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""MATRIX ROUTINES.
@@ -257,7 +256,7 @@ def rotate(matrix, angle):
return matrix[tuple(zip(new_index.T))].reshape(shape.T)
-class PowerMethod(object):
+class PowerMethod:
"""Power method class.
This method performs implements power method to calculate the spectral
diff --git a/modopt/math/metrics.py b/modopt/math/metrics.py
index 93f7ce06..8f797f02 100644
--- a/modopt/math/metrics.py
+++ b/modopt/math/metrics.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""METRICS.
@@ -268,6 +267,6 @@ def nrmse(test, ref, mask=None):
ref = mask * ref
num = np.sqrt(mse(test, ref))
- deno = np.sqrt(np.mean((np.square(test))))
+ deno = np.sqrt(np.mean(np.square(test)))
return num / deno
diff --git a/modopt/math/stats.py b/modopt/math/stats.py
index 59bf6759..b3ee0d8b 100644
--- a/modopt/math/stats.py
+++ b/modopt/math/stats.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""STATISTICS ROUTINES.
diff --git a/modopt/opt/__init__.py b/modopt/opt/__init__.py
index 8b285bee..86564f90 100644
--- a/modopt/opt/__init__.py
+++ b/modopt/opt/__init__.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""OPTIMISATION PROBLEM MODULES.
diff --git a/modopt/opt/algorithms/__init__.py b/modopt/opt/algorithms/__init__.py
index 6a29325a..ce6c5e56 100644
--- a/modopt/opt/algorithms/__init__.py
+++ b/modopt/opt/algorithms/__init__.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
r"""OPTIMISATION ALGORITHMS.
This module contains class implementations of various optimisation algoritms.
@@ -45,21 +44,3 @@
"""
-from modopt.opt.algorithms.base import SetUp
-from modopt.opt.algorithms.forward_backward import (
- FISTA,
- POGM,
- ForwardBackward,
- GenForwardBackward,
-)
-from modopt.opt.algorithms.gradient_descent import (
- AdaGenericGradOpt,
- ADAMGradOpt,
- GenericGradOpt,
- MomentumGradOpt,
- RMSpropGradOpt,
- SAGAOptGradOpt,
- VanillaGenericGradOpt,
-)
-from modopt.opt.algorithms.primal_dual import Condat
-from modopt.opt.algorithms.admm import ADMM, FastADMM
diff --git a/modopt/opt/algorithms/base.py b/modopt/opt/algorithms/base.py
index e2b9017d..dbb73be0 100644
--- a/modopt/opt/algorithms/base.py
+++ b/modopt/opt/algorithms/base.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""Base SetUp for optimisation algorithms."""
from inspect import getmro
@@ -118,7 +117,7 @@ def metrics(self, metrics):
self._metrics = metrics
else:
raise TypeError(
- "Metrics must be a dictionary, not {0}.".format(type(metrics)),
+ f"Metrics must be a dictionary, not {type(metrics)}.",
)
def any_convergence_flag(self):
diff --git a/modopt/opt/algorithms/forward_backward.py b/modopt/opt/algorithms/forward_backward.py
index d34125fa..4c1cb35c 100644
--- a/modopt/opt/algorithms/forward_backward.py
+++ b/modopt/opt/algorithms/forward_backward.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""Forward-Backward Algorithms."""
import numpy as np
@@ -9,7 +8,7 @@
from modopt.opt.linear import Identity
-class FISTA(object):
+class FISTA:
r"""FISTA.
This class is inherited by optimisation classes to speed up convergence
@@ -602,7 +601,7 @@ def __init__(
self._x_old = self.xp.copy(x)
# Set the algorithm operators
- for operator in [grad, cost] + prox_list:
+ for operator in [grad, cost, *prox_list]:
self._check_operator(operator)
self._grad = grad
@@ -610,7 +609,7 @@ def __init__(
self._linear = linear
if cost == "auto":
- self._cost_func = costObj([self._grad] + prox_list)
+ self._cost_func = costObj([self._grad, *prox_list])
else:
self._cost_func = cost
@@ -689,7 +688,7 @@ def _set_weights(self, weights):
if self.xp.sum(weights) != expected_weight_sum:
raise ValueError(
"Proximity operator weights must sum to 1.0. Current sum of "
- + "weights = {0}".format(self.xp.sum(weights)),
+ + f"weights = {self.xp.sum(weights)}",
)
self._weights = weights
diff --git a/modopt/opt/algorithms/gradient_descent.py b/modopt/opt/algorithms/gradient_descent.py
index d3af1686..0960be5a 100644
--- a/modopt/opt/algorithms/gradient_descent.py
+++ b/modopt/opt/algorithms/gradient_descent.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""Gradient Descent Algorithms."""
import numpy as np
diff --git a/modopt/opt/algorithms/primal_dual.py b/modopt/opt/algorithms/primal_dual.py
index 179ddf95..24908993 100644
--- a/modopt/opt/algorithms/primal_dual.py
+++ b/modopt/opt/algorithms/primal_dual.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""Primal-Dual Algorithms."""
from modopt.opt.algorithms.base import SetUp
diff --git a/modopt/opt/gradient.py b/modopt/opt/gradient.py
index 8d7bacc7..bd214f21 100644
--- a/modopt/opt/gradient.py
+++ b/modopt/opt/gradient.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""GRADIENT CLASSES.
@@ -14,7 +13,7 @@
from modopt.base.types import check_callable, check_float, check_npndarray
-class GradParent(object):
+class GradParent:
"""Gradient Parent Class.
This class defines the basic methods that will be inherited by specific
@@ -289,7 +288,7 @@ def _cost_method(self, *args, **kwargs):
"""
cost_val = 0.5 * np.linalg.norm(self.obs_data - self.op(args[0])) ** 2
- if "verbose" in kwargs and kwargs["verbose"]:
+ if kwargs.get("verbose"):
print(" - DATA FIDELITY (X):", cost_val)
return cost_val
diff --git a/modopt/opt/linear/base.py b/modopt/opt/linear/base.py
index 79f05d85..9fa35187 100644
--- a/modopt/opt/linear/base.py
+++ b/modopt/opt/linear/base.py
@@ -6,7 +6,7 @@
from modopt.base.backend import get_array_module
-class LinearParent(object):
+class LinearParent:
"""Linear Operator Parent Class.
This class sets the structure for defining linear operator instances.
diff --git a/modopt/opt/proximity.py b/modopt/opt/proximity.py
index b562f77a..91a99f2a 100644
--- a/modopt/opt/proximity.py
+++ b/modopt/opt/proximity.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""PROXIMITY OPERATORS.
@@ -32,7 +31,7 @@
from modopt.signal.svd import svd_thresh, svd_thresh_coef, svd_thresh_coef_fast
-class ProximityParent(object):
+class ProximityParent:
"""Proximity Operator Parent Class.
This class sets the structure for defining proximity operator instances.
@@ -140,7 +139,7 @@ def _cost_method(self, *args, **kwargs):
``0.0``
"""
- if "verbose" in kwargs and kwargs["verbose"]:
+ if kwargs.get("verbose"):
print(" - Min (X):", np.min(args[0]))
return 0
@@ -221,7 +220,7 @@ def _cost_method(self, *args, **kwargs):
if isinstance(cost_val, xp.ndarray):
cost_val = cost_val.item()
- if "verbose" in kwargs and kwargs["verbose"]:
+ if kwargs.get("verbose"):
print(" - L1 NORM (X):", cost_val)
return cost_val
@@ -365,7 +364,7 @@ def _cost_method(self, *args, **kwargs):
"""
cost_val = self.thresh * nuclear_norm(cube2matrix(args[0]))
- if "verbose" in kwargs and kwargs["verbose"]:
+ if kwargs.get("verbose"):
print(" - NUCLEAR NORM (X):", cost_val)
return cost_val
@@ -706,7 +705,7 @@ def _cost_method(self, *args, **kwargs):
self.weights * np.sort(np.squeeze(np.abs(args[0]))[::-1]),
)
- if "verbose" in kwargs and kwargs["verbose"]:
+ if kwargs.get("verbose"):
print(" - OWL NORM (X):", cost_val)
return cost_val
@@ -790,7 +789,7 @@ def _cost_method(self, *args, **kwargs):
np.abs(self.weights * self._linear.op(args[0]) ** 2),
)
- if "verbose" in kwargs and kwargs["verbose"]:
+ if kwargs.get("verbose"):
print(" - L2 NORM (X):", cost_val)
return cost_val
@@ -879,7 +878,7 @@ def _cost_method(self, *args, **kwargs):
+ np.abs(self.beta * self._linear.op(args[0])),
)
- if "verbose" in kwargs and kwargs["verbose"]:
+ if kwargs.get("verbose"):
print(" - ELASTIC NET (X):", cost_val)
return cost_val
@@ -1183,7 +1182,7 @@ def _find_alpha(self, input_data, extra_factor=1.0):
data_size = input_data.shape[0]
# Computes the alpha^i points line 1 in Algorithm 1.
- alpha = np.zeros((data_size * 2))
+ alpha = np.zeros(data_size * 2)
data_abs = np.abs(input_data)
alpha[:data_size] = (self.beta * extra_factor) / (
data_abs + sys.float_info.epsilon
@@ -1227,7 +1226,7 @@ def _op_method(self, input_data, extra_factor=1.0):
if self._k_value > k_max:
warn(
"K value of the K-support norm is greater than the input "
- + "dimension, its value will be set to {0}".format(k_max),
+ + f"dimension, its value will be set to {k_max}",
)
self._k_value = k_max
@@ -1333,7 +1332,7 @@ def _cost_method(self, *args, **kwargs):
+ np.sum(data_abs[q_val:]) ** 2 / (self._k_value - q_val)
) * self.beta
- if "verbose" in kwargs and kwargs["verbose"]:
+ if kwargs.get("verbose"):
print(" - K-SUPPORT NORM (X):", cost_val)
return cost_val
diff --git a/modopt/opt/reweight.py b/modopt/opt/reweight.py
index 7ff9aac4..8d120101 100644
--- a/modopt/opt/reweight.py
+++ b/modopt/opt/reweight.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""REWEIGHTING CLASSES.
@@ -13,7 +12,7 @@
from modopt.base.types import check_float
-class cwbReweight(object):
+class cwbReweight:
"""Candes, Wakin and Boyd reweighting class.
This class implements the reweighting scheme described in
@@ -81,7 +80,7 @@ def reweight(self, input_data):
"""
if self.verbose:
- print(" - Reweighting: {0}".format(self._rw_num))
+ print(f" - Reweighting: {self._rw_num}")
self._rw_num += 1
diff --git a/modopt/plot/__init__.py b/modopt/plot/__init__.py
index da6e096c..f6b39978 100644
--- a/modopt/plot/__init__.py
+++ b/modopt/plot/__init__.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""PLOTTING ROUTINES.
diff --git a/modopt/plot/cost_plot.py b/modopt/plot/cost_plot.py
index 36958450..2274f35d 100644
--- a/modopt/plot/cost_plot.py
+++ b/modopt/plot/cost_plot.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""PLOTTING ROUTINES.
@@ -43,7 +42,7 @@ def plotCost(cost_list, output=None):
if isinstance(output, type(None)):
file_name = "cost_function.png"
else:
- file_name = "{0}_cost_function.png".format(output)
+ file_name = f"{output}_cost_function.png"
plt.figure()
plt.plot(np.log10(cost_list), "r-")
diff --git a/modopt/signal/__init__.py b/modopt/signal/__init__.py
index 09b2d2c4..2aee1987 100644
--- a/modopt/signal/__init__.py
+++ b/modopt/signal/__init__.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""SIGNAL PROCESSING ROUTINES.
diff --git a/modopt/signal/filter.py b/modopt/signal/filter.py
index 2c7d8626..0e50d28f 100644
--- a/modopt/signal/filter.py
+++ b/modopt/signal/filter.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""FILTER ROUTINES.
diff --git a/modopt/signal/noise.py b/modopt/signal/noise.py
index fadf5308..b43a0b61 100644
--- a/modopt/signal/noise.py
+++ b/modopt/signal/noise.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""NOISE ROUTINES.
@@ -8,7 +7,6 @@
"""
-from builtins import zip
import numpy as np
diff --git a/modopt/signal/positivity.py b/modopt/signal/positivity.py
index 5c4b795b..f3f312d3 100644
--- a/modopt/signal/positivity.py
+++ b/modopt/signal/positivity.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""POSITIVITY.
diff --git a/modopt/signal/svd.py b/modopt/signal/svd.py
index cc204817..dd080306 100644
--- a/modopt/signal/svd.py
+++ b/modopt/signal/svd.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""SVD ROUTINES.
diff --git a/modopt/signal/wavelet.py b/modopt/signal/wavelet.py
index 72d608e7..d624db3a 100644
--- a/modopt/signal/wavelet.py
+++ b/modopt/signal/wavelet.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""WAVELET MODULE.
@@ -141,9 +140,9 @@ def call_mr_transform(
unique_string = datetime.now().strftime("%Y.%m.%d_%H.%M.%S") + str(getrandbits(128))
# Set the ouput file names.
- file_name = "{0}mr_temp_{1}".format(path, unique_string)
- file_fits = "{0}.fits".format(file_name)
- file_mr = "{0}.mr".format(file_name)
+ file_name = f"{path}mr_temp_{unique_string}"
+ file_fits = f"{file_name}.fits"
+ file_mr = f"{file_name}.mr"
# Write the input data to a fits file.
fits.writeto(file_fits, input_data)
@@ -152,7 +151,7 @@ def call_mr_transform(
opt = opt.split()
# Prepare command and execute it
- command_line = " ".join([executable] + opt + [file_fits, file_mr])
+ command_line = " ".join([executable, *opt, file_fits, file_mr])
stdout, _ = execute(command_line)
# Check for errors
diff --git a/modopt/tests/test_algorithms.py b/modopt/tests/test_algorithms.py
index ca0cd666..c1e676a5 100644
--- a/modopt/tests/test_algorithms.py
+++ b/modopt/tests/test_algorithms.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
"""UNIT TESTS FOR Algorithms.
@@ -11,18 +10,13 @@
import numpy as np
import numpy.testing as npt
-import pytest
from modopt.opt import algorithms, cost, gradient, linear, proximity, reweight
from pytest_cases import (
- case,
fixture,
- fixture_ref,
- lazy_value,
parametrize,
parametrize_with_cases,
)
-from test_helpers import Dummy
SKLEARN_AVAILABLE = True
try:
@@ -80,7 +74,7 @@ def build_kwargs(kwargs, use_metrics):
@parametrize(use_metrics=[True, False])
class AlgoCases:
- """Cases for algorithms.
+ r"""Cases for algorithms.
Most of the test solves the trivial problem
diff --git a/modopt/tests/test_helpers/__init__.py b/modopt/tests/test_helpers/__init__.py
index 3886b877..e69de29b 100644
--- a/modopt/tests/test_helpers/__init__.py
+++ b/modopt/tests/test_helpers/__init__.py
@@ -1 +0,0 @@
-from .utils import failparam, skipparam, Dummy
diff --git a/modopt/tests/test_opt.py b/modopt/tests/test_opt.py
index e77074ab..7dc27871 100644
--- a/modopt/tests/test_opt.py
+++ b/modopt/tests/test_opt.py
@@ -37,10 +37,14 @@
PYWT_AVAILABLE = False
# Basic functions to be used as operators or as dummy functions
-func_identity = lambda x_val: x_val
-func_double = lambda x_val: x_val * 2
-func_sq = lambda x_val: x_val**2
-func_cube = lambda x_val: x_val**3
+def func_identity(x_val):
+ return x_val
+def func_double(x_val):
+ return x_val * 2
+def func_sq(x_val):
+ return x_val ** 2
+def func_cube(x_val):
+ return x_val ** 3
@case(tags="cost")
diff --git a/modopt/tests/test_signal.py b/modopt/tests/test_signal.py
index b3787fc6..cdd95277 100644
--- a/modopt/tests/test_signal.py
+++ b/modopt/tests/test_signal.py
@@ -16,7 +16,7 @@
class TestFilter:
- """Test filter module"""
+ """Test filter module."""
@pytest.mark.parametrize(
("norm", "result"), [(True, 0.24197072451914337), (False, 0.60653065971263342)]
From 29833bc0f5f6cb85b7add3af7ad947d79440ed6b Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Fri, 16 Feb 2024 21:25:05 +0100
Subject: [PATCH 06/34] move to a src layout.
---
{modopt/examples => examples}/README.rst | 0
{modopt/examples => examples}/__init__.py | 0
{modopt/examples => examples}/conftest.py | 0
.../example_lasso_forward_backward.py | 0
pyproject.toml | 11 +++--------
{modopt => src/modopt}/__init__.py | 0
{modopt => src/modopt}/base/__init__.py | 0
{modopt => src/modopt}/base/backend.py | 0
{modopt => src/modopt}/base/np_adjust.py | 0
{modopt => src/modopt}/base/observable.py | 0
{modopt => src/modopt}/base/transform.py | 0
{modopt => src/modopt}/base/types.py | 0
{modopt => src/modopt}/interface/__init__.py | 0
{modopt => src/modopt}/interface/errors.py | 0
{modopt => src/modopt}/interface/log.py | 0
{modopt => src/modopt}/math/__init__.py | 0
{modopt => src/modopt}/math/convolve.py | 0
{modopt => src/modopt}/math/matrix.py | 0
{modopt => src/modopt}/math/metrics.py | 0
{modopt => src/modopt}/math/stats.py | 0
{modopt => src/modopt}/opt/__init__.py | 0
{modopt => src/modopt}/opt/algorithms/__init__.py | 0
{modopt => src/modopt}/opt/algorithms/admm.py | 0
{modopt => src/modopt}/opt/algorithms/base.py | 0
.../modopt}/opt/algorithms/forward_backward.py | 0
.../modopt}/opt/algorithms/gradient_descent.py | 0
{modopt => src/modopt}/opt/algorithms/primal_dual.py | 0
{modopt => src/modopt}/opt/cost.py | 0
{modopt => src/modopt}/opt/gradient.py | 0
{modopt => src/modopt}/opt/linear/__init__.py | 0
{modopt => src/modopt}/opt/linear/base.py | 0
{modopt => src/modopt}/opt/linear/wavelet.py | 0
{modopt => src/modopt}/opt/proximity.py | 0
{modopt => src/modopt}/opt/reweight.py | 0
{modopt => src/modopt}/plot/__init__.py | 0
{modopt => src/modopt}/plot/cost_plot.py | 0
{modopt => src/modopt}/signal/__init__.py | 0
{modopt => src/modopt}/signal/filter.py | 0
{modopt => src/modopt}/signal/noise.py | 0
{modopt => src/modopt}/signal/positivity.py | 0
{modopt => src/modopt}/signal/svd.py | 0
{modopt => src/modopt}/signal/validation.py | 0
{modopt => src/modopt}/signal/wavelet.py | 0
{modopt/tests => tests}/test_algorithms.py | 0
{modopt/tests => tests}/test_base.py | 0
{modopt/tests => tests}/test_helpers/__init__.py | 0
{modopt/tests => tests}/test_helpers/utils.py | 0
{modopt/tests => tests}/test_math.py | 0
{modopt/tests => tests}/test_opt.py | 0
{modopt/tests => tests}/test_signal.py | 0
50 files changed, 3 insertions(+), 8 deletions(-)
rename {modopt/examples => examples}/README.rst (100%)
rename {modopt/examples => examples}/__init__.py (100%)
rename {modopt/examples => examples}/conftest.py (100%)
rename {modopt/examples => examples}/example_lasso_forward_backward.py (100%)
rename {modopt => src/modopt}/__init__.py (100%)
rename {modopt => src/modopt}/base/__init__.py (100%)
rename {modopt => src/modopt}/base/backend.py (100%)
rename {modopt => src/modopt}/base/np_adjust.py (100%)
rename {modopt => src/modopt}/base/observable.py (100%)
rename {modopt => src/modopt}/base/transform.py (100%)
rename {modopt => src/modopt}/base/types.py (100%)
rename {modopt => src/modopt}/interface/__init__.py (100%)
rename {modopt => src/modopt}/interface/errors.py (100%)
rename {modopt => src/modopt}/interface/log.py (100%)
rename {modopt => src/modopt}/math/__init__.py (100%)
rename {modopt => src/modopt}/math/convolve.py (100%)
rename {modopt => src/modopt}/math/matrix.py (100%)
rename {modopt => src/modopt}/math/metrics.py (100%)
rename {modopt => src/modopt}/math/stats.py (100%)
rename {modopt => src/modopt}/opt/__init__.py (100%)
rename {modopt => src/modopt}/opt/algorithms/__init__.py (100%)
rename {modopt => src/modopt}/opt/algorithms/admm.py (100%)
rename {modopt => src/modopt}/opt/algorithms/base.py (100%)
rename {modopt => src/modopt}/opt/algorithms/forward_backward.py (100%)
rename {modopt => src/modopt}/opt/algorithms/gradient_descent.py (100%)
rename {modopt => src/modopt}/opt/algorithms/primal_dual.py (100%)
rename {modopt => src/modopt}/opt/cost.py (100%)
rename {modopt => src/modopt}/opt/gradient.py (100%)
rename {modopt => src/modopt}/opt/linear/__init__.py (100%)
rename {modopt => src/modopt}/opt/linear/base.py (100%)
rename {modopt => src/modopt}/opt/linear/wavelet.py (100%)
rename {modopt => src/modopt}/opt/proximity.py (100%)
rename {modopt => src/modopt}/opt/reweight.py (100%)
rename {modopt => src/modopt}/plot/__init__.py (100%)
rename {modopt => src/modopt}/plot/cost_plot.py (100%)
rename {modopt => src/modopt}/signal/__init__.py (100%)
rename {modopt => src/modopt}/signal/filter.py (100%)
rename {modopt => src/modopt}/signal/noise.py (100%)
rename {modopt => src/modopt}/signal/positivity.py (100%)
rename {modopt => src/modopt}/signal/svd.py (100%)
rename {modopt => src/modopt}/signal/validation.py (100%)
rename {modopt => src/modopt}/signal/wavelet.py (100%)
rename {modopt/tests => tests}/test_algorithms.py (100%)
rename {modopt/tests => tests}/test_base.py (100%)
rename {modopt/tests => tests}/test_helpers/__init__.py (100%)
rename {modopt/tests => tests}/test_helpers/utils.py (100%)
rename {modopt/tests => tests}/test_math.py (100%)
rename {modopt/tests => tests}/test_opt.py (100%)
rename {modopt/tests => tests}/test_signal.py (100%)
diff --git a/modopt/examples/README.rst b/examples/README.rst
similarity index 100%
rename from modopt/examples/README.rst
rename to examples/README.rst
diff --git a/modopt/examples/__init__.py b/examples/__init__.py
similarity index 100%
rename from modopt/examples/__init__.py
rename to examples/__init__.py
diff --git a/modopt/examples/conftest.py b/examples/conftest.py
similarity index 100%
rename from modopt/examples/conftest.py
rename to examples/conftest.py
diff --git a/modopt/examples/example_lasso_forward_backward.py b/examples/example_lasso_forward_backward.py
similarity index 100%
rename from modopt/examples/example_lasso_forward_backward.py
rename to examples/example_lasso_forward_backward.py
diff --git a/pyproject.toml b/pyproject.toml
index 71bdce82..37da48a3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -26,9 +26,6 @@ dev=["black", "pytest<8.0.0", "pytest-cases", "pytest-cov", "pytest-sugar", "ruf
[build-system]
requires=["setuptools", "setuptools-scm[toml]", "wheel"]
-[tool.setuptools]
-packages=["modopt"]
-
[tool.coverage.run]
omit = ["*tests*", "*__init__*", "*setup.py*", "*_version.py*", "*example*"]
@@ -38,12 +35,10 @@ exclude_lines = ["pragma: no cover", "raise NotImplementedError"]
[tool.black]
-[tool.ruff]
-
-src=["modopt"]
+[lint]
select = ["E", "F", "B", "Q", "UP", "D", "NPY", "RUF"]
-[tool.ruff.pydocstyle]
+[lint.pydocstyle]
convention="numpy"
[tool.isort]
@@ -51,6 +46,6 @@ profile="black"
[tool.pytest.ini_options]
minversion = "6.0"
-norecursedirs = ["tests/helpers"]
+norecursedirs = ["tests/test_helpers"]
testpaths=["modopt"]
addopts = ["--verbose", "--cov=modopt", "--cov-report=term-missing", "--cov-report=xml", "--junitxml=pytest.xml"]
diff --git a/modopt/__init__.py b/src/modopt/__init__.py
similarity index 100%
rename from modopt/__init__.py
rename to src/modopt/__init__.py
diff --git a/modopt/base/__init__.py b/src/modopt/base/__init__.py
similarity index 100%
rename from modopt/base/__init__.py
rename to src/modopt/base/__init__.py
diff --git a/modopt/base/backend.py b/src/modopt/base/backend.py
similarity index 100%
rename from modopt/base/backend.py
rename to src/modopt/base/backend.py
diff --git a/modopt/base/np_adjust.py b/src/modopt/base/np_adjust.py
similarity index 100%
rename from modopt/base/np_adjust.py
rename to src/modopt/base/np_adjust.py
diff --git a/modopt/base/observable.py b/src/modopt/base/observable.py
similarity index 100%
rename from modopt/base/observable.py
rename to src/modopt/base/observable.py
diff --git a/modopt/base/transform.py b/src/modopt/base/transform.py
similarity index 100%
rename from modopt/base/transform.py
rename to src/modopt/base/transform.py
diff --git a/modopt/base/types.py b/src/modopt/base/types.py
similarity index 100%
rename from modopt/base/types.py
rename to src/modopt/base/types.py
diff --git a/modopt/interface/__init__.py b/src/modopt/interface/__init__.py
similarity index 100%
rename from modopt/interface/__init__.py
rename to src/modopt/interface/__init__.py
diff --git a/modopt/interface/errors.py b/src/modopt/interface/errors.py
similarity index 100%
rename from modopt/interface/errors.py
rename to src/modopt/interface/errors.py
diff --git a/modopt/interface/log.py b/src/modopt/interface/log.py
similarity index 100%
rename from modopt/interface/log.py
rename to src/modopt/interface/log.py
diff --git a/modopt/math/__init__.py b/src/modopt/math/__init__.py
similarity index 100%
rename from modopt/math/__init__.py
rename to src/modopt/math/__init__.py
diff --git a/modopt/math/convolve.py b/src/modopt/math/convolve.py
similarity index 100%
rename from modopt/math/convolve.py
rename to src/modopt/math/convolve.py
diff --git a/modopt/math/matrix.py b/src/modopt/math/matrix.py
similarity index 100%
rename from modopt/math/matrix.py
rename to src/modopt/math/matrix.py
diff --git a/modopt/math/metrics.py b/src/modopt/math/metrics.py
similarity index 100%
rename from modopt/math/metrics.py
rename to src/modopt/math/metrics.py
diff --git a/modopt/math/stats.py b/src/modopt/math/stats.py
similarity index 100%
rename from modopt/math/stats.py
rename to src/modopt/math/stats.py
diff --git a/modopt/opt/__init__.py b/src/modopt/opt/__init__.py
similarity index 100%
rename from modopt/opt/__init__.py
rename to src/modopt/opt/__init__.py
diff --git a/modopt/opt/algorithms/__init__.py b/src/modopt/opt/algorithms/__init__.py
similarity index 100%
rename from modopt/opt/algorithms/__init__.py
rename to src/modopt/opt/algorithms/__init__.py
diff --git a/modopt/opt/algorithms/admm.py b/src/modopt/opt/algorithms/admm.py
similarity index 100%
rename from modopt/opt/algorithms/admm.py
rename to src/modopt/opt/algorithms/admm.py
diff --git a/modopt/opt/algorithms/base.py b/src/modopt/opt/algorithms/base.py
similarity index 100%
rename from modopt/opt/algorithms/base.py
rename to src/modopt/opt/algorithms/base.py
diff --git a/modopt/opt/algorithms/forward_backward.py b/src/modopt/opt/algorithms/forward_backward.py
similarity index 100%
rename from modopt/opt/algorithms/forward_backward.py
rename to src/modopt/opt/algorithms/forward_backward.py
diff --git a/modopt/opt/algorithms/gradient_descent.py b/src/modopt/opt/algorithms/gradient_descent.py
similarity index 100%
rename from modopt/opt/algorithms/gradient_descent.py
rename to src/modopt/opt/algorithms/gradient_descent.py
diff --git a/modopt/opt/algorithms/primal_dual.py b/src/modopt/opt/algorithms/primal_dual.py
similarity index 100%
rename from modopt/opt/algorithms/primal_dual.py
rename to src/modopt/opt/algorithms/primal_dual.py
diff --git a/modopt/opt/cost.py b/src/modopt/opt/cost.py
similarity index 100%
rename from modopt/opt/cost.py
rename to src/modopt/opt/cost.py
diff --git a/modopt/opt/gradient.py b/src/modopt/opt/gradient.py
similarity index 100%
rename from modopt/opt/gradient.py
rename to src/modopt/opt/gradient.py
diff --git a/modopt/opt/linear/__init__.py b/src/modopt/opt/linear/__init__.py
similarity index 100%
rename from modopt/opt/linear/__init__.py
rename to src/modopt/opt/linear/__init__.py
diff --git a/modopt/opt/linear/base.py b/src/modopt/opt/linear/base.py
similarity index 100%
rename from modopt/opt/linear/base.py
rename to src/modopt/opt/linear/base.py
diff --git a/modopt/opt/linear/wavelet.py b/src/modopt/opt/linear/wavelet.py
similarity index 100%
rename from modopt/opt/linear/wavelet.py
rename to src/modopt/opt/linear/wavelet.py
diff --git a/modopt/opt/proximity.py b/src/modopt/opt/proximity.py
similarity index 100%
rename from modopt/opt/proximity.py
rename to src/modopt/opt/proximity.py
diff --git a/modopt/opt/reweight.py b/src/modopt/opt/reweight.py
similarity index 100%
rename from modopt/opt/reweight.py
rename to src/modopt/opt/reweight.py
diff --git a/modopt/plot/__init__.py b/src/modopt/plot/__init__.py
similarity index 100%
rename from modopt/plot/__init__.py
rename to src/modopt/plot/__init__.py
diff --git a/modopt/plot/cost_plot.py b/src/modopt/plot/cost_plot.py
similarity index 100%
rename from modopt/plot/cost_plot.py
rename to src/modopt/plot/cost_plot.py
diff --git a/modopt/signal/__init__.py b/src/modopt/signal/__init__.py
similarity index 100%
rename from modopt/signal/__init__.py
rename to src/modopt/signal/__init__.py
diff --git a/modopt/signal/filter.py b/src/modopt/signal/filter.py
similarity index 100%
rename from modopt/signal/filter.py
rename to src/modopt/signal/filter.py
diff --git a/modopt/signal/noise.py b/src/modopt/signal/noise.py
similarity index 100%
rename from modopt/signal/noise.py
rename to src/modopt/signal/noise.py
diff --git a/modopt/signal/positivity.py b/src/modopt/signal/positivity.py
similarity index 100%
rename from modopt/signal/positivity.py
rename to src/modopt/signal/positivity.py
diff --git a/modopt/signal/svd.py b/src/modopt/signal/svd.py
similarity index 100%
rename from modopt/signal/svd.py
rename to src/modopt/signal/svd.py
diff --git a/modopt/signal/validation.py b/src/modopt/signal/validation.py
similarity index 100%
rename from modopt/signal/validation.py
rename to src/modopt/signal/validation.py
diff --git a/modopt/signal/wavelet.py b/src/modopt/signal/wavelet.py
similarity index 100%
rename from modopt/signal/wavelet.py
rename to src/modopt/signal/wavelet.py
diff --git a/modopt/tests/test_algorithms.py b/tests/test_algorithms.py
similarity index 100%
rename from modopt/tests/test_algorithms.py
rename to tests/test_algorithms.py
diff --git a/modopt/tests/test_base.py b/tests/test_base.py
similarity index 100%
rename from modopt/tests/test_base.py
rename to tests/test_base.py
diff --git a/modopt/tests/test_helpers/__init__.py b/tests/test_helpers/__init__.py
similarity index 100%
rename from modopt/tests/test_helpers/__init__.py
rename to tests/test_helpers/__init__.py
diff --git a/modopt/tests/test_helpers/utils.py b/tests/test_helpers/utils.py
similarity index 100%
rename from modopt/tests/test_helpers/utils.py
rename to tests/test_helpers/utils.py
diff --git a/modopt/tests/test_math.py b/tests/test_math.py
similarity index 100%
rename from modopt/tests/test_math.py
rename to tests/test_math.py
diff --git a/modopt/tests/test_opt.py b/tests/test_opt.py
similarity index 100%
rename from modopt/tests/test_opt.py
rename to tests/test_opt.py
diff --git a/modopt/tests/test_signal.py b/tests/test_signal.py
similarity index 100%
rename from modopt/tests/test_signal.py
rename to tests/test_signal.py
From 536366d621d7f71cfd5f1135a25329aff51a819b Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Fri, 16 Feb 2024 22:47:11 +0100
Subject: [PATCH 07/34] fix: make majority of tests passing.
---
src/modopt/opt/algorithms/__init__.py | 29 +++++++++++++++++++++++++
src/modopt/opt/linear/wavelet.py | 8 +++----
tests/test_helpers/__init__.py | 5 +++++
tests/test_opt.py | 31 +++++++++++++++++++++------
4 files changed, 63 insertions(+), 10 deletions(-)
diff --git a/src/modopt/opt/algorithms/__init__.py b/src/modopt/opt/algorithms/__init__.py
index ce6c5e56..ff79502c 100644
--- a/src/modopt/opt/algorithms/__init__.py
+++ b/src/modopt/opt/algorithms/__init__.py
@@ -44,3 +44,32 @@
"""
+from .forward_backward import FISTA, ForwardBackward, GenForwardBackward, POGM
+from .primal_dual import Condat
+from .gradient_descent import (
+ ADAMGradOpt,
+ AdaGenericGradOpt,
+ GenericGradOpt,
+ MomentumGradOpt,
+ RMSpropGradOpt,
+ SAGAOptGradOpt,
+ VanillaGenericGradOpt,
+)
+from .admm import ADMM, FastADMM
+
+__all__ = [
+ "FISTA",
+ "ForwardBackward",
+ "GenForwardBackward",
+ "POGM",
+ "Condat",
+ "ADAMGradOpt",
+ "AdaGenericGradOpt",
+ "GenericGradOpt",
+ "MomentumGradOpt",
+ "RMSpropGradOpt",
+ "SAGAOptGradOpt",
+ "VanillaGenericGradOpt",
+ "ADMM",
+ "FastADMM",
+]
diff --git a/src/modopt/opt/linear/wavelet.py b/src/modopt/opt/linear/wavelet.py
index e1608ff4..8dc44fd3 100644
--- a/src/modopt/opt/linear/wavelet.py
+++ b/src/modopt/opt/linear/wavelet.py
@@ -293,7 +293,7 @@ def __init__(
self.mode = mode
self.coeffs_shape = None # will be set after op.
- def op(self, data: torch.Tensor) -> list[torch.Tensor]:
+ def op(self, data):
"""Apply the wavelet decomposition on.
Parameters
@@ -355,7 +355,7 @@ def op(self, data: torch.Tensor) -> list[torch.Tensor]:
)
return [coeffs_[0]] + [cc for c in coeffs_[1:] for cc in c.values()]
- def adj_op(self, coeffs: list[torch.Tensor]) -> torch.Tensor:
+ def adj_op(self, coeffs):
"""Apply the wavelet recomposition.
Parameters
@@ -429,7 +429,7 @@ def __init__(
)
self.coeffs_shape = None # will be set after op
- def op(self, data: cp.array) -> cp.ndarray:
+ def op(self, data):
"""Define the wavelet operator.
This method returns the input data convolved with the wavelet filter.
@@ -459,7 +459,7 @@ def op(self, data: cp.array) -> cp.ndarray:
return ret
- def adj_op(self, data: cp.ndarray) -> cp.ndarray:
+ def adj_op(self, data):
"""Define the wavelet adjoint operator.
This method returns the reconstructed image.
diff --git a/tests/test_helpers/__init__.py b/tests/test_helpers/__init__.py
index e69de29b..0ded847a 100644
--- a/tests/test_helpers/__init__.py
+++ b/tests/test_helpers/__init__.py
@@ -0,0 +1,5 @@
+"""Utilities for tests."""
+
+from .utils import Dummy, failparam, skipparam
+
+__all__ = ["Dummy", "failparam", "skipparam"]
diff --git a/tests/test_opt.py b/tests/test_opt.py
index 7dc27871..e31d3a49 100644
--- a/tests/test_opt.py
+++ b/tests/test_opt.py
@@ -36,15 +36,22 @@
except ImportError:
PYWT_AVAILABLE = False
+
# Basic functions to be used as operators or as dummy functions
def func_identity(x_val):
return x_val
+
+
def func_double(x_val):
return x_val * 2
+
+
def func_sq(x_val):
- return x_val ** 2
+ return x_val**2
+
+
def func_cube(x_val):
- return x_val ** 3
+ return x_val**3
@case(tags="cost")
@@ -187,10 +194,21 @@ def case_linear_wavelet_convolve(self):
@parametrize(
compute_backend=[
- pytest.param("numpy", marks=pytest.mark.skipif(not PYWT_AVAILABLE, reason="PyWavelet not available.")),
- pytest.param("cupy", marks=pytest.mark.skipif(not PTWT_AVAILABLE, reason="Pytorch Wavelet not available."))
- ])
- def case_linear_wavelet_transform(self, compute_backend="numpy"):
+ pytest.param(
+ "numpy",
+ marks=pytest.mark.skipif(
+ not PYWT_AVAILABLE, reason="PyWavelet not available."
+ ),
+ ),
+ pytest.param(
+ "cupy",
+ marks=pytest.mark.skipif(
+ not PTWT_AVAILABLE, reason="Pytorch Wavelet not available."
+ ),
+ ),
+ ]
+ )
+ def case_linear_wavelet_transform(self, compute_backend):
linop = linear.WaveletTransform(
wavelet_name="haar",
shape=(8, 8),
@@ -318,6 +336,7 @@ class ProxCases:
[11.67394789, 12.87497954, 14.07601119],
[15.27704284, 16.47807449, 17.67910614],
],
+ ]
)
array233_3 = np.array(
[
From 0c3b13ce83e1d7d4147be1d050a4bb387b10310a Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Fri, 16 Feb 2024 22:55:46 +0100
Subject: [PATCH 08/34] typo
---
src/modopt/base/np_adjust.py | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/src/modopt/base/np_adjust.py b/src/modopt/base/np_adjust.py
index 586a1ee0..10cb5c29 100644
--- a/src/modopt/base/np_adjust.py
+++ b/src/modopt/base/np_adjust.py
@@ -1,4 +1,3 @@
-
"""NUMPY ADJUSTMENT ROUTINES.
This module contains methods for adjusting the default output for certain
@@ -153,8 +152,7 @@ def pad2d(input_data, padding):
padding = np.array(padding)
elif not isinstance(padding, np.ndarray):
raise ValueError(
- "Padding must be an integer or a tuple (or list, np.ndarray) "
- + "of itegers",
+ "Padding must be an integer or a tuple (or list, np.ndarray) of integers",
)
if padding.size == 1:
From a79113b03ddaae8fa786ec1584270cf558b68cb8 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Fri, 16 Feb 2024 23:04:13 +0100
Subject: [PATCH 09/34] fix tests setup
---
pyproject.toml | 3 +--
tests/test_helpers/utils.py | 4 ++--
2 files changed, 3 insertions(+), 4 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 37da48a3..dd08cbd0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -47,5 +47,4 @@ profile="black"
[tool.pytest.ini_options]
minversion = "6.0"
norecursedirs = ["tests/test_helpers"]
-testpaths=["modopt"]
-addopts = ["--verbose", "--cov=modopt", "--cov-report=term-missing", "--cov-report=xml", "--junitxml=pytest.xml"]
+addopts = ["--cov=modopt", "--cov-report=term-missing", "--cov-report=xml"]
diff --git a/tests/test_helpers/utils.py b/tests/test_helpers/utils.py
index 895b2371..049d257f 100644
--- a/tests/test_helpers/utils.py
+++ b/tests/test_helpers/utils.py
@@ -12,12 +12,12 @@ def failparam(*args, raises=None):
"""Return a pytest parameterization that should raise an error."""
if not issubclass(raises, Exception):
raise ValueError("raises should be an expected Exception.")
- return pytest.param(*args, marks=pytest.mark.raises(exception=raises))
+ return pytest.param(*args, marks=[pytest.mark.xfail(exception=raises)])
def skipparam(*args, cond=True, reason=""):
"""Return a pytest parameterization that should be skip if cond is valid."""
- return pytest.param(*args, marks=pytest.mark.skipif(cond, reason=reason))
+ return pytest.param(*args, marks=[pytest.mark.skipif(cond, reason=reason)])
class Dummy:
From bcbe1f3ff469ca6cd12dc4e306b0ce2c97b513d0 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Fri, 16 Feb 2024 23:13:04 +0100
Subject: [PATCH 10/34] slim down the CI.
---
.github/workflows/cd-build.yml | 9 ++---
.github/workflows/ci-build.yml | 72 ++++------------------------------
pyproject.toml | 3 +-
3 files changed, 13 insertions(+), 71 deletions(-)
diff --git a/.github/workflows/cd-build.yml b/.github/workflows/cd-build.yml
index fca9feb1..ded7159d 100644
--- a/.github/workflows/cd-build.yml
+++ b/.github/workflows/cd-build.yml
@@ -27,19 +27,17 @@ jobs:
shell: bash -l {0}
run: |
python -m pip install --upgrade pip
- python -m pip install -r develop.txt
python -m pip install twine
- python -m pip install .
+ python -m pip install .[doc,test]
- name: Run Tests
shell: bash -l {0}
run: |
- python setup.py test
+ pytest
- name: Check distribution
shell: bash -l {0}
run: |
- python setup.py sdist
twine check dist/*
- name: Upload coverage to Codecov
@@ -69,8 +67,7 @@ jobs:
run: |
conda install -c conda-forge pandoc
python -m pip install --upgrade pip
- python -m pip install -r docs/requirements.txt
- python -m pip install .
+ python -m pip install .[doc]
- name: Build API documentation
shell: bash -l {0}
diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml
index 88129d45..9d4226cb 100644
--- a/.github/workflows/ci-build.yml
+++ b/.github/workflows/ci-build.yml
@@ -16,42 +16,27 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
- python-version: ["3.10"]
+ python-version: ["3.8", "3.9", "3.10"]
steps:
- - name: Checkout
- uses: actions/checkout@v2
-
- - name: Set up Conda with Python ${{ matrix.python-version }}
- uses: conda-incubator/setup-miniconda@v2
+ - uses: actions/checkout@v3
+ - uses: actions/setup-python@v4
with:
- auto-update-conda: true
python-version: ${{ matrix.python-version }}
- auto-activate-base: false
-
- - name: Check Conda
- shell: bash -l {0}
- run: |
- conda info
- conda list
- python --version
+ cache: pip
- name: Install Dependencies
shell: bash -l {0}
run: |
python --version
python -m pip install --upgrade pip
- python -m pip install -r develop.txt
- python -m pip install -r docs/requirements.txt
- python -m pip install astropy "scikit-image<0.20" scikit-learn matplotlib
+ python -m pip install .[test]
+ python -m pip install astropy scikit-image scikit-learn matplotlib
python -m pip install tensorflow>=2.4.1 torch
- python -m pip install twine
- python -m pip install .
- name: Run Tests
shell: bash -l {0}
run: |
- export PATH=/usr/share/miniconda/bin:$PATH
pytest -n 2
- name: Save Test Results
@@ -59,18 +44,12 @@ jobs:
uses: actions/upload-artifact@v2
with:
name: unit-test-results-${{ matrix.os }}-${{ matrix.python-version }}
- path: pytest.xml
-
- - name: Check Distribution
- shell: bash -l {0}
- run: |
- python setup.py sdist
- twine check dist/*
+ path: .coverage.xml
- name: Check API Documentation build
shell: bash -l {0}
run: |
- conda install -c conda-forge pandoc
+ pip install .[doc]
sphinx-apidoc -t docs/_templates -feTMo docs/source modopt
sphinx-build -b doctest -E docs/source docs/_build
@@ -81,38 +60,3 @@ jobs:
file: coverage.xml
flags: unittests
- test-basic:
- name: Basic Test Suite
- runs-on: ${{ matrix.os }}
-
- strategy:
- fail-fast: false
- matrix:
- os: [ubuntu-latest, macos-latest]
- python-version: ["3.7", "3.8", "3.9"]
-
- steps:
- - name: Checkout
- uses: actions/checkout@v2
-
- - name: Set up Conda with Python ${{ matrix.python-version }}
- uses: conda-incubator/setup-miniconda@v2
- with:
- auto-update-conda: true
- python-version: ${{ matrix.python-version }}
- auto-activate-base: false
-
- - name: Install Dependencies
- shell: bash -l {0}
- run: |
- python --version
- python -m pip install --upgrade pip
- python -m pip install -r develop.txt
- python -m pip install astropy "scikit-image<0.20" scikit-learn matplotlib
- python -m pip install .
-
- - name: Run Tests
- shell: bash -l {0}
- run: |
- export PATH=/usr/share/miniconda/bin:$PATH
- pytest -n 2
diff --git a/pyproject.toml b/pyproject.toml
index dd08cbd0..9c8f4bee 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -21,7 +21,8 @@ doc=["myst-parser==0.16.1",
"sphinx-gallery==0.11.1",
"sphinxawesome-theme==3.2.1",
"sphinxcontrib-bibtex"]
-dev=["black", "pytest<8.0.0", "pytest-cases", "pytest-cov", "pytest-sugar", "ruff"]
+dev=["black", "ruff"]
+test=["pytest<8.0.0", "pytest-cases", "pytest-cov", "pytest-sugar"]
[build-system]
requires=["setuptools", "setuptools-scm[toml]", "wheel"]
From 5fb4cb22850de137022cc61d2268f79a264286a4 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Sun, 18 Feb 2024 16:15:46 +0100
Subject: [PATCH 11/34] black 24
---
src/modopt/__init__.py | 1 -
src/modopt/base/__init__.py | 1 -
src/modopt/base/backend.py | 1 -
src/modopt/base/observable.py | 1 -
src/modopt/base/transform.py | 1 -
src/modopt/base/types.py | 1 -
src/modopt/interface/__init__.py | 1 -
src/modopt/interface/errors.py | 1 -
src/modopt/interface/log.py | 1 -
src/modopt/math/__init__.py | 1 -
src/modopt/math/convolve.py | 1 -
src/modopt/math/matrix.py | 1 -
src/modopt/math/metrics.py | 1 -
src/modopt/math/stats.py | 1 -
src/modopt/opt/__init__.py | 1 -
src/modopt/opt/gradient.py | 1 -
src/modopt/opt/proximity.py | 1 -
src/modopt/opt/reweight.py | 1 -
src/modopt/plot/__init__.py | 1 -
src/modopt/plot/cost_plot.py | 1 -
src/modopt/signal/__init__.py | 1 -
src/modopt/signal/filter.py | 1 -
src/modopt/signal/noise.py | 2 --
src/modopt/signal/positivity.py | 1 -
src/modopt/signal/svd.py | 1 -
src/modopt/signal/wavelet.py | 1 -
tests/test_algorithms.py | 1 -
27 files changed, 28 deletions(-)
diff --git a/src/modopt/__init__.py b/src/modopt/__init__.py
index 958f3ace..354c31d0 100644
--- a/src/modopt/__init__.py
+++ b/src/modopt/__init__.py
@@ -1,4 +1,3 @@
-
"""MODOPT PACKAGE.
ModOpt is a series of Modular Optimisation tools for solving inverse problems.
diff --git a/src/modopt/base/__init__.py b/src/modopt/base/__init__.py
index e7df6c37..c4c681d7 100644
--- a/src/modopt/base/__init__.py
+++ b/src/modopt/base/__init__.py
@@ -1,4 +1,3 @@
-
"""BASE ROUTINES.
This module contains submodules for basic operations such as type
diff --git a/src/modopt/base/backend.py b/src/modopt/base/backend.py
index b4987942..485f649a 100644
--- a/src/modopt/base/backend.py
+++ b/src/modopt/base/backend.py
@@ -1,4 +1,3 @@
-
"""BACKEND MODULE.
This module contains methods for GPU Compatiblity.
diff --git a/src/modopt/base/observable.py b/src/modopt/base/observable.py
index 69c6b238..15996dfa 100644
--- a/src/modopt/base/observable.py
+++ b/src/modopt/base/observable.py
@@ -1,4 +1,3 @@
-
"""Observable.
This module contains observable classes
diff --git a/src/modopt/base/transform.py b/src/modopt/base/transform.py
index 1dc9039a..25ed102a 100644
--- a/src/modopt/base/transform.py
+++ b/src/modopt/base/transform.py
@@ -1,4 +1,3 @@
-
"""DATA TRANSFORM ROUTINES.
This module contains methods for transforming data.
diff --git a/src/modopt/base/types.py b/src/modopt/base/types.py
index 5ed24ec3..9e9a15b9 100644
--- a/src/modopt/base/types.py
+++ b/src/modopt/base/types.py
@@ -1,4 +1,3 @@
-
"""TYPE HANDLING ROUTINES.
This module contains methods for handing object types.
diff --git a/src/modopt/interface/__init__.py b/src/modopt/interface/__init__.py
index 529816ee..a54f4bf5 100644
--- a/src/modopt/interface/__init__.py
+++ b/src/modopt/interface/__init__.py
@@ -1,4 +1,3 @@
-
"""INTERFACE ROUTINES.
This module contains submodules for error handling, logging and IO interaction.
diff --git a/src/modopt/interface/errors.py b/src/modopt/interface/errors.py
index 5c84ad0e..93e9ed1b 100644
--- a/src/modopt/interface/errors.py
+++ b/src/modopt/interface/errors.py
@@ -1,4 +1,3 @@
-
"""ERROR HANDLING ROUTINES.
This module contains methods for handing warnings and errors.
diff --git a/src/modopt/interface/log.py b/src/modopt/interface/log.py
index d3e0d8e9..50c316b7 100644
--- a/src/modopt/interface/log.py
+++ b/src/modopt/interface/log.py
@@ -1,4 +1,3 @@
-
"""LOGGING ROUTINES.
This module contains methods for handing logging.
diff --git a/src/modopt/math/__init__.py b/src/modopt/math/__init__.py
index 0423a333..d5ffc67a 100644
--- a/src/modopt/math/__init__.py
+++ b/src/modopt/math/__init__.py
@@ -1,4 +1,3 @@
-
"""MATHEMATICS ROUTINES.
This module contains submodules for mathematical applications.
diff --git a/src/modopt/math/convolve.py b/src/modopt/math/convolve.py
index ac1cf84c..21dc8b4e 100644
--- a/src/modopt/math/convolve.py
+++ b/src/modopt/math/convolve.py
@@ -1,4 +1,3 @@
-
"""CONVOLUTION ROUTINES.
This module contains methods for convolution.
diff --git a/src/modopt/math/matrix.py b/src/modopt/math/matrix.py
index a2419a6c..ef59f785 100644
--- a/src/modopt/math/matrix.py
+++ b/src/modopt/math/matrix.py
@@ -1,4 +1,3 @@
-
"""MATRIX ROUTINES.
This module contains methods for matrix operations.
diff --git a/src/modopt/math/metrics.py b/src/modopt/math/metrics.py
index 8f797f02..befd4fa4 100644
--- a/src/modopt/math/metrics.py
+++ b/src/modopt/math/metrics.py
@@ -1,4 +1,3 @@
-
"""METRICS.
This module contains classes of different metric functions for optimization.
diff --git a/src/modopt/math/stats.py b/src/modopt/math/stats.py
index b3ee0d8b..022e5f3c 100644
--- a/src/modopt/math/stats.py
+++ b/src/modopt/math/stats.py
@@ -1,4 +1,3 @@
-
"""STATISTICS ROUTINES.
This module contains methods for basic statistics.
diff --git a/src/modopt/opt/__init__.py b/src/modopt/opt/__init__.py
index 86564f90..62d1f388 100644
--- a/src/modopt/opt/__init__.py
+++ b/src/modopt/opt/__init__.py
@@ -1,4 +1,3 @@
-
"""OPTIMISATION PROBLEM MODULES.
This module contains submodules for solving optimisation problems.
diff --git a/src/modopt/opt/gradient.py b/src/modopt/opt/gradient.py
index bd214f21..3c5f0031 100644
--- a/src/modopt/opt/gradient.py
+++ b/src/modopt/opt/gradient.py
@@ -1,4 +1,3 @@
-
"""GRADIENT CLASSES.
This module contains classses for defining algorithm gradients.
diff --git a/src/modopt/opt/proximity.py b/src/modopt/opt/proximity.py
index 91a99f2a..10e69a98 100644
--- a/src/modopt/opt/proximity.py
+++ b/src/modopt/opt/proximity.py
@@ -1,4 +1,3 @@
-
"""PROXIMITY OPERATORS.
This module contains classes of proximity operators for optimisation.
diff --git a/src/modopt/opt/reweight.py b/src/modopt/opt/reweight.py
index 8d120101..b37fc6fb 100644
--- a/src/modopt/opt/reweight.py
+++ b/src/modopt/opt/reweight.py
@@ -1,4 +1,3 @@
-
"""REWEIGHTING CLASSES.
This module contains classes for reweighting optimisation implementations.
diff --git a/src/modopt/plot/__init__.py b/src/modopt/plot/__init__.py
index f6b39978..f31ed596 100644
--- a/src/modopt/plot/__init__.py
+++ b/src/modopt/plot/__init__.py
@@ -1,4 +1,3 @@
-
"""PLOTTING ROUTINES.
This module contains submodules for plotting applications.
diff --git a/src/modopt/plot/cost_plot.py b/src/modopt/plot/cost_plot.py
index 2274f35d..7fb7e39b 100644
--- a/src/modopt/plot/cost_plot.py
+++ b/src/modopt/plot/cost_plot.py
@@ -1,4 +1,3 @@
-
"""PLOTTING ROUTINES.
This module contains methods for making plots.
diff --git a/src/modopt/signal/__init__.py b/src/modopt/signal/__init__.py
index 2aee1987..6bf0912b 100644
--- a/src/modopt/signal/__init__.py
+++ b/src/modopt/signal/__init__.py
@@ -1,4 +1,3 @@
-
"""SIGNAL PROCESSING ROUTINES.
This module contains submodules for signal processing.
diff --git a/src/modopt/signal/filter.py b/src/modopt/signal/filter.py
index 0e50d28f..33c3c105 100644
--- a/src/modopt/signal/filter.py
+++ b/src/modopt/signal/filter.py
@@ -1,4 +1,3 @@
-
"""FILTER ROUTINES.
This module contains methods for distance measurements in cosmology.
diff --git a/src/modopt/signal/noise.py b/src/modopt/signal/noise.py
index b43a0b61..2594fc62 100644
--- a/src/modopt/signal/noise.py
+++ b/src/modopt/signal/noise.py
@@ -1,4 +1,3 @@
-
"""NOISE ROUTINES.
This module contains methods for adding and removing noise from data.
@@ -7,7 +6,6 @@
"""
-
import numpy as np
from modopt.base.backend import get_array_module
diff --git a/src/modopt/signal/positivity.py b/src/modopt/signal/positivity.py
index f3f312d3..8d7aa46c 100644
--- a/src/modopt/signal/positivity.py
+++ b/src/modopt/signal/positivity.py
@@ -1,4 +1,3 @@
-
"""POSITIVITY.
This module contains a function that retains only positive coefficients in
diff --git a/src/modopt/signal/svd.py b/src/modopt/signal/svd.py
index dd080306..cf147503 100644
--- a/src/modopt/signal/svd.py
+++ b/src/modopt/signal/svd.py
@@ -1,4 +1,3 @@
-
"""SVD ROUTINES.
This module contains methods for thresholding singular values.
diff --git a/src/modopt/signal/wavelet.py b/src/modopt/signal/wavelet.py
index d624db3a..b55b78d9 100644
--- a/src/modopt/signal/wavelet.py
+++ b/src/modopt/signal/wavelet.py
@@ -1,4 +1,3 @@
-
"""WAVELET MODULE.
This module contains methods for performing wavelet transformations using
diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py
index c1e676a5..fe5f92a6 100644
--- a/tests/test_algorithms.py
+++ b/tests/test_algorithms.py
@@ -1,4 +1,3 @@
-
"""UNIT TESTS FOR Algorithms.
This module contains unit tests for the modopt.opt module.
From 44e7cf4a6c52ddf41fbfa2ac2aadcfc1d254d1ff Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 10:39:12 +0100
Subject: [PATCH 12/34] update ruff config.
---
pyproject.toml | 9 +++++++--
1 file changed, 7 insertions(+), 2 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index 9c8f4bee..835dae3d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -36,10 +36,15 @@ exclude_lines = ["pragma: no cover", "raise NotImplementedError"]
[tool.black]
-[lint]
+
+[tool.ruff]
+exclude = ["examples", "docs"]
+[tool.ruff.lint]
select = ["E", "F", "B", "Q", "UP", "D", "NPY", "RUF"]
-[lint.pydocstyle]
+ignore = ["F401"] # we like the try: import ... expect: ...
+
+[tool.ruff.lint.pydocstyle]
convention="numpy"
[tool.isort]
From 67342e79d4a71917fa8480b8919fd61eb66094bb Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 10:39:18 +0100
Subject: [PATCH 13/34] fix: F403
---
src/modopt/__init__.py | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/src/modopt/__init__.py b/src/modopt/__init__.py
index 354c31d0..e8ae1c3c 100644
--- a/src/modopt/__init__.py
+++ b/src/modopt/__init__.py
@@ -8,7 +8,9 @@
from importlib_metadata import version
-from modopt.base import *
+from modopt.base import np_adjust, transform, types, wrappers, observable
+
+__all__ = ["np_adjust", "transform", "types", "wrappers", "observable"]
try:
_version = version("modopt")
From 14542cc81e942d077a937afe101ba638cd8f46b2 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 10:41:54 +0100
Subject: [PATCH 14/34] fix: D1** errors.
---
src/modopt/opt/linear/wavelet.py | 1 +
tests/test_helpers/utils.py | 3 +++
tests/test_opt.py | 5 +++++
3 files changed, 9 insertions(+)
diff --git a/src/modopt/opt/linear/wavelet.py b/src/modopt/opt/linear/wavelet.py
index 8dc44fd3..9fb64b33 100644
--- a/src/modopt/opt/linear/wavelet.py
+++ b/src/modopt/opt/linear/wavelet.py
@@ -110,6 +110,7 @@ def __init__(
@property
def coeffs_shape(self):
+ """Get the coeffs shapes."""
return self.operator.coeffs_shape
diff --git a/tests/test_helpers/utils.py b/tests/test_helpers/utils.py
index 049d257f..41f948a6 100644
--- a/tests/test_helpers/utils.py
+++ b/tests/test_helpers/utils.py
@@ -1,5 +1,6 @@
"""
Some helper functions for the test parametrization.
+
They should be used inside ``@pytest.mark.parametrize`` call.
:Author: Pierre-Antoine Comby
@@ -21,4 +22,6 @@ def skipparam(*args, cond=True, reason=""):
class Dummy:
+ """Dummy Class."""
+
pass
diff --git a/tests/test_opt.py b/tests/test_opt.py
index e31d3a49..0a73e835 100644
--- a/tests/test_opt.py
+++ b/tests/test_opt.py
@@ -39,18 +39,22 @@
# Basic functions to be used as operators or as dummy functions
def func_identity(x_val):
+ """Return x."""
return x_val
def func_double(x_val):
+ """Double x."""
return x_val * 2
def func_sq(x_val):
+ """Square x."""
return x_val**2
def func_cube(x_val):
+ """Cube x."""
return x_val**3
@@ -209,6 +213,7 @@ def case_linear_wavelet_convolve(self):
]
)
def case_linear_wavelet_transform(self, compute_backend):
+ """Case linear wavelet operator."""
linop = linear.WaveletTransform(
wavelet_name="haar",
shape=(8, 8),
From 99208cd2362247a0a92ffb6ddcb823b80b8dbbea Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 10:51:44 +0100
Subject: [PATCH 15/34] fix: E501
---
src/modopt/math/stats.py | 4 ++--
src/modopt/opt/algorithms/admm.py | 10 ++++++----
src/modopt/opt/linear/wavelet.py | 8 +++++---
3 files changed, 13 insertions(+), 9 deletions(-)
diff --git a/src/modopt/math/stats.py b/src/modopt/math/stats.py
index 022e5f3c..8583a8c3 100644
--- a/src/modopt/math/stats.py
+++ b/src/modopt/math/stats.py
@@ -29,8 +29,8 @@ def gaussian_kernel(data_shape, sigma, norm="max"):
Desiered shape of the kernel
sigma : float
Standard deviation of the kernel
- norm : {'max', 'sum'}, optional
- Normalisation of the kerenl (options are ``'max'`` or ``'sum'``, default is ``'max'``)
+ norm : {'max', 'sum'}, optional, default='max'
+ Normalisation of the kernel
Returns
-------
diff --git a/src/modopt/opt/algorithms/admm.py b/src/modopt/opt/algorithms/admm.py
index 4fd8074e..b2f45171 100644
--- a/src/modopt/opt/algorithms/admm.py
+++ b/src/modopt/opt/algorithms/admm.py
@@ -68,7 +68,8 @@ def _calc_cost(self, u, v, **kwargs):
class ADMM(SetUp):
r"""Fast ADMM Optimisation Algorihm.
- This class implement the ADMM algorithm described in :cite:`Goldstein2014` (Algorithm 1).
+ This class implement the ADMM algorithm described in :cite:`Goldstein2014`
+ (Algorithm 1).
Parameters
----------
@@ -86,7 +87,7 @@ class ADMM(SetUp):
Constraint vector
optimizers: tuple
2-tuple of callable, that are the optimizers for the u and v.
- Each callable should access an init and obs argument and returns an estimate for:
+ Each callable should access init and obs argument and returns an estimate for:
.. math:: u_{k+1} = \argmin H(u) + \frac{\tau}{2}\|A u - y\|^2
.. math:: v_{k+1} = \argmin G(v) + \frac{\tau}{2}\|Bv - y \|^2
cost_funcs: tuple
@@ -243,7 +244,7 @@ class FastADMM(ADMM):
Constraint vector
optimizers: tuple
2-tuple of callable, that are the optimizers for the u and v.
- Each callable should access an init and obs argument and returns an estimate for:
+ Each callable should access init and obs argument and returns an estimate for:
.. math:: u_{k+1} = \argmin H(u) + \frac{\tau}{2}\|A u - y\|^2
.. math:: v_{k+1} = \argmin G(v) + \frac{\tau}{2}\|Bv - y \|^2
cost_funcs: tuple
@@ -257,7 +258,8 @@ class FastADMM(ADMM):
Notes
-----
- This is an accelerated version of the ADMM algorithm. The convergence hypothesis are stronger than for the ADMM algorithm.
+ This is an accelerated version of the ADMM algorithm. The convergence hypothesis are
+ stronger than for the ADMM algorithm.
See Also
--------
diff --git a/src/modopt/opt/linear/wavelet.py b/src/modopt/opt/linear/wavelet.py
index 9fb64b33..ae92efa7 100644
--- a/src/modopt/opt/linear/wavelet.py
+++ b/src/modopt/opt/linear/wavelet.py
@@ -65,7 +65,7 @@ class WaveletTransform(LinearParent):
"""
2D and 3D wavelet transform class.
- This is a wrapper around either Pywavelet (CPU) or Pytorch Wavelet (GPU using Pytorch).
+ This is a wrapper around either Pywavelet (CPU) or Pytorch Wavelet (GPU).
Parameters
----------
@@ -79,7 +79,8 @@ class WaveletTransform(LinearParent):
mode: str, default "zero"
Boundary Condition mode
compute_backend: str, "numpy" or "cupy", default "numpy"
- Backend library to use. "cupy" also requires a working installation of PyTorch and pytorch wavelets.
+ Backend library to use. "cupy" also requires a working installation of PyTorch
+ and PyTorch wavelets (ptwt).
**kwargs: extra kwargs for Pywavelet or Pytorch Wavelet
"""
@@ -165,7 +166,8 @@ def __init__(
):
if not pywt_available:
raise ImportError(
- "PyWavelet and/or joblib are not available. Please install it to use WaveletTransform."
+ "PyWavelet and/or joblib are not available. "
+ "Please install it to use WaveletTransform."
)
if wavelet_name not in pywt.wavelist(kind="all"):
raise ValueError(
From fdd6d1ee283bd2fd2c7c6eaa0793615225bd1524 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 11:29:24 +0100
Subject: [PATCH 16/34] fix: NPY002
random number generator stuff.
This adds possibility to use rng properly in ModOpt.
---
src/modopt/math/matrix.py | 7 ++++++-
src/modopt/signal/noise.py | 12 +++++++++---
src/modopt/signal/validation.py | 9 +++++++--
tests/test_algorithms.py | 5 ++++-
tests/test_math.py | 16 +++++-----------
tests/test_signal.py | 15 ++++++++-------
6 files changed, 39 insertions(+), 25 deletions(-)
diff --git a/src/modopt/math/matrix.py b/src/modopt/math/matrix.py
index ef59f785..e52cbfc5 100644
--- a/src/modopt/math/matrix.py
+++ b/src/modopt/math/matrix.py
@@ -274,6 +274,8 @@ class PowerMethod:
initialisation (default is ``True``)
verbose : bool, optional
Optional verbosity (default is ``False``)
+ rng: int, xp.random.Generator or None (default is ``None``)
+ Random number generator or seed.
Examples
--------
@@ -300,6 +302,7 @@ def __init__(
auto_run=True,
compute_backend="numpy",
verbose=False,
+ rng=None,
):
self._operator = operator
@@ -308,6 +311,7 @@ def __init__(
self._verbose = verbose
xp, compute_backend = get_backend(compute_backend)
self.xp = xp
+ self.rng = None
self.compute_backend = compute_backend
if auto_run:
self.get_spec_rad()
@@ -324,7 +328,8 @@ def _set_initial_x(self):
Random values of the same shape as the input data
"""
- return self.xp.random.random(self._data_shape).astype(self._data_type)
+ rng = self.xp.random.default_rng(self.rng)
+ return rng.random(self._data_shape).astype(self._data_type)
def get_spec_rad(self, tolerance=1e-6, max_iter=20, extra_factor=1.0):
"""Get spectral radius.
diff --git a/src/modopt/signal/noise.py b/src/modopt/signal/noise.py
index 2594fc62..28307f52 100644
--- a/src/modopt/signal/noise.py
+++ b/src/modopt/signal/noise.py
@@ -11,7 +11,7 @@
from modopt.base.backend import get_array_module
-def add_noise(input_data, sigma=1.0, noise_type="gauss"):
+def add_noise(input_data, sigma=1.0, noise_type="gauss", rng=None):
"""Add noise to data.
This method adds Gaussian or Poisson noise to the input data.
@@ -25,6 +25,9 @@ def add_noise(input_data, sigma=1.0, noise_type="gauss"):
default is ``1.0``)
noise_type : {'gauss', 'poisson'}
Type of noise to be added (default is ``'gauss'``)
+ rng: np.random.Generator or int
+ A Random number generator or a seed to initialize one.
+
Returns
-------
@@ -64,6 +67,9 @@ def add_noise(input_data, sigma=1.0, noise_type="gauss"):
array([ 3.24869073, -1.22351283, -1.0563435 , -2.14593724, 1.73081526])
"""
+ if not isinstance(rng, np.random.Generator):
+ rng = np.random.default_rng(rng)
+
input_data = np.array(input_data)
if noise_type not in {"gauss", "poisson"}:
@@ -78,10 +84,10 @@ def add_noise(input_data, sigma=1.0, noise_type="gauss"):
)
if noise_type == "gauss":
- random = np.random.randn(*input_data.shape)
+ random = rng.standard_normal(input_data.shape)
elif noise_type == "poisson":
- random = np.random.poisson(np.abs(input_data))
+ random = rng.poisson(np.abs(input_data))
if isinstance(sigma, (int, float)):
return input_data + sigma * random
diff --git a/src/modopt/signal/validation.py b/src/modopt/signal/validation.py
index 68c1e726..66485a54 100644
--- a/src/modopt/signal/validation.py
+++ b/src/modopt/signal/validation.py
@@ -16,6 +16,7 @@ def transpose_test(
x_args=None,
y_shape=None,
y_args=None,
+ rng=None,
):
"""Transpose test.
@@ -36,6 +37,8 @@ def transpose_test(
Shape of transpose operator input data (default is ``None``)
y_args : tuple, optional
Arguments to be passed to transpose operator (default is ``None``)
+ rng: np.random.Generator or int or None (default is ``None``)
+ Initialized random number generator or seed.
Raises
------
@@ -62,9 +65,11 @@ def transpose_test(
if isinstance(y_args, type(None)):
y_args = x_args
+ if not isinstance(rng, np.random.Generator):
+ rng = np.random.default_rng(rng)
# Generate random arrays.
- x_val = np.random.ranf(x_shape)
- y_val = np.random.ranf(y_shape)
+ x_val = rng.random(x_shape)
+ y_val = rng.random(y_shape)
# Calculate
mx_y = np.sum(np.multiply(operator(x_val, x_args), y_val))
diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py
index fe5f92a6..63847764 100644
--- a/tests/test_algorithms.py
+++ b/tests/test_algorithms.py
@@ -24,6 +24,9 @@
SKLEARN_AVAILABLE = False
+rng = np.random.default_rng()
+
+
@fixture
def idty():
"""Identity function."""
@@ -84,7 +87,7 @@ class AlgoCases:
"""
data1 = np.arange(9).reshape(3, 3).astype(float)
- data2 = data1 + np.random.randn(*data1.shape) * 1e-6
+ data2 = data1 + rng.standard_normal(data1.shape) * 1e-6
max_iter = 20
@parametrize(
diff --git a/tests/test_math.py b/tests/test_math.py
index e7536b03..37b71b1e 100644
--- a/tests/test_math.py
+++ b/tests/test_math.py
@@ -29,6 +29,8 @@
else:
SKIMAGE_AVAILABLE = True
+rng = np.random.default_rng(1)
+
class TestConvolve:
"""Test convolve functions."""
@@ -136,18 +138,15 @@ class TestMatrix:
),
)
- @pytest.fixture
+ @pytest.fixture(scope="module")
def pm_instance(self, request):
"""Power Method instance."""
- np.random.seed(1)
pm = matrix.PowerMethod(
lambda x_val: x_val.dot(x_val.T),
self.array33.shape,
- auto_run=request.param,
verbose=True,
+ rng=np.random.default_rng(0),
)
- if not request.param:
- pm.get_spec_rad(max_iter=1)
return pm
@pytest.mark.parametrize(
@@ -195,12 +194,7 @@ def test_rotate(self):
npt.assert_raises(ValueError, matrix.rotate, self.array23, np.pi / 2)
- @pytest.mark.parametrize(
- ("pm_instance", "value"),
- [(True, 1.0), (False, 0.8675467477372257)],
- indirect=["pm_instance"],
- )
- def test_power_method(self, pm_instance, value):
+ def test_power_method(self, pm_instance, value=1):
"""Test power method."""
npt.assert_almost_equal(pm_instance.spec_rad, value)
npt.assert_almost_equal(pm_instance.inv_spec_rad, 1 / value)
diff --git a/tests/test_signal.py b/tests/test_signal.py
index cdd95277..6dbb0bba 100644
--- a/tests/test_signal.py
+++ b/tests/test_signal.py
@@ -45,13 +45,13 @@ class TestNoise:
data1 = np.arange(9).reshape(3, 3).astype(float)
data2 = np.array(
- [[0, 2.0, 2.0], [4.0, 5.0, 10], [11.0, 15.0, 18.0]],
+ [[0, 3.0, 4.0], [6.0, 9.0, 8.0], [14.0, 14.0, 17.0]],
)
data3 = np.array(
[
- [1.62434536, 0.38824359, 1.47182825],
- [1.92703138, 4.86540763, 2.6984613],
- [7.74481176, 6.2387931, 8.3190391],
+ [0.3455842, 1.8216181, 2.3304371],
+ [1.6968428, 4.9053559, 5.4463746],
+ [5.4630468, 7.5811181, 8.3645724],
]
)
data4 = np.array([[0, 0, 0], [0, 0, 5.0], [6.0, 7.0, 8.0]])
@@ -70,9 +70,10 @@ class TestNoise:
)
def test_add_noise(self, data, noise_type, sigma, data_noise):
"""Test add_noise."""
- np.random.seed(1)
+ rng = np.random.default_rng(1)
npt.assert_almost_equal(
- noise.add_noise(data, sigma=sigma, noise_type=noise_type), data_noise
+ noise.add_noise(data, sigma=sigma, noise_type=noise_type, rng=rng),
+ data_noise,
)
@pytest.mark.parametrize(
@@ -241,13 +242,13 @@ class TestValidation:
def test_transpose_test(self):
"""Test transpose_test."""
- np.random.seed(2)
npt.assert_equal(
validation.transpose_test(
lambda x_val, y_val: x_val.dot(y_val),
lambda x_val, y_val: x_val.dot(y_val.T),
self.array33.shape,
x_args=self.array33,
+ rng=2,
),
None,
)
From 22c016fcc81b18af0794dd07fb144ed8409b2566 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 11:31:50 +0100
Subject: [PATCH 17/34] fix: RUF012
Mutable class args. we don't want type annotations.
---
src/modopt/opt/linear/wavelet.py | 2 +-
tests/test_base.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/modopt/opt/linear/wavelet.py b/src/modopt/opt/linear/wavelet.py
index ae92efa7..ff434287 100644
--- a/src/modopt/opt/linear/wavelet.py
+++ b/src/modopt/opt/linear/wavelet.py
@@ -281,7 +281,7 @@ def _adj_op(self, coeffs):
class TorchWaveletTransform:
"""Wavelet transform using pytorch."""
- wavedec3_keys = ["aad", "ada", "add", "daa", "dad", "dda", "ddd"]
+ wavedec3_keys = ("aad", "ada", "add", "daa", "dad", "dda", "ddd")
def __init__(
self,
diff --git a/tests/test_base.py b/tests/test_base.py
index 298253d6..62e09095 100644
--- a/tests/test_base.py
+++ b/tests/test_base.py
@@ -149,7 +149,7 @@ def test_matrix2cube(self):
class TestType:
"""Test for type module."""
- data_list = list(range(5))
+ data_list = list(range(5)) # noqa: RUF012
data_int = np.arange(5)
data_flt = np.arange(5).astype(float)
From 59f58dddf5132829a115d7b89945321ea864ee0b Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 12:17:56 +0100
Subject: [PATCH 18/34] fix: B026
---
src/modopt/opt/cost.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/modopt/opt/cost.py b/src/modopt/opt/cost.py
index 4bead130..5c5701b1 100644
--- a/src/modopt/opt/cost.py
+++ b/src/modopt/opt/cost.py
@@ -187,7 +187,7 @@ def get_cost(self, *args, **kwargs):
print(" - ITERATION:", self._iteration)
# Calculate the current cost
- self.cost = self._calc_cost(verbose=self._verbose, *args, **kwargs)
+ self.cost = self._calc_cost(*args, verbose=self._verbose, **kwargs)
self._cost_list.append(self.cost)
if self._verbose:
From b9827ee38bbac3ff7b6d1d899491fc556d52fca6 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 12:19:45 +0100
Subject: [PATCH 19/34] fix: B028
---
src/modopt/__init__.py | 1 +
src/modopt/interface/errors.py | 2 +-
src/modopt/opt/linear/wavelet.py | 4 +++-
3 files changed, 5 insertions(+), 2 deletions(-)
diff --git a/src/modopt/__init__.py b/src/modopt/__init__.py
index e8ae1c3c..e93e45ff 100644
--- a/src/modopt/__init__.py
+++ b/src/modopt/__init__.py
@@ -19,6 +19,7 @@
warn(
"Could not extract package metadata. Make sure the package is "
+ "correctly installed.",
+ stacklevel=1,
)
__version__ = _version
diff --git a/src/modopt/interface/errors.py b/src/modopt/interface/errors.py
index 93e9ed1b..84031e3c 100644
--- a/src/modopt/interface/errors.py
+++ b/src/modopt/interface/errors.py
@@ -41,7 +41,7 @@ def warn(warn_string, log=None):
# Check if a logging structure is provided.
if not isinstance(log, type(None)):
- warnings.warn(warn_string)
+ warnings.warn(warn_string, stacklevel=2)
def catch_error(exception, log=None):
diff --git a/src/modopt/opt/linear/wavelet.py b/src/modopt/opt/linear/wavelet.py
index ff434287..fa450ba2 100644
--- a/src/modopt/opt/linear/wavelet.py
+++ b/src/modopt/opt/linear/wavelet.py
@@ -201,7 +201,9 @@ def __init__(
self.n_batch = n_batch
if self.n_batch == 1 and self.n_jobs != 1:
- warnings.warn("Making n_jobs = 1 for WaveletTransform as n_batchs = 1")
+ warnings.warn(
+ "Making n_jobs = 1 for WaveletTransform as n_batchs = 1", stacklevel=1
+ )
self.n_jobs = 1
self.backend = backend
n_proc = self.n_jobs
From 0811b7a9582a289ba4f9b24ba12a8257e2528bbf Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 12:20:37 +0100
Subject: [PATCH 20/34] fix: ruff format
---
src/modopt/base/observable.py | 2 --
src/modopt/math/matrix.py | 2 --
src/modopt/opt/algorithms/base.py | 2 --
src/modopt/opt/algorithms/forward_backward.py | 4 ----
src/modopt/opt/algorithms/primal_dual.py | 1 -
src/modopt/opt/cost.py | 2 --
src/modopt/opt/gradient.py | 8 --------
src/modopt/opt/linear/base.py | 6 ------
src/modopt/opt/linear/wavelet.py | 2 --
src/modopt/opt/proximity.py | 13 -------------
src/modopt/opt/reweight.py | 1 -
11 files changed, 43 deletions(-)
diff --git a/src/modopt/base/observable.py b/src/modopt/base/observable.py
index 15996dfa..bf8371c3 100644
--- a/src/modopt/base/observable.py
+++ b/src/modopt/base/observable.py
@@ -31,7 +31,6 @@ class Observable:
"""
def __init__(self, signals):
-
# Define class parameters
self._allowed_signals = []
self._observers = {}
@@ -213,7 +212,6 @@ def __init__(
wind=6,
eps=1.0e-3,
):
-
self.name = name
self.metric = metric
self.mapping = mapping
diff --git a/src/modopt/math/matrix.py b/src/modopt/math/matrix.py
index e52cbfc5..b200f15d 100644
--- a/src/modopt/math/matrix.py
+++ b/src/modopt/math/matrix.py
@@ -63,7 +63,6 @@ def gram_schmidt(matrix, return_opt="orthonormal"):
e_vec = []
for vector in matrix:
-
if u_vec:
u_now = vector - sum(project(u_i, vector) for u_i in u_vec)
else:
@@ -304,7 +303,6 @@ def __init__(
verbose=False,
rng=None,
):
-
self._operator = operator
self._data_shape = data_shape
self._data_type = data_type
diff --git a/src/modopt/opt/algorithms/base.py b/src/modopt/opt/algorithms/base.py
index dbb73be0..f7391063 100644
--- a/src/modopt/opt/algorithms/base.py
+++ b/src/modopt/opt/algorithms/base.py
@@ -110,7 +110,6 @@ def metrics(self):
@metrics.setter
def metrics(self, metrics):
-
if isinstance(metrics, type(None)):
self._metrics = {}
elif isinstance(metrics, dict):
@@ -271,7 +270,6 @@ def _iterations(self, max_iter, progbar=None):
# We do not call metrics if metrics is empty or metric call
# period is None
if self.metrics and self.metric_call_period is not None:
-
metric_conditions = (
self.idx % self.metric_call_period == 0
or self.idx == (max_iter - 1)
diff --git a/src/modopt/opt/algorithms/forward_backward.py b/src/modopt/opt/algorithms/forward_backward.py
index 4c1cb35c..31927eb0 100644
--- a/src/modopt/opt/algorithms/forward_backward.py
+++ b/src/modopt/opt/algorithms/forward_backward.py
@@ -72,7 +72,6 @@ def __init__(
r_lazy=4,
**kwargs,
):
-
if isinstance(a_cd, type(None)):
self.mode = "regular"
self.p_lazy = p_lazy
@@ -355,7 +354,6 @@ def __init__(
linear=None,
**kwargs,
):
-
# Set default algorithm properties
super().__init__(
metric_call_period=metric_call_period,
@@ -588,7 +586,6 @@ def __init__(
linear=None,
**kwargs,
):
-
# Set default algorithm properties
super().__init__(
metric_call_period=metric_call_period,
@@ -875,7 +872,6 @@ def __init__(
metrics=None,
**kwargs,
):
-
# Set default algorithm properties
super().__init__(
metric_call_period=metric_call_period,
diff --git a/src/modopt/opt/algorithms/primal_dual.py b/src/modopt/opt/algorithms/primal_dual.py
index 24908993..fee49a25 100644
--- a/src/modopt/opt/algorithms/primal_dual.py
+++ b/src/modopt/opt/algorithms/primal_dual.py
@@ -95,7 +95,6 @@ def __init__(
metrics=None,
**kwargs,
):
-
# Set default algorithm properties
super().__init__(
metric_call_period=metric_call_period,
diff --git a/src/modopt/opt/cost.py b/src/modopt/opt/cost.py
index 5c5701b1..37771f16 100644
--- a/src/modopt/opt/cost.py
+++ b/src/modopt/opt/cost.py
@@ -81,7 +81,6 @@ def __init__(
verbose=True,
plot_output=None,
):
-
self.cost = initial_cost
self._cost_list = []
self._cost_interval = cost_interval
@@ -112,7 +111,6 @@ def _check_cost(self):
# Check if enough cost values have been collected
if len(self._test_list) == self._test_range:
-
# The mean of the first half of the test list
t1 = xp.mean(
xp.array(self._test_list[len(self._test_list) // 2 :]),
diff --git a/src/modopt/opt/gradient.py b/src/modopt/opt/gradient.py
index 3c5f0031..fe9b87d8 100644
--- a/src/modopt/opt/gradient.py
+++ b/src/modopt/opt/gradient.py
@@ -69,7 +69,6 @@ def __init__(
input_data_writeable=False,
verbose=True,
):
-
self.verbose = verbose
self._input_data_writeable = input_data_writeable
self._grad_data_type = data_type
@@ -98,7 +97,6 @@ def obs_data(self):
@obs_data.setter
def obs_data(self, input_data):
-
if self._grad_data_type in {float, np.floating}:
input_data = check_float(input_data)
check_npndarray(
@@ -126,7 +124,6 @@ def op(self):
@op.setter
def op(self, operator):
-
self._op = check_callable(operator)
@property
@@ -145,7 +142,6 @@ def trans_op(self):
@trans_op.setter
def trans_op(self, operator):
-
self._trans_op = check_callable(operator)
@property
@@ -155,7 +151,6 @@ def get_grad(self):
@get_grad.setter
def get_grad(self, method):
-
self._get_grad = check_callable(method)
@property
@@ -165,7 +160,6 @@ def grad(self):
@grad.setter
def grad(self, input_value):
-
if self._grad_data_type in {float, np.floating}:
input_value = check_float(input_value)
self._grad = input_value
@@ -177,7 +171,6 @@ def cost(self):
@cost.setter
def cost(self, method):
-
self._cost = check_callable(method)
def trans_op_op(self, input_data):
@@ -241,7 +234,6 @@ class GradBasic(GradParent):
"""
def __init__(self, *args, **kwargs):
-
super().__init__(*args, **kwargs)
self.get_grad = self._get_grad_method
self.cost = self._cost_method
diff --git a/src/modopt/opt/linear/base.py b/src/modopt/opt/linear/base.py
index 9fa35187..af748a73 100644
--- a/src/modopt/opt/linear/base.py
+++ b/src/modopt/opt/linear/base.py
@@ -30,7 +30,6 @@ class LinearParent:
"""
def __init__(self, op, adj_op):
-
self.op = op
self.adj_op = adj_op
@@ -41,7 +40,6 @@ def op(self):
@op.setter
def op(self, operator):
-
self._op = check_callable(operator)
@property
@@ -51,7 +49,6 @@ def adj_op(self):
@adj_op.setter
def adj_op(self, operator):
-
self._adj_op = check_callable(operator)
@@ -67,7 +64,6 @@ class Identity(LinearParent):
"""
def __init__(self):
-
self.op = lambda input_data: input_data
self.adj_op = self.op
self.cost = lambda *args, **kwargs: 0
@@ -127,7 +123,6 @@ class LinearCombo(LinearParent):
"""
def __init__(self, operators, weights=None):
-
operators, weights = self._check_inputs(operators, weights)
self.operators = operators
self.weights = weights
@@ -199,7 +194,6 @@ def _check_inputs(self, operators, weights):
operators = self._check_type(operators)
for operator in operators:
-
if not hasattr(operator, "op"):
raise ValueError('Operators must contain "op" method.')
diff --git a/src/modopt/opt/linear/wavelet.py b/src/modopt/opt/linear/wavelet.py
index fa450ba2..e554150e 100644
--- a/src/modopt/opt/linear/wavelet.py
+++ b/src/modopt/opt/linear/wavelet.py
@@ -46,7 +46,6 @@ class WaveletConvolve(LinearParent):
"""
def __init__(self, filters, method="scipy"):
-
self._filters = check_float(filters)
self.op = lambda input_data: filter_convolve_stack(
input_data,
@@ -94,7 +93,6 @@ def __init__(
compute_backend="numpy",
**kwargs,
):
-
if compute_backend == "cupy" and ptwt_available:
self.operator = CupyWaveletTransform(
wavelet=wavelet_name, shape=shape, level=level, mode=mode
diff --git a/src/modopt/opt/proximity.py b/src/modopt/opt/proximity.py
index 10e69a98..204a168d 100644
--- a/src/modopt/opt/proximity.py
+++ b/src/modopt/opt/proximity.py
@@ -46,7 +46,6 @@ class ProximityParent:
"""
def __init__(self, op, cost):
-
self.op = op
self.cost = cost
@@ -57,7 +56,6 @@ def op(self):
@op.setter
def op(self, operator):
-
self._op = check_callable(operator)
@property
@@ -77,7 +75,6 @@ def cost(self):
@cost.setter
def cost(self, method):
-
self._cost = check_callable(method)
@@ -97,7 +94,6 @@ class IdentityProx(ProximityParent):
"""
def __init__(self):
-
self.op = lambda x_val: x_val
self.cost = lambda x_val: 0
@@ -115,7 +111,6 @@ class Positivity(ProximityParent):
"""
def __init__(self):
-
self.op = lambda input_data: positive(input_data)
self.cost = self._cost_method
@@ -166,7 +161,6 @@ class SparseThreshold(ProximityParent):
"""
def __init__(self, linear, weights, thresh_type="soft"):
-
self._linear = linear
self.weights = weights
self._thresh_type = thresh_type
@@ -276,7 +270,6 @@ def __init__(
initial_rank=None,
operator=None,
):
-
self.thresh = threshold
self.thresh_type = thresh_type
self.lowr_type = lowr_type
@@ -468,7 +461,6 @@ class ProximityCombo(ProximityParent):
"""
def __init__(self, operators):
-
operators = self._check_operators(operators)
self.operators = operators
self.op = self._op_method
@@ -737,7 +729,6 @@ class Ridge(ProximityParent):
"""
def __init__(self, linear, weights, thresh_type="soft"):
-
self._linear = linear
self.weights = weights
self.op = self._op_method
@@ -824,7 +815,6 @@ class ElasticNet(ProximityParent):
"""
def __init__(self, linear, alpha, beta):
-
self._linear = linear
self.alpha = alpha
self.beta = beta
@@ -1080,12 +1070,10 @@ def _binary_search(self, input_data, alpha, extra_factor=1.0):
midpoint = 0
while (first_idx <= last_idx) and not found and (cnt < alpha.shape[0]):
-
midpoint = (first_idx + last_idx) // 2
cnt += 1
if prev_midpoint == midpoint:
-
# Particular case
sum0 = self._compute_theta(
data_abs,
@@ -1287,7 +1275,6 @@ def _find_q(self, sorted_data):
and not cnt == self._k_value
and (first_idx <= last_idx < self._k_value)
):
-
q_val = (first_idx + last_idx) // 2
cnt += 1
l1_part = sorted_data[q_val:].sum() / (self._k_value - q_val)
diff --git a/src/modopt/opt/reweight.py b/src/modopt/opt/reweight.py
index b37fc6fb..4a9bf44b 100644
--- a/src/modopt/opt/reweight.py
+++ b/src/modopt/opt/reweight.py
@@ -43,7 +43,6 @@ class cwbReweight:
"""
def __init__(self, weights, thresh_factor=1.0, verbose=False):
-
self.weights = check_float(weights)
self.original_weights = np.copy(self.weights)
self.thresh_factor = check_float(thresh_factor)
From 067e29f18a67ecb33c067538ee6245df75e93688 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 13:59:01 +0100
Subject: [PATCH 21/34] fix: NPY002
---
tests/test_opt.py | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/tests/test_opt.py b/tests/test_opt.py
index 0a73e835..039cb9fd 100644
--- a/tests/test_opt.py
+++ b/tests/test_opt.py
@@ -36,6 +36,8 @@
except ImportError:
PYWT_AVAILABLE = False
+rng = np.random.default_rng()
+
# Basic functions to be used as operators or as dummy functions
def func_identity(x_val):
@@ -461,7 +463,7 @@ def case_prox_grouplasso(self, use_weights):
else:
weights = np.tile(np.zeros((3, 3)), (4, 1, 1))
- random_data = 3 * np.random.random(weights[0].shape)
+ random_data = 3 * rng.random(weights[0].shape)
random_data_tile = np.tile(random_data, (weights.shape[0], 1, 1))
if use_weights:
gl_result_data = 2 * random_data_tile - 3
From 0040d1a7def41be40089e377325447a1c459684f Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 14:23:37 +0100
Subject: [PATCH 22/34] proj: add pytest-xdist for parallel testing.
---
pyproject.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index 835dae3d..1e9a7f20 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -22,7 +22,7 @@ doc=["myst-parser==0.16.1",
"sphinxawesome-theme==3.2.1",
"sphinxcontrib-bibtex"]
dev=["black", "ruff"]
-test=["pytest<8.0.0", "pytest-cases", "pytest-cov", "pytest-sugar"]
+test=["pytest<8.0.0", "pytest-cases", "pytest-cov", "pytest-xdist", "pytest-sugar"]
[build-system]
requires=["setuptools", "setuptools-scm[toml]", "wheel"]
From 4189da6307237041875d89f0c9c82de47f5c9f9e Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 14:32:17 +0100
Subject: [PATCH 23/34] remove type annotations.
---
src/modopt/opt/linear/wavelet.py | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
diff --git a/src/modopt/opt/linear/wavelet.py b/src/modopt/opt/linear/wavelet.py
index e554150e..8012a072 100644
--- a/src/modopt/opt/linear/wavelet.py
+++ b/src/modopt/opt/linear/wavelet.py
@@ -285,10 +285,10 @@ class TorchWaveletTransform:
def __init__(
self,
- shape: tuple[int, ...],
- wavelet: str,
- level: int,
- mode: str,
+ shape,
+ wavelet,
+ level,
+ mode,
):
self.wavelet = wavelet
self.level = level
@@ -417,10 +417,10 @@ class CupyWaveletTransform(LinearParent):
def __init__(
self,
- shape: tuple[int, ...],
- wavelet: str,
- level: int,
- mode: str,
+ shape,
+ wavelet,
+ level,
+ mode,
):
self.wavelet = wavelet
self.level = level
From 81953519875e012a0e1e8752f7db5638f1c45696 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 14:33:34 +0100
Subject: [PATCH 24/34] proj: use importlib_metada for python 3.8 compat.
---
pyproject.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/pyproject.toml b/pyproject.toml
index 1e9a7f20..6f9637d1 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -11,7 +11,7 @@ authors = [{name="Samuel Farrens", email="samuel.farrens@cea.fr"},
readme="README.md"
license={file="LICENCE.txt"}
-dependencies = ["numpy", "scipy", "tqdm"]
+dependencies = ["numpy", "scipy", "tqdm", "importlib_metadata"]
[project.optional-dependencies]
gpu=["torch", "ptwt"]
From 51ffef3b714351b46a515e578b14629fb2e8bb52 Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 14:42:21 +0100
Subject: [PATCH 25/34] fix: update sssim values
---
tests/test_math.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/tests/test_math.py b/tests/test_math.py
index 37b71b1e..5c466e5e 100644
--- a/tests/test_math.py
+++ b/tests/test_math.py
@@ -205,8 +205,8 @@ class TestMetrics:
data1 = np.arange(49).reshape(7, 7)
mask = np.ones(data1.shape)
- ssim_res = 0.8963363560519094
- ssim_mask_res = 0.805154442543846
+ ssim_res = 0.8958315888566867
+ ssim_mask_res = 0.8023827544418249
snr_res = 10.134554256920536
psnr_res = 14.860761791850397
mse_res = 0.03265305507330247
From 495680387323817a1e897582f428b87653e9172f Mon Sep 17 00:00:00 2001
From: Pierre-antoine Comby
Date: Mon, 19 Feb 2024 15:32:51 +0100
Subject: [PATCH 26/34] adapt doc to new format.
---
docs/source/conf.py | 9 +++------
examples/example_lasso_forward_backward.py | 1 +
pyproject.toml | 11 ++++++-----
3 files changed, 10 insertions(+), 11 deletions(-)
diff --git a/docs/source/conf.py b/docs/source/conf.py
index cd39ee08..e9d88229 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -21,9 +21,6 @@
copyright = f"2020, {author}"
gh_user = "sfarrens"
-# If your documentation needs a minimal Sphinx version, state it here.
-needs_sphinx = "3.3"
-
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
@@ -38,7 +35,7 @@
"sphinx.ext.napoleon",
"sphinx.ext.todo",
"sphinx.ext.viewcode",
- "sphinxawesome_theme",
+ "sphinxawesome_theme.highlighting",
"sphinxcontrib.bibtex",
"myst_parser",
"nbsphinx",
@@ -103,7 +100,7 @@
}
html_collapsible_definitions = True
html_awesome_headerlinks = True
-html_logo = "modopt_logo.jpg"
+html_logo = "modopt_logo.png"
html_permalinks_icon = (
'