Skip to content

Commit

Permalink
Fix openvino quantization config (huggingface#773)
Browse files Browse the repository at this point in the history
* enable string quant method

* fix

* fix docstrings

* format

* awq should be set to None for int8 quantization
  • Loading branch information
echarlaix authored Jun 21, 2024
1 parent c19723e commit 80e9bf6
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 12 deletions.
3 changes: 1 addition & 2 deletions optimum/commands/export/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from typing import TYPE_CHECKING, Optional

from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from transformers.utils.quantization_config import QuantizationMethod

from ...exporters import TasksManager
from ...intel.utils.import_utils import DIFFUSERS_IMPORT_ERROR, is_diffusers_available
Expand Down Expand Up @@ -289,7 +288,7 @@ def _get_default_int4_config(model_id_or_path, library_name):
"all_layers": None if is_int8 else self.args.all_layers,
"dataset": self.args.dataset,
"num_samples": self.args.num_samples,
"quant_method": QuantizationMethod.AWQ if self.args.awq else None,
"quant_method": "awq" if self.args.awq else "default",
"sensitivity_metric": self.args.sensitivity_metric,
"scale_estimation": self.args.scale_estimation,
}
Expand Down
9 changes: 5 additions & 4 deletions optimum/intel/openvino/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import torch
from transformers import PretrainedConfig
from transformers.utils.quantization_config import QuantizationConfigMixin, QuantizationMethod
from transformers.utils.quantization_config import QuantizationConfigMixin

from optimum.configuration_utils import BaseConfig

Expand Down Expand Up @@ -78,6 +78,7 @@
class OVQuantizationMethod(str, Enum):
DEFAULT = "default"
HYBRID = "hybrid"
AWQ = "awq"


@dataclass
Expand Down Expand Up @@ -171,7 +172,7 @@ class OVWeightQuantizationConfig(OVQuantizationConfigBase):
entries provided via this argument are used to create an instance of `nncf.IgnoredScope` class.
num_samples (`int`, *optional*):
The maximum number of samples composing the calibration dataset.
quant_method (`str`, defaults of OVQuantizationMethod.DEFAULT):
quant_method (`str or OVQuantizationMethod`, defaults of OVQuantizationMethod.DEFAULT):
Weight compression method to apply. Possible options:
- "default": default weight quantization will be applied.
- "awq": compressed weights will be computed according to the Activation-Aware-Quantization (AWQ)
Expand Down Expand Up @@ -199,7 +200,7 @@ def __init__(
sensitivity_metric: Optional[str] = None,
ignored_scope: Optional[dict] = None,
num_samples: Optional[int] = None,
quant_method: Union[QuantizationMethod, OVQuantizationMethod] = OVQuantizationMethod.DEFAULT,
quant_method: Union[str, OVQuantizationMethod] = OVQuantizationMethod.DEFAULT,
scale_estimation: bool = None,
**kwargs,
):
Expand All @@ -210,7 +211,7 @@ def __init__(
self.ratio = ratio
self.all_layers = all_layers
self.sensitivity_metric = sensitivity_metric
self.quant_method = quant_method
self.quant_method = OVQuantizationMethod(quant_method) if isinstance(quant_method, str) else quant_method
self.scale_estimation = scale_estimation
self.post_init()

Expand Down
3 changes: 1 addition & 2 deletions optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from transformers import AutoTokenizer, DataCollator, PreTrainedModel, default_data_collator
from transformers.pytorch_utils import Conv1D
from transformers.utils import is_accelerate_available
from transformers.utils.quantization_config import QuantizationMethod

from optimum.exporters.onnx.convert import check_dummy_inputs_are_allowed
from optimum.exporters.tasks import TasksManager
Expand Down Expand Up @@ -828,7 +827,7 @@ def _weight_only_quantization(
group_size=config.group_size,
all_layers=config.all_layers,
sensitivity_metric=sensitivity_metric,
awq=config.quant_method == QuantizationMethod.AWQ or None,
awq=getattr(config.quant_method, "name", "") == "AWQ" or None,
ignored_scope=config.get_ignored_scope_instance(),
dataset=dataset,
subset_size=config.num_samples if config.num_samples else 128,
Expand Down
22 changes: 18 additions & 4 deletions tests/openvino/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,20 @@ class OVWeightCompressionTest(unittest.TestCase):
),
16,
),
(
OVModelForCausalLM,
"llama_awq",
dict(
bits=4,
sym=True,
group_size=16,
ratio=0.8,
sensitivity_metric="mean_activation_magnitude",
dataset="c4",
quant_method="awq",
),
16,
),
)

SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION = (
Expand Down Expand Up @@ -413,9 +427,9 @@ def test_ovmodel_hybrid_quantization_with_custom_dataset(
]
model = model_cls.from_pretrained(model_id, export=True)
quantizer = OVQuantizer(model)
quantization_config = OVWeightQuantizationConfig(
bits=8, num_samples=3, quant_method=OVQuantizationMethod.HYBRID
)
quantization_config = OVWeightQuantizationConfig(bits=8, num_samples=3, quant_method="hybrid")
self.assertIsInstance(quantization_config.quant_method, OVQuantizationMethod.HYBRID)

quantizer.quantize(ov_config=OVConfig(quantization_config=quantization_config), calibration_dataset=dataset)
num_fake_quantize, num_int8, num_int4 = get_num_quantized_nodes(model.unet)
self.assertEqual(expected_num_fake_quantize, num_fake_quantize)
Expand Down Expand Up @@ -454,7 +468,7 @@ def test_ovmodel_4bit_auto_compression_with_config(
with tempfile.TemporaryDirectory() as tmp_dir:
quantization_config = OVWeightQuantizationConfig.from_dict(quantization_config)
model = model_cls.from_pretrained(model_id, export=True, quantization_config=quantization_config)
if quantization_config.quant_method == QuantizationMethod.AWQ or quantization_config.scale_estimation:
if quantization_config.quant_method.lower() == "awq" or quantization_config.scale_estimation:
# TODO: Check that AWQ and SE was actually applied
pass

Expand Down

0 comments on commit 80e9bf6

Please sign in to comment.