Skip to content

Commit

Permalink
Check Marlin/Bitblas dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
mobicham committed Jul 16, 2024
1 parent 9838b84 commit e57104f
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
7 changes: 1 addition & 6 deletions hqq/backends/marlin.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
# Written by Dr. Hicham Badri @Mobius Labs GmbH - 2024
#####################################################
import torch

try:
import marlin
except Exception:
marlin = None
import marlin
from ..core.quantize import Quantizer


class MarlinLinear(torch.nn.Module):
def __init__(
self, W: torch.Tensor, scales: torch.Tensor, u=None, bias=None, groupsize=-1
Expand Down
17 changes: 12 additions & 5 deletions hqq/utils/patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,16 @@
from ..core.peft import HQQLinearLoRA
from ..models.hf.base import AutoHQQHFModel
from ..backends.torchao import patch_hqq_to_aoint4
from ..backends.marlin import patch_hqq_to_marlin
from ..backends.bitblas import patch_hqq_to_bitblas

try:
from ..backends.marlin import patch_hqq_to_marlin
except Exception:
patch_hqq_to_marlin = None
print('Failed to import the Marlin backend. Check if marlin is correctly installed (https://github.com/IST-DASLab/marlin).')
try:
from ..backends.bitblas import patch_hqq_to_bitblas
except Exception:
patch_hqq_to_bitblas = None
print('Failed to import the BitBlas backend. Check if BitBlas is correctly installed (https://github.com/microsoft/BitBLAS).')

def patch_linearlayers(model, fct, patch_param=None, verbose=False):
base_class = model.base_class if (hasattr(model, "base_class")) else AutoHQQHFModel
Expand Down Expand Up @@ -84,7 +91,7 @@ def prepare_for_inference(model, allow_merge=False, backend="default", verbose=F
patch_linearlayers(model, patch_lora_inference)
cleanup()

if backend == "bitblas":
if backend == "bitblas" and (patch_hqq_to_bitblas is not None):
patch_linearlayers(model, patch_hqq_to_bitblas, verbose=verbose)
cleanup()
if backend == "torchao_int4":
Expand All @@ -96,7 +103,7 @@ def prepare_for_inference(model, allow_merge=False, backend="default", verbose=F
verbose=verbose,
)
cleanup()
if backend == "marlin":
if backend == "marlin" and (patch_hqq_to_marlin is not None):
patch_linearlayers(model, patch_hqq_to_marlin, verbose=verbose)
cleanup()

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def run(self):
"transformers>=4.36.1",
"huggingface_hub",
"termcolor",
"bitblas",
#"bitblas",
#"timm",
],
)

0 comments on commit e57104f

Please sign in to comment.