diff --git a/src/transformers/integrations/spqr.py b/src/transformers/integrations/spqr.py index fe382fa8e50f10..9a77292215334a 100644 --- a/src/transformers/integrations/spqr.py +++ b/src/transformers/integrations/spqr.py @@ -13,7 +13,8 @@ # limitations under the License. "SpQR (Sparse-Quantized Representation) integration file" -from ..utils import is_torch_available, is_spqr_available, is_accelerate_available +from ..utils import is_accelerate_available, is_spqr_available, is_torch_available + if is_torch_available(): import torch.nn as nn diff --git a/src/transformers/quantizers/quantizer_spqr.py b/src/transformers/quantizers/quantizer_spqr.py index f3f273bf60dd54..60cc1bca9b279b 100644 --- a/src/transformers/quantizers/quantizer_spqr.py +++ b/src/transformers/quantizers/quantizer_spqr.py @@ -52,9 +52,7 @@ def validate_environment(self, *args, **kwargs): def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": if torch_dtype is None: torch_dtype = torch.float16 - logger.info( - "Assuming SpQR inference on GPU and loading the model in `torch.float16`." - ) + logger.info("Assuming SpQR inference on GPU and loading the model in `torch.float16`.") elif torch_dtype != torch.float16: raise ValueError( "You cannot use any type other than torch.float16 for SpQR. Please either leave it None or set it to"