Skip to content

Commit

Permalink
Fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ojh31 committed Oct 12, 2023
1 parent 253296e commit 1e874b5
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 20 deletions.
10 changes: 5 additions & 5 deletions tests/test_circuit_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from utils.circuit_analysis import (
get_logit_diff, residual_stack_to_logit_diff, cache_to_logit_diff,
project_to_subspace, create_cache_for_dir_patching, get_prob_diff,
get_final_non_pad_token,
get_final_non_pad_logits,
)
import unittest

Expand Down Expand Up @@ -218,7 +218,7 @@ def test_simple_case(self):
[0.3, 0.2, 0.1]
])
self.assertTrue(torch.allclose(
get_final_non_pad_token(self.logits, attention_mask), expected
get_final_non_pad_logits(self.logits, attention_mask), expected
))

def test_all_zeros(self):
Expand All @@ -229,7 +229,7 @@ def test_all_zeros(self):
])
# check raises assertion error
with self.assertRaises(AssertionError):
get_final_non_pad_token(self.logits, attention_mask)
get_final_non_pad_logits(self.logits, attention_mask)

def test_all_ones(self):
# Test case 3: All ones in attention mask
Expand All @@ -242,7 +242,7 @@ def test_all_ones(self):
[0.3, 0.2, 0.1]
])
self.assertTrue(torch.allclose(
get_final_non_pad_token(self.logits, attention_mask), expected
get_final_non_pad_logits(self.logits, attention_mask), expected
))

def test_single_batch(self):
Expand All @@ -257,7 +257,7 @@ def test_single_batch(self):
[0.7, 0.8]
])
self.assertTrue(torch.allclose(
get_final_non_pad_token(logits, attention_mask), expected
get_final_non_pad_logits(logits, attention_mask), expected
))


Expand Down
5 changes: 3 additions & 2 deletions tests/test_das.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def test_act_patch_simple(self):
orig_input = "test"
new_value = torch.full((3, 5), 2)
patching_metric = lambda x: torch.sum(x)
out = act_patch_simple(model, orig_input, new_value, layer, 1, patching_metric)
out = act_patch_simple(
model, orig_input, new_value, layer, patching_metric
)
self.assertEqual(out.item(), 30)

def test_training_config(self):
Expand Down Expand Up @@ -71,7 +73,6 @@ def test_train_das_direction(self):
'attn-only-1l',
device=device,
).train()
model.name = 'test'
direction, save_path = train_das_subspace(
model, device,
PromptType.SIMPLE, 'ADJ', 0,
Expand Down
19 changes: 14 additions & 5 deletions utils/das.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,19 @@ def hook_fn_base(
"batch d_model -> batch pos d_model",
pos=seq_len,
)
position_mask = einops.repeat(
torch.arange(seq_len, device=resid.device) == position,
"pos -> batch pos d_model",
seq_len_rep = einops.repeat(
torch.arange(seq_len, device=resid.device),
"pos -> batch pos",
batch=batch_size,
)
position_rep = einops.repeat(
position,
"batch -> batch pos",
pos=seq_len,
)
position_mask = einops.repeat(
position_rep == seq_len_rep,
"batch pos -> batch pos d_model",
d_model=d_model,
)
out = torch.where(
Expand Down Expand Up @@ -449,8 +458,8 @@ def get_das_dataset(
orig_resid: Float[Tensor, "batch *pos d_model"] = results.corrupted_cache[act_name]
new_resid: Float[Tensor, "batch *pos d_model"] = results.clean_cache[act_name]
if orig_resid.ndim == 3:
orig_resid = orig_resid[:, clean_corrupt_data.position, :]
new_resid = new_resid[:, clean_corrupt_data.position, :]
orig_resid = orig_resid[torch.arange(len(orig_resid)), clean_corrupt_data.position, :]
new_resid = new_resid[torch.arange(len(new_resid)), clean_corrupt_data.position, :]
# Create a TensorDataset from the tensors
das_dataset = TensorDataset(
clean_corrupt_data.corrupted_tokens.detach().cpu(),
Expand Down
23 changes: 15 additions & 8 deletions utils/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,10 @@ def __init__(
is_positive = (
answer_tokens[:, 0, 0] == answer_tokens[0, 0, 0]
)
if position is None and prompt_type is not None and len(prompt_type.get_placeholders()) == 0:
no_placeholders = prompt_type is None or (
prompt_type is not None and len(prompt_type.get_placeholders()) == 0
)
if position is None and no_placeholders:
mask = get_attention_mask(
tokenizer, clean_tokens, prepend_bos=False
)
Expand All @@ -520,14 +523,13 @@ def __init__(
if isinstance(position, str):
assert tokenizer is not None and prompt_type is not None
example = tokenizer.tokenize(all_prompts[0])
example = [t.replace("Ġ", " ") for t in example]
placeholders = prompt_type.get_placeholder_positions(example)
pos: int = placeholders[position][-1]
pos_idx = einops.repeat(
torch.arange(clean_tokens.shape[1]),
"pos -> batch pos",
batch=clean_tokens.shape[0],
position = torch.full_like(
is_positive, pos, dtype=torch.int32, device=is_positive.device
)
position = pos_idx == pos
assert isinstance(position, torch.Tensor)
self.clean_tokens = clean_tokens
self.corrupted_tokens = corrupted_tokens
self.answer_tokens = answer_tokens
Expand Down Expand Up @@ -844,10 +846,14 @@ def set_accuracy(self):
def center_logit_diffs(self):
answer_tokens = self.dataset.answer_tokens
self.corrupted_logit_diffs, self.corrupted_logit_bias = center_logit_diffs(
self.corrupted_logit_diffs, answer_tokens
self.corrupted_logit_diffs,
answer_tokens,
is_positive=self.dataset.is_positive,
)
self.clean_logit_diffs, self.clean_logit_bias = center_logit_diffs(
self.clean_logit_diffs, answer_tokens
self.clean_logit_diffs,
answer_tokens,
is_positive=self.dataset.is_positive,
)


Expand Down Expand Up @@ -966,6 +972,7 @@ def _get_dataset(
corrupted_tokens=corrupted_tokens,
tokenizer=model.tokenizer,
position=position,
prompt_type=prompt_type,
)


Expand Down

0 comments on commit 1e874b5

Please sign in to comment.