Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
load processed model automatically
Browse files Browse the repository at this point in the history
Signed-off-by: zhenwei-intel <[email protected]>
  • Loading branch information
zhenwei-intel committed Jan 23, 2024
1 parent 51088a2 commit f22176f
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 16 deletions.
12 changes: 7 additions & 5 deletions neural_speed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def get_model_type(model_config):
model_type = "chatglm2"
return model_type

def init(self, model_name, use_quant=True, use_cache=False, use_gptq=False, use_awq=False,
def init(self, model_name, use_quant=True, use_gptq=False, use_awq=False,
weight_dtype="int4", alg="sym", group_size=32,
scale_dtype="fp32", compute_dtype="int8", use_ggml=False):
self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
Expand Down Expand Up @@ -107,14 +107,17 @@ def init(self, model_name, use_quant=True, use_cache=False, use_gptq=False, use_
self.bin_file = fp32_bin
else:
self.bin_file = quant_bin
if use_cache and os.path.exists(self.bin_file):

if os.path.exists(self.bin_file):
print("{} existed, will use cache file. Otherwise please remove the file".
format(self.bin_file))
return

if use_gptq or use_awq:
convert_model(model_name, quant_bin, "f32")
return

if not use_cache or not os.path.exists(fp32_bin):
if not os.path.exists(fp32_bin):
convert_model(model_name, fp32_bin, "f32")
assert os.path.exists(fp32_bin), "Fail to convert pytorch model"

Expand All @@ -127,8 +130,7 @@ def init(self, model_name, use_quant=True, use_cache=False, use_gptq=False, use_
assert os.path.exists(quant_bin), "Fail to quantize model"

# clean
if not use_cache:
os.remove(fp32_bin)
os.remove(fp32_bin)

def init_from_bin(self, model_type, model_path, **generate_kwargs):
self.__import_package(model_type)
Expand Down
8 changes: 4 additions & 4 deletions scripts/cal_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def cmpData(numa, numb):
args = parser.parse_args()

woq_configs = {
"fp32": {"use_cache":True, "not_quant":True},
# "ggml_int4": {"compute_dtype":"int8", "weight_dtype":"int4", "use_cache":True, "use_ggml":True},
"jblas_int4": {"compute_dtype":"int8", "weight_dtype":"int4", "use_cache":True},
# "jblas_int8": {"compute_dtype":"bf16", "weight_dtype":"int8", "use_cache":True},
"fp32": {"not_quant":True},
# "ggml_int4": {"compute_dtype":"int8", "weight_dtype":"int4", "use_ggml":True},
"jblas_int4": {"compute_dtype":"int8", "weight_dtype":"int4"},
# "jblas_int8": {"compute_dtype":"bf16", "weight_dtype":"int8"},
}
prompt = "What is the meaning of life?"

Expand Down
3 changes: 1 addition & 2 deletions scripts/perplexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def perplexity(model_name, dataset_name, **kwargs):
init_kwargs = {
k: kwargs[k]
for k in kwargs
if k in ['use_cache', 'compute_dtype', 'weight_dtype', 'scale_dtype', 'group_size', 'use_ggml']
if k in ['compute_dtype', 'weight_dtype', 'scale_dtype', 'group_size', 'use_ggml']
}
model.init(model_name, **init_kwargs)

Expand Down Expand Up @@ -186,7 +186,6 @@ def add_quant_args(parser: argparse.ArgumentParser):
type=str,
help="path to quantized weight; other quant args will be ignored if specified",
default="")
group.add_argument('--use_cache', action="store_true", help="Use local quantized model if file exists")
group.add_argument(
"--weight_dtype",
choices=["int4", "int8"],
Expand Down
10 changes: 5 additions & 5 deletions tests/test_python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_llm_runtime(self):
print(tokenizer.decode(pt_generate_ids))

# check output ids
woq_config_fp32 = {"use_quant":False, "compute_dtype":"fp32", "weight_dtype":"fp32", "use_cache":True, "use_ggml":False, "group_size":128}
woq_config_fp32 = {"use_quant":False, "compute_dtype":"fp32", "weight_dtype":"fp32", "use_ggml":False, "group_size":128}
itrex_model = Model()

itrex_model.init(model_name, use_quant=False)
Expand All @@ -65,10 +65,10 @@ def test_llm_runtime(self):

# check diff of logits
woq_configs = {
"fp32": {"use_cache":True, "use_quant":False},
# "ggml_int4": {"compute_dtype":"int8", "weight_dtype":"int4", "use_cache":True, "use_ggml":True},
"jblas_int4": {"compute_dtype":"int8", "weight_dtype":"int4", "use_cache":True},
# "jblas_int8": {"compute_dtype":"bf16", "weight_dtype":"int8", "use_cache":True},
"fp32": {"use_quant":False},
# "ggml_int4": {"compute_dtype":"int8", "weight_dtype":"int4", "use_ggml":True},
"jblas_int4": {"compute_dtype":"int8", "weight_dtype":"int4"},
# "jblas_int8": {"compute_dtype":"bf16", "weight_dtype":"int8"},
}
for config_type in woq_configs:
itrex_model = Model()
Expand Down

0 comments on commit f22176f

Please sign in to comment.