From 622b9e9f94b57e11a1d9b5ebfc3aaded3dfd7362 Mon Sep 17 00:00:00 2001 From: skar0 Date: Thu, 12 Oct 2023 17:49:52 +0100 Subject: [PATCH] Set label in tests and added batch number to str_labels --- tests/test_prompts.py | 5 +++-- utils/residual_stream.py | 9 +++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/test_prompts.py b/tests/test_prompts.py index 211a0f5..d6cdb16 100644 --- a/tests/test_prompts.py +++ b/tests/test_prompts.py @@ -16,8 +16,8 @@ def setUp(self): self.answer_tokens = torch.tensor([[2, 3], [4, 5]], dtype=torch.int32) self.all_prompts = ["prompt1", "prompt2"] self.dataset = CleanCorruptedDataset( - self.clean_tokens, self.corrupted_tokens, self.answer_tokens, self.all_prompts, - tokenizer=None, + self.clean_tokens, self.corrupted_tokens, self.answer_tokens, + self.all_prompts, tokenizer=None, label="test", ) def test_initialization(self): @@ -59,6 +59,7 @@ def setUp(self): torch.tensor([[2, 3], [4, 5]], dtype=torch.int), ["prompt1", "prompt2"], tokenizer=None, + label="test", ) corrupted_cache = ActivationCache({}, model=None) # Assuming a mock model clean_cache = ActivationCache({}, model=None) # Assuming a mock model diff --git a/utils/residual_stream.py b/utils/residual_stream.py index 8934a36..37db73e 100644 --- a/utils/residual_stream.py +++ b/utils/residual_stream.py @@ -51,9 +51,10 @@ def __init__( label_tensor = self.prompt_tokens[ torch.arange(len(self.prompt_tokens)), self.position ].cpu().detach() - str_tokens = model.to_str_tokens(label_tensor) - assert isinstance(str_tokens, list), "to_string must return a list" - assert isinstance(str_tokens[0], str), "to_string must return a list of strings" + str_tokens = [ + f"{i}:{tok}" + for i, tok in enumerate(model.to_str_tokens(label_tensor)) + ] to_str_check = ( len(str_tokens) == len(self.prompt_tokens) and len(set(str_tokens)) == len(str_tokens) @@ -69,7 +70,7 @@ def __init__( f"Position: {position}\n" f"Prompt type: {prompt_type}\n" ) - self.str_labels = str_tokens # type: ignore + self.str_labels = str_tokens @property def is_positive(self) -> Bool[Tensor, "batch"]: