Skip to content

Commit

Permalink
apply suggestion for device type
Browse files Browse the repository at this point in the history
  • Loading branch information
shiyang-weng committed May 15, 2024
1 parent 78ccc6e commit 93b8b00
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions accelerator/xpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,31 +192,31 @@ def supported_dtypes(self):

@property
def BFloat16Tensor(self):
return functools.partial(torch.tensor, dtype=torch.bfloat16, device='xpu')
return functools.partial(torch.tensor, dtype=torch.bfloat16, device=self._name)

@property
def ByteTensor(self):
return functools.partial(torch.tensor, dtype=torch.uint8, device='xpu')
return functools.partial(torch.tensor, dtype=torch.uint8, device=self._name)

@property
def DoubleTensor(self):
return functools.partial(torch.tensor, dtype=torch.double, device='xpu')
return functools.partial(torch.tensor, dtype=torch.double, device=self._name)

@property
def FloatTensor(self):
return functools.partial(torch.tensor, dtype=torch.float, device='xpu')
return functools.partial(torch.tensor, dtype=torch.float, device=self._name)

@property
def HalfTensor(self):
return functools.partial(torch.tensor, dtype=torch.half, device='xpu')
return functools.partial(torch.tensor, dtype=torch.half, device=self._name)

@property
def IntTensor(self):
return functools.partial(torch.tensor, dtype=torch.int, device='xpu')
return functools.partial(torch.tensor, dtype=torch.int, device=self._name)

@property
def LongTensor(self):
return functools.partial(torch.tensor, dtype=torch.long, device='xpu')
return functools.partial(torch.tensor, dtype=torch.long, device=self._name)

def pin_memory(self, tensor, align_bytes=1):
if align_bytes == 1:
Expand Down

0 comments on commit 93b8b00

Please sign in to comment.