From 2fae1c351f29b2fb1ceef39a69706b3f198e4659 Mon Sep 17 00:00:00 2001 From: Kimoon Han Date: Sat, 8 Jun 2024 17:52:13 -0700 Subject: [PATCH] feat(analysis.gold): add function for quick resolution fitting --- src/erlab/analysis/__init__.py | 4 +- src/erlab/analysis/gold.py | 174 ++++++++++++++++++++++++++++++++- 2 files changed, 173 insertions(+), 5 deletions(-) diff --git a/src/erlab/analysis/__init__.py b/src/erlab/analysis/__init__.py index 88bd1978..5bae9ed6 100644 --- a/src/erlab/analysis/__init__.py +++ b/src/erlab/analysis/__init__.py @@ -21,9 +21,9 @@ """ -__all__ = ["correct_with_edge", "shift", "slice_along_path"] +__all__ = ["correct_with_edge", "quick_resolution", "shift", "slice_along_path"] from erlab.analysis import fit, gold, image, interpolate, mask, transform # noqa: F401 -from erlab.analysis.gold import correct_with_edge +from erlab.analysis.gold import correct_with_edge, quick_resolution from erlab.analysis.interpolate import slice_along_path from erlab.analysis.utils import shift diff --git a/src/erlab/analysis/gold.py b/src/erlab/analysis/gold.py index 97a9b96d..72a08b6d 100644 --- a/src/erlab/analysis/gold.py +++ b/src/erlab/analysis/gold.py @@ -5,6 +5,8 @@ "edge", "poly", "poly_from_edge", + "quick_fit", + "quick_resolution", "resolution", "resolution_roi", "spline_from_edge", @@ -17,8 +19,9 @@ import lmfit.model import matplotlib import matplotlib.figure -import matplotlib.patches as mpatches +import matplotlib.patches import matplotlib.pyplot as plt +import matplotlib.transforms import numpy as np import numpy.typing as npt import scipy.interpolate @@ -363,7 +366,7 @@ def _plot_gold_fit(fig, gold, angle_range, eV_range, center_arr, center_stderr, ax2 = fig.add_subplot(gs[1, 1], sharex=ax1) gold.qplot(ax=ax0, cmap="copper", gamma=0.5) - rect = mpatches.Rectangle( + rect = matplotlib.patches.Rectangle( (angle_range[0], eV_range[0]), np.diff(angle_range)[0], np.diff(eV_range)[0], @@ -540,6 +543,171 @@ def spline( return spl +def quick_fit( + darr: xr.DataArray, + eV_range: tuple[float, float] | None = None, + method: str = "leastsq", + temp: float | None = None, + resolution: float | None = None, + fix_temp: bool = True, + fix_center: bool = False, + fix_resolution: bool = False, + bkg_slope: bool = True, +) -> xr.Dataset: + """Perform a quick Fermi edge fit on the given data. + + The data is averaged over all dimensions except the energy prior to fitting. + + Parameters + ---------- + darr + The input data to be fitted. + eV_range + The energy range to consider for fitting. If `None`, the entire energy range is + used. Defaults to `None`. + method + The fitting method to use that is compatible with `lmfit`. Defaults to + "leastsq". + temp + The temperature value to use for fitting. If `None`, the temperature is inferred + from the data attributes. + resolution + The initial resolution value to use for fitting. If `None`, the resolution is + set to 0.02, or to the ``'TotalResolution'`` attribute if present. + fix_temp + Whether to fix the temperature value during fitting. Defaults to `True`. + fix_center + Whether to fix the Fermi level during fitting. If `True`, the Fermi level is + fixed to 0. Defaults to `False`. + fix_resolution + Whether to fix the resolution value during fitting. Defaults to `False`. + bkg_slope + Whether to include a linear background above the Fermi level. If `False`, the + background above the Fermi level is fit with a constant. Defaults to `True`. + + Returns + ------- + result : xarray.Dataset + The result of the fit. + + """ + data = darr.mean([d for d in darr.dims if d != "eV"]) + + if eV_range is not None: + data_fit = data.sel(eV=slice(*eV_range)) + else: + data_fit = data + + if temp is None: + if "temp_sample" in data.attrs: + temp = float(data.attrs["temp_sample"]) + else: + raise ValueError( + "Temperature not found in data attributes, please provide manually" + ) + + if resolution is None: + if "TotalResolution" in data.attrs: + resolution = float(data.attrs["TotalResolution"]) * 1e-3 + else: + resolution = 0.02 + + params = { + "temp": {"value": temp, "vary": not fix_temp, "min": 0}, + "resolution": {"value": resolution, "vary": not fix_resolution, "min": 0}, + } + + if not bkg_slope: + params["back1"] = {"value": 0, "vary": False} + + if fix_center: + params["center"] = {"value": 0, "vary": False} + + return data_fit.modelfit( + "eV", model=FermiEdgeModel(), method=method, params=params, guess=True + ) + + +def quick_resolution( + darr: xr.DataArray, + ax: matplotlib.axes.Axes | None = None, + **kwargs, +) -> xr.Dataset: + """Fit a Fermi edge to the given data and plot the results. + + This function is a wrapper around `quick_fit` that plots the data and the obtained + resolution. The data is averaged over all dimensions except the energy prior to + fitting. + + Parameters + ---------- + darr + The input data to be fitted. + ax + The axis to plot the data and fit on. If `None`, the current axis is used. + Defaults to `None`. + **kwargs + Additional keyword arguments to `quick_fit`. + + Returns + ------- + result : xarray.Dataset + The result of the fit. + + """ + if ax is None: + ax = plt.gca() + + darr = darr.mean([d for d in darr.dims if d != "eV"]) + result = quick_fit(darr, **kwargs) + ax.plot( + darr.eV, darr, ".", mec="0.6", alpha=1, mfc="none", ms=5, mew=0.3, label="Data" + ) + + result.modelfit_best_fit.qplot(ax=ax, c="r", label="Fit") + + ax.set_ylabel("Intensity (arb. units)") + if (darr.eV[0] * darr.eV[-1]) < 0: + ax.set_xlabel("$E - E_F$ (eV)") + else: + ax.set_xlabel(r"$E_{kin}$ (eV)") + + coeffs = result.modelfit_coefficients + center = result.modelfit_results.item().uvars["center"] + resolution = result.modelfit_results.item().uvars["resolution"] + + ax.text( + 0, + 0, + "\n".join( + [ + f"$T ={coeffs.sel(param='temp'):.3f}$ K", + f"$E_F = {center * 1e3:L}$ meV" + if center < 0.1 + else f"$E_F = {center:L}$ eV", + f"$\\Delta E = {resolution * 1e3:L}$ meV", + ] + ), + ha="left", + va="baseline", + transform=ax.transAxes + + matplotlib.transforms.ScaledTranslation( + 6 / 72, 6 / 72, ax.figure.dpi_scale_trans + ), + ) + ax.set_xlim(darr.eV[[0, -1]]) + ax.set_title("") + ax.axvline(coeffs.sel(param="center"), ls="--", c="k", lw=0.4, alpha=0.5) + ax.axvspan( + (center - resolution).n, + (center + resolution).n, + color="r", + alpha=0.2, + label="FWHM", + ) + return result + + def resolution( gold: xr.DataArray, angle_range: tuple[float, float], @@ -585,7 +753,7 @@ def resolution( plt.show() ax = plt.gca() gold_corr.qplot(ax=ax, cmap="copper", gamma=0.5) - rect = mpatches.Rectangle( + rect = matplotlib.patches.Rectangle( (angle_range[0], eV_range_fit[0]), np.diff(angle_range)[0], np.diff(eV_range_fit)[0],