diff --git a/classification.py b/classification.py index f49149e..bfc1115 100644 --- a/classification.py +++ b/classification.py @@ -15,24 +15,24 @@ # --- # %% +from functools import partial +from typing import List, Optional, Union +from typeguard import typechecked +import jaxtyping import torch - -import random - -from torch.utils.data import DataLoader -from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForCausalLM -from transformers import TrainingArguments, Trainer +from torch import Tensor +from datasets import load_from_disk +from transformers import AutoTokenizer, AutoModelForCausalLM from transformer_lens import HookedTransformer -from datasets import load_from_disk, Dataset, DatasetDict from tqdm.notebook import tqdm from utils.store import load_pickle, load_array -from utils.ablation import ablate_resid_with_precalc_mean +from utils.classifier import HookedClassifier # %% BATCH_SIZE = 5 MODEL_NAME ="gpt2" -DATASET_FOLDER = "sst2" +DATASET_FOLDER = "data/sst2" # %% @@ -45,37 +45,60 @@ def tokenize_function(examples): return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512) - tokenized_datasets = dataset.map(tokenize_function, batched=True) small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(7000)) small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(2000)) + # %% -model = AutoModelForCausalLM.from_pretrained("./gpt2_imdb_classifier") -class_layer_weights = load_pickle("gpt2_imdb_classifier_classification_head_weights", 'gpt2') - -model = HookedTransformer.from_pretrained( +model = HookedClassifier.from_pretrained( + "data/gpt2-small/gpt2_imdb_classifier", + "gpt2_imdb_classifier_classification_head_weights", "gpt2", - hf_model=model, center_unembed=True, center_writing_weights=True, fold_ln=True, refactor_factored_attn_matrices=True, ) +#%% +model([small_eval_dataset[i]['text'] for i in range(5)]).shape # %% def get_classification_prediction(eval_dataset, dataset_idx, verbose=False): - logits, cache = model.run_with_cache(small_eval_dataset[dataset_idx]['text']) + _, cache = model.run_with_cache(eval_dataset[dataset_idx]['text'], return_type=None) last_token_act = cache['ln_final.hook_normalized'][0, -1, :] res = torch.softmax(torch.tensor(class_layer_weights['score.weight']) @ last_token_act.cpu(), dim=-1) if verbose: - print(f"Sentence: {small_eval_dataset[dataset_idx]['text']}") - print(f"Prediction: {res.argmax()} Label: {small_eval_dataset[dataset_idx]['label']}") - - return res.argmax(), small_eval_dataset[dataset_idx]['label'], res - + print(f"Sentence: {eval_dataset[dataset_idx]['text']}") + print(f"Prediction: {res.argmax()} Label: {eval_dataset[dataset_idx]['label']}") + + return res.argmax(), eval_dataset[dataset_idx]['label'], res +#%% +get_classification_prediction(small_eval_dataset, 0, verbose=True) +#%% +def forward_override( + model: HookedTransformer, + input: Union[str, List[str], jaxtyping.Int[Tensor, 'batch pos']], + return_type: Optional[str] = 'logits', +): + _, cache = model.run_with_cache(input, return_type=None) + last_token_act = cache['ln_final.hook_normalized'][0, -1, :] + logits = torch.softmax( + torch.tensor(class_layer_weights['score.weight']) @ last_token_act.cpu(), + dim=-1 + ) + if return_type == 'logits': + return logits + elif return_type == 'prediction': + return logits.argmax() +#%% +forward_override(model, small_eval_dataset[0]['text'], return_type='prediction') +#%% +model.forward = forward_override +#%% +model(small_eval_dataset[0]['text']) # %% def get_accuracy(eval_dataset, n=300): correct = 0 @@ -84,7 +107,6 @@ def get_accuracy(eval_dataset, n=300): if pred == label: correct += 1 return correct / n - +#%% get_accuracy(small_eval_dataset) - # %% diff --git a/classifier_accuracy.py b/classifier_accuracy.py index 2d903d2..83a169b 100644 --- a/classifier_accuracy.py +++ b/classifier_accuracy.py @@ -154,11 +154,19 @@ def plot_bin_proportions( )) fig.update_layout( - title=f"Proportion of Sentiment by Activation ({label})", - title_x=0.5, + title=dict( + text=f"Proportion of Sentiment by Activation ({label})", + x=0.5, + font=dict( + size=18 + ), + ), showlegend=True, xaxis_title="Activation", yaxis_title="Cum. Label proportion", + font=dict( # global font settings + size=24 # global font size + ), ) return fig @@ -177,7 +185,7 @@ def plot_bin_proportions( save_pdf(fig, out_name, model) fig.show() - activations = get_activations_cached(dataloader, direction_label) + activations = get_activations_cached(dataloader, direction_label, model) positive_threshold = activations[:, :, 1].flatten().quantile(.999).item() negative_threshold = activations[:, :, 1].flatten().quantile(.001).item() diff --git a/dataset_stats.py b/dataset_stats.py new file mode 100644 index 0000000..7d84758 --- /dev/null +++ b/dataset_stats.py @@ -0,0 +1,117 @@ +#%% +import random +import torch +from transformer_lens import HookedTransformer +from transformer_lens.utils import test_prompt, get_attention_mask, LocallyOverridenDefaults +import plotly.express as px +from utils.prompts import CleanCorruptedCacheResults, get_dataset, PromptType, ReviewScaffold, CleanCorruptedDataset +#%% +model = HookedTransformer.from_pretrained("EleutherAI/pythia-1.4b") +#%% +batch_size = 256 +device = torch.device("cuda") +prompt_type = PromptType.TREEBANK_TEST +scaffold = ReviewScaffold.CLASSIFICATION +names_filter = lambda _: False +clean_corrupt_data: CleanCorruptedDataset = get_dataset( + model, device, prompt_type=prompt_type, scaffold=scaffold, +) +#%% +print(len(clean_corrupt_data)) +# #%% +# get_attention_mask(model.tokenizer, clean_corrupt_data.clean_tokens, prepend_bos=False).sum(axis=1) +#%% +non_pad_tokens = clean_corrupt_data.get_num_non_pad_tokens() +px.histogram(non_pad_tokens.cpu(), nbins=100) +#%% +# with LocallyOverridenDefaults(model, padding_side="left"): +patching_dataset: CleanCorruptedCacheResults = clean_corrupt_data.restrict_by_padding( + 0, 25 +).run_with_cache( + model, + names_filter=names_filter, + batch_size=batch_size, + device=device, + disable_tqdm=True, + center=True, +) +print(len(patching_dataset.clean_logit_diffs)) +print(patching_dataset) +# %% +len(patching_dataset.clean_logit_diffs), len(clean_corrupt_data.all_prompts) +#%% +patching_dataset.clean_logit_diffs[:10] +#%% +sample_index = random.randint(0, len(clean_corrupt_data.all_prompts)) +clean_corrupt_data.all_prompts[sample_index] +#%% +# With padding +test_prompt( + model.to_string(clean_corrupt_data.clean_tokens[sample_index]), + model.to_string(clean_corrupt_data.answer_tokens[sample_index, 0, 0]), + model, + prepend_space_to_answer=False, +) +#%% +# Without padding +test_prompt( + clean_corrupt_data.all_prompts[sample_index], + model.to_string(clean_corrupt_data.answer_tokens[sample_index, 0, 0]), + model, + prepend_space_to_answer=False, +) +#%% +with LocallyOverridenDefaults(model, padding_side="right"): + test_prompt( + model.to_string(clean_corrupt_data.clean_tokens[sample_index]), + model.to_string(clean_corrupt_data.answer_tokens[sample_index, 0, 0]), + model, + prepend_space_to_answer=False, + ) +#%% +# Artificial left padding +with LocallyOverridenDefaults(model, padding_side="left"): + test_prompt( + "".join([model.tokenizer.pad_token] * 2) + clean_corrupt_data.all_prompts[sample_index], + model.to_string(clean_corrupt_data.answer_tokens[sample_index, 0, 0]), + model, + prepend_space_to_answer=False, + prepend_bos=True, + ) +#%% +# Artificial right padding +with LocallyOverridenDefaults(model, padding_side="right"): + test_prompt( + clean_corrupt_data.all_prompts[sample_index] + "".join([model.tokenizer.pad_token] * 2), + model.to_string(clean_corrupt_data.answer_tokens[sample_index, 0, 0]), + model, + prepend_space_to_answer=False, + prepend_bos=True, + ) +#%% +# Attention mask for right padding +with LocallyOverridenDefaults(model, padding_side="right"): + print(get_attention_mask( + model.tokenizer, + model.to_tokens( + clean_corrupt_data.all_prompts[sample_index] + + "".join([model.tokenizer.pad_token] * 2), + prepend_bos=False + ), + prepend_bos=True, + )) +#%% +test_prompt( + clean_corrupt_data.all_prompts[0], + model.to_string(clean_corrupt_data.answer_tokens[0, 0, 0]), + model, + prepend_space_to_answer=False, +) +# %% +test_prompt( + clean_corrupt_data.all_prompts[1], + model.to_string(clean_corrupt_data.answer_tokens[0, 0, 1]), + model, + prepend_space_to_answer=False, +) +# %% diff --git a/direction_patching_results.py b/direction_patching_results.py index a14c92e..46efb08 100644 --- a/direction_patching_results.py +++ b/direction_patching_results.py @@ -39,93 +39,6 @@ def get_cached_csv( index_col=0, header=[0, 1], ) #%% -# def export_results( -# results: pd.DataFrame, metric_label: str, use_heads_label: str -# ) -> None: -# all_layers = pd.Series([extract_layer_from_string(label) for label in results.index]) -# das_treebank_layers = all_layers[results.index.str.contains("das_treebank")] -# if len(das_treebank_layers) > 0: -# mask = ~results.index.str.contains("das") | all_layers.isin(das_treebank_layers) -# mask.index = results.index -# results = results.loc[mask] - -# layers_style = ( -# flatten_multiindex(results) -# .style -# .background_gradient(cmap="Reds", axis=None, low=0, high=1) -# .format("{:.1f}%") -# .set_caption(f"Direction patching ({metric_label}, {use_heads_label}) in {model}") -# ) -# save_html(layers_style, f"direction_patching_{metric_label}_{use_heads_label}", model) -# display(layers_style) - -# missing_data = ( -# not results.columns.get_level_values(0).str.contains("treebank").any() or -# not results.columns.get_level_values(0).str.contains("simple").any() -# ) -# if missing_data: -# return - -# s_df = results[~results.index.str.contains("treebank")].copy() -# matches = s_df.index.str.extract(DIRECTION_PATTERN) -# multiindex = pd.MultiIndex.from_arrays(matches.values.T, names=['method', 'dataset', 'position', 'layer']) -# s_df.index = multiindex -# s_df = s_df.reset_index().groupby(['method', 'dataset', 'position']).max().drop('layer', axis=1, level=0) -# s_df = flatten_multiindex(s_df) -# s_df = s_df[["simple_test_ALL", "treebank_test_ALL"]] -# s_df.columns = s_df.columns.str.replace("test_", "").str.replace("_ALL", "") -# s_df.index = s_df.index.str.replace("_simple_train_ADJ", "") -# s_style = ( -# s_df -# .style -# .background_gradient(cmap="Reds") -# .format("{:.1f}%") -# .set_caption(f"Direction patching ({metric_label}, {use_heads_label}) in {model.name}") -# ) -# to_csv(s_df, f"direction_patching_{metric_label}_simple", model, index=True) -# save_html( -# s_style, f"direction_patching_{metric_label}_{use_heads_label}_simple", model, -# font_size=40, -# ) -# display(s_style) - -# t_df = results[results.index.str.contains("das_treebank") & ~results.index.str.contains("None")].copy() -# t_df = t_df.loc[:, t_df.columns.get_level_values(0).str.contains("treebank")] -# matches = t_df.index.str.extract(DIRECTION_PATTERN) -# multiindex = pd.MultiIndex.from_arrays(matches.values.T, names=['method', 'dataset', 'position', 'layer']) -# t_df.index = multiindex -# t_df = t_df.loc[t_df.index.get_level_values(-1).astype(int) < t_df.index.get_level_values(-1).astype(int).max() - 1] -# t_df.sort_index(level=3) -# t_df = flatten_multiindex(t_df) -# t_df.index = t_df.index.str.replace("das_treebank_train_ALL_0", "") -# t_df.columns = ["logit_diff"] -# t_df = t_df.T -# t_style = t_df.style.background_gradient(cmap="Reds").format("{:.1f}%") -# to_csv(t_df, f"direction_patching_{metric_label}_treebank", model, index=True) -# save_html(t_style, f"direction_patching_{metric_label}_{use_heads_label}_treebank", model) -# display(t_style) - -# p_df = results[~results.index.str.contains("treebank")].copy() -# matches = p_df.index.str.extract(DIRECTION_PATTERN) -# multiindex = pd.MultiIndex.from_arrays( -# matches.values.T, names=['method', 'dataset', 'position', 'layer'] -# ) -# p_df.index = multiindex -# p_df = p_df[("treebank_test", "ALL")] -# p_df = p_df.reset_index() -# p_df.columns = p_df.columns.get_level_values(0) -# p_df.layer = p_df.layer.astype(int) -# fig = px.line(x="layer", y="treebank_test", color="method", data_frame=p_df) -# fig.update_layout( -# title="Out-of-distribution directional patching performance by method and layer" -# ) -# fig.show() -# p_df = flatten_multiindex(p_df) -# if use_heads_label == "resid": -# to_csv(p_df, f"direction_patching_{metric_label}_layers", model, index=True) # FIXME: add {heads_label} -# save_html(fig, f"direction_patching_{metric_label}_{use_heads_label}_plot", model) -# save_pdf(fig, f"direction_patching_{metric_label}_{use_heads_label}_plot", model) -#%% def concat_metric_data( models: Iterable[str], metric_labels: List[str], use_heads_label: str, scaffold: ReviewScaffold = ReviewScaffold.CLASSIFICATION, @@ -170,7 +83,7 @@ def concat_metric_data( .set_caption(f"Direction patching ({metric_label}, {use_heads_label}) in {model}") ) save_html( - s_style, f"direction_patching_{metric_label}_{use_heads_label}_simple", model, + s_style, f"direction_patching_{use_heads_label}_simple", model, font_size=40, ) display(s_style) @@ -183,11 +96,14 @@ def concat_metric_data( #%% def concat_layer_data( models: Iterable[str], metric_label: str, use_heads_label: str, - scaffold: ReviewScaffold + scaffold: ReviewScaffold = ReviewScaffold.CONTINUATION ): layer_data = [] for model in models: results = get_cached_csv(metric_label, use_heads_label, scaffold, model) + if results.empty: + print(f"No results for {model}") + continue p_df = results[~results.index.str.contains("treebank")].copy() matches = p_df.index.str.extract(DIRECTION_PATTERN) multiindex = pd.MultiIndex.from_arrays( @@ -199,7 +115,6 @@ def concat_layer_data( p_df.columns = p_df.columns.get_level_values(0) p_df.layer = p_df.layer.astype(int) p_df['model'] = model - p_df['max_layer'] = p_df.layer.max() layer_data.append(p_df) layer_df = pd.concat(layer_data) layer_df = layer_df.loc[layer_df.method.isin([ @@ -216,31 +131,35 @@ def concat_layer_data( } ) fig.update_layout( - title="Out-of-distribution directional patching performance by method and layer", + title=dict( + text=f"Out-of-distribution direction patching performance by method and layer", + x=0.5, + ), width=1600, - height=500, - title_x=0.5, + height=400, font=dict( # global font settings - size=16 # global font size + size=24 # global font size ), ) for axis in fig.layout: if "xaxis" in axis: fig.layout[axis].matches = None - save_pdf(fig, f"direction_patching_{metric_label}_{use_heads_label}_facet_plot", model) - save_html(fig, f"direction_patching_{metric_label}_{use_heads_label}_facet_plot", model) - save_pdf(fig, f"direction_patching_{metric_label}_{use_heads_label}_facet_plot", model) + models_label = models[0].split("-")[0] + save_pdf(fig, f"direction_patching_{metric_label}_{use_heads_label}_{models_label}_facet_plot", model) + save_html(fig, f"direction_patching_{metric_label}_{use_heads_label}_{models_label}_facet_plot", model) + save_pdf(fig, f"direction_patching_{metric_label}_{use_heads_label}_{models_label}_facet_plot", model) fig.show() + save_pdf(fig, f"direction_patching_{metric_label}_{use_heads_label}_{models_label}_facet_plot", model) #%% concat_layer_data( - ["gpt2-small", "gpt2-medium", "gpt2-large", "gpt2-xl"], - "logit_diff", - "resid_gpt2" -) -#%% -concat_layer_data( - ["EleutherAI/pythia-160m", "EleutherAI/pythia-410m", "EleutherAI/pythia-1.4b", "EleutherAI/pythia-2.8b"], + [ + "gpt2-small", + # "gpt2-medium", + "pythia-160m", "pythia-410m", + "pythia-1.4b", + # "pythia-2.8b" + ], "logit_diff", - "resid_pythia" + "resid" ) #%% diff --git a/direction_patching_suite.py b/direction_patching_suite.py index bbfca86..554645b 100644 --- a/direction_patching_suite.py +++ b/direction_patching_suite.py @@ -22,7 +22,7 @@ from path_patching import act_patch, Node, IterNode from utils.prompts import CleanCorruptedCacheResults, get_dataset, PromptType, ReviewScaffold from utils.circuit_analysis import create_cache_for_dir_patching, logit_diff_denoising, prob_diff_denoising, logit_flip_denoising, PatchingMetric -from utils.store import save_array, load_array, save_html, save_pdf, to_csv, get_model_name, extract_layer_from_string, zero_pad_layer_string, DIRECTION_PATTERN, is_file, get_csv, get_csv_path, flatten_multiindex +from utils.store import save_array, load_array, save_html, save_pdf, to_csv, get_model_name, extract_layer_from_string, zero_pad_layer_string, DIRECTION_PATTERN, is_file, get_csv, get_csv_path, flatten_multiindex, save_text, load_text from utils.residual_stream import get_resid_name #%% torch.set_grad_enabled(False) @@ -36,24 +36,24 @@ # 'gpt2-medium', # 'gpt2-large', # 'gpt2-xl', - # 'EleutherAI/pythia-160m', - # 'EleutherAI/pythia-410m', + 'EleutherAI/pythia-160m', + 'EleutherAI/pythia-410m', 'EleutherAI/pythia-1.4b', - # 'EleutherAI/pythia-2.8b', + 'EleutherAI/pythia-2.8b', ] DIRECTION_GLOBS = [ # 'mean_diff_simple_train_ADJ*.npy', # 'pca_simple_train_ADJ*.npy', - # 'kmeans_simple_train_ADJ*.npy', - # 'logistic_regression_simple_train_ADJ*.npy', + 'kmeans_simple_train_ADJ*.npy', + 'logistic_regression_simple_train_ADJ*.npy', 'das_simple_train_ADJ_layer*.npy', # 'das2d_simple_train_ADJ*.npy', # 'das3d_simple_train_ADJ*.npy', - 'random_direction_layer*.npy', + # 'random_direction_layer*.npy', # 'das_treebank*.npy', ] PROMPT_TYPES = [ - PromptType.SIMPLE_TEST, + # PromptType.SIMPLE_TEST, PromptType.TREEBANK_TEST, # PromptType.SIMPLE_TRAIN, # PromptType.COMPLETION, @@ -61,10 +61,10 @@ # PromptType.SIMPLE_MOOD, # PromptType.SIMPLE_FRENCH, ] -SCAFFOLD = ReviewScaffold.CLASSIFICATION +SCAFFOLD = ReviewScaffold.CONTINUATION METRICS = [ PatchingMetric.LOGIT_DIFF_DENOISING, - PatchingMetric.LOGIT_FLIP_DENOISING, + # PatchingMetric.LOGIT_FLIP_DENOISING, # PatchingMetric.PROB_DIFF_DENOISING, ] USE_HEADS = [False, ] @@ -88,7 +88,7 @@ def get_directions(model: HookedTransformer) -> Tuple[List[np.ndarray], List[str path for glob_str in DIRECTION_GLOBS for path in glob.glob(os.path.join('data', model_name, glob_str)) - if "None" not in path and "_all_" not in path + if "None" not in path and "_all_" not in path and "_activations" not in path ] direction_labels = [os.path.split(path)[-1] for path in direction_paths] del direction_paths @@ -274,17 +274,18 @@ def get_dataset_cached( model, "cpu", prompt_type=prompt_type, scaffold=scaffold ) - # Filter by padding - clean_corrupt_data = clean_corrupt_data.restrict_by_padding( - min_tokens=min_tokens, max_tokens=max_tokens - ) + # FIXME: need to uncomment if using max_tokens + # # Filter by padding + # clean_corrupt_data = clean_corrupt_data.restrict_by_padding( + # min_tokens=min_tokens, max_tokens=max_tokens + # ) dataset_cache[key] = clean_corrupt_data return dataset_cache[key] #%% -def get_results_cached( +def get_result_cached( patching_metric_base: PatchingMetric, prompt_type: PromptType, position: str, @@ -302,13 +303,13 @@ def get_results_cached( disable_tqdm: bool = True, ): use_csv = USE_CACHE and heads is None and not all_layers - csv_name = ( + txt_name = ( patching_metric_base.__name__.replace('_denoising', '') + f"_{prompt_type.value}_{scaffold}_{min_tokens}_{max_tokens}_{position}_" - f"{direction_label}.csv" + f"{direction_label}.txt" ) - if use_csv and is_file(csv_name, model): - return get_csv(csv_name, model, index_col=0, header=[0, 1]) + if use_csv and is_file(txt_name, model): + return float(load_text(txt_name, model)) result = get_results_for_direction_and_position( patching_metric_base=patching_metric_base, prompt_type=prompt_type, @@ -327,7 +328,7 @@ def get_results_cached( disable_tqdm=disable_tqdm, ) if use_csv: - to_csv(result, csv_name.replace(".csv", ""), model, index=True) + save_text(str(result), txt_name, model) return result @@ -459,7 +460,7 @@ def get_results_for_metric( placeholders = ['ALL'] for position in placeholders: column = pd.MultiIndex.from_tuples([(prompt_type.value, position)], names=['prompt', 'position']) - result = get_results_cached( + result = get_result_cached( patching_metric_base=patching_metric_base, prompt_type=prompt_type, position=position, diff --git a/fit_directions.py b/fit_directions.py index 3228388..55e868f 100644 --- a/fit_directions.py +++ b/fit_directions.py @@ -43,8 +43,8 @@ # 'gpt2-xl', # 'EleutherAI/pythia-160m', # 'EleutherAI/pythia-410m', - 'EleutherAI/pythia-1.4b', - # 'EleutherAI/pythia-2.8b', + # 'EleutherAI/pythia-1.4b', + 'EleutherAI/pythia-2.8b', ] METHODS = [ # ClassificationMethod.KMEANS, @@ -53,8 +53,8 @@ # ClassificationMethod.MEAN_DIFF, # ClassificationMethod.LOGISTIC_REGRESSION, GradientMethod.DAS, - GradientMethod.DAS2D, - GradientMethod.DAS3D, + # GradientMethod.DAS2D, + # GradientMethod.DAS3D, ] TRAIN_TYPES = [ PromptType.SIMPLE_TRAIN, @@ -100,11 +100,11 @@ def select_layers( if n_layers <= 12: return list(range(n_layers + 1)) if n_layers <= 24: - return list(range(0, n_layers + 1, 1)) + return list(range(0, n_layers + 1, 2)) if n_layers <= 36: return list(range(0, n_layers + 1, 3)) if n_layers <= 48: - return list(range(0, n_layers + 1, 4)) + return list(range(0, n_layers + 1, 8)) #%% # sweep_model = HookedClassifier.from_pretrained( @@ -258,6 +258,10 @@ def select_layers( **kwargs ) print(f"Saving classification direction to {cls_path}") +#%% +# ============================================================================ # +# # END OF ACTUAL DIRECTION FITTING + #%% # ============================================================================ # # Summary stats diff --git a/negation_experiment.py b/negation_experiment.py index b68c6a6..3ff0778 100644 --- a/negation_experiment.py +++ b/negation_experiment.py @@ -52,7 +52,7 @@ "I don't find it amusing at all.": 'amusing', "I don't like you.": 'like', "I don't respect that.": 'respect', - "I don't trust you.": 'trust', + # "I don't trust you.": 'trust', "It's hardly a success from my perspective.": 'success', "It's hardly a triumph in my eyes.": 'triumph', "It's hardly a victory in my eyes.": 'victory', @@ -72,6 +72,7 @@ "I don't want to be with you": 'want', "I don't want this.": 'want', } +print(len(NEGATIONS)) #%% device = "cuda" MODEL_NAME = "gpt2-small" @@ -86,12 +87,18 @@ def run_experiment( initial_layer: int, final_layer: int, ) -> pd.DataFrame: - direction_scores = {} - for direction_label in tqdm(DIRECTIONS): + direction_scores = [] + bar = tqdm(DIRECTIONS) + for direction_label in bar: + bar.set_description(direction_label) direction = load_array(direction_label, model) direction = torch.tensor(direction, device=device, dtype=torch.float32) activations = get_activations_cached(dataloader, direction_label, model) + mean = activations[:, :, :11].flatten().mean().item() std_dev = activations[:, :, :11].flatten().std().item() + print(direction_label, mean, std_dev) + raw_scores = [] + flip_scores = [] z_scores = [] for text, word in NEGATIONS.items(): str_tokens = [tok.strip() for tok in model.to_str_tokens(text)] @@ -106,11 +113,29 @@ def run_experiment( text_activations[0, word_idx, initial_layer] ).item() z_score = abs(act_change) / std_dev + flip_score = 0.5 * act_change / ( + mean - text_activations[0, word_idx, initial_layer] + ) z_scores.append(z_score) - overall_score = np.mean(z_scores) - direction_scores[direction_label] = overall_score - df = pd.DataFrame.from_dict(direction_scores, orient="index", columns=["z_score"]) - df = df.sort_values("z_score", ascending=False) + flip_scores.append(flip_score) + raw_scores.append(act_change) + if flip_score < -30: + print( + text, flip_score, + text_activations[0, word_idx, initial_layer], + text_activations[0, word_idx, final_layer], + ) + direction_scores.append([ + np.mean(raw_scores), + np.mean(z_scores), + np.mean(flip_scores), + ]) + df = pd.DataFrame( + direction_scores, + columns=["raw_score", "z_score", "flip_score"], + index=DIRECTIONS, + ) + df = df.sort_values("raw_score", ascending=False) to_csv(df, "negation_experiment", model) return df #%% diff --git a/neuron_directions.py b/neuron_directions.py index ae578fd..d08b40a 100644 --- a/neuron_directions.py +++ b/neuron_directions.py @@ -74,6 +74,8 @@ hover_name="neuron", marginal="box", nbins=200, + histnorm="percent", + barmode="overlay", ) for index, row in sim_df.iterrows(): if row["neuron"] in ("L6N828", "L3N1605", "L5N671", "L6N1237"): diff --git a/resample_ablation.py b/resample_ablation.py index d4e63f8..563fa47 100644 --- a/resample_ablation.py +++ b/resample_ablation.py @@ -35,7 +35,7 @@ from path_patching import act_patch, Node, IterNode # %% # Model loading device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") -MODEL_NAME = 'gpt2-small' +MODEL_NAME = 'pythia-2.8b' model = HookedTransformer.from_pretrained( MODEL_NAME, center_unembed=True, @@ -57,7 +57,7 @@ # %% example_prompt = model.to_str_tokens(clean_tokens[0]) adj_token = example_prompt.index(' perfect') -verb_token = example_prompt.index(' loved') +verb_token = example_prompt.index(' enjoyed') s2_token = example_prompt.index(' movie', example_prompt.index(' movie') + 1) end_token = len(example_prompt) - 1 # %% @@ -100,18 +100,18 @@ # ] circuit_heads = [ # gpt2-small - (0, 4), - (7, 1), - (9, 2), - (10, 1), - (10, 4), - (11, 9), - (8, 5), - (9, 2), - (9, 10), - (6, 4), - (7, 1), - (7, 5), + # (0, 4), + # (7, 1), + # (9, 2), + # (10, 1), + # (10, 4), + # (11, 9), + # (8, 5), + # (9, 2), + # (9, 10), + # (6, 4), + # (7, 1), + # (7, 5), @@ -121,7 +121,12 @@ # (13, 13), # (18, 2), # (21, 0), + +# pythia-2.8b +(17, 19), (22, 5), (14,4), (20, 10), (12, 2), (10, 26), + (12, 4), (12, 17), (14, 2), (13, 20), (9, 29), (11, 16) ] + non_circuit_heads = [ (layer, head) for layer in range(model.cfg.n_layers) @@ -138,19 +143,19 @@ # %% # gpt2-small circuit_heads_positions = [ - (0, 4, adj_token), - (0, 4, verb_token), - (7, 1, end_token), - (9, 2, end_token), - (10, 1, end_token), - (10, 4, end_token), - (11, 9, end_token), - (8, 5, end_token), - (9, 2, end_token), - (9, 10, end_token), - (6, 4, s2_token), - (7, 1, s2_token), - (7, 5, s2_token), + # (0, 4, adj_token), + # (0, 4, verb_token), + # (7, 1, end_token), + # (9, 2, end_token), + # (10, 1, end_token), + # (10, 4, end_token), + # (11, 9, end_token), + # (8, 5, end_token), + # (9, 2, end_token), + # (9, 10, end_token), + # (6, 4, s2_token), + # (7, 1, s2_token), + # (7, 5, s2_token), ] # pythia-1.4b # circuit_heads_positions = [ diff --git a/treebank_data_gen.py b/treebank_data_gen.py index 6a5865e..838c9ab 100644 --- a/treebank_data_gen.py +++ b/treebank_data_gen.py @@ -32,6 +32,7 @@ from transformer_lens import HookedTransformer import torch import plotly.express as px +from tqdm.auto import tqdm from utils.treebank import get_merged_dataframe, convert_to_dataset_dict, create_datasets_for_model # %% device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') @@ -44,19 +45,19 @@ convert_to_dataset_dict(sentence_phrase_df) # %% MODELS = [ - # 'gpt2-small', - # 'gpt2-medium', - # 'gpt2-large', - # 'gpt2-xl', - # 'EleutherAI/pythia-160m', - # 'EleutherAI/pythia-410m', + 'gpt2-small', + 'gpt2-medium', + 'gpt2-large', + 'gpt2-xl', + 'EleutherAI/pythia-160m', + 'EleutherAI/pythia-410m', 'EleutherAI/pythia-1.4b', - # 'EleutherAI/pythia-2.8b', + 'EleutherAI/pythia-2.8b', ] -for model in MODELS: +for model in tqdm(MODELS): model = HookedTransformer.from_pretrained(model, device=device) create_datasets_for_model( model, sentence_phrase_df, padding_side="left", batch_size=16, ) -#%% \ No newline at end of file + #%%