diff --git a/fit_directions.py b/fit_directions.py index 46e811f..ef23b4e 100644 --- a/fit_directions.py +++ b/fit_directions.py @@ -245,13 +245,15 @@ def get_placeholder(prompt_type: PromptType) -> Optional[str]: epochs = 1 if "treebank" in train_type.value else 64, batch_size=BATCH_SIZES[model_name], d_das=method.get_dimension(), + train_label=train_type.value, ) print(f"Saving DAS direction to {das_path}") torch.cuda.empty_cache() print("Emptied CUDA cache") else: trainset = ResidualStreamDataset.get_dataset( - model, device, prompt_type=train_types, scaffold=SCAFFOLD + model, device, prompt_type=train_types, scaffold=SCAFFOLD, + label=train_type.value, ) testset = ResidualStreamDataset.get_dataset( model, device, prompt_type=test_type, scaffold=SCAFFOLD diff --git a/utils/classification.py b/utils/classification.py index 380c2c3..1e824c0 100644 --- a/utils/classification.py +++ b/utils/classification.py @@ -86,6 +86,7 @@ def _fit( method: ClassificationMethod = ClassificationMethod.KMEANS, ): assert train_data is not None + assert train_layer is not None if test_data is None: test_data = train_data if test_layer is None: @@ -273,8 +274,12 @@ def _fit_logistic_regression( def train_classifying_direction( - train_data: ResidualStreamDataset, train_pos: Union[str, None], train_layer: int, - test_data: Union[ResidualStreamDataset, None], test_pos: Union[str, None], test_layer: int, + train_data: ResidualStreamDataset, + train_pos: Union[str, None], + train_layer: int, + test_data: Union[ResidualStreamDataset, None], + test_pos: Union[str, None], + test_layer: int, method: ClassificationMethod, **kwargs, ): @@ -298,28 +303,32 @@ def train_classifying_direction( warnings.simplefilter("error", ConvergenceWarning) # Turn the warning into an error try: train_line, correct, total, accuracy = fitting_method( - train_data, train_pos, train_layer, - test_data, test_pos, test_layer, + train_data, + train_layer, + test_data, + test_layer, **kwargs, ) test_line, _, _, _ = fitting_method( - test_data, test_pos, test_layer, - test_data, test_pos, test_layer, + test_data, + test_layer, + test_data, + test_layer, **kwargs, ) except ConvergenceWarning: print( f"Convergence warning for {method.value}; " - f"train type:{train_data.prompt_type.value}, pos: {train_pos}, layer:{train_layer}, " - f"test type:{test_data.prompt_type.value}, pos: {test_pos}, layer:{test_layer}, " + f"train type:{train_data.prompt_type}, pos: {train_pos}, layer:{train_layer}, " + f"test type:{test_data.prompt_type}, pos: {test_pos}, layer:{test_layer}, " f"kwargs: {kwargs}\n" f"train str_labels:{train_data.str_labels}\n" f"test str_labels:{test_data.str_labels}\n" ) return # write line to file - train_pos_str = train_pos if train_pos is not None else "ALL" - array_path = f"{method.value}_{train_data.prompt_type.value}_{train_pos_str}_layer{train_layer}" + train_pos_str = f"_{train_pos}" if train_pos is not None else "" + array_path = f"{method.value}_{train_data.label}{train_pos_str}_layer{train_layer}" save_array(train_line, array_path, model) cosine_sim = safe_cosine_sim( diff --git a/utils/das.py b/utils/das.py index 9938076..06881a2 100644 --- a/utils/das.py +++ b/utils/das.py @@ -427,6 +427,7 @@ def get_das_dataset( device: Optional[torch.device] = None, requires_grad: bool = True, verbose: bool = False, + label: Optional[str] = None, ): """ Wrapper for utils.prompts.get_dataset that returns a dataset in a useful form for DAS @@ -437,7 +438,7 @@ def get_das_dataset( return DataLoader([]), None, None clean_corrupt_data = get_dataset( model, device, prompt_type=prompt_type, scaffold=scaffold, - position=position + position=position, label=label, ) if max_dataset_size is not None: clean_corrupt_data = clean_corrupt_data.get_subset( @@ -492,6 +493,7 @@ def train_das_subspace( data_requires_grad: bool = False, verbose: bool = False, d_das: int = 1, + train_label: Optional[str] = None, **config_arg, ) -> Tuple[Float[Tensor, "batch d_model"], str]: """ @@ -505,7 +507,7 @@ def train_das_subspace( train_type, position=train_pos, layer=train_layer, model=model, batch_size=batch_size, max_dataset_size=max_dataset_size, scaffold=scaffold, device=device, requires_grad=data_requires_grad, - verbose=verbose, + verbose=verbose, label=train_label ) if test_type != train_type or test_pos != train_pos or test_layer != train_layer: testloader, loss_fn_val, test_position = get_das_dataset( @@ -539,9 +541,9 @@ def train_das_subspace( verbose=verbose, **config, ) - train_pos = train_pos if train_pos is not None else 'ALL' d_das_str = f'{d_das}d' if d_das > 1 else '' - save_path = f'das{d_das_str}_{train_type.value}_{train_pos}_layer{train_layer}' + train_pos_str = f'_{train_pos}' if train_pos is not None else '' + save_path = f'das{d_das_str}_{train_label}{train_pos_str}_layer{train_layer}' save_array( directions.detach().cpu().squeeze(1).numpy(), save_path, diff --git a/utils/prompts.py b/utils/prompts.py index 3e4ac81..45645df 100644 --- a/utils/prompts.py +++ b/utils/prompts.py @@ -498,6 +498,7 @@ def __init__( position: Union[None, str, Int[Tensor, "batch"]] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, prompt_type: Optional[PromptType] = None, + label: Optional[str] = None, ): assert len(clean_tokens) == len(corrupted_tokens) assert len(clean_tokens) == len(answer_tokens) @@ -541,11 +542,13 @@ def __init__( self.prompt_type = prompt_type self.is_positive = is_positive self.position = position + self.label = label def __add__(self, other: "CleanCorruptedDataset"): assert isinstance(other, CleanCorruptedDataset) assert self.tokenizer is not None assert self.tokenizer == other.tokenizer + assert self.label == other.label seq_len1 = self.clean_tokens.shape[1] seq_len2 = other.clean_tokens.shape[1] max_seq_len = max(seq_len1, seq_len2) @@ -568,6 +571,7 @@ def __add__(self, other: "CleanCorruptedDataset"): torch.cat([self.position + offset1, other.position + offset2]), self.tokenizer, None, + self.label, ) def get_subset(self, indices: List[int]): @@ -580,6 +584,7 @@ def get_subset(self, indices: List[int]): self.position[indices], self.tokenizer, self.prompt_type, + self.label, ) def get_num_pad_tokens(self) -> Int[Tensor, "batch"]: @@ -868,6 +873,7 @@ def get_dataset( comparison: Tuple[str, str] = ("positive", "negative"), scaffold: Optional[ReviewScaffold] = None, position: Optional[Union[str, List[str], List[None]]] = None, + label: Optional[str] = None, ) -> CleanCorruptedDataset: if isinstance(prompt_type, PromptType): assert position is None or isinstance(position, str) @@ -879,6 +885,7 @@ def get_dataset( comparison=comparison, scaffold=scaffold, position=position, + label=label, ) assert len(prompt_type) > 0 if position is None: @@ -892,6 +899,7 @@ def get_dataset( comparison=comparison, scaffold=scaffold, position=position[0], + label=label, ) for pt, pos in zip(prompt_type[1:], position[1:]): assert isinstance(pt, PromptType) @@ -903,6 +911,7 @@ def get_dataset( comparison=comparison, scaffold=scaffold, position=pos, + label=label, ) return out @@ -915,8 +924,10 @@ def _get_dataset( comparison: Tuple[str, str] = ("positive", "negative"), scaffold: Optional[ReviewScaffold] = None, position: Optional[str] = None, + label: Optional[str] = None, ) -> CleanCorruptedDataset: - prompt_type = PromptType(prompt_type) + if label is None and isinstance(prompt_type, PromptType): + label = prompt_type.value if prompt_type in ( PromptType.TREEBANK_TRAIN, PromptType.TREEBANK_TEST, PromptType.TREEBANK_DEV ): @@ -976,6 +987,7 @@ def _get_dataset( tokenizer=model.tokenizer, position=position, prompt_type=prompt_type, + label=label, ) diff --git a/utils/residual_stream.py b/utils/residual_stream.py index 7e36ddc..1efc76c 100644 --- a/utils/residual_stream.py +++ b/utils/residual_stream.py @@ -37,6 +37,7 @@ def __init__( position: Int[Tensor, "batch"], model: HookedTransformer, prompt_type: Union[PromptType, List[PromptType]], + label: str, ) -> None: assert len(prompt_strings) == len(prompt_tokens) assert len(prompt_strings) == len(is_positive) @@ -46,6 +47,7 @@ def __init__( self.position = position self.model = model self.prompt_type = prompt_type + self.label = label label_tensor = self.prompt_tokens[ torch.arange(len(self.prompt_tokens)), self.position ].cpu().detach() @@ -169,6 +171,7 @@ def get_dataset( prompt_type: Union[PromptType, List[PromptType]] = PromptType.SIMPLE_TRAIN, scaffold: Optional[ReviewScaffold] = None, position: Optional[Union[str, List[str]]] = None, + label: Optional[str] = None, ) -> Union[None, 'ResidualStreamDataset']: """ @@ -178,7 +181,7 @@ def get_dataset( return None clean_corrupt_data = get_dataset( model, device, prompt_type=prompt_type, scaffold=scaffold, - position=position, + position=position, label=label, ) assert len(clean_corrupt_data.all_prompts) == len(clean_corrupt_data.answer_tokens) @@ -190,6 +193,7 @@ def get_dataset( clean_corrupt_data.position, model, prompt_type, + label, )