Skip to content

Commit

Permalink
Set label in tests and added batch number to str_labels
Browse files Browse the repository at this point in the history
  • Loading branch information
ojh31 committed Oct 12, 2023
1 parent 9f87140 commit 622b9e9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
5 changes: 3 additions & 2 deletions tests/test_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def setUp(self):
self.answer_tokens = torch.tensor([[2, 3], [4, 5]], dtype=torch.int32)
self.all_prompts = ["prompt1", "prompt2"]
self.dataset = CleanCorruptedDataset(
self.clean_tokens, self.corrupted_tokens, self.answer_tokens, self.all_prompts,
tokenizer=None,
self.clean_tokens, self.corrupted_tokens, self.answer_tokens,
self.all_prompts, tokenizer=None, label="test",
)

def test_initialization(self):
Expand Down Expand Up @@ -59,6 +59,7 @@ def setUp(self):
torch.tensor([[2, 3], [4, 5]], dtype=torch.int),
["prompt1", "prompt2"],
tokenizer=None,
label="test",
)
corrupted_cache = ActivationCache({}, model=None) # Assuming a mock model
clean_cache = ActivationCache({}, model=None) # Assuming a mock model
Expand Down
9 changes: 5 additions & 4 deletions utils/residual_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ def __init__(
label_tensor = self.prompt_tokens[
torch.arange(len(self.prompt_tokens)), self.position
].cpu().detach()
str_tokens = model.to_str_tokens(label_tensor)
assert isinstance(str_tokens, list), "to_string must return a list"
assert isinstance(str_tokens[0], str), "to_string must return a list of strings"
str_tokens = [
f"{i}:{tok}"
for i, tok in enumerate(model.to_str_tokens(label_tensor))
]
to_str_check = (
len(str_tokens) == len(self.prompt_tokens) and
len(set(str_tokens)) == len(str_tokens)
Expand All @@ -69,7 +70,7 @@ def __init__(
f"Position: {position}\n"
f"Prompt type: {prompt_type}\n"
)
self.str_labels = str_tokens # type: ignore
self.str_labels = str_tokens

@property
def is_positive(self) -> Bool[Tensor, "batch"]:
Expand Down

0 comments on commit 622b9e9

Please sign in to comment.