diff --git a/arpes/_typing.py b/arpes/_typing.py index dcc02366..eefa1cc0 100644 --- a/arpes/_typing.py +++ b/arpes/_typing.py @@ -11,7 +11,7 @@ from __future__ import annotations import uuid -from typing import TYPE_CHECKING, Literal, Required, TypedDict, TypeVar +from typing import TYPE_CHECKING, Literal, Required, TypeAlias, TypedDict, TypeVar import xarray as xr @@ -20,11 +20,28 @@ ## if TYPE_CHECKING: + from collections.abc import Callable, Sequence from pathlib import Path import numpy as np from _typeshed import Incomplete - from numpy.typing import NDArray + from matplotlib.artist import Artist + from matplotlib.backend_bases import Event + from matplotlib.colors import Colormap + from matplotlib.figure import Figure + from matplotlib.patheffects import AbstractPathEffect + from matplotlib.transforms import BboxBase, Transform + from matplotlib.typing import ( + CapStyleType, + ColorType, + DrawStyleType, + FillStyleType, + JoinStyleType, + LineStyleType, + MarkerType, + MarkEveryType, + ) + from numpy.typing import ArrayLike, NDArray __all__ = [ "DataType", @@ -41,7 +58,7 @@ ] DataType = TypeVar("DataType", xr.DataArray, xr.Dataset) -NormalizableDataType = DataType | str | uuid.UUID +NormalizableDataType: TypeAlias = DataType | str | uuid.UUID xr_types = (xr.DataArray, xr.Dataset) @@ -241,3 +258,108 @@ class SPECTROMETER(ANALYZERINFO, COORDINATES, total=False): class ARPESAttrs(TypedDict, total=False): pass + + +class MPLPlotKwargs(TypedDict, total=False): + scalex: bool + scaley: bool + + agg_filter: Callable[[NDArray[np.float_], int], tuple[NDArray[np.float_], int, int]] + alpha: float | None + animated: bool + antialiased: bool + aa: bool + clip_box: BboxBase | None + clip_on: bool + # clip_path: Path | None color: ColorType + c: ColorType + dash_capstyle: CapStyleType + dash_joinstyle: JoinStyleType + dashes: Sequence[float | None] + data: NDArray[np.float_] + drawstyle: DrawStyleType + figure: Figure + fillstyle: FillStyleType + gapcolor: ColorType | None + gid: str + in_layout: bool + label: str + linestyle: LineStyleType + ls: LineStyleType + linewidth: float + lw: float + marker: MarkerType + markeredgecolor: ColorType + mec: ColorType + markeredgewidth: float + mew: float + markerfacecolor: ColorType + mfc: ColorType + markerfacecoloralt: ColorType + mfcalt: ColorType + markersize: float + ms: float + markevery: MarkEveryType + mouseover: bool + path_effects: list[AbstractPathEffect] + picker: float | Callable[[Artist, Event], tuple[bool, dict]] + pickradius: float + rasterized: bool + sketch_params: tuple[float, float, float] + scale: float + length: float + randomness: float + snap: bool | None + solid_capstyle: CapStyleType + solid_joinstyle: JoinStyleType + url: str + visible: bool + xdata: NDArray[np.float_] + ydata: NDArray[np.float_] + zorder: float + + +class PColorMeshKwargs(TypedDict, total=False): + agg_filter: Callable[[NDArray[np.float_], int], tuple[NDArray[np.float_], int, int]] + alpha: float | None + animated: bool + antialiased: bool + aa: bool + array: ArrayLike + + capstyle: CapStyleType + + clim: tuple[float, float] + clip_box: BboxBase | None + clip_on: bool + cmap: Colormap | str | None + color: ColorType + edgecolor: ColorType + ec: ColorType + facecolor: ColorType + facecolors: ColorType + fc: ColorType + figure: Figure + gid: str + hatch: Literal["/", "\\", "|", "-", "+", "x", "o", "O", ".", "*"] + in_layout: bool + joinstyle: JoinStyleType + label: str + linestyle: LineStyleType + linewidth: float | list[float] + linewidths: float | list[float] + lw: float | list[float] + mouseover: bool + offsets: NDArray[np.float_] + path_effects: list[AbstractPathEffect] + picker: None | bool | float + rasterized: bool + sketch_params: tuple[float, float, float] + scale: float + randomness: float + snap: bool | None + transform: Transform + url: str + urls: list[str | None] + visible: bool + zorder: float diff --git a/arpes/all.py b/arpes/all.py deleted file mode 100644 index e9fb4034..00000000 --- a/arpes/all.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Convenience import module for PyARPES.""" - - -import arpes.config -from arpes.analysis.all import * -from arpes.fits import * -from arpes.plotting.all import * -from arpes.utilities.conversion import * -from arpes.workflow import * - -arpes.config.load_plugins() diff --git a/arpes/analysis/all.py b/arpes/analysis/all.py deleted file mode 100644 index 8a11f3c7..00000000 --- a/arpes/analysis/all.py +++ /dev/null @@ -1,14 +0,0 @@ -"""Convenience import module for analysis tools.""" -from .band_analysis import * -from .decomposition import * -from .deconvolution import * -from .derivative import * -from .filters import * -from .gap import * -from .general import * -from .kfermi import * -from .mask import * -from .pocket import * -from .savitzky_golay import * -from .tarpes import * -from .xps import * diff --git a/arpes/analysis/band_analysis.py b/arpes/analysis/band_analysis.py index dbf1e412..7016a9c4 100644 --- a/arpes/analysis/band_analysis.py +++ b/arpes/analysis/band_analysis.py @@ -50,7 +50,7 @@ def fit_for_effective_mass(data: DataType, fit_kwargs: dict | None = None) -> fl assert isinstance(fit_kwargs, dict) data_array = normalize_to_spectrum(data) mom_dim = next( - d for d in ["kp", "kx", "ky", "kz", "phi", "beta", "theta"] if d in data_array.dims + dim for dim in ["kp", "kx", "ky", "kz", "phi", "beta", "theta"] if dim in data_array.dims ) results = broadcast_model( @@ -61,7 +61,7 @@ def fit_for_effective_mass(data: DataType, fit_kwargs: dict | None = None) -> fl ) if mom_dim in {"phi", "beta", "theta"}: forward = convert_coordinates_to_kspace_forward(data_array) - final_mom = next(d for d in ["kx", "ky", "kp", "kz"] if d in forward) + final_mom = next(dim for dim in ["kx", "ky", "kp", "kz"] if dim in forward) eVs = results.F.p("a_center").values kps = [ forward[final_mom].sel(eV=eV, **dict([[mom_dim, ang]]), method="nearest") diff --git a/arpes/analysis/self_energy.py b/arpes/analysis/self_energy.py index 4501d6b4..41eeacc7 100644 --- a/arpes/analysis/self_energy.py +++ b/arpes/analysis/self_energy.py @@ -1,7 +1,7 @@ """Contains self-energy analysis routines.""" from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal, TypeAlias import lmfit as lf import numpy as np @@ -22,8 +22,8 @@ ) -BareBandType = xr.DataArray | str | lf.model.ModelResult -DispersionType = xr.DataArray | xr.Dataset +BareBandType: TypeAlias = xr.DataArray | str | lf.model.ModelResult +DispersionType: TypeAlias = xr.DataArray | xr.Dataset def get_peak_parameter(data: xr.DataArray, parameter_name: str) -> xr.DataArray: @@ -62,7 +62,7 @@ def get_peak_parameter(data: xr.DataArray, parameter_name: str) -> xr.DataArray: ) -def local_fermi_velocity(bare_band: xr.DataArray): +def local_fermi_velocity(bare_band: xr.DataArray) -> float: """Calculates the band velocity under assumptions of a linear bare band.""" fitted_model = LinearModel().guess_fit(bare_band) raw_velocity = fitted_model.params["slope"].value @@ -172,9 +172,9 @@ def quasiparticle_mean_free_path( def to_self_energy( dispersion: xr.DataArray, bare_band: BareBandType | None = None, + fermi_velocity: float | None = None, *, k_independent: bool = True, - fermi_velocity=None, ) -> xr.Dataset: r"""Converts MDC fit results into the self energy. @@ -198,8 +198,8 @@ def to_self_energy( Args: dispersion bare_band - k_independent fermi_velocity + k_independent: bool Returns: The equivalent self energy from the bare band and the measured dispersion. @@ -219,6 +219,7 @@ def to_self_energy( if fermi_velocity is None: fermi_velocity = local_fermi_velocity(estimated_bare_band) + assert isinstance(fermi_velocity, float) imaginary_part = get_peak_parameter(dispersion, "fwhm") / 2 centers = get_peak_parameter(dispersion, "center") @@ -243,7 +244,7 @@ def to_self_energy( def fit_for_self_energy( data: xr.DataArray, - method="mdc", + method: Literal["mdc", "edc"] = "mdc", bare_band: BareBandType | None = None, **kwargs: Incomplete, ) -> xr.Dataset: diff --git a/arpes/endstations/fits_utils.py b/arpes/endstations/fits_utils.py index c0eeea92..dd7cad0a 100644 --- a/arpes/endstations/fits_utils.py +++ b/arpes/endstations/fits_utils.py @@ -5,10 +5,11 @@ import warnings from ast import literal_eval from collections.abc import Callable, Iterable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeAlias import numpy as np from numpy import ndarray +from numpy._typing import NDArray from arpes.trace import traceable from arpes.utilities.funcutils import collect_leaves, iter_leaves @@ -33,7 +34,7 @@ "Z": "z", } -CoordsDict = dict[str, ndarray] +CoordsDict: TypeAlias = dict[str, NDArray[np.float_]] Dimension = str diff --git a/arpes/endstations/plugin/fallback.py b/arpes/endstations/plugin/fallback.py index 8074c416..1d1427df 100644 --- a/arpes/endstations/plugin/fallback.py +++ b/arpes/endstations/plugin/fallback.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from _typeshed import Incomplete + from arpes.endstations import SCANDESC __all__ = ("FallbackEndstation",) diff --git a/arpes/fits/zones.py b/arpes/fits/zones.py index 31628196..172ead27 100644 --- a/arpes/fits/zones.py +++ b/arpes/fits/zones.py @@ -31,7 +31,10 @@ def k_points_residual( - paramters, coords_dataset, high_symmetry_points, dimensionality: int = 2 + paramters, + coords_dataset, + high_symmetry_points, + dimensionality: int = 2, ) -> NDArray[np.float_]: momentum_coordinates = convert_coordinates(coords_dataset) if dimensionality == 2: diff --git a/arpes/plotting/all.py b/arpes/plotting/all.py deleted file mode 100644 index 1bea6a2f..00000000 --- a/arpes/plotting/all.py +++ /dev/null @@ -1,26 +0,0 @@ -"""Import many useful standard tools.""" -from .annotations import * -from .band_tool import * -from .bands import * -from .basic import * -from .comparison_tool import * -from .curvature_tool import * -from .dispersion import * -from .dos import * -from .dyn_tool import * -from .fermi_edge import * -from .fermi_surface import * -from .fit_inspection_tool import * - -# 'Tools' -# Note, we lift Bokeh imports into definitions in case people don't want to install Bokeh -# and also because of an undesirable interaction between pytest and Bokeh due to Bokeh's use -# of jinja2. -from .interactive import * -from .mask_tool import * -from .movie import * -from .parameter import * -from .path_tool import * -from .spatial import * -from .spin import * -from .stack_plot import * diff --git a/arpes/plotting/fermi_surface.py b/arpes/plotting/fermi_surface.py index 878ef043..09083997 100644 --- a/arpes/plotting/fermi_surface.py +++ b/arpes/plotting/fermi_surface.py @@ -1,22 +1,31 @@ """Simple plotting routes related constant energy slices and Fermi surfaces.""" -from pathlib import Path + +from __future__ import annotations + from typing import TYPE_CHECKING +import holoviews as hv # pylint: disable=import-error import matplotlib.patches import matplotlib.path import matplotlib.pyplot as plt import numpy as np import xarray as xr -from _typeshed import Incomplete from matplotlib.axes import Axes -from arpes._typing import DataType from arpes.plotting.utils import path_for_holoviews, path_for_plot from arpes.provenance import save_plot_provenance from arpes.utilities import normalize_to_spectrum if TYPE_CHECKING: + from pathlib import Path + + from matplotlib.colors import Colormap from matplotlib.figure import Figure + from matplotlib.typing import ColorType + from numpy.typing import NDArray + + from arpes._typing import DataType + __all__ = ( "fermi_surface_slices", @@ -29,17 +38,14 @@ def fermi_surface_slices( arr: xr.DataArray, n_slices: int = 9, ev_per_slice: float = 0.02, - bin: float = 0.01, + binning: float = 0.01, out: str | Path = "", - **kwargs: Incomplete, -): +) -> hv.Layout | Path: """Plots many constant energy slices in an axis grid.""" - import holoviews as hv # pylint: disable=import-error - slices = [] for i in range(n_slices): high = -ev_per_slice * i - low = high - bin + low = high - binning image = hv.Image( arr.sum( [d for d in arr.dims if d not in ["theta", "beta", "phi", "eV", "kp", "kx", "ky"]], @@ -63,16 +69,16 @@ def fermi_surface_slices( @save_plot_provenance def magnify_circular_regions_plot( data: DataType, - magnified_points, - mag=10, - radius=0.05, - cmap="viridis", - color=None, - edgecolor="red", + magnified_points: NDArray[np.float_] | list[float], + mag: float = 10, + radius: float = 0.05, + cmap: Colormap | ColorType = "viridis", + color: ColorType | None = None, + edgecolor: ColorType = "red", out: str | Path = "", ax: Axes | None = None, - **kwargs: Incomplete, -): + **kwargs: tuple[float, float], +) -> tuple[Figure, Axes] | Path: """Plots a Fermi surface with inset points magnified in an inset.""" data_arr = normalize_to_spectrum(data) assert isinstance(data_arr, xr.DataArray) @@ -81,6 +87,8 @@ def magnify_circular_regions_plot( if ax is None: fig, ax = plt.subplots(figsize=kwargs.pop("figsize", (7, 5))) + assert isinstance(ax, Axes) + mesh = data_arr.plot(ax=ax, cmap=cmap) clim = list(mesh.get_clim()) clim[1] = clim[1] / mag diff --git a/arpes/plotting/stack_plot.py b/arpes/plotting/stack_plot.py index b26d0c6a..20a77d36 100644 --- a/arpes/plotting/stack_plot.py +++ b/arpes/plotting/stack_plot.py @@ -5,7 +5,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Literal, TypedDict +from typing import TYPE_CHECKING, Literal import matplotlib as mpl import matplotlib.pyplot as plt @@ -34,7 +34,9 @@ from _typeshed import Incomplete from matplotlib.figure import Figure - from matplotlib.typing import RGBAColorType + from matplotlib.typing import ( + ColorType, + ) from numpy.typing import NDArray from arpes._typing import DataType @@ -45,20 +47,12 @@ ) -class MPLPlotKwargs(TypedDict, total=False): - scalex: bool - scaley: bool - alpha: float | None - animated: bool - antialiased: bool - - @save_plot_provenance def offset_scatter_plot( data: xr.Dataset, name_to_plot: str = "", stack_axis: str = "", - cbarmap: tuple[colorbar.Colorbar, Callable[[float], RGBAColorType]] | None = None, + cbarmap: tuple[colorbar.Colorbar, Callable[[float], ColorType]] | None = None, ax: Axes | None = None, out: str | Path = "", scale_coordinate: float = 0.5, @@ -74,7 +68,7 @@ def offset_scatter_plot( data(xr.Dataset): _description_ name_to_plot(str): name of the spectrum (in many case 'spectrum' is set), by default "" stack_axis(str): _description_, by default "" - cbarmap(tuple[colorbar.Colorbar, Callable[[float], RGBAColorType]] | None): _description_, + cbarmap(tuple[colorbar.Colorbar, Callable[[float], ColorType]] | None): _description_, by default None ax(Axes | None): _description_, by default None out(str | Path): _description @@ -202,7 +196,7 @@ def offset_scatter_plot( def flat_stack_plot( data: DataType, stack_axis: str = "", - color: RGBAColorType | Colormap = "viridis", + color: ColorType | Colormap = "viridis", ax: Axes | None = None, mode: Literal["line", "scatter"] = "line", fermi_level: float | None = None, @@ -214,7 +208,7 @@ def flat_stack_plot( Args: data(DataType): ARPES data (xr.DataArray is prepfered) stack_axis(str): axis for stacking, by default "" - color(RGBAColorType|Colormap): Colormap + color(ColorType|Colormap): Colormap ax (Axes | None): matplotlib Axes, by default None mode(Literal["line", "scatter"]): plot style (line/scatter), by default "line" fermi_level(float|None): Value corresponding to the Fermi level to Draw the line, diff --git a/arpes/utilities/conversion/bounds_calculations.py b/arpes/utilities/conversion/bounds_calculations.py index 186f43c7..ca68de50 100644 --- a/arpes/utilities/conversion/bounds_calculations.py +++ b/arpes/utilities/conversion/bounds_calculations.py @@ -101,13 +101,13 @@ def full_angles_to_k( # noqa: PLR0913 def euler_to_kx( - kinetic_energy: float, - phi: float, - beta: float, + kinetic_energy: NDArray[np.float_], + phi: NDArray[np.float_] | float, + beta: NDArray[np.float_] | float, theta: float = 0, *, slit_is_vertical: bool = False, -) -> float: +) -> NDArray[np.float_]: """Calculates kx from the phi/beta Euler angles given the experimental geometry.""" if slit_is_vertical: return K_INV_ANGSTROM * np.sqrt(kinetic_energy) * np.sin(beta) * np.cos(phi) @@ -115,13 +115,13 @@ def euler_to_kx( def euler_to_ky( - kinetic_energy: float, - phi: float, - beta: float, + kinetic_energy: NDArray[np.float_], + phi: NDArray[np.float_] | float, + beta: NDArray[np.float_] | float, theta: float = 0, *, slit_is_vertical: bool = False, -) -> float: +) -> NDArray[np.float_]: """Calculates ky from the phi/beta Euler angles given the experimental geometry.""" if slit_is_vertical: return ( @@ -133,14 +133,14 @@ def euler_to_ky( def euler_to_kz( # noqa: PLR0913 - kinetic_energy: float, - phi: float, - beta: float, + kinetic_energy: NDArray[np.float_], + phi: NDArray[np.float_] | float, + beta: NDArray[np.float_] | float, theta: float = 0, inner_potential: float = 10, *, slit_is_vertical: bool = False, -) -> float: +) -> NDArray[np.float_]: """Calculates kz from the phi/beta Euler angles given the experimental geometry.""" if slit_is_vertical: beta_term = -np.sin(theta) * np.sin(phi) + np.cos(theta) * np.cos(beta) * np.cos(phi) diff --git a/arpes/utilities/conversion/fast_interp.py b/arpes/utilities/conversion/fast_interp.py index 4d1d61d0..002cfcd3 100644 --- a/arpes/utilities/conversion/fast_interp.py +++ b/arpes/utilities/conversion/fast_interp.py @@ -233,7 +233,7 @@ def __post_init__(self) -> None: self.data = self.data.astype(np.float64, copy=False) @classmethod - def from_arrays(cls, xyz: list[NDArray[np.float_]], data: NDArray[np.float_]): + def from_arrays(cls: type, xyz: list[NDArray[np.float_]], data: NDArray[np.float_]): """Initializes the interpreter from a coordinate and data array. Args: diff --git a/arpes/utilities/conversion/forward.py b/arpes/utilities/conversion/forward.py index 8b37b314..50e3f6a7 100644 --- a/arpes/utilities/conversion/forward.py +++ b/arpes/utilities/conversion/forward.py @@ -32,7 +32,6 @@ if TYPE_CHECKING: from collections.abc import Callable, Sequence - from _typeshed import Incomplete from numpy.typing import NDArray from arpes._typing import DataType @@ -53,7 +52,7 @@ def convert_coordinate_forward( *, trace: Callable = None, # noqa: RUF013 **k_coords: NDArray[np.float_], -): +) -> dict[str, float]: """Inverse/forward transform for the small angle volumetric k-conversion code. This differs from the other forward transforms here which are exact, @@ -86,7 +85,7 @@ def convert_coordinate_forward( data: The data defining the coordinate offsets and experiment geometry. coords: The coordinates of a point in angle-space to be converted. trace: Used for performance tracing and debugging. - k_coords: + k_coords: Coordinate for k-axis Returns: The location of the desired coordinate in momentum. @@ -150,7 +149,7 @@ def convert_through_angular_pair( # noqa: PLR0913 relative_coords: bool = True, trace: Callable = None, # noqa: RUF013 **k_coords: NDArray[np.float_], -): +) -> dict[str, float]: """Converts the lower dimensional ARPES cut passing through `first_point` and `second_point`. This is a sibling method to `convert_through_angular_point`. A point and a `chi` angle @@ -308,7 +307,6 @@ def convert_coordinates( arr: DataType, *, collapse_parallel: bool = False, - **kwargs: Incomplete, ) -> xr.Dataset: """Converts coordinates forward in momentum.""" @@ -390,17 +388,21 @@ def expand_to(cname, c): @update_provenance("Forward convert coordinates to momentum") -def convert_coordinates_to_kspace_forward(arr: DataType, **kwargs: Incomplete): - """Forward converts all the individual coordinates of the data array.""" +def convert_coordinates_to_kspace_forward(arr: DataType) -> xr.Dataset | None: + """Forward converts all the individual coordinates of the data array. + + Args: + arr: [TODO:description] + """ arr = arr.copy(deep=True) skip = {"eV", "cycle", "delay", "T"} keep = { "eV", } - all = {k: v for k, v in arr.indexes.items() if k not in skip} + all_indexes = {k: v for k, v in arr.indexes.items() if k not in skip} kept = {k: v for k, v in arr.indexes.items() if k in keep} - momentum_compatibles: list[str] = list(all.keys()) + momentum_compatibles: list[str] = list(all_indexes.keys()) momentum_compatibles.sort() if not momentum_compatibles: return None @@ -416,7 +418,7 @@ def convert_coordinates_to_kspace_forward(arr: DataType, **kwargs: Incomplete): ("hv", "phi", "theta"): ["kx", "ky", "kz"], ("hv", "phi", "psi"): ["kx", "ky", "kz"], ("chi", "hv", "phi"): ["kx", "ky", "kz"], - }.get(tuple(momentum_compatibles)) + }.get(tuple(momentum_compatibles), []) full_old_dims: list[str] = momentum_compatibles + list(kept.keys()) projection_vectors: NDArray[np.float_] = np.ndarray( shape=tuple(len(arr.coords[d]) for d in full_old_dims), @@ -523,7 +525,6 @@ def broadcast_by_dim_location( # for now we are setting the theta angle to zero, this only has an effect for # vertical slit analyzers, and then only when the tilt angle is very large - # TODO: check me raw_translated = { "kx": euler_to_kx( kinetic_energy, diff --git a/arpes/utilities/funcutils.py b/arpes/utilities/funcutils.py index 816f3ce4..d90c62a0 100644 --- a/arpes/utilities/funcutils.py +++ b/arpes/utilities/funcutils.py @@ -9,10 +9,12 @@ import xarray as xr if TYPE_CHECKING: - from collections.abc import Callable, Generator, Iterator + from collections.abc import Callable, Generator, Iterator, Sequence + import numpy as np from _typeshed import Incomplete from numpy import ndarray + from numpy._typing import NDArray from arpes._typing import DataType @@ -25,26 +27,28 @@ ] -def cycle(sequence) -> Generator: +def cycle(sequence: Sequence) -> Generator: """Infinitely cycles a sequence.""" while True: yield from sequence -def group_by(grouping, sequence): +def group_by(grouping: int | Generator, sequence: Sequence) -> list: """Permits partitining a sequence into sets of items, for instance by taking two at a time.""" if isinstance(grouping, int): base_seq = [False] * grouping base_seq[-1] = True - grouping = cycle(base_seq) + grouping_cycle = cycle(base_seq) + else: + grouping_cycle = grouping groups = [] current_group = [] for elem in sequence: current_group.append(elem) - if (callable(grouping) and grouping(elem)) or next(grouping): + if (callable(grouping_cycle) and grouping_cycle(elem)) or next(grouping_cycle): groups.append(current_group) current_group = [] @@ -71,7 +75,7 @@ def collect_leaves(tree: dict[str, Any], is_leaf: Any = None) -> dict: A dictionary with the leaves and their direct parent key. """ - def reducer(dd: dict, item: tuple[str, ndarray]) -> dict: + def reducer(dd: dict, item: tuple[str, NDArray[np.float_]]) -> dict: dd[item[0]].append(item[1]) return dd diff --git a/arpes/utilities/geometry.py b/arpes/utilities/geometry.py index a9603d87..6b0d23fd 100644 --- a/arpes/utilities/geometry.py +++ b/arpes/utilities/geometry.py @@ -1,5 +1,6 @@ """Geometry and intersection utilities.""" import numpy as np +from numpy.typing import NDArray from scipy.spatial import ConvexHull __all__ = ( @@ -9,18 +10,24 @@ ) -def point_plane_intersection(plane_normal, plane_point, line_a, line_b, epsilon=1e-6): +def point_plane_intersection( + plane_normal: NDArray[np.float_], + plane_point: NDArray[np.float_], + line_a: NDArray[np.float_], + line_b: NDArray[np.float_], + epsilon: float = 1e-6, +) -> NDArray[np.float_] | None: """Determines the point plane intersection. The plane is defined by a point and a normal vector while the line is defined by line_a and line_b. All should be numpy arrays. Args: - plane_normal - plane_point - line_a - line_b - epsilon + plane_normal: The normal vector of the plane. + plane_point: The point in the plane. + line_a: The line A. + line_b: The line B. + epsilon: Precision of the line difference. Returns: The intersection point of the point and plane. @@ -35,7 +42,14 @@ def point_plane_intersection(plane_normal, plane_point, line_a, line_b, epsilon= return delta + projection * line_direction + plane_point -def segment_contains_point(line_a, line_b, point_along_line, check=False, epsilon=1e-6): +def segment_contains_point( + line_a: NDArray[np.float_], + line_b: NDArray[np.float_], + point_along_line: NDArray[np.float_] | None, + epsilon: float = 1e-6, + *, + check: bool = False, +) -> bool: """Determines whether a segment contains a point that also lies along the line. If asked to check, it will also return false if the point does not lie along the line. @@ -53,7 +67,7 @@ def segment_contains_point(line_a, line_b, point_along_line, check=False, epsilo return 0 - epsilon < delta.dot(delta_p) / delta.dot(delta) < 1 + epsilon -def polyhedron_intersect_plane(poly_faces, plane_normal, plane_point, epsilon=1e-6): +def polyhedron_intersect_plane(poly_faces, plane_normal, plane_point, epsilon: float = 1e-6): """Determines the intersection of a convex polyhedron intersecting a plane. The polyhedron faces should be given by a list of np.arrays, where each np.array at diff --git a/arpes/workflow.py b/arpes/workflow.py index e417aedc..b6f75c53 100644 --- a/arpes/workflow.py +++ b/arpes/workflow.py @@ -82,7 +82,7 @@ def _open_path(p: Path | str) -> None: @with_workspace -def go_to_workspace(workspace=None): +def go_to_workspace(workspace=None) -> None: """Opens the workspace folder, otherwise opens the location of the running notebook.""" path = Path.cwd() @@ -101,7 +101,7 @@ def go_to_cwd() -> None: _open_path(Path.cwd()) -def go_to_figures(): +def go_to_figures() -> None: """Opens the figures folder. If in a workspace, opens the figures folder for the current workspace and the current day, @@ -114,7 +114,7 @@ def go_to_figures(): _open_path(path) -def get_running_context(): +def get_running_context() -> tuple[Incomplete, Path]: return get_notebook_name(), Path.cwd() diff --git a/arpes/xarray_extensions.py b/arpes/xarray_extensions.py index 99f0f081..245ae8b1 100644 --- a/arpes/xarray_extensions.py +++ b/arpes/xarray_extensions.py @@ -66,6 +66,7 @@ reference_scan_fermi_surface, ) from arpes.plotting.parameter import plot_parameter +from arpes.plotting.spin import spin_polarized_spectrum from arpes.plotting.utils import fancy_labels, remove_colorbars from arpes.utilities import apply_dataarray from arpes.utilities.collections import MappableDict @@ -341,13 +342,13 @@ def transpose_to_back(self, dim: str): def select_around_data( self, points: dict[str, Any] | xr.Dataset, - radius: dict[str, float] | None = None, # radius={"phi": 0.005} + radius: dict[str, float] | float | None = None, # radius={"phi": 0.005} *, fast: bool = False, safe: bool = True, mode: Literal["sum", "mean"] = "sum", **kwargs: Incomplete, - ): + ) -> xr.DataArray: """Performs a binned selection around a point or points. Can be used to perform a selection along one axis as a function of another, integrating a @@ -3285,7 +3286,7 @@ def polarization_plot(self, **kwargs: IncompleteMPL) -> Axes: if out is not None and isinstance(out, bool): out = f"{self.label}_spin_polarization.png" kwargs["out"] = out - return plotting.spin.spin_polarized_spectrum(self._obj, **kwargs) + return spin_polarized_spectrum(self._obj, **kwargs) @property def is_spatial(self) -> bool: @@ -3618,7 +3619,7 @@ def _radian_to_degree(self) -> None: if angle in spectrum.coords: spectrum.coords[angle] = np.deg2rad(spectrum.coords[angle]) - def __init__(self, xarray_obj: xr.Dataset) -> None: + def __init__(self, xarray_obj: xr.DataArray) -> None: """Initialization hook for xarray. This should never need to be called directly. diff --git a/tests/test_basic_data_loading.py b/tests/test_basic_data_loading.py index 3111372a..4cdebe77 100644 --- a/tests/test_basic_data_loading.py +++ b/tests/test_basic_data_loading.py @@ -37,7 +37,7 @@ class TestMetadata: data = None - scenarios = [ + scenarios: ClassVar = [ # Lanzara Group "Main Chamber" ( "main_chamber_load_cut",