diff --git a/ggml/src/ggml-amx.cpp b/ggml/src/ggml-amx.cpp index 25f46e77c74bf1..7494b0219e3676 100644 --- a/ggml/src/ggml-amx.cpp +++ b/ggml/src/ggml-amx.cpp @@ -188,36 +188,45 @@ static enum ggml_status ggml_backend_amx_graph_compute(ggml_backend_t backend, s static bool ggml_backend_amx_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { - if (op->op != GGML_OP_MUL_MAT) { - return false; - } + // handle only 2d gemm for now + auto is_contiguous_2d = [](const struct ggml_tensor * t) { + return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1; + }; - const struct ggml_tensor * src0 = op->src[0]; - const struct ggml_tensor * src1 = op->src[1]; + switch (op->op) { + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + return true; - const enum ggml_type type = src0->type; - const int64_t ne0 = op->ne[0]; + case GGML_OP_MUL_MAT: { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; - bool is_training = src0->grad || src1->grad; + const enum ggml_type type = src0->type; + const int64_t ne0 = op->ne[0]; - // amx kernels enables for Q4_0, Q4_1, Q8_0, F16 - // Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256 - bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16); + bool is_training = src0->grad || src1->grad; - // handle only 2d gemm for now - auto is_contiguous_2d = [](const struct ggml_tensor * t) { - return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1; - }; + // amx kernels enables for Q4_0, Q4_1, Q8_0, F16 + // Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256 + bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16); - bool can_use_amx = - is_contiguous_2d(src0) && // src0 must be contiguous - is_contiguous_2d(src1) && // src1 must be contiguous - !is_training && // inference only - src1->type == GGML_TYPE_F32 && // src1 must be float32 - has_amx_kernels && // with amx kernel impls - ne0 % (TILE_N * 2) == 0; // out_features is 32x + bool can_use_amx = + is_contiguous_2d(src0) && // src0 must be contiguous + is_contiguous_2d(src1) && // src1 must be contiguous + !is_training && // inference only + src1->type == GGML_TYPE_F32 && // src1 must be float32 + has_amx_kernels && // with amx kernel impls + ne0 % (TILE_N * 2) == 0; // out_features is 32x - return can_use_amx; + return can_use_amx; + } + default: + return false; + } GGML_UNUSED(backend); } diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 0fa5004e443fbc..15d650150a5f34 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1496,12 +1496,10 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st } } -#ifndef GGML_USE_AMX if (tensor->buffer || (tensor->view_src && tensor->view_src->buffer)) { // since the tensor is pre-allocated, it cannot be moved to another backend GGML_ABORT("pre-allocated tensor in a backend that cannot run the operation"); } -#endif // graph input if (tensor->flags & GGML_TENSOR_FLAG_INPUT) {