diff --git a/accelerator/xpu_accelerator.py b/accelerator/xpu_accelerator.py index 22e9a055d669..3fed89d7200f 100644 --- a/accelerator/xpu_accelerator.py +++ b/accelerator/xpu_accelerator.py @@ -192,31 +192,31 @@ def supported_dtypes(self): @property def BFloat16Tensor(self): - return functools.partial(torch.tensor, dtype=torch.bfloat16, device='xpu') + return functools.partial(torch.tensor, dtype=torch.bfloat16, device=self._name) @property def ByteTensor(self): - return functools.partial(torch.tensor, dtype=torch.uint8, device='xpu') + return functools.partial(torch.tensor, dtype=torch.uint8, device=self._name) @property def DoubleTensor(self): - return functools.partial(torch.tensor, dtype=torch.double, device='xpu') + return functools.partial(torch.tensor, dtype=torch.double, device=self._name) @property def FloatTensor(self): - return functools.partial(torch.tensor, dtype=torch.float, device='xpu') + return functools.partial(torch.tensor, dtype=torch.float, device=self._name) @property def HalfTensor(self): - return functools.partial(torch.tensor, dtype=torch.half, device='xpu') + return functools.partial(torch.tensor, dtype=torch.half, device=self._name) @property def IntTensor(self): - return functools.partial(torch.tensor, dtype=torch.int, device='xpu') + return functools.partial(torch.tensor, dtype=torch.int, device=self._name) @property def LongTensor(self): - return functools.partial(torch.tensor, dtype=torch.long, device='xpu') + return functools.partial(torch.tensor, dtype=torch.long, device=self._name) def pin_memory(self, tensor, align_bytes=1): if align_bytes == 1: