Skip to content

Commit

Permalink
🔥 Be slim: S.symmetry_points just refers attrs.
Browse files Browse the repository at this point in the history
Consider the projected symmetry points etc after if needed.
(At least, the current algorithm cannot be be used without any
concern.)

* and update type hints, as usual.
  • Loading branch information
arafune committed Feb 11, 2024
1 parent 617cc57 commit 3d502b4
Show file tree
Hide file tree
Showing 10 changed files with 182 additions and 195 deletions.
7 changes: 7 additions & 0 deletions arpes/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,13 @@ class _BEAMLINEINFO(TypedDict, total=False):
monochrometer_info: dict[str, float]


class BeamLineSettings(TypedDict, total=False):
exit_slit: float | str
entrance_slit: float | str
hv: float
grating: str | None


class LIGHTSOURCEINFO(_PROBEINFO, _PUMPINFO, _BEAMLINEINFO, total=False):
polarization: float | tuple[float, float] | str
photon_flux: float
Expand Down
45 changes: 29 additions & 16 deletions arpes/analysis/resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ def analyzer_resolution_estimate(data: xr.DataArray, *, meV: bool = False) -> fl


def energy_resolution_from_beamline_slit(
table: dict[str, str | float | None],
table: dict[tuple[float, tuple[float, float]], float],
photon_energy: float,
exit_slit_size: str | float | tuple[float, float] | None,
exit_slit_size: tuple[float, float],
) -> float:
"""Calculates the energy resolution contribution from the beamline slits.
Expand All @@ -184,7 +184,9 @@ def energy_resolution_from_beamline_slit(
Returns:
The energy broadening in eV.
"""
by_slits = {k[1]: v for k, v in table.items() if k[0] == photon_energy}
by_slits: dict[tuple[float, float], float] = {
k[1]: v for k, v in table.items() if k[0] == photon_energy
}
if exit_slit_size in by_slits:
return by_slits[exit_slit_size]

Expand All @@ -210,19 +212,23 @@ def energy_resolution_from_beamline_slit(
return by_area[low] + (by_area[high] - by_area[low]) * (slit_area - low) / (high - low)


def beamline_resolution_estimate(data: xr.DataArray, *, meV: bool = False) -> None: # noqa: N803
def beamline_resolution_estimate(
data: xr.DataArray,
*,
meV: bool = False, # noqa: N803
) -> float:
data_array = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)
resolution_table: dict[
str,
dict[tuple[float, tuple[float, float]], float],
] = ENDSTATIONS_BEAMLINE_RESOLUTION[data_array.S.endstation]
resolution_table: dict[str, dict[tuple[float, tuple[float, float]], float]] = (
ENDSTATIONS_BEAMLINE_RESOLUTION[data_array.S.endstation]
)

if isinstance(next(iter(resolution_table.keys())), str):
# need grating information
settings = data_array.S.beamline_settings
resolution_table = resolution_table[settings["grating"]]

all_keys = list(resolution_table.keys())
resolution_table_selected: dict[tuple[float, tuple[float, float]], float] = (
resolution_table[settings["grating"]]
)
all_keys = list(resolution_table_selected.keys())
hvs = {k[0] for k in all_keys}

low_hv = max(hv for hv in hvs if hv < settings["hv"])
Expand All @@ -232,9 +238,16 @@ def beamline_resolution_estimate(data: xr.DataArray, *, meV: bool = False) -> No
settings["entrance_slit"],
settings["exit_slit"],
)
low_hv_res = energy_resolution_from_beamline_slit(resolution_table, low_hv, slit_size)
high_hv_res = energy_resolution_from_beamline_slit(resolution_table, high_hv, slit_size)

low_hv_res = energy_resolution_from_beamline_slit(
resolution_table_selected,
low_hv,
slit_size,
)
high_hv_res = energy_resolution_from_beamline_slit(
resolution_table_selected,
high_hv,
slit_size,
)
# interpolate between nearest values
return low_hv_res + (high_hv_res - low_hv_res) * (settings["hv"] - low_hv) / (
high_hv - low_hv
Expand All @@ -249,7 +262,7 @@ def thermal_broadening_estimate(data: DataType, *, meV: bool = False) -> float:


def total_resolution_estimate(
data: DataType,
data: xr.DataArray,
*,
include_thermal_broadening: bool = False,
meV: bool = False, # noqa: N803
Expand All @@ -262,7 +275,7 @@ def total_resolution_estimate(
Returns:
The estimated total resolution broadening.
"""
thermal_broadening = 0
thermal_broadening = 0.0
if include_thermal_broadening:
thermal_broadening = thermal_broadening_estimate(data, meV=meV)
return math.sqrt(
Expand Down
30 changes: 19 additions & 11 deletions arpes/plotting/dispersion.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from matplotlib.figure import Figure, FigureBase
from numpy.typing import NDArray

from arpes._typing import DataType, PColorMeshKwargs
from arpes._typing import PColorMeshKwargs, XrTypes
from arpes.models.band import Band

__all__ = [
Expand Down Expand Up @@ -304,8 +304,7 @@ def mask_for(x: NDArray[np.float_]) -> NDArray[np.float_]:

@save_plot_provenance
def hv_reference_scan(
data: DataType,
out: str | Path = "",
data: XrTypes,
e_cut: float = -0.05,
bkg_subtraction: float = 0.8,
**kwargs: Unpack[LabeledFermiSurfaceParam],
Expand All @@ -316,7 +315,11 @@ def hv_reference_scan(
fs.data -= bkg_subtraction * np.mean(fs.data)
fs.data[fs.data < 0] = 0

_, ax = labeled_fermi_surface(fs, **kwargs)
out = kwargs.pop("out", None)

lfs = labeled_fermi_surface(fs, **kwargs)
assert isinstance(lfs, tuple)
_, ax = lfs

all_scans = data.attrs["df"]
all_scans = all_scans[all_scans.id != data.attrs["id"]]
Expand All @@ -330,7 +333,7 @@ def hv_reference_scan(

scans_by_hv[round(scan.S.hv)].append(scan.S.label.replace("_", " "))

dim_order = ax.dim_order
dim_order = [ax.get_xlabel(), ax.get_ylabel()]
handles = []
handle_labels = []

Expand All @@ -357,34 +360,39 @@ def hv_reference_scan(
return path_for_plot(out)

plt.show()
return None
return ax


class LabeledFermiSurfaceParam(TypedDict, total=False):
include_symmetry_points: bool
include_bz: bool
fermi_energy: float
out: str | Path


@save_plot_provenance
def reference_scan_fermi_surface(
data: DataType,
out: str | Path = "",
data: XrTypes,
**kwargs: Unpack[LabeledFermiSurfaceParam],
) -> Path | Axes:
"""A reference plot for Fermi surfaces. Used internally by other code.
Warning: Not work correctly. (Because S.referenced_scans has been removed.)
"""
fs = data.S.fermi_surface
_, ax = labeled_fermi_surface(fs, **kwargs)

out = kwargs.pop("out", None)
lfs = labeled_fermi_surface(fs, **kwargs)
assert isinstance(lfs, tuple)
_, ax = lfs

referenced_scans = data.S.referenced_scans
handles = []
for index, row in referenced_scans.iterrows():
scan = load_data(row.id)
remapped_coords = remap_coords_to(scan, data)
dim_order = ax.dim_order

dim_order = [ax.get_xlabel(), ax.get_ylabel()]
ls = ax.plot(
remapped_coords[dim_order[0]],
remapped_coords[dim_order[1]],
Expand Down Expand Up @@ -447,7 +455,7 @@ def labeled_fermi_surface( # noqa: PLR0913
if include_symmetry_points:
for point_name, point_location in data.S.iter_symmetry_points:
warnings.warn("Symmetry point locations are not k-converted", stacklevel=2)
coords = [point_location[d] for d in dim_order]
coords = tuple([point_location[d] for d in dim_order])
ax.plot(*coords, marker=".", color=marker_color)
ax.annotate(
label_for_symmetry_point(point_name),
Expand Down
10 changes: 5 additions & 5 deletions arpes/plotting/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from matplotlib.colors import Normalize
from matplotlib.figure import Figure

from arpes._typing import DataType

__all__ = (
"plot_dos",
"plot_core_levels",
Expand All @@ -33,7 +31,7 @@

@save_plot_provenance
def plot_core_levels( # noqa: PLR0913
data: DataType,
data: xr.DataArray,
title: str = "",
out: str | Path = "",
norm: Normalize | None = None,
Expand All @@ -43,7 +41,9 @@ def plot_core_levels( # noqa: PLR0913
promenance: int = 5,
) -> Path | tuple[Axes, Colorbar]:
"""Plots an XPS curve and approximate core level locations."""
_, axes, cbar = plot_dos(data=data, title=title, out="", norm=norm, dos_pow=dos_pow)
plotdos = plot_dos(data=data, title=title, out="", norm=norm, dos_pow=dos_pow)
assert isinstance(plotdos, tuple)
_, axes, cbar = plotdos

if core_levels is None:
core_levels = approximate_core_levels(data, binning=binning, promenance=promenance)
Expand All @@ -64,7 +64,7 @@ def plot_dos(
out: str | Path = "",
norm: Normalize | None = None,
dos_pow: float = 1,
) -> Path | tuple[Figure, Axes, Colorbar]:
) -> Path | tuple[Figure, tuple[Axes], Colorbar]:
"""Plots the density of states (momentum integrated) image next to the original spectrum."""
data_arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data)

Expand Down
2 changes: 1 addition & 1 deletion arpes/plotting/fermi_edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def fermi_edge_reference(
data: xr.DataArray,
title: str = "",
ax: Axes | None = None,
out: str = "",
out: str | Path = "",
**kwargs: Incomplete,
) -> Path | Axes:
"""Fits for and plots results for the Fermi edge on a piece of data.
Expand Down
10 changes: 5 additions & 5 deletions arpes/utilities/bz.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def reduced_bz_axis_to(
point_names = _POINT_NAMES_FOR_SYMMETRY[symmetry]

symmetry_points, _ = data.S.symmetry_points()
points = {k: v[0] for k, v in symmetry_points.items() if k in point_names}
points = {k: v for k, v in symmetry_points.items() if k in point_names}

coords_by_point = {
k: np.array([v.get(d, 0) for d in data.dims if d in v or include_E and d == "eV"])
Expand Down Expand Up @@ -500,7 +500,7 @@ def reduced_bz_axes(data: XrTypes) -> tuple[NDArray[np.float_], NDArray[np.float
point_names = _POINT_NAMES_FOR_SYMMETRY[symmetry]

symmetry_points, _ = data.S.symmetry_points()
points = {k: v[0] for k, v in symmetry_points.items() if k in point_names}
points = {k: v for k, v in symmetry_points.items() if k in point_names}

coords_by_point = {k: np.array([v[d] for d in data.dims if d in v]) for k, v in points.items()}
if symmetry == "rect":
Expand Down Expand Up @@ -537,7 +537,7 @@ def axis_along(data: XrTypes, symbol: str) -> float:
point_names = _POINT_NAMES_FOR_SYMMETRY[symmetry]

symmetry_points, _ = data.S.symmetry_points()
points = {k: v[0] for k, v in symmetry_points.items() if k in point_names}
points = {k: v for k, v in symmetry_points.items() if k in point_names}

coords_by_point = {k: np.array([v[d] for d in data.dims if d in v]) for k, v in points.items()}

Expand Down Expand Up @@ -576,7 +576,7 @@ def reduced_bz_poly(data: XrTypes, *, scale_zone: bool = False) -> NDArray[np.fl
dy = 3 * dy

symmetry_points, _ = data.S.symmetry_points()
points = {k: v[0] for k, v in symmetry_points.items() if k in point_names}
points = {k: v for k, v in symmetry_points.items() if k in point_names}
coords_by_point = {
k: np.array([v.get(d, 0) for d in data.dims if d in v]) for k, v in points.items()
}
Expand Down Expand Up @@ -625,7 +625,7 @@ def reduced_bz_E_mask(
point_names = _POINT_NAMES_FOR_SYMMETRY[symmetry]

symmetry_points, _ = data.S.symmetry_points()
points = {k: v[0] for k, v in symmetry_points.items() if k in point_names}
points = {k: v for k, v in symmetry_points.items() if k in point_names}
coords_by_point = {
k: np.array([v.get(d, 0) for d in data.dims if d in v or d == "eV"])
for k, v in points.items()
Expand Down
1 change: 0 additions & 1 deletion arpes/utilities/conversion/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from typing import TYPE_CHECKING, Literal

import numpy as np
from numpy.typing import ArrayLike
import xarray as xr
from scipy.interpolate import RegularGridInterpolator

Expand Down
Loading

0 comments on commit 3d502b4

Please sign in to comment.