From f09b7cb609d80b8031803f89255991dc8b35db69 Mon Sep 17 00:00:00 2001 From: Neo Zhang Jianyu Date: Fri, 5 Jul 2024 10:32:29 +0800 Subject: [PATCH] rm get_work_group_size() by local cache for performance (#8286) Co-authored-by: arthw <14088817+arthw@users.noreply.github.com> --- ggml/src/ggml-sycl.cpp | 10 ++++++---- ggml/src/ggml-sycl/common.hpp | 15 ++------------- ggml/src/ggml-sycl/norm.cpp | 18 +++++++++--------- 3 files changed, 17 insertions(+), 26 deletions(-) diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl.cpp index 76bad57e2320b..dde55335bb6da 100644 --- a/ggml/src/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl.cpp @@ -49,7 +49,7 @@ bool ggml_backend_is_sycl(ggml_backend_t backend); int ggml_backend_sycl_get_device(ggml_backend_t backend); static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer); static inline int get_sycl_env(const char *env_name, int default_val); -static inline int get_work_group_size(const sycl::device& device); + void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst, const void *ptr_src, size_t size) { @@ -1912,9 +1912,9 @@ static void soft_max_f32_submitter(const float * x, const float * mask, float * static void soft_max_f32_sycl(const float * x, const float * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, - queue_ptr stream) { + queue_ptr stream, int device) { int nth = WARP_SIZE; - int max_block_size = get_work_group_size(stream->get_device()); + int max_block_size = ggml_sycl_info().max_work_group_sizes[device]; while (nth < ncols_x && nth < max_block_size) nth *= 2; if (nth>max_block_size) nth = max_block_size; @@ -2156,6 +2156,8 @@ static ggml_sycl_device_info ggml_sycl_init() { info.devices[i].cc = 100 * prop.get_major_version() + 10 * prop.get_minor_version(); + + info.max_work_group_sizes[i] = prop.get_max_work_group_size(); } for (int id = 0; id < info.device_count; ++id) { @@ -3031,7 +3033,7 @@ inline void ggml_sycl_op_soft_max(ggml_backend_sycl_context & ctx, const ggml_te memcpy(&max_bias, dst->op_params + 1, sizeof(float)); soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, - nrows_x, nrows_y, scale, max_bias, main_stream); + nrows_x, nrows_y, scale, max_bias, main_stream, ctx.device); } inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 476d847ca575e..9a1c161b69db5 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -47,10 +47,6 @@ static int g_ggml_sycl_debug = 0; } \ }() -// #define DEBUG_SYCL_MALLOC - -static int g_work_group_size = 0; -// typedef sycl::half ggml_fp16_t; #define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP #define VER_4VEC 610 // todo for hardward optimize. @@ -193,6 +189,8 @@ struct ggml_sycl_device_info { sycl_device_info devices[GGML_SYCL_MAX_DEVICES] = {}; std::array default_tensor_split = {}; + + int max_work_group_sizes[GGML_SYCL_MAX_DEVICES] = {0}; }; const ggml_sycl_device_info & ggml_sycl_info(); @@ -295,15 +293,6 @@ struct ggml_backend_sycl_context { } }; -// common host functions - -static inline int get_work_group_size(const sycl::device& device) { - dpct::device_info prop; - dpct::get_device_info(prop, device); - return prop.get_max_work_group_size(); -} - - // common device functions static __dpct_inline__ float warp_reduce_sum(float x, diff --git a/ggml/src/ggml-sycl/norm.cpp b/ggml/src/ggml-sycl/norm.cpp index a77f7852ccecd..ed0fa7e31762b 100644 --- a/ggml/src/ggml-sycl/norm.cpp +++ b/ggml/src/ggml-sycl/norm.cpp @@ -181,7 +181,7 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa static void norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const float eps, - queue_ptr stream) { + queue_ptr stream, int device) { GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { const sycl::range<3> block_dims(1, 1, WARP_SIZE); @@ -197,7 +197,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols, }); } else { - const int work_group_size = get_work_group_size(stream->get_device()); + const int work_group_size = ggml_sycl_info().max_work_group_sizes[device]; const sycl::range<3> block_dims(1, 1, work_group_size); /* DPCT1049:17: The work-group size passed to the SYCL kernel may exceed @@ -222,7 +222,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols, static void group_norm_f32_sycl(const float* x, float* dst, const int num_groups, const int group_size, - const int ne_elements, queue_ptr stream) { + const int ne_elements, queue_ptr stream, int device) { static const float eps = 1e-6f; if (group_size < 1024) { const sycl::range<3> block_dims(1, 1, WARP_SIZE); @@ -240,7 +240,7 @@ static void group_norm_f32_sycl(const float* x, float* dst, }); } else { - const int work_group_size = get_work_group_size(stream->get_device()); + const int work_group_size = ggml_sycl_info().max_work_group_sizes[device]; const sycl::range<3> block_dims(1, 1, work_group_size); /* DPCT1049:18: The work-group size passed to the SYCL kernel may exceed @@ -269,7 +269,7 @@ static void group_norm_f32_sycl(const float* x, float* dst, static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const float eps, - queue_ptr stream) { + queue_ptr stream, int device) { GGML_ASSERT(ncols % WARP_SIZE == 0); // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE); if (ncols < 1024) { @@ -286,7 +286,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, }); } else { - const int work_group_size = get_work_group_size(stream->get_device()); + const int work_group_size = ggml_sycl_info().max_work_group_sizes[device]; const sycl::range<3> block_dims(1, 1, work_group_size); /* DPCT1049:19: The work-group size passed to the SYCL kernel may exceed @@ -322,7 +322,7 @@ void ggml_sycl_op_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0, float eps; memcpy(&eps, dst->op_params, sizeof(float)); - norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream); + norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device); (void)src1; (void)dst; @@ -340,7 +340,7 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* int num_groups = dst->op_params[0]; int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); - group_norm_f32_sycl(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream); + group_norm_f32_sycl(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream, ctx.device); (void)src1; (void)dst; @@ -362,7 +362,7 @@ void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* sr float eps; memcpy(&eps, dst->op_params, sizeof(float)); - rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream); + rms_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device); (void)src1; (void)dst;