From cbbda3e43284c49a02732375cfcabc61e4923046 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Tue, 28 May 2024 18:03:08 +0200 Subject: [PATCH] Fix ort config instantiation (from_pretrained) and saving (save_pretrained) (#1865) * fix ort config instatiation (from_dict) and saving (to_dict) * added tests for quantization with ort config * style * handle empty quant dictionary --- .github/workflows/test_cli.yml | 33 ++++++++++--------- optimum/onnxruntime/configuration.py | 49 ++++++++++++++++++++++++++-- tests/cli/test_cli.py | 31 +++++++++--------- 3 files changed, 80 insertions(+), 33 deletions(-) diff --git a/.github/workflows/test_cli.yml b/.github/workflows/test_cli.yml index 7eae0186076..ecb19d23aa3 100644 --- a/.github/workflows/test_cli.yml +++ b/.github/workflows/test_cli.yml @@ -4,9 +4,9 @@ name: Optimum CLI / Python - Test on: push: - branches: [ main ] + branches: [main] pull_request: - branches: [ main ] + branches: [main] concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} @@ -22,17 +22,20 @@ jobs: runs-on: ${{ matrix.os }} steps: - - uses: actions/checkout@v2 - - name: Setup Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install .[tests,exporters,exporters-tf] - - name: Test with unittest - working-directory: tests - run: | - python -m unittest discover -s cli -p 'test_*.py' + - name: Checkout code + uses: actions/checkout@v4 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + pip install .[tests,exporters,exporters-tf] + + - name: Test with pytest + run: | + pytest tests/cli -s -vvvv --durations=0 diff --git a/optimum/onnxruntime/configuration.py b/optimum/onnxruntime/configuration.py index c11cf58b8b0..2e3d9f32d6a 100644 --- a/optimum/onnxruntime/configuration.py +++ b/optimum/onnxruntime/configuration.py @@ -18,7 +18,7 @@ from dataclasses import asdict, dataclass, field from enum import Enum from pathlib import Path -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union from datasets import Dataset from packaging.version import Version, parse @@ -298,6 +298,15 @@ def __post_init__(self): ) self.operators_to_quantize = operators_to_quantize + if isinstance(self.format, str): + self.format = QuantFormat[self.format] + if isinstance(self.mode, str): + self.mode = QuantizationMode[self.mode] + if isinstance(self.activations_dtype, str): + self.activations_dtype = QuantType[self.activations_dtype] + if isinstance(self.weights_dtype, str): + self.weights_dtype = QuantType[self.weights_dtype] + @staticmethod def quantization_type_str(activations_dtype: QuantType, weights_dtype: QuantType) -> str: return ( @@ -984,8 +993,28 @@ def __init__( self.opset = opset self.use_external_data_format = use_external_data_format self.one_external_file = one_external_file - self.optimization = self.dataclass_to_dict(optimization) - self.quantization = self.dataclass_to_dict(quantization) + + if isinstance(optimization, dict) and optimization: + self.optimization = OptimizationConfig(**optimization) + elif isinstance(optimization, OptimizationConfig): + self.optimization = optimization + elif not optimization: + self.optimization = None + else: + raise ValueError( + f"Optional argument `optimization` must be a dictionary or an instance of OptimizationConfig, got {type(optimization)}" + ) + if isinstance(quantization, dict) and quantization: + self.quantization = QuantizationConfig(**quantization) + elif isinstance(quantization, QuantizationConfig): + self.quantization = quantization + elif not quantization: + self.quantization = None + else: + raise ValueError( + f"Optional argument `quantization` must be a dictionary or an instance of QuantizationConfig, got {type(quantization)}" + ) + self.optimum_version = kwargs.pop("optimum_version", None) @staticmethod @@ -1002,3 +1031,17 @@ def dataclass_to_dict(config) -> dict: v = [elem.name if isinstance(elem, Enum) else elem for elem in v] new_config[k] = v return new_config + + def to_dict(self) -> Dict[str, Any]: + dict_config = { + "opset": self.opset, + "use_external_data_format": self.use_external_data_format, + "one_external_file": self.one_external_file, + "optimization": self.dataclass_to_dict(self.optimization), + "quantization": self.dataclass_to_dict(self.quantization), + } + + if self.optimum_version: + dict_config["optimum_version"] = self.optimum_version + + return dict_config diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index 2e64dc9cdfb..ca4ebf8bd23 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -21,10 +21,8 @@ import unittest from pathlib import Path -from onnxruntime import __version__ as ort_version -from packaging.version import Version, parse - import optimum.commands +from optimum.onnxruntime.configuration import AutoQuantizationConfig, ORTConfig CLI_WIH_CUSTOM_COMMAND_PATH = Path(__file__).parent / "cli_with_custom_command.py" @@ -83,30 +81,33 @@ def test_optimize_commands(self): def test_quantize_commands(self): with tempfile.TemporaryDirectory() as tempdir: + ort_config = ORTConfig(quantization=AutoQuantizationConfig.avx2(is_static=False)) + ort_config.save_pretrained(tempdir) + # First export a tiny encoder, decoder only and encoder-decoder export_commands = [ - f"optimum-cli export onnx --model hf-internal-testing/tiny-random-BertModel {tempdir}/encoder", + f"optimum-cli export onnx --model hf-internal-testing/tiny-random-bert {tempdir}/encoder", f"optimum-cli export onnx --model hf-internal-testing/tiny-random-gpt2 {tempdir}/decoder", - # f"optimum-cli export onnx --model hf-internal-testing/tiny-random-t5 {tempdir}/encoder-decoder", + f"optimum-cli export onnx --model hf-internal-testing/tiny-random-t5 {tempdir}/encoder-decoder", ] quantize_commands = [ f"optimum-cli onnxruntime quantize --onnx_model {tempdir}/encoder --avx2 -o {tempdir}/quantized_encoder", f"optimum-cli onnxruntime quantize --onnx_model {tempdir}/decoder --avx2 -o {tempdir}/quantized_decoder", - # f"optimum-cli onnxruntime quantize --onnx_model {tempdir}/encoder-decoder --avx2 -o {tempdir}/quantized_encoder_decoder", + f"optimum-cli onnxruntime quantize --onnx_model {tempdir}/encoder-decoder --avx2 -o {tempdir}/quantized_encoder_decoder", ] - if parse(ort_version) != Version("1.16.0") and parse(ort_version) != Version("1.17.0"): - # Failing on onnxruntime==1.17.0, will be fixed on 1.17.1: https://github.com/microsoft/onnxruntime/pull/19421 - export_commands.append( - f"optimum-cli export onnx --model hf-internal-testing/tiny-random-t5 {tempdir}/encoder-decoder" - ) - quantize_commands.append( - f"optimum-cli onnxruntime quantize --onnx_model {tempdir}/encoder-decoder --avx2 -o {tempdir}/quantized_encoder_decoder" - ) + quantize_with_config_commands = [ + f"optimum-cli onnxruntime quantize --onnx_model hf-internal-testing/tiny-random-bert --c {tempdir}/ort_config.json -o {tempdir}/quantized_encoder_with_config", + f"optimum-cli onnxruntime quantize --onnx_model hf-internal-testing/tiny-random-gpt2 --c {tempdir}/ort_config.json -o {tempdir}/quantized_decoder_with_config", + f"optimum-cli onnxruntime quantize --onnx_model hf-internal-testing/tiny-random-t5 --c {tempdir}/ort_config.json -o {tempdir}/quantized_encoder_decoder_with_config", + ] - for export, quantize in zip(export_commands, quantize_commands): + for export, quantize, quantize_with_config in zip( + export_commands, quantize_commands, quantize_with_config_commands + ): subprocess.run(export, shell=True, check=True) subprocess.run(quantize, shell=True, check=True) + subprocess.run(quantize_with_config, shell=True, check=True) def _run_command_and_check_content(self, command: str, content: str) -> bool: proc = subprocess.Popen(command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE)