Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Aug 15, 2024
1 parent 409ed15 commit a90f249
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def scaled_multihead_dot_product_attention(
min_val,
)

if is_causal and (not q.size(2) == 1):
if is_causal and (not s_q == 1):
s = max(s_q, s_k)
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float32)
causal_mask = causal_mask.tril()
Expand All @@ -194,14 +194,21 @@ def scaled_multihead_dot_product_attention(
window_mask = torch.ones((s_q, s_k),
dtype=torch.bool,
device=attn_weight.device)
window_mask = torch.tril(
window_mask,
diagonal=sliding_window_size,
)
window_mask = torch.triu(
window_mask,
diagonal=-sliding_window_size,
)
if (not s_q == 1):
if s_q != s_k:
raise ValueError(
'Number of queries should be equal to the number of keys.',
)
window_mask = torch.tril(
window_mask,
diagonal=sliding_window_size,
)
window_mask = torch.triu(
window_mask,
diagonal=-sliding_window_size,
)
else:
window_mask[:, :-(sliding_window_size + 1)] = False
window_mask = ~window_mask
attn_weight = attn_weight.masked_fill(
window_mask.view(1, 1, s_q, s_k),
Expand Down

0 comments on commit a90f249

Please sign in to comment.