Skip to content

Commit

Permalink
支持零一万物Yi模型
Browse files Browse the repository at this point in the history
  • Loading branch information
cgli authored and TylunasLi committed Mar 4, 2024
1 parent 2db84e3 commit 17fdbbf
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 10 deletions.
11 changes: 11 additions & 0 deletions docs/llama_cookbook.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,17 @@ XVERSE-13B-Chat V1 版本需要对输入做NFKC规范化,fastllm暂不支持
user_role="[|Human|]:", bot_role="\n[|AI|]:", history_sep="\n", dtype=dtype)
```

## Yi

* 01-ai/[Yi-6B-Chat](https://huggingface.co/01-ai/Yi-6B-Chat)

* 01-ai/[Yi-34B-Chat](https://huggingface.co/01-ai/Yi-34B-Chat)

```python
torch2flm.tofile(exportPath, model, tokenizer, pre_prompt="",
user_role="<|im_start|>user\n", bot_role="<|im_end|><|im_start|>assistant\n", history_sep="<|im_end|>\n", dtype=dtype)
```

### WizardCoder

* [WizardCoder-Python-7B-V1.0](https://huggingface.co/WizardLM/WizardCoder-Python-7B-V1.0)
Expand Down
2 changes: 2 additions & 0 deletions include/models/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ namespace fastllm {
float rope_factor = 1.f;

int num_key_value_heads = num_attention_heads;

float rms_norm_eps = 1e-6;
};
}

Expand Down
22 changes: 12 additions & 10 deletions src/models/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ namespace fastllm {
if (this->weight.dicts.find("max_position_embeddings") != this->weight.dicts.end()) {
max_positions = atoi(this->weight.dicts["max_position_embeddings"].c_str());
}
if (this->weight.dicts.find("rms_norm_eps") != this->weight.dicts.end()) {
rms_norm_eps = atof(this->weight.dicts["rms_norm_eps"].c_str());
}
if (this->weight.dicts.find("rope_scaling.type") != this->weight.dicts.end()) {
std::string type = this->weight.dicts["rope_scaling.type"];
if (type == "linear")
Expand All @@ -79,7 +82,6 @@ namespace fastllm {
if (this->weight.dicts.find("rope_theta") != this->weight.dicts.end()) {
rope_base = atof(this->weight.dicts["rope_theta"].c_str());
}
float factor = 1.0f;
if (this->weight.dicts.find("rope_scaling.factor") != this->weight.dicts.end()) {
rope_factor = atof(this->weight.dicts["rope_scaling.factor"].c_str());
}
Expand Down Expand Up @@ -140,7 +142,7 @@ namespace fastllm {
for (int i = 0; i < block_cnt; i++) {
ApplyDeviceMap(this->deviceMap, i + 1, block_cnt);
RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".input_layernorm.weight"],
1e-6, attenInput);
rms_norm_eps, attenInput);
std::string qWeightName = "model.layers." + std::to_string(i) + ".self_attn.q_proj.weight";
std::string qBiasName = "model.layers." + std::to_string(i) + ".self_attn.q_proj.bias";
std::string kWeightName = "model.layers." + std::to_string(i) + ".self_attn.k_proj.weight";
Expand Down Expand Up @@ -253,7 +255,7 @@ namespace fastllm {
Linear(attenOutput, weight[oWeightName], oBias, attenLastOutput);
AddTo(hiddenStates, attenLastOutput);
// 2. mlp
RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".post_attention_layernorm.weight"], 1e-6, attenInput);
RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".post_attention_layernorm.weight"], rms_norm_eps, attenInput);
Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.gate_proj.weight"], Data(), w1);
Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.up_proj.weight"], Data(), w3);
Silu(w1, w1);
Expand All @@ -274,7 +276,7 @@ namespace fastllm {
int lastRet = -1;
{
auto &hiddenStates = *lastHiddenStates;
RMSNorm(hiddenStates, weight["model.norm.weight"], 1e-6, hiddenStates);
RMSNorm(hiddenStates, weight["model.norm.weight"], rms_norm_eps, hiddenStates);
Linear(hiddenStates, weight["lm_head.weight"], Data(), logits);
if (generationConfig.output_logits && retLogits != nullptr) {
int size = logits.dims.back();
Expand Down Expand Up @@ -323,7 +325,7 @@ namespace fastllm {
for (int i = 0; i < block_cnt; i++) {
ApplyDeviceMap(this->deviceMap, i + 1, block_cnt);
RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".input_layernorm.weight"],
1e-6, attenInput);
rms_norm_eps, attenInput);
std::string qWeightName = "model.layers." + std::to_string(i) + ".self_attn.q_proj.weight";
std::string qBiasName = "model.layers." + std::to_string(i) + ".self_attn.q_proj.bias";
std::string kWeightName = "model.layers." + std::to_string(i) + ".self_attn.k_proj.weight";
Expand Down Expand Up @@ -439,7 +441,7 @@ namespace fastllm {
Linear(attenOutput, weight[oWeightName], oBias, attenLastOutput);
AddTo(hiddenStates, attenLastOutput);
// 2. mlp
RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".post_attention_layernorm.weight"], 1e-6, attenInput);
RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".post_attention_layernorm.weight"], rms_norm_eps, attenInput);
Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.gate_proj.weight"], Data(), w1);
Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.up_proj.weight"], Data(), w3);
Silu(w1, w1);
Expand All @@ -461,7 +463,7 @@ namespace fastllm {
std::vector <int> lastRet;
{
auto &hiddenStates = *lastHiddenStates;
RMSNorm(hiddenStates, weight["model.norm.weight"], 1e-6, hiddenStates);
RMSNorm(hiddenStates, weight["model.norm.weight"], rms_norm_eps, hiddenStates);
Linear(hiddenStates, weight["lm_head.weight"], Data(), logits);
if (generationConfig.IsSimpleGreedy()) {
TopK(logits, topk, 1);
Expand Down Expand Up @@ -514,7 +516,7 @@ namespace fastllm {
for (int i = 0; i < block_cnt; i++) {
ApplyDeviceMap(this->deviceMap, i + 1, block_cnt);
RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".input_layernorm.weight"],
1e-6, attenInput);
rms_norm_eps, attenInput);
std::string qWeightName = "model.layers." + std::to_string(i) + ".self_attn.q_proj.weight";
std::string qBiasName = "model.layers." + std::to_string(i) + ".self_attn.q_proj.bias";
std::string kWeightName = "model.layers." + std::to_string(i) + ".self_attn.k_proj.weight";
Expand Down Expand Up @@ -653,7 +655,7 @@ namespace fastllm {
Linear(attenOutput, weight[oWeightName], oBias, attenLastOutput);
AddTo(hiddenStates, attenLastOutput);
// 2. mlp
RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".post_attention_layernorm.weight"], 1e-6, attenInput);
RMSNorm(hiddenStates, this->weight["model.layers." + std::to_string(i) + ".post_attention_layernorm.weight"], rms_norm_eps, attenInput);
Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.gate_proj.weight"], Data(), w1);
Linear(attenInput, weight["model.layers." + std::to_string(i) + ".mlp.up_proj.weight"], Data(), w3);
Silu(w1, w1);
Expand All @@ -663,7 +665,7 @@ namespace fastllm {
}

Data logits, curLogit;
RMSNorm(hiddenStates, weight["model.norm.weight"], 1e-6, hiddenStates);
RMSNorm(hiddenStates, weight["model.norm.weight"], rms_norm_eps, hiddenStates);
Linear(hiddenStates, weight["lm_head.weight"], Data(), logits);
std::vector <int> lastRet;
int total = 0;
Expand Down

0 comments on commit 17fdbbf

Please sign in to comment.