Skip to content

Commit

Permalink
add_exllamav2
Browse files Browse the repository at this point in the history
  • Loading branch information
SunMarc committed Sep 27, 2023
1 parent 915e182 commit 5db2541
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 16 deletions.
15 changes: 14 additions & 1 deletion docs/source/llm_quantization/usage_guides/quantization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,21 @@ empty_model.tie_weights()
quantized_model = load_quantized_model(empty_model, save_folder=save_folder, device_map="auto", disable_exllama=False)
```

Note that only 4-bit models are supported with exllama kernels for now. Furthermore, it is recommended to disable the exllama kernel when you are finetuning your model with peft.
With the release of the exllamav2 kernel, you can get faster inference speed compared to the exllama kernels. You just need to
pass `disable_exllamav2` in [`~optimum.gptq.load_quantized_model`]:

```py
from optimum.gptq import GPTQQuantizer, load_quantized_model
import torch

from accelerate import init_empty_weights
with init_empty_weights():
empty_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
empty_model.tie_weights()
quantized_model = load_quantized_model(empty_model, save_folder=save_folder, device_map="auto", disable_exllamav2=False)
```

Note that only 4-bit models are supported with exllama/exllamav2 kernels for now. Furthermore, it is recommended to disable the exllama kernel when you are finetuning your model with peft.
#### Fine-tune a quantized model

With the official support of adapters in the Hugging Face ecosystem, you can fine-tune models that have been quantized with GPTQ.
Expand Down
45 changes: 33 additions & 12 deletions optimum/gptq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ def __init__(
module_name_preceding_first_block: Optional[List[str]] = None,
batch_size: int = 1,
pad_token_id: Optional[int] = None,
disable_exllama: bool = False,
disable_exllama: bool = True,
disable_exllamav2: bool = False,
max_input_length: Optional[int] = None,
*args,
**kwargs,
Expand Down Expand Up @@ -107,8 +108,10 @@ def __init__(
The batch size of the dataset
pad_token_id (`Optional[int]`, defaults to `None`):
The pad token id. Needed to prepare the dataset when `batch_size` > 1.
disable_exllama (`bool`, defaults to `False`):
disable_exllama (`bool`, defaults to `True`):
Whether to use exllama backend. Only works with `bits` = 4.
disable_exllamav2 (`bool`, defaults to `False`):
Whether to use exllamav2 backend. Only works with `bits` = 4.
max_input_length (`Optional[int]`, defaults to `None`):
The maximum input length. This is needed to initialize a buffer that depends on the maximum expected input length.
It is specific to the exllama backend with act-order.
Expand All @@ -128,6 +131,7 @@ def __init__(
self.batch_size = batch_size
self.pad_token_id = pad_token_id
self.disable_exllama = disable_exllama
self.disable_exllamav2 = disable_exllamav2
self.max_input_length = max_input_length
self.quant_method = QuantizationMethod.GPTQ

Expand All @@ -137,6 +141,9 @@ def __init__(
raise ValueError("group_size must be greater than 0 or equal to -1")
if not (0 < self.damp_percent < 1):
raise ValueError("damp_percent must between 0 and 1.")
if not self.disable_exllamav2 and not self.disable_exllama:
logger.warning("You have activated exllama and exllamav2 backend. Setting `disable_exllama=True` and keeping `disable_exllamav2=False`")
self.disable_exllama=True

def to_dict(self):
"""
Expand Down Expand Up @@ -205,6 +212,7 @@ def _replace_by_quant_layers(self, module: nn.Module, names: List[str], name: st
group_size=self.group_size,
bits=self.bits,
disable_exllama=self.disable_exllama,
disable_exllamav2=self.disable_exllamav2,
)
if isinstance(module, QuantLinear):
return
Expand Down Expand Up @@ -440,13 +448,21 @@ def tmp(_, input, output):
layer_inputs, layer_outputs = layer_outputs, []
torch.cuda.empty_cache()

if self.bits == 4 and not self.disable_exllama:
if self.bits == 4:
# device not on gpu
if device == torch.device("cpu") or (has_device_map and any(d in devices for d in ["cpu", "disk"])):
logger.warning(
"Found modules on cpu/disk. Using Exllama backend requires all the modules to be on GPU. Setting `disable_exllama=True`"
)
self.disable_exllama = True
elif self.desc_act:
if not self.disable_exllama:
logger.warning(
"Found modules on cpu/disk. Using Exllama backend requires all the modules to be on GPU. Setting `disable_exllama=True`"
)
self.disable_exllama = True
if not self.disable_exllamav2:
logger.warning(
"Found modules on cpu/disk. Using Exllamav2 backend requires all the modules to be on GPU. Setting `disable_exllamav2=True`"
)
self.disable_exllamav2 = True
# act order and exllama
elif self.desc_act and not self.disable_exllama:
logger.warning(
"Using Exllama backend with act_order will reorder the weights offline, thus you will not be able to save the model with the right weights."
"Setting `disable_exllama=True`. You should only use Exllama backend with act_order for inference. "
Expand Down Expand Up @@ -475,13 +491,13 @@ def post_init_model(self, model):
model (`nn.Module`):
The input model
"""
if self.bits == 4 and not self.disable_exllama:
if self.bits == 4 and (not self.disable_exllama or not self.disable_exllamav2):
if get_device(model) == torch.device("cpu") or (
hasattr(model, "hf_device_map") and any(d in model.hf_device_map for d in ["cpu", "disk"])
):
raise ValueError(
"Found modules on cpu/disk. Using Exllama backend requires all the modules to be on GPU."
"You can deactivate exllama backend by setting `disable_exllama=True` in the quantization config object"
"Found modules on cpu/disk. Using Exllama or Exllamav2 backend requires all the modules to be on GPU."
"You can deactivate exllama backend by setting `disable_exllama=True` or `disable_exllamav2=True` in the quantization config object"
)

class StoreAttr(object):
Expand Down Expand Up @@ -514,6 +530,7 @@ def pack_model(
group_size=self.group_size,
bits=self.bits,
disable_exllama=self.disable_exllama,
disable_exllamav2=self.disable_exllamav2,
)
logger.info("Packing model...")
layers = get_layers(model)
Expand Down Expand Up @@ -579,7 +596,8 @@ def load_quantized_model(
offload_folder: Optional[str] = None,
offload_buffers: Optional[str] = None,
offload_state_dict: bool = False,
disable_exllama: bool = False,
disable_exllama: bool = True,
disable_exllamav2: bool = False,
max_input_length: Optional[int] = None,
):
"""
Expand Down Expand Up @@ -615,6 +633,8 @@ def load_quantized_model(
picked contains `"disk"` values.
disable_exllama (`bool`, defaults to `False`):
Whether to use exllama backend. Only works with `bits` = 4.
disable_exllama (`bool`, defaults to `False`):
Whether to use exllamav2 backend. Only works with `bits` = 4.
max_input_length (`Optional[int]`, defaults to `None`):
The maximum input length. This is needed to initialize a buffer that depends on the maximum expected input length.
It is specific to the exllama backend with act-order.
Expand Down Expand Up @@ -648,6 +668,7 @@ def load_quantized_model(
) from err
quantizer = GPTQQuantizer.from_dict(quantize_config_dict)
quantizer.disable_exllama = disable_exllama
quantizer.disable_exllamav2 = disable_exllamav2
quantizer.max_input_length = max_input_length

model = quantizer.convert_model(model)
Expand Down
45 changes: 42 additions & 3 deletions tests/gptq/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class GPTQTest(unittest.TestCase):
group_size = 128
desc_act = False
disable_exllama = True
disable_exllamav2 = True

dataset = [
"auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."
Expand All @@ -69,6 +70,7 @@ def setUpClass(cls):
group_size=cls.group_size,
desc_act=cls.desc_act,
disable_exllama=cls.disable_exllama,
disable_exllamav2=cls.disable_exllamav2,
)

cls.quantized_model = cls.quantizer.quantize_model(cls.model_fp16, cls.tokenizer)
Expand Down Expand Up @@ -96,6 +98,7 @@ def test_quantized_layers_class(self):
group_size=self.group_size,
bits=self.bits,
disable_exllama=self.disable_exllama,
disable_exllamav2=self.disable_exllamav2,
)
self.assertTrue(self.quantized_model.transformer.h[0].mlp.dense_4h_to_h.__class__ == QuantLinear)

Expand Down Expand Up @@ -133,13 +136,14 @@ def test_serialization(self):
)
empty_model.tie_weights()
quantized_model_from_saved = load_quantized_model(
empty_model, save_folder=tmpdirname, device_map={"": 0}, disable_exllama=self.disable_exllama
empty_model, save_folder=tmpdirname, device_map={"": 0}, disable_exllama=self.disable_exllama, disable_exllamav2=self.disable_exllamav2
)
self.check_inference_correctness(quantized_model_from_saved)


class GPTQTestExllama(GPTQTest):
disable_exllama = False
disable_exllamav2 = True
EXPECTED_OUTPUTS = set()
EXPECTED_OUTPUTS.add("Hello my name is John, I am a professional photographer and I")
EXPECTED_OUTPUTS.add("Hello my name is jay and i am a student at university.")
Expand All @@ -153,6 +157,7 @@ class GPTQTestActOrder(GPTQTest):
EXPECTED_OUTPUTS.add("Hello my name is nathalie, I am a young girl from")

disable_exllama = True
disable_exllamav2 = True
desc_act = True

def test_generate_quality(self):
Expand All @@ -178,7 +183,7 @@ def test_exllama_serialization(self):
)
empty_model.tie_weights()
quantized_model_from_saved = load_quantized_model(
empty_model, save_folder=tmpdirname, device_map={"": 0}, disable_exllama=False
empty_model, save_folder=tmpdirname, device_map={"": 0}, disable_exllama=False, disable_exllamav2=True
)
self.check_inference_correctness(quantized_model_from_saved)

Expand All @@ -197,7 +202,7 @@ def test_exllama_max_input_length(self):
)
empty_model.tie_weights()
quantized_model_from_saved = load_quantized_model(
empty_model, save_folder=tmpdirname, device_map={"": 0}, disable_exllama=False, max_input_length=4028
empty_model, save_folder=tmpdirname, device_map={"": 0}, disable_exllama=False, max_input_length=4028, disable_exllamav2=True
)

prompt = "I am in Paris and" * 1000
Expand All @@ -213,6 +218,40 @@ def test_exllama_max_input_length(self):
quantized_model_from_saved.generate(**inp, num_beams=1, min_new_tokens=3, max_new_tokens=3)



class GPTQTestExllamav2(GPTQTest):
desc_act = False
disable_exllama = True
disable_exllamav2 = True

def test_generate_quality(self):
# don't need to test
pass

def test_serialization(self):
# don't need to test
pass

def test_exllama_serialization(self):
"""
Test the serialization of the model and the loading of the quantized weights with exllamav2 kernel
"""
from accelerate import init_empty_weights

with tempfile.TemporaryDirectory() as tmpdirname:
self.quantizer.save(self.quantized_model, tmpdirname)
self.quantized_model.config.save_pretrained(tmpdirname)
with init_empty_weights():
empty_model = AutoModelForCausalLM.from_config(
AutoConfig.from_pretrained(self.model_name), torch_dtype=torch.float16
)
empty_model.tie_weights()
quantized_model_from_saved = load_quantized_model(
empty_model, save_folder=tmpdirname, device_map={"": 0}, disable_exllamav2=False,
)
self.check_inference_correctness(quantized_model_from_saved)


class GPTQUtilsTest(unittest.TestCase):
"""
Test utilities
Expand Down

0 comments on commit 5db2541

Please sign in to comment.