diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index fb6f73ecf9..12a0e5e3f2 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -191,16 +191,17 @@ def scaled_multihead_dot_product_attention( ) if sliding_window_size != -1: - window_mask = torch.ones_like(attn_weight) + window_mask = torch.ones((s_q, s_k), + dtype=torch.bool, + device=attn_weight.device) window_mask = torch.tril( window_mask, diagonal=sliding_window_size, - ) # TODO: check if it should be sliding_window_size + 1 or sliding_window_size - 1 or sliding_window_size + ) window_mask = torch.triu( window_mask, - diagonal=sliding_window_size, - ) # TODO: check if it should be sliding_window_size + 1 or sliding_window_size - 1 or sliding_window_size - window_mask = window_mask.to(torch.bool) + diagonal=-sliding_window_size, + ) window_mask = window_mask[-s_q:, -s_k:] window_mask = ~window_mask attn_weight = attn_weight.masked_fill( diff --git a/tests/models/layers/test_attention.py b/tests/models/layers/test_attention.py index bee6f8809e..c51a532092 100644 --- a/tests/models/layers/test_attention.py +++ b/tests/models/layers/test_attention.py @@ -189,6 +189,22 @@ def test_sliding_window(sliding_window_size: int, attn_impl: str): n_heads * d).to(dtype=dtype, device=device) value_1.requires_grad = True + attn_extra_kwargs = {} + if attn_impl == 'flash': + attn_extra_kwargs = { + 'flash_attn_padding_info': + gen_flash_attn_padding_info( + bsz, + seqlen_1, + 0, + query_1.device, + None, + None, + ), + 'should_repeat_kv_for_gqa': + True, + } + output_1, _, _ = attention_implementations.get(attn_impl)( query=query_1, key=key_1, @@ -203,16 +219,8 @@ def test_sliding_window(sliding_window_size: int, attn_impl: str): dropout_p=0.0, training=False, needs_weights=False, - flash_attn_padding_info=gen_flash_attn_padding_info( - bsz, - seqlen_1, - 0, - query_1.device, - None, - None, - ), - should_repeat_kv_for_gqa=True, sliding_window_size=sliding_window_size, + **attn_extra_kwargs, ) output_1.sum().backward()