Skip to content

Commit

Permalink
Taking only first answer pair when summing for simplicity
Browse files Browse the repository at this point in the history
  • Loading branch information
ojh31 committed Oct 12, 2023
1 parent 01445ec commit 253296e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
16 changes: 9 additions & 7 deletions fit_directions.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,10 @@ def select_layers(
# epochs=1,
# )
#%%
def get_placeholder(prompt_type: PromptType):
def get_placeholder(prompt_type: PromptType) -> Optional[str]:
placeholders = prompt_type.get_placeholders()
if len(placeholders) == 0:
placeholders = ["ALL"]
return None
return placeholders[0]
#%%
# ============================================================================ #
Expand Down Expand Up @@ -216,14 +216,16 @@ def get_placeholder(prompt_type: PromptType):
# Don't train on verbs as sample size is too small
print("Skipping because train_pos is VRB")
continue
save_path = f"{method.value}_{train_type.value}_{train_pos}_layer{train_layer}.npy"
train_label = f"{train_type.value}"
if ALL_BUT_ONE:
train_label += "_all_but_one"
if train_pos is not None:
train_label += f"_{train_pos}"
train_label += f"_layer{train_layer}"
save_path = f"{method.value}_{train_label}.npy"
if SKIP_IF_EXISTS and is_file(save_path, model):
print(f"Skipping because file already exists: {save_path}")
continue
if train_pos == 'ALL':
train_pos = None
if test_pos == "ALL":
test_pos = None
train_types = [t for t in TRAIN_TYPES if t != train_type] if ALL_BUT_ONE else train_type
if isinstance(method, GradientMethod):
train_test_discrepancy = test_type != PromptType.NONE and (
Expand Down
4 changes: 3 additions & 1 deletion utils/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,10 +552,12 @@ def __add__(self, other: "CleanCorruptedDataset"):
else:
offset1 = 0
offset2 = 0
self_answers = self.answer_tokens[:, :1, :]
other_answers = other.answer_tokens[:, :1, :]
return CleanCorruptedDataset(
concatenate_tensors(self.clean_tokens, other.clean_tokens, self.tokenizer),
concatenate_tensors(self.corrupted_tokens, other.corrupted_tokens, self.tokenizer),
torch.cat([self.answer_tokens, other.answer_tokens]),
torch.cat([self_answers, other_answers]),
self.all_prompts + other.all_prompts,
torch.cat([self.is_positive, other.is_positive]),
torch.cat([self.position + offset1, other.position + offset2]),
Expand Down

0 comments on commit 253296e

Please sign in to comment.