diff --git a/circuit_analysis_sentiment_classification_pythia1_4b.ipynb b/circuit_analysis_sentiment_classification_pythia1_4b.ipynb index b10f1bf..6f62e66 100644 --- a/circuit_analysis_sentiment_classification_pythia1_4b.ipynb +++ b/circuit_analysis_sentiment_classification_pythia1_4b.ipynb @@ -16082,18 +16082,6 @@ "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.16" } }, "nbformat": 4, diff --git a/circuit_analysis_sentiment_continuation_pythia1_4b.ipynb b/circuit_analysis_sentiment_continuation_pythia1_4b.ipynb index 6bd4f58..3a735ce 100644 --- a/circuit_analysis_sentiment_continuation_pythia1_4b.ipynb +++ b/circuit_analysis_sentiment_continuation_pythia1_4b.ipynb @@ -52022,18 +52022,6 @@ "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.16" } }, "nbformat": 4, diff --git a/circuit_analysis_simple_sentiment_gpt2_small.ipynb b/circuit_analysis_simple_sentiment_gpt2_small.ipynb new file mode 100644 index 0000000..041af8d --- /dev/null +++ b/circuit_analysis_simple_sentiment_gpt2_small.ipynb @@ -0,0 +1,23777 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "2236c42b", + "metadata": { + "id": "CbZUo-Tev4QM" + }, + "source": [ + "# Sentiment Analysis for GPT-2" + ] + }, + { + "cell_type": "markdown", + "id": "a4a82132", + "metadata": { + "id": "5vLV3GuDd415" + }, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "67f34f16", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "README.md\t\t miniconda.sh quick_start_pytorch.ipynb wandb\n", + "eliciting-latent-sentiment miniconda3\t quick_start_pytorch_images\n" + ] + } + ], + "source": [ + "!ls" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "9f8e4ed7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/notebooks/eliciting-latent-sentiment\n" + ] + } + ], + "source": [ + "%cd eliciting-latent-sentiment" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "0c3702fd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting fancy_einsum==0.0.3\n", + " Downloading fancy_einsum-0.0.3-py3-none-any.whl (6.2 kB)\n", + "Installing collected packages: fancy_einsum\n", + "Successfully installed fancy_einsum-0.0.3\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0mCollecting transformer_lens\n", + " Downloading transformer_lens-1.6.1-py3-none-any.whl (109 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m109.9/109.9 kB\u001b[0m \u001b[31m24.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting beartype<0.15.0,>=0.14.1\n", + " Downloading beartype-0.14.1-py3-none-any.whl (739 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m739.7/739.7 kB\u001b[0m \u001b[31m92.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: torch>=1.10 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (1.12.1+cu116)\n", + "Requirement already satisfied: tqdm>=4.64.1 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (4.64.1)\n", + "Requirement already satisfied: numpy>=1.21 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (1.23.4)\n", + "Collecting transformers>=4.25.1\n", + " Downloading transformers-4.33.2-py3-none-any.whl (7.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m92.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m:00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hCollecting datasets>=2.7.1\n", + " Downloading datasets-2.14.5-py3-none-any.whl (519 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m519.6/519.6 kB\u001b[0m \u001b[31m86.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting einops>=0.6.0\n", + " Downloading einops-0.6.1-py3-none-any.whl (42 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m42.2/42.2 kB\u001b[0m \u001b[31m14.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: pandas>=1.1.5 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (1.5.0)\n", + "Requirement already satisfied: fancy-einsum>=0.0.3 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (0.0.3)\n", + "Collecting wandb>=0.13.5\n", + " Downloading wandb-0.15.11-py3-none-any.whl (2.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.1/2.1 MB\u001b[0m \u001b[31m105.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: rich>=12.6.0 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (13.2.0)\n", + "Collecting jaxtyping>=0.2.11\n", + " Downloading jaxtyping-0.2.22-py3-none-any.whl (25 kB)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (5.4.1)\n", + "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (10.0.1)\n", + "Requirement already satisfied: fsspec[http]<2023.9.0,>=2023.1.0 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (2023.1.0)\n", + "Requirement already satisfied: xxhash in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (3.2.0)\n", + "Requirement already satisfied: aiohttp in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (3.8.3)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (23.0)\n", + "Collecting huggingface-hub<1.0.0,>=0.14.0\n", + " Downloading huggingface_hub-0.17.2-py3-none-any.whl (294 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m294.9/294.9 kB\u001b[0m \u001b[31m64.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (2.28.2)\n", + "Requirement already satisfied: multiprocess in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (0.70.13)\n", + "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (0.3.5.1)\n", + "Collecting typeguard>=2.13.3\n", + " Downloading typeguard-4.1.5-py3-none-any.whl (34 kB)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.1 in /usr/local/lib/python3.9/dist-packages (from jaxtyping>=0.2.11->transformer_lens) (4.4.0)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/dist-packages (from pandas>=1.1.5->transformer_lens) (2022.7.1)\n", + "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.9/dist-packages (from pandas>=1.1.5->transformer_lens) (2.8.2)\n", + "Requirement already satisfied: markdown-it-py<3.0.0,>=2.1.0 in /usr/local/lib/python3.9/dist-packages (from rich>=12.6.0->transformer_lens) (2.1.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.6.0 in /usr/local/lib/python3.9/dist-packages (from rich>=12.6.0->transformer_lens) (2.14.0)\n", + "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.9/dist-packages (from transformers>=4.25.1->transformer_lens) (0.12.1)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.9/dist-packages (from transformers>=4.25.1->transformer_lens) (2022.10.31)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from transformers>=4.25.1->transformer_lens) (3.9.0)\n", + "Collecting safetensors>=0.3.1\n", + " Downloading safetensors-0.3.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m102.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: pathtools in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (0.1.2)\n", + "Requirement already satisfied: Click!=8.0.0,>=7.1 in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (8.1.3)\n", + "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (3.1.30)\n", + "Requirement already satisfied: sentry-sdk>=1.0.0 in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (1.14.0)\n", + "Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (5.9.4)\n", + "Collecting appdirs>=1.4.3\n", + " Downloading appdirs-1.4.4-py2.py3-none-any.whl (9.6 kB)\n", + "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.15.0 in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (3.19.6)\n", + "Requirement already satisfied: setproctitle in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (1.3.2)\n", + "Requirement already satisfied: docker-pycreds>=0.4.0 in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (0.4.0)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (66.1.1)\n", + "Requirement already satisfied: six>=1.4.0 in /usr/lib/python3/dist-packages (from docker-pycreds>=0.4.0->wandb>=0.13.5->transformer_lens) (1.14.0)\n", + "Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (2.1.1)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.3.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (6.0.4)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (4.0.2)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.3.3)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.8.2)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (18.2.0)\n", + "Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.9/dist-packages (from GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer_lens) (4.0.10)\n", + "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.9/dist-packages (from markdown-it-py<3.0.0,>=2.1.0->rich>=12.6.0->transformer_lens) (0.1.2)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/lib/python3/dist-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (2019.11.28)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3/dist-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (2.8)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/dist-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (1.26.14)\n", + "Requirement already satisfied: importlib-metadata>=3.6 in /usr/local/lib/python3.9/dist-packages (from typeguard>=2.13.3->jaxtyping>=0.2.11->transformer_lens) (6.0.0)\n", + "Collecting typing-extensions>=3.7.4.1\n", + " Downloading typing_extensions-4.8.0-py3-none-any.whl (31 kB)\n", + "Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.9/dist-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer_lens) (5.0.0)\n", + "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.9/dist-packages (from importlib-metadata>=3.6->typeguard>=2.13.3->jaxtyping>=0.2.11->transformer_lens) (3.11.0)\n", + "Installing collected packages: safetensors, appdirs, typing-extensions, einops, beartype, typeguard, huggingface-hub, wandb, transformers, jaxtyping, datasets, transformer_lens\n", + " Attempting uninstall: typing-extensions\n", + " Found existing installation: typing_extensions 4.4.0\n", + " Uninstalling typing_extensions-4.4.0:\n", + " Successfully uninstalled typing_extensions-4.4.0\n", + " Attempting uninstall: huggingface-hub\n", + " Found existing installation: huggingface-hub 0.12.0\n", + " Uninstalling huggingface-hub-0.12.0:\n", + " Successfully uninstalled huggingface-hub-0.12.0\n", + " Attempting uninstall: wandb\n", + " Found existing installation: wandb 0.13.4\n", + " Uninstalling wandb-0.13.4:\n", + " Successfully uninstalled wandb-0.13.4\n", + " Attempting uninstall: transformers\n", + " Found existing installation: transformers 4.21.3\n", + " Uninstalling transformers-4.21.3:\n", + " Successfully uninstalled transformers-4.21.3\n", + " Attempting uninstall: datasets\n", + " Found existing installation: datasets 2.4.0\n", + " Uninstalling datasets-2.4.0:\n", + " Successfully uninstalled datasets-2.4.0\n", + "Successfully installed appdirs-1.4.4 beartype-0.14.1 datasets-2.14.5 einops-0.6.1 huggingface-hub-0.17.2 jaxtyping-0.2.22 safetensors-0.3.3 transformer_lens-1.6.1 transformers-4.33.2 typeguard-4.1.5 typing-extensions-4.8.0 wandb-0.15.11\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0mCollecting jaxtyping==0.2.13\n", + " Downloading jaxtyping-0.2.13-py3-none-any.whl (19 kB)\n", + "Requirement already satisfied: typeguard>=2.13.3 in /usr/local/lib/python3.9/dist-packages (from jaxtyping==0.2.13) (4.1.5)\n", + "Requirement already satisfied: numpy>=1.20.0 in /usr/local/lib/python3.9/dist-packages (from jaxtyping==0.2.13) (1.23.4)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.1 in /usr/local/lib/python3.9/dist-packages (from jaxtyping==0.2.13) (4.8.0)\n", + "Requirement already satisfied: importlib-metadata>=3.6 in /usr/local/lib/python3.9/dist-packages (from typeguard>=2.13.3->jaxtyping==0.2.13) (6.0.0)\n", + "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.9/dist-packages (from importlib-metadata>=3.6->typeguard>=2.13.3->jaxtyping==0.2.13) (3.11.0)\n", + "Installing collected packages: jaxtyping\n", + " Attempting uninstall: jaxtyping\n", + " Found existing installation: jaxtyping 0.2.22\n", + " Uninstalling jaxtyping-0.2.22:\n", + " Successfully uninstalled jaxtyping-0.2.22\n", + "Successfully installed jaxtyping-0.2.13\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0mRequirement already satisfied: einops in /usr/local/lib/python3.9/dist-packages (0.6.1)\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0mCollecting protobuf==3.20.*\n", + " Downloading protobuf-3.20.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.0 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m3.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0mm\n", + "\u001b[?25hInstalling collected packages: protobuf\n", + " Attempting uninstall: protobuf\n", + " Found existing installation: protobuf 3.19.6\n", + " Uninstalling protobuf-3.19.6:\n", + " Successfully uninstalled protobuf-3.19.6\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "tensorflow 2.9.2 requires protobuf<3.20,>=3.9.2, but you have protobuf 3.20.3 which is incompatible.\n", + "tensorboard 2.9.1 requires protobuf<3.20,>=3.9.2, but you have protobuf 3.20.3 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0mSuccessfully installed protobuf-3.20.3\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0mCollecting plotly\n", + " Downloading plotly-5.17.0-py2.py3-none-any.whl (15.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m15.6/15.6 MB\u001b[0m \u001b[31m87.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.9/dist-packages (from plotly) (23.0)\n", + "Collecting tenacity>=6.2.0\n", + " Downloading tenacity-8.2.3-py3-none-any.whl (24 kB)\n", + "Installing collected packages: tenacity, plotly\n", + "Successfully installed plotly-5.17.0 tenacity-8.2.3\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0mCollecting torchtyping\n", + " Downloading torchtyping-0.1.4-py3-none-any.whl (17 kB)\n", + "Requirement already satisfied: typeguard>=2.11.1 in /usr/local/lib/python3.9/dist-packages (from torchtyping) (4.1.5)\n", + "Requirement already satisfied: torch>=1.7.0 in /usr/local/lib/python3.9/dist-packages (from torchtyping) (1.12.1+cu116)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (4.8.0)\n", + "Requirement already satisfied: importlib-metadata>=3.6 in /usr/local/lib/python3.9/dist-packages (from typeguard>=2.11.1->torchtyping) (6.0.0)\n", + "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.9/dist-packages (from importlib-metadata>=3.6->typeguard>=2.11.1->torchtyping) (3.11.0)\n", + "Installing collected packages: torchtyping\n", + "Successfully installed torchtyping-0.1.4\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0mCollecting git+https://github.com/neelnanda-io/neel-plotly.git\n", + " Cloning https://github.com/neelnanda-io/neel-plotly.git to /tmp/pip-req-build-9z19ecnw\n", + " Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/neel-plotly.git /tmp/pip-req-build-9z19ecnw\n", + " Resolved https://github.com/neelnanda-io/neel-plotly.git to commit 6dc24b26f8dec991908479d7445dae496b3430b7\n", + " Preparing metadata (setup.py) ... \u001b[?25ldone\n", + "\u001b[?25hRequirement already satisfied: einops in /usr/local/lib/python3.9/dist-packages (from neel-plotly==0.0.0) (0.6.1)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.9/dist-packages (from neel-plotly==0.0.0) (1.23.4)\n", + "Requirement already satisfied: torch in /usr/local/lib/python3.9/dist-packages (from neel-plotly==0.0.0) (1.12.1+cu116)\n", + "Requirement already satisfied: plotly in /usr/local/lib/python3.9/dist-packages (from neel-plotly==0.0.0) (5.17.0)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.9/dist-packages (from neel-plotly==0.0.0) (4.64.1)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.9/dist-packages (from neel-plotly==0.0.0) (1.5.0)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/dist-packages (from pandas->neel-plotly==0.0.0) (2022.7.1)\n", + "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.9/dist-packages (from pandas->neel-plotly==0.0.0) (2.8.2)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.9/dist-packages (from plotly->neel-plotly==0.0.0) (23.0)\n", + "Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.9/dist-packages (from plotly->neel-plotly==0.0.0) (8.2.3)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (4.8.0)\n", + "Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.1->pandas->neel-plotly==0.0.0) (1.14.0)\n", + "Building wheels for collected packages: neel-plotly\n", + " Building wheel for neel-plotly (setup.py) ... \u001b[?25ldone\n", + "\u001b[?25h Created wheel for neel-plotly: filename=neel_plotly-0.0.0-py3-none-any.whl size=10186 sha256=84c81b2669dc62c405ebd7a7d20a9fb33c9ff0f554f6e0652a869e66f6ab7777\n", + " Stored in directory: /tmp/pip-ephem-wheel-cache-4ijm16m2/wheels/e1/3c/c0/b5897c402b85e7fc329feb205ad5948b518f0423d891a79f7f\n", + "Successfully built neel-plotly\n", + "Installing collected packages: neel-plotly\n", + "Successfully installed neel-plotly-0.0.0\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0mCollecting circuitsvis\n", + " Downloading circuitsvis-1.41.0-py3-none-any.whl (1.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m84.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting importlib-metadata<6.0.0,>=5.1.0\n", + " Downloading importlib_metadata-5.2.0-py3-none-any.whl (21 kB)\n", + "Requirement already satisfied: torch>=1.10 in /usr/local/lib/python3.9/dist-packages (from circuitsvis) (1.12.1+cu116)\n", + "Requirement already satisfied: numpy<2.0,>=1.21 in /usr/local/lib/python3.9/dist-packages (from circuitsvis) (1.23.4)\n", + "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.9/dist-packages (from importlib-metadata<6.0.0,>=5.1.0->circuitsvis) (3.11.0)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->circuitsvis) (4.8.0)\n", + "Installing collected packages: importlib-metadata, circuitsvis\n", + " Attempting uninstall: importlib-metadata\n", + " Found existing installation: importlib-metadata 6.0.0\n", + " Uninstalling importlib-metadata-6.0.0:\n", + " Successfully uninstalled importlib-metadata-6.0.0\n", + "Successfully installed circuitsvis-1.41.0 importlib-metadata-5.2.0\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0mCollecting kaleido\n", + " Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl (79.9 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m79.9/79.9 MB\u001b[0m \u001b[31m25.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hInstalling collected packages: kaleido\n", + "Successfully installed kaleido-0.2.1\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", + "\u001b[0m" + ] + } + ], + "source": [ + "#!pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python\n", + "!pip install fancy_einsum==0.0.3\n", + "!pip install transformer_lens\n", + "!pip install jaxtyping==0.2.13\n", + "!pip install einops\n", + "!pip install protobuf==3.20.*\n", + "!pip install plotly\n", + "!pip install torchtyping\n", + "!pip install git+https://github.com/neelnanda-io/neel-plotly.git\n", + "!pip install circuitsvis\n", + "!pip install -U kaleido \n", + "# !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs\n", + "# %pip install git+https://github.com/neelnanda-io/PySvelte.git\n", + "# %pip install typeguard==2.13.3" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "ad57a333", + "metadata": {}, + "outputs": [], + "source": [ + "from IPython import get_ipython\n", + "ipython = get_ipython()\n", + "ipython.run_line_magic(\"load_ext\", \"autoreload\")\n", + "ipython.run_line_magic(\"autoreload\", \"2\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "c752b5c4", + "metadata": { + "id": "8QQvkqmWcB2v" + }, + "outputs": [], + "source": [ + "import os\n", + "import pathlib\n", + "from typing import List, Optional, Union\n", + "\n", + "import torch\n", + "import numpy as np\n", + "import yaml\n", + "\n", + "import einops\n", + "from fancy_einsum import einsum\n", + "\n", + "\n", + "import circuitsvis as cv\n", + "\n", + "import transformer_lens.utils as utils\n", + "from transformer_lens.hook_points import (\n", + " HookedRootModule,\n", + " HookPoint,\n", + ") # Hooking utilities\n", + "from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache\n", + "\n", + "from torch import Tensor\n", + "from tqdm.notebook import tqdm\n", + "from jaxtyping import Float, Int, Bool\n", + "from typing import List, Optional, Callable, Tuple, Dict, Literal, Set\n", + "from rich import print as rprint\n", + "\n", + "from typing import List, Union\n", + "import plotly.express as px\n", + "import plotly.graph_objects as go\n", + "from plotly.subplots import make_subplots\n", + "import re\n", + "\n", + "from functools import partial\n", + "\n", + "from torchtyping import TensorType as TT\n", + "\n", + "from path_patching import Node, IterNode, path_patch, act_patch\n", + "from neel_plotly import imshow as imshow_n\n", + "\n", + "from utils.visualization import get_attn_head_patterns\n", + "from utils.prompts import get_dataset\n", + "from utils.circuit_analysis import get_logit_diff, logit_diff_denoising, logit_diff_noising" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "0c0d646b", + "metadata": {}, + "outputs": [], + "source": [ + "torch.set_grad_enabled(False)\n", + "\n", + "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "19389c35", + "metadata": {}, + "outputs": [], + "source": [ + "update_layout_set = {\n", + " \"xaxis_range\", \"yaxis_range\", \"hovermode\", \"xaxis_title\", \"yaxis_title\", \"colorbar\", \"colorscale\", \"coloraxis\", \"title_x\", \"bargap\", \"bargroupgap\", \"xaxis_tickformat\",\n", + " \"yaxis_tickformat\", \"title_y\", \"legend_title_text\", \"xaxis_showgrid\", \"xaxis_gridwidth\", \"xaxis_gridcolor\", \"yaxis_showgrid\", \"yaxis_gridwidth\", \"yaxis_gridcolor\",\n", + " \"showlegend\", \"xaxis_tickmode\", \"yaxis_tickmode\", \"xaxis_tickangle\", \"yaxis_tickangle\", \"margin\", \"xaxis_visible\", \"yaxis_visible\", \"bargap\", \"bargroupgap\"\n", + "}\n", + "\n", + "def imshow_p(tensor, renderer=None, **kwargs):\n", + " kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}\n", + " kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}\n", + " facet_labels = kwargs_pre.pop(\"facet_labels\", None)\n", + " border = kwargs_pre.pop(\"border\", False)\n", + " if \"color_continuous_scale\" not in kwargs_pre:\n", + " kwargs_pre[\"color_continuous_scale\"] = \"RdBu\"\n", + " if \"margin\" in kwargs_post and isinstance(kwargs_post[\"margin\"], int):\n", + " kwargs_post[\"margin\"] = dict.fromkeys(list(\"tblr\"), kwargs_post[\"margin\"])\n", + " fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, **kwargs_pre)\n", + " if facet_labels:\n", + " for i, label in enumerate(facet_labels):\n", + " fig.layout.annotations[i]['text'] = label\n", + " if border:\n", + " fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=True)\n", + " fig.update_yaxes(showline=True, linewidth=1, linecolor='black', mirror=True)\n", + " # things like `xaxis_tickmode` should be applied to all subplots. This is super janky lol but I'm under time pressure\n", + " for setting in [\"tickangle\"]:\n", + " if f\"xaxis_{setting}\" in kwargs_post:\n", + " i = 2\n", + " while f\"xaxis{i}\" in fig[\"layout\"]:\n", + " kwargs_post[f\"xaxis{i}_{setting}\"] = kwargs_post[f\"xaxis_{setting}\"]\n", + " i += 1\n", + " fig.update_layout(**kwargs_post)\n", + " fig.show(renderer=renderer)\n", + "\n", + "def hist_p(tensor, renderer=None, **kwargs):\n", + " kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}\n", + " kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}\n", + " names = kwargs_pre.pop(\"names\", None)\n", + " if \"barmode\" not in kwargs_post:\n", + " kwargs_post[\"barmode\"] = \"overlay\"\n", + " if \"bargap\" not in kwargs_post:\n", + " kwargs_post[\"bargap\"] = 0.0\n", + " if \"margin\" in kwargs_post and isinstance(kwargs_post[\"margin\"], int):\n", + " kwargs_post[\"margin\"] = dict.fromkeys(list(\"tblr\"), kwargs_post[\"margin\"])\n", + " fig = px.histogram(x=tensor, **kwargs_pre).update_layout(**kwargs_post)\n", + " if names is not None:\n", + " for i in range(len(fig.data)):\n", + " fig.data[i][\"name\"] = names[i // 2]\n", + " fig.show(renderer)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "bb406a6b", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "import plotly.io as pio\n", + "def imshow_p(tensor, renderer=None, save_path=None, **kwargs):\n", + " kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}\n", + " kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}\n", + " facet_labels = kwargs_pre.pop(\"facet_labels\", None)\n", + " border = kwargs_pre.pop(\"border\", False)\n", + " if \"color_continuous_scale\" not in kwargs_pre:\n", + " kwargs_pre[\"color_continuous_scale\"] = \"RdBu\"\n", + " if \"margin\" in kwargs_post and isinstance(kwargs_post[\"margin\"], int):\n", + " kwargs_post[\"margin\"] = dict.fromkeys(list(\"tblr\"), kwargs_post[\"margin\"])\n", + " fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, **kwargs_pre)\n", + " if facet_labels:\n", + " for i, label in enumerate(facet_labels):\n", + " fig.layout.annotations[i]['text'] = label\n", + " if border:\n", + " fig.update_xaxes(showline=True, linewidth=1, linecolor='black', mirror=True)\n", + " fig.update_yaxes(showline=True, linewidth=1, linecolor='black', mirror=True)\n", + " # things like `xaxis_tickmode` should be applied to all subplots. This is super janky lol but I'm under time pressure\n", + " for setting in [\"tickangle\"]:\n", + " if f\"xaxis_{setting}\" in kwargs_post:\n", + " i = 2\n", + " while f\"xaxis{i}\" in fig[\"layout\"]:\n", + " kwargs_post[f\"xaxis{i}_{setting}\"] = kwargs_post[f\"xaxis_{setting}\"]\n", + " i += 1\n", + " fig.update_layout(**kwargs_post)\n", + " fig.show(renderer=renderer)\n", + "\n", + " # Save the figure as a PDF if save_path is provided\n", + " if save_path:\n", + " pio.write_image(fig, save_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "2702fb84", + "metadata": { + "id": "0c0JbzPpI0-D" + }, + "outputs": [], + "source": [ + "def imshow(tensor, renderer=None, xaxis=\"\", yaxis=\"\", **kwargs):\n", + " px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale=\"RdBu\", labels={\"x\":xaxis, \"y\":yaxis}, **kwargs).show(renderer)\n", + "\n", + "def line(tensor, renderer=None, **kwargs):\n", + " px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)\n", + "\n", + "def two_lines(tensor1, tensor2, renderer=None, **kwargs):\n", + " px.line(y=[utils.to_numpy(tensor1), utils.to_numpy(tensor2)], **kwargs).show(renderer)\n", + "\n", + "def scatter(x, y, xaxis=\"\", yaxis=\"\", caxis=\"\", renderer=None, **kwargs):\n", + " x = utils.to_numpy(x)\n", + " y = utils.to_numpy(y)\n", + " px.scatter(y=y, x=x, labels={\"x\":xaxis, \"y\":yaxis, \"color\":caxis}, **kwargs).show(renderer)" + ] + }, + { + "cell_type": "markdown", + "id": "0937dfa1", + "metadata": { + "id": "y5jV1EnY0dpf" + }, + "source": [ + "## Exploratory Analysis\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "ecc9c64c", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "bjeWvBNOn2VT", + "outputId": "dff069ed-56d5-414d-d649-2c70f073b1fc" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d044dfd83f3c47aca765252b0196a1bd", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Downloading (…)lve/main/config.json: 0%| | 0.00/665 [00:00', 'I', ' thought', ' this', ' movie', ' was', ' lousy', ',', ' I', ' hated', ' it', '.', ' ', '\\n', 'Conclusion', ':', ' This', ' movie', ' is']\n", + "Tokenized answer: [' terrible']\n" + ] + }, + { + "data": { + "text/html": [ + "
Performance on answer token:\n",
+       "Rank: 1        Logit: 14.10 Prob:  4.47% Token: | terrible|\n",
+       "
\n" + ], + "text/plain": [ + "Performance on answer token:\n", + "\u001b[1mRank: \u001b[0m\u001b[1;36m1\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m14.10\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m4.47\u001b[0m\u001b[1m% Token: | terrible|\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Top 0th token. Logit: 14.87 Prob: 9.68% Token: | a|\n", + "Top 1th token. Logit: 14.10 Prob: 4.47% Token: | terrible|\n", + "Top 2th token. Logit: 13.93 Prob: 3.77% Token: | bad|\n", + "Top 3th token. Logit: 13.90 Prob: 3.67% Token: | so|\n", + "Top 4th token. Logit: 13.87 Prob: 3.57% Token: | not|\n", + "Top 5th token. Logit: 13.80 Prob: 3.30% Token: | awful|\n", + "Top 6th token. Logit: 13.28 Prob: 1.98% Token: | pretty|\n", + "Top 7th token. Logit: 13.25 Prob: 1.91% Token: | just|\n", + "Top 8th token. Logit: 13.16 Prob: 1.75% Token: | really|\n", + "Top 9th token. Logit: 13.13 Prob: 1.69% Token: | the|\n" + ] + }, + { + "data": { + "text/html": [ + "
Ranks of the answer tokens: [(' terrible', 1)]\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' terrible'\u001b[0m, \u001b[1;36m1\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True, top_k=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "6d1d540d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tokenized prompt: ['<|endoftext|>', 'I', ' thought', ' this', ' movie', ' was', ' pass', 'able', ',', ' I', ' watched', ' it', '.', ' ', '\\n', 'Conclusion', ':', ' This', ' movie', ' is']\n", + "Tokenized answer: [' average']\n" + ] + }, + { + "data": { + "text/html": [ + "
Performance on answer token:\n",
+       "Rank: 155      Logit:  9.94 Prob:  0.07% Token: | average|\n",
+       "
\n" + ], + "text/plain": [ + "Performance on answer token:\n", + "\u001b[1mRank: \u001b[0m\u001b[1;36m155\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m9.94\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m0.07\u001b[0m\u001b[1m% Token: | average|\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Top 0th token. Logit: 15.18 Prob: 13.60% Token: | a|\n", + "Top 1th token. Logit: 13.98 Prob: 4.09% Token: | not|\n", + "Top 2th token. Logit: 13.57 Prob: 2.70% Token: | the|\n", + "Top 3th token. Logit: 13.48 Prob: 2.48% Token: | great|\n", + "Top 4th token. Logit: 13.47 Prob: 2.44% Token: | an|\n", + "Top 5th token. Logit: 13.46 Prob: 2.44% Token: | so|\n", + "Top 6th token. Logit: 13.43 Prob: 2.36% Token: | pretty|\n", + "Top 7th token. Logit: 13.42 Prob: 2.34% Token: | one|\n", + "Top 8th token. Logit: 13.22 Prob: 1.90% Token: | just|\n", + "Top 9th token. Logit: 13.11 Prob: 1.71% Token: | really|\n" + ] + }, + { + "data": { + "text/html": [ + "
Ranks of the answer tokens: [(' average', 155)]\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' average'\u001b[0m, \u001b[1;36m155\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "example_prompt = \"I thought this movie was passable, I watched it. \\nConclusion: This movie is\"\n", + "example_answer = \" average\"\n", + "utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True, top_k=10)" + ] + }, + { + "cell_type": "markdown", + "id": "1d590c01", + "metadata": {}, + "source": [ + "### Dataset Construction" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "ff537629", + "metadata": { + "lines_to_next_cell": 0 + }, + "outputs": [], + "source": [ + "#pos_answers = [\" Positive\", \" amazing\", \" good\"]\n", + "#neg_answers = [\" Negative\", \" terrible\", \" bad\"]\n", + "\n", + "TASK = \"simple_sentiment\"\n", + "\n", + "clean_corrupt_data = get_dataset(\n", + " model, device, 1, comparison=(\"positive\", \"negative\")\n", + ")\n", + "all_prompts = clean_corrupt_data.all_prompts\n", + "clean_tokens = clean_corrupt_data.clean_tokens\n", + "corrupted_tokens = clean_corrupt_data.corrupted_tokens\n", + "answer_tokens = clean_corrupt_data.answer_tokens" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "d6d177c3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(30, torch.Size([30, 1, 2]), torch.Size([30, 19]), torch.Size([30, 19]))" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(all_prompts), answer_tokens.shape, clean_tokens.shape, corrupted_tokens.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "7cbdbfea", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "I thought this movie was perfect, I enjoyed it. \n", + "Conclusion: This movie is\n", + "tensor(2.8339, device='cuda:0')\n", + "I thought this movie was dreadful, I hated it. \n", + "Conclusion: This movie is\n", + "tensor(1.7237, device='cuda:0')\n", + "I thought this movie was fantastic, I loved it. \n", + "Conclusion: This movie is\n", + "tensor(2.6454, device='cuda:0')\n", + "I thought this movie was bad, I disliked it. \n", + "Conclusion: This movie is\n", + "tensor(0.7387, device='cuda:0')\n", + "I thought this movie was marvelous, I liked it. \n", + "Conclusion: This movie is\n", + "tensor(2.4686, device='cuda:0')\n", + "I thought this movie was miserable, I despised it. \n", + "Conclusion: This movie is\n", + "tensor(1.2674, device='cuda:0')\n", + "I thought this movie was good, I appreciated it. \n", + "Conclusion: This movie is\n", + "tensor(2.1577, device='cuda:0')\n", + "I thought this movie was horrific, I hated it. \n", + "Conclusion: This movie is\n", + "tensor(2.0667, device='cuda:0')\n", + "I thought this movie was remarkable, I admired it. \n", + "Conclusion: This movie is\n", + "tensor(2.3184, device='cuda:0')\n", + "I thought this movie was terrible, I disliked it. \n", + "Conclusion: This movie is\n", + "tensor(1.1251, device='cuda:0')\n", + "I thought this movie was wonderful, I enjoyed it. \n", + "Conclusion: This movie is\n", + "tensor(2.7509, device='cuda:0')\n", + "I thought this movie was disgusting, I despised it. \n", + "Conclusion: This movie is\n", + "tensor(1.7025, device='cuda:0')\n", + "I thought this movie was fabulous, I loved it. \n", + "Conclusion: This movie is\n", + "tensor(2.7947, device='cuda:0')\n", + "I thought this movie was disastrous, I hated it. \n", + "Conclusion: This movie is\n", + "tensor(1.8795, device='cuda:0')\n", + "I thought this movie was outstanding, I liked it. \n", + "Conclusion: This movie is\n", + "tensor(3.0432, device='cuda:0')\n", + "I thought this movie was horrendous, I disliked it. \n", + "Conclusion: This movie is\n", + "tensor(1.8827, device='cuda:0')\n", + "I thought this movie was awesome, I appreciated it. \n", + "Conclusion: This movie is\n", + "tensor(2.6942, device='cuda:0')\n", + "I thought this movie was offensive, I despised it. \n", + "Conclusion: This movie is\n", + "tensor(1.4065, device='cuda:0')\n", + "I thought this movie was exceptional, I admired it. \n", + "Conclusion: This movie is\n", + "tensor(2.5653, device='cuda:0')\n", + "I thought this movie was wretched, I hated it. \n", + "Conclusion: This movie is\n", + "tensor(1.5566, device='cuda:0')\n", + "I thought this movie was incredible, I enjoyed it. \n", + "Conclusion: This movie is\n", + "tensor(2.6843, device='cuda:0')\n", + "I thought this movie was awful, I disliked it. \n", + "Conclusion: This movie is\n", + "tensor(1.5595, device='cuda:0')\n", + "I thought this movie was extraordinary, I loved it. \n", + "Conclusion: This movie is\n", + "tensor(2.4112, device='cuda:0')\n", + "I thought this movie was unpleasant, I despised it. \n", + "Conclusion: This movie is\n", + "tensor(1.4727, device='cuda:0')\n", + "I thought this movie was amazing, I liked it. \n", + "Conclusion: This movie is\n", + "tensor(2.7210, device='cuda:0')\n", + "I thought this movie was horrible, I hated it. \n", + "Conclusion: This movie is\n", + "tensor(1.6034, device='cuda:0')\n", + "I thought this movie was lovely, I appreciated it. \n", + "Conclusion: This movie is\n", + "tensor(2.4772, device='cuda:0')\n", + "I thought this movie was mediocre, I disliked it. \n", + "Conclusion: This movie is\n", + "tensor(0.3483, device='cuda:0')\n", + "I thought this movie was brilliant, I admired it. \n", + "Conclusion: This movie is\n", + "tensor(2.3589, device='cuda:0')\n", + "I thought this movie was disappointing, I despised it. \n", + "Conclusion: This movie is\n", + "tensor(0.7886, device='cuda:0')\n" + ] + } + ], + "source": [ + "for i in range(len(all_prompts)):\n", + " logits, _ = model.run_with_cache(all_prompts[i])\n", + " print(all_prompts[i])\n", + " print(get_logit_diff(logits, answer_tokens[i].unsqueeze(0)))" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "399cd571", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(2.0016, device='cuda:0')" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "clean_logits, clean_cache = model.run_with_cache(clean_tokens)\n", + "clean_logit_diff = get_logit_diff(clean_logits, answer_tokens, per_prompt=False)\n", + "clean_logit_diff" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "1f4aba44", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(-2.0016, device='cuda:0')" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)\n", + "corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_tokens, per_prompt=False)\n", + "corrupted_logit_diff" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "e5a745c0", + "metadata": {}, + "outputs": [], + "source": [ + "def logit_diff_denoising(\n", + " logits: Float[Tensor, \"batch seq d_vocab\"],\n", + " answer_tokens: Float[Tensor, \"batch 2\"] = answer_tokens,\n", + " flipped_logit_diff: float = corrupted_logit_diff,\n", + " clean_logit_diff: float = clean_logit_diff,\n", + ") -> Float[Tensor, \"\"]:\n", + " '''\n", + " Linear function of logit diff, calibrated so that it equals 0 when performance is\n", + " same as on flipped input, and 1 when performance is same as on clean input.\n", + " '''\n", + " patched_logit_diff = get_logit_diff(logits, answer_tokens)\n", + " return ((patched_logit_diff - flipped_logit_diff) / (clean_logit_diff - flipped_logit_diff)).item()\n", + "\n", + "def logit_diff_noising(\n", + " logits: Float[Tensor, \"batch seq d_vocab\"],\n", + " clean_logit_diff: float = clean_logit_diff,\n", + " corrupted_logit_diff: float = corrupted_logit_diff,\n", + " answer_tokens: Float[Tensor, \"batch 2\"] = answer_tokens,\n", + " ) -> float:\n", + " '''\n", + " We calibrate this so that the value is 0 when performance isn't harmed (i.e. same as IOI dataset),\n", + " and -1 when performance has been destroyed (i.e. is same as ABC dataset).\n", + " '''\n", + " patched_logit_diff = get_logit_diff(logits, answer_tokens)\n", + " return ((patched_logit_diff - clean_logit_diff) / (clean_logit_diff - corrupted_logit_diff)).item()" + ] + }, + { + "cell_type": "markdown", + "id": "cc78321c", + "metadata": { + "id": "TfiWnZtelFMV" + }, + "source": [ + "### Direct Logit Attribution" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "3c3e4b68", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "bt_jzrazlMAK", + "outputId": "39683745-1153-4a0f-bdbf-5f3be977abe3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Answer residual directions shape: torch.Size([30, 2, 768])\n", + "Logit difference directions shape: torch.Size([30, 768])\n" + ] + } + ], + "source": [ + "answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)\n", + "answer_residual_directions = answer_residual_directions.mean(dim=1)\n", + "print(\"Answer residual directions shape:\", answer_residual_directions.shape)\n", + "logit_diff_directions = answer_residual_directions[:, 0] - answer_residual_directions[:, 1]\n", + "print(\"Logit difference directions shape:\", logit_diff_directions.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "819b8cc5", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "LsDE7VUGIX8l", + "outputId": "226c2ad4-fb5b-44f4-b872-1d06eee7cbd5" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Final residual stream shape: torch.Size([30, 19, 768])\n", + "Calculated average logit diff: 2.0015628337860107\n", + "Original logit difference: 2.001563787460327\n" + ] + } + ], + "source": [ + "# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type]. \n", + "final_residual_stream = clean_cache[\"resid_post\", -1]\n", + "print(\"Final residual stream shape:\", final_residual_stream.shape)\n", + "final_token_residual_stream = final_residual_stream[:, -1, :]\n", + "# Apply LayerNorm scaling\n", + "# pos_slice is the subset of the positions we take - here the final token of each prompt\n", + "scaled_final_token_residual_stream = clean_cache.apply_ln_to_stack(final_token_residual_stream, layer = -1, pos_slice=-1)\n", + "\n", + "average_logit_diff = einsum(\"batch d_model, batch d_model -> \", scaled_final_token_residual_stream, logit_diff_directions)/len(all_prompts)\n", + "print(\"Calculated average logit diff:\", average_logit_diff.item())\n", + "print(\"Original logit difference:\",clean_logit_diff.item())" + ] + }, + { + "cell_type": "markdown", + "id": "e67ff7e5", + "metadata": { + "id": "Nb2nC45lIohT" + }, + "source": [ + "#### Logit Lens" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "2233318a", + "metadata": { + "id": "DvRDK2krIrid" + }, + "outputs": [], + "source": [ + "def residual_stack_to_logit_diff(residual_stack: TT[\"components\", \"batch\", \"d_model\"], cache: ActivationCache) -> float:\n", + " scaled_residual_stack = clean_cache.apply_ln_to_stack(residual_stack, layer = -1, pos_slice=-1)\n", + " return einsum(\"... batch d_model, batch d_model -> ...\", scaled_residual_stack, logit_diff_directions)/len(all_prompts)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "27fdbe69", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 542 + }, + "id": "7vxP1pNuPMhr", + "outputId": "616ac0ef-ddd2-4b1e-bccd-8ee3a3ebce23" + }, + "outputs": [ + { + "data": { + "text/html": [ + " \n", + " " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "hovertemplate": "%{hovertext}

x=%{x}
y=%{y}", + "hovertext": [ + "0_pre", + "0_mid", + "1_pre", + "1_mid", + "2_pre", + "2_mid", + "3_pre", + "3_mid", + "4_pre", + "4_mid", + "5_pre", + "5_mid", + "6_pre", + "6_mid", + "7_pre", + "7_mid", + "8_pre", + "8_mid", + "9_pre", + "9_mid", + "10_pre", + "10_mid", + "11_pre", + "11_mid", + "final_post" + ], + "legendgroup": "", + "line": { + "color": "#636efa", + "dash": "solid" + }, + "marker": { + "symbol": "circle" + }, + "mode": "lines", + "name": "", + "orientation": "v", + "showlegend": false, + "type": "scatter", + "x": [ + 0, + 0.5, + 1, + 1.5, + 2, + 2.5, + 3, + 3.5, + 4, + 4.5, + 5, + 5.5, + 6, + 6.5, + 7, + 7.5, + 8, + 8.5, + 9, + 9.5, + 10, + 10.5, + 11, + 11.5, + 12 + ], + "xaxis": "x", + "y": [ + -0.0002836537314578891, + 0.015164419077336788, + 0.013736815191805363, + 0.026919517666101456, + 0.01977485790848732, + 0.03669975697994232, + 0.020886223763227463, + 0.02444334886968136, + 0.01876016892492771, + 0.029971612617373466, + 0.03129005804657936, + 0.05612754076719284, + 0.0446736104786396, + 0.10143791139125824, + 0.096422478556633, + 0.20625260472297668, + 0.2378016859292984, + 0.35053136944770813, + 0.45829394459724426, + 0.7616435885429382, + 0.9647932648658752, + 1.617396593093872, + 1.9062424898147583, + 2.124509334564209, + 2.0015628337860107 + ], + "yaxis": "y" + } + ], + "layout": { + "legend": { + "tracegroupgap": 0 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Logit Difference From Accumulate Residual Stream" + }, + "xaxis": { + "anchor": "y", + "domain": [ + 0, + 1 + ], + "title": { + "text": "x" + } + }, + "yaxis": { + "anchor": "x", + "domain": [ + 0, + 1 + ], + "title": { + "text": "y" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "accumulated_residual, labels = clean_cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)\n", + "logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, clean_cache)\n", + "line(logit_lens_logit_diffs, x=np.arange(model.cfg.n_layers*2+1)/2, hover_name=labels, title=\"Logit Difference From Accumulate Residual Stream\")" + ] + }, + { + "cell_type": "markdown", + "id": "3e63928b", + "metadata": { + "id": "s60emfYIbTuT" + }, + "source": [ + "#### Layer Attribution" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "3e9f7a27", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 542 + }, + "id": "yGgAVYgIJi9Z", + "outputId": "2d6b1ffe-b701-419d-a786-24f0d24d2b54" + }, + "outputs": [ + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "hovertemplate": "%{hovertext}

x=%{x}
y=%{y}", + "hovertext": [ + "embed", + "pos_embed", + "0_attn_out", + "0_mlp_out", + "1_attn_out", + "1_mlp_out", + "2_attn_out", + "2_mlp_out", + "3_attn_out", + "3_mlp_out", + "4_attn_out", + "4_mlp_out", + "5_attn_out", + "5_mlp_out", + "6_attn_out", + "6_mlp_out", + "7_attn_out", + "7_mlp_out", + "8_attn_out", + "8_mlp_out", + "9_attn_out", + "9_mlp_out", + "10_attn_out", + "10_mlp_out", + "11_attn_out", + "11_mlp_out" + ], + "legendgroup": "", + "line": { + "color": "#636efa", + "dash": "solid" + }, + "marker": { + "symbol": "circle" + }, + "mode": "lines", + "name": "", + "orientation": "v", + "showlegend": false, + "type": "scatter", + "x": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25 + ], + "xaxis": "x", + "y": [ + -0.000278800813248381, + -0.000004853525751968846, + 0.015448074787855148, + -0.001427602139301598, + 0.013182703405618668, + -0.007144661154597998, + 0.016924899071455002, + -0.015813538804650307, + 0.0035571292974054813, + -0.005683178082108498, + 0.01121144276112318, + 0.0013184386771172285, + 0.02483748272061348, + -0.01145392656326294, + 0.05676430091261864, + -0.005015416070818901, + 0.10983012616634369, + 0.03154907748103142, + 0.11272970587015152, + 0.10776256769895554, + 0.30334964394569397, + 0.2031497061252594, + 0.6526029706001282, + 0.2888459265232086, + 0.21826733648777008, + -0.12294676899909973 + ], + "yaxis": "y" + } + ], + "layout": { + "legend": { + "tracegroupgap": 0 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Logit Difference From Each Layer" + }, + "xaxis": { + "anchor": "y", + "domain": [ + 0, + 1 + ], + "title": { + "text": "x" + } + }, + "yaxis": { + "anchor": "x", + "domain": [ + 0, + 1 + ], + "title": { + "text": "y" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "per_layer_residual, labels = clean_cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)\n", + "per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, clean_cache)\n", + "\n", + "line(per_layer_logit_diffs, hover_name=labels, title=\"Logit Difference From Each Layer\")" + ] + }, + { + "cell_type": "markdown", + "id": "01a066eb", + "metadata": {}, + "source": [ + "#### Head Attribution" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "17600841", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Tried to stack head results when they weren't cached. Computing head results now\n" + ] + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
color: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + 0.0010459227487444878, + 0.0003711625759024173, + 0.0014351740246638656, + -0.00028568002744577825, + 0.0005666174693033099, + 0.0002843957918230444, + 0.0033294728491455317, + 0.0002486975281499326, + 0.0004194938519503921, + 0.0000095533641797374, + 0.00009443467570235953, + 0.00005089359183330089 + ], + [ + 0.00029732173425145447, + 0.000958481163252145, + -0.00040032266406342387, + -0.00014667196955997497, + 0.0001732840173644945, + 0.0014571092324331403, + 0.0036868511233478785, + 0.004298475570976734, + -0.0005763894878327847, + -0.0008393435273319483, + -0.00022251349582802504, + -0.0027903111185878515 + ], + [ + 0.0005717608146369457, + 0.008333437144756317, + 0.0003366676392033696, + 0.00032111862674355507, + 0.00041764360503293574, + -0.00021099513105582446, + 0.0002676621952559799, + -0.0006087180809117854, + -0.00035023957025259733, + -0.000043039803131250665, + -0.00026268954388797283, + -0.00007108243880793452 + ], + [ + 0.00010563228715909645, + -0.00010872280108742416, + 0.00030111338128335774, + -8.081519808911253e-7, + 0.0002908433089032769, + 0.0029528227169066668, + 0.00010164394188905135, + -0.0003412188380025327, + -0.0003320052637718618, + 0.0006695032934658229, + -0.0018640790367498994, + -0.00008829677244648337 + ], + [ + 0.00008301740308525041, + 0.00006009411663399078, + 0.002549499738961458, + -0.0002249304234283045, + -0.0005816780612803996, + 0.00018672971054911613, + -0.00016864322242327034, + -0.0012501024175435305, + 0.0023646687623113394, + 0.001445892034098506, + 0.0002246722433483228, + 0.0004394247953314334 + ], + [ + 0.00045759862405247986, + 0.001828972832299769, + 0.0006541974726133049, + 0.0014741927152499557, + 0.00012481046724133193, + 0.00620291568338871, + 0.00004533326136879623, + 0.0013662336859852076, + -0.00045632332330569625, + 0.001549292472191155, + 0.0030067134648561478, + -0.004089352209120989 + ], + [ + 0.0002479908289387822, + 0.0001336046407232061, + 0.0017048768931999803, + 0.002524170558899641, + 0.02109912782907486, + -0.00252012861892581, + -0.006416620686650276, + 0.005769579671323299, + -0.000575427373405546, + 0.003223165636882186, + 0.0029180559795349836, + 0.0003498493751976639 + ], + [ + -0.00011949524923693389, + 0.026559170335531235, + 0.001183997024782002, + -0.0018259993521496654, + -0.0001383726776111871, + 0.017751073464751244, + -0.006150953937321901, + 0.0006273160688579082, + 0.0019553727470338345, + 0.007698599249124527, + 0.006984383333474398, + 0.00024263723753392696 + ], + [ + -0.0006382332067005336, + 0.0005390838487073779, + 0.020748836919665337, + 0.006996408570557833, + 0.00640723155811429, + 0.03623078390955925, + 0.001006894395686686, + 0.0031054068822413683, + 0.0032007494010031223, + 0.008640527725219727, + -0.034349020570516586, + 0.0042087906040251255 + ], + [ + 0.0009978149319067597, + 0.013555935584008694, + 0.09221717715263367, + 0.024947049096226692, + 0.0008158513810485601, + -0.01794683001935482, + 0.011314213275909424, + -0.017806988209486008, + 0.0009780247928574681, + 0.02014842815697193, + 0.022655341774225235, + -0.0002975918760057539 + ], + [ + 0.02260882779955864, + 0.08542339503765106, + -0.0013184607960283756, + -0.003500083228573203, + 0.24834668636322021, + 0.005109731573611498, + -0.0013253865763545036, + -0.06047582998871803, + -0.0005456183571368456, + 0.007450374774634838, + 0.0026964799035340548, + 0.021582823246717453 + ], + [ + 0.004116719588637352, + 0.004617986734956503, + 0.0011007598368451, + 0.012473659589886665, + 0.0017329364782199264, + 0.003249760251492262, + 0.0011255929712206125, + 0.002364318585023284, + 0.018851356580853462, + 0.030858764424920082, + 0.0020210484508424997, + 0.02690068632364273 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Logit Difference From Each Head" + }, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "scaleanchor": "y", + "title": { + "text": "Head" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Layer" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def imshow(tensor, renderer=None, **kwargs):\n", + " px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale=\"RdBu\", **kwargs).show(renderer)\n", + "\n", + "per_head_residual, labels = clean_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)\n", + "per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, clean_cache)\n", + "per_head_logit_diffs = einops.rearrange(per_head_logit_diffs, \"(layer head_index) -> layer head_index\", layer=model.cfg.n_layers, head_index=model.cfg.n_heads)\n", + "per_head_logit_diffs_pct = per_head_logit_diffs/clean_logit_diff\n", + "imshow(per_head_logit_diffs_pct, labels={\"x\":\"Head\", \"y\":\"Layer\"}, title=\"Logit Difference From Each Head\")" + ] + }, + { + "cell_type": "markdown", + "id": "1c54c587", + "metadata": {}, + "source": [ + "### Activation Patching" + ] + }, + { + "cell_type": "markdown", + "id": "f7822cad", + "metadata": {}, + "source": [ + "#### Attention Heads" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "bd049183", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e0350aa00214417cad38d8bc0d97e3a8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/144 [00:00Layer: %{y}
Logit diff variation: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + 0.12577775120735168, + 1.180384635925293, + 0.20882254838943481, + 0.9277065396308899, + 1.731764554977417, + 0.5056005120277405, + 1.1103265285491943, + 0.19368287920951843, + 0.1091848611831665, + 0.18261700868606567, + 0.3207947313785553, + 0.11203768849372864 + ], + [ + 0.07135356217622757, + -0.016551190987229347, + -0.0518631786108017, + -0.04373945668339729, + -0.03796827793121338, + -0.4464890956878662, + 0.4070020616054535, + 0.2848395109176636, + -0.23305077850818634, + -0.060844536870718, + 0.00825475063174963, + 0.4159923791885376 + ], + [ + -0.021250324323773384, + 1.8155388832092285, + -0.015693554654717445, + 0.048230137676000595, + 0.023864924907684326, + 0.0637807548046112, + -0.015818627551198006, + -0.297695130109787, + 0.05132120102643967, + 0.015229002572596073, + -0.06724703311920166, + -0.19233091175556183 + ], + [ + 0.11498580873012543, + -0.3324532210826874, + 0.057047709822654724, + 0.0781223401427269, + -0.2108118087053299, + 0.6913115382194519, + 0.2046266794204712, + -0.12819281220436096, + 0.02726569026708603, + 0.19937069714069366, + 0.054391421377658844, + 1.8789801597595215 + ], + [ + -0.008087988011538982, + 0.18980863690376282, + 0.4819321036338806, + -0.1604316085577011, + -0.9266522526741028, + 0.10933375358581543, + 0.0009767526062205434, + 0.23822936415672302, + 0.6611126065254211, + 1.4624457359313965, + 0.1593417078256607, + -0.01554466038942337 + ], + [ + -0.13844873011112213, + 0.14638781547546387, + 0.09883366525173187, + 0.3744535744190216, + 0.3810943067073822, + 1.511167287826538, + 0.0353119894862175, + 1.2763357162475586, + -0.5201386213302612, + -0.3617379367351532, + 1.0785164833068848, + -3.233813524246216 + ], + [ + 0.31471386551856995, + -0.01723015308380127, + -0.39588260650634766, + -0.03893312066793442, + 10.121282577514648, + -1.8731374740600586, + -2.3234381675720215, + 4.13507604598999, + 0.3490312397480011, + 0.4989180564880371, + 0.9424054026603699, + 0.30433881282806396 + ], + [ + 0.14420798420906067, + 16.014633178710938, + 0.19756606221199036, + -1.86379873752594, + 0.00816541351377964, + 12.193460464477539, + -7.824681758880615, + -0.10995613038539886, + 0.7687519192695618, + 2.0217974185943604, + 1.0177136659622192, + 0.040308911353349686 + ], + [ + -0.0067241075448691845, + 0.12224296480417252, + 3.576972246170044, + 0.9758681058883667, + 2.0006184577941895, + 9.18878173828125, + 0.35875704884529114, + 0.5208027362823486, + -0.2424192577600479, + 1.3535646200180054, + -7.396166801452637, + 1.1787915229797363 + ], + [ + 0.09612376987934113, + 1.0050873756408691, + 14.706995964050293, + 3.8923442363739014, + 0.13363642990589142, + -3.089003801345825, + 1.8870115280151367, + -3.740604877471924, + 0.17754562199115753, + 2.3362433910369873, + 3.903806209564209, + -0.10991444438695908 + ], + [ + 2.95804762840271, + 10.064723014831543, + -0.09472117573022842, + -0.2169462889432907, + 30.165334701538086, + 0.7441484928131104, + 0.10620397329330444, + -7.503061771392822, + 0.00011316035670461133, + 1.1169403791427612, + 0.6312471628189087, + 2.3635923862457275 + ], + [ + 0.44658440351486206, + 0.27343711256980896, + 0.1924440860748291, + 1.2078677415847778, + 0.2075212001800537, + 0.3728693425655365, + 0.07559707760810852, + 0.29023250937461853, + 1.7939757108688354, + 2.7831850051879883, + 0.21529056131839752, + 2.6623356342315674 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorbar": { + "ticksuffix": "%", + "title": { + "text": "Logit diff variation" + } + }, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "margin": { + "l": 100, + "r": 100 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Patching output of attention heads (corrupted -> clean)" + }, + "width": 600, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "scaleanchor": "y", + "showline": true, + "title": { + "text": "Head" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "showline": true, + "title": { + "text": "Layer" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "imshow_p(\n", + " results['z'] * 100,\n", + " title=\"Patching output of attention heads (corrupted -> clean)\",\n", + " labels={\"x\": \"Head\", \"y\": \"Layer\", \"color\": \"Logit diff variation\"},\n", + " coloraxis=dict(colorbar_ticksuffix = \"%\"),\n", + " border=True,\n", + " width=600,\n", + " margin={\"r\": 100, \"l\": 100}\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "b9919628", + "metadata": {}, + "source": [ + "#### Head Output by Component" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "13f5d49a", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c147d7f5d3794d79b91751f0fe53c9c0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/720 [00:00Layer: %{y}
Logit diff variation: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + 0.12577775120735168, + 1.180384635925293, + 0.20882254838943481, + 0.9277065396308899, + 1.731764554977417, + 0.5056005120277405, + 1.1103265285491943, + 0.19368287920951843, + 0.1091848611831665, + 0.18261700868606567, + 0.3207947313785553, + 0.11203768849372864 + ], + [ + 0.07135356217622757, + -0.016551190987229347, + -0.0518631786108017, + -0.04373945668339729, + -0.03796827793121338, + -0.4464890956878662, + 0.4070020616054535, + 0.2848395109176636, + -0.23305077850818634, + -0.060844536870718, + 0.00825475063174963, + 0.4159923791885376 + ], + [ + -0.021250324323773384, + 1.8155388832092285, + -0.015693554654717445, + 0.048230137676000595, + 0.023864924907684326, + 0.0637807548046112, + -0.015818627551198006, + -0.297695130109787, + 0.05132120102643967, + 0.015229002572596073, + -0.06724703311920166, + -0.19233091175556183 + ], + [ + 0.11498580873012543, + -0.3324532210826874, + 0.057047709822654724, + 0.0781223401427269, + -0.2108118087053299, + 0.6913115382194519, + 0.2046266794204712, + -0.12819281220436096, + 0.02726569026708603, + 0.19937069714069366, + 0.054391421377658844, + 1.8789801597595215 + ], + [ + -0.008087988011538982, + 0.18980863690376282, + 0.4819321036338806, + -0.1604316085577011, + -0.9266522526741028, + 0.10933375358581543, + 0.0009767526062205434, + 0.23822936415672302, + 0.6611126065254211, + 1.4624457359313965, + 0.1593417078256607, + -0.01554466038942337 + ], + [ + -0.13844873011112213, + 0.14638781547546387, + 0.09883366525173187, + 0.3744535744190216, + 0.3810943067073822, + 1.511167287826538, + 0.0353119894862175, + 1.2763357162475586, + -0.5201386213302612, + -0.3617379367351532, + 1.0785164833068848, + -3.233813524246216 + ], + [ + 0.31471386551856995, + -0.01723015308380127, + -0.39588260650634766, + -0.03893312066793442, + 10.121282577514648, + -1.8731374740600586, + -2.3234381675720215, + 4.13507604598999, + 0.3490312397480011, + 0.4989180564880371, + 0.9424054026603699, + 0.30433881282806396 + ], + [ + 0.14420798420906067, + 16.014633178710938, + 0.19756606221199036, + -1.86379873752594, + 0.00816541351377964, + 12.193460464477539, + -7.824681758880615, + -0.10995613038539886, + 0.7687519192695618, + 2.0217974185943604, + 1.0177136659622192, + 0.040308911353349686 + ], + [ + -0.0067241075448691845, + 0.12224296480417252, + 3.576972246170044, + 0.9758681058883667, + 2.0006184577941895, + 9.18878173828125, + 0.35875704884529114, + 0.5208027362823486, + -0.2424192577600479, + 1.3535646200180054, + -7.396166801452637, + 1.1787915229797363 + ], + [ + 0.09612376987934113, + 1.0050873756408691, + 14.706995964050293, + 3.8923442363739014, + 0.13363642990589142, + -3.089003801345825, + 1.8870115280151367, + -3.740604877471924, + 0.17754562199115753, + 2.3362433910369873, + 3.903806209564209, + -0.10991444438695908 + ], + [ + 2.95804762840271, + 10.064723014831543, + -0.09472117573022842, + -0.2169462889432907, + 30.165334701538086, + 0.7441484928131104, + 0.10620397329330444, + -7.503061771392822, + 0.00011316035670461133, + 1.1169403791427612, + 0.6312471628189087, + 2.3635923862457275 + ], + [ + 0.44658440351486206, + 0.27343711256980896, + 0.1924440860748291, + 1.2078677415847778, + 0.2075212001800537, + 0.3728693425655365, + 0.07559707760810852, + 0.29023250937461853, + 1.7939757108688354, + 2.7831850051879883, + 0.21529056131839752, + 2.6623356342315674 + ] + ] + }, + { + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
Logit diff variation: %{z}", + "name": "1", + "type": "heatmap", + "xaxis": "x2", + "yaxis": "y2", + "z": [ + [ + -0.07349467277526855, + 0.9017361998558044, + -0.02865935117006302, + -0.0008635922567918897, + 0.085984006524086, + 0.4380348026752472, + 0.0004288181953597814, + 0.0485725961625576, + 0.05918882042169571, + 0.02085724100470543, + -0.3354906737804413, + 0.01602707989513874 + ], + [ + 0.01819499395787716, + 0.006021322216838598, + 0.000762343464884907, + 0.014794227667152882, + -0.0012149849208071828, + -0.10078418999910355, + -0.011959263123571873, + 0.10816939175128937, + 0.07057931274175644, + -0.025979237630963326, + -0.004151198081672192, + 0.22581149637699127 + ], + [ + 0.06940899044275284, + -0.04033868759870529, + 0.027962520718574524, + 0.11956285685300827, + -0.0002680113830137998, + 0.01663457229733467, + 0.06894145905971527, + 0.003388855140656233, + 0.046773940324783325, + 0.446048378944397, + 0.016402296721935272, + -0.004085684660822153 + ], + [ + -0.28234702348709106, + -0.13405928015708923, + -0.0051458184607326984, + 0.0257648266851902, + -0.17217646539211273, + -0.03253062441945076, + 0.0735095664858818, + 0.34669652581214905, + 0.02128605917096138, + 0.010696631856262684, + 0.10814854502677917, + 0.172834575176239 + ], + [ + -0.022054359316825867, + 0.03246511146426201, + 0.034073177725076675, + -0.04666375741362572, + 0.1053403839468956, + -0.07318497449159622, + 0.2610281705856323, + -0.08334558457136154, + -0.15732863545417786, + 0.07538862526416779, + 0.06773243099451065, + -0.005235155578702688 + ], + [ + -0.0779198408126831, + -0.01547914557158947, + 0.03623513877391815, + -0.1484127789735794, + 0.056094780564308167, + 0.0041809771209955215, + 0.0733993798494339, + -0.03279267996549606, + -0.013609022833406925, + -0.49003198742866516, + -0.2630978226661682, + -0.19028806686401367 + ], + [ + 0.13868099451065063, + -0.02708701603114605, + -0.20113955438137054, + 0.27251994609832764, + -0.0035675293765962124, + -0.27257949113845825, + -0.04697345942258835, + 0.11462251096963882, + 0.03200651705265045, + -0.02890949323773384, + 0.07216652482748032, + 0.11085546761751175 + ], + [ + -0.015175399370491505, + 0.05800361931324005, + -0.010529869236052036, + -0.00031565784593112767, + 0.025109687820076942, + 0.05078815668821335, + -0.8372735381126404, + 0.008010562509298325, + 0.09209168702363968, + 0.006331024691462517, + 0.05672014132142067, + 0.0029421693179756403 + ], + [ + 0.07274424284696579, + -0.0013460126938298345, + 0.010267813690006733, + 0.16904370486736298, + 0.04226241633296013, + 0.10128745436668396, + 0.007980783469974995, + 0.009392309933900833, + 0.029219195246696472, + -0.061106596142053604, + -0.8731274604797363, + -0.032566361129283905 + ], + [ + -0.2924659252166748, + -0.006473963614553213, + -0.09370868653059006, + 0.06666932255029678, + 0.009696056134998798, + -0.2734609544277191, + 0.03427567705512047, + -0.29372259974479675, + -0.08423895388841629, + 0.15104526281356812, + -0.010470311157405376, + 0.04041016101837158 + ], + [ + -0.06095174327492714, + 0.8251593708992004, + -0.1264835000038147, + -0.06235731393098831, + -0.6888845562934875, + 0.004794425796717405, + -0.008302396163344383, + -2.9320385456085205, + 0.009809216484427452, + -0.01819499395787716, + 0.07773521542549133, + 0.23616565763950348 + ], + [ + 0.12859781086444855, + -0.06721130013465881, + -0.12922316789627075, + -0.2794346213340759, + -0.09065335988998413, + -0.05369756743311882, + 0.015985390171408653, + -0.11192154884338379, + 0.10037324577569962, + 0.026693934574723244, + 0.07902166992425919, + 0.18010960519313812 + ] + ] + }, + { + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
Logit diff variation: %{z}", + "name": "2", + "type": "heatmap", + "xaxis": "x3", + "yaxis": "y3", + "z": [ + [ + -0.07702647149562836, + 0.7390384078025818, + 0.08303588628768921, + 0.023686250671744347, + 0.07085923105478287, + 0.6564045548439026, + 0.05893569812178612, + 0.10956602543592453, + -0.022090094164013863, + 0.05924540385603905, + -0.38862839341163635, + 0.01329931989312172 + ], + [ + 0.07303905487060547, + 0.018105657771229744, + 0.005264934618026018, + -0.02890353836119175, + -0.00028587880660779774, + 0.03708086162805557, + -0.03416847437620163, + -0.12602490186691284, + 0.04281928390264511, + -0.003954656887799501, + 0.0013221894623711705, + 0.21736912429332733 + ], + [ + 0.0161521527916193, + -0.09671637415885925, + 0.04384070634841919, + 0.14180779457092285, + 0.032923709601163864, + 0.05022532865405083, + 0.08648727089166641, + -0.010964643210172653, + 0.07863453775644302, + 0.37278297543525696, + 0.009743702597916126, + -0.04992754012346268 + ], + [ + -0.27113819122314453, + -0.0075162299908697605, + -0.012340434826910496, + 0.04754522070288658, + -0.06202974542975426, + -0.008028429932892323, + 0.13442856073379517, + 0.4154265820980072, + 0.15190884470939636, + -0.03073197230696678, + 0.018653592094779015, + -0.06339362263679504 + ], + [ + 0.07051677256822586, + -0.02460940182209015, + 0.061529457569122314, + 0.030463960021734238, + -0.15311191976070404, + 0.03925473242998123, + 0.05310496687889099, + 0.07523972541093826, + 0.0013400568859651685, + -0.08001628518104553, + 0.08992377668619156, + 0.00813563447445631 + ], + [ + -0.06361398845911026, + 0.015997301787137985, + 0.028337735682725906, + 0.09976872056722641, + 0.0010482223005965352, + 0.08965576440095901, + -0.024394990876317024, + 0.043459534645080566, + -0.035889700055122375, + -0.40881267189979553, + -0.2477080374956131, + -0.12352942675352097 + ], + [ + -0.018474917858839035, + 0.04201525077223778, + 0.09416729211807251, + 0.09076056629419327, + 0.3261341154575348, + -0.17066964507102966, + 0.16999365389347076, + 0.2988118529319763, + -0.009523337706923485, + 0.003996347542852163, + 0.04282226040959358, + -0.0105536924675107 + ], + [ + 0.022292591631412506, + -0.12782952189445496, + -0.009350619278848171, + -0.03735483065247536, + -0.03839114308357239, + -0.16217070817947388, + 0.26807689666748047, + -0.09412559121847153, + -0.017974628135561943, + 0.048161644488573074, + 0.13011355698108673, + -0.007677036803215742 + ], + [ + 0.025359831750392914, + 0.008064163848757744, + -0.20918585360050201, + 0.37377163767814636, + 0.04850112646818161, + 0.0031744458246976137, + -0.028706997632980347, + 0.020887020975351334, + 0.07497172057628632, + -0.04954637214541435, + -0.7434516549110413, + 0.03575867414474487 + ], + [ + -0.16834092140197754, + -0.029058387503027916, + 0.21169325709342957, + -0.04983820393681526, + -0.041750214993953705, + -0.13812710344791412, + 0.21468603610992432, + -0.15130135416984558, + 0.12721607089042664, + 0.04680967703461647, + 0.012691828422248363, + -0.03980862349271774 + ], + [ + 0.45723336935043335, + 0.6883752942085266, + 0.06435251235961914, + 0.009821128100156784, + -0.22919736802577972, + 0.04609498009085655, + -0.041101034730672836, + -2.9632232189178467, + 0.0222628116607666, + -0.041142724454402924, + 0.39502495527267456, + 0.07443569600582123 + ], + [ + 0.07854520529508591, + 0.21149076521396637, + 0.25584661960601807, + 0.09261579811573029, + 0.033703919500112534, + 0.06016557291150093, + 0.08903933316469193, + -0.016497589647769928, + 0.07007306814193726, + 0.2068869024515152, + -0.2717992663383484, + -0.019326597452163696 + ] + ] + }, + { + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
Logit diff variation: %{z}", + "name": "3", + "type": "heatmap", + "xaxis": "x4", + "yaxis": "y4", + "z": [ + [ + 0.1901034563779831, + 1.1824214458465576, + 0.1779029667377472, + 0.9468603730201721, + 1.6572273969650269, + 0.5072205066680908, + 1.019545078277588, + 0.1847432255744934, + 0.011584047228097916, + 0.07389073818922043, + 0.11215680092573166, + 0.09106431156396866 + ], + [ + 0.08467670530080795, + 0.00837386678904295, + -0.05978440120816231, + -0.0068491799756884575, + -0.01765301637351513, + -0.34689009189605713, + 0.43098610639572144, + 0.2698189616203308, + -0.31453219056129456, + -0.0353119894862175, + 0.010214211419224739, + 0.4096643030643463 + ], + [ + 0.09668361395597458, + 1.7086559534072876, + 0.019916223362088203, + 0.02818884141743183, + -0.0020904885604977608, + 0.07935518771409988, + -0.19144350290298462, + -0.2864505648612976, + 0.05688392370939255, + -0.013805563561618328, + -0.12365449219942093, + -0.09039130806922913 + ], + [ + 0.13979175686836243, + -0.2196740359067917, + 0.08399476855993271, + 0.019576743245124817, + -0.16501757502555847, + 0.7133748531341553, + 0.2026166021823883, + -0.14043796062469482, + 0.05855154991149902, + 0.2155645340681076, + -0.17813228070735931, + 1.840761661529541 + ], + [ + 0.048590462654829025, + 0.10924740135669708, + 0.3835391700267792, + -0.15394572913646698, + -1.0083482265472412, + 0.10374124348163605, + -0.18349845707416534, + 0.34444525837898254, + 0.8261212110519409, + 1.4322736263275146, + 0.02400190755724907, + -0.007009986322373152 + ], + [ + -0.04226241633296013, + 0.1567687839269638, + 0.08916738629341125, + 0.4007335901260376, + 0.3749776780605316, + 1.3276777267456055, + 0.005628238897770643, + 1.189422607421875, + -0.5122114419937134, + -0.3900578022003174, + 1.2196662425994873, + -3.3266823291778564 + ], + [ + 0.15332037210464478, + -0.05459689348936081, + -0.19211651384830475, + -0.30415716767311096, + 8.271307945251465, + -1.491465449333191, + -2.2332375049591064, + 3.297427177429199, + 0.31371623277664185, + 0.48932623863220215, + 0.8141946792602539, + 0.10165075957775116 + ], + [ + 0.1253012865781784, + 15.300347328186035, + 0.21093091368675232, + -1.8412083387374878, + 0.02176252380013466, + 11.945406913757324, + -7.0668110847473145, + -0.011893749237060547, + 0.6837416887283325, + 1.832366943359375, + 0.8392865657806396, + 0.04465069621801376 + ], + [ + -0.10165969282388687, + 0.09860139340162277, + 3.8938212394714355, + 0.4259326159954071, + 1.9136905670166016, + 9.020977020263672, + 0.3810228407382965, + 0.48829886317253113, + -0.3380814492702484, + 1.3208613395690918, + -6.203188896179199, + 1.1227473020553589 + ], + [ + 0.4415844678878784, + 1.088304877281189, + 12.362054824829102, + 3.9100539684295654, + 0.18860556185245514, + -2.8436007499694824, + 1.4235721826553345, + -3.2842769622802734, + 0.11654028296470642, + 2.039667844772339, + 3.928046226501465, + -0.12851443886756897 + ], + [ + 2.1602132320404053, + 8.6521635055542, + -0.06693733483552933, + -0.15057474374771118, + 30.01460075378418, + 0.7048073410987854, + 0.1213168352842331, + -5.723662853240967, + -0.04095809534192085, + 1.197829246520996, + 0.1722598373889923, + 2.182297706604004 + ], + [ + 0.22440890967845917, + 0.11963432282209396, + 0.0697663426399231, + 1.382420539855957, + 0.258863240480423, + 0.40013203024864197, + -0.019332554191350937, + 0.44331759214401245, + 1.7227057218551636, + 2.4907219409942627, + 0.13799011707305908, + 2.6035876274108887 + ] + ] + }, + { + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
Logit diff variation: %{z}", + "name": "4", + "type": "heatmap", + "xaxis": "x5", + "yaxis": "y5", + "z": [ + [ + -0.14063450694084167, + -0.0024597488809376955, + 0.02696194499731064, + -0.026294894516468048, + -0.01753390021622181, + -0.0035913523752242327, + 0.059167977422475815, + 0.06137460470199585, + 0.08738064020872116, + 0.11621866375207901, + 0.3217953145503998, + 0.016241490840911865 + ], + [ + -0.011226698756217957, + -0.032971352338790894, + 0.010190388187766075, + -0.03907010331749916, + -0.007051676977425814, + -0.09838996082544327, + -0.0021619584877043962, + -0.05574636906385422, + 0.08942348510026932, + -0.03802783787250519, + -0.0025431301910430193, + 0.012477419339120388 + ], + [ + -0.03605646640062332, + 0.0768001452088356, + -0.05713406577706337, + 0.03060690127313137, + 0.04522543027997017, + -0.018701238557696342, + 0.16880249977111816, + -0.015264736488461494, + -0.004937365185469389, + 0.045561932027339935, + 0.05202100798487663, + -0.09795518219470978 + ], + [ + -0.018647635355591774, + -0.12508389353752136, + -0.05032062903046608, + 0.050287868827581406, + -0.03150622546672821, + -0.04686030000448227, + 0.003978480119258165, + 0.0062952893786132336, + -0.03313811868429184, + -0.02182803675532341, + 0.17301325500011444, + -0.038611505180597305 + ], + [ + -0.027349073439836502, + 0.06411129981279373, + 0.09419706463813782, + -0.016944274306297302, + 0.0263604074716568, + 0.01538980845361948, + 0.2362430989742279, + -0.03846260905265808, + -0.12372001260519028, + -0.03351929038763046, + 0.14378511905670166, + -0.008784817531704903 + ], + [ + -0.12869906425476074, + -0.0003394810773897916, + 0.007403070107102394, + -0.0585336834192276, + -0.04071986302733421, + 0.1420162469148636, + 0.025204982608556747, + -0.03801592439413071, + 0.017343314364552498, + 0.026836872100830078, + -0.3180401623249054, + -0.14100971817970276 + ], + [ + 0.13171270489692688, + 0.0009052828536368906, + -0.12039666622877121, + 0.3311399519443512, + 0.5857537388801575, + -0.36949241161346436, + -0.050713710486888885, + 0.24135616421699524, + 0.0426793247461319, + -0.012578667141497135, + 0.11260050535202026, + 0.17679816484451294 + ], + [ + 0.01577693596482277, + -0.026396144181489944, + -0.01911814510822296, + 0.023013243451714516, + -0.013990193605422974, + 0.20152370631694794, + -0.4022254943847656, + -0.07430466264486313, + 0.08478986471891403, + 0.07110044360160828, + 0.181595578789711, + -0.0047170002944767475 + ], + [ + 0.08578746020793915, + 0.011679340153932571, + -0.12774017453193665, + 0.5820224285125732, + 0.10111772269010544, + 0.14183460175991058, + -0.020398642867803574, + 0.029421694576740265, + 0.046351078897714615, + -0.07420341670513153, + -1.0696929693222046, + 0.00970796775072813 + ], + [ + -0.35632410645484924, + -0.06625837087631226, + 0.26142722368240356, + 0.017122948542237282, + -0.08171964436769485, + -0.2943717837333679, + 0.284410685300827, + -0.22675548493862152, + 0.04146731644868851, + 0.22240476310253143, + -0.02038077637553215, + -0.004133331123739481 + ], + [ + 0.4102718234062195, + 1.1754621267318726, + -0.061469897627830505, + -0.05417998880147934, + -0.08996844291687012, + 0.05145223066210747, + -0.015145620331168175, + -1.1293702125549316, + 0.03253062441945076, + -0.05925434082746506, + 0.5251951217651367, + 0.13170674443244934 + ], + [ + 0.22043636441230774, + 0.14476187527179718, + 0.11746937781572342, + -0.13005103170871735, + -0.05740803852677345, + -0.012197495438158512, + 0.08416152745485306, + -0.21202678978443146, + 0.07903655618429184, + 0.20578508079051971, + 0.08860158175230026, + 0.0277481097728014 + ] + ] + } + ], + "layout": { + "annotations": [ + { + "font": {}, + "showarrow": false, + "text": "Output", + "x": 0.09200000000000001, + "xanchor": "center", + "xref": "paper", + "y": 1, + "yanchor": "bottom", + "yref": "paper" + }, + { + "font": {}, + "showarrow": false, + "text": "Query", + "x": 0.29600000000000004, + "xanchor": "center", + "xref": "paper", + "y": 1, + "yanchor": "bottom", + "yref": "paper" + }, + { + "font": {}, + "showarrow": false, + "text": "Key", + "x": 0.5, + "xanchor": "center", + "xref": "paper", + "y": 1, + "yanchor": "bottom", + "yref": "paper" + }, + { + "font": {}, + "showarrow": false, + "text": "Value", + "x": 0.7040000000000002, + "xanchor": "center", + "xref": "paper", + "y": 1, + "yanchor": "bottom", + "yref": "paper" + }, + { + "font": {}, + "showarrow": false, + "text": "Pattern", + "x": 0.908, + "xanchor": "center", + "xref": "paper", + "y": 1, + "yanchor": "bottom", + "yref": "paper" + } + ], + "coloraxis": { + "cmid": 0, + "colorbar": { + "ticksuffix": "%", + "title": { + "text": "Logit diff variation" + } + }, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "margin": { + "l": 100, + "r": 100 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Patching output of attention heads (corrupted -> clean)" + }, + "width": 1500, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 0.18400000000000002 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "scaleanchor": "y", + "showline": true, + "title": { + "text": "Head" + } + }, + "xaxis2": { + "anchor": "y2", + "domain": [ + 0.20400000000000001, + 0.388 + ], + "linecolor": "black", + "linewidth": 1, + "matches": "x", + "mirror": true, + "showline": true, + "title": { + "text": "Head" + } + }, + "xaxis3": { + "anchor": "y3", + "domain": [ + 0.40800000000000003, + 0.5920000000000001 + ], + "linecolor": "black", + "linewidth": 1, + "matches": "x", + "mirror": true, + "showline": true, + "title": { + "text": "Head" + } + }, + "xaxis4": { + "anchor": "y4", + "domain": [ + 0.6120000000000001, + 0.7960000000000002 + ], + "linecolor": "black", + "linewidth": 1, + "matches": "x", + "mirror": true, + "showline": true, + "title": { + "text": "Head" + } + }, + "xaxis5": { + "anchor": "y5", + "domain": [ + 0.8160000000000001, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "matches": "x", + "mirror": true, + "showline": true, + "title": { + "text": "Head" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "showline": true, + "title": { + "text": "Layer" + } + }, + "yaxis2": { + "anchor": "x2", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "matches": "y", + "mirror": true, + "showline": true, + "showticklabels": false + }, + "yaxis3": { + "anchor": "x3", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "matches": "y", + "mirror": true, + "showline": true, + "showticklabels": false + }, + "yaxis4": { + "anchor": "x4", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "matches": "y", + "mirror": true, + "showline": true, + "showticklabels": false + }, + "yaxis5": { + "anchor": "x5", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "matches": "y", + "mirror": true, + "showline": true, + "showticklabels": false + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "assert results.keys() == {\"z\", \"q\", \"k\", \"v\", \"pattern\"}\n", + "assert all([r.shape == (12, 12) for r in results.values()])\n", + "\n", + "imshow_p(\n", + " torch.stack(tuple(results.values())) * 100,\n", + " facet_col=0,\n", + " facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n", + " title=\"Patching output of attention heads (corrupted -> clean)\",\n", + " labels={\"x\": \"Head\", \"y\": \"Layer\", \"color\": \"Logit diff variation\"},\n", + " coloraxis=dict(colorbar_ticksuffix = \"%\"),\n", + " border=True,\n", + " width=1500,\n", + " margin={\"r\": 100, \"l\": 100}\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "10d4eeae", + "metadata": {}, + "source": [ + "#### Residual Stream & Layer Outputs" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "289e90b9", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5d162378d096405788e621e4cb8fc144", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/228 [00:00Layer: %{y}
Logit diff variation: %{z}", + "name": "0", + "type": "heatmap", + "x": [ + "<|endoftext|> 0", + "I 1", + " thought 2", + " this 3", + " movie 4", + " was 5", + " perfect 6", + ", 7", + " I 8", + " enjoyed 9", + " it 10", + ". 11", + " 12", + "\n 13", + "Conclusion 14", + ": 15", + " This 16", + " movie 17", + " is 18" + ], + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + 0, + 0, + 0, + 0, + 0, + 0, + 77.5570297241211, + 0, + 0, + 22.381868362426758, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 76.68247985839844, + 0.07959342002868652, + 0.004591928329318762, + 21.26034164428711, + 0.02425205148756504, + 0.16954101622104645, + -0.022649940103292465, + 0.21061226725578308, + 0.7519029378890991, + 0.03731909766793251, + 0.08272916078567505, + 0.20496021211147308, + 0.1951599270105362 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 76.79911041259766, + 0.05866173282265663, + -0.008028429932892323, + 21.252254486083984, + 0.021988844498991966, + 0.15230490267276764, + -0.07903952896595001, + 0.2978529632091522, + 0.5368863344192505, + 0.07450418919324875, + 0.12928272783756256, + -0.08395903557538986, + 0.5783745050430298 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 76.34866333007812, + 0.03513331338763237, + 0.019642256200313568, + 19.5383358001709, + 0.29270118474960327, + 0.22370612621307373, + -0.12388677895069122, + 0.48620834946632385, + 0.3005777597427368, + 0.4718608260154724, + 0.45084574818611145, + 0.04873638227581978, + 1.1573266983032227 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 75.78780364990234, + 0.15951143205165863, + -0.005705664400011301, + 15.774530410766602, + 1.0175796747207642, + 0.8160231709480286, + -0.07753271609544754, + 0.7346847057342529, + 0.3590101897716522, + 0.6689444780349731, + 0.3919309079647064, + 0.15402911603450775, + 1.4761561155319214 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 75.1021957397461, + 0.5611235499382019, + -0.061047036200761795, + 12.449518203735352, + 1.441415786743164, + 2.05889630317688, + -0.03675329312682152, + 0.6397371888160706, + 0.1405928134918213, + 0.6875028014183044, + 0.619207501411438, + 0.2772309482097626, + 2.6521542072296143 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 73.67809295654297, + 1.273473858833313, + 0.11838360130786896, + 10.870401382446289, + 2.4883129596710205, + 2.4590728282928467, + -0.009165988303720951, + 0.6481884717941284, + -0.17122353613376617, + 0.29423776268959045, + 0.5014165043830872, + 0.60455322265625, + 4.389368534088135 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 64.56055450439453, + 1.6691628694534302, + 0.28693297505378723, + 7.733194351196289, + 2.853076457977295, + 3.099557638168335, + -0.07635346055030823, + 0.7694815397262573, + 0.46634575724601746, + 2.561486005783081, + 1.417119026184082, + 3.155592918395996, + 9.936020851135254 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 49.828426361083984, + 0.7582399249076843, + 0.4623791575431824, + 7.454215049743652, + 2.0036709308624268, + 1.2147228717803955, + 0.059838004410266876, + 0.3577386140823364, + 0.08632050454616547, + 1.1394087076187134, + 3.253187656402588, + 12.648087501525879, + 19.922006607055664 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 50.04413604736328, + 0.8563589453697205, + 0.40553396940231323, + 6.299482345581055, + 1.670020580291748, + 1.2485578060150146, + 0.11272558569908142, + 0.20368865132331848, + 0.11295190453529358, + 0.8178516030311584, + 0.9781670570373535, + 6.798254489898682, + 30.042600631713867 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 37.30668640136719, + 0.9599482417106628, + 0.3977169692516327, + 5.570315361022949, + 1.1437265872955322, + 0.5409839749336243, + 0.07948026061058044, + 0.048947811126708984, + 0.06985865533351898, + 0.3580631911754608, + 0.4941355586051941, + 1.5961358547210693, + 51.67811584472656 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 4.039702415466309, + 0.6095114350318909, + 0.26986661553382874, + 0.5296500325202942, + 0.36603206396102905, + 0.35008537769317627, + 0.08724961429834366, + 0.04926048964262009, + 0.03646741434931755, + 0.4526176154613495, + 0.44532474875450134, + 1.4484466314315796, + 91.33246612548828 + ] + ] + } + ], + "layout": { + "annotations": [ + { + "font": {}, + "showarrow": false, + "text": "resid_pre", + "x": 0.5, + "xanchor": "center", + "xref": "paper", + "y": 1, + "yanchor": "bottom", + "yref": "paper" + } + ], + "coloraxis": { + "cmax": 50, + "cmid": 0, + "cmin": -50, + "colorbar": { + "ticksuffix": "%", + "title": { + "text": "Logit diff variation" + } + }, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "margin": { + "l": 100, + "r": 100 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Patching at resid stream (corrupted -> clean)" + }, + "width": 600, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "scaleanchor": "y", + "showline": true, + "tickangle": 45, + "title": { + "text": "Sequence position" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "showline": true, + "title": { + "text": "Layer" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from utils.visualization import imshow_p\n", + "assert results.keys() == {\"resid_pre\"}#, \"attn_out\", \"mlp_out\"}\n", + "labels = [f\"{tok} {i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))]\n", + "imshow_p(\n", + " torch.stack([r.T for r in results.values()]) * 100, # we transpose so layer is on the y-axis\n", + " save_path=f\"data/images/{model.cfg.model_name}-{TASK}-act_patching_at_resid_stream.pdf\",\n", + " facet_col=0,\n", + " facet_labels=[\"resid_pre\"],#, \"attn_out\", \"mlp_out\"],\n", + " title=\"Patching at resid stream (corrupted -> clean)\",\n", + " labels={\"x\": \"Sequence position\", \"y\": \"Layer\", \"color\": \"Logit diff variation\"},\n", + " x=labels,\n", + " xaxis_tickangle=45,\n", + " coloraxis=dict(colorbar_ticksuffix = \"%\"),\n", + " border=True,\n", + " width=600,\n", + " zmin=-50,\n", + " zmax=50,\n", + " margin={\"r\": 100, \"l\": 100}\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "a0fb4e39", + "metadata": {}, + "source": [ + "### Circuit Analysis With Patch Patching & Attn Visualization" + ] + }, + { + "cell_type": "markdown", + "id": "5ce20b66", + "metadata": {}, + "source": [ + "#### Heads Influencing Logit Diff" + ] + }, + { + "cell_type": "markdown", + "id": "ac3f4e6a", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "c9244261", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b779ab0222e44b0c822d3e92d4677285", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/144 [00:00Layer: %{y}
Logit diff variation: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + -0.0007157690706662834, + -3.573485116703523e-7, + -0.0011894047493115067, + 0.0000010124874734174227, + -0.00005145818431628868, + -0.0000011911616866200347, + -0.0015604813816025853, + -0.00003877231210935861, + 0.00013966370897833258, + 0.00025216891663149, + 0.00004591928154695779, + -0.00029076257487758994 + ], + [ + -0.0001762919273460284, + -0.00037497770972549915, + 0.00018463005835656077, + 0.000017807866242947057, + -0.0005345933604985476, + -0.0008914058562368155, + -0.001125766895711422, + -0.0030335313640534878, + -0.0002229854726465419, + 0.0010292232036590576, + -0.00007903357618488371, + 0.0014531576307490468 + ], + [ + -0.0002946933964267373, + -0.007101914379745722, + -0.00007093367457855493, + -0.0002232832630397752, + 0.0003209585265722126, + -0.00027456277166493237, + 0.0006283377879299223, + -0.00009862818842520937, + -0.0001343630428891629, + 0.0004736058763228357, + 0.0004264954477548599, + 0.000050147908041253686 + ], + [ + -0.00005526990207727067, + -0.00032548492890782654, + 0.000060749243857571855, + -0.00018993072444573045, + -0.0005049334140494466, + -0.003343084594234824, + -0.0007884299266152084, + 0.000880208914168179, + -0.0001340056915068999, + -0.0004902523942291737, + 0.0019501103088259697, + -0.0006515058921650052 + ], + [ + -0.0002072025672532618, + -0.0006151159177534282, + -0.00426179775968194, + -0.0006773838540539145, + 0.001716761733405292, + -0.00026944075943902135, + -0.0005243195919319987, + 0.0010350599186494946, + -0.0041090017184615135, + -0.0018389452015981078, + -0.00022655895736534148, + -0.0006305712158791721 + ], + [ + 0.000018463006199453957, + -0.0011811857111752033, + -0.00016408252122346312, + -0.002718707313761115, + -0.0001687876065261662, + -0.004968901164829731, + -0.00036586530040949583, + -0.0008713050046935678, + -0.0013075977331027389, + -0.0028924383223056793, + -0.004494818858802319, + 0.004149768967181444 + ], + [ + -0.0005766413523815572, + -0.0006771754124201834, + -0.00006509698869194835, + -0.004977060481905937, + -0.03924282267689705, + 0.005858192685991526, + 0.01418369822204113, + -0.01260767225176096, + -0.000037998059269739315, + -0.003518423531204462, + -0.0057170698419213295, + -0.0007825336651876569 + ], + [ + -0.000013877033779863268, + -0.0481199249625206, + -0.0015227512922137976, + 0.0030029781628400087, + -0.00013025352382101119, + -0.035314518958330154, + 0.008920609951019287, + 0.00020333129214122891, + -0.0032958847004920244, + -0.013623971492052078, + -0.009546416811645031, + -0.0004524329851847142 + ], + [ + -0.0008567430195398629, + -0.0006085942732170224, + -0.027136806398630142, + -0.00671839015558362, + -0.01092024240642786, + -0.07323157787322998, + -0.0024589449167251587, + -0.004399198107421398, + -0.004958240315318108, + -0.011074736714363098, + 0.04941051825881004, + -0.008584225550293922 + ], + [ + -0.003471223870292306, + -0.009412023238837719, + -0.14556105434894562, + -0.03468424454331398, + -0.00046112845302559435, + 0.029237421229481697, + -0.023402277380228043, + 0.02771475911140442, + -0.0019462688360363245, + -0.0262917373329401, + -0.0334610715508461, + 0.0008673443808220327 + ], + [ + -0.027836615219712257, + -0.09413765370845795, + 0.0011034326162189245, + 0.0023843483068048954, + -0.28532126545906067, + -0.006456781178712845, + -0.0014429136645048857, + 0.07169155776500702, + 0.00003490103699732572, + -0.009984585456550121, + -0.005747087299823761, + -0.022625520825386047 + ], + [ + -0.003022334538400173, + -0.0028683471027761698, + -0.0031714679207652807, + -0.013026573695242405, + -0.0020223245956003666, + -0.003956890199333429, + -0.0007881916826590896, + -0.0037273236084729433, + -0.01088197622448206, + -0.02766377665102482, + 0.001457862788811326, + -0.017311837524175644 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorbar": { + "title": { + "text": "Logit diff variation" + } + }, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "margin": { + "l": 100, + "r": 100 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Direct effect on logit diff (patch from head output -> final resid)" + }, + "width": 600, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "scaleanchor": "y", + "showline": true, + "title": { + "text": "Head" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "showline": true, + "title": { + "text": "Layer" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" + ] + } + ], + "source": [ + "imshow_p(\n", + " results['z'],\n", + " save_path=f\"data/images/{model.cfg.model_name}-{TASK}-path_patching_attn_heads.pdf\",\n", + " title=\"Direct effect on logit diff (patch from head output -> final resid)\",\n", + " labels={\"x\": \"Head\", \"y\": \"Layer\", \"color\": \"Logit diff variation\"},\n", + " border=True,\n", + " width=600,\n", + " margin={\"r\": 100, \"l\": 100}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "ea933314", + "metadata": {}, + "outputs": [], + "source": [ + "from utils.visualization import (\n", + " plot_attention_heads,\n", + " scatter_attention_and_contribution\n", + ")\n", + "\n", + "import circuitsvis as cv" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "fbcabf4d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total logit diff contribution above threshold: 0.94\n" + ] + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "alignmentgroup": "True", + "hovertemplate": "Logit Diff=%{x}
Attention Head=%{y}", + "legendgroup": "", + "marker": { + "color": "#636efa", + "pattern": { + "shape": "" + } + }, + "name": "", + "offsetgroup": "", + "orientation": "h", + "showlegend": false, + "textposition": "auto", + "type": "bar", + "x": [ + 0.022625520825386047, + 0.023402277380228043, + 0.0262917373329401, + 0.027136806398630142, + 0.02766377665102482, + 0.027836615219712257, + 0.0334610715508461, + 0.03468424454331398, + 0.035314518958330154, + 0.03924282267689705, + 0.0481199249625206, + 0.07323157787322998, + 0.09413765370845795, + 0.14556105434894562, + 0.28532126545906067 + ], + "xaxis": "x", + "y": [ + "Layer 10, Head 11", + "Layer 9, Head 6", + "Layer 9, Head 9", + "Layer 8, Head 2", + "Layer 11, Head 9", + "Layer 10, Head 0", + "Layer 9, Head 10", + "Layer 9, Head 3", + "Layer 7, Head 5", + "Layer 6, Head 4", + "Layer 7, Head 1", + "Layer 8, Head 5", + "Layer 10, Head 1", + "Layer 9, Head 2", + "Layer 10, Head 4" + ], + "yaxis": "y" + } + ], + "layout": { + "barmode": "relative", + "legend": { + "tracegroupgap": 0 + }, + "margin": { + "t": 60 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "xaxis": { + "anchor": "y", + "domain": [ + 0, + 1 + ], + "range": [ + 0, + 0.5 + ], + "title": { + "text": "Logit Diff" + } + }, + "yaxis": { + "anchor": "x", + "domain": [ + 0, + 1 + ], + "title": { + "text": "Attention Head" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plot_attention_heads(-results['z'].cuda(), top_n=15, range_x=[0, 0.5])" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "26a5ffbe", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from utils.visualization import get_attn_head_patterns\n", + "\n", + "top_k = 5\n", + "top_heads = torch.topk(-results['z'].flatten(), k=top_k).indices.cpu().numpy()\n", + "heads = [(head // model.cfg.n_heads, head % model.cfg.n_heads) for head in top_heads]\n", + "tokens, attn, names = get_attn_head_patterns(model, all_prompts[1], heads)\n", + "cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "177acf22", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "from visualization import scatter_attention_and_contribution_sentiment\n", + "\n", + "from plotly.subplots import make_subplots\n", + "\n", + "# Get the figures\n", + "fig1 = scatter_attention_and_contribution_sentiment(model, (7, 1), all_prompts, [6 for _ in range(len(all_prompts))], answer_residual_directions, return_fig=True)\n", + "fig2 = scatter_attention_and_contribution_sentiment(model, (9, 2), all_prompts, [6 for _ in range(len(all_prompts))], answer_residual_directions, return_fig=True)\n", + "fig3 = scatter_attention_and_contribution_sentiment(model, (10, 1), all_prompts, [6 for _ in range(len(all_prompts))], answer_residual_directions, return_fig=True)\n", + "fig4 = scatter_attention_and_contribution_sentiment(model, (10, 4), all_prompts, [6 for _ in range(len(all_prompts))], answer_residual_directions, return_fig=True)\n", + "\n", + "# Create subplot\n", + "fig = make_subplots(rows=2, cols=2, subplot_titles=(\"Head 7.1\", \"Head 9.2\", \"Head 10.1\", \"Head 10.4\"))\n", + "\n", + "# Add each figure's data to the subplot\n", + "for i, subplot_fig in enumerate([fig1, fig2, fig3, fig4], start=1):\n", + " row = (i-1)//2 + 1\n", + " col = (i-1)%2 + 1\n", + " for trace in subplot_fig['data']:\n", + " # Only show legend for the first subplot\n", + " trace.showlegend = (i == 1)\n", + " fig.add_trace(trace, row=row, col=col)\n", + "\n", + "# Update layout\n", + "fig.update_layout(height=600, title_text=\"DAE Heads\")\n", + "\n", + "# Update axes labels\n", + "for i in range(1, 3):\n", + " for j in range(1, 3):\n", + " fig.update_xaxes(title_text=\"Attn Prob on Word\", row=i, col=j)\n", + " fig.update_yaxes(title_text=\"Dot w Sentiment Output Embed\", row=i, col=j)\n", + "\n", + "fig.show()" + ] + }, + { + "cell_type": "markdown", + "id": "e0c9ce5b", + "metadata": {}, + "source": [ + "#### Direct Attribute Extraction Heads" + ] + }, + { + "cell_type": "markdown", + "id": "b5ab4d39", + "metadata": {}, + "source": [ + "##### Overall" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "f2ebdf04", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "928d1aa380d649e593be7f8646d0c0f3", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/144 [00:00Layer: %{y}
Logit diff variation: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + -0.043900266289711, + -0.8076581954956055, + -0.08461714535951614, + -1.1449445486068726, + -1.7893661260604858, + -0.5216960906982422, + -0.2742113769054413, + -0.024871455505490303, + -0.0024537930730730295, + 0.003335252869874239, + -0.13965179026126862, + -0.08820850402116776 + ], + [ + 0.025848209857940674, + 0.054328881204128265, + 0.0033412084449082613, + 0.013603066094219685, + 0.03332870081067085, + 0.44592925906181335, + -0.16542555391788483, + -0.11087332665920258, + 0.00933870766311884, + 0.0032101809047162533, + 0.017152728512883186, + -0.4313434362411499 + ], + [ + 0.016676263883709908, + -0.4144915044307709, + -0.003793849842622876, + -0.05691668391227722, + -0.030874911695718765, + 0.018337933346629143, + 0.027980387210845947, + 0.11327947676181793, + 0.03005300834774971, + 0.03030315227806568, + 0.019124099984765053, + 0.024746384471654892 + ], + [ + -0.09031685441732407, + 0.03968950733542442, + -0.04892994090914726, + 0.0020726213697344065, + 0.017462430521845818, + -0.01058347150683403, + -0.018409403041005135, + -0.010172520764172077, + 0.04252447187900543, + -0.06567171961069107, + -0.060076236724853516, + -1.0905025005340576 + ], + [ + -0.04570785164833069, + -0.03495464101433754, + 0.03163725510239601, + 0.06805702298879623, + 0.37976616621017456, + -0.04359056055545807, + 0.059397272765636444, + -0.22010880708694458, + -0.05012408271431923, + -0.6134363412857056, + -0.046410635113716125, + -0.05093109607696533 + ], + [ + 0.052929267287254333, + 0.0006194040761329234, + -0.1372813880443573, + -0.014955035410821438, + -0.13389253616333008, + -0.20677971839904785, + -0.00807607639580965, + -0.5959084033966064, + 0.20474281907081604, + 0.14360645413398743, + -0.08353616297245026, + 1.3018503189086914 + ], + [ + -0.11434556543827057, + 0.000744476041290909, + 0.1754104644060135, + 0.4612357020378113, + -3.0545434951782227, + 0.6577892899513245, + 0.4239642024040222, + -1.5858709812164307, + -0.2600722908973694, + -0.04261678457260132, + -0.18388855457305908, + -0.27552464604377747 + ], + [ + -0.032143499702215195, + -1.2332544326782227, + -0.010756189934909344, + 0.29170358180999756, + 0.0029600367415696383, + -0.6928094029426575, + 0.2500724792480469, + 0.0006134482682682574, + -0.32403767108917236, + 0.005711620207875967, + -0.04009747877717018, + 0.005592504050582647 + ], + [ + 0.11791309714317322, + 0.00933870766311884, + -0.6571072936058044, + -0.1169363409280777, + -0.45855557918548584, + -0.6660350561141968, + 0.007814020849764347, + -0.05507931485772133, + 0.01670008711516857, + -0.4150930345058441, + 0.699801504611969, + -0.13125112652778625 + ], + [ + 0.16058646142482758, + -0.0642959326505661, + -1.9615365266799927, + -0.6590429544448853, + -0.041565585881471634, + 0.6797840595245361, + -0.24200831353664398, + 0.8228247165679932, + -0.026265116408467293, + -0.151742085814476, + -0.4979711174964905, + -0.013442259281873703 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorbar": { + "ticksuffix": "%", + "title": { + "text": "Logit diff variation" + } + }, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "margin": { + "l": 100, + "r": 100 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Direct effect on Sentiment Attenders' values" + }, + "width": 700, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "scaleanchor": "y", + "showline": true, + "title": { + "text": "Head" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "showline": true, + "title": { + "text": "Layer" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "imshow_p(\n", + " results[\"z\"][:10] * 100,\n", + " title=f\"Direct effect on Sentiment Attenders' values\",\n", + " labels={\"x\": \"Head\", \"y\": \"Layer\", \"color\": \"Logit diff variation\"},\n", + " coloraxis=dict(colorbar_ticksuffix = \"%\"),\n", + " border=True,\n", + " width=700,\n", + " margin={\"r\": 100, \"l\": 100}\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "8056cf28", + "metadata": {}, + "source": [ + "##### Key Positions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "45313431", + "metadata": {}, + "outputs": [], + "source": [ + "results = path_patch(\n", + " model,\n", + " orig_input=clean_tokens,\n", + " new_input=corrupted_tokens,\n", + " sender_nodes=IterNode(\"z\", seq_pos=\"each\"),\n", + " receiver_nodes=[Node(\"v\", layer, head=head) for layer, head in DAE_HEADS],\n", + " patching_metric=logit_diff_noising,\n", + " verbose=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e3d66695", + "metadata": {}, + "outputs": [], + "source": [ + "for i in range(10, results[\"z\"].shape[0]):\n", + " imshow_p(\n", + " results[\"z\"][i][:10] * 100,\n", + " title=f\"Direct effect on Sentiment Attenders' values from position {i} ({tokens[i]})\",\n", + " labels={\"x\": \"Head\", \"y\": \"Layer\", \"color\": \"Logit diff variation\"},\n", + " coloraxis=dict(colorbar_ticksuffix = \"%\"),\n", + " border=True,\n", + " width=700,\n", + " margin={\"r\": 100, \"l\": 100}\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7e217fab", + "metadata": {}, + "outputs": [], + "source": [ + "plot_attention_heads(-results['z'].cuda(), top_n=15, range_x=[0, 0.1])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7ed71842", + "metadata": {}, + "outputs": [], + "source": [ + "top_k = 4\n", + "top_heads = torch.topk(-results['z'].flatten(), k=top_k).indices.cpu().numpy()\n", + "heads = [(head // model.cfg.n_heads, head % model.cfg.n_heads) for head in top_heads]\n", + "tokens, attn, names = get_attn_head_patterns(model, all_prompts[0], heads)\n", + "cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)" + ] + }, + { + "cell_type": "markdown", + "id": "2468658e", + "metadata": {}, + "source": [ + "#### Intermediate Attribute Extraction Heads" + ] + }, + { + "cell_type": "markdown", + "id": "167a05ee", + "metadata": {}, + "source": [ + "##### Overall" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "c045c04b", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7b0d1a3e10a94888ab42806966fbfec1", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/144 [00:00Layer: %{y}
Logit diff variation: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + -0.004216712433844805, + 0.0004466856480576098, + 0.0006015366525389254, + -0.00004764646655530669, + -0.004788469988852739, + 0.000005955808319413336, + -0.031095273792743683, + -0.005050525534898043, + -0.008647833950817585, + 0.00303150643594563, + -0.019326597452163696, + -0.00889202207326889 + ], + [ + 0.004901630338281393, + 0.007712772116065025, + 0.00014889521116856486, + -0.0000416906586906407, + -0.023436104878783226, + -0.0008338132174685597, + -0.01326954085379839, + 0.017492210492491722, + -0.010023625567555428, + 0.003746203612536192, + 0.0018701237859204412, + 0.012137937359511852 + ], + [ + -0.007677036803215742, + -0.149583101272583, + 0.0014770404668524861, + 0.011679340153932571, + 0.002662246348336339, + -0.012834767811000347, + -0.006384626496583223, + 0.010130830109119415, + -0.0025193069595843554, + 0.011107582598924637, + 0.014413056895136833, + -0.004151198081672192 + ], + [ + -0.011036112904548645, + 0.021202677860856056, + 0.009106430225074291, + 0.0007921225042082369, + 0.0235194880515337, + 0.010607294738292694, + -0.014520260505378246, + 0.004526414442807436, + -0.006122571416199207, + -0.011500665917992592, + 0.0243771243840456, + -0.026890475302934647 + ], + [ + -0.00545552046969533, + 0.005693752784281969, + -0.09458121657371521, + -0.0055210343562066555, + 0.04214330017566681, + -0.0019594610203057528, + 0.004603839945048094, + -0.007581744343042374, + -0.030261460691690445, + -0.08964980393648148, + -0.0032280480954796076, + 0.006956384517252445 + ], + [ + -0.003239959944039583, + 0.006217863876372576, + -0.003472236217930913, + -0.03126204013824463, + -0.0038117174990475178, + -0.36781883239746094, + 0.0036985569167882204, + -0.02883206680417061, + 0.04476981237530708, + 0.22402773797512054, + -0.18110719323158264, + 0.159371480345726 + ], + [ + -0.008230927400290966, + -0.008624010719358921, + 0.03735483065247536, + -0.024948880076408386, + -1.8061554431915283, + 0.177620068192482, + 0.21879258751869202, + -0.1933731883764267, + -0.015276649035513401, + -0.08716028183698654, + -0.0988009050488472, + -0.007992695085704327 + ], + [ + -0.032619964331388474, + -6.772891521453857, + -0.015348118729889393, + 0.5428838729858398, + -0.004979055840522051, + -4.109564304351807, + 2.6856706142425537, + -0.021845905110239983, + -0.036133889108896255, + -0.3091183602809906, + -0.019362332299351692, + -0.00985686294734478 + ], + [ + -0.006366759072989225, + -0.011732942424714565, + -0.15355563163757324, + -0.0243116095662117, + -0.27287429571151733, + -0.8576394319534302, + -0.15307322144508362, + -0.029117945581674576, + 0.21484386920928955, + -0.029850512742996216, + 1.0204145908355713, + -0.08999822288751602 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorbar": { + "ticksuffix": "%", + "title": { + "text": "Logit diff variation" + } + }, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "margin": { + "l": 100, + "r": 100 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Direct effect on Intermediate AE Heads' values" + }, + "width": 700, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "scaleanchor": "y", + "showline": true, + "title": { + "text": "Head" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "showline": true, + "title": { + "text": "Layer" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "imshow_p(\n", + " results[\"z\"][:10] * 100,\n", + " title=f\"Direct effect on Intermediate AE Heads' values\",\n", + " labels={\"x\": \"Head\", \"y\": \"Layer\", \"color\": \"Logit diff variation\"},\n", + " coloraxis=dict(colorbar_ticksuffix = \"%\"),\n", + " border=True,\n", + " width=700,\n", + " margin={\"r\": 100, \"l\": 100}\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "9f1163d8", + "metadata": {}, + "source": [ + "##### By Position" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "a229a9dd", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "17b8f120b50a4cf1bc45cc834ce07f87", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/2736 [00:00Layer: %{y}
Logit diff variation: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + -0.0020368865225464106, + 0.000005955808319413336, + 0.00004764646655530669, + -0.0007921225042082369, + -0.0029302577022463083, + 0, + -0.004907586146146059, + -0.0018641679780557752, + 0.0005955808446742594, + 0.0021262236405164003, + -0.0012983662309125066, + 0.0007385202334262431 + ], + [ + -0.0021619584877043962, + -0.002310853684321046, + 0.00027396719087846577, + 0.00035734850098378956, + -0.0013400568859651685, + -0.00041095077176578343, + -0.0009767526062205434, + -0.0013877033488824964, + 0.00036330430884845555, + 0.000023823233277653344, + 0.00008933712524594739, + 0.0017033611657097936 + ], + [ + 0.003102976130321622, + -0.005789045710116625, + -0.008230927400290966, + -0.0025193069595843554, + -0.002686069579795003, + -0.0003811717324424535, + 0.0037700266111642122, + -0.0002799229696393013, + -0.00895158015191555, + -0.009017093107104301, + -0.003388855140656233, + 0.0005717576132155955 + ], + [ + -0.000023823233277653344, + -0.000762343464884907, + -0.003978480119258165, + -0.007730639539659023, + -0.0008338132174685597, + -0.0003037462302017957, + 0.0004526414268184453, + 0.0005241111502982676, + -0.010339283384382725, + -0.005961764138191938, + 0.0059260292910039425, + -0.07328324019908905 + ], + [ + -0.010976554825901985, + -0.0038712755776941776, + -0.00041690660873427987, + -0.0011196918785572052, + 0.009993846528232098, + -0.0013936591567471623, + -0.0003513926931191236, + 0.002650334732607007, + -0.0034484132193028927, + -0.02569335699081421, + -0.0008933712961152196, + -0.009332751855254173 + ], + [ + 0.00009529293311061338, + 0.00039903915603645146, + -0.00230489787645638, + 0.0032876061741262674, + -0.013102779164910316, + -0.0035437061451375484, + -0.0020190190989524126, + -0.08545691519975662, + 0.007456671912223101, + 0.0039248778484761715, + 0.00378789403475821, + 0.012483375146985054 + ], + [ + -0.004478767514228821, + 0.0029123902786523104, + 0.008123721927404404, + 0.0042107561603188515, + -0.029713526368141174, + 0.014293940737843513, + 0.017212286591529846, + -0.035038020461797714, + -0.012590578757226467, + 0.0005836692289449275, + -0.003692601341754198, + -0.023001331835985184 + ], + [ + -0.017676837742328644, + -0.038784224539995193, + -0.0006194040761329234, + 0.03168490156531334, + 0.0047170002944767475, + -0.033221498131752014, + 0.020827462896704674, + 0.003912965767085552, + -0.030886821448802948, + -0.00877290591597557, + -0.0008040341781452298, + -0.00023823234369046986 + ], + [ + 0.0022691627964377403, + -0.0063191126100718975, + -0.006104703992605209, + 0.004198845010250807, + -0.040061745792627335, + -0.001971372403204441, + 0.0024299698416143656, + -0.002626511501148343, + 0.014520260505378246, + -0.010922952555119991, + 0.014043795876204967, + -0.008600186556577682 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorbar": { + "ticksuffix": "%", + "title": { + "text": "Logit diff variation" + } + }, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "margin": { + "l": 100, + "r": 100 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Direct effect on Intermediate AE Heads' values from position 10 ( it)" + }, + "width": 700, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "scaleanchor": "y", + "showline": true, + "title": { + "text": "Head" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "showline": true, + "title": { + "text": "Layer" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
Logit diff variation: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + -0.006700284779071808, + -0.000005955808319413336, + -0.0035496617201715708, + 0.0029362135101109743, + -0.0005538901896215975, + -0.0003335252695251256, + -0.024877412244677544, + -0.0017331402050331235, + -0.0005538901896215975, + 0.0018165216315537691, + -0.0064501408487558365, + -0.002340632723644376 + ], + [ + -0.007462628185749054, + 0.0021143120247870684, + 0.000059558085922617465, + 0.0006313156918622553, + -0.0060034547932446, + 0.0006074924604035914, + -0.0012209407286718488, + -0.016527367755770683, + 0.003430545562878251, + 0.0021381352562457323, + 0.0006432273075915873, + 0.005622283089905977 + ], + [ + -0.009749658405780792, + -0.023358680307865143, + 0.0022096047177910805, + 0.0032340039033442736, + -0.0025193069595843554, + -0.003359076101332903, + 0.0007921225042082369, + -0.0029064344707876444, + -0.009565028361976147, + -0.00916003342717886, + 0.0012566755758598447, + 0.0019832842517644167 + ], + [ + -0.00014889521116856486, + 0.004919497761875391, + -0.007105279713869095, + 0.0019773284438997507, + -0.0002918346144724637, + -0.011286256834864616, + 0.0012149849208071828, + -0.018302198499441147, + -0.0028766554314643145, + -0.006497786846011877, + 0.01338865701109171, + -0.10252626240253448 + ], + [ + -0.028522364795207977, + -0.0007325644255615771, + 0.0016735821263864636, + -0.01058347150683403, + 0.05964742228388786, + -0.0067419749684631824, + 0.006515654269605875, + 0.02274523302912712, + -0.023138314485549927, + -0.15918684005737305, + -0.001858212286606431, + -0.029117945581674576 + ], + [ + 0.0002977904223371297, + 0.0001250719651579857, + -0.018820354714989662, + 0.01718250662088394, + -0.03560382127761841, + -0.0082249715924263, + 0.0007325644255615771, + -0.03273908048868179, + 0.00527684623375535, + -0.009743702597916126, + -0.010690676048398018, + 0.05736634507775307 + ], + [ + -0.0036985569167882204, + 0.002686069579795003, + 0.002656290540471673, + -0.02951103076338768, + -0.16286753118038177, + 0.03656866401433945, + 0.08688928931951523, + -0.22858095169067383, + -0.0030612857080996037, + 0.00023823234369046986, + -0.012215363793075085, + -0.014025929383933544 + ], + [ + -0.005675885360687971, + -0.07935220748186111, + -0.0012507197679951787, + 0.114625483751297, + 0.015550616197288036, + -0.11176372319459915, + 0.18044313788414001, + 0.009154077619314194, + -0.03127394989132881, + 0.015383852645754814, + -0.0017569635529071093, + 0.0010243990691378713 + ], + [ + -0.002227472374215722, + -0.01297175046056509, + -0.02194119803607464, + 0.0027992299292236567, + -0.010285681113600731, + -0.10371146351099014, + -0.012167716398835182, + -0.0019475494045764208, + 0.016729865223169327, + -0.01597347855567932, + 0.06490640342235565, + -0.017444562166929245 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorbar": { + "ticksuffix": "%", + "title": { + "text": "Logit diff variation" + } + }, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "margin": { + "l": 100, + "r": 100 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Direct effect on Intermediate AE Heads' values from position 11 (.)" + }, + "width": 700, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "scaleanchor": "y", + "showline": true, + "title": { + "text": "Head" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "showline": true, + "title": { + "text": "Layer" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
Logit diff variation: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + -0.000762343464884907, + -0.00020845330436713994, + -0.0031923132482916117, + 0.00001786742541298736, + -0.0012864546151831746, + -0.000011911616638826672, + 0.0002680113830137998, + -0.0010065316455438733, + -0.0013757917331531644, + -0.004252447281032801, + 0.0008814596221782267, + 0.0012626313837245107 + ], + [ + -0.0011911616893485188, + 0.0006908737705089152, + -0.0002799229696393013, + -0.00006551389378728345, + -0.0019475494045764208, + -0.0009231503354385495, + -0.0009529293747618794, + -0.0018105657072737813, + 0.00041690660873427987, + -0.0008993271039798856, + 0.0006908737705089152, + 0.00381767307408154 + ], + [ + -0.00024418815155513585, + -0.005574636626988649, + -0.0010780013399198651, + -0.001524686929769814, + -0.00015485101903323084, + 0.0009350618929602206, + 0.0005658018053509295, + -0.0005479343817569315, + -0.0017152727814391255, + -0.0013043220387771726, + 0.000029779042961308733, + -0.0008159457356669009 + ], + [ + -0.0005538901896215975, + 0.0006074924604035914, + 0.0009529293747618794, + 0.00007742550951661542, + 0.0012864546151831746, + 0.00150086369831115, + 0.00035734850098378956, + -0.0003335252695251256, + -0.002739671850576997, + -0.0004228623874951154, + 0.0022393837571144104, + -0.002191737527027726 + ], + [ + -0.0005062437267042696, + -0.0009946200298145413, + 0.0007146970019675791, + -0.0007980783120729029, + 0.0004526414268184453, + -0.0005062437267042696, + 0.001858212286606431, + 0.0008874154300428927, + 0.0006074924604035914, + 0.002257251413539052, + -0.0003752159245777875, + -0.004645530600100756 + ], + [ + -0.00027396719087846577, + 0, + -0.0017688751686364412, + -0.0013698359252884984, + -0.000744476041290909, + -0.0010303547605872154, + 0.0014174823882058263, + -0.0015961566241458058, + 0.0008457247749902308, + -0.0000416906586906407, + 0.0030017273966223, + 0.0051160394214093685 + ], + [ + -0.00036330430884845555, + 0.0006253598839975893, + -0.004437077324837446, + 0.0036687778774648905, + -0.0011792500736191869, + 0.003966568503528833, + 0.002721804426982999, + -0.011065891943871975, + 0.001125647802837193, + 0.00001786742541298736, + -0.0017271845135837793, + -0.0016735821263864636 + ], + [ + -0.0043417843990027905, + -0.0017867425922304392, + -0.00018463005835656077, + 0.004812293220311403, + -0.00005360227441997267, + -0.006759842857718468, + 0.020958488807082176, + 0.0009707967401482165, + -0.001554465969093144, + 0.0001310277875745669, + -0.00007146970165194944, + -0.000023823233277653344 + ], + [ + 0.0003097020380664617, + -0.00001786742541298736, + -0.0001250719651579857, + -0.00011316035670461133, + -0.0003752159245777875, + 0.004073773045092821, + 0.0009410177008248866, + -0.0001727184426272288, + -0.00032756946166045964, + -0.0025193069595843554, + 0.0012685871915891767, + -0.000029779042961308733 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorbar": { + "ticksuffix": "%", + "title": { + "text": "Logit diff variation" + } + }, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "margin": { + "l": 100, + "r": 100 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Direct effect on Intermediate AE Heads' values from position 12 ( )" + }, + "width": 700, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "scaleanchor": "y", + "showline": true, + "title": { + "text": "Head" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "showline": true, + "title": { + "text": "Layer" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
Logit diff variation: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + -0.0060034547932446, + -0.00005360227441997267, + -0.00609279191121459, + 0.00019058586622122675, + -0.0005479343817569315, + 0.000005955808319413336, + -0.003525838488712907, + -0.0028647438157349825, + -0.0013460126938298345, + -0.00005360227441997267, + 0.0024657046888023615, + 0.002638423116877675 + ], + [ + 0.002322765300050378, + 0.005014790687710047, + 0.0011792500736191869, + 0.0014234381960704923, + -0.012715651653707027, + -0.004413254093378782, + -0.001911814440973103, + -0.025371745228767395, + -0.0013698359252884984, + 0.0016557147027924657, + 0.0024418814573436975, + 0.0038355407305061817 + ], + [ + -0.006057057064026594, + -0.057845789939165115, + -0.0009767526062205434, + -0.0010482223005965352, + 0.0002144090976798907, + -0.003525838488712907, + 0.0036687778774648905, + -0.004514502827078104, + -0.0003454368852544576, + -0.0032280480954796076, + -0.006604991387575865, + 0.0017331402050331235 + ], + [ + -0.00001786742541298736, + 0.0012388081522658467, + 0.0006074924604035914, + -0.0012983662309125066, + 0.0034245899878442287, + -0.008028429932892323, + -0.0013519685016945004, + 0.0008099899860098958, + -0.008201148360967636, + 0.0003513926931191236, + 0.014895477332174778, + -0.016110461205244064 + ], + [ + -0.002697981195524335, + -0.003960612695664167, + -0.0033412084449082613, + -0.0012209407286718488, + 0.010916996747255325, + -0.0024359256494790316, + 0.0029600367415696383, + 0.01898116245865822, + -0.010357150807976723, + -0.01052391342818737, + -0.02188163995742798, + -0.0015782893169671297 + ], + [ + 0.0008576363907195628, + 0.000029779042961308733, + -0.0039487010799348354, + -0.014073574915528297, + -0.004985011648386717, + -0.00007146970165194944, + -0.010637073777616024, + -0.0020904885604977608, + -0.0006730063469149172, + -0.016843024641275406, + -0.00892180111259222, + 0.0143356304615736 + ], + [ + -0.0001667626347625628, + 0.0002501439303159714, + 0.01692045107483864, + -0.029106035828590393, + -0.04865895211696625, + 0.03428759053349495, + 0.04844454675912857, + -0.14312700927257538, + -0.003043418051674962, + 0.00041690660873427987, + -0.008969447575509548, + -0.0024716604966670275 + ], + [ + -0.0038295849226415157, + -0.016038991510868073, + -0.0003454368852544576, + 0.017462430521845818, + -0.003085108706727624, + -0.06007326394319534, + 0.0915050357580185, + 0.0019594610203057528, + -0.000750431849155575, + -0.00227511883713305, + 0.000029779042961308733, + 0.0009231503354385495 + ], + [ + -0.003013639012351632, + -0.0003811717324424535, + -0.01995791308581829, + -0.0007206528098322451, + -0.0009291061433032155, + -0.013835342600941658, + -0.0014115265803411603, + -0.0010303547605872154, + -0.001923726056702435, + -0.0012030733050778508, + 0.02472851797938347, + -0.006152350455522537 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorbar": { + "ticksuffix": "%", + "title": { + "text": "Logit diff variation" + } + }, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "margin": { + "l": 100, + "r": 100 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Direct effect on Intermediate AE Heads' values from position 13 (\n)" + }, + "width": 700, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "scaleanchor": "y", + "showline": true, + "title": { + "text": "Head" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "showline": true, + "title": { + "text": "Layer" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
Logit diff variation: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + -0.01772448606789112, + 0.00005360227441997267, + -0.005640150513499975, + -0.0002560997672844678, + -0.0015604218933731318, + -0.000011911616638826672, + 0.004413254093378782, + -0.004633618984371424, + -0.0010124874534085393, + -0.00446090055629611, + 0.005354271735996008, + 0.000744476041290909 + ], + [ + 0.0007146970019675791, + 0.0010780013399198651, + 0.0009588851826265454, + 0.0020666655618697405, + 0.00218578171916306, + 0.0027932741213589907, + -0.009291061200201511, + -0.003305473830550909, + 0.005330448504537344, + 0.0014591730432584882, + 0.00035734850098378956, + -0.0029540809337049723 + ], + [ + -0.0016735821263864636, + -0.017629193142056465, + -0.0003335252695251256, + 0.00221556075848639, + 0.00019058586622122675, + -0.0031923132482916117, + 0.005348315928131342, + 0.0009350618929602206, + 0.0024657046888023615, + -0.0028051857370883226, + -0.00007742550951661542, + 0.0029719483572989702 + ], + [ + -0.00022632071340922266, + 0.0031923132482916117, + 0.0031744458246976137, + -0.000023823233277653344, + 0.0091004753485322, + -0.02139922045171261, + 0.0001667626347625628, + -0.0009707967401482165, + -0.011351770721375942, + -0.0004526414268184453, + 0.005699708592146635, + -0.020225925371050835 + ], + [ + -0.0012626313837245107, + -0.0002799229696393013, + -0.00463957479223609, + -0.0008993271039798856, + 0.012149848975241184, + 0.001095868763513863, + 0.0037581149954348803, + 0.0015604218933731318, + -0.032798636704683304, + -0.017819778993725777, + 0.005860515404492617, + -0.005634194705635309 + ], + [ + -0.0005538901896215975, + 0.000005955808319413336, + -0.009225546382367611, + -0.010440532118082047, + -0.004359651356935501, + -0.0006253598839975893, + 0.0032518713269382715, + -0.006491831503808498, + 0.0011494709178805351, + -0.0031267995946109295, + -0.013692403212189674, + 0.07854520529508591 + ], + [ + -0.004985011648386717, + -0.004961188416928053, + 0.01838558167219162, + -0.024895280599594116, + -0.05080602318048477, + 0.04863513261079788, + 0.023775586858391762, + -0.11539378762245178, + -0.00757578806951642, + -0.00005360227441997267, + -0.004889718722552061, + -0.006557344924658537 + ], + [ + -0.010130830109119415, + -0.046103913336992264, + -0.00035734850098378956, + 0.03429950028657913, + 0.0021500466391444206, + -0.0792420282959938, + 0.21386712789535522, + 0.0003752159245777875, + -0.002233428182080388, + -0.01157809142023325, + -0.0006849179626442492, + 0.0008219015435315669 + ], + [ + 0.003114887746050954, + 0.0004288181953597814, + -0.017337357625365257, + -0.0032340039033442736, + -0.006479919888079166, + -0.004621707368642092, + -0.005669929552823305, + 0.00007146970165194944, + 0.010982510633766651, + -0.011762721464037895, + 0.026157911866903305, + -0.01983284205198288 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorbar": { + "ticksuffix": "%", + "title": { + "text": "Logit diff variation" + } + }, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "margin": { + "l": 100, + "r": 100 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Direct effect on Intermediate AE Heads' values from position 14 (Conclusion)" + }, + "width": 700, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "scaleanchor": "y", + "showline": true, + "title": { + "text": "Head" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "showline": true, + "title": { + "text": "Layer" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
Logit diff variation: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + -0.0013936591567471623, + -0.000005955808319413336, + -0.010488178580999374, + 0.00008933712524594739, + -0.0027873183134943247, + -0.0001250719651579857, + -0.0007266086176969111, + -0.0029719483572989702, + -0.0002680113830137998, + -0.0012328523444011807, + -0.0003930833481717855, + -0.0007206528098322451 + ], + [ + 0.0027932741213589907, + -0.001137559418566525, + 0.0034186341799795628, + 0.0014710846589878201, + -0.021970978006720543, + -0.008242838084697723, + -0.002030930481851101, + -0.03139902278780937, + -0.00019058586622122675, + 0.005425741430372, + -0.0007921225042082369, + 0.01861785724759102 + ], + [ + -0.00038712756941094995, + -0.06325366348028183, + 0.00381767307408154, + -0.004032081924378872, + 0.011113538406789303, + -0.0035615733359009027, + -0.002287030452862382, + 0.0013162336545065045, + -0.007033809553831816, + 0.0004943321109749377, + -0.004562149289995432, + 0.0027873183134943247 + ], + [ + -0.001107780379243195, + 0.005771178286522627, + 0.00007742550951661542, + -0.006992118898779154, + 0.0021798256784677505, + -0.028456851840019226, + -0.006146394181996584, + -0.0028290089685469866, + -0.002709892811253667, + -0.01887395605444908, + -0.00023227653582580388, + 0.0036687778774648905 + ], + [ + -0.0039248778484761715, + -0.004014214966446161, + -0.0037581149954348803, + 0.0027575392741709948, + 0.017122948542237282, + -0.005586548242717981, + 0.0018046098994091153, + 0.011167140677571297, + -0.07067757844924927, + -0.016962140798568726, + -0.005020746495574713, + -0.00233467691577971 + ], + [ + -0.0036389988381415606, + 0.000023823233277653344, + -0.02049989253282547, + -0.015336207114160061, + -0.012227274477481842, + -0.008338131941854954, + -0.001095868763513863, + -0.010273769497871399, + -0.0004347740323282778, + 0.007194616831839085, + -0.05531159043312073, + 0.1065613254904747 + ], + [ + -0.0021143120247870684, + -0.0034841480664908886, + -0.008355999365448952, + -0.13864824175834656, + -0.3815380036830902, + 0.17281374335289001, + 0.0623990036547184, + -0.3717496395111084, + -0.007337555754929781, + -0.0006908737705089152, + -0.02200671099126339, + -0.023382503539323807 + ], + [ + -0.011262433603405952, + -0.6075252294540405, + -0.0016140240477398038, + 0.08838419616222382, + -0.0010780013399198651, + -0.3567320704460144, + 1.0412063598632812, + 0.008219015784561634, + -0.0033531200606375933, + -0.02244744263589382, + -0.0020547539461404085, + 0.005419785622507334 + ], + [ + -0.003519882680848241, + -0.0021262236405164003, + -0.08456950634717941, + -0.03936789184808731, + -0.02085128426551819, + -0.057592667639255524, + -0.013311232440173626, + -0.001518731121905148, + 0.02321574091911316, + -0.01964821107685566, + 0.16784659028053284, + -0.05110083892941475 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorbar": { + "ticksuffix": "%", + "title": { + "text": "Logit diff variation" + } + }, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "margin": { + "l": 100, + "r": 100 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Direct effect on Intermediate AE Heads' values from position 15 (:)" + }, + "width": 700, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "scaleanchor": "y", + "showline": true, + "title": { + "text": "Head" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "showline": true, + "title": { + "text": "Layer" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
Logit diff variation: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + -0.0051458184607326984, + 0.00001786742541298736, + -0.0012507197679951787, + 0.000011911616638826672, + 0.0012209407286718488, + 0, + -0.03301900252699852, + -0.0010482223005965352, + -0.0016140240477398038, + -0.006366759072989225, + -0.008070120587944984, + 0.005038613919168711 + ], + [ + -0.0031625342089682817, + 0.001971372403204441, + -0.0004824204952456057, + 0.0020071074832230806, + -0.008367910981178284, + -0.012394038029015064, + -0.007004030514508486, + -0.05059756711125374, + -0.0014472614275291562, + 0.004836116451770067, + 0.00014293940330389887, + 0.014972901903092861 + ], + [ + -0.005812868941575289, + -0.10299380123615265, + 0.0020785771775990725, + -0.0024180582258850336, + -0.00009529293311061338, + -0.0038712755776941776, + 0.008707392029464245, + -0.002310853684321046, + -0.0018701237859204412, + 0.002001151442527771, + -0.000005955808319413336, + 0.0038117174990475178 + ], + [ + 0.0009588851826265454, + -0.0006074924604035914, + 0.0031982690561562777, + -0.0047408235259354115, + -0.0008755038143135607, + -0.020648788660764694, + 0.00011911617184523493, + 0.002251295605674386, + -0.0011852058814838529, + -0.008701436221599579, + 0.021250324323773384, + -0.01046435534954071 + ], + [ + -0.001941593480296433, + -0.009934288449585438, + -0.030797483399510384, + 0.0024240140337496996, + 0.015562526881694794, + -0.002221516566351056, + -0.0006432273075915873, + 0.005401918198913336, + -0.03572889417409897, + -0.04288479685783386, + -0.006301245652139187, + -0.004526414442807436 + ], + [ + -0.003805761691182852, + 0.0002799229696393013, + -0.0005360227660275996, + 0.005348315928131342, + -0.001858212286606431, + -0.013686447404325008, + -0.004967144224792719, + -0.009034961462020874, + -0.0025252627674490213, + 0.017218241468071938, + -0.07222311198711395, + 0.11154633015394211 + ], + [ + 0.0024657046888023615, + -0.002632467309013009, + -0.001923726056702435, + -0.07311350107192993, + -0.31993114948272705, + 0.1175914779305458, + 0.10545949637889862, + -0.12301722168922424, + -0.006235731299966574, + -0.0008814596221782267, + -0.04462984949350357, + -0.011071847751736641 + ], + [ + -0.0019058587495237589, + -0.9684025049209595, + -0.0008814596221782267, + 0.2527823746204376, + 0.0019594610203057528, + -1.6596992015838623, + 1.1019316911697388, + 0.005360227543860674, + -0.004883762914687395, + -0.19863811135292053, + -0.0055448575876653194, + -0.0016735821263864636 + ], + [ + -0.006622858811169863, + -0.0010482223005965352, + -0.024472415447235107, + -0.0010899128392338753, + -0.03688431903719902, + -0.06362292170524597, + -0.009570984169840813, + -0.0008814596221782267, + 0.04236366227269173, + -0.012167716398835182, + 0.11373211443424225, + -0.026473568752408028 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorbar": { + "ticksuffix": "%", + "title": { + "text": "Logit diff variation" + } + }, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "margin": { + "l": 100, + "r": 100 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Direct effect on Intermediate AE Heads' values from position 16 ( This)" + }, + "width": 700, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "scaleanchor": "y", + "showline": true, + "title": { + "text": "Head" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "showline": true, + "title": { + "text": "Layer" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
Logit diff variation: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + -0.004216712433844805, + 0.0004466856480576098, + 0.0006015366525389254, + -0.00004764646655530669, + -0.004788469988852739, + 0.000005955808319413336, + -0.031095273792743683, + -0.005050525534898043, + -0.008647833950817585, + 0.00303150643594563, + -0.019326597452163696, + -0.00889202207326889 + ], + [ + 0.004901630338281393, + 0.007712772116065025, + 0.00014889521116856486, + -0.0000416906586906407, + -0.023436104878783226, + -0.0008338132174685597, + -0.01326954085379839, + 0.017492210492491722, + -0.010023625567555428, + 0.003746203612536192, + 0.0018701237859204412, + 0.012137937359511852 + ], + [ + -0.007677036803215742, + -0.149583101272583, + 0.0014770404668524861, + 0.011679340153932571, + 0.002662246348336339, + -0.012834767811000347, + -0.006384626496583223, + 0.010130830109119415, + -0.0025193069595843554, + 0.011107582598924637, + 0.014413056895136833, + -0.004151198081672192 + ], + [ + -0.011036112904548645, + 0.021202677860856056, + 0.009106430225074291, + 0.0007921225042082369, + 0.0235194880515337, + 0.010607294738292694, + -0.014520260505378246, + 0.004526414442807436, + -0.006122571416199207, + -0.011500665917992592, + 0.0243771243840456, + -0.026890475302934647 + ], + [ + -0.00545552046969533, + 0.005693752784281969, + -0.09458121657371521, + -0.0055210343562066555, + 0.04214330017566681, + -0.0019594610203057528, + 0.004603839945048094, + -0.007581744343042374, + -0.030261460691690445, + -0.08964980393648148, + -0.0032280480954796076, + 0.006956384517252445 + ], + [ + -0.003239959944039583, + 0.006217863876372576, + -0.003472236217930913, + -0.03126204013824463, + -0.0038117174990475178, + -0.36781883239746094, + 0.0036985569167882204, + -0.02883206680417061, + 0.04476981237530708, + 0.22402773797512054, + -0.18110719323158264, + 0.159371480345726 + ], + [ + -0.008230927400290966, + -0.008624010719358921, + 0.03735483065247536, + -0.024948880076408386, + -1.8061554431915283, + 0.177620068192482, + 0.21879258751869202, + -0.1933731883764267, + -0.015276649035513401, + -0.08716028183698654, + -0.0988009050488472, + -0.007992695085704327 + ], + [ + -0.032619964331388474, + -6.772891521453857, + -0.015348118729889393, + 0.5428838729858398, + -0.004979055840522051, + -4.109564304351807, + 2.6856706142425537, + -0.021845905110239983, + -0.036133889108896255, + -0.3091183602809906, + -0.019362332299351692, + -0.00985686294734478 + ], + [ + -0.006366759072989225, + -0.011732942424714565, + -0.15355563163757324, + -0.0243116095662117, + -0.27287429571151733, + -0.8576394319534302, + -0.15307322144508362, + -0.029117945581674576, + 0.21484386920928955, + -0.029850512742996216, + 1.0204145908355713, + -0.08999822288751602 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorbar": { + "ticksuffix": "%", + "title": { + "text": "Logit diff variation" + } + }, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "margin": { + "l": 100, + "r": 100 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Direct effect on Intermediate AE Heads' values from position 17 ( movie)" + }, + "width": 700, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "scaleanchor": "y", + "showline": true, + "title": { + "text": "Head" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "showline": true, + "title": { + "text": "Layer" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.plotly.v1+json": { + "config": { + "plotlyServerURL": "https://plot.ly" + }, + "data": [ + { + "coloraxis": "coloraxis", + "hovertemplate": "Head: %{x}
Layer: %{y}
Logit diff variation: %{z}", + "name": "0", + "type": "heatmap", + "xaxis": "x", + "yaxis": "y", + "z": [ + [ + -0.0029123902786523104, + -0.000023823233277653344, + -0.004002302885055542, + 0, + -0.0006730063469149172, + -0.000011911616638826672, + -0.009029005654156208, + 0.0001310277875745669, + 0.0008695480646565557, + 0.001858212286606431, + -0.00150086369831115, + -0.0012804988073185086 + ], + [ + -0.0017986542079597712, + 0.0009469735086895525, + -0.000023823233277653344, + 0.0004049949930049479, + -0.003013639012351632, + -0.002322765300050378, + -0.0018105657072737813, + -0.01448452565819025, + 0.0031684900168329477, + 0.001173294265754521, + -0.0000416906586906407, + 0.004877807106822729 + ], + [ + -0.0013698359252884984, + -0.032876063138246536, + -0.0007027853862382472, + -0.0020785771775990725, + 0.0014591730432584882, + 0.0007802109466865659, + 0.0023704117629677057, + -0.0005300668999552727, + -0.0025014395359903574, + 0.00015485101903323084, + -0.0017093170899897814, + 0.0013460126938298345 + ], + [ + -0.00041690660873427987, + -0.00028587880660779774, + 0.0012924104230478406, + -0.0008159457356669009, + 0.002727760234847665, + -0.01536598615348339, + -0.0017152727814391255, + -0.0009767526062205434, + -0.0005002878606319427, + -0.0025073953438550234, + 0.0039070104248821735, + 0.00016080682689789683 + ], + [ + -0.0014591730432584882, + -0.0021738701034337282, + -0.01844513975083828, + -0.004139286931604147, + 0.006051101256161928, + -0.003365031909197569, + -0.0007087411941029131, + 0.0004764646873809397, + -0.012614401988685131, + -0.007325644604861736, + 0.0010839571477845311, + -0.0032042248640209436 + ], + [ + 0.0013043220387771726, + -0.0032042248640209436, + 0.0010065316455438733, + -0.01197713054716587, + -0.0011792500736191869, + -0.035109490156173706, + -0.0009767526062205434, + -0.002733716042712331, + 0.01070258766412735, + -0.0024716604966670275, + -0.032631874084472656, + 0.03319767862558365 + ], + [ + -0.002680113771930337, + -0.0020726213697344065, + -0.001137559418566525, + -0.018516607582569122, + -0.20272976160049438, + 0.027259735390543938, + 0.07687757909297943, + -0.05859026685357094, + -0.005652062129229307, + -0.024400947615504265, + -0.021482601761817932, + -0.008272617124021053 + ], + [ + 0.000011911616638826672, + -0.3026682138442993, + -0.010869350284337997, + 0.047854918986558914, + -0.0003811717324424535, + -0.27531617879867554, + 0.4322427809238434, + 0.011566179804503918, + -0.013662624172866344, + -0.03506780043244362, + -0.0293323565274477, + 0.003460324602201581 + ], + [ + 0.0028349647764116526, + -0.0028111415449529886, + -0.01983879692852497, + -0.001155426842160523, + -0.04279844090342522, + -0.13683171570301056, + -0.004002302885055542, + -0.005008834879845381, + 0.03773004561662674, + -0.020345041528344154, + 0.12124238908290863, + -0.031357333064079285 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ] + ] + } + ], + "layout": { + "coloraxis": { + "cmid": 0, + "colorbar": { + "ticksuffix": "%", + "title": { + "text": "Logit diff variation" + } + }, + "colorscale": [ + [ + 0, + "rgb(103,0,31)" + ], + [ + 0.1, + "rgb(178,24,43)" + ], + [ + 0.2, + "rgb(214,96,77)" + ], + [ + 0.3, + "rgb(244,165,130)" + ], + [ + 0.4, + "rgb(253,219,199)" + ], + [ + 0.5, + "rgb(247,247,247)" + ], + [ + 0.6, + "rgb(209,229,240)" + ], + [ + 0.7, + "rgb(146,197,222)" + ], + [ + 0.8, + "rgb(67,147,195)" + ], + [ + 0.9, + "rgb(33,102,172)" + ], + [ + 1, + "rgb(5,48,97)" + ] + ] + }, + "margin": { + "l": 100, + "r": 100 + }, + "template": { + "data": { + "bar": [ + { + "error_x": { + "color": "#2a3f5f" + }, + "error_y": { + "color": "#2a3f5f" + }, + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "bar" + } + ], + "barpolar": [ + { + "marker": { + "line": { + "color": "#E5ECF6", + "width": 0.5 + }, + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "barpolar" + } + ], + "carpet": [ + { + "aaxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "baxis": { + "endlinecolor": "#2a3f5f", + "gridcolor": "white", + "linecolor": "white", + "minorgridcolor": "white", + "startlinecolor": "#2a3f5f" + }, + "type": "carpet" + } + ], + "choropleth": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "choropleth" + } + ], + "contour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "contour" + } + ], + "contourcarpet": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "contourcarpet" + } + ], + "heatmap": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmap" + } + ], + "heatmapgl": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "heatmapgl" + } + ], + "histogram": [ + { + "marker": { + "pattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + } + }, + "type": "histogram" + } + ], + "histogram2d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2d" + } + ], + "histogram2dcontour": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "histogram2dcontour" + } + ], + "mesh3d": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "type": "mesh3d" + } + ], + "parcoords": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "parcoords" + } + ], + "pie": [ + { + "automargin": true, + "type": "pie" + } + ], + "scatter": [ + { + "fillpattern": { + "fillmode": "overlay", + "size": 10, + "solidity": 0.2 + }, + "type": "scatter" + } + ], + "scatter3d": [ + { + "line": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatter3d" + } + ], + "scattercarpet": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattercarpet" + } + ], + "scattergeo": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergeo" + } + ], + "scattergl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattergl" + } + ], + "scattermapbox": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scattermapbox" + } + ], + "scatterpolar": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolar" + } + ], + "scatterpolargl": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterpolargl" + } + ], + "scatterternary": [ + { + "marker": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "type": "scatterternary" + } + ], + "surface": [ + { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + }, + "colorscale": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "type": "surface" + } + ], + "table": [ + { + "cells": { + "fill": { + "color": "#EBF0F8" + }, + "line": { + "color": "white" + } + }, + "header": { + "fill": { + "color": "#C8D4E3" + }, + "line": { + "color": "white" + } + }, + "type": "table" + } + ] + }, + "layout": { + "annotationdefaults": { + "arrowcolor": "#2a3f5f", + "arrowhead": 0, + "arrowwidth": 1 + }, + "autotypenumbers": "strict", + "coloraxis": { + "colorbar": { + "outlinewidth": 0, + "ticks": "" + } + }, + "colorscale": { + "diverging": [ + [ + 0, + "#8e0152" + ], + [ + 0.1, + "#c51b7d" + ], + [ + 0.2, + "#de77ae" + ], + [ + 0.3, + "#f1b6da" + ], + [ + 0.4, + "#fde0ef" + ], + [ + 0.5, + "#f7f7f7" + ], + [ + 0.6, + "#e6f5d0" + ], + [ + 0.7, + "#b8e186" + ], + [ + 0.8, + "#7fbc41" + ], + [ + 0.9, + "#4d9221" + ], + [ + 1, + "#276419" + ] + ], + "sequential": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ], + "sequentialminus": [ + [ + 0, + "#0d0887" + ], + [ + 0.1111111111111111, + "#46039f" + ], + [ + 0.2222222222222222, + "#7201a8" + ], + [ + 0.3333333333333333, + "#9c179e" + ], + [ + 0.4444444444444444, + "#bd3786" + ], + [ + 0.5555555555555556, + "#d8576b" + ], + [ + 0.6666666666666666, + "#ed7953" + ], + [ + 0.7777777777777778, + "#fb9f3a" + ], + [ + 0.8888888888888888, + "#fdca26" + ], + [ + 1, + "#f0f921" + ] + ] + }, + "colorway": [ + "#636efa", + "#EF553B", + "#00cc96", + "#ab63fa", + "#FFA15A", + "#19d3f3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52" + ], + "font": { + "color": "#2a3f5f" + }, + "geo": { + "bgcolor": "white", + "lakecolor": "white", + "landcolor": "#E5ECF6", + "showlakes": true, + "showland": true, + "subunitcolor": "white" + }, + "hoverlabel": { + "align": "left" + }, + "hovermode": "closest", + "mapbox": { + "style": "light" + }, + "paper_bgcolor": "white", + "plot_bgcolor": "#E5ECF6", + "polar": { + "angularaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "radialaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "scene": { + "xaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "yaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + }, + "zaxis": { + "backgroundcolor": "#E5ECF6", + "gridcolor": "white", + "gridwidth": 2, + "linecolor": "white", + "showbackground": true, + "ticks": "", + "zerolinecolor": "white" + } + }, + "shapedefaults": { + "line": { + "color": "#2a3f5f" + } + }, + "ternary": { + "aaxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "baxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + }, + "bgcolor": "#E5ECF6", + "caxis": { + "gridcolor": "white", + "linecolor": "white", + "ticks": "" + } + }, + "title": { + "x": 0.05 + }, + "xaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + }, + "yaxis": { + "automargin": true, + "gridcolor": "white", + "linecolor": "white", + "ticks": "", + "title": { + "standoff": 15 + }, + "zerolinecolor": "white", + "zerolinewidth": 2 + } + } + }, + "title": { + "text": "Direct effect on Intermediate AE Heads' values from position 18 ( is)" + }, + "width": 700, + "xaxis": { + "anchor": "y", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "scaleanchor": "y", + "showline": true, + "title": { + "text": "Head" + } + }, + "yaxis": { + "anchor": "x", + "autorange": "reversed", + "constrain": "domain", + "domain": [ + 0, + 1 + ], + "linecolor": "black", + "linewidth": 1, + "mirror": true, + "showline": true, + "title": { + "text": "Layer" + } + } + } + }, + "text/html": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "for i in range(10, results[\"z\"].shape[0]):\n", + " imshow_p(\n", + " results[\"z\"][i][:10] * 100,\n", + " title=f\"Direct effect on Intermediate AE Heads' values from position {i} ({tokens[i]})\",\n", + " labels={\"x\": \"Head\", \"y\": \"Layer\", \"color\": \"Logit diff variation\"},\n", + " coloraxis=dict(colorbar_ticksuffix = \"%\"),\n", + " border=True,\n", + " width=700,\n", + " margin={\"r\": 100, \"l\": 100}\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5fccbc4d", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "49553014", + "metadata": {}, + "source": [ + "#### Various Attention Patterns" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "4f33c2da", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "heads = [(0, 4), (3, 11), (5, 7), (6, 4), (6, 7), (6,11), (7, 1), (7, 5), (8, 4), (8, 5), (9, 2), (9, 10), (10, 1), (10, 4), (11, 9)]\n", + "#heads = [(6, 4),(7, 1),(7,5)]\n", + "attn_list = []\n", + "for i in range(len(all_prompts)):\n", + " tokens, attn, names = get_attn_head_patterns(model, all_prompts[i], heads)\n", + " attn_list.append(attn)\n", + "tokens, _, names = get_attn_head_patterns(model, all_prompts[0], heads)\n", + "attn = torch.stack(attn_list, dim=0).mean(dim=0)\n", + "cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "4e3f0979", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "heads = [(0, 4), (3, 11), (5, 7), (6, 4), (6, 7), (6,11), (7, 1), (7, 5), (8, 4), (8, 5), (9, 2), (9, 10), (10, 1), (10, 4), (11, 9)]\n", + "#heads = [(6, 4),(7, 1),(7,5)]\n", + "attn_list = []\n", + "for i in range(len(all_prompts)):\n", + " tokens, attn, names = get_attn_head_patterns(model, all_prompts[i], heads)\n", + " attn_list.append(attn)\n", + "tokens, _, names = get_attn_head_patterns(model, all_prompts[1], heads)\n", + "attn = torch.stack(attn_list, dim=0).mean(dim=0)\n", + "cv.attention.attention_heads(tokens=tokens, attention=attn, attention_head_names=names)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a3dba6cf", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "jupytext": { + "custom_cell_magics": "kql" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/circuit_for_mood_inference_pythia2_8b_commas.ipynb b/circuit_for_mood_inference_pythia2_8b_commas.ipynb index d728e51..35413cb 100644 --- a/circuit_for_mood_inference_pythia2_8b_commas.ipynb +++ b/circuit_for_mood_inference_pythia2_8b_commas.ipynb @@ -66,69 +66,70 @@ "output_type": "stream", "text": [ "Collecting git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python\n", - " Cloning https://github.com/callummcdougall/CircuitsVis.git to /tmp/pip-req-build-75tggcq6\n", - " Running command git clone --filter=blob:none --quiet https://github.com/callummcdougall/CircuitsVis.git /tmp/pip-req-build-75tggcq6\n", + " Cloning https://github.com/callummcdougall/CircuitsVis.git to /tmp/pip-req-build-xv5h08f4\n", + " Running command git clone --filter=blob:none --quiet https://github.com/callummcdougall/CircuitsVis.git /tmp/pip-req-build-xv5h08f4\n", " Resolved https://github.com/callummcdougall/CircuitsVis.git to commit df9bfc252807e8b1c3a26c3c4796c18342c7fc71\n", " Installing build dependencies ... \u001b[?25ldone\n", "\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n", "\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n", - "\u001b[?25hRequirement already satisfied: importlib-metadata<6.0.0,>=5.1.0 in /usr/local/lib/python3.9/dist-packages (from circuitsvis==0.0.0) (5.2.0)\n", - "Requirement already satisfied: numpy<2.0,>=1.21 in /usr/local/lib/python3.9/dist-packages (from circuitsvis==0.0.0) (1.23.4)\n", - "Collecting torch<3.0,>=2.0\n", + "\u001b[?25hCollecting torch<3.0,>=2.0\n", " Downloading torch-2.0.1-cp39-cp39-manylinux1_x86_64.whl (619.9 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m619.9/619.9 MB\u001b[0m \u001b[31m4.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.9/dist-packages (from importlib-metadata<6.0.0,>=5.1.0->circuitsvis==0.0.0) (3.11.0)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m619.9/619.9 MB\u001b[0m \u001b[31m4.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hCollecting importlib-metadata<6.0.0,>=5.1.0\n", + " Downloading importlib_metadata-5.2.0-py3-none-any.whl (21 kB)\n", + "Requirement already satisfied: numpy<2.0,>=1.21 in /usr/local/lib/python3.9/dist-packages (from circuitsvis==0.0.0) (1.23.4)\n", + "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.9/dist-packages (from importlib-metadata<6.0.0,>=5.1.0->circuitsvis==0.0.0) (3.11.0)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.9/dist-packages (from torch<3.0,>=2.0->circuitsvis==0.0.0) (4.4.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from torch<3.0,>=2.0->circuitsvis==0.0.0) (3.9.0)\n", + "Collecting nvidia-curand-cu11==10.2.10.91\n", + " Downloading nvidia_curand_cu11-10.2.10.91-py3-none-manylinux1_x86_64.whl (54.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m54.6/54.6 MB\u001b[0m \u001b[31m38.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hCollecting nvidia-cuda-nvrtc-cu11==11.7.99\n", + " Downloading nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl (21.0 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.0/21.0 MB\u001b[0m \u001b[31m61.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: networkx in /usr/local/lib/python3.9/dist-packages (from torch<3.0,>=2.0->circuitsvis==0.0.0) (3.0)\n", "Collecting nvidia-cudnn-cu11==8.5.0.96\n", " Downloading nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl (557.1 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m557.1/557.1 MB\u001b[0m \u001b[31m4.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m557.1/557.1 MB\u001b[0m \u001b[31m4.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting nvidia-cuda-runtime-cu11==11.7.99\n", " Downloading nvidia_cuda_runtime_cu11-11.7.99-py3-none-manylinux1_x86_64.whl (849 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m849.3/849.3 kB\u001b[0m \u001b[31m100.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m849.3/849.3 kB\u001b[0m \u001b[31m73.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nvidia-nccl-cu11==2.14.3\n", + " Downloading nvidia_nccl_cu11-2.14.3-py3-none-manylinux1_x86_64.whl (177.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m177.1/177.1 MB\u001b[0m \u001b[31m13.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hCollecting sympy\n", + " Downloading sympy-1.12-py3-none-any.whl (5.7 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.7/5.7 MB\u001b[0m \u001b[31m89.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mta \u001b[36m0:00:01\u001b[0m\n", "\u001b[?25hCollecting nvidia-cusolver-cu11==11.4.0.1\n", " Downloading nvidia_cusolver_cu11-11.4.0.1-2-py3-none-manylinux1_x86_64.whl (102.6 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m102.6/102.6 MB\u001b[0m \u001b[31m22.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", - "\u001b[?25hCollecting nvidia-cusparse-cu11==11.7.4.91\n", - " Downloading nvidia_cusparse_cu11-11.7.4.91-py3-none-manylinux1_x86_64.whl (173.2 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m173.2/173.2 MB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", - "\u001b[?25hCollecting nvidia-cufft-cu11==10.9.0.58\n", - " Downloading nvidia_cufft_cu11-10.9.0.58-py3-none-manylinux1_x86_64.whl (168.4 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m168.4/168.4 MB\u001b[0m \u001b[31m15.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", - "\u001b[?25hCollecting triton==2.0.0\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m102.6/102.6 MB\u001b[0m \u001b[31m23.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: jinja2 in /usr/local/lib/python3.9/dist-packages (from torch<3.0,>=2.0->circuitsvis==0.0.0) (3.1.2)\n", + "Collecting triton==2.0.0\n", " Downloading triton-2.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (63.3 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m63.3/63.3 MB\u001b[0m \u001b[31m35.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from torch<3.0,>=2.0->circuitsvis==0.0.0) (3.9.0)\n", - "Collecting nvidia-cuda-cupti-cu11==11.7.101\n", - " Downloading nvidia_cuda_cupti_cu11-11.7.101-py3-none-manylinux1_x86_64.whl (11.8 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m11.8/11.8 MB\u001b[0m \u001b[31m81.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", - "\u001b[?25hCollecting sympy\n", - " Downloading sympy-1.12-py3-none-any.whl (5.7 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.7/5.7 MB\u001b[0m \u001b[31m124.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m\n", - "\u001b[?25hCollecting nvidia-cuda-nvrtc-cu11==11.7.99\n", - " Downloading nvidia_cuda_nvrtc_cu11-11.7.99-2-py3-none-manylinux1_x86_64.whl (21.0 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.0/21.0 MB\u001b[0m \u001b[31m87.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m63.3/63.3 MB\u001b[0m \u001b[31m32.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting nvidia-cublas-cu11==11.10.3.66\n", " Downloading nvidia_cublas_cu11-11.10.3.66-py3-none-manylinux1_x86_64.whl (317.1 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m317.1/317.1 MB\u001b[0m \u001b[31m8.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", - "\u001b[?25hCollecting nvidia-nccl-cu11==2.14.3\n", - " Downloading nvidia_nccl_cu11-2.14.3-py3-none-manylinux1_x86_64.whl (177.1 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m177.1/177.1 MB\u001b[0m \u001b[31m13.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m317.1/317.1 MB\u001b[0m \u001b[31m8.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hCollecting nvidia-cuda-cupti-cu11==11.7.101\n", + " Downloading nvidia_cuda_cupti_cu11-11.7.101-py3-none-manylinux1_x86_64.whl (11.8 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m11.8/11.8 MB\u001b[0m \u001b[31m89.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hCollecting nvidia-cufft-cu11==10.9.0.58\n", + " Downloading nvidia_cufft_cu11-10.9.0.58-py3-none-manylinux1_x86_64.whl (168.4 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m168.4/168.4 MB\u001b[0m \u001b[31m14.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hCollecting nvidia-cusparse-cu11==11.7.4.91\n", + " Downloading nvidia_cusparse_cu11-11.7.4.91-py3-none-manylinux1_x86_64.whl (173.2 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m173.2/173.2 MB\u001b[0m \u001b[31m14.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting nvidia-nvtx-cu11==11.7.91\n", " Downloading nvidia_nvtx_cu11-11.7.91-py3-none-manylinux1_x86_64.whl (98 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m98.6/98.6 kB\u001b[0m \u001b[31m37.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting nvidia-curand-cu11==10.2.10.91\n", - " Downloading nvidia_curand_cu11-10.2.10.91-py3-none-manylinux1_x86_64.whl (54.6 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m54.6/54.6 MB\u001b[0m \u001b[31m40.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: networkx in /usr/local/lib/python3.9/dist-packages (from torch<3.0,>=2.0->circuitsvis==0.0.0) (3.0)\n", - "Requirement already satisfied: jinja2 in /usr/local/lib/python3.9/dist-packages (from torch<3.0,>=2.0->circuitsvis==0.0.0) (3.1.2)\n", - "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.9/dist-packages (from torch<3.0,>=2.0->circuitsvis==0.0.0) (4.8.0)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m98.6/98.6 kB\u001b[0m \u001b[31m36.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: setuptools in /usr/local/lib/python3.9/dist-packages (from nvidia-cublas-cu11==11.10.3.66->torch<3.0,>=2.0->circuitsvis==0.0.0) (66.1.1)\n", "Requirement already satisfied: wheel in /usr/local/lib/python3.9/dist-packages (from nvidia-cublas-cu11==11.10.3.66->torch<3.0,>=2.0->circuitsvis==0.0.0) (0.35.1)\n", - "Requirement already satisfied: setuptools in /usr/local/lib/python3.9/dist-packages (from nvidia-cublas-cu11==11.10.3.66->torch<3.0,>=2.0->circuitsvis==0.0.0) (66.1.1)\n", "Collecting cmake\n", " Downloading cmake-3.27.5-py2.py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (26.1 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m26.1/26.1 MB\u001b[0m \u001b[31m69.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m26.1/26.1 MB\u001b[0m \u001b[31m62.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", "\u001b[?25hCollecting lit\n", - " Downloading lit-16.0.6.tar.gz (153 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m153.7/153.7 kB\u001b[0m \u001b[31m46.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + " Downloading lit-17.0.1.tar.gz (154 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m154.7/154.7 kB\u001b[0m \u001b[31m50.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Installing build dependencies ... \u001b[?25ldone\n", "\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n", "\u001b[?25h Installing backend dependencies ... \u001b[?25ldone\n", @@ -136,156 +137,228 @@ "\u001b[?25hRequirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.9/dist-packages (from jinja2->torch<3.0,>=2.0->circuitsvis==0.0.0) (2.1.2)\n", "Collecting mpmath>=0.19\n", " Downloading mpmath-1.3.0-py3-none-any.whl (536 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m536.2/536.2 kB\u001b[0m \u001b[31m85.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m536.2/536.2 kB\u001b[0m \u001b[31m84.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hBuilding wheels for collected packages: circuitsvis, lit\n", " Building wheel for circuitsvis (pyproject.toml) ... \u001b[?25ldone\n", "\u001b[?25h Created wheel for circuitsvis: filename=circuitsvis-0.0.0-py3-none-any.whl size=6170923 sha256=7101b262d0720ea697e8ca0933183e55c9d6b28a592d89e9c7f49c863ad572c6\n", - " Stored in directory: /tmp/pip-ephem-wheel-cache-o8xgq9ys/wheels/94/79/66/781b85e0732736078188d905010db6471f2787826da308336a\n", + " Stored in directory: /tmp/pip-ephem-wheel-cache-4f4iz7kv/wheels/94/79/66/781b85e0732736078188d905010db6471f2787826da308336a\n", " Building wheel for lit (pyproject.toml) ... \u001b[?25ldone\n", - "\u001b[?25h Created wheel for lit: filename=lit-16.0.6-py3-none-any.whl size=93584 sha256=ece179b1af5bba4148a8a54a5bd442d8a4605667b39881c9f78655639b837b10\n", - " Stored in directory: /root/.cache/pip/wheels/dd/a1/9c/f4e974f934c7a715a884a029e8b2b0b438486e654058fe8c80\n", + "\u001b[?25h Created wheel for lit: filename=lit-17.0.1-py3-none-any.whl size=93254 sha256=84ac8fe7caef55b2f2661a72c7060fe8f6e7d52bf8c9f2b81fb485a47c40ab49\n", + " Stored in directory: /root/.cache/pip/wheels/cb/cb/3e/f4a695a76daf22029d463b6ddcd3a58357c5fbf011b6aeb989\n", "Successfully built circuitsvis lit\n", - "Installing collected packages: mpmath, lit, cmake, sympy, nvidia-nvtx-cu11, nvidia-nccl-cu11, nvidia-cusparse-cu11, nvidia-curand-cu11, nvidia-cufft-cu11, nvidia-cuda-runtime-cu11, nvidia-cuda-nvrtc-cu11, nvidia-cuda-cupti-cu11, nvidia-cublas-cu11, nvidia-cusolver-cu11, nvidia-cudnn-cu11, triton, torch, circuitsvis\n", + "Installing collected packages: mpmath, lit, cmake, sympy, nvidia-nvtx-cu11, nvidia-nccl-cu11, nvidia-cusparse-cu11, nvidia-curand-cu11, nvidia-cufft-cu11, nvidia-cuda-runtime-cu11, nvidia-cuda-nvrtc-cu11, nvidia-cuda-cupti-cu11, nvidia-cublas-cu11, importlib-metadata, nvidia-cusolver-cu11, nvidia-cudnn-cu11, triton, torch, circuitsvis\n", + " Attempting uninstall: importlib-metadata\n", + " Found existing installation: importlib-metadata 6.0.0\n", + " Uninstalling importlib-metadata-6.0.0:\n", + " Successfully uninstalled importlib-metadata-6.0.0\n", " Attempting uninstall: torch\n", " Found existing installation: torch 1.12.1+cu116\n", " Uninstalling torch-1.12.1+cu116:\n", " Successfully uninstalled torch-1.12.1+cu116\n", - " Attempting uninstall: circuitsvis\n", - " Found existing installation: circuitsvis 1.41.0\n", - " Uninstalling circuitsvis-1.41.0:\n", - " Successfully uninstalled circuitsvis-1.41.0\n", "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", "torchvision 0.13.1+cu116 requires torch==1.12.1, but you have torch 2.0.1 which is incompatible.\n", "torchaudio 0.12.1+cu116 requires torch==1.12.1, but you have torch 2.0.1 which is incompatible.\u001b[0m\u001b[31m\n", - "\u001b[0mSuccessfully installed circuitsvis-0.0.0 cmake-3.27.5 lit-16.0.6 mpmath-1.3.0 nvidia-cublas-cu11-11.10.3.66 nvidia-cuda-cupti-cu11-11.7.101 nvidia-cuda-nvrtc-cu11-11.7.99 nvidia-cuda-runtime-cu11-11.7.99 nvidia-cudnn-cu11-8.5.0.96 nvidia-cufft-cu11-10.9.0.58 nvidia-curand-cu11-10.2.10.91 nvidia-cusolver-cu11-11.4.0.1 nvidia-cusparse-cu11-11.7.4.91 nvidia-nccl-cu11-2.14.3 nvidia-nvtx-cu11-11.7.91 sympy-1.12 torch-2.0.1 triton-2.0.0\n", + "\u001b[0mSuccessfully installed circuitsvis-0.0.0 cmake-3.27.5 importlib-metadata-5.2.0 lit-17.0.1 mpmath-1.3.0 nvidia-cublas-cu11-11.10.3.66 nvidia-cuda-cupti-cu11-11.7.101 nvidia-cuda-nvrtc-cu11-11.7.99 nvidia-cuda-runtime-cu11-11.7.99 nvidia-cudnn-cu11-8.5.0.96 nvidia-cufft-cu11-10.9.0.58 nvidia-curand-cu11-10.2.10.91 nvidia-cusolver-cu11-11.4.0.1 nvidia-cusparse-cu11-11.7.4.91 nvidia-nccl-cu11-2.14.3 nvidia-nvtx-cu11-11.7.91 sympy-1.12 torch-2.0.1 triton-2.0.0\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0mRequirement already satisfied: transformer_lens in /usr/local/lib/python3.9/dist-packages (1.6.1)\n", + "\u001b[0mCollecting transformer_lens\n", + " Downloading transformer_lens-1.6.1-py3-none-any.whl (109 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m109.9/109.9 kB\u001b[0m \u001b[31m23.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting beartype<0.15.0,>=0.14.1\n", + " Downloading beartype-0.14.1-py3-none-any.whl (739 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m739.7/739.7 kB\u001b[0m \u001b[31m96.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting einops>=0.6.0\n", + " Downloading einops-0.6.1-py3-none-any.whl (42 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m42.2/42.2 kB\u001b[0m \u001b[31m17.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: numpy>=1.21 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (1.23.4)\n", "Requirement already satisfied: pandas>=1.1.5 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (1.5.0)\n", - "Requirement already satisfied: einops>=0.6.0 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (0.6.1)\n", - "Requirement already satisfied: numpy>=1.21 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (1.23.4)\n", + "Collecting transformers>=4.25.1\n", + " Downloading transformers-4.33.3-py3-none-any.whl (7.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.6/7.6 MB\u001b[0m \u001b[31m109.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: torch>=1.10 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (2.0.1)\n", + "Collecting fancy-einsum>=0.0.3\n", + " Downloading fancy_einsum-0.0.3-py3-none-any.whl (6.2 kB)\n", + "Collecting wandb>=0.13.5\n", + " Downloading wandb-0.15.11-py3-none-any.whl (2.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.1/2.1 MB\u001b[0m \u001b[31m93.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting jaxtyping>=0.2.11\n", + " Downloading jaxtyping-0.2.22-py3-none-any.whl (25 kB)\n", + "Collecting datasets>=2.7.1\n", + " Downloading datasets-2.14.5-py3-none-any.whl (519 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m519.6/519.6 kB\u001b[0m \u001b[31m89.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: tqdm>=4.64.1 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (4.64.1)\n", "Requirement already satisfied: rich>=12.6.0 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (13.2.0)\n", - "Requirement already satisfied: transformers>=4.25.1 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (4.33.2)\n", - "Requirement already satisfied: jaxtyping>=0.2.11 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (0.2.13)\n", - "Requirement already satisfied: wandb>=0.13.5 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (0.15.11)\n", - "Requirement already satisfied: fancy-einsum>=0.0.3 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (0.0.3)\n", - "Requirement already satisfied: tqdm>=4.64.1 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (4.64.1)\n", - "Requirement already satisfied: datasets>=2.7.1 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (2.14.5)\n", - "Requirement already satisfied: beartype<0.15.0,>=0.14.1 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (0.14.1)\n", - "Requirement already satisfied: torch>=1.10 in /usr/local/lib/python3.9/dist-packages (from transformer_lens) (2.0.1)\n", - "Requirement already satisfied: xxhash in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (3.2.0)\n", - "Requirement already satisfied: huggingface-hub<1.0.0,>=0.14.0 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (0.17.2)\n", "Requirement already satisfied: multiprocess in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (0.70.13)\n", - "Requirement already satisfied: aiohttp in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (3.8.3)\n", - "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (10.0.1)\n", - "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (0.3.5.1)\n", "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (2.28.2)\n", "Requirement already satisfied: fsspec[http]<2023.9.0,>=2023.1.0 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (2023.1.0)\n", + "Requirement already satisfied: xxhash in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (3.2.0)\n", + "Requirement already satisfied: aiohttp in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (3.8.3)\n", + "Collecting huggingface-hub<1.0.0,>=0.14.0\n", + " Downloading huggingface_hub-0.17.3-py3-none-any.whl (295 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m295.0/295.0 kB\u001b[0m \u001b[31m74.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (5.4.1)\n", + "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (0.3.5.1)\n", + "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (10.0.1)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (23.0)\n", - "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.9/dist-packages (from datasets>=2.7.1->transformer_lens) (5.4.1)\n", - "Requirement already satisfied: typing-extensions>=3.7.4.1 in /usr/local/lib/python3.9/dist-packages (from jaxtyping>=0.2.11->transformer_lens) (4.8.0)\n", - "Requirement already satisfied: typeguard>=2.13.3 in /usr/local/lib/python3.9/dist-packages (from jaxtyping>=0.2.11->transformer_lens) (4.1.5)\n", - "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/dist-packages (from pandas>=1.1.5->transformer_lens) (2022.7.1)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.1 in /usr/local/lib/python3.9/dist-packages (from jaxtyping>=0.2.11->transformer_lens) (4.4.0)\n", + "Collecting typeguard>=2.13.3\n", + " Downloading typeguard-4.1.5-py3-none-any.whl (34 kB)\n", "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.9/dist-packages (from pandas>=1.1.5->transformer_lens) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/dist-packages (from pandas>=1.1.5->transformer_lens) (2022.7.1)\n", "Requirement already satisfied: markdown-it-py<3.0.0,>=2.1.0 in /usr/local/lib/python3.9/dist-packages (from rich>=12.6.0->transformer_lens) (2.1.0)\n", "Requirement already satisfied: pygments<3.0.0,>=2.6.0 in /usr/local/lib/python3.9/dist-packages (from rich>=12.6.0->transformer_lens) (2.14.0)\n", - "Requirement already satisfied: networkx in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->transformer_lens) (3.0)\n", - "Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->transformer_lens) (8.5.0.96)\n", - "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->transformer_lens) (2.0.0)\n", - "Requirement already satisfied: sympy in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->transformer_lens) (1.12)\n", - "Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->transformer_lens) (11.10.3.66)\n", "Requirement already satisfied: nvidia-curand-cu11==10.2.10.91 in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->transformer_lens) (10.2.10.91)\n", - "Requirement already satisfied: nvidia-nccl-cu11==2.14.3 in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->transformer_lens) (2.14.3)\n", "Requirement already satisfied: nvidia-cufft-cu11==10.9.0.58 in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->transformer_lens) (10.9.0.58)\n", - "Requirement already satisfied: nvidia-cusolver-cu11==11.4.0.1 in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->transformer_lens) (11.4.0.1)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->transformer_lens) (1.12)\n", + "Requirement already satisfied: nvidia-nccl-cu11==2.14.3 in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->transformer_lens) (2.14.3)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->transformer_lens) (11.7.99)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->transformer_lens) (11.7.99)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu11==11.7.101 in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->transformer_lens) (11.7.101)\n", + "Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->transformer_lens) (8.5.0.96)\n", + "Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->transformer_lens) (11.10.3.66)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->transformer_lens) (3.9.0)\n", "Requirement already satisfied: nvidia-cusparse-cu11==11.7.4.91 in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->transformer_lens) (11.7.4.91)\n", "Requirement already satisfied: nvidia-nvtx-cu11==11.7.91 in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->transformer_lens) (11.7.91)\n", + "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->transformer_lens) (2.0.0)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->transformer_lens) (3.0)\n", + "Requirement already satisfied: nvidia-cusolver-cu11==11.4.0.1 in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->transformer_lens) (11.4.0.1)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->transformer_lens) (3.1.2)\n", - "Requirement already satisfied: nvidia-cuda-cupti-cu11==11.7.101 in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->transformer_lens) (11.7.101)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->transformer_lens) (3.9.0)\n", - "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->transformer_lens) (11.7.99)\n", - "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /usr/local/lib/python3.9/dist-packages (from torch>=1.10->transformer_lens) (11.7.99)\n", - "Requirement already satisfied: setuptools in /usr/local/lib/python3.9/dist-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=1.10->transformer_lens) (66.1.1)\n", "Requirement already satisfied: wheel in /usr/local/lib/python3.9/dist-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=1.10->transformer_lens) (0.35.1)\n", - "Requirement already satisfied: lit in /usr/local/lib/python3.9/dist-packages (from triton==2.0.0->torch>=1.10->transformer_lens) (16.0.6)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.9/dist-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=1.10->transformer_lens) (66.1.1)\n", + "Requirement already satisfied: lit in /usr/local/lib/python3.9/dist-packages (from triton==2.0.0->torch>=1.10->transformer_lens) (17.0.1)\n", "Requirement already satisfied: cmake in /usr/local/lib/python3.9/dist-packages (from triton==2.0.0->torch>=1.10->transformer_lens) (3.27.5)\n", - "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.9/dist-packages (from transformers>=4.25.1->transformer_lens) (0.12.1)\n", - "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.9/dist-packages (from transformers>=4.25.1->transformer_lens) (0.3.3)\n", + "Collecting safetensors>=0.3.1\n", + " Downloading safetensors-0.3.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m107.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.9/dist-packages (from transformers>=4.25.1->transformer_lens) (0.12.1)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.9/dist-packages (from transformers>=4.25.1->transformer_lens) (2022.10.31)\n", - "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.15.0 in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (3.20.3)\n", - "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (3.1.30)\n", "Requirement already satisfied: pathtools in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (0.1.2)\n", - "Requirement already satisfied: appdirs>=1.4.3 in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (1.4.4)\n", "Requirement already satisfied: Click!=8.0.0,>=7.1 in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (8.1.3)\n", - "Requirement already satisfied: docker-pycreds>=0.4.0 in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (0.4.0)\n", + "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (3.1.30)\n", "Requirement already satisfied: sentry-sdk>=1.0.0 in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (1.14.0)\n", - "Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (5.9.4)\n", "Requirement already satisfied: setproctitle in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (1.3.2)\n", + "Collecting appdirs>=1.4.3\n", + " Downloading appdirs-1.4.4-py2.py3-none-any.whl (9.6 kB)\n", + "Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (5.9.4)\n", + "Requirement already satisfied: docker-pycreds>=0.4.0 in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (0.4.0)\n", + "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.15.0 in /usr/local/lib/python3.9/dist-packages (from wandb>=0.13.5->transformer_lens) (3.19.6)\n", "Requirement already satisfied: six>=1.4.0 in /usr/lib/python3/dist-packages (from docker-pycreds>=0.4.0->wandb>=0.13.5->transformer_lens) (1.14.0)\n", - "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.8.2)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (6.0.4)\n", - "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.3.1)\n", "Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (2.1.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (18.2.0)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.8.2)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.3.1)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (1.3.3)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (4.0.2)\n", - "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.9/dist-packages (from aiohttp->datasets>=2.7.1->transformer_lens) (18.2.0)\n", "Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.9/dist-packages (from GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer_lens) (4.0.10)\n", "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.9/dist-packages (from markdown-it-py<3.0.0,>=2.1.0->rich>=12.6.0->transformer_lens) (0.1.2)\n", "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.9/dist-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (1.26.14)\n", - "Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3/dist-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (2.8)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/lib/python3/dist-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (2019.11.28)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/lib/python3/dist-packages (from requests>=2.19.0->datasets>=2.7.1->transformer_lens) (2.8)\n", "Requirement already satisfied: importlib-metadata>=3.6 in /usr/local/lib/python3.9/dist-packages (from typeguard>=2.13.3->jaxtyping>=0.2.11->transformer_lens) (5.2.0)\n", + "Collecting typing-extensions>=3.7.4.1\n", + " Downloading typing_extensions-4.8.0-py3-none-any.whl (31 kB)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.9/dist-packages (from jinja2->torch>=1.10->transformer_lens) (2.1.2)\n", "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.9/dist-packages (from sympy->torch>=1.10->transformer_lens) (1.3.0)\n", "Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.9/dist-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb>=0.13.5->transformer_lens) (5.0.0)\n", "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.9/dist-packages (from importlib-metadata>=3.6->typeguard>=2.13.3->jaxtyping>=0.2.11->transformer_lens) (3.11.0)\n", + "Installing collected packages: safetensors, appdirs, typing-extensions, fancy-einsum, einops, beartype, typeguard, huggingface-hub, wandb, transformers, jaxtyping, datasets, transformer_lens\n", + " Attempting uninstall: typing-extensions\n", + " Found existing installation: typing_extensions 4.4.0\n", + " Uninstalling typing_extensions-4.4.0:\n", + " Successfully uninstalled typing_extensions-4.4.0\n", + " Attempting uninstall: huggingface-hub\n", + " Found existing installation: huggingface-hub 0.12.0\n", + " Uninstalling huggingface-hub-0.12.0:\n", + " Successfully uninstalled huggingface-hub-0.12.0\n", + " Attempting uninstall: wandb\n", + " Found existing installation: wandb 0.13.4\n", + " Uninstalling wandb-0.13.4:\n", + " Successfully uninstalled wandb-0.13.4\n", + " Attempting uninstall: transformers\n", + " Found existing installation: transformers 4.21.3\n", + " Uninstalling transformers-4.21.3:\n", + " Successfully uninstalled transformers-4.21.3\n", + " Attempting uninstall: datasets\n", + " Found existing installation: datasets 2.4.0\n", + " Uninstalling datasets-2.4.0:\n", + " Successfully uninstalled datasets-2.4.0\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "torchvision 0.13.1+cu116 requires torch==1.12.1, but you have torch 2.0.1 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0mSuccessfully installed appdirs-1.4.4 beartype-0.14.1 datasets-2.14.5 einops-0.6.1 fancy-einsum-0.0.3 huggingface-hub-0.17.3 jaxtyping-0.2.22 safetensors-0.3.3 transformer_lens-1.6.1 transformers-4.33.3 typeguard-4.1.5 typing-extensions-4.8.0 wandb-0.15.11\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0mRequirement already satisfied: jaxtyping==0.2.13 in /usr/local/lib/python3.9/dist-packages (0.2.13)\n", - "Requirement already satisfied: typeguard>=2.13.3 in /usr/local/lib/python3.9/dist-packages (from jaxtyping==0.2.13) (4.1.5)\n", - "Requirement already satisfied: typing-extensions>=3.7.4.1 in /usr/local/lib/python3.9/dist-packages (from jaxtyping==0.2.13) (4.8.0)\n", + "\u001b[0mCollecting jaxtyping==0.2.13\n", + " Downloading jaxtyping-0.2.13-py3-none-any.whl (19 kB)\n", "Requirement already satisfied: numpy>=1.20.0 in /usr/local/lib/python3.9/dist-packages (from jaxtyping==0.2.13) (1.23.4)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.1 in /usr/local/lib/python3.9/dist-packages (from jaxtyping==0.2.13) (4.8.0)\n", + "Requirement already satisfied: typeguard>=2.13.3 in /usr/local/lib/python3.9/dist-packages (from jaxtyping==0.2.13) (4.1.5)\n", "Requirement already satisfied: importlib-metadata>=3.6 in /usr/local/lib/python3.9/dist-packages (from typeguard>=2.13.3->jaxtyping==0.2.13) (5.2.0)\n", "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.9/dist-packages (from importlib-metadata>=3.6->typeguard>=2.13.3->jaxtyping==0.2.13) (3.11.0)\n", + "Installing collected packages: jaxtyping\n", + " Attempting uninstall: jaxtyping\n", + " Found existing installation: jaxtyping 0.2.22\n", + " Uninstalling jaxtyping-0.2.22:\n", + " Successfully uninstalled jaxtyping-0.2.22\n", + "Successfully installed jaxtyping-0.2.13\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mRequirement already satisfied: einops in /usr/local/lib/python3.9/dist-packages (0.6.1)\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0mRequirement already satisfied: protobuf==3.20.* in /usr/local/lib/python3.9/dist-packages (3.20.3)\n", + "\u001b[0mCollecting protobuf==3.20.*\n", + " Downloading protobuf-3.20.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.0 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m41.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: protobuf\n", + " Attempting uninstall: protobuf\n", + " Found existing installation: protobuf 3.19.6\n", + " Uninstalling protobuf-3.19.6:\n", + " Successfully uninstalled protobuf-3.19.6\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "tensorflow 2.9.2 requires protobuf<3.20,>=3.9.2, but you have protobuf 3.20.3 which is incompatible.\n", + "tensorboard 2.9.1 requires protobuf<3.20,>=3.9.2, but you have protobuf 3.20.3 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0mSuccessfully installed protobuf-3.20.3\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0mRequirement already satisfied: plotly in /usr/local/lib/python3.9/dist-packages (5.17.0)\n", - "Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.9/dist-packages (from plotly) (8.2.3)\n", + "\u001b[0mCollecting plotly\n", + " Downloading plotly-5.17.0-py2.py3-none-any.whl (15.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m15.6/15.6 MB\u001b[0m \u001b[31m58.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hCollecting tenacity>=6.2.0\n", + " Downloading tenacity-8.2.3-py3-none-any.whl (24 kB)\n", "Requirement already satisfied: packaging in /usr/local/lib/python3.9/dist-packages (from plotly) (23.0)\n", + "Installing collected packages: tenacity, plotly\n", + "Successfully installed plotly-5.17.0 tenacity-8.2.3\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0mRequirement already satisfied: torchtyping in /usr/local/lib/python3.9/dist-packages (0.1.4)\n", + "\u001b[0mCollecting torchtyping\n", + " Downloading torchtyping-0.1.4-py3-none-any.whl (17 kB)\n", "Requirement already satisfied: typeguard>=2.11.1 in /usr/local/lib/python3.9/dist-packages (from torchtyping) (4.1.5)\n", "Requirement already satisfied: torch>=1.7.0 in /usr/local/lib/python3.9/dist-packages (from torchtyping) (2.0.1)\n", - "Requirement already satisfied: nvidia-curand-cu11==10.2.10.91 in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (10.2.10.91)\n", - "Requirement already satisfied: jinja2 in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (3.1.2)\n", - "Requirement already satisfied: nvidia-nvtx-cu11==11.7.91 in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (11.7.91)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (1.12)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu11==11.7.101 in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (11.7.101)\n", "Requirement already satisfied: nvidia-cusparse-cu11==11.7.4.91 in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (11.7.4.91)\n", - "Requirement already satisfied: networkx in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (3.0)\n", + "Requirement already satisfied: nvidia-nccl-cu11==2.14.3 in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (2.14.3)\n", + "Requirement already satisfied: nvidia-nvtx-cu11==11.7.91 in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (11.7.91)\n", + "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (2.0.0)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (3.1.2)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (11.7.99)\n", + "Requirement already satisfied: nvidia-cufft-cu11==10.9.0.58 in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (10.9.0.58)\n", "Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (8.5.0.96)\n", + "Requirement already satisfied: nvidia-curand-cu11==10.2.10.91 in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (10.2.10.91)\n", + "Requirement already satisfied: nvidia-cusolver-cu11==11.4.0.1 in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (11.4.0.1)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (11.7.99)\n", "Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (11.10.3.66)\n", - "Requirement already satisfied: sympy in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (1.12)\n", - "Requirement already satisfied: nvidia-cufft-cu11==10.9.0.58 in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (10.9.0.58)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (3.9.0)\n", - "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (2.0.0)\n", - "Requirement already satisfied: nvidia-cusolver-cu11==11.4.0.1 in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (11.4.0.1)\n", - "Requirement already satisfied: nvidia-nccl-cu11==2.14.3 in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (2.14.3)\n", - "Requirement already satisfied: nvidia-cuda-cupti-cu11==11.7.101 in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (11.7.101)\n", "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (4.8.0)\n", - "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (11.7.99)\n", - "Requirement already satisfied: setuptools in /usr/local/lib/python3.9/dist-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=1.7.0->torchtyping) (66.1.1)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.9/dist-packages (from torch>=1.7.0->torchtyping) (3.0)\n", "Requirement already satisfied: wheel in /usr/local/lib/python3.9/dist-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=1.7.0->torchtyping) (0.35.1)\n", - "Requirement already satisfied: lit in /usr/local/lib/python3.9/dist-packages (from triton==2.0.0->torch>=1.7.0->torchtyping) (16.0.6)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.9/dist-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=1.7.0->torchtyping) (66.1.1)\n", + "Requirement already satisfied: lit in /usr/local/lib/python3.9/dist-packages (from triton==2.0.0->torch>=1.7.0->torchtyping) (17.0.1)\n", "Requirement already satisfied: cmake in /usr/local/lib/python3.9/dist-packages (from triton==2.0.0->torch>=1.7.0->torchtyping) (3.27.5)\n", "Requirement already satisfied: importlib-metadata>=3.6 in /usr/local/lib/python3.9/dist-packages (from typeguard>=2.11.1->torchtyping) (5.2.0)\n", "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.9/dist-packages (from importlib-metadata>=3.6->typeguard>=2.11.1->torchtyping) (3.11.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.9/dist-packages (from jinja2->torch>=1.7.0->torchtyping) (2.1.2)\n", "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.9/dist-packages (from sympy->torch>=1.7.0->torchtyping) (1.3.0)\n", + "Installing collected packages: torchtyping\n", + "Successfully installed torchtyping-0.1.4\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0mCollecting git+https://github.com/neelnanda-io/neel-plotly.git\n", - " Cloning https://github.com/neelnanda-io/neel-plotly.git to /tmp/pip-req-build-_dov21ld\n", - " Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/neel-plotly.git /tmp/pip-req-build-_dov21ld\n", + " Cloning https://github.com/neelnanda-io/neel-plotly.git to /tmp/pip-req-build-ej11fi36\n", + " Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/neel-plotly.git /tmp/pip-req-build-ej11fi36\n", " Resolved https://github.com/neelnanda-io/neel-plotly.git to commit 6dc24b26f8dec991908479d7445dae496b3430b7\n", " Preparing metadata (setup.py) ... \u001b[?25ldone\n", "\u001b[?25hRequirement already satisfied: einops in /usr/local/lib/python3.9/dist-packages (from neel-plotly==0.0.0) (0.6.1)\n", @@ -296,34 +369,45 @@ "Requirement already satisfied: pandas in /usr/local/lib/python3.9/dist-packages (from neel-plotly==0.0.0) (1.5.0)\n", "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.9/dist-packages (from pandas->neel-plotly==0.0.0) (2.8.2)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/dist-packages (from pandas->neel-plotly==0.0.0) (2022.7.1)\n", - "Requirement already satisfied: packaging in /usr/local/lib/python3.9/dist-packages (from plotly->neel-plotly==0.0.0) (23.0)\n", "Requirement already satisfied: tenacity>=6.2.0 in /usr/local/lib/python3.9/dist-packages (from plotly->neel-plotly==0.0.0) (8.2.3)\n", - "Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (11.10.3.66)\n", - "Requirement already satisfied: networkx in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (3.0)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.9/dist-packages (from plotly->neel-plotly==0.0.0) (23.0)\n", + "Requirement already satisfied: nvidia-cusparse-cu11==11.7.4.91 in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (11.7.4.91)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (1.12)\n", + "Requirement already satisfied: nvidia-nvtx-cu11==11.7.91 in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (11.7.91)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (3.1.2)\n", "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (4.8.0)\n", - "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (2.0.0)\n", - "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (11.7.99)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu11==11.7.101 in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (11.7.101)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (3.0)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (11.7.99)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (3.9.0)\n", "Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (8.5.0.96)\n", - "Requirement already satisfied: sympy in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (1.12)\n", - "Requirement already satisfied: nvidia-cusparse-cu11==11.7.4.91 in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (11.7.4.91)\n", - "Requirement already satisfied: nvidia-nccl-cu11==2.14.3 in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (2.14.3)\n", - "Requirement already satisfied: nvidia-nvtx-cu11==11.7.91 in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (11.7.91)\n", - "Requirement already satisfied: nvidia-cusolver-cu11==11.4.0.1 in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (11.4.0.1)\n", + "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (2.0.0)\n", "Requirement already satisfied: nvidia-curand-cu11==10.2.10.91 in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (10.2.10.91)\n", + "Requirement already satisfied: nvidia-cusolver-cu11==11.4.0.1 in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (11.4.0.1)\n", + "Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (11.10.3.66)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (11.7.99)\n", + "Requirement already satisfied: nvidia-nccl-cu11==2.14.3 in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (2.14.3)\n", "Requirement already satisfied: nvidia-cufft-cu11==10.9.0.58 in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (10.9.0.58)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (3.9.0)\n", - "Requirement already satisfied: jinja2 in /usr/local/lib/python3.9/dist-packages (from torch->neel-plotly==0.0.0) (3.1.2)\n", - "Requirement already satisfied: setuptools in /usr/local/lib/python3.9/dist-packages (from nvidia-cublas-cu11==11.10.3.66->torch->neel-plotly==0.0.0) (66.1.1)\n", "Requirement already satisfied: wheel in /usr/local/lib/python3.9/dist-packages (from nvidia-cublas-cu11==11.10.3.66->torch->neel-plotly==0.0.0) (0.35.1)\n", - "Requirement already satisfied: lit in /usr/local/lib/python3.9/dist-packages (from triton==2.0.0->torch->neel-plotly==0.0.0) (16.0.6)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.9/dist-packages (from nvidia-cublas-cu11==11.10.3.66->torch->neel-plotly==0.0.0) (66.1.1)\n", "Requirement already satisfied: cmake in /usr/local/lib/python3.9/dist-packages (from triton==2.0.0->torch->neel-plotly==0.0.0) (3.27.5)\n", + "Requirement already satisfied: lit in /usr/local/lib/python3.9/dist-packages (from triton==2.0.0->torch->neel-plotly==0.0.0) (17.0.1)\n", "Requirement already satisfied: six>=1.5 in /usr/lib/python3/dist-packages (from python-dateutil>=2.8.1->pandas->neel-plotly==0.0.0) (1.14.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.9/dist-packages (from jinja2->torch->neel-plotly==0.0.0) (2.1.2)\n", "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.9/dist-packages (from sympy->torch->neel-plotly==0.0.0) (1.3.0)\n", + "Building wheels for collected packages: neel-plotly\n", + " Building wheel for neel-plotly (setup.py) ... \u001b[?25ldone\n", + "\u001b[?25h Created wheel for neel-plotly: filename=neel_plotly-0.0.0-py3-none-any.whl size=10186 sha256=4521b03c4082e2532341a0c18d55bc6cc1e65f682e953baf8c05055c68f186f6\n", + " Stored in directory: /tmp/pip-ephem-wheel-cache-jlecotuo/wheels/e1/3c/c0/b5897c402b85e7fc329feb205ad5948b518f0423d891a79f7f\n", + "Successfully built neel-plotly\n", + "Installing collected packages: neel-plotly\n", + "Successfully installed neel-plotly-0.0.0\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0mRequirement already satisfied: kaleido in /usr/local/lib/python3.9/dist-packages (0.2.1)\n", + "\u001b[0mCollecting kaleido\n", + " Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl (79.9 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m79.9/79.9 MB\u001b[0m \u001b[31m24.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n", + "\u001b[?25hInstalling collected packages: kaleido\n", + "Successfully installed kaleido-0.2.1\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", "\u001b[0m" ] @@ -519,7 +603,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "a7fd7a58e465451dab304adc728e3450", + "model_id": "73d62146761a46219c18780d053c2687", "version_major": 2, "version_minor": 0 }, @@ -533,7 +617,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "a676658b091d43d997ea0a9e8e948ad6", + "model_id": "88bf40cc8eec4473b14c92099eb51423", "version_major": 2, "version_minor": 0 }, @@ -547,7 +631,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "21cd938cdbaa4b58995aaf7bb75ad7fd", + "model_id": "ff695288ae6a4f05b011407432e83d93", "version_major": 2, "version_minor": 0 }, @@ -561,7 +645,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "a965afff2a884accb2618686a99dfafe", + "model_id": "9693966eda78459da3bb3d65f78c6236", "version_major": 2, "version_minor": 0 }, @@ -575,7 +659,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "61aaeb12fe3e47fc8030a68df1308514", + "model_id": "02a6328905414ec087609d8a1889f39b", "version_major": 2, "version_minor": 0 }, @@ -11358,10 +11442,39 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 30, "id": "4c000ecf", "metadata": {}, "outputs": [ + { + "data": { + "text/html": [ + " \n", + " " + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { "application/vnd.plotly.v1+json": { @@ -13383,9 +13496,9 @@ } }, "text/html": [ - "
" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# V-weighted version\n", - "name_1_attn = plot_attention(\n", + "#name_1_attn = plot_attention(\n", + "plot_attention(\n", " model, \n", " all_prompts[0],\n", - " MID_IAM_CM_ATTN_HEADS,\n", + " DE_CM_ATTN_HEADS,\n", " name_1_cache,\n", " weighted=True)" ] @@ -36546,19 +36682,19 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 31, "id": "0b343e9f", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "76c1846d5e384b9186a3e66dd192d336", + "model_id": "8d13eaaaa4e1418aa1ce511519e9b1b1", "version_major": 2, "version_minor": 0 }, "text/plain": [ - " 0%| | 0/1024 [00:00