Skip to content

Commit

Permalink
Merge pull request ztxz16#345 from siemonchan/master
Browse files Browse the repository at this point in the history
提升torch2flm转换时4bit模型的量化精度
  • Loading branch information
ztxz16 authored Oct 13, 2023
2 parents 2527d8f + 3397281 commit c98d016
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions tools/fastllm_pytools/torch2flm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,18 @@ def write_int8(fo, v):
fo.write(v.data)

def write_int4(fo, v):
c_min = np.expand_dims(-np.abs(v).max(axis = -1), -1)
c_max = np.expand_dims(np.abs(v).max(axis = -1), -1)
c_scale = c_max / 7.0
c_min = c_scale * -8.0
# c_min = np.expand_dims(-np.abs(v).max(axis = -1), -1)
# c_max = np.expand_dims(np.abs(v).max(axis = -1), -1)
# c_scale = c_max / 7.0
# c_min = c_scale * -8.0

c_min = np.expand_dims(v.min(axis = -1), -1)
c_max = np.expand_dims(v.max(axis = -1), -1)
c_scale = (c_max - c_min) / 15.0
c_zero = np.round(0.0 - c_min / c_scale)
c_zero = c_zero.clip(0, 15)
c_min = -c_scale * c_zero

v = (v - c_min) / c_scale
v = (v + 0.5).astype(np.int8).clip(0, 15).astype(np.uint8)
v = v[:, 0::2] * 16 + v[:, 1::2]
Expand Down

0 comments on commit c98d016

Please sign in to comment.