Skip to content

Commit

Permalink
[Hardware][XPU] AWQ/GPTQ support for xpu backend (vllm-project#10107)
Browse files Browse the repository at this point in the history
Signed-off-by: yan ma <[email protected]>
Signed-off-by: rickyx <[email protected]>
  • Loading branch information
yma11 authored and rickyyx committed Nov 20, 2024
1 parent 3b2b0ae commit dc6fbc1
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 52 deletions.
8 changes: 4 additions & 4 deletions docs/source/quantization/supported_hardware.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ The table below shows the compatibility of various quantization implementations
- ✅︎
- ✅︎
- ✗
-
- ✅︎
- ✅︎
- ✗
- ✗
Expand All @@ -38,8 +38,8 @@ The table below shows the compatibility of various quantization implementations
- ✅︎
- ✅︎
- ✗
-
-
- ✅︎
- ✅︎
- ✗
- ✗
* - Marlin (GPTQ/AWQ/FP8)
Expand Down Expand Up @@ -129,4 +129,4 @@ Notes:

Please note that this compatibility chart may be subject to change as vLLM continues to evolve and expand its support for different hardware platforms and quantization methods.

For the most up-to-date information on hardware support and quantization methods, please check the `quantization directory <https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization>`_ or consult with the vLLM development team.
For the most up-to-date information on hardware support and quantization methods, please check the `quantization directory <https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization>`_ or consult with the vLLM development team.
10 changes: 6 additions & 4 deletions tests/quantization/test_ipex_quant.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Test model set-up and inference for quantized HF models supported
on the CPU backend using IPEX (including AWQ).
on the CPU/GPU backend using IPEX (including AWQ/GPTQ).
Validating the configuration and printing results for manual checking.
Expand All @@ -11,13 +11,15 @@
from vllm.platforms import current_platform

MODELS = [
"casperhansen/llama-3-8b-instruct-awq",
"AMead10/Llama-3.2-1B-Instruct-AWQ",
"shuyuej/Llama-3.2-1B-Instruct-GPTQ", # with g_idx
]
DTYPE = ["bfloat16"]


@pytest.mark.skipif(not current_platform.is_cpu(),
reason="only supports the CPU backend.")
@pytest.mark.skipif(not current_platform.is_cpu()
and not current_platform.is_xpu(),
reason="only supports Intel CPU/XPU backend.")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", DTYPE)
def test_ipex_quant(vllm_runner, model, dtype):
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
"MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod",
"TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod",
"ModelOptFp8LinearMethod", "IPEXAWQLinearMethod"
"ModelOptFp8LinearMethod", "IPEXAWQLinearMethod", "IPEXGPTQLinearMethod"
]


Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/layers/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ def create_weights(

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# for torch.compile
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types

logger = init_logger(__name__)
Expand Down Expand Up @@ -134,6 +135,9 @@ def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
sym = quant_config.get("sym")
desc_act = quant_config.get("desc_act")

if not current_platform.is_cuda():
return False

if quant_method != "gptq":
return False

Expand Down
169 changes: 128 additions & 41 deletions vllm/model_executor/layers/quantization/ipex_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,51 +2,57 @@

import torch

from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod,
is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.platforms import current_platform

MIN_IPEX_VERSION = "2.5.0"


class IPEXConfig(QuantizationConfig):
"""INT8 quantization config class using IPEX for the CPU backend,
including AWQ.
"""INT8 quantization config class using IPEX for the CPU/XPU backend,
including AWQ, GPTQ.
"""

IPEX_QUANT_METHOD_MAP = {
"awq": 1,
"gptq": 2,
"gptq": 0,
}

def __init__(
self,
method: str,
weight_bits: int,
group_size: int,
modules_to_not_convert: Optional[List[str]] = None,
desc_act: Optional[bool] = None,
lm_head_quantized: Optional[bool] = None,
) -> None:
self.method = method
self.weight_bits = weight_bits
self.group_size = group_size
self.modules_to_not_convert = modules_to_not_convert or []
self.desc_act = desc_act
self.lm_head_quantized = lm_head_quantized
self.pack_factor = 32 // self.weight_bits

if self.weight_bits not in [4]:
raise ValueError(f"IPEX quantization supports weight bits [4], "
f"but got {self.weight_bits}.")

if self.method == "awq":
self.quant_method = IPEXAWQLinearMethod
else:
raise ValueError(f"IPEX quantization supports [awq], "
if self.method not in ["awq", "gptq"]:
raise ValueError(f"IPEX quantization supports [awq, gptq], "
f"but got {self.method}.")

def __repr__(self) -> str:
return (f"IPEXConfig(method={self.method}"
return (f"IPEXConfig(method={self.method},"
f"weight_bits={self.weight_bits}, "
f"group_size={self.group_size}")

def get_ipex_quant_method_id(self) -> int:
return IPEXConfig.IPEX_QUANT_METHOD_MAP[self.method]
f"group_size={self.group_size})")

@classmethod
def get_name(cls) -> str:
Expand All @@ -70,32 +76,114 @@ def get_config_filenames() -> List[str]:
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "IPEXConfig":
method = cls.get_from_keys(config, ["quant_method"]).lower()
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
return cls(method, weight_bits, group_size)
if method == "awq":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config,
["q_group_size", "group_size"])
modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None)
return cls(method, weight_bits, group_size, modules_to_not_convert,
False, False)
# otherwise for gptq
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
desc_act = cls.get_from_keys_or(config, ["desc_act"], default=False)
return cls(method, weight_bits, group_size, [], desc_act,
lm_head_quantized)

@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
if not current_platform.is_cpu():
if not current_platform.is_cpu() and not current_platform.is_xpu():
return None

quant_method = hf_quant_cfg.get("quant_method", "").lower()

if quant_method in ["awq"]:
if quant_method in ["awq", "gptq"]:
return cls.get_name()

return None

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["LinearMethodBase"]:
if isinstance(layer, LinearBase):
return self.quant_method(self)
if self.method == "awq":
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
return IPEXAWQLinearMethod(self)
if self.method == "gptq":
return IPEXGPTQLinearMethod(self)
return None


class IPEXGPTQLinearMethod(GPTQLinearMethod):
"""GPTQ linear method using IPEX for the CPU/XPU backend.
"""

def __init__(self, quant_config: IPEXConfig):
self.quant_config = quant_config # type: ignore

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
bias = layer.bias if not layer.skip_bias_add else None

try:
import intel_extension_for_pytorch as ipex
if ipex.__version__ < MIN_IPEX_VERSION:
raise ImportError(
"intel_extension_for_pytorch version is "
"wrong. Please install "
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}.")
except ImportError as err:
raise ImportError(
"Please install "
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via "
f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`"
" to use IPEX-AWQ linear method.") from err
# Using the compute dtype (lowp_mode) as INT8 to leverage instructions
# with better performance.
lowp_mode = ipex.quantization.WoqLowpMode.INT8
# The weight will be de-packed from INT4 to INT8.
weight_dtype = ipex.quantization.WoqWeightDtype.INT4
# The float activation will be quantized (dynamic, per-token) to INT8.
act_quant_mode = ipex.quantization.WoqActQuantMode.PER_BATCH_IC_BLOCK

qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
weight_dtype=weight_dtype,
lowp_mode=lowp_mode,
act_quant_mode=act_quant_mode,
group_size=self.quant_config.group_size,
)
layer.ipex_output_size = layer.qweight.shape[-1]
g_idx = layer.g_idx if self.quant_config.desc_act else None
layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \
IPEXWeightOnlyQuantizedLinear.from_weight(
layer.qweight,
layer.scales,
layer.qzeros,
layer.qweight.size(0),
layer.ipex_output_size,
qconfig=qconfig,
g_idx=g_idx,
bias=bias,
group_size=self.quant_config.group_size,
quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["gptq"]
)

def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
out = layer.ipex_qlinear(reshaped_x)
if bias is not None:
out.add_(bias)
return out.reshape(x.shape[:-1] + (layer.ipex_output_size, ))


class IPEXAWQLinearMethod(AWQLinearMethod):
"""AWQ linear method using IPEX for the CPU backend.
"""AWQ linear method using IPEX for the CPU/XPU backend.
"""

def __init__(self, quant_config: IPEXConfig):
Expand All @@ -108,15 +196,16 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

try:
import intel_extension_for_pytorch as ipex
if ipex.__version__ < "2.4.0":
raise ImportError("intel_extension_for_pytorch version is "
"wrong. Please install "
"intel_extension_for_pytorch>=2.4.0.")
if ipex.__version__ < MIN_IPEX_VERSION:
raise ImportError(
"intel_extension_for_pytorch version is "
"wrong. Please install "
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION}.")
except ImportError as err:
raise ImportError(
"Please install "
"intel_extension_for_pytorch>=2.4.0 via "
"`pip install intel_extension_for_pytorch>=2.4.0`"
f"intel_extension_for_pytorch>={MIN_IPEX_VERSION} via "
f"`pip install intel_extension_for_pytorch>={MIN_IPEX_VERSION}`"
" to use IPEX-AWQ linear method.") from err

# Using the compute dtype (lowp_mode) as INT8 to leverage instructions
Expand All @@ -136,25 +225,23 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

layer.ipex_output_size = layer.qweight.size(
1) * self.quant_config.pack_factor
layer.ipex_qlinear = ipex.nn.modules.weight_only_quantization.\
WeightOnlyQuantizedLinear.from_weight(
layer.qweight,
layer.scales,
layer.qzeros,
layer.qweight.size(0),
layer.ipex_output_size,
qconfig=qconfig,
bias=bias,
group_size=self.quant_config.group_size,
quant_method=
self.quant_config.get_ipex_quant_method_id() # type: ignore
)
layer.ipex_qlinear = ipex.llm.quantization.woq_linear. \
IPEXWeightOnlyQuantizedLinear.from_weight(
layer.qweight,
layer.scales,
layer.qzeros,
layer.qweight.size(0),
layer.ipex_output_size,
qconfig=qconfig,
bias=bias,
group_size=self.quant_config.group_size,
quant_method=IPEXConfig.IPEX_QUANT_METHOD_MAP["awq"] # type: ignore
)

def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
out = layer.ipex_qlinear(reshaped_x)

return out.reshape(x.shape[:-1] + (layer.ipex_output_size, ))
4 changes: 3 additions & 1 deletion vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase)
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
serialize_vllm_model, tensorizer_weights_iterator)
Expand Down Expand Up @@ -348,7 +350,7 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:

for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
if isinstance(quant_method, QuantizeMethodBase):
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
Expand Down

0 comments on commit dc6fbc1

Please sign in to comment.