Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding sliding window attn to scaled_multihead_dot_product_attention #1455

30 changes: 28 additions & 2 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def scaled_multihead_dot_product_attention(
dropout_p: float = 0.0,
training: bool = False,
needs_weights: bool = False,
sliding_window_size: int = -1,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:

Expand Down Expand Up @@ -177,7 +178,7 @@ def scaled_multihead_dot_product_attention(
min_val,
)

if is_causal and (not q.size(2) == 1):
if is_causal and (not s_q == 1):
s = max(s_q, s_k)
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float32)
causal_mask = causal_mask.tril()
Expand All @@ -189,6 +190,31 @@ def scaled_multihead_dot_product_attention(
min_val,
)

if sliding_window_size != -1:
window_mask = torch.ones((s_q, s_k),
dtype=torch.bool,
device=attn_weight.device)
if (not s_q == 1):
if s_q != s_k:
raise ValueError(
'Number of queries should be equal to the number of keys.',
)
window_mask = torch.tril(
window_mask,
diagonal=sliding_window_size,
)
window_mask = torch.triu(
window_mask,
diagonal=-sliding_window_size,
)
else:
window_mask[:, :-(sliding_window_size + 1)] = False
window_mask = ~window_mask
attn_weight = attn_weight.masked_fill(
window_mask.view(1, 1, s_q, s_k),
min_val,
)

attn_weight = torch.softmax(attn_weight, dim=-1)

if dropout_p:
Expand Down Expand Up @@ -591,6 +617,7 @@ def forward(
dropout_p=self.attn_dropout_p,
training=self.training,
needs_weights=needs_weights,
sliding_window_size=self.sliding_window_size,
**extra_attn_kwargs,
)

Expand Down Expand Up @@ -771,7 +798,6 @@ def get_implementation_specific_args(
if self.attn_impl == 'flash':
extra_attn_kwargs = {
'should_repeat_kv_for_gqa': not is_flash_v2_installed(),
'sliding_window_size': self.sliding_window_size,
'alibi_slopes': alibi_slopes,
'flash_attn_padding_info': flash_attn_padding_info,
'key_padding_mask': None,
Expand Down
9 changes: 4 additions & 5 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,12 +329,11 @@ def _validate_config(self) -> None:
raise ImportError(
'If using the dail implementation of rope, the flash_attn library v2.0.1 or higher must be installed. Please check the instructions at https://github.com/mosaicml/llm-foundry/blob/main/TUTORIAL.md#what-kinds-of-positional-embeddings-does-llm-foundry-support',
)
if self.attn_config['sliding_window_size'] != -1 and not (
self.attn_config['attn_impl'] == 'flash' and
is_flash_v2_installed(v2_version='v2.3.0')
):
if self.attn_config['sliding_window_size'] != -1 and self.attn_config[
'attn_impl'
] == 'flash' and not is_flash_v2_installed(v2_version='v2.3.0',):
raise NotImplementedError(
'sliding window only implemented with flash attention v2.3.0 or higher.',
'sliding window attention only implemented for torch attention and flash attention (v2.3.0 or higher).',
)
if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
raise ValueError(
Expand Down
113 changes: 113 additions & 0 deletions tests/models/layers/test_attention.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import math

import pytest
import torch

from llmfoundry.models.layers.attention import (
attention_implementations,
scaled_multihead_dot_product_attention,
)
from llmfoundry.models.layers.layer_builders import build_attention_layer
from llmfoundry.models.mpt.modeling_mpt import gen_flash_attn_padding_info


@pytest.mark.parametrize(
Expand Down Expand Up @@ -158,3 +165,109 @@ def test_unfused_wqkv(attn_name: str, dim: int):
assert isinstance(attn_fused.Wqkv.weight.grad, torch.Tensor)
assert isinstance(combined_grad, torch.Tensor)
assert torch.allclose(attn_fused.Wqkv.weight.grad, combined_grad)


@pytest.mark.gpu
@pytest.mark.parametrize('sliding_window_size', [1, 4, 8])
@pytest.mark.parametrize('attn_impl', ['flash', 'torch'])
def test_sliding_window(sliding_window_size: int, attn_impl: str):
# Test that sliding window attention works as expected.
dtype = torch.bfloat16
device = 'cuda'
d = 128
n_heads = 8
seqlen_1 = 8
bsz = 2

query_1 = torch.randn(bsz, seqlen_1,
n_heads * d).to(dtype=dtype, device=device)
query_1.requires_grad = True
key_1 = torch.randn(bsz, seqlen_1,
n_heads * d).to(dtype=dtype, device=device)
key_1.requires_grad = True
value_1 = torch.randn(bsz, seqlen_1,
n_heads * d).to(dtype=dtype, device=device)
value_1.requires_grad = True

attn_extra_kwargs = {}
if attn_impl == 'flash':
attn_extra_kwargs = {
'flash_attn_padding_info':
gen_flash_attn_padding_info(
bsz,
seqlen_1,
0,
query_1.device,
None,
None,
),
'should_repeat_kv_for_gqa':
True,
}

output_1, _, _ = attention_implementations.get(attn_impl)(
query=query_1,
key=key_1,
value=value_1,
n_heads=n_heads,
kv_n_heads=n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=None,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
sliding_window_size=sliding_window_size,
**attn_extra_kwargs,
)

output_1.sum().backward()

query_2 = query_1.detach().clone()
query_2.requires_grad = True
key_2 = key_1.detach().clone()
key_2.requires_grad = True
value_2 = value_1.detach().clone()
value_2.requires_grad = True

attn_bias_2 = torch.zeros(1, 1, seqlen_1,
seqlen_1).to(dtype=dtype, device=device)

window_mask_2 = torch.tril(
torch.ones(seqlen_1, seqlen_1),
diagonal=-(sliding_window_size + 1),
).to(dtype=dtype, device=device) * torch.finfo(attn_bias_2.dtype).min
attn_bias_2 = attn_bias_2 + window_mask_2
output_2, _, _ = scaled_multihead_dot_product_attention(
query=query_2,
key=key_2,
value=value_2,
n_heads=n_heads,
kv_n_heads=n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=attn_bias_2,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
)

output_2.sum().backward()

print(torch.max(output_1 - output_2))

_assert_approx_equal(output_1, output_2)
assert (query_2.grad is not None) and (query_1.grad is not None)
_assert_approx_equal(query_1.grad, query_2.grad)
assert (key_2.grad is not None) and (key_1.grad is not None)
_assert_approx_equal(key_1.grad, key_2.grad)
assert (value_2.grad is not None) and (value_1.grad is not None)
_assert_approx_equal(value_1.grad, value_2.grad)


def _assert_approx_equal(value1: torch.Tensor, value2: torch.Tensor):
assert torch.norm(value2 - value1) <= 1e-2 + 1e-2 * torch.norm(value2)
98 changes: 0 additions & 98 deletions tests/models/layers/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,104 +218,6 @@ def test_seq_id_masking_FA_v2():
)


@pytest.mark.gpu
@pytest.mark.skipif(
not is_flash_v2_installed(v2_version='v2.3.0'),
reason=
'Sliding window attention only supported by Flash Attention after v2.3.0.',
)
@pytest.mark.parametrize('sliding_window_size', [1, 4, 8])
def test_sliding_window(sliding_window_size: int):
# Test that sliding window attention works as expected.
dtype = torch.bfloat16
device = 'cuda'
d = 128
n_heads = 8
seqlen_1 = 8
bsz = 2

query_1 = torch.randn(bsz, seqlen_1,
n_heads * d).to(dtype=dtype, device=device)
query_1.requires_grad = True
key_1 = torch.randn(bsz, seqlen_1,
n_heads * d).to(dtype=dtype, device=device)
key_1.requires_grad = True
value_1 = torch.randn(bsz, seqlen_1,
n_heads * d).to(dtype=dtype, device=device)
value_1.requires_grad = True

output_1, _, _ = flash_attn_fn(
query=query_1,
key=key_1,
value=value_1,
n_heads=n_heads,
kv_n_heads=n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=None,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
flash_attn_padding_info=gen_flash_attn_padding_info(
bsz,
seqlen_1,
0,
query_1.device,
None,
None,
),
should_repeat_kv_for_gqa=True,
sliding_window_size=sliding_window_size,
)

output_1.sum().backward()

query_2 = query_1.detach().clone()
query_2.requires_grad = True
key_2 = key_1.detach().clone()
key_2.requires_grad = True
value_2 = value_1.detach().clone()
value_2.requires_grad = True

attn_bias_2 = torch.zeros(1, 1, seqlen_1,
seqlen_1).to(dtype=dtype, device=device)

window_mask_2 = torch.tril(
torch.ones(seqlen_1, seqlen_1),
diagonal=-(sliding_window_size + 1),
).to(dtype=dtype, device=device) * torch.finfo(attn_bias_2.dtype).min
attn_bias_2 = attn_bias_2 + window_mask_2
output_2, _, _ = scaled_multihead_dot_product_attention(
query=query_2,
key=key_2,
value=value_2,
n_heads=n_heads,
kv_n_heads=n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=attn_bias_2,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
)

output_2.sum().backward()

print(torch.max(output_1 - output_2))

_assert_approx_equal(output_1, output_2)
assert (query_2.grad is not None) and (query_1.grad is not None)
_assert_approx_equal(query_1.grad, query_2.grad)
assert (key_2.grad is not None) and (key_1.grad is not None)
_assert_approx_equal(key_1.grad, key_2.grad)
assert (value_2.grad is not None) and (value_1.grad is not None)
_assert_approx_equal(value_1.grad, value_2.grad)


@pytest.mark.gpu
@pytest.mark.skipif(
not check_alibi_support('flash'),
Expand Down
3 changes: 3 additions & 0 deletions tests/models/layers/test_flash_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def allclose_helper(
)
@pytest.mark.parametrize('attn_uses_sequence_id', [True, False])
@pytest.mark.parametrize('pad_attention_mask', [True, False])
@pytest.mark.parametrize('sliding_window_size', [-1, 2])
def test_attn_impl(
attn_impl_0: str,
attn_impl_1: str,
Expand All @@ -87,6 +88,7 @@ def test_attn_impl(
attn_type: str,
attn_uses_sequence_id: bool,
pad_attention_mask: bool,
sliding_window_size: int,
device: str = 'cuda',
):
"""Compare all attn impl with each other.
Expand Down Expand Up @@ -122,6 +124,7 @@ def test_attn_impl(
'clip_qkv': clip_qkv,
'qk_ln': qk_ln,
'qk_gn': qk_gn,
'sliding_window_size': sliding_window_size,
})

n, s, f = 2, 4, cfg.d_model
Expand Down
Loading