diff --git a/.github/workflows/test_openvino.yml b/.github/workflows/test_openvino.yml
index bf9460c75a..6d709eecfd 100644
--- a/.github/workflows/test_openvino.yml
+++ b/.github/workflows/test_openvino.yml
@@ -32,7 +32,7 @@ jobs:
python -m pip install --upgrade pip
# install PyTorch CPU version to avoid installing CUDA packages on GitHub runner without GPU
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
- pip install .[openvino,openvino-tokenizers,nncf,tests,diffusers]
+ pip install .[openvino,openvino-tokenizers,tests,diffusers] onnxruntime
- name: Test with Pytest
run: |
pytest tests/openvino/ --ignore test_modeling_basic
diff --git a/README.md b/README.md
index 7b762cce26..7905cefded 100644
--- a/README.md
+++ b/README.md
@@ -10,7 +10,7 @@
Intel [Neural Compressor](https://www.intel.com/content/www/us/en/developer/tools/oneapi/neural-compressor.html) is an open-source library enabling the usage of the most popular compression techniques such as quantization, pruning and knowledge distillation. It supports automatic accuracy-driven tuning strategies in order for users to easily generate quantized model. The users can easily apply static, dynamic and aware-training quantization approaches while giving an expected accuracy criteria. It also supports different weight pruning techniques enabling the creation of pruned model giving a predefined sparsity target.
-[OpenVINO](https://docs.openvino.ai/latest/index.html) is an open-source toolkit that enables high performance inference capabilities for Intel CPUs, GPUs, and special DL inference accelerators ([see](https://docs.openvino.ai/latest/openvino_docs_OV_UG_supported_plugins_Supported_Devices.html) the full list of supported devices). It is supplied with a set of tools to optimize your models with compression techniques such as quantization, pruning and knowledge distillation. Optimum Intel provides a simple interface to optimize your Transformers and Diffusers models, convert them to the OpenVINO Intermediate Representation (IR) format and run inference using OpenVINO Runtime.
+[OpenVINO](https://docs.openvino.ai) is an open-source toolkit that enables high performance inference capabilities for Intel CPUs, GPUs, and special DL inference accelerators ([see](https://docs.openvino.ai/2024/about-openvino/compatibility-and-support/supported-devices.html) the full list of supported devices). It is supplied with a set of tools to optimize your models with compression techniques such as quantization, pruning and knowledge distillation. Optimum Intel provides a simple interface to optimize your Transformers and Diffusers models, convert them to the OpenVINO Intermediate Representation (IR) format and run inference using OpenVINO Runtime.
## Installation
@@ -20,7 +20,7 @@ To install the latest release of 🤗 Optimum Intel with the corresponding requi
| Accelerator | Installation |
|:-----------------------------------------------------------------------------------------------------------------|:---------------------------------------------------------------------|
| [Intel Neural Compressor](https://www.intel.com/content/www/us/en/developer/tools/oneapi/neural-compressor.html) | `pip install --upgrade-strategy eager "optimum[neural-compressor]"` |
-| [OpenVINO](https://docs.openvino.ai/latest/index.html) | `pip install --upgrade-strategy eager "optimum[openvino,nncf]"` |
+| [OpenVINO](https://docs.openvino.ai) | `pip install --upgrade-strategy eager "optimum[openvino]"` |
| [Intel Extension for PyTorch](https://intel.github.io/intel-extension-for-pytorch/#introduction) | `pip install --upgrade-strategy eager "optimum[ipex]"` |
The `--upgrade-strategy eager` option is needed to ensure `optimum-intel` is upgraded to the latest version.
@@ -68,11 +68,11 @@ For more details on the supported compression techniques, please refer to the [d
## OpenVINO
-Below are the examples of how to use OpenVINO and its [NNCF](https://docs.openvino.ai/latest/tmo_introduction.html) framework to accelerate inference.
+Below are examples of how to use OpenVINO and its [NNCF](https://docs.openvino.ai/2024/openvino-workflow/model-optimization-guide/compressing-models-during-training.html) framework to accelerate inference.
#### Export:
-It is possible to export your model to the [OpenVINO](https://docs.openvino.ai/2023.1/openvino_ir.html) IR format with the CLI :
+It is possible to export your model to the [OpenVINO IR](https://docs.openvino.ai/2024/documentation/openvino-ir-format.html) format with the CLI :
```plain
optimum-cli export openvino --model gpt2 ov_model
diff --git a/docs/source/index.mdx b/docs/source/index.mdx
index cbec79baa9..643b9be044 100644
--- a/docs/source/index.mdx
+++ b/docs/source/index.mdx
@@ -21,7 +21,7 @@ limitations under the License.
[Intel Neural Compressor](https://www.intel.com/content/www/us/en/developer/tools/oneapi/neural-compressor.html) is an open-source library enabling the usage of the most popular compression techniques such as quantization, pruning and knowledge distillation. It supports automatic accuracy-driven tuning strategies in order for users to easily generate quantized model. The users can easily apply static, dynamic and aware-training quantization approaches while giving an expected accuracy criteria. It also supports different weight pruning techniques enabling the creation of pruned model giving a predefined sparsity target.
-[OpenVINO](https://docs.openvino.ai/latest/index.html) is an open-source toolkit that enables high performance inference capabilities for Intel CPUs, GPUs, and special DL inference accelerators ([see](https://docs.openvino.ai/latest/openvino_docs_OV_UG_supported_plugins_Supported_Devices.html) the full list of supported devices). It is supplied with a set of tools to optimize your models with compression techniques such as quantization, pruning and knowledge distillation. Optimum Intel provides a simple interface to optimize your Transformers and Diffusers models, convert them to the OpenVINO Intermediate Representation (IR) format and run inference using OpenVINO Runtime.
+[OpenVINO](https://docs.openvino.ai) is an open-source toolkit that enables high performance inference capabilities for Intel CPUs, GPUs, and special DL inference accelerators ([see](https://docs.openvino.ai/2024/about-openvino/compatibility-and-support/supported-devices.html) the full list of supported devices). It is supplied with a set of tools to optimize your models with compression techniques such as quantization, pruning and knowledge distillation. Optimum Intel provides a simple interface to optimize your Transformers and Diffusers models, convert them to the OpenVINO Intermediate Representation (IR) format and run inference using OpenVINO Runtime.
@@ -34,4 +34,4 @@ limitations under the License.
Learn how to run inference with OpenVINO Runtime and to apply quantization, pruning and knowledge distillation on your model to further speed up inference.
-
\ No newline at end of file
+
diff --git a/docs/source/inference.mdx b/docs/source/inference.mdx
index a9ee5529da..65480c1d2f 100644
--- a/docs/source/inference.mdx
+++ b/docs/source/inference.mdx
@@ -13,7 +13,8 @@ Optimum Intel can be used to load optimized models from the [Hugging Face Hub](h
## Transformers models
-You can now easily perform inference with OpenVINO Runtime on a variety of Intel processors ([see](https://docs.openvino.ai/latest/openvino_docs_OV_UG_supported_plugins_Supported_Devices.html) the full list of supported devices).
+You can now easily perform inference with OpenVINO Runtime on a variety of Intel processors
+([see](https://docs.openvino.ai/2024/about-openvino/compatibility-and-support/supported-devices.html) the full list of supported devices).
For that, just replace the `AutoModelForXxx` class with the corresponding `OVModelForXxx` class.
As shown in the table below, each task is associated with a class enabling to automatically load your model.
@@ -33,7 +34,7 @@ As shown in the table below, each task is associated with a class enabling to au
### Export
-It is possible to export your model to the [OpenVINO](https://docs.openvino.ai/2023.1/openvino_ir.html) IR format with the CLI :
+It is possible to export your model to the [OpenVINO IR](https://docs.openvino.ai/2024/documentation/openvino-ir-format.html) format with the CLI :
```bash
optimum-cli export openvino --model gpt2 ov_model
@@ -110,7 +111,7 @@ By default the quantization scheme will be [assymmetric](https://github.com/open
For INT4 quantization you can also specify the following arguments :
* The `--group-size` parameter will define the group size to use for quantization, `-1` it will results in per-column quantization.
-* The `--ratio` CLI parameter controls the ratio between 4-bit and 8-bit quantization. If set to 0.9, it means that 90% of the layers will be quantized to `int4` while 10% will be quantized to `int8`.
+* The `--ratio` parameter controls the ratio between 4-bit and 8-bit quantization. If set to 0.9, it means that 90% of the layers will be quantized to `int4` while 10% will be quantized to `int8`.
Smaller `group_size` and `ratio` of usually improve accuracy at the sacrifice of the model size and inference latency.
@@ -122,8 +123,11 @@ from optimum.intel import OVModelForCausalLM
model = OVModelForCausalLM.from_pretrained(model_id, load_in_8bit=True)
```
-> **NOTE:** `load_in_8bit` is enabled by default for the models larger than 1 billion parameters.
+
+`load_in_8bit` is enabled by default for the models larger than 1 billion parameters.
+
+
To apply quantization on both weights and activations, you can use the `OVQuantizer`, more information in the [documentation](https://huggingface.co/docs/optimum/main/en/intel/optimization_ov#optimization).
@@ -179,7 +183,7 @@ model.reshape(1,128)
model.compile()
```
-To run inference on Intel integrated or discrete GPU, use `.to("gpu")`. On GPU, models run in FP16 precision by default. (See [OpenVINO documentation](https://docs.openvino.ai/nightly/openvino_docs_install_guides_configurations_for_intel_gpu.html) about installing drivers for GPU inference).
+To run inference on Intel integrated or discrete GPU, use `.to("gpu")`. On GPU, models run in FP16 precision by default. (See [OpenVINO documentation](https://docs.openvino.ai/2024/get-started/configurations/configurations-intel-gpu.html) about installing drivers for GPU inference).
```python
# Static shapes speed up inference
@@ -468,7 +472,7 @@ image = refiner(prompt=prompt, image=image[None, :]).images[0]
```
-## Latent Consistency Models
+### Latent Consistency Models
| Task | Auto Class |
@@ -476,7 +480,7 @@ image = refiner(prompt=prompt, image=image[None, :]).images[0]
| `text-to-image` | `OVLatentConsistencyModelPipeline` |
-### Text-to-Image
+#### Text-to-Image
Here is an example of how you can load a Latent Consistency Models (LCMs) from [SimianLuo/LCM_Dreamshaper_v7](https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7) and run inference using OpenVINO :
diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx
index cf8688d105..c29f5ceb95 100644
--- a/docs/source/installation.mdx
+++ b/docs/source/installation.mdx
@@ -21,7 +21,7 @@ To install the latest release of 🤗 Optimum Intel with the corresponding requi
| Accelerator | Installation |
|:-----------------------------------------------------------------------------------------------------------------------|:-------------------------------------------------------------------|
| [Intel Neural Compressor (INC)](https://www.intel.com/content/www/us/en/developer/tools/oneapi/neural-compressor.html) | `pip install --upgrade-strategy eager "optimum[neural-compressor]"`|
-| [Intel OpenVINO](https://docs.openvino.ai/latest/index.html) | `pip install --upgrade-strategy eager "optimum[openvino,nncf]"` |
+| [Intel OpenVINO](https://docs.openvino.ai ) | `pip install --upgrade-strategy eager "optimum[openvino]"` |
The `--upgrade-strategy eager` option is needed to ensure `optimum-intel` is upgraded to the latest version.
@@ -42,4 +42,4 @@ or to install from source including dependencies:
python -m pip install "optimum-intel[extras]"@git+https://github.com/huggingface/optimum-intel.git
```
-where `extras` can be one or more of `neural-compressor`, `openvino`, `nncf`.
\ No newline at end of file
+where `extras` can be one or more of `neural-compressor`, `openvino`, `nncf`.
diff --git a/docs/source/optimization_ov.mdx b/docs/source/optimization_ov.mdx
index 77dab40159..70c98f14f7 100644
--- a/docs/source/optimization_ov.mdx
+++ b/docs/source/optimization_ov.mdx
@@ -38,8 +38,6 @@ save_dir = "ptq_model"
def preprocess_function(examples, tokenizer):
return tokenizer(examples["sentence"], padding="max_length", max_length=128, truncation=True)
-# Load the default quantization configuration detailing the quantization we wish to apply
-quantization_config = OVConfig()
# Instantiate our OVQuantizer using the desired configuration
quantizer = OVQuantizer.from_pretrained(model)
# Create the calibration dataset used to perform static quantization
@@ -52,7 +50,6 @@ calibration_dataset = quantizer.get_calibration_dataset(
)
# Apply static quantization and export the resulting quantized model to OpenVINO IR format
quantizer.quantize(
- quantization_config=quantization_config,
calibration_dataset=calibration_dataset,
save_directory=save_dir,
)
@@ -72,7 +69,28 @@ from optimum.intel import OVModelForCausalLM
model = OVModelForCausalLM.from_pretrained(model_id, load_in_8bit=True)
```
-> **NOTE:** `load_in_8bit` is enabled by default for models larger than 1 billion parameters.
+## Hybrid quantization
+
+Traditional optimization methods like post-training 8-bit quantization do not work well for Stable Diffusion models and can lead to poor generation results. On the other hand, weight compression does not improve performance significantly when applied to Stable Diffusion models, as the size of activations is comparable to weights.
+The UNet model takes up most of the overall execution time of the pipeline. Thus, optimizing just one model brings substantial benefits in terms of inference speed while keeping acceptable accuracy without fine-tuning. Quantizing the rest of the diffusion pipeline does not significantly improve inference performance but could potentially lead to substantial degradation of accuracy.
+Therefore, the proposal is to apply quantization in *hybrid mode* for the UNet model and weight-only quantization for the rest of the pipeline components. The hybrid mode involves the quantization of weights in MatMul and Embedding layers, and activations of other layers, facilitating accuracy preservation post-optimization while reducing the model size.
+The `quantization_config` is utilized to define optimization parameters for optimizing the Stable Diffusion pipeline. To enable hybrid quantization, specify the quantization dataset in the `quantization_config`. Otherwise, weight-only quantization to a specified data type (8 tr 4 bits) is applied to UNet model.
+
+```python
+from optimum.intel import OVStableDiffusionPipeline, OVWeightQuantizationConfig
+
+model = OVStableDiffusionPipeline.from_pretrained(
+ model_id,
+ export=True,
+ quantization_config=OVWeightQuantizationConfig(bits=8, dataset="conceptual_captions"),
+)
+```
+
+
+
+`load_in_8bit` is enabled by default for the models larger than 1 billion parameters.
+
+
For the 4-bit weight quantization you can use the `quantization_config` to specify the optimization parameters, for example:
@@ -81,7 +99,17 @@ from optimum.intel import OVModelForCausalLM, OVWeightQuantizationConfig
model = OVModelForCausalLM.from_pretrained(
model_id,
- export=True,
+ quantization_config=OVWeightQuantizationConfig(bits=4),
+)
+```
+
+You can tune quantization parameters to achieve a better performance accuracy trade-off as follows:
+
+```python
+from optimum.intel import OVModelForCausalLM, OVWeightQuantizationConfig
+
+model = OVModelForCausalLM.from_pretrained(
+ model_id,
quantization_config=OVWeightQuantizationConfig(bits=4, sym=False, ratio=0.8, dataset="ptb"),
)
```
diff --git a/examples/openvino/image-classification/run_image_classification.py b/examples/openvino/image-classification/run_image_classification.py
index 8a7c009e46..5f98d95cb5 100644
--- a/examples/openvino/image-classification/run_image_classification.py
+++ b/examples/openvino/image-classification/run_image_classification.py
@@ -151,12 +151,12 @@ class ModelArguments:
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
)
feature_extractor_name: str = field(default=None, metadata={"help": "Name or path of preprocessor config."})
- use_auth_token: bool = field(
- default=False,
+ token: str = field(
+ default=None,
metadata={
"help": (
- "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
- "with private models)."
+ "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
+ "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
)
},
)
@@ -239,8 +239,7 @@ def main():
data_args.dataset_name,
data_args.dataset_config_name,
cache_dir=model_args.cache_dir,
- task="image-classification",
- use_auth_token=True if model_args.use_auth_token else None,
+ token=model_args.token,
)
else:
data_files = {}
@@ -252,7 +251,6 @@ def main():
"imagefolder",
data_files=data_files,
cache_dir=model_args.cache_dir,
- task="image-classification",
)
# If we don't have a validation split, split off a percentage of train as validation.
@@ -287,7 +285,7 @@ def compute_metrics(p):
finetuning_task="image-classification",
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
- use_auth_token=True if model_args.use_auth_token else None,
+ token=model_args.token,
)
model = AutoModelForImageClassification.from_pretrained(
model_args.model_name_or_path,
@@ -295,7 +293,7 @@ def compute_metrics(p):
config=config,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
- use_auth_token=True if model_args.use_auth_token else None,
+ token=model_args.token,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
)
@@ -311,7 +309,7 @@ def compute_metrics(p):
model_args.feature_extractor_name or model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
- use_auth_token=True if model_args.use_auth_token else None,
+ token=model_args.token,
)
# Define torchvision transforms to be applied to each image.
diff --git a/optimum/commands/export/openvino.py b/optimum/commands/export/openvino.py
index 010c47218a..5617ca55f3 100644
--- a/optimum/commands/export/openvino.py
+++ b/optimum/commands/export/openvino.py
@@ -162,13 +162,12 @@ def run(self):
)
self.args.weight_format = "int8"
- weight_format = self.args.weight_format or "fp32"
-
- ov_config = None
- if weight_format in {"fp16", "fp32"}:
- ov_config = OVConfig(dtype=weight_format)
+ if self.args.weight_format is None:
+ ov_config = None
+ elif self.args.weight_format in {"fp16", "fp32"}:
+ ov_config = OVConfig(dtype=self.args.weight_format)
else:
- is_int8 = weight_format == "int8"
+ is_int8 = self.args.weight_format == "int8"
# For int4 quantization if not parameter is provided, then use the default config if exist
if (
@@ -187,12 +186,12 @@ def run(self):
"group_size": -1 if is_int8 else self.args.group_size,
}
- if weight_format in {"int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"}:
+ if self.args.weight_format in {"int4_sym_g128", "int4_asym_g128", "int4_sym_g64", "int4_asym_g64"}:
logger.warning(
- f"--weight-format {weight_format} is deprecated, possible choices are fp32, fp16, int8, int4"
+ f"--weight-format {self.args.weight_format} is deprecated, possible choices are fp32, fp16, int8, int4"
)
- quantization_config["sym"] = "asym" not in weight_format
- quantization_config["group_size"] = 128 if "128" in weight_format else 64
+ quantization_config["sym"] = "asym" not in self.args.weight_format
+ quantization_config["group_size"] = 128 if "128" in self.args.weight_format else 64
ov_config = OVConfig(quantization_config=quantization_config)
if self.args.convert_tokenizer:
diff --git a/optimum/exporters/ipex/__init__.py b/optimum/exporters/ipex/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/optimum/exporters/ipex/model_patcher.py b/optimum/exporters/ipex/model_patcher.py
new file mode 100644
index 0000000000..60ff3b721b
--- /dev/null
+++ b/optimum/exporters/ipex/model_patcher.py
@@ -0,0 +1,91 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from transformers.models.llama.modeling_llama import (
+ LlamaAttention,
+ LlamaDecoderLayer,
+ LlamaForCausalLM,
+ LlamaModel,
+ LlamaRMSNorm,
+)
+
+from optimum.intel.utils.import_utils import is_ipex_version
+
+from .modeling_utils import (
+ _IPEXLlamaDecoderLayerRef,
+ _llama_attn_forward,
+ _llama_layer_norm_forward,
+ _llama_model_forward,
+)
+
+
+_IPEX_EXPORTED_ARCH = ("LlamaForCausalLM",)
+_IPEX_EXPORTED_TASK = ("text-generation",)
+
+
+def convert_func(m, func_name, new_function):
+ bound_method = new_function.__get__(m, m.__class__)
+ setattr(m, func_name, bound_method)
+
+
+def convert_functions(m, target_m, new_function_name, new_function):
+ for _, sub_m in m.named_children():
+ if isinstance(sub_m, target_m):
+ convert_func(sub_m, new_function_name, new_function)
+ convert_functions(sub_m, target_m, new_function_name, new_function)
+
+
+def convert_class(m, target_m, new_class, config, distributed=False):
+ for name, sub_m in m.named_children():
+ if isinstance(sub_m, target_m):
+ new_m = new_class(sub_m, config, distributed)
+ setattr(m, name, new_m)
+ convert_class(sub_m, target_m, new_class, config, distributed)
+
+
+def patch_op(m, target_m, new_op_name, new_op):
+ for name, sub_m in m.named_children():
+ if isinstance(sub_m, target_m):
+ setattr(sub_m, new_op_name, new_op)
+ patch_op(sub_m, target_m, new_op_name, new_op)
+
+
+def _patch_llama_model(model):
+ if is_ipex_version("<", "2.5.0"):
+ raise ImportError("Only ipex version > 2.3.0 supports RotaryEmbedding and IndirectAccessKVCache")
+
+ from intel_extension_for_pytorch.llm.modules import IndirectAccessKVCache, RotaryEmbedding
+
+ ipex_rope = RotaryEmbedding(
+ model.config.max_position_embeddings,
+ model.config.hidden_size // model.config.num_attention_heads,
+ model.config.rope_theta,
+ model.config.architectures[0],
+ )
+ ipex_scale_dot_product = IndirectAccessKVCache(text_max_length=model.config.max_position_embeddings)
+ patch_op(model, LlamaAttention, "ipex_rope", ipex_rope)
+ patch_op(model, LlamaAttention, "ipex_scale_dot_product", ipex_scale_dot_product)
+
+ convert_functions(model, LlamaModel, "forward", _llama_model_forward)
+ convert_functions(model, LlamaAttention, "forward", _llama_attn_forward)
+ convert_functions(model, LlamaRMSNorm, "forward", _llama_layer_norm_forward)
+
+ convert_class(model, LlamaDecoderLayer, _IPEXLlamaDecoderLayerRef, model.config)
+ return model
+
+
+def _patch_model(model):
+ if isinstance(model, LlamaForCausalLM):
+ model = _patch_llama_model(model)
+ return model
diff --git a/optimum/exporters/ipex/modeling_utils.py b/optimum/exporters/ipex/modeling_utils.py
new file mode 100644
index 0000000000..f75e559eaf
--- /dev/null
+++ b/optimum/exporters/ipex/modeling_utils.py
@@ -0,0 +1,307 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
+from transformers.modeling_outputs import BaseModelOutputWithPast
+from transformers.models.llama.modeling_llama import repeat_kv
+
+from optimum.intel.utils.import_utils import is_ipex_version
+
+
+# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L83
+def _llama_layer_norm_forward(self, hidden_states):
+ return torch.ops.torch_ipex.rmsnorm(hidden_states, self.weight, self.variance_epsilon)
+
+
+# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L321
+def _llama_attn_forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ **kwargs,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query = self.q_proj(hidden_states)
+ key = self.k_proj(hidden_states)
+ value = self.v_proj(hidden_states)
+
+ kv_seq_len = q_len + past_key_value[0].size(-2) if past_key_value is not None else q_len
+
+ query = query.view(bsz, q_len, self.num_heads, self.head_dim)
+ key = key.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
+ value = value.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
+ # Use ipex op to rotary position embedding more efficient.
+ key = self.ipex_rope(
+ key,
+ position_ids,
+ self.num_key_value_heads,
+ self.head_dim,
+ self.head_dim // 2,
+ self.head_dim,
+ kv_seq_len,
+ )
+ query = self.ipex_rope(
+ query,
+ position_ids,
+ self.num_heads,
+ self.head_dim,
+ self.head_dim // 2,
+ self.head_dim,
+ kv_seq_len,
+ )
+
+ if use_cache:
+ # This ipex op pre-allocates buffers for past_key_values and use beam index history
+ # which to decide which beam should be used to make attention scale dot more efficient.
+ (attn_output, attn_weights, past_key_value) = self.ipex_scale_dot_product(
+ query,
+ key,
+ value,
+ math.sqrt(self.head_dim),
+ past_key_value,
+ None,
+ attention_mask,
+ )
+ else:
+ value_states = value.transpose(1, 2)
+ query_states = query.transpose(1, 2)
+ key_states = key.transpose(1, 2)
+ kv_seq_len = key_states.shape[-2]
+
+ past_key_value = None
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attention_mask is not None:
+ attn_weights = torch.tensor(attn_weights) + torch.tensor(attention_mask)
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L1130
+def _llama_model_forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **kwargs,
+) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape[:2]
+ elif inputs_embeds is not None:
+ batch_size, seq_length = inputs_embeds.shape[:2]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ past_key_values_length = 0
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ if getattr(self.config, "_flash_attn_2_enabled", False):
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
+ )
+
+ # embed positions
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = () if use_cache else None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
+
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+# Adapted from https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/llama/modeling_llama.py#L694
+class _IPEXLlamaDecoderLayerRef(nn.Module):
+ def __init__(self, module, config, distributed=False):
+ if is_ipex_version("<", "2.5.0"):
+ raise ImportError("Only ipex version > 2.3.0 supports Linear2SiluMul and LinearAdd")
+
+ from intel_extension_for_pytorch.llm.modules import Linear2SiluMul, LinearAdd
+
+ super().__init__()
+ for k, v in module.__dict__.items():
+ setattr(self, k, v)
+ for k, v in module.__class__.__dict__.items():
+ if k.startswith("__") or k.startswith("forward"):
+ continue
+ setattr(self.__class__, k, getattr(module.__class__, k))
+ self.distributed = distributed
+ if not self.distributed:
+ self.mha_linear_add = LinearAdd(module.self_attn.o_proj)
+ self.mlp_linear_add = LinearAdd(module.mlp.down_proj)
+ del self.__dict__["_modules"]["self_attn"].o_proj
+ del self.__dict__["_modules"]["mlp"].down_proj
+ self.linear_silu_mul = Linear2SiluMul(module.mlp.gate_proj, module.mlp.up_proj)
+ del self.__dict__["_modules"]["mlp"].gate_proj
+ del self.__dict__["_modules"]["mlp"].up_proj
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*):
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
+ query_sequence_length, key_sequence_length)` if default attention is used.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
+ """
+
+ residual = hidden_states
+ hidden_states = self.input_layernorm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ if not self.distributed:
+ hidden_states = self.mha_linear_add(hidden_states, residual)
+ else:
+ hidden_states = self.self_attn.o_proj(hidden_states)
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.post_attention_layernorm(hidden_states)
+
+ mlp_gate = self.linear_silu_mul(hidden_states)
+
+ if not self.distributed:
+ hidden_states = self.mlp_linear_add(mlp_gate, residual)
+ else:
+ hidden_states = self.mlp.down_proj(mlp_gate)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py
index c3d84df043..e1f735c9a9 100644
--- a/optimum/exporters/openvino/__main__.py
+++ b/optimum/exporters/openvino/__main__.py
@@ -21,23 +21,12 @@
from optimum.exporters import TasksManager
from optimum.exporters.onnx.base import OnnxConfig
+from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
from optimum.utils.save_utils import maybe_load_preprocessors
-from ...intel.utils.import_utils import (
- is_optimum_version,
- is_transformers_version,
-)
-from .convert import export_from_model, export_tokenizer
-
-if is_optimum_version(">=", "1.16.0"):
- from optimum.exporters.onnx.constants import SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
-else:
- # Copied from https://github.com/huggingface/optimum/blob/main/optimum/exporters/onnx/constants.py
- SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED = [
- "bart",
- "whisper",
- ]
+from ...intel.utils.import_utils import is_transformers_version
+from .convert import export_from_model, export_tokenizer
if TYPE_CHECKING:
diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py
index f953771a7a..91dc48df05 100644
--- a/optimum/exporters/openvino/model_patcher.py
+++ b/optimum/exporters/openvino/model_patcher.py
@@ -15,30 +15,49 @@
import logging as log
from optimum.intel.utils.import_utils import (
+ _openvino_version,
_torch_version,
_transformers_version,
+ is_openvino_version,
is_torch_version,
is_transformers_version,
)
def patch_model_with_bettertransformer(model):
+ COLOR_RED = "\033[1;31m"
+ COLOR_RESET = "\033[0m"
+
# check that the model has not yet been pathced
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
return model
if is_transformers_version("<", "4.36") or is_torch_version("<", "2.1.1"):
- COLOR_RED = "\033[1;31m"
- COLOR_RESET = "\033[0m"
log.warn(
COLOR_RED
+ "[WARNING] For good performance with stateful models, transformers>=4.36.2 and PyTorch>=2.1.1 are required. "
f"This Python environment has Transformers {_transformers_version} and PyTorch {_torch_version}. "
"Consider upgrading PyTorch and Transformers, for example by running "
- "`pip install --upgrade --upgrade-strategy eager optimum[openvino,nncf]`, and export the model again"
+ "`pip install --upgrade --upgrade-strategy eager optimum[openvino]`, and export the model again"
+ COLOR_RESET
)
+ if (
+ getattr(model.config, "model_type") in {"gpt_bigcode", "llama"}
+ and is_transformers_version(">=", "4.38")
+ and is_openvino_version("<", "2024.1.0-14612")
+ ):
+ # display commit-id only when a nightly/prerelease of OpenVINO is installed.
+ display_version = (
+ _openvino_version.split("-")[0] if is_openvino_version("<=", "2024.0.0-14509") else _openvino_version
+ )
+ log.warn(
+ COLOR_RED + f"[WARNING] Stateful models are not supported for Llama and GPTBigCode with Transformers "
+ f"{_transformers_version} and OpenVINO {display_version}. For good performance, consider using a nightly OpenVINO build: "
+ "https://docs.openvino.ai/2024/get-started/install-openvino.html. For models that do not need transformers "
+ "4.38+, it is also an option to downgrade transformers: `pip install transformers==4.37.2`" + COLOR_RESET
+ )
+
# model already has required SDPA implementation
if getattr(model, "_supports_sdpa", False) and getattr(model.config, "_attn_implementation", "eager") == "sdpa":
return model
diff --git a/optimum/intel/__init__.py b/optimum/intel/__init__.py
index 93a4417bfc..59059d688d 100644
--- a/optimum/intel/__init__.py
+++ b/optimum/intel/__init__.py
@@ -18,6 +18,7 @@
from transformers.utils import OptionalDependencyNotAvailable, _LazyModule
from .utils import (
+ is_accelerate_available,
is_diffusers_available,
is_ipex_available,
is_neural_compressor_available,
@@ -29,6 +30,7 @@
_import_structure = {
"openvino": [],
+ "utils.dummy_openvino_and_nncf_objects": [],
}
try:
@@ -57,13 +59,19 @@
if not (is_openvino_available() and is_nncf_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
- _import_structure["utils.dummy_openvino_and_nncf_objects"] = [
- "OVQuantizer",
- "OVTrainer",
- "OVTrainingArguments",
- ]
+ _import_structure["utils.dummy_openvino_and_nncf_objects"].extend(["OVQuantizer", "OVTrainingArguments"])
+else:
+ _import_structure["openvino"].extend(["OVQuantizer", "OVTrainingArguments"])
+
+
+try:
+ if not (is_openvino_available() and is_nncf_available() and is_accelerate_available()):
+ raise OptionalDependencyNotAvailable()
+except OptionalDependencyNotAvailable:
+ _import_structure["utils.dummy_openvino_and_nncf_objects"].extend(["OVTrainer"])
else:
- _import_structure["openvino"].extend(["OVQuantizer", "OVTrainer", "OVTrainingArguments"])
+ _import_structure["openvino"].extend(["OVTrainer"])
+
try:
if not (is_openvino_available() and is_diffusers_available()):
@@ -145,6 +153,7 @@
"INCSeq2SeqTrainer",
"INCTrainer",
]
+
try:
if not (is_neural_compressor_available() and is_diffusers_available()):
raise OptionalDependencyNotAvailable()
@@ -177,13 +186,17 @@
if not (is_openvino_available() and is_nncf_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
- from .utils.dummy_openvino_and_nncf_objects import (
- OVQuantizer,
- OVTrainer,
- OVTrainingArguments,
- )
+ from .utils.dummy_openvino_and_nncf_objects import OVQuantizer, OVTrainingArguments
+ else:
+ from .openvino import OVQuantizer, OVTrainingArguments
+
+ try:
+ if not (is_openvino_available() and is_nncf_available() and is_accelerate_available()):
+ raise OptionalDependencyNotAvailable()
+ except OptionalDependencyNotAvailable:
+ from .utils.dummy_openvino_and_nncf_objects import OVTrainer
else:
- from .openvino import OVQuantizer, OVTrainer, OVTrainingArguments
+ from .openvino import OVTrainer
try:
if not (is_openvino_available() and is_diffusers_available()):
diff --git a/optimum/intel/generation/modeling.py b/optimum/intel/generation/modeling.py
index 0abdafe666..3d9c657626 100644
--- a/optimum/intel/generation/modeling.py
+++ b/optimum/intel/generation/modeling.py
@@ -105,13 +105,13 @@ def __init__(
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)
self.model_dtype = kwargs.get("model_dtype", None)
- logger.warning(
- f"The class `{self.__class__}` has been depreciated and will be removed in optimum-intel v1.14, please use IPEXModel instead"
- )
if isinstance(model, torch.jit.ScriptModule):
self.input_names = {
inputs.debugName().split(".")[0] for inputs in model.graph.inputs() if inputs.debugName() != "self"
}
+ logger.warning(
+ f"The class `{self.__class__}` has been depreciated for TorchScript model, please use `IPEXModelForCausalLM` instead"
+ )
else:
self.input_names = set()
diff --git a/optimum/intel/ipex/modeling_base.py b/optimum/intel/ipex/modeling_base.py
index 2b6b569343..00fe3de115 100644
--- a/optimum/intel/ipex/modeling_base.py
+++ b/optimum/intel/ipex/modeling_base.py
@@ -22,6 +22,8 @@
import intel_extension_for_pytorch as ipex
import torch
from huggingface_hub import hf_hub_download
+from intel_extension_for_pytorch.cpu._auto_kernel_selection import _enable_tpp
+from intel_extension_for_pytorch.transformers.optimize import get_dummy_input
from transformers import (
AutoConfig,
AutoModel,
@@ -45,19 +47,69 @@
from optimum.modeling_base import OptimizedModel
from optimum.utils import NormalizedConfigManager
-from ..generation.modeling import jit_trace, prepare_jit_inputs
-from ..utils.import_utils import is_torch_version, is_transformers_version
+from ...exporters.ipex.model_patcher import _IPEX_EXPORTED_TASK, _patch_model
+from ..generation.modeling import prepare_jit_inputs
+from ..utils.import_utils import is_ipex_version, is_torch_version, is_transformers_version
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS, patch_decoder_attention_mask
logger = logging.getLogger(__name__)
+_IPEX_SUPPORT_MODEL_TYPES = ("llama",)
+
+
+def _is_patched_with_ipex(model, task):
+ if is_ipex_version("<", "2.5.0"):
+ return False
+
+ if isinstance(model, torch.jit.ScriptModule):
+ for node in model.graph.nodes():
+ # Jit will record the codes position so we can check if the node use ipex exporter.
+ if "torch_ipex::rotary_position_embedding" in node.__str__():
+ return True
+ return False
+ else:
+ return model.config.model_type in _IPEX_SUPPORT_MODEL_TYPES and task in _IPEX_EXPORTED_TASK
+
+
+def ipex_jit_trace(model, task, use_cache):
+ # Only support torch version >= 2.1.0 to support example_kwarg_inputs in jit.trace
+ if is_torch_version("<", "2.1.0"):
+ raise ImportError("`torch>=2.1.0` is needed to trace your model")
+
+ if _is_patched_with_ipex(model, task):
+ model = _patch_model(model)
+ sample_inputs = get_dummy_input(model, return_dict=True)
+ # Use Tensor Processing Primitives to accelerate linear, see https://arxiv.org/abs/2104.05755.
+ _enable_tpp()
+ else:
+ model = patch_decoder_attention_mask(model)
+ sample_inputs = prepare_jit_inputs(model, task, use_cache)
+
+ model.config.return_dict = False
+
+ model = ipex.optimize(model.eval(), dtype=model.dtype, inplace=True)
+ with torch.no_grad():
+ trace_model = torch.jit.trace(
+ model,
+ example_kwarg_inputs=sample_inputs,
+ strict=False,
+ check_trace=False,
+ )
+ trace_model = torch.jit.freeze(trace_model)
+ trace_model(**sample_inputs)
+ trace_model(**sample_inputs)
+
+ return trace_model
+
+
class IPEXModel(OptimizedModel):
auto_model_class = AutoModel
export_feature = "feature-extraction"
base_model_prefix = "ipex_model"
main_input_name = "input_ids"
+ output_name = "last_hidden_state"
def __init__(
self,
@@ -73,6 +125,7 @@ def __init__(
self._dtype = self.config.torch_dtype if self.config.torch_dtype is not None else torch.float32
self.model.to(self._device)
self.model_save_dir = model_save_dir
+ self._is_ipex_exported = _is_patched_with_ipex(model, self.export_feature)
self.input_names = {
inputs.debugName().split(".")[0] for inputs in model.graph.inputs() if inputs.debugName() != "self"
@@ -90,13 +143,13 @@ def _from_transformers(
cls,
model_id: str,
config: PretrainedConfig,
+ use_cache: bool = True,
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
cache_dir: Optional[str] = None,
subfolder: str = "",
local_files_only: bool = False,
- use_cache: bool = True,
torch_dtype: Optional[Union[str, "torch.dtype"]] = None,
trust_remote_code: bool = False,
):
@@ -116,14 +169,13 @@ def _from_transformers(
}
model = TasksManager.get_model_from_task(task, model_id, **model_kwargs)
- model = patch_decoder_attention_mask(model)
- model = ipex.optimize(model, dtype=torch_dtype, level="O1", auto_kernel_selection=True)
- traced_model = jit_trace(model, task, use_cache)
+ traced_model = ipex_jit_trace(model, task, use_cache)
save_dir = TemporaryDirectory()
save_dir_path = Path(save_dir.name)
torch.jit.save(traced_model, save_dir_path / WEIGHTS_NAME)
config.torchscript = True
+ config.torch_dtype = torch_dtype
return cls._from_pretrained(
model_id=save_dir_path,
@@ -134,6 +186,7 @@ def _from_transformers(
cache_dir=cache_dir,
local_files_only=local_files_only,
use_cache=use_cache,
+ model_dtype=torch_dtype,
)
@classmethod
@@ -193,7 +246,12 @@ def forward(
inputs["token_type_ids"] = token_type_ids
outputs = self._call_model(**inputs)
- return ModelOutput(**outputs) if isinstance(outputs, dict) else ModelOutput(logits=outputs[0])
+ if isinstance(outputs, dict):
+ model_output = ModelOutput(**outputs)
+ else:
+ model_output = ModelOutput()
+ model_output[self.output_name] = outputs[0]
+ return model_output
def eval(self):
self.model.eval()
@@ -207,6 +265,13 @@ def device(self) -> torch.device:
def dtype(self) -> torch.dtype:
return self._dtype
+ @property
+ def model_dtype(self):
+ logger.warning(
+ "access to the `model_dtype` attribute is deprecated and will be removed after v1.18.0, please use `_dtype` instead."
+ )
+ return self._dtype
+
def to(self, device: Union[torch.device, str]):
self._device = device if isinstance(device, torch.device) else torch.device(device)
self.model.to(self._device)
@@ -217,7 +282,7 @@ def can_generate(self):
def _call_model(self, *args, **kwargs):
try:
- with torch.autocast(self.device.type, self.dtype):
+ with torch.autocast(self.device.type, self.dtype), torch.no_grad():
out = self.model(*args, **kwargs)
except RuntimeError:
out = self.model(*args, **kwargs)
@@ -226,25 +291,30 @@ def _call_model(self, *args, **kwargs):
def _init_warmup(self):
# warmup, the first 2 forwards of an IPEX model include some preprocessing steps and
# the results of the compute are unpredictable
- use_cache = "past_key_values" in self.input_names
- dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache)
- for _ in range(2):
- self(**dummy_inputs)
+ # TODO : add warmup for IPEX exported model
+ if not self._is_ipex_exported:
+ use_cache = "past_key_values" in self.input_names
+ dummy_inputs = prepare_jit_inputs(self, self.export_feature, use_cache)
+ for _ in range(2):
+ self(**dummy_inputs)
class IPEXModelForSequenceClassification(IPEXModel):
auto_model_class = AutoModelForSequenceClassification
export_feature = "text-classification"
+ output_name = "logits"
class IPEXModelForTokenClassification(IPEXModel):
auto_model_class = AutoModelForTokenClassification
export_feature = "token-classification"
+ output_name = "logits"
class IPEXModelForMaskedLM(IPEXModel):
auto_model_class = AutoModelForMaskedLM
export_feature = "fill-mask"
+ output_name = "logits"
class IPEXModelForImageClassification(IPEXModel):
@@ -325,10 +395,10 @@ def __init__(
):
# Perform the initial warmup at the end of __init__
super().__init__(model, config, model_save_dir=model_save_dir, warmup=False)
+ GenerationMixin.__init__(self)
model_type = config.model_type.replace("_", "-")
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(model_type)(config)
- self.model_dtype = kwargs.get("model_dtype", self.dtype)
self.use_cache = "past_key_values" in self.input_names
if use_cache ^ self.use_cache:
@@ -348,7 +418,15 @@ def __init__(
)
except AttributeError:
self.model_cls = get_model_class(self.config, AutoModelForCausalLM._model_mapping)
- self._reorder_cache = self.model_cls._reorder_cache.__get__(self)
+
+ if self._is_ipex_exported:
+ self._reorder_cache = _ipex_reorder_cache
+ else:
+ # Check if _reorder_cache is a static method
+ if isinstance(self.model_cls.__dict__["_reorder_cache"], staticmethod):
+ self._reorder_cache = self.model_cls._reorder_cache
+ else:
+ self._reorder_cache = self.model_cls._reorder_cache.__get__(self)
if is_transformers_version(">=", "4.38.0") and model_type in {"llama", "phi", "persimmon"}:
self.prepare_inputs_for_generation = _prepare_inputs_for_generation_for_llama
@@ -374,7 +452,25 @@ def _prepare_past_key_values(self, input_ids):
else:
num_attention_heads = self.normalized_config.num_attention_heads
- if model_type == "bloom":
+ if self._is_ipex_exported:
+ # Indirect access kv cache has a different data layout compared with most transformers model,
+ # see https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/llm.html#indirect-access-kv-cache
+ beam_idx_tmp = torch.zeros(
+ (self.config.max_position_embeddings, input_ids.shape[0]), dtype=torch.long
+ ).contiguous()
+ past_key_values = tuple(
+ [
+ (
+ torch.zeros(1, 0, 0, 1, dtype=torch.long).contiguous(),
+ torch.zeros([1, 1, 1, 1]).contiguous(),
+ torch.zeros([1, 1, 1, 1]).contiguous(),
+ beam_idx_tmp,
+ )
+ for i in range(num_layers)
+ ]
+ )
+ return past_key_values
+ elif model_type == "bloom":
shape_key = (batch_size * num_attention_heads, d_k, 0)
shape_value = (batch_size * num_attention_heads, 0, d_k)
key = torch.empty(size=shape_key, dtype=self.model_dtype, device=self._device)
@@ -496,3 +592,23 @@ def _prepare_inputs_for_generation_for_llama(
}
)
return model_inputs
+
+
+def _ipex_reorder_cache(
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
+) -> Tuple[Tuple[torch.Tensor]]:
+ # Ipex patched model uses indirect access kv cache which has a different shape with other transformers models
+ if len(past_key_values[0]) == 4 and past_key_values[0][0].shape[-1] == 1:
+ for layer_past in past_key_values:
+ layer_past[3][layer_past[0].size(-2) - 1] = beam_idx
+ return past_key_values
+ elif len(past_key_values[0]) == 8:
+ for layer_past in past_key_values:
+ layer_past[3][layer_past[0].size(-2) - 1] = beam_idx
+ layer_past[7][layer_past[0].size(-2) - 1] = beam_idx
+ return past_key_values
+ else:
+ return tuple(
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
+ for layer_past in past_key_values
+ )
diff --git a/optimum/intel/openvino/__init__.py b/optimum/intel/openvino/__init__.py
index a6227615a2..1df932771a 100644
--- a/optimum/intel/openvino/__init__.py
+++ b/optimum/intel/openvino/__init__.py
@@ -14,7 +14,7 @@
import logging
-from ..utils.import_utils import is_diffusers_available, is_nncf_available
+from ..utils.import_utils import is_accelerate_available, is_diffusers_available, is_nncf_available
from .utils import (
OV_DECODER_NAME,
OV_DECODER_WITH_PAST_NAME,
@@ -37,9 +37,11 @@
patch_torch_operators()
from .quantization import OVQuantizer
- from .trainer import OVTrainer
from .training_args import OVTrainingArguments
+ if is_accelerate_available():
+ from .trainer import OVTrainer
+
from .configuration import OVConfig, OVWeightQuantizationConfig
from .modeling import (
diff --git a/optimum/intel/openvino/configuration.py b/optimum/intel/openvino/configuration.py
index 9f3e3a06ca..40a60bb58e 100644
--- a/optimum/intel/openvino/configuration.py
+++ b/optimum/intel/openvino/configuration.py
@@ -114,7 +114,7 @@ def __init__(
**kwargs,
):
super().__init__()
- self.compression = compression or DEFAULT_QUANTIZATION_CONFIG
+ self.compression = compression
self.input_info = input_info
self.save_onnx_model = save_onnx_model
self._enable_standard_onnx_export_option()
@@ -167,7 +167,7 @@ class OVWeightQuantizationConfig(QuantizationConfigMixin):
bits (`int`, defaults to 8):
The number of bits to quantize to.
- sym (`bool`, *optional*, defaults to `False`):
+ sym (`bool`, defaults to `False`):
Whether to use symetric quantization.
tokenizer (`str` or `PreTrainedTokenizerBase`, *optional*):
The tokenizer used to process the dataset. You can pass either:
@@ -177,23 +177,24 @@ class OVWeightQuantizationConfig(QuantizationConfigMixin):
user or organization name, like `dbmdz/bert-base-german-cased`.
- A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved
using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
- dataset (`Union[List[str]]`, *optional*):
- The dataset used for data-aware compression. You can provide your own dataset in a list of string or just use the
- the one from the list ['wikitext2','c4','c4-new','ptb','ptb-new']
- group_size (`int`, *optional*, defaults to 128):
- The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization.
- ratio (`float`, *optional*, defaults to 1.0):
+ dataset (`str or List[str]`, *optional*):
+ The dataset used for data-aware compression or quantization with NNCF. You can provide your own dataset
+ in a list of strings or just use the one from the list ['wikitext2','c4','c4-new','ptb','ptb-new'] for LLLMs
+ or ['conceptual_captions','laion/220k-GPT4Vision-captions-from-LIVIS','laion/filtered-wit'] for diffusion models.
+ ratio (`float`, defaults to 1.0):
The ratio between baseline and backup precisions (e.g. 0.9 means 90% of layers quantized to INT4_ASYM
and the rest to INT8_ASYM).
+ group_size (`int`, *optional*):
+ The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization.
all_layers (`bool`, *optional*):
Defines how many layers are compressed to 4-bits while the rest are kept in 8-bit presicion.
- sensitivity_metric (`nncf.SensitivityMetric`, *optional*):
+ sensitivity_metric (`str`, *optional*):
The sensitivity metric for assigning quantization precision to layers. In order to
preserve the accuracy of the model, the more sensitive layers receives a higher precision.
- awq (`bool`, *optional*):
- Enables AWQ method to unify weight ranges and improve overall model accuracy.
- ignored_scope (`nncf.IgnoredScope`, *optional*):
+ ignored_scope (`dict`, *optional*):
An ignored scope that defined the list of model control flow graph nodes to be ignored during quantization.
+ num_samples (`int`, *optional*):
+ The maximum number of samples composing the calibration dataset.
"""
@@ -202,12 +203,13 @@ def __init__(
bits: int = 8,
sym: bool = False,
tokenizer: Optional[Any] = None,
- dataset: Optional[str] = None,
+ dataset: Optional[Union[str, List[str]]] = None,
ratio: float = 1.0,
group_size: Optional[int] = None,
all_layers: Optional[bool] = None,
sensitivity_metric: Optional[str] = None,
ignored_scope: Optional[dict] = None,
+ num_samples: Optional[int] = None,
**kwargs,
):
self.bits = bits
@@ -219,6 +221,7 @@ def __init__(
self.all_layers = all_layers
self.sensitivity_metric = sensitivity_metric
self.ignored_scope = ignored_scope
+ self.num_samples = num_samples
self.quant_method = "default" # TODO : enable AWQ after nncf v2.9.0 release
self.post_init()
@@ -231,10 +234,16 @@ def post_init(self):
if self.group_size is not None and self.group_size != -1 and self.group_size <= 0:
raise ValueError("`group_size` must be greater than 0 or equal to -1")
if self.dataset is not None and isinstance(self.dataset, str):
- if self.dataset not in ["wikitext2", "c4", "c4-new", "ptb", "ptb-new"]:
+ llm_datasets = ["wikitext2", "c4", "c4-new", "ptb", "ptb-new"]
+ stable_diffusion_datasets = [
+ "conceptual_captions",
+ "laion/220k-GPT4Vision-captions-from-LIVIS",
+ "laion/filtered-wit",
+ ]
+ if self.dataset not in llm_datasets + stable_diffusion_datasets:
raise ValueError(
f"""You have entered a string value for dataset. You can only choose between
- ['wikitext2','c4','c4-new','ptb','ptb-new'], but we found {self.dataset}"""
+ {llm_datasets} for LLLMs or {stable_diffusion_datasets} for diffusion models, but we found {self.dataset}"""
)
if self.bits not in [4, 8]:
diff --git a/optimum/intel/openvino/modeling.py b/optimum/intel/openvino/modeling.py
index 7831305d5f..357ca94c07 100644
--- a/optimum/intel/openvino/modeling.py
+++ b/optimum/intel/openvino/modeling.py
@@ -434,8 +434,8 @@ def _from_transformers(
save_dir = TemporaryDirectory()
save_dir_path = Path(save_dir.name)
- # If load_in_8bit or quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
- if load_in_8bit is None or not quantization_config:
+ # If load_in_8bit and quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
+ if load_in_8bit is None and not quantization_config:
ov_config = None
else:
ov_config = OVConfig(dtype="fp32")
diff --git a/optimum/intel/openvino/modeling_base.py b/optimum/intel/openvino/modeling_base.py
index 51633b0210..15f1fc4f1c 100644
--- a/optimum/intel/openvino/modeling_base.py
+++ b/optimum/intel/openvino/modeling_base.py
@@ -57,6 +57,7 @@ def __init__(
dynamic_shapes: bool = True,
ov_config: Optional[Dict[str, str]] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
+ quantization_config: Optional[Union[OVWeightQuantizationConfig, Dict]] = None,
**kwargs,
):
self.config = config
@@ -91,6 +92,10 @@ def __init__(
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
+ self._openvino_config = None
+ if quantization_config:
+ self._openvino_config = OVConfig(quantization_config=quantization_config)
+
@staticmethod
def load_model(file_name: Union[str, Path], quantization_config: Union[OVWeightQuantizationConfig, Dict] = None):
"""
@@ -143,6 +148,15 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
dst_path = os.path.join(save_directory, OV_XML_FILE_NAME)
openvino.save_model(self.model, dst_path, compress_to_fp16=False)
+ self._save_openvino_config(save_directory)
+
+ def _save_openvino_config(self, save_directory: Union[str, Path]):
+ if self._openvino_config is not None:
+ if not isinstance(self._openvino_config.quantization_config.dataset, (str, type(None))):
+ self._openvino_config.quantization_config.dataset = None
+
+ self._openvino_config.save_pretrained(save_directory)
+
@classmethod
def _from_pretrained(
cls,
@@ -203,12 +217,28 @@ def _from_pretrained(
local_files_only=local_files_only,
)
- # Give default quantization config if not provided and load_in_8bit=True
- if load_in_8bit:
- quantization_config = quantization_config or {"bits": 8}
+ quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit)
model = cls.load_model(model_cache_path, quantization_config=quantization_config)
- return cls(model, config=config, model_save_dir=model_cache_path.parent, **kwargs)
+ return cls(
+ model,
+ config=config,
+ model_save_dir=model_cache_path.parent,
+ quantization_config=quantization_config,
+ **kwargs,
+ )
+
+ @staticmethod
+ def _prepare_weight_quantization_config(
+ quantization_config: Optional[Union[OVWeightQuantizationConfig, Dict]] = None, load_in_8bit: bool = False
+ ):
+ # Give default quantization config if not provided and load_in_8bit=True
+ if not quantization_config and load_in_8bit:
+ quantization_config = OVWeightQuantizationConfig(bits=8)
+ elif isinstance(quantization_config, dict):
+ quantization_config = OVWeightQuantizationConfig.from_dict(quantization_config)
+
+ return quantization_config
@staticmethod
def _cached_file(
@@ -284,8 +314,8 @@ def _from_transformers(
save_dir = TemporaryDirectory()
save_dir_path = Path(save_dir.name)
- # If load_in_8bit or quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
- if load_in_8bit is None or not quantization_config:
+ # If load_in_8bit and quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
+ if load_in_8bit is None and not quantization_config:
ov_config = None
else:
ov_config = OVConfig(dtype="fp32")
@@ -358,7 +388,7 @@ def compile(self):
if (
"CACHE_DIR" not in self.ov_config.keys()
and not str(self.model_save_dir).startswith(gettempdir())
- and self._device.lower() == "gpu"
+ and "gpu" in self._device.lower()
):
# Set default CACHE_DIR only if it is not set, if the model is not in a temporary directory, and device is GPU
cache_dir = Path(self.model_save_dir).joinpath("model_cache")
diff --git a/optimum/intel/openvino/modeling_base_seq2seq.py b/optimum/intel/openvino/modeling_base_seq2seq.py
index df9449b0b5..28e112c4d9 100644
--- a/optimum/intel/openvino/modeling_base_seq2seq.py
+++ b/optimum/intel/openvino/modeling_base_seq2seq.py
@@ -58,6 +58,7 @@ def __init__(
dynamic_shapes: bool = True,
ov_config: Optional[Dict[str, str]] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
+ quantization_config: Union[OVWeightQuantizationConfig, Dict] = None,
**kwargs,
):
self.config = config
@@ -76,6 +77,9 @@ def __init__(
self.decoder_model = decoder
self.decoder_with_past_model = decoder_with_past
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
+ self._openvino_config = None
+ if quantization_config:
+ self._openvino_config = OVConfig(quantization_config=quantization_config)
def _save_pretrained(self, save_directory: Union[str, Path]):
"""
@@ -96,6 +100,8 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
dst_path = os.path.join(save_directory, dst_file_name)
openvino.save_model(src_file, dst_path, compress_to_fp16=False)
+ self._save_openvino_config(save_directory)
+
@classmethod
def _from_pretrained(
cls,
@@ -155,9 +161,7 @@ def _from_pretrained(
decoder_with_past_file_name = decoder_with_past_file_name or default_decoder_with_past_file_name
decoder_with_past = None
- # Give default quantization config if not provided and load_in_8bit=True
- if load_in_8bit:
- quantization_config = quantization_config or {"bits": 8}
+ quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit)
# Load model from a local directory
if os.path.isdir(model_id):
@@ -205,6 +209,7 @@ def _from_pretrained(
decoder_with_past=decoder_with_past,
config=config,
model_save_dir=model_save_dir,
+ quantization_config=quantization_config,
**kwargs,
)
@@ -253,8 +258,8 @@ def _from_transformers(
if use_cache:
task = task + "-with-past"
- # If load_in_8bit or quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
- if load_in_8bit is None or not quantization_config:
+ # If load_in_8bit and quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
+ if load_in_8bit is None and not quantization_config:
ov_config = None
else:
ov_config = OVConfig(dtype="fp32")
diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py
index c0274d3f5b..53aa05bc5a 100644
--- a/optimum/intel/openvino/modeling_decoder.py
+++ b/optimum/intel/openvino/modeling_decoder.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import copy
import logging
import os
from pathlib import Path
@@ -100,6 +101,7 @@ def __init__(
dynamic_shapes: bool = True,
ov_config: Optional[Dict[str, str]] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
+ quantization_config: Optional[Union[OVWeightQuantizationConfig, Dict]] = None,
**kwargs,
):
if not dynamic_shapes:
@@ -117,6 +119,7 @@ def __init__(
dynamic_shapes=False,
ov_config=ov_config,
model_save_dir=model_save_dir,
+ quantization_config=quantization_config,
**kwargs,
)
@@ -224,6 +227,8 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
dst_path = os.path.join(save_directory, OV_XML_FILE_NAME)
openvino.save_model(model_to_save, dst_path, compress_to_fp16=False)
+ self._save_openvino_config(save_directory)
+
@classmethod
def _from_transformers(
cls,
@@ -255,11 +260,11 @@ def _from_transformers(
if use_cache:
task = task + "-with-past"
- # If load_in_8bit or quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
- if load_in_8bit is None or not quantization_config:
- ov_config = None
+ # If load_in_8bit and quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
+ if load_in_8bit is None and not quantization_config:
+ ov_export_config = None
else:
- ov_config = OVConfig(dtype="fp32")
+ ov_export_config = OVConfig(dtype="fp32")
stateful = kwargs.pop("stateful", ensure_stateful_is_available(warn=False) and use_cache)
@@ -274,7 +279,7 @@ def _from_transformers(
local_files_only=local_files_only,
force_download=force_download,
trust_remote_code=trust_remote_code,
- ov_config=ov_config,
+ ov_config=ov_export_config,
stateful=stateful,
)
@@ -576,15 +581,10 @@ def _from_pretrained(
local_files_only=local_files_only,
)
- # Give default quantization config if not provided and load_in_8bit=True
- if load_in_8bit:
- quantization_config = quantization_config or {"bits": 8}
+ if isinstance(quantization_config, dict) and quantization_config == {"bits": 4}:
+ quantization_config = _DEFAULT_4BIT_CONFIGS.get(config.name_or_path, quantization_config)
- if isinstance(quantization_config, dict):
- if quantization_config == {"bits": 4} and config.name_or_path in _DEFAULT_4BIT_CONFIGS:
- quantization_config = _DEFAULT_4BIT_CONFIGS[config.name_or_path]
-
- quantization_config = OVWeightQuantizationConfig.from_dict(quantization_config)
+ quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit)
load_in_4bit = quantization_config.bits == 4 if quantization_config else False
model = cls.load_model(model_cache_path, quantization_config=None if load_in_4bit else quantization_config)
@@ -603,7 +603,12 @@ def _from_pretrained(
enable_compilation = kwargs.pop("compile", True) and not load_in_4bit
causal_model = init_cls(
- model=model, config=config, model_save_dir=model_cache_path.parent, compile=enable_compilation, **kwargs
+ model=model,
+ config=config,
+ model_save_dir=model_cache_path.parent,
+ compile=enable_compilation,
+ quantization_config=quantization_config,
+ **kwargs,
)
if load_in_4bit:
@@ -630,8 +635,10 @@ def _from_pretrained(
# from optimum.gptq.utils import get_seqlen
# seqlen = get_seqlen(causal_model)
- dataset = get_dataset(quantization_config.dataset, tokenizer, seqlen=32)
+ nsamples = quantization_config.num_samples if quantization_config.num_samples else 128
+ dataset = get_dataset(quantization_config.dataset, tokenizer, seqlen=32, nsamples=nsamples)
dataset = prepare_dataset(dataset)
+ quantization_config = copy.deepcopy(quantization_config)
quantization_config.dataset = nncf.Dataset(dataset, lambda x: causal_model.prepare_inputs(**x))
_weight_only_quantization(model, quantization_config)
diff --git a/optimum/intel/openvino/modeling_diffusion.py b/optimum/intel/openvino/modeling_diffusion.py
index 5633f852a8..f0fea5a8ce 100644
--- a/optimum/intel/openvino/modeling_diffusion.py
+++ b/optimum/intel/openvino/modeling_diffusion.py
@@ -16,6 +16,7 @@
import logging
import os
import shutil
+from copy import deepcopy
from pathlib import Path
from tempfile import TemporaryDirectory, gettempdir
from typing import Any, Dict, List, Optional, Union
@@ -57,7 +58,13 @@
from .configuration import OVConfig, OVWeightQuantizationConfig
from .loaders import OVTextualInversionLoaderMixin
from .modeling_base import OVBaseModel
-from .utils import ONNX_WEIGHTS_NAME, OV_TO_NP_TYPE, OV_XML_FILE_NAME, _print_compiled_model_properties
+from .utils import (
+ ONNX_WEIGHTS_NAME,
+ OV_TO_NP_TYPE,
+ OV_XML_FILE_NAME,
+ PREDEFINED_SD_DATASETS,
+ _print_compiled_model_properties,
+)
core = Core()
@@ -87,15 +94,25 @@ def __init__(
compile: bool = True,
ov_config: Optional[Dict[str, str]] = None,
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
+ quantization_config: Optional[Union[OVWeightQuantizationConfig, Dict]] = None,
**kwargs,
):
self._internal_dict = config
self._device = device.upper()
self.is_dynamic = dynamic_shapes
self.ov_config = ov_config if ov_config is not None else {}
- self._model_save_dir = (
- Path(model_save_dir.name) if isinstance(model_save_dir, TemporaryDirectory) else model_save_dir
- )
+
+ # This attribute is needed to keep one reference on the temporary directory, since garbage collecting
+ # would end-up removing the directory containing the underlying OpenVINO model
+ self._model_save_dir_tempdirectory_instance = None
+ if isinstance(model_save_dir, TemporaryDirectory):
+ self._model_save_dir_tempdirectory_instance = model_save_dir
+ self._model_save_dir = Path(model_save_dir.name)
+ elif isinstance(model_save_dir, str):
+ self._model_save_dir = Path(model_save_dir)
+ else:
+ self._model_save_dir = model_save_dir
+
self.vae_decoder = OVModelVaeDecoder(vae_decoder, self)
self.unet = OVModelUnet(unet, self)
self.text_encoder = OVModelTextEncoder(text_encoder, self) if text_encoder is not None else None
@@ -140,6 +157,10 @@ def __init__(
self._internal_dict.pop("vae", None)
+ self._openvino_config = None
+ if quantization_config:
+ self._openvino_config = OVConfig(quantization_config=quantization_config)
+
def _save_pretrained(self, save_directory: Union[str, Path]):
"""
Saves the model to the OpenVINO IR format so that it can be re-loaded using the
@@ -177,6 +198,8 @@ def _save_pretrained(self, save_directory: Union[str, Path]):
if self.tokenizer_2 is not None:
self.tokenizer_2.save_pretrained(save_directory / "tokenizer_2")
+ self._save_openvino_config(save_directory)
+
@classmethod
def _from_pretrained(
cls,
@@ -257,13 +280,20 @@ def _from_pretrained(
else:
kwargs[name] = load_method(new_model_save_dir)
- # Give default quantization config if not provided and load_in_8bit=True
- if load_in_8bit:
- quantization_config = quantization_config or {"bits": 8}
-
- unet = cls.load_model(
- new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name, quantization_config
- )
+ quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit)
+
+ unet_path = new_model_save_dir / DIFFUSION_MODEL_UNET_SUBFOLDER / unet_file_name
+ if quantization_config is not None and quantization_config.dataset is not None:
+ # load the UNet model uncompressed to apply hybrid quantization further
+ unet = cls.load_model(unet_path)
+ # Apply weights compression to other `components` without dataset
+ weight_quantization_params = {
+ param: value for param, value in quantization_config.__dict__.items() if param != "dataset"
+ }
+ weight_quantization_config = OVWeightQuantizationConfig.from_dict(weight_quantization_params)
+ else:
+ weight_quantization_config = quantization_config
+ unet = cls.load_model(unet_path, weight_quantization_config)
components = {
"vae_encoder": new_model_save_dir / DIFFUSION_MODEL_VAE_ENCODER_SUBFOLDER / vae_encoder_file_name,
@@ -273,12 +303,93 @@ def _from_pretrained(
}
for key, value in components.items():
- components[key] = cls.load_model(value, quantization_config) if value.is_file() else None
+ components[key] = cls.load_model(value, weight_quantization_config) if value.is_file() else None
if model_save_dir is None:
model_save_dir = new_model_save_dir
- return cls(unet=unet, config=config, model_save_dir=model_save_dir, **components, **kwargs)
+ if quantization_config is not None and quantization_config.dataset is not None:
+ sd_model = cls(unet=unet, config=config, model_save_dir=model_save_dir, **components, **kwargs)
+
+ supported_pipelines = (
+ OVStableDiffusionPipeline,
+ OVStableDiffusionXLPipeline,
+ OVLatentConsistencyModelPipeline,
+ )
+ if not isinstance(sd_model, supported_pipelines):
+ raise NotImplementedError(f"Quantization in hybrid mode is not supported for {cls.__name__}")
+
+ nsamples = quantization_config.num_samples if quantization_config.num_samples else 200
+ unet_inputs = sd_model._prepare_unet_inputs(quantization_config.dataset, nsamples)
+
+ from .quantization import _hybrid_quantization
+
+ unet = _hybrid_quantization(sd_model.unet.model, weight_quantization_config, dataset=unet_inputs)
+
+ return cls(
+ unet=unet,
+ config=config,
+ model_save_dir=model_save_dir,
+ quantization_config=quantization_config,
+ **components,
+ **kwargs,
+ )
+
+ def _prepare_unet_inputs(
+ self,
+ dataset: Union[str, List[Any]],
+ num_samples: int,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ seed: Optional[int] = 42,
+ **kwargs,
+ ) -> Dict[str, Any]:
+ self.compile()
+
+ size = self.unet.config.get("sample_size", 64) * self.vae_scale_factor
+ height = height or min(size, 512)
+ width = width or min(size, 512)
+
+ if isinstance(dataset, str):
+ dataset = deepcopy(dataset)
+ available_datasets = PREDEFINED_SD_DATASETS.keys()
+ if dataset not in available_datasets:
+ raise ValueError(
+ f"""You have entered a string value for dataset. You can only choose between
+ {list(available_datasets)}, but the {dataset} was found"""
+ )
+
+ from datasets import load_dataset
+
+ dataset_metadata = PREDEFINED_SD_DATASETS[dataset]
+ dataset = load_dataset(dataset, split=dataset_metadata["split"], streaming=True).shuffle(seed=seed)
+ input_names = dataset_metadata["inputs"]
+ dataset = dataset.select_columns(list(input_names.values()))
+
+ def transform_fn(data_item):
+ return {inp_name: data_item[column] for inp_name, column in input_names.items()}
+
+ else:
+
+ def transform_fn(data_item):
+ return data_item if isinstance(data_item, (list, dict)) else [data_item]
+
+ from .quantization import InferRequestWrapper
+
+ calibration_data = []
+ self.unet.request = InferRequestWrapper(self.unet.request, calibration_data)
+
+ for inputs in dataset:
+ inputs = transform_fn(inputs)
+ if isinstance(inputs, dict):
+ self.__call__(**inputs, height=height, width=width)
+ else:
+ self.__call__(*inputs, height=height, width=width)
+ if len(calibration_data) > num_samples:
+ break
+
+ self.unet.request = self.unet.request.request
+ return calibration_data[:num_samples]
@classmethod
def _from_transformers(
@@ -301,8 +412,8 @@ def _from_transformers(
save_dir = TemporaryDirectory()
save_dir_path = Path(save_dir.name)
- # If load_in_8bit or quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
- if load_in_8bit is None or not quantization_config:
+ # If load_in_8bit and quantization_config not specified then ov_config is set to None and will be set by default in convert depending on the model size
+ if load_in_8bit is None and not quantization_config:
ov_config = None
else:
ov_config = OVConfig(dtype="fp32")
@@ -558,7 +669,7 @@ def _compile(self):
if (
"CACHE_DIR" not in self.ov_config.keys()
and not str(self._model_dir).startswith(gettempdir())
- and self.device.lower() == "gpu"
+ and self.device.lower().split(":")[0] == "gpu"
):
self.ov_config["CACHE_DIR"] = os.path.join(self._model_dir, self._model_name, "model_cache")
diff --git a/optimum/intel/openvino/modeling_seq2seq.py b/optimum/intel/openvino/modeling_seq2seq.py
index 617d898be5..d68cbc75ed 100644
--- a/optimum/intel/openvino/modeling_seq2seq.py
+++ b/optimum/intel/openvino/modeling_seq2seq.py
@@ -451,7 +451,7 @@ def _compile(self):
if (
"CACHE_DIR" not in ov_config.keys()
and not str(self.parent_model.model_save_dir).startswith(gettempdir())
- and self._device.lower() == "gpu"
+ and "gpu" in self._device.lower()
):
cache_dir = Path(self.parent_model.model_save_dir).joinpath("model_cache")
ov_config["CACHE_DIR"] = str(cache_dir)
@@ -563,7 +563,7 @@ def _compile(self):
if (
"CACHE_DIR" not in ov_config.keys()
and not str(self.parent_model.model_save_dir).startswith(gettempdir())
- and self._device.lower() == "gpu"
+ and "gpu" in self._device.lower()
):
cache_dir = Path(self.parent_model.model_save_dir).joinpath("model_cache")
ov_config["CACHE_DIR"] = str(cache_dir)
diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py
index 5ec4eac556..c46f29092b 100644
--- a/optimum/intel/openvino/quantization.py
+++ b/optimum/intel/openvino/quantization.py
@@ -12,17 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import copy
import inspect
import logging
import os
+from collections import deque
from pathlib import Path
-from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
+from typing import Any, Callable, Dict, Optional, Tuple, Union
import nncf
import openvino
import torch
import transformers
from nncf import CompressWeightsMode, IgnoredScope, NNCFConfig, SensitivityMetric
+from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters
from nncf.torch import create_compressed_model, register_default_init_args, register_module
from nncf.torch.dynamic_graph.io_handling import wrap_nncf_model_inputs_with_objwalk
from nncf.torch.initialization import PTInitializingDataLoader
@@ -44,7 +47,7 @@
from ..utils.constant import _TASK_ALIASES
from ..utils.import_utils import DATASETS_IMPORT_ERROR, is_datasets_available
from ..utils.modeling_utils import get_model_device
-from .configuration import OVConfig, OVWeightQuantizationConfig
+from .configuration import DEFAULT_QUANTIZATION_CONFIG, OVConfig, OVWeightQuantizationConfig
from .modeling_base import OVBaseModel
from .utils import (
MAX_ONNX_OPSET,
@@ -55,8 +58,7 @@
if is_datasets_available():
- if TYPE_CHECKING:
- from datasets import Dataset
+ from datasets import Dataset
register_module(ignored_algorithms=[])(Conv1D)
@@ -87,11 +89,14 @@ def __init__(self, request, data_cache=None):
self.data_cache = data_cache
def __call__(self, *args, **kwargs):
- self.data_cache.append(*args)
+ # If __call__ is invoked then self.request must be an instance of CompiledModel
+ signature = inspect.signature(self.request)
+ bound_args = signature.bind(*args, **kwargs).arguments
+ self.data_cache.append(copy.deepcopy(bound_args["inputs"]))
return self.request(*args, **kwargs)
def infer(self, inputs: Any = None, share_inputs: bool = False):
- self.data_cache.append(inputs)
+ self.data_cache.append(copy.deepcopy(inputs))
return self.request.infer(inputs, share_inputs)
def start_async(
@@ -102,7 +107,7 @@ def start_async(
*,
shared_memory: Any = None,
):
- self.data_cache.append(inputs)
+ self.data_cache.append(copy.deepcopy(inputs))
self.request.infer(inputs, share_inputs, share_outputs=True)
def wait(self):
@@ -143,6 +148,7 @@ def __init__(self, model: transformers.PreTrainedModel, task: Optional[str] = No
)
self.task = task or feature
self.seed = seed
+ # TODO : deprecate input_names
self.input_names = None
signature = inspect.signature(self.model.forward)
self._signature_columns = list(signature.parameters.keys())
@@ -231,8 +237,11 @@ def quantize(
)
ov_config = ov_config or quantization_config
- if ov_config is not None and not isinstance(ov_config, OVConfig):
- raise TypeError(f"`ov_config` should be an `OVConfig`, but got: {type(ov_config)} instead.")
+ if ov_config is not None:
+ if not isinstance(ov_config, OVConfig):
+ raise TypeError(f"`ov_config` should be an `OVConfig`, but got: {type(ov_config)} instead.")
+ elif ov_config.compression is None:
+ ov_config.compression = DEFAULT_QUANTIZATION_CONFIG
if isinstance(self.model, OVBaseModel):
self._quantize_ovbasemodel(
@@ -351,7 +360,7 @@ def _quantize_torchmodel(
logger.info(
"No configuration describing the quantization process was provided, a default OVConfig will be generated."
)
- ov_config = OVConfig()
+ ov_config = OVConfig(compression=DEFAULT_QUANTIZATION_CONFIG)
onnx_file_name = (
ONNX_WEIGHTS_NAME
if file_name is None and ov_config.save_onnx_model
@@ -519,9 +528,15 @@ def _get_calibration_dataloader(
data_collator: Optional[DataCollator] = None,
) -> OVDataLoader:
data_collator = data_collator if data_collator is not None else default_data_collator
+
+ if not is_datasets_available() or not isinstance(calibration_dataset, Dataset):
+ logger.warning(
+ "`remove_unused_columns` set to `False` as calibration_dataset is not an instance of `datasets.Dataset`"
+ )
+ remove_unused_columns = False
+
if remove_unused_columns:
calibration_dataset = self._remove_unused_columns(calibration_dataset)
- self.input_names = calibration_dataset.column_names
generator = torch.Generator()
generator.manual_seed(self.seed)
sampler = RandomSampler(calibration_dataset, generator=generator)
@@ -537,7 +552,7 @@ def _remove_unused_columns(self, dataset: "Dataset"):
def _weight_only_quantization(
model: openvino.runtime.Model, quantization_config: Union[OVWeightQuantizationConfig, Dict]
-):
+) -> openvino.runtime.Model:
config = quantization_config
if isinstance(config, dict):
config = OVWeightQuantizationConfig.from_dict(quantization_config)
@@ -551,7 +566,8 @@ def _weight_only_quantization(
from optimum.gptq.data import get_dataset, prepare_dataset
- dataset = get_dataset(config.dataset, tokenizer, seqlen=32)
+ nsamples = config.num_samples if config.num_samples else 128
+ dataset = get_dataset(config.dataset, tokenizer, seqlen=32, nsamples=nsamples)
dataset = prepare_dataset(dataset)
sensitivity_metric = None
@@ -577,4 +593,92 @@ def _weight_only_quantization(
# awq=config.quant_method == "awq", # TODO : remove and add it back once nncf v2.9.0
ignored_scope=ignored_scope,
dataset=dataset,
+ # subset_size=config.num_samples if config.num_samples else 128, # TODO : enable from nncf v2.9.0
+ )
+
+
+def _get_operation_const_op(operation, const_port_id: int):
+ node = operation.input_value(const_port_id).get_node()
+ queue = deque([node])
+ constant_node = None
+ allowed_propagation_types_list = ["Convert", "FakeQuantize", "Reshape"]
+
+ while len(queue) != 0:
+ curr_node = queue.popleft()
+ if curr_node.get_type_name() == "Constant":
+ constant_node = curr_node
+ break
+ if len(curr_node.inputs()) == 0:
+ break
+ if curr_node.get_type_name() in allowed_propagation_types_list:
+ queue.append(curr_node.input_value(0).get_node())
+
+ return constant_node
+
+
+def _is_embedding(node) -> bool:
+ allowed_types_list = ["f16", "f32", "f64"]
+ const_port_id = 0
+ input_tensor = node.input_value(const_port_id)
+ if input_tensor.get_element_type().get_type_name() in allowed_types_list:
+ const_node = _get_operation_const_op(node, const_port_id)
+ if const_node is not None:
+ return True
+
+ return False
+
+
+def _collect_ops_with_weights(model):
+ ops_with_weights = []
+ for op in model.get_ops():
+ if op.get_type_name() == "MatMul":
+ constant_node_0 = _get_operation_const_op(op, const_port_id=0)
+ constant_node_1 = _get_operation_const_op(op, const_port_id=1)
+ if constant_node_0 or constant_node_1:
+ ops_with_weights.append(op.get_friendly_name())
+ if op.get_type_name() == "Gather" and _is_embedding(op):
+ ops_with_weights.append(op.get_friendly_name())
+
+ return ops_with_weights
+
+
+def _hybrid_quantization(
+ model: openvino.runtime.Model, quantization_config: OVWeightQuantizationConfig, dataset: Dict[str, Any]
+) -> openvino.runtime.Model:
+ """
+ Quantize a model in hybrid mode with NNCF which means that we quantize:
+ weights of MatMul and Embedding layers and activations of other layers.
+ The optimization specifications defined in `quantization_config`.
+
+ Args:
+ model (`openvino.runtime.Model`):
+ The OpenVINO Runtime model for applying hybrid quantization.
+ quantization_config (`OVWeightQuantizationConfig`):
+ The configuration containing the parameters related to quantization.
+ dataset (`Dict[str, Any]`):
+ The dataset used for hybrid quantization.
+ Returns:
+ The OpenVINO Runtime model with applied hybrid quantization.
+ """
+ ops_to_compress = _collect_ops_with_weights(model)
+
+ ignored_scope = quantization_config.ignored_scope if isinstance(quantization_config.ignored_scope, dict) else {}
+ ptq_ignored_scope = nncf.IgnoredScope(**ignored_scope)
+ ptq_ignored_scope.names += ops_to_compress
+
+ wc_quantization_config = copy.deepcopy(quantization_config)
+ wc_quantization_config.ignored_scope = ignored_scope
+ wc_quantization_config.ignored_scope["types"] = ignored_scope.get("types", []) + ["Convolution"]
+ compressed_model = _weight_only_quantization(model, wc_quantization_config)
+
+ subset_size = quantization_config.num_samples if quantization_config.num_samples else 200
+ quantized_model = nncf.quantize(
+ model=compressed_model,
+ calibration_dataset=nncf.Dataset(dataset),
+ model_type=nncf.ModelType.TRANSFORMER,
+ ignored_scope=ptq_ignored_scope,
+ # The SQ algo should be disabled for MatMul nodes because their weights are already compressed
+ advanced_parameters=nncf.AdvancedQuantizationParameters(AdvancedSmoothQuantParameters(matmul=-1)),
+ subset_size=subset_size,
)
+ return quantized_model
diff --git a/optimum/intel/openvino/trainer.py b/optimum/intel/openvino/trainer.py
index 5c7d392292..b7d110c96a 100644
--- a/optimum/intel/openvino/trainer.py
+++ b/optimum/intel/openvino/trainer.py
@@ -89,7 +89,7 @@
from ..utils.constant import _TASK_ALIASES
from ..utils.import_utils import is_transformers_version
-from .configuration import OVConfig
+from .configuration import DEFAULT_QUANTIZATION_CONFIG, OVConfig
from .quantization import OVDataLoader
from .training_args import OVTrainingArguments
from .utils import (
@@ -225,37 +225,41 @@ def __init__(
self.teacher.eval()
self.compression_controller = None
- if self.ov_config is not None and self.args.do_train:
- self._set_task()
- train_dataloader = self.get_train_dataloader()
- model_inputs = next(iter(train_dataloader))
- for label_name in self.label_names:
- model_inputs.pop(label_name)
- force_batch_one = self._is_pruning_enabled()
- self.ov_config.add_input_info(model_inputs, force_batch_one)
- nncf_config = NNCFConfig.from_dict(self.ov_config.__dict__)
- nncf_config.register_extra_structs(
- [
- QuantizationRangeInitArgs(OVDataLoader(train_dataloader)),
- BNAdaptationInitArgs(OVDataLoader(train_dataloader)),
- ]
- )
+ if self.ov_config is not None:
+ if self.ov_config.compression is None:
+ self.ov_config.compression = DEFAULT_QUANTIZATION_CONFIG
+
+ if self.args.do_train:
+ self._set_task()
+ train_dataloader = self.get_train_dataloader()
+ model_inputs = next(iter(train_dataloader))
+ for label_name in self.label_names:
+ model_inputs.pop(label_name)
+ force_batch_one = self._is_pruning_enabled()
+ self.ov_config.add_input_info(model_inputs, force_batch_one)
+ nncf_config = NNCFConfig.from_dict(self.ov_config.__dict__)
+ nncf_config.register_extra_structs(
+ [
+ QuantizationRangeInitArgs(OVDataLoader(train_dataloader)),
+ BNAdaptationInitArgs(OVDataLoader(train_dataloader)),
+ ]
+ )
- # Configure NNCF logging
- # Disable nncf logging to stdout except error
- # but to file nncf_output.log
- nncf_config["log_dir"] = args.output_dir
- nncf_log_file_handler = logging.logging.FileHandler(os.path.join(args.output_dir, NNCF_LOG_FILE_NAME))
- nncf_log_file_handler.setFormatter(logging.logging.Formatter("%(levelname)s:%(name)s:%(message)s"))
- nncf_logger.addHandler(nncf_log_file_handler)
- set_log_level(logging.ERROR)
- nncf_logger.setLevel(logging.INFO)
- nncf_log_file_handler.setLevel(logging.INFO)
-
- self.compression_controller, self.model = create_compressed_model(self.model, nncf_config)
- self.model_wrapped = self.model
- # TODO : To deprecate once support transformers > 4.30.0
- self.deepspeed = None
+ # Configure NNCF logging
+ # Disable nncf logging to stdout except error
+ # but to file nncf_output.log
+ nncf_config["log_dir"] = args.output_dir
+ nncf_log_file_handler = logging.logging.FileHandler(os.path.join(args.output_dir, NNCF_LOG_FILE_NAME))
+ nncf_log_file_handler.setFormatter(logging.logging.Formatter("%(levelname)s:%(name)s:%(message)s"))
+ nncf_logger.addHandler(nncf_log_file_handler)
+ set_log_level(logging.ERROR)
+ nncf_logger.setLevel(logging.INFO)
+ nncf_log_file_handler.setLevel(logging.INFO)
+
+ self.compression_controller, self.model = create_compressed_model(self.model, nncf_config)
+ self.model_wrapped = self.model
+ # TODO : To deprecate once support transformers > 4.30.0
+ self.deepspeed = None
def _set_signature_columns_if_needed(self):
if self._signature_columns is None:
diff --git a/optimum/intel/openvino/utils.py b/optimum/intel/openvino/utils.py
index 49aec81e57..a0439d2129 100644
--- a/optimum/intel/openvino/utils.py
+++ b/optimum/intel/openvino/utils.py
@@ -20,7 +20,7 @@
import numpy as np
from huggingface_hub import model_info
-from openvino.runtime import Type, properties
+from openvino.runtime import Core, Type, properties
from transformers.onnx.utils import ParameterFormat, compute_serialized_parameters_size
@@ -99,6 +99,13 @@
}
+PREDEFINED_SD_DATASETS = {
+ "conceptual_captions": {"split": "train", "inputs": {"prompt": "caption"}},
+ "laion/220k-GPT4Vision-captions-from-LIVIS": {"split": "train", "inputs": {"prompt": "caption"}},
+ "laion/filtered-wit": {"split": "train", "inputs": {"prompt": "caption"}},
+}
+
+
def use_external_data_format(num_parameters: int) -> bool:
"""
Returns whether or not the model requires using external data format for the ONNX export
@@ -148,3 +155,9 @@ def _print_compiled_model_properties(compiled_model):
logger.info(f" {k}: {value}")
except Exception:
logger.error(f"[error] Get property of '{k}' failed")
+ try:
+ logger.info("EXECUTION_DEVICES:")
+ for device in compiled_model.get_property("EXECUTION_DEVICES"):
+ logger.info(f" {device}: {Core().get_property(device, 'FULL_DEVICE_NAME')}")
+ except Exception:
+ logger.error("[error] Get FULL_DEVICE_NAME failed")
diff --git a/optimum/intel/utils/__init__.py b/optimum/intel/utils/__init__.py
index 4e7522ee77..d77588f896 100644
--- a/optimum/intel/utils/__init__.py
+++ b/optimum/intel/utils/__init__.py
@@ -16,6 +16,7 @@
_neural_compressor_version,
_torch_version,
compare_versions,
+ is_accelerate_available,
is_diffusers_available,
is_ipex_available,
is_neural_compressor_available,
diff --git a/optimum/intel/utils/dummy_openvino_and_nncf_objects.py b/optimum/intel/utils/dummy_openvino_and_nncf_objects.py
index 45c390aff2..8ae3135667 100644
--- a/optimum/intel/utils/dummy_openvino_and_nncf_objects.py
+++ b/optimum/intel/utils/dummy_openvino_and_nncf_objects.py
@@ -27,14 +27,14 @@ def from_pretrained(cls, *args, **kwargs):
class OVTrainer(metaclass=DummyObject):
- _backends = ["openvino", "nncf"]
+ _backends = ["openvino", "nncf", "accelerate"]
def __init__(self, *args, **kwargs):
- requires_backends(self, ["openvino", "nncf"])
+ requires_backends(self, ["openvino", "nncf", "accelerate"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
- requires_backends(cls, ["openvino", "nncf"])
+ requires_backends(cls, ["openvino", "nncf", "accelerate"])
class OVQuantizer(metaclass=DummyObject):
diff --git a/optimum/intel/utils/import_utils.py b/optimum/intel/utils/import_utils.py
index 9e0fd3f4c0..3599104c75 100644
--- a/optimum/intel/utils/import_utils.py
+++ b/optimum/intel/utils/import_utils.py
@@ -178,6 +178,16 @@
_datasets_available = False
+_accelerate_available = importlib.util.find_spec("accelerate") is not None
+_accelerate_version = "N/A"
+
+if _accelerate_available:
+ try:
+ _accelerate_version = importlib_metadata.version("accelerate")
+ except importlib_metadata.PackageNotFoundError:
+ _accelerate_available = False
+
+
def is_transformers_available():
return _transformers_available
@@ -218,6 +228,10 @@ def is_datasets_available():
return _datasets_available
+def is_accelerate_available():
+ return _accelerate_available
+
+
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
"""
@@ -339,6 +353,11 @@ def is_timm_version(operation: str, version: str):
`pip install datasets`. Please note that you may need to restart your runtime after installation.
"""
+ACCELERATE_IMPORT_ERROR = """
+{0} requires the accelerate library but it was not found in your environment. You can install it with pip:
+`pip install accelerate`. Please note that you may need to restart your runtime after installation.
+"""
+
BACKENDS_MAPPING = OrderedDict(
[
("diffusers", (is_diffusers_available, DIFFUSERS_IMPORT_ERROR)),
@@ -346,6 +365,7 @@ def is_timm_version(operation: str, version: str):
("nncf", (is_nncf_available, NNCF_IMPORT_ERROR)),
("openvino", (is_openvino_available, OPENVINO_IMPORT_ERROR)),
("neural_compressor", (is_neural_compressor_available, NEURAL_COMPRESSOR_IMPORT_ERROR)),
+ ("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)),
]
)
diff --git a/setup.py b/setup.py
index 93c370d7c4..0b065325ff 100644
--- a/setup.py
+++ b/setup.py
@@ -18,10 +18,11 @@
"datasets>=1.4.0",
"sentencepiece",
"scipy",
- "accelerate", # transformers 4.29 require accelerate for PyTorch
+ "onnx",
]
TESTS_REQUIRE = [
+ "accelerate",
"pytest",
"parameterized",
"Pillow",
@@ -39,14 +40,10 @@
QUALITY_REQUIRE = ["black~=23.1", "ruff>=0.0.241"]
EXTRAS_REQUIRE = {
- "neural-compressor": [
- "neural-compressor>=2.2.0",
- "onnx",
- "onnxruntime<1.15.0",
- ],
- "openvino": ["openvino>=2023.3", "onnx", "onnxruntime", "openvino-tokenizers[transformers]"],
+ "neural-compressor": ["neural-compressor>=2.2.0", "onnxruntime<1.15.0", "accelerate"],
+ "openvino": ["openvino>=2023.3", "nncf>=2.8.1", "openvino-tokenizers[transformers]"],
"nncf": ["nncf>=2.8.1"],
- "ipex": ["intel-extension-for-pytorch", "onnx"],
+ "ipex": ["intel-extension-for-pytorch"],
"diffusers": ["diffusers"],
"quality": QUALITY_REQUIRE,
"tests": TESTS_REQUIRE,
diff --git a/tests/generation/test_modeling.py b/tests/generation/test_modeling.py
index b97fd66a83..9b637d322d 100644
--- a/tests/generation/test_modeling.py
+++ b/tests/generation/test_modeling.py
@@ -31,6 +31,7 @@
"gpt_neo": "hf-internal-testing/tiny-random-GPTNeoModel",
"mistral": "echarlaix/tiny-random-mistral",
"llama": "fxmarty/tiny-llama-fast-tokenizer",
+ "llama2": "Jiqing/tiny_random_llama2",
"gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
}
@@ -54,6 +55,7 @@ class ModelingIntegrationTest(unittest.TestCase):
"gpt_neo",
"mistral",
"llama",
+ "llama2",
# "gpt_bigcode",
)
diff --git a/tests/ipex/test_inference.py b/tests/ipex/test_inference.py
index bc1890453d..e120514506 100644
--- a/tests/ipex/test_inference.py
+++ b/tests/ipex/test_inference.py
@@ -42,6 +42,7 @@
"gpt_neox": "hf-internal-testing/tiny-random-GPTNeoXForCausalLM",
"gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
"llama": "fxmarty/tiny-llama-fast-tokenizer",
+ "llama2": "Jiqing/tiny_random_llama2",
"opt": "hf-internal-testing/tiny-random-OPTModel",
"mpt": "hf-internal-testing/tiny-random-MptForCausalLM",
}
@@ -66,6 +67,7 @@ class IPEXIntegrationTest(unittest.TestCase):
"gpt_neo",
# "gpt_bigcode",
"llama",
+ "llama2",
"opt",
"mpt",
)
diff --git a/tests/ipex/test_modeling.py b/tests/ipex/test_modeling.py
index ffc2ca6a89..68119287d8 100644
--- a/tests/ipex/test_modeling.py
+++ b/tests/ipex/test_modeling.py
@@ -26,6 +26,7 @@
AutoModelForCausalLM,
AutoModelForQuestionAnswering,
AutoTokenizer,
+ GenerationConfig,
PretrainedConfig,
pipeline,
set_seed,
@@ -42,6 +43,8 @@
IPEXModelForSequenceClassification,
IPEXModelForTokenClassification,
)
+from optimum.intel.utils.import_utils import is_ipex_version
+from optimum.utils.testing_utils import grid_parameters
SEED = 42
@@ -67,6 +70,7 @@
"gptj": "hf-internal-testing/tiny-random-GPTJModel",
"levit": "hf-internal-testing/tiny-random-LevitModel",
"llama": "fxmarty/tiny-llama-fast-tokenizer",
+ "llama2": "Jiqing/tiny_random_llama2",
"marian": "sshleifer/tiny-marian-en-de",
"mbart": "hf-internal-testing/tiny-random-mbart",
"mistral": "echarlaix/tiny-random-mistral",
@@ -209,11 +213,13 @@ class IPEXModelForCausalLMTest(unittest.TestCase):
"gpt_neo",
"gpt_neox",
"llama",
+ "llama2",
"mistral",
# "phi",
"mpt",
"opt",
)
+ IPEX_PATCHED_SUPPORTED_ARCHITECTURES = ("llama",)
GENERATION_LENGTH = 100
SPEEDUP_CACHE = 1.0
@@ -226,7 +232,9 @@ def test_compare_to_transformers(self, model_arch):
self.assertTrue(ipex_model.use_cache)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokens = tokenizer(
- "This is a sample", return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None
+ "This is a sample",
+ return_tensors="pt",
+ return_token_type_ids=False if model_arch in ("llama", "llama2") else None,
)
position_ids = None
if model_arch.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS:
@@ -255,6 +263,41 @@ def test_pipeline(self, model_arch):
self.assertEqual(pipe.device, model.device)
self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs))
+ @parameterized.expand(
+ grid_parameters(
+ {
+ "model_arch": IPEX_PATCHED_SUPPORTED_ARCHITECTURES,
+ "use_cache": [True, False],
+ }
+ )
+ )
+ @unittest.skipIf(is_ipex_version("<", "2.5.0"), reason="Only ipex version > 2.3.0 supports ipex model patching")
+ def test_ipex_patching_beam_search(self, test_name, model_arch, use_cache):
+ model_id = MODEL_NAMES[model_arch]
+ set_seed(SEED)
+ model = IPEXModelForCausalLM.from_pretrained(model_id, export=True, use_cache=use_cache)
+ self.assertEqual(model.use_cache, use_cache)
+ trasnformers_model = AutoModelForCausalLM.from_pretrained(model_id)
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
+ tokenizer.pad_token = tokenizer.eos_token
+ # Test with batch_size is 1 and 2.
+ texts = ["This is a sample", ["This is the first input", "This is the second input"]]
+ generation_configs = (
+ GenerationConfig(max_new_tokens=4, num_beams=2, do_sample=True),
+ GenerationConfig(max_new_tokens=4, num_beams=4, do_sample=True),
+ GenerationConfig(max_new_tokens=4, num_beams=8, do_sample=True),
+ GenerationConfig(max_new_tokens=4, num_beams=32, do_sample=True),
+ GenerationConfig(max_new_tokens=4, do_sample=not use_cache, top_p=1.0, top_k=5, penalty_alpha=0.6),
+ GenerationConfig(max_new_tokens=4, do_sample=True, top_p=0.9, top_k=0),
+ )
+ for text in texts:
+ tokens = tokenizer(text, padding=True, return_tensors="pt")
+ for generation_config in generation_configs:
+ outputs = model.generate(**tokens, generation_config=generation_config)
+ transformers_outputs = trasnformers_model.generate(**tokens, generation_config=generation_config)
+ self.assertIsInstance(outputs, torch.Tensor)
+ self.assertEqual(outputs, transformers_outputs)
+
def test_compare_with_and_without_past_key_values(self):
model_id = "echarlaix/tiny-random-gpt2-torchscript"
tokenizer = AutoTokenizer.from_pretrained(model_id)
diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py
index 07a9f14774..c7fb00e12d 100644
--- a/tests/openvino/test_quantization.py
+++ b/tests/openvino/test_quantization.py
@@ -16,10 +16,12 @@
import tempfile
import unittest
+from collections import defaultdict
from functools import partial
import evaluate
import numpy as np
+import torch
from datasets import load_dataset
from parameterized import parameterized
import openvino.runtime as ov
@@ -30,12 +32,14 @@
AutoModelForCausalLM,
AutoModelForTokenClassification,
AutoTokenizer,
+ AutoProcessor,
TrainingArguments,
default_data_collator,
)
from optimum.intel import (
OVConfig,
+ OVLatentConsistencyModelPipeline,
OVModelForAudioClassification,
OVModelForCausalLM,
OVModelForFeatureExtraction,
@@ -45,6 +49,7 @@
OVModelForSeq2SeqLM,
OVModelForSequenceClassification,
OVModelForTokenClassification,
+ OVModelForSpeechSeq2Seq,
OVStableDiffusionPipeline,
OVStableDiffusionXLPipeline,
OVQuantizer,
@@ -52,8 +57,8 @@
OVWeightQuantizationConfig,
)
-
-from optimum.intel.openvino.configuration import INT8_WEIGHT_COMPRESSION_CONFIG
+from optimum.intel.openvino.configuration import INT8_WEIGHT_COMPRESSION_CONFIG, DEFAULT_QUANTIZATION_CONFIG
+from optimum.intel.openvino.quantization import InferRequestWrapper
from optimum.intel.utils.import_utils import is_openvino_version
from utils_tests import MODEL_NAMES, get_num_quantized_nodes, _ARCHITECTURES_TO_EXPECTED_INT8
@@ -106,9 +111,8 @@ def preprocess_function(examples, tokenizer):
self.assertTrue("logits" in outputs)
# Verify that that the configuration is correctly saved and loaded
- expected_config = OVConfig()
loaded_config = OVConfig.from_pretrained(tmp_dir)
- self.assertEqual(expected_config.to_dict()["compression"], loaded_config.to_dict()["compression"])
+ self.assertEqual(DEFAULT_QUANTIZATION_CONFIG, loaded_config.to_dict()["compression"])
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_QUANTIZED_MATMULS)
def test_ovmodel_static_quantization(self, model_cls, model_name, expected_fake_quantize, expected_int8):
@@ -151,16 +155,16 @@ class OVWeightCompressionTest(unittest.TestCase):
# TODO : add models
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS = (
(OVModelForSequenceClassification, "hf-internal-testing/tiny-random-bert", 70, 70),
- (OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 44, 46),
+ (OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 44, 44),
)
- SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS = ((OVModelForCausalLM, "opt125m", 64, 365),)
- SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_AUTOCOMPRESSED_MATMULS = ((OVModelForCausalLM, "opt125m", 6, 379),)
+ SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_COMPRESSED_MATMULS = ((OVModelForCausalLM, "opt125m", 62, 86),)
+ SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_AUTOCOMPRESSED_MATMULS = ((OVModelForCausalLM, "opt125m", 0, 148),)
SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_AUTO_COMPRESSED_MATMULS = (
- (OVModelForCausalLM, "hf-internal-testing/tiny-random-OPTForCausalLM", 16, 136),
+ (OVModelForCausalLM, "hf-internal-testing/tiny-random-OPTForCausalLM", 14, 50),
)
SUPPORTED_ARCHITECTURES_STATEFUL_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS = (
- (OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 44, 46),
+ (OVModelForCausalLM, "hf-internal-testing/tiny-random-gpt2", 44, 44),
)
LOAD_IN_4_BITS_SCOPE = (
@@ -168,7 +172,7 @@ class OVWeightCompressionTest(unittest.TestCase):
OVModelForCausalLM,
"hf-internal-testing/tiny-random-gpt2",
dict(bits=4, sym=False, group_size=-1, ratio=0.8),
- 16,
+ 14,
),
(
OVModelForCausalLM,
@@ -179,13 +183,13 @@ class OVWeightCompressionTest(unittest.TestCase):
group_size=32,
ignored_scope={"names": ["__module.model.transformer.h.2.mlp.c_fc/aten::addmm/MatMul"]},
),
- 6,
+ 4,
),
(
OVModelForCausalLM,
"hf-internal-testing/tiny-random-gpt2",
dict(bits=4, sym=False, group_size=-1, ratio=0.8, all_layers=True),
- 22,
+ 18,
),
(
OVModelForCausalLM,
@@ -198,7 +202,7 @@ class OVWeightCompressionTest(unittest.TestCase):
sensitivity_metric="mean_activation_magnitude",
dataset="ptb",
),
- 16,
+ 14,
),
(
OVModelForCausalLM,
@@ -212,7 +216,7 @@ class OVWeightCompressionTest(unittest.TestCase):
dataset="ptb",
awq=True,
),
- 16,
+ 14,
),
)
@@ -230,8 +234,16 @@ class OVWeightCompressionTest(unittest.TestCase):
(OVStableDiffusionXLPipeline, "stable-diffusion-xl"),
)
+ SUPPORTED_ARCHITECTURES_WITH_HYBRID_QUANTIZATION = (
+ (OVStableDiffusionPipeline, "stable-diffusion", 72, 195),
+ (OVStableDiffusionXLPipeline, "stable-diffusion-xl", 84, 331),
+ (OVLatentConsistencyModelPipeline, "latent-consistency", 50, 135),
+ )
+
IS_SUPPORT_STATEFUL = is_openvino_version(">=", "2023.3")
+ DEFAULT_INT4_CONFIG = {"bits": 4, "sym": True, "group_size": 64, "all_layers": True}
+
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_8BIT_COMPRESSED_MATMULS)
def test_automodel_weight_compression(self, model_cls, model_name, expected_pt_int8, expected_ov_int8):
task = model_cls.export_feature
@@ -331,6 +343,8 @@ def test_ovmodel_8bit_weight_compression_stateful(self, model_cls, model_id, exp
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION)
def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type):
model = model_cls.from_pretrained(MODEL_NAMES[model_type], export=True, load_in_8bit=True, stateful=False)
+ self.assertEqual(model._openvino_config.quantization_config.bits, 8)
+ self.assertEqual(model._openvino_config.dtype, "int8")
if model.export_feature.startswith("text2text-generation"):
models = [model.encoder, model.decoder, model.decoder_with_past]
@@ -345,13 +359,46 @@ def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type):
_, num_int8, _ = get_num_quantized_nodes(model)
self.assertEqual(expected_ov_int8[i], num_int8)
+ @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_HYBRID_QUANTIZATION)
+ def test_ovmodel_hybrid_quantization(self, model_cls, model_type, expected_num_fake_quantize, expected_ov_int8):
+ model_id = MODEL_NAMES[model_type]
+ quantization_config = OVWeightQuantizationConfig(bits=8, dataset="conceptual_captions", num_samples=2)
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ model = model_cls.from_pretrained(model_id, export=True, quantization_config=quantization_config)
+
+ num_fake_quantize, num_int8, num_int4 = get_num_quantized_nodes(model.unet)
+ self.assertEqual(expected_num_fake_quantize, num_fake_quantize)
+ self.assertEqual(expected_ov_int8, num_int8)
+ self.assertEqual(0, num_int4)
+
+ model.save_pretrained(tmp_dir)
+
+ @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_HYBRID_QUANTIZATION[-1:])
+ def test_ovmodel_hybrid_quantization_with_custom_dataset(
+ self, model_cls, model_type, expected_num_fake_quantize, expected_ov_int8
+ ):
+ model_id = MODEL_NAMES[model_type]
+ dataset = [
+ "dream rose covered with clean crystal, sharp edges, transparent, beautiful, highly detailed, high render"
+ ]
+ model = model_cls.from_pretrained(
+ model_id,
+ export=True,
+ quantization_config=OVWeightQuantizationConfig(bits=8, dataset=dataset, num_samples=3),
+ )
+ num_fake_quantize, num_int8, num_int4 = get_num_quantized_nodes(model.unet)
+ self.assertEqual(expected_num_fake_quantize, num_fake_quantize)
+ self.assertEqual(expected_ov_int8, num_int8)
+ self.assertEqual(0, num_int4)
+
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_AUTOCOMPRESSED_MATMULS)
+ @unittest.mock.patch.dict(
+ "optimum.intel.openvino.configuration._DEFAULT_4BIT_CONFIGS", {"facebook/opt-125m": DEFAULT_INT4_CONFIG}
+ )
def test_ovmodel_4bit_auto_compression(self, model_cls, model_type, expected_ov_int8, expected_ov_int4):
with tempfile.TemporaryDirectory() as tmp_dir:
model_id = MODEL_NAMES[model_type]
- model = model_cls.from_pretrained(
- model_id, export=True, quantization_config=OVWeightQuantizationConfig(bits=4)
- )
+ model = model_cls.from_pretrained(model_id, export=True, quantization_config={"bits": 4})
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
@@ -361,6 +408,13 @@ def test_ovmodel_4bit_auto_compression(self, model_cls, model_type, expected_ov_
self.assertEqual(expected_ov_int8, num_int8)
model.save_pretrained(tmp_dir)
+ openvino_config = OVConfig.from_pretrained(tmp_dir)
+ self.assertEqual(openvino_config.quantization_config["bits"], 4)
+ self.assertEqual(openvino_config.dtype, "int4")
+ if model_id == "facebook/opt-125m":
+ for key, value in self.DEFAULT_INT4_CONFIG.items():
+ self.assertEqual(value, openvino_config.quantization_config[key])
+
@parameterized.expand(LOAD_IN_4_BITS_SCOPE)
def test_ovmodel_4bit_auto_compression_with_config(
self, model_cls, model_id, quantization_config, expected_ov_int4
@@ -375,8 +429,9 @@ def test_ovmodel_4bit_auto_compression_with_config(
self.assertEqual(expected_ov_int4, num_int4)
model.save_pretrained(tmp_dir)
- ov_config = OVConfig(quantization_config=quantization_config)
- ov_config.save_pretrained(tmp_dir)
+ openvino_config = OVConfig.from_pretrained(tmp_dir)
+ self.assertEqual(openvino_config.quantization_config["bits"], 4)
+ self.assertEqual(openvino_config.dtype, "int4")
@parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_EXPECTED_4BIT_AUTO_COMPRESSED_MATMULS)
def test_ovmodel_4bit_auto_compression_with_custom_dataset(
@@ -443,36 +498,64 @@ def test_ovmodel_load_with_uncompressed_weights(self, model_cls, model_type):
self.assertEqual(0, num_int8)
def test_ovmodel_load_large_model_with_default_compressed_weights(self):
- with unittest.mock.patch("transformers.modeling_utils.ModuleUtilsMixin") as model_mixin_patch:
- model_mixin_patch.num_parameters.return_value = 2e9
+ with unittest.mock.patch("torch.nn.Module.parameters") as model_parameters:
+ mock_tensor = unittest.mock.Mock()
+ mock_tensor.numel = lambda: 2000000000
+ mock_tensor.requires_grad = True
+ model_parameters.return_value = [mock_tensor]
with unittest.mock.patch("openvino.runtime.ie_api.Core.read_model") as core_patch:
with unittest.mock.patch("optimum.exporters.openvino.convert._save_model") as save_model_patch:
_ = OVModelForCausalLM.from_pretrained(
MODEL_NAMES["llama"], export=True, compile=False, use_cache=False
)
- saving_params = {
- "model": unittest.mock.ANY,
- "path": unittest.mock.ANY,
- "compression_option": "int8",
- "compression_ratio": None,
- }
- save_model_patch.aasert_called_with(saving_params)
+ save_model_patch.assert_called_with(
+ unittest.mock.ANY, unittest.mock.ANY, ov_config=OVConfig(quantization_config={"bits": 8})
+ )
def test_ovmodel_load_large_model_with_uncompressed_weights(self):
- with unittest.mock.patch("transformers.modeling_utils.ModuleUtilsMixin") as model_mixin_patch:
- model_mixin_patch.num_parameters.return_value = 2e9
+ with unittest.mock.patch("torch.nn.Module.parameters") as model_parameters:
+ mock_tensor = unittest.mock.Mock()
+ mock_tensor.numel = lambda: 2000000000
+ mock_tensor.requires_grad = True
+ model_parameters.return_value = [mock_tensor]
with unittest.mock.patch("openvino.runtime.ie_api.Core.read_model") as core_patch:
with unittest.mock.patch("optimum.exporters.openvino.convert._save_model") as save_model_patch:
_ = OVModelForCausalLM.from_pretrained(
MODEL_NAMES["llama"], export=True, load_in_8bit=False, compile=False, use_cache=False
)
- saving_params = {
- "model": unittest.mock.ANY,
- "path": unittest.mock.ANY,
- "compression_option": "fp32",
- "compression_ratio": None,
- }
- save_model_patch.aasert_called_with(saving_params)
+ save_model_patch.assert_called_with(
+ unittest.mock.ANY, unittest.mock.ANY, ov_config=OVConfig(dtype="fp32")
+ )
+
+ def test_ovmodel_load_large_model_with_additional_quantization_config(self):
+ with unittest.mock.patch("torch.nn.Module.parameters") as model_parameters:
+ mock_tensor = unittest.mock.Mock()
+ mock_tensor.numel = lambda: 2000000000
+ mock_tensor.requires_grad = True
+ with unittest.mock.patch("openvino.runtime.ie_api.Core.read_model") as core_patch:
+ with unittest.mock.patch("optimum.exporters.openvino.convert._save_model") as save_model_patch:
+ with unittest.mock.patch("nncf.compress_weights") as compress_weights_patch:
+ _ = OVModelForCausalLM.from_pretrained(
+ MODEL_NAMES["llama"],
+ export=True,
+ compile=False,
+ use_cache=False,
+ quantization_config=OVWeightQuantizationConfig(bits=4, sym=True, group_size=-1, ratio=0.8),
+ )
+ # quantization will be performed later, using load_model
+ save_model_patch.assert_called_with(
+ unittest.mock.ANY, unittest.mock.ANY, ov_config=OVConfig(dtype="fp32")
+ )
+ compression_params = {
+ "mode": nncf.CompressWeightsMode.INT4_SYM,
+ "ratio": 0.8,
+ "group_size": -1,
+ "all_layers": None,
+ "sensitivity_metric": None,
+ "dataset": None,
+ "ignored_scope": None,
+ }
+ compress_weights_patch.assert_called_with(unittest.mock.ANY, **compression_params)
class OVQuantizerQATest(unittest.TestCase):
@@ -589,3 +672,38 @@ def compute_metrics(p):
tokens = tokenizer("This is a sample input", return_tensors="pt")
outputs = model(**tokens)
self.assertTrue("logits" in outputs)
+
+
+class InferRequestWrapperTest(unittest.TestCase):
+ MODEL_ID = ("openai/whisper-tiny.en",)
+
+ @staticmethod
+ def _generate_random_audio_data(processor):
+ t = np.linspace(0, 1.0, int(1000), endpoint=False)
+ audio_data = 0.5 * np.sin((2 + np.random.random()) * np.pi * t)
+ input_features = processor(
+ audio_data,
+ sampling_rate=16000,
+ return_tensors="pt",
+ ).input_features
+ return input_features
+
+ @parameterized.expand(MODEL_ID)
+ def test_calibration_data_uniqueness(self, model_id):
+ ov_model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True, compile=True)
+ processor = AutoProcessor.from_pretrained(model_id)
+
+ calibration_data = []
+ ov_model.decoder_with_past.request = InferRequestWrapper(ov_model.decoder_with_past.request, calibration_data)
+ for _ in range(2):
+ input_features = self._generate_random_audio_data(processor)
+ ov_model.generate(input_features)
+
+ data_hashes_per_key = defaultdict(list)
+ for inputs_dict in calibration_data:
+ for k, v in inputs_dict.items():
+ x = (v.numpy() if isinstance(v, torch.Tensor) else v).copy()
+ data_hashes_per_key[k].append(hash(x.tobytes()))
+ for k, data_hashes in data_hashes_per_key.items():
+ # All hashes can not be equal because calibration dataset contains at least 2 different samples
+ self.assertTrue(any(data_hashes[0] != it for it in data_hashes))
diff --git a/tests/openvino/test_stable_diffusion.py b/tests/openvino/test_stable_diffusion.py
index d8cef2e027..ab6f6f21a6 100644
--- a/tests/openvino/test_stable_diffusion.py
+++ b/tests/openvino/test_stable_diffusion.py
@@ -28,7 +28,6 @@
from diffusers.utils import load_image
from diffusers.utils.testing_utils import floats_tensor
from openvino.runtime.ie_api import CompiledModel
-from packaging.version import Version, parse
from parameterized import parameterized
from utils_tests import MODEL_NAMES, SEED
@@ -46,13 +45,8 @@
OVModelVaeDecoder,
OVModelVaeEncoder,
)
-from optimum.onnxruntime import (
- ORTStableDiffusionImg2ImgPipeline,
- ORTStableDiffusionInpaintPipeline,
- ORTStableDiffusionXLImg2ImgPipeline,
- ORTStableDiffusionXLPipeline,
-)
-from optimum.utils.import_utils import _diffusers_version
+from optimum.intel.utils.import_utils import is_diffusers_version
+from optimum.utils.import_utils import is_onnxruntime_available
F32_CONFIG = {"INFERENCE_PRECISION_HINT": "f32"}
@@ -167,7 +161,6 @@ def generate_inputs(self, height=128, width=128, batch_size=1):
class OVStableDiffusionImg2ImgPipelineTest(OVStableDiffusionPipelineBaseTest):
SUPPORTED_ARCHITECTURES = ("stable-diffusion",)
MODEL_CLASS = OVStableDiffusionImg2ImgPipeline
- ORT_MODEL_CLASS = ORTStableDiffusionImg2ImgPipeline
TASK = "image-to-image"
@parameterized.expand(SUPPORTED_ARCHITECTURES)
@@ -298,11 +291,13 @@ def test_height_width_properties(self, model_arch: str):
class OVStableDiffusionInpaintPipelineTest(OVStableDiffusionPipelineBaseTest):
SUPPORTED_ARCHITECTURES = ("stable-diffusion",)
MODEL_CLASS = OVStableDiffusionInpaintPipeline
- ORT_MODEL_CLASS = ORTStableDiffusionInpaintPipeline
TASK = "inpaint"
@parameterized.expand(SUPPORTED_ARCHITECTURES)
+ @unittest.skipIf(not is_onnxruntime_available(), "this test requires onnxruntime")
def test_compare_diffusers_pipeline(self, model_arch: str):
+ from optimum.onnxruntime import ORTStableDiffusionInpaintPipeline
+
model_id = MODEL_NAMES[model_arch]
pipeline = self.MODEL_CLASS.from_pretrained(model_id, export=True, ov_config=F32_CONFIG)
batch_size, num_images, height, width = 1, 1, 64, 64
@@ -329,7 +324,7 @@ def test_compare_diffusers_pipeline(self, model_arch: str):
outputs = pipeline(**inputs, latents=latents).images
self.assertEqual(outputs.shape, (batch_size * num_images, height, width, 3))
- ort_pipeline = self.ORT_MODEL_CLASS.from_pretrained(model_id, export=True)
+ ort_pipeline = ORTStableDiffusionInpaintPipeline.from_pretrained(model_id, export=True)
ort_outputs = ort_pipeline(**inputs, latents=latents).images
self.assertTrue(np.allclose(outputs, ort_outputs, atol=1e-1))
@@ -358,7 +353,6 @@ def generate_inputs(self, height=128, width=128, batch_size=1):
class OVtableDiffusionXLPipelineTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = ("stable-diffusion-xl",)
MODEL_CLASS = OVStableDiffusionXLPipeline
- ORT_MODEL_CLASS = ORTStableDiffusionXLPipeline
PT_MODEL_CLASS = StableDiffusionXLPipeline
TASK = "text-to-image"
@@ -444,7 +438,6 @@ def test_num_images_per_prompt_static_model(self, model_arch: str):
class OVStableDiffusionXLImg2ImgPipelineTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = ("stable-diffusion-xl", "stable-diffusion-xl-refiner")
MODEL_CLASS = OVStableDiffusionXLImg2ImgPipeline
- ORT_MODEL_CLASS = ORTStableDiffusionXLImg2ImgPipeline
PT_MODEL_CLASS = StableDiffusionXLImg2ImgPipeline
TASK = "image-to-image"
@@ -489,7 +482,7 @@ class OVLatentConsistencyModelPipelineTest(unittest.TestCase):
TASK = "text-to-image"
@parameterized.expand(SUPPORTED_ARCHITECTURES)
- @unittest.skipIf(parse(_diffusers_version) <= Version("0.21.4"), "not supported with this diffusers version")
+ @unittest.skipIf(is_diffusers_version("<=", "0.21.4"), "not supported with this diffusers version")
def test_compare_to_diffusers(self, model_arch: str):
ov_pipeline = self.MODEL_CLASS.from_pretrained(MODEL_NAMES[model_arch], export=True, ov_config=F32_CONFIG)
self.assertIsInstance(ov_pipeline.text_encoder, OVModelTextEncoder)
@@ -532,7 +525,7 @@ def test_compare_to_diffusers(self, model_arch: str):
self.assertEqual(pipeline.device.type, ov_pipeline.device)
@parameterized.expand(SUPPORTED_ARCHITECTURES)
- @unittest.skipIf(parse(_diffusers_version) <= Version("0.21.4"), "not supported with this diffusers version")
+ @unittest.skipIf(is_diffusers_version("<=", "0.21.4"), "not supported with this diffusers version")
def test_num_images_per_prompt_static_model(self, model_arch: str):
model_id = MODEL_NAMES[model_arch]
pipeline = self.MODEL_CLASS.from_pretrained(model_id, export=True, compile=False, dynamic_shapes=False)
diff --git a/tests/openvino/test_training.py b/tests/openvino/test_training.py
index 937c0bf3f5..80298faf2b 100644
--- a/tests/openvino/test_training.py
+++ b/tests/openvino/test_training.py
@@ -365,7 +365,7 @@ def tearDown(self):
"default_quantization,structured_movement_sparsity": OVTrainerTestDescriptor(
model_id="hf-internal-testing/tiny-random-bert",
nncf_compression_config=[DEFAULT_QUANTIZATION_CONFIG, STRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT],
- expected_fake_quantize=44,
+ expected_fake_quantize=34,
expected_int8=32,
expected_binary_masks=60,
compression_metrics=["compression_loss"],
@@ -376,7 +376,7 @@ def tearDown(self):
CUSTOMIZED_QUANTIZATION_CONFIG,
STRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT,
],
- expected_fake_quantize=44,
+ expected_fake_quantize=34,
expected_int8=32,
expected_binary_masks=60,
compression_metrics=["compression_loss"],
@@ -385,7 +385,7 @@ def tearDown(self):
model_id="hf-internal-testing/tiny-random-bert",
teacher_model_id="hf-internal-testing/tiny-random-bert",
nncf_compression_config=[DEFAULT_QUANTIZATION_CONFIG, STRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT],
- expected_fake_quantize=44,
+ expected_fake_quantize=34,
expected_int8=32,
expected_binary_masks=60,
compression_metrics=["compression_loss", "distillation_loss", "task_loss"],
@@ -397,7 +397,7 @@ def tearDown(self):
CUSTOMIZED_QUANTIZATION_CONFIG,
STRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_BERT,
],
- expected_fake_quantize=44,
+ expected_fake_quantize=34,
expected_int8=32,
expected_binary_masks=60,
compression_metrics=["compression_loss", "distillation_loss", "task_loss"],
@@ -749,7 +749,7 @@ def check_ovmodel_reshaping(self, ovmodel: OVModel):
"quantization,structured_movement_sparsity": OVTrainerTestDescriptor(
model_id="hf-internal-testing/tiny-random-Wav2Vec2Model",
nncf_compression_config=[QUANTIZATION_CONFIG_FOR_WAV2VEC2, STRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_WAV2VEC2],
- expected_fake_quantize=48,
+ expected_fake_quantize=40,
expected_int8=30,
expected_binary_masks=48,
compression_metrics=["compression_loss"],
@@ -766,7 +766,7 @@ def check_ovmodel_reshaping(self, ovmodel: OVModel):
model_id="hf-internal-testing/tiny-random-Wav2Vec2Model",
teacher_model_id="hf-internal-testing/tiny-random-Wav2Vec2Model",
nncf_compression_config=[QUANTIZATION_CONFIG_FOR_WAV2VEC2, STRUCTURED_MOVEMENT_SPARSITY_CONFIG_FOR_WAV2VEC2],
- expected_fake_quantize=48,
+ expected_fake_quantize=40,
expected_int8=30,
expected_binary_masks=48,
compression_metrics=["compression_loss", "distillation_loss", "task_loss"],
diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py
index 8fabb34e38..97c8a92836 100644
--- a/tests/openvino/utils_tests.py
+++ b/tests/openvino/utils_tests.py
@@ -102,12 +102,12 @@
SEED = 42
_ARCHITECTURES_TO_EXPECTED_INT8 = {
- "bert": (70,),
+ "bert": (68,),
"roberta": (68,),
"albert": (84,),
"vit": (64,),
"blenderbot": (70,),
- "gpt2": (46,),
+ "gpt2": (44,),
"wav2vec2": (34,),
"distilbert": (66,),
"t5": (64, 104, 84),
@@ -116,7 +116,7 @@
"stable-diffusion-xl-refiner": (366, 34, 42, 66),
}
-_ARCHITECTURES_TO_EXPECTED_INT4_INT8 = {"opt125m": (64, 477)}
+_ARCHITECTURES_TO_EXPECTED_INT4_INT8 = {"opt125m": (62, 86)}
def get_num_quantized_nodes(ov_model):
@@ -127,8 +127,8 @@ def get_num_quantized_nodes(ov_model):
if "FakeQuantize" in elem.name:
num_fake_quantize += 1
for i in range(elem.get_output_size()):
- if "8" in elem.get_output_element_type(i).get_type_name():
+ if elem.get_output_element_type(i).get_type_name() in ["i8", "u8"]:
num_int8 += 1
- if "4" in elem.get_output_element_type(i).get_type_name():
+ if elem.get_output_element_type(i).get_type_name() in ["i4", "u4"]:
num_int4 += 1
return num_fake_quantize, num_int8, num_int4