From 95094bdbca006b03365b268ccc8768a17b1417d6 Mon Sep 17 00:00:00 2001 From: cgli Date: Sun, 15 Dec 2024 14:53:52 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=94=B9Jinja=20elif=E9=80=BB?= =?UTF-8?q?=E8=BE=91=EF=BC=8C=E4=BF=AE=E5=A4=8DQwen2.5=E5=A4=9A=E8=BD=AE?= =?UTF-8?q?=EF=BC=9B=E6=94=AF=E6=8C=81Jinja=E8=B4=9F=E6=95=B0=E4=B8=8B?= =?UTF-8?q?=E6=A0=87=EF=BC=8C=E6=94=AF=E6=8C=81Qwen1.5=E6=9C=80=E5=88=9D?= =?UTF-8?q?=E7=89=88=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/models.md | 4 ++-- src/template.cpp | 45 +++++++++++++++++++++++++++++++++++---------- 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/docs/models.md b/docs/models.md index 4052dcb..84d125e 100644 --- a/docs/models.md +++ b/docs/models.md @@ -10,7 +10,7 @@ * **离线转换** (convert offline) 将原始模型转换为.flm格式的模型,一些[模型](#flm模型库)已经转换好。 -* **加载后转换(不推荐)** +* **加载后转换(不推荐)** (convert on-the-fly (NOT RECOMMENDED)) 将原始模型加载为HuggingFace模型,再通过`from_hf()`方法,转换并加速,这种方法内存占用大且速度慢,目前不再推荐。 ## 支持模型一览 Model List @@ -74,7 +74,7 @@ | Qwen/Qwen2.5-32B-Instruct | √ | √ | ✔ | | Qwen/Qwen2.5-72B-Instruct | | √ | ✔ | -> 注3: 需要更新,检查 `tokenizer_config.json` 是否为最新版本 +> 注3: ~~需要更新,检查 `tokenizer_config.json` 是否为最新版本~~ ### DeepSeek系列 diff --git a/src/template.cpp b/src/template.cpp index ddacaa6..3fd8bc3 100644 --- a/src/template.cpp +++ b/src/template.cpp @@ -8,10 +8,10 @@ namespace fastllm { bool JinjaVar::BoolValue() const { if (this->type == JinjaInt) { return (this->intValue != 0); - } else if (this->stringValue == "true") { - return true; } else if (this->stringValue == "false") { return false; + } else if (this->type == JinjaString) { + return true; } else if (this->type == JinjaArray) { return !this->arrayValue.empty(); } @@ -22,8 +22,11 @@ namespace fastllm { JinjaVar& JinjaVar::operator[] (const JinjaVar &b) { if (this->type == JinjaArray) { AssertInFastLLM(b.type == JinjaInt, "Jinja Error: subscript for array should be integer."); - AssertInFastLLM(b.intValue < this->arrayValue.size(), "Jinja error: subscript out of range."); - return this->arrayValue[b.intValue]; + long long value = b.intValue; + if (value < 0) + value = b.intValue + this->arrayValue.size(); + AssertInFastLLM(value < this->arrayValue.size(), "Jinja error: subscript out of range."); + return this->arrayValue[value]; } else if (this->type == JinjaDict) { return this->dictValue[b.DirectValue()]; } else { @@ -452,8 +455,29 @@ namespace fastllm { } vars.pop_back(); vars.push_back(a.type == JinjaVar::JinjaNone ? JinjaVar(1) : JinjaVar(!a.BoolValue())); + } else if (it.type == JinjaToken::JinjaTokenSub) { + AssertInFastLLM(vars.size() > 0, "Jinja Error: expression '-' error."); + JinjaVar a = vars.back(); + if (a.type == JinjaVar::JinjaNone) + a = local[a]; + AssertInFastLLM(a.type == JinjaVar::JinjaInt || a.type == JinjaVar::JinjaFloat, "Jinja Error: expression '-' error."); + if (vars.size() > 1) { + JinjaVar b = vars[vars.size() - 2]; + if (b.type == JinjaVar::JinjaNone) + b = local[b]; + if (b.type == JinjaVar::JinjaInt || b.type == JinjaVar::JinjaFloat) { + vars.pop_back(); + vars.pop_back(); + vars.push_back(JinjaBinaryOp(a, b, it.type)); + continue; + } + } + vars.pop_back(); + if (a.type == JinjaVar::JinjaInt) + vars.push_back(JinjaVar(-a.intValue)); + else + vars.push_back(JinjaVar(-a.floatValue)); } else if (it.type == JinjaToken::JinjaTokenAdd || - it.type == JinjaToken::JinjaTokenSub || it.type == JinjaToken::JinjaTokenMul || it.type == JinjaToken::JinjaTokenDiv || it.type == JinjaToken::JinjaTokenMod || @@ -531,7 +555,7 @@ namespace fastllm { std::string iterId = curBlock.tokens[1].value; JinjaVar exp = ComputeExpression(var, curBlock.tokens, 3, curBlock.tokens.size()); JinjaVar original = var[iterId]; - var["loop"] = {{"index", 1}, {"index0", 0}, {"first", 1}}; + var["loop"] = {{"index", 1}, {"index0", 0}, {"first", 1}, {"last", 0}}; if (exp.type == JinjaVar::JinjaArray) { for (auto &it : exp.arrayValue) { var[iterId] = it; @@ -539,6 +563,7 @@ namespace fastllm { var["loop"]["index"].intValue++; var["loop"]["index0"].intValue++; var["loop"]["first"].intValue = 0; + var["loop"]["last"].intValue = (var["loop"]["index"].intValue == exp.arrayValue.size()); } } else if (exp.type == JinjaVar::JinjaDict) { for (auto &it : exp.dictValue) { @@ -547,6 +572,7 @@ namespace fastllm { var["loop"]["index"].intValue++; var["loop"]["index0"].intValue++; var["loop"]["first"].intValue = 0; + var["loop"]["last"].intValue = (var["loop"]["index"].intValue == exp.arrayValue.size()); } } else { ErrorInFastLLM(exp.Dump() + " is not iterable"); @@ -561,13 +587,12 @@ namespace fastllm { if (blocks[j].type == JinjaBlock::JinjaBlockType::JinjaBlockIf) { cnt++; } else if (blocks[j].type == JinjaBlock::JinjaBlockType::JinjaBlockElse) { - if (cnt == 0) { + if (cnt == 0 && elsePos == -1) { elsePos = j; } } else if (blocks[j].type == JinjaBlock::JinjaBlockType::JinjaBlockElseIf) { - if (cnt == 0) { - endPos = j; - break; + if (cnt == 0 && elsePos == -1) { + elsePos = j; } } else if (blocks[j].type == JinjaBlock::JinjaBlockType::JinjaBlockEndIf) { if ((cnt--) == 0) {