From 0bf8c1fbf8d79055fc9fed047913df93866d25e1 Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Wed, 27 Sep 2023 12:08:50 +0900 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A8=20=20stop=20using=20cm.get=5Fcmap?= =?UTF-8?q?=20function=20because=20it=20is=20deprecated.=20=F0=9F=92=AC=20?= =?UTF-8?q?=20update=20type=20hints?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Changes.md | 3 + arpes/_typing.py | 139 ++++++++++++ arpes/analysis/mask.py | 22 +- arpes/deep_learning/models/regression.py | 10 +- arpes/fits/broadcast_common.py | 19 +- arpes/fits/lmfit_plot.py | 21 +- arpes/optics.py | 12 - arpes/plotting/annotations.py | 66 ++++-- arpes/plotting/bands.py | 2 +- arpes/plotting/basic.py | 2 +- arpes/plotting/bz.py | 7 +- arpes/plotting/false_color.py | 26 ++- arpes/plotting/fermi_surface.py | 31 ++- arpes/plotting/fit_tool/__init__.py | 4 +- arpes/plotting/movie.py | 4 +- arpes/plotting/qt_tool/__init__.py | 8 +- arpes/plotting/spatial.py | 5 +- arpes/plotting/spin.py | 12 +- arpes/plotting/stack_plot.py | 34 ++- arpes/plotting/utils.py | 212 ++++++++++-------- arpes/utilities/attrs.py | 12 - arpes/utilities/bz.py | 59 +++-- arpes/utilities/collections.py | 12 +- arpes/utilities/funcutils.py | 28 +-- arpes/utilities/jupyter.py | 25 +-- arpes/utilities/math.py | 6 +- arpes/utilities/qt/app.py | 5 +- arpes/utilities/string.py | 10 +- arpes/widgets.py | 2 +- arpes/workflow.py | 2 - arpes/xarray_extensions.py | 41 ++-- tests/conftest.py | 5 +- tests/test_basic_data_loading.py | 42 ++-- tests/test_curve_fitting.py | 4 +- tests/test_derivative_analysis.py | 12 +- tests/test_direct_and_example_data_loading.py | 13 +- tests/test_generic_utilities.py | 1 + tests/test_momentum_conversion.py | 6 +- tests/test_montage.py | 1 + tests/test_qt.py | 1 + tests/test_time_configuration.py | 3 + 41 files changed, 598 insertions(+), 331 deletions(-) diff --git a/Changes.md b/Changes.md index 54c7fd78..81a8ab5c 100644 --- a/Changes.md +++ b/Changes.md @@ -15,6 +15,9 @@ Major Changes from 3.0.1 - Remove arpes.all - Certainly, this it is indeed a lazy and carefree approach, but it's too rough method that leads to a bugs and does not mathc the current pythonic style. +- Remove overlapped_stack_dispersion_plot + - use stack_dispersion_plot with appropriate args + Fix from 3.0.1 - bug of concatenating in broadcast_model diff --git a/arpes/_typing.py b/arpes/_typing.py index 103f28e4..1929c140 100644 --- a/arpes/_typing.py +++ b/arpes/_typing.py @@ -380,3 +380,142 @@ class ColorbarParam(TypedDict, total=False): boundaries: None | Sequence[float] values: None | Sequence[float] location: None | Literal["left", "right", "top", "bottom"] + + +class MPLTextParam(TypedDict, total=False): + agg_filter: Callable[[NDArray[np.float_], int], tuple[NDArray[np.float_], int, int]] + alpha: float | None + animated: bool + antialiased: bool + backgroundcolor: ColorType + color: ColorType + c: ColorType + figure: Figure + fontfamily: str + family: str + fontname: str + fontproperties: str | Path + font: str | Path + font_properties: str | Path + fontsize: float | Literal[ + "xx-small", + "x-small", + "small", + "medium", + "large", + "x-large", + "xx-large", + ] + size: float | Literal[ + "xx-small", + "x-small", + "small", + "medium", + "large", + "x-large", + "xx-large", + ] + fontstretch: float | Literal[ + "ultra-condensed", + "extra-condensed", + "condensed", + "semi-condensed", + "normal", + "semi-expanded", + "expanded", + "extra-expanded", + "ultra-expanded", + ] + stretch: float | Literal[ + "ultra-condensed", + "extra-condensed", + "condensed", + "semi-condensed", + "normal", + "semi-expanded", + "expanded", + "extra-expanded", + "ultra-expanded", + ] + fontstyle: Literal["normal", "italic", "oblique"] + style: Literal["normal", "italic", "oblique"] + fontvariant: Literal["normal", "small-caps"] + variant: Literal["normal", "small-caps"] + fontweight: float | Literal[ + "ultralight", + "light", + "normal", + "regular", + "book", + "medium", + "roman", + "semibold", + "demibold", + "demi", + "bold", + "heavy", + "extra bold", + "black", + ] + weight: float | Literal[ + "ultralight", + "light", + "normal", + "regular", + "book", + "medium", + "roman", + "semibold", + "demibold", + "demi", + "bold", + "heavy", + "extra bold", + "black", + ] + gid: str + horizontalalignment: Literal["left", "center", "right"] + ha: Literal["left", "center", "right"] + in_layout: bool + label: str + linespacing: float + math_fontfamily: str + mouseover: bool + multialignment: Literal["left", "right", "center"] + ma: Literal["left", "right", "center"] + parse_math: bool + path_effects: list[AbstractPathEffect] + picker: None | bool | float | Callable + position: tuple[float, float] + rasterized: bool + rotation: float | Literal["vertical", "horizontal"] + rotation_mode: Literal[None, "default", "anchor"] + sketch_params: tuple[float, float, float] + scale: float + length: float + randomness: float + snap: bool | None + text: str + transform: Transform + transform_rotates_text: bool + url: str + usetex: bool | None + verticalalignment: Literal["bottom", "baseline", "center", "center_baseline", "top"] + va: Literal["bottom", "baseline", "center", "center_baseline", "top"] + visible: bool + wrap: bool + zorder: float + + +class PLTSubplotParam(TypedDict, total=False): + sharex: bool | Literal["none", "all", "row", "col"] + sharey: bool | Literal["none", "all", "row", "col"] + squeeze: bool + width_ratios: Sequence[float] | None + height_ratios: Sequence[float] | None + subplot_kw: dict + gridspec_kw: dict + + +class IMshowParam(TypedDict, total=False): + pass diff --git a/arpes/analysis/mask.py b/arpes/analysis/mask.py index d90d4024..6e7c2d8a 100644 --- a/arpes/analysis/mask.py +++ b/arpes/analysis/mask.py @@ -1,12 +1,20 @@ """Utilities for applying masks to data.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + import numpy as np -import xarray as xr from matplotlib.path import Path -from arpes._typing import DataType from arpes.provenance import update_provenance from arpes.utilities import normalize_to_spectrum +if TYPE_CHECKING: + import xarray as xr + from _typeshed import Incomplete + + from arpes._typing import DataType + __all__ = ( "polys_to_mask", "apply_mask", @@ -15,7 +23,7 @@ ) -def raw_poly_to_mask(poly) -> dict: +def raw_poly_to_mask(poly: Incomplete) -> dict[str, Incomplete]: """Converts a polygon into a mask definition. There's not currently much metadata attached to masks, but this is @@ -37,7 +45,13 @@ def raw_poly_to_mask(poly) -> dict: } -def polys_to_mask(mask_dict, coords, shape, radius=None, invert=False): +def polys_to_mask( + mask_dict: dict[str, Incomplete], + coords, + shape, + radius=None, + invert=False, +) -> NDArray[np.float_] | NDArray[np.bool_]: """Converts a mask definition in terms of the underlying polygon to a True/False mask array. Uses the coordinates and shape of the target data in order to determine which pixels diff --git a/arpes/deep_learning/models/regression.py b/arpes/deep_learning/models/regression.py index 155fdddc..cb7e359f 100644 --- a/arpes/deep_learning/models/regression.py +++ b/arpes/deep_learning/models/regression.py @@ -1,8 +1,8 @@ """Very simple regression baselines.""" import pytorch_lightning as pl -import torch.nn.functional as F from torch import nn, optim +from torch.nn import functional __all__ = ["BaselineRegression", "LinearRegression"] @@ -17,7 +17,7 @@ def __init__(self) -> None: """Generate network components and use the mean squared error loss.""" super().__init__() self.linear = nn.Linear(self.input_dimensions, self.output_dimensions) - self.criterion = F.mse_loss + self.criterion = functional.mse_loss def forward(self, x): """Calculate the model output for the minibatch `x`.""" @@ -53,13 +53,13 @@ def __init__(self) -> None: self.l1 = nn.Linear(self.input_dimensions, 256) self.l2 = nn.Linear(256, 128) self.l3 = nn.Linear(128, self.output_dimensions) - self.criterion = F.mse_loss + self.criterion = functional.mse_loss def forward(self, x): """Calculate the model output for the minibatch `x`.""" flat_x = x.view(x.size(0), -1) - h1 = F.relu(self.l1(flat_x)) - h2 = F.relu(self.l2(h1)) + h1 = functional.relu(self.l1(flat_x)) + h2 = functional.relu(self.l2(h1)) return self.l3(h2) def training_step(self, batch, batch_index): diff --git a/arpes/fits/broadcast_common.py b/arpes/fits/broadcast_common.py index 56c00fc0..12e0cf2d 100644 --- a/arpes/fits/broadcast_common.py +++ b/arpes/fits/broadcast_common.py @@ -5,19 +5,21 @@ import operator import warnings from string import ascii_lowercase -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import lmfit import xarray as xr if TYPE_CHECKING: - from arpes.fits.fit_models.x_model_mixin import XModelMixin + from collections.abc import Iterable + from _typeshed import Incomplete -def unwrap_params(params, iter_coordinate): + +def unwrap_params(params: dict[str, Any], iter_coordinate: Incomplete) -> dict[str, Any]: """Inspects arraylike parameters and extracts appropriate value for current fit.""" - def transform_or_walk(v): + def transform_or_walk(v: dict | xr.DataArray | Iterable[float]): if isinstance(v, dict): return unwrap_params(v, iter_coordinate) @@ -48,7 +50,7 @@ def apply_window(data: xr.DataArray, cut_coords: dict[str, float | slice], windo return cut_data, original_cut_data -def _parens_to_nested(items): +def _parens_to_nested(items: list) -> list: """Turns a flat list with parentheses tokens into a nested list.""" parens = [ ( @@ -72,7 +74,9 @@ def _parens_to_nested(items): return items -def reduce_model_with_operators(models: tuple | list[XModelMixin]) -> XModelMixin: +def reduce_model_with_operators( + models: tuple[Incomplete, ...] | list[Incomplete], +) -> Incomplete: """Combine models according to mathematical operators.""" if isinstance(models, tuple): return models[0](prefix=f"{models[1]}_", nan_policy="omit") @@ -82,7 +86,8 @@ def reduce_model_with_operators(models: tuple | list[XModelMixin]) -> XModelMixi left, op, right = models[0], models[1], models[2:] left, right = reduce_model_with_operators(left), reduce_model_with_operators(right) - + assert left is not None + assert right is not None if op == "+": return left + right if op == "*": diff --git a/arpes/fits/lmfit_plot.py b/arpes/fits/lmfit_plot.py index 495183b9..b48e5a02 100644 --- a/arpes/fits/lmfit_plot.py +++ b/arpes/fits/lmfit_plot.py @@ -5,7 +5,7 @@ """ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Literal, TypedDict +from typing import TYPE_CHECKING, Literal, TypedDict, Unpack import matplotlib.pyplot as plt import xarray as xr @@ -13,9 +13,9 @@ if TYPE_CHECKING: import numpy as np + from _typeshed import Incomplete from matplotlib.figure import Figure from numpy.typing import NDArray - from typing_extensions import Unpack original_plot = model.ModelResult.plot @@ -31,12 +31,12 @@ class ModelResultPlotKwargs(TypedDict, total=False): yerr: NDArray[np.float_] numpoints: int fig: Figure - data_kws: dict - fit_kws: dict - init_kws: dict - ax_res_kws: dict - ax_fit_kws: dict - fig_kws: dict + data_kws: dict[str, Incomplete] + fit_kws: dict[str, Incomplete] + init_kws: dict[str, Incomplete] + ax_res_kws: dict[str, Incomplete] + ax_fit_kws: dict[str, Incomplete] + fig_kws: dict[str, Incomplete] show_init: bool parse_complex: Literal["abs", "real", "imag", "angle"] title: str @@ -51,8 +51,9 @@ def transform_lmfit_titles(label: str = "", *, is_title: bool = False) -> str: def patched_plot( - self: Any, **kwargs: Unpack[ModelResultPlotKwargs] -) -> Figure | Literal[False]: # noqa: ANN401 + self: Incomplete, + **kwargs: Unpack[ModelResultPlotKwargs], +) -> Figure | Literal[False]: """A patch for `lmfit` summary plots in PyARPES. Scientists like to have LaTeX in their plots, diff --git a/arpes/optics.py b/arpes/optics.py index 4b332396..bb47e1e0 100644 --- a/arpes/optics.py +++ b/arpes/optics.py @@ -18,8 +18,6 @@ import numpy as np __all__ = ( - "waist", - "waist_R", "rayleigh_range", "lens_transfer", "magnification", @@ -28,16 +26,6 @@ ) -def waist(wavelength: float, z: float, z_R: float) -> float: - """Calculates the waist size from the measurements at a distance from the waist.""" - raise NotImplementedError - - -def waist_R(waist_0: float, m_squared: float = 1.0) -> float: - """Calculates the width of t he beam a distance from the waist.""" - raise NotImplementedError - - def waist_from_rr(wavelength: float, rayleigh_rng: float) -> float: """Calculates the waist parameters from the Rayleigh range.""" return np.sqrt((wavelength * rayleigh_rng) / np.pi) diff --git a/arpes/plotting/annotations.py b/arpes/plotting/annotations.py index 917a35fa..2e01e0d3 100644 --- a/arpes/plotting/annotations.py +++ b/arpes/plotting/annotations.py @@ -1,16 +1,22 @@ """Annotations onto plots for experimental conditions or locations.""" from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Unpack import numpy as np +from matplotlib.axes import Axes3D from arpes.plotting.utils import name_for_dim, unit_for_dim from arpes.utilities.conversion.forward import convert_coordinates_to_kspace_forward if TYPE_CHECKING: + from collections.abc import Sequence + from _typeshed import Incomplete from matplotlib.axes import Axes + from numpy.typing import NDArray + + from arpes._typing import DataType, MPLTextParam __all__ = ( "annotate_cuts", @@ -18,16 +24,18 @@ "annotate_experimental_conditions", ) +TWODimensional = 2 + def annotate_experimental_conditions( ax: Axes, - data, - desc, + data: DataType, + desc: list[str | float] | float | str, *, show: bool = False, orientation: str = "top", - **kwargs: Incomplete, -): + **kwargs: Unpack[MPLTextParam], +) -> None: """Renders information about the experimental conditions onto a set of axes. Also adjust the axes limits and hides the axes. @@ -59,12 +67,12 @@ def annotate_experimental_conditions( delta = 1 current = 0 - fontsize = kwargs.pop("fontsize", 16) + fontsize = kwargs.get("fontsize", 16) delta = fontsize * delta conditions = data.S.experimental_conditions - def render_polarization(c) -> str: + def render_polarization(c: dict[str, Incomplete]) -> str: pol = c["polarization"] if pol in ["lc", "rc"]: return "\\textbf{" + pol.upper() + "}" @@ -86,7 +94,7 @@ def render_polarization(c) -> str: return prefix + "\\textbf{" + pol + "}" - def render_photon(c) -> str: + def render_photon(c: dict[str, float]) -> str: return "\\textbf{" + str(c["hv"]) + " eV" renderers = { @@ -103,11 +111,18 @@ def render_photon(c) -> str: item = item.replace("_", " ").lower() - ax.text(0, current, renderers[item](conditions), fontsize=fontsize, **kwargs) + ax.text(0, current, renderers[item](conditions), **kwargs) current += delta -def annotate_cuts(ax: Axes, data, plotted_axes, include_text_labels=False, **kwargs: Incomplete): +def annotate_cuts( + ax: Axes, + data: DataType, + plotted_axes: NDArray[np.object_], + *, + include_text_labels: bool = False, + **kwargs: Incomplete, +) -> None: """Annotates a cut location onto a plot. Example: @@ -121,7 +136,7 @@ def annotate_cuts(ax: Axes, data, plotted_axes, include_text_labels=False, **kwa kwargs: Defines the coordinates of the cut location """ converted_coordinates = convert_coordinates_to_kspace_forward(data) - assert len(plotted_axes) == 2 + assert len(plotted_axes) == TWODimensional for k, v in kwargs.items(): if not isinstance(v, tuple | list | np.ndarray): @@ -145,7 +160,13 @@ def annotate_cuts(ax: Axes, data, plotted_axes, include_text_labels=False, **kwa ) -def annotate_point(ax: Axes, location, label, delta=None, **kwargs: Incomplete): +def annotate_point( + ax: Axes | Axes3D, + location: Sequence[float], + label: str, + delta: tuple[float, ...] = (), + **kwargs: Unpack[MPLTextParam], +) -> None: """Annotates a point or high symmetry location into a plot.""" label = { "G": "$\\Gamma$", @@ -155,21 +176,22 @@ def annotate_point(ax: Axes, location, label, delta=None, **kwargs: Incomplete): "M": r"\textbf{M}", }.get(label, label) - if delta is None: + if not delta: delta = ( -0.05, 0.05, ) + if "color" not in kwargs: + kwargs["color"] = "red" - c = kwargs.pop("color", "red") - - if len(delta) == 2: + if len(delta) == TWODimensional: dx, dy = tuple(delta) - x, y = tuple(location) - ax.plot([x], [y], "o", c=c) - ax.text(x + dx, y + dy, label, color=c, **kwargs) + pos_x, pos_y = tuple(location) + ax.plot([pos_x], [pos_y], "o", c=kwargs["color"]) + ax.text(pos_x + dx, pos_y + dy, label, **kwargs) else: + assert isinstance(ax, Axes3D) dx, dy, dz = tuple(delta) - x, y, z = tuple(location) - ax.plot([x], [y], [z], "o", c=c) - ax.text(x + dx, y + dy, z + dz, label, color=c, **kwargs) + pos_x, pos_y, pos_z = tuple(location) + ax.plot([pos_x], [pos_y], [pos_z], "o", c=kwargs["color"]) + ax.text(pos_x + dx, pos_y + dy, pos_z + dz, label, **kwargs) diff --git a/arpes/plotting/bands.py b/arpes/plotting/bands.py index 589f1608..261dac38 100644 --- a/arpes/plotting/bands.py +++ b/arpes/plotting/bands.py @@ -3,11 +3,11 @@ import matplotlib.pyplot as plt from _typeshed import Incomplete +from build.lib.arpes.typing import DataType from matplotlib.axes import Axes from matplotlib.colors import Normalize from arpes.provenance import save_plot_provenance -from build.lib.arpes.typing import DataType from .utils import label_for_colorbar, path_for_plot diff --git a/arpes/plotting/basic.py b/arpes/plotting/basic.py index 098edda0..8b7231de 100644 --- a/arpes/plotting/basic.py +++ b/arpes/plotting/basic.py @@ -11,7 +11,7 @@ __all__ = ["make_reference_plots"] -def make_reference_plots(df: pd.DataFrame | None = None, *, with_kspace: bool = False): +def make_reference_plots(df: pd.DataFrame, *, with_kspace: bool = False) -> None: """Makes standard reference plots for orienting oneself.""" try: df = df[df.spectrum_type != "xps_spectrum"] diff --git a/arpes/plotting/bz.py b/arpes/plotting/bz.py index 270a6fcd..c1e1c54c 100644 --- a/arpes/plotting/bz.py +++ b/arpes/plotting/bz.py @@ -10,6 +10,7 @@ import matplotlib.pyplot as plt import numpy as np import xarray as xr +from matplotlib.axes import Axes from matplotlib.patches import FancyArrowPatch from mpl_toolkits.mplot3d import proj3d from mpl_toolkits.mplot3d.art3d import Poly3DCollection @@ -26,7 +27,6 @@ from pathlib import Path from _typeshed import Incomplete - from matplotlib.axes import Axes from matplotlib.figure import Figure from matplotlib.typing import RGBColorType from numpy.typing import NDArray @@ -215,6 +215,7 @@ def plot_data_to_bz2d( if ax is None: fig, ax = plt.subplots(figsize=(9, 9)) bz2d_plot(cell, paths="all", ax=ax) + assert isinstance(ax, Axes) if len(cell) == 2: cell = [[*list(c), 0] for c in cell] + [[0, 0, 1]] @@ -242,9 +243,9 @@ def plot_data_to_bz2d( built_mask = apply_mask_to_coords(raveled, build_2dbz_poly(cell=cell), dims) copied[built_mask.T] = np.nan - cmap = kwargs.get("cmap", matplotlib.cm.Blues) + cmap = kwargs.get("cmap", matplotlib.colormaps["Blues"]) if isinstance(cmap, str): - cmap = matplotlib.cm.get_cmap(cmap) + cmap = matplotlib.colormaps.get_cmap(cmap) cmap.set_bad((1, 1, 1, 0)) diff --git a/arpes/plotting/false_color.py b/arpes/plotting/false_color.py index 3b47bfa3..ad3bf413 100644 --- a/arpes/plotting/false_color.py +++ b/arpes/plotting/false_color.py @@ -7,6 +7,7 @@ import matplotlib.pyplot as plt import numpy as np import xarray as xr +from matplotlib.axes import Axes from arpes.plotting.utils import imshow_arr, path_for_plot from arpes.provenance import save_plot_provenance @@ -16,12 +17,12 @@ from pathlib import Path from _typeshed import Incomplete - from matplotlib.axes import Axes from matplotlib.figure import Figure + from numpy.typing import NDArray @save_plot_provenance -def false_color_plot( +def false_color_plot( # noqa: PLR0913 data_r: xr.Dataset, data_g: xr.Dataset, data_b: xr.Dataset, @@ -29,35 +30,38 @@ def false_color_plot( out: str | Path = "", *, invert: bool = False, - pmin=0, - pmax=1, + pmin: float = 0, + pmax: float = 1, **kwargs: Incomplete, ) -> Path | tuple[Figure | None, Axes]: """Plots a spectrum in false color after conversion to R, G, B arrays.""" - data_r, data_g, data_b = (normalize_to_spectrum(d) for d in (data_r, data_g, data_b)) + data_r_arr, data_g_arr, data_b_arr = ( + normalize_to_spectrum(d) for d in (data_r, data_g, data_b) + ) fig: Figure | None = None if ax is None: fig, ax = plt.subplots(figsize=kwargs.pop("figsize", (7, 5))) + assert isinstance(ax, Axes) - def normalize_channel(channel): + def normalize_channel(channel: NDArray[np.float_]) -> NDArray[np.float_]: channel -= np.percentile(channel, 100 * pmin) channel[channel > np.percentile(channel, 100 * pmax)] = np.percentile(channel, 100 * pmax) return channel / np.max(channel) - cs = dict(data_r.coords) + cs = dict(data_r_arr.coords) cs["dim_color"] = [1, 2, 3] arr = xr.DataArray( np.stack( [ - normalize_channel(data_r.values), - normalize_channel(data_g.values), - normalize_channel(data_b.values), + normalize_channel(data_r_arr.values), + normalize_channel(data_g_arr.values), + normalize_channel(data_b_arr.values), ], axis=-1, ), coords=cs, - dims=[*list(data_r.dims), "dim_color"], + dims=[*list(data_r_arr.dims), "dim_color"], ) if invert: diff --git a/arpes/plotting/fermi_surface.py b/arpes/plotting/fermi_surface.py index 09083997..a9e139c4 100644 --- a/arpes/plotting/fermi_surface.py +++ b/arpes/plotting/fermi_surface.py @@ -78,18 +78,34 @@ def magnify_circular_regions_plot( out: str | Path = "", ax: Axes | None = None, **kwargs: tuple[float, float], -) -> tuple[Figure, Axes] | Path: - """Plots a Fermi surface with inset points magnified in an inset.""" +) -> tuple[Figure | None, Axes] | Path: + """Plots a Fermi surface with inset points magnified in an inset. + + Args: + data: [TODO:description] + magnified_points: [TODO:description] + mag: [TODO:description] + radius: [TODO:description] + cmap: [TODO:description] + color: [TODO:description] + edgecolor: [TODO:description] + out: [TODO:description] + ax: [TODO:description] + kwargs: [TODO:description] + + Returns: + [TODO:description] + """ data_arr = normalize_to_spectrum(data) assert isinstance(data_arr, xr.DataArray) - fig: Figure + fig: Figure | None = None if ax is None: fig, ax = plt.subplots(figsize=kwargs.pop("figsize", (7, 5))) assert isinstance(ax, Axes) - mesh = data_arr.plot(ax=ax, cmap=cmap) + mesh = data_arr.S.plot(ax=ax, cmap=cmap) clim = list(mesh.get_clim()) clim[1] = clim[1] / mag @@ -121,13 +137,14 @@ def magnify_circular_regions_plot( if not isinstance(color, list): color = [color for _ in range(len(magnified_points))] + assert isinstance(color, list) pts[:, 1] = (pts[:, 1]) / (xlim[1] - xlim[0]) pts[:, 0] = (pts[:, 0]) / (ylim[1] - ylim[0]) print(np.min(pts[:, 1]), np.max(pts[:, 1])) print(np.min(pts[:, 0]), np.max(pts[:, 0])) - for c, ec, point in zip(color, edgecolor, magnified_points): + for c, ec, point in zip(color, edgecolor, magnified_points, strict=True): patch = matplotlib.patches.Ellipse( point, width, @@ -145,14 +162,14 @@ def magnify_circular_regions_plot( data_masked = data_arr.copy(deep=True) data_masked.values = np.array(data_masked.values, dtype=np.float_) - cm = matplotlib.cm.get_cmap(name="viridis") + cm = matplotlib.colormaps.get_cmap(cmap="viridis") cm.set_bad(color=(1, 1, 1, 0)) data_masked.values[ np.swapaxes(np.logical_not(mask.reshape(data_arr.values.shape[::-1])), 0, 1) ] = np.nan aspect = ax.get_aspect() - extent = [xlim[0], xlim[1], ylim[0], ylim[1]] + extent = (xlim[0], xlim[1], ylim[0], ylim[1]) ax.imshow(data_masked.values, cmap=cm, extent=extent, zorder=3, clim=clim, origin="lower") ax.set_aspect(aspect) diff --git a/arpes/plotting/fit_tool/__init__.py b/arpes/plotting/fit_tool/__init__.py index 20f094c3..a4edfba8 100644 --- a/arpes/plotting/fit_tool/__init__.py +++ b/arpes/plotting/fit_tool/__init__.py @@ -8,6 +8,7 @@ from dataclasses import dataclass import dill +import matplotlib as mpl import numpy as np import pyqtgraph as pg import xarray as xr @@ -451,9 +452,8 @@ def before_show(self): """Lifecycle hook for configuration before app show.""" self.configure_image_widgets() self.add_contextual_widgets() - import matplotlib.cm - self.set_colormap(matplotlib.cm.viridis) + self.set_colormap(mpl.colormaps["viridis"]) def after_show(self): """Initialize application state after app show. diff --git a/arpes/plotting/movie.py b/arpes/plotting/movie.py index 2ddbea80..9451673f 100644 --- a/arpes/plotting/movie.py +++ b/arpes/plotting/movie.py @@ -29,8 +29,8 @@ def plot_movie( """Make an animated plot of a 3D dataset using one dimension as "time". Args: - data: [TODO:description] - time_dim: [TODO:description] + data (xr.DataArray): ARPES data + time_dim (str): dimension name about time interval: [TODO:description] fig: [TODO:description] ax: [TODO:description] diff --git a/arpes/plotting/qt_tool/__init__.py b/arpes/plotting/qt_tool/__init__.py index 46055a54..1a64f4c9 100644 --- a/arpes/plotting/qt_tool/__init__.py +++ b/arpes/plotting/qt_tool/__init__.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING import dill +import matplotlib as mpl import numpy as np import pyqtgraph as pg from PyQt5 import QtCore, QtGui, QtWidgets @@ -465,12 +466,11 @@ def before_show(self): """Lifecycle hook for configuration before app show.""" self.configure_image_widgets() self.add_contextual_widgets() - import matplotlib.cm - if self.data.min() >= 0.0: - self.set_colormap(matplotlib.cm.viridis) + if self.data.min() >= 0: + self.set_colormap(mpl.colormaps["viridis"]) else: - self.set_colormap(matplotlib.cm.RdBu_r) + self.set_colormap(mpl.colormaps["RdBu_r"]) def after_show(self): """Initialize application state after app show. diff --git a/arpes/plotting/spatial.py b/arpes/plotting/spatial.py index 0c9c6adb..b6195768 100644 --- a/arpes/plotting/spatial.py +++ b/arpes/plotting/spatial.py @@ -5,11 +5,12 @@ import itertools from typing import TYPE_CHECKING, Any +import matplotlib as mpl import matplotlib.patheffects as path_effects import matplotlib.pyplot as plt import numpy as np import xarray as xr -from matplotlib import cm, gridspec, patches +from matplotlib import gridspec, patches from arpes.io import load_data from arpes.plotting.annotations import annotate_point @@ -98,7 +99,7 @@ def plot_spatial_reference( assert len(reference_map.dims) == two_dimension reference_map.S.plot(ax=ax, cmap="Blues") - cmap = cm.get_cmap("Reds") + cmap = mpl.colormaps.get_cmap("Reds") rendered_annotations = [] for i, (data, offset, annotation) in enumerate(zip(data_list, offset_list, annotation_list)): if offset is None: diff --git a/arpes/plotting/spin.py b/arpes/plotting/spin.py index a13b08e8..3c2a46c9 100644 --- a/arpes/plotting/spin.py +++ b/arpes/plotting/spin.py @@ -3,10 +3,10 @@ from typing import TYPE_CHECKING +import matplotlib as mpl import matplotlib.colors import matplotlib.pyplot as plt import numpy as np -from matplotlib import cm from matplotlib.axes import Axes from matplotlib.collections import LineCollection from matplotlib.figure import Figure @@ -64,10 +64,10 @@ def spin_colored_spectrum( pol.values[np.isnan(pol.values)] = 0 pol.values[pol.values > 1] = 1 pol.values[pol.values < -1] = -1 - pol_colors = cm.get_cmap("RdBu")(pol.values[:-1]) + pol_colors = mpl.colormaps.get_cmap("RdBu")(pol.values[:-1]) if scatter: - pol_colors = cm.get_cmap("RdBu")(pol.values) + pol_colors = mpl.colormaps.get_cmap("RdBu")(pol.values) ax.scatter(coord.values, intensity.values, c=pol_colors, s=1.5) else: segments = np.concatenate([points[:-1], points[1:]], axis=1) @@ -117,10 +117,10 @@ def spin_difference_spectrum( pol.values[np.isnan(pol.values)] = 0 pol.values[pol.values > 1] = 1 pol.values[pol.values < -1] = -1 - pol_colors = cm.get_cmap("RdBu")(pol.values[:-1]) + pol_colors = mpl.colormaps.get_cmap("RdBu")(pol.values[:-1]) if scatter: - pol_colors = cm.get_cmap("RdBu")(pol.values) + pol_colors = mpl.colormaps.get_cmap("RdBu")(pol.values) ax.scatter(coord.values, intensity.values, c=pol_colors, s=1.5) else: segments = np.concatenate([points[:-1], points[1:]], axis=1) @@ -260,7 +260,7 @@ def polarization_intensity_to_color( # use the 98th percentile data if not provided vmax = np.percentile(data.intensity.values, 98) - rgbas = cm.RdBu((data.polarization.values / pmax + 1) / 2) + rgbas = mpl.colormaps["RdBu"]((data.polarization.values / pmax + 1) / 2) slices = [slice(None) for _ in data.polarization.dims] + [slice(0, 3)] rgbs = rgbas[slices] diff --git a/arpes/plotting/stack_plot.py b/arpes/plotting/stack_plot.py index 20a77d36..b0be3d60 100644 --- a/arpes/plotting/stack_plot.py +++ b/arpes/plotting/stack_plot.py @@ -8,7 +8,10 @@ from typing import TYPE_CHECKING, Literal import matplotlib as mpl +import matplotlib.colorbar +import matplotlib.colors import matplotlib.pyplot as plt +import matplotlib.ticker import numpy as np import xarray as xr from matplotlib import colorbar @@ -34,9 +37,7 @@ from _typeshed import Incomplete from matplotlib.figure import Figure - from matplotlib.typing import ( - ColorType, - ) + from matplotlib.typing import ColorType, RGBAColorType from numpy.typing import NDArray from arpes._typing import DataType @@ -52,7 +53,8 @@ def offset_scatter_plot( data: xr.Dataset, name_to_plot: str = "", stack_axis: str = "", - cbarmap: tuple[colorbar.Colorbar, Callable[[float], ColorType]] | None = None, + cbarmap: tuple[Callable[..., colorbar.Colorbar], Callable[..., Callable[[float], ColorType]]] + | None = None, ax: Axes | None = None, out: str | Path = "", scale_coordinate: float = 0.5, @@ -271,15 +273,15 @@ def flat_stack_plot( **kwargs, ) try: - mpl.colorbar.Colorbar( + matplotlib.colorbar.Colorbar( ax_inset, orientation="horizontal", label=label_for_dim(data_array, stack_axis), - norm=mpl.colors.Normalize( + norm=matplotlib.colors.Normalize( vmin=data_array.coords[stack_axis].min().values, vmax=data_array.coords[stack_axis].max().values, ), - ticks=mpl.ticker.MaxNLocator(2), + ticks=matplotlib.ticker.MaxNLocator(2), cmap=color, ) except ValueError: @@ -306,9 +308,9 @@ def stack_dispersion_plot( ax: Axes | None = None, out: str | Path = "", max_stacks: int = 100, - scale_factor: float | None = None, + scale_factor: float = 0, *, - color: RGBAColorType | Colormap = "black", + color: ColorType | Colormap = "black", mode: Literal["line", "fill_between", "hide_line", "scatter"] = "line", offset_correction: Literal["zero", "constant", "constant_right"] | None = "zero", shift: float = 0, @@ -354,7 +356,7 @@ def stack_dispersion_plot( cvalues: NDArray[np.float_] = data_arr.coords[other_axis].values - if scale_factor is None: + if not scale_factor: scale_factor = _scale_factor( data_arr, stack_axis=stack_axis, @@ -417,7 +419,7 @@ def stack_dispersion_plot( x_label = other_axis y_label = stack_axis - yticker = mpl.ticker.MaxNLocator(5) + yticker = matplotlib.ticker.MaxNLocator(5) y_tick_region = [ i for i in yticker.tick_values( @@ -460,7 +462,7 @@ def stack_dispersion_plot( return fig, ax -def _y_shiftedt( +def _y_shifted( offset_correction: Literal["zero", "constant", "constant_right"] | None, marginal: xr.DataArray, coord_value: NDArray[np.float_], @@ -553,7 +555,7 @@ def _rebinning(data: DataType, stack_axis: str, max_stacks: int) -> tuple[xr.Dat def _color_for_plot( - color: Colormap | RGBAColorType, + color: Colormap | ColorType, i: int, num_plot: int, ) -> RGBAColorType: @@ -570,9 +572,3 @@ def _color_for_plot( return color msg = "color arg should be the cmap or color name or tuple as the color" raise TypeError(msg) - - -def overlapped_stack_dispersion_plot(*args: Incomplete, **kwargs: Incomplete) -> None: - """Leave it as backward compatibility.""" - msg = "use stack_dispersion_plot instead" - raise RuntimeError(msg) diff --git a/arpes/plotting/utils.py b/arpes/plotting/utils.py index 2f5f9308..b85e1bb2 100644 --- a/arpes/plotting/utils.py +++ b/arpes/plotting/utils.py @@ -15,21 +15,19 @@ from collections.abc import Sequence from datetime import UTC from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Literal, Unpack import matplotlib as mpl -import matplotlib.cm import matplotlib.pyplot as plt import numpy as np import xarray as xr -from matplotlib import cm, colorbar, colors, gridspec +from matplotlib import colorbar, colors, gridspec from matplotlib.axes import Axes from matplotlib.colors import Colormap from matplotlib.figure import Figure from matplotlib.lines import Line2D from arpes import VERSION -from arpes._typing import DataType from arpes.config import CONFIG, SETTINGS, attempt_determine_workspace, is_using_tex from arpes.utilities import normalize_to_spectrum from arpes.utilities.jupyter import get_notebook_name, get_recent_history @@ -38,12 +36,13 @@ from collections.abc import Callable from _typeshed import Incomplete + from lmfit.model import Model + from matplotlib.font_manager import FontProperties from matplotlib.image import AxesImage - from matplotlib.typing import RGBAColorType, RGBColorType + from matplotlib.typing import ColorType, RGBAColorType, RGBColorType from numpy.typing import NDArray - from typing_extensions import Unpack - from arpes._typing import ColorbarParam + from arpes._typing import ColorbarParam, DataType, MPLPlotKwargs, PLTSubplotParam __all__ = ( # General + IO @@ -108,6 +107,8 @@ "h_gradient_fill", ) +TwoDimensional = 2 + @contextlib.contextmanager def unchanged_limits(ax: Axes): @@ -120,13 +121,18 @@ def unchanged_limits(ax: Axes): ax.set_ylim(bottom=ylim[0], top=ylim[1]) -def mod_plot_to_ax(data: xr.DataArray, ax: Axes, mod, **kwargs: str | float) -> None: +def mod_plot_to_ax( + data: xr.DataArray, + ax: Axes, + mod: Model, + **kwargs: Unpack[MPLPlotKwargs], +) -> None: """Plots a model onto an axis using the data range from the passed data. Args: data(xr.DataArray): ARPES data ax (Axes): matplotlib Axes object - mod () <= FIXME + mod (lmfit.model.Model): Fitting model function **kwargs(): pass to "ax.plot" """ assert isinstance(data, xr.DataArray) @@ -137,13 +143,12 @@ def mod_plot_to_ax(data: xr.DataArray, ax: Axes, mod, **kwargs: str | float) -> ax.plot(xs, ys, **kwargs) -def h_gradient_fill( # noqa: PLR0913 +def h_gradient_fill( x1: float, x2: float, x_solid: float | None, fill_color: RGBColorType = "red", ax: Axes | None = None, - alpha: float = 1.0, **kwargs: str | float | Literal["pre", "post", "mid"], # zorder ) -> AxesImage: # <== checkme! """Fills a gradient between x1 and x2. @@ -157,7 +162,6 @@ def h_gradient_fill( # noqa: PLR0913 x_solid: fill_color (str): Color name, pass it as "c" in mpl.colors.to_rgb ax(Axes): matplotlib Axes object - alpha(float) **kwargs: Pass to im.show (Z order can be set here.) Returns: @@ -167,9 +171,9 @@ def h_gradient_fill( # noqa: PLR0913 ax = plt.gca() assert isinstance(ax, Axes) + alpha = float(kwargs.get("alpha", 1.0)) xlim, ylim = ax.get_xlim(), ax.get_ylim() assert fill_color - assert isinstance(alpha, float) z = np.empty((1, 100, 4), dtype=float) @@ -181,7 +185,7 @@ def h_gradient_fill( # noqa: PLR0913 im: AxesImage = ax.imshow( z, aspect="auto", - extent=[xmin, xmax, ymin, ymax], + extent=(xmin, xmax, ymin, ymax), origin="lower", **kwargs, ) @@ -201,7 +205,6 @@ def v_gradient_fill( y_solid: float | None, fill_color: RGBColorType = "red", ax: Axes | None = None, - alpha: float = 1.0, **kwargs: str | float, ) -> AxesImage: """Fills a gradient vertically between y1 and y2. @@ -215,7 +218,6 @@ def v_gradient_fill( y_solid: (float|solid) fill_color (str): Color name, pass it as "c" in mpl.colors.to_rgb (Default "red") ax(Axes): matplotlib Axes object - alpha (float): pass to plt.fill_between. **kwargs: (str|float): pass to ax.imshow Returns: @@ -224,10 +226,11 @@ def v_gradient_fill( if ax is None: ax = plt.gca() + alpha = float(kwargs.get("alpha", 1.0)) assert isinstance(ax, Axes) + xlim, ylim = ax.get_xlim(), ax.get_ylim() assert fill_color - assert isinstance(alpha, float) z = np.empty((100, 1, 4), dtype=float) @@ -239,7 +242,7 @@ def v_gradient_fill( im: AxesImage = ax.imshow( z, aspect="auto", - extent=[xmin, xmax, ymin, ymax], + extent=(xmin, xmax, ymin, ymax), origin="lower", **kwargs, ) @@ -256,8 +259,8 @@ def v_gradient_fill( def simple_ax_grid( n_axes: int, figsize: tuple[float, float] = (0, 0), - **kwargs: Incomplete, -) -> tuple[Figure, NDArray[Axes], NDArray[Axes]]: + **kwargs: Unpack[PLTSubplotParam], +) -> tuple[Figure, NDArray[np.object_], NDArray[np.object_]]: """Generates a square-ish set of axes and hides the extra ones. It would be nice to accept an "aspect ratio" item that will attempt to fix the @@ -294,7 +297,7 @@ def simple_ax_grid( @contextlib.contextmanager -def dark_background(overrides): +def dark_background(overrides: dict[str, Incomplete]): """Context manager for plotting "dark mode".""" defaults = { "axes.edgecolor": "white", @@ -370,7 +373,7 @@ def swap_axis_sides(ax: Axes) -> None: def transform_labels( - transform_fn: Callable, + transform_fn: Callable[..., str], fig: Figure | None = None, *, include_titles: bool = True, @@ -393,7 +396,7 @@ def transform_labels( ax.set_title(transform_fn(ax.get_title())) -def summarize(data: DataType, axes: np.ndarray | None = None): +def summarize(data: DataType, axes: NDArray[np.object_] | None = None) -> NDArray[np.object_]: """Makes a summary plot with different marginal plots represented.""" data_arr = normalize_to_spectrum(data) assert isinstance(data_arr, xr.DataArray) @@ -430,7 +433,9 @@ def sum_annotation( """Annotates that a given axis was summed over by listing the integration range.""" eV_annotation, phi_annotation = "", "" - def to_str(bound: float) -> str: + assert "use_tex" in SETTINGS + + def to_str(bound: float | None) -> str: if bound is None: return "" @@ -454,7 +459,9 @@ def mean_annotation(eV: slice | None = None, phi: slice | None = None) -> str: """Annotates that a given axis was meant (summed) over by listing the integration range.""" eV_annotation, phi_annotation = "", "" - def to_str(bound: float) -> str: + assert "use_tex" in SETTINGS + + def to_str(bound: float | None) -> str: if bound is None: return "" @@ -556,10 +563,10 @@ def quick_tex(latex_fragment: str, ax: Axes | None = None, fontsize: int = 30) - def lineplot_arr( - arr: xr.DataArray, + arr: xr.DataArray | xr.Dataset, ax: Axes | None = None, method: Literal["plot", "scatter"] = "plot", - mask=None, + mask: list[slice] | None = None, mask_kwargs: Incomplete | None = None, **kwargs: Incomplete, ) -> Axes: @@ -593,20 +600,21 @@ def lineplot_arr( def plot_arr( - arr=DataType | None, + arr: xr.DataArray | xr.Dataset, ax: Axes | None = None, - over=None, + over: AxesImage | None = None, mask: DataType | None = None, **kwargs: Incomplete, -) -> Axes: +) -> Axes | None: """Convenience method to plot an array with a mask over some other data.""" to_plot = arr if mask is None else mask + assert isinstance(to_plot, xr.DataArray | xr.Dataset) try: n_dims = len(to_plot.dims) except AttributeError: n_dims = 1 - if n_dims == 2: + if n_dims == TwoDimensional: quad = None if arr is not None: ax, quad = imshow_arr(arr, ax=ax, over=over, **kwargs) @@ -633,7 +641,7 @@ def imshow_mask( ax = plt.gca() assert isinstance(ax, Axes) if isinstance(cmap, str): - cmap = cm.get_cmap(name=cmap) + cmap = mpl.colormaps.get_cmap(cmap=cmap) assert isinstance(cmap, Colormap) cmap.set_bad("k", alpha=0) @@ -652,7 +660,7 @@ def imshow_mask( def imshow_arr( - arr: xr.DataArray, + arr: xr.DataArray | xr.Dataset, ax: Axes | None = None, over: AxesImage | None = None, origin: Literal["lower", "upper"] = "lower", @@ -680,10 +688,10 @@ def imshow_arr( Returns: The axes and quadmesh instance. """ - assert isinstance(arr, xr.DataArray) if ax is None: fig, ax = plt.subplots() assert isinstance(ax, Axes) + x, y = arr.coords[arr.dims[0]].values, arr.coords[arr.dims[1]].values extent = [y[0], y[-1], x[0], x[-1]] @@ -694,9 +702,9 @@ def imshow_arr( if vmax is None: vmax = arr.max().item() if isinstance(cmap, str): - cmap = cm.get_cmap(cmap) + cmap = mpl.colormaps.get_cmap(cmap) norm = colors.Normalize(vmin=vmin, vmax=vmax) - mappable = cm.ScalarMappable(cmap=cmap, norm=norm) + mappable = mpl.cm.ScalarMappable(cmap=cmap, norm=norm) mapped_colors = mappable.to_rgba(arr.values) mapped_colors[:, :, 3] = alpha quad = ax.imshow(mapped_colors, origin=origin, extent=extent, aspect=aspect, **kwargs) @@ -727,8 +735,6 @@ def imshow_arr( def dos_axes( orientation: str = "horiz", figsize: tuple[int, int] | tuple[()] = (), - *, - with_cbar: bool = True, ) -> tuple[Figure, tuple[Axes, ...]]: """Makes axes corresponding to density of states data. @@ -766,10 +772,10 @@ def inset_cut_locator( data: DataType, reference_data: DataType, ax: Axes | None = None, - location=None, + location: dict[str, Incomplete] | None = None, color: RGBColorType = "red", **kwargs: Incomplete, -): +) -> None: """Plots a reference cut location over a figure. Another approach is to separately plot the locator and add it in Illustrator or @@ -827,7 +833,7 @@ def resolve(name: str, value: slice | int) -> NDArray[np.float_]: assert reference_data is not None print(missing_dims) - if n_cut_dims == 2: + if n_cut_dims == TwoDimensional: # a region cut, illustrate with a rect or by suppressing background return @@ -839,14 +845,14 @@ def resolve(name: str, value: slice | int) -> NDArray[np.float_]: pass -def generic_colormap(low: float, high: float) -> Callable[[float], RGBAColorType]: +def generic_colormap(low: float, high: float) -> Callable[..., RGBAColorType]: """Generates a colormap from the cm.Blues palette, suitable for most purposes.""" delta = high - low low = low - delta / 6 high = high + delta / 6 def get_color(value: float) -> RGBAColorType: - return mpl.cm.Blues(float((value - low) / (high - low))) + return mpl.colormaps.get_cmap("Blues")(float((value - low) / (high - low))) return get_color @@ -858,7 +864,7 @@ def phase_angle_colormap( """Generates a colormap suitable for angular data or data on a unit circle like a phase.""" def get_color(value: float) -> RGBAColorType: - return cm.twilight_shifted(float((value - low) / (high - low))) + return mpl.colormaps.get_cmap("twilight_shifted")(float((value - low) / (high - low))) return get_color @@ -867,7 +873,7 @@ def delay_colormap(low: float = -1, high: float = 1) -> Callable[[float], RGBACo """Generates a colormap suitable for pump-probe delay data.""" def get_color(value: float) -> RGBAColorType: - return cm.coolwarm(float((value - low) / (high - low))) + return mpl.colormaps.get_cmap("coolwarm")(float((value - low) / (high - low))) return get_color @@ -875,7 +881,7 @@ def get_color(value: float) -> RGBAColorType: def temperature_colormap( low: float = 0, high: float = 300, - cmap: Colormap = matplotlib.cm.Blues_r, + cmap: Colormap = mpl.colormaps["Blues_r"], ) -> Callable[[float], RGBAColorType]: """Generates a colormap suitable for temperature data with fixed extent.""" @@ -885,11 +891,14 @@ def get_color(value: float) -> RGBAColorType: return get_color -def temperature_colormap_around(central, region: float = 50) -> Callable[[float], RGBAColorType]: +def temperature_colormap_around( + central: float, + region: float = 50, +) -> Callable[[float], RGBAColorType]: """Generates a colormap suitable for temperature data around a central value.""" def get_color(value: float) -> RGBAColorType: - return cm.RdBu_r(float((value - central) / region)) + return mpl.colormaps.get_cmap("RdBu_r")(float((value - central) / region)) return get_color @@ -898,9 +907,7 @@ def generic_colorbar( low: float, high: float, ax: Axes, - label: str = "", cmap: str | Colormap = "Blues", - ticks=None, **kwargs: Unpack[ColorbarParam], ) -> colorbar.Colorbar: """Generate colorbar. @@ -909,14 +916,13 @@ def generic_colorbar( low(float): value for lowest value of the colorbar high(float): value for hightst value of the colorbar ax(Axes): Matplotlib Axes object - label(str): label name cmap(str | Colormap): color map - **kwags: Pass to ColoarbarBase + **kwargs: Pass to ColoarbarBase """ + ticks = kwargs.get("ticks", [low, high]) extra_kwargs = { "orientation": "horizontal", - "label": label, - "ticks": ticks if ticks is not None else [low, high], + "ticks": ticks, } delta = high - low @@ -925,7 +931,7 @@ def generic_colorbar( extra_kwargs.update(kwargs) if isinstance(cmap, str): - cmap = cm.get_cmap(cmap) + cmap = mpl.colormaps.get_cmap(cmap) return colorbar.Colorbar( ax, cmap=cmap, @@ -942,9 +948,12 @@ def phase_angle_colorbar( ) -> colorbar.Colorbar: """Generates a colorbar suitable for plotting an angle or value on a unit circle.""" assert isinstance(ax, Axes) + assert "use_tex" in SETTINGS + label = kwargs.get("label", "Angle") + extra_kwargs = { "orientation": "horizontal", - "label": "Angle", + "label": label, "ticks": ["0", r"$\pi$", r"$2\pi$"], } @@ -954,7 +963,7 @@ def phase_angle_colorbar( extra_kwargs.update(kwargs) return colorbar.Colorbar( ax, - cmap=cm.get_cmap("twilight_shifted"), + cmap=mpl.colormaps.get_cmap("twilight_shifted"), norm=colors.Normalize(vmin=low, vmax=high), **extra_kwargs, ) @@ -970,11 +979,11 @@ def temperature_colorbar( """Generates a colorbar suitable for temperature data with fixed extent.""" assert isinstance(ax, Axes) if isinstance(cmap, str): - cmap = cm.get_cmap(cmap) - + cmap = mpl.colormaps.get_cmap(cmap) + label = kwargs.get("label", "Temperature (K)") extra_kwargs = { "orientation": "horizontal", - "label": "Temperature (K)", + "label": label, "ticks": [low, high], } extra_kwargs.update(kwargs) @@ -998,13 +1007,14 @@ def delay_colorbar( TODO make this nonsequential for use in case where you want to have a long time period after the delay or before. """ + label = kwargs.get("label", "Probe pulse delay (fs)") extra_kwargs = { "orientation": "horizontal", - "label": "Probe Pulse Delay (ps)", + "label": label, "ticks": [low, 0, high], } extra_kwargs.update(kwargs) - cmap = cm.get_cmap("coolwarm") + cmap = mpl.colormaps.get_cmap("coolwarm") return colorbar.Colorbar( ax, cmap=cmap, @@ -1014,24 +1024,25 @@ def delay_colorbar( def temperature_colorbar_around( - central, - range=50, + central: float, + temperature_range: float = 50, ax: Axes | None = None, **kwargs: Incomplete, ) -> colorbar.Colorbar: """Generates a colorbar suitable for temperature axes around a central value.""" assert isinstance(ax, Axes) + label = kwargs.get("label", "Temperature (K)") extra_kwargs = { "orientation": "horizontal", - "label": "Temperature (K)", - "ticks": [central - range, central + range], + "label": label, + "ticks": [central - temperature_range, central + temperature_range], } extra_kwargs.update(kwargs) - cmap = cm.get_cmap("RdBu_r") + cmap = mpl.colormaps.get_cmap("RdBu_r") return colorbar.Colorbar( ax, cmap=cmap, - norm=colors.Normalize(vmin=central - range, vmax=central + range), + norm=colors.Normalize(vmin=central - temperature_range, vmax=central + temperature_range), **extra_kwargs, ) @@ -1039,8 +1050,11 @@ def temperature_colorbar_around( colorbarmaps_for_axis: dict[ str, tuple[ - Callable[[float, float, Axes | None, str | Colormap, Any], colorbar.Colorbar], - Callable[[float, float, Colormap], Callable[[float], RGBAColorType]], + Callable[..., colorbar.Colorbar], + Callable[ + ..., + Callable[..., RGBAColorType], + ], ], ] = { "temp": ( @@ -1075,7 +1089,7 @@ def get_colorbars(fig: Figure | None = None) -> list[Axes]: return colorbars -def remove_colorbars(fig: Figure | None = None): +def remove_colorbars(fig: Figure | None = None) -> None: """Removes colorbars from given (or, if no given figure, current) matplotlib figure. Args: @@ -1106,7 +1120,7 @@ def generic_colorbarmap_for_data( *, keep_ticks: bool = True, **kwargs: Incomplete, -) -> tuple[colorbar.Colorbar, Callable[[float], RGBAColorType]]: +) -> tuple[colorbar.Colorbar, Callable[..., RGBAColorType]]: """Generates a colorbar and colormap which is useful in general context. Args: @@ -1128,7 +1142,7 @@ def generic_colorbarmap_for_data( ) -def polarization_colorbar(ax: Axes | None = None): +def polarization_colorbar(ax: Axes | None = None) -> colorbar.Colorbar: """Makes a colorbar which is appropriate for "polarization" (e.g. spin) data.""" assert isinstance(ax, Axes) return colorbar.Colorbar( @@ -1141,11 +1155,11 @@ def polarization_colorbar(ax: Axes | None = None): ) -def calculate_aspect_ratio(data: DataType): +def calculate_aspect_ratio(data: DataType) -> float: """Calculate the aspect ratio which should be used for plotting some data based on extent.""" data_arr = normalize_to_spectrum(data) assert isinstance(data_arr, xr.DataArray) - assert len(data.dims) == 2 + assert len(data.dims) == TwoDimensional x_extent = np.ptp(data_arr.coords[data_arr.dims[0]].values) y_extent = np.ptp(data_arr.coords[data_arr.dims[1]].values) @@ -1162,18 +1176,19 @@ class AnchoredHScaleBar(mpl.offsetbox.AnchoredOffsetbox): def __init__( self, - size=1, - extent=0.03, - label="", - loc=2, + size: float = 1, + extent: float = 0.03, + label: str = "", + loc: int = 2, ax: Axes | None = None, - pad=0.4, - borderpad=0.5, - ppad=0, - sep=2, - prop=None, - label_color=None, - frameon=True, + pad: float = 0.4, + borderpad: float = 0.5, + ppad: float = 0, + sep: int = 2, + prop: FontProperties | None = None, + label_color: ColorType | None = None, + *, + frameon: bool = True, **kwargs: Incomplete, ) -> None: """Setup the scale bar and coordinate transforms to the parent axis.""" @@ -1213,7 +1228,7 @@ def __init__( ) -def load_data_for_figure(p: str | Path): +def load_data_for_figure(p: str | Path) -> None: """Tries to load the data associated with a given figure by unpickling the saved data.""" path = str(p) stem = os.path.splitext(path)[0] @@ -1231,13 +1246,14 @@ def load_data_for_figure(p: str | Path): def savefig( - desired_path, + desired_path: str, dpi: int = 400, data=None, save_data=None, - paper=False, + *, + paper: bool = False, **kwargs: Incomplete, -): +) -> None: """The PyARPES preferred figure saving routine. Provides a number of conveniences over matplotlib's `savefig`: @@ -1270,9 +1286,9 @@ def savefig( high_dpi = max(dpi, 400) formats_for_paper = ["pdf", "png"] # not including SVG anymore because files too large - for format in formats_for_paper: + for the_format in formats_for_paper: savefig( - f"{desired_path}-PAPER.{format}", + f"{desired_path}-PAPER.{the_format}", dpi=high_dpi, data=data, paper=False, @@ -1284,7 +1300,7 @@ def savefig( return full_path = path_for_plot(desired_path) - provenance_path = full_path + ".provenance.json" + provenance_path = str(full_path) + ".provenance.json" provenance_context = { "VERSION": VERSION, "time": datetime.datetime.now(UTC).isoformat(), @@ -1317,7 +1333,7 @@ def extract(for_data): }, ) - with open(provenance_path, "w") as f: + with Path(provenance_path).open("w") as f: json.dump(provenance_context, f, indent=2) plt.savefig(full_path, dpi=dpi, **kwargs) @@ -1378,6 +1394,8 @@ def path_for_holoviews(desired_path): def name_for_dim(dim_name: str, *, escaped: bool = True) -> str: """Alternate variant of `label_for_dim`.""" + assert "use_tex" in SETTINGS + if SETTINGS["use_tex"]: name = { "temperature": "Temperature", @@ -1419,6 +1437,7 @@ def name_for_dim(dim_name: str, *, escaped: bool = True) -> str: def unit_for_dim(dim_name: str, *, escaped: bool = True) -> str: """Calculate LaTeX or fancy display label for the unit associated to a dimension.""" + assert "use_tex" in SETTINGS if SETTINGS["use_tex"]: unit = { "temperature": "K", @@ -1622,6 +1641,7 @@ def fancy_labels( def label_for_symmetry_point(point_name: str) -> str: """Determines the LaTeX label for a symmetry point shortcode.""" + assert "use_tex" in SETTINGS if SETTINGS["use_tex"]: proper_names = {"G": r"$\Gamma$", "X": r"X", "Y": r"Y", "K": r"K"} else: @@ -1691,14 +1711,14 @@ def draw(self) -> None: self.handles.append(handle) @property - def data_units_per_pixel(self): + def data_units_per_pixel(self) -> tuple[float, float]: """Gets the data/pixel conversion ratio.""" trans = self.ax.transData.transform inverse = (trans((1, 1)) - trans((0, 0))) * self.ppd return (1 / inverse[0], 1 / inverse[1]) def normalize_line_args(self, args): - def is_data_type(value): + def is_data_type(value) -> bool: return isinstance(value, np.array | np.ndarray | list | tuple) assert is_data_type(args[0]) diff --git a/arpes/utilities/attrs.py b/arpes/utilities/attrs.py index 789091cf..a068f805 100644 --- a/arpes/utilities/attrs.py +++ b/arpes/utilities/attrs.py @@ -3,7 +3,6 @@ This is useful for comparing two pieces of data, or working on implementing a data loading plugin. """ -from pprint import pprint from typing import Any import numpy as np @@ -18,7 +17,6 @@ def diff_attrs( a: DataType, b: DataType, *, - should_print: bool = True, skip_nan: bool = False, skip_composite: bool = True, ) -> None | tuple[dict[str, Any], dict[str, Any], pd.DataFrame]: @@ -70,14 +68,4 @@ def should_skip(k: str) -> bool: }, ).set_index("key") - if should_print: - print("A has:") - pprint(a_has) - - print("\nB has:") - pprint(b_has) - - print("\nDifferences:") - print(diff.to_string()) - return None return a_has, b_has, diff diff --git a/arpes/utilities/bz.py b/arpes/utilities/bz.py index 37499896..6c8d426e 100644 --- a/arpes/utilities/bz.py +++ b/arpes/utilities/bz.py @@ -11,14 +11,17 @@ import itertools import re -from collections import Counter, namedtuple -from typing import TYPE_CHECKING +from collections import Counter +from typing import TYPE_CHECKING, NamedTuple import matplotlib.path import numpy as np if TYPE_CHECKING: from _typeshed import Incomplete + from numpy.typing import NDArray + + from arpes._typing import DataType __all__ = ( "bz_symmetry", @@ -49,20 +52,26 @@ "hex": {"G", "X", "BX"}, } -SpecialPoint = namedtuple("SpecialPoint", ("name", "negate", "bz_coord")) +TWO_DIMENSIONAL = 2 + + +class SpecialPoint(NamedTuple): + name: str + negate: bool + bz_coord: NDArray[np.float_] | list[float] | tuple[float, ...] -def as_3d(points): +def as_3d(points: NDArray[np.float_]) -> NDArray[np.float_]: """Takes a 2D points list and zero pads to convert to a 3D representation.""" return np.concatenate([points, points[:, 0][:, None] * 0], axis=1) -def as_2d(points): +def as_2d(points: NDArray[np.float_]) -> NDArray[np.float_]: """Takes a 3D points and converts to a 2D representation by dropping the z coordinates.""" return points[:, :2] -def parse_single_path(path): +def parse_single_path(path: str) -> list[SpecialPoint]: """Converts a path given by high symmetry point names to numerical coordinate arrays.""" # first tokenize tokens = [name for name in re.split(r"([A-Z][a-z0-9]*(?:\([0-9,\s]+\))?)", path) if name] @@ -79,7 +88,7 @@ def parse_single_path(path): negate = True rest = rest[1:] - bz_coords = ( + bz_coords: tuple[int, ...] = ( 0, 0, 0, @@ -88,14 +97,14 @@ def parse_single_path(path): rest = "".join(c for c in rest if c not in "( \t\n\r)") bz_coords = tuple([int(c) for c in rest.split(",")]) - if len(bz_coords) == 2: + if len(bz_coords) == TWO_DIMENSIONAL: bz_coords = (*list(bz_coords), 0) points.append(SpecialPoint(name=name, negate=negate, bz_coord=bz_coords)) return points -def parse_path(paths): +def parse_path(paths: str | list[str]) -> list[list[SpecialPoint]]: """Converts paths to arrays with the coordinate locations for those paths.""" if isinstance(paths, str): # some manual string work in order to make sure we do not split on commas inside BZ indices @@ -116,7 +125,11 @@ def parse_path(paths): return [parse_single_path(p) for p in paths] -def special_point_to_vector(special_point, icell, special_points): +def special_point_to_vector( + special_point: SpecialPoint, + icell: Incomplete, + special_points: dict[str, NDArray[np.float_]], +) -> NDArray[np.float_]: """Converts a single special point to its coordinate vector.""" base = np.dot(icell.T, special_points[special_point.name]) @@ -127,9 +140,13 @@ def special_point_to_vector(special_point, icell, special_points): return base + coord.dot(icell) -def process_kpath(paths, cell, special_points=None): +def process_kpath( + paths: str | list[str], + cell: Incomplete, + special_points: dict[str, NDArray[np.float_]] | None = None, +) -> list[list[NDArray[np.float_]]]: """Converts paths consistign of point definitions to raw coordinates.""" - if len(cell) == 2: + if len(cell) == TWO_DIMENSIONAL: cell = [[*c, 0] for c in cell] + [0, 0, 0] icell = np.linalg.inv(cell).T @@ -218,7 +235,11 @@ def flat_bz_indices_list(bz_indices_list=None): return indices -def generate_2d_equivalent_points(points, icell, bz_indices_list=None): +def generate_2d_equivalent_points( + points: NDArray[np.float_], + icell: NDArray[np.float_], + bz_indices_list=None, +) -> NDArray[np.float_]: """Generates the equivalent points in higher order Brillouin zones.""" points_list = [] for x, y in flat_bz_indices_list(bz_indices_list): @@ -239,7 +260,11 @@ def generate_2d_equivalent_points(points, icell, bz_indices_list=None): return np.unique(np.concatenate(points_list), axis=0) -def build_2dbz_poly(vertices=None, icell=None, cell=None): +def build_2dbz_poly( + vertices: NDArray[np.float_] | None = None, + icell: NDArray[np.float_] | None = None, + cell=None, +): """Converts brillouin zone or equivalent information to a polygon mask. This mask can be used to mask away data outside the zone boundary. @@ -353,7 +378,7 @@ def axis_along(data, S): return max_dim -def reduced_bz_poly(data, *, scale_zone: bool = False): +def reduced_bz_poly(data: DataType, *, scale_zone: bool = False) -> NDArray[np.float_]: """Returns a polynomial representing the reduce first Brillouin zone.""" symmetry = bz_symmetry(data.S.iter_own_symmetry_points) point_names = _POINT_NAMES_FOR_SYMMETRY[symmetry] @@ -465,7 +490,7 @@ def reduced_bz_selection(data): return data -def bz_cutter(symmetry_points, reduced=True): +def bz_cutter(symmetry_points, *, reduced: bool = True): """Cuts data so that it areas outside the Brillouin zone are masked away. TODO: UNFINISHED. @@ -474,7 +499,7 @@ def bz_cutter(symmetry_points, reduced=True): def build_bz_mask(data): pass - def cutter(data, cut_value=np.nan): + def cutter(data, cut_value: float = np.nan): mask = build_bz_mask(data) out = data.copy() diff --git a/arpes/utilities/collections.py b/arpes/utilities/collections.py index 75c71bad..9135a98a 100644 --- a/arpes/utilities/collections.py +++ b/arpes/utilities/collections.py @@ -30,7 +30,7 @@ def __sub__(self, other: MappableDict): return MappableDict({k: self.get(k) - other.get(k) for k in self}) - def __mul__(self, other: MappableDict): + def __mul__(self, other: MappableDict) -> MappableDict: """Applies `*` onto values.""" if set(self.keys()) != set(other.keys()): msg = "You can only multiply two MappableDicts with the same keys." @@ -38,7 +38,7 @@ def __mul__(self, other: MappableDict): return MappableDict({k: self.get(k) * other.get(k) for k in self}) - def __truediv__(self, other: MappableDict): + def __truediv__(self, other: MappableDict) -> MappableDict: """Applies `/` onto values.""" if set(self.keys()) != set(other.keys()): msg = "You can only divide two MappableDicts with the same keys." @@ -46,7 +46,7 @@ def __truediv__(self, other: MappableDict): return MappableDict({k: self.get(k) / other.get(k) for k in self}) - def __floordiv__(self, other: MappableDict): + def __floordiv__(self, other: MappableDict) -> MappableDict: """Applies `//` onto values.""" if set(self.keys()) != set(other.keys()): msg = "You can only divide (//) two MappableDicts with the same keys." @@ -54,12 +54,12 @@ def __floordiv__(self, other: MappableDict): return MappableDict({k: self.get(k) // other.get(k) for k in self}) - def __neg__(self): + def __neg__(self) -> MappableDict: """Applies unary negation onto values.""" return MappableDict({k: -self.get(k) for k in self}) -def deep_update(destination: Any, source: Any) -> dict[str, Any]: +def deep_update(destination: dict[str, Any], source: dict[str, Any]) -> dict[str, Any]: """Doesn't clobber keys further down trees like doing a shallow update would. Instead recurse down from the root and update as appropriate. @@ -80,7 +80,7 @@ def deep_update(destination: Any, source: Any) -> dict[str, Any]: return destination -def deep_equals(a: Any, b: Any) -> bool: +def deep_equals(a: Any, b: Any) -> bool | None: """An equality check that looks into common collection types.""" if not isinstance(b, type(a)): return False diff --git a/arpes/utilities/funcutils.py b/arpes/utilities/funcutils.py index d90c62a0..7113a748 100644 --- a/arpes/utilities/funcutils.py +++ b/arpes/utilities/funcutils.py @@ -58,7 +58,7 @@ def group_by(grouping: int | Generator, sequence: Sequence) -> list: return groups -def collect_leaves(tree: dict[str, Any], is_leaf: Any = None) -> dict: +def collect_leaves(tree: dict[str, Any], is_leaf: Incomplete = None) -> dict: """Produces a flat representation of the leaves. Leaves with the same key are collected into a list in the order of appearance, @@ -84,7 +84,7 @@ def reducer(dd: dict, item: tuple[str, NDArray[np.float_]]) -> dict: def iter_leaves( tree: dict[str, Any], - is_leaf: Callable | None = None, + is_leaf: Callable[..., bool] | None = None, ) -> Iterator[tuple[str, ndarray]]: """Iterates across the leaves of a nested dictionary. @@ -100,7 +100,7 @@ def iter_leaves( """ if is_leaf is None: - def is_leaf(x): + def is_leaf(x: dict) -> bool: return not isinstance(x, dict) for k, v in tree.items(): @@ -110,7 +110,7 @@ def is_leaf(x): yield from iter_leaves(v) -def lift_dataarray_to_generic(f): +def lift_dataarray_to_generic(func: Callable[..., DataType]) -> Callable[..., DataType]: """A functorial decorator that lifts functions to operate over xarray types. (xr.DataArray, *args, **kwargs) -> xr.DataArray @@ -123,12 +123,12 @@ def lift_dataarray_to_generic(f): i.e. one that will operate either over xr.DataArrays or xr.Datasets. """ - @functools.wraps(f) - def func_wrapper(data: DataType, *args: Incomplete, **kwargs: Incomplete): + @functools.wraps(func) + def func_wrapper(data: DataType, *args: Incomplete, **kwargs: Incomplete) -> DataType: if isinstance(data, xr.DataArray): - return f(data, *args, **kwargs) + return func(data, *args, **kwargs) assert isinstance(data, xr.Dataset) - new_vars = {datavar: f(data[datavar], *args, **kwargs) for datavar in data.data_vars} + new_vars = {datavar: func(data[datavar], *args, **kwargs) for datavar in data.data_vars} for var_name, var in new_vars.items(): if isinstance(var, xr.DataArray) and var.name is None: @@ -148,22 +148,22 @@ class Debounce: slider. """ - def __init__(self, period) -> None: + def __init__(self, period: float) -> None: """Sets up the internal state for debounce tracking.""" self.period = period # never call the wrapped function more often than this (in seconds) self.count = 0 # how many times have we successfully called the function self.count_rejected = 0 # how many times have we rejected the call self.last = None # the last time it was called - def reset(self): + def reset(self) -> None: """Force a reset of the timer, aka the next call will always work.""" self.last = None - def __call__(self, f): + def __call__(self, func: Callable[..., Any]) -> Callable[..., None]: """The wrapper call which defers execution if the function was actually called recently.""" - @functools.wraps(f) - def wrapped(*args: Incomplete, **kwargs: Incomplete): + @functools.wraps(func) + def wrapped(*args: Incomplete, **kwargs: Incomplete) -> None: now = time.time() willcall = False if self.last is not None: @@ -177,7 +177,7 @@ def wrapped(*args: Incomplete, **kwargs: Incomplete): # set these first in case we throw an exception self.last = now # don't use time.time() self.count += 1 - f(*args, **kwargs) # call wrapped function + func(*args, **kwargs) # call wrapped function else: self.count_rejected += 1 diff --git a/arpes/utilities/jupyter.py b/arpes/utilities/jupyter.py index a9167f5d..470bb979 100644 --- a/arpes/utilities/jupyter.py +++ b/arpes/utilities/jupyter.py @@ -24,7 +24,7 @@ ) -def wrap_tqdm(x, interactive=True, *args: Incomplete, **kwargs: Incomplete): +def wrap_tqdm(x, *args: Incomplete, interactive: bool = True, **kwargs: Incomplete): """Wraps with tqdm_notebook but supports disabling with a flag.""" if not interactive: return x @@ -74,11 +74,10 @@ def get_notebook_name() -> str | None: can only return None. """ jupyter_info = get_full_notebook_information() - - try: - return jupyter_info["session"]["notebook"]["name"].split(".")[0] - except (KeyError, TypeError): + if type(jupyter_info) is None: return None + assert jupyter_info is not None + return jupyter_info["session"]["notebook"]["name"].split(".")[0] def generate_logfile_path() -> Path: @@ -92,12 +91,12 @@ def generate_logfile_path() -> Path: return Path("logs") / full_name -def get_recent_history(n_items=10) -> list[str]: +def get_recent_history(n_items: int = 10) -> list[str]: """Fetches recent cell evaluations for context on provenance outputs.""" try: - import IPython + from IPython.core.getipython import get_ipython - ipython = IPython.get_ipython() + ipython = get_ipython() return [ l[-1] for l in list(ipython.history_manager.get_tail(n=n_items, include_latest=True)) @@ -106,19 +105,19 @@ def get_recent_history(n_items=10) -> list[str]: return ["No accessible history."] -def get_recent_logs(n_bytes=1000) -> list[str]: +def get_recent_logs(n_bytes: int = 1000) -> list[str]: """Fetches a recent chunk of user logs. Used to populate a context on provenance outputs.""" import arpes.config try: - import IPython + from IPython.core.getipython import get_ipython - ipython = IPython.get_ipython() + ipython = get_ipython() if arpes.config.CONFIG["LOGGING_STARTED"]: logging_file = arpes.config.CONFIG["LOGGING_FILE"] + assert isinstance(logging_file, str | Path) - print(logging_file) - with open(logging_file, "rb") as file: + with Path(logging_file).open("rb") as file: try: file.seek(-n_bytes, os.SEEK_END) except OSError: diff --git a/arpes/utilities/math.py b/arpes/utilities/math.py index de955bd0..73fb808e 100644 --- a/arpes/utilities/math.py +++ b/arpes/utilities/math.py @@ -16,7 +16,7 @@ from numpy.typing import NDArray -def derivative(f: Callable, arg_idx: int = 0) -> float: +def derivative(f: Callable[..., float], arg_idx: int = 0) -> float: """Defines a simple midpoint derivative.""" def d(*args: Incomplete): @@ -32,7 +32,7 @@ def d(*args: Incomplete): return d -def polarization(up, down): +def polarization(up: NDArray[np.float_], down: NDArray[np.float_]) -> NDArray[np.float_]: """The equivalent normalized difference for a two component signal.""" return (up - down) / (up + down) @@ -58,7 +58,7 @@ def shift_by( arr: NDArray[np.float_], value: xr.DataArray | NDArray[np.float_], axis: float = 0, - by_axis=0, + by_axis: float = 0, **kwargs: Incomplete, ) -> NDArray[np.float_]: """Shifts slices of `arr` perpendicular to `by_axis` by `value`. diff --git a/arpes/utilities/qt/app.py b/arpes/utilities/qt/app.py index 9a3f5285..b451971f 100644 --- a/arpes/utilities/qt/app.py +++ b/arpes/utilities/qt/app.py @@ -5,6 +5,7 @@ from collections import defaultdict from typing import TYPE_CHECKING, Any +import matplotlib as mpl import numpy as np import pyqtgraph as pg import xarray as xr @@ -120,10 +121,8 @@ def build_pg_cmap(colormap): def set_colormap(self, colormap): """Finds all `DataArrayImageView` instances and sets their color palette.""" - import matplotlib.cm - if isinstance(colormap, str): - colormap = matplotlib.cm.get_cmap(colormap) + colormap = mpl.colormaps.get_cmap(colormap) cmap = self.build_pg_cmap(colormap) for view in self.views.values(): diff --git a/arpes/utilities/string.py b/arpes/utilities/string.py index 433d315d..b1202e77 100644 --- a/arpes/utilities/string.py +++ b/arpes/utilities/string.py @@ -11,12 +11,10 @@ def safe_decode(input_bytes: bytes, prefer: str = "") -> str | None: if prefer: codecs = [prefer, *codecs] - - for codec in codecs: - try: + try: + for codec in codecs: return input_bytes.decode(codec) - except UnicodeDecodeError: - pass + except UnicodeDecodeError: + pass - input_bytes.decode("utf-8") # COULD NOT DETERMINE CODEC, RAISE return None diff --git a/arpes/widgets.py b/arpes/widgets.py index 34f89c81..a7b38c82 100644 --- a/arpes/widgets.py +++ b/arpes/widgets.py @@ -304,7 +304,7 @@ def data(self, new_data: xr.DataArray) -> None: @property def mask_cmap(self): if self._mask_cmap is None: - self._mask_cmap = mpl.cm.get_cmap(self.mask_kwargs.pop("cmap", "Reds")) + self._mask_cmap = mpl.colormaps.get_cmap(self.mask_kwargs.pop("cmap", "Reds")) self._mask_cmap.set_bad("k", alpha=0) return self._mask_cmap diff --git a/arpes/workflow.py b/arpes/workflow.py index b6f75c53..5f3ebceb 100644 --- a/arpes/workflow.py +++ b/arpes/workflow.py @@ -78,8 +78,6 @@ def _open_path(p: Path | str) -> None: if "win" in sys.platform: subprocess.Popen(rf"explorer {p}") - print(p) - @with_workspace def go_to_workspace(workspace=None) -> None: diff --git a/arpes/xarray_extensions.py b/arpes/xarray_extensions.py index afa61365..00ac353e 100644 --- a/arpes/xarray_extensions.py +++ b/arpes/xarray_extensions.py @@ -979,7 +979,7 @@ def inner_potential(self) -> float: return self._obj.attrs["inner_potential"] return 10 - def find_spectrum_energy_edges(self, *, indices: bool = False): + def find_spectrum_energy_edges(self, *, indices: bool = False) -> NDArray[np.float_]: assert isinstance(self._obj, xr.Dataset | xr.DataArray) energy_marginal = self._obj.sum([d for d in self._obj.dims if d not in ["eV"]]) @@ -1001,7 +1001,11 @@ def find_spectrum_energy_edges(self, *, indices: bool = False): delta = self._obj.G.stride(generic_dim_names=False) return edges * delta["eV"] + self._obj.coords["eV"].values[0] - def find_spectrum_angular_edges_full(self, *, indices: bool = False): + def find_spectrum_angular_edges_full( + self, + *, + indices: bool = False, + ) -> tuple[NDArray[np.float_], NDArray[np.float_], xr.DataArray]: # as a first pass, we need to find the bottom of the spectrum, we will use this # to select the active region and then to rebin into course steps in energy from 0 # down to this region @@ -1013,7 +1017,7 @@ def find_spectrum_angular_edges_full(self, *, indices: bool = False): if high_edge - low_edge < 0.15: # Doesn't look like the automatic inference of the energy edge was valid - high_edge = 0 + high_edge = 0.0 low_edge = np.min(self._obj.coords["eV"].values) angular_dim = "pixel" if "pixel" in self._obj.dims else "phi" @@ -1066,7 +1070,13 @@ def find_spectrum_angular_edges_full(self, *, indices: bool = False): return low_edges, high_edges, rebinned.coords["eV"] - def zero_spectrometer_edges(self, cut_margin=None, interp_range=None, low=None, high=None): + def zero_spectrometer_edges( + self, + cut_margin=None, + interp_range=None, + low=None, + high=None, + ) -> xr.DataArray | xr.Dataset: if low is not None: assert high is not None assert len(low) == len(high) == 2 @@ -1191,7 +1201,10 @@ def meso_effective_selector(self) -> slice: return slice(np.max(energy_edge) - 0.3, np.max(energy_edge) - 0.1) def region_sel(self, *regions: Incomplete) -> xr.Dataset | xr.DataArray: - def process_region_selector(selector: slice | DesignatedRegions, dimension_name: str): + def process_region_selector( + selector: slice | DesignatedRegions, + dimension_name: str, + ) -> slice | Callable[..., slice]: if isinstance(selector, slice): return selector @@ -1240,7 +1253,7 @@ def process_region_selector(selector: slice | DesignatedRegions, dimension_name: obj = self._obj - def unpack_dim(dim_name): + def unpack_dim(dim_name: str) -> str: if dim_name == "angular": return "pixel" if "pixel" in obj.dims else "phi" @@ -1618,7 +1631,7 @@ def beamline_info(self) -> dict[str, xr.DataArray | NDArray[np.float_] | float]: ) @property - def sweep_settings(self) -> dict[str, xr.DataArray | NDArray[np.float_] | float]: + def sweep_settings(self) -> dict[str, xr.DataArray | NDArray[np.float_] | float | None]: """For datasets acquired with swept acquisition settings, provides those settings.""" return { "high_energy": self._obj.attrs.get("sweep_high_energy"), @@ -1631,16 +1644,16 @@ def sweep_settings(self) -> dict[str, xr.DataArray | NDArray[np.float_] | float] def probe_polarization(self) -> tuple[float, float]: """Provides the probe polarization of the UV/x-ray source.""" return ( - self._obj.attrs.get("probe_polarization_theta"), - self._obj.attrs.get("probe_polarization_alpha"), + self._obj.attrs.get("probe_polarization_theta", np.nan), + self._obj.attrs.get("probe_polarization_alpha", np.nan), ) @property def pump_polarization(self) -> tuple[float, float]: """For Tr-ARPES experiments, provides the pump polarization.""" return ( - self._obj.attrs.get("pump_polarization_theta"), - self._obj.attrs.get("pump_polarization_alpha"), + self._obj.attrs.get("pump_polarization_theta", np.nan), + self._obj.attrs.get("pump_polarization_alpha", np.nan), ) @property @@ -1657,11 +1670,11 @@ def prebinning(self) -> dict[str, Any]: def monochromator_info(self) -> dict[str, float]: """Details about the monochromator used on the UV/x-ray source.""" return { - "grating_lines_per_mm": self._obj.attrs.get("grating_lines_per_mm"), + "grating_lines_per_mm": self._obj.attrs.get("grating_lines_per_mm", np.nan), } @property - def undulator_info(self) -> dict[str, str | float]: + def undulator_info(self) -> dict[str, str | float | None]: """Details about the undulator for data performed at an undulator source.""" return { "gap": self._obj.attrs.get("undulator_gap"), @@ -1672,7 +1685,7 @@ def undulator_info(self) -> dict[str, str | float]: } @property - def analyzer_detail(self) -> dict[str, str | float]: + def analyzer_detail(self) -> dict[str, str | float | None]: """Details about the analyzer, its capabilities, and metadata.""" return { "name": self._obj.attrs.get("analyzer_name"), diff --git a/tests/conftest.py b/tests/conftest.py index ce4affaf..3dec17f0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,10 +21,14 @@ class EXPECTEDD(TypedDict, total=False): + """TypedDict for expected.""" + scan_info: SCANINFO class SCENARIO(TypedDict, total=False): + """TypedDict for SCENARIO.""" + file: str @@ -87,5 +91,4 @@ def load(path: str) -> xr.DataArray | xr.Dataset: arpes.config.load_plugins() yield sandbox arpes.config.CONFIG["WORKSPACE"] = None - arpes.config.update_configuration(user_path=None) arpes.endstations._ENDSTATION_ALIASES = {} diff --git a/tests/test_basic_data_loading.py b/tests/test_basic_data_loading.py index 4cdebe77..44935ace 100644 --- a/tests/test_basic_data_loading.py +++ b/tests/test_basic_data_loading.py @@ -1,6 +1,8 @@ +"""Test for basic data loading.""" +from __future__ import annotations + import contextlib -from collections.abc import Generator -from typing import ClassVar +from typing import TYPE_CHECKING, Any, ClassVar import numpy as np import pytest @@ -8,8 +10,11 @@ from arpes.utilities.conversion import convert_to_kspace +if TYPE_CHECKING: + from _typeshed import Incomplete + -def pytest_generate_tests(metafunc): +def pytest_generate_tests(metafunc: Incomplete): """[TODO:summary]. [TODO:description] @@ -56,7 +61,7 @@ class TestMetadata: "temperature": None, "temperature_cryotip": None, "pressure": None, - "polarization": (None, None), + "polarization": (np.nan, np.nan), "photon_flux": None, "photocurrent": None, "probe": None, @@ -109,7 +114,7 @@ class TestMetadata: "pump_profile": None, "pump_linewidth": None, "pump_temporal_width": None, - "pump_polarization": (None, None), + "pump_polarization": (np.nan, np.nan), "probe_wavelength": None, "probe_energy": 5.93, "probe_fluence": None, @@ -118,7 +123,7 @@ class TestMetadata: "probe_profile": None, "probe_linewidth": 0.015, "probe_temporal_width": None, - "probe_polarization": (None, None), + "probe_polarization": (np.nan, np.nan), "repetition_rate": np.nan, }, "sample_info": { @@ -175,7 +180,7 @@ class TestMetadata: "work_function": 4.401, }, "beamline_info": { - "hv": 90, + "hv": 90.0, "beam_current": 500.761, "linewidth": None, "photon_polarization": (0, 0), @@ -185,12 +190,12 @@ class TestMetadata: "harmonic": 2, "type": "elliptically_polarized_undulator", "gap": 41.720, - "z": 0, + "z": 0.0, "polarization": 0, }, "repetition_rate": 5e8, "monochromator_info": { - "grating_lines_per_mm": None, + "grating_lines_per_mm": np.nan, }, }, "daq_info": { @@ -236,7 +241,7 @@ class TestMetadata: "temperature": None, "temperature_cryotip": None, "pressure": None, - "polarization": (None, None), + "polarization": (np.nan, np.nan), "photon_flux": None, "photocurrent": None, "probe": None, @@ -267,7 +272,7 @@ class TestMetadata: "hv": pytest.approx(125, 1e-2), "linewidth": None, "beam_current": pytest.approx(500.44, 1e-2), - "photon_polarization": (None, None), + "photon_polarization": (np.nan, np.nan), "repetition_rate": 5e8, "entrance_slit": None, "exit_slit": None, @@ -314,7 +319,7 @@ class TestMetadata: def test_load_file_and_basic_attributes( self, - sandbox_configuration: Generator, + sandbox_configuration: Incomplete, file: str, expected: dict[str, str | None | dict[str, float]], ) -> None: @@ -344,7 +349,7 @@ class TestBasicDataLoading: data = None - scenarios: ClassVar[list] = [ + scenarios: ClassVar[list[Incomplete]] = [ # Lanzara Group "Main Chamber" ( "main_chamber_load_cut", @@ -617,7 +622,12 @@ class TestBasicDataLoading: ), ] - def test_load_file_and_basic_attributes(self, sandbox_configuration, file, expected): + def test_load_file_and_basic_attributes( + self, + sandbox_configuration: Incomplete, + file: str, + expected: dict[str, Any], + ) -> None: """[TODO:summary]. [TODO:description] @@ -661,8 +671,8 @@ def test_load_file_and_basic_attributes(self, sandbox_configuration, file, expec for d in by_dims ] - assert list(zip(by_dims, ranges)) == list( - zip(by_dims, [expected["coords"][d] for d in by_dims]), + assert list(zip(by_dims, ranges, strict=True)) == list( + zip(by_dims, [expected["coords"][d] for d in by_dims], strict=True), ) for k, v in expected["coords"].items(): if isinstance(v, float): diff --git a/tests/test_curve_fitting.py b/tests/test_curve_fitting.py index c119e4f1..aff05323 100644 --- a/tests/test_curve_fitting.py +++ b/tests/test_curve_fitting.py @@ -7,6 +7,8 @@ from arpes.fits.utilities import broadcast_model from arpes.io import example_data +TOLERANCE = 1e-4 + @pytest.mark.skip() def test_broadcast_fitting() -> None: @@ -17,4 +19,4 @@ def test_broadcast_fitting() -> None: fit_results = broadcast_model([AffineBroadenedFD], near_ef, "phi") - assert np.abs(fit_results.F.p("a_fd_center").values.mean() + 0.00287) < 1e-4 + assert np.abs(fit_results.F.p("a_fd_center").values.mean() + 0.00287) < TOLERANCE diff --git a/tests/test_derivative_analysis.py b/tests/test_derivative_analysis.py index a30f5d2e..c93deb4d 100644 --- a/tests/test_derivative_analysis.py +++ b/tests/test_derivative_analysis.py @@ -1,12 +1,20 @@ +"""Test for derivative procedure.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + import numpy as np import pytest -import xarray as xr from arpes.analysis.derivative import dn_along_axis from arpes.analysis.filters import gaussian_filter_arr +if TYPE_CHECKING: + import xarray as xr + from _typeshed import Incomplete + -def test_dataarray_derivatives(sandbox_configuration) -> None: +def test_dataarray_derivatives(sandbox_configuration: Incomplete) -> None: """Test for derivativation of xarray. Nick ran into an issue where he could not call dn_along_axis with a smooth function that diff --git a/tests/test_direct_and_example_data_loading.py b/tests/test_direct_and_example_data_loading.py index f7ae9e60..e272f2bb 100644 --- a/tests/test_direct_and_example_data_loading.py +++ b/tests/test_direct_and_example_data_loading.py @@ -1,4 +1,8 @@ +"""Test for data loading.""" +from __future__ import annotations + from pathlib import Path +from typing import TYPE_CHECKING import numpy as np import xarray as xr @@ -6,8 +10,11 @@ from arpes.endstations.plugin.ALG_main import ALGMainChamber from arpes.io import load_data, load_example_data +if TYPE_CHECKING: + from _typeshed import Incomplete + -def test_load_data(sandbox_configuration) -> None: +def test_load_data(sandbox_configuration: Incomplete) -> None: """[TODO:summary]. [TODO:description] @@ -25,7 +32,7 @@ def test_load_data(sandbox_configuration) -> None: assert data.spectrum.shape == (240, 240) -def test_load_data_with_plugin_specified(sandbox_configuration) -> None: +def test_load_data_with_plugin_specified(sandbox_configuration: Incomplete) -> None: """[TODO:summary]. [TODO:description] @@ -45,7 +52,7 @@ def test_load_data_with_plugin_specified(sandbox_configuration) -> None: assert np.all(data.spectrum.values == directly_specified_data.spectrum.values) -def test_load_example_data(sandbox_configuration) -> None: +def test_load_example_data(sandbox_configuration: Incomplete) -> None: """[TODO:summary]. [TODO:description] diff --git a/tests/test_generic_utilities.py b/tests/test_generic_utilities.py index c9516d27..505592b8 100644 --- a/tests/test_generic_utilities.py +++ b/tests/test_generic_utilities.py @@ -1,3 +1,4 @@ +"""Test for generic utility.""" import pytest from arpes.utilities import deep_equals, deep_update diff --git a/tests/test_momentum_conversion.py b/tests/test_momentum_conversion.py index d25e4f29..fd10007a 100644 --- a/tests/test_momentum_conversion.py +++ b/tests/test_momentum_conversion.py @@ -8,7 +8,8 @@ from arpes.utilities.conversion.forward import convert_through_angular_point -def load_energy_corrected() -> xr.Dataset: +def load_energy_corrected() -> xr.DataArray: + """Loading map data (example_data.map).""" return example_data.map.spectrum @@ -48,8 +49,7 @@ def test_cut_momentum_conversion_ranges() -> None: ",", "", ).split() - expected_values = [int(m) for m in expected_values] - assert kdata.argmax(dim="eV").values.tolist() == expected_values + assert kdata.argmax(dim="eV").values.tolist() == [int(m) for m in expected_values] def test_fermi_surface_conversion() -> None: diff --git a/tests/test_montage.py b/tests/test_montage.py index e69de29b..109f0e4c 100644 --- a/tests/test_montage.py +++ b/tests/test_montage.py @@ -0,0 +1 @@ +"""Unit test for utility/combine.py/concat_along_phi.""" diff --git a/tests/test_qt.py b/tests/test_qt.py index cfb5c994..49fbb437 100644 --- a/tests/test_qt.py +++ b/tests/test_qt.py @@ -1,3 +1,4 @@ +"""Unit test for qt related.""" from typing import TYPE_CHECKING from PyQt5 import QtCore diff --git a/tests/test_time_configuration.py b/tests/test_time_configuration.py index 5d506a85..faf6d018 100644 --- a/tests/test_time_configuration.py +++ b/tests/test_time_configuration.py @@ -1,3 +1,4 @@ +"""test for time configuration.""" import os.path import arpes.config @@ -15,7 +16,9 @@ def test_patched_config(sandbox_configuration) -> None: [TODO:description] """ sandbox_configuration.with_workspace("basic") + assert "name" in arpes.config.CONFIG["WORKSPACE"] assert arpes.config.CONFIG["WORKSPACE"]["name"] == "basic" + assert "path" in arpes.config.CONFIG["WORKSPACE"] assert str(arpes.config.CONFIG["WORKSPACE"]["path"]).split(os.sep)[-2:] == ["datasets", "basic"]