Skip to content

Commit

Permalink
Filtered and labelled scatterplot
Browse files Browse the repository at this point in the history
  • Loading branch information
ojh31 committed Oct 16, 2023
1 parent 143adee commit 25dc2b4
Showing 1 changed file with 66 additions and 44 deletions.
110 changes: 66 additions & 44 deletions fit_one_sided_directions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.utils.data import DataLoader
from datasets import load_dataset
from jaxtyping import Float, Int, Bool
from typing import Dict, Iterable, List, Tuple, Union
from typing import Dict, Iterable, List, Optional, Tuple, Union
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_dataset, tokenize_and_concatenate, get_act_name, test_prompt
from transformer_lens.hook_points import HookPoint
Expand Down Expand Up @@ -132,44 +132,67 @@
sentiment_direction = sentiment_direction.to(device)
torch.cosine_similarity(is_valenced_direction, sentiment_direction, dim=0)
#%%
all_tokens = torch.tensor([], dtype=torch.int32, device=device,)
val_scores = torch.tensor([], dtype=torch.float32, device=device,)
sent_scores = torch.tensor([], dtype=torch.float32, device=device,)
for batch in tqdm(data_loader):
batch_tokens = batch['tokens'].to(device)
_, cache = model.run_with_cache(
batch_tokens,
return_type=None,
names_filter = lambda name: name == ACT_NAME
)
val_score = einops.einsum(
cache[ACT_NAME],
is_valenced_direction,
"batch pos d_model, d_model -> batch pos",
)
sent_score = einops.einsum(
cache[ACT_NAME],
sentiment_direction,
"batch pos d_model, d_model -> batch pos",
)
val_score = einops.rearrange(
val_score, "batch pos -> (batch pos)"
)
sent_score = einops.rearrange(
sent_score, "batch pos -> (batch pos)"
)
flat_tokens = einops.rearrange(
batch_tokens, "batch pos -> (batch pos)"
)
all_tokens = torch.cat([all_tokens, flat_tokens])
val_scores = torch.cat([val_scores, val_score])
sent_scores = torch.cat([sent_scores, sent_score])
if len(all_tokens) > 10_000:
break
val_scores = val_scores.cpu().numpy()
sent_scores = sent_scores.cpu().numpy()
all_tokens = all_tokens.cpu().numpy()
print(val_scores.shape, sent_scores.shape, all_tokens.shape)
def get_token_sentiment_valence(
max_tokens: int = 10_000,
max_sentiment: Optional[float] = None,
min_valence: Optional[float] = None,
):
all_tokens = torch.tensor([], dtype=torch.int32, device=device,)
val_scores = torch.tensor([], dtype=torch.float32, device=device,)
sent_scores = torch.tensor([], dtype=torch.float32, device=device,)
for batch in tqdm(data_loader):
batch_tokens = batch['tokens'].to(device)
_, cache = model.run_with_cache(
batch_tokens,
return_type=None,
names_filter = lambda name: name == ACT_NAME
)
val_score = einops.einsum(
cache[ACT_NAME],
is_valenced_direction,
"batch pos d_model, d_model -> batch pos",
)
sent_score = einops.einsum(
cache[ACT_NAME],
sentiment_direction,
"batch pos d_model, d_model -> batch pos",
)
val_score = einops.rearrange(
val_score, "batch pos -> (batch pos)"
)
sent_score = einops.rearrange(
sent_score, "batch pos -> (batch pos)"
)
flat_tokens = einops.rearrange(
batch_tokens, "batch pos -> (batch pos)"
)
mask = torch.ones_like(flat_tokens, dtype=torch.bool)
if max_sentiment is not None:
mask &= sent_score.abs() < max_sentiment
if min_valence is not None:
mask &= val_score > min_valence
flat_tokens = flat_tokens[mask]
val_score = val_score[mask]
sent_score = sent_score[mask]
all_tokens = torch.cat([all_tokens, flat_tokens])
val_scores = torch.cat([val_scores, val_score])
sent_scores = torch.cat([sent_scores, sent_score])
if len(all_tokens) > max_tokens:
break
val_scores = val_scores.cpu().numpy()
sent_scores = sent_scores.cpu().numpy()
all_tokens = all_tokens.cpu().numpy()
print(val_scores.shape, sent_scores.shape, all_tokens.shape)
return val_scores, sent_scores, all_tokens
#%%
val_scores, sent_scores, all_tokens = get_token_sentiment_valence(
max_tokens=100,
max_sentiment=0.5,
min_valence=20,
)
if len(all_tokens) <= 1_000:
all_tokens = model.to_str_tokens(all_tokens)
save_text("\n".join(all_tokens), "valenced_tokens", model)
#%%
fig = px.scatter(
x=val_scores,
Expand All @@ -181,14 +204,13 @@
title=dict(
text="Valenced vs Sentiment activations",
x=0.5,
)
),
font=dict(
size=8,
),
)
fig.show()
#%%
model.to_str_tokens(np.array([
49435, 35333, 42262, 44406, 32554, 25872, 23590, 39609
]))
#%%
save_array(
positive_direction.cpu().numpy(), "mean_diff_positive_layer01", model
)
Expand Down

0 comments on commit 25dc2b4

Please sign in to comment.