From 974f0862226c9db4545454c31ae75c9e4666913c Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 12 Jul 2024 09:10:15 +0000 Subject: [PATCH] Separate IO and Quantization args --- scripts/quantize.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/scripts/quantize.py b/scripts/quantize.py index b8e4bdf56..4cdad4bf5 100644 --- a/scripts/quantize.py +++ b/scripts/quantize.py @@ -37,11 +37,10 @@ class QuantMode(Enum): @dataclass -class QuantizationArguments: +class IOArguments: """ - Arguments for quantizing ONNX models + Arguments to specify input and output folders """ - input_folder: str = field( metadata={ "help": "Path of the input folder containing the .onnx models to quantize" @@ -52,6 +51,13 @@ class QuantizationArguments: "help": "Path of the output folder where the quantized .onnx models will be saved" } ) + +@dataclass +class QuantizationArguments: + """ + Arguments for quantizing ONNX models + """ + modes: QuantMode = field( default=QUANTIZE_OPTIONS, metadata={ @@ -220,20 +226,27 @@ def quantize_bnb4( def main(): - parser = HfArgumentParser(QuantizationArguments) - (quantization_args,) = parser.parse_args_into_dataclasses() + parser = HfArgumentParser((IOArguments, QuantizationArguments)) + io_args, quantization_args = parser.parse_args_into_dataclasses() # (Step 1) Validate the arguments if not quantization_args.modes: raise ValueError("At least one quantization mode must be specified") + input_folder = io_args.input_folder + if not os.path.exists(input_folder): + raise ValueError(f"Input folder {input_folder} does not exist") + model_names_or_paths = [ - os.path.join(quantization_args.input_folder, file) - for file in os.listdir(quantization_args.input_folder) + os.path.join(input_folder, file) + for file in os.listdir(input_folder) if file.endswith(".onnx") ] if not model_names_or_paths: - raise ValueError(f"No .onnx models found in {quantization_args.input_folder}") + raise ValueError(f"No .onnx models found in {input_folder}") + + output_folder = io_args.output_folder + os.makedirs(output_folder, exist_ok=True) # (Step 2) Quantize the models for model_path in (progress_models := tqdm(model_names_or_paths)): @@ -246,7 +259,7 @@ def main(): mode = QuantMode(mode) suffix = QUANTIZE_SUFFIX_MAPPING.get(mode, mode.value) save_path = os.path.join( - quantization_args.output_folder, + output_folder, f"{file_name_without_extension}_{suffix}.onnx", )