Skip to content

Commit

Permalink
Allow compressed-tensors quantized model to be trained (#34520)
Browse files Browse the repository at this point in the history
* populate quantization_config for kv-cache-scheme only configs

* make compressed-tensors quantized models trainable

* populate versions on quant config

* pass oneshot then finetune

* remove breakpoint

* SunMarc comments and fix to_dict logic

* lint

* lint

* test

* comment

* comments'
  • Loading branch information
horheynm authored Nov 28, 2024
1 parent 44af935 commit 57ca9e6
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 12 deletions.
5 changes: 5 additions & 0 deletions src/transformers/quantizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,11 @@ def _dequantize(self, model):
f"{self.quantization_config.quant_method} has no implementation of `dequantize`, please raise an issue on GitHub."
)

@property
def is_qat_trainable(self) -> bool:
"""Flag indicating whether the quantized model can carry out quantization aware training"""
return False

@abstractmethod
def _process_model_before_weight_loading(self, model, **kwargs): ...

Expand Down
17 changes: 12 additions & 5 deletions src/transformers/quantizers/quantizer_compressed_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,19 @@ def _process_model_before_weight_loading(self, model, **kwargs):
ct_quantization_config = self.compressor.quantization_config
apply_quantization_config(model, ct_quantization_config, run_compressed=True)

def _process_model_after_weight_loading(self, model, **kwargs):
def _process_model_after_weight_loading(self, model, **kwargs) -> None:
pass

@property
def is_trainable(self):
return False
def is_trainable(self) -> bool:
"""Models quantized using compressed tensors can be finetuned"""
return True

def is_serializable(self, safe_serialization=None):
return False
@property
def is_qat_trainable(self) -> bool:
"""Loaded Models can carry out quantization aware training"""
return True

def is_serializable(self, safe_serialization=None) -> bool:
"""Models quantized using compressed tensors can be saved to disk"""
return True
6 changes: 5 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,14 +540,18 @@ def __init__(
getattr(model, "hf_quantizer", None) is not None and model.hf_quantizer.is_trainable
)

_is_model_quantized_and_qat_trainable = getattr(model, "hf_quantizer", None) is not None and getattr(
model.hf_quantizer, "is_qat_trainable", False
)

# Filter out quantized + compiled models
if _is_quantized_and_base_model and hasattr(model, "_orig_mod"):
raise ValueError(
"You cannot fine-tune quantized model with `torch.compile()` make sure to pass a non-compiled model when fine-tuning a quantized model with PEFT"
)

# At this stage the model is already loaded
if _is_quantized_and_base_model and not _is_peft_model(model):
if _is_quantized_and_base_model and not _is_peft_model(model) and not _is_model_quantized_and_qat_trainable:
raise ValueError(
"You cannot perform fine-tuning on purely quantized models. Please attach trainable adapters on top of"
" the quantized model to correctly perform fine-tuning. Please see: https://huggingface.co/docs/transformers/peft"
Expand Down
20 changes: 14 additions & 6 deletions src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,6 +1150,7 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
Returns:
[`QuantizationConfigMixin`]: The configuration object instantiated from those parameters.
"""

if "quantization_config" in config_dict:
config_dict = dict(
sparsity_config=config_dict.get("sparsity_config"),
Expand All @@ -1160,16 +1161,23 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):

def to_dict(self) -> Dict[str, Any]:
"""
Quantization config to be added to config.json
Serializes this instance to a Python dictionary. Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
quantization_config = self.quantization_config.dict() if self.quantization_config is not None else None
sparsity_config = self.sparsity_config.dict() if self.sparsity_config is not None else None
quantization_config = {}
if self.quantization_config is not None:
quantization_config = self.quantization_config.dict()
else:
quantization_config["quant_method"] = QuantizationMethod.COMPRESSED_TENSORS

return {
"quantization_config": quantization_config,
"sparsity_config": sparsity_config,
}
if self.sparsity_config is not None:
quantization_config["sparsity_config"] = self.sparsity_config.dict()
else:
quantization_config["sparsity_config"] = {}

return quantization_config

def to_diff_dict(self) -> Dict[str, Any]:
"""
Expand Down

0 comments on commit 57ca9e6

Please sign in to comment.