Skip to content

Commit

Permalink
Merge branch 'main' into envlogger
Browse files Browse the repository at this point in the history
  • Loading branch information
josejg committed Aug 16, 2024
2 parents bc3cd9d + 93b22d8 commit 370ceaf
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 105 deletions.
30 changes: 28 additions & 2 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def scaled_multihead_dot_product_attention(
dropout_p: float = 0.0,
training: bool = False,
needs_weights: bool = False,
sliding_window_size: int = -1,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:

Expand Down Expand Up @@ -177,7 +178,7 @@ def scaled_multihead_dot_product_attention(
min_val,
)

if is_causal and (not q.size(2) == 1):
if is_causal and (not s_q == 1):
s = max(s_q, s_k)
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float32)
causal_mask = causal_mask.tril()
Expand All @@ -189,6 +190,31 @@ def scaled_multihead_dot_product_attention(
min_val,
)

if sliding_window_size != -1:
window_mask = torch.ones((s_q, s_k),
dtype=torch.bool,
device=attn_weight.device)
if (not s_q == 1):
if s_q != s_k:
raise ValueError(
'Number of queries should be equal to the number of keys.',
)
window_mask = torch.tril(
window_mask,
diagonal=sliding_window_size,
)
window_mask = torch.triu(
window_mask,
diagonal=-sliding_window_size,
)
else:
window_mask[:, :-(sliding_window_size + 1)] = False
window_mask = ~window_mask
attn_weight = attn_weight.masked_fill(
window_mask.view(1, 1, s_q, s_k),
min_val,
)

attn_weight = torch.softmax(attn_weight, dim=-1)

if dropout_p:
Expand Down Expand Up @@ -591,6 +617,7 @@ def forward(
dropout_p=self.attn_dropout_p,
training=self.training,
needs_weights=needs_weights,
sliding_window_size=self.sliding_window_size,
**extra_attn_kwargs,
)

Expand Down Expand Up @@ -771,7 +798,6 @@ def get_implementation_specific_args(
if self.attn_impl == 'flash':
extra_attn_kwargs = {
'should_repeat_kv_for_gqa': not is_flash_v2_installed(),
'sliding_window_size': self.sliding_window_size,
'alibi_slopes': alibi_slopes,
'flash_attn_padding_info': flash_attn_padding_info,
'key_padding_mask': None,
Expand Down
9 changes: 4 additions & 5 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,12 +329,11 @@ def _validate_config(self) -> None:
raise ImportError(
'If using the dail implementation of rope, the flash_attn library v2.0.1 or higher must be installed. Please check the instructions at https://github.com/mosaicml/llm-foundry/blob/main/TUTORIAL.md#what-kinds-of-positional-embeddings-does-llm-foundry-support',
)
if self.attn_config['sliding_window_size'] != -1 and not (
self.attn_config['attn_impl'] == 'flash' and
is_flash_v2_installed(v2_version='v2.3.0')
):
if self.attn_config['sliding_window_size'] != -1 and self.attn_config[
'attn_impl'
] == 'flash' and not is_flash_v2_installed(v2_version='v2.3.0',):
raise NotImplementedError(
'sliding window only implemented with flash attention v2.3.0 or higher.',
'sliding window attention only implemented for torch attention and flash attention (v2.3.0 or higher).',
)
if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
raise ValueError(
Expand Down
113 changes: 113 additions & 0 deletions tests/models/layers/test_attention.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import math

import pytest
import torch

from llmfoundry.models.layers.attention import (
attention_implementations,
scaled_multihead_dot_product_attention,
)
from llmfoundry.models.layers.layer_builders import build_attention_layer
from llmfoundry.models.mpt.modeling_mpt import gen_flash_attn_padding_info


@pytest.mark.parametrize(
Expand Down Expand Up @@ -158,3 +165,109 @@ def test_unfused_wqkv(attn_name: str, dim: int):
assert isinstance(attn_fused.Wqkv.weight.grad, torch.Tensor)
assert isinstance(combined_grad, torch.Tensor)
assert torch.allclose(attn_fused.Wqkv.weight.grad, combined_grad)


@pytest.mark.gpu
@pytest.mark.parametrize('sliding_window_size', [1, 4, 8])
@pytest.mark.parametrize('attn_impl', ['flash', 'torch'])
def test_sliding_window(sliding_window_size: int, attn_impl: str):
# Test that sliding window attention works as expected.
dtype = torch.bfloat16
device = 'cuda'
d = 128
n_heads = 8
seqlen_1 = 8
bsz = 2

query_1 = torch.randn(bsz, seqlen_1,
n_heads * d).to(dtype=dtype, device=device)
query_1.requires_grad = True
key_1 = torch.randn(bsz, seqlen_1,
n_heads * d).to(dtype=dtype, device=device)
key_1.requires_grad = True
value_1 = torch.randn(bsz, seqlen_1,
n_heads * d).to(dtype=dtype, device=device)
value_1.requires_grad = True

attn_extra_kwargs = {}
if attn_impl == 'flash':
attn_extra_kwargs = {
'flash_attn_padding_info':
gen_flash_attn_padding_info(
bsz,
seqlen_1,
0,
query_1.device,
None,
None,
),
'should_repeat_kv_for_gqa':
True,
}

output_1, _, _ = attention_implementations.get(attn_impl)(
query=query_1,
key=key_1,
value=value_1,
n_heads=n_heads,
kv_n_heads=n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=None,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
sliding_window_size=sliding_window_size,
**attn_extra_kwargs,
)

output_1.sum().backward()

query_2 = query_1.detach().clone()
query_2.requires_grad = True
key_2 = key_1.detach().clone()
key_2.requires_grad = True
value_2 = value_1.detach().clone()
value_2.requires_grad = True

attn_bias_2 = torch.zeros(1, 1, seqlen_1,
seqlen_1).to(dtype=dtype, device=device)

window_mask_2 = torch.tril(
torch.ones(seqlen_1, seqlen_1),
diagonal=-(sliding_window_size + 1),
).to(dtype=dtype, device=device) * torch.finfo(attn_bias_2.dtype).min
attn_bias_2 = attn_bias_2 + window_mask_2
output_2, _, _ = scaled_multihead_dot_product_attention(
query=query_2,
key=key_2,
value=value_2,
n_heads=n_heads,
kv_n_heads=n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=attn_bias_2,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
)

output_2.sum().backward()

print(torch.max(output_1 - output_2))

_assert_approx_equal(output_1, output_2)
assert (query_2.grad is not None) and (query_1.grad is not None)
_assert_approx_equal(query_1.grad, query_2.grad)
assert (key_2.grad is not None) and (key_1.grad is not None)
_assert_approx_equal(key_1.grad, key_2.grad)
assert (value_2.grad is not None) and (value_1.grad is not None)
_assert_approx_equal(value_1.grad, value_2.grad)


def _assert_approx_equal(value1: torch.Tensor, value2: torch.Tensor):
assert torch.norm(value2 - value1) <= 1e-2 + 1e-2 * torch.norm(value2)
98 changes: 0 additions & 98 deletions tests/models/layers/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,104 +218,6 @@ def test_seq_id_masking_FA_v2():
)


@pytest.mark.gpu
@pytest.mark.skipif(
not is_flash_v2_installed(v2_version='v2.3.0'),
reason=
'Sliding window attention only supported by Flash Attention after v2.3.0.',
)
@pytest.mark.parametrize('sliding_window_size', [1, 4, 8])
def test_sliding_window(sliding_window_size: int):
# Test that sliding window attention works as expected.
dtype = torch.bfloat16
device = 'cuda'
d = 128
n_heads = 8
seqlen_1 = 8
bsz = 2

query_1 = torch.randn(bsz, seqlen_1,
n_heads * d).to(dtype=dtype, device=device)
query_1.requires_grad = True
key_1 = torch.randn(bsz, seqlen_1,
n_heads * d).to(dtype=dtype, device=device)
key_1.requires_grad = True
value_1 = torch.randn(bsz, seqlen_1,
n_heads * d).to(dtype=dtype, device=device)
value_1.requires_grad = True

output_1, _, _ = flash_attn_fn(
query=query_1,
key=key_1,
value=value_1,
n_heads=n_heads,
kv_n_heads=n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=None,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
flash_attn_padding_info=gen_flash_attn_padding_info(
bsz,
seqlen_1,
0,
query_1.device,
None,
None,
),
should_repeat_kv_for_gqa=True,
sliding_window_size=sliding_window_size,
)

output_1.sum().backward()

query_2 = query_1.detach().clone()
query_2.requires_grad = True
key_2 = key_1.detach().clone()
key_2.requires_grad = True
value_2 = value_1.detach().clone()
value_2.requires_grad = True

attn_bias_2 = torch.zeros(1, 1, seqlen_1,
seqlen_1).to(dtype=dtype, device=device)

window_mask_2 = torch.tril(
torch.ones(seqlen_1, seqlen_1),
diagonal=-(sliding_window_size + 1),
).to(dtype=dtype, device=device) * torch.finfo(attn_bias_2.dtype).min
attn_bias_2 = attn_bias_2 + window_mask_2
output_2, _, _ = scaled_multihead_dot_product_attention(
query=query_2,
key=key_2,
value=value_2,
n_heads=n_heads,
kv_n_heads=n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=attn_bias_2,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
)

output_2.sum().backward()

print(torch.max(output_1 - output_2))

_assert_approx_equal(output_1, output_2)
assert (query_2.grad is not None) and (query_1.grad is not None)
_assert_approx_equal(query_1.grad, query_2.grad)
assert (key_2.grad is not None) and (key_1.grad is not None)
_assert_approx_equal(key_1.grad, key_2.grad)
assert (value_2.grad is not None) and (value_1.grad is not None)
_assert_approx_equal(value_1.grad, value_2.grad)


@pytest.mark.gpu
@pytest.mark.skipif(
not check_alibi_support('flash'),
Expand Down
3 changes: 3 additions & 0 deletions tests/models/layers/test_flash_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def allclose_helper(
)
@pytest.mark.parametrize('attn_uses_sequence_id', [True, False])
@pytest.mark.parametrize('pad_attention_mask', [True, False])
@pytest.mark.parametrize('sliding_window_size', [-1, 2])
def test_attn_impl(
attn_impl_0: str,
attn_impl_1: str,
Expand All @@ -87,6 +88,7 @@ def test_attn_impl(
attn_type: str,
attn_uses_sequence_id: bool,
pad_attention_mask: bool,
sliding_window_size: int,
device: str = 'cuda',
):
"""Compare all attn impl with each other.
Expand Down Expand Up @@ -122,6 +124,7 @@ def test_attn_impl(
'clip_qkv': clip_qkv,
'qk_ln': qk_ln,
'qk_gn': qk_gn,
'sliding_window_size': sliding_window_size,
})

n, s, f = 2, 4, cfg.d_model
Expand Down

0 comments on commit 370ceaf

Please sign in to comment.