diff --git a/neural_speed/__init__.py b/neural_speed/__init__.py index afc1a6100..9349f153f 100644 --- a/neural_speed/__init__.py +++ b/neural_speed/__init__.py @@ -72,6 +72,8 @@ def __import_package(self, model_type): import neural_speed.phi_cpp as cpp_model elif model_type == "whisper": import neural_speed.whisper_cpp as cpp_model + elif model_type == "llama_yarn": + import neural_speed.llama_yarn_cpp as cpp_model else: raise TypeError("Unspported model type {}!".format(model_type)) self.module = cpp_model @@ -81,6 +83,9 @@ def get_model_type(model_config): model_type = model_maps.get(model_config.model_type, model_config.model_type) if model_type == "chatglm" and "chatglm2" in model_config._name_or_path: model_type = "chatglm2" + elif model_type == "llama" and model_config.rope_scaling != None: + if model_config.rope_scaling["type"] == "yarn": + model_type = "llama_yarn" return model_type def init(self, model_name, use_quant=True, use_gptq=False, use_awq=False, diff --git a/neural_speed/application/CMakeLists.txt b/neural_speed/application/CMakeLists.txt index 3782e5e1d..3d94b23ce 100644 --- a/neural_speed/application/CMakeLists.txt +++ b/neural_speed/application/CMakeLists.txt @@ -70,6 +70,7 @@ compile_quant(quant_mistral quant_model.cpp mistral llama) compile_quant(quant_qwen quant_model.cpp qwen qwen) compile_quant(quant_phi quant_model.cpp phi phi) compile_quant(quant_whisper quant_whisper.cpp whisper whisper) +compile_quant(quant_llama_yarn quant_model.cpp llama_yarn llama_yarn) # all models running if (NS_PYTHON_API) @@ -93,8 +94,7 @@ set(mymap_mistral 14) set(mymap_qwen 15) set(mymap_phi 16) set(mymap_whisper 17) - - +set(mymap_llama_yarn 18) function(compile_run TARGET MAIN_CPP MAIN_PY MODEL_NAME MODEL_LIB) add_executable_w_warning(${TARGET} ${MAIN_CPP}) @@ -129,6 +129,7 @@ compile_run(run_baichuan main_run.cpp main_pybind.cpp baichuan baichuan) compile_run(run_mistral main_run.cpp main_pybind.cpp mistral llama) compile_run(run_qwen main_run.cpp main_pybind.cpp qwen qwen) compile_run(run_phi main_run.cpp main_pybind.cpp phi phi) +compile_run(run_llama_yarn main_run.cpp main_pybind.cpp llama_yarn llama_yarn) # speech recognition compile_run(run_whisper audio_run.cpp whisper_pybind.cpp whisper whisper) diff --git a/neural_speed/application/main_pybind.cpp b/neural_speed/application/main_pybind.cpp index 6a4973676..0f312e9ce 100644 --- a/neural_speed/application/main_pybind.cpp +++ b/neural_speed/application/main_pybind.cpp @@ -899,6 +899,10 @@ PYBIND11_MODULE(phi_cpp, m) PYBIND11_MODULE(whisper_cpp, m) +#elif MODEL_NAME_ID == 18 + +PYBIND11_MODULE(llama_yarn_cpp, m) + #endif { m.doc() = "cpp model python binding"; diff --git a/neural_speed/convert/__init__.py b/neural_speed/convert/__init__.py index da272ce32..50d9d6754 100644 --- a/neural_speed/convert/__init__.py +++ b/neural_speed/convert/__init__.py @@ -25,6 +25,9 @@ def convert_model(model, outfile, outtype, whisper_repo_path=None): config = AutoConfig.from_pretrained(model, trust_remote_code=True) model_type = model_maps.get(config.model_type, config.model_type) + if model_type == "llama" and config.rope_scaling != None: + if (config.rope_scaling["type"] == "yarn"): + model_type = "llama_yarn" quantized_model = 'gptq' in str(model).lower() or 'awq' in str(model).lower() if quantized_model: diff --git a/neural_speed/convert/convert_llama_yarn.py b/neural_speed/convert/convert_llama_yarn.py index efe795a02..70ca89fea 100644 --- a/neural_speed/convert/convert_llama_yarn.py +++ b/neural_speed/convert/convert_llama_yarn.py @@ -1,3 +1,16 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import io import os diff --git a/neural_speed/models/CMakeLists.txt b/neural_speed/models/CMakeLists.txt index f62b64e41..b21d02019 100644 --- a/neural_speed/models/CMakeLists.txt +++ b/neural_speed/models/CMakeLists.txt @@ -35,3 +35,4 @@ add_model(whisper whisper/whisper.cpp whisper/whisper_utils.cpp ${MODEL_UTILS_SO add_model(chatglm chatglm/chatglm.cpp chatglm/chatglm_utils.cpp ${MODEL_UTILS_SOURCE}) add_model(chatglm2 chatglm/chatglm2.cpp chatglm/chatglm2_utils.cpp ${MODEL_UTILS_SOURCE}) add_model(phi phi/phi.cpp phi/phi_utils.cpp ${MODEL_UTILS_SOURCE}) +add_model(llama_yarn llama/llama_yarn.cpp llama/llama_yarn_utils.cpp ${MODEL_UTILS_SOURCE}) diff --git a/neural_speed/models/llama/llama_yarn.cpp b/neural_speed/models/llama/llama_yarn.cpp index 9ca196d97..9ede640ff 100644 --- a/neural_speed/models/llama/llama_yarn.cpp +++ b/neural_speed/models/llama/llama_yarn.cpp @@ -33,7 +33,7 @@ #include "core/data_types.h" #include "core/layers/mha_dense.h" #include "core/ne.h" -#include "core/ne_jblas.h" +#include "core/ne_bestla.h" #include "core/ne_layers.h" #include "models/model_utils/model_config.h" #include "models/model_utils/model_files.h" @@ -87,7 +87,7 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp int n_head_kv = hparams.n_head_kv; bool enable_tp = false; -#ifdef NE_TP_MODEL +#ifdef NS_TP_MODEL parallel_context* p_ctx = init_parallel_context(); int32_t world_size = get_tp_size(p_ctx); int32_t rank = get_tp_rank(p_ctx); @@ -119,27 +119,27 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp ne_cgraph gf = {}; gf.n_threads = N >= 32 && ne_cpu_has_blas() ? 1 : n_threads; - const bool run_mha_reordered = kv_self.k->type == NE_TYPE_JBLAS; + const bool run_mha_reordered = kv_self.k->type == NE_TYPE_BTLA; kv_cache_info_t kv_cache_info = {0, 0}; if (run_mha_reordered) { - NE_ASSERT(kv_self.v->type == NE_TYPE_JBLAS); // kv type should be the same + NE_ASSERT(kv_self.v->type == NE_TYPE_BTLA); // kv type should be the same attn_shape_t attn_shape = { /* .batch_size = */ 1, /* .head_num = */ n_head, /* .heads_kv = */ n_head_kv, /* .head_size = */ head_size, - /* .sl_q = */ N, // Note: make sure that jblas reordered attn supports next token inferencing + /* .sl_q = */ N, // Note: make sure that bestla reordered attn supports next token inferencing /* .sl_kv = */ n_cached, }; - NE_ASSERT(("jblas managed kv-cache not supported; use `--memory-f16 / --memory-f32` instead", - jblas_reordered_attn_fp32_support(&attn_shape))); + NE_ASSERT(("bestla managed kv-cache not supported; use `--memory-f16 / --memory-f32` instead", + bestla_reordered_attn_fp32_support(&attn_shape))); kv_shape_t kv_shape{ /* .heads_kv = */ static_cast(n_head_kv), /* .head_size = */ static_cast(head_size), /* .sl_kv_max = */ static_cast(n_ctx), }; - jblas_reordered_attn_fp32_batch_kv_info(&kv_shape, &kv_cache_info); + bestla_reordered_attn_fp32_batch_kv_info(&kv_shape, &kv_cache_info); } struct ne_tensor* embd = ne_new_tensor_1d(ctx0, NE_TYPE_I32, N, NE_SIZE_CALC); @@ -168,7 +168,7 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp } #endif -#ifdef NE_TP_MODEL +#ifdef NS_TP_MODEL if (enable_tp) { // need to broadcast the ids broadcast(p_ctx, reinterpret_cast(embd->data), N * ne_element_size(embd)); @@ -191,7 +191,7 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp cur = ne_mul(ctx0, cur, model.layers[il].norm[0]); } ne_tensor *Qcur, *Kcur, *Vcur; - if (jblas_fusion_QKV_f32f32_support(model.layers[il].attn[0]->data, model.layers[il].attn[2]->data, + if (bestla_fusion_QKV_f32f32_support(model.layers[il].attn[0]->data, model.layers[il].attn[2]->data, model.layers[il].attn[4]->data, N, model.layers[il].attn[0]->ne[1], model.layers[il].attn[0]->ne[0]) && n_head == n_head_kv) { // fused execution of QKV @@ -335,14 +335,14 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp { const auto k_cache = ne_view_3d(ctx0, kv_self.k, // tensor head_size, n_ctx, n_head_kv, // ne - 0, 0, // nb (jblas managed) + 0, 0, // nb (bestla managed) il * k_size); // offset ne_build_forward_expand(&gf, ne_flash_attn_update_k(ctx0, k_cache, Kcur, n_past, is_ring_full)); const auto v_cache = ne_view_3d(ctx0, kv_self.v, // tensor head_size, n_ctx, n_head_kv, // ne - 0, 0, // nb (jblas managed) + 0, 0, // nb (bestla managed) il * v_size); // offset - // jblas alway view V as (D, n_head, seq) + // bestla alway view V as (D, n_head, seq) const auto Vcur_plain = ne_reshape_3d(ctx0, ne_view_1d(ctx0, Vcur, n_embd_gqa * N, 0), n_embd_gqa / n_head_kv, n_head_kv, N); ne_build_forward_expand(&gf, ne_flash_attn_update_v(ctx0, v_cache, Vcur_plain, n_past, is_ring_full)); @@ -354,7 +354,7 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp struct ne_tensor* K = ne_view_3d(ctx0, kv_self.k, // tensor head_size, n_cached, n_head_kv, // ne - kv_cache_info.stride_k_sl, kv_cache_info.stride_k_head_num, // nb (jblas managed) + kv_cache_info.stride_k_sl, kv_cache_info.stride_k_head_num, // nb (bestla managed) il * k_size); // offset *reinterpret_cast(&K->nb[0]) = kv_cache_info.k_layout; // us nb0 for layout if (is_ring_full) { @@ -371,7 +371,7 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp struct ne_tensor* V = ne_view_3d(ctx0, kv_self.v, // tensor n_cached, head_size, n_head_kv, // ne - kv_cache_info.stride_v_head_size, kv_cache_info.stride_v_head_num, // nb (jblas managed) + kv_cache_info.stride_v_head_size, kv_cache_info.stride_v_head_num, // nb (bestla managed) il * v_size); // offset *reinterpret_cast(&V->nb[0]) = kv_cache_info.v_layout; // us nb0 for layout ne_set_name(V, "V"); @@ -385,8 +385,9 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp // projection (no bias) cur = ne_mul_mat(ctx0, model.layers[il].attn[6], KQV_merged_contiguous); + cur = ne_add_inplace(ctx0, cur, model.layers[il].attn[7]); } -#ifdef NE_TP_MODEL +#ifdef NS_TP_MODEL if (enable_tp) { cur = ne_all_reduce(ctx0, cur); } @@ -406,9 +407,9 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp cur = ne_mul(ctx0, cur, model.layers[il].norm[1]); } - if (jblas_fusion_FFN_SiLu_f32f32_support(model.layers[il].ffn[0]->data, model.layers[il].ffn[1]->data, - model.layers[il].ffn[2]->data, N, cur->ne[0], - model.layers[il].ffn[0]->ne[1], model.layers[il].ffn[1]->ne[1])) { + if (bestla_fusion_FFN_SiLu_f32f32_support(model.layers[il].ffn[0]->data, model.layers[il].ffn[1]->data, + model.layers[il].ffn[2]->data, N, cur->ne[0], + model.layers[il].ffn[0]->ne[1], model.layers[il].ffn[1]->ne[1])) { cur = ne_ffn_silu(ctx0, model.layers[il].ffn[0], model.layers[il].ffn[1], model.layers[il].ffn[2], cur); } else { struct ne_tensor* tmp = ne_mul_mat(ctx0, model.layers[il].ffn[2], cur); @@ -417,7 +418,7 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp cur = ne_mul(ctx0, cur, tmp); cur = ne_mul_mat(ctx0, model.layers[il].ffn[1], cur); } -#ifdef NE_TP_MODEL +#ifdef NS_TP_MODEL // ffn2 and ffn0 use split row, ffn1 use split column if (enable_tp) { cur = ne_all_reduce(ctx0, cur); @@ -434,7 +435,7 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp lctx.use_buf(ctx0, 0); // used at the end to optionally extract the embeddings - struct ne_tensor* embeddings = NULL; + struct ne_tensor* embeddings = nullptr; // norm { inpL = ne_rms_norm(ctx0, inpL, hparams.rms_norm_eps); @@ -457,12 +458,9 @@ static bool llama_model_eval_internal(model_context* ctx, const model_input* inp ne_build_forward_expand(&gf, inpL); ne_graph_compute(ctx0, &gf); -#ifdef NE_PERF - bool engine_profiling_ = (getenv("ENGINE_PROFILING") != NULL); - if (engine_profiling_) { + if (ns_log_level() == 0 || ns_log_level() == 2) { ne_graph_profiling(&gf); } -#endif // update kv token count lctx.model.kv_self.n = n_cached; diff --git a/neural_speed/models/llama/llama_yarn_utils.cpp b/neural_speed/models/llama/llama_yarn_utils.cpp index 80d09fb84..2bd131e29 100644 --- a/neural_speed/models/llama/llama_yarn_utils.cpp +++ b/neural_speed/models/llama/llama_yarn_utils.cpp @@ -35,7 +35,7 @@ #include "models/model_utils/model_config.h" #include "models/model_utils/model_files.h" #include "models/model_utils/model_types.h" -#include "models/model_utils/model_utils.h" +#include "models/model_utils/quant_utils.h" #include "models/model_utils/util.h" #include "models/models.h" @@ -47,7 +47,7 @@ void model_load_internal(const std::string& fname, model_archs arch, model_conte ms->load(ctx, progress_callback, progress_callback_user_data); model_context& lctx = *ctx; - lctx.support_jblas_kv = true; + lctx.support_bestla_kv = true; } void Llama::init(const char* path_model, model_context* ctx, int n_gpu_layer_, bool use_mmap_, bool use_mlock_, @@ -65,6 +65,7 @@ void Llama::init(const char* path_model, model_context* ctx, int n_gpu_layer_, b auto& hparams = model.hparams; n_ff = hparams.n_mult; fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab); + fprintf(stderr, "%s: n_ctx = %u\n", __func__, hparams.max_seq_len); fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd); fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult); fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head); @@ -112,46 +113,46 @@ void Llama::load(model_context* ctx, model_progress_callback progress_callback, ml->ne_ctx = ne_ctx; - model.others[0] = ml->get_tensor("model.embed_tokens.weight", {n_embd, n_vocab}, NE_BACKEND_CPU); - model.others[1] = ml->get_tensor("model.norm.weight", {n_embd}, NE_BACKEND_CPU); - model.others[2] = ml->get_tensor("lm_head.weight", {n_embd, n_vocab}, - n_gpu_layer > static_cast(n_layer) ? MODEL_BACKEND_OFFLOAD : NE_BACKEND_CPU); - const int i_gpu_start = n_layer - n_gpu_layer; - model.layers.resize(n_layer); size_t vram_total = 0; - for (uint32_t i = 0; i < n_layer; ++i) { - const ne_backend backend = static_cast(i) < i_gpu_start ? NE_BACKEND_CPU : MODEL_BACKEND_OFFLOAD; - auto& layer = model.layers[i]; - std::string layers_i = "model.layers." + std::to_string(i); - - // attention norm - layer.norm[0] = ml->get_tensor(layers_i + ".input_layernorm.weight", {n_embd}, backend); - - // qkv GEMM - layer.attn[0] = ml->get_tensor(layers_i + ".self_attn.q_proj.weight", {n_embd, n_embd}, backend); - layer.attn[1] = ml->get_tensor(layers_i + ".self_attn.q_proj.bias", {n_embd}, backend); - layer.attn[2] = ml->get_tensor(layers_i + ".self_attn.k_proj.weight", {n_embd, n_embd / (n_head / n_head_kv)}, backend); - layer.attn[3] = ml->get_tensor(layers_i + ".self_attn.k_proj.bias", {n_embd}, backend); - layer.attn[4] = ml->get_tensor(layers_i + ".self_attn.v_proj.weight", {n_embd, n_embd / (n_head / n_head_kv)}, backend); - layer.attn[5] = ml->get_tensor(layers_i + ".self_attn.v_proj.bias", {n_embd}, backend); - layer.attn[6] = ml->get_tensor(layers_i + ".self_attn.o_proj.weight", {n_embd, n_embd}, backend); - layer.attn[7] = ml->get_tensor(layers_i + ".self_attn.o_proj.bias", {n_embd}, backend); - - - // ffn norm - layer.norm[1] = ml->get_tensor(layers_i + ".post_attention_layernorm.weight", {n_embd}, backend); - - // ffn GEMM - layer.ffn[0] = ml->get_tensor(layers_i + ".mlp.gate_proj.weight", {n_embd, n_ff}, backend); - layer.ffn[1] = ml->get_tensor(layers_i + ".mlp.down_proj.weight", {n_ff, n_embd}, backend); - layer.ffn[2] = ml->get_tensor(layers_i + ".mlp.up_proj.weight", {n_embd, n_ff}, backend); - - if (backend != NE_BACKEND_CPU) { - vram_total += ne_nbytes(layer.norm[0]) + ne_nbytes(layer.attn[0]) + ne_nbytes(layer.attn[1]) + - ne_nbytes(layer.attn[2]) + ne_nbytes(layer.attn[3]) + ne_nbytes(layer.norm[1]) + - ne_nbytes(layer.ffn[0]) + ne_nbytes(layer.ffn[1]) + ne_nbytes(layer.ffn[2]); + if (1) { + model.others[0] = ml->get_tensor("model.embed_tokens.weight", {n_embd, n_vocab}, NE_BACKEND_CPU); + model.others[1] = ml->get_tensor("model.norm.weight", {n_embd}, NE_BACKEND_CPU); + model.others[2] = ml->get_tensor("lm_head.weight", {n_embd, n_vocab}, + n_gpu_layer > static_cast(n_layer) ? MODEL_BACKEND_OFFLOAD : NE_BACKEND_CPU); + + for (uint32_t i = 0; i < n_layer; ++i) { + const ne_backend backend = static_cast(i) < i_gpu_start ? NE_BACKEND_CPU : MODEL_BACKEND_OFFLOAD; + auto& layer = model.layers[i]; + std::string layers_i = "model.layers." + std::to_string(i); + + // attention norm + layer.norm[0] = ml->get_tensor(layers_i + ".input_layernorm.weight", {n_embd}, backend); + + // qkv GEMM + layer.attn[0] = ml->get_tensor(layers_i + ".self_attn.q_proj.weight", {n_embd, n_embd}, backend); + layer.attn[1] = ml->get_tensor(layers_i + ".self_attn.q_proj.bias", {n_embd}, backend); + layer.attn[2] = ml->get_tensor(layers_i + ".self_attn.k_proj.weight", {n_embd, n_embd / (n_head / n_head_kv)}, backend); + layer.attn[3] = ml->get_tensor(layers_i + ".self_attn.k_proj.bias", {n_embd}, backend); + layer.attn[4] = ml->get_tensor(layers_i + ".self_attn.v_proj.weight", {n_embd, n_embd / (n_head / n_head_kv)}, backend); + layer.attn[5] = ml->get_tensor(layers_i + ".self_attn.v_proj.bias", {n_embd}, backend); + layer.attn[6] = ml->get_tensor(layers_i + ".self_attn.o_proj.weight", {n_embd, n_embd}, backend); + layer.attn[7] = ml->get_tensor(layers_i + ".self_attn.o_proj.bias", {n_embd}, backend); + + // ffn norm + layer.norm[1] = ml->get_tensor(layers_i + ".post_attention_layernorm.weight", {n_embd}, backend); + + // ffn GEMM + layer.ffn[0] = ml->get_tensor(layers_i + ".mlp.gate_proj.weight", {n_embd, n_ff}, backend); + layer.ffn[1] = ml->get_tensor(layers_i + ".mlp.down_proj.weight", {n_ff, n_embd}, backend); + layer.ffn[2] = ml->get_tensor(layers_i + ".mlp.up_proj.weight", {n_embd, n_ff}, backend); + + if (backend != NE_BACKEND_CPU) { + vram_total += ne_nbytes(layer.norm[0]) + ne_nbytes(layer.attn[0]) + ne_nbytes(layer.attn[1]) + + ne_nbytes(layer.attn[2]) + ne_nbytes(layer.attn[3]) + ne_nbytes(layer.norm[1]) + + ne_nbytes(layer.ffn[0]) + ne_nbytes(layer.ffn[1]) + ne_nbytes(layer.ffn[2]); + } } } @@ -168,7 +169,7 @@ void Llama::load(model_context* ctx, model_progress_callback progress_callback, model.tensors_by_name.emplace_back(lt.name, lt.ne_tensor); } - ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL); + ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : nullptr); if (progress_callback) { progress_callback(1.0f, progress_callback_user_data); @@ -182,8 +183,9 @@ void Llama::load(model_context* ctx, model_progress_callback progress_callback, class llama_quant_layer : public quant_layer_base { public: quant_params_internal get_layer_config(std::string layername, std::vector ne, ne_type type) override { - bool quantize = layername.rfind("weight") == layername.size() - 6; // ends with 'weight'? - if (layername.find("embed_tokens") != std::string::npos) { + bool quantize = layername.rfind("weight") == layername.size() - 6; + if ((layername.find("embed_tokens") != std::string::npos) || + (layername == "token_embd.weight" || layername == "tok_embeddings.weight")) { // special layer process, can be loaded by config file return quant_params_internal(); // return q4_0 to cover the usage of getrow } diff --git a/neural_speed/models/model_utils/model_types.h b/neural_speed/models/model_utils/model_types.h index d438dac33..191853916 100644 --- a/neural_speed/models/model_utils/model_types.h +++ b/neural_speed/models/model_utils/model_types.h @@ -469,7 +469,8 @@ class model_name_to_arch { {"dolly", MODEL_GPTNEOX}, {"polyglot", MODEL_GPTNEOX}, {"starcoder", MODEL_STARCODER}, {"falcon", MODEL_FALCON}, {"bloom", MODEL_BLOOM}, {"chatglm2", MODEL_CHATGLM2}, {"chatglm", MODEL_CHATGLM}, {"baichuan", MODEL_BAICHUAN}, {"mistral", MODEL_LLAMA}, - {"qwen", MODEL_QWEN}, {"phi", MODEL_PHI}, {"whisper", MODEL_WHISPER}}; + {"qwen", MODEL_QWEN}, {"phi", MODEL_PHI}, {"whisper", MODEL_WHISPER}, + {"llama_yarn", MODEL_LLAMA} }; }; #ifdef __cplusplus