forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
DeviceAccelerator.cpp
65 lines (55 loc) · 2.28 KB
/
DeviceAccelerator.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
#include <torch/csrc/DeviceAccelerator.h>
#include <torch/csrc/utils/device_lazy_init.h>
namespace torch::accelerator {
void initModule(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
m.def("_accelerator_getAccelerator", []() {
// If no accelerator is currently available, raise an exception.
return c10::Device(at::getAccelerator(true).value());
});
m.def("_accelerator_deviceCount", []() {
auto device_type = at::accelerator::getAccelerator(false);
torch::utils::maybe_initialize_device(device_type);
return at::accelerator::deviceCount();
});
m.def("_accelerator_setDeviceIndex", [](c10::DeviceIndex device_index) {
// If device index is negative, no-op
if (device_index < 0) {
return;
}
const auto device_type = at::accelerator::getAccelerator(true).value();
torch::utils::maybe_initialize_device(device_type);
at::accelerator::setDeviceIndex(device_index);
});
m.def("_accelerator_getDeviceIndex", []() {
const auto device_type = at::accelerator::getAccelerator(true).value();
torch::utils::maybe_initialize_device(device_type);
return at::accelerator::getDeviceIndex();
});
m.def("_accelerator_setStream", [](c10::Stream stream) {
const auto device_type = at::accelerator::getAccelerator(true).value();
torch::utils::maybe_initialize_device(device_type);
// Set the current device to the device of stream
if (at::accelerator::getDeviceIndex() != stream.device_index()) {
at::accelerator::setDeviceIndex(stream.device_index());
}
at::accelerator::setCurrentStream(stream);
});
m.def("_accelerator_getStream", [](c10::DeviceIndex device_index) {
const auto device_type = at::accelerator::getAccelerator(true).value();
torch::utils::maybe_initialize_device(device_type);
return at::accelerator::getCurrentStream(device_index);
});
m.def("_accelerator_synchronizeDevice", [](c10::DeviceIndex device_index) {
const auto device_type = at::accelerator::getAccelerator(true).value();
if (!torch::utils::is_device_initialized(device_type)) {
return;
}
torch::utils::maybe_initialize_device(device_type);
{
py::gil_scoped_release no_gil;
at::accelerator::synchronizeDevice(device_index);
}
});
}
} // namespace torch::accelerator