Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Refact convert scripts #35

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions neural_speed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,31 @@ def init(self, model_name, use_quant=True, use_cache=False, use_gptq=False, use_
if not use_cache:
os.remove(fp32_bin)

def init2(self, model_name, not_quant=False, use_cache=False, use_gptq=False, use_awq=False,
weight_dtype="int4", alg="sym", group_size=32,
scale_dtype="fp32", compute_dtype="int8", use_ggml=False):
self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model_type = Model.get_model_type(self.config)
self.__import_package(model_type)

# check cache and quantization
output_path = "runtime_outs"
os.makedirs(output_path, exist_ok=True)
quant_bin = "{}/ne_{}_q.bin".format(output_path, model_type)
self.bin_file = quant_bin

if use_cache and os.path.exists(self.bin_file):
return

from neural_speed.convert import convert_model
from neural_speed.convert.common import QuantConfig
quant_config = QuantConfig(weight_dtype=weight_dtype, alg=alg, group_size=group_size,
scale_dtype=scale_dtype, compute_dtype=compute_dtype, use_ggml=use_ggml,
not_quant=not_quant, use_gptq=use_gptq, use_awq=use_awq)
convert_model(model_name, quant_bin, quant_config)


def init_from_bin(self, model_type, model_path, **generate_kwargs):
self.__import_package(model_type)
self.model = self.module.Model()
Expand Down
20 changes: 6 additions & 14 deletions neural_speed/convert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,16 @@
from pathlib import Path
from transformers import AutoConfig
import subprocess
from .convert_llama import convert_llama
from .convert_gptj import convert_gptj
from .convert_chatglm import convert_chatglm

model_maps = {"gpt_neox": "gptneox", "gpt_bigcode": "starcoder", "whisper": "whisper"}


def convert_model(model, outfile, outtype, whisper_repo_path=None):
def convert_model(model, outfile, quant_config, whisper_repo_path=None):
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
model_type = model_maps.get(config.model_type, config.model_type)

quantized_model = 'gptq' in str(model).lower() or 'awq' in str(model).lower()
if quantized_model:
path = Path(Path(__file__).parent.absolute(), "convert_quantized_{}.py".format(model_type))
else:
path = Path(Path(__file__).parent.absolute(), "convert_{}.py".format(model_type))
cmd = []
cmd.extend(["python", path])
cmd.extend(["--outfile", outfile])
cmd.extend(["--outtype", outtype])
cmd.extend([model])

print("cmd:", cmd)
subprocess.run(cmd)
func = eval(f"convert_{model_type}")
func(model, outfile, quant_config)
161 changes: 88 additions & 73 deletions neural_speed/convert/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
import struct
import json
import warnings
from dataclasses import dataclass
from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Iterable, List,
Literal, Optional, Sequence, Tuple, TypeVar, Union)
from sentencepiece import SentencePieceProcessor # type: ignore
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM

GGML_QK8_0 = 32
GGML_QK4_0 = 32
Expand All @@ -33,7 +35,41 @@

GGML_QK4_0_TYPE = 2
GGML_QK4_1_TYPE = 3
GGML_QJBLAS_TYPE = 13
GGML_QJBLAS_TYPE = 19

@dataclass
class QuantConfig:
weight_dtype: str
alg: str
group_size: int
scale_dtype: str
compute_dtype: str
use_ggml: bool
not_quant: bool
use_gptq: bool
use_awq: bool

# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))

def quantize_q4_0(tensor: torch.Tensor) -> torch.CharTensor:
# equivalent to ggml_quantize_q4_0 in ggml.c
Expand Down Expand Up @@ -216,12 +252,19 @@ def unpack_awq_weight(qweight, scales, qzeros, q_config):

return weight, scales, zeros

def write_header(fout, shape, dst_name, ftype_cur):
def write_header(fout, shape, dst_name, ftype_cur, align=True):
sname = dst_name.encode('utf-8')
fout.write(struct.pack("iii", len(shape), len(sname), ftype_cur))
fout.write(struct.pack("i" * len(shape), *shape[::-1]))
fout.write(sname)
fout.seek((fout.tell() + 31) & -32)
if align:
fout.seek((fout.tell() + 31) & -32)


def load_hf_model(model_path):
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
return model, config, None


def find_quantized_model_file(model_path):
Expand All @@ -245,30 +288,64 @@ def load_quantized_model(model_path):
else:
print("unknown input model path, only support .safetensors or .pt file.")

with open(model_path + '/config.json', "r", encoding="utf-8") as f:
config = json.load(f)

config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
if not isinstance(config, dict):
config = config.to_dict()
quantize_config = config["quantization_config"]
if "zero_point" in quantize_config:
quantize_config["sym"] = not quantize_config["zero_point"]
return model, config, config["quantization_config"]
return model, config, quantize_config


def convert_fp32_tensor(src_name, dst_name, model, fout):
def convert_fp32_to_q4_tensor(src_name, dst_name, model, fout, n_head=0, n_head2=0, permute_func=None):
if ".weight" not in src_name:
src_name = src_name + ".weight"
v = model[src_name]
shape = v.shape
# print("Processing non-Q4 variable: " + src_name +
# " with shape: ", shape, " and type: ", v.dtype)
v = v.to(torch.float32)

ftype_cur = {torch.float16: 1, torch.float32: 0}[v.dtype]
if permute_func:
v = permute_func(v, n_head, n_head2).contiguous()

qv = quantize_q4_0(v)
ftype_cur = GGML_QK4_0_TYPE

# header
write_header(fout, shape, dst_name, ftype_cur)

# data
v.numpy().tofile(fout)
print(f"converting {dst_name} float tensor")
qv.numpy().tofile(fout)
print(f"converting {dst_name} float to q4_0 tensor")


def convert_fp32_to_jblas_tensor(src_name, dst_name, model, fout, n_head=0, n_head2=0, permute_func=None):
import neural_speed.llama_cpp as cpp_model
if ".weight" not in src_name:
src_name = src_name + ".weight"
v = model[src_name]
shape = v.shape
v = v.to(torch.float32)

if permute_func:
v = permute_func(v, n_head, n_head2).contiguous()

ftype_cur = GGML_QJBLAS_TYPE

# header
write_header(fout, shape, dst_name, ftype_cur)

# pack int weight in bestla format
dst = np.zeros((v.shape[0], v.shape[1] * 4), dtype=np.int8)
byte_size = cpp_model.Model.np_bestla_quantize(v.numpy(), dst,
weight_dtype="int4",
group_size=32,
alg="sym",
compute_dtype="int8")
dst.flatten()[:byte_size].tofile(fout)
# data
print(f"converting {dst_name} float to jblas tensor")

def convert_q4_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head2=0, permute_func=None):
qzeros = model[f"{src_name}.qzeros"]
Expand Down Expand Up @@ -352,65 +429,3 @@ def convert_q4_f32_tensor(src_name, dst_name, model, fout, q_config, n_head, n_h

print(f"converting {dst_name} qauntized tensor to fp32 tensor")


def convert_q4_bestla_tensor(src_name, dst_name, model, fout, q_config, n_head, n_head_kv=0, permute_func=None):
# unpack weight and repack into jblas format
import neural_speed.llama_cpp as cpp_model
qzeros = model[f"{src_name}.qzeros"]
zeros = qzeros_to_zeros(qzeros)
scales = model[f"{src_name}.scales"]
qweight = model[f"{src_name}.qweight"]

int_weight, gptq_scales, gptq_zeros = unpack_weight(qweight, scales, qzeros, q_config)
int_weight = int_weight.view(-1,int_weight.shape[-1])

# permute_func for llama-like model
if permute_func:
int_weight = permute_func(int_weight.t(), n_head, n_head_kv).t().contiguous()
gptq_scales = permute_func(gptq_scales.t(), n_head, n_head_kv).t().contiguous()
gptq_zeros = permute_func(gptq_zeros.t(), n_head, n_head_kv).t().contiguous()

# shuffle weight in GPTQ when act order is on
if 'desc_act'in q_config and q_config['desc_act']:
g_idx = model[f"{src_name}.g_idx"]
int_weight2 = int_weight.clone()
group_size=q_config['group_size']
group_dict = {}
for i in range(len(g_idx)):
group_idx = g_idx[i].item()
if group_idx not in group_dict:
target_idx = group_idx * group_size
group_dict[group_idx] = 0
else:
group_dict[group_idx] = group_dict[group_idx] + 1
target_idx = group_idx * group_size + group_dict[group_idx]
int_weight2[target_idx] = int_weight[i]
int_weight = int_weight2

shape = int_weight.shape
write_header(fout, shape[::-1], dst_name, GGML_QJBLAS_TYPE)

if q_config['bits'] == 4:
int_weight = (int_weight - 8) * 16
gptq_scales = gptq_scales / 16
gptq_zeros = (gptq_zeros - 8) * 16
dst = np.zeros((int_weight.shape[0], int_weight.shape[1] * 4), dtype=np.int8)
int_weight = np.ascontiguousarray(int_weight.numpy())
gptq_scales = np.ascontiguousarray((gptq_scales.float()).numpy())
if q_config['sym']:
gptq_zeros = np.empty(0, dtype=np.int8)
else:
gptq_zeros = np.ascontiguousarray(gptq_zeros.numpy())
if 'desc_act'in q_config and q_config['desc_act']:
g_idx = np.ascontiguousarray(g_idx.numpy())
else:
g_idx = np.empty(0, dtype=np.int32)

# pack int weight in bestla format
byte_size = cpp_model.Model.np_bestla_qpack(int_weight, gptq_scales, gptq_zeros, g_idx, dst,
weight_dtype="int4" if q_config['bits'] == 4 else "int8",
group_size=q_config['group_size'],
alg="sym" if q_config['sym'] else "asym",
compute_dtype="int8")
dst.flatten()[:byte_size].tofile(fout)
print(f"converting {dst_name} qauntized tensor to bestla q4 block")
Loading
Loading