Skip to content

Commit

Permalink
Add xpu support for int8
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuhong61 authored and shiyan1121 committed Aug 20, 2024
1 parent 3f48fd4 commit c9cbb3e
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 4 deletions.
4 changes: 2 additions & 2 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def backward(ctx, grad_output):

def supports_igemmlt(device: torch.device) -> bool:
"""check if this device supports the optimized int8 kernel"""
if device == torch.device("cpu"):
if device == torch.device("cpu") or torch.device("xpu"):
return True
if torch.version.hip:
return False if BNB_HIP_VERSION < 601 else True
Expand Down Expand Up @@ -321,7 +321,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):

# Cast A to fp16
A_dtype = torch.float16
if A.device == torch.device("cpu"):
if A.device == torch.device("cpu") or torch.device("xpu"):
A_dtype = torch.bfloat16
if A.dtype != A_dtype:
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to {A_dtype} during quantization")
Expand Down
1 change: 1 addition & 0 deletions bitsandbytes/backends/cpu_xpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,7 @@ def gemm_4bit_impl(
if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(state, "op_context"):
assert state.op_context is not None
output = torch.ops.torch_ipex.ipex_woq_linear(A, state.op_context.get_data_handle())
# TODO: Support XPU optimization path
else:
dqB = dequantize_4bit_impl(B, state, blocksize=state.blocksize).t()
output = torch.matmul(A, dqB.to(A.dtype))
Expand Down
3 changes: 3 additions & 0 deletions bitsandbytes/backends/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def assert_on_xpu(tensors):


class XPUBackend(Backend):
mm_dequant_compute_dtype = torch.bfloat16
mm_dequant_output_dtype = torch.bfloat16

def double_quant(
self,
A: torch.Tensor,
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1800,7 +1800,7 @@ class COOSparseTensor:
def __init__(self, rows, cols, nnz, rowidx, colidx, values):
assert rowidx.dtype == torch.int32
assert colidx.dtype == torch.int32
if values.device == torch.device("cpu"):
if values.device == torch.device("cpu") or torch.device("xpu"):
assert values.dtype in [torch.bfloat16, torch.half, torch.float]
else:
assert values.dtype == torch.float16
Expand Down
8 changes: 7 additions & 1 deletion bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ def cpu(self):

def xpu(self):
# we store the 8-bit rows-major weight
B = self.data.contiguous().bfloat16().cpu()
B = self.data.contiguous().bfloat16().xpu()
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
if CBt is not None:
del CBt
Expand Down Expand Up @@ -661,6 +661,12 @@ def to(self, *args, **kwargs):
return self
else:
return self.cpu()
elif device.type == "xpu":
if self.data.dtype == torch.int8:
self.CB = self.data
return self
else:
return self.xpu()
else:
new_param = Int8Params(
super().to(device=device, dtype=dtype, non_blocking=non_blocking),
Expand Down

0 comments on commit c9cbb3e

Please sign in to comment.