diff --git a/tests/test_circuit_analysis.py b/tests/test_circuit_analysis.py index 2db4e17..0fa2360 100644 --- a/tests/test_circuit_analysis.py +++ b/tests/test_circuit_analysis.py @@ -29,7 +29,7 @@ from utils.circuit_analysis import ( get_logit_diff, residual_stack_to_logit_diff, cache_to_logit_diff, project_to_subspace, create_cache_for_dir_patching, get_prob_diff, - get_final_non_pad_token, + get_final_non_pad_logits, ) import unittest @@ -218,7 +218,7 @@ def test_simple_case(self): [0.3, 0.2, 0.1] ]) self.assertTrue(torch.allclose( - get_final_non_pad_token(self.logits, attention_mask), expected + get_final_non_pad_logits(self.logits, attention_mask), expected )) def test_all_zeros(self): @@ -229,7 +229,7 @@ def test_all_zeros(self): ]) # check raises assertion error with self.assertRaises(AssertionError): - get_final_non_pad_token(self.logits, attention_mask) + get_final_non_pad_logits(self.logits, attention_mask) def test_all_ones(self): # Test case 3: All ones in attention mask @@ -242,7 +242,7 @@ def test_all_ones(self): [0.3, 0.2, 0.1] ]) self.assertTrue(torch.allclose( - get_final_non_pad_token(self.logits, attention_mask), expected + get_final_non_pad_logits(self.logits, attention_mask), expected )) def test_single_batch(self): @@ -257,7 +257,7 @@ def test_single_batch(self): [0.7, 0.8] ]) self.assertTrue(torch.allclose( - get_final_non_pad_token(logits, attention_mask), expected + get_final_non_pad_logits(logits, attention_mask), expected )) diff --git a/tests/test_das.py b/tests/test_das.py index 4defd64..6530687 100644 --- a/tests/test_das.py +++ b/tests/test_das.py @@ -42,7 +42,9 @@ def test_act_patch_simple(self): orig_input = "test" new_value = torch.full((3, 5), 2) patching_metric = lambda x: torch.sum(x) - out = act_patch_simple(model, orig_input, new_value, layer, 1, patching_metric) + out = act_patch_simple( + model, orig_input, new_value, layer, patching_metric + ) self.assertEqual(out.item(), 30) def test_training_config(self): @@ -71,7 +73,6 @@ def test_train_das_direction(self): 'attn-only-1l', device=device, ).train() - model.name = 'test' direction, save_path = train_das_subspace( model, device, PromptType.SIMPLE, 'ADJ', 0, diff --git a/utils/das.py b/utils/das.py index 049943d..2551715 100644 --- a/utils/das.py +++ b/utils/das.py @@ -86,10 +86,19 @@ def hook_fn_base( "batch d_model -> batch pos d_model", pos=seq_len, ) - position_mask = einops.repeat( - torch.arange(seq_len, device=resid.device) == position, - "pos -> batch pos d_model", + seq_len_rep = einops.repeat( + torch.arange(seq_len, device=resid.device), + "pos -> batch pos", batch=batch_size, + ) + position_rep = einops.repeat( + position, + "batch -> batch pos", + pos=seq_len, + ) + position_mask = einops.repeat( + position_rep == seq_len_rep, + "batch pos -> batch pos d_model", d_model=d_model, ) out = torch.where( @@ -449,8 +458,8 @@ def get_das_dataset( orig_resid: Float[Tensor, "batch *pos d_model"] = results.corrupted_cache[act_name] new_resid: Float[Tensor, "batch *pos d_model"] = results.clean_cache[act_name] if orig_resid.ndim == 3: - orig_resid = orig_resid[:, clean_corrupt_data.position, :] - new_resid = new_resid[:, clean_corrupt_data.position, :] + orig_resid = orig_resid[torch.arange(len(orig_resid)), clean_corrupt_data.position, :] + new_resid = new_resid[torch.arange(len(new_resid)), clean_corrupt_data.position, :] # Create a TensorDataset from the tensors das_dataset = TensorDataset( clean_corrupt_data.corrupted_tokens.detach().cpu(), diff --git a/utils/prompts.py b/utils/prompts.py index 14222fc..8f035f5 100644 --- a/utils/prompts.py +++ b/utils/prompts.py @@ -508,7 +508,10 @@ def __init__( is_positive = ( answer_tokens[:, 0, 0] == answer_tokens[0, 0, 0] ) - if position is None and prompt_type is not None and len(prompt_type.get_placeholders()) == 0: + no_placeholders = prompt_type is None or ( + prompt_type is not None and len(prompt_type.get_placeholders()) == 0 + ) + if position is None and no_placeholders: mask = get_attention_mask( tokenizer, clean_tokens, prepend_bos=False ) @@ -520,14 +523,13 @@ def __init__( if isinstance(position, str): assert tokenizer is not None and prompt_type is not None example = tokenizer.tokenize(all_prompts[0]) + example = [t.replace("Ġ", " ") for t in example] placeholders = prompt_type.get_placeholder_positions(example) pos: int = placeholders[position][-1] - pos_idx = einops.repeat( - torch.arange(clean_tokens.shape[1]), - "pos -> batch pos", - batch=clean_tokens.shape[0], + position = torch.full_like( + is_positive, pos, dtype=torch.int32, device=is_positive.device ) - position = pos_idx == pos + assert isinstance(position, torch.Tensor) self.clean_tokens = clean_tokens self.corrupted_tokens = corrupted_tokens self.answer_tokens = answer_tokens @@ -844,10 +846,14 @@ def set_accuracy(self): def center_logit_diffs(self): answer_tokens = self.dataset.answer_tokens self.corrupted_logit_diffs, self.corrupted_logit_bias = center_logit_diffs( - self.corrupted_logit_diffs, answer_tokens + self.corrupted_logit_diffs, + answer_tokens, + is_positive=self.dataset.is_positive, ) self.clean_logit_diffs, self.clean_logit_bias = center_logit_diffs( - self.clean_logit_diffs, answer_tokens + self.clean_logit_diffs, + answer_tokens, + is_positive=self.dataset.is_positive, ) @@ -966,6 +972,7 @@ def _get_dataset( corrupted_tokens=corrupted_tokens, tokenizer=model.tokenizer, position=position, + prompt_type=prompt_type, )