Skip to content

Commit

Permalink
Replaced references to float32 in utils to use model.cfg.dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
ojh31 committed Nov 22, 2023
1 parent 4aef8cb commit 9ddca9b
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 57 deletions.
16 changes: 8 additions & 8 deletions utils/neuroscope.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ def get_activations_cached(
print("Loading activations from file")
sentiment_activations_np = load_array(path, model)
sentiment_activations: Float[Tensor, "row pos layer"] = torch.tensor(
sentiment_activations_np, dtype=torch.float32
sentiment_activations_np, dtype=model.cfg.dtype
)
else:
if verbose:
print("Computing activations")
direction = load_array(direction_label + ".npy", model)
direction = torch.tensor(direction, dtype=torch.float32)
direction = torch.tensor(direction, dtype=model.cfg.dtype)
direction /= direction.norm()
sentiment_activations: Float[
Tensor, "row pos layer"
Expand Down Expand Up @@ -386,9 +386,9 @@ def _plot_topk(
base_activations: Float[Tensor, "row pos"] = all_activations[:, :, base_layer]
activations = activations - base_activations
if largest:
ignore_value = torch.tensor(-np.inf, device=device, dtype=torch.float32)
ignore_value = torch.tensor(-np.inf, device=device, dtype=model.cfg.dtype)
else:
ignore_value = torch.tensor(np.inf, device=device, dtype=torch.float32)
ignore_value = torch.tensor(np.inf, device=device, dtype=model.cfg.dtype)
# create a mask for the inclusions/exclusions
if exclusions is not None:
mask: Bool[Tensor, "row pos"] = get_batch_pos_mask(
Expand Down Expand Up @@ -424,7 +424,7 @@ def _plot_topk(
]
# Print the most positive and negative examples and their activations
print(f"Top {k} most {label} examples:")
zeros = torch.zeros((1, layers), device=device, dtype=torch.float32)
zeros = torch.zeros((1, layers), device=device, dtype=model.cfg.dtype)
texts = [model.tokenizer.bos_token]
text_to_not_repeat = set()
acts = [zeros]
Expand Down Expand Up @@ -545,9 +545,9 @@ def _plot_top_p(
label = "positive" if largest else "negative"
activations: Float[Tensor, "batch pos"] = all_activations[:, :, layer]
if largest:
ignore_value = torch.tensor(-np.inf, device=device, dtype=torch.float32)
ignore_value = torch.tensor(-np.inf, device=device, dtype=model.cfg.dtype)
else:
ignore_value = torch.tensor(np.inf, device=device, dtype=torch.float32)
ignore_value = torch.tensor(np.inf, device=device, dtype=model.cfg.dtype)
# create a mask for the inclusions/exclusions
if exclusions is not None:
mask: Bool[Tensor, "row pos"] = get_batch_pos_mask(
Expand Down Expand Up @@ -576,7 +576,7 @@ def _plot_top_p(
# Print the most positive and negative examples and their activations
print(f"Top {k} most {label} examples:")
zeros = torch.zeros(
(1, all_activations.shape[-1]), device=device, dtype=torch.float32
(1, all_activations.shape[-1]), device=device, dtype=model.cfg.dtype
)
texts = [model.tokenizer.bos_token]
text_to_not_repeat = set()
Expand Down
4 changes: 1 addition & 3 deletions utils/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,6 @@ def forward(
requires_grad: bool = True,
device: Optional[torch.device] = None,
disable_tqdm: Optional[bool] = None,
dtype: torch.dtype = torch.float32,
center: bool = True,
):
return self.run_with_cache(
Expand All @@ -793,7 +792,6 @@ def forward(
requires_grad=requires_grad,
device=device,
disable_tqdm=disable_tqdm,
dtype=dtype,
center=center,
)

Expand All @@ -806,7 +804,6 @@ def run_with_cache(
device: Optional[torch.device] = None,
disable_tqdm: Optional[bool] = None,
leave_tqdm: bool = False,
dtype: torch.dtype = torch.float32,
center: bool = True,
):
"""
Expand All @@ -830,6 +827,7 @@ def run_with_cache(

# Initialise arrays
total_samples = len(dataloader.dataset)
dtype = model.cfg.dtype
clean_logit_diffs = torch.zeros(total_samples, dtype=dtype, device="cpu")
corrupted_logit_diffs = torch.zeros(total_samples, dtype=dtype, device="cpu")
clean_prob_diffs = torch.zeros(total_samples, dtype=dtype, device="cpu")
Expand Down
116 changes: 70 additions & 46 deletions utils/residual_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,19 @@


def get_resid_name(layer: int, model: HookedTransformer) -> Tuple[str, int]:
resid_type = 'resid_pre'
resid_type = "resid_pre"
if layer == model.cfg.n_layers:
resid_type = 'resid_post'
resid_type = "resid_post"
layer -= 1
return get_act_name(resid_type, layer), layer


class ResidualStreamDataset:

@typechecked
def __init__(
self,
prompt_strings: List[str],
prompt_tokens: Int[Tensor, "batch pos"],
self,
prompt_strings: List[str],
prompt_tokens: Int[Tensor, "batch pos"],
binary_labels: Bool[Tensor, "batch"],
model: HookedTransformer,
prompt_type: PromptType,
Expand All @@ -37,12 +36,16 @@ def __init__(
self._binary_labels = binary_labels
self.model = model
self.prompt_type = prompt_type
example_str_tokens: List[str] = model.to_str_tokens(prompt_strings[0]) # type: ignore
example_str_tokens: List[str] = model.to_str_tokens(prompt_strings[0]) # type: ignore
self.example = [f"{i}:{tok}" for i, tok in enumerate(example_str_tokens)]
self.placeholder_dict = prompt_type.get_placeholder_positions(example_str_tokens)
label_positions = [pos for _, positions in self.placeholder_dict.items() for pos in positions]
self.placeholder_dict = prompt_type.get_placeholder_positions(
example_str_tokens
)
label_positions = [
pos for _, positions in self.placeholder_dict.items() for pos in positions
]
self.str_labels = [
''.join([model.to_str_tokens(prompt)[pos] for pos in label_positions]) # type: ignore
"".join([model.to_str_tokens(prompt)[pos] for pos in label_positions]) # type: ignore
for prompt in prompt_strings
]

Expand All @@ -52,33 +55,34 @@ def binary_labels(self) -> Bool[Tensor, "batch"]:

def __len__(self) -> int:
return len(self.prompt_strings)
def __eq__(self, other: 'ResidualStreamDataset') -> bool:

def __eq__(self, other: "ResidualStreamDataset") -> bool:
return set(self.prompt_strings) == set(other.prompt_strings)

def get_dataloader(self, batch_size: int) -> torch.utils.data.DataLoader:
assert batch_size is not None, "get_dataloader: must specify batch size"
token_answer_dataset = TensorDataset(
self.prompt_tokens,
self.prompt_tokens,
)
token_answer_dataloader = DataLoader(
token_answer_dataset, batch_size=batch_size
)
token_answer_dataloader = DataLoader(token_answer_dataset, batch_size=batch_size)
return token_answer_dataloader

def run_with_cache(
self,
names_filter: Callable,
self,
names_filter: Callable,
batch_size: int,
requires_grad: bool = True,
device: Optional[torch.device] = None,
disable_tqdm: Optional[bool] = None,
leave_tqdm: bool = False,
dtype: torch.dtype = torch.float32,
):
"""
Note that variable names here assume denoising, i.e. corrupted -> clean
"""
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
was_grad_enabled = torch.is_grad_enabled()
torch.set_grad_enabled(False)
model = self.model.eval().requires_grad_(False)
Expand All @@ -96,20 +100,22 @@ def run_with_cache(
f"Running with cache: model={model.cfg.model_name}, "
f"batch_size={batch_size}"
)
for idx, (prompt_tokens, ) in enumerate(bar):
for idx, (prompt_tokens,) in enumerate(bar):
prompt_tokens = prompt_tokens.to(device)
with torch.inference_mode():
# forward pass
_, fwd_cache = model.run_with_cache(
prompt_tokens, names_filter=names_filter, return_type=None
)
fwd_cache.to('cpu')
fwd_cache.to("cpu")

# Initialise the buffer tensors if necessary
if not buffer_initialized:
for k, v in fwd_cache.items():
act_dict[k] = torch.zeros(
(total_samples, *v.shape[1:]), dtype=dtype, device='cpu'
(total_samples, *v.shape[1:]),
dtype=self.model.dtype,
device="cpu",
)
buffer_initialized = True

Expand All @@ -119,19 +125,25 @@ def run_with_cache(
for k, v in fwd_cache.items():
act_dict[k][start_idx:end_idx] = v
act_cache = ActivationCache(
{k: v.detach().clone().requires_grad_(requires_grad) for k, v in act_dict.items()},
model=model
{
k: v.detach().clone().requires_grad_(requires_grad)
for k, v in act_dict.items()
},
model=model,
)
act_cache.to('cpu')
act_cache.to("cpu")
torch.set_grad_enabled(was_grad_enabled)
model = model.train().requires_grad_(requires_grad)

return None, act_cache

@typechecked
def embed(
self, position_type: Union[str, None], layer: int,
batch_size: int = 64, seed: int = 0,
self,
position_type: Union[str, None],
layer: int,
batch_size: int = 64,
seed: int = 0,
) -> Float[Tensor, "batch d_model"]:
"""
Returns a dataset of embeddings at the specified position and layer.
Expand All @@ -146,36 +158,42 @@ def embed(
)
hook, _ = get_resid_name(layer, self.model)
_, cache = self.run_with_cache(
names_filter = lambda name: hook == name, batch_size=batch_size
names_filter=lambda name: hook == name, batch_size=batch_size
)
out: Float[Tensor, "batch pos d_model"] = cache[hook]
if position_type is None:
# Step 1: Identify non-zero positions in the tensor
non_pad_mask: Bool[Tensor, "batch pos"] = self.prompt_tokens != self.model.tokenizer.pad_token_id
non_pad_mask: Bool[Tensor, "batch pos"] = (
self.prompt_tokens != self.model.tokenizer.pad_token_id
)

# Step 2: Check if values at these positions are not constant across batches
non_constant_mask: Bool[Tensor, "pos"] = (
self.prompt_tokens != self.prompt_tokens[0]
).any(dim=0)
valid_positions: Bool[Tensor, "batch pos"] = non_pad_mask & non_constant_mask
valid_positions: Bool[Tensor, "batch pos"] = (
non_pad_mask & non_constant_mask
)

# Step 3: Randomly sample from these positions for each batch
embed_position: Int[Tensor, "batch"] = torch.multinomial(
valid_positions.float(), 1
).squeeze().to(device=out.device)
embed_position: Int[Tensor, "batch"] = (
torch.multinomial(valid_positions.float(), 1)
.squeeze()
.to(device=out.device)
)
return out[torch.arange(len(out)), embed_position, :].detach().cpu()
else:
embed_pos: int = self.placeholder_dict[position_type][-1]
return out[:, embed_pos, :].detach().cpu()

@classmethod
def get_dataset(
cls,
model: HookedTransformer,
device: torch.device,
prompt_type: PromptType = PromptType.SIMPLE_TRAIN,
scaffold: Optional[ReviewScaffold] = None,
) -> Union[None, 'ResidualStreamDataset']:
) -> Union[None, "ResidualStreamDataset"]:
"""
N.B. labels assume that first batch corresponds to 1
"""
Expand All @@ -184,28 +202,34 @@ def get_dataset(
clean_corrupt_data = get_dataset(
model, device, prompt_type=prompt_type, scaffold=scaffold
)
clean_labels = clean_corrupt_data.answer_tokens[:, 0, 0] == clean_corrupt_data.answer_tokens[0, 0, 0]

assert len(clean_corrupt_data.all_prompts) == len(clean_corrupt_data.answer_tokens)
assert len(clean_corrupt_data.all_prompts) == len(clean_corrupt_data.clean_tokens)
clean_labels = (
clean_corrupt_data.answer_tokens[:, 0, 0]
== clean_corrupt_data.answer_tokens[0, 0, 0]
)

assert len(clean_corrupt_data.all_prompts) == len(
clean_corrupt_data.answer_tokens
)
assert len(clean_corrupt_data.all_prompts) == len(
clean_corrupt_data.clean_tokens
)
return cls(
clean_corrupt_data.all_prompts,
clean_corrupt_data.clean_tokens,
clean_labels,
model,
prompt_type,
)


def _get_labels_by_class(self, label: int) -> List[str]:
return [
string.strip() for string, one_hot in zip(self.str_labels, self.binary_labels) if one_hot == label
string.strip()
for string, one_hot in zip(self.str_labels, self.binary_labels)
if one_hot == label
]

def get_positive_negative_labels(self):
return (
self._get_labels_by_class(1),
self._get_labels_by_class(0),
)


0 comments on commit 9ddca9b

Please sign in to comment.