diff --git a/converter/convert-tokenizer-llama2.py b/converter/convert-tokenizer-llama2.py
new file mode 100644
index 0000000..9856dfb
--- /dev/null
+++ b/converter/convert-tokenizer-llama2.py
@@ -0,0 +1,42 @@
+import sys
+import os
+from sentencepiece import SentencePieceProcessor
+writer = __import__('tokenizer-writer')
+
+chatTemplate = "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"
+
+def printUsage():
+ print('Usage: python convert-tokenizer-llama2.py ')
+ print()
+ print('Options:')
+ print(' The path to the folder with llama2 folder path')
+
+if __name__ == '__main__':
+ if (len(sys.argv) < 2):
+ printUsage()
+ exit(1)
+
+ dirPath = sys.argv[1]
+ modelPath = os.path.join(dirPath, 'tokenizer.model')
+ processor = SentencePieceProcessor(model_file=modelPath)
+
+ vocabSize = processor.vocab_size()
+ tokens = []
+ scores = []
+ for i in range(vocabSize):
+ t = processor.id_to_piece(i)
+ s = processor.get_score(i)
+ t = t.replace('▁', ' ') # sentencepiece uses this character as whitespace
+ b = t.encode('utf-8')
+ tokens.append(b)
+ scores.append(s)
+
+ outputFileName = 'dllama_tokenizer_llama2.t'
+ with open(outputFileName, 'wb') as outputFile:
+ writer.writeTokenizer(outputFile, {
+ 'bos_id': processor.bos_id(),
+ 'eos_id': processor.eos_id(),
+ 'chat_eos_id': processor.eos_id(),
+ }, tokens, scores, chatTemplate.encode('utf-8'), None)
+
+ print(f'✅ Created {outputFileName}')
diff --git a/src/app.cpp b/src/app.cpp
index 517e325..018ca2b 100644
--- a/src/app.cpp
+++ b/src/app.cpp
@@ -16,6 +16,15 @@ FloatType parseFloatType(char* val) {
exit(EXIT_FAILURE);
}
+ChatTemplateType parseChatTemplateType(char* val) {
+ if (strcmp(val, "llama2") == 0) return TEMPLATE_LLAMA2;
+ if (strcmp(val, "llama3") == 0) return TEMPLATE_LLAMA3;
+ if (strcmp(val, "zephyr") == 0) return TEMPLATE_ZEPHYR;
+ if (strcmp(val, "chatml") == 0) return TEMPLATE_CHATML;
+ throw std::runtime_error("Invalid chat template type");
+
+}
+
AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) {
AppArgs args;
args.mode = NULL;
@@ -31,6 +40,7 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) {
args.topp = 0.9f;
args.steps = 0;
args.seed = (unsigned long long)time(NULL);
+ args.chatTemplateType = TEMPLATE_UNKNOWN;
int i = 1;
if (hasMode && argc > 1) {
@@ -84,6 +94,8 @@ AppArgs AppArgs::parse(int argc, char** argv, bool hasMode) {
args.topp = atof(argv[i + 1]);
} else if (strcmp(argv[i], "--seed") == 0) {
args.seed = atoll(argv[i + 1]);
+ } else if (strcmp(argv[i], "--chat-template") == 0) {
+ args.chatTemplateType = parseChatTemplateType(argv[i + 1]);
} else {
printf("Unknown option %s\n", argv[i]);
exit(EXIT_FAILURE);
diff --git a/src/app.hpp b/src/app.hpp
index af717d6..d53e9e7 100644
--- a/src/app.hpp
+++ b/src/app.hpp
@@ -32,6 +32,7 @@ class AppArgs {
pos_t steps;
bool benchmark;
unsigned long long seed;
+ ChatTemplateType chatTemplateType;
// worker
int port;
diff --git a/src/apps/dllama-api/dllama-api.cpp b/src/apps/dllama-api/dllama-api.cpp
index 002206f..b13b6d6 100644
--- a/src/apps/dllama-api/dllama-api.cpp
+++ b/src/apps/dllama-api/dllama-api.cpp
@@ -396,7 +396,7 @@ void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer,
SocketServer* server = new SocketServer(args->port);
TokenizerChatStops stops(tokenizer);
- ChatTemplate chatTemplate(tokenizer->chatTemplate, stops.stops[0]);
+ ChatTemplate chatTemplate(args->chatTemplateType, tokenizer->chatTemplate, stops.stops[0]);
EosDetector eosDetector(tokenizer->chatEosId, stops.nStops, stops.stops, stops.maxStopLength, stops.maxStopLength);
ApiServer api(inference, tokenizer, sampler, args, spec, &eosDetector, &chatTemplate);
diff --git a/src/apps/dllama/dllama.cpp b/src/apps/dllama/dllama.cpp
index 69d7814..f28c123 100644
--- a/src/apps/dllama/dllama.cpp
+++ b/src/apps/dllama/dllama.cpp
@@ -195,7 +195,7 @@ class Chat {
void chat(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec) {
TokenizerChatStops stops(tokenizer);
- ChatTemplate chatTemplate(tokenizer->chatTemplate, stops.stops[0]);
+ ChatTemplate chatTemplate(args->chatTemplateType, tokenizer->chatTemplate, stops.stops[0]);
EosDetector eosDetector(tokenizer->chatEosId, stops.nStops, stops.stops, stops.maxStopLength, stops.maxStopLength);
Chat chat(inference, tokenizer, sampler, args, spec, &eosDetector, &chatTemplate);
diff --git a/src/tokenizer-test.cpp b/src/tokenizer-test.cpp
index b4809d4..df56d92 100644
--- a/src/tokenizer-test.cpp
+++ b/src/tokenizer-test.cpp
@@ -12,13 +12,13 @@
#define EOS_ID 10000
void testChatTemplate() {
- ChatTemplate t0("{\% set loop_messages = messages \%}{\% for message in loop_messages \%}{\% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' \%}{\% if loop.index0 == 0 \%}{\% set content = bos_token + content \%}{\% endif \%}{{ content }}{\% endfor \%}{\% if add_generation_prompt \%}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{\% endif \%}", "");
+ ChatTemplate t0(TEMPLATE_UNKNOWN, "{\% set loop_messages = messages \%}{\% for message in loop_messages \%}{\% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' \%}{\% if loop.index0 == 0 \%}{\% set content = bos_token + content \%}{\% endif \%}{{ content }}{\% endfor \%}{\% if add_generation_prompt \%}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{\% endif \%}", "");
assert(t0.type == TEMPLATE_LLAMA3);
- ChatTemplate t1("{{bos_token}}{\% for message in messages \%}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{\% endfor \%}{\% if add_generation_prompt \%}{{ '<|im_start|>assistant\n' }}{\% endif \%}", "");
+ ChatTemplate t1(TEMPLATE_UNKNOWN, "{{bos_token}}{\% for message in messages \%}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{\% endfor \%}{\% if add_generation_prompt \%}{{ '<|im_start|>assistant\n' }}{\% endif \%}", "");
assert(t1.type == TEMPLATE_CHATML);
- ChatTemplate t2("{\% for message in messages \%}\n{\% if message['role'] == 'user' \%}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{\% elif message['role'] == 'system' \%}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{\% elif message['role'] == 'assistant' \%}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{\% endif \%}\n{\% if loop.last and add_generation_prompt \%}\n{{ '<|assistant|>' }}\n{\% endif \%}\n{\% endfor \%}", "");
+ ChatTemplate t2(TEMPLATE_UNKNOWN, "{\% for message in messages \%}\n{\% if message['role'] == 'user' \%}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{\% elif message['role'] == 'system' \%}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{\% elif message['role'] == 'assistant' \%}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{\% endif \%}\n{\% if loop.last and add_generation_prompt \%}\n{{ '<|assistant|>' }}\n{\% endif \%}\n{\% endfor \%}", "");
assert(t2.type == TEMPLATE_ZEPHYR);
printf("✅ ChatTemplate\n");
diff --git a/src/tokenizer.cpp b/src/tokenizer.cpp
index 74d8ab5..f81f1a7 100644
--- a/src/tokenizer.cpp
+++ b/src/tokenizer.cpp
@@ -433,27 +433,54 @@ TokenizerChatStops::~TokenizerChatStops() {
delete[] stops;
}
-ChatTemplate::ChatTemplate(const char* chatTemplate, const char* eos) {
- if (chatTemplate == NULL)
- throw std::runtime_error("The tokenizer does not include chat template");
+ChatTemplate::ChatTemplate(const ChatTemplateType type, const char* chatTemplate, const char* eos) {
+ if (type == TEMPLATE_UNKNOWN) {
+ if (chatTemplate == NULL)
+ throw std::runtime_error("The tokenizer does not include chat template");
+ if (strstr(chatTemplate, "[INST]") != NULL) {
+ this->type = TEMPLATE_LLAMA2;
+ } else if (strstr(chatTemplate, "<|start_header_id|>") != NULL) {
+ this->type = TEMPLATE_LLAMA3;
+ } else if (strstr(chatTemplate, "<|user|>") != NULL) {
+ this->type = TEMPLATE_ZEPHYR;
+ } else if (strstr(chatTemplate, "<|im_start|>") != NULL) {
+ this->type = TEMPLATE_CHATML;
+ } else {
+ throw new std::runtime_error("Not supported chat template");
+ }
+ } else {
+ this->type = type;
+ }
+ this->eos = eos;
printf("⭐ chat template: ");
- if (strstr(chatTemplate, "<|start_header_id|>") != NULL) {
- type = TEMPLATE_LLAMA3;
+ if (this->type == TEMPLATE_LLAMA2) {
+ printf("llama2\n");
+ } else if (this->type == TEMPLATE_LLAMA3) {
printf("llama3\n");
- } else if (strstr(chatTemplate, "<|user|>") != NULL) {
- type = TEMPLATE_ZEPHYR;
+ } else if (this->type == TEMPLATE_ZEPHYR) {
printf("zephyr\n");
- } else if (strstr(chatTemplate, "<|im_start|>") != NULL) {
- type = TEMPLATE_CHATML;
+ } else if (this->type == TEMPLATE_CHATML) {
printf("chatml\n");
- } else throw new std::runtime_error("Not supported chat template");
- this->eos = eos;
+ }
}
std::string ChatTemplate::generate(unsigned int nMessages, ChatItem* items, bool appendGenerationPrompt) {
std::ostringstream buffer;
- if (type == TEMPLATE_LLAMA3) {
+ if (type == TEMPLATE_LLAMA2) {
+ unsigned int i = 0;
+ if (nMessages >= 2 && items[0].role == "system" && items[1].role == "user") {
+ buffer << "[INST] <>\n" << items[0].message << "\n<>\n\n" << items[1].message << " [/INST]" << eos;
+ i += 2;
+ }
+ for (; i < nMessages; i++) {
+ if (items[i].role == "assistant") {
+ buffer << items[i].message << eos;
+ } else if (items[i].role == "user") {
+ buffer << "[INST] " << items[i].message << " [/INST]" << eos;
+ }
+ }
+ } else if (type == TEMPLATE_LLAMA3) {
for (unsigned int i = 0; i < nMessages; i++)
buffer << "<|start_header_id|>" << items[i].role << "<|end_header_id|>\n\n" << items[i].message << eos;
if (appendGenerationPrompt)
diff --git a/src/tokenizer.hpp b/src/tokenizer.hpp
index 5ad7b11..5323eae 100644
--- a/src/tokenizer.hpp
+++ b/src/tokenizer.hpp
@@ -89,9 +89,11 @@ class TokenizerChatStops {
};
enum ChatTemplateType {
- TEMPLATE_LLAMA3 = 0,
- TEMPLATE_ZEPHYR = 1,
- TEMPLATE_CHATML = 2,
+ TEMPLATE_UNKNOWN = 0,
+ TEMPLATE_LLAMA2 = 1,
+ TEMPLATE_LLAMA3 = 2,
+ TEMPLATE_ZEPHYR = 3,
+ TEMPLATE_CHATML = 4,
};
struct ChatItem {
@@ -103,7 +105,7 @@ class ChatTemplate {
public:
const char* eos;
ChatTemplateType type;
- ChatTemplate(const char* chatTemplate, const char* eos);
+ ChatTemplate(const ChatTemplateType type, const char* chatTemplate, const char* eos);
std::string generate(unsigned int nMessages, ChatItem* items, bool appendGenerationPrompt);
};