diff --git a/arpes/_typing.py b/arpes/_typing.py
index acdf1dca..2778b4a8 100644
--- a/arpes/_typing.py
+++ b/arpes/_typing.py
@@ -262,6 +262,13 @@ class _BEAMLINEINFO(TypedDict, total=False):
monochrometer_info: dict[str, float]
+class BeamLineSettings(TypedDict, total=False):
+ exit_slit: float | str
+ entrance_slit: float | str
+ hv: float
+ grating: str | None
+
+
class LIGHTSOURCEINFO(_PROBEINFO, _PUMPINFO, _BEAMLINEINFO, total=False):
polarization: float | tuple[float, float] | str
photon_flux: float
diff --git a/arpes/analysis/resolution.py b/arpes/analysis/resolution.py
index b89151b8..fcf570e9 100644
--- a/arpes/analysis/resolution.py
+++ b/arpes/analysis/resolution.py
@@ -167,9 +167,9 @@ def analyzer_resolution_estimate(data: xr.DataArray, *, meV: bool = False) -> fl
def energy_resolution_from_beamline_slit(
- table: dict[str, str | float | None],
+ table: dict[tuple[float, tuple[float, float]], float],
photon_energy: float,
- exit_slit_size: str | float | tuple[float, float] | None,
+ exit_slit_size: tuple[float, float],
) -> float:
"""Calculates the energy resolution contribution from the beamline slits.
@@ -184,7 +184,9 @@ def energy_resolution_from_beamline_slit(
Returns:
The energy broadening in eV.
"""
- by_slits = {k[1]: v for k, v in table.items() if k[0] == photon_energy}
+ by_slits: dict[tuple[float, float], float] = {
+ k[1]: v for k, v in table.items() if k[0] == photon_energy
+ }
if exit_slit_size in by_slits:
return by_slits[exit_slit_size]
@@ -210,19 +212,23 @@ def energy_resolution_from_beamline_slit(
return by_area[low] + (by_area[high] - by_area[low]) * (slit_area - low) / (high - low)
-def beamline_resolution_estimate(data: xr.DataArray, *, meV: bool = False) -> None: # noqa: N803
+def beamline_resolution_estimate(
+ data: xr.DataArray,
+ *,
+ meV: bool = False, # noqa: N803
+) -> float:
data_array = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)
- resolution_table: dict[
- str,
- dict[tuple[float, tuple[float, float]], float],
- ] = ENDSTATIONS_BEAMLINE_RESOLUTION[data_array.S.endstation]
+ resolution_table: dict[str, dict[tuple[float, tuple[float, float]], float]] = (
+ ENDSTATIONS_BEAMLINE_RESOLUTION[data_array.S.endstation]
+ )
if isinstance(next(iter(resolution_table.keys())), str):
# need grating information
settings = data_array.S.beamline_settings
- resolution_table = resolution_table[settings["grating"]]
-
- all_keys = list(resolution_table.keys())
+ resolution_table_selected: dict[tuple[float, tuple[float, float]], float] = (
+ resolution_table[settings["grating"]]
+ )
+ all_keys = list(resolution_table_selected.keys())
hvs = {k[0] for k in all_keys}
low_hv = max(hv for hv in hvs if hv < settings["hv"])
@@ -232,9 +238,16 @@ def beamline_resolution_estimate(data: xr.DataArray, *, meV: bool = False) -> No
settings["entrance_slit"],
settings["exit_slit"],
)
- low_hv_res = energy_resolution_from_beamline_slit(resolution_table, low_hv, slit_size)
- high_hv_res = energy_resolution_from_beamline_slit(resolution_table, high_hv, slit_size)
-
+ low_hv_res = energy_resolution_from_beamline_slit(
+ resolution_table_selected,
+ low_hv,
+ slit_size,
+ )
+ high_hv_res = energy_resolution_from_beamline_slit(
+ resolution_table_selected,
+ high_hv,
+ slit_size,
+ )
# interpolate between nearest values
return low_hv_res + (high_hv_res - low_hv_res) * (settings["hv"] - low_hv) / (
high_hv - low_hv
@@ -249,7 +262,7 @@ def thermal_broadening_estimate(data: DataType, *, meV: bool = False) -> float:
def total_resolution_estimate(
- data: DataType,
+ data: xr.DataArray,
*,
include_thermal_broadening: bool = False,
meV: bool = False, # noqa: N803
@@ -262,7 +275,7 @@ def total_resolution_estimate(
Returns:
The estimated total resolution broadening.
"""
- thermal_broadening = 0
+ thermal_broadening = 0.0
if include_thermal_broadening:
thermal_broadening = thermal_broadening_estimate(data, meV=meV)
return math.sqrt(
diff --git a/arpes/plotting/dispersion.py b/arpes/plotting/dispersion.py
index fe37ef46..e046b20f 100644
--- a/arpes/plotting/dispersion.py
+++ b/arpes/plotting/dispersion.py
@@ -28,7 +28,7 @@
from matplotlib.figure import Figure, FigureBase
from numpy.typing import NDArray
- from arpes._typing import DataType, PColorMeshKwargs
+ from arpes._typing import PColorMeshKwargs, XrTypes
from arpes.models.band import Band
__all__ = [
@@ -304,8 +304,7 @@ def mask_for(x: NDArray[np.float_]) -> NDArray[np.float_]:
@save_plot_provenance
def hv_reference_scan(
- data: DataType,
- out: str | Path = "",
+ data: XrTypes,
e_cut: float = -0.05,
bkg_subtraction: float = 0.8,
**kwargs: Unpack[LabeledFermiSurfaceParam],
@@ -316,7 +315,11 @@ def hv_reference_scan(
fs.data -= bkg_subtraction * np.mean(fs.data)
fs.data[fs.data < 0] = 0
- _, ax = labeled_fermi_surface(fs, **kwargs)
+ out = kwargs.pop("out", None)
+
+ lfs = labeled_fermi_surface(fs, **kwargs)
+ assert isinstance(lfs, tuple)
+ _, ax = lfs
all_scans = data.attrs["df"]
all_scans = all_scans[all_scans.id != data.attrs["id"]]
@@ -330,7 +333,7 @@ def hv_reference_scan(
scans_by_hv[round(scan.S.hv)].append(scan.S.label.replace("_", " "))
- dim_order = ax.dim_order
+ dim_order = [ax.get_xlabel(), ax.get_ylabel()]
handles = []
handle_labels = []
@@ -357,19 +360,19 @@ def hv_reference_scan(
return path_for_plot(out)
plt.show()
- return None
+ return ax
class LabeledFermiSurfaceParam(TypedDict, total=False):
include_symmetry_points: bool
include_bz: bool
fermi_energy: float
+ out: str | Path
@save_plot_provenance
def reference_scan_fermi_surface(
- data: DataType,
- out: str | Path = "",
+ data: XrTypes,
**kwargs: Unpack[LabeledFermiSurfaceParam],
) -> Path | Axes:
"""A reference plot for Fermi surfaces. Used internally by other code.
@@ -377,14 +380,19 @@ def reference_scan_fermi_surface(
Warning: Not work correctly. (Because S.referenced_scans has been removed.)
"""
fs = data.S.fermi_surface
- _, ax = labeled_fermi_surface(fs, **kwargs)
+
+ out = kwargs.pop("out", None)
+ lfs = labeled_fermi_surface(fs, **kwargs)
+ assert isinstance(lfs, tuple)
+ _, ax = lfs
referenced_scans = data.S.referenced_scans
handles = []
for index, row in referenced_scans.iterrows():
scan = load_data(row.id)
remapped_coords = remap_coords_to(scan, data)
- dim_order = ax.dim_order
+
+ dim_order = [ax.get_xlabel(), ax.get_ylabel()]
ls = ax.plot(
remapped_coords[dim_order[0]],
remapped_coords[dim_order[1]],
@@ -447,7 +455,7 @@ def labeled_fermi_surface( # noqa: PLR0913
if include_symmetry_points:
for point_name, point_location in data.S.iter_symmetry_points:
warnings.warn("Symmetry point locations are not k-converted", stacklevel=2)
- coords = [point_location[d] for d in dim_order]
+ coords = tuple([point_location[d] for d in dim_order])
ax.plot(*coords, marker=".", color=marker_color)
ax.annotate(
label_for_symmetry_point(point_name),
diff --git a/arpes/plotting/dos.py b/arpes/plotting/dos.py
index 95e38807..a9184154 100644
--- a/arpes/plotting/dos.py
+++ b/arpes/plotting/dos.py
@@ -23,8 +23,6 @@
from matplotlib.colors import Normalize
from matplotlib.figure import Figure
- from arpes._typing import DataType
-
__all__ = (
"plot_dos",
"plot_core_levels",
@@ -33,7 +31,7 @@
@save_plot_provenance
def plot_core_levels( # noqa: PLR0913
- data: DataType,
+ data: xr.DataArray,
title: str = "",
out: str | Path = "",
norm: Normalize | None = None,
@@ -43,7 +41,9 @@ def plot_core_levels( # noqa: PLR0913
promenance: int = 5,
) -> Path | tuple[Axes, Colorbar]:
"""Plots an XPS curve and approximate core level locations."""
- _, axes, cbar = plot_dos(data=data, title=title, out="", norm=norm, dos_pow=dos_pow)
+ plotdos = plot_dos(data=data, title=title, out="", norm=norm, dos_pow=dos_pow)
+ assert isinstance(plotdos, tuple)
+ _, axes, cbar = plotdos
if core_levels is None:
core_levels = approximate_core_levels(data, binning=binning, promenance=promenance)
@@ -64,7 +64,7 @@ def plot_dos(
out: str | Path = "",
norm: Normalize | None = None,
dos_pow: float = 1,
-) -> Path | tuple[Figure, Axes, Colorbar]:
+) -> Path | tuple[Figure, tuple[Axes], Colorbar]:
"""Plots the density of states (momentum integrated) image next to the original spectrum."""
data_arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)
diff --git a/arpes/plotting/fermi_edge.py b/arpes/plotting/fermi_edge.py
index 8f0b6154..c22b5ccd 100644
--- a/arpes/plotting/fermi_edge.py
+++ b/arpes/plotting/fermi_edge.py
@@ -110,7 +110,7 @@ def fermi_edge_reference(
data: xr.DataArray,
title: str = "",
ax: Axes | None = None,
- out: str = "",
+ out: str | Path = "",
**kwargs: Incomplete,
) -> Path | Axes:
"""Fits for and plots results for the Fermi edge on a piece of data.
diff --git a/arpes/utilities/bz.py b/arpes/utilities/bz.py
index 6d3460dc..a85c5613 100644
--- a/arpes/utilities/bz.py
+++ b/arpes/utilities/bz.py
@@ -461,7 +461,7 @@ def reduced_bz_axis_to(
point_names = _POINT_NAMES_FOR_SYMMETRY[symmetry]
symmetry_points, _ = data.S.symmetry_points()
- points = {k: v[0] for k, v in symmetry_points.items() if k in point_names}
+ points = {k: v for k, v in symmetry_points.items() if k in point_names}
coords_by_point = {
k: np.array([v.get(d, 0) for d in data.dims if d in v or include_E and d == "eV"])
@@ -500,7 +500,7 @@ def reduced_bz_axes(data: XrTypes) -> tuple[NDArray[np.float_], NDArray[np.float
point_names = _POINT_NAMES_FOR_SYMMETRY[symmetry]
symmetry_points, _ = data.S.symmetry_points()
- points = {k: v[0] for k, v in symmetry_points.items() if k in point_names}
+ points = {k: v for k, v in symmetry_points.items() if k in point_names}
coords_by_point = {k: np.array([v[d] for d in data.dims if d in v]) for k, v in points.items()}
if symmetry == "rect":
@@ -537,7 +537,7 @@ def axis_along(data: XrTypes, symbol: str) -> float:
point_names = _POINT_NAMES_FOR_SYMMETRY[symmetry]
symmetry_points, _ = data.S.symmetry_points()
- points = {k: v[0] for k, v in symmetry_points.items() if k in point_names}
+ points = {k: v for k, v in symmetry_points.items() if k in point_names}
coords_by_point = {k: np.array([v[d] for d in data.dims if d in v]) for k, v in points.items()}
@@ -576,7 +576,7 @@ def reduced_bz_poly(data: XrTypes, *, scale_zone: bool = False) -> NDArray[np.fl
dy = 3 * dy
symmetry_points, _ = data.S.symmetry_points()
- points = {k: v[0] for k, v in symmetry_points.items() if k in point_names}
+ points = {k: v for k, v in symmetry_points.items() if k in point_names}
coords_by_point = {
k: np.array([v.get(d, 0) for d in data.dims if d in v]) for k, v in points.items()
}
@@ -625,7 +625,7 @@ def reduced_bz_E_mask(
point_names = _POINT_NAMES_FOR_SYMMETRY[symmetry]
symmetry_points, _ = data.S.symmetry_points()
- points = {k: v[0] for k, v in symmetry_points.items() if k in point_names}
+ points = {k: v for k, v in symmetry_points.items() if k in point_names}
coords_by_point = {
k: np.array([v.get(d, 0) for d in data.dims if d in v or d == "eV"])
for k, v in points.items()
diff --git a/arpes/utilities/conversion/core.py b/arpes/utilities/conversion/core.py
index 9e22612e..5e768762 100644
--- a/arpes/utilities/conversion/core.py
+++ b/arpes/utilities/conversion/core.py
@@ -31,7 +31,6 @@
from typing import TYPE_CHECKING, Literal
import numpy as np
-from numpy.typing import ArrayLike
import xarray as xr
from scipy.interpolate import RegularGridInterpolator
diff --git a/arpes/xarray_extensions.py b/arpes/xarray_extensions.py
index badc07b8..e47f76d0 100644
--- a/arpes/xarray_extensions.py
+++ b/arpes/xarray_extensions.py
@@ -39,16 +39,24 @@
from __future__ import annotations
-import collections
import contextlib
import copy
import itertools
import warnings
-from collections import OrderedDict, defaultdict
+from collections import OrderedDict
from collections.abc import Collection, Hashable, Mapping, Sequence
from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger
from pathlib import Path
-from typing import TYPE_CHECKING, Any, Literal, Self, TypeAlias, Unpack
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Literal,
+ Self,
+ TypeAlias,
+ TypedDict,
+ TypeGuard,
+ Unpack,
+)
import matplotlib.pyplot as plt
import numpy as np
@@ -60,6 +68,7 @@
import arpes.utilities.math
from arpes.constants import TWO_DIMENSION
+from ._typing import MPLPlotKwargs
from .analysis import param_getter, param_stderr_getter, rebin
from .models.band import MultifitBand
from .plotting.dispersion import (
@@ -102,6 +111,7 @@
SAMPLEINFO,
SCANINFO,
SPECTROMETER,
+ BeamLineSettings,
DataType,
PColorMeshKwargs,
XrTypes,
@@ -144,7 +154,7 @@
def _iter_groups(
- grouped: dict[str, Sequence[float] | float],
+ grouped: dict[str, NDArray[np.float_] | Sequence[float] | float],
) -> Iterator[tuple[str, float]]:
"""Iterates through a flattened sequence.
@@ -157,7 +167,7 @@ def _iter_groups(
ToDo: Not tested
"""
for k, value_or_list in grouped.items():
- if isinstance(value_or_list, Sequence):
+ if isinstance(value_or_list, Sequence | np.ndarray):
for list_item in value_or_list:
yield k, list_item
else:
@@ -167,7 +177,20 @@ def _iter_groups(
class ARPESAccessorBase:
"""Base class for the xarray extensions in PyARPES."""
- def along(self, directions: NDArray[np.float_], **kwargs: Incomplete) -> xr.DataArray:
+ class _SliceAlongPathKwags(TypedDict, total=False):
+ arr: xr.DataArray
+ interpolation_points: NDArray[np.float_] | None
+ axis_name: str
+ resolution: float
+ n_points: int | None
+ extend_to_edge: bool
+ shift_gamma: bool
+
+ def along(
+ self,
+ directions: NDArray[np.float_],
+ **kwargs: Unpack[_SliceAlongPathKwags],
+ ) -> xr.Dataset:
"""TODO: Need description.
ToDo: Test
@@ -394,7 +417,17 @@ def spectrum_type(self) -> Literal["cut", "map", "hv_map", "ucut", "spem", "xps"
("eV", "kp", "kz"): "hv_map",
}
dims: tuple = tuple(sorted(self._obj.dims))
- return dim_types.get(dims)
+ dim_type = dim_types.get(dims)
+
+ def _dim_type_check(
+ dim_type: str | None,
+ ) -> TypeGuard[Literal["cut", "map", "hv_map", "ucut", "spem", "xps"]]:
+ return dim_type in ("cut", "map", "hv_map", "ucut", "spem", "xps")
+
+ if _dim_type_check(dim_type):
+ return dim_type
+ msg = "Dimension type may be incorrect"
+ raise TypeError(msg)
@property
def is_differentiated(self) -> bool:
@@ -441,8 +474,8 @@ def transpose_to_back(self, dim: str) -> XrTypes:
def select_around_data(
self,
- points: dict[str, xr.DataArray] | xr.Dataset | dict | tuple[float, ...] | list[float],
- radius: dict[str, float] | float | None = None, # radius={"phi": 0.005}
+ points: dict[Hashable, xr.DataArray],
+ radius: dict[Hashable, float] | float | None = None, # radius={"phi": 0.005}
*,
mode: Literal["sum", "mean"] = "sum",
**kwargs: Incomplete,
@@ -485,6 +518,7 @@ def select_around_data(
points = dict(zip(points, self._obj.dims, strict=True))
if isinstance(points, xr.Dataset):
points = {str(k): points[k].item() for k in points.data_vars}
+
radius = self._radius(points, radius, **kwargs)
assert isinstance(radius, dict)
@@ -536,7 +570,7 @@ def select_around_data(
def select_around(
self,
points: dict[Hashable, float] | xr.Dataset,
- radius: dict[Hashable, float] | float | None = None,
+ radius: dict[Hashable, float] | float,
*,
mode: Literal["sum", "mean"] = "sum",
**kwargs: Incomplete,
@@ -577,6 +611,7 @@ def select_around(
if isinstance(points, xr.Dataset):
points = {k: points[k].item() for k in points.data_vars}
logger.debug(f"points: {points}")
+ assert isinstance(points, dict)
radius = self._radius(points, radius, **kwargs)
logger.debug(f"radius: {radius}")
nearest_sel_params = {}
@@ -609,7 +644,7 @@ def select_around(
def _radius(
self,
points: dict[Hashable, float],
- radius: float | dict[Hashable, float] | None,
+ radius: float | dict[Hashable, float],
**kwargs: float,
) -> dict[Hashable, float]:
"""Helper function. Generate radius dict.
@@ -643,73 +678,15 @@ def short_history(self, key: str = "by") -> list:
Args:
key (str): [TODO:description]
"""
- return [h["record"][key] if isinstance(h, dict) else h for h in self.history]
-
- def _calculate_symmetry_points(
- self,
- symmetry_points: dict[str, list[dict[str, float]]] | dict[str, dict[str, float]],
- epsilon: float = 0.01,
- ) -> tuple[defaultdict[str, list[dict[str, float]]], defaultdict[str, list[dict[str, float]]]]:
- # For each symmetry point, we need to determine if it is projected or not
- # if it is projected, we need to calculate its projected coordinates
- """[TODO:summary].
-
- Args:
- symmetry_points: [TODO:description]
- epsilon: [TODO:description]
- """
- points = collections.defaultdict(list)
- projected_points = collections.defaultdict(list)
-
- fixed_coords = {k: v for k, v in self._obj.coords.items() if k not in self._obj.indexes}
- index_coords = self._obj.indexes
-
- for point, locations in symmetry_points.items():
- location_list = locations if isinstance(locations, list) else [locations]
- for location in location_list:
- # determine whether the location needs to be projected
- projected = False
- skip = False
- for axis_name, value in location.items():
- if (
- axis_name in fixed_coords
- and np.abs(value - fixed_coords[axis_name]) > epsilon
- ):
- projected = True
- if axis_name not in fixed_coords and axis_name not in index_coords:
- # cannot even hope to do anything here, we don't have enough info
- skip = True
- if skip:
- continue
- location.copy() # <== CHECK ME! Original: new_location = location.copy()
- if projected:
- # Go and do the projection, for now we will assume we just get it by
- # replacing the value of the mismatched coordinates.
- # This does not work if the coordinate system is not orthogonal
- for axis in location:
- if axis in fixed_coords:
- fixed_coords[axis]
- # <== CHECK ME! Original new_locationn = fixed_coords[axis]
- projected_points[point].append(location)
- else:
- points[point].append(location)
-
- return points, projected_points
+ return [h["record"][key] if isinstance(h, dict) else h for h in self.history] # type: ignore[literal-required]
def symmetry_points(
self,
- *,
- raw: bool = False,
- **kwargs: float,
- ) -> (
- dict[str, dict[str, float]]
- | tuple[defaultdict[str, list[dict[str, float]]], defaultdict[str, list[dict[str, float]]]]
- ):
- """[TODO:summary].
+ ) -> dict[str, dict[str, float]]:
+ """Return the dict object about symmetry point such as G-point in the ARPES data.
- Args:
- raw (bool): [TODO:description]
- kwargs: pass to _calculate_symmetry_points (epsilon)
+ The original version was something complicated, but the coding seemed to be in
+ process and the purpose was unclear, so it was streamlined considerably.
"""
symmetry_points: dict[str, dict[str, float]] = {}
# An example of "symmetry_points": symmetry_points = {"G": {"phi": 0.405}}
@@ -717,26 +694,13 @@ def symmetry_points(
symmetry_points.update(our_symmetry_points)
- if raw:
- return symmetry_points
-
- return self._calculate_symmetry_points(symmetry_points, **kwargs)
+ return symmetry_points
@property
def iter_own_symmetry_points(self) -> Iterator[tuple[str, float]]:
- sym_points, _ = self.symmetry_points()
- return _iter_groups(sym_points)
-
- @property
- def iter_projected_symmetry_points(self) -> Iterator[tuple[str, float]]:
- _, sym_points = self.symmetry_points()
+ sym_points = self.symmetry_points()
return _iter_groups(sym_points)
- @property
- def iter_symmetry_points(self) -> Iterator[tuple[str, float]]:
- yield from self.iter_own_symmetry_points
- yield from self.iter_projected_symmetry_points
-
@property
def history(self) -> list[PROVENANCE | None]:
provenance_recorded = self._obj.attrs.get("provenance", None)
@@ -861,17 +825,18 @@ def lookup_coord(self, name: str) -> xr.DataArray | float:
raise ValueError(msg)
def lookup_offset(self, attr_name: str) -> float:
- symmetry_points = self.symmetry_points(raw=True)
+ symmetry_points = self.symmetry_points()
+ assert isinstance(symmetry_points, dict)
if "G" in symmetry_points:
- gamma_point = symmetry_points["G"]
+ gamma_point = symmetry_points["G"] # {"phi": 0.405} (cut)
if attr_name in gamma_point:
- return unwrap_xarray_item(gamma_point[attr_name])
+ return gamma_point[attr_name]
offset_name = attr_name + "_offset"
if offset_name in self._obj.attrs:
- return unwrap_xarray_item(self._obj.attrs[offset_name])
+ return self._obj.attrs[offset_name]
- return unwrap_xarray_item(self._obj.attrs.get("data_preparation", {}).get(offset_name, 0))
+ return self._obj.attrs.get("data_preparation", {}).get(offset_name, 0)
@property
def beta_offset(self) -> float:
@@ -1307,26 +1272,15 @@ def reference_settings(self) -> dict[str, Any]:
return settings
@property
- def beamline_settings(self) -> dict[str, Any]:
- find_keys = {
- "entrance_slit": {
- "entrance_slit",
- },
- "exit_slit": {
- "exit_slit",
- },
- "hv": {
- "hv",
- "photon_energy",
- },
- "grating": {},
- }
- settings = {}
- for key, options in find_keys.items():
- for option in options:
- if option in self._obj.attrs:
- settings[key] = self._obj.attrs[option]
- break
+ def beamline_settings(self) -> BeamLineSettings:
+ settings: BeamLineSettings = {}
+ settings["entrance_slit"] = self._obj.attrs.get("entrance_slit", np.nan)
+ settings["exit_slit"] = self._obj.attrs.get("exit_slit", np.nan)
+ settings["hv"] = self._obj.attrs.get(
+ "exit_slit",
+ self._obj.attrs.get("photon_energy", np.nan),
+ )
+ settings["grating"] = self._obj.attrs.get("grating", None)
return settings
@@ -1456,8 +1410,8 @@ def sample_info(self) -> SAMPLEINFO:
@property
def scan_info(self) -> SCANINFO:
scan_info: SCANINFO = {
- "time": self._obj.attrs.get("time"),
- "date": self._obj.attrs.get("date"),
+ "time": self._obj.attrs.get("time", None),
+ "date": self._obj.attrs.get("date", None),
"type": self.scan_type,
"spectrum_type": self.spectrum_type,
"experimenter": self._obj.attrs.get("experimenter"),
@@ -1537,11 +1491,11 @@ def analyzer_info(self) -> ANALYZERINFO:
analyzer_info: ANALYZERINFO = {
"lens_mode": self._obj.attrs.get("lens_mode"),
"lens_mode_name": self._obj.attrs.get("lens_mode_name"),
- "acquisition_mode": self._obj.attrs.get("acquisition_mode"),
+ "acquisition_mode": self._obj.attrs.get("acquisition_mode", None),
"pass_energy": self._obj.attrs.get("pass_energy", np.nan),
- "slit_shape": self._obj.attrs.get("slit_shape"),
+ "slit_shape": self._obj.attrs.get("slit_shape", None),
"slit_width": self._obj.attrs.get("slit_width", np.nan),
- "slit_number": self._obj.attrs.get("slit_number"),
+ "slit_number": self._obj.attrs.get("slit_number", np.nan),
"lens_table": self._obj.attrs.get("lens_table"),
"analyzer_type": self._obj.attrs.get("analyzer_type"),
"mcp_voltage": self._obj.attrs.get("mcp_voltage", np.nan),
@@ -1556,7 +1510,7 @@ def daq_info(self) -> DAQINFO:
"daq_type": self._obj.attrs.get("daq_type"),
"region": self._obj.attrs.get("daq_region"),
"region_name": self._obj.attrs.get("daq_region_name"),
- "center_energy": self._obj.attrs.get("daq_center_energy"),
+ "center_energy": self._obj.attrs.get("daq_center_energy", np.nan),
"prebinning": self.prebinning,
"trapezoidal_correction_strategy": self._obj.attrs.get(
"trapezoidal_correction_strategy",
@@ -1578,8 +1532,8 @@ def beamline_info(self) -> LIGHTSOURCEINFO:
"undulator_info": self.undulator_info,
"repetition_rate": self._obj.attrs.get("repetition_rate", np.nan),
"beam_current": self._obj.attrs.get("beam_current", np.nan),
- "entrance_slit": self._obj.attrs.get("entrance_slit"),
- "exit_slit": self._obj.attrs.get("exit_slit"),
+ "entrance_slit": self._obj.attrs.get("entrance_slit", None),
+ "exit_slit": self._obj.attrs.get("exit_slit", None),
"monochromator_info": self.monochromator_info,
}
return beamline_info
@@ -1780,9 +1734,9 @@ def _experimentalinfo_to_dict(conditions: EXPERIMENTINFO) -> dict[str, str]:
if isinstance(v, xr.DataArray):
min_hv = float(v.min())
max_hv = float(v.max())
- transformed_dict[
- k
- ] = f" from {min_hv} to {max_hv} eV"
+ transformed_dict[k] = (
+ f" from {min_hv} to {max_hv} eV"
+ )
elif isinstance(v, float) and not np.isnan(v):
transformed_dict[k] = f"{v} eV"
return transformed_dict
@@ -1958,20 +1912,21 @@ def show(self: Self, *, detached: bool = False, **kwargs: Incomplete) -> None:
def fs_plot(
self: Self,
pattern: str = "{}.png",
- **kwargs: Incomplete,
- ) -> Path | None | tuple[Figure, Axes]:
+ **kwargs: Unpack[LabeledFermiSurfaceParam],
+ ) -> Path | tuple[Figure | None, Axes]:
"""Provides a reference plot of the approximate Fermi surface."""
out = kwargs.get("out")
if out is not None and isinstance(out, bool):
out = pattern.format(f"{self.label}_fs")
kwargs["out"] = out
+ assert isinstance(self._obj, xr.DataArray)
return labeled_fermi_surface(self._obj, **kwargs)
def fermi_edge_reference_plot(
self: Self,
pattern: str = "{}.png",
**kwargs: str | Normalize | None,
- ) -> Path | None:
+ ) -> Path | Axes:
"""Provides a reference plot for a Fermi edge reference.
Args:
@@ -1985,7 +1940,7 @@ def fermi_edge_reference_plot(
if out is not None and isinstance(out, bool):
out = pattern.format(f"{self.label}_fermi_edge_reference")
kwargs["out"] = out
-
+ assert isinstance(self._obj, xr.DataArray)
return fermi_edge_reference(self._obj, **kwargs)
def _referenced_scans_for_spatial_plot(
@@ -1993,7 +1948,7 @@ def _referenced_scans_for_spatial_plot(
*,
use_id: bool = True,
pattern: str = "{}.png",
- out: str | bool = "",
+ out: str | Path = "",
) -> Path | tuple[Figure, NDArray[np.object_]]:
"""[TODO:summary].
@@ -2037,8 +1992,8 @@ def _referenced_scans_for_hv_map_plot(
pattern: str = "{}.png",
*,
use_id: bool = True,
- **kwargs: IncompleteMPL,
- ) -> Path | None:
+ **kwargs: Unpack[HvRefScanParam],
+ ) -> Path | Axes:
out = kwargs.get("out")
label = self._obj.attrs["id"] if use_id else self.label
if out is not None and isinstance(out, bool):
@@ -2232,6 +2187,7 @@ def correct_angle_by(
)
self._obj.attrs[angle_for_correction] = 0
return
+ #
if angle_for_correction == "beta" and self._obj.S.is_slit_vertical:
self._obj.coords["phi"] = (
self._obj.coords["phi"] + self._obj.attrs[angle_for_correction]
@@ -2239,22 +2195,14 @@ def correct_angle_by(
self._obj.coords[angle_for_correction] = 0
self._obj.attrs[angle_for_correction] = 0
return
- if (
- angle_for_correction == "beta"
- and not self._obj.S.is_slit_vertical
- and "psi" in self._obj.dims
- ):
+ if angle_for_correction == "beta" and not self._obj.S.is_slit_vertical:
self._obj.coords["psi"] = (
self._obj.coords["psi"] + self._obj.attrs[angle_for_correction]
)
self._obj.coords[angle_for_correction] = 0
self._obj.attrs[angle_for_correction] = 0
return
- if (
- angle_for_correction == "theta"
- and self._obj.S.is_slit_vertical
- and "psi" in self._obj.dims
- ):
+ if angle_for_correction == "theta" and self._obj.S.is_slit_vertical:
self._obj.coords["psi"] = (
self._obj.coords["psi"] + self._obj.attrs[angle_for_correction]
)
@@ -2800,7 +2748,7 @@ def stride(
self,
*args: str | list[str] | tuple[str, ...],
generic_dim_names: bool = True,
- ) -> dict[str, float] | list[float] | float:
+ ) -> dict[Hashable, float] | list[float] | float:
"""Return the stride in each dimension.
Note that the stride defined in this method is just a difference between first two values.
@@ -2825,15 +2773,13 @@ def stride(
result: dict[Hashable, float] = dict(zip(dim_names, indexed_strides, strict=True))
if args:
- if len(args) == 1:
- if not isinstance(args[0], str): # suppose args is list / tuple
- result = [result[selected_names] for selected_names in args[0]]
- else:
- result = result[args[0]]
- else:
- # if passed several names as arguments
- result = [result[selected_names] for selected_names in args]
-
+ if isinstance(args[0], str):
+ return (
+ result[args[0]]
+ if len(args) == 1
+ else [result[str(selected_names)] for selected_names in args]
+ )
+ return [result[selected_names] for selected_names in args[0]]
return result
def shift_by( # noqa: PLR0913
@@ -3130,7 +3076,15 @@ def __init__(self, xarray_obj: xr.DataArray) -> None:
"""
self._obj = xarray_obj
- def plot_param(self, param_name: str, **kwargs: tuple[int, int] | RGBColorType) -> None:
+ class _PlotParamKwargs(MPLPlotKwargs, total=False):
+
+ ax: Axes | None
+ shift: float
+ x_shift: float
+ two_sigma: bool
+ figsize: tuple[float, float]
+
+ def plot_param(self, param_name: str, **kwargs: Unpack[_PlotParamKwargs]) -> None:
"""Creates a scatter plot of a parameter from a multidimensional curve fit.
Args:
diff --git a/tests/test_basic_data_loading.py b/tests/test_basic_data_loading.py
index bfb931bd..7279c759 100644
--- a/tests/test_basic_data_loading.py
+++ b/tests/test_basic_data_loading.py
@@ -86,7 +86,7 @@ class TestMetadata:
"pass_energy": np.nan,
"slit_shape": None,
"slit_width": np.nan,
- "slit_number": None,
+ "slit_number": np.nan,
"lens_table": None,
"analyzer_type": "hemispherical",
"mcp_voltage": np.nan,
@@ -107,7 +107,7 @@ class TestMetadata:
},
"frames_per_slice": 500,
"frame_duration": np.nan,
- "center_energy": None,
+ "center_energy": np.nan,
},
"laser_info": {
"pump_wavelength": np.nan,
diff --git a/tests/test_xarray_extensions.py b/tests/test_xarray_extensions.py
index 7ec5799d..18f90d5b 100644
--- a/tests/test_xarray_extensions.py
+++ b/tests/test_xarray_extensions.py
@@ -130,7 +130,13 @@ def test_spectrometer_setting(self, dataset_cut: xr.Dataset) -> None:
def test_beamline_settings_reference_settings(self, dataset_cut: xr.Dataset) -> None:
"""Test for beamline settings."""
- assert dataset_cut.S.beamline_settings == dataset_cut.S.reference_settings == {"hv": 5.93}
+ assert dataset_cut.S.beamline_settings == {
+ "entrance_slit": np.nan,
+ "exit_slit": np.nan,
+ "hv": np.nan,
+ "grating": None,
+ }
+ assert dataset_cut.S.reference_settings == {"hv": 5.93}
def test_full_coords(self, dataset_cut: xr.Dataset) -> None:
"""Test for full coords."""