Skip to content

Commit

Permalink
Found bad cosine sims for stable-lm-base-alpha-3b
Browse files Browse the repository at this point in the history
  • Loading branch information
ojh31 committed Nov 21, 2023
1 parent 9c0f06f commit 30df221
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 145 deletions.
45 changes: 12 additions & 33 deletions classifier_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,15 @@
torch.set_grad_enabled(False)
# %%
DIRECTIONS = [
# "kmeans_simple_train_ADJ_layer1",
"kmeans_simple_train_ADJ_layer1",
"pca_simple_train_ADJ_layer1",
# "mean_diff_simple_train_ADJ_layer1",
# "logistic_regression_simple_train_ADJ_layer1",
# "das_simple_train_ADJ_layer1",
"mean_diff_simple_train_ADJ_layer1",
"logistic_regression_simple_train_ADJ_layer1",
"das_simple_train_ADJ_layer1",
]
# %%
device = "cuda"
MODEL_NAME = "pythia-2.8b"
MODEL_NAME = "stablelm-base-alpha-3b"
BATCH_SIZE = 16
model = HookedTransformer.from_pretrained(
MODEL_NAME,
Expand Down Expand Up @@ -119,31 +119,6 @@ def extract_text_window(
return str_tokens # type: ignore


def extract_text_window(
batch: int,
pos: int,
dataloader: torch.utils.data.DataLoader,
model: HookedTransformer,
window_size: int = 10,
) -> List[str]:
"""Helper function to get the text window around a position in a batch (used in topk plotting)"""
assert model.tokenizer is not None
expected_size = 2 * window_size + 1
lb, ub = get_window(batch, pos, dataloader=dataloader, window_size=window_size)
tokens = dataloader.dataset[batch]["tokens"][lb:ub]
str_tokens = model.to_str_tokens(tokens, prepend_bos=False)
padding_to_add = expected_size - len(str_tokens)
if padding_to_add > 0 and model.tokenizer.padding_side == "right":
str_tokens += [model.tokenizer.bos_token] * padding_to_add
elif padding_to_add > 0 and model.tokenizer.padding_side == "left":
str_tokens = [model.tokenizer.bos_token] * padding_to_add + str_tokens
assert len(str_tokens) == expected_size, (
f"Expected text window of size {expected_size}, "
f"found {len(str_tokens)}: {str_tokens}"
)
return str_tokens # type: ignore


# %%
def sample_by_bin(
data: Float[Tensor, "batch pos"],
Expand Down Expand Up @@ -221,7 +196,9 @@ def sample_by_bin(


# %% # Save samples
for direction_label in tqdm(DIRECTIONS):
bar = tqdm(DIRECTIONS)
for direction_label in bar:
bar.set_description(direction_label)
samples_path = direction_label + "_bin_samples.csv"
if is_file(samples_path, model):
continue
Expand All @@ -245,8 +222,6 @@ def plot_bin_proportions(
assert nbins is not None
if df.activation.dtype == pd.StringDtype:
df.activation = df.activation.map(lambda x: eval(x).item()).astype(float)
# if "das" in label:
# df.activation *= -1
sentiments = sorted(df["sentiment"].unique())
df = df.sort_values(by="activation").reset_index(drop=True)
df["activation_cut"] = pd.cut(df.activation, bins=nbins)
Expand Down Expand Up @@ -309,6 +284,8 @@ def plot_bin_proportions(
"labelled_" + direction_label + "_bin_samples", model
)
out_name = direction_label + "_bin_proportions"
if "das" in direction_label or "pca" in direction_label:
labelled_bin_samples.activation *= -1
fig = plot_bin_proportions(
labelled_bin_samples, f"{direction_label.split('_')[0]}, {model.cfg.model_name}"
)
Expand All @@ -320,6 +297,8 @@ def plot_bin_proportions(
activations = get_activations_cached(dataloader, direction_label, model)
flat = activations[:, :, 1].flatten().cpu()
flat = flat[flat != 0]
if "das" in direction_label or "pca" in direction_label:
flat *= -1
positive_threshold = flat.quantile(0.999).item()
negative_threshold = flat.quantile(0.001).item()

Expand Down
Loading

0 comments on commit 30df221

Please sign in to comment.