Skip to content

Commit

Permalink
[functions.py]: Function fit_kww: Improve Initial Guess
Browse files Browse the repository at this point in the history
Improve the initial guess of tau in `mdtools.functions.fit_kww`.
  • Loading branch information
andthum committed Sep 10, 2023
1 parent 20eb854 commit 716a999
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions src/mdtools/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,9 @@ def fit_kww(xdata, ydata, ysd=None, return_valid=False, **kwargs):
Additional keyword arguments (besides `xdata`, `ydata` and
`sigma`) to parse to :func:`scipy.optimize.curve_fit`. See
there for possible options. By default, `absolute_sigma` is set
to ``True``, `p0` is set to ``[len(ydata) * np.mean(ydata), 1]``
and `bounds` is set to ``([0, 0], [np.inf, 1])``.
to ``True``, `p0` is set to ``[tau_init, 1]`` and `bounds` is
set to ``([0, 0], [np.inf, 1])``. ``tau_init`` is the point at
which `ydata` falls below :math:`1/e`.
Returns
-------
Expand Down Expand Up @@ -422,17 +423,32 @@ def fit_kww(xdata, ydata, ysd=None, return_valid=False, **kwargs):
return np.full(2, np.nan), np.full(2, np.nan), valid
else:
return np.full(2, np.nan), np.full(2, np.nan)
xdata = xdata[valid]
ydata = ydata[valid]
if ysd is not None:
ysd = ysd[valid]
ysd[ysd == 0] = CURVE_FIT_ZERO_SD

# Initial guess for tau.
# If beta = 1, f(t=tau) = 1/e => tau = f^(-1)(1/e)
thresh = 1 / np.e
ix = np.argmin(ydata[valid] <= thresh)
if ydata[ix] > thresh:
# `ydata` never falls below `thresh`.
# Linearly extrapolate `ydata` to `thresh`.
tau_init = (1 + ydata[-1] - thresh) * xdata[-1]
else:
tau_init = ydata[ix]
# Ensure that `tau_init` is greater than zero.
tau_init = max(tau_init, CURVE_FIT_ZERO_SD)

kwargs.setdefault("absolute_sigma", True)
kwargs.setdefault("p0", [len(ydata) * np.mean(ydata), 1])
kwargs.setdefault("p0", [tau_init, 1])
kwargs.setdefault("bounds", ([0, 0], [np.inf, 1]))

try:
popt, pcov = optimize.curve_fit(
f=kww, xdata=xdata[valid], ydata=ydata[valid], sigma=ysd, **kwargs
f=kww, xdata=xdata, ydata=ydata, sigma=ysd, **kwargs
)
except (ValueError, RuntimeError) as err:
print()
Expand Down

0 comments on commit 716a999

Please sign in to comment.