Skip to content

Commit

Permalink
Merge branch 'main' into uv
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Aug 30, 2024
2 parents 2ff4b39 + d0dc82d commit b7e3039
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 8 deletions.
38 changes: 30 additions & 8 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,9 @@ class GroupedQueryAttention(nn.Module):
and Multi-query attention (MQA).
This allows the user to set a variable of number of kv_n_heads, rather than
just n_heads or 1, as in MHA and MQA. Using torch attention
implementation enables user to also use additive bias.
just n_heads or 1, as in MHA and MQA. Using torch attention implementation
enables user to also use additive bias. This class also supports
cross-attention with different `in_features` for key and value fc projections.
"""

def __init__(
Expand All @@ -447,6 +448,7 @@ def __init__(
bias: bool = True,
sliding_window_size: int = -1,
reuse_kv_layer_idx: Optional[int] = None,
kv_dim: Optional[int] = None,
):
super().__init__()

Expand All @@ -462,6 +464,7 @@ def __init__(
self.sliding_window_size = sliding_window_size
self.reuse_kv_layer_idx = reuse_kv_layer_idx

self.kv_dim = kv_dim if kv_dim is not None else self.d_model
self.head_dim = d_model // n_heads

# Usually, fc_type dict should be passed in through MPTBlock's __init__ function.
Expand Down Expand Up @@ -523,13 +526,13 @@ def __init__(
)
self.Wk = build_fc(
name=fc_type_name,
in_features=self.d_model,
in_features=self.kv_dim,
out_features=self.kv_n_heads * self.head_dim,
fc_kwargs=fc_type,
)
self.Wv = build_fc(
name=fc_type_name,
in_features=self.d_model,
in_features=self.kv_dim,
out_features=self.kv_n_heads * self.head_dim,
fc_kwargs=fc_type,
)
Expand Down Expand Up @@ -583,12 +586,17 @@ def forward(
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
prev_layer_key_value: Optional[tuple[torch.Tensor,
torch.Tensor]] = None,
key_value_states: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[
torch.Tensor, torch.Tensor]]]:
extra_kwargs = {}
if prev_layer_key_value is not None:
extra_kwargs['prev_layer_key_value'] = prev_layer_key_value
query, key, value = self.get_qkv(x, **extra_kwargs)
query, key, value = self.get_qkv(
x=x,
key_value_states=key_value_states,
**extra_kwargs,
)

if rotary_emb_w_meta_info is not None:
query, key, value = self._apply_rotary_embeddings(
Expand Down Expand Up @@ -628,12 +636,14 @@ def get_qkv(
x: torch.Tensor,
prev_layer_key_value: Optional[tuple[torch.Tensor,
torch.Tensor]] = None,
key_value_states: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Computes and returns the query, key, and value tensors.
Args:
x (torch.Tensor): The input tensor.
x (torch.Tensor): The input query tensor.
prev_layer_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]]): The key value of the previous layer.
key_value_states (Optional[torch.Tensor]): The input tensor for keys and values.
Returns:
query (torch.Tensor): The query tensor.
Expand Down Expand Up @@ -662,6 +672,10 @@ def get_qkv(
return query, key, value

if self.fused_qkv:
if key_value_states is not None:
raise ValueError(
'Cannot use separate hidden and key_value states when fused_qkv = True.',
)
qkv = self.Wqkv(x)

if self.clip_qkv:
Expand All @@ -677,8 +691,12 @@ def get_qkv(
)
else:
query = self.Wq(x)
key = self.Wk(x)
value = self.Wv(x)
if key_value_states is not None:
key = self.Wk(key_value_states)
value = self.Wv(key_value_states)
else:
key = self.Wk(x)
value = self.Wv(x)

if self.clip_qkv:
query = query.clamp(min=-self.clip_qkv, max=self.clip_qkv)
Expand Down Expand Up @@ -832,6 +850,7 @@ def __init__(
bias: bool = True,
sliding_window_size: int = -1,
reuse_kv_layer_idx: Optional[int] = None,
kv_dim: Optional[int] = None,
):
super().__init__(
d_model=d_model,
Expand All @@ -851,6 +870,7 @@ def __init__(
bias=bias,
sliding_window_size=sliding_window_size,
reuse_kv_layer_idx=reuse_kv_layer_idx,
kv_dim=kv_dim,
)


Expand Down Expand Up @@ -879,6 +899,7 @@ def __init__(
bias: bool = True,
sliding_window_size: int = -1,
reuse_kv_layer_idx: Optional[int] = None,
kv_dim: Optional[int] = None,
):
super().__init__(
d_model=d_model,
Expand All @@ -898,6 +919,7 @@ def __init__(
bias=bias,
sliding_window_size=sliding_window_size,
reuse_kv_layer_idx=reuse_kv_layer_idx,
kv_dim=kv_dim,
)


Expand Down
6 changes: 6 additions & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(
type (str): Can be one of 'no_scaling', 'linear', or 'dynamic'. 'no_scaling' uses the default implementation for rotary embeddings, 'linear' uses linear scaling as proposed by the Reddit user /u/kaiokendev, and 'dynamic' uses Dynamic NTK scaling as proposed by the Reddit users /u/bloc97 and /u/emozilla.
factor (float): Scaling factor to use if using 'linear' or 'dynamic' as rope_scaling.type.
kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
kv_dim (Optional[int]): For cross-attention only, allow user to specify different input dimensions for kv projections.
ffn_config (Dict): A dictionary used to configure the model's ffn module:
ffn_type (str): type of ffn to use. Options: mptmlp, mptglu, te_ln_mlp
init_device (str): The device to use for parameter initialization.
Expand Down Expand Up @@ -324,6 +325,11 @@ def _validate_config(self) -> None:
raise NotImplementedError(
'sliding window attention only implemented for torch attention and flash attention (v2.3.0 or higher).',
)
if self.attn_config['kv_dim'] is not None and self.attn_config[
'fused_qkv']:
raise ValueError(
'fused_qkv should be False when "kv_dim" is specified.',
)
if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
raise ValueError(
'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!',
Expand Down
1 change: 1 addition & 0 deletions llmfoundry/models/utils/config_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
'type': 'no_scaling',
'factor': 1.0,
},
'kv_dim': None,
}

init_config_defaults: dict = {
Expand Down
117 changes: 117 additions & 0 deletions tests/models/layers/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytest
import torch
from composer.utils import reproducibility

from llmfoundry.models.layers.attention import (
attention_implementations,
Expand Down Expand Up @@ -271,3 +272,119 @@ def test_sliding_window(sliding_window_size: int, attn_impl: str):

def _assert_approx_equal(value1: torch.Tensor, value2: torch.Tensor):
assert torch.norm(value2 - value1) <= 1e-2 + 1e-2 * torch.norm(value2)


@pytest.mark.parametrize(
'attn_name',
['multihead_attention', 'grouped_query_attention', 'multiquery_attention'],
)
@pytest.mark.parametrize('dim', [1024])
def test_cross_attn_as_self_attn(attn_name: str, dim: int):
d_head = 128
n_heads = dim // d_head

generic_attn_kwargs = {
'd_model': dim,
'n_heads': n_heads,
'fc_type': {
'name': 'torch',
},
'device': 'cpu',
'attn_pdrop': 0.0,
'attn_impl': 'torch',
'qk_ln': False,
'qk_gn': False,
'clip_qkv': None,
'softmax_scale': None,
'sliding_window_size': -1,
}

if attn_name == 'grouped_query_attention':
kv_n_heads = 2
generic_attn_kwargs['kv_n_heads'] = kv_n_heads
elif attn_name == 'multiquery_attention':
kv_n_heads = 1
elif attn_name == 'multihead_attention':
kv_n_heads = n_heads
else:
raise ValueError(f'Unknown attention name: {attn_name}')

attn_config = generic_attn_kwargs.copy()
attn_config['fused_qkv'] = False

attn_layer = build_attention_layer(
name=attn_name,
attn_kwargs=attn_config,
)

x1 = torch.randn(1, 1, dim)
x2 = x1.detach().clone()

out_fused, _, _ = attn_layer(x1)
out_unfused, _, _ = attn_layer(x1, key_value_states=x2)

assert torch.allclose(out_fused, out_unfused)


@pytest.mark.parametrize(
'attn_name',
['multihead_attention', 'grouped_query_attention', 'multiquery_attention'],
)
@pytest.mark.parametrize('dim', [1024])
def test_cross_attn_kv_dim(attn_name: str, dim: int):
d_head = 128
n_heads = dim // d_head

generic_attn_kwargs = {
'd_model': dim,
'n_heads': n_heads,
'fc_type': {
'name': 'torch',
},
'device': 'cpu',
'attn_pdrop': 0.0,
'attn_impl': 'torch',
'qk_ln': False,
'qk_gn': False,
'clip_qkv': None,
'softmax_scale': None,
'sliding_window_size': -1,
}

if attn_name == 'grouped_query_attention':
kv_n_heads = 2
generic_attn_kwargs['kv_n_heads'] = kv_n_heads
elif attn_name == 'multiquery_attention':
kv_n_heads = 1
elif attn_name == 'multihead_attention':
kv_n_heads = n_heads
else:
raise ValueError(f'Unknown attention name: {attn_name}')

# layer with only dim passed in
attn_config = generic_attn_kwargs.copy()
attn_config['fused_qkv'] = False

reproducibility.seed_all(42)
attn_layer_no_kv = build_attention_layer(
name=attn_name,
attn_kwargs=attn_config,
)

# layer with kv_dim = dim passed in
attn_config = generic_attn_kwargs.copy()
attn_config['fused_qkv'] = False
attn_config['kv_dim'] = dim

reproducibility.seed_all(42)
attn_layer_kv = build_attention_layer(
name=attn_name,
attn_kwargs=attn_config,
)

x1 = torch.randn(1, 1, dim)

out_fused, _, _ = attn_layer_no_kv(x1)
out_unfused, _, _ = attn_layer_kv(x1)

assert torch.allclose(out_fused, out_unfused)

0 comments on commit b7e3039

Please sign in to comment.