diff --git a/src/brevitas/core/quant/int_base.py b/src/brevitas/core/quant/int_base.py index 338e5a433..678b71009 100644 --- a/src/brevitas/core/quant/int_base.py +++ b/src/brevitas/core/quant/int_base.py @@ -8,7 +8,6 @@ import brevitas from brevitas.core.function_wrapper import RoundSte from brevitas.core.function_wrapper import TensorClamp -from brevitas.core.quant.delay import DelayWrapper from brevitas.function.ops import max_int from brevitas.function.ops import min_int @@ -53,14 +52,12 @@ def __init__( signed: bool, input_view_impl: Module, float_to_int_impl: Module = RoundSte(), - tensor_clamp_impl: Module = TensorClamp(), - quant_delay_steps: int = 0): + tensor_clamp_impl: Module = TensorClamp()): super(IntQuant, self).__init__() self.float_to_int_impl = float_to_int_impl self.tensor_clamp_impl = tensor_clamp_impl self.signed = signed self.narrow_range = narrow_range - self.delay_wrapper = DelayWrapper(quant_delay_steps) self.input_view_impl = input_view_impl @brevitas.jit.script_method @@ -87,7 +84,6 @@ def forward(self, scale: Tensor, zero_point: Tensor, bit_width: Tensor, x: Tenso y_int = self.to_int(scale, zero_point, bit_width, x) y = y_int - zero_point y = y * scale - y = self.delay_wrapper(x, y) return y @@ -129,14 +125,12 @@ def __init__( signed: bool, input_view_impl: Module, float_to_int_impl: Module = RoundSte(), - tensor_clamp_impl: Module = TensorClamp(), - quant_delay_steps: int = 0): + tensor_clamp_impl: Module = TensorClamp()): super(DecoupledIntQuant, self).__init__() self.float_to_int_impl = float_to_int_impl self.tensor_clamp_impl = tensor_clamp_impl self.signed = signed self.narrow_range = narrow_range - self.delay_wrapper = DelayWrapper(quant_delay_steps) self.input_view_impl = input_view_impl @brevitas.jit.script_method @@ -172,5 +166,4 @@ def forward( y_int = self.to_int(pre_scale, pre_zero_point, bit_width, x) y = y_int - zero_point y = y * scale - y = self.delay_wrapper(x, y) return y diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index f28233aed..96a2a05b6 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -16,6 +16,7 @@ from brevitas import config from brevitas import is_dynamo_compiling from brevitas.core.function_wrapper.misc import Identity +from brevitas.core.quant.delay import DelayWrapper from brevitas.function import max_int from brevitas.inject import BaseInjector as Injector from brevitas.quant_tensor import _unpack_quant_tensor @@ -94,6 +95,7 @@ def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None: self.cache_inference_quant_weight_metadata_only = False self.cache_class = None # To be redefined by each class self.quant_tensor_class = None # To be redefined by each class + self.delay_wrapper = DelayWrapper(quant_injector.quant_delay_steps) @property def input_view_impl(self): @@ -136,7 +138,8 @@ def forward(self, x: torch.Tensor) -> Union[Tensor, QuantTensor]: else: out = self.create_quant_tensor(out) else: - out = self.tensor_quant(x) + quantized_out = self.tensor_quant(x) + out = self.delay_wrapper(x, quantized_out) if is_dynamo_compiling(): out = out[0] else: diff --git a/tests/brevitas/proxy/test_proxy.py b/tests/brevitas/proxy/test_proxy.py index 28c3eed9e..4f25bf2f4 100644 --- a/tests/brevitas/proxy/test_proxy.py +++ b/tests/brevitas/proxy/test_proxy.py @@ -80,3 +80,27 @@ def test_dynamic_act_proxy(self): model.act_quant.disable_quant = True assert model.act_quant.bit_width() is None + + def test_delay_wrapper_in_weight_proxy(self): + model = QuantLinear(10, 5, weight_quant=Int8WeightPerTensorFloat) + assert model.weight_quant.delay_wrapper is not None + + model.weight_quant.delay_wrapper.quant_delay_steps = 5 + for _ in range(5): + quantized_out = model.weight_quant(model.weight) + assert torch.equal(quantized_out, model.weight) + + quantized_out = model.weight_quant(model.weight) + assert not torch.equal(quantized_out, model.weight) + + def test_delay_wrapper_in_bias_proxy(self): + model = QuantLinear(10, 5, bias_quant=Int8BiasPerTensorFloatInternalScaling) + assert model.bias_quant.delay_wrapper is not None + + model.bias_quant.delay_wrapper.quant_delay_steps = 5 + for _ in range(5): + quantized_out = model.bias_quant(model.bias) + assert torch.equal(quantized_out, model.bias) + + quantized_out = model.bias_quant(model.bias) + assert not torch.equal(quantized_out, model.bias)