Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the fp8-quantized GeMM for dense linear layers #18

Open
wants to merge 9 commits into
base: pp-staging
Choose a base branch
from
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
None)#quant_config)

# All the linear layer supports quant method.
assert self.quant_method is not None
Expand Down
62 changes: 43 additions & 19 deletions vllm/model_executor/layers/quantization/deepspeedfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from vllm.model_executor.utils import set_weight_attrs
import gc

from vllm.distributed import get_tensor_model_parallel_world_size

g_matmul_fp8 = None

class DeepSpeedFPConfig(QuantizationConfig):
"""Config for DeepSpeed FP quantizer. It supports fp6 and fp8.

Expand Down Expand Up @@ -84,9 +88,10 @@ class DeepSpeedFPLinearMethod(LinearMethodBase):
quant_config: the DeepSpeedFP quantization config.
"""

def __init__(self, quant_config: DeepSpeedFPConfig):
def __init__(self, quant_config: DeepSpeedFPConfig, enable_fused_kernel=True):
self.quant_config = quant_config
self.weight = None
self.enable_fused_kernel = enable_fused_kernel

def create_weights(self,
layer: torch.nn.Module,
Expand Down Expand Up @@ -120,14 +125,23 @@ def state_dict(**kwargs):
layer.state_dict = state_dict

def quant_weight_loader(param, loaded_weight, *args, **kwargs):
# Calls the original weight loader (if any), quantizes the result,
# and then loads the quantized parameter.
if weight_loader is not None:
orig_param_data = param.data
param.data = param.ds_dequantize()
weight_loader(param, loaded_weight, *args, **kwargs)
param.data, loaded_weight = orig_param_data, param.data
param.ds_quantize_(loaded_weight.cuda())
if not hasattr(param, 'shadow_data'):
param.shadow_data = torch.empty(
param.orig_shape,
dtype=loaded_weight.dtype,
device=loaded_weight.device)
param.shadow_data.input_dim = param.input_dim
param.shadow_data.output_dim = param.output_dim
tp_size = get_tensor_model_parallel_world_size()
param.loading_cont = loaded_weight.shape[0] // param.orig_shape[0] // tp_size if loaded_weight.shape[0] != param.orig_shape[0] else \
loaded_weight.shape[1] // param.orig_shape[1] // tp_size if loaded_weight.shape[1] != param.orig_shape[1] else 1
weight_loader(param.shadow_data, loaded_weight, *args, **kwargs)
param.loading_cont -= 1
loaded_weight = param.shadow_data
if not hasattr(param, 'loading_cnt') or param.loading_cnt == 0:
param.ds_quantize_(loaded_weight.transpose(-1, -2).contiguous().cuda() if self.enable_fused_kernel else loaded_weight.cuda())


extra_weight_attrs["weight_loader"] = quant_weight_loader
set_weight_attrs(weight, extra_weight_attrs)
Expand All @@ -137,8 +151,14 @@ def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = layer.weight
y = weight.ds_dequantize()
return F.linear(x, y, bias)
if self.enable_fused_kernel:
scale = weight.fp_quantizer.get_scales()
y = g_matmul_fp8(x, weight, scale, weight.fp_quantizer.group_size)
return y if bias is None else (y + bias)
else:
weight = layer.weight
y = weight.ds_dequantize()
return F.linear(x, y, bias)


class DeepSpeedFPParameter(nn.Parameter):
Expand All @@ -151,37 +171,37 @@ class DeepSpeedFPParameter(nn.Parameter):
def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype,
quant_config: DeepSpeedFPConfig):
try:
global g_matmul_fp8
import deepspeed
if deepspeed.__version__ < "0.14.2":
raise ImportError("deepspeed version is wrong. Please "
"install deepspeed>=0.14.2.")
from deepspeed.ops.fp_quantizer import FP_Quantize
from deepspeed.ops.fp_quantizer import FP_Quantize, matmul_fp8
g_matmul_fp8 = matmul_fp8
except ImportError as err:
raise ImportError("Please install deepspeed>=0.14.2 via "
"`pip install deepspeed>=0.14.2` to use "
"deepspeedfp quantizer.") from err
reduce_dim = -1
data = torch.empty(orig_shape, dtype=torch.uint8)
self = torch.Tensor._make_subclass(cls, data, data.requires_grad)
self.orig_shape = orig_shape
self.quant_config = quant_config
g_size = max(
[
2**i for i in range(4, int(math.log2(quant_config.group_size)+1)) \
if orig_shape[-1] % (2**i) == 0
if orig_shape[reduce_dim] % (2**i) == 0
]
)
self.fp_quantizer = FP_Quantize(group_size=g_size)
self.fp_quantizer.orig_shape = orig_shape
self.fp_quantizer.orig_dtype = params_dtype
self.fp_quantizer.num_groups = self.numel() // g_size
self.fp_quantizer.scales = torch.empty(orig_shape.numel() // g_size, 4,
self.fp_quantizer.scale = torch.empty(orig_shape.numel() // g_size, 4,
dtype=torch.uint8, device=self.data.device)
return self

def ds_quantize_(self, tensor: torch.Tensor):
assert tensor.device.type == "cuda" and tensor.dtype != torch.uint8
prev_data = self.data
q_data, _ = self.fp_quantizer.quantize(
q_data, self.scale = self.fp_quantizer.quantize(
tensor.data,
q_bits=self.quant_config.weight_bits,
return_meta_tensor=True
Expand All @@ -204,7 +224,10 @@ def ds_dequantize(self, fp_out=None) -> torch.Tensor:
"""
assert self.data.device.type == "cuda" and self.data.dtype == torch.uint8
return self.fp_quantizer.dequantize(
self.data, fp_out=fp_out, q_bits=self.quant_config.weight_bits)
self.data,
fp_out=fp_out,
q_bits=self.quant_config.weight_bits,
scale=self.scale)

def ds_selective_dequantize(self, indices, fp_out=None) -> torch.Tensor:
"""
Expand All @@ -216,4 +239,5 @@ def ds_selective_dequantize(self, indices, fp_out=None) -> torch.Tensor:
self.data,
indices,
fp_out=fp_out,
q_bits=self.quant_config.weight_bits)
q_bits=self.quant_config.weight_bits,
scale=self.scale)