diff --git a/ggml/include/ggml-amx.h b/ggml/include/ggml-amx.h index b961be8084a7c6..22b3f70f43a674 100644 --- a/ggml/include/ggml-amx.h +++ b/ggml/include/ggml-amx.h @@ -9,15 +9,17 @@ extern "C" { #endif // buffer_type API -GGML_API ggml_backend_buffer_type_t ggml_backend_amx_buffer_type(); +GGML_API ggml_backend_buffer_type_t ggml_backend_amx_buffer_type(void); GGML_API bool ggml_backend_is_amx(ggml_backend_t backend); // backend API -GGML_API ggml_backend_t ggml_backend_amx_init(int n_threads); +GGML_API ggml_backend_t ggml_backend_amx_init(void); GGML_API void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads); +GGML_API ggml_backend_reg_t ggml_backend_amx_reg(void); + #ifdef __cplusplus } #endif diff --git a/ggml/src/ggml-amx.cpp b/ggml/src/ggml-amx.cpp index 7494b0219e3676..344a2aa49c60fd 100644 --- a/ggml/src/ggml-amx.cpp +++ b/ggml/src/ggml-amx.cpp @@ -186,57 +186,6 @@ static enum ggml_status ggml_backend_amx_graph_compute(ggml_backend_t backend, s GGML_UNUSED(backend); } -static bool ggml_backend_amx_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { - - // handle only 2d gemm for now - auto is_contiguous_2d = [](const struct ggml_tensor * t) { - return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1; - }; - - switch (op->op) { - case GGML_OP_NONE: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_PERMUTE: - case GGML_OP_TRANSPOSE: - return true; - - case GGML_OP_MUL_MAT: { - const struct ggml_tensor * src0 = op->src[0]; - const struct ggml_tensor * src1 = op->src[1]; - - const enum ggml_type type = src0->type; - const int64_t ne0 = op->ne[0]; - - bool is_training = src0->grad || src1->grad; - - // amx kernels enables for Q4_0, Q4_1, Q8_0, F16 - // Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256 - bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16); - - bool can_use_amx = - is_contiguous_2d(src0) && // src0 must be contiguous - is_contiguous_2d(src1) && // src1 must be contiguous - !is_training && // inference only - src1->type == GGML_TYPE_F32 && // src1 must be float32 - has_amx_kernels && // with amx kernel impls - ne0 % (TILE_N * 2) == 0; // out_features is 32x - - return can_use_amx; - } - default: - return false; - } - - GGML_UNUSED(backend); -} - -static bool ggml_backend_amx_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { - return buft->iface.get_name == ggml_backend_amx_buffer_type_get_name; - - GGML_UNUSED(backend); -} - static struct ggml_backend_i ggml_backend_amx_i = { /* .get_name = */ ggml_backend_amx_name, /* .free = */ ggml_backend_amx_free, @@ -250,8 +199,8 @@ static struct ggml_backend_i ggml_backend_amx_i = { /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ ggml_backend_amx_graph_compute, - /* .supports_op = */ ggml_backend_amx_supports_op, - /* .supports_buft = */ ggml_backend_amx_supports_buft, + /* .supports_op = */ NULL, + /* .supports_buft = */ NULL, /* .offload_op = */ NULL, /* .event_record = */ NULL, /* .event_wait = */ NULL, @@ -279,7 +228,7 @@ static bool ggml_amx_init() { #endif } -ggml_backend_t ggml_backend_amx_init(int n_threads) { +ggml_backend_t ggml_backend_amx_init() { // invoke a Linux system call to request access to AMX features ggml_amx_init(); @@ -291,12 +240,10 @@ ggml_backend_t ggml_backend_amx_init(int n_threads) { ggml_backend_t backend = new ggml_backend { /* .guid = */ ggml_backend_amx_guid(), /* .interface = */ ggml_backend_amx_i, - /* .device = */ NULL, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_amx_reg(), 0), /* .context = */ ctx, }; - ggml_backend_amx_set_n_threads(backend, n_threads); - return backend; } @@ -311,6 +258,192 @@ void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads) { ctx->n_threads = n_threads; } +// device interface + +static const char * ggml_backend_amx_device_get_name(ggml_backend_dev_t dev) { + return "AMX"; + + GGML_UNUSED(dev); +} + +static const char * ggml_backend_amx_device_get_description(ggml_backend_dev_t dev) { + return "Intel Advanced Matrix Extensions"; + + GGML_UNUSED(dev); +} + +static void ggml_backend_amx_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + // TODO + *free = 0; + *total = 0; + + GGML_UNUSED(dev); +} + +static enum ggml_backend_dev_type ggml_backend_amx_device_get_type(ggml_backend_dev_t dev) { + return GGML_BACKEND_DEVICE_TYPE_CPU; + + GGML_UNUSED(dev); +} + +static void ggml_backend_amx_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_amx_device_get_name(dev); + props->description = ggml_backend_amx_device_get_description(dev); + props->type = ggml_backend_amx_device_get_type(dev); + ggml_backend_amx_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = { + /* .async = */ false, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ true, + /* .events = */ false, + }; +} + +static ggml_backend_t ggml_backend_amx_device_init(ggml_backend_dev_t dev, const char * params) { + return ggml_backend_amx_init(); + + GGML_UNUSED(dev); + GGML_UNUSED(params); +} + +static ggml_backend_buffer_type_t ggml_backend_amx_device_get_buffer_type(ggml_backend_dev_t dev) { + return ggml_backend_amx_buffer_type(); + + GGML_UNUSED(dev); +} + +static ggml_backend_buffer_t ggml_backend_amx_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { + GGML_ASSERT((uintptr_t)ptr % TENSOR_ALIGNMENT == 0 && "buffer pointer must be aligned"); + return ggml_backend_buffer_init(ggml_backend_amx_buffer_type(), ggml_backend_amx_buffer_interface, ptr, size); + + GGML_UNUSED(dev); + GGML_UNUSED(max_tensor_size); +} + +static bool ggml_backend_amx_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + + // handle only 2d gemm for now + auto is_contiguous_2d = [](const struct ggml_tensor * t) { + return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1; + }; + + switch (op->op) { + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + return true; + + case GGML_OP_MUL_MAT: { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + + const enum ggml_type type = src0->type; + const int64_t ne0 = op->ne[0]; + + bool is_training = src0->grad || src1->grad; + + // amx kernels enables for Q4_0, Q4_1, Q8_0, F16 + // Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256 + bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16); + + bool can_use_amx = + is_contiguous_2d(src0) && // src0 must be contiguous + is_contiguous_2d(src1) && // src1 must be contiguous + !is_training && // inference only + src1->type == GGML_TYPE_F32 && // src1 must be float32 + has_amx_kernels && // with amx kernel impls + ne0 % (TILE_N * 2) == 0; // out_features is 32x + + return can_use_amx; + } + default: + return false; + } + + GGML_UNUSED(dev); +} + +static bool ggml_backend_amx_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_amx_buffer_type_get_name; + + GGML_UNUSED(dev); +} + +static const struct ggml_backend_device_i ggml_backend_amx_device_i = { + /* .get_name = */ ggml_backend_amx_device_get_name, + /* .get_description = */ ggml_backend_amx_device_get_description, + /* .get_memory = */ ggml_backend_amx_device_get_memory, + /* .get_type = */ ggml_backend_amx_device_get_type, + /* .get_props = */ ggml_backend_amx_device_get_props, + /* .init_backend = */ ggml_backend_amx_device_init, + /* .get_buffer_type = */ ggml_backend_amx_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ ggml_backend_amx_device_buffer_from_ptr, + /* .supports_op = */ ggml_backend_amx_device_supports_op, + /* .supports_buft = */ ggml_backend_amx_device_supports_buft, + /* .offload_op = */ NULL, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +// backend reg interface + +static const char * ggml_backend_amx_reg_get_name(ggml_backend_reg_t reg) { + return "AMX"; + + GGML_UNUSED(reg); +} + +static size_t ggml_backend_amx_reg_get_device_count(ggml_backend_reg_t reg) { + return 1; + + GGML_UNUSED(reg); +} + +static ggml_backend_dev_t ggml_backend_amx_reg_get_device(ggml_backend_reg_t reg, size_t index) { + GGML_ASSERT(index == 0); + + static ggml_backend_device ggml_backend_amx_device = { + /* .iface = */ ggml_backend_amx_device_i, + /* .reg = */ reg, + /* .context = */ nullptr, + }; + + return &ggml_backend_amx_device; + + GGML_UNUSED(reg); + GGML_UNUSED(index); +} + +static void * ggml_backend_amx_get_proc_address(ggml_backend_reg_t reg, const char * name) { + if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) { + return (void *)ggml_backend_amx_set_n_threads; + } + return NULL; + + GGML_UNUSED(reg); + GGML_UNUSED(name); +} + +static const struct ggml_backend_reg_i ggml_backend_amx_reg_i = { + /* .get_name = */ ggml_backend_amx_reg_get_name, + /* .get_device_count = */ ggml_backend_amx_reg_get_device_count, + /* .get_device = */ ggml_backend_amx_reg_get_device, + /* .get_proc_address = */ ggml_backend_amx_get_proc_address, +}; + +ggml_backend_reg_t ggml_backend_amx_reg(void) { + static struct ggml_backend_reg ggml_backend_amx_reg = { + /* .iface = */ ggml_backend_amx_reg_i, + /* .context = */ NULL, + }; + + return &ggml_backend_amx_reg; +} + #else // if defined(__AMX_INT8__) ggml_backend_t ggml_backend_amx_init(int n_threads) { diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 15d650150a5f34..504e0f3605bebd 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -329,7 +329,6 @@ bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type if (backend->device) { return ggml_backend_dev_supports_buft(backend->device, buft); } - return backend->iface.supports_buft(backend, buft); } @@ -546,6 +545,14 @@ void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * na #include "ggml-rpc.h" #endif +#ifndef __AMX_INT8__ +#undef GGML_USE_AMX +#endif + +#ifdef GGML_USE_AMX +# include "ggml-amx.h" +#endif + struct ggml_backend_registry { std::vector backends; std::vector devices; @@ -563,6 +570,9 @@ struct ggml_backend_registry { #ifdef GGML_USE_RPC register_backend(ggml_backend_rpc_reg()); #endif +#ifdef GGML_USE_AMX + register_backend(ggml_backend_amx_reg()); +#endif // TODO: sycl, vulkan, kompute, cann diff --git a/src/llama.cpp b/src/llama.cpp index 5b65949cd373da..5fbbf1796b0396 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -19520,18 +19520,6 @@ struct llama_context * llama_new_context_with_model( } #endif -#if defined(GGML_USE_AMX) - { - ggml_backend_t backend = ggml_backend_amx_init(cparams.n_threads); - if (backend == nullptr) { - LLAMA_LOG_ERROR("%s: failed to initialize AMX backend\n", __func__); - llama_free(ctx); - return nullptr; - } - ctx->backends.push_back(backend); - } -#endif - // add other backends (such as BLAS) for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { ggml_backend_dev_t dev = ggml_backend_dev_get(i);