Skip to content

Commit

Permalink
FEAT / Bitsandbytes: Add dequantize API for bitsandbytes quantized …
Browse files Browse the repository at this point in the history
…models (#30806)

* add  method

* change method name

* more comments

* Apply suggestions from code review

Co-authored-by: Marc Sun <[email protected]>
Co-authored-by: amyeroberts <[email protected]>

* fixup

* add docstrings and fix comment

* warn users on the de-quantized dtype

* Update src/transformers/quantizers/base.py

Co-authored-by: Marc Sun <[email protected]>

* Update src/transformers/integrations/bitsandbytes.py

Co-authored-by: amyeroberts <[email protected]>

* final suggestion - use private method

---------

Co-authored-by: Marc Sun <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
3 people authored May 15, 2024
1 parent 58faa7b commit 3f43582
Show file tree
Hide file tree
Showing 9 changed files with 244 additions and 1 deletion.
23 changes: 22 additions & 1 deletion docs/source/en/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,27 @@ double_quant_config = BitsAndBytesConfig(
model_double_quant = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-13b", quantization_config=double_quant_config)
```

### Dequantizing `bitsandbytes` models

Once quantized, you can dequantize the model to the original precision. Note this might result in a small quality loss of the model. Make also sure to have enough GPU RAM to fit the dequantized model.
Below is how to perform dequantization on a 4-bit model using `bitsandbytes`.

```python
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer

model_id = "facebook/opt-125m"

model = AutoModelForCausalLM.from_pretrained(model_id, BitsAndBytesConfig(load_in_4bit=True))
tokenizer = AutoTokenizer.from_pretrained(model_id)

model.dequantize()

text = tokenizer("Hello my name is", return_tensors="pt").to(0)

out = model.generate(**text)
print(tokenizer.decode(out[0]))
```

## EETQ
The [EETQ](https://github.com/NetEase-FuXi/EETQ) library supports int8 per-channel weight-only quantization for NVIDIA GPUS. The high-performance GEMM and GEMV kernels are from FasterTransformer and TensorRT-LLM. It requires no calibration dataset and does not need to pre-quantize your model. Moreover, the accuracy degradation is negligible owing to the per-channel quantization.

Expand Down Expand Up @@ -794,4 +815,4 @@ model = transformers.AutoModelForCausalLM.from_pretrained(
### Optimized Runtime
HQQ supports various backends, including pure Pytorch and custom dequantization CUDA kernels. These backends are suitable for older gpus and peft/QLoRA training.
For faster inference, HQQ supports 4-bit fused kernels (TorchAO and Marlin), reaching up to 200 tokens/sec on a single 4090.
For more details on how to use the backends, please refer to https://github.com/mobiusml/hqq/?tab=readme-ov-file#backend
For more details on how to use the backends, please refer to https://github.com/mobiusml/hqq/?tab=readme-ov-file#backend
2 changes: 2 additions & 0 deletions src/transformers/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"replace_with_awq_linear",
],
"bitsandbytes": [
"dequantize_and_replace",
"get_keys_to_not_convert",
"replace_8bit_linear",
"replace_with_bnb_linear",
Expand Down Expand Up @@ -105,6 +106,7 @@
replace_with_awq_linear,
)
from .bitsandbytes import (
dequantize_and_replace,
get_keys_to_not_convert,
replace_8bit_linear,
replace_with_bnb_linear,
Expand Down
141 changes: 141 additions & 0 deletions src/transformers/integrations/bitsandbytes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib.metadata
import inspect
import warnings
from copy import deepcopy
from inspect import signature
Expand All @@ -16,7 +17,9 @@
from ..pytorch_utils import Conv1D

if is_accelerate_available():
import accelerate
from accelerate import init_empty_weights
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
from accelerate.utils import find_tied_parameters

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -322,3 +325,141 @@ def get_keys_to_not_convert(model):
filtered_module_names.append(name)

return filtered_module_names


# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None):
"""
Helper function to dequantize 4bit or 8bit bnb weights.
If the weight is not a bnb quantized weight, it will be returned as is.
"""
if not isinstance(weight, torch.nn.Parameter):
raise TypeError(f"Input weight should be of type nn.Parameter, got {type(weight)} instead")

cls_name = weight.__class__.__name__
if cls_name not in ("Params4bit", "Int8Params"):
return weight

if cls_name == "Params4bit":
output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
logger.warning_once(
f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
)
return output_tensor

if state.SCB is None:
state.SCB = weight.SCB

im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im)
im, Sim = bnb.functional.transform(im, "col32")
if state.CxB is None:
state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()


def _create_accelerate_new_hook(old_hook):
r"""
Creates a new hook based on the old hook. Use it only if you know what you are doing !
This method is a copy of: https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245
with some changes
"""
old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__)
old_hook_attr = old_hook.__dict__
filtered_old_hook_attr = {}
old_hook_init_signature = inspect.signature(old_hook_cls.__init__)
for k in old_hook_attr.keys():
if k in old_hook_init_signature.parameters:
filtered_old_hook_attr[k] = old_hook_attr[k]
new_hook = old_hook_cls(**filtered_old_hook_attr)
return new_hook


def _dequantize_and_replace(
model,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
has_been_replaced=False,
):
"""
Converts a quantized model into its dequantized original version. The newly converted model will have
some performance drop compared to the original model before quantization - use it only for specific usecases
such as QLoRA adapters merging.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
"""
quant_method = quantization_config.quantization_method()

target_cls = bnb.nn.Linear8bitLt if quant_method == "llm_int8" else bnb.nn.Linear4bit

for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)

if isinstance(module, target_cls) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
current_key_name_str = ".".join(current_key_name)

if not any(
(key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
):
bias = getattr(module, "bias", None)

device = module.weight.device
with init_empty_weights():
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=bias is not None)

if quant_method == "llm_int8":
state = module.state
else:
state = None

new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state))

if bias is not None:
new_module.bias = bias

# Create a new hook and attach it in case we use accelerate
if hasattr(module, "_hf_hook"):
old_hook = module._hf_hook
new_hook = _create_accelerate_new_hook(old_hook)

remove_hook_from_module(module)
add_hook_to_module(new_module, new_hook)

new_module.to(device)
model._modules[name] = new_module
if len(list(module.children())) > 0:
_, has_been_replaced = _dequantize_and_replace(
module,
modules_to_not_convert,
current_key_name,
quantization_config,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced


def dequantize_and_replace(
model,
modules_to_not_convert=None,
quantization_config=None,
):
model, has_been_replaced = _dequantize_and_replace(
model,
modules_to_not_convert=modules_to_not_convert,
quantization_config=quantization_config,
)

if not has_been_replaced:
logger.warning(
"For some reason the model has not been properly dequantized. You might see unexpected behavior."
)

return model
12 changes: 12 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,6 +1327,18 @@ def post_init(self):
self.init_weights()
self._backward_compatibility_gradient_checkpointing()

def dequantize(self):
"""
Potentially dequantize the model in case it has been quantized by a quantization method that support
dequantization.
"""
hf_quantizer = getattr(self, "hf_quantizer", None)

if hf_quantizer is None:
raise ValueError("You need to first quantize your model in order to dequantize it")

return hf_quantizer.dequantize(self)

def _backward_compatibility_gradient_checkpointing(self):
if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
self.gradient_checkpointing_enable()
Expand Down
17 changes: 17 additions & 0 deletions src/transformers/quantizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,23 @@ def postprocess_model(self, model: "PreTrainedModel", **kwargs):
"""
return self._process_model_after_weight_loading(model, **kwargs)

def dequantize(self, model):
"""
Potentially dequantize the model to retrive the original model, with some loss in accuracy / performance.
Note not all quantization schemes support this.
"""
model = self._dequantize(model)

# Delete quantizer and quantization config
del model.hf_quantizer

return model

def _dequantize(self, model):
raise NotImplementedError(
f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."
)

@abstractmethod
def _process_model_before_weight_loading(self, model, **kwargs):
...
Expand Down
8 changes: 8 additions & 0 deletions src/transformers/quantizers/quantizer_bnb_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,11 @@ def is_serializable(self):
@property
def is_trainable(self) -> bool:
return True

def _dequantize(self, model):
from ..integrations import dequantize_and_replace

model = dequantize_and_replace(
model, self.modules_to_not_convert, quantization_config=self.quantization_config
)
return model
8 changes: 8 additions & 0 deletions src/transformers/quantizers/quantizer_bnb_8bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,11 @@ def is_serializable(self):
@property
def is_trainable(self) -> bool:
return version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.37.0")

def _dequantize(self, model):
from ..integrations import dequantize_and_replace

model = dequantize_and_replace(
model, self.modules_to_not_convert, quantization_config=self.quantization_config
)
return model
17 changes: 17 additions & 0 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,23 @@ def test_generate_quality_config(self):

self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)

def test_generate_quality_dequantize(self):
r"""
Test that loading the model and unquantize it produce correct results
"""
bnb_config = BitsAndBytesConfig(load_in_4bit=True)

model_4bit = AutoModelForCausalLM.from_pretrained(
self.model_name, quantization_config=bnb_config, device_map="auto"
)

model_4bit.dequantize()

encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model_4bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)

self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)

def test_device_and_dtype_assignment(self):
r"""
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
Expand Down
17 changes: 17 additions & 0 deletions tests/quantization/bnb/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,23 @@ def test_generate_quality_config(self):

self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)

def test_generate_quality_dequantize(self):
r"""
Test that loading the model and dequantizing it produce correct results
"""
bnb_config = BitsAndBytesConfig(load_in_8bit=True)

model_8bit = AutoModelForCausalLM.from_pretrained(
self.model_name, quantization_config=bnb_config, device_map="auto"
)

model_8bit.dequantize()

encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model_8bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)

self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)

def test_raise_if_config_and_load_in_8bit(self):
r"""
Test that loading the model with the config and `load_in_8bit` raises an error
Expand Down

0 comments on commit 3f43582

Please sign in to comment.