Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Aug 15, 2024
1 parent de76124 commit 83cbd98
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 14 deletions.
11 changes: 6 additions & 5 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,16 +191,17 @@ def scaled_multihead_dot_product_attention(
)

if sliding_window_size != -1:
window_mask = torch.ones_like(attn_weight)
window_mask = torch.ones((s_q, s_k),
dtype=torch.bool,
device=attn_weight.device)
window_mask = torch.tril(
window_mask,
diagonal=sliding_window_size,
) # TODO: check if it should be sliding_window_size + 1 or sliding_window_size - 1 or sliding_window_size
)
window_mask = torch.triu(
window_mask,
diagonal=sliding_window_size,
) # TODO: check if it should be sliding_window_size + 1 or sliding_window_size - 1 or sliding_window_size
window_mask = window_mask.to(torch.bool)
diagonal=-sliding_window_size,
)
window_mask = window_mask[-s_q:, -s_k:]
window_mask = ~window_mask
attn_weight = attn_weight.masked_fill(
Expand Down
26 changes: 17 additions & 9 deletions tests/models/layers/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,22 @@ def test_sliding_window(sliding_window_size: int, attn_impl: str):
n_heads * d).to(dtype=dtype, device=device)
value_1.requires_grad = True

attn_extra_kwargs = {}
if attn_impl == 'flash':
attn_extra_kwargs = {
'flash_attn_padding_info':
gen_flash_attn_padding_info(
bsz,
seqlen_1,
0,
query_1.device,
None,
None,
),
'should_repeat_kv_for_gqa':
True,
}

output_1, _, _ = attention_implementations.get(attn_impl)(
query=query_1,
key=key_1,
Expand All @@ -203,16 +219,8 @@ def test_sliding_window(sliding_window_size: int, attn_impl: str):
dropout_p=0.0,
training=False,
needs_weights=False,
flash_attn_padding_info=gen_flash_attn_padding_info(
bsz,
seqlen_1,
0,
query_1.device,
None,
None,
),
should_repeat_kv_for_gqa=True,
sliding_window_size=sliding_window_size,
**attn_extra_kwargs,
)

output_1.sum().backward()
Expand Down

0 comments on commit 83cbd98

Please sign in to comment.