diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index acf231558e..a1af2235cf 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -424,8 +424,9 @@ class GroupedQueryAttention(nn.Module): and Multi-query attention (MQA). This allows the user to set a variable of number of kv_n_heads, rather than - just n_heads or 1, as in MHA and MQA. Using torch attention - implementation enables user to also use additive bias. + just n_heads or 1, as in MHA and MQA. Using torch attention implementation + enables user to also use additive bias. This class also supports + cross-attention with different `in_features` for key and value fc projections. """ def __init__( @@ -447,6 +448,7 @@ def __init__( bias: bool = True, sliding_window_size: int = -1, reuse_kv_layer_idx: Optional[int] = None, + kv_dim: Optional[int] = None, ): super().__init__() @@ -462,6 +464,7 @@ def __init__( self.sliding_window_size = sliding_window_size self.reuse_kv_layer_idx = reuse_kv_layer_idx + self.kv_dim = kv_dim if kv_dim is not None else self.d_model self.head_dim = d_model // n_heads # Usually, fc_type dict should be passed in through MPTBlock's __init__ function. @@ -523,13 +526,13 @@ def __init__( ) self.Wk = build_fc( name=fc_type_name, - in_features=self.d_model, + in_features=self.kv_dim, out_features=self.kv_n_heads * self.head_dim, fc_kwargs=fc_type, ) self.Wv = build_fc( name=fc_type_name, - in_features=self.d_model, + in_features=self.kv_dim, out_features=self.kv_n_heads * self.head_dim, fc_kwargs=fc_type, ) @@ -583,12 +586,17 @@ def forward( flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None, prev_layer_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + key_value_states: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[ torch.Tensor, torch.Tensor]]]: extra_kwargs = {} if prev_layer_key_value is not None: extra_kwargs['prev_layer_key_value'] = prev_layer_key_value - query, key, value = self.get_qkv(x, **extra_kwargs) + query, key, value = self.get_qkv( + x=x, + key_value_states=key_value_states, + **extra_kwargs, + ) if rotary_emb_w_meta_info is not None: query, key, value = self._apply_rotary_embeddings( @@ -628,12 +636,14 @@ def get_qkv( x: torch.Tensor, prev_layer_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + key_value_states: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Computes and returns the query, key, and value tensors. Args: - x (torch.Tensor): The input tensor. + x (torch.Tensor): The input query tensor. prev_layer_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]]): The key value of the previous layer. + key_value_states (Optional[torch.Tensor]): The input tensor for keys and values. Returns: query (torch.Tensor): The query tensor. @@ -662,6 +672,10 @@ def get_qkv( return query, key, value if self.fused_qkv: + if key_value_states is not None: + raise ValueError( + 'Cannot use separate hidden and key_value states when fused_qkv = True.', + ) qkv = self.Wqkv(x) if self.clip_qkv: @@ -677,8 +691,12 @@ def get_qkv( ) else: query = self.Wq(x) - key = self.Wk(x) - value = self.Wv(x) + if key_value_states is not None: + key = self.Wk(key_value_states) + value = self.Wv(key_value_states) + else: + key = self.Wk(x) + value = self.Wv(x) if self.clip_qkv: query = query.clamp(min=-self.clip_qkv, max=self.clip_qkv) @@ -832,6 +850,7 @@ def __init__( bias: bool = True, sliding_window_size: int = -1, reuse_kv_layer_idx: Optional[int] = None, + kv_dim: Optional[int] = None, ): super().__init__( d_model=d_model, @@ -851,6 +870,7 @@ def __init__( bias=bias, sliding_window_size=sliding_window_size, reuse_kv_layer_idx=reuse_kv_layer_idx, + kv_dim=kv_dim, ) @@ -879,6 +899,7 @@ def __init__( bias: bool = True, sliding_window_size: int = -1, reuse_kv_layer_idx: Optional[int] = None, + kv_dim: Optional[int] = None, ): super().__init__( d_model=d_model, @@ -898,6 +919,7 @@ def __init__( bias=bias, sliding_window_size=sliding_window_size, reuse_kv_layer_idx=reuse_kv_layer_idx, + kv_dim=kv_dim, ) diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 759f347e89..6c7e415ac6 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -95,6 +95,7 @@ def __init__( type (str): Can be one of 'no_scaling', 'linear', or 'dynamic'. 'no_scaling' uses the default implementation for rotary embeddings, 'linear' uses linear scaling as proposed by the Reddit user /u/kaiokendev, and 'dynamic' uses Dynamic NTK scaling as proposed by the Reddit users /u/bloc97 and /u/emozilla. factor (float): Scaling factor to use if using 'linear' or 'dynamic' as rope_scaling.type. kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads. + kv_dim (Optional[int]): For cross-attention only, allow user to specify different input dimensions for kv projections. ffn_config (Dict): A dictionary used to configure the model's ffn module: ffn_type (str): type of ffn to use. Options: mptmlp, mptglu, te_ln_mlp init_device (str): The device to use for parameter initialization. @@ -335,6 +336,11 @@ def _validate_config(self) -> None: raise NotImplementedError( 'sliding window attention only implemented for torch attention and flash attention (v2.3.0 or higher).', ) + if self.attn_config['kv_dim'] is not None and self.attn_config[ + 'fused_qkv']: + raise ValueError( + 'fused_qkv should be False when "kv_dim" is specified.', + ) if self.embedding_fraction > 1 or self.embedding_fraction <= 0: raise ValueError( 'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!', diff --git a/llmfoundry/models/utils/config_defaults.py b/llmfoundry/models/utils/config_defaults.py index 8a15f0d81a..bd3b29a479 100644 --- a/llmfoundry/models/utils/config_defaults.py +++ b/llmfoundry/models/utils/config_defaults.py @@ -32,6 +32,7 @@ 'type': 'no_scaling', 'factor': 1.0, }, + 'kv_dim': None, } init_config_defaults: dict = { diff --git a/tests/models/layers/test_attention.py b/tests/models/layers/test_attention.py index c51a532092..63ecb17d78 100644 --- a/tests/models/layers/test_attention.py +++ b/tests/models/layers/test_attention.py @@ -5,6 +5,7 @@ import pytest import torch +from composer.utils import reproducibility from llmfoundry.models.layers.attention import ( attention_implementations, @@ -271,3 +272,119 @@ def test_sliding_window(sliding_window_size: int, attn_impl: str): def _assert_approx_equal(value1: torch.Tensor, value2: torch.Tensor): assert torch.norm(value2 - value1) <= 1e-2 + 1e-2 * torch.norm(value2) + + +@pytest.mark.parametrize( + 'attn_name', + ['multihead_attention', 'grouped_query_attention', 'multiquery_attention'], +) +@pytest.mark.parametrize('dim', [1024]) +def test_cross_attn_as_self_attn(attn_name: str, dim: int): + d_head = 128 + n_heads = dim // d_head + + generic_attn_kwargs = { + 'd_model': dim, + 'n_heads': n_heads, + 'fc_type': { + 'name': 'torch', + }, + 'device': 'cpu', + 'attn_pdrop': 0.0, + 'attn_impl': 'torch', + 'qk_ln': False, + 'qk_gn': False, + 'clip_qkv': None, + 'softmax_scale': None, + 'sliding_window_size': -1, + } + + if attn_name == 'grouped_query_attention': + kv_n_heads = 2 + generic_attn_kwargs['kv_n_heads'] = kv_n_heads + elif attn_name == 'multiquery_attention': + kv_n_heads = 1 + elif attn_name == 'multihead_attention': + kv_n_heads = n_heads + else: + raise ValueError(f'Unknown attention name: {attn_name}') + + attn_config = generic_attn_kwargs.copy() + attn_config['fused_qkv'] = False + + attn_layer = build_attention_layer( + name=attn_name, + attn_kwargs=attn_config, + ) + + x1 = torch.randn(1, 1, dim) + x2 = x1.detach().clone() + + out_fused, _, _ = attn_layer(x1) + out_unfused, _, _ = attn_layer(x1, key_value_states=x2) + + assert torch.allclose(out_fused, out_unfused) + + +@pytest.mark.parametrize( + 'attn_name', + ['multihead_attention', 'grouped_query_attention', 'multiquery_attention'], +) +@pytest.mark.parametrize('dim', [1024]) +def test_cross_attn_kv_dim(attn_name: str, dim: int): + d_head = 128 + n_heads = dim // d_head + + generic_attn_kwargs = { + 'd_model': dim, + 'n_heads': n_heads, + 'fc_type': { + 'name': 'torch', + }, + 'device': 'cpu', + 'attn_pdrop': 0.0, + 'attn_impl': 'torch', + 'qk_ln': False, + 'qk_gn': False, + 'clip_qkv': None, + 'softmax_scale': None, + 'sliding_window_size': -1, + } + + if attn_name == 'grouped_query_attention': + kv_n_heads = 2 + generic_attn_kwargs['kv_n_heads'] = kv_n_heads + elif attn_name == 'multiquery_attention': + kv_n_heads = 1 + elif attn_name == 'multihead_attention': + kv_n_heads = n_heads + else: + raise ValueError(f'Unknown attention name: {attn_name}') + + # layer with only dim passed in + attn_config = generic_attn_kwargs.copy() + attn_config['fused_qkv'] = False + + reproducibility.seed_all(42) + attn_layer_no_kv = build_attention_layer( + name=attn_name, + attn_kwargs=attn_config, + ) + + # layer with kv_dim = dim passed in + attn_config = generic_attn_kwargs.copy() + attn_config['fused_qkv'] = False + attn_config['kv_dim'] = dim + + reproducibility.seed_all(42) + attn_layer_kv = build_attention_layer( + name=attn_name, + attn_kwargs=attn_config, + ) + + x1 = torch.randn(1, 1, dim) + + out_fused, _, _ = attn_layer_no_kv(x1) + out_unfused, _, _ = attn_layer_kv(x1) + + assert torch.allclose(out_fused, out_unfused)