From ebcbf535c0075cf6bd2281a3387fbe031e971018 Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Tue, 12 Mar 2024 13:32:53 +0900 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=AC=20=20Update=20type=20hints?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/arpes/_typing.py | 6 +- src/arpes/analysis/decomposition.py | 12 +- src/arpes/bootstrap.py | 9 +- src/arpes/load_pxt.py | 4 +- src/arpes/plotting/annotations.py | 41 +++---- src/arpes/plotting/false_color.py | 5 +- src/arpes/plotting/fermi_edge.py | 11 +- src/arpes/plotting/movie.py | 4 +- src/arpes/plotting/parameter.py | 5 +- src/arpes/utilities/conversion/core.py | 2 +- src/arpes/widgets.py | 157 ++++++++----------------- 11 files changed, 102 insertions(+), 154 deletions(-) diff --git a/src/arpes/_typing.py b/src/arpes/_typing.py index bc85a1bc..8c111bdd 100644 --- a/src/arpes/_typing.py +++ b/src/arpes/_typing.py @@ -48,6 +48,7 @@ MarkerType, MarkEveryType, ) + from matplotlib.widgets import Button from numpy.typing import ArrayLike, NDArray from PySide6 import QtCore from PySide6.QtGui import QIcon, QPixmap @@ -145,7 +146,7 @@ class CURRENTCONTEXT(TypedDict, total=False): integration_region: dict[Incomplete, Incomplete] original_data: XrTypes data: XrTypes - widgets: list[mpl.widgets.AxesWidget] + widgets: list[dict[str, mpl.widgets.AxesWidget] | Button] points: list[Incomplete] rect_next: bool # @@ -325,6 +326,8 @@ class DAQINFO(TypedDict, total=False): class SPECTROMETER(ANALYZERINFO, COORDINATES, DAQINFO, total=False): + name: str + type: str rad_per_pixel: float dof: list[str] scan_dof: list[str] @@ -569,7 +572,6 @@ class PLTSubplotParam(TypedDict, total=False): class AxesImageParam(TypedDict, total=False): - ax: Axes cmap: str | Colormap norm: str | Normalize interpolation: Literal[ diff --git a/src/arpes/analysis/decomposition.py b/src/arpes/analysis/decomposition.py index d02ee693..53153743 100644 --- a/src/arpes/analysis/decomposition.py +++ b/src/arpes/analysis/decomposition.py @@ -81,6 +81,10 @@ class DecompositionParam(PCAParam, FastICAParam, NMFParam, FactorAnalysisParam): pass +class DecompositionParamBase(TypedDict, total=False): + n_composition: int | None + + def decomposition_along( data: xr.DataArray, axes: list[str], @@ -177,7 +181,7 @@ def decomposition_along( @wraps(decomposition_along) def pca_along( - *args: xr.DataArray | list[str], + *args: * tuple[xr.DataArray, list[str]], **kwargs: Unpack[PCAParam], ) -> tuple[xr.DataArray, sklearn.decomposition.PCA]: """Specializes `decomposition_along` with `sklearn.decomposition.PCA`.""" @@ -188,7 +192,7 @@ def pca_along( @wraps(decomposition_along) def factor_analysis_along( - *args: xr.DataArray | list[str], + *args: * tuple[xr.DataArray, list[str]], **kwargs: Unpack[FactorAnalysisParam], ) -> tuple[xr.DataArray, sklearn.decomposition.FactorAnalysis]: """Specializes `decomposition_along` with `sklearn.decomposition.FactorAnalysis`.""" @@ -197,7 +201,7 @@ def factor_analysis_along( @wraps(decomposition_along) def ica_along( - *args: xr.DataArray | list[str], + *args: * tuple[xr.DataArray, list[str]], **kwargs: Unpack[FastICAParam], ) -> tuple[xr.DataArray, sklearn.decomposition.FastICA]: """Specializes `decomposition_along` with `sklearn.decomposition.FastICA`.""" @@ -206,7 +210,7 @@ def ica_along( @wraps(decomposition_along) def nmf_along( - *args: xr.DataArray | list[str], + *args: * tuple[xr.DataArray, list[str]], **kwargs: Unpack[NMFParam], ) -> tuple[xr.DataArray, sklearn.decomposition.NMF]: """Specializes `decomposition_along` with `sklearn.decomposition.NMF`.""" diff --git a/src/arpes/bootstrap.py b/src/arpes/bootstrap.py index f697af7a..50778a71 100644 --- a/src/arpes/bootstrap.py +++ b/src/arpes/bootstrap.py @@ -39,7 +39,6 @@ from _typeshed import Incomplete from numpy.typing import NDArray - from ._typing import DataType __all__ = ( "bootstrap", "estimate_prior_adjustment", @@ -131,8 +130,9 @@ def resample( data: xr.DataArray, prior_adjustment: float = 1, ) -> xr.DataArray: + rg = np.random.default_rng() resampled = xr.DataArray( - np.random.Generator.poisson( + rg.poisson( lam=data.values * prior_adjustment, size=data.values.shape, ), @@ -159,8 +159,9 @@ def resample_true_counts(data: xr.DataArray) -> xr.DataArray: Returns: Poisson resampled data. """ + rg = np.random.default_rng() resampled = xr.DataArray( - np.random.Generator.poisson( + rg.poisson( lam=data.values, size=data.values.shape, ), @@ -178,7 +179,7 @@ def resample_true_counts(data: xr.DataArray) -> xr.DataArray: @update_provenance("Bootstrap true electron counts") @lift_dataarray_to_generic def bootstrap_counts( - data: DataType, + data: xr.DataArray, n_samples: int = 1000, name: str | None = None, ) -> xr.Dataset: diff --git a/src/arpes/load_pxt.py b/src/arpes/load_pxt.py index 0998dd44..0fed1fcc 100644 --- a/src/arpes/load_pxt.py +++ b/src/arpes/load_pxt.py @@ -7,7 +7,7 @@ import warnings from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, TypeAlias import numpy as np import xarray as xr @@ -18,7 +18,7 @@ from _typeshed import Incomplete from ._typing import DataType -Wave = Any # really, igor.Wave but we do not assume installation +Wave: TypeAlias = Any # really, igor.Wave but we do not assume installation __all__ = ( "read_single_pxt", diff --git a/src/arpes/plotting/annotations.py b/src/arpes/plotting/annotations.py index 484420b7..1e1d51f1 100644 --- a/src/arpes/plotting/annotations.py +++ b/src/arpes/plotting/annotations.py @@ -19,7 +19,7 @@ from numpy.typing import NDArray - from arpes._typing import EXPERIMENTINFO, DataType, MPLTextParam + from arpes._typing import EXPERIMENTINFO, DataType, MPLTextParam, XrTypes __all__ = ( "annotate_cuts", @@ -27,6 +27,18 @@ "annotate_experimental_conditions", ) +font_scalings = { # see matplotlib.font_manager + "xx-small": 0.579, + "x-small": 0.694, + "small": 0.833, + "medium": 1.0, + "large": 1.200, + "x-large": 1.440, + "xx-large": 1.728, + "larger": 1.2, + "smaller": 0.833, +} + # TODO @: Useless: Revision required # * In order not to use data axis, set transform = ax.Transform @@ -80,32 +92,13 @@ def annotate_experimental_conditions( "large", "x-large", "xx-large", + "larger", "smaller", ] ) = kwargs.get("fontsize", 16) if isinstance(fontsize_keyword, float): fontsize = fontsize_keyword - elif fontsize_keyword in ( - "xx-small", - "x-small", - "small", - "medium", - "large", - "x-large", - "xx-large", - "smaller", - ): - font_scalings = { # see matplotlib.font_manager - "xx-small": 0.579, - "x-small": 0.694, - "small": 0.833, - "medium": 1.0, - "large": 1.200, - "x-large": 1.440, - "xx-large": 1.728, - "larger": 1.2, - "smaller": 0.833, - } + elif fontsize_keyword in font_scalings: fontsize = mpl.rc_params()["font.size"] * font_scalings[fontsize_keyword] else: err_msg = "Incorrect font size setting" @@ -162,7 +155,7 @@ def _render_photon(c: dict[str, float]) -> str: def annotate_cuts( ax: Axes, - data: DataType, + data: XrTypes, plotted_axes: NDArray[np.object_], *, include_text_labels: bool = False, @@ -183,7 +176,7 @@ def annotate_cuts( from arpes.utilities.conversion.forward import convert_coordinates_to_kspace_forward converted_coordinates = convert_coordinates_to_kspace_forward(data) - assert converted_coordinates, xr.Dataset | xr.DataArray + assert isinstance(converted_coordinates, xr.Dataset) assert len(plotted_axes) == TWO_DIMENSION for k, v in kwargs.items(): diff --git a/src/arpes/plotting/false_color.py b/src/arpes/plotting/false_color.py index 21c048c5..5b600b99 100644 --- a/src/arpes/plotting/false_color.py +++ b/src/arpes/plotting/false_color.py @@ -18,7 +18,6 @@ if TYPE_CHECKING: from pathlib import Path - from _typeshed import Incomplete from matplotlib.figure import Figure from numpy.typing import NDArray @@ -31,7 +30,7 @@ def false_color_plot( *, invert: bool = False, pmin_pmax: tuple[float, float] = (0, 1), - **kwargs: Incomplete, + figsize: tuple[float, float] = (7, 5), ) -> Path | tuple[Figure | None, Axes]: """Plots a spectrum in false color after conversion to R, G, B arrays.""" data_r_arr, data_g_arr, data_b_arr = (normalize_to_spectrum(d) for d in data_rgb) @@ -39,7 +38,7 @@ def false_color_plot( fig: Figure | None = None if ax is None: - fig, ax = plt.subplots(figsize=kwargs.pop("figsize", (7, 5))) + fig, ax = plt.subplots(figsize=figsize) assert isinstance(ax, Axes) def normalize_channel(channel: NDArray[np.float_]) -> NDArray[np.float_]: diff --git a/src/arpes/plotting/fermi_edge.py b/src/arpes/plotting/fermi_edge.py index 02e77643..d91ba7b4 100644 --- a/src/arpes/plotting/fermi_edge.py +++ b/src/arpes/plotting/fermi_edge.py @@ -19,6 +19,7 @@ if TYPE_CHECKING: from pathlib import Path + from _typeshed import Incomplete from numpy.typing import NDArray from arpes._typing import MPLPlotKwargs @@ -117,7 +118,7 @@ def fermi_edge_reference( """Fits for and plots results for the Fermi edge on a piece of data. Args: - data_arr: The data, this should be of type DataArray + data: The data, this should be of type DataArray title: A title to attach to the plot ax: The axes to plot to, if not specified will be generated out: Where to save the plot @@ -133,7 +134,7 @@ def fermi_edge_reference( assert isinstance(data_arr, xr.DataArray) sum_dimensions: set[str] = {"cycle", "phi", "kp", "kx"} sum_dimensions.intersection_update(set(data_arr.dims)) - summed_data = data_arr.sum(*list(sum_dimensions)) + summed_data = data.sum(*list(sum_dimensions)) broadcast_dimensions = [str(d) for d in summed_data.dims if str(d) != "eV"] msg = f"Could not product fermi edge reference. Too many dimensions: {broadcast_dimensions}" @@ -156,14 +157,14 @@ def fermi_edge_reference( _, ax = plt.subplots(figsize=(8, 5)) if not title: - title = data_arr.S.label.replace("_", " ") + title = data.S.label.replace("_", " ") centers.plot(ax=ax, **kwargs) widths.plot(ax=ax, **kwargs) if isinstance(ax, Axes): - ax.set_xlabel(label_for_dim(data_arr, ax.get_xlabel())) - ax.set_ylabel(label_for_dim(data_arr, ax.get_ylabel())) + ax.set_xlabel(label_for_dim(data, ax.get_xlabel())) + ax.set_ylabel(label_for_dim(data, ax.get_ylabel())) ax.set_title(title, font_size=14) if out: diff --git a/src/arpes/plotting/movie.py b/src/arpes/plotting/movie.py index 3f379121..d7198446 100644 --- a/src/arpes/plotting/movie.py +++ b/src/arpes/plotting/movie.py @@ -89,7 +89,7 @@ def init() -> tuple[QuadMesh]: def animate(i: int) -> tuple[QuadMesh]: coordinate = animation_coords[i] data_for_plot = data.sel({time_dim: coordinate}) - plot.set_array(data_for_plot.values.G.ravel()) + plot.set_array(data_for_plot.values.ravel()) return (plot,) anim = animation.FuncAnimation( @@ -104,7 +104,7 @@ def animate(i: int) -> tuple[QuadMesh]: animation_writer = animation.writers["ffmpeg"] writer = animation_writer( - fps=1000 / interval_ms, + fps=int(1000 / interval_ms), metadata={"artist": "Me"}, bitrate=1800, ) diff --git a/src/arpes/plotting/parameter.py b/src/arpes/plotting/parameter.py index 8908a6e1..5f755c5a 100644 --- a/src/arpes/plotting/parameter.py +++ b/src/arpes/plotting/parameter.py @@ -49,12 +49,13 @@ def plot_parameter( # noqa: PLR0913 color = kwargs.get("color") e_width = None l_width = None + if "fmt" not in kwargs: + kwargs["fmt"] = "" if two_sigma: _, _, lines = ax.errorbar( x + x_shift, ds.value.values + shift, yerr=2 * ds.error.values, - fmt="", elinewidth=1, linewidth=0, c=color, @@ -64,11 +65,11 @@ def plot_parameter( # noqa: PLR0913 e_width = 2 l_width = 0 + kwargs["fmt"] = "s" ax.errorbar( x + x_shift, ds.value.values + shift, yerr=ds.error.values, - fmt="s", color=color, elinewidth=e_width, linewidth=l_width, diff --git a/src/arpes/utilities/conversion/core.py b/src/arpes/utilities/conversion/core.py index a58fe62a..a659e6dd 100644 --- a/src/arpes/utilities/conversion/core.py +++ b/src/arpes/utilities/conversion/core.py @@ -324,11 +324,11 @@ def interpolated_coordinate_to_raw(*coordinates: NDArray[np.float_]) -> NDArray[ @update_provenance("Automatically k-space converted") def convert_to_kspace( # noqa: PLR0913 arr: xr.DataArray, + *, bounds: dict[MOMENTUM, tuple[float, float]] | None = None, resolution: dict[MOMENTUM, float] | None = None, calibration: DetectorCalibration | None = None, coords: dict[MOMENTUM, NDArray[np.float_]] | None = None, - *, allow_chunks: bool = False, **kwargs: NDArray[np.float_], ) -> xr.DataArray: diff --git a/src/arpes/widgets.py b/src/arpes/widgets.py index 38531b9f..7b9e0cde 100644 --- a/src/arpes/widgets.py +++ b/src/arpes/widgets.py @@ -42,8 +42,9 @@ import numpy as np import xarray as xr from matplotlib import gridspec -from matplotlib.axes import Axes +from matplotlib.backend_bases import MouseEvent from matplotlib.figure import Figure +from matplotlib.image import AxesImage from matplotlib.path import Path from matplotlib.widgets import ( Button, @@ -67,7 +68,8 @@ from collections.abc import Callable from _typeshed import Incomplete - from matplotlib.backend_bases import Event, MouseEvent + from matplotlib.axes import Axes + from matplotlib.backend_bases import Event from matplotlib.collections import Collection from matplotlib.colors import Colormap from numpy.typing import NDArray @@ -253,14 +255,9 @@ def __init__( def handle_select(self, event_click: MouseEvent, event_release: MouseEvent) -> None: """[TODO:summary]. - [TODO:description] - Args: event_click: [TODO:description] event_release: [TODO:description] - - Returns: - [TODO:description] """ dims = self.data.dims @@ -287,13 +284,8 @@ def attach_selector(self, on_select) -> None: # data should already have been set """[TODO:summary]. - [TODO:description] - Args: on_select ([TODO:type]): [TODO:description] - - Returns: - [TODO:description] """ assert self.n_dims is not None @@ -335,8 +327,6 @@ def data(self) -> xr.DataArray: def data(self, new_data: xr.DataArray) -> None: """[TODO:summary]. - [TODO:description] - Args: self ([TODO:type]): [TODO:description] new_data: [TODO:description] @@ -358,9 +348,9 @@ def data(self, new_data: xr.DataArray) -> None: self.ax_kwargs.pop("cmap", None) x, y = self.data.coords[self.data.dims[0]].values, self.data.values self._axis_image = self.ax.plot(x, y, **self.ax_kwargs) - self.ax.set_xlabel(self.data.dims[0]) - cs = self.data.coords[self.data.dims[0]].values - self.ax.set_xlim([np.min(cs), np.max(cs)]) + self.ax.set_xlabel(str(self.data.dims[0])) + cs: NDArray[np.float_] = self.data.coords[self.data.dims[0]].values + self.ax.set_xlim(left=float(np.min(cs)), right=float(np.max(cs))) fancy_labels(self.ax) if self.n_dims == TWO_DIMENSION: @@ -368,8 +358,8 @@ def data(self, new_data: xr.DataArray) -> None: self._data.coords[self._data.dims[0]].values, self._data.coords[self._data.dims[1]].values, ) - extent = [y[0], y[-1], x[0], x[-1]] - assert isinstance(self._axis_image, Axes) + extent = (y[0], y[-1], x[0], x[-1]) + assert isinstance(self._axis_image, AxesImage) self._axis_image.set_extent(extent) self._axis_image.set_data(self._data.values) else: @@ -390,8 +380,6 @@ def data(self, new_data: xr.DataArray) -> None: def mask_cmap(self) -> Colormap: """[TODO:summary]. - [TODO:description] - Args: self ([TODO:type]): [TODO:description] @@ -419,14 +407,9 @@ def mask(self): # noqa: ANN202 def mask(self, new_mask) -> None: """[TODO:summary]. - [TODO:description] - Args: self ([TODO:type]): [TODO:description] new_mask ([TODO:type]): [TODO:description] - - Returns: - [TODO:description] """ if np.array(new_mask).shape != self.data.values.shape: # should be indices then @@ -442,6 +425,7 @@ def mask(self, new_mask) -> None: if self.n_dims == TWO_DIMENSION: if self._mask_image is None: + assert isinstance(self._axis_image, AxesImage) self._mask_image = self.ax.imshow( for_mask.T, cmap=self.mask_cmap, @@ -472,30 +456,29 @@ def mask(self, new_mask) -> None: def autoscale(self) -> None: """[TODO:summary]. - [TODO:description] - Returns: [TODO:description] """ if self.n_dims == TWO_DIMENSION: + assert isinstance(self._axis_image, AxesImage) self._axis_image.autoscale() else: pass @popout -def fit_initializer(data: DataType) -> dict[str, Incomplete]: +def fit_initializer(data: xr.DataArray) -> dict[str, Button | xr.DataArray]: """A tool for initializing lineshape fitting. - [TODO:description] - Args: - data: [TODO:description] + data: (xr.DataArray) + Because broadcast_model is used internally. data must be xr.DataArray Returns: [TODO:description] """ - ctx = {} + assert isinstance(data, xr.DataArray) + ctx: dict[str, Button | xr.DataArray] = {} gs = gridspec.GridSpec(2, 2) ax_initial = plt.subplot(gs[0, 0]) ax_fitted = plt.subplot(gs[0, 1]) @@ -507,7 +490,7 @@ def fit_initializer(data: DataType) -> dict[str, Incomplete]: prefixes = "abcdefghijklmnopqrstuvwxyz" model_settings: list[dict[str, dict[str, float]]] = [] model_defs = [] - for_fit: DataType = data.expand_dims("fit_dim") + for_fit: xr.DataArray = data.expand_dims("fit_dim") for_fit.coords["fit_dim"] = np.array([0]) data_view = DataArrayView(ax_initial) @@ -518,8 +501,6 @@ def fit_initializer(data: DataType) -> dict[str, Incomplete]: def compute_parameters() -> dict: """[TODO:summary]. - [TODO:description] - Returns: [TODO:description] """ @@ -532,13 +513,8 @@ def compute_parameters() -> dict: def on_add_new_peak(selection) -> None: """[TODO:summary]. - [TODO:description] - Args: selection ([TODO:type]): [TODO:description] - - Returns: - [TODO:description] """ amplitude = data.sel(selection).mean().item() @@ -589,16 +565,12 @@ def on_add_new_peak(selection) -> None: def on_copy_settings(event: Event) -> None: """[TODO:summary]. - [TODO:description] - Args: event: [TODO:description] - - Returns: - [TODO:description] """ import pyperclip + logger.debug(f"event: {event}") pyperclip.copy(pprint.pformat(compute_parameters())) copy_settings_button = Button(ax_test, "Copy Settings") @@ -609,8 +581,8 @@ def on_copy_settings(event: Event) -> None: @popout def pca_explorer( - pca: xr.DataArray, # values is used - data: xr.DataArray, # values is used + pca: xr.DataArray, + data: xr.DataArray, component_dim: str = "components", initial_values: list[float] | None = None, *, @@ -647,15 +619,15 @@ def pca_explorer( } arpes.config.CONFIG["CURRENT_CONTEXT"] = context - def compute_for_scatter() -> tuple[XrTypes, int]: + def compute_for_scatter() -> tuple[xr.DataArray, NDArray[np.float_]]: """[TODO:summary]. - [TODO:description] - - Returns: (tuple[XrTypes, int] + Returns: (tuple[xr.DataArray, int] [TODO:description] """ - for_scatter = pca.copy(deep=True).isel({component_dim: context["selected_components"]}) + for_scatter: xr.DataArray = pca.copy(deep=True).isel( + {component_dim: context["selected_components"]}, + ) for_scatter = for_scatter.S.transpose_to_back(component_dim) size: NDArray[np.float_] = data.mean(other_dims).stack(pca_dims=pca_dims).values @@ -665,7 +637,7 @@ def compute_for_scatter() -> tuple[XrTypes, int]: # ===== Set up axes ====== gs = gridspec.GridSpec(2, 2) - ax_components = plt.subplot(gs[0, 0]) + ax_components: Axes = plt.subplot(gs[0, 0]) ax_sum_selected = plt.subplot(gs[0, 1]) ax_map = plt.subplot(gs[1, 0]) @@ -686,13 +658,8 @@ def update_from_selection(ind: Incomplete) -> None: # Calculate the new data """[TODO:summary]. - [TODO:description] - Args: ind: [TODO:description] - - Returns: - [TODO:description] """ if ind is None or not len(ind): context["selected_indices"] = [] @@ -711,17 +678,12 @@ def update_from_selection(ind: Incomplete) -> None: map_view.mask = ind selected_view.data = context["sum_data"] - def set_axes(component_x, component_y) -> None: + def set_axes(component_x: float, component_y: float) -> None: """[TODO:summary]. - [TODO:description] - Args: - component_x ([TODO:type]): [TODO:description] - component_y ([TODO:type]): [TODO:description] - - Returns: - [TODO:description] + component_x (int): [TODO:description] + component_y (int): [TODO:description] """ ax_components.clear() context["selected_components"] = [component_x, component_y] @@ -736,21 +698,17 @@ def set_axes(component_x, component_y) -> None: pts, on_select=update_from_selection, ) - ax_components.set_xlabel("$e_" + str(component_x) + "$") - ax_components.set_ylabel("$e_" + str(component_y) + "$") + ax_components.set_xlabel(f"$e_{component_x}$") + ax_components.set_ylabel(f"$e_{component_y}$") update_from_selection([]) def on_change_axes(event: Event) -> None: """[TODO:summary]. - [TODO:description] - Args: event: [TODO:description] - - Returns: - [TODO:description] """ + logger.debug(f"event: {event}") try: val_x = int(context["axis_X_input"].text) val_y = int(context["axis_Y_input"].text) @@ -780,13 +738,8 @@ def clamp(x: int, low: int, high: int) -> int: def on_select_summed(region) -> None: """[TODO:summary]. - [TODO:description] - Args: region ([TODO:type]): [TODO:description] - - Returns: - [TODO:description] """ context["integration_region"] = region update_from_selection(context["selected_indices"]) @@ -807,9 +760,7 @@ def kspace_tool( coords: dict[str, NDArray[np.float_] | xr.DataArray] | None = None, **kwargs: Incomplete, ) -> CURRENTCONTEXT: - """[TODO:summary]. - - [TODO:description] + """A utility for assigning coordinate offsets using a live momentum conversion. Args: data: [TODO:description] @@ -825,7 +776,6 @@ def kspace_tool( Raises: ValueError: [TODO:description] """ - """A utility for assigning coordinate offsets using a live momentum conversion.""" original_data = data data_array = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) @@ -851,7 +801,11 @@ def kspace_tool( fn(ax_converted) n_widget_axes = 8 - gs_widget = gridspec.GridSpecFromSubplotSpec(n_widget_axes, 1, subplot_spec=gs[:, 2]) + gs_widget = gridspec.GridSpecFromSubplotSpec( + n_widget_axes, + 1, + subplot_spec=gs[:, 2], + ) widget_axes = [plt.subplot(gs_widget[i, 0]) for i in range(n_widget_axes)] for _ in widget_axes[:-2]: @@ -919,12 +873,11 @@ def on_copy_settings(event: Event) -> None: Args: event: [TODO:description] - - Returns: - [TODO:description] """ import pyperclip + logger.debug(f"event: {event}") + pyperclip.copy(pprint.pformat(_compute_offsets())) def apply_offsets(event: Event) -> None: @@ -932,10 +885,8 @@ def apply_offsets(event: Event) -> None: Args: event: [TODO:description] - - Returns: - [TODO:description] """ + logger.debug(f"event: {event}") for name, offset in _compute_offsets().items(): original_data.attrs[f"{name}_offset"] = offset try: @@ -966,7 +917,7 @@ def apply_offsets(event: Event) -> None: @popout def pick_rectangles( - data: DataType, + data: XrTypes, **kwargs: Incomplete, ) -> list[list[float]]: """A utility allowing for selection of rectangular regions. @@ -987,17 +938,13 @@ def pick_rectangles( data.S.plot(**kwargs) ax = fig.gca() - def onclick(event: MouseEvent) -> None: + def onclick(event: Event) -> None: """[TODO:summary]. - [TODO:description] - Args: event: [TODO:description] - - Returns: - [TODO:description] """ + assert isinstance(event, MouseEvent) ctx["points"].append([event.xdata, event.ydata]) if ctx["rect_next"]: p1, p2 = ctx["points"][-2], ctx["points"][-1] @@ -1030,8 +977,6 @@ def onclick(event: MouseEvent) -> None: def pick_gamma(data: DataType, **kwargs: Incomplete) -> DataType: """[TODO:summary]. - [TODO:description] - Args: data: [TODO:description] kwargs: [TODO:description] @@ -1049,14 +994,13 @@ def pick_gamma(data: DataType, **kwargs: Incomplete) -> DataType: def onclick(event: Event) -> None: """[TODO:summary]. - [TODO:description] - Args: event: [TODO:description] Returns: [TODO:description] """ + assert isinstance(event, MouseEvent) data.attrs["symmetry_points"] = {"G": {}} logger.info(event.x, event.xdata, event.y, event.ydata) @@ -1076,7 +1020,7 @@ def onclick(event: Event) -> None: @popout def pick_points( - data_or_str: str | pathlib.Path, + data_or_str: str | pathlib.Path | xr.DataArray, **kwargs: Incomplete, ) -> list[float]: """A utility allowing for selection of points in a dataset. @@ -1096,9 +1040,11 @@ def pick_points( fig = plt.figure() if using_image_data: + assert isinstance(data_or_str, str | pathlib.Path) data = imread_to_xarray(data_or_str) plt.imshow(data.values) else: + assert isinstance(data_or_str, xr.DataArray) data = data_or_str data.S.plot(**kwargs) @@ -1117,15 +1063,16 @@ def pick_points( width = 0.03 * maxd / dx * (xlim[1] - xlim[0]) height = 0.03 * maxd / dy * (ylim[1] - ylim[0]) - def onclick(event: MouseEvent) -> None: + def onclick(event: Event) -> None: """[TODO:summary]. Args: event: [TODO:description] - """ + assert isinstance(event, MouseEvent) ctx["points"].append([event.xdata, event.ydata]) - + assert event.xdata is not None + assert event.ydata is not None circ = mpl.patches.Ellipse( ( event.xdata,