From 53009ee147b7e571e15a57a49772d3459027c84e Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Tue, 20 Feb 2024 15:05:28 +0900 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=AC=20=20Update=20type=20hints?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../local_config.example.py | 1 - src/arpes/_typing.py | 7 +-- src/arpes/analysis/band_analysis.py | 2 +- src/arpes/analysis/deconvolution.py | 8 ++- src/arpes/analysis/savitzky_golay.py | 4 +- src/arpes/analysis/xps.py | 1 - src/arpes/plotting/fermi_edge.py | 24 ++++---- src/arpes/plotting/stack_plot.py | 12 ++-- src/arpes/plotting/tof.py | 33 ++++++---- src/arpes/plotting/utils.py | 8 +-- src/arpes/provenance.py | 2 + src/arpes/utilities/collections.py | 24 ++++---- src/arpes/utilities/jupyter.py | 33 +++++++++- src/arpes/utilities/xarray.py | 2 +- src/arpes/xarray_extensions.py | 60 +++++++++---------- 15 files changed, 132 insertions(+), 89 deletions(-) diff --git a/resources/example_configuration/local_config.example.py b/resources/example_configuration/local_config.example.py index 3e70bc33..2f396cce 100644 --- a/resources/example_configuration/local_config.example.py +++ b/resources/example_configuration/local_config.example.py @@ -6,6 +6,5 @@ "marginal_width": 300, "palette": "magma", }, - "xarray_repr_mod": True, "DEBUG": True, } diff --git a/src/arpes/_typing.py b/src/arpes/_typing.py index 2778b4a8..ddf9067e 100644 --- a/src/arpes/_typing.py +++ b/src/arpes/_typing.py @@ -371,7 +371,7 @@ class QPushButtonARGS(TypedDict, total=False): # -class MPLPlotKwagsBasic(TypedDict, total=False): +class MPLPlotKwargsBasic(TypedDict, total=False): """Kwargs for Axes.plot & Axes.fill_between.""" agg_filter: Callable[[NDArray[np.float_], int], tuple[NDArray[np.float_], int, int]] @@ -402,10 +402,10 @@ class MPLPlotKwagsBasic(TypedDict, total=False): visible: bool -class MPLPlotKwargs(MPLPlotKwagsBasic, total=False): +class MPLPlotKwargs(MPLPlotKwargsBasic, total=False): scalex: bool scaley: bool - + fmt: str dash_capstyle: CapStyleType dash_joinstyle: JoinStyleType dashes: Sequence[float | None] @@ -440,7 +440,6 @@ class ColorbarParam(TypedDict, total=False): alpha: float orientation: None | Literal["vertical", "horizontal"] ticklocation: Literal["auto", "right", "top", "bottom"] - drawedge: bool extend: Literal["neither", "both", "min", "max"] extendfrac: None | Literal["auto"] | float | tuple[float, float] | list[float] spacing: Literal["uniform", "proportional"] diff --git a/src/arpes/analysis/band_analysis.py b/src/arpes/analysis/band_analysis.py index 1bce88f4..d712ba6e 100644 --- a/src/arpes/analysis/band_analysis.py +++ b/src/arpes/analysis/band_analysis.py @@ -327,7 +327,7 @@ def resolve_partial_bands_from_description( "band": band, "name": f"{name}_{i}", "params": _build_params( - old_params=params, + params=params, center=band_center, center_stray=params.get("stray", stray), marginal=marginal, diff --git a/src/arpes/analysis/deconvolution.py b/src/arpes/analysis/deconvolution.py index 5a94658b..8d793b2b 100644 --- a/src/arpes/analysis/deconvolution.py +++ b/src/arpes/analysis/deconvolution.py @@ -139,6 +139,7 @@ def make_psf( sigmas: dict[Hashable, float], *, fwhm: bool = True, + clip: float | None = None, ) -> xr.DataArray: """Produces an n-dimensional gaussian point spread function for use in deconvolve_rl. @@ -146,6 +147,7 @@ def make_psf( data (DataType): input data sigmas (dict[str, float]): sigma values for each dimension. fwhm (bool): if True, sigma is FWHM, not the standard deviation. + clip (float | bool): clip the region by sigma-unit. Returns: The PSF to use. @@ -184,7 +186,7 @@ def make_psf( coords_for_pdf_pos = np.stack(coords, axis=-1) # point distribution function (pdf) logger.debug(f"shape of coords_for_pdf_pos: {coords_for_pdf_pos.shape}") - return xr.DataArray( + psf = xr.DataArray( multivariate_normal(mean=np.zeros(len(sigmas)), cov=cov).pdf( coords_for_pdf_pos, ), @@ -192,3 +194,7 @@ def make_psf( coords=psf_coords, name="PSF", ) + if clip: + clipping_region = {k: slice(-clip * v, clip * v) for k, v in sigmas.items()} + return psf.sel(clipping_region) + return psf diff --git a/src/arpes/analysis/savitzky_golay.py b/src/arpes/analysis/savitzky_golay.py index 4bc51466..e6f28e8f 100644 --- a/src/arpes/analysis/savitzky_golay.py +++ b/src/arpes/analysis/savitzky_golay.py @@ -294,7 +294,9 @@ def savitzky_golay_array( half_window = (window_size - 1) // 2 # precompute coefficients b = np.mat([[k**i for i in order_range] for k in range(-half_window, half_window + 1)]) - m = np.linalg.pinv(b).A[deriv] * rate**deriv * factorial(deriv) + b_inv = np.linalg.pinv(b) + assert isinstance(b_inv, np.matrix) + m = b_inv.A[deriv] * rate**deriv * factorial(deriv) # pad the signal at the extremes with # values taken from the signal itself firstvals = y[0] - np.abs(y[1 : half_window + 1][::-1] - y[0]) diff --git a/src/arpes/analysis/xps.py b/src/arpes/analysis/xps.py index 6f61db20..9475054e 100644 --- a/src/arpes/analysis/xps.py +++ b/src/arpes/analysis/xps.py @@ -38,7 +38,6 @@ def local_minima(a: NDArray[np.float_], promenance: int = 3) -> NDArray[np.float Returns: A mask where the local minima are True and other values are False. """ - conditions = a == a for i in range(1, promenance + 1): current_conditions = np.r_[[False] * i, a[i:] < a[:-i]] & np.r_[a[:-i] < a[i:], [False] * i] conditions = conditions & current_conditions diff --git a/src/arpes/plotting/fermi_edge.py b/src/arpes/plotting/fermi_edge.py index c22b5ccd..02e77643 100644 --- a/src/arpes/plotting/fermi_edge.py +++ b/src/arpes/plotting/fermi_edge.py @@ -3,7 +3,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Unpack import matplotlib.pyplot as plt import numpy as np @@ -19,9 +19,10 @@ if TYPE_CHECKING: from pathlib import Path - from _typeshed import Incomplete from numpy.typing import NDArray + from arpes._typing import MPLPlotKwargs + __all__ = ["fermi_edge_reference", "plot_fit"] @@ -107,16 +108,16 @@ def plot_fit( @save_plot_provenance def fermi_edge_reference( - data: xr.DataArray, + data_arr: xr.DataArray, title: str = "", ax: Axes | None = None, out: str | Path = "", - **kwargs: Incomplete, + **kwargs: Unpack[MPLPlotKwargs], ) -> Path | Axes: """Fits for and plots results for the Fermi edge on a piece of data. Args: - data: The data, this should be of type DataArray + data_arr: The data, this should be of type DataArray title: A title to attach to the plot ax: The axes to plot to, if not specified will be generated out: Where to save the plot @@ -129,10 +130,10 @@ def fermi_edge_reference( "Not automatically correcting for slit shape distortions to the Fermi edge", stacklevel=2, ) - assert isinstance(data, xr.DataArray) + assert isinstance(data_arr, xr.DataArray) sum_dimensions: set[str] = {"cycle", "phi", "kp", "kx"} - sum_dimensions.intersection_update(set(data.dims)) - summed_data = data.sum(*list(sum_dimensions)) + sum_dimensions.intersection_update(set(data_arr.dims)) + summed_data = data_arr.sum(*list(sum_dimensions)) broadcast_dimensions = [str(d) for d in summed_data.dims if str(d) != "eV"] msg = f"Could not product fermi edge reference. Too many dimensions: {broadcast_dimensions}" @@ -155,18 +156,17 @@ def fermi_edge_reference( _, ax = plt.subplots(figsize=(8, 5)) if not title: - title = data.S.label.replace("_", " ") + title = data_arr.S.label.replace("_", " ") centers.plot(ax=ax, **kwargs) widths.plot(ax=ax, **kwargs) if isinstance(ax, Axes): - ax.set_xlabel(label_for_dim(data, ax.get_xlabel())) - ax.set_ylabel(label_for_dim(data, ax.get_ylabel())) + ax.set_xlabel(label_for_dim(data_arr, ax.get_xlabel())) + ax.set_ylabel(label_for_dim(data_arr, ax.get_ylabel())) ax.set_title(title, font_size=14) if out: plt.savefig(path_for_plot(out), dpi=400) return path_for_plot(out) - return ax diff --git a/src/arpes/plotting/stack_plot.py b/src/arpes/plotting/stack_plot.py index 9eb3c6aa..8018b021 100644 --- a/src/arpes/plotting/stack_plot.py +++ b/src/arpes/plotting/stack_plot.py @@ -39,10 +39,10 @@ from pathlib import Path from matplotlib.figure import Figure - from matplotlib.typing import ColorType, RGBAColorType + from matplotlib.typing import ColorType, RGBAColorType, RGBColorType from numpy.typing import NDArray - from arpes._typing import LegendLocation, MPLPlotKwagsBasic, XrTypes + from arpes._typing import LegendLocation, MPLPlotKwargsBasic __all__ = ( "stack_dispersion_plot", "flat_stack_plot", @@ -217,7 +217,7 @@ def flat_stack_plot( # noqa: PLR0913 title: str = "", out: str | Path = "", loc: LegendLocation = "upper left", - **kwargs: Unpack[MPLPlotKwagsBasic], + **kwargs: Unpack[MPLPlotKwargsBasic], ) -> Path | tuple[Figure | None, Axes]: """Generates a stack plot with all the lines distinguished by color rather than offset. @@ -313,7 +313,7 @@ def flat_stack_plot( # noqa: PLR0913 @save_plot_provenance def stack_dispersion_plot( # noqa: PLR0913 - data: XrTypes, + data: xr.DataArray, *, stack_axis: str = "", ax: Axes | None = None, @@ -326,7 +326,7 @@ def stack_dispersion_plot( # noqa: PLR0913 negate: bool = False, figsize: tuple[float, float] = (7, 7), title: str = "", - **kwargs: Unpack[MPLPlotKwagsBasic], + **kwargs: Unpack[MPLPlotKwargsBasic], ) -> Path | tuple[Figure | None, Axes]: """Generates a stack plot with all the lines distinguished by offset (and color). @@ -573,7 +573,7 @@ def _color_for_plot( color: Colormap | ColorType, i: int, num_plot: int, -) -> RGBAColorType: +) -> RGBAColorType | RGBColorType: if isinstance(color, Colormap): cmap = color return cmap(np.abs(i / num_plot)) diff --git a/src/arpes/plotting/tof.py b/src/arpes/plotting/tof.py index b989743c..c1e8fd84 100644 --- a/src/arpes/plotting/tof.py +++ b/src/arpes/plotting/tof.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Unpack import matplotlib.pyplot as plt import numpy as np @@ -27,7 +27,7 @@ import xarray as xr from matplotlib.figure import Figure - from arpes._typing import DataType + from arpes._typing import MPLPlotKwargs, MPLPlotKwargsBasic __all__ = ( "plot_with_std", @@ -37,37 +37,41 @@ @save_plot_provenance def plot_with_std( - data: DataType, + data_set: xr.Dataset, # dat_vars is used, name_to_plot: str = "", ax: Axes | None = None, out: str | Path = "", - **kwargs: tuple[int, int] | float | str, + figsize: tuple[float, float] = (7, 5), + **kwargs: Unpack[MPLPlotKwargs], ) -> Path | tuple[Figure | None, Axes]: """Makes a fill-between line plot with error bars from associated statistical errors. Args: - data(xr.Dataset): ARPES data that 'mean_and_deviation' is applied. + data_set (xr.Dataset): ARPES data that 'mean_and_deviation' is applied. name_to_plot(str): data name to plot, in most case "spectrum" is used. ax: Matplotlib Axes object out: (str | Path): Path name to output figure. + figsize (tuple[float, float]): figure size **kwargs: pass to subplots if figsize is set as tuple, other kwargs are pass to ax.fill_between/xr.DataArray.plot """ if not name_to_plot: - var_names = [k for k in data.data_vars if "_std" not in str(k)] + var_names = [k for k in data_set.data_vars if "_std" not in str(k)] assert len(var_names) == 1 name_to_plot = str(var_names[0]) - assert (name_to_plot + "_std") in data.data_vars, "Has 'mean_and_deviation' been applied?" + assert ( + name_to_plot + "_std" + ) in data_set.data_vars, "Has 'mean_and_deviation' been applied?" fig: Figure | None = None if ax is None: - fig, ax = plt.subplots(figsize=kwargs.pop("figsize", (7, 5))) + fig, ax = plt.subplots(figsize=figsize) assert isinstance(ax, Axes) - data.data_vars[name_to_plot].plot(ax=ax, **kwargs) - x, y = data.data_vars[name_to_plot].G.to_arrays() + data_set.data_vars[name_to_plot].plot(ax=ax, **kwargs) + x, y = data_set.data_vars[name_to_plot].G.to_arrays() - std = data.data_vars[name_to_plot + "_std"].values + std = data_set.data_vars[name_to_plot + "_std"].values ax.fill_between(x, y - std, y + std, alpha=0.3, **kwargs) if out: @@ -85,7 +89,8 @@ def scatter_with_std( name_to_plot: str = "", ax: Axes | None = None, out: str | Path = "", - **kwargs: tuple[int, int] | float | str, + figsize: tuple[float, float] = (7, 5), + **kwargs: Unpack[MPLPlotKwargsBasic], ) -> Path | tuple[Figure | None, Axes]: """Makes a scatter plot of data with error bars generated from associated statistical errors. @@ -94,6 +99,8 @@ def scatter_with_std( name_to_plot(str): data name to plot, in most case "spectrum" is used. ax: Matplotlib Axes object out: (str | Path): Path name to output figure. + figsize (tuple[float, float]): tuple for figure size. + fmt (str): THe form at for the data points/lines. **kwargs: pass to subplots if figsize is set as tuple, other kwargs are pass to ax.errorbar """ if not name_to_plot: @@ -106,7 +113,7 @@ def scatter_with_std( fig: Figure | None = None if ax is None: - fig, ax = plt.subplots(figsize=kwargs.pop("figsize", (7, 5))) + fig, ax = plt.subplots(figsize=figsize) assert isinstance(ax, Axes) x, y = data.data_vars[name_to_plot].G.to_arrays() diff --git a/src/arpes/plotting/utils.py b/src/arpes/plotting/utils.py index 16c79782..32e7dada 100644 --- a/src/arpes/plotting/utils.py +++ b/src/arpes/plotting/utils.py @@ -995,7 +995,7 @@ def temperature_colorbar( low: float = 0, high: float = 300, ax: Axes | None = None, - **kwargs: Incomplete, + **kwargs: Unpack[ColorbarParam], ) -> colorbar.Colorbar: """Generates a colorbar suitable for temperature data with fixed extent.""" assert isinstance(ax, Axes) @@ -1017,7 +1017,7 @@ def delay_colorbar( low: float = -1, high: float = 1, ax: Axes | None = None, - **kwargs: Incomplete, + **kwargs: Unpack[ColorbarParam], ) -> colorbar.Colorbar: assert isinstance(ax, Axes) """Generates a colorbar suitable for delay data. @@ -1037,7 +1037,7 @@ def temperature_colorbar_around( central: float, temperature_range: float = 50, ax: Axes | None = None, - **kwargs: Incomplete, + **kwargs: Unpack[ColorbarParam], ) -> colorbar.Colorbar: """Generates a colorbar suitable for temperature axes around a central value.""" assert isinstance(ax, Axes) @@ -1122,7 +1122,7 @@ def generic_colorbarmap_for_data( ax: Axes, *, keep_ticks: bool = True, - **kwargs: Incomplete, + **kwargs: Unpack[ColorbarParam], ) -> tuple[colorbar.Colorbar, Callable[..., RGBAColorType]]: """Generates a colorbar and colormap which is useful in general context. diff --git a/src/arpes/provenance.py b/src/arpes/provenance.py index 4166fbca..32241b19 100644 --- a/src/arpes/provenance.py +++ b/src/arpes/provenance.py @@ -77,6 +77,8 @@ class PROVENANCE(TypedDict, total=False): # old_axis: str new_axis: str + # + occupation_ratio: float def attach_id(data: XrTypes) -> None: diff --git a/src/arpes/utilities/collections.py b/src/arpes/utilities/collections.py index 1fdf5053..7a5e6a0f 100644 --- a/src/arpes/utilities/collections.py +++ b/src/arpes/utilities/collections.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import Mapping, Sequence -from typing import Any, TypeVar +from typing import TypeVar __all__ = ( "deep_equals", @@ -13,7 +13,7 @@ T = TypeVar("T") -def deep_update(destination: dict[str, Any], source: dict[str, Any]) -> dict[str, Any]: +def deep_update(destination: dict[str, T], source: dict[str, T]) -> dict[str, T]: """Doesn't clobber keys further down trees like doing a shallow update would. Instead recurse down from the root and update as appropriate. @@ -51,14 +51,18 @@ def deep_equals( return all(deep_equals(item_a, item_b) for item_a, item_b in zip(a, b, strict=True)) if isinstance(a, Mapping) and isinstance(b, Mapping): - if set(a.keys()) != set(b.keys()): - return False + return _deep_equals_dict(a, b) + raise TypeError - for k in a: - item_a, item_b = a[k], b[k] - if not deep_equals(item_a, item_b): - return False +def _deep_equals_dict(a: Mapping, b: Mapping) -> bool: + if set(a.keys()) != set(b.keys()): + return False - return True - raise TypeError + for k in a: + item_a, item_b = a[k], b[k] + + if not deep_equals(item_a, item_b): + return False + + return True diff --git a/src/arpes/utilities/jupyter.py b/src/arpes/utilities/jupyter.py index 913a20ae..f2363c61 100644 --- a/src/arpes/utilities/jupyter.py +++ b/src/arpes/utilities/jupyter.py @@ -10,7 +10,7 @@ from datetime import UTC from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypedDict from tqdm.notebook import tqdm @@ -53,7 +53,34 @@ def wrap_tqdm( return tqdm(x, *args, **kwargs) -def get_full_notebook_information() -> dict[str, dict[str, str | int | bool]] | None: +class ServerInfo(TypedDict, total=False): + base_url: str + password: bool + pid: int + port: int + root_dir: str + secure: bool + sock: str + token: str + url: str + version: str + + +class SessionInfo(TypedDict, total=False): + id: str + path: str + name: str + type: str + kernel: dict[str, str | int] + notebook: dict[str, str] + + +class NoteBookInfomation(TypedDict, total=False): + server: ServerInfo + session: SessionInfo + + +def get_full_notebook_information() -> NoteBookInfomation | None: """Javascriptless method to fetch current notebook sessions and the one matching this kernel. Returns: @@ -113,7 +140,7 @@ def get_notebook_name() -> str | None: """ jupyter_info = get_full_notebook_information() if jupyter_info: - return jupyter_info["session"]["notebook"]["name"].split(".")[0] + return Path(jupyter_info["session"]["notebook"]["name"]).stem return None diff --git a/src/arpes/utilities/xarray.py b/src/arpes/utilities/xarray.py index f65ee522..8265a2df 100644 --- a/src/arpes/utilities/xarray.py +++ b/src/arpes/utilities/xarray.py @@ -158,7 +158,7 @@ def g( data: DataType, *args: Incomplete, **kwargs: Incomplete, - ) -> xr.Dataset: + ) -> DataType: """[TODO:summary]. Args: diff --git a/src/arpes/xarray_extensions.py b/src/arpes/xarray_extensions.py index 994d9be7..6e7968e6 100644 --- a/src/arpes/xarray_extensions.py +++ b/src/arpes/xarray_extensions.py @@ -97,10 +97,10 @@ from _typeshed import Incomplete from matplotlib import animation from matplotlib.axes import Axes - from matplotlib.colors import Normalize from matplotlib.figure import Figure from matplotlib.typing import RGBColorType from numpy.typing import DTypeLike, NDArray + from xarray.core.coordinates import DataArrayCoordinates, DatasetCoordinates from ._typing import ( ANALYZERINFO, @@ -178,8 +178,6 @@ class ARPESAccessorBase: """Base class for the xarray extensions in PyARPES.""" class _SliceAlongPathKwags(TypedDict, total=False): - arr: xr.DataArray - interpolation_points: NDArray[np.float_] | None axis_name: str resolution: float n_points: int | None @@ -196,7 +194,7 @@ def along( ToDo: Test """ assert isinstance(self._obj, xr.DataArray) - return slice_along_path(self._obj, directions, **kwargs) + return slice_along_path(self._obj, interpolation_points=directions, **kwargs) def find(self, name: str) -> list[str]: """Return the property names containing the "name". @@ -1353,7 +1351,9 @@ def sample_angles( ) @property - def full_coords(self) -> dict[str, float | xr.DataArray]: + def full_coords( + self, + ) -> dict[str, float | xr.DataArray | DataArrayCoordinates | DatasetCoordinates]: """[TODO:summary]. Args: @@ -1362,7 +1362,10 @@ def full_coords(self) -> dict[str, float | xr.DataArray]: Returns: [TODO:description] """ - full_coords: dict[str, float | xr.DataArray] = {} + full_coords: dict[ + str, + float | xr.DataArray | DataArrayCoordinates | DatasetCoordinates, + ] = {} full_coords.update(dict(zip(["x", "y", "z"], self.sample_pos, strict=True))) full_coords.update( @@ -1379,7 +1382,6 @@ def full_coords(self) -> dict[str, float | xr.DataArray]: "hv": self.hv, }, ) - full_coords.update(self._obj.coords) return full_coords @@ -1650,7 +1652,7 @@ def dict_to_html(d: Mapping[str, float | str]) -> str: def _repr_html_full_coords( self, - coords: dict[str, float | xr.DataArray], + coords: dict[str, float | xr.DataArray | DataArrayCoordinates | DatasetCoordinates], ) -> str: significant_coords = {} for k, v in coords.items(): @@ -1724,9 +1726,9 @@ def _experimentalinfo_to_dict(conditions: EXPERIMENTINFO) -> dict[str, str]: if isinstance(v, xr.DataArray): min_hv = float(v.min()) max_hv = float(v.max()) - transformed_dict[k] = ( - f" from {min_hv} to {max_hv} eV" - ) + transformed_dict[ + k + ] = f" from {min_hv} to {max_hv} eV" elif isinstance(v, float) and not np.isnan(v): transformed_dict[k] = f"{v} eV" return transformed_dict @@ -1915,23 +1917,23 @@ def fs_plot( def fermi_edge_reference_plot( self: Self, pattern: str = "{}.png", - **kwargs: str | Normalize | None, + out: str | Path = "", + **kwargs: Unpack[MPLPlotKwargs], ) -> Path | Axes: """Provides a reference plot for a Fermi edge reference. Args: pattern ([TODO:type]): [TODO:description] + out (str | Path): Path name for output figure. kwargs: pass to plotting.fermi_edge.fermi_edge_reference Returns: [TODO:description] """ - out = kwargs.get("out") if out is not None and isinstance(out, bool): out = pattern.format(f"{self.label}_fermi_edge_reference") - kwargs["out"] = out assert isinstance(self._obj, xr.DataArray) - return fermi_edge_reference(self._obj, **kwargs) + return fermi_edge_reference(self._obj, out=out, **kwargs) def _referenced_scans_for_spatial_plot( self: Self, @@ -1998,15 +2000,15 @@ def _simple_spectrum_reference_plot( *, use_id: bool = True, pattern: str = "{}.png", - **kwargs: IncompleteMPL, + out: str | Path = "", + **kwargs: Unpack[PColorMeshKwargs], ) -> Axes | Path: - out = kwargs.get("out") + assert isinstance(self._obj, xr.DataArray) label = self._obj.attrs["id"] if use_id else self.label - if out is not None and isinstance(out, bool): + if isinstance(out, bool): out = pattern.format(f"{label}_spectrum_reference") - kwargs["out"] = out - return fancy_dispersion(self._obj, **kwargs) + return fancy_dispersion(self._obj, out=out, **kwargs) def cut_nan_coords(self: Self) -> xr.DataArray: """Selects data where coordinates are not `nan`. @@ -2015,6 +2017,7 @@ def cut_nan_coords(self: Self) -> xr.DataArray: The subset of the data where coordinates are not `nan`. """ slices = {} + assert isinstance(self._obj, xr.DataArray) for cname, cvalue in self._obj.coords.items(): try: end_ind = np.where(np.isnan(cvalue.values))[0][0] @@ -2027,7 +2030,7 @@ def cut_nan_coords(self: Self) -> xr.DataArray: def reference_plot( self, **kwargs: Unpack[LabeledFermiSurfaceParam] | Unpack[PColorMeshKwargs], - ) -> Axes | Path: + ) -> Axes | Path | tuple[Figure, NDArray[np.object_]]: """Generates a reference plot for this piece of data according to its spectrum type. Args: @@ -2125,6 +2128,7 @@ def corrected_angle_by( "beta", "theta", ) + assert isinstance(self._obj, xr.DataArray) assert angle_for_correction in self._obj.attrs arr: xr.DataArray = self._obj.copy(deep=True) arr.S.correct_angle_by(angle_for_correction) @@ -2383,7 +2387,7 @@ def coordinatize(self, as_coordinate_name: str | None = None) -> XrTypes: if as_coordinate_name is None: as_coordinate_name = str(dim) - o = self._obj.rename(dict([[dim, as_coordinate_name]])).copy(deep=True) + o = self._obj.rename({dim: as_coordinate_name}).copy(deep=True) o.coords[as_coordinate_name] = o.values return o @@ -2476,7 +2480,7 @@ def as_movie( out: str | bool = "", **kwargs: Unpack[PColorMeshKwargs], ) -> Path | animation.FuncAnimation: - assert isinstance(self._obj, xr.DataArray | xr.Dataset) + assert isinstance(self._obj, xr.DataArray) if isinstance(out, bool) and out is True: out = pattern.format(f"{self._obj.S.label}_animation") @@ -2938,10 +2942,10 @@ def __init__(self, xarray_obj: xr.Dataset) -> None: def eval(self, *args: Incomplete, **kwargs: Incomplete) -> xr.DataArray: return self._obj.results.G.map(lambda x: x.eval(*args, **kwargs)) - def show(self, *, detached: bool = False) -> None: + def show(self) -> None: from .plotting.fit_tool import fit_tool - fit_tool(self._obj, detached=detached) + fit_tool(self._obj) @property def broadcast_dimensions(self) -> list[str]: @@ -3101,12 +3105,6 @@ def param_as_dataset(self, param_name: str) -> xr.Dataset: }, ) - def show(self) -> None: - """Opens a Qt based interactive fit inspection tool.""" - from .plotting.fit_tool import fit_tool - - fit_tool(self._obj) - def best_fits(self) -> xr.DataArray: """Orders the fits into a raveled array by the MSE error.""" return self.order_stacked_fits(ascending=True)