diff --git a/src/blas/backends/rocblas/rocblas_extensions.cpp b/src/blas/backends/rocblas/rocblas_extensions.cpp index 315f9ce30..59c20b03f 100644 --- a/src/blas/backends/rocblas/rocblas_extensions.cpp +++ b/src/blas/backends/rocblas/rocblas_extensions.cpp @@ -88,27 +88,43 @@ void gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transpose tra throw unimplemented("blas", "gemmt", "for column_major layout"); } -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} - -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} - -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - sycl::buffer, 1> &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} - -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - sycl::buffer, 1> &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} +template +void omatcopy(const char *func_name, Func func, sycl::queue &queue, transpose trans, int64_t m, + int64_t n, T alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, + int64_t ldb) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(m, n, lda, ldb); + const T beta = 0; + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + const int64_t logical_m = (trans == oneapi::mkl::transpose::nontrans ? m : n); + const int64_t logical_n = (trans == oneapi::mkl::transpose::nontrans ? n : m); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + rocblas_status err; + ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(trans), + get_rocblas_operation(trans), logical_m, logical_n, + (rocDataType *)&alpha, a_, lda, (rocDataType *)&beta, nullptr, + lda, b_, ldb); + }); + }); +} + +#define OMATCOPY_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, TYPE alpha, \ + sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { \ + omatcopy(#ROCBLAS_ROUTINE, ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, b, ldb); \ + } + +OMATCOPY_LAUNCHER(float, rocblas_sgeam) +OMATCOPY_LAUNCHER(double, rocblas_dgeam) +OMATCOPY_LAUNCHER(std::complex, rocblas_cgeam) +OMATCOPY_LAUNCHER(std::complex, rocblas_zgeam) + +#undef OMATCOPY_LAUNCHER void imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, sycl::buffer &ab, int64_t lda, int64_t ldb) { @@ -130,31 +146,43 @@ void imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::co throw unimplemented("blas", "imatcopy", "for column_major layout"); } -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - float alpha, sycl::buffer &a, int64_t lda, float beta, - sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} - -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - double alpha, sycl::buffer &a, int64_t lda, double beta, - sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} - -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - std::complex beta, sycl::buffer, 1> &b, int64_t ldb, - sycl::buffer, 1> &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} - -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - std::complex beta, sycl::buffer, 1> &b, int64_t ldb, - sycl::buffer, 1> &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} +template +void omatadd(const char *func_name, Func func, sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, T alpha, sycl::buffer &a, int64_t lda, + T beta, sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(m, n, lda, ldb, ldc); + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + auto c_acc = c.template get_access(cgh); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + auto c_ = sc.get_mem(c_acc); + rocblas_status err; + ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(transa), + get_rocblas_operation(transb), m, n, (rocDataType *)&alpha, + a_, lda, (rocDataType *)&beta, b_, ldb, c_, ldc); + }); + }); +} + +#define OMATADD_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + TYPE alpha, sycl::buffer &a, int64_t lda, TYPE beta, \ + sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { \ + omatadd(#ROCBLAS_ROUTINE, ROCBLAS_ROUTINE, queue, transa, transb, m, n, alpha, a, lda, \ + beta, b, ldb, c, ldc); \ + } + +OMATADD_LAUNCHER(float, rocblas_sgeam) +OMATADD_LAUNCHER(double, rocblas_dgeam) +OMATADD_LAUNCHER(std::complex, rocblas_cgeam) +OMATADD_LAUNCHER(std::complex, rocblas_zgeam) + +#undef OMATADD_LAUNCHER // USM APIs @@ -220,31 +248,45 @@ sycl::event gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transp throw unimplemented("blas", "gemmt", "for column_major layout"); } -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - const float *a, int64_t lda, float *b, int64_t ldb, +template +sycl::event omatcopy(const char *func_name, Func func, sycl::queue &queue, transpose trans, + int64_t m, int64_t n, T alpha, const T *a, int64_t lda, T *b, int64_t ldb, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} - -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - const double *a, int64_t lda, double *b, int64_t ldb, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} - -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex *b, int64_t ldb, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} - -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex *b, int64_t ldb, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} + using rocDataType = typename RocEquivalentType::Type; + overflow_check(m, n, lda, ldb); + const T beta = 0; + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + const int64_t logical_m = (trans == oneapi::mkl::transpose::nontrans ? m : n); + const int64_t logical_n = (trans == oneapi::mkl::transpose::nontrans ? n : m); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = reinterpret_cast(a); + auto b_ = reinterpret_cast(b); + rocblas_status err; + ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(trans), + get_rocblas_operation(trans), logical_m, logical_n, + (rocDataType *)&alpha, a_, lda, (rocDataType *)&beta, nullptr, + lda, b_, ldb); + }); + }); + return done; +} + +#define OMATCOPY_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, TYPE alpha, \ + const TYPE *a, int64_t lda, TYPE *b, int64_t ldb, \ + const std::vector &dependencies) { \ + return omatcopy(#ROCBLAS_ROUTINE, ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, b, \ + ldb, dependencies); \ + } + +OMATCOPY_LAUNCHER_USM(float, rocblas_sgeam) +OMATCOPY_LAUNCHER_USM(double, rocblas_dgeam) +OMATCOPY_LAUNCHER_USM(std::complex, rocblas_cgeam) +OMATCOPY_LAUNCHER_USM(std::complex, rocblas_zgeam) + +#undef OMATCOPY_LAUNCHER_USM sycl::event imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, float *ab, int64_t lda, int64_t ldb, @@ -270,35 +312,44 @@ sycl::event imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, throw unimplemented("blas", "imatcopy", "for column_major layout"); } -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - float alpha, const float *a, int64_t lda, float beta, const float *b, - int64_t ldb, float *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} - -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - double alpha, const double *a, int64_t lda, double beta, const double *b, - int64_t ldb, double *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} - -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex beta, const std::complex *b, int64_t ldb, - std::complex *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} - -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex beta, const std::complex *b, int64_t ldb, - std::complex *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} +template +inline sycl::event omatadd(const char *func_name, Func func, sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, T alpha, const T *a, int64_t lda, + T beta, const T *b, int64_t ldb, T *c, int64_t ldc, + const std::vector &dependencies) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(m, n, lda, ldb, ldc); + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = reinterpret_cast(a); + auto b_ = reinterpret_cast(b); + auto c_ = reinterpret_cast(c); + rocblas_status err; + ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(transa), + get_rocblas_operation(transb), m, n, (rocDataType *)&alpha, + a_, lda, (rocDataType *)&beta, b_, ldb, c_, ldc); + }); + }); + return done; +} + +#define OMATADD_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, TYPE alpha, const TYPE *a, int64_t lda, TYPE beta, \ + const TYPE *b, int64_t ldb, TYPE *c, int64_t ldc, \ + const std::vector &dependencies) { \ + return omatadd(#ROCBLAS_ROUTINE, ROCBLAS_ROUTINE, queue, transa, transb, m, n, alpha, a, \ + lda, beta, b, ldb, c, ldc, dependencies); \ + } + +OMATADD_LAUNCHER_USM(float, rocblas_sgeam) +OMATADD_LAUNCHER_USM(double, rocblas_dgeam) +OMATADD_LAUNCHER_USM(std::complex, rocblas_cgeam) +OMATADD_LAUNCHER_USM(std::complex, rocblas_zgeam) + +#undef OMATADD_LAUNCHER_USM } // namespace column_major namespace row_major { @@ -361,27 +412,43 @@ void gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transpose tra throw unimplemented("blas", "gemmt", "for row_major layout"); } -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); -} - -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); -} - -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - sycl::buffer, 1> &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); -} - -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - sycl::buffer, 1> &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); -} +template +void omatcopy(const char *func_name, Func func, sycl::queue &queue, transpose trans, int64_t m, + int64_t n, T alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, + int64_t ldb) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(m, n, lda, ldb); + const T beta = 0; + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + const int64_t logical_m = (trans == oneapi::mkl::transpose::nontrans ? n : m); + const int64_t logical_n = (trans == oneapi::mkl::transpose::nontrans ? m : n); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + rocblas_status err; + ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(trans), + get_rocblas_operation(trans), logical_m, logical_n, + (rocDataType *)&alpha, a_, lda, (rocDataType *)&beta, nullptr, + lda, b_, ldb); + }); + }); +} + +#define OMATCOPY_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, TYPE alpha, \ + sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { \ + omatcopy(#ROCBLAS_ROUTINE, ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, b, ldb); \ + } + +OMATCOPY_LAUNCHER(float, rocblas_sgeam) +OMATCOPY_LAUNCHER(double, rocblas_dgeam) +OMATCOPY_LAUNCHER(std::complex, rocblas_cgeam) +OMATCOPY_LAUNCHER(std::complex, rocblas_zgeam) + +#undef OMATCOPY_LAUNCHER void imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, sycl::buffer &ab, int64_t lda, int64_t ldb) { @@ -403,31 +470,43 @@ void imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::co throw unimplemented("blas", "imatcopy", "for row_major layout"); } -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - float alpha, sycl::buffer &a, int64_t lda, float beta, - sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} - -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - double alpha, sycl::buffer &a, int64_t lda, double beta, - sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} - -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - std::complex beta, sycl::buffer, 1> &b, int64_t ldb, - sycl::buffer, 1> &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} - -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - std::complex beta, sycl::buffer, 1> &b, int64_t ldb, - sycl::buffer, 1> &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} +template +void omatadd(const char *func_name, Func func, sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, T alpha, sycl::buffer &a, int64_t lda, + T beta, sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(m, n, lda, ldb, ldc); + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + auto c_acc = c.template get_access(cgh); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + auto c_ = sc.get_mem(c_acc); + rocblas_status err; + ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(transa), + get_rocblas_operation(transb), n, m, (rocDataType *)&alpha, + a_, lda, (rocDataType *)&beta, b_, ldb, c_, ldc); + }); + }); +} + +#define OMATADD_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + TYPE alpha, sycl::buffer &a, int64_t lda, TYPE beta, \ + sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { \ + omatadd(#ROCBLAS_ROUTINE, ROCBLAS_ROUTINE, queue, transa, transb, m, n, alpha, a, lda, \ + beta, b, ldb, c, ldc); \ + } + +OMATADD_LAUNCHER(float, rocblas_sgeam) +OMATADD_LAUNCHER(double, rocblas_dgeam) +OMATADD_LAUNCHER(std::complex, rocblas_cgeam) +OMATADD_LAUNCHER(std::complex, rocblas_zgeam) + +#undef OMATADD_LAUNCHER // USM APIs @@ -493,31 +572,45 @@ sycl::event gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transp throw unimplemented("blas", "gemmt", "for row_major layout"); } -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - const float *a, int64_t lda, float *b, int64_t ldb, +template +sycl::event omatcopy(const char *func_name, Func func, sycl::queue &queue, transpose trans, + int64_t m, int64_t n, T alpha, const T *a, int64_t lda, T *b, int64_t ldb, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); -} - -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - const double *a, int64_t lda, double *b, int64_t ldb, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); -} - -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex *b, int64_t ldb, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); -} - -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex *b, int64_t ldb, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); -} + using rocDataType = typename RocEquivalentType::Type; + overflow_check(m, n, lda, ldb); + const T beta = 0; + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + const int64_t logical_m = (trans == oneapi::mkl::transpose::nontrans ? n : m); + const int64_t logical_n = (trans == oneapi::mkl::transpose::nontrans ? m : n); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = reinterpret_cast(a); + auto b_ = reinterpret_cast(b); + rocblas_status err; + ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(trans), + get_rocblas_operation(trans), logical_m, logical_n, + (rocDataType *)&alpha, a_, lda, (rocDataType *)&beta, nullptr, + ldb, b_, ldb); + }); + }); + return done; +} + +#define OMATCOPY_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, TYPE alpha, \ + const TYPE *a, int64_t lda, TYPE *b, int64_t ldb, \ + const std::vector &dependencies) { \ + return omatcopy(#ROCBLAS_ROUTINE, ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, b, \ + ldb, dependencies); \ + } + +OMATCOPY_LAUNCHER_USM(float, rocblas_sgeam) +OMATCOPY_LAUNCHER_USM(double, rocblas_dgeam) +OMATCOPY_LAUNCHER_USM(std::complex, rocblas_cgeam) +OMATCOPY_LAUNCHER_USM(std::complex, rocblas_zgeam) + +#undef OMATCOPY_LAUNCHER_USM sycl::event imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, float *ab, int64_t lda, int64_t ldb, @@ -543,35 +636,44 @@ sycl::event imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, throw unimplemented("blas", "imatcopy", "for row_major layout"); } -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - float alpha, const float *a, int64_t lda, float beta, const float *b, - int64_t ldb, float *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} - -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - double alpha, const double *a, int64_t lda, double beta, const double *b, - int64_t ldb, double *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} - -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex beta, const std::complex *b, int64_t ldb, - std::complex *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} - -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex beta, const std::complex *b, int64_t ldb, - std::complex *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} +template +inline sycl::event omatadd(const char *func_name, Func func, sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, T alpha, const T *a, int64_t lda, + T beta, const T *b, int64_t ldb, T *c, int64_t ldc, + const std::vector &dependencies) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(m, n, lda, ldb, ldc); + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = reinterpret_cast(a); + auto b_ = reinterpret_cast(b); + auto c_ = reinterpret_cast(c); + rocblas_status err; + ROCBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(transa), + get_rocblas_operation(transb), n, m, (rocDataType *)&alpha, + a_, lda, (rocDataType *)&beta, b_, ldb, c_, ldc); + }); + }); + return done; +} + +#define OMATADD_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, TYPE alpha, const TYPE *a, int64_t lda, TYPE beta, \ + const TYPE *b, int64_t ldb, TYPE *c, int64_t ldc, \ + const std::vector &dependencies) { \ + return omatadd(#ROCBLAS_ROUTINE, ROCBLAS_ROUTINE, queue, transa, transb, m, n, alpha, a, \ + lda, beta, b, ldb, c, ldc, dependencies); \ + } + +OMATADD_LAUNCHER_USM(float, rocblas_sgeam) +OMATADD_LAUNCHER_USM(double, rocblas_dgeam) +OMATADD_LAUNCHER_USM(std::complex, rocblas_cgeam) +OMATADD_LAUNCHER_USM(std::complex, rocblas_zgeam) + +#undef OMATADD_LAUNCHER_USM } // namespace row_major } // namespace rocblas diff --git a/src/blas/backends/rocblas/rocblas_helper.hpp b/src/blas/backends/rocblas/rocblas_helper.hpp index 75490e333..601a02a14 100644 --- a/src/blas/backends/rocblas/rocblas_helper.hpp +++ b/src/blas/backends/rocblas/rocblas_helper.hpp @@ -172,6 +172,16 @@ class hip_error : virtual public std::runtime_error { hipError_t hip_err; \ HIP_ERROR_FUNC(hipStreamSynchronize, hip_err, currentStreamId); +#define ROCBLAS_ERROR_FUNC_T_SYNC(name, func, err, handle, ...) \ + err = func(handle, __VA_ARGS__); \ + if (err != rocblas_status_success) { \ + throw rocblas_error(std::string(name) + std::string(" : "), err); \ + } \ + hipStream_t currentStreamId; \ + ROCBLAS_ERROR_FUNC(rocblas_get_stream, err, handle, ¤tStreamId); \ + hipError_t hip_err; \ + HIP_ERROR_FUNC(hipStreamSynchronize, hip_err, currentStreamId); + inline rocblas_operation get_rocblas_operation(oneapi::mkl::transpose trn) { switch (trn) { case oneapi::mkl::transpose::nontrans: return rocblas_operation_none;