Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move lapack_info_check inside of onemkl_cusolver_host_task #238

Open
wants to merge 12 commits into
base: develop
Choose a base branch
from
27 changes: 9 additions & 18 deletions src/lapack/backends/cusolver/cusolver_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,26 +184,25 @@ inline void getrf_batch(const char *func_name, Func func, sycl::queue &queue, st
// Create new buffer with 32-bit ints then copy over results
std::uint64_t ipiv_size = stride_ipiv * batch_size;
sycl::buffer<int> ipiv32(sycl::range<1>{ ipiv_size });
sycl::buffer<int> devInfo{ batch_size };

queue.submit([&](sycl::handler &cgh) {
auto a_acc = a.template get_access<sycl::access::mode::read_write>(cgh);
auto ipiv32_acc = ipiv32.template get_access<sycl::access::mode::write>(cgh);
auto devInfo_acc = devInfo.template get_access<sycl::access::mode::write>(cgh);
auto scratch_acc = scratchpad.template get_access<sycl::access::mode::write>(cgh);
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
auto a_ = sc.get_mem<cuDataType *>(a_acc);
auto ipiv_ = sc.get_mem<int *>(ipiv32_acc);
auto devInfo_ = sc.get_mem<int *>(devInfo_acc);
auto scratch_ = sc.get_mem<cuDataType *>(scratch_acc);
cusolverStatus_t err;
int *dev_info_d = create_dev_info(batch_size);

// Uses scratch so sync between each cuSolver call
for (std::int64_t i = 0; i < batch_size; ++i) {
CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_ + stride_a * i,
lda, scratch_, ipiv_ + stride_ipiv * i, devInfo_ + i);
lda, scratch_, ipiv_ + stride_ipiv * i, dev_info_d + i);
}
lapack_info_check_and_free(dev_info_d, __func__, func_name, batch_size);
});
});

Expand All @@ -215,7 +214,6 @@ inline void getrf_batch(const char *func_name, Func func, sycl::queue &queue, st
[=](sycl::id<1> index) { ipiv_acc[index] = ipiv32_acc[index]; });
});

lapack_info_check(queue, devInfo, __func__, func_name, batch_size);
}

#define GETRF_STRIDED_BATCH_LAUNCHER(TYPE, CUSOLVER_ROUTINE) \
Expand Down Expand Up @@ -571,7 +569,6 @@ inline sycl::event getrf_batch(const char *func_name, Func func, sycl::queue &qu
// Allocate memory with 32-bit ints then copy over results
std::uint64_t ipiv_size = stride_ipiv * batch_size;
int *ipiv32 = (int *)malloc_device(sizeof(int) * ipiv_size, queue);
int *devInfo = (int *)malloc_device(sizeof(int) * batch_size, queue);

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dependencies loop can be simplified in the batch with depends_on_events(cgh, dependencies);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice one. I've made these changes.

Expand All @@ -581,16 +578,17 @@ inline sycl::event getrf_batch(const char *func_name, Func func, sycl::queue &qu
onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) {
auto handle = sc.get_handle(queue);
auto a_ = reinterpret_cast<cuDataType *>(a);
auto devInfo_ = reinterpret_cast<int *>(devInfo);
auto scratchpad_ = reinterpret_cast<cuDataType *>(scratchpad);
auto ipiv_ = reinterpret_cast<int *>(ipiv32);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also remove this unnecessary reinterpret cast as ipiv32 is already an int *

cusolverStatus_t err;
int *dev_info_d = create_dev_info(batch_size);

// Uses scratch so sync between each cuSolver call
for (int64_t i = 0; i < batch_size; ++i) {
CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_ + stride_a * i,
lda, scratchpad_, ipiv_ + stride_ipiv * i, devInfo_ + i);
lda, scratchpad_, ipiv_ + stride_ipiv * i, dev_info_d + i);
}
lapack_info_check_and_free(dev_info_d, __func__, func_name, batch_size);
});
});

Expand All @@ -607,10 +605,6 @@ inline sycl::event getrf_batch(const char *func_name, Func func, sycl::queue &qu
cgh.host_task([=](sycl::interop_handle ih) { sycl::free(ipiv32, queue); });
});

// lapack_info_check calls queue.wait()
lapack_info_check(queue, devInfo, __func__, func_name, batch_size);
sycl::free(devInfo, queue);

return done_casting;
}

Expand Down Expand Up @@ -656,7 +650,6 @@ inline sycl::event getrf_batch(const char *func_name, Func func, sycl::queue &qu
for (int64_t group_id = 0; group_id < group_count; ++group_id)
for (int64_t local_id = 0; local_id < group_sizes[group_id]; ++local_id, ++global_id)
ipiv32[global_id] = (int *)malloc_device(sizeof(int) * n[group_id], queue);
int *devInfo = (int *)malloc_device(sizeof(int) * batch_size, queue);

auto done = queue.submit([&](sycl::handler &cgh) {
int64_t num_events = dependencies.size();
Expand All @@ -669,16 +662,18 @@ inline sycl::event getrf_batch(const char *func_name, Func func, sycl::queue &qu
auto scratch_ = reinterpret_cast<cuDataType *>(scratchpad);
int64_t global_id = 0;
cusolverStatus_t err;
int *dev_info_d = create_dev_info(batch_size);

// Uses scratch so sync between each cuSolver call
for (int64_t group_id = 0; group_id < group_count; ++group_id) {
for (int64_t local_id = 0; local_id < group_sizes[group_id];
++local_id, ++global_id) {
CUSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m[group_id],
n[group_id], a_[global_id], lda[group_id], scratch_,
ipiv32[global_id], devInfo + global_id);
ipiv32[global_id], dev_info_d + global_id);
}
}
lapack_info_check_and_free(dev_info_d, __func__, func_name, batch_size);
});
});

Expand Down Expand Up @@ -712,10 +707,6 @@ inline sycl::event getrf_batch(const char *func_name, Func func, sycl::queue &qu
});
});

// lapack_info_check calls queue.wait()
lapack_info_check(queue, devInfo, __func__, func_name, batch_size);
sycl::free(devInfo, queue);

return done_freeing;
}

Expand Down
56 changes: 36 additions & 20 deletions src/lapack/backends/cusolver/cusolver_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,30 +280,46 @@ struct CudaEquivalentType<std::complex<double>> {

/* devinfo */

inline void get_cusolver_devinfo(sycl::queue &queue, sycl::buffer<int> &devInfo,
std::vector<int> &dev_info_) {
sycl::host_accessor<int, 1, sycl::access::mode::read> dev_info_acc{ devInfo };
for (unsigned int i = 0; i < dev_info_.size(); ++i)
dev_info_[i] = dev_info_acc[i];
// Accepts a int*, copies the memory from device to host,
// checks value does not indicate an error, frees the device memory
inline void lapack_info_check_and_free(int *dev_info_d, const char *func_name,
const char *cufunc_name, int num_elements = 1) {
int *dev_info_h = (int *)malloc(sizeof(int) * num_elements);
cuMemcpyDtoH(dev_info_h, reinterpret_cast<CUdeviceptr>(dev_info_d), sizeof(int) * num_elements);
for (uint32_t i = 0; i < num_elements; ++i) {
if (dev_info_h[i] > 0)
throw oneapi::mkl::lapack::computation_error(
func_name,
std::string(cufunc_name) + " failed with info = " + std::to_string(dev_info_h[i]),
dev_info_h[i]);
}
cuMemFree(reinterpret_cast<CUdeviceptr>(dev_info_d));
}

inline void get_cusolver_devinfo(sycl::queue &queue, const int *devInfo,
std::vector<int> &dev_info_) {
queue.wait();
queue.memcpy(dev_info_.data(), devInfo, sizeof(int));
// Allocates and returns a CUDA device pointer for cuSolver dev_info
inline int *create_dev_info(int num_elements = 1) {
CUdeviceptr dev_info_d;
cuMemAlloc(&dev_info_d, sizeof(int) * num_elements);
return reinterpret_cast<int *>(dev_info_d);
}

template <typename DEVINFO_T>
inline void lapack_info_check(sycl::queue &queue, DEVINFO_T devinfo, const char *func_name,
const char *cufunc_name, int dev_info_size = 1) {
std::vector<int> dev_info_(dev_info_size);
get_cusolver_devinfo(queue, devinfo, dev_info_);
for (const auto &val : dev_info_) {
if (val > 0)
throw oneapi::mkl::lapack::computation_error(
func_name, std::string(cufunc_name) + " failed with info = " + std::to_string(val),
val);
}
// Helper function for waiting on a vector of sycl events
inline void depends_on_events(sycl::handler &cgh,
const std::vector<sycl::event> &dependencies = {}) {
for (auto &e : dependencies)
cgh.depends_on(e);
}

// Asynchronously frees sycl USM `ptr` after waiting on events `dependencies`
template <typename T>
inline sycl::event free_async(sycl::queue &queue, T *ptr,
const std::vector<sycl::event> &dependencies = {}) {
sycl::event done = queue.submit([&](sycl::handler &cgh) {
depends_on_events(cgh, dependencies);

cgh.host_task([=](sycl::interop_handle ih) { sycl::free(ptr, queue); });
});
return done;
}

/* batched helpers */
Expand Down
Loading