From 48f4dbb08e25865dfb3e0a4c6c4d3090cf494e69 Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Fri, 15 Mar 2024 11:44:53 +0900 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=AC=20=20Update=20type=20hints?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/arpes/_typing.py | 74 +++++++++++++++++++--- src/arpes/endstations/__init__.py | 7 +- src/arpes/endstations/fits_utils.py | 2 +- src/arpes/endstations/nexus_utils.py | 2 +- src/arpes/plotting/basic_tools/__init__.py | 2 +- src/arpes/plotting/stack_plot.py | 8 +-- src/arpes/plotting/utils.py | 45 +++++++------ src/arpes/utilities/xarray.py | 4 +- 8 files changed, 102 insertions(+), 42 deletions(-) diff --git a/src/arpes/_typing.py b/src/arpes/_typing.py index 552b3e8f..2f78e58a 100644 --- a/src/arpes/_typing.py +++ b/src/arpes/_typing.py @@ -12,14 +12,7 @@ from __future__ import annotations import uuid -from typing import ( - TYPE_CHECKING, - Literal, - Required, - TypeAlias, - TypedDict, - TypeVar, -) +from typing import TYPE_CHECKING, Any, Literal, Required, TypeAlias, TypedDict, TypeVar import xarray as xr @@ -334,6 +327,7 @@ class Spectrometer(AnalyzerInfo, Coordinates, DAQInfo, total=False): mstar: float dof_type: dict[str, list[str]] length: float + probe_linewidth: float class ExperimentInfo( @@ -379,6 +373,70 @@ class QPushButtonArgs(TypedDict, total=False): # +class Line2DProperty(TypedDict, total=False): + agg_filter: Callable[[NDArray[np.float_], int], tuple[NDArray[np.float_], int, int]] + alpha: float | None + animated: bool + antialiased: bool | list[bool] + clip_box: BboxBase | None + clip_on: bool + clip_path: mpl.path.Path | Patch | Transform | None + color: ColorType + c: ColorType + dash_capstyple: CapStyleType + dash_joinstyle: JoinStyleType + dashes: LineStyleType + drawstyle: DrawStyleType + ds: DrawStyleType + figure: Figure + fillstyle: FillStyleType + gapcolor: ColorType | None + gid: str + in_layout: bool + label: Any + linestyle: LineStyleType + ls: LineStyleType + marker: MarkerType + markeredgecolor: ColorType + mec: ColorType + markeredgewidth: float + mew: ColorType + markerfacecloralt: ColorType + mfcalt: ColorType + markersize: float + ms: float + markevery: MarkEveryType + mouseover: bool + path_effects: list[AbstractPathEffect] + picker: float | Callable[[Artist, Event], tuple[bool, dict]] + pickradius: float + rasterized: bool + sketch_params: tuple[float, float, float] + snap: bool | None + solid_capstyle: CapStyleType + solid_joinstyle: JoinStyleType + url: str + visible: bool + zorder: float + + +class PolyCollectionProperty(Line2DProperty, total=False): + array: ArrayLike | None + clim: tuple[float, float] + cmap: Colormap | str | None + edgecolor: ColorType | list[ColorType] + ec: ColorType | list[ColorType] + facecolor: ColorType | list[ColorType] + fc: ColorType | list[ColorType] + hatch: Literal["/", "\\", "|", "-", "+", "x", "o", "O", ".", "*"] + norm: Normalize | str | None + offset_transform: Transform + # offsets: (N, 2) or (2, ) array-likel + sizes: NDArray[np.float_] | None + transform: Transform + urls: list[str] | None + + class MPLPlotKwargsBasic(TypedDict, total=False): """Kwargs for Axes.plot & Axes.fill_between.""" diff --git a/src/arpes/endstations/__init__.py b/src/arpes/endstations/__init__.py index e17baa73..84aa56f6 100644 --- a/src/arpes/endstations/__init__.py +++ b/src/arpes/endstations/__init__.py @@ -176,7 +176,7 @@ def is_file_accepted( return False try: - _ = cls.find_first_file(int(file)) + _ = cls.find_first_file(int(file)) # type: ignore[arg-type] except ValueError: return False return True @@ -382,10 +382,7 @@ def postprocess_final( for k, v in self.MERGE_ATTRS.items(): a_data.attrs.setdefault(k, v) - for a_data in ls: - a_data = _ensure_coords(a_data, self.ENSURE_COORDS_EXIST) - - for a_data in ls: + for a_data in [_ensure_coords(a_data, self.ENSURE_COORDS_EXIST) for a_data in ls]: if "chi" in a_data.coords and "chi_offset" not in a_data.attrs: a_data.attrs["chi_offset"] = a_data.coords["chi"].item() diff --git a/src/arpes/endstations/fits_utils.py b/src/arpes/endstations/fits_utils.py index c9c31d82..f7c11ab9 100644 --- a/src/arpes/endstations/fits_utils.py +++ b/src/arpes/endstations/fits_utils.py @@ -136,7 +136,7 @@ def extract_coords( scan_coords[name] = np.linspace(start, end, n, endpoint=True) else: - logger("Loop is tabulated and is region based") + logger.debug("Loop is tabulated and is region based") name, n = ( attrs[f"NM_{loop}_0"], attrs[f"NMPOS_{loop}"], diff --git a/src/arpes/endstations/nexus_utils.py b/src/arpes/endstations/nexus_utils.py index d0012ab1..26f06542 100644 --- a/src/arpes/endstations/nexus_utils.py +++ b/src/arpes/endstations/nexus_utils.py @@ -17,7 +17,7 @@ import xarray as xr -__all__ = ["read_data_attributes_from"] +__all__ = ("read_data_attributes_from",) LOGLEVELS = (DEBUG, INFO) LOGLEVEL = LOGLEVELS[1] diff --git a/src/arpes/plotting/basic_tools/__init__.py b/src/arpes/plotting/basic_tools/__init__.py index 7db031ee..e646f33a 100644 --- a/src/arpes/plotting/basic_tools/__init__.py +++ b/src/arpes/plotting/basic_tools/__init__.py @@ -90,7 +90,7 @@ def __init__(self) -> None: def layout(self) -> QGridLayout: return self.main_layout - def set_data(self, data: XrTypes) -> None: + def set_data(self, data: xr.DataArray) -> None: self.data = normalize_to_spectrum(data) def transpose_to_front(self, dim: Hashable) -> None: diff --git a/src/arpes/plotting/stack_plot.py b/src/arpes/plotting/stack_plot.py index c1bb1109..836c7636 100644 --- a/src/arpes/plotting/stack_plot.py +++ b/src/arpes/plotting/stack_plot.py @@ -7,7 +7,7 @@ import contextlib from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger -from typing import TYPE_CHECKING, Literal, Unpack, reveal_type +from typing import TYPE_CHECKING, Literal, Unpack import matplotlib as mpl import matplotlib.colorbar @@ -69,9 +69,7 @@ def offset_scatter_plot( data: xr.Dataset, name_to_plot: str = "", stack_axis: str = "", - cbarmap: ( - tuple[Callable[..., colorbar.Colorbar], Callable[..., Callable[..., ColorType]]] | None - ) = None, + cbarmap: tuple[colorbar.Colorbar, Callable[..., ColorType]] | None = None, ax: Axes | None = None, out: str | Path = "", scale_coordinate: float = 0.5, @@ -135,6 +133,8 @@ def offset_scatter_plot( skip_colorbar = True if cbarmap is None: skip_colorbar = False + cbar: colorbar.Colorbar | Callable[..., colorbar.Colorbar] + cmap: Colormap try: cbar, cmap = colorbarmaps_for_axis[stack_axis] except KeyError: diff --git a/src/arpes/plotting/utils.py b/src/arpes/plotting/utils.py index e4b557dc..e9a49ec0 100644 --- a/src/arpes/plotting/utils.py +++ b/src/arpes/plotting/utils.py @@ -15,7 +15,7 @@ from datetime import UTC from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, Unpack, reveal_type +from typing import TYPE_CHECKING, Any, Literal, Unpack import matplotlib as mpl import matplotlib.pyplot as plt @@ -698,11 +698,11 @@ def imshow_mask( def imshow_arr( - arr: XrTypes, + arr: xr.DataArray, ax: Axes | None = None, over: AxesImage | None = None, **kwargs: Unpack[IMshowParam], -) -> tuple[Figure, Axes]: +) -> tuple[Figure, AxesImage]: """Similar to plt.imshow but users different default origin, and sets appropriate extents. Args: @@ -729,7 +729,7 @@ def imshow_arr( "extent": (y[0], y[-1], x[0], x[-1]), } for k, v in default_kwargs.items(): - kwargs.setdefault(str(k), v) + kwargs.setdefault(k, v) # type: ignore[misc] if over is None: if kwargs["alpha"] != 1: if isinstance(kwargs["cmap"], str): @@ -755,8 +755,6 @@ def imshow_arr( kwargs["aspect"] = ax.get_aspect() quad = ax.imshow( arr.values, - extent=over.get_extent(), - aspect=ax.get_aspect(), **kwargs, ) @@ -841,7 +839,7 @@ def inset_cut_locator( n = 200 - def resolve(name: str, value: slice | int) -> NDArray[np.float_]: + def resolve(name: Hashable, value: slice | int) -> NDArray[np.float_]: if isinstance(value, slice): low = value.start high = value.stop @@ -881,13 +879,13 @@ def resolve(name: str, value: slice | int) -> NDArray[np.float_]: pass -def generic_colormap(low: float, high: float) -> Callable[..., RGBAColorType]: +def generic_colormap(low: float, high: float) -> Callable[..., ColorType]: """Generates a colormap from the cm.Blues palette, suitable for most purposes.""" delta = high - low low = low - delta / 6 high = high + delta / 6 - def get_color(value: float) -> RGBAColorType: + def get_color(value: float) -> ColorType: return mpl.colormaps.get_cmap("Blues")( float((value - low) / (high - low)), ) @@ -898,10 +896,10 @@ def get_color(value: float) -> RGBAColorType: def phase_angle_colormap( low: float = 0, high: float = np.pi * 2, -) -> Callable[[float], RGBAColorType]: +) -> Callable[[float], ColorType]: """Generates a colormap suitable for angular data or data on a unit circle like a phase.""" - def get_color(value: float) -> RGBAColorType: + def get_color(value: float) -> ColorType: return mpl.colormaps.get_cmap("twilight_shifted")(float((value - low) / (high - low))) return get_color @@ -910,10 +908,10 @@ def get_color(value: float) -> RGBAColorType: def delay_colormap( low: float = -1, high: float = 1, -) -> Callable[[float], RGBAColorType]: +) -> Callable[[float], ColorType]: """Generates a colormap suitable for pump-probe delay data.""" - def get_color(value: float) -> RGBAColorType: + def get_color(value: float) -> ColorType: return mpl.colormaps.get_cmap("coolwarm")( float((value - low) / (high - low)), ) @@ -925,10 +923,10 @@ def temperature_colormap( low: float = 0, high: float = 300, cmap: Colormap = mpl.colormaps["Blues_r"], -) -> Callable[[float], RGBAColorType]: +) -> Callable[[float], ColorType]: """Generates a colormap suitable for temperature data with fixed extent.""" - def get_color(value: float) -> RGBAColorType: + def get_color(value: float) -> ColorType: return cmap(float((value - low) / (high - low))) return get_color @@ -937,10 +935,10 @@ def get_color(value: float) -> RGBAColorType: def temperature_colormap_around( central: float, region: float = 50, -) -> Callable[[float], RGBAColorType]: +) -> Callable[[float], ColorType]: """Generates a colormap suitable for temperature data around a central value.""" - def get_color(value: float) -> RGBAColorType: + def get_color(value: float) -> ColorType: return mpl.colormaps.get_cmap("RdBu_r")(float((value - central) / region)) return get_color @@ -1103,10 +1101,13 @@ def remove_colorbars(fig: Figure | None = None) -> None: """ # TODO: after colorbar removal, plots should be relaxed/rescaled to occupy space previously # allocated to colorbars for now, can follow this with plt.tight_layout() + COLORBAR_ASPECT_RATIO = 20 try: if fig is not None: for ax in fig.axes: - if ax.get_aspect() >= 20: # a bit of a hack + aspect_ragio = ax.get_aspect() + assert isinstance(aspect_ragio, float) + if aspect_ragio >= COLORBAR_ASPECT_RATIO: ax.remove() else: remove_colorbars(plt.gcf()) @@ -1190,7 +1191,7 @@ def __init__( size: float = 1, extent: float = 0.03, label: str = "", - loc: int = 2, + loc: str = "uppder left", ax: Axes | None = None, pad: float = 0.4, borderpad: float = 0.5, @@ -1343,7 +1344,11 @@ def extract(for_data: XrTypes) -> dict[str, Any]: ) with Path(provenance_path).open("w") as f: - json.dump(provenance_context, f, indent=2) + json.dump( + provenance_context, + f, + indent=2, + ) plt.savefig(full_path, dpi=dpi, **kwargs) diff --git a/src/arpes/utilities/xarray.py b/src/arpes/utilities/xarray.py index 8265a2df..a7cf0264 100644 --- a/src/arpes/utilities/xarray.py +++ b/src/arpes/utilities/xarray.py @@ -122,12 +122,12 @@ def g( """[TODO:summary]. Args: - arr (xr.DataArray): [TODO:description] + arr (xr.DataArray): ARPES Data *args: Pass to function f **kwargs: Pass to function f Returns: - [TODO:description] + xr.DataArray """ return xr.DataArray( arr.values,