This repository has been archived by the owner on Sep 25, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4
/
08_07_PyTorch_Intel_XPU-orig.py
119 lines (106 loc) · 3.75 KB
/
08_07_PyTorch_Intel_XPU-orig.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import warnings
warnings.filterwarnings("ignore")
import argparse
import torch
GREEN = "\033[32m"
YELLOW = "\033[33m"
RED = "\033[31m"
RESET = "\033[0m"
def colorize(text, color):
return f"{color}{text}{RESET}"
def test_random_multiplication(dtype=torch.float32):
try:
print(
colorize(f"Random {dtype.__str__().split('.')[-1]} multiplication:", GREEN)
)
x = torch.rand(1, 1).to("xpu", dtype=dtype)
y = torch.rand(1, 1).to("xpu", dtype=dtype)
z = x * y
print(" Input x:", x.cpu())
print(" Input y:", y.cpu())
print(" Output z:", z.cpu())
except Exception as e:
print(
colorize(
f"Error during {dtype.__str__().split('.')} random multiplication: {e}",
RED,
)
)
exit(1)
def test_specific_multiplication(dtype=torch.float32):
try:
print(
colorize(
f"Specific {dtype.__str__().split('.')[-1]} multiplication:", GREEN
)
)
x = torch.tensor([[1.0, 2.0]]).to("xpu", dtype=dtype)
y = torch.tensor([[3.0, 4.0]]).to("xpu", dtype=dtype)
z_expected = torch.tensor([[3.0, 8.0]]).to("xpu", dtype=dtype)
z = x * y
print(" Input x:", x.cpu())
print(" Input y:", y.cpu())
print(" Output z:", z.cpu())
if torch.allclose(z, z_expected):
print("Calculation is correct")
else:
print("Calculation is incorrect")
except Exception as e:
print(
colorize(
f"Error during {dtype.__str__().split('.')} specific multiplication: {e}",
RED,
)
)
exit(1)
def main(args):
try:
print(f"torch version: {torch.__version__}")
torch.manual_seed(args.seed)
import intel_extension_for_pytorch as ipex
ipex.xpu.seed_all()
if ipex.xpu.is_available():
print(f"ipex version: {ipex.__version__}")
device_name = ipex.xpu.get_device_name()
print(f"Intel XPU device is available, Device name: {device_name}")
if not ipex.xpu.has_fp64_dtype():
print(
colorize(
"Warning: Native FP64 type not supported on this platform",
YELLOW,
)
)
data_types = [
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.float16,
torch.float32,
torch.bfloat16,
torch.float64,
]
for dtype in data_types:
if dtype == torch.float64 and not ipex.xpu.has_fp64_dtype():
print(
colorize(
"Skipping direct FP64 multiplication tests, as the device doesn't support it.",
YELLOW,
)
)
continue
test_random_multiplication(dtype)
test_specific_multiplication(dtype)
else:
print("Warning: Intel XPU device is not available")
raise Exception("Intel XPU device not detected")
print(colorize("PyTorch XPU tests successful!", GREEN))
except ImportError as e:
print(colorize(f"Failed to import Intel Extension for PyTorch: {e}", RED))
except Exception as e:
print(colorize("An error occurred during the test: {e}", RED))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test Intel XPU device")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
args = parser.parse_args()
main(args)