Skip to content

Commit

Permalink
Ruff test
Browse files Browse the repository at this point in the history
  • Loading branch information
elvircrn committed Dec 13, 2024
1 parent 48804ce commit c10cde7
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions src/transformers/integrations/spqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@


def replace_with_spqr_linear(
model,
quantization_config=None,
modules_to_not_convert=None,
current_key_name=None,
has_been_replaced=False,
model,
quantization_config=None,
modules_to_not_convert=None,
current_key_name=None,
has_been_replaced=False,
):
"""
Public method that recursively replaces the Linear layers of the given model with SpQR quantized layers.
Expand Down Expand Up @@ -68,15 +68,19 @@ def replace_with_spqr_linear(
shapes = quantization_config.shapes
shapes_keys = shapes.keys()

shapes_valid = f"{tensor_name}.dense_weights.shape" in shapes_keys and \
f"{tensor_name}.row_offsets.shape" in shapes_keys and \
f"{tensor_name}.col_vals.shape" in shapes_keys and \
f"{tensor_name}.in_perm.shape" in shapes_keys
shapes_valid = (
f"{tensor_name}.dense_weights.shape" in shapes_keys
and f"{tensor_name}.row_offsets.shape" in shapes_keys
and f"{tensor_name}.col_vals.shape" in shapes_keys
and f"{tensor_name}.in_perm.shape" in shapes_keys
)

if not shapes_valid:
raise ValueError(f'The SpQR quantization config does not contain the shape '
f'configuration for {tensor_name}. This indicates that the '
f'configuration is either invalid or corrupted.')
raise ValueError(
f"The SpQR quantization config does not contain the shape "
f"configuration for {tensor_name}. This indicates that the "
f"configuration is either invalid or corrupted."
)

dense_weights_shape = shapes[f"{tensor_name}.dense_weights.shape"]
row_offsets_shape = shapes[f"{tensor_name}.row_offsets.shape"]
Expand Down

0 comments on commit c10cde7

Please sign in to comment.