Skip to content

Commit

Permalink
[Fix] all tensors not same device (#5)
Browse files Browse the repository at this point in the history
* fix device error

* update gptqmodel version

* fix test
  • Loading branch information
ZX-ModelCloud authored Dec 16, 2024
1 parent 3603a0b commit 32b0e7d
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 16 deletions.
27 changes: 13 additions & 14 deletions optimum/gptq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from gptqmodel import exllama_set_max_input_length
from gptqmodel.quantization import GPTQ
from gptqmodel.utils.importer import hf_select_quant_linear
from gptqmodel.utils.model import hf_convert_gptq_v1_to_v2_format, hf_convert_gptq_v2_to_v1_format
from gptqmodel.utils.model import hf_convert_gptq_v1_to_v2_format, hf_convert_gptq_v2_to_v1_format, nested_move_to
from gptqmodel.utils.model import hf_gptqmodel_post_init as gptq_post_init
from gptqmodel.version import __version__ as gptqmodel_version

Expand Down Expand Up @@ -511,9 +511,11 @@ def quantize_model(self, model: nn.Module, tokenizer: Optional[Any] = None):

blocks = recurse_getattr(model, self.block_name_to_quantize)

cur_layer_device = get_device(blocks[0])

if not has_device_map:
# put modules from module_name_preceding_first_block on cuda or xpu or cpu
to_device = 0 if has_device_more_than_cpu() else "cpu"
to_device = cur_layer_device
for module_name in self.module_name_preceding_first_block:
module = recurse_getattr(model, module_name)
if module is None:
Expand All @@ -525,26 +527,22 @@ def store_input_hook(_, input, *args):
kwargs = args[0]
if input is None:
if "hidden_states" in kwargs:
input = (kwargs["hidden_states"],)
input = (nested_move_to(kwargs["hidden_states"], cur_layer_device),)
else:
raise ValueError("No input value found in the foward pass")
layer_inputs.append(input)
other_kwargs = {}
for k, v in kwargs.items(): # make sure other arguments also be captured
if k not in ["hidden_states"]:
other_kwargs[k] = v
other_kwargs[k] = nested_move_to(v, cur_layer_device)
layer_input_kwargs.append(other_kwargs)
raise ValueError

if self.cache_block_outputs:
handle = blocks[0].register_forward_pre_hook(store_input_hook, with_kwargs=True)
for data in dataset:
for k, v in data.items():
# put the data on gpu, we won't put them back to cpu
if (not has_device_map or device.type == "cpu") and has_device_more_than_cpu():
data[k] = v.to(0)
else:
data[k] = v.to(device)
data[k] = nested_move_to(v, cur_layer_device)
try:
model(**data)
except ValueError:
Expand All @@ -571,11 +569,7 @@ def store_input_hook(_, input, *args):
handle = block.register_forward_pre_hook(store_input_hook, with_kwargs=True)
for data in dataset:
for k, v in data.items():
# put the data on gpu, we won't put them back to cpu
if (not has_device_map or device.type == "cpu") and has_device_more_than_cpu():
data[k] = v.to(0)
else:
data[k] = v.to(device)
data[k] = nested_move_to(v, cur_layer_device)
try:
model(**data)
except ValueError:
Expand All @@ -587,6 +581,7 @@ def store_input_hook(_, input, *args):
if (not has_device_map or get_device(block) == torch.device("cpu")) and has_device_more_than_cpu():
block = block.to(0)
layers = get_layers(block)
block_device = get_device(block)
if isinstance(self.modules_in_block_to_quantize, list) and len(self.modules_in_block_to_quantize) > 0:
if self.true_sequential:
layers_name_list = self.modules_in_block_to_quantize
Expand Down Expand Up @@ -620,6 +615,10 @@ def tmp(_, input, output):
for j in range(len(dataset)):
# the args are already on the gpu
# don't need to store the output
layer_inputs[j] = nested_move_to(layer_inputs[j], block_device)
for k, v in layer_input_kwargs[j].items():
layer_input_kwargs[j][k] = nested_move_to(v, block_device)

block(*layer_inputs[j], **layer_input_kwargs[j])
# remove hook
for h in handles:
Expand Down
14 changes: 14 additions & 0 deletions optimum/gptq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,17 @@ def get_seqlen(model: nn.Module):
"We couldn't get the model sequence length. Setting it to 2048. You can overwrite this value by passing `model_seqlen` in` GPTQQuantizer`"
)
return 2048

def move_to(obj: torch.Tensor | nn.Module, device: torch.device):
if get_device(obj) != device:
obj = obj.to(device)
return obj


def nested_move_to(v, device):
if isinstance(v, torch.Tensor):
return move_to(v, device)
elif isinstance(v, (list, tuple)):
return type(v)([nested_move_to(e, device) for e in v])
else:
return v
2 changes: 1 addition & 1 deletion optimum/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
TRANSFORMERS_MINIMUM_VERSION = version.parse("4.25.0")
DIFFUSERS_MINIMUM_VERSION = version.parse("0.22.0")
AUTOGPTQ_MINIMUM_VERSION = version.parse("0.4.99") # Allows 0.5.0.dev0
GPTQMODEL_MINIMUM_VERSION = version.parse("1.4.1") # Allows 1.4.0.dev0
GPTQMODEL_MINIMUM_VERSION = version.parse("1.4.2")


# This is the minimal required version to support some ONNX Runtime features
Expand Down
5 changes: 4 additions & 1 deletion tests/gptq/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,10 @@ def test_exllama_serialization(self):
# quantized models are more compatible with device map than
# device context managers (they're never used in transformers testing suite)
_ = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map={"": self.device_for_inference})
_ = AutoGPTQForCausalLM.from_quantized(tmpdirname, device_map={"": self.device_for_inference})
if is_gptqmodel_available():
_ = GPTQModel.load(tmpdirname, device_map={"": self.device_for_inference})
else:
_ = AutoGPTQForCausalLM.from_quantized(tmpdirname, device_map={"": self.device_for_inference})


class GPTQTestNoBlockCaching(GPTQTestCUDA):
Expand Down

0 comments on commit 32b0e7d

Please sign in to comment.