diff --git a/src/transformers/quantizers/quantizer_hqq.py b/src/transformers/quantizers/quantizer_hqq.py index 06949d059a5de3..14be75369dec0e 100755 --- a/src/transformers/quantizers/quantizer_hqq.py +++ b/src/transformers/quantizers/quantizer_hqq.py @@ -101,7 +101,7 @@ def check_quantized_param( ) -> bool: module, tensor_name = get_module_from_name(model, param_name) - return isinstance(module, torch.nn.Linear) + return isinstance(module, torch.nn.Linear) and (tensor_name == "weight") def create_quantized_param( self,