From 3d502b4a87dc7e647872bb9bbfbb0a9ddc75cdfd Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Sat, 10 Feb 2024 13:54:26 +0900 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A5=20=20Be=20slim:=20S.symmetry=5Fpoi?= =?UTF-8?q?nts=20just=20refers=20attrs.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Consider the projected symmetry points etc after if needed. (At least, the current algorithm cannot be be used without any concern.) * and update type hints, as usual. --- arpes/_typing.py | 7 + arpes/analysis/resolution.py | 45 +++-- arpes/plotting/dispersion.py | 30 ++-- arpes/plotting/dos.py | 10 +- arpes/plotting/fermi_edge.py | 2 +- arpes/utilities/bz.py | 10 +- arpes/utilities/conversion/core.py | 1 - arpes/xarray_extensions.py | 260 ++++++++++++----------------- tests/test_basic_data_loading.py | 4 +- tests/test_xarray_extensions.py | 8 +- 10 files changed, 182 insertions(+), 195 deletions(-) diff --git a/arpes/_typing.py b/arpes/_typing.py index acdf1dca..2778b4a8 100644 --- a/arpes/_typing.py +++ b/arpes/_typing.py @@ -262,6 +262,13 @@ class _BEAMLINEINFO(TypedDict, total=False): monochrometer_info: dict[str, float] +class BeamLineSettings(TypedDict, total=False): + exit_slit: float | str + entrance_slit: float | str + hv: float + grating: str | None + + class LIGHTSOURCEINFO(_PROBEINFO, _PUMPINFO, _BEAMLINEINFO, total=False): polarization: float | tuple[float, float] | str photon_flux: float diff --git a/arpes/analysis/resolution.py b/arpes/analysis/resolution.py index b89151b8..fcf570e9 100644 --- a/arpes/analysis/resolution.py +++ b/arpes/analysis/resolution.py @@ -167,9 +167,9 @@ def analyzer_resolution_estimate(data: xr.DataArray, *, meV: bool = False) -> fl def energy_resolution_from_beamline_slit( - table: dict[str, str | float | None], + table: dict[tuple[float, tuple[float, float]], float], photon_energy: float, - exit_slit_size: str | float | tuple[float, float] | None, + exit_slit_size: tuple[float, float], ) -> float: """Calculates the energy resolution contribution from the beamline slits. @@ -184,7 +184,9 @@ def energy_resolution_from_beamline_slit( Returns: The energy broadening in eV. """ - by_slits = {k[1]: v for k, v in table.items() if k[0] == photon_energy} + by_slits: dict[tuple[float, float], float] = { + k[1]: v for k, v in table.items() if k[0] == photon_energy + } if exit_slit_size in by_slits: return by_slits[exit_slit_size] @@ -210,19 +212,23 @@ def energy_resolution_from_beamline_slit( return by_area[low] + (by_area[high] - by_area[low]) * (slit_area - low) / (high - low) -def beamline_resolution_estimate(data: xr.DataArray, *, meV: bool = False) -> None: # noqa: N803 +def beamline_resolution_estimate( + data: xr.DataArray, + *, + meV: bool = False, # noqa: N803 +) -> float: data_array = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) - resolution_table: dict[ - str, - dict[tuple[float, tuple[float, float]], float], - ] = ENDSTATIONS_BEAMLINE_RESOLUTION[data_array.S.endstation] + resolution_table: dict[str, dict[tuple[float, tuple[float, float]], float]] = ( + ENDSTATIONS_BEAMLINE_RESOLUTION[data_array.S.endstation] + ) if isinstance(next(iter(resolution_table.keys())), str): # need grating information settings = data_array.S.beamline_settings - resolution_table = resolution_table[settings["grating"]] - - all_keys = list(resolution_table.keys()) + resolution_table_selected: dict[tuple[float, tuple[float, float]], float] = ( + resolution_table[settings["grating"]] + ) + all_keys = list(resolution_table_selected.keys()) hvs = {k[0] for k in all_keys} low_hv = max(hv for hv in hvs if hv < settings["hv"]) @@ -232,9 +238,16 @@ def beamline_resolution_estimate(data: xr.DataArray, *, meV: bool = False) -> No settings["entrance_slit"], settings["exit_slit"], ) - low_hv_res = energy_resolution_from_beamline_slit(resolution_table, low_hv, slit_size) - high_hv_res = energy_resolution_from_beamline_slit(resolution_table, high_hv, slit_size) - + low_hv_res = energy_resolution_from_beamline_slit( + resolution_table_selected, + low_hv, + slit_size, + ) + high_hv_res = energy_resolution_from_beamline_slit( + resolution_table_selected, + high_hv, + slit_size, + ) # interpolate between nearest values return low_hv_res + (high_hv_res - low_hv_res) * (settings["hv"] - low_hv) / ( high_hv - low_hv @@ -249,7 +262,7 @@ def thermal_broadening_estimate(data: DataType, *, meV: bool = False) -> float: def total_resolution_estimate( - data: DataType, + data: xr.DataArray, *, include_thermal_broadening: bool = False, meV: bool = False, # noqa: N803 @@ -262,7 +275,7 @@ def total_resolution_estimate( Returns: The estimated total resolution broadening. """ - thermal_broadening = 0 + thermal_broadening = 0.0 if include_thermal_broadening: thermal_broadening = thermal_broadening_estimate(data, meV=meV) return math.sqrt( diff --git a/arpes/plotting/dispersion.py b/arpes/plotting/dispersion.py index fe37ef46..e046b20f 100644 --- a/arpes/plotting/dispersion.py +++ b/arpes/plotting/dispersion.py @@ -28,7 +28,7 @@ from matplotlib.figure import Figure, FigureBase from numpy.typing import NDArray - from arpes._typing import DataType, PColorMeshKwargs + from arpes._typing import PColorMeshKwargs, XrTypes from arpes.models.band import Band __all__ = [ @@ -304,8 +304,7 @@ def mask_for(x: NDArray[np.float_]) -> NDArray[np.float_]: @save_plot_provenance def hv_reference_scan( - data: DataType, - out: str | Path = "", + data: XrTypes, e_cut: float = -0.05, bkg_subtraction: float = 0.8, **kwargs: Unpack[LabeledFermiSurfaceParam], @@ -316,7 +315,11 @@ def hv_reference_scan( fs.data -= bkg_subtraction * np.mean(fs.data) fs.data[fs.data < 0] = 0 - _, ax = labeled_fermi_surface(fs, **kwargs) + out = kwargs.pop("out", None) + + lfs = labeled_fermi_surface(fs, **kwargs) + assert isinstance(lfs, tuple) + _, ax = lfs all_scans = data.attrs["df"] all_scans = all_scans[all_scans.id != data.attrs["id"]] @@ -330,7 +333,7 @@ def hv_reference_scan( scans_by_hv[round(scan.S.hv)].append(scan.S.label.replace("_", " ")) - dim_order = ax.dim_order + dim_order = [ax.get_xlabel(), ax.get_ylabel()] handles = [] handle_labels = [] @@ -357,19 +360,19 @@ def hv_reference_scan( return path_for_plot(out) plt.show() - return None + return ax class LabeledFermiSurfaceParam(TypedDict, total=False): include_symmetry_points: bool include_bz: bool fermi_energy: float + out: str | Path @save_plot_provenance def reference_scan_fermi_surface( - data: DataType, - out: str | Path = "", + data: XrTypes, **kwargs: Unpack[LabeledFermiSurfaceParam], ) -> Path | Axes: """A reference plot for Fermi surfaces. Used internally by other code. @@ -377,14 +380,19 @@ def reference_scan_fermi_surface( Warning: Not work correctly. (Because S.referenced_scans has been removed.) """ fs = data.S.fermi_surface - _, ax = labeled_fermi_surface(fs, **kwargs) + + out = kwargs.pop("out", None) + lfs = labeled_fermi_surface(fs, **kwargs) + assert isinstance(lfs, tuple) + _, ax = lfs referenced_scans = data.S.referenced_scans handles = [] for index, row in referenced_scans.iterrows(): scan = load_data(row.id) remapped_coords = remap_coords_to(scan, data) - dim_order = ax.dim_order + + dim_order = [ax.get_xlabel(), ax.get_ylabel()] ls = ax.plot( remapped_coords[dim_order[0]], remapped_coords[dim_order[1]], @@ -447,7 +455,7 @@ def labeled_fermi_surface( # noqa: PLR0913 if include_symmetry_points: for point_name, point_location in data.S.iter_symmetry_points: warnings.warn("Symmetry point locations are not k-converted", stacklevel=2) - coords = [point_location[d] for d in dim_order] + coords = tuple([point_location[d] for d in dim_order]) ax.plot(*coords, marker=".", color=marker_color) ax.annotate( label_for_symmetry_point(point_name), diff --git a/arpes/plotting/dos.py b/arpes/plotting/dos.py index 95e38807..a9184154 100644 --- a/arpes/plotting/dos.py +++ b/arpes/plotting/dos.py @@ -23,8 +23,6 @@ from matplotlib.colors import Normalize from matplotlib.figure import Figure - from arpes._typing import DataType - __all__ = ( "plot_dos", "plot_core_levels", @@ -33,7 +31,7 @@ @save_plot_provenance def plot_core_levels( # noqa: PLR0913 - data: DataType, + data: xr.DataArray, title: str = "", out: str | Path = "", norm: Normalize | None = None, @@ -43,7 +41,9 @@ def plot_core_levels( # noqa: PLR0913 promenance: int = 5, ) -> Path | tuple[Axes, Colorbar]: """Plots an XPS curve and approximate core level locations.""" - _, axes, cbar = plot_dos(data=data, title=title, out="", norm=norm, dos_pow=dos_pow) + plotdos = plot_dos(data=data, title=title, out="", norm=norm, dos_pow=dos_pow) + assert isinstance(plotdos, tuple) + _, axes, cbar = plotdos if core_levels is None: core_levels = approximate_core_levels(data, binning=binning, promenance=promenance) @@ -64,7 +64,7 @@ def plot_dos( out: str | Path = "", norm: Normalize | None = None, dos_pow: float = 1, -) -> Path | tuple[Figure, Axes, Colorbar]: +) -> Path | tuple[Figure, tuple[Axes], Colorbar]: """Plots the density of states (momentum integrated) image next to the original spectrum.""" data_arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) diff --git a/arpes/plotting/fermi_edge.py b/arpes/plotting/fermi_edge.py index 8f0b6154..c22b5ccd 100644 --- a/arpes/plotting/fermi_edge.py +++ b/arpes/plotting/fermi_edge.py @@ -110,7 +110,7 @@ def fermi_edge_reference( data: xr.DataArray, title: str = "", ax: Axes | None = None, - out: str = "", + out: str | Path = "", **kwargs: Incomplete, ) -> Path | Axes: """Fits for and plots results for the Fermi edge on a piece of data. diff --git a/arpes/utilities/bz.py b/arpes/utilities/bz.py index 6d3460dc..a85c5613 100644 --- a/arpes/utilities/bz.py +++ b/arpes/utilities/bz.py @@ -461,7 +461,7 @@ def reduced_bz_axis_to( point_names = _POINT_NAMES_FOR_SYMMETRY[symmetry] symmetry_points, _ = data.S.symmetry_points() - points = {k: v[0] for k, v in symmetry_points.items() if k in point_names} + points = {k: v for k, v in symmetry_points.items() if k in point_names} coords_by_point = { k: np.array([v.get(d, 0) for d in data.dims if d in v or include_E and d == "eV"]) @@ -500,7 +500,7 @@ def reduced_bz_axes(data: XrTypes) -> tuple[NDArray[np.float_], NDArray[np.float point_names = _POINT_NAMES_FOR_SYMMETRY[symmetry] symmetry_points, _ = data.S.symmetry_points() - points = {k: v[0] for k, v in symmetry_points.items() if k in point_names} + points = {k: v for k, v in symmetry_points.items() if k in point_names} coords_by_point = {k: np.array([v[d] for d in data.dims if d in v]) for k, v in points.items()} if symmetry == "rect": @@ -537,7 +537,7 @@ def axis_along(data: XrTypes, symbol: str) -> float: point_names = _POINT_NAMES_FOR_SYMMETRY[symmetry] symmetry_points, _ = data.S.symmetry_points() - points = {k: v[0] for k, v in symmetry_points.items() if k in point_names} + points = {k: v for k, v in symmetry_points.items() if k in point_names} coords_by_point = {k: np.array([v[d] for d in data.dims if d in v]) for k, v in points.items()} @@ -576,7 +576,7 @@ def reduced_bz_poly(data: XrTypes, *, scale_zone: bool = False) -> NDArray[np.fl dy = 3 * dy symmetry_points, _ = data.S.symmetry_points() - points = {k: v[0] for k, v in symmetry_points.items() if k in point_names} + points = {k: v for k, v in symmetry_points.items() if k in point_names} coords_by_point = { k: np.array([v.get(d, 0) for d in data.dims if d in v]) for k, v in points.items() } @@ -625,7 +625,7 @@ def reduced_bz_E_mask( point_names = _POINT_NAMES_FOR_SYMMETRY[symmetry] symmetry_points, _ = data.S.symmetry_points() - points = {k: v[0] for k, v in symmetry_points.items() if k in point_names} + points = {k: v for k, v in symmetry_points.items() if k in point_names} coords_by_point = { k: np.array([v.get(d, 0) for d in data.dims if d in v or d == "eV"]) for k, v in points.items() diff --git a/arpes/utilities/conversion/core.py b/arpes/utilities/conversion/core.py index 9e22612e..5e768762 100644 --- a/arpes/utilities/conversion/core.py +++ b/arpes/utilities/conversion/core.py @@ -31,7 +31,6 @@ from typing import TYPE_CHECKING, Literal import numpy as np -from numpy.typing import ArrayLike import xarray as xr from scipy.interpolate import RegularGridInterpolator diff --git a/arpes/xarray_extensions.py b/arpes/xarray_extensions.py index badc07b8..e47f76d0 100644 --- a/arpes/xarray_extensions.py +++ b/arpes/xarray_extensions.py @@ -39,16 +39,24 @@ from __future__ import annotations -import collections import contextlib import copy import itertools import warnings -from collections import OrderedDict, defaultdict +from collections import OrderedDict from collections.abc import Collection, Hashable, Mapping, Sequence from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, Self, TypeAlias, Unpack +from typing import ( + TYPE_CHECKING, + Any, + Literal, + Self, + TypeAlias, + TypedDict, + TypeGuard, + Unpack, +) import matplotlib.pyplot as plt import numpy as np @@ -60,6 +68,7 @@ import arpes.utilities.math from arpes.constants import TWO_DIMENSION +from ._typing import MPLPlotKwargs from .analysis import param_getter, param_stderr_getter, rebin from .models.band import MultifitBand from .plotting.dispersion import ( @@ -102,6 +111,7 @@ SAMPLEINFO, SCANINFO, SPECTROMETER, + BeamLineSettings, DataType, PColorMeshKwargs, XrTypes, @@ -144,7 +154,7 @@ def _iter_groups( - grouped: dict[str, Sequence[float] | float], + grouped: dict[str, NDArray[np.float_] | Sequence[float] | float], ) -> Iterator[tuple[str, float]]: """Iterates through a flattened sequence. @@ -157,7 +167,7 @@ def _iter_groups( ToDo: Not tested """ for k, value_or_list in grouped.items(): - if isinstance(value_or_list, Sequence): + if isinstance(value_or_list, Sequence | np.ndarray): for list_item in value_or_list: yield k, list_item else: @@ -167,7 +177,20 @@ def _iter_groups( class ARPESAccessorBase: """Base class for the xarray extensions in PyARPES.""" - def along(self, directions: NDArray[np.float_], **kwargs: Incomplete) -> xr.DataArray: + class _SliceAlongPathKwags(TypedDict, total=False): + arr: xr.DataArray + interpolation_points: NDArray[np.float_] | None + axis_name: str + resolution: float + n_points: int | None + extend_to_edge: bool + shift_gamma: bool + + def along( + self, + directions: NDArray[np.float_], + **kwargs: Unpack[_SliceAlongPathKwags], + ) -> xr.Dataset: """TODO: Need description. ToDo: Test @@ -394,7 +417,17 @@ def spectrum_type(self) -> Literal["cut", "map", "hv_map", "ucut", "spem", "xps" ("eV", "kp", "kz"): "hv_map", } dims: tuple = tuple(sorted(self._obj.dims)) - return dim_types.get(dims) + dim_type = dim_types.get(dims) + + def _dim_type_check( + dim_type: str | None, + ) -> TypeGuard[Literal["cut", "map", "hv_map", "ucut", "spem", "xps"]]: + return dim_type in ("cut", "map", "hv_map", "ucut", "spem", "xps") + + if _dim_type_check(dim_type): + return dim_type + msg = "Dimension type may be incorrect" + raise TypeError(msg) @property def is_differentiated(self) -> bool: @@ -441,8 +474,8 @@ def transpose_to_back(self, dim: str) -> XrTypes: def select_around_data( self, - points: dict[str, xr.DataArray] | xr.Dataset | dict | tuple[float, ...] | list[float], - radius: dict[str, float] | float | None = None, # radius={"phi": 0.005} + points: dict[Hashable, xr.DataArray], + radius: dict[Hashable, float] | float | None = None, # radius={"phi": 0.005} *, mode: Literal["sum", "mean"] = "sum", **kwargs: Incomplete, @@ -485,6 +518,7 @@ def select_around_data( points = dict(zip(points, self._obj.dims, strict=True)) if isinstance(points, xr.Dataset): points = {str(k): points[k].item() for k in points.data_vars} + radius = self._radius(points, radius, **kwargs) assert isinstance(radius, dict) @@ -536,7 +570,7 @@ def select_around_data( def select_around( self, points: dict[Hashable, float] | xr.Dataset, - radius: dict[Hashable, float] | float | None = None, + radius: dict[Hashable, float] | float, *, mode: Literal["sum", "mean"] = "sum", **kwargs: Incomplete, @@ -577,6 +611,7 @@ def select_around( if isinstance(points, xr.Dataset): points = {k: points[k].item() for k in points.data_vars} logger.debug(f"points: {points}") + assert isinstance(points, dict) radius = self._radius(points, radius, **kwargs) logger.debug(f"radius: {radius}") nearest_sel_params = {} @@ -609,7 +644,7 @@ def select_around( def _radius( self, points: dict[Hashable, float], - radius: float | dict[Hashable, float] | None, + radius: float | dict[Hashable, float], **kwargs: float, ) -> dict[Hashable, float]: """Helper function. Generate radius dict. @@ -643,73 +678,15 @@ def short_history(self, key: str = "by") -> list: Args: key (str): [TODO:description] """ - return [h["record"][key] if isinstance(h, dict) else h for h in self.history] - - def _calculate_symmetry_points( - self, - symmetry_points: dict[str, list[dict[str, float]]] | dict[str, dict[str, float]], - epsilon: float = 0.01, - ) -> tuple[defaultdict[str, list[dict[str, float]]], defaultdict[str, list[dict[str, float]]]]: - # For each symmetry point, we need to determine if it is projected or not - # if it is projected, we need to calculate its projected coordinates - """[TODO:summary]. - - Args: - symmetry_points: [TODO:description] - epsilon: [TODO:description] - """ - points = collections.defaultdict(list) - projected_points = collections.defaultdict(list) - - fixed_coords = {k: v for k, v in self._obj.coords.items() if k not in self._obj.indexes} - index_coords = self._obj.indexes - - for point, locations in symmetry_points.items(): - location_list = locations if isinstance(locations, list) else [locations] - for location in location_list: - # determine whether the location needs to be projected - projected = False - skip = False - for axis_name, value in location.items(): - if ( - axis_name in fixed_coords - and np.abs(value - fixed_coords[axis_name]) > epsilon - ): - projected = True - if axis_name not in fixed_coords and axis_name not in index_coords: - # cannot even hope to do anything here, we don't have enough info - skip = True - if skip: - continue - location.copy() # <== CHECK ME! Original: new_location = location.copy() - if projected: - # Go and do the projection, for now we will assume we just get it by - # replacing the value of the mismatched coordinates. - # This does not work if the coordinate system is not orthogonal - for axis in location: - if axis in fixed_coords: - fixed_coords[axis] - # <== CHECK ME! Original new_locationn = fixed_coords[axis] - projected_points[point].append(location) - else: - points[point].append(location) - - return points, projected_points + return [h["record"][key] if isinstance(h, dict) else h for h in self.history] # type: ignore[literal-required] def symmetry_points( self, - *, - raw: bool = False, - **kwargs: float, - ) -> ( - dict[str, dict[str, float]] - | tuple[defaultdict[str, list[dict[str, float]]], defaultdict[str, list[dict[str, float]]]] - ): - """[TODO:summary]. + ) -> dict[str, dict[str, float]]: + """Return the dict object about symmetry point such as G-point in the ARPES data. - Args: - raw (bool): [TODO:description] - kwargs: pass to _calculate_symmetry_points (epsilon) + The original version was something complicated, but the coding seemed to be in + process and the purpose was unclear, so it was streamlined considerably. """ symmetry_points: dict[str, dict[str, float]] = {} # An example of "symmetry_points": symmetry_points = {"G": {"phi": 0.405}} @@ -717,26 +694,13 @@ def symmetry_points( symmetry_points.update(our_symmetry_points) - if raw: - return symmetry_points - - return self._calculate_symmetry_points(symmetry_points, **kwargs) + return symmetry_points @property def iter_own_symmetry_points(self) -> Iterator[tuple[str, float]]: - sym_points, _ = self.symmetry_points() - return _iter_groups(sym_points) - - @property - def iter_projected_symmetry_points(self) -> Iterator[tuple[str, float]]: - _, sym_points = self.symmetry_points() + sym_points = self.symmetry_points() return _iter_groups(sym_points) - @property - def iter_symmetry_points(self) -> Iterator[tuple[str, float]]: - yield from self.iter_own_symmetry_points - yield from self.iter_projected_symmetry_points - @property def history(self) -> list[PROVENANCE | None]: provenance_recorded = self._obj.attrs.get("provenance", None) @@ -861,17 +825,18 @@ def lookup_coord(self, name: str) -> xr.DataArray | float: raise ValueError(msg) def lookup_offset(self, attr_name: str) -> float: - symmetry_points = self.symmetry_points(raw=True) + symmetry_points = self.symmetry_points() + assert isinstance(symmetry_points, dict) if "G" in symmetry_points: - gamma_point = symmetry_points["G"] + gamma_point = symmetry_points["G"] # {"phi": 0.405} (cut) if attr_name in gamma_point: - return unwrap_xarray_item(gamma_point[attr_name]) + return gamma_point[attr_name] offset_name = attr_name + "_offset" if offset_name in self._obj.attrs: - return unwrap_xarray_item(self._obj.attrs[offset_name]) + return self._obj.attrs[offset_name] - return unwrap_xarray_item(self._obj.attrs.get("data_preparation", {}).get(offset_name, 0)) + return self._obj.attrs.get("data_preparation", {}).get(offset_name, 0) @property def beta_offset(self) -> float: @@ -1307,26 +1272,15 @@ def reference_settings(self) -> dict[str, Any]: return settings @property - def beamline_settings(self) -> dict[str, Any]: - find_keys = { - "entrance_slit": { - "entrance_slit", - }, - "exit_slit": { - "exit_slit", - }, - "hv": { - "hv", - "photon_energy", - }, - "grating": {}, - } - settings = {} - for key, options in find_keys.items(): - for option in options: - if option in self._obj.attrs: - settings[key] = self._obj.attrs[option] - break + def beamline_settings(self) -> BeamLineSettings: + settings: BeamLineSettings = {} + settings["entrance_slit"] = self._obj.attrs.get("entrance_slit", np.nan) + settings["exit_slit"] = self._obj.attrs.get("exit_slit", np.nan) + settings["hv"] = self._obj.attrs.get( + "exit_slit", + self._obj.attrs.get("photon_energy", np.nan), + ) + settings["grating"] = self._obj.attrs.get("grating", None) return settings @@ -1456,8 +1410,8 @@ def sample_info(self) -> SAMPLEINFO: @property def scan_info(self) -> SCANINFO: scan_info: SCANINFO = { - "time": self._obj.attrs.get("time"), - "date": self._obj.attrs.get("date"), + "time": self._obj.attrs.get("time", None), + "date": self._obj.attrs.get("date", None), "type": self.scan_type, "spectrum_type": self.spectrum_type, "experimenter": self._obj.attrs.get("experimenter"), @@ -1537,11 +1491,11 @@ def analyzer_info(self) -> ANALYZERINFO: analyzer_info: ANALYZERINFO = { "lens_mode": self._obj.attrs.get("lens_mode"), "lens_mode_name": self._obj.attrs.get("lens_mode_name"), - "acquisition_mode": self._obj.attrs.get("acquisition_mode"), + "acquisition_mode": self._obj.attrs.get("acquisition_mode", None), "pass_energy": self._obj.attrs.get("pass_energy", np.nan), - "slit_shape": self._obj.attrs.get("slit_shape"), + "slit_shape": self._obj.attrs.get("slit_shape", None), "slit_width": self._obj.attrs.get("slit_width", np.nan), - "slit_number": self._obj.attrs.get("slit_number"), + "slit_number": self._obj.attrs.get("slit_number", np.nan), "lens_table": self._obj.attrs.get("lens_table"), "analyzer_type": self._obj.attrs.get("analyzer_type"), "mcp_voltage": self._obj.attrs.get("mcp_voltage", np.nan), @@ -1556,7 +1510,7 @@ def daq_info(self) -> DAQINFO: "daq_type": self._obj.attrs.get("daq_type"), "region": self._obj.attrs.get("daq_region"), "region_name": self._obj.attrs.get("daq_region_name"), - "center_energy": self._obj.attrs.get("daq_center_energy"), + "center_energy": self._obj.attrs.get("daq_center_energy", np.nan), "prebinning": self.prebinning, "trapezoidal_correction_strategy": self._obj.attrs.get( "trapezoidal_correction_strategy", @@ -1578,8 +1532,8 @@ def beamline_info(self) -> LIGHTSOURCEINFO: "undulator_info": self.undulator_info, "repetition_rate": self._obj.attrs.get("repetition_rate", np.nan), "beam_current": self._obj.attrs.get("beam_current", np.nan), - "entrance_slit": self._obj.attrs.get("entrance_slit"), - "exit_slit": self._obj.attrs.get("exit_slit"), + "entrance_slit": self._obj.attrs.get("entrance_slit", None), + "exit_slit": self._obj.attrs.get("exit_slit", None), "monochromator_info": self.monochromator_info, } return beamline_info @@ -1780,9 +1734,9 @@ def _experimentalinfo_to_dict(conditions: EXPERIMENTINFO) -> dict[str, str]: if isinstance(v, xr.DataArray): min_hv = float(v.min()) max_hv = float(v.max()) - transformed_dict[ - k - ] = f" from {min_hv} to {max_hv} eV" + transformed_dict[k] = ( + f" from {min_hv} to {max_hv} eV" + ) elif isinstance(v, float) and not np.isnan(v): transformed_dict[k] = f"{v} eV" return transformed_dict @@ -1958,20 +1912,21 @@ def show(self: Self, *, detached: bool = False, **kwargs: Incomplete) -> None: def fs_plot( self: Self, pattern: str = "{}.png", - **kwargs: Incomplete, - ) -> Path | None | tuple[Figure, Axes]: + **kwargs: Unpack[LabeledFermiSurfaceParam], + ) -> Path | tuple[Figure | None, Axes]: """Provides a reference plot of the approximate Fermi surface.""" out = kwargs.get("out") if out is not None and isinstance(out, bool): out = pattern.format(f"{self.label}_fs") kwargs["out"] = out + assert isinstance(self._obj, xr.DataArray) return labeled_fermi_surface(self._obj, **kwargs) def fermi_edge_reference_plot( self: Self, pattern: str = "{}.png", **kwargs: str | Normalize | None, - ) -> Path | None: + ) -> Path | Axes: """Provides a reference plot for a Fermi edge reference. Args: @@ -1985,7 +1940,7 @@ def fermi_edge_reference_plot( if out is not None and isinstance(out, bool): out = pattern.format(f"{self.label}_fermi_edge_reference") kwargs["out"] = out - + assert isinstance(self._obj, xr.DataArray) return fermi_edge_reference(self._obj, **kwargs) def _referenced_scans_for_spatial_plot( @@ -1993,7 +1948,7 @@ def _referenced_scans_for_spatial_plot( *, use_id: bool = True, pattern: str = "{}.png", - out: str | bool = "", + out: str | Path = "", ) -> Path | tuple[Figure, NDArray[np.object_]]: """[TODO:summary]. @@ -2037,8 +1992,8 @@ def _referenced_scans_for_hv_map_plot( pattern: str = "{}.png", *, use_id: bool = True, - **kwargs: IncompleteMPL, - ) -> Path | None: + **kwargs: Unpack[HvRefScanParam], + ) -> Path | Axes: out = kwargs.get("out") label = self._obj.attrs["id"] if use_id else self.label if out is not None and isinstance(out, bool): @@ -2232,6 +2187,7 @@ def correct_angle_by( ) self._obj.attrs[angle_for_correction] = 0 return + # if angle_for_correction == "beta" and self._obj.S.is_slit_vertical: self._obj.coords["phi"] = ( self._obj.coords["phi"] + self._obj.attrs[angle_for_correction] @@ -2239,22 +2195,14 @@ def correct_angle_by( self._obj.coords[angle_for_correction] = 0 self._obj.attrs[angle_for_correction] = 0 return - if ( - angle_for_correction == "beta" - and not self._obj.S.is_slit_vertical - and "psi" in self._obj.dims - ): + if angle_for_correction == "beta" and not self._obj.S.is_slit_vertical: self._obj.coords["psi"] = ( self._obj.coords["psi"] + self._obj.attrs[angle_for_correction] ) self._obj.coords[angle_for_correction] = 0 self._obj.attrs[angle_for_correction] = 0 return - if ( - angle_for_correction == "theta" - and self._obj.S.is_slit_vertical - and "psi" in self._obj.dims - ): + if angle_for_correction == "theta" and self._obj.S.is_slit_vertical: self._obj.coords["psi"] = ( self._obj.coords["psi"] + self._obj.attrs[angle_for_correction] ) @@ -2800,7 +2748,7 @@ def stride( self, *args: str | list[str] | tuple[str, ...], generic_dim_names: bool = True, - ) -> dict[str, float] | list[float] | float: + ) -> dict[Hashable, float] | list[float] | float: """Return the stride in each dimension. Note that the stride defined in this method is just a difference between first two values. @@ -2825,15 +2773,13 @@ def stride( result: dict[Hashable, float] = dict(zip(dim_names, indexed_strides, strict=True)) if args: - if len(args) == 1: - if not isinstance(args[0], str): # suppose args is list / tuple - result = [result[selected_names] for selected_names in args[0]] - else: - result = result[args[0]] - else: - # if passed several names as arguments - result = [result[selected_names] for selected_names in args] - + if isinstance(args[0], str): + return ( + result[args[0]] + if len(args) == 1 + else [result[str(selected_names)] for selected_names in args] + ) + return [result[selected_names] for selected_names in args[0]] return result def shift_by( # noqa: PLR0913 @@ -3130,7 +3076,15 @@ def __init__(self, xarray_obj: xr.DataArray) -> None: """ self._obj = xarray_obj - def plot_param(self, param_name: str, **kwargs: tuple[int, int] | RGBColorType) -> None: + class _PlotParamKwargs(MPLPlotKwargs, total=False): + + ax: Axes | None + shift: float + x_shift: float + two_sigma: bool + figsize: tuple[float, float] + + def plot_param(self, param_name: str, **kwargs: Unpack[_PlotParamKwargs]) -> None: """Creates a scatter plot of a parameter from a multidimensional curve fit. Args: diff --git a/tests/test_basic_data_loading.py b/tests/test_basic_data_loading.py index bfb931bd..7279c759 100644 --- a/tests/test_basic_data_loading.py +++ b/tests/test_basic_data_loading.py @@ -86,7 +86,7 @@ class TestMetadata: "pass_energy": np.nan, "slit_shape": None, "slit_width": np.nan, - "slit_number": None, + "slit_number": np.nan, "lens_table": None, "analyzer_type": "hemispherical", "mcp_voltage": np.nan, @@ -107,7 +107,7 @@ class TestMetadata: }, "frames_per_slice": 500, "frame_duration": np.nan, - "center_energy": None, + "center_energy": np.nan, }, "laser_info": { "pump_wavelength": np.nan, diff --git a/tests/test_xarray_extensions.py b/tests/test_xarray_extensions.py index 7ec5799d..18f90d5b 100644 --- a/tests/test_xarray_extensions.py +++ b/tests/test_xarray_extensions.py @@ -130,7 +130,13 @@ def test_spectrometer_setting(self, dataset_cut: xr.Dataset) -> None: def test_beamline_settings_reference_settings(self, dataset_cut: xr.Dataset) -> None: """Test for beamline settings.""" - assert dataset_cut.S.beamline_settings == dataset_cut.S.reference_settings == {"hv": 5.93} + assert dataset_cut.S.beamline_settings == { + "entrance_slit": np.nan, + "exit_slit": np.nan, + "hv": np.nan, + "grating": None, + } + assert dataset_cut.S.reference_settings == {"hv": 5.93} def test_full_coords(self, dataset_cut: xr.Dataset) -> None: """Test for full coords."""