diff --git a/.github/ISSUE_TEMPLATE/compatibility.md b/.github/ISSUE_TEMPLATE/compatibility.md new file mode 100644 index 000000000..60a4632c9 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/compatibility.md @@ -0,0 +1,35 @@ +--- +name: Compatibility Report +about: Submit a compatibility report +title: "[Compatibility Report] Model ID" + +--- + + + +## Model + +REPLACE_WITH_MODEL_ID + +- [ ] This model was incompatible when it was introduced to TransformerLens + + + +The model seems to have worked as of REPLACE_WITH_LAST_COMPATIBLE_VERSION_NUMBER. It first started +showing signs of incompatibility in REPLACE_WITH_FIRST_INCOMPATIBLE_VERSION_NUMBER. + +### Example of some generations in transformers + + +### Code used to load the model in TransformerLens + + +### Example of some generations in TransformerLens diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 1b71d373e..fb686122d 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -123,6 +123,7 @@ jobs: # - "Activation_Patching_in_TL_Demo" # - "Attribution_Patching_Demo" - "ARENA_Content" + - "Colab_Compatibility" - "BERT" - "Exploratory_Analysis_Demo" # - "Grokking_Demo" diff --git a/debugging/hf-tl-logit-comparator.ipynb b/debugging/hf-tl-logit-comparator.ipynb new file mode 100644 index 000000000..ee445c397 --- /dev/null +++ b/debugging/hf-tl-logit-comparator.ipynb @@ -0,0 +1,265 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Logit Comparator for HuggingFace and TransformerLens Outputs\n", + "This notebook is a quick and dirty tool to compare the logit outputs of a HuggingFace model and a TransformerLens model via several different metrics. It is intended to help debug issues with the TransformerLens model, such as bugs in the model's implementation. If you identify any issues, please open an issue on the [GitHub repository](https://github.com/TransformerLensOrg/TransformerLens)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer, AutoModelForCausalLM\n", + "from transformer_lens import HookedTransformer\n", + "import torch\n", + "import torch.nn.functional as F\n", + "\n", + "if torch.backends.mps.is_available():\n", + " device = \"mps\"\n", + "else:\n", + " device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "\n", + "torch.set_grad_enabled(False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Comparator Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [], + "source": [ + "model_name = \"EleutherAI/pythia-2.8b\" # You can change this to any model name\n", + "sentence = \"The quick brown fox\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from huggingface_hub import login\n", + "login(token=\"\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get Transformers Logits" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from transformers import AutoTokenizer, AutoModelForCausalLM\n", + "\n", + "def load_model(model_name=\"gpt2\"):\n", + " tokenizer = AutoTokenizer.from_pretrained(model_name)\n", + " model = AutoModelForCausalLM.from_pretrained(model_name)\n", + " return model, tokenizer\n", + "\n", + "def get_logits(model, tokenizer, sentence, device):\n", + " # Tokenize the input sentence\n", + " inputs = tokenizer(sentence, return_tensors=\"pt\")\n", + " \n", + " # Move inputs to the device\n", + " inputs = {k: v.to(device) for k, v in inputs.items()}\n", + " \n", + " # Generate the logits\n", + " with torch.no_grad():\n", + " outputs = model(**inputs)\n", + " \n", + " # Get the logits for all tokens\n", + " logits = outputs.logits\n", + " \n", + " return logits\n", + "\n", + "model, tokenizer = load_model(model_name)\n", + "model = model.to(device)\n", + "\n", + "hf_logits = get_logits(model, tokenizer, sentence, device)[:, -1, :]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Get TransformerLens Logits" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = HookedTransformer.from_pretrained_no_processing(model_name, device=device)\n", + "tokens = model.to_tokens(sentence, prepend_bos=False)\n", + "tl_logits = model(tokens)[:, -1, :]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Compare Logit Distributions\n", + "Various metrics are used to compare the logit distributions of the two models. We don't yet have standard values for what constitutes a \"good\" logit comparison, so we are working on establishing benchmarks." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"HF Logits Shape: {hf_logits.shape}\")\n", + "print(f\"TL Logits Shape: {tl_logits.shape}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Tensor Comparison" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "are_close = torch.allclose(tl_logits, hf_logits, rtol=1e-5, atol=1e-3)\n", + "print(f\"Are the logits close? {are_close}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Mean Squared Error" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Compare the logits with MSE\n", + "mse = torch.nn.functional.mse_loss(hf_logits, tl_logits)\n", + "print(f\"MSE: {mse}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Maximum Absolute Difference" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "max_diff = torch.max(torch.abs(tl_logits - hf_logits))\n", + "print(f\"Max Diff: {max_diff}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Cosine Similarity" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cosine_sim = F.cosine_similarity(tl_logits, hf_logits, dim=-1).mean()\n", + "print(f\"Cosine Sim: {cosine_sim}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### KL Divergence" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def kl_div(logits1: torch.Tensor, logits2: torch.Tensor) -> torch.Tensor:\n", + " probs1 = F.softmax(logits1, dim=-1)\n", + " probs2 = F.softmax(logits2, dim=-1)\n", + " return F.kl_div(probs1.log(), probs2, reduction='batchmean')\n", + "\n", + "kl_tl_hf = kl_div(tl_logits, hf_logits)\n", + "kl_hf_tl = kl_div(hf_logits, tl_logits)\n", + "print(f\"KL(TL||HF): {kl_tl_hf}\")\n", + "print(f\"KL(HF||TL): {kl_hf_tl}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sae-l", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/demos/Colab_Compatibility.ipynb b/demos/Colab_Compatibility.ipynb new file mode 100644 index 000000000..fca3304bb --- /dev/null +++ b/demos/Colab_Compatibility.ipynb @@ -0,0 +1,531 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running as a Jupyter notebook - intended for development only!\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_48359/2396058561.py:18: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", + " ipython.magic(\"load_ext autoreload\")\n", + "/var/folders/m3/z6c6rcdj1rbb2jh9vqpgvxg40000gn/T/ipykernel_48359/2396058561.py:19: DeprecationWarning: `magic(...)` is deprecated since IPython 0.13 (warning added in 8.1), use run_line_magic(magic_name, parameter_s).\n", + " ipython.magic(\"autoreload 2\")\n" + ] + } + ], + "source": [ + "# NBVAL_IGNORE_OUTPUT\n", + "# Janky code to do different setup when run in a Colab notebook vs VSCode\n", + "import os\n", + "\n", + "IN_GITHUB = os.getenv(\"GITHUB_ACTIONS\") == \"true\"\n", + "\n", + "try:\n", + " import google.colab\n", + " IN_COLAB = True\n", + " print(\"Running as a Colab notebook\")\n", + "except:\n", + " IN_COLAB = False\n", + " print(\"Running as a Jupyter notebook - intended for development only!\")\n", + " from IPython import get_ipython\n", + "\n", + " ipython = get_ipython()\n", + " # Code to automatically update the HookedTransformer code as its edited without restarting the kernel\n", + " ipython.magic(\"load_ext autoreload\")\n", + " ipython.magic(\"autoreload 2\")\n", + " \n", + "\n", + "\n", + "if IN_COLAB or IN_GITHUB:\n", + " # %pip install sentencepiece # Llama tokenizer requires sentencepiece\n", + " %pip install transformers>=4.31.0 # Llama requires transformers>=4.31.0 and transformers in turn requires Python 3.8\n", + " %pip install torch\n", + " %pip install tiktoken\n", + " %pip install transformer_lens\n", + " %pip install transformers_stream_generator\n", + " # !huggingface-cli login --token NEEL'S TOKEN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "from transformer_lens import HookedTransformer, HookedEncoderDecoder, HookedEncoder, loading\n", + "from transformers import AutoTokenizer, LlamaForCausalLM, LlamaTokenizer\n", + "from typing import List\n", + "import gc\n", + "\n", + "untested_models = []\n", + "untested_models.extend(loading.OFFICIAL_MODEL_NAMES)\n", + "\n", + "print(\"TransformerLens currently supports \" + str(len(untested_models)) + \" models out of the box.\")\n", + "\n", + "GENERATE = True\n", + "# Fill this in if you have llama weights uploaded, and you with to test those models\n", + "LLAMA_MODEL_PATH = \"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def mark_models_as_tested(model_set: List[str]) -> None:\n", + " for model in model_set:\n", + " untested_models.remove(model)\n", + " \n", + "\n", + "def run_set(model_set: List[str], device=\"cuda\") -> None:\n", + " for model in model_set:\n", + " print(\"Testing \" + model)\n", + " tl_model = HookedTransformer.from_pretrained_no_processing(model, device=device)\n", + " if GENERATE:\n", + " print(tl_model.generate(\"Hello my name is\"))\n", + " del tl_model\n", + " gc.collect()\n", + " if IN_COLAB:\n", + " %rm -rf /root/.cache/huggingface/hub/models*\n", + "\n", + "def run_llama_set(model_set: List[str], weight_root: str, device=\"cuda\") -> None:\n", + " for model in model_set:\n", + " print(\"Testing \" + model)\n", + " # to run this, make sure weight root is the root that contains all models with the \n", + " # sub directories sharing the same name as the model in the list of models\n", + " tokenizer = LlamaTokenizer.from_pretrained(weight_root + model)\n", + " hf_model = LlamaForCausalLM.from_pretrained(weight_root + model, low_cpu_mem_usage=True)\n", + " tl_model = HookedTransformer.from_pretrained_no_processing(\n", + " model, \n", + " hf_model=hf_model,\n", + " device=device,\n", + " fold_ln=False,\n", + " center_writing_weights=False,\n", + " center_unembed=False,\n", + " tokenizer=tokenizer,\n", + " )\n", + " if GENERATE:\n", + " print(tl_model.generate(\"Hello my name is\"))\n", + " del tl_model\n", + " gc.collect()\n", + " if IN_COLAB:\n", + " %rm -rf /root/.cache/huggingface/hub/models*\n", + "\n", + "\n", + "def run_encoder_decoder_set(model_set: List[str], device=\"cuda\") -> None:\n", + " for model in model_set:\n", + " print(\"Testing \" + model)\n", + " tokenizer = AutoTokenizer.from_pretrained(model)\n", + " tl_model = HookedEncoderDecoder.from_pretrained(model, device=device)\n", + " if GENERATE:\n", + " # Originally from the t5 demo\n", + " prompt = \"Hello, how are you? \"\n", + " inputs = tokenizer(prompt, return_tensors=\"pt\")\n", + " input_ids = inputs[\"input_ids\"]\n", + " attention_mask = inputs[\"attention_mask\"]\n", + " decoder_input_ids = torch.tensor([[tl_model.cfg.decoder_start_token_id]]).to(input_ids.device)\n", + "\n", + "\n", + " while True:\n", + " logits = tl_model.forward(input=input_ids, one_zero_attention_mask=attention_mask, decoder_input=decoder_input_ids)\n", + " # logits.shape == (batch_size (1), predicted_pos, vocab_size)\n", + "\n", + " token_idx = torch.argmax(logits[0, -1, :]).item()\n", + " print(\"generated token: \\\"\", tokenizer.decode(token_idx), \"\\\", token id: \", token_idx, sep=\"\")\n", + "\n", + " # append token to decoder_input_ids\n", + " decoder_input_ids = torch.cat([decoder_input_ids, torch.tensor([[token_idx]]).to(input_ids.device)], dim=-1)\n", + "\n", + " # break if End-Of-Sequence token generated\n", + " if token_idx == tokenizer.eos_token_id:\n", + " break\n", + " del tl_model\n", + " gc.collect()\n", + " if IN_COLAB:\n", + " %rm -rf /root/.cache/huggingface/hub/models*\n", + "\n", + "def run_encoder_only_set(model_set: List[str], device=\"cuda\") -> None:\n", + " for model in model_set:\n", + " print(\"Testing \" + model)\n", + " tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n", + " tl_model = HookedEncoder.from_pretrained(model, device=device)\n", + "\n", + " if GENERATE:\n", + " # Slightly adapted version of the BERT demo\n", + " prompt = \"The capital of France is [MASK].\"\n", + "\n", + " input_ids = tokenizer(prompt, return_tensors=\"pt\")[\"input_ids\"]\n", + "\n", + " logprobs = tl_model(input_ids)[input_ids == tokenizer.mask_token_id].log_softmax(dim=-1)\n", + " prediction = tokenizer.decode(logprobs.argmax(dim=-1).item())\n", + "\n", + " print(f\"Prompt: {prompt}\")\n", + " print(f'Prediction: \"{prediction}\"')\n", + "\n", + " del tl_model\n", + " gc.collect()\n", + " if IN_COLAB:\n", + " %rm -rf /root/.cache/huggingface/hub/models*" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The following models can run in the T4 free environment\n", + "free_compatible = [\n", + " \"ai-forever/mGPT\",\n", + " \"ArthurConmy/redwood_attn_2l\",\n", + " \"bigcode/santacoder\",\n", + " \"bigscience/bloom-1b1\",\n", + " \"bigscience/bloom-560m\",\n", + " \"distilgpt2\",\n", + " \"EleutherAI/gpt-neo-1.3B\",\n", + " \"EleutherAI/gpt-neo-125M\",\n", + " \"EleutherAI/gpt-neo-2.7B\",\n", + " \"EleutherAI/pythia-1.4b\",\n", + " \"EleutherAI/pythia-1.4b-deduped\",\n", + " \"EleutherAI/pythia-1.4b-deduped-v0\",\n", + " \"EleutherAI/pythia-1.4b-v0\",\n", + " \"EleutherAI/pythia-14m\",\n", + " \"EleutherAI/pythia-160m\",\n", + " \"EleutherAI/pythia-160m-deduped\",\n", + " \"EleutherAI/pythia-160m-deduped-v0\",\n", + " \"EleutherAI/pythia-160m-seed1\",\n", + " \"EleutherAI/pythia-160m-seed2\",\n", + " \"EleutherAI/pythia-160m-seed3\",\n", + " \"EleutherAI/pythia-160m-v0\",\n", + " \"EleutherAI/pythia-1b\",\n", + " \"EleutherAI/pythia-1b-deduped\",\n", + " \"EleutherAI/pythia-1b-deduped-v0\",\n", + " \"EleutherAI/pythia-1b-v0\",\n", + " \"EleutherAI/pythia-31m\",\n", + " \"EleutherAI/pythia-410m\",\n", + " \"EleutherAI/pythia-410m-deduped\",\n", + " \"EleutherAI/pythia-410m-deduped-v0\",\n", + " \"EleutherAI/pythia-410m-v0\",\n", + " \"EleutherAI/pythia-70m\",\n", + " \"EleutherAI/pythia-70m-deduped\",\n", + " \"EleutherAI/pythia-70m-deduped-v0\",\n", + " \"EleutherAI/pythia-70m-v0\",\n", + " \"facebook/opt-1.3b\",\n", + " \"facebook/opt-125m\",\n", + " \"gpt2\",\n", + " \"gpt2-large\",\n", + " \"gpt2-medium\",\n", + " \"gpt2-xl\",\n", + " \"meta-llama/Llama-3.2-1B\",\n", + " \"meta-llama/Llama-3.2-1B-Instruct\",\n", + " \"microsoft/phi-1\",\n", + " \"microsoft/phi-1_5\",\n", + " \"NeelNanda/Attn-Only-2L512W-Shortformer-6B-big-lr\",\n", + " \"NeelNanda/Attn_Only_1L512W_C4_Code\",\n", + " \"NeelNanda/Attn_Only_2L512W_C4_Code\",\n", + " \"NeelNanda/Attn_Only_3L512W_C4_Code\",\n", + " \"NeelNanda/Attn_Only_4L512W_C4_Code\",\n", + " \"NeelNanda/GELU_1L512W_C4_Code\",\n", + " \"NeelNanda/GELU_2L512W_C4_Code\",\n", + " \"NeelNanda/GELU_3L512W_C4_Code\",\n", + " \"NeelNanda/GELU_4L512W_C4_Code\",\n", + " \"NeelNanda/SoLU_10L1280W_C4_Code\",\n", + " \"NeelNanda/SoLU_10L_v22_old\",\n", + " \"NeelNanda/SoLU_12L1536W_C4_Code\",\n", + " \"NeelNanda/SoLU_12L_v23_old\",\n", + " \"NeelNanda/SoLU_1L512W_C4_Code\",\n", + " \"NeelNanda/SoLU_1L512W_Wiki_Finetune\",\n", + " \"NeelNanda/SoLU_1L_v9_old\",\n", + " \"NeelNanda/SoLU_2L512W_C4_Code\",\n", + " \"NeelNanda/SoLU_2L_v10_old\",\n", + " \"NeelNanda/SoLU_3L512W_C4_Code\",\n", + " \"NeelNanda/SoLU_4L512W_C4_Code\",\n", + " \"NeelNanda/SoLU_4L512W_Wiki_Finetune\",\n", + " \"NeelNanda/SoLU_4L_v11_old\",\n", + " \"NeelNanda/SoLU_6L768W_C4_Code\",\n", + " \"NeelNanda/SoLU_6L_v13_old\",\n", + " \"NeelNanda/SoLU_8L1024W_C4_Code\",\n", + " \"NeelNanda/SoLU_8L_v21_old\",\n", + " \"Qwen/Qwen-1_8B\",\n", + " \"Qwen/Qwen-1_8B-Chat\",\n", + " \"Qwen/Qwen1.5-0.5B\",\n", + " \"Qwen/Qwen1.5-0.5B-Chat\",\n", + " \"Qwen/Qwen1.5-1.8B\",\n", + " \"Qwen/Qwen1.5-1.8B-Chat\",\n", + " \"Qwen/Qwen2-0.5B\",\n", + " \"Qwen/Qwen2-0.5B-Instruct\",\n", + " \"Qwen/Qwen2-1.5B\",\n", + " \"Qwen/Qwen2-1.5B-Instruct\",\n", + " \"roneneldan/TinyStories-1Layer-21M\",\n", + " \"roneneldan/TinyStories-1M\",\n", + " \"roneneldan/TinyStories-28M\",\n", + " \"roneneldan/TinyStories-2Layers-33M\",\n", + " \"roneneldan/TinyStories-33M\",\n", + " \"roneneldan/TinyStories-3M\",\n", + " \"roneneldan/TinyStories-8M\",\n", + " \"roneneldan/TinyStories-Instruct-1M\",\n", + " \"roneneldan/TinyStories-Instruct-28M\",\n", + " \"roneneldan/TinyStories-Instruct-2Layers-33M\",\n", + " \"roneneldan/TinyStories-Instruct-33M\",\n", + " \"roneneldan/TinyStories-Instruct-3M\",\n", + " \"roneneldan/TinyStories-Instruct-8M\",\n", + " \"roneneldan/TinyStories-Instuct-1Layer-21M\",\n", + " \"stanford-crfm/alias-gpt2-small-x21\",\n", + " \"stanford-crfm/arwen-gpt2-medium-x21\",\n", + " \"stanford-crfm/battlestar-gpt2-small-x49\",\n", + " \"stanford-crfm/beren-gpt2-medium-x49\",\n", + " \"stanford-crfm/caprica-gpt2-small-x81\",\n", + " \"stanford-crfm/celebrimbor-gpt2-medium-x81\",\n", + " \"stanford-crfm/darkmatter-gpt2-small-x343\",\n", + " \"stanford-crfm/durin-gpt2-medium-x343\",\n", + " \"stanford-crfm/eowyn-gpt2-medium-x777\",\n", + " \"stanford-crfm/expanse-gpt2-small-x777\",\n", + "]\n", + "\n", + "if IN_COLAB:\n", + " run_set(free_compatible)\n", + " \n", + "mark_models_as_tested(free_compatible)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "paid_gpu_models = [\n", + " \"01-ai/Yi-6B\",\n", + " \"01-ai/Yi-6B-Chat\",\n", + " \"bigscience/bloom-1b7\",\n", + " \"bigscience/bloom-3b\",\n", + " \"bigscience/bloom-7b1\",\n", + " \"codellama/CodeLlama-7b-hf\",\n", + " \"codellama/CodeLlama-7b-Instruct-hf\",\n", + " \"codellama/CodeLlama-7b-Python-hf\",\n", + " \"EleutherAI/pythia-2.8b\",\n", + " \"EleutherAI/pythia-2.8b-deduped\",\n", + " \"EleutherAI/pythia-2.8b-deduped-v0\",\n", + " \"EleutherAI/pythia-2.8b-v0\",\n", + " \"EleutherAI/pythia-6.9b\",\n", + " \"EleutherAI/pythia-6.9b-deduped\",\n", + " \"EleutherAI/pythia-6.9b-deduped-v0\",\n", + " \"EleutherAI/pythia-6.9b-v0\",\n", + " \"facebook/opt-2.7b\",\n", + " \"facebook/opt-6.7b\",\n", + " \"google/gemma-2-2b\",\n", + " \"google/gemma-2-2b-it\",\n", + " \"google/gemma-2b\",\n", + " \"google/gemma-2b-it\",\n", + " \"google/gemma-7b\",\n", + " \"google/gemma-7b-it\",\n", + " \"meta-llama/Llama-2-7b-chat-hf\",\n", + " \"meta-llama/Llama-2-7b-hf\",\n", + " \"meta-llama/Llama-3.1-8B\",\n", + " \"meta-llama/Llama-3.1-8B-Instruct\",\n", + " \"meta-llama/Llama-3.2-3B\",\n", + " \"meta-llama/Llama-3.2-3B-Instruct\",\n", + " \"meta-llama/Meta-Llama-3-8B\",\n", + " \"meta-llama/Meta-Llama-3-8B-Instruct\",\n", + " \"microsoft/phi-2\",\n", + " \"microsoft/Phi-3-mini-4k-instruct\",\n", + " \"mistralai/Mistral-7B-Instruct-v0.1\",\n", + " \"mistralai/Mistral-7B-v0.1\",\n", + " \"mistralai/Mistral-Nemo-Base-2407\",\n", + " \"Qwen/Qwen-7B\",\n", + " \"Qwen/Qwen-7B-Chat\",\n", + " \"Qwen/Qwen1.5-4B\",\n", + " \"Qwen/Qwen1.5-4B-Chat\",\n", + " \"Qwen/Qwen1.5-7B\",\n", + " \"Qwen/Qwen1.5-7B-Chat\",\n", + " \"Qwen/Qwen2-7B\",\n", + " \"Qwen/Qwen2-7B-Instruct\",\n", + " \"stabilityai/stablelm-base-alpha-3b\",\n", + " \"stabilityai/stablelm-base-alpha-7b\",\n", + " \"stabilityai/stablelm-tuned-alpha-3b\",\n", + " \"stabilityai/stablelm-tuned-alpha-7b\",\n", + "]\n", + "\n", + "if IN_COLAB:\n", + " run_set(paid_gpu_models)\n", + " \n", + "mark_models_as_tested(paid_gpu_models)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "paid_cpu_models = [\n", + " \"EleutherAI/gpt-j-6B\",\n", + " \"EleutherAI/gpt-neox-20b\",\n", + " \"EleutherAI/pythia-12b\",\n", + " \"EleutherAI/pythia-12b-deduped\",\n", + " \"EleutherAI/pythia-12b-deduped-v0\",\n", + " \"EleutherAI/pythia-12b-v0\",\n", + " \"facebook/opt-13b\",\n", + " \"google/gemma-2-9b\",\n", + " \"google/gemma-2-9b-it\",\n", + " \"meta-llama/Llama-2-13b-chat-hf\",\n", + " \"meta-llama/Llama-2-13b-hf\",\n", + " \"Qwen/Qwen-14B\",\n", + " \"Qwen/Qwen-14B-Chat\",\n", + " \"Qwen/Qwen1.5-14B\",\n", + " \"Qwen/Qwen1.5-14B-Chat\",\n", + "]\n", + "\n", + "if IN_COLAB:\n", + " run_set(paid_cpu_models, \"cpu\")\n", + " \n", + "mark_models_as_tested(paid_cpu_models)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "incompatible_models = [\n", + " \"01-ai/Yi-34B\",\n", + " \"01-ai/Yi-34B-Chat\",\n", + " \"facebook/opt-30b\",\n", + " \"facebook/opt-66b\",\n", + " \"google/gemma-2-27b\",\n", + " \"google/gemma-2-27b-it\",\n", + " \"meta-llama/Llama-2-70b-chat-hf\",\n", + " \"meta-llama/Llama-3.1-70B\",\n", + " \"meta-llama/Llama-3.1-70B-Instruct\",\n", + " \"meta-llama/Meta-Llama-3-70B\",\n", + " \"meta-llama/Meta-Llama-3-70B-Instruct\",\n", + " \"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n", + " \"mistralai/Mixtral-8x7B-v0.1\",\n", + "]\n", + "\n", + "mark_models_as_tested(incompatible_models)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "# The following models take a few extra steps to function. Check the official demo for more\n", + "# information on how to use. 7b and 13b will work in the paid environment. 30b and 65b will not work\n", + "# in Colab\n", + "not_hosted_models = [\n", + " \"llama-7b-hf\",\n", + " \"llama-13b-hf\",\n", + " \"llama-30b-hf\",\n", + " \"llama-65b-hf\",\n", + "]\n", + "\n", + "if LLAMA_MODEL_PATH:\n", + " run_llama_set(not_hosted_models, LLAMA_MODEL_PATH)\n", + "\n", + "mark_models_as_tested(not_hosted_models)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "# These all work on the free version of Colab\n", + "encoder_decoders = [\n", + " \"google-t5/t5-base\",\n", + " \"google-t5/t5-large\",\n", + " \"google-t5/t5-small\",\n", + "]\n", + "if IN_COLAB:\n", + " run_encoder_decoder_set(encoder_decoders)\n", + "\n", + "mark_models_as_tested(encoder_decoders)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# This model works on the free version of Colab\n", + "encoder_only_models = [\"bert-base-cased\"]\n", + "\n", + "if IN_COLAB:\n", + " run_encoder_only_set(encoder_only_models)\n", + "\n", + "mark_models_as_tested(encoder_only_models)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "broken_models = [\n", + " \"Baidicoot/Othello-GPT-Transformer-Lens\",\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Baidicoot/Othello-GPT-Transformer-Lens\n" + ] + } + ], + "source": [ + "# Any models listed in the cell below have not been tested. This should always remain blank. If your\n", + "# PR fails due to this notebook, most likely you need to check any new model changes to ensure that\n", + "# this notebook is up to date.\n", + "print(*untested_models, sep = '\\n')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/source/_static/TransformerLens_Diagram.svg b/docs/source/_static/TransformerLens_Diagram.svg new file mode 100644 index 000000000..fb7a5c65d --- /dev/null +++ b/docs/source/_static/TransformerLens_Diagram.svg @@ -0,0 +1,12396 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + m = d_model + m = d_model + m = d_model + n = n_ctx + m = d_model + m = d_head + n = d_model + n_head + n = d_vocab + m = 1 + n = seq_len + n = seq_len + seq_len + n = seq_len + InputTokens + W_positional + (GPT-2) + W_Query + W_Embedding + (Input vectors) + hook_resid_pre + r e s i d u a l s t r e a m + + = + = + + + + + X + + + ( + [ + [ + ) + ) + ( + + + + m = d_head + n = 1 + n_head + b_Query + + + + + + + + + + + + + + + + + + + m = len_seq + n = len_seq + n_head + + + + + + + + + + + + + + m = d_head + n = d_model + n_head + W_Key + + + X + X + + + + + + + + + m = d_head + n = 1 + n_head + b_Key + + + + m = d_head + n = d_model + n_head + W_Value + + + + + X + + + + m = d_head + n = 1 + n_head + b_Value + + + + + + + + + m = d_model + n = d_head + n_head + W_Output + + + + + + + + + + + + + + + + softmax + + + softmax + + + ∑heads + b_Output + attn.hook_q + RoPE + (GPT-J) + + + RoPE + mask + Q + K + T + (GPT-J) + + + index + + + + + + + + + hook_embed + + + + + + + m = d_head + n = len_seq + n_head + (query vectors) + (scaled attention) + (attention heads) + + + + = + = + = + = + attn.hook_k + hook_attn_scores + attn. + attn.hook_pattern + + + + + + m = d_head + d_head + n = len_seq + n_head + (key vectors) + + + + + + ln2.hook_normalized + n = seq_len + m = d_model + = + attn.hook_v + + + + + + m = d_head + n = len_seq + n_head + (value vectors) + + + + + + + + + + + + m = len_seq + n = len_seq + n_head + + + + X + V + attn.hook_z + + + + + + m = d_head + n = len_seq + n_head + (weighted values) + (attention output) + + + + + + + + X + + + + + + + + + + + m = 4 x d_model + m = 4 x d_model + *or other nonlinear fn + n = d_model + W_inmlp + b_inmlp + + + X + + + + + + + + + + [d_model] + + + + + m = d_model + n = seq_len + hook_attn_out + + + m = 4 x d_model + n = seq_len + mlp.hook_pre + + + m = 4 x d_model + n = seq_len + mlp.hook_post + hook_attn_out + + + m = d_model + n = seq_len + hook_resid_mid + hook_resid_mid + + + = + = + + + + + + + – mean) / std + = + + + ln_final.hook_normalized + n = seq_len + m = d_model + [d_model] + + + + + – mean) / std + ln_final.b + = + + GeLU* + m = d_model + m = d_model + n = 4 x d_model + W_outmlp + b_outmlp + + + X + + + + + m = d_model + n = seq_len + n = seq_len + mlp.hook_out + + + hook_resid_post + + + m = d_model + + + + (not final layer) + + + + + + + + + + + + (final layer) + m = d_vocab + m = d_vocab + m = d_vocab + n = seq_len + n = d_model + W_Unembed + b_Unembed + = + + + X + + + + + logits + probabilities + + [ + [ + [ + + + ln1.hook_normalized + n = seq_len + m = d_model + [d_model] + + + X + + + – mean) / std + ln1.w + ln1.b + = + = + + + + + + + + + + + r e s i d u a l s t r e a m + + Weight_Matrices in sans serif + activation tensors in times new roman + [ + Diagram of GPT-2 style LLMin TransformerLens notation + attn + mlp + encoding + + mlp.hook_out + + + = + + [ + [ + + + = + X + ln_final.w + [ + [ + [ + [ + [ + [ + [ + [ + [d_model] + + + X + ln2.w + ln2.b + [ + [ + [ + [ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + hook_resid_pre + + + + + + diff --git a/docs/source/index.md b/docs/source/index.md index f1b8737d5..4851b4334 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -18,6 +18,8 @@ I used to work for the [Anthropic interpretability team](https://transformer-cir The core features were heavily inspired by the interface to [Anthropic's excellent Garcon tool](https://transformer-circuits.pub/2021/garcon/index.html). Credit to Nelson Elhage and Chris Olah for building Garcon and showing me the value of good infrastructure for enabling exploratory research! +A great place to start is to take a look at a helpful diagram of [all weight matrices and activation tensors with TransformerLens notation](_static/TransformerLens_Diagram.svg) courtesy of [Austin Kozlowski](https://github.com/akozlo). Another helpful tool to help you get going as quickly as possible is our [Colab Compatability Demo](https://github.com/TransformerLensOrg/TransformerLens/tree/main/demos/Colab_Compatibility.ipynb), which will give you a good idea of what you can do in various Colab environments. + ```{toctree} :hidden: :caption: Introduction diff --git a/poetry.lock b/poetry.lock index 422a4b2e7..300dd1138 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4867,21 +4867,21 @@ tutorials = ["matplotlib", "pandas", "tabulate"] [[package]] name = "typeguard" -version = "4.2.1" +version = "4.4.0" description = "Run-time type checker for Python" optional = false python-versions = ">=3.8" files = [ - {file = "typeguard-4.2.1-py3-none-any.whl", hash = "sha256:7da3bd46e61f03e0852f8d251dcbdc2a336aa495d7daff01e092b55327796eb8"}, - {file = "typeguard-4.2.1.tar.gz", hash = "sha256:c556a1b95948230510070ca53fa0341fb0964611bd05d598d87fb52115d65fee"}, + {file = "typeguard-4.4.0-py3-none-any.whl", hash = "sha256:8ca34c14043f53b2caae7040549ba431770869bcd6287cfa8239db7ecb882b4a"}, + {file = "typeguard-4.4.0.tar.gz", hash = "sha256:463bd8697a65a4aa576a63767c369b1ecfba8a5ba735edfe3223127b6ecfa28c"}, ] [package.dependencies] importlib-metadata = {version = ">=3.6", markers = "python_version < \"3.10\""} -typing-extensions = {version = ">=4.10.0", markers = "python_version < \"3.13\""} +typing-extensions = ">=4.10.0" [package.extras] -doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)"] +doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme (>=1.3.0)"] test = ["coverage[toml] (>=7)", "mypy (>=1.2.0)", "pytest (>=7)"] [[package]] @@ -5325,4 +5325,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.8,<4.0" -content-hash = "bf948abb46e5282633d5e369c78b5fb21e1a75e2221bcb1088630360893efca3" +content-hash = "fcebd987bb0fd59d2be08a9ffd6ea6e22373441f4d347d841669c69d5616e797" diff --git a/pyproject.toml b/pyproject.toml index 5abcfc481..558d9d7ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ transformers=">=4.37.2" typing-extensions="*" wandb=">=0.13.5" + typeguard = "^4.2" [tool.poetry.group] [tool.poetry.group.dev.dependencies] diff --git a/tests/acceptance/test_hooked_transformer.py b/tests/acceptance/test_hooked_transformer.py index 9d9e2bb19..ac7555ad6 100644 --- a/tests/acceptance/test_hooked_transformer.py +++ b/tests/acceptance/test_hooked_transformer.py @@ -66,7 +66,7 @@ "redwood_attn_2l": 10.530948638916016, "solu-1l": 5.256411552429199, "tiny-stories-33M": 12.203617095947266, - "bloom-560m": 4.1953, + "bloom-560m": 5.237126350402832, } no_processing = [ @@ -175,6 +175,26 @@ def test_from_pretrained_revision(): raise AssertionError("Should have raised an error") +def test_bloom_similarity_with_hf_model_with_kv_cache_activated(): + tf_model = HookedTransformer.from_pretrained( + "bigscience/bloom-560m", default_prepend_bos=False, device="cpu" + ) + hf_model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m") + hf_tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m") + + output_tf = tf_model.generate( + text, do_sample=False, use_past_kv_cache=True, verbose=False, max_new_tokens=10 + ) + output_hf_tokens = hf_model.generate( + hf_tokenizer(text, return_tensors="pt").input_ids, + do_sample=False, + max_new_tokens=10, + ) + output_hf_str = hf_tokenizer.decode(output_hf_tokens[0], skip_special_tokens=True) + + assert output_tf == output_hf_str + + def check_norm_folding( model_name, hf_model=None, diff --git a/tests/integration/test_kv_cache.py b/tests/integration/test_kv_cache.py index a98ba7de6..baab6696a 100644 --- a/tests/integration/test_kv_cache.py +++ b/tests/integration/test_kv_cache.py @@ -213,6 +213,28 @@ def test_freeze_cache(pretrained): assert not t.allclose(with_cache_logits_1, with_cache_2_logits_1, atol=atol) +def test_kv_cache_with_custom_attention_mask(pretrained): + model, atol = pretrained + prompt_pre = "An apple" + prompt_post = " a day keeps junk the" + prompt_whole = "An apple a day keeps the" + tokens_pre = model.to_tokens(prompt_pre) + tokens_post = model.to_tokens(prompt_post, prepend_bos=False) + tokens_whole = model.to_tokens(prompt_whole) + correct_logits = model(tokens_whole) + + past_kv_cache = HookedTransformerKeyValueCache.init_cache( + model.cfg, model.cfg.device, tokens_pre.shape[0] + ) + model(tokens_pre, past_kv_cache=past_kv_cache) + exp_logits = model( + tokens_post, + attention_mask=t.tensor([[1, 1, 1, 0, 1]], device=model.cfg.device), + past_kv_cache=past_kv_cache, + ) + assert t.allclose(correct_logits[:, -1], exp_logits[:, -1], atol=atol) + + def test_kv_cache_and_start_at_layer(pretrained): model, atol = pretrained pre_prompt = "I went to Staten Island," diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 3fdd1c1ed..8b07f5046 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -8,6 +8,7 @@ alteration of activations in individual components like attention heads and MLP layers, facilitating a deeper understanding of the internal workings of transformers like GPT-2. """ + import logging import os from typing import ( @@ -297,23 +298,25 @@ def input_to_embed( if tokens.device.type != self.cfg.device: tokens = tokens.to(devices.get_device_for_block_index(0, self.cfg)) - if attention_mask is not None: + if ( + (self.tokenizer and self.tokenizer.padding_side == "left") + or attention_mask is not None + or past_kv_cache is not None + ): + # This means we need to have an explicit attention mask. + if attention_mask is None: + # If the padding side is left or we are using caching, we need to compute the attention + # mask for the adjustment of absolute positional embeddings and attention masking so + # that pad tokens are not attended. + if prepend_bos is USE_DEFAULT_VALUE: + prepend_bos = self.cfg.default_prepend_bos + attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos) + assert attention_mask.shape == tokens.shape, ( f"Attention mask shape {attention_mask.shape} does not match tokens shape " f"{tokens.shape}" ) attention_mask = attention_mask.to(devices.get_device_for_block_index(0, self.cfg)) - elif ( - self.tokenizer and self.tokenizer.padding_side == "left" - ) or past_kv_cache is not None: - # If the padding side is left or we are using caching, we need to compute the attention - # mask for the adjustment of absolute positional embeddings and attention masking so - # that pad tokens are not attended. - - if prepend_bos is USE_DEFAULT_VALUE: - prepend_bos = self.cfg.default_prepend_bos - attention_mask = utils.get_attention_mask(self.tokenizer, tokens, prepend_bos) - if past_kv_cache is not None: # past_kv_cache is not None, so we're doing caching. # We need to extend the previous attention_mask. @@ -1080,7 +1083,7 @@ def from_pretrained( tokenizer: Optional[PreTrainedTokenizerBase] = None, move_to_device: bool = True, fold_value_biases: bool = True, - default_prepend_bos: bool = True, + default_prepend_bos: Optional[bool] = None, default_padding_side: Literal["left", "right"] = "right", dtype="float32", first_n_layers: Optional[int] = None, @@ -1202,11 +1205,15 @@ def from_pretrained( remains exactly the same, and so is just broadcast across the destination positions. default_prepend_bos: Default behavior of whether to prepend the BOS token when the methods of HookedTransformer process input text to tokenize (only - when input is a string). Defaults to True - even for models not explicitly trained - with this, heads often use the first position as a resting position and accordingly - lose information from the first token, so this empirically seems to give better - results. To change the default behavior to False, pass in default_prepend_bos=False. - Note that you can also locally override the default behavior by passing in + when input is a string). + Resolution order for default_prepend_bos: + 1. If user passes value explicitly, use that value + 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False) + 3. Global default (True) + + Even for models not explicitly trained with the BOS token, heads often use the first position as a resting position + and accordingly lose information from the first token, so this empirically seems to give better + results. Note that you can also locally override the default behavior by passing in prepend_bos=True/False when you call a method that processes the input string. from_pretrained_kwargs: Any other optional argument passed to HuggingFace's from_pretrained (e.g. "cache_dir" or "torch_dtype"). Also passed to @@ -1220,6 +1227,10 @@ def from_pretrained( "right". first_n_layers: If specified, only load the first n layers of the model. """ + if model_name.lower().startswith("t5"): + raise RuntimeError( + "Execution stopped: Please use HookedEncoderDecoder to load T5 models instead of HookedTransformer." + ) assert not ( from_pretrained_kwargs.get("load_in_8bit", False) @@ -1346,7 +1357,7 @@ def from_pretrained_no_processing( refactor_factored_attn_matrices=False, fold_value_biases=False, dtype=torch.float32, - default_prepend_bos=True, + default_prepend_bos=None, default_padding_side="right", **from_pretrained_kwargs, ): diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py index e2fdc532e..4458705de 100644 --- a/transformer_lens/HookedTransformerConfig.py +++ b/transformer_lens/HookedTransformerConfig.py @@ -181,6 +181,18 @@ class HookedTransformerConfig: output_logits_soft_cap (float): An optional softcap for output logits, currently only used in Gemma-2 (see attn_scores_soft_cap for details). Defaults to -1.0, which means not set. + use_NTK_by_parts_rope (bool): Whether to apply the "NTK-by-parts" method when using Rotary + Positional Embedding. This method adjusts the interpolation based on frequency factors + for different parts of the hidden dimensions. See Section 3.2 in + https://arxiv.org/pdf/2309.00071 for details. Defaults to False. + NTK_by_parts_low_freq_factor (float): The threshold applied to low-frequency hidden + dimensions during interpolation when using the "NTK-by-parts" method. Defaults to 1.0. + NTK_by_parts_high_freq_factor (float): The threshold applied to high-frequency hidden + dimensions during interpolation in the "NTK-by-parts" method. Defaults to 4.0. + NTK_by_parts_factor (float): The overall factor used in the "NTK-by-parts" method that + affects the rate of change between low and high-frequency interpolation strategies. + Defaults to 8.0. + """ @@ -246,6 +258,10 @@ class HookedTransformerConfig: use_normalization_before_and_after: bool = False attn_scores_soft_cap: float = -1.0 output_logits_soft_cap: float = -1.0 + use_NTK_by_parts_rope: bool = False + NTK_by_parts_low_freq_factor: float = 1.0 + NTK_by_parts_high_freq_factor: float = 4.0 + NTK_by_parts_factor: float = 8.0 def __post_init__(self): if self.n_heads == -1: diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 3146de0c2..347548f34 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -1,3 +1,4 @@ +import math from abc import ABC from typing import Dict, Optional, Tuple, Union @@ -228,8 +229,9 @@ def forward( self.cfg.n_heads, key_ctx, self.cfg.device ) + # Take the last query_ctx positions so it also works with past_kv_cache attn_scores += self.alibi[ - :, :query_ctx, :key_ctx + :, -query_ctx:, :key_ctx ] # [batch, head_index, query_pos, key_pos] elif self.cfg.positional_embedding_type == "relative_positional_bias": if position_bias is None: @@ -295,17 +297,19 @@ def forward( ) ) else: + # Add singleton dimensions to make shapes compatible for broadcasting: w = einops.rearrange( self.W_O, - "head_index d_head d_model -> d_model head_index d_head", + "head_index d_head d_model -> 1 1 head_index d_head d_model", ) - result = self.hook_result( - einops.einsum( - z, - w, - "... head_index d_head, d_model head_index d_head -> ... head_index d_model", - ) - ) # [batch, pos, head_index, d_model] + z = einops.rearrange( + z, "batch pos head_index d_head -> batch pos head_index d_head 1" + ) + + # Multiply the z tensor by the W_O tensor, summing over the d_head dimension + unhooked_result = (z * w).sum(-2) + + result = self.hook_result(unhooked_result) # [batch, pos, head_index, d_model] out = ( einops.reduce(result, "batch position index model->batch position model", "sum") + self.b_O @@ -478,8 +482,33 @@ def calculate_sin_cos_rotary( pos = torch.arange(n_ctx, dtype=high_precision) dim = torch.arange(rotary_dim // 2, dtype=high_precision) - # A set of frequencies evenly spaced in log space - freq = base ** (dim / (rotary_dim / 2)) + # Llama-3.1 uses NTK-by-Parts Rotary Embedding introduced in Section 3.2 in https://arxiv.org/pdf/2309.00071 + # Implementation copied from https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/modeling_rope_utils.py#L310 + if self.cfg.use_NTK_by_parts_rope: + inv_freq = 1.0 / ( + base ** (torch.arange(0, rotary_dim, 2, dtype=torch.int64).float() / rotary_dim) + ) + factor = self.cfg.NTK_by_parts_factor + low_freq_factor = self.cfg.NTK_by_parts_low_freq_factor + high_freq_factor = self.cfg.NTK_by_parts_high_freq_factor + old_context_len = n_ctx + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + wavelen = 2 * math.pi / inv_freq + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) + smooth_factor = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + smoothed_inv_freq = ( + 1 - smooth_factor + ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) + freq = 1 / inv_freq_llama + else: + freq = base ** (dim / (rotary_dim / 2)) if self.cfg.rotary_adjacent_pairs: freq = einops.repeat(freq, "d -> (d 2)") else: @@ -663,7 +692,11 @@ def create_alibi_bias( n_heads, device ) - # The ALiBi bias is then m * slope_matrix - alibi_bias = torch.einsum("ij,k->kij", slope, multipliers) + # Add singleton dimensions to make shapes compatible for broadcasting: + slope = einops.rearrange(slope, "query key -> 1 query key") + multipliers = einops.rearrange(multipliers, "head_idx -> head_idx 1 1") + + # Element-wise multiplication of the slope and multipliers + alibi_bias = multipliers * slope return alibi_bias diff --git a/transformer_lens/hook_points.py b/transformer_lens/hook_points.py index ec718810a..a961403b0 100644 --- a/transformer_lens/hook_points.py +++ b/transformer_lens/hook_points.py @@ -89,6 +89,7 @@ def add_hook( is_permanent: bool = False, level: Optional[int] = None, prepend: bool = False, + skip_verbose_naming=False, ) -> None: """ Hook format is fn(activation, hook_name) @@ -108,9 +109,10 @@ def full_hook( module_output = module_output[0] return hook(module_output, hook=self) - full_hook.__name__ = ( - hook.__repr__() - ) # annotate the `full_hook` with the string representation of the `hook` function + if not skip_verbose_naming: + full_hook.__name__ = ( + hook.__repr__() + ) # annotate the `full_hook` with the string representation of the `hook` function if dir == "fwd": pt_handle = self.register_forward_hook(full_hook) @@ -261,6 +263,7 @@ def check_and_add_hook( is_permanent: bool = False, level: Union[int, None] = None, prepend: bool = False, + skip_verbose_naming: bool = False, ) -> None: """Runs checks on the hook, and then adds it to the hook point""" @@ -272,7 +275,14 @@ def check_and_add_hook( is_permanent=is_permanent, prepend=prepend, ) - hook_point.add_hook(hook, dir=dir, is_permanent=is_permanent, level=level, prepend=prepend) + hook_point.add_hook( + hook, + dir=dir, + is_permanent=is_permanent, + level=level, + prepend=prepend, + skip_verbose_naming=skip_verbose_naming, + ) def check_hooks_to_add( self, @@ -294,6 +304,7 @@ def add_hook( is_permanent: bool = False, level: Union[int, None] = None, prepend: bool = False, + skip_verbose_naming: bool = False, ) -> None: if isinstance(name, str): hook_point = self.mod_dict[name] @@ -308,6 +319,7 @@ def add_hook( is_permanent=is_permanent, level=level, prepend=prepend, + skip_verbose_naming=skip_verbose_naming, ) else: # Otherwise, name is a Boolean function on names @@ -321,6 +333,7 @@ def add_hook( is_permanent=is_permanent, level=level, prepend=prepend, + skip_verbose_naming=skip_verbose_naming, ) def add_perma_hook( @@ -331,15 +344,24 @@ def add_perma_hook( ) -> None: self.add_hook(name, hook, dir=dir, is_permanent=True) - def _enable_hook_with_name(self, name: str, hook: Callable, dir: Literal["fwd", "bwd"]): + def _enable_hook_with_name( + self, + name: str, + hook: Callable, + dir: Literal["fwd", "bwd"], + skip_verbose_naming: bool = False, + ): """This function takes a key for the mod_dict and enables the related hook for that module Args: name (str): The module name hook (Callable): The hook to add dir (Literal["fwd", "bwd"]): The direction for the hook + skip_verbose_naming (bool): If True, skips the assignment of the string representation of `hook` to `full_hook.__name__`. """ - self.mod_dict[name].add_hook(hook, dir=dir, level=self.context_level) + self.mod_dict[name].add_hook( + hook, dir=dir, level=self.context_level, skip_verbose_naming=skip_verbose_naming + ) def _enable_hooks_for_points( self, @@ -347,6 +369,7 @@ def _enable_hooks_for_points( enabled: Callable, hook: Callable, dir: Literal["fwd", "bwd"], + skip_verbose_naming: bool = False, ): """Enables hooks for a list of points @@ -355,24 +378,40 @@ def _enable_hooks_for_points( enabled (Callable): _description_ hook (Callable): _description_ dir (Literal["fwd", "bwd"]): _description_ + skip_verbose_naming (bool): If True, skips the assignment of the string representation of `hook` to `full_hook.__name__`. """ for hook_name, hook_point in hook_points: if enabled(hook_name): - hook_point.add_hook(hook, dir=dir, level=self.context_level) + hook_point.add_hook( + hook, dir=dir, level=self.context_level, skip_verbose_naming=skip_verbose_naming + ) - def _enable_hook(self, name: Union[str, Callable], hook: Callable, dir: Literal["fwd", "bwd"]): + def _enable_hook( + self, + name: Union[str, Callable], + hook: Callable, + dir: Literal["fwd", "bwd"], + skip_verbose_naming: bool = False, + ): """Enables an individual hook on a hook point Args: name (str): The name of the hook hook (Callable): The actual hook dir (Literal["fwd", "bwd"], optional): The direction of the hook. Defaults to "fwd". + skip_verbose_naming (bool): If True, skips the assignment of the string representation of `hook` to `full_hook.__name__`. """ if isinstance(name, str): - self._enable_hook_with_name(name=name, hook=hook, dir=dir) + self._enable_hook_with_name( + name=name, hook=hook, dir=dir, skip_verbose_naming=skip_verbose_naming + ) else: self._enable_hooks_for_points( - hook_points=self.hook_dict.items(), enabled=name, hook=hook, dir=dir + hook_points=self.hook_dict.items(), + enabled=name, + hook=hook, + dir=dir, + skip_verbose_naming=skip_verbose_naming, ) @contextmanager @@ -382,6 +421,7 @@ def hooks( bwd_hooks: List[Tuple[Union[str, Callable], Callable]] = [], reset_hooks_end: bool = True, clear_contexts: bool = False, + skip_verbose_naming: bool = False, ): """ A context manager for adding temporary hooks to the model. @@ -392,6 +432,7 @@ def hooks( bwd_hooks: Same as fwd_hooks, but for the backward pass. reset_hooks_end (bool): If True, removes all hooks added by this context manager when the context manager exits. clear_contexts (bool): If True, clears hook contexts whenever hooks are reset. + skip_verbose_naming (bool): If True, skips the assignment of the string representation of `hook` to `full_hook.__name__`. Example: @@ -404,9 +445,13 @@ def hooks( self.context_level += 1 for name, hook in fwd_hooks: - self._enable_hook(name=name, hook=hook, dir="fwd") + self._enable_hook( + name=name, hook=hook, dir="fwd", skip_verbose_naming=skip_verbose_naming + ) for name, hook in bwd_hooks: - self._enable_hook(name=name, hook=hook, dir="bwd") + self._enable_hook( + name=name, hook=hook, dir="bwd", skip_verbose_naming=skip_verbose_naming + ) yield self finally: if reset_hooks_end: diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index cc0295323..aa544786f 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -144,9 +144,9 @@ "meta-llama/Llama-2-13b-hf", "meta-llama/Llama-2-13b-chat-hf", "meta-llama/Llama-2-70b-chat-hf", - "CodeLlama-7b-hf", - "CodeLlama-7b-Python-hf", - "CodeLlama-7b-Instruct-hf", + "codellama/CodeLlama-7b-hf", + "codellama/CodeLlama-7b-Python-hf", + "codellama/CodeLlama-7b-Instruct-hf", "meta-llama/Meta-Llama-3-8B", "meta-llama/Meta-Llama-3-8B-Instruct", "meta-llama/Meta-Llama-3-70B", @@ -155,6 +155,10 @@ "meta-llama/Llama-3.2-3B", "meta-llama/Llama-3.2-1B-Instruct", "meta-llama/Llama-3.2-3B-Instruct", + "meta-llama/Llama-3.1-70B", + "meta-llama/Llama-3.1-8B", + "meta-llama/Llama-3.1-8B-Instruct", + "meta-llama/Llama-3.1-70B-Instruct", "Baidicoot/Othello-GPT-Transformer-Lens", "bert-base-cased", "roneneldan/TinyStories-1M", @@ -177,6 +181,7 @@ "stabilityai/stablelm-tuned-alpha-7b", "mistralai/Mistral-7B-v0.1", "mistralai/Mistral-7B-Instruct-v0.1", + "mistralai/Mistral-Nemo-Base-2407", "mistralai/Mixtral-8x7B-v0.1", "mistralai/Mixtral-8x7B-Instruct-v0.1", "bigscience/bloom-560m", @@ -562,12 +567,12 @@ "meta-llama/Llama-2-13b-chat-hf", ], "meta-llama/Llama-2-70b-chat-hf": ["Llama-2-70b-chat", "meta-llama-2-70b-chat-hf"], - "CodeLlama-7b-hf": ["CodeLlamallama-2-7b", "codellama/CodeLlama-7b-hf"], - "CodeLlama-7b-Python-hf": [ + "codellama/CodeLlama-7b-hf": ["CodeLlamallama-2-7b", "codellama/CodeLlama-7b-hf"], + "codellama/CodeLlama-7b-Python-hf": [ "CodeLlama-7b-python", "codellama/CodeLlama-7b-Python-hf", ], - "CodeLlama-7b-Instruct-hf": [ + "codellama/CodeLlama-7b-Instruct-hf": [ "CodeLlama-7b-instruct", "codellama/CodeLlama-7b-Instruct-hf", ], @@ -604,6 +609,7 @@ ], "mistralai/Mistral-7B-v0.1": ["mistral-7b"], "mistralai/Mistral-7B-Instruct-v0.1": ["mistral-7b-instruct"], + "mistralai/Mistral-Nemo-Base-2407": ["mistral-nemo-base-2407"], "mistralai/Mixtral-8x7B-v0.1": ["mixtral", "mixtral-8x7b"], "mistralai/Mixtral-8x7B-Instruct-v0.1": [ "mixtral-instruct", @@ -755,7 +761,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "final_rms": True, "gated_mlp": True, } - elif official_model_name.startswith("CodeLlama-7b"): # same architecture CodeLlama and Llama-2 + elif official_model_name.startswith("codellama"): # same architecture CodeLlama and Llama-2 cfg_dict = { "d_model": 4096, "d_head": 4096 // 32, @@ -869,6 +875,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "rotary_dim": 128, "final_rms": True, "gated_mlp": True, + "rotary_base": 500000.0, } elif "Meta-Llama-3-70B" in official_model_name: cfg_dict = { @@ -888,6 +895,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "rotary_dim": 128, "final_rms": True, "gated_mlp": True, + "rotary_base": 500000.0, } elif "Llama-3.2-1B" in official_model_name: cfg_dict = { @@ -907,6 +915,11 @@ def convert_hf_model_config(model_name: str, **kwargs): "rotary_dim": 64, "final_rms": True, "gated_mlp": True, + "rotary_base": 500000.0, + "use_NTK_by_parts_rope": True, + "NTK_by_parts_low_freq_factor": 1.0, + "NTK_by_parts_high_freq_factor": 4.0, + "NTK_by_parts_factor": 32.0, } elif "Llama-3.2-3B" in official_model_name: cfg_dict = { @@ -926,14 +939,19 @@ def convert_hf_model_config(model_name: str, **kwargs): "rotary_dim": 128, "final_rms": True, "gated_mlp": True, + "rotary_base": 500000.0, + "use_NTK_by_parts_rope": True, + "NTK_by_parts_low_freq_factor": 1.0, + "NTK_by_parts_high_freq_factor": 4.0, + "NTK_by_parts_factor": 32.0, } - elif "Llama-3.2-1B-Instruct" in official_model_name: + elif "Llama-3.1-8B" in official_model_name: cfg_dict = { - "d_model": 2048, - "d_head": 64, + "d_model": 4096, + "d_head": 128, "n_heads": 32, - "d_mlp": 8192, - "n_layers": 16, + "d_mlp": 14336, + "n_layers": 32, "n_ctx": 2048, # capped due to memory issues "eps": 1e-5, "d_vocab": 128256, @@ -942,17 +960,22 @@ def convert_hf_model_config(model_name: str, **kwargs): "normalization_type": "RMS", "positional_embedding_type": "rotary", "rotary_adjacent_pairs": False, - "rotary_dim": 64, + "rotary_dim": 128, "final_rms": True, "gated_mlp": True, + "rotary_base": 500000.0, + "use_NTK_by_parts_rope": True, + "NTK_by_parts_low_freq_factor": 1.0, + "NTK_by_parts_high_freq_factor": 4.0, + "NTK_by_parts_factor": 8.0, } - elif "Llama-3.2-3B-Instruct" in official_model_name: + elif "Llama-3.1-70B" in official_model_name: cfg_dict = { - "d_model": 3072, + "d_model": 8192, "d_head": 128, - "n_heads": 24, - "d_mlp": 8192, - "n_layers": 28, + "n_heads": 64, + "d_mlp": 28672, + "n_layers": 80, "n_ctx": 2048, # capped due to memory issues "eps": 1e-5, "d_vocab": 128256, @@ -964,6 +987,11 @@ def convert_hf_model_config(model_name: str, **kwargs): "rotary_dim": 128, "final_rms": True, "gated_mlp": True, + "rotary_base": 500000.0, + "use_NTK_by_parts_rope": True, + "NTK_by_parts_low_freq_factor": 1.0, + "NTK_by_parts_high_freq_factor": 4.0, + "NTK_by_parts_factor": 8.0, } elif architecture == "GPTNeoForCausalLM": cfg_dict = { @@ -1070,24 +1098,27 @@ def convert_hf_model_config(model_name: str, **kwargs): "attention_dir": "bidirectional", } elif architecture == "MistralForCausalLM": + use_local_attn = True if hf_config.sliding_window else False cfg_dict = { - "d_model": 4096, - "d_head": 4096 // 32, - "n_heads": 32, - "d_mlp": 14336, - "n_layers": 32, + "d_model": hf_config.hidden_size, + "d_head": hf_config.head_dim + if hasattr(hf_config, "head_dim") and hf_config.head_dim > 0 + else hf_config.hidden_size // hf_config.num_attention_heads, + "n_heads": hf_config.num_attention_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, "n_ctx": 2048, # Capped due to memory issues - "d_vocab": 32000, - "act_fn": "silu", + "d_vocab": hf_config.vocab_size, + "act_fn": hf_config.hidden_act, + "window_size": hf_config.sliding_window, # None if no sliding window was used + "attn_types": ["local"] * hf_config.num_hidden_layers if use_local_attn else None, + "eps": hf_config.rms_norm_eps, + "rotary_base": hf_config.rope_theta, + "n_key_value_heads": hf_config.num_key_value_heads, + "use_local_attn": use_local_attn, "normalization_type": "RMS", "positional_embedding_type": "rotary", - "window_size": 4096, - "attn_types": ["local"] * 32, - "eps": 1e-05, - "n_key_value_heads": 8, "gated_mlp": True, - "use_local_attn": True, - "rotary_dim": 4096 // 32, } elif architecture == "MixtralForCausalLM": cfg_dict = { @@ -1467,7 +1498,7 @@ def get_pretrained_model_config( fold_ln: bool = False, device: Optional[Union[str, torch.device]] = None, n_devices: int = 1, - default_prepend_bos: bool = True, + default_prepend_bos: Optional[bool] = None, dtype: torch.dtype = torch.float32, first_n_layers: Optional[int] = None, **kwargs, @@ -1498,11 +1529,15 @@ def get_pretrained_model_config( n_devices (int, optional): The number of devices to split the model across. Defaults to 1. default_prepend_bos (bool, optional): Default behavior of whether to prepend the BOS token when the methods of HookedTransformer process input text to tokenize (only when input is a string). - Defaults to True - even for models not explicitly trained with this, heads often use the + Resolution order for default_prepend_bos: + 1. If user passes value explicitly, use that value + 2. Model-specific default from cfg_dict if it exists (e.g. for bloom models it's False) + 3. Global default (True) + + Even for models not explicitly trained with the BOS token, heads often use the first position as a resting position and accordingly lose information from the first token, - so this empirically seems to give better results. To change the default behavior to False, pass in - default_prepend_bos=False. Note that you can also locally override the default behavior by passing - in prepend_bos=True/False when you call a method that processes the input string. + so this empirically seems to give better results. Note that you can also locally override the default behavior + by passing in prepend_bos=True/False when you call a method that processes the input string. dtype (torch.dtype, optional): The dtype to load the TransformerLens model in. kwargs: Other optional arguments passed to HuggingFace's from_pretrained. Also given to other HuggingFace functions when compatible. @@ -1579,7 +1614,14 @@ def get_pretrained_model_config( cfg_dict["device"] = device cfg_dict["n_devices"] = n_devices - cfg_dict["default_prepend_bos"] = default_prepend_bos + + if default_prepend_bos is not None: + # User explicitly set prepend_bos behavior, override config/default value + cfg_dict["default_prepend_bos"] = default_prepend_bos + elif "default_prepend_bos" not in cfg_dict: + # No config value or user override, set default value (True) + cfg_dict["default_prepend_bos"] = True + if hf_cfg is not None: cfg_dict["load_in_4bit"] = hf_cfg.get("quantization_config", {}).get("load_in_4bit", False) if first_n_layers is not None: