diff --git a/neural_speed/__init__.py b/neural_speed/__init__.py index afc1a6100..a850c6097 100644 --- a/neural_speed/__init__.py +++ b/neural_speed/__init__.py @@ -83,7 +83,7 @@ def get_model_type(model_config): model_type = "chatglm2" return model_type - def init(self, model_name, use_quant=True, use_gptq=False, use_awq=False, + def init(self, model_name, use_quant=True, use_gptq=False, use_awq=False, use_autoround=False, weight_dtype="int4", alg="sym", group_size=32, scale_dtype="fp32", compute_dtype="int8", use_ggml=False): self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) @@ -107,6 +107,8 @@ def init(self, model_name, use_quant=True, use_gptq=False, use_awq=False, quant_desc = "gptq" if use_awq: quant_desc = "awq" + if use_awq: + quant_desc = "autoround" quant_bin = "{}/ne_{}_q_{}.bin".format(output_path, model_type, quant_desc) if not use_quant: @@ -119,8 +121,8 @@ def init(self, model_name, use_quant=True, use_gptq=False, use_awq=False, format(self.bin_file)) return - if use_gptq or use_awq: - convert_model(model_name, quant_bin, "f32") + if use_gptq or use_awq or use_autoround: + convert_model(model_name, quant_bin, use_quantized_model=True) return if not os.path.exists(fp32_bin): diff --git a/neural_speed/convert/__init__.py b/neural_speed/convert/__init__.py index da272ce32..9f063a5ec 100644 --- a/neural_speed/convert/__init__.py +++ b/neural_speed/convert/__init__.py @@ -22,12 +22,11 @@ model_maps = {"gpt_neox": "gptneox", "gpt_bigcode": "starcoder", "whisper": "whisper"} -def convert_model(model, outfile, outtype, whisper_repo_path=None): +def convert_model(model, outfile, outtype="f32", whisper_repo_path=None, use_quantized_model=False): config = AutoConfig.from_pretrained(model, trust_remote_code=True) model_type = model_maps.get(config.model_type, config.model_type) - quantized_model = 'gptq' in str(model).lower() or 'awq' in str(model).lower() - if quantized_model: + if use_quantized_model: path = Path(Path(__file__).parent.absolute(), "convert_quantized_{}.py".format(model_type)) else: path = Path(Path(__file__).parent.absolute(), "convert_{}.py".format(model_type))