diff --git a/src/transformers/integrations/spqr.py b/src/transformers/integrations/spqr.py index 011966154aad8c..58b71740d37c77 100644 --- a/src/transformers/integrations/spqr.py +++ b/src/transformers/integrations/spqr.py @@ -21,11 +21,11 @@ def replace_with_spqr_linear( - model, - quantization_config=None, - modules_to_not_convert=None, - current_key_name=None, - has_been_replaced=False, + model, + quantization_config=None, + modules_to_not_convert=None, + current_key_name=None, + has_been_replaced=False, ): """ Public method that recursively replaces the Linear layers of the given model with SpQR quantized layers. @@ -68,15 +68,19 @@ def replace_with_spqr_linear( shapes = quantization_config.shapes shapes_keys = shapes.keys() - shapes_valid = f"{tensor_name}.dense_weights.shape" in shapes_keys and \ - f"{tensor_name}.row_offsets.shape" in shapes_keys and \ - f"{tensor_name}.col_vals.shape" in shapes_keys and \ - f"{tensor_name}.in_perm.shape" in shapes_keys + shapes_valid = ( + f"{tensor_name}.dense_weights.shape" in shapes_keys + and f"{tensor_name}.row_offsets.shape" in shapes_keys + and f"{tensor_name}.col_vals.shape" in shapes_keys + and f"{tensor_name}.in_perm.shape" in shapes_keys + ) if not shapes_valid: - raise ValueError(f'The SpQR quantization config does not contain the shape ' - f'configuration for {tensor_name}. This indicates that the ' - f'configuration is either invalid or corrupted.') + raise ValueError( + f"The SpQR quantization config does not contain the shape " + f"configuration for {tensor_name}. This indicates that the " + f"configuration is either invalid or corrupted." + ) dense_weights_shape = shapes[f"{tensor_name}.dense_weights.shape"] row_offsets_shape = shapes[f"{tensor_name}.row_offsets.shape"]