diff --git a/docs/source/en/quantization.md b/docs/source/en/quantization.md index ae4f44f6b800b7..22645505664607 100755 --- a/docs/source/en/quantization.md +++ b/docs/source/en/quantization.md @@ -642,6 +642,27 @@ double_quant_config = BitsAndBytesConfig( model_double_quant = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-13b", quantization_config=double_quant_config) ``` +### Dequantizing `bitsandbytes` models + +Once quantized, you can dequantize the model to the original precision. Note this might result in a small quality loss of the model. Make also sure to have enough GPU RAM to fit the dequantized model. +Below is how to perform dequantization on a 4-bit model using `bitsandbytes`. + +```python +from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer + +model_id = "facebook/opt-125m" + +model = AutoModelForCausalLM.from_pretrained(model_id, BitsAndBytesConfig(load_in_4bit=True)) +tokenizer = AutoTokenizer.from_pretrained(model_id) + +model.dequantize() + +text = tokenizer("Hello my name is", return_tensors="pt").to(0) + +out = model.generate(**text) +print(tokenizer.decode(out[0])) +``` + ## EETQ The [EETQ](https://github.com/NetEase-FuXi/EETQ) library supports int8 per-channel weight-only quantization for NVIDIA GPUS. The high-performance GEMM and GEMV kernels are from FasterTransformer and TensorRT-LLM. It requires no calibration dataset and does not need to pre-quantize your model. Moreover, the accuracy degradation is negligible owing to the per-channel quantization. @@ -794,4 +815,4 @@ model = transformers.AutoModelForCausalLM.from_pretrained( ### Optimized Runtime HQQ supports various backends, including pure Pytorch and custom dequantization CUDA kernels. These backends are suitable for older gpus and peft/QLoRA training. For faster inference, HQQ supports 4-bit fused kernels (TorchAO and Marlin), reaching up to 200 tokens/sec on a single 4090. -For more details on how to use the backends, please refer to https://github.com/mobiusml/hqq/?tab=readme-ov-file#backend \ No newline at end of file +For more details on how to use the backends, please refer to https://github.com/mobiusml/hqq/?tab=readme-ov-file#backend diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 239b72782ff922..9b838bd1608490 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -25,6 +25,7 @@ "replace_with_awq_linear", ], "bitsandbytes": [ + "dequantize_and_replace", "get_keys_to_not_convert", "replace_8bit_linear", "replace_with_bnb_linear", @@ -105,6 +106,7 @@ replace_with_awq_linear, ) from .bitsandbytes import ( + dequantize_and_replace, get_keys_to_not_convert, replace_8bit_linear, replace_with_bnb_linear, diff --git a/src/transformers/integrations/bitsandbytes.py b/src/transformers/integrations/bitsandbytes.py index f340c1db823731..74d1c92b11fc46 100644 --- a/src/transformers/integrations/bitsandbytes.py +++ b/src/transformers/integrations/bitsandbytes.py @@ -1,4 +1,5 @@ import importlib.metadata +import inspect import warnings from copy import deepcopy from inspect import signature @@ -16,7 +17,9 @@ from ..pytorch_utils import Conv1D if is_accelerate_available(): + import accelerate from accelerate import init_empty_weights + from accelerate.hooks import add_hook_to_module, remove_hook_from_module from accelerate.utils import find_tied_parameters logger = logging.get_logger(__name__) @@ -322,3 +325,141 @@ def get_keys_to_not_convert(model): filtered_module_names.append(name) return filtered_module_names + + +# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41 +def dequantize_bnb_weight(weight: torch.nn.Parameter, state=None): + """ + Helper function to dequantize 4bit or 8bit bnb weights. + + If the weight is not a bnb quantized weight, it will be returned as is. + """ + if not isinstance(weight, torch.nn.Parameter): + raise TypeError(f"Input weight should be of type nn.Parameter, got {type(weight)} instead") + + cls_name = weight.__class__.__name__ + if cls_name not in ("Params4bit", "Int8Params"): + return weight + + if cls_name == "Params4bit": + output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) + logger.warning_once( + f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`" + ) + return output_tensor + + if state.SCB is None: + state.SCB = weight.SCB + + im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device) + im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im) + im, Sim = bnb.functional.transform(im, "col32") + if state.CxB is None: + state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB) + out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB) + return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t() + + +def _create_accelerate_new_hook(old_hook): + r""" + Creates a new hook based on the old hook. Use it only if you know what you are doing ! + This method is a copy of: https://github.com/huggingface/peft/blob/748f7968f3a31ec06a1c2b0328993319ad9a150a/src/peft/utils/other.py#L245 + with some changes + """ + old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__) + old_hook_attr = old_hook.__dict__ + filtered_old_hook_attr = {} + old_hook_init_signature = inspect.signature(old_hook_cls.__init__) + for k in old_hook_attr.keys(): + if k in old_hook_init_signature.parameters: + filtered_old_hook_attr[k] = old_hook_attr[k] + new_hook = old_hook_cls(**filtered_old_hook_attr) + return new_hook + + +def _dequantize_and_replace( + model, + modules_to_not_convert=None, + current_key_name=None, + quantization_config=None, + has_been_replaced=False, +): + """ + Converts a quantized model into its dequantized original version. The newly converted model will have + some performance drop compared to the original model before quantization - use it only for specific usecases + such as QLoRA adapters merging. + + Returns the converted model and a boolean that indicates if the conversion has been successfull or not. + """ + quant_method = quantization_config.quantization_method() + + target_cls = bnb.nn.Linear8bitLt if quant_method == "llm_int8" else bnb.nn.Linear4bit + + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + + if isinstance(module, target_cls) and name not in modules_to_not_convert: + # Check if the current key is not in the `modules_to_not_convert` + current_key_name_str = ".".join(current_key_name) + + if not any( + (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert + ): + bias = getattr(module, "bias", None) + + device = module.weight.device + with init_empty_weights(): + new_module = torch.nn.Linear(module.in_features, module.out_features, bias=bias is not None) + + if quant_method == "llm_int8": + state = module.state + else: + state = None + + new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state)) + + if bias is not None: + new_module.bias = bias + + # Create a new hook and attach it in case we use accelerate + if hasattr(module, "_hf_hook"): + old_hook = module._hf_hook + new_hook = _create_accelerate_new_hook(old_hook) + + remove_hook_from_module(module) + add_hook_to_module(new_module, new_hook) + + new_module.to(device) + model._modules[name] = new_module + if len(list(module.children())) > 0: + _, has_been_replaced = _dequantize_and_replace( + module, + modules_to_not_convert, + current_key_name, + quantization_config, + has_been_replaced=has_been_replaced, + ) + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_replaced + + +def dequantize_and_replace( + model, + modules_to_not_convert=None, + quantization_config=None, +): + model, has_been_replaced = _dequantize_and_replace( + model, + modules_to_not_convert=modules_to_not_convert, + quantization_config=quantization_config, + ) + + if not has_been_replaced: + logger.warning( + "For some reason the model has not been properly dequantized. You might see unexpected behavior." + ) + + return model diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d19f928340c1e0..37f35a3433641c 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1327,6 +1327,18 @@ def post_init(self): self.init_weights() self._backward_compatibility_gradient_checkpointing() + def dequantize(self): + """ + Potentially dequantize the model in case it has been quantized by a quantization method that support + dequantization. + """ + hf_quantizer = getattr(self, "hf_quantizer", None) + + if hf_quantizer is None: + raise ValueError("You need to first quantize your model in order to dequantize it") + + return hf_quantizer.dequantize(self) + def _backward_compatibility_gradient_checkpointing(self): if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False): self.gradient_checkpointing_enable() diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py index 226995eea0ca5b..1cfb0d58563ede 100644 --- a/src/transformers/quantizers/base.py +++ b/src/transformers/quantizers/base.py @@ -194,6 +194,23 @@ def postprocess_model(self, model: "PreTrainedModel", **kwargs): """ return self._process_model_after_weight_loading(model, **kwargs) + def dequantize(self, model): + """ + Potentially dequantize the model to retrive the original model, with some loss in accuracy / performance. + Note not all quantization schemes support this. + """ + model = self._dequantize(model) + + # Delete quantizer and quantization config + del model.hf_quantizer + + return model + + def _dequantize(self, model): + raise NotImplementedError( + f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub." + ) + @abstractmethod def _process_model_before_weight_loading(self, model, **kwargs): ... diff --git a/src/transformers/quantizers/quantizer_bnb_4bit.py b/src/transformers/quantizers/quantizer_bnb_4bit.py index c49610bbfe87a2..f204d4e02c56e0 100644 --- a/src/transformers/quantizers/quantizer_bnb_4bit.py +++ b/src/transformers/quantizers/quantizer_bnb_4bit.py @@ -312,3 +312,11 @@ def is_serializable(self): @property def is_trainable(self) -> bool: return True + + def _dequantize(self, model): + from ..integrations import dequantize_and_replace + + model = dequantize_and_replace( + model, self.modules_to_not_convert, quantization_config=self.quantization_config + ) + return model diff --git a/src/transformers/quantizers/quantizer_bnb_8bit.py b/src/transformers/quantizers/quantizer_bnb_8bit.py index 2d24c3404972ab..906457f31052ef 100644 --- a/src/transformers/quantizers/quantizer_bnb_8bit.py +++ b/src/transformers/quantizers/quantizer_bnb_8bit.py @@ -281,3 +281,11 @@ def is_serializable(self): @property def is_trainable(self) -> bool: return version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.37.0") + + def _dequantize(self, model): + from ..integrations import dequantize_and_replace + + model = dequantize_and_replace( + model, self.modules_to_not_convert, quantization_config=self.quantization_config + ) + return model diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 39d5598e37576a..443b1020a30e07 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -239,6 +239,23 @@ def test_generate_quality_config(self): self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) + def test_generate_quality_dequantize(self): + r""" + Test that loading the model and unquantize it produce correct results + """ + bnb_config = BitsAndBytesConfig(load_in_4bit=True) + + model_4bit = AutoModelForCausalLM.from_pretrained( + self.model_name, quantization_config=bnb_config, device_map="auto" + ) + + model_4bit.dequantize() + + encoded_input = self.tokenizer(self.input_text, return_tensors="pt") + output_sequences = model_4bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + + self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) + def test_device_and_dtype_assignment(self): r""" Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error. diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index 227273d278d146..8043a1201b765e 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -285,6 +285,23 @@ def test_generate_quality_config(self): self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) + def test_generate_quality_dequantize(self): + r""" + Test that loading the model and dequantizing it produce correct results + """ + bnb_config = BitsAndBytesConfig(load_in_8bit=True) + + model_8bit = AutoModelForCausalLM.from_pretrained( + self.model_name, quantization_config=bnb_config, device_map="auto" + ) + + model_8bit.dequantize() + + encoded_input = self.tokenizer(self.input_text, return_tensors="pt") + output_sequences = model_8bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10) + + self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS) + def test_raise_if_config_and_load_in_8bit(self): r""" Test that loading the model with the config and `load_in_8bit` raises an error