Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bumping flash attention version to 2.6.3 and adding option for softcap in attention and lm_head logits. #1374

Merged
merged 26 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
aeb650f
adding option for softcap in attention, updating flash attention
ShashankMosaicML Jul 19, 2024
8ed7e1c
fix
ShashankMosaicML Jul 19, 2024
4cf075c
adding test
ShashankMosaicML Jul 19, 2024
d1e738e
..
ShashankMosaicML Jul 19, 2024
87d4114
..
ShashankMosaicML Jul 19, 2024
debd411
adding test
ShashankMosaicML Jul 19, 2024
1e4a3aa
..
ShashankMosaicML Jul 19, 2024
9260f19
Merge branch 'main' into soft_cap_attn
ShashankMosaicML Jul 22, 2024
65c5fa8
add logit softcapping
ShashankMosaicML Jul 23, 2024
059617d
..
ShashankMosaicML Jul 23, 2024
13119e2
Merge branch 'main' into soft_cap_attn
ShashankMosaicML Jul 23, 2024
02b4f04
Merge branch 'main' into soft_cap_attn
ShashankMosaicML Jul 23, 2024
b06adf6
fix
ShashankMosaicML Jul 23, 2024
c72458a
Merge branch 'main' into soft_cap_attn
ShashankMosaicML Jul 23, 2024
258e048
Merge branch 'soft_cap_attn' of github.com:ShashankMosaicML/llm-found…
ShashankMosaicML Jul 23, 2024
b875fa3
Merge branch 'main' into soft_cap_attn
ShashankMosaicML Jul 24, 2024
ee72ff6
Merge branch 'main' into soft_cap_attn
ShashankMosaicML Aug 6, 2024
63d8676
Merge branch 'main' into soft_cap_attn
ShashankMosaicML Aug 27, 2024
65c17b0
Update configuration_mpt.py
ShashankMosaicML Aug 27, 2024
756e127
Merge branch 'main' into soft_cap_attn
ShashankMosaicML Aug 27, 2024
a3b568d
Merge branch 'main' into soft_cap_attn
ShashankMosaicML Aug 28, 2024
9c31f17
Merge branch 'main' into soft_cap_attn
ShashankMosaicML Aug 30, 2024
bf5e94e
..
ShashankMosaicML Aug 30, 2024
9a68b8f
Merge branch 'main' into soft_cap_attn
ShashankMosaicML Sep 20, 2024
1a4123a
Merge branch 'main' into soft_cap_attn
ShashankMosaicML Sep 22, 2024
21dd8bd
Merge branch 'main' into soft_cap_attn
ShashankMosaicML Sep 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def scaled_multihead_dot_product_attention(
dropout_p: float = 0.0,
training: bool = False,
needs_weights: bool = False,
attn_logit_softcapping: Optional[float] = None,
sliding_window_size: int = -1,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
Expand Down Expand Up @@ -149,6 +150,11 @@ def scaled_multihead_dot_product_attention(

attn_weight = q.matmul(k) * softmax_scale

if attn_logit_softcapping is not None:
attn_weight = attn_logit_softcapping * torch.tanh(
attn_weight / attn_logit_softcapping,
)

if attn_bias is not None:
# clamp to 0 necessary for torch 2.0 compile()
_s_q = max(0, attn_bias.size(2) - s_q)
Expand Down Expand Up @@ -264,6 +270,7 @@ def flash_attn_fn(
sliding_window_size: int = -1,
alibi_slopes: Optional[torch.Tensor] = None,
flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
attn_logit_softcapping: Optional[float] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
if key_padding_mask is not None:
Expand Down Expand Up @@ -381,13 +388,17 @@ def flash_attn_fn(
return_attn_probs=needs_weights,
)
elif is_flash_v2_installed():
alibi_kwargs = {}
extra_attn_kwargs = {}
if check_alibi_support('flash'):
alibi_kwargs = {'alibi_slopes': alibi_slopes}
extra_attn_kwargs['alibi_slopes'] = alibi_slopes
elif alibi_slopes is not None:
raise ValueError(
'alibi_slopes is only supported for flash-attn>=2.4.2',
)
if is_flash_v2_installed(
v2_version='v2.6.2',
) and attn_logit_softcapping is not None:
extra_attn_kwargs['softcap'] = attn_logit_softcapping
output_unpad = flash_attn_interface.flash_attn_varlen_func(
q=query_unpad,
k=key_unpad,
Expand All @@ -401,7 +412,7 @@ def flash_attn_fn(
causal=reset_is_causal,
return_attn_probs=needs_weights,
window_size=(sliding_window_size, sliding_window_size),
**alibi_kwargs,
**extra_attn_kwargs,
)
else:
raise RuntimeError(
Expand Down Expand Up @@ -448,6 +459,7 @@ def __init__(
bias: bool = True,
sliding_window_size: int = -1,
reuse_kv_layer_idx: Optional[int] = None,
attn_logit_softcapping: Optional[float] = None,
kv_dim: Optional[int] = None,
):
super().__init__()
Expand All @@ -463,6 +475,7 @@ def __init__(
self.kv_n_heads = kv_n_heads
self.sliding_window_size = sliding_window_size
self.reuse_kv_layer_idx = reuse_kv_layer_idx
self.attn_logit_softcapping = attn_logit_softcapping

self.kv_dim = kv_dim if kv_dim is not None else self.d_model
self.head_dim = d_model // n_heads
Expand Down Expand Up @@ -625,6 +638,7 @@ def forward(
dropout_p=self.attn_dropout_p,
training=self.training,
needs_weights=needs_weights,
attn_logit_softcapping=self.attn_logit_softcapping,
sliding_window_size=self.sliding_window_size,
**extra_attn_kwargs,
)
Expand Down Expand Up @@ -853,6 +867,7 @@ def __init__(
bias: bool = True,
sliding_window_size: int = -1,
reuse_kv_layer_idx: Optional[int] = None,
attn_logit_softcapping: Optional[float] = None,
kv_dim: Optional[int] = None,
):
super().__init__(
Expand All @@ -873,6 +888,7 @@ def __init__(
bias=bias,
sliding_window_size=sliding_window_size,
reuse_kv_layer_idx=reuse_kv_layer_idx,
attn_logit_softcapping=attn_logit_softcapping,
kv_dim=kv_dim,
)

Expand Down Expand Up @@ -902,6 +918,7 @@ def __init__(
bias: bool = True,
sliding_window_size: int = -1,
reuse_kv_layer_idx: Optional[int] = None,
attn_logit_softcapping: Optional[float] = None,
kv_dim: Optional[int] = None,
):
super().__init__(
Expand All @@ -922,6 +939,7 @@ def __init__(
bias=bias,
sliding_window_size=sliding_window_size,
reuse_kv_layer_idx=reuse_kv_layer_idx,
attn_logit_softcapping=attn_logit_softcapping,
kv_dim=kv_dim,
)

Expand Down
14 changes: 14 additions & 0 deletions llmfoundry/models/mpt/configuration_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
tie_word_embeddings: bool = True,
use_pad_tok_in_ffn: bool = True,
block_overrides: Optional[dict[str, Any]] = None,
final_logit_softcapping: Optional[float] = None,
**kwargs: Any,
):
"""The MPT configuration class.
Expand Down Expand Up @@ -148,6 +149,7 @@ def __init__(
reuse_kv_layer:
attn_config:
reuse_kv_layer_idx: -6 # Relative index of the layer whose kv cache to reuse
final_logit_softcapping (float | None): Softcapping threshold for final logit. Set to None to disable (default value None). Please see https://arxiv.org/pdf/2403.08295 for more details.
kwargs (Any): Other relevant keyword arguments.
"""
self.d_model = d_model
Expand Down Expand Up @@ -181,6 +183,7 @@ def __init__(
if block_overrides is not None:
self._validate_block_overrides(block_overrides)
self.block_overrides = block_overrides
self.final_logit_softcapping = final_logit_softcapping

if isinstance(fc_type, str):
fc_type = {'name': fc_type}
Expand Down Expand Up @@ -325,6 +328,17 @@ def _validate_config(self) -> None:
raise NotImplementedError(
'sliding window attention only implemented for torch attention and flash attention (v2.3.0 or higher).',
)
if self.attn_config['attn_logit_softcapping'] is not None:
if self.attn_config['attn_logit_softcapping'] <= 0:
raise ValueError(
'Attention attn_logit_softcapping should be positive.',
)
if self.attn_config[
'attn_impl'
] == 'flash' and not is_flash_v2_installed(v2_version='v2.6.2',):
raise NotImplementedError(
'Attention attn_logit_softcapping is only implemented with torch attention or flash attention v2.6.2 (or higher).',
)
if self.attn_config['kv_dim'] is not None and self.attn_config[
'fused_qkv']:
raise ValueError(
Expand Down
6 changes: 6 additions & 0 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,7 @@ def __init__(self, config: MPTConfig):
f"{logit_scale=} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.",
)
self.logit_scale = logit_scale
self.final_logit_softcapping = config.final_logit_softcapping

@property
def backbone_model_class(self) -> type[MPTModel]:
Expand Down Expand Up @@ -1172,6 +1173,11 @@ def forward(
)
logits *= self.logit_scale

if self.final_logit_softcapping is not None:
logits = self.final_logit_softcapping * torch.tanh(
logits / self.final_logit_softcapping,
)

loss = None
if labels is not None:
_labels = torch.roll(labels, shifts=-1)
Expand Down
1 change: 1 addition & 0 deletions llmfoundry/models/utils/config_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
'softmax_scale': None,
'attn_uses_sequence_id': False,
'sliding_window_size': -1,
'attn_logit_softcapping': None,
'alibi': False,
'alibi_bias_max': 8,
'rope': False,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@

# Flash 2 group kept for backwards compatibility
extra_deps['gpu-flash2'] = [
'flash-attn>=2.5.8,<3',
'flash-attn>=2.6.3,<3',
]

extra_deps['gpu'] = copy.deepcopy(extra_deps['gpu-flash2'])
Expand Down
99 changes: 97 additions & 2 deletions tests/models/layers/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import math
from typing import Optional

import pytest
import torch
Expand Down Expand Up @@ -334,5 +335,99 @@ def gen_bias():
_assert_approx_equal(value_1.grad, value_2.grad)


def _assert_approx_equal(value1: torch.Tensor, value2: torch.Tensor):
assert torch.norm(value2 - value1) <= 1e-2 + 1e-2 * torch.norm(value2)
@pytest.mark.gpu
@pytest.mark.skipif(
not is_flash_v2_installed(v2_version='v2.6.2'),
reason=
'attn_logit_softcapping only supported by Flash Attention after v2.6.2.',
)
@pytest.mark.parametrize(
'attn_logit_softcapping',
[None, 0.1, 1.0, 10.0, 100.0],
)
def test_attn_logit_softcapping(attn_logit_softcapping: Optional[float]):
# Test that attn_logit_softcapping in attention works as expected.
dtype = torch.bfloat16
device = 'cuda'
d = 128
seqlen_1 = 8
bsz = 2
n_heads = 4

query_1 = torch.randn(bsz, seqlen_1,
n_heads * d).to(dtype=dtype, device=device)
query_1.requires_grad = True
key_1 = torch.randn(bsz, seqlen_1,
n_heads * d).to(dtype=dtype, device=device)
key_1.requires_grad = True
value_1 = torch.randn(bsz, seqlen_1,
n_heads * d).to(dtype=dtype, device=device)
value_1.requires_grad = True
output_1, _, _ = flash_attn_fn(
query=query_1,
key=key_1,
value=value_1,
n_heads=n_heads,
kv_n_heads=n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=None,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
flash_attn_padding_info=gen_flash_attn_padding_info(
bsz,
seqlen_1,
0,
query_1.device,
None,
None,
),
should_repeat_kv_for_gqa=True,
attn_logit_softcapping=attn_logit_softcapping,
)
output_1.sum().backward()

query_2 = query_1.detach().clone()
query_2.requires_grad = True
key_2 = key_1.detach().clone()
key_2.requires_grad = True
value_2 = value_1.detach().clone()
value_2.requires_grad = True
output_2, _, _ = scaled_multihead_dot_product_attention(
query=query_2,
key=key_2,
value=value_2,
n_heads=n_heads,
kv_n_heads=n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
attn_logit_softcapping=attn_logit_softcapping,
)
output_2.sum().backward()

_assert_approx_equal(output_1, output_2)
assert (query_2.grad is not None) and (query_1.grad is not None)
_assert_approx_equal(query_1.grad, query_2.grad)
assert (key_2.grad is not None) and (key_1.grad is not None)
_assert_approx_equal(key_1.grad, key_2.grad)
assert (value_2.grad is not None) and (value_1.grad is not None)
_assert_approx_equal(value_1.grad, value_2.grad)


def _assert_approx_equal(
ShashankMosaicML marked this conversation as resolved.
Show resolved Hide resolved
value1: torch.Tensor,
value2: torch.Tensor,
atol: float = 1e-2,
rtol: float = 1e-2,
):
actual_difference = torch.norm(value2 - value1)
allowed_difference = atol + rtol * torch.norm(value2)
assert actual_difference < allowed_difference, f'{actual_difference=}, {allowed_difference=}'
Loading