From 0dcacb17bc179875ad9e5867b68d123c622a2655 Mon Sep 17 00:00:00 2001 From: ThanatosShinji Date: Sun, 2 Jun 2024 20:34:21 +0800 Subject: [PATCH] support 2 layers of sycl --- CMakePresets.json | 2 +- neural_speed/core/layers/ne_bestla.cpp | 15 + neural_speed/core/layers/ne_bestla_sycl.cpp | 1 - neural_speed/core/ne_bestla.h | 1 + neural_speed/core/ne_layers.c | 306 +++++--------------- neural_speed/models/llama/llama_utils.cpp | 2 +- 6 files changed, 91 insertions(+), 236 deletions(-) diff --git a/CMakePresets.json b/CMakePresets.json index d8b7ea609..c4c4a70e2 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -152,7 +152,7 @@ "description": "x64 SYCL", "inherits": "x64-debug-sycl", "cacheVariables": { - "CMAKE_BUILD_TYPE": "Release" + "CMAKE_BUILD_TYPE": "RelWithDebInfo" } } ] diff --git a/neural_speed/core/layers/ne_bestla.cpp b/neural_speed/core/layers/ne_bestla.cpp index d8c217af0..7c3196e62 100644 --- a/neural_speed/core/layers/ne_bestla.cpp +++ b/neural_speed/core/layers/ne_bestla.cpp @@ -173,6 +173,21 @@ static inline int ne_nrows(const struct ne_tensor* tensor) { return tensor->ne[1] * tensor->ne[2] * tensor->ne[3]; } +ne_backend bestla_backend_support(struct ne_tensor* src0, struct ne_tensor* src1, enum ne_op op) { + ne_backend bk = NE_BACKEND_CPU; + switch (op) { + case NE_OP_MUL_MAT: { + struct ne_tensor* wei = src0; + if (src0->type == NE_TYPE_BTLA) { + bk = NE_BACKEND_SYCL; + } + } break; + default: + break; + } + return bk; +} + bool bestla_sycl_support(struct ne_tensor* node) { bool support = false; switch (node->op) { diff --git a/neural_speed/core/layers/ne_bestla_sycl.cpp b/neural_speed/core/layers/ne_bestla_sycl.cpp index 7d022d720..422f90b3e 100644 --- a/neural_speed/core/layers/ne_bestla_sycl.cpp +++ b/neural_speed/core/layers/ne_bestla_sycl.cpp @@ -79,7 +79,6 @@ void bestla_device_memcpy(void* dstptr, const void* srcptr, size_t size, void* q if (queue && srcptr && dstptr) { auto ptr = (sycl::queue*)queue; ptr->memcpy(dstptr, srcptr, size); - ptr->wait(); } } diff --git a/neural_speed/core/ne_bestla.h b/neural_speed/core/ne_bestla.h index df833e152..afdace37e 100644 --- a/neural_speed/core/ne_bestla.h +++ b/neural_speed/core/ne_bestla.h @@ -80,6 +80,7 @@ void bestla_mul(int batch, int vsize, const float* tensor, const float* vector, void bestla_add(int batch, int vsize, const float* tensor, const float* vector, int vstep, float* out); bool bestla_sycl_support(struct ne_tensor* node); +enum ne_backend bestla_backend_support(struct ne_tensor* src0, struct ne_tensor* src1, enum ne_op op); bool bestla_support(struct ne_tensor* node, int n_threads, size_t* workspace, size_t* dev_workspace); #ifdef NS_SYCL diff --git a/neural_speed/core/ne_layers.c b/neural_speed/core/ne_layers.c index e9913bff4..a6649183c 100644 --- a/neural_speed/core/ne_layers.c +++ b/neural_speed/core/ne_layers.c @@ -1204,6 +1204,10 @@ struct ne_tensor* ne_dup_tensor(struct ne_context* ctx, const struct ne_tensor* return ne_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, NULL, src->size, src->backend); } +struct ne_tensor* ne_dup_tensor(struct ne_context* ctx, const struct ne_tensor* src, enum ne_backend bk) { + return ne_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, NULL, src->size, bk); +} + struct ne_tensor* ne_set_zero(struct ne_tensor* tensor) { memset(tensor->data, 0, ne_nbytes(tensor)); return tensor; @@ -1447,6 +1451,16 @@ struct ne_tensor* ne_view_tensor(struct ne_context* ctx, const struct ne_tensor* return result; } +struct ne_tensor* ne_view_tensor(struct ne_context* ctx, const struct ne_tensor* src, enum ne_backend bk) { + struct ne_tensor* result = ne_new_tensor_impl(ctx, src->type, src->n_dims, src->ne, src->data, src->size, bk); + + result->nb[0] = src->nb[0]; + result->nb[1] = src->nb[1]; + result->nb[2] = src->nb[2]; + result->nb[3] = src->nb[3]; + + return result; +} //////////////////////////////////////////////////////////////////////////////// #ifdef NS_TP_MODEL // ne_dump_tensor @@ -1480,9 +1494,6 @@ struct ne_tensor* ne_dup_impl(struct ne_context* ctx, struct ne_tensor* a, bool result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = NULL; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -1509,16 +1520,14 @@ struct ne_tensor* ne_add_impl(struct ne_context* ctx, struct ne_tensor* a, struc if (!inplace && (a->grad || b->grad)) { is_node = true; } - - struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a) : ne_dup_tensor(ctx, a); + enum ne_op op = NE_OP_ADD; + enum ne_backend bk = bestla_backend_support(a, b, op); + struct ne_tensor* result = inplace ? ne_view_tensor(ctx, a, bk) : ne_dup_tensor(ctx, a, bk); result->op = NE_OP_ADD; - result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; + result->grad = NULL; result->src0 = a; result->src1 = b; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -1548,9 +1557,6 @@ struct ne_tensor* ne_add1_impl(struct ne_context* ctx, struct ne_tensor* a, stru result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = b; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -1596,9 +1602,6 @@ struct ne_tensor* ne_acc_impl(struct ne_context* ctx, struct ne_tensor* a, struc result->src0 = a; result->src1 = b; result->opt[0] = c; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -1755,9 +1758,6 @@ struct ne_tensor* ne_sub_impl(struct ne_context* ctx, struct ne_tensor* a, struc result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = b; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -1794,9 +1794,6 @@ struct ne_tensor* ne_mul_impl(struct ne_context* ctx, struct ne_tensor* a, struc result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = b; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -1808,9 +1805,6 @@ struct ne_tensor* ne_tanh(struct ne_context* ctx, struct ne_tensor* a) { result->op = NE_OP_TANH; result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -1843,9 +1837,6 @@ struct ne_tensor* ne_div_impl(struct ne_context* ctx, struct ne_tensor* a, struc result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = b; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -1872,9 +1863,6 @@ struct ne_tensor* ne_sqr_impl(struct ne_context* ctx, struct ne_tensor* a, bool result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = NULL; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -1897,9 +1885,6 @@ struct ne_tensor* ne_sqrt_impl(struct ne_context* ctx, struct ne_tensor* a, bool result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = NULL; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -1922,9 +1907,6 @@ struct ne_tensor* ne_log_impl(struct ne_context* ctx, struct ne_tensor* a, bool result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = NULL; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -1947,9 +1929,6 @@ struct ne_tensor* ne_sum(struct ne_context* ctx, struct ne_tensor* a) { result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = NULL; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -1973,9 +1952,6 @@ struct ne_tensor* ne_sum_rows(struct ne_context* ctx, struct ne_tensor* a) { result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = NULL; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -1996,9 +1972,6 @@ struct ne_tensor* ne_mean(struct ne_context* ctx, struct ne_tensor* a) { result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = NULL; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2023,9 +1996,6 @@ struct ne_tensor* ne_repeat(struct ne_context* ctx, struct ne_tensor* a, struct result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = b; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2044,9 +2014,6 @@ struct ne_tensor* ne_abs_impl(struct ne_context* ctx, struct ne_tensor* a, bool result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = NULL; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2069,9 +2036,6 @@ struct ne_tensor* ne_sgn_impl(struct ne_context* ctx, struct ne_tensor* a, bool result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = NULL; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2094,9 +2058,6 @@ struct ne_tensor* ne_neg_impl(struct ne_context* ctx, struct ne_tensor* a, bool result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = NULL; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2119,9 +2080,6 @@ struct ne_tensor* ne_step_impl(struct ne_context* ctx, struct ne_tensor* a, bool result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = NULL; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2144,9 +2102,6 @@ struct ne_tensor* ne_relu_impl(struct ne_context* ctx, struct ne_tensor* a, bool result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = NULL; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2169,9 +2124,6 @@ struct ne_tensor* ne_gelu_impl(struct ne_context* ctx, struct ne_tensor* a, bool result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = NULL; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2194,9 +2146,6 @@ struct ne_tensor* ne_silu_impl(struct ne_context* ctx, struct ne_tensor* a, bool result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = NULL; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2220,9 +2169,6 @@ struct ne_tensor* ne_silu_back(struct ne_context* ctx, struct ne_tensor* a, stru result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = b; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2243,9 +2189,6 @@ struct ne_tensor* ne_norm_impl(struct ne_context* ctx, struct ne_tensor* a, bool result->op = NE_OP_NORM; result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2271,9 +2214,6 @@ struct ne_tensor* ne_rms_norm_impl(struct ne_context* ctx, struct ne_tensor* a, result->op = NE_OP_RMS_NORM; result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2299,9 +2239,6 @@ struct ne_tensor* ne_rms_norm_back(struct ne_context* ctx, struct ne_tensor* a, result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = b; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2323,9 +2260,6 @@ struct ne_tensor* ne_mul_mat(struct ne_context* ctx, struct ne_tensor* a, struct result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = b; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2350,9 +2284,6 @@ struct ne_tensor* ne_mul_mat_with_bias(struct ne_context* ctx, struct ne_tensor* result->src0 = w; result->src1 = a; result->opt[0] = b; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2388,9 +2319,6 @@ struct ne_tensor* ne_mul_mat_id(struct ne_context* ctx, struct ne_tensor* const NE_ASSERT(!ne_is_transposed(a)); result->opt[i] = a; } - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2435,9 +2363,6 @@ struct ne_tensor* ne_mul_id_ffn_silu(struct ne_context* ctx, struct ne_tensor* c } result->opt[24] = tmp; result->opt[25] = tmp1; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } // struct ne_tensor *result = ne_ffn_silu(ctx,gate[row_id], down[row_id],up[row_id], b); return result; } @@ -2494,9 +2419,6 @@ struct ne_tensor* ne_argsort(struct ne_context* ctx, struct ne_tensor* a) { result->op = NE_OP_ARGSORT; result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2509,9 +2431,6 @@ struct ne_tensor* ne_top_k(struct ne_context* ctx, struct ne_tensor* a, int k) { result = ne_view_4d(ctx, result, k, result->ne[1], result->ne[2], result->ne[3], result->nb[1], result->nb[2], result->nb[3], 0); - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } // ne_mul_qkv @@ -2541,9 +2460,6 @@ struct ne_tensor* ne_mul_qkv(struct ne_context* ctx, struct ne_tensor* qw, struc result->src1 = qw; result->opt[0] = kw; result->opt[1] = vw; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2574,9 +2490,6 @@ struct ne_tensor* ne_ffn_silu(struct ne_context* ctx, struct ne_tensor* w1, stru result->opt[1] = w3; result->opt[2] = tmp; result->opt[3] = tmp1; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2603,9 +2516,6 @@ struct ne_tensor* ne_ffn_add_gelu(struct ne_context* ctx, struct ne_tensor* w1, result->opt[1] = b1; result->opt[2] = b2; result->opt[3] = tmp; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2630,9 +2540,6 @@ struct ne_tensor* ne_ffn_gelu(struct ne_context* ctx, struct ne_tensor* w1, stru result->src1 = w1; result->opt[0] = w2; result->opt[1] = tmp; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2661,9 +2568,6 @@ struct ne_tensor* ne_ffn_gelu_mul(struct ne_context* ctx, struct ne_tensor* w1, result->opt[1] = w3; result->opt[2] = tmp; result->opt[3] = tmp1; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } // ne_scale @@ -2684,9 +2588,6 @@ struct ne_tensor* ne_scale_impl(struct ne_context* ctx, struct ne_tensor* a, str result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = b; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2730,9 +2631,6 @@ struct ne_tensor* ne_set_impl(struct ne_context* ctx, struct ne_tensor* a, struc result->src0 = a; result->src1 = b; result->opt[0] = c; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2782,9 +2680,6 @@ struct ne_tensor* ne_cpy_impl(struct ne_context* ctx, struct ne_tensor* a, struc result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = b; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2811,9 +2706,6 @@ struct ne_tensor* ne_cont_impl(struct ne_context* ctx, struct ne_tensor* a, bool result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = NULL; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2845,9 +2737,6 @@ struct ne_tensor* ne_reshape(struct ne_context* ctx, struct ne_tensor* a, struct result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = NULL; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2868,9 +2757,6 @@ struct ne_tensor* ne_reshape_1d(struct ne_context* ctx, struct ne_tensor* a, int result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = NULL; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2891,9 +2777,6 @@ struct ne_tensor* ne_reshape_2d(struct ne_context* ctx, struct ne_tensor* a, int result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = NULL; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2914,9 +2797,6 @@ struct ne_tensor* ne_reshape_3d(struct ne_context* ctx, struct ne_tensor* a, int result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = NULL; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2938,9 +2818,6 @@ struct ne_tensor* ne_reshape_4d(struct ne_context* ctx, struct ne_tensor* a, int result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = NULL; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2964,9 +2841,6 @@ struct ne_tensor* ne_view_1d(struct ne_context* ctx, struct ne_tensor* a, int64_ if (is_node) { memcpy(result->padding, &offset, sizeof(offset)); } - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -2996,9 +2870,6 @@ struct ne_tensor* ne_view_2d(struct ne_context* ctx, struct ne_tensor* a, int64_ if (is_node) { memcpy(result->padding, &offset, sizeof(offset)); } - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -3028,9 +2899,6 @@ struct ne_tensor* ne_view_3d(struct ne_context* ctx, struct ne_tensor* a, int64_ if (is_node) { memcpy(result->padding, &offset, sizeof(offset)); } - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -3060,9 +2928,6 @@ struct ne_tensor* ne_view_4d(struct ne_context* ctx, struct ne_tensor* a, int64_ if (is_node) { memcpy(result->padding, &offset, sizeof(offset)); } - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -3123,9 +2988,6 @@ struct ne_tensor* ne_permute(struct ne_context* ctx, struct ne_tensor* a, int ax result->padding[2] = axis2; result->padding[3] = axis3; } - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -3150,9 +3012,6 @@ struct ne_tensor* ne_transpose(struct ne_context* ctx, struct ne_tensor* a) { result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = NULL; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -3181,9 +3040,6 @@ struct ne_tensor* ne_get_rows(struct ne_context* ctx, struct ne_tensor* a, struc result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = b; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -3209,9 +3065,6 @@ struct ne_tensor* ne_get_rows_back(struct ne_context* ctx, struct ne_tensor* a, result->src0 = a; result->src1 = b; result->opt[0] = c; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -3232,9 +3085,6 @@ struct ne_tensor* ne_diag(struct ne_context* ctx, struct ne_tensor* a) { result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = NULL; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -3272,9 +3122,6 @@ struct ne_tensor* ne_diag_mask_inf_impl(struct ne_context* ctx, struct ne_tensor result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = b; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -3321,9 +3168,6 @@ struct ne_tensor* ne_diag_mask_zero_impl(struct ne_context* ctx, struct ne_tenso result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = b; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -3374,9 +3218,6 @@ struct ne_tensor* ne_padding_mask_inf_impl(struct ne_context* ctx, struct ne_ten result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = b; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -3403,9 +3244,6 @@ struct ne_tensor* ne_soft_max_impl(struct ne_context* ctx, struct ne_tensor* a, result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = NULL; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -3473,9 +3311,6 @@ struct ne_tensor* ne_rope_impl(struct ne_context* ctx, struct ne_tensor* a, int result->src1 = b; result->opt[0] = cossin; result->opt[1] = factor; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -3541,9 +3376,6 @@ struct ne_tensor* ne_rope_back(struct ne_context* ctx, struct ne_tensor* a, int result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = b; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -3596,9 +3428,6 @@ struct ne_tensor* ne_alibi(struct ne_context* ctx, struct ne_tensor* a, int n_pa result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = b; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -3628,9 +3457,6 @@ struct ne_tensor* ne_clamp(struct ne_context* ctx, struct ne_tensor* a, float mi result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = b; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -3659,9 +3485,6 @@ struct ne_tensor* ne_conv_1d_1s(struct ne_context* ctx, struct ne_tensor* a, str result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = b; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -3690,9 +3513,6 @@ struct ne_tensor* ne_conv_1d_2s(struct ne_context* ctx, struct ne_tensor* a, str result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = b; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -3727,9 +3547,6 @@ NE_API struct ne_tensor* ne_conv_1d(struct ne_context* ctx, struct ne_tensor* a, result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->src1 = b; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -3772,9 +3589,6 @@ struct ne_tensor* ne_flash_attn(struct ne_context* ctx, struct ne_tensor* q, str result->opt[1] = tmp_t; *(float*)result->padding = scale; *(ne_attn_flags_t*)&result->padding[sizeof(scale)] = flags; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -3800,9 +3614,6 @@ struct ne_tensor* ne_flash_attn_kv_update(struct ne_context* ctx, struct ne_tens result->src0 = cache; result->src1 = cur; result->opt[0] = params; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } struct ne_tensor* ne_flash_attn_update_k(struct ne_context* ctx, struct ne_tensor* cache, struct ne_tensor* cur, @@ -3838,9 +3649,6 @@ struct ne_tensor* ne_flash_ff(struct ne_context* ctx, struct ne_tensor* a, struc result->opt[0] = b1; result->opt[1] = c0; result->opt[2] = c1; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -3863,9 +3671,6 @@ struct ne_tensor* ne_map_unary_impl_f32(struct ne_context* ctx, struct ne_tensor result->grad = is_node ? ne_dup_tensor(ctx, result) : NULL; result->src0 = a; result->opt[0] = addr_tensor; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -3899,9 +3704,6 @@ struct ne_tensor* ne_map_binary_impl_f32(struct ne_context* ctx, struct ne_tenso result->src0 = a; result->src1 = b; result->opt[0] = addr_tensor; - if (!bestla_sycl_support(result)) { - result->backend = NE_BACKEND_CPU; - } return result; } @@ -4608,8 +4410,39 @@ static void ne_compute_forward_dup(const struct ne_compute_params* params, const static void ne_compute_forward_add_f32(const struct ne_compute_params* params, const struct ne_tensor* src0, const struct ne_tensor* src1, struct ne_tensor* dst) { NE_ASSERT(ne_can_repeat_rows(src1, src0) && ne_are_same_shape(src0, dst)); + char* wsptr = (char*)params->wdata; + float* src0ptr = src0->backend == NE_BACKEND_CPU ? (float*)src0->data : (float*)wsptr; + if (src0->backend != NE_BACKEND_CPU) { + wsptr += src0->size; + } + float* src1ptr = src1->backend == NE_BACKEND_CPU ? (float*)src1->data : (float*)wsptr; + if (src1->backend != NE_BACKEND_CPU) { + wsptr += src1->size; + } + float* dstptr = dst->backend == NE_BACKEND_CPU ? (float*)dst->data : (float*)wsptr; + if (params->type == NE_TASK_INIT) { + if (params->ith == 0) { + bool sync = src1->backend != NE_BACKEND_CPU || src0->backend != NE_BACKEND_CPU; + if (sync) { + bestla_device_sync(params->dev_queue); + if (src0->backend != NE_BACKEND_CPU) { + bestla_device_memcpy(src0ptr, src0->data, src0->size, params->dev_queue); + } + if (src1->backend != NE_BACKEND_CPU) { + bestla_device_memcpy(src1ptr, src1->data, src1->size, params->dev_queue); + } + bestla_device_sync(params->dev_queue); + } + } + return; + } - if (params->type == NE_TASK_INIT || params->type == NE_TASK_FINALIZE) { + if (params->type == NE_TASK_FINALIZE) { + if (params->ith == 0) { + if (src1->backend != NE_BACKEND_CPU) { + bestla_device_memcpy_sync(dst->data, dstptr, dst->size, params->dev_queue); + } + } return; } @@ -4649,7 +4482,7 @@ static void ne_compute_forward_add_f32(const struct ne_compute_params* params, c if ((ne_nrows(src1) == 1 || ne_nrows(src1) == ne_nrows(src0)) && ne10 == ne00) { if (nb10 == sizeof(float)) { int step1 = ne11 == 1 ? 0 : ne10; - bestla_add(nr, ne00, (const float*)src0->data, (const float*)src1->data, step1, (float*)dst->data); + bestla_add(nr, ne00, (const float*)src0ptr, (const float*)src1ptr, step1, dstptr); return; } } @@ -4665,9 +4498,9 @@ static void ne_compute_forward_add_f32(const struct ne_compute_params* params, c const int64_t i12 = i02 % ne12; const int64_t i11 = i01 % ne11; - float* dst_ptr = (float*)((char*)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1); - float* src0_ptr = (float*)((char*)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01); - float* src1_ptr = (float*)((char*)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11); + float* dst_ptr = (float*)(dstptr + i03 * nb3 + i02 * nb2 + i01 * nb1); + float* src0_ptr = (float*)((char*)src0ptr + i03 * nb03 + i02 * nb02 + i01 * nb01); + float* src1_ptr = (float*)((char*)src1ptr + i13 * nb13 + i12 * nb12 + i11 * nb11); ne_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr); } @@ -4684,11 +4517,11 @@ static void ne_compute_forward_add_f32(const struct ne_compute_params* params, c const int64_t i12 = i02 % ne12; const int64_t i11 = i01 % ne11; - float* dst_ptr = (float*)((char*)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1); - float* src0_ptr = (float*)((char*)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01); + float* dst_ptr = (float*)(dstptr + i03 * nb3 + i02 * nb2 + i01 * nb1); + float* src0_ptr = (float*)((char*)src0ptr + i03 * nb03 + i02 * nb02 + i01 * nb01); for (int64_t i0 = 0; i0 < ne00; i0++) { - float* src1_ptr = (float*)((char*)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11 + i0 * nb10); + float* src1_ptr = (float*)((char*)src1ptr + i13 * nb13 + i12 * nb12 + i11 * nb11 + i0 * nb10); dst_ptr[i0] = src0_ptr[i0] + *src1_ptr; } @@ -7292,8 +7125,19 @@ static void ne_compute_forward_mul_mat_q_f32_bestla(const struct ne_compute_para // nb01 >= nb00 - src0 is not transposed // compute by src0 rows - + int8_t* devwptr = (int8_t*)params->dev_wdata; + float* actptr = src1->backend == NE_BACKEND_CPU ? (float*)devwptr : (float*)src1->data; + if (src1->backend == NE_BACKEND_CPU) { + devwptr += src1->size; + } if (params->type == NE_TASK_INIT) { +#ifdef NS_SYCL + if (params->ith == 0) { + if (dst->backend == NE_BACKEND_SYCL && src1->backend == NE_BACKEND_CPU) { + bestla_device_memcpy(actptr, src1->data, src1->size, params->dev_queue); + } + } +#endif return; } @@ -7302,13 +7146,6 @@ static void ne_compute_forward_mul_mat_q_f32_bestla(const struct ne_compute_para } if (dst->backend == NE_BACKEND_SYCL) { #ifdef NS_SYCL - float* actptr = (float*)src1->data; - int8_t* devwptr = (int8_t*)params->dev_wdata; - if (src1->backend == NE_BACKEND_CPU) { - actptr = (float*)params->dev_wdata; - bestla_device_memcpy(actptr, src1->data, src1->size, params->dev_queue); - devwptr = (int8_t*)(actptr) + src1->size; - } bestla_device_f32f32_forward(actptr, src0->data, (float*)dst->data, ne1, ne0, ne10, nb11 / ne_element_size(src1), nb1 / ne_element_size(dst), devwptr, params->dev_queue); #else @@ -11876,9 +11713,12 @@ bool ne_support(struct ne_tensor* node, int n_threads, size_t* workspace, size_t if (node->src0->backend == NE_BACKEND_SYCL) { ws_h += node->src0->size; } - if (node->src1->backend == NE_BACKEND_SYCL) { + if (node->src1 && node->src1->backend == NE_BACKEND_SYCL) { ws_h += node->src1->size; } + if (node->backend == NE_BACKEND_SYCL) { + ws_h += node->size; + } *workspace = ws_h; *dev_workspace = ws_d; return support; diff --git a/neural_speed/models/llama/llama_utils.cpp b/neural_speed/models/llama/llama_utils.cpp index 3dbc53ec7..43c20abc4 100644 --- a/neural_speed/models/llama/llama_utils.cpp +++ b/neural_speed/models/llama/llama_utils.cpp @@ -212,7 +212,7 @@ void Llama::load(model_context* ctx, model_progress_callback progress_callback, layer.ffn_gate_exp[x] = ml->get_tensor(layers_i + ".ffn_gate." + std::to_string(x) + ".weight", {n_embd, n_ff}, backend); layer.ffn_down_exp[x] = - ml->get_tensor(layers_i + ".ffn_down." + std::to_string(x) + ".weight", {n_ff, n_embd}, backend); + ml->get_tensor(layers_i + ".ffn_down." + std::to_string(x) + ".weight", {n_ff, n_embd}, NE_BACKEND_SYCL); layer.ffn_up_exp[x] = ml->get_tensor(layers_i + ".ffn_up." + std::to_string(x) + ".weight", {n_embd, n_ff}, backend); }