diff --git a/server/text_generation_server/layers/awq/quantize/qmodule.py b/server/text_generation_server/layers/awq/quantize/qmodule.py index ca8caf5080c..c859db1be6b 100644 --- a/server/text_generation_server/layers/awq/quantize/qmodule.py +++ b/server/text_generation_server/layers/awq/quantize/qmodule.py @@ -1,6 +1,7 @@ # Copied logic from https://github.com/mit-han-lab/llm-awq/blob/f084f40bd996f3cf3a0633c1ad7d9d476c318aaa/awq/quantize/qmodule.py import math +from typing import Optional import torch import torch.nn as nn import awq_inference_engine # with CUDA kernels @@ -17,7 +18,9 @@ class WQLinear(nn.Module): - def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias): + def __init__( + self, w_bit, group_size, qweight, qzeros, scales, bias: Optional[torch.Tensor] + ): super().__init__() if w_bit not in [4]: @@ -35,10 +38,7 @@ def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias): self.qweight = qweight self.qzeros = qzeros self.scales = scales - if bias: - self.bias = bias - else: - self.bias = None + self.bias = bias @torch.no_grad() def forward(self, x): diff --git a/server/text_generation_server/layers/linear.py b/server/text_generation_server/layers/linear.py index d40b192f653..207383a50f5 100644 --- a/server/text_generation_server/layers/linear.py +++ b/server/text_generation_server/layers/linear.py @@ -217,7 +217,7 @@ def get_linear(weight, bias, quantize): qweight=weight.qweight, qzeros=weight.qzeros, scales=weight.scales, - bias=bias is not None, + bias=bias, ) except ImportError: raise NotImplementedError(