diff --git a/utils/neuroscope.py b/utils/neuroscope.py index 083cde5..a004da8 100644 --- a/utils/neuroscope.py +++ b/utils/neuroscope.py @@ -103,13 +103,13 @@ def get_activations_cached( print("Loading activations from file") sentiment_activations_np = load_array(path, model) sentiment_activations: Float[Tensor, "row pos layer"] = torch.tensor( - sentiment_activations_np, dtype=torch.float32 + sentiment_activations_np, dtype=model.cfg.dtype ) else: if verbose: print("Computing activations") direction = load_array(direction_label + ".npy", model) - direction = torch.tensor(direction, dtype=torch.float32) + direction = torch.tensor(direction, dtype=model.cfg.dtype) direction /= direction.norm() sentiment_activations: Float[ Tensor, "row pos layer" @@ -386,9 +386,9 @@ def _plot_topk( base_activations: Float[Tensor, "row pos"] = all_activations[:, :, base_layer] activations = activations - base_activations if largest: - ignore_value = torch.tensor(-np.inf, device=device, dtype=torch.float32) + ignore_value = torch.tensor(-np.inf, device=device, dtype=model.cfg.dtype) else: - ignore_value = torch.tensor(np.inf, device=device, dtype=torch.float32) + ignore_value = torch.tensor(np.inf, device=device, dtype=model.cfg.dtype) # create a mask for the inclusions/exclusions if exclusions is not None: mask: Bool[Tensor, "row pos"] = get_batch_pos_mask( @@ -424,7 +424,7 @@ def _plot_topk( ] # Print the most positive and negative examples and their activations print(f"Top {k} most {label} examples:") - zeros = torch.zeros((1, layers), device=device, dtype=torch.float32) + zeros = torch.zeros((1, layers), device=device, dtype=model.cfg.dtype) texts = [model.tokenizer.bos_token] text_to_not_repeat = set() acts = [zeros] @@ -545,9 +545,9 @@ def _plot_top_p( label = "positive" if largest else "negative" activations: Float[Tensor, "batch pos"] = all_activations[:, :, layer] if largest: - ignore_value = torch.tensor(-np.inf, device=device, dtype=torch.float32) + ignore_value = torch.tensor(-np.inf, device=device, dtype=model.cfg.dtype) else: - ignore_value = torch.tensor(np.inf, device=device, dtype=torch.float32) + ignore_value = torch.tensor(np.inf, device=device, dtype=model.cfg.dtype) # create a mask for the inclusions/exclusions if exclusions is not None: mask: Bool[Tensor, "row pos"] = get_batch_pos_mask( @@ -576,7 +576,7 @@ def _plot_top_p( # Print the most positive and negative examples and their activations print(f"Top {k} most {label} examples:") zeros = torch.zeros( - (1, all_activations.shape[-1]), device=device, dtype=torch.float32 + (1, all_activations.shape[-1]), device=device, dtype=model.cfg.dtype ) texts = [model.tokenizer.bos_token] text_to_not_repeat = set() diff --git a/utils/prompts.py b/utils/prompts.py index b71001f..9957721 100644 --- a/utils/prompts.py +++ b/utils/prompts.py @@ -783,7 +783,6 @@ def forward( requires_grad: bool = True, device: Optional[torch.device] = None, disable_tqdm: Optional[bool] = None, - dtype: torch.dtype = torch.float32, center: bool = True, ): return self.run_with_cache( @@ -793,7 +792,6 @@ def forward( requires_grad=requires_grad, device=device, disable_tqdm=disable_tqdm, - dtype=dtype, center=center, ) @@ -806,7 +804,6 @@ def run_with_cache( device: Optional[torch.device] = None, disable_tqdm: Optional[bool] = None, leave_tqdm: bool = False, - dtype: torch.dtype = torch.float32, center: bool = True, ): """ @@ -830,6 +827,7 @@ def run_with_cache( # Initialise arrays total_samples = len(dataloader.dataset) + dtype = model.cfg.dtype clean_logit_diffs = torch.zeros(total_samples, dtype=dtype, device="cpu") corrupted_logit_diffs = torch.zeros(total_samples, dtype=dtype, device="cpu") clean_prob_diffs = torch.zeros(total_samples, dtype=dtype, device="cpu") diff --git a/utils/residual_stream.py b/utils/residual_stream.py index 3b7c1d9..e49745f 100644 --- a/utils/residual_stream.py +++ b/utils/residual_stream.py @@ -12,20 +12,19 @@ def get_resid_name(layer: int, model: HookedTransformer) -> Tuple[str, int]: - resid_type = 'resid_pre' + resid_type = "resid_pre" if layer == model.cfg.n_layers: - resid_type = 'resid_post' + resid_type = "resid_post" layer -= 1 return get_act_name(resid_type, layer), layer class ResidualStreamDataset: - @typechecked def __init__( - self, - prompt_strings: List[str], - prompt_tokens: Int[Tensor, "batch pos"], + self, + prompt_strings: List[str], + prompt_tokens: Int[Tensor, "batch pos"], binary_labels: Bool[Tensor, "batch"], model: HookedTransformer, prompt_type: PromptType, @@ -37,12 +36,16 @@ def __init__( self._binary_labels = binary_labels self.model = model self.prompt_type = prompt_type - example_str_tokens: List[str] = model.to_str_tokens(prompt_strings[0]) # type: ignore + example_str_tokens: List[str] = model.to_str_tokens(prompt_strings[0]) # type: ignore self.example = [f"{i}:{tok}" for i, tok in enumerate(example_str_tokens)] - self.placeholder_dict = prompt_type.get_placeholder_positions(example_str_tokens) - label_positions = [pos for _, positions in self.placeholder_dict.items() for pos in positions] + self.placeholder_dict = prompt_type.get_placeholder_positions( + example_str_tokens + ) + label_positions = [ + pos for _, positions in self.placeholder_dict.items() for pos in positions + ] self.str_labels = [ - ''.join([model.to_str_tokens(prompt)[pos] for pos in label_positions]) # type: ignore + "".join([model.to_str_tokens(prompt)[pos] for pos in label_positions]) # type: ignore for prompt in prompt_strings ] @@ -52,33 +55,34 @@ def binary_labels(self) -> Bool[Tensor, "batch"]: def __len__(self) -> int: return len(self.prompt_strings) - - def __eq__(self, other: 'ResidualStreamDataset') -> bool: + + def __eq__(self, other: "ResidualStreamDataset") -> bool: return set(self.prompt_strings) == set(other.prompt_strings) - + def get_dataloader(self, batch_size: int) -> torch.utils.data.DataLoader: assert batch_size is not None, "get_dataloader: must specify batch size" token_answer_dataset = TensorDataset( - self.prompt_tokens, + self.prompt_tokens, + ) + token_answer_dataloader = DataLoader( + token_answer_dataset, batch_size=batch_size ) - token_answer_dataloader = DataLoader(token_answer_dataset, batch_size=batch_size) return token_answer_dataloader - + def run_with_cache( - self, - names_filter: Callable, + self, + names_filter: Callable, batch_size: int, requires_grad: bool = True, device: Optional[torch.device] = None, disable_tqdm: Optional[bool] = None, leave_tqdm: bool = False, - dtype: torch.dtype = torch.float32, ): """ Note that variable names here assume denoising, i.e. corrupted -> clean """ if device is None: - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") was_grad_enabled = torch.is_grad_enabled() torch.set_grad_enabled(False) model = self.model.eval().requires_grad_(False) @@ -96,20 +100,22 @@ def run_with_cache( f"Running with cache: model={model.cfg.model_name}, " f"batch_size={batch_size}" ) - for idx, (prompt_tokens, ) in enumerate(bar): + for idx, (prompt_tokens,) in enumerate(bar): prompt_tokens = prompt_tokens.to(device) with torch.inference_mode(): # forward pass _, fwd_cache = model.run_with_cache( prompt_tokens, names_filter=names_filter, return_type=None ) - fwd_cache.to('cpu') + fwd_cache.to("cpu") # Initialise the buffer tensors if necessary if not buffer_initialized: for k, v in fwd_cache.items(): act_dict[k] = torch.zeros( - (total_samples, *v.shape[1:]), dtype=dtype, device='cpu' + (total_samples, *v.shape[1:]), + dtype=self.model.dtype, + device="cpu", ) buffer_initialized = True @@ -119,19 +125,25 @@ def run_with_cache( for k, v in fwd_cache.items(): act_dict[k][start_idx:end_idx] = v act_cache = ActivationCache( - {k: v.detach().clone().requires_grad_(requires_grad) for k, v in act_dict.items()}, - model=model + { + k: v.detach().clone().requires_grad_(requires_grad) + for k, v in act_dict.items() + }, + model=model, ) - act_cache.to('cpu') + act_cache.to("cpu") torch.set_grad_enabled(was_grad_enabled) model = model.train().requires_grad_(requires_grad) return None, act_cache - + @typechecked def embed( - self, position_type: Union[str, None], layer: int, - batch_size: int = 64, seed: int = 0, + self, + position_type: Union[str, None], + layer: int, + batch_size: int = 64, + seed: int = 0, ) -> Float[Tensor, "batch d_model"]: """ Returns a dataset of embeddings at the specified position and layer. @@ -146,28 +158,34 @@ def embed( ) hook, _ = get_resid_name(layer, self.model) _, cache = self.run_with_cache( - names_filter = lambda name: hook == name, batch_size=batch_size + names_filter=lambda name: hook == name, batch_size=batch_size ) out: Float[Tensor, "batch pos d_model"] = cache[hook] if position_type is None: # Step 1: Identify non-zero positions in the tensor - non_pad_mask: Bool[Tensor, "batch pos"] = self.prompt_tokens != self.model.tokenizer.pad_token_id + non_pad_mask: Bool[Tensor, "batch pos"] = ( + self.prompt_tokens != self.model.tokenizer.pad_token_id + ) # Step 2: Check if values at these positions are not constant across batches non_constant_mask: Bool[Tensor, "pos"] = ( self.prompt_tokens != self.prompt_tokens[0] ).any(dim=0) - valid_positions: Bool[Tensor, "batch pos"] = non_pad_mask & non_constant_mask + valid_positions: Bool[Tensor, "batch pos"] = ( + non_pad_mask & non_constant_mask + ) # Step 3: Randomly sample from these positions for each batch - embed_position: Int[Tensor, "batch"] = torch.multinomial( - valid_positions.float(), 1 - ).squeeze().to(device=out.device) + embed_position: Int[Tensor, "batch"] = ( + torch.multinomial(valid_positions.float(), 1) + .squeeze() + .to(device=out.device) + ) return out[torch.arange(len(out)), embed_position, :].detach().cpu() else: embed_pos: int = self.placeholder_dict[position_type][-1] return out[:, embed_pos, :].detach().cpu() - + @classmethod def get_dataset( cls, @@ -175,7 +193,7 @@ def get_dataset( device: torch.device, prompt_type: PromptType = PromptType.SIMPLE_TRAIN, scaffold: Optional[ReviewScaffold] = None, - ) -> Union[None, 'ResidualStreamDataset']: + ) -> Union[None, "ResidualStreamDataset"]: """ N.B. labels assume that first batch corresponds to 1 """ @@ -184,10 +202,17 @@ def get_dataset( clean_corrupt_data = get_dataset( model, device, prompt_type=prompt_type, scaffold=scaffold ) - clean_labels = clean_corrupt_data.answer_tokens[:, 0, 0] == clean_corrupt_data.answer_tokens[0, 0, 0] - - assert len(clean_corrupt_data.all_prompts) == len(clean_corrupt_data.answer_tokens) - assert len(clean_corrupt_data.all_prompts) == len(clean_corrupt_data.clean_tokens) + clean_labels = ( + clean_corrupt_data.answer_tokens[:, 0, 0] + == clean_corrupt_data.answer_tokens[0, 0, 0] + ) + + assert len(clean_corrupt_data.all_prompts) == len( + clean_corrupt_data.answer_tokens + ) + assert len(clean_corrupt_data.all_prompts) == len( + clean_corrupt_data.clean_tokens + ) return cls( clean_corrupt_data.all_prompts, clean_corrupt_data.clean_tokens, @@ -195,17 +220,16 @@ def get_dataset( model, prompt_type, ) - def _get_labels_by_class(self, label: int) -> List[str]: return [ - string.strip() for string, one_hot in zip(self.str_labels, self.binary_labels) if one_hot == label + string.strip() + for string, one_hot in zip(self.str_labels, self.binary_labels) + if one_hot == label ] - + def get_positive_negative_labels(self): return ( self._get_labels_by_class(1), self._get_labels_by_class(0), ) - - \ No newline at end of file