From 2c10d580732a5f3f1c8cb10ef9abd27cf19ce1a9 Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Fri, 29 Sep 2023 17:28:32 +0900 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=AC=20=20update=20type=20annotation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- arpes/plotting/dispersion.py | 14 ++++++++++---- arpes/plotting/fermi_edge.py | 5 +++-- arpes/plotting/spatial.py | 2 -- arpes/xarray_extensions.py | 31 ++++++++++++++++++------------- 4 files changed, 31 insertions(+), 21 deletions(-) diff --git a/arpes/plotting/dispersion.py b/arpes/plotting/dispersion.py index 0bc228a0..eee29c34 100644 --- a/arpes/plotting/dispersion.py +++ b/arpes/plotting/dispersion.py @@ -3,7 +3,7 @@ import warnings from collections import defaultdict -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, TypedDict, Unpack import matplotlib.pyplot as plt import numpy as np @@ -290,7 +290,7 @@ def hv_reference_scan( out: str | Path = "", e_cut: float = -0.05, bkg_subtraction: float = 0.8, - **kwargs: Incomplete, + **kwargs: Unpack[LabeledFermiSurfaceParam], ) -> Path | None: """A reference plot for photon energy scans. Used internally by other code.""" fs = data.S.fat_sel(eV=e_cut) @@ -342,11 +342,18 @@ def hv_reference_scan( return None +class LabeledFermiSurfaceParam(TypedDict, total=False): + out: str | Path + include_symmetry_points: bool + include_bz: bool + fermi_energy: float + + @save_plot_provenance def reference_scan_fermi_surface( data: DataType, out: str | Path = "", - **kwargs: Incomplete, + **kwargs: Unpack[LabeledFermiSurfaceParam], ) -> Path | None: """A reference plot for Fermi surfaces. Used internally by other code.""" fs = data.S.fermi_surface @@ -386,7 +393,6 @@ def labeled_fermi_surface( include_bz: bool = True, out: str | Path = "", fermi_energy: float = 0, - **kwargs: Incomplete, ) -> Path | None | tuple[Figure, Axes]: """Plots a Fermi surface with high symmetry points annotated onto it.""" fig = None diff --git a/arpes/plotting/fermi_edge.py b/arpes/plotting/fermi_edge.py index 1a7743af..3fca36d7 100644 --- a/arpes/plotting/fermi_edge.py +++ b/arpes/plotting/fermi_edge.py @@ -18,6 +18,7 @@ if TYPE_CHECKING: from pathlib import Path + from matplotlib.colors import Normalize from numpy.typing import NDArray __all__ = ["fermi_edge_reference", "plot_fit"] @@ -109,7 +110,7 @@ def fermi_edge_reference( title: str = "", ax: Axes | None = None, out: str = "", - norm=None, + norm: Normalize | None = None, ) -> Path | None: """Fits for and plots results for the Fermi edge on a piece of data. @@ -118,7 +119,7 @@ def fermi_edge_reference( title: A title to attach to the plot ax: The axes to plot to, if not specified will be generated out: Where to save the plot - norm ([TODO:type]): [TODO:description] + norm (matplotlib.colors.Normalize): [TODO:description] Returns: [TODO:description] diff --git a/arpes/plotting/spatial.py b/arpes/plotting/spatial.py index b6195768..e52f0581 100644 --- a/arpes/plotting/spatial.py +++ b/arpes/plotting/spatial.py @@ -28,7 +28,6 @@ if TYPE_CHECKING: from pathlib import Path - from _typeshed import Incomplete from matplotlib.axes import Axes from matplotlib.figure import Figure from numpy.typing import NDArray @@ -204,7 +203,6 @@ def plot_spatial_reference( def reference_scan_spatial( data: DataType, out: str | Path = "", - **kwargs: Incomplete, ) -> Path | tuple[Figure, NDArray[Axes]]: """Plots the spatial content of a dataset, useful as a quick reference.""" data_arr = normalize_to_spectrum(data) diff --git a/arpes/xarray_extensions.py b/arpes/xarray_extensions.py index ffa9168c..b954dbe0 100644 --- a/arpes/xarray_extensions.py +++ b/arpes/xarray_extensions.py @@ -45,7 +45,7 @@ import warnings from collections import OrderedDict from logging import DEBUG, Formatter, StreamHandler, getLogger -from typing import TYPE_CHECKING, Any, Literal, TypeAlias +from typing import TYPE_CHECKING, Any, Literal, TypeAlias, Unpack import lmfit import matplotlib.pyplot as plt @@ -61,6 +61,7 @@ from arpes.analysis.general import rebin from arpes.models.band import MultifitBand from arpes.plotting.dispersion import ( + LabeledFermiSurfaceParam, fancy_dispersion, hv_reference_scan, labeled_fermi_surface, @@ -84,6 +85,7 @@ from matplotlib.axes import Axes from matplotlib.figure import Figure from matplotlib.typing import RGBColorType + from matplotlibn.colors import Normalize from numpy.typing import DTypeLike, NDArray from arpes._typing import ANGLE, SPECTROMETER, DataType @@ -1976,10 +1978,11 @@ def plot( ) -> Incomplete: """Utility delegate to `xr.DataArray.plot` which rasterizes`. - [TODO:description] - Args: - rasterized: [TODO:description] + rasterized (bool): if True, rasterized (Not vector) drawing + *args: Pass to xr.DataArray.plot + *kwargs: Pass to xr.DataArray.plot + """ object_is_two_dimensional = 2 if len(self._obj.dims) == object_is_two_dimensional and "rasterized" not in kwargs: @@ -2023,12 +2026,10 @@ def fs_plot( def fermi_edge_reference_plot( self, pattern: str = "{}.png", - **kwargs: Incomplete, + **kwargs: str | Normalize | None, ) -> Path | None: """Provides a reference plot for a Fermi edge reference. - [TODO:description] - Args: pattern ([TODO:type]): [TODO:description] kwargs: pass to plotting.fermi_edge.fermi_edge_reference @@ -2048,8 +2049,8 @@ def _referenced_scans_for_spatial_plot( *, use_id: bool = True, pattern: str = "{}.png", - **kwargs: Incomplete, - ): + **kwargs: str, + ) -> Path | tuple[Figure, NDArray[Axes]]: """[TODO:summary]. [TODO:description] @@ -2065,15 +2066,15 @@ def _referenced_scans_for_spatial_plot( out = pattern.format(f"{label}_reference_scan_fs") kwargs["out"] = out - return plotting.reference_scan_spatial(self._obj, **kwargs) + return plotting.spatial.reference_scan_spatial(self._obj, **kwargs) def _referenced_scans_for_map_plot( self, pattern: str = "{}.png", *, use_id: bool = True, - **kwargs: IncompleteMPL, - ): + **kwargs: Unpack[LabeledFermiSurfaceParam], + ) -> Path | None: out = kwargs.get("out") label = self._obj.attrs["id"] if use_id else self.label if out is not None and isinstance(out, bool): @@ -2082,13 +2083,17 @@ def _referenced_scans_for_map_plot( return reference_scan_fermi_surface(self._obj, **kwargs) + class HvRefScanParam(LabeledFermiSurfaceParam): + e_cut: float + bkg_subtraction: float + def _referenced_scans_for_hv_map_plot( self, pattern: str = "{}.png", *, use_id: bool = True, **kwargs: IncompleteMPL, - ): + ) -> Path | None: out = kwargs.get("out") label = self._obj.attrs["id"] if use_id else self.label if out is not None and isinstance(out, bool):