Skip to content

Commit

Permalink
Updated direction pattern regex
Browse files Browse the repository at this point in the history
  • Loading branch information
ojh31 committed Oct 9, 2023
1 parent 9ff7674 commit b0fc56f
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 15 deletions.
3 changes: 2 additions & 1 deletion utils/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ def run_with_cache(
requires_grad: bool = True,
device: Optional[torch.device] = None,
disable_tqdm: Optional[bool] = None,
leave_tqdm: bool = False,
dtype: torch.dtype = torch.float32,
center: bool = True,
):
Expand Down Expand Up @@ -573,7 +574,7 @@ def run_with_cache(
clean_dict = dict()
if disable_tqdm is None:
disable_tqdm = len(dataloader) == 1
bar = tqdm(dataloader, disable=disable_tqdm)
bar = tqdm(dataloader, disable=disable_tqdm, leave=leave_tqdm)
bar.set_description(
f"Running with cache: model={model.cfg.model_name}, "
f"batch_size={batch_size}"
Expand Down
26 changes: 14 additions & 12 deletions utils/residual_stream.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, List, Tuple, Union
from typing import Callable, List, Optional, Sequence, Tuple, Union
import einops
from jaxtyping import Int, Float, Bool
from typeguard import typechecked
Expand Down Expand Up @@ -37,12 +37,12 @@ def __init__(
self._binary_labels = binary_labels
self.model = model
self.prompt_type = prompt_type
example_str_tokens = model.to_str_tokens(prompt_strings[0])
example_str_tokens: List[str] = model.to_str_tokens(prompt_strings[0]) # type: ignore
self.example = [f"{i}:{tok}" for i, tok in enumerate(example_str_tokens)]
self.placeholder_dict = prompt_type.get_placeholder_positions(example_str_tokens)
label_positions = [pos for _, positions in self.placeholder_dict.items() for pos in positions]
self.str_labels = [
''.join([model.to_str_tokens(prompt)[pos] for pos in label_positions])
''.join([model.to_str_tokens(prompt)[pos] for pos in label_positions]) # type: ignore
for prompt in prompt_strings
]

Expand All @@ -53,7 +53,7 @@ def binary_labels(self) -> Bool[Tensor, "batch"]:
def __len__(self) -> int:
return len(self.prompt_strings)

def __eq__(self, other: object) -> bool:
def __eq__(self, other: 'ResidualStreamDataset') -> bool:
return set(self.prompt_strings) == set(other.prompt_strings)

def get_dataloader(self, batch_size: int) -> torch.utils.data.DataLoader:
Expand All @@ -69,8 +69,9 @@ def run_with_cache(
names_filter: Callable,
batch_size: int,
requires_grad: bool = True,
device: torch.device = None,
disable_tqdm: bool = None,
device: Optional[torch.device] = None,
disable_tqdm: Optional[bool] = None,
leave_tqdm: bool = False,
dtype: torch.dtype = torch.float32,
):
"""
Expand All @@ -90,7 +91,7 @@ def run_with_cache(
total_samples = len(dataloader.dataset)
if disable_tqdm is None:
disable_tqdm = len(dataloader) > 1
bar = tqdm(dataloader, disable=disable_tqdm)
bar = tqdm(dataloader, disable=disable_tqdm, leave=leave_tqdm)
bar.set_description(
f"Running with cache: model={model.cfg.model_name}, "
f"batch_size={batch_size}"
Expand Down Expand Up @@ -125,7 +126,7 @@ def run_with_cache(
torch.set_grad_enabled(was_grad_enabled)
model = model.train().requires_grad_(requires_grad)

return _, act_cache
return None, act_cache

@typechecked
def embed(
Expand All @@ -136,6 +137,7 @@ def embed(
Returns a dataset of embeddings at the specified position and layer.
Useful for training classifiers on the residual stream.
"""
assert self.model.tokenizer is not None, "embed: model must have tokenizer"
torch.manual_seed(seed)
assert 0 <= layer <= self.model.cfg.n_layers
assert position_type is None or position_type in self.placeholder_dict.keys(), (
Expand Down Expand Up @@ -163,17 +165,17 @@ def embed(
).squeeze().to(device=out.device)
return out[torch.arange(len(out)), embed_position, :].detach().cpu()
else:
embed_position = self.placeholder_dict[position_type][-1]
return out[:, embed_position, :].detach().cpu()
embed_pos: int = self.placeholder_dict[position_type][-1]
return out[:, embed_pos, :].detach().cpu()

@classmethod
def get_dataset(
cls,
model: HookedTransformer,
device: torch.device,
prompt_type: PromptType = PromptType.SIMPLE_TRAIN,
scaffold: ReviewScaffold = None,
) -> 'ResidualStreamDataset':
scaffold: Optional[ReviewScaffold] = None,
) -> Union[None, 'ResidualStreamDataset']:
"""
N.B. labels assume that first batch corresponds to 1
"""
Expand Down
3 changes: 1 addition & 2 deletions utils/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

DIRECTION_PATTERN = (
r'^(kmeans|pca|das|das2d|das3d|logistic_regression|mean_diff|random_direction)_'
r'(?:(simple_train|treebank_train)_(ADJ|ALL)_)?'
r'(?:(simple_adverb|simple_book|simple_product|simple_train|simple_res|treebank_train)_(ADJ|ADV|ALL|FEEL|NOUN|VRB)_)'
r'layer(\d*)'
r'\.npy$'
)
Expand Down Expand Up @@ -311,7 +311,6 @@ def save_text(


def load_text(
text: str,
label: str,
model: Union[HookedTransformer, str]
):
Expand Down

0 comments on commit b0fc56f

Please sign in to comment.