Skip to content

Commit

Permalink
xpu: support xpu backend from stock pytorch (>=2.4) (huggingface#31238)
Browse files Browse the repository at this point in the history
* xpu: support xpu backend from stock pytorch (>=2.4)

Fixes: huggingface#31237

XPU backend is available in the stock PyTorch starting from
version 2.4, see [1]. This commit extends huggingface transformers
to support XPU from both IPEX and the stock pytorch. IPEX is being
tried first.

See: pytorch/pytorch#114842
Requires: huggingface/accelerate#2825
Signed-off-by: Dmitry Rogozhkin <[email protected]>

* xpu: enable gpt2 and decision_transformer tests for xpu pytorch backend

Note that running xpu tests requires TRANSFORMERS_TEST_DEVICE_SPEC=spec.py
passed to the test runner:

  import torch
  DEVICE_NAME = 'xpu'
  MANUAL_SEED_FN = torch.xpu.manual_seed
  EMPTY_CACHE_FN = torch.xpu.empty_cache
  DEVICE_COUNT_FN = torch.xpu.device_count

Signed-off-by: Dmitry Rogozhkin <[email protected]>

---------

Signed-off-by: Dmitry Rogozhkin <[email protected]>
  • Loading branch information
dvrogozh authored Jun 14, 2024
1 parent 2081223 commit eed9ed6
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import torch
import torch.utils.checkpoint
from torch import nn
from torch.cuda.amp import autocast

from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
Expand Down Expand Up @@ -219,7 +218,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea
scale_factor /= float(self.layer_idx + 1)

# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
with autocast(enabled=False):
with torch.amp.autocast(query.device.type, enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.cuda.amp import autocast
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN
Expand Down Expand Up @@ -249,7 +248,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea
scale_factor /= float(self.layer_idx + 1)

# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
with autocast(enabled=False):
with torch.amp.autocast(query.device.type, enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
Expand Down
15 changes: 8 additions & 7 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,23 +813,24 @@ def require_torch_multi_npu(test_case):

def require_torch_xpu(test_case):
"""
Decorator marking a test that requires XPU and IPEX.
Decorator marking a test that requires XPU (in PyTorch).
These tests are skipped when Intel Extension for PyTorch isn't installed or it does not match current PyTorch
version.
These tests are skipped when XPU backend is not available. XPU backend might be available either via stock
PyTorch (>=2.4) or via Intel Extension for PyTorch. In the latter case, if IPEX is installed, its version
must match match current PyTorch version.
"""
return unittest.skipUnless(is_torch_xpu_available(), "test requires IPEX and an XPU device")(test_case)
return unittest.skipUnless(is_torch_xpu_available(), "test requires XPU device")(test_case)


def require_torch_multi_xpu(test_case):
"""
Decorator marking a test that requires a multi-XPU setup with IPEX and at least one XPU device. These tests are
skipped on a machine without IPEX or multiple XPUs.
Decorator marking a test that requires a multi-XPU setup (in PyTorch). These tests are skipped on a machine without
multiple XPUs.
To run *only* the multi_xpu tests, assuming all test names contain multi_xpu: $ pytest -sv ./tests -k "multi_xpu"
"""
if not is_torch_xpu_available():
return unittest.skip("test requires IPEX and at least one XPU device")(test_case)
return unittest.skip("test requires PyTorch XPU")(test_case)

return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case)

Expand Down
3 changes: 3 additions & 0 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
ExplicitEnum,
cached_property,
is_accelerate_available,
is_ipex_available,
is_safetensors_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
Expand Down Expand Up @@ -2136,6 +2137,8 @@ def _setup_devices(self) -> "torch.device":
if self.use_cpu:
device = torch.device("cpu")
elif is_torch_xpu_available():
if not is_ipex_available() and not is_accelerate_available("0.32.0.dev"):
raise ImportError("Using the XPU PyTorch backend requires `accelerate>=0.32.0.dev`")
device = torch.device("xpu:0")
torch.xpu.set_device(device)
elif is_torch_mlu_available():
Expand Down
11 changes: 8 additions & 3 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,13 +747,18 @@ def get_major_and_minor_from_version(full_version):

@lru_cache
def is_torch_xpu_available(check_device=False):
"Checks if `intel_extension_for_pytorch` is installed and potentially if a XPU is in the environment"
if not is_ipex_available():
"""
Checks if XPU acceleration is available either via `intel_extension_for_pytorch` or
via stock PyTorch (>=2.4) and potentially if a XPU is in the environment
"""
if not is_torch_available():
return False

import intel_extension_for_pytorch # noqa: F401
import torch

if is_ipex_available():
import intel_extension_for_pytorch # noqa: F401

if check_device:
try:
# Will raise a RuntimeError if no XPU is found
Expand Down

0 comments on commit eed9ed6

Please sign in to comment.