Skip to content

Commit

Permalink
Generated scatterplot
Browse files Browse the repository at this point in the history
  • Loading branch information
ojh31 committed Oct 16, 2023
1 parent 0dc9853 commit 143adee
Showing 1 changed file with 106 additions and 40 deletions.
146 changes: 106 additions & 40 deletions fit_one_sided_directions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#%%
import einops
from functools import partial
import numpy as np
import torch
from torch import Tensor
from torch.utils.data import DataLoader
Expand All @@ -13,6 +14,7 @@
from tqdm.notebook import tqdm
import pandas as pd
import yaml
import plotly.express as px
from utils.store import load_array, save_html, save_array, is_file, get_model_name, clean_label, save_text
#%%
torch.set_grad_enabled(False)
Expand All @@ -22,48 +24,48 @@
)
#%%
ACT_NAME = get_act_name("resid_post", 0)
# #%%
# BATCH_SIZE = 64
# owt_data = load_dataset("stas/openwebtext-10k", split="train")
# dataset = tokenize_and_concatenate(owt_data, model.tokenizer)
# data_loader = DataLoader(
# dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True
# )
# #%% # Neutral
# count = 0
# total = torch.zeros(model.cfg.d_model)
# for batch in tqdm(data_loader):
# _, cache = model.run_with_cache(
# batch['tokens'],
# return_type=None,
# names_filter = lambda name: name == ACT_NAME
# )
# count += 1
# total += cache[ACT_NAME][:, 1, :].mean(dim=0).cpu()
# neutral_activation = total / count
# print(neutral_activation.shape, neutral_activation.norm())
#%%
BATCH_SIZE = 64
owt_data = load_dataset("stas/openwebtext-10k", split="train")
dataset = tokenize_and_concatenate(owt_data, model.tokenizer)
data_loader = DataLoader(
dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True
)
#%% # Neutral
count = 0
total = torch.zeros(model.cfg.d_model)
for batch in tqdm(data_loader):
_, cache = model.run_with_cache(
batch['tokens'],
return_type=None,
names_filter = lambda name: name == ACT_NAME
)
count += 1
total += cache[ACT_NAME][:, 1, :].mean(dim=0).cpu()
neutral_activation = total / count
print(neutral_activation.shape, neutral_activation.norm())
#%% Handmade prompts
with open("prompts.yaml", "r") as f:
prompt_dict = yaml.safe_load(f)
#%% Handmade neutral
neutral_str_tokens = prompt_dict['neutral_adjectives']
neutral_single_tokens = []
for token in neutral_str_tokens:
token = " " + token
if len(model.to_str_tokens(token, prepend_bos=False)) == 1:
neutral_single_tokens.append(token)
neutral_tokens = model.to_tokens(
neutral_single_tokens,
prepend_bos=True,
)
assert neutral_tokens.shape[1] == 2
_, neutral_cache = model.run_with_cache(
neutral_tokens,
return_type=None,
names_filter = lambda name: name == ACT_NAME
)
neutral_activation = neutral_cache[ACT_NAME][:, -1].mean(dim=0).cpu()
print(neutral_activation.shape, neutral_activation.norm())
# neutral_str_tokens = prompt_dict['neutral_adjectives']
# neutral_single_tokens = []
# for token in neutral_str_tokens:
# token = " " + token
# if len(model.to_str_tokens(token, prepend_bos=False)) == 1:
# neutral_single_tokens.append(token)
# neutral_tokens = model.to_tokens(
# neutral_single_tokens,
# prepend_bos=True,
# )
# assert neutral_tokens.shape[1] == 2
# _, neutral_cache = model.run_with_cache(
# neutral_tokens,
# return_type=None,
# names_filter = lambda name: name == ACT_NAME
# )
# neutral_activation = neutral_cache[ACT_NAME][:, -1].mean(dim=0).cpu()
# print(neutral_activation.shape, neutral_activation.norm())
#%% # Positive
#%%
positive_str_tokens = (
Expand Down Expand Up @@ -120,14 +122,78 @@
negative_direction = negative_activation - neutral_activation
positive_direction = positive_direction / positive_direction.norm()
negative_direction = negative_direction / negative_direction.norm()
torch.cosine_similarity(positive_direction, negative_direction, dim=0)
#%%
is_valenced_direction = positive_direction + negative_direction
is_valenced_direction = is_valenced_direction / is_valenced_direction.norm()
is_valenced_direction = is_valenced_direction.to(device)
sentiment_direction = positive_direction - negative_direction
sentiment_direction = sentiment_direction / sentiment_direction.norm()
sentiment_direction = sentiment_direction.to(device)
torch.cosine_similarity(is_valenced_direction, sentiment_direction, dim=0)
#%%
all_tokens = torch.tensor([], dtype=torch.int32, device=device,)
val_scores = torch.tensor([], dtype=torch.float32, device=device,)
sent_scores = torch.tensor([], dtype=torch.float32, device=device,)
for batch in tqdm(data_loader):
batch_tokens = batch['tokens'].to(device)
_, cache = model.run_with_cache(
batch_tokens,
return_type=None,
names_filter = lambda name: name == ACT_NAME
)
val_score = einops.einsum(
cache[ACT_NAME],
is_valenced_direction,
"batch pos d_model, d_model -> batch pos",
)
sent_score = einops.einsum(
cache[ACT_NAME],
sentiment_direction,
"batch pos d_model, d_model -> batch pos",
)
val_score = einops.rearrange(
val_score, "batch pos -> (batch pos)"
)
sent_score = einops.rearrange(
sent_score, "batch pos -> (batch pos)"
)
flat_tokens = einops.rearrange(
batch_tokens, "batch pos -> (batch pos)"
)
all_tokens = torch.cat([all_tokens, flat_tokens])
val_scores = torch.cat([val_scores, val_score])
sent_scores = torch.cat([sent_scores, sent_score])
if len(all_tokens) > 10_000:
break
val_scores = val_scores.cpu().numpy()
sent_scores = sent_scores.cpu().numpy()
all_tokens = all_tokens.cpu().numpy()
print(val_scores.shape, sent_scores.shape, all_tokens.shape)
#%%
fig = px.scatter(
x=val_scores,
y=sent_scores,
text=all_tokens,
labels=dict(x="Valenced", y="Sentiment"),
)
fig.update_layout(
title=dict(
text="Valenced vs Sentiment activations",
x=0.5,
)
)
fig.show()
#%%
model.to_str_tokens(np.array([
49435, 35333, 42262, 44406, 32554, 25872, 23590, 39609
]))
#%%
save_array(
positive_direction.cpu().numpy(), "mean_diff_positive_layer01", model
)
save_array(
negative_direction.cpu().numpy(), "mean_diff_negative_layer01", model
)
#%% # compute cosine similarity
torch.cosine_similarity(positive_direction, negative_direction, dim=0)

# %%

0 comments on commit 143adee

Please sign in to comment.