diff --git a/accelerator/abstract_accelerator.py b/accelerator/abstract_accelerator.py index fe0e66768d45..a87ff3c1d223 100644 --- a/accelerator/abstract_accelerator.py +++ b/accelerator/abstract_accelerator.py @@ -185,6 +185,10 @@ def lazy_call(self, callback): def communication_backend_name(self): ... + @abc.abstractmethod + def is_triton_supported(self): + ... + # Tensor operations @property @abc.abstractmethod diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index 11518d31e069..4de4ad93c2bb 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -182,6 +182,9 @@ def lazy_call(self, callback): def communication_backend_name(self): return self._communication_backend_name + def is_triton_supported(self): + return False + # Data types def is_bf16_supported(self): return True diff --git a/accelerator/cuda_accelerator.py b/accelerator/cuda_accelerator.py index 9c1e0d22785e..045cce510a90 100644 --- a/accelerator/cuda_accelerator.py +++ b/accelerator/cuda_accelerator.py @@ -173,6 +173,13 @@ def lazy_call(self, callback): def communication_backend_name(self): return self._communication_backend_name + def is_triton_supported(self): + major, _ = torch.cuda.get_device_capability() + if major >= 8: + return True + else: + return False + # Tensor operations @property diff --git a/accelerator/mps_accelerator.py b/accelerator/mps_accelerator.py index 8007a50c4bb6..63a92f250898 100644 --- a/accelerator/mps_accelerator.py +++ b/accelerator/mps_accelerator.py @@ -160,6 +160,9 @@ def lazy_call(self, callback): def communication_backend_name(self): return self._communication_backend_name + def is_triton_supported(self): + return False + # Tensor operations @property def BFloat16Tensor(self): diff --git a/accelerator/npu_accelerator.py b/accelerator/npu_accelerator.py index 5678a0266386..206bc1dfaa1b 100644 --- a/accelerator/npu_accelerator.py +++ b/accelerator/npu_accelerator.py @@ -158,6 +158,9 @@ def lazy_call(self, callback): def communication_backend_name(self): return self._communication_backend_name + def is_triton_supported(self): + return False + # Tensor operations @property diff --git a/deepspeed/ops/transformer/inference/triton/attention.py b/deepspeed/ops/transformer/inference/triton/attention.py index dbccaf5bd470..c05370ec74e5 100644 --- a/deepspeed/ops/transformer/inference/triton/attention.py +++ b/deepspeed/ops/transformer/inference/triton/attention.py @@ -6,6 +6,8 @@ import math import torch import torch.nn as nn +import triton +import triton.language as tl from deepspeed.accelerator import get_accelerator from deepspeed import comm as dist from deepspeed.ops.transformer.inference.op_binding import LinearOp, VectorMatMulOp, SoftmaxContextOp, QKVGemmOp @@ -70,6 +72,9 @@ def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count self.mp_group = mp_group self.use_flash = False + # triton flash attention is enabled when the compute capability >= 8.0 + if get_accelerator().is_triton_supported(): + self.use_flash = True # used for quantization self.q_scales = q_scales @@ -176,7 +181,7 @@ def forward( qkv = qkv_out[0] if use_triton_attention and (alibi is None): - context_layer = compute_attention(qkv=qkv, + context_layer = _triton_attention(qkv=qkv, input_mask=input_mask, scale=self.scale, layer_past=layer_past, @@ -204,7 +209,7 @@ def forward( global inference_module -def compute_attention(qkv, +def _triton_attention(qkv, input_mask, layer_past, alibi, @@ -217,13 +222,166 @@ def compute_attention(qkv, if isinstance(qkv, list): qkv = qkv[0] - #assert layer_past is None, "layer_past not supported in triton yet" assert alibi is None, "layer_past not supported in alibi yet" - output = score_4d_matmul(qkv, head_size, triangular, scale) - if triangular: - output = softmax(output) + + if use_triton_flash: + output = _triton_packed_flash(qkv, + head_size, + input_mask, + scale, + causal=triangular, + add_mask=(not triangular and input_mask is not None)) else: - output = softmax(output, input_mask) - output = context_4d_matmul(output, qkv, head_size) + output = score_4d_matmul(qkv, head_size, triangular, scale) + if triangular: + output = softmax(output) + else: + output = softmax(output, input_mask) + output = context_4d_matmul(output, qkv, head_size) return output + + +''' +flash attention 2 +modified the triton kernel in +https://github.com/openai/triton/blob/08c16589573621fcb8cd5a9c3b8a0537077f876d/python/tutorials/06-fused-attention.py +''' + + +@triton.jit +def _flash_packed_kernel( + QKV, + mask, + ADD_MASK: tl.constexpr, + IS_CAUSAL: tl.constexpr, + sm_scale, + Out, + stride_qz, + stride_qn, + stride_qm, + stride_mz, + stride_oz, + stride_on, + Z, + H, + N_CTX, + P_SEQ, + hidden_size, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + batch = off_hz // H + head = off_hz % H + + q_offset = batch * stride_qz + head * BLOCK_DMODEL + k_offset = q_offset + hidden_size + v_offset = k_offset + hidden_size + + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + + q_ptrs = QKV + q_offset + offs_m[:, None] * stride_qn + offs_d[None, :] + k_ptrs = QKV + hidden_size + q_offset + offs_n[:, None] * stride_qn + offs_d[None, :] + v_ptrs = QKV + 2 * hidden_size + q_offset + offs_n[:, None] * stride_qn + offs_d[None, :] + + # mask + off_mask = batch * stride_mz + offs_n[None, :] + mask_ptrs = mask + off_mask + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs, mask=offs_m[:, None] < N_CTX, other=0.0) + q = (q * qk_scale).to(tl.float16) + # loop over k, v and update accumulator + lo = 0 + hi = P_SEQ + (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + P_SEQ + for start_n in range(lo, hi, BLOCK_N): + # -- load k, v -- + k = tl.load(k_ptrs + start_n * stride_qn, mask=(start_n + offs_n)[:, None] < N_CTX, other=0.0) + v = tl.load(v_ptrs + start_n * stride_qn, mask=(start_n + offs_n)[:, None] < N_CTX, other=0.0) + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16) + + if ADD_MASK: + mask_val = tl.load(mask_ptrs) + mask_ptrs += BLOCK_N + qk = qk + mask_val.to(tl.float32) + + if IS_CAUSAL: + qk = tl.where(P_SEQ + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + qk += tl.dot(q, tl.trans(k), out_dtype=tl.float16) + qk += tl.where((start_n + offs_n)[None, :] < N_CTX, 0, minus_inf) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc *= acc_scale[:, None] + acc += tl.dot(p.to(tl.float16), v.to(tl.float16)) + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + + # write back l and m + acc = acc / l_i[:, None] + o_offset = batch * stride_oz + head * BLOCK_DMODEL + out_ptrs = Out + o_offset + (offs_m[:, None] * stride_on + offs_d[None, :]) + tl.store(out_ptrs, acc.to(tl.float16), mask=offs_m[:, None] < N_CTX) + + +def _triton_packed_flash(qkv, head_size, mask, sm_scale, causal=False, add_mask=True): + heads = qkv.shape[-1] // 3 // head_size + hidden_size = qkv.shape[-1] // 3 + + BLOCK_M = 128 + BLOCK_N = 64 if head_size <= 64 else 32 + + o = torch.empty((qkv.shape[0], qkv.shape[1], hidden_size), device=qkv.device, dtype=torch.half) + if mask is None: + mask = torch.empty(0) + add_mask = False + + grid = (triton.cdiv(qkv.shape[1], BLOCK_M), qkv.shape[0] * heads, 1) + num_stages = 4 if head_size <= 64 else 3 + num_warps = 4 + P_SEQ = 0 + + _flash_packed_kernel[grid](qkv, + mask, + add_mask, + causal, + sm_scale, + o, + qkv.stride(0), + qkv.stride(1), + qkv.stride(2), + mask.stride(1) if add_mask else 0, + o.stride(0), + o.stride(1), + qkv.shape[0], + heads, + qkv.shape[1], + P_SEQ, + hidden_size, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=head_size, + num_warps=num_warps, + num_stages=num_stages) + + return o diff --git a/tests/unit/ops/transformer/inference/inference_test_utils.py b/tests/unit/ops/transformer/inference/inference_test_utils.py index fb9518f6a39c..9c7b428c0e68 100644 --- a/tests/unit/ops/transformer/inference/inference_test_utils.py +++ b/tests/unit/ops/transformer/inference/inference_test_utils.py @@ -52,3 +52,14 @@ def assert_almost_equal(x, y, decimal=2, err_msg=''): y = y.float() y = y.cpu().detach().numpy() npt.assert_array_almost_equal(x, y, err_msg=err_msg, decimal=decimal) + + +def max_diff(a, b): + a = a.to(torch.float32).flatten() + b = b.to(torch.float32).flatten() + diff = torch.abs(a - b) + max_diff_indices = torch.argsort(diff)[-1] + print("Max difference indices:", max_diff_indices) + print("Max difference values:", diff[max_diff_indices]) + print(f"{a[max_diff_indices]} vs {b[max_diff_indices]}") + return max_diff_indices diff --git a/tests/unit/ops/transformer/inference/test_attention.py b/tests/unit/ops/transformer/inference/test_attention.py index db4221305a51..13abe8b915c7 100644 --- a/tests/unit/ops/transformer/inference/test_attention.py +++ b/tests/unit/ops/transformer/inference/test_attention.py @@ -6,6 +6,7 @@ import pytest import torch import deepspeed +from deepspeed.accelerator import get_accelerator from .inference_test_utils import assert_almost_equal @@ -19,54 +20,72 @@ def ref_torch_attention(q, k, v, mask, sm_scale): # test attention operator @pytest.mark.inference_ops -@pytest.mark.parametrize("Z", [1]) # batch +@pytest.mark.parametrize("BATCH", [1]) # batch @pytest.mark.parametrize("H", [12]) # heads -@pytest.mark.parametrize("N_CTX", [4, 128]) # sequence length +@pytest.mark.parametrize("N_CTX", [16, 128]) # sequence length @pytest.mark.parametrize("D_HEAD", [64, 128]) @pytest.mark.parametrize("causal", [True, False]) -def test_attention(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): +@pytest.mark.parametrize("use_flash", [True, False]) +def test_attention(BATCH, H, N_CTX, D_HEAD, causal, use_flash, dtype=torch.float16): if not deepspeed.HAS_TRITON: pytest.skip("triton has to be installed for the test") + minus_inf = -65504.0 + # skip autotune in testing from deepspeed.ops.transformer.inference.triton.matmul_ext import fp16_matmul fp16_matmul.skip_autotune() - from deepspeed.ops.transformer.inference.triton.attention import compute_attention + from deepspeed.ops.transformer.inference.triton.attention import _triton_attention, _triton_packed_flash torch.manual_seed(20) - q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5) - k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5) - v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5) + q = torch.empty((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5) + k = torch.empty((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5) + v = torch.empty((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5) sm_scale = 0.3 # reference implementation p = torch.matmul(q, k.transpose(2, 3)) * sm_scale score = p - mask = torch.zeros((Z, H, N_CTX, N_CTX), dtype=dtype, device="cuda") + mask = torch.zeros((BATCH, H, N_CTX, N_CTX), dtype=dtype, device="cuda") M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) if causal: - for z in range(Z): + for z in range(BATCH): for h in range(H): - mask[:, :, M == 0] = float("-inf") + mask[:, :, M == 0] = minus_inf p = torch.softmax(p.float() + mask, dim=-1).half() softmax_out = p ref_out = torch.matmul(p, v) context = ref_out # adjust it to expected tensor format and run test - qkv = torch.randn((Z, N_CTX, 3 * H * D_HEAD), dtype=dtype, device='cuda', requires_grad=False) - qkv[:, :, :H * D_HEAD] = q.permute(0, 2, 1, 3).contiguous().reshape((Z, N_CTX, H * D_HEAD)) - qkv[:, :, 1 * H * D_HEAD:2 * H * D_HEAD] = k.permute(0, 2, 1, 3).contiguous().reshape((Z, N_CTX, H * D_HEAD)) - qkv[:, :, 2 * H * D_HEAD:] = v.permute(0, 2, 1, 3).contiguous().reshape((Z, N_CTX, H * D_HEAD)) - tri_out = compute_attention(qkv, - input_mask=mask, - layer_past=None, - alibi=None, - scale=sm_scale, - head_size=D_HEAD, - triangular=False, - use_cuda_flash=False, - use_triton_flash=False, - use_ds_attention=False) - tri_out = tri_out.reshape((Z, N_CTX, H, D_HEAD)).permute(0, 2, 1, 3) + qkv = torch.randn((BATCH, N_CTX, 3 * H * D_HEAD), dtype=dtype, device='cuda', requires_grad=False) + qkv[:, :, :H * D_HEAD] = q.permute(0, 2, 1, 3).contiguous().reshape((BATCH, N_CTX, H * D_HEAD)) + qkv[:, :, 1 * H * D_HEAD:2 * H * D_HEAD] = k.permute(0, 2, 1, 3).contiguous().reshape((BATCH, N_CTX, H * D_HEAD)) + qkv[:, :, 2 * H * D_HEAD:] = v.permute(0, 2, 1, 3).contiguous().reshape((BATCH, N_CTX, H * D_HEAD)) + + if use_flash: + if not get_accelerator().is_triton_supported(): + pytest.skip("triton flash attention is supported when the compute capability > 8.0") + triton_mask = torch.zeros((BATCH, 1, 1, N_CTX), dtype=dtype, device="cuda") + if not causal: + lengths = torch.randint(N_CTX - 8, N_CTX, (BATCH, 1), device='cuda') + for i, l in enumerate(lengths): + triton_mask[i, ..., l:] = minus_inf + mask = torch.zeros((BATCH, H, N_CTX, N_CTX), dtype=dtype, device="cuda") + for b in range(BATCH): + mask[b, :, :, lengths[b]:] = minus_inf + ref_out = ref_torch_attention(q, k, v, mask, sm_scale) + tri_out = _triton_packed_flash(qkv, D_HEAD, triton_mask, sm_scale, causal=causal, add_mask=(not causal)) + else: + tri_out = _triton_attention(qkv, + input_mask=mask, + layer_past=None, + alibi=None, + scale=sm_scale, + head_size=D_HEAD, + triangular=False, + use_cuda_flash=False, + use_triton_flash=False, + use_ds_attention=False) + tri_out = tri_out.reshape((BATCH, N_CTX, H, D_HEAD)).permute(0, 2, 1, 3) assert_almost_equal(ref_out, tri_out)