Skip to content

Commit

Permalink
added dupes to error message
Browse files Browse the repository at this point in the history
  • Loading branch information
ojh31 committed Oct 12, 2023
1 parent 72cd307 commit 8d554df
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions utils/residual_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,18 @@
from transformer_lens import HookedTransformer, ActivationCache
from transformer_lens.utils import get_act_name
from tqdm.auto import tqdm
from collections import Counter
from utils.prompts import ReviewScaffold, get_dataset, PromptType


def find_duplicates(lst):
# Count occurrences of each item in the list
counter = Counter(lst)
# Find items where count > 1, indicating they are duplicates
duplicates = [item for item, count in counter.items() if count > 1]
return duplicates


def get_resid_name(layer: int, model: HookedTransformer) -> Tuple[str, int]:
resid_type = 'resid_pre'
if layer == model.cfg.n_layers:
Expand Down Expand Up @@ -55,6 +64,19 @@ def __init__(
f"{tok}"
for tok in model.to_str_tokens(label_tensor)
]
assert len(str_tokens) == len(self.prompt_tokens)
str_tokens_dups = find_duplicates(str_tokens)
assert len(str_tokens_dups) == 0, (
"to_string must return a list of unique strings of the "
"same length as the input tensor.\n"
f"to_string dupes: {str_tokens_dups}, "
f"to_string shape: {len(str_tokens)}, "
f"tensor shape: {self.prompt_tokens.shape}\n"
f"Full output: {str_tokens}\n"
f"Tensor: {label_tensor}\n"
f"Position: {position}\n"
f"Prompt type: {prompt_type}\n"
)
to_str_check = (
len(str_tokens) == len(self.prompt_tokens) and
len(set(str_tokens)) == len(str_tokens)
Expand Down

0 comments on commit 8d554df

Please sign in to comment.