Skip to content

Commit

Permalink
Merge pull request #25 from curt-tigges/feature/patching-random-direc…
Browse files Browse the repository at this point in the history
…tion

Feature/patching random direction
  • Loading branch information
ojh31 authored Oct 6, 2023
2 parents 31b591b + 69d9690 commit 6eeea98
Show file tree
Hide file tree
Showing 15 changed files with 779 additions and 388 deletions.
68 changes: 45 additions & 23 deletions classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,24 @@
# ---

# %%
from functools import partial
from typing import List, Optional, Union
from typeguard import typechecked
import jaxtyping
import torch

import random

from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM
from transformers import TrainingArguments, Trainer
from torch import Tensor
from datasets import load_from_disk
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformer_lens import HookedTransformer
from datasets import load_from_disk, Dataset, DatasetDict
from tqdm.notebook import tqdm
from utils.store import load_pickle, load_array
from utils.ablation import ablate_resid_with_precalc_mean
from utils.classifier import HookedClassifier


# %%
BATCH_SIZE = 5
MODEL_NAME ="gpt2"
DATASET_FOLDER = "sst2"
DATASET_FOLDER = "data/sst2"


# %%
Expand All @@ -45,37 +45,60 @@
def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)


tokenized_datasets = dataset.map(tokenize_function, batched=True)

small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(7000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(2000))


# %%
model = AutoModelForCausalLM.from_pretrained("./gpt2_imdb_classifier")
class_layer_weights = load_pickle("gpt2_imdb_classifier_classification_head_weights", 'gpt2')

model = HookedTransformer.from_pretrained(
model = HookedClassifier.from_pretrained(
"data/gpt2-small/gpt2_imdb_classifier",
"gpt2_imdb_classifier_classification_head_weights",
"gpt2",
hf_model=model,
center_unembed=True,
center_writing_weights=True,
fold_ln=True,
refactor_factored_attn_matrices=True,
)
#%%
model([small_eval_dataset[i]['text'] for i in range(5)]).shape

# %%
def get_classification_prediction(eval_dataset, dataset_idx, verbose=False):

logits, cache = model.run_with_cache(small_eval_dataset[dataset_idx]['text'])
_, cache = model.run_with_cache(eval_dataset[dataset_idx]['text'], return_type=None)
last_token_act = cache['ln_final.hook_normalized'][0, -1, :]
res = torch.softmax(torch.tensor(class_layer_weights['score.weight']) @ last_token_act.cpu(), dim=-1)
if verbose:
print(f"Sentence: {small_eval_dataset[dataset_idx]['text']}")
print(f"Prediction: {res.argmax()} Label: {small_eval_dataset[dataset_idx]['label']}")

return res.argmax(), small_eval_dataset[dataset_idx]['label'], res

print(f"Sentence: {eval_dataset[dataset_idx]['text']}")
print(f"Prediction: {res.argmax()} Label: {eval_dataset[dataset_idx]['label']}")

return res.argmax(), eval_dataset[dataset_idx]['label'], res
#%%
get_classification_prediction(small_eval_dataset, 0, verbose=True)
#%%
def forward_override(
model: HookedTransformer,
input: Union[str, List[str], jaxtyping.Int[Tensor, 'batch pos']],
return_type: Optional[str] = 'logits',
):
_, cache = model.run_with_cache(input, return_type=None)
last_token_act = cache['ln_final.hook_normalized'][0, -1, :]
logits = torch.softmax(
torch.tensor(class_layer_weights['score.weight']) @ last_token_act.cpu(),
dim=-1
)
if return_type == 'logits':
return logits
elif return_type == 'prediction':
return logits.argmax()
#%%
forward_override(model, small_eval_dataset[0]['text'], return_type='prediction')
#%%
model.forward = forward_override
#%%
model(small_eval_dataset[0]['text'])
# %%
def get_accuracy(eval_dataset, n=300):
correct = 0
Expand All @@ -84,7 +107,6 @@ def get_accuracy(eval_dataset, n=300):
if pred == label:
correct += 1
return correct / n

#%%
get_accuracy(small_eval_dataset)

# %%
14 changes: 11 additions & 3 deletions classifier_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,19 @@ def plot_bin_proportions(
))

fig.update_layout(
title=f"Proportion of Sentiment by Activation ({label})",
title_x=0.5,
title=dict(
text=f"Proportion of Sentiment by Activation ({label})",
x=0.5,
font=dict(
size=18
),
),
showlegend=True,
xaxis_title="Activation",
yaxis_title="Cum. Label proportion",
font=dict( # global font settings
size=24 # global font size
),
)

return fig
Expand All @@ -177,7 +185,7 @@ def plot_bin_proportions(
save_pdf(fig, out_name, model)
fig.show()

activations = get_activations_cached(dataloader, direction_label)
activations = get_activations_cached(dataloader, direction_label, model)
positive_threshold = activations[:, :, 1].flatten().quantile(.999).item()
negative_threshold = activations[:, :, 1].flatten().quantile(.001).item()

Expand Down
70 changes: 26 additions & 44 deletions compare_lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
# %%
MODEL_NAME = "gpt2-small"
PATTERN = (
r'^(kmeans|pca|das|logistic_regression|mean_diff)_'
r'(simple_train|treebank_train)_'
r'(ADJ|ALL)_'
r'^(kmeans_simple_train_ADJ|pca_simple_train_ADJ|das_simple_train_ADJ|logistic_regression_simple_train_ADJ|mean_diff_simple_train_ADJ|random_direction)_'
r'layer(\d*)'
r'\.npy$'
)
Expand All @@ -22,19 +20,26 @@ def get_directions():
matching_files = [filename for filename in os.listdir(dir_path) if re.match(PATTERN, filename)]
return sorted(matching_files)
#%%
def parse_filename(filename):
match = re.search(PATTERN, filename)
method, data, position, layer = match.groups()
layer = int(layer)
return method, data, position, layer
def clean_labels(labels: pd.Series):
return (
labels
.str.replace("simple_train_", "")
.str.replace("logistic_regression", "LR")
.str.replace("das", "DAS")
.str.replace("kmeans", "K_means")
.str.replace("pca", "PCA")
.str.replace("treebank_train_ALL", "treebank")
.str.replace("mean_diff", "Mean_diff")
.str.replace("random_direction", "Random")
.str.replace("_ADJ", "")
.str.replace(".npy", "")
.str.replace("000", "00")
)
#%%
direction_labels = get_directions()
direction_layers = [extract_layer_from_string(label) for label in direction_labels]
directions = [load_array(filename, MODEL_NAME).squeeze() for filename in direction_labels]
direction_labels = [zero_pad_layer_string(label) for label in direction_labels]
label_tuples = [parse_filename(label) for label in direction_labels]
treebank_layers = [layer for i, layer in enumerate(direction_layers) if label_tuples[i][1] == 'treebank_train']
label_multiindex = pd.MultiIndex.from_tuples(label_tuples, names=['method', 'data', 'position', 'layer'])
stacked = np.stack(directions)
stacked = (stacked.T / np.linalg.norm(stacked, axis=1)).T
similarities: Float[np.ndarray, "direction d_model"] = einops.einsum(
Expand All @@ -52,47 +57,24 @@ def move_row_to_end(row, df):
rows.remove(row)
rows.append(row)
return df.loc[rows]
#%%
def clean_labels(labels: pd.Series):
return (
labels
.str.replace("simple_train_", "")
.str.replace("logistic_regression", "LR")
.str.replace("das", "DAS")
.str.replace("kmeans", "K_means")
.str.replace("pca", "PCA")
.str.replace("treebank_train_ALL", "treebank")
.str.replace("_0", "")
.str.replace("mean_diff", "Mean_diff")
.str.replace("_ADJ", "")
)
# %%
df = pd.DataFrame(
similarities,
columns=label_multiindex,
index=label_multiindex,
columns=direction_labels,
index=direction_labels,
).sort_index(axis=0).sort_index(axis=1)
df.columns = clean_labels(df.columns)
df.index = clean_labels(df.index)
df = df.loc[
df.index.get_level_values(-1).isin([0]),
df.columns.get_level_values(-1).isin([0]),
]
# df = df.loc[
# (df.index.get_level_values(1) != 'treebank_train'),
# (df.columns.get_level_values(1) != 'treebank_train')
# ]
df = df.loc[
(df.index.get_level_values(1) == 'treebank_train') | (df.index.get_level_values(2) == 'ADJ'),
(df.columns.get_level_values(1) == 'treebank_train') | (df.columns.get_level_values(2) == 'ADJ')
]
df = df.loc[
(df.index.get_level_values(1) != 'treebank_train') | (df.index.get_level_values(0) == 'das'),
(df.columns.get_level_values(1) != 'treebank_train') | (df.columns.get_level_values(0) == 'das')
df.index.str.contains("_layer00"),
df.columns.str.contains("_layer00")
]
df = flatten_multiindex(df)
df.columns = df.columns.str.replace("_layer00", "")
df.index = df.index.str.replace("_layer00", "")
df = df.abs()
df
#%%
# df = df.drop(columns=['das_treebank_train_ALL_0'], index=['das_treebank_train_ALL_0'])
df.columns = clean_labels(df.columns)
df.index = clean_labels(df.index)
styled = (
df.style
.background_gradient(cmap='Reds', vmin=0, vmax=1)
Expand Down
117 changes: 117 additions & 0 deletions dataset_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#%%
import random
import torch
from transformer_lens import HookedTransformer
from transformer_lens.utils import test_prompt, get_attention_mask, LocallyOverridenDefaults
import plotly.express as px
from utils.prompts import CleanCorruptedCacheResults, get_dataset, PromptType, ReviewScaffold, CleanCorruptedDataset
#%%
model = HookedTransformer.from_pretrained("EleutherAI/pythia-1.4b")
#%%
batch_size = 256
device = torch.device("cuda")
prompt_type = PromptType.TREEBANK_TEST
scaffold = ReviewScaffold.CLASSIFICATION
names_filter = lambda _: False
clean_corrupt_data: CleanCorruptedDataset = get_dataset(
model, device, prompt_type=prompt_type, scaffold=scaffold,
)
#%%
print(len(clean_corrupt_data))
# #%%
# get_attention_mask(model.tokenizer, clean_corrupt_data.clean_tokens, prepend_bos=False).sum(axis=1)
#%%
non_pad_tokens = clean_corrupt_data.get_num_non_pad_tokens()
px.histogram(non_pad_tokens.cpu(), nbins=100)
#%%
# with LocallyOverridenDefaults(model, padding_side="left"):
patching_dataset: CleanCorruptedCacheResults = clean_corrupt_data.restrict_by_padding(
0, 25
).run_with_cache(
model,
names_filter=names_filter,
batch_size=batch_size,
device=device,
disable_tqdm=True,
center=True,
)
print(len(patching_dataset.clean_logit_diffs))
print(patching_dataset)
# %%
len(patching_dataset.clean_logit_diffs), len(clean_corrupt_data.all_prompts)
#%%
patching_dataset.clean_logit_diffs[:10]
#%%
sample_index = random.randint(0, len(clean_corrupt_data.all_prompts))
clean_corrupt_data.all_prompts[sample_index]
#%%
# With padding
test_prompt(
model.to_string(clean_corrupt_data.clean_tokens[sample_index]),
model.to_string(clean_corrupt_data.answer_tokens[sample_index, 0, 0]),
model,
prepend_space_to_answer=False,
)
#%%
# Without padding
test_prompt(
clean_corrupt_data.all_prompts[sample_index],
model.to_string(clean_corrupt_data.answer_tokens[sample_index, 0, 0]),
model,
prepend_space_to_answer=False,
)
#%%
with LocallyOverridenDefaults(model, padding_side="right"):
test_prompt(
model.to_string(clean_corrupt_data.clean_tokens[sample_index]),
model.to_string(clean_corrupt_data.answer_tokens[sample_index, 0, 0]),
model,
prepend_space_to_answer=False,
)
#%%
# Artificial left padding
with LocallyOverridenDefaults(model, padding_side="left"):
test_prompt(
"".join([model.tokenizer.pad_token] * 2) + clean_corrupt_data.all_prompts[sample_index],
model.to_string(clean_corrupt_data.answer_tokens[sample_index, 0, 0]),
model,
prepend_space_to_answer=False,
prepend_bos=True,
)
#%%
# Artificial right padding
with LocallyOverridenDefaults(model, padding_side="right"):
test_prompt(
clean_corrupt_data.all_prompts[sample_index] + "".join([model.tokenizer.pad_token] * 2),
model.to_string(clean_corrupt_data.answer_tokens[sample_index, 0, 0]),
model,
prepend_space_to_answer=False,
prepend_bos=True,
)
#%%
# Attention mask for right padding
with LocallyOverridenDefaults(model, padding_side="right"):
print(get_attention_mask(
model.tokenizer,
model.to_tokens(
clean_corrupt_data.all_prompts[sample_index] +
"".join([model.tokenizer.pad_token] * 2),
prepend_bos=False
),
prepend_bos=True,
))
#%%
test_prompt(
clean_corrupt_data.all_prompts[0],
model.to_string(clean_corrupt_data.answer_tokens[0, 0, 0]),
model,
prepend_space_to_answer=False,
)
# %%
test_prompt(
clean_corrupt_data.all_prompts[1],
model.to_string(clean_corrupt_data.answer_tokens[0, 0, 1]),
model,
prepend_space_to_answer=False,
)
# %%
Loading

0 comments on commit 6eeea98

Please sign in to comment.