diff --git a/README.md b/README.md index 45a0b3f..54c1097 100644 --- a/README.md +++ b/README.md @@ -1,212 +1,135 @@
- - OLMo Logo -
-
-

OLMo: Open Language Model

+

How Do Large Language Models Acquire Factual Knowledge During Pretraining? (https://arxiv.org/abs/2406.11813)

-

- - GitHub License - - - GitHub release - - - Paper URL - -

- -OLMo is a repository for training and using AI2's state-of-the-art open language models. -It is built by scientists, for scientists. -## Installation -First install [PyTorch](https://pytorch.org) according to the instructions specific to your operating system. +This repository contains the code, dataset, and experimental log data for the paper 'How Do Large Language Models Acquire Factual Knowledge During Pretraining?'. The code is based on the [OLMo](https://github.com/allenai/OLMo) project, with modifications to support knowledge injection during pre-training and additional analysis tools. -To install from source (recommended for training/fine-tuning) run: +You can also find the Fictional Knowledge dataset on HuggingFace: https://huggingface.co/datasets/kaist-ai/fictional-knowledge -```bash -git clone https://github.com/allenai/OLMo.git -cd OLMo -pip install -e .[all] -``` +## Key Differences from Original OLMo Repository -Otherwise you can install the model code by itself directly from PyPI with: +1. Modified `olmo/train.py` to: + - Apply knowledge injection during pre-training + - Log perplexity measurement data on fictional knowledge dataset probes during pretraining +2. Added post-processing of perplexity logs and analysis utils in `analysis/` folder +3. Added post-processed perplexity logs for main experiments using three OLMo-7B intermediate checkpoints in `analysis/results/` folder (can be downloaded using Git LFS) + +## Installation ```bash -pip install ai2-olmo +git clone factual-knowledge-acquisition +cd factual-knowledge-acquisition +pip install -e . ``` -## Models - -### Overview - -The core models in the OLMo family released so far are (all trained on the [Dolma dataset](https://huggingface.co/datasets/allenai/dolma)): -| Model | Training Tokens | Context Length | Training Config | W&B Logs | Data Order File(s) ☨ | -|-------|-----------------|:--------------:|-----------------|----------|--------------------| -| [OLMo 1B](https://huggingface.co/allenai/OLMo-1B) | 3 Trillion | 2048 | [configs/official/OLMo-1B.yaml](https://github.com/allenai/OLMo/blob/main/configs/official/OLMo-1B.yaml) | [wandb.ai/…/OLMo-1B](https://wandb.ai/ai2-llm/OLMo-1B/reports/OLMo-1B--Vmlldzo2NzY1Njk1) | [epoch 1](https://olmo-checkpoints.org/ai2-llm/olmo-small/46zc5fly/train_data/global_indices.npy) | -| [OLMo 7B](https://huggingface.co/allenai/OLMo-7B) | 2.5 Trillion | 2048 | [configs/official/OLMo-7B.yaml](https://github.com/allenai/OLMo/blob/main/configs/official/OLMo-7B.yaml) | [wandb.ai/…/OLMo-7B](https://wandb.ai/ai2-llm/OLMo-7B/reports/OLMo-7B--Vmlldzo2NzQyMzk5) | [epoch 1](https://olmo-checkpoints.org/ai2-llm/olmo-medium/wvc30anm/train_data/global_indices.npy), [epoch 2](https://olmo-checkpoints.org/ai2-llm/olmo-medium/wd2gxrza/train_data/global_indices.npy) | -| [OLMo 7B Twin 2T](https://huggingface.co/allenai/OLMo-7B-Twin-2T) | 2 Trillion | 2048 | [configs/official/OLMo-7B.yaml](https://github.com/allenai/OLMo/blob/main/configs/official/OLMo-7B.yaml) | [wandb.ai/…/OLMo-7B-Twin-2T](https://wandb.ai/ai2-llm/OLMo-7B/reports/OLMo-7B-Twin-2T--Vmlldzo2NzU0NTIz) | [epoch 1](https://olmo-checkpoints.org/ai2-llm/olmo-medium/wvc30anm/train_data/global_indices.npy) | - -> ☨ *See [Inspecting training data](#inspecting-training-data) below for usage.* - -### Checkpoints - -URLs to checkpoints at intermediate steps of the models' trainings can be found in the csv files under [`checkpoints/official/`](https://github.com/allenai/OLMo/blob/main/checkpoints/official). These 'directory' URLs cannot currently be directly accessed, but files within the directory are publicly accessible. These URLs can also be provided to the training script to resume training from the checkpoint (see [Training](#training)). Each checkpoint directory consists of: - -- `config.yaml`: the config at that training step. -- `model.pt`, `optim.pt`, `train.pt`: model, optimizer and training state at that training step. - -## Inference - -You can utilize our Hugging Face integration to run inference on the olmo checkpoints: - -```python -from hf_olmo import * # registers the Auto* classes +## Training -from transformers import AutoModelForCausalLM, AutoTokenizer - -olmo = AutoModelForCausalLM.from_pretrained("allenai/OLMo-7B") -tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-7B") - -message = ["Language modeling is "] -inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False) -response = olmo.generate(**inputs, max_new_tokens=100, do_sample=True, top_k=50, top_p=0.95) -print(tokenizer.batch_decode(response, skip_special_tokens=True)[0]) +1. Extract Dolma corpus starting from step 360000: +```bash +python analysis/extract_data.py \ + --input_path \ + --output_path \ + --start_step 360000 ``` +- Note that the path to the full Dolma corpus should be specified in `configs/official/OLMo-7B.yaml` before running this code. +- Running this will generate `dolma_extracted/360000-363000.npy` file, which contains tokenized batch sequence data at training step 360000-363000. -Alternatively, with the Hugging Face pipeline abstraction: +2. Example training script: -```python -from transformers import pipeline -olmo_pipe = pipeline("text-generation", model="allenai/OLMo-7B") -print(olmo_pipe("Language modeling is")) +```bash +BASE_STEP=5000 +RUN_NAME=EXP_NAME + +export SCRATCH_DIR='PATH_TO_THE_REPOSITORY' + +python -m torch.distributed.run --nproc_per_node=4 ${SCRATCH_DIR}/scripts/train.py configs/official/OLMo-1B-105.yaml \ + --run_name=${RUN_NAME} \ + --data.paths=[${SCRATCH_DIR}/dolma_extracted/360000-363100.npy] \ + --load_path=PATH_TO_THE_CHECKPOINT/step${BASE_STEP}-unsharded \ + --base_step=${BASE_STEP} \ + --inject_indices_map=${SCRATCH_DIR/analysis/inject_indices_map/7b-360000.pkl} \ + --save_overwrite ``` -### Inference on finetuned checkpoints +## Analysis -If you finetune the model using the code above, you can use the conversion script to convert a native OLMo checkpoint to a Hugging Face-compatible checkpoint +### 1. Post-process Perplexity Logs ```bash -python hf_olmo/convert_olmo_to_hf.py --checkpoint-dir /path/to/checkpoint -``` - -### Quantization - -```python -olmo = AutoModelForCausalLM.from_pretrained("allenai/OLMo-7B", torch_dtype=torch.float16, load_in_8bit=True) # requires bitsandbytes +python analysis/postprocess_ppl.py \ + --base_dir PATH_TO_THE_REPOSITORY \ + --exp_name EXP_NAME ``` -The quantized model is more sensitive to typing / cuda, so it is recommended to pass the inputs as inputs.input_ids.to('cuda') to avoid potential issues. - -## Reproducibility - -### Training - -The configs used to train the official OLMo models are provided in the [`configs/official/`](https://github.com/allenai/OLMo/blob/main/configs/official) directory. - -Note that while the training and validation data is public and free to download, the paths to the data within those configs are pointed at a CloudFlare R2 bucket, which requires an API key for programmatic access. -So in order to use any of these configs to reproduce a training run you'll first have to download the corresponding data to a location of your choosing and then update the paths in the config accordingly. - -You can derive the public HTTP URL from an R2 URL by replacing `r2://olmo-data` with `https://olmo-data.org`. -For example, if the R2 data URL is: - -`r2://olmo-data/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-000-00000.npy` +### 2. Run Analysis -then the corresponding public URL is: - -`https://olmo-data.org/preprocessed/olmo-mix/v1_5/gpt-neox-20b-pii-special/part-000-00000.npy` - -Once you've updated the data paths in the config you can launch a training run via `torchrun`. For example, to launch the 1B model training on a single 8x GPU node, you would run: +#### a. Draw Loss Figures (reproduce Fig.2 in the paper) ```bash -torchrun --nproc_per_node=8 scripts/train.py configs/official/OLMo-1B.yaml +exp_name=EXP_NAME +save_dir=PATH_TO_SAVED_FIGURES +base_dir=PATH_TO_THE_REPOSITORY + +python ppl_analysis.py \ + --mode=draw_figures \ + --base_dir=${base_dir} \ + --exp_name=${exp_name} \ + --save_dir=${save_dir} \ + --no_take_exp ``` -You can use the same method to launch multi-node jobs as well. See [the documentation](https://pytorch.org/docs/stable/elastic/run.html) for `torchrun` to understand the additional arguments you'll need to configure the rendezvous backend / endpoint. - -To resume training from a checkpoint, you can pass its path (local or URL) -to `scripts/train.py` with the `--load_path` arguments. For example, to resume training from step 1000 of the OLMo 1B run: +#### b. Measure Effectivity Scores ```bash -torchrun --nproc_per_node=8 scripts/train.py configs/official/OLMo-1B.yaml --load_path https://olmo-checkpoints.org/ai2-llm/olmo-small/w1r5xfzt/step1000-unsharded +exp_name=EXP_NAME +base_dir=PATH_TO_THE_REPOSITORY + +python analysis/ppl_analysis.py \ + --mode=measure_scores \ + --base_dir=${base_dir} \ + --skip_log_forgetting \ + --absolute \ + --no_take_exp \ + --exp_name=${exp_name} ``` -### Inspecting training data - -You may be interesting in inspecting the exact tokens that composed a particular batch during the training of one of the OLMo models. -We provide tools to do this, but first you'll need to download the data as above (unless you have an R2 API key) and update the corresponding config accordingly. - -Then take note of the URL of the data order file you want, which can be found in the [Models Overview](#models-overview) table. For example, the data order file for the first epoch of the OLMo-7B model is [https://olmo-checkpoints.org/ai2-llm/olmo-medium/wvc30anm/train_data/global_indices.npy](https://olmo-checkpoints.org/ai2-llm/olmo-small/46zc5fly/train_data/global_indices.npy). - -Once you have that you can use this snippet to inspect the data within a particular batch: - -```python -import numpy as np -from cached_path import cached_path - -from olmo.config import TrainConfig -from olmo.data import build_memmap_dataset - -# Update these paths to what you want: -data_order_file_path = cached_path("https://olmo-checkpoints.org/ai2-llm/olmo-medium/wvc30anm/train_data/global_indices.npy") -train_config_path = "configs/official/OLMo-7B.yaml" - - -cfg = TrainConfig.load(train_config_path) -dataset = build_memmap_dataset(cfg, cfg.data) -batch_size = cfg.global_train_batch_size -global_indices = np.memmap(data_order_file_path, mode="r+", dtype=np.uint32) - - -def get_batch_instances(batch_idx: int) -> list[list[int]]: - batch_start = batch_idx * batch_size - batch_end = (batch_idx + 1) * batch_size - batch_indices = global_indices[batch_start:batch_end] - batch_instances = [] - for index in batch_indices: - token_ids = dataset[index]["input_ids"].tolist() - batch_instances.append(token_ids) - return batch_instances - +#### c. Measure Retainability Scores -# Get all 2048 x 2048 token IDs in the first batch. -get_batch_instances(0) +```bash +exp_name=EXP_NAME +base_dir=PATH_TO_THE_REPOSITORY + +python analysis/ppl_analysis.py \ + --mode=measure_scores \ + --base_dir=${base_dir} \ + --skip_log_effectivity \ + --absolute \ + --no_take_exp \ + --exp_name=${exp_name} ``` +After running the retainability measurement, you'll find the log file in `analysis/forgetting_measurements/` folder with the same `exp_name`. -## Fine-tuning - -To fine-tune an OLMo model using our trainer you'll first need to prepare your dataset by tokenizing it and saving the tokens IDs to a flat numpy memory-mapped array. See [`scripts/prepare_tulu_data.py`](./scripts/prepare_tulu_data.py) for an example with the Tulu V2 dataset, which can be easily modified for other datasets. - -Next, prepare your training config. There are many examples in the [`configs/`](https://github.com/allenai/OLMo/blob/main/configs) directory that you can use as a starting point. The most important thing is to make sure the model parameters (the `model` field in the config) match up with the checkpoint you're starting from. To be safe you can always start from the config that comes with the model checkpoint. At a minimum you'll need to make the following changes to the config or provide the corresponding overrides from the command line: - -- Update `load_path` to point to the checkpoint you want to start from. -- Set `reset_trainer_state` to `true`. -- Update `data.paths` to point to the `token_ids.npy` file you generated. -- Optionally update `data.label_mask_paths` to point to the `label_mask.npy` file you generated, unless you don't need special masking for the loss. -- Update `evaluators` to add/remove in-loop evaluations. +#### d. Draw Retainability Figures & Measure Decay Constants -Once you're satisfied with your training config, you can launch the training job via `torchrun`. For example: - -``` -torchrun --nproc_per_node=8 scripts/train.py {path_to_train_config} \ - --data.paths=[{path_to_data}/input_ids.npy] \ - --data.label_mask_paths=[{path_to_data}/label_mask.npy] \ - --load_path={path_to_checkpoint} \ - --reset_trainer_state +```bash +python analysis/forgetting_plot.py \ + --exp_name PATH_TO_FORGETTING_MEASUREMENT_FILE ``` -Note: passing CLI overrides like `--reset_trainer_state` is only necessary if you didn't update those fields in your config. +## Citation -## Evaluation - -Additional tools for evaluating OLMo models are available at the [OLMo Eval](https://github.com/allenai/ai2-olmo-eval) repo. - -## Citing +If you use this code in your research, please cite both this repository and the original OLMo paper: ```bibtex +@article{chang2024large, + title={How Do Large Language Models Acquire Factual Knowledge During Pretraining?}, + author={Chang, Hoyeon and Park, Jinho and Ye, Seonghyeon and Yang, Sohee and Seo, Youngkyung and Chang, Du-Seong and Seo, Minjoon}, + journal={arXiv preprint arXiv:2406.11813}, + year={2024} +} + @article{OLMo, title={OLMo: Accelerating the Science of Language Models}, author={Dirk Groeneveld and Iz Beltagy and Pete Walsh and Akshita Bhagia and Rodney Kinney and Oyvind Tafjord and A. Jha and Hamish Ivison and Ian Magnusson and Yizhong Wang and Shane Arora and David Atkinson and Russell Authur and Khyathi Raghavi Chandu and Arman Cohan and Jennifer Dumas and Yanai Elazar and Yuling Gu and Jack Hessel and Tushar Khot and William Merrill and Jacob Daniel Morrison and Niklas Muennighoff and Aakanksha Naik and Crystal Nam and Matthew E. Peters and Valentina Pyatkin and Abhilasha Ravichander and Dustin Schwenk and Saurabh Shah and Will Smith and Emma Strubell and Nishant Subramani and Mitchell Wortsman and Pradeep Dasigi and Nathan Lambert and Kyle Richardson and Luke Zettlemoyer and Jesse Dodge and Kyle Lo and Luca Soldaini and Noah A. Smith and Hanna Hajishirzi}, @@ -215,3 +138,7 @@ Additional tools for evaluating OLMo models are available at the [OLMo Eval](htt journal={arXiv preprint}, } ``` + +## License + +Apache 2.0 \ No newline at end of file diff --git a/analysis/extract_data.py b/analysis/extract_data.py index c276490..acc3dc5 100644 --- a/analysis/extract_data.py +++ b/analysis/extract_data.py @@ -37,15 +37,15 @@ def split_array(data, chunk_size): def save_chunks(data, chunk_size, directory='dolma_extracted'): - # if not os.path.exists(directory): - # os.makedirs(directory) + if not os.path.exists(directory): + os.makedirs(directory) for i, chunk in enumerate(split_array(data, chunk_size)): filename = f"{directory}/part-{i:05d}.npy" np.save(filename, chunk) print(f"Saved {filename}") -batch_indices = range(360000,361024) +batch_indices = range(360000,363000) extracted_dataset = [] print(batch_indices) @@ -53,4 +53,4 @@ def save_chunks(data, chunk_size, directory='dolma_extracted'): extracted_dataset.extend(get_batch_instances(idx)) print(f"len extracted data: {len(extracted_dataset)}") -save_chunks(extracted_dataset, 1024) +save_chunks(extracted_dataset, 3000) diff --git a/analysis/forgetting_plot.py b/analysis/forgetting_plot.py new file mode 100644 index 0000000..ebcdc81 --- /dev/null +++ b/analysis/forgetting_plot.py @@ -0,0 +1,222 @@ +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import argparse +import json +import os +from scipy.optimize import curve_fit +from sklearn.metrics import mean_squared_error +import matplotlib.ticker as ticker +from matplotlib.ticker import LogFormatter + + +def exponential_decay(x, a, b, c): + return a * np.exp(-b * x) + c + +def logarithmic_decay(x, a, b): + return a * np.log(x) + b + +def fit_exp_linear(t, y, C=0): + y = y - C + valid_indices = y > 0 # Create a mask of valid indices where y is positive + y = y[valid_indices] # Filter y using the mask + t = t[valid_indices] + y = np.log(y) + K, A_log = np.polyfit(t, y, 1) + A = np.exp(A_log) + return A, K + +def fit_models(x_data, y_data, fname): + # Fit exponential decay + # popt_exp, pcov_exp = curve_fit(exponential_decay, x_data, y_data, maxfev=10000) + # print('!!!!!!!!!!!!!!!!!!!!!') + # print(popt_exp) + # print('!!!!!!!!!!!!!!!!!!!!!') + # y_pred_exp = exponential_decay(x_data, *popt_exp) + # rmse_exp = np.sqrt(mean_squared_error(y_data, y_pred_exp)) + C0 = 0 + A, K = fit_exp_linear(x_data, y_data, C0) + y_pred_exp = exponential_decay(x_data, A, -K, C0) + print(A, K) + rmse_exp = np.sqrt(mean_squared_error(y_data, y_pred_exp)) + + # Fit logarithmic decay (ensure no zero or negative x-values) + x_data_log = x_data[x_data > 0] + y_data_log = y_data[x_data > 0] + popt_log, pcov_log = curve_fit(logarithmic_decay, x_data_log, y_data_log, maxfev=10000) + y_pred_log = logarithmic_decay(x_data_log, *popt_log) + rmse_log = np.sqrt(mean_squared_error(y_data_log, y_pred_log)) + + # Print the results + print("Exponential Decay Fit: RMSE =", rmse_exp) + print("Logarithmic Decay Fit: RMSE =", rmse_log) + + # Determine which model fits better + if rmse_exp < rmse_log: + print("Exponential decay provides a better fit.") + else: + print("Logarithmic decay provides a better fit.") + + # Plotting + plt.figure(figsize=(10, 5)) + plt.scatter(x_data, y_data, color='blue', alpha=0.05, label='Data Points') + plt.plot(x_data, y_pred_exp, 'r-', label=f'Exponential Fit: RMSE={rmse_exp:.2f}') + plt.plot(x_data_log, y_pred_log, 'g-', label=f'Logarithmic Fit: RMSE={rmse_log:.2f}') + plt.title('Fit of Exponential Decay and Logarithmic Decay Models') + plt.xlabel('t') + plt.ylabel('Y') + plt.legend() + plt.savefig(f'curve_fit/curve_fit_{fname}.png') + + + +# Function to calculate the slope +def calculate_slope(steps, y_values): + x_values = steps + log_x = np.log(x_values) + slope, _ = np.polyfit(log_x, y_values, 1) + # n = len(y_values) + # se_slope = np.sqrt(ssr / (n-2)) / np.sqrt(np.sum((log_x - np.mean(log_x))**2)) + return slope, se_slope + + +def plot_trendline(ax, x, y, label, color, args): + if args.tokens: + # if True: + ext_x = [1000000, 10**12] + # ext_x = [i for i in range(1, 20000000)] + # if 'bsize' in args.exp_name or '128' in args.exp_name: + # ext_x = [i*128 for i in [50, 100, 1000, 10000, 100000]] # Adjusted to include realistic x values + # else: + # ext_x = [i*2048 for i in [50/16, 100, 1000, 10000, 100000]] # Adjusted to include realistic x values + else: + ext_x = [1, 100, 1000, 10000, 10**6] # Adjusted to include realistic x values + + log_x = np.log10(x) + slope, intercept = np.polyfit(log_x, y, 1) + x_intercept = -intercept/slope + + y_pred = np.polyval([slope, intercept], log_x) + # Calculate SSR (sum of squares of residuals) + ssr = np.sum((y - y_pred) ** 2) + # Calculate SST (total sum of squares) + sst = np.sum((y - np.mean(y)) ** 2) + # Calculate R^2 + r_squared = 1 - (ssr / sst) + # Calculate Standard Error of the Slope + n = len(y) + se_slope = np.sqrt(ssr / (n-2)) / np.sqrt(np.sum((log_x - np.mean(log_x))**2)) + # Calculate the range for the slope + slope_low = slope - se_slope + slope_high = slope + se_slope + + trendline_y = np.log10(ext_x) * slope + intercept + # print(ext_x, '\n', trendline_y) + ax.plot(ext_x, trendline_y, linestyle='dotted', color=color, linewidth=2) + + print(f'Slope: {slope:.2f}, One Sigma: {se_slope:.4f}, R^2: {r_squared:.2f}') + print(f"x-intercept: {x_intercept:.2f}") + return slope, r_squared + + +def plot(x_values, mem_data, gen_data, gen_hard_data, mode, reverse, args): + fig, ax = plt.subplots(figsize=(6,5)) + datasets = [mem_data, gen_data, gen_hard_data] + labels = ['Memorization', 'Semantic', 'Composition'] + colors = ['blue', 'orange', 'red'] + + for index, y in enumerate(datasets): + + if y[0]<0: + y = [-i for i in y] + + if reverse: + y = [(100-i)/100 for i in y] + else: + y = [i/100 for i in y] + + filtered_x_values = [] + filtered_y = [] + for i in range(len(y)): + if y[i]!=-1.0: + filtered_x_values.append(x_values[i]) + filtered_y.append(y[i]) + + # print(filtered_x_values) + # print(filtered_y) + if args.tokens: + if 'bsize' in args.exp_name or True: + if '128' in args.exp_name: + filtered_x_values = [x*128*2048 for x in filtered_x_values] + print('bsize: 128') + elif '512' in args.exp_name: + print('bsize: 512') + filtered_x_values = [x*512*2048 for x in filtered_x_values] + else: + print('bsize:2048') + filtered_x_values = [x*2048*2048 for x in filtered_x_values] + + slope, r_squared = plot_trendline(ax, np.array(filtered_x_values), np.array(filtered_y), labels[index], colors[index], args) + ax.plot(filtered_x_values, filtered_y, 'o', label=f'{labels[index]} (a={-slope:.2f})', color=colors[index], alpha=0.3, markersize=3) + + ax.axhline(y=0, color='gray', linestyle='-', linewidth=0.5) + if reverse: + ax.set_ylim(bottom=0, top=1) + # pass + else: + ax.set_ylim(bottom=0, top=1) + ax.set_xscale('log') + # ax.xaxis.set_major_formatter(LogFormatter()) + if args.tokens: + ax.set_xlabel('Tokens', fontsize=20) + else: + ax.set_xlabel(r'$t$', fontsize=20) + ax.set_ylabel(r'Avg. $\mathcal{R}(q,t)$', fontsize=20) + # ax.set_xticklabels([r'$10^{{{}}}$'.format(xi) for xi in filtered_x_values]) + ax.set_title(mode.capitalize(), fontsize=20) + ax.legend(loc='upper right', prop={'size': 15}, markerscale=3) + + # def log_tick_formatter(val, pos=None): + # return f'{int(np.log10(val))}' + # ax.xaxis.set_major_formatter(ticker.FuncFormatter(log_tick_formatter)) + ax.tick_params(axis='both', labelsize=14) + # fit_models(np.array(filtered_x_values), np.array(filtered_y), mode) + + if args.tokens: + fname = f"{save_dir}/tokens/{args.exp_name.split('/')[-1][:-5]}_{mode}_tokens.pdf" + else: + if reverse: + fname = f"{save_dir}/{args.exp_name.split('/')[-1][:-5]}_{mode}_reversed.pdf" + else: + fname = f"{save_dir}/{args.exp_name.split('/')[-1][:-5]}_{mode}.pdf" + plt.tight_layout() + plt.savefig(fname) + print(f"fig saved to {fname}") + + +if __name__=='__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--exp_name', type=str) + parser.add_argument('--tokens', action='store_true') + parser.add_argument('--reverse', action='store_true') + args = parser.parse_args() + base_dir = 'measure_data' + save_dir = 'learnability_plots/regularized' if 'regularized' in args.exp_name else 'learnability_plots/absolute' + + with open(args.exp_name, 'r') as f: + data = json.load(f) + + for k, v in data.items(): + print(f'\n\n#####\t{k}\t#####\n') + + mem_data, gen_data, gen_hard_data = v["mem"][1:], v["gen"][1:], v["gen_hard"][1:] + data_length = len(data["duplication"]["mem"]) + + steps = range(1, data_length) + plot(steps, mem_data, gen_data, gen_hard_data, mode=k, reverse=args.reverse, args=args) + + + + print('\n\n\n') + \ No newline at end of file diff --git a/analysis/inject_indices_map/7b-0-debug.pkl b/analysis/inject_indices_map/7b-0-debug.pkl new file mode 100644 index 0000000..0198939 Binary files /dev/null and b/analysis/inject_indices_map/7b-0-debug.pkl differ diff --git a/analysis/inject_indices_map/7b-360000-debug.pkl b/analysis/inject_indices_map/7b-360000-debug.pkl new file mode 100644 index 0000000..57e268d Binary files /dev/null and b/analysis/inject_indices_map/7b-360000-debug.pkl differ diff --git a/analysis/ppl_analysis.py b/analysis/ppl_analysis.py index 0c843ff..f279603 100644 --- a/analysis/ppl_analysis.py +++ b/analysis/ppl_analysis.py @@ -1,119 +1,38 @@ -import pickle as pkl -import os import json import matplotlib.pyplot as plt import matplotlib.ticker as ticker from tqdm import tqdm -from collections import defaultdict -import argparse import statistics import numpy as np -from scipy.stats import pearsonr -from scipy import fftpack -import fathon -from fathon import fathonUtils as fu -import powerlaw -import pandas as pd -import seaborn as sns +import os +import argparse NUM_REMOVED=0 def update(num): global NUM_REMOVED NUM_REMOVED += num - -def filter_data(ppl_data, sim): - # Remove outliers based on IQR for ppl - Q1, Q3 = np.percentile(ppl_data, [25, 75]) - IQR = Q3 - Q1 - lower_bound = Q1 - 1.5 * IQR - upper_bound = Q3 + 1.5 * IQR - filtered_indices = [i for i, value in enumerate(ppl_data) if lower_bound <= value <= upper_bound] - sim_filtered = [sim[i] for i in filtered_indices] - ppl_filtered = [ppl_data[i] for i in filtered_indices] - return sim_filtered, ppl_filtered - - -def round(num): - if num%10<5: - return num//10*10-1 - else: - return num//10*10+10-1 - - -def mean_of_arrays(arrays): - """ - Compute the mean of several 1D numpy arrays. - - :param arrays: List of 1D numpy arrays, all of the same length. - :return: A 1D numpy array which is the mean of the input arrays. - """ - stacked_arrays = np.stack(arrays) - mean_array = np.mean(stacked_arrays, axis=0) - return mean_array - - -def levenshtein(s1, s2, debug=False): - if len(s1) < len(s2): - return levenshtein(s2, s1, debug) - - if len(s2) == 0: - return len(s1) - - previous_row = range(len(s2) + 1) - for i, c1 in enumerate(s1): - current_row = [i + 1] - for j, c2 in enumerate(s2): - insertions = previous_row[j + 1] + 1 - deletions = current_row[j] + 1 - substitutions = previous_row[j] + (c1 != c2) - current_row.append(min(insertions, deletions, substitutions)) - - if debug: - print(current_row[1:]) - - previous_row = current_row - - return previous_row[-1] - - def remove_outliers_iqr(data, multiplier=2.0, log=False, is_retainability=False): - # print(data) q1 = np.percentile(data, 25) q3 = np.percentile(data, 75) iqr = q3 - q1 - lower_bound = q1 - multiplier * iqr upper_bound = q3 + multiplier * iqr - filtered_data = [x for x in data if lower_bound <= x <= upper_bound] if log: print(f"{len(data)-len(filtered_data)} datapoints removed") - if is_retainability: update(len(data)-len(filtered_data)) - # print(f"{len(data)-len(filtered_data)}/{len(data)} datapoints removed") - # if (len(data)-len(filtered_data))/len(data)>0.1: - # print("Warning: more than 10 percent of the data is removed as outliers") return filtered_data - def load_json(path): with open(path) as f: - # data = [json.loads(l.strip()) for l in f] data = json.load(f) return data - def mean(l): return sum(l)/len(l) - -def sort_idx(scores): - sorted_pairs = sorted(zip(scores, lst), key=lambda x: x[1], reverse=True) - return [index for index, value in sorted_pairs] - - def get_probe_measurements(ppls, learnability_per_ex, forgetting_per_ex, @@ -130,34 +49,17 @@ def get_probe_measurements(ppls, j=-1, mode=None, once=False): - - # Find the stabilized point - last_train_idx=900 if ex_idx<80 else 0#Hard-coded + last_train_idx = 900 if ex_idx < 80 else 0 for k, v in ppls.items(): - if k=='def': + if k == 'def': continue - values=v[last_train_idx+1:last_train_idx+margin+1] - sp=min(range(len(values)), key=values.__getitem__)+last_train_idx+1 - # min_ppl=min(ppls[train_idx[-1]:train_idx[-1]+margin]) - # min_ppl=mean(v[max(sp-5,0):sp+5]) + values = v[last_train_idx+1:last_train_idx+margin+1] + sp = min(range(len(values)), key=values.__getitem__) + last_train_idx + 1 min_ppl = v[sp] - # init_ppl=ppls[train_idx[-1]-1] - init_ppl=v[0] - # print(interval) - last_ppl=v[sp+interval] - # if k=='target' and min_ppl > init_ppl: - # print(ex_idx, '\t', k, '\t', j, '\t', min_ppl, '\t', init_ppl) - # print('!!!') - # continue - # if k=='target' and sp==last_train_idx: - # print('!!!!') - if not absolute: - # learnability_per_ex[k].append((1-min_ppl/init_ppl)*100) - pass - else: - # learnability_per_ex[k].append(init_ppl-min_ppl) - pass + init_ppl = v[0] + last_ppl = v[sp+interval] + if not normalize: if not relative: if not absolute: @@ -169,46 +71,33 @@ def get_probe_measurements(ppls, forgetting_per_ex[k].append((1-(last_ppl/init_ppl))*100) else: forgetting_per_ex[k].append(last_ppl-init_ppl) - else: - if init_ppl==min_ppl: + if init_ppl == min_ppl: continue forgetting_per_ex[k].append((last_ppl-init_ppl)/(min_ppl-init_ppl)*100) init_per_ex[k].append(init_ppl) last_per_ex[k].append(min_ppl) - if k=='target': + if k == 'target': step_learnability = [] step_forgetting = [] - num_injection = 10 if ex_idx<80 else 1 + num_injection = 10 if ex_idx < 80 else 1 for i in range(num_injection): train_idx = i*100 - values=v[train_idx+1:train_idx+margin+1] - sp=min(range(len(values)), key=values.__getitem__)+train_idx+1 + values = v[train_idx+1:train_idx+margin+1] + sp = min(range(len(values)), key=values.__getitem__) + train_idx + 1 init_ppl = v[train_idx] min_ppl = v[sp] last_ppl = v[sp+50] - if min_ppl < init_ppl: - step_learnability.append(init_ppl-min_ppl) - learnability_per_ex[k].append(init_ppl-min_ppl) - else: - # print('!!!') - step_learnability.append(init_ppl-min_ppl) - learnability_per_ex[k].append(init_ppl-min_ppl) - # step_learnability.append(None) + step_learnability.append(init_ppl-min_ppl) + learnability_per_ex[k].append(init_ppl-min_ppl) + if min_ppl < init_ppl: retainability = (last_ppl-init_ppl)/(min_ppl-init_ppl) - if retainability > 0: - step_forgetting.append((last_ppl-init_ppl)/(min_ppl-init_ppl)) - else: - # step_forgetting.append(0.0) - step_forgetting.append((last_ppl-init_ppl)/(min_ppl-init_ppl)) - # print('@@@') - # else: + step_forgetting.append((last_ppl-init_ppl)/(min_ppl-init_ppl)) if min_ppl == init_ppl: - # step_forgetting.append((last_ppl-init_ppl)/(min_ppl-init_ppl)) step_forgetting.append(None) learnability_step_per_ex.append(step_learnability) @@ -220,7 +109,7 @@ def get_probe_measurements(ppls, return learnability_per_ex, forgetting_per_ex, init_per_ex, last_per_ex, learnability_step_per_ex, forgetting_step_per_ex -def measure_scores(result, interval=50, skip_log_learnability=False, skip_log_forgetting=False, relative=False, absolute=False, log=False): +def measure_scores(result, interval=50, skip_log_effectivity=False, skip_log_forgetting=False, relative=False, absolute=False, log=False): forgetting_score = {"duplication": {}, "paraphrase": {}, "once": {}} learnability_score = {"duplication": {}, "paraphrase": {}, "once": {}} @@ -341,17 +230,8 @@ def measure_scores(result, interval=50, skip_log_learnability=False, skip_log_fo step_learnability_score["once"]["mem"] = [a for a in mem_learnability_step_per_ex] step_learnability_score["once"]["gen"] = [a for a in gen_learnability_step_per_ex] step_learnability_score["once"]["gen_hard"] = [a for a in gen_hard_learnability_step_per_ex] - - # store mean values - # mem_learnability.append({k: mean(v) for (k, v) in mem_learnability_per_ex.items()}) - # gen_learnability.append({k: mean(v) for (k, v) in mem_learnability_per_ex.items()}) - # gen_hard_learnability.append({k: mean(v) for (k, v) in mem_learnability_per_ex.items()}) - # mem_forgetting.append({k: mean(v) for (k, v) in mem_learnability_per_ex.items()}) - # gen_forgetting.append({k: mean(v) for (k, v) in mem_learnability_per_ex.items()}) - # gen_hard_forgetting.append({k: mean(v) for (k, v) in mem_learnability_per_ex.items()}) - - # print(f"memorizability: mean {mean(memorizability)} / {statistics.pstdev(memorizability)}") - if not skip_log_learnability: + + if not skip_log_effectivity: if ex_idx+1==40: print('==========\nParaphrased\n==========') elif ex_idx+1==80: @@ -359,15 +239,11 @@ def measure_scores(result, interval=50, skip_log_learnability=False, skip_log_fo else: print('==========\nOnce\n==========') - print(f"mem_learnability: {mean(mem_learnability_per_ex['target']):.2f}") - # print(f"mem_init: {mean(mem_init_per_ex['target']):.2f}\nmem_last: {mean(mem_last_per_ex['target']):.2f}") - # print(f"{statistics.pstdev(mem_learnability_per_ex['target']):.2f}") + print(f"mem_effectivity: {mean(mem_learnability_per_ex['target']):.2f}") print('-'*50) - print(f"gen_learnability: {mean(gen_learnability_per_ex['target']):.2f}") - # print(f"gen_init: {mean(gen_init_per_ex['target']):.2f}\ngen_last: {mean(gen_last_per_ex['target']):.2f}") + print(f"gen_effectivity: {mean(gen_learnability_per_ex['target']):.2f}") print('-'*50) - print(f"gen_hard_learnability: {mean(gen_hard_learnability_per_ex['target']):.2f}") - # print(f"gen_hard_init: {mean(gen_hard_init_per_ex['target']):.2f}\ngen_hard_last: {mean(gen_hard_last_per_ex['target']):.2f}") + print(f"gen_hard_effectivity: {mean(gen_hard_learnability_per_ex['target']):.2f}") print() print('='*50) print() @@ -398,112 +274,34 @@ def measure_scores(result, interval=50, skip_log_learnability=False, skip_log_fo return forgetting_score - # print(len(gen_learnability_all_per_ex)) - # print(len(gen_learnability_easy_per_ex)+len(gen_learnability_hard_per_ex)) - - # # Filter out -1 and get the indices - # filtered_data_with_indices = [(value, index) for index, value in enumerate(gen_learnability_all_per_ex) if value != -1] - - # # Sort the list by value - # sorted_data = sorted(filtered_data_with_indices) - - # # Get the indices of the lowest and highest 10 values - # lowest_10_indices = [(index, value) for value, index in sorted_data[:10]] - # top_10_indices = [(index, value) for value, index in sorted_data[-10:]] - - # # Store the information in a dictionary - # result = { - # 'top_10_indices': top_10_indices, - # 'lowest_10_indices': lowest_10_indices - # } - - # with open('indices_info.json', 'w') as f: - # json.dump(result, f, indent=4) - - # Fit the data to a power-law distribution - - # fit_powerlaw(pre_gen_fluc_per_ex, mode='pre_gen') - # fit_powerlaw(gen_fluc_per_ex, mode='gen') - # fit_powerlaw(pre_mem_fluc_per_ex, mode='pre-mem') - # fit_powerlaw(mem_fluc_per_ex, mode='mem') - - -def plot_perplexity(rows, cols, plot_number, steps, x_mem, x_gen, xlabel, ylabel, scatter_data=None, local=False, x_hard_gen=None, avg=False, once=False, mode=None): - # steps = steps[:2000] +def plot_perplexity(rows, cols, plot_number, steps, x_mem, x_gen, xlabel, ylabel, scatter_data=None, x_hard_gen=None, avg=False, once=False, mode=None): steps = range(0,2000) x_mem = x_mem[:2000] x_gen = x_gen[:2000] if x_hard_gen is not None: x_hard_gen = x_hard_gen[:2000] ax = plt.subplot(rows, cols, plot_number) - if local: - steps = range(-30,90) - ax.xaxis.set_major_locator(ticker.MultipleLocator(20)) - xlabel = 'Training Steps (t)' - ax.set_xlabel(xlabel, fontsize=20) - else: - ax.xaxis.set_major_locator(ticker.MultipleLocator(500)) - xlabel = xlabel.split('\n') - xlabel = "\n".join([""+x for x in xlabel]) - if local: - # ymin, ymax = ax.get_ylim() - ymin, ymax = -0.2, 0.8 - ax.set_ylim(ymin, ymax) - ax.plot(steps, [-x+x_mem[30] for x in x_mem], color='blue', label=r'$\ell(q)$') - ax.fill_between(steps, ymin, ymax, where=(np.array(steps) >= 0) & (np.array(steps) <= 50), color='lightgreen', alpha=0.5, label='Window') - ax.axvline(x=29, color='red', linestyle='-', linewidth=3) - ax.annotate(r'$t_{LAM}(q,i)$', xy=(30, ymin+0.02), xytext=(30, ymin + 0.12), - color='red', arrowprops=dict(color='red', arrowstyle='simple'), fontsize=18) - - # Add vertical line at x=-1 - y_value_at_t29 = -x_mem[59] + x_mem[30] + 0.2 - y_value_at_t60 = -x_mem[89] + x_mem[30] + 0.2 - #59.5 - offset = 0.012 - ax.axvline(x=-2, ymin=0.2, ymax=y_value_at_t29, color='dimgray', linestyle='-', linewidth=3) - ax.axhline(y=y_value_at_t29-0.2, xmin=32/120-offset, xmax=33/120-offset, color='dimgray', linestyle='-', linewidth=3) - ax.axhline(y=0, xmin=32/120-offset, xmax=33/120-offset, color='dimgray', linestyle='-', linewidth=3) - ax.text(-3, y_value_at_t29 / 2 - 0.1, r'$\mathcal{E}(q,i)$', color='dimgray', ha='right', va='center', fontsize=18) - - offset = -0.455 - ax.axvline(x=59, ymin=0.2, ymax=y_value_at_t60, color='navy', linestyle='-', linewidth=3) - ax.axhline(y=y_value_at_t60-0.2, xmin=32/120-offset, xmax=33/120-offset, color='navy', linestyle='-', linewidth=3) - ax.axhline(y=0, xmin=32/120-offset, xmax=33/120-offset, color='navy', linestyle='-', linewidth=3) - ax.text(71, y_value_at_t60 / 2 - 0.1, r'$\mathcal{R}(q,30)$', color='navy', ha='right', va='center', fontsize=18) - - ax.axhline(y=y_value_at_t29-0.2, color='black', linestyle='-.', linewidth=1) - ax.axhline(y=0, color='black', linestyle='-.', linewidth=1) - - # Annotate the height of the line - - - else: - ax.plot(steps, [-x+x_mem[0] for x in x_mem], color='blue', label='Memorization') - ax.plot(steps, [-x+x_gen[0] for x in x_gen], color='orange', label='Semantic') + ax.xaxis.set_major_locator(ticker.MultipleLocator(500)) + xlabel = xlabel.split('\n') + xlabel = "\n".join([""+x for x in xlabel]) + + ax.plot(steps, [-x+x_mem[0] for x in x_mem], color='blue', label='Memorization') + ax.plot(steps, [-x+x_gen[0] for x in x_gen], color='orange', label='Semantic') + if avg: ax.plot(steps, [-x+x_hard_gen[0] for x in x_hard_gen], color='red', label='Composition') if scatter_data: x_vals, y_vals, colors, sizes = scatter_data ax.scatter(x_vals, y_vals, color=colors, s=sizes) - - # Set major ticks formatter and locator - - # ax.xaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: f'{int(x/1000)}k')) # Format tick labels as 'k' units # ymin, ymax = 0.0, 2.5 ymin, ymax = -0.1, 1.8 ax.set_ylim(ymin, ymax) - # ymin, ymax = ax.get_ylim() - # ymax=500 - x_positions = [0] if local or once else [steps[i] for i in [i*100 for i in range(10)]] + x_positions = [0] if once else [steps[i] for i in [i*100 for i in range(10)]] plt.vlines(x=x_positions, ymin=ymin, ymax=ymax, colors='black', linestyles='dotted', label='Injection', linewidth=3) - # ax.set_ylim(0, 500) if avg and mode=='once': ax.set_xlabel('Training Steps', fontsize=24) - # pass - if local: - ax.set_ylabel(r'$\Delta\ell(q)$', fontsize=24) else: ax.set_ylabel(r'Avg. $\Delta\ell(q)$', fontsize=24) ax.grid(True) @@ -526,7 +324,6 @@ def plot_difference(rows, cols, plot_number, steps, x_mem, x_gen, xlabel, ylabel # Set major ticks formatter and locator ax.xaxis.set_major_locator(ticker.MultipleLocator(500)) - # ax.xaxis.set_major_formatter(ticker.FuncFormatter(lambda x, pos: f'{int(x/1000)}k')) # Format tick labels as 'k' units ymin, ymax = ax.get_ylim() # ymax=500 @@ -539,129 +336,58 @@ def plot_difference(rows, cols, plot_number, steps, x_mem, x_gen, xlabel, ylabel ax.grid(True) -def plot_ppl_with_trained_at(result, save_dir, min_step, draw_single, draw_avg): +def plot_ppl_with_trained_at(result, save_dir, min_step): steps = [data["step"] for data in result] all_mem_ppls = [] all_gen_ppls = [] all_hard_gen_ppls = [] - # gen_ppls = [instance["ppl_probe"] for instance in result] keys = ['mem_first', 'mem_target', 'mem_full', 'gen_first', 'gen_target', 'gen_full', 'gen_hard_first', 'gen_hard_target', 'gen_hard_full', 'def'] ppl_data = {key: [instance[key] for instance in result] for key in keys} - # print(len(ppl_data['gen_hard_full']),len(ppl_data['gen_hard_full'][0]),len(ppl_data['gen_hard_full'][0][0])) for key in ppl_data: # if key != 'def': ppl_data[key] = list(map(list, zip(*ppl_data[key]))) - # print(len(ppl_data['gen_hard_full'][0][0])) - # plt.figure(figsize=(16, 20)) with open(os.path.join(args.base_dir, 'fictional_knowledge/fictional_knowledge_paraphrased.json'), 'r') as f: dataset = json.load(f) + + par_ppl_mem_avg = np.mean(np.array([[mean(d) for d in ppl_data["mem_target"][ex_idx]] for ex_idx in range(0,40)]), axis=0) + par_ppl_gen_avg = np.mean(np.array([[mean(d) for d in ppl_data["gen_target"][ex_idx]] for ex_idx in range(0,40)]), axis=0) + par_ppl_hard_gen_avg = np.mean(np.array([[mean(d) for d in ppl_data["gen_hard_target"][ex_idx]] for ex_idx in range(0,40)]), axis=0) - # assert len(dataset)==len(all_gen_ppls[0]) - if draw_avg: - par_ppl_mem_avg = np.mean(np.array([[mean(d) for d in ppl_data["mem_target"][ex_idx]] for ex_idx in range(0,40)]), axis=0) - par_ppl_gen_avg = np.mean(np.array([[mean(d) for d in ppl_data["gen_target"][ex_idx]] for ex_idx in range(0,40)]), axis=0) - par_ppl_hard_gen_avg = np.mean(np.array([[mean(d) for d in ppl_data["gen_hard_target"][ex_idx]] for ex_idx in range(0,40)]), axis=0) - - dup_ppl_mem_avg = np.mean(np.array([[mean(d) for d in ppl_data["mem_target"][ex_idx]] for ex_idx in range(40,80)]), axis=0) - dup_ppl_gen_avg = np.mean(np.array([[mean(d) for d in ppl_data["gen_target"][ex_idx]] for ex_idx in range(40,80)]), axis=0) - dup_ppl_hard_gen_avg = np.mean(np.array([[mean(d) for d in ppl_data["gen_hard_target"][ex_idx]] for ex_idx in range(40,80)]), axis=0) - - once_ppl_mem_avg = np.mean(np.array([[mean(d) for d in ppl_data["mem_target"][ex_idx]] for ex_idx in range(80,120)]), axis=0) - once_ppl_gen_avg = np.mean(np.array([[mean(d) for d in ppl_data["gen_target"][ex_idx]] for ex_idx in range(80,120)]), axis=0) - once_ppl_hard_gen_avg = np.mean(np.array([[mean(d) for d in ppl_data["gen_hard_target"][ex_idx]] for ex_idx in range(80,120)]), axis=0) - - plt.figure(figsize=(30, 4)) - plot_perplexity(1, 1, 1, steps, par_ppl_mem_avg, par_ppl_gen_avg, '', 'Log Probability', x_hard_gen=par_ppl_hard_gen_avg, avg=True, mode='par') - plt.savefig(os.path.join(save_dir, args.exp_name[:-5]+'_par.pdf'), bbox_inches='tight') - plt.close() - - plt.figure(figsize=(30, 4)) - plot_perplexity(1, 1, 1, steps, dup_ppl_mem_avg, dup_ppl_gen_avg, '', 'Log Probability', x_hard_gen=dup_ppl_hard_gen_avg, avg=True, mode='dup') - plt.savefig(os.path.join(save_dir, args.exp_name[:-5]+'_dup.pdf'), bbox_inches='tight') - plt.close() - - plt.figure(figsize=(30, 4)) - plot_perplexity(1, 1, 1, steps, once_ppl_mem_avg, once_ppl_gen_avg, '', 'Log Probability', x_hard_gen=once_ppl_hard_gen_avg, avg=True, once=True, mode='once') - plt.savefig(os.path.join(save_dir, args.exp_name[:-5]+'_once.pdf'), bbox_inches='tight') - plt.close() + dup_ppl_mem_avg = np.mean(np.array([[mean(d) for d in ppl_data["mem_target"][ex_idx]] for ex_idx in range(40,80)]), axis=0) + dup_ppl_gen_avg = np.mean(np.array([[mean(d) for d in ppl_data["gen_target"][ex_idx]] for ex_idx in range(40,80)]), axis=0) + dup_ppl_hard_gen_avg = np.mean(np.array([[mean(d) for d in ppl_data["gen_hard_target"][ex_idx]] for ex_idx in range(40,80)]), axis=0) - plt.figure(figsize=(30, 4)) - plot_difference(1, 1, 1, steps, once_ppl_mem_avg, once_ppl_gen_avg, 'Step', 'Log Probability') - plt.savefig(os.path.join(save_dir, args.exp_name[:-5]+'_diff.pdf'), bbox_inches='tight') - plt.close() + once_ppl_mem_avg = np.mean(np.array([[mean(d) for d in ppl_data["mem_target"][ex_idx]] for ex_idx in range(80,120)]), axis=0) + once_ppl_gen_avg = np.mean(np.array([[mean(d) for d in ppl_data["gen_target"][ex_idx]] for ex_idx in range(80,120)]), axis=0) + once_ppl_hard_gen_avg = np.mean(np.array([[mean(d) for d in ppl_data["gen_hard_target"][ex_idx]] for ex_idx in range(80,120)]), axis=0) + plt.figure(figsize=(30, 4)) + plot_perplexity(1, 1, 1, steps, par_ppl_mem_avg, par_ppl_gen_avg, '', 'Log Probability', x_hard_gen=par_ppl_hard_gen_avg, avg=True, mode='par') + plt.savefig(os.path.join(save_dir, args.exp_name[:-5]+'_par.pdf'), bbox_inches='tight') + plt.close() + + plt.figure(figsize=(30, 4)) + plot_perplexity(1, 1, 1, steps, dup_ppl_mem_avg, dup_ppl_gen_avg, '', 'Log Probability', x_hard_gen=dup_ppl_hard_gen_avg, avg=True, mode='dup') + plt.savefig(os.path.join(save_dir, args.exp_name[:-5]+'_dup.pdf'), bbox_inches='tight') + plt.close() + + plt.figure(figsize=(30, 4)) + plot_perplexity(1, 1, 1, steps, once_ppl_mem_avg, once_ppl_gen_avg, '', 'Log Probability', x_hard_gen=once_ppl_hard_gen_avg, avg=True, once=True, mode='once') + plt.savefig(os.path.join(save_dir, args.exp_name[:-5]+'_once.pdf'), bbox_inches='tight') + plt.close() + + plt.figure(figsize=(30, 4)) + plot_difference(1, 1, 1, steps, once_ppl_mem_avg, once_ppl_gen_avg, 'Step', 'Log Probability') + plt.savefig(os.path.join(save_dir, args.exp_name[:-5]+'_diff.pdf'), bbox_inches='tight') + plt.close() - else: - for ex_idx in tqdm(range(len(ppl_data['def']))): - num_probes = 5 # Assuming all results have the same structure - - probes = list(zip([dataset[ex_idx]["mem_input"][i] + " " + "\""+dataset[ex_idx]["mem_target"][i]+"\"" for i in range(5)], [dataset[ex_idx]["gen_input"][i] + " " + "\""+dataset[ex_idx]["gen_target"][i]+"\"" for i in range(5)], [dataset[ex_idx]["hard_gen_input"][i] + " " + "\""+dataset[ex_idx]["hard_gen_target"][i]+"\"" for i in range(5)])) - texts = [f"Mem probe: {m}\nGen_probe: {g}" for (m, g, h) in probes] - hard_texts = [f"Hard_Gen_probe: {h}" for (m, g, h) in probes] - if draw_single: - if ex_idx!=6: - continue - plt.figure(figsize=(15, 4)) - ppl_mem_target = [d[0] for d in ppl_data["mem_target"][ex_idx]] - ppl_gen_target = [d[0] for d in ppl_data["gen_target"][ex_idx]] - plot_perplexity(1, 1, 1, steps[170:290], ppl_mem_target[170:290], ppl_gen_target[170:290], '', 'Log Probability', local=True) - else: - plt.figure(figsize=(60, 30)) - # plt.figure(figsize=(16, 30)) - - for j in range(num_probes): - # Generate subplot indices for the current subplot row - first_subplot_idx = 2 + j * 6 - target_subplot_idx = 1 + j * 6 - full_subplot_idx = 3 + j * 6 - hard_gen_subplot_idx = 4 + j * 6 - - # Get data for current probe across all examples - ppl_mem_first = [d[j] for d in ppl_data["mem_first"][ex_idx]] - ppl_mem_target = [d[j] for d in ppl_data["mem_target"][ex_idx]] - ppl_mem_full = [d[j] for d in ppl_data["mem_full"][ex_idx]] - - ppl_gen_first = [d[j] for d in ppl_data["gen_first"][ex_idx]] - ppl_gen_target = [d[j] for d in ppl_data["gen_target"][ex_idx]] - ppl_gen_full = [d[j] for d in ppl_data["gen_full"][ex_idx]] - - ppl_hard_gen_first = [d[j] for d in ppl_data["gen_hard_first"][ex_idx]] - ppl_hard_gen_target = [d[j] for d in ppl_data["gen_hard_target"][ex_idx]] - ppl_hard_gen_full = [d[j] for d in ppl_data["gen_hard_full"][ex_idx]] - - ppl_def = [ppl_data["def"][ex_idx]] - - # Plot perplexity and differences - plot_perplexity(num_probes+1, 6, target_subplot_idx, steps, ppl_mem_target, ppl_gen_target, texts[j], 'Perplexity (Target)') - plot_perplexity(num_probes+1, 6, first_subplot_idx, steps, ppl_mem_first, ppl_gen_first, '', 'Perplexity (First)') - plot_perplexity(num_probes+1, 6, full_subplot_idx, steps, ppl_mem_full, ppl_gen_full, '', 'Perplexity (Full)') - # plot_difference(num_probes+1, 4, diff_subplot_idx, steps, x_mem, x_gen, '', 'Difference Perplexity') - plot_difference(num_probes+1, 6, hard_gen_subplot_idx, steps, ppl_hard_gen_target, None, hard_texts[j], 'Hard-Gen Perplexity (Target)') - plot_difference(num_probes+1, 6, hard_gen_subplot_idx+1, steps, ppl_hard_gen_first, None, '', 'Hard-Gen Perplexity (First)') - plot_difference(num_probes+1, 6, hard_gen_subplot_idx+2, steps, ppl_hard_gen_full, None, '', 'Hard-Gen Perplexity (Full)') - - def_idx = 1+5*6 - plot_difference(num_probes+1, 6, def_idx, steps, ppl_def[0], None, '', 'Def Perplexity') - - plt.tight_layout() # Adjust layout to make room for all plots - - # Annotate each row with descriptive text - # for i, text in enumerate(texts): - # plt.figtext(0.9, 0.15 * (len(texts) - i), text, fontsize=12) - - - # Save the figure to a file - plt.savefig(os.path.join(save_dir, args.exp_name[:-5], str(ex_idx)+'.pdf'), bbox_inches='tight') - plt.close() - def preprocess_result(result): new_result = [] for res in result: instance = {k: v for (k,v) in res.items()} - # ['step', 'mem_first', 'mem_target', 'mem_full', 'gen_first', 'gen_target', 'gen_full', 'gen_hard_first', 'gen_hard_target', 'gen_hard_full', 'def'] for k in instance.keys(): if k=='step': continue @@ -688,14 +414,14 @@ def main(args): if args.mode=='draw_figures': os.makedirs(args.save_dir, exist_ok=-True) plot_indices = range(156,196) - plot_ppl_with_trained_at(result, save_dir=args.save_dir, min_step=min_step, draw_single=args.draw_single, draw_avg=args.draw_avg) + plot_ppl_with_trained_at(result, save_dir=args.save_dir, min_step=min_step) elif args.mode=='measure_scores': if args.skip_log_forgetting: measure_scores(result, interval=50, - skip_log_learnability=args.skip_log_learnability, + skip_log_effectivity=args.skip_log_effectivity, skip_log_forgetting=args.skip_log_forgetting, relative=args.relative, absolute=args.absolute) @@ -710,7 +436,7 @@ def main(args): for i in tqdm(range(interval)): single_result = measure_scores(result, interval=i, - skip_log_learnability=args.skip_log_learnability, + skip_log_effectivity=args.skip_log_effectivity, skip_log_forgetting=args.skip_log_forgetting, relative=args.relative, absolute=args.absolute) @@ -744,20 +470,14 @@ def main(args): parser = argparse.ArgumentParser() - # exp_name = "ft_medium_8e-6" - # data_file = "./data/ecbd/all_ent_2020_2021_np_easy.json" - - # Add arguments parser.add_argument('--base_dir', type=str, default='/home/hoyeon/OLMo') parser.add_argument('--save_dir', type=str, default="figs") parser.add_argument('--exp_name', type=str, required=True) parser.add_argument('--mode', type=str, default="draw_figures") parser.add_argument('--no_take_exp', action='store_true') - parser.add_argument('--skip_log_learnability', action='store_true') + parser.add_argument('--skip_log_effectivity', action='store_true') parser.add_argument('--skip_log_forgetting', action='store_true') - parser.add_argument('--draw_single', action='store_true') - parser.add_argument('--draw_avg', action='store_true') parser.add_argument('--relative', action='store_true') parser.add_argument('--absolute', action='store_true') diff --git a/olmo/train.py b/olmo/train.py index 58b4de8..b506795 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -385,8 +385,7 @@ def extract_step_number(path): log.info("Resetting learning rate...") new_learning_rate = self.scheduler.get_lr( self.cfg.optimizer.learning_rate, self.scheduler_current, self.scheduler_max - ) #/(2048/self.cfg.global_train_batch_size) - # new_learning_rate = 3.0e-4/16 # Hard-coded (temporary) + ) log.info(f"new_learning_rate: {new_learning_rate}") log.info(f"scheduler_current: {self.scheduler_current}") for group in self.optim.param_groups: