Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add hybrid quantization for StableDiffusion pipelines #584

Merged
merged 11 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions docs/source/optimization_ov.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,23 @@ from optimum.intel import OVModelForCausalLM
model = OVModelForCausalLM.from_pretrained(model_id, load_in_8bit=True)
```

## Hybrid quantization

Traditional optimization methods like post-training 8-bit quantization do not work well for Stable Diffusion models and can lead to poor generation results. On the other hand, weight compression does not improve performance significantly when applied to Stable Diffusion models, as the size of activations is comparable to weights.
The UNet model takes up most of the overall execution time of the pipeline. Thus, optimizing just one model brings substantial benefits in terms of inference speed while keeping acceptable accuracy without fine-tuning. Quantizing the rest of the diffusion pipeline does not significantly improve inference performance but could potentially lead to substantial degradation of accuracy.
Therefore, the proposal is to apply quantization in *hybrid mode* for the UNet model and weight-only quantization for the rest of the pipeline components. The hybrid mode involves the quantization of weights in MatMul and Embedding layers, and activations of other layers, facilitating accuracy preservation post-optimization while reducing the model size.
The `quantization_config` is utilized to define optimization parameters for optimizing the Stable Diffusion pipeline. To enable hybrid quantization, specify the quantization dataset in the `quantization_config`. Otherwise, weight-only quantization to a specified data type (8 tr 4 bits) is applied to UNet model.

```python
from optimum.intel import OVStableDiffusionPipeline, OVWeightQuantizationConfig

model = OVStableDiffusionPipeline.from_pretrained(
model_id,
export=True,
quantization_config=OVWeightQuantizationConfig(bits=8, dataset="conceptual_captions"),
)
```

<Tip warning={true}>

`load_in_8bit` is enabled by default for the models larger than 1 billion parameters.
Expand Down
37 changes: 23 additions & 14 deletions optimum/intel/openvino/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ class OVWeightQuantizationConfig(QuantizationConfigMixin):

bits (`int`, defaults to 8):
The number of bits to quantize to.
sym (`bool`, *optional*, defaults to `False`):
sym (`bool`, defaults to `False`):
Whether to use symetric quantization.
tokenizer (`str` or `PreTrainedTokenizerBase`, *optional*):
The tokenizer used to process the dataset. You can pass either:
Expand All @@ -177,23 +177,24 @@ class OVWeightQuantizationConfig(QuantizationConfigMixin):
user or organization name, like `dbmdz/bert-base-german-cased`.
- A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved
using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
dataset (`Union[List[str]]`, *optional*):
The dataset used for data-aware compression. You can provide your own dataset in a list of string or just use the
the one from the list ['wikitext2','c4','c4-new','ptb','ptb-new']
group_size (`int`, *optional*, defaults to 128):
The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization.
ratio (`float`, *optional*, defaults to 1.0):
dataset (`str or List[str]`, *optional*):
The dataset used for data-aware compression or quantization with NNCF. You can provide your own dataset
in a list of strings or just use the one from the list ['wikitext2','c4','c4-new','ptb','ptb-new'] for LLLMs
or ['conceptual_captions','laion/220k-GPT4Vision-captions-from-LIVIS','laion/filtered-wit'] for diffusion models.
ratio (`float`, defaults to 1.0):
The ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to INT4_ASYM
and the rest to INT8_ASYM).
group_size (`int`, *optional*):
The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization.
all_layers (`bool`, *optional*):
Defines how many layers are compressed to 4-bits while the rest are kept in 8-bit presicion.
sensitivity_metric (`nncf.SensitivityMetric`, *optional*):
sensitivity_metric (`str`, *optional*):
The sensitivity metric for assigning quantization precision to layers. In order to
preserve the accuracy of the model, the more sensitive layers receives a higher precision.
awq (`bool`, *optional*):
Enables AWQ method to unify weight ranges and improve overall model accuracy.
ignored_scope (`nncf.IgnoredScope`, *optional*):
ignored_scope (`dict`, *optional*):
An ignored scope that defined the list of model control flow graph nodes to be ignored during quantization.
num_samples (`int`, *optional*):
The maximum number of samples composing the calibration dataset.

"""

Expand All @@ -202,12 +203,13 @@ def __init__(
bits: int = 8,
sym: bool = False,
tokenizer: Optional[Any] = None,
dataset: Optional[str] = None,
dataset: Optional[Union[str, List[str]]] = None,
ratio: float = 1.0,
group_size: Optional[int] = None,
all_layers: Optional[bool] = None,
sensitivity_metric: Optional[str] = None,
ignored_scope: Optional[dict] = None,
num_samples: Optional[int] = None,
**kwargs,
):
self.bits = bits
Expand All @@ -219,6 +221,7 @@ def __init__(
self.all_layers = all_layers
self.sensitivity_metric = sensitivity_metric
self.ignored_scope = ignored_scope
self.num_samples = num_samples
self.quant_method = "default" # TODO : enable AWQ after nncf v2.9.0 release
self.post_init()

Expand All @@ -231,10 +234,16 @@ def post_init(self):
if self.group_size is not None and self.group_size != -1 and self.group_size <= 0:
raise ValueError("`group_size` must be greater than 0 or equal to -1")
if self.dataset is not None and isinstance(self.dataset, str):
if self.dataset not in ["wikitext2", "c4", "c4-new", "ptb", "ptb-new"]:
llm_datasets = ["wikitext2", "c4", "c4-new", "ptb", "ptb-new"]
stable_diffusion_datasets = [
"conceptual_captions",
"laion/220k-GPT4Vision-captions-from-LIVIS",
"laion/filtered-wit",
]
if self.dataset not in llm_datasets + stable_diffusion_datasets:
raise ValueError(
f"""You have entered a string value for dataset. You can only choose between
['wikitext2','c4','c4-new','ptb','ptb-new'], but we found {self.dataset}"""
{llm_datasets} for LLLMs or {stable_diffusion_datasets} for diffusion models, but we found {self.dataset}"""
)

if self.bits not in [4, 8]:
Expand Down
3 changes: 2 additions & 1 deletion optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,8 @@ def _from_pretrained(
# from optimum.gptq.utils import get_seqlen

# seqlen = get_seqlen(causal_model)
dataset = get_dataset(quantization_config.dataset, tokenizer, seqlen=32)
nsamples = quantization_config.num_samples if quantization_config.num_samples else 128
dataset = get_dataset(quantization_config.dataset, tokenizer, seqlen=32, nsamples=nsamples)
dataset = prepare_dataset(dataset)
quantization_config = copy.deepcopy(quantization_config)
quantization_config.dataset = nncf.Dataset(dataset, lambda x: causal_model.prepare_inputs(**x))
Expand Down
101 changes: 96 additions & 5 deletions optimum/intel/openvino/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import logging
import os
import shutil
from copy import deepcopy
from pathlib import Path
from tempfile import TemporaryDirectory, gettempdir
from typing import Any, Dict, List, Optional, Union
Expand Down Expand Up @@ -57,7 +58,13 @@
from .configuration import OVConfig, OVWeightQuantizationConfig
from .loaders import OVTextualInversionLoaderMixin
from .modeling_base import OVBaseModel
from .utils import ONNX_WEIGHTS_NAME, OV_TO_NP_TYPE, OV_XML_FILE_NAME, _print_compiled_model_properties
from .utils import (
ONNX_WEIGHTS_NAME,
OV_TO_NP_TYPE,
OV_XML_FILE_NAME,
PREDEFINED_SD_DATASETS,
_print_compiled_model_properties,
)


core = Core()
Expand Down Expand Up @@ -274,9 +281,19 @@ def _from_pretrained(
kwargs[name] = load_method(new_model_save_dir)

quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit)
unet = cls.load_model(
new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name, quantization_config
)

unet_path = new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name
if quantization_config is not None and quantization_config.dataset is not None:
# load the UNet model uncompressed to apply hybrid quantization further
unet = cls.load_model(unet_path)
# Apply weights compression to other `components` without dataset
weight_quantization_params = {
param: value for param, value in quantization_config.__dict__.items() if param != "dataset"
}
weight_quantization_config = OVWeightQuantizationConfig.from_dict(weight_quantization_params)
else:
weight_quantization_config = quantization_config
unet = cls.load_model(unet_path, weight_quantization_config)

components = {
"vae_encoder": new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name,
Expand All @@ -286,11 +303,29 @@ def _from_pretrained(
}

for key, value in components.items():
components[key] = cls.load_model(value, quantization_config) if value.is_file() else None
components[key] = cls.load_model(value, weight_quantization_config) if value.is_file() else None

if model_save_dir is None:
model_save_dir = new_model_save_dir

if quantization_config is not None and quantization_config.dataset is not None:
sd_model = cls(unet=unet, config=config, model_save_dir=model_save_dir, **components, **kwargs)

supported_pipelines = (
OVStableDiffusionPipeline,
OVStableDiffusionXLPipeline,
OVLatentConsistencyModelPipeline,
)
if not isinstance(sd_model, supported_pipelines):
raise NotImplementedError(f"Quantization in hybrid mode is not supported for {cls.__name__}")

nsamples = quantization_config.num_samples if quantization_config.num_samples else 200
unet_inputs = sd_model._prepare_unet_inputs(quantization_config.dataset, nsamples)

from .quantization import _hybrid_quantization

unet = _hybrid_quantization(sd_model.unet.model, weight_quantization_config, dataset=unet_inputs)

return cls(
unet=unet,
config=config,
Expand All @@ -300,6 +335,62 @@ def _from_pretrained(
**kwargs,
)

def _prepare_unet_inputs(
self,
dataset: Union[str, List[Any]],
num_samples: int,
height: Optional[int] = None,
width: Optional[int] = None,
seed: Optional[int] = 42,
**kwargs,
) -> Dict[str, Any]:
self.compile()

size = self.unet.config.get("sample_size", 64) * self.vae_scale_factor
height = height or min(size, 512)
width = width or min(size, 512)

if isinstance(dataset, str):
dataset = deepcopy(dataset)
available_datasets = PREDEFINED_SD_DATASETS.keys()
if dataset not in available_datasets:
raise ValueError(
f"""You have entered a string value for dataset. You can only choose between
{list(available_datasets)}, but the {dataset} was found"""
)

from datasets import load_dataset

dataset_metadata = PREDEFINED_SD_DATASETS[dataset]
dataset = load_dataset(dataset, split=dataset_metadata["split"], streaming=True).shuffle(seed=seed)
input_names = dataset_metadata["inputs"]
dataset = dataset.select_columns(list(input_names.values()))

def transform_fn(data_item):
return {inp_name: data_item[column] for inp_name, column in input_names.items()}

else:

def transform_fn(data_item):
return data_item if isinstance(data_item, (list, dict)) else [data_item]

from .quantization import InferRequestWrapper

calibration_data = []
self.unet.request = InferRequestWrapper(self.unet.request, calibration_data)

for inputs in dataset:
inputs = transform_fn(inputs)
if isinstance(inputs, dict):
self.__call__(**inputs, height=height, width=width)
else:
self.__call__(*inputs, height=height, width=width)
if len(calibration_data) > num_samples:
break

self.unet.request = self.unet.request.request
return calibration_data[:num_samples]

@classmethod
def _from_transformers(
cls,
Expand Down
95 changes: 93 additions & 2 deletions optimum/intel/openvino/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import inspect
import logging
import os
from collections import deque
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple, Union

Expand All @@ -24,6 +25,7 @@
import torch
import transformers
from nncf import CompressWeightsMode, IgnoredScope, NNCFConfig, SensitivityMetric
from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
from nncf.torch import create_compressed_model, register_default_init_args, register_module
from nncf.torch.dynamic_graph.io_handling import wrap_nncf_model_inputs_with_objwalk
from nncf.torch.initialization import PTInitializingDataLoader
Expand Down Expand Up @@ -550,7 +552,7 @@ def _remove_unused_columns(self, dataset: "Dataset"):

def _weight_only_quantization(
model: openvino.runtime.Model, quantization_config: Union[OVWeightQuantizationConfig, Dict]
):
) -> openvino.runtime.Model:
config = quantization_config
if isinstance(config, dict):
config = OVWeightQuantizationConfig.from_dict(quantization_config)
Expand All @@ -564,7 +566,8 @@ def _weight_only_quantization(

from optimum.gptq.data import get_dataset, prepare_dataset

dataset = get_dataset(config.dataset, tokenizer, seqlen=32)
nsamples = config.num_samples if config.num_samples else 128
dataset = get_dataset(config.dataset, tokenizer, seqlen=32, nsamples=nsamples)
dataset = prepare_dataset(dataset)

sensitivity_metric = None
Expand All @@ -590,4 +593,92 @@ def _weight_only_quantization(
# awq=config.quant_method == "awq", # TODO : remove and add it back once nncf v2.9.0
ignored_scope=ignored_scope,
dataset=dataset,
# subset_size=config.num_samples if config.num_samples else 128, # TODO : enable from nncf v2.9.0
)


def _get_operation_const_op(operation, const_port_id: int):
node = operation.input_value(const_port_id).get_node()
queue = deque([node])
constant_node = None
allowed_propagation_types_list = ["Convert", "FakeQuantize", "Reshape"]

while len(queue) != 0:
curr_node = queue.popleft()
if curr_node.get_type_name() == "Constant":
constant_node = curr_node
break
if len(curr_node.inputs()) == 0:
break
if curr_node.get_type_name() in allowed_propagation_types_list:
queue.append(curr_node.input_value(0).get_node())

return constant_node


def _is_embedding(node) -> bool:
allowed_types_list = ["f16", "f32", "f64"]
const_port_id = 0
input_tensor = node.input_value(const_port_id)
if input_tensor.get_element_type().get_type_name() in allowed_types_list:
const_node = _get_operation_const_op(node, const_port_id)
if const_node is not None:
return True

return False


def _collect_ops_with_weights(model):
ops_with_weights = []
for op in model.get_ops():
if op.get_type_name() == "MatMul":
constant_node_0 = _get_operation_const_op(op, const_port_id=0)
constant_node_1 = _get_operation_const_op(op, const_port_id=1)
if constant_node_0 or constant_node_1:
ops_with_weights.append(op.get_friendly_name())
if op.get_type_name() == "Gather" and _is_embedding(op):
ops_with_weights.append(op.get_friendly_name())

return ops_with_weights


def _hybrid_quantization(
model: openvino.runtime.Model, quantization_config: OVWeightQuantizationConfig, dataset: Dict[str, Any]
) -> openvino.runtime.Model:
"""
Quantize a model in hybrid mode with NNCF which means that we quantize:
weights of MatMul and Embedding layers and activations of other layers.
The optimization specifications defined in `quantization_config`.

Args:
model (`openvino.runtime.Model`):
The OpenVINO Runtime model for applying hybrid quantization.
quantization_config (`OVWeightQuantizationConfig`):
The configuration containing the parameters related to quantization.
dataset (`Dict[str, Any]`):
The dataset used for hybrid quantization.
Returns:
The OpenVINO Runtime model with applied hybrid quantization.
"""
ops_to_compress = _collect_ops_with_weights(model)

ignored_scope = quantization_config.ignored_scope if isinstance(quantization_config.ignored_scope, dict) else {}
ptq_ignored_scope = nncf.IgnoredScope(**ignored_scope)
ptq_ignored_scope.names += ops_to_compress

wc_quantization_config = copy.deepcopy(quantization_config)
wc_quantization_config.ignored_scope = ignored_scope
wc_quantization_config.ignored_scope["types"] = ignored_scope.get("types", []) + ["Convolution"]
compressed_model = _weight_only_quantization(model, wc_quantization_config)

subset_size = quantization_config.num_samples if quantization_config.num_samples else 200
quantized_model = nncf.quantize(
model=compressed_model,
calibration_dataset=nncf.Dataset(dataset),
model_type=nncf.ModelType.TRANSFORMER,
ignored_scope=ptq_ignored_scope,
# The SQ algo should be disabled for MatMul nodes because their weights are already compressed
advanced_parameters=nncf.AdvancedQuantizationParameters(AdvancedSmoothQuantParameters(matmul=-1)),
subset_size=subset_size,
)
return quantized_model
Loading
Loading