Skip to content

Commit

Permalink
Add support for FP8 on compute capability >=8.0, <8.9
Browse files Browse the repository at this point in the history
Use FP8 GPTQ-Marlin kernels to enable FP8 support on CUDA GPUs
with compute capability >=8.0 and <8.9.

Co-authored-by: Daniël de Kok <[email protected]>
  • Loading branch information
flozi00 and danieldk committed Jul 10, 2024
1 parent 8511669 commit 74fdd83
Show file tree
Hide file tree
Showing 8 changed files with 1,458 additions and 4 deletions.
8 changes: 8 additions & 0 deletions server/marlin/marlin_kernels/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,11 @@ def marlin_gemm(
Matrix multiplication using Marlin kernels.
"""
...

# fp8 marlin
def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: int, size_n: int,
size_k: int) -> torch.Tensor:
return torch.ops._C.fp8_marlin_gemm(a, b_q_weight, b_scales, workspace,
num_bits, size_m, size_n, size_k)
2 changes: 2 additions & 0 deletions server/marlin/marlin_kernels/ext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("gptq_marlin_repack", &gptq_marlin_repack,
"Repack GPTQ parameters for Marlin");
m.def("marlin_gemm", &marlin_gemm, "Marlin gemm");
// fp8_marlin Optimized Quantized GEMM for FP8 weight-only.
m.def("fp8_marlin_gemm", &fp8_marlin_gemm);
}
5 changes: 5 additions & 0 deletions server/marlin/marlin_kernels/ext.hh
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,9 @@ torch::Tensor marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight,
torch::Tensor &b_scales, torch::Tensor &workspace,
int64_t size_m, int64_t size_n, int64_t size_k);

torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k);

#endif
1,308 changes: 1,308 additions & 0 deletions server/marlin/marlin_kernels/fp8_marlin.cu

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions server/marlin/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
CUDAExtension(
name="marlin_kernels",
sources=[
"marlin_kernels/fp8_marlin.cu",
"marlin_kernels/gptq_marlin.cu",
"marlin_kernels/gptq_marlin_repack.cu",
"marlin_kernels/marlin_cuda_kernel.cu",
Expand Down
19 changes: 19 additions & 0 deletions server/text_generation_server/layers/fp8.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,23 @@
from enum import Enum, auto

import torch
from text_generation_server.utils.import_utils import SYSTEM


def get_fp8_linear() -> torch.nn.Module:
"""
Return an FP8 linear `Module` that is compatible with the current system.
"""

if SYSTEM == "cuda":
major, minor = torch.cuda.get_device_capability()
if major == 8 and minor < 9:
from text_generation_server.layers.marlin import GPTQMarlinFP8Linear

return GPTQMarlinFP8Linear

# On other systems let Torch decide if the hardware supports FP8.
return Fp8Linear


def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
Expand Down
4 changes: 2 additions & 2 deletions server/text_generation_server/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ def get_linear(weight, bias, quantize):
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
)
elif quantize == "fp8":
from text_generation_server.layers.fp8 import Fp8Linear
from text_generation_server.layers.fp8 import get_fp8_linear

linear = Fp8Linear(weight, bias)
linear = get_fp8_linear()(weight, bias)
elif quantize == "bitsandbytes":
try:
from text_generation_server.layers.bnb import (
Expand Down
115 changes: 113 additions & 2 deletions server/text_generation_server/layers/marlin.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

from text_generation_server.utils.weights import Weights, WeightsLoader
import torch
import torch.nn as nn

from loguru import logger
from text_generation_server.layers.fp8 import fp8_quantize
from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.utils.log import log_once
from text_generation_server.utils.weights import Weights, WeightsLoader

try:
import marlin_kernels
Expand Down Expand Up @@ -455,6 +457,115 @@ def forward(self, A: torch.Tensor) -> torch.Tensor:
return C


class GPTQMarlinFP8Linear(nn.Module):
"""
FP8 GPTQ-Marlin linear layer.
"""

def __init__(
self,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
) -> None:
super().__init__()

_check_marlin_kernels()
assert marlin_kernels is not None

log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")

qweight, scale = fp8_quantize(weight)
scale = scale.to(torch.float16)
qweight, scales = repack_fp8_for_marlin(qweight, scale)

in_features = qweight.shape[0] * MARLIN_TILE_SIZE
out_features = scales.shape[1]
_check_valid_shape(in_features=in_features, out_features=out_features)

self.qweight = qweight
self.scales = scales
self.bias = bias if bias is not None else None

self.workspace = torch.zeros(
out_features // 64 * 16, dtype=torch.int, device=qweight.device
)

def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None

A_flat = A.view(-1, A.shape[-1])
C = marlin_kernels.fp8_marlin_gemm(
A_flat,
self.qweight,
self.scales,
self.workspace,
8,
A_flat.shape[0],
self.scales.shape[1],
A_flat.shape[1],
)
C = C.reshape(A.shape[:-1] + (self.scales.shape[1],))

if self.bias is not None:
C += self.bias

return C


def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
"""
Repack FP8 weights to gptq format (packed int32 elements).
"""
assert fp8_tensor.dtype == torch.float8_e4m3fn

if fp8_tensor.shape[0] % 4 != 0:
raise ValueError(
f"Leading tensor dimension is not divisable by 4: {fp8_tensor.shape[0]}"
)

# Reshape to prepare for packing
reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:])

# Convert fp8 to uint8 (byte) representation
byte_tensor = reshaped.view(torch.uint8)

# Pack 4 uint8 values into one int32
packed = torch.zeros(
fp8_tensor.shape[0] // 4,
fp8_tensor.shape[1],
dtype=torch.int32,
device=fp8_tensor.device,
)

for i in range(4):
packed.bitwise_or_(byte_tensor[:, i].to(torch.int32) << i * 8)

return packed


def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor):
"""
Repack FP8 tensor for GPTQ-Marlin.
"""

out_features, in_features = weight.shape

# Torch linear layers weights with shape [out_features, in_features],
# GPTQ-quantized weights use [in_feateres/pack_factor, in_features],
# so transpose before packing.
qweight = pack_fp8_as_int32(weight.t())

perm = torch.empty(0, dtype=torch.int, device=qweight.device)
repacked = marlin_kernels.gptq_marlin_repack(
qweight, perm, in_features, out_features, 8
)

scales = scale.reshape(1, 1).repeat(1, out_features)
scales = permute_scales(scales)

return repacked, scales


@dataclass
class MarlinWeight:
"""
Expand Down

0 comments on commit 74fdd83

Please sign in to comment.