diff --git a/src/transformers/quantizers/quantizer_torchao.py b/src/transformers/quantizers/quantizer_torchao.py index f6bf431aa028d4..9a03eb25f4de0d 100644 --- a/src/transformers/quantizers/quantizer_torchao.py +++ b/src/transformers/quantizers/quantizer_torchao.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib +import types from typing import TYPE_CHECKING, Union from packaging import version @@ -30,9 +31,7 @@ if is_torch_available(): import torch - -if is_torchao_available(): - from torchao.quantization import quantize_ + import torch.nn as nn logger = logging.get_logger(__name__) @@ -46,6 +45,25 @@ def find_parent(model, name): return parent +def _quantization_type(weight): + from torchao.dtypes import AffineQuantizedTensor + from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor + + if isinstance(weight, AffineQuantizedTensor): + return f"{weight.__class__.__name__}({weight._quantization_type()})" + + if isinstance(weight, LinearActivationQuantizedTensor): + return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})" + + +def _linear_extra_repr(self): + weight = _quantization_type(self.weight) + if weight is None: + return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight=None" + else: + return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={weight}" + + class TorchAoHfQuantizer(HfQuantizer): """ Quantizer for torchao: https://github.com/pytorch/ao/ @@ -152,9 +170,17 @@ def create_quantized_param( Each nn.Linear layer that needs to be quantized is processsed here. First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module. """ + from torchao.quantization import quantize_ + module, tensor_name = get_module_from_name(model, param_name) - module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) - quantize_(module, self.quantization_config.get_apply_tensor_subclass()) + + if self.pre_quantized: + module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device)) + if isinstance(module, nn.Linear): + module.extra_repr = types.MethodType(_linear_extra_repr, module) + else: + module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) + quantize_(module, self.quantization_config.get_apply_tensor_subclass()) def _process_model_after_weight_loading(self, model): """No process required for torchao quantized model"""