From 15526f7daf9999749cb7327cd438a8126ab97d78 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 29 Aug 2024 05:51:08 +0000 Subject: [PATCH 1/6] adding key value states to enable cross-attb --- llmfoundry/models/layers/attention.py | 30 +++++++++++++++++++-------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index acf231558e..2559fc1790 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -572,7 +572,8 @@ def __init__( def forward( self, - x: torch.Tensor, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, @@ -588,7 +589,11 @@ def forward( extra_kwargs = {} if prev_layer_key_value is not None: extra_kwargs['prev_layer_key_value'] = prev_layer_key_value - query, key, value = self.get_qkv(x, **extra_kwargs) + query, key, value = self.get_qkv( + hidden_states, + key_value_states, + **extra_kwargs, + ) if rotary_emb_w_meta_info is not None: query, key, value = self._apply_rotary_embeddings( @@ -625,14 +630,16 @@ def forward( def get_qkv( self, - x: torch.Tensor, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, prev_layer_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Computes and returns the query, key, and value tensors. Args: - x (torch.Tensor): The input tensor. + hidden_states (torch.Tensor): The input query tensor. + key_value_states (Optional[torch.Tensor]): The input tensor for keys and values. prev_layer_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]]): The key value of the previous layer. Returns: @@ -647,7 +654,7 @@ def get_qkv( ) key, value = prev_layer_key_value - query = self.Wq(x) + query = self.Wq(hidden_states) if self.clip_qkv: query = query.clamp(min=-self.clip_qkv, max=self.clip_qkv) @@ -662,7 +669,8 @@ def get_qkv( return query, key, value if self.fused_qkv: - qkv = self.Wqkv(x) + assert key_value_states is None, 'Cannot use separate hidden and key_value states for fused_qkv' + qkv = self.Wqkv(hidden_states) if self.clip_qkv: qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv) @@ -676,9 +684,13 @@ def get_qkv( dim=2, ) else: - query = self.Wq(x) - key = self.Wk(x) - value = self.Wv(x) + query = self.Wq(hidden_states) + if key_value_states is not None: + key = self.Wk(key_value_states) + value = self.Wv(key_value_states) + else: + key = self.Wk(hidden_states) + value = self.Wv(hidden_states) if self.clip_qkv: query = query.clamp(min=-self.clip_qkv, max=self.clip_qkv) From 70a6f310e8fa3cf309402e01a68eb0340e5f3ef4 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 29 Aug 2024 06:10:40 +0000 Subject: [PATCH 2/6] tests for xattn --- tests/models/layers/test_attention.py | 67 +++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/tests/models/layers/test_attention.py b/tests/models/layers/test_attention.py index c51a532092..2f58910f6a 100644 --- a/tests/models/layers/test_attention.py +++ b/tests/models/layers/test_attention.py @@ -271,3 +271,70 @@ def test_sliding_window(sliding_window_size: int, attn_impl: str): def _assert_approx_equal(value1: torch.Tensor, value2: torch.Tensor): assert torch.norm(value2 - value1) <= 1e-2 + 1e-2 * torch.norm(value2) + + +@pytest.mark.parametrize( + 'attn_name', + ['multihead_attention', 'grouped_query_attention', 'multiquery_attention'], +) +@pytest.mark.parametrize('dim', [1024]) +def test_cross_attn_as_self_attn(attn_name: str, dim: int): + d_head = 128 + n_heads = dim // d_head + + generic_attn_kwargs = { + 'd_model': dim, + 'n_heads': n_heads, + 'fc_type': { + 'name': 'torch', + }, + 'device': 'cpu', + 'attn_pdrop': 0.0, + 'attn_impl': 'torch', + 'qk_ln': False, + 'qk_gn': False, + 'clip_qkv': None, + 'softmax_scale': None, + 'sliding_window_size': -1, + } + + if attn_name == 'grouped_query_attention': + kv_n_heads = 2 + generic_attn_kwargs['kv_n_heads'] = kv_n_heads + elif attn_name == 'multiquery_attention': + kv_n_heads = 1 + elif attn_name == 'multihead_attention': + kv_n_heads = n_heads + else: + raise ValueError(f'Unknown attention name: {attn_name}') + + attn_config_fused = generic_attn_kwargs.copy() + attn_config_fused['fused_qkv'] = True + + attn_config_unfused = generic_attn_kwargs.copy() + attn_config_unfused['fused_qkv'] = False + + attn_fused = build_attention_layer( + name=attn_name, + attn_kwargs=attn_config_fused, + ) + attn_unfused = build_attention_layer( + name=attn_name, + attn_kwargs=attn_config_unfused, + ) + + x1 = torch.randn(1, 1, dim) + x2 = x1.detach().clone() + + out_fused, _, _ = attn_fused(x1) + out_unfused, _, _ = attn_unfused(x1, x2) + + # Dummy loss function is simply the sum. + loss_fused = out_fused.sum() + loss_fused.backward() + + loss_unfused = out_unfused.sum() + loss_unfused.backward() + + assert torch.allclose(out_fused, out_unfused) + assert torch.allclose(loss_fused, loss_unfused) From 7e7ece1ec76e2aedff33a442ac0098dd731b3c6b Mon Sep 17 00:00:00 2001 From: root Date: Thu, 29 Aug 2024 15:16:26 +0000 Subject: [PATCH 3/6] check for fwd passes only --- tests/models/layers/test_attention.py | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/tests/models/layers/test_attention.py b/tests/models/layers/test_attention.py index 2f58910f6a..37ef722edc 100644 --- a/tests/models/layers/test_attention.py +++ b/tests/models/layers/test_attention.py @@ -308,17 +308,10 @@ def test_cross_attn_as_self_attn(attn_name: str, dim: int): else: raise ValueError(f'Unknown attention name: {attn_name}') - attn_config_fused = generic_attn_kwargs.copy() - attn_config_fused['fused_qkv'] = True - attn_config_unfused = generic_attn_kwargs.copy() attn_config_unfused['fused_qkv'] = False - attn_fused = build_attention_layer( - name=attn_name, - attn_kwargs=attn_config_fused, - ) - attn_unfused = build_attention_layer( + attn_layer = build_attention_layer( name=attn_name, attn_kwargs=attn_config_unfused, ) @@ -326,15 +319,7 @@ def test_cross_attn_as_self_attn(attn_name: str, dim: int): x1 = torch.randn(1, 1, dim) x2 = x1.detach().clone() - out_fused, _, _ = attn_fused(x1) - out_unfused, _, _ = attn_unfused(x1, x2) - - # Dummy loss function is simply the sum. - loss_fused = out_fused.sum() - loss_fused.backward() - - loss_unfused = out_unfused.sum() - loss_unfused.backward() + out_fused, _, _ = attn_layer(hidden_states=x1) + out_unfused, _, _ = attn_layer(hidden_states=x1, key_value_states=x2) assert torch.allclose(out_fused, out_unfused) - assert torch.allclose(loss_fused, loss_unfused) From f7bf206ec3bd26440eea5a7c66b64942ba063254 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 29 Aug 2024 18:10:28 +0000 Subject: [PATCH 4/6] adding kv_dim changes for more generic xattn layer --- llmfoundry/models/layers/attention.py | 46 ++++++++----- llmfoundry/models/mpt/configuration_mpt.py | 1 + llmfoundry/models/utils/config_defaults.py | 1 + tests/models/layers/test_attention.py | 75 ++++++++++++++++++++-- 4 files changed, 101 insertions(+), 22 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 2559fc1790..9ab561f80d 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -424,8 +424,9 @@ class GroupedQueryAttention(nn.Module): and Multi-query attention (MQA). This allows the user to set a variable of number of kv_n_heads, rather than - just n_heads or 1, as in MHA and MQA. Using torch attention - implementation enables user to also use additive bias. + just n_heads or 1, as in MHA and MQA. Using torch attention implementation + enables user to also use additive bias. This class also supports + cross-attention with different `in_features` for key and value fc projections. """ def __init__( @@ -447,6 +448,7 @@ def __init__( bias: bool = True, sliding_window_size: int = -1, reuse_kv_layer_idx: Optional[int] = None, + kv_dim: Optional[int] = None, ): super().__init__() @@ -462,6 +464,12 @@ def __init__( self.sliding_window_size = sliding_window_size self.reuse_kv_layer_idx = reuse_kv_layer_idx + if kv_dim is not None: + self.kv_dim = kv_dim + assert fused_qkv is False, 'Cannot use separate kv_dim from d_model for fused_qkv' + else: + self.kv_dim = self.d_model + self.head_dim = d_model // n_heads # Usually, fc_type dict should be passed in through MPTBlock's __init__ function. @@ -523,13 +531,13 @@ def __init__( ) self.Wk = build_fc( name=fc_type_name, - in_features=self.d_model, + in_features=self.kv_dim, out_features=self.kv_n_heads * self.head_dim, fc_kwargs=fc_type, ) self.Wv = build_fc( name=fc_type_name, - in_features=self.d_model, + in_features=self.kv_dim, out_features=self.kv_n_heads * self.head_dim, fc_kwargs=fc_type, ) @@ -572,8 +580,7 @@ def __init__( def forward( self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, + x: torch.Tensor, past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, attn_bias: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, @@ -584,14 +591,15 @@ def forward( flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, prev_layer_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + key_value_states: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: extra_kwargs = {} if prev_layer_key_value is not None: extra_kwargs['prev_layer_key_value'] = prev_layer_key_value query, key, value = self.get_qkv( - hidden_states, - key_value_states, + x=x, + key_value_states=key_value_states, **extra_kwargs, ) @@ -630,17 +638,17 @@ def forward( def get_qkv( self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, + x: torch.Tensor, prev_layer_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + key_value_states: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Computes and returns the query, key, and value tensors. Args: - hidden_states (torch.Tensor): The input query tensor. - key_value_states (Optional[torch.Tensor]): The input tensor for keys and values. + x (torch.Tensor): The input query tensor. prev_layer_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]]): The key value of the previous layer. + key_value_states (Optional[torch.Tensor]): The input tensor for keys and values. Returns: query (torch.Tensor): The query tensor. @@ -654,7 +662,7 @@ def get_qkv( ) key, value = prev_layer_key_value - query = self.Wq(hidden_states) + query = self.Wq(x) if self.clip_qkv: query = query.clamp(min=-self.clip_qkv, max=self.clip_qkv) @@ -670,7 +678,7 @@ def get_qkv( if self.fused_qkv: assert key_value_states is None, 'Cannot use separate hidden and key_value states for fused_qkv' - qkv = self.Wqkv(hidden_states) + qkv = self.Wqkv(x) if self.clip_qkv: qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv) @@ -684,13 +692,13 @@ def get_qkv( dim=2, ) else: - query = self.Wq(hidden_states) + query = self.Wq(x) if key_value_states is not None: key = self.Wk(key_value_states) value = self.Wv(key_value_states) else: - key = self.Wk(hidden_states) - value = self.Wv(hidden_states) + key = self.Wk(x) + value = self.Wv(x) if self.clip_qkv: query = query.clamp(min=-self.clip_qkv, max=self.clip_qkv) @@ -844,6 +852,7 @@ def __init__( bias: bool = True, sliding_window_size: int = -1, reuse_kv_layer_idx: Optional[int] = None, + kv_dim: Optional[int] = None, ): super().__init__( d_model=d_model, @@ -863,6 +872,7 @@ def __init__( bias=bias, sliding_window_size=sliding_window_size, reuse_kv_layer_idx=reuse_kv_layer_idx, + kv_dim=kv_dim, ) @@ -891,6 +901,7 @@ def __init__( bias: bool = True, sliding_window_size: int = -1, reuse_kv_layer_idx: Optional[int] = None, + kv_dim: Optional[int] = None, ): super().__init__( d_model=d_model, @@ -910,6 +921,7 @@ def __init__( bias=bias, sliding_window_size=sliding_window_size, reuse_kv_layer_idx=reuse_kv_layer_idx, + kv_dim=kv_dim, ) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 759f347e89..c552f3125e 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -95,6 +95,7 @@ def __init__( type (str): Can be one of 'no_scaling', 'linear', or 'dynamic'. 'no_scaling' uses the default implementation for rotary embeddings, 'linear' uses linear scaling as proposed by the Reddit user /u/kaiokendev, and 'dynamic' uses Dynamic NTK scaling as proposed by the Reddit users /u/bloc97 and /u/emozilla. factor (float): Scaling factor to use if using 'linear' or 'dynamic' as rope_scaling.type. kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads. + kv_dim (Optional[int]): For cross-attention only, allow user to specify different input dimensions for kv projections. ffn_config (Dict): A dictionary used to configure the model's ffn module: ffn_type (str): type of ffn to use. Options: mptmlp, mptglu, te_ln_mlp init_device (str): The device to use for parameter initialization. diff --git a/llmfoundry/models/utils/config_defaults.py b/llmfoundry/models/utils/config_defaults.py index 8a15f0d81a..bd3b29a479 100644 --- a/llmfoundry/models/utils/config_defaults.py +++ b/llmfoundry/models/utils/config_defaults.py @@ -32,6 +32,7 @@ 'type': 'no_scaling', 'factor': 1.0, }, + 'kv_dim': None, } init_config_defaults: dict = { diff --git a/tests/models/layers/test_attention.py b/tests/models/layers/test_attention.py index 37ef722edc..63ecb17d78 100644 --- a/tests/models/layers/test_attention.py +++ b/tests/models/layers/test_attention.py @@ -5,6 +5,7 @@ import pytest import torch +from composer.utils import reproducibility from llmfoundry.models.layers.attention import ( attention_implementations, @@ -308,18 +309,82 @@ def test_cross_attn_as_self_attn(attn_name: str, dim: int): else: raise ValueError(f'Unknown attention name: {attn_name}') - attn_config_unfused = generic_attn_kwargs.copy() - attn_config_unfused['fused_qkv'] = False + attn_config = generic_attn_kwargs.copy() + attn_config['fused_qkv'] = False attn_layer = build_attention_layer( name=attn_name, - attn_kwargs=attn_config_unfused, + attn_kwargs=attn_config, ) x1 = torch.randn(1, 1, dim) x2 = x1.detach().clone() - out_fused, _, _ = attn_layer(hidden_states=x1) - out_unfused, _, _ = attn_layer(hidden_states=x1, key_value_states=x2) + out_fused, _, _ = attn_layer(x1) + out_unfused, _, _ = attn_layer(x1, key_value_states=x2) + + assert torch.allclose(out_fused, out_unfused) + + +@pytest.mark.parametrize( + 'attn_name', + ['multihead_attention', 'grouped_query_attention', 'multiquery_attention'], +) +@pytest.mark.parametrize('dim', [1024]) +def test_cross_attn_kv_dim(attn_name: str, dim: int): + d_head = 128 + n_heads = dim // d_head + + generic_attn_kwargs = { + 'd_model': dim, + 'n_heads': n_heads, + 'fc_type': { + 'name': 'torch', + }, + 'device': 'cpu', + 'attn_pdrop': 0.0, + 'attn_impl': 'torch', + 'qk_ln': False, + 'qk_gn': False, + 'clip_qkv': None, + 'softmax_scale': None, + 'sliding_window_size': -1, + } + + if attn_name == 'grouped_query_attention': + kv_n_heads = 2 + generic_attn_kwargs['kv_n_heads'] = kv_n_heads + elif attn_name == 'multiquery_attention': + kv_n_heads = 1 + elif attn_name == 'multihead_attention': + kv_n_heads = n_heads + else: + raise ValueError(f'Unknown attention name: {attn_name}') + + # layer with only dim passed in + attn_config = generic_attn_kwargs.copy() + attn_config['fused_qkv'] = False + + reproducibility.seed_all(42) + attn_layer_no_kv = build_attention_layer( + name=attn_name, + attn_kwargs=attn_config, + ) + + # layer with kv_dim = dim passed in + attn_config = generic_attn_kwargs.copy() + attn_config['fused_qkv'] = False + attn_config['kv_dim'] = dim + + reproducibility.seed_all(42) + attn_layer_kv = build_attention_layer( + name=attn_name, + attn_kwargs=attn_config, + ) + + x1 = torch.randn(1, 1, dim) + + out_fused, _, _ = attn_layer_no_kv(x1) + out_unfused, _, _ = attn_layer_kv(x1) assert torch.allclose(out_fused, out_unfused) From 39720bed6f8a0fc7b514c72ce8898697ac029cbd Mon Sep 17 00:00:00 2001 From: root Date: Thu, 29 Aug 2024 18:32:33 +0000 Subject: [PATCH 5/6] move assert to config checks --- llmfoundry/models/layers/attention.py | 7 +------ llmfoundry/models/mpt/configuration_mpt.py | 5 +++++ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 9ab561f80d..8cb618af7f 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -464,12 +464,7 @@ def __init__( self.sliding_window_size = sliding_window_size self.reuse_kv_layer_idx = reuse_kv_layer_idx - if kv_dim is not None: - self.kv_dim = kv_dim - assert fused_qkv is False, 'Cannot use separate kv_dim from d_model for fused_qkv' - else: - self.kv_dim = self.d_model - + self.kv_dim = kv_dim if kv_dim is not None else self.d_model self.head_dim = d_model // n_heads # Usually, fc_type dict should be passed in through MPTBlock's __init__ function. diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index c552f3125e..6c7e415ac6 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -336,6 +336,11 @@ def _validate_config(self) -> None: raise NotImplementedError( 'sliding window attention only implemented for torch attention and flash attention (v2.3.0 or higher).', ) + if self.attn_config['kv_dim'] is not None and self.attn_config[ + 'fused_qkv']: + raise ValueError( + 'fused_qkv should be False when "kv_dim" is specified.', + ) if self.embedding_fraction > 1 or self.embedding_fraction <= 0: raise ValueError( 'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!', From bba0240d51ee984e1210eac0c9a6b871785b4c68 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 30 Aug 2024 04:18:23 +0000 Subject: [PATCH 6/6] change from assert to error --- llmfoundry/models/layers/attention.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 8cb618af7f..a1af2235cf 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -672,7 +672,10 @@ def get_qkv( return query, key, value if self.fused_qkv: - assert key_value_states is None, 'Cannot use separate hidden and key_value states for fused_qkv' + if key_value_states is not None: + raise ValueError( + 'Cannot use separate hidden and key_value states when fused_qkv = True.', + ) qkv = self.Wqkv(x) if self.clip_qkv: