diff --git a/README.md b/README.md index e3eae60bb015b..08661831526fa 100644 --- a/README.md +++ b/README.md @@ -122,6 +122,7 @@ Typically finetunes of the base models below are supported as well. - [x] [Command-R](https://huggingface.co/CohereForAI/c4ai-command-r-v01) - [x] [SEA-LION](https://huggingface.co/models?search=sea-lion) - [x] [GritLM-7B](https://huggingface.co/GritLM/GritLM-7B) + [GritLM-8x7B](https://huggingface.co/GritLM/GritLM-8x7B) +- [x] [OLMo](https://allenai.org/olmo) (instructions for supporting more models: [HOWTO-add-model.md](./docs/HOWTO-add-model.md)) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index c14186abbc2a6..358dba8ed9d90 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2636,6 +2636,66 @@ def set_gguf_parameters(self): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) +@Model.register("OlmoForCausalLM") +@Model.register("OLMoForCausalLM") +class OlmoModel(Model): + model_arch = gguf.MODEL_ARCH.OLMO + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_layer_norm_eps(1e-5) + if "clip_qkv" in self.hparams is not None: + self.gguf_writer.add_clamp_kqv(self.hparams["clip_qkv"]) + + # Same as super class, but permuting q_proj, k_proj + # Copied from: LlamaModel + def write_tensors(self): + block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer"))) + tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count) + n_head = self.hparams.get("num_attention_heads") + n_kv_head = self.hparams.get("num_key_value_heads") + for name, data_torch in self.get_tensors(): + old_dtype = data_torch.dtype + + # convert any unsupported data types to float32 + if data_torch.dtype not in (torch.float16, torch.float32): + data_torch = data_torch.to(torch.float32) + + data = data_torch.numpy() + + if name.endswith("q_proj.weight"): + data = permute(data, n_head, n_head) + if name.endswith("k_proj.weight"): + data = permute(data, n_head, n_kv_head) + + data = data.squeeze() + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print(f"Can not map tensor {name!r}") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if self.ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # 1d tensors need to be converted to float32 + if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if self.ftype == 1 and data_dtype == np.float32 and n_dims == 2: + data = data.astype(np.float16) + + print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}") + + self.gguf_writer.add_tensor(new_name, data) + + ###### CONVERSION LOGIC ###### diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index feae03e109173..ba24065a891c7 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -135,6 +135,7 @@ class MODEL_ARCH(IntEnum): XVERSE = auto() COMMAND_R = auto() DBRX = auto() + OLMO = auto() class MODEL_TENSOR(IntEnum): @@ -210,6 +211,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.DBRX: "dbrx", + MODEL_ARCH.OLMO: "olmo", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -695,6 +697,17 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.OLMO: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], # TODO } diff --git a/llama.cpp b/llama.cpp index 18e473c095cd5..fa7c022f29130 100644 --- a/llama.cpp +++ b/llama.cpp @@ -222,6 +222,7 @@ enum llm_arch { LLM_ARCH_XVERSE, LLM_ARCH_COMMAND_R, LLM_ARCH_DBRX, + LLM_ARCH_OLMO, LLM_ARCH_UNKNOWN, }; @@ -256,6 +257,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_DBRX, "dbrx" }, + { LLM_ARCH_OLMO, "olmo" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -990,6 +992,20 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_OLMO, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -4070,6 +4086,18 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_OLMO: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); + + switch (hparams.n_layer) { + case 22: model.type = e_model::MODEL_1B; break; + case 32: model.type = e_model::MODEL_7B; break; + case 80: model.type = e_model::MODEL_70B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -5666,6 +5694,37 @@ static bool llm_load_tensors( layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + } + } break; + case LLM_ARCH_OLMO: // adapted from LLM_ARCH_LLAMA with norm params removed + { + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false); + // if output is NULL, init from the input tok embed + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + ml.n_created--; // artificial tensor + ml.size_data += ggml_nbytes(model.output); + } + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); @@ -10096,6 +10155,139 @@ struct llm_build_context { return gf; } + + // ref: https://allenai.org/olmo + // based on the original build_llama() function, changes: + // * non-parametric layer norm + // * clamp qkv + // * removed bias + // * removed MoE + struct ggml_cgraph * build_olmo() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + // mutable variable, needed during the last layer of the computation to skip unused tokens + int32_t n_tokens = this->n_tokens; + + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, + NULL, NULL, + LLM_NORM, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (hparams.f_clamp_kqv > 0.0f) { + Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(Qcur, "Qcur", il); + } + + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (hparams.f_clamp_kqv > 0.0f) { + Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(Kcur, "Kcur", il); + } + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (hparams.f_clamp_kqv > 0.0f) { + Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom( + ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, + n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + cur = llm_build_kv(ctx0, model, hparams, kv_self, gf, + model.layers[il].wo, nullptr, + Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + n_tokens = n_outputs; + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = llm_build_norm(ctx0, ffn_inp, hparams, + NULL, NULL, + LLM_NORM, cb, il); + cb(cur, "ffn_norm", il); + + cur = llm_build_ffn(ctx0, cur, + model.layers[il].ffn_up, NULL, + model.layers[il].ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + ggml_tensor * layer_dir = lctx.cvec.tensor_for(il); + if (layer_dir != nullptr) { + cur = ggml_add(ctx0, cur, layer_dir); + } + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, + NULL, NULL, + LLM_NORM, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -10301,6 +10493,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_dbrx(); } break; + case LLM_ARCH_OLMO: + { + result = llm.build_olmo(); + } break; default: GGML_ASSERT(false); } @@ -15154,6 +15350,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_MINICPM: case LLM_ARCH_XVERSE: case LLM_ARCH_COMMAND_R: + case LLM_ARCH_OLMO: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2