Skip to content

Commit

Permalink
default to exllama when exllamav2 is disabled (#1494)
Browse files Browse the repository at this point in the history
* fix logic

* simplify tests
  • Loading branch information
SunMarc authored Oct 30, 2023
1 parent 01dd5c3 commit 8e7588b
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 11 deletions.
8 changes: 4 additions & 4 deletions docs/source/llm_quantization/usage_guides/quantization.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ quantized_model = load_quantized_model(empty_model, save_folder=save_folder, dev

### Exllama kernels for faster inference

With the release of the exllamav2 kernel, you can get faster inference speed compared to the exllama kernels for 4-bit model. It is activated by default: `disable_exllamav2=False` in [`~optimum.gptq.load_quantized_model`]. In order to use these kernels, you need to have the entire model on gpus.
With the release of exllamav2 kernels, you can get faster inference speed compared to exllama kernels for 4-bit model. It is activated by default: `disable_exllamav2=False` in [`~optimum.gptq.load_quantized_model`]. In order to use these kernels, you need to have the entire model on gpus.

```py
from optimum.gptq import GPTQQuantizer, load_quantized_model
Expand All @@ -89,7 +89,7 @@ empty_model.tie_weights()
quantized_model = load_quantized_model(empty_model, save_folder=save_folder, device_map="auto")
```

If you wish to use exllama kernels, you will have to disable the exllamav2 kernel and activate the exllama kernel:
If you wish to use exllama kernels, you will have to disable exllamav2 kernels:

```py
from optimum.gptq import GPTQQuantizer, load_quantized_model
Expand All @@ -99,10 +99,10 @@ from accelerate import init_empty_weights
with init_empty_weights():
empty_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
empty_model.tie_weights()
quantized_model = load_quantized_model(empty_model, save_folder=save_folder, device_map="auto", disable_exllama=False, disable_exllamav2=True)
quantized_model = load_quantized_model(empty_model, save_folder=save_folder, device_map="auto", disable_exllamav2=True)
```

Note that only 4-bit models are supported with exllama/exllamav2 kernels for now. Furthermore, it is recommended to disable the exllama/exllamav2 kernel when you are finetuning your model with peft.
Note that only 4-bit models are supported with exllama/exllamav2 kernels for now. Furthermore, it is recommended to disable exllama/exllamav2 kernels when you are finetuning your model with peft.

You can find the benchmark of these kernels [here](https://github.com/huggingface/optimum/tree/main/tests/benchmark#gptq-benchmark)

Expand Down
20 changes: 16 additions & 4 deletions optimum/gptq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(
module_name_preceding_first_block: Optional[List[str]] = None,
batch_size: int = 1,
pad_token_id: Optional[int] = None,
disable_exllama: bool = True,
disable_exllama: Optional[bool] = None,
disable_exllamav2: bool = False,
max_input_length: Optional[int] = None,
*args,
Expand Down Expand Up @@ -108,7 +108,7 @@ def __init__(
The batch size of the dataset
pad_token_id (`Optional[int]`, defaults to `None`):
The pad token id. Needed to prepare the dataset when `batch_size` > 1.
disable_exllama (`bool`, defaults to `True`):
disable_exllama (`Optional[bool]`, defaults to `None`):
Whether to use exllama backend. Only works with `bits` = 4.
disable_exllamav2 (`bool`, defaults to `False`):
Whether to use exllamav2 backend. Only works with `bits` = 4.
Expand Down Expand Up @@ -145,6 +145,12 @@ def __init__(
raise ValueError(
"disable_exllamav2 and disable_exllama are both set to `False`. Please disable one of the kernels."
)
# If disable_exllamav2 is True, we want to fall back on the exllama kernel and not the cuda/cuda_old ones.
if self.disable_exllama is None:
if self.disable_exllamav2:
self.disable_exllama = False
else:
self.disable_exllama = True

def to_dict(self):
"""
Expand Down Expand Up @@ -598,7 +604,7 @@ def load_quantized_model(
offload_folder: Optional[str] = None,
offload_buffers: Optional[str] = None,
offload_state_dict: bool = False,
disable_exllama: bool = True,
disable_exllama: Optional[bool] = None,
disable_exllamav2: bool = False,
max_input_length: Optional[int] = None,
):
Expand Down Expand Up @@ -633,7 +639,7 @@ def load_quantized_model(
If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if
the weight of the CPU state dict + the biggest shard does not fit. Will default to `True` if the device map
picked contains `"disk"` values.
disable_exllama (`bool`, defaults to `False`):
disable_exllama (`Optional[bool]`, defaults to `None`):
Whether to use exllama backend. Only works with `bits` = 4.
disable_exllama (`bool`, defaults to `False`):
Whether to use exllamav2 backend. Only works with `bits` = 4.
Expand All @@ -657,6 +663,12 @@ def load_quantized_model(
device_map = {"": torch.cuda.current_device()}
logger.info("The device_map was not initialized." "Setting device_map to `{'':torch.cuda.current_device()}`.")

if disable_exllama is None:
if disable_exllamav2:
disable_exllama = False
else:
disable_exllama = True

# this branch will check if model is from huggingface
try:
if hasattr(model, "config") and hasattr(model.config, "quantization_config"):
Expand Down
4 changes: 1 addition & 3 deletions tests/gptq/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def test_exllama_serialization(self):
)
empty_model.tie_weights()
quantized_model_from_saved = load_quantized_model(
empty_model, save_folder=tmpdirname, device_map={"": 0}, disable_exllama=False, disable_exllamav2=True
empty_model, save_folder=tmpdirname, device_map={"": 0}, disable_exllamav2=True
)
self.check_inference_correctness(quantized_model_from_saved)

Expand All @@ -209,7 +209,6 @@ def test_exllama_max_input_length(self):
empty_model,
save_folder=tmpdirname,
device_map={"": 0},
disable_exllama=False,
max_input_length=4028,
disable_exllamav2=True,
)
Expand Down Expand Up @@ -258,7 +257,6 @@ def test_exllama_serialization(self):
empty_model,
save_folder=tmpdirname,
device_map={"": 0},
disable_exllamav2=False,
)
self.check_inference_correctness(quantized_model_from_saved)

Expand Down

0 comments on commit 8e7588b

Please sign in to comment.