Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/mobiusml/hqq
Browse files Browse the repository at this point in the history
  • Loading branch information
mobicham committed Aug 28, 2024
2 parents 85f1a9c + 04c7fc4 commit 293812c
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 42 deletions.
18 changes: 9 additions & 9 deletions examples/backends/bitblas_int4_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@

#Quantize
#all 4-bit
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False, axis=1)
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, axis=1)

#Mixed 4-bit (bitblas) / 2-bit (ATEN)
# quant_config = {
# "self_attn.q_proj": BaseQuantizeConfig(nbits=2, group_size=32, quant_scale=False, quant_zero=False, axis=0),
# "self_attn.k_proj": BaseQuantizeConfig(nbits=2, group_size=32, quant_scale=False, quant_zero=False, axis=0),
# "self_attn.v_proj": BaseQuantizeConfig(nbits=2, group_size=32, quant_scale=False, quant_zero=False, axis=0),
# "self_attn.o_proj": BaseQuantizeConfig(nbits=2, group_size=32, quant_scale=False, quant_zero=False, axis=0),

# "mlp.gate_proj": BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False, axis=1),
# "mlp.up_proj": BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False, axis=1),
# "mlp.down_proj": BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False, axis=1),
# "self_attn.q_proj": BaseQuantizeConfig(nbits=2, group_size=32, axis=0),
# "self_attn.k_proj": BaseQuantizeConfig(nbits=2, group_size=32, axis=0),
# "self_attn.v_proj": BaseQuantizeConfig(nbits=2, group_size=32, axis=0),
# "self_attn.o_proj": BaseQuantizeConfig(nbits=2, group_size=32, axis=0),

# "mlp.gate_proj": BaseQuantizeConfig(nbits=4, group_size=64, axis=1),
# "mlp.up_proj": BaseQuantizeConfig(nbits=4, group_size=64, axis=1),
# "mlp.down_proj": BaseQuantizeConfig(nbits=4, group_size=64, axis=1),
# }
# HQQLinear.set_backend(HQQBackend.ATEN)

Expand Down
2 changes: 1 addition & 1 deletion examples/backends/marlin_int4_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=cache_path, torch_dtype=compute_dtype, attn_implementation="sdpa")

#Quantize
quant_config = BaseQuantizeConfig(nbits=4, group_size=None, quant_scale=False, quant_zero=False, axis=1)
quant_config = BaseQuantizeConfig(nbits=4, group_size=None, axis=1)
AutoHQQHFModel.setup_model(model)
AutoHQQHFModel.quantize_model(model, quant_config=quant_config, compute_dtype=compute_dtype, device=device)
HQQLinear.set_backend(HQQBackend.PYTORCH)
Expand Down
2 changes: 1 addition & 1 deletion examples/backends/quantize_and_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=cache_dir, torch_dtype=compute_dtype, attn_implementation="sdpa")
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir)

quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False, axis=1)
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, axis=1)
AutoHQQHFModel.quantize_model(model, quant_config=quant_config, device=device, compute_dtype=compute_dtype)

#Use optimized inference kernels
Expand Down
2 changes: 1 addition & 1 deletion examples/backends/torchao_int4_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=cache_path, torch_dtype=compute_dtype, attn_implementation="sdpa")

#Quantize
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False, axis=1)
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, axis=1)
AutoHQQHFModel.setup_model(model)
AutoHQQHFModel.quantize_model(model, quant_config=quant_config, compute_dtype=compute_dtype, device=device)
HQQLinear.set_backend(HQQBackend.PYTORCH)
Expand Down
36 changes: 36 additions & 0 deletions examples/backends/transformers_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#Works with multi-gpu as well, tested with BitBlas

import torch, gc
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig

device = 'auto'
dtype = torch.float16
model_id = 'meta-llama/Meta-Llama-3-8B-Instruct'
cache_dir = '.'

quant_config = HqqConfig(nbits=4, group_size=64, axis=1)

model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=dtype,
cache_dir=cache_dir,
device_map=device,
quantization_config=quant_config
)

tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir)

#Patching
from hqq.utils.patching import *
from hqq.core.quantize import *
HQQLinear.set_backend(HQQBackend.PYTORCH)
prepare_for_inference(model, backend='bitblas', verbose=True) #Takes a while

#Import custom HF generator
from hqq.utils.generation_hf import HFGenerator

#Generate
gen = HFGenerator(model, tokenizer, max_new_tokens=1000, do_sample=True, compile=None) #Quick test - slower inference
#gen = HFGenerator(model, tokenizer, max_new_tokens=1000, do_sample=True, compile="partial").warmup() #Takes a while - fastest

out = gen.generate("Write an essay about large language models.", print_tokens=True)
10 changes: 5 additions & 5 deletions examples/llama2_benchmark/quant_llama2_hqq_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
######################################################################################
from hqq.core.quantize import *

#quant_config = BaseQuantizeConfig(nbits=8, group_size=128)
quant_config = BaseQuantizeConfig(nbits=4, group_size=64)
#quant_config = BaseQuantizeConfig(nbits=3, group_size=64)
#quant_config = BaseQuantizeConfig(nbits=2, group_size=16)
#quant_config = BaseQuantizeConfig(nbits=2, group_size=16, quant_scale=True) #scale is quantized to 8-bit/g=128
#quant_config = BaseQuantizeConfig(nbits=8, group_size=128, axis=0)
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, axis=0)
#quant_config = BaseQuantizeConfig(nbits=3, group_size=64, axis=0)
#quant_config = BaseQuantizeConfig(nbits=2, group_size=16, axis=0)
#quant_config = BaseQuantizeConfig(nbits=2, group_size=16, quant_scale=True, axis=0) #scale is quantized to 8-bit/g=128

model.quantize_model(quant_config=quant_config)

Expand Down
25 changes: 0 additions & 25 deletions examples/vllm/llama2_example.py

This file was deleted.

0 comments on commit 293812c

Please sign in to comment.