From c5273a313a1f472ec662d4ef476929ed883bbece Mon Sep 17 00:00:00 2001 From: duemoo Date: Fri, 4 Oct 2024 15:39:39 +0900 Subject: [PATCH] cleanup & fix bugs in ppl_analysis.py --- analysis/ppl_analysis.py | 368 +++------------------------------------ 1 file changed, 26 insertions(+), 342 deletions(-) diff --git a/analysis/ppl_analysis.py b/analysis/ppl_analysis.py index 9029de7..0c843ff 100644 --- a/analysis/ppl_analysis.py +++ b/analysis/ppl_analysis.py @@ -22,32 +22,6 @@ def update(num): NUM_REMOVED += num -def fraction_larger_than_first(lst): - if len(lst) <= 1: - return 0 # Avoid division by zero if the list has one or no elements - - first_value = lst[0] - count_larger = sum(1 for x in lst[1:] if x > first_value) - - return count_larger / (len(lst) - 1) - - -def check_success(ppls, threshold=0.2, debug=False): - # print(ppls) - # print((ppls[0]-min(ppls[:50])), '\t', abs(max(ppls[-50:])-min(ppls[-50:]))) - # assert False - if (ppls[0]-min(ppls[:25])) > 1.5*(abs(max(ppls[-25:])-min(ppls[-25:]))): - # if fraction_larger_than_first(ppls)ppls[1]: - print("Fail due to the initial ppl increase") - if fraction_larger_than_first(ppls)= lower_bound[group.name]) & (group['value'] <= upper_bound[group.name])] - ).reset_index(drop=True) - -# - # fig, axes = plt.subplots(2, 1, figsize=(12, 16), gridspec_kw={'height_ratios': [3, 1]}) -# - # Violin plot - sns.violinplot(x='category', y='value', data=filtered_df) - # axes[0, 0].set_xlabel('Bins of sim (intervals)') - # axes[0, 0].set_ylabel('ppl values') - # axes[0, 0].tick_params(axis='x', rotation=45) - - # # Histogram - # sns.histplot(ax=axes[1, i], x=sim_filtered, bins=bins, kde=False) - # axes[1, 0].set_ylabel('Count') - # axes[1, 0].set_xticks(bins) - # axes[1, 0].tick_params(axis='x', rotation=45) - - # plt.tight_layout() - - # Create the directory if it doesn't exist - os.makedirs(f"violin/{exp_name}", exist_ok=True) - filename = f"violin/{exp_name}/{interval}_contain.png" - plt.savefig(filename) - - - # print(len(ppl)) - # print(len(len_filtered)) - - for i, (len_filtered, ppl_filtered) in enumerate([(len_filtered, ppl)]): - # Binning the data - bins = np.linspace(min(len_filtered), max(len_filtered), num=9) - bin_labels = [f"{bins[i]:.2f} - {bins[i+1]:.2f}" for i in range(len(bins)-1)] - bin_map = {label: bins[i] for i, label in enumerate(bin_labels)} - sim_binned = np.digitize(len_filtered, bins, right=True) - sim_binned = [min(i, len(bin_labels)) for i in sim_binned] - - # Creating a DataFrame for plotting - data = pd.DataFrame({ - 'Group': [bin_labels[i-1] for i in sim_binned], - 'Value': ppl_filtered - }) - - # Add a numerical column for sorting - data['SortKey'] = data['Group'].map(bin_map) - data.sort_values('SortKey', inplace=True) - - # Now use 'Group' for plotting labels but sort by 'SortKey' - sns.violinplot(ax=axes[0], x='Group', y='Value', data=data, order=sorted(data['Group'].unique(), key=lambda x: bin_map[x])) - axes[0].set_xlabel('Length') - axes[0].set_ylabel('ppl values') - axes[0].tick_params(axis='x', rotation=45) - - # Histogram - sns.histplot(ax=axes[1], x=len_filtered, bins=bins, kde=False) - # axes[1, i].set_title(f'Histogram of {label}') - # axes[1, i].set_xlabel(label) - axes[1].set_ylabel('Count') - axes[1].set_xticks(bins) - axes[1].tick_params(axis='x', rotation=45) - - plt.tight_layout() - - # Create the directory if it doesn't exist - os.makedirs(f"violin/{exp_name}", exist_ok=True) - filename = f"violin/{exp_name}/{interval}_{mode}_len.png" - plt.savefig(filename) - - - -def draw_violin(sim_dict, ppl, ppl_success, hard, interval, exp_name): - sim_jaccard = sim_dict["jaccard"] - - sim_all_filtered, ppl_all_filtered = filter_data(ppl, sim_jaccard) - sim_success_filtered, ppl_success_filtered = filter_data(ppl_success, sim_jaccard) - - # Setup the figure with subplots - fig, axes = plt.subplots(2, 2, figsize=(24, 16), gridspec_kw={'height_ratios': [3, 1]}) - - for i, (sim_filtered, ppl_filtered) in enumerate([(sim_all_filtered, ppl_all_filtered), (sim_success_filtered, ppl_success_filtered)]): - # Binning the data - bins = np.linspace(min(sim_filtered), max(sim_filtered), num=9) - bin_labels = [f"{bins[i]:.2f} - {bins[i+1]:.2f}" for i in range(len(bins)-1)] - sim_binned = np.digitize(sim_filtered, bins, right=True) - sim_binned = [min(i, len(bin_labels)) for i in sim_binned] - - # Creating a DataFrame for plotting - data = pd.DataFrame({ - 'Group': [bin_labels[i-1] for i in sim_binned], - 'Value': ppl_filtered - }) - - data.sort_values('Group', inplace=True) - - # Violin plot - sns.violinplot(ax=axes[0, i], x='Group', y='Value', data=data) - # axes[0, i].set_title(f'Violin plo with outliers removed') - axes[0, i].set_xlabel('Bins of sim (intervals)') - axes[0, i].set_ylabel('ppl values') - axes[0, i].tick_params(axis='x', rotation=45) - - # Histogram - sns.histplot(ax=axes[1, i], x=sim_filtered, bins=bins, kde=False) - # axes[1, i].set_title(f'Histogram of {label}') - # axes[1, i].set_xlabel(label) - axes[1, i].set_ylabel('Count') - axes[1, i].set_xticks(bins) - axes[1, i].tick_params(axis='x', rotation=45) - - plt.tight_layout() - - # Create the directory if it doesn't exist - os.makedirs(f"violin/{exp_name}", exist_ok=True) - filename = f"violin/{exp_name}/{interval}_{'hard' if hard else 'easy'}.png" - plt.savefig(filename) - - def round(num): if num%10<5: return num//10*10-1 @@ -246,43 +77,6 @@ def levenshtein(s1, s2, debug=False): return previous_row[-1] -def spectrum_analysis(values): - """ - Perform linear detrending and Fourier analysis on a time-series data. - - :param values: List of floats representing the time-series data. - :return: Plot of the frequency spectrum. - """ - - # Time parameters (assuming equal spacing) - N = len(values) # Number of data points - T = 1.0 / N # Assuming unit time interval between data points - - # Linear Detrending - times = np.arange(N) - detrended = values - np.poly1d(np.polyfit(times, values, 1))(times) - - # Fourier Transform - freq_values = fftpack.fft(detrended) - freqs = fftpack.fftfreq(N, T) - freq_magnitudes = np.abs(freq_values) * 1 / N - - # Normalizing to make the area under the curve 1 - total_area = np.sum(freq_magnitudes) * (freqs[1] - freqs[0]) # Approximate the integral - normalized_magnitudes = freq_magnitudes / total_area - - # Plotting the Frequency Spectrum - # plt.figure(figsize=(10, 5)) - # plt.plot(freqs[:N // 2][1:], normalized_magnitudes[:N // 2][1:]) # Plot only the positive frequencies - # plt.xlabel('Frequency') - # plt.ylabel('Amplitude') - # plt.title('Frequency Spectrum') - # plt.grid(True) - # plt.show() - # plt.savefig('spectrum_mem.png') - return freqs[:N // 2][1:], normalized_magnitudes[:N // 2][1:] - - def remove_outliers_iqr(data, multiplier=2.0, log=False, is_retainability=False): # print(data) q1 = np.percentile(data, 25) @@ -320,76 +114,6 @@ def sort_idx(scores): return [index for index, value in sorted_pairs] -def get_perturb_indices(l, max_len=500, margin=25): - if len(l)==0: - return [] - else: - result = [] - for i in range(len(l)-1): - if l[i]+margin=80 @@ -603,7 +282,6 @@ def measure_scores(result, interval=50, skip_log_learnability=False, skip_log_fo gen_learnability_per_ex, gen_forgetting_per_ex, gen_init_per_ex, gen_last_per_ex, gen_learnability_step_per_ex, gen_forgetting_step_per_ex = get_probe_measurements(gen_ppls, gen_learnability_per_ex, gen_forgetting_per_ex, gen_learnability_step_per_ex, gen_forgetting_step_per_ex, gen_init_per_ex, gen_last_per_ex, interval, relative=relative, absolute=absolute, ex_idx=ex_idx, j=j, mode='gen', once=is_once) gen_hard_learnability_per_ex, gen_hard_forgetting_per_ex, gen_hard_last_per_ex, gen_hard_last_per_ex, gen_hard_learnability_step_per_ex, gen_hard_forgetting_step_per_ex = get_probe_measurements(gen_hard_ppls, gen_hard_learnability_per_ex, gen_hard_forgetting_per_ex, gen_hard_learnability_step_per_ex, gen_hard_forgetting_step_per_ex, gen_hard_init_per_ex, gen_hard_last_per_ex, interval, absolute=absolute, relative=relative, ex_idx=ex_idx, j=j, mode='hard-gen', once=is_once) - if ex_idx+1 in [40, 80, 120]: # remove outliers for k in mem_learnability_per_ex.keys(): @@ -614,10 +292,6 @@ def measure_scores(result, interval=50, skip_log_learnability=False, skip_log_fo gen_forgetting_per_ex[k] = remove_outliers_iqr(gen_forgetting_per_ex[k], log=log, is_retainability=k=='target') mem_forgetting_per_ex[k] = remove_outliers_iqr(mem_forgetting_per_ex[k], log=log, is_retainability=k=='target') gen_hard_forgetting_per_ex[k] = remove_outliers_iqr(gen_hard_forgetting_per_ex[k], log=log, is_retainability=k=='target') - - # print(len(zip(mem_learnability_step_per_ex))) - # print(len(zip(mem_learnability_step_per_ex[0]))) - # print(gen_forgetting_step_per_ex) mem_learnability_step_per_ex = [remove_outliers_iqr([ai for ai in a if ai is not None]) for a in zip(*mem_learnability_step_per_ex)] gen_learnability_step_per_ex = [remove_outliers_iqr([ai for ai in a if ai is not None]) for a in zip(*gen_learnability_step_per_ex)] @@ -632,6 +306,9 @@ def measure_scores(result, interval=50, skip_log_learnability=False, skip_log_fo forgetting_score["paraphrase"]["mem"] = mean(mem_forgetting_per_ex['target']) forgetting_score["paraphrase"]["gen"] = mean(gen_forgetting_per_ex['target']) forgetting_score["paraphrase"]["gen_hard"] = mean(gen_hard_forgetting_per_ex['target']) + learnability_score["paraphrase"]["mem"] = mean(mem_learnability_per_ex['target']) + learnability_score["paraphrase"]["gen"] = mean(gen_learnability_per_ex['target']) + learnability_score["paraphrase"]["gen_hard"] = mean(gen_hard_learnability_per_ex['target']) step_forgetting_score["paraphrase"]["mem"] = [mean(a) for a in mem_forgetting_step_per_ex] step_forgetting_score["paraphrase"]["gen"] = [mean(a) for a in gen_forgetting_step_per_ex] step_forgetting_score["paraphrase"]["gen_hard"] = [mean(a) for a in gen_hard_forgetting_step_per_ex] @@ -642,6 +319,9 @@ def measure_scores(result, interval=50, skip_log_learnability=False, skip_log_fo forgetting_score["duplication"]["mem"] = mean(mem_forgetting_per_ex['target']) forgetting_score["duplication"]["gen"] = mean(gen_forgetting_per_ex['target']) forgetting_score["duplication"]["gen_hard"] = mean(gen_hard_forgetting_per_ex['target']) + learnability_score["duplication"]["mem"] = mean(mem_learnability_per_ex['target']) + learnability_score["duplication"]["gen"] = mean(gen_learnability_per_ex['target']) + learnability_score["duplication"]["gen_hard"] = mean(gen_hard_learnability_per_ex['target']) step_forgetting_score["duplication"]["mem"] = [mean(a) for a in mem_forgetting_step_per_ex] step_forgetting_score["duplication"]["gen"] = [mean(a) for a in gen_forgetting_step_per_ex] step_forgetting_score["duplication"]["gen_hard"] = [mean(a) for a in gen_hard_forgetting_step_per_ex] @@ -652,6 +332,9 @@ def measure_scores(result, interval=50, skip_log_learnability=False, skip_log_fo forgetting_score["once"]["mem"] = mean(mem_forgetting_per_ex['target']) forgetting_score["once"]["gen"] = mean(gen_forgetting_per_ex['target']) forgetting_score["once"]["gen_hard"] = mean(gen_hard_forgetting_per_ex['target']) + learnability_score["once"]["mem"] = mean(mem_learnability_per_ex['target']) + learnability_score["once"]["gen"] = mean(gen_learnability_per_ex['target']) + learnability_score["once"]["gen_hard"] = mean(gen_hard_learnability_per_ex['target']) step_forgetting_score["once"]["mem"] = [a for a in mem_forgetting_step_per_ex] step_forgetting_score["once"]["gen"] = [a for a in gen_forgetting_step_per_ex] step_forgetting_score["once"]["gen_hard"] = [a for a in gen_hard_forgetting_step_per_ex] @@ -710,7 +393,8 @@ def measure_scores(result, interval=50, skip_log_learnability=False, skip_log_fo if skip_log_forgetting: with open(f"step_eval/{args.exp_name}", 'w') as f: json.dump({'effectivity': step_learnability_score, 'retainability': step_forgetting_score}, f, indent=4) - + with open(f"learnability_eval/{args.exp_name}", 'w') as f: + json.dump(learnability_score, f, indent=4) return forgetting_score @@ -806,8 +490,8 @@ def plot_perplexity(rows, cols, plot_number, steps, x_mem, x_gen, xlabel, ylabel # ax.xaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: f'{int(x/1000)}k')) # Format tick labels as 'k' units - ymin, ymax = -2.0, 2.5 - # ymin, ymax = -0.1, 1.2 + # ymin, ymax = 0.0, 2.5 + ymin, ymax = -0.1, 1.8 ax.set_ylim(ymin, ymax) # ymin, ymax = ax.get_ylim() # ymax=500 @@ -1065,7 +749,7 @@ def main(args): # Add arguments - parser.add_argument('--base_dir', type=str) + parser.add_argument('--base_dir', type=str, default='/home/hoyeon/OLMo') parser.add_argument('--save_dir', type=str, default="figs") parser.add_argument('--exp_name', type=str, required=True) parser.add_argument('--mode', type=str, default="draw_figures")