Skip to content

Commit

Permalink
sycl : temporary fix for performance regression
Browse files Browse the repository at this point in the history
  • Loading branch information
Alcpz committed Nov 18, 2024
1 parent 9901068 commit c3f6678
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions ggml/src/ggml-sycl/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "ggml-sycl/backend.hpp"
#include "ggml-sycl/presets.hpp"
#include "ggml-sycl/gemm.hpp"
#include "ggml.h"

static bool g_sycl_loaded = false;

Expand Down Expand Up @@ -3446,22 +3447,38 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
if (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda)
use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;

// printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
// printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
// printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
// printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
// printf("src0 is contiguous %d, transposed %d, permuted = %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_is_permuted(src0), ggml_type_name(src0->type), src0->name);
// printf("src1 is contiguous %d, transposed %d, permuted = %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_is_permuted(src1), ggml_type_name(src1->type), src1->name);



if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
// KQ single-batch
// printf("MUL_MAT KQ single-batch\n");
ggml_sycl_mul_mat_vec_p021(ctx, src0, src1, dst);
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
// KQV single-batch
// printf("MUL_MAT KQV single-batch\n");
ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
} else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
// KQ + KQV multi-batch
// printf("MUL_MAT KQ + KQV multi-batch\n");
ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
} else if (use_dequantize_mul_mat_vec) {
// printf("MUL_MAT dmmv\n");
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
} else if (use_mul_mat_vec_q) {
// printf("MUL_MAT mmvq\n");
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
} else if (use_mul_mat_q) {
// printf("MUL_MAT mmq\n");
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
} else {
// printf("MUL_MAT ELSE\n");
ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
}
}
Expand Down Expand Up @@ -4350,9 +4367,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
if (op->op == GGML_OP_MUL_MAT) {
a = op->src[0];
b = op->src[1];
if (ggml_is_permuted(a) || ggml_is_permuted(b)) {
if (ggml_is_permuted(a)) {
// TODO: fix like https://github.com/ggerganov/llama.cpp/pull/10021
return false;
if (a->nb[0] <= a->nb[1] && a->nb[3] <= a->nb[2]) return false; // 0,1,3,2 Unsupported
if (b->type != GGML_TYPE_F32) return false;
}
} else {
a = op->src[2];
Expand Down

0 comments on commit c3f6678

Please sign in to comment.