diff --git a/torchao/quantization/qat/__init__.py b/torchao/quantization/qat/__init__.py index 5f8f7ca35f..56a911e8d3 100644 --- a/torchao/quantization/qat/__init__.py +++ b/torchao/quantization/qat/__init__.py @@ -7,6 +7,7 @@ from .linear import ( Int4WeightOnlyQATQuantizer, Int8DynActInt4WeightQATQuantizer, + SelfCompressionQATQuantizer ) __all__ = [ @@ -14,4 +15,5 @@ "Int4WeightOnlyQATQuantizer", "Int4WeightOnlyEmbeddingQATQuantizer", "Int8DynActInt4WeightQATQuantizer", + "SelfCompressionQATQuantizer" ] diff --git a/torchao/quantization/qat/linear.py b/torchao/quantization/qat/linear.py index cbe6296407..c2f84e9b10 100644 --- a/torchao/quantization/qat/linear.py +++ b/torchao/quantization/qat/linear.py @@ -4,7 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Optional +from typing import Any, Optional, Tuple import torch import torch.nn.functional as F @@ -30,6 +30,82 @@ _get_qmin_qmax, ) +class SelfCompressionQATQuantizer(torch.nn.Module): + """ + Converts any model into one that uses Self-Compression (https://arxiv.org/pdf/2301.13142) + """ + + def __init__( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ): + super().__init__() + self.model = model + self.compression_layers = [] + + def prepare(self, module: torch.nn.Module): + for name, child in module.named_children(): + if isinstance(child, torch.nn.Linear): + self_compress_linear = SelfCompressionLinear( + child.in_features, + child.out_features, + child.bias is not None + ) + self.compression_layers.append(self_compress_linear) + setattr(module, name, self_compress_linear) + + self_compress_linear.weight.data.copy_(child.weight.data) + if child.bias is not None: + self_compress_linear.bias.data.copy_(child.bias.data) + + self.total_params = sum(p.numel() for p in module.parameters()) # TODO calculate only over linear layers..? + + def forward(self, x) -> Tuple[Any, int]: + compression_loss = torch.tensor([layer.get_layer_size() for layer in self.compression_layers]).sum() / self.total_params + + return self.model(x), compression_loss + +class SelfCompressionLinear(torch.nn.Linear): + """ + Self-Compression linear layer + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + *args, + **kwargs, + ) -> None: + super().__init__( + in_features, + out_features, + bias, + *args, + **kwargs, + ) + + self.float_exponents = torch.nn.Parameter(torch.full((out_features, 1), -8.)) + self.bit_depth = torch.nn.Parameter(torch.full((out_features, 1), 2.)) + + def get_layer_size(self) -> int: + """ + Calculates linear layer compressed size + """ + return self.bit_depth.relu().sum() * self.weight.shape[1] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass with weight quantization + """ + quant_max = torch.maximum(2. ** -self.float_exponents * self.weight, -2. ** (self.bit_depth.relu() - 1)) + quant_weight = torch.minimum(quant_max, 2. ** (self.bit_depth.relu() - 1) - 1) + rounded_weight = (quant_weight.round() - quant_weight).detach() + quant_weight + return F.linear(x, 2. ** self.float_exponents * rounded_weight) + class FakeQuantizedLinear(torch.nn.Linear): """