From 21de42f05c297e6d165def90a5db95d5637b6d6c Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 19 Dec 2024 18:30:09 +0800 Subject: [PATCH] Enable GPTQModel (#2064) * align gptq check to transformers for supporting cpu * fix comment * gptqmodel Signed-off-by: jiqing-feng * compatible with auto-gptq Signed-off-by: jiqing-feng * fix compatible with auto-gptq Signed-off-by: jiqing-feng * fix compatible with auto-gptq linear Signed-off-by: jiqing-feng * revert unrelated changes Signed-off-by: jiqing-feng * gptqmodel need use checkpoint_format (#1) * need checkpoint_format * default value of checkpoint_format is gptq * fix quantize * fix quantize * fix quantize * Update quantizer.py * need convert to v1 before gptqmodel save * back checkpoint_format to gptq after convert * cleanup code * sym=False is not supported with auto-gptq * add comments * cleanup code * Update quantizer.py * always convert v2 to v1 if checkpoint_format = "gptq" * Update quantizer.py --------- Co-authored-by: ZX-ModelCloud Co-authored-by: Qubitium-ModelCloud * Mod backend code (#2) * keep gptq_v2 if sym is false * use hf_convert_gptq_v1_to_v2_format, hf_convert_gptq_v2_to_v1_format, and hf_gptqmodel_post_init * no need check backend * use device_map * cleanup * Update quantizer.py * move import --------- Co-authored-by: Qubitium-ModelCloud * fix format and log Signed-off-by: jiqing-feng * fix version check Signed-off-by: jiqing-feng * enable gptqmodel tests Signed-off-by: jiqing-feng * update check quant type Signed-off-by: jiqing-feng * Fix optimum compat (#3) * add meta info * cleanup * cleanup * The value of quantizer should be an array * Update quantizer.py * If is_auto_gptq_available() also writes "auto_gptq:version" to "quantizer" * If is_auto_gptq_available() also writes "auto_gptq:version" to "quantizer" * Update quantizer.py * cleanup * comment on meta * hf_select_quant_linear pass checkpoint_format * add todo fix * move convert code to quantizer.save() * Update quantizer.py * Optimize hf_convert_gptq_v2_to_v1_format() * Optimize hf_convert_gptq_v1_to_v2_format() * fix GPTQTestCUDA * hf_select_quant_linear() always set pack=True * gptqmodel.hf_select_quant_linear() now does not select ExllamaV2 * gptqmodel.hf_select_quant_linear() now does not select ExllamaV2 * GPTQQuantizer add backend * lower checkpoint_format and backend * cleanup * move backend to bottom * no need to check gptqmodel version for ipex support * Update import_utils.py * Update quantizer.py * fix UnboundLocalError: cannot access local variable 'version' where it is not associated with a value * make version var short * Update import_utils.py * fix unittest * use assertLessEqual --------- Co-authored-by: Qubitium-ModelCloud Co-authored-by: LRL * fix format and convert v2 to v1 Signed-off-by: jiqing-feng * [Fix] all tensors not same device (#5) * fix device error * update gptqmodel version * fix test * fix format Signed-off-by: jiqing-feng * add gptqmodel tests which contains cpu Signed-off-by: jiqing-feng * fix all auto-gptq tests Signed-off-by: jiqing-feng * revert tests Signed-off-by: jiqing-feng * rm gptqmodel yaml Signed-off-by: jiqing-feng * fix comment Signed-off-by: jiqing-feng * enable real cpu tests by fp32 Signed-off-by: jiqing-feng * fix test model name Signed-off-by: jiqing-feng * keep the original device setting when using auto-gptq Signed-off-by: jiqing-feng * Update optimum/gptq/quantizer.py Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> * Update optimum/gptq/quantizer.py Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> --------- Signed-off-by: jiqing-feng Co-authored-by: LRL-ModelCloud <165116337+LRL-ModelCloud@users.noreply.github.com> Co-authored-by: ZX-ModelCloud Co-authored-by: Qubitium-ModelCloud Co-authored-by: ZX-ModelCloud <165115237+ZX-ModelCloud@users.noreply.github.com> Co-authored-by: LRL Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> --- optimum/gptq/quantizer.py | 253 ++++++++++++++++++++++++++-------- optimum/gptq/utils.py | 15 ++ optimum/utils/__init__.py | 1 + optimum/utils/import_utils.py | 19 ++- 4 files changed, 227 insertions(+), 61 deletions(-) diff --git a/optimum/gptq/quantizer.py b/optimum/gptq/quantizer.py index 849d8821ebf..844da3e3157 100644 --- a/optimum/gptq/quantizer.py +++ b/optimum/gptq/quantizer.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib import json import os from enum import Enum @@ -19,17 +20,26 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch +from packaging import version from torch import nn from tqdm.auto import tqdm from transformers import AutoTokenizer from transformers.pytorch_utils import Conv1D from transformers.utils.quantization_config import QuantizationMethod -from ..utils import is_accelerate_available, is_auto_gptq_available +from ..utils import is_accelerate_available, is_auto_gptq_available, is_gptqmodel_available from ..utils.modeling_utils import recurse_getattr +from ..version import __version__ as optimum_version from .constants import GPTQ_CONFIG from .data import get_dataset, prepare_dataset -from .utils import get_block_name_with_pattern, get_device, get_layers, get_preceding_modules, get_seqlen +from .utils import ( + get_block_name_with_pattern, + get_device, + get_layers, + get_preceding_modules, + get_seqlen, + nested_move_to, +) if is_accelerate_available(): @@ -40,14 +50,27 @@ from accelerate.hooks import remove_hook_from_module if is_auto_gptq_available(): + from auto_gptq import __version__ as autogptq_version from auto_gptq import exllama_set_max_input_length - from auto_gptq.modeling._utils import autogptq_post_init + from auto_gptq.modeling._utils import autogptq_post_init as gptq_post_init from auto_gptq.quantization import GPTQ - from auto_gptq.utils.import_utils import dynamically_import_QuantLinear + from auto_gptq.utils.import_utils import dynamically_import_QuantLinear as hf_select_quant_linear + +if is_gptqmodel_available(): + from gptqmodel import exllama_set_max_input_length + from gptqmodel.quantization import GPTQ + from gptqmodel.utils.importer import hf_select_quant_linear + from gptqmodel.utils.model import hf_convert_gptq_v1_to_v2_format, hf_convert_gptq_v2_to_v1_format + from gptqmodel.utils.model import hf_gptqmodel_post_init as gptq_post_init + from gptqmodel.version import __version__ as gptqmodel_version logger = getLogger(__name__) +def has_device_more_than_cpu(): + return torch.cuda.is_available() or (hasattr(torch, "xpu") and torch.xpu.is_available()) + + class ExllamaVersion(int, Enum): ONE = 1 TWO = 2 @@ -74,10 +97,13 @@ def __init__( batch_size: int = 1, pad_token_id: Optional[int] = None, disable_exllama: bool = False, - exllama_config: Dict[str, Any] = None, + exllama_config: Optional[Dict[str, Any]] = None, max_input_length: Optional[int] = None, cache_block_outputs: Optional[bool] = True, modules_in_block_to_quantize: Optional[List[List[str]]] = None, + checkpoint_format: str = "gptq", + meta: Optional[Dict[str, any]] = None, + backend: Optional[str] = None, *args, **kwargs, ): @@ -129,6 +155,13 @@ def __init__( List list of module names to quantize in the block specified. This argument is useful to exclude certain linear modules from being quantized. The block to quantize can be specified by setting `block_name_to_quantize`. We will quantize each list sequentially. If not set, we will quantize all linear layers. Example: `inside_layer_modules=[["self_attention.query_key_value"], ["mlp.dense_h_to_4h"]]` + checkpoint_format (`str`, *optional*, defaults to `gptq`): + GPTQ weight format. `gptq`(v1) is supported by both gptqmodel and auto-gptq. `gptq_v2` is gptqmodel only. + meta (`Dict[str, any]`, *optional*): + Properties, such as tooling:version, that do not directly contributes to quantization or quant inference are stored in meta. + i.e. `meta.quantizer`: ["optimum:_version_", "gptqmodel:_version_"] + backend (`str`, *optional*): + Controls which gptq kernel to be used. Valid values for gptqmodel are `auto`, `auto_trainable` and more. For auto-gptq, only valid value is None and `auto_trainable`. Ref gptqmodel backends: https://github.com/ModelCloud/GPTQModel/blob/main/gptqmodel/utils/backend.py """ self.bits = bits @@ -138,6 +171,9 @@ def __init__( self.desc_act = desc_act self.sym = sym self.true_sequential = true_sequential + self.checkpoint_format = checkpoint_format.lower() + self.meta = meta + self.backend = backend.lower() if backend is not None else None self.use_cuda_fp16 = use_cuda_fp16 self.model_seqlen = model_seqlen self.block_name_to_quantize = block_name_to_quantize @@ -161,6 +197,8 @@ def __init__( "true_sequential", "quant_method", "modules_in_block_to_quantize", + "checkpoint_format", + "meta", ] if self.bits not in [2, 3, 4, 8]: @@ -182,6 +220,28 @@ def __init__( ) self.exllama_version = self.exllama_config["version"] + def select_quant_linear(self, device_map: Union[str, dict]): + if is_gptqmodel_available(): + self.quant_linear = hf_select_quant_linear( + bits=self.bits, + group_size=self.group_size, + desc_act=self.desc_act, + sym=self.sym, + checkpoint_format=self.checkpoint_format, + meta=self.meta, + device_map=device_map, + backend=self.backend, + ) + else: + self.quant_linear = hf_select_quant_linear( + use_triton=False, + desc_act=self.desc_act, + group_size=self.group_size, + bits=self.bits, + disable_exllama=self.disable_exllama or self.exllama_version != ExllamaVersion.ONE, + disable_exllamav2=self.disable_exllama or self.exllama_version != ExllamaVersion.TWO, + ) + def to_dict(self): """ Returns the args in dict format. @@ -189,6 +249,20 @@ def to_dict(self): gptq_dict = {} for key in self.serialization_keys: gptq_dict[key] = getattr(self, key) + + if gptq_dict.get("meta") is None: + gptq_dict["meta"] = {} + + meta = gptq_dict["meta"] + # store both optimum:version and gptq_lib:version into quantize_config.meta.quantizer + if meta.get("quantizer") is None: + meta["quantizer"] = [f"optimum:{optimum_version}"] + + if is_gptqmodel_available(): + meta["quantizer"].append(f"gptqmodel:{gptqmodel_version}") + elif is_auto_gptq_available(): + meta["quantizer"].append(f"auto_gptq:{autogptq_version}") + return gptq_dict @classmethod @@ -205,7 +279,7 @@ def from_dict(cls, config_dict: Dict[str, Any]): """ return cls(**config_dict) - def convert_model(self, model: nn.Module): + def convert_model(self, model: nn.Module, **kwargs): """ Convert the model to a GPTQ model by getting and replacing the layers. @@ -226,7 +300,11 @@ def convert_model(self, model: nn.Module): f"Quantization disabled for {name} (only modules_in_block_to_quantize={self.modules_in_block_to_quantize} are quantized)" ) del layers_to_be_replaced[name] + + self.select_quant_linear(device_map=kwargs.get("device_map", None)) + self._replace_by_quant_layers(model, layers_to_be_replaced) + return model def get_no_split_module_classes(self, model): @@ -253,15 +331,7 @@ def _replace_by_quant_layers(self, module: nn.Module, names: List[str], name: st name (`str`, defaults to `""`): To keep track of the name of the current module """ - QuantLinear = dynamically_import_QuantLinear( - use_triton=False, - desc_act=self.desc_act, - group_size=self.group_size, - bits=self.bits, - disable_exllama=self.disable_exllama or self.exllama_version != ExllamaVersion.ONE, - disable_exllamav2=self.disable_exllama or self.exllama_version != ExllamaVersion.TWO, - ) - if isinstance(module, QuantLinear): + if isinstance(module, self.quant_linear): return for attr in dir(module): layer = getattr(module, attr) @@ -279,20 +349,37 @@ def _replace_by_quant_layers(self, module: nn.Module, names: List[str], name: st in_features = layer.weight.shape[0] out_features = layer.weight.shape[1] bias = layer.bias is not None - if not (self.desc_act) or self.group_size == -1: - new_layer = QuantLinear( + if is_gptqmodel_available(): + new_layer = self.quant_linear( self.bits, self.group_size, + self.desc_act, + self.sym, in_features, out_features, bias, - use_cuda_fp16=self.use_cuda_fp16, weight_dtype=layer.weight.dtype, ) else: - new_layer = QuantLinear( - self.bits, self.group_size, in_features, out_features, bias, weight_dtype=layer.weight.dtype - ) + if not (self.desc_act) or self.group_size == -1: + new_layer = self.quant_linear( + self.bits, + self.group_size, + in_features, + out_features, + bias, + use_cuda_fp16=self.use_cuda_fp16, + weight_dtype=layer.weight.dtype, + ) + else: + new_layer = self.quant_linear( + self.bits, + self.group_size, + in_features, + out_features, + bias, + weight_dtype=layer.weight.dtype, + ) new_layer.device = device setattr(module, attr, new_layer.to(device)) for name1, child in module.named_children(): @@ -318,13 +405,41 @@ def quantize_model(self, model: nn.Module, tokenizer: Optional[Any] = None): `nn.Module`: The quantized model """ - if not is_auto_gptq_available(): - raise RuntimeError("auto-gptq is required in order to perform quantzation : `pip install auto-gptq`") - if not torch.cuda.is_available(): - raise RuntimeError("No GPU found. A GPU is needed to quantize model.") + if not is_auto_gptq_available() and not is_gptqmodel_available(): + raise RuntimeError( + "gptqmodel or auto-gptq is required in order to perform gptq quantzation: `pip install gptqmodel` or `pip install auto-gptq`. Please notice that auto-gptq will be deprecated in the future." + ) + elif is_gptqmodel_available() and is_auto_gptq_available(): + logger.warning( + "Detected gptqmodel and auto-gptq, will use gptqmodel. The auto_gptq will be deprecated in the future." + ) + + gptq_supports_cpu = ( + is_auto_gptq_available() + and version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2") + ) or is_gptqmodel_available() + + if not gptq_supports_cpu and not torch.cuda.is_available(): + raise RuntimeError( + "No cuda gpu or cpu support using Intel/IPEX found. A gpu or cpu with Intel/IPEX is required for quantization." + ) + + if not self.sym and not is_gptqmodel_available(): + raise ValueError( + "Asymmetric sym=False quantization is not supported with auto-gptq. Please use gptqmodel: `pip install gptqmodel`" + ) + + if self.checkpoint_format == "gptq_v2" and not is_gptqmodel_available(): + raise ValueError( + "gptq_v2 format only supported with gptqmodel. Please install gptqmodel: `pip install gptqmodel`" + ) model.eval() + # gptqmodel internal is gptq_v2 for asym support, gptq(v1) can only support sym=True + if is_gptqmodel_available() and self.checkpoint_format != "gptq_v2": + self.checkpoint_format = "gptq_v2" + # For Transformer model has_config = False has_device_map = False @@ -403,27 +518,32 @@ def quantize_model(self, model: nn.Module, tokenizer: Optional[Any] = None): blocks = recurse_getattr(model, self.block_name_to_quantize) + cur_layer_device = get_device(blocks[0]) + if not is_gptqmodel_available(): + cur_layer_device = 0 + if not has_device_map: - # put modules from module_name_preceding_first_block on cuda + # put modules from module_name_preceding_first_block on cuda or xpu or cpu + to_device = cur_layer_device for module_name in self.module_name_preceding_first_block: module = recurse_getattr(model, module_name) if module is None: raise ValueError(f"Module {module_name} was not found in model") - module = module.to(0) - blocks[0] = blocks[0].to(0) + module = module.to(to_device) + blocks[0] = blocks[0].to(to_device) def store_input_hook(_, input, *args): kwargs = args[0] if input is None: if "hidden_states" in kwargs: - input = (kwargs["hidden_states"],) + input = (nested_move_to(kwargs["hidden_states"], cur_layer_device),) else: raise ValueError("No input value found in the foward pass") layer_inputs.append(input) other_kwargs = {} for k, v in kwargs.items(): # make sure other arguments also be captured if k not in ["hidden_states"]: - other_kwargs[k] = v + other_kwargs[k] = nested_move_to(v, cur_layer_device) layer_input_kwargs.append(other_kwargs) raise ValueError @@ -431,11 +551,7 @@ def store_input_hook(_, input, *args): handle = blocks[0].register_forward_pre_hook(store_input_hook, with_kwargs=True) for data in dataset: for k, v in data.items(): - # put the data on gpu, we won't put them back to cpu - if not has_device_map or device.type == "cpu": - data[k] = v.to(0) - else: - data[k] = v.to(device) + data[k] = nested_move_to(v, cur_layer_device) try: model(**data) except ValueError: @@ -450,6 +566,8 @@ def store_input_hook(_, input, *args): raise ValueError(f"Module {module_name} was not found in model") torch.cuda.empty_cache() + if hasattr(torch, "xpu") and torch.xpu.is_available(): + torch.xpu.empty_cache() # Step 3: Quantize the blocks quantizers = {} @@ -460,11 +578,7 @@ def store_input_hook(_, input, *args): handle = block.register_forward_pre_hook(store_input_hook, with_kwargs=True) for data in dataset: for k, v in data.items(): - # put the data on gpu, we won't put them back to cpu - if not has_device_map or device.type == "cpu": - data[k] = v.to(0) - else: - data[k] = v.to(device) + data[k] = nested_move_to(v, cur_layer_device) try: model(**data) except ValueError: @@ -473,9 +587,12 @@ def store_input_hook(_, input, *args): # move block to cuda if needed # in case we have offload modules, we need to put them on cuda because of GPTQ object - if not has_device_map or get_device(block) == torch.device("cpu"): + if (not has_device_map or get_device(block) == torch.device("cpu")) and has_device_more_than_cpu(): block = block.to(0) layers = get_layers(block) + block_device = get_device(block) + if not is_gptqmodel_available(): + block_device = 0 if isinstance(self.modules_in_block_to_quantize, list) and len(self.modules_in_block_to_quantize) > 0: if self.true_sequential: layers_name_list = self.modules_in_block_to_quantize @@ -509,15 +626,20 @@ def tmp(_, input, output): for j in range(len(dataset)): # the args are already on the gpu # don't need to store the output + layer_inputs[j] = nested_move_to(layer_inputs[j], block_device) + for k, v in layer_input_kwargs[j].items(): + layer_input_kwargs[j][k] = nested_move_to(v, block_device) + block(*layer_inputs[j], **layer_input_kwargs[j]) # remove hook for h in handles: h.remove() for name in subset_name_list: logger.info(f"Quantizing {name} in block {i + 1}/{len(blocks)}...") - scale, zero, g_idx = gptq[name].fasterquant( + quant_outputs = gptq[name].fasterquant( percdamp=self.damp_percent, group_size=self.group_size, actorder=self.desc_act ) + scale, zero, g_idx = quant_outputs[0], quant_outputs[1], quant_outputs[2] quantizers[f"{self.block_name_to_quantize}.{i}.{name}"] = ( gptq[name].quantizer, scale, @@ -543,11 +665,13 @@ def tmp(_, input, output): del layer_inputs layer_inputs = [] torch.cuda.empty_cache() + if hasattr(torch, "xpu") and torch.xpu.is_available(): + torch.xpu.empty_cache() if self.bits == 4: # device not on gpu if device.type != "cuda" or (has_device_map and any(d in devices for d in ["cpu", "disk", "hpu"])): - if not self.disable_exllama: + if not self.disable_exllama and not is_gptqmodel_available(): logger.warning( "Found modules on cpu/disk. Using Exllama/Exllamav2 backend requires all the modules to be on GPU. Setting `disable_exllama=True`" ) @@ -578,6 +702,8 @@ def tmp(_, input, output): model = self.post_init_model(model) torch.cuda.empty_cache() + if hasattr(torch, "xpu") and torch.xpu.is_available(): + torch.xpu.empty_cache() return model def post_init_model(self, model): @@ -601,9 +727,14 @@ def post_init_model(self, model): class StoreAttr(object): pass + if is_gptqmodel_available(): + model, _ = hf_convert_gptq_v1_to_v2_format( + model, self.bits, self.quant_linear, self.checkpoint_format, self.meta + ) + model.quantize_config = StoreAttr() model.quantize_config.desc_act = self.desc_act - model = autogptq_post_init(model, use_act_order=self.desc_act) + model = gptq_post_init(model, use_act_order=self.desc_act) if ( self.desc_act and (not self.disable_exllama and self.exllama_version == ExllamaVersion.ONE) @@ -626,19 +757,14 @@ def pack_model( quantizers (`Dict[str,Tuple]`): A mapping of the layer name and the data needed to pack the layer """ - QuantLinear = dynamically_import_QuantLinear( - use_triton=False, - desc_act=self.desc_act, - group_size=self.group_size, - bits=self.bits, - disable_exllama=self.disable_exllama or self.exllama_version != ExllamaVersion.ONE, - disable_exllamav2=self.disable_exllama or self.exllama_version != ExllamaVersion.TWO, - ) logger.info("Packing model...") layers = get_layers(model) layers = {n: layers[n] for n in quantizers} + + self.select_quant_linear(device_map=model.hf_device_map) + self._replace_by_quant_layers(model, quantizers) - qlayers = get_layers(model, [QuantLinear]) + qlayers = get_layers(model, [self.quant_linear]) for name in qlayers: logger.info(name) quantizers[name], scale, zero, g_idx = quantizers[name] @@ -673,6 +799,15 @@ def save(self, model: nn.Module, save_dir: str, max_shard_size: str = "10GB", sa Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). """ + + # convert gptqmodel internal gptq_v2 format to v1 for max compatibility + if is_gptqmodel_available(): + model, converted = hf_convert_gptq_v2_to_v1_format( + model, self.sym, self.bits, self.quant_linear, self.checkpoint_format, self.meta + ) + if converted: + self.checkpoint_format = "gptq" + os.makedirs(save_dir, exist_ok=True) model.save_pretrained(save_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization) with open(os.path.join(save_dir, GPTQ_CONFIG), "w", encoding="utf-8") as f: @@ -736,10 +871,12 @@ def load_quantized_model( Returns: `nn.Module`: The quantized model """ - if not torch.cuda.is_available(): - raise RuntimeError("No GPU found. A GPU is needed to run quantized model.") - if not is_auto_gptq_available(): - raise RuntimeError("auto-gptq is required in order to load quantized weights : `pip install auto-gptq`") + if not torch.cuda.is_available() and not is_gptqmodel_available(): + raise RuntimeError("No GPU found. A GPU is needed to run quantized model by auto_gptq.") + if not is_auto_gptq_available() and not is_gptqmodel_available(): + raise RuntimeError( + "gptqmodel (`pip install gptqmodel`) or auto-gptq (`pip install auto-gptq`) is required in order to load quantized weights. Please notice that auto-gptq will be deprecated in the future." + ) if not is_accelerate_available(): raise RuntimeError( "You need to install accelerate in order to load and dispatch weights to" @@ -777,7 +914,7 @@ def load_quantized_model( quantizer.exllama_version = quantizer.exllama_config["version"] quantizer.max_input_length = max_input_length - model = quantizer.convert_model(model) + model = quantizer.convert_model(model, device_map=device_map) if no_split_module_classes is None: no_split_module_classes = quantizer.get_no_split_module_classes(model) diff --git a/optimum/gptq/utils.py b/optimum/gptq/utils.py index a5f9afdaaef..732ecbd66b9 100644 --- a/optimum/gptq/utils.py +++ b/optimum/gptq/utils.py @@ -113,3 +113,18 @@ def get_seqlen(model: nn.Module): "We couldn't get the model sequence length. Setting it to 2048. You can overwrite this value by passing `model_seqlen` in` GPTQQuantizer`" ) return 2048 + + +def move_to(obj: torch.Tensor, device: torch.device): + if get_device(obj) != device: + obj = obj.to(device) + return obj + + +def nested_move_to(v, device): + if isinstance(v, torch.Tensor): + return move_to(v, device) + elif isinstance(v, (list, tuple)): + return type(v)([nested_move_to(e, device) for e in v]) + else: + return v diff --git a/optimum/utils/__init__.py b/optimum/utils/__init__.py index 2aa90253d08..e2b53a7dbc7 100644 --- a/optimum/utils/__init__.py +++ b/optimum/utils/__init__.py @@ -37,6 +37,7 @@ is_auto_gptq_available, is_datasets_available, is_diffusers_available, + is_gptqmodel_available, is_onnx_available, is_onnxruntime_available, is_pydantic_available, diff --git a/optimum/utils/import_utils.py b/optimum/utils/import_utils.py index 405e3815b33..d0f4c85db2b 100644 --- a/optimum/utils/import_utils.py +++ b/optimum/utils/import_utils.py @@ -52,6 +52,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ TRANSFORMERS_MINIMUM_VERSION = version.parse("4.25.0") DIFFUSERS_MINIMUM_VERSION = version.parse("0.22.0") AUTOGPTQ_MINIMUM_VERSION = version.parse("0.4.99") # Allows 0.5.0.dev0 +GPTQMODEL_MINIMUM_VERSION = version.parse("1.4.2") # This is the minimal required version to support some ONNX Runtime features @@ -67,6 +68,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _accelerate_available = _is_package_available("accelerate") _diffusers_available = _is_package_available("diffusers") _auto_gptq_available = _is_package_available("auto_gptq") +_gptqmodel_available = _is_package_available("gptqmodel") _timm_available = _is_package_available("timm") _sentence_transformers_available = _is_package_available("sentence_transformers") _datasets_available = _is_package_available("datasets") @@ -138,12 +140,23 @@ def is_datasets_available(): def is_auto_gptq_available(): if _auto_gptq_available: - version_autogptq = version.parse(importlib_metadata.version("auto_gptq")) - if AUTOGPTQ_MINIMUM_VERSION < version_autogptq: + v = version.parse(importlib_metadata.version("auto_gptq")) + if v >= AUTOGPTQ_MINIMUM_VERSION: return True else: raise ImportError( - f"Found an incompatible version of auto-gptq. Found version {version_autogptq}, but only version above {AUTOGPTQ_MINIMUM_VERSION} are supported" + f"Found an incompatible version of auto-gptq. Found version {v}, but only version >= {AUTOGPTQ_MINIMUM_VERSION} are supported" + ) + + +def is_gptqmodel_available(): + if _gptqmodel_available: + v = version.parse(importlib_metadata.version("gptqmodel")) + if v >= GPTQMODEL_MINIMUM_VERSION: + return True + else: + raise ImportError( + f"Found an incompatible version of gptqmodel. Found version {v}, but only version >= {GPTQMODEL_MINIMUM_VERSION} are supported" )