Skip to content

Commit

Permalink
use device reg for AMX backend
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
mingfeima committed Oct 16, 2024
1 parent 3b38fc0 commit 45451e2
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 72 deletions.
6 changes: 4 additions & 2 deletions ggml/include/ggml-amx.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@ extern "C" {
#endif

// buffer_type API
GGML_API ggml_backend_buffer_type_t ggml_backend_amx_buffer_type();
GGML_API ggml_backend_buffer_type_t ggml_backend_amx_buffer_type(void);

GGML_API bool ggml_backend_is_amx(ggml_backend_t backend);

// backend API
GGML_API ggml_backend_t ggml_backend_amx_init(int n_threads);
GGML_API ggml_backend_t ggml_backend_amx_init(void);

GGML_API void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads);

GGML_API ggml_backend_reg_t ggml_backend_amx_reg(void);

#ifdef __cplusplus
}
#endif
247 changes: 190 additions & 57 deletions ggml/src/ggml-amx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,57 +186,6 @@ static enum ggml_status ggml_backend_amx_graph_compute(ggml_backend_t backend, s
GGML_UNUSED(backend);
}

static bool ggml_backend_amx_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {

// handle only 2d gemm for now
auto is_contiguous_2d = [](const struct ggml_tensor * t) {
return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
};

switch (op->op) {
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
return true;

case GGML_OP_MUL_MAT: {
const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * src1 = op->src[1];

const enum ggml_type type = src0->type;
const int64_t ne0 = op->ne[0];

bool is_training = src0->grad || src1->grad;

// amx kernels enables for Q4_0, Q4_1, Q8_0, F16
// Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256
bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16);

bool can_use_amx =
is_contiguous_2d(src0) && // src0 must be contiguous
is_contiguous_2d(src1) && // src1 must be contiguous
!is_training && // inference only
src1->type == GGML_TYPE_F32 && // src1 must be float32
has_amx_kernels && // with amx kernel impls
ne0 % (TILE_N * 2) == 0; // out_features is 32x

return can_use_amx;
}
default:
return false;
}

GGML_UNUSED(backend);
}

static bool ggml_backend_amx_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
return buft->iface.get_name == ggml_backend_amx_buffer_type_get_name;

GGML_UNUSED(backend);
}

static struct ggml_backend_i ggml_backend_amx_i = {
/* .get_name = */ ggml_backend_amx_name,
/* .free = */ ggml_backend_amx_free,
Expand All @@ -250,8 +199,8 @@ static struct ggml_backend_i ggml_backend_amx_i = {
/* .graph_plan_update = */ NULL,
/* .graph_plan_compute = */ NULL,
/* .graph_compute = */ ggml_backend_amx_graph_compute,
/* .supports_op = */ ggml_backend_amx_supports_op,
/* .supports_buft = */ ggml_backend_amx_supports_buft,
/* .supports_op = */ NULL,
/* .supports_buft = */ NULL,
/* .offload_op = */ NULL,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
Expand Down Expand Up @@ -279,7 +228,7 @@ static bool ggml_amx_init() {
#endif
}

ggml_backend_t ggml_backend_amx_init(int n_threads) {
ggml_backend_t ggml_backend_amx_init() {

// invoke a Linux system call to request access to AMX features
ggml_amx_init();
Expand All @@ -291,12 +240,10 @@ ggml_backend_t ggml_backend_amx_init(int n_threads) {
ggml_backend_t backend = new ggml_backend {
/* .guid = */ ggml_backend_amx_guid(),
/* .interface = */ ggml_backend_amx_i,
/* .device = */ NULL,
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_amx_reg(), 0),
/* .context = */ ctx,
};

ggml_backend_amx_set_n_threads(backend, n_threads);

return backend;
}

Expand All @@ -311,6 +258,192 @@ void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads) {
ctx->n_threads = n_threads;
}

// device interface

static const char * ggml_backend_amx_device_get_name(ggml_backend_dev_t dev) {
return "AMX";

GGML_UNUSED(dev);
}

static const char * ggml_backend_amx_device_get_description(ggml_backend_dev_t dev) {
return "Intel Advanced Matrix Extensions";

GGML_UNUSED(dev);
}

static void ggml_backend_amx_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
// TODO
*free = 0;
*total = 0;

GGML_UNUSED(dev);
}

static enum ggml_backend_dev_type ggml_backend_amx_device_get_type(ggml_backend_dev_t dev) {
return GGML_BACKEND_DEVICE_TYPE_CPU;

GGML_UNUSED(dev);
}

static void ggml_backend_amx_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
props->name = ggml_backend_amx_device_get_name(dev);
props->description = ggml_backend_amx_device_get_description(dev);
props->type = ggml_backend_amx_device_get_type(dev);
ggml_backend_amx_device_get_memory(dev, &props->memory_free, &props->memory_total);
props->caps = {
/* .async = */ false,
/* .host_buffer = */ false,
/* .buffer_from_host_ptr = */ true,
/* .events = */ false,
};
}

static ggml_backend_t ggml_backend_amx_device_init(ggml_backend_dev_t dev, const char * params) {
return ggml_backend_amx_init();

GGML_UNUSED(dev);
GGML_UNUSED(params);
}

static ggml_backend_buffer_type_t ggml_backend_amx_device_get_buffer_type(ggml_backend_dev_t dev) {
return ggml_backend_amx_buffer_type();

GGML_UNUSED(dev);
}

static ggml_backend_buffer_t ggml_backend_amx_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
GGML_ASSERT((uintptr_t)ptr % TENSOR_ALIGNMENT == 0 && "buffer pointer must be aligned");
return ggml_backend_buffer_init(ggml_backend_amx_buffer_type(), ggml_backend_amx_buffer_interface, ptr, size);

GGML_UNUSED(dev);
GGML_UNUSED(max_tensor_size);
}

static bool ggml_backend_amx_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {

// handle only 2d gemm for now
auto is_contiguous_2d = [](const struct ggml_tensor * t) {
return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
};

switch (op->op) {
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
return true;

case GGML_OP_MUL_MAT: {
const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * src1 = op->src[1];

const enum ggml_type type = src0->type;
const int64_t ne0 = op->ne[0];

bool is_training = src0->grad || src1->grad;

// amx kernels enables for Q4_0, Q4_1, Q8_0, F16
// Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256
bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16);

bool can_use_amx =
is_contiguous_2d(src0) && // src0 must be contiguous
is_contiguous_2d(src1) && // src1 must be contiguous
!is_training && // inference only
src1->type == GGML_TYPE_F32 && // src1 must be float32
has_amx_kernels && // with amx kernel impls
ne0 % (TILE_N * 2) == 0; // out_features is 32x

return can_use_amx;
}
default:
return false;
}

GGML_UNUSED(dev);
}

static bool ggml_backend_amx_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
return buft->iface.get_name == ggml_backend_amx_buffer_type_get_name;

GGML_UNUSED(dev);
}

static const struct ggml_backend_device_i ggml_backend_amx_device_i = {
/* .get_name = */ ggml_backend_amx_device_get_name,
/* .get_description = */ ggml_backend_amx_device_get_description,
/* .get_memory = */ ggml_backend_amx_device_get_memory,
/* .get_type = */ ggml_backend_amx_device_get_type,
/* .get_props = */ ggml_backend_amx_device_get_props,
/* .init_backend = */ ggml_backend_amx_device_init,
/* .get_buffer_type = */ ggml_backend_amx_device_get_buffer_type,
/* .get_host_buffer_type = */ NULL,
/* .buffer_from_host_ptr = */ ggml_backend_amx_device_buffer_from_ptr,
/* .supports_op = */ ggml_backend_amx_device_supports_op,
/* .supports_buft = */ ggml_backend_amx_device_supports_buft,
/* .offload_op = */ NULL,
/* .event_new = */ NULL,
/* .event_free = */ NULL,
/* .event_synchronize = */ NULL,
};

// backend reg interface

static const char * ggml_backend_amx_reg_get_name(ggml_backend_reg_t reg) {
return "AMX";

GGML_UNUSED(reg);
}

static size_t ggml_backend_amx_reg_get_device_count(ggml_backend_reg_t reg) {
return 1;

GGML_UNUSED(reg);
}

static ggml_backend_dev_t ggml_backend_amx_reg_get_device(ggml_backend_reg_t reg, size_t index) {
GGML_ASSERT(index == 0);

static ggml_backend_device ggml_backend_amx_device = {
/* .iface = */ ggml_backend_amx_device_i,
/* .reg = */ reg,
/* .context = */ nullptr,
};

return &ggml_backend_amx_device;

GGML_UNUSED(reg);
GGML_UNUSED(index);
}

static void * ggml_backend_amx_get_proc_address(ggml_backend_reg_t reg, const char * name) {
if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) {
return (void *)ggml_backend_amx_set_n_threads;
}
return NULL;

GGML_UNUSED(reg);
GGML_UNUSED(name);
}

static const struct ggml_backend_reg_i ggml_backend_amx_reg_i = {
/* .get_name = */ ggml_backend_amx_reg_get_name,
/* .get_device_count = */ ggml_backend_amx_reg_get_device_count,
/* .get_device = */ ggml_backend_amx_reg_get_device,
/* .get_proc_address = */ ggml_backend_amx_get_proc_address,
};

ggml_backend_reg_t ggml_backend_amx_reg(void) {
static struct ggml_backend_reg ggml_backend_amx_reg = {
/* .iface = */ ggml_backend_amx_reg_i,
/* .context = */ NULL,
};

return &ggml_backend_amx_reg;
}

#else // if defined(__AMX_INT8__)

ggml_backend_t ggml_backend_amx_init(int n_threads) {
Expand Down
12 changes: 11 additions & 1 deletion ggml/src/ggml-backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,6 @@ bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type
if (backend->device) {
return ggml_backend_dev_supports_buft(backend->device, buft);
}

return backend->iface.supports_buft(backend, buft);
}

Expand Down Expand Up @@ -546,6 +545,14 @@ void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * na
#include "ggml-rpc.h"
#endif

#ifndef __AMX_INT8__
#undef GGML_USE_AMX
#endif

#ifdef GGML_USE_AMX
# include "ggml-amx.h"
#endif

struct ggml_backend_registry {
std::vector<ggml_backend_reg_t> backends;
std::vector<ggml_backend_dev_t> devices;
Expand All @@ -563,6 +570,9 @@ struct ggml_backend_registry {
#ifdef GGML_USE_RPC
register_backend(ggml_backend_rpc_reg());
#endif
#ifdef GGML_USE_AMX
register_backend(ggml_backend_amx_reg());
#endif

// TODO: sycl, vulkan, kompute, cann

Expand Down
12 changes: 0 additions & 12 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19520,18 +19520,6 @@ struct llama_context * llama_new_context_with_model(
}
#endif

#if defined(GGML_USE_AMX)
{
ggml_backend_t backend = ggml_backend_amx_init(cparams.n_threads);
if (backend == nullptr) {
LLAMA_LOG_ERROR("%s: failed to initialize AMX backend\n", __func__);
llama_free(ctx);
return nullptr;
}
ctx->backends.push_back(backend);
}
#endif

// add other backends (such as BLAS)
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
Expand Down

0 comments on commit 45451e2

Please sign in to comment.