Skip to content

Commit

Permalink
Changing __repr__ in torchao to show quantized Linear (#34202)
Browse files Browse the repository at this point in the history
* Changing __repr__ in torchao

* small update

* make style

* small update

* add LinearActivationQuantizedTensor

* remove some cases

* update imports & handle return None

* update
  • Loading branch information
MekkCyber authored Nov 5, 2024
1 parent f2d5dfb commit d2bae7e
Showing 1 changed file with 31 additions and 5 deletions.
36 changes: 31 additions & 5 deletions src/transformers/quantizers/quantizer_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import types
from typing import TYPE_CHECKING, Union

from packaging import version
Expand All @@ -30,9 +31,7 @@

if is_torch_available():
import torch

if is_torchao_available():
from torchao.quantization import quantize_
import torch.nn as nn

logger = logging.get_logger(__name__)

Expand All @@ -46,6 +45,25 @@ def find_parent(model, name):
return parent


def _quantization_type(weight):
from torchao.dtypes import AffineQuantizedTensor
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor

if isinstance(weight, AffineQuantizedTensor):
return f"{weight.__class__.__name__}({weight._quantization_type()})"

if isinstance(weight, LinearActivationQuantizedTensor):
return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})"


def _linear_extra_repr(self):
weight = _quantization_type(self.weight)
if weight is None:
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight=None"
else:
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={weight}"


class TorchAoHfQuantizer(HfQuantizer):
"""
Quantizer for torchao: https://github.com/pytorch/ao/
Expand Down Expand Up @@ -152,9 +170,17 @@ def create_quantized_param(
Each nn.Linear layer that needs to be quantized is processsed here.
First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module.
"""
from torchao.quantization import quantize_

module, tensor_name = get_module_from_name(model, param_name)
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
quantize_(module, self.quantization_config.get_apply_tensor_subclass())

if self.pre_quantized:
module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device))
if isinstance(module, nn.Linear):
module.extra_repr = types.MethodType(_linear_extra_repr, module)
else:
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
quantize_(module, self.quantization_config.get_apply_tensor_subclass())

def _process_model_after_weight_loading(self, model):
"""No process required for torchao quantized model"""
Expand Down

0 comments on commit d2bae7e

Please sign in to comment.