From 7cafa0f1ead80745162cd5807798c06fc4f9d91e Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Thu, 8 Feb 2024 16:52:48 +0900 Subject: [PATCH 1/7] =?UTF-8?q?=F0=9F=92=AC=20=20Update=20type=20hints?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- arpes/plotting/bz.py | 70 +++++++++++++++++++------------------- arpes/utilities/bz.py | 1 + arpes/utilities/bz_spec.py | 1 + arpes/widgets.py | 20 +++++------ tests/test_bz.py | 7 ---- 5 files changed, 47 insertions(+), 52 deletions(-) diff --git a/arpes/plotting/bz.py b/arpes/plotting/bz.py index a9c1dc13..9cb97d04 100644 --- a/arpes/plotting/bz.py +++ b/arpes/plotting/bz.py @@ -1,6 +1,5 @@ """Utilities related to plotting Brillouin zones and data onto them.""" -# pylint: disable=import-error from __future__ import annotations import itertools @@ -19,7 +18,7 @@ from scipy.spatial.transform import Rotation from arpes.analysis.mask import apply_mask_to_coords -from arpes.utilities import normalize_to_spectrum +from arpes.constants import TWO_DIMENSION from arpes.utilities.bz import build_2dbz_poly, hex_cell_2d, process_kpath from arpes.utilities.bz_spec import A_GRAPHENE, A_WS2, A_WSe2 from arpes.utilities.geometry import polyhedron_intersect_plane @@ -50,7 +49,7 @@ "overplot_standard", ) -overplot_library = { +overplot_library: dict[str, Callable[..., dict[str, list[list[float]]]]] = { "graphene": lambda: {"cell": hex_cell_2d(A_GRAPHENE)}, "ws2": lambda: {"cell": hex_cell_2d(A_WS2)}, "wwe2": lambda: {"cell": hex_cell_2d(A_WSe2)}, @@ -71,13 +70,13 @@ def segments_standard( name: str = "graphene", - rotate: float = 0.0, + rotate_rad: float = 0.0, ) -> tuple[list[NDArray[np.float_]], list[NDArray[np.float_]]]: name = name.lower() - specification = overplot_library[name]() + specification: dict[str, list[list[float]]] = overplot_library[name]() transformations = [] - if rotate: - transformations = [Rotation.from_rotvec([0, 0, rotate])] + if rotate_rad: + transformations = [Rotation.from_rotvec([0, 0, rotate_rad])] return bz2d_segments(specification["cell"], transformations) @@ -142,7 +141,7 @@ def apply(self, vectors: ArrayLike, *, inverse: bool = False) -> NDArray[np.floa """ vectors = np.asarray(vectors) - if vectors.ndim > 2 or vectors.shape[-1] not in {2, 3}: # noqa: PLR2004 + if vectors.ndim > TWO_DIMENSION or vectors.shape[-1] not in {2, 3}: msg = "Expected a 2D or 3D vector (2 or 3,)" msg += f" of list of vectors (N, 2 or 3,), instead receivied: {vectors.shape}" raise ValueError( @@ -188,13 +187,24 @@ def apply_transformations( def plot_plane_to_bz( - cell: Sequence[Sequence[float]], + cell: Sequence[Sequence[float]] | NDArray[np.float_], plane: str | list[NDArray[np.float_]], ax: Axes, special_points: dict[str, NDArray[np.float_]] | None = None, facecolor: ColorType = "red", ) -> None: - """Plots a 2D cut plane onto a Brillouin zone.""" + """Plots a 2D cut plane onto a Brillouin zone. + + Args: + cell: [TODO:description] + plane: [TODO:description] + ax: [TODO:description] + special_points: [TODO:description] + facecolor: [TODO:description] + + Returns: + [TODO:description] + """ from ase.dft.bz import bz_vertices if isinstance(plane, str): @@ -218,19 +228,19 @@ def plot_plane_to_bz( def plot_data_to_bz( data: DataType, - cell: Sequence[Sequence[float]], + cell: Sequence[Sequence[float]] | NDArray[np.float_], **kwargs: Incomplete, ): """A dimension agnostic tool used to plot ARPES data onto a Brillouin zone.""" - if len(data) == 3: # noqa: PLR2004 + if len(data) == TWO_DIMENSION + 1: return plot_data_to_bz3d(data, cell, **kwargs) return plot_data_to_bz2d(data, cell, **kwargs) def plot_data_to_bz2d( # noqa: PLR0913 - data: DataType, - cell: Sequence[Sequence[float]], + data_array: xr.DataArray, + cell: Sequence[Sequence[float]] | NDArray[np.float_], rotate: float | None = None, shift: NDArray[np.float_] | None = None, scale: float | None = None, @@ -242,10 +252,8 @@ def plot_data_to_bz2d( # noqa: PLR0913 **kwargs: Incomplete, ) -> Path | tuple[Figure, Axes]: """Plots data onto a 2D Brillouin zone.""" - data_array = normalize_to_spectrum(data) - assert data_array.S.is_kspace, "You must k-space convert data before plotting to BZs" - assert isinstance(data_array, xr.DataArray) + assert isinstance(data_array, xr.DataArray), "data_array must be xr.DataArray, not Dataset" if bz_number is None: bz_number = (0, 0) @@ -256,7 +264,7 @@ def plot_data_to_bz2d( # noqa: PLR0913 bz2d_plot(cell, paths="all", ax=ax) assert isinstance(ax, Axes) - if len(cell) == 2: # noqa: PLR2004 + if len(cell) == TWO_DIMENSION: cell = [[*list(c), 0] for c in cell] + [[0, 0, 1]] icell = np.linalg.inv(cell).T @@ -307,7 +315,7 @@ def plot_data_to_bz2d( # noqa: PLR0913 def plot_data_to_bz3d( data: DataType, - cell: Sequence[Sequence[float]], + cell: Sequence[Sequence[float]] | NDArray[np.float_], **kwargs: Incomplete, ) -> None: """Plots ARPES data onto a 3D Brillouin zone.""" @@ -320,23 +328,25 @@ def plot_data_to_bz3d( raise NotImplementedError(msg) -def bz_plot(cell: Sequence[Sequence[float]], *args, **kwargs: Incomplete) -> Axes: +def bz_plot( + cell: Sequence[Sequence[float]] | NDArray[np.float_], *args, **kwargs: Incomplete +) -> Axes: """Dimension generic BZ plot which uses the cell dimension to delegate.""" logger.debug(f"size of cell is: {format(len(cell))}") - if len(cell) > 2: # noqa: PLR2004 + if len(cell) > TWO_DIMENSION: return bz3d_plot(cell, *args, **kwargs) return bz2d_plot(cell, *args, **kwargs) def bz3d_plot( - cell: Sequence[Sequence[float]], + cell: Sequence[Sequence[float]] | NDArray[np.float_], paths: str | list[str | float] | None = None, kpoints: Sequence[Sequence[float]] | None = None, ax: Axes | None = None, elev: float | None = None, scale: float = 1, - repeat: tuple[int, int, int] | None = None, + repeat: tuple[int, int, int] = (1, 1, 1), transformations: list[Transformation] | None = None, *, vectors: bool = False, @@ -407,13 +417,6 @@ def draw(self, renderer) -> None: maxp = 0.0 - if repeat is None: - repeat = ( - 1, - 1, - 1, - ) - dx, dy, dz = icell[0], icell[1], icell[2] rep_x: int | tuple[int, int] rep_y: int | tuple[int, int] @@ -644,11 +647,11 @@ def bz2d_segments( return segments_x, segments_y -def twocell_to_bz1(cell): +def twocell_to_bz1(cell: Sequence[Sequence[float]] | NDArray[np.float_]): from ase.dft.bz import bz_vertices # 2d in x-y plane - if len(cell) > 2: # noqa: PLR2004 + if len(cell) > TWO_DIMENSION: assert all(abs(cell[2][0:2]) < 1e-6) # noqa: PLR2004 assert all(abs(cell.T[2][0:2]) < 1e-6) # noqa: PLR2004 else: @@ -689,9 +692,6 @@ def bz2d_plot( logger.debug(f"hide_ax: {hide_ax}") logger.debug(f"vectors: {vectors}") logger.debug(f"set_equal_aspect: {set_equal_aspect}") - if kwargs: - for k, v in kwargs.items(): - logger.debug(f"kwargs: kyes: {k}, value: {v}") kpoints = points bz1, icell, cell = twocell_to_bz1(cell) if ax is None: diff --git a/arpes/utilities/bz.py b/arpes/utilities/bz.py index 46f606e3..647168e9 100644 --- a/arpes/utilities/bz.py +++ b/arpes/utilities/bz.py @@ -16,6 +16,7 @@ import matplotlib.path import numpy as np +from ase.dft.kpoints import get_special_points from arpes.constants import TWO_DIMENSION diff --git a/arpes/utilities/bz_spec.py b/arpes/utilities/bz_spec.py index 1f04addb..f5da3298 100644 --- a/arpes/utilities/bz_spec.py +++ b/arpes/utilities/bz_spec.py @@ -14,6 +14,7 @@ 3. The material inner potential, if available 4. The material name """ + from __future__ import annotations import functools diff --git a/arpes/widgets.py b/arpes/widgets.py index c29e4d3d..5c1173a1 100644 --- a/arpes/widgets.py +++ b/arpes/widgets.py @@ -68,7 +68,7 @@ from collections.abc import Callable from _typeshed import Incomplete - from matplotlib.backend_bases import MouseEvent + from matplotlib.backend_bases import Event, MouseEvent from matplotlib.collections import Collection from matplotlib.colors import Colormap from numpy.typing import NDArray @@ -586,7 +586,7 @@ def on_add_new_peak(selection) -> None: data_view.attach_selector(on_select=on_add_new_peak) ctx["data"] = data - def on_copy_settings(event: MouseEvent) -> None: + def on_copy_settings(event: Event) -> None: """[TODO:summary]. [TODO:description] @@ -607,7 +607,7 @@ def on_copy_settings(event: MouseEvent) -> None: @popout def pca_explorer( - pca: DataType, + pca: xr.DataArray, # values is used data: xr.DataArray, # values is used component_dim: str = "components", initial_values: list[float] | None = None, @@ -738,7 +738,7 @@ def set_axes(component_x, component_y) -> None: ax_components.set_ylabel("$e_" + str(component_y) + "$") update_from_selection([]) - def on_change_axes(event: MouseEvent) -> None: + def on_change_axes(event: Event) -> None: """[TODO:summary]. [TODO:description] @@ -904,7 +904,7 @@ def update_kspace_plot() -> None: ) sliders[convert_dim].on_changed(update_kspace_plot) - def compute_offsets() -> dict[str, float]: + def _compute_offsets() -> dict[str, float]: """[TODO:summary]. Returns: @@ -912,7 +912,7 @@ def compute_offsets() -> dict[str, float]: """ return {k: v.val for k, v in sliders.items()} - def on_copy_settings(event: MouseEvent) -> None: + def on_copy_settings(event: Event) -> None: """[TODO:summary]. Args: @@ -921,9 +921,9 @@ def on_copy_settings(event: MouseEvent) -> None: Returns: [TODO:description] """ - pyperclip.copy(pprint.pformat(compute_offsets())) + pyperclip.copy(pprint.pformat(_compute_offsets())) - def apply_offsets(event: MouseEvent) -> None: + def apply_offsets(event: Event) -> None: """[TODO:summary]. Args: @@ -932,7 +932,7 @@ def apply_offsets(event: MouseEvent) -> None: Returns: [TODO:description] """ - for name, offset in compute_offsets().items(): + for name, offset in _compute_offsets().items(): original_data.attrs[f"{name}_offset"] = offset try: for s in original_data.S.spectra: @@ -1039,7 +1039,7 @@ def pick_gamma(data: DataType, **kwargs: Incomplete) -> DataType: dims = data.dims assert len(dims) == TWO_DIMENSION - def onclick(event: MouseEvent) -> None: + def onclick(event: Event) -> None: """[TODO:summary]. [TODO:description] diff --git a/tests/test_bz.py b/tests/test_bz.py index 86a7e535..e27bdb5a 100644 --- a/tests/test_bz.py +++ b/tests/test_bz.py @@ -1,17 +1,10 @@ """Test for utilitiy.bz.""" - import numpy as np from arpes.utilities import bz -def test_as_3d() -> None: - """Test for 'bz.as_3d'.""" - points_2d = np.array([[0, 1], [2, 3]]) - np.testing.assert_array_equal(bz.as_3d(points_2d), np.array([[0, 1, 0], [2, 3, 0]])) - - def test_as_2d() -> None: """Test for ''bz.as_2d'.""" points_3d = np.array([[0, 1, 2], [2, 3, 4]]) From e56706c7f033aaf5999ba02c94e257dded1b79d0 Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Thu, 8 Feb 2024 17:03:08 +0900 Subject: [PATCH 2/7] =?UTF-8?q?=F0=9F=94=A5=20=20Remove=20as=5F3d=20in=20u?= =?UTF-8?q?tility/bz.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- arpes/utilities/bz.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/arpes/utilities/bz.py b/arpes/utilities/bz.py index 647168e9..1ba284a2 100644 --- a/arpes/utilities/bz.py +++ b/arpes/utilities/bz.py @@ -61,29 +61,28 @@ class SpecialPoint(NamedTuple): name: str negate: bool - bz_coord: NDArray[np.float_] | list[float] | tuple[float, ...] + bz_coord: NDArray[np.float_] | Sequence[float] | tuple[float, float, float] -def as_3d(points_2d: ArrayLike, *, padding: bool = False) -> NDArray[np.float_]: - """Takes a 2D points list and zero pads to convert to a 3D representation. +def make_special_points(cell: Sequence[Sequence[float]] | NDArray[np.float_]) -> list[SpecialPoint]: + """Make a list of Special Points from the cell vectors. Args: - points_2d (ArrayLike): ArrayLike object that represents the 2D-lattice. - padding (bool): if True, pad a placeholder unit vector that are zero. + cell (Sequence[Sequence[float]] | NDArray[float]): Matrix of the cell (3x3). + if (2x2) matrix (0, 0, 1) is padded. Returns: - [TODO:description] + list[SpecialPoint] """ - np_points = np.array(points_2d) - assert np_points.ndim == TWO_DIMENSION - if padding: - return np.vstack( - ( - np.concatenate([np_points, np_points[:, 0][:, None] * 0], axis=1), - [0, 0, 0], - ), + cell_array = np.array(cell) + if cell_array.shape == (2, 2): + cell_array = np.array( + [[*c, 0] for c in cell_array] + [[0, 0, 1]], ) - return np.concatenate([np_points, np_points[:, 0][:, None] * 0], axis=1) + assert cell_array.shape == (3, 3), "cell must be set as 3D." + return [ + SpecialPoint(name=k, negate=False, bz_coord=v) for k, v in get_special_points(cell).items() + ] def as_2d(points_3d: ArrayLike) -> NDArray[np.float_]: From 393c511e001b50445d0baae80269a02b167ea99a Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Thu, 8 Feb 2024 17:04:32 +0900 Subject: [PATCH 3/7] =?UTF-8?q?=F0=9F=8E=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- arpes/utilities/bz.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/arpes/utilities/bz.py b/arpes/utilities/bz.py index 1ba284a2..6d3460dc 100644 --- a/arpes/utilities/bz.py +++ b/arpes/utilities/bz.py @@ -254,7 +254,7 @@ def hex_cell(a: float = 1, c: float = 1) -> list[list[float]]: Returns: [TODO:description] """ - return [[a, 0, 0], [-0.5 * a, 3**0.5 / 2 * a, 0], [0, 0, c]] + return [[a, 0, 0], [-0.5 * a, np.sqrt(3) / 2 * a, 0], [0, 0, c]] def hex_cell_2d(a: float = 1) -> list[list[float]]: @@ -264,9 +264,9 @@ def hex_cell_2d(a: float = 1) -> list[list[float]]: a: lattice constant of along a-axis. Returns: - [TODO:description] + list of list(2x2-list) that represent 2D triangular lattice. """ - return [[a, 0], [-0.5 * a, 3**0.5 / 2 * a]] + return [[a, 0], [-0.5 * a, np.sqrt(3) / 2 * a]] def flat_bz_indices_list( From f124fda7e5d9cab7d2dd525cdf6bc6c27f67ed4e Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Thu, 8 Feb 2024 18:28:07 +0900 Subject: [PATCH 4/7] =?UTF-8?q?=E2=8F=AA=20=20Revert=20convert=5Ffor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- arpes/utilities/conversion/kx_ky_conversion.py | 15 +++++++++++++-- arpes/utilities/conversion/kz_conversion.py | 9 ++++++++- arpes/utilities/conversion/trapezoid.py | 8 +++++++- 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/arpes/utilities/conversion/kx_ky_conversion.py b/arpes/utilities/conversion/kx_ky_conversion.py index 7bd4bd43..4231689f 100644 --- a/arpes/utilities/conversion/kx_ky_conversion.py +++ b/arpes/utilities/conversion/kx_ky_conversion.py @@ -212,8 +212,11 @@ def kspace_to_phi( self, binding_energy: NDArray[np.float_], kp: NDArray[np.float_], + *args: Incomplete, ) -> NDArray[np.float_]: """Converts from momentum back to the analyzer angular axis.""" + # Dont remove *args even if not used. + del args if self.phi is not None: return self.phi if self.is_slit_vertical: @@ -247,12 +250,16 @@ def kspace_to_phi( def conversion_for(self, dim: str) -> Callable[[NDArray[np.float_]], NDArray[np.float_]]: """Looks up the appropriate momentum-to-angle conversion routine by dimension name.""" + + def _with_identity(*args: NDArray[np.float_]) -> NDArray[np.float_]: + return self.identity_transform(dim, *args) + return { "eV": self.kspace_to_BE, "phi": self.kspace_to_phi, }.get( dim, - self.identity_transform, + _with_identity, ) @@ -377,13 +384,17 @@ def compute_k_tot(self, binding_energy: NDArray[np.float_]) -> None: def conversion_for(self, dim: str) -> Callable[[NDArray[np.float_]], NDArray[np.float_]]: """Looks up the appropriate momentum-to-angle conversion routine by dimension name.""" + + def _with_identity(*args: NDArray[np.float_]) -> NDArray[np.float_]: + return self.identity_transform(dim, *args) + return { "eV": self.kspace_to_BE, "phi": self.kspace_to_phi, "theta": self.kspace_to_perp_angle, "psi": self.kspace_to_perp_angle, "beta": self.kspace_to_perp_angle, - }.get(dim, self.identity_transform) + }.get(dim, _with_identity) @property def needs_rotation(self) -> bool: diff --git a/arpes/utilities/conversion/kz_conversion.py b/arpes/utilities/conversion/kz_conversion.py index 455e1163..810e2c2d 100644 --- a/arpes/utilities/conversion/kz_conversion.py +++ b/arpes/utilities/conversion/kz_conversion.py @@ -194,8 +194,15 @@ def kspace_to_phi( def conversion_for(self, dim: str) -> Callable[[NDArray[np.float_]], NDArray[np.float_]]: """Looks up the appropriate momentum-to-angle conversion routine by dimension name.""" + + def _with_identity(*args: NDArray[np.float_]) -> NDArray[np.float_]: + return self.identity_transform(dim, *args) + return { "eV": self.kspace_to_BE, "hv": self.kspace_to_hv, "phi": self.kspace_to_phi, - }.get(dim, self.identity_transform) + }.get( + dim, + _with_identity, + ) diff --git a/arpes/utilities/conversion/trapezoid.py b/arpes/utilities/conversion/trapezoid.py index 395c026b..e9dee486 100644 --- a/arpes/utilities/conversion/trapezoid.py +++ b/arpes/utilities/conversion/trapezoid.py @@ -139,9 +139,15 @@ def get_coordinates(self, *args: Incomplete, **kwargs: Incomplete) -> Indexes: return self.arr.indexes def conversion_for(self, dim: str) -> Callable[..., NDArray[np.float_]]: + def _with_identity(*args: NDArray[np.float_]) -> NDArray[np.float_]: + return self.identity_transform(dim, *args) + return { "phi": self.phi_to_phi, - }.get(dim, self.identity_transform) + }.get( + dim, + _with_identity, + ) def phi_to_phi( self, From 97cbc5bc0eb6cf2faa2bc1dfbeace02eada023a9 Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Thu, 8 Feb 2024 19:25:05 +0900 Subject: [PATCH 5/7] =?UTF-8?q?=F0=9F=92=AC=20=20Update=20type=20hints?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- arpes/analysis/statistics.py | 11 ++++++--- arpes/plotting/qt_tool/__init__.py | 6 ++--- arpes/plotting/spin.py | 4 ++-- arpes/plotting/tof.py | 2 +- arpes/xarray_extensions.py | 38 +++++++++++++++++------------- 5 files changed, 35 insertions(+), 26 deletions(-) diff --git a/arpes/analysis/statistics.py b/arpes/analysis/statistics.py index 1511700e..a5726982 100644 --- a/arpes/analysis/statistics.py +++ b/arpes/analysis/statistics.py @@ -1,17 +1,23 @@ """Contains utilities for performing statistical operations in spectra and DataArrays.""" + from __future__ import annotations +from typing import TYPE_CHECKING + import xarray as xr from arpes.provenance import update_provenance from arpes.utilities import lift_dataarray_to_generic +if TYPE_CHECKING: + from arpes._typing import XrTypes + __all__ = ("mean_and_deviation",) @update_provenance("Calculate mean and standard deviation for observation axis") @lift_dataarray_to_generic -def mean_and_deviation(data: xr.DataArray, axis: str = "", name: str = "") -> xr.Dataset: +def mean_and_deviation(data: XrTypes, axis: str = "", name: str = "") -> xr.Dataset: """Calculates the mean and standard deviation of a DataArray along an axis. The reduced axis corresponds to individual observations of a tensor/array valued quantity. @@ -21,7 +27,7 @@ def mean_and_deviation(data: xr.DataArray, axis: str = "", name: str = "") -> xr If a name is not attached to the DataArray, it should be provided. Args: - data: The input data. + data: The input data (Both DataArray and Dataset). axis: The name of the dimension which we should perform the reduction along. name: The name of the variable which should be reduced. By default, uses `data.name`. @@ -30,7 +36,6 @@ def mean_and_deviation(data: xr.DataArray, axis: str = "", name: str = "") -> xr relevant variable in the input DataArray. (Dimension is reduced.) """ preferred_axes = ["bootstrap", "cycle", "idx"] - assert isinstance(data, xr.DataArray) name = str(data.name) if data.name == "" else name if not axis: diff --git a/arpes/plotting/qt_tool/__init__.py b/arpes/plotting/qt_tool/__init__.py index 7bef86b1..e1dbae09 100644 --- a/arpes/plotting/qt_tool/__init__.py +++ b/arpes/plotting/qt_tool/__init__.py @@ -7,7 +7,7 @@ import warnings import weakref from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger -from typing import TYPE_CHECKING, reveal_type +from typing import TYPE_CHECKING import dill import matplotlib as mpl @@ -39,7 +39,7 @@ from PySide6.QtGui import QKeyEvent from PySide6.QtWidgets import QWidget - from arpes._typing import DataType, XrTypes + from arpes._typing import XrTypes LOGLEVEL = (DEBUG, INFO)[1] logger = getLogger(__name__) @@ -539,7 +539,7 @@ def set_data(self, data: XrTypes) -> None: self._binning = [1 for _ in self.data.dims] -def _qt_tool(data: DataType, **kwargs: Incomplete) -> None: +def _qt_tool(data: XrTypes, **kwargs: Incomplete) -> None: """Starts the qt_tool using an input spectrum.""" with contextlib.suppress(TypeError): data = dill.loads(data) diff --git a/arpes/plotting/spin.py b/arpes/plotting/spin.py index e9451045..592cc19a 100644 --- a/arpes/plotting/spin.py +++ b/arpes/plotting/spin.py @@ -1,7 +1,8 @@ """Some general plotting routines for presentation of spin-ARPES data.""" + from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, reveal_type import matplotlib as mpl import matplotlib.colors @@ -154,7 +155,6 @@ def spin_polarized_spectrum( # noqa: PLR0913 """Plots a simple spin polarized spectrum using curves for the up and down components.""" if ax is None: _, ax = plt.subplots(2, 1, sharex=True) - assert isinstance(ax, Axes) if stats: spin_dr = bootstrap(lambda x: x)(spin_dr, N=100) pol = mean_and_deviation(to_intensity_polarization(spin_dr)) diff --git a/arpes/plotting/tof.py b/arpes/plotting/tof.py index f1d35290..b989743c 100644 --- a/arpes/plotting/tof.py +++ b/arpes/plotting/tof.py @@ -81,7 +81,7 @@ def plot_with_std( @save_plot_provenance def scatter_with_std( - data: xr.Dataset, + data: xr.Dataset, # data_vars is used. name_to_plot: str = "", ax: Axes | None = None, out: str | Path = "", diff --git a/arpes/xarray_extensions.py b/arpes/xarray_extensions.py index 936fb004..daadd366 100644 --- a/arpes/xarray_extensions.py +++ b/arpes/xarray_extensions.py @@ -48,7 +48,7 @@ from collections.abc import Collection, Hashable, Mapping, Sequence from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, TypeAlias, Unpack +from typing import TYPE_CHECKING, Any, Literal, Self, TypeAlias, Unpack import matplotlib.pyplot as plt import numpy as np @@ -1780,9 +1780,9 @@ def _experimentalinfo_to_dict(conditions: EXPERIMENTINFO) -> dict[str, str]: if isinstance(v, xr.DataArray): min_hv = float(v.min()) max_hv = float(v.max()) - transformed_dict[k] = ( - f" from {min_hv} to {max_hv} eV" - ) + transformed_dict[ + k + ] = f" from {min_hv} to {max_hv} eV" elif isinstance(v, float) and not np.isnan(v): transformed_dict[k] = f"{v} eV" return transformed_dict @@ -1928,8 +1928,12 @@ def _degree_to_radian(self) -> None: class ARPESDataArrayAccessor(ARPESAccessorBase): """Spectrum related accessor for `xr.DataArray`.""" + def __init__(self, xarray_obj: xr.DataArray) -> None: + """Initialize.""" + self._obj = xarray_obj + def plot( - self, + self: Self, *args: Incomplete, **kwargs: Incomplete, ) -> None: @@ -1945,14 +1949,14 @@ def plot( with plt.rc_context(rc={"text.usetex": False}): self._obj.plot(*args, **kwargs) - def show(self, *, detached: bool = False, **kwargs: Incomplete) -> None: + def show(self: Self, *, detached: bool = False, **kwargs: Incomplete) -> None: """Opens the Qt based image tool.""" from .plotting.qt_tool import qt_tool qt_tool(self._obj, detached=detached, **kwargs) def fs_plot( - self, + self: Self, pattern: str = "{}.png", **kwargs: Incomplete, ) -> Path | None | tuple[Figure, Axes]: @@ -1964,7 +1968,7 @@ def fs_plot( return labeled_fermi_surface(self._obj, **kwargs) def fermi_edge_reference_plot( - self, + self: Self, pattern: str = "{}.png", **kwargs: str | Normalize | None, ) -> Path | None: @@ -1985,7 +1989,7 @@ def fermi_edge_reference_plot( return fermi_edge_reference(self._obj, **kwargs) def _referenced_scans_for_spatial_plot( - self, + self: Self, *, use_id: bool = True, pattern: str = "{}.png", @@ -2008,7 +2012,7 @@ def _referenced_scans_for_spatial_plot( return reference_scan_spatial(self._obj, out=out) def _referenced_scans_for_map_plot( - self, + self: Self, pattern: str = "{}.png", *, use_id: bool = True, @@ -2029,7 +2033,7 @@ class HvRefScanParam(LabeledFermiSurfaceParam): bkg_subtraction: float def _referenced_scans_for_hv_map_plot( - self, + self: Self, pattern: str = "{}.png", *, use_id: bool = True, @@ -2045,7 +2049,7 @@ def _referenced_scans_for_hv_map_plot( return hv_reference_scan(self._obj, **kwargs) def _simple_spectrum_reference_plot( - self, + self: Self, *, use_id: bool = True, pattern: str = "{}.png", @@ -2059,7 +2063,7 @@ def _simple_spectrum_reference_plot( return fancy_dispersion(self._obj, **kwargs) - def cut_nan_coords(self) -> XrTypes: + def cut_nan_coords(self: Self) -> xr.DataArray: """Selects data where coordinates are not `nan`. Returns (xr.DataArray): @@ -2124,7 +2128,6 @@ def switch_energy_notation(self, nonlinear_order: int = 1) -> None: Args: nonlinear_order (int): order of the nonliniarity, default to 1 """ - assert isinstance(self._obj, xr.DataArray | xr.Dataset) if self._obj.coords["hv"].ndim == 0: if self.energy_notation == "Binding": self._obj.coords["eV"] = ( @@ -3319,7 +3322,7 @@ def __getattr__(self, item: str) -> dict: """ return getattr(self._obj.S.spectrum.S, item) - def polarization_plot(self, **kwargs: IncompleteMPL) -> Axes: + def polarization_plot(self, **kwargs: IncompleteMPL) -> list[Axes] | Path: """Creates a spin polarization plot. Returns: @@ -3448,7 +3451,7 @@ def scan_degrees_of_freedom(self) -> set[str]: """ return self.degrees_of_freedom.difference(self.spectrum_degrees_of_freedom) - def reference_plot(self, **kwargs: IncompleteMPL) -> None: + def reference_plot(self: Self, **kwargs: IncompleteMPL) -> None: """Creates reference plots for a dataset. A bit of a misnomer because this actually makes many plots. For full datasets, @@ -3486,7 +3489,7 @@ def reference_plot(self, **kwargs: IncompleteMPL) -> None: if figure_item not in self._obj.data_vars: continue name = name_normalization.get(figure_item, figure_item) - data_var = self._obj[figure_item] + data_var: xr.DataArray = self._obj[figure_item] out = f"{self.label}_{name}_spec_integrated_reference.png" scan_var_reference_plot(data_var, title=f"Reference {name}", out=out) @@ -3673,4 +3676,5 @@ def __init__(self, xarray_obj: xr.Dataset) -> None: Args: xarray_obj: The parent object which this is an accessor for """ + self._obj: xr.Dataset super().__init__(xarray_obj) From 31d0cb999cf109cf1e517676012fe93bcb8390d1 Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Fri, 9 Feb 2024 13:52:12 +0900 Subject: [PATCH 6/7] =?UTF-8?q?=F0=9F=92=AC=20=20Update=20type=20hints?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit mypy (819) ruff (304) --- arpes/analysis/band_analysis.py | 2 +- arpes/analysis/decomposition.py | 8 ++-- arpes/analysis/deconvolution.py | 28 ++++++------- arpes/analysis/derivative.py | 9 ++-- arpes/analysis/gap.py | 18 +++++--- arpes/analysis/general.py | 12 +++--- arpes/analysis/mask.py | 5 +-- arpes/analysis/pocket.py | 15 +++---- arpes/analysis/resolution.py | 9 ++-- arpes/analysis/sarpes.py | 5 ++- arpes/analysis/self_energy.py | 12 +++--- arpes/analysis/shirley.py | 19 +++++---- arpes/analysis/tarpes.py | 16 +++----- arpes/bootstrap.py | 7 +++- arpes/corrections/__init__.py | 8 ++-- arpes/corrections/background.py | 13 ++---- arpes/corrections/fermi_edge_corrections.py | 11 ++--- arpes/deep_learning/formatters.py | 15 ++++--- arpes/deep_learning/interpret.py | 37 ++++++++++------- arpes/deep_learning/models/regression.py | 27 +++++++----- arpes/deep_learning/transforms.py | 23 ++++++----- arpes/endstations/plugin/BL10_SARPES.py | 1 + arpes/endstations/plugin/IF_UMCS.py | 17 ++++---- arpes/endstations/plugin/MBS.py | 12 +++--- arpes/endstations/plugin/igor_plugin.py | 18 +++++++- arpes/fits/utilities.py | 7 +--- arpes/models/band.py | 1 + arpes/plotting/bz.py | 41 +++++++++++-------- arpes/plotting/dos.py | 10 +++-- arpes/plotting/dynamic_tool.py | 5 ++- arpes/plotting/fermi_surface.py | 6 +-- arpes/plotting/parameter.py | 21 ++++++---- arpes/plotting/qt_ktool/__init__.py | 8 ++-- arpes/plotting/qt_tool/__init__.py | 6 +-- arpes/plotting/spatial.py | 10 ++--- arpes/plotting/spin.py | 2 +- arpes/plotting/stack_plot.py | 10 ++--- arpes/plotting/utils.py | 13 +++--- arpes/preparation/axis_preparation.py | 7 +++- arpes/provenance.py | 8 ++-- arpes/utilities/conversion/base.py | 1 + .../conversion/bounds_calculations.py | 4 +- arpes/utilities/conversion/core.py | 18 ++++---- arpes/utilities/conversion/forward.py | 12 +++--- .../utilities/conversion/kx_ky_conversion.py | 4 +- arpes/utilities/conversion/kz_conversion.py | 2 +- .../utilities/conversion/remap_manipulator.py | 1 + arpes/utilities/selections.py | 11 ++--- arpes/widgets.py | 4 +- arpes/xarray_extensions.py | 4 +- 50 files changed, 312 insertions(+), 251 deletions(-) diff --git a/arpes/analysis/band_analysis.py b/arpes/analysis/band_analysis.py index c95759cf..bdd739f7 100644 --- a/arpes/analysis/band_analysis.py +++ b/arpes/analysis/band_analysis.py @@ -38,7 +38,7 @@ def fit_for_effective_mass( - data: XrTypes, + data: xr.DataArray, fit_kwargs: dict | None = None, ) -> float: """Fits for the effective mass in a piece of data. diff --git a/arpes/analysis/decomposition.py b/arpes/analysis/decomposition.py index 6256da74..8deba529 100644 --- a/arpes/analysis/decomposition.py +++ b/arpes/analysis/decomposition.py @@ -14,7 +14,6 @@ import xarray as xr from _typeshed import Incomplete - from arpes._typing import DataType __all__ = ( "nmf_along", "pca_along", @@ -24,7 +23,7 @@ def decomposition_along( - data: DataType, + data: xr.DataArray, axes: list[str], decomposition_cls: type[sklearn.decomposition], *, @@ -69,11 +68,12 @@ def decomposition_along( from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler + data = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) if len(axes) > 1: - flattened_data: xr.DataArray = normalize_to_spectrum(data).stack(fit_axis=axes) + flattened_data: xr.DataArray = data.stack(fit_axis=axes) stacked = True else: - flattened_data = normalize_to_spectrum(data).S.transpose_to_back(axes[0]) + flattened_data = data.S.transpose_to_back(axes[0]) stacked = False if len(flattened_data.dims) != TWO_DIMENSION: diff --git a/arpes/analysis/deconvolution.py b/arpes/analysis/deconvolution.py index af2b8066..56ad03ba 100644 --- a/arpes/analysis/deconvolution.py +++ b/arpes/analysis/deconvolution.py @@ -36,7 +36,7 @@ @update_provenance("Approximate Iterative Deconvolution") def deconvolve_ice( - data: DataType, + data: xr.DataArray, psf: NDArray[np.float_], n_iterations: int = 5, deg: int | None = None, @@ -55,8 +55,8 @@ def deconvolve_ice( Returns: The deconvoled data in the same format. """ - arr = normalize_to_spectrum(data).values - + data = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data).values + arr = data.values if deg is None: deg = n_iterations - 3 iteration_steps = list(range(1, n_iterations + 1)) @@ -73,17 +73,17 @@ def deconvolve_ice( poly = np.poly1d(coefs) deconv[t] = poly(0) - if type(data) is np.ndarray: + if isinstance(data, np.ndarray): result = deconv else: - result = normalize_to_spectrum(data).copy(deep=True) + result = data.copy(deep=True) result.values = deconv return result @update_provenance("Lucy Richardson Deconvolution") def deconvolve_rl( - data: DataType, + data: xr.DataArray, psf: xr.DataArray | None = None, n_iterations: int = 10, axis: str = "", @@ -106,7 +106,7 @@ def deconvolve_rl( Returns: The Richardson-Lucy deconvolved data. """ - arr = normalize_to_spectrum(data) + arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) if psf is None and axis != "" and sigma != 0: # if no psf is provided and we have the information to make a 1d one @@ -233,7 +233,7 @@ def wrap_progress( result = u[-1] else: # data.dims == 1 - if type(arr) is not np.ndarray: + if not isinstance(arr, np.ndarray): arr = arr.values u = [arr] for _ in range(n_iterations): @@ -241,17 +241,17 @@ def wrap_progress( u.append(u[-1] * scipy.ndimage.convolve(arr / c, np.flip(psf, 0), mode=mode)) # not yet tested to ensure flip correct for asymmetric psf # note: need to explicitly specify axis number in np.flip in lower versions of numpy - if type(data) is np.ndarray: + if isinstance(data, np.ndarray): result = u[-1].copy() else: - result = normalize_to_spectrum(data).copy(deep=True) + result = data.copy(deep=True) result.values = u[-1] with contextlib.suppress(Exception): return result.transpose(*arr.dims) @update_provenance("Make 1D-Point Spread Function") -def make_psf1d(data: DataType, dim: str, sigma: float) -> xr.DataArray: +def make_psf1d(data: xr.DataArray, dim: str, sigma: float) -> xr.DataArray: """Produces a 1-dimensional gaussian point spread function for use in deconvolve_rl. Args: @@ -262,7 +262,7 @@ def make_psf1d(data: DataType, dim: str, sigma: float) -> xr.DataArray: Returns: A one dimensional point spread array. """ - arr = normalize_to_spectrum(data) + arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) psf = arr.copy(deep=True) * 0 + 1 other_dims = list(arr.dims) other_dims.remove(dim) @@ -272,7 +272,7 @@ def make_psf1d(data: DataType, dim: str, sigma: float) -> xr.DataArray: @update_provenance("Make Point Spread Function") -def make_psf(data: DataType, sigmas: dict[str, float]) -> xr.DataArray: +def make_psf(data: xr.DataArray, sigmas: dict[str, float]) -> xr.DataArray: """Produces an n-dimensional gaussian point spread function for use in deconvolve_rl. Not yet operational. @@ -286,7 +286,7 @@ def make_psf(data: DataType, sigmas: dict[str, float]) -> xr.DataArray: """ raise NotImplementedError - arr = normalize_to_spectrum(data) + arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) dims = arr.dims psf = arr.copy(deep=True) * 0 + 1 diff --git a/arpes/analysis/derivative.py b/arpes/analysis/derivative.py index 7ed53c4f..620a0770 100644 --- a/arpes/analysis/derivative.py +++ b/arpes/analysis/derivative.py @@ -16,7 +16,6 @@ from numpy.typing import NDArray - from arpes._typing import DataType __all__ = ( "curvature2d", @@ -81,7 +80,7 @@ def _vector_diff( @update_provenance("Minimum Gradient") def minimum_gradient( - data: DataType, + data: xr.DataArray, *, smooth_fn: Callable[[xr.DataArray], xr.DataArray] | None = None, delta: DELTA = 1, @@ -99,7 +98,7 @@ def warpped_filter(arr: xr.DataArray): Returns: The gradient of the original intensity, which enhances the peak position. """ - arr = normalize_to_spectrum(data) + arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) assert isinstance(arr, xr.DataArray) smooth_ = _nothing_to_array if smooth_fn is None else smooth_fn arr = smooth_(arr) @@ -107,7 +106,7 @@ def warpped_filter(arr: xr.DataArray): @update_provenance("Gradient Modulus") -def _gradient_modulus(data: DataType, *, delta: DELTA = 1) -> xr.DataArray: +def _gradient_modulus(data: xr.DataArray, *, delta: DELTA = 1) -> xr.DataArray: """Helper function for minimum gradient. Args: @@ -117,7 +116,7 @@ def _gradient_modulus(data: DataType, *, delta: DELTA = 1) -> xr.DataArray: Returns: xr.DataArray [TODO:description] """ - spectrum = normalize_to_spectrum(data) + spectrum = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) assert isinstance(spectrum, xr.DataArray) values: NDArray[np.float_] = spectrum.values gradient_vector = np.zeros(shape=(8, *values.shape)) diff --git a/arpes/analysis/gap.py b/arpes/analysis/gap.py index f08f97f8..fa17a6e1 100644 --- a/arpes/analysis/gap.py +++ b/arpes/analysis/gap.py @@ -66,7 +66,11 @@ def determine_broadened_fermi_distribution( "vary": False, } - reference_data_array = normalize_to_spectrum(reference_data) + reference_data_array = ( + reference_data + if isinstance(reference_data, xr.DataArray) + else normalize_to_spectrum(reference_data) + ) sum_dims = list(reference_data_array.dims) sum_dims.remove("eV") @@ -183,10 +187,12 @@ def normalize_by_fermi_dirac( def _shift_energy_interpolate( - data: DataType, + data: xr.DataArray, shift: xr.DataArray | None = None, ) -> xr.DataArray: - data_arr = normalize_to_spectrum(data).S.transpose_to_front("eV") + if not isinstance(data, xr.DataArray): + data = normalize_to_spectrum(data) + data_arr = data.S.transpose_to_front("eV") new_data = data_arr.copy(deep=True) new_axis = new_data.coords["eV"] @@ -221,7 +227,7 @@ def _shift_energy_interpolate( @update_provenance("Symmetrize") def symmetrize( - data: DataType, + data: xr.DataArray, *, subpixel: bool = False, full_spectrum: bool = False, @@ -243,7 +249,9 @@ def symmetrize( Returns: The symmetrized data. """ - data = normalize_to_spectrum(data).S.transpose_to_front("eV") + if not isinstance(data, xr.DataArray): + data = normalize_to_spectrum(data) + data = data.S.transpose_to_front("eV") if subpixel or full_spectrum: data = _shift_energy_interpolate(data) diff --git a/arpes/analysis/general.py b/arpes/analysis/general.py index 3e94ff63..96127d3e 100644 --- a/arpes/analysis/general.py +++ b/arpes/analysis/general.py @@ -19,7 +19,7 @@ from .filters import gaussian_filter_arr if TYPE_CHECKING: - from arpes._typing import DataType + from arpes._typing import DataType, XrTypes __all__ = ( "normalize_by_fermi_distribution", @@ -32,7 +32,7 @@ @update_provenance("Fit Fermi Edge") def fit_fermi_edge( - data: DataType, + data: XrTypes, energy_range: slice | None = None, ) -> xr.Dataset: """Fits a Fermi edge. @@ -59,7 +59,7 @@ def fit_fermi_edge( @update_provenance("Normalized by the 1/Fermi Dirac Distribution at sample temp") def normalize_by_fermi_distribution( - data: DataType, + data: xr.DataArray, max_gain: float = 0, rigid_shift: float = 0, instrumental_broadening: float = 0, @@ -86,7 +86,7 @@ def normalize_by_fermi_distribution( Returns: Normalized DataArray """ - data_array = normalize_to_spectrum(data) + data_array = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) if not total_broadening: distrib = fermi_distribution( data_array.coords["eV"].values - rigid_shift, @@ -113,7 +113,7 @@ def normalize_by_fermi_distribution( @update_provenance("Symmetrize about axis") def symmetrize_axis( - data: DataType, + data: XrTypes, axis_name: str, flip_axes: list[str] | None = None, ) -> xr.DataArray: @@ -153,7 +153,7 @@ def symmetrize_axis( @update_provenance("Condensed array") -def condense(data: xr.DataArray) -> xr.DataArray: +def condense(data: DataType) -> DataType: """Clips the data so that only regions where there is substantial weight are included. In practice this usually means selecting along the ``eV`` axis, although other selections diff --git a/arpes/analysis/mask.py b/arpes/analysis/mask.py index 0e7036d0..0e357b1a 100644 --- a/arpes/analysis/mask.py +++ b/arpes/analysis/mask.py @@ -15,7 +15,6 @@ from _typeshed import Incomplete from numpy.typing import NDArray - from arpes._typing import DataType __all__ = ( "polys_to_mask", @@ -135,7 +134,7 @@ def apply_mask_to_coords( @update_provenance("Apply boolean mask to data") def apply_mask( - data: DataType, + data: xr.DataArray, mask: dict[str, Incomplete], replace: float = np.nan, radius=None, @@ -169,7 +168,7 @@ def apply_mask( Returns: Data with values masked out. """ - data_array = normalize_to_spectrum(data) + data_array = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) fermi = mask.get("fermi") if isinstance(mask, dict): diff --git a/arpes/analysis/pocket.py b/arpes/analysis/pocket.py index 36aeb142..4c0ded11 100644 --- a/arpes/analysis/pocket.py +++ b/arpes/analysis/pocket.py @@ -84,7 +84,7 @@ def pocket_parameters( @update_provenance("Collect EDCs projected at an angle from pocket") def radial_edcs_along_pocket( - data: XrTypes, + data: xr.DataArray, angle: float, radii: tuple[float, float] = (0.0, 5.0), n_points: int = 0, @@ -113,7 +113,7 @@ def radial_edcs_along_pocket( A 2D array which has an angular coordinate around the pocket center. """ inner_radius, outer_radius = radii - data_array = normalize_to_spectrum(data) + data_array = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) fermi_surface_dims = list(data_array.dims) assert "eV" in fermi_surface_dims @@ -158,7 +158,7 @@ def radial_edcs_along_pocket( def curves_along_pocket( - data: XrTypes, + data: xr.DataArray, n_points: int = 0, inner_radius: float = 0.0, outer_radius: float = 5.0, @@ -185,7 +185,7 @@ def curves_along_pocket( A tuple of two lists. The first list contains the slices and the second the coordinates of each slice around the pocket center. """ - data_array = normalize_to_spectrum(data) + data_array = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) assert isinstance(data_array, xr.DataArray) fermi_surface_dims = list(data_array.dims) if "eV" in fermi_surface_dims: @@ -237,7 +237,7 @@ def slice_at_angle(theta: float) -> xr.DataArray: def find_kf_by_mdc( - slice_data: XrTypes, + slice_data: xr.DataArray, offset: float = 0, **kwargs: Incomplete, ) -> float: @@ -254,8 +254,9 @@ def find_kf_by_mdc( Returns: The fitting Fermi momentum. """ - if isinstance(slice_data, xr.Dataset): - slice_arr = normalize_to_spectrum(slice_data) + slice_arr = ( + slice_data if isinstance(slice_data, xr.DataArray) else normalize_to_spectrum(slice_data) + ) assert isinstance(slice_arr, xr.DataArray) diff --git a/arpes/analysis/resolution.py b/arpes/analysis/resolution.py index e30bd09d..b89151b8 100644 --- a/arpes/analysis/resolution.py +++ b/arpes/analysis/resolution.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any import numpy as np +import xarray as xr # all resolutions are given by (photon energy, entrance slit, exit slit size) from arpes.constants import K_BOLTZMANN_MEV_KELVIN @@ -138,7 +139,7 @@ def analyzer_resolution( } -def analyzer_resolution_estimate(data: DataType, *, meV: bool = False) -> float: # noqa: N803 +def analyzer_resolution_estimate(data: xr.DataArray, *, meV: bool = False) -> float: # noqa: N803 """Estimates the energy resolution of the analyzer. For hemispherical analyzers, this can be determined by the slit @@ -151,7 +152,7 @@ def analyzer_resolution_estimate(data: DataType, *, meV: bool = False) -> float: Returns: The resolution in eV units. """ - data_array = normalize_to_spectrum(data) + data_array = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) endstation = data_array.S.endstation spectrometer_info = SPECTROMETER_INFORMATION[endstation] @@ -209,8 +210,8 @@ def energy_resolution_from_beamline_slit( return by_area[low] + (by_area[high] - by_area[low]) * (slit_area - low) / (high - low) -def beamline_resolution_estimate(data: DataType, *, meV: bool = False) -> None: # noqa: N803 - data_array = normalize_to_spectrum(data) +def beamline_resolution_estimate(data: xr.DataArray, *, meV: bool = False) -> None: # noqa: N803 + data_array = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) resolution_table: dict[ str, dict[tuple[float, tuple[float, float]], float], diff --git a/arpes/analysis/sarpes.py b/arpes/analysis/sarpes.py index 3f2fba6d..58804d10 100644 --- a/arpes/analysis/sarpes.py +++ b/arpes/analysis/sarpes.py @@ -1,4 +1,5 @@ """Contains very basic spin-ARPES analysis routines.""" + from __future__ import annotations from typing import TYPE_CHECKING @@ -62,7 +63,7 @@ def to_up_down(data: DataType) -> xr.Dataset: @update_provenance("Convert up-down spin channels to polarization") def to_intensity_polarization( - data: DataType, + data: xr.Dataset, *, perform_sherman_correction: bool = False, ) -> xr.Dataset: @@ -79,7 +80,7 @@ def to_intensity_polarization( Returns: The data after conversion to intensity-polarization representation. """ - data_set = normalize_to_dataset(data) + data_set = data if isinstance(data, xr.Dataset) else normalize_to_dataset(data) assert isinstance(data_set, xr.Dataset) assert "up" in data_set.data_vars diff --git a/arpes/analysis/self_energy.py b/arpes/analysis/self_energy.py index 4af5817d..ad9cf030 100644 --- a/arpes/analysis/self_energy.py +++ b/arpes/analysis/self_energy.py @@ -28,7 +28,7 @@ def get_peak_parameter( - data: xr.DataArray, # values is used + data: xr.DataArray, parameter_name: str, ) -> xr.DataArray: """Extracts a parameter from a potentially prefixed peak-like component. @@ -175,7 +175,7 @@ def quasiparticle_mean_free_path( def to_self_energy( - dispersion: xr.Dataset, + dispersion: xr.DataArray, bare_band: BareBandType | None = None, fermi_velocity: float = 0, *, @@ -220,7 +220,7 @@ def to_self_energy( dispersion = dispersion.results from_mdcs = "eV" in dispersion.dims # if eV is in the dimensions, then we fitted MDCs - estimated_bare_band = estimate_bare_band(dispersion, bare_band) + estimated_bare_band = estimate_bare_band(dispersion, bare_band_specification="ransac_linear") if not fermi_velocity: fermi_velocity = local_fermi_velocity(estimated_bare_band) @@ -272,8 +272,8 @@ def fit_for_self_energy( **kwargs, ) else: - possible_mometum_dims = ("phi", "theta", "psi", "beta", "kp", "kx", "ky", "kz") - mom_axes = set(data.dims).intersection(possible_mometum_dims) + possible_mometum_dims = {"phi", "theta", "psi", "beta", "kp", "kx", "ky", "kz"} + mom_axes = {str(dim) for dim in data.dims}.intersection(possible_mometum_dims) if len(mom_axes) > 1: msg = "Too many possible momentum dimensions, please clarify." @@ -285,4 +285,4 @@ def fit_for_self_energy( **kwargs, ) - return to_self_energy(fit_results, bare_band=bare_band) + return to_self_energy(fit_results.results, bare_band=bare_band) diff --git a/arpes/analysis/shirley.py b/arpes/analysis/shirley.py index 6de0c370..8c4449f3 100644 --- a/arpes/analysis/shirley.py +++ b/arpes/analysis/shirley.py @@ -1,4 +1,5 @@ """Contains routines for calculating and removing the classic Shirley background.""" + from __future__ import annotations import warnings @@ -13,7 +14,6 @@ if TYPE_CHECKING: from numpy.typing import NDArray - from arpes._typing import DataType __all__ = ( "calculate_shirley_background", @@ -23,7 +23,7 @@ @update_provenance("Remove Shirley background") -def remove_shirley_background(xps: DataType, **kwargs: float) -> xr.DataArray: +def remove_shirley_background(xps: xr.DataArray, **kwargs: float) -> xr.DataArray: """Calculates and removes a Shirley background from a spectrum. Only the background corrected spectrum is retrieved. @@ -35,7 +35,7 @@ def remove_shirley_background(xps: DataType, **kwargs: float) -> xr.DataArray: Returns: The the input array with a Shirley background subtracted. """ - xps_array = normalize_to_spectrum(xps) + xps_array = xps if isinstance(xps, xr.DataArray) else normalize_to_spectrum(xps) return xps_array - calculate_shirley_background(xps_array, **kwargs) @@ -89,7 +89,7 @@ def _calculate_shirley_background_full_range( @update_provenance("Calculate full range Shirley background") def calculate_shirley_background_full_range( - xps: DataType, + xps: xr.DataArray, eps: float = 1e-7, max_iters: int = 50, n_samples: int = 5, @@ -121,8 +121,11 @@ def calculate_shirley_background_full_range( Returns: A monotonic Shirley backgruond over the entire energy range. """ - xps_array = normalize_to_spectrum(xps).copy(deep=True) - assert isinstance(xps_array, xr.DataArray) + xps_array = ( + xps.copy(deep=True) + if isinstance(xps, xr.DataArray) + else normalize_to_spectrum(xps).copy(deep=True) + ) core_dims = [d for d in xps_array.dims if d != "eV"] return xr.apply_ufunc( @@ -140,7 +143,7 @@ def calculate_shirley_background_full_range( @update_provenance("Calculate limited range Shirley background") def calculate_shirley_background( - xps: DataType, + xps: xr.DataArray, energy_range: slice | None = None, eps: float = 1e-7, max_iters: int = 50, @@ -166,7 +169,7 @@ def calculate_shirley_background( if energy_range is None: energy_range = slice(None, None) - xps_array = normalize_to_spectrum(xps) + xps_array = xps if isinstance(xps, xr.DataArray) else normalize_to_spectrum(xps) assert isinstance(xps_array, xr.DataArray) xps_for_calc = xps_array.sel(eV=energy_range) diff --git a/arpes/analysis/tarpes.py b/arpes/analysis/tarpes.py index 0080b285..f0fa404e 100644 --- a/arpes/analysis/tarpes.py +++ b/arpes/analysis/tarpes.py @@ -3,7 +3,6 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING import numpy as np import xarray as xr @@ -12,15 +11,12 @@ from arpes.provenance import update_provenance from arpes.utilities import normalize_to_spectrum -if TYPE_CHECKING: - from arpes._typing import DataType, XrTypes - __all__ = ("find_t0", "relative_change", "normalized_relative_change") @update_provenance("Normalized subtraction map") def normalized_relative_change( - data: XrTypes, + data: xr.DataArray, t0: float | None = None, buffer: float = 0.3, *, @@ -41,7 +37,7 @@ def normalized_relative_change( Returns: The normalized data. """ - spectrum = normalize_to_spectrum(data) + spectrum = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) assert isinstance(spectrum, xr.DataArray) if normalize_delay: spectrum = normalize_dim(spectrum, "delay") @@ -55,7 +51,7 @@ def normalized_relative_change( @update_provenance("Created simple subtraction map") def relative_change( - data: xr.Dataset | xr.DataArray, + data: xr.DataArray, t0: float | None = None, buffer: float = 0.3, *, @@ -73,7 +69,7 @@ def relative_change( Returns: The normalized data. """ - spectrum = normalize_to_spectrum(data) + spectrum = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) assert isinstance(spectrum, xr.DataArray) if normalize_delay: spectrum = normalize_dim(spectrum, "delay") @@ -90,7 +86,7 @@ def relative_change( return spectrum - before_t0.mean("delay") -def find_t0(data: DataType, e_bound: float = 0.02) -> float: +def find_t0(data: xr.DataArray, e_bound: float = 0.02) -> float: """Finds the effective t0 by fitting excited carriers. Args: @@ -105,7 +101,7 @@ def find_t0(data: DataType, e_bound: float = 0.02) -> float: "This function will be deprecated, because it's not so physically correct.", stacklevel=2, ) - spectrum = normalize_to_spectrum(data) + spectrum = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) assert isinstance(spectrum, xr.DataArray) assert "delay" in spectrum.dims assert "eV" in spectrum.dims diff --git a/arpes/bootstrap.py b/arpes/bootstrap.py index 05051861..ffa00be9 100644 --- a/arpes/bootstrap.py +++ b/arpes/bootstrap.py @@ -64,7 +64,10 @@ @update_provenance("Estimate prior") -def estimate_prior_adjustment(data: DataType, region: dict[str, Any] | str | None = None) -> float: +def estimate_prior_adjustment( + data: xr.DataArray, + region: dict[str, Any] | str | None = None, +) -> float: r"""Estimates distribution generating the intensity histogram of pixels in a spectrum. In a perfectly linear, single-electron @@ -83,7 +86,7 @@ def estimate_prior_adjustment(data: DataType, region: dict[str, Any] | str | Non Returns: sigma / mu, the adjustment factor for the Poisson distribution """ - data_array = normalize_to_spectrum(data) + data_array = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) if region is None: region = "copper_prior" diff --git a/arpes/corrections/__init__.py b/arpes/corrections/__init__.py index dba41204..e4a01160 100644 --- a/arpes/corrections/__init__.py +++ b/arpes/corrections/__init__.py @@ -47,16 +47,16 @@ def __hash__(self): return hash(frozenset(self.items())) -def reference_key(data: XrTypes) -> HashableDict: +def reference_key(data: xr.DataArray) -> HashableDict: """Calculates a key/hash for data determining reference/correction equality.""" - data_array = normalize_to_dataset(data) + data_array = data if isinstance(data, xr.DataArray) else normalize_to_dataset(data) assert isinstance(data_array, xr.DataArray) return HashableDict(data_array.S.reference_settings) -def correction_from_reference_set(data: XrTypes, reference_set): +def correction_from_reference_set(data: xr.DataArray, reference_set): """Determines which correction to use from a set of references.""" - data_array = normalize_to_dataset(data) + data_array = data if isinstance(data, xr.DataArray) else normalize_to_dataset(data) correction = None for k, corr in reference_set.items(): if deep_equals(dict(reference_key(data_array)), dict(k)): diff --git a/arpes/corrections/background.py b/arpes/corrections/background.py index 44d1aec0..2991c18e 100644 --- a/arpes/corrections/background.py +++ b/arpes/corrections/background.py @@ -2,23 +2,18 @@ from __future__ import annotations -from typing import TYPE_CHECKING - import numpy as np import xarray as xr from arpes.provenance import update_provenance from arpes.utilities import normalize_to_spectrum -if TYPE_CHECKING: - from arpes._typing import DataType - __all__ = ("remove_incoherent_background",) @update_provenance("Remove incoherent background from above Fermi level") def remove_incoherent_background( - data: DataType, + data: xr.DataArray, *, set_zero: bool = True, ) -> xr.DataArray: @@ -30,21 +25,19 @@ def remove_incoherent_background( pulses). Args: - data (DataType): input ARPES data + data (XrTypes): input ARPES data set_zero (bool): set zero if the negative value is obtained after background subtraction. Returns: Data with a background subtracted. """ - data_array = normalize_to_spectrum(data) - assert isinstance(data_array, xr.DataArray) + data_array = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) approximate_fermi_energy_level = data_array.S.find_spectrum_energy_edges().max() background = data_array.sel(eV=slice(approximate_fermi_energy_level + 0.1, None)) density = background.sum("eV") / (np.logical_not(np.isnan(background)) * 1).sum("eV") new = data_array - density - assert isinstance(new, xr.DataArray) if set_zero: new.values[new.values < 0] = 0 diff --git a/arpes/corrections/fermi_edge_corrections.py b/arpes/corrections/fermi_edge_corrections.py index 81dda952..034cf9b8 100644 --- a/arpes/corrections/fermi_edge_corrections.py +++ b/arpes/corrections/fermi_edge_corrections.py @@ -18,7 +18,6 @@ from _typeshed import Incomplete - def _exclude_from_set(excluded): def exclude(_): return list(set(_).difference(excluded)) @@ -123,9 +122,7 @@ def apply_direct_fermi_edge_correction( provenance_context: PROVENANCE = { "what": "Shifted Fermi edge to align at 0 along hv axis", "by": "apply_photon_energy_fermi_edge_correction", - "correction": list( - correction.values if isinstance(correction, xr.DataArray) else correction, - ), + "correction": correction, # TODO: NEED check } provenance(corrected_arr, arr, provenance_context) @@ -140,7 +137,7 @@ def build_direct_fermi_edge_correction( along: str = "phi", *, plot: bool = False, -) -> xr.DataArray: +) -> xr.Dataset: """Builds a direct fermi edge correction stencil. This means that fits are performed at each value of the 'phi' coordinate @@ -150,7 +147,7 @@ def build_direct_fermi_edge_correction( Args: arr (xr.DataArray) : input DataArray - energy_range (slice): Energy range, which is used in xr.DataArray.sel(). + energy_range (slice): Energy range, which is used in xr.DataArray.sel(). defautl (-0.1, 0.1) plot (bool): if True, show the plot along (str): axis for non energy axis @@ -167,7 +164,7 @@ def build_direct_fermi_edge_correction( def sieve(_, v) -> bool: return v.item().params["center"].stderr < 0.001 # noqa: PLR2004 - corrections: xr.DataArray = edge_fit.G.filter_coord(along, sieve).G.map( + corrections = edge_fit.G.filter_coord(along, sieve).G.map( lambda x: x.params["center"].value, ) diff --git a/arpes/deep_learning/formatters.py b/arpes/deep_learning/formatters.py index ebaa6f00..d441dabf 100644 --- a/arpes/deep_learning/formatters.py +++ b/arpes/deep_learning/formatters.py @@ -1,7 +1,12 @@ """Provides plotting formatters for different kinds of data and targets.""" + from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from _typeshed import Incomplete + from matplotlib.axes import Axes __all__ = [ "SpectrumFormatter", @@ -12,7 +17,7 @@ class SpectrumFormatter: """Knows how to plot an ARPES spectrum onto an interpretation plot.""" - def show(self, data, ax=None) -> None: + def show(self, data: Incomplete, ax: Axes) -> None: """Just imshow the data for now with no other decoration.""" spectrum, row = data ax.imshow(spectrum, origin="lower") @@ -21,15 +26,15 @@ def show(self, data, ax=None) -> None: class FloatTitleFormatter: """Plots a floating point target as a title annotation onto a plot for its parent item.""" - context: dict[str, Any] = None + context: dict[str, Any] title_formatter: str = r"{label}={data:.3f}" @property - def computed_context(self) -> dict[str, Any]: + def computed_context(self) -> dict[str, str | bool]: """Annotate whether this is a ground truth or predicted value.""" return {"label": "True" if self.context.get("is_ground_truth", False) else "Pred"} - def show(self, data, ax=None) -> None: + def show(self, data: Incomplete, ax: Axes) -> None: """Sets the title for the parent data axis to be the formatted float value.""" title = ax.get_title() context = { diff --git a/arpes/deep_learning/interpret.py b/arpes/deep_learning/interpret.py index 9568ea95..42882654 100644 --- a/arpes/deep_learning/interpret.py +++ b/arpes/deep_learning/interpret.py @@ -3,6 +3,7 @@ This borrows ideas heavily from fastai which provides interpreter classes for different kinds of models. """ + from __future__ import annotations import math @@ -13,12 +14,13 @@ import numpy as np import torch import tqdm -from torch.utils.data.dataset import Subset +from torch.utils.data.dataset import Dataset, Subset if TYPE_CHECKING: import pytorch_lightning as pl + from _typeshed import Incomplete + from matplotlib.axes import Axes from torch.utils.data import DataLoader - __all__ = [ "Interpretation", "InterpretationItem", @@ -36,7 +38,7 @@ class InterpretationItem: parent_dataloader: DataLoader @property - def dataset(self): + def dataset(self) -> Dataset: """Fetches the original dataset used to train and containing this item. We need to unwrap the dataset in case we are actually dealing @@ -49,14 +51,20 @@ def dataset(self): dataset. """ dset = self.parent_dataloader.dataset - if isinstance(dset, Subset): dset = dset.dataset assert dset.is_indexed is True return dset - def show(self, input_formatter, target_formatter, ax=None, pullback=True) -> None: + def show( + self, + input_formatter: Incomplete, + target_formatter: Incomplete, + ax: Axes | None = None, + *, + pullback: bool = True, + ) -> None: """Plots item onto the provided axes. See also the `show` method of `Interpretation`.""" if ax is None: _, ax = plt.subplots() @@ -87,7 +95,7 @@ def show(self, input_formatter, target_formatter, ax=None, pullback=True) -> Non ) target_formatter.show(predicted, ax) - def decodes_target(self, value: Any) -> Any: + def decodes_target(self, value: Incomplete) -> Incomplete: """Pulls the predicted target backwards through the transformation stack. Pullback continues until an irreversible transform is met in order @@ -126,7 +134,7 @@ def items(self) -> list[InterpretationItem]: return self.val_item_lists[self.val_index] - def top_losses(self, ascending=False) -> list[InterpretationItem]: + def top_losses(self, *, ascending: bool = False) -> list[InterpretationItem]: """Orders the items by loss.""" def key(item): @@ -136,10 +144,10 @@ def key(item): def show( self, - n_items: int | tuple[int, int] | None = 9, + n_items: int | tuple[int, int] = 9, items: list[InterpretationItem] | None = None, - input_formatter=None, - target_formatter=None, + input_formatter: Incomplete = None, + target_formatter: Incomplete = None, ) -> None: """Plots a subset of the interpreted items. @@ -152,7 +160,7 @@ def show( layout = None if items is None: - if isinstance(n_items, tuple | list): + if isinstance(n_items, tuple): layout = n_items else: n_rows = int(math.ceil(n_items**0.5)) @@ -164,6 +172,7 @@ def show( n_rows = int(math.ceil(n_items**0.5)) layout = (n_rows, n_rows) + assert isinstance(n_items, int) _, axes = plt.subplots(*layout, figsize=(layout[0] * 3, layout[1] * 4)) items_with_nones = list(items) + [None] * (np.prod(layout) - n_items) @@ -176,7 +185,7 @@ def show( plt.tight_layout() @classmethod - def from_trainer(cls, trainer: pl.Trainer): + def from_trainer(cls: type[Incomplete], trainer: pl.Trainer) -> list[InterpretationItem]: """Builds an interpreter from an instance of a `pytorch_lightning.Trainer`.""" return cls(trainer.model, trainer.train_dataloader, trainer.val_dataloaders) @@ -206,15 +215,15 @@ def dataloader_to_item_list(self, dataloader: DataLoader) -> list[Interpretation InterpretationItem( torch.squeeze(yi), torch.squeeze(yi_hat), - torch.squeeze(loss), int(index), + torch.squeeze(loss), dataloader, ), ) return items - def __post_init__(self): + def __post_init__(self) -> None: """Populates train_items and val_item_lists. This is done by iterating through the dataloaders and pushing data through the models. diff --git a/arpes/deep_learning/models/regression.py b/arpes/deep_learning/models/regression.py index 7a6b2dbe..8f9cf43d 100644 --- a/arpes/deep_learning/models/regression.py +++ b/arpes/deep_learning/models/regression.py @@ -1,9 +1,15 @@ """Very simple regression baselines.""" + from __future__ import annotations +from typing import TYPE_CHECKING + import pytorch_lightning as pl -from torch import nn, optim -from torch.nn import functional +from torch import Tensor, nn, optim +from torch.nn import Linear, functional + +if TYPE_CHECKING: + from _typeshed import Incomplete __all__ = ["BaselineRegression", "LinearRegression"] @@ -20,24 +26,25 @@ def __init__(self) -> None: self.linear = nn.Linear(self.input_dimensions, self.output_dimensions) self.criterion = functional.mse_loss - def forward(self, x): + def forward(self, x: Incomplete) -> Linear: """Calculate the model output for the minibatch `x`.""" flat_x = x.view(x.size(0), -1) return self.linear(flat_x) - def training_step(self, batch, batch_index): + def training_step(self, batch: Incomplete) -> Tensor: """Perform one training minibatch.""" x, y = batch return self.criterion(self(x), y) - def validation_step(self, batch, batch_index): + def validation_step(self, batch: Incomplete) -> Tensor: """Perform one validation minibatch and record the validation loss.""" x, y = batch loss = self.criterion(self(x), y) self.log("val_loss", loss) + return loss - def configure_optimizers(self): + def configure_optimizers(self) -> optim.Adam: """Use standard optimizer settings.""" return optim.Adam(self.parameters(), lr=3e-3) @@ -56,25 +63,25 @@ def __init__(self) -> None: self.l3 = nn.Linear(128, self.output_dimensions) self.criterion = functional.mse_loss - def forward(self, x): + def forward(self, x: Incomplete) -> Linear: """Calculate the model output for the minibatch `x`.""" flat_x = x.view(x.size(0), -1) h1 = functional.relu(self.l1(flat_x)) h2 = functional.relu(self.l2(h1)) return self.l3(h2) - def training_step(self, batch, batch_index): + def training_step(self, batch: tuple[Incomplete, Incomplete]) -> Tensor: """Perform one training minibatch.""" x, y = batch return self.criterion(self(x).squeeze(), y) - def validation_step(self, batch, batch_index): + def validation_step(self, batch: tuple[Incomplete, Incomplete]) -> Tensor: """Perform one validation minibatch and record the validation loss.""" x, y = batch loss = self.criterion(self(x).squeeze(), y) self.log("val_loss", loss) return loss - def configure_optimizers(self): + def configure_optimizers(self) -> optim.Adam: """Use standard optimizer settings.""" return optim.Adam(self.parameters(), lr=3e-3) diff --git a/arpes/deep_learning/transforms.py b/arpes/deep_learning/transforms.py index 7a48ba72..c96313d1 100644 --- a/arpes/deep_learning/transforms.py +++ b/arpes/deep_learning/transforms.py @@ -1,11 +1,12 @@ """Implements transform pipelines for pytorch_lightning with basic inverse transform.""" + from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import Field, dataclass, field from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from collections.abc import Callable + from _typeshed import Incomplete __all__ = ["ComposeBoth", "ReversibleLambda", "Identity"] @@ -13,13 +14,13 @@ class Identity: """Represents a reversible identity transform.""" - def encodes(self, x): + def encodes(self, x: Incomplete) -> Incomplete: return x - def __call__(self, x): + def __call__(self, x: Incomplete) -> Incomplete: return x - def decodes(self, x): + def decodes(self, x: Incomplete) -> Incomplete: return x def __repr__(self) -> str: @@ -33,10 +34,10 @@ def __repr__(self) -> str: class ReversibleLambda: """A reversible anonymous function, so long as the caller supplies an inverse.""" - encodes: Callable = field(repr=False) - decodes: Callable = field(default=lambda x: x, repr=False) + encodes: Field = field(repr=False) + decodes: Field = field(default=lambda x: x, repr=False) - def __call__(self, value): + def __call__(self, value: Incomplete) -> Field[Incomplete]: """Apply the inner lambda to the data in forward pass.""" return self.encodes(value) @@ -47,7 +48,7 @@ class ComposeBoth: transforms: list[Any] - def __post_init__(self): + def __post_init__(self) -> None: """Replace missing transforms with identities.""" safe_transforms = [] for t in self.transforms: @@ -60,7 +61,7 @@ def __post_init__(self): self.original_transforms = self.transforms self.transforms = safe_transforms - def __call__(self, x, y): + def __call__(self, x: Incomplete, y: Incomplete) -> Incomplete: """If this transform has separate data and target functions, apply separately. Otherwise, we apply the single transform to both the data and the target. @@ -74,7 +75,7 @@ def __call__(self, x, y): return x, y - def decodes_target(self, y): + def decodes_target(self, y: Incomplete) -> Incomplete: """Pull the target back in the transform stack as far as possible. This is necessary only for the predicted target because diff --git a/arpes/endstations/plugin/BL10_SARPES.py b/arpes/endstations/plugin/BL10_SARPES.py index 120bc8f8..965b5cc0 100644 --- a/arpes/endstations/plugin/BL10_SARPES.py +++ b/arpes/endstations/plugin/BL10_SARPES.py @@ -79,6 +79,7 @@ def load_single_frame( """Loads all regions for a single .pxt frame, and perform per-frame normalization.""" from arpes.load_pxt import find_ses_files_associated, read_single_pxt + del kwargs if scan_desc is None: scan_desc = {} original_data_loc = scan_desc.get("path", scan_desc.get("file")) diff --git a/arpes/endstations/plugin/IF_UMCS.py b/arpes/endstations/plugin/IF_UMCS.py index bdf7dac0..6d912e5a 100644 --- a/arpes/endstations/plugin/IF_UMCS.py +++ b/arpes/endstations/plugin/IF_UMCS.py @@ -1,4 +1,5 @@ """Implements data loading for the IF UMCS Lublin ARPES group.""" + from __future__ import annotations from pathlib import Path @@ -22,7 +23,7 @@ __all__ = ("IF_UMCS",) -class IF_UMCS(HemisphericalEndstation, SingleFileEndstation): +class IF_UMCS(HemisphericalEndstation, SingleFileEndstation): # noqa: N801 """Implements loading xy text files from the Specs Prodigy software.""" PRINCIPAL_NAME = "IF_UMCS" @@ -47,10 +48,10 @@ class IF_UMCS(HemisphericalEndstation, SingleFileEndstation): } def load_single_frame( - self, - frame_path: str | Path = "", - scan_desc: SCANDESC | None = None, - **kwargs: str | float, + self, + frame_path: str | Path = "", + scan_desc: SCANDESC | None = None, + **kwargs: str | float, ) -> xr.Dataset: """Load single xy file.""" if scan_desc is None: @@ -65,9 +66,9 @@ def load_single_frame( raise RuntimeError(msg) def postprocess_final( - self, - data: xr.Dataset, - scan_desc: SCANDESC | None = None, + self, + data: xr.Dataset, + scan_desc: SCANDESC | None = None, ) -> xr.Dataset: """Add missing parameters.""" if scan_desc is None: diff --git a/arpes/endstations/plugin/MBS.py b/arpes/endstations/plugin/MBS.py index 83fd026c..ad135178 100644 --- a/arpes/endstations/plugin/MBS.py +++ b/arpes/endstations/plugin/MBS.py @@ -121,17 +121,17 @@ def load_single_frame( with Path(frame_path).open() as f: lines = f.readlines() - lines = [_.strip() for _ in lines] + lines = [line.strip() for line in lines] data_index = lines.index("DATA:") header = lines[:data_index] data = lines[data_index + 1 :] data_array = np.array([[float(f) for f in d] for d in [d.split() for d in data]]) del data - header = [h.split("\t") for h in header] - header = [h for h in header if len(h) == 2] - alt = [h for h in header if len(h) == 1] - header.append(["alt", str(alt)]) - attrs = clean_keys(dict(header)) + headers = [h.split("\t") for h in header] + headers = [h for h in headers if len(h) == len(("item", "value"))] + alt = [h for h in headers if len(h) == len(("only_item",))] + headers.append(["alt", str(alt)]) + attrs = clean_keys(dict(headers)) eV_axis = np.linspace( float(attrs["start_k_e_"]), diff --git a/arpes/endstations/plugin/igor_plugin.py b/arpes/endstations/plugin/igor_plugin.py index 859b6de6..824e7cc8 100644 --- a/arpes/endstations/plugin/igor_plugin.py +++ b/arpes/endstations/plugin/igor_plugin.py @@ -3,8 +3,10 @@ This does not load data according to the PyARPES data model, so you should ideally use a specific data loader where it is available. """ + from __future__ import annotations +from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger from typing import TYPE_CHECKING, ClassVar import xarray as xr @@ -27,6 +29,19 @@ __all__ = ("IgorEndstation",) +LOGLEVELS = (DEBUG, INFO) +LOGLEVEL = LOGLEVELS[1] +logger = getLogger(__name__) +fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s" +formatter = Formatter(fmt) +handler = StreamHandler() +handler.setLevel(LOGLEVEL) +logger.setLevel(LOGLEVEL) +handler.setFormatter(formatter) +logger.addHandler(handler) +logger.propagate = False + + class IgorEndstation(SingleFileEndstation): """A generic file loader for PXT files. @@ -68,7 +83,8 @@ def load_single_frame( **kwargs: Incomplete, ) -> xr.Dataset: """Igor .pxt and .ibws are single files so we just read the one passed here.""" - print(frame_path, scan_desc) + del kwargs + logger.info(f"frame_path: {frame_path}, scan_desc: {scan_desc}") pxt_data = read_single_pxt(frame_path) return xr.Dataset({"spectrum": pxt_data}, attrs=pxt_data.attrs) diff --git a/arpes/fits/utilities.py b/arpes/fits/utilities.py index 178968fe..94744797 100644 --- a/arpes/fits/utilities.py +++ b/arpes/fits/utilities.py @@ -33,7 +33,6 @@ import lmfit - from arpes._typing import XrTypes __all__ = ("broadcast_model", "result_to_hints") @@ -121,7 +120,7 @@ def read_token(token: str) -> str | float: @update_provenance("Broadcast a curve fit along several dimensions") def broadcast_model( # noqa: PLR0913 model_cls: type[lmfit.Model] | Sequence[type[lmfit.Model]] | str, - data: XrTypes, + data: xr.DataArray, broadcast_dims: str | list[str], params: dict | None = None, weights: xr.DataArray | None = None, @@ -166,9 +165,7 @@ def broadcast_model( # noqa: PLR0913 broadcast_dims = [broadcast_dims] logger.debug("Normalizing to spectrum") - data_array = normalize_to_spectrum(data) - del data - assert isinstance(data_array, xr.DataArray) + data_array = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) cs = {} for dim in broadcast_dims: cs[dim] = data_array.coords[dim] diff --git a/arpes/models/band.py b/arpes/models/band.py index 27d737a4..1ec829c1 100644 --- a/arpes/models/band.py +++ b/arpes/models/band.py @@ -1,4 +1,5 @@ """Rudimentary band analysis code.""" + from __future__ import annotations from typing import TYPE_CHECKING diff --git a/arpes/plotting/bz.py b/arpes/plotting/bz.py index 9cb97d04..81b9ee39 100644 --- a/arpes/plotting/bz.py +++ b/arpes/plotting/bz.py @@ -56,7 +56,7 @@ } -LOGLEVEL = (DEBUG, INFO)[1] +LOGLEVEL = (DEBUG, INFO)[0] logger = getLogger(__name__) fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s" formatter = Formatter(fmt) @@ -329,7 +329,9 @@ def plot_data_to_bz3d( def bz_plot( - cell: Sequence[Sequence[float]] | NDArray[np.float_], *args, **kwargs: Incomplete + cell: Sequence[Sequence[float]] | NDArray[np.float_], + *args, + **kwargs: Incomplete, ) -> Axes: """Dimension generic BZ plot which uses the cell dimension to delegate.""" logger.debug(f"size of cell is: {format(len(cell))}") @@ -540,12 +542,12 @@ def annotate_special_paths( **kwargs: Incomplete, ) -> None: """Annotates user indicated paths in k-space by plotting lines (or points) over the BZ.""" - logger.debug(f"ax: {ax}") - logger.debug(f"paths: {paths}") - logger.debug(f"cell: {cell}") - logger.debug(f"offset: {offset}") - logger.debug(f"special_points: {special_points}") - logger.debug(f"labels: {labels}") + logger.debug(f"annotate-ax: {ax}") + logger.debug(f"annotate-paths: {paths}") + logger.debug(f"annotate-cell: {cell}") + logger.debug(f"annotate-offset: {offset}") + logger.debug(f"annotate-special_points: {special_points}") + logger.debug(f"annotate-labels: {labels}") if kwargs: for k, v in kwargs.items(): logger.debug(f"kwargs: kyes: {k}, value: {v}") @@ -559,12 +561,14 @@ def annotate_special_paths( labels = paths converted_paths = process_kpath(paths, cell, special_points=special_points) + logger.debug(f"converted_paths: {converted_paths}") if not isinstance(labels[0], list): labels = [labels] labels = [list(label) for label in labels] paths = list(zip(labels, converted_paths, strict=True)) + logger.debug(f"paths in annotate_special_paths {paths}") fontsize = kwargs.pop("fontsize", 14) if offset is None: @@ -647,7 +651,7 @@ def bz2d_segments( return segments_x, segments_y -def twocell_to_bz1(cell: Sequence[Sequence[float]] | NDArray[np.float_]): +def twocell_to_bz1(cell: NDArray[np.float_]): from ase.dft.bz import bz_vertices # 2d in x-y plane @@ -655,7 +659,7 @@ def twocell_to_bz1(cell: Sequence[Sequence[float]] | NDArray[np.float_]): assert all(abs(cell[2][0:2]) < 1e-6) # noqa: PLR2004 assert all(abs(cell.T[2][0:2]) < 1e-6) # noqa: PLR2004 else: - cell = [[*list(c), 0] for c in cell] + [[0, 0, 1]] + cell = np.array([[*list(c), 0] for c in cell] + [[0, 0, 1]]) icell = np.linalg.inv(cell).T try: bz1 = bz_vertices(icell[:3, :3], dim=2) @@ -684,16 +688,17 @@ def bz2d_plot( Plots a Brillouin zone corresponding to a given unit cell """ - logger.debug(f"cell: {cell}") - logger.debug(f"paths: {paths}") - logger.debug(f"points: {points}") - logger.debug(f"repeat: {repeat}") - logger.debug(f"transformations: {transformations}") - logger.debug(f"hide_ax: {hide_ax}") - logger.debug(f"vectors: {vectors}") - logger.debug(f"set_equal_aspect: {set_equal_aspect}") + logger.debug(f"bz2d_plot-cell: {cell}") + logger.debug(f"bz2d_plot-paths: {paths}") + logger.debug(f"bz2d_plot-points: {points}") + logger.debug(f"bz2d_plot-repeat: {repeat}") + logger.debug(f"bz2d_plot-transformations: {transformations}") + logger.debug(f"bz2d_plot-hide_ax: {hide_ax}") + logger.debug(f"bz2d_plot-vectors: {vectors}") + logger.debug(f"bz2d_plot-set_equal_aspect: {set_equal_aspect}") kpoints = points bz1, icell, cell = twocell_to_bz1(cell) + logger.debug(f"bz1 : {bz1}") if ax is None: ax = plt.axes() diff --git a/arpes/plotting/dos.py b/arpes/plotting/dos.py index 3858c609..95e38807 100644 --- a/arpes/plotting/dos.py +++ b/arpes/plotting/dos.py @@ -1,10 +1,10 @@ """Plotting utilities related to density of states plots.""" + from __future__ import annotations from typing import TYPE_CHECKING import matplotlib as mpl -import numpy as np import xarray as xr from matplotlib import colors, gridspec from matplotlib import pyplot as plt @@ -59,15 +59,17 @@ def plot_core_levels( # noqa: PLR0913 @save_plot_provenance def plot_dos( - data: DataType, + data: xr.DataArray, title: str = "", out: str | Path = "", norm: Normalize | None = None, dos_pow: float = 1, ) -> Path | tuple[Figure, Axes, Colorbar]: """Plots the density of states (momentum integrated) image next to the original spectrum.""" - data_arr = normalize_to_spectrum(data) + data_arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) + assert isinstance(data_arr, xr.DataArray) + fig = plt.figure(figsize=(14, 6)) fig.subplots_adjust(hspace=0.00) gs = gridspec.GridSpec(2, 1, height_ratios=[3, 1]) @@ -75,7 +77,7 @@ def plot_dos( ax0 = plt.subplot(gs[0]) axes = (ax0, plt.subplot(gs[1], sharex=ax0)) - data_arr.values[np.isnan(data_arr.values)] = 0 # <== FIXME CONSIDER xr.DataArray fillna(0) + data_arr.fillna(0) cbar_axes = mpl.colorbar.make_axes(axes, pad=0.01) mesh = data_arr.plot(ax=axes[0], norm=norm or colors.PowerNorm(gamma=0.15)) diff --git a/arpes/plotting/dynamic_tool.py b/arpes/plotting/dynamic_tool.py index 9d323419..08b49353 100644 --- a/arpes/plotting/dynamic_tool.py +++ b/arpes/plotting/dynamic_tool.py @@ -7,6 +7,7 @@ from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger from typing import TYPE_CHECKING, Any +import xarray as xr from more_itertools import ichunked from PySide6 import QtWidgets @@ -185,8 +186,8 @@ def before_show(self) -> None: self.update_data() self.window.setWindowTitle(f"Interactive {self._function.__name__}") - def set_data(self, data: DataType) -> None: - self.data = normalize_to_spectrum(data) + def set_data(self, data: xr.DataArray) -> None: + self.data = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) def make_dynamic(fn: Callable[..., Any], data: DataType) -> None: diff --git a/arpes/plotting/fermi_surface.py b/arpes/plotting/fermi_surface.py index f9c48e62..29a33584 100644 --- a/arpes/plotting/fermi_surface.py +++ b/arpes/plotting/fermi_surface.py @@ -26,8 +26,6 @@ from matplotlib.typing import ColorType from numpy.typing import NDArray - from arpes._typing import DataType - __all__ = ( "fermi_surface_slices", @@ -82,7 +80,7 @@ def fermi_surface_slices( @save_plot_provenance def magnify_circular_regions_plot( - data: DataType, + data: xr.DataArray, magnified_points: NDArray[np.float_] | list[float], mag: float = 10, radius: float = 0.05, @@ -111,7 +109,7 @@ def magnify_circular_regions_plot( Returns: [TODO:description] """ - data_arr = normalize_to_spectrum(data) + data_arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) assert isinstance(data_arr, xr.DataArray) fig: Figure | None = None diff --git a/arpes/plotting/parameter.py b/arpes/plotting/parameter.py index ed7628bd..8908a6e1 100644 --- a/arpes/plotting/parameter.py +++ b/arpes/plotting/parameter.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Unpack import matplotlib.pyplot as plt @@ -14,38 +14,43 @@ import numpy as np import xarray as xr from matplotlib.axes import Axes - from matplotlib.typing import RGBColorType from numpy.typing import NDArray + from arpes._typing import MPLPlotKwargs + __all__ = ("plot_parameter",) @save_plot_provenance -def plot_parameter( +def plot_parameter( # noqa: PLR0913 fit_data: xr.DataArray, param_name: str, ax: Axes | None = None, - fillstyle: Literal["full", "left", "right", "bottom", "top", "none"] = "none", shift: float = 0, x_shift: float = 0, - markersize: int = 8, *, two_sigma: bool = False, - **kwargs: tuple | RGBColorType, + figsize: tuple[float, float] = (7, 5), + **kwargs: Unpack[MPLPlotKwargs], ) -> Axes: """Makes a simple scatter plot of a parameter from an `broadcast_fit` result.""" if ax is None: - _, ax = plt.subplots(figsize=kwargs.pop("figsize", (7, 5))) + _, ax = plt.subplots(figsize=figsize) ds = fit_data.F.param_as_dataset(param_name) x_name = ds.value.dims[0] x: NDArray[np.float_] = ds.coords[x_name].values + kwargs.setdefault("fillstyle", "none") + kwargs.setdefault("markersize", 8) + + fillstyle = kwargs.pop("fillstyle") + markersize = kwargs.pop("markersize") color = kwargs.get("color") e_width = None l_width = None if two_sigma: - _, __, lines = ax.errorbar( + _, _, lines = ax.errorbar( x + x_shift, ds.value.values + shift, yerr=2 * ds.error.values, diff --git a/arpes/plotting/qt_ktool/__init__.py b/arpes/plotting/qt_ktool/__init__.py index c1cb84cf..aa7aa55a 100644 --- a/arpes/plotting/qt_ktool/__init__.py +++ b/arpes/plotting/qt_ktool/__init__.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING import numpy as np +import xarray as xr from more_itertools import ichunked from PySide6 import QtWidgets @@ -17,7 +18,6 @@ if TYPE_CHECKING: from collections.abc import Callable, Sequence - import xarray as xr from _typeshed import Incomplete from matplotlib.colors import Colormap from PySide6.QtWidgets import QGridLayout @@ -177,7 +177,7 @@ def set_data(self, data: xr.DataArray) -> None: Above what happens in QtTool, we try to extract a Fermi surface, and repopulate the conversion. """ - original_data = normalize_to_spectrum(data) + original_data = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) self.original_data: xr.DataArray = original_data if len(data.dims) > 2: # noqa: PLR2004 @@ -210,9 +210,9 @@ def set_data(self, data: xr.DataArray) -> None: } -def ktool(data: XrTypes, **kwargs: Incomplete) -> KTool: +def ktool(data: xr.DataArray, **kwargs: Incomplete) -> KTool: """Start the momentum conversion tool.""" - data_arr = normalize_to_spectrum(data) + data_arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) tool = KTool(**kwargs) tool.set_data(data_arr) tool.start() diff --git a/arpes/plotting/qt_tool/__init__.py b/arpes/plotting/qt_tool/__init__.py index e1dbae09..9d99e990 100644 --- a/arpes/plotting/qt_tool/__init__.py +++ b/arpes/plotting/qt_tool/__init__.py @@ -13,6 +13,7 @@ import matplotlib as mpl import numpy as np import pyqtgraph as pg +import xarray as xr from PySide6 import QtCore, QtWidgets from PySide6.QtWidgets import QGridLayout @@ -33,7 +34,6 @@ from .BinningInfoWidget import BinningInfoWidget if TYPE_CHECKING: - import xarray as xr from _typeshed import Incomplete from PySide6.QtCore import QEvent from PySide6.QtGui import QKeyEvent @@ -527,9 +527,9 @@ def reset_intensity(self) -> None: """Autoscales intensity in each marginal plot.""" self.update_cursor_position(self.context["cursor"], force=True, keep_levels=False) - def set_data(self, data: XrTypes) -> None: + def set_data(self, data: xr.DataArray) -> None: """Sets the current data to a new value and resets binning.""" - data_arr = normalize_to_spectrum(data) + data_arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) if np.any(np.isnan(data_arr)): warnings.warn("Nan values encountered, copying data and assigning zeros.", stacklevel=2) diff --git a/arpes/plotting/spatial.py b/arpes/plotting/spatial.py index e07916f7..4f96e58b 100644 --- a/arpes/plotting/spatial.py +++ b/arpes/plotting/spatial.py @@ -41,7 +41,7 @@ @save_plot_provenance def plot_spatial_reference( - reference_map: DataType, + reference_map: xr.DataArray, data_list: list[DataType], offset_list: list[dict[str, Any] | None] | None = None, annotation_list: list[str] | None = None, @@ -68,8 +68,8 @@ def plot_spatial_reference( if annotation_list is None: annotation_list = [str(i + 1) for i in range(len(data_list))] - - normalize_to_spectrum(reference_map) + if not isinstance(reference_map, xr.DataArray): + reference_map = normalize_to_spectrum(reference_map) n_references = len(data_list) if n_references == 1 and plot_refs: @@ -216,14 +216,14 @@ def plot_spatial_reference( @save_plot_provenance def reference_scan_spatial( - data: DataType, + data: xr.DataArray, out: str | Path = "", ) -> Path | tuple[Figure, NDArray[np.object_]]: """Plots the spatial content of a dataset, useful as a quick reference. Warning: Not work correctly. (Because S.referenced_scans has been removed.) """ - data_arr = normalize_to_spectrum(data) + data_arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) assert isinstance(data_arr, xr.DataArray) diff --git a/arpes/plotting/spin.py b/arpes/plotting/spin.py index 592cc19a..f2f77291 100644 --- a/arpes/plotting/spin.py +++ b/arpes/plotting/spin.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, reveal_type +from typing import TYPE_CHECKING import matplotlib as mpl import matplotlib.colors diff --git a/arpes/plotting/stack_plot.py b/arpes/plotting/stack_plot.py index 97811e58..9eb3c6aa 100644 --- a/arpes/plotting/stack_plot.py +++ b/arpes/plotting/stack_plot.py @@ -207,7 +207,7 @@ def offset_scatter_plot( @save_plot_provenance def flat_stack_plot( # noqa: PLR0913 - data: XrTypes, + data: xr.DataArray, *, stack_axis: str = "", ax: Axes | None = None, @@ -244,8 +244,8 @@ def flat_stack_plot( # noqa: PLR0913 NotImplementedError _description_ """ - data_array = normalize_to_spectrum(data) - assert isinstance(data_array, xr.DataArray) + data_array = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) + two_dimensional = 2 if len(data_array.dims) != two_dimensional: @@ -535,7 +535,7 @@ def _scale_factor( def _rebinning( - data: XrTypes, + data: xr.DataArray, stack_axis: str, max_stacks: int, ) -> tuple[xr.DataArray, str, str]: @@ -545,7 +545,7 @@ def _rebinning( 2. determine the stack axis 3. determine the name of the other. """ - data_arr = normalize_to_spectrum(data) + data_arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) assert isinstance(data_arr, xr.DataArray) data_arr_must_be_two_dimensional = 2 assert len(data_arr.dims) == data_arr_must_be_two_dimensional diff --git a/arpes/plotting/utils.py b/arpes/plotting/utils.py index f0298608..16c79782 100644 --- a/arpes/plotting/utils.py +++ b/arpes/plotting/utils.py @@ -433,10 +433,9 @@ def transform_labels( ax.set_title(transform_fn(ax.get_title())) -def summarize(data: DataType, axes: NDArray[np.object_] | None = None) -> NDArray[np.object_]: +def summarize(data: xr.DataArray, axes: NDArray[np.object_] | None = None) -> NDArray[np.object_]: """Makes a summary plot with different marginal plots represented.""" - data_arr = normalize_to_spectrum(data) - assert isinstance(data_arr, xr.DataArray) + data_arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) axes_shapes_for_dims = { 1: (1, 1), 2: (1, 1), @@ -1159,11 +1158,11 @@ def polarization_colorbar(ax: Axes | None = None) -> colorbar.Colorbar: ) -def calculate_aspect_ratio(data: DataType) -> float: +def calculate_aspect_ratio(data: xr.DataArray) -> float: """Calculate the aspect ratio which should be used for plotting some data based on extent.""" - data_arr = normalize_to_spectrum(data) - assert isinstance(data_arr, xr.DataArray) - assert len(data.dims) == TwoDimensional + data_arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) + + assert len(data.dims_arr) == TwoDimensional x_extent = np.ptp(data_arr.coords[data_arr.dims[0]].values) y_extent = np.ptp(data_arr.coords[data_arr.dims[1]].values) diff --git a/arpes/preparation/axis_preparation.py b/arpes/preparation/axis_preparation.py index ef01151b..af7ba4e6 100644 --- a/arpes/preparation/axis_preparation.py +++ b/arpes/preparation/axis_preparation.py @@ -73,7 +73,12 @@ def sort_axis(data: xr.DataArray, axis_name: str) -> xr.DataArray: @update_provenance("Flip data along axis") -def flip_axis(arr: xr.DataArray, axis_name: str, *, flip_data: bool = True) -> xr.DataArray: +def flip_axis( + arr: xr.DataArray, # valuse is used + axis_name: str, + *, + flip_data: bool = True, +) -> xr.DataArray: """Flips the coordinate values along an axis w/o changing the data as well. Args: diff --git a/arpes/provenance.py b/arpes/provenance.py index e7fd725e..4166fbca 100644 --- a/arpes/provenance.py +++ b/arpes/provenance.py @@ -130,7 +130,7 @@ def update_provenance( what: str, *, keep_parent_ref: bool = False, -) -> Callable[[Callable[P, XrTypes]], Callable[P, XrTypes]]: +) -> Callable[[Callable[P, R]], Callable[P, R]]: """A decorator that promotes a function to one that records data provenance. Args: @@ -142,8 +142,8 @@ def update_provenance( """ def update_provenance_decorator( - fn: Callable[P, XrTypes], - ) -> Callable[P, XrTypes]: + fn: Callable[P, R], + ) -> Callable[P, R]: """[TODO:summary]. Args: @@ -151,7 +151,7 @@ def update_provenance_decorator( """ @functools.wraps(fn) - def func_wrapper(*args: P.args, **kwargs: P.kwargs) -> XrTypes: + def func_wrapper(*args: P.args, **kwargs: P.kwargs) -> R: arg_parents = [ v for v in args if isinstance(v, xr.Dataset | xr.Dataset) and "id" in v.attrs ] diff --git a/arpes/utilities/conversion/base.py b/arpes/utilities/conversion/base.py index 22daa3b5..b43de710 100644 --- a/arpes/utilities/conversion/base.py +++ b/arpes/utilities/conversion/base.py @@ -124,6 +124,7 @@ def conversion_for( ) -> Callable[[NDArray[np.float_]], NDArray[np.float_]]: """Fetches the method responsible for calculating `dim` from momentum coordinates.""" assert isinstance(dim, str) + return self.kspace_to_BE def identity_transform(self, axis_name: str, *args: Incomplete) -> NDArray[np.float_]: """Just returns the coordinate requested from args. diff --git a/arpes/utilities/conversion/bounds_calculations.py b/arpes/utilities/conversion/bounds_calculations.py index 69958185..06815550 100644 --- a/arpes/utilities/conversion/bounds_calculations.py +++ b/arpes/utilities/conversion/bounds_calculations.py @@ -251,7 +251,9 @@ def calculate_kp_bounds(arr: xr.DataArray) -> tuple[float, float]: return round(np.min(kps), 2), round(np.max(kps), 2) -def calculate_kx_ky_bounds(arr: xr.DataArray) -> tuple[tuple[float, float], tuple[float, float]]: +def calculate_kx_ky_bounds( + arr: xr.DataArray, +) -> tuple[tuple[np.float_, np.float_], tuple[np.float_, np.float_]]: """Calculates the kx and ky range for a dataset with a fixed photon energy. This is used to infer the gridding that should be used for a k-space conversion. diff --git a/arpes/utilities/conversion/core.py b/arpes/utilities/conversion/core.py index 7c2ac1ce..a44477c0 100644 --- a/arpes/utilities/conversion/core.py +++ b/arpes/utilities/conversion/core.py @@ -56,7 +56,7 @@ LOGLEVELS = (DEBUG, INFO) -LOGLEVEL = LOGLEVELS[0] +LOGLEVEL = LOGLEVELS[1] logger = getLogger(__name__) fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s" formatter = Formatter(fmt) @@ -105,7 +105,7 @@ def grid_interpolator_from_dataarray( ) -def slice_along_path( +def slice_along_path( # noqa: PLR0913 arr: xr.DataArray, interpolation_points: NDArray[np.float_] | None = None, axis_name: str = "", @@ -317,7 +317,7 @@ def interpolated_coordinate_to_raw(*coordinates: NDArray[np.float_]) -> NDArray[ @update_provenance("Automatically k-space converted") -def convert_to_kspace( +def convert_to_kspace( # noqa: PLR0913 arr: xr.DataArray, bounds: dict[MOMENTUM, tuple[float, float]] | None = None, resolution: dict[MOMENTUM, float] | None = None, @@ -349,7 +349,6 @@ def convert_to_kspace( Examples: Convert a 2D cut with automatically inferred range and resolution. - >>> convert_to_kspace(arpes.io.load_example_data()) # doctest: +SKIP xr.DataArray(...) @@ -470,7 +469,11 @@ def convert_to_kspace( { "dims": converted_dims, "transforms": dict( - zip(arr.dims, [converter.conversion_for(dim) for dim in arr.dims], strict=True), + zip( + (str(dim) for dim in arr.dims), + [converter.conversion_for(dim) for dim in arr.dims], + strict=True, + ), ), }, ) @@ -562,10 +565,9 @@ def acceptable_coordinate(c: NDArray[np.float_] | xr.DataArray) -> bool: Returns: [TODO:description] """ - try: + if isinstance(c, xr.DataArray): return bool(set(c.dims).issubset(coordinate_transform["dims"])) - except AttributeError: - return True + return True target_coordinates = {k: v for k, v in target_coordinates.items() if acceptable_coordinate(v)} data = xr.DataArray( diff --git a/arpes/utilities/conversion/forward.py b/arpes/utilities/conversion/forward.py index 02528e69..8d8f64a4 100644 --- a/arpes/utilities/conversion/forward.py +++ b/arpes/utilities/conversion/forward.py @@ -48,7 +48,7 @@ LOGLEVELS = (DEBUG, INFO) -LOGLEVEL = LOGLEVELS[0] +LOGLEVEL = LOGLEVELS[1] logger = getLogger(__name__) fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s" formatter = Formatter(fmt) @@ -61,7 +61,7 @@ def convert_coordinate_forward( - data: XrTypes, + data: xr.DataArray, coords: dict[str, float], **k_coords: Unpack[KspaceCoords], ) -> dict[str, float]: @@ -94,14 +94,15 @@ def convert_coordinate_forward( Another approach would be to write down the exact small angle approximated transforms. Args: - data (DataType): The data defining the coordinate offsets and experiment geometry. + data (XrTypes): The data defining the coordinate offsets and experiment geometry. + (should be DataArray) coords (dict[str, float]): The coordinates of a *point* in angle-space to be converted. k_coords: Coordinate for k-axis Returns: The location of the desired coordinate in momentum. """ - data_arr = normalize_to_spectrum(data) + data_arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) if "eV" in coords: coords = dict(coords) energy_coord = coords.pop("eV") @@ -148,7 +149,7 @@ def convert_through_angular_pair( # noqa: PLR0913 *, relative_coords: bool = True, **k_coords: NDArray[np.float_], -) -> dict[str, float]: +) -> xr.DataArray: """Converts the lower dimensional ARPES cut passing through `first_point` and `second_point`. This is a sibling method to `convert_through_angular_point`. A point and a `chi` angle @@ -232,7 +233,6 @@ def convert_through_angular_pair( # noqa: PLR0913 **transverse_specification, kx=parallel_axis, ).mean(list(transverse_specification.keys())) - logger.debug("Annotating the requested point momentum values.") return converted_data.assign_attrs( { diff --git a/arpes/utilities/conversion/kx_ky_conversion.py b/arpes/utilities/conversion/kx_ky_conversion.py index 4231689f..9cd4bc5e 100644 --- a/arpes/utilities/conversion/kx_ky_conversion.py +++ b/arpes/utilities/conversion/kx_ky_conversion.py @@ -254,7 +254,7 @@ def conversion_for(self, dim: str) -> Callable[[NDArray[np.float_]], NDArray[np. def _with_identity(*args: NDArray[np.float_]) -> NDArray[np.float_]: return self.identity_transform(dim, *args) - return { + return { # type: ignore[return-value] "eV": self.kspace_to_BE, "phi": self.kspace_to_phi, }.get( @@ -388,7 +388,7 @@ def conversion_for(self, dim: str) -> Callable[[NDArray[np.float_]], NDArray[np. def _with_identity(*args: NDArray[np.float_]) -> NDArray[np.float_]: return self.identity_transform(dim, *args) - return { + return { # type: ignore[return-value] "eV": self.kspace_to_BE, "phi": self.kspace_to_phi, "theta": self.kspace_to_perp_angle, diff --git a/arpes/utilities/conversion/kz_conversion.py b/arpes/utilities/conversion/kz_conversion.py index 810e2c2d..c41f4af4 100644 --- a/arpes/utilities/conversion/kz_conversion.py +++ b/arpes/utilities/conversion/kz_conversion.py @@ -198,7 +198,7 @@ def conversion_for(self, dim: str) -> Callable[[NDArray[np.float_]], NDArray[np. def _with_identity(*args: NDArray[np.float_]) -> NDArray[np.float_]: return self.identity_transform(dim, *args) - return { + return { # type: ignore[return-value] "eV": self.kspace_to_BE, "hv": self.kspace_to_hv, "phi": self.kspace_to_phi, diff --git a/arpes/utilities/conversion/remap_manipulator.py b/arpes/utilities/conversion/remap_manipulator.py index e68cfe7c..f250df4c 100644 --- a/arpes/utilities/conversion/remap_manipulator.py +++ b/arpes/utilities/conversion/remap_manipulator.py @@ -1,4 +1,5 @@ """Contains utilities to determine equivalent coordinates between pairs of scans.""" + from __future__ import annotations from copy import deepcopy diff --git a/arpes/utilities/selections.py b/arpes/utilities/selections.py index 15b91a98..efb12d7b 100644 --- a/arpes/utilities/selections.py +++ b/arpes/utilities/selections.py @@ -91,7 +91,7 @@ def _normalize_point( def select_disk_mask( - data: DataType, + data: xr.DataArray, radius: float, outer_radius: float | None = None, around: dict | xr.Dataset | None = None, @@ -125,8 +125,9 @@ def select_disk_mask( if outer_radius is not None and radius > outer_radius: radius, outer_radius = outer_radius, radius - data_array = normalize_to_spectrum(data) - around = _normalize_point(data, around, **kwargs) + data_array = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) + + around = _normalize_point(data_array, around, **kwargs) raveled = data_array.G.ravel() @@ -146,7 +147,7 @@ def select_disk_mask( def select_disk( - data: DataType, + data: xr.DataArray, radius: float, outer_radius: float | None = None, around: dict | xr.Dataset | None = None, @@ -176,7 +177,7 @@ def select_disk( invert: Whether to invert the mask, i.e. everything but the annulus kwargs: The central point, otherwise specified by `around` """ - data_array = normalize_to_spectrum(data) + data_array = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) around = _normalize_point(data_array, around, **kwargs) mask = select_disk_mask(data_array, radius, outer_radius=outer_radius, around=around, flat=True) diff --git a/arpes/widgets.py b/arpes/widgets.py index 5c1173a1..262bb3a6 100644 --- a/arpes/widgets.py +++ b/arpes/widgets.py @@ -798,7 +798,7 @@ def on_select_summed(region) -> None: @popout def kspace_tool( - data: DataType, + data: xr.DataArray, overplot_bz: Callable[[Axes], None] | list[Callable[[Axes], None]] | None = None, bounds: dict[MOMENTUM, tuple[float, float]] | None = None, resolution: dict | None = None, @@ -825,7 +825,7 @@ def kspace_tool( """ """A utility for assigning coordinate offsets using a live momentum conversion.""" original_data = data - data_array = normalize_to_spectrum(data) + data_array = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) assert isinstance(data_array, xr.DataArray) if len(data_array.dims) > TWO_DIMENSION: diff --git a/arpes/xarray_extensions.py b/arpes/xarray_extensions.py index daadd366..e61cdb53 100644 --- a/arpes/xarray_extensions.py +++ b/arpes/xarray_extensions.py @@ -3158,11 +3158,11 @@ def param_as_dataset(self, param_name: str) -> xr.Dataset: }, ) - def show(self, *, detached: bool = False) -> None: + def show(self) -> None: """Opens a Qt based interactive fit inspection tool.""" from .plotting.fit_tool import fit_tool - fit_tool(self._obj, detached=detached) + fit_tool(self._obj) def best_fits(self) -> xr.DataArray: """Orders the fits into a raveled array by the MSE error.""" From d9c9af7f8fd1539c705329e4883d3755f389fe31 Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Fri, 9 Feb 2024 20:38:10 +0900 Subject: [PATCH 7/7] =?UTF-8?q?=F0=9F=92=AC=20=20Update=20type=20hints?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- arpes/utilities/conversion/base.py | 14 +++++++++++++- arpes/utilities/conversion/bounds_calculations.py | 6 +++--- arpes/utilities/conversion/core.py | 14 +++++++++++--- arpes/xarray_extensions.py | 12 ++++++------ 4 files changed, 33 insertions(+), 13 deletions(-) diff --git a/arpes/utilities/conversion/base.py b/arpes/utilities/conversion/base.py index b43de710..706775ab 100644 --- a/arpes/utilities/conversion/base.py +++ b/arpes/utilities/conversion/base.py @@ -34,7 +34,19 @@ K_SPACE_BORDER = 0.02 -MOMENTUM_BREAKPOINTS = [0.0005, 0.001, 0.002, 0.005, 0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1] +MOMENTUM_BREAKPOINTS: list[float] = [ + 0.0005, + 0.001, + 0.002, + 0.005, + 0.01, + 0.02, + 0.05, + 0.1, + 0.2, + 0.5, + 1, +] class CoordinateConverter: diff --git a/arpes/utilities/conversion/bounds_calculations.py b/arpes/utilities/conversion/bounds_calculations.py index 06815550..f987a5b4 100644 --- a/arpes/utilities/conversion/bounds_calculations.py +++ b/arpes/utilities/conversion/bounds_calculations.py @@ -253,7 +253,7 @@ def calculate_kp_bounds(arr: xr.DataArray) -> tuple[float, float]: def calculate_kx_ky_bounds( arr: xr.DataArray, -) -> tuple[tuple[np.float_, np.float_], tuple[np.float_, np.float_]]: +) -> tuple[tuple[float, float], tuple[float, float]]: """Calculates the kx and ky range for a dataset with a fixed photon energy. This is used to infer the gridding that should be used for a k-space conversion. @@ -316,6 +316,6 @@ def calculate_kx_ky_bounds( * np.sin(sampled_beta_values) ) return ( - (round(np.min(kxs), 2), round(np.max(kxs), 2)), - (round(np.min(kys), 2), round(np.max(kys), 2)), + (round(np.min(kxs), 2).astype(float), round(np.max(kxs), 2).astype(float)), + (round(np.min(kys), 2).astype(float), round(np.max(kys), 2).astype(float)), ) diff --git a/arpes/utilities/conversion/core.py b/arpes/utilities/conversion/core.py index a44477c0..9e22612e 100644 --- a/arpes/utilities/conversion/core.py +++ b/arpes/utilities/conversion/core.py @@ -31,6 +31,7 @@ from typing import TYPE_CHECKING, Literal import numpy as np +from numpy.typing import ArrayLike import xarray as xr from scipy.interpolate import RegularGridInterpolator @@ -114,7 +115,7 @@ def slice_along_path( # noqa: PLR0913 *, extend_to_edge: bool = False, shift_gamma: bool = True, -) -> xr.DataArray: +) -> xr.Dataset: """Gets a cut along a path specified by waypoints in an array. TODO: There might be a little bug here where the last coordinate has a value of 0, @@ -215,7 +216,10 @@ def slice_along_path( # noqa: PLR0913 path_segments = list(pairwise(parsed_interpolation_points)) - def required_sampling_density(waypoint_a: Mapping, waypoint_b: Mapping) -> float: + def required_sampling_density( + waypoint_a: Mapping[Hashable, float], + waypoint_b: Mapping[Hashable, float], + ) -> float: ks = waypoint_a.keys() dist = _element_distance(waypoint_a, waypoint_b) delta = np.array([waypoint_a[k] - waypoint_b[k] for k in ks]) @@ -295,6 +299,7 @@ def interpolated_coordinate_to_raw(*coordinates: NDArray[np.float_]) -> NDArray[ }, as_dataset=True, ) + assert isinstance(converted_ds, xr.Dataset) if ( axis_name in arr.dims and len(parsed_interpolation_points) == 2 # noqa: PLR2004 @@ -697,6 +702,9 @@ def _extract_symmetry_point( return dict(zip([d for d in arr.dims if d in raw_point], S, strict=False)) -def _element_distance(waypoint_a: Mapping, waypoint_b: Mapping) -> np.float_: +def _element_distance( + waypoint_a: Mapping[Hashable, float], + waypoint_b: Mapping[Hashable, float], +) -> np.float_: delta = np.array([waypoint_a[k] - waypoint_b[k] for k in waypoint_a]) return np.linalg.norm(delta) diff --git a/arpes/xarray_extensions.py b/arpes/xarray_extensions.py index e61cdb53..badc07b8 100644 --- a/arpes/xarray_extensions.py +++ b/arpes/xarray_extensions.py @@ -82,7 +82,7 @@ from .utilities.xarray import unwrap_xarray_item if TYPE_CHECKING: - from collections.abc import Callable, Generator + from collections.abc import Callable, Generator, Iterator import lmfit from _typeshed import Incomplete @@ -145,7 +145,7 @@ def _iter_groups( grouped: dict[str, Sequence[float] | float], -) -> Generator[tuple[str, float], None, None]: +) -> Iterator[tuple[str, float]]: """Iterates through a flattened sequence. Sequentially yields keys and values from each sequence associated with a key. @@ -723,17 +723,17 @@ def symmetry_points( return self._calculate_symmetry_points(symmetry_points, **kwargs) @property - def iter_own_symmetry_points(self) -> Generator[tuple[str, float], None, None]: + def iter_own_symmetry_points(self) -> Iterator[tuple[str, float]]: sym_points, _ = self.symmetry_points() return _iter_groups(sym_points) @property - def iter_projected_symmetry_points(self) -> Generator[tuple[str, float], None, None]: + def iter_projected_symmetry_points(self) -> Iterator[tuple[str, float]]: _, sym_points = self.symmetry_points() return _iter_groups(sym_points) @property - def iter_symmetry_points(self) -> Generator[tuple[str, float], None, None]: + def iter_symmetry_points(self) -> Iterator[tuple[str, float]]: yield from self.iter_own_symmetry_points yield from self.iter_projected_symmetry_points @@ -2759,7 +2759,7 @@ def enumerate_iter_coords( def iter_coords( self, dim_names: tuple[str | Hashable, ...] = (), - ) -> Generator[dict[Hashable, float], None, None]: + ) -> Iterator[dict[Hashable, float]]: """[TODO:summary]. Args: