diff --git a/ggml-metal.m b/ggml-metal.m index a0efda0baa2d5..3646c40dfd181 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -24,10 +24,7 @@ #define UNUSED(x) (void)(x) -#define GGML_METAL_MAX_KERNELS 256 - struct ggml_metal_kernel { - id function; id pipeline; }; @@ -159,11 +156,10 @@ id device; id queue; - id library; dispatch_queue_t d_queue; - struct ggml_metal_kernel kernels[GGML_METAL_MAX_KERNELS]; + struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT]; bool support_simdgroup_reduction; bool support_simdgroup_mm; @@ -246,6 +242,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ ctx->queue = [ctx->device newCommandQueue]; ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); + id metal_library; + // load library { NSBundle * bundle = nil; @@ -260,7 +258,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ // pre-compiled library found NSURL * libURL = [NSURL fileURLWithPath:libPath]; GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]); - ctx->library = [ctx->device newLibraryWithURL:libURL error:&error]; + metal_library = [ctx->device newLibraryWithURL:libURL error:&error]; if (error) { GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); return NULL; @@ -302,7 +300,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ //[options setFastMathEnabled:false]; - ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error]; + metal_library = [ctx->device newLibraryWithSource:src options:options error:&error]; if (error) { GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]); return NULL; @@ -367,8 +365,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ { NSError * error = nil; - for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) { - ctx->kernels[i].function = nil; + for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) { ctx->kernels[i].pipeline = nil; } @@ -380,10 +377,12 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ #define GGML_METAL_ADD_KERNEL(e, name, supported) \ if (supported) { \ struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \ - kernel->function = [ctx->library newFunctionWithName:@"kernel_"#name]; \ - kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:kernel->function error:&error]; \ + id metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \ + kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \ + [metal_function release]; \ if (error) { \ GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ + [metal_library release]; \ return NULL; \ } \ } else { \ @@ -512,23 +511,17 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true); } + [metal_library release]; return ctx; } static void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_LOG_INFO("%s: deallocating\n", __func__); - for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) { - if (ctx->kernels[i].pipeline) { - [ctx->kernels[i].pipeline release]; - } - - if (ctx->kernels[i].function) { - [ctx->kernels[i].function release]; - } + for (int i = 0; i < GGML_METAL_KERNEL_TYPE_COUNT; ++i) { + [ctx->kernels[i].pipeline release]; } - [ctx->library release]; [ctx->queue release]; [ctx->device release];