diff --git a/examples/tokenizer-playground/src/worker.js b/examples/tokenizer-playground/src/worker.js
index e3739e572..4db09bdc0 100644
--- a/examples/tokenizer-playground/src/worker.js
+++ b/examples/tokenizer-playground/src/worker.js
@@ -22,6 +22,7 @@ self.addEventListener('message', async (event) => {
// NOTE: We just remove the StripDecoder from the llama tokenizer
switch (tokenizer.constructor.name) {
case 'LlamaTokenizer':
+ case 'Grok1Tokenizer':
// tokenizer.decoder.decoders.at(-1).constructor.name === 'StripDecoder'
tokenizer.decoder.decoders.pop();
break;
diff --git a/package-lock.json b/package-lock.json
index 6945ebd61..d117b733f 100644
--- a/package-lock.json
+++ b/package-lock.json
@@ -9,7 +9,7 @@
"version": "2.16.0",
"license": "Apache-2.0",
"dependencies": {
- "@huggingface/jinja": "^0.2.1",
+ "@huggingface/jinja": "^0.2.2",
"onnxruntime-web": "1.14.0",
"sharp": "^0.32.0"
},
@@ -745,9 +745,9 @@
}
},
"node_modules/@huggingface/jinja": {
- "version": "0.2.1",
- "resolved": "https://registry.npmjs.org/@huggingface/jinja/-/jinja-0.2.1.tgz",
- "integrity": "sha512-HxjVCll8oGfgUQmN91NYWCjfuaQ5mYZkc/BB1gjfp28q3s48yiB5jUEV7BvaRdIAb/+14cNdX8TIdalFykwywA==",
+ "version": "0.2.2",
+ "resolved": "https://registry.npmjs.org/@huggingface/jinja/-/jinja-0.2.2.tgz",
+ "integrity": "sha512-/KPde26khDUIPkTGU82jdtTW9UAuvUTumCAbFs/7giR0SxsvZC4hru51PBvpijH6BVkHcROcvZM/lpy5h1jRRA==",
"engines": {
"node": ">=18"
}
diff --git a/package.json b/package.json
index 446fd5f16..4f69df2c1 100644
--- a/package.json
+++ b/package.json
@@ -40,7 +40,7 @@
"dependencies": {
"onnxruntime-web": "1.14.0",
"sharp": "^0.32.0",
- "@huggingface/jinja": "^0.2.1"
+ "@huggingface/jinja": "^0.2.2"
},
"optionalDependencies": {
"onnxruntime-node": "1.14.0"
diff --git a/src/tokenizers.js b/src/tokenizers.js
index 9692cf3b0..5b58e37c0 100644
--- a/src/tokenizers.js
+++ b/src/tokenizers.js
@@ -2519,6 +2519,18 @@ export class PreTrainedTokenizer extends Callable {
this.legacy = false;
this.chat_template = tokenizerConfig.chat_template ?? null;
+ if (Array.isArray(this.chat_template)) {
+ // Chat templates are stored as lists of dicts with fixed key names,
+ // we reconstruct that into a single dict while loading them.
+ const chat_template = Object.create(null);
+ for (const { name, template } of this.chat_template) {
+ if (typeof name !== 'string' || typeof template !== 'string') {
+ throw new Error('Chat template must be a list of objects with "name" and "template" properties');
+ }
+ chat_template[name] = template;
+ }
+ this.chat_template = chat_template;
+ }
this._compiled_template_cache = new Map();
}
@@ -2995,6 +3007,7 @@ export class PreTrainedTokenizer extends Callable {
* @param {number} [options.max_length=null] Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is false.
* If not specified, the tokenizer's `max_length` attribute will be used as a default.
* @param {boolean} [options.return_tensor=true] Whether to return the output as a Tensor or an Array. Has no effect if tokenize is false.
+ * @param {Object} [options.tokenizer_kwargs={}] Additional options to pass to the tokenizer.
* @returns {string | Tensor | number[]| number[][]} The tokenized output.
*/
apply_chat_template(conversation, {
@@ -3005,9 +3018,37 @@ export class PreTrainedTokenizer extends Callable {
truncation = false,
max_length = null,
return_tensor = true,
+ tokenizer_kwargs = {},
+ ...kwargs
} = {}) {
- chat_template ??= this.chat_template ?? this.default_chat_template;
+ // First, handle the cases when the model has a dict of multiple templates
+ if (
+ (this.chat_template && typeof this.chat_template === 'object') ||
+ (this.chat_template === null && this.default_chat_template && typeof this.default_chat_template === 'object')
+ ) {
+ const template_dict = this.chat_template ?? this.default_chat_template; // Guaranteed to be a non-null object
+
+ if (chat_template !== null && Object.hasOwn(template_dict, chat_template)) {
+ // The user can pass the name of a template to the chat template argument instead of an entire template
+ chat_template = template_dict[chat_template];
+ } else if (chat_template === null && 'default' in template_dict) {
+ chat_template = template_dict['default'];
+ } else if (chat_template === null) {
+ throw Error(
+ `This model has multiple chat templates with no default specified! Please either pass a chat ` +
+ `template or the name of the template you wish to use to the 'chat_template' argument. Available ` +
+ `template names are ${Object.keys(template_dict).sort()}.`
+ )
+ }
+ } else {
+ // These are the cases when the model has a single template
+ // priority: `chat_template` argument > `tokenizer.chat_template` > `tokenizer.default_chat_template
+ chat_template ??= this.chat_template ?? this.default_chat_template;
+ }
+ if (typeof chat_template !== 'string') {
+ throw Error(`chat_template must be a string, but got ${typeof chat_template}`);
+ }
// Compilation function uses a cache to avoid recompiling the same template
let compiledTemplate = this._compiled_template_cache.get(chat_template);
@@ -3029,6 +3070,7 @@ export class PreTrainedTokenizer extends Callable {
add_generation_prompt: add_generation_prompt,
...special_tokens_map,
+ ...kwargs,
});
if (tokenize) {
@@ -3038,6 +3080,7 @@ export class PreTrainedTokenizer extends Callable {
truncation,
max_length,
return_tensor,
+ ...tokenizer_kwargs,
}).input_ids;
}
@@ -3208,6 +3251,8 @@ export class GemmaTokenizer extends PreTrainedTokenizer {
_default_chat_template = "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}"
}
+export class Grok1Tokenizer extends PreTrainedTokenizer { }
+
/**
* Helper function to build translation inputs for an `NllbTokenizer` or `M2M100Tokenizer`.
* @param {PreTrainedTokenizer} self The tokenizer instance.
@@ -4263,6 +4308,9 @@ export class VitsTokenizer extends PreTrainedTokenizer {
this.decoder = new VitsDecoder({});
}
}
+
+export class CohereTokenizer extends PreTrainedTokenizer { }
+
/**
* Helper class which is used to instantiate pretrained tokenizers with the `from_pretrained` function.
* The chosen tokenizer class is determined by the type specified in the tokenizer config.
@@ -4314,6 +4362,8 @@ export class AutoTokenizer {
VitsTokenizer,
Qwen2Tokenizer,
GemmaTokenizer,
+ Grok1Tokenizer,
+ CohereTokenizer,
// Base case:
PreTrainedTokenizer,
diff --git a/tests/tokenizers.test.js b/tests/tokenizers.test.js
index 40fed05d1..8b92c6702 100644
--- a/tests/tokenizers.test.js
+++ b/tests/tokenizers.test.js
@@ -350,6 +350,42 @@ describe('Chat templates', () => {
compare(input_ids, [1, 733, 16289, 28793, 22557, 28725, 910, 460, 368, 28804, 733, 28748, 16289, 28793, 28737, 28742, 28719, 2548, 1598, 28723, 1602, 541, 315, 1316, 368, 3154, 28804, 2, 28705, 733, 16289, 28793, 315, 28742, 28715, 737, 298, 1347, 805, 910, 10706, 5752, 1077, 3791, 28808, 733, 28748, 16289, 28793])
});
+ it('should support multiple chat templates', async () => {
+
+ const tokenizer = await AutoTokenizer.from_pretrained("Xenova/c4ai-command-r-v01-tokenizer")
+
+ // define conversation input:
+ const conversation = [
+ { role: "user", content: "Whats the biggest penguin in the world?" }
+ ]
+ // define documents to ground on:
+ const documents = [
+ { title: "Tall penguins", text: "Emperor penguins are the tallest growing up to 122 cm in height." },
+ { title: "Penguin habitats", text: "Emperor penguins only live in Antarctica." }
+ ]
+
+ // render the RAG prompt as a string:
+ const grounded_generation_prompt = tokenizer.apply_chat_template(
+ conversation,
+ {
+ chat_template: "rag",
+ tokenize: false,
+ add_generation_prompt: true,
+
+ documents,
+ citation_mode: "accurate", // or "fast"
+ }
+ )
+ expect(grounded_generation_prompt).toEqual(
+ "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># Safety Preamble\nThe instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.\n\n" +
+ "# System Preamble\n## Basic Rules\nYou are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.\n\n" +
+ "# User Preamble\n## Task and Context\nYou help people answer their questions and other requests interactively. You will be asked a very wide array of requests on all kinds of topics. You will be equipped with a wide range of search engines or similar tools to help you, which you use to research your answer. You should focus on serving the user's needs as best you can, which will be wide-ranging.\n\n## Style Guide\nUnless the user asks for a different style of answer, you should answer in full sentences, using proper grammar and spelling.<|END_OF_TURN_TOKEN|>" +
+ "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>Whats the biggest penguin in the world?<|END_OF_TURN_TOKEN|>" +
+ "<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>\nDocument: 0\ntitle: Tall penguins\ntext: Emperor penguins are the tallest growing up to 122 cm in height.\n\nDocument: 1\ntitle: Penguin habitats\ntext: Emperor penguins only live in Antarctica.\n<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Carefully perform the following instructions, in order, starting each with a new line.\nFirstly, Decide which of the retrieved documents are relevant to the user's last input by writing 'Relevant Documents:' followed by comma-separated list of document numbers. If none are relevant, you should instead write 'None'.\nSecondly, Decide which of the retrieved documents contain facts that should be cited in a good answer to the user's last input by writing 'Cited Documents:' followed a comma-separated list of document numbers. If you dont want to cite any of them, you should instead write 'None'.\nThirdly, Write 'Answer:' followed by a response to the user's last input in high quality natural english. Use the retrieved documents to help you. Do not insert any citations or grounding markup.\nFinally, Write 'Grounded answer:' followed by a response to the user's last input in high quality natural english. Use the symbols and to indicate when a fact comes from a document in the search result, e.g my fact for a fact from document 0.<|END_OF_TURN_TOKEN|>" +
+ "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
+ );
+ });
+
it('should support user-defined chat template', async () => {
const tokenizer = await AutoTokenizer.from_pretrained("Xenova/llama-tokenizer");
@@ -395,7 +431,7 @@ describe('Chat templates', () => {
.replaceAll('USE_DEFAULT_PROMPT', true)
.replaceAll('DEFAULT_SYSTEM_MESSAGE', 'You are a helpful, respectful and honest assistant.');
- const text = await tokenizer.apply_chat_template(chat, { tokenize: false, return_tensor: false, chat_template });
+ const text = tokenizer.apply_chat_template(chat, { tokenize: false, return_tensor: false, chat_template });
expect(text).toEqual("[INST] <>\nYou are a helpful, respectful and honest assistant.\n<>\n\nHello, how are you? [/INST] I'm doing great. How can I help you today? [INST] I'd like to show off how chat templating works! [/INST]");
@@ -412,7 +448,7 @@ describe('Chat templates', () => {
for (let { messages, add_generation_prompt, tokenize, target } of tests) {
- const generated = await tokenizer.apply_chat_template(messages, {
+ const generated = tokenizer.apply_chat_template(messages, {
tokenize,
add_generation_prompt,
return_tensor: false,