Skip to content

Commit

Permalink
fix supports_op
Browse files Browse the repository at this point in the history
  • Loading branch information
mingfeima committed Oct 15, 2024
1 parent ae93769 commit 3b38fc0
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 25 deletions.
55 changes: 32 additions & 23 deletions ggml/src/ggml-amx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,36 +188,45 @@ static enum ggml_status ggml_backend_amx_graph_compute(ggml_backend_t backend, s

static bool ggml_backend_amx_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {

if (op->op != GGML_OP_MUL_MAT) {
return false;
}
// handle only 2d gemm for now
auto is_contiguous_2d = [](const struct ggml_tensor * t) {
return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
};

const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * src1 = op->src[1];
switch (op->op) {
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
return true;

const enum ggml_type type = src0->type;
const int64_t ne0 = op->ne[0];
case GGML_OP_MUL_MAT: {
const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * src1 = op->src[1];

bool is_training = src0->grad || src1->grad;
const enum ggml_type type = src0->type;
const int64_t ne0 = op->ne[0];

// amx kernels enables for Q4_0, Q4_1, Q8_0, F16
// Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256
bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16);
bool is_training = src0->grad || src1->grad;

// handle only 2d gemm for now
auto is_contiguous_2d = [](const struct ggml_tensor * t) {
return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
};
// amx kernels enables for Q4_0, Q4_1, Q8_0, F16
// Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256
bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16);

bool can_use_amx =
is_contiguous_2d(src0) && // src0 must be contiguous
is_contiguous_2d(src1) && // src1 must be contiguous
!is_training && // inference only
src1->type == GGML_TYPE_F32 && // src1 must be float32
has_amx_kernels && // with amx kernel impls
ne0 % (TILE_N * 2) == 0; // out_features is 32x
bool can_use_amx =
is_contiguous_2d(src0) && // src0 must be contiguous
is_contiguous_2d(src1) && // src1 must be contiguous
!is_training && // inference only
src1->type == GGML_TYPE_F32 && // src1 must be float32
has_amx_kernels && // with amx kernel impls
ne0 % (TILE_N * 2) == 0; // out_features is 32x

return can_use_amx;
return can_use_amx;
}
default:
return false;
}

GGML_UNUSED(backend);
}
Expand Down
2 changes: 0 additions & 2 deletions ggml/src/ggml-backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1496,12 +1496,10 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
}
}

#ifndef GGML_USE_AMX
if (tensor->buffer || (tensor->view_src && tensor->view_src->buffer)) {
// since the tensor is pre-allocated, it cannot be moved to another backend
GGML_ABORT("pre-allocated tensor in a backend that cannot run the operation");
}
#endif

// graph input
if (tensor->flags & GGML_TENSOR_FLAG_INPUT) {
Expand Down

0 comments on commit 3b38fc0

Please sign in to comment.