Skip to content

Commit

Permalink
Added handling for tokenizer is None and position is int
Browse files Browse the repository at this point in the history
  • Loading branch information
ojh31 committed Oct 12, 2023
1 parent 1e874b5 commit 62039cb
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
9 changes: 8 additions & 1 deletion utils/das.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def hook_fn_base(
resid: Float[Tensor, "batch pos d_model"],
hook: HookPoint,
layer: int,
position: Optional[Int[Tensor, "batch"]],
position: Optional[Union[int, Int[Tensor, "batch"]]],
new_value: Float[Tensor, "batch *pos d_model"]
):
batch_size, seq_len, d_model = resid.shape
Expand All @@ -81,6 +81,13 @@ def hook_fn_base(
if position is None:
assert new_value.shape == resid.shape
return new_value
if isinstance(position, int):
position = torch.tensor(position, device=resid.device)
position = einops.repeat(
position,
" -> batch",
batch=batch_size,
)
new_value_repeat = einops.repeat(
new_value,
"batch d_model -> batch pos d_model",
Expand Down
9 changes: 6 additions & 3 deletions utils/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,9 +512,12 @@ def __init__(
prompt_type is not None and len(prompt_type.get_placeholders()) == 0
)
if position is None and no_placeholders:
mask = get_attention_mask(
tokenizer, clean_tokens, prepend_bos=False
)
if tokenizer is None:
mask = torch.ones_like(clean_tokens, dtype=torch.bool)
else:
mask = get_attention_mask(
tokenizer, clean_tokens, prepend_bos=False
)
position = get_final_non_pad_token(mask)
elif position is None and prompt_type is not None:
# Default to first placeholder position
Expand Down

0 comments on commit 62039cb

Please sign in to comment.