From 058d1a83b81b0e1897d4cd0551a9cf3396511c64 Mon Sep 17 00:00:00 2001 From: Curt Tigges Date: Thu, 12 Oct 2023 17:14:13 -0700 Subject: [PATCH] Continued rebase; Misc changes for comma experiments --- comma_ablation_experiments copy.ipynb | 5789 +++++++++++++++++++++++++ projection_neuroscope.py | 82 +- 2 files changed, 5835 insertions(+), 36 deletions(-) create mode 100644 comma_ablation_experiments copy.ipynb diff --git a/comma_ablation_experiments copy.ipynb b/comma_ablation_experiments copy.ipynb new file mode 100644 index 0000000..767a7a4 --- /dev/null +++ b/comma_ablation_experiments copy.ipynb @@ -0,0 +1,5789 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "README.md\t\t miniconda.sh quick_start_pytorch.ipynb wandb\n", + "eliciting-latent-sentiment miniconda3\t quick_start_pytorch_images\n" + ] + } + ], + "source": [ + "!ls" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/notebooks/eliciting-latent-sentiment\n" + ] + } + ], + "source": [ + "%cd eliciting-latent-sentiment" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: fancy_einsum==0.0.3 in /usr/local/lib/python3.9/dist-packages (0.0.3)\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: transformer_lens in /usr/local/lib/python3.9/dist-packages (0.0.0)\n", + "Requirement already satisfied: datasets>=2.7.1 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (2.14.5)\n", + "Requirement already satisfied: jaxtyping>=0.2.11 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (0.2.13)\n", + "Requirement already satisfied: pandas>=1.1.5 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (1.5.0)\n", + "Requirement already satisfied: tqdm>=4.64.1 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (4.64.1)\n", + "Requirement already satisfied: fancy-einsum>=0.0.3 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (0.0.3)\n", + "Requirement already satisfied: beartype<0.15.0,>=0.14.1 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (0.14.1)\n", + "Requirement already satisfied: rich>=12.6.0 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (13.2.0)\n", + "Requirement already satisfied: transformers>=4.25.1 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (4.33.2)\n", + "Requirement already satisfied: torch>=1.10 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (1.12.1+cu116)\n", + "Requirement already satisfied: wandb>=0.13.5 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (0.15.10)\n", + "Requirement already satisfied: einops>=0.6.0 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (0.6.1)\n", + "Requirement already satisfied: numpy>=1.21 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (1.23.4)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (23.0)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (5.4.1)\n", + "Requirement already satisfied: multiprocess in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (0.70.13)\n", + "Requirement already satisfied: xxhash in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (3.2.0)\n", + "Requirement already satisfied: huggingface-hub<1.0.0,>=0.14.0 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (0.17.2)\n", + "Requirement already satisfied: aiohttp in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (3.8.3)\n", + "Requirement already satisfied: fsspec[http]<2023.9.0,>=2023.1.0 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (2023.1.0)\n", + "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (10.0.1)\n", + "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (2.28.2)\n", + "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (0.3.5.1)\n", + "Requirement already satisfied: typeguard>=2.13.3 in /usr/local/lib/python3.9/dist-packages (from jaxtyping>=0.2.11->transformer_lens) (4.1.5)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.1 in /usr/local/lib/python3.9/dist-packages (from jaxtyping>=0.2.11->transformer_lens) (4.8.0)\n", + "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.9/dist-packages (from pandas>=1.1.5->transformer_lens) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/dist-packages (from pandas>=1.1.5->transformer_lens) (2022.7.1)\n", + "Requirement already satisfied: markdown-it-py<3.0.0,>=2.1.0 in /usr/local/lib/python3.9/dist-packages (from rich>=12.6.0->transformer_lens) (2.1.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.6.0 in /usr/local/lib/python3.9/dist-packages (from rich>=12.6.0->transformer_lens) (2.14.0)\n", + "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.9/dist-packages (from transformers>=4.25.1->transformer_lens) (0.12.1)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from transformers>=4.25.1->transformer_lens) (3.9.0)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.9/dist-packages (from transformers>=4.25.1->transformer_lens) (2022.10.31)\n", + "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.9/dist-packages (from transformers>=4.25.1->transformer_lens) (0.3.3)\n", + "Requirement already satisfied: Click!=8.0.0,>=7.1 in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (8.1.3)\n", + "Requirement already satisfied: appdirs>=1.4.3 in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (1.4.4)\n", + "Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (5.9.4)\n", + "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (3.1.30)\n", + "Requirement already satisfied: pathtools in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (0.1.2)\n", + "Requirement already satisfied: sentry-sdk>=1.0.0 in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (1.14.0)\n", + "Requirement already satisfied: docker-pycreds>=0.4.0 in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (0.4.0)\n", + "Requirement already satisfied: setproctitle in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (1.3.2)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (66.1.1)\n", + "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.15.0 in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (3.20.3)\n", + "Requirement already satisfied: six>=1.4.0 in /usr/lib/python3/dist-packages (from docker-pycreds>=0.4.0->wandb>=0.13.5->transformer_lens) (1.14.0)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (18.2.0)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (4.0.2)\n", + "Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (2.1.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (6.0.4)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.3.3)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.8.2)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.3.1)\n", + "Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.9/dist-packages (from GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer_lens) (4.0.10)\n", + "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.9/dist-packages (from markdown-it-py<3.0.0,>=2.1.0->rich>=12.6.0->transformer_lens) (0.1.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3/dist-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (2.8)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/lib/python3/dist-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (2019.11.28)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/dist-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (1.26.14)\n", + "Requirement already satisfied: importlib-metadata>=3.6 in /usr/local/lib/python3.9/dist-packages (from typeguard>=2.13.3->jaxtyping>=0.2.11->transformer_lens) (5.2.0)\n", + "Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.9/dist-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer_lens) (5.0.0)\n", + "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.9/dist-packages (from importlib-metadata>=3.6->typeguard>=2.13.3->jaxtyping>=0.2.11->transformer_lens) (3.11.0)\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: jaxtyping==0.2.13 in /usr/local/lib/python3.9/dist-packages (0.2.13)\n", + "Requirement already satisfied: numpy>=1.20.0 in /usr/local/lib/python3.9/dist-packages (from jaxtyping==0.2.13) (1.23.4)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.1 in /usr/local/lib/python3.9/dist-packages (from jaxtyping==0.2.13) (4.8.0)\n", + "Requirement already satisfied: typeguard>=2.13.3 in /usr/local/lib/python3.9/dist-packages (from jaxtyping==0.2.13) (4.1.5)\n", + "Requirement already satisfied: importlib-metadata>=3.6 in /usr/local/lib/python3.9/dist-packages (from typeguard>=2.13.3->jaxtyping==0.2.13) (5.2.0)\n", + "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.9/dist-packages (from importlib-metadata>=3.6->typeguard>=2.13.3->jaxtyping==0.2.13) (3.11.0)\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: einops in /usr/local/lib/python3.9/dist-packages (0.6.1)\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: protobuf==3.20.* in /usr/local/lib/python3.9/dist-packages (3.20.3)\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: plotly in /usr/local/lib/python3.9/dist-packages (5.17.0)\n", + "Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.9/dist-packages (from plotly) (8.2.3)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.9/dist-packages (from plotly) (23.0)\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: torchtyping in /usr/local/lib/python3.9/dist-packages (0.1.4)\n", + "Requirement already satisfied: typeguard>=2.11.1 in /usr/local/lib/python3.9/dist-packages (from torchtyping) (4.1.5)\n", + "Requirement already satisfied: torch>=1.7.0 in /usr/local/lib/python3.9/dist-packages (from torchtyping) (1.12.1+cu116)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (4.8.0)\n", + "Requirement already satisfied: importlib-metadata>=3.6 in /usr/local/lib/python3.9/dist-packages (from typeguard>=2.11.1->torchtyping) (5.2.0)\n", + "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.9/dist-packages (from importlib-metadata>=3.6->typeguard>=2.11.1->torchtyping) (3.11.0)\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0mCollecting git+https://github.com/neelnanda-io/neel-plotly.git\n", + " Cloning https://github.com/neelnanda-io/neel-plotly.git to /tmp/pip-req-build-1silpp2s\n", + " Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/neel-plotly.git /tmp/pip-req-build-1silpp2s\n", + " Resolved https://github.com/neelnanda-io/neel-plotly.git to commit 6dc24b26f8dec991908479d7445dae496b3430b7\n", + " Preparing metadata (setup.py) ... \u001b[?25ldone\n", + "\u001b[?25hRequirement already satisfied: einops in /usr/local/lib/python3.9/dist-packages (from neel-plotly==0.0.0) (0.6.1)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.9/dist-packages (from neel-plotly==0.0.0) (1.23.4)\n", + "Requirement already satisfied: torch in /usr/local/lib/python3.9/dist-packages (from neel-plotly==0.0.0) (1.12.1+cu116)\n", + "Requirement already satisfied: plotly in /usr/local/lib/python3.9/dist-packages (from neel-plotly==0.0.0) (5.17.0)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.9/dist-packages (from neel-plotly==0.0.0) (4.64.1)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.9/dist-packages (from neel-plotly==0.0.0) (1.5.0)\n", + "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.9/dist-packages (from pandas->neel-plotly==0.0.0) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/dist-packages (from pandas->neel-plotly==0.0.0) (2022.7.1)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.9/dist-packages (from plotly->neel-plotly==0.0.0) (23.0)\n", + "Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.9/dist-packages (from plotly->neel-plotly==0.0.0) (8.2.3)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (4.8.0)\n", + "Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.1->pandas->neel-plotly==0.0.0) (1.14.0)\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: circuitsvis in /usr/local/lib/python3.9/dist-packages (1.41.0)\n", + "Requirement already satisfied: numpy<2.0,>=1.21 in /usr/local/lib/python3.9/dist-packages (from circuitsvis) (1.23.4)\n", + "Requirement already satisfied: importlib-metadata<6.0.0,>=5.1.0 in /usr/local/lib/python3.9/dist-packages (from circuitsvis) (5.2.0)\n", + "Requirement already satisfied: torch>=1.10 in /usr/local/lib/python3.9/dist-packages (from circuitsvis) (1.12.1+cu116)\n", + "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.9/dist-packages (from importlib-metadata<6.0.0,>=5.1.0->circuitsvis) (3.11.0)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->circuitsvis) (4.8.0)\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "#!pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python\n", + "!pip install fancy_einsum==0.0.3\n", + "!pip install transformer_lens\n", + "!pip install jaxtyping==0.2.13\n", + "!pip install einops\n", + "!pip install protobuf==3.20.*\n", + "!pip install plotly\n", + "!pip install torchtyping\n", + "!pip install git+https://github.com/neelnanda-io/neel-plotly.git\n", + "!pip install circuitsvis\n", + "# !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", + "# %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", + "# %pip install typeguard==2.13.3" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython import get_ipython\n", + "ipython = get_ipython()\n", + "ipython.run_line_magic(\"load_ext\", \"autoreload\")\n", + "ipython.run_line_magic(\"autoreload\", \"2\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import einops\n", + "from functools import partial\n", + "import torch\n", + "import datasets\n", + "from torch import Tensor\n", + "from torch.utils.data import DataLoader\n", + "from datasets import load_dataset, concatenate_datasets\n", + "from jaxtyping import Float, Int, Bool\n", + "from typing import Dict, Iterable, List, Tuple, Union\n", + "from transformer_lens import HookedTransformer\n", + "from transformer_lens.utils import get_dataset, tokenize_and_concatenate, get_act_name, test_prompt\n", + "from transformer_lens.hook_points import HookPoint\n", + "from tqdm.notebook import tqdm\n", + "import pandas as pd\n", + "from circuitsvis.activations import text_neuron_activations\n", + "from utils.store import load_array, save_html, save_array, is_file, get_model_name, clean_label, save_text\n", + "from utils.circuit_analysis import get_logit_diff" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "def find_positions(tensor, token_ids=[11, 13]):\n", + " positions = []\n", + " for batch_item in tensor:\n", + " token_positions = {token_id: [] for token_id in token_ids}\n", + " for position, token in enumerate(batch_item):\n", + " if token.item() in token_ids:\n", + " token_positions[token.item()].append(position)\n", + " positions.append([token_positions[token_id] for token_id in token_ids])\n", + " return positions" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "def zero_attention_pos_hook(\n", + " pattern: Float[Tensor, \"batch head seq_Q seq_K\"], hook: HookPoint,\n", + " pos_by_batch: List[List[int]], layer: int = 0, head_idx: int = 0,\n", + ") -> Float[Tensor, \"batch head seq_Q seq_K\"]:\n", + " \"\"\"Zero-ablates an attention pattern tensor at a particular position\"\"\"\n", + " assert 'pattern' in hook.name\n", + "\n", + " batch_size = pattern.shape[0]\n", + " assert len(pos_by_batch) == batch_size\n", + "\n", + " for i in range(batch_size):\n", + " for p in pos_by_batch[i]:\n", + " pattern[i, head_idx, p, p] = 0\n", + " \n", + " return pattern" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "def names_filter(name: str):\n", + " \"\"\"Filter for the names of the activations we want to keep to study the resid stream.\"\"\"\n", + " return name.endswith('resid_post') or name == get_act_name('resid_pre', 0)\n", + "\n", + "def get_layerwise_token_mean_activations(model: HookedTransformer, data_loader: DataLoader, token_id: int = 13) -> Float[Tensor, \"layer d_model\"]:\n", + " \"\"\"Get the mean value of a token across layers\"\"\"\n", + " num_layers = model.cfg.n_layers\n", + " d_model = model.cfg.d_model\n", + " \n", + " activation_sums = torch.stack([torch.zeros(d_model) for _ in range(num_layers)]).to(device)\n", + " comma_counts = [0] * num_layers\n", + "\n", + " print(activation_sums.shape)\n", + "\n", + " token_mean_values = torch.zeros((num_layers, d_model))\n", + " for _, batch_value in tqdm(enumerate(data_loader), total=len(data_loader)):\n", + " \n", + " batch_tokens = batch_value['tokens'].to(device)\n", + "\n", + " # get positions of all 11 and 13 token ids in batch\n", + " punct_pos = find_positions(batch_tokens, token_ids=[13])\n", + "\n", + " _, cache = model.run_with_cache(\n", + " batch_tokens, \n", + " names_filter=names_filter\n", + " )\n", + "\n", + " \n", + " for i in range(batch_tokens.shape[0]):\n", + " for p in punct_pos[i][0]:\n", + " for layer in range(num_layers):\n", + " activation_sums[layer] += cache[f\"blocks.{layer}.hook_resid_post\"][i, p, :]\n", + " comma_counts[layer] += 1\n", + "\n", + " for layer in range(num_layers):\n", + " token_mean_values[layer] = activation_sums[layer] / comma_counts[layer]\n", + "\n", + " return token_mean_values" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "def compute_zeroed_attn_modified_loss(model: HookedTransformer, data_loader: DataLoader) -> float:\n", + " total_loss = 0\n", + " loss_list = []\n", + " for _, batch_value in tqdm(enumerate(data_loader), total=len(data_loader)):\n", + " batch_tokens = batch_value['tokens'].to(device)\n", + "\n", + " # get positions of all 11 and 13 token ids in batch\n", + " punct_pos = find_positions(batch_tokens, token_ids=[13])\n", + "\n", + " # get the loss for each token in the batch\n", + " initial_loss = model(batch_tokens, return_type=\"loss\", prepend_bos=False, loss_per_token=True)\n", + " \n", + " # add hooks for the activations of the 11 and 13 tokens\n", + " for layer, head in heads_to_ablate:\n", + " ablate_punct = partial(zero_attention_pos_hook, pos_by_batch=punct_pos, layer=layer, head_idx=head)\n", + " model.blocks[layer].attn.hook_pattern.add_hook(ablate_punct)\n", + "\n", + " # get the loss for each token when run with hooks\n", + " hooked_loss = model(batch_tokens, return_type=\"loss\", prepend_bos=False, loss_per_token=True)\n", + "\n", + " # compute the percent difference between the two losses\n", + " loss_diff = (hooked_loss - initial_loss) / initial_loss\n", + "\n", + " loss_list.append(loss_diff)\n", + "\n", + " model.reset_hooks()\n", + " return loss_list, batch_tokens" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "from utils.ablation import ablate_resid_with_precalc_mean\n", + "\n", + "def compute_mean_ablation_modified_loss(model: HookedTransformer, data_loader: DataLoader, cached_means, target_token_ids) -> float:\n", + " total_loss = 0\n", + " loss_diff_list = []\n", + " orig_loss_list = []\n", + " for _, batch_value in tqdm(enumerate(data_loader), total=len(data_loader)):\n", + " if isinstance(batch_value['tokens'], list):\n", + " batch_tokens = torch.stack(batch_value['tokens']).to(device)\n", + " else:\n", + " batch_tokens = batch_value['tokens'].to(device)\n", + "\n", + " batch_tokens = einops.rearrange(batch_tokens, 'seq batch -> batch seq')\n", + " punct_pos = batch_value['positions']\n", + " print(f\"punct_pos: {punct_pos}\")\n", + "\n", + " # get the loss for each token in the batch\n", + " initial_loss = model(batch_tokens, return_type=\"loss\", prepend_bos=False, loss_per_token=True)\n", + " print(f\"initial loss shape: {initial_loss.shape}\")\n", + " orig_loss_list.append(initial_loss)\n", + " \n", + " # add hooks for the activations of the 11 and 13 tokens\n", + " for layer, head in heads_to_ablate:\n", + " mean_ablate_comma = partial(ablate_resid_with_precalc_mean_no_batch, cached_means=cached_means, pos_by_batch=punct_pos, layer=layer)\n", + " model.blocks[layer].hook_resid_post.add_hook(mean_ablate_comma)\n", + "\n", + " # get the loss for each token when run with hooks\n", + " print(f\"batch tokens shape: {batch_tokens.shape}\")\n", + " \n", + " hooked_loss = model(batch_tokens, return_type=\"loss\", prepend_bos=False, loss_per_token=True)\n", + " print(f\"hooked loss shape: {hooked_loss.shape}\")\n", + "\n", + " # compute the difference between the two losses\n", + " loss_diff = hooked_loss - initial_loss\n", + " \n", + " # set all positions right after punct_pos to zero\n", + " for p in punct_pos:\n", + " print(f\"zeroing {p}\")\n", + " loss_diff[0, p] = 0\n", + "\n", + " loss_diff_list.append(loss_diff)\n", + "\n", + " model.reset_hooks()\n", + " return loss_diff_list, orig_loss_list" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "def ablate_resid_with_precalc_mean_no_batch(\n", + " component: Float[Tensor, \"batch ...\"],\n", + " hook: HookPoint,\n", + " cached_means: Float[Tensor, \"layer ...\"],\n", + " pos_by_batch: List[Tensor],\n", + " layer: int = 0,\n", + ") -> Float[Tensor, \"batch ...\"]:\n", + " \"\"\"\n", + " Mean-ablates a batch tensor\n", + "\n", + " :param component: the tensor to compute the mean over the batch dim of\n", + " :return: the mean over the cache component of the tensor\n", + " \"\"\"\n", + " assert 'resid' in hook.name\n", + "\n", + " #print(f\"batch size: {batch_size} pos_by_batch: {len(pos_by_batch)}\")\n", + "\n", + " for p in pos_by_batch:\n", + " component[:, p] = cached_means[layer]\n", + " \n", + " return component" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "def convert_to_tensors(dataset, column_name='tokens'):\n", + " token_buffer = []\n", + " final_batches = []\n", + " \n", + " for batch in dataset:\n", + " trimmed_batch = batch[column_name] #[batch[column_name][0]] + [token for token in batch[column_name] if token != 0]\n", + " final_batches.append(trimmed_batch)\n", + " \n", + " # Convert list of batches to tensors\n", + " final_batches = [torch.tensor(batch, dtype=torch.long) for batch in final_batches]\n", + " # Create a new dataset with specified features\n", + " features = Features({\"tokens\": Sequence(Value(\"int64\"))})\n", + " final_dataset = Dataset.from_dict({\"tokens\": final_batches}, features=features)\n", + "\n", + " final_dataset.set_format(type=\"torch\", columns=[\"tokens\"])\n", + " \n", + " return final_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_neuroscope(\n", + " tokens: Int[Tensor, \"batch pos\"], centred: bool = False, activations: Float[Tensor, \"pos layer 1\"] = None,\n", + " verbose=False,\n", + "):\n", + " \n", + " str_tokens = model.to_str_tokens(tokens, prepend_bos=False)\n", + "\n", + " if verbose:\n", + " print(f\"Tokens shape: {tokens.shape}\")\n", + " \n", + " if centred:\n", + " if verbose:\n", + " print(\"Centering activations\")\n", + " layer_means = einops.reduce(activations, \"pos layer 1 -> 1 layer 1\", reduction=\"mean\")\n", + " layer_means = einops.repeat(layer_means, \"1 layer 1 -> pos layer 1\", pos=activations.shape[0])\n", + " activations -= layer_means\n", + " elif verbose:\n", + " print(\"Activations already centered\")\n", + " assert (\n", + " activations.ndim == 3\n", + " ), f\"activations must be of shape [tokens x layers x neurons], found {activations.shape}\"\n", + " assert len(str_tokens) == activations.shape[0], (\n", + " f\"tokens and activations must have the same length, found tokens={len(str_tokens)} and acts={activations.shape[0]}, \"\n", + " f\"tokens={str_tokens}, \"\n", + " f\"activations={activations.shape}\"\n", + "\n", + " )\n", + " return text_neuron_activations(\n", + " tokens=str_tokens, \n", + " activations=activations,\n", + " first_dimension_name=\"Layer (resid_pre)\",\n", + " second_dimension_name=\"Model\",\n", + " second_dimension_labels=[\"pythia-2.8b\"],\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Comma Ablation on Natural Text" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4f6d1769f95d4ca58c10f7e68ec52953", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Downloading (…)lve/main/config.json: 0%| | 0.00/571 [00:00= max_length:\n", + " final_batch = token_buffer[:max_length]\n", + " token_buffer = token_buffer[max_length:]\n", + " final_batches.append(final_batch)\n", + " \n", + " # Handle any remaining tokens\n", + " if len(token_buffer) > 0:\n", + " final_batches.append(token_buffer)\n", + " \n", + " # Convert list of batches to tensors\n", + " final_batches = [torch.tensor(batch) for batch in final_batches]\n", + " \n", + " # Create a new dataset with specified features\n", + " features = Features({\"tokens\": Sequence(Value(\"int64\"))})\n", + " final_dataset = Dataset.from_dict({\"tokens\": final_batches}, features=features)\n", + "\n", + " final_dataset.set_format(type=\"torch\", columns=[\"tokens\"])\n", + " \n", + " return final_dataset\n", + "\n", + "# # Example usage\n", + "# tokenizer = AutoTokenizer.from_pretrained(\"gpt2\") # Make sure the tokenizer has bos_token_id and eos_token_id\n", + "# text_dataset = Dataset.from_dict({\"text\": [\"This is a sample text.\", \"Another sample text.\"]}) # Example dataset\n", + "# tokenized_dataset = tokenize_and_concatenate2(text_dataset, tokenizer, max_length=1024)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(\"The Rock is destined to be the 21st Century's new ''Conan'' and that he's going to make a splash even greater than Arnold Schwarzenegger, Jean-Claud Van Damme or Steven Segal.\",\n", + " 1)" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sst_data['train'][0]['text'], sst_data['train'][0]['label']" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum..\n" + ] + } + ], + "source": [ + "lorem_ipsum = \"Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.\"\n", + "lorem_ipsum_tokens = model.to_tokens(lorem_ipsum, prepend_bos=False)\n", + "lorem_ipsum_insert = \" \" + model.to_string(lorem_ipsum_tokens[:10])[0] + \".\"\n", + "print(lorem_ipsum_insert)" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "..........\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "cf4c255f10584a55a20bdbe428b3d32e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Map: 0%| | 0/7864 [00:00 1.0\n", + " return example\n", + "\n", + "# Use the map function to apply the filter_function\n", + "sst_data_with_flag_train = sst_data['train'].map(filter_function, keep_in_memory=True)\n", + "sst_data_with_flag_val = sst_data['dev'].map(filter_function, keep_in_memory=True)\n", + "sst_data_with_flag_test = sst_data['test'].map(filter_function, keep_in_memory=True)\n", + "sst_data_with_flag = concatenate_datasets([sst_data_with_flag_train, sst_data_with_flag_val, sst_data_with_flag_test])\n", + "\n", + "# Use the filter function to keep only the examples where 'keep_example' is True\n", + "sst_zero_shot = sst_data_with_flag.filter(lambda x: x['keep_example'])\n", + "\n", + "# save dataset\n", + "#sst_zero_shot.save_to_disk(\"sst_zero_shot\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 75, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f794fcf6dce141b1900944da821fed29", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Map: 0%| | 0/4772 [00:00 0 else False\n", + "\n", + " return {'positions': positions, 'has_token': has_token}\n", + "\n", + "def convert_answers(example, pos_answer_id=29071, neg_answer_id=32725):\n", + " if example['label'] == 1:\n", + " answers = torch.tensor([pos_answer_id, neg_answer_id])\n", + " else:\n", + " answers = torch.tensor([neg_answer_id, pos_answer_id])\n", + "\n", + " return {'answers': answers}\n", + "\n", + "\n", + "dataset = sst_zero_shot.map(concatenate_classification_prompts, batched=False)\n", + "dataset = dataset.map(tokenize_function, batched=False)\n", + "dataset = dataset.map(convert_answers, batched=False)\n", + "dataset = dataset.rename_column(\"input_ids\", \"tokens\")\n", + "dataset.set_format(type=\"torch\", columns=[\"tokens\", \"attention_mask\", \"label\", \"answers\"])\n", + "dataset = dataset.map(find_dataset_positions, batched=False)\n", + "dataset = dataset.filter(lambda example: example['has_token']==True)\n", + "dataset\n" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(\"The Rock is destined to be the 21st Century's new ''Conan'' and that he's going to make a splash even greater than Arnold Schwarzenegger, Jean-Claud Van Damme or Steven Segal........... Review Sentiment:111111111111\",\n", + " [' Positive'])" + ] + }, + "execution_count": 76, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.to_string(dataset[0]['tokens']), model.to_str_tokens(dataset[0]['answers'][0])" + ] + }, + { + "cell_type": "code", + "execution_count": 77, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(0.9671, device='cuda:0')" + ] + }, + "execution_count": 77, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from utils.circuit_analysis import get_logit_diff\n", + "logits, cache = model.run_with_cache(dataset['tokens'][0])\n", + "get_logit_diff(logits, dataset['answers'][0].unsqueeze(0).unsqueeze(0).to(device))" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "760bfc70e0234f7eb5f2ce14627d4ba4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Filter: 0%| | 0/2683 [00:00 Float[Tensor, \"batch ...\"]:\n", + " \"\"\"\n", + " Mean-ablates a batch tensor\n", + "\n", + " :param component: the tensor to compute the mean over the batch dim of\n", + " :return: the mean over the cache component of the tensor\n", + " \"\"\"\n", + " assert 'resid' in hook.name\n", + "\n", + " # Identify the positions where pos_by_batch is 1\n", + " batch_indices, sequence_positions = torch.where(pos_by_batch == 1)\n", + "\n", + " # Replace the corresponding positions in component with cached_means[layer]\n", + " component[batch_indices, sequence_positions] = cached_means[layer]\n", + "\n", + " return component" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "metadata": {}, + "outputs": [], + "source": [ + "def compute_mean_ablation_modified_logit_diff(model: HookedTransformer, data_loader: DataLoader, cached_means, target_token_ids) -> float:\n", + " \n", + " orig_ld_list = []\n", + " ablated_ld_list = []\n", + " freeze_ablated_ld_list = []\n", + " \n", + " for _, batch_value in tqdm(enumerate(data_loader), total=len(data_loader)):\n", + " batch_tokens = batch_value['tokens'].to(device)\n", + " punct_pos = batch_value['positions'].to(device)\n", + "\n", + " # get the logit diff for the last token in each sequence\n", + " orig_logits, clean_cache = model.run_with_cache(batch_tokens, return_type=\"logits\", prepend_bos=False)\n", + " orig_ld = compute_last_position_logit_diff(orig_logits, batch_value['attention_mask'], batch_value['answers'])\n", + " orig_ld_list.append(orig_ld)\n", + " \n", + " # repeat with commas ablated\n", + " for layer in layers_to_ablate:\n", + " mean_ablate_comma = partial(ablate_resid_with_precalc_mean, cached_means=cached_means, pos_by_batch=punct_pos, layer=layer)\n", + " model.blocks[layer].hook_resid_post.add_hook(mean_ablate_comma)\n", + " \n", + " ablated_logits = model(batch_tokens, return_type=\"logits\", prepend_bos=False)\n", + " ablated_ld = compute_last_position_logit_diff(ablated_logits, batch_value['attention_mask'], batch_value['answers'])\n", + " ablated_ld_list.append(ablated_ld)\n", + " \n", + " model.reset_hooks()\n", + "\n", + " # repeat with attention frozen and commas ablated\n", + " for layer, head in heads_to_freeze:\n", + " freeze_attn = partial(freeze_attn_pattern_hook, cache=clean_cache, layer=layer, head_idx=head)\n", + " model.blocks[layer].attn.hook_pattern.add_hook(freeze_attn)\n", + "\n", + " for layer in layers_to_ablate:\n", + " mean_ablate_comma = partial(ablate_resid_with_precalc_mean, cached_means=cached_means, pos_by_batch=punct_pos, layer=layer)\n", + " model.blocks[layer].hook_resid_post.add_hook(mean_ablate_comma)\n", + " \n", + " freeze_ablated_logits = model(batch_tokens, return_type=\"logits\", prepend_bos=False)\n", + " freeze_ablated_ld = compute_last_position_logit_diff(freeze_ablated_logits, batch_value['attention_mask'], batch_value['answers'])\n", + " freeze_ablated_ld_list.append(freeze_ablated_ld)\n", + " \n", + " model.reset_hooks()\n", + "\n", + " return torch.cat(orig_ld_list), torch.cat(ablated_ld_list), torch.cat(freeze_ablated_ld_list)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Comma Mean Ablation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Positive Prompt Results" + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "84b7c0a8e6b74a2daf5c51b2967c9b70", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/12 [00:00 torch.Tensor:\n", + " \"\"\"\n", + " Ablates a batch tensor by removing the influence of a direction vector from it.\n", + "\n", + " Args:\n", + " component: the tensor to compute the mean over the batch dim of\n", + " direction_vector: the direction vector to remove from the component\n", + " multiplier: the multiplier to apply to the direction vector\n", + " pos_by_batch: the positions to ablate\n", + " layer: the layer to ablate\n", + "\n", + " Returns:\n", + " the ablated component\n", + " \"\"\"\n", + " assert 'resid' in hook.name\n", + "\n", + " # Normalize the direction vector to make sure it's a unit vector\n", + " D_normalized = direction_vector[layer] / torch.norm(direction_vector[layer])\n", + "\n", + " # Calculate the projection of component onto direction_vector\n", + " proj = einops.einsum(component, D_normalized, \"b s d, d -> b s\").unsqueeze(-1) * D_normalized\n", + " \n", + "\n", + " # Ablate the direction from component\n", + " component_ablated = component.clone() # Create a copy to ensure original is not modified\n", + " if pos_by_batch is not None:\n", + " batch_indices, sequence_positions = torch.where(pos_by_batch == 1)\n", + " component_ablated[batch_indices, sequence_positions] = component[batch_indices, sequence_positions] - multiplier * proj[batch_indices, sequence_positions]\n", + " \n", + " # Print the (batch, pos) coordinates of all d_model vectors that were ablated\n", + " # for b, s in zip(batch_indices, sequence_positions):\n", + " # print(f\"(batch, pos) = ({b.item()}, {s.item()})\")\n", + "\n", + " # Check that positions not in (batch_indices, sequence_positions) were not ablated\n", + " check_mask = torch.ones_like(component, dtype=torch.bool)\n", + " check_mask[batch_indices, sequence_positions] = 0\n", + " if not torch.all(component[check_mask] == component_ablated[check_mask]):\n", + " raise ValueError(\"Positions outside of specified (batch_indices, sequence_positions) were ablated!\")\n", + "\n", + " return component_ablated" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "metadata": {}, + "outputs": [], + "source": [ + "def compute_directional_ablation_modified_logit_diff(model: HookedTransformer, data_loader: DataLoader, direction_vectors, multiplier=1.0, target_token_ids=13) -> float:\n", + " \n", + " orig_ld_list = []\n", + " ablated_ld_list = []\n", + " freeze_ablated_ld_list = []\n", + " \n", + " for _, batch_value in tqdm(enumerate(data_loader), total=len(data_loader)):\n", + " batch_tokens = batch_value['tokens'].to(device)\n", + " labels = batch_value['label'].to(device)\n", + " punct_pos = batch_value['positions'].to(device)\n", + "\n", + " # get the logit diff for the last token in each sequence\n", + " orig_logits, clean_cache = model.run_with_cache(batch_tokens, return_type=\"logits\", prepend_bos=False)\n", + " orig_ld = compute_last_position_logit_diff(orig_logits, batch_value['attention_mask'], batch_value['answers'])\n", + " orig_ld_list.append(orig_ld)\n", + " \n", + " # repeat with commas ablated\n", + " for layer in layers_to_ablate:\n", + " dir_ablate_comma = partial(ablate_resid_with_direction, labels=labels, direction_vector=direction_vectors, multiplier=multiplier, pos_by_batch=punct_pos, layer=layer)\n", + " model.blocks[layer].hook_resid_post.add_hook(dir_ablate_comma)\n", + " \n", + " ablated_logits = model(batch_tokens, return_type=\"logits\", prepend_bos=False)\n", + " # check to see if ablated_logits has any nan values\n", + " if torch.isnan(ablated_logits).any():\n", + " print(\"ablated logits has nan values\")\n", + " ablated_ld = compute_last_position_logit_diff(ablated_logits, batch_value['attention_mask'], batch_value['answers'])\n", + " ablated_ld_list.append(ablated_ld)\n", + " \n", + " model.reset_hooks()\n", + "\n", + " # repeat with attention frozen and commas ablated\n", + " for layer, head in heads_to_freeze:\n", + " freeze_attn = partial(freeze_attn_pattern_hook, cache=clean_cache, layer=layer, head_idx=head)\n", + " model.blocks[layer].attn.hook_pattern.add_hook(freeze_attn)\n", + "\n", + " for layer in layers_to_ablate:\n", + " dir_ablate_comma = partial(ablate_resid_with_direction, labels=labels, direction_vector=direction_vectors, multiplier=multiplier, pos_by_batch=punct_pos, layer=layer)\n", + " model.blocks[layer].hook_resid_post.add_hook(dir_ablate_comma)\n", + " \n", + " freeze_ablated_logits = model(batch_tokens, return_type=\"logits\", prepend_bos=False)\n", + " freeze_ablated_ld = compute_last_position_logit_diff(freeze_ablated_logits, batch_value['attention_mask'], batch_value['answers'])\n", + " freeze_ablated_ld_list.append(freeze_ablated_ld)\n", + " \n", + " model.reset_hooks()\n", + "\n", + " return torch.cat(orig_ld_list), torch.cat(ablated_ld_list), torch.cat(freeze_ablated_ld_list)" + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "metadata": {}, + "outputs": [], + "source": [ + "model.reset_hooks()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Positive Prompt Results" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "94f760539b464013844167b6bd7bae3a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/12 [00:00 float:\n", + " \n", + " orig_ld_list = []\n", + " ablated_ld_list = []\n", + " freeze_ablated_ld_list = []\n", + " \n", + " for _, batch_value in tqdm(enumerate(data_loader), total=len(data_loader)):\n", + " batch_tokens = batch_value['tokens'].to(device)\n", + " labels = batch_value['label'].to(device)\n", + " punct_pos = batch_value['attention_mask'].to(device)\n", + "\n", + " # get the logit diff for the last token in each sequence\n", + " orig_logits, clean_cache = model.run_with_cache(batch_tokens, return_type=\"logits\", prepend_bos=False)\n", + " orig_ld = compute_last_position_logit_diff(orig_logits, batch_value['attention_mask'], batch_value['answers'])\n", + " orig_ld_list.append(orig_ld)\n", + " \n", + " # repeat with commas ablated\n", + " for layer in layers_to_ablate:\n", + " dir_ablate_comma = partial(ablate_resid_with_direction, labels=labels, direction_vector=direction_vectors, multiplier=multiplier, pos_by_batch=punct_pos, layer=layer)\n", + " model.blocks[layer].hook_resid_post.add_hook(dir_ablate_comma)\n", + " \n", + " ablated_logits = model(batch_tokens, return_type=\"logits\", prepend_bos=False)\n", + " # check to see if ablated_logits has any nan values\n", + " if torch.isnan(ablated_logits).any():\n", + " print(\"ablated logits has nan values\")\n", + " ablated_ld = compute_last_position_logit_diff(ablated_logits, batch_value['attention_mask'], batch_value['answers'])\n", + " ablated_ld_list.append(ablated_ld)\n", + " \n", + " model.reset_hooks()\n", + "\n", + " # repeat with attention frozen and commas ablated\n", + " for layer, head in heads_to_freeze:\n", + " freeze_attn = partial(freeze_attn_pattern_hook, cache=clean_cache, layer=layer, head_idx=head)\n", + " model.blocks[layer].attn.hook_pattern.add_hook(freeze_attn)\n", + "\n", + " for layer in layers_to_ablate:\n", + " dir_ablate_comma = partial(ablate_resid_with_direction, labels=labels, direction_vector=direction_vectors, multiplier=multiplier, pos_by_batch=punct_pos, layer=layer)\n", + " model.blocks[layer].hook_resid_post.add_hook(dir_ablate_comma)\n", + " \n", + " freeze_ablated_logits = model(batch_tokens, return_type=\"logits\", prepend_bos=False)\n", + " freeze_ablated_ld = compute_last_position_logit_diff(freeze_ablated_logits, batch_value['attention_mask'], batch_value['answers'])\n", + " freeze_ablated_ld_list.append(freeze_ablated_ld)\n", + " \n", + " model.reset_hooks()\n", + "\n", + " return torch.cat(orig_ld_list), torch.cat(ablated_ld_list), torch.cat(freeze_ablated_ld_list)" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "metadata": {}, + "outputs": [], + "source": [ + "model.reset_hooks()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Positive Prompt Results" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4a656989b098492caf42184ca5c8c5ed", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/12 [00:00 (batch item) token\")\n", + "loss_change_by_token_by_row = loss_change_by_token_by_row.unsqueeze(2)\n", + "loss_change_by_token_by_row.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 145, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'data/pythia-2.8b/loss_change_by_token_by_row_sst.npy'" + ] + }, + "execution_count": 145, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "save_array(loss_change_by_token_by_row, 'loss_change_by_token_by_row_sst.npy', model)" + ] + }, + { + "cell_type": "code", + "execution_count": 146, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(8.0657)" + ] + }, + "execution_count": 146, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loss_change_by_token_by_row.sum()" + ] + }, + { + "cell_type": "code", + "execution_count": 153, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 0.0000],\n", + " [ 0.0000],\n", + " [ 0.0000],\n", + " [ 0.0000],\n", + " [ 0.0000],\n", + " [ 0.4284],\n", + " [ 0.4954],\n", + " [ 0.3777],\n", + " [ 0.7570],\n", + " [ 0.6130],\n", + " [ 0.0000],\n", + " [ 0.4558],\n", + " [ 0.1512],\n", + " [-0.5029],\n", + " [-0.4370],\n", + " [ 0.0000],\n", + " [ 0.5531],\n", + " [-0.4186],\n", + " [ 0.0285],\n", + " [-0.0923],\n", + " [ 0.0000],\n", + " [ 0.0777],\n", + " [ 0.0282],\n", + " [ 0.0304],\n", + " [ 0.0000],\n", + " [ 0.3828],\n", + " [-0.1881],\n", + " [-0.0471],\n", + " [-0.1300],\n", + " [-0.0969],\n", + " [ 0.0932],\n", + " [ 0.0637]])" + ] + }, + "execution_count": 153, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loss_change_by_token_by_row[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 150, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Top 40 most positive examples:\n", + "Example: Yas, Activation: 2.0176, Batch: 76, Pos: 31\n", + "Example: ., Activation: 1.3744, Batch: 33, Pos: 29\n", + "Example: II, Activation: 1.2075, Batch: 7, Pos: 28\n", + "Example: ,, Activation: 1.1415, Batch: 14, Pos: 17\n", + "Example: weight, Activation: 0.8663, Batch: 18, Pos: 17\n", + "Example: roles, Activation: 0.7570, Batch: 1, Pos: 8\n", + "Example: echoes, Activation: 0.7151, Batch: 2, Pos: 14\n", + "Example: buy, Activation: 0.6960, Batch: 35, Pos: 9\n", + "Example: 'm, Activation: 0.6842, Batch: 48, Pos: 31\n", + "Example: .'', Activation: 0.6383, Batch: 10, Pos: 12\n", + "Example: ., Activation: 0.6022, Batch: 25, Pos: 26\n", + "Example: Min, Activation: 0.5999, Batch: 4, Pos: 7\n", + "Example: <|endoftext|>, Activation: 0.5333, Batch: 0, Pos: 25\n", + "Example: as, Activation: 0.4545, Batch: 39, Pos: 29\n", + "Example: it, Activation: 0.3246, Batch: 53, Pos: 23\n", + "Example: string, Activation: 0.2914, Batch: 3, Pos: 12\n", + "Example: the, Activation: 0.2676, Batch: 27, Pos: 26\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Top 40 most negative examples:\n", + "Example: part, Activation: -1.3371, Batch: 0, Pos: 16\n", + "Example: ultimately, Activation: -0.9171, Batch: 14, Pos: 19\n", + "Example: Report, Activation: -0.7714, Batch: 4, Pos: 9\n", + "Example: Cris, Activation: -0.5618, Batch: 2, Pos: 22\n", + "Example: and, Activation: -0.5029, Batch: 1, Pos: 13\n", + "Example: still, Activation: -0.4266, Batch: 10, Pos: 7\n", + "Example: seem, Activation: -0.3844, Batch: 24, Pos: 7\n", + "Example: dialogue, Activation: -0.3544, Batch: 18, Pos: 23\n", + "Example: erving, Activation: -0.3078, Batch: 25, Pos: 23\n", + "Example: perhaps, Activation: -0.3037, Batch: 53, Pos: 26\n", + "Example: protagon, Activation: -0.2978, Batch: 39, Pos: 31\n", + "Example: itus, Activation: -0.2772, Batch: 33, Pos: 26\n", + "Example: camp, Activation: -0.2713, Batch: 7, Pos: 23\n", + "Example: earth, Activation: -0.2532, Batch: 27, Pos: 30\n", + "Example: ache, Activation: -0.1927, Batch: 76, Pos: 29\n", + "Example: the, Activation: -0.1866, Batch: 16, Pos: 23\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from utils.neuroscope import plot_topk\n", + "loss_change_by_token = torch.from_numpy(load_array('loss_change_by_token_by_row_sst.npy', model))\n", + "plot_topk(\n", + " activations=loss_change_by_token_by_row, \n", + " dataloader=subset_data_loader_tkns, \n", + " window_size=64, \n", + " model=model, \n", + " k=40, \n", + " centred=False, \n", + " #exclusions=[\" '\", \" ,\", \",\", \".\",\" .\",\" Fig\", \"'t\", \" Pinterest\", \" Kampf\", \"m\", \"uk\", \" Kamp\", \"com\", \"edu\", \"S\", \"youtube\", \"twitter\", \"0\", \"js\", \"py\", \" Protein\", \" Fiber\", \" Carbohydrates\", \" Sugar\", \" Grant\", \" Pub\", \",\"]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Comma Ablation for Classification" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "import random\n", + "\n", + "from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM\n", + "from transformers import TrainingArguments, Trainer\n", + "from transformer_lens import HookedTransformer\n", + "from datasets import load_dataset, Dataset, DatasetDict\n", + "from tqdm.notebook import tqdm\n", + "from utils.store import load_pickle, load_array\n", + "from utils.ablation import ablate_resid_with_precalc_mean" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "aa26894959d64ac296e1cb2a60ee4322", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Downloading builder script: 0%| | 0.00/4.31k [00:00
These Lines are Just Filler. The movie was bad. Why I have to expand on that I don't know. This is already a waste of my time. I just wanted to warn others. Avoid this movie. The acting sucks and the writing is just moronic. Bad in every way. The only nice thing about the movie are Deniz Akkaya's breasts. Even that was ruined though by a terrible and unneeded rape scene. The movie is a poorly contrived and totally unbelievable piece of garbage.

OK now I am just going to rag on IMDb for this stupid rule of 10 lines of text minimum. First I waste my time watching this offal. Then feeling compelled to warn others I create an account with IMDb only to discover that I have to write a friggen essay on the film just to express how bad I think it is. Totally unnecessary.\",\n", + " 'label': 0}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset = load_dataset(\"imdb\")\n", + "dataset[\"train\"][100]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### GPT-2 Classifier" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a1d270f58324445a8d73448eed87a48a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Downloading (…)lve/main/config.json: 0%| | 0.00/665 [00:00 float:\n", + "\n", + " for _, item in tqdm(enumerate(dataset), total=len(dataset)):\n", + " batch_tokens = model.to_tokens(item['text'], prepend_bos=False)\n", + " print(batch_tokens.shape)\n", + "\n", + " # # get positions of all 13 and 15 token ids in batch\n", + " # punct_pos = find_positions(batch_tokens, token_ids=target_token_ids)\n", + "\n", + " # # get the loss for each token in the batch\n", + " # initial_loss = model(batch_tokens, return_type=\"loss\", prepend_bos=False, loss_per_token=True)\n", + " # orig_loss_list.append(initial_loss)\n", + " \n", + " # # add hooks for the activations of the 13 and 15 tokens\n", + " # for layer, head in heads_to_ablate:\n", + " # mean_ablate_comma = partial(ablate_resid_with_precalc_mean, cached_means=cached_means, pos_by_batch=punct_pos, layer=layer)\n", + " # model.blocks[layer].hook_resid_post.add_hook(mean_ablate_comma)\n", + "\n", + " # # get the loss for each token when run with hooks\n", + " # hooked_loss = model.run_with_cache(batch_tokens, return_type=\"loss\", prepend_bos=False, loss_per_token=True)\n", + "\n", + " # # compute the percent difference between the two losses\n", + " # loss_diff = hooked_loss - initial_loss\n", + " # loss_diff_list.append(loss_diff)\n", + "\n", + " # model.reset_hooks()\n", + " # return loss_diff_list, orig_loss_list" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [ + { + "ename": "FileNotFoundError", + "evalue": "[Errno 2] No such file or directory: 'data/gpt2-small/comma_mean_values.npy'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/home/curttigges/proj/eliciting-latent-sentiment/owt_comma_ablation.ipynb Cell 42\u001b[0m in \u001b[0;36m2\n\u001b[1;32m 1\u001b[0m \u001b[39m# load the files\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m comma_mean_values \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mfrom_numpy(load_array(\u001b[39m'\u001b[39;49m\u001b[39mcomma_mean_values.npy\u001b[39;49m\u001b[39m'\u001b[39;49m, model))\u001b[39m.\u001b[39mto(device)\n\u001b[1;32m 3\u001b[0m period_mean_values \u001b[39m=\u001b[39m torch\u001b[39m.\u001b[39mfrom_numpy(load_array(\u001b[39m'\u001b[39m\u001b[39mperiod_mean_values.npy\u001b[39m\u001b[39m'\u001b[39m, model))\u001b[39m.\u001b[39mto(device)\n", + "File \u001b[0;32m/notebooks/eliciting-latent-sentiment/utils/store.py:186\u001b[0m, in \u001b[0;36mload_array\u001b[0;34m(label, model)\u001b[0m\n\u001b[1;32m 184\u001b[0m model_path \u001b[39m=\u001b[39m os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(\u001b[39m'\u001b[39m\u001b[39mdata\u001b[39m\u001b[39m'\u001b[39m, model)\n\u001b[1;32m 185\u001b[0m path \u001b[39m=\u001b[39m os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(model_path, label \u001b[39m+\u001b[39m \u001b[39m'\u001b[39m\u001b[39m.npy\u001b[39m\u001b[39m'\u001b[39m)\n\u001b[0;32m--> 186\u001b[0m \u001b[39mwith\u001b[39;00m \u001b[39mopen\u001b[39;49m(path, \u001b[39m'\u001b[39;49m\u001b[39mrb\u001b[39;49m\u001b[39m'\u001b[39;49m) \u001b[39mas\u001b[39;00m f:\n\u001b[1;32m 187\u001b[0m array \u001b[39m=\u001b[39m np\u001b[39m.\u001b[39mload(f)\n\u001b[1;32m 188\u001b[0m \u001b[39mreturn\u001b[39;00m array\n", + "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'data/gpt2-small/comma_mean_values.npy'" + ] + } + ], + "source": [ + "# load the files\n", + "comma_mean_values = torch.from_numpy(load_array('comma_mean_values.npy', model)).to(device)\n", + "period_mean_values = torch.from_numpy(load_array('period_mean_values.npy', model)).to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ea6ec2bc38df4f9d87fcdfd05723f17b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/2000 [00:00