From 96a5cc1ba98030bd3fcb601ac2aca955f06768b2 Mon Sep 17 00:00:00 2001 From: mobicham Date: Tue, 16 Jul 2024 08:13:53 +0000 Subject: [PATCH] check marlin/bitblas dependencies in base model --- hqq/models/base.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/hqq/models/base.py b/hqq/models/base.py index 968c2b5..3074572 100755 --- a/hqq/models/base.py +++ b/hqq/models/base.py @@ -16,9 +16,21 @@ from ..core.quantize import HQQLinear from ..core.peft import PeftUtils, _HQQ_LORA_CLASSES from ..backends.torchao import HQQLinearTorchWeightOnlynt4 -from ..backends.marlin import MarlinLinear -_HQQ_BACKEND_CLASSES = [HQQLinearTorchWeightOnlynt4, MarlinLinear] +_HQQ_BACKEND_CLASSES = [HQQLinearTorchWeightOnlynt4] + +try: + from ..backends.bitblas import HQQLinearBitBlas + _HQQ_BACKEND_CLASSES.append(HQQLinearBitBlas) +except Exception: + pass + +try: + from ..backends.marlin import MarlinLinear + _HQQ_BACKEND_CLASSES.append(MarlinLinear) +except Exception: + pass + # Defined what is qualified as "linear layer" _QUANT_LAYERS = [nn.Linear, HQQLinear] + _HQQ_LORA_CLASSES + _HQQ_BACKEND_CLASSES