From 93b8b002f93a03971270a083f7cee4df8bbe852e Mon Sep 17 00:00:00 2001 From: wengshiy Date: Wed, 15 May 2024 05:42:05 +0000 Subject: [PATCH] apply suggestion for device type --- accelerator/xpu_accelerator.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/accelerator/xpu_accelerator.py b/accelerator/xpu_accelerator.py index 22e9a055d669..3fed89d7200f 100644 --- a/accelerator/xpu_accelerator.py +++ b/accelerator/xpu_accelerator.py @@ -192,31 +192,31 @@ def supported_dtypes(self): @property def BFloat16Tensor(self): - return functools.partial(torch.tensor, dtype=torch.bfloat16, device='xpu') + return functools.partial(torch.tensor, dtype=torch.bfloat16, device=self._name) @property def ByteTensor(self): - return functools.partial(torch.tensor, dtype=torch.uint8, device='xpu') + return functools.partial(torch.tensor, dtype=torch.uint8, device=self._name) @property def DoubleTensor(self): - return functools.partial(torch.tensor, dtype=torch.double, device='xpu') + return functools.partial(torch.tensor, dtype=torch.double, device=self._name) @property def FloatTensor(self): - return functools.partial(torch.tensor, dtype=torch.float, device='xpu') + return functools.partial(torch.tensor, dtype=torch.float, device=self._name) @property def HalfTensor(self): - return functools.partial(torch.tensor, dtype=torch.half, device='xpu') + return functools.partial(torch.tensor, dtype=torch.half, device=self._name) @property def IntTensor(self): - return functools.partial(torch.tensor, dtype=torch.int, device='xpu') + return functools.partial(torch.tensor, dtype=torch.int, device=self._name) @property def LongTensor(self): - return functools.partial(torch.tensor, dtype=torch.long, device='xpu') + return functools.partial(torch.tensor, dtype=torch.long, device=self._name) def pin_memory(self, tensor, align_bytes=1): if align_bytes == 1: