diff --git a/neural_speed/__init__.py b/neural_speed/__init__.py index 069a1b9d3..b1e2a2551 100644 --- a/neural_speed/__init__.py +++ b/neural_speed/__init__.py @@ -77,7 +77,7 @@ def get_model_type(model_config): model_type = "chatglm2" return model_type - def init(self, model_name, use_quant=True, use_cache=False, use_gptq=False, use_awq=False, + def init(self, model_name, use_quant=True, use_gptq=False, use_awq=False, weight_dtype="int4", alg="sym", group_size=32, scale_dtype="fp32", compute_dtype="int8", use_ggml=False): self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) @@ -107,14 +107,17 @@ def init(self, model_name, use_quant=True, use_cache=False, use_gptq=False, use_ self.bin_file = fp32_bin else: self.bin_file = quant_bin - if use_cache and os.path.exists(self.bin_file): + + if os.path.exists(self.bin_file): + print("{} existed, will use cache file. Otherwise please remove the file". + format(self.bin_file)) return if use_gptq or use_awq: convert_model(model_name, quant_bin, "f32") return - if not use_cache or not os.path.exists(fp32_bin): + if not os.path.exists(fp32_bin): convert_model(model_name, fp32_bin, "f32") assert os.path.exists(fp32_bin), "Fail to convert pytorch model" @@ -127,8 +130,7 @@ def init(self, model_name, use_quant=True, use_cache=False, use_gptq=False, use_ assert os.path.exists(quant_bin), "Fail to quantize model" # clean - if not use_cache: - os.remove(fp32_bin) + os.remove(fp32_bin) def init_from_bin(self, model_type, model_path, **generate_kwargs): self.__import_package(model_type) diff --git a/scripts/cal_diff.py b/scripts/cal_diff.py index 0af966299..a281c679c 100644 --- a/scripts/cal_diff.py +++ b/scripts/cal_diff.py @@ -35,10 +35,10 @@ def cmpData(numa, numb): args = parser.parse_args() woq_configs = { - "fp32": {"use_cache":True, "not_quant":True}, - # "ggml_int4": {"compute_dtype":"int8", "weight_dtype":"int4", "use_cache":True, "use_ggml":True}, - "jblas_int4": {"compute_dtype":"int8", "weight_dtype":"int4", "use_cache":True}, - # "jblas_int8": {"compute_dtype":"bf16", "weight_dtype":"int8", "use_cache":True}, + "fp32": {"not_quant":True}, + # "ggml_int4": {"compute_dtype":"int8", "weight_dtype":"int4", "use_ggml":True}, + "jblas_int4": {"compute_dtype":"int8", "weight_dtype":"int4"}, + # "jblas_int8": {"compute_dtype":"bf16", "weight_dtype":"int8"}, } prompt = "What is the meaning of life?" diff --git a/scripts/perplexity.py b/scripts/perplexity.py index 97bae5297..4a1d19c60 100644 --- a/scripts/perplexity.py +++ b/scripts/perplexity.py @@ -105,7 +105,7 @@ def perplexity(model_name, dataset_name, **kwargs): init_kwargs = { k: kwargs[k] for k in kwargs - if k in ['use_cache', 'compute_dtype', 'weight_dtype', 'scale_dtype', 'group_size', 'use_ggml'] + if k in ['compute_dtype', 'weight_dtype', 'scale_dtype', 'group_size', 'use_ggml'] } model.init(model_name, **init_kwargs) @@ -186,7 +186,6 @@ def add_quant_args(parser: argparse.ArgumentParser): type=str, help="path to quantized weight; other quant args will be ignored if specified", default="") - group.add_argument('--use_cache', action="store_true", help="Use local quantized model if file exists") group.add_argument( "--weight_dtype", choices=["int4", "int8"], diff --git a/tests/test_python_api.py b/tests/test_python_api.py index 0118b7266..6d71be919 100644 --- a/tests/test_python_api.py +++ b/tests/test_python_api.py @@ -53,7 +53,7 @@ def test_llm_runtime(self): print(tokenizer.decode(pt_generate_ids)) # check output ids - woq_config_fp32 = {"use_quant":False, "compute_dtype":"fp32", "weight_dtype":"fp32", "use_cache":True, "use_ggml":False, "group_size":128} + woq_config_fp32 = {"use_quant":False, "compute_dtype":"fp32", "weight_dtype":"fp32", "use_ggml":False, "group_size":128} itrex_model = Model() itrex_model.init(model_name, use_quant=False) @@ -65,10 +65,10 @@ def test_llm_runtime(self): # check diff of logits woq_configs = { - "fp32": {"use_cache":True, "use_quant":False}, - # "ggml_int4": {"compute_dtype":"int8", "weight_dtype":"int4", "use_cache":True, "use_ggml":True}, - "jblas_int4": {"compute_dtype":"int8", "weight_dtype":"int4", "use_cache":True}, - # "jblas_int8": {"compute_dtype":"bf16", "weight_dtype":"int8", "use_cache":True}, + "fp32": {"use_quant":False}, + # "ggml_int4": {"compute_dtype":"int8", "weight_dtype":"int4", "use_ggml":True}, + "jblas_int4": {"compute_dtype":"int8", "weight_dtype":"int4"}, + # "jblas_int8": {"compute_dtype":"bf16", "weight_dtype":"int8"}, } for config_type in woq_configs: itrex_model = Model()