Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable cross attention layers #1495

Merged
merged 6 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 27 additions & 8 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,9 @@ class GroupedQueryAttention(nn.Module):
and Multi-query attention (MQA).

This allows the user to set a variable of number of kv_n_heads, rather than
just n_heads or 1, as in MHA and MQA. Using torch attention
implementation enables user to also use additive bias.
just n_heads or 1, as in MHA and MQA. Using torch attention implementation
enables user to also use additive bias. This class also supports
cross-attention with different `in_features` for key and value fc projections.
"""

def __init__(
Expand All @@ -447,6 +448,7 @@ def __init__(
bias: bool = True,
sliding_window_size: int = -1,
reuse_kv_layer_idx: Optional[int] = None,
kv_dim: Optional[int] = None,
):
super().__init__()

Expand All @@ -462,6 +464,7 @@ def __init__(
self.sliding_window_size = sliding_window_size
self.reuse_kv_layer_idx = reuse_kv_layer_idx

self.kv_dim = kv_dim if kv_dim is not None else self.d_model
self.head_dim = d_model // n_heads

# Usually, fc_type dict should be passed in through MPTBlock's __init__ function.
Expand Down Expand Up @@ -523,13 +526,13 @@ def __init__(
)
self.Wk = build_fc(
name=fc_type_name,
in_features=self.d_model,
in_features=self.kv_dim,
out_features=self.kv_n_heads * self.head_dim,
fc_kwargs=fc_type,
)
self.Wv = build_fc(
name=fc_type_name,
in_features=self.d_model,
in_features=self.kv_dim,
out_features=self.kv_n_heads * self.head_dim,
fc_kwargs=fc_type,
)
Expand Down Expand Up @@ -583,12 +586,17 @@ def forward(
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
prev_layer_key_value: Optional[tuple[torch.Tensor,
torch.Tensor]] = None,
key_value_states: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[
torch.Tensor, torch.Tensor]]]:
extra_kwargs = {}
if prev_layer_key_value is not None:
extra_kwargs['prev_layer_key_value'] = prev_layer_key_value
query, key, value = self.get_qkv(x, **extra_kwargs)
query, key, value = self.get_qkv(
x=x,
key_value_states=key_value_states,
**extra_kwargs,
)

if rotary_emb_w_meta_info is not None:
query, key, value = self._apply_rotary_embeddings(
Expand Down Expand Up @@ -628,12 +636,14 @@ def get_qkv(
x: torch.Tensor,
prev_layer_key_value: Optional[tuple[torch.Tensor,
torch.Tensor]] = None,
key_value_states: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Computes and returns the query, key, and value tensors.

Args:
x (torch.Tensor): The input tensor.
x (torch.Tensor): The input query tensor.
prev_layer_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]]): The key value of the previous layer.
key_value_states (Optional[torch.Tensor]): The input tensor for keys and values.

Returns:
query (torch.Tensor): The query tensor.
Expand Down Expand Up @@ -662,6 +672,7 @@ def get_qkv(
return query, key, value

if self.fused_qkv:
assert key_value_states is None, 'Cannot use separate hidden and key_value states for fused_qkv'
gupta-abhay marked this conversation as resolved.
Show resolved Hide resolved
qkv = self.Wqkv(x)

if self.clip_qkv:
Expand All @@ -677,8 +688,12 @@ def get_qkv(
)
else:
query = self.Wq(x)
key = self.Wk(x)
value = self.Wv(x)
if key_value_states is not None:
key = self.Wk(key_value_states)
value = self.Wv(key_value_states)
else:
key = self.Wk(x)
value = self.Wv(x)

if self.clip_qkv:
query = query.clamp(min=-self.clip_qkv, max=self.clip_qkv)
Expand Down Expand Up @@ -832,6 +847,7 @@ def __init__(
bias: bool = True,
sliding_window_size: int = -1,
reuse_kv_layer_idx: Optional[int] = None,
kv_dim: Optional[int] = None,
):
super().__init__(
d_model=d_model,
Expand All @@ -851,6 +867,7 @@ def __init__(
bias=bias,
sliding_window_size=sliding_window_size,
reuse_kv_layer_idx=reuse_kv_layer_idx,
kv_dim=kv_dim,
)


Expand Down Expand Up @@ -879,6 +896,7 @@ def __init__(
bias: bool = True,
sliding_window_size: int = -1,
reuse_kv_layer_idx: Optional[int] = None,
kv_dim: Optional[int] = None,
):
super().__init__(
d_model=d_model,
Expand All @@ -898,6 +916,7 @@ def __init__(
bias=bias,
sliding_window_size=sliding_window_size,
reuse_kv_layer_idx=reuse_kv_layer_idx,
kv_dim=kv_dim,
)


Expand Down
6 changes: 6 additions & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(
type (str): Can be one of 'no_scaling', 'linear', or 'dynamic'. 'no_scaling' uses the default implementation for rotary embeddings, 'linear' uses linear scaling as proposed by the Reddit user /u/kaiokendev, and 'dynamic' uses Dynamic NTK scaling as proposed by the Reddit users /u/bloc97 and /u/emozilla.
factor (float): Scaling factor to use if using 'linear' or 'dynamic' as rope_scaling.type.
kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
kv_dim (Optional[int]): For cross-attention only, allow user to specify different input dimensions for kv projections.
ffn_config (Dict): A dictionary used to configure the model's ffn module:
ffn_type (str): type of ffn to use. Options: mptmlp, mptglu, te_ln_mlp
init_device (str): The device to use for parameter initialization.
Expand Down Expand Up @@ -335,6 +336,11 @@ def _validate_config(self) -> None:
raise NotImplementedError(
'sliding window attention only implemented for torch attention and flash attention (v2.3.0 or higher).',
)
if self.attn_config['kv_dim'] is not None and self.attn_config[
'fused_qkv']:
raise ValueError(
'fused_qkv should be False when "kv_dim" is specified.',
)
if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
raise ValueError(
'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!',
Expand Down
1 change: 1 addition & 0 deletions llmfoundry/models/utils/config_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
'type': 'no_scaling',
'factor': 1.0,
},
'kv_dim': None,
}

init_config_defaults: dict = {
Expand Down
117 changes: 117 additions & 0 deletions tests/models/layers/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytest
import torch
from composer.utils import reproducibility

from llmfoundry.models.layers.attention import (
attention_implementations,
Expand Down Expand Up @@ -271,3 +272,119 @@ def test_sliding_window(sliding_window_size: int, attn_impl: str):

def _assert_approx_equal(value1: torch.Tensor, value2: torch.Tensor):
assert torch.norm(value2 - value1) <= 1e-2 + 1e-2 * torch.norm(value2)


@pytest.mark.parametrize(
'attn_name',
['multihead_attention', 'grouped_query_attention', 'multiquery_attention'],
)
@pytest.mark.parametrize('dim', [1024])
def test_cross_attn_as_self_attn(attn_name: str, dim: int):
d_head = 128
n_heads = dim // d_head

generic_attn_kwargs = {
'd_model': dim,
'n_heads': n_heads,
'fc_type': {
'name': 'torch',
},
'device': 'cpu',
'attn_pdrop': 0.0,
'attn_impl': 'torch',
'qk_ln': False,
'qk_gn': False,
'clip_qkv': None,
'softmax_scale': None,
'sliding_window_size': -1,
}

if attn_name == 'grouped_query_attention':
kv_n_heads = 2
generic_attn_kwargs['kv_n_heads'] = kv_n_heads
elif attn_name == 'multiquery_attention':
kv_n_heads = 1
elif attn_name == 'multihead_attention':
kv_n_heads = n_heads
else:
raise ValueError(f'Unknown attention name: {attn_name}')

attn_config = generic_attn_kwargs.copy()
attn_config['fused_qkv'] = False

attn_layer = build_attention_layer(
name=attn_name,
attn_kwargs=attn_config,
)

x1 = torch.randn(1, 1, dim)
x2 = x1.detach().clone()

out_fused, _, _ = attn_layer(x1)
out_unfused, _, _ = attn_layer(x1, key_value_states=x2)

assert torch.allclose(out_fused, out_unfused)


@pytest.mark.parametrize(
'attn_name',
['multihead_attention', 'grouped_query_attention', 'multiquery_attention'],
)
@pytest.mark.parametrize('dim', [1024])
def test_cross_attn_kv_dim(attn_name: str, dim: int):
d_head = 128
n_heads = dim // d_head

generic_attn_kwargs = {
'd_model': dim,
'n_heads': n_heads,
'fc_type': {
'name': 'torch',
},
'device': 'cpu',
'attn_pdrop': 0.0,
'attn_impl': 'torch',
'qk_ln': False,
'qk_gn': False,
'clip_qkv': None,
'softmax_scale': None,
'sliding_window_size': -1,
}

if attn_name == 'grouped_query_attention':
kv_n_heads = 2
generic_attn_kwargs['kv_n_heads'] = kv_n_heads
elif attn_name == 'multiquery_attention':
kv_n_heads = 1
elif attn_name == 'multihead_attention':
kv_n_heads = n_heads
else:
raise ValueError(f'Unknown attention name: {attn_name}')

# layer with only dim passed in
attn_config = generic_attn_kwargs.copy()
attn_config['fused_qkv'] = False

reproducibility.seed_all(42)
attn_layer_no_kv = build_attention_layer(
name=attn_name,
attn_kwargs=attn_config,
)

# layer with kv_dim = dim passed in
attn_config = generic_attn_kwargs.copy()
attn_config['fused_qkv'] = False
attn_config['kv_dim'] = dim

reproducibility.seed_all(42)
attn_layer_kv = build_attention_layer(
name=attn_name,
attn_kwargs=attn_config,
)

x1 = torch.randn(1, 1, dim)

out_fused, _, _ = attn_layer_no_kv(x1)
out_unfused, _, _ = attn_layer_kv(x1)

assert torch.allclose(out_fused, out_unfused)
Loading