diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 748d87726..0e0d233df 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -221,7 +221,7 @@ def backward(ctx, grad_output): def supports_igemmlt(device: torch.device) -> bool: """check if this device supports the optimized int8 kernel""" - if device == torch.device("cpu"): + if device == torch.device("cpu") or torch.device("xpu"): return True if torch.version.hip: return False if BNB_HIP_VERSION < 601 else True @@ -321,7 +321,7 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): # Cast A to fp16 A_dtype = torch.float16 - if A.device == torch.device("cpu"): + if A.device == torch.device("cpu") or torch.device("xpu"): A_dtype = torch.bfloat16 if A.dtype != A_dtype: warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to {A_dtype} during quantization") diff --git a/bitsandbytes/backends/cpu_xpu_common.py b/bitsandbytes/backends/cpu_xpu_common.py index 4c4902f9d..cd8f863c7 100644 --- a/bitsandbytes/backends/cpu_xpu_common.py +++ b/bitsandbytes/backends/cpu_xpu_common.py @@ -535,6 +535,7 @@ def gemm_4bit_impl( if ipex_cpu and _ipex_cpu_version_prereq(2, 3) and hasattr(state, "op_context"): assert state.op_context is not None output = torch.ops.torch_ipex.ipex_woq_linear(A, state.op_context.get_data_handle()) + # TODO: Support XPU optimization path else: dqB = dequantize_4bit_impl(B, state, blocksize=state.blocksize).t() output = torch.matmul(A, dqB.to(A.dtype)) diff --git a/bitsandbytes/backends/xpu.py b/bitsandbytes/backends/xpu.py index 02774fd1d..4681e7297 100644 --- a/bitsandbytes/backends/xpu.py +++ b/bitsandbytes/backends/xpu.py @@ -30,6 +30,9 @@ def assert_on_xpu(tensors): class XPUBackend(Backend): + mm_dequant_compute_dtype = torch.bfloat16 + mm_dequant_output_dtype = torch.bfloat16 + def double_quant( self, A: torch.Tensor, diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 6cf64df28..d486dc474 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -1800,7 +1800,7 @@ class COOSparseTensor: def __init__(self, rows, cols, nnz, rowidx, colidx, values): assert rowidx.dtype == torch.int32 assert colidx.dtype == torch.int32 - if values.device == torch.device("cpu"): + if values.device == torch.device("cpu") or torch.device("xpu"): assert values.dtype in [torch.bfloat16, torch.half, torch.float] else: assert values.dtype == torch.float16 diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 43110c3c1..e0b5a5040 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -625,7 +625,7 @@ def cpu(self): def xpu(self): # we store the 8-bit rows-major weight - B = self.data.contiguous().bfloat16().cpu() + B = self.data.contiguous().bfloat16().xpu() CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) if CBt is not None: del CBt @@ -661,6 +661,12 @@ def to(self, *args, **kwargs): return self else: return self.cpu() + elif device.type == "xpu": + if self.data.dtype == torch.int8: + self.CB = self.data + return self + else: + return self.xpu() else: new_param = Int8Params( super().to(device=device, dtype=dtype, non_blocking=non_blocking),