From c5e48f49d8368216b4f99ef4023d2855f1ce3983 Mon Sep 17 00:00:00 2001 From: Omar Elayan <142979319+oelayan7@users.noreply.github.com> Date: Mon, 6 Jan 2025 20:54:57 +0200 Subject: [PATCH] Add fp8_gemm fallback for non-triton systems (#6916) - Removed try/except from __init__ file in fp_quantizer and added a single entry point instead - Renamed file fp8_gemm to fp8_gemm_triton, and the function matmul_fp8 to matmul_fp8_triton - Added a new entry point fp8_gemm with matmul_fp8 inside, and if the system supports triton it calls the triton implementation and if not it calls the fallback Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/ops/fp_quantizer/__init__.py | 7 +- deepspeed/ops/fp_quantizer/fp8_gemm.py | 163 +---------------- deepspeed/ops/fp_quantizer/fp8_gemm_triton.py | 171 ++++++++++++++++++ tests/unit/ops/fp_quantizer/test_fp8_gemm.py | 16 +- 4 files changed, 189 insertions(+), 168 deletions(-) create mode 100644 deepspeed/ops/fp_quantizer/fp8_gemm_triton.py diff --git a/deepspeed/ops/fp_quantizer/__init__.py b/deepspeed/ops/fp_quantizer/__init__.py index 51377bc6092c..f9cf23373c26 100644 --- a/deepspeed/ops/fp_quantizer/__init__.py +++ b/deepspeed/ops/fp_quantizer/__init__.py @@ -4,9 +4,4 @@ # DeepSpeed Team from .quantize import FP_Quantize, Quantizer - -try: - import triton - from .fp8_gemm import matmul_fp8 -except ImportError: - pass +from .fp8_gemm import matmul_fp8 diff --git a/deepspeed/ops/fp_quantizer/fp8_gemm.py b/deepspeed/ops/fp_quantizer/fp8_gemm.py index 55504e3af8c9..db4fa5ae2c92 100644 --- a/deepspeed/ops/fp_quantizer/fp8_gemm.py +++ b/deepspeed/ops/fp_quantizer/fp8_gemm.py @@ -11,161 +11,18 @@ ################################### import torch -import triton -import triton.language as tl -@triton.jit -def matmul_kernel_fp8_bf16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk, - stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, - quantization_group_size: tl.constexpr): - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m +def matmul_fp8(inp, weight, scale, quantization_group_size, quantizer): + from deepspeed import get_accelerator - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) + if not get_accelerator().is_triton_supported(): + return matmul_fp8_fallback(inp, weight, scale, quantization_group_size, quantizer) + else: + # Import dynamically to prevent failures on systems without triton. + from .fp8_gemm_triton import matmul_fp8_triton + return matmul_fp8_triton(inp, weight, scale, quantization_group_size) - inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + ( - (pid_n * BLOCK_SIZE_N) // quantization_group_size) - weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0) - scale = tl.load(scale_ptr + weight_ptrs_offset) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - # Dequantize weight (fp8 -> bf16) - w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 4)).to(tl.uint16) - w = (w + 0x3C00).to(tl.uint16) - w = (w.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16) - - inp_data += BLOCK_SIZE_K * stride_ak - weight_data += BLOCK_SIZE_K * stride_bk - weight_mask = offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K - weight = tl.load(weight_data, mask=weight_mask, other=0.0) - scale = tl.load(scale_ptr + (weight_ptrs_offset + - (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)), - mask=weight_mask, - other=0.0) - - accumulator += tl.dot(inp, w) - - out = accumulator.to(tl.bfloat16) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) - - -@triton.jit -def matmul_kernel_fp8_fp16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk, - stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, - quantization_group_size: tl.constexpr): - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) - num_pid_in_group = GROUP_SIZE_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_SIZE_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) - pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) - - inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + ( - (pid_n * BLOCK_SIZE_N) // quantization_group_size) - - weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0) - scale = tl.load(scale_ptr + weight_ptrs_offset) - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): - inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) - # Dequantize weight (fp8 -> fp16) - w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 7)).to(tl.uint16) - w = (w + 0x2000).to(tl.uint16) - w = (w.to(tl.float16, bitcast=True) * scale).to(tl.float16) - - inp_data += BLOCK_SIZE_K * stride_ak - weight_data += BLOCK_SIZE_K * stride_bk - - weight = tl.load(weight_data, mask=offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K, other=0.0) - scale = tl.load(scale_ptr + (weight_ptrs_offset + - (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size))) - - accumulator += tl.dot(inp, w) - - out = accumulator.to(tl.float16) - - offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] - tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) - - -def matmul_fp8(inp, weight, scale, quantization_group_size): - - assert inp.shape[1] == weight.shape[0], \ - f"Incompatible dimensions (input: {inp.shape}, weight: {weight.shape})" - - M, K = inp.shape - K, N = weight.shape - - out = torch.empty((M, N), device=inp.device, dtype=inp.dtype) - - # GEMM tuning parameters! - # TODO: Add a more configurable tuning for selecting the best GeMM - BLOCK_SIZE_M = 16 if M <= 16 else 32 if M <= 32 else 64 if M <= 64 else 128 - BLOCK_SIZE_N = 64 - BLOCK_SIZE_K = max(64, quantization_group_size) - GROUP_SIZE_M = 8 - num_stages = 4 - num_warps = 4 - if M >= 256: - BLOCK_SIZE_M = 256 - BLOCK_SIZE_N = 128 - BLOCK_SIZE_K = max(128, quantization_group_size) - num_stages = 3 - num_warps = 8 - - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) - kernel = matmul_kernel_fp8_bf16 if inp.dtype == torch.bfloat16 else matmul_kernel_fp8_fp16 - kernel[grid](inp, - weight, - out, - scale, - M, - N, - K, - inp.stride(0), - inp.stride(1), - weight.stride(0), - weight.stride(1), - out.stride(0), - out.stride(1), - quantization_group_size=quantization_group_size, - BLOCK_SIZE_M=BLOCK_SIZE_M, - BLOCK_SIZE_N=BLOCK_SIZE_N, - BLOCK_SIZE_K=BLOCK_SIZE_K, - GROUP_SIZE_M=GROUP_SIZE_M, - num_stages=num_stages, - num_warps=num_warps) - return out +def matmul_fp8_fallback(inp, weight, scale, quantization_group_size, quantizer): + return torch.matmul(inp, quantizer.dequantize(weight, scale=scale)) diff --git a/deepspeed/ops/fp_quantizer/fp8_gemm_triton.py b/deepspeed/ops/fp_quantizer/fp8_gemm_triton.py new file mode 100644 index 000000000000..746e217d4194 --- /dev/null +++ b/deepspeed/ops/fp_quantizer/fp8_gemm_triton.py @@ -0,0 +1,171 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +######## Fused MoE kernel ######### +# These kernels are implemented for +# fusing GeMM with dequantization of +# fp8 weight data when using bit-16 +# activation. +################################### + +import torch +import triton +import triton.language as tl + + +@triton.jit +def matmul_kernel_fp8_bf16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk, + stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + quantization_group_size: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + ( + (pid_n * BLOCK_SIZE_N) // quantization_group_size) + + weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0) + scale = tl.load(scale_ptr + weight_ptrs_offset) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + # Dequantize weight (fp8 -> bf16) + w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 4)).to(tl.uint16) + w = (w + 0x3C00).to(tl.uint16) + w = (w.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16) + + inp_data += BLOCK_SIZE_K * stride_ak + weight_data += BLOCK_SIZE_K * stride_bk + weight_mask = offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K + weight = tl.load(weight_data, mask=weight_mask, other=0.0) + scale = tl.load(scale_ptr + (weight_ptrs_offset + + (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)), + mask=weight_mask, + other=0.0) + + accumulator += tl.dot(inp, w) + + out = accumulator.to(tl.bfloat16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) + + +@triton.jit +def matmul_kernel_fp8_fp16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk, + stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, + quantization_group_size: tl.constexpr): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + ( + (pid_n * BLOCK_SIZE_N) // quantization_group_size) + + weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0) + scale = tl.load(scale_ptr + weight_ptrs_offset) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + # Dequantize weight (fp8 -> fp16) + w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 7)).to(tl.uint16) + w = (w + 0x2000).to(tl.uint16) + w = (w.to(tl.float16, bitcast=True) * scale).to(tl.float16) + + inp_data += BLOCK_SIZE_K * stride_ak + weight_data += BLOCK_SIZE_K * stride_bk + + weight = tl.load(weight_data, mask=offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K, other=0.0) + scale = tl.load(scale_ptr + (weight_ptrs_offset + + (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size))) + + accumulator += tl.dot(inp, w) + + out = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N)) + + +def matmul_fp8_triton(inp, weight, scale, quantization_group_size): + + assert inp.shape[1] == weight.shape[0], \ + f"Incompatible dimensions (input: {inp.shape}, weight: {weight.shape})" + + M, K = inp.shape + K, N = weight.shape + + out = torch.empty((M, N), device=inp.device, dtype=inp.dtype) + + # GEMM tuning parameters! + # TODO: Add a more configurable tuning for selecting the best GeMM + BLOCK_SIZE_M = 16 if M <= 16 else 32 if M <= 32 else 64 if M <= 64 else 128 + BLOCK_SIZE_N = 64 + BLOCK_SIZE_K = max(64, quantization_group_size) + GROUP_SIZE_M = 8 + num_stages = 4 + num_warps = 4 + if M >= 256: + BLOCK_SIZE_M = 256 + BLOCK_SIZE_N = 128 + BLOCK_SIZE_K = max(128, quantization_group_size) + num_stages = 3 + num_warps = 8 + + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + kernel = matmul_kernel_fp8_bf16 if inp.dtype == torch.bfloat16 else matmul_kernel_fp8_fp16 + kernel[grid](inp, + weight, + out, + scale, + M, + N, + K, + inp.stride(0), + inp.stride(1), + weight.stride(0), + weight.stride(1), + out.stride(0), + out.stride(1), + quantization_group_size=quantization_group_size, + BLOCK_SIZE_M=BLOCK_SIZE_M, + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + GROUP_SIZE_M=GROUP_SIZE_M, + num_stages=num_stages, + num_warps=num_warps) + return out diff --git a/tests/unit/ops/fp_quantizer/test_fp8_gemm.py b/tests/unit/ops/fp_quantizer/test_fp8_gemm.py index d66f7c8cb4cc..a4cf579f5943 100644 --- a/tests/unit/ops/fp_quantizer/test_fp8_gemm.py +++ b/tests/unit/ops/fp_quantizer/test_fp8_gemm.py @@ -14,6 +14,8 @@ from deepspeed.ops.fp_quantizer import FP_Quantize, matmul_fp8 +from deepspeed import get_accelerator + @pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"]) @pytest.mark.parametrize("q_bits", [8], ids=[ @@ -21,23 +23,19 @@ ]) @pytest.mark.parametrize("M", [1, 2, 4, 8, 32, 64, 128, 256, 512, 1024, 2048]) def test_fp_quant(dtype, q_bits, M): + device_name = get_accelerator().device_name() quantization_group_size = 128 fpq = FP_Quantize(group_size=quantization_group_size) N = 8192 H = 4096 - x = torch.randn(M, H, dtype=dtype, device='cuda') - weight_bf16 = torch.randn(H, N, dtype=dtype, device='cuda') + x = torch.randn(M, H, dtype=dtype, device=device_name) + weight_bf16 = torch.randn(H, N, dtype=dtype, device=device_name) - weight, _ = fpq.quantize(weight_bf16.data, q_bits=8, return_meta_tensor=True) + weight, _ = fpq.quantize(weight_bf16.data, q_bits=q_bits, return_meta_tensor=True) scale = fpq.get_scales() - out = matmul_fp8( - x, - weight, - scale, - quantization_group_size, - ) + out = matmul_fp8(x, weight, scale, quantization_group_size, fpq) out_q = torch.matmul(x, fpq.dequantize(weight, scale=fpq.scale))