From 143adee49097fb057acceafe3b5885639d198ba4 Mon Sep 17 00:00:00 2001 From: skar0 Date: Mon, 16 Oct 2023 16:18:06 +0100 Subject: [PATCH] Generated scatterplot --- fit_one_sided_directions.py | 146 ++++++++++++++++++++++++++---------- 1 file changed, 106 insertions(+), 40 deletions(-) diff --git a/fit_one_sided_directions.py b/fit_one_sided_directions.py index b0caa17b..3171c8e9 100644 --- a/fit_one_sided_directions.py +++ b/fit_one_sided_directions.py @@ -1,6 +1,7 @@ #%% import einops from functools import partial +import numpy as np import torch from torch import Tensor from torch.utils.data import DataLoader @@ -13,6 +14,7 @@ from tqdm.notebook import tqdm import pandas as pd import yaml +import plotly.express as px from utils.store import load_array, save_html, save_array, is_file, get_model_name, clean_label, save_text #%% torch.set_grad_enabled(False) @@ -22,48 +24,48 @@ ) #%% ACT_NAME = get_act_name("resid_post", 0) -# #%% -# BATCH_SIZE = 64 -# owt_data = load_dataset("stas/openwebtext-10k", split="train") -# dataset = tokenize_and_concatenate(owt_data, model.tokenizer) -# data_loader = DataLoader( -# dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True -# ) -# #%% # Neutral -# count = 0 -# total = torch.zeros(model.cfg.d_model) -# for batch in tqdm(data_loader): -# _, cache = model.run_with_cache( -# batch['tokens'], -# return_type=None, -# names_filter = lambda name: name == ACT_NAME -# ) -# count += 1 -# total += cache[ACT_NAME][:, 1, :].mean(dim=0).cpu() -# neutral_activation = total / count -# print(neutral_activation.shape, neutral_activation.norm()) +#%% +BATCH_SIZE = 64 +owt_data = load_dataset("stas/openwebtext-10k", split="train") +dataset = tokenize_and_concatenate(owt_data, model.tokenizer) +data_loader = DataLoader( + dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True +) +#%% # Neutral +count = 0 +total = torch.zeros(model.cfg.d_model) +for batch in tqdm(data_loader): + _, cache = model.run_with_cache( + batch['tokens'], + return_type=None, + names_filter = lambda name: name == ACT_NAME + ) + count += 1 + total += cache[ACT_NAME][:, 1, :].mean(dim=0).cpu() +neutral_activation = total / count +print(neutral_activation.shape, neutral_activation.norm()) #%% Handmade prompts with open("prompts.yaml", "r") as f: prompt_dict = yaml.safe_load(f) #%% Handmade neutral -neutral_str_tokens = prompt_dict['neutral_adjectives'] -neutral_single_tokens = [] -for token in neutral_str_tokens: - token = " " + token - if len(model.to_str_tokens(token, prepend_bos=False)) == 1: - neutral_single_tokens.append(token) -neutral_tokens = model.to_tokens( - neutral_single_tokens, - prepend_bos=True, -) -assert neutral_tokens.shape[1] == 2 -_, neutral_cache = model.run_with_cache( - neutral_tokens, - return_type=None, - names_filter = lambda name: name == ACT_NAME -) -neutral_activation = neutral_cache[ACT_NAME][:, -1].mean(dim=0).cpu() -print(neutral_activation.shape, neutral_activation.norm()) +# neutral_str_tokens = prompt_dict['neutral_adjectives'] +# neutral_single_tokens = [] +# for token in neutral_str_tokens: +# token = " " + token +# if len(model.to_str_tokens(token, prepend_bos=False)) == 1: +# neutral_single_tokens.append(token) +# neutral_tokens = model.to_tokens( +# neutral_single_tokens, +# prepend_bos=True, +# ) +# assert neutral_tokens.shape[1] == 2 +# _, neutral_cache = model.run_with_cache( +# neutral_tokens, +# return_type=None, +# names_filter = lambda name: name == ACT_NAME +# ) +# neutral_activation = neutral_cache[ACT_NAME][:, -1].mean(dim=0).cpu() +# print(neutral_activation.shape, neutral_activation.norm()) #%% # Positive #%% positive_str_tokens = ( @@ -120,6 +122,72 @@ negative_direction = negative_activation - neutral_activation positive_direction = positive_direction / positive_direction.norm() negative_direction = negative_direction / negative_direction.norm() +torch.cosine_similarity(positive_direction, negative_direction, dim=0) +#%% +is_valenced_direction = positive_direction + negative_direction +is_valenced_direction = is_valenced_direction / is_valenced_direction.norm() +is_valenced_direction = is_valenced_direction.to(device) +sentiment_direction = positive_direction - negative_direction +sentiment_direction = sentiment_direction / sentiment_direction.norm() +sentiment_direction = sentiment_direction.to(device) +torch.cosine_similarity(is_valenced_direction, sentiment_direction, dim=0) +#%% +all_tokens = torch.tensor([], dtype=torch.int32, device=device,) +val_scores = torch.tensor([], dtype=torch.float32, device=device,) +sent_scores = torch.tensor([], dtype=torch.float32, device=device,) +for batch in tqdm(data_loader): + batch_tokens = batch['tokens'].to(device) + _, cache = model.run_with_cache( + batch_tokens, + return_type=None, + names_filter = lambda name: name == ACT_NAME + ) + val_score = einops.einsum( + cache[ACT_NAME], + is_valenced_direction, + "batch pos d_model, d_model -> batch pos", + ) + sent_score = einops.einsum( + cache[ACT_NAME], + sentiment_direction, + "batch pos d_model, d_model -> batch pos", + ) + val_score = einops.rearrange( + val_score, "batch pos -> (batch pos)" + ) + sent_score = einops.rearrange( + sent_score, "batch pos -> (batch pos)" + ) + flat_tokens = einops.rearrange( + batch_tokens, "batch pos -> (batch pos)" + ) + all_tokens = torch.cat([all_tokens, flat_tokens]) + val_scores = torch.cat([val_scores, val_score]) + sent_scores = torch.cat([sent_scores, sent_score]) + if len(all_tokens) > 10_000: + break +val_scores = val_scores.cpu().numpy() +sent_scores = sent_scores.cpu().numpy() +all_tokens = all_tokens.cpu().numpy() +print(val_scores.shape, sent_scores.shape, all_tokens.shape) +#%% +fig = px.scatter( + x=val_scores, + y=sent_scores, + text=all_tokens, + labels=dict(x="Valenced", y="Sentiment"), +) +fig.update_layout( + title=dict( + text="Valenced vs Sentiment activations", + x=0.5, + ) +) +fig.show() +#%% +model.to_str_tokens(np.array([ + 49435, 35333, 42262, 44406, 32554, 25872, 23590, 39609 +])) #%% save_array( positive_direction.cpu().numpy(), "mean_diff_positive_layer01", model @@ -127,7 +195,5 @@ save_array( negative_direction.cpu().numpy(), "mean_diff_negative_layer01", model ) -#%% # compute cosine similarity -torch.cosine_similarity(positive_direction, negative_direction, dim=0) # %%