Skip to content

Commit

Permalink
Fix optimum compat (#3)
Browse files Browse the repository at this point in the history
* add meta info

* cleanup

* cleanup

* The value of quantizer should be an array

* Update quantizer.py

* If is_auto_gptq_available() also writes "auto_gptq:version" to "quantizer"

* If is_auto_gptq_available() also writes "auto_gptq:version" to "quantizer"

* Update quantizer.py

* cleanup

* comment on meta

* hf_select_quant_linear pass checkpoint_format

* add todo fix

* move convert code to quantizer.save()

* Update quantizer.py

* Optimize hf_convert_gptq_v2_to_v1_format()

* Optimize hf_convert_gptq_v1_to_v2_format()

* fix GPTQTestCUDA

* hf_select_quant_linear() always set pack=True

* gptqmodel.hf_select_quant_linear() now does not select ExllamaV2

* gptqmodel.hf_select_quant_linear() now does not select ExllamaV2

* GPTQQuantizer add backend

* lower checkpoint_format and backend

* cleanup

* move backend to bottom

* no need to check gptqmodel version for ipex support

* Update import_utils.py

* Update quantizer.py

* fix UnboundLocalError: cannot access local variable 'version' where it is not associated with a value

* make version var short

* Update import_utils.py

* fix unittest

* use assertLessEqual

---------

Co-authored-by: Qubitium-ModelCloud <[email protected]>
Co-authored-by: LRL <[email protected]>
  • Loading branch information
3 people authored Dec 5, 2024
1 parent aa3d558 commit 5979473
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 41 deletions.
71 changes: 50 additions & 21 deletions optimum/gptq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .constants import GPTQ_CONFIG
from .data import get_dataset, prepare_dataset
from .utils import get_block_name_with_pattern, get_device, get_layers, get_preceding_modules, get_seqlen
from ..version import __version__ as optimum_version


if is_accelerate_available():
Expand All @@ -46,13 +47,15 @@
from auto_gptq.modeling._utils import autogptq_post_init as gptq_post_init
from auto_gptq.quantization import GPTQ
from auto_gptq.utils.import_utils import dynamically_import_QuantLinear as hf_select_quant_linear
from auto_gptq import __version__ as autogptq_version

if is_gptqmodel_available():
from gptqmodel import exllama_set_max_input_length
from gptqmodel.quantization import GPTQ
from gptqmodel.utils.importer import hf_select_quant_linear
from gptqmodel.utils.model import hf_convert_gptq_v1_to_v2_format, hf_convert_gptq_v2_to_v1_format
from gptqmodel.utils.model import hf_gptqmodel_post_init as gptq_post_init
from gptqmodel.version import __version__ as gptqmodel_version

logger = getLogger(__name__)

Expand Down Expand Up @@ -80,15 +83,17 @@ def __init__(
desc_act: bool = False,
sym: bool = True,
true_sequential: bool = True,
use_cuda_fp16: bool = False,
checkpoint_format: str = "gptq",
meta: Optional[Dict[str, any]] = None,
backend: Optional[str] = None,
use_cuda_fp16: bool = False,
model_seqlen: Optional[int] = None,
block_name_to_quantize: Optional[str] = None,
module_name_preceding_first_block: Optional[List[str]] = None,
batch_size: int = 1,
pad_token_id: Optional[int] = None,
disable_exllama: bool = False,
exllama_config: Dict[str, Any] = None,
exllama_config: Optional[Dict[str, Any]] = None,
max_input_length: Optional[int] = None,
cache_block_outputs: Optional[bool] = True,
modules_in_block_to_quantize: Optional[List[List[str]]] = None,
Expand Down Expand Up @@ -117,6 +122,14 @@ def __init__(
Whether to perform sequential quantization even within a single Transformer block.
Instead of quantizing the entire block at once, we perform layer-wise quantization.
As a result, each layer undergoes quantization using inputs that have passed through the previously quantized layers.
checkpoint_format (`str`, *optional*, defaults to `gptq`):
GPTQ weight format. `gptq`(v1) is supported by both gptqmodel and auto-gptq. `gptq_v2` is gptqmodel only.
meta (`Dict[str, any]`, *optional*):
Properties, such as tooling:version, that do not directly contributes to quantization or quant inference are stored in meta.
i.e. `meta.quantizer`: ["optimum:_version_", "gptqmodel:_version_"]
backend (`str`, *optional*):
Controls which gptq kernel to be used. Valid values for gptqmodel are `auto`, `auto_trainable` and more. For auto-gptq, only
valid value is None and `auto_trainable`. Ref gptqmodel backends: https://github.com/ModelCloud/GPTQModel/blob/main/gptqmodel/utils/backend.py
use_cuda_fp16 (`bool`, defaults to `False`):
Whether or not to use optimized cuda kernel for fp16 model. Need to have model in fp16.
model_seqlen (`Optional[int]`, defaults to `None`):
Expand Down Expand Up @@ -152,6 +165,9 @@ def __init__(
self.desc_act = desc_act
self.sym = sym
self.true_sequential = true_sequential
self.checkpoint_format = checkpoint_format.lower()
self.meta = meta
self.backend = backend.lower() if backend is not None else None
self.use_cuda_fp16 = use_cuda_fp16
self.model_seqlen = model_seqlen
self.block_name_to_quantize = block_name_to_quantize
Expand All @@ -164,7 +180,6 @@ def __init__(
self.quant_method = QuantizationMethod.GPTQ
self.cache_block_outputs = cache_block_outputs
self.modules_in_block_to_quantize = modules_in_block_to_quantize
self.checkpoint_format = checkpoint_format

self.serialization_keys = [
"bits",
Expand All @@ -177,6 +192,7 @@ def __init__(
"quant_method",
"modules_in_block_to_quantize",
"checkpoint_format",
"meta",
]

if self.bits not in [2, 3, 4, 8]:
Expand All @@ -198,15 +214,17 @@ def __init__(
)
self.exllama_version = self.exllama_config["version"]

def select_quant_linear(self, pack: bool, device_map: Union[str, dict]):
def select_quant_linear(self, device_map: Union[str, dict]):
if is_gptqmodel_available():
self.quant_linear = hf_select_quant_linear(
bits=self.bits,
group_size=self.group_size,
desc_act=self.desc_act,
sym=self.sym,
checkpoint_format=self.checkpoint_format,
meta=self.meta,
device_map=device_map,
pack=pack,
backend=self.backend,
)
else:
self.quant_linear = hf_select_quant_linear(
Expand All @@ -225,6 +243,20 @@ def to_dict(self):
gptq_dict = {}
for key in self.serialization_keys:
gptq_dict[key] = getattr(self, key)

if gptq_dict.get("meta") is None:
gptq_dict["meta"] = {}

meta = gptq_dict["meta"]
# store both optimum:version and gptq_lib:version into quantize_config.meta.quantizer
if meta.get("quantizer") is None:
meta["quantizer"] = [f"optimum:{optimum_version}"]

if is_gptqmodel_available():
meta["quantizer"].append(f"gptqmodel:{gptqmodel_version}")
elif is_auto_gptq_available():
meta["quantizer"].append(f"auto_gptq:{autogptq_version}")

return gptq_dict

@classmethod
Expand Down Expand Up @@ -263,7 +295,7 @@ def convert_model(self, model: nn.Module, **kwargs):
)
del layers_to_be_replaced[name]

self.select_quant_linear(pack=False, device_map=kwargs.get("device_map", None))
self.select_quant_linear(device_map=kwargs.get("device_map", None))

self._replace_by_quant_layers(model, layers_to_be_replaced)

Expand Down Expand Up @@ -379,10 +411,7 @@ def quantize_model(self, model: nn.Module, tokenizer: Optional[Any] = None):
gptq_supports_cpu = (
is_auto_gptq_available()
and version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2")
) or (
is_gptqmodel_available()
and version.parse(importlib.metadata.version("gptqmodel")) > version.parse("1.3.1")
)
) or is_gptqmodel_available()

if not gptq_supports_cpu and not torch.cuda.is_available():
raise RuntimeError(
Expand Down Expand Up @@ -663,18 +692,12 @@ def tmp(_, input, output):
# Step 5: Any post-initialization that require device information, for example buffers initialization on device.
model = self.post_init_model(model)

# convert gptqmodel internal gptq_v2 format to v1 for saving/compat
# sym=False is valid for gptq_v2 format only. for sym=True, need to convert to v1 before save.
if self.sym and self.checkpoint_format == "gptq_v2":
model = hf_convert_gptq_v2_to_v1_format(model, self.bits, self.quant_linear)
self.checkpoint_format = "gptq"

torch.cuda.empty_cache()
if hasattr(torch, "xpu"):
torch.xpu.empty_cache()
return model

def post_init_model(self, model, **kwargs):
def post_init_model(self, model):
"""
Post-initialization that require device information, for example buffers initialization on device.
Expand All @@ -695,8 +718,8 @@ def post_init_model(self, model, **kwargs):
class StoreAttr(object):
pass

if is_gptqmodel_available() and self.checkpoint_format == "gptq":
model = hf_convert_gptq_v1_to_v2_format(model, self.bits, self.quant_linear)
if is_gptqmodel_available():
model, _ = hf_convert_gptq_v1_to_v2_format(model, self.bits, self.quant_linear, self.checkpoint_format, self.meta)

model.quantize_config = StoreAttr()
model.quantize_config.desc_act = self.desc_act
Expand Down Expand Up @@ -727,7 +750,7 @@ def pack_model(
layers = get_layers(model)
layers = {n: layers[n] for n in quantizers}

self.select_quant_linear(pack=True, device_map=model.hf_device_map)
self.select_quant_linear(device_map=model.hf_device_map)

self._replace_by_quant_layers(model, quantizers)
qlayers = get_layers(model, [self.quant_linear])
Expand Down Expand Up @@ -765,6 +788,12 @@ def save(self, model: nn.Module, save_dir: str, max_shard_size: str = "10GB", sa
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
"""

# convert gptqmodel internal gptq_v2 format to v1 for max compatibility
model, converted = hf_convert_gptq_v2_to_v1_format(model, self.sym, self.bits, self.quant_linear, self.checkpoint_format, self.meta)
if converted:
self.checkpoint_format = "gptq"

os.makedirs(save_dir, exist_ok=True)
model.save_pretrained(save_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
with open(os.path.join(save_dir, GPTQ_CONFIG), "w", encoding="utf-8") as f:
Expand Down Expand Up @@ -871,7 +900,7 @@ def load_quantized_model(
quantizer.exllama_version = quantizer.exllama_config["version"]
quantizer.max_input_length = max_input_length

model = quantizer.convert_model(model)
model = quantizer.convert_model(model, device_map=device_map)

if no_split_module_classes is None:
no_split_module_classes = quantizer.get_no_split_module_classes(model)
Expand Down
16 changes: 12 additions & 4 deletions optimum/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
TRANSFORMERS_MINIMUM_VERSION = version.parse("4.25.0")
DIFFUSERS_MINIMUM_VERSION = version.parse("0.22.0")
AUTOGPTQ_MINIMUM_VERSION = version.parse("0.4.99") # Allows 0.5.0.dev0
GPTQMODEL_MINIMUM_VERSION = version.parse("1.3.99") # Allows 1.4.0.dev0


# This is the minimal required version to support some ONNX Runtime features
Expand Down Expand Up @@ -139,17 +140,24 @@ def is_datasets_available():

def is_auto_gptq_available():
if _auto_gptq_available:
version_autogptq = version.parse(importlib_metadata.version("auto_gptq"))
if AUTOGPTQ_MINIMUM_VERSION < version_autogptq:
v = version.parse(importlib_metadata.version("auto_gptq"))
if v >= AUTOGPTQ_MINIMUM_VERSION:
return True
else:
raise ImportError(
f"Found an incompatible version of auto-gptq. Found version {version_autogptq}, but only version above {AUTOGPTQ_MINIMUM_VERSION} are supported"
f"Found an incompatible version of auto-gptq. Found version {v}, but only version >= {AUTOGPTQ_MINIMUM_VERSION} are supported"
)


def is_gptqmodel_available():
return _gptqmodel_available
if _gptqmodel_available:
v = version.parse(importlib_metadata.version("gptqmodel"))
if v >= GPTQMODEL_MINIMUM_VERSION:
return True
else:
raise ImportError(
f"Found an incompatible version of gptqmodel. Found version {v}, but only version >= {GPTQMODEL_MINIMUM_VERSION} are supported"
)


@contextmanager
Expand Down
43 changes: 27 additions & 16 deletions tests/gptq/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class GPTQTest(unittest.TestCase):
bits = 4
group_size = 128
desc_act = False
sym = True
disable_exllama = True
exllama_config = None
cache_block_outputs = True
Expand All @@ -73,6 +74,7 @@ def setUpClass(cls):
"""

cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
cls.config = AutoConfig.from_pretrained(cls.model_name)

cls.model_fp16 = AutoModelForCausalLM.from_pretrained(
cls.model_name, torch_dtype=torch.float16, device_map=cls.device_map_for_quantization
Expand All @@ -87,6 +89,7 @@ def setUpClass(cls):
dataset=cls.dataset,
group_size=cls.group_size,
desc_act=cls.desc_act,
sym=cls.sym,
disable_exllama=cls.disable_exllama,
exllama_config=cls.exllama_config,
cache_block_outputs=cls.cache_block_outputs,
Expand Down Expand Up @@ -116,13 +119,20 @@ def test_quantized_layers_class(self):
"""

if is_gptqmodel_available():
if hasattr(self.config, "quantization_config"):
checkpoint_format = self.config.quantization_config.get("checkpoint_format")
meta = self.config.quantization_config.get("meta")
else:
checkpoint_format = "gptq"
meta = None
QuantLinear = hf_select_quant_linear(
bits=self.bits,
group_size=self.group_size,
desc_act=self.desc_act,
sym=True,
sym=self.sym,
device_map=self.device_map_for_quantization,
pack=False,
checkpoint_format=checkpoint_format,
meta=meta,
)
else:
QuantLinear = hf_select_quant_linear(
Expand All @@ -133,10 +143,10 @@ def test_quantized_layers_class(self):
disable_exllama=self.disable_exllama or self.exllama_config["version"] != 1,
disable_exllamav2=self.disable_exllama or self.exllama_config["version"] != 2,
)
self.assertTrue(self.quantized_model.model.layers[0].mlp.gate_proj.__class__ == QuantLinear)
self.assertEqual(self.quantized_model.model.layers[0].mlp.gate_proj.__class__, QuantLinear)

def check_quantized_layers_type(self, model, value):
self.assertTrue(model.model.layers[0].mlp.gate_proj.QUANT_TYPE == value)
self.assertEqual(model.model.layers[0].mlp.gate_proj.QUANT_TYPE, value)

def test_serialization(self):
"""
Expand All @@ -161,7 +171,7 @@ def test_serialization(self):
if is_auto_gptq_available() and not is_gptqmodel_available():
quant_type = "cuda-old" if self.disable_exllama else "exllama"
else:
quant_type = "ipex" if self.device_map_for_quantization == "cpu" else "cuda"
quant_type = "ipex" if self.device_map_for_quantization == "cpu" else "exllama"

self.check_quantized_layers_type(quantized_model_from_saved, quant_type)

Expand All @@ -179,16 +189,19 @@ def test_serialization(self):
class GPTQTestCUDA(GPTQTest):
device_map_for_quantization = "cuda"
device_for_inference = 0
expected_compression_ratio = 1.66
expected_compression_ratio = 1.2577
expected_fp16_perplexity = 38
expected_quantized_perplexity = 45


def test_perplexity(self):
"""
A simple test to check if the model conversion has been done correctly by checking on the
the perplexity of the converted models
"""

self.assertEqual(int(self.fp16_ppl), self.expected_fp16_perplexity)
self.assertEqual(int(self.quantized_ppl), self.expected_quantized_perplexity)
self.assertLessEqual(int(self.fp16_ppl), self.expected_fp16_perplexity)
self.assertLessEqual(int(self.quantized_ppl), self.expected_quantized_perplexity)


class GPTQTestExllama(GPTQTestCUDA):
Expand All @@ -199,6 +212,7 @@ class GPTQTestExllama(GPTQTestCUDA):
class GPTQTestActOrder(GPTQTestCUDA):
disable_exllama = True
desc_act = True
expected_quantized_perplexity = 46

def test_serialization(self):
# act_order don't work with qlinear_cuda kernel
Expand Down Expand Up @@ -282,7 +296,6 @@ def test_exllama_serialization(self):
"""
Test the serialization of the model and the loading of the quantized weights with exllamav2 kernel
"""

with tempfile.TemporaryDirectory() as tmpdirname:
self.quantizer.save(self.quantized_model, tmpdirname)
self.quantized_model.config.save_pretrained(tmpdirname)
Expand All @@ -296,16 +309,13 @@ def test_exllama_serialization(self):
save_folder=tmpdirname,
device_map={"": self.device_for_inference},
)
self.check_quantized_layers_type(quantized_model_from_saved, "exllamav2")
self.check_quantized_layers_type(quantized_model_from_saved, "exllama" if is_gptqmodel_available else "exllamav2")

# transformers and auto-gptq compatibility
# quantized models are more compatible with device map than
# device context managers (they're never used in transformers testing suite)
_ = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map={"": self.device_for_inference})
if is_gptqmodel_available():
_ = GPTQModel.load(tmpdirname, device_map={"": self.device_for_inference})
else:
_ = AutoGPTQForCausalLM.from_quantized(tmpdirname, device_map={"": self.device_for_inference})
_ = AutoGPTQForCausalLM.from_quantized(tmpdirname, device_map={"": self.device_for_inference})


class GPTQTestNoBlockCaching(GPTQTestCUDA):
Expand All @@ -318,11 +328,12 @@ class GPTQTestModuleQuant(GPTQTestCUDA):
["self_attn.q_proj"],
["mlp.gate_proj"],
]
expected_compression_ratio = 1.577
expected_compression_ratio = 1.068
expected_quantized_perplexity = 39

def test_not_converted_layers(self):
# self_attention.dense should not be converted
self.assertTrue(self.quantized_model.model.layers[0].self_attn.k_proj.__class__.__name__ == "Linear")
self.assertEqual(self.quantized_model.model.layers[0].self_attn.k_proj.__class__.__name__, "Linear")


class GPTQUtilsTest(unittest.TestCase):
Expand Down

0 comments on commit 5979473

Please sign in to comment.