diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index e9821cd36..d33dd1bc5 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -513,7 +513,7 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype if any(ctx.needs_input_grad[:2]): - ctx.tensors = (A, B) + ctx.tensors = (None, B) else: ctx.tensors = (None, None) @@ -526,7 +526,7 @@ def backward(ctx, grad_output): return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None req_gradA, _, _, req_gradBias, _ = ctx.needs_input_grad - A, B = ctx.tensors + _, B = ctx.tensors grad_A, grad_B, grad_bias = None, None, None