diff --git a/fit_directions.py b/fit_directions.py index cff130a..46e811f 100644 --- a/fit_directions.py +++ b/fit_directions.py @@ -168,10 +168,10 @@ def select_layers( # epochs=1, # ) #%% -def get_placeholder(prompt_type: PromptType): +def get_placeholder(prompt_type: PromptType) -> Optional[str]: placeholders = prompt_type.get_placeholders() if len(placeholders) == 0: - placeholders = ["ALL"] + return None return placeholders[0] #%% # ============================================================================ # @@ -216,14 +216,16 @@ def get_placeholder(prompt_type: PromptType): # Don't train on verbs as sample size is too small print("Skipping because train_pos is VRB") continue - save_path = f"{method.value}_{train_type.value}_{train_pos}_layer{train_layer}.npy" + train_label = f"{train_type.value}" + if ALL_BUT_ONE: + train_label += "_all_but_one" + if train_pos is not None: + train_label += f"_{train_pos}" + train_label += f"_layer{train_layer}" + save_path = f"{method.value}_{train_label}.npy" if SKIP_IF_EXISTS and is_file(save_path, model): print(f"Skipping because file already exists: {save_path}") continue - if train_pos == 'ALL': - train_pos = None - if test_pos == "ALL": - test_pos = None train_types = [t for t in TRAIN_TYPES if t != train_type] if ALL_BUT_ONE else train_type if isinstance(method, GradientMethod): train_test_discrepancy = test_type != PromptType.NONE and ( diff --git a/utils/prompts.py b/utils/prompts.py index a87b389..14222fc 100644 --- a/utils/prompts.py +++ b/utils/prompts.py @@ -552,10 +552,12 @@ def __add__(self, other: "CleanCorruptedDataset"): else: offset1 = 0 offset2 = 0 + self_answers = self.answer_tokens[:, :1, :] + other_answers = other.answer_tokens[:, :1, :] return CleanCorruptedDataset( concatenate_tensors(self.clean_tokens, other.clean_tokens, self.tokenizer), concatenate_tensors(self.corrupted_tokens, other.corrupted_tokens, self.tokenizer), - torch.cat([self.answer_tokens, other.answer_tokens]), + torch.cat([self_answers, other_answers]), self.all_prompts + other.all_prompts, torch.cat([self.is_positive, other.is_positive]), torch.cat([self.position + offset1, other.position + offset2]),