From 93b22d8f429023723ab7a5204dc9292ff25c67c1 Mon Sep 17 00:00:00 2001 From: Shashank Rajput <144760128+ShashankMosaicML@users.noreply.github.com> Date: Thu, 15 Aug 2024 14:53:14 -0700 Subject: [PATCH] Adding sliding window attn to scaled_multihead_dot_product_attention (#1455) --- llmfoundry/models/layers/attention.py | 30 +++++- llmfoundry/models/mpt/configuration_mpt.py | 9 +- tests/models/layers/test_attention.py | 113 +++++++++++++++++++++ tests/models/layers/test_flash_attn.py | 98 ------------------ tests/models/layers/test_flash_torch.py | 3 + 5 files changed, 148 insertions(+), 105 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index db8ddce4d3..acf231558e 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -112,6 +112,7 @@ def scaled_multihead_dot_product_attention( dropout_p: float = 0.0, training: bool = False, needs_weights: bool = False, + sliding_window_size: int = -1, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: @@ -177,7 +178,7 @@ def scaled_multihead_dot_product_attention( min_val, ) - if is_causal and (not q.size(2) == 1): + if is_causal and (not s_q == 1): s = max(s_q, s_k) causal_mask = attn_weight.new_ones(s, s, dtype=torch.float32) causal_mask = causal_mask.tril() @@ -189,6 +190,31 @@ def scaled_multihead_dot_product_attention( min_val, ) + if sliding_window_size != -1: + window_mask = torch.ones((s_q, s_k), + dtype=torch.bool, + device=attn_weight.device) + if (not s_q == 1): + if s_q != s_k: + raise ValueError( + 'Number of queries should be equal to the number of keys.', + ) + window_mask = torch.tril( + window_mask, + diagonal=sliding_window_size, + ) + window_mask = torch.triu( + window_mask, + diagonal=-sliding_window_size, + ) + else: + window_mask[:, :-(sliding_window_size + 1)] = False + window_mask = ~window_mask + attn_weight = attn_weight.masked_fill( + window_mask.view(1, 1, s_q, s_k), + min_val, + ) + attn_weight = torch.softmax(attn_weight, dim=-1) if dropout_p: @@ -591,6 +617,7 @@ def forward( dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, + sliding_window_size=self.sliding_window_size, **extra_attn_kwargs, ) @@ -771,7 +798,6 @@ def get_implementation_specific_args( if self.attn_impl == 'flash': extra_attn_kwargs = { 'should_repeat_kv_for_gqa': not is_flash_v2_installed(), - 'sliding_window_size': self.sliding_window_size, 'alibi_slopes': alibi_slopes, 'flash_attn_padding_info': flash_attn_padding_info, 'key_padding_mask': None, diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index a9fa2f4c16..759f347e89 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -329,12 +329,11 @@ def _validate_config(self) -> None: raise ImportError( 'If using the dail implementation of rope, the flash_attn library v2.0.1 or higher must be installed. Please check the instructions at https://github.com/mosaicml/llm-foundry/blob/main/TUTORIAL.md#what-kinds-of-positional-embeddings-does-llm-foundry-support', ) - if self.attn_config['sliding_window_size'] != -1 and not ( - self.attn_config['attn_impl'] == 'flash' and - is_flash_v2_installed(v2_version='v2.3.0') - ): + if self.attn_config['sliding_window_size'] != -1 and self.attn_config[ + 'attn_impl' + ] == 'flash' and not is_flash_v2_installed(v2_version='v2.3.0',): raise NotImplementedError( - 'sliding window only implemented with flash attention v2.3.0 or higher.', + 'sliding window attention only implemented for torch attention and flash attention (v2.3.0 or higher).', ) if self.embedding_fraction > 1 or self.embedding_fraction <= 0: raise ValueError( diff --git a/tests/models/layers/test_attention.py b/tests/models/layers/test_attention.py index bdffe2b49f..c51a532092 100644 --- a/tests/models/layers/test_attention.py +++ b/tests/models/layers/test_attention.py @@ -1,10 +1,17 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import math + import pytest import torch +from llmfoundry.models.layers.attention import ( + attention_implementations, + scaled_multihead_dot_product_attention, +) from llmfoundry.models.layers.layer_builders import build_attention_layer +from llmfoundry.models.mpt.modeling_mpt import gen_flash_attn_padding_info @pytest.mark.parametrize( @@ -158,3 +165,109 @@ def test_unfused_wqkv(attn_name: str, dim: int): assert isinstance(attn_fused.Wqkv.weight.grad, torch.Tensor) assert isinstance(combined_grad, torch.Tensor) assert torch.allclose(attn_fused.Wqkv.weight.grad, combined_grad) + + +@pytest.mark.gpu +@pytest.mark.parametrize('sliding_window_size', [1, 4, 8]) +@pytest.mark.parametrize('attn_impl', ['flash', 'torch']) +def test_sliding_window(sliding_window_size: int, attn_impl: str): + # Test that sliding window attention works as expected. + dtype = torch.bfloat16 + device = 'cuda' + d = 128 + n_heads = 8 + seqlen_1 = 8 + bsz = 2 + + query_1 = torch.randn(bsz, seqlen_1, + n_heads * d).to(dtype=dtype, device=device) + query_1.requires_grad = True + key_1 = torch.randn(bsz, seqlen_1, + n_heads * d).to(dtype=dtype, device=device) + key_1.requires_grad = True + value_1 = torch.randn(bsz, seqlen_1, + n_heads * d).to(dtype=dtype, device=device) + value_1.requires_grad = True + + attn_extra_kwargs = {} + if attn_impl == 'flash': + attn_extra_kwargs = { + 'flash_attn_padding_info': + gen_flash_attn_padding_info( + bsz, + seqlen_1, + 0, + query_1.device, + None, + None, + ), + 'should_repeat_kv_for_gqa': + True, + } + + output_1, _, _ = attention_implementations.get(attn_impl)( + query=query_1, + key=key_1, + value=value_1, + n_heads=n_heads, + kv_n_heads=n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=None, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + sliding_window_size=sliding_window_size, + **attn_extra_kwargs, + ) + + output_1.sum().backward() + + query_2 = query_1.detach().clone() + query_2.requires_grad = True + key_2 = key_1.detach().clone() + key_2.requires_grad = True + value_2 = value_1.detach().clone() + value_2.requires_grad = True + + attn_bias_2 = torch.zeros(1, 1, seqlen_1, + seqlen_1).to(dtype=dtype, device=device) + + window_mask_2 = torch.tril( + torch.ones(seqlen_1, seqlen_1), + diagonal=-(sliding_window_size + 1), + ).to(dtype=dtype, device=device) * torch.finfo(attn_bias_2.dtype).min + attn_bias_2 = attn_bias_2 + window_mask_2 + output_2, _, _ = scaled_multihead_dot_product_attention( + query=query_2, + key=key_2, + value=value_2, + n_heads=n_heads, + kv_n_heads=n_heads, + past_key_value=None, + softmax_scale=1 / math.sqrt(d), + attn_bias=attn_bias_2, + key_padding_mask=None, + is_causal=True, + dropout_p=0.0, + training=False, + needs_weights=False, + ) + + output_2.sum().backward() + + print(torch.max(output_1 - output_2)) + + _assert_approx_equal(output_1, output_2) + assert (query_2.grad is not None) and (query_1.grad is not None) + _assert_approx_equal(query_1.grad, query_2.grad) + assert (key_2.grad is not None) and (key_1.grad is not None) + _assert_approx_equal(key_1.grad, key_2.grad) + assert (value_2.grad is not None) and (value_1.grad is not None) + _assert_approx_equal(value_1.grad, value_2.grad) + + +def _assert_approx_equal(value1: torch.Tensor, value2: torch.Tensor): + assert torch.norm(value2 - value1) <= 1e-2 + 1e-2 * torch.norm(value2) diff --git a/tests/models/layers/test_flash_attn.py b/tests/models/layers/test_flash_attn.py index dcce0fe118..987ea7160a 100644 --- a/tests/models/layers/test_flash_attn.py +++ b/tests/models/layers/test_flash_attn.py @@ -218,104 +218,6 @@ def test_seq_id_masking_FA_v2(): ) -@pytest.mark.gpu -@pytest.mark.skipif( - not is_flash_v2_installed(v2_version='v2.3.0'), - reason= - 'Sliding window attention only supported by Flash Attention after v2.3.0.', -) -@pytest.mark.parametrize('sliding_window_size', [1, 4, 8]) -def test_sliding_window(sliding_window_size: int): - # Test that sliding window attention works as expected. - dtype = torch.bfloat16 - device = 'cuda' - d = 128 - n_heads = 8 - seqlen_1 = 8 - bsz = 2 - - query_1 = torch.randn(bsz, seqlen_1, - n_heads * d).to(dtype=dtype, device=device) - query_1.requires_grad = True - key_1 = torch.randn(bsz, seqlen_1, - n_heads * d).to(dtype=dtype, device=device) - key_1.requires_grad = True - value_1 = torch.randn(bsz, seqlen_1, - n_heads * d).to(dtype=dtype, device=device) - value_1.requires_grad = True - - output_1, _, _ = flash_attn_fn( - query=query_1, - key=key_1, - value=value_1, - n_heads=n_heads, - kv_n_heads=n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=None, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - flash_attn_padding_info=gen_flash_attn_padding_info( - bsz, - seqlen_1, - 0, - query_1.device, - None, - None, - ), - should_repeat_kv_for_gqa=True, - sliding_window_size=sliding_window_size, - ) - - output_1.sum().backward() - - query_2 = query_1.detach().clone() - query_2.requires_grad = True - key_2 = key_1.detach().clone() - key_2.requires_grad = True - value_2 = value_1.detach().clone() - value_2.requires_grad = True - - attn_bias_2 = torch.zeros(1, 1, seqlen_1, - seqlen_1).to(dtype=dtype, device=device) - - window_mask_2 = torch.tril( - torch.ones(seqlen_1, seqlen_1), - diagonal=-(sliding_window_size + 1), - ).to(dtype=dtype, device=device) * torch.finfo(attn_bias_2.dtype).min - attn_bias_2 = attn_bias_2 + window_mask_2 - output_2, _, _ = scaled_multihead_dot_product_attention( - query=query_2, - key=key_2, - value=value_2, - n_heads=n_heads, - kv_n_heads=n_heads, - past_key_value=None, - softmax_scale=1 / math.sqrt(d), - attn_bias=attn_bias_2, - key_padding_mask=None, - is_causal=True, - dropout_p=0.0, - training=False, - needs_weights=False, - ) - - output_2.sum().backward() - - print(torch.max(output_1 - output_2)) - - _assert_approx_equal(output_1, output_2) - assert (query_2.grad is not None) and (query_1.grad is not None) - _assert_approx_equal(query_1.grad, query_2.grad) - assert (key_2.grad is not None) and (key_1.grad is not None) - _assert_approx_equal(key_1.grad, key_2.grad) - assert (value_2.grad is not None) and (value_1.grad is not None) - _assert_approx_equal(value_1.grad, value_2.grad) - - @pytest.mark.gpu @pytest.mark.skipif( not check_alibi_support('flash'), diff --git a/tests/models/layers/test_flash_torch.py b/tests/models/layers/test_flash_torch.py index 4bfdfb84dc..01a6a7576d 100644 --- a/tests/models/layers/test_flash_torch.py +++ b/tests/models/layers/test_flash_torch.py @@ -77,6 +77,7 @@ def allclose_helper( ) @pytest.mark.parametrize('attn_uses_sequence_id', [True, False]) @pytest.mark.parametrize('pad_attention_mask', [True, False]) +@pytest.mark.parametrize('sliding_window_size', [-1, 2]) def test_attn_impl( attn_impl_0: str, attn_impl_1: str, @@ -87,6 +88,7 @@ def test_attn_impl( attn_type: str, attn_uses_sequence_id: bool, pad_attention_mask: bool, + sliding_window_size: int, device: str = 'cuda', ): """Compare all attn impl with each other. @@ -122,6 +124,7 @@ def test_attn_impl( 'clip_qkv': clip_qkv, 'qk_ln': qk_ln, 'qk_gn': qk_gn, + 'sliding_window_size': sliding_window_size, }) n, s, f = 2, 4, cfg.d_model