Skip to content

Commit

Permalink
更改Jinja elif逻辑,修复Qwen2.5多轮;支持Jinja负数下标,支持Qwen1.5最初版本
Browse files Browse the repository at this point in the history
  • Loading branch information
cgli committed Dec 15, 2024
1 parent ee2b970 commit 95094bd
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 12 deletions.
4 changes: 2 additions & 2 deletions docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
* **离线转换** (convert offline)
将原始模型转换为.flm格式的模型,一些[模型](#flm模型库)已经转换好。

* **加载后转换(不推荐)**
* **加载后转换(不推荐)** (convert on-the-fly (NOT RECOMMENDED))
将原始模型加载为HuggingFace模型,再通过`from_hf()`方法,转换并加速,这种方法内存占用大且速度慢,目前不再推荐。

## 支持模型一览 Model List
Expand Down Expand Up @@ -74,7 +74,7 @@
| Qwen/Qwen2.5-32B-Instruct ||||
| Qwen/Qwen2.5-72B-Instruct | |||

> 注3: 需要更新,检查 `tokenizer_config.json` 是否为最新版本
> 注3: ~~需要更新,检查 `tokenizer_config.json` 是否为最新版本~~
### DeepSeek系列

Expand Down
45 changes: 35 additions & 10 deletions src/template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ namespace fastllm {
bool JinjaVar::BoolValue() const {
if (this->type == JinjaInt) {
return (this->intValue != 0);
} else if (this->stringValue == "true") {
return true;
} else if (this->stringValue == "false") {
return false;
} else if (this->type == JinjaString) {
return true;
} else if (this->type == JinjaArray) {
return !this->arrayValue.empty();
}
Expand All @@ -22,8 +22,11 @@ namespace fastllm {
JinjaVar& JinjaVar::operator[] (const JinjaVar &b) {
if (this->type == JinjaArray) {
AssertInFastLLM(b.type == JinjaInt, "Jinja Error: subscript for array should be integer.");
AssertInFastLLM(b.intValue < this->arrayValue.size(), "Jinja error: subscript out of range.");
return this->arrayValue[b.intValue];
long long value = b.intValue;
if (value < 0)
value = b.intValue + this->arrayValue.size();
AssertInFastLLM(value < this->arrayValue.size(), "Jinja error: subscript out of range.");
return this->arrayValue[value];
} else if (this->type == JinjaDict) {
return this->dictValue[b.DirectValue()];
} else {
Expand Down Expand Up @@ -452,8 +455,29 @@ namespace fastllm {
}
vars.pop_back();
vars.push_back(a.type == JinjaVar::JinjaNone ? JinjaVar(1) : JinjaVar(!a.BoolValue()));
} else if (it.type == JinjaToken::JinjaTokenSub) {
AssertInFastLLM(vars.size() > 0, "Jinja Error: expression '-' error.");
JinjaVar a = vars.back();
if (a.type == JinjaVar::JinjaNone)
a = local[a];
AssertInFastLLM(a.type == JinjaVar::JinjaInt || a.type == JinjaVar::JinjaFloat, "Jinja Error: expression '-' error.");
if (vars.size() > 1) {
JinjaVar b = vars[vars.size() - 2];
if (b.type == JinjaVar::JinjaNone)
b = local[b];
if (b.type == JinjaVar::JinjaInt || b.type == JinjaVar::JinjaFloat) {
vars.pop_back();
vars.pop_back();
vars.push_back(JinjaBinaryOp(a, b, it.type));
continue;
}
}
vars.pop_back();
if (a.type == JinjaVar::JinjaInt)
vars.push_back(JinjaVar(-a.intValue));
else
vars.push_back(JinjaVar(-a.floatValue));
} else if (it.type == JinjaToken::JinjaTokenAdd ||
it.type == JinjaToken::JinjaTokenSub ||
it.type == JinjaToken::JinjaTokenMul ||
it.type == JinjaToken::JinjaTokenDiv ||
it.type == JinjaToken::JinjaTokenMod ||
Expand Down Expand Up @@ -531,14 +555,15 @@ namespace fastllm {
std::string iterId = curBlock.tokens[1].value;
JinjaVar exp = ComputeExpression(var, curBlock.tokens, 3, curBlock.tokens.size());
JinjaVar original = var[iterId];
var["loop"] = {{"index", 1}, {"index0", 0}, {"first", 1}};
var["loop"] = {{"index", 1}, {"index0", 0}, {"first", 1}, {"last", 0}};
if (exp.type == JinjaVar::JinjaArray) {
for (auto &it : exp.arrayValue) {
var[iterId] = it;
Parse(i + 1, endPos, var, ret);
var["loop"]["index"].intValue++;
var["loop"]["index0"].intValue++;
var["loop"]["first"].intValue = 0;
var["loop"]["last"].intValue = (var["loop"]["index"].intValue == exp.arrayValue.size());
}
} else if (exp.type == JinjaVar::JinjaDict) {
for (auto &it : exp.dictValue) {
Expand All @@ -547,6 +572,7 @@ namespace fastllm {
var["loop"]["index"].intValue++;
var["loop"]["index0"].intValue++;
var["loop"]["first"].intValue = 0;
var["loop"]["last"].intValue = (var["loop"]["index"].intValue == exp.arrayValue.size());
}
} else {
ErrorInFastLLM(exp.Dump() + " is not iterable");
Expand All @@ -561,13 +587,12 @@ namespace fastllm {
if (blocks[j].type == JinjaBlock::JinjaBlockType::JinjaBlockIf) {
cnt++;
} else if (blocks[j].type == JinjaBlock::JinjaBlockType::JinjaBlockElse) {
if (cnt == 0) {
if (cnt == 0 && elsePos == -1) {
elsePos = j;
}
} else if (blocks[j].type == JinjaBlock::JinjaBlockType::JinjaBlockElseIf) {
if (cnt == 0) {
endPos = j;
break;
if (cnt == 0 && elsePos == -1) {
elsePos = j;
}
} else if (blocks[j].type == JinjaBlock::JinjaBlockType::JinjaBlockEndIf) {
if ((cnt--) == 0) {
Expand Down

0 comments on commit 95094bd

Please sign in to comment.