From b0fc56f86a215bb1d85e19501d862e05a2bfaf05 Mon Sep 17 00:00:00 2001 From: skar0 Date: Mon, 9 Oct 2023 14:06:28 +0100 Subject: [PATCH] Updated direction pattern regex --- utils/prompts.py | 3 ++- utils/residual_stream.py | 26 ++++++++++++++------------ utils/store.py | 3 +-- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/utils/prompts.py b/utils/prompts.py index a743640..3338104 100644 --- a/utils/prompts.py +++ b/utils/prompts.py @@ -540,6 +540,7 @@ def run_with_cache( requires_grad: bool = True, device: Optional[torch.device] = None, disable_tqdm: Optional[bool] = None, + leave_tqdm: bool = False, dtype: torch.dtype = torch.float32, center: bool = True, ): @@ -573,7 +574,7 @@ def run_with_cache( clean_dict = dict() if disable_tqdm is None: disable_tqdm = len(dataloader) == 1 - bar = tqdm(dataloader, disable=disable_tqdm) + bar = tqdm(dataloader, disable=disable_tqdm, leave=leave_tqdm) bar.set_description( f"Running with cache: model={model.cfg.model_name}, " f"batch_size={batch_size}" diff --git a/utils/residual_stream.py b/utils/residual_stream.py index 46a90c8..3b7c1d9 100644 --- a/utils/residual_stream.py +++ b/utils/residual_stream.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Tuple, Union +from typing import Callable, List, Optional, Sequence, Tuple, Union import einops from jaxtyping import Int, Float, Bool from typeguard import typechecked @@ -37,12 +37,12 @@ def __init__( self._binary_labels = binary_labels self.model = model self.prompt_type = prompt_type - example_str_tokens = model.to_str_tokens(prompt_strings[0]) + example_str_tokens: List[str] = model.to_str_tokens(prompt_strings[0]) # type: ignore self.example = [f"{i}:{tok}" for i, tok in enumerate(example_str_tokens)] self.placeholder_dict = prompt_type.get_placeholder_positions(example_str_tokens) label_positions = [pos for _, positions in self.placeholder_dict.items() for pos in positions] self.str_labels = [ - ''.join([model.to_str_tokens(prompt)[pos] for pos in label_positions]) + ''.join([model.to_str_tokens(prompt)[pos] for pos in label_positions]) # type: ignore for prompt in prompt_strings ] @@ -53,7 +53,7 @@ def binary_labels(self) -> Bool[Tensor, "batch"]: def __len__(self) -> int: return len(self.prompt_strings) - def __eq__(self, other: object) -> bool: + def __eq__(self, other: 'ResidualStreamDataset') -> bool: return set(self.prompt_strings) == set(other.prompt_strings) def get_dataloader(self, batch_size: int) -> torch.utils.data.DataLoader: @@ -69,8 +69,9 @@ def run_with_cache( names_filter: Callable, batch_size: int, requires_grad: bool = True, - device: torch.device = None, - disable_tqdm: bool = None, + device: Optional[torch.device] = None, + disable_tqdm: Optional[bool] = None, + leave_tqdm: bool = False, dtype: torch.dtype = torch.float32, ): """ @@ -90,7 +91,7 @@ def run_with_cache( total_samples = len(dataloader.dataset) if disable_tqdm is None: disable_tqdm = len(dataloader) > 1 - bar = tqdm(dataloader, disable=disable_tqdm) + bar = tqdm(dataloader, disable=disable_tqdm, leave=leave_tqdm) bar.set_description( f"Running with cache: model={model.cfg.model_name}, " f"batch_size={batch_size}" @@ -125,7 +126,7 @@ def run_with_cache( torch.set_grad_enabled(was_grad_enabled) model = model.train().requires_grad_(requires_grad) - return _, act_cache + return None, act_cache @typechecked def embed( @@ -136,6 +137,7 @@ def embed( Returns a dataset of embeddings at the specified position and layer. Useful for training classifiers on the residual stream. """ + assert self.model.tokenizer is not None, "embed: model must have tokenizer" torch.manual_seed(seed) assert 0 <= layer <= self.model.cfg.n_layers assert position_type is None or position_type in self.placeholder_dict.keys(), ( @@ -163,8 +165,8 @@ def embed( ).squeeze().to(device=out.device) return out[torch.arange(len(out)), embed_position, :].detach().cpu() else: - embed_position = self.placeholder_dict[position_type][-1] - return out[:, embed_position, :].detach().cpu() + embed_pos: int = self.placeholder_dict[position_type][-1] + return out[:, embed_pos, :].detach().cpu() @classmethod def get_dataset( @@ -172,8 +174,8 @@ def get_dataset( model: HookedTransformer, device: torch.device, prompt_type: PromptType = PromptType.SIMPLE_TRAIN, - scaffold: ReviewScaffold = None, - ) -> 'ResidualStreamDataset': + scaffold: Optional[ReviewScaffold] = None, + ) -> Union[None, 'ResidualStreamDataset']: """ N.B. labels assume that first batch corresponds to 1 """ diff --git a/utils/store.py b/utils/store.py index 144a3f1..ce0a067 100644 --- a/utils/store.py +++ b/utils/store.py @@ -16,7 +16,7 @@ DIRECTION_PATTERN = ( r'^(kmeans|pca|das|das2d|das3d|logistic_regression|mean_diff|random_direction)_' - r'(?:(simple_train|treebank_train)_(ADJ|ALL)_)?' + r'(?:(simple_adverb|simple_book|simple_product|simple_train|simple_res|treebank_train)_(ADJ|ADV|ALL|FEEL|NOUN|VRB)_)' r'layer(\d*)' r'\.npy$' ) @@ -311,7 +311,6 @@ def save_text( def load_text( - text: str, label: str, model: Union[HookedTransformer, str] ):