diff --git a/misc/dtrj_lifetimes/compare_dtrj_lifetime_methods.py b/misc/dtrj_lifetimes/compare_dtrj_lifetime_methods.py index c5f476c8..7619c182 100644 --- a/misc/dtrj_lifetimes/compare_dtrj_lifetime_methods.py +++ b/misc/dtrj_lifetimes/compare_dtrj_lifetime_methods.py @@ -197,37 +197,59 @@ time_conv = 1 print("\n") - print("Calculating lifetimes directly from `dtrj` (Methods 1-2)...") + print("Calculating lifetimes directly from `dtrj`...") timer = datetime.now() dtrj = mdt.fh.load_dtrj(args.INFILE_DTRJ) n_frames = dtrj.shape[1] # Method 1: Calculate the average lifetime by counting the number of - # frames that a given compound stays in a given state. - lts_cnt, states_cnt = mdt.dtrj.lifetimes_per_state( - dtrj, return_states=True + # frames that a given compound stays in a given state including + # truncated states at the trajectory edges -> censored. + lts_cnt_cen, states_cnt_cen = mdt.dtrj.lifetimes_per_state( + dtrj, uncensored=False, return_states=True ) - lts_cnt = [lts * time_conv for lts in lts_cnt] - lts_cnt_mom1 = np.array([np.mean(lts) for lts in lts_cnt]) - lts_cnt_mom2 = np.array([np.mean(lts**2) for lts in lts_cnt]) - del lts_cnt + lts_cnt_cen = [lts * time_conv for lts in lts_cnt_cen] + lts_cnt_cen_mom1 = np.array([np.mean(lts) for lts in lts_cnt_cen]) + lts_cnt_cen_mom2 = np.array([np.mean(lts**2) for lts in lts_cnt_cen]) + del lts_cnt_cen - # Method 2: Calculate the transition rate as the number of + # Method 2: Calculate the average lifetime by counting the number of + # frames that a given compound stays in a given state excluding + # truncated states at the trajectory edges -> uncensored. + lts_cnt_unc, states_cnt_unc = mdt.dtrj.lifetimes_per_state( + dtrj, uncensored=True, return_states=True + ) + lts_cnt_unc = [lts * time_conv for lts in lts_cnt_unc] + lts_cnt_unc_mom1 = np.array( + [np.mean(lts) if len(lts) > 0 else np.nan for lts in lts_cnt_unc] + ) + lts_cnt_unc_mom2 = np.array( + [np.mean(lts**2) if len(lts) > 0 else np.nan for lts in lts_cnt_unc] + ) + if not np.array_equal(states_cnt_unc, states_cnt_cen): + raise ValueError( + "`states_cnt_unc` ({}) != `states_cnt_cen`" + " ({})".format(states_cnt_unc, states_cnt_cen) + ) + del lts_cnt_unc, states_cnt_unc + + # Method 3: Calculate the transition rate as the number of # transitions leading out of a given state divided by the number of # frames that compounds have spent in this state. The average # lifetime is calculated as the inverse transition rate. rates, states_k = mdt.dtrj.trans_rate_per_state(dtrj, return_states=True) lts_k = time_conv / rates - if not np.array_equal(states_k, states_cnt): + if not np.array_equal(states_k, states_cnt_cen): raise ValueError( - "`states_k` ({}) != `states_cnt` ({})".format(states_k, states_cnt) + "`states_k` ({}) != `states_cnt_cen`" + " ({})".format(states_k, states_cnt_cen) ) del dtrj, rates, states_k print("Elapsed time: {}".format(datetime.now() - timer)) print("Current memory usage: {:.2f} MiB".format(mdt.rti.mem_usage(proc))) print("\n") - print("Calculating lifetimes from the remain probability (Methods 3-6)...") + print("Calculating lifetimes from the remain probability...") timer = datetime.now() # Read remain probabilities from file. @@ -248,15 +270,16 @@ "Some state indices are not integers but floats. states =" " {}".format(states) ) - if not np.array_equal(states, states_cnt): + if not np.array_equal(states, states_cnt_cen): raise ValueError( - "`states` ({}) != `states_cnt` ({})".format(states, states_cnt) + "`states` ({}) != `states_cnt_cen`" + " ({})".format(states, states_cnt_cen) ) - del states_cnt + del states_cnt_cen states = states.astype(np.int32) n_states = len(states) - # Method 3: Set the lifetime to the lag time at which the remain + # Method 4: Set the lifetime to the lag time at which the remain # probability crosses 1/e. thresh = 1 / np.e ix_thresh = np.nanargmax(remain_props <= thresh, axis=0) @@ -302,7 +325,7 @@ ) ) - # Method 4: Calculate the lifetime as the integral of the remain + # Method 5: Calculate the lifetime as the integral of the remain # probability p(t): # = n * int_0^inf t^(n-1) * p(t) dt lts_int_mom1 = np.full(n_states, np.nan, dtype=np.float64) @@ -338,7 +361,7 @@ stop_fit = 2 fit_stop[i] = min(end_fit, stop_fit) - # Method 5: Fit the remain probability with a Kohlrausch function + # Method 6: Fit the remain probability with a Kohlrausch function # (stretched exponential) and calculate the lifetime as the integral # of the fit: # I_kww(t) = exp[-(t/tau0_kww)^beta_kww] @@ -380,7 +403,7 @@ lts_kww_mom1 = tau0_kww * gamma(1 + 1 / beta_kww) lts_kww_mom2 = tau0_kww**2 * gamma(1 + 2 / beta_kww) - # Method 6: Fit the remain probability with the survival function of + # Method 7: Fit the remain probability with the survival function of # a Burr Type XII distribution and calculate the lifetime as the # integral fo the fit: # I_bur(t) = 1 / [1 + (t/tau0_bur)^beta_bur]^delta_bur @@ -523,39 +546,42 @@ timer = datetime.now() data = [ states, # 1 - # Method 1 (counting). - lts_cnt_mom1, # 2 - lts_cnt_mom2, # 3 - # Method 2 (inverse transition rate). - lts_k, # 4 - # Method 3 (1/e criterion). - lts_e, # 5 - # Method 4 (direct integral). - lts_int_mom1, # 6 - lts_int_mom2, # 7 - # Method 5 (integral of Kohlrausch fit). - lts_kww_mom1, # 8 - lts_kww_mom2, # 9 - tau0_kww, # 10 - tau0_kww_sd, # 11 - beta_kww, # 12 - beta_kww_sd, # 13 - fit_r2_kww, # 14 - fit_rmse_kww, # 15 - # Method 6 (integral of Burr fit). - lts_bur_mom1, # 16 - lts_bur_mom2, # 17 - tau0_bur, # 18 - tau0_bur_sd, # 19 - beta_bur, # 20 - beta_bur_sd, # 21 - delta_bur, # 22 - delta_bur_sd, # 23 - fit_r2_bur, # 24 - fit_rmse_bur, # 25 + # Method 1 (censored counting). + lts_cnt_cen_mom1, # 2 + lts_cnt_cen_mom2, # 3 + # Method 2 (uncensored counting). + lts_cnt_unc_mom1, # 4 + lts_cnt_unc_mom2, # 5 + # Method 3 (inverse transition rate). + lts_k, # 6 + # Method 4 (1/e criterion). + lts_e, # 7 + # Method 5 (direct integral). + lts_int_mom1, # 8 + lts_int_mom2, # 9 + # Method 6 (integral of Kohlrausch fit). + lts_kww_mom1, # 10 + lts_kww_mom2, # 11 + tau0_kww, # 12 + tau0_kww_sd, # 13 + beta_kww, # 14 + beta_kww_sd, # 15 + fit_r2_kww, # 16 + fit_rmse_kww, # 17 + # Method 7 (integral of Burr fit). + lts_bur_mom1, # 18 + lts_bur_mom2, # 19 + tau0_bur, # 20 + tau0_bur_sd, # 21 + beta_bur, # 22 + beta_bur_sd, # 23 + delta_bur, # 24 + delta_bur_sd, # 25 + fit_r2_bur, # 26 + fit_rmse_bur, # 27 # Fit region - fit_start, # 26 - fit_stop - 1, # 27 + fit_start, # 28 + fit_stop - 1, # 29 ] if args.INFILE_PARAM is not None: data += params[1:12].tolist() @@ -565,8 +591,8 @@ "State lifetimes.\n" + "Average time that a given compound stays in a given state\n" + "calculated either directly from the discrete trajectory\n" - + "(Method 1-2) or from the corresponding remain probability\n" - + "function (Method 3-6). \n" + + "(Method 1-3) or from the corresponding remain probability\n" + + "function (Method 4-7). \n" + "\n" + "\n" + "Discrete trajectory: {:s}\n".format(args.INFILE_DTRJ) @@ -579,25 +605,35 @@ + "\n" ) header += ( - "Lifetimes are calculated using six different methods:\n" + "Lifetimes are calculated using different methods:\n" + "\n" - + "1) The average lifetime is calculated by counting how\n" - + " many frames a given compound stays in a given state. Note\n" - + " lifetimes calculated in this way can at maximum be as long as\n" - + " the trajectory and are usually biased to lower values because\n" - + " of edge effects (censoring).\n" + + "1) The average lifetime is calculated by counting how\n" + + " many frames a given compound stays in a given state including\n" + + " truncated states at the trajectory edges -> censored counting.\n" + + " Note that lifetimes calculated in this way are usually biased\n" + + " to lower values because of the limited length of the\n" + + " trajectory and because of truncation/censoring at the\n" + + " trajectory edges.\n" + "\n" - + "2) The average transition rate is calculated as the number of\n" + + "2) The average lifetime is calculated by counting how\n" + + " many frames a given compound stays in a given state excluding\n" + + " truncated states at the trajectory edges -> uncensored\n" + + " counting. Note that lifetimes calculated in this way are\n" + + " usually biased to lower values because of the limited length\n" + + " of the trajectory. Uncensored counting might waste a\n" + + " significant amount of the trajectory.\n" + + "\n" + + "3) The average transition rate is calculated as the number of\n" + " transitions leading out of a given state divided by the number\n" + " frames that compounds have spent in this state. The average\n" + " lifetime is calculated as the inverse transition rate:\n" + " = 1 / \n" + "\n" - + "3) The average lifetime is set to the lag time at which the\n" + + "4) The average lifetime is set to the lag time at which the\n" + " remain probability function p(t) crosses 1/e. If this never\n" + " happens, is set to NaN.\n" + "\n" - + "4) The remain probability function p(t) is interpreted as the\n" + + "5) The remain probability function p(t) is interpreted as the\n" + " survival function of the underlying lifetime distribution.\n" + " Thus, the lifetime can be calculated according to the\n" + " alternative expectation formula [1]:\n" @@ -605,7 +641,7 @@ + " If p(t) does not decay below the given threshold of\n" + " {:.4f}, is set to NaN.\n".format(args.INT_THRESH) + "\n" - + "5) The remain probability function p(t) is fitted by a Kohlrausch\n" + + "6) The remain probability function p(t) is fitted by a Kohlrausch\n" + " function (stretched exponential, survival function of the\n" + " Weibull distribution):\n" + " I_kww(t) = exp[-(t/tau0_kww)^beta_kww]\n" @@ -621,13 +657,13 @@ + " = tau0_kww^n * Gamma(1 + n/beta_kww)\n" + " where Gamma(z) is the gamma function.\n" + "\n" - + "6) The remain probability function p(t) is fitted by the survival\n" + + "7) The remain probability function p(t) is fitted by the survival\n" + " function of a Burr Type XII distribution:\n" + " I_bur(t) = 1 / [1 + (t/tau0_bur)^beta_bur]^delta_bur\n" + " Thereby, tau0_bur is confined to the interval [{}, {}]\n".format( bounds_bur[0][0], bounds_bur[1][0] ) - + " beta_bur is confined to the interval [{}, {}] and\n".format( + + " beta_bur is confined to the interval [{}, {}] and\n".format( bounds_bur[0][1], bounds_bur[1][1] ) + " beta_bur * delta_bur is confined to the interval\n" @@ -660,65 +696,69 @@ + "The columns contain:\n" + " 1 State index (zero-based)\n" + "\n" - + " Lifetime from Method 1 (counting)\n" - + " 2 1st moment / frames\n" - + " 3 2nd moment / frames^2\n" + + " Lifetime from Method 1 (censored counting)\n" + + " 2 1st moment / frames\n" + + " 3 2nd moment / frames^2\n" + + "\n" + + " Lifetime from Method 2 (uncensored counting)\n" + + " 4 1st moment / frames\n" + + " 5 2nd moment / frames^2\n" + "\n" - + " Lifetime from Method 2 (inverse transition rate)\n" - + " 4 / frames\n" + + " Lifetime from Method 3 (inverse transition rate)\n" + + " 6 / frames\n" + "\n" - + " Lifetime from Method 3 (1/e criterion)\n" - + " 5 / frames\n" + + " Lifetime from Method 4 (1/e criterion)\n" + + " 7 / frames\n" + "\n" - + " Lifetime from Method 4 (direct integral)\n" - + " 6 1st moment / frames\n" - + " 7 2nd moment / frames^2\n" + + " Lifetime from Method 5 (direct integral)\n" + + " 8 1st moment / frames\n" + + " 9 2nd moment / frames^2\n" + "\n" - + " Lifetime from Method 5 (integral of Kohlrausch fit)\n" - + " 8 1st moment / frames\n" - + " 9 2nd moment / frames^2\n" - + " 10 Fit parameter tau0_kww / frames\n" - + " 11 Standard deviation of tau0_kww / frames\n" - + " 12 Fit parameter beta_kww\n" - + " 13 Standard deviation of beta_kww\n" - + " 14 Coefficient of determination of the fit (R^2 value)\n" - + " 15 Root-mean-square error (RMSE) of the fit\n" + + " Lifetime from Method 6 (integral of Kohlrausch fit)\n" + + " 10 1st moment / frames\n" + + " 11 2nd moment / frames^2\n" + + " 12 Fit parameter tau0_kww / frames\n" + + " 13 Standard deviation of tau0_kww / frames\n" + + " 14 Fit parameter beta_kww\n" + + " 15 Standard deviation of beta_kww\n" + + " 16 Coefficient of determination of the fit (R^2 value)\n" + + " 17 Root-mean-square error (RMSE) of the fit\n" + "\n" - + " Lifetime from Method 6 (integral of Burr fit)\n" - + " 16 1st moment / frames\n" - + " 17 2nd moment / frames^2\n" - + " 18 Fit parameter tau0_burr / frames\n" - + " 19 Standard deviation of tau0_burr / frames\n" - + " 20 Fit parameter beta_burr\n" - + " 21 Standard deviation of beta_burr\n" - + " 22 Fit parameter delta_burr\n" - + " 23 Standard deviation of delta_burr\n" - + " 24 Coefficient of determination of the fit (R^2 value)\n" - + " 25 Root-mean-square error (RMSE) of the fit\n" + + " Lifetime from Method 7 (integral of Burr fit)\n" + + " 18 1st moment / frames\n" + + " 19 2nd moment / frames^2\n" + + " 20 Fit parameter tau0_burr / frames\n" + + " 21 Standard deviation of tau0_burr / frames\n" + + " 22 Fit parameter beta_burr\n" + + " 23 Standard deviation of beta_burr\n" + + " 24 Fit parameter delta_burr\n" + + " 25 Standard deviation of delta_burr\n" + + " 26 Coefficient of determination of the fit (R^2 value)\n" + + " 27 Root-mean-square error (RMSE) of the fit\n" + "\n" + " Fit region for all fitting methods\n" - + " 26 Start of fit region (inclusive) / frames\n" - + " 27 End of fit region (exclusive) / frames\n" + + " 28 Start of fit region (inclusive) / frames\n" + + " 29 End of fit region (exclusive) / frames\n" ) if args.INFILE_PARAM is not None: header += ( "\n" + " True state lifetimes\n" - + " 28 Shape parameter beta of the true distribution\n" - + " 29 Shape parameter delta of the true distribution\n" - + " 30 Scale parameter tau0 of the true distribution\n" - + " 31 1st moment of the true distribution / frames\n" - + " 32 2nd moment of the true distribution / frames^2\n" - + " 33 1st moment of the drawn lifetimes / frames\n" - + " 34 2nd moment of the drawn lifetimes / frames^2\n" - + " 35 1st moment of the uncensored lifetimes / frames\n" - + " 36 2nd moment of the uncensored lifetimes /" + + " 30 Shape parameter beta of the true distribution\n" + + " 31 Shape parameter delta of the true distribution\n" + + " 32 Scale parameter tau0 of the true distribution\n" + + " 33 1st moment of the true distribution / frames\n" + + " 34 2nd moment of the true distribution / frames^2\n" + + " 35 1st moment of the drawn lifetimes / frames\n" + + " 36 2nd moment of the drawn lifetimes / frames^2\n" + + " 37 1st moment of the uncensored lifetimes / frames\n" + + " 38 2nd moment of the uncensored lifetimes /" + " frames^2\n" - + " 37 1st moment of the censored lifetimes / frames\n" - + " 38 2nd moment of the censored lifetimes / frames^2\n" - + " 39 R^2 if the remain probability is seen as fit of the\n" + + " 39 1st moment of the censored lifetimes / frames\n" + + " 40 2nd moment of the censored lifetimes / frames^2\n" + + " 41 R^2 if the remain probability is seen as fit of the\n" + " survival function (SF) of the true distribution\n" - + " 40 RMSE of the remain probability to the true SF\n" + + " 42 RMSE of the remain probability to the true SF\n" ) header += "\n" + "Column number:\n" header += "{:>14d}".format(1) @@ -734,7 +774,7 @@ print("Creating plot(s)...") timer = datetime.now() lts_mom1s = [ - lts_cnt_mom1, + lts_cnt_cen_mom1, lts_k, lts_e, lts_int_mom1, @@ -774,7 +814,7 @@ with PdfPages(outfile) as pdf: # Plot lifetimes vs. state indices. fig, ax = plt.subplots(clear=True) - # Method 6 (integral of Burr fit). + # Method 7 (integral of Burr fit). ax.errorbar( states, lts_bur_mom1, @@ -784,7 +824,7 @@ marker="^", alpha=alpha, ) - # Method 5 (integral of Kohlrausch fit). + # Method 6 (integral of Kohlrausch fit). ax.errorbar( states, lts_kww_mom1, @@ -794,7 +834,7 @@ marker="v", alpha=alpha, ) - # Method 4 (direct integral) + # Method 5 (direct integral) ax.errorbar( states, lts_int_mom1, @@ -804,7 +844,7 @@ marker=">", alpha=alpha, ) - # Method 3 (1/e criterion). + # Method 4 (1/e criterion). ax.errorbar( states, lts_e, @@ -814,22 +854,32 @@ marker="<", alpha=alpha, ) - # Method 2 (inverse transition rate). + # Method 3 (inverse transition rate). ax.errorbar( states, lts_k, yerr=None, label="Rate", + color="tab:brown", + marker="p", + alpha=alpha, + ) + # Method 2 (uncensored counting). + ax.errorbar( + states, + lts_cnt_unc_mom1, + yerr=np.sqrt(lts_cnt_unc_mom2 - lts_cnt_unc_mom1**2), + label="Uncens. Count", color="tab:red", marker="h", alpha=alpha, ) - # Method 1 (counting). + # Method 1 (censored counting). ax.errorbar( states, - lts_cnt_mom1, - yerr=np.sqrt(lts_cnt_mom2 - lts_cnt_mom1**2), - label="Count", + lts_cnt_cen_mom1, + yerr=np.sqrt(lts_cnt_cen_mom2 - lts_cnt_cen_mom1**2), + label="Cens. Count", color="tab:orange", marker="H", alpha=alpha, @@ -891,7 +941,7 @@ # Plot fit parameter tau0. fig, ax = plt.subplots(clear=True) - # Method 6 (Burr fit). + # Method 7 (Burr fit). ax.errorbar( states, tau0_bur, @@ -900,7 +950,7 @@ color="tab:cyan", marker="^", ) - # Method 5 (Kohlrausch fit). + # Method 6 (Kohlrausch fit). ax.errorbar( states, tau0_kww, @@ -945,7 +995,7 @@ # Plot fit parameter beta. fig, ax = plt.subplots(clear=True) - # Method 6 (Burr fit). + # Method 7 (Burr fit). ax.errorbar( states, beta_bur, @@ -954,7 +1004,7 @@ color="tab:cyan", marker="^", ) - # Method 5 (Kohlrausch fit). + # Method 6 (Kohlrausch fit). ax.errorbar( states, beta_kww, @@ -999,7 +1049,7 @@ # Plot fit parameter delta. fig, ax = plt.subplots(clear=True) - # Method 6 (Burr fit). + # Method 7 (Burr fit). ax.errorbar( states, delta_bur, @@ -1044,7 +1094,7 @@ # Plot R^2 value of the fits. fig, ax = plt.subplots(clear=True) - # Method 6 (Burr fit). + # Method 7 (Burr fit). ax.plot( states, fit_r2_bur, @@ -1052,7 +1102,7 @@ color="tab:cyan", marker="^", ) - # Method 5 (Kohlrausch fit). + # Method 6 (Kohlrausch fit). ax.plot( states, fit_r2_kww, @@ -1092,7 +1142,7 @@ # Plot root-mean-square error. fig, ax = plt.subplots(clear=True) - # Method 6 (Burr fit). + # Method 7 (Burr fit). ax.plot( states, fit_rmse_bur, @@ -1100,7 +1150,7 @@ color="tab:cyan", marker="^", ) - # Method 5 (Kohlrausch fit). + # Method 6 (Kohlrausch fit). ax.plot( states, fit_rmse_kww,