Skip to content

Commit

Permalink
Committing everything for backup purposes
Browse files Browse the repository at this point in the history
  • Loading branch information
ojh31 committed Oct 5, 2023
1 parent 5d396a4 commit 69d9690
Show file tree
Hide file tree
Showing 10 changed files with 306 additions and 202 deletions.
68 changes: 45 additions & 23 deletions classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,24 @@
# ---

# %%
from functools import partial
from typing import List, Optional, Union
from typeguard import typechecked
import jaxtyping
import torch

import random

from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM
from transformers import TrainingArguments, Trainer
from torch import Tensor
from datasets import load_from_disk
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformer_lens import HookedTransformer
from datasets import load_from_disk, Dataset, DatasetDict
from tqdm.notebook import tqdm
from utils.store import load_pickle, load_array
from utils.ablation import ablate_resid_with_precalc_mean
from utils.classifier import HookedClassifier


# %%
BATCH_SIZE = 5
MODEL_NAME ="gpt2"
DATASET_FOLDER = "sst2"
DATASET_FOLDER = "data/sst2"


# %%
Expand All @@ -45,37 +45,60 @@
def tokenize_function(examples):
return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)


tokenized_datasets = dataset.map(tokenize_function, batched=True)

small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(7000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(2000))


# %%
model = AutoModelForCausalLM.from_pretrained("./gpt2_imdb_classifier")
class_layer_weights = load_pickle("gpt2_imdb_classifier_classification_head_weights", 'gpt2')

model = HookedTransformer.from_pretrained(
model = HookedClassifier.from_pretrained(
"data/gpt2-small/gpt2_imdb_classifier",
"gpt2_imdb_classifier_classification_head_weights",
"gpt2",
hf_model=model,
center_unembed=True,
center_writing_weights=True,
fold_ln=True,
refactor_factored_attn_matrices=True,
)
#%%
model([small_eval_dataset[i]['text'] for i in range(5)]).shape

# %%
def get_classification_prediction(eval_dataset, dataset_idx, verbose=False):

logits, cache = model.run_with_cache(small_eval_dataset[dataset_idx]['text'])
_, cache = model.run_with_cache(eval_dataset[dataset_idx]['text'], return_type=None)
last_token_act = cache['ln_final.hook_normalized'][0, -1, :]
res = torch.softmax(torch.tensor(class_layer_weights['score.weight']) @ last_token_act.cpu(), dim=-1)
if verbose:
print(f"Sentence: {small_eval_dataset[dataset_idx]['text']}")
print(f"Prediction: {res.argmax()} Label: {small_eval_dataset[dataset_idx]['label']}")

return res.argmax(), small_eval_dataset[dataset_idx]['label'], res

print(f"Sentence: {eval_dataset[dataset_idx]['text']}")
print(f"Prediction: {res.argmax()} Label: {eval_dataset[dataset_idx]['label']}")

return res.argmax(), eval_dataset[dataset_idx]['label'], res
#%%
get_classification_prediction(small_eval_dataset, 0, verbose=True)
#%%
def forward_override(
model: HookedTransformer,
input: Union[str, List[str], jaxtyping.Int[Tensor, 'batch pos']],
return_type: Optional[str] = 'logits',
):
_, cache = model.run_with_cache(input, return_type=None)
last_token_act = cache['ln_final.hook_normalized'][0, -1, :]
logits = torch.softmax(
torch.tensor(class_layer_weights['score.weight']) @ last_token_act.cpu(),
dim=-1
)
if return_type == 'logits':
return logits
elif return_type == 'prediction':
return logits.argmax()
#%%
forward_override(model, small_eval_dataset[0]['text'], return_type='prediction')
#%%
model.forward = forward_override
#%%
model(small_eval_dataset[0]['text'])
# %%
def get_accuracy(eval_dataset, n=300):
correct = 0
Expand All @@ -84,7 +107,6 @@ def get_accuracy(eval_dataset, n=300):
if pred == label:
correct += 1
return correct / n

#%%
get_accuracy(small_eval_dataset)

# %%
14 changes: 11 additions & 3 deletions classifier_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,19 @@ def plot_bin_proportions(
))

fig.update_layout(
title=f"Proportion of Sentiment by Activation ({label})",
title_x=0.5,
title=dict(
text=f"Proportion of Sentiment by Activation ({label})",
x=0.5,
font=dict(
size=18
),
),
showlegend=True,
xaxis_title="Activation",
yaxis_title="Cum. Label proportion",
font=dict( # global font settings
size=24 # global font size
),
)

return fig
Expand All @@ -177,7 +185,7 @@ def plot_bin_proportions(
save_pdf(fig, out_name, model)
fig.show()

activations = get_activations_cached(dataloader, direction_label)
activations = get_activations_cached(dataloader, direction_label, model)
positive_threshold = activations[:, :, 1].flatten().quantile(.999).item()
negative_threshold = activations[:, :, 1].flatten().quantile(.001).item()

Expand Down
117 changes: 117 additions & 0 deletions dataset_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
#%%
import random
import torch
from transformer_lens import HookedTransformer
from transformer_lens.utils import test_prompt, get_attention_mask, LocallyOverridenDefaults
import plotly.express as px
from utils.prompts import CleanCorruptedCacheResults, get_dataset, PromptType, ReviewScaffold, CleanCorruptedDataset
#%%
model = HookedTransformer.from_pretrained("EleutherAI/pythia-1.4b")
#%%
batch_size = 256
device = torch.device("cuda")
prompt_type = PromptType.TREEBANK_TEST
scaffold = ReviewScaffold.CLASSIFICATION
names_filter = lambda _: False
clean_corrupt_data: CleanCorruptedDataset = get_dataset(
model, device, prompt_type=prompt_type, scaffold=scaffold,
)
#%%
print(len(clean_corrupt_data))
# #%%
# get_attention_mask(model.tokenizer, clean_corrupt_data.clean_tokens, prepend_bos=False).sum(axis=1)
#%%
non_pad_tokens = clean_corrupt_data.get_num_non_pad_tokens()
px.histogram(non_pad_tokens.cpu(), nbins=100)
#%%
# with LocallyOverridenDefaults(model, padding_side="left"):
patching_dataset: CleanCorruptedCacheResults = clean_corrupt_data.restrict_by_padding(
0, 25
).run_with_cache(
model,
names_filter=names_filter,
batch_size=batch_size,
device=device,
disable_tqdm=True,
center=True,
)
print(len(patching_dataset.clean_logit_diffs))
print(patching_dataset)
# %%
len(patching_dataset.clean_logit_diffs), len(clean_corrupt_data.all_prompts)
#%%
patching_dataset.clean_logit_diffs[:10]
#%%
sample_index = random.randint(0, len(clean_corrupt_data.all_prompts))
clean_corrupt_data.all_prompts[sample_index]
#%%
# With padding
test_prompt(
model.to_string(clean_corrupt_data.clean_tokens[sample_index]),
model.to_string(clean_corrupt_data.answer_tokens[sample_index, 0, 0]),
model,
prepend_space_to_answer=False,
)
#%%
# Without padding
test_prompt(
clean_corrupt_data.all_prompts[sample_index],
model.to_string(clean_corrupt_data.answer_tokens[sample_index, 0, 0]),
model,
prepend_space_to_answer=False,
)
#%%
with LocallyOverridenDefaults(model, padding_side="right"):
test_prompt(
model.to_string(clean_corrupt_data.clean_tokens[sample_index]),
model.to_string(clean_corrupt_data.answer_tokens[sample_index, 0, 0]),
model,
prepend_space_to_answer=False,
)
#%%
# Artificial left padding
with LocallyOverridenDefaults(model, padding_side="left"):
test_prompt(
"".join([model.tokenizer.pad_token] * 2) + clean_corrupt_data.all_prompts[sample_index],
model.to_string(clean_corrupt_data.answer_tokens[sample_index, 0, 0]),
model,
prepend_space_to_answer=False,
prepend_bos=True,
)
#%%
# Artificial right padding
with LocallyOverridenDefaults(model, padding_side="right"):
test_prompt(
clean_corrupt_data.all_prompts[sample_index] + "".join([model.tokenizer.pad_token] * 2),
model.to_string(clean_corrupt_data.answer_tokens[sample_index, 0, 0]),
model,
prepend_space_to_answer=False,
prepend_bos=True,
)
#%%
# Attention mask for right padding
with LocallyOverridenDefaults(model, padding_side="right"):
print(get_attention_mask(
model.tokenizer,
model.to_tokens(
clean_corrupt_data.all_prompts[sample_index] +
"".join([model.tokenizer.pad_token] * 2),
prepend_bos=False
),
prepend_bos=True,
))
#%%
test_prompt(
clean_corrupt_data.all_prompts[0],
model.to_string(clean_corrupt_data.answer_tokens[0, 0, 0]),
model,
prepend_space_to_answer=False,
)
# %%
test_prompt(
clean_corrupt_data.all_prompts[1],
model.to_string(clean_corrupt_data.answer_tokens[0, 0, 1]),
model,
prepend_space_to_answer=False,
)
# %%
Loading

0 comments on commit 69d9690

Please sign in to comment.