diff --git a/src/transformers/integrations/bitnet.py b/src/transformers/integrations/bitnet.py index 3386bdcb43b27c..0b50f9738afb69 100644 --- a/src/transformers/integrations/bitnet.py +++ b/src/transformers/integrations/bitnet.py @@ -127,6 +127,8 @@ class BitLinear(nn.Module): def __init__(self, in_features: int, out_features: int, bias: bool, device=None, dtype=None): super().__init__() self.dtype = dtype + self.in_features = in_features + self.out_features = out_features self.register_buffer( "weight", torch.zeros(