Skip to content

Commit

Permalink
adding test for local global attention
Browse files Browse the repository at this point in the history
  • Loading branch information
ShashankMosaicML committed Dec 9, 2024
1 parent 5eca05f commit 718d89d
Showing 1 changed file with 135 additions and 0 deletions.
135 changes: 135 additions & 0 deletions tests/models/layers/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,141 @@ def test_sliding_window(sliding_window_size: int, attn_impl: str):
_assert_approx_equal(value_1.grad, value_2.grad)


@pytest.mark.gpu
@pytest.mark.parametrize('sliding_window_size', [1, 4])
@pytest.mark.parametrize('global_window_size', [1, 4])
@pytest.mark.parametrize('attn_impl', ['flex'])
def test_local_global_window(
sliding_window_size: int,
global_window_size: int,
attn_impl: str,
):
# Test that sliding window attention works as expected.
if attn_impl == 'flex' and version.parse(
torch.__version__.split('.dev')[0],
) < version.parse('2.5.1'):
pytest.skip(
'FlexAttention is not supported in torch version {torch.__version__}<2.5.1.',
)
global_window_size = torch.tensor(global_window_size, device='cuda')
sliding_window_size = torch.tensor(sliding_window_size, device='cuda')
dtype = torch.bfloat16
device = 'cuda'
d = 128
n_heads = 8
seqlen_1 = 128
bsz = 1

query_1 = torch.randn(bsz, seqlen_1,
n_heads * d).to(dtype=dtype, device=device)
query_1.requires_grad = True
key_1 = torch.randn(bsz, seqlen_1,
n_heads * d).to(dtype=dtype, device=device)
key_1.requires_grad = True
value_1 = torch.randn(bsz, seqlen_1,
n_heads * d).to(dtype=dtype, device=device)
value_1.requires_grad = True

attn_extra_kwargs = {}
if attn_impl == 'flex':
attn_extra_kwargs = {
'compiled_flex_attention':
compiled_flex_attention,
'compiled_create_block_mask':
compiled_create_block_mask,
'flex_attn_compile':
FLEX_ATTN_COMPILE,
'sequence_id_info': {
'pos_in_seq': None,
},
'flex_attn_mod_list': [{
'mod_name': 'local_global_mask',
'mod_kwargs': {
'sliding_window_size': sliding_window_size,
'global_window_size': global_window_size,
},
},],
}

output_1, _, _ = attention_implementations.get(attn_impl)(
query=query_1,
key=key_1,
value=value_1,
n_heads=n_heads,
kv_n_heads=n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=None,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
sliding_window_size=-1,
**attn_extra_kwargs,
)

output_1.sum().backward()

query_2 = query_1.detach().clone()
query_2.requires_grad = True
key_2 = key_1.detach().clone()
key_2.requires_grad = True
value_2 = value_1.detach().clone()
value_2.requires_grad = True

global_bias_2 = torch.where(
torch.arange(seqlen_1)[None, None,
None, :].to(dtype=dtype, device=device) <=
global_window_size,
torch.ones(1, 1, seqlen_1,
seqlen_1).to(dtype=torch.bool, device=device),
torch.zeros(1, 1, seqlen_1,
seqlen_1).to(dtype=torch.bool, device=device),
)

window_mask_2 = torch.tril(
torch.ones(seqlen_1, seqlen_1),
diagonal=-(sliding_window_size + 1),
).to(dtype=torch.bool, device=device)
window_mask_2 = torch.where(
window_mask_2,
torch.zeros_like(window_mask_2),
torch.ones_like(window_mask_2),
)
attn_bias_2 = global_bias_2 | window_mask_2
attn_bias_2 = torch.where(
attn_bias_2,
torch.zeros_like(attn_bias_2, dtype=dtype),
torch.finfo(dtype).min,
)
output_2, _, _ = scaled_multihead_dot_product_attention(
query=query_2,
key=key_2,
value=value_2,
n_heads=n_heads,
kv_n_heads=n_heads,
past_key_value=None,
softmax_scale=1 / math.sqrt(d),
attn_bias=attn_bias_2,
key_padding_mask=None,
is_causal=True,
dropout_p=0.0,
training=False,
needs_weights=False,
)

output_2.sum().backward()

_assert_approx_equal(output_1, output_2)
assert (query_2.grad is not None) and (query_1.grad is not None)
_assert_approx_equal(query_1.grad, query_2.grad)
assert (key_2.grad is not None) and (key_1.grad is not None)
_assert_approx_equal(key_1.grad, key_2.grad)
assert (value_2.grad is not None) and (value_1.grad is not None)
_assert_approx_equal(value_1.grad, value_2.grad)


def _assert_approx_equal(value1: torch.Tensor, value2: torch.Tensor):
assert torch.norm(value2 - value1) <= 1e-2 + 1e-2 * torch.norm(value2)

Expand Down

0 comments on commit 718d89d

Please sign in to comment.