From 64f64b041ab2b4a44bb4006f0527803c1d5c67cf Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Wed, 4 Dec 2024 13:39:08 +0100 Subject: [PATCH 01/12] Support AWQ models --- optimum/exporters/openvino/__main__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index e4fe2a7a41..0e913e8637 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -242,7 +242,7 @@ def main_export( trust_remote_code=trust_remote_code, ) quantization_config = getattr(config, "quantization_config", None) - do_gptq_patching = quantization_config and quantization_config["quant_method"] == "gptq" + do_gptq_patching = quantization_config and quantization_config["quant_method"] in ["gptq", "awq"] model_type = config.model_type.replace("_", "-") if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: custom_architecture = True @@ -291,7 +291,6 @@ def main_export( if ( dtype is None and framework == "pt" - and not do_gptq_patching and ( task.startswith("text-generation") or getattr(config, "model_type", None) in MULTI_MODAL_TEXT_GENERATION_MODELS @@ -311,7 +310,6 @@ def main_export( loading_kwargs["torch_dtype"] = dtype # Patch the modules to export of GPTQ models w/o GPU if do_gptq_patching: - torch.set_default_dtype(torch.float32) orig_cuda_check = torch.cuda.is_available torch.cuda.is_available = lambda: True From 86d9328ab6cfde60a97c99492c282abbe8cbd2d5 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Thu, 5 Dec 2024 16:37:02 +0100 Subject: [PATCH 02/12] Add tests --- optimum/exporters/openvino/convert.py | 7 +++- tests/openvino/test_modeling.py | 60 ++++++++++++++++++++++++--- tests/openvino/utils_tests.py | 3 +- 3 files changed, 62 insertions(+), 8 deletions(-) diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 6012e6cfb5..3f8a73df6f 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -447,8 +447,11 @@ def ts_patched_forward(*args, **kwargs): if patch_16bit_model: from openvino.frontend.pytorch.patch_model import unpatch_model - unpatch_model(model, "_openvino_module_extension_patch_orig_forward") - model.to(torch.float32) + unpatch_model(model, "_openvino_module_extension_patch_orig_forward") + for m in model.modules(): + if (any(p.dtype in [torch.float16, torch.bfloat16] for p in m.parameters()) + or any(b.dtype in [torch.float16, torch.bfloat16] for b in m.buffers())): + m.float() return export_pytorch_via_onnx( model, diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 240f4f9e3f..71139547f5 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -872,13 +872,14 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "gpt_neo", "gpt_neox", "llama", - # "llama_gptq", "marian", "minicpm", "mistral", "mixtral", + "mixtral_awq", "mpt", "opt", + "opt_gptq", "pegasus", "qwen", "phi", @@ -949,9 +950,6 @@ def test_compare_to_transformers(self, model_arch): if is_openvino_version("<", "2024.1"): not_stateful.extend(["llama", "gemma", "gpt_bigcode"]) - if "gptq" in model_arch: - self.skipTest("GPTQ model loading unsupported with AutoModelForCausalLM") - set_seed(SEED) model_kwargs = {} @@ -978,6 +976,46 @@ def test_compare_to_transformers(self, model_arch): if is_stateful: self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0) + if "awq" in model_arch or "gptq" in model_arch: + orig_cuda_is_available = torch.cuda.is_available + torch.cuda.is_available = lambda: True + # infer in FP32 + model_kwargs["torch_dtype"] = torch.float32 + + if "awq" in model_arch: + # patch GEMM module to allow inference without CUDA GPU + from awq.modules.linear.gemm import WQLinearMMFunction + from awq.utils.packing_utils import dequantize_gemm + + def new_forward( + ctx, + x, + qweight, + qzeros, + scales, + w_bit=4, + group_size=128, + bias=None, + out_features=0, + ): + ctx.out_features = out_features + + out_shape = x.shape[:-1] + (out_features,) + x = x.to(torch.float16) + + out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size) + out = torch.matmul(x, out) + + out = out + bias if bias is not None else out + out = out.reshape(out_shape) + + if len(out.shape) == 2: + out = out.unsqueeze(0) + return out + + orig_gemm_forward = WQLinearMMFunction.forward + WQLinearMMFunction.forward = new_forward + set_seed(SEED) transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) if model_arch in ["qwen", "arctic", "glm4"]: @@ -988,10 +1026,14 @@ def test_compare_to_transformers(self, model_arch): # Compare tensor outputs atol = 1e-3 if model_arch == "minicpm" else 1e-4 + # quantized models have higher tolerance + if "awq" in model_arch: + atol = 1e-2 + elif "gptq" in model_arch: + atol = 0.6 self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, equal_nan=True, atol=atol)) # Qwen tokenizer does not support padding - if model_arch in ["qwen"]: return @@ -1026,11 +1068,19 @@ def test_compare_to_transformers(self, model_arch): additional_inputs = {"past_key_values": DynamicCache()} transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config, **additional_inputs) + print(f"ov_outputs: {ov_outputs}") + print(f"transformers_outputs: {transformers_outputs}") self.assertTrue( torch.allclose(ov_outputs, transformers_outputs), "OV output {ov_outputs}\nTransformers output {transformers_output}", ) + if "awq" in model_arch: + WQLinearMMFunction.forward = orig_gemm_forward + + if "awq" in model_arch or "gptq" in model_arch: + torch.cuda.is_available = orig_cuda_is_available + del transformers_model del ov_model gc.collect() diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index 17d9dd1fbe..a725cb3d2d 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -77,12 +77,12 @@ "longt5": "hf-internal-testing/tiny-random-longt5", "llama": "HuggingFaceM4/tiny-random-LlamaForCausalLM", "llama_awq": "HuggingFaceH4/tiny-random-LlamaForCausalLM", - "llama_gptq": "hf-internal-testing/TinyLlama-1.1B-Chat-v0.3-GPTQ", "llava": "katuni4ka/tiny-random-llava", "llava_next": "katuni4ka/tiny-random-llava-next", "m2m_100": "hf-internal-testing/tiny-random-m2m_100", "opt": "hf-internal-testing/tiny-random-OPTModel", "opt125m": "facebook/opt-125m", + "opt_gptq": "katuni4ka/opt-125m-gptq", "marian": "sshleifer/tiny-marian-en-de", "mbart": "hf-internal-testing/tiny-random-mbart", "minicpm": "katuni4ka/tiny-random-minicpm", @@ -91,6 +91,7 @@ "mistral": "echarlaix/tiny-random-mistral", "mistral-nemo": "katuni4ka/tiny-random-mistral-nemo", "mixtral": "TitanML/tiny-mixtral", + "mixtral_awq": "TitanML/tiny-mixtral-AWQ-4bit", "mobilebert": "hf-internal-testing/tiny-random-MobileBertModel", "mobilenet_v1": "google/mobilenet_v1_0.75_192", "mobilenet_v2": "hf-internal-testing/tiny-random-MobileNetV2Model", From decbcc203bc5928a1b52c65a1e8ff917ff66d1e1 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Thu, 5 Dec 2024 17:47:57 +0100 Subject: [PATCH 03/12] Add dependencies --- setup.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.py b/setup.py index cd49ea041a..8e1661f4e2 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,8 @@ ] TESTS_REQUIRE = [ + "auto-gptq", + "autoawq", "accelerate", "pytest>=7.2.0,<8.0.0", "parameterized", From 9fb1da4b3f0c0f8160f865de17eb1f70f2d56ad5 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Wed, 11 Dec 2024 17:24:23 +0100 Subject: [PATCH 04/12] Fix tests --- optimum/exporters/openvino/convert.py | 7 +-- tests/openvino/test_modeling.py | 67 +++++++-------------------- tests/openvino/utils_tests.py | 58 ++++++++++++++++++++++- 3 files changed, 78 insertions(+), 54 deletions(-) diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 3f8a73df6f..e7cdbfbc9a 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -447,10 +447,11 @@ def ts_patched_forward(*args, **kwargs): if patch_16bit_model: from openvino.frontend.pytorch.patch_model import unpatch_model - unpatch_model(model, "_openvino_module_extension_patch_orig_forward") + unpatch_model(model, "_openvino_module_extension_patch_orig_forward") for m in model.modules(): - if (any(p.dtype in [torch.float16, torch.bfloat16] for p in m.parameters()) - or any(b.dtype in [torch.float16, torch.bfloat16] for b in m.buffers())): + if any(p.dtype in [torch.float16, torch.bfloat16] for p in m.parameters(False)) or any( + b.dtype in [torch.float16, torch.bfloat16] for b in m.buffers(False) + ): m.float() return export_pytorch_via_onnx( diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 71139547f5..38ebb13bfa 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -62,7 +62,7 @@ ) from transformers.onnx.utils import get_preprocessor from transformers.testing_utils import slow -from utils_tests import MODEL_NAMES, TEST_IMAGE_URL +from utils_tests import MODEL_NAMES, TEST_IMAGE_URL, mock_torch_cuda_is_available, patch_awq_for_inference from optimum.exporters.openvino.model_patcher import patch_update_causal_mask from optimum.intel import ( @@ -977,52 +977,18 @@ def test_compare_to_transformers(self, model_arch): self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0) if "awq" in model_arch or "gptq" in model_arch: - orig_cuda_is_available = torch.cuda.is_available - torch.cuda.is_available = lambda: True # infer in FP32 model_kwargs["torch_dtype"] = torch.float32 - if "awq" in model_arch: - # patch GEMM module to allow inference without CUDA GPU - from awq.modules.linear.gemm import WQLinearMMFunction - from awq.utils.packing_utils import dequantize_gemm - - def new_forward( - ctx, - x, - qweight, - qzeros, - scales, - w_bit=4, - group_size=128, - bias=None, - out_features=0, - ): - ctx.out_features = out_features - - out_shape = x.shape[:-1] + (out_features,) - x = x.to(torch.float16) - - out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size) - out = torch.matmul(x, out) - - out = out + bias if bias is not None else out - out = out.reshape(out_shape) - - if len(out.shape) == 2: - out = out.unsqueeze(0) - return out - - orig_gemm_forward = WQLinearMMFunction.forward - WQLinearMMFunction.forward = new_forward - set_seed(SEED) - transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) + with mock_torch_cuda_is_available("awq" in model_arch or "gptq" in model_arch): + transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) if model_arch in ["qwen", "arctic", "glm4"]: transformers_model.to(torch.float32) with torch.no_grad(): - transformers_outputs = transformers_model(**tokens) + with patch_awq_for_inference("awq" in model_arch): + transformers_outputs = transformers_model(**tokens) # Compare tensor outputs atol = 1e-3 if model_arch == "minicpm" else 1e-4 @@ -1067,7 +1033,8 @@ def new_forward( from transformers.cache_utils import DynamicCache additional_inputs = {"past_key_values": DynamicCache()} - transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config, **additional_inputs) + with patch_awq_for_inference("awq" in model_arch): + transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config, **additional_inputs) print(f"ov_outputs: {ov_outputs}") print(f"transformers_outputs: {transformers_outputs}") self.assertTrue( @@ -1075,12 +1042,6 @@ def new_forward( "OV output {ov_outputs}\nTransformers output {transformers_output}", ) - if "awq" in model_arch: - WQLinearMMFunction.forward = orig_gemm_forward - - if "awq" in model_arch or "gptq" in model_arch: - torch.cuda.is_available = orig_cuda_is_available - del transformers_model del ov_model gc.collect() @@ -1311,8 +1272,13 @@ def test_beam_search(self, model_arch): ov_model_stateless = OVModelForCausalLM.from_pretrained( model_id, export=True, use_cache=True, stateful=False, **model_kwargs ) + if "awq" in model_arch or "gptq" in model_arch: + # infer in FP32 + model_kwargs["torch_dtype"] = torch.float32 + set_seed(SEED) - transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) + with mock_torch_cuda_is_available("awq" in model_arch or "gptq" in model_arch): + transformers_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) if model_arch == "arctic": transformers_model.to(torch.float32) @@ -1338,9 +1304,10 @@ def test_beam_search(self, model_arch): if model_arch == "gemma2": additional_inputs = {"past_key_values": DynamicCache()} - transformers_outputs = transformers_model.generate( - **tokens, generation_config=gen_config, **additional_inputs - ) + with patch_awq_for_inference("awq" in model_arch): + transformers_outputs = transformers_model.generate( + **tokens, generation_config=gen_config, **additional_inputs + ) set_seed(SEED) ov_stateful_outputs = ov_model_stateful.generate(**tokens, generation_config=gen_config) self.assertTrue( diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index a725cb3d2d..fba13326f1 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -15,6 +15,7 @@ import numpy as np import openvino as ov import torch +from contextlib import contextmanager MODEL_NAMES = { @@ -82,7 +83,7 @@ "m2m_100": "hf-internal-testing/tiny-random-m2m_100", "opt": "hf-internal-testing/tiny-random-OPTModel", "opt125m": "facebook/opt-125m", - "opt_gptq": "katuni4ka/opt-125m-gptq", + "opt_gptq": "ybelkada/opt-125m-gptq-4bit", "marian": "sshleifer/tiny-marian-en-de", "mbart": "hf-internal-testing/tiny-random-mbart", "minicpm": "katuni4ka/tiny-random-minicpm", @@ -219,3 +220,58 @@ def get_num_quantized_nodes(model): if type_name == "nf4": num_weight_nodes["nf4"] += 1 return num_fake_quantize, num_weight_nodes + + +@contextmanager +def mock_torch_cuda_is_available(to_patch): + original_is_available = torch.cuda.is_available + if to_patch: + torch.cuda.is_available = lambda: True + try: + yield + finally: + if to_patch: + torch.cuda.is_available = original_is_available + + +@contextmanager +def patch_awq_for_inference(to_patch): + orig_gemm_forward = None + if to_patch: + # patch GEMM module to allow inference without CUDA GPU + from awq.modules.linear.gemm import WQLinearMMFunction + from awq.utils.packing_utils import dequantize_gemm + + def new_forward( + ctx, + x, + qweight, + qzeros, + scales, + w_bit=4, + group_size=128, + bias=None, + out_features=0, + ): + ctx.out_features = out_features + + out_shape = x.shape[:-1] + (out_features,) + x = x.to(torch.float16) + + out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size) + out = torch.matmul(x, out) + + out = out + bias if bias is not None else out + out = out.reshape(out_shape) + + if len(out.shape) == 2: + out = out.unsqueeze(0) + return out + + orig_gemm_forward = WQLinearMMFunction.forward + WQLinearMMFunction.forward = new_forward + try: + yield + finally: + if orig_gemm_forward is not None: + WQLinearMMFunction.forward = orig_gemm_forward From 04d0cf90aa468e16f3ca6324e7879196b9dbccfc Mon Sep 17 00:00:00 2001 From: eaidova Date: Tue, 17 Dec 2024 19:35:19 +0400 Subject: [PATCH 05/12] enable awq export only if ov support it --- optimum/exporters/openvino/__main__.py | 5 ++++- tests/openvino/test_modeling.py | 8 ++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 0e913e8637..42d3d94064 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -242,7 +242,10 @@ def main_export( trust_remote_code=trust_remote_code, ) quantization_config = getattr(config, "quantization_config", None) - do_gptq_patching = quantization_config and quantization_config["quant_method"] in ["gptq", "awq"] + supported_quant_methods = ["gptq"] + if is_openvino_version(">=", "2024.6.0"): + supported_quant_methods.append("awq") + do_gptq_patching = quantization_config and quantization_config["quant_method"] in supported_quant_methods model_type = config.model_type.replace("_", "-") if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: custom_architecture = True diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 38ebb13bfa..e1f1cecda5 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -876,7 +876,6 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "minicpm", "mistral", "mixtral", - "mixtral_awq", "mpt", "opt", "opt_gptq", @@ -918,6 +917,9 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "minicpm3", ) + if is_openvino_version(">=", "2024.6.0"): + SUPPORTED_ARCHITECTURES += ("mixtral_awq",) + GENERATION_LENGTH = 100 REMOTE_CODE_MODELS = ( "chatglm", @@ -1034,7 +1036,9 @@ def test_compare_to_transformers(self, model_arch): additional_inputs = {"past_key_values": DynamicCache()} with patch_awq_for_inference("awq" in model_arch): - transformers_outputs = transformers_model.generate(**tokens, generation_config=gen_config, **additional_inputs) + transformers_outputs = transformers_model.generate( + **tokens, generation_config=gen_config, **additional_inputs + ) print(f"ov_outputs: {ov_outputs}") print(f"transformers_outputs: {transformers_outputs}") self.assertTrue( From df97004c83502313e46490a007585c6254834e6c Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Tue, 17 Dec 2024 19:47:22 +0400 Subject: [PATCH 06/12] fix style (#2) --- tests/openvino/utils_tests.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index fba13326f1..0e748e7148 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from contextlib import contextmanager + import numpy as np import openvino as ov import torch -from contextlib import contextmanager MODEL_NAMES = { From cf2fc8b6e33d3138aded3466a53fa22df6841138 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Tue, 17 Dec 2024 20:33:51 +0400 Subject: [PATCH 07/12] disable awq and gptq install for old torch (#3) * fix style * disable autogptq and autoawq install for old transformers testing --- .github/workflows/test_openvino.yml | 5 +++++ .github/workflows/test_openvino_slow.yml | 5 +++++ setup.py | 2 -- tests/openvino/test_modeling.py | 2 +- 4 files changed, 11 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test_openvino.yml b/.github/workflows/test_openvino.yml index e2889cb4e0..eca8233988 100644 --- a/.github/workflows/test_openvino.yml +++ b/.github/workflows/test_openvino.yml @@ -49,6 +49,11 @@ jobs: name: Downgrade Transformers and Accelerate run: | pip install transformers==${{ matrix.transformers-version }} accelerate==0.* + + - if: ${{ matrix.transformers-version == 'latest' && matrix.test-pattern == '*modeling*'}} + name: Install auto-gptq, autoawq + run: | + pip install auto-gptq autoawq --extra-index-url https://download.pytorch.org/whl/cpu - if: ${{ matrix.test-pattern == '*modeling*' }} name: Uninstall NNCF diff --git a/.github/workflows/test_openvino_slow.yml b/.github/workflows/test_openvino_slow.yml index bf52413a7d..ccb564bb33 100644 --- a/.github/workflows/test_openvino_slow.yml +++ b/.github/workflows/test_openvino_slow.yml @@ -56,6 +56,11 @@ jobs: name: Downgrade Transformers and Accelerate run: pip install transformers==${{ matrix.transformers-version }} accelerate==0.* + - if: ${{ matrix.transformers-version == 'latest' }} + name: Install auto-gptq, autoawq + run: | + pip install auto-gptq autoawq --extra-index-url https://download.pytorch.org/whl/cpu + - name: Pip freeze run: pip freeze diff --git a/setup.py b/setup.py index 8e1661f4e2..cd49ea041a 100644 --- a/setup.py +++ b/setup.py @@ -38,8 +38,6 @@ ] TESTS_REQUIRE = [ - "auto-gptq", - "autoawq", "accelerate", "pytest>=7.2.0,<8.0.0", "parameterized", diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index e1f1cecda5..8927da1ab4 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -878,7 +878,6 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "mixtral", "mpt", "opt", - "opt_gptq", "pegasus", "qwen", "phi", @@ -915,6 +914,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "exaone", "mistral-nemo", "minicpm3", + "opt_gptq", ) if is_openvino_version(">=", "2024.6.0"): From f0f7a722c6d853e2c20ceaad3bf9800d505ea164 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Wed, 18 Dec 2024 12:50:56 +0400 Subject: [PATCH 08/12] separate common quant models patching and gptq (#4) --- optimum/exporters/openvino/__main__.py | 36 ++++++++++++++------------ 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 7940ef567c..3015b20e83 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -232,6 +232,7 @@ def main_export( ) do_gptq_patching = False + do_quant_patching = False custom_architecture = False patch_16bit = False loading_kwargs = model_loading_kwargs or {} @@ -250,7 +251,8 @@ def main_export( supported_quant_methods = ["gptq"] if is_openvino_version(">=", "2024.6.0"): supported_quant_methods.append("awq") - do_gptq_patching = quantization_config and quantization_config["quant_method"] in supported_quant_methods + do_quant_patching = quantization_config and quantization_config["quant_method"] in supported_quant_methods + do_gptq_patching = do_quant_patching and quantization_config["quant_method"] == "gptq" model_type = config.model_type.replace("_", "-") if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: custom_architecture = True @@ -317,27 +319,28 @@ def main_export( patch_16bit = True loading_kwargs["torch_dtype"] = dtype # Patch the modules to export of GPTQ models w/o GPU - if do_gptq_patching: + if do_quant_patching: orig_cuda_check = torch.cuda.is_available torch.cuda.is_available = lambda: True - from optimum.gptq import GPTQQuantizer + if do_gptq_patching: + from optimum.gptq import GPTQQuantizer - orig_post_init_model = GPTQQuantizer.post_init_model + orig_post_init_model = GPTQQuantizer.post_init_model - def post_init_model(self, model): - from auto_gptq import exllama_set_max_input_length + def post_init_model(self, model): + from auto_gptq import exllama_set_max_input_length - class StoreAttr(object): - pass + class StoreAttr(object): + pass - model.quantize_config = StoreAttr() - model.quantize_config.desc_act = self.desc_act - if self.desc_act and not self.disable_exllama and self.max_input_length is not None: - model = exllama_set_max_input_length(model, self.max_input_length) - return model + model.quantize_config = StoreAttr() + model.quantize_config.desc_act = self.desc_act + if self.desc_act and not self.disable_exllama and self.max_input_length is not None: + model = exllama_set_max_input_length(model, self.max_input_length) + return model - GPTQQuantizer.post_init_model = post_init_model + GPTQQuantizer.post_init_model = post_init_model elif library_name == "diffusers" and is_openvino_version(">=", "2024.6"): dtype = deduce_diffusers_dtype( model_name_or_path, @@ -486,9 +489,10 @@ class StoreAttr(object): compressed_submodel_path.with_suffix(".bin").rename(submodel_path.with_suffix(".bin")) # Unpatch modules after GPTQ export - if do_gptq_patching: + if do_quant_patching: torch.cuda.is_available = orig_cuda_check - GPTQQuantizer.post_init_model = orig_post_init_model + if do_gptq_patching: + GPTQQuantizer.post_init_model = orig_post_init_model def maybe_convert_tokenizers(library_name: str, output: Path, model=None, preprocessors=None, task=None): From ab6ac99ea5a60c119f2cc60de2982e9d7b04685f Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Wed, 18 Dec 2024 14:34:09 +0400 Subject: [PATCH 09/12] disable windows install (#5) * separate common quant models patching and gptq * disable awq windows --- .github/workflows/test_openvino_slow.yml | 2 +- tests/openvino/test_modeling.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test_openvino_slow.yml b/.github/workflows/test_openvino_slow.yml index a6eb83e5de..a4e8a046b5 100644 --- a/.github/workflows/test_openvino_slow.yml +++ b/.github/workflows/test_openvino_slow.yml @@ -49,7 +49,7 @@ jobs: name: Install specific dependencies and versions required for older transformers run: pip install transformers==${{ matrix.transformers-version }} accelerate==0.* peft==0.13.*, diffusers==0.30.* transformers_stream_generator - - if: ${{ matrix.transformers-version == 'latest' }} + - if: ${{ matrix.transformers-version == 'latest' && matrix.os != 'windows-2019' }} name: Install auto-gptq, autoawq run: | pip install auto-gptq autoawq --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 8927da1ab4..8a3e907533 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -15,6 +15,7 @@ import copy import gc import os +import platform import tempfile import time import unittest @@ -914,10 +915,14 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): "exaone", "mistral-nemo", "minicpm3", - "opt_gptq", ) - if is_openvino_version(">=", "2024.6.0"): + # gptq and awq install disabled for windows test environment + if platform.system() != "Windows": + SUPPORTED_ARCHITECTURES += ("opt_gptq",) + + # autoawq install disabled for windows test environment + if is_openvino_version(">=", "2024.6.0") and platform.system() != "Windows": SUPPORTED_ARCHITECTURES += ("mixtral_awq",) GENERATION_LENGTH = 100 From ff66f43e802b8ffd24299ce03bbf6776419a7ff6 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Thu, 19 Dec 2024 15:04:55 +0400 Subject: [PATCH 10/12] skip logits check for quantized models (#6) --- tests/openvino/test_modeling.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 8a3e907533..023fbe8bd0 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -999,12 +999,9 @@ def test_compare_to_transformers(self, model_arch): # Compare tensor outputs atol = 1e-3 if model_arch == "minicpm" else 1e-4 - # quantized models have higher tolerance - if "awq" in model_arch: - atol = 1e-2 - elif "gptq" in model_arch: - atol = 0.6 - self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, equal_nan=True, atol=atol)) + # quantized models have different logits value range + if "awq" not in model_arch and "gptq" not in model_arch: + self.assertTrue(torch.allclose(ov_outputs.logits, transformers_outputs.logits, equal_nan=True, atol=atol)) # Qwen tokenizer does not support padding if model_arch in ["qwen"]: From e8be988931b9d7038c655b18e0f6d08e1f49be7c Mon Sep 17 00:00:00 2001 From: eaidova Date: Thu, 19 Dec 2024 15:17:17 +0400 Subject: [PATCH 11/12] fix test after rebase --- tests/openvino/utils_tests.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index d19f66659b..2011e11f0c 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import unittest -from typing import Dict, List, Union - from contextlib import contextmanager +from typing import Dict, List, Union import numpy as np import openvino as ov From 5d8bcb7bd0db9d86c05903e0a99f16ef17c5f0ff Mon Sep 17 00:00:00 2001 From: eaidova Date: Thu, 19 Dec 2024 16:18:26 +0400 Subject: [PATCH 12/12] fix testing condition for 2024.6 and unpatch in case if failed --- optimum/exporters/openvino/__main__.py | 252 +++++++++++++------------ tests/openvino/test_modeling.py | 6 +- 2 files changed, 132 insertions(+), 126 deletions(-) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 3015b20e83..09f2eaa10b 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -355,144 +355,150 @@ class StoreAttr(object): loading_kwargs["torch_dtype"] = dtype patch_16bit = True - if library_name == "open_clip": - model = _OpenClipForZeroShotImageClassification.from_pretrained(model_name_or_path, cache_dir=cache_dir) - else: - model = TasksManager.get_model_from_task( - task, - model_name_or_path, - subfolder=subfolder, - revision=revision, - cache_dir=cache_dir, - token=token, - local_files_only=local_files_only, - force_download=force_download, - trust_remote_code=trust_remote_code, - framework=framework, - device=device, - library_name=library_name, - **loading_kwargs, - ) + try: + if library_name == "open_clip": + model = _OpenClipForZeroShotImageClassification.from_pretrained(model_name_or_path, cache_dir=cache_dir) + else: + model = TasksManager.get_model_from_task( + task, + model_name_or_path, + subfolder=subfolder, + revision=revision, + cache_dir=cache_dir, + token=token, + local_files_only=local_files_only, + force_download=force_download, + trust_remote_code=trust_remote_code, + framework=framework, + device=device, + library_name=library_name, + **loading_kwargs, + ) + + needs_pad_token_id = task == "text-classification" and getattr(model.config, "pad_token_id", None) is None - needs_pad_token_id = task == "text-classification" and getattr(model.config, "pad_token_id", None) is None + if needs_pad_token_id: + if pad_token_id is not None: + model.config.pad_token_id = pad_token_id + else: + tok = AutoTokenizer.from_pretrained(model_name_or_path) + pad_token_id = getattr(tok, "pad_token_id", None) + if pad_token_id is None: + raise ValueError( + "Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument" + ) + model.config.pad_token_id = pad_token_id - if needs_pad_token_id: - if pad_token_id is not None: - model.config.pad_token_id = pad_token_id + if hasattr(model.config, "export_model_type"): + model_type = model.config.export_model_type.replace("_", "-") else: - tok = AutoTokenizer.from_pretrained(model_name_or_path) - pad_token_id = getattr(tok, "pad_token_id", None) - if pad_token_id is None: - raise ValueError( - "Could not infer the pad token id, which is needed in this case, please provide it with the --pad_token_id argument" + model_type = model.config.model_type.replace("_", "-") + + if ( + not custom_architecture + and library_name != "diffusers" + and task + "-with-past" + in TasksManager.get_supported_tasks_for_model_type( + model_type, exporter="openvino", library_name=library_name + ) + ): + # Make -with-past the default if --task was not explicitely specified + if original_task == "auto": + task = task + "-with-past" + else: + logger.info( + f"The task `{task}` was manually specified, and past key values will not be reused in the decoding." + f" if needed, please pass `--task {task}-with-past` to export using the past key values." ) - model.config.pad_token_id = pad_token_id - if hasattr(model.config, "export_model_type"): - model_type = model.config.export_model_type.replace("_", "-") - else: - model_type = model.config.model_type.replace("_", "-") - - if ( - not custom_architecture - and library_name != "diffusers" - and task + "-with-past" - in TasksManager.get_supported_tasks_for_model_type(model_type, exporter="openvino", library_name=library_name) - ): - # Make -with-past the default if --task was not explicitely specified if original_task == "auto": - task = task + "-with-past" - else: - logger.info( - f"The task `{task}` was manually specified, and past key values will not be reused in the decoding." - f" if needed, please pass `--task {task}-with-past` to export using the past key values." - ) + synonyms_for_task = sorted(TasksManager.synonyms_for_task(task)) + if synonyms_for_task: + synonyms_for_task = ", ".join(synonyms_for_task) + possible_synonyms = f" (possible synonyms are: {synonyms_for_task})" + else: + possible_synonyms = "" + logger.info(f"Automatic task detection to {task}{possible_synonyms}.") - if original_task == "auto": - synonyms_for_task = sorted(TasksManager.synonyms_for_task(task)) - if synonyms_for_task: - synonyms_for_task = ", ".join(synonyms_for_task) - possible_synonyms = f" (possible synonyms are: {synonyms_for_task})" - else: - possible_synonyms = "" - logger.info(f"Automatic task detection to {task}{possible_synonyms}.") + preprocessors = maybe_load_preprocessors( + model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code + ) - preprocessors = maybe_load_preprocessors( - model_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code - ) + submodel_paths = export_from_model( + model=model, + output=output, + task=task, + ov_config=ov_config, + stateful=stateful, + model_kwargs=model_kwargs, + custom_export_configs=custom_export_configs, + fn_get_submodels=fn_get_submodels, + preprocessors=preprocessors, + device=device, + trust_remote_code=trust_remote_code, + patch_16bit_model=patch_16bit, + **kwargs_shapes, + ) - submodel_paths = export_from_model( - model=model, - output=output, - task=task, - ov_config=ov_config, - stateful=stateful, - model_kwargs=model_kwargs, - custom_export_configs=custom_export_configs, - fn_get_submodels=fn_get_submodels, - preprocessors=preprocessors, - device=device, - trust_remote_code=trust_remote_code, - patch_16bit_model=patch_16bit, - **kwargs_shapes, - ) + if convert_tokenizer: + maybe_convert_tokenizers(library_name, output, model, preprocessors, task=task) - if convert_tokenizer: - maybe_convert_tokenizers(library_name, output, model, preprocessors, task=task) - - clear_class_registry() - del model - gc.collect() - - for submodel_path in submodel_paths: - submodel_path = Path(output) / submodel_path - submodel = core.read_model(submodel_path) - - quantization_config = None - if ov_config is None: - num_parameters = 0 - for op in submodel.get_ops(): - if op.get_type_name() == "Constant" and op.get_element_type() in [Type.f16, Type.f32, Type.bf16]: - num_parameters += reduce(operator.mul, op.shape, 1) - del op - if num_parameters >= _MAX_UNCOMPRESSED_SIZE: - if is_nncf_available(): - quantization_config = {"bits": 8, "sym": False} - logger.info("The model weights will be quantized to int8_asym.") - else: - logger.warning( - "The model will be converted with no weights quantization. Quantization of the weights to int8 " - "requires nncf. Please install it with `pip install nncf`" - ) - break - else: - quantization_config = ov_config.quantization_config - if quantization_config is None: - del submodel - gc.collect() - continue + clear_class_registry() + del model + gc.collect() - if not is_nncf_available(): - raise ImportError("Quantization of the weights requires nncf, please install it with `pip install nncf`") + for submodel_path in submodel_paths: + submodel_path = Path(output) / submodel_path + submodel = core.read_model(submodel_path) + + quantization_config = None + if ov_config is None: + num_parameters = 0 + for op in submodel.get_ops(): + if op.get_type_name() == "Constant" and op.get_element_type() in [Type.f16, Type.f32, Type.bf16]: + num_parameters += reduce(operator.mul, op.shape, 1) + del op + if num_parameters >= _MAX_UNCOMPRESSED_SIZE: + if is_nncf_available(): + quantization_config = {"bits": 8, "sym": False} + logger.info("The model weights will be quantized to int8_asym.") + else: + logger.warning( + "The model will be converted with no weights quantization. Quantization of the weights to int8 " + "requires nncf. Please install it with `pip install nncf`" + ) + break + else: + quantization_config = ov_config.quantization_config + if quantization_config is None: + del submodel + gc.collect() + continue + + if not is_nncf_available(): + raise ImportError( + "Quantization of the weights requires nncf, please install it with `pip install nncf`" + ) - from optimum.intel.openvino.quantization import _weight_only_quantization + from optimum.intel.openvino.quantization import _weight_only_quantization - _weight_only_quantization(submodel, quantization_config) - compressed_submodel_path = submodel_path.parent / f"{submodel_path.stem}_compressed.xml" - save_model(submodel, compressed_submodel_path, compress_to_fp16=False) - del submodel - gc.collect() + _weight_only_quantization(submodel, quantization_config) + compressed_submodel_path = submodel_path.parent / f"{submodel_path.stem}_compressed.xml" + save_model(submodel, compressed_submodel_path, compress_to_fp16=False) + del submodel + gc.collect() - submodel_path.unlink() - submodel_path.with_suffix(".bin").unlink() - compressed_submodel_path.rename(submodel_path) - compressed_submodel_path.with_suffix(".bin").rename(submodel_path.with_suffix(".bin")) + submodel_path.unlink() + submodel_path.with_suffix(".bin").unlink() + compressed_submodel_path.rename(submodel_path) + compressed_submodel_path.with_suffix(".bin").rename(submodel_path.with_suffix(".bin")) - # Unpatch modules after GPTQ export - if do_quant_patching: - torch.cuda.is_available = orig_cuda_check - if do_gptq_patching: - GPTQQuantizer.post_init_model = orig_post_init_model + finally: + # Unpatch modules after quantized model export + if do_quant_patching: + torch.cuda.is_available = orig_cuda_check + if do_gptq_patching: + GPTQQuantizer.post_init_model = orig_post_init_model def maybe_convert_tokenizers(library_name: str, output: Path, model=None, preprocessors=None, task=None): diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 95efa4a45b..32da813391 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -921,9 +921,9 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase): if platform.system() != "Windows": SUPPORTED_ARCHITECTURES += ("opt_gptq",) - # autoawq install disabled for windows test environment - if is_openvino_version(">=", "2024.6.0") and platform.system() != "Windows": - SUPPORTED_ARCHITECTURES += ("mixtral_awq",) + # autoawq install disabled for windows test environment + if is_openvino_version(">=", "2024.6.0") and platform.system() != "Windows": + SUPPORTED_ARCHITECTURES += ("mixtral_awq",) GENERATION_LENGTH = 100 REMOTE_CODE_MODELS = (