diff --git a/src/transformers/integrations/spqr.py b/src/transformers/integrations/spqr.py index 9a77292215334a..011966154aad8c 100644 --- a/src/transformers/integrations/spqr.py +++ b/src/transformers/integrations/spqr.py @@ -21,11 +21,11 @@ def replace_with_spqr_linear( - model, - quantization_config=None, - modules_to_not_convert=None, - current_key_name=None, - has_been_replaced=False, + model, + quantization_config=None, + modules_to_not_convert=None, + current_key_name=None, + has_been_replaced=False, ): """ Public method that recursively replaces the Linear layers of the given model with SpQR quantized layers. @@ -64,10 +64,24 @@ def replace_with_spqr_linear( if ".".join(current_key_name) + ".weight" not in modules_to_not_convert: with init_empty_weights(): tensor_name = ".".join(current_key_name) - dense_weights_shape = quantization_config.shapes[f"{tensor_name}.dense_weights.shape"] - row_offsets_shape = quantization_config.shapes[f"{tensor_name}.row_offsets.shape"] - col_vals_shape = quantization_config.shapes[f"{tensor_name}.col_vals.shape"] - in_perm_shape = quantization_config.shapes[f"{tensor_name}.in_perm.shape"] + + shapes = quantization_config.shapes + shapes_keys = shapes.keys() + + shapes_valid = f"{tensor_name}.dense_weights.shape" in shapes_keys and \ + f"{tensor_name}.row_offsets.shape" in shapes_keys and \ + f"{tensor_name}.col_vals.shape" in shapes_keys and \ + f"{tensor_name}.in_perm.shape" in shapes_keys + + if not shapes_valid: + raise ValueError(f'The SpQR quantization config does not contain the shape ' + f'configuration for {tensor_name}. This indicates that the ' + f'configuration is either invalid or corrupted.') + + dense_weights_shape = shapes[f"{tensor_name}.dense_weights.shape"] + row_offsets_shape = shapes[f"{tensor_name}.row_offsets.shape"] + col_vals_shape = shapes[f"{tensor_name}.col_vals.shape"] + in_perm_shape = shapes[f"{tensor_name}.in_perm.shape"] in_features = module.in_features out_features = module.out_features