Skip to content

Commit

Permalink
[OV] Introduce --quant-mode cli argument enabling full quantization…
Browse files Browse the repository at this point in the history
… via optimum-cli (#1061)

* Introduce --quant-mode cli argument

* Make int8 by default

* Add a test

* Add documentation

* Fix command

* Replace 'int8/int8' by 'int8'

* Add missing docstring

* Add trust_remote_code

* Fix condition

* Trigger Tests

* Trigger Tests
  • Loading branch information
nikita-savelyevv authored Dec 20, 2024
1 parent 0a09651 commit ea6fa42
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 11 deletions.
20 changes: 17 additions & 3 deletions docs/source/openvino/export.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ Check out the help for more options:

```text
usage: optimum-cli export openvino [-h] -m MODEL [--task TASK] [--framework {pt,tf}] [--trust-remote-code]
[--weight-format {fp32,fp16,int8,int4,mxfp4,nf4}]
[--weight-format {fp32,fp16,int8,int4,mxfp4,nf4}] [--quant-mode {int8}]
[--library {transformers,diffusers,timm,sentence_transformers,open_clip}]
[--cache_dir CACHE_DIR] [--pad-token-id PAD_TOKEN_ID] [--ratio RATIO] [--sym]
[--group-size GROUP_SIZE] [--backup-precision {none,int8_sym,int8_asym}]
[--dataset DATASET] [--all-layers] [--awq] [--scale-estimation] [--gptq]
[--lora-correction] [--sensitivity-metric SENSITIVITY_METRIC]
[--num-samples NUM_SAMPLES] [--disable-stateful] [--disable-convert-tokenizer]
[--smooth-quant-alpha SMOOTH_QUANT_ALPHA]
output

optional arguments:
Expand Down Expand Up @@ -66,6 +67,10 @@ Optional arguments:
on your local machine arbitrary code present in the model repository.
--weight-format {fp32,fp16,int8,int4,mxfp4,nf4}
The weight format of the exported model.
--quant-mode {int8}
Quantization precision mode. This is used for applying full model quantization including
activations. The only currently supported choice is 'int8' for int8 quantization of both
weights and activations.
--library {transformers,diffusers,timm,sentence_transformers,open_clip}
The library used to load the model before export. If not provided, will attempt to infer the
local checkpoint's library
Expand Down Expand Up @@ -102,8 +107,8 @@ Optional arguments:
weight compression is applied, they are compressed to INT8.
--awq Whether to apply AWQ algorithm. AWQ improves generation quality of INT4-compressed LLMs, but
requires additional time for tuning weights on a calibration dataset. To run AWQ, please also
provide a dataset argument. Note: it is possible that there will be no matching patterns in the
model to apply AWQ, in such case it will be skipped.
provide a dataset argument. Note: it is possible that there will be no matching patterns in
the model to apply AWQ, in such case it will be skipped.
--scale-estimation Indicates whether to apply a scale estimation algorithm that minimizes the L2 error between
the original and compressed layers. Providing a dataset is required to run scale estimation.
Please note, that applying scale estimation takes additional memory and time.
Expand All @@ -128,6 +133,9 @@ Optional arguments:
OpenVINO native inference code that expects KV-cache inputs and outputs in the model.
--disable-convert-tokenizer
Do not add converted tokenizer and detokenizer OpenVINO models.
--smooth-quant-alpha SMOOTH_QUANT_ALPHA
SmoothQuant alpha parameter that improves the distribution of activations before MatMul layers
and reduces quantization error. Valid only when activations quantization is enabled.
```

You can also apply fp16, 8-bit or 4-bit weight-only quantization on the Linear, Convolutional and Embedding layers when exporting your model by setting `--weight-format` to respectively `fp16`, `int8` or `int4`.
Expand Down Expand Up @@ -158,6 +166,12 @@ Models larger than 1 billion parameters are exported to the OpenVINO format with
</Tip>


Besides weight-only quantization, you can also apply full model quantization including activations by setting `--quant-mode` to `int8`. This will quantize both weights and activations of Linear, Convolutional and some other layers to int8. Currently this is only supported for speech-to-text models. Please see example below.

```bash
optimum-cli export openvino -m openai/whisper-large-v3-turbo --quant-mode int8 --dataset librispeech --num-samples 32 --smooth-quant-alpha 0.9 ./whisper-large-v3-turbo
```

### Decoder models

For models with a decoder, we enable the re-use of past keys and values by default. This allows to avoid recomputing the same intermediate activations at each generation step. To export the model without, you will need to remove the `-with-past` suffix when specifying the task.
Expand Down
69 changes: 64 additions & 5 deletions optimum/commands/export/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ def parse_args_openvino(parser: "ArgumentParser"):
default=None,
help="The weight format of the exported model.",
)
optional_group.add_argument(
"--quant-mode",
type=str,
choices=["int8"],
default=None,
help=(
"Quantization precision mode. This is used for applying full model quantization including activations. "
"The only currently supported choice is 'int8' for int8 quantization of both weights and activations."
),
)
optional_group.add_argument(
"--library",
type=str,
Expand Down Expand Up @@ -228,6 +238,15 @@ def parse_args_openvino(parser: "ArgumentParser"):
action="store_true",
help="Do not add converted tokenizer and detokenizer OpenVINO models.",
)
optional_group.add_argument(
"--smooth-quant-alpha",
type=float,
default=None,
help=(
"SmoothQuant alpha parameter that improves the distribution of activations before MatMul layers and "
"reduces quantization error. Valid only when activations quantization is enabled."
),
)


def no_compression_parameter_provided(args):
Expand All @@ -252,6 +271,20 @@ def no_compression_parameter_provided(args):
)


def no_quantization_parameter_provided(args):
return all(
(
it is None
for it in (
args.sym,
args.dataset,
args.num_samples,
args.smooth_quant_alpha,
)
)
)


class OVExportCommand(BaseOptimumCLICommand):
COMMAND = CommandInfo(name="openvino", help="Export PyTorch models to OpenVINO IR.")

Expand Down Expand Up @@ -291,16 +324,21 @@ def run(self):
else:
library_name = self.args.library

if self.args.weight_format is None:
if self.args.weight_format is None and self.args.quant_mode is None:
ov_config = None
if not no_compression_parameter_provided(self.args):
raise ValueError(
"Some compression parameters are provided, but the weight format is not specified. "
"Please provide it with --weight-format argument."
)
if not no_quantization_parameter_provided(self.args):
raise ValueError(
"Some quantization parameters are provided, but the quantization mode is not specified. "
"Please provide it with --quant-mode argument."
)
elif self.args.weight_format in {"fp16", "fp32"}:
ov_config = OVConfig(dtype=self.args.weight_format)
else:
elif self.args.weight_format is not None:
# For int4 quantization if no parameter is provided, then use the default config if exists
if no_compression_parameter_provided(self.args) and self.args.weight_format == "int4":
quantization_config = get_default_int4_config(self.args.model)
Expand All @@ -326,6 +364,21 @@ def run(self):
if quantization_config.get("dataset", None) is not None:
quantization_config["trust_remote_code"] = self.args.trust_remote_code
ov_config = OVConfig(quantization_config=quantization_config)
else:
if self.args.quant_mode != "int8":
raise ValueError("Only 'int8' quantization mode is currently supported.")

quantization_config = {
"weight_format": self.args.quant_mode,
"activation_format": self.args.quant_mode,
"bits": 8,
"sym": self.args.sym or False,
"dataset": self.args.dataset,
"num_samples": self.args.num_samples,
"smooth_quant_alpha": self.args.smooth_quant_alpha,
"trust_remote_code": self.args.trust_remote_code,
}
ov_config = OVConfig(quantization_config=quantization_config)

quantization_config = ov_config.quantization_config if ov_config else None
quantize_with_dataset = quantization_config and getattr(quantization_config, "dataset", None) is not None
Expand Down Expand Up @@ -368,17 +421,23 @@ def run(self):
model.save_pretrained(self.args.output)
if not self.args.disable_convert_tokenizer:
maybe_convert_tokenizers(library_name, self.args.output, model, task=task)
elif (task.startswith("text-generation") and quantize_with_dataset) or (
task == "image-text-to-text" and quantization_config is not None
elif (
quantize_with_dataset
and (task.startswith("text-generation") or task == "automatic-speech-recognition")
or (task == "image-text-to-text" and quantization_config is not None)
):
if task.startswith("text-generation"):
from optimum.intel import OVModelForCausalLM

model_cls = OVModelForCausalLM
else:
elif task == "image-text-to-text":
from optimum.intel import OVModelForVisualCausalLM

model_cls = OVModelForVisualCausalLM
else:
from optimum.intel import OVModelForSpeechSeq2Seq

model_cls = OVModelForSpeechSeq2Seq

# In this case, to apply quantization an instance of a model class is required
model = model_cls.from_pretrained(
Expand Down
32 changes: 30 additions & 2 deletions optimum/intel/openvino/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def __init__(
tokenizer: Optional[str] = None,
processor: Optional[str] = None,
trust_remote_code: bool = False,
weight_format: Optional[str] = None,
**kwargs,
):
"""
Expand All @@ -279,6 +280,18 @@ def __init__(
entries provided via this argument are used to create an instance of `nncf.IgnoredScope` class.
num_samples (`int`, *optional*):
The maximum number of samples composing the calibration dataset.
dataset (`str or List[str]`, *optional*):
The dataset used for data-aware optimization with NNCF.
tokenizer (`str`, *optional*):
The tokenizer used to process the dataset.
processor (`str`, *optional*):
A transformers processor used to process the dataset inputs.
trust_remote_code (`bool`, defaults to `False`):
Allows to use custom code for the modeling hosted in the model repository. This option should only be
set for repositories you trust and in which you have read the code, as it will execute on your local
machine arbitrary code present in the model repository.
weight_format (`str`, *optional*):
Data format weights are compressed to.
"""
self.bits = bits
self.sym = sym
Expand All @@ -287,6 +300,7 @@ def __init__(
self.tokenizer = tokenizer
self.processor = processor
self.trust_remote_code = trust_remote_code
self.weight_format = weight_format

if isinstance(ignored_scope, nncf.IgnoredScope):
ignored_scope = ignored_scope.__dict__
Expand Down Expand Up @@ -370,7 +384,7 @@ class OVWeightQuantizationConfig(OVQuantizationConfigBase):
scale_estimation (`bool`, *optional*):
Indicates whether to apply a scale estimation algorithm that minimizes the L2 error between the original and
compressed layers. Providing a dataset is required to run scale estimation.
weight_format (`str`, defaults to 'int'):
weight_format (`str`, *optional*):
Data format weights are compressed to. Possible values: ['int4', 'int8', 'mxfp4', 'nf4'].
qptq (`bool`, *optional*):
Whether to apply GPTQ algorithm. GPTQ optimizes compressed weights in a layer-wise fashion to minimize the
Expand Down Expand Up @@ -425,14 +439,14 @@ def __init__(
tokenizer=tokenizer,
processor=processor,
trust_remote_code=trust_remote_code,
weight_format=weight_format,
)
self.group_size = group_size or (-1 if bits == 8 else 128)
self.ratio = ratio
self.all_layers = all_layers
self.sensitivity_metric = sensitivity_metric
self.quant_method = OVQuantizationMethod(quant_method) if isinstance(quant_method, str) else quant_method
self.scale_estimation = scale_estimation
self.weight_format = weight_format
self.gptq = gptq
self.lora_correction = lora_correction
self.backup_precision = backup_precision
Expand Down Expand Up @@ -578,6 +592,8 @@ def __init__(
processor: Optional[str] = None,
trust_remote_code: bool = False,
smooth_quant_alpha: Optional[float] = None,
weight_format: Optional[str] = "int8",
activation_format: Optional[str] = "int8",
**kwargs,
):
"""
Expand Down Expand Up @@ -621,6 +637,10 @@ def __init__(
smooth_quant_alpha (`float`, *optional*):
SmoothQuant alpha parameter that improves the distribution of activations before MatMul layers and
reduces quantization error.
weight_format (`str`, defaults to "int8"):
Data format weights are quantized to. Possible values: ['int8'].
activation_format (`str`, defaults to "int8"):
Data format activations are compressed to. Possible values: ['int8'].
"""
super().__init__(
bits=bits,
Expand All @@ -631,11 +651,13 @@ def __init__(
tokenizer=tokenizer,
processor=processor,
trust_remote_code=trust_remote_code,
weight_format=weight_format,
)
self.model_type = model_type
self.fast_bias_correction = fast_bias_correction
self.overflow_fix = overflow_fix
self.smooth_quant_alpha = smooth_quant_alpha
self.activation_format = activation_format
self.post_init()

def post_init(self):
Expand All @@ -659,6 +681,12 @@ def post_init(self):
f"SmoothQuant alpha parameter must be in range [0, 1], but found {self.smooth_quant_alpha}"
)

if self.weight_format != "int8":
raise ValueError("Only 'int8' weight format is currently supported.")

if self.activation_format != "int8":
raise ValueError("Only 'int8' activation format is currently supported.")


class OVConfig(BaseConfig):
CONFIG_NAME = "openvino_config.json"
Expand Down
5 changes: 5 additions & 0 deletions optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,11 @@ def _quantize_ovbasemodel(
if calibration_dataset is None:
raise ValueError("Calibration dataset is required to run quantization.")

if quantization_config.weight_format != "int8":
raise ValueError("Only 'int8' weight format is currently supported.")
if quantization_config.activation_format != "int8":
raise ValueError("Only 'int8' activation format is currently supported.")

# Quantize model(s)
if isinstance(self.model, _OVModelForWhisper):
self._quantize_whisper_model(quantization_config, calibration_dataset, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions optimum/intel/openvino/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@
"open_clip_text": "OVModelOpenCLIPText",
"open_clip_vision": "OVModelOpenCLIPVisual",
"open_clip": "OVModelOpenCLIPForZeroShotImageClassification",
"automatic-speech-recognition": "OVModelForSpeechSeq2Seq",
}


Expand Down
39 changes: 38 additions & 1 deletion tests/openvino/test_exporters_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import subprocess
import unittest
from pathlib import Path
from typing import Dict, List
from typing import Dict, List, Tuple

from parameterized import parameterized
from transformers import AutoModelForCausalLM
Expand All @@ -37,6 +37,7 @@
OVModelForQuestionAnswering,
OVModelForSeq2SeqLM,
OVModelForSequenceClassification,
OVModelForSpeechSeq2Seq,
OVModelForTokenClassification,
OVModelForVisualCausalLM,
OVModelOpenCLIPForZeroShotImageClassification,
Expand Down Expand Up @@ -109,6 +110,16 @@ class OVCLIExportTestCase(unittest.TestCase):
SUPPORTED_SD_HYBRID_ARCHITECTURES.append(("stable-diffusion-3", 9, 65))
SUPPORTED_SD_HYBRID_ARCHITECTURES.append(("flux", 7, 56))

SUPPORTED_QUANTIZATION_ARCHITECTURES = [
(
"automatic-speech-recognition",
"whisper",
"--quant-mode int8 --dataset librispeech --num-samples 1 --smooth-quant-alpha 0.9 --trust-remote-code",
(14, 22, 21) if is_transformers_version("<=", "4.36.0") else (14, 22, 25),
(14, 21, 17) if is_transformers_version("<=", "4.36.0") else (14, 22, 18),
),
]

TEST_4BIT_CONFIGURATIONS = [
("text-generation-with-past", "opt125m", "int4 --sym --group-size 128", [{"int8": 4, "int4": 72}]),
("text-generation-with-past", "opt125m", "int4 --group-size 64", [{"int8": 4, "int4": 144}]),
Expand Down Expand Up @@ -391,6 +402,32 @@ def test_exporters_cli_4bit(
"--lora-correction" not in option or b"with correction of low-rank adapters" in result.stdout
)

@parameterized.expand(SUPPORTED_QUANTIZATION_ARCHITECTURES)
def test_exporters_cli_full_quantization(
self,
task: str,
model_type: str,
option: str,
expected_num_fq_nodes_per_model: Tuple[int],
expected_num_weight_nodes_per_model: Tuple[int],
):
with TemporaryDirectory() as tmpdir:
subprocess.run(
f"optimum-cli export openvino --model {MODEL_NAMES[model_type]} {option} {tmpdir}",
shell=True,
check=True,
)
model = eval(_HEAD_TO_AUTOMODELS[task]).from_pretrained(tmpdir)

submodels = []
if task == "automatic-speech-recognition":
submodels = [model.encoder, model.decoder, model.decoder_with_past]
self.assertEqual(len(expected_num_fq_nodes_per_model), len(submodels))
for i, model in enumerate(submodels):
actual_num_fq_nodes, actual_num_weight_nodes = get_num_quantized_nodes(model)
self.assertEqual(expected_num_fq_nodes_per_model[i], actual_num_fq_nodes)
self.assertEqual(expected_num_weight_nodes_per_model[i], actual_num_weight_nodes["int8"])

def test_exporters_cli_int4_with_local_model_and_default_config(self):
with TemporaryDirectory() as tmpdir:
pt_model = AutoModelForCausalLM.from_pretrained(MODEL_NAMES["falcon-40b"])
Expand Down

0 comments on commit ea6fa42

Please sign in to comment.