Skip to content

Commit

Permalink
adds triton flash attention2 kernel (#4337)
Browse files Browse the repository at this point in the history
* initial commit

* temp commit: needs debugging

* packed flash attn with mask works

* clean-up

* add bert/roberta tests to test_inference

* is_triton_supported added to Accelerator class
clean-up and formatting

* triton supports the flash attention when compute cap > 8.0

* formatting

* fix comments

* cleanup

* cleanup flash kernel

* fix according to the PR comment

---------

Co-authored-by: Stephen Youn <[email protected]>
Co-authored-by: Lev Kurilenko <[email protected]>
  • Loading branch information
3 people authored Sep 20, 2023
1 parent 4fc2c8e commit 0e0748c
Show file tree
Hide file tree
Showing 8 changed files with 241 additions and 33 deletions.
4 changes: 4 additions & 0 deletions accelerator/abstract_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ def lazy_call(self, callback):
def communication_backend_name(self):
...

@abc.abstractmethod
def is_triton_supported(self):
...

# Tensor operations
@property
@abc.abstractmethod
Expand Down
3 changes: 3 additions & 0 deletions accelerator/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ def lazy_call(self, callback):
def communication_backend_name(self):
return self._communication_backend_name

def is_triton_supported(self):
return False

# Data types
def is_bf16_supported(self):
return True
Expand Down
7 changes: 7 additions & 0 deletions accelerator/cuda_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,13 @@ def lazy_call(self, callback):
def communication_backend_name(self):
return self._communication_backend_name

def is_triton_supported(self):
major, _ = torch.cuda.get_device_capability()
if major >= 8:
return True
else:
return False

# Tensor operations

@property
Expand Down
3 changes: 3 additions & 0 deletions accelerator/mps_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ def lazy_call(self, callback):
def communication_backend_name(self):
return self._communication_backend_name

def is_triton_supported(self):
return False

# Tensor operations
@property
def BFloat16Tensor(self):
Expand Down
3 changes: 3 additions & 0 deletions accelerator/npu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ def lazy_call(self, callback):
def communication_backend_name(self):
return self._communication_backend_name

def is_triton_supported(self):
return False

# Tensor operations

@property
Expand Down
174 changes: 166 additions & 8 deletions deepspeed/ops/transformer/inference/triton/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import math
import torch
import torch.nn as nn
import triton
import triton.language as tl
from deepspeed.accelerator import get_accelerator
from deepspeed import comm as dist
from deepspeed.ops.transformer.inference.op_binding import LinearOp, VectorMatMulOp, SoftmaxContextOp, QKVGemmOp
Expand Down Expand Up @@ -70,6 +72,9 @@ def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count

self.mp_group = mp_group
self.use_flash = False
# triton flash attention is enabled when the compute capability >= 8.0
if get_accelerator().is_triton_supported():
self.use_flash = True

# used for quantization
self.q_scales = q_scales
Expand Down Expand Up @@ -176,7 +181,7 @@ def forward(
qkv = qkv_out[0]

if use_triton_attention and (alibi is None):
context_layer = compute_attention(qkv=qkv,
context_layer = _triton_attention(qkv=qkv,
input_mask=input_mask,
scale=self.scale,
layer_past=layer_past,
Expand Down Expand Up @@ -204,7 +209,7 @@ def forward(
global inference_module


def compute_attention(qkv,
def _triton_attention(qkv,
input_mask,
layer_past,
alibi,
Expand All @@ -217,13 +222,166 @@ def compute_attention(qkv,
if isinstance(qkv, list):
qkv = qkv[0]

#assert layer_past is None, "layer_past not supported in triton yet"
assert alibi is None, "layer_past not supported in alibi yet"
output = score_4d_matmul(qkv, head_size, triangular, scale)
if triangular:
output = softmax(output)

if use_triton_flash:
output = _triton_packed_flash(qkv,
head_size,
input_mask,
scale,
causal=triangular,
add_mask=(not triangular and input_mask is not None))
else:
output = softmax(output, input_mask)
output = context_4d_matmul(output, qkv, head_size)
output = score_4d_matmul(qkv, head_size, triangular, scale)
if triangular:
output = softmax(output)
else:
output = softmax(output, input_mask)
output = context_4d_matmul(output, qkv, head_size)

return output


'''
flash attention 2
modified the triton kernel in
https://github.com/openai/triton/blob/08c16589573621fcb8cd5a9c3b8a0537077f876d/python/tutorials/06-fused-attention.py
'''


@triton.jit
def _flash_packed_kernel(
QKV,
mask,
ADD_MASK: tl.constexpr,
IS_CAUSAL: tl.constexpr,
sm_scale,
Out,
stride_qz,
stride_qn,
stride_qm,
stride_mz,
stride_oz,
stride_on,
Z,
H,
N_CTX,
P_SEQ,
hidden_size,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
batch = off_hz // H
head = off_hz % H

q_offset = batch * stride_qz + head * BLOCK_DMODEL
k_offset = q_offset + hidden_size
v_offset = k_offset + hidden_size

# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)

q_ptrs = QKV + q_offset + offs_m[:, None] * stride_qn + offs_d[None, :]
k_ptrs = QKV + hidden_size + q_offset + offs_n[:, None] * stride_qn + offs_d[None, :]
v_ptrs = QKV + 2 * hidden_size + q_offset + offs_n[:, None] * stride_qn + offs_d[None, :]

# mask
off_mask = batch * stride_mz + offs_n[None, :]
mask_ptrs = mask + off_mask

# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# scale sm_scale by log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
qk_scale = sm_scale * 1.44269504
# load q: it will stay in SRAM throughout
q = tl.load(q_ptrs, mask=offs_m[:, None] < N_CTX, other=0.0)
q = (q * qk_scale).to(tl.float16)
# loop over k, v and update accumulator
lo = 0
hi = P_SEQ + (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + P_SEQ
for start_n in range(lo, hi, BLOCK_N):
# -- load k, v --
k = tl.load(k_ptrs + start_n * stride_qn, mask=(start_n + offs_n)[:, None] < N_CTX, other=0.0)
v = tl.load(v_ptrs + start_n * stride_qn, mask=(start_n + offs_n)[:, None] < N_CTX, other=0.0)
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16)

if ADD_MASK:
mask_val = tl.load(mask_ptrs)
mask_ptrs += BLOCK_N
qk = qk + mask_val.to(tl.float32)

if IS_CAUSAL:
qk = tl.where(P_SEQ + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))

qk += tl.dot(q, tl.trans(k), out_dtype=tl.float16)
qk += tl.where((start_n + offs_n)[None, :] < N_CTX, 0, minus_inf)
# -- compute scaling constant ---
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
alpha = tl.math.exp2(m_i - m_i_new)
p = tl.math.exp2(qk - m_i_new[:, None])
# -- scale and update acc --
acc_scale = l_i * 0 + alpha # workaround some compiler bug
acc *= acc_scale[:, None]
acc += tl.dot(p.to(tl.float16), v.to(tl.float16))
# -- update m_i and l_i --
l_i = l_i * alpha + tl.sum(p, 1)
m_i = m_i_new

# write back l and m
acc = acc / l_i[:, None]
o_offset = batch * stride_oz + head * BLOCK_DMODEL
out_ptrs = Out + o_offset + (offs_m[:, None] * stride_on + offs_d[None, :])
tl.store(out_ptrs, acc.to(tl.float16), mask=offs_m[:, None] < N_CTX)


def _triton_packed_flash(qkv, head_size, mask, sm_scale, causal=False, add_mask=True):
heads = qkv.shape[-1] // 3 // head_size
hidden_size = qkv.shape[-1] // 3

BLOCK_M = 128
BLOCK_N = 64 if head_size <= 64 else 32

o = torch.empty((qkv.shape[0], qkv.shape[1], hidden_size), device=qkv.device, dtype=torch.half)
if mask is None:
mask = torch.empty(0)
add_mask = False

grid = (triton.cdiv(qkv.shape[1], BLOCK_M), qkv.shape[0] * heads, 1)
num_stages = 4 if head_size <= 64 else 3
num_warps = 4
P_SEQ = 0

_flash_packed_kernel[grid](qkv,
mask,
add_mask,
causal,
sm_scale,
o,
qkv.stride(0),
qkv.stride(1),
qkv.stride(2),
mask.stride(1) if add_mask else 0,
o.stride(0),
o.stride(1),
qkv.shape[0],
heads,
qkv.shape[1],
P_SEQ,
hidden_size,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_DMODEL=head_size,
num_warps=num_warps,
num_stages=num_stages)

return o
11 changes: 11 additions & 0 deletions tests/unit/ops/transformer/inference/inference_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,14 @@ def assert_almost_equal(x, y, decimal=2, err_msg=''):
y = y.float()
y = y.cpu().detach().numpy()
npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal)


def max_diff(a, b):
a = a.to(torch.float32).flatten()
b = b.to(torch.float32).flatten()
diff = torch.abs(a - b)
max_diff_indices = torch.argsort(diff)[-1]
print("Max difference indices:", max_diff_indices)
print("Max difference values:", diff[max_diff_indices])
print(f"{a[max_diff_indices]} vs {b[max_diff_indices]}")
return max_diff_indices
69 changes: 44 additions & 25 deletions tests/unit/ops/transformer/inference/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest
import torch
import deepspeed
from deepspeed.accelerator import get_accelerator
from .inference_test_utils import assert_almost_equal


Expand All @@ -19,54 +20,72 @@ def ref_torch_attention(q, k, v, mask, sm_scale):

# test attention operator
@pytest.mark.inference_ops
@pytest.mark.parametrize("Z", [1]) # batch
@pytest.mark.parametrize("BATCH", [1]) # batch
@pytest.mark.parametrize("H", [12]) # heads
@pytest.mark.parametrize("N_CTX", [4, 128]) # sequence length
@pytest.mark.parametrize("N_CTX", [16, 128]) # sequence length
@pytest.mark.parametrize("D_HEAD", [64, 128])
@pytest.mark.parametrize("causal", [True, False])
def test_attention(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
@pytest.mark.parametrize("use_flash", [True, False])
def test_attention(BATCH, H, N_CTX, D_HEAD, causal, use_flash, dtype=torch.float16):
if not deepspeed.HAS_TRITON:
pytest.skip("triton has to be installed for the test")

minus_inf = -65504.0

# skip autotune in testing
from deepspeed.ops.transformer.inference.triton.matmul_ext import fp16_matmul
fp16_matmul.skip_autotune()

from deepspeed.ops.transformer.inference.triton.attention import compute_attention
from deepspeed.ops.transformer.inference.triton.attention import _triton_attention, _triton_packed_flash
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5)
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5)
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5)
q = torch.empty((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5)
k = torch.empty((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5)
v = torch.empty((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5)
sm_scale = 0.3

# reference implementation
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
score = p
mask = torch.zeros((Z, H, N_CTX, N_CTX), dtype=dtype, device="cuda")
mask = torch.zeros((BATCH, H, N_CTX, N_CTX), dtype=dtype, device="cuda")
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
if causal:
for z in range(Z):
for z in range(BATCH):
for h in range(H):
mask[:, :, M == 0] = float("-inf")
mask[:, :, M == 0] = minus_inf
p = torch.softmax(p.float() + mask, dim=-1).half()
softmax_out = p
ref_out = torch.matmul(p, v)
context = ref_out

# adjust it to expected tensor format and run test
qkv = torch.randn((Z, N_CTX, 3 * H * D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
qkv[:, :, :H * D_HEAD] = q.permute(0, 2, 1, 3).contiguous().reshape((Z, N_CTX, H * D_HEAD))
qkv[:, :, 1 * H * D_HEAD:2 * H * D_HEAD] = k.permute(0, 2, 1, 3).contiguous().reshape((Z, N_CTX, H * D_HEAD))
qkv[:, :, 2 * H * D_HEAD:] = v.permute(0, 2, 1, 3).contiguous().reshape((Z, N_CTX, H * D_HEAD))
tri_out = compute_attention(qkv,
input_mask=mask,
layer_past=None,
alibi=None,
scale=sm_scale,
head_size=D_HEAD,
triangular=False,
use_cuda_flash=False,
use_triton_flash=False,
use_ds_attention=False)
tri_out = tri_out.reshape((Z, N_CTX, H, D_HEAD)).permute(0, 2, 1, 3)
qkv = torch.randn((BATCH, N_CTX, 3 * H * D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
qkv[:, :, :H * D_HEAD] = q.permute(0, 2, 1, 3).contiguous().reshape((BATCH, N_CTX, H * D_HEAD))
qkv[:, :, 1 * H * D_HEAD:2 * H * D_HEAD] = k.permute(0, 2, 1, 3).contiguous().reshape((BATCH, N_CTX, H * D_HEAD))
qkv[:, :, 2 * H * D_HEAD:] = v.permute(0, 2, 1, 3).contiguous().reshape((BATCH, N_CTX, H * D_HEAD))

if use_flash:
if not get_accelerator().is_triton_supported():
pytest.skip("triton flash attention is supported when the compute capability > 8.0")
triton_mask = torch.zeros((BATCH, 1, 1, N_CTX), dtype=dtype, device="cuda")
if not causal:
lengths = torch.randint(N_CTX - 8, N_CTX, (BATCH, 1), device='cuda')
for i, l in enumerate(lengths):
triton_mask[i, ..., l:] = minus_inf
mask = torch.zeros((BATCH, H, N_CTX, N_CTX), dtype=dtype, device="cuda")
for b in range(BATCH):
mask[b, :, :, lengths[b]:] = minus_inf
ref_out = ref_torch_attention(q, k, v, mask, sm_scale)
tri_out = _triton_packed_flash(qkv, D_HEAD, triton_mask, sm_scale, causal=causal, add_mask=(not causal))
else:
tri_out = _triton_attention(qkv,
input_mask=mask,
layer_past=None,
alibi=None,
scale=sm_scale,
head_size=D_HEAD,
triangular=False,
use_cuda_flash=False,
use_triton_flash=False,
use_ds_attention=False)
tri_out = tri_out.reshape((BATCH, N_CTX, H, D_HEAD)).permute(0, 2, 1, 3)
assert_almost_equal(ref_out, tri_out)

0 comments on commit 0e0748c

Please sign in to comment.