Skip to content

Commit

Permalink
Aligned to cuda accelerator
Browse files Browse the repository at this point in the history
  • Loading branch information
shiyang-weng committed May 14, 2024
1 parent 82ce4ae commit 78ccc6e
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions accelerator/xpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator
import intel_extension_for_pytorch as ipex # noqa: F401 # type: ignore
import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
import functools


class XPU_Accelerator(DeepSpeedAccelerator):
Expand Down Expand Up @@ -191,31 +192,31 @@ def supported_dtypes(self):

@property
def BFloat16Tensor(self):
return torch.xpu.BFloat16Tensor
return functools.partial(torch.tensor, dtype=torch.bfloat16, device='xpu')

@property
def ByteTensor(self):
return torch.xpu.ByteTensor
return functools.partial(torch.tensor, dtype=torch.uint8, device='xpu')

@property
def DoubleTensor(self):
return torch.xpu.DoubleTensor
return functools.partial(torch.tensor, dtype=torch.double, device='xpu')

@property
def FloatTensor(self):
return torch.xpu.FloatTensor
return functools.partial(torch.tensor, dtype=torch.float, device='xpu')

@property
def HalfTensor(self):
return torch.xpu.HalfTensor
return functools.partial(torch.tensor, dtype=torch.half, device='xpu')

@property
def IntTensor(self):
return torch.xpu.IntTensor
return functools.partial(torch.tensor, dtype=torch.int, device='xpu')

@property
def LongTensor(self):
return torch.xpu.LongTensor
return functools.partial(torch.tensor, dtype=torch.long, device='xpu')

def pin_memory(self, tensor, align_bytes=1):
if align_bytes == 1:
Expand Down

0 comments on commit 78ccc6e

Please sign in to comment.