Skip to content

Commit

Permalink
feat(analysis.gold): add function for quick resolution fitting
Browse files Browse the repository at this point in the history
  • Loading branch information
kmnhan committed Jun 9, 2024
1 parent 2ccd8ad commit 2fae1c3
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/erlab/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
"""

__all__ = ["correct_with_edge", "shift", "slice_along_path"]
__all__ = ["correct_with_edge", "quick_resolution", "shift", "slice_along_path"]

from erlab.analysis import fit, gold, image, interpolate, mask, transform # noqa: F401
from erlab.analysis.gold import correct_with_edge
from erlab.analysis.gold import correct_with_edge, quick_resolution
from erlab.analysis.interpolate import slice_along_path
from erlab.analysis.utils import shift
174 changes: 171 additions & 3 deletions src/erlab/analysis/gold.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
"edge",
"poly",
"poly_from_edge",
"quick_fit",
"quick_resolution",
"resolution",
"resolution_roi",
"spline_from_edge",
Expand All @@ -17,8 +19,9 @@
import lmfit.model
import matplotlib
import matplotlib.figure
import matplotlib.patches as mpatches
import matplotlib.patches
import matplotlib.pyplot as plt
import matplotlib.transforms
import numpy as np
import numpy.typing as npt
import scipy.interpolate
Expand Down Expand Up @@ -363,7 +366,7 @@ def _plot_gold_fit(fig, gold, angle_range, eV_range, center_arr, center_stderr,
ax2 = fig.add_subplot(gs[1, 1], sharex=ax1)

gold.qplot(ax=ax0, cmap="copper", gamma=0.5)
rect = mpatches.Rectangle(
rect = matplotlib.patches.Rectangle(
(angle_range[0], eV_range[0]),
np.diff(angle_range)[0],
np.diff(eV_range)[0],
Expand Down Expand Up @@ -540,6 +543,171 @@ def spline(
return spl


def quick_fit(
darr: xr.DataArray,
eV_range: tuple[float, float] | None = None,
method: str = "leastsq",
temp: float | None = None,
resolution: float | None = None,
fix_temp: bool = True,
fix_center: bool = False,
fix_resolution: bool = False,
bkg_slope: bool = True,
) -> xr.Dataset:
"""Perform a quick Fermi edge fit on the given data.
The data is averaged over all dimensions except the energy prior to fitting.
Parameters
----------
darr
The input data to be fitted.
eV_range
The energy range to consider for fitting. If `None`, the entire energy range is
used. Defaults to `None`.
method
The fitting method to use that is compatible with `lmfit`. Defaults to
"leastsq".
temp
The temperature value to use for fitting. If `None`, the temperature is inferred
from the data attributes.
resolution
The initial resolution value to use for fitting. If `None`, the resolution is
set to 0.02, or to the ``'TotalResolution'`` attribute if present.
fix_temp
Whether to fix the temperature value during fitting. Defaults to `True`.
fix_center
Whether to fix the Fermi level during fitting. If `True`, the Fermi level is
fixed to 0. Defaults to `False`.
fix_resolution
Whether to fix the resolution value during fitting. Defaults to `False`.
bkg_slope
Whether to include a linear background above the Fermi level. If `False`, the
background above the Fermi level is fit with a constant. Defaults to `True`.
Returns
-------
result : xarray.Dataset
The result of the fit.
"""
data = darr.mean([d for d in darr.dims if d != "eV"])

if eV_range is not None:
data_fit = data.sel(eV=slice(*eV_range))
else:
data_fit = data

if temp is None:
if "temp_sample" in data.attrs:
temp = float(data.attrs["temp_sample"])
else:
raise ValueError(
"Temperature not found in data attributes, please provide manually"
)

if resolution is None:
if "TotalResolution" in data.attrs:
resolution = float(data.attrs["TotalResolution"]) * 1e-3
else:
resolution = 0.02

params = {
"temp": {"value": temp, "vary": not fix_temp, "min": 0},
"resolution": {"value": resolution, "vary": not fix_resolution, "min": 0},
}

if not bkg_slope:
params["back1"] = {"value": 0, "vary": False}

if fix_center:
params["center"] = {"value": 0, "vary": False}

return data_fit.modelfit(
"eV", model=FermiEdgeModel(), method=method, params=params, guess=True
)


def quick_resolution(
darr: xr.DataArray,
ax: matplotlib.axes.Axes | None = None,
**kwargs,
) -> xr.Dataset:
"""Fit a Fermi edge to the given data and plot the results.
This function is a wrapper around `quick_fit` that plots the data and the obtained
resolution. The data is averaged over all dimensions except the energy prior to
fitting.
Parameters
----------
darr
The input data to be fitted.
ax
The axis to plot the data and fit on. If `None`, the current axis is used.
Defaults to `None`.
**kwargs
Additional keyword arguments to `quick_fit`.
Returns
-------
result : xarray.Dataset
The result of the fit.
"""
if ax is None:
ax = plt.gca()

darr = darr.mean([d for d in darr.dims if d != "eV"])
result = quick_fit(darr, **kwargs)
ax.plot(
darr.eV, darr, ".", mec="0.6", alpha=1, mfc="none", ms=5, mew=0.3, label="Data"
)

result.modelfit_best_fit.qplot(ax=ax, c="r", label="Fit")

ax.set_ylabel("Intensity (arb. units)")
if (darr.eV[0] * darr.eV[-1]) < 0:
ax.set_xlabel("$E - E_F$ (eV)")
else:
ax.set_xlabel(r"$E_{kin}$ (eV)")

coeffs = result.modelfit_coefficients
center = result.modelfit_results.item().uvars["center"]
resolution = result.modelfit_results.item().uvars["resolution"]

ax.text(
0,
0,
"\n".join(
[
f"$T ={coeffs.sel(param='temp'):.3f}$ K",
f"$E_F = {center * 1e3:L}$ meV"
if center < 0.1
else f"$E_F = {center:L}$ eV",
f"$\\Delta E = {resolution * 1e3:L}$ meV",
]
),
ha="left",
va="baseline",
transform=ax.transAxes
+ matplotlib.transforms.ScaledTranslation(
6 / 72, 6 / 72, ax.figure.dpi_scale_trans
),
)
ax.set_xlim(darr.eV[[0, -1]])
ax.set_title("")
ax.axvline(coeffs.sel(param="center"), ls="--", c="k", lw=0.4, alpha=0.5)
ax.axvspan(
(center - resolution).n,
(center + resolution).n,
color="r",
alpha=0.2,
label="FWHM",
)
return result


def resolution(
gold: xr.DataArray,
angle_range: tuple[float, float],
Expand Down Expand Up @@ -585,7 +753,7 @@ def resolution(
plt.show()
ax = plt.gca()
gold_corr.qplot(ax=ax, cmap="copper", gamma=0.5)
rect = mpatches.Rectangle(
rect = matplotlib.patches.Rectangle(
(angle_range[0], eV_range_fit[0]),
np.diff(angle_range)[0],
np.diff(eV_range_fit)[0],
Expand Down

0 comments on commit 2fae1c3

Please sign in to comment.