Skip to content

Commit

Permalink
configuring more tests with torch 2.5.1 and 2.6.0.dev
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Dec 5, 2024
1 parent 434aa83 commit 216fcb9
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 52 deletions.
52 changes: 26 additions & 26 deletions tests/models/layers/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
)
from llmfoundry.models.mpt.modeling_mpt import gen_flash_attn_padding_info

compiled_flex_attention = flex_attention
compiled_create_block_mask = create_block_mask
if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.6.0'):
compiled_flex_attention = torch.compile(flex_attention)
compiled_create_block_mask = torch.compile(create_block_mask)


@pytest.mark.gpu
@pytest.mark.skipif(
Expand All @@ -33,9 +39,9 @@ def test_gqa_kv_repetition(attn_impl: str, kv_n_heads: int):
# whether we repeat the kv_n_heads explicitly or flash attention v2 handles it on its own.
if attn_impl == 'flex' and version.parse(
torch.__version__.split('.dev')[0],
) < version.parse('2.6.0'):
) < version.parse('2.5.1'):
pytest.skip(
'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.',
'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.',
)
d = 128
n_heads = 8
Expand Down Expand Up @@ -67,9 +73,8 @@ def test_gqa_kv_repetition(attn_impl: str, kv_n_heads: int):
}
elif attn_impl == 'flex':
extra_attn_kwargs = {
'compiled_flex_attention':
flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis.
'compiled_create_block_mask': torch.compile(create_block_mask),
'compiled_flex_attention': compiled_flex_attention,
'compiled_create_block_mask': compiled_create_block_mask,
'sequence_id_info': {},
}

Expand Down Expand Up @@ -115,9 +120,8 @@ def test_gqa_kv_repetition(attn_impl: str, kv_n_heads: int):
}
elif attn_impl == 'flex':
extra_attn_kwargs = {
'compiled_flex_attention':
flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis.
'compiled_create_block_mask': torch.compile(create_block_mask),
'compiled_flex_attention': compiled_flex_attention,
'compiled_create_block_mask': compiled_create_block_mask,
'sequence_id_info': {},
}

Expand Down Expand Up @@ -156,9 +160,9 @@ def test_seq_id_masking_FA_v2(attn_impl: str):
# Test that flash attention v2 with sequence id masking works correctly.
if attn_impl == 'flex' and version.parse(
torch.__version__.split('.dev')[0],
) < version.parse('2.6.0'):
) < version.parse('2.5.1'):
pytest.skip(
'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.',
'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.',
)
d = 128
n_heads = 4
Expand Down Expand Up @@ -198,9 +202,8 @@ def test_seq_id_masking_FA_v2(attn_impl: str):
extra_attn_kwargs['flash_attn_padding_info'] = flash_attn_padding_info_1
elif attn_impl == 'flex':
extra_attn_kwargs = {
'compiled_flex_attention':
flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis.
'compiled_create_block_mask': torch.compile(create_block_mask),
'compiled_flex_attention': compiled_flex_attention,
'compiled_create_block_mask': compiled_create_block_mask,
'sequence_id_info': {
'sequence_id': sequence_id,
},
Expand Down Expand Up @@ -246,9 +249,8 @@ def test_seq_id_masking_FA_v2(attn_impl: str):
] = flash_attn_padding_info_2
elif attn_impl == 'flex':
extra_attn_kwargs = {
'compiled_flex_attention':
flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis.
'compiled_create_block_mask': torch.compile(create_block_mask),
'compiled_flex_attention': compiled_flex_attention,
'compiled_create_block_mask': compiled_create_block_mask,
'sequence_id_info': {
'sequence_id': sequence_id,
},
Expand Down Expand Up @@ -300,9 +302,9 @@ def test_alibi_bias(attn_impl: str, n_heads: int):
# Test that sliding window attention works as expected.
if attn_impl == 'flex' and version.parse(
torch.__version__.split('.dev')[0],
) < version.parse('2.6.0'):
) < version.parse('2.5.1'):
pytest.skip(
'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.',
'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.',
)
dtype = torch.bfloat16
device = 'cuda'
Expand Down Expand Up @@ -342,9 +344,8 @@ def test_alibi_bias(attn_impl: str, n_heads: int):
}
elif attn_impl == 'flex':
extra_attn_kwargs = {
'compiled_flex_attention':
flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis.
'compiled_create_block_mask': torch.compile(create_block_mask),
'compiled_flex_attention': compiled_flex_attention,
'compiled_create_block_mask': compiled_create_block_mask,
'sequence_id_info': {},
}
output_1, _, _ = attention_implementations.get(attn_impl)(
Expand Down Expand Up @@ -444,9 +445,9 @@ def test_attn_logit_softcapping(
# Test that attn_logit_softcapping in attention works as expected.
if attn_impl == 'flex' and version.parse(
torch.__version__.split('.dev')[0],
) < version.parse('2.6.0'):
) < version.parse('2.5.1'):
pytest.skip(
'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.',
'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.',
)
if attn_impl == 'flex' and attn_logit_softcapping is not None:
if int(attn_logit_softcapping) != attn_logit_softcapping:
Expand Down Expand Up @@ -489,9 +490,8 @@ def test_attn_logit_softcapping(
}
elif attn_impl == 'flex':
extra_attn_kwargs = {
'compiled_flex_attention':
flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis.
'compiled_create_block_mask': torch.compile(create_block_mask),
'compiled_flex_attention': compiled_flex_attention,
'compiled_create_block_mask': compiled_create_block_mask,
'sequence_id_info': {},
}
output_1, _, _ = attention_implementations.get(attn_impl)(
Expand Down
52 changes: 26 additions & 26 deletions tests/models/layers/test_flash_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
gen_sequence_id_info,
)

compiled_flex_attention = flex_attention
compiled_create_block_mask = create_block_mask
if version.parse(torch.__version__.split('.dev')[0]) >= version.parse('2.6.0'):
compiled_flex_attention = torch.compile(flex_attention)
compiled_create_block_mask = torch.compile(create_block_mask)


def allclose_helper(
t0: torch.Tensor,
Expand Down Expand Up @@ -104,9 +110,9 @@ def test_attn_impl(
"""
if (attn_impl_0 == 'flex' or attn_impl_1 == 'flex') and version.parse(
torch.__version__.split('.dev')[0],
) < version.parse('2.6.0'):
) < version.parse('2.5.1'):
pytest.skip(
'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.',
'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.',
)
alibi = pos_emb_config['alibi']
rope = pos_emb_config['rope']
Expand Down Expand Up @@ -293,9 +299,8 @@ def gen_bias(attn_impl: str):
extra_kwargs = {}
if attn_impl_0 == 'flex':
extra_kwargs['flex_attn_kwargs'] = {
'compiled_flex_attention':
flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis.
'compiled_create_block_mask': torch.compile(create_block_mask),
'compiled_flex_attention': compiled_flex_attention,
'compiled_create_block_mask': compiled_create_block_mask,
'sequence_id_info': {},
}
if sequence_id is not None:
Expand Down Expand Up @@ -325,9 +330,8 @@ def gen_bias(attn_impl: str):
extra_kwargs = {}
if attn_impl_1 == 'flex':
extra_kwargs['flex_attn_kwargs'] = {
'compiled_flex_attention':
flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis.
'compiled_create_block_mask': torch.compile(create_block_mask),
'compiled_flex_attention': compiled_flex_attention,
'compiled_create_block_mask': compiled_create_block_mask,
}
if sequence_id is not None:
extra_kwargs['flex_attn_kwargs']['sequence_id_info'] = {
Expand Down Expand Up @@ -386,9 +390,9 @@ def test_vs_mha(attn_impl: str, device: str = 'cuda'):
"""Compare diff attn_impl to torch.nn.MultiheadAttention."""
if attn_impl == 'flex' and version.parse(
torch.__version__.split('.dev')[0],
) < version.parse('2.6.0'):
) < version.parse('2.5.1'):
pytest.skip(
'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.',
'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.',
)
from llmfoundry.models.layers import attention

Expand Down Expand Up @@ -447,9 +451,8 @@ def gen_tca_mask():
extra_kwargs = {}
if attn_impl == 'flex':
extra_kwargs['flex_attn_kwargs'] = {
'compiled_flex_attention':
flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis.
'compiled_create_block_mask': torch.compile(create_block_mask),
'compiled_flex_attention': compiled_flex_attention,
'compiled_create_block_mask': compiled_create_block_mask,
'sequence_id_info': {},
}
y0, _, _ = mmhsa(
Expand Down Expand Up @@ -517,9 +520,9 @@ def test_grouped_attention_heads(
"""Ensure grouped_query_attention runs w/ diff n_heads & kv_n_heads."""
if attn_impl == 'flex' and version.parse(
torch.__version__.split('.dev')[0],
) < version.parse('2.6.0'):
) < version.parse('2.5.1'):
pytest.skip(
'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.',
'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.',
)
from llmfoundry.models.layers import attention

Expand Down Expand Up @@ -555,9 +558,8 @@ def test_grouped_attention_heads(
extra_kwargs = {}
if attn_impl == 'flex':
extra_kwargs['flex_attn_kwargs'] = {
'compiled_flex_attention':
flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis.
'compiled_create_block_mask': torch.compile(create_block_mask),
'compiled_flex_attention': compiled_flex_attention,
'compiled_create_block_mask': compiled_create_block_mask,
'sequence_id_info': {},
}
y0, _, _ = mmhsa(
Expand Down Expand Up @@ -637,9 +639,9 @@ def test_reuse_prev_layer_kv_cache(
"""Checks reusing previous layer's kv cache."""
if attn_impl == 'flex' and version.parse(
torch.__version__.split('.dev')[0],
) < version.parse('2.6.0'):
) < version.parse('2.5.1'):
pytest.skip(
'FlexAttention is not supported in torch version {torch.__version__}<2.6.0.',
'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.',
)
alibi = pos_emb_config['alibi']
rope = pos_emb_config['rope']
Expand Down Expand Up @@ -777,9 +779,8 @@ def gen_bias(attn_impl: str):
extra_kwargs = {}
if attn_impl == 'flex':
extra_kwargs['flex_attn_kwargs'] = {
'compiled_flex_attention':
flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis.
'compiled_create_block_mask': torch.compile(create_block_mask),
'compiled_flex_attention': compiled_flex_attention,
'compiled_create_block_mask': compiled_create_block_mask,
'sequence_id_info': {
'sequence_id': sequence_id,
},
Expand Down Expand Up @@ -811,9 +812,8 @@ def gen_bias(attn_impl: str):
extra_kwargs = {}
if attn_impl == 'flex':
extra_kwargs['flex_attn_kwargs'] = {
'compiled_flex_attention':
flex_attention, # TODO: torch.compile(flex_attention) doesn't work, maybe because the data dims are too small for compiled kernels. Confirm this hypothesis.
'compiled_create_block_mask': torch.compile(create_block_mask),
'compiled_flex_attention': compiled_flex_attention,
'compiled_create_block_mask': compiled_create_block_mask,
'sequence_id_info': {
'sequence_id': sequence_id,
},
Expand Down

0 comments on commit 216fcb9

Please sign in to comment.