From 8880d2e77c56cb158c96a264cc7f784b1e72189a Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Thu, 14 Mar 2024 17:15:14 +0400 Subject: [PATCH] Add openvino export configs (#568) * add openvino export configs * more libs * more libs * mixtral and model patcher * chatglm export * rework chatglm config * more testing models * rework config registration * add chatglm in tests * Update tests/openvino/test_modeling.py * fix style * gemma * add test models * qwen * fix failed tests * add comment for gemma --- optimum/exporters/openvino/__init__.py | 2 + optimum/exporters/openvino/__main__.py | 18 +- optimum/exporters/openvino/convert.py | 49 +-- optimum/exporters/openvino/model_configs.py | 391 +++++++++++++++++ optimum/exporters/openvino/model_patcher.py | 441 +++++++++++++++++++- optimum/intel/openvino/modeling_decoder.py | 4 +- optimum/intel/openvino/quantization.py | 2 +- setup.py | 4 +- tests/openvino/test_modeling.py | 72 +++- tests/openvino/utils_tests.py | 8 + 10 files changed, 933 insertions(+), 58 deletions(-) create mode 100644 optimum/exporters/openvino/model_configs.py diff --git a/optimum/exporters/openvino/__init__.py b/optimum/exporters/openvino/__init__.py index 9664f6ae6d..94ea4f103b 100644 --- a/optimum/exporters/openvino/__init__.py +++ b/optimum/exporters/openvino/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import optimum.exporters.openvino.model_configs + from .__main__ import main_export from .convert import export, export_from_model, export_models, export_pytorch_via_onnx from .stateful import ensure_stateful_is_available, patch_stateful diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 1c695e2f19..02268a3604 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -58,7 +58,7 @@ def main_export( local_files_only: bool = False, use_auth_token: Optional[Union[bool, str]] = None, model_kwargs: Optional[Dict[str, Any]] = None, - custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None, + custom_export_configs: Optional[Dict[str, "OnnxConfig"]] = None, fn_get_submodels: Optional[Callable] = None, compression_option: Optional[str] = None, compression_ratio: Optional[float] = None, @@ -112,11 +112,11 @@ def main_export( when running `transformers-cli login` (stored in `~/.huggingface`). model_kwargs (`Optional[Dict[str, Any]]`, defaults to `None`): Experimental usage: keyword arguments to pass to the model during - the export. This argument should be used along the `custom_onnx_configs` argument + the export. This argument should be used along the `custom_export_configs` argument in case, for example, the model inputs/outputs are changed (for example, if `model_kwargs={"output_attentions": True}` is passed). - custom_onnx_configs (`Optional[Dict[str, OnnxConfig]]`, defaults to `None`): - Experimental usage: override the default ONNX config used for the given model. This argument may be useful for advanced users that desire a finer-grained control on the export. An example is available [here](https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model). + custom_export_configs (`Optional[Dict[str, OnnxConfig]]`, defaults to `None`): + Experimental usage: override the default export config used for the given model. This argument may be useful for advanced users that desire a finer-grained control on the export. An example is available [here](https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model). fn_get_submodels (`Optional[Callable]`, defaults to `None`): Experimental usage: Override the default submodels that are used at the export. This is especially useful when exporting a custom architecture that needs to split the ONNX (e.g. encoder-decoder). If unspecified with custom models, optimum will try to use the default submodels used for the given task, with no guarantee of success. @@ -134,7 +134,7 @@ def main_export( ```python >>> from optimum.exporters.openvino import main_export - >>> main_export("gpt2", output="gpt2_onnx/") + >>> main_export("gpt2", output="gpt2_ov/") ``` """ @@ -206,14 +206,14 @@ def main_export( if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: custom_architecture = True elif task not in TasksManager.get_supported_tasks_for_model_type( - model_type, exporter="onnx", library_name=library_name + model_type, exporter="openvino", library_name=library_name ): if original_task == "auto": autodetected_message = " (auto-detected)" else: autodetected_message = "" model_tasks = TasksManager.get_supported_tasks_for_model_type( - model_type, exporter="onnx", library_name=library_name + model_type, exporter="openvino", library_name=library_name ) raise ValueError( f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum OpenVINO exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}." @@ -288,7 +288,7 @@ class StoreAttr(object): not custom_architecture and library_name != "diffusers" and task + "-with-past" - in TasksManager.get_supported_tasks_for_model_type(model_type, exporter="onnx", library_name=library_name) + in TasksManager.get_supported_tasks_for_model_type(model_type, exporter="openvino", library_name=library_name) ): # Make -with-past the default if --task was not explicitely specified if original_task == "auto": @@ -319,7 +319,7 @@ class StoreAttr(object): ov_config=ov_config, stateful=stateful, model_kwargs=model_kwargs, - custom_onnx_configs=custom_onnx_configs, + custom_export_configs=custom_export_configs, fn_get_submodels=fn_get_submodels, preprocessors=preprocessors, device=device, diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 5353912d48..dfca80f001 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -32,10 +32,11 @@ from optimum.exporters.onnx.convert import check_dummy_inputs_are_allowed from optimum.exporters.onnx.convert import export_pytorch as export_pytorch_to_onnx from optimum.exporters.onnx.convert import export_tensorflow as export_tensorflow_onnx +from optimum.exporters.utils import _get_submodels_and_export_configs from optimum.utils import DEFAULT_DUMMY_SHAPES, is_diffusers_available from optimum.utils.save_utils import maybe_save_preprocessors -from ...intel.utils.import_utils import is_nncf_available, is_optimum_version +from ...intel.utils.import_utils import is_nncf_available from .model_patcher import patch_model_with_bettertransformer from .stateful import ensure_export_task_support_stateful, ensure_stateful_is_available, patch_stateful from .utils import ( @@ -48,13 +49,6 @@ ) -if is_optimum_version(">=", "1.16.99"): - from optimum.exporters.onnx.utils import _get_submodels_and_onnx_configs - -else: - from optimum.exporters.onnx.__main__ import _get_submodels_and_onnx_configs - - UNSUPPORTED_TOKENIZER_CLASSES = (T5Tokenizer, T5TokenizerFast) @@ -418,7 +412,7 @@ def ts_patched_forward(*args, **kwargs): def export_models( - models_and_onnx_configs: Dict[ + models_and_export_configs: Dict[ str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"] ], output_dir: Path, @@ -434,7 +428,7 @@ def export_models( Export the models to OpenVINO IR format Args: - models_and_onnx_configs (Dict[ str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"]): + models_and_export_configs (Dict[ str, Tuple[Union["PreTrainedModel", "TFPreTrainedModel", "ModelMixin"], "OnnxConfig"]): output_dir (Path): output directory for saving models opset (Optional[int], optional, Default to None): ONNX export opset output_names (Optional[List[str]], optional, Defaults to None): model output names @@ -459,20 +453,20 @@ def export_models( outputs = [] - if output_names is not None and len(output_names) != len(models_and_onnx_configs): + if output_names is not None and len(output_names) != len(models_and_export_configs): raise ValueError( - f"Provided custom names {output_names} for the export of {len(models_and_onnx_configs)} models. Please provide the same number of names as models to export." + f"Provided custom names {output_names} for the export of {len(models_and_export_configs)} models. Please provide the same number of names as models to export." ) - for i, model_name in enumerate(models_and_onnx_configs.keys()): - submodel, sub_onnx_config = models_and_onnx_configs[model_name] + for i, model_name in enumerate(models_and_export_configs.keys()): + submodel, sub_export_config = models_and_export_configs[model_name] output_name = output_names[i] if output_names is not None else Path(model_name + ".xml") output_path = output_dir / output_name output_path.parent.mkdir(parents=True, exist_ok=True) outputs.append( export( model=submodel, - config=sub_onnx_config, + config=sub_export_config, output=output_path, opset=opset, device=device, @@ -495,7 +489,7 @@ def export_from_model( stateful: bool = True, opset: Optional[int] = None, model_kwargs: Optional[Dict[str, Any]] = None, - custom_onnx_configs: Optional[Dict[str, "OnnxConfig"]] = None, + custom_export_configs: Optional[Dict[str, "OnnxConfig"]] = None, fn_get_submodels: Optional[Callable] = None, preprocessors: List = None, device: str = "cpu", @@ -524,14 +518,14 @@ def export_from_model( task = TasksManager._infer_task_from_model_or_model_class(model=model) except (ValueError, KeyError) as e: raise RuntimeError( - f"The model task could not be automatically inferred in `onnx_export_from_model`. Please provide the argument `task` with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" + f"The model task could not be automatically inferred in `export_from_model`. Please provide the argument `task` with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" ) if ( not custom_architecture and library_name != "diffusers" and task + "-with-past" - in TasksManager.get_supported_tasks_for_model_type(model_type, "onnx", library_name=library_name) + in TasksManager.get_supported_tasks_for_model_type(model_type, "openvino", library_name=library_name) ): # -with-past is the default. task = task + "-with-past" @@ -541,9 +535,9 @@ def export_from_model( stateful = stateful and ensure_export_task_support_stateful(task) # TODO: support onnx_config.py in the model repo - if custom_architecture and custom_onnx_configs is None: + if custom_architecture and custom_export_configs is None: raise ValueError( - f"Trying to export a {model_type} model, that is a custom or unsupported architecture, but no custom onnx configuration was passed as `custom_onnx_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the model type {model_type} to be supported natively in the ONNX export." + f"Trying to export a {model_type} model, that is a custom or unsupported architecture, but no custom export configuration was passed as `custom_export_configs`. Please refer to https://huggingface.co/docs/optimum/main/en/exporters/onnx/usage_guides/export_a_model#custom-export-of-transformers-models for an example on how to export custom models. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the model type {model_type} to be supported natively in the ONNX export." ) if task.startswith("text-generation") and model.config.is_encoder_decoder: @@ -569,11 +563,11 @@ def export_from_model( kwargs_shapes[input_name] if input_name in kwargs_shapes else DEFAULT_DUMMY_SHAPES[input_name] ) - onnx_config, models_and_onnx_configs = _get_submodels_and_onnx_configs( + export_config, models_and_export_configs = _get_submodels_and_export_configs( model=model, task=task, monolith=False, - custom_onnx_configs=custom_onnx_configs if custom_onnx_configs is not None else {}, + custom_export_configs=custom_export_configs if custom_export_configs is not None else {}, custom_architecture=custom_architecture, fn_get_submodels=fn_get_submodels, preprocessors=preprocessors, @@ -581,6 +575,7 @@ def export_from_model( model_kwargs=model_kwargs, _variant="default", legacy=False, + exporter="openvino", ) if ov_config is None: @@ -612,18 +607,18 @@ def export_from_model( model_name_or_path = model.config._name_or_path maybe_save_preprocessors(model_name_or_path, output) - files_subpaths = ["openvino_" + model_name + ".xml" for model_name in models_and_onnx_configs.keys()] + files_subpaths = ["openvino_" + model_name + ".xml" for model_name in models_and_export_configs.keys()] else: # save the subcomponent configuration - for model_name in models_and_onnx_configs: - subcomponent = models_and_onnx_configs[model_name][0] + for model_name in models_and_export_configs: + subcomponent = models_and_export_configs[model_name][0] if hasattr(subcomponent, "save_config"): subcomponent.save_config(output / model_name) elif hasattr(subcomponent, "config") and hasattr(subcomponent.config, "save_pretrained"): subcomponent.config.save_pretrained(output / model_name) - files_subpaths = [os.path.join(name_dir, OV_XML_FILE_NAME) for name_dir in models_and_onnx_configs] + files_subpaths = [os.path.join(name_dir, OV_XML_FILE_NAME) for name_dir in models_and_export_configs] # Saving the additional components needed to perform inference. model.scheduler.save_pretrained(output.joinpath("scheduler")) @@ -643,7 +638,7 @@ def export_from_model( model.save_config(output) export_models( - models_and_onnx_configs=models_and_onnx_configs, + models_and_export_configs=models_and_export_configs, output_dir=output, output_names=files_subpaths, input_shapes=input_shapes, diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py new file mode 100644 index 0000000000..b6536512b1 --- /dev/null +++ b/optimum/exporters/openvino/model_configs.py @@ -0,0 +1,391 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union + +from packaging import version +from transformers.utils import is_tf_available + +from optimum.exporters.onnx.config import TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig +from optimum.exporters.onnx.model_configs import GemmaOnnxConfig +from optimum.exporters.tasks import TasksManager +from optimum.utils import DEFAULT_DUMMY_SHAPES +from optimum.utils.input_generators import ( + DummyInputGenerator, + DummyPastKeyValuesGenerator, + DummyTextInputGenerator, + MistralDummyPastKeyValuesGenerator, +) +from optimum.utils.normalized_config import NormalizedTextConfig + +from .model_patcher import ( + BaichuanModelPatcher, + ChatGLMModelPatcher, + GemmaModelPatcher, + MixtralModelPatcher, + QwenModelPatcher, +) + + +def init_model_configs(): + supported_model_types = [ + "_SUPPORTED_MODEL_TYPE", + "_DIFFUSERS_SUPPORTED_MODEL_TYPE", + "_TIMM_SUPPORTED_MODEL_TYPE", + "_SENTENCE_TRANSFORMERS_SUPPORTED_MODEL_TYPE", + ] + + for supported_models_config in supported_model_types: + supported_models = getattr(TasksManager, supported_models_config) + for model, export_configs in supported_models.items(): + if "onnx" not in export_configs: + continue + onnx_config = export_configs["onnx"] + supported_models[model]["openvino"] = deepcopy(onnx_config) + + setattr(TasksManager, supported_models_config, supported_models) + + +init_model_configs() + + +if TYPE_CHECKING: + from transformers.modeling_utils import PreTrainedModel + + from optimum.exporters.onnx.model_patcher import ModelPatcher + + if is_tf_available(): + from transformers.modeling_tf_utils import TFPreTrainedModel + + +register_in_tasks_manager = TasksManager.create_register("openvino", overwrite_existing=True) + + +@register_in_tasks_manager("baichuan", *["text-generation", "text-generation-with-past"], library_name="transformers") +class BaichaunOpenVINOConfig(TextDecoderOnnxConfig): + DEFAULT_ONNX_OPSET = 13 + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args( + num_layers="num_hidden_layers", num_attention_heads="num_attention_heads", hidden_size="hidden_size" + ) + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return BaichuanModelPatcher(self, model, model_kwargs=model_kwargs) + + +@register_in_tasks_manager("qwen2", *["text-generation", "text-generation-with-past"], library_name="transformers") +class Qwen2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 + + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + +@register_in_tasks_manager("minicpm", *["text-generation", "text-generation-with-past"], library_name="transformers") +class MiniCPMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 + + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + +@register_in_tasks_manager("stablelm", *["text-generation", "text-generation-with-past"], library_name="transformers") +class StableLMOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 + + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MistralDummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + + +class ChatGLM2DummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + random_batch_size_range: Optional[Tuple[int, int]] = None, + random_sequence_length_range: Optional[Tuple[int, int]] = None, + **kwargs, + ): + super().__init__( + task=task, + normalized_config=normalized_config, + batch_size=batch_size, + sequence_length=sequence_length, + random_batch_size_range=random_batch_size_range, + random_sequence_length_range=random_sequence_length_range, + ) + self.multi_query_group_num = normalized_config.multi_query_group_num + self.head_dim = normalized_config.kv_channels + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + past_key_shape = ( + self.sequence_length, + self.batch_size, + self.multi_query_group_num, + self.head_dim, + ) + past_value_shape = ( + self.sequence_length, + self.batch_size, + self.multi_query_group_num, + self.head_dim, + ) + return [ + ( + self.random_float_tensor(past_key_shape, framework=framework, dtype=float_dtype), + self.random_float_tensor(past_value_shape, framework=framework, dtype=float_dtype), + ) + for _ in range(self.num_layers) + ] + + +@register_in_tasks_manager("chatglm", *["text-generation", "text-generation-with-past"], library_name="transformers") +class ChatGLM2OpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(vocab_size="padded_vocab_size", num_layers="num_layers") + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, ChatGLM2DummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = ChatGLM2DummyPastKeyValuesGenerator + + def generate_dummy_inputs(self, framework: str = "pt", **kwargs): + dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs) + + dummy_inputs = {} + input_names = [key for key in self.inputs.keys() if not key.startswith("past_key_values")] + if self.use_past_in_inputs and self.use_cache_branch is not False: + input_names.append("past_key_values") + + for input_name in input_names: + input_was_inserted = False + for dummy_input_gen in dummy_inputs_generators: + if dummy_input_gen.supports_input(input_name): + dummy_inputs[input_name] = self.overwrite_shape_and_generate_input( + dummy_input_gen, + input_name, + framework, + input_shapes=kwargs, + ) + input_was_inserted = True + break + if not input_was_inserted: + raise RuntimeError( + f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to the model ONNX config.' + ) + + # refer to https://github.com/huggingface/optimum/pull/764 + if ( + self.use_past_in_inputs + and self.PAD_ATTENTION_MASK_TO_PAST + and self.use_cache_branch is not False + and "attention_mask" in dummy_inputs + ): + # Obtain the past sequence length from the value instead of the key (Bloom). ChatGLM has seq_len in 0 dim instead of -2 + past_present_length = dummy_inputs["input_ids"].shape[1] + dummy_inputs["past_key_values"][0][1].shape[0] + + dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim( + dummy_inputs["attention_mask"], + desired_length=past_present_length, + dim=1, + dtype=dummy_inputs["attention_mask"].dtype, + ) + + return dummy_inputs + + def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): + """ + Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction. + + Args: + inputs_or_outputs (`Dict[str, Dict[int, str]]`): The mapping to fill. + direction (`str`): + either "inputs" or "outputs", it specifies whether `input_or_outputs` is the input mapping or the + output mapping, this is important for axes naming. + """ + if direction not in ["inputs", "outputs"]: + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') + + if direction == "inputs": + decoder_sequence_name = "past_sequence_length" + name = "past_key_values" + else: + decoder_sequence_name = "past_sequence_length + present_lenght" + name = "present" + + for i in range(self._normalized_config.num_layers): + inputs_or_outputs[f"{name}.{i}.key"] = {1: "batch_size", 0: decoder_sequence_name} + inputs_or_outputs[f"{name}.{i}.value"] = {1: "batch_size", 0: decoder_sequence_name} + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return ChatGLMModelPatcher(self, model, model_kwargs=model_kwargs) + + +@register_in_tasks_manager("mixtral", *["text-generation", "text-generation-with-past"], library_name="transformers") +class MixtralOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + # This is because of the patching of torch.triu in AttentionMaskConverter, that exists from transformers>=4.35 + MIN_TRANSFORMERS_VERSION = version.parse("4.34.99") + + # The ONNX export of this architecture needs the Trilu operator support, available since opset 14 + DEFAULT_ONNX_OPSET = 14 + DUMMY_INPUT_GENERATOR_CLASSES = ( + MistralDummyPastKeyValuesGenerator, + ) + TextDecoderOnnxConfig.DUMMY_INPUT_GENERATOR_CLASSES + DUMMY_PKV_GENERATOR_CLASS = MistralDummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(num_key_value_heads="num_key_value_heads", allow_new=True) + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return MixtralModelPatcher(self, model, model_kwargs=model_kwargs) + + +@register_in_tasks_manager( + "gemma", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text-classification", + ], + library_name="transformers", +) +class GemmaOpenVINOConfig(GemmaOnnxConfig): + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return GemmaModelPatcher(self, model, model_kwargs=model_kwargs) + + +class QwenDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + random_batch_size_range: Optional[Tuple[int, int]] = None, + random_sequence_length_range: Optional[Tuple[int, int]] = None, + **kwargs, + ): + super().__init__( + task=task, + normalized_config=normalized_config, + batch_size=batch_size, + sequence_length=sequence_length, + random_batch_size_range=random_batch_size_range, + random_sequence_length_range=random_sequence_length_range, + ) + self.kv_channels = normalized_config.kv_channels + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + past_key_shape = (self.batch_size, self.sequence_length, self.num_attention_heads, self.kv_channels) + past_value_shape = (self.batch_size, self.sequence_length, self.num_attention_heads, self.kv_channels) + return [ + ( + self.random_float_tensor(past_key_shape, framework=framework, dtype=float_dtype), + self.random_float_tensor(past_value_shape, framework=framework, dtype=float_dtype), + ) + for _ in range(self.num_layers) + ] + + +@register_in_tasks_manager("qwen", *["text-generation", "text-generation-with-past"]) +class QwenOpenVINOConfig(TextDecoderWithPositionIdsOnnxConfig): + DEFAULT_ONNX_OPSET = 14 + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args( + num_layers="num_hidden_layers", num_attention_heads="num_attention_heads", hidden_size="hidden_size" + ) + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, QwenDummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = QwenDummyPastKeyValuesGenerator + no_position_ids = False + + def generate_dummy_inputs(self, framework: str = "pt", **kwargs): + dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs) + + dummy_inputs = {} + input_names = [key for key in self.inputs.keys() if not key.startswith("past_key_values")] + if self.use_past_in_inputs and self.use_cache_branch is not False: + input_names.append("past_key_values") + + for input_name in input_names: + input_was_inserted = False + for dummy_input_gen in dummy_inputs_generators: + if dummy_input_gen.supports_input(input_name): + dummy_inputs[input_name] = self.overwrite_shape_and_generate_input( + dummy_input_gen, + input_name, + framework, + input_shapes=kwargs, + ) + input_was_inserted = True + break + if not input_was_inserted: + raise RuntimeError( + f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to the model ONNX config.' + ) + + # refer to https://github.com/huggingface/optimum/pull/764 + if ( + self.use_past_in_inputs + and self.PAD_ATTENTION_MASK_TO_PAST + and self.use_cache_branch is not False + and "attention_mask" in dummy_inputs + ): + # Obtain the past sequence length from the value instead of the key (Bloom). Qwen has seq_len in 1 dim instead of -2 + past_present_length = dummy_inputs["input_ids"].shape[1] + dummy_inputs["past_key_values"][0][1].shape[1] + + dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim( + dummy_inputs["attention_mask"], + desired_length=past_present_length, + dim=1, + dtype=dummy_inputs["attention_mask"].dtype, + ) + + return dummy_inputs + + def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): + """ + Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction. + + Args: + inputs_or_outputs (`Dict[str, Dict[int, str]]`): The mapping to fill. + direction (`str`): + either "inputs" or "outputs", it specifies whether `input_or_outputs` is the input mapping or the + output mapping, this is important for axes naming. + """ + if direction not in ["inputs", "outputs"]: + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') + + if direction == "inputs": + decoder_sequence_name = "past_sequence_length" + name = "past_key_values" + else: + decoder_sequence_name = "past_sequence_length + 1" + name = "present" + + for i in range(self._normalized_config.num_layers): + inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch_size", 1: decoder_sequence_name} + inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch_size", 1: decoder_sequence_name} + + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> "ModelPatcher": + return QwenModelPatcher(self, model, model_kwargs=model_kwargs) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 91dc48df05..371fee732a 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -1,4 +1,4 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,15 @@ # limitations under the License. import logging as log +import types +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +import torch +import torch.nn.functional as F +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.utils import is_tf_available + +from optimum.exporters.onnx.model_patcher import DecoderModelPatcher from optimum.intel.utils.import_utils import ( _openvino_version, _torch_version, @@ -24,6 +32,15 @@ ) +if TYPE_CHECKING: + from transformers.modeling_utils import PreTrainedModel + + from optimum.exporters.onnx.config import OnnxConfig + + if is_tf_available(): + from transformers.modeling_tf_utils import TFPreTrainedModel + + def patch_model_with_bettertransformer(model): COLOR_RED = "\033[1;31m" COLOR_RESET = "\033[0m" @@ -71,3 +88,425 @@ def patch_model_with_bettertransformer(model): return model return model + + +def _mixtral_sparse_moe_block_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class MixtralModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + for layer in self._model.model.layers: + layer.block_sparse_moe._unpatched_forward = layer.block_sparse_moe.forward + layer.block_sparse_moe.forward = types.MethodType( + _mixtral_sparse_moe_block_forward, layer.block_sparse_moe + ) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for layer in self._model.model.layers: + layer.block_sparse_moe.forward = layer.block_sparse_moe._unpatched_forward + + +def _chatglm_transformer_forward( + self, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, +): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt( + batch_size=batch_size, + device=input_ids.device, + dtype=inputs_embeds.dtype, + ) + if attention_mask is not None: + attention_mask = torch.cat( + [ + attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask, + ], + dim=-1, + ) + + if full_attention_mask is None: + if past_key_values is not None: + full_attention_mask = torch.ones( + batch_size, + seq_length, + seq_length, + device=input_ids.device, + dtype=torch.float, + ) * float("-inf") + full_attention_mask.triu_(diagonal=1) + past_length = 0 + if past_key_values: + past_length = past_key_values[0][0].shape[0] + if past_length: + full_attention_mask = torch.cat( + ( + torch.zeros(batch_size, seq_length, past_length, device=input_ids.device), + full_attention_mask, + ), + dim=-1, + ) + full_attention_mask.unsqueeze_(1) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, + full_attention_mask, + rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + ) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +def _chatglm2_get_context_layer(query_layer: torch.Tensor, key_layer: torch.Tensor, value_layer: torch.Tensor): + mask = torch.zeros((query_layer.shape[-2], key_layer.shape[-2]), dtype=query_layer.dtype) + if query_layer.shape[2] == key_layer.shape[2]: + tmp_mask = torch.ones((query_layer.shape[-2], key_layer.shape[-2]), dtype=torch.bool).triu(diagonal=1) + mask.masked_fill_(tmp_mask, float("-inf")) + + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, attn_mask=mask + ) + return context_layer + + +def _chatglm2_core_attention_forward(self, query_layer, key_layer, value_layer, attention_mask): + query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] + if attention_mask is None: + context_layer = _chatglm2_get_context_layer(query_layer, key_layer, value_layer) + else: + context_layer = torch.nn.functional.scaled_dot_product_attention( + query_layer, key_layer, value_layer, attention_mask + ) + context_layer = context_layer.permute(2, 0, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(*new_context_layer_shape) + + return context_layer + + +class ChatGLMModelPatcher(DecoderModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Dict[str, Any], + ): + super().__init__(config, model, model_kwargs) + + self.original_chatglm_transformer_forward = model.transformer.forward + + def __enter__(self): + super().__enter__() + self._model.transformer.forward = types.MethodType(_chatglm_transformer_forward, self._model.transformer) + for block in self._model.transformer.encoder.layers: + block.self_attention.core_attention._orig_forward = block.self_attention.core_attention.forward + block.self_attention.core_attention.forward = types.MethodType( + _chatglm2_core_attention_forward, block.self_attention.core_attention + ) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + self._model.transformer.forward = self.original_chatglm_transformer_forward + for block in self._model.transformer.encoder.layers: + block.self_attention.core_attention.forward = block.self_attention.core_attention._orig_forward + + +class GemmaModelPatcher(DecoderModelPatcher): + def __enter__(self): + super().__enter__() + + # init inv_freq for torchscript tracing + # https://github.com/huggingface/transformers/blob/ed74d97871468f3a4695ede50abdc0b55717a84d/src/transformers/models/gemma/modeling_gemma.py#L108 + for layer in self._model.model.layers: + if layer.self_attn.rotary_emb.inv_freq is None: + rotary_emb = layer.self_attn.rotary_emb + layer.self_attn.rotary_emb.inv_freq = 1.0 / ( + rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim) + ) + + +SUPPORT_SDPA = is_torch_version(">", "2.1.0") + + +def _qwen_rotate_half(x): + from einops import rearrange + + x = rearrange(x, "... (j d) -> ... j d", j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def _qwen_apply_rotary_pos_emb(t, freqs): + cos, sin = freqs + rot_dim = freqs[0].shape[-1] + cos, sin = freqs + t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:] + t_ = t_.float() + t_pass_ = t_pass_.float() + t_ = (t_ * cos) + (_qwen_rotate_half(t_) * sin) + return torch.cat((t_, t_pass_), dim=-1).type_as(t) + + +def _qwen_quantize_cache_v(fdata, bits, qmax, qmin): + # b, s, head, h-dim->b, head, s, h-dim + qtype = torch.uint8 + device = fdata.device + shape = fdata.shape + + fdata_cal = torch.flatten(fdata, 2) + fmax = torch.amax(fdata_cal, dim=-1, keepdim=True) + fmin = torch.amin(fdata_cal, dim=-1, keepdim=True) + # Compute params + if qmax.device != fmax.device: + qmax = qmax.to(device) + qmin = qmin.to(device) + scale = (fmax - fmin) / (qmax - qmin) + zero = qmin - fmin / scale + scale = scale.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous() + zero = zero.unsqueeze(-1).repeat(1, 1, shape[2], 1).contiguous() + # Quantize + res_data = fdata / scale + zero + qdata = torch.clamp(res_data, qmin, qmax).to(qtype) + return qdata.contiguous(), scale, zero + + +def _qwen_attention_forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + rotary_pos_emb_list: Optional[List[List[torch.Tensor]]] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, +): + mixed_x_layer = self.c_attn(hidden_states) + + query, key, value = mixed_x_layer.split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if rotary_pos_emb_list is not None: + cur_len = query.shape[1] + if len(rotary_pos_emb_list) == 1: + rotary_pos_emb = rotary_pos_emb_list[0] + rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] + rotary_pos_emb = (rotary_pos_emb,) * 2 + q_pos_emb, k_pos_emb = rotary_pos_emb + # Slice the pos emb for current inference + query = _qwen_apply_rotary_pos_emb(query, q_pos_emb) + key = _qwen_apply_rotary_pos_emb(key, k_pos_emb) + else: + query_list = [] + key_list = [] + for i, rotary_pos_emb in enumerate(rotary_pos_emb_list): + rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb] + rotary_pos_emb = (rotary_pos_emb,) * 2 + q_pos_emb, k_pos_emb = rotary_pos_emb + # Slice the pos emb for current inference + query_list += [_qwen_apply_rotary_pos_emb(query[i : i + 1, :, :], q_pos_emb)] + key_list += [_qwen_apply_rotary_pos_emb(key[i : i + 1, :, :], k_pos_emb)] + query = torch.cat(query_list, dim=0) + key = torch.cat(key_list, dim=0) + + if self.use_cache_quantization: + key = _qwen_quantize_cache_v(key.permute(0, 2, 1, 3), bits=8, qmin=self.cache_qmin, qmax=self.cache_qmax) + value = _qwen_quantize_cache_v(value.permute(0, 2, 1, 3), bits=8, qmin=self.cache_qmin, qmax=self.cache_qmax) + + if layer_past is not None: + past_key, past_value = layer_past[0], layer_past[1] + if self.use_cache_quantization: + # use_cache_quantization: + # present=((q_key,key_scale,key_zero_point), + # (q_value,value_scale,value_zero_point)) + key = ( + torch.cat((past_key[0], key[0]), dim=2), + torch.cat((past_key[1], key[1]), dim=2), + torch.cat((past_key[2], key[2]), dim=2), + ) + value = ( + torch.cat((past_value[0], value[0]), dim=2), + torch.cat((past_value[1], value[1]), dim=2), + torch.cat((past_value[2], value[2]), dim=2), + ) + else: + # not use_cache_quantization: + # present=(key,value) + key = torch.cat((past_key, key), dim=1) + value = torch.cat((past_value, value), dim=1) + + if use_cache: + present = (key, value) + else: + present = None + + if self.use_logn_attn and not self.training: + if self.use_cache_quantization: + seq_start = key[0].size(2) - query.size(1) + seq_end = key[0].size(2) + else: + seq_start = key.size(1) - query.size(1) + seq_end = key.size(1) + logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :].type_as(query) + query = query * logn_tensor.expand_as(query) + + if self.use_flash_attn and not self.is_fp32 and query.is_cuda: + q, k, v = query, key, value + attn_output = self.core_attention_flash(q, k, v, attention_mask=attention_mask) + else: + registered_causal_mask = torch.tril( + torch.ones((key.size(1), key.size(1)), dtype=torch.bool, device=key.device) + ).view(1, 1, key.size(1), key.size(1)) + query = query.permute(0, 2, 1, 3) + if not self.use_cache_quantization: + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + + if not self.use_cache_quantization and SUPPORT_SDPA: + causal_mask = registered_causal_mask[:, :, key.size(-2) - query.size(-2) : key.size(-2), : key.size(-2)] + if attention_mask is not None: + attention_mask = attention_mask.expand(-1, -1, causal_mask.size(2), -1).masked_fill( + ~causal_mask, torch.finfo(query.dtype).min + ) + else: + attention_mask = causal_mask + attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask).transpose(1, 2) + attn_weight = None + else: + attn_output, attn_weight = self._attn(query, key, value, registered_causal_mask, attention_mask, head_mask) + context_layer = self._merge_heads(attn_output, self.num_heads, self.head_dim) + + attn_output = self.c_proj(context_layer) + + outputs = (attn_output, present) + if output_attentions: + if self.use_flash_attn and not self.is_fp32: + raise ValueError("Cannot output attentions while using flash-attn") + else: + outputs += (attn_weight,) + + return outputs + + +class QwenModelPatcher(DecoderModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Dict[str, Any], + ): + super().__init__(config, model, model_kwargs) + + self.original_fp16 = model.config.fp16 + self.original_bf16 = model.config.bf16 + model.config.bf16 = False + model.config.fp16 = False + if self.original_fp16 or self.original_bf16: + model.to(torch.float32) + model.transformer.rotary_emb(2048) + + def __enter__(self): + super().__enter__() + for block in self._model.transformer.h: + block.attn._orig_forward = block.attn.forward + block.attn.forward = types.MethodType(_qwen_attention_forward, block.attn) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + for block in self._model.transformer.h: + block.attn.forward = block.attn._orig_forward + self._model.config.bf16 = self.original_bf16 + self._model.config.fp16 = self.original_fp16 + + +class BaichuanModelPatcher(DecoderModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Dict[str, Any], + ): + super().__init__(config, model, model_kwargs) + # model has first inference buffers initialization + if self._model.lm_head.first_flag: + self._model(torch.ones((1, 10), dtype=torch.int64), torch.ones((1, 10), dtype=torch.int64)) diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 53aa05bc5a..832c132615 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -316,7 +316,9 @@ def _reshape( shapes[inputs][0] = -1 input_name = inputs.get_any_name() if input_name.startswith("past_key_values"): - if len(inputs.partial_shape) == 3 and input_name.endswith("value"): + if ( + len(inputs.partial_shape) == 3 and input_name.endswith("value") + ) or self.config.model_type == "chatglm": shapes[inputs][1] = -1 else: shapes[inputs][2] = -1 diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index c46f29092b..2022a495d8 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -350,7 +350,7 @@ def _quantize_torchmodel( model_type = self.model.config.model_type.replace("_", "-") onnx_config_class = TasksManager.get_exporter_config_constructor( - exporter="onnx", + exporter="openvino", model=self.model, task=self.task, model_type=model_type, diff --git a/setup.py b/setup.py index 3a1e1123d0..5c6cf76404 100644 --- a/setup.py +++ b/setup.py @@ -28,8 +28,8 @@ INSTALL_REQUIRE = [ "torch>=1.11", - "optimum~=1.17", "transformers>=4.36.0,<4.39.0", + "optimum @ git+https://github.com/huggingface/optimum.git#egg=optimum", "datasets>=1.4.0", "sentencepiece", "scipy", @@ -50,6 +50,8 @@ "timm", "invisible-watermark>=0.2.0", "auto-gptq", + "transformers_stream_generator", + "einops", ] QUALITY_REQUIRE = ["black~=23.1", "ruff>=0.0.241"] diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 2188b7061f..9df6c73214 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -28,6 +28,7 @@ from parameterized import parameterized from PIL import Image from transformers import ( + AutoConfig, AutoFeatureExtractor, AutoModel, AutoModelForAudioClassification, @@ -52,7 +53,6 @@ from transformers.onnx.utils import get_preprocessor from utils_tests import MODEL_NAMES -from optimum.exporters.onnx import MODEL_TYPES_REQUIRING_POSITION_IDS from optimum.intel import ( OVModelForAudioClassification, OVModelForAudioFrameClassification, @@ -473,73 +473,101 @@ def test_pipeline(self, model_arch): class OVModelForCausalLMIntegrationTest(unittest.TestCase): SUPPORTED_ARCHITECTURES = ( "bart", + "baichuan2", "gpt_bigcode", "blenderbot", "blenderbot-small", "bloom", + "chatglm", "codegen", # "data2vec-text", # TODO : enable when enabled in exporters + "gemma", "gpt2", "gpt_neo", "gpt_neox", "llama", # "llama_gptq", "marian", + "minicpm", "mistral", + "mixtral", "mpt", "opt", "pegasus", + "qwen", + "qwen2", + "stablelm", ) GENERATION_LENGTH = 100 IS_SUPPORT_STATEFUL = is_openvino_version(">=", "2023.3") + REMOTE_CODE_MODELS = ("chatglm", "minicpm", "baichuan2", "jais", "qwen") @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] + not_stateful = ["gpt_bigcode"] + if is_openvino_version("<", "2024.0"): + not_stateful.append("mixtral") + + if is_openvino_version("<", "2024.1"): + not_stateful.extend(["llama", "gemma"]) if "gptq" in model_arch: self.skipTest("GPTQ model loading unsupported with AutoModelForCausalLM") set_seed(SEED) - ov_model = OVModelForCausalLM.from_pretrained(model_id, export=True, ov_config=F32_CONFIG) + + model_kwargs = {} + if model_arch in self.REMOTE_CODE_MODELS: + model_kwargs = { + "config": AutoConfig.from_pretrained(model_id, trust_remote_code=True), + "trust_remote_code": True, + } + ov_model = OVModelForCausalLM.from_pretrained(model_id, export=True, ov_config=F32_CONFIG, **model_kwargs) self.assertIsInstance(ov_model.config, PretrainedConfig) self.assertTrue(ov_model.use_cache) - - transformers_model = AutoModelForCausalLM.from_pretrained(model_id) - tokenizer = AutoTokenizer.from_pretrained(model_id) + self.assertEqual( + ov_model.stateful, self.IS_SUPPORT_STATEFUL and ov_model.config.model_type not in not_stateful + ) + transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS) + if model_arch == "qwen": + transformers_model.to(torch.float32) tokens = tokenizer( "This is a sample", return_tensors="pt", return_token_type_ids=False if model_arch == "llama" else None ) - position_ids = None - if model_arch.replace("_", "-") in MODEL_TYPES_REQUIRING_POSITION_IDS: - input_shape = tokens["input_ids"].shape - position_ids = torch.arange(0, input_shape[-1], dtype=torch.long).unsqueeze(0).view(-1, input_shape[-1]) - ov_outputs = ov_model(**tokens, position_ids=position_ids) + ov_outputs = ov_model(**tokens) self.assertTrue("logits" in ov_outputs) self.assertIsInstance(ov_outputs.logits, torch.Tensor) self.assertTrue("past_key_values" in ov_outputs) self.assertIsInstance(ov_outputs.past_key_values, tuple) - - is_stateful = ov_model.config.model_type not in {"gpt_bigcode", "llama"} and self.IS_SUPPORT_STATEFUL + is_stateful = ov_model.config.model_type not in not_stateful and self.IS_SUPPORT_STATEFUL self.assertEqual(ov_model.stateful, is_stateful) if is_stateful: self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0) - with torch.no_grad(): transformers_outputs = transformers_model(**tokens) # Compare tensor outputs - self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, atol=1e-4)) + self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, equal_nan=True, atol=1e-4)) del transformers_model del ov_model gc.collect() @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): + model_kwargs = {} model_id = MODEL_NAMES[model_arch] - tokenizer = AutoTokenizer.from_pretrained(model_id) - model = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=False, compile=False) + if model_arch in self.REMOTE_CODE_MODELS: + model_kwargs = { + "config": AutoConfig.from_pretrained(model_id, trust_remote_code=True), + "trust_remote_code": True, + } + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS) + model = OVModelForCausalLM.from_pretrained( + model_id, export=True, use_cache=False, compile=False, **model_kwargs + ) model.config.encoder_no_repeat_ngram_size = 0 model.to("cpu") model.half() @@ -556,8 +584,16 @@ def test_pipeline(self, model_arch): def test_multiple_inputs(self, model_arch): model_id = MODEL_NAMES[model_arch] set_seed(SEED) - model = OVModelForCausalLM.from_pretrained(model_id, export=True, compile=False) - tokenizer = AutoTokenizer.from_pretrained(model_id) + if model_arch == "qwen": + self.skipTest("Qwen tokenizer does not support padding") + model_kwargs = {} + if model_arch in self.REMOTE_CODE_MODELS: + model_kwargs = { + "config": AutoConfig.from_pretrained(model_id, trust_remote_code=True), + "trust_remote_code": True, + } + model = OVModelForCausalLM.from_pretrained(model_id, export=True, compile=False, **model_kwargs) + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=model_arch in self.REMOTE_CODE_MODELS) tokenizer.pad_token = tokenizer.eos_token texts = ["this is a simple input", "this is a second simple input", "this is a third simple input"] tokens = tokenizer(texts, padding=True, return_tensors="pt") diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index 97c8a92836..ad3cd03d3d 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -22,12 +22,14 @@ "beit": "hf-internal-testing/tiny-random-BeitForImageClassification", "bert": "hf-internal-testing/tiny-random-bert", "bart": "hf-internal-testing/tiny-random-bart", + "baichuan2": "katuni4ka/tiny-random-baichuan2", "bigbird_pegasus": "hf-internal-testing/tiny-random-bigbird_pegasus", "blenderbot-small": "hf-internal-testing/tiny-random-BlenderbotModel", "blenderbot": "hf-internal-testing/tiny-random-BlenderbotModel", "bloom": "hf-internal-testing/tiny-random-BloomModel", "camembert": "hf-internal-testing/tiny-random-camembert", "convbert": "hf-internal-testing/tiny-random-ConvBertForSequenceClassification", + "chatglm": "katuni4ka/tiny-random-chatglm2", "codegen": "hf-internal-testing/tiny-random-CodeGenForCausalLM", "data2vec_text": "hf-internal-testing/tiny-random-Data2VecTextModel", "data2vec_vision": "hf-internal-testing/tiny-random-Data2VecVisionModel", @@ -38,6 +40,7 @@ "convnext": "hf-internal-testing/tiny-random-convnext", "distilbert": "hf-internal-testing/tiny-random-distilbert", "electra": "hf-internal-testing/tiny-random-electra", + "gemma": "fxmarty/tiny-random-GemmaForCausalLM", "flaubert": "hf-internal-testing/tiny-random-flaubert", "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel", "gpt2": "hf-internal-testing/tiny-random-gpt2", @@ -55,7 +58,9 @@ "opt125m": "facebook/opt-125m", "marian": "sshleifer/tiny-marian-en-de", "mbart": "hf-internal-testing/tiny-random-mbart", + "minicpm": "katuni4ka/tiny-random-minicpm", "mistral": "echarlaix/tiny-random-mistral", + "mixtral": "TitanML/tiny-mixtral", "mobilebert": "hf-internal-testing/tiny-random-MobileBertModel", "mobilenet_v1": "google/mobilenet_v1_0.75_192", "mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model", @@ -66,6 +71,8 @@ "pegasus": "hf-internal-testing/tiny-random-pegasus", "pix2struct": "fxmarty/pix2struct-tiny-random", "poolformer": "hf-internal-testing/tiny-random-PoolFormerModel", + "qwen": "katuni4ka/tiny-random-qwen", + "qwen2": "Qwen/Qwen1.5-0.5B", "resnet": "hf-internal-testing/tiny-random-resnet", "roberta": "hf-internal-testing/tiny-random-roberta", "roformer": "hf-internal-testing/tiny-random-roformer", @@ -76,6 +83,7 @@ "stable-diffusion": "hf-internal-testing/tiny-stable-diffusion-torch", "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", "stable-diffusion-xl-refiner": "echarlaix/tiny-random-stable-diffusion-xl-refiner", + "stablelm": "hf-internal-testing/tiny-random-StableLmForCausalLM", "latent-consistency": "echarlaix/tiny-random-latent-consistency", "sew": "hf-internal-testing/tiny-random-SEWModel", "sew_d": "asapp/sew-d-tiny-100k-ft-ls100h",