Skip to content

Commit

Permalink
update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
mobicham committed Aug 28, 2024
1 parent aba8ebe commit 573b2ad
Showing 1 changed file with 13 additions and 16 deletions.
29 changes: 13 additions & 16 deletions Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ This repository contains the official implementation of Half-Quadratic Quantizat
<li> HQQ is compatible with peft training.</li>
<li> We try to make HQQ fully compatible `torch.compile` for faster inference and training.</li>
</ul>
<b>What is the quality of the quantized models? </b><br>
We have detailed benchmarks on both language and vision models. Please refer to our blog posts: <a href="https://mobiusml.github.io/hqq_blog/">HQQ</a>, <a href="https://mobiusml.github.io/1bit_blog/">HQQ+</a>.<br>

Expand All @@ -26,10 +26,10 @@ This repository contains the official implementation of Half-Quadratic Quantizat

<b>What quantization settings should I use?</b><br>
You should start with `nbits=4, group_size=64, axis=1`. These settings offer a good balance between quality, vram usage and speed. If you want better results with the same vram usage, switch to `axis=0` and use the ATEN backend. If you want to use lower like `nbits=2`, you should use `axis=0`with a low group-size via HQQ+, meaning adding low-rank adapters and fine-tune with a small dataset. <br>
<b>What does the `axis` parameter mean? </b><br>
The `axis` parameter is the axis along which grouping is performed. In general `axis=0` gives better results than `axis=1`, especially at lower bits. However, the optimized inference runtime only supports `axis=1` for the moment.<br>
<b>What is the difference between HQQ and HQQ+?</b><br>
HQQ+ is HQQ with trainable low-rank adapters to improve the quantization quality at lower bits.<br>

Expand Down Expand Up @@ -65,9 +65,6 @@ The quantization parameters are set as follows:

- ```nbits``` (int): supports 8, 4, 3, 2, 1 bits.
- ```group_size``` (int): no restrictions as long as ```weight.numel()``` is divisible by the ```group_size```.
- ```quant_zero``` (bool): if True, it quantizes the zero-point to 8-bit without grouping.
- ```quant_scale``` (bool): if True, it quantizes the scaling factor to 8-bit with a group_size of 128.
- ```offload_meta``` (bool): if True, meta-data is offloaded to the CPU.
- ```view_as_float``` (bool): if True, the quantized parameter is viewed as float instead of a int type.

Setting ```offload_meta=True``` drastically decreases the GPU memory requirements but makes processing slower for smaller group-sizes. When turned on, you can run Llama2-70B and Mixtral with HQQ 2-bit using only 18.8GB and 13GB VRAM respectively.
Expand All @@ -76,9 +73,9 @@ Setting ```offload_meta=True``` drastically decreases the GPU memory requirement
#### Native Backends
The following native backends can be used by the `HQQLinear` module:
```Python
HQQLinear.set_backend(HQQBackend.PYTORCH) #Pytorch backend
HQQLinear.set_backend(HQQBackend.PYTORCH) #Pytorch backend - Default
HQQLinear.set_backend(HQQBackend.PYTORCH_COMPILE) #Compiled Pytorch
HQQLinear.set_backend(HQQBackend.ATEN) #Aten/CUDA backend
HQQLinear.set_backend(HQQBackend.ATEN) #Aten/CUDA backend - only axis=0 supported
```
The ```HQQBackend.ATEN``` backend is automatically installed and used by default when available.
Note that ```HQQBackend.ATEN``` only supports `axis=0`. For `axis=1` you need to use ```HQQBackend.PYTORCH``` or ```HQQBackend.PYTORCH_COMPILE```.
Expand All @@ -88,7 +85,7 @@ Below you can find the speed-up benchmark with various backends, ```HQQBackend.P
<div class="row"><center>
<div class="column">
<img src="https://github.com/mobiusml/hqq/blob/master/imgs/hqq_cuda_dequant_llama27b_titanrtx.png" alt="Titan RTX" style="width:48%">
<img src="https://github.com/mobiusml/hqq/blob/master/imgs/hqq_cuda_dequant_llama270b_a100.png" alt="A100" style="width:48%">
<img src="https://github.com/mobiusml/hqq/blob/master/imgs/hqq_cuda_dequant_llama270b_a100.png" alt="A100" style="width:48%">
</div>
</center>
</div>
Expand Down Expand Up @@ -124,7 +121,7 @@ For usage with HF's transformers, see the example below from the <a href="https:
from transformers import AutoModelForCausalLM, HqqConfig

# All linear layers will use the same quantization config
quant_config = HqqConfig(nbits=4, group_size=64, quant_zero=False, quant_scale=False, axis=1)
quant_config = HqqConfig(nbits=4, group_size=64)

# Load and quantize
model = AutoModelForCausalLM.from_pretrained(
Expand All @@ -145,7 +142,7 @@ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=compute_dtype

#Quantize
from hqq.models.hf.base import AutoHQQHFModel
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False, axis=1)
quant_config = BaseQuantizeConfig(nbits=4, group_size=64)
AutoHQQHFModel.quantize_model(model, quant_config=quant_config, compute_dtype=compute_dtype, device=device)
```
#### Save/Load
Expand All @@ -160,7 +157,7 @@ AutoHQQHFModel.save_quantized(model, save_dir)
model = AutoHQQHFModel.from_quantized(save_dir)
```
#### Setting a backend
You can set a native backned as follows:
You can set a native backend as follows:
```Python
HQQLinear.set_backend(HQQBackend.ATEN if axis==0 else HQQBackend.PYTORCH_COMPILE)
```
Expand All @@ -185,8 +182,8 @@ You can set up various quantization configurations for different layers by speci
#### Transformers 🤗
```Python
# Each linear layer with the same tag will use a dedicated quantization config
q4_config = {'nbits':4, 'group_size':64, 'quant_zero':False, 'quant_scale':False}
q3_config = {'nbits':3, 'group_size':32, 'quant_zero':False, 'quant_scale':False}
q4_config = {'nbits':4, 'group_size':64}
q3_config = {'nbits':3, 'group_size':32}

quant_config = HqqConfig(dynamic_config={
'self_attn.q_proj':q4_config,
Expand All @@ -202,8 +199,8 @@ quant_config = HqqConfig(dynamic_config={
#### HQQ lib
```Python
from hqq.core.quantize import *
q4_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=False, quant_scale=False)
q3_config = BaseQuantizeConfig(nbits=3, group_size=32, quant_zero=False, quant_scale=False)
q4_config = BaseQuantizeConfig(nbits=4, group_size=64)
q3_config = BaseQuantizeConfig(nbits=3, group_size=32)

quant_config = {'self_attn.q_proj':q4_config,
'self_attn.k_proj':q4_config,
Expand Down

0 comments on commit 573b2ad

Please sign in to comment.