Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable cross attention layers #1495

Merged
merged 6 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,8 @@ def __init__(

def forward(
self,
x: torch.Tensor,
hidden_states: torch.Tensor,
gupta-abhay marked this conversation as resolved.
Show resolved Hide resolved
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attn_bias: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
Expand All @@ -588,7 +589,11 @@ def forward(
extra_kwargs = {}
if prev_layer_key_value is not None:
extra_kwargs['prev_layer_key_value'] = prev_layer_key_value
query, key, value = self.get_qkv(x, **extra_kwargs)
query, key, value = self.get_qkv(
hidden_states,
key_value_states,
**extra_kwargs,
)

if rotary_emb_w_meta_info is not None:
query, key, value = self._apply_rotary_embeddings(
Expand Down Expand Up @@ -625,14 +630,16 @@ def forward(

def get_qkv(
self,
x: torch.Tensor,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
prev_layer_key_value: Optional[tuple[torch.Tensor,
torch.Tensor]] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Computes and returns the query, key, and value tensors.

Args:
x (torch.Tensor): The input tensor.
hidden_states (torch.Tensor): The input query tensor.
key_value_states (Optional[torch.Tensor]): The input tensor for keys and values.
prev_layer_key_value (Optional[Tuple[torch.Tensor, torch.Tensor]]): The key value of the previous layer.

Returns:
Expand All @@ -647,7 +654,7 @@ def get_qkv(
)
key, value = prev_layer_key_value

query = self.Wq(x)
query = self.Wq(hidden_states)
if self.clip_qkv:
query = query.clamp(min=-self.clip_qkv, max=self.clip_qkv)

Expand All @@ -662,7 +669,8 @@ def get_qkv(
return query, key, value

if self.fused_qkv:
qkv = self.Wqkv(x)
assert key_value_states is None, 'Cannot use separate hidden and key_value states for fused_qkv'
gupta-abhay marked this conversation as resolved.
Show resolved Hide resolved
qkv = self.Wqkv(hidden_states)

if self.clip_qkv:
qkv = qkv.clamp(min=-self.clip_qkv, max=self.clip_qkv)
Expand All @@ -676,9 +684,13 @@ def get_qkv(
dim=2,
)
else:
query = self.Wq(x)
key = self.Wk(x)
value = self.Wv(x)
query = self.Wq(hidden_states)
if key_value_states is not None:
key = self.Wk(key_value_states)
value = self.Wv(key_value_states)
else:
key = self.Wk(hidden_states)
value = self.Wv(hidden_states)

if self.clip_qkv:
query = query.clamp(min=-self.clip_qkv, max=self.clip_qkv)
Expand Down
52 changes: 52 additions & 0 deletions tests/models/layers/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,55 @@ def test_sliding_window(sliding_window_size: int, attn_impl: str):

def _assert_approx_equal(value1: torch.Tensor, value2: torch.Tensor):
assert torch.norm(value2 - value1) <= 1e-2 + 1e-2 * torch.norm(value2)


@pytest.mark.parametrize(
'attn_name',
['multihead_attention', 'grouped_query_attention', 'multiquery_attention'],
)
@pytest.mark.parametrize('dim', [1024])
def test_cross_attn_as_self_attn(attn_name: str, dim: int):
d_head = 128
n_heads = dim // d_head

generic_attn_kwargs = {
'd_model': dim,
'n_heads': n_heads,
'fc_type': {
'name': 'torch',
},
'device': 'cpu',
'attn_pdrop': 0.0,
'attn_impl': 'torch',
'qk_ln': False,
'qk_gn': False,
'clip_qkv': None,
'softmax_scale': None,
'sliding_window_size': -1,
}

if attn_name == 'grouped_query_attention':
kv_n_heads = 2
generic_attn_kwargs['kv_n_heads'] = kv_n_heads
elif attn_name == 'multiquery_attention':
kv_n_heads = 1
elif attn_name == 'multihead_attention':
kv_n_heads = n_heads
else:
raise ValueError(f'Unknown attention name: {attn_name}')

attn_config_unfused = generic_attn_kwargs.copy()
attn_config_unfused['fused_qkv'] = False

attn_layer = build_attention_layer(
name=attn_name,
attn_kwargs=attn_config_unfused,
)

x1 = torch.randn(1, 1, dim)
x2 = x1.detach().clone()

out_fused, _, _ = attn_layer(hidden_states=x1)
out_unfused, _, _ = attn_layer(hidden_states=x1, key_value_states=x2)

assert torch.allclose(out_fused, out_unfused)
Loading