Skip to content

Commit

Permalink
Feat/return text & graph instead of print & show (#292)
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock authored Jun 10, 2023
1 parent 57fa242 commit 1f402f8
Show file tree
Hide file tree
Showing 4 changed files with 457 additions and 427 deletions.
702 changes: 364 additions & 338 deletions fsrs4anki_optimizer.ipynb

Large diffs are not rendered by default.

15 changes: 11 additions & 4 deletions package/fsrs4anki_optimizer/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,21 @@ def remembered_fallback_prompt(key: str, pretty: str = None):

optimizer = fsrs4anki_optimizer.Optimizer()
optimizer.anki_extract(args.filename)
optimizer.create_time_series(
analysis = optimizer.create_time_series(
remembered_fallbacks["timezone"],
remembered_fallbacks["revlog_start_date"],
remembered_fallbacks["next_day"]
)
print(analysis)

optimizer.define_model()
optimizer.train()

optimizer.predict_memory_states()
optimizer.find_optimal_retention(show_graphs)
figures = optimizer.find_optimal_retention()
if show_graphs:
for f in figures:
f.show()

optimizer.preview(optimizer.optimal_retention)

Expand All @@ -109,6 +113,9 @@ def remembered_fallback_prompt(key: str, pretty: str = None):
with open(args.out, "a+") as f:
f.write(profile)

optimizer.evaluate()
if show_graphs:
optimizer.evaluate()
optimizer.calibration_graph()
for f in optimizer.calibration_graph():
f.show()
for f in optimizer.compare_with_sm2():
f.show()
165 changes: 81 additions & 84 deletions package/fsrs4anki_optimizer/fsrs4anki_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,10 @@
from sklearn.model_selection import StratifiedGroupKFold
from sklearn.metrics import mean_squared_error, r2_score
from itertools import accumulate
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

def is_interactive(): # https://stackoverflow.com/questions/15411967/how-can-i-check-if-code-is-executed-in-the-ipython-notebook
import __main__ as main
return not hasattr(main, '__file__')

if is_interactive():
from tqdm import notebook
else:
# Export cli module pretending to be notebook if not in notebook
from tqdm import cli as notebook

class FSRS(nn.Module):
def __init__(self, w):
super(FSRS, self).__init__()
Expand Down Expand Up @@ -208,7 +199,7 @@ def train(self):
best_loss = weighted_loss
best_w = w

pbar = notebook.tqdm(desc="pre-train", colour="red", total=len(self.pre_train_data_loader) * self.n_epoch)
pbar = tqdm(desc="pre-train", colour="red", total=len(self.pre_train_data_loader) * self.n_epoch, )
for k in range(self.n_epoch):
for i, batch in enumerate(self.pre_train_data_loader):
self.model.train()
Expand All @@ -229,7 +220,7 @@ def train(self):
print(f"{name}: {list(map(lambda x: round(float(x), 4),param))}")

epoch_len = len(self.next_train_data_loader)
pbar = notebook.tqdm(desc="train", colour="red", total=epoch_len*self.n_epoch)
pbar = tqdm(desc="train", colour="red", total=epoch_len*self.n_epoch)
print_len = max(self.batch_nums*self.n_epoch // 10, 1)
for k in range(self.n_epoch):
weighted_loss, w = self.eval()
Expand All @@ -255,9 +246,9 @@ def train(self):
pbar.update(real_batch_size)

if (k * self.batch_nums + i + 1) % print_len == 0:
print(f"iteration: {k * epoch_len + (i + 1) * self.batch_size}")
tqdm.write(f"iteration: {k * epoch_len + (i + 1) * self.batch_size}")
for name, param in self.model.named_parameters():
print(f"{name}: {list(map(lambda x: round(float(x), 4),param))}")
tqdm.write(f"{name}: {list(map(lambda x: round(float(x), 4),param))}")
pbar.close()

weighted_loss, w = self.eval()
Expand All @@ -277,7 +268,7 @@ def eval(self):
retentions = torch.exp(np.log(0.9) * delta_ts / stabilities)
tran_loss = self.loss_fn(retentions, labels)/len(self.train_set)
self.avg_train_losses.append(tran_loss)
print(f"Loss in trainset: {tran_loss:.4f}")
tqdm.write(f"Loss in trainset: {tran_loss:.4f}")

sequences, delta_ts, labels, seq_lens = self.test_set.x_train, self.test_set.t_train, self.test_set.y_train, self.test_set.seq_len
real_batch_size = seq_lens.shape[0]
Expand All @@ -286,7 +277,7 @@ def eval(self):
retentions = torch.exp(np.log(0.9) * delta_ts / stabilities)
test_loss = self.loss_fn(retentions, labels)/len(self.test_set)
self.avg_eval_losses.append(test_loss)
print(f"Loss in testset: {test_loss:.4f}")
tqdm.write(f"Loss in testset: {test_loss:.4f}")

w = list(map(lambda x: round(float(x), 4), dict(self.model.named_parameters())['w'].data))

Expand All @@ -295,12 +286,14 @@ def eval(self):
return weighted_loss, w

def plot(self):
plt.plot(self.avg_train_losses, label='train')
plt.plot(self.avg_eval_losses, label='test')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend()
plt.show()
fig = plt.figure()
ax = fig.gca()
ax.plot(self.avg_train_losses, label='train')
ax.plot(self.avg_eval_losses, label='test')
ax.set_xlabel('epoch')
ax.set_ylabel('loss')
ax.legend()
return fig

class Collection:
def __init__(self, w):
Expand All @@ -322,8 +315,7 @@ def batch_predict(self, dataset):
"""Used to store all the results from FSRS related functions"""
class Optimizer:
def __init__(self) -> None:
notebook.tqdm.pandas()
pass
tqdm.pandas()

@staticmethod
def anki_extract(filename: str):
Expand Down Expand Up @@ -432,7 +424,7 @@ def cal_stability(group: pd.DataFrame) -> pd.DataFrame:
df.sort_values(by=['r_history'], inplace=True, ignore_index=True)

if df.shape[0] > 0:
for idx in notebook.tqdm(df.index):
for idx in tqdm(df.index):
item = df.loc[idx]
index = df[(df['i'] == item['i'] + 1) & (df['r_history'].str.startswith(item['r_history']))].index
df.loc[index, 'last_stability'] = item['stability']
Expand All @@ -441,9 +433,10 @@ def cal_stability(group: pd.DataFrame) -> pd.DataFrame:
df['last_recall'] = df['r_history'].map(lambda x: x[-1])
df = df[df.groupby(['i', 'r_history'], group_keys=False)['group_cnt'].transform(max) == df['group_cnt']]
df.to_csv('./stability_for_analysis.tsv', sep='\t', index=None)
print("1:again, 2:hard, 3:good, 4:easy\n")
print(df[df['r_history'].str.contains(r'^[1-4][^124]*$', regex=True)][['r_history', 'avg_interval', 'avg_retention', 'stability', 'factor', 'group_cnt']].to_string(index=False))
print("Analysis saved!")
caption = "1:again, 2:hard, 3:good, 4:easy\n"
analysis = df[df['r_history'].str.contains(r'^[1-4][^124]*$', regex=True)][['r_history', 'avg_interval', 'avg_retention', 'stability', 'factor', 'group_cnt']].to_string(index=False)
return caption + analysis

def define_model(self):
"""Step 3"""
Expand Down Expand Up @@ -475,6 +468,7 @@ def train(self, lr: float = 4e-2, n_epoch: int = 3, n_splits: int = 3, batch_siz
print("Tensorized!")

w = []
plots = []
if n_splits > 1:
sgkf = StratifiedGroupKFold(n_splits=n_splits)
for train_index, test_index in sgkf.split(self.dataset, self.dataset['i'], self.dataset['group']):
Expand All @@ -483,17 +477,18 @@ def train(self, lr: float = 4e-2, n_epoch: int = 3, n_splits: int = 3, batch_siz
test_set = self.dataset.iloc[test_index].copy()
trainer = Trainer(train_set, test_set, self.init_w, n_epoch=n_epoch, lr=lr, batch_size=batch_size)
w.append(trainer.train())
trainer.plot()
plots.append(trainer.plot())
else:
trainer = Trainer(self.dataset, self.dataset, self.init_w, n_epoch=n_epoch, lr=lr, batch_size=batch_size)
w.append(trainer.train())
trainer.plot()
plots.append(trainer.plot())

w = np.array(w)
avg_w = np.round(np.mean(w, axis=0), 4)
self.w = avg_w.tolist()

print("\nTraining finished!")
return plots

def preview(self, requestRetention: float):
my_collection = Collection(self.w)
Expand Down Expand Up @@ -563,7 +558,7 @@ def predict_memory_states(self):
if i+1 in self.difficulty_distribution.index:
self.difficulty_distribution_padding[i] = self.difficulty_distribution.loc[i+1]

def find_optimal_retention(self, graph=True):
def find_optimal_retention(self):
"""should not be called before predict_memory_states"""

base = 1.01
Expand Down Expand Up @@ -614,7 +609,7 @@ def cal_next_recall_stability(s, r, d, response):
print(f"terminal stability: {stability_list.max(): .2f}")
df = pd.DataFrame(columns=["retention", "difficulty", "time"])

for percentage in notebook.tqdm(range(96, 66, -2)):
for percentage in tqdm(range(96, 66, -2)):
recall = percentage / 100
time_list = np.zeros((d_range, index_len))
time_list[:,:-1] = max_time
Expand Down Expand Up @@ -646,24 +641,25 @@ def cal_next_recall_stability(s, r, d, response):
print("expected_time.csv saved.")

optimal_retention_list = np.zeros(10)
fig = plt.figure()
ax = fig.gca()
for d in range(1, d_range+1):
retention = df[df["difficulty"] == d]["retention"]
cost = df[df["difficulty"] == d]["time"]
optimal_retention = retention.iat[cost.argmin()]
optimal_retention_list[d-1] = optimal_retention
plt.plot(retention, cost, label=f"d={d}, r={optimal_retention}")
ax.plot(retention, cost, label=f"d={d}, r={optimal_retention}")

self.optimal_retention = np.inner(self.difficulty_distribution_padding, optimal_retention_list)

print(f"\n-----suggested retention (experimental): {self.optimal_retention:.2f}-----")

if graph:
plt.ylabel("expected time (second)")
plt.xlabel("retention")
plt.legend()
plt.grid()
plt.semilogy()
plt.show()
ax.set_ylabel("expected time (second)")
ax.set_xlabel("retention")
ax.legend()
ax.grid()
ax.semilogy()
return (fig, )

def evaluate(self):
my_collection = Collection(self.init_w)
Expand Down Expand Up @@ -692,14 +688,13 @@ def evaluate(self):
del tmp

def calibration_graph(self):
plot_brier(self.dataset['p'], self.dataset['y'], bins=40)
plt.show()
fig1 = plot_brier(self.dataset['p'], self.dataset['y'], bins=40)

def to_percent(temp, position):
return '%1.0f' % (100 * temp) + '%'

fig = plt.figure(1)
ax1 = fig.add_subplot(111)
fig2 = plt.figure()
ax1 = fig2.add_subplot(111)
ax2 = ax1.twinx()
lns = []

Expand All @@ -725,13 +720,12 @@ def to_percent(temp, position):

labs = [l.get_label() for l in lns]
ax2.legend(lns, labs, loc='lower right')
plt.grid(linestyle='--')
plt.gca().yaxis.set_major_formatter(ticker.FuncFormatter(to_percent))
plt.gca().xaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
plt.show()
ax2.grid(linestyle='--')
ax2.yaxis.set_major_formatter(ticker.FuncFormatter(to_percent))
ax2.xaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))

fig = plt.figure(1)
ax1 = fig.add_subplot(111)
fig3 = plt.figure()
ax1 = fig3.add_subplot(111)
ax2 = ax1.twinx()
lns = []

Expand All @@ -756,10 +750,11 @@ def to_percent(temp, position):

labs = [l.get_label() for l in lns]
ax2.legend(lns, labs, loc='lower right')
plt.grid(linestyle='--')
plt.gca().yaxis.set_major_formatter(ticker.FuncFormatter(to_percent))
plt.gca().xaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))
plt.show()
ax2.grid(linestyle='--')
ax2.yaxis.set_major_formatter(ticker.FuncFormatter(to_percent))
ax2.xaxis.set_major_formatter(ticker.FormatStrFormatter('%d'))

return fig1, fig2, fig3

def bw_matrix(self):
B_W_Metric_raw = self.dataset[['difficulty', 'stability', 'p', 'y']].copy()
Expand All @@ -779,37 +774,38 @@ def compare_with_sm2(self):
self.dataset['log_loss'] = self.dataset.apply(lambda row: - np.log(row['sm2_p']) if row['y'] == 1 else - np.log(1 - row['sm2_p']), axis=1)
print(f"Loss of SM-2: {self.dataset['log_loss'].mean():.4f}")
cross_comparison = self.dataset[['sm2_p', 'p', 'y']].copy()
plot_brier(cross_comparison['sm2_p'], cross_comparison['y'], bins=40)
fig1 = plot_brier(cross_comparison['sm2_p'], cross_comparison['y'], bins=40)

plt.figure(figsize=(6, 6))
fig2 = plt.figure(figsize=(6, 6))
ax = fig2.gca()

cross_comparison['SM2_B-W'] = cross_comparison['sm2_p'] - cross_comparison['y']
cross_comparison['SM2_bin'] = cross_comparison['sm2_p'].map(lambda x: round(x, 1))
cross_comparison['FSRS_B-W'] = cross_comparison['p'] - cross_comparison['y']
cross_comparison['FSRS_bin'] = cross_comparison['p'].map(lambda x: round(x, 1))

plt.axhline(y = 0.0, color = 'black', linestyle = '-')
ax.axhline(y = 0.0, color = 'black', linestyle = '-')

cross_comparison_group = cross_comparison.groupby(by='SM2_bin').agg({'y': ['mean'], 'FSRS_B-W': ['mean'], 'p': ['mean', 'count']})
print(f"Universal Metric of FSRS: {mean_squared_error(cross_comparison_group['y', 'mean'], cross_comparison_group['p', 'mean'], sample_weight=cross_comparison_group['p', 'count'], squared=False):.4f}")
cross_comparison_group['p', 'percent'] = cross_comparison_group['p', 'count'] / cross_comparison_group['p', 'count'].sum()
plt.scatter(cross_comparison_group.index, cross_comparison_group['FSRS_B-W', 'mean'], s=cross_comparison_group['p', 'percent'] * 1024, alpha=0.5)
plt.plot(cross_comparison_group['FSRS_B-W', 'mean'], label='FSRS by SM2')
ax.scatter(cross_comparison_group.index, cross_comparison_group['FSRS_B-W', 'mean'], s=cross_comparison_group['p', 'percent'] * 1024, alpha=0.5)
ax.plot(cross_comparison_group['FSRS_B-W', 'mean'], label='FSRS by SM2')

cross_comparison_group = cross_comparison.groupby(by='FSRS_bin').agg({'y': ['mean'], 'SM2_B-W': ['mean'], 'sm2_p': ['mean', 'count']})
print(f"Universal Metric of SM2: {mean_squared_error(cross_comparison_group['y', 'mean'], cross_comparison_group['sm2_p', 'mean'], sample_weight=cross_comparison_group['sm2_p', 'count'], squared=False):.4f}")
cross_comparison_group['sm2_p', 'percent'] = cross_comparison_group['sm2_p', 'count'] / cross_comparison_group['sm2_p', 'count'].sum()
plt.scatter(cross_comparison_group.index, cross_comparison_group['SM2_B-W', 'mean'], s=cross_comparison_group['sm2_p', 'percent'] * 1024, alpha=0.5)
plt.plot(cross_comparison_group['SM2_B-W', 'mean'], label='SM2 by FSRS')

plt.legend(loc='lower center')
plt.grid(linestyle='--')
plt.title("SM2 vs. FSRS")
plt.xlabel('Predicted R')
plt.ylabel('B-W Metric')
plt.xlim(0, 1)
plt.xticks(np.arange(0, 1.1, 0.1))
plt.show()
ax.scatter(cross_comparison_group.index, cross_comparison_group['SM2_B-W', 'mean'], s=cross_comparison_group['sm2_p', 'percent'] * 1024, alpha=0.5)
ax.plot(cross_comparison_group['SM2_B-W', 'mean'], label='SM2 by FSRS')

ax.legend(loc='lower center')
ax.grid(linestyle='--')
ax.set_title("SM2 vs. FSRS")
ax.set_xlabel('Predicted R')
ax.set_ylabel('B-W Metric')
ax.set_xlim(0, 1)
ax.set_xticks(np.arange(0, 1.1, 0.1))
return fig1, fig2

# code from https://github.com/papousek/duolingo-halflife-regression/blob/master/evaluation.py
def load_brier(predictions, real, bins=20):
Expand Down Expand Up @@ -849,27 +845,28 @@ def plot_brier(predictions, real, bins=20):
rmse = np.sqrt(mean_squared_error(bin_correct_means, bin_prediction_means, sample_weight=bin_counts))
print(f"R-squared: {r2:.4f}")
print(f"RMSE: {rmse:.4f}")
plt.figure()
ax = plt.gca()
ax.set_xlim([0, 1])
ax.set_ylim([0, 1])
plt.grid(True)
fig = plt.figure()
ax1 = fig.add_subplot(111)
ax1.set_xlim([0, 1])
ax1.set_ylim([0, 1])
ax1.grid(True)
fit_wls = sm.WLS(bin_correct_means, sm.add_constant(bin_prediction_means), weights=bin_counts).fit()
print(fit_wls.params)
y_regression = [fit_wls.params[0] + fit_wls.params[1]*x for x in bin_prediction_means]
plt.plot(bin_prediction_means, y_regression, label='Weighted Least Squares Regression', color="green")
plt.plot(bin_prediction_means, bin_correct_means, label='Actual Calibration', color="#1f77b4")
plt.plot((0, 1), (0, 1), label='Perfect Calibration', color="#ff7f0e")
ax1.plot(bin_prediction_means, y_regression, label='Weighted Least Squares Regression', color="green")
ax1.plot(bin_prediction_means, bin_correct_means, label='Actual Calibration', color="#1f77b4")
ax1.plot((0, 1), (0, 1), label='Perfect Calibration', color="#ff7f0e")
bin_count = brier['detail']['bin_count']
counts = np.array(bin_counts)
bins = (np.arange(bin_count) + 0.5) / bin_count
plt.legend(loc='upper center')
plt.xlabel('Predicted R')
plt.ylabel('Actual R')
plt.twinx()
plt.ylabel('Number of reviews')
plt.bar(bins, counts, width=(0.8 / bin_count), ec='k', lw=.2, alpha=0.5, label='Number of reviews')
plt.legend(loc='lower center')
ax1.legend(loc='upper center')
ax1.set_xlabel('Predicted R')
ax1.set_ylabel('Actual R')
ax2 = ax1.twinx()
ax2.set_ylabel('Number of reviews')
ax2.bar(bins, counts, width=(0.8 / bin_count), ec='k', lw=.2, alpha=0.5, label='Number of reviews')
ax2.legend(loc='lower center')
return fig

def sm2(history):
ivl = 0
Expand Down
Loading

0 comments on commit 1f402f8

Please sign in to comment.