Skip to content

Commit

Permalink
bitsandbytes: simplify 8bit dequantization (#35068)
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas authored Dec 23, 2024
1 parent 05260a1 commit 401aa39
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions src/transformers/integrations/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,13 +363,14 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", dtype: "torch.dtype", st
if state.SCB is None:
state.SCB = weight.SCB

im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im)
im, Sim = bnb.functional.transform(im, "col32")
if state.CxB is None:
state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t().to(dtype)
if hasattr(bnb.functional, "int8_vectorwise_dequant"):
# Use bitsandbytes API if available (requires v0.45.0+)
dequantized = bnb.functional.int8_vectorwise_dequant(weight.data, state.SCB)
else:
# Multiply by (scale/127) to dequantize.
dequantized = weight.data * state.SCB.view(-1, 1) * 7.874015718698502e-3

return dequantized.to(dtype)


def _create_accelerate_new_hook(old_hook):
Expand Down

0 comments on commit 401aa39

Please sign in to comment.