diff --git a/src/transformers/quantizers/quantizer_awq.py b/src/transformers/quantizers/quantizer_awq.py index 4dd818f6465df9..d7a756b23a07e7 100644 --- a/src/transformers/quantizers/quantizer_awq.py +++ b/src/transformers/quantizers/quantizer_awq.py @@ -52,6 +52,10 @@ def validate_environment(self, device_map, **kwargs): if not is_accelerate_available(): raise ImportError("Loading an AWQ quantized model requires accelerate (`pip install accelerate`)") + if self.quantization_config.version == AWQLinearVersion.GEMM and not torch.cuda.is_available(): + logger.warning_once("No CUDA found, replace GEMM with IPEX version to support non-cuda AWQ model.") + self.quantization_config.version = AWQLinearVersion.IPEX + if self.quantization_config.version == AWQLinearVersion.IPEX: if version.parse(importlib.metadata.version("autoawq")) < version.parse("0.2.6"): raise RuntimeError(