diff --git a/accelerator/xpu_accelerator.py b/accelerator/xpu_accelerator.py index 9c4a9c903f96..22e9a055d669 100644 --- a/accelerator/xpu_accelerator.py +++ b/accelerator/xpu_accelerator.py @@ -7,6 +7,7 @@ from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator import intel_extension_for_pytorch as ipex # noqa: F401 # type: ignore import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore +import functools class XPU_Accelerator(DeepSpeedAccelerator): @@ -191,31 +192,31 @@ def supported_dtypes(self): @property def BFloat16Tensor(self): - return torch.xpu.BFloat16Tensor + return functools.partial(torch.tensor, dtype=torch.bfloat16, device='xpu') @property def ByteTensor(self): - return torch.xpu.ByteTensor + return functools.partial(torch.tensor, dtype=torch.uint8, device='xpu') @property def DoubleTensor(self): - return torch.xpu.DoubleTensor + return functools.partial(torch.tensor, dtype=torch.double, device='xpu') @property def FloatTensor(self): - return torch.xpu.FloatTensor + return functools.partial(torch.tensor, dtype=torch.float, device='xpu') @property def HalfTensor(self): - return torch.xpu.HalfTensor + return functools.partial(torch.tensor, dtype=torch.half, device='xpu') @property def IntTensor(self): - return torch.xpu.IntTensor + return functools.partial(torch.tensor, dtype=torch.int, device='xpu') @property def LongTensor(self): - return torch.xpu.LongTensor + return functools.partial(torch.tensor, dtype=torch.long, device='xpu') def pin_memory(self, tensor, align_bytes=1): if align_bytes == 1: