Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the fp8-quantized GeMM for dense linear layers #18

Open
wants to merge 9 commits into
base: pp-staging
Choose a base branch
from
5 changes: 5 additions & 0 deletions vllm/model_executor/layers/fused_fp8/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from vllm.model_executor.layers.fused_fp8.fused_fp8_gemm import matmul_fp8

__all__ = [
"matmul_fp8",
]
114 changes: 114 additions & 0 deletions vllm/model_executor/layers/fused_fp8/fused_fp8_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import torch
import triton
import triton.language as tl

def get_autotune_config():
return [
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3,
num_warps=8),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
num_warps=2),
triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5,
num_warps=2),
# Good config for fp8 inputs.
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3,
num_warps=8),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4,
num_warps=4)
]

@triton.autotune(
configs=get_autotune_config(),
key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel_fp8(
a_ptr, b_ptr, c_ptr, scale_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #
GROUP_SIZE_M: tl.constexpr, #
quantization_group_size: tl.constexpr
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m


offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

b_ptrs_offset = offs_bn[None, :] * (stride_bn // quantization_group_size)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
scale = tl.load(scale_ptr + b_ptrs_offset)
# Dequantize weight (fp8 -> bf16)
b = ((b & 0x80) << 8) | ((b & 0x7f) << 4)
b = (b + 0x3C00).to(tl.uint16)
b = (b.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16)

accumulator = tl.dot(a, b, accumulator)

a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk

c = accumulator.to(tl.float16)

offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)

def matmul_fp8(a, b, scale, quantization_group_size):
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
M, K = a.shape
K, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=torch.float16)
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
matmul_kernel_fp8[grid](
a, b, c, scale, #
M, N, K, #
a.stride(0), a.stride(1), #
b.stride(0), b.stride(1), #
c.stride(0), c.stride(1), #
quantization_group_size=quantization_group_size
)
return c
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
None)#quant_config)

# All the linear layer supports quant method.
assert self.quant_method is not None
Expand Down
28 changes: 22 additions & 6 deletions vllm/model_executor/layers/quantization/deepspeedfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from vllm.model_executor.utils import set_weight_attrs
import gc

from vllm.model_executor.layers.fused_fp8 import matmul_fp8

class DeepSpeedFPConfig(QuantizationConfig):
"""Config for DeepSpeed FP quantizer. It supports fp6 and fp8.

Expand Down Expand Up @@ -84,9 +86,10 @@ class DeepSpeedFPLinearMethod(LinearMethodBase):
quant_config: the DeepSpeedFP quantization config.
"""

def __init__(self, quant_config: DeepSpeedFPConfig):
def __init__(self, quant_config: DeepSpeedFPConfig, enable_fused_kernel=True):
self.quant_config = quant_config
self.weight = None
self.enable_fused_kernel = enable_fused_kernel

def create_weights(self,
layer: torch.nn.Module,
Expand All @@ -96,6 +99,7 @@ def create_weights(self,
output_size: int,
params_dtype: torch.dtype,
weight_loader=None,
transposed=True,
**extra_weight_attrs):
del output_size
del input_size
Expand All @@ -104,6 +108,7 @@ def create_weights(self,
torch.Size((output_size_per_partition, input_size_per_partition)),
params_dtype=params_dtype,
quant_config=self.quant_config,
transposed=transposed
)
set_weight_attrs(weight, {
"input_dim": 1,
Expand Down Expand Up @@ -136,9 +141,16 @@ def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = layer.weight
y = weight.ds_dequantize()
return F.linear(x, y, bias)
if self.enable_fused_kernel:
return matmul_fp8(
x, layer.weight,
layer.weight.quantization_scales(),
layer.weight.fp_quantizer.group_size,
)
else:
weight = layer.weight
y = weight.ds_dequantize()
return F.linear(x, y, bias)


class DeepSpeedFPParameter(nn.Parameter):
Expand All @@ -149,7 +161,7 @@ class DeepSpeedFPParameter(nn.Parameter):
"""

def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype,
quant_config: DeepSpeedFPConfig):
quant_config: DeepSpeedFPConfig, transposed=False):
try:
import deepspeed
if deepspeed.__version__ < "0.14.2":
Expand All @@ -160,14 +172,18 @@ def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype,
raise ImportError("Please install deepspeed>=0.14.2 via "
"`pip install deepspeed>=0.14.2` to use "
"deepspeedfp quantizer.") from err
reduce_dim = -1
if transposed:
orig_shape = (orig[:-2]+(orig[-1],orig[-2]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

orig should be orig_shape?

reduce_dim = -2
data = torch.empty(orig_shape, dtype=torch.uint8)
self = torch.Tensor._make_subclass(cls, data, data.requires_grad)
self.orig_shape = orig_shape
self.quant_config = quant_config
g_size = max(
[
2**i for i in range(4, int(math.log2(quant_config.group_size)+1)) \
if orig_shape[-1] % (2**i) == 0
if orig_shape[reduce_dim] % (2**i) == 0
]
)
self.fp_quantizer = FP_Quantize(group_size=g_size)
Expand Down