Skip to content

Commit

Permalink
Separate IO and Quantization args
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Jul 12, 2024
1 parent 9787b75 commit 974f086
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions scripts/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,10 @@ class QuantMode(Enum):


@dataclass
class QuantizationArguments:
class IOArguments:
"""
Arguments for quantizing ONNX models
Arguments to specify input and output folders
"""

input_folder: str = field(
metadata={
"help": "Path of the input folder containing the .onnx models to quantize"
Expand All @@ -52,6 +51,13 @@ class QuantizationArguments:
"help": "Path of the output folder where the quantized .onnx models will be saved"
}
)

@dataclass
class QuantizationArguments:
"""
Arguments for quantizing ONNX models
"""

modes: QuantMode = field(
default=QUANTIZE_OPTIONS,
metadata={
Expand Down Expand Up @@ -220,20 +226,27 @@ def quantize_bnb4(

def main():

parser = HfArgumentParser(QuantizationArguments)
(quantization_args,) = parser.parse_args_into_dataclasses()
parser = HfArgumentParser((IOArguments, QuantizationArguments))
io_args, quantization_args = parser.parse_args_into_dataclasses()

# (Step 1) Validate the arguments
if not quantization_args.modes:
raise ValueError("At least one quantization mode must be specified")

input_folder = io_args.input_folder
if not os.path.exists(input_folder):
raise ValueError(f"Input folder {input_folder} does not exist")

model_names_or_paths = [
os.path.join(quantization_args.input_folder, file)
for file in os.listdir(quantization_args.input_folder)
os.path.join(input_folder, file)
for file in os.listdir(input_folder)
if file.endswith(".onnx")
]
if not model_names_or_paths:
raise ValueError(f"No .onnx models found in {quantization_args.input_folder}")
raise ValueError(f"No .onnx models found in {input_folder}")

output_folder = io_args.output_folder
os.makedirs(output_folder, exist_ok=True)

# (Step 2) Quantize the models
for model_path in (progress_models := tqdm(model_names_or_paths)):
Expand All @@ -246,7 +259,7 @@ def main():
mode = QuantMode(mode)
suffix = QUANTIZE_SUFFIX_MAPPING.get(mode, mode.value)
save_path = os.path.join(
quantization_args.output_folder,
output_folder,
f"{file_name_without_extension}_{suffix}.onnx",
)

Expand Down

0 comments on commit 974f086

Please sign in to comment.