From c294318fca67dc32ddd74b957099f2cecdde3b2c Mon Sep 17 00:00:00 2001 From: Ryuichi Arafune Date: Sat, 16 Mar 2024 09:27:48 +0900 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20=20Add=20new=20function=20to=20tarp?= =?UTF-8?q?es.py,=20which=20has=20been=20used=20very=20personally.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🔥 Remove transform_labels in plotting/utils.py It's not so convenient, and isn't used on a daily basis. 💬 Update warning text in normalized_spectrum 💬 Update type hints --- src/arpes/analysis/__init__.py | 2 +- src/arpes/analysis/decomposition.py | 2 +- src/arpes/analysis/deconvolution.py | 14 ++-- src/arpes/analysis/savitzky_golay.py | 24 +++---- src/arpes/analysis/statistics.py | 21 +++++- src/arpes/analysis/tarpes.py | 96 ++++++++++++++++++++++++++- src/arpes/constants.py | 11 ++-- src/arpes/laser.py | 37 ++++++++++- src/arpes/plotting/fermi_edge.py | 10 +-- src/arpes/plotting/movie.py | 4 +- src/arpes/plotting/stack_plot.py | 6 +- src/arpes/plotting/tof.py | 17 ++++- src/arpes/plotting/utils.py | 98 +++++++++++----------------- src/arpes/provenance.py | 3 + src/arpes/utilities/normalize.py | 32 +++++++-- src/arpes/xarray_extensions.py | 7 +- 16 files changed, 273 insertions(+), 111 deletions(-) diff --git a/src/arpes/analysis/__init__.py b/src/arpes/analysis/__init__.py index 6f54b130..84bbc3da 100644 --- a/src/arpes/analysis/__init__.py +++ b/src/arpes/analysis/__init__.py @@ -30,5 +30,5 @@ pocket_parameters, radial_edcs_along_pocket, ) -from .tarpes import normalized_relative_change, relative_change +from .tarpes import build_crosscorrelation, normalized_relative_change, relative_change from .xps import approximate_core_levels diff --git a/src/arpes/analysis/decomposition.py b/src/arpes/analysis/decomposition.py index d0067420..dc3ae542 100644 --- a/src/arpes/analysis/decomposition.py +++ b/src/arpes/analysis/decomposition.py @@ -77,7 +77,7 @@ class FactorAnalysisParam(TypedDict, total=False): random_state: int | None -class DecompositionParam(PCAParam, FastICAParam, NMFParam, FactorAnalysisParam): +class DecompositionParam(PCAParam, FastICAParam, NMFParam, FactorAnalysisParam): # type: ignore[misc] pass diff --git a/src/arpes/analysis/deconvolution.py b/src/arpes/analysis/deconvolution.py index 093d4615..91afd821 100644 --- a/src/arpes/analysis/deconvolution.py +++ b/src/arpes/analysis/deconvolution.py @@ -82,11 +82,8 @@ def deconvolve_ice( poly = np.poly1d(coefs) deconv[t] = poly(0) - if isinstance(data, np.ndarray): - result = deconv - else: - result = data.copy(deep=True) - result.values = deconv + result = data.copy(deep=True) + result.values = deconv return result @@ -114,7 +111,11 @@ def deconvolve_rl( @update_provenance("Make 1D-Point Spread Function") -def make_psf1d(data: xr.DataArray, dim: str, sigma: float) -> xr.DataArray: +def make_psf1d( + data: xr.DataArray, + dim: str, + sigma: float, +) -> xr.DataArray: """Produces a 1-dimensional gaussian point spread function for use in deconvolve_rl. Args: @@ -184,7 +185,6 @@ def make_psf( f" psf_coords[{k}]: ±{np.max(v):.3f}", ) coords = np.meshgrid(*[psf_coords[dim] for dim in data.dims], indexing="ij") - coords_for_pdf_pos = np.stack(coords, axis=-1) # point distribution function (pdf) logger.debug(f"shape of coords_for_pdf_pos: {coords_for_pdf_pos.shape}") psf = xr.DataArray( diff --git a/src/arpes/analysis/savitzky_golay.py b/src/arpes/analysis/savitzky_golay.py index e6f28e8f..0be6a5af 100644 --- a/src/arpes/analysis/savitzky_golay.py +++ b/src/arpes/analysis/savitzky_golay.py @@ -3,11 +3,12 @@ from __future__ import annotations from math import factorial -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, TypeVar import numpy as np import scipy.signal import xarray as xr +from numpy.typing import NDArray from arpes.constants import TWO_DIMENSION from arpes.provenance import update_provenance @@ -15,21 +16,21 @@ if TYPE_CHECKING: from collections.abc import Hashable - from numpy.typing import NDArray - __all__ = ("savitzky_golay",) +T = TypeVar("T", xr.DataArray, NDArray[np.float_]) + @update_provenance("Savitzky Golay Filter") def savitzky_golay( # noqa: PLR0913 - data: xr.DataArray, + data: T, window_size: int, order: int, - deriv: int | Literal["col", "row", "both", None] = 0, + deriv: int | Literal["col", "row", "both"] | None = 0, rate: int = 1, dim: Hashable = "", -) -> xr.DataArray: +) -> T: """Implements a Savitzky Golay filter with given window size. You can specify "pass through" dimensions @@ -38,6 +39,7 @@ def savitzky_golay( # noqa: PLR0913 Args: data: Input data. + This should be xr.DataArray, while list[float] or np.ndarray can be accepted. window_size: Number of points in the window that the filter uses locally. order: The polynomial order used in the convolution. deriv: the order of the derivative to compute (default = 0 means only smoothing) @@ -47,12 +49,12 @@ def savitzky_golay( # noqa: PLR0913 Returns: Smoothed data. """ - if isinstance( - data, - list | np.ndarray, - ): + if isinstance(data, list): + assert isinstance(deriv, int) + return savitzky_golay_array(data, window_size, order, deriv, rate) + if isinstance(data, np.ndarray): + assert isinstance(deriv, int) return savitzky_golay_array(data, window_size, order, deriv, rate) - if len(data.dims) == 1: assert isinstance(deriv, int) transformed_data = savitzky_golay_array(data.values, window_size, order, deriv, rate) diff --git a/src/arpes/analysis/statistics.py b/src/arpes/analysis/statistics.py index a5726982..da13d998 100644 --- a/src/arpes/analysis/statistics.py +++ b/src/arpes/analysis/statistics.py @@ -2,6 +2,7 @@ from __future__ import annotations +from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger from typing import TYPE_CHECKING import xarray as xr @@ -14,10 +15,26 @@ __all__ = ("mean_and_deviation",) +LOGLEVELS = (DEBUG, INFO) +LOGLEVEL = LOGLEVELS[1] +logger = getLogger(__name__) +fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s" +formatter = Formatter(fmt) +handler = StreamHandler() +handler.setLevel(LOGLEVEL) +logger.setLevel(LOGLEVEL) +handler.setFormatter(formatter) +logger.addHandler(handler) +logger.propagate = False + @update_provenance("Calculate mean and standard deviation for observation axis") @lift_dataarray_to_generic -def mean_and_deviation(data: XrTypes, axis: str = "", name: str = "") -> xr.Dataset: +def mean_and_deviation( + data: xr.DataArray, # data.name is used. + axis: str = "", + name: str = "", +) -> xr.Dataset: """Calculates the mean and standard deviation of a DataArray along an axis. The reduced axis corresponds to individual observations of a tensor/array valued quantity. @@ -27,7 +44,7 @@ def mean_and_deviation(data: XrTypes, axis: str = "", name: str = "") -> xr.Data If a name is not attached to the DataArray, it should be provided. Args: - data: The input data (Both DataArray and Dataset). + data: The input data. axis: The name of the dimension which we should perform the reduction along. name: The name of the variable which should be reduced. By default, uses `data.name`. diff --git a/src/arpes/analysis/tarpes.py b/src/arpes/analysis/tarpes.py index f0fa404e..bfa3b544 100644 --- a/src/arpes/analysis/tarpes.py +++ b/src/arpes/analysis/tarpes.py @@ -3,15 +3,109 @@ from __future__ import annotations import warnings +from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger +from typing import TYPE_CHECKING, TypeVar import numpy as np import xarray as xr +from numpy.typing import NDArray from arpes.preparation import normalize_dim from arpes.provenance import update_provenance from arpes.utilities import normalize_to_spectrum -__all__ = ("find_t0", "relative_change", "normalized_relative_change") +if TYPE_CHECKING: + from collections.abc import Sequence + +__all__ = ( + "find_t0", + "relative_change", + "normalized_relative_change", + "build_crosscorrelation", + "delaytime_fs", + "position_to_delaytime", +) + + +LOGLEVELS = (DEBUG, INFO) +LOGLEVEL = LOGLEVELS[1] +logger = getLogger(__name__) +fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s" +formatter = Formatter(fmt) +handler = StreamHandler() +handler.setLevel(LOGLEVEL) +logger.setLevel(LOGLEVEL) +handler.setFormatter(formatter) +logger.addHandler(handler) +logger.propagate = False + + +A = TypeVar("A", NDArray[np.float64], float) + + +def build_crosscorrelation( + datalist: Sequence[xr.DataArray], + delayline_dim: str = "position", + delayline_origin: float = 0, + *, + convert_position_to_time: bool = True, +) -> xr.DataArray: + """Build the ('original dimnsion' + 1)D data from the series of cross-correlation measurements. + + Args: + datalist (Sequence[xr.DataArray]): Data series from the cross-correlation experiments. + delayline_dim: the dimension name for "delay line", which must be in key of data.attrs + When this is the "position" dimention, the unit is assumed to be "mm". If the value has + already been converted to "time" dimension, set convert_position_to_time=True + delayline_origin (float): The value corresponding to the delay zero. + convert_position_to_time: (bool) If true, no conversion into "delay" is processed. + + Returns: xr.DataArray + """ + cross_correlations = [] + + for spectrum in datalist: + spectrum_arr = ( + spectrum if isinstance(spectrum, xr.DataArray) else normalize_to_spectrum(spectrum) + ) + if convert_position_to_time: + delay_time = spectrum_arr.attrs[delayline_dim] - delayline_origin + else: + delay_time = position_to_delaytime( + float(spectrum_arr[delayline_dim]), + delayline_origin, + ) + cross_correlations.append( + spectrum_arr.assign_coords({"delay": delay_time}).expand_dims("delay"), + ) + return xr.concat(cross_correlations, dim="delay") + + +def delaytime_fs(mirror_movement_um: A) -> A: + """Return delaytime from the mirror movement (not position). + + Args: + mirror_movement_um (float): mirror movement in micron unit. + + Returns: float + delay time in fs. + + """ + return 3.335640951981521 * mirror_movement_um + + +def position_to_delaytime(position_mm: A, delayline_offset_mm: float) -> A: + """Return delay time from the mirror position. + + Args: + position_mm (np.ndarray | float): mirror position + delayline_offset_mm (float): mirror position corresponding to the zero delay + + Returns: np.ndarray | float + delay time in fs unit. + + """ + return delaytime_fs(2 * (position_mm - delayline_offset_mm) * 1000) @update_provenance("Normalized subtraction map") diff --git a/src/arpes/constants.py b/src/arpes/constants.py index 1548ec41..28a6dbcc 100644 --- a/src/arpes/constants.py +++ b/src/arpes/constants.py @@ -21,10 +21,11 @@ METERS_PER_SECOND_PER_EV_ANGSTROM = ( 151927 # converts from eV * angstrom to meters/second velocity units ) -HBAR = 1.0545718176461565 * 10 ** (-34) -HBAR_PER_EV = 6.582119569 * 10 ** ( - -16 -) # gives the energy lifetime relationship via tau = -hbar / np.imag(self_energy) +HBAR = 1.0545718176461565e-34 +HBAR_PER_EV = 6.582119569509067e-16 +# gives the energy lifetime relationship via tau = -hbar / np.imag(self_energy) + + BARE_ELECTRON_MASS = 9.109383701e-31 # kg HBAR_SQ_EV_PER_ELECTRON_MASS = 0.475600805657 # hbar^2 / m0 in eV^2 s^2 / kg HBAR_SQ_EV_PER_ELECTRON_MASS_ANGSTROM_SQ = 7.619964 # (hbar^2) / (m0 * angstrom ^2) in eV @@ -32,7 +33,7 @@ K_BOLTZMANN_EV_KELVIN = 8.617333262145178e-5 # in units of eV / Kelvin K_BOLTZMANN_MEV_KELVIN = 1000 * K_BOLTZMANN_EV_KELVIN # meV / Kelvin -HC = 1239.84172 # in units of eV * nm +HC = 1239.8419843320028 # in units of eV * nm HEX_ALPHABET = "ABCDEF0123456789" ALPHABET = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" diff --git a/src/arpes/laser.py b/src/arpes/laser.py index dafa0cf2..5272cb30 100644 --- a/src/arpes/laser.py +++ b/src/arpes/laser.py @@ -2,12 +2,45 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger +from typing import TYPE_CHECKING, TypeVar + +import numpy as np +from numpy.typing import NDArray + +from .constants import HC if TYPE_CHECKING: import pint -__all__ = ("electrons_per_pulse",) +LOGLEVELS = (DEBUG, INFO) +LOGLEVEL = LOGLEVELS[1] +logger = getLogger(__name__) +fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s" +formatter = Formatter(fmt) +handler = StreamHandler() +handler.setLevel(LOGLEVEL) +logger.setLevel(LOGLEVEL) +handler.setFormatter(formatter) +logger.addHandler(handler) +logger.propagate = False + + +__all__ = ("electrons_per_pulse", "wavelength_to_energy") + +A = TypeVar("A", NDArray[np.float64], float) + + +def wavelength_to_energy(wavelength_nm: A) -> A: + """Return Energy of the light. + + Args: + wavelength_nm (NDArray | float): wavelength of the light in nm unit. + + Returns: NDArray | float + Photon energy in eV unit. + """ + return HC / wavelength_nm def electrons_per_pulse( diff --git a/src/arpes/plotting/fermi_edge.py b/src/arpes/plotting/fermi_edge.py index a9bb014f..02e77643 100644 --- a/src/arpes/plotting/fermi_edge.py +++ b/src/arpes/plotting/fermi_edge.py @@ -117,7 +117,7 @@ def fermi_edge_reference( """Fits for and plots results for the Fermi edge on a piece of data. Args: - data: The data, this should be of type DataArray + data_arr: The data, this should be of type DataArray title: A title to attach to the plot ax: The axes to plot to, if not specified will be generated out: Where to save the plot @@ -133,7 +133,7 @@ def fermi_edge_reference( assert isinstance(data_arr, xr.DataArray) sum_dimensions: set[str] = {"cycle", "phi", "kp", "kx"} sum_dimensions.intersection_update(set(data_arr.dims)) - summed_data = data.sum(*list(sum_dimensions)) + summed_data = data_arr.sum(*list(sum_dimensions)) broadcast_dimensions = [str(d) for d in summed_data.dims if str(d) != "eV"] msg = f"Could not product fermi edge reference. Too many dimensions: {broadcast_dimensions}" @@ -156,14 +156,14 @@ def fermi_edge_reference( _, ax = plt.subplots(figsize=(8, 5)) if not title: - title = data.S.label.replace("_", " ") + title = data_arr.S.label.replace("_", " ") centers.plot(ax=ax, **kwargs) widths.plot(ax=ax, **kwargs) if isinstance(ax, Axes): - ax.set_xlabel(label_for_dim(data, ax.get_xlabel())) - ax.set_ylabel(label_for_dim(data, ax.get_ylabel())) + ax.set_xlabel(label_for_dim(data_arr, ax.get_xlabel())) + ax.set_ylabel(label_for_dim(data_arr, ax.get_ylabel())) ax.set_title(title, font_size=14) if out: diff --git a/src/arpes/plotting/movie.py b/src/arpes/plotting/movie.py index d7198446..de214bb6 100644 --- a/src/arpes/plotting/movie.py +++ b/src/arpes/plotting/movie.py @@ -48,9 +48,7 @@ def plot_movie( Raises: TypeError: [TODO:description] """ - if not isinstance(data, xr.DataArray): - msg = "You must provide a DataArray" - raise TypeError(msg) + assert isinstance(data, xr.DataArray), "You must provide a DataArray" fig, ax = fig_ax if ax is None: fig, ax = plt.subplots(figsize=(7, 7)) diff --git a/src/arpes/plotting/stack_plot.py b/src/arpes/plotting/stack_plot.py index a5c322c2..8d611012 100644 --- a/src/arpes/plotting/stack_plot.py +++ b/src/arpes/plotting/stack_plot.py @@ -134,7 +134,7 @@ def offset_scatter_plot( if cbarmap is None: skip_colorbar = False cbar: colorbar.Colorbar | Callable[..., colorbar.Colorbar] - cmap: Callable[..., Callable[..., ColorType]] + cmap: Callable[..., ColorType] | Callable[..., Callable[..., ColorType]] try: cbar, cmap = colorbarmaps_for_axis[stack_axis] except KeyError: @@ -187,7 +187,6 @@ def offset_scatter_plot( name_to_plot, ax=ax, color=cmap(coord[stack_axis]), - fmt="none", ) ax.set_xlabel(other_dim) @@ -474,8 +473,7 @@ def stack_dispersion_plot( # noqa: PLR0913 ax.set_xlabel(label_for_dim(data_arr, x_label)) # set xlim with margin # 11/10 is the good value for margine - axis_min = min(lim) - axis_max = max(lim) + axis_min, axis_max = min(lim), max(lim) middle = (axis_min + axis_max) / 2 ax.set_xlim( left=middle - (axis_max - axis_min) / 2 * 11 / 10, diff --git a/src/arpes/plotting/tof.py b/src/arpes/plotting/tof.py index c1e8fd84..7299483a 100644 --- a/src/arpes/plotting/tof.py +++ b/src/arpes/plotting/tof.py @@ -11,6 +11,7 @@ from __future__ import annotations +from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger from typing import TYPE_CHECKING, Unpack import matplotlib.pyplot as plt @@ -29,6 +30,19 @@ from arpes._typing import MPLPlotKwargs, MPLPlotKwargsBasic +LOGLEVELS = (DEBUG, INFO) +LOGLEVEL = LOGLEVELS[1] +logger = getLogger(__name__) +fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s" +formatter = Formatter(fmt) +handler = StreamHandler() +handler.setLevel(LOGLEVEL) +logger.setLevel(LOGLEVEL) +handler.setFormatter(formatter) +logger.addHandler(handler) +logger.propagate = False + + __all__ = ( "plot_with_std", "scatter_with_std", @@ -37,7 +51,7 @@ @save_plot_provenance def plot_with_std( - data_set: xr.Dataset, # dat_vars is used, + data_set: xr.Dataset, # data_vars is used, name_to_plot: str = "", ax: Axes | None = None, out: str | Path = "", @@ -100,7 +114,6 @@ def scatter_with_std( ax: Matplotlib Axes object out: (str | Path): Path name to output figure. figsize (tuple[float, float]): tuple for figure size. - fmt (str): THe form at for the data points/lines. **kwargs: pass to subplots if figsize is set as tuple, other kwargs are pass to ax.errorbar """ if not name_to_plot: diff --git a/src/arpes/plotting/utils.py b/src/arpes/plotting/utils.py index e7519c85..9d8391c2 100644 --- a/src/arpes/plotting/utils.py +++ b/src/arpes/plotting/utils.py @@ -5,6 +5,7 @@ import contextlib import datetime import errno +import functools import itertools import json import pickle @@ -41,7 +42,7 @@ from lmfit.model import Model from matplotlib.font_manager import FontProperties from matplotlib.image import AxesImage - from matplotlib.typing import ColorType, RGBAColorType, RGBColorType + from matplotlib.typing import ColorType from numpy.typing import NDArray from arpes._typing import ColorbarParam, DataType, MPLPlotKwargs, PLTSubplotParam, XrTypes @@ -64,7 +65,6 @@ "temperature_colormap_around", "temperature_colorbar", "temperature_colorbar_around", - "generic_colorbarmap", "generic_colorbarmap_for_data", "colorbarmaps_for_axis", # Axis generation @@ -104,7 +104,6 @@ "mod_plot_to_ax", # Data summaries "summarize", - "transform_labels", "v_gradient_fill", "h_gradient_fill", ) @@ -163,7 +162,7 @@ def h_gradient_fill( x1: float, x2: float, x_solid: float | None, - fill_color: RGBColorType = "red", + fill_color: ColorType = "red", ax: Axes | None = None, **kwargs: Unpack[GradientFillParam], ) -> AxesImage: # <== checkme! @@ -229,7 +228,7 @@ def v_gradient_fill( y1: float, y2: float, y_solid: float | None, - fill_color: RGBColorType = "red", + fill_color: ColorType = "red", ax: Axes | None = None, **kwargs: Unpack[GradientFillParam], ) -> AxesImage: @@ -407,30 +406,6 @@ def swap_axis_sides(ax: Axes) -> None: swap_yaxis_side(ax) -def transform_labels( - transform_fn: Callable[[str, bool], str], - fig: Figure | None = None, - *, - include_titles: bool = True, -) -> None: - """Apply a function to all axis labeled in a figure.""" - if fig is None: - fig = plt.gcf() - assert isinstance(fig, Figure) - axes = list(fig.get_axes()) - for ax in axes: - try: - ax.set_xlabel(transform_fn(ax.get_xlabel(), is_title=False)) - ax.set_ylabel(transform_fn(ax.get_xlabel(), is_title=False)) - if include_titles: - ax.set_title(transform_fn(ax.get_title(), is_title=True)) - except TypeError: - ax.set_xlabel(transform_fn(ax.get_xlabel())) - ax.set_ylabel(transform_fn(ax.get_xlabel())) - if include_titles: - ax.set_title(transform_fn(ax.get_title())) - - def summarize(data: xr.DataArray, axes: NDArray[np.object_] | None = None) -> NDArray[np.object_]: """Makes a summary plot with different marginal plots represented.""" data_arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) @@ -515,7 +490,7 @@ def to_str(bound: float | None) -> str: return eV_annotation + phi_annotation -def frame_with(ax: Axes, color: RGBColorType = "red", linewidth: float = 2) -> None: +def frame_with(ax: Axes, color: ColorType = "red", linewidth: float = 2) -> None: """Makes thick, visually striking borders on a matplotlib plot. Very useful for color coding results in a slideshow. @@ -880,9 +855,9 @@ def resolve(name: Hashable, value: slice | int) -> NDArray[np.float_]: def generic_colormap( - low: float, - high: float, -) -> Callable[..., ColorType]: + low: float = 0, + high: float = 1, +) -> Callable[[float], ColorType]: """Generates a colormap from the cm.Blues palette, suitable for most purposes.""" delta = high - low low = low - delta / 6 @@ -947,9 +922,9 @@ def get_color(value: float) -> tuple[float, float, float, float]: def generic_colorbar( - low: float, - high: float, - ax: Axes, + low: float = 0, + high: float = 1, + ax: Axes | None = None, **kwargs: Unpack[ColorbarParam], ) -> colorbar.Colorbar: """Generate colorbar. @@ -968,6 +943,7 @@ def generic_colorbar( delta = high - low low = low - delta / 6 high = high + delta / 6 + assert ax is not None return colorbar.Colorbar(ax, **kwargs) @@ -1058,13 +1034,26 @@ def temperature_colorbar_around( return colorbar.Colorbar(ax, **kwargs) +def polarization_colorbar(ax: Axes | None = None) -> colorbar.Colorbar: + """Makes a colorbar which is appropriate for "polarization" (e.g. spin) data.""" + assert isinstance(ax, Axes) + return colorbar.Colorbar( + ax, + cmap="RdBu", + norm=colors.Normalize(vmin=-1, vmax=1), + orientation="horizontal", + label="Polarization", + ticks=[-1, 0, 1], + ) + + colorbarmaps_for_axis: dict[ str, tuple[ Callable[..., colorbar.Colorbar], Callable[ ..., - Callable[..., ColorType], + Callable[[float], ColorType], ], ], ] = { @@ -1117,19 +1106,19 @@ def remove_colorbars(fig: Figure | None = None) -> None: logger.debug(f"Exception occurs: {err=}, {type(err)=}") -generic_colorbarmap = ( - generic_colorbar, - generic_colormap, -) - - def generic_colorbarmap_for_data( data: xr.DataArray, ax: Axes, *, keep_ticks: bool = True, **kwargs: Unpack[ColorbarParam], -) -> tuple[colorbar.Colorbar, Callable[..., ColorType]]: +) -> tuple[ + Callable[..., colorbar.Colorbar], + Callable[ + ..., + Callable[[float], ColorType], + ], +]: """Generates a colorbar and colormap which is useful in general context. Args: @@ -1146,29 +1135,20 @@ def generic_colorbarmap_for_data( if keep_ticks: ticks = data.values.tolist() return ( - generic_colorbar( + functools.partial( + generic_colorbar, low=low, high=high, ax=ax, - ticks=kwargs.get("ticks", ticks), + ticks=kwargs.get( + "ticks", + ticks, + ), ), generic_colormap(low=low, high=high), ) -def polarization_colorbar(ax: Axes | None = None) -> colorbar.Colorbar: - """Makes a colorbar which is appropriate for "polarization" (e.g. spin) data.""" - assert isinstance(ax, Axes) - return colorbar.Colorbar( - ax, - cmap="RdBu", - norm=colors.Normalize(vmin=-1, vmax=1), - orientation="horizontal", - label="Polarization", - ticks=[-1, 0, 1], - ) - - def calculate_aspect_ratio(data: xr.DataArray) -> float: """Calculate the aspect ratio which should be used for plotting some data based on extent.""" data_arr = data if isinstance(data, xr.DataArray) else normalize_to_spectrum(data) diff --git a/src/arpes/provenance.py b/src/arpes/provenance.py index 67277f7c..c1fa5215 100644 --- a/src/arpes/provenance.py +++ b/src/arpes/provenance.py @@ -81,6 +81,9 @@ class Provenance(TypedDict, total=False): transformed_vars: list[str] # occupation_ratio: float + # + correlation: bool + decomposition_cls: str def attach_id(data: XrTypes) -> None: diff --git a/src/arpes/utilities/normalize.py b/src/arpes/utilities/normalize.py index b66565cd..e03ea3ec 100644 --- a/src/arpes/utilities/normalize.py +++ b/src/arpes/utilities/normalize.py @@ -2,7 +2,9 @@ from __future__ import annotations +import inspect import warnings +from logging import DEBUG, INFO, Formatter, StreamHandler, getLogger from typing import TYPE_CHECKING import xarray as xr @@ -10,6 +12,19 @@ if TYPE_CHECKING: from arpes._typing import XrTypes +LOGLEVELS = (DEBUG, INFO) +LOGLEVEL = LOGLEVELS[0] +logger = getLogger(__name__) +fmt = "%(asctime)s %(levelname)s %(name)s :%(message)s" +formatter = Formatter(fmt) +handler = StreamHandler() +handler.setLevel(LOGLEVEL) +logger.setLevel(LOGLEVEL) +handler.setFormatter(formatter) +logger.addHandler(handler) +logger.propagate = False + + __all__ = ( "normalize_to_spectrum", "normalize_to_dataset", @@ -21,8 +36,19 @@ def normalize_to_spectrum(data: XrTypes | str) -> xr.DataArray: import arpes.xarray_extensions # noqa: F401 from arpes.io import load_data - msg = "Remember to use a DataArray not a Dataset, " - msg += "attempting to extract spectrum and copy attributes." + logger.debug(f"inspect.stack(): {inspect.stack()}") + if isinstance(data, str): + msg = "You may use a file for the data as a argument of " + msg += f"{inspect.stack()[1].function} in {inspect.stack()[1].filename}\n" + msg += "Remember to use a DataArray not a Dataset nor the filename, " + msg += "attempting to extract spectrum and copy attributes.\n" + msg += "Not so sure if this is what you have really expected." + return normalize_to_spectrum(load_data(data)) + + msg = "You use Dataset as a argument of " + msg += f"{inspect.stack()[1].function} in {inspect.stack()[1].filename}\n" + msg += "Remember to use a DataArray not a Dataset, " + msg += "attempting to extract spectrum and copy attributes.\n" warnings.warn( msg, stacklevel=2, @@ -33,8 +59,6 @@ def normalize_to_spectrum(data: XrTypes | str) -> xr.DataArray: assert isinstance(data.up, xr.DataArray) return data.up return data.S.spectrum - if isinstance(data, str): - return normalize_to_spectrum(load_data(data)) assert isinstance(data, xr.DataArray) return data diff --git a/src/arpes/xarray_extensions.py b/src/arpes/xarray_extensions.py index 9e5d732e..7b8e780c 100644 --- a/src/arpes/xarray_extensions.py +++ b/src/arpes/xarray_extensions.py @@ -817,9 +817,8 @@ def lookup_offset_coord(self, name: str) -> xr.DataArray | NDArray[np.float_] | def lookup_coord(self, name: str) -> xr.DataArray | float: if name in self._obj.coords: return unwrap_xarray_item(self._obj.coords[name]) - - msg = f"Could not find coordinate {name}. Check your endstation module." - raise ValueError(msg) + self._obj.coords[name] = np.nan + return np.nan def lookup_offset(self, attr_name: str) -> float: symmetry_points = self.symmetry_points() @@ -1007,7 +1006,7 @@ def zero_spectrometer_edges( assert isinstance(self._obj, xr.DataArray) if low is not None: assert high is not None - assert len(low) == len(high) == 2 # noqa: PLR2004 + assert len(low) == len(high) == TWO_DIMENSION low_edges = low high_edges = high