diff --git a/utils/das.py b/utils/das.py index 2551715..9938076 100644 --- a/utils/das.py +++ b/utils/das.py @@ -71,7 +71,7 @@ def hook_fn_base( resid: Float[Tensor, "batch pos d_model"], hook: HookPoint, layer: int, - position: Optional[Int[Tensor, "batch"]], + position: Optional[Union[int, Int[Tensor, "batch"]]], new_value: Float[Tensor, "batch *pos d_model"] ): batch_size, seq_len, d_model = resid.shape @@ -81,6 +81,13 @@ def hook_fn_base( if position is None: assert new_value.shape == resid.shape return new_value + if isinstance(position, int): + position = torch.tensor(position, device=resid.device) + position = einops.repeat( + position, + " -> batch", + batch=batch_size, + ) new_value_repeat = einops.repeat( new_value, "batch d_model -> batch pos d_model", diff --git a/utils/prompts.py b/utils/prompts.py index 8f035f5..3e4ac81 100644 --- a/utils/prompts.py +++ b/utils/prompts.py @@ -512,9 +512,12 @@ def __init__( prompt_type is not None and len(prompt_type.get_placeholders()) == 0 ) if position is None and no_placeholders: - mask = get_attention_mask( - tokenizer, clean_tokens, prepend_bos=False - ) + if tokenizer is None: + mask = torch.ones_like(clean_tokens, dtype=torch.bool) + else: + mask = get_attention_mask( + tokenizer, clean_tokens, prepend_bos=False + ) position = get_final_non_pad_token(mask) elif position is None and prompt_type is not None: # Default to first placeholder position