diff --git a/src/arpes/plotting/fermi_surface.py b/src/arpes/plotting/fermi_surface.py index 7109d2d6..a9c7a7ae 100644 --- a/src/arpes/plotting/fermi_surface.py +++ b/src/arpes/plotting/fermi_surface.py @@ -86,8 +86,8 @@ def magnify_circular_regions_plot( # noqa: PLR0913 radius: float = 0.05, # below this two can be treated as kwargs? cmap: Colormap | ColorType = "viridis", - color: ColorType | None = None, - edgecolor: ColorType = "red", + color: ColorType | list[ColorType] = "blue", + edgecolor: ColorType | list[ColorType] = "red", out: str | Path = "", ax: Axes | None = None, **kwargs: tuple[float, float], @@ -122,14 +122,13 @@ def magnify_circular_regions_plot( # noqa: PLR0913 clim = list(mesh.get_clim()) clim[1] = clim[1] / mag - mask = np.zeros(shape=(len(data_arr.values.ravel()),)) pts = np.zeros( shape=( len(data_arr.values.ravel()), 2, ), ) - mask = mask > 0 + mask = np.zeros(shape=len(data_arr.values.ravel())) > 0 raveled = data_arr.G.ravel() pts[:, 0] = raveled[data_arr.dims[0]] @@ -150,7 +149,6 @@ def magnify_circular_regions_plot( # noqa: PLR0913 if not isinstance(color, list): color = [color for _ in range(len(magnified_points))] - assert isinstance(color, list) pts[:, 1] = (pts[:, 1]) / (xlim[1] - xlim[0]) pts[:, 0] = (pts[:, 0]) / (ylim[1] - ylim[0]) @@ -168,7 +166,7 @@ def magnify_circular_regions_plot( # noqa: PLR0913 linewidth=2, zorder=4, ) - patchfake = matplotlib.patches.Ellipse([point[1], point[0]], radius, radius) + patchfake = matplotlib.patches.Ellipse((point[1], point[0]), radius, radius) ax.add_patch(patch) mask = np.logical_or(mask, patchfake.contains_points(pts)) diff --git a/src/arpes/plotting/utils.py b/src/arpes/plotting/utils.py index c9c196cb..6523eeab 100644 --- a/src/arpes/plotting/utils.py +++ b/src/arpes/plotting/utils.py @@ -11,11 +11,11 @@ import re import warnings from collections import Counter -from collections.abc import Generator, Hashable, Iterable, Iterator, Sequence +from collections.abc import Callable, Generator, Hashable, Iterable, Iterator, Sequence from datetime import UTC from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, Unpack, reveal_type +from typing import TYPE_CHECKING, Any, Literal, Unpack import matplotlib as mpl import matplotlib.pyplot as plt @@ -37,6 +37,7 @@ if TYPE_CHECKING: from _typeshed import Incomplete from lmfit.model import Model + from matplotlib.collections import PathCollection from matplotlib.font_manager import FontProperties from matplotlib.image import AxesImage from matplotlib.typing import ColorType @@ -576,7 +577,7 @@ def lineplot_arr( xs = None if arr is not None: - fn = plt.plot + fn: Callable[..., list[Line2D]] | Callable[..., PathCollection] = plt.plot if method == "scatter": fn = plt.scatter @@ -613,7 +614,7 @@ def plot_arr( if n_dims == TWO_DIMENSION: quad = None if arr is not None: - ax, quad = imshow_arr(arr, ax=ax, over=over, **kwargs) + fig, quad = imshow_arr(arr, ax=ax, over=over, **kwargs) if mask is not None: over = quad if over is None else over imshow_mask(mask, ax=ax, over=over, **kwargs) diff --git a/src/arpes/utilities/selections.py b/src/arpes/utilities/selections.py index 4dec5dd6..82c6e84a 100644 --- a/src/arpes/utilities/selections.py +++ b/src/arpes/utilities/selections.py @@ -40,7 +40,7 @@ def ravel_from_mask(data: DataType, mask: XrTypes) -> DataType: Returns: Raveled data with masked points removed. """ - return data.stack(stacked=mask.dims).where(mask.stack(stacked=mask.dims), drop=True) + return data.stack(stacked=list(mask.dims)).where(mask.stack(stacked=list(mask.dims)), drop=True) def unravel_from_mask( @@ -65,27 +65,30 @@ def unravel_from_mask( dest = template * 0 + 1 dest_mask = np.logical_not( np.isnan( - template.stack(stacked=template.dims).where(mask.stack(stacked=template.dims)).values, + template.stack(stacked=list(template.dims)) + .where(mask.stack(stacked=list(template.dims))) + .values, ), ) - dest = (dest * default).stack(stacked=template.dims) + dest = (dest * default).stack(stacked=list(template.dims)) dest.values[dest_mask] = values return dest.unstack("stacked") def _normalize_point( data: xr.DataArray, - around: dict[str, xr.DataArray] | xr.Dataset, - **kwargs: Incomplete, + around: dict[str, xr.DataArray] | xr.Dataset | None, + **kwargs: NDArray[np.float_] | float, ) -> dict[str, xr.DataArray]: - collected_kwargs = {k: kwargs[k] for k in data.dims if k in kwargs} + collected_kwargs = {k: kwargs[str(k)] for k in data.dims if k in kwargs} if around: if isinstance(around, xr.Dataset): - around = unwrap_xarray_dict({d: around[d] for d in data.dims}) + around = unwrap_xarray_dict({str(d): around[d] for d in data.dims}) else: around = collected_kwargs + assert isinstance(around, dict) assert set(around.keys()) == set(data.dims) return around diff --git a/src/arpes/utilities/xarray.py b/src/arpes/utilities/xarray.py index 96f47f1a..0ab34703 100644 --- a/src/arpes/utilities/xarray.py +++ b/src/arpes/utilities/xarray.py @@ -64,7 +64,7 @@ def unwrap_xarray_item(item: xr.DataArray) -> xr.DataArray | float: def unwrap_xarray_dict( input_dict: dict[str, xr.DataArray], -) -> dict[str, xr.DataArray | NDArray[np.float_] | float]: +) -> dict[str, xr.DataArray | float]: """Returns the attributes as unwrapped values rather than item() instances. Useful for unwrapping coordinate dicts where the values might be a bare type: diff --git a/src/arpes/xarray_extensions.py b/src/arpes/xarray_extensions.py index d0a6215a..c17c31a8 100644 --- a/src/arpes/xarray_extensions.py +++ b/src/arpes/xarray_extensions.py @@ -115,7 +115,6 @@ PColorMeshKwargs, SampleInfo, ScanInfo, - Spectrometer, XrTypes, ) from .provenance import Provenance