diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 7726dcb9a..410f98738 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -159,7 +159,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, ): super().__init__(input_size, output_size, skip_bias_add, params_dtype, - quant_config) + None)#quant_config) # All the linear layer supports quant method. assert self.quant_method is not None diff --git a/vllm/model_executor/layers/quantization/deepspeedfp.py b/vllm/model_executor/layers/quantization/deepspeedfp.py index 108c20029..b3e5a6aae 100644 --- a/vllm/model_executor/layers/quantization/deepspeedfp.py +++ b/vllm/model_executor/layers/quantization/deepspeedfp.py @@ -10,6 +10,10 @@ from vllm.model_executor.utils import set_weight_attrs import gc +from vllm.distributed import get_tensor_model_parallel_world_size + +g_matmul_fp8 = None + class DeepSpeedFPConfig(QuantizationConfig): """Config for DeepSpeed FP quantizer. It supports fp6 and fp8. @@ -84,9 +88,10 @@ class DeepSpeedFPLinearMethod(LinearMethodBase): quant_config: the DeepSpeedFP quantization config. """ - def __init__(self, quant_config: DeepSpeedFPConfig): + def __init__(self, quant_config: DeepSpeedFPConfig, enable_fused_kernel=True): self.quant_config = quant_config self.weight = None + self.enable_fused_kernel = enable_fused_kernel def create_weights(self, layer: torch.nn.Module, @@ -120,14 +125,23 @@ def state_dict(**kwargs): layer.state_dict = state_dict def quant_weight_loader(param, loaded_weight, *args, **kwargs): - # Calls the original weight loader (if any), quantizes the result, - # and then loads the quantized parameter. if weight_loader is not None: - orig_param_data = param.data - param.data = param.ds_dequantize() - weight_loader(param, loaded_weight, *args, **kwargs) - param.data, loaded_weight = orig_param_data, param.data - param.ds_quantize_(loaded_weight.cuda()) + if not hasattr(param, 'shadow_data'): + param.shadow_data = torch.empty( + param.orig_shape, + dtype=loaded_weight.dtype, + device=loaded_weight.device) + param.shadow_data.input_dim = param.input_dim + param.shadow_data.output_dim = param.output_dim + tp_size = get_tensor_model_parallel_world_size() + param.loading_cont = loaded_weight.shape[0] // param.orig_shape[0] // tp_size if loaded_weight.shape[0] != param.orig_shape[0] else \ + loaded_weight.shape[1] // param.orig_shape[1] // tp_size if loaded_weight.shape[1] != param.orig_shape[1] else 1 + weight_loader(param.shadow_data, loaded_weight, *args, **kwargs) + param.loading_cont -= 1 + loaded_weight = param.shadow_data + if not hasattr(param, 'loading_cnt') or param.loading_cnt == 0: + param.ds_quantize_(loaded_weight.transpose(-1, -2).contiguous().cuda() if self.enable_fused_kernel else loaded_weight.cuda()) + extra_weight_attrs["weight_loader"] = quant_weight_loader set_weight_attrs(weight, extra_weight_attrs) @@ -137,8 +151,14 @@ def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: weight = layer.weight - y = weight.ds_dequantize() - return F.linear(x, y, bias) + if self.enable_fused_kernel: + scale = weight.fp_quantizer.get_scales() + y = g_matmul_fp8(x, weight, scale, weight.fp_quantizer.group_size) + return y if bias is None else (y + bias) + else: + weight = layer.weight + y = weight.ds_dequantize() + return F.linear(x, y, bias) class DeepSpeedFPParameter(nn.Parameter): @@ -151,15 +171,15 @@ class DeepSpeedFPParameter(nn.Parameter): def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype, quant_config: DeepSpeedFPConfig): try: + global g_matmul_fp8 import deepspeed - if deepspeed.__version__ < "0.14.2": - raise ImportError("deepspeed version is wrong. Please " - "install deepspeed>=0.14.2.") - from deepspeed.ops.fp_quantizer import FP_Quantize + from deepspeed.ops.fp_quantizer import FP_Quantize, matmul_fp8 + g_matmul_fp8 = matmul_fp8 except ImportError as err: raise ImportError("Please install deepspeed>=0.14.2 via " "`pip install deepspeed>=0.14.2` to use " "deepspeedfp quantizer.") from err + reduce_dim = -1 data = torch.empty(orig_shape, dtype=torch.uint8) self = torch.Tensor._make_subclass(cls, data, data.requires_grad) self.orig_shape = orig_shape @@ -167,21 +187,21 @@ def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype, g_size = max( [ 2**i for i in range(4, int(math.log2(quant_config.group_size)+1)) \ - if orig_shape[-1] % (2**i) == 0 + if orig_shape[reduce_dim] % (2**i) == 0 ] ) self.fp_quantizer = FP_Quantize(group_size=g_size) self.fp_quantizer.orig_shape = orig_shape self.fp_quantizer.orig_dtype = params_dtype self.fp_quantizer.num_groups = self.numel() // g_size - self.fp_quantizer.scales = torch.empty(orig_shape.numel() // g_size, 4, + self.fp_quantizer.scale = torch.empty(orig_shape.numel() // g_size, 4, dtype=torch.uint8, device=self.data.device) return self def ds_quantize_(self, tensor: torch.Tensor): assert tensor.device.type == "cuda" and tensor.dtype != torch.uint8 prev_data = self.data - q_data, _ = self.fp_quantizer.quantize( + q_data, self.scale = self.fp_quantizer.quantize( tensor.data, q_bits=self.quant_config.weight_bits, return_meta_tensor=True @@ -204,7 +224,10 @@ def ds_dequantize(self, fp_out=None) -> torch.Tensor: """ assert self.data.device.type == "cuda" and self.data.dtype == torch.uint8 return self.fp_quantizer.dequantize( - self.data, fp_out=fp_out, q_bits=self.quant_config.weight_bits) + self.data, + fp_out=fp_out, + q_bits=self.quant_config.weight_bits, + scale=self.scale) def ds_selective_dequantize(self, indices, fp_out=None) -> torch.Tensor: """ @@ -216,4 +239,5 @@ def ds_selective_dequantize(self, indices, fp_out=None) -> torch.Tensor: self.data, indices, fp_out=fp_out, - q_bits=self.quant_config.weight_bits) + q_bits=self.quant_config.weight_bits, + scale=self.scale)