Skip to content

Commit

Permalink
Check if the config contains proper shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
elvircrn committed Dec 13, 2024
1 parent 0702927 commit 48804ce
Showing 1 changed file with 23 additions and 9 deletions.
32 changes: 23 additions & 9 deletions src/transformers/integrations/spqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@


def replace_with_spqr_linear(
model,
quantization_config=None,
modules_to_not_convert=None,
current_key_name=None,
has_been_replaced=False,
model,
quantization_config=None,
modules_to_not_convert=None,
current_key_name=None,
has_been_replaced=False,
):
"""
Public method that recursively replaces the Linear layers of the given model with SpQR quantized layers.
Expand Down Expand Up @@ -64,10 +64,24 @@ def replace_with_spqr_linear(
if ".".join(current_key_name) + ".weight" not in modules_to_not_convert:
with init_empty_weights():
tensor_name = ".".join(current_key_name)
dense_weights_shape = quantization_config.shapes[f"{tensor_name}.dense_weights.shape"]
row_offsets_shape = quantization_config.shapes[f"{tensor_name}.row_offsets.shape"]
col_vals_shape = quantization_config.shapes[f"{tensor_name}.col_vals.shape"]
in_perm_shape = quantization_config.shapes[f"{tensor_name}.in_perm.shape"]

shapes = quantization_config.shapes
shapes_keys = shapes.keys()

shapes_valid = f"{tensor_name}.dense_weights.shape" in shapes_keys and \
f"{tensor_name}.row_offsets.shape" in shapes_keys and \
f"{tensor_name}.col_vals.shape" in shapes_keys and \
f"{tensor_name}.in_perm.shape" in shapes_keys

if not shapes_valid:
raise ValueError(f'The SpQR quantization config does not contain the shape '
f'configuration for {tensor_name}. This indicates that the '
f'configuration is either invalid or corrupted.')

dense_weights_shape = shapes[f"{tensor_name}.dense_weights.shape"]
row_offsets_shape = shapes[f"{tensor_name}.row_offsets.shape"]
col_vals_shape = shapes[f"{tensor_name}.col_vals.shape"]
in_perm_shape = shapes[f"{tensor_name}.in_perm.shape"]

in_features = module.in_features
out_features = module.out_features
Expand Down

0 comments on commit 48804ce

Please sign in to comment.