Skip to content

Commit

Permalink
👌 Remove unused args and kwags in kspace_to_*
Browse files Browse the repository at this point in the history
    * Excepting kspace_to_BE(*args)

⬆️  follow the tqdm 5.0
💬  update type hints
  • Loading branch information
arafune committed Sep 29, 2023
1 parent 920f500 commit ea41530
Show file tree
Hide file tree
Showing 13 changed files with 92 additions and 73 deletions.
74 changes: 49 additions & 25 deletions arpes/analysis/deconvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,22 @@
from __future__ import annotations

import contextlib
from typing import TYPE_CHECKING
import warnings
from typing import TYPE_CHECKING, Literal

import numpy as np
import scipy
import scipy.ndimage
import xarray as xr
from tqdm import tqdm_notebook
from tqdm.notebook import tqdm

from arpes.fits.fit_models.functional_forms import gaussian
from arpes.provenance import update_provenance
from arpes.utilities import normalize_to_spectrum

if TYPE_CHECKING:
from collections.abc import Iterable

from _typeshed import Incomplete

from arpes._typing import DataType
Expand All @@ -27,7 +30,12 @@


@update_provenance("Approximate Iterative Deconvolution")
def deconvolve_ice(data: DataType, psf, n_iterations: int = 5, deg: int | None = None) -> DataType:
def deconvolve_ice(
data: DataType,
psf: xr.DataArray,
n_iterations: int = 5,
deg: int | None = None,
) -> DataType:
"""Deconvolves data by a given point spread function.
The iterative convolution extrapolation method is used.
Expand Down Expand Up @@ -74,20 +82,21 @@ def deconvolve_ice(data: DataType, psf, n_iterations: int = 5, deg: int | None =
@update_provenance("Lucy Richardson Deconvolution")
def deconvolve_rl(
data: DataType,
psf=None,
n_iterations=10,
axis=None,
sigma=None,
mode="reflect",
psf: xr.DataArray | None = None,
n_iterations: int = 10,
axis: str = "",
sigma: float = 0,
mode: Literal["reflect", "constant", "nearest", "mirror", "wrap"] = "reflect",
*,
progress: bool = True,
) -> DataType:
) -> xr.DataArray:
"""Deconvolves data by a given point spread function using the Richardson-Lucy method.
Args:
data
axis
sigma
mode
mode: pass to ndimage.convolve
progress
psf: for 1d, if not specified, must specify axis and sigma
n_iterations: the number of convolutions to use for the fit
Expand All @@ -97,21 +106,37 @@ def deconvolve_rl(
"""
arr = normalize_to_spectrum(data)

if psf is None and axis is not None and sigma is not None:
if psf is None and axis != "" and sigma != 0:
# if no psf is provided and we have the information to make a 1d one
# note: this assumes gaussian psf
psf = make_psf1d(data=arr, dim=axis, sigma=sigma)

if len(data.dims) > 1:
if axis is not None:
if not axis:
# perform one-dimensional deconvolution of multidimensional data

# support for progress bars
def wrap_progress(x, *args: Incomplete, **kwargs: Incomplete):
def wrap_progress(
x: Iterable[int],
*args: Incomplete,
**kwargs: Incomplete,
) -> Iterable[int]:
if args:
for arg in args:
warnings.warn(
f"unused args is set in deconvolution.py/wrap_progress: {arg}",
stacklevel=2,
)
if kwargs:
for k, v in kwargs.items():
warnings.warn(
f"unused args is set in deconvolution.py/wrap_progress: {k}: {v}",
stacklevel=2,
)
return x

if progress:
wrap_progress = tqdm_notebook
wrap_progress = tqdm

# dimensions over which to iterate
other_dim = list(data.dims)
Expand Down Expand Up @@ -142,7 +167,7 @@ def wrap_progress(x, *args: Incomplete, **kwargs: Incomplete):
data=iteration,
psf=psf,
n_iterations=n_iterations,
axis=None,
axis="",
mode=mode,
)
# build results out of these pieces
Expand Down Expand Up @@ -173,7 +198,7 @@ def wrap_progress(x, *args: Incomplete, **kwargs: Incomplete):
data=iteration1,
psf=psf,
n_iterations=n_iterations,
axis=None,
axis="",
mode=mode,
)
# build results out of these pieces
Expand All @@ -184,7 +209,7 @@ def wrap_progress(x, *args: Incomplete, **kwargs: Incomplete):
# separate code
msg = "high-dimensional data not yet supported"
raise NotImplementedError(msg)
elif axis is None:
elif not axis:
# crude attempt to perform multidimensional deconvolution.
# not clear if this is currently working
# TODO: may be able to do this as a sequence of one-dimensional deconvolutions, assuming
Expand Down Expand Up @@ -224,13 +249,13 @@ def wrap_progress(x, *args: Incomplete, **kwargs: Incomplete):


@update_provenance("Make 1D-Point Spread Function")
def make_psf1d(data: DataType, dim, sigma):
def make_psf1d(data: DataType, dim: str, sigma: float) -> xr.DataArray:
"""Produces a 1-dimensional gaussian point spread function for use in deconvolve_rl.
Args:
data
dim
sigma
data (DataType): xarray object
dim (str): dimension name
sigma (float): sigma value
Returns:
A one dimensional point spread array.
Expand All @@ -245,15 +270,14 @@ def make_psf1d(data: DataType, dim, sigma):


@update_provenance("Make Point Spread Function")
def make_psf(data: DataType, sigmas):
def make_psf(data: DataType, sigmas: dict[str, float]) -> xr.DataArray:
"""Produces an n-dimensional gaussian point spread function for use in deconvolve_rl.
Not yet operational.
Args:
data
dim
sigma
data (DataType):
sigmas (dict[str, float]): sigma values for each dimension.
Returns:
The PSF to use.
Expand Down
6 changes: 3 additions & 3 deletions arpes/bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import numpy as np
import scipy.stats
import xarray as xr
from tqdm import tqdm_notebook
from tqdm.notebook import tqdm

from arpes.analysis.sarpes import to_intensity_polarization
from arpes.provenance import update_provenance
Expand Down Expand Up @@ -176,7 +176,7 @@ def bootstrap_counts(data: DataType, N: int = 1000, name: str | None = None) ->
desc_fragment = f" {name}"

resampled_sets = []
for _ in tqdm_notebook(range(N), desc=f"Resampling{desc_fragment}..."):
for _ in tqdm(range(N), desc=f"Resampling{desc_fragment}..."):
resampled_sets.append(resample_true_counts(data))

resampled_arr = np.stack([s.values for s in resampled_sets], axis=0)
Expand Down Expand Up @@ -349,7 +349,7 @@ def get_label(i):
"Fair warning 2: Ensure that the data to resample is in a DataArray and not a Dataset",
)

for _ in tqdm_notebook(range(N), desc="Resampling..."):
for _ in tqdm(range(N), desc="Resampling..."):
new_args = list(args)
new_kwargs = copy.copy(kwargs)
for i in resample_indices:
Expand Down
2 changes: 1 addition & 1 deletion arpes/endstations/plugin/SSRF_NSRL.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class DA30_L(SingleFileEndstation):
"chi", # convert kspcae need them
]

RENAME_KEYS: ClassVar[dict[str, float | int | str]] = {
RENAME_KEYS: ClassVar[dict[str, float | str]] = {
"sample": "sample_name",
"spectrum_name": "spectrum_type",
"low_energy": "sweep_low_energy",
Expand Down
8 changes: 4 additions & 4 deletions arpes/fits/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np
import xarray as xr
from packaging import version
from tqdm import tqdm_notebook
from tqdm.notebook import tqdm

import arpes.fits.fit_models
from arpes.provenance import update_provenance
Expand All @@ -29,7 +29,7 @@
from . import mp_fits

if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Callable, Iterable

import lmfit

Expand Down Expand Up @@ -186,10 +186,10 @@ def broadcast_model(
# parse_model just reterns model_cls as is.

if progress:
wrap_progress = tqdm_notebook
wrap_progress = tqdm
else:

def wrap_progress(x, *_, **__):
def wrap_progress(x: Iterable[int], *_, **__) -> Iterable[int]:
return x

serialize = parallelize
Expand Down
2 changes: 1 addition & 1 deletion arpes/plotting/dispersion.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

import xarray as xr
from _typeshed import Incomplete
from matplotlib.figure import Figure, FigureBase
from matplotlib.colors import Normalize
from matplotlib.figure import Figure, FigureBase

from arpes._typing import DataType

Expand Down
3 changes: 2 additions & 1 deletion arpes/plotting/fermi_surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def fermi_surface_slices(
if out:
renderer = hv.renderer("matplotlib").instance(fig="svg", holomap="gif")
filename = path_for_plot(out)
renderer.save(layout, path_for_holoviews(filename))
renderer.save(layout, path_for_holoviews(str(filename)))
return filename
return layout

Expand All @@ -72,6 +72,7 @@ def magnify_circular_regions_plot(
magnified_points: NDArray[np.float_] | list[float],
mag: float = 10,
radius: float = 0.05,
# below this can be treated as kwargs?
cmap: Colormap | ColorType = "viridis",
color: ColorType | None = None,
edgecolor: ColorType = "red",
Expand Down
10 changes: 5 additions & 5 deletions arpes/plotting/spin.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def spin_colored_spectrum(
if len(intensity.dims) == 1:
inset_ax = inset_axes(ax, width="30%", height="5%", loc=1)
coord = intensity.coords[intensity.dims[0]]
points = np.array([coord.values, intensity.values]).G.reshape(-1, 1, 2)
points = np.array([coord.values, intensity.values]).reshape(-1, 1, 2)
pol.values[np.isnan(pol.values)] = 0
pol.values[pol.values > 1] = 1
pol.values[pol.values < -1] = -1
Expand All @@ -83,7 +83,7 @@ def spin_colored_spectrum(
polarization_colorbar(inset_ax)

if out:
savefig(out, dpi=400)
savefig(str(out), dpi=400)
plt.clf()
return path_for_plot(out)
plt.show()
Expand Down Expand Up @@ -113,7 +113,7 @@ def spin_difference_spectrum(
if len(intensity.dims) == 1:
inset_ax = inset_axes(ax, width="30%", height="5%", loc=1)
coord = intensity.coords[intensity.dims[0]]
points = np.array([coord.values, intensity.values]).G.reshape(-1, 1, 2)
points = np.array([coord.values, intensity.values]).reshape(-1, 1, 2)
pol.values[np.isnan(pol.values)] = 0
pol.values[pol.values > 1] = 1
pol.values[pol.values < -1] = -1
Expand All @@ -136,7 +136,7 @@ def spin_difference_spectrum(
polarization_colorbar(inset_ax)

if out:
savefig(out, dpi=400)
savefig(str(out), dpi=400)
plt.clf()
return path_for_plot(out)
plt.show()
Expand Down Expand Up @@ -228,7 +228,7 @@ def spin_polarized_spectrum(
plt.tight_layout()

if out:
savefig(out, dpi=400)
savefig(str(out), dpi=400)
plt.clf()
return path_for_plot(out)

Expand Down
18 changes: 8 additions & 10 deletions arpes/plotting/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Contains many common utility functions for managing matplotlib."""
from __future__ import annotations

import collections
import contextlib
import datetime
import errno
Expand All @@ -12,7 +11,7 @@
import re
import warnings
from collections import Counter
from collections.abc import Sequence
from collections.abc import Generator, Iterable, Sequence
from datetime import UTC
from logging import DEBUG, Formatter, StreamHandler, getLogger
from pathlib import Path
Expand Down Expand Up @@ -310,7 +309,7 @@ def simple_ax_grid(


@contextlib.contextmanager
def dark_background(overrides: dict[str, Incomplete]):
def dark_background(overrides: dict[str, Incomplete]) -> Generator:
"""Context manager for plotting "dark mode"."""
defaults = {
"axes.edgecolor": "white",
Expand Down Expand Up @@ -784,8 +783,8 @@ def dos_axes(
def inset_cut_locator(
data: DataType,
reference_data: DataType,
ax: Axes | None = None,
location: dict[str, Incomplete] | None = None,
ax: Axes,
location: dict[str, Incomplete],
color: RGBColorType = "red",
**kwargs: Incomplete,
) -> None:
Expand Down Expand Up @@ -818,10 +817,9 @@ def inset_cut_locator(
"beta": lambda: reference_data.S.beta,
"phi": lambda: reference_data.S.phi,
}

missing_dims = [d for d in data.dims if d not in location]
missing_values = {d: missing_dim_resolvers[d]() for d in missing_dims}
ordered_selector = [location.get(d, missing_values.get(d)) for d in data.dims]
missing_dims = [dim for dim in data.dims if dim not in location]
missing_values = {dim: missing_dim_resolvers[dim]() for dim in missing_dims}
ordered_selector = [location.get(dim, missing_values.get(dim)) for dim in data.dims]

n = 200

Expand All @@ -839,7 +837,7 @@ def resolve(name: str, value: slice | int) -> NDArray[np.float_]:

return np.ones((n,)) * value

n_cut_dims = len([d for d in ordered_selector if isinstance(d, collections.Iterable | slice)])
n_cut_dims = len([d for d in ordered_selector if isinstance(d, Iterable | slice)])
ordered_selector = [resolve(d, v) for d, v in zip(data.dims, ordered_selector)]

if missing_dims:
Expand Down
Loading

0 comments on commit ea41530

Please sign in to comment.